Commit 0cf3ddf7 by Siju Committed by Tianqi Chen

Relay reshape reshape_like compute and schedule (#2159)

parent ccd33332
...@@ -53,3 +53,11 @@ _reg.register_schedule("strided_slice", schedule_injective) ...@@ -53,3 +53,11 @@ _reg.register_schedule("strided_slice", schedule_injective)
# slice_like # slice_like
_reg.register_schedule("slice_like", schedule_injective) _reg.register_schedule("slice_like", schedule_injective)
_reg.register_pattern("slice_like", OpPattern.INJECTIVE) _reg.register_pattern("slice_like", OpPattern.INJECTIVE)
# reshape
_reg.register_schedule("reshape", schedule_injective)
_reg.register_pattern("reshape", OpPattern.INJECTIVE)
# reshape_like
_reg.register_schedule("reshape_like", schedule_injective)
_reg.register_pattern("reshape_like", OpPattern.INJECTIVE)
...@@ -376,7 +376,15 @@ Example:: ...@@ -376,7 +376,15 @@ Example::
.set_attrs_type_key("relay.attrs.ReshapeAttrs") .set_attrs_type_key("relay.attrs.ReshapeAttrs")
.add_argument("data", "Tensor", "The input tensor.") .add_argument("data", "Tensor", "The input tensor.")
.set_support_level(3) .set_support_level(3)
.add_type_rel("Reshape", ReshapeRel); .add_type_rel("Reshape", ReshapeRel)
.set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
const Target& target) {
const auto* param = attrs.as<ReshapeAttrs>();
CHECK(param != nullptr);
return Array<Tensor>{ topi::reshape(inputs[0], param->newshape) };
});
/*! /*!
...@@ -431,7 +439,13 @@ the input array into an output array with the same shape as the second input arr ...@@ -431,7 +439,13 @@ the input array into an output array with the same shape as the second input arr
.add_argument("data", "Tensor", "The input tensor.") .add_argument("data", "Tensor", "The input tensor.")
.add_argument("shape_like", "Tensor", "Shape tensor.") .add_argument("shape_like", "Tensor", "Shape tensor.")
.set_support_level(3) .set_support_level(3)
.add_type_rel("ReshapeLike", ReshapeLikeRel); .add_type_rel("ReshapeLike", ReshapeLikeRel)
.set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
const Target& target) {
return Array<Tensor>{ topi::reshape(inputs[0], inputs[1]->shape) };
});
// Take // Take
......
...@@ -123,8 +123,28 @@ def test_reshape_infer_type(): ...@@ -123,8 +123,28 @@ def test_reshape_infer_type():
assert yy.checked_type == relay.TensorType( assert yy.checked_type == relay.TensorType(
(n, t, 2000), "float32") (n, t, 2000), "float32")
def test_reshape():
def verify_reshape(shape, oshape):
x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32")
ref_res = np.reshape(x_data, oshape)
def test_reshape_like(): x = relay.var("x", relay.TensorType(shape, "float32"))
z = relay.reshape(x, newshape=ref_res.shape)
zz = relay.ir_pass.infer_type(z)
assert "newshape=" in z.astext()
assert zz.checked_type == relay.ty.TensorType(oshape, "float32")
func = relay.Function([x], z)
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(x_data)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
verify_reshape((2, 3, 4), (8, 3))
verify_reshape((4, 7), (2, 7, 2))
def test_reshape_like_infer_type():
# concrete shape # concrete shape
x = relay.var("x", relay.TensorType((1, 2, 3), "float32")) x = relay.var("x", relay.TensorType((1, 2, 3), "float32"))
y = relay.var("y", relay.TensorType((1,6), "float32")) y = relay.var("y", relay.TensorType((1,6), "float32"))
...@@ -141,6 +161,29 @@ def test_reshape_like(): ...@@ -141,6 +161,29 @@ def test_reshape_like():
assert zz.checked_type == relay.TensorType((1, 8, 8), "float32") assert zz.checked_type == relay.TensorType((1, 8, 8), "float32")
def test_reshape_like():
def verify_reshape_like(shape, oshape):
x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32")
y_data = np.random.uniform(low=-1, high=1, size=oshape).astype("float32")
ref_res = np.reshape(x_data, y_data.shape)
x = relay.var("x", relay.TensorType(shape, "float32"))
y = relay.var("x", relay.TensorType(oshape, "float32"))
z = relay.reshape_like(x, y)
zz = relay.ir_pass.infer_type(z)
assert zz.checked_type == relay.ty.TensorType(ref_res.shape, "float32")
func = relay.Function([x, y], z)
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(x_data, y_data)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
verify_reshape_like((2, 3, 4), (1, 8, 3))
verify_reshape_like((4, 7), (2, 7, 2))
def test_take_infer_type(): def test_take_infer_type():
def verify_take(dshape, indices_shape, oshape, axis=None): def verify_take(dshape, indices_shape, oshape, axis=None):
x = relay.var("x", relay.TensorType(dshape, "float32")) x = relay.var("x", relay.TensorType(dshape, "float32"))
...@@ -318,6 +361,8 @@ if __name__ == "__main__": ...@@ -318,6 +361,8 @@ if __name__ == "__main__":
test_clip() test_clip()
test_transpose_infer_type() test_transpose_infer_type()
test_reshape_infer_type() test_reshape_infer_type()
test_reshape()
test_reshape_like_infer_type()
test_reshape_like() test_reshape_like()
test_take_infer_type() test_take_infer_type()
test_full() test_full()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment