Commit a3cbefd8 by Xingjian Shi Committed by Tianqi Chen

[TOPI] add reshape, concatenate, split (#481)

* [TOPI]add reshape

* fix problems

* fix lint

* try to add concatenate

* fix lint and error

* fix doc

* fix error

* try to add split

* fix lint

* fix error

* fix lint
parent 7ec58675
......@@ -80,7 +80,7 @@ def _get_binary_op_bcast_shape(lhs_shape, rhs_shape):
@tvm.tag_scope(tag="broadcast_to")
@tvm.tag_scope(tag=tag.BROADCAST)
def broadcast_to(data, shape):
"""Broadcast the src to the target shape
......@@ -97,20 +97,20 @@ def broadcast_to(data, shape):
-------
ret : tvm.Tensor
"""
def _bcast_to_arg_eval(data, bcast_info, *args):
def _bcast_to_arg_eval(data, bcast_info, *indices):
indices_tuple = []
for i in range(len(args)):
for i, ind in enumerate(indices):
if bcast_info[i] == 0:
indices_tuple.append(args[i])
indices_tuple.append(ind)
elif bcast_info[i] == 1:
indices_tuple.append(0)
return data[tuple(indices_tuple)]
original_shape = data.shape
bcast_info = _get_bcast_info(original_shape=original_shape, target_shape=shape)
ret = tvm.compute([tvm.convert(ele) for ele in shape],
lambda *args: _bcast_to_arg_eval(data,
ret = tvm.compute(shape,
lambda *indices: _bcast_to_arg_eval(data,
bcast_info,
*args), name=data.name + "_broadcast")
*indices), name=data.name + "_broadcast")
return ret
......@@ -131,25 +131,25 @@ def broadcast_binary_op(lhs, rhs, func, name="bop"):
-------
ret : tvm.Tensor
"""
def _inner_arg_eval(lhs, rhs, lhs_bcast_info, rhs_bcast_info, func, *args):
def _inner_arg_eval(lhs, rhs, lhs_bcast_info, rhs_bcast_info, func, *indices):
lhs_indices = []
rhs_indices = []
for i in range(len(args)):
for i, ind in enumerate(indices):
if lhs_bcast_info[i] == 0:
lhs_indices.append(args[i])
lhs_indices.append(ind)
elif lhs_bcast_info[i] == 1:
lhs_indices.append(0)
if rhs_bcast_info[i] == 0:
rhs_indices.append(args[i])
rhs_indices.append(ind)
elif rhs_bcast_info[i] == 1:
rhs_indices.append(0)
return func(lhs[tuple(lhs_indices)], rhs[tuple(rhs_indices)])
ret_shape = _get_binary_op_bcast_shape(get_const_tuple(lhs.shape), get_const_tuple(rhs.shape))
lhs_bcast_info = _get_bcast_info(original_shape=lhs.shape, target_shape=ret_shape)
rhs_bcast_info = _get_bcast_info(original_shape=rhs.shape, target_shape=ret_shape)
ret = tvm.compute([tvm.convert(ele) for ele in ret_shape],
lambda *args: _inner_arg_eval(lhs, rhs, lhs_bcast_info, rhs_bcast_info,
func, *args),
ret = tvm.compute(ret_shape,
lambda *indices: _inner_arg_eval(lhs, rhs, lhs_bcast_info, rhs_bcast_info,
func, *indices),
name=lhs.name + "_" + rhs.name + "_" + name)
return ret
......
......@@ -30,7 +30,9 @@ def schedule_injective(outs):
s = tvm.create_schedule([x.op for x in outs])
tvm.schedule.AutoInlineInjective(s)
return _schedule_injective(outs[0].op, s)
for out in outs:
_schedule_injective(out.op, s)
return s
schedule_elemwise = schedule_injective
schedule_broadcast = schedule_injective
# pylint: disable=invalid-name,consider-using-enumerate
"""Injective transformation operators"""
from __future__ import absolute_import as _abs
import tvm
from . import tag
from .util import ravel_index, unravel_index, get_const_int
@tvm.tag_scope(tag=tag.BROADCAST)
def expand_dims(a, axis, num_newaxis=1):
......@@ -52,3 +54,109 @@ def transpose(a, axes=None):
idx[k] = indices[i]
return a(*idx)
return tvm.compute(new_shape, _compute)
@tvm.tag_scope(tag=tag.INJECTIVE)
def reshape(a, newshape):
"""Reshape the array
Parameters
----------
a : tvm.Tensor
The tensor to be reshaped
newshape : tuple of ints
The new shape
Returns
-------
ret : tvm.Tensor
"""
ndim = len(a.shape)
a_shape = [a.shape[i] for i in range(ndim)]
return tvm.compute(newshape,
lambda *indices: a(*unravel_index(ravel_index(indices, newshape), a_shape)))
@tvm.tag_scope(tag=tag.INJECTIVE)
def concatenate(a_tuple, axis=0):
"""Join a sequence of arrays along an existing axis.
Parameters
----------
a_tuple : tuple of tvm.Tensor
The arrays to concatenate
axis : int, optional
The axis along which the arrays will be joined. Default is 0.
Returns
-------
ret : tvm.Tensor
"""
assert isinstance(a_tuple, (list, tuple))
if axis < 0:
axis += len(a_tuple[0].shape)
assert axis < len(a_tuple[0].shape)
axis_sizes = [a_tuple[i].shape[axis] for i in range(len(a_tuple))]
out_shape = [a_tuple[0].shape[i] for i in range(0, axis)] + [sum(axis_sizes)]\
+ [a_tuple[0].shape[i] for i in range(axis + 1, len(a_tuple[0].shape))]
def _compute(*indices):
ret = a_tuple[0](*indices)
ind = indices[axis]
for i in range(len(a_tuple) - 1):
ind -= axis_sizes[i]
ret = tvm.select(ind >= 0,
a_tuple[i + 1](*(indices[0:axis] + (ind,) + indices[axis + 1:])),
ret)
return ret
return tvm.compute(out_shape, _compute)
@tvm.tag_scope(tag=tag.INJECTIVE)
def split(ary, indices_or_sections, axis=0):
"""Split an array into multiple sub-arrays.
Parameters
----------
ary : tvm.Tensor
indices_or_sections : int or 1-D array
axis : int
Returns
-------
ret : tuple of tvm.Tensor
"""
def _compute(begin, *indices):
real_indices = indices[:axis] + (indices[axis] + begin, ) + indices[axis + 1:]
return ary(*real_indices)
if axis < 0:
axis += len(ary.shape)
src_axis_size = get_const_int(ary.shape[axis])
if isinstance(indices_or_sections, int):
assert indices_or_sections > 0
assert src_axis_size % indices_or_sections == 0
seg_size = src_axis_size // indices_or_sections
begin_ids = [seg_size * i for i in range(indices_or_sections)]
elif isinstance(indices_or_sections, (tuple, list)):
assert tuple(indices_or_sections) == tuple(sorted(indices_or_sections)),\
"Should be sorted, recieved %s" %str(indices_or_sections)
begin_ids = [0] + list(indices_or_sections)
else:
raise NotImplementedError
out_shapes = []
for i in range(len(begin_ids)):
if i == len(begin_ids) - 1:
out_axis_size = src_axis_size - begin_ids[i]
else:
out_axis_size = begin_ids[i + 1] - begin_ids[i]
out_shapes.append([ary.shape[i] for i in range(axis)] + [out_axis_size] +\
[ary.shape[i] for i in range(axis + 1, len(ary.shape))])
# pylint: disable=cell-var-from-loop
return [tvm.compute(out_shape,
lambda *indices: _compute(begin_id, *indices), name="s%d" %i)
for i, (out_shape, begin_id) in enumerate(zip(out_shapes, begin_ids))]
# pylint: enable=cell-var-from-loop
......@@ -79,3 +79,52 @@ def simplify(expr):
The simplified output
"""
return tvm.ir_pass.Simplify(expr) if isinstance(expr, tvm.expr.Expr) else expr
def ravel_index(indices, shape):
"""Flatten the index tuple to 1D
Parameters
----------
indices : tuple of int or tvm.expr.IntImm
The input coordinates
shape : tuple of int
Shape of the tensor.
Returns
-------
idx : int or Expr
The index after flattening
"""
idx = None
for i, (shape_val, ind) in enumerate(zip(shape, indices)):
if i != 0:
idx = idx * shape_val + ind
else:
idx = ind
return idx
def unravel_index(idx, shape):
"""Convert the flattened ind to the coordinate array
Parameters
----------
idx : int or tvm.expr.IntImm
The 1D index
shape : tuple of int
Shape of the tensor
Returns
-------
indices : tuple of int or tvm.expr.IntImm
Corresponding coordinate of the 1D index
"""
indices = []
for i in range(len(shape) - 1, -1, -1):
indices.append(idx % shape[i])
idx = idx // shape[i]
indices = indices[::-1]
return indices
......@@ -47,6 +47,74 @@ def verify_tranpose(in_shape, axes):
check_device("metal")
def verify_reshape(src_shape, dst_shape):
A = tvm.placeholder(shape=src_shape, name="A")
B = topi.reshape(A, dst_shape)
s = topi.cuda.schedule_injective(B)
def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0)
foo = tvm.build(s, [A, B], device, name="reshape")
data_npy = np.random.normal(size=src_shape).astype(A.dtype)
out_npy = np.reshape(data_npy, newshape=dst_shape)
data_nd = tvm.nd.array(data_npy, ctx)
out_nd = tvm.nd.empty(dst_shape, ctx=ctx, dtype=B.dtype)
foo(data_nd, out_nd)
np.testing.assert_allclose(out_nd.asnumpy(), out_npy)
check_device("cuda")
check_device("opencl")
check_device("metal")
def verify_concatenate(shapes, axis):
tensor_l = []
for i, shape in enumerate(shapes):
tensor_l.append(tvm.placeholder(shape, name="A" + str(i)))
out_tensor = topi.concatenate(a_tuple=tensor_l, axis=axis)
s = topi.cuda.schedule_injective(out_tensor)
def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0)
foo = tvm.build(s, tensor_l + [out_tensor], device, name="concatenate")
data_npys = [np.random.normal(size=shape).astype(tensor_l[0].dtype) for shape in shapes]
out_npy = np.concatenate(data_npys, axis=axis)
data_nds = [tvm.nd.array(data_npy, ctx) for data_npy in data_npys]
out_nd = tvm.nd.empty(out_npy.shape, ctx=ctx, dtype=out_tensor.dtype)
foo(*(data_nds + [out_nd]))
np.testing.assert_allclose(out_nd.asnumpy(), out_npy)
check_device("cuda")
check_device("opencl")
check_device("metal")
def verify_split(src_shape, indices_or_sections, axis):
A = tvm.placeholder(shape=src_shape, name="A")
tensor_l = topi.split(A, indices_or_sections, axis=axis)
s = topi.cuda.schedule_injective(tensor_l)
def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0)
foo = tvm.build(s, [A] + tensor_l, device, name="split")
data_npy = np.random.normal(size=src_shape).astype(A.dtype)
out_npys = np.split(data_npy, indices_or_sections, axis=axis)
data_nd = tvm.nd.array(data_npy, ctx)
out_nds = [tvm.nd.empty(out_npy.shape, ctx=ctx, dtype=tensor_l[0].dtype) for out_npy in out_npys]
foo(*([data_nd] + out_nds))
for out_nd, out_npy in zip(out_nds, out_npys):
np.testing.assert_allclose(out_nd.asnumpy(), out_npy)
check_device("cuda")
check_device("opencl")
check_device("metal")
def test_expand_dims():
verify_expand_dims((3, 10), (3, 10, 1, 1), 2, 2)
verify_expand_dims((3, 10), (1, 3, 10), -3, 1)
......@@ -58,6 +126,32 @@ def test_tranpose():
verify_tranpose((3, 10), None)
def test_reshape():
verify_reshape((1, 2, 3, 4), (2, 3, 4))
verify_reshape((4, 2, 3, 4), (2, 4, 12))
verify_reshape((4, 2, 3, 4), (2, 48))
verify_reshape((16, ), (2, 2, 2, 2))
def test_concatenate():
verify_concatenate([(2, 3, 4), (2, 2, 4), (2, 5, 4)], 1)
verify_concatenate([(1, 2, 4), (1, 2, 3), (1, 2, 7), (1, 2, 8), (1, 2, 1)], -1)
verify_concatenate([(5, 6, 7, 3),
(16, 6, 7, 3),
(12, 6, 7, 3),
(8, 6, 7, 3),
(2, 6, 7, 3)], 0)
def test_split():
verify_split((2, 12, 3), 3, 1)
verify_split((2, 12, 3), [2, 4], 1)
verify_split((10, 12, 24), [5, 7, 9], -1)
if __name__ == "__main__":
test_tranpose()
test_expand_dims()
test_reshape()
test_concatenate()
test_split()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment