Commit 28499304 by Siju Committed by Tianqi Chen

[RELAY]Slice_like support (#2014)

parent 401ffe13
...@@ -143,6 +143,7 @@ This level support backpropagation of broadcast operators. It is temporary. ...@@ -143,6 +143,7 @@ This level support backpropagation of broadcast operators. It is temporary.
tvm.relay.broadcast_to_like tvm.relay.broadcast_to_like
tvm.relay.collapse_sum_like tvm.relay.collapse_sum_like
tvm.relay.slice_like
Level 1 Definitions Level 1 Definitions
...@@ -231,7 +232,6 @@ Level 4 Definitions ...@@ -231,7 +232,6 @@ Level 4 Definitions
.. autofunction:: tvm.relay.strided_slice .. autofunction:: tvm.relay.strided_slice
Level 5 Definitions Level 5 Definitions
------------------- -------------------
.. autofunction:: tvm.relay.image.resize .. autofunction:: tvm.relay.image.resize
...@@ -241,3 +241,4 @@ Level 10 Definitions ...@@ -241,3 +241,4 @@ Level 10 Definitions
-------------------- --------------------
.. autofunction:: tvm.relay.broadcast_to_like .. autofunction:: tvm.relay.broadcast_to_like
.. autofunction:: tvm.relay.collapse_sum_like .. autofunction:: tvm.relay.collapse_sum_like
.. autofunction:: tvm.relay.slice_like
...@@ -138,6 +138,19 @@ struct StridedSliceAttrs : public tvm::AttrsNode<StridedSliceAttrs> { ...@@ -138,6 +138,19 @@ struct StridedSliceAttrs : public tvm::AttrsNode<StridedSliceAttrs> {
.describe("Stride values of the slice"); .describe("Stride values of the slice");
} }
}; };
struct SliceLikeAttrs : public tvm::AttrsNode<SliceLikeAttrs> {
Array<Integer> axes;
TVM_DECLARE_ATTRS(SliceLikeAttrs, "relay.attrs.SliceLikeAttrs") {
TVM_ATTR_FIELD(axes)
.describe("List of axes on which input data will be sliced according to the "
"corresponding size of the second input. By default will slice "
"on all axes. Negative axes mean counting in reverse.");
}
};
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
#endif // TVM_RELAY_ATTRS_TRANSFORM_H_ #endif // TVM_RELAY_ATTRS_TRANSFORM_H_
...@@ -2,7 +2,11 @@ ...@@ -2,7 +2,11 @@
"""Backend compiler related feature registration""" """Backend compiler related feature registration"""
from __future__ import absolute_import from __future__ import absolute_import
from . import op as _reg from . import op as _reg
from .op import schedule_injective from .op import schedule_injective, OpPattern
# strided_slice # strided_slice
_reg.register_schedule("strided_slice", schedule_injective) _reg.register_schedule("strided_slice", schedule_injective)
# slice_like
_reg.register_schedule("slice_like", schedule_injective)
_reg.register_pattern("slice_like", OpPattern.INJECTIVE)
...@@ -361,3 +361,29 @@ def strided_slice(data, begin, end, strides=None): ...@@ -361,3 +361,29 @@ def strided_slice(data, begin, end, strides=None):
""" """
strides = strides or [] strides = strides or []
return _make.strided_slice(data, list(begin), list(end), list(strides)) return _make.strided_slice(data, list(begin), list(end), list(strides))
def slice_like(data, shape_like, axes=None):
"""Slice the first input with respect to the second input.
For an input array with shape ``(d1, d2, ..., dk)``, `slice_like` operation slices the
the input array corresponding size of second array. By default will slice on all axes.
Parameters
----------
data : tvm.relay.Expr
The source array.
shape_like : tvm.relay.Expr
The new shape.
axes : Optional[Tuple[int]]
List of axes on which input data will be sliced according to the corresponding size of
the second input. By default will slice on all axes. Negative axes mean counting in reverse.
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.slice_like(data, shape_like, axes)
...@@ -1153,5 +1153,152 @@ the entries indicate where along axis the array is split. ...@@ -1153,5 +1153,152 @@ the entries indicate where along axis the array is split.
.set_support_level(3) .set_support_level(3)
.add_type_rel("Split", SplitRel); .add_type_rel("Split", SplitRel);
TVM_REGISTER_NODE_TYPE(SliceLikeAttrs);
/*!
* \brief SliceLikeRel User defined type constraint function.
* \param num_inputs Number of input types in the args.
* \param attrs The additional attributes of the operator.
* \param reporter The reporter to report solution to.
* \return False if the relation has not been resolved, it might be resolved later.
* True if this relation has been resolved.
*/
bool SliceLikeRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) {
return false;
}
const auto* target = types[1].as<TensorTypeNode>();
if (target == nullptr) {
return false;
}
const auto param = attrs.as<SliceLikeAttrs>();
CHECK(param != nullptr);
const Array<IndexExpr> dshape = data->shape;
const Array<IndexExpr> target_shape = target->shape;
std::vector<IndexExpr>&& oshape = AsVector(dshape);
if (!param->axes.defined()) {
for (size_t i = 0; i < dshape.size(); ++i) {
if (i < target_shape.size()) {
oshape[i] = target_shape[i];
CHECK(reporter->Assert(oshape[i] <= dshape[i]))
<< "End index of axis " << i << " exceeds input shape: "
<< oshape[i] << " vs " << dshape[i];
}
}
} else {
CHECK(param->axes.size() != 0) << "Axes cannot be empty.";
for (Integer val : param->axes) {
int axis = val->value;
if (axis < 0) {
axis += dshape.size();
}
CHECK(axis < static_cast<int>(target_shape.size()))
<< "Axis " << axis << " exceeds dimension "
<< target_shape.size() << " of target_shape.";
oshape[axis] = target_shape[axis];
CHECK(reporter->Assert(oshape[axis] <= dshape[axis]))
<< "End index of axis " << axis << " exceeds input shape: "
<< oshape[axis] << " vs " << dshape[axis];
}
}
reporter->Assign(types[2], TensorTypeNode::make(oshape, data->dtype));
return true;
}
Expr MakeSliceLike(Expr data,
Expr shape_like,
Array<Integer> axes) {
auto attrs = make_node<SliceLikeAttrs>();
attrs->axes = std::move(axes);
static const Op& op = Op::Get("slice_like");
return CallNode::make(op, {data, shape_like}, Attrs(attrs), {});
}
// Adapter function to make int array.
Array<Integer> GetIntArray(Array<IndexExpr> arr) {
for (size_t i = 0; i < arr.size(); ++i) {
CHECK(!arr[i].defined() || arr[i].as<IntImm>())
<< "Expect an int array";
}
return Array<Integer>(arr.node_);
}
template<typename AttrType>
Array<Tensor> SliceLikeCompute(const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
const Target& target) {
const auto* param = attrs.as<AttrType>();
CHECK(param != nullptr);
Array<IndexExpr> src_shape = inputs[0]->shape;
Array<IndexExpr> target_shape = inputs[1]->shape;
Array<IndexExpr> begin_idx, end_idx, strides;
for (size_t i = 0; i < src_shape.size(); ++i) {
begin_idx.push_back(0);
strides.push_back(1);
}
end_idx = Array<IndexExpr>(src_shape);
if (!param->axes.defined()) {
for (size_t i = 0; i < src_shape.size(); ++i) {
if (i < target_shape.size()) {
end_idx.Set(i, target_shape[i]);
CHECK_LE(topi::GetConstInt(end_idx[i]),
topi::GetConstInt(src_shape[i]))
<< "End index of axis " << i << " exceeds input shape: "
<< topi::GetConstInt(end_idx[i]) << " vs "
<< topi::GetConstInt(src_shape[i]);
}
}
} else {
for (int axis : param->axes) {
if (axis < 0) {
axis = static_cast<int>(src_shape.size()) + axis;
}
end_idx.Set(axis, target_shape[axis]);
CHECK_LE(topi::GetConstInt(end_idx[axis]),
topi::GetConstInt(src_shape[axis]))
<< "End index of axis " << axis << " exceeds input shape: "
<< topi::GetConstInt(end_idx[axis]) << " vs "
<< topi::GetConstInt(src_shape[axis]);
}
}
return Array<Tensor>{
topi::strided_slice(inputs[0],
GetIntArray(begin_idx),
GetIntArray(end_idx),
GetIntArray(strides))
};
}
TVM_REGISTER_API("relay.op._make.slice_like")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 3>(MakeSliceLike, args, rv);
});
RELAY_REGISTER_OP("slice_like")
.describe(R"code(Slice the first input respect to the second input.
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.SlicelikeAttrs")
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("shape_like", "Tensor", "Shape tensor.")
.set_support_level(10)
.add_type_rel("SliceLike", SliceLikeRel)
.set_attr<FTVMCompute>("FTVMCompute", SliceLikeCompute<SliceLikeAttrs>);
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
""" Support level10 operator test cases. """ Support level10 operator test cases.
""" """
import numpy as np
import tvm import tvm
from tvm import relay from tvm import relay
from tvm.relay.testing import ctx_list
def test_collapse_sum_like(): def test_collapse_sum_like():
x = relay.Var("x", relay.ty.TensorType((3, 4, 5, 6), "int8")) x = relay.Var("x", relay.ty.TensorType((3, 4, 5, 6), "int8"))
...@@ -18,6 +20,66 @@ def test_broadcast_to_like(): ...@@ -18,6 +20,66 @@ def test_broadcast_to_like():
zz = relay.ir_pass.infer_type(z) zz = relay.ir_pass.infer_type(z)
assert zz.checked_type == relay.ty.TensorType((3, 4, 5, 6), "int8") assert zz.checked_type == relay.ty.TensorType((3, 4, 5, 6), "int8")
def np_slice_like(np_data, np_shape_like, axis=None):
begin_idx = [0 for _ in np_data.shape]
end_idx = list(np_data.shape)
if axis:
for i in axis:
if i < 0:
i = len(np_data.shape) + i
end_idx[i] = np_shape_like.shape[i]
else:
for i in range(len(np_data.shape)):
if i < len(np_shape_like.shape):
end_idx[i] = np_shape_like.shape[i]
slice_idx = []
for b, e in zip(begin_idx, end_idx):
slice_idx.append(slice(b, e))
np_result = np_data[tuple(slice_idx)]
return np_result
def verify_slice_like(data, slice_like, axes, output, dtype="float32"):
x = relay.var("data", relay.TensorType(data, dtype))
y = relay.var("slice_like", relay.TensorType(slice_like, dtype))
z = relay.slice_like(x, y, axes)
zz = relay.ir_pass.infer_type(z)
if axes:
assert "axes" in z.astext()
assert zz.checked_type == relay.ty.TensorType(output, dtype)
if all(isinstance(v, int) == 0 for v in data) or \
all(isinstance(v, int) == 0 for v in slice_like):
return
func = relay.Function([x, y], z)
x_data = np.random.uniform(size=data).astype(dtype)
y_data = np.random.uniform(size=slice_like).astype(dtype)
ref_res = np_slice_like(x_data, y_data, axes)
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(x_data, y_data)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
def test_slice_like():
d1, d2, d3, d4 = tvm.var("d1"), tvm.var("d2"), tvm.var("d3"), tvm.var("d4")
verify_slice_like(data=(d1, d2, d3), slice_like=(1, 2, 3), axes=None, output=(1, 2, 3))
verify_slice_like(data=(1, 2, 3), slice_like=(d1, d2, d3), axes=None, output=(d1, d2, d3))
verify_slice_like(data=(d2, d3, d4), slice_like=(d1, d2, d3), axes=(1,2), output=(d2, d2, d3))
verify_slice_like(data=(3, 4, 5), slice_like=(1, 2, 3), axes=None, output=(1, 2, 3))
verify_slice_like(data=(3, 4, 5), slice_like=(1, 2), axes=None, output=(1, 2, 5))
verify_slice_like(data=(3, 4, 5), slice_like=(1, 2, 3), axes=(1, 2), output=(3, 2, 3))
verify_slice_like(data=(3, 4, 5), slice_like=(1, 2, 3), axes=(-1, -3), output=(1, 4, 3))
verify_slice_like(data=(1, 3, 224, 224),
slice_like=(1, 3, 112, 112),
axes=(2, 3),
output=(1, 3, 112, 112))
if __name__ == "__main__": if __name__ == "__main__":
test_collapse_sum_like() test_collapse_sum_like()
test_broadcast_to_like() test_broadcast_to_like()
test_slice_like()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment