Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
32af4d28
Unverified
Commit
32af4d28
authored
Oct 01, 2018
by
Tianqi Chen
Committed by
GitHub
Oct 01, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[IR] eager constant folding in operator overloading (#1789)
parent
3455c8a5
Hide whitespace changes
Inline
Side-by-side
Showing
29 changed files
with
1106 additions
and
196 deletions
+1106
-196
include/tvm/buffer.h
+1
-0
include/tvm/expr.h
+0
-10
include/tvm/ir.h
+0
-2
include/tvm/ir_operator.h
+552
-37
include/tvm/tensor.h
+1
-0
nnvm/src/top/tensor/reduce.cc
+1
-1
python/tvm/api.py
+6
-6
python/tvm/expr.py
+8
-8
python/tvm/generic.py
+4
-4
src/api/api_ir.cc
+60
-39
src/arithmetic/compute_expr.h
+7
-54
src/arithmetic/detect_linear_equation.cc
+2
-2
src/codegen/codegen_cuda.cc
+1
-1
src/codegen/verilog/verilog_ir.cc
+1
-1
src/lang/expr.cc
+1
-0
src/lang/ir_operator.cc
+401
-1
src/pass/ir_util.h
+3
-2
src/pass/split_pipeline.cc
+1
-2
src/pass/storage_rewrite.cc
+1
-1
src/pass/vectorize_loop.cc
+0
-1
tests/cpp/ir_mutator_test.cc
+1
-0
tests/python/unittest/test_arith_intset.py
+4
-5
tests/python/unittest/test_lang_basic.py
+1
-1
tests/python/unittest/test_lang_operator.py
+35
-0
tests/python/unittest/test_lang_reflection.py
+1
-1
tests/python/unittest/test_pass_simplify.py
+0
-1
topi/include/topi/elemwise.h
+3
-6
topi/include/topi/nn/pooling.h
+5
-5
topi/python/topi/vision/ssd/multibox.py
+5
-5
No files found.
include/tvm/buffer.h
View file @
32af4d28
...
@@ -10,6 +10,7 @@
...
@@ -10,6 +10,7 @@
#include "base.h"
#include "base.h"
#include "expr.h"
#include "expr.h"
#include "ir_operator.h"
#include "node/container.h"
#include "node/container.h"
namespace
tvm
{
namespace
tvm
{
...
...
include/tvm/expr.h
View file @
32af4d28
...
@@ -7,7 +7,6 @@
...
@@ -7,7 +7,6 @@
#define TVM_EXPR_H_
#define TVM_EXPR_H_
#include <ir/Expr.h>
#include <ir/Expr.h>
#include <ir/IROperator.h>
#include <ir/IRPrinter.h>
#include <ir/IRPrinter.h>
#include <string>
#include <string>
#include <algorithm>
#include <algorithm>
...
@@ -34,15 +33,6 @@ using HalideIR::Internal::Stmt;
...
@@ -34,15 +33,6 @@ using HalideIR::Internal::Stmt;
using
HalideIR
::
Internal
::
IRPrinter
;
using
HalideIR
::
Internal
::
IRPrinter
;
using
HalideIR
::
Internal
::
Variable
;
using
HalideIR
::
Internal
::
Variable
;
using
HalideIR
::
Internal
::
make_const
;
using
HalideIR
::
Internal
::
make_zero
;
using
HalideIR
::
Internal
::
make_one
;
using
HalideIR
::
Internal
::
as_const_int
;
using
HalideIR
::
Internal
::
as_const_uint
;
using
HalideIR
::
Internal
::
const_true
;
using
HalideIR
::
Internal
::
const_false
;
using
HalideIR
::
Internal
::
is_no_op
;
inline
Type
TVMShapeIndexType
()
{
inline
Type
TVMShapeIndexType
()
{
if
(
std
::
is_signed
<
tvm_index_t
>::
value
)
{
if
(
std
::
is_signed
<
tvm_index_t
>::
value
)
{
return
Int
(
sizeof
(
tvm_index_t
)
*
8
);
return
Int
(
sizeof
(
tvm_index_t
)
*
8
);
...
...
include/tvm/ir.h
View file @
32af4d28
...
@@ -495,8 +495,6 @@ using HalideIR::Internal::Block;
...
@@ -495,8 +495,6 @@ using HalideIR::Internal::Block;
using
HalideIR
::
Internal
::
IfThenElse
;
using
HalideIR
::
Internal
::
IfThenElse
;
using
HalideIR
::
Internal
::
Evaluate
;
using
HalideIR
::
Internal
::
Evaluate
;
using
HalideIR
::
Internal
::
Shuffle
;
using
HalideIR
::
Internal
::
Shuffle
;
// ir functions
using
HalideIR
::
Internal
::
is_const_power_of_two_integer
;
/*!
/*!
* \brief Create a type annotation expression
* \brief Create a type annotation expression
...
...
include/tvm/ir_operator.h
View file @
32af4d28
/*!
/*!
* Copyright (c) 201
7
by Contributors
* Copyright (c) 201
8
by Contributors
* \file tvm/ir_operator.h
* \file tvm/ir_operator.h
* \brief Common operators of Expr
* \brief Common operators defined for Expr.
*
* \note Most of the operator defined here perform simple constant folding
* when the type is int32 or int64 for simplifying the index expressions.
*/
*/
#ifndef TVM_IR_OPERATOR_H_
#ifndef TVM_IR_OPERATOR_H_
#define TVM_IR_OPERATOR_H_
#define TVM_IR_OPERATOR_H_
#include <algorithm>
#include <algorithm>
#include <type_traits>
#include "expr.h"
#include "expr.h"
#include "ir.h"
#include "ir.h"
namespace
tvm
{
namespace
tvm
{
/*!
* \brief Make a const value with certain data type.
* \param t The target type.
* \param value The input value
* \return the result expression.
* \tparam ValueType The constant value type
*/
template
<
typename
ValueType
,
typename
=
typename
std
::
enable_if
<
std
::
is_pod
<
ValueType
>::
value
>::
type
>
inline
Expr
make_const
(
Type
t
,
ValueType
value
);
/*!
* \brief Make a const zero expr.
* \param t The target type.
* \return the result expression.
*/
inline
Expr
make_zero
(
Type
t
);
/*!
* \brief Make a constant true expression.
* \param lanes The number of lanes in the bool
* \return The result expression.
*/
inline
Expr
const_true
(
int
lanes
=
1
)
{
return
make_const
(
UInt
(
1
,
lanes
),
1
);
}
/*!
* \brief Make a constant false expression.
* \param lanes The number of lanes in the bool
* \return The result expression.
*/
inline
Expr
const_false
(
int
lanes
=
1
)
{
return
make_const
(
UInt
(
1
,
lanes
),
0
);
}
/*!
* \brief Get x as constant int expression.
* \param x The expression
* \return the address to the int expression,
* return nullptr, if x is not IntImm.
*/
inline
const
int64_t
*
as_const_int
(
const
Expr
&
x
)
{
if
(
!
x
.
defined
())
return
nullptr
;
if
(
const
ir
::
IntImm
*
op
=
x
.
as
<
ir
::
IntImm
>
())
{
return
&
(
op
->
value
);
}
else
{
return
nullptr
;
}
}
/*!
* \brief Get x as constant uint expression.
* \param x The expression
* \return the address to the int expression,
* return nullptr, if x is not UIntImm.
*/
inline
const
uint64_t
*
as_const_uint
(
const
Expr
&
x
)
{
if
(
!
x
.
defined
())
return
nullptr
;
if
(
const
ir
::
UIntImm
*
op
=
x
.
as
<
ir
::
UIntImm
>
())
{
return
&
(
op
->
value
);
}
else
{
return
nullptr
;
}
}
/*!
* \brief Check whether x is a constant integer expression.
* \param x The input argument
* \param value the value to be compared against.
* \return whether x is constant expression.
*/
inline
bool
is_const_int
(
const
Expr
&
x
,
int64_t
value
);
/*!
* \brief Check whether stmt is nop.
* \param stmt The input statement
* \return whether stmt is nop
*/
inline
bool
is_no_op
(
const
Stmt
&
stmt
);
/*!
* \brief Check whether x is a constant integer 1
* \param x The input argument.
* \note This only return true for integer types.
* \return whether x is constant 1
*/
inline
bool
is_one
(
const
Expr
&
x
)
{
return
is_const_int
(
x
,
1
);
}
using
HalideIR
::
likely
;
/*!
using
HalideIR
::
likely_if_innermost
;
* \brief Check whether x is a constant integer 0
// functions
* \param x The input argument
using
HalideIR
::
cast
;
* \return whether x is constant 0
using
HalideIR
::
min
;
* \note This only return true for integer types.
using
HalideIR
::
max
;
*/
using
HalideIR
::
select
;
inline
bool
is_zero
(
const
Expr
&
x
)
{
return
is_const_int
(
x
,
0
);
}
/*!
* \brief Check whether x is a constant.
* \note This only return true for integer types.
* \return whether x is constant
*/
inline
bool
is_const
(
const
Expr
&
x
);
/*!
* \brief Check whether x is a constant power of two
* If x is power of two, write the power to the shift.
*
* \param x The input expression.
* \param shift The output shift if x is power of two.
* \return whether x is constant power of two
*/
TVM_DLL
bool
is_const_power_of_two_integer
(
const
Expr
&
x
,
int
*
shift
);
/*!
* \brief cast value to type.
*
* \param t the target type.
* \param value The value
* \return The result expression.
* \note This function may return value if the type is the same.
*/
TVM_DLL
Expr
cast
(
const
Type
&
t
,
Expr
value
);
/*!
* \brief perform reinterpret cast value to type.
*
* \param t the target type.
* \param value The value
* \return The result expression.
* \note This function may return value if the type is the same.
*/
TVM_DLL
Expr
reinterpret
(
const
Type
&
t
,
Expr
value
);
/*!
* \brief add operator
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL
Expr
operator
+
(
Expr
a
,
Expr
b
);
/*!
* \brief subtraction operator
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL
Expr
operator
-
(
Expr
a
,
Expr
b
);
/*!
* \brief negation.
*
* \param a input.
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL
Expr
operator
-
(
Expr
a
);
/*!
* \brief multiplication operator
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL
Expr
operator
*
(
Expr
a
,
Expr
b
);
/*!
* \brief division operator
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL
Expr
operator
/
(
Expr
a
,
Expr
b
);
/*!
* \brief mod operator
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL
Expr
operator
%
(
Expr
a
,
Expr
b
);
/*!
* \brief left shift operator
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL
Expr
operator
<<
(
Expr
a
,
Expr
b
);
/*!
* \brief right shift operator
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL
Expr
operator
>>
(
Expr
a
,
Expr
b
);
/*!
* \brief greater
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL
Expr
operator
>
(
Expr
a
,
Expr
b
);
/*!
* \brief greater_equal
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL
Expr
operator
>=
(
Expr
a
,
Expr
b
);
/*!
* \brief less
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL
Expr
operator
<
(
Expr
a
,
Expr
b
);
/*!
* \brief less_equal
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL
Expr
operator
<=
(
Expr
a
,
Expr
b
);
/*!
* \brief equal
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL
Expr
operator
==
(
Expr
a
,
Expr
b
);
/*!
* \brief not_equal
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL
Expr
operator
!=
(
Expr
a
,
Expr
b
);
/*!
* \brief and
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note This operator does eager constant folding.
*/
TVM_DLL
Expr
operator
&&
(
Expr
a
,
Expr
b
);
/*!
* \brief or
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note This operator does eager constant folding.
*/
TVM_DLL
Expr
operator
||
(
Expr
a
,
Expr
b
);
/*!
* \brief not
*
* \param a left operand
* \return The result expression.
* \note This operator does eager constant folding.
*/
TVM_DLL
Expr
operator
!
(
Expr
a
);
/*!
* \brief take maximum of two values
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL
Expr
max
(
Expr
a
,
Expr
b
);
/*!
* \brief take minimum of two values
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL
Expr
min
(
Expr
a
,
Expr
b
);
/*!
* \brief right shift
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL
Expr
operator
>>
(
Expr
a
,
Expr
b
);
/*!
* \brief left shift
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL
Expr
operator
<<
(
Expr
a
,
Expr
b
);
/*!
* \brief take bitwise and of two values
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL
Expr
operator
&
(
Expr
a
,
Expr
b
);
/*!
* \brief take bitwise or of two values
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL
Expr
operator
|
(
Expr
a
,
Expr
b
);
/*!
* \brief take bitwise xor of two values
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL
Expr
operator
^
(
Expr
a
,
Expr
b
);
/*!
* \brief take bitwise negation of two values
*
* \param a the input expression.
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL
Expr
operator
~
(
Expr
a
);
/*!
* \brief select result by condition
*
* \param cond The condition
* \param true_value The value when results are true.
* \param false_value The value when results are false.
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL
Expr
select
(
Expr
cond
,
Expr
true_value
,
Expr
false_value
);
/*!
* \brief Mark condition as likely.
* \param cond The condition
* \return The marked expression.
*/
TVM_DLL
Expr
likely
(
Expr
cond
);
/*!
* \brief Calculate power(x, y)
* \param x The left operand.
* \param y The right operand.
*/
TVM_DLL
Expr
pow
(
Expr
x
,
Expr
y
);
/*!
* \brief Calculate absolute value of x.
* \param x The input data
*
* \return The aboslute value of input data x
*/
TVM_DLL
Expr
abs
(
Expr
x
);
/*!
/*!
* \brief sum of of source expression over axis
* \brief sum of of source expression over axis
...
@@ -48,13 +450,12 @@ TVM_DLL Expr min(Expr source, Array<IterVar> axis);
...
@@ -48,13 +450,12 @@ TVM_DLL Expr min(Expr source, Array<IterVar> axis);
*/
*/
TVM_DLL
Expr
prod
(
Expr
source
,
Array
<
IterVar
>
axis
);
TVM_DLL
Expr
prod
(
Expr
source
,
Array
<
IterVar
>
axis
);
//
Unary i
ntrinsic operators
//
I
ntrinsic operators
#define TVM_DECLARE_INTRIN_UNARY(OpName) \
#define TVM_DECLARE_INTRIN_UNARY(OpName) \
inline Expr OpName(Expr x) { \
inline Expr OpName(Expr x) { \
return ir::Call::make(x.type(), #OpName, {x}, ir::Call::PureIntrinsic); \
return ir::Call::make(x.type(), #OpName, {x}, ir::Call::PureIntrinsic); \
} \
} \
TVM_DECLARE_INTRIN_UNARY
(
exp
);
TVM_DECLARE_INTRIN_UNARY
(
exp
);
TVM_DECLARE_INTRIN_UNARY
(
tanh
);
TVM_DECLARE_INTRIN_UNARY
(
tanh
);
TVM_DECLARE_INTRIN_UNARY
(
sigmoid
);
TVM_DECLARE_INTRIN_UNARY
(
sigmoid
);
...
@@ -64,38 +465,152 @@ TVM_DECLARE_INTRIN_UNARY(floor);
...
@@ -64,38 +465,152 @@ TVM_DECLARE_INTRIN_UNARY(floor);
TVM_DECLARE_INTRIN_UNARY
(
ceil
);
TVM_DECLARE_INTRIN_UNARY
(
ceil
);
TVM_DECLARE_INTRIN_UNARY
(
round
);
TVM_DECLARE_INTRIN_UNARY
(
round
);
TVM_DECLARE_INTRIN_UNARY
(
trunc
);
TVM_DECLARE_INTRIN_UNARY
(
trunc
);
TVM_DECLARE_INTRIN_UNARY
(
popcount
);
/*!
* \brief Calculate power(x, y)
// Implementation details after this
* \param x The left operand.
inline
bool
is_const
(
const
Expr
&
x
)
{
* \param y The right operand.
if
(
x
.
as
<
ir
::
IntImm
>
()
||
x
.
as
<
ir
::
UIntImm
>
())
{
*/
return
true
;
inline
Expr
pow
(
Expr
x
,
Expr
y
)
{
}
else
if
(
const
auto
*
op
=
x
.
as
<
ir
::
Broadcast
>
())
{
match_types
(
x
,
y
);
const
Expr
&
val
=
op
->
value
;
CHECK
(
x
.
type
().
is_float
())
<<
"power only applies to float"
;
if
(
val
.
as
<
ir
::
IntImm
>
()
||
val
.
as
<
ir
::
UIntImm
>
())
{
return
ir
::
Call
::
make
(
x
.
type
(),
"pow"
,
{
x
,
y
},
ir
::
Call
::
PureIntrinsic
);
return
true
;
}
}
return
false
;
}
}
/*!
inline
bool
is_positive_const
(
const
Expr
&
a
)
{
* \brief Calculate absolute value of x, elementwise
if
(
const
ir
::
IntImm
*
op
=
a
.
as
<
ir
::
IntImm
>
())
{
* \param x The input data
return
op
->
value
>
0
;
*
}
else
if
(
const
ir
::
UIntImm
*
op
=
a
.
as
<
ir
::
UIntImm
>
())
{
* \return The aboslute value of input data x
return
op
->
value
>
0
;
*/
inline
Expr
abs
(
Expr
x
)
{
if
(
x
.
type
().
is_int
())
{
return
select
(
x
>=
make_zero
(
x
.
type
()),
x
,
-
x
);
}
else
if
(
x
.
type
().
is_float
())
{
return
ir
::
Call
::
make
(
x
.
type
(),
"fabs"
,
{
x
},
ir
::
Call
::
PureIntrinsic
);
}
else
if
(
x
.
type
().
is_uint
())
{
return
x
;
}
else
{
}
else
{
LOG
(
WARNING
)
<<
"Warning: Data type "
<<
x
.
type
()
return
false
;
<<
" not supported for absolute op. Skipping absolute op..."
;
return
x
;
}
}
}
}
}
// namespace tvm
inline
bool
is_negative_const
(
const
Expr
&
a
)
{
if
(
const
ir
::
IntImm
*
op
=
a
.
as
<
ir
::
IntImm
>
())
{
return
op
->
value
<
0
;
}
else
{
return
false
;
}
}
inline
bool
is_const_int
(
const
Expr
&
x
,
int64_t
value
)
{
if
(
const
auto
*
op
=
x
.
as
<
ir
::
IntImm
>
())
{
return
op
->
value
==
value
;
}
else
if
(
const
auto
*
op
=
x
.
as
<
ir
::
UIntImm
>
())
{
return
op
->
value
==
static_cast
<
uint64_t
>
(
value
);
}
else
if
(
const
auto
*
op
=
x
.
as
<
ir
::
Broadcast
>
())
{
const
Expr
&
val
=
op
->
value
;
if
(
const
auto
*
opv
=
val
.
as
<
ir
::
IntImm
>
())
{
return
opv
->
value
==
value
;
}
else
if
(
const
auto
*
opv
=
val
.
as
<
ir
::
UIntImm
>
())
{
return
opv
->
value
==
static_cast
<
uint64_t
>
(
value
);
}
}
return
false
;
}
inline
bool
is_no_op
(
const
Stmt
&
stmt
)
{
if
(
!
stmt
.
defined
())
return
true
;
if
(
const
auto
*
op
=
stmt
.
as
<
ir
::
Evaluate
>
())
{
return
is_const
(
op
->
value
);
}
return
false
;
}
template
<
typename
ValueType
>
inline
Expr
MakeConstScalar
(
Type
t
,
ValueType
value
)
{
if
(
t
.
is_int
())
return
ir
::
IntImm
::
make
(
t
,
static_cast
<
int64_t
>
(
value
));
if
(
t
.
is_uint
())
return
ir
::
UIntImm
::
make
(
t
,
static_cast
<
uint64_t
>
(
value
));
if
(
t
.
is_float
())
return
ir
::
FloatImm
::
make
(
t
,
static_cast
<
double
>
(
value
));
LOG
(
FATAL
)
<<
"cannot make const for type "
<<
t
;
return
Expr
();
}
template
<
typename
ValueType
,
typename
>
inline
Expr
make_const
(
Type
t
,
ValueType
value
)
{
if
(
t
.
lanes
()
==
1
)
{
return
MakeConstScalar
(
t
,
value
);
}
else
{
return
ir
::
Broadcast
::
make
(
MakeConstScalar
(
t
.
element_of
(),
value
),
t
.
lanes
());
}
}
inline
Expr
make_zero
(
Type
t
)
{
if
(
t
.
is_handle
())
{
return
reinterpret
(
t
,
make_const
(
UInt
(
64
),
0
));
}
return
make_const
(
t
,
0
);
}
// additional const expression overloading
#define TVM_DEFINE_ASSIGN_OP_OVERLOAD(Name, OpFunc) \
inline Expr Name(Expr& a, Expr b) { \
a = OpFunc(a, b); \
return a; \
}
#define TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(Name) \
inline Expr Name(const Expr& a, float b) { \
return Name(a, Expr(b)); \
} \
inline Expr Name(float a, const Expr& b) { \
return Name(Expr(a), b); \
} \
inline Expr Name(int a, const Expr& b) { \
return Name(make_const(b.type(), a), b); \
} \
inline Expr Name(const Expr& a, int b) { \
return Name(a, make_const(a.type(), b)); \
}
#define TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(Name) \
inline Expr Name(const Expr& a, bool b) { \
return Name(a, Expr(b)); \
} \
inline Expr Name(bool a, const Expr& b) { \
return Name(Expr(a), b); \
}
#define TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(Name) \
inline Expr Name(const Expr& a, int b) { \
return Name(a, make_const(a.type(), b)); \
} \
inline Expr Name(int a, const Expr& b) { \
return Name(make_const(b.type(), a), b); \
}
TVM_DEFINE_ASSIGN_OP_OVERLOAD
(
operator
+=
,
operator
+
);
TVM_DEFINE_ASSIGN_OP_OVERLOAD
(
operator
-=
,
operator
-
);
TVM_DEFINE_ASSIGN_OP_OVERLOAD
(
operator
*=
,
operator
*
);
TVM_DEFINE_ASSIGN_OP_OVERLOAD
(
operator
/=
,
operator
/
);
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD
(
operator
+
);
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD
(
operator
-
);
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD
(
operator
*
);
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD
(
operator
/
);
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD
(
max
);
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD
(
min
);
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD
(
operator
>
);
// NOLINT(*)
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD
(
operator
>=
);
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD
(
operator
<
);
// NOLINT(*)
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD
(
operator
<=
);
// integer related ops
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD
(
operator
%
);
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD
(
operator
>>
);
// NOLINT(*)
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD
(
operator
<<
);
// NOLINT(*)
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD
(
operator
&
);
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD
(
operator
|
);
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD
(
operator
^
);
// logical ops
TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD
(
operator
&&
);
TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD
(
operator
||
);
}
// namespace tvm
#endif // TVM_IR_OPERATOR_H_
#endif // TVM_IR_OPERATOR_H_
include/tvm/tensor.h
View file @
32af4d28
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
#include "base.h"
#include "base.h"
#include "expr.h"
#include "expr.h"
#include "ir_operator.h"
#include "arithmetic.h"
#include "arithmetic.h"
#include "node/container.h"
#include "node/container.h"
...
...
nnvm/src/top/tensor/reduce.cc
View file @
32af4d28
...
@@ -354,7 +354,7 @@ Example::
...
@@ -354,7 +354,7 @@ Example::
if
(
!
r_axes
.
ndim
())
return
Array
<
Tensor
>
{
topi
::
identity
(
inputs
[
0
])
};
if
(
!
r_axes
.
ndim
())
return
Array
<
Tensor
>
{
topi
::
identity
(
inputs
[
0
])
};
auto
axis
=
ShapeToArray
(
r_axes
);
auto
axis
=
ShapeToArray
(
r_axes
);
Expr
count
=
make_
one
(
inputs
[
0
]
->
dtype
);
Expr
count
=
make_
const
(
inputs
[
0
]
->
dtype
,
1
);
for
(
auto
&
i
:
r_axes
)
{
for
(
auto
&
i
:
r_axes
)
{
count
*=
inputs
[
0
]
->
shape
[
i
];
count
*=
inputs
[
0
]
->
shape
[
i
];
}
}
...
...
python/tvm/api.py
View file @
32af4d28
...
@@ -156,9 +156,9 @@ def any(*args):
...
@@ -156,9 +156,9 @@ def any(*args):
raise
ValueError
(
"Any must take at least 1 argument"
)
raise
ValueError
(
"Any must take at least 1 argument"
)
if
len
(
args
)
==
1
:
if
len
(
args
)
==
1
:
return
args
[
0
]
return
args
[
0
]
ret
=
_
expr
.
Or
(
args
[
0
],
args
[
1
])
ret
=
_
make
.
_Op
Or
(
args
[
0
],
args
[
1
])
for
i
in
range
(
2
,
len
(
args
)):
for
i
in
range
(
2
,
len
(
args
)):
ret
=
_
expr
.
Or
(
ret
,
args
[
i
])
ret
=
_
make
.
_Op
Or
(
ret
,
args
[
i
])
return
ret
return
ret
...
@@ -180,9 +180,9 @@ def all(*args):
...
@@ -180,9 +180,9 @@ def all(*args):
raise
ValueError
(
"Any must take at least 1 argument"
)
raise
ValueError
(
"Any must take at least 1 argument"
)
if
len
(
args
)
==
1
:
if
len
(
args
)
==
1
:
return
args
[
0
]
return
args
[
0
]
ret
=
_
expr
.
And
(
args
[
0
],
args
[
1
])
ret
=
_
make
.
_Op
And
(
args
[
0
],
args
[
1
])
for
i
in
range
(
2
,
len
(
args
)):
for
i
in
range
(
2
,
len
(
args
)):
ret
=
_
expr
.
And
(
ret
,
args
[
i
])
ret
=
_
make
.
_Op
And
(
ret
,
args
[
i
])
return
ret
return
ret
...
@@ -773,5 +773,5 @@ def comm_reducer(fcombine, fidentity, name="reduce"):
...
@@ -773,5 +773,5 @@ def comm_reducer(fcombine, fidentity, name="reduce"):
_init_api
(
"tvm.api"
)
_init_api
(
"tvm.api"
)
#pylint: disable=unnecessary-lambda
#pylint: disable=unnecessary-lambda
sum
=
comm_reducer
(
lambda
x
,
y
:
x
+
y
,
lambda
t
:
const
(
0
,
dtype
=
t
),
name
=
"sum"
)
sum
=
comm_reducer
(
lambda
x
,
y
:
x
+
y
,
lambda
t
:
const
(
0
,
dtype
=
t
),
name
=
"sum"
)
min
=
comm_reducer
(
lambda
x
,
y
:
_
expr
.
Min
(
x
,
y
),
max_value
,
name
=
'min'
)
min
=
comm_reducer
(
lambda
x
,
y
:
_
make
.
_Op
Min
(
x
,
y
),
max_value
,
name
=
'min'
)
max
=
comm_reducer
(
lambda
x
,
y
:
_
expr
.
Max
(
x
,
y
),
min_value
,
name
=
'max'
)
max
=
comm_reducer
(
lambda
x
,
y
:
_
make
.
_Op
Max
(
x
,
y
),
min_value
,
name
=
'max'
)
python/tvm/expr.py
View file @
32af4d28
...
@@ -60,7 +60,7 @@ class ExprOp(object):
...
@@ -60,7 +60,7 @@ class ExprOp(object):
return
self
.
__rdiv__
(
other
)
return
self
.
__rdiv__
(
other
)
def
__mod__
(
self
,
other
):
def
__mod__
(
self
,
other
):
return
_make
.
Mod
(
self
,
other
)
return
_make
.
_Op
Mod
(
self
,
other
)
def
__neg__
(
self
):
def
__neg__
(
self
):
neg_one
=
_api_internal
.
_const
(
-
1
,
self
.
dtype
)
neg_one
=
_api_internal
.
_const
(
-
1
,
self
.
dtype
)
...
@@ -85,10 +85,10 @@ class ExprOp(object):
...
@@ -85,10 +85,10 @@ class ExprOp(object):
return
_make
.
Call
(
self
.
dtype
,
"bitwise_not"
,
[
self
],
Call
.
PureIntrinsic
,
None
,
0
)
return
_make
.
Call
(
self
.
dtype
,
"bitwise_not"
,
[
self
],
Call
.
PureIntrinsic
,
None
,
0
)
def
__lt__
(
self
,
other
):
def
__lt__
(
self
,
other
):
return
_make
.
LT
(
self
,
other
)
return
_make
.
_Op
LT
(
self
,
other
)
def
__le__
(
self
,
other
):
def
__le__
(
self
,
other
):
return
_make
.
LE
(
self
,
other
)
return
_make
.
_Op
LE
(
self
,
other
)
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
return
EqualOp
(
self
,
other
)
return
EqualOp
(
self
,
other
)
...
@@ -97,10 +97,10 @@ class ExprOp(object):
...
@@ -97,10 +97,10 @@ class ExprOp(object):
return
NotEqualOp
(
self
,
other
)
return
NotEqualOp
(
self
,
other
)
def
__gt__
(
self
,
other
):
def
__gt__
(
self
,
other
):
return
_make
.
GT
(
self
,
other
)
return
_make
.
_Op
GT
(
self
,
other
)
def
__ge__
(
self
,
other
):
def
__ge__
(
self
,
other
):
return
_make
.
GE
(
self
,
other
)
return
_make
.
_Op
GE
(
self
,
other
)
def
__nonzero__
(
self
):
def
__nonzero__
(
self
):
raise
ValueError
(
"Cannot use and / or / not operator to Expr, hint: "
+
raise
ValueError
(
"Cannot use and / or / not operator to Expr, hint: "
+
...
@@ -122,7 +122,7 @@ class ExprOp(object):
...
@@ -122,7 +122,7 @@ class ExprOp(object):
ret : Expr
ret : Expr
The equality expression.
The equality expression.
"""
"""
return
_make
.
EQ
(
self
,
other
)
return
_make
.
_Op
EQ
(
self
,
other
)
def
astype
(
self
,
dtype
):
def
astype
(
self
,
dtype
):
"""Cast the expression to other type.
"""Cast the expression to other type.
...
@@ -169,7 +169,7 @@ class EqualOp(NodeGeneric, ExprOp):
...
@@ -169,7 +169,7 @@ class EqualOp(NodeGeneric, ExprOp):
def
asnode
(
self
):
def
asnode
(
self
):
"""Convert node."""
"""Convert node."""
return
_make
.
EQ
(
self
.
a
,
self
.
b
)
return
_make
.
_Op
EQ
(
self
.
a
,
self
.
b
)
class
NotEqualOp
(
NodeGeneric
,
ExprOp
):
class
NotEqualOp
(
NodeGeneric
,
ExprOp
):
...
@@ -201,7 +201,7 @@ class NotEqualOp(NodeGeneric, ExprOp):
...
@@ -201,7 +201,7 @@ class NotEqualOp(NodeGeneric, ExprOp):
def
asnode
(
self
):
def
asnode
(
self
):
"""Convert node."""
"""Convert node."""
return
_make
.
NE
(
self
.
a
,
self
.
b
)
return
_make
.
_Op
NE
(
self
.
a
,
self
.
b
)
class
Expr
(
ExprOp
,
NodeBase
):
class
Expr
(
ExprOp
,
NodeBase
):
...
...
python/tvm/generic.py
View file @
32af4d28
...
@@ -24,7 +24,7 @@ def add(lhs, rhs):
...
@@ -24,7 +24,7 @@ def add(lhs, rhs):
op : tvm.Expr
op : tvm.Expr
The result Expr of add operaton.
The result Expr of add operaton.
"""
"""
return
_make
.
Add
(
lhs
,
rhs
)
return
_make
.
_Op
Add
(
lhs
,
rhs
)
def
subtract
(
lhs
,
rhs
):
def
subtract
(
lhs
,
rhs
):
...
@@ -42,7 +42,7 @@ def subtract(lhs, rhs):
...
@@ -42,7 +42,7 @@ def subtract(lhs, rhs):
op : tvm.Expr
op : tvm.Expr
The result Expr of subtract operaton.
The result Expr of subtract operaton.
"""
"""
return
_make
.
Sub
(
lhs
,
rhs
)
return
_make
.
_Op
Sub
(
lhs
,
rhs
)
def
multiply
(
lhs
,
rhs
):
def
multiply
(
lhs
,
rhs
):
...
@@ -60,7 +60,7 @@ def multiply(lhs, rhs):
...
@@ -60,7 +60,7 @@ def multiply(lhs, rhs):
op : tvm.Expr
op : tvm.Expr
The result Expr of multiply operaton.
The result Expr of multiply operaton.
"""
"""
return
_make
.
Mul
(
lhs
,
rhs
)
return
_make
.
_Op
Mul
(
lhs
,
rhs
)
def
divide
(
lhs
,
rhs
):
def
divide
(
lhs
,
rhs
):
...
@@ -78,7 +78,7 @@ def divide(lhs, rhs):
...
@@ -78,7 +78,7 @@ def divide(lhs, rhs):
op : tvm.Expr
op : tvm.Expr
The result Expr of divide operaton.
The result Expr of divide operaton.
"""
"""
return
_make
.
Div
(
lhs
,
rhs
)
return
_make
.
_Op
Div
(
lhs
,
rhs
)
def
cast
(
src
,
dtype
):
def
cast
(
src
,
dtype
):
...
...
src/api/api_ir.cc
View file @
32af4d28
...
@@ -5,7 +5,7 @@
...
@@ -5,7 +5,7 @@
*/
*/
#include <tvm/expr.h>
#include <tvm/expr.h>
#include <tvm/ir.h>
#include <tvm/ir.h>
#include <
ir/IRO
perator.h>
#include <
tvm/ir_o
perator.h>
#include <tvm/api_registry.h>
#include <tvm/api_registry.h>
#include <tvm/ir_operator.h>
#include <tvm/ir_operator.h>
...
@@ -117,6 +117,50 @@ TVM_REGISTER_API("make.CommReducer")
...
@@ -117,6 +117,50 @@ TVM_REGISTER_API("make.CommReducer")
*ret = Node::make(args[0], args[1], args[2], args[3], args[4]); \
*ret = Node::make(args[0], args[1], args[2], args[3], args[4]); \
}) \
}) \
REGISTER_MAKE5
(
Reduce
);
REGISTER_MAKE4
(
AttrStmt
);
REGISTER_MAKE2
(
IntImm
);
REGISTER_MAKE2
(
UIntImm
);
REGISTER_MAKE2
(
FloatImm
);
REGISTER_MAKE1
(
StringImm
);
REGISTER_MAKE2
(
Add
);
REGISTER_MAKE2
(
Sub
);
REGISTER_MAKE2
(
Mul
);
REGISTER_MAKE2
(
Div
);
REGISTER_MAKE2
(
Mod
);
REGISTER_MAKE2
(
Min
);
REGISTER_MAKE2
(
Max
);
REGISTER_MAKE2
(
EQ
);
REGISTER_MAKE2
(
NE
);
REGISTER_MAKE2
(
LT
);
REGISTER_MAKE2
(
LE
);
REGISTER_MAKE2
(
GT
);
REGISTER_MAKE2
(
GE
);
REGISTER_MAKE2
(
And
);
REGISTER_MAKE2
(
Or
);
REGISTER_MAKE1
(
Not
);
REGISTER_MAKE3
(
Select
);
REGISTER_MAKE3
(
Ramp
);
REGISTER_MAKE2
(
Cast
);
REGISTER_MAKE2
(
Broadcast
);
REGISTER_MAKE2
(
Shuffle
);
REGISTER_MAKE3
(
Let
);
REGISTER_MAKE3
(
LetStmt
);
REGISTER_MAKE3
(
AssertStmt
);
REGISTER_MAKE3
(
ProducerConsumer
);
REGISTER_MAKE5
(
Allocate
);
REGISTER_MAKE4
(
Provide
);
REGISTER_MAKE4
(
Prefetch
);
REGISTER_MAKE1
(
Free
);
REGISTER_MAKE2
(
Block
);
REGISTER_MAKE3
(
IfThenElse
);
REGISTER_MAKE1
(
Evaluate
);
// operator overloading, smarter than make
#define REGISTER_MAKE_BINARY_OP(Node, Func) \
#define REGISTER_MAKE_BINARY_OP(Node, Func) \
TVM_REGISTER_API("make."#Node) \
TVM_REGISTER_API("make."#Node) \
.set_body([](TVMArgs args, TVMRetValue *ret) { \
.set_body([](TVMArgs args, TVMRetValue *ret) { \
...
@@ -138,50 +182,27 @@ TVM_REGISTER_API("make.CommReducer")
...
@@ -138,50 +182,27 @@ TVM_REGISTER_API("make.CommReducer")
} \
} \
})
})
REGISTER_MAKE5
(
Reduce
);
REGISTER_MAKE4
(
AttrStmt
);
REGISTER_MAKE2
(
IntImm
);
REGISTER_MAKE_BINARY_OP
(
_OpAdd
,
operator
+
);
REGISTER_MAKE2
(
UIntImm
);
REGISTER_MAKE_BINARY_OP
(
_OpSub
,
operator
-
);
REGISTER_MAKE2
(
FloatImm
);
REGISTER_MAKE_BINARY_OP
(
_OpMul
,
operator
*
);
REGISTER_MAKE1
(
StringImm
);
REGISTER_MAKE_BINARY_OP
(
_OpDiv
,
operator
/
);
REGISTER_MAKE_BINARY_OP
(
Add
,
operator
+
);
REGISTER_MAKE_BINARY_OP
(
_OpMod
,
operator
%
);
REGISTER_MAKE_BINARY_OP
(
Sub
,
operator
-
);
REGISTER_MAKE_BINARY_OP
(
_OpMin
,
min
);
REGISTER_MAKE_BINARY_OP
(
Mul
,
operator
*
);
REGISTER_MAKE_BINARY_OP
(
_OpMax
,
max
);
REGISTER_MAKE_BINARY_OP
(
Div
,
operator
/
);
REGISTER_MAKE_BINARY_OP
(
_OpEQ
,
operator
==
);
REGISTER_MAKE_BINARY_OP
(
Mod
,
operator
%
);
REGISTER_MAKE_BINARY_OP
(
_OpNE
,
operator
!=
);
REGISTER_MAKE_BINARY_OP
(
Min
,
min
);
REGISTER_MAKE_BINARY_OP
(
_OpLT
,
operator
<
);
// NOLINT(*)
REGISTER_MAKE_BINARY_OP
(
Max
,
max
);
REGISTER_MAKE_BINARY_OP
(
_OpLE
,
operator
<=
);
// NOLINT(*)
REGISTER_MAKE_BINARY_OP
(
EQ
,
operator
==
);
REGISTER_MAKE_BINARY_OP
(
_OpGT
,
operator
>
);
// NOLINT(*)
REGISTER_MAKE_BINARY_OP
(
NE
,
operator
!=
);
REGISTER_MAKE_BINARY_OP
(
_OpGE
,
operator
>=
);
REGISTER_MAKE_BINARY_OP
(
LT
,
operator
<
);
// NOLINT(*)
REGISTER_MAKE_BINARY_OP
(
_OpAnd
,
operator
&&
);
REGISTER_MAKE_BINARY_OP
(
LE
,
operator
<=
);
// NOLINT(*)
REGISTER_MAKE_BINARY_OP
(
_OpOr
,
operator
||
);
REGISTER_MAKE_BINARY_OP
(
GT
,
operator
>
);
// NOLINT(*)
REGISTER_MAKE_BINARY_OP
(
GE
,
operator
>=
);
REGISTER_MAKE_BINARY_OP
(
And
,
operator
&&
);
REGISTER_MAKE_BINARY_OP
(
Or
,
operator
||
);
REGISTER_MAKE_BIT_OP
(
bitwise_and
,
operator
&
);
REGISTER_MAKE_BIT_OP
(
bitwise_and
,
operator
&
);
REGISTER_MAKE_BIT_OP
(
bitwise_or
,
operator
|
);
REGISTER_MAKE_BIT_OP
(
bitwise_or
,
operator
|
);
REGISTER_MAKE_BIT_OP
(
bitwise_xor
,
operator
^
);
REGISTER_MAKE_BIT_OP
(
bitwise_xor
,
operator
^
);
REGISTER_MAKE_BIT_OP
(
left_shift
,
operator
<<
);
// NOLINT(*)
REGISTER_MAKE_BIT_OP
(
left_shift
,
operator
<<
);
// NOLINT(*)
REGISTER_MAKE_BIT_OP
(
right_shift
,
operator
>>
);
REGISTER_MAKE_BIT_OP
(
right_shift
,
operator
>>
);
REGISTER_MAKE1
(
Not
);
REGISTER_MAKE3
(
Select
);
REGISTER_MAKE3
(
Ramp
);
REGISTER_MAKE2
(
Cast
);
REGISTER_MAKE2
(
Broadcast
);
REGISTER_MAKE2
(
Shuffle
);
REGISTER_MAKE3
(
Let
);
REGISTER_MAKE3
(
LetStmt
);
REGISTER_MAKE3
(
AssertStmt
);
REGISTER_MAKE3
(
ProducerConsumer
);
REGISTER_MAKE5
(
Allocate
);
REGISTER_MAKE4
(
Provide
);
REGISTER_MAKE4
(
Prefetch
);
REGISTER_MAKE1
(
Free
);
REGISTER_MAKE2
(
Block
);
REGISTER_MAKE3
(
IfThenElse
);
REGISTER_MAKE1
(
Evaluate
);
}
// namespace ir
}
// namespace ir
}
// namespace tvm
}
// namespace tvm
src/arithmetic/compute_expr.h
View file @
32af4d28
...
@@ -14,10 +14,6 @@
...
@@ -14,10 +14,6 @@
namespace
tvm
{
namespace
tvm
{
namespace
arith
{
namespace
arith
{
using
HalideIR
::
Internal
::
add_would_overflow
;
using
HalideIR
::
Internal
::
sub_would_overflow
;
using
HalideIR
::
Internal
::
mul_would_overflow
;
/*!
/*!
* \brief Compute the expression with the given binary op.
* \brief Compute the expression with the given binary op.
* \param lhs The left operand
* \param lhs The left operand
...
@@ -42,23 +38,9 @@ template<typename Op>
...
@@ -42,23 +38,9 @@ template<typename Op>
inline
Expr
ComputeReduce
(
inline
Expr
ComputeReduce
(
const
Array
<
Expr
>&
values
,
Expr
empty_value
);
const
Array
<
Expr
>&
values
,
Expr
empty_value
);
template
<
typename
T
>
inline
bool
GetConst
(
Expr
e
,
int64_t
*
out
)
{
inline
bool
GetConst
(
Expr
e
,
T
*
out
);
template
<>
inline
bool
GetConst
<
int64_t
>
(
Expr
e
,
int64_t
*
out
)
{
if
(
e
.
type
().
is_vector
())
return
false
;
const
int64_t
*
v
=
as_const_int
(
e
);
if
(
v
)
{
*
out
=
*
v
;
return
true
;
}
else
{
return
false
;
}
}
template
<>
inline
bool
GetConst
<
uint64_t
>
(
Expr
e
,
uint64_t
*
out
)
{
if
(
e
.
type
().
is_vector
())
return
false
;
if
(
e
.
type
().
is_vector
())
return
false
;
const
uint64_t
*
v
=
as_const_u
int
(
e
);
const
int64_t
*
v
=
as_const_
int
(
e
);
if
(
v
)
{
if
(
v
)
{
*
out
=
*
v
;
return
true
;
*
out
=
*
v
;
return
true
;
}
else
{
}
else
{
...
@@ -69,66 +51,37 @@ inline bool GetConst<uint64_t>(Expr e, uint64_t *out) {
...
@@ -69,66 +51,37 @@ inline bool GetConst<uint64_t>(Expr e, uint64_t *out) {
// get a small constant int
// get a small constant int
inline
bool
GetConstInt
(
Expr
e
,
int
*
out
)
{
inline
bool
GetConstInt
(
Expr
e
,
int
*
out
)
{
int64_t
v1
=
0
;
int64_t
v1
=
0
;
uint64_t
v2
=
0
;
if
(
GetConst
(
e
,
&
v1
))
{
if
(
GetConst
(
e
,
&
v1
))
{
if
(
v1
>
static_cast
<
int64_t
>
(
if
(
v1
>
static_cast
<
int64_t
>
(
std
::
numeric_limits
<
int
>::
max
()))
return
false
;
std
::
numeric_limits
<
int
>::
max
()))
return
false
;
*
out
=
static_cast
<
int
>
(
v1
);
return
true
;
*
out
=
static_cast
<
int
>
(
v1
);
return
true
;
}
}
if
(
GetConst
(
e
,
&
v2
))
{
if
(
v2
>
static_cast
<
uint64_t
>
(
std
::
numeric_limits
<
int
>::
max
()))
return
false
;
*
out
=
static_cast
<
int
>
(
v2
);
return
true
;
}
return
false
;
return
false
;
}
}
#define TVM_CONST_PROPAGATION(OP_NAME, OP) \
int64_t ia = 0, ib = 0; \
if (GetConst(a, &ia) && GetConst(b, &ib)) { \
if (OP_NAME ## _would_overflow(a.type().bits(), ia, ib)) { \
LOG(FATAL) << "signed int overflow"; \
} \
return ir::IntImm::make(a.type(), ia OP ib); \
} \
uint64_t ua = 0, ub = 0; \
if (GetConst(a, &ua) && GetConst(b, &ub)) { \
return ir::UIntImm::make(a.type(), ua OP ub); \
} \
template
<>
template
<>
inline
Expr
ComputeExpr
<
ir
::
Add
>
(
Expr
a
,
Expr
b
)
{
inline
Expr
ComputeExpr
<
ir
::
Add
>
(
Expr
a
,
Expr
b
)
{
if
(
is_zero
(
a
))
return
b
;
return
a
+
b
;
if
(
is_zero
(
b
))
return
a
;
TVM_CONST_PROPAGATION
(
add
,
+
);
return
ir
::
Add
::
make
(
a
,
b
);
}
}
template
<>
template
<>
inline
Expr
ComputeExpr
<
ir
::
Sub
>
(
Expr
a
,
Expr
b
)
{
inline
Expr
ComputeExpr
<
ir
::
Sub
>
(
Expr
a
,
Expr
b
)
{
if
(
is_zero
(
b
))
return
a
;
return
a
-
b
;
TVM_CONST_PROPAGATION
(
sub
,
-
);
return
ir
::
Sub
::
make
(
a
,
b
);
}
}
template
<>
template
<>
inline
Expr
ComputeExpr
<
ir
::
Mul
>
(
Expr
a
,
Expr
b
)
{
inline
Expr
ComputeExpr
<
ir
::
Mul
>
(
Expr
a
,
Expr
b
)
{
if
(
is_one
(
a
))
return
b
;
return
a
*
b
;
if
(
is_one
(
b
))
return
a
;
TVM_CONST_PROPAGATION
(
mul
,
*
);
return
ir
::
Mul
::
make
(
a
,
b
);
}
}
template
<>
template
<>
inline
Expr
ComputeExpr
<
ir
::
Div
>
(
Expr
a
,
Expr
b
)
{
inline
Expr
ComputeExpr
<
ir
::
Div
>
(
Expr
a
,
Expr
b
)
{
if
(
is_one
(
b
))
return
a
;
return
a
/
b
;
return
ir
::
Div
::
make
(
a
,
b
);
}
}
template
<>
template
<>
inline
Expr
ComputeExpr
<
ir
::
Mod
>
(
Expr
a
,
Expr
b
)
{
inline
Expr
ComputeExpr
<
ir
::
Mod
>
(
Expr
a
,
Expr
b
)
{
if
(
is_zero
(
a
))
return
make_zero
(
a
.
type
());
return
a
%
b
;
return
ir
::
Mod
::
make
(
a
,
b
);
}
}
template
<>
template
<>
...
...
src/arithmetic/detect_linear_equation.cc
View file @
32af4d28
...
@@ -194,7 +194,7 @@ bool DetectClipBound(
...
@@ -194,7 +194,7 @@ bool DetectClipBound(
if
(
!
LinearEqDetector
(
var
).
Detect
(
canonical
,
&
ret
))
return
false
;
if
(
!
LinearEqDetector
(
var
).
Detect
(
canonical
,
&
ret
))
return
false
;
ret
.
coeff
=
Simplify
(
ret
.
coeff
);
ret
.
coeff
=
Simplify
(
ret
.
coeff
);
IntervalEntry
&
p
=
(
*
bmap
)[
var
.
get
()];
IntervalEntry
&
p
=
(
*
bmap
)[
var
.
get
()];
if
(
is_
one
(
ret
.
coeff
))
{
if
(
is_
const_int
(
ret
.
coeff
,
1
))
{
// var + shift >=0 -> var >= -shift
// var + shift >=0 -> var >= -shift
if
(
p
.
min_value
.
defined
())
{
if
(
p
.
min_value
.
defined
())
{
p
.
min_value
=
ir
::
Max
::
make
(
p
.
min_value
,
-
ret
.
base
);
p
.
min_value
=
ir
::
Max
::
make
(
p
.
min_value
,
-
ret
.
base
);
...
@@ -203,7 +203,7 @@ bool DetectClipBound(
...
@@ -203,7 +203,7 @@ bool DetectClipBound(
}
}
return
true
;
return
true
;
}
}
if
(
is_const
(
ret
.
coeff
,
-
1
))
{
if
(
is_const
_int
(
ret
.
coeff
,
-
1
))
{
// -var + shift >=0 -> var <= shift
// -var + shift >=0 -> var <= shift
if
(
p
.
max_value
.
defined
())
{
if
(
p
.
max_value
.
defined
())
{
p
.
max_value
=
ir
::
Min
::
make
(
p
.
max_value
,
ret
.
base
);
p
.
max_value
=
ir
::
Min
::
make
(
p
.
max_value
,
ret
.
base
);
...
...
src/codegen/codegen_cuda.cc
View file @
32af4d28
...
@@ -42,7 +42,7 @@ std::string CodeGenCUDA::Finish() {
...
@@ -42,7 +42,7 @@ std::string CodeGenCUDA::Finish() {
}
}
void
CodeGenCUDA
::
VisitStmt_
(
const
ir
::
For
*
op
)
{
void
CodeGenCUDA
::
VisitStmt_
(
const
ir
::
For
*
op
)
{
CHECK
(
is_
zero
(
op
->
min
));
CHECK
(
is_
const_int
(
op
->
min
,
0
));
if
(
op
->
for_type
==
ir
::
ForType
::
Unrolled
)
{
if
(
op
->
for_type
==
ir
::
ForType
::
Unrolled
)
{
PrintIndent
();
PrintIndent
();
stream
<<
"#pragma unroll
\n
"
;
stream
<<
"#pragma unroll
\n
"
;
...
...
src/codegen/verilog/verilog_ir.cc
View file @
32af4d28
...
@@ -195,7 +195,7 @@ class PipelineExtractor: public IRVisitor {
...
@@ -195,7 +195,7 @@ class PipelineExtractor: public IRVisitor {
ChannelEntry
&
cb
=
cmap_
.
at
(
ch
->
handle_var
.
get
());
ChannelEntry
&
cb
=
cmap_
.
at
(
ch
->
handle_var
.
get
());
trigger
->
signal_index
=
static_cast
<
int
>
(
cb
.
node
->
ctrl_signals
.
size
());
trigger
->
signal_index
=
static_cast
<
int
>
(
cb
.
node
->
ctrl_signals
.
size
());
// Grab the advance constant size.
// Grab the advance constant size.
int
trigger_size
;
int
trigger_size
=
0
;
if
(
attr
->
attr_key
==
attr
::
pipeline_stage_scope
)
{
if
(
attr
->
attr_key
==
attr
::
pipeline_stage_scope
)
{
cb
.
node
->
ctrl_signals
.
push_back
(
cb
.
node
->
ctrl_signals
.
push_back
(
ControlSignalNode
::
make
(
kComputeFinish
,
0
));
ControlSignalNode
::
make
(
kComputeFinish
,
0
));
...
...
src/lang/expr.cc
View file @
32af4d28
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
#include <tvm/base.h>
#include <tvm/base.h>
#include <tvm/expr.h>
#include <tvm/expr.h>
#include <tvm/ir.h>
#include <tvm/ir.h>
#include <tvm/ir_operator.h>
#include <ir/IRPrinter.h>
#include <ir/IRPrinter.h>
#include <memory>
#include <memory>
...
...
src/lang/ir_operator.cc
View file @
32af4d28
...
@@ -8,6 +8,406 @@
...
@@ -8,6 +8,406 @@
namespace
tvm
{
namespace
tvm
{
/*!
* \brief Check whether type is used to represent index.
*
* Index types are frequently used in shape computation
* and need to be aggressively constant-folded.
*
* \param type The type to represent index.
* \return the checked result.
*/
inline
bool
IsIndexType
(
const
Type
&
type
)
{
return
type
.
is_int
()
&&
type
.
lanes
()
==
1
&&
(
type
.
bits
()
==
32
||
type
.
bits
()
==
64
);
}
// simple cast that only checks if type matches and cast
inline
Expr
SimpleCast
(
const
Type
&
t
,
Expr
value
)
{
if
(
value
.
type
()
==
t
)
return
value
;
return
ir
::
Cast
::
make
(
t
,
value
);
}
// The public function with a quick checking path.
void
BinaryOpMatchTypes
(
Expr
&
lhs
,
Expr
&
rhs
)
{
// NOLINT(*)
if
(
lhs
.
type
()
==
rhs
.
type
())
return
;
Type
ltype
=
lhs
.
type
();
Type
rtype
=
rhs
.
type
();
if
(
ltype
.
lanes
()
==
1
&&
rtype
.
lanes
()
!=
1
)
{
lhs
=
ir
::
Broadcast
::
make
(
lhs
,
rtype
.
lanes
());
}
else
if
(
rtype
.
lanes
()
==
1
&&
ltype
.
lanes
()
!=
1
)
{
rhs
=
ir
::
Broadcast
::
make
(
rhs
,
ltype
.
lanes
());
}
else
{
CHECK
(
ltype
.
lanes
()
==
rtype
.
lanes
())
<<
"Cannot match type "
<<
ltype
<<
" vs "
<<
rtype
;
}
if
(
lhs
.
type
()
==
rhs
.
type
())
return
;
// Only do very simple type coversion
// int->float, int(32)->int(64)
// require the types to be relatively consistent
// This will the reduce amount code generated by operators
// and also help user to find potential type conversion problems.
if
(
!
lhs
.
type
().
is_float
()
&&
rhs
.
type
().
is_float
())
{
// int->float
lhs
=
ir
::
Cast
::
make
(
rhs
.
type
(),
lhs
);
}
else
if
(
lhs
.
type
().
is_float
()
&&
!
rhs
.
type
().
is_float
())
{
// int->float
rhs
=
ir
::
Cast
::
make
(
lhs
.
type
(),
rhs
);
}
else
if
((
lhs
.
type
().
is_int
()
&&
rhs
.
type
().
is_int
())
||
(
lhs
.
type
().
is_uint
()
&&
rhs
.
type
().
is_uint
()))
{
// promote int to higher bits
if
(
lhs
.
type
().
bits
()
<
rhs
.
type
().
bits
())
{
lhs
=
ir
::
Cast
::
make
(
rhs
.
type
(),
lhs
);
}
else
{
rhs
=
ir
::
Cast
::
make
(
lhs
.
type
(),
rhs
);
}
}
else
if
((
lhs
.
type
().
is_int
()
&&
rhs
.
type
().
is_uint
())
||
(
lhs
.
type
().
is_uint
()
&&
rhs
.
type
().
is_int
()))
{
int
bits
=
std
::
max
(
lhs
.
type
().
bits
(),
rhs
.
type
().
bits
());
lhs
=
SimpleCast
(
Int
(
bits
,
lhs
.
type
().
lanes
()),
lhs
);
rhs
=
SimpleCast
(
Int
(
bits
,
rhs
.
type
().
lanes
()),
rhs
);
}
else
{
LOG
(
FATAL
)
<<
"Cannot match type "
<<
ltype
<<
" vs "
<<
rtype
;
}
}
template
<
typename
ValueType
>
inline
bool
ConstPowerHelper
(
ValueType
val
,
int
*
shift
)
{
if
(
val
<=
0
)
return
false
;
shift
[
0
]
=
0
;
while
(
val
!=
0
)
{
if
(
val
&
1
)
{
return
(
val
==
1
);
}
++
shift
[
0
];
val
=
val
>>
1
;
}
return
true
;
}
bool
is_const_power_of_two_integer
(
const
Expr
&
x
,
int
*
shift
)
{
if
(
const
auto
*
op
=
x
.
as
<
ir
::
IntImm
>
())
{
return
ConstPowerHelper
(
op
->
value
,
shift
);
}
else
if
(
const
auto
*
op
=
x
.
as
<
ir
::
UIntImm
>
())
{
return
ConstPowerHelper
(
op
->
value
,
shift
);
}
else
{
return
false
;
}
}
Expr
cast
(
const
Type
&
t
,
Expr
value
)
{
using
ir
::
IntImm
;
if
(
value
.
type
()
==
t
)
return
value
;
// const fold IntImm as they are used in index computations
if
(
t
.
lanes
()
==
1
)
{
if
(
const
IntImm
*
op
=
value
.
as
<
IntImm
>
())
{
return
make_const
(
t
,
op
->
value
);
}
return
ir
::
Cast
::
make
(
t
,
value
);
}
else
{
if
(
value
.
type
().
lanes
()
==
1
)
{
// manually unroll cast
Type
vtype
=
t
.
element_of
();
if
(
value
.
type
()
!=
vtype
)
{
if
(
const
IntImm
*
op
=
value
.
as
<
IntImm
>
())
{
value
=
make_const
(
vtype
,
op
->
value
);
}
else
{
value
=
ir
::
Cast
::
make
(
vtype
,
value
);
}
}
return
ir
::
Broadcast
::
make
(
value
,
t
.
lanes
());
}
else
{
CHECK
(
value
.
type
().
lanes
()
==
t
.
lanes
());
return
ir
::
Cast
::
make
(
t
,
value
);
}
}
}
Expr
reinterpret
(
const
Type
&
t
,
Expr
value
)
{
if
(
value
.
type
()
==
t
)
return
value
;
return
ir
::
Call
::
make
(
t
,
ir
::
Call
::
reinterpret
,
{
value
},
ir
::
Call
::
PureIntrinsic
);
}
#define TVM_CONST_PROPAGATION(BODY) \
using ir::IntImm; \
using ir::UIntImm; \
const IntImm* pa = a.as<IntImm>(); \
const IntImm* pb = b.as<IntImm>(); \
const Type& ta = a.type(); \
const Type& tb = b.type(); \
if (IsIndexType(ta) && IsIndexType(tb)) { \
BODY; \
} \
BinaryOpMatchTypes(a, b);
Expr
operator
+
(
Expr
a
,
Expr
b
)
{
TVM_CONST_PROPAGATION
({
Type
rtype
=
ta
.
bits
()
>=
tb
.
bits
()
?
ta
:
tb
;
if
(
pa
&&
pb
)
return
IntImm
::
make
(
rtype
,
pa
->
value
+
pb
->
value
);
if
(
pa
&&
pa
->
value
==
0
)
return
SimpleCast
(
rtype
,
b
);
if
(
pb
&&
pb
->
value
==
0
)
return
SimpleCast
(
rtype
,
a
);
});
return
ir
::
Add
::
make
(
a
,
b
);
}
Expr
operator
-
(
Expr
a
)
{
using
ir
::
IntImm
;
const
IntImm
*
pa
=
a
.
as
<
IntImm
>
();
if
(
pa
)
{
return
ir
::
IntImm
::
make
(
a
.
type
(),
-
pa
->
value
);
}
return
make_zero
(
a
.
type
())
-
a
;
}
Expr
operator
-
(
Expr
a
,
Expr
b
)
{
TVM_CONST_PROPAGATION
({
Type
rtype
=
ta
.
bits
()
>=
tb
.
bits
()
?
ta
:
tb
;
if
(
pa
&&
pb
)
return
IntImm
::
make
(
rtype
,
pa
->
value
-
pb
->
value
);
if
(
pb
&&
pb
->
value
==
0
)
return
SimpleCast
(
rtype
,
a
);
});
return
ir
::
Sub
::
make
(
a
,
b
);
}
Expr
operator
*
(
Expr
a
,
Expr
b
)
{
TVM_CONST_PROPAGATION
({
Type
rtype
=
ta
.
bits
()
>=
tb
.
bits
()
?
ta
:
tb
;
if
(
pa
&&
pb
)
return
IntImm
::
make
(
rtype
,
pa
->
value
*
pb
->
value
);
if
(
pa
)
{
if
(
pa
->
value
==
1
)
return
SimpleCast
(
rtype
,
b
);
if
(
pa
->
value
==
0
)
return
SimpleCast
(
rtype
,
a
);
}
if
(
pb
)
{
if
(
pb
->
value
==
1
)
return
SimpleCast
(
rtype
,
a
);
if
(
pb
->
value
==
0
)
return
SimpleCast
(
rtype
,
b
);
}
});
return
ir
::
Mul
::
make
(
a
,
b
);
}
Expr
operator
/
(
Expr
a
,
Expr
b
)
{
TVM_CONST_PROPAGATION
({
Type
rtype
=
ta
.
bits
()
>=
tb
.
bits
()
?
ta
:
tb
;
// due to division and mod can have different modes
// only constant fold positive number where rule is fixed.
if
(
pa
&&
pb
&&
pa
->
value
>=
0
&&
pb
->
value
>
0
)
{
return
IntImm
::
make
(
rtype
,
pa
->
value
/
pb
->
value
);
}
if
(
pa
)
{
if
(
pa
->
value
==
0
)
return
SimpleCast
(
rtype
,
a
);
}
if
(
pb
)
{
if
(
pb
->
value
==
1
)
return
SimpleCast
(
rtype
,
a
);
CHECK_NE
(
pb
->
value
,
0
)
<<
"Divide by zero"
;
}
});
return
ir
::
Div
::
make
(
a
,
b
);
}
Expr
operator
%
(
Expr
a
,
Expr
b
)
{
TVM_CONST_PROPAGATION
({
Type
rtype
=
ta
.
bits
()
>=
tb
.
bits
()
?
ta
:
tb
;
// due to division and mod can have different modes
// only constant fold positive number where rule is fixed.
if
(
pa
&&
pb
&&
pa
->
value
>=
0
&&
pb
->
value
>
0
)
{
return
IntImm
::
make
(
rtype
,
pa
->
value
%
pb
->
value
);
}
if
(
pa
)
{
if
(
pa
->
value
==
0
)
return
SimpleCast
(
rtype
,
a
);
}
if
(
pb
)
{
if
(
pb
->
value
==
1
)
return
make_zero
(
rtype
);
CHECK_NE
(
pb
->
value
,
0
)
<<
"Divide by zero"
;
}
});
return
ir
::
Mod
::
make
(
a
,
b
);
}
Expr
min
(
Expr
a
,
Expr
b
)
{
TVM_CONST_PROPAGATION
({
Type
rtype
=
ta
.
bits
()
>=
tb
.
bits
()
?
ta
:
tb
;
if
(
pa
&&
pb
)
return
IntImm
::
make
(
rtype
,
std
::
min
(
pa
->
value
,
pb
->
value
));
});
return
ir
::
Min
::
make
(
a
,
b
);
}
Expr
max
(
Expr
a
,
Expr
b
)
{
TVM_CONST_PROPAGATION
({
Type
rtype
=
ta
.
bits
()
>=
tb
.
bits
()
?
ta
:
tb
;
if
(
pa
&&
pb
)
return
IntImm
::
make
(
rtype
,
std
::
max
(
pa
->
value
,
pb
->
value
));
});
return
ir
::
Max
::
make
(
a
,
b
);
}
Expr
select
(
Expr
cond
,
Expr
true_value
,
Expr
false_value
)
{
using
ir
::
IntImm
;
using
ir
::
UIntImm
;
CHECK
(
cond
.
type
().
is_bool
());
BinaryOpMatchTypes
(
true_value
,
false_value
);
if
(
const
UIntImm
*
op
=
cond
.
as
<
UIntImm
>
())
{
if
(
op
->
value
!=
0
)
{
return
true_value
;
}
else
{
return
false_value
;
}
}
else
if
(
const
IntImm
*
op
=
cond
.
as
<
IntImm
>
())
{
if
(
op
->
value
!=
0
)
{
return
true_value
;
}
else
{
return
false_value
;
}
}
return
ir
::
Select
::
make
(
cond
,
true_value
,
false_value
);
}
Expr
likely
(
Expr
cond
)
{
if
(
is_const
(
cond
))
return
cond
;
return
ir
::
Call
::
make
(
cond
.
type
(),
ir
::
Call
::
likely
,
{
cond
},
ir
::
Call
::
PureIntrinsic
);
}
Expr
operator
>
(
Expr
a
,
Expr
b
)
{
TVM_CONST_PROPAGATION
({
if
(
pa
&&
pb
)
return
UIntImm
::
make
(
UInt
(
1
),
pa
->
value
>
pb
->
value
);
});
return
ir
::
GT
::
make
(
a
,
b
);
}
Expr
operator
>=
(
Expr
a
,
Expr
b
)
{
TVM_CONST_PROPAGATION
({
if
(
pa
&&
pb
)
return
UIntImm
::
make
(
UInt
(
1
),
pa
->
value
>=
pb
->
value
);
});
return
ir
::
GE
::
make
(
a
,
b
);
}
Expr
operator
<
(
Expr
a
,
Expr
b
)
{
TVM_CONST_PROPAGATION
({
if
(
pa
&&
pb
)
return
UIntImm
::
make
(
UInt
(
1
),
pa
->
value
<
pb
->
value
);
});
return
ir
::
LT
::
make
(
a
,
b
);
}
Expr
operator
<=
(
Expr
a
,
Expr
b
)
{
TVM_CONST_PROPAGATION
({
if
(
pa
&&
pb
)
return
UIntImm
::
make
(
UInt
(
1
),
pa
->
value
<=
pb
->
value
);
});
return
ir
::
LE
::
make
(
a
,
b
);
}
Expr
operator
==
(
Expr
a
,
Expr
b
)
{
TVM_CONST_PROPAGATION
({
if
(
pa
&&
pb
)
return
UIntImm
::
make
(
UInt
(
1
),
pa
->
value
==
pb
->
value
);
});
return
ir
::
EQ
::
make
(
a
,
b
);
}
Expr
operator
!=
(
Expr
a
,
Expr
b
)
{
TVM_CONST_PROPAGATION
({
if
(
pa
&&
pb
)
return
UIntImm
::
make
(
UInt
(
1
),
pa
->
value
!=
pb
->
value
);
});
return
ir
::
NE
::
make
(
a
,
b
);
}
Expr
operator
&&
(
Expr
a
,
Expr
b
)
{
using
ir
::
UIntImm
;
const
UIntImm
*
pa
=
a
.
as
<
UIntImm
>
();
const
UIntImm
*
pb
=
b
.
as
<
UIntImm
>
();
if
(
pa
&&
pb
)
{
return
UIntImm
::
make
(
UInt
(
1
),
pa
->
value
&&
pb
->
value
);
}
return
ir
::
And
::
make
(
a
,
b
);
}
Expr
operator
||
(
Expr
a
,
Expr
b
)
{
using
ir
::
UIntImm
;
const
UIntImm
*
pa
=
a
.
as
<
UIntImm
>
();
const
UIntImm
*
pb
=
b
.
as
<
UIntImm
>
();
if
(
pa
&&
pb
)
{
return
UIntImm
::
make
(
UInt
(
1
),
pa
->
value
||
pb
->
value
);
}
return
ir
::
Or
::
make
(
a
,
b
);
}
Expr
operator
!
(
Expr
a
)
{
using
ir
::
UIntImm
;
const
UIntImm
*
pa
=
a
.
as
<
UIntImm
>
();
if
(
pa
)
{
return
UIntImm
::
make
(
UInt
(
1
),
!
(
pa
->
value
));
}
return
ir
::
Not
::
make
(
a
);
}
Expr
operator
>>
(
Expr
a
,
Expr
b
)
{
TVM_CONST_PROPAGATION
({
Type
rtype
=
ta
.
bits
()
>=
tb
.
bits
()
?
ta
:
tb
;
if
(
pa
&&
pb
)
return
IntImm
::
make
(
rtype
,
(
pa
->
value
>>
pb
->
value
));
if
(
pb
)
{
if
(
pb
->
value
==
0
)
return
SimpleCast
(
rtype
,
a
);
}
});
return
ir
::
Call
::
make
(
a
.
type
(),
ir
::
Call
::
shift_right
,
{
a
,
b
},
ir
::
Call
::
PureIntrinsic
);
}
Expr
operator
<<
(
Expr
a
,
Expr
b
)
{
TVM_CONST_PROPAGATION
({
Type
rtype
=
ta
.
bits
()
>=
tb
.
bits
()
?
ta
:
tb
;
if
(
pa
&&
pb
)
return
IntImm
::
make
(
rtype
,
(
pa
->
value
<<
pb
->
value
));
if
(
pb
)
{
if
(
pb
->
value
==
0
)
return
SimpleCast
(
rtype
,
a
);
}
});
return
ir
::
Call
::
make
(
a
.
type
(),
ir
::
Call
::
shift_left
,
{
a
,
b
},
ir
::
Call
::
PureIntrinsic
);
}
Expr
operator
&
(
Expr
a
,
Expr
b
)
{
TVM_CONST_PROPAGATION
({
Type
rtype
=
ta
.
bits
()
>=
tb
.
bits
()
?
ta
:
tb
;
if
(
pa
&&
pb
)
return
IntImm
::
make
(
rtype
,
(
pa
->
value
&
pb
->
value
));
});
return
ir
::
Call
::
make
(
a
.
type
(),
ir
::
Call
::
bitwise_and
,
{
a
,
b
},
ir
::
Call
::
PureIntrinsic
);
}
Expr
operator
|
(
Expr
a
,
Expr
b
)
{
TVM_CONST_PROPAGATION
({
Type
rtype
=
ta
.
bits
()
>=
tb
.
bits
()
?
ta
:
tb
;
if
(
pa
&&
pb
)
return
IntImm
::
make
(
rtype
,
(
pa
->
value
|
pb
->
value
));
});
return
ir
::
Call
::
make
(
a
.
type
(),
ir
::
Call
::
bitwise_or
,
{
a
,
b
},
ir
::
Call
::
PureIntrinsic
);
}
Expr
operator
^
(
Expr
a
,
Expr
b
)
{
TVM_CONST_PROPAGATION
({
Type
rtype
=
ta
.
bits
()
>=
tb
.
bits
()
?
ta
:
tb
;
if
(
pa
&&
pb
)
return
IntImm
::
make
(
rtype
,
(
pa
->
value
^
pb
->
value
));
});
return
ir
::
Call
::
make
(
a
.
type
(),
ir
::
Call
::
bitwise_xor
,
{
a
,
b
},
ir
::
Call
::
PureIntrinsic
);
}
Expr
operator
~
(
Expr
a
)
{
CHECK
(
a
.
type
().
is_int
()
||
a
.
type
().
is_uint
());
return
ir
::
Call
::
make
(
a
.
type
(),
ir
::
Call
::
bitwise_not
,
{
a
},
ir
::
Call
::
PureIntrinsic
);
}
Expr
pow
(
Expr
x
,
Expr
y
)
{
BinaryOpMatchTypes
(
x
,
y
);
CHECK
(
x
.
type
().
is_float
())
<<
"power only applies to float"
;
return
ir
::
Call
::
make
(
x
.
type
(),
"pow"
,
{
x
,
y
},
ir
::
Call
::
PureIntrinsic
);
}
Expr
abs
(
Expr
x
)
{
if
(
x
.
type
().
is_int
())
{
return
select
(
x
>=
make_zero
(
x
.
type
()),
x
,
-
x
);
}
else
if
(
x
.
type
().
is_float
())
{
return
ir
::
Call
::
make
(
x
.
type
(),
"fabs"
,
{
x
},
ir
::
Call
::
PureIntrinsic
);
}
else
if
(
x
.
type
().
is_uint
())
{
return
x
;
}
else
{
LOG
(
FATAL
)
<<
"Data type "
<<
x
.
type
()
<<
" not supported for absolute op. Skipping absolute op..."
;
return
x
;
}
}
Expr
sum
(
Expr
source
,
Array
<
IterVar
>
rdom
)
{
Expr
sum
(
Expr
source
,
Array
<
IterVar
>
rdom
)
{
Var
x
(
"x"
,
source
.
type
()),
y
(
"y"
,
source
.
type
());
Var
x
(
"x"
,
source
.
type
()),
y
(
"y"
,
source
.
type
());
Expr
result
=
ir
::
Add
::
make
(
x
,
y
);
Expr
result
=
ir
::
Add
::
make
(
x
,
y
);
...
@@ -38,7 +438,7 @@ Expr min(Expr source, Array<IterVar> rdom) {
...
@@ -38,7 +438,7 @@ Expr min(Expr source, Array<IterVar> rdom) {
Expr
prod
(
Expr
source
,
Array
<
IterVar
>
rdom
)
{
Expr
prod
(
Expr
source
,
Array
<
IterVar
>
rdom
)
{
Var
x
(
"x"
,
source
.
type
()),
y
(
"y"
,
source
.
type
());
Var
x
(
"x"
,
source
.
type
()),
y
(
"y"
,
source
.
type
());
Expr
result
=
ir
::
Mul
::
make
(
x
,
y
);
Expr
result
=
ir
::
Mul
::
make
(
x
,
y
);
Expr
identity_element
=
make_
one
(
source
.
type
()
);
Expr
identity_element
=
make_
const
(
source
.
type
(),
1
);
ir
::
CommReducer
combiner
=
ir
::
CommReducer
combiner
=
ir
::
CommReducerNode
::
make
({
x
},
{
y
},
{
result
},
{
identity_element
});
ir
::
CommReducerNode
::
make
({
x
},
{
y
},
{
result
},
{
identity_element
});
return
ir
::
Reduce
::
make
(
combiner
,
{
source
},
rdom
,
make_const
(
Bool
(
1
),
true
),
0
);
return
ir
::
Reduce
::
make
(
combiner
,
{
source
},
rdom
,
make_const
(
Bool
(
1
),
true
),
0
);
...
...
src/pass/ir_util.h
View file @
32af4d28
...
@@ -7,6 +7,7 @@
...
@@ -7,6 +7,7 @@
#define TVM_PASS_IR_UTIL_H_
#define TVM_PASS_IR_UTIL_H_
#include <tvm/ir.h>
#include <tvm/ir.h>
#include <tvm/ir_operator.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/device_api.h>
#include <vector>
#include <vector>
...
@@ -75,7 +76,7 @@ inline Expr TVMStructGet(
...
@@ -75,7 +76,7 @@ inline Expr TVMStructGet(
Array
<
Expr
>
args
=
{
Array
<
Expr
>
args
=
{
handle
,
handle
,
make_const
(
Int
(
32
),
index
),
make_const
(
Int
(
32
),
index
),
make_const
(
Int
(
32
),
kind
)};
make_const
(
Int
(
32
),
static_cast
<
int
>
(
kind
)
)};
return
Call
::
make
(
dtype
,
intrinsic
::
tvm_struct_get
,
args
,
Call
::
PureIntrinsic
);
return
Call
::
make
(
dtype
,
intrinsic
::
tvm_struct_get
,
args
,
Call
::
PureIntrinsic
);
}
}
...
@@ -125,7 +126,7 @@ inline Stmt TVMStructSet(
...
@@ -125,7 +126,7 @@ inline Stmt TVMStructSet(
Array
<
Expr
>
args
=
{
Array
<
Expr
>
args
=
{
handle
,
handle
,
make_const
(
Int
(
32
),
index
),
make_const
(
Int
(
32
),
index
),
make_const
(
Int
(
32
),
kind
),
make_const
(
Int
(
32
),
static_cast
<
int
>
(
kind
)
),
value
};
value
};
return
Evaluate
::
make
(
return
Evaluate
::
make
(
Call
::
make
(
Int
(
32
),
intrinsic
::
tvm_struct_set
,
args
,
Call
::
Intrinsic
));
Call
::
make
(
Int
(
32
),
intrinsic
::
tvm_struct_set
,
args
,
Call
::
Intrinsic
));
...
...
src/pass/split_pipeline.cc
View file @
32af4d28
...
@@ -102,9 +102,8 @@ class MarkChannelAccess : public IRMutator {
...
@@ -102,9 +102,8 @@ class MarkChannelAccess : public IRMutator {
}
else
{
}
else
{
alloc_size
=
op
->
extents
[
0
];
alloc_size
=
op
->
extents
[
0
];
for
(
size_t
i
=
1
;
i
<
op
->
extents
.
size
();
++
i
)
{
for
(
size_t
i
=
1
;
i
<
op
->
extents
.
size
();
++
i
)
{
alloc_size
*=
op
->
extents
[
i
];
alloc_size
=
alloc_size
*
op
->
extents
[
i
];
}
}
alloc_size
=
ir
::
Simplify
(
alloc_size
);
}
}
if
(
rw
.
write_count
)
{
if
(
rw
.
write_count
)
{
...
...
src/pass/storage_rewrite.cc
View file @
32af4d28
...
@@ -578,7 +578,7 @@ class StoragePlanRewriter : public IRMutator {
...
@@ -578,7 +578,7 @@ class StoragePlanRewriter : public IRMutator {
combo_size
=
combo_size
/
type_bits
;
combo_size
=
combo_size
/
type_bits
;
// round up for can not divided
// round up for can not divided
if
(
!
divided
)
{
if
(
!
divided
)
{
combo_size
+=
make_const
(
Int
(
32
),
1
);
combo_size
=
combo_size
+
make_const
(
Int
(
32
),
1
);
}
}
combo_size
=
ir
::
Simplify
(
combo_size
);
combo_size
=
ir
::
Simplify
(
combo_size
);
e
->
new_alloc
=
Allocate
::
make
(
e
->
new_alloc
=
Allocate
::
make
(
...
...
src/pass/vectorize_loop.cc
View file @
32af4d28
...
@@ -437,7 +437,6 @@ class LoopVectorizer : public IRMutator {
...
@@ -437,7 +437,6 @@ class LoopVectorizer : public IRMutator {
Stmt
Mutate_
(
const
For
*
op
,
const
Stmt
&
s
)
final
{
Stmt
Mutate_
(
const
For
*
op
,
const
Stmt
&
s
)
final
{
if
(
op
->
for_type
==
ForType
::
Vectorized
)
{
if
(
op
->
for_type
==
ForType
::
Vectorized
)
{
CHECK
(
is_zero
(
op
->
min
));
CHECK
(
is_zero
(
op
->
min
));
CHECK
(
is_positive_const
(
op
->
extent
));
int
lanes
=
0
;
int
lanes
=
0
;
bool
succ
=
arith
::
GetConstInt
(
op
->
extent
,
&
lanes
);
bool
succ
=
arith
::
GetConstInt
(
op
->
extent
,
&
lanes
);
if
(
!
succ
||
lanes
<
1
)
{
if
(
!
succ
||
lanes
<
1
)
{
...
...
tests/cpp/ir_mutator_test.cc
View file @
32af4d28
#include <dmlc/logging.h>
#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <gtest/gtest.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_operator.h>
namespace
{
namespace
{
using
namespace
tvm
::
ir
;
using
namespace
tvm
::
ir
;
...
...
tests/python/unittest/test_arith_intset.py
View file @
32af4d28
...
@@ -35,7 +35,7 @@ def test_deduce():
...
@@ -35,7 +35,7 @@ def test_deduce():
e1
=
(
a
*
4
+
b
<
c
)
e1
=
(
a
*
4
+
b
<
c
)
res1
=
tvm
.
arith
.
DeduceBound
(
a
,
e1
,
{
b
:
b_s
,
c
:
c_s
,
d
:
d_s
},
{})
res1
=
tvm
.
arith
.
DeduceBound
(
a
,
e1
,
{
b
:
b_s
,
c
:
c_s
,
d
:
d_s
},
{})
ans1
=
(((
c
-
b
)
+
-
1
)
/
4
)
ans1
=
(((
c
-
b
)
+
-
1
)
/
4
)
assert
str
(
tvm
.
ir_pass
.
Simplify
(
res1
.
max
()))
==
str
(
ans1
)
assert
str
(
tvm
.
ir_pass
.
Simplify
(
res1
.
max
()))
==
str
(
ans1
)
e2
=
(
tvm
.
max
(
5
,
a
*
4
)
<
0
)
e2
=
(
tvm
.
max
(
5
,
a
*
4
)
<
0
)
...
@@ -63,7 +63,7 @@ def test_check():
...
@@ -63,7 +63,7 @@ def test_check():
assert
res1
.
is_nothing
()
assert
res1
.
is_nothing
()
# multiple compare operators
# multiple compare operators
res2
=
tvm
.
arith
.
DeduceBound
(
a
,
(
a
+
b
>
3
)
>
c
,
{
b
:
b_s
,
c
:
c_s
},
{})
res2
=
tvm
.
arith
.
DeduceBound
(
a
,
(
a
+
b
>
3
)
.
astype
(
c
.
dtype
)
>
c
,
{
b
:
b_s
,
c
:
c_s
},
{})
assert
res2
.
is_nothing
()
assert
res2
.
is_nothing
()
# multiple target variable
# multiple target variable
...
@@ -88,11 +88,11 @@ def test_deduce_basic():
...
@@ -88,11 +88,11 @@ def test_deduce_basic():
res1
=
tvm
.
arith
.
DeduceBound
(
a
,
e0
<=
17
,
{
b
:
b_s
},
{
b
:
b_s
})
res1
=
tvm
.
arith
.
DeduceBound
(
a
,
e0
<=
17
,
{
b
:
b_s
},
{
b
:
b_s
})
[
x
,
y
]
=
[
res1
.
max
(),
b_s
.
max
()]
if
coff
>
0
else
[
res1
.
min
(),
b_s
.
min
()]
[
x
,
y
]
=
[
res1
.
max
(),
b_s
.
max
()]
if
coff
>
0
else
[
res1
.
min
(),
b_s
.
min
()]
assert
(
tvm
.
ir_pass
.
Simplify
((
x
*
coff
+
3
+
y
)
<=
17
))
.
value
==
1
assert
(
tvm
.
ir_pass
.
Simplify
((
x
*
coff
+
3
+
y
)
<=
17
))
.
value
==
1
res1
=
tvm
.
arith
.
DeduceBound
(
a
,
e0
>=
17
,
{
b
:
b_s
},
{
b
:
b_s
})
res1
=
tvm
.
arith
.
DeduceBound
(
a
,
e0
>=
17
,
{
b
:
b_s
},
{
b
:
b_s
})
[
x
,
y
]
=
[
res1
.
max
(),
b_s
.
max
()]
if
coff
<
0
else
[
res1
.
min
(),
b_s
.
min
()]
[
x
,
y
]
=
[
res1
.
max
(),
b_s
.
max
()]
if
coff
<
0
else
[
res1
.
min
(),
b_s
.
min
()]
assert
(
tvm
.
ir_pass
.
Simplify
((
x
*
coff
+
3
+
y
)
>=
17
))
.
value
==
1
assert
(
tvm
.
ir_pass
.
Simplify
((
x
*
coff
+
3
+
y
)
>=
17
))
.
value
==
1
test_basic
(
0
,
4
,
4
)
test_basic
(
0
,
4
,
4
)
test_basic
(
1
,
5
,
4
)
test_basic
(
1
,
5
,
4
)
test_basic
(
2
,
6
,
4
)
test_basic
(
2
,
6
,
4
)
...
@@ -137,4 +137,3 @@ if __name__ == "__main__":
...
@@ -137,4 +137,3 @@ if __name__ == "__main__":
test_check
()
test_check
()
test_deduce_basic
()
test_deduce_basic
()
test_deduce_complex
()
test_deduce_complex
()
tests/python/unittest/test_lang_basic.py
View file @
32af4d28
...
@@ -8,7 +8,7 @@ def test_const():
...
@@ -8,7 +8,7 @@ def test_const():
def
test_make
():
def
test_make
():
x
=
tvm
.
const
(
1
)
x
=
tvm
.
const
(
1
)
y
=
tvm
.
make
.
IntImm
(
'int32'
,
1
)
y
=
tvm
.
var
(
"x"
)
z
=
x
+
y
z
=
x
+
y
assert
isinstance
(
tvm
.
max
(
x
,
y
),
tvm
.
expr
.
Max
)
assert
isinstance
(
tvm
.
max
(
x
,
y
),
tvm
.
expr
.
Max
)
assert
isinstance
(
tvm
.
min
(
x
,
y
),
tvm
.
expr
.
Min
)
assert
isinstance
(
tvm
.
min
(
x
,
y
),
tvm
.
expr
.
Min
)
...
...
tests/python/unittest/test_lang_operator.py
0 → 100644
View file @
32af4d28
import
tvm
def
test_const_fold
():
def
check
(
f
,
*
args
):
x
=
f
(
*
[
tvm
.
const
(
x
)
for
x
in
args
])
y
=
f
(
*
args
)
if
not
isinstance
(
x
,
(
tvm
.
expr
.
IntImm
,
tvm
.
expr
.
UIntImm
))
or
x
.
value
!=
int
(
y
):
raise
ValueError
(
"check error:
%
s vs
%
s "
%
(
x
,
y
))
check
(
lambda
x
,
y
:
x
+
y
,
3
,
4
)
check
(
lambda
x
,
y
:
x
*
y
,
3
,
12
)
check
(
lambda
x
,
y
:
x
*
y
-
10
,
3
,
12
)
check
(
lambda
x
,
y
:
x
-
y
%
10
,
3
,
12
)
check
(
lambda
x
,
y
:
x
//
y
+
10
,
100
,
12
)
check
(
lambda
x
,
y
:
x
&
y
+
10
,
112
,
128
)
check
(
lambda
x
,
y
:
x
>
y
,
112
,
128
)
check
(
lambda
x
,
y
:
x
<
y
,
112
,
128
)
check
(
lambda
x
,
y
:
x
<=
y
,
112
,
128
)
check
(
lambda
x
,
y
:
x
>=
y
,
112
,
128
)
check
(
lambda
x
,
y
:
(
x
|
y
)
^
10
,
112
,
128
)
def
test_const_fold2
():
x
=
tvm
.
var
(
"x"
)
assert
(
x
+
0
)
.
same_as
(
x
)
assert
(
0
+
x
)
.
same_as
(
x
)
assert
(
x
-
0
)
.
same_as
(
x
)
assert
(
x
%
1
)
.
value
==
0
assert
(
x
*
1
)
.
same_as
(
x
)
assert
(
1
*
x
)
.
same_as
(
x
)
assert
isinstance
((
1
/
x
),
tvm
.
expr
.
Div
)
if
__name__
==
"__main__"
:
test_const_fold
()
test_const_fold2
()
tests/python/unittest/test_lang_reflection.py
View file @
32af4d28
...
@@ -15,7 +15,7 @@ def test_make_smap():
...
@@ -15,7 +15,7 @@ def test_make_smap():
# save load json
# save load json
x
=
tvm
.
const
(
1
)
x
=
tvm
.
const
(
1
)
y
=
tvm
.
const
(
10
)
y
=
tvm
.
const
(
10
)
z
=
x
+
y
z
=
tvm
.
expr
.
Add
(
x
,
y
)
smap
=
tvm
.
convert
({
"z"
:
z
,
"x"
:
x
})
smap
=
tvm
.
convert
({
"z"
:
z
,
"x"
:
x
})
json_str
=
tvm
.
save_json
(
tvm
.
convert
([
smap
]))
json_str
=
tvm
.
save_json
(
tvm
.
convert
([
smap
]))
arr
=
tvm
.
load_json
(
json_str
)
arr
=
tvm
.
load_json
(
json_str
)
...
...
tests/python/unittest/test_pass_simplify.py
View file @
32af4d28
...
@@ -53,7 +53,6 @@ def test_canonical():
...
@@ -53,7 +53,6 @@ def test_canonical():
assert
(
tvm
.
ir_pass
.
Equal
(
ret1
,
ret2
))
assert
(
tvm
.
ir_pass
.
Equal
(
ret1
,
ret2
))
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_modular
()
test_bound
()
test_bound
()
test_basic
()
test_basic
()
test_simplify
()
test_simplify
()
...
...
topi/include/topi/elemwise.h
View file @
32af4d28
...
@@ -163,7 +163,7 @@ inline Tensor full(const Array<Expr>& shape,
...
@@ -163,7 +163,7 @@ inline Tensor full(const Array<Expr>& shape,
const
Expr
fill_value
,
const
Expr
fill_value
,
std
::
string
name
=
"tensor"
,
std
::
string
name
=
"tensor"
,
std
::
string
tag
=
kElementWise
)
{
std
::
string
tag
=
kElementWise
)
{
Expr
ev
=
lossless_
cast
(
dtype
,
fill_value
);
Expr
ev
=
cast
(
dtype
,
fill_value
);
if
(
!
ev
.
defined
())
{
if
(
!
ev
.
defined
())
{
LOG
(
ERROR
)
<<
"Can't cast fill_value to "
<<
dtype
;
LOG
(
ERROR
)
<<
"Can't cast fill_value to "
<<
dtype
;
}
}
...
@@ -173,7 +173,7 @@ inline Tensor full(const Array<Expr>& shape,
...
@@ -173,7 +173,7 @@ inline Tensor full(const Array<Expr>& shape,
}
}
/*!
/*!
* \brief Creates an operation that construct a tensor with same shape as input tensor,
* \brief Creates an operation that construct a tensor with same shape as input tensor,
* then fill a tensor with fill_value
* then fill a tensor with fill_value
*
*
* \param x The input tensor
* \param x The input tensor
...
@@ -187,10 +187,7 @@ inline Tensor full_like(const Tensor& x,
...
@@ -187,10 +187,7 @@ inline Tensor full_like(const Tensor& x,
const
Expr
fill_value
,
const
Expr
fill_value
,
std
::
string
name
=
"tensor"
,
std
::
string
name
=
"tensor"
,
std
::
string
tag
=
kElementWise
)
{
std
::
string
tag
=
kElementWise
)
{
Expr
ev
=
lossless_cast
(
x
->
dtype
,
fill_value
);
Expr
ev
=
cast
(
x
->
dtype
,
fill_value
);
if
(
!
ev
.
defined
())
{
LOG
(
ERROR
)
<<
"Can't cast fill_value to "
<<
x
->
dtype
;
}
return
compute
(
x
->
shape
,
[
&
](
const
Array
<
Var
>&
i
)
{
return
compute
(
x
->
shape
,
[
&
](
const
Array
<
Var
>&
i
)
{
return
ev
;
return
ev
;
},
name
,
tag
);
},
name
,
tag
);
...
...
topi/include/topi/nn/pooling.h
View file @
32af4d28
...
@@ -94,10 +94,10 @@ inline Tensor pool_impl(const Tensor& x,
...
@@ -94,10 +94,10 @@ inline Tensor pool_impl(const Tensor& x,
out_shape
.
Set
(
height_axis
,
out_height
);
out_shape
.
Set
(
height_axis
,
out_height
);
out_shape
.
Set
(
width_axis
,
out_width
);
out_shape
.
Set
(
width_axis
,
out_width
);
const
int64_t
*
padding_h0
=
HalideIR
::
Internal
::
as_const_int
(
pad_top
);
const
int64_t
*
padding_h0
=
as_const_int
(
pad_top
);
const
int64_t
*
padding_w0
=
HalideIR
::
Internal
::
as_const_int
(
pad_left
);
const
int64_t
*
padding_w0
=
as_const_int
(
pad_left
);
const
int64_t
*
padding_h1
=
HalideIR
::
Internal
::
as_const_int
(
pad_bottom
);
const
int64_t
*
padding_h1
=
as_const_int
(
pad_bottom
);
const
int64_t
*
padding_w1
=
HalideIR
::
Internal
::
as_const_int
(
pad_right
);
const
int64_t
*
padding_w1
=
as_const_int
(
pad_right
);
const
bool
do_pad
=
((
padding_h0
&&
*
padding_h0
)
||
(
padding_w0
&&
*
padding_w0
))
||
const
bool
do_pad
=
((
padding_h0
&&
*
padding_h0
)
||
(
padding_w0
&&
*
padding_w0
))
||
((
padding_h1
&&
*
padding_h1
)
||
(
padding_w1
&&
*
padding_w1
));
((
padding_h1
&&
*
padding_h1
)
||
(
padding_w1
&&
*
padding_w1
));
...
@@ -192,7 +192,7 @@ inline bool find_height_width(const std::string& layout,
...
@@ -192,7 +192,7 @@ inline bool find_height_width(const std::string& layout,
* Since pooling does not care about the factor size of dimensions
* Since pooling does not care about the factor size of dimensions
* other than `H` and `W`, one can pass `NCHWc` as well.
* other than `H` and `W`, one can pass `NCHWc` as well.
* \param count_include_pad Whether include padding in the calculation when pool_type is 'avg'
* \param count_include_pad Whether include padding in the calculation when pool_type is 'avg'
*
*
*
*
* \return The output tensor in the same layout
* \return The output tensor in the same layout
*/
*/
...
...
topi/python/topi/vision/ssd/multibox.py
View file @
32af4d28
...
@@ -164,10 +164,10 @@ def transform_loc_ir(cls_prob, loc_pred, anchor, valid_count, out, clip, thresho
...
@@ -164,10 +164,10 @@ def transform_loc_ir(cls_prob, loc_pred, anchor, valid_count, out, clip, thresho
oy
=
py
*
vy
*
ah
+
ay
oy
=
py
*
vy
*
ah
+
ay
ow
=
tvm
.
exp
(
pw
*
vw
)
*
aw
/
2.0
ow
=
tvm
.
exp
(
pw
*
vw
)
*
aw
/
2.0
oh
=
tvm
.
exp
(
ph
*
vh
)
*
ah
/
2.0
oh
=
tvm
.
exp
(
ph
*
vh
)
*
ah
/
2.0
return
tvm
.
select
(
clip
,
tvm
.
ma
ke
.
Max
(
0
,
tvm
.
make
.
M
in
(
1
,
ox
-
ow
)),
ox
-
ow
),
\
return
tvm
.
select
(
clip
,
tvm
.
ma
x
(
0
,
tvm
.
m
in
(
1
,
ox
-
ow
)),
ox
-
ow
),
\
tvm
.
select
(
clip
,
tvm
.
ma
ke
.
Max
(
0
,
tvm
.
make
.
M
in
(
1
,
oy
-
oh
)),
oy
-
oh
),
\
tvm
.
select
(
clip
,
tvm
.
ma
x
(
0
,
tvm
.
m
in
(
1
,
oy
-
oh
)),
oy
-
oh
),
\
tvm
.
select
(
clip
,
tvm
.
ma
ke
.
Max
(
0
,
tvm
.
make
.
M
in
(
1
,
ox
+
ow
)),
ox
+
ow
),
\
tvm
.
select
(
clip
,
tvm
.
ma
x
(
0
,
tvm
.
m
in
(
1
,
ox
+
ow
)),
ox
+
ow
),
\
tvm
.
select
(
clip
,
tvm
.
ma
ke
.
Max
(
0
,
tvm
.
make
.
M
in
(
1
,
oy
+
oh
)),
oy
+
oh
)
tvm
.
select
(
clip
,
tvm
.
ma
x
(
0
,
tvm
.
m
in
(
1
,
oy
+
oh
)),
oy
+
oh
)
batch_size
=
cls_prob
.
shape
[
0
]
batch_size
=
cls_prob
.
shape
[
0
]
num_classes
=
cls_prob
.
shape
[
1
]
num_classes
=
cls_prob
.
shape
[
1
]
...
@@ -191,7 +191,7 @@ def transform_loc_ir(cls_prob, loc_pred, anchor, valid_count, out, clip, thresho
...
@@ -191,7 +191,7 @@ def transform_loc_ir(cls_prob, loc_pred, anchor, valid_count, out, clip, thresho
with
ib
.
if_scope
(
j
>
0
):
with
ib
.
if_scope
(
j
>
0
):
temp
=
p_cls_prob
[
n
*
num_anchors
*
num_classes
+
j
*
num_anchors
+
i
]
temp
=
p_cls_prob
[
n
*
num_anchors
*
num_classes
+
j
*
num_anchors
+
i
]
cls_id
[
0
]
=
tvm
.
select
(
temp
>
score
[
0
],
j
,
cls_id
[
0
])
cls_id
[
0
]
=
tvm
.
select
(
temp
>
score
[
0
],
j
,
cls_id
[
0
])
score
[
0
]
=
tvm
.
ma
ke
.
Ma
x
(
temp
,
score
[
0
])
score
[
0
]
=
tvm
.
max
(
temp
,
score
[
0
])
with
ib
.
if_scope
(
tvm
.
all
(
cls_id
[
0
]
>
0
,
score
[
0
]
<
threshold
)):
with
ib
.
if_scope
(
tvm
.
all
(
cls_id
[
0
]
>
0
,
score
[
0
]
<
threshold
)):
cls_id
[
0
]
=
0
cls_id
[
0
]
=
0
# [id, prob, xmin, ymin, xmax, ymax]
# [id, prob, xmin, ymin, xmax, ymax]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment