Commit e6e9b371 by Siva Committed by Tianqi Chen

[RELAY][OP] Split (#1876)

parent cbf4fdbb
...@@ -94,6 +94,7 @@ This level enables additional math and transform operators. ...@@ -94,6 +94,7 @@ This level enables additional math and transform operators.
tvm.relay.full tvm.relay.full
tvm.relay.full_like tvm.relay.full_like
tvm.relay.cast tvm.relay.cast
tvm.relay.split
**Level 4: Broadcast and Reductions** **Level 4: Broadcast and Reductions**
...@@ -198,6 +199,7 @@ Level 3 Definitions ...@@ -198,6 +199,7 @@ Level 3 Definitions
.. autofunction:: tvm.relay.full .. autofunction:: tvm.relay.full
.. autofunction:: tvm.relay.full_like .. autofunction:: tvm.relay.full_like
.. autofunction:: tvm.relay.cast .. autofunction:: tvm.relay.cast
.. autofunction:: tvm.relay.split
Level 4 Definitions Level 4 Definitions
......
...@@ -106,6 +106,22 @@ struct SqueezeAttrs : public tvm::AttrsNode<SqueezeAttrs> { ...@@ -106,6 +106,22 @@ struct SqueezeAttrs : public tvm::AttrsNode<SqueezeAttrs> {
} }
}; // struct SqueezeAttrs }; // struct SqueezeAttrs
struct SplitAttrs : public tvm::AttrsNode<SplitAttrs> {
NodeRef indices_or_sections;
int axis;
TVM_DECLARE_ATTRS(SplitAttrs, "relay.attrs.SplitAttrs") {
TVM_ATTR_FIELD(indices_or_sections)
.describe("Indices or sections to split into. Accepts an int or a tuple"
"If indices_or_sections is an integer, the input will be divided equally"
"along given axis. If such a split is not possible, an error is raised."
"If indices_or_sections is a tuple of sorted integers,"
"the entries indicate where along axis the array is split.");
TVM_ATTR_FIELD(axis).set_default(0)
.describe("the axis to be splitted.");
}
};
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
#endif // TVM_RELAY_ATTRS_TRANSFORM_H_ #endif // TVM_RELAY_ATTRS_TRANSFORM_H_
...@@ -427,7 +427,7 @@ along which to split the array. ...@@ -427,7 +427,7 @@ along which to split the array.
return Array<Tensor>{ topi::split(inputs[0], indices, param.axis) }; return Array<Tensor>{ topi::split(inputs[0], indices, param.axis) };
} }
}) })
.set_support_level(1); .set_support_level(3);
// cast // cast
DMLC_REGISTER_PARAMETER(CastParam); DMLC_REGISTER_PARAMETER(CastParam);
......
...@@ -5,6 +5,7 @@ from __future__ import absolute_import ...@@ -5,6 +5,7 @@ from __future__ import absolute_import
import numpy as _np import numpy as _np
from .base import RelayNode, register_relay_node from .base import RelayNode, register_relay_node
from . import _make from . import _make
from . import _expr
from . import ty as _ty from . import ty as _ty
from .._ffi import base as _base from .._ffi import base as _base
from .. import nd as _nd from .. import nd as _nd
...@@ -284,6 +285,16 @@ class TupleWrapper(object): ...@@ -284,6 +285,16 @@ class TupleWrapper(object):
as an argument to an FFI function.""" as an argument to an FFI function."""
return self.tuple_value return self.tuple_value
def astext(self):
"""Get the text format of the tuple expression.
Returns
-------
text : str
The text format of the tuple expression.
"""
return _expr._text_print(self.tuple_value)
def __getitem__(self, index): def __getitem__(self, index):
if index >= len(self): if index >= len(self):
raise IndexError("Tuple index out of range") raise IndexError("Tuple index out of range")
......
"""Transform operators.""" """Transform operators."""
from . import _make from . import _make
from ..expr import TupleWrapper
def expand_dims(data, axis, num_newaxis=1): def expand_dims(data, axis, num_newaxis=1):
...@@ -146,7 +147,7 @@ def take(data, indices, axis=None): ...@@ -146,7 +147,7 @@ def take(data, indices, axis=None):
Parameters Parameters
---------- ----------
a : relay.Expr data : relay.Expr
The source array. The source array.
indices : rely.Expr indices : rely.Expr
...@@ -280,3 +281,35 @@ def collapse_sum_like(data, collapse_type): ...@@ -280,3 +281,35 @@ def collapse_sum_like(data, collapse_type):
The resulting tensor. The resulting tensor.
""" """
return _make.collapse_sum_like(data, collapse_type) return _make.collapse_sum_like(data, collapse_type)
def split(data, indices_or_sections, axis=0):
"""Split input tensor along axis by sections or indices.
If indices_or_sections is an integer, the input will be divided equally
along given axis. If such a split is not possible, an error is raised.
If indices_or_sections is a tuple of sorted integers,
the entries indicate where along axis the array is split.
Parameters
----------
data : relay.Expr
The source array.
indices_or_sections : int or tuple of int
Indices or sections to split into. Accepts an int or a tuple
axis : int, optional
The axis over which to split.
Returns
-------
ret : relay.Tuple([relay.Expr, relay.Expr])
The computed result.
"""
if isinstance(indices_or_sections, int):
ret_size = indices_or_sections
else:
ret_size = len(indices_or_sections) + 1
return TupleWrapper(_make.split(data, indices_or_sections, axis), ret_size)
...@@ -64,6 +64,7 @@ class AttrFunctor<R(const NodeRef& n, Args...)> { ...@@ -64,6 +64,7 @@ class AttrFunctor<R(const NodeRef& n, Args...)> {
virtual R VisitAttr_(const ir::Add* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::Add* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::Sub* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::Sub* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::Mul* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::Mul* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::Div* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::Mod* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::Mod* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::Min* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::Min* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::Max* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::Max* op, Args... args) ATTR_FUNCTOR_DEFAULT;
...@@ -96,6 +97,7 @@ class AttrFunctor<R(const NodeRef& n, Args...)> { ...@@ -96,6 +97,7 @@ class AttrFunctor<R(const NodeRef& n, Args...)> {
ATTR_FUNCTOR_DISPATCH(Add); ATTR_FUNCTOR_DISPATCH(Add);
ATTR_FUNCTOR_DISPATCH(Sub); ATTR_FUNCTOR_DISPATCH(Sub);
ATTR_FUNCTOR_DISPATCH(Mul); ATTR_FUNCTOR_DISPATCH(Mul);
ATTR_FUNCTOR_DISPATCH(Div);
ATTR_FUNCTOR_DISPATCH(Min); ATTR_FUNCTOR_DISPATCH(Min);
ATTR_FUNCTOR_DISPATCH(Max); ATTR_FUNCTOR_DISPATCH(Max);
ATTR_FUNCTOR_DISPATCH(GE); ATTR_FUNCTOR_DISPATCH(GE);
...@@ -135,6 +137,7 @@ class AttrsEqualHandler : ...@@ -135,6 +137,7 @@ class AttrsEqualHandler :
bool VisitAttr_(const ir::Add* lhs, const NodeRef& other) final; bool VisitAttr_(const ir::Add* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::Sub* lhs, const NodeRef& other) final; bool VisitAttr_(const ir::Sub* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::Mul* lhs, const NodeRef& other) final; bool VisitAttr_(const ir::Mul* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::Div* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::Mod* lhs, const NodeRef& other) final; bool VisitAttr_(const ir::Mod* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::Min* lhs, const NodeRef& other) final; bool VisitAttr_(const ir::Min* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::Max* lhs, const NodeRef& other) final; bool VisitAttr_(const ir::Max* lhs, const NodeRef& other) final;
...@@ -174,6 +177,7 @@ class AttrsHashHandler : ...@@ -174,6 +177,7 @@ class AttrsHashHandler :
size_t VisitAttr_(const ir::Add* op) final; size_t VisitAttr_(const ir::Add* op) final;
size_t VisitAttr_(const ir::Sub* op) final; size_t VisitAttr_(const ir::Sub* op) final;
size_t VisitAttr_(const ir::Mul* op) final; size_t VisitAttr_(const ir::Mul* op) final;
size_t VisitAttr_(const ir::Div* op) final;
size_t VisitAttr_(const ir::Mod* op) final; size_t VisitAttr_(const ir::Mod* op) final;
size_t VisitAttr_(const ir::Min* op) final; size_t VisitAttr_(const ir::Min* op) final;
size_t VisitAttr_(const ir::Max* op) final; size_t VisitAttr_(const ir::Max* op) final;
......
...@@ -132,6 +132,7 @@ bool AttrsEqualHandler::VisitAttr_(const StrMapNode* lhs, const NodeRef& other) ...@@ -132,6 +132,7 @@ bool AttrsEqualHandler::VisitAttr_(const StrMapNode* lhs, const NodeRef& other)
TVM_DEFINE_ATTRS_BINOP_EQUAL(Add); TVM_DEFINE_ATTRS_BINOP_EQUAL(Add);
TVM_DEFINE_ATTRS_BINOP_EQUAL(Sub); TVM_DEFINE_ATTRS_BINOP_EQUAL(Sub);
TVM_DEFINE_ATTRS_BINOP_EQUAL(Mul); TVM_DEFINE_ATTRS_BINOP_EQUAL(Mul);
TVM_DEFINE_ATTRS_BINOP_EQUAL(Div);
TVM_DEFINE_ATTRS_BINOP_EQUAL(Mod); TVM_DEFINE_ATTRS_BINOP_EQUAL(Mod);
TVM_DEFINE_ATTRS_BINOP_EQUAL(Max); TVM_DEFINE_ATTRS_BINOP_EQUAL(Max);
TVM_DEFINE_ATTRS_BINOP_EQUAL(Min); TVM_DEFINE_ATTRS_BINOP_EQUAL(Min);
...@@ -243,6 +244,7 @@ size_t AttrsHashHandler::VisitAttr_(const StrMapNode* lhs) { ...@@ -243,6 +244,7 @@ size_t AttrsHashHandler::VisitAttr_(const StrMapNode* lhs) {
TVM_DEFINE_ATTRS_BINOP_HASH(Add); TVM_DEFINE_ATTRS_BINOP_HASH(Add);
TVM_DEFINE_ATTRS_BINOP_HASH(Sub); TVM_DEFINE_ATTRS_BINOP_HASH(Sub);
TVM_DEFINE_ATTRS_BINOP_HASH(Mul); TVM_DEFINE_ATTRS_BINOP_HASH(Mul);
TVM_DEFINE_ATTRS_BINOP_HASH(Div);
TVM_DEFINE_ATTRS_BINOP_HASH(Mod); TVM_DEFINE_ATTRS_BINOP_HASH(Mod);
TVM_DEFINE_ATTRS_BINOP_HASH(Max); TVM_DEFINE_ATTRS_BINOP_HASH(Max);
TVM_DEFINE_ATTRS_BINOP_HASH(Min); TVM_DEFINE_ATTRS_BINOP_HASH(Min);
......
...@@ -6,12 +6,14 @@ ...@@ -6,12 +6,14 @@
#include <tvm/relay/op.h> #include <tvm/relay/op.h>
#include <tvm/relay/attrs/transform.h> #include <tvm/relay/attrs/transform.h>
#include <tvm/ir_operator.h> #include <tvm/ir_operator.h>
#include <tvm/ir.h>
#include <vector> #include <vector>
#include "../op_common.h" #include "../op_common.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
using ir::IntImm;
// relay.cast // relay.cast
TVM_REGISTER_NODE_TYPE(CastAttrs); TVM_REGISTER_NODE_TYPE(CastAttrs);
...@@ -834,5 +836,100 @@ RELAY_REGISTER_OP("broadcast_to_like") ...@@ -834,5 +836,100 @@ RELAY_REGISTER_OP("broadcast_to_like")
.set_support_level(10) .set_support_level(10)
.add_type_rel("BroadCastToLike", BroadCastToLikeRel); .add_type_rel("BroadCastToLike", BroadCastToLikeRel);
// Split
TVM_REGISTER_NODE_TYPE(SplitAttrs);
bool SplitRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
// `types` contains: [data, result]
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
CHECK(data != nullptr);
CHECK_NE(data->shape.size(), 0) << "Input shape cannot be empty";
const auto param = attrs.as<SplitAttrs>();
CHECK(param != nullptr);
auto axis = param->axis;
if (axis < 0) {
axis += data->shape.size();
}
CHECK_LT(axis, data->shape.size())
<< "axis should be within the input dimension range.";
CHECK_GT(axis, 0)
<< "axis should be within the input dimension range.";
if (const IntImm* sections = param->indices_or_sections.as<IntImm>()) {
CHECK(reporter->Assert(data->shape[axis] %
sections->value == make_zero(Int(64))))
<< "indices_or_sections need to be able to divide input.shape[axis]";
std::vector<Type> fields;
for (int i = 0; i < sections->value; ++i) {
std::vector<IndexExpr>&& oshape = AsVector(data->shape);
oshape[axis] /= int32_t(sections->value);
auto vec_type = TensorTypeNode::make(oshape, data->dtype);
fields.push_back(vec_type);
}
reporter->Assign(types[1], TupleTypeNode::make(Array<Type>(fields)));
} else {
auto indices = param->indices_or_sections.as<ArrayNode>()->data;
auto begin = IndexExpr(make_zero(Int(32)));
std::vector<Type> fields;
for (uint i = 0; i < indices.size(); ++i) {
CHECK(reporter->Assert(IndexExpr(indices[i]) > begin))
<< "indices_or_sections need to be a sorted ascending list";
std::vector<IndexExpr>&& oshape = AsVector(data->shape);
oshape[axis] = IndexExpr(indices[i]) - begin;
begin = IndexExpr(indices[i]);
auto vec_type = TensorTypeNode::make(oshape, data->dtype);
fields.push_back(vec_type);
}
CHECK(reporter->Assert(begin < data->shape[axis]))
<< "The sum of sections must match the input.shape[axis]";
std::vector<IndexExpr>&& oshape = AsVector(data->shape);
oshape[axis] = data->shape[axis] - begin;
auto vec_type = TensorTypeNode::make(oshape, data->dtype);
fields.push_back(vec_type);
reporter->Assign(types[1], TupleTypeNode::make(Array<Type>(fields)));
}
return true;
}
Expr MakeSplit(Expr data,
NodeRef indices_or_sections,
int axis) {
auto attrs = make_node<SplitAttrs>();
attrs->axis = axis;
attrs->indices_or_sections = std::move(indices_or_sections);
static const Op& op = Op::Get("split");
return CallNode::make(op, {data}, Attrs(attrs), {});
}
TVM_REGISTER_API("relay.op._make.split")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
if (args.type_codes[1] == kDLInt) {
*rv = MakeSplit(args[0], make_const(Int(64), int64_t(args[1])), args[2]);
} else {
*rv = MakeSplit(args[0], args[1], args[2]);
}
});
RELAY_REGISTER_OP("split")
.describe(R"code(Splits an array along a particular axis into multiple sub-arrays.
Indices or sections to split into. Accepts an int or a tuple
If indices_or_sections is an integer, the input will be divided equally
along given axis. If such a split is not possible, an error is raised.
If indices_or_sections is a tuple of sorted integers,
the entries indicate where along axis the array is split.
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.SplitAttrs")
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(3)
.add_type_rel("Split", SplitRel);
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
...@@ -107,6 +107,38 @@ def test_take_infer_type(): ...@@ -107,6 +107,38 @@ def test_take_infer_type():
verify_take((d1, d2), (d3, d4, d5), (d1, d3, d4, d5), 1) verify_take((d1, d2), (d3, d4, d5), (d1, d3, d4, d5), 1)
verify_take((d1, d2, d3, d4), (d5, d6), (d1, d2, d5, d6, d4), -2) verify_take((d1, d2, d3, d4), (d5, d6), (d1, d2, d5, d6, d4), -2)
def test_split_infer_type():
def verify_split(dshape, indices_or_sections, ret_type, axis=None):
x = relay.var("x", relay.ty.TensorType(dshape, "float32"))
y = relay.split(x, indices_or_sections, axis=axis)
y.astext()
yy = relay.ir_pass.infer_type(y.astuple())
assert yy.checked_type == ret_type
d1, d2, d3, d4 = tvm.var("d1"), tvm.var("d2"), tvm.var("d3"), tvm.var("d4")
axis = tvm.var("axis")
verify_split((5, 5, 2, 2), 5,
relay.ty.TupleType(tvm.convert([
relay.ty.TensorType((5, 1, 2, 2), "float32"),
relay.ty.TensorType((5, 1, 2, 2), "float32"),
relay.ty.TensorType((5, 1, 2, 2), "float32"),
relay.ty.TensorType((5, 1, 2, 2), "float32"),
relay.ty.TensorType((5, 1, 2, 2), "float32")])),
axis=1)
verify_split((d1, d2, d3, d4), 4,
relay.ty.TupleType(tvm.convert([
relay.ty.TensorType((d1, d2, d3/4, d4), "float32"),
relay.ty.TensorType((d1, d2, d3/4, d4), "float32"),
relay.ty.TensorType((d1, d2, d3/4, d4), "float32"),
relay.ty.TensorType((d1, d2, d3/4, d4), "float32")])),
axis=2)
verify_split((d1, d2, d3, d4), (2, 4, 7),
relay.ty.TupleType(tvm.convert([
relay.ty.TensorType((d1, 2, d3, d4), "float32"),
relay.ty.TensorType((d1, 2, d3, d4), "float32"),
relay.ty.TensorType((d1, 3, d3, d4), "float32"),
relay.ty.TensorType((d1, (d2-7), d3, d4), "float32")])),
axis=1)
def test_full(): def test_full():
# default settings: match input dtype # default settings: match input dtype
...@@ -161,3 +193,4 @@ if __name__ == "__main__": ...@@ -161,3 +193,4 @@ if __name__ == "__main__":
test_infer_type_leaky_relu() test_infer_type_leaky_relu()
test_squeeze_infer_type() test_squeeze_infer_type()
test_squeeze_bad_axes_infer_type() test_squeeze_bad_axes_infer_type()
test_split_infer_type()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment