Commit ae21eddf by Leyuan Wang Committed by Haichen Shen

[Relay][OP] Gather_nd exposed to relay (#2945)

* gather_nd added

* gather_nd test added

* more test added

* fix lint

* fix build error

* fix lint

* comments addressed
parent fbd1c164
......@@ -92,6 +92,7 @@ This level enables additional math and transform operators.
tvm.relay.zeros_like
tvm.relay.ones
tvm.relay.ones_like
tvm.relay.gather_nd
tvm.relay.full
tvm.relay.full_like
tvm.relay.cast
......@@ -225,6 +226,7 @@ Level 3 Definitions
.. autofunction:: tvm.relay.zeros_like
.. autofunction:: tvm.relay.ones
.. autofunction:: tvm.relay.ones_like
.. autofunction:: tvm.relay.gather_nd
.. autofunction:: tvm.relay.full
.. autofunction:: tvm.relay.full_like
.. autofunction:: tvm.relay.cast
......
......@@ -646,6 +646,7 @@ _identity_list = [
"zeros_like",
"ones_like",
"where",
"gather_nd",
]
_convert_map = {
......@@ -782,7 +783,6 @@ _convert_map = {
# TODO(tvm-tvm): support all operators.
#
# "broadcast_to",
# "gather_nd",
# "Crop" : _crop_like,
}
......
......@@ -32,6 +32,7 @@ _reg.register_schedule("where", schedule_broadcast)
_reg.register_schedule("stack", schedule_injective)
_reg.register_schedule("concatenate", schedule_injective)
_reg.register_schedule("_contrib_reverse_reshape", schedule_injective)
_reg.register_schedule("gather_nd", schedule_injective)
# layout_transform
_reg.register_schedule("layout_transform", schedule_injective)
......
......@@ -651,3 +651,36 @@ def reverse_reshape(data, newshape):
if isinstance(newshape, int):
newshape = [newshape]
return _make._contrib_reverse_reshape(data, list(newshape))
def gather_nd(data, indices):
"""Gather elements or slices from data and store to a tensor whose shape is
defined by indices.
Parameters
----------
data : relay.Expr
The input data to the operator.
indices : relay.Expr
The shape of output tensor.
Returns
-------
ret : relay.Expr
The computed result.
Examples
--------
.. code-block:: python
data = [[0, 1], [2, 3]]
indices = [[1, 1, 0], [0, 1, 0]]
relay.gather_nd(data, indices) = [2, 3, 0]
data = [[[1, 2], [3, 4]], [[5, 6], [7, 8]]]
indices = [[0, 1], [1, 0]]
relay.gather_nd(data, indices) = [[3, 4], [5, 6]]
"""
return _make.gather_nd(data, indices)
......@@ -2122,5 +2122,75 @@ example below::
.set_attr<FTVMCompute>("FTVMCompute", ReshapeCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);
// gather_nd operator
bool GatherNDRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
// `types` contains: [data, indices, result]
CHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
const auto* indices = types[1].as<TensorTypeNode>();
if (data == nullptr) {
CHECK(types[0].as<IncompleteTypeNode>())
<< "GatherND: expect input data type to be TensorType but get "
<< types[0];
return false;
}
if (indices == nullptr) {
CHECK(types[1].as<IncompleteTypeNode>())
<< "GatherND: expect indices type to be TensorType but get "
<< types[1];
return false;
}
const size_t ndim = data->shape.size();
const IntImm* mdim = data->shape[0].as<IntImm>();
const size_t kdim = indices->shape.size() - 1;
CHECK(size_t(mdim->value) <= ndim)
<< "GatherND: indices shape does satisfy.";
Array<IndexExpr> oshape;
for (size_t i = 1; i < kdim + 1; ++i)
oshape.push_back(indices->shape[i]);
for (size_t i = mdim->value; i < ndim; ++i)
oshape.push_back(data->shape[i]);
reporter->Assign(types[2], TensorTypeNode::make(oshape, data->dtype));
return true;
}
Array<Tensor> GatherNDCompute(const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
const Target& target) {
return { topi::gather_nd(inputs[0], inputs[1]) };
}
Expr MakeGatherND(Expr data,
Expr indices) {
static const Op& op = Op::Get("gather_nd");
return CallNode::make(op, {data, indices}, {});
}
TVM_REGISTER_API("relay.op._make.gather_nd")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 2>(MakeGatherND, args, rv);
});
RELAY_REGISTER_OP("gather_nd")
.describe(R"code(Gather elements or slices from data and store to
a tensor whose shape is defined by indices.
Given data with shape (X_0, X_1, ..., X_{N-1}) and indices with
shape (M, Y_0, ..., Y_{K-1}), the output will have shape
(Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}), where M <= N. If M == N,
output shape will simply be (Y_0, ..., Y_{K-1}).
)code" TVM_ADD_FILELINE)
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(3)
.add_type_rel("GatherND", GatherNDRel)
.set_attr<FTVMCompute>("FTVMCompute", GatherNDCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);
} // namespace relay
} // namespace tvm
......@@ -491,6 +491,21 @@ def test_forward_take():
verify((3,4), [-1, 5], 1)
verify((3,4), [-1, 5], 1, mode="wrap")
def test_forward_gather_nd():
def verify(xshape, yshape, y_data):
x_data = np.random.uniform(size=xshape).astype("float32")
ref_res = mx.nd.gather_nd(mx.nd.array(x_data), mx.nd.array(y_data))
mx_sym = mx.sym.gather_nd(mx.sym.var("x_data"), mx.sym.var("y_data"))
new_sym, _ = relay.frontend.from_mxnet(mx_sym, {"x_data": xshape, "y_data": yshape}, {"x_data": "float32", "y_data": "int32"})
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(new_sym)(x_data, y_data)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
verify((2, 2), (2, 3), [[1, 1, 0], [0, 1, 0]])
verify((2, 2, 2), (2, 2), [[0, 1], [1, 0]])
if __name__ == '__main__':
test_forward_mlp()
test_forward_vgg()
......@@ -527,3 +542,4 @@ if __name__ == '__main__':
test_forward_embedding()
test_forward_smooth_l1()
test_forward_take()
test_forward_gather_nd()
......@@ -553,7 +553,6 @@ def test_stack():
verify_stack([(2, 2, 3, 4), (2, 2, 3, 4), (2, 2, 3, 4), (2, 2, 3, 4)], -1)
def test_reverse():
def verify_reverse(dshape, axis):
x = relay.var("x", relay.TensorType(dshape, "float32"))
......@@ -573,6 +572,25 @@ def test_reverse():
verify_reverse((2, 3, 4), -1)
def test_gather_nd():
def verify_gather_nd(xshape, yshape, y_data):
x = relay.var("x", relay.TensorType(xshape, "float32"))
y = relay.var("y", relay.TensorType(yshape, "int32"))
z = relay.gather_nd(x, y)
func = relay.Function([x, y], z)
x_data = np.random.uniform(size=xshape).astype("float32")
ref_res = x_data[y_data]
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(x_data, y_data)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
verify_gather_nd((2, 2), (2, 3), [[1, 1, 0], [0, 1, 0]])
verify_gather_nd((2, 2, 2), (2, 2), [[0, 1], [1, 0]])
if __name__ == "__main__":
test_cast()
test_zeros_ones()
......@@ -601,3 +619,4 @@ if __name__ == "__main__":
test_stack()
test_tile()
test_repeat()
test_gather_nd()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment