Commit 3746d902 by Haichen Shen Committed by Yizhi Liu

[Relay/TOPI][OP] Add clip and wrap mode support in take (#2858)

* Update take

* Add special case for canonical simplify and fix test cases

* Use lower case for wrap and clip

* remove unnecssary lower

* Fix mxnet converter for take

* fix
parent 7cc9240a
Subproject commit 86351c40824dfc4cbb7447d70e5e63d9bd76eb90
Subproject commit 55ba1778fd264c7507953552d8e51212ed11f748
......@@ -75,10 +75,15 @@ struct ReshapeAttrs : public tvm::AttrsNode<ReshapeAttrs> {
struct TakeAttrs : public tvm::AttrsNode<TakeAttrs> {
Integer axis;
std::string mode;
TVM_DECLARE_ATTRS(TakeAttrs, "relay.attrs.TakeAttrs") {
TVM_ATTR_FIELD(axis).set_default(NullValue<Integer>())
.describe("The axis over which to select values.");
TVM_ATTR_FIELD(mode).set_default("clip")
.describe("Specify how out-of-bound indices will behave."
"clip - clip to the range (default)"
"wrap - wrap around the indices");
}
};
......
......@@ -444,6 +444,15 @@ def _mx_tile(inputs, attrs):
return _op.tile(inputs[0], **new_attrs)
def _mx_take(inputs, attrs):
assert len(inputs) == 2
mode = attrs.get_str("mode", "clip")
if mode == "raise":
raise RuntimeError("take doesn't support raise mode")
axis = attrs.get_int("axis", 0)
return _op.take(inputs[0], inputs[1].astype("int32"), axis, mode)
def _mx_reverse(inputs, attrs):
assert len(inputs) == 1
new_attrs = {}
......@@ -749,6 +758,7 @@ _convert_map = {
"_full" : _mx_full,
"repeat" : _mx_repeat,
"tile" : _mx_tile,
"take" : _mx_take,
"reverse" : _mx_reverse,
"squeeze" : _mx_squeeze,
"broadcast_axis": _mx_broadcast_axis,
......
......@@ -186,7 +186,7 @@ def reshape_like(data, shape_like):
return _make.reshape_like(data, shape_like)
def take(data, indices, axis=None):
def take(data, indices, axis=None, mode="clip"):
"""Take elements from an array along an axis.
Parameters
......@@ -201,12 +201,17 @@ def take(data, indices, axis=None):
The axis over which to select values. By default,
the flattened input array is used.
mode : str, optional
Specifies how out-of-bound indices will behave.
clip - clip to the range (default)
wrap - wrap around the indices
Returns
-------
ret : relay.Expr
The computed result.
"""
return _make.take(data, indices, axis)
return _make.take(data, indices, axis, mode)
def full(fill_value, shape=(), dtype=""):
......
......@@ -753,24 +753,26 @@ Array<Tensor> TakeCompute(const Attrs& attrs,
const auto* param = attrs.as<TakeAttrs>();
CHECK(param != nullptr);
if (!param->axis.defined()) {
return Array<Tensor>{ topi::take(inputs[0], inputs[1]) };
return Array<Tensor>{ topi::take(inputs[0], inputs[1], param->mode) };
} else {
return Array<Tensor>{ topi::take(inputs[0], inputs[1], param->axis) };
return Array<Tensor>{ topi::take(inputs[0], inputs[1], param->axis, param->mode) };
}
}
Expr MakeTake(Expr data,
Expr indices,
Integer axis) {
Integer axis,
std::string mode) {
auto attrs = make_node<TakeAttrs>();
attrs->axis = std::move(axis);
attrs->mode = std::move(mode);
static const Op& op = Op::Get("take");
return CallNode::make(op, {data, indices}, Attrs(attrs), {});
}
TVM_REGISTER_API("relay.op._make.take")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 3>(MakeTake, args, rv);
runtime::detail::unpack_call<Expr, 4>(MakeTake, args, rv);
});
RELAY_REGISTER_OP("take")
......
......@@ -464,7 +464,6 @@ def test_forward_embedding():
verify((2, 2), (4, 5))
verify((2, 3, 4), (4, 5))
def test_forward_smooth_l1():
data = mx.sym.var('data')
mx_sym = mx.sym.smooth_l1(data)
......@@ -472,6 +471,26 @@ def test_forward_smooth_l1():
mx_sym = mx.sym.smooth_l1(data, scalar=1.0)
verify_mxnet_frontend_impl(mx_sym, (3, 4), (3, 4))
def test_forward_take():
def verify(shape, indices_src, axis, mode="clip"):
x_np = np.random.uniform(size=shape).astype("float32")
indices_np = np.array(indices_src, dtype="float32")
ref_res = mx.nd.take(mx.nd.array(x_np), mx.nd.array(indices_np), axis, mode)
mx_sym = mx.sym.take(mx.sym.var("x"), mx.sym.var("y"), axis, mode)
new_sym, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape, "y": indices_np.shape})
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(new_sym)(x_np, indices_np)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
verify((2,2), [[[1,0],[0,1]]], 0)
verify((2,2), [[[1,0],[0,1]]], 1)
verify((4,3,5,6), [[2,1,0,0]], -2)
verify((3,4), [-1, 5], 0)
verify((3,4), [-1, 5], 0, mode="wrap")
verify((3,4), [-1, 5], 1)
verify((3,4), [-1, 5], 1, mode="wrap")
if __name__ == '__main__':
test_forward_mlp()
test_forward_vgg()
......@@ -507,3 +526,4 @@ if __name__ == '__main__':
test_forward_full()
test_forward_embedding()
test_forward_smooth_l1()
test_forward_take()
......@@ -243,17 +243,17 @@ def test_take_infer_type():
verify_take((d1, d2, d3, d4), (d5, d6), (d1, d2, d5, d6, d4), -2)
def test_take():
def verify_take(src_shape, indices_src, axis=None):
def verify_take(src_shape, indices_src, axis=None, mode="clip"):
src_dtype = "float32"
indices_dtype = "int32"
indices_src = np.array(indices_src, dtype=indices_dtype)
x = relay.var("x", relay.TensorType(src_shape, src_dtype))
indices = relay.var("indices", relay.TensorType(indices_src.shape, indices_dtype))
z = relay.take(x, indices, axis=axis)
z = relay.take(x, indices, axis=axis, mode=mode)
func = relay.Function([x, indices], z)
x_data = np.random.uniform(low=-1, high=1, size=src_shape).astype(src_dtype)
ref_res = np.take(x_data, indices=indices_src, axis=axis)
ref_res = np.take(x_data, indices=indices_src, axis=axis, mode=mode)
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
......@@ -269,6 +269,12 @@ def test_take():
verify_take((2,2), [[[1,0],[0,1]]], 0)
verify_take((2,2), [[[1,0],[0,1]]], 1)
verify_take((4,3,5,6), [[2,1,0,0]], -2)
verify_take((3,4), [-5, 20])
verify_take((3,4), [-5, 20], mode="wrap")
verify_take((3,4), [-1, 2], axis=0)
verify_take((3,4), [-1, 2], axis=0, mode="wrap")
verify_take((3,4), [-1, 2], axis=1)
verify_take((3,4), [-1, 2], axis=1, mode="wrap")
def test_split_infer_type():
......
......@@ -39,6 +39,11 @@ def test_simplify_mod():
stmt = tvm.ir_pass.CanonicalSimplify(body)
diff = tvm.ir_pass.CanonicalSimplify(stmt.body.body.value.index - (1 + i) % 16)
assert diff.value == 0
# if we can't prove that j is non-negative, we can't prove that (j+16) % 16 is j%16
index = tvm.ir_pass.CanonicalSimplify((j + 16) % 16)
assert index != j
index = tvm.ir_pass.CanonicalSimplify((j + 16) % 16, {j: tvm.Range(0, 6)})
assert index == j
# if we can't prove that j+n*32 is non-negative, we can't prove that (j+n*32) % 16 is j%16
index = tvm.ir_pass.CanonicalSimplify(
(j + n * 32) % 16, {j: tvm.Range(0, 6)})
......
......@@ -604,22 +604,29 @@ inline Array<Tensor> split_sections(const Tensor& x,
*/
inline Tensor take(const Tensor& a,
const Tensor& indices,
std::string mode = "clip",
std::string name = "tensor",
std::string tag = kInjective) {
Array<Expr> a_shape = a->shape;
Array<Expr> out_shape;
for (size_t j = 0; j < indices->shape.size(); ++j) {
out_shape.push_back(indices->shape[j]);
Array<Expr> out_shape = indices->shape;
Expr a_size = 1;
for (size_t i = 0; i < a_shape.size(); ++i) {
a_size = a_size * a_shape[i];
}
return compute(
if (mode == "clip") {
return compute(
out_shape, [&](const Array<Var>& out_index) {
Array<Expr> indices_position;
for (size_t j = 0; j < indices->shape.size(); ++j) {
indices_position.push_back(out_index[j]);
}
return a(UnravelIndex(indices(indices_position), a_shape));
auto idx = tvm::min(tvm::max(0, indices(out_index)), a_size - 1);
return a(UnravelIndex(idx, a_shape));
}, name, tag);
} else { // mode == "wrap"
return compute(
out_shape, [&](const Array<Var>& out_index) {
auto idx = (indices(out_index) % a_size + a_size) % a_size;
return a(UnravelIndex(idx, a_shape));
}, name, tag);
}
}
/*!
......@@ -637,12 +644,15 @@ inline Tensor take(const Tensor& a,
inline Tensor take(const Tensor& a,
const Tensor& indices,
int axis,
std::string mode = "clip",
std::string name = "tensor",
std::string tag = kInjective) {
if (axis < 0) {
axis += static_cast<int>(a->shape.size());
}
CHECK_GE(axis, 0) << "axis out of bounds";
CHECK_LT(axis, a->shape.size()) << "axis out of bounds";
auto axis_dim = a->shape[axis];
int indices_len = static_cast<int>(indices->shape.size());
Array<Expr> out_shape;
......@@ -655,7 +665,27 @@ inline Tensor take(const Tensor& a,
out_shape.push_back(a->shape[i]);
}
}
return compute(
if (mode == "clip") {
return compute(
out_shape, [&](const Array<Var>& out_index) {
Array<Expr> indices_position;
for (size_t j = axis; j < static_cast<size_t>(axis+indices_len); ++j) {
indices_position.push_back(out_index[j]);
}
Array<Expr> real_indices;
for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
real_indices.push_back(out_index[j]);
}
auto idx = tvm::min(tvm::max(0, indices(indices_position)),
axis_dim - 1);
real_indices.push_back(idx);
for (size_t j = axis + indices_len; j < out_index.size(); ++j) {
real_indices.push_back(out_index[j]);
}
return a(real_indices);
}, name, tag);
} else { // mode == "wrap"
return compute(
out_shape, [&](const Array<Var>& out_index) {
Array<Expr> indices_position;
for (size_t j = axis; j < static_cast<size_t>(axis+indices_len); ++j) {
......@@ -665,12 +695,14 @@ inline Tensor take(const Tensor& a,
for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
real_indices.push_back(out_index[j]);
}
real_indices.push_back(indices(indices_position));
auto idx = (indices(indices_position) % axis_dim + axis_dim) % axis_dim;
real_indices.push_back(idx);
for (size_t j = axis + indices_len; j < out_index.size(); ++j) {
real_indices.push_back(out_index[j]);
}
return a(real_indices);
}, name, tag);
}
}
/*!
......
......@@ -228,7 +228,7 @@ def split(ary, indices_or_sections, axis=0):
return cpp.split(ary, indices_or_sections, axis)
def take(a, indices, axis=None):
def take(a, indices, axis=None, mode="clip"):
"""Take elements from an array along an axis.
Parameters
......@@ -243,13 +243,18 @@ def take(a, indices, axis=None):
The axis over which to select values. By default,
the flattened input array is used.
mode : str, optional
Specifies how out-of-bound indices will behave.
clip - clip to the range (default)
wrap - wrap around the indices
Returns
-------
ret : tvm.Tensor
"""
if axis is None:
return cpp.take(a, indices)
return cpp.take(a, indices, int(axis))
return cpp.take(a, indices, mode)
return cpp.take(a, indices, int(axis), mode)
def gather_nd(a, indices):
......
......@@ -297,11 +297,13 @@ TVM_REGISTER_GLOBAL("topi.layout_transform")
TVM_REGISTER_GLOBAL("topi.take")
.set_body([](TVMArgs args, TVMRetValue *rv) {
if (args.size() == 2) {
*rv = take(args[0], args[1]);
if (args.size() == 3) {
std::string mode = args[2];
*rv = take(args[0], args[1], mode);
} else {
int axis = args[2];
*rv = take(args[0], args[1], axis);
std::string mode = args[3];
*rv = take(args[0], args[1], axis, mode);
}
});
......
......@@ -232,16 +232,16 @@ def verify_flip(in_shape, axis):
for device in ["llvm", "cuda", "opencl", "sdaccel", "aocl_sw_emu"]:
check_device(device)
def verify_take(src_shape, indices_src, axis=None):
def verify_take(src_shape, indices_src, axis=None, mode="clip"):
src_dtype = "float32"
indices_dtype = "int32"
indices_src = np.array(indices_src, dtype=indices_dtype)
A = tvm.placeholder(shape=src_shape, dtype=src_dtype, name="A")
indices = tvm.placeholder(shape=indices_src.shape, dtype=indices_dtype, name="indices")
if axis is None:
out_tensor = topi.take(a=A, indices=indices)
out_tensor = topi.take(a=A, indices=indices, mode=mode)
else:
out_tensor = topi.take(a=A, indices=indices, axis=axis)
out_tensor = topi.take(a=A, indices=indices, axis=axis, mode=mode)
def check_device(device):
ctx = tvm.context(device, 0)
......@@ -259,9 +259,9 @@ def verify_take(src_shape, indices_src, axis=None):
data_npy = np.arange(shape_size, dtype=src_dtype).reshape((src_shape))
if axis is None:
out_npys = np.take(data_npy, indices_src)
out_npys = np.take(data_npy, indices_src, mode=mode)
else:
out_npys = np.take(data_npy, indices_src, axis=axis)
out_npys = np.take(data_npy, indices_src, axis=axis, mode=mode)
data_nd = tvm.nd.array(data_npy, ctx)
indices_nd = tvm.nd.array(indices_src, ctx)
out_nd = tvm.nd.empty(out_npys.shape, ctx=ctx, dtype=src_dtype)
......@@ -498,6 +498,12 @@ def test_take():
verify_take((2,2), [[[1,0],[0,1]]], 0)
verify_take((2,2), [[[1,0],[0,1]]], 1)
verify_take((4,3,5,6), [[2,1,0,0]], -2)
verify_take((3,4), [-5, 20])
verify_take((3,4), [-5, 20], mode="wrap")
verify_take((3,4), [-1, 2], axis=0)
verify_take((3,4), [-1, 2], axis=0, mode="wrap")
verify_take((3,4), [-1, 2], axis=1)
verify_take((3,4), [-1, 2], axis=1, mode="wrap")
def test_gather_nd():
for indices_dtype in ['int32', 'float32']:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment