Commit 2c41fd2f by hlu1 Committed by Haichen Shen

[Topi] Fast mode in take op (#3325)

parent d4ca627a
...@@ -101,7 +101,8 @@ struct TakeAttrs : public tvm::AttrsNode<TakeAttrs> { ...@@ -101,7 +101,8 @@ struct TakeAttrs : public tvm::AttrsNode<TakeAttrs> {
TVM_ATTR_FIELD(mode).set_default("clip") TVM_ATTR_FIELD(mode).set_default("clip")
.describe("Specify how out-of-bound indices will behave." .describe("Specify how out-of-bound indices will behave."
"clip - clip to the range (default)" "clip - clip to the range (default)"
"wrap - wrap around the indices"); "wrap - wrap around the indices"
"fast - no clip or wrap around (user must make sure indices are in-bound)");
} }
}; };
......
...@@ -218,9 +218,10 @@ def take(data, indices, axis=None, mode="clip"): ...@@ -218,9 +218,10 @@ def take(data, indices, axis=None, mode="clip"):
the flattened input array is used. the flattened input array is used.
mode : str, optional mode : str, optional
Specifies how out-of-bound indices will behave [clip, wrap]. Specifies how out-of-bound indices will behave [clip, wrap, fast].
clip: clip to the range (default). clip: clip to the range (default).
wrap: wrap around the indices. wrap: wrap around the indices.
fast: no clip or wrap around (user must make sure indices are in-bound).
Returns Returns
------- -------
......
...@@ -269,7 +269,8 @@ def test_take(): ...@@ -269,7 +269,8 @@ def test_take():
func = relay.Function([x, indices], z) func = relay.Function([x, indices], z)
x_data = np.random.uniform(low=-1, high=1, size=src_shape).astype(src_dtype) x_data = np.random.uniform(low=-1, high=1, size=src_shape).astype(src_dtype)
ref_res = np.take(x_data, indices=indices_src, axis=axis, mode=mode) np_mode = "raise" if mode == "fast" else mode
ref_res = np.take(x_data, indices=indices_src, axis=axis, mode=np_mode)
for target, ctx in ctx_list(): for target, ctx in ctx_list():
for kind in ["graph", "debug"]: for kind in ["graph", "debug"]:
...@@ -291,6 +292,9 @@ def test_take(): ...@@ -291,6 +292,9 @@ def test_take():
verify_take((3,4), [-1, 2], axis=0, mode="wrap") verify_take((3,4), [-1, 2], axis=0, mode="wrap")
verify_take((3,4), [-1, 2], axis=1) verify_take((3,4), [-1, 2], axis=1)
verify_take((3,4), [-1, 2], axis=1, mode="wrap") verify_take((3,4), [-1, 2], axis=1, mode="wrap")
verify_take((3,3,3), [[11,25]], mode="fast")
verify_take((3,4), [0, 2], axis=0, mode="fast")
verify_take((3,4), [0, 2], axis=1, mode="fast")
def test_split_infer_type(): def test_split_infer_type():
......
...@@ -641,6 +641,13 @@ inline Tensor take(const Tensor& a, ...@@ -641,6 +641,13 @@ inline Tensor take(const Tensor& a,
auto idx = tvm::min(tvm::max(0, indices(out_index)), a_size - 1); auto idx = tvm::min(tvm::max(0, indices(out_index)), a_size - 1);
return a(UnravelIndex(idx, a_shape)); return a(UnravelIndex(idx, a_shape));
}, name, tag); }, name, tag);
} else if (mode == "fast") {
LOG(WARNING) << "Fast mode segfaults when there are out-of-bounds indices. "
"Make sure input indices are in bound";
return compute(
out_shape, [&](const Array<Var>& out_index) {
return a(UnravelIndex(indices(out_index), a_shape));
}, name, tag);
} else { // mode == "wrap" } else { // mode == "wrap"
return compute( return compute(
out_shape, [&](const Array<Var>& out_index) { out_shape, [&](const Array<Var>& out_index) {
...@@ -706,6 +713,25 @@ inline Tensor take(const Tensor& a, ...@@ -706,6 +713,25 @@ inline Tensor take(const Tensor& a,
} }
return a(real_indices); return a(real_indices);
}, name, tag); }, name, tag);
} else if (mode == "fast") {
LOG(WARNING) << "Fast mode segfaults when there are out-of-bounds indices. "
"Make sure input indices are in bound";
return compute(
out_shape, [&](const Array<Var>& out_index) {
Array<Expr> indices_position;
for (size_t j = axis; j < static_cast<size_t>(axis+indices_len); ++j) {
indices_position.push_back(out_index[j]);
}
Array<Expr> real_indices;
for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
real_indices.push_back(out_index[j]);
}
real_indices.push_back(indices(indices_position));
for (size_t j = axis + indices_len; j < out_index.size(); ++j) {
real_indices.push_back(out_index[j]);
}
return a(real_indices);
}, name, tag);
} else { // mode == "wrap" } else { // mode == "wrap"
return compute( return compute(
out_shape, [&](const Array<Var>& out_index) { out_shape, [&](const Array<Var>& out_index) {
......
...@@ -265,6 +265,7 @@ def take(a, indices, axis=None, mode="clip"): ...@@ -265,6 +265,7 @@ def take(a, indices, axis=None, mode="clip"):
Specifies how out-of-bound indices will behave. Specifies how out-of-bound indices will behave.
clip - clip to the range (default) clip - clip to the range (default)
wrap - wrap around the indices wrap - wrap around the indices
fast - no clip or wrap around (user must make sure indices are in-bound)
Returns Returns
------- -------
......
...@@ -275,9 +275,11 @@ def verify_take(src_shape, indices_src, axis=None, mode="clip"): ...@@ -275,9 +275,11 @@ def verify_take(src_shape, indices_src, axis=None, mode="clip"):
data_npy = np.arange(shape_size, dtype=src_dtype).reshape((src_shape)) data_npy = np.arange(shape_size, dtype=src_dtype).reshape((src_shape))
if axis is None: if axis is None:
out_npys = np.take(data_npy, indices_src, mode=mode) np_mode = "raise" if mode == "fast" else mode
out_npys = np.take(data_npy, indices_src, mode=np_mode)
else: else:
out_npys = np.take(data_npy, indices_src, axis=axis, mode=mode) np_mode = "raise" if mode == "fast" else mode
out_npys = np.take(data_npy, indices_src, axis=axis, mode=np_mode)
data_nd = tvm.nd.array(data_npy, ctx) data_nd = tvm.nd.array(data_npy, ctx)
indices_nd = tvm.nd.array(indices_src, ctx) indices_nd = tvm.nd.array(indices_src, ctx)
out_nd = tvm.nd.empty(out_npys.shape, ctx=ctx, dtype=src_dtype) out_nd = tvm.nd.empty(out_npys.shape, ctx=ctx, dtype=src_dtype)
...@@ -521,6 +523,9 @@ def test_take(): ...@@ -521,6 +523,9 @@ def test_take():
verify_take((3,4), [-1, 2], axis=0, mode="wrap") verify_take((3,4), [-1, 2], axis=0, mode="wrap")
verify_take((3,4), [-1, 2], axis=1) verify_take((3,4), [-1, 2], axis=1)
verify_take((3,4), [-1, 2], axis=1, mode="wrap") verify_take((3,4), [-1, 2], axis=1, mode="wrap")
verify_take((3,3,3), [[11,25]], mode="fast")
verify_take((3,4), [0, 2], axis=0, mode="fast")
verify_take((3,4), [0, 2], axis=1, mode="fast")
def test_gather_nd(): def test_gather_nd():
for indices_dtype in ['int32', 'float32']: for indices_dtype in ['int32', 'float32']:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment