Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
9d20fa1b
Unverified
Commit
9d20fa1b
authored
Jan 11, 2019
by
Tianqi Chen
Committed by
GitHub
Jan 11, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[PASS][TENSOR] Use correct select semantics (#2394)
parent
98e761f8
Show whitespace changes
Inline
Side-by-side
Showing
35 changed files
with
174 additions
and
106 deletions
+174
-106
docs/api/python/intrin.rst
+2
-0
docs/api/python/tvm.rst
+0
-2
include/tvm/ir_operator.h
+2
-2
python/tvm/api.py
+0
-22
python/tvm/expr.py
+8
-0
python/tvm/intrin.py
+36
-0
src/arithmetic/int_set.cc
+4
-2
src/lang/ir_operator.cc
+14
-4
src/pass/inject_copy_intrin.cc
+31
-7
tests/python/integration/test_reduce.py
+1
-1
tests/python/unittest/test_codegen_llvm.py
+3
-3
tests/python/unittest/test_pass_inject_copy_intrin.py
+2
-2
tests/python/unittest/test_pass_loop_partition.py
+2
-2
tests/python/unittest/test_pass_rewrite_unsafe_select.py
+6
-5
tests/python/unittest/test_schedule_schedule_ops.py
+4
-4
topi/include/topi/image/resize.h
+4
-4
topi/include/topi/nn.h
+8
-5
topi/include/topi/nn/dilate.h
+2
-1
topi/include/topi/reduction.h
+4
-4
topi/include/topi/transform.h
+3
-3
topi/python/topi/cuda/conv2d_transpose_nchw.py
+1
-1
topi/python/topi/cuda/nms.py
+4
-4
topi/python/topi/cuda/ssd/multibox.py
+7
-6
topi/python/topi/mali/conv2d.py
+2
-1
topi/python/topi/nn/dilate.py
+1
-1
topi/python/topi/nn/elemwise.py
+3
-2
topi/python/topi/nn/pad.py
+1
-1
topi/python/topi/nn/util.py
+2
-2
topi/python/topi/util.py
+1
-1
topi/python/topi/vision/nms.py
+3
-2
topi/python/topi/vision/rcnn/roi_align.py
+1
-1
topi/python/topi/vision/ssd/multibox.py
+8
-7
topi/tests/python/test_topi_conv2d_nchw.py
+1
-1
tutorials/language/tuple_inputs.py
+2
-2
tutorials/optimize/opt_conv_cuda.py
+1
-1
No files found.
docs/api/python/intrin.rst
View file @
9d20fa1b
...
...
@@ -11,6 +11,7 @@ tvm.intrin
tvm.call_extern
tvm.call_llvm_intrin
tvm.register_intrin_rule
tvm.if_then_else
tvm.exp
tvm.log
tvm.floor
...
...
@@ -26,6 +27,7 @@ tvm.intrin
.. autofunction:: tvm.call_extern
.. autofunction:: tvm.call_llvm_intrin
.. autofunction:: tvm.register_intrin_rule
.. autofunction:: tvm.if_then_else
.. autofunction:: tvm.exp
.. autofunction:: tvm.log
.. autofunction:: tvm.floor
...
...
docs/api/python/tvm.rst
View file @
9d20fa1b
...
...
@@ -15,7 +15,6 @@ The user facing API for computation declaration.
tvm.extern
tvm.decl_buffer
tvm.reduce_axis
tvm.select
tvm.thread_axis
tvm.comm_reducer
tvm.sum
...
...
@@ -34,7 +33,6 @@ The user facing API for computation declaration.
.. autofunction:: tvm.extern
.. autofunction:: tvm.decl_buffer
.. autofunction:: tvm.reduce_axis
.. autofunction:: tvm.select
.. autofunction:: tvm.thread_axis
.. autofunction:: tvm.comm_reducer
.. autofunction:: tvm.sum
...
...
include/tvm/ir_operator.h
View file @
9d20fa1b
...
...
@@ -392,7 +392,7 @@ TVM_DLL Expr operator^(Expr a, Expr b);
*/
TVM_DLL
Expr
operator
~
(
Expr
a
);
/*!
* \brief
select result by condition
* \brief
Conditional expression.
*
* \param cond The condition
* \param true_value The value when results are true.
...
...
@@ -401,7 +401,7 @@ TVM_DLL Expr operator~(Expr a);
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL
Expr
select
(
Expr
cond
,
Expr
true_value
,
Expr
false_value
);
TVM_DLL
Expr
if_then_else
(
Expr
cond
,
Expr
true_value
,
Expr
false_value
);
/*!
* \brief Mark condition as likely.
* \param cond The condition
...
...
python/tvm/api.py
View file @
9d20fa1b
...
...
@@ -669,28 +669,6 @@ def reduce_axis(dom, name="rv"):
return
_IterVar
(
dom
,
name
,
2
)
def
select
(
cond
,
t
,
f
):
"""Construct a select branch.
Parameters
----------
cond : Expr
The condition
t : Expr
The result expression if cond is true.
f : Expr
The result expression if cond is false.
Returns
-------
node : Node
The tvm.expr.Select node
"""
return
_expr
.
Select
(
convert
(
cond
),
convert
(
t
),
convert
(
f
))
def
comm_reducer
(
fcombine
,
fidentity
,
name
=
"reduce"
):
"""Create a commutative reducer for reduction.
...
...
python/tvm/expr.py
View file @
9d20fa1b
...
...
@@ -624,6 +624,13 @@ class Not(LogicalExpr):
class
Select
(
Expr
):
"""Select node.
Note
----
Select may compute both true_value and false_value.
Use :any:`tvm.if_then_else` instead if you want to
get a conditional expression that only evaluates
the correct branch.
Parameters
----------
condition : Expr
...
...
@@ -634,6 +641,7 @@ class Select(Expr):
false_value : Expr
The value to take when condition is false.
"""
def
__init__
(
self
,
condition
,
true_value
,
false_value
):
self
.
__init_handle_by_constructor__
(
...
...
python/tvm/intrin.py
View file @
9d20fa1b
...
...
@@ -393,6 +393,42 @@ def fmod(x, y):
"""
return
call_pure_intrin
(
x
.
dtype
,
"fmod"
,
x
,
y
)
def
if_then_else
(
cond
,
t
,
f
):
"""Conditional selection expression.
Parameters
----------
cond : Expr
The condition
t : Expr
The result expression if cond is true.
f : Expr
The result expression if cond is false.
Returns
-------
result : Node
The result of conditional expression.
Note
----
Unlike Select, if_then_else will not execute
the branch that does not satisfy the condition.
You can use it to guard against out of bound access.
Unlike Select, if_then_else cannot be vectorized
if some lanes in the vector have different conditions.
"""
t
=
convert
(
t
)
f
=
convert
(
f
)
cond
=
convert
(
cond
)
if
cond
.
dtype
!=
"bool"
:
raise
TypeError
(
"The condition's data type has to be bool"
)
return
call_pure_intrin
(
t
.
dtype
,
"tvm_if_then_else"
,
cond
,
t
,
f
)
# Intrinsic rule related code
def
register_intrin_rule
(
target
,
intrin
,
f
=
None
,
override
=
False
):
"""Register an intrinsic function generation rule.
...
...
src/arithmetic/int_set.cc
View file @
9d20fa1b
...
...
@@ -268,8 +268,9 @@ inline IntSet CombineInterval<Mul>(Interval a, Interval b) {
}
else
if
(
is_negative_const
(
b
.
min
))
{
return
IntervalSet
::
make
(
e2
,
e1
);
}
else
if
(
a
.
is_bounded
())
{
using
ir
::
Select
;
Expr
cmp
=
b
.
min
>=
make_zero
(
b
.
min
.
type
().
element_of
());
return
IntervalSet
::
make
(
select
(
cmp
,
e1
,
e2
),
select
(
cmp
,
e2
,
e1
));
return
IntervalSet
::
make
(
Select
::
make
(
cmp
,
e1
,
e2
),
Select
::
make
(
cmp
,
e2
,
e1
));
}
}
LOG
(
WARNING
)
<<
"Return Everything in CombineInterval Mul"
;
...
...
@@ -294,8 +295,9 @@ inline IntSet CombineInterval<Div>(Interval a, Interval b) {
}
else
if
(
is_negative_const
(
b
.
min
))
{
return
IntervalSet
::
make
(
e2
,
e1
);
}
else
if
(
a
.
is_bounded
())
{
using
ir
::
Select
;
Expr
cmp
=
b
.
min
>=
make_zero
(
b
.
min
.
type
().
element_of
());
return
IntervalSet
::
make
(
select
(
cmp
,
e1
,
e2
),
select
(
cmp
,
e2
,
e1
));
return
IntervalSet
::
make
(
Select
::
make
(
cmp
,
e1
,
e2
),
Select
::
make
(
cmp
,
e2
,
e1
));
}
}
LOG
(
WARNING
)
<<
"Return Everything in CombineInterval Div"
;
...
...
src/lang/ir_operator.cc
View file @
9d20fa1b
...
...
@@ -240,10 +240,11 @@ Expr max(Expr a, Expr b) {
return
ir
::
Max
::
make
(
a
,
b
);
}
Expr
select
(
Expr
cond
,
Expr
true_value
,
Expr
false_value
)
{
Expr
if_then_else
(
Expr
cond
,
Expr
true_value
,
Expr
false_value
)
{
using
ir
::
IntImm
;
using
ir
::
UIntImm
;
CHECK
(
cond
.
type
().
is_bool
());
CHECK
(
cond
.
type
()
==
Bool
(
1
))
<<
"if_then_else only accept a single condition"
;
BinaryOpMatchTypes
(
true_value
,
false_value
);
if
(
const
UIntImm
*
op
=
cond
.
as
<
UIntImm
>
())
{
if
(
op
->
value
!=
0
)
{
...
...
@@ -258,7 +259,11 @@ Expr select(Expr cond, Expr true_value, Expr false_value) {
return
false_value
;
}
}
return
ir
::
Select
::
make
(
cond
,
true_value
,
false_value
);
return
ir
::
Call
::
make
(
true_value
.
type
(),
ir
::
intrinsic
::
tvm_if_then_else
,
{
cond
,
true_value
,
false_value
},
ir
::
Call
::
PureIntrinsic
);
}
Expr
likely
(
Expr
cond
)
{
...
...
@@ -402,7 +407,12 @@ Expr pow(Expr x, Expr y) {
Expr
abs
(
Expr
x
)
{
if
(
x
.
type
().
is_int
())
{
return
select
(
x
>=
make_zero
(
x
.
type
()),
x
,
-
x
);
using
ir
::
IntImm
;
const
IntImm
*
px
=
x
.
as
<
IntImm
>
();
if
(
px
)
{
return
ir
::
IntImm
::
make
(
x
.
type
(),
std
::
abs
(
px
->
value
));
}
return
ir
::
Select
::
make
(
x
>=
make_zero
(
x
.
type
()),
x
,
-
x
);
}
else
if
(
x
.
type
().
is_float
())
{
return
ir
::
Call
::
make
(
x
.
type
(),
"fabs"
,
{
x
},
ir
::
Call
::
PureIntrinsic
);
}
else
if
(
x
.
type
().
is_uint
())
{
...
...
src/pass/inject_copy_intrin.cc
View file @
9d20fa1b
...
...
@@ -35,6 +35,26 @@ class CopyIntrinInjector : public IRMutator {
}
private
:
bool
MatchCondition
(
Expr
expr
,
Expr
*
cond
,
Expr
*
true_value
,
Expr
*
false_value
)
{
if
(
const
auto
*
op
=
expr
.
as
<
Select
>
())
{
*
cond
=
op
->
condition
;
*
true_value
=
op
->
true_value
;
*
false_value
=
op
->
false_value
;
return
true
;
}
else
if
(
const
auto
*
op
=
expr
.
as
<
Call
>
())
{
if
(
op
->
name
==
intrinsic
::
tvm_if_then_else
)
{
*
cond
=
op
->
args
[
0
];
*
true_value
=
op
->
args
[
1
];
*
false_value
=
op
->
args
[
2
];
return
true
;
}
}
return
false
;
}
bool
MatchCopyPattern
(
Stmt
stmt
,
Stmt
*
out
)
{
Stmt
body
=
stmt
;
bool
is_single_point_copy
=
false
;
...
...
@@ -48,16 +68,20 @@ class CopyIntrinInjector : public IRMutator {
}
const
Store
*
store
=
body
.
as
<
Store
>
();
if
(
store
==
nullptr
)
return
false
;
const
Select
*
select
=
store
->
value
.
as
<
Select
>
();
Expr
sel_cond
,
sel_true_value
,
sel_false_value
;
bool
has_cond
=
MatchCondition
(
store
->
value
,
&
sel_cond
,
&
sel_true_value
,
&
sel_false_value
);
const
Cast
*
cast
=
store
->
value
.
as
<
Cast
>
();
const
Load
*
load
=
store
->
value
.
as
<
Load
>
();
if
(
0
==
loops
.
size
())
{
is_single_point_copy
=
true
;
CHECK
(
select
==
nullptr
);
CHECK
(
!
has_cond
);
}
// for now only support true condition matching
if
(
select
!=
nullptr
)
{
load
=
sel
ect
->
true_value
.
as
<
Load
>
();
if
(
has_cond
)
{
load
=
sel
_
true_value
.
as
<
Load
>
();
}
// cast can be part of the pattern
if
(
cast
!=
nullptr
)
{
...
...
@@ -88,10 +112,10 @@ class CopyIntrinInjector : public IRMutator {
Array
<
Expr
>
pad_before
,
pad_after
;
Expr
pad_value
;
Expr
src_elem_offset
=
load_strides
[
loop_var_size
];
if
(
select
!=
nullptr
)
{
if
(
has_cond
)
{
Array
<
Expr
>
clip_bound
=
arith
::
DetectClipBound
(
sel
ect
->
condition
,
loop_vars
);
pad_value
=
sel
ect
->
false_value
;
arith
::
DetectClipBound
(
sel
_cond
,
loop_vars
);
pad_value
=
sel
_
false_value
;
if
(
clip_bound
.
size
()
==
0
)
return
false
;
CHECK_EQ
(
src_shape
.
size
(),
loop_vars
.
size
());
CHECK_EQ
(
clip_bound
.
size
(),
loop_vars
.
size
()
*
2
);
...
...
tests/python/integration/test_reduce.py
View file @
9d20fa1b
...
...
@@ -8,7 +8,7 @@ def test_reduce_prims():
n
=
tvm
.
var
(
'n'
)
m
=
tvm
.
var
(
'm'
)
A
=
tvm
.
placeholder
((
n
,
m
),
name
=
'A'
)
R
=
tvm
.
compute
((
n
,
),
lambda
i
:
tvm
.
s
elect
((
i
>
1
),
1
,
0
),
name
=
'R'
)
R
=
tvm
.
compute
((
n
,
),
lambda
i
:
tvm
.
expr
.
S
elect
((
i
>
1
),
1
,
0
),
name
=
'R'
)
k
=
tvm
.
reduce_axis
((
0
,
m
))
B
=
tvm
.
compute
((
n
,),
lambda
i
:
reducer
(
A
[
i
,
k
],
axis
=
k
,
where
=
(
R
[
i
]
==
1
)),
name
=
'B'
)
# schedule
...
...
tests/python/unittest/test_codegen_llvm.py
View file @
9d20fa1b
...
...
@@ -287,12 +287,12 @@ def test_multiple_func():
def
test_llvm_
select
():
def
test_llvm_
condition
():
def
check_llvm
(
n
,
offset
):
if
not
tvm
.
module
.
enabled
(
"llvm"
):
return
A
=
tvm
.
placeholder
((
n
,
),
name
=
'A'
)
C
=
tvm
.
compute
((
n
,),
lambda
i
:
tvm
.
select
(
i
>=
offset
,
A
[
i
],
0.0
),
name
=
'C'
)
C
=
tvm
.
compute
((
n
,),
lambda
i
:
tvm
.
if_then_else
(
i
>=
offset
,
A
[
i
],
0.0
),
name
=
'C'
)
s
=
tvm
.
create_schedule
(
C
.
op
)
# build and invoke the kernel.
f
=
tvm
.
build
(
s
,
[
A
,
C
],
"llvm"
)
...
...
@@ -462,7 +462,7 @@ if __name__ == "__main__":
test_rank_zero_bound_checkers
()
test_llvm_bool
()
test_llvm_persist_parallel
()
test_llvm_
select
()
test_llvm_
condition
()
test_llvm_vadd_pipeline
()
test_llvm_add_pipeline
()
test_llvm_intrin
()
...
...
tests/python/unittest/test_pass_inject_copy_intrin.py
View file @
9d20fa1b
...
...
@@ -25,7 +25,7 @@ def test_copy_pad():
l
=
tvm
.
var
(
'l'
)
A
=
tvm
.
placeholder
((
m
,
l
),
name
=
'A'
)
B
=
tvm
.
compute
((
m
+
2
,
l
),
lambda
i
,
j
:
tvm
.
select
(
tvm
.
all
(
i
>=
1
,
i
<
m
+
1
),
tvm
.
if_then_else
(
tvm
.
all
(
i
>=
1
,
i
<
m
+
1
),
A
[
i
-
1
,
j
],
1.0
),
name
=
'B'
)
s
=
tvm
.
create_schedule
(
B
.
op
)
s
[
B
]
.
pragma
(
B
.
op
.
axis
[
0
],
"memcpy"
)
...
...
@@ -71,7 +71,7 @@ def test_copy_pad_split():
m
=
4
*
3
A
=
tvm
.
placeholder
((
m
,
),
name
=
"A"
)
Apad
=
tvm
.
compute
((
m
+
2
,),
lambda
i
:
tvm
.
select
(
tvm
.
all
(
i
>=
1
,
i
<=
m
),
tvm
.
if_then_else
(
tvm
.
all
(
i
>=
1
,
i
<=
m
),
A
[
i
-
1
],
0.0
),
"Apad"
)
B
=
tvm
.
compute
((
m
,),
lambda
i
:
Apad
[
i
]
+
Apad
[
i
+
1
]
+
Apad
[
i
+
2
])
s
=
tvm
.
create_schedule
(
B
.
op
)
...
...
tests/python/unittest/test_pass_loop_partition.py
View file @
9d20fa1b
...
...
@@ -133,7 +133,7 @@ def test_vectorize():
assert
(
x
.
var
.
name
not
in
str
(
body
.
condition
))
assert
(
any
(
collect_visit
(
body
.
then_case
,
lambda
x
:
isinstance
(
x
,
tvm
.
expr
.
Ramp
))))
def
test_
select
():
def
test_
condition
():
ib
=
tvm
.
ir_builder
.
create
()
m
=
tvm
.
var
(
'm'
)
n
=
tvm
.
var
(
'n'
)
...
...
@@ -335,7 +335,7 @@ if __name__ == "__main__":
test_multi_if
()
test_thread_axis
()
test_vectorize
()
test_
select
()
test_
condition
()
test_thread_axis2
()
test_everything_during_deduction
()
test_single_likely
()
...
...
tests/python/unittest/test_pass_rewrite_unsafe_select.py
View file @
9d20fa1b
import
tvm
def
test_rewrite_
s
elect
():
def
test_rewrite_
S
elect
():
ib
=
tvm
.
ir_builder
.
create
()
A
=
ib
.
allocate
(
"float32"
,
100
,
name
=
"A"
,
scope
=
"global"
)
i
=
tvm
.
var
(
"i"
)
y
=
tvm
.
s
elect
(
i
>
1
,
A
[
i
-
1
],
1.0
)
y
=
tvm
.
expr
.
S
elect
(
i
>
1
,
A
[
i
-
1
],
1.0
)
yy
=
tvm
.
ir_pass
.
RewriteUnsafeSelect
(
tvm
.
make
.
Evaluate
(
y
))
.
value
z
=
tvm
.
select
(
tvm
.
select
(
i
>
1
,
A
[
i
-
1
],
1.0
)
>
0.0
,
A
[
i
],
0.1
)
z
=
tvm
.
expr
.
Select
(
tvm
.
expr
.
Select
(
i
>
1
,
A
[
i
-
1
],
1.0
)
>
0.0
,
A
[
i
],
0.1
)
zz
=
tvm
.
ir_pass
.
RewriteUnsafeSelect
(
tvm
.
make
.
Evaluate
(
z
))
.
value
a
=
tvm
.
s
elect
(
i
>
10
,
y
,
z
)
a
=
tvm
.
expr
.
S
elect
(
i
>
10
,
y
,
z
)
aa
=
tvm
.
ir_pass
.
RewriteUnsafeSelect
(
tvm
.
make
.
Evaluate
(
a
))
.
value
assert
yy
.
name
==
"tvm_if_then_else"
assert
zz
.
name
==
"tvm_if_then_else"
...
...
@@ -19,4 +20,4 @@ def test_rewrite_select():
if
__name__
==
"__main__"
:
test_rewrite_
s
elect
()
test_rewrite_
S
elect
()
tests/python/unittest/test_schedule_schedule_ops.py
View file @
9d20fa1b
...
...
@@ -63,8 +63,8 @@ def test_schedule_scan():
def
test_inline_multi_reduce
():
def
argmax_comp
(
x
,
y
):
idx
=
tvm
.
s
elect
((
x
[
1
]
>=
y
[
1
]),
x
[
0
],
y
[
0
])
val
=
tvm
.
s
elect
((
x
[
1
]
>=
y
[
1
]),
x
[
1
],
y
[
1
])
idx
=
tvm
.
expr
.
S
elect
((
x
[
1
]
>=
y
[
1
]),
x
[
0
],
y
[
0
])
val
=
tvm
.
expr
.
S
elect
((
x
[
1
]
>=
y
[
1
]),
x
[
1
],
y
[
1
])
return
idx
,
val
def
argmax_init
(
idx_typ
,
val_typ
):
return
tvm
.
const
(
-
1
,
idx_typ
),
tvm
.
min_value
(
val_typ
)
...
...
@@ -272,7 +272,7 @@ def test_schedule_cache_relayout4():
def
test_schedule_bound_condition
():
A
=
tvm
.
placeholder
((
64
,),
name
=
'A'
,
dtype
=
"float32"
)
Apad
=
tvm
.
compute
((
66
,),
lambda
i
:
tvm
.
select
(
Apad
=
tvm
.
compute
((
66
,),
lambda
i
:
tvm
.
if_then_else
(
tvm
.
all
(
i
>
0
,
i
<
65
),
A
[
i
-
1
],
tvm
.
const
(
0.
,
"float32"
)),
name
=
'Apad'
)
Apad2
=
tvm
.
compute
((
66
,),
lambda
i
:
Apad
[
i
]
*
2
,
name
=
'Apad2'
)
s
=
tvm
.
create_schedule
(
Apad2
.
op
)
...
...
@@ -424,7 +424,7 @@ def test_loop_dep_reduce_cache_write():
X
=
tvm
.
placeholder
(
shape
=
(
10
,),
name
=
"x"
)
def
f
(
n
):
rv
=
tvm
.
reduce_axis
((
0
,
n
))
init
=
lambda
dtype
:
tvm
.
s
elect
(
n
>
1
,
tvm
.
const
(
0
,
dtype
),
n
.
astype
(
dtype
))
init
=
lambda
dtype
:
tvm
.
expr
.
S
elect
(
n
>
1
,
tvm
.
const
(
0
,
dtype
),
n
.
astype
(
dtype
))
sum
=
tvm
.
comm_reducer
(
lambda
x
,
y
:
tvm
.
max
(
x
+
y
,
n
.
astype
(
'float32'
)),
init
,
name
=
'sum'
)
return
sum
(
X
[
rv
],
axis
=
rv
)
Y
=
tvm
.
compute
(
X
.
shape
,
f
,
name
=
"y"
)
...
...
topi/include/topi/image/resize.h
View file @
9d20fa1b
...
...
@@ -38,7 +38,7 @@ inline Expr bilinear_sample_nchw(const Tensor& input, const Array<Expr>& indices
auto
yc
=
HalideIR
::
Internal
::
Cast
::
make
(
Int
(
32
),
tvm
::
ceil
(
in_y
));
auto
y0
=
HalideIR
::
Internal
::
Cast
::
make
(
Int
(
32
),
tvm
::
floor
(
in_y
));
auto
y1
=
tvm
::
select
((
yc
>
max_y
),
max_y
,
yc
);
auto
y1
=
tvm
::
if_then_else
((
yc
>
max_y
),
max_y
,
yc
);
auto
y_lerp
=
in_y
-
yf
;
auto
in_x
=
indices
[
3
];
...
...
@@ -46,7 +46,7 @@ inline Expr bilinear_sample_nchw(const Tensor& input, const Array<Expr>& indices
auto
xc
=
HalideIR
::
Internal
::
Cast
::
make
(
Int
(
32
),
tvm
::
ceil
(
in_x
));
auto
x0
=
HalideIR
::
Internal
::
Cast
::
make
(
Int
(
32
),
tvm
::
floor
(
in_x
));
auto
x1
=
tvm
::
select
((
xc
>
max_x
),
max_x
,
xc
);
auto
x1
=
tvm
::
if_then_else
((
xc
>
max_x
),
max_x
,
xc
);
auto
x_lerp
=
in_x
-
xf
;
auto
A
=
input
(
indices
[
0
],
indices
[
1
],
y0
,
x0
);
...
...
@@ -215,7 +215,7 @@ inline Tensor resize_bilinear_nhwc(const Tensor& input,
auto
yc
=
HalideIR
::
Internal
::
Cast
::
make
(
Int
(
32
),
tvm
::
ceil
(
in_y
));
auto
y0
=
HalideIR
::
Internal
::
Cast
::
make
(
Int
(
32
),
tvm
::
floor
(
in_y
));
auto
y1
=
tvm
::
select
((
yc
>
other_y
),
other_y
,
yc
);
auto
y1
=
tvm
::
if_then_else
((
yc
>
other_y
),
other_y
,
yc
);
auto
y_lerp
=
in_y
-
yf
;
auto
in_x
=
indices
[
2
]
*
x_ratio
;
...
...
@@ -223,7 +223,7 @@ inline Tensor resize_bilinear_nhwc(const Tensor& input,
auto
xc
=
HalideIR
::
Internal
::
Cast
::
make
(
Int
(
32
),
tvm
::
ceil
(
in_x
));
auto
x0
=
HalideIR
::
Internal
::
Cast
::
make
(
Int
(
32
),
tvm
::
floor
(
in_x
));
auto
x1
=
tvm
::
select
((
xc
>
other_x
),
other_x
,
xc
);
auto
x1
=
tvm
::
if_then_else
((
xc
>
other_x
),
other_x
,
xc
);
auto
x_lerp
=
in_x
-
xf
;
auto
A
=
input
(
indices
[
0
],
y0
,
x0
,
indices
[
3
]);
...
...
topi/include/topi/nn.h
View file @
9d20fa1b
...
...
@@ -75,7 +75,7 @@ inline tvm::Tensor leaky_relu(const tvm::Tensor& t,
[
&
](
const
tvm
::
Array
<
tvm
::
Var
>&
i
)
{
auto
value
=
t
(
i
);
auto
calpha
=
tvm
::
make_const
(
value
.
type
(),
alpha
);
return
tvm
::
select
(
value
>
0
,
value
,
value
*
calpha
);
return
tvm
::
ir
::
Select
::
make
(
value
>
0
,
value
,
value
*
calpha
);
},
name
,
tag
);
...
...
@@ -106,9 +106,11 @@ inline tvm::Tensor prelu(const tvm::Tensor &x,
return
tvm
::
compute
(
x
->
shape
,
[
&
](
const
tvm
::
Array
<
tvm
::
Var
>
&
indices
)
{
return
tvm
::
select
(
x
(
indices
)
>
0
,
x
(
indices
),
x
(
indices
)
*
slope
(
indices
[
axis
]));
auto
xval
=
x
(
indices
);
return
tvm
::
ir
::
Select
::
make
(
xval
>
0
,
xval
,
xval
*
slope
(
indices
[
axis
]));
},
name
,
tag
);
...
...
@@ -193,7 +195,8 @@ inline tvm::Tensor pad(const tvm::Tensor& t,
}
}
if
(
sel
.
size
()
!=
0
)
{
return
tvm
::
select
(
detail
::
Map
(
sel
,
tvm
::
ir
::
And
::
make
),
t
(
indices
),
pad_value
);
return
tvm
::
if_then_else
(
detail
::
Map
(
sel
,
tvm
::
ir
::
And
::
make
),
t
(
indices
),
pad_value
);
}
return
t
(
indices
);
};
...
...
topi/include/topi/nn/dilate.h
View file @
9d20fa1b
...
...
@@ -76,7 +76,8 @@ inline Tensor dilate(const Tensor& x,
}
if
(
not_zero
.
size
()
>
0
)
{
auto
all_not_zero
=
all
(
not_zero
);
return
tvm
::
select
(
all_not_zero
,
x
(
index_tuple
),
make_const
(
x
->
dtype
,
0
));
return
tvm
::
if_then_else
(
all_not_zero
,
x
(
index_tuple
),
make_const
(
x
->
dtype
,
0
));
}
return
x
(
index_tuple
);
},
name
,
tag
);
...
...
topi/include/topi/reduction.h
View file @
9d20fa1b
...
...
@@ -411,8 +411,8 @@ inline Tensor argmin(const Tensor& data,
bool
atleast1d
=
false
)
{
auto
fcombine
=
[](
Array
<
Var
>
lhs
,
Array
<
Var
>
rhs
)
{
Array
<
Expr
>
result
;
result
.
push_back
(
tvm
::
select
(
lhs
[
1
]
<=
rhs
[
1
],
lhs
[
0
],
rhs
[
0
]));
// idx
result
.
push_back
(
tvm
::
select
(
lhs
[
1
]
<=
rhs
[
1
],
lhs
[
1
],
rhs
[
1
]));
// val
result
.
push_back
(
tvm
::
ir
::
Select
::
make
(
lhs
[
1
]
<=
rhs
[
1
],
lhs
[
0
],
rhs
[
0
]));
// idx
result
.
push_back
(
tvm
::
ir
::
Select
::
make
(
lhs
[
1
]
<=
rhs
[
1
],
lhs
[
1
],
rhs
[
1
]));
// val
return
result
;
};
auto
fidentity
=
[](
std
::
vector
<
Type
>
types
)
{
...
...
@@ -445,8 +445,8 @@ inline Tensor argmax(const Tensor& data,
bool
atleast1d
=
false
)
{
auto
fcombine
=
[](
Array
<
Var
>
lhs
,
Array
<
Var
>
rhs
)
{
Array
<
Expr
>
result
;
result
.
push_back
(
tvm
::
select
(
lhs
[
1
]
>=
rhs
[
1
],
lhs
[
0
],
rhs
[
0
]));
// idx
result
.
push_back
(
tvm
::
select
(
lhs
[
1
]
>=
rhs
[
1
],
lhs
[
1
],
rhs
[
1
]));
// val
result
.
push_back
(
tvm
::
ir
::
Select
::
make
(
lhs
[
1
]
>=
rhs
[
1
],
lhs
[
0
],
rhs
[
0
]));
// idx
result
.
push_back
(
tvm
::
ir
::
Select
::
make
(
lhs
[
1
]
>=
rhs
[
1
],
lhs
[
1
],
rhs
[
1
]));
// val
return
result
;
};
auto
fidentity
=
[](
std
::
vector
<
Type
>
types
)
{
...
...
topi/include/topi/transform.h
View file @
9d20fa1b
...
...
@@ -314,7 +314,7 @@ inline Tensor concatenate(const Array<Tensor>& inputs,
idx
.
push_back
(
indices
[
i
]);
}
ret
=
tvm
::
select
(
ind
>=
0
,
ret
=
tvm
::
if_then_else
(
ind
>=
0
,
inputs
[
i
+
1
](
idx
),
ret
);
}
...
...
@@ -652,7 +652,7 @@ inline Tensor where(const Tensor& condition,
<<
condition
->
shape
.
size
()
<<
" vs "
<<
x
->
shape
.
size
();
out
=
compute
(
oshape
,
[
&
](
const
Array
<
Var
>&
indices
)
{
return
tvm
::
select
(
condition
(
indices
)
!=
0
,
x
(
indices
),
y
(
indices
));
return
tvm
::
ir
::
Select
::
make
(
condition
(
indices
)
!=
0
,
x
(
indices
),
y
(
indices
));
},
name
,
tag
);
}
else
{
CHECK_EQ
(
topi
::
GetConstInt
(
condition
->
shape
[
0
]),
topi
::
GetConstInt
(
x
->
shape
[
0
]))
...
...
@@ -661,7 +661,7 @@ inline Tensor where(const Tensor& condition,
out
=
compute
(
oshape
,
[
&
](
const
Array
<
Var
>&
indices
)
{
Array
<
Expr
>
condition_idx
{
indices
[
0
]};
return
tvm
::
select
(
condition
(
condition_idx
)
!=
0
,
return
tvm
::
ir
::
Select
::
make
(
condition
(
condition_idx
)
!=
0
,
x
(
indices
),
y
(
indices
));
},
name
,
tag
);
}
...
...
topi/python/topi/cuda/conv2d_transpose_nchw.py
View file @
9d20fa1b
...
...
@@ -72,7 +72,7 @@ def conv2d_transpose_nchw_cuda(cfg, Input, Filter, strides, padding, out_dtype):
index_tuple
.
append
(
indices
[
i
])
if
not_zero
:
not_zero
=
tvm
.
all
(
*
not_zero
)
return
tvm
.
select
(
not_zero
,
data
(
*
index_tuple
),
tvm
.
const
(
0.0
,
data
.
dtype
))
return
tvm
.
if_then_else
(
not_zero
,
data
(
*
index_tuple
),
tvm
.
const
(
0.0
,
data
.
dtype
))
return
data
(
*
index_tuple
)
# convolution stage
...
...
topi/python/topi/cuda/nms.py
View file @
9d20fa1b
...
...
@@ -315,11 +315,11 @@ def sort_ir_out(data, index, new_index, loc, output, axis_mul_before, axis_mul_a
start
=
0
with
ib
.
else_scope
():
start
=
sizes
[
tid
-
1
]
p_out
[
base_idx
+
k
*
axis_mul_after
]
=
tvm
.
select
(
p_out
[
base_idx
+
k
*
axis_mul_after
]
=
tvm
.
if_then_else
(
k
<
p_index
[
tid
],
index_new
[
k
+
start
],
k
)
with
ib
.
else_scope
():
with
ib
.
if_scope
(
tid
<
data
.
shape
[
axis
]):
p_out
[
tid
]
=
tvm
.
select
(
tid
<
p_index
[
0
],
index_new
[
tid
],
tid
)
p_out
[
tid
]
=
tvm
.
if_then_else
(
tid
<
p_index
[
0
],
index_new
[
tid
],
tid
)
body
=
ib
.
get
()
return
body
...
...
@@ -470,7 +470,7 @@ def nms_ir(data, sort_result, valid_count, out, nms_threshold, force_suppress, n
(
out_tensor
[
box_a_idx
+
3
]
-
out_tensor
[
box_a_idx
+
1
])
+
\
(
out_tensor
[
box_b_idx
+
2
]
-
out_tensor
[
box_b_idx
])
*
\
(
out_tensor
[
box_b_idx
+
3
]
-
out_tensor
[
box_b_idx
+
1
])
-
i
return
tvm
.
s
elect
(
u
<=
0.0
,
0.0
,
i
/
u
)
return
tvm
.
expr
.
S
elect
(
u
<=
0.0
,
0.0
,
i
/
u
)
max_threads
=
int
(
math
.
sqrt
(
tvm
.
target
.
current_target
(
allow_none
=
False
)
.
max_num_threads
))
...
...
@@ -506,7 +506,7 @@ def nms_ir(data, sort_result, valid_count, out, nms_threshold, force_suppress, n
tvm
.
all
(
nms_threshold_node
>
0
,
nms_threshold_node
<
1
,
p_valid_count
[
0
]
>
0
)):
# Reorder output
nkeep
=
tvm
.
select
(
nkeep
=
tvm
.
if_then_else
(
tvm
.
all
(
nms_topk_node
>
0
,
nms_topk
<
p_valid_count
[
n
]),
nms_topk
,
p_valid_count
[
n
])
with
ib
.
if_scope
(
i
<
nkeep
):
...
...
topi/python/topi/cuda/ssd/multibox.py
View file @
9d20fa1b
...
...
@@ -77,12 +77,13 @@ def multibox_prior_ir(data, out, sizes, ratios, steps, offsets):
center_w
=
(
j
+
offset_w
)
*
steps_w
for
k
in
range
(
num_sizes
+
num_ratios
-
1
):
w
=
tvm
.
select
(
k
<
num_sizes
,
w
=
tvm
.
if_then_else
(
k
<
num_sizes
,
size_ratio_concat
[
k
]
*
in_height
/
in_width
/
2.0
,
size_ratio_concat
[
0
]
*
in_height
/
in_width
*
math
.
sqrt
(
size_ratio_concat
[
k
+
1
])
/
2.0
)
h
=
tvm
.
select
(
k
<
num_sizes
,
size_ratio_concat
[
k
]
/
2.0
,
h
=
tvm
.
if_then_else
(
k
<
num_sizes
,
size_ratio_concat
[
k
]
/
2.0
,
size_ratio_concat
[
0
]
/
math
.
sqrt
(
size_ratio_concat
[
k
+
1
])
/
2.0
)
count
=
(
i
*
in_width
*
(
num_sizes
+
num_ratios
-
1
)
+
j
*
(
num_sizes
+
num_ratios
-
1
)
+
k
)
*
4
...
...
@@ -278,10 +279,10 @@ def transform_loc_ir(loc_pred, anchor, temp_flag, temp_id, temp_score_in, \
oy
=
py
*
vy
*
ah
+
ay
ow
=
tvm
.
exp
(
pw
*
vw
)
*
aw
/
2.0
oh
=
tvm
.
exp
(
ph
*
vh
)
*
ah
/
2.0
return
tvm
.
select
(
clip
,
tvm
.
make
.
Max
(
0.0
,
tvm
.
make
.
Min
(
1.0
,
ox
-
ow
)),
ox
-
ow
),
\
tvm
.
select
(
clip
,
tvm
.
make
.
Max
(
0.0
,
tvm
.
make
.
Min
(
1.0
,
oy
-
oh
)),
oy
-
oh
),
\
tvm
.
select
(
clip
,
tvm
.
make
.
Max
(
0.0
,
tvm
.
make
.
Min
(
1.0
,
ox
+
ow
)),
ox
+
ow
),
\
tvm
.
select
(
clip
,
tvm
.
make
.
Max
(
0.0
,
tvm
.
make
.
Min
(
1.0
,
oy
+
oh
)),
oy
+
oh
)
return
tvm
.
if_then_else
(
clip
,
tvm
.
make
.
Max
(
0.0
,
tvm
.
make
.
Min
(
1.0
,
ox
-
ow
)),
ox
-
ow
),
\
tvm
.
if_then_else
(
clip
,
tvm
.
make
.
Max
(
0.0
,
tvm
.
make
.
Min
(
1.0
,
oy
-
oh
)),
oy
-
oh
),
\
tvm
.
if_then_else
(
clip
,
tvm
.
make
.
Max
(
0.0
,
tvm
.
make
.
Min
(
1.0
,
ox
+
ow
)),
ox
+
ow
),
\
tvm
.
if_then_else
(
clip
,
tvm
.
make
.
Max
(
0.0
,
tvm
.
make
.
Min
(
1.0
,
oy
+
oh
)),
oy
+
oh
)
max_threads
=
int
(
tvm
.
target
.
current_target
(
allow_none
=
False
)
.
max_num_threads
)
...
...
topi/python/topi/mali/conv2d.py
View file @
9d20fa1b
...
...
@@ -296,7 +296,8 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt
# pack input tile
input_tile
=
tvm
.
compute
((
CI
,
P_round
//
bnb
,
alpha
,
alpha
,
bnb
),
lambda
ci
,
b
,
eps
,
nu
,
bb
:
\
tvm
.
select
(
b
*
bnb
+
bb
<
P
,
tvm
.
if_then_else
(
b
*
bnb
+
bb
<
P
,
data_pad
[(
b
*
bnb
+
bb
)
//
(
nH
*
nW
)][
ci
][(
b
*
bnb
+
bb
)
//
nW
%
nH
*
m
+
eps
]
[(
b
*
bnb
+
bb
)
%
nW
*
m
+
nu
],
tvm
.
const
(
0
,
data_pad
.
dtype
)),
name
=
'd'
)
...
...
topi/python/topi/nn/dilate.py
View file @
9d20fa1b
...
...
@@ -44,7 +44,7 @@ def dilate(data, strides, name="DilatedInput"):
index_tuple
.
append
(
indices
[
i
])
if
not_zero
:
not_zero
=
tvm
.
all
(
*
not_zero
)
return
tvm
.
select
(
not_zero
,
data
(
*
index_tuple
),
tvm
.
const
(
0.0
,
data
.
dtype
))
return
tvm
.
if_then_else
(
not_zero
,
data
(
*
index_tuple
),
tvm
.
const
(
0.0
,
data
.
dtype
))
return
data
(
*
index_tuple
)
return
tvm
.
compute
(
out_shape
,
_dilate
,
name
=
name
)
topi/python/topi/nn/elemwise.py
View file @
9d20fa1b
...
...
@@ -41,7 +41,7 @@ def leaky_relu(x, alpha):
def
_compute
(
*
indices
):
value
=
x
(
*
indices
)
calpha
=
tvm
.
const
(
alpha
,
value
.
dtype
)
return
tvm
.
s
elect
(
value
>
0
,
value
,
value
*
calpha
)
return
tvm
.
expr
.
S
elect
(
value
>
0
,
value
,
value
*
calpha
)
return
tvm
.
compute
(
x
.
shape
,
_compute
)
@tvm.tag_scope
(
tag
=
tag
.
BROADCAST
)
...
...
@@ -74,5 +74,6 @@ def prelu(x, slope, axis=1):
assert
get_const_int
(
slope
.
shape
[
0
])
==
get_const_int
(
x
.
shape
[
axis
])
def
_compute_channelwise
(
*
indices
):
return
tvm
.
select
(
x
(
*
indices
)
>
0
,
x
(
*
indices
),
x
(
*
indices
)
*
slope
(
indices
[
axis
]))
xval
=
x
(
*
indices
)
return
tvm
.
expr
.
Select
(
xval
>
0
,
xval
,
xval
*
slope
(
indices
[
axis
]))
return
tvm
.
compute
(
x
.
shape
,
_compute_channelwise
)
topi/python/topi/nn/pad.py
View file @
9d20fa1b
...
...
@@ -55,6 +55,6 @@ def pad(data, pad_before, pad_after=None, pad_value=0.0, name="PadInput"):
not_zero
.
append
(
indices
[
i
]
<
data
.
shape
[
i
]
+
pad_before
[
i
])
if
not_zero
:
not_zero
=
tvm
.
all
(
*
not_zero
)
return
tvm
.
select
(
not_zero
,
data
(
*
index_tuple
),
pad_value
)
return
tvm
.
if_then_else
(
not_zero
,
data
(
*
index_tuple
),
pad_value
)
return
data
(
*
index_tuple
)
return
tvm
.
compute
(
out_shape
,
_pad
,
name
=
name
)
topi/python/topi/nn/util.py
View file @
9d20fa1b
...
...
@@ -55,8 +55,8 @@ def infer_stride(data, kernel, out):
_
,
_
,
IH
,
IW
=
data
.
shape
_
,
_
,
KH
,
KW
=
kernel
.
shape
_
,
_
,
OH
,
OW
=
out
.
shape
hstride
=
(
IH
-
KH
)
//
tvm
.
make
.
Max
(
OH
-
1
,
1
)
+
tvm
.
s
elect
(
OH
==
1
,
1
,
0
)
wstride
=
(
IW
-
KW
)
//
tvm
.
make
.
Max
(
OW
-
1
,
1
)
+
tvm
.
s
elect
(
OW
==
1
,
1
,
0
)
hstride
=
(
IH
-
KH
)
//
tvm
.
make
.
Max
(
OH
-
1
,
1
)
+
tvm
.
expr
.
S
elect
(
OH
==
1
,
1
,
0
)
wstride
=
(
IW
-
KW
)
//
tvm
.
make
.
Max
(
OW
-
1
,
1
)
+
tvm
.
expr
.
S
elect
(
OW
==
1
,
1
,
0
)
return
get_const_int
(
hstride
),
get_const_int
(
wstride
)
...
...
topi/python/topi/util.py
View file @
9d20fa1b
...
...
@@ -249,7 +249,7 @@ def const_matrix(matrix, name="const_matrix"):
now
=
tvm
.
const
(
0.0
,
dtype
)
for
ii
in
range
(
row
):
for
jj
in
range
(
col
):
now
=
tvm
.
s
elect
(
tvm
.
all
(
i
%
row
==
ii
,
j
%
col
==
jj
),
now
=
tvm
.
expr
.
S
elect
(
tvm
.
all
(
i
%
row
==
ii
,
j
%
col
==
jj
),
tvm
.
const
(
matrix
[
ii
][
jj
],
dtype
),
now
)
return
now
...
...
topi/python/topi/vision/nms.py
View file @
9d20fa1b
...
...
@@ -47,7 +47,7 @@ def nms_ir(data, sort_result, valid_count, out, nms_threshold, force_suppress, n
(
out_tensor
[
box_a_idx
+
3
]
-
out_tensor
[
box_a_idx
+
1
])
+
\
(
out_tensor
[
box_b_idx
+
2
]
-
out_tensor
[
box_b_idx
])
*
\
(
out_tensor
[
box_b_idx
+
3
]
-
out_tensor
[
box_b_idx
+
1
])
-
i
return
tvm
.
s
elect
(
u
<=
0.0
,
0.0
,
i
/
u
)
return
tvm
.
expr
.
S
elect
(
u
<=
0.0
,
0.0
,
i
/
u
)
ib
=
tvm
.
ir_builder
.
create
()
p_data
=
ib
.
buffer_ptr
(
data
)
...
...
@@ -64,7 +64,8 @@ def nms_ir(data, sort_result, valid_count, out, nms_threshold, force_suppress, n
with
ib
.
if_scope
(
tvm
.
all
(
nms_threshold_node
>
0
,
nms_threshold_node
<
1
,
p_valid_count
[
0
]
>
0
)):
# Reorder output
nkeep
=
tvm
.
select
(
tvm
.
all
(
nms_topk_node
>
0
,
nms_topk
<
p_valid_count
[
n
]),
nkeep
=
tvm
.
if_then_else
(
tvm
.
all
(
nms_topk_node
>
0
,
nms_topk
<
p_valid_count
[
n
]),
nms_topk
,
p_valid_count
[
n
])
with
ib
.
for_range
(
0
,
nkeep
,
name
=
"l"
)
as
l
:
with
ib
.
for_range
(
0
,
6
,
name
=
"m"
)
as
m
:
...
...
topi/python/topi/vision/rcnn/roi_align.py
View file @
9d20fa1b
...
...
@@ -47,7 +47,7 @@ def roi_align_nchw(data, rois, pooled_size, spatial_scale, sample_ratio=-1):
y
=
tvm
.
max
(
y
,
0.0
)
x
=
tvm
.
max
(
x
,
0.0
)
val
=
bilinear_sample_nchw
(
data
,
(
i
,
c
,
y
,
x
),
height
-
1
,
width
-
1
)
return
tvm
.
select
(
outside
,
0.0
,
val
)
return
tvm
.
if_then_else
(
outside
,
0.0
,
val
)
def
_sample
(
i
,
c
,
ph
,
pw
):
roi
=
rois
[
i
]
...
...
topi/python/topi/vision/ssd/multibox.py
View file @
9d20fa1b
...
...
@@ -55,11 +55,12 @@ def multibox_prior_ir(data, out, sizes, ratios, steps, offsets):
with
ib
.
for_range
(
0
,
in_width
,
name
=
"j"
)
as
j
:
center_w
=
(
j
+
offset_w
)
*
steps_w
for
k
in
range
(
num_sizes
+
num_ratios
-
1
):
w
=
tvm
.
select
(
k
<
num_sizes
,
w
=
tvm
.
if_then_else
(
k
<
num_sizes
,
size_ratio_concat
[
k
]
*
in_height
/
in_width
/
2.0
,
size_ratio_concat
[
0
]
*
in_height
/
in_width
*
math
.
sqrt
(
size_ratio_concat
[
k
+
1
])
/
2.0
)
h
=
tvm
.
select
(
k
<
num_sizes
,
size_ratio_concat
[
k
]
/
2.0
,
h
=
tvm
.
if_then_else
(
k
<
num_sizes
,
size_ratio_concat
[
k
]
/
2.0
,
size_ratio_concat
[
0
]
/
math
.
sqrt
(
size_ratio_concat
[
k
+
1
])
/
2.0
)
count
=
(
i
*
in_width
*
(
num_sizes
+
num_ratios
-
1
)
+
j
*
(
num_sizes
+
num_ratios
-
1
)
+
k
)
*
4
...
...
@@ -164,10 +165,10 @@ def transform_loc_ir(cls_prob, loc_pred, anchor, valid_count, out, clip, thresho
oy
=
py
*
vy
*
ah
+
ay
ow
=
tvm
.
exp
(
pw
*
vw
)
*
aw
/
2.0
oh
=
tvm
.
exp
(
ph
*
vh
)
*
ah
/
2.0
return
tvm
.
select
(
clip
,
tvm
.
max
(
0
,
tvm
.
min
(
1
,
ox
-
ow
)),
ox
-
ow
),
\
tvm
.
select
(
clip
,
tvm
.
max
(
0
,
tvm
.
min
(
1
,
oy
-
oh
)),
oy
-
oh
),
\
tvm
.
select
(
clip
,
tvm
.
max
(
0
,
tvm
.
min
(
1
,
ox
+
ow
)),
ox
+
ow
),
\
tvm
.
select
(
clip
,
tvm
.
max
(
0
,
tvm
.
min
(
1
,
oy
+
oh
)),
oy
+
oh
)
return
tvm
.
if_then_else
(
clip
,
tvm
.
max
(
0
,
tvm
.
min
(
1
,
ox
-
ow
)),
ox
-
ow
),
\
tvm
.
if_then_else
(
clip
,
tvm
.
max
(
0
,
tvm
.
min
(
1
,
oy
-
oh
)),
oy
-
oh
),
\
tvm
.
if_then_else
(
clip
,
tvm
.
max
(
0
,
tvm
.
min
(
1
,
ox
+
ow
)),
ox
+
ow
),
\
tvm
.
if_then_else
(
clip
,
tvm
.
max
(
0
,
tvm
.
min
(
1
,
oy
+
oh
)),
oy
+
oh
)
batch_size
=
cls_prob
.
shape
[
0
]
num_classes
=
cls_prob
.
shape
[
1
]
...
...
@@ -190,7 +191,7 @@ def transform_loc_ir(cls_prob, loc_pred, anchor, valid_count, out, clip, thresho
with
ib
.
for_range
(
0
,
num_classes
,
name
=
"j"
)
as
j
:
with
ib
.
if_scope
(
j
>
0
):
temp
=
p_cls_prob
[
n
*
num_anchors
*
num_classes
+
j
*
num_anchors
+
i
]
cls_id
[
0
]
=
tvm
.
select
(
temp
>
score
[
0
],
j
,
cls_id
[
0
])
cls_id
[
0
]
=
tvm
.
if_then_else
(
temp
>
score
[
0
],
j
,
cls_id
[
0
])
score
[
0
]
=
tvm
.
max
(
temp
,
score
[
0
])
with
ib
.
if_scope
(
tvm
.
all
(
cls_id
[
0
]
>
0
,
score
[
0
]
<
threshold
)):
cls_id
[
0
]
=
0
...
...
topi/tests/python/test_topi_conv2d_nchw.py
View file @
9d20fa1b
...
...
@@ -65,7 +65,7 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p
else
:
func
=
tvm
.
build
(
s
,
[
A
,
W
,
C
],
device
,
name
=
"relu_
%
d_
%
d_
%
d_
%
d_
%
d_
%
d_
%
d_
%
d"
%
(
batch
,
in_channel
,
in_size
,
num_filter
,
kernel
,
stride
,
padding
,
dilation
))
func
(
a
,
w
,
c
)
tvm
.
testing
.
assert_allclose
(
c
.
asnumpy
(),
c_np
,
rtol
=
1e-
5
)
tvm
.
testing
.
assert_allclose
(
c
.
asnumpy
(),
c_np
,
rtol
=
1e-
4
)
for
device
in
get_all_backend
():
with
autotvm
.
tophub
.
context
(
device
):
# load tophub pre-tuned parameters
...
...
tutorials/language/tuple_inputs.py
View file @
9d20fa1b
...
...
@@ -45,8 +45,8 @@ print(tvm.lower(s, [A0, A1, B0, B1], simple_mode=True))
# x and y are the operands of reduction, both of them is a tuple of index
# and value.
def
fcombine
(
x
,
y
):
lhs
=
tvm
.
s
elect
((
x
[
1
]
>=
y
[
1
]),
x
[
0
],
y
[
0
])
rhs
=
tvm
.
s
elect
((
x
[
1
]
>=
y
[
1
]),
x
[
1
],
y
[
1
])
lhs
=
tvm
.
expr
.
S
elect
((
x
[
1
]
>=
y
[
1
]),
x
[
0
],
y
[
0
])
rhs
=
tvm
.
expr
.
S
elect
((
x
[
1
]
>=
y
[
1
]),
x
[
1
],
y
[
1
])
return
lhs
,
rhs
# our identity element also need to be a tuple, so `fidentity` accepts
...
...
tutorials/optimize/opt_conv_cuda.py
View file @
9d20fa1b
...
...
@@ -43,7 +43,7 @@ out_size = (in_size - kernel + 2*pad) // stride + 1
# Pad input
Apad
=
tvm
.
compute
(
(
in_size
+
2
*
pad
,
in_size
+
2
*
pad
,
in_channel
,
batch
),
lambda
yy
,
xx
,
cc
,
nn
:
tvm
.
select
(
lambda
yy
,
xx
,
cc
,
nn
:
tvm
.
if_then_else
(
tvm
.
all
(
yy
>=
pad
,
yy
-
pad
<
in_size
,
xx
>=
pad
,
xx
-
pad
<
in_size
),
A
[
yy
-
pad
,
xx
-
pad
,
cc
,
nn
],
tvm
.
const
(
0.
,
"float32"
)),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment