Commit add1f90e by Haichen Shen Committed by Tianqi Chen

[NNVM/TOPI][OP] gather_nd (#2041)

parent 2005f852
......@@ -30,6 +30,7 @@ List of operators
topi.concatenate
topi.split
topi.take
topi.gather_nd
topi.full
topi.full_like
topi.nn.relu
......@@ -103,6 +104,7 @@ topi
.. autofunction:: topi.concatenate
.. autofunction:: topi.split
.. autofunction:: topi.take
.. autofunction:: topi.gather_nd
.. autofunction:: topi.full
.. autofunction:: topi.full_like
.. autofunction:: topi.max
......
......@@ -61,6 +61,7 @@ This level enables fully connected multi-layer perceptron.
nnvm.symbol.flip
nnvm.symbol.lrn
nnvm.symbol.where
nnvm.symbol.gather_nd
**Level 2: Convolutions**
......@@ -197,6 +198,7 @@ Detailed Definitions
.. autofunction:: nnvm.symbol.flip
.. autofunction:: nnvm.symbol.lrn
.. autofunction:: nnvm.symbol.where
.. autofunction:: nnvm.symbol.gather_nd
.. autofunction:: nnvm.symbol.conv2d
.. autofunction:: nnvm.symbol.conv2d_transpose
......
......@@ -290,7 +290,7 @@ _identity_list = ['__add_scalar__', '__add_symbol__', '__div_scalar__',
'elemwise_div', 'elemwise_mul', 'elemwise_sub', 'exp',
'flatten', 'log', 'log_softmax', 'max', 'min', 'negative',
'ones_like', 'relu', 'sigmoid', 'slice_like', 'softmax',
'sum', 'tanh', 'transpose', 'zeros_like']
'sum', 'tanh', 'transpose', 'zeros_like', 'gather_nd']
_convert_map = {
'_copy' : _rename('copy'),
......
......@@ -86,3 +86,7 @@ reg.register_schedule("slice_like", _fschedule_injective)
# where
reg.register_pattern("where", OpPattern.INJECTIVE)
reg.register_schedule("where", _fschedule_injective)
# gather_nd
reg.register_pattern("gather_nd", OpPattern.INJECTIVE)
reg.register_schedule("gather_nd", _fschedule_injective)
......@@ -1003,7 +1003,7 @@ Examples::
[ 3, 4]]
flip(x) = [[ 3., 4.],
[ 1., 2.]]
[ 1., 2.]]
x = [[[ 1., 2.],
[ 3., 4.]],
......@@ -1012,16 +1012,16 @@ Examples::
[ 7., 8.]]]
flip(x) = [[[ 5., 6.],
[ 7., 8.]],
[ 7., 8.]],
[[ 1., 2.],
[ 3., 4.]]]
[[ 1., 2.],
[ 3., 4.]]]
flip(x, axis=1) = [[[ 3., 4.],
[ 1., 2.]],
[ 1., 2.]],
[[ 7., 8.],
[ 5., 6.]]]
[[ 7., 8.],
[ 5., 6.]]]
)code" NNVM_ADD_FILELINE)
.add_argument("data", "Tensor", "Source input")
.add_arguments(FlipParam::__FIELDS__())
......@@ -1353,5 +1353,107 @@ Examples::
})
.set_support_level(4);
// gather_nd
inline bool GatherNDInferShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape>* in_attrs,
std::vector<TShape>* out_attrs) {
CHECK_EQ(in_attrs->size(), 2U);
CHECK_EQ(out_attrs->size(), 1U);
const TShape& data_shape = in_attrs->at(0);
const TShape& indices_shape = in_attrs->at(1);
CHECK_GT(indices_shape.ndim(), 1) << "indices must have at least 2 dimensions";
CHECK_LE(indices_shape[0], data_shape.ndim()) <<
"dim 0 of indices must be no more than rank of data";
std::vector<dim_t> oshape;
for (size_t i = 1; i < indices_shape.ndim(); ++i) {
oshape.push_back(indices_shape[i]);
}
for (size_t i = indices_shape[0]; i < data_shape.ndim(); ++i) {
oshape.push_back(data_shape[i]);
}
if (oshape.size() == 0) {
oshape.push_back(1);
}
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_attrs, 0,
TShape(oshape.begin(), oshape.end()));
return true;
}
inline bool GatherNDInferType(const NodeAttrs &attrs,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
CHECK_EQ(in_attrs->size(), 2U);
CHECK_EQ(out_attrs->size(), 1U);
NNVM_ASSIGN_OUTPUT_TYPE(attrs, *out_attrs, 0, (*in_attrs)[0]);
return true;
}
inline bool GatherNDCorrectLayout(const NodeAttrs& attrs,
std::vector<Layout> *ilayouts,
const std::vector<Layout> *last_ilayouts,
std::vector<Layout> *olayouts) {
CHECK_EQ(ilayouts->size(), last_ilayouts->size());
CHECK_EQ(olayouts->size(), 1U);
for (size_t i = 0; i < ilayouts->size(); ++i) {
const Layout& input = last_ilayouts->at(i).defined() ?
last_ilayouts->at(i) : ilayouts->at(i);
NNVM_ASSIGN_LAYOUT(*ilayouts, i, input);
}
return true;
}
NNVM_REGISTER_OP(gather_nd)
.describe(R"code(
Gather elements or slices from ``data`` into a tensor specified by ``indices``.
The shape of output tensor is inferred from ``indices``. Given ``data`` with
shape ``(X0, X1, ..., X_{N-1})`` and ``indices`` with shape ``(Y_0, ...,
Y_{M-1})``, the output will have shape ``(Y_1, ..., Y_{M-1}, X_{Y_0}, ...,
X_{N-1})`` when ``Y_0 < N``, or ``(Y_1, ..., Y_{M-1})`` when ``Y_0 == N``. The
operator is invalid when ``Y_0 > N``.
The element in output is defined as follows::
output[y_1, ..., y_{M-1}, x_{Y_0}, ..., x_{N-1}] = data[indices[0, y_1, ..., y_{M-1}],
...,
indices[Y_0-1, y_1, ..., y_{M-1}],
x_{Y_0}, ..., x_{N-1}]
Examples::
data = [[0, 1], [2, 3]]
indices = [[1], [0]]
gather_nd(data, indices) = [2]
data = [[0, 1], [2, 3]]
indices = [[1, 1, 0], [0, 1, 0]]
gather_nd(data, indices) = [2, 3, 0]
data = [[[1, 2], [3, 4]], [[5, 6], [7, 8]]]
indices = [[0, 1], [1, 0]]
gather_nd(data, indices) = [[3, 4], [5, 6]]
)code" NNVM_ADD_FILELINE)
.add_argument("data", "Tensor", "Input data.")
.add_argument("indices", "Tensor", "Indices of data")
.set_num_inputs(2)
.set_num_outputs(1)
.set_attr<FInferShape>("FInferShape", GatherNDInferShape)
.set_attr<FInferType>("FInferType", GatherNDInferType)
.set_attr<FCorrectLayout>("FCorrectLayout", GatherNDCorrectLayout)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
return Array<Tensor>{
topi::gather_nd(inputs[0], inputs[1]) };
})
.set_attr<FListInputNames>("FListInputNames", [](const NodeAttrs& attrs) {
return std::vector<std::string>{"data", "indices"};
})
.set_support_level(3);
} // namespace top
} // namespace nnvm
......@@ -533,6 +533,36 @@ def test_l2_normalize():
verify_l2_normalize((1, 3, 20, 20), 0.001, (1,))
verify_l2_normalize((1, 3, 20, 20), 0.001, (1, 2))
def verify_gather_nd(src_shape, indices_src):
src_dtype = "float32"
indices_dtype = "int32"
indices_src = np.array(indices_src, dtype=indices_dtype)
a = sym.Variable("a", shape=src_shape)
indices = sym.Variable("indices", shape=indices_src.shape)
y = sym.gather_nd(a, indices)
def forward(a, indices):
return topi.testing.gather_nd_python(a, indices)
a_src = np.arange(np.prod(src_shape), dtype=src_dtype).reshape(src_shape)
check_function(y, forward,
dtype={'a': src_dtype, 'indices': indices_dtype},
values={'a': a_src, 'indices': indices_src})
def test_gather_nd():
verify_gather_nd((4,), [[1]])
verify_gather_nd((4,), [[1, 3, 2]])
verify_gather_nd((2, 3), [[1]])
verify_gather_nd((2, 3), [[1], [0]])
verify_gather_nd((2, 3), [[1, 0], [0, 2]])
verify_gather_nd((2, 3, 4), [[1, 0], [0, 2]])
verify_gather_nd((2, 3, 4), [[1, 0], [0, 2], [3, 1]])
verify_gather_nd((2, 3, 4), [[[1, 0], [0, 1]], [[0, 2], [1, 2]],
[[3, 1], [0, 2]]])
verify_gather_nd((2, 3, 4, 5), [[1, 0], [0, 2]])
verify_gather_nd((2, 3, 4, 5), [[1, 0], [2, 1], [3, 2], [4, 2]])
if __name__ == "__main__":
test_check_function()
test_split()
......@@ -556,3 +586,4 @@ if __name__ == "__main__":
test_lrn()
test_l2_normalize()
test_strided_slice()
test_gather_nd()
......@@ -356,6 +356,26 @@ def test_reduce():
check((4, 5, 10), (1, 5, 1), axis=(0, 2), keepdims=True)
def test_gather_nd():
def check(data_shape, indices_shape, out_shape):
x = sym.Variable("x", shape=data_shape)
indices = sym.Variable("indices", shape=indices_shape)
y = sym.gather_nd(x, indices, name="y")
sdict = infer_shape(y)
assert(tuple(sdict["y"][0]) == tuple(out_shape))
check((4,), (1, 1), (1,))
check((4,), (1, 3), (3,))
check((2, 3), (1, 1), (1, 3))
check((2, 3), (2, 1), (1,))
check((2, 3), (2, 5, 6), (5, 6))
check((2, 3, 4), (1, 1), (1, 3, 4))
check((2, 3, 4), (2, 1), (1, 4))
check((2, 3, 4), (2, 5), (5, 4))
check((2, 3, 4), (2, 5, 6), (5, 6, 4))
check((2, 3, 4, 5), (2, 6, 7), (6, 7, 4, 5))
if __name__ == "__main__":
test_conv2d_packed()
test_expand_dims()
......@@ -376,3 +396,4 @@ if __name__ == "__main__":
test_transpose()
test_prelu()
test_squeeze()
test_gather_nd()
......@@ -640,6 +640,60 @@ inline Tensor where(const Tensor& condition,
}
/*!
* \brief Gather elements from a n-dimension array.
*
* \param data The source array.
* \param indices The indices of the values to extract.
* \param name The name of the operation.
* \param tag The tag to mark the operation.
*
* \return A Tensor whose op member is the gather_nd operation
*/
inline Tensor gather_nd(const Tensor& data,
const Tensor& indices,
std::string name = "tensor",
std::string tag = kInjective) {
size_t ndim_d = data->shape.size();
size_t ndim_i = indices->shape.size();
CHECK_GT(ndim_i, 1) << "indices tensor must have at least 2 dimensions";
size_t indices_dim0 = static_cast<size_t>(GetConstInt(indices->shape[0]));
CHECK_LE(indices_dim0, ndim_d) << "dim 0 of indices tensor must be no more "
<< "than dimensions of data tensor";
Array<Expr> out_shape;
for (size_t i = 1; i < ndim_i; ++i) {
out_shape.push_back(indices->shape[i]);
}
for (size_t i = indices_dim0; i < ndim_d; ++i) {
out_shape.push_back(data->shape[i]);
}
if (out_shape.size() == 0) {
out_shape.push_back(make_const(Int(32), 1));
}
return compute(
out_shape, [&](const Array<Var>& out_index) {
Array<Expr> indices_position;
indices_position.push_back(0);
for (size_t i = 0; i < ndim_i - 1; ++i) {
indices_position.push_back(out_index[i]);
}
Array<Expr> real_indices;
for (size_t i = 0; i < indices_dim0; ++i) {
indices_position.Set(0, make_const(Int(32), i));
if (indices->dtype.is_int()) {
real_indices.push_back(indices(indices_position));
} else {
real_indices.push_back(
tvm::cast(tvm::Int(32), indices(indices_position)));
}
}
for (size_t i = ndim_i - 1; i < out_index.size(); ++i) {
real_indices.push_back(out_index[i]);
}
return data(real_indices);
}, name, tag);
}
/*!
* \brief Creates an operation that calculates a matrix multiplication
* (row-major notation):
* A(i, k) * B(k, j), if trans_a == trans_b
......
......@@ -18,3 +18,4 @@ from .region_python import region_python
from .shortcut_python import shortcut_python
from .lrn_python import lrn_python
from .l2_normalize_python import l2_normalize_python
from .gather_nd_python import gather_nd_python
# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals
"""gather_nd in python"""
import numpy as np
def gather_nd_python(a_np, indices_np):
""" Python version of GatherND operator
Parameters
----------
a_np : numpy.ndarray
Numpy array
indices_np : numpy.ndarray
Numpy array
Returns
-------
b_np : numpy.ndarray
Numpy array
"""
a_shape = a_np.shape
indices_np = indices_np.astype('int32')
indices_shape = indices_np.shape
assert len(indices_shape) > 1
assert indices_shape[0] <= len(a_shape)
b_shape = list(indices_shape[1:])
for i in range(indices_shape[0], len(a_shape)):
b_shape.append(a_shape[i])
b_np = np.zeros(b_shape)
for idx in np.ndindex(*indices_shape[1:]):
a_idx = []
for i in range(indices_shape[0]):
indices_pos = tuple([i] + list(idx))
a_idx.append(indices_np[indices_pos])
b_np[idx] = a_np[tuple(a_idx)]
return b_np
......@@ -240,6 +240,24 @@ def take(a, indices, axis=None):
return cpp.take(a, indices, int(axis))
def gather_nd(a, indices):
"""Gather elements from a n-dimension array..
Parameters
----------
a : tvm.Tensor
The source array.
indices : tvm.Tensor
The indices of the values to extract.
Returns
-------
ret : tvm.Tensor
"""
return cpp.gather_nd(a, indices)
def matmul(a, b, transp_a=False, transp_b=False):
"""
Creates an operation that calculates a matrix multiplication (row-major notation):
......
......@@ -291,6 +291,11 @@ TVM_REGISTER_GLOBAL("topi.where")
*rv = where(args[0], args[1], args[2]);
});
TVM_REGISTER_GLOBAL("topi.gather_nd")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = gather_nd(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.matmul")
.set_body([](TVMArgs args, TVMRetValue *rv) {
switch ( args.size() ) {
......
......@@ -2,6 +2,7 @@
import numpy as np
import tvm
import topi
import topi.testing
from common import get_all_backend
......@@ -275,6 +276,38 @@ def verify_strided_slice(in_shape, begin, end, stride=None):
for device in ["llvm", "opencl", "sdaccel", "aocl_sw_emu"]:
check_device(device)
def verify_gather_nd(src_shape, indices_src, indices_dtype):
src_dtype = "float32"
indices_src = np.array(indices_src, dtype=indices_dtype)
A = tvm.placeholder(shape=src_shape, dtype=src_dtype, name="A")
indices = tvm.placeholder(shape=indices_src.shape, dtype=indices_dtype, name="indices")
out_tensor = topi.gather_nd(a=A, indices=indices)
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_injective(out_tensor)
func = tvm.build(s, [A, indices, out_tensor] , device, name="take")
shape_size = 1
for i in range(len(src_shape)):
shape_size = shape_size * src_shape[i]
data_npy = np.arange(shape_size, dtype=src_dtype).reshape((src_shape))
out_npys = topi.testing.gather_nd_python(data_npy, indices_src)
data_nd = tvm.nd.array(data_npy, ctx)
indices_nd = tvm.nd.array(indices_src, ctx)
out_nd = tvm.nd.empty(out_npys.shape, ctx=ctx, dtype=src_dtype)
func(data_nd, indices_nd, out_nd)
tvm.testing.assert_allclose(out_nd.asnumpy(), out_npys)
for device in get_all_backend():
check_device(device)
def test_strided_slice():
verify_strided_slice((3, 4, 3), [0, 0, 0], [4, -5, 4], [1, -1, 2])
verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1])
......@@ -363,6 +396,21 @@ def test_take():
verify_take((2,2), [[[1,0],[0,1]]], 1)
verify_take((4,3,5,6), [[2,1,0,0]], -2)
def test_gather_nd():
for indices_dtype in ['int32', 'float32']:
verify_gather_nd((4,), [[1.8]], indices_dtype)
verify_gather_nd((4,), [[1, 3, 2]], indices_dtype)
verify_gather_nd((2, 3), [[1]], indices_dtype)
verify_gather_nd((2, 3), [[1], [0]], indices_dtype)
verify_gather_nd((2, 3), [[1, 0], [0, 2]], indices_dtype)
verify_gather_nd((2, 3, 4), [[1, 0], [0, 2]], indices_dtype)
verify_gather_nd((2, 3, 4), [[1, 0], [0, 2], [3, 1]], indices_dtype)
verify_gather_nd((2, 3, 4), [[[1, 0], [0, 1]], [[0, 2], [1, 2]],
[[3, 1], [0, 2]]], indices_dtype)
verify_gather_nd((2, 3, 4, 5), [[1, 0], [0, 2]], indices_dtype)
verify_gather_nd((2, 3, 4, 5), [[1, 0], [2, 1], [3, 2], [4, 2]],
indices_dtype)
if __name__ == "__main__":
test_concatenate()
test_tranpose()
......@@ -374,3 +422,4 @@ if __name__ == "__main__":
test_expand_like()
test_take()
test_strided_slice()
test_gather_nd()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment