Commit 40ac2064 by Dayananda V Committed by Tianqi Chen

add take frontend (#1307)

parent 4503f77b
...@@ -48,6 +48,16 @@ struct SplitParam : public dmlc::Parameter<SplitParam> { ...@@ -48,6 +48,16 @@ struct SplitParam : public dmlc::Parameter<SplitParam> {
} }
}; };
struct TakeParam : public dmlc::Parameter<TakeParam> {
dmlc::optional<int> axis;
DMLC_DECLARE_PARAMETER(TakeParam) {
DMLC_DECLARE_FIELD(axis).set_default(dmlc::optional<int>())
.describe("the axis over which to select values.");
}
};
struct StridedSliceParam : public dmlc::Parameter<StridedSliceParam> { struct StridedSliceParam : public dmlc::Parameter<StridedSliceParam> {
// numpy convention, only support indices, not support list. // numpy convention, only support indices, not support list.
Tuple<int64_t> begin; Tuple<int64_t> begin;
......
...@@ -61,6 +61,10 @@ reg.register_schedule("concatenate", _fschedule_injective) ...@@ -61,6 +61,10 @@ reg.register_schedule("concatenate", _fschedule_injective)
reg.register_pattern("split", OpPattern.INJECTIVE) reg.register_pattern("split", OpPattern.INJECTIVE)
reg.register_schedule("split", _fschedule_injective) reg.register_schedule("split", _fschedule_injective)
# take
reg.register_pattern("take", OpPattern.INJECTIVE)
reg.register_schedule("take", _fschedule_injective)
# strided_slice # strided_slice
reg.register_pattern("strided_slice", OpPattern.INJECTIVE) reg.register_pattern("strided_slice", OpPattern.INJECTIVE)
reg.register_schedule("strided_slice", _fschedule_injective) reg.register_schedule("strided_slice", _fschedule_injective)
......
...@@ -1001,6 +1001,126 @@ Examples:: ...@@ -1001,6 +1001,126 @@ Examples::
return Array<Tensor>{ topi::flip(inputs[0], param.axis) }; return Array<Tensor>{ topi::flip(inputs[0], param.axis) };
}); });
// take
DMLC_REGISTER_PARAMETER(TakeParam);
inline bool TakeInferShape(const NodeAttrs& attrs,
std::vector<TShape>* in_shape,
std::vector<TShape>* out_shape) {
CHECK_EQ(in_shape->size(), 2U);
CHECK_EQ(out_shape->size(), 1U);
const TShape& dshape = (*in_shape)[0];
const TShape& indicesshape = (*in_shape)[1];
if (dshape.ndim() == 0) return false;
if (indicesshape.ndim() == 0) return false;
const TakeParam& param = nnvm::get<TakeParam>(attrs.parsed);
TShape oshape((!param.axis ? 0: dshape.ndim() - 1) + indicesshape.ndim());
if (!param.axis) {
for (size_t j = 0; j < indicesshape.ndim(); ++j) {
oshape[j] = indicesshape[j];
}
} else {
int axis = param.axis.value();
if (axis < 0) {
axis += dshape.ndim();
}
CHECK_LT(axis, dshape.ndim());
size_t posi = 0;
for (size_t i = 0; i < dshape.ndim(); ++i) {
if (static_cast<int>(i) == axis) {
for (size_t j = 0; j < indicesshape.ndim(); ++j) {
oshape[posi++] = indicesshape[j];
}
} else {
oshape[posi++] = dshape[i];
}
}
}
NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, 0, dshape);
NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, 1, indicesshape);
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0, oshape);
return dshape.Size() != 0;
}
inline bool TakeInferType(const NodeAttrs& attrs,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
CHECK_EQ(in_attrs->size(), 2U);
CHECK_EQ(out_attrs->size(), 1U);
CHECK_EQ((*in_attrs)[1], kInt32);
NNVM_ASSIGN_INPUT_TYPE(attrs, *in_attrs, 0, (*in_attrs)[0]);
NNVM_ASSIGN_INPUT_TYPE(attrs, *in_attrs, 1, static_cast<int>(kInt32));
NNVM_ASSIGN_OUTPUT_TYPE(attrs, *out_attrs, 0, (*in_attrs)[0]);
return true;
}
inline bool TakeCorrectLayout(const NodeAttrs& attrs,
std::vector<Layout> *ilayouts,
const std::vector<Layout> *last_ilayouts,
std::vector<Layout> *olayouts) {
CHECK_EQ(ilayouts->size(), last_ilayouts->size());
CHECK_EQ(olayouts->size(), 1U);
for (size_t i = 0; i < ilayouts->size(); ++i) {
const Layout& input = last_ilayouts->at(i).defined() ?
last_ilayouts->at(i) : ilayouts->at(i);
NNVM_ASSIGN_LAYOUT(*ilayouts, i, input);
}
return true;
}
NNVM_REGISTER_OP(take)
.describe(R"code(Take elements from an array along an axis.
When axis is not None, this function does the same thing as 'fancy' indexing
(indexing arrays using arrays); however, it can be easier to use if you need
elements along a given axis.
**Note** that when axis is none the flattened input array is used.
Examples::
a = [[ 1, 2],
[ 3, 4]]
indices = [3, 0, 2]
take(a, indices) = [ 4, 1, 3]
a = [[ 1., 2.],
[ 3., 4.]]
indices = [1, 0]
take(a, indices, axis=1) = [[ 2., 1.],
[ 4., 3.]]
)code" NNVM_ADD_FILELINE)
.add_argument("data", "Tensor", "Array to be indexed")
.add_argument("indices", "Tensor", "The indices of the values to extract")
.add_arguments(TakeParam::__FIELDS__())
.set_attr_parser(ParamParser<TakeParam>)
.set_attr<FInferShape>("FInferShape", TakeInferShape)
.set_attr<FInferType>("FInferType", TakeInferType)
.set_attr<FCorrectLayout>("FCorrectLayout", TakeCorrectLayout)
.set_num_inputs(2)
.set_num_outputs(1)
.set_support_level(1)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
const TakeParam& param = nnvm::get<TakeParam>(attrs.parsed);
if (!param.axis) {
return Array<Tensor>{
topi::take(inputs[0], inputs[1]) };
} else {
return Array<Tensor>{
topi::take(inputs[0], inputs[1], param.axis.value()) };
}
});
// SliceLike // SliceLike
DMLC_REGISTER_PARAMETER(SliceLikeParam); DMLC_REGISTER_PARAMETER(SliceLikeParam);
......
...@@ -365,6 +365,40 @@ def test_strided_slice(): ...@@ -365,6 +365,40 @@ def test_strided_slice():
verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4]) verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4])
verify_strided_slice((3, 4, 3), [1, 1], [4, 4, 3]) verify_strided_slice((3, 4, 3), [1, 1], [4, 4, 3])
def verify_take(src_shape, indices_src, axis=None):
src_dtype = "float32"
indices_dtype = "int32"
indices_src = np.array(indices_src, dtype=indices_dtype)
a = sym.Variable("a")
indices = sym.Variable("indices")
y = sym.take(a, indices, axis=axis)
for target, ctx in ctx_list():
# set input
shape_dict = {"a":src_shape, "indices":indices_src.shape}
type_dict = {"a":src_dtype, "indices":indices_dtype}
graph, lib, _ = nnvm.compiler.build(y, target, shape=shape_dict, dtype=type_dict)
m = graph_runtime.create(graph, lib, ctx)
shape_size = 1
for i in range(len(src_shape)):
shape_size = shape_size * src_shape[i]
a_src = np.arange(shape_size, dtype=src_dtype).reshape((src_shape))
out_np = np.take(a_src, indices_src, axis=axis)
m.run(a=a_src, indices=indices_src)
out = m.get_output(0, tvm.nd.empty(out_np.shape, dtype=src_dtype))
np.testing.assert_allclose(out.asnumpy(), out_np, atol=1e-5, rtol=1e-5)
def test_take():
verify_take((4,), [1])
verify_take((4,), [[0,1,2,3]])
verify_take((3,3,3), [[11,25]])
verify_take((4,), [[0,1],[2,3]])
verify_take((4,), [1], 0)
verify_take((2,2), [[[1,0],[0,1]]], 0)
verify_take((2,2), [[[1,0],[0,1]]], 1)
verify_take((4,3,5,6), [[2,1,0,0]], -2)
def verify_squeeze(dshape, axis): def verify_squeeze(dshape, axis):
x = sym.Variable("x") x = sym.Variable("x")
if axis: if axis:
...@@ -481,6 +515,7 @@ if __name__ == "__main__": ...@@ -481,6 +515,7 @@ if __name__ == "__main__":
test_softmax() test_softmax()
test_squeeze() test_squeeze()
test_pad() test_pad()
test_take()
test_lrn() test_lrn()
test_l2_normalize() test_l2_normalize()
test_strided_slice() test_strided_slice()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment