Commit 2a967044 by Jared Roesch Committed by Tianqi Chen

[RELAY] Add broadcast_to operator (#2276)

parent 7bc990ad
......@@ -124,6 +124,7 @@ This level enables additional math and transform operators.
tvm.relay.mean
tvm.relay.prod
tvm.relay.strided_slice
tvm.relay.broadcast_to
**Level 5: Vision/Image Operators**
......
......@@ -11,6 +11,7 @@ schedule_broadcast = _reg.schedule_injective
_reg.register_schedule("collapse_sum_like", _schedule_reduce)
_reg.register_schedule("broadcast_to", schedule_broadcast)
_reg.register_schedule("broadcast_to_like", schedule_broadcast)
_reg.register_schedule("expand_dims", schedule_broadcast)
_reg.register_schedule("squeeze", schedule_injective)
......
......@@ -267,6 +267,24 @@ def where(condition, x, y):
"""
return _make.where(condition, x, y)
def broadcast_to(data, shape):
"""Return an scalar value array with the same type, broadcast to
the provided shape.
Parameters
----------
data : relay.Expr
The input tensor.
shape : shape
Provide the shape to broadcast to.
Returns
-------
result : relay.Expr
The resulting tensor.
"""
return _make.broadcast_to(data, shape)
def broadcast_to_like(data, broadcast_type):
"""Return an scalar value array with the same shape and type as the input array.
......
......@@ -258,8 +258,7 @@ bool GlobalPool2DRel(const Array<Type>& types,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
CHECK(data != nullptr);
if (data == nullptr) { return false; }
const auto dshape = data->shape;
CHECK_NE(dshape.size(), 0);
CHECK_GE(dshape.size(), 2U)
......
......@@ -1084,6 +1084,52 @@ RELAY_REGISTER_OP("collapse_sum_like")
.set_attr<FTVMCompute>("FTVMCompute", CollapseSumLikeCompute)
.set_attr<TOpPattern>("TOpPattern", kCommReduce);
// BroadCastTo: <A, B> -> B where BroadCast(A, B) = B
bool BroadCastToRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2);
auto ioattrs = attrs.as<InitOpAttrs>();
CHECK(ioattrs);
auto intt = types[0].as<TensorTypeNode>();
if (intt == nullptr) { return false; }
auto type = TensorTypeNode::make(ioattrs->shape, intt->dtype);
reporter->Assign(types[1], type);
return true;
}
Expr MakeBroadCastTo(Expr data, Array<IndexExpr> shape) {
static const Op& op = Op::Get("broadcast_to");
auto attrs = make_node<InitOpAttrs>();
attrs->shape = std::move(shape);
return CallNode::make(op, {data}, Attrs(attrs), {});
}
Array<Tensor> BroadCastToCompute(const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
const Target& target) {
auto ioattrs = attrs.as<InitOpAttrs>();
CHECK(ioattrs != nullptr);
return { topi::broadcast_to(inputs[0], ioattrs->shape) };
}
TVM_REGISTER_API("relay.op._make.broadcast_to")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 2>(MakeBroadCastTo, args, rv);
});
RELAY_REGISTER_OP("broadcast_to")
.describe(R"code(Broadcast the first input to match the shape argument.
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(4)
.add_type_rel("BroadCastTo", BroadCastToRel)
.set_attr<FTVMCompute>("FTVMCompute", BroadCastToCompute)
.set_attr<TOpPattern>("TOpPattern", kBroadcast);
// BroadCastToLike: <A, B> -> B where BroadCast(A, B) = B
bool BroadCastToLikeRel(const Array<Type>& types,
int num_inputs,
......
......@@ -25,6 +25,24 @@ def test_collapse_sum_like():
op_res = intrp.evaluate(func)(x, y)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
def test_broadcast_to():
shape = (4, 1, 6)
shape_like = (3, 4, 5, 6)
dtype = "float32"
x = relay.Var("x", relay.ty.TensorType(shape , dtype))
z = relay.broadcast_to(x, shape=shape_like)
zz = relay.ir_pass.infer_type(z)
assert zz.checked_type == relay.ty.TensorType(shape_like, dtype)
func = relay.Function([x], z)
x = np.random.uniform(size=shape).astype(dtype)
ref_res = np.broadcast_to(x, shape_like)
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(x)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
def test_broadcast_to_like():
shape = (4, 1, 6)
shape_like = (3, 4, 5, 6)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment