Commit dc7ab96b by Xingjian Shi Committed by Tianqi Chen

[TOPI] add squeeze (#494)

* add squeeze

* should be squeeze
parent fd864c51
......@@ -3,7 +3,7 @@
from __future__ import absolute_import as _abs
import tvm
from . import tag
from .util import ravel_index, unravel_index, get_const_int
from .util import ravel_index, unravel_index, get_const_int, get_const_tuple
@tvm.tag_scope(tag=tag.BROADCAST)
def expand_dims(a, axis, num_newaxis=1):
......@@ -78,6 +78,57 @@ def reshape(a, newshape):
@tvm.tag_scope(tag=tag.INJECTIVE)
def squeeze(a, axis=None):
"""Remove single-dimensional entries from the shape of an array.
Parameters
----------
a : tvm.Tensor
axis : None or int or tuple of ints, optional
Selects a subset of the single-dimensional entries in the shape.
If an axis is selected with shape entry greater than one, an error is raised.
Returns
-------
squeezed : tvm.Tensor
"""
a_ndim = len(a.shape)
a_shape = get_const_tuple(a.shape)
if axis is None:
axis = []
for i, ele in enumerate(a_shape):
if ele == 1:
axis.append(i)
else:
if isinstance(axis, int):
axis = axis + a_ndim if axis < 0 else axis
assert a_shape[axis] == 1
axis = [axis]
else:
axis = [ele + a_ndim if ele < 0 else ele for ele in axis]
for ele in axis:
assert a_shape[ele] == 1
out_shape = []
search_axis = set(axis)
for i, a_dim in enumerate(a_shape):
if i not in search_axis:
out_shape.append(a_dim)
def _compute(*indices):
real_indices = []
flag = 0
for i in range(a_ndim):
if i not in search_axis:
real_indices.append(indices[i - flag])
else:
real_indices.append(0)
flag += 1
return a(*real_indices)
return tvm.compute(out_shape, _compute)
@tvm.tag_scope(tag=tag.INJECTIVE)
def concatenate(a_tuple, axis=0):
"""Join a sequence of arrays along an existing axis.
......
......@@ -69,6 +69,28 @@ def verify_reshape(src_shape, dst_shape):
check_device("metal")
def verify_squeeze(src_shape, axis):
A = tvm.placeholder(shape=src_shape, name="A")
B = topi.squeeze(A, axis=axis)
s = topi.cuda.schedule_injective(B)
def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0)
foo = tvm.build(s, [A, B], device, name="squeeze")
data_npy = np.random.normal(size=src_shape).astype(A.dtype)
out_npy = np.squeeze(data_npy, axis=axis)
data_nd = tvm.nd.array(data_npy, ctx)
out_nd = tvm.nd.empty(out_npy.shape, ctx=ctx, dtype=B.dtype)
foo(data_nd, out_nd)
np.testing.assert_allclose(out_nd.asnumpy(), out_npy)
check_device("cuda")
check_device("opencl")
check_device("metal")
def verify_concatenate(shapes, axis):
tensor_l = []
for i, shape in enumerate(shapes):
......@@ -133,6 +155,12 @@ def test_reshape():
verify_reshape((16, ), (2, 2, 2, 2))
def test_squeeze():
verify_squeeze((1, 2, 3, 4), 0)
verify_squeeze((1, 2, 1, 4), None)
verify_squeeze((1, 1, 1, 4), (1, 2))
def test_concatenate():
verify_concatenate([(2, 3, 4), (2, 2, 4), (2, 5, 4)], 1)
verify_concatenate([(1, 2, 4), (1, 2, 3), (1, 2, 7), (1, 2, 8), (1, 2, 1)], -1)
......@@ -152,6 +180,7 @@ if __name__ == "__main__":
test_tranpose()
test_expand_dims()
test_reshape()
test_squeeze()
test_concatenate()
test_split()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment