Unverified Commit 9d20fa1b by Tianqi Chen Committed by GitHub

[PASS][TENSOR] Use correct select semantics (#2394)

parent 98e761f8
......@@ -11,6 +11,7 @@ tvm.intrin
tvm.call_extern
tvm.call_llvm_intrin
tvm.register_intrin_rule
tvm.if_then_else
tvm.exp
tvm.log
tvm.floor
......@@ -26,6 +27,7 @@ tvm.intrin
.. autofunction:: tvm.call_extern
.. autofunction:: tvm.call_llvm_intrin
.. autofunction:: tvm.register_intrin_rule
.. autofunction:: tvm.if_then_else
.. autofunction:: tvm.exp
.. autofunction:: tvm.log
.. autofunction:: tvm.floor
......
......@@ -15,7 +15,6 @@ The user facing API for computation declaration.
tvm.extern
tvm.decl_buffer
tvm.reduce_axis
tvm.select
tvm.thread_axis
tvm.comm_reducer
tvm.sum
......@@ -34,7 +33,6 @@ The user facing API for computation declaration.
.. autofunction:: tvm.extern
.. autofunction:: tvm.decl_buffer
.. autofunction:: tvm.reduce_axis
.. autofunction:: tvm.select
.. autofunction:: tvm.thread_axis
.. autofunction:: tvm.comm_reducer
.. autofunction:: tvm.sum
......
......@@ -392,7 +392,7 @@ TVM_DLL Expr operator^(Expr a, Expr b);
*/
TVM_DLL Expr operator~(Expr a);
/*!
* \brief select result by condition
* \brief Conditional expression.
*
* \param cond The condition
* \param true_value The value when results are true.
......@@ -401,7 +401,7 @@ TVM_DLL Expr operator~(Expr a);
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr select(Expr cond, Expr true_value, Expr false_value);
TVM_DLL Expr if_then_else(Expr cond, Expr true_value, Expr false_value);
/*!
* \brief Mark condition as likely.
* \param cond The condition
......
......@@ -669,28 +669,6 @@ def reduce_axis(dom, name="rv"):
return _IterVar(dom, name, 2)
def select(cond, t, f):
"""Construct a select branch.
Parameters
----------
cond : Expr
The condition
t : Expr
The result expression if cond is true.
f : Expr
The result expression if cond is false.
Returns
-------
node : Node
The tvm.expr.Select node
"""
return _expr.Select(convert(cond), convert(t), convert(f))
def comm_reducer(fcombine, fidentity, name="reduce"):
"""Create a commutative reducer for reduction.
......
......@@ -624,6 +624,13 @@ class Not(LogicalExpr):
class Select(Expr):
"""Select node.
Note
----
Select may compute both true_value and false_value.
Use :any:`tvm.if_then_else` instead if you want to
get a conditional expression that only evaluates
the correct branch.
Parameters
----------
condition : Expr
......@@ -634,6 +641,7 @@ class Select(Expr):
false_value : Expr
The value to take when condition is false.
"""
def __init__(self, condition, true_value, false_value):
self.__init_handle_by_constructor__(
......
......@@ -393,6 +393,42 @@ def fmod(x, y):
"""
return call_pure_intrin(x.dtype, "fmod", x, y)
def if_then_else(cond, t, f):
"""Conditional selection expression.
Parameters
----------
cond : Expr
The condition
t : Expr
The result expression if cond is true.
f : Expr
The result expression if cond is false.
Returns
-------
result : Node
The result of conditional expression.
Note
----
Unlike Select, if_then_else will not execute
the branch that does not satisfy the condition.
You can use it to guard against out of bound access.
Unlike Select, if_then_else cannot be vectorized
if some lanes in the vector have different conditions.
"""
t = convert(t)
f = convert(f)
cond = convert(cond)
if cond.dtype != "bool":
raise TypeError("The condition's data type has to be bool")
return call_pure_intrin(t.dtype, "tvm_if_then_else", cond, t, f)
# Intrinsic rule related code
def register_intrin_rule(target, intrin, f=None, override=False):
"""Register an intrinsic function generation rule.
......
......@@ -268,8 +268,9 @@ inline IntSet CombineInterval<Mul>(Interval a, Interval b) {
} else if (is_negative_const(b.min)) {
return IntervalSet::make(e2, e1);
} else if (a.is_bounded()) {
using ir::Select;
Expr cmp = b.min >= make_zero(b.min.type().element_of());
return IntervalSet::make(select(cmp, e1, e2), select(cmp, e2, e1));
return IntervalSet::make(Select::make(cmp, e1, e2), Select::make(cmp, e2, e1));
}
}
LOG(WARNING) << "Return Everything in CombineInterval Mul";
......@@ -294,8 +295,9 @@ inline IntSet CombineInterval<Div>(Interval a, Interval b) {
} else if (is_negative_const(b.min)) {
return IntervalSet::make(e2, e1);
} else if (a.is_bounded()) {
using ir::Select;
Expr cmp = b.min >= make_zero(b.min.type().element_of());
return IntervalSet::make(select(cmp, e1, e2), select(cmp, e2, e1));
return IntervalSet::make(Select::make(cmp, e1, e2), Select::make(cmp, e2, e1));
}
}
LOG(WARNING) << "Return Everything in CombineInterval Div";
......
......@@ -240,10 +240,11 @@ Expr max(Expr a, Expr b) {
return ir::Max::make(a, b);
}
Expr select(Expr cond, Expr true_value, Expr false_value) {
Expr if_then_else(Expr cond, Expr true_value, Expr false_value) {
using ir::IntImm;
using ir::UIntImm;
CHECK(cond.type().is_bool());
CHECK(cond.type() == Bool(1))
<< "if_then_else only accept a single condition";
BinaryOpMatchTypes(true_value, false_value);
if (const UIntImm* op = cond.as<UIntImm>()) {
if (op->value != 0) {
......@@ -258,7 +259,11 @@ Expr select(Expr cond, Expr true_value, Expr false_value) {
return false_value;
}
}
return ir::Select::make(cond, true_value, false_value);
return ir::Call::make(
true_value.type(),
ir::intrinsic::tvm_if_then_else,
{cond, true_value, false_value},
ir::Call::PureIntrinsic);
}
Expr likely(Expr cond) {
......@@ -402,7 +407,12 @@ Expr pow(Expr x, Expr y) {
Expr abs(Expr x) {
if (x.type().is_int()) {
return select(x >= make_zero(x.type()), x, -x);
using ir::IntImm;
const IntImm* px = x.as<IntImm>();
if (px) {
return ir::IntImm::make(x.type(), std::abs(px->value));
}
return ir::Select::make(x >= make_zero(x.type()), x, -x);
} else if (x.type().is_float()) {
return ir::Call::make(x.type(), "fabs", {x}, ir::Call::PureIntrinsic);
} else if (x.type().is_uint()) {
......
......@@ -35,6 +35,26 @@ class CopyIntrinInjector : public IRMutator {
}
private:
bool MatchCondition(Expr expr,
Expr* cond,
Expr* true_value,
Expr* false_value) {
if (const auto* op = expr.as<Select>()) {
*cond = op->condition;
*true_value = op->true_value;
*false_value = op->false_value;
return true;
} else if (const auto* op = expr.as<Call>()) {
if (op->name == intrinsic::tvm_if_then_else) {
*cond = op->args[0];
*true_value = op->args[1];
*false_value = op->args[2];
return true;
}
}
return false;
}
bool MatchCopyPattern(Stmt stmt, Stmt *out) {
Stmt body = stmt;
bool is_single_point_copy = false;
......@@ -48,16 +68,20 @@ class CopyIntrinInjector : public IRMutator {
}
const Store* store = body.as<Store>();
if (store == nullptr) return false;
const Select* select = store->value.as<Select>();
Expr sel_cond, sel_true_value, sel_false_value;
bool has_cond = MatchCondition(store->value,
&sel_cond,
&sel_true_value,
&sel_false_value);
const Cast* cast = store->value.as<Cast>();
const Load* load = store->value.as<Load>();
if (0 == loops.size()) {
is_single_point_copy = true;
CHECK(select == nullptr);
CHECK(!has_cond);
}
// for now only support true condition matching
if (select != nullptr) {
load = select->true_value.as<Load>();
if (has_cond) {
load = sel_true_value.as<Load>();
}
// cast can be part of the pattern
if (cast != nullptr) {
......@@ -88,10 +112,10 @@ class CopyIntrinInjector : public IRMutator {
Array<Expr> pad_before, pad_after;
Expr pad_value;
Expr src_elem_offset = load_strides[loop_var_size];
if (select != nullptr) {
if (has_cond) {
Array<Expr> clip_bound =
arith::DetectClipBound(select->condition, loop_vars);
pad_value = select->false_value;
arith::DetectClipBound(sel_cond, loop_vars);
pad_value = sel_false_value;
if (clip_bound.size() == 0) return false;
CHECK_EQ(src_shape.size(), loop_vars.size());
CHECK_EQ(clip_bound.size(), loop_vars.size() * 2);
......
......@@ -8,7 +8,7 @@ def test_reduce_prims():
n = tvm.var('n')
m = tvm.var('m')
A = tvm.placeholder((n, m), name='A')
R = tvm.compute((n, ), lambda i: tvm.select((i > 1), 1, 0), name='R')
R = tvm.compute((n, ), lambda i: tvm.expr.Select((i > 1), 1, 0), name='R')
k = tvm.reduce_axis((0, m))
B = tvm.compute((n,), lambda i: reducer(A[i, k], axis=k, where=(R[i]==1)), name='B')
# schedule
......
......@@ -287,12 +287,12 @@ def test_multiple_func():
def test_llvm_select():
def test_llvm_condition():
def check_llvm(n, offset):
if not tvm.module.enabled("llvm"):
return
A = tvm.placeholder((n, ), name='A')
C = tvm.compute((n,), lambda i: tvm.select(i >= offset, A[i], 0.0), name='C')
C = tvm.compute((n,), lambda i: tvm.if_then_else(i >= offset, A[i], 0.0), name='C')
s = tvm.create_schedule(C.op)
# build and invoke the kernel.
f = tvm.build(s, [A, C], "llvm")
......@@ -462,7 +462,7 @@ if __name__ == "__main__":
test_rank_zero_bound_checkers()
test_llvm_bool()
test_llvm_persist_parallel()
test_llvm_select()
test_llvm_condition()
test_llvm_vadd_pipeline()
test_llvm_add_pipeline()
test_llvm_intrin()
......
......@@ -25,7 +25,7 @@ def test_copy_pad():
l = tvm.var('l')
A = tvm.placeholder((m, l), name='A')
B = tvm.compute((m + 2, l), lambda i, j:
tvm.select(tvm.all(i >= 1, i < m + 1),
tvm.if_then_else(tvm.all(i >= 1, i < m + 1),
A[i - 1, j], 1.0), name='B')
s = tvm.create_schedule(B.op)
s[B].pragma(B.op.axis[0], "memcpy")
......@@ -71,7 +71,7 @@ def test_copy_pad_split():
m = 4 * 3
A = tvm.placeholder((m, ), name="A")
Apad = tvm.compute((m + 2,), lambda i:
tvm.select(tvm.all(i >= 1, i <= m),
tvm.if_then_else(tvm.all(i >= 1, i <= m),
A[i - 1], 0.0), "Apad")
B = tvm.compute((m,), lambda i: Apad[i] + Apad[i + 1] + Apad[i + 2])
s = tvm.create_schedule(B.op)
......
......@@ -133,7 +133,7 @@ def test_vectorize():
assert(x.var.name not in str(body.condition))
assert(any(collect_visit(body.then_case, lambda x: isinstance(x, tvm.expr.Ramp))))
def test_select():
def test_condition():
ib = tvm.ir_builder.create()
m = tvm.var('m')
n = tvm.var('n')
......@@ -335,7 +335,7 @@ if __name__ == "__main__":
test_multi_if()
test_thread_axis()
test_vectorize()
test_select()
test_condition()
test_thread_axis2()
test_everything_during_deduction()
test_single_likely()
......
import tvm
def test_rewrite_select():
def test_rewrite_Select():
ib = tvm.ir_builder.create()
A = ib.allocate("float32", 100, name="A", scope="global")
i = tvm.var("i")
y = tvm.select(i > 1, A[i-1], 1.0)
y = tvm.expr.Select(i > 1, A[i-1], 1.0)
yy = tvm.ir_pass.RewriteUnsafeSelect(tvm.make.Evaluate(y)).value
z = tvm.select(tvm.select(i > 1, A[i-1], 1.0) > 0.0, A[i], 0.1)
z = tvm.expr.Select(
tvm.expr.Select(i > 1, A[i-1], 1.0) > 0.0, A[i], 0.1)
zz = tvm.ir_pass.RewriteUnsafeSelect(tvm.make.Evaluate(z)).value
a = tvm.select(i>10, y, z)
a = tvm.expr.Select(i>10, y, z)
aa = tvm.ir_pass.RewriteUnsafeSelect(tvm.make.Evaluate(a)).value
assert yy.name == "tvm_if_then_else"
assert zz.name == "tvm_if_then_else"
......@@ -19,4 +20,4 @@ def test_rewrite_select():
if __name__ == "__main__":
test_rewrite_select()
test_rewrite_Select()
......@@ -63,8 +63,8 @@ def test_schedule_scan():
def test_inline_multi_reduce():
def argmax_comp(x, y):
idx = tvm.select((x[1] >= y[1]), x[0], y[0])
val = tvm.select((x[1] >= y[1]), x[1], y[1])
idx = tvm.expr.Select((x[1] >= y[1]), x[0], y[0])
val = tvm.expr.Select((x[1] >= y[1]), x[1], y[1])
return idx, val
def argmax_init(idx_typ, val_typ):
return tvm.const(-1, idx_typ), tvm.min_value(val_typ)
......@@ -272,7 +272,7 @@ def test_schedule_cache_relayout4():
def test_schedule_bound_condition():
A = tvm.placeholder((64,), name='A', dtype="float32")
Apad = tvm.compute((66,), lambda i: tvm.select(
Apad = tvm.compute((66,), lambda i: tvm.if_then_else(
tvm.all(i>0, i < 65), A[i-1], tvm.const(0., "float32")), name='Apad')
Apad2 = tvm.compute((66,), lambda i: Apad[i]*2, name='Apad2')
s = tvm.create_schedule(Apad2.op)
......@@ -424,7 +424,7 @@ def test_loop_dep_reduce_cache_write():
X = tvm.placeholder(shape=(10,), name="x")
def f(n):
rv = tvm.reduce_axis((0, n))
init = lambda dtype: tvm.select(n > 1, tvm.const(0, dtype), n.astype(dtype))
init = lambda dtype: tvm.expr.Select(n > 1, tvm.const(0, dtype), n.astype(dtype))
sum = tvm.comm_reducer(lambda x, y: tvm.max(x + y, n.astype('float32')), init, name='sum')
return sum(X[rv], axis=rv)
Y = tvm.compute(X.shape, f, name="y")
......
......@@ -38,7 +38,7 @@ inline Expr bilinear_sample_nchw(const Tensor& input, const Array<Expr>& indices
auto yc = HalideIR::Internal::Cast::make(Int(32), tvm::ceil(in_y));
auto y0 = HalideIR::Internal::Cast::make(Int(32), tvm::floor(in_y));
auto y1 = tvm::select((yc > max_y), max_y, yc);
auto y1 = tvm::if_then_else((yc > max_y), max_y, yc);
auto y_lerp = in_y - yf;
auto in_x = indices[3];
......@@ -46,7 +46,7 @@ inline Expr bilinear_sample_nchw(const Tensor& input, const Array<Expr>& indices
auto xc = HalideIR::Internal::Cast::make(Int(32), tvm::ceil(in_x));
auto x0 = HalideIR::Internal::Cast::make(Int(32), tvm::floor(in_x));
auto x1 = tvm::select((xc > max_x), max_x, xc);
auto x1 = tvm::if_then_else((xc > max_x), max_x, xc);
auto x_lerp = in_x - xf;
auto A = input(indices[0], indices[1], y0, x0);
......@@ -215,7 +215,7 @@ inline Tensor resize_bilinear_nhwc(const Tensor& input,
auto yc = HalideIR::Internal::Cast::make(Int(32), tvm::ceil(in_y));
auto y0 = HalideIR::Internal::Cast::make(Int(32), tvm::floor(in_y));
auto y1 = tvm::select((yc > other_y), other_y, yc);
auto y1 = tvm::if_then_else((yc > other_y), other_y, yc);
auto y_lerp = in_y - yf;
auto in_x = indices[2] * x_ratio;
......@@ -223,7 +223,7 @@ inline Tensor resize_bilinear_nhwc(const Tensor& input,
auto xc = HalideIR::Internal::Cast::make(Int(32), tvm::ceil(in_x));
auto x0 = HalideIR::Internal::Cast::make(Int(32), tvm::floor(in_x));
auto x1 = tvm::select((xc > other_x), other_x, xc);
auto x1 = tvm::if_then_else((xc > other_x), other_x, xc);
auto x_lerp = in_x - xf;
auto A = input(indices[0], y0, x0, indices[3]);
......
......@@ -75,7 +75,7 @@ inline tvm::Tensor leaky_relu(const tvm::Tensor& t,
[&](const tvm::Array<tvm::Var>& i) {
auto value = t(i);
auto calpha = tvm::make_const(value.type(), alpha);
return tvm::select(value > 0, value, value * calpha);
return tvm::ir::Select::make(value > 0, value, value * calpha);
},
name,
tag);
......@@ -106,9 +106,11 @@ inline tvm::Tensor prelu(const tvm::Tensor &x,
return tvm::compute(x->shape,
[&](const tvm::Array<tvm::Var> &indices) {
return tvm::select(x(indices) > 0,
x(indices),
x(indices) * slope(indices[axis]));
auto xval = x(indices);
return tvm::ir::Select::make(
xval > 0,
xval,
xval * slope(indices[axis]));
},
name,
tag);
......@@ -193,7 +195,8 @@ inline tvm::Tensor pad(const tvm::Tensor& t,
}
}
if (sel.size() != 0) {
return tvm::select(detail::Map(sel, tvm::ir::And::make), t(indices), pad_value);
return tvm::if_then_else(
detail::Map(sel, tvm::ir::And::make), t(indices), pad_value);
}
return t(indices);
};
......
......@@ -76,7 +76,8 @@ inline Tensor dilate(const Tensor& x,
}
if (not_zero.size() > 0) {
auto all_not_zero = all(not_zero);
return tvm::select(all_not_zero, x(index_tuple), make_const(x->dtype, 0));
return tvm::if_then_else(
all_not_zero, x(index_tuple), make_const(x->dtype, 0));
}
return x(index_tuple);
}, name, tag);
......
......@@ -411,8 +411,8 @@ inline Tensor argmin(const Tensor& data,
bool atleast1d = false) {
auto fcombine = [](Array<Var> lhs, Array<Var> rhs) {
Array<Expr> result;
result.push_back(tvm::select(lhs[1] <= rhs[1], lhs[0], rhs[0])); // idx
result.push_back(tvm::select(lhs[1] <= rhs[1], lhs[1], rhs[1])); // val
result.push_back(tvm::ir::Select::make(lhs[1] <= rhs[1], lhs[0], rhs[0])); // idx
result.push_back(tvm::ir::Select::make(lhs[1] <= rhs[1], lhs[1], rhs[1])); // val
return result;
};
auto fidentity = [](std::vector<Type> types) {
......@@ -445,8 +445,8 @@ inline Tensor argmax(const Tensor& data,
bool atleast1d = false) {
auto fcombine = [](Array<Var> lhs, Array<Var> rhs) {
Array<Expr> result;
result.push_back(tvm::select(lhs[1] >= rhs[1], lhs[0], rhs[0])); // idx
result.push_back(tvm::select(lhs[1] >= rhs[1], lhs[1], rhs[1])); // val
result.push_back(tvm::ir::Select::make(lhs[1] >= rhs[1], lhs[0], rhs[0])); // idx
result.push_back(tvm::ir::Select::make(lhs[1] >= rhs[1], lhs[1], rhs[1])); // val
return result;
};
auto fidentity = [](std::vector<Type> types) {
......
......@@ -314,7 +314,7 @@ inline Tensor concatenate(const Array<Tensor>& inputs,
idx.push_back(indices[i]);
}
ret = tvm::select(ind >= 0,
ret = tvm::if_then_else(ind >= 0,
inputs[i + 1](idx),
ret);
}
......@@ -652,7 +652,7 @@ inline Tensor where(const Tensor& condition,
<< condition->shape.size() << " vs " << x->shape.size();
out = compute(
oshape, [&](const Array<Var>& indices) {
return tvm::select(condition(indices) != 0, x(indices), y(indices));
return tvm::ir::Select::make(condition(indices) != 0, x(indices), y(indices));
}, name, tag);
} else {
CHECK_EQ(topi::GetConstInt(condition->shape[0]), topi::GetConstInt(x->shape[0]))
......@@ -661,7 +661,7 @@ inline Tensor where(const Tensor& condition,
out = compute(
oshape, [&](const Array<Var>& indices) {
Array<Expr> condition_idx{indices[0]};
return tvm::select(condition(condition_idx) != 0,
return tvm::ir::Select::make(condition(condition_idx) != 0,
x(indices), y(indices));
}, name, tag);
}
......
......@@ -72,7 +72,7 @@ def conv2d_transpose_nchw_cuda(cfg, Input, Filter, strides, padding, out_dtype):
index_tuple.append(indices[i])
if not_zero:
not_zero = tvm.all(*not_zero)
return tvm.select(not_zero, data(*index_tuple), tvm.const(0.0, data.dtype))
return tvm.if_then_else(not_zero, data(*index_tuple), tvm.const(0.0, data.dtype))
return data(*index_tuple)
# convolution stage
......
......@@ -315,11 +315,11 @@ def sort_ir_out(data, index, new_index, loc, output, axis_mul_before, axis_mul_a
start = 0
with ib.else_scope():
start = sizes[tid-1]
p_out[base_idx + k * axis_mul_after] = tvm.select(
p_out[base_idx + k * axis_mul_after] = tvm.if_then_else(
k < p_index[tid], index_new[k+start], k)
with ib.else_scope():
with ib.if_scope(tid < data.shape[axis]):
p_out[tid] = tvm.select(tid < p_index[0], index_new[tid], tid)
p_out[tid] = tvm.if_then_else(tid < p_index[0], index_new[tid], tid)
body = ib.get()
return body
......@@ -470,7 +470,7 @@ def nms_ir(data, sort_result, valid_count, out, nms_threshold, force_suppress, n
(out_tensor[box_a_idx + 3] - out_tensor[box_a_idx + 1]) + \
(out_tensor[box_b_idx + 2] - out_tensor[box_b_idx]) * \
(out_tensor[box_b_idx + 3] - out_tensor[box_b_idx + 1]) - i
return tvm.select(u <= 0.0, 0.0, i / u)
return tvm.expr.Select(u <= 0.0, 0.0, i / u)
max_threads = int(math.sqrt(
tvm.target.current_target(allow_none=False).max_num_threads))
......@@ -506,7 +506,7 @@ def nms_ir(data, sort_result, valid_count, out, nms_threshold, force_suppress, n
tvm.all(nms_threshold_node > 0, nms_threshold_node < 1,
p_valid_count[0] > 0)):
# Reorder output
nkeep = tvm.select(
nkeep = tvm.if_then_else(
tvm.all(nms_topk_node > 0, nms_topk < p_valid_count[n]),
nms_topk, p_valid_count[n])
with ib.if_scope(i < nkeep):
......
......@@ -77,12 +77,13 @@ def multibox_prior_ir(data, out, sizes, ratios, steps, offsets):
center_w = (j + offset_w) * steps_w
for k in range(num_sizes + num_ratios - 1):
w = tvm.select(k < num_sizes,
w = tvm.if_then_else(k < num_sizes,
size_ratio_concat[
k] * in_height / in_width / 2.0,
size_ratio_concat[0] * in_height / in_width *
math.sqrt(size_ratio_concat[k + 1]) / 2.0)
h = tvm.select(k < num_sizes, size_ratio_concat[k] / 2.0,
h = tvm.if_then_else(
k < num_sizes, size_ratio_concat[k] / 2.0,
size_ratio_concat[0] / math.sqrt(size_ratio_concat[k + 1]) / 2.0)
count = (i * in_width * (num_sizes + num_ratios - 1) +
j * (num_sizes + num_ratios - 1) + k) * 4
......@@ -278,10 +279,10 @@ def transform_loc_ir(loc_pred, anchor, temp_flag, temp_id, temp_score_in, \
oy = py * vy * ah + ay
ow = tvm.exp(pw * vw) * aw / 2.0
oh = tvm.exp(ph * vh) * ah / 2.0
return tvm.select(clip, tvm.make.Max(0.0, tvm.make.Min(1.0, ox - ow)), ox - ow), \
tvm.select(clip, tvm.make.Max(0.0, tvm.make.Min(1.0, oy - oh)), oy - oh), \
tvm.select(clip, tvm.make.Max(0.0, tvm.make.Min(1.0, ox + ow)), ox + ow), \
tvm.select(clip, tvm.make.Max(0.0, tvm.make.Min(1.0, oy + oh)), oy + oh)
return tvm.if_then_else(clip, tvm.make.Max(0.0, tvm.make.Min(1.0, ox - ow)), ox - ow), \
tvm.if_then_else(clip, tvm.make.Max(0.0, tvm.make.Min(1.0, oy - oh)), oy - oh), \
tvm.if_then_else(clip, tvm.make.Max(0.0, tvm.make.Min(1.0, ox + ow)), ox + ow), \
tvm.if_then_else(clip, tvm.make.Max(0.0, tvm.make.Min(1.0, oy + oh)), oy + oh)
max_threads = int(
tvm.target.current_target(allow_none=False).max_num_threads)
......
......@@ -296,7 +296,8 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt
# pack input tile
input_tile = tvm.compute((CI, P_round // bnb, alpha, alpha, bnb), lambda ci, b, eps, nu, bb: \
tvm.select(b * bnb + bb < P,
tvm.if_then_else(
b * bnb + bb < P,
data_pad[(b*bnb+bb) // (nH*nW)][ci][(b*bnb+bb) // nW % nH * m + eps]
[(b*bnb+bb) % nW * m + nu], tvm.const(0, data_pad.dtype)), name='d')
......
......@@ -44,7 +44,7 @@ def dilate(data, strides, name="DilatedInput"):
index_tuple.append(indices[i])
if not_zero:
not_zero = tvm.all(*not_zero)
return tvm.select(not_zero, data(*index_tuple), tvm.const(0.0, data.dtype))
return tvm.if_then_else(not_zero, data(*index_tuple), tvm.const(0.0, data.dtype))
return data(*index_tuple)
return tvm.compute(out_shape, _dilate, name=name)
......@@ -41,7 +41,7 @@ def leaky_relu(x, alpha):
def _compute(*indices):
value = x(*indices)
calpha = tvm.const(alpha, value.dtype)
return tvm.select(value > 0, value, value * calpha)
return tvm.expr.Select(value > 0, value, value * calpha)
return tvm.compute(x.shape, _compute)
@tvm.tag_scope(tag=tag.BROADCAST)
......@@ -74,5 +74,6 @@ def prelu(x, slope, axis=1):
assert get_const_int(slope.shape[0]) == get_const_int(x.shape[axis])
def _compute_channelwise(*indices):
return tvm.select(x(*indices) > 0, x(*indices), x(*indices) * slope(indices[axis]))
xval = x(*indices)
return tvm.expr.Select(xval > 0, xval, xval * slope(indices[axis]))
return tvm.compute(x.shape, _compute_channelwise)
......@@ -55,6 +55,6 @@ def pad(data, pad_before, pad_after=None, pad_value=0.0, name="PadInput"):
not_zero.append(indices[i] < data.shape[i] + pad_before[i])
if not_zero:
not_zero = tvm.all(*not_zero)
return tvm.select(not_zero, data(*index_tuple), pad_value)
return tvm.if_then_else(not_zero, data(*index_tuple), pad_value)
return data(*index_tuple)
return tvm.compute(out_shape, _pad, name=name)
......@@ -55,8 +55,8 @@ def infer_stride(data, kernel, out):
_, _, IH, IW = data.shape
_, _, KH, KW = kernel.shape
_, _, OH, OW = out.shape
hstride = (IH - KH) // tvm.make.Max(OH - 1, 1) + tvm.select(OH == 1, 1, 0)
wstride = (IW - KW) // tvm.make.Max(OW - 1, 1) + tvm.select(OW == 1, 1, 0)
hstride = (IH - KH) // tvm.make.Max(OH - 1, 1) + tvm.expr.Select(OH == 1, 1, 0)
wstride = (IW - KW) // tvm.make.Max(OW - 1, 1) + tvm.expr.Select(OW == 1, 1, 0)
return get_const_int(hstride), get_const_int(wstride)
......
......@@ -249,7 +249,7 @@ def const_matrix(matrix, name="const_matrix"):
now = tvm.const(0.0, dtype)
for ii in range(row):
for jj in range(col):
now = tvm.select(tvm.all(i % row == ii, j % col == jj),
now = tvm.expr.Select(tvm.all(i % row == ii, j % col == jj),
tvm.const(matrix[ii][jj], dtype),
now)
return now
......
......@@ -47,7 +47,7 @@ def nms_ir(data, sort_result, valid_count, out, nms_threshold, force_suppress, n
(out_tensor[box_a_idx + 3] - out_tensor[box_a_idx + 1]) + \
(out_tensor[box_b_idx + 2] - out_tensor[box_b_idx]) * \
(out_tensor[box_b_idx + 3] - out_tensor[box_b_idx + 1]) - i
return tvm.select(u <= 0.0, 0.0, i / u)
return tvm.expr.Select(u <= 0.0, 0.0, i / u)
ib = tvm.ir_builder.create()
p_data = ib.buffer_ptr(data)
......@@ -64,7 +64,8 @@ def nms_ir(data, sort_result, valid_count, out, nms_threshold, force_suppress, n
with ib.if_scope(tvm.all(nms_threshold_node > 0, nms_threshold_node < 1,
p_valid_count[0] > 0)):
# Reorder output
nkeep = tvm.select(tvm.all(nms_topk_node > 0, nms_topk < p_valid_count[n]),
nkeep = tvm.if_then_else(
tvm.all(nms_topk_node > 0, nms_topk < p_valid_count[n]),
nms_topk, p_valid_count[n])
with ib.for_range(0, nkeep, name="l") as l:
with ib.for_range(0, 6, name="m") as m:
......
......@@ -47,7 +47,7 @@ def roi_align_nchw(data, rois, pooled_size, spatial_scale, sample_ratio=-1):
y = tvm.max(y, 0.0)
x = tvm.max(x, 0.0)
val = bilinear_sample_nchw(data, (i, c, y, x), height - 1, width - 1)
return tvm.select(outside, 0.0, val)
return tvm.if_then_else(outside, 0.0, val)
def _sample(i, c, ph, pw):
roi = rois[i]
......
......@@ -55,11 +55,12 @@ def multibox_prior_ir(data, out, sizes, ratios, steps, offsets):
with ib.for_range(0, in_width, name="j") as j:
center_w = (j + offset_w) * steps_w
for k in range(num_sizes + num_ratios - 1):
w = tvm.select(k < num_sizes,
w = tvm.if_then_else(k < num_sizes,
size_ratio_concat[k] * in_height / in_width / 2.0,
size_ratio_concat[0] * in_height / in_width *
math.sqrt(size_ratio_concat[k + 1]) / 2.0)
h = tvm.select(k < num_sizes, size_ratio_concat[k] / 2.0,
h = tvm.if_then_else(
k < num_sizes, size_ratio_concat[k] / 2.0,
size_ratio_concat[0] / math.sqrt(size_ratio_concat[k + 1]) / 2.0)
count = (i * in_width * (num_sizes + num_ratios - 1) +
j * (num_sizes + num_ratios - 1) + k) * 4
......@@ -164,10 +165,10 @@ def transform_loc_ir(cls_prob, loc_pred, anchor, valid_count, out, clip, thresho
oy = py * vy * ah + ay
ow = tvm.exp(pw * vw) * aw / 2.0
oh = tvm.exp(ph * vh) * ah / 2.0
return tvm.select(clip, tvm.max(0, tvm.min(1, ox - ow)), ox - ow), \
tvm.select(clip, tvm.max(0, tvm.min(1, oy - oh)), oy - oh), \
tvm.select(clip, tvm.max(0, tvm.min(1, ox + ow)), ox + ow), \
tvm.select(clip, tvm.max(0, tvm.min(1, oy + oh)), oy + oh)
return tvm.if_then_else(clip, tvm.max(0, tvm.min(1, ox - ow)), ox - ow), \
tvm.if_then_else(clip, tvm.max(0, tvm.min(1, oy - oh)), oy - oh), \
tvm.if_then_else(clip, tvm.max(0, tvm.min(1, ox + ow)), ox + ow), \
tvm.if_then_else(clip, tvm.max(0, tvm.min(1, oy + oh)), oy + oh)
batch_size = cls_prob.shape[0]
num_classes = cls_prob.shape[1]
......@@ -190,7 +191,7 @@ def transform_loc_ir(cls_prob, loc_pred, anchor, valid_count, out, clip, thresho
with ib.for_range(0, num_classes, name="j") as j:
with ib.if_scope(j > 0):
temp = p_cls_prob[n * num_anchors * num_classes + j * num_anchors + i]
cls_id[0] = tvm.select(temp > score[0], j, cls_id[0])
cls_id[0] = tvm.if_then_else(temp > score[0], j, cls_id[0])
score[0] = tvm.max(temp, score[0])
with ib.if_scope(tvm.all(cls_id[0] > 0, score[0] < threshold)):
cls_id[0] = 0
......
......@@ -65,7 +65,7 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p
else:
func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation))
func(a, w, c)
tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-4)
for device in get_all_backend():
with autotvm.tophub.context(device): # load tophub pre-tuned parameters
......
......@@ -45,8 +45,8 @@ print(tvm.lower(s, [A0, A1, B0, B1], simple_mode=True))
# x and y are the operands of reduction, both of them is a tuple of index
# and value.
def fcombine(x, y):
lhs = tvm.select((x[1] >= y[1]), x[0], y[0])
rhs = tvm.select((x[1] >= y[1]), x[1], y[1])
lhs = tvm.expr.Select((x[1] >= y[1]), x[0], y[0])
rhs = tvm.expr.Select((x[1] >= y[1]), x[1], y[1])
return lhs, rhs
# our identity element also need to be a tuple, so `fidentity` accepts
......
......@@ -43,7 +43,7 @@ out_size = (in_size - kernel + 2*pad) // stride + 1
# Pad input
Apad = tvm.compute(
(in_size + 2*pad, in_size + 2*pad, in_channel, batch),
lambda yy, xx, cc, nn: tvm.select(
lambda yy, xx, cc, nn: tvm.if_then_else(
tvm.all(yy >= pad, yy - pad < in_size,
xx >= pad, xx - pad < in_size),
A[yy - pad, xx - pad, cc, nn], tvm.const(0., "float32")),
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment