Commit ee805806 by Haichen Shen Committed by Leyuan Wang

[Relay/TOPI][Op] Add shape op in Relay and TOPI (#2749)

* Add shapeof op in topi

* Add relay shape_of op

* Add constant folding for shape_of

* Allow shape op to specify dtype

* Add mxnet converter for shape_array

* lint

* lint

* Add doc
parent 4d09fc4e
...@@ -75,6 +75,7 @@ List of operators ...@@ -75,6 +75,7 @@ List of operators
topi.stack topi.stack
topi.repeat topi.repeat
topi.tile topi.tile
topi.shape
topi.layout_transform topi.layout_transform
topi.image.resize topi.image.resize
...@@ -136,6 +137,7 @@ topi ...@@ -136,6 +137,7 @@ topi
.. autofunction:: topi.stack .. autofunction:: topi.stack
.. autofunction:: topi.repeat .. autofunction:: topi.repeat
.. autofunction:: topi.tile .. autofunction:: topi.tile
.. autofunction:: topi.shape
.. autofunction:: topi.layout_transform .. autofunction:: topi.layout_transform
topi.nn topi.nn
......
...@@ -155,6 +155,7 @@ This level support backpropagation of broadcast operators. It is temporary. ...@@ -155,6 +155,7 @@ This level support backpropagation of broadcast operators. It is temporary.
tvm.relay.broadcast_to_like tvm.relay.broadcast_to_like
tvm.relay.collapse_sum_like tvm.relay.collapse_sum_like
tvm.relay.slice_like tvm.relay.slice_like
tvm.relay.shape_of
tvm.relay.layout_transform tvm.relay.layout_transform
tvm.relay.device_copy tvm.relay.device_copy
tvm.relay.annotation.on_device tvm.relay.annotation.on_device
...@@ -275,6 +276,7 @@ Level 10 Definitions ...@@ -275,6 +276,7 @@ Level 10 Definitions
.. autofunction:: tvm.relay.broadcast_to_like .. autofunction:: tvm.relay.broadcast_to_like
.. autofunction:: tvm.relay.collapse_sum_like .. autofunction:: tvm.relay.collapse_sum_like
.. autofunction:: tvm.relay.slice_like .. autofunction:: tvm.relay.slice_like
.. autofunction:: tvm.relay.shape_of
.. autofunction:: tvm.relay.layout_transform .. autofunction:: tvm.relay.layout_transform
.. autofunction:: tvm.relay.device_copy .. autofunction:: tvm.relay.device_copy
.. autofunction:: tvm.relay.annotation.on_device .. autofunction:: tvm.relay.annotation.on_device
......
...@@ -226,6 +226,7 @@ struct ClipAttrs : public tvm::AttrsNode<ClipAttrs> { ...@@ -226,6 +226,7 @@ struct ClipAttrs : public tvm::AttrsNode<ClipAttrs> {
} }
}; };
/*! \brief Attributes for LayoutTransform operator */
struct LayoutTransformAttrs : public tvm::AttrsNode<LayoutTransformAttrs> { struct LayoutTransformAttrs : public tvm::AttrsNode<LayoutTransformAttrs> {
std::string src_layout; std::string src_layout;
std::string dst_layout; std::string dst_layout;
...@@ -238,6 +239,17 @@ struct LayoutTransformAttrs : public tvm::AttrsNode<LayoutTransformAttrs> { ...@@ -238,6 +239,17 @@ struct LayoutTransformAttrs : public tvm::AttrsNode<LayoutTransformAttrs> {
} }
}; };
/*! \brief Attributes for ShapeOf operator */
struct ShapeOfAttrs : public tvm::AttrsNode<ShapeOfAttrs> {
DataType dtype;
TVM_DECLARE_ATTRS(ShapeOfAttrs, "relay.attrs.ShapeOfAttrs") {
TVM_ATTR_FIELD(dtype)
.describe("Target data type")
.set_default(NullValue<DataType>());
}
};
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
#endif // TVM_RELAY_ATTRS_TRANSFORM_H_ #endif // TVM_RELAY_ATTRS_TRANSFORM_H_
...@@ -495,6 +495,19 @@ def _mx_l2_normalize(inputs, attrs): ...@@ -495,6 +495,19 @@ def _mx_l2_normalize(inputs, attrs):
return _op.nn.l2_normalize(inputs[0], **new_attrs) return _op.nn.l2_normalize(inputs[0], **new_attrs)
def _mx_shape_array(inputs, attrs):
assert len(inputs) == 1
if attrs.get_int("lhs_begin", None) is not None:
raise RuntimeError("shape_array doesn't support lhs_begin")
if attrs.get_int("lhs_end", None) is not None:
raise RuntimeError("shape_array doesn't support lhs_end")
if attrs.get_int("rhs_begin", None) is not None:
raise RuntimeError("shape_array doesn't support rhs_begin")
if attrs.get_int("rhs_end", None) is not None:
raise RuntimeError("shape_array doesn't support rhs_end")
return _op.shape_of(inputs[0], dtype='int64')
# Note: due to attribute conversion constraint # Note: due to attribute conversion constraint
# ops in the identity set must be attribute free # ops in the identity set must be attribute free
_identity_list = [ _identity_list = [
...@@ -621,6 +634,7 @@ _convert_map = { ...@@ -621,6 +634,7 @@ _convert_map = {
"tile" : _mx_tile, "tile" : _mx_tile,
"reverse" : _mx_reverse, "reverse" : _mx_reverse,
"BlockGrad" : _mx_BlockGrad, "BlockGrad" : _mx_BlockGrad,
"shape_array" : _mx_shape_array,
"SoftmaxOutput" : _mx_softmax_output, "SoftmaxOutput" : _mx_softmax_output,
"SoftmaxActivation" : _mx_softmax_activation, "SoftmaxActivation" : _mx_softmax_activation,
# vision # vision
......
...@@ -40,6 +40,7 @@ register_schedule("maximum", schedule_injective) ...@@ -40,6 +40,7 @@ register_schedule("maximum", schedule_injective)
register_schedule("minimum", schedule_injective) register_schedule("minimum", schedule_injective)
register_schedule("right_shift", schedule_injective) register_schedule("right_shift", schedule_injective)
register_schedule("left_shift", schedule_injective) register_schedule("left_shift", schedule_injective)
register_schedule("shape_of", schedule_injective)
# zeros # zeros
@register_compute("zeros") @register_compute("zeros")
......
...@@ -713,3 +713,22 @@ def device_copy(data, src_dev, dst_dev): ...@@ -713,3 +713,22 @@ def device_copy(data, src_dev, dst_dev):
raise ValueError("dst_dev is expected to be the type of TVMContext or " raise ValueError("dst_dev is expected to be the type of TVMContext or "
"str, but received %s" % (type(dst_dev))) "str, but received %s" % (type(dst_dev)))
return _make.device_copy(data, src_dev, dst_dev) return _make.device_copy(data, src_dev, dst_dev)
def shape_of(data, dtype="int32"):
"""Get shape of a tensor.
Parameters
----------
data : tvm.relay.Expr
The input tensor.
dtype : str, optional
The target data type.
Returns
-------
result : tvm.relay.Expr
The shape tensor.
"""
return _make.shape_of(data, dtype)
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include <tvm/relay/op.h> #include <tvm/relay/op.h>
#include <tvm/relay/attrs/transform.h> #include <tvm/relay/attrs/transform.h>
#include <topi/elemwise.h> #include <topi/elemwise.h>
#include <topi/transform.h>
#include "../type_relations.h" #include "../type_relations.h"
#include "../op_common.h" #include "../op_common.h"
...@@ -189,5 +190,56 @@ RELAY_REGISTER_UNARY_OP("logical_not") ...@@ -189,5 +190,56 @@ RELAY_REGISTER_UNARY_OP("logical_not")
.set_support_level(4) .set_support_level(4)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::logical_not)); .set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::logical_not));
// shape_of
TVM_REGISTER_NODE_TYPE(ShapeOfAttrs);
bool ShapeOfRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(num_inputs, 1);
auto tt = types[0].as<TensorTypeNode>();
CHECK(tt != nullptr);
const auto* param = attrs.as<ShapeOfAttrs>();
CHECK(param != nullptr);
auto vector_out = tvm::Integer(tt->shape.size());
reporter->Assign(types[1], TensorTypeNode::make({ vector_out }, param->dtype));
return true;
}
Array<Tensor> ShapeOfCompute(const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
const Target& target) {
CHECK_EQ(inputs.size(), 1);
const auto* param = attrs.as<ShapeOfAttrs>();
CHECK(param != nullptr);
return {topi::shape(inputs[0], param->dtype)};
}
TVM_REGISTER_API("relay.op._make.shape_of")
.set_body_typed<Expr(Expr, DataType)>([](Expr data, DataType dtype) {
auto attrs = make_node<ShapeOfAttrs>();
attrs->dtype = dtype;
static const Op& op = Op::Get("shape_of");
return CallNode::make(op, {data}, Attrs(attrs), {});
});
RELAY_REGISTER_OP("shape_of")
.describe(R"code(Returns a tensor representing the shape of a tensor.
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.set_attrs_type_key("relay.attrs.ShapeOfAttrs")
.add_argument("data", "Tensor", "The input tensor.")
.add_type_rel("ShapeOf", ShapeOfRel)
.set_attr<TOpIsStateful>("TOpIsStateful", false)
.set_attr<TOpPattern>("TOpPattern", kInjective)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
ElemwiseArbitraryLayout)
.set_support_level(10)
.set_attr<FTVMCompute>("FTVMCompute", ShapeOfCompute);
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include <tvm/relay/op_attr_types.h> #include <tvm/relay/op_attr_types.h>
#include <tvm/relay/interpreter.h> #include <tvm/relay/interpreter.h>
#include <tvm/relay/attrs/transform.h>
namespace tvm { namespace tvm {
namespace relay { namespace relay {
...@@ -71,6 +72,7 @@ class ConstantFolder : public ExprMutator { ...@@ -71,6 +72,7 @@ class ConstantFolder : public ExprMutator {
Expr VisitExpr_(const CallNode* call) final { Expr VisitExpr_(const CallNode* call) final {
static auto op_stateful = Op::GetAttr<TOpIsStateful>("TOpIsStateful"); static auto op_stateful = Op::GetAttr<TOpIsStateful>("TOpIsStateful");
auto origin_args = call->args;
Expr res = ExprMutator::VisitExpr_(call); Expr res = ExprMutator::VisitExpr_(call);
call = res.as<CallNode>(); call = res.as<CallNode>();
// We don't constant fold function with zero arguments. // We don't constant fold function with zero arguments.
...@@ -81,6 +83,10 @@ class ConstantFolder : public ExprMutator { ...@@ -81,6 +83,10 @@ class ConstantFolder : public ExprMutator {
if (op == nullptr) return res; if (op == nullptr) return res;
// skip stateful ops. // skip stateful ops.
if (op_stateful.get(GetRef<Op>(op), false)) return res; if (op_stateful.get(GetRef<Op>(op), false)) return res;
// Try to evaluate shape_of op
if (call->op.same_as(Op::Get("shape_of"))) {
return EvaluateShapeOf(res, origin_args, call->attrs);
}
bool all_const_args = true; bool all_const_args = true;
for (Expr arg : call->args) { for (Expr arg : call->args) {
if (!checker_.Check(arg)) { if (!checker_.Check(arg)) {
...@@ -132,6 +138,42 @@ class ConstantFolder : public ExprMutator { ...@@ -132,6 +138,42 @@ class ConstantFolder : public ExprMutator {
expr = InferType(expr, Module(nullptr)); expr = InferType(expr, Module(nullptr));
return ValueToExpr(executor_(expr)); return ValueToExpr(executor_(expr));
} }
// Evaluate shape_of op
Expr EvaluateShapeOf(Expr expr, Array<Expr> args, Attrs attrs) {
Expr input = args[0];
const auto* param = attrs.as<ShapeOfAttrs>();
CHECK(param != nullptr);
tvm::Array<IndexExpr> ishape;
if (const ConstantNode* op = input.as<ConstantNode>()) {
ishape = op->tensor_type()->shape;
} else if (input->checked_type_.defined()) {
ishape = input->checked_type().as<TensorTypeNode>()->shape;
} else {
return expr;
}
// Get the constant shape
DLContext ctx;
ctx.device_type = kDLCPU;
ctx.device_id = 0;
auto val = runtime::NDArray::Empty(
{(int64_t)ishape.size()}, Type2TVMType(Int(32)), ctx);
int32_t* dims = static_cast<int32_t*>(val->data);
using ::tvm::ir::IntImm;
for (size_t i = 0; i < ishape.size(); ++i) {
if (const IntImm* dim = ishape[i].as<IntImm>()) {
dims[i] = dim->value;
} else {
return expr;
}
}
Expr shape = ValueToExpr(TensorValueNode::make(val));
// Cast the constant into correct dtype
auto cast_attrs = make_node<CastAttrs>();
cast_attrs->dtype = param->dtype;
static const Op& cast_op = Op::Get("cast");
Expr ret = CallNode::make(cast_op, {shape}, Attrs(cast_attrs), {});
return ConstEvaluate(ret);
}
}; };
......
...@@ -380,6 +380,22 @@ def test_forward_l2_normalize(): ...@@ -380,6 +380,22 @@ def test_forward_l2_normalize():
verify_mxnet_frontend_impl(mx_sym, (2, 3, 4, 5), (2, 3, 4, 5)) verify_mxnet_frontend_impl(mx_sym, (2, 3, 4, 5), (2, 3, 4, 5))
def test_forward_shape_array():
def verify(shape):
x_np = np.random.uniform(size=shape).astype("float32")
ref_res = mx.nd.shape_array(mx.nd.array(x_np))
mx_sym = mx.sym.shape_array(mx.sym.var("x"))
new_sym, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape})
for target, ctx in ctx_list():
for kind in ["debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(new_sym)(x_np)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
verify((1,))
verify((3, 4, 5))
verify((3, 4, 5, 6))
if __name__ == '__main__': if __name__ == '__main__':
test_forward_mlp() test_forward_mlp()
test_forward_vgg() test_forward_vgg()
...@@ -409,3 +425,4 @@ if __name__ == '__main__': ...@@ -409,3 +425,4 @@ if __name__ == '__main__':
test_forward_slice_like() test_forward_slice_like()
test_forward_slice_axis() test_forward_slice_axis()
test_forward_l2_normalize() test_forward_l2_normalize()
test_forward_shape_array()
...@@ -177,6 +177,20 @@ def test_batch_matmul(): ...@@ -177,6 +177,20 @@ def test_batch_matmul():
verify_batch_matmul((5, 16, 32), (5, 20, 32), (5, 16, 20)) verify_batch_matmul((5, 16, 32), (5, 20, 32), (5, 16, 20))
verify_batch_matmul((30, 16, 32), (30, 20, 32), (30, 16, 20)) verify_batch_matmul((30, 16, 32), (30, 20, 32), (30, 16, 20))
def test_shape_of():
shape = (10, 5, 12)
x = relay.var("x", shape=shape)
func = relay.Function([x], relay.op.shape_of(x))
func = relay.ir_pass.infer_type(func)
x_data = np.random.rand(*shape).astype('float32')
for target, ctx in ctx_list():
# Because using graph executor, this op will be optimized after
# constant folding pass, here we only test with interpreter
for kind in ["debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(x_data)
tvm.testing.assert_allclose(op_res.asnumpy(),
np.array(shape).astype('int32'))
if __name__ == "__main__": if __name__ == "__main__":
test_collapse_sum_like() test_collapse_sum_like()
...@@ -184,3 +198,4 @@ if __name__ == "__main__": ...@@ -184,3 +198,4 @@ if __name__ == "__main__":
test_slice_like() test_slice_like()
test_reverse_reshape() test_reverse_reshape()
test_batch_matmul() test_batch_matmul()
test_shape_of()
...@@ -95,8 +95,34 @@ def test_fold_concat(): ...@@ -95,8 +95,34 @@ def test_fold_concat():
assert relay.ir_pass.graph_equal(zz, zexpected) assert relay.ir_pass.graph_equal(zz, zexpected)
def test_fold_shape_of():
c_shape = (8, 9, 10)
def before(dtype):
x = relay.var("x", shape=c_shape, dtype="float32")
y = relay.var("y", shape=c_shape, dtype="float32")
z = relay.shape_of(x + y, dtype)
return relay.Function([x, y], z)
def expected(dtype):
x = relay.var("x", shape=c_shape, dtype="float32")
y = relay.var("y", shape=c_shape, dtype="float32")
z = relay.const(np.array(c_shape).astype(dtype), dtype=dtype)
return relay.ir_pass.infer_type(relay.Function([x, y], z))
for dtype in ["int32", "float32"]:
zbefore = before(dtype)
zz = relay.ir_pass.fold_constant(zbefore)
assert relay.ir_pass.graph_equal(zz, zbefore)
zz = relay.ir_pass.infer_type(zbefore)
zz = relay.ir_pass.fold_constant(zz)
zexpected = expected(dtype)
assert relay.ir_pass.graph_equal(zz, zexpected)
if __name__ == "__main__": if __name__ == "__main__":
test_fold_const() test_fold_const()
test_fold_let() test_fold_let()
test_fold_tuple() test_fold_tuple()
test_fold_concat() test_fold_concat()
test_fold_shape_of()
...@@ -1081,5 +1081,28 @@ inline Tensor layout_transform(const Tensor& src, ...@@ -1081,5 +1081,28 @@ inline Tensor layout_transform(const Tensor& src,
}, name, tag); }, name, tag);
} }
/*!
* \brief Get the shape of input tensor.
* \param src the input tensor.
* \param name output tensor name.
* \param tag output tensor tag.
* \return Tensor of input shape.
*/
inline Tensor shape(const Tensor& src,
Type dtype,
const std::string name = "shape",
const std::string tag = kInjective) {
int ndim = static_cast<int>(src->shape.size());
Array<Expr> out_shape{ndim};
return compute(out_shape, [&](const Array<Var>& indices) {
auto idx = indices[0];
Expr ret = 0;
for (int i = 0; i < ndim; ++i) {
ret = tvm::if_then_else(idx == i, src->shape[i], ret);
}
return tvm::cast(dtype, ret);
}, name, tag);
}
} // namespace topi } // namespace topi
#endif // TOPI_TRANSFORM_H_ #endif // TOPI_TRANSFORM_H_
...@@ -393,3 +393,22 @@ def layout_transform(array, src_layout, dst_layout): ...@@ -393,3 +393,22 @@ def layout_transform(array, src_layout, dst_layout):
the destination layout. the destination layout.
""" """
return cpp.layout_transform(array, src_layout, dst_layout) return cpp.layout_transform(array, src_layout, dst_layout)
def shape(array, dtype="int32"):
"""Get the shape of input array
Parameters
----------
array : tvm.Tensor
The source tenosr.
dtype : str, optional
The target data type.
Returns
-------
result : tvm.Tensor
The resulting tensor.
"""
return cpp.shape(array, dtype)
...@@ -271,6 +271,11 @@ TVM_REGISTER_GLOBAL("topi.stack") ...@@ -271,6 +271,11 @@ TVM_REGISTER_GLOBAL("topi.stack")
*rv = stack(args[0], args[1]); *rv = stack(args[0], args[1]);
}); });
TVM_REGISTER_GLOBAL("topi.shape")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = shape(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.split") TVM_REGISTER_GLOBAL("topi.split")
.set_body([](TVMArgs args, TVMRetValue *rv) { .set_body([](TVMArgs args, TVMRetValue *rv) {
if (args[1].type_code() == kDLInt || args[1].type_code() == kDLUInt) { if (args[1].type_code() == kDLInt || args[1].type_code() == kDLUInt) {
...@@ -278,7 +283,7 @@ TVM_REGISTER_GLOBAL("topi.split") ...@@ -278,7 +283,7 @@ TVM_REGISTER_GLOBAL("topi.split")
} else { } else {
*rv = split(args[0], args[1], args[2]); *rv = split(args[0], args[1], args[2]);
} }
}); });
TVM_REGISTER_GLOBAL("topi.layout_transform") TVM_REGISTER_GLOBAL("topi.layout_transform")
.set_body([](TVMArgs args, TVMRetValue *rv) { .set_body([](TVMArgs args, TVMRetValue *rv) {
......
...@@ -564,6 +564,33 @@ def test_layout_transform(): ...@@ -564,6 +564,33 @@ def test_layout_transform():
check_device(backend) check_device(backend)
def test_shape():
in_shape = (8, 7, 13)
dtype = "int32"
A = tvm.placeholder(shape=in_shape, dtype="float32", name="A")
B = topi.shape(A, dtype)
input = np.random.uniform(size=in_shape).astype(A.dtype)
output = np.asarray(in_shape).astype(dtype)
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
tvm_input = tvm.nd.array(input, ctx)
tvm_output = tvm.nd.empty(output.shape, ctx=ctx, dtype=dtype)
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_injective(B)
f = tvm.build(s, [A, B], device, name="shape")
f(tvm_input, tvm_output)
tvm.testing.assert_allclose(tvm_output.asnumpy(), output)
for backend in get_all_backend():
check_device(backend)
if __name__ == "__main__": if __name__ == "__main__":
test_strided_slice() test_strided_slice()
test_concatenate() test_concatenate()
...@@ -581,3 +608,4 @@ if __name__ == "__main__": ...@@ -581,3 +608,4 @@ if __name__ == "__main__":
test_layout_transform() test_layout_transform()
test_repeat() test_repeat()
test_tile() test_tile()
test_shape()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment