Commit bd988658 by Pariksheet Pinjari Committed by Tianqi Chen

[TOPI] flip (#1161)

parent b1c690bd
...@@ -107,6 +107,46 @@ inline Tensor transpose(const Tensor& x, ...@@ -107,6 +107,46 @@ inline Tensor transpose(const Tensor& x,
}, name, tag); }, name, tag);
} }
/*!
* \brief flip/reverse elements of an array in a particular axis
*
* \param x The input tensor
* \param axis The axis along which the tensors will be reveresed
* (allows negative indices)
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is the reverse operation
*/
inline Tensor flip(const Tensor& x,
int axis = 0,
std::string name = "tensor",
std::string tag = kInjective) {
size_t src_tensor_dim = x->shape.size();
int axis_inp = axis;
if (axis < 0) {
axis = static_cast<int>(x->shape.size()) + axis;
}
CHECK((0 <= axis) && (axis < static_cast<int>(x->shape.size())))
<< "axis=" << axis_inp << " is invalid for the "
<< static_cast<int>(x->shape.size()) << "-dimensional input tensor";
// Reverse the Input Tensor in the axis specified
return compute(
x->shape, [&](const Array<Var>& indices) {
Array<Expr> real_indices;
for (size_t i = 0; i < src_tensor_dim; ++i) {
if (i == static_cast<size_t>(axis)) {
real_indices.push_back(x->shape[i] - indices[i] - 1);
} else {
real_indices.push_back(indices[i]);
}
}
return x(real_indices);
}, name, tag);
}
/*! /*!
* \brief Reshape a tensor * \brief Reshape a tensor
......
...@@ -5,6 +5,7 @@ import tvm ...@@ -5,6 +5,7 @@ import tvm
import topi import topi
from . import tag from . import tag
from .util import ravel_index, unravel_index, get_const_int, get_const_tuple from .util import ravel_index, unravel_index, get_const_int, get_const_tuple
from . import cpp
@tvm.tag_scope(tag=tag.BROADCAST) @tvm.tag_scope(tag=tag.BROADCAST)
def expand_dims(a, axis, num_newaxis=1): def expand_dims(a, axis, num_newaxis=1):
...@@ -110,6 +111,23 @@ def transpose(a, axes=None): ...@@ -110,6 +111,23 @@ def transpose(a, axes=None):
return a(*idx) return a(*idx)
return tvm.compute(new_shape, _compute) return tvm.compute(new_shape, _compute)
@tvm.tag_scope(tag=tag.INJECTIVE)
def flip(a, axis=0):
"""Flip/reverse elements of an array in a particular axis.
Parameters
----------
a : tvm.Tensor
The tensor to be expanded.
axis : int, optional
The axis along which the tensors will be reveresed.
Returns
-------
ret : tvm.Tensor
"""
return cpp.flip(a, axis)
@tvm.tag_scope(tag=tag.INJECTIVE) @tvm.tag_scope(tag=tag.INJECTIVE)
def reshape(a, newshape): def reshape(a, newshape):
......
...@@ -241,6 +241,11 @@ TVM_REGISTER_GLOBAL("topi.transpose") ...@@ -241,6 +241,11 @@ TVM_REGISTER_GLOBAL("topi.transpose")
*rv = transpose(args[0], args[1]); *rv = transpose(args[0], args[1]);
}); });
TVM_REGISTER_GLOBAL("topi.flip")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = flip(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.reshape") TVM_REGISTER_GLOBAL("topi.reshape")
.set_body([](TVMArgs args, TVMRetValue *rv) { .set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = reshape(args[0], args[1]); *rv = reshape(args[0], args[1]);
......
...@@ -184,6 +184,28 @@ def verify_expand_like(in_shape, out_shape, axis): ...@@ -184,6 +184,28 @@ def verify_expand_like(in_shape, out_shape, axis):
for device in ["llvm"]: for device in ["llvm"]:
check_device(device) check_device(device)
def verify_flip(in_shape, axis):
A = tvm.placeholder(shape=in_shape, name="A")
B = topi.flip(A, axis) + 1
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_injective(B)
foo = tvm.build(s, [A, B], device, name="reverse")
x_np = np.random.uniform(size=in_shape).astype(A.dtype)
out_npy = np.flip(x_np, axis) + 1
data_nd = tvm.nd.array(x_np, ctx)
out_nd = tvm.nd.empty(out_npy.shape, ctx=ctx, dtype=A.dtype)
foo(data_nd, out_nd)
np.testing.assert_allclose(out_nd.asnumpy(), out_npy)
for device in ["llvm", "cuda", "opencl"]:
check_device(device)
def test_expand_dims(): def test_expand_dims():
verify_expand_dims((3, 10), (3, 10, 1, 1), 2, 2) verify_expand_dims((3, 10), (3, 10, 1, 1), 2, 2)
...@@ -226,6 +248,13 @@ def test_split(): ...@@ -226,6 +248,13 @@ def test_split():
verify_split((2, 12, 3), [2, 4], 1) verify_split((2, 12, 3), [2, 4], 1)
verify_split((10, 12, 24), [5, 7, 9], -1) verify_split((10, 12, 24), [5, 7, 9], -1)
def test_flip():
verify_flip((3, 4, 3), 1)
verify_flip((3, 4, 3), 0)
verify_flip((3, 4, 3), 2)
verify_flip((3, 4, 3), -1)
verify_flip((3, 4, 3), -3)
verify_flip((3, 4, 3), -2)
def test_expand_like(): def test_expand_like():
verify_expand_like((3,), (2, 3), [0]) verify_expand_like((3,), (2, 3), [0])
...@@ -241,4 +270,5 @@ if __name__ == "__main__": ...@@ -241,4 +270,5 @@ if __name__ == "__main__":
test_reshape() test_reshape()
test_squeeze() test_squeeze()
test_split() test_split()
test_flip()
test_expand_like() test_expand_like()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment