Unverified Commit 1f2c8156 by Tianqi Chen Committed by GitHub

[RELAY][OP] strided_slice (#2094)

parent 4369b7f6
...@@ -123,6 +123,7 @@ This level enables additional math and transform operators. ...@@ -123,6 +123,7 @@ This level enables additional math and transform operators.
tvm.relay.min tvm.relay.min
tvm.relay.mean tvm.relay.mean
tvm.relay.prod tvm.relay.prod
tvm.relay.strided_slice
**Level 5: Vision/Image Operators** **Level 5: Vision/Image Operators**
...@@ -227,6 +228,7 @@ Level 4 Definitions ...@@ -227,6 +228,7 @@ Level 4 Definitions
.. autofunction:: tvm.relay.min .. autofunction:: tvm.relay.min
.. autofunction:: tvm.relay.mean .. autofunction:: tvm.relay.mean
.. autofunction:: tvm.relay.prod .. autofunction:: tvm.relay.prod
.. autofunction:: tvm.relay.strided_slice
......
...@@ -123,6 +123,21 @@ struct SplitAttrs : public tvm::AttrsNode<SplitAttrs> { ...@@ -123,6 +123,21 @@ struct SplitAttrs : public tvm::AttrsNode<SplitAttrs> {
} }
}; };
/*! \brief Attributes for StridedSlice operator */
struct StridedSliceAttrs : public tvm::AttrsNode<StridedSliceAttrs> {
Array<Integer> begin;
Array<Integer> end;
Array<Integer> strides;
TVM_DECLARE_ATTRS(StridedSliceAttrs, "relay.attrs.StridedSliceAttrs") {
TVM_ATTR_FIELD(begin)
.describe("Indices for begin of slice, begin index is also inclusive");
TVM_ATTR_FIELD(end)
.describe("Indices for end of slice, end index is also inclusive");
TVM_ATTR_FIELD(strides).set_default(Array<Integer>({}))
.describe("Stride values of the slice");
}
};
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
#endif // TVM_RELAY_ATTRS_TRANSFORM_H_ #endif // TVM_RELAY_ATTRS_TRANSFORM_H_
...@@ -980,23 +980,25 @@ Examples:: ...@@ -980,23 +980,25 @@ Examples::
const Array<Tensor>& inputs, const Array<Tensor>& inputs,
const Array<Tensor>& out_info) { const Array<Tensor>& out_info) {
const StridedSliceParam& param = nnvm::get<StridedSliceParam>(attrs.parsed); const StridedSliceParam& param = nnvm::get<StridedSliceParam>(attrs.parsed);
Array<Expr> begin; Array<Integer> begin;
Array<Expr> end; Array<Integer> end;
Array<Expr> stride; Array<Integer> stride;
for (int64_t i : param.begin) { for (int64_t i : param.begin) {
begin.push_back(tvm::make_const(tvm::Int(32), i)); begin.push_back(static_cast<int>(i));
} }
for (int64_t i : param.end) { for (int64_t i : param.end) {
end.push_back(tvm::make_const(tvm::Int(32), i)); end.push_back(static_cast<int>(i));
} }
for (int64_t i : param.stride) { for (int64_t i : param.stride) {
stride.push_back(tvm::make_const(tvm::Int(32), i)); stride.push_back(static_cast<int>(i));
} }
return Array<Tensor>{ topi::strided_slice(inputs[0], begin, end, stride) }; return Array<Tensor>{
topi::strided_slice(inputs[0], begin, end, stride)
};
}) })
.set_support_level(1); .set_support_level(1);
...@@ -1210,6 +1212,15 @@ inline bool SliceLikeShape(const nnvm::NodeAttrs& attrs, ...@@ -1210,6 +1212,15 @@ inline bool SliceLikeShape(const nnvm::NodeAttrs& attrs,
return true; return true;
} }
// Adapter function to make int array.
Array<Integer> GetIntArray(Array<Expr> arr) {
for (size_t i = 0; i < arr.size(); ++i) {
CHECK(!arr[i].defined() || arr[i].as<IntImm>())
<< "Expect an int array";
}
return Array<Integer>(arr.node_);
}
NNVM_REGISTER_OP(slice_like) NNVM_REGISTER_OP(slice_like)
.describe(R"code(Slice the first input respect to the second input. .describe(R"code(Slice the first input respect to the second input.
)code" NNVM_ADD_FILELINE) )code" NNVM_ADD_FILELINE)
...@@ -1261,7 +1272,10 @@ NNVM_REGISTER_OP(slice_like) ...@@ -1261,7 +1272,10 @@ NNVM_REGISTER_OP(slice_like)
} }
} }
return Array<Tensor>{ return Array<Tensor>{
topi::strided_slice(inputs[0], begin_idx, end_idx, strides) topi::strided_slice(inputs[0],
GetIntArray(begin_idx),
GetIntArray(end_idx),
GetIntArray(strides))
}; };
}) })
.set_attr<FListInputNames>("FListInputNames", [](const NodeAttrs& attrs) { .set_attr<FListInputNames>("FListInputNames", [](const NodeAttrs& attrs) {
......
...@@ -56,6 +56,8 @@ def convert_to_node(value): ...@@ -56,6 +56,8 @@ def convert_to_node(value):
return _api_internal._Map(*vlist) return _api_internal._Map(*vlist)
elif isinstance(value, NodeGeneric): elif isinstance(value, NodeGeneric):
return value.asnode() return value.asnode()
elif value is None:
return None
else: else:
raise ValueError("don't know how to convert type %s to node" % type(value)) raise ValueError("don't know how to convert type %s to node" % type(value))
......
...@@ -13,6 +13,7 @@ from . import vision ...@@ -13,6 +13,7 @@ from . import vision
# operator registry # operator registry
from . import _tensor from . import _tensor
from . import _transform
from ..expr import Expr from ..expr import Expr
from ..base import register_relay_node from ..base import register_relay_node
......
#pylint: disable=invalid-name, unused-argument
"""Backend compiler related feature registration"""
from __future__ import absolute_import
from . import op as _reg
from .op import schedule_injective
# strided_slice
_reg.register_schedule("strided_slice", schedule_injective)
...@@ -334,3 +334,30 @@ def split(data, indices_or_sections, axis=0): ...@@ -334,3 +334,30 @@ def split(data, indices_or_sections, axis=0):
else: else:
ret_size = len(indices_or_sections) + 1 ret_size = len(indices_or_sections) + 1
return TupleWrapper(_make.split(data, indices_or_sections, axis), ret_size) return TupleWrapper(_make.split(data, indices_or_sections, axis), ret_size)
def strided_slice(data, begin, end, strides=None):
"""Strided slice of an array..
Parameters
----------
data : relay.Expr
The source array to be sliced.
begin: list of int
The indices to begin with in the slicing.
end: list of int
Indicies indicating end of the slice.
strides: list of int, optional
Specifies the stride values, it can be negative in that case,
the input tensor will be reversed in that particular axis.
Returns
-------
ret : relay.Expr
The computed result.
"""
strides = strides or []
return _make.strided_slice(data, list(begin), list(end), list(strides))
...@@ -47,7 +47,11 @@ TVM_REGISTER_API("_Array") ...@@ -47,7 +47,11 @@ TVM_REGISTER_API("_Array")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
std::vector<NodePtr<Node> > data; std::vector<NodePtr<Node> > data;
for (int i = 0; i < args.size(); ++i) { for (int i = 0; i < args.size(); ++i) {
data.push_back(args[i].node_sptr()); if (args[i].type_code() != kNull) {
data.push_back(args[i].node_sptr());
} else {
data.push_back(NodePtr<Node>(nullptr));
}
} }
auto node = make_node<ArrayNode>(); auto node = make_node<ArrayNode>();
node->data = std::move(data); node->data = std::move(data);
......
...@@ -403,7 +403,11 @@ class TextPrinter : ...@@ -403,7 +403,11 @@ class TextPrinter :
* \param os The output type. * \param os The output type.
*/ */
void PrintAttr(const NodeRef& value, std::ostream& os) { // NOLINT(*) void PrintAttr(const NodeRef& value, std::ostream& os) { // NOLINT(*)
this->VisitAttr(value, os); if (value.defined()) {
this->VisitAttr(value, os);
} else {
os << "None";
}
} }
//------------------------------------ //------------------------------------
// Overload of Attr printing functions // Overload of Attr printing functions
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include <tvm/relay/attrs/transform.h> #include <tvm/relay/attrs/transform.h>
#include <tvm/ir_operator.h> #include <tvm/ir_operator.h>
#include <tvm/ir.h> #include <tvm/ir.h>
#include <topi/transform.h>
#include <vector> #include <vector>
#include "../op_common.h" #include "../op_common.h"
...@@ -890,6 +891,173 @@ RELAY_REGISTER_OP("broadcast_to_like") ...@@ -890,6 +891,173 @@ RELAY_REGISTER_OP("broadcast_to_like")
.set_support_level(10) .set_support_level(10)
.add_type_rel("BroadCastToLike", BroadCastToLikeRel); .add_type_rel("BroadCastToLike", BroadCastToLikeRel);
// strided_slice
TVM_REGISTER_NODE_TYPE(StridedSliceAttrs);
bool StridedSliceRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) return false;
const StridedSliceAttrs *param = attrs.as<StridedSliceAttrs>();
CHECK(param != nullptr);
auto dshape = data->shape;
auto num_axis = dshape.size();
std::vector<int64_t> stride_vec;
for (Integer i : param->strides) {
CHECK(i.defined());
stride_vec.push_back(i->value);
}
for (size_t i = stride_vec.size(); i < num_axis; ++i) {
stride_vec.push_back(1);
}
const int64_t max_range = std::numeric_limits<int64_t>::max();
std::vector<int64_t> begin_vec;
for (size_t i = 0; i < param->begin.size(); ++i) {
if (!param->begin[i].defined()) {
// value=None
begin_vec.push_back(stride_vec[i] > 0 ? 0 : max_range);
} else {
begin_vec.push_back(param->begin[i]->value);
}
}
for (size_t i = begin_vec.size(); i < num_axis; ++i) {
begin_vec.push_back(stride_vec[i] > 0 ? 0 : max_range);
}
std::vector<int64_t> end_vec;
for (size_t i = 0; i < param->end.size(); ++i) {
// allow end to be None
if (!param->end[i].defined()) {
end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range);
} else {
end_vec.push_back(param->end[i]->value);
}
}
for (size_t i = end_vec.size(); i < num_axis; ++i) {
end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range);
}
std::vector<IndexExpr> oshape(dshape.size());
for (size_t i = 0; i < num_axis; ++i) {
int64_t stride_v = stride_vec[i];
int64_t begin_v = begin_vec[i];
int64_t end_v = end_vec[i];
if ((stride_v == 1 &&
begin_v == 0 &&
end_v == max_range) ||
(stride_v == -1 &&
begin_v == max_range &&
end_v == 0)) {
// Quick path, do not slice this dimension.
oshape[i] = dshape[i];
continue;
}
// Normal path, require the shape to be concrete integer.
// Require concrete integer as symbolic inference of min/max
// can get complicated and not very helpful.
const int64_t* p_dim_size = as_const_int(dshape[i]);
CHECK(p_dim_size)
<< "strided_slice requires sliced dimension to be concrete int";
int64_t dim_size = p_dim_size[0];
begin_v = (begin_v < 0) ? dim_size + begin_v : begin_v;
end_v = (end_v < 0) ? dim_size + end_v : end_v;
int64_t slice_range, step;
if (stride_v < 0) {
if (end_v < -1) end_v = -1;
CHECK_LT(end_v, begin_v)
<< "strided_slice get empty slice at axis " << i;
begin_v = std::min(dim_size - 1, begin_v);
slice_range = begin_v - end_v;
step = -stride_v;
} else {
if (begin_v < 0) begin_v = 0;
CHECK_GE(stride_v, 0);
CHECK_LT(begin_v, end_v)
<< "strided_slice get empty slice at axis " << i;
end_v = std::min(dim_size, end_v);
slice_range = end_v - begin_v;
step = stride_v;
}
oshape[i] = make_const(dshape[i].type(), (slice_range + step - 1) / step);
}
reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));
return true;
}
// Positional relay function to create StridedSlice operator used by frontend FFI.
Expr MakeStridedSlice(Expr data,
Array<Integer> begin,
Array<Integer> end,
Array<Integer> strides) {
auto attrs = make_node<StridedSliceAttrs>();
attrs->begin = std::move(begin);
attrs->end = std::move(end);
attrs->strides = std::move(strides);
static const Op& op = Op::Get("strided_slice");
return CallNode::make(op, {data}, Attrs(attrs), {});
}
Array<Tensor> StridedSliceCompute(const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
const Target& target) {
const StridedSliceAttrs *param = attrs.as<StridedSliceAttrs>();
CHECK(param != nullptr);
return Array<Tensor>{
topi::strided_slice(inputs[0], param->begin, param->end, param->strides)
};
}
TVM_REGISTER_API("relay.op._make.strided_slice")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 4>(MakeStridedSlice, args, rv);
});
RELAY_REGISTER_OP("strided_slice")
.describe(R"code(Strided slice of an array.
Examples::
x = [[ 1., 4., 7., 10.],
[ 2., 5., 8., 11.],
[ 3., 6., 9., 12.]]
strided_slice(x, begin=[0, 1], end=[2, 4], stride=[1, 1]) = [[ 4., 7., 10.],
[ 5., 8., 11.]]
x = [[[ 1., 2.],
[ 3., 4.]],
[[ 5., 6.],
[ 7., 8.]]]
strided_slice(x, begin=[0, 0], end=[2, 2]) = [[[ 1., 2.],
[ 3., 4.]],
[[ 5., 6.],
[ 7., 8.]]]
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(4)
.set_attrs_type_key("relay.attrs.StridedSliceAttrs")
.add_type_rel("StridedSlice", StridedSliceRel)
.set_attr<FTVMCompute>("FTVMCompute", StridedSliceCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);
// Split // Split
TVM_REGISTER_NODE_TYPE(SplitAttrs); TVM_REGISTER_NODE_TYPE(SplitAttrs);
......
...@@ -2,7 +2,7 @@ import tvm ...@@ -2,7 +2,7 @@ import tvm
import numpy as np import numpy as np
from tvm import relay from tvm import relay
from tvm.relay.testing import ctx_list from tvm.relay.testing import ctx_list
import topi.testing
def test_binary_op(): def test_binary_op():
def check_binary_op(opfunc, ref): def check_binary_op(opfunc, ref):
...@@ -142,7 +142,43 @@ def test_reduce_functions(): ...@@ -142,7 +142,43 @@ def test_reduce_functions():
verify_reduce(func, (128, 24, 128), (0, 1), True, False, (1, 1, 128)) verify_reduce(func, (128, 24, 128), (0, 1), True, False, (1, 1, 128))
verify_reduce(func, (128, 24, 128), (0, 2), True, False, (1, 24, 1)) verify_reduce(func, (128, 24, 128), (0, 2), True, False, (1, 24, 1))
def test_strided_slice():
def verify(dshape, begin, end, strides, output, test_ref=True):
x = relay.var("x", relay.TensorType(dshape, "float32"))
z = relay.strided_slice(x, begin=begin, end=end, strides=strides)
func = relay.Function([x], z)
func = relay.ir_pass.infer_type(func)
text = func.astext()
assert "begin=" in text
assert "end=" in text
if output:
assert func.body.checked_type == relay.ty.TensorType(output, "float32")
if not test_ref:
return
x_data = np.random.uniform(size=dshape).astype("float32")
ref_res = topi.testing.strided_slice_python(
x_data, begin, end, strides)
for target, ctx in ctx_list():
intrp = relay.create_executor("graph", ctx=ctx, target=target)
op_res = intrp.evaluate(func)(x_data)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res)
d1, d2, d3, d4 = tvm.var("d1"), tvm.var("d2"), tvm.var("d3"), tvm.var("d4")
verify((d1, d2, 3), [None, None, 1], [None, None, 2], None, (d1, d2, 1), False)
verify((3, 4, 3), [0, 0, 0], [4, -5, 4], [1, -1, 2], (3, 1, 2))
verify((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1], (1, 3, 3))
verify((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], (1, 4, 3))
verify((3, 4, 3), [1, 0, 0], [2, 2, 3], [1, 1, 2], (1, 2, 2))
verify((3, 4, 3), [1, -1, 0], [2, -3, 3], [1, -1, 1], (1, 2, 3))
verify((3, 4, 3), [1, 1, 0], [4, 4, 3], None, (2, 3, 3))
verify((3, 4, 3), [1, 1, 0], [4, 1000, 3], None, (2, 3, 3))
verify((3, 4, 3), [1, 1, 0], [4, 4], None, (2, 3, 3))
verify((3, 4, 3), [1, 1], [4, 4, 3], None, (2, 3, 3))
if __name__ == "__main__": if __name__ == "__main__":
test_strided_slice()
test_binary_op() test_binary_op()
test_cmp_type() test_cmp_type()
test_binary_int_broadcast() test_binary_int_broadcast()
......
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include <vector> #include <vector>
#include <iterator> #include <iterator>
#include <algorithm> #include <algorithm>
#include <limits>
#include "topi/tags.h" #include "topi/tags.h"
#include "topi/detail/ravel_unravel.h" #include "topi/detail/ravel_unravel.h"
...@@ -403,31 +404,51 @@ inline Array<Tensor> split(const Tensor& x, ...@@ -403,31 +404,51 @@ inline Array<Tensor> split(const Tensor& x,
* \return A Tensor whose op member is the split operation * \return A Tensor whose op member is the split operation
*/ */
inline Tensor strided_slice(const Tensor& x, inline Tensor strided_slice(const Tensor& x,
const Array<Expr>& begin, const Array<Integer>& begin,
const Array<Expr>& end, const Array<Integer>& end,
const Array<Expr>& strides, const Array<Integer>& strides,
std::string name = "tensor", std::string name = "tensor",
std::string tag = kInjective) { std::string tag = kInjective) {
size_t src_tensor_dim = static_cast<size_t>(x->shape.size()); size_t src_tensor_dim = static_cast<size_t>(x->shape.size());
std::vector<int64_t> begin_vec = GetConstInt64Values(begin, "begin"); // Setup the ranges.
std::vector<int64_t> end_vec = GetConstInt64Values(end, "end"); // NOTE: this code duplicates the shape inference logic relay.op
std::vector<int64_t> stride_vec = GetConstInt64Values(strides, "strides"); // Consider to refactor in the future.
// in case user has not provided begin indices for all the axes, std::vector<int64_t> stride_vec;
// then inflate it with default value = 0 for (Integer i : strides) {
for (size_t i = begin_vec.size(); i < src_tensor_dim; ++i) { CHECK(i.defined());
begin_vec.push_back(0); stride_vec.push_back(i->value);
}
// in case user has not provided end indices for all the axes,
// then inflate it with default value = input_tensor.shape[axis]
for (size_t i = end_vec.size(); i < src_tensor_dim; ++i) {
end_vec.push_back(GetConstInt(x->shape[i]));
} }
// in case user has not provided stride values,
// then inflate it with default value = 1
for (size_t i = stride_vec.size(); i < src_tensor_dim; ++i) { for (size_t i = stride_vec.size(); i < src_tensor_dim; ++i) {
stride_vec.push_back(1); stride_vec.push_back(1);
} }
const int64_t max_range = std::numeric_limits<int64_t>::max();
std::vector<int64_t> begin_vec;
for (size_t i = 0; i < begin.size(); ++i) {
if (!begin[i].defined()) {
// value=None
begin_vec.push_back(stride_vec[i] > 0 ? 0 : max_range);
} else {
begin_vec.push_back(begin[i]->value);
}
}
for (size_t i = begin_vec.size(); i < src_tensor_dim; ++i) {
begin_vec.push_back(stride_vec[i] > 0 ? 0 : max_range);
}
std::vector<int64_t> end_vec;
for (size_t i = 0; i < end.size(); ++i) {
// allow end to be None
if (!end[i].defined()) {
end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range);
} else {
end_vec.push_back(end[i]->value);
}
}
for (size_t i = end_vec.size(); i < src_tensor_dim; ++i) {
end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range);
}
// Compute
Array<Expr> out_shape; Array<Expr> out_shape;
Array<Expr> begin_expr; Array<Expr> begin_expr;
Array<Expr> strides_expr; Array<Expr> strides_expr;
......
...@@ -19,3 +19,4 @@ from .shortcut_python import shortcut_python ...@@ -19,3 +19,4 @@ from .shortcut_python import shortcut_python
from .lrn_python import lrn_python from .lrn_python import lrn_python
from .l2_normalize_python import l2_normalize_python from .l2_normalize_python import l2_normalize_python
from .gather_nd_python import gather_nd_python from .gather_nd_python import gather_nd_python
from .strided_slice_python import strided_slice_python
"""gather_nd in python"""
def strided_slice_python(data, begin, end, strides):
"""Python version of strided slice operator.
Parameters
----------
data : numpy.ndarray
Input data
begin : list
Begining of the slices.
end : list
End of the slices.
strides : list
The stride of each slice.
Returns
-------
result : numpy.ndarray
The sliced result.
"""
strides = [] if strides is None else strides
slices = []
for i in range(len(data.shape)):
slices.append(slice(
begin[i] if i < len(begin) else None,
end[i] if i < len(end) else None,
strides[i] if i < len(strides) else None))
return data[tuple(slices)]
...@@ -249,13 +249,11 @@ def verify_take(src_shape, indices_src, axis=None): ...@@ -249,13 +249,11 @@ def verify_take(src_shape, indices_src, axis=None):
for device in ["llvm", "opencl", "sdaccel", "aocl_sw_emu"]: for device in ["llvm", "opencl", "sdaccel", "aocl_sw_emu"]:
check_device(device) check_device(device)
def verify_strided_slice(in_shape, begin, end, stride=None): def verify_strided_slice(in_shape, begin, end, strides=None):
stride = stride if stride else [1, 1, 1]
A = tvm.placeholder(shape=in_shape, name="A") A = tvm.placeholder(shape=in_shape, name="A")
B = topi.strided_slice(A, begin, end, stride) + 1 strides = [1,1,1] if strides is None else strides
def test_forward(x, begin, end, stride): B = topi.strided_slice(A, begin, end, strides) + 1
return x[begin[0]:end[0]:stride[0],
begin[1]:end[1]:stride[1], begin[2]:end[2]:stride[2]] + 1
def check_device(device): def check_device(device):
ctx = tvm.context(device, 0) ctx = tvm.context(device, 0)
if not ctx.exist: if not ctx.exist:
...@@ -267,7 +265,8 @@ def verify_strided_slice(in_shape, begin, end, stride=None): ...@@ -267,7 +265,8 @@ def verify_strided_slice(in_shape, begin, end, stride=None):
foo = tvm.build(s, [A, B], device, name="stride_slice") foo = tvm.build(s, [A, B], device, name="stride_slice")
x_np = np.random.uniform(size=in_shape).astype(A.dtype) x_np = np.random.uniform(size=in_shape).astype(A.dtype)
out_npy = test_forward(x_np, begin, end, stride) out_npy = topi.testing.strided_slice_python(
x_np, begin, end, strides) + 1
data_nd = tvm.nd.array(x_np, ctx) data_nd = tvm.nd.array(x_np, ctx)
out_nd = tvm.nd.empty(out_npy.shape, ctx=ctx, dtype=A.dtype) out_nd = tvm.nd.empty(out_npy.shape, ctx=ctx, dtype=A.dtype)
foo(data_nd, out_nd) foo(data_nd, out_nd)
...@@ -298,7 +297,7 @@ def verify_gather_nd(src_shape, indices_src, indices_dtype): ...@@ -298,7 +297,7 @@ def verify_gather_nd(src_shape, indices_src, indices_dtype):
shape_size = shape_size * src_shape[i] shape_size = shape_size * src_shape[i]
data_npy = np.arange(shape_size, dtype=src_dtype).reshape((src_shape)) data_npy = np.arange(shape_size, dtype=src_dtype).reshape((src_shape))
out_npys = topi.testing.gather_nd_python(data_npy, indices_src) out_npys = topi.testing.gather_nd_python(data_npy, indices_src)
data_nd = tvm.nd.array(data_npy, ctx) data_nd = tvm.nd.array(data_npy, ctx)
indices_nd = tvm.nd.array(indices_src, ctx) indices_nd = tvm.nd.array(indices_src, ctx)
out_nd = tvm.nd.empty(out_npys.shape, ctx=ctx, dtype=src_dtype) out_nd = tvm.nd.empty(out_npys.shape, ctx=ctx, dtype=src_dtype)
...@@ -412,6 +411,7 @@ def test_gather_nd(): ...@@ -412,6 +411,7 @@ def test_gather_nd():
indices_dtype) indices_dtype)
if __name__ == "__main__": if __name__ == "__main__":
test_strided_slice()
test_concatenate() test_concatenate()
test_tranpose() test_tranpose()
test_expand_dims() test_expand_dims()
...@@ -421,5 +421,4 @@ if __name__ == "__main__": ...@@ -421,5 +421,4 @@ if __name__ == "__main__":
test_flip() test_flip()
test_expand_like() test_expand_like()
test_take() test_take()
test_strided_slice()
test_gather_nd() test_gather_nd()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment