Commit 4fb58115 by Pariksheet Pinjari Committed by Tianqi Chen

Strided_slice added in NNVM (#1318)

parent 2aa1f054
......@@ -48,6 +48,22 @@ struct SplitParam : public dmlc::Parameter<SplitParam> {
}
};
struct StridedSliceParam : public dmlc::Parameter<StridedSliceParam> {
// numpy convention, only support indices, not support list.
Tuple<int64_t> begin;
Tuple<int64_t> end;
Tuple<int64_t> stride;
DMLC_DECLARE_PARAMETER(StridedSliceParam) {
DMLC_DECLARE_FIELD(begin)
.describe("Indices for begin of slice");
DMLC_DECLARE_FIELD(end)
.describe("Indices for end of the slice");
DMLC_DECLARE_FIELD(stride).set_default(Tuple<int64_t>())
.describe("Stride values of the slice");
}
};
enum TypeFlag {
kFloat32 = 0,
kFloat64 = 1,
......
......@@ -61,6 +61,10 @@ reg.register_schedule("concatenate", _fschedule_injective)
reg.register_pattern("split", OpPattern.INJECTIVE)
reg.register_schedule("split", _fschedule_injective)
# strided_slice
reg.register_pattern("strided_slice", OpPattern.INJECTIVE)
reg.register_schedule("strided_slice", _fschedule_injective)
# slice_like
reg.register_pattern("slice_like", OpPattern.INJECTIVE)
reg.register_schedule("slice_like", _fschedule_injective)
......@@ -829,6 +829,119 @@ Examples::
};
});
// strided_slice
DMLC_REGISTER_PARAMETER(StridedSliceParam);
inline void StridedSliceParamParser(nnvm::NodeAttrs* attrs) {
StridedSliceParam param;
param.Init(attrs->dict);
attrs->parsed = std::move(param);
}
inline bool StridedSliceInferShape(const NodeAttrs& attrs,
std::vector<TShape>* in_shape,
std::vector<TShape>* out_shape) {
const StridedSliceParam& param = nnvm::get<StridedSliceParam>(attrs.parsed);
const TShape& dshape = (*in_shape)[0];
if (dshape.ndim() == 0) return false;
TShape oshape = dshape;
dim_t num_axis = dshape.ndim();
std::vector<int64_t> begin_vec;
std::copy(param.begin.begin(), param.begin.end(), std::back_inserter(begin_vec));
for (dim_t i = begin_vec.size(); i < num_axis; ++i) {
begin_vec.push_back(0);
}
std::vector<int64_t> end_vec;
std::copy(param.end.begin(), param.end.end(), std::back_inserter(end_vec));
for (dim_t i = end_vec.size(); i < num_axis; ++i) {
end_vec.push_back(dshape[i]);
}
std::vector<int64_t> stride_vec;
std::copy(param.stride.begin(), param.stride.end(), std::back_inserter(stride_vec));
for (dim_t i = stride_vec.size(); i < num_axis; ++i) {
stride_vec.push_back(1);
}
for (dim_t i = 0; i < num_axis; ++i) {
int64_t begin_range = stride_vec[i] < 0 ? -1 : 0;
int64_t end_range = stride_vec[i] < 0 ? dshape[i] - 1 : dshape[i];
int64_t begin = begin_vec[i] < 0 ? dshape[i] + begin_vec[i] : begin_vec[i];
int64_t end = end_vec[i] < 0 ? dshape[i] + end_vec[i] : end_vec[i];
begin = std::min(std::max(begin, begin_range), end_range);
end = std::min(std::max(end, begin_range), end_range);
int interval = std::abs(end - begin);
int slice_size = static_cast<int>((interval
+ std::abs(stride_vec[i]) - 1) / std::abs(stride_vec[i]));
CHECK(stride_vec[i] < 0 ? (end < begin) : (begin < end))
<< ": Input [Begin=" << begin_vec[i] << ", End=" << end_vec[i]
<< "] is invalid for axis=" << i;
oshape[i] = slice_size;
}
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0, oshape);
return true;
}
NNVM_REGISTER_OP(strided_slice)
.describe(R"code(Strided slice of an array.
Examples::
x = [[ 1., 4., 7., 10.],
[ 2., 5., 8., 11.],
[ 3., 6., 9., 12.]]
strided_slice(x, begin=[0, 1], end=[2, 4], stride=[1, 1]) = [[ 4., 7., 10.],
[ 5., 8., 11.]]
x = [[[ 1., 2.],
[ 3., 4.]],
[[ 5., 6.],
[ 7., 8.]]]
strided_slice(x, begin=[0, 0], end=[2, 2]) = [[[ 1., 2.],
[ 3., 4.]],
[[ 5., 6.],
[ 7., 8.]]]
)code" NNVM_ADD_FILELINE)
.add_argument("data", "Tensor", "Array to be sliced")
.add_arguments(StridedSliceParam::__FIELDS__())
.set_attr_parser(StridedSliceParamParser)
.set_attr<FInferShape>("FInferShape", StridedSliceInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FCorrectLayout>("FCorrectLayout", ElemwiseArbitraryLayout<1, 1>)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
const StridedSliceParam& param = nnvm::get<StridedSliceParam>(attrs.parsed);
Array<Expr> begin;
Array<Expr> end;
Array<Expr> stride;
for (int64_t i : param.begin) {
begin.push_back(tvm::make_const(tvm::Int(32), i));
}
for (int64_t i : param.end) {
end.push_back(tvm::make_const(tvm::Int(32), i));
}
for (int64_t i : param.stride) {
stride.push_back(tvm::make_const(tvm::Int(32), i));
}
return Array<Tensor>{ topi::strided_slice(inputs[0], begin, end, stride) };
})
.set_support_level(1);
// Flip
DMLC_REGISTER_PARAMETER(FlipParam);
......
......@@ -329,6 +329,41 @@ def test_split():
verify_split((5, 3), [3], axis=0)
verify_split((5, 9, 3), [3, 4], axis=1)
def verify_strided_slice(ishape, begin, end, strideinp=None):
stride = strideinp if strideinp else [1, 1, 1]
x = sym.Variable("x")
if strideinp:
y = sym.strided_slice(x, begin = begin, end = end, stride = stride) + 1
else:
y = sym.strided_slice(x, begin = begin, end = end) + 1
x_np = np.random.uniform(size=ishape).astype("float32")
for i in range(len(begin), 3):
begin.append(0)
for i in range(len(end), 3):
end.append(ishape[i])
def test_forward(x, begin, end, stride):
return x[begin[0]:end[0]:stride[0],
begin[1]:end[1]:stride[1], begin[2]:end[2]:stride[2]] + 1
for target, ctx in ctx_list():
# set input
graph, lib, _ = nnvm.compiler.build(y, target, {"x": ishape})
m = graph_runtime.create(graph, lib, ctx)
m.run(x=x_np)
res = test_forward(x_np, begin, end, stride)
out = m.get_output(0, tvm.nd.empty(res.shape))
np.testing.assert_allclose(out.asnumpy(), res, atol=1e-5, rtol=1e-5)
def test_strided_slice():
verify_strided_slice((3, 4, 3), [0, 0, 0], [4, -5, 4], [1, -1, 2])
verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1])
verify_strided_slice((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1])
verify_strided_slice((3, 4, 3), [1, 0, 0], [2, 2, 3], [1, 1, 2])
verify_strided_slice((3, 4, 3), [1, -1, 0], [2, -3, 3], [1, -1, 1])
verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4, 3])
verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 1000, 3])
verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4])
verify_strided_slice((3, 4, 3), [1, 1], [4, 4, 3])
def verify_squeeze(dshape, axis):
x = sym.Variable("x")
......@@ -448,3 +483,4 @@ if __name__ == "__main__":
test_pad()
test_lrn()
test_l2_normalize()
test_strided_slice()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment