Commit ae21eddf by Leyuan Wang Committed by Haichen Shen

[Relay][OP] Gather_nd exposed to relay (#2945)

* gather_nd added

* gather_nd test added

* more test added

* fix lint

* fix build error

* fix lint

* comments addressed
parent fbd1c164
...@@ -92,6 +92,7 @@ This level enables additional math and transform operators. ...@@ -92,6 +92,7 @@ This level enables additional math and transform operators.
tvm.relay.zeros_like tvm.relay.zeros_like
tvm.relay.ones tvm.relay.ones
tvm.relay.ones_like tvm.relay.ones_like
tvm.relay.gather_nd
tvm.relay.full tvm.relay.full
tvm.relay.full_like tvm.relay.full_like
tvm.relay.cast tvm.relay.cast
...@@ -225,6 +226,7 @@ Level 3 Definitions ...@@ -225,6 +226,7 @@ Level 3 Definitions
.. autofunction:: tvm.relay.zeros_like .. autofunction:: tvm.relay.zeros_like
.. autofunction:: tvm.relay.ones .. autofunction:: tvm.relay.ones
.. autofunction:: tvm.relay.ones_like .. autofunction:: tvm.relay.ones_like
.. autofunction:: tvm.relay.gather_nd
.. autofunction:: tvm.relay.full .. autofunction:: tvm.relay.full
.. autofunction:: tvm.relay.full_like .. autofunction:: tvm.relay.full_like
.. autofunction:: tvm.relay.cast .. autofunction:: tvm.relay.cast
......
...@@ -646,6 +646,7 @@ _identity_list = [ ...@@ -646,6 +646,7 @@ _identity_list = [
"zeros_like", "zeros_like",
"ones_like", "ones_like",
"where", "where",
"gather_nd",
] ]
_convert_map = { _convert_map = {
...@@ -782,7 +783,6 @@ _convert_map = { ...@@ -782,7 +783,6 @@ _convert_map = {
# TODO(tvm-tvm): support all operators. # TODO(tvm-tvm): support all operators.
# #
# "broadcast_to", # "broadcast_to",
# "gather_nd",
# "Crop" : _crop_like, # "Crop" : _crop_like,
} }
......
...@@ -32,6 +32,7 @@ _reg.register_schedule("where", schedule_broadcast) ...@@ -32,6 +32,7 @@ _reg.register_schedule("where", schedule_broadcast)
_reg.register_schedule("stack", schedule_injective) _reg.register_schedule("stack", schedule_injective)
_reg.register_schedule("concatenate", schedule_injective) _reg.register_schedule("concatenate", schedule_injective)
_reg.register_schedule("_contrib_reverse_reshape", schedule_injective) _reg.register_schedule("_contrib_reverse_reshape", schedule_injective)
_reg.register_schedule("gather_nd", schedule_injective)
# layout_transform # layout_transform
_reg.register_schedule("layout_transform", schedule_injective) _reg.register_schedule("layout_transform", schedule_injective)
......
...@@ -651,3 +651,36 @@ def reverse_reshape(data, newshape): ...@@ -651,3 +651,36 @@ def reverse_reshape(data, newshape):
if isinstance(newshape, int): if isinstance(newshape, int):
newshape = [newshape] newshape = [newshape]
return _make._contrib_reverse_reshape(data, list(newshape)) return _make._contrib_reverse_reshape(data, list(newshape))
def gather_nd(data, indices):
"""Gather elements or slices from data and store to a tensor whose shape is
defined by indices.
Parameters
----------
data : relay.Expr
The input data to the operator.
indices : relay.Expr
The shape of output tensor.
Returns
-------
ret : relay.Expr
The computed result.
Examples
--------
.. code-block:: python
data = [[0, 1], [2, 3]]
indices = [[1, 1, 0], [0, 1, 0]]
relay.gather_nd(data, indices) = [2, 3, 0]
data = [[[1, 2], [3, 4]], [[5, 6], [7, 8]]]
indices = [[0, 1], [1, 0]]
relay.gather_nd(data, indices) = [[3, 4], [5, 6]]
"""
return _make.gather_nd(data, indices)
...@@ -2122,5 +2122,75 @@ example below:: ...@@ -2122,5 +2122,75 @@ example below::
.set_attr<FTVMCompute>("FTVMCompute", ReshapeCompute) .set_attr<FTVMCompute>("FTVMCompute", ReshapeCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective); .set_attr<TOpPattern>("TOpPattern", kInjective);
// gather_nd operator
bool GatherNDRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
// `types` contains: [data, indices, result]
CHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
const auto* indices = types[1].as<TensorTypeNode>();
if (data == nullptr) {
CHECK(types[0].as<IncompleteTypeNode>())
<< "GatherND: expect input data type to be TensorType but get "
<< types[0];
return false;
}
if (indices == nullptr) {
CHECK(types[1].as<IncompleteTypeNode>())
<< "GatherND: expect indices type to be TensorType but get "
<< types[1];
return false;
}
const size_t ndim = data->shape.size();
const IntImm* mdim = data->shape[0].as<IntImm>();
const size_t kdim = indices->shape.size() - 1;
CHECK(size_t(mdim->value) <= ndim)
<< "GatherND: indices shape does satisfy.";
Array<IndexExpr> oshape;
for (size_t i = 1; i < kdim + 1; ++i)
oshape.push_back(indices->shape[i]);
for (size_t i = mdim->value; i < ndim; ++i)
oshape.push_back(data->shape[i]);
reporter->Assign(types[2], TensorTypeNode::make(oshape, data->dtype));
return true;
}
Array<Tensor> GatherNDCompute(const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
const Target& target) {
return { topi::gather_nd(inputs[0], inputs[1]) };
}
Expr MakeGatherND(Expr data,
Expr indices) {
static const Op& op = Op::Get("gather_nd");
return CallNode::make(op, {data, indices}, {});
}
TVM_REGISTER_API("relay.op._make.gather_nd")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 2>(MakeGatherND, args, rv);
});
RELAY_REGISTER_OP("gather_nd")
.describe(R"code(Gather elements or slices from data and store to
a tensor whose shape is defined by indices.
Given data with shape (X_0, X_1, ..., X_{N-1}) and indices with
shape (M, Y_0, ..., Y_{K-1}), the output will have shape
(Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}), where M <= N. If M == N,
output shape will simply be (Y_0, ..., Y_{K-1}).
)code" TVM_ADD_FILELINE)
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(3)
.add_type_rel("GatherND", GatherNDRel)
.set_attr<FTVMCompute>("FTVMCompute", GatherNDCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
...@@ -491,6 +491,21 @@ def test_forward_take(): ...@@ -491,6 +491,21 @@ def test_forward_take():
verify((3,4), [-1, 5], 1) verify((3,4), [-1, 5], 1)
verify((3,4), [-1, 5], 1, mode="wrap") verify((3,4), [-1, 5], 1, mode="wrap")
def test_forward_gather_nd():
def verify(xshape, yshape, y_data):
x_data = np.random.uniform(size=xshape).astype("float32")
ref_res = mx.nd.gather_nd(mx.nd.array(x_data), mx.nd.array(y_data))
mx_sym = mx.sym.gather_nd(mx.sym.var("x_data"), mx.sym.var("y_data"))
new_sym, _ = relay.frontend.from_mxnet(mx_sym, {"x_data": xshape, "y_data": yshape}, {"x_data": "float32", "y_data": "int32"})
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(new_sym)(x_data, y_data)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
verify((2, 2), (2, 3), [[1, 1, 0], [0, 1, 0]])
verify((2, 2, 2), (2, 2), [[0, 1], [1, 0]])
if __name__ == '__main__': if __name__ == '__main__':
test_forward_mlp() test_forward_mlp()
test_forward_vgg() test_forward_vgg()
...@@ -527,3 +542,4 @@ if __name__ == '__main__': ...@@ -527,3 +542,4 @@ if __name__ == '__main__':
test_forward_embedding() test_forward_embedding()
test_forward_smooth_l1() test_forward_smooth_l1()
test_forward_take() test_forward_take()
test_forward_gather_nd()
...@@ -553,7 +553,6 @@ def test_stack(): ...@@ -553,7 +553,6 @@ def test_stack():
verify_stack([(2, 2, 3, 4), (2, 2, 3, 4), (2, 2, 3, 4), (2, 2, 3, 4)], -1) verify_stack([(2, 2, 3, 4), (2, 2, 3, 4), (2, 2, 3, 4), (2, 2, 3, 4)], -1)
def test_reverse(): def test_reverse():
def verify_reverse(dshape, axis): def verify_reverse(dshape, axis):
x = relay.var("x", relay.TensorType(dshape, "float32")) x = relay.var("x", relay.TensorType(dshape, "float32"))
...@@ -573,6 +572,25 @@ def test_reverse(): ...@@ -573,6 +572,25 @@ def test_reverse():
verify_reverse((2, 3, 4), -1) verify_reverse((2, 3, 4), -1)
def test_gather_nd():
def verify_gather_nd(xshape, yshape, y_data):
x = relay.var("x", relay.TensorType(xshape, "float32"))
y = relay.var("y", relay.TensorType(yshape, "int32"))
z = relay.gather_nd(x, y)
func = relay.Function([x, y], z)
x_data = np.random.uniform(size=xshape).astype("float32")
ref_res = x_data[y_data]
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(x_data, y_data)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
verify_gather_nd((2, 2), (2, 3), [[1, 1, 0], [0, 1, 0]])
verify_gather_nd((2, 2, 2), (2, 2), [[0, 1], [1, 0]])
if __name__ == "__main__": if __name__ == "__main__":
test_cast() test_cast()
test_zeros_ones() test_zeros_ones()
...@@ -601,3 +619,4 @@ if __name__ == "__main__": ...@@ -601,3 +619,4 @@ if __name__ == "__main__":
test_stack() test_stack()
test_tile() test_tile()
test_repeat() test_repeat()
test_gather_nd()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment