Commit 6d88c987 by abergeron Committed by Tianqi Chen

[TOPI][Relay][OP] Add a strided_set operation. (#4303)

parent e3eff20d
...@@ -48,6 +48,7 @@ _reg.register_schedule("cast", schedule_injective) ...@@ -48,6 +48,7 @@ _reg.register_schedule("cast", schedule_injective)
_reg.register_schedule("cast_like", schedule_injective) _reg.register_schedule("cast_like", schedule_injective)
_reg.register_schedule("reinterpret", schedule_injective) _reg.register_schedule("reinterpret", schedule_injective)
_reg.register_schedule("strided_slice", schedule_injective) _reg.register_schedule("strided_slice", schedule_injective)
_reg.register_schedule("strided_set", schedule_injective)
_reg.register_schedule("slice_like", schedule_injective) _reg.register_schedule("slice_like", schedule_injective)
_reg.register_schedule("split", schedule_injective) _reg.register_schedule("split", schedule_injective)
_reg.register_schedule("take", schedule_injective) _reg.register_schedule("take", schedule_injective)
...@@ -304,6 +305,11 @@ def compute_argwhere(attrs, inputs, output_type, _): ...@@ -304,6 +305,11 @@ def compute_argwhere(attrs, inputs, output_type, _):
new_output_type = tvm.relay.ty.TensorType(output_shape, "int32") new_output_type = tvm.relay.ty.TensorType(output_shape, "int32")
return [topi.argwhere(new_output_type, inputs[0])] return [topi.argwhere(new_output_type, inputs[0])]
@_reg.register_compute("strided_set")
def compute_strided_set(attrs, inputs, output_type, _):
"""Compute definition of strided_set"""
return [topi.strided_set(inputs[0], inputs[1], inputs[2], inputs[3], inputs[4])]
@script @script
def _layout_transform_shape_func(data_shape, def _layout_transform_shape_func(data_shape,
out_layout_len, out_layout_len,
......
...@@ -631,6 +631,36 @@ def strided_slice(data, begin, end, strides=None): ...@@ -631,6 +631,36 @@ def strided_slice(data, begin, end, strides=None):
return _make.strided_slice(data, list(begin), list(end), list(strides)) return _make.strided_slice(data, list(begin), list(end), list(strides))
def strided_set(data, v, begin, end, strides=None):
"""Strided set of an array.
Parameters
----------
data : relay.Expr
The source array to be sliced.
v : relay.Expr
The data to be set.
begin: relay.Expr
The indices to begin with in the slicing.
end: relay.Expr
Indices indicating end of the slice.
strides: relay.Expr, optional
Specifies the stride values, it can be negative in that case,
the input tensor will be reversed in that particular axis.
Returns
-------
ret : relay.Expr
The computed result.
"""
strides = strides or const([1], dtype="int32")
return _make.strided_set(data, v, begin, end, strides)
def slice_like(data, shape_like, axes=None): def slice_like(data, shape_like, axes=None):
"""Slice the first input with respect to the second input. """Slice the first input with respect to the second input.
......
...@@ -2049,6 +2049,54 @@ Examples:: ...@@ -2049,6 +2049,54 @@ Examples::
.set_attr<TOpPattern>("TOpPattern", kInjective) .set_attr<TOpPattern>("TOpPattern", kInjective)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", StridedSliceInferCorrectLayout); .set_attr<FInferCorrectLayout>("FInferCorrectLayout", StridedSliceInferCorrectLayout);
// strided_set
bool StridedSetRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 6);
reporter->Assign(types[5], types[0]);
return true;
}
Expr MakeStridedSet(Expr data,
Expr v,
Expr begin,
Expr end,
Expr strides) {
static const Op& op = Op::Get("strided_set");
return CallNode::make(op, {data, v, begin, end, strides}, {});
}
TVM_REGISTER_API("relay.op._make.strided_set")
.set_body_typed(MakeStridedSet);
RELAY_REGISTER_OP("strided_set")
.describe(R"code(Strided set of an array.
Example::
x = [[ 1., 4., 7., 10.],
[ 2., 5., 8., 11.],
[ 3., 6., 9., 12.]]
v = [[ 11., 22., 33.]
[ 44., 55., 66.]]
strided_set(x, v, begin=[0, 1], end=[2, 4], stride=[1, 1]) = \
[[ 1., 11., 22., 33.],
[ 2., 44., 55., 66.],
[ 3., 6., 9., 12.]]
)code" TVM_ADD_FILELINE)
.set_num_inputs(5)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("v", "Tensor", "The data to set.")
.add_argument("begin", "Tensor", "Indices for the start of the slice.")
.add_argument("end", "Tensor", "Indices indicating the end of the slice.")
.add_argument("strides", "Tensor", "The strides values.")
.set_support_level(4)
.set_attr<TOpPattern>("TOpPattern", kInjective)
.add_type_rel("StridedSet", StridedSetRel);
// relay.split // relay.split
TVM_REGISTER_NODE_TYPE(SplitAttrs); TVM_REGISTER_NODE_TYPE(SplitAttrs);
......
...@@ -300,8 +300,48 @@ def test_strided_slice(): ...@@ -300,8 +300,48 @@ def test_strided_slice():
verify((3, 4, 3), [1, 1], [4, 4, 3], None, (2, 3, 3)) verify((3, 4, 3), [1, 1], [4, 4, 3], None, (2, 3, 3))
def test_strided_set():
def verify(dshape, begin, end, strides, vshape, test_ref=True):
x = relay.var("x", relay.TensorType(dshape, "float32"))
v = relay.var("v", relay.TensorType(vshape, "float32"))
begin_c = relay.const(begin, dtype="int32")
end_c = relay.const(end, dtype="int32")
if strides:
strides_c = relay.const(strides, dtype="int32")
z = relay.strided_set(x, v, begin=begin_c, end=end_c, strides=strides_c)
else:
z = relay.strided_set(x, v, begin=begin_c, end=end_c)
func = relay.Function([x, v], z)
func = run_infer_type(func)
text = func.astext()
assert "strided_set" in text
print(text)
assert func.body.checked_type == relay.ty.TensorType(dshape, "float32")
if not test_ref:
return
x_data = np.random.uniform(size=dshape).astype("float32")
v_data = np.random.uniform(size=vshape).astype("float32")
ref_res = topi.testing.strided_set_python(
x_data, v_data, begin, end, strides)
for target, ctx in ctx_list():
intrp = relay.create_executor("graph", ctx=ctx, target=target)
op_res = intrp.evaluate(func)(x_data, v_data)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res)
verify((3, 4, 3), [0, 0, 0], [4, -5, 4], [1, -1, 2], (3, 1, 2))
verify((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1], (1, 3, 3))
verify((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], (1, 4, 3))
verify((3, 4, 3), [1, 0, 0], [2, 2, 3], [1, 1, 2], (1, 2, 2))
verify((3, 4, 3), [1, -1, 0], [2, -3, 3], [1, -1, 1], (1, 2, 3))
verify((3, 4, 3), [1, 1, 0], [4, 4, 3], None, (2, 3, 3))
verify((3, 4, 3), [1, 1, 0], [4, 1000, 3], None, (2, 3, 3))
verify((3, 4, 3), [1, 1, 0], [4, 4], None, (2, 3, 3))
verify((3, 4, 3), [1, 1], [4, 4, 3], None, (2, 3, 3))
if __name__ == "__main__": if __name__ == "__main__":
test_strided_slice() test_strided_slice()
test_strided_set()
test_binary_op() test_binary_op()
test_cmp_type() test_cmp_type()
test_binary_int_broadcast() test_binary_int_broadcast()
......
...@@ -37,7 +37,7 @@ from .roi_pool_python import roi_pool_nchw_python ...@@ -37,7 +37,7 @@ from .roi_pool_python import roi_pool_nchw_python
from .lrn_python import lrn_python from .lrn_python import lrn_python
from .l2_normalize_python import l2_normalize_python from .l2_normalize_python import l2_normalize_python
from .gather_nd_python import gather_nd_python from .gather_nd_python import gather_nd_python
from .strided_slice_python import strided_slice_python from .strided_slice_python import strided_slice_python, strided_set_python
from .batch_matmul import batch_matmul from .batch_matmul import batch_matmul
from .slice_axis_python import slice_axis_python from .slice_axis_python import slice_axis_python
from .sequence_mask_python import sequence_mask from .sequence_mask_python import sequence_mask
......
...@@ -14,7 +14,8 @@ ...@@ -14,7 +14,8 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
"""gather_nd in python""" """strided_slice/set in python"""
def strided_slice_python(data, begin, end, strides): def strided_slice_python(data, begin, end, strides):
"""Python version of strided slice operator. """Python version of strided slice operator.
...@@ -46,3 +47,40 @@ def strided_slice_python(data, begin, end, strides): ...@@ -46,3 +47,40 @@ def strided_slice_python(data, begin, end, strides):
end[i] if i < len(end) else None, end[i] if i < len(end) else None,
strides[i] if i < len(strides) else None)) strides[i] if i < len(strides) else None))
return data[tuple(slices)] return data[tuple(slices)]
def strided_set_python(data, v, begin, end, strides):
"""Python version of strided slice operator.
Parameters
----------
data : numpy.ndarray
Input data
v : numpy.ndarray
Value data
begin : list
Begining of the slices.
end : list
End of the slices.
strides : list
The stride of each slice.
Returns
-------
result : numpy.ndarray
The updated result.
"""
strides = [] if strides is None else strides
slices = []
res = data.copy()
for i in range(len(data.shape)):
slices.append(slice(
begin[i] if i < len(begin) else None,
end[i] if i < len(end) else None,
strides[i] if i < len(strides) else None))
res[tuple(slices)] = v
return res
...@@ -20,6 +20,8 @@ from __future__ import absolute_import as _abs ...@@ -20,6 +20,8 @@ from __future__ import absolute_import as _abs
import tvm import tvm
import topi import topi
from . import cpp from . import cpp
from . import tag
from .util import within_index, make_idx
def expand_dims(a, axis, num_newaxis=1): def expand_dims(a, axis, num_newaxis=1):
...@@ -155,6 +157,97 @@ def strided_slice(a, begin, end, strides=None): ...@@ -155,6 +157,97 @@ def strided_slice(a, begin, end, strides=None):
strides = [] strides = []
return cpp.strided_slice(a, begin, end, strides) return cpp.strided_slice(a, begin, end, strides)
@tvm.tag_scope(tag=tag.INJECTIVE+",strided_set")
def strided_set(a, v, begin, end, strides=None):
"""Set slice of an array.
Parameters
----------
a : tvm.Tensor
The tensor to be sliced.
v : tvm.Tensor
The values to set
begin: tvm.Tensor
The indices to begin with in the slicing.
end: tvm.Tensor
Indicies indicating end of the slice.
strides: tvm.Tensor, optional
Specifies the stride values, it can be negative
in that case, the input tensor will be reversed
in that particular axis.
Returns
-------
ret : tvm.Tensor
"""
n = len(a.shape)
if len(begin.shape) != 1:
raise ValueError("begin should be a vector")
if not begin.dtype == 'int32':
raise TypeError("begin should be int32")
if len(end.shape) != 1:
raise ValueError("end should be a vector")
if not end.dtype == 'int32':
raise TypeError("end should be int32")
if strides is not None:
if len(strides.shape) != 1:
raise ValueError("strides should be a vector")
if not strides.dtype == 'int32':
raise TypeError("strides should be int32")
def _max(a, b):
return tvm.expr.Select(a > b, a, b)
if strides is None:
strides = [tvm.const(1, 'int32')] * n
else:
strides = [tvm.if_then_else(strides.shape[0] > i,
strides[i],
tvm.const(1, 'int32'))
for i in range(n)]
begin = [tvm.if_then_else(begin.shape[0] > i,
begin[i],
tvm.expr.Select(strides[i] > 0,
tvm.const(0, 'int32'),
a.shape[i]))
for i in range(n)]
end = [tvm.if_then_else(end.shape[0] > i,
end[i],
tvm.expr.Select(strides[i] > 0,
a.shape[i] + 1,
-(a.shape[i] + 1)))
for i in range(n)]
# Convert negative indexes
for i in range(n):
begin[i] = tvm.if_then_else(begin[i] < 0,
begin[i] + a.shape[i],
begin[i])
end[i] = tvm.if_then_else(end[i] < 0,
end[i] + a.shape[i],
end[i])
def _select(*indices):
from_val = []
index_tuple = []
for i in range(n):
from_val.append(
within_index(begin[i], end[i], strides[i], indices[i]))
index_tuple.append(
make_idx(begin[i], end[i], strides[i], a.shape[i], indices[i]))
return tvm.if_then_else(tvm.all(*from_val),
v(*index_tuple),
a(*indices))
return tvm.compute(a.shape, _select, name="strided_set")
def reshape(a, newshape): def reshape(a, newshape):
"""Reshape the array """Reshape the array
......
...@@ -345,3 +345,75 @@ def get_shape(src_shape, src_layout, dst_layout): ...@@ -345,3 +345,75 @@ def get_shape(src_shape, src_layout, dst_layout):
tvm.convert([i for i in range(len(src_layout))])) tvm.convert([i for i in range(len(src_layout))]))
return get_const_tuple(tuple([src_shape[i.value] for i in dst_indices])) return get_const_tuple(tuple([src_shape[i.value] for i in dst_indices]))
def within_index(b, e, s, i):
"""Return a boolean value that indicates if i is within the given index.
Parameter
---------
b : Expr
beginning of the index
e : Expr
end of the index
s : Expr
strides of index
i : Expr
array position
Returns
-------
selected: Expr
bool expression that is True is the array position would be selected
by the index and False otherwise
"""
bc = tvm.expr.Select(s < 0, i <= e, i < b)
ec = tvm.expr.Select(s < 0, i > b, i >= e)
ss = tvm.if_then_else(s < 0,
((i - e) + (e % tvm.abs(s)) + 1) % tvm.abs(s),
(i - b) % s)
return tvm.expr.Select(tvm.expr.Or(bc, ec), tvm.const(False), ss.equal(0))
def make_idx(b, e, s, z, i):
"""Return the array position in the selection that corresponds to an
array position in the full array.
The returned value is only meaningful if within_index() returns True
for the same set of parameters.
Parameter
---------
b : Expr
beginning of the index
e : Expr
end of the index
s : Expr
strides of index
z : Expr
size of the indexed dimension
i : Expr
array position
Returns
-------
postion: Expr
int expression that corresponds to an array position in the selection.
"""
bc = tvm.expr.Select(s < 0, i <= e, i < b)
ec = tvm.expr.Select(s < 0, i > b, i >= e)
# Clamp to array size
b = tvm.expr.Select(z < b, z - 1, b)
ss = tvm.if_then_else(s < 0,
(b - i) // tvm.abs(s),
(i - b) // s)
return tvm.if_then_else(tvm.expr.Or(bc, ec), 88, ss)
...@@ -342,6 +342,52 @@ def verify_strided_slice(in_shape, begin, end, strides=None): ...@@ -342,6 +342,52 @@ def verify_strided_slice(in_shape, begin, end, strides=None):
for device in ["llvm", "opencl", "sdaccel", "aocl_sw_emu"]: for device in ["llvm", "opencl", "sdaccel", "aocl_sw_emu"]:
check_device(device) check_device(device)
def verify_strided_set(in_shape, v_shape, begin, end, strides=None):
A = tvm.placeholder(shape=in_shape, name="A")
V = tvm.placeholder(shape=v_shape, name="V")
b = tvm.placeholder(shape=(len(begin),), name="b", dtype='int32')
e = tvm.placeholder(shape=(len(end),), name="e", dtype='int32')
if strides is not None:
st = tvm.placeholder(shape=(len(strides),), name="st", dtype='int32')
B = topi.strided_set(A, V, b, e, st) + 1
else:
B = topi.strided_set(A, V, b, e) + 1
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_injective(B)
if strides is not None:
foo = tvm.build(s, [A, V, b, e, st, B], device, name="stride_set")
s_np = np.asarray(strides).astype('int32')
s_nd = tvm.nd.array(s_np, ctx)
else:
foo = tvm.build(s, [A, V, b, e, B], device, name="stride_set")
x_np = np.random.uniform(size=in_shape).astype(A.dtype)
v_np = np.random.uniform(size=v_shape).astype(V.dtype)
b_np = np.asarray(begin).astype('int32')
e_np = np.asarray(end).astype('int32')
out_npy = topi.testing.strided_set_python(
x_np, v_np, begin, end, strides) + 1
data_nd = tvm.nd.array(x_np, ctx)
v_nd = tvm.nd.array(v_np, ctx)
b_nd = tvm.nd.array(b_np, ctx)
e_nd = tvm.nd.array(e_np, ctx)
out_nd = tvm.nd.empty(out_npy.shape, ctx=ctx, dtype=A.dtype)
if strides is not None:
foo(data_nd, v_nd, b_nd, e_nd, s_nd, out_nd)
else:
foo(data_nd, v_nd, b_nd, e_nd, out_nd)
tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy)
for device in ["llvm", "opencl", "sdaccel", "aocl_sw_emu"]:
check_device(device)
def verify_gather_nd(src_shape, indices_src, indices_dtype): def verify_gather_nd(src_shape, indices_src, indices_dtype):
src_dtype = "float32" src_dtype = "float32"
indices_src = np.array(indices_src, dtype=indices_dtype) indices_src = np.array(indices_src, dtype=indices_dtype)
...@@ -510,6 +556,17 @@ def test_strided_slice(): ...@@ -510,6 +556,17 @@ def test_strided_slice():
verify_strided_slice((3, 4, 3), [1, -1, 0], [2, -3, 3], [1, -1, 1]) verify_strided_slice((3, 4, 3), [1, -1, 0], [2, -3, 3], [1, -1, 1])
verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4, 3]) verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4, 3])
def test_strided_set():
verify_strided_set((3, 4, 3), (3, 2, 2), [0, 3, 0], [4, 1, 4], [1, -1, 2])
verify_strided_set((3, 4, 3), (3, 1, 2), [0, 0, 0], [4, -5, 4], [1, -1, 2])
verify_strided_set((3, 4, 3), (1, 3, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1])
verify_strided_set((3, 4, 3), (1, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1])
verify_strided_set((3, 4, 3), (1, 2, 2), [1, 0, 0], [2, 2, 3], [1, 1, 2])
verify_strided_set((3, 4, 3), (1, 2, 3), [1, -1, 0], [2, -3, 3], [1, -1, 1])
verify_strided_set((3, 4, 3), (1, 2, 3), [1, 1, 0], [2, 3, 3], [1])
verify_strided_set((3, 4, 3), (2, 3, 3), [1, 1, 0], [4, 4, 3])
verify_strided_set((3, 4, 3), (2, 3, 3), [1, 1], [4, 4, 3])
def test_expand_dims(): def test_expand_dims():
verify_expand_dims((3, 10), (3, 10, 1, 1), 2, 2) verify_expand_dims((3, 10), (3, 10, 1, 1), 2, 2)
verify_expand_dims((3, 10), (1, 3, 10), -3, 1) verify_expand_dims((3, 10), (1, 3, 10), -3, 1)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment