Unverified Commit 9d20fa1b by Tianqi Chen Committed by GitHub

[PASS][TENSOR] Use correct select semantics (#2394)

parent 98e761f8
...@@ -11,6 +11,7 @@ tvm.intrin ...@@ -11,6 +11,7 @@ tvm.intrin
tvm.call_extern tvm.call_extern
tvm.call_llvm_intrin tvm.call_llvm_intrin
tvm.register_intrin_rule tvm.register_intrin_rule
tvm.if_then_else
tvm.exp tvm.exp
tvm.log tvm.log
tvm.floor tvm.floor
...@@ -26,6 +27,7 @@ tvm.intrin ...@@ -26,6 +27,7 @@ tvm.intrin
.. autofunction:: tvm.call_extern .. autofunction:: tvm.call_extern
.. autofunction:: tvm.call_llvm_intrin .. autofunction:: tvm.call_llvm_intrin
.. autofunction:: tvm.register_intrin_rule .. autofunction:: tvm.register_intrin_rule
.. autofunction:: tvm.if_then_else
.. autofunction:: tvm.exp .. autofunction:: tvm.exp
.. autofunction:: tvm.log .. autofunction:: tvm.log
.. autofunction:: tvm.floor .. autofunction:: tvm.floor
......
...@@ -15,7 +15,6 @@ The user facing API for computation declaration. ...@@ -15,7 +15,6 @@ The user facing API for computation declaration.
tvm.extern tvm.extern
tvm.decl_buffer tvm.decl_buffer
tvm.reduce_axis tvm.reduce_axis
tvm.select
tvm.thread_axis tvm.thread_axis
tvm.comm_reducer tvm.comm_reducer
tvm.sum tvm.sum
...@@ -34,7 +33,6 @@ The user facing API for computation declaration. ...@@ -34,7 +33,6 @@ The user facing API for computation declaration.
.. autofunction:: tvm.extern .. autofunction:: tvm.extern
.. autofunction:: tvm.decl_buffer .. autofunction:: tvm.decl_buffer
.. autofunction:: tvm.reduce_axis .. autofunction:: tvm.reduce_axis
.. autofunction:: tvm.select
.. autofunction:: tvm.thread_axis .. autofunction:: tvm.thread_axis
.. autofunction:: tvm.comm_reducer .. autofunction:: tvm.comm_reducer
.. autofunction:: tvm.sum .. autofunction:: tvm.sum
......
...@@ -392,7 +392,7 @@ TVM_DLL Expr operator^(Expr a, Expr b); ...@@ -392,7 +392,7 @@ TVM_DLL Expr operator^(Expr a, Expr b);
*/ */
TVM_DLL Expr operator~(Expr a); TVM_DLL Expr operator~(Expr a);
/*! /*!
* \brief select result by condition * \brief Conditional expression.
* *
* \param cond The condition * \param cond The condition
* \param true_value The value when results are true. * \param true_value The value when results are true.
...@@ -401,7 +401,7 @@ TVM_DLL Expr operator~(Expr a); ...@@ -401,7 +401,7 @@ TVM_DLL Expr operator~(Expr a);
* \note this function does eager constant folding for * \note this function does eager constant folding for
* index types(int32, int64) when possible. * index types(int32, int64) when possible.
*/ */
TVM_DLL Expr select(Expr cond, Expr true_value, Expr false_value); TVM_DLL Expr if_then_else(Expr cond, Expr true_value, Expr false_value);
/*! /*!
* \brief Mark condition as likely. * \brief Mark condition as likely.
* \param cond The condition * \param cond The condition
......
...@@ -669,28 +669,6 @@ def reduce_axis(dom, name="rv"): ...@@ -669,28 +669,6 @@ def reduce_axis(dom, name="rv"):
return _IterVar(dom, name, 2) return _IterVar(dom, name, 2)
def select(cond, t, f):
"""Construct a select branch.
Parameters
----------
cond : Expr
The condition
t : Expr
The result expression if cond is true.
f : Expr
The result expression if cond is false.
Returns
-------
node : Node
The tvm.expr.Select node
"""
return _expr.Select(convert(cond), convert(t), convert(f))
def comm_reducer(fcombine, fidentity, name="reduce"): def comm_reducer(fcombine, fidentity, name="reduce"):
"""Create a commutative reducer for reduction. """Create a commutative reducer for reduction.
......
...@@ -624,6 +624,13 @@ class Not(LogicalExpr): ...@@ -624,6 +624,13 @@ class Not(LogicalExpr):
class Select(Expr): class Select(Expr):
"""Select node. """Select node.
Note
----
Select may compute both true_value and false_value.
Use :any:`tvm.if_then_else` instead if you want to
get a conditional expression that only evaluates
the correct branch.
Parameters Parameters
---------- ----------
condition : Expr condition : Expr
...@@ -634,6 +641,7 @@ class Select(Expr): ...@@ -634,6 +641,7 @@ class Select(Expr):
false_value : Expr false_value : Expr
The value to take when condition is false. The value to take when condition is false.
""" """
def __init__(self, condition, true_value, false_value): def __init__(self, condition, true_value, false_value):
self.__init_handle_by_constructor__( self.__init_handle_by_constructor__(
......
...@@ -393,6 +393,42 @@ def fmod(x, y): ...@@ -393,6 +393,42 @@ def fmod(x, y):
""" """
return call_pure_intrin(x.dtype, "fmod", x, y) return call_pure_intrin(x.dtype, "fmod", x, y)
def if_then_else(cond, t, f):
"""Conditional selection expression.
Parameters
----------
cond : Expr
The condition
t : Expr
The result expression if cond is true.
f : Expr
The result expression if cond is false.
Returns
-------
result : Node
The result of conditional expression.
Note
----
Unlike Select, if_then_else will not execute
the branch that does not satisfy the condition.
You can use it to guard against out of bound access.
Unlike Select, if_then_else cannot be vectorized
if some lanes in the vector have different conditions.
"""
t = convert(t)
f = convert(f)
cond = convert(cond)
if cond.dtype != "bool":
raise TypeError("The condition's data type has to be bool")
return call_pure_intrin(t.dtype, "tvm_if_then_else", cond, t, f)
# Intrinsic rule related code # Intrinsic rule related code
def register_intrin_rule(target, intrin, f=None, override=False): def register_intrin_rule(target, intrin, f=None, override=False):
"""Register an intrinsic function generation rule. """Register an intrinsic function generation rule.
......
...@@ -268,8 +268,9 @@ inline IntSet CombineInterval<Mul>(Interval a, Interval b) { ...@@ -268,8 +268,9 @@ inline IntSet CombineInterval<Mul>(Interval a, Interval b) {
} else if (is_negative_const(b.min)) { } else if (is_negative_const(b.min)) {
return IntervalSet::make(e2, e1); return IntervalSet::make(e2, e1);
} else if (a.is_bounded()) { } else if (a.is_bounded()) {
using ir::Select;
Expr cmp = b.min >= make_zero(b.min.type().element_of()); Expr cmp = b.min >= make_zero(b.min.type().element_of());
return IntervalSet::make(select(cmp, e1, e2), select(cmp, e2, e1)); return IntervalSet::make(Select::make(cmp, e1, e2), Select::make(cmp, e2, e1));
} }
} }
LOG(WARNING) << "Return Everything in CombineInterval Mul"; LOG(WARNING) << "Return Everything in CombineInterval Mul";
...@@ -294,8 +295,9 @@ inline IntSet CombineInterval<Div>(Interval a, Interval b) { ...@@ -294,8 +295,9 @@ inline IntSet CombineInterval<Div>(Interval a, Interval b) {
} else if (is_negative_const(b.min)) { } else if (is_negative_const(b.min)) {
return IntervalSet::make(e2, e1); return IntervalSet::make(e2, e1);
} else if (a.is_bounded()) { } else if (a.is_bounded()) {
using ir::Select;
Expr cmp = b.min >= make_zero(b.min.type().element_of()); Expr cmp = b.min >= make_zero(b.min.type().element_of());
return IntervalSet::make(select(cmp, e1, e2), select(cmp, e2, e1)); return IntervalSet::make(Select::make(cmp, e1, e2), Select::make(cmp, e2, e1));
} }
} }
LOG(WARNING) << "Return Everything in CombineInterval Div"; LOG(WARNING) << "Return Everything in CombineInterval Div";
......
...@@ -240,10 +240,11 @@ Expr max(Expr a, Expr b) { ...@@ -240,10 +240,11 @@ Expr max(Expr a, Expr b) {
return ir::Max::make(a, b); return ir::Max::make(a, b);
} }
Expr select(Expr cond, Expr true_value, Expr false_value) { Expr if_then_else(Expr cond, Expr true_value, Expr false_value) {
using ir::IntImm; using ir::IntImm;
using ir::UIntImm; using ir::UIntImm;
CHECK(cond.type().is_bool()); CHECK(cond.type() == Bool(1))
<< "if_then_else only accept a single condition";
BinaryOpMatchTypes(true_value, false_value); BinaryOpMatchTypes(true_value, false_value);
if (const UIntImm* op = cond.as<UIntImm>()) { if (const UIntImm* op = cond.as<UIntImm>()) {
if (op->value != 0) { if (op->value != 0) {
...@@ -258,7 +259,11 @@ Expr select(Expr cond, Expr true_value, Expr false_value) { ...@@ -258,7 +259,11 @@ Expr select(Expr cond, Expr true_value, Expr false_value) {
return false_value; return false_value;
} }
} }
return ir::Select::make(cond, true_value, false_value); return ir::Call::make(
true_value.type(),
ir::intrinsic::tvm_if_then_else,
{cond, true_value, false_value},
ir::Call::PureIntrinsic);
} }
Expr likely(Expr cond) { Expr likely(Expr cond) {
...@@ -402,7 +407,12 @@ Expr pow(Expr x, Expr y) { ...@@ -402,7 +407,12 @@ Expr pow(Expr x, Expr y) {
Expr abs(Expr x) { Expr abs(Expr x) {
if (x.type().is_int()) { if (x.type().is_int()) {
return select(x >= make_zero(x.type()), x, -x); using ir::IntImm;
const IntImm* px = x.as<IntImm>();
if (px) {
return ir::IntImm::make(x.type(), std::abs(px->value));
}
return ir::Select::make(x >= make_zero(x.type()), x, -x);
} else if (x.type().is_float()) { } else if (x.type().is_float()) {
return ir::Call::make(x.type(), "fabs", {x}, ir::Call::PureIntrinsic); return ir::Call::make(x.type(), "fabs", {x}, ir::Call::PureIntrinsic);
} else if (x.type().is_uint()) { } else if (x.type().is_uint()) {
......
...@@ -35,6 +35,26 @@ class CopyIntrinInjector : public IRMutator { ...@@ -35,6 +35,26 @@ class CopyIntrinInjector : public IRMutator {
} }
private: private:
bool MatchCondition(Expr expr,
Expr* cond,
Expr* true_value,
Expr* false_value) {
if (const auto* op = expr.as<Select>()) {
*cond = op->condition;
*true_value = op->true_value;
*false_value = op->false_value;
return true;
} else if (const auto* op = expr.as<Call>()) {
if (op->name == intrinsic::tvm_if_then_else) {
*cond = op->args[0];
*true_value = op->args[1];
*false_value = op->args[2];
return true;
}
}
return false;
}
bool MatchCopyPattern(Stmt stmt, Stmt *out) { bool MatchCopyPattern(Stmt stmt, Stmt *out) {
Stmt body = stmt; Stmt body = stmt;
bool is_single_point_copy = false; bool is_single_point_copy = false;
...@@ -48,16 +68,20 @@ class CopyIntrinInjector : public IRMutator { ...@@ -48,16 +68,20 @@ class CopyIntrinInjector : public IRMutator {
} }
const Store* store = body.as<Store>(); const Store* store = body.as<Store>();
if (store == nullptr) return false; if (store == nullptr) return false;
const Select* select = store->value.as<Select>(); Expr sel_cond, sel_true_value, sel_false_value;
bool has_cond = MatchCondition(store->value,
&sel_cond,
&sel_true_value,
&sel_false_value);
const Cast* cast = store->value.as<Cast>(); const Cast* cast = store->value.as<Cast>();
const Load* load = store->value.as<Load>(); const Load* load = store->value.as<Load>();
if (0 == loops.size()) { if (0 == loops.size()) {
is_single_point_copy = true; is_single_point_copy = true;
CHECK(select == nullptr); CHECK(!has_cond);
} }
// for now only support true condition matching // for now only support true condition matching
if (select != nullptr) { if (has_cond) {
load = select->true_value.as<Load>(); load = sel_true_value.as<Load>();
} }
// cast can be part of the pattern // cast can be part of the pattern
if (cast != nullptr) { if (cast != nullptr) {
...@@ -88,10 +112,10 @@ class CopyIntrinInjector : public IRMutator { ...@@ -88,10 +112,10 @@ class CopyIntrinInjector : public IRMutator {
Array<Expr> pad_before, pad_after; Array<Expr> pad_before, pad_after;
Expr pad_value; Expr pad_value;
Expr src_elem_offset = load_strides[loop_var_size]; Expr src_elem_offset = load_strides[loop_var_size];
if (select != nullptr) { if (has_cond) {
Array<Expr> clip_bound = Array<Expr> clip_bound =
arith::DetectClipBound(select->condition, loop_vars); arith::DetectClipBound(sel_cond, loop_vars);
pad_value = select->false_value; pad_value = sel_false_value;
if (clip_bound.size() == 0) return false; if (clip_bound.size() == 0) return false;
CHECK_EQ(src_shape.size(), loop_vars.size()); CHECK_EQ(src_shape.size(), loop_vars.size());
CHECK_EQ(clip_bound.size(), loop_vars.size() * 2); CHECK_EQ(clip_bound.size(), loop_vars.size() * 2);
......
...@@ -8,7 +8,7 @@ def test_reduce_prims(): ...@@ -8,7 +8,7 @@ def test_reduce_prims():
n = tvm.var('n') n = tvm.var('n')
m = tvm.var('m') m = tvm.var('m')
A = tvm.placeholder((n, m), name='A') A = tvm.placeholder((n, m), name='A')
R = tvm.compute((n, ), lambda i: tvm.select((i > 1), 1, 0), name='R') R = tvm.compute((n, ), lambda i: tvm.expr.Select((i > 1), 1, 0), name='R')
k = tvm.reduce_axis((0, m)) k = tvm.reduce_axis((0, m))
B = tvm.compute((n,), lambda i: reducer(A[i, k], axis=k, where=(R[i]==1)), name='B') B = tvm.compute((n,), lambda i: reducer(A[i, k], axis=k, where=(R[i]==1)), name='B')
# schedule # schedule
......
...@@ -287,12 +287,12 @@ def test_multiple_func(): ...@@ -287,12 +287,12 @@ def test_multiple_func():
def test_llvm_select(): def test_llvm_condition():
def check_llvm(n, offset): def check_llvm(n, offset):
if not tvm.module.enabled("llvm"): if not tvm.module.enabled("llvm"):
return return
A = tvm.placeholder((n, ), name='A') A = tvm.placeholder((n, ), name='A')
C = tvm.compute((n,), lambda i: tvm.select(i >= offset, A[i], 0.0), name='C') C = tvm.compute((n,), lambda i: tvm.if_then_else(i >= offset, A[i], 0.0), name='C')
s = tvm.create_schedule(C.op) s = tvm.create_schedule(C.op)
# build and invoke the kernel. # build and invoke the kernel.
f = tvm.build(s, [A, C], "llvm") f = tvm.build(s, [A, C], "llvm")
...@@ -462,7 +462,7 @@ if __name__ == "__main__": ...@@ -462,7 +462,7 @@ if __name__ == "__main__":
test_rank_zero_bound_checkers() test_rank_zero_bound_checkers()
test_llvm_bool() test_llvm_bool()
test_llvm_persist_parallel() test_llvm_persist_parallel()
test_llvm_select() test_llvm_condition()
test_llvm_vadd_pipeline() test_llvm_vadd_pipeline()
test_llvm_add_pipeline() test_llvm_add_pipeline()
test_llvm_intrin() test_llvm_intrin()
......
...@@ -25,8 +25,8 @@ def test_copy_pad(): ...@@ -25,8 +25,8 @@ def test_copy_pad():
l = tvm.var('l') l = tvm.var('l')
A = tvm.placeholder((m, l), name='A') A = tvm.placeholder((m, l), name='A')
B = tvm.compute((m + 2, l), lambda i, j: B = tvm.compute((m + 2, l), lambda i, j:
tvm.select(tvm.all(i >= 1, i < m + 1), tvm.if_then_else(tvm.all(i >= 1, i < m + 1),
A[i - 1, j], 1.0), name='B') A[i - 1, j], 1.0), name='B')
s = tvm.create_schedule(B.op) s = tvm.create_schedule(B.op)
s[B].pragma(B.op.axis[0], "memcpy") s[B].pragma(B.op.axis[0], "memcpy")
bounds = tvm.schedule.InferBound(s) bounds = tvm.schedule.InferBound(s)
...@@ -71,8 +71,8 @@ def test_copy_pad_split(): ...@@ -71,8 +71,8 @@ def test_copy_pad_split():
m = 4 * 3 m = 4 * 3
A = tvm.placeholder((m, ), name="A") A = tvm.placeholder((m, ), name="A")
Apad = tvm.compute((m + 2,), lambda i: Apad = tvm.compute((m + 2,), lambda i:
tvm.select(tvm.all(i >= 1, i <= m), tvm.if_then_else(tvm.all(i >= 1, i <= m),
A[i - 1], 0.0), "Apad") A[i - 1], 0.0), "Apad")
B = tvm.compute((m,), lambda i: Apad[i] + Apad[i + 1] + Apad[i + 2]) B = tvm.compute((m,), lambda i: Apad[i] + Apad[i + 1] + Apad[i + 2])
s = tvm.create_schedule(B.op) s = tvm.create_schedule(B.op)
xo, xi = s[B].split(B.op.axis[0], factor=4) xo, xi = s[B].split(B.op.axis[0], factor=4)
......
...@@ -133,7 +133,7 @@ def test_vectorize(): ...@@ -133,7 +133,7 @@ def test_vectorize():
assert(x.var.name not in str(body.condition)) assert(x.var.name not in str(body.condition))
assert(any(collect_visit(body.then_case, lambda x: isinstance(x, tvm.expr.Ramp)))) assert(any(collect_visit(body.then_case, lambda x: isinstance(x, tvm.expr.Ramp))))
def test_select(): def test_condition():
ib = tvm.ir_builder.create() ib = tvm.ir_builder.create()
m = tvm.var('m') m = tvm.var('m')
n = tvm.var('n') n = tvm.var('n')
...@@ -335,7 +335,7 @@ if __name__ == "__main__": ...@@ -335,7 +335,7 @@ if __name__ == "__main__":
test_multi_if() test_multi_if()
test_thread_axis() test_thread_axis()
test_vectorize() test_vectorize()
test_select() test_condition()
test_thread_axis2() test_thread_axis2()
test_everything_during_deduction() test_everything_during_deduction()
test_single_likely() test_single_likely()
......
import tvm import tvm
def test_rewrite_select(): def test_rewrite_Select():
ib = tvm.ir_builder.create() ib = tvm.ir_builder.create()
A = ib.allocate("float32", 100, name="A", scope="global") A = ib.allocate("float32", 100, name="A", scope="global")
i = tvm.var("i") i = tvm.var("i")
y = tvm.select(i > 1, A[i-1], 1.0) y = tvm.expr.Select(i > 1, A[i-1], 1.0)
yy = tvm.ir_pass.RewriteUnsafeSelect(tvm.make.Evaluate(y)).value yy = tvm.ir_pass.RewriteUnsafeSelect(tvm.make.Evaluate(y)).value
z = tvm.select(tvm.select(i > 1, A[i-1], 1.0) > 0.0, A[i], 0.1) z = tvm.expr.Select(
tvm.expr.Select(i > 1, A[i-1], 1.0) > 0.0, A[i], 0.1)
zz = tvm.ir_pass.RewriteUnsafeSelect(tvm.make.Evaluate(z)).value zz = tvm.ir_pass.RewriteUnsafeSelect(tvm.make.Evaluate(z)).value
a = tvm.select(i>10, y, z) a = tvm.expr.Select(i>10, y, z)
aa = tvm.ir_pass.RewriteUnsafeSelect(tvm.make.Evaluate(a)).value aa = tvm.ir_pass.RewriteUnsafeSelect(tvm.make.Evaluate(a)).value
assert yy.name == "tvm_if_then_else" assert yy.name == "tvm_if_then_else"
assert zz.name == "tvm_if_then_else" assert zz.name == "tvm_if_then_else"
...@@ -19,4 +20,4 @@ def test_rewrite_select(): ...@@ -19,4 +20,4 @@ def test_rewrite_select():
if __name__ == "__main__": if __name__ == "__main__":
test_rewrite_select() test_rewrite_Select()
...@@ -63,8 +63,8 @@ def test_schedule_scan(): ...@@ -63,8 +63,8 @@ def test_schedule_scan():
def test_inline_multi_reduce(): def test_inline_multi_reduce():
def argmax_comp(x, y): def argmax_comp(x, y):
idx = tvm.select((x[1] >= y[1]), x[0], y[0]) idx = tvm.expr.Select((x[1] >= y[1]), x[0], y[0])
val = tvm.select((x[1] >= y[1]), x[1], y[1]) val = tvm.expr.Select((x[1] >= y[1]), x[1], y[1])
return idx, val return idx, val
def argmax_init(idx_typ, val_typ): def argmax_init(idx_typ, val_typ):
return tvm.const(-1, idx_typ), tvm.min_value(val_typ) return tvm.const(-1, idx_typ), tvm.min_value(val_typ)
...@@ -272,7 +272,7 @@ def test_schedule_cache_relayout4(): ...@@ -272,7 +272,7 @@ def test_schedule_cache_relayout4():
def test_schedule_bound_condition(): def test_schedule_bound_condition():
A = tvm.placeholder((64,), name='A', dtype="float32") A = tvm.placeholder((64,), name='A', dtype="float32")
Apad = tvm.compute((66,), lambda i: tvm.select( Apad = tvm.compute((66,), lambda i: tvm.if_then_else(
tvm.all(i>0, i < 65), A[i-1], tvm.const(0., "float32")), name='Apad') tvm.all(i>0, i < 65), A[i-1], tvm.const(0., "float32")), name='Apad')
Apad2 = tvm.compute((66,), lambda i: Apad[i]*2, name='Apad2') Apad2 = tvm.compute((66,), lambda i: Apad[i]*2, name='Apad2')
s = tvm.create_schedule(Apad2.op) s = tvm.create_schedule(Apad2.op)
...@@ -424,7 +424,7 @@ def test_loop_dep_reduce_cache_write(): ...@@ -424,7 +424,7 @@ def test_loop_dep_reduce_cache_write():
X = tvm.placeholder(shape=(10,), name="x") X = tvm.placeholder(shape=(10,), name="x")
def f(n): def f(n):
rv = tvm.reduce_axis((0, n)) rv = tvm.reduce_axis((0, n))
init = lambda dtype: tvm.select(n > 1, tvm.const(0, dtype), n.astype(dtype)) init = lambda dtype: tvm.expr.Select(n > 1, tvm.const(0, dtype), n.astype(dtype))
sum = tvm.comm_reducer(lambda x, y: tvm.max(x + y, n.astype('float32')), init, name='sum') sum = tvm.comm_reducer(lambda x, y: tvm.max(x + y, n.astype('float32')), init, name='sum')
return sum(X[rv], axis=rv) return sum(X[rv], axis=rv)
Y = tvm.compute(X.shape, f, name="y") Y = tvm.compute(X.shape, f, name="y")
......
...@@ -38,7 +38,7 @@ inline Expr bilinear_sample_nchw(const Tensor& input, const Array<Expr>& indices ...@@ -38,7 +38,7 @@ inline Expr bilinear_sample_nchw(const Tensor& input, const Array<Expr>& indices
auto yc = HalideIR::Internal::Cast::make(Int(32), tvm::ceil(in_y)); auto yc = HalideIR::Internal::Cast::make(Int(32), tvm::ceil(in_y));
auto y0 = HalideIR::Internal::Cast::make(Int(32), tvm::floor(in_y)); auto y0 = HalideIR::Internal::Cast::make(Int(32), tvm::floor(in_y));
auto y1 = tvm::select((yc > max_y), max_y, yc); auto y1 = tvm::if_then_else((yc > max_y), max_y, yc);
auto y_lerp = in_y - yf; auto y_lerp = in_y - yf;
auto in_x = indices[3]; auto in_x = indices[3];
...@@ -46,7 +46,7 @@ inline Expr bilinear_sample_nchw(const Tensor& input, const Array<Expr>& indices ...@@ -46,7 +46,7 @@ inline Expr bilinear_sample_nchw(const Tensor& input, const Array<Expr>& indices
auto xc = HalideIR::Internal::Cast::make(Int(32), tvm::ceil(in_x)); auto xc = HalideIR::Internal::Cast::make(Int(32), tvm::ceil(in_x));
auto x0 = HalideIR::Internal::Cast::make(Int(32), tvm::floor(in_x)); auto x0 = HalideIR::Internal::Cast::make(Int(32), tvm::floor(in_x));
auto x1 = tvm::select((xc > max_x), max_x, xc); auto x1 = tvm::if_then_else((xc > max_x), max_x, xc);
auto x_lerp = in_x - xf; auto x_lerp = in_x - xf;
auto A = input(indices[0], indices[1], y0, x0); auto A = input(indices[0], indices[1], y0, x0);
...@@ -215,7 +215,7 @@ inline Tensor resize_bilinear_nhwc(const Tensor& input, ...@@ -215,7 +215,7 @@ inline Tensor resize_bilinear_nhwc(const Tensor& input,
auto yc = HalideIR::Internal::Cast::make(Int(32), tvm::ceil(in_y)); auto yc = HalideIR::Internal::Cast::make(Int(32), tvm::ceil(in_y));
auto y0 = HalideIR::Internal::Cast::make(Int(32), tvm::floor(in_y)); auto y0 = HalideIR::Internal::Cast::make(Int(32), tvm::floor(in_y));
auto y1 = tvm::select((yc > other_y), other_y, yc); auto y1 = tvm::if_then_else((yc > other_y), other_y, yc);
auto y_lerp = in_y - yf; auto y_lerp = in_y - yf;
auto in_x = indices[2] * x_ratio; auto in_x = indices[2] * x_ratio;
...@@ -223,7 +223,7 @@ inline Tensor resize_bilinear_nhwc(const Tensor& input, ...@@ -223,7 +223,7 @@ inline Tensor resize_bilinear_nhwc(const Tensor& input,
auto xc = HalideIR::Internal::Cast::make(Int(32), tvm::ceil(in_x)); auto xc = HalideIR::Internal::Cast::make(Int(32), tvm::ceil(in_x));
auto x0 = HalideIR::Internal::Cast::make(Int(32), tvm::floor(in_x)); auto x0 = HalideIR::Internal::Cast::make(Int(32), tvm::floor(in_x));
auto x1 = tvm::select((xc > other_x), other_x, xc); auto x1 = tvm::if_then_else((xc > other_x), other_x, xc);
auto x_lerp = in_x - xf; auto x_lerp = in_x - xf;
auto A = input(indices[0], y0, x0, indices[3]); auto A = input(indices[0], y0, x0, indices[3]);
......
...@@ -75,7 +75,7 @@ inline tvm::Tensor leaky_relu(const tvm::Tensor& t, ...@@ -75,7 +75,7 @@ inline tvm::Tensor leaky_relu(const tvm::Tensor& t,
[&](const tvm::Array<tvm::Var>& i) { [&](const tvm::Array<tvm::Var>& i) {
auto value = t(i); auto value = t(i);
auto calpha = tvm::make_const(value.type(), alpha); auto calpha = tvm::make_const(value.type(), alpha);
return tvm::select(value > 0, value, value * calpha); return tvm::ir::Select::make(value > 0, value, value * calpha);
}, },
name, name,
tag); tag);
...@@ -106,9 +106,11 @@ inline tvm::Tensor prelu(const tvm::Tensor &x, ...@@ -106,9 +106,11 @@ inline tvm::Tensor prelu(const tvm::Tensor &x,
return tvm::compute(x->shape, return tvm::compute(x->shape,
[&](const tvm::Array<tvm::Var> &indices) { [&](const tvm::Array<tvm::Var> &indices) {
return tvm::select(x(indices) > 0, auto xval = x(indices);
x(indices), return tvm::ir::Select::make(
x(indices) * slope(indices[axis])); xval > 0,
xval,
xval * slope(indices[axis]));
}, },
name, name,
tag); tag);
...@@ -193,7 +195,8 @@ inline tvm::Tensor pad(const tvm::Tensor& t, ...@@ -193,7 +195,8 @@ inline tvm::Tensor pad(const tvm::Tensor& t,
} }
} }
if (sel.size() != 0) { if (sel.size() != 0) {
return tvm::select(detail::Map(sel, tvm::ir::And::make), t(indices), pad_value); return tvm::if_then_else(
detail::Map(sel, tvm::ir::And::make), t(indices), pad_value);
} }
return t(indices); return t(indices);
}; };
......
...@@ -76,7 +76,8 @@ inline Tensor dilate(const Tensor& x, ...@@ -76,7 +76,8 @@ inline Tensor dilate(const Tensor& x,
} }
if (not_zero.size() > 0) { if (not_zero.size() > 0) {
auto all_not_zero = all(not_zero); auto all_not_zero = all(not_zero);
return tvm::select(all_not_zero, x(index_tuple), make_const(x->dtype, 0)); return tvm::if_then_else(
all_not_zero, x(index_tuple), make_const(x->dtype, 0));
} }
return x(index_tuple); return x(index_tuple);
}, name, tag); }, name, tag);
......
...@@ -411,8 +411,8 @@ inline Tensor argmin(const Tensor& data, ...@@ -411,8 +411,8 @@ inline Tensor argmin(const Tensor& data,
bool atleast1d = false) { bool atleast1d = false) {
auto fcombine = [](Array<Var> lhs, Array<Var> rhs) { auto fcombine = [](Array<Var> lhs, Array<Var> rhs) {
Array<Expr> result; Array<Expr> result;
result.push_back(tvm::select(lhs[1] <= rhs[1], lhs[0], rhs[0])); // idx result.push_back(tvm::ir::Select::make(lhs[1] <= rhs[1], lhs[0], rhs[0])); // idx
result.push_back(tvm::select(lhs[1] <= rhs[1], lhs[1], rhs[1])); // val result.push_back(tvm::ir::Select::make(lhs[1] <= rhs[1], lhs[1], rhs[1])); // val
return result; return result;
}; };
auto fidentity = [](std::vector<Type> types) { auto fidentity = [](std::vector<Type> types) {
...@@ -445,8 +445,8 @@ inline Tensor argmax(const Tensor& data, ...@@ -445,8 +445,8 @@ inline Tensor argmax(const Tensor& data,
bool atleast1d = false) { bool atleast1d = false) {
auto fcombine = [](Array<Var> lhs, Array<Var> rhs) { auto fcombine = [](Array<Var> lhs, Array<Var> rhs) {
Array<Expr> result; Array<Expr> result;
result.push_back(tvm::select(lhs[1] >= rhs[1], lhs[0], rhs[0])); // idx result.push_back(tvm::ir::Select::make(lhs[1] >= rhs[1], lhs[0], rhs[0])); // idx
result.push_back(tvm::select(lhs[1] >= rhs[1], lhs[1], rhs[1])); // val result.push_back(tvm::ir::Select::make(lhs[1] >= rhs[1], lhs[1], rhs[1])); // val
return result; return result;
}; };
auto fidentity = [](std::vector<Type> types) { auto fidentity = [](std::vector<Type> types) {
......
...@@ -314,9 +314,9 @@ inline Tensor concatenate(const Array<Tensor>& inputs, ...@@ -314,9 +314,9 @@ inline Tensor concatenate(const Array<Tensor>& inputs,
idx.push_back(indices[i]); idx.push_back(indices[i]);
} }
ret = tvm::select(ind >= 0, ret = tvm::if_then_else(ind >= 0,
inputs[i + 1](idx), inputs[i + 1](idx),
ret); ret);
} }
return ret; return ret;
}, name, tag); }, name, tag);
...@@ -652,7 +652,7 @@ inline Tensor where(const Tensor& condition, ...@@ -652,7 +652,7 @@ inline Tensor where(const Tensor& condition,
<< condition->shape.size() << " vs " << x->shape.size(); << condition->shape.size() << " vs " << x->shape.size();
out = compute( out = compute(
oshape, [&](const Array<Var>& indices) { oshape, [&](const Array<Var>& indices) {
return tvm::select(condition(indices) != 0, x(indices), y(indices)); return tvm::ir::Select::make(condition(indices) != 0, x(indices), y(indices));
}, name, tag); }, name, tag);
} else { } else {
CHECK_EQ(topi::GetConstInt(condition->shape[0]), topi::GetConstInt(x->shape[0])) CHECK_EQ(topi::GetConstInt(condition->shape[0]), topi::GetConstInt(x->shape[0]))
...@@ -661,8 +661,8 @@ inline Tensor where(const Tensor& condition, ...@@ -661,8 +661,8 @@ inline Tensor where(const Tensor& condition,
out = compute( out = compute(
oshape, [&](const Array<Var>& indices) { oshape, [&](const Array<Var>& indices) {
Array<Expr> condition_idx{indices[0]}; Array<Expr> condition_idx{indices[0]};
return tvm::select(condition(condition_idx) != 0, return tvm::ir::Select::make(condition(condition_idx) != 0,
x(indices), y(indices)); x(indices), y(indices));
}, name, tag); }, name, tag);
} }
return out; return out;
......
...@@ -72,7 +72,7 @@ def conv2d_transpose_nchw_cuda(cfg, Input, Filter, strides, padding, out_dtype): ...@@ -72,7 +72,7 @@ def conv2d_transpose_nchw_cuda(cfg, Input, Filter, strides, padding, out_dtype):
index_tuple.append(indices[i]) index_tuple.append(indices[i])
if not_zero: if not_zero:
not_zero = tvm.all(*not_zero) not_zero = tvm.all(*not_zero)
return tvm.select(not_zero, data(*index_tuple), tvm.const(0.0, data.dtype)) return tvm.if_then_else(not_zero, data(*index_tuple), tvm.const(0.0, data.dtype))
return data(*index_tuple) return data(*index_tuple)
# convolution stage # convolution stage
......
...@@ -315,11 +315,11 @@ def sort_ir_out(data, index, new_index, loc, output, axis_mul_before, axis_mul_a ...@@ -315,11 +315,11 @@ def sort_ir_out(data, index, new_index, loc, output, axis_mul_before, axis_mul_a
start = 0 start = 0
with ib.else_scope(): with ib.else_scope():
start = sizes[tid-1] start = sizes[tid-1]
p_out[base_idx + k * axis_mul_after] = tvm.select( p_out[base_idx + k * axis_mul_after] = tvm.if_then_else(
k < p_index[tid], index_new[k+start], k) k < p_index[tid], index_new[k+start], k)
with ib.else_scope(): with ib.else_scope():
with ib.if_scope(tid < data.shape[axis]): with ib.if_scope(tid < data.shape[axis]):
p_out[tid] = tvm.select(tid < p_index[0], index_new[tid], tid) p_out[tid] = tvm.if_then_else(tid < p_index[0], index_new[tid], tid)
body = ib.get() body = ib.get()
return body return body
...@@ -470,7 +470,7 @@ def nms_ir(data, sort_result, valid_count, out, nms_threshold, force_suppress, n ...@@ -470,7 +470,7 @@ def nms_ir(data, sort_result, valid_count, out, nms_threshold, force_suppress, n
(out_tensor[box_a_idx + 3] - out_tensor[box_a_idx + 1]) + \ (out_tensor[box_a_idx + 3] - out_tensor[box_a_idx + 1]) + \
(out_tensor[box_b_idx + 2] - out_tensor[box_b_idx]) * \ (out_tensor[box_b_idx + 2] - out_tensor[box_b_idx]) * \
(out_tensor[box_b_idx + 3] - out_tensor[box_b_idx + 1]) - i (out_tensor[box_b_idx + 3] - out_tensor[box_b_idx + 1]) - i
return tvm.select(u <= 0.0, 0.0, i / u) return tvm.expr.Select(u <= 0.0, 0.0, i / u)
max_threads = int(math.sqrt( max_threads = int(math.sqrt(
tvm.target.current_target(allow_none=False).max_num_threads)) tvm.target.current_target(allow_none=False).max_num_threads))
...@@ -506,7 +506,7 @@ def nms_ir(data, sort_result, valid_count, out, nms_threshold, force_suppress, n ...@@ -506,7 +506,7 @@ def nms_ir(data, sort_result, valid_count, out, nms_threshold, force_suppress, n
tvm.all(nms_threshold_node > 0, nms_threshold_node < 1, tvm.all(nms_threshold_node > 0, nms_threshold_node < 1,
p_valid_count[0] > 0)): p_valid_count[0] > 0)):
# Reorder output # Reorder output
nkeep = tvm.select( nkeep = tvm.if_then_else(
tvm.all(nms_topk_node > 0, nms_topk < p_valid_count[n]), tvm.all(nms_topk_node > 0, nms_topk < p_valid_count[n]),
nms_topk, p_valid_count[n]) nms_topk, p_valid_count[n])
with ib.if_scope(i < nkeep): with ib.if_scope(i < nkeep):
......
...@@ -77,13 +77,14 @@ def multibox_prior_ir(data, out, sizes, ratios, steps, offsets): ...@@ -77,13 +77,14 @@ def multibox_prior_ir(data, out, sizes, ratios, steps, offsets):
center_w = (j + offset_w) * steps_w center_w = (j + offset_w) * steps_w
for k in range(num_sizes + num_ratios - 1): for k in range(num_sizes + num_ratios - 1):
w = tvm.select(k < num_sizes, w = tvm.if_then_else(k < num_sizes,
size_ratio_concat[ size_ratio_concat[
k] * in_height / in_width / 2.0, k] * in_height / in_width / 2.0,
size_ratio_concat[0] * in_height / in_width * size_ratio_concat[0] * in_height / in_width *
math.sqrt(size_ratio_concat[k + 1]) / 2.0) math.sqrt(size_ratio_concat[k + 1]) / 2.0)
h = tvm.select(k < num_sizes, size_ratio_concat[k] / 2.0, h = tvm.if_then_else(
size_ratio_concat[0] / math.sqrt(size_ratio_concat[k + 1]) / 2.0) k < num_sizes, size_ratio_concat[k] / 2.0,
size_ratio_concat[0] / math.sqrt(size_ratio_concat[k + 1]) / 2.0)
count = (i * in_width * (num_sizes + num_ratios - 1) + count = (i * in_width * (num_sizes + num_ratios - 1) +
j * (num_sizes + num_ratios - 1) + k) * 4 j * (num_sizes + num_ratios - 1) + k) * 4
p_out[count] = center_w - w p_out[count] = center_w - w
...@@ -278,10 +279,10 @@ def transform_loc_ir(loc_pred, anchor, temp_flag, temp_id, temp_score_in, \ ...@@ -278,10 +279,10 @@ def transform_loc_ir(loc_pred, anchor, temp_flag, temp_id, temp_score_in, \
oy = py * vy * ah + ay oy = py * vy * ah + ay
ow = tvm.exp(pw * vw) * aw / 2.0 ow = tvm.exp(pw * vw) * aw / 2.0
oh = tvm.exp(ph * vh) * ah / 2.0 oh = tvm.exp(ph * vh) * ah / 2.0
return tvm.select(clip, tvm.make.Max(0.0, tvm.make.Min(1.0, ox - ow)), ox - ow), \ return tvm.if_then_else(clip, tvm.make.Max(0.0, tvm.make.Min(1.0, ox - ow)), ox - ow), \
tvm.select(clip, tvm.make.Max(0.0, tvm.make.Min(1.0, oy - oh)), oy - oh), \ tvm.if_then_else(clip, tvm.make.Max(0.0, tvm.make.Min(1.0, oy - oh)), oy - oh), \
tvm.select(clip, tvm.make.Max(0.0, tvm.make.Min(1.0, ox + ow)), ox + ow), \ tvm.if_then_else(clip, tvm.make.Max(0.0, tvm.make.Min(1.0, ox + ow)), ox + ow), \
tvm.select(clip, tvm.make.Max(0.0, tvm.make.Min(1.0, oy + oh)), oy + oh) tvm.if_then_else(clip, tvm.make.Max(0.0, tvm.make.Min(1.0, oy + oh)), oy + oh)
max_threads = int( max_threads = int(
tvm.target.current_target(allow_none=False).max_num_threads) tvm.target.current_target(allow_none=False).max_num_threads)
......
...@@ -296,9 +296,10 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt ...@@ -296,9 +296,10 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt
# pack input tile # pack input tile
input_tile = tvm.compute((CI, P_round // bnb, alpha, alpha, bnb), lambda ci, b, eps, nu, bb: \ input_tile = tvm.compute((CI, P_round // bnb, alpha, alpha, bnb), lambda ci, b, eps, nu, bb: \
tvm.select(b * bnb + bb < P, tvm.if_then_else(
data_pad[(b*bnb+bb) // (nH*nW)][ci][(b*bnb+bb) // nW % nH * m + eps] b * bnb + bb < P,
[(b*bnb+bb) % nW * m + nu], tvm.const(0, data_pad.dtype)), name='d') data_pad[(b*bnb+bb) // (nH*nW)][ci][(b*bnb+bb) // nW % nH * m + eps]
[(b*bnb+bb) % nW * m + nu], tvm.const(0, data_pad.dtype)), name='d')
# transform kernel # transform kernel
if pre_computed: if pre_computed:
......
...@@ -44,7 +44,7 @@ def dilate(data, strides, name="DilatedInput"): ...@@ -44,7 +44,7 @@ def dilate(data, strides, name="DilatedInput"):
index_tuple.append(indices[i]) index_tuple.append(indices[i])
if not_zero: if not_zero:
not_zero = tvm.all(*not_zero) not_zero = tvm.all(*not_zero)
return tvm.select(not_zero, data(*index_tuple), tvm.const(0.0, data.dtype)) return tvm.if_then_else(not_zero, data(*index_tuple), tvm.const(0.0, data.dtype))
return data(*index_tuple) return data(*index_tuple)
return tvm.compute(out_shape, _dilate, name=name) return tvm.compute(out_shape, _dilate, name=name)
...@@ -41,7 +41,7 @@ def leaky_relu(x, alpha): ...@@ -41,7 +41,7 @@ def leaky_relu(x, alpha):
def _compute(*indices): def _compute(*indices):
value = x(*indices) value = x(*indices)
calpha = tvm.const(alpha, value.dtype) calpha = tvm.const(alpha, value.dtype)
return tvm.select(value > 0, value, value * calpha) return tvm.expr.Select(value > 0, value, value * calpha)
return tvm.compute(x.shape, _compute) return tvm.compute(x.shape, _compute)
@tvm.tag_scope(tag=tag.BROADCAST) @tvm.tag_scope(tag=tag.BROADCAST)
...@@ -74,5 +74,6 @@ def prelu(x, slope, axis=1): ...@@ -74,5 +74,6 @@ def prelu(x, slope, axis=1):
assert get_const_int(slope.shape[0]) == get_const_int(x.shape[axis]) assert get_const_int(slope.shape[0]) == get_const_int(x.shape[axis])
def _compute_channelwise(*indices): def _compute_channelwise(*indices):
return tvm.select(x(*indices) > 0, x(*indices), x(*indices) * slope(indices[axis])) xval = x(*indices)
return tvm.expr.Select(xval > 0, xval, xval * slope(indices[axis]))
return tvm.compute(x.shape, _compute_channelwise) return tvm.compute(x.shape, _compute_channelwise)
...@@ -55,6 +55,6 @@ def pad(data, pad_before, pad_after=None, pad_value=0.0, name="PadInput"): ...@@ -55,6 +55,6 @@ def pad(data, pad_before, pad_after=None, pad_value=0.0, name="PadInput"):
not_zero.append(indices[i] < data.shape[i] + pad_before[i]) not_zero.append(indices[i] < data.shape[i] + pad_before[i])
if not_zero: if not_zero:
not_zero = tvm.all(*not_zero) not_zero = tvm.all(*not_zero)
return tvm.select(not_zero, data(*index_tuple), pad_value) return tvm.if_then_else(not_zero, data(*index_tuple), pad_value)
return data(*index_tuple) return data(*index_tuple)
return tvm.compute(out_shape, _pad, name=name) return tvm.compute(out_shape, _pad, name=name)
...@@ -55,8 +55,8 @@ def infer_stride(data, kernel, out): ...@@ -55,8 +55,8 @@ def infer_stride(data, kernel, out):
_, _, IH, IW = data.shape _, _, IH, IW = data.shape
_, _, KH, KW = kernel.shape _, _, KH, KW = kernel.shape
_, _, OH, OW = out.shape _, _, OH, OW = out.shape
hstride = (IH - KH) // tvm.make.Max(OH - 1, 1) + tvm.select(OH == 1, 1, 0) hstride = (IH - KH) // tvm.make.Max(OH - 1, 1) + tvm.expr.Select(OH == 1, 1, 0)
wstride = (IW - KW) // tvm.make.Max(OW - 1, 1) + tvm.select(OW == 1, 1, 0) wstride = (IW - KW) // tvm.make.Max(OW - 1, 1) + tvm.expr.Select(OW == 1, 1, 0)
return get_const_int(hstride), get_const_int(wstride) return get_const_int(hstride), get_const_int(wstride)
......
...@@ -249,9 +249,9 @@ def const_matrix(matrix, name="const_matrix"): ...@@ -249,9 +249,9 @@ def const_matrix(matrix, name="const_matrix"):
now = tvm.const(0.0, dtype) now = tvm.const(0.0, dtype)
for ii in range(row): for ii in range(row):
for jj in range(col): for jj in range(col):
now = tvm.select(tvm.all(i % row == ii, j % col == jj), now = tvm.expr.Select(tvm.all(i % row == ii, j % col == jj),
tvm.const(matrix[ii][jj], dtype), tvm.const(matrix[ii][jj], dtype),
now) now)
return now return now
return tvm.compute(matrix.shape, select_array, name=name) return tvm.compute(matrix.shape, select_array, name=name)
...@@ -47,7 +47,7 @@ def nms_ir(data, sort_result, valid_count, out, nms_threshold, force_suppress, n ...@@ -47,7 +47,7 @@ def nms_ir(data, sort_result, valid_count, out, nms_threshold, force_suppress, n
(out_tensor[box_a_idx + 3] - out_tensor[box_a_idx + 1]) + \ (out_tensor[box_a_idx + 3] - out_tensor[box_a_idx + 1]) + \
(out_tensor[box_b_idx + 2] - out_tensor[box_b_idx]) * \ (out_tensor[box_b_idx + 2] - out_tensor[box_b_idx]) * \
(out_tensor[box_b_idx + 3] - out_tensor[box_b_idx + 1]) - i (out_tensor[box_b_idx + 3] - out_tensor[box_b_idx + 1]) - i
return tvm.select(u <= 0.0, 0.0, i / u) return tvm.expr.Select(u <= 0.0, 0.0, i / u)
ib = tvm.ir_builder.create() ib = tvm.ir_builder.create()
p_data = ib.buffer_ptr(data) p_data = ib.buffer_ptr(data)
...@@ -64,8 +64,9 @@ def nms_ir(data, sort_result, valid_count, out, nms_threshold, force_suppress, n ...@@ -64,8 +64,9 @@ def nms_ir(data, sort_result, valid_count, out, nms_threshold, force_suppress, n
with ib.if_scope(tvm.all(nms_threshold_node > 0, nms_threshold_node < 1, with ib.if_scope(tvm.all(nms_threshold_node > 0, nms_threshold_node < 1,
p_valid_count[0] > 0)): p_valid_count[0] > 0)):
# Reorder output # Reorder output
nkeep = tvm.select(tvm.all(nms_topk_node > 0, nms_topk < p_valid_count[n]), nkeep = tvm.if_then_else(
nms_topk, p_valid_count[n]) tvm.all(nms_topk_node > 0, nms_topk < p_valid_count[n]),
nms_topk, p_valid_count[n])
with ib.for_range(0, nkeep, name="l") as l: with ib.for_range(0, nkeep, name="l") as l:
with ib.for_range(0, 6, name="m") as m: with ib.for_range(0, 6, name="m") as m:
p_out[(n * num_anchors * 6 p_out[(n * num_anchors * 6
......
...@@ -47,7 +47,7 @@ def roi_align_nchw(data, rois, pooled_size, spatial_scale, sample_ratio=-1): ...@@ -47,7 +47,7 @@ def roi_align_nchw(data, rois, pooled_size, spatial_scale, sample_ratio=-1):
y = tvm.max(y, 0.0) y = tvm.max(y, 0.0)
x = tvm.max(x, 0.0) x = tvm.max(x, 0.0)
val = bilinear_sample_nchw(data, (i, c, y, x), height - 1, width - 1) val = bilinear_sample_nchw(data, (i, c, y, x), height - 1, width - 1)
return tvm.select(outside, 0.0, val) return tvm.if_then_else(outside, 0.0, val)
def _sample(i, c, ph, pw): def _sample(i, c, ph, pw):
roi = rois[i] roi = rois[i]
......
...@@ -55,12 +55,13 @@ def multibox_prior_ir(data, out, sizes, ratios, steps, offsets): ...@@ -55,12 +55,13 @@ def multibox_prior_ir(data, out, sizes, ratios, steps, offsets):
with ib.for_range(0, in_width, name="j") as j: with ib.for_range(0, in_width, name="j") as j:
center_w = (j + offset_w) * steps_w center_w = (j + offset_w) * steps_w
for k in range(num_sizes + num_ratios - 1): for k in range(num_sizes + num_ratios - 1):
w = tvm.select(k < num_sizes, w = tvm.if_then_else(k < num_sizes,
size_ratio_concat[k] * in_height / in_width / 2.0, size_ratio_concat[k] * in_height / in_width / 2.0,
size_ratio_concat[0] * in_height / in_width * size_ratio_concat[0] * in_height / in_width *
math.sqrt(size_ratio_concat[k + 1]) / 2.0) math.sqrt(size_ratio_concat[k + 1]) / 2.0)
h = tvm.select(k < num_sizes, size_ratio_concat[k] / 2.0, h = tvm.if_then_else(
size_ratio_concat[0] / math.sqrt(size_ratio_concat[k + 1]) / 2.0) k < num_sizes, size_ratio_concat[k] / 2.0,
size_ratio_concat[0] / math.sqrt(size_ratio_concat[k + 1]) / 2.0)
count = (i * in_width * (num_sizes + num_ratios - 1) + count = (i * in_width * (num_sizes + num_ratios - 1) +
j * (num_sizes + num_ratios - 1) + k) * 4 j * (num_sizes + num_ratios - 1) + k) * 4
p_out[count] = center_w - w p_out[count] = center_w - w
...@@ -164,10 +165,10 @@ def transform_loc_ir(cls_prob, loc_pred, anchor, valid_count, out, clip, thresho ...@@ -164,10 +165,10 @@ def transform_loc_ir(cls_prob, loc_pred, anchor, valid_count, out, clip, thresho
oy = py * vy * ah + ay oy = py * vy * ah + ay
ow = tvm.exp(pw * vw) * aw / 2.0 ow = tvm.exp(pw * vw) * aw / 2.0
oh = tvm.exp(ph * vh) * ah / 2.0 oh = tvm.exp(ph * vh) * ah / 2.0
return tvm.select(clip, tvm.max(0, tvm.min(1, ox - ow)), ox - ow), \ return tvm.if_then_else(clip, tvm.max(0, tvm.min(1, ox - ow)), ox - ow), \
tvm.select(clip, tvm.max(0, tvm.min(1, oy - oh)), oy - oh), \ tvm.if_then_else(clip, tvm.max(0, tvm.min(1, oy - oh)), oy - oh), \
tvm.select(clip, tvm.max(0, tvm.min(1, ox + ow)), ox + ow), \ tvm.if_then_else(clip, tvm.max(0, tvm.min(1, ox + ow)), ox + ow), \
tvm.select(clip, tvm.max(0, tvm.min(1, oy + oh)), oy + oh) tvm.if_then_else(clip, tvm.max(0, tvm.min(1, oy + oh)), oy + oh)
batch_size = cls_prob.shape[0] batch_size = cls_prob.shape[0]
num_classes = cls_prob.shape[1] num_classes = cls_prob.shape[1]
...@@ -190,7 +191,7 @@ def transform_loc_ir(cls_prob, loc_pred, anchor, valid_count, out, clip, thresho ...@@ -190,7 +191,7 @@ def transform_loc_ir(cls_prob, loc_pred, anchor, valid_count, out, clip, thresho
with ib.for_range(0, num_classes, name="j") as j: with ib.for_range(0, num_classes, name="j") as j:
with ib.if_scope(j > 0): with ib.if_scope(j > 0):
temp = p_cls_prob[n * num_anchors * num_classes + j * num_anchors + i] temp = p_cls_prob[n * num_anchors * num_classes + j * num_anchors + i]
cls_id[0] = tvm.select(temp > score[0], j, cls_id[0]) cls_id[0] = tvm.if_then_else(temp > score[0], j, cls_id[0])
score[0] = tvm.max(temp, score[0]) score[0] = tvm.max(temp, score[0])
with ib.if_scope(tvm.all(cls_id[0] > 0, score[0] < threshold)): with ib.if_scope(tvm.all(cls_id[0] > 0, score[0] < threshold)):
cls_id[0] = 0 cls_id[0] = 0
......
...@@ -65,7 +65,7 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p ...@@ -65,7 +65,7 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p
else: else:
func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation)) func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation))
func(a, w, c) func(a, w, c)
tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5) tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-4)
for device in get_all_backend(): for device in get_all_backend():
with autotvm.tophub.context(device): # load tophub pre-tuned parameters with autotvm.tophub.context(device): # load tophub pre-tuned parameters
......
...@@ -45,8 +45,8 @@ print(tvm.lower(s, [A0, A1, B0, B1], simple_mode=True)) ...@@ -45,8 +45,8 @@ print(tvm.lower(s, [A0, A1, B0, B1], simple_mode=True))
# x and y are the operands of reduction, both of them is a tuple of index # x and y are the operands of reduction, both of them is a tuple of index
# and value. # and value.
def fcombine(x, y): def fcombine(x, y):
lhs = tvm.select((x[1] >= y[1]), x[0], y[0]) lhs = tvm.expr.Select((x[1] >= y[1]), x[0], y[0])
rhs = tvm.select((x[1] >= y[1]), x[1], y[1]) rhs = tvm.expr.Select((x[1] >= y[1]), x[1], y[1])
return lhs, rhs return lhs, rhs
# our identity element also need to be a tuple, so `fidentity` accepts # our identity element also need to be a tuple, so `fidentity` accepts
......
...@@ -43,7 +43,7 @@ out_size = (in_size - kernel + 2*pad) // stride + 1 ...@@ -43,7 +43,7 @@ out_size = (in_size - kernel + 2*pad) // stride + 1
# Pad input # Pad input
Apad = tvm.compute( Apad = tvm.compute(
(in_size + 2*pad, in_size + 2*pad, in_channel, batch), (in_size + 2*pad, in_size + 2*pad, in_channel, batch),
lambda yy, xx, cc, nn: tvm.select( lambda yy, xx, cc, nn: tvm.if_then_else(
tvm.all(yy >= pad, yy - pad < in_size, tvm.all(yy >= pad, yy - pad < in_size,
xx >= pad, xx - pad < in_size), xx >= pad, xx - pad < in_size),
A[yy - pad, xx - pad, cc, nn], tvm.const(0., "float32")), A[yy - pad, xx - pad, cc, nn], tvm.const(0., "float32")),
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment