Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
9d20fa1b
Unverified
Commit
9d20fa1b
authored
Jan 11, 2019
by
Tianqi Chen
Committed by
GitHub
Jan 11, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[PASS][TENSOR] Use correct select semantics (#2394)
parent
98e761f8
Hide whitespace changes
Inline
Side-by-side
Showing
35 changed files
with
193 additions
and
125 deletions
+193
-125
docs/api/python/intrin.rst
+2
-0
docs/api/python/tvm.rst
+0
-2
include/tvm/ir_operator.h
+2
-2
python/tvm/api.py
+0
-22
python/tvm/expr.py
+8
-0
python/tvm/intrin.py
+36
-0
src/arithmetic/int_set.cc
+4
-2
src/lang/ir_operator.cc
+14
-4
src/pass/inject_copy_intrin.cc
+31
-7
tests/python/integration/test_reduce.py
+1
-1
tests/python/unittest/test_codegen_llvm.py
+3
-3
tests/python/unittest/test_pass_inject_copy_intrin.py
+4
-4
tests/python/unittest/test_pass_loop_partition.py
+2
-2
tests/python/unittest/test_pass_rewrite_unsafe_select.py
+6
-5
tests/python/unittest/test_schedule_schedule_ops.py
+4
-4
topi/include/topi/image/resize.h
+4
-4
topi/include/topi/nn.h
+8
-5
topi/include/topi/nn/dilate.h
+2
-1
topi/include/topi/reduction.h
+4
-4
topi/include/topi/transform.h
+6
-6
topi/python/topi/cuda/conv2d_transpose_nchw.py
+1
-1
topi/python/topi/cuda/nms.py
+4
-4
topi/python/topi/cuda/ssd/multibox.py
+12
-11
topi/python/topi/mali/conv2d.py
+4
-3
topi/python/topi/nn/dilate.py
+1
-1
topi/python/topi/nn/elemwise.py
+3
-2
topi/python/topi/nn/pad.py
+1
-1
topi/python/topi/nn/util.py
+2
-2
topi/python/topi/util.py
+3
-3
topi/python/topi/vision/nms.py
+4
-3
topi/python/topi/vision/rcnn/roi_align.py
+1
-1
topi/python/topi/vision/ssd/multibox.py
+12
-11
topi/tests/python/test_topi_conv2d_nchw.py
+1
-1
tutorials/language/tuple_inputs.py
+2
-2
tutorials/optimize/opt_conv_cuda.py
+1
-1
No files found.
docs/api/python/intrin.rst
View file @
9d20fa1b
...
@@ -11,6 +11,7 @@ tvm.intrin
...
@@ -11,6 +11,7 @@ tvm.intrin
tvm.call_extern
tvm.call_extern
tvm.call_llvm_intrin
tvm.call_llvm_intrin
tvm.register_intrin_rule
tvm.register_intrin_rule
tvm.if_then_else
tvm.exp
tvm.exp
tvm.log
tvm.log
tvm.floor
tvm.floor
...
@@ -26,6 +27,7 @@ tvm.intrin
...
@@ -26,6 +27,7 @@ tvm.intrin
.. autofunction:: tvm.call_extern
.. autofunction:: tvm.call_extern
.. autofunction:: tvm.call_llvm_intrin
.. autofunction:: tvm.call_llvm_intrin
.. autofunction:: tvm.register_intrin_rule
.. autofunction:: tvm.register_intrin_rule
.. autofunction:: tvm.if_then_else
.. autofunction:: tvm.exp
.. autofunction:: tvm.exp
.. autofunction:: tvm.log
.. autofunction:: tvm.log
.. autofunction:: tvm.floor
.. autofunction:: tvm.floor
...
...
docs/api/python/tvm.rst
View file @
9d20fa1b
...
@@ -15,7 +15,6 @@ The user facing API for computation declaration.
...
@@ -15,7 +15,6 @@ The user facing API for computation declaration.
tvm.extern
tvm.extern
tvm.decl_buffer
tvm.decl_buffer
tvm.reduce_axis
tvm.reduce_axis
tvm.select
tvm.thread_axis
tvm.thread_axis
tvm.comm_reducer
tvm.comm_reducer
tvm.sum
tvm.sum
...
@@ -34,7 +33,6 @@ The user facing API for computation declaration.
...
@@ -34,7 +33,6 @@ The user facing API for computation declaration.
.. autofunction:: tvm.extern
.. autofunction:: tvm.extern
.. autofunction:: tvm.decl_buffer
.. autofunction:: tvm.decl_buffer
.. autofunction:: tvm.reduce_axis
.. autofunction:: tvm.reduce_axis
.. autofunction:: tvm.select
.. autofunction:: tvm.thread_axis
.. autofunction:: tvm.thread_axis
.. autofunction:: tvm.comm_reducer
.. autofunction:: tvm.comm_reducer
.. autofunction:: tvm.sum
.. autofunction:: tvm.sum
...
...
include/tvm/ir_operator.h
View file @
9d20fa1b
...
@@ -392,7 +392,7 @@ TVM_DLL Expr operator^(Expr a, Expr b);
...
@@ -392,7 +392,7 @@ TVM_DLL Expr operator^(Expr a, Expr b);
*/
*/
TVM_DLL
Expr
operator
~
(
Expr
a
);
TVM_DLL
Expr
operator
~
(
Expr
a
);
/*!
/*!
* \brief
select result by condition
* \brief
Conditional expression.
*
*
* \param cond The condition
* \param cond The condition
* \param true_value The value when results are true.
* \param true_value The value when results are true.
...
@@ -401,7 +401,7 @@ TVM_DLL Expr operator~(Expr a);
...
@@ -401,7 +401,7 @@ TVM_DLL Expr operator~(Expr a);
* \note this function does eager constant folding for
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
* index types(int32, int64) when possible.
*/
*/
TVM_DLL
Expr
select
(
Expr
cond
,
Expr
true_value
,
Expr
false_value
);
TVM_DLL
Expr
if_then_else
(
Expr
cond
,
Expr
true_value
,
Expr
false_value
);
/*!
/*!
* \brief Mark condition as likely.
* \brief Mark condition as likely.
* \param cond The condition
* \param cond The condition
...
...
python/tvm/api.py
View file @
9d20fa1b
...
@@ -669,28 +669,6 @@ def reduce_axis(dom, name="rv"):
...
@@ -669,28 +669,6 @@ def reduce_axis(dom, name="rv"):
return
_IterVar
(
dom
,
name
,
2
)
return
_IterVar
(
dom
,
name
,
2
)
def
select
(
cond
,
t
,
f
):
"""Construct a select branch.
Parameters
----------
cond : Expr
The condition
t : Expr
The result expression if cond is true.
f : Expr
The result expression if cond is false.
Returns
-------
node : Node
The tvm.expr.Select node
"""
return
_expr
.
Select
(
convert
(
cond
),
convert
(
t
),
convert
(
f
))
def
comm_reducer
(
fcombine
,
fidentity
,
name
=
"reduce"
):
def
comm_reducer
(
fcombine
,
fidentity
,
name
=
"reduce"
):
"""Create a commutative reducer for reduction.
"""Create a commutative reducer for reduction.
...
...
python/tvm/expr.py
View file @
9d20fa1b
...
@@ -624,6 +624,13 @@ class Not(LogicalExpr):
...
@@ -624,6 +624,13 @@ class Not(LogicalExpr):
class
Select
(
Expr
):
class
Select
(
Expr
):
"""Select node.
"""Select node.
Note
----
Select may compute both true_value and false_value.
Use :any:`tvm.if_then_else` instead if you want to
get a conditional expression that only evaluates
the correct branch.
Parameters
Parameters
----------
----------
condition : Expr
condition : Expr
...
@@ -634,6 +641,7 @@ class Select(Expr):
...
@@ -634,6 +641,7 @@ class Select(Expr):
false_value : Expr
false_value : Expr
The value to take when condition is false.
The value to take when condition is false.
"""
"""
def
__init__
(
self
,
condition
,
true_value
,
false_value
):
def
__init__
(
self
,
condition
,
true_value
,
false_value
):
self
.
__init_handle_by_constructor__
(
self
.
__init_handle_by_constructor__
(
...
...
python/tvm/intrin.py
View file @
9d20fa1b
...
@@ -393,6 +393,42 @@ def fmod(x, y):
...
@@ -393,6 +393,42 @@ def fmod(x, y):
"""
"""
return
call_pure_intrin
(
x
.
dtype
,
"fmod"
,
x
,
y
)
return
call_pure_intrin
(
x
.
dtype
,
"fmod"
,
x
,
y
)
def
if_then_else
(
cond
,
t
,
f
):
"""Conditional selection expression.
Parameters
----------
cond : Expr
The condition
t : Expr
The result expression if cond is true.
f : Expr
The result expression if cond is false.
Returns
-------
result : Node
The result of conditional expression.
Note
----
Unlike Select, if_then_else will not execute
the branch that does not satisfy the condition.
You can use it to guard against out of bound access.
Unlike Select, if_then_else cannot be vectorized
if some lanes in the vector have different conditions.
"""
t
=
convert
(
t
)
f
=
convert
(
f
)
cond
=
convert
(
cond
)
if
cond
.
dtype
!=
"bool"
:
raise
TypeError
(
"The condition's data type has to be bool"
)
return
call_pure_intrin
(
t
.
dtype
,
"tvm_if_then_else"
,
cond
,
t
,
f
)
# Intrinsic rule related code
# Intrinsic rule related code
def
register_intrin_rule
(
target
,
intrin
,
f
=
None
,
override
=
False
):
def
register_intrin_rule
(
target
,
intrin
,
f
=
None
,
override
=
False
):
"""Register an intrinsic function generation rule.
"""Register an intrinsic function generation rule.
...
...
src/arithmetic/int_set.cc
View file @
9d20fa1b
...
@@ -268,8 +268,9 @@ inline IntSet CombineInterval<Mul>(Interval a, Interval b) {
...
@@ -268,8 +268,9 @@ inline IntSet CombineInterval<Mul>(Interval a, Interval b) {
}
else
if
(
is_negative_const
(
b
.
min
))
{
}
else
if
(
is_negative_const
(
b
.
min
))
{
return
IntervalSet
::
make
(
e2
,
e1
);
return
IntervalSet
::
make
(
e2
,
e1
);
}
else
if
(
a
.
is_bounded
())
{
}
else
if
(
a
.
is_bounded
())
{
using
ir
::
Select
;
Expr
cmp
=
b
.
min
>=
make_zero
(
b
.
min
.
type
().
element_of
());
Expr
cmp
=
b
.
min
>=
make_zero
(
b
.
min
.
type
().
element_of
());
return
IntervalSet
::
make
(
select
(
cmp
,
e1
,
e2
),
select
(
cmp
,
e2
,
e1
));
return
IntervalSet
::
make
(
Select
::
make
(
cmp
,
e1
,
e2
),
Select
::
make
(
cmp
,
e2
,
e1
));
}
}
}
}
LOG
(
WARNING
)
<<
"Return Everything in CombineInterval Mul"
;
LOG
(
WARNING
)
<<
"Return Everything in CombineInterval Mul"
;
...
@@ -294,8 +295,9 @@ inline IntSet CombineInterval<Div>(Interval a, Interval b) {
...
@@ -294,8 +295,9 @@ inline IntSet CombineInterval<Div>(Interval a, Interval b) {
}
else
if
(
is_negative_const
(
b
.
min
))
{
}
else
if
(
is_negative_const
(
b
.
min
))
{
return
IntervalSet
::
make
(
e2
,
e1
);
return
IntervalSet
::
make
(
e2
,
e1
);
}
else
if
(
a
.
is_bounded
())
{
}
else
if
(
a
.
is_bounded
())
{
using
ir
::
Select
;
Expr
cmp
=
b
.
min
>=
make_zero
(
b
.
min
.
type
().
element_of
());
Expr
cmp
=
b
.
min
>=
make_zero
(
b
.
min
.
type
().
element_of
());
return
IntervalSet
::
make
(
select
(
cmp
,
e1
,
e2
),
select
(
cmp
,
e2
,
e1
));
return
IntervalSet
::
make
(
Select
::
make
(
cmp
,
e1
,
e2
),
Select
::
make
(
cmp
,
e2
,
e1
));
}
}
}
}
LOG
(
WARNING
)
<<
"Return Everything in CombineInterval Div"
;
LOG
(
WARNING
)
<<
"Return Everything in CombineInterval Div"
;
...
...
src/lang/ir_operator.cc
View file @
9d20fa1b
...
@@ -240,10 +240,11 @@ Expr max(Expr a, Expr b) {
...
@@ -240,10 +240,11 @@ Expr max(Expr a, Expr b) {
return
ir
::
Max
::
make
(
a
,
b
);
return
ir
::
Max
::
make
(
a
,
b
);
}
}
Expr
select
(
Expr
cond
,
Expr
true_value
,
Expr
false_value
)
{
Expr
if_then_else
(
Expr
cond
,
Expr
true_value
,
Expr
false_value
)
{
using
ir
::
IntImm
;
using
ir
::
IntImm
;
using
ir
::
UIntImm
;
using
ir
::
UIntImm
;
CHECK
(
cond
.
type
().
is_bool
());
CHECK
(
cond
.
type
()
==
Bool
(
1
))
<<
"if_then_else only accept a single condition"
;
BinaryOpMatchTypes
(
true_value
,
false_value
);
BinaryOpMatchTypes
(
true_value
,
false_value
);
if
(
const
UIntImm
*
op
=
cond
.
as
<
UIntImm
>
())
{
if
(
const
UIntImm
*
op
=
cond
.
as
<
UIntImm
>
())
{
if
(
op
->
value
!=
0
)
{
if
(
op
->
value
!=
0
)
{
...
@@ -258,7 +259,11 @@ Expr select(Expr cond, Expr true_value, Expr false_value) {
...
@@ -258,7 +259,11 @@ Expr select(Expr cond, Expr true_value, Expr false_value) {
return
false_value
;
return
false_value
;
}
}
}
}
return
ir
::
Select
::
make
(
cond
,
true_value
,
false_value
);
return
ir
::
Call
::
make
(
true_value
.
type
(),
ir
::
intrinsic
::
tvm_if_then_else
,
{
cond
,
true_value
,
false_value
},
ir
::
Call
::
PureIntrinsic
);
}
}
Expr
likely
(
Expr
cond
)
{
Expr
likely
(
Expr
cond
)
{
...
@@ -402,7 +407,12 @@ Expr pow(Expr x, Expr y) {
...
@@ -402,7 +407,12 @@ Expr pow(Expr x, Expr y) {
Expr
abs
(
Expr
x
)
{
Expr
abs
(
Expr
x
)
{
if
(
x
.
type
().
is_int
())
{
if
(
x
.
type
().
is_int
())
{
return
select
(
x
>=
make_zero
(
x
.
type
()),
x
,
-
x
);
using
ir
::
IntImm
;
const
IntImm
*
px
=
x
.
as
<
IntImm
>
();
if
(
px
)
{
return
ir
::
IntImm
::
make
(
x
.
type
(),
std
::
abs
(
px
->
value
));
}
return
ir
::
Select
::
make
(
x
>=
make_zero
(
x
.
type
()),
x
,
-
x
);
}
else
if
(
x
.
type
().
is_float
())
{
}
else
if
(
x
.
type
().
is_float
())
{
return
ir
::
Call
::
make
(
x
.
type
(),
"fabs"
,
{
x
},
ir
::
Call
::
PureIntrinsic
);
return
ir
::
Call
::
make
(
x
.
type
(),
"fabs"
,
{
x
},
ir
::
Call
::
PureIntrinsic
);
}
else
if
(
x
.
type
().
is_uint
())
{
}
else
if
(
x
.
type
().
is_uint
())
{
...
...
src/pass/inject_copy_intrin.cc
View file @
9d20fa1b
...
@@ -35,6 +35,26 @@ class CopyIntrinInjector : public IRMutator {
...
@@ -35,6 +35,26 @@ class CopyIntrinInjector : public IRMutator {
}
}
private
:
private
:
bool
MatchCondition
(
Expr
expr
,
Expr
*
cond
,
Expr
*
true_value
,
Expr
*
false_value
)
{
if
(
const
auto
*
op
=
expr
.
as
<
Select
>
())
{
*
cond
=
op
->
condition
;
*
true_value
=
op
->
true_value
;
*
false_value
=
op
->
false_value
;
return
true
;
}
else
if
(
const
auto
*
op
=
expr
.
as
<
Call
>
())
{
if
(
op
->
name
==
intrinsic
::
tvm_if_then_else
)
{
*
cond
=
op
->
args
[
0
];
*
true_value
=
op
->
args
[
1
];
*
false_value
=
op
->
args
[
2
];
return
true
;
}
}
return
false
;
}
bool
MatchCopyPattern
(
Stmt
stmt
,
Stmt
*
out
)
{
bool
MatchCopyPattern
(
Stmt
stmt
,
Stmt
*
out
)
{
Stmt
body
=
stmt
;
Stmt
body
=
stmt
;
bool
is_single_point_copy
=
false
;
bool
is_single_point_copy
=
false
;
...
@@ -48,16 +68,20 @@ class CopyIntrinInjector : public IRMutator {
...
@@ -48,16 +68,20 @@ class CopyIntrinInjector : public IRMutator {
}
}
const
Store
*
store
=
body
.
as
<
Store
>
();
const
Store
*
store
=
body
.
as
<
Store
>
();
if
(
store
==
nullptr
)
return
false
;
if
(
store
==
nullptr
)
return
false
;
const
Select
*
select
=
store
->
value
.
as
<
Select
>
();
Expr
sel_cond
,
sel_true_value
,
sel_false_value
;
bool
has_cond
=
MatchCondition
(
store
->
value
,
&
sel_cond
,
&
sel_true_value
,
&
sel_false_value
);
const
Cast
*
cast
=
store
->
value
.
as
<
Cast
>
();
const
Cast
*
cast
=
store
->
value
.
as
<
Cast
>
();
const
Load
*
load
=
store
->
value
.
as
<
Load
>
();
const
Load
*
load
=
store
->
value
.
as
<
Load
>
();
if
(
0
==
loops
.
size
())
{
if
(
0
==
loops
.
size
())
{
is_single_point_copy
=
true
;
is_single_point_copy
=
true
;
CHECK
(
select
==
nullptr
);
CHECK
(
!
has_cond
);
}
}
// for now only support true condition matching
// for now only support true condition matching
if
(
select
!=
nullptr
)
{
if
(
has_cond
)
{
load
=
sel
ect
->
true_value
.
as
<
Load
>
();
load
=
sel
_
true_value
.
as
<
Load
>
();
}
}
// cast can be part of the pattern
// cast can be part of the pattern
if
(
cast
!=
nullptr
)
{
if
(
cast
!=
nullptr
)
{
...
@@ -88,10 +112,10 @@ class CopyIntrinInjector : public IRMutator {
...
@@ -88,10 +112,10 @@ class CopyIntrinInjector : public IRMutator {
Array
<
Expr
>
pad_before
,
pad_after
;
Array
<
Expr
>
pad_before
,
pad_after
;
Expr
pad_value
;
Expr
pad_value
;
Expr
src_elem_offset
=
load_strides
[
loop_var_size
];
Expr
src_elem_offset
=
load_strides
[
loop_var_size
];
if
(
select
!=
nullptr
)
{
if
(
has_cond
)
{
Array
<
Expr
>
clip_bound
=
Array
<
Expr
>
clip_bound
=
arith
::
DetectClipBound
(
sel
ect
->
condition
,
loop_vars
);
arith
::
DetectClipBound
(
sel
_cond
,
loop_vars
);
pad_value
=
sel
ect
->
false_value
;
pad_value
=
sel
_
false_value
;
if
(
clip_bound
.
size
()
==
0
)
return
false
;
if
(
clip_bound
.
size
()
==
0
)
return
false
;
CHECK_EQ
(
src_shape
.
size
(),
loop_vars
.
size
());
CHECK_EQ
(
src_shape
.
size
(),
loop_vars
.
size
());
CHECK_EQ
(
clip_bound
.
size
(),
loop_vars
.
size
()
*
2
);
CHECK_EQ
(
clip_bound
.
size
(),
loop_vars
.
size
()
*
2
);
...
...
tests/python/integration/test_reduce.py
View file @
9d20fa1b
...
@@ -8,7 +8,7 @@ def test_reduce_prims():
...
@@ -8,7 +8,7 @@ def test_reduce_prims():
n
=
tvm
.
var
(
'n'
)
n
=
tvm
.
var
(
'n'
)
m
=
tvm
.
var
(
'm'
)
m
=
tvm
.
var
(
'm'
)
A
=
tvm
.
placeholder
((
n
,
m
),
name
=
'A'
)
A
=
tvm
.
placeholder
((
n
,
m
),
name
=
'A'
)
R
=
tvm
.
compute
((
n
,
),
lambda
i
:
tvm
.
s
elect
((
i
>
1
),
1
,
0
),
name
=
'R'
)
R
=
tvm
.
compute
((
n
,
),
lambda
i
:
tvm
.
expr
.
S
elect
((
i
>
1
),
1
,
0
),
name
=
'R'
)
k
=
tvm
.
reduce_axis
((
0
,
m
))
k
=
tvm
.
reduce_axis
((
0
,
m
))
B
=
tvm
.
compute
((
n
,),
lambda
i
:
reducer
(
A
[
i
,
k
],
axis
=
k
,
where
=
(
R
[
i
]
==
1
)),
name
=
'B'
)
B
=
tvm
.
compute
((
n
,),
lambda
i
:
reducer
(
A
[
i
,
k
],
axis
=
k
,
where
=
(
R
[
i
]
==
1
)),
name
=
'B'
)
# schedule
# schedule
...
...
tests/python/unittest/test_codegen_llvm.py
View file @
9d20fa1b
...
@@ -287,12 +287,12 @@ def test_multiple_func():
...
@@ -287,12 +287,12 @@ def test_multiple_func():
def
test_llvm_
select
():
def
test_llvm_
condition
():
def
check_llvm
(
n
,
offset
):
def
check_llvm
(
n
,
offset
):
if
not
tvm
.
module
.
enabled
(
"llvm"
):
if
not
tvm
.
module
.
enabled
(
"llvm"
):
return
return
A
=
tvm
.
placeholder
((
n
,
),
name
=
'A'
)
A
=
tvm
.
placeholder
((
n
,
),
name
=
'A'
)
C
=
tvm
.
compute
((
n
,),
lambda
i
:
tvm
.
select
(
i
>=
offset
,
A
[
i
],
0.0
),
name
=
'C'
)
C
=
tvm
.
compute
((
n
,),
lambda
i
:
tvm
.
if_then_else
(
i
>=
offset
,
A
[
i
],
0.0
),
name
=
'C'
)
s
=
tvm
.
create_schedule
(
C
.
op
)
s
=
tvm
.
create_schedule
(
C
.
op
)
# build and invoke the kernel.
# build and invoke the kernel.
f
=
tvm
.
build
(
s
,
[
A
,
C
],
"llvm"
)
f
=
tvm
.
build
(
s
,
[
A
,
C
],
"llvm"
)
...
@@ -462,7 +462,7 @@ if __name__ == "__main__":
...
@@ -462,7 +462,7 @@ if __name__ == "__main__":
test_rank_zero_bound_checkers
()
test_rank_zero_bound_checkers
()
test_llvm_bool
()
test_llvm_bool
()
test_llvm_persist_parallel
()
test_llvm_persist_parallel
()
test_llvm_
select
()
test_llvm_
condition
()
test_llvm_vadd_pipeline
()
test_llvm_vadd_pipeline
()
test_llvm_add_pipeline
()
test_llvm_add_pipeline
()
test_llvm_intrin
()
test_llvm_intrin
()
...
...
tests/python/unittest/test_pass_inject_copy_intrin.py
View file @
9d20fa1b
...
@@ -25,8 +25,8 @@ def test_copy_pad():
...
@@ -25,8 +25,8 @@ def test_copy_pad():
l
=
tvm
.
var
(
'l'
)
l
=
tvm
.
var
(
'l'
)
A
=
tvm
.
placeholder
((
m
,
l
),
name
=
'A'
)
A
=
tvm
.
placeholder
((
m
,
l
),
name
=
'A'
)
B
=
tvm
.
compute
((
m
+
2
,
l
),
lambda
i
,
j
:
B
=
tvm
.
compute
((
m
+
2
,
l
),
lambda
i
,
j
:
tvm
.
select
(
tvm
.
all
(
i
>=
1
,
i
<
m
+
1
),
tvm
.
if_then_else
(
tvm
.
all
(
i
>=
1
,
i
<
m
+
1
),
A
[
i
-
1
,
j
],
1.0
),
name
=
'B'
)
A
[
i
-
1
,
j
],
1.0
),
name
=
'B'
)
s
=
tvm
.
create_schedule
(
B
.
op
)
s
=
tvm
.
create_schedule
(
B
.
op
)
s
[
B
]
.
pragma
(
B
.
op
.
axis
[
0
],
"memcpy"
)
s
[
B
]
.
pragma
(
B
.
op
.
axis
[
0
],
"memcpy"
)
bounds
=
tvm
.
schedule
.
InferBound
(
s
)
bounds
=
tvm
.
schedule
.
InferBound
(
s
)
...
@@ -71,8 +71,8 @@ def test_copy_pad_split():
...
@@ -71,8 +71,8 @@ def test_copy_pad_split():
m
=
4
*
3
m
=
4
*
3
A
=
tvm
.
placeholder
((
m
,
),
name
=
"A"
)
A
=
tvm
.
placeholder
((
m
,
),
name
=
"A"
)
Apad
=
tvm
.
compute
((
m
+
2
,),
lambda
i
:
Apad
=
tvm
.
compute
((
m
+
2
,),
lambda
i
:
tvm
.
select
(
tvm
.
all
(
i
>=
1
,
i
<=
m
),
tvm
.
if_then_else
(
tvm
.
all
(
i
>=
1
,
i
<=
m
),
A
[
i
-
1
],
0.0
),
"Apad"
)
A
[
i
-
1
],
0.0
),
"Apad"
)
B
=
tvm
.
compute
((
m
,),
lambda
i
:
Apad
[
i
]
+
Apad
[
i
+
1
]
+
Apad
[
i
+
2
])
B
=
tvm
.
compute
((
m
,),
lambda
i
:
Apad
[
i
]
+
Apad
[
i
+
1
]
+
Apad
[
i
+
2
])
s
=
tvm
.
create_schedule
(
B
.
op
)
s
=
tvm
.
create_schedule
(
B
.
op
)
xo
,
xi
=
s
[
B
]
.
split
(
B
.
op
.
axis
[
0
],
factor
=
4
)
xo
,
xi
=
s
[
B
]
.
split
(
B
.
op
.
axis
[
0
],
factor
=
4
)
...
...
tests/python/unittest/test_pass_loop_partition.py
View file @
9d20fa1b
...
@@ -133,7 +133,7 @@ def test_vectorize():
...
@@ -133,7 +133,7 @@ def test_vectorize():
assert
(
x
.
var
.
name
not
in
str
(
body
.
condition
))
assert
(
x
.
var
.
name
not
in
str
(
body
.
condition
))
assert
(
any
(
collect_visit
(
body
.
then_case
,
lambda
x
:
isinstance
(
x
,
tvm
.
expr
.
Ramp
))))
assert
(
any
(
collect_visit
(
body
.
then_case
,
lambda
x
:
isinstance
(
x
,
tvm
.
expr
.
Ramp
))))
def
test_
select
():
def
test_
condition
():
ib
=
tvm
.
ir_builder
.
create
()
ib
=
tvm
.
ir_builder
.
create
()
m
=
tvm
.
var
(
'm'
)
m
=
tvm
.
var
(
'm'
)
n
=
tvm
.
var
(
'n'
)
n
=
tvm
.
var
(
'n'
)
...
@@ -335,7 +335,7 @@ if __name__ == "__main__":
...
@@ -335,7 +335,7 @@ if __name__ == "__main__":
test_multi_if
()
test_multi_if
()
test_thread_axis
()
test_thread_axis
()
test_vectorize
()
test_vectorize
()
test_
select
()
test_
condition
()
test_thread_axis2
()
test_thread_axis2
()
test_everything_during_deduction
()
test_everything_during_deduction
()
test_single_likely
()
test_single_likely
()
...
...
tests/python/unittest/test_pass_rewrite_unsafe_select.py
View file @
9d20fa1b
import
tvm
import
tvm
def
test_rewrite_
s
elect
():
def
test_rewrite_
S
elect
():
ib
=
tvm
.
ir_builder
.
create
()
ib
=
tvm
.
ir_builder
.
create
()
A
=
ib
.
allocate
(
"float32"
,
100
,
name
=
"A"
,
scope
=
"global"
)
A
=
ib
.
allocate
(
"float32"
,
100
,
name
=
"A"
,
scope
=
"global"
)
i
=
tvm
.
var
(
"i"
)
i
=
tvm
.
var
(
"i"
)
y
=
tvm
.
s
elect
(
i
>
1
,
A
[
i
-
1
],
1.0
)
y
=
tvm
.
expr
.
S
elect
(
i
>
1
,
A
[
i
-
1
],
1.0
)
yy
=
tvm
.
ir_pass
.
RewriteUnsafeSelect
(
tvm
.
make
.
Evaluate
(
y
))
.
value
yy
=
tvm
.
ir_pass
.
RewriteUnsafeSelect
(
tvm
.
make
.
Evaluate
(
y
))
.
value
z
=
tvm
.
select
(
tvm
.
select
(
i
>
1
,
A
[
i
-
1
],
1.0
)
>
0.0
,
A
[
i
],
0.1
)
z
=
tvm
.
expr
.
Select
(
tvm
.
expr
.
Select
(
i
>
1
,
A
[
i
-
1
],
1.0
)
>
0.0
,
A
[
i
],
0.1
)
zz
=
tvm
.
ir_pass
.
RewriteUnsafeSelect
(
tvm
.
make
.
Evaluate
(
z
))
.
value
zz
=
tvm
.
ir_pass
.
RewriteUnsafeSelect
(
tvm
.
make
.
Evaluate
(
z
))
.
value
a
=
tvm
.
s
elect
(
i
>
10
,
y
,
z
)
a
=
tvm
.
expr
.
S
elect
(
i
>
10
,
y
,
z
)
aa
=
tvm
.
ir_pass
.
RewriteUnsafeSelect
(
tvm
.
make
.
Evaluate
(
a
))
.
value
aa
=
tvm
.
ir_pass
.
RewriteUnsafeSelect
(
tvm
.
make
.
Evaluate
(
a
))
.
value
assert
yy
.
name
==
"tvm_if_then_else"
assert
yy
.
name
==
"tvm_if_then_else"
assert
zz
.
name
==
"tvm_if_then_else"
assert
zz
.
name
==
"tvm_if_then_else"
...
@@ -19,4 +20,4 @@ def test_rewrite_select():
...
@@ -19,4 +20,4 @@ def test_rewrite_select():
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_rewrite_
s
elect
()
test_rewrite_
S
elect
()
tests/python/unittest/test_schedule_schedule_ops.py
View file @
9d20fa1b
...
@@ -63,8 +63,8 @@ def test_schedule_scan():
...
@@ -63,8 +63,8 @@ def test_schedule_scan():
def
test_inline_multi_reduce
():
def
test_inline_multi_reduce
():
def
argmax_comp
(
x
,
y
):
def
argmax_comp
(
x
,
y
):
idx
=
tvm
.
s
elect
((
x
[
1
]
>=
y
[
1
]),
x
[
0
],
y
[
0
])
idx
=
tvm
.
expr
.
S
elect
((
x
[
1
]
>=
y
[
1
]),
x
[
0
],
y
[
0
])
val
=
tvm
.
s
elect
((
x
[
1
]
>=
y
[
1
]),
x
[
1
],
y
[
1
])
val
=
tvm
.
expr
.
S
elect
((
x
[
1
]
>=
y
[
1
]),
x
[
1
],
y
[
1
])
return
idx
,
val
return
idx
,
val
def
argmax_init
(
idx_typ
,
val_typ
):
def
argmax_init
(
idx_typ
,
val_typ
):
return
tvm
.
const
(
-
1
,
idx_typ
),
tvm
.
min_value
(
val_typ
)
return
tvm
.
const
(
-
1
,
idx_typ
),
tvm
.
min_value
(
val_typ
)
...
@@ -272,7 +272,7 @@ def test_schedule_cache_relayout4():
...
@@ -272,7 +272,7 @@ def test_schedule_cache_relayout4():
def
test_schedule_bound_condition
():
def
test_schedule_bound_condition
():
A
=
tvm
.
placeholder
((
64
,),
name
=
'A'
,
dtype
=
"float32"
)
A
=
tvm
.
placeholder
((
64
,),
name
=
'A'
,
dtype
=
"float32"
)
Apad
=
tvm
.
compute
((
66
,),
lambda
i
:
tvm
.
select
(
Apad
=
tvm
.
compute
((
66
,),
lambda
i
:
tvm
.
if_then_else
(
tvm
.
all
(
i
>
0
,
i
<
65
),
A
[
i
-
1
],
tvm
.
const
(
0.
,
"float32"
)),
name
=
'Apad'
)
tvm
.
all
(
i
>
0
,
i
<
65
),
A
[
i
-
1
],
tvm
.
const
(
0.
,
"float32"
)),
name
=
'Apad'
)
Apad2
=
tvm
.
compute
((
66
,),
lambda
i
:
Apad
[
i
]
*
2
,
name
=
'Apad2'
)
Apad2
=
tvm
.
compute
((
66
,),
lambda
i
:
Apad
[
i
]
*
2
,
name
=
'Apad2'
)
s
=
tvm
.
create_schedule
(
Apad2
.
op
)
s
=
tvm
.
create_schedule
(
Apad2
.
op
)
...
@@ -424,7 +424,7 @@ def test_loop_dep_reduce_cache_write():
...
@@ -424,7 +424,7 @@ def test_loop_dep_reduce_cache_write():
X
=
tvm
.
placeholder
(
shape
=
(
10
,),
name
=
"x"
)
X
=
tvm
.
placeholder
(
shape
=
(
10
,),
name
=
"x"
)
def
f
(
n
):
def
f
(
n
):
rv
=
tvm
.
reduce_axis
((
0
,
n
))
rv
=
tvm
.
reduce_axis
((
0
,
n
))
init
=
lambda
dtype
:
tvm
.
s
elect
(
n
>
1
,
tvm
.
const
(
0
,
dtype
),
n
.
astype
(
dtype
))
init
=
lambda
dtype
:
tvm
.
expr
.
S
elect
(
n
>
1
,
tvm
.
const
(
0
,
dtype
),
n
.
astype
(
dtype
))
sum
=
tvm
.
comm_reducer
(
lambda
x
,
y
:
tvm
.
max
(
x
+
y
,
n
.
astype
(
'float32'
)),
init
,
name
=
'sum'
)
sum
=
tvm
.
comm_reducer
(
lambda
x
,
y
:
tvm
.
max
(
x
+
y
,
n
.
astype
(
'float32'
)),
init
,
name
=
'sum'
)
return
sum
(
X
[
rv
],
axis
=
rv
)
return
sum
(
X
[
rv
],
axis
=
rv
)
Y
=
tvm
.
compute
(
X
.
shape
,
f
,
name
=
"y"
)
Y
=
tvm
.
compute
(
X
.
shape
,
f
,
name
=
"y"
)
...
...
topi/include/topi/image/resize.h
View file @
9d20fa1b
...
@@ -38,7 +38,7 @@ inline Expr bilinear_sample_nchw(const Tensor& input, const Array<Expr>& indices
...
@@ -38,7 +38,7 @@ inline Expr bilinear_sample_nchw(const Tensor& input, const Array<Expr>& indices
auto
yc
=
HalideIR
::
Internal
::
Cast
::
make
(
Int
(
32
),
tvm
::
ceil
(
in_y
));
auto
yc
=
HalideIR
::
Internal
::
Cast
::
make
(
Int
(
32
),
tvm
::
ceil
(
in_y
));
auto
y0
=
HalideIR
::
Internal
::
Cast
::
make
(
Int
(
32
),
tvm
::
floor
(
in_y
));
auto
y0
=
HalideIR
::
Internal
::
Cast
::
make
(
Int
(
32
),
tvm
::
floor
(
in_y
));
auto
y1
=
tvm
::
select
((
yc
>
max_y
),
max_y
,
yc
);
auto
y1
=
tvm
::
if_then_else
((
yc
>
max_y
),
max_y
,
yc
);
auto
y_lerp
=
in_y
-
yf
;
auto
y_lerp
=
in_y
-
yf
;
auto
in_x
=
indices
[
3
];
auto
in_x
=
indices
[
3
];
...
@@ -46,7 +46,7 @@ inline Expr bilinear_sample_nchw(const Tensor& input, const Array<Expr>& indices
...
@@ -46,7 +46,7 @@ inline Expr bilinear_sample_nchw(const Tensor& input, const Array<Expr>& indices
auto
xc
=
HalideIR
::
Internal
::
Cast
::
make
(
Int
(
32
),
tvm
::
ceil
(
in_x
));
auto
xc
=
HalideIR
::
Internal
::
Cast
::
make
(
Int
(
32
),
tvm
::
ceil
(
in_x
));
auto
x0
=
HalideIR
::
Internal
::
Cast
::
make
(
Int
(
32
),
tvm
::
floor
(
in_x
));
auto
x0
=
HalideIR
::
Internal
::
Cast
::
make
(
Int
(
32
),
tvm
::
floor
(
in_x
));
auto
x1
=
tvm
::
select
((
xc
>
max_x
),
max_x
,
xc
);
auto
x1
=
tvm
::
if_then_else
((
xc
>
max_x
),
max_x
,
xc
);
auto
x_lerp
=
in_x
-
xf
;
auto
x_lerp
=
in_x
-
xf
;
auto
A
=
input
(
indices
[
0
],
indices
[
1
],
y0
,
x0
);
auto
A
=
input
(
indices
[
0
],
indices
[
1
],
y0
,
x0
);
...
@@ -215,7 +215,7 @@ inline Tensor resize_bilinear_nhwc(const Tensor& input,
...
@@ -215,7 +215,7 @@ inline Tensor resize_bilinear_nhwc(const Tensor& input,
auto
yc
=
HalideIR
::
Internal
::
Cast
::
make
(
Int
(
32
),
tvm
::
ceil
(
in_y
));
auto
yc
=
HalideIR
::
Internal
::
Cast
::
make
(
Int
(
32
),
tvm
::
ceil
(
in_y
));
auto
y0
=
HalideIR
::
Internal
::
Cast
::
make
(
Int
(
32
),
tvm
::
floor
(
in_y
));
auto
y0
=
HalideIR
::
Internal
::
Cast
::
make
(
Int
(
32
),
tvm
::
floor
(
in_y
));
auto
y1
=
tvm
::
select
((
yc
>
other_y
),
other_y
,
yc
);
auto
y1
=
tvm
::
if_then_else
((
yc
>
other_y
),
other_y
,
yc
);
auto
y_lerp
=
in_y
-
yf
;
auto
y_lerp
=
in_y
-
yf
;
auto
in_x
=
indices
[
2
]
*
x_ratio
;
auto
in_x
=
indices
[
2
]
*
x_ratio
;
...
@@ -223,7 +223,7 @@ inline Tensor resize_bilinear_nhwc(const Tensor& input,
...
@@ -223,7 +223,7 @@ inline Tensor resize_bilinear_nhwc(const Tensor& input,
auto
xc
=
HalideIR
::
Internal
::
Cast
::
make
(
Int
(
32
),
tvm
::
ceil
(
in_x
));
auto
xc
=
HalideIR
::
Internal
::
Cast
::
make
(
Int
(
32
),
tvm
::
ceil
(
in_x
));
auto
x0
=
HalideIR
::
Internal
::
Cast
::
make
(
Int
(
32
),
tvm
::
floor
(
in_x
));
auto
x0
=
HalideIR
::
Internal
::
Cast
::
make
(
Int
(
32
),
tvm
::
floor
(
in_x
));
auto
x1
=
tvm
::
select
((
xc
>
other_x
),
other_x
,
xc
);
auto
x1
=
tvm
::
if_then_else
((
xc
>
other_x
),
other_x
,
xc
);
auto
x_lerp
=
in_x
-
xf
;
auto
x_lerp
=
in_x
-
xf
;
auto
A
=
input
(
indices
[
0
],
y0
,
x0
,
indices
[
3
]);
auto
A
=
input
(
indices
[
0
],
y0
,
x0
,
indices
[
3
]);
...
...
topi/include/topi/nn.h
View file @
9d20fa1b
...
@@ -75,7 +75,7 @@ inline tvm::Tensor leaky_relu(const tvm::Tensor& t,
...
@@ -75,7 +75,7 @@ inline tvm::Tensor leaky_relu(const tvm::Tensor& t,
[
&
](
const
tvm
::
Array
<
tvm
::
Var
>&
i
)
{
[
&
](
const
tvm
::
Array
<
tvm
::
Var
>&
i
)
{
auto
value
=
t
(
i
);
auto
value
=
t
(
i
);
auto
calpha
=
tvm
::
make_const
(
value
.
type
(),
alpha
);
auto
calpha
=
tvm
::
make_const
(
value
.
type
(),
alpha
);
return
tvm
::
select
(
value
>
0
,
value
,
value
*
calpha
);
return
tvm
::
ir
::
Select
::
make
(
value
>
0
,
value
,
value
*
calpha
);
},
},
name
,
name
,
tag
);
tag
);
...
@@ -106,9 +106,11 @@ inline tvm::Tensor prelu(const tvm::Tensor &x,
...
@@ -106,9 +106,11 @@ inline tvm::Tensor prelu(const tvm::Tensor &x,
return
tvm
::
compute
(
x
->
shape
,
return
tvm
::
compute
(
x
->
shape
,
[
&
](
const
tvm
::
Array
<
tvm
::
Var
>
&
indices
)
{
[
&
](
const
tvm
::
Array
<
tvm
::
Var
>
&
indices
)
{
return
tvm
::
select
(
x
(
indices
)
>
0
,
auto
xval
=
x
(
indices
);
x
(
indices
),
return
tvm
::
ir
::
Select
::
make
(
x
(
indices
)
*
slope
(
indices
[
axis
]));
xval
>
0
,
xval
,
xval
*
slope
(
indices
[
axis
]));
},
},
name
,
name
,
tag
);
tag
);
...
@@ -193,7 +195,8 @@ inline tvm::Tensor pad(const tvm::Tensor& t,
...
@@ -193,7 +195,8 @@ inline tvm::Tensor pad(const tvm::Tensor& t,
}
}
}
}
if
(
sel
.
size
()
!=
0
)
{
if
(
sel
.
size
()
!=
0
)
{
return
tvm
::
select
(
detail
::
Map
(
sel
,
tvm
::
ir
::
And
::
make
),
t
(
indices
),
pad_value
);
return
tvm
::
if_then_else
(
detail
::
Map
(
sel
,
tvm
::
ir
::
And
::
make
),
t
(
indices
),
pad_value
);
}
}
return
t
(
indices
);
return
t
(
indices
);
};
};
...
...
topi/include/topi/nn/dilate.h
View file @
9d20fa1b
...
@@ -76,7 +76,8 @@ inline Tensor dilate(const Tensor& x,
...
@@ -76,7 +76,8 @@ inline Tensor dilate(const Tensor& x,
}
}
if
(
not_zero
.
size
()
>
0
)
{
if
(
not_zero
.
size
()
>
0
)
{
auto
all_not_zero
=
all
(
not_zero
);
auto
all_not_zero
=
all
(
not_zero
);
return
tvm
::
select
(
all_not_zero
,
x
(
index_tuple
),
make_const
(
x
->
dtype
,
0
));
return
tvm
::
if_then_else
(
all_not_zero
,
x
(
index_tuple
),
make_const
(
x
->
dtype
,
0
));
}
}
return
x
(
index_tuple
);
return
x
(
index_tuple
);
},
name
,
tag
);
},
name
,
tag
);
...
...
topi/include/topi/reduction.h
View file @
9d20fa1b
...
@@ -411,8 +411,8 @@ inline Tensor argmin(const Tensor& data,
...
@@ -411,8 +411,8 @@ inline Tensor argmin(const Tensor& data,
bool
atleast1d
=
false
)
{
bool
atleast1d
=
false
)
{
auto
fcombine
=
[](
Array
<
Var
>
lhs
,
Array
<
Var
>
rhs
)
{
auto
fcombine
=
[](
Array
<
Var
>
lhs
,
Array
<
Var
>
rhs
)
{
Array
<
Expr
>
result
;
Array
<
Expr
>
result
;
result
.
push_back
(
tvm
::
select
(
lhs
[
1
]
<=
rhs
[
1
],
lhs
[
0
],
rhs
[
0
]));
// idx
result
.
push_back
(
tvm
::
ir
::
Select
::
make
(
lhs
[
1
]
<=
rhs
[
1
],
lhs
[
0
],
rhs
[
0
]));
// idx
result
.
push_back
(
tvm
::
select
(
lhs
[
1
]
<=
rhs
[
1
],
lhs
[
1
],
rhs
[
1
]));
// val
result
.
push_back
(
tvm
::
ir
::
Select
::
make
(
lhs
[
1
]
<=
rhs
[
1
],
lhs
[
1
],
rhs
[
1
]));
// val
return
result
;
return
result
;
};
};
auto
fidentity
=
[](
std
::
vector
<
Type
>
types
)
{
auto
fidentity
=
[](
std
::
vector
<
Type
>
types
)
{
...
@@ -445,8 +445,8 @@ inline Tensor argmax(const Tensor& data,
...
@@ -445,8 +445,8 @@ inline Tensor argmax(const Tensor& data,
bool
atleast1d
=
false
)
{
bool
atleast1d
=
false
)
{
auto
fcombine
=
[](
Array
<
Var
>
lhs
,
Array
<
Var
>
rhs
)
{
auto
fcombine
=
[](
Array
<
Var
>
lhs
,
Array
<
Var
>
rhs
)
{
Array
<
Expr
>
result
;
Array
<
Expr
>
result
;
result
.
push_back
(
tvm
::
select
(
lhs
[
1
]
>=
rhs
[
1
],
lhs
[
0
],
rhs
[
0
]));
// idx
result
.
push_back
(
tvm
::
ir
::
Select
::
make
(
lhs
[
1
]
>=
rhs
[
1
],
lhs
[
0
],
rhs
[
0
]));
// idx
result
.
push_back
(
tvm
::
select
(
lhs
[
1
]
>=
rhs
[
1
],
lhs
[
1
],
rhs
[
1
]));
// val
result
.
push_back
(
tvm
::
ir
::
Select
::
make
(
lhs
[
1
]
>=
rhs
[
1
],
lhs
[
1
],
rhs
[
1
]));
// val
return
result
;
return
result
;
};
};
auto
fidentity
=
[](
std
::
vector
<
Type
>
types
)
{
auto
fidentity
=
[](
std
::
vector
<
Type
>
types
)
{
...
...
topi/include/topi/transform.h
View file @
9d20fa1b
...
@@ -314,9 +314,9 @@ inline Tensor concatenate(const Array<Tensor>& inputs,
...
@@ -314,9 +314,9 @@ inline Tensor concatenate(const Array<Tensor>& inputs,
idx
.
push_back
(
indices
[
i
]);
idx
.
push_back
(
indices
[
i
]);
}
}
ret
=
tvm
::
select
(
ind
>=
0
,
ret
=
tvm
::
if_then_else
(
ind
>=
0
,
inputs
[
i
+
1
](
idx
),
inputs
[
i
+
1
](
idx
),
ret
);
ret
);
}
}
return
ret
;
return
ret
;
},
name
,
tag
);
},
name
,
tag
);
...
@@ -652,7 +652,7 @@ inline Tensor where(const Tensor& condition,
...
@@ -652,7 +652,7 @@ inline Tensor where(const Tensor& condition,
<<
condition
->
shape
.
size
()
<<
" vs "
<<
x
->
shape
.
size
();
<<
condition
->
shape
.
size
()
<<
" vs "
<<
x
->
shape
.
size
();
out
=
compute
(
out
=
compute
(
oshape
,
[
&
](
const
Array
<
Var
>&
indices
)
{
oshape
,
[
&
](
const
Array
<
Var
>&
indices
)
{
return
tvm
::
select
(
condition
(
indices
)
!=
0
,
x
(
indices
),
y
(
indices
));
return
tvm
::
ir
::
Select
::
make
(
condition
(
indices
)
!=
0
,
x
(
indices
),
y
(
indices
));
},
name
,
tag
);
},
name
,
tag
);
}
else
{
}
else
{
CHECK_EQ
(
topi
::
GetConstInt
(
condition
->
shape
[
0
]),
topi
::
GetConstInt
(
x
->
shape
[
0
]))
CHECK_EQ
(
topi
::
GetConstInt
(
condition
->
shape
[
0
]),
topi
::
GetConstInt
(
x
->
shape
[
0
]))
...
@@ -661,8 +661,8 @@ inline Tensor where(const Tensor& condition,
...
@@ -661,8 +661,8 @@ inline Tensor where(const Tensor& condition,
out
=
compute
(
out
=
compute
(
oshape
,
[
&
](
const
Array
<
Var
>&
indices
)
{
oshape
,
[
&
](
const
Array
<
Var
>&
indices
)
{
Array
<
Expr
>
condition_idx
{
indices
[
0
]};
Array
<
Expr
>
condition_idx
{
indices
[
0
]};
return
tvm
::
select
(
condition
(
condition_idx
)
!=
0
,
return
tvm
::
ir
::
Select
::
make
(
condition
(
condition_idx
)
!=
0
,
x
(
indices
),
y
(
indices
));
x
(
indices
),
y
(
indices
));
},
name
,
tag
);
},
name
,
tag
);
}
}
return
out
;
return
out
;
...
...
topi/python/topi/cuda/conv2d_transpose_nchw.py
View file @
9d20fa1b
...
@@ -72,7 +72,7 @@ def conv2d_transpose_nchw_cuda(cfg, Input, Filter, strides, padding, out_dtype):
...
@@ -72,7 +72,7 @@ def conv2d_transpose_nchw_cuda(cfg, Input, Filter, strides, padding, out_dtype):
index_tuple
.
append
(
indices
[
i
])
index_tuple
.
append
(
indices
[
i
])
if
not_zero
:
if
not_zero
:
not_zero
=
tvm
.
all
(
*
not_zero
)
not_zero
=
tvm
.
all
(
*
not_zero
)
return
tvm
.
select
(
not_zero
,
data
(
*
index_tuple
),
tvm
.
const
(
0.0
,
data
.
dtype
))
return
tvm
.
if_then_else
(
not_zero
,
data
(
*
index_tuple
),
tvm
.
const
(
0.0
,
data
.
dtype
))
return
data
(
*
index_tuple
)
return
data
(
*
index_tuple
)
# convolution stage
# convolution stage
...
...
topi/python/topi/cuda/nms.py
View file @
9d20fa1b
...
@@ -315,11 +315,11 @@ def sort_ir_out(data, index, new_index, loc, output, axis_mul_before, axis_mul_a
...
@@ -315,11 +315,11 @@ def sort_ir_out(data, index, new_index, loc, output, axis_mul_before, axis_mul_a
start
=
0
start
=
0
with
ib
.
else_scope
():
with
ib
.
else_scope
():
start
=
sizes
[
tid
-
1
]
start
=
sizes
[
tid
-
1
]
p_out
[
base_idx
+
k
*
axis_mul_after
]
=
tvm
.
select
(
p_out
[
base_idx
+
k
*
axis_mul_after
]
=
tvm
.
if_then_else
(
k
<
p_index
[
tid
],
index_new
[
k
+
start
],
k
)
k
<
p_index
[
tid
],
index_new
[
k
+
start
],
k
)
with
ib
.
else_scope
():
with
ib
.
else_scope
():
with
ib
.
if_scope
(
tid
<
data
.
shape
[
axis
]):
with
ib
.
if_scope
(
tid
<
data
.
shape
[
axis
]):
p_out
[
tid
]
=
tvm
.
select
(
tid
<
p_index
[
0
],
index_new
[
tid
],
tid
)
p_out
[
tid
]
=
tvm
.
if_then_else
(
tid
<
p_index
[
0
],
index_new
[
tid
],
tid
)
body
=
ib
.
get
()
body
=
ib
.
get
()
return
body
return
body
...
@@ -470,7 +470,7 @@ def nms_ir(data, sort_result, valid_count, out, nms_threshold, force_suppress, n
...
@@ -470,7 +470,7 @@ def nms_ir(data, sort_result, valid_count, out, nms_threshold, force_suppress, n
(
out_tensor
[
box_a_idx
+
3
]
-
out_tensor
[
box_a_idx
+
1
])
+
\
(
out_tensor
[
box_a_idx
+
3
]
-
out_tensor
[
box_a_idx
+
1
])
+
\
(
out_tensor
[
box_b_idx
+
2
]
-
out_tensor
[
box_b_idx
])
*
\
(
out_tensor
[
box_b_idx
+
2
]
-
out_tensor
[
box_b_idx
])
*
\
(
out_tensor
[
box_b_idx
+
3
]
-
out_tensor
[
box_b_idx
+
1
])
-
i
(
out_tensor
[
box_b_idx
+
3
]
-
out_tensor
[
box_b_idx
+
1
])
-
i
return
tvm
.
s
elect
(
u
<=
0.0
,
0.0
,
i
/
u
)
return
tvm
.
expr
.
S
elect
(
u
<=
0.0
,
0.0
,
i
/
u
)
max_threads
=
int
(
math
.
sqrt
(
max_threads
=
int
(
math
.
sqrt
(
tvm
.
target
.
current_target
(
allow_none
=
False
)
.
max_num_threads
))
tvm
.
target
.
current_target
(
allow_none
=
False
)
.
max_num_threads
))
...
@@ -506,7 +506,7 @@ def nms_ir(data, sort_result, valid_count, out, nms_threshold, force_suppress, n
...
@@ -506,7 +506,7 @@ def nms_ir(data, sort_result, valid_count, out, nms_threshold, force_suppress, n
tvm
.
all
(
nms_threshold_node
>
0
,
nms_threshold_node
<
1
,
tvm
.
all
(
nms_threshold_node
>
0
,
nms_threshold_node
<
1
,
p_valid_count
[
0
]
>
0
)):
p_valid_count
[
0
]
>
0
)):
# Reorder output
# Reorder output
nkeep
=
tvm
.
select
(
nkeep
=
tvm
.
if_then_else
(
tvm
.
all
(
nms_topk_node
>
0
,
nms_topk
<
p_valid_count
[
n
]),
tvm
.
all
(
nms_topk_node
>
0
,
nms_topk
<
p_valid_count
[
n
]),
nms_topk
,
p_valid_count
[
n
])
nms_topk
,
p_valid_count
[
n
])
with
ib
.
if_scope
(
i
<
nkeep
):
with
ib
.
if_scope
(
i
<
nkeep
):
...
...
topi/python/topi/cuda/ssd/multibox.py
View file @
9d20fa1b
...
@@ -77,13 +77,14 @@ def multibox_prior_ir(data, out, sizes, ratios, steps, offsets):
...
@@ -77,13 +77,14 @@ def multibox_prior_ir(data, out, sizes, ratios, steps, offsets):
center_w
=
(
j
+
offset_w
)
*
steps_w
center_w
=
(
j
+
offset_w
)
*
steps_w
for
k
in
range
(
num_sizes
+
num_ratios
-
1
):
for
k
in
range
(
num_sizes
+
num_ratios
-
1
):
w
=
tvm
.
select
(
k
<
num_sizes
,
w
=
tvm
.
if_then_else
(
k
<
num_sizes
,
size_ratio_concat
[
size_ratio_concat
[
k
]
*
in_height
/
in_width
/
2.0
,
k
]
*
in_height
/
in_width
/
2.0
,
size_ratio_concat
[
0
]
*
in_height
/
in_width
*
size_ratio_concat
[
0
]
*
in_height
/
in_width
*
math
.
sqrt
(
size_ratio_concat
[
k
+
1
])
/
2.0
)
math
.
sqrt
(
size_ratio_concat
[
k
+
1
])
/
2.0
)
h
=
tvm
.
select
(
k
<
num_sizes
,
size_ratio_concat
[
k
]
/
2.0
,
h
=
tvm
.
if_then_else
(
size_ratio_concat
[
0
]
/
math
.
sqrt
(
size_ratio_concat
[
k
+
1
])
/
2.0
)
k
<
num_sizes
,
size_ratio_concat
[
k
]
/
2.0
,
size_ratio_concat
[
0
]
/
math
.
sqrt
(
size_ratio_concat
[
k
+
1
])
/
2.0
)
count
=
(
i
*
in_width
*
(
num_sizes
+
num_ratios
-
1
)
+
count
=
(
i
*
in_width
*
(
num_sizes
+
num_ratios
-
1
)
+
j
*
(
num_sizes
+
num_ratios
-
1
)
+
k
)
*
4
j
*
(
num_sizes
+
num_ratios
-
1
)
+
k
)
*
4
p_out
[
count
]
=
center_w
-
w
p_out
[
count
]
=
center_w
-
w
...
@@ -278,10 +279,10 @@ def transform_loc_ir(loc_pred, anchor, temp_flag, temp_id, temp_score_in, \
...
@@ -278,10 +279,10 @@ def transform_loc_ir(loc_pred, anchor, temp_flag, temp_id, temp_score_in, \
oy
=
py
*
vy
*
ah
+
ay
oy
=
py
*
vy
*
ah
+
ay
ow
=
tvm
.
exp
(
pw
*
vw
)
*
aw
/
2.0
ow
=
tvm
.
exp
(
pw
*
vw
)
*
aw
/
2.0
oh
=
tvm
.
exp
(
ph
*
vh
)
*
ah
/
2.0
oh
=
tvm
.
exp
(
ph
*
vh
)
*
ah
/
2.0
return
tvm
.
select
(
clip
,
tvm
.
make
.
Max
(
0.0
,
tvm
.
make
.
Min
(
1.0
,
ox
-
ow
)),
ox
-
ow
),
\
return
tvm
.
if_then_else
(
clip
,
tvm
.
make
.
Max
(
0.0
,
tvm
.
make
.
Min
(
1.0
,
ox
-
ow
)),
ox
-
ow
),
\
tvm
.
select
(
clip
,
tvm
.
make
.
Max
(
0.0
,
tvm
.
make
.
Min
(
1.0
,
oy
-
oh
)),
oy
-
oh
),
\
tvm
.
if_then_else
(
clip
,
tvm
.
make
.
Max
(
0.0
,
tvm
.
make
.
Min
(
1.0
,
oy
-
oh
)),
oy
-
oh
),
\
tvm
.
select
(
clip
,
tvm
.
make
.
Max
(
0.0
,
tvm
.
make
.
Min
(
1.0
,
ox
+
ow
)),
ox
+
ow
),
\
tvm
.
if_then_else
(
clip
,
tvm
.
make
.
Max
(
0.0
,
tvm
.
make
.
Min
(
1.0
,
ox
+
ow
)),
ox
+
ow
),
\
tvm
.
select
(
clip
,
tvm
.
make
.
Max
(
0.0
,
tvm
.
make
.
Min
(
1.0
,
oy
+
oh
)),
oy
+
oh
)
tvm
.
if_then_else
(
clip
,
tvm
.
make
.
Max
(
0.0
,
tvm
.
make
.
Min
(
1.0
,
oy
+
oh
)),
oy
+
oh
)
max_threads
=
int
(
max_threads
=
int
(
tvm
.
target
.
current_target
(
allow_none
=
False
)
.
max_num_threads
)
tvm
.
target
.
current_target
(
allow_none
=
False
)
.
max_num_threads
)
...
...
topi/python/topi/mali/conv2d.py
View file @
9d20fa1b
...
@@ -296,9 +296,10 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt
...
@@ -296,9 +296,10 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt
# pack input tile
# pack input tile
input_tile
=
tvm
.
compute
((
CI
,
P_round
//
bnb
,
alpha
,
alpha
,
bnb
),
lambda
ci
,
b
,
eps
,
nu
,
bb
:
\
input_tile
=
tvm
.
compute
((
CI
,
P_round
//
bnb
,
alpha
,
alpha
,
bnb
),
lambda
ci
,
b
,
eps
,
nu
,
bb
:
\
tvm
.
select
(
b
*
bnb
+
bb
<
P
,
tvm
.
if_then_else
(
data_pad
[(
b
*
bnb
+
bb
)
//
(
nH
*
nW
)][
ci
][(
b
*
bnb
+
bb
)
//
nW
%
nH
*
m
+
eps
]
b
*
bnb
+
bb
<
P
,
[(
b
*
bnb
+
bb
)
%
nW
*
m
+
nu
],
tvm
.
const
(
0
,
data_pad
.
dtype
)),
name
=
'd'
)
data_pad
[(
b
*
bnb
+
bb
)
//
(
nH
*
nW
)][
ci
][(
b
*
bnb
+
bb
)
//
nW
%
nH
*
m
+
eps
]
[(
b
*
bnb
+
bb
)
%
nW
*
m
+
nu
],
tvm
.
const
(
0
,
data_pad
.
dtype
)),
name
=
'd'
)
# transform kernel
# transform kernel
if
pre_computed
:
if
pre_computed
:
...
...
topi/python/topi/nn/dilate.py
View file @
9d20fa1b
...
@@ -44,7 +44,7 @@ def dilate(data, strides, name="DilatedInput"):
...
@@ -44,7 +44,7 @@ def dilate(data, strides, name="DilatedInput"):
index_tuple
.
append
(
indices
[
i
])
index_tuple
.
append
(
indices
[
i
])
if
not_zero
:
if
not_zero
:
not_zero
=
tvm
.
all
(
*
not_zero
)
not_zero
=
tvm
.
all
(
*
not_zero
)
return
tvm
.
select
(
not_zero
,
data
(
*
index_tuple
),
tvm
.
const
(
0.0
,
data
.
dtype
))
return
tvm
.
if_then_else
(
not_zero
,
data
(
*
index_tuple
),
tvm
.
const
(
0.0
,
data
.
dtype
))
return
data
(
*
index_tuple
)
return
data
(
*
index_tuple
)
return
tvm
.
compute
(
out_shape
,
_dilate
,
name
=
name
)
return
tvm
.
compute
(
out_shape
,
_dilate
,
name
=
name
)
topi/python/topi/nn/elemwise.py
View file @
9d20fa1b
...
@@ -41,7 +41,7 @@ def leaky_relu(x, alpha):
...
@@ -41,7 +41,7 @@ def leaky_relu(x, alpha):
def
_compute
(
*
indices
):
def
_compute
(
*
indices
):
value
=
x
(
*
indices
)
value
=
x
(
*
indices
)
calpha
=
tvm
.
const
(
alpha
,
value
.
dtype
)
calpha
=
tvm
.
const
(
alpha
,
value
.
dtype
)
return
tvm
.
s
elect
(
value
>
0
,
value
,
value
*
calpha
)
return
tvm
.
expr
.
S
elect
(
value
>
0
,
value
,
value
*
calpha
)
return
tvm
.
compute
(
x
.
shape
,
_compute
)
return
tvm
.
compute
(
x
.
shape
,
_compute
)
@tvm.tag_scope
(
tag
=
tag
.
BROADCAST
)
@tvm.tag_scope
(
tag
=
tag
.
BROADCAST
)
...
@@ -74,5 +74,6 @@ def prelu(x, slope, axis=1):
...
@@ -74,5 +74,6 @@ def prelu(x, slope, axis=1):
assert
get_const_int
(
slope
.
shape
[
0
])
==
get_const_int
(
x
.
shape
[
axis
])
assert
get_const_int
(
slope
.
shape
[
0
])
==
get_const_int
(
x
.
shape
[
axis
])
def
_compute_channelwise
(
*
indices
):
def
_compute_channelwise
(
*
indices
):
return
tvm
.
select
(
x
(
*
indices
)
>
0
,
x
(
*
indices
),
x
(
*
indices
)
*
slope
(
indices
[
axis
]))
xval
=
x
(
*
indices
)
return
tvm
.
expr
.
Select
(
xval
>
0
,
xval
,
xval
*
slope
(
indices
[
axis
]))
return
tvm
.
compute
(
x
.
shape
,
_compute_channelwise
)
return
tvm
.
compute
(
x
.
shape
,
_compute_channelwise
)
topi/python/topi/nn/pad.py
View file @
9d20fa1b
...
@@ -55,6 +55,6 @@ def pad(data, pad_before, pad_after=None, pad_value=0.0, name="PadInput"):
...
@@ -55,6 +55,6 @@ def pad(data, pad_before, pad_after=None, pad_value=0.0, name="PadInput"):
not_zero
.
append
(
indices
[
i
]
<
data
.
shape
[
i
]
+
pad_before
[
i
])
not_zero
.
append
(
indices
[
i
]
<
data
.
shape
[
i
]
+
pad_before
[
i
])
if
not_zero
:
if
not_zero
:
not_zero
=
tvm
.
all
(
*
not_zero
)
not_zero
=
tvm
.
all
(
*
not_zero
)
return
tvm
.
select
(
not_zero
,
data
(
*
index_tuple
),
pad_value
)
return
tvm
.
if_then_else
(
not_zero
,
data
(
*
index_tuple
),
pad_value
)
return
data
(
*
index_tuple
)
return
data
(
*
index_tuple
)
return
tvm
.
compute
(
out_shape
,
_pad
,
name
=
name
)
return
tvm
.
compute
(
out_shape
,
_pad
,
name
=
name
)
topi/python/topi/nn/util.py
View file @
9d20fa1b
...
@@ -55,8 +55,8 @@ def infer_stride(data, kernel, out):
...
@@ -55,8 +55,8 @@ def infer_stride(data, kernel, out):
_
,
_
,
IH
,
IW
=
data
.
shape
_
,
_
,
IH
,
IW
=
data
.
shape
_
,
_
,
KH
,
KW
=
kernel
.
shape
_
,
_
,
KH
,
KW
=
kernel
.
shape
_
,
_
,
OH
,
OW
=
out
.
shape
_
,
_
,
OH
,
OW
=
out
.
shape
hstride
=
(
IH
-
KH
)
//
tvm
.
make
.
Max
(
OH
-
1
,
1
)
+
tvm
.
s
elect
(
OH
==
1
,
1
,
0
)
hstride
=
(
IH
-
KH
)
//
tvm
.
make
.
Max
(
OH
-
1
,
1
)
+
tvm
.
expr
.
S
elect
(
OH
==
1
,
1
,
0
)
wstride
=
(
IW
-
KW
)
//
tvm
.
make
.
Max
(
OW
-
1
,
1
)
+
tvm
.
s
elect
(
OW
==
1
,
1
,
0
)
wstride
=
(
IW
-
KW
)
//
tvm
.
make
.
Max
(
OW
-
1
,
1
)
+
tvm
.
expr
.
S
elect
(
OW
==
1
,
1
,
0
)
return
get_const_int
(
hstride
),
get_const_int
(
wstride
)
return
get_const_int
(
hstride
),
get_const_int
(
wstride
)
...
...
topi/python/topi/util.py
View file @
9d20fa1b
...
@@ -249,9 +249,9 @@ def const_matrix(matrix, name="const_matrix"):
...
@@ -249,9 +249,9 @@ def const_matrix(matrix, name="const_matrix"):
now
=
tvm
.
const
(
0.0
,
dtype
)
now
=
tvm
.
const
(
0.0
,
dtype
)
for
ii
in
range
(
row
):
for
ii
in
range
(
row
):
for
jj
in
range
(
col
):
for
jj
in
range
(
col
):
now
=
tvm
.
s
elect
(
tvm
.
all
(
i
%
row
==
ii
,
j
%
col
==
jj
),
now
=
tvm
.
expr
.
S
elect
(
tvm
.
all
(
i
%
row
==
ii
,
j
%
col
==
jj
),
tvm
.
const
(
matrix
[
ii
][
jj
],
dtype
),
tvm
.
const
(
matrix
[
ii
][
jj
],
dtype
),
now
)
now
)
return
now
return
now
return
tvm
.
compute
(
matrix
.
shape
,
select_array
,
name
=
name
)
return
tvm
.
compute
(
matrix
.
shape
,
select_array
,
name
=
name
)
topi/python/topi/vision/nms.py
View file @
9d20fa1b
...
@@ -47,7 +47,7 @@ def nms_ir(data, sort_result, valid_count, out, nms_threshold, force_suppress, n
...
@@ -47,7 +47,7 @@ def nms_ir(data, sort_result, valid_count, out, nms_threshold, force_suppress, n
(
out_tensor
[
box_a_idx
+
3
]
-
out_tensor
[
box_a_idx
+
1
])
+
\
(
out_tensor
[
box_a_idx
+
3
]
-
out_tensor
[
box_a_idx
+
1
])
+
\
(
out_tensor
[
box_b_idx
+
2
]
-
out_tensor
[
box_b_idx
])
*
\
(
out_tensor
[
box_b_idx
+
2
]
-
out_tensor
[
box_b_idx
])
*
\
(
out_tensor
[
box_b_idx
+
3
]
-
out_tensor
[
box_b_idx
+
1
])
-
i
(
out_tensor
[
box_b_idx
+
3
]
-
out_tensor
[
box_b_idx
+
1
])
-
i
return
tvm
.
s
elect
(
u
<=
0.0
,
0.0
,
i
/
u
)
return
tvm
.
expr
.
S
elect
(
u
<=
0.0
,
0.0
,
i
/
u
)
ib
=
tvm
.
ir_builder
.
create
()
ib
=
tvm
.
ir_builder
.
create
()
p_data
=
ib
.
buffer_ptr
(
data
)
p_data
=
ib
.
buffer_ptr
(
data
)
...
@@ -64,8 +64,9 @@ def nms_ir(data, sort_result, valid_count, out, nms_threshold, force_suppress, n
...
@@ -64,8 +64,9 @@ def nms_ir(data, sort_result, valid_count, out, nms_threshold, force_suppress, n
with
ib
.
if_scope
(
tvm
.
all
(
nms_threshold_node
>
0
,
nms_threshold_node
<
1
,
with
ib
.
if_scope
(
tvm
.
all
(
nms_threshold_node
>
0
,
nms_threshold_node
<
1
,
p_valid_count
[
0
]
>
0
)):
p_valid_count
[
0
]
>
0
)):
# Reorder output
# Reorder output
nkeep
=
tvm
.
select
(
tvm
.
all
(
nms_topk_node
>
0
,
nms_topk
<
p_valid_count
[
n
]),
nkeep
=
tvm
.
if_then_else
(
nms_topk
,
p_valid_count
[
n
])
tvm
.
all
(
nms_topk_node
>
0
,
nms_topk
<
p_valid_count
[
n
]),
nms_topk
,
p_valid_count
[
n
])
with
ib
.
for_range
(
0
,
nkeep
,
name
=
"l"
)
as
l
:
with
ib
.
for_range
(
0
,
nkeep
,
name
=
"l"
)
as
l
:
with
ib
.
for_range
(
0
,
6
,
name
=
"m"
)
as
m
:
with
ib
.
for_range
(
0
,
6
,
name
=
"m"
)
as
m
:
p_out
[(
n
*
num_anchors
*
6
p_out
[(
n
*
num_anchors
*
6
...
...
topi/python/topi/vision/rcnn/roi_align.py
View file @
9d20fa1b
...
@@ -47,7 +47,7 @@ def roi_align_nchw(data, rois, pooled_size, spatial_scale, sample_ratio=-1):
...
@@ -47,7 +47,7 @@ def roi_align_nchw(data, rois, pooled_size, spatial_scale, sample_ratio=-1):
y
=
tvm
.
max
(
y
,
0.0
)
y
=
tvm
.
max
(
y
,
0.0
)
x
=
tvm
.
max
(
x
,
0.0
)
x
=
tvm
.
max
(
x
,
0.0
)
val
=
bilinear_sample_nchw
(
data
,
(
i
,
c
,
y
,
x
),
height
-
1
,
width
-
1
)
val
=
bilinear_sample_nchw
(
data
,
(
i
,
c
,
y
,
x
),
height
-
1
,
width
-
1
)
return
tvm
.
select
(
outside
,
0.0
,
val
)
return
tvm
.
if_then_else
(
outside
,
0.0
,
val
)
def
_sample
(
i
,
c
,
ph
,
pw
):
def
_sample
(
i
,
c
,
ph
,
pw
):
roi
=
rois
[
i
]
roi
=
rois
[
i
]
...
...
topi/python/topi/vision/ssd/multibox.py
View file @
9d20fa1b
...
@@ -55,12 +55,13 @@ def multibox_prior_ir(data, out, sizes, ratios, steps, offsets):
...
@@ -55,12 +55,13 @@ def multibox_prior_ir(data, out, sizes, ratios, steps, offsets):
with
ib
.
for_range
(
0
,
in_width
,
name
=
"j"
)
as
j
:
with
ib
.
for_range
(
0
,
in_width
,
name
=
"j"
)
as
j
:
center_w
=
(
j
+
offset_w
)
*
steps_w
center_w
=
(
j
+
offset_w
)
*
steps_w
for
k
in
range
(
num_sizes
+
num_ratios
-
1
):
for
k
in
range
(
num_sizes
+
num_ratios
-
1
):
w
=
tvm
.
select
(
k
<
num_sizes
,
w
=
tvm
.
if_then_else
(
k
<
num_sizes
,
size_ratio_concat
[
k
]
*
in_height
/
in_width
/
2.0
,
size_ratio_concat
[
k
]
*
in_height
/
in_width
/
2.0
,
size_ratio_concat
[
0
]
*
in_height
/
in_width
*
size_ratio_concat
[
0
]
*
in_height
/
in_width
*
math
.
sqrt
(
size_ratio_concat
[
k
+
1
])
/
2.0
)
math
.
sqrt
(
size_ratio_concat
[
k
+
1
])
/
2.0
)
h
=
tvm
.
select
(
k
<
num_sizes
,
size_ratio_concat
[
k
]
/
2.0
,
h
=
tvm
.
if_then_else
(
size_ratio_concat
[
0
]
/
math
.
sqrt
(
size_ratio_concat
[
k
+
1
])
/
2.0
)
k
<
num_sizes
,
size_ratio_concat
[
k
]
/
2.0
,
size_ratio_concat
[
0
]
/
math
.
sqrt
(
size_ratio_concat
[
k
+
1
])
/
2.0
)
count
=
(
i
*
in_width
*
(
num_sizes
+
num_ratios
-
1
)
+
count
=
(
i
*
in_width
*
(
num_sizes
+
num_ratios
-
1
)
+
j
*
(
num_sizes
+
num_ratios
-
1
)
+
k
)
*
4
j
*
(
num_sizes
+
num_ratios
-
1
)
+
k
)
*
4
p_out
[
count
]
=
center_w
-
w
p_out
[
count
]
=
center_w
-
w
...
@@ -164,10 +165,10 @@ def transform_loc_ir(cls_prob, loc_pred, anchor, valid_count, out, clip, thresho
...
@@ -164,10 +165,10 @@ def transform_loc_ir(cls_prob, loc_pred, anchor, valid_count, out, clip, thresho
oy
=
py
*
vy
*
ah
+
ay
oy
=
py
*
vy
*
ah
+
ay
ow
=
tvm
.
exp
(
pw
*
vw
)
*
aw
/
2.0
ow
=
tvm
.
exp
(
pw
*
vw
)
*
aw
/
2.0
oh
=
tvm
.
exp
(
ph
*
vh
)
*
ah
/
2.0
oh
=
tvm
.
exp
(
ph
*
vh
)
*
ah
/
2.0
return
tvm
.
select
(
clip
,
tvm
.
max
(
0
,
tvm
.
min
(
1
,
ox
-
ow
)),
ox
-
ow
),
\
return
tvm
.
if_then_else
(
clip
,
tvm
.
max
(
0
,
tvm
.
min
(
1
,
ox
-
ow
)),
ox
-
ow
),
\
tvm
.
select
(
clip
,
tvm
.
max
(
0
,
tvm
.
min
(
1
,
oy
-
oh
)),
oy
-
oh
),
\
tvm
.
if_then_else
(
clip
,
tvm
.
max
(
0
,
tvm
.
min
(
1
,
oy
-
oh
)),
oy
-
oh
),
\
tvm
.
select
(
clip
,
tvm
.
max
(
0
,
tvm
.
min
(
1
,
ox
+
ow
)),
ox
+
ow
),
\
tvm
.
if_then_else
(
clip
,
tvm
.
max
(
0
,
tvm
.
min
(
1
,
ox
+
ow
)),
ox
+
ow
),
\
tvm
.
select
(
clip
,
tvm
.
max
(
0
,
tvm
.
min
(
1
,
oy
+
oh
)),
oy
+
oh
)
tvm
.
if_then_else
(
clip
,
tvm
.
max
(
0
,
tvm
.
min
(
1
,
oy
+
oh
)),
oy
+
oh
)
batch_size
=
cls_prob
.
shape
[
0
]
batch_size
=
cls_prob
.
shape
[
0
]
num_classes
=
cls_prob
.
shape
[
1
]
num_classes
=
cls_prob
.
shape
[
1
]
...
@@ -190,7 +191,7 @@ def transform_loc_ir(cls_prob, loc_pred, anchor, valid_count, out, clip, thresho
...
@@ -190,7 +191,7 @@ def transform_loc_ir(cls_prob, loc_pred, anchor, valid_count, out, clip, thresho
with
ib
.
for_range
(
0
,
num_classes
,
name
=
"j"
)
as
j
:
with
ib
.
for_range
(
0
,
num_classes
,
name
=
"j"
)
as
j
:
with
ib
.
if_scope
(
j
>
0
):
with
ib
.
if_scope
(
j
>
0
):
temp
=
p_cls_prob
[
n
*
num_anchors
*
num_classes
+
j
*
num_anchors
+
i
]
temp
=
p_cls_prob
[
n
*
num_anchors
*
num_classes
+
j
*
num_anchors
+
i
]
cls_id
[
0
]
=
tvm
.
select
(
temp
>
score
[
0
],
j
,
cls_id
[
0
])
cls_id
[
0
]
=
tvm
.
if_then_else
(
temp
>
score
[
0
],
j
,
cls_id
[
0
])
score
[
0
]
=
tvm
.
max
(
temp
,
score
[
0
])
score
[
0
]
=
tvm
.
max
(
temp
,
score
[
0
])
with
ib
.
if_scope
(
tvm
.
all
(
cls_id
[
0
]
>
0
,
score
[
0
]
<
threshold
)):
with
ib
.
if_scope
(
tvm
.
all
(
cls_id
[
0
]
>
0
,
score
[
0
]
<
threshold
)):
cls_id
[
0
]
=
0
cls_id
[
0
]
=
0
...
...
topi/tests/python/test_topi_conv2d_nchw.py
View file @
9d20fa1b
...
@@ -65,7 +65,7 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p
...
@@ -65,7 +65,7 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p
else
:
else
:
func
=
tvm
.
build
(
s
,
[
A
,
W
,
C
],
device
,
name
=
"relu_
%
d_
%
d_
%
d_
%
d_
%
d_
%
d_
%
d_
%
d"
%
(
batch
,
in_channel
,
in_size
,
num_filter
,
kernel
,
stride
,
padding
,
dilation
))
func
=
tvm
.
build
(
s
,
[
A
,
W
,
C
],
device
,
name
=
"relu_
%
d_
%
d_
%
d_
%
d_
%
d_
%
d_
%
d_
%
d"
%
(
batch
,
in_channel
,
in_size
,
num_filter
,
kernel
,
stride
,
padding
,
dilation
))
func
(
a
,
w
,
c
)
func
(
a
,
w
,
c
)
tvm
.
testing
.
assert_allclose
(
c
.
asnumpy
(),
c_np
,
rtol
=
1e-
5
)
tvm
.
testing
.
assert_allclose
(
c
.
asnumpy
(),
c_np
,
rtol
=
1e-
4
)
for
device
in
get_all_backend
():
for
device
in
get_all_backend
():
with
autotvm
.
tophub
.
context
(
device
):
# load tophub pre-tuned parameters
with
autotvm
.
tophub
.
context
(
device
):
# load tophub pre-tuned parameters
...
...
tutorials/language/tuple_inputs.py
View file @
9d20fa1b
...
@@ -45,8 +45,8 @@ print(tvm.lower(s, [A0, A1, B0, B1], simple_mode=True))
...
@@ -45,8 +45,8 @@ print(tvm.lower(s, [A0, A1, B0, B1], simple_mode=True))
# x and y are the operands of reduction, both of them is a tuple of index
# x and y are the operands of reduction, both of them is a tuple of index
# and value.
# and value.
def
fcombine
(
x
,
y
):
def
fcombine
(
x
,
y
):
lhs
=
tvm
.
s
elect
((
x
[
1
]
>=
y
[
1
]),
x
[
0
],
y
[
0
])
lhs
=
tvm
.
expr
.
S
elect
((
x
[
1
]
>=
y
[
1
]),
x
[
0
],
y
[
0
])
rhs
=
tvm
.
s
elect
((
x
[
1
]
>=
y
[
1
]),
x
[
1
],
y
[
1
])
rhs
=
tvm
.
expr
.
S
elect
((
x
[
1
]
>=
y
[
1
]),
x
[
1
],
y
[
1
])
return
lhs
,
rhs
return
lhs
,
rhs
# our identity element also need to be a tuple, so `fidentity` accepts
# our identity element also need to be a tuple, so `fidentity` accepts
...
...
tutorials/optimize/opt_conv_cuda.py
View file @
9d20fa1b
...
@@ -43,7 +43,7 @@ out_size = (in_size - kernel + 2*pad) // stride + 1
...
@@ -43,7 +43,7 @@ out_size = (in_size - kernel + 2*pad) // stride + 1
# Pad input
# Pad input
Apad
=
tvm
.
compute
(
Apad
=
tvm
.
compute
(
(
in_size
+
2
*
pad
,
in_size
+
2
*
pad
,
in_channel
,
batch
),
(
in_size
+
2
*
pad
,
in_size
+
2
*
pad
,
in_channel
,
batch
),
lambda
yy
,
xx
,
cc
,
nn
:
tvm
.
select
(
lambda
yy
,
xx
,
cc
,
nn
:
tvm
.
if_then_else
(
tvm
.
all
(
yy
>=
pad
,
yy
-
pad
<
in_size
,
tvm
.
all
(
yy
>=
pad
,
yy
-
pad
<
in_size
,
xx
>=
pad
,
xx
-
pad
<
in_size
),
xx
>=
pad
,
xx
-
pad
<
in_size
),
A
[
yy
-
pad
,
xx
-
pad
,
cc
,
nn
],
tvm
.
const
(
0.
,
"float32"
)),
A
[
yy
-
pad
,
xx
-
pad
,
cc
,
nn
],
tvm
.
const
(
0.
,
"float32"
)),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment