Commit c1c32758 by Pariksheet Pinjari Committed by Tianqi Chen

[TOPI] Slice operator (#1165)

parent a9313787
......@@ -17,11 +17,14 @@ List of operators
topi.clip
topi.cast
topi.transpose
topi.flip
topi.strided_slice
topi.expand_dims
topi.reshape
topi.squeeze
topi.concatenate
topi.split
topi.take
topi.full
topi.full_like
topi.greater
......@@ -72,11 +75,14 @@ topi
.. autofunction:: topi.clip
.. autofunction:: topi.cast
.. autofunction:: topi.transpose
.. autofunction:: topi.flip
.. autofunction:: topi.strided_slice
.. autofunction:: topi.expand_dims
.. autofunction:: topi.reshape
.. autofunction:: topi.squeeze
.. autofunction:: topi.concatenate
.. autofunction:: topi.split
.. autofunction:: topi.take
.. autofunction:: topi.full
.. autofunction:: topi.full_like
.. autofunction:: topi.greater
......
......@@ -67,6 +67,24 @@ inline std::vector<int> GetConstIntValues(Array<Expr> exprs, const std::string&
}
/*!
* \brief Get the value of all the constant integer expressions in the given array
*
* \param exprs The array of expressions to get the values of
* \param var_name The name to be used when logging an error in the event that any
* of the expressions are not constant integers.
*
* \return A vector of the int64_t values
*/
inline std::vector<int64_t> GetConstInt64Values(Array<Expr> exprs, const std::string& var_name) {
std::vector<int64_t> result;
for (auto expr : exprs) {
CHECK(IsConstInt(expr)) << "All elements of " << var_name << " must be constant integers";
result.push_back(GetConstInt(expr));
}
return result;
}
/*!
* \brief Check weather the two expressions are equal or not, if not simplify the expressions and check again
* \note This is stronger equality check than tvm::ir::Equal
*
......
......@@ -366,6 +366,87 @@ inline Array<Tensor> split(const Tensor& x,
}
/*!
* \brief strided_slice of a tensor
*
* \param x The input tensor
* \param begin The indices to begin with in the slicing
* \param end Indicies indicating end of the slice
* \param strides Specifies the stride values, it can be negative
* in that case, the input tensor will be reversed in that particular axis
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is the split operation
*/
inline Tensor strided_slice(const Tensor& x,
const Array<Expr>& begin,
const Array<Expr>& end,
const Array<Expr>& strides,
std::string name = "tensor",
std::string tag = kInjective) {
size_t src_tensor_dim = static_cast<size_t>(x->shape.size());
std::vector<int64_t> begin_vec = GetConstInt64Values(begin, "begin");
std::vector<int64_t> end_vec = GetConstInt64Values(end, "end");
std::vector<int64_t> stride_vec = GetConstInt64Values(strides, "strides");
// in case user has not provided begin indices for all the axes,
// then inflate it with default value = 0
for (size_t i = begin_vec.size(); i < src_tensor_dim; ++i) {
begin_vec.push_back(0);
}
// in case user has not provided end indices for all the axes,
// then inflate it with default value = input_tensor.shape[axis]
for (size_t i = end_vec.size(); i < src_tensor_dim; ++i) {
end_vec.push_back(GetConstInt(x->shape[i]));
}
// in case user has not provided stride values,
// then inflate it with default value = 1
for (size_t i = stride_vec.size(); i < src_tensor_dim; ++i) {
stride_vec.push_back(1);
}
Array<Expr> out_shape;
Array<Expr> begin_expr;
Array<Expr> strides_expr;
for (size_t i = 0; i < src_tensor_dim; ++i) {
int64_t begin_range = stride_vec[i] < 0 ? -1 : 0;
int64_t dim_i = GetConstInt(x->shape[i]);
int64_t end_range = stride_vec[i] < 0 ? dim_i - 1 : dim_i;
// transform negative indices to positive value, clips on the correct range
auto index_canonicalization = [dim_i, begin_range, end_range](int64_t index) {
if (index < 0) {
index += dim_i;
}
return std::min(std::max(index, begin_range), end_range);
};
int64_t begin_i = index_canonicalization(begin_vec[i]);
int64_t end_i = index_canonicalization(end_vec[i]);
int interval = std::abs(end_i - begin_i);
int slice_size = static_cast<int>((interval
+ std::abs(stride_vec[i]) - 1) / std::abs(stride_vec[i]));
CHECK(stride_vec[i] < 0 ? (end_i < begin_i) : (begin_i < end_i))
<< ": Input [Begin=" << begin_vec[i] << ", End=" << end_vec[i]
<< "] is invalid for axis=" << i;
begin_expr.push_back(make_const(begin[0].type(), begin_i));
strides_expr.push_back(make_const((strides.size() != 0 ? strides[0].type() : begin[0].type()),
stride_vec[i]));
out_shape.push_back(slice_size);
}
return compute(
out_shape, [&](const Array<Var>& indices) {
Array<Expr> real_indices;
for (size_t i = 0; i < src_tensor_dim; ++i) {
real_indices.push_back(indices[i] * strides_expr[i] + begin_expr[i]);
}
return x(real_indices);
}, name, tag);
}
/*!
* \brief Split a tensor into a number of sub-tensors
*
* \param x The input tensor
......
......@@ -130,6 +130,32 @@ def flip(a, axis=0):
return cpp.flip(a, axis)
@tvm.tag_scope(tag=tag.INJECTIVE)
def strided_slice(a, begin, end, strides=None):
"""Slice of an array.
Parameters
----------
a : tvm.Tensor
The tensor to be sliced.
begin: list of int
The indices to begin with in the slicing.
end: list of int
Indicies indicating end of the slice.
strides: list of int, optional
Specifies the stride values, it can be negative
in that case, the input tensor will be reversed
in that particular axis.
Returns
-------
ret : tvm.Tensor
"""
return cpp.strided_slice(a, begin, end, strides)
@tvm.tag_scope(tag=tag.INJECTIVE)
def reshape(a, newshape):
"""Reshape the array
......
......@@ -280,6 +280,11 @@ TVM_REGISTER_GLOBAL("topi.take")
}
});
TVM_REGISTER_GLOBAL("topi.strided_slice")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = strided_slice(args[0], args[1], args[2], args[3]);
});
/* Ops from nn/batch_norm.h */
TVM_REGISTER_GLOBAL("topi.nn.batch_norm_inference")
.set_body([](TVMArgs args, TVMRetValue *rv) {
......
......@@ -246,6 +246,40 @@ def verify_take(src_shape, indices_src, axis=None):
for device in ["llvm", "opencl"]:
check_device(device)
def verify_strided_slice(in_shape, begin, end, stride=None):
stride = stride if stride else [1, 1, 1]
A = tvm.placeholder(shape=in_shape, name="A")
B = topi.strided_slice(A, begin, end, stride) + 1
def test_forward(x, begin, end, stride):
return x[begin[0]:end[0]:stride[0],
begin[1]:end[1]:stride[1], begin[2]:end[2]:stride[2]] + 1
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_injective(B)
foo = tvm.build(s, [A, B], device, name="stride_slice")
x_np = np.random.uniform(size=in_shape).astype(A.dtype)
out_npy = test_forward(x_np, begin, end, stride)
data_nd = tvm.nd.array(x_np, ctx)
out_nd = tvm.nd.empty(out_npy.shape, ctx=ctx, dtype=A.dtype)
foo(data_nd, out_nd)
np.testing.assert_allclose(out_nd.asnumpy(), out_npy)
for device in ["llvm", "opencl"]:
check_device(device)
def test_strided_slice():
verify_strided_slice((3, 4, 3), [0, 0, 0], [4, -5, 4], [1, -1, 2])
verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1])
verify_strided_slice((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1])
verify_strided_slice((3, 4, 3), [1, 0, 0], [2, 2, 3], [1, 1, 2])
verify_strided_slice((3, 4, 3), [1, -1, 0], [2, -3, 3], [1, -1, 1])
verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4, 3])
def test_expand_dims():
verify_expand_dims((3, 10), (3, 10, 1, 1), 2, 2)
......@@ -322,3 +356,4 @@ if __name__ == "__main__":
test_flip()
test_expand_like()
test_take()
test_strided_slice()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment