Commit 2a967044 by Jared Roesch Committed by Tianqi Chen

[RELAY] Add broadcast_to operator (#2276)

parent 7bc990ad
...@@ -124,6 +124,7 @@ This level enables additional math and transform operators. ...@@ -124,6 +124,7 @@ This level enables additional math and transform operators.
tvm.relay.mean tvm.relay.mean
tvm.relay.prod tvm.relay.prod
tvm.relay.strided_slice tvm.relay.strided_slice
tvm.relay.broadcast_to
**Level 5: Vision/Image Operators** **Level 5: Vision/Image Operators**
......
...@@ -11,6 +11,7 @@ schedule_broadcast = _reg.schedule_injective ...@@ -11,6 +11,7 @@ schedule_broadcast = _reg.schedule_injective
_reg.register_schedule("collapse_sum_like", _schedule_reduce) _reg.register_schedule("collapse_sum_like", _schedule_reduce)
_reg.register_schedule("broadcast_to", schedule_broadcast)
_reg.register_schedule("broadcast_to_like", schedule_broadcast) _reg.register_schedule("broadcast_to_like", schedule_broadcast)
_reg.register_schedule("expand_dims", schedule_broadcast) _reg.register_schedule("expand_dims", schedule_broadcast)
_reg.register_schedule("squeeze", schedule_injective) _reg.register_schedule("squeeze", schedule_injective)
......
...@@ -267,6 +267,24 @@ def where(condition, x, y): ...@@ -267,6 +267,24 @@ def where(condition, x, y):
""" """
return _make.where(condition, x, y) return _make.where(condition, x, y)
def broadcast_to(data, shape):
"""Return an scalar value array with the same type, broadcast to
the provided shape.
Parameters
----------
data : relay.Expr
The input tensor.
shape : shape
Provide the shape to broadcast to.
Returns
-------
result : relay.Expr
The resulting tensor.
"""
return _make.broadcast_to(data, shape)
def broadcast_to_like(data, broadcast_type): def broadcast_to_like(data, broadcast_type):
"""Return an scalar value array with the same shape and type as the input array. """Return an scalar value array with the same shape and type as the input array.
......
...@@ -258,8 +258,7 @@ bool GlobalPool2DRel(const Array<Type>& types, ...@@ -258,8 +258,7 @@ bool GlobalPool2DRel(const Array<Type>& types,
const TypeReporter& reporter) { const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2); CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>(); const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) { return false; }
CHECK(data != nullptr);
const auto dshape = data->shape; const auto dshape = data->shape;
CHECK_NE(dshape.size(), 0); CHECK_NE(dshape.size(), 0);
CHECK_GE(dshape.size(), 2U) CHECK_GE(dshape.size(), 2U)
......
...@@ -1084,6 +1084,52 @@ RELAY_REGISTER_OP("collapse_sum_like") ...@@ -1084,6 +1084,52 @@ RELAY_REGISTER_OP("collapse_sum_like")
.set_attr<FTVMCompute>("FTVMCompute", CollapseSumLikeCompute) .set_attr<FTVMCompute>("FTVMCompute", CollapseSumLikeCompute)
.set_attr<TOpPattern>("TOpPattern", kCommReduce); .set_attr<TOpPattern>("TOpPattern", kCommReduce);
// BroadCastTo: <A, B> -> B where BroadCast(A, B) = B
bool BroadCastToRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2);
auto ioattrs = attrs.as<InitOpAttrs>();
CHECK(ioattrs);
auto intt = types[0].as<TensorTypeNode>();
if (intt == nullptr) { return false; }
auto type = TensorTypeNode::make(ioattrs->shape, intt->dtype);
reporter->Assign(types[1], type);
return true;
}
Expr MakeBroadCastTo(Expr data, Array<IndexExpr> shape) {
static const Op& op = Op::Get("broadcast_to");
auto attrs = make_node<InitOpAttrs>();
attrs->shape = std::move(shape);
return CallNode::make(op, {data}, Attrs(attrs), {});
}
Array<Tensor> BroadCastToCompute(const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
const Target& target) {
auto ioattrs = attrs.as<InitOpAttrs>();
CHECK(ioattrs != nullptr);
return { topi::broadcast_to(inputs[0], ioattrs->shape) };
}
TVM_REGISTER_API("relay.op._make.broadcast_to")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 2>(MakeBroadCastTo, args, rv);
});
RELAY_REGISTER_OP("broadcast_to")
.describe(R"code(Broadcast the first input to match the shape argument.
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(4)
.add_type_rel("BroadCastTo", BroadCastToRel)
.set_attr<FTVMCompute>("FTVMCompute", BroadCastToCompute)
.set_attr<TOpPattern>("TOpPattern", kBroadcast);
// BroadCastToLike: <A, B> -> B where BroadCast(A, B) = B // BroadCastToLike: <A, B> -> B where BroadCast(A, B) = B
bool BroadCastToLikeRel(const Array<Type>& types, bool BroadCastToLikeRel(const Array<Type>& types,
int num_inputs, int num_inputs,
......
...@@ -25,6 +25,24 @@ def test_collapse_sum_like(): ...@@ -25,6 +25,24 @@ def test_collapse_sum_like():
op_res = intrp.evaluate(func)(x, y) op_res = intrp.evaluate(func)(x, y)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
def test_broadcast_to():
shape = (4, 1, 6)
shape_like = (3, 4, 5, 6)
dtype = "float32"
x = relay.Var("x", relay.ty.TensorType(shape , dtype))
z = relay.broadcast_to(x, shape=shape_like)
zz = relay.ir_pass.infer_type(z)
assert zz.checked_type == relay.ty.TensorType(shape_like, dtype)
func = relay.Function([x], z)
x = np.random.uniform(size=shape).astype(dtype)
ref_res = np.broadcast_to(x, shape_like)
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(x)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
def test_broadcast_to_like(): def test_broadcast_to_like():
shape = (4, 1, 6) shape = (4, 1, 6)
shape_like = (3, 4, 5, 6) shape_like = (3, 4, 5, 6)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment