Commit 40f76825 by Haichen Shen Committed by Yizhi Liu

[Relay][OP] Add reverse_reshape (#2503)

* Enable reverse in reshape

* Fix lint and typo

* Put reverse_reshape into a separate op

* Fix pylint
parent 3a75b13d
...@@ -149,6 +149,7 @@ This level support backpropagation of broadcast operators. It is temporary. ...@@ -149,6 +149,7 @@ This level support backpropagation of broadcast operators. It is temporary.
tvm.relay.layout_transform tvm.relay.layout_transform
tvm.relay.device_copy tvm.relay.device_copy
tvm.relay.annotation.on_device tvm.relay.annotation.on_device
tvm.relay.reverse_reshape
Level 1 Definitions Level 1 Definitions
...@@ -257,4 +258,5 @@ Level 10 Definitions ...@@ -257,4 +258,5 @@ Level 10 Definitions
.. autofunction:: tvm.relay.slice_like .. autofunction:: tvm.relay.slice_like
.. autofunction:: tvm.relay.layout_transform .. autofunction:: tvm.relay.layout_transform
.. autofunction:: tvm.relay.device_copy .. autofunction:: tvm.relay.device_copy
.. autofunction:: tvm.relay.annotation.on_device .. autofunction:: tvm.relay.annotation.on_device
\ No newline at end of file .. autofunction:: tvm.relay.reverse_reshape
...@@ -63,9 +63,13 @@ struct TransposeAttrs : public tvm::AttrsNode<TransposeAttrs> { ...@@ -63,9 +63,13 @@ struct TransposeAttrs : public tvm::AttrsNode<TransposeAttrs> {
/*! \brief Attributes used in reshape operators */ /*! \brief Attributes used in reshape operators */
struct ReshapeAttrs : public tvm::AttrsNode<ReshapeAttrs> { struct ReshapeAttrs : public tvm::AttrsNode<ReshapeAttrs> {
Array<Integer> newshape; Array<Integer> newshape;
bool reverse;
TVM_DECLARE_ATTRS(ReshapeAttrs, "relay.attrs.ReshapeAttrs") { TVM_DECLARE_ATTRS(ReshapeAttrs, "relay.attrs.ReshapeAttrs") {
TVM_ATTR_FIELD(newshape) TVM_ATTR_FIELD(newshape)
.describe("The new shape. Should be compatible with the original shape."); .describe("The new shape. Should be compatible with the original shape.");
TVM_ATTR_FIELD(reverse)
.describe("Infer the special values from right to left if true")
.set_default(false);
} }
}; // struct ReshapeAttrs }; // struct ReshapeAttrs
......
...@@ -22,9 +22,10 @@ def _rename(new_op): ...@@ -22,9 +22,10 @@ def _rename(new_op):
def _reshape(inputs, attrs): def _reshape(inputs, attrs):
if attrs.get_bool("reverse", False):
raise RuntimeError("reshape do not support option reverse")
shape = attrs.get_int_tuple("shape") shape = attrs.get_int_tuple("shape")
reverse = attrs.get_bool("reverse", False)
if reverse:
return _op.reverse_reshape(inputs[0], newshape=shape)
return _op.reshape(inputs[0], newshape=shape) return _op.reshape(inputs[0], newshape=shape)
......
...@@ -26,6 +26,7 @@ _reg.register_schedule("split", schedule_injective) ...@@ -26,6 +26,7 @@ _reg.register_schedule("split", schedule_injective)
_reg.register_schedule("take", schedule_injective) _reg.register_schedule("take", schedule_injective)
_reg.register_schedule("transpose", schedule_injective) _reg.register_schedule("transpose", schedule_injective)
_reg.register_schedule("where", schedule_broadcast) _reg.register_schedule("where", schedule_broadcast)
_reg.register_schedule("_contrib_reverse_reshape", schedule_injective)
# layout_transform # layout_transform
_reg.register_schedule("layout_transform", schedule_injective) _reg.register_schedule("layout_transform", schedule_injective)
......
...@@ -141,7 +141,7 @@ def reshape(data, newshape): ...@@ -141,7 +141,7 @@ def reshape(data, newshape):
Example:: Example::
- data.shape = (2,3,4), newshape = (-4,1,2,-2), result.shape =(1,2,3,4) - data.shape = (2,3,4), newshape = (-4,1,2,-2), result.shape = (1,2,3,4)
- data.shape = (2,3,4), newshape = (2,-4,-1,3,-2), result.shape = (2,1,3,4) - data.shape = (2,3,4), newshape = (2,-4,-1,3,-2), result.shape = (2,1,3,4)
Parameters Parameters
...@@ -449,3 +449,34 @@ def layout_transform(data, src_layout, dst_layout): ...@@ -449,3 +449,34 @@ def layout_transform(data, src_layout, dst_layout):
The transformed tensor. The transformed tensor.
""" """
return _make.layout_transform(data, src_layout, dst_layout) return _make.layout_transform(data, src_layout, dst_layout)
def reverse_reshape(data, newshape):
"""Reshapes the input array where the special values are inferred from
right to left.
Example::
The special values have the same semantics as :py:class:`tvm.relay.reshape`.
The difference is that special values are inferred from right to left. It
can be explained in the example below::
- data.shape = (10,5,4), newshape = (-1,0), reshape results in (40,5)
- data.shape = (10,5,4), newshape = (-1,0), reverse_reshape results in (40,5)
Parameters
----------
data : relay.Expr
The input data to the operator.
newshape : Union[int, Tuple[int], List[int]]
The new shape. Should be compatible with the original shape.
Returns
-------
result : relay.Expr
The reshaped result.
"""
if isinstance(newshape, int):
newshape = [newshape]
return _make._contrib_reverse_reshape(data, list(newshape))
...@@ -382,20 +382,29 @@ bool ReshapeRel(const Array<Type>& types, ...@@ -382,20 +382,29 @@ bool ReshapeRel(const Array<Type>& types,
} }
const auto* param = attrs.as<ReshapeAttrs>(); const auto* param = attrs.as<ReshapeAttrs>();
Array<IndexExpr> data_shape;
Array<Integer> newshape;
if (param->reverse) {
data_shape.assign(data->shape.rbegin(), data->shape.rend());
newshape.assign(param->newshape.rbegin(), param->newshape.rend());
} else {
data_shape = data->shape;
newshape = param->newshape;
}
Array<IndexExpr> oshape; Array<IndexExpr> oshape;
size_t src_idx = 0; size_t src_idx = 0;
int infer_idx = -1; int infer_idx = -1;
for (size_t i = 0; i < param->newshape.size(); ++i) { for (size_t i = 0; i < newshape.size(); ++i) {
int svalue = param->newshape[i]->value; int svalue = newshape[i]->value;
// special flag handling for shape inference. // special flag handling for shape inference.
if (svalue > 0) { if (svalue > 0) {
oshape.push_back(param->newshape[i]); oshape.push_back(newshape[i]);
++src_idx; ++src_idx;
} else if (svalue == 0) { } else if (svalue == 0) {
// keep same // keep same
CHECK_LT(src_idx, data->shape.size()); CHECK_LT(src_idx, data_shape.size());
oshape.push_back(data->shape[src_idx++]); oshape.push_back(data_shape[src_idx++]);
} else if (svalue == -1) { } else if (svalue == -1) {
// inference based on rest // inference based on rest
CHECK_LT(infer_idx, 0) CHECK_LT(infer_idx, 0)
...@@ -405,42 +414,51 @@ bool ReshapeRel(const Array<Type>& types, ...@@ -405,42 +414,51 @@ bool ReshapeRel(const Array<Type>& types,
++src_idx; ++src_idx;
} else if (svalue == -2) { } else if (svalue == -2) {
// copy all remaining dims from source // copy all remaining dims from source
while (src_idx < data->shape.size()) { while (src_idx < data_shape.size()) {
oshape.push_back(data->shape[src_idx++]); oshape.push_back(data_shape[src_idx++]);
} }
} else if (svalue == -3) { } else if (svalue == -3) {
// merge two dims from source // merge two dims from source
CHECK_LT(src_idx + 1, data->shape.size()); CHECK_LT(src_idx + 1, data_shape.size());
IndexExpr d1 = data->shape[src_idx++]; IndexExpr d1 = data_shape[src_idx++];
IndexExpr d2 = data->shape[src_idx++]; IndexExpr d2 = data_shape[src_idx++];
oshape.push_back(d1 * d2); oshape.push_back(d1 * d2);
} else if (svalue == -4) { } else if (svalue == -4) {
// split the source dim s into two dims // split the source dim s into two dims
// read the left dim and then the right dim (either can be -1) // read the left dim and then the right dim (either can be -1)
CHECK_LT(i + 2, param->newshape.size()); CHECK_LT(i + 2, newshape.size());
CHECK_LT(src_idx, data->shape.size()); CHECK_LT(src_idx, data_shape.size());
IndexExpr d0 = data->shape[src_idx++]; IndexExpr d0 = data_shape[src_idx++];
Integer d1 = param->newshape[++i]; Integer d1 = newshape[++i];
Integer d2 = param->newshape[++i]; Integer d2 = newshape[++i];
if (d1->value == -1) { if (d1->value == -1) {
CHECK(d2->value != -1) CHECK(d2->value != -1)
<< "Split dims cannot both be -1."; << "Split dims cannot both be -1.";
oshape.push_back(d0 / d2); oshape.push_back(d0 / d2);
oshape.push_back(d2); oshape.push_back(d2);
} else { } else {
CHECK_EQ(d2->value, -1);
oshape.push_back(d1); oshape.push_back(d1);
oshape.push_back(d0 / d1); if (d2->value == -1) {
oshape.push_back(d0 / d1);
} else {
oshape.push_back(d2);
}
} }
} }
} }
if (infer_idx >= 0) { if (infer_idx >= 0) {
IndexExpr new_size = arith::ComputeReduce<tvm::ir::Mul>(oshape, 1); IndexExpr new_size = arith::ComputeReduce<tvm::ir::Mul>(oshape, 1);
IndexExpr old_size = arith::ComputeReduce<tvm::ir::Mul>(data->shape, 1); IndexExpr old_size = arith::ComputeReduce<tvm::ir::Mul>(data_shape, 1);
oshape.Set(infer_idx, old_size / new_size); oshape.Set(infer_idx, old_size / new_size);
} }
reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));
if (param->reverse) {
reporter->Assign(types[1], TensorTypeNode::make(
Array<IndexExpr>(oshape.rbegin(), oshape.rend()), data->dtype));
} else {
reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));
}
return true; return true;
} }
...@@ -457,6 +475,7 @@ Expr MakeReshape(Expr data, ...@@ -457,6 +475,7 @@ Expr MakeReshape(Expr data,
Array<Integer> newshape) { Array<Integer> newshape) {
auto attrs = make_node<ReshapeAttrs>(); auto attrs = make_node<ReshapeAttrs>();
attrs->newshape = std::move(newshape); attrs->newshape = std::move(newshape);
attrs->reverse = false;
static const Op& op = Op::Get("reshape"); static const Op& op = Op::Get("reshape");
return CallNode::make(op, {data}, Attrs(attrs), {}); return CallNode::make(op, {data}, Attrs(attrs), {});
} }
...@@ -1699,5 +1718,43 @@ the input array by output[n, c, h, w, C] = data[n, C*16+c, h, w] ...@@ -1699,5 +1718,43 @@ the input array by output[n, c, h, w, C] = data[n, C*16+c, h, w]
.set_support_level(5) .set_support_level(5)
.set_attr<FTVMCompute>("FTVMCompute", LayoutTransformCompute); .set_attr<FTVMCompute>("FTVMCompute", LayoutTransformCompute);
/* relay._contrib_reverse_reshape */
Expr MakeReverseReshape(Expr data,
Array<Integer> newshape) {
auto attrs = make_node<ReshapeAttrs>();
attrs->newshape = std::move(newshape);
attrs->reverse = true;
static const Op& op = Op::Get("_contrib_reverse_reshape");
return CallNode::make(op, {data}, Attrs(attrs), {});
}
TVM_REGISTER_API("relay.op._make._contrib_reverse_reshape")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 2>(MakeReverseReshape, args, rv);
});
RELAY_REGISTER_OP("_contrib_reverse_reshape")
.describe(R"code(Reshapes the input array where the special values are inferred from
right to left.
Example::
The special values have the same semantics as reshape. The difference is that
special values are inferred from right to left. It can be explained in the
example below::
- data.shape = (10,5,4), newshape = (-1,0), reshape results in (40,5)
- data.shape = (10,5,4), newshape = (-1,0), reverse_reshape results in (40,5)
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.set_attrs_type_key("relay.attrs.ReshapeAttrs")
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(10)
.add_type_rel("Reshape", ReshapeRel)
.set_attr<FTVMCompute>("FTVMCompute", ReshapeCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
...@@ -121,8 +121,31 @@ def test_slice_like(): ...@@ -121,8 +121,31 @@ def test_slice_like():
axes=(2, 3), axes=(2, 3),
output=(1, 3, 112, 112)) output=(1, 3, 112, 112))
def test_reverse_reshape():
def verify_reverse_reshape(shape, newshape, oshape):
x = relay.var("x", relay.TensorType(shape, "float32"))
z = relay.reverse_reshape(x, newshape=newshape)
zz = relay.ir_pass.infer_type(z)
print(zz.checked_type)
assert "newshape=" in z.astext()
assert zz.checked_type == relay.ty.TensorType(oshape, "float32")
func = relay.Function([x], z)
x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32")
ref_res = np.reshape(x_data, oshape)
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(x_data)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
verify_reverse_reshape((2, 3, 4), (4, 0, 2), (4, 3, 2))
verify_reverse_reshape((2, 3, 4), (2, 0, 0), (2, 3, 4))
verify_reverse_reshape((2, 3, 4), (0, -1), (3, 8))
verify_reverse_reshape((2, 3, 4), (-1, 0), (6, 4))
verify_reverse_reshape((2, 3, 4), (0, -3), (2, 12))
if __name__ == "__main__": if __name__ == "__main__":
test_collapse_sum_like() test_collapse_sum_like()
test_broadcast_to_like() test_broadcast_to_like()
test_slice_like() test_slice_like()
test_reverse_reshape()
...@@ -152,25 +152,36 @@ def test_reshape_infer_type(): ...@@ -152,25 +152,36 @@ def test_reshape_infer_type():
(n, t, 2000), "float32") (n, t, 2000), "float32")
def test_reshape(): def test_reshape():
def verify_reshape(shape, oshape): def verify_reshape(shape, newshape, oshape):
x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32")
ref_res = np.reshape(x_data, oshape)
x = relay.var("x", relay.TensorType(shape, "float32")) x = relay.var("x", relay.TensorType(shape, "float32"))
z = relay.reshape(x, newshape=ref_res.shape) z = relay.reshape(x, newshape=newshape)
zz = relay.ir_pass.infer_type(z) zz = relay.ir_pass.infer_type(z)
assert "newshape=" in z.astext() assert "newshape=" in z.astext()
assert zz.checked_type == relay.ty.TensorType(oshape, "float32") assert zz.checked_type == relay.ty.TensorType(oshape, "float32")
func = relay.Function([x], z) func = relay.Function([x], z)
x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32")
ref_res = np.reshape(x_data, oshape)
for target, ctx in ctx_list(): for target, ctx in ctx_list():
for kind in ["graph", "debug"]: for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target) intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(x_data) op_res = intrp.evaluate(func)(x_data)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
verify_reshape((2, 3, 4), (8, 3)) verify_reshape((2, 3, 4), (8, 3), (8, 3))
verify_reshape((4, 7), (2, 7, 2)) verify_reshape((4, 7), (2, 7, 2), (2, 7, 2))
verify_reshape((2, 3, 4), (4, 0, 2), (4, 3, 2))
verify_reshape((2, 3, 4), (2, 0, 0), (2, 3, 4))
verify_reshape((2, 3, 4), (0, -1), (2, 12))
verify_reshape((2, 3, 4), (-1, 0), (8, 3))
verify_reshape((2, 3, 4), (2, -2), (2, 3, 4))
verify_reshape((2, 3, 4), (-2, 1, 1), (2, 3, 4, 1, 1))
verify_reshape((2, 3, 4), (-3, 4), (6, 4))
verify_reshape((2, 3, 4, 5), (-3, -3), (6, 20))
verify_reshape((2, 3, 4), (0, -3), (2, 12))
verify_reshape((2, 3, 4), (-3, -2), (6, 4))
verify_reshape((2, 3, 4), (-4, 1, 2, -2), (1, 2, 3, 4))
verify_reshape((2, 3, 4), (2, -4, -1, 3, -2), (2, 1, 3, 4))
def test_reshape_like_infer_type(): def test_reshape_like_infer_type():
# concrete shape # concrete shape
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment