Commit 3746d902 by Haichen Shen Committed by Yizhi Liu

[Relay/TOPI][OP] Add clip and wrap mode support in take (#2858)

* Update take

* Add special case for canonical simplify and fix test cases

* Use lower case for wrap and clip

* remove unnecssary lower

* Fix mxnet converter for take

* fix
parent 7cc9240a
Subproject commit 86351c40824dfc4cbb7447d70e5e63d9bd76eb90 Subproject commit 55ba1778fd264c7507953552d8e51212ed11f748
...@@ -75,10 +75,15 @@ struct ReshapeAttrs : public tvm::AttrsNode<ReshapeAttrs> { ...@@ -75,10 +75,15 @@ struct ReshapeAttrs : public tvm::AttrsNode<ReshapeAttrs> {
struct TakeAttrs : public tvm::AttrsNode<TakeAttrs> { struct TakeAttrs : public tvm::AttrsNode<TakeAttrs> {
Integer axis; Integer axis;
std::string mode;
TVM_DECLARE_ATTRS(TakeAttrs, "relay.attrs.TakeAttrs") { TVM_DECLARE_ATTRS(TakeAttrs, "relay.attrs.TakeAttrs") {
TVM_ATTR_FIELD(axis).set_default(NullValue<Integer>()) TVM_ATTR_FIELD(axis).set_default(NullValue<Integer>())
.describe("The axis over which to select values."); .describe("The axis over which to select values.");
TVM_ATTR_FIELD(mode).set_default("clip")
.describe("Specify how out-of-bound indices will behave."
"clip - clip to the range (default)"
"wrap - wrap around the indices");
} }
}; };
......
...@@ -444,6 +444,15 @@ def _mx_tile(inputs, attrs): ...@@ -444,6 +444,15 @@ def _mx_tile(inputs, attrs):
return _op.tile(inputs[0], **new_attrs) return _op.tile(inputs[0], **new_attrs)
def _mx_take(inputs, attrs):
assert len(inputs) == 2
mode = attrs.get_str("mode", "clip")
if mode == "raise":
raise RuntimeError("take doesn't support raise mode")
axis = attrs.get_int("axis", 0)
return _op.take(inputs[0], inputs[1].astype("int32"), axis, mode)
def _mx_reverse(inputs, attrs): def _mx_reverse(inputs, attrs):
assert len(inputs) == 1 assert len(inputs) == 1
new_attrs = {} new_attrs = {}
...@@ -749,6 +758,7 @@ _convert_map = { ...@@ -749,6 +758,7 @@ _convert_map = {
"_full" : _mx_full, "_full" : _mx_full,
"repeat" : _mx_repeat, "repeat" : _mx_repeat,
"tile" : _mx_tile, "tile" : _mx_tile,
"take" : _mx_take,
"reverse" : _mx_reverse, "reverse" : _mx_reverse,
"squeeze" : _mx_squeeze, "squeeze" : _mx_squeeze,
"broadcast_axis": _mx_broadcast_axis, "broadcast_axis": _mx_broadcast_axis,
......
...@@ -186,7 +186,7 @@ def reshape_like(data, shape_like): ...@@ -186,7 +186,7 @@ def reshape_like(data, shape_like):
return _make.reshape_like(data, shape_like) return _make.reshape_like(data, shape_like)
def take(data, indices, axis=None): def take(data, indices, axis=None, mode="clip"):
"""Take elements from an array along an axis. """Take elements from an array along an axis.
Parameters Parameters
...@@ -201,12 +201,17 @@ def take(data, indices, axis=None): ...@@ -201,12 +201,17 @@ def take(data, indices, axis=None):
The axis over which to select values. By default, The axis over which to select values. By default,
the flattened input array is used. the flattened input array is used.
mode : str, optional
Specifies how out-of-bound indices will behave.
clip - clip to the range (default)
wrap - wrap around the indices
Returns Returns
------- -------
ret : relay.Expr ret : relay.Expr
The computed result. The computed result.
""" """
return _make.take(data, indices, axis) return _make.take(data, indices, axis, mode)
def full(fill_value, shape=(), dtype=""): def full(fill_value, shape=(), dtype=""):
......
...@@ -753,24 +753,26 @@ Array<Tensor> TakeCompute(const Attrs& attrs, ...@@ -753,24 +753,26 @@ Array<Tensor> TakeCompute(const Attrs& attrs,
const auto* param = attrs.as<TakeAttrs>(); const auto* param = attrs.as<TakeAttrs>();
CHECK(param != nullptr); CHECK(param != nullptr);
if (!param->axis.defined()) { if (!param->axis.defined()) {
return Array<Tensor>{ topi::take(inputs[0], inputs[1]) }; return Array<Tensor>{ topi::take(inputs[0], inputs[1], param->mode) };
} else { } else {
return Array<Tensor>{ topi::take(inputs[0], inputs[1], param->axis) }; return Array<Tensor>{ topi::take(inputs[0], inputs[1], param->axis, param->mode) };
} }
} }
Expr MakeTake(Expr data, Expr MakeTake(Expr data,
Expr indices, Expr indices,
Integer axis) { Integer axis,
std::string mode) {
auto attrs = make_node<TakeAttrs>(); auto attrs = make_node<TakeAttrs>();
attrs->axis = std::move(axis); attrs->axis = std::move(axis);
attrs->mode = std::move(mode);
static const Op& op = Op::Get("take"); static const Op& op = Op::Get("take");
return CallNode::make(op, {data, indices}, Attrs(attrs), {}); return CallNode::make(op, {data, indices}, Attrs(attrs), {});
} }
TVM_REGISTER_API("relay.op._make.take") TVM_REGISTER_API("relay.op._make.take")
.set_body([](const TVMArgs& args, TVMRetValue* rv) { .set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 3>(MakeTake, args, rv); runtime::detail::unpack_call<Expr, 4>(MakeTake, args, rv);
}); });
RELAY_REGISTER_OP("take") RELAY_REGISTER_OP("take")
......
...@@ -464,7 +464,6 @@ def test_forward_embedding(): ...@@ -464,7 +464,6 @@ def test_forward_embedding():
verify((2, 2), (4, 5)) verify((2, 2), (4, 5))
verify((2, 3, 4), (4, 5)) verify((2, 3, 4), (4, 5))
def test_forward_smooth_l1(): def test_forward_smooth_l1():
data = mx.sym.var('data') data = mx.sym.var('data')
mx_sym = mx.sym.smooth_l1(data) mx_sym = mx.sym.smooth_l1(data)
...@@ -472,6 +471,26 @@ def test_forward_smooth_l1(): ...@@ -472,6 +471,26 @@ def test_forward_smooth_l1():
mx_sym = mx.sym.smooth_l1(data, scalar=1.0) mx_sym = mx.sym.smooth_l1(data, scalar=1.0)
verify_mxnet_frontend_impl(mx_sym, (3, 4), (3, 4)) verify_mxnet_frontend_impl(mx_sym, (3, 4), (3, 4))
def test_forward_take():
def verify(shape, indices_src, axis, mode="clip"):
x_np = np.random.uniform(size=shape).astype("float32")
indices_np = np.array(indices_src, dtype="float32")
ref_res = mx.nd.take(mx.nd.array(x_np), mx.nd.array(indices_np), axis, mode)
mx_sym = mx.sym.take(mx.sym.var("x"), mx.sym.var("y"), axis, mode)
new_sym, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape, "y": indices_np.shape})
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(new_sym)(x_np, indices_np)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
verify((2,2), [[[1,0],[0,1]]], 0)
verify((2,2), [[[1,0],[0,1]]], 1)
verify((4,3,5,6), [[2,1,0,0]], -2)
verify((3,4), [-1, 5], 0)
verify((3,4), [-1, 5], 0, mode="wrap")
verify((3,4), [-1, 5], 1)
verify((3,4), [-1, 5], 1, mode="wrap")
if __name__ == '__main__': if __name__ == '__main__':
test_forward_mlp() test_forward_mlp()
test_forward_vgg() test_forward_vgg()
...@@ -507,3 +526,4 @@ if __name__ == '__main__': ...@@ -507,3 +526,4 @@ if __name__ == '__main__':
test_forward_full() test_forward_full()
test_forward_embedding() test_forward_embedding()
test_forward_smooth_l1() test_forward_smooth_l1()
test_forward_take()
...@@ -243,17 +243,17 @@ def test_take_infer_type(): ...@@ -243,17 +243,17 @@ def test_take_infer_type():
verify_take((d1, d2, d3, d4), (d5, d6), (d1, d2, d5, d6, d4), -2) verify_take((d1, d2, d3, d4), (d5, d6), (d1, d2, d5, d6, d4), -2)
def test_take(): def test_take():
def verify_take(src_shape, indices_src, axis=None): def verify_take(src_shape, indices_src, axis=None, mode="clip"):
src_dtype = "float32" src_dtype = "float32"
indices_dtype = "int32" indices_dtype = "int32"
indices_src = np.array(indices_src, dtype=indices_dtype) indices_src = np.array(indices_src, dtype=indices_dtype)
x = relay.var("x", relay.TensorType(src_shape, src_dtype)) x = relay.var("x", relay.TensorType(src_shape, src_dtype))
indices = relay.var("indices", relay.TensorType(indices_src.shape, indices_dtype)) indices = relay.var("indices", relay.TensorType(indices_src.shape, indices_dtype))
z = relay.take(x, indices, axis=axis) z = relay.take(x, indices, axis=axis, mode=mode)
func = relay.Function([x, indices], z) func = relay.Function([x, indices], z)
x_data = np.random.uniform(low=-1, high=1, size=src_shape).astype(src_dtype) x_data = np.random.uniform(low=-1, high=1, size=src_shape).astype(src_dtype)
ref_res = np.take(x_data, indices=indices_src, axis=axis) ref_res = np.take(x_data, indices=indices_src, axis=axis, mode=mode)
for target, ctx in ctx_list(): for target, ctx in ctx_list():
for kind in ["graph", "debug"]: for kind in ["graph", "debug"]:
...@@ -269,6 +269,12 @@ def test_take(): ...@@ -269,6 +269,12 @@ def test_take():
verify_take((2,2), [[[1,0],[0,1]]], 0) verify_take((2,2), [[[1,0],[0,1]]], 0)
verify_take((2,2), [[[1,0],[0,1]]], 1) verify_take((2,2), [[[1,0],[0,1]]], 1)
verify_take((4,3,5,6), [[2,1,0,0]], -2) verify_take((4,3,5,6), [[2,1,0,0]], -2)
verify_take((3,4), [-5, 20])
verify_take((3,4), [-5, 20], mode="wrap")
verify_take((3,4), [-1, 2], axis=0)
verify_take((3,4), [-1, 2], axis=0, mode="wrap")
verify_take((3,4), [-1, 2], axis=1)
verify_take((3,4), [-1, 2], axis=1, mode="wrap")
def test_split_infer_type(): def test_split_infer_type():
......
...@@ -39,6 +39,11 @@ def test_simplify_mod(): ...@@ -39,6 +39,11 @@ def test_simplify_mod():
stmt = tvm.ir_pass.CanonicalSimplify(body) stmt = tvm.ir_pass.CanonicalSimplify(body)
diff = tvm.ir_pass.CanonicalSimplify(stmt.body.body.value.index - (1 + i) % 16) diff = tvm.ir_pass.CanonicalSimplify(stmt.body.body.value.index - (1 + i) % 16)
assert diff.value == 0 assert diff.value == 0
# if we can't prove that j is non-negative, we can't prove that (j+16) % 16 is j%16
index = tvm.ir_pass.CanonicalSimplify((j + 16) % 16)
assert index != j
index = tvm.ir_pass.CanonicalSimplify((j + 16) % 16, {j: tvm.Range(0, 6)})
assert index == j
# if we can't prove that j+n*32 is non-negative, we can't prove that (j+n*32) % 16 is j%16 # if we can't prove that j+n*32 is non-negative, we can't prove that (j+n*32) % 16 is j%16
index = tvm.ir_pass.CanonicalSimplify( index = tvm.ir_pass.CanonicalSimplify(
(j + n * 32) % 16, {j: tvm.Range(0, 6)}) (j + n * 32) % 16, {j: tvm.Range(0, 6)})
......
...@@ -604,22 +604,29 @@ inline Array<Tensor> split_sections(const Tensor& x, ...@@ -604,22 +604,29 @@ inline Array<Tensor> split_sections(const Tensor& x,
*/ */
inline Tensor take(const Tensor& a, inline Tensor take(const Tensor& a,
const Tensor& indices, const Tensor& indices,
std::string mode = "clip",
std::string name = "tensor", std::string name = "tensor",
std::string tag = kInjective) { std::string tag = kInjective) {
Array<Expr> a_shape = a->shape; Array<Expr> a_shape = a->shape;
Array<Expr> out_shape; Array<Expr> out_shape = indices->shape;
for (size_t j = 0; j < indices->shape.size(); ++j) { Expr a_size = 1;
out_shape.push_back(indices->shape[j]); for (size_t i = 0; i < a_shape.size(); ++i) {
a_size = a_size * a_shape[i];
} }
return compute( if (mode == "clip") {
return compute(
out_shape, [&](const Array<Var>& out_index) { out_shape, [&](const Array<Var>& out_index) {
Array<Expr> indices_position; auto idx = tvm::min(tvm::max(0, indices(out_index)), a_size - 1);
for (size_t j = 0; j < indices->shape.size(); ++j) { return a(UnravelIndex(idx, a_shape));
indices_position.push_back(out_index[j]); }, name, tag);
} } else { // mode == "wrap"
return a(UnravelIndex(indices(indices_position), a_shape)); return compute(
out_shape, [&](const Array<Var>& out_index) {
auto idx = (indices(out_index) % a_size + a_size) % a_size;
return a(UnravelIndex(idx, a_shape));
}, name, tag); }, name, tag);
}
} }
/*! /*!
...@@ -637,12 +644,15 @@ inline Tensor take(const Tensor& a, ...@@ -637,12 +644,15 @@ inline Tensor take(const Tensor& a,
inline Tensor take(const Tensor& a, inline Tensor take(const Tensor& a,
const Tensor& indices, const Tensor& indices,
int axis, int axis,
std::string mode = "clip",
std::string name = "tensor", std::string name = "tensor",
std::string tag = kInjective) { std::string tag = kInjective) {
if (axis < 0) { if (axis < 0) {
axis += static_cast<int>(a->shape.size()); axis += static_cast<int>(a->shape.size());
} }
CHECK_GE(axis, 0) << "axis out of bounds";
CHECK_LT(axis, a->shape.size()) << "axis out of bounds"; CHECK_LT(axis, a->shape.size()) << "axis out of bounds";
auto axis_dim = a->shape[axis];
int indices_len = static_cast<int>(indices->shape.size()); int indices_len = static_cast<int>(indices->shape.size());
Array<Expr> out_shape; Array<Expr> out_shape;
...@@ -655,7 +665,27 @@ inline Tensor take(const Tensor& a, ...@@ -655,7 +665,27 @@ inline Tensor take(const Tensor& a,
out_shape.push_back(a->shape[i]); out_shape.push_back(a->shape[i]);
} }
} }
return compute( if (mode == "clip") {
return compute(
out_shape, [&](const Array<Var>& out_index) {
Array<Expr> indices_position;
for (size_t j = axis; j < static_cast<size_t>(axis+indices_len); ++j) {
indices_position.push_back(out_index[j]);
}
Array<Expr> real_indices;
for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
real_indices.push_back(out_index[j]);
}
auto idx = tvm::min(tvm::max(0, indices(indices_position)),
axis_dim - 1);
real_indices.push_back(idx);
for (size_t j = axis + indices_len; j < out_index.size(); ++j) {
real_indices.push_back(out_index[j]);
}
return a(real_indices);
}, name, tag);
} else { // mode == "wrap"
return compute(
out_shape, [&](const Array<Var>& out_index) { out_shape, [&](const Array<Var>& out_index) {
Array<Expr> indices_position; Array<Expr> indices_position;
for (size_t j = axis; j < static_cast<size_t>(axis+indices_len); ++j) { for (size_t j = axis; j < static_cast<size_t>(axis+indices_len); ++j) {
...@@ -665,12 +695,14 @@ inline Tensor take(const Tensor& a, ...@@ -665,12 +695,14 @@ inline Tensor take(const Tensor& a,
for (size_t j = 0; j < static_cast<size_t>(axis); ++j) { for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
real_indices.push_back(out_index[j]); real_indices.push_back(out_index[j]);
} }
real_indices.push_back(indices(indices_position)); auto idx = (indices(indices_position) % axis_dim + axis_dim) % axis_dim;
real_indices.push_back(idx);
for (size_t j = axis + indices_len; j < out_index.size(); ++j) { for (size_t j = axis + indices_len; j < out_index.size(); ++j) {
real_indices.push_back(out_index[j]); real_indices.push_back(out_index[j]);
} }
return a(real_indices); return a(real_indices);
}, name, tag); }, name, tag);
}
} }
/*! /*!
......
...@@ -228,7 +228,7 @@ def split(ary, indices_or_sections, axis=0): ...@@ -228,7 +228,7 @@ def split(ary, indices_or_sections, axis=0):
return cpp.split(ary, indices_or_sections, axis) return cpp.split(ary, indices_or_sections, axis)
def take(a, indices, axis=None): def take(a, indices, axis=None, mode="clip"):
"""Take elements from an array along an axis. """Take elements from an array along an axis.
Parameters Parameters
...@@ -243,13 +243,18 @@ def take(a, indices, axis=None): ...@@ -243,13 +243,18 @@ def take(a, indices, axis=None):
The axis over which to select values. By default, The axis over which to select values. By default,
the flattened input array is used. the flattened input array is used.
mode : str, optional
Specifies how out-of-bound indices will behave.
clip - clip to the range (default)
wrap - wrap around the indices
Returns Returns
------- -------
ret : tvm.Tensor ret : tvm.Tensor
""" """
if axis is None: if axis is None:
return cpp.take(a, indices) return cpp.take(a, indices, mode)
return cpp.take(a, indices, int(axis)) return cpp.take(a, indices, int(axis), mode)
def gather_nd(a, indices): def gather_nd(a, indices):
......
...@@ -297,11 +297,13 @@ TVM_REGISTER_GLOBAL("topi.layout_transform") ...@@ -297,11 +297,13 @@ TVM_REGISTER_GLOBAL("topi.layout_transform")
TVM_REGISTER_GLOBAL("topi.take") TVM_REGISTER_GLOBAL("topi.take")
.set_body([](TVMArgs args, TVMRetValue *rv) { .set_body([](TVMArgs args, TVMRetValue *rv) {
if (args.size() == 2) { if (args.size() == 3) {
*rv = take(args[0], args[1]); std::string mode = args[2];
*rv = take(args[0], args[1], mode);
} else { } else {
int axis = args[2]; int axis = args[2];
*rv = take(args[0], args[1], axis); std::string mode = args[3];
*rv = take(args[0], args[1], axis, mode);
} }
}); });
......
...@@ -232,16 +232,16 @@ def verify_flip(in_shape, axis): ...@@ -232,16 +232,16 @@ def verify_flip(in_shape, axis):
for device in ["llvm", "cuda", "opencl", "sdaccel", "aocl_sw_emu"]: for device in ["llvm", "cuda", "opencl", "sdaccel", "aocl_sw_emu"]:
check_device(device) check_device(device)
def verify_take(src_shape, indices_src, axis=None): def verify_take(src_shape, indices_src, axis=None, mode="clip"):
src_dtype = "float32" src_dtype = "float32"
indices_dtype = "int32" indices_dtype = "int32"
indices_src = np.array(indices_src, dtype=indices_dtype) indices_src = np.array(indices_src, dtype=indices_dtype)
A = tvm.placeholder(shape=src_shape, dtype=src_dtype, name="A") A = tvm.placeholder(shape=src_shape, dtype=src_dtype, name="A")
indices = tvm.placeholder(shape=indices_src.shape, dtype=indices_dtype, name="indices") indices = tvm.placeholder(shape=indices_src.shape, dtype=indices_dtype, name="indices")
if axis is None: if axis is None:
out_tensor = topi.take(a=A, indices=indices) out_tensor = topi.take(a=A, indices=indices, mode=mode)
else: else:
out_tensor = topi.take(a=A, indices=indices, axis=axis) out_tensor = topi.take(a=A, indices=indices, axis=axis, mode=mode)
def check_device(device): def check_device(device):
ctx = tvm.context(device, 0) ctx = tvm.context(device, 0)
...@@ -259,9 +259,9 @@ def verify_take(src_shape, indices_src, axis=None): ...@@ -259,9 +259,9 @@ def verify_take(src_shape, indices_src, axis=None):
data_npy = np.arange(shape_size, dtype=src_dtype).reshape((src_shape)) data_npy = np.arange(shape_size, dtype=src_dtype).reshape((src_shape))
if axis is None: if axis is None:
out_npys = np.take(data_npy, indices_src) out_npys = np.take(data_npy, indices_src, mode=mode)
else: else:
out_npys = np.take(data_npy, indices_src, axis=axis) out_npys = np.take(data_npy, indices_src, axis=axis, mode=mode)
data_nd = tvm.nd.array(data_npy, ctx) data_nd = tvm.nd.array(data_npy, ctx)
indices_nd = tvm.nd.array(indices_src, ctx) indices_nd = tvm.nd.array(indices_src, ctx)
out_nd = tvm.nd.empty(out_npys.shape, ctx=ctx, dtype=src_dtype) out_nd = tvm.nd.empty(out_npys.shape, ctx=ctx, dtype=src_dtype)
...@@ -498,6 +498,12 @@ def test_take(): ...@@ -498,6 +498,12 @@ def test_take():
verify_take((2,2), [[[1,0],[0,1]]], 0) verify_take((2,2), [[[1,0],[0,1]]], 0)
verify_take((2,2), [[[1,0],[0,1]]], 1) verify_take((2,2), [[[1,0],[0,1]]], 1)
verify_take((4,3,5,6), [[2,1,0,0]], -2) verify_take((4,3,5,6), [[2,1,0,0]], -2)
verify_take((3,4), [-5, 20])
verify_take((3,4), [-5, 20], mode="wrap")
verify_take((3,4), [-1, 2], axis=0)
verify_take((3,4), [-1, 2], axis=0, mode="wrap")
verify_take((3,4), [-1, 2], axis=1)
verify_take((3,4), [-1, 2], axis=1, mode="wrap")
def test_gather_nd(): def test_gather_nd():
for indices_dtype in ['int32', 'float32']: for indices_dtype in ['int32', 'float32']:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment