Commit b8715008 by Siju Committed by Tianqi Chen

[Relay]where compute and schedule (#2179)

parent 10ce048c
...@@ -19,3 +19,4 @@ _reg.register_schedule("slice_like", schedule_injective) ...@@ -19,3 +19,4 @@ _reg.register_schedule("slice_like", schedule_injective)
_reg.register_schedule("split", schedule_injective) _reg.register_schedule("split", schedule_injective)
_reg.register_schedule("take", schedule_injective) _reg.register_schedule("take", schedule_injective)
_reg.register_schedule("transpose", schedule_injective) _reg.register_schedule("transpose", schedule_injective)
_reg.register_schedule("where", schedule_broadcast)
...@@ -857,6 +857,13 @@ Expr MakeWhere(const Expr& condition, const Expr& x, const Expr& y) { ...@@ -857,6 +857,13 @@ Expr MakeWhere(const Expr& condition, const Expr& x, const Expr& y) {
return CallNode::make(op, {condition, x, y}); return CallNode::make(op, {condition, x, y});
} }
Array<Tensor> WhereCompute(const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
const Target& target) {
return { topi::where(inputs[0], inputs[1], inputs[2]) };
}
TVM_REGISTER_API("relay.op._make.where") TVM_REGISTER_API("relay.op._make.where")
.set_body([](const TVMArgs& args, TVMRetValue* rv) { .set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 3>(MakeWhere, args, rv); runtime::detail::unpack_call<Expr, 3>(MakeWhere, args, rv);
...@@ -896,7 +903,9 @@ Examples:: ...@@ -896,7 +903,9 @@ Examples::
.add_argument("y", "Tensor", "Second array to be selected") .add_argument("y", "Tensor", "Second array to be selected")
.set_num_inputs(3) .set_num_inputs(3)
.set_support_level(4) .set_support_level(4)
.add_type_rel("Where", WhereRel); .add_type_rel("Where", WhereRel)
.set_attr<FTVMCompute>("FTVMCompute", WhereCompute)
.set_attr<TOpPattern>("TOpPattern", kBroadcast);
// Squeeze // Squeeze
......
...@@ -98,12 +98,25 @@ def test_binary_int_broadcast(): ...@@ -98,12 +98,25 @@ def test_binary_int_broadcast():
def test_where(): def test_where():
cond = relay.var("cond", relay.TensorType((3, 4), "float32")) shape = (3, 4)
x = relay.var("x", relay.TensorType((3, 4), "float32")) dtype = "float32"
y = relay.var("y", relay.TensorType((3, 4), "float32")) cond = relay.var("cond", relay.TensorType(shape, dtype))
x = relay.var("x", relay.TensorType(shape, dtype))
y = relay.var("y", relay.TensorType(shape, dtype))
z = relay.where(cond, x, y) z = relay.where(cond, x, y)
zz = relay.ir_pass.infer_type(z) zz = relay.ir_pass.infer_type(z)
assert zz.checked_type == relay.TensorType((3, 4), "float32") assert zz.checked_type == relay.TensorType(shape, dtype)
func = relay.Function([cond, x, y], z)
condition = np.random.uniform(low=-1, high=1, size=shape).astype(dtype)
x = np.random.uniform(size=shape).astype(dtype)
y = np.random.uniform(size=shape).astype(dtype)
ref_res = np.where(condition, x, y)
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(condition, x, y)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
def verify_reduce(funcs, data, axis, keepdims, exclude, output, dtype="float32"): def verify_reduce(funcs, data, axis, keepdims, exclude, output, dtype="float32"):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment