Commit c2b36154 by 雾雨魔理沙 Committed by Tianqi Chen

[Relay][Op]BroadcastToLike CollapseSumLike (#1886)

parent c51268c3
......@@ -123,6 +123,17 @@ This level enables additional math and transform operators.
tvm.relay.image.resize
**Level 10: Temporary Operators**
This level support backpropagation of broadcast operators. It is temporary.
.. autosummary::
:nosignatures:
tvm.relay.broadcast_to_like
tvm.relay.collapse_sum_like
Level 1 Definitions
-------------------
.. autofunction:: tvm.relay.log
......@@ -199,6 +210,13 @@ Level 4 Definitions
.. autofunction:: tvm.relay.prod
Level 5 Definitions
-------------------
.. autofunction:: tvm.relay.image.resize
Level 10 Definitions
--------------------
.. autofunction:: tvm.relay.broadcast_to_like
.. autofunction:: tvm.relay.collapse_sum_like
......@@ -242,3 +242,41 @@ def where(condition, x, y):
Note that the shape of condition, x, and y needs to be the same.
"""
return _make.where(condition, x, y)
def broadcast_to_like(data, broadcast_type):
"""Return an scalar value array with the same shape and type as the input array.
Parameters
----------
data : relay.Expr
The input tensor.
broadcast_type : relay.Expr
Provide the type to broadcast to.
Returns
-------
result : relay.Expr
The resulting tensor.
"""
return _make.broadcast_to_like(data, broadcast_type)
def collapse_sum_like(data, collapse_type):
"""Return an scalar value array with the same shape and type as the input array.
Parameters
----------
data : relay.Expr
The input tensor.
collapse_type : relay.Expr
Provide the type to collapse to.
Returns
-------
result : relay.Expr
The resulting tensor.
"""
return _make.collapse_sum_like(data, collapse_type)
......@@ -718,5 +718,66 @@ RELAY_REGISTER_OP("squeeze")
.set_support_level(3)
.add_type_rel("Squeeze", SqueezeRel);
// Have no idea how to assert the constraint.
// CollapseSumLike: <A, B> -> B where BroadCast(A, B) = A
bool CollapseSumLikeRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
reporter->Assign(types[2], types[1]);
return true;
}
Expr MakeCollapseSumLike(Expr data,
Expr collapse_type) {
static const Op& op = Op::Get("collapse_sum_like");
return CallNode::make(op, {data, collapse_type}, Attrs(), {});
}
TVM_REGISTER_API("relay.op._make.collapse_sum_like")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 2>(MakeCollapseSumLike, args, rv);
});
RELAY_REGISTER_OP("collapse_sum_like")
.describe(R"code(Collapse the first input to match the shape of the second input.
)code" TVM_ADD_FILELINE)
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("collapse_type", "Tensor", "Provide the type to collapse to.")
.set_support_level(10)
.add_type_rel("CollapseSumLike", CollapseSumLikeRel);
// BroadCastToLike: <A, B> -> B where BroadCast(A, B) = B
bool BroadCastToLikeRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
reporter->Assign(types[2], types[1]);
return true;
}
Expr MakeBroadCastToLike(Expr data,
Expr broadcast_type) {
static const Op& op = Op::Get("broadcast_to_like");
return CallNode::make(op, {data, broadcast_type}, Attrs(), {});
}
TVM_REGISTER_API("relay.op._make.broadcast_to_like")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 2>(MakeBroadCastToLike, args, rv);
});
RELAY_REGISTER_OP("broadcast_to_like")
.describe(R"code(Broadcast the first input to match the shape of the second input.
)code" TVM_ADD_FILELINE)
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("broadcast_type", "Tensor", "Provide the type to broadcast to.")
.set_support_level(10)
.add_type_rel("BroadCastToLike", BroadCastToLikeRel);
} // namespace relay
} // namespace tvm
""" Support level10 operator test cases.
"""
import tvm
from tvm import relay
def test_collapse_sum_like():
x = relay.Var("x", relay.ty.TensorType((3, 4, 5, 6), "int8"))
y = relay.Var("y", relay.ty.TensorType((4, 1, 6), "int8"))
z = relay.collapse_sum_like(x, y)
zz = relay.ir_pass.infer_type(z)
assert zz.checked_type == relay.ty.TensorType((4, 1, 6), "int8")
def test_broadcast_to_like():
x = relay.Var("x", relay.ty.TensorType((3, 4, 5, 6), "int8"))
y = relay.Var("y", relay.ty.TensorType((4, 1, 6), "int8"))
z = relay.broadcast_to_like(y, x)
zz = relay.ir_pass.infer_type(z)
assert zz.checked_type == relay.ty.TensorType((3, 4, 5, 6), "int8")
if __name__ == "__main__":
test_collapse_sum_like()
test_broadcast_to_like()
......@@ -461,3 +461,4 @@ if __name__ == "__main__":
test_let_alpha_equal()
test_if_alpha_equal()
test_op_alpha_equal()
test_var_alpha_equal()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment