Commit 8fe38eef by Siju Committed by Tianqi Chen

[Relay]collapse_sum and broadcast_to compute & schedule (#2180)

parent f522b0f8
...@@ -2,13 +2,15 @@ ...@@ -2,13 +2,15 @@
# pylint: disable=invalid-name # pylint: disable=invalid-name
from __future__ import absolute_import from __future__ import absolute_import
from . import op as _reg from . import op as _reg
from ._reduce import _schedule_reduce
schedule_injective = _reg.schedule_injective schedule_injective = _reg.schedule_injective
schedule_broadcast = _reg.schedule_injective schedule_broadcast = _reg.schedule_injective
_reg.register_schedule("collapse_sum_like", _schedule_reduce)
_reg.register_schedule("squeeze", schedule_injective) _reg.register_schedule("broadcast_to_like", schedule_broadcast)
_reg.register_schedule("expand_dims", schedule_broadcast) _reg.register_schedule("expand_dims", schedule_broadcast)
_reg.register_schedule("squeeze", schedule_injective)
_reg.register_schedule("reshape", schedule_injective) _reg.register_schedule("reshape", schedule_injective)
_reg.register_schedule("reshape_like", schedule_injective) _reg.register_schedule("reshape_like", schedule_injective)
_reg.register_schedule("full", schedule_injective) _reg.register_schedule("full", schedule_injective)
......
...@@ -9,6 +9,8 @@ ...@@ -9,6 +9,8 @@
#include <tvm/ir.h> #include <tvm/ir.h>
#include <topi/transform.h> #include <topi/transform.h>
#include <topi/elemwise.h> #include <topi/elemwise.h>
#include <topi/broadcast.h>
#include <topi/reduction.h>
#include <vector> #include <vector>
#include "../op_common.h" #include "../op_common.h"
#include "../../../arithmetic/compute_expr.h" #include "../../../arithmetic/compute_expr.h"
...@@ -1017,6 +1019,15 @@ Expr MakeCollapseSumLike(Expr data, ...@@ -1017,6 +1019,15 @@ Expr MakeCollapseSumLike(Expr data,
return CallNode::make(op, {data, collapse_type}, Attrs(), {}); return CallNode::make(op, {data, collapse_type}, Attrs(), {});
} }
Array<Tensor> CollapseSumLikeCompute(const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
const Target& target) {
const auto* out_ttype = out_type.as<TensorTypeNode>();
CHECK(out_ttype != nullptr);
return { topi::collapse_sum(inputs[0], out_ttype->shape) };
}
TVM_REGISTER_API("relay.op._make.collapse_sum_like") TVM_REGISTER_API("relay.op._make.collapse_sum_like")
.set_body([](const TVMArgs& args, TVMRetValue* rv) { .set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 2>(MakeCollapseSumLike, args, rv); runtime::detail::unpack_call<Expr, 2>(MakeCollapseSumLike, args, rv);
...@@ -1029,7 +1040,9 @@ RELAY_REGISTER_OP("collapse_sum_like") ...@@ -1029,7 +1040,9 @@ RELAY_REGISTER_OP("collapse_sum_like")
.add_argument("data", "Tensor", "The input tensor.") .add_argument("data", "Tensor", "The input tensor.")
.add_argument("collapse_type", "Tensor", "Provide the type to collapse to.") .add_argument("collapse_type", "Tensor", "Provide the type to collapse to.")
.set_support_level(10) .set_support_level(10)
.add_type_rel("CollapseSumLike", CollapseSumLikeRel); .add_type_rel("CollapseSumLike", CollapseSumLikeRel)
.set_attr<FTVMCompute>("FTVMCompute", CollapseSumLikeCompute)
.set_attr<TOpPattern>("TOpPattern", kCommReduce);
// BroadCastToLike: <A, B> -> B where BroadCast(A, B) = B // BroadCastToLike: <A, B> -> B where BroadCast(A, B) = B
bool BroadCastToLikeRel(const Array<Type>& types, bool BroadCastToLikeRel(const Array<Type>& types,
...@@ -1047,6 +1060,15 @@ Expr MakeBroadCastToLike(Expr data, ...@@ -1047,6 +1060,15 @@ Expr MakeBroadCastToLike(Expr data,
return CallNode::make(op, {data, broadcast_type}, Attrs(), {}); return CallNode::make(op, {data, broadcast_type}, Attrs(), {});
} }
Array<Tensor> BroadCastToLikeCompute(const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
const Target& target) {
const auto* out_ttype = out_type.as<TensorTypeNode>();
CHECK(out_ttype != nullptr);
return { topi::broadcast_to(inputs[0], out_ttype->shape) };
}
TVM_REGISTER_API("relay.op._make.broadcast_to_like") TVM_REGISTER_API("relay.op._make.broadcast_to_like")
.set_body([](const TVMArgs& args, TVMRetValue* rv) { .set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 2>(MakeBroadCastToLike, args, rv); runtime::detail::unpack_call<Expr, 2>(MakeBroadCastToLike, args, rv);
...@@ -1059,7 +1081,9 @@ RELAY_REGISTER_OP("broadcast_to_like") ...@@ -1059,7 +1081,9 @@ RELAY_REGISTER_OP("broadcast_to_like")
.add_argument("data", "Tensor", "The input tensor.") .add_argument("data", "Tensor", "The input tensor.")
.add_argument("broadcast_type", "Tensor", "Provide the type to broadcast to.") .add_argument("broadcast_type", "Tensor", "Provide the type to broadcast to.")
.set_support_level(10) .set_support_level(10)
.add_type_rel("BroadCastToLike", BroadCastToLikeRel); .add_type_rel("BroadCastToLike", BroadCastToLikeRel)
.set_attr<FTVMCompute>("FTVMCompute", BroadCastToLikeCompute)
.set_attr<TOpPattern>("TOpPattern", kBroadcast);
// strided_slice // strided_slice
......
...@@ -6,19 +6,44 @@ from tvm import relay ...@@ -6,19 +6,44 @@ from tvm import relay
from tvm.relay.testing import ctx_list from tvm.relay.testing import ctx_list
def test_collapse_sum_like(): def test_collapse_sum_like():
x = relay.Var("x", relay.ty.TensorType((3, 4, 5, 6), "int8")) shape = (3, 4, 5, 6)
y = relay.Var("y", relay.ty.TensorType((4, 1, 6), "int8")) shape_like = (4, 5, 6)
dtype = "float32"
x = relay.Var("x", relay.ty.TensorType(shape , dtype))
y = relay.Var("y", relay.ty.TensorType(shape_like, dtype))
z = relay.collapse_sum_like(x, y) z = relay.collapse_sum_like(x, y)
zz = relay.ir_pass.infer_type(z) zz = relay.ir_pass.infer_type(z)
assert zz.checked_type == relay.ty.TensorType((4, 1, 6), "int8") assert zz.checked_type == relay.ty.TensorType(shape_like, dtype)
func = relay.Function([x, y], z)
x = np.random.uniform(size=shape).astype(dtype)
y = np.random.uniform(size=shape_like).astype(dtype)
ref_res = np.sum(x, 0)
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(x, y)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
def test_broadcast_to_like(): def test_broadcast_to_like():
x = relay.Var("x", relay.ty.TensorType((3, 4, 5, 6), "int8")) shape = (4, 1, 6)
y = relay.Var("y", relay.ty.TensorType((4, 1, 6), "int8")) shape_like = (3, 4, 5, 6)
z = relay.broadcast_to_like(y, x) dtype = "float32"
x = relay.Var("x", relay.ty.TensorType(shape , dtype))
y = relay.Var("y", relay.ty.TensorType(shape_like, dtype))
z = relay.broadcast_to_like(x, y)
zz = relay.ir_pass.infer_type(z) zz = relay.ir_pass.infer_type(z)
assert zz.checked_type == relay.ty.TensorType((3, 4, 5, 6), "int8") assert zz.checked_type == relay.ty.TensorType(shape_like, dtype)
func = relay.Function([x, y], z)
x = np.random.uniform(size=shape).astype(dtype)
y = np.random.uniform(size=shape_like).astype(dtype)
ref_res = np.broadcast_to(x, shape_like)
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(x, y)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
def np_slice_like(np_data, np_shape_like, axis=None): def np_slice_like(np_data, np_shape_like, axis=None):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment