Commit b68d9dc0 by Siva Committed by Tianqi Chen

[RELAY][OP] take (#1863)

parent 64d3393e
...@@ -73,6 +73,7 @@ This level enables additional math and transform operators. ...@@ -73,6 +73,7 @@ This level enables additional math and transform operators.
tvm.relay.round tvm.relay.round
tvm.relay.abs tvm.relay.abs
tvm.relay.negative tvm.relay.negative
tvm.relay.take
...@@ -143,6 +144,7 @@ Level 3 Definitions ...@@ -143,6 +144,7 @@ Level 3 Definitions
.. autofunction:: tvm.relay.reshape .. autofunction:: tvm.relay.reshape
.. autofunction:: tvm.relay.copy .. autofunction:: tvm.relay.copy
.. autofunction:: tvm.relay.transpose .. autofunction:: tvm.relay.transpose
.. autofunction:: tvm.relay.take
Level 3 Definitions Level 3 Definitions
------------------- -------------------
......
...@@ -59,6 +59,15 @@ struct ReshapeAttrs : public tvm::AttrsNode<ReshapeAttrs> { ...@@ -59,6 +59,15 @@ struct ReshapeAttrs : public tvm::AttrsNode<ReshapeAttrs> {
} }
}; // struct ReshapeAttrs }; // struct ReshapeAttrs
struct TakeAttrs : public tvm::AttrsNode<TakeAttrs> {
IndexExpr axis;
TVM_DECLARE_ATTRS(TakeAttrs, "relay.attrs.TakeAttrs") {
TVM_ATTR_FIELD(axis).set_default(NullValue<IndexExpr>())
.describe("The axis over which to select values.");
}
};
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
#endif // TVM_RELAY_ATTRS_TRANSFORM_H_ #endif // TVM_RELAY_ATTRS_TRANSFORM_H_
...@@ -1135,7 +1135,7 @@ Examples:: ...@@ -1135,7 +1135,7 @@ Examples::
.set_attr<FCorrectLayout>("FCorrectLayout", TakeCorrectLayout) .set_attr<FCorrectLayout>("FCorrectLayout", TakeCorrectLayout)
.set_num_inputs(2) .set_num_inputs(2)
.set_num_outputs(1) .set_num_outputs(1)
.set_support_level(1) .set_support_level(3)
.set_attr<FTVMCompute>( .set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs, "FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs, const Array<Tensor>& inputs,
......
...@@ -116,3 +116,26 @@ def reshape(data, newshape): ...@@ -116,3 +116,26 @@ def reshape(data, newshape):
if isinstance(newshape, int): if isinstance(newshape, int):
newshape = [newshape] newshape = [newshape]
return _make.reshape(data, list(newshape)) return _make.reshape(data, list(newshape))
def take(data, indices, axis=None):
"""Take elements from an array along an axis.
Parameters
----------
a : relay.Expr
The source array.
indices : rely.Expr
The indices of the values to extract.
axis : int, optional
The axis over which to select values. By default,
the flattened input array is used.
Returns
-------
ret : relay.Expr
The computed result.
"""
return _make.take(data, indices, axis)
...@@ -315,5 +315,94 @@ Example:: ...@@ -315,5 +315,94 @@ Example::
.set_support_level(3) .set_support_level(3)
.add_type_rel("Reshape", ReshapeRel); .add_type_rel("Reshape", ReshapeRel);
// Take
TVM_REGISTER_NODE_TYPE(TakeAttrs);
bool TakeRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
// `types` contains: [data, indices, result]
CHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
CHECK(data != nullptr);
const auto* indices = types[1].as<TensorTypeNode>();
CHECK(indices != nullptr);
const auto param = attrs.as<TakeAttrs>();
CHECK(param != nullptr);
if (!param->axis.defined()) {
std::vector<IndexExpr>&& oshape = AsVector(indices->shape);
reporter->Assign(types[2], TensorTypeNode::make(oshape, data->dtype));
return true;
}
std::vector<IndexExpr> oshape;
const auto ndim_data = static_cast<int>(data->shape.size());
const auto ndim_indices = static_cast<int>(indices->shape.size());
auto axis = (*as_const_int(param->axis));
if (axis < 0) axis += ndim_data;
CHECK_LE(axis, ndim_data)
<< "axis should be with in data shape"
<< ", but got = " << axis;
oshape.reserve(ndim_data - 1 + ndim_indices);
for (int i = 0; i < axis; ++i) {
oshape.emplace_back(data->shape[i]);
}
for (int i = 0; i < ndim_indices; ++i) {
oshape.emplace_back(indices->shape[i]);
}
for (int i = axis+1; i < ndim_data; ++i) {
oshape.emplace_back(data->shape[i]);
}
reporter->Assign(types[2], TensorTypeNode::make(oshape, data->dtype));
return true;
}
Expr MakeTake(Expr data,
Expr indices,
IndexExpr axis) {
auto attrs = make_node<TakeAttrs>();
attrs->axis = axis;
static const Op& op = Op::Get("take");
return CallNode::make(op, {data, indices}, Attrs(attrs), {});
}
TVM_REGISTER_API("relay.op._make.take")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 3>(MakeTake, args, rv);
});
RELAY_REGISTER_OP("take")
.describe(R"code(Take elements from an array along an axis.
When axis is not None, this function does the same thing as 'fancy' indexing
(indexing arrays using arrays); however, it can be easier to use if you need
elements along a given axis.
**Note** that when axis is none the flattened input array is used.
Examples::
a = [[ 1, 2],
[ 3, 4]]
indices = [3, 0, 2]
take(a, indices) = [ 4, 1, 3]
a = [[ 1., 2.],
[ 3., 4.]]
indices = [1, 0]
take(a, indices, axis=1) = [[ 2., 1.],
[ 4., 3.]]
)code" TVM_ADD_FILELINE)
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("indices", "Tensor", "The indices tensor.")
.set_support_level(2)
.add_type_rel("Take", TakeRel);
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
...@@ -91,6 +91,27 @@ def test_single_op(): ...@@ -91,6 +91,27 @@ def test_single_op():
tvm.relay.round, tvm.relay.abs, tvm.relay.negative]: tvm.relay.round, tvm.relay.abs, tvm.relay.negative]:
check_single_op(opfunc) check_single_op(opfunc)
def test_take_infer_type():
def verify_take(dshape, indices_shape, oshape, axis=None):
ib = relay.ir_builder.IRBuilder()
x = ib.param("x", relay.ty.TensorType(dshape, "float32"))
indices = ib.param("indices", relay.ty.TensorType(indices_shape, "int32"))
with ib.function(x, indices) as func:
ib.ret(relay.take(x.var, indices.var, axis=axis))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
assert ftype.ret_type == relay.ty.TensorType(oshape, "float32")
d1, d2, d3 = tvm.var("d1"), tvm.var("d2"), tvm.var("d3")
d4, d5, d6 = tvm.var("d4"), tvm.var("d5"), tvm.var("d6")
verify_take((d1,), (1,), (1,), 0)
verify_take((4,), (d1, d2), (d1, d2))
verify_take((3, 3, 3), (1, d2), (1, d2))
verify_take((d1, d2), (d3, d4, d5), (d3, d4, d5, d2), 0)
verify_take((d1, d2), (d3, d4, d5), (d1, d3, d4, d5), 1)
verify_take((d1, d2, d3, d4), (d5, d6), (d1, d2, d5, d6, d4), -2)
if __name__ == "__main__": if __name__ == "__main__":
test_single_op() test_single_op()
...@@ -99,3 +120,4 @@ if __name__ == "__main__": ...@@ -99,3 +120,4 @@ if __name__ == "__main__":
test_copy_infer_type() test_copy_infer_type()
test_transpose_infer_type() test_transpose_infer_type()
test_reshape_infer_type() test_reshape_infer_type()
test_take_infer_type()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment