Commit 58888b21 by Pariksheet Pinjari Committed by Tianqi Chen

[TOPI] add take (#1158)

parent bd988658
......@@ -398,5 +398,86 @@ inline Array<Tensor> split_sections(const Tensor& x,
return split(x, split_indices, axis, name, tag);
}
/*!
* \brief Take elements from an flattened input array when axis is None.
*
* \param a The source array.
* \param indices The indices of the values to extract.
* \param name The name of the operation.
* \param tag The tag to mark the operation.
*
* \return A Tensor whose op member is the take operation
*/
inline Tensor take(const Tensor& a,
const Tensor& indices,
std::string name = "tensor",
std::string tag = kInjective) {
Array<Expr> a_shape = a->shape;
Array<Expr> out_shape;
for (size_t j = 0; j < indices->shape.size(); ++j) {
out_shape.push_back(indices->shape[j]);
}
return compute(
out_shape, [&](const Array<Var>& out_index) {
Array<Expr> indices_position;
for (size_t j = 0; j < indices->shape.size(); ++j) {
indices_position.push_back(out_index[j]);
}
return a(UnavelIndex(indices(indices_position), a_shape));
}, name, tag);
}
/*!
* \brief Take elements from an array along an axis.
*
* \param a The source array.
* \param indices The indices of the values to extract.
* \param axis The axis over which to select values. By default,
* the flattened input array is used.
* \param name The name of the operation.
* \param tag The tag to mark the operation.
*
* \return A Tensor whose op member is the take operation
*/
inline Tensor take(const Tensor& a,
const Tensor& indices,
int axis,
std::string name = "tensor",
std::string tag = kInjective) {
if (axis < 0) {
axis += static_cast<int>(a->shape.size());
}
CHECK_LT(axis, a->shape.size()) << "axis out of bounds";
int indices_len = static_cast<int>(indices->shape.size());
Array<Expr> out_shape;
for (size_t i = 0; i < a->shape.size(); ++i) {
if (axis == static_cast<int>(i)) {
for (size_t j = 0; j < indices->shape.size(); ++j) {
out_shape.push_back(indices->shape[j]);
}
} else {
out_shape.push_back(a->shape[i]);
}
}
return compute(
out_shape, [&](const Array<Var>& out_index) {
Array<Expr> indices_position;
for (size_t j = axis; j < static_cast<size_t>(axis+indices_len); ++j) {
indices_position.push_back(out_index[j]);
}
Array<Expr> real_indices;
for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
real_indices.push_back(out_index[j]);
}
real_indices.push_back(indices(indices_position));
for (size_t j = axis + indices_len; j < out_index.size(); ++j) {
real_indices.push_back(out_index[j]);
}
return a(real_indices);
}, name, tag);
}
} // namespace topi
#endif // TOPI_TRANSFORM_H_
......@@ -286,3 +286,28 @@ def split(ary, indices_or_sections, axis=0):
lambda *indices: _compute(begin_id, *indices), name="s%d" %i)
for i, (out_shape, begin_id) in enumerate(zip(out_shapes, begin_ids))]
# pylint: enable=cell-var-from-loop
@tvm.tag_scope(tag=tag.INJECTIVE)
def take(a, indices, axis=None):
"""Take elements from an array along an axis.
Parameters
----------
a : tvm.Tensor
The source array.
indices : tvm.Tensor
The indices of the values to extract.
axis : int, optional
The axis over which to select values. By default,
the flattened input array is used.
Returns
-------
ret : tvm.Tensor
"""
if axis is None:
return cpp.take(a, indices)
return cpp.take(a, indices, int(axis))
......@@ -270,6 +270,16 @@ TVM_REGISTER_GLOBAL("topi.split")
}
});
TVM_REGISTER_GLOBAL("topi.take")
.set_body([](TVMArgs args, TVMRetValue *rv) {
if (args.size() == 2) {
*rv = take(args[0], args[1]);
} else {
int axis = args[2];
*rv = take(args[0], args[1], axis);
}
});
/* Ops from nn/batch_norm.h */
TVM_REGISTER_GLOBAL("topi.nn.batch_norm_inference")
.set_body([](TVMArgs args, TVMRetValue *rv) {
......
......@@ -207,6 +207,46 @@ def verify_flip(in_shape, axis):
for device in ["llvm", "cuda", "opencl"]:
check_device(device)
def verify_take(src_shape, indices_src, axis=None):
src_dtype = "float32"
indices_dtype = "int32"
indices_src = np.array(indices_src, dtype=indices_dtype)
A = tvm.placeholder(shape=src_shape, dtype=src_dtype, name="A")
indices = tvm.placeholder(shape=indices_src.shape, dtype=indices_dtype, name="indices")
if axis is None:
out_tensor = topi.take(a=A, indices=indices)
else:
out_tensor = topi.take(a=A, indices=indices, axis=axis)
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_injective(out_tensor)
foo = tvm.build(s, [A] + [indices] + [out_tensor] , device, name="take")
shape_size = 1
for i in range(len(src_shape)):
shape_size = shape_size * src_shape[i]
data_npy = np.arange(shape_size, dtype=src_dtype).reshape((src_shape))
if axis is None:
out_npys = np.take(data_npy, indices_src)
else:
out_npys = np.take(data_npy, indices_src, axis=axis)
data_nd = tvm.nd.array(data_npy, ctx)
indices_nd = tvm.nd.array(indices_src, ctx)
out_nd = tvm.nd.empty(out_npys.shape, ctx=ctx, dtype=src_dtype)
foo(data_nd, indices_nd, out_nd)
np.testing.assert_allclose(out_nd.asnumpy(), out_npys)
for device in ["llvm", "opencl"]:
check_device(device)
def test_expand_dims():
verify_expand_dims((3, 10), (3, 10, 1, 1), 2, 2)
verify_expand_dims((3, 10), (1, 3, 10), -3, 1)
......@@ -262,6 +302,15 @@ def test_expand_like():
verify_expand_like((3, 4), (3, 5, 4), [1])
verify_expand_like((5, 7), (5, 6, 7, 8), [1, 3])
def test_take():
verify_take((4,), [1])
verify_take((4,), [[0,1,2,3]])
verify_take((3,3,3), [[11,25]])
verify_take((4,), [[0,1],[2,3]])
verify_take((4,), [1], 0)
verify_take((2,2), [[[1,0],[0,1]]], 0)
verify_take((2,2), [[[1,0],[0,1]]], 1)
verify_take((4,3,5,6), [[2,1,0,0]], -2)
if __name__ == "__main__":
test_concatenate()
......@@ -272,3 +321,4 @@ if __name__ == "__main__":
test_split()
test_flip()
test_expand_like()
test_take()
......@@ -167,6 +167,45 @@ def verify_split(src_shape, indices_or_sections, axis):
for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm"]:
check_device(device)
def verify_take(src_shape, indices_src, axis=None):
src_dtype = "float32"
indices_dtype = "int32"
indices_src = np.array(indices_src, dtype=indices_dtype)
A = tvm.placeholder(shape=src_shape, dtype=src_dtype, name="A")
indices = tvm.placeholder(shape=indices_src.shape, dtype=indices_dtype, name="indices")
if axis is None:
out_tensor = topi.cpp.take(A, indices)
else:
out_tensor = topi.cpp.take(A, indices, axis)
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_injective(out_tensor)
foo = tvm.build(s, [A] + [indices] + [out_tensor] , device, name="take")
shape_size = 1
for i in range(len(src_shape)):
shape_size = shape_size * src_shape[i]
data_npy = np.arange(shape_size, dtype=src_dtype).reshape((src_shape))
if axis is None:
out_npys = np.take(data_npy, indices_src)
else:
out_npys = np.take(data_npy, indices_src, axis=axis)
data_nd = tvm.nd.array(data_npy, ctx)
indices_nd = tvm.nd.array(indices_src, ctx)
out_nd = tvm.nd.empty(out_npys.shape, ctx=ctx, dtype=src_dtype)
foo(data_nd, indices_nd, out_nd)
np.testing.assert_allclose(out_nd.asnumpy(), out_npys)
for device in ["llvm", "opencl"]:
check_device(device)
def test_expand_dims():
verify_expand_dims((3, 10), (3, 10, 1, 1), 2, 2)
......@@ -209,6 +248,16 @@ def test_split():
verify_split((2, 12, 3), [2, 4], 1)
verify_split((10, 12, 24), [5, 7, 9], -1)
def test_take():
verify_take((4,), [1])
verify_take((4,), [[0,1,2,3]])
verify_take((3,3,3), [[11,25]])
verify_take((4,), [[0,1],[2,3]])
verify_take((4,), [1], 0)
verify_take((2,2), [[[1,0],[0,1]]], 0)
verify_take((2,2), [[[1,0],[0,1]]], 1)
verify_take((4,3,5,6), [[2,1,0,0]], -2)
if __name__ == "__main__":
test_concatenate()
test_tranpose()
......@@ -216,3 +265,4 @@ if __name__ == "__main__":
test_reshape()
test_squeeze()
test_split()
test_take()
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment