Commit 62a94c76 by Siju Committed by Tianqi Chen

[RELAY]reshape_like (#1950)

parent 4fbb7c89
...@@ -78,6 +78,7 @@ This level enables additional math and transform operators. ...@@ -78,6 +78,7 @@ This level enables additional math and transform operators.
tvm.relay.ones tvm.relay.ones
tvm.relay.ones_like tvm.relay.ones_like
tvm.relay.reshape tvm.relay.reshape
tvm.relay.reshape_like
tvm.relay.copy tvm.relay.copy
tvm.relay.transpose tvm.relay.transpose
tvm.relay.floor tvm.relay.floor
...@@ -189,6 +190,7 @@ Level 3 Definitions ...@@ -189,6 +190,7 @@ Level 3 Definitions
.. autofunction:: tvm.relay.abs .. autofunction:: tvm.relay.abs
.. autofunction:: tvm.relay.negative .. autofunction:: tvm.relay.negative
.. autofunction:: tvm.relay.reshape .. autofunction:: tvm.relay.reshape
.. autofunction:: tvm.relay.reshape_like
.. autofunction:: tvm.relay.copy .. autofunction:: tvm.relay.copy
.. autofunction:: tvm.relay.transpose .. autofunction:: tvm.relay.transpose
.. autofunction:: tvm.relay.take .. autofunction:: tvm.relay.take
......
...@@ -82,6 +82,11 @@ class TensorTypeNode : public BaseTensorTypeNode { ...@@ -82,6 +82,11 @@ class TensorTypeNode : public BaseTensorTypeNode {
v->Visit("span", &span); v->Visit("span", &span);
} }
/*! \brief Return product of elements in the shape.
* \return (d1 * d_2 ... * d_n) if shape is (d_1, d_2, ..., d_n) and 1 if shape size is zero.
*/
TVM_DLL IndexExpr Size() const;
TVM_DLL static TensorType make(Array<IndexExpr> shape, DataType dtype); TVM_DLL static TensorType make(Array<IndexExpr> shape, DataType dtype);
/*! \brief Construct an scalar containing elements of dtype. */ /*! \brief Construct an scalar containing elements of dtype. */
......
...@@ -142,6 +142,29 @@ def reshape(data, newshape): ...@@ -142,6 +142,29 @@ def reshape(data, newshape):
return _make.reshape(data, list(newshape)) return _make.reshape(data, list(newshape))
def reshape_like(data, shape_like):
"""Reshapes the input array by the size of another array.
For an input array with shape ``(d1, d2, ..., dk)``, `reshape_like` operation reshapes
the input array into an output array with the same shape as the second input array.
.. note::
Sizes for both array should be compatible.
Parameters
----------
data : relay.Expr
The input data to the operator.
shape_like : tuple of int
The new shape. Should be compatible with the original shape.
Returns
-------
ret : relay.Expr
The computed result.
"""
return _make.reshape_like(data, shape_like)
def take(data, indices, axis=None): def take(data, indices, axis=None):
"""Take elements from an array along an axis. """Take elements from an array along an axis.
......
...@@ -22,6 +22,18 @@ TensorType TensorTypeNode::Scalar(DataType dtype) { ...@@ -22,6 +22,18 @@ TensorType TensorTypeNode::Scalar(DataType dtype) {
return TensorTypeNode::make({}, dtype); return TensorTypeNode::make({}, dtype);
} }
IndexExpr TensorTypeNode::Size() const {
if (shape.size() == 0) {
return make_const(Int(64), 1);
}
IndexExpr size = shape[0];
for (size_t i = 1; i < shape.size(); ++i) {
size *= shape[i];
}
return size;
}
TVM_REGISTER_NODE_TYPE(TensorTypeNode); TVM_REGISTER_NODE_TYPE(TensorTypeNode);
TVM_REGISTER_API("relay._make.TensorType") TVM_REGISTER_API("relay._make.TensorType")
......
...@@ -377,6 +377,62 @@ Example:: ...@@ -377,6 +377,62 @@ Example::
.set_support_level(3) .set_support_level(3)
.add_type_rel("Reshape", ReshapeRel); .add_type_rel("Reshape", ReshapeRel);
/*!
* \brief ReshapeLikeRel User defined type constraint function.
* \param num_inputs Number of input types in the args.
* \param attrs The additional attributes of the operator.
* \param reporter The reporter to report solution to.
* \return False if the relation has not been resolved, it might be resolved later.
* True if this relation has been resolved.
*/
bool ReshapeLikeRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) {
return false;
}
const auto* reshape_like = types[1].as<TensorTypeNode>();
if (reshape_like == nullptr) {
return false;
}
CHECK(reporter->AssertEQ(data->Size(), reshape_like->Size()))
<< "Reshape inputs size should be compatible.";
reporter->Assign(types[2], TensorTypeNode::make(reshape_like->shape, data->dtype));
return true;
}
Expr MakeReshapeLike(Expr data,
Expr shape_like) {
static const Op& op = Op::Get("reshape_like");
return CallNode::make(op, {data, shape_like}, Attrs(), {});
}
TVM_REGISTER_API("relay.op._make.reshape_like")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 2>(MakeReshapeLike, args, rv);
});
RELAY_REGISTER_OP("reshape_like")
.describe(R"code(Reshapes the input array by the size of another array.
For an input array with shape ``(d1, d2, ..., dk)``, `reshape_like` operation reshapes
the input array into an output array with the same shape as the second input array.
.. note::
Sizes for both array should be compatible.
)code" TVM_ADD_FILELINE)
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("shape_like", "Tensor", "Shape tensor.")
.set_support_level(3)
.add_type_rel("ReshapeLike", ReshapeLikeRel);
// Take // Take
TVM_REGISTER_NODE_TYPE(TakeAttrs); TVM_REGISTER_NODE_TYPE(TakeAttrs);
......
...@@ -88,6 +88,22 @@ def test_reshape_infer_type(): ...@@ -88,6 +88,22 @@ def test_reshape_infer_type():
(n, t, 2000), "float32") (n, t, 2000), "float32")
def test_reshape_like():
# concrete shape
x = relay.var("x", relay.TensorType((1, 2, 3), "float32"))
y = relay.var("y", relay.TensorType((1,6), "float32"))
z = relay.reshape_like(x, y)
zz = relay.ir_pass.infer_type(z)
assert zz.checked_type == relay.TensorType((1, 6), "float32")
# symbolic shape
n, c, h, w = tvm.var("n"), 2, 3, tvm.var("w")
x = relay.var("x", relay.TensorType((n, c, h, w), "float32"))
y = relay.var("y", relay.TensorType((1, 8, 8), "float32"))
z = relay.reshape_like(x, y)
zz = relay.ir_pass.infer_type(z)
assert zz.checked_type == relay.TensorType((1, 8, 8), "float32")
def test_take_infer_type(): def test_take_infer_type():
def verify_take(dshape, indices_shape, oshape, axis=None): def verify_take(dshape, indices_shape, oshape, axis=None):
...@@ -187,6 +203,7 @@ if __name__ == "__main__": ...@@ -187,6 +203,7 @@ if __name__ == "__main__":
test_clip_type() test_clip_type()
test_transpose_infer_type() test_transpose_infer_type()
test_reshape_infer_type() test_reshape_infer_type()
test_reshape_like()
test_take_infer_type() test_take_infer_type()
test_full() test_full()
test_full_like() test_full_like()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment