Commit 246a38a1 by Siju Committed by Tianqi Chen

[RELAY]full, full_like compute and schedule (#2170)

parent 7880b50c
...@@ -11,6 +11,8 @@ _reg.register_schedule("squeeze", schedule_injective) ...@@ -11,6 +11,8 @@ _reg.register_schedule("squeeze", schedule_injective)
_reg.register_schedule("expand_dims", schedule_broadcast) _reg.register_schedule("expand_dims", schedule_broadcast)
_reg.register_schedule("reshape", schedule_injective) _reg.register_schedule("reshape", schedule_injective)
_reg.register_schedule("reshape_like", schedule_injective) _reg.register_schedule("reshape_like", schedule_injective)
_reg.register_schedule("full", schedule_injective)
_reg.register_schedule("full_like", schedule_injective)
_reg.register_schedule("cast", schedule_broadcast) _reg.register_schedule("cast", schedule_broadcast)
_reg.register_schedule("strided_slice", schedule_injective) _reg.register_schedule("strided_slice", schedule_injective)
_reg.register_schedule("slice_like", schedule_injective) _reg.register_schedule("slice_like", schedule_injective)
......
...@@ -673,6 +673,14 @@ bool FullRel(const Array<Type>& types, ...@@ -673,6 +673,14 @@ bool FullRel(const Array<Type>& types,
return true; return true;
} }
Array<Tensor> FullCompute(const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
const Target& target) {
const auto* out_ttype = out_type.as<TensorTypeNode>();
return { topi::full(out_ttype->shape, out_ttype->dtype, inputs[0]()) };
}
Expr MakeFull(Expr fill_value, Expr MakeFull(Expr fill_value,
Array<IndexExpr> shape, Array<IndexExpr> shape,
DataType dtype) { DataType dtype) {
...@@ -696,7 +704,9 @@ RELAY_REGISTER_OP("full") ...@@ -696,7 +704,9 @@ RELAY_REGISTER_OP("full")
.set_num_inputs(1) .set_num_inputs(1)
.add_argument("fill_value", "double", "The value to fill.") .add_argument("fill_value", "double", "The value to fill.")
.set_support_level(3) .set_support_level(3)
.add_type_rel("Full", FullRel); .add_type_rel("Full", FullRel)
.set_attr<FTVMCompute>("FTVMCompute", FullCompute)
.set_attr<TOpPattern>("TOpPattern", kElemWise);
bool InitOpRel(const Array<Type>& types, bool InitOpRel(const Array<Type>& types,
int num_inputs, int num_inputs,
...@@ -777,6 +787,13 @@ bool FullLikeRel(const Array<Type>& types, ...@@ -777,6 +787,13 @@ bool FullLikeRel(const Array<Type>& types,
return true; return true;
} }
Array<Tensor> FullLikeCompute(const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
const Target& target) {
return { topi::full_like(inputs[0], inputs[1]()) };
}
Expr MakeFullLike(Expr data, Expr MakeFullLike(Expr data,
Expr fill_value) { Expr fill_value) {
static const Op& op = Op::Get("full_like"); static const Op& op = Op::Get("full_like");
...@@ -797,7 +814,9 @@ and type as the input array. ...@@ -797,7 +814,9 @@ and type as the input array.
.add_argument("data", "Tensor", "The input tensor.") .add_argument("data", "Tensor", "The input tensor.")
.add_argument("fill_value", "double", "Scalar value to fill.") .add_argument("fill_value", "double", "Scalar value to fill.")
.set_support_level(3) .set_support_level(3)
.add_type_rel("FullLike", FullLikeRel); .add_type_rel("FullLike", FullLikeRel)
.set_attr<FTVMCompute>("FTVMCompute", FullLikeCompute)
.set_attr<TOpPattern>("TOpPattern", kElemWise);
// where operator // where operator
bool WhereRel(const Array<Type>& types, bool WhereRel(const Array<Type>& types,
......
...@@ -293,7 +293,7 @@ def test_split_infer_type(): ...@@ -293,7 +293,7 @@ def test_split_infer_type():
relay.ty.TensorType((d1, (d2-7), d3, d4), "float32")])), relay.ty.TensorType((d1, (d2-7), d3, d4), "float32")])),
axis=1) axis=1)
def test_full(): def test_full_infer_type():
# default settings: match input dtype # default settings: match input dtype
x = relay.var("x", relay.TensorType((), "int8")) x = relay.var("x", relay.TensorType((), "int8"))
y = relay.full(x, ()) y = relay.full(x, ())
...@@ -308,7 +308,22 @@ def test_full(): ...@@ -308,7 +308,22 @@ def test_full():
assert yy.checked_type == relay.TensorType((1, 2), "int8") assert yy.checked_type == relay.TensorType((1, 2), "int8")
def test_full_like(): def test_full():
def verify_full(fill_value, src_shape, dtype):
x = relay.var("x", relay.scalar_type(dtype))
z = relay.full(x, src_shape, dtype)
func = relay.Function([x], z)
ref_res = np.full(src_shape, fill_value)
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(fill_value)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
verify_full(4, (1, 3, 4, 4), "int32")
verify_full(4.0, (1, 4), "float32")
def test_full_like_infer_type():
# concrete shape # concrete shape
base = relay.var("base", relay.TensorType((1, 2, 3), "float32")) base = relay.var("base", relay.TensorType((1, 2, 3), "float32"))
fill = relay.var("fill", relay.TensorType((), "float32")) fill = relay.var("fill", relay.TensorType((), "float32"))
...@@ -324,6 +339,26 @@ def test_full_like(): ...@@ -324,6 +339,26 @@ def test_full_like():
yy = relay.ir_pass.infer_type(y) yy = relay.ir_pass.infer_type(y)
assert yy.checked_type == relay.TensorType((n, c, h, w), "float32") assert yy.checked_type == relay.TensorType((n, c, h, w), "float32")
def test_full_like():
def verify_full_like(base, fill_value, dtype):
x_data = np.random.uniform(low=-1, high=1, size=base).astype(dtype)
x = relay.var("x", relay.TensorType(base, dtype))
y = relay.var("y", relay.scalar_type(dtype))
z = relay.full_like(x, y)
func = relay.Function([x, y], z)
ref_res = np.full_like(x_data, fill_value)
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(x_data, fill_value)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
verify_full_like((1, 3, 4, 4), 4, "int32")
verify_full_like((1, 1), 44.0, "float32")
def test_infer_type_leaky_relu(): def test_infer_type_leaky_relu():
n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
x = relay.var("x", relay.TensorType((n, c, h, w), "float32")) x = relay.var("x", relay.TensorType((n, c, h, w), "float32"))
...@@ -412,7 +447,9 @@ if __name__ == "__main__": ...@@ -412,7 +447,9 @@ if __name__ == "__main__":
test_reshape_like() test_reshape_like()
test_take_infer_type() test_take_infer_type()
test_take() test_take()
test_full_infer_type()
test_full() test_full()
test_full_like_infer_type()
test_full_like() test_full_like()
test_infer_type_leaky_relu() test_infer_type_leaky_relu()
test_infer_type_prelu() test_infer_type_prelu()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment