Commit 2c41fd2f by hlu1 Committed by Haichen Shen

[Topi] Fast mode in take op (#3325)

parent d4ca627a
......@@ -101,7 +101,8 @@ struct TakeAttrs : public tvm::AttrsNode<TakeAttrs> {
TVM_ATTR_FIELD(mode).set_default("clip")
.describe("Specify how out-of-bound indices will behave."
"clip - clip to the range (default)"
"wrap - wrap around the indices");
"wrap - wrap around the indices"
"fast - no clip or wrap around (user must make sure indices are in-bound)");
}
};
......
......@@ -218,9 +218,10 @@ def take(data, indices, axis=None, mode="clip"):
the flattened input array is used.
mode : str, optional
Specifies how out-of-bound indices will behave [clip, wrap].
Specifies how out-of-bound indices will behave [clip, wrap, fast].
clip: clip to the range (default).
wrap: wrap around the indices.
fast: no clip or wrap around (user must make sure indices are in-bound).
Returns
-------
......
......@@ -269,7 +269,8 @@ def test_take():
func = relay.Function([x, indices], z)
x_data = np.random.uniform(low=-1, high=1, size=src_shape).astype(src_dtype)
ref_res = np.take(x_data, indices=indices_src, axis=axis, mode=mode)
np_mode = "raise" if mode == "fast" else mode
ref_res = np.take(x_data, indices=indices_src, axis=axis, mode=np_mode)
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
......@@ -291,6 +292,9 @@ def test_take():
verify_take((3,4), [-1, 2], axis=0, mode="wrap")
verify_take((3,4), [-1, 2], axis=1)
verify_take((3,4), [-1, 2], axis=1, mode="wrap")
verify_take((3,3,3), [[11,25]], mode="fast")
verify_take((3,4), [0, 2], axis=0, mode="fast")
verify_take((3,4), [0, 2], axis=1, mode="fast")
def test_split_infer_type():
......
......@@ -641,6 +641,13 @@ inline Tensor take(const Tensor& a,
auto idx = tvm::min(tvm::max(0, indices(out_index)), a_size - 1);
return a(UnravelIndex(idx, a_shape));
}, name, tag);
} else if (mode == "fast") {
LOG(WARNING) << "Fast mode segfaults when there are out-of-bounds indices. "
"Make sure input indices are in bound";
return compute(
out_shape, [&](const Array<Var>& out_index) {
return a(UnravelIndex(indices(out_index), a_shape));
}, name, tag);
} else { // mode == "wrap"
return compute(
out_shape, [&](const Array<Var>& out_index) {
......@@ -706,6 +713,25 @@ inline Tensor take(const Tensor& a,
}
return a(real_indices);
}, name, tag);
} else if (mode == "fast") {
LOG(WARNING) << "Fast mode segfaults when there are out-of-bounds indices. "
"Make sure input indices are in bound";
return compute(
out_shape, [&](const Array<Var>& out_index) {
Array<Expr> indices_position;
for (size_t j = axis; j < static_cast<size_t>(axis+indices_len); ++j) {
indices_position.push_back(out_index[j]);
}
Array<Expr> real_indices;
for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
real_indices.push_back(out_index[j]);
}
real_indices.push_back(indices(indices_position));
for (size_t j = axis + indices_len; j < out_index.size(); ++j) {
real_indices.push_back(out_index[j]);
}
return a(real_indices);
}, name, tag);
} else { // mode == "wrap"
return compute(
out_shape, [&](const Array<Var>& out_index) {
......
......@@ -265,6 +265,7 @@ def take(a, indices, axis=None, mode="clip"):
Specifies how out-of-bound indices will behave.
clip - clip to the range (default)
wrap - wrap around the indices
fast - no clip or wrap around (user must make sure indices are in-bound)
Returns
-------
......
......@@ -275,9 +275,11 @@ def verify_take(src_shape, indices_src, axis=None, mode="clip"):
data_npy = np.arange(shape_size, dtype=src_dtype).reshape((src_shape))
if axis is None:
out_npys = np.take(data_npy, indices_src, mode=mode)
np_mode = "raise" if mode == "fast" else mode
out_npys = np.take(data_npy, indices_src, mode=np_mode)
else:
out_npys = np.take(data_npy, indices_src, axis=axis, mode=mode)
np_mode = "raise" if mode == "fast" else mode
out_npys = np.take(data_npy, indices_src, axis=axis, mode=np_mode)
data_nd = tvm.nd.array(data_npy, ctx)
indices_nd = tvm.nd.array(indices_src, ctx)
out_nd = tvm.nd.empty(out_npys.shape, ctx=ctx, dtype=src_dtype)
......@@ -521,6 +523,9 @@ def test_take():
verify_take((3,4), [-1, 2], axis=0, mode="wrap")
verify_take((3,4), [-1, 2], axis=1)
verify_take((3,4), [-1, 2], axis=1, mode="wrap")
verify_take((3,3,3), [[11,25]], mode="fast")
verify_take((3,4), [0, 2], axis=0, mode="fast")
verify_take((3,4), [0, 2], axis=1, mode="fast")
def test_gather_nd():
for indices_dtype in ['int32', 'float32']:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment