Commit b8715008 by Siju Committed by Tianqi Chen

[Relay]where compute and schedule (#2179)

parent 10ce048c
......@@ -19,3 +19,4 @@ _reg.register_schedule("slice_like", schedule_injective)
_reg.register_schedule("split", schedule_injective)
_reg.register_schedule("take", schedule_injective)
_reg.register_schedule("transpose", schedule_injective)
_reg.register_schedule("where", schedule_broadcast)
......@@ -857,6 +857,13 @@ Expr MakeWhere(const Expr& condition, const Expr& x, const Expr& y) {
return CallNode::make(op, {condition, x, y});
}
Array<Tensor> WhereCompute(const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
const Target& target) {
return { topi::where(inputs[0], inputs[1], inputs[2]) };
}
TVM_REGISTER_API("relay.op._make.where")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 3>(MakeWhere, args, rv);
......@@ -896,7 +903,9 @@ Examples::
.add_argument("y", "Tensor", "Second array to be selected")
.set_num_inputs(3)
.set_support_level(4)
.add_type_rel("Where", WhereRel);
.add_type_rel("Where", WhereRel)
.set_attr<FTVMCompute>("FTVMCompute", WhereCompute)
.set_attr<TOpPattern>("TOpPattern", kBroadcast);
// Squeeze
......
......@@ -98,12 +98,25 @@ def test_binary_int_broadcast():
def test_where():
cond = relay.var("cond", relay.TensorType((3, 4), "float32"))
x = relay.var("x", relay.TensorType((3, 4), "float32"))
y = relay.var("y", relay.TensorType((3, 4), "float32"))
shape = (3, 4)
dtype = "float32"
cond = relay.var("cond", relay.TensorType(shape, dtype))
x = relay.var("x", relay.TensorType(shape, dtype))
y = relay.var("y", relay.TensorType(shape, dtype))
z = relay.where(cond, x, y)
zz = relay.ir_pass.infer_type(z)
assert zz.checked_type == relay.TensorType((3, 4), "float32")
assert zz.checked_type == relay.TensorType(shape, dtype)
func = relay.Function([cond, x, y], z)
condition = np.random.uniform(low=-1, high=1, size=shape).astype(dtype)
x = np.random.uniform(size=shape).astype(dtype)
y = np.random.uniform(size=shape).astype(dtype)
ref_res = np.where(condition, x, y)
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(condition, x, y)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
def verify_reduce(funcs, data, axis, keepdims, exclude, output, dtype="float32"):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment