Commit 5d9647e2 by Xingjian Shi Committed by Tianqi Chen

fix squeeze to output (1,) if all axes are squeezed. E.g squeeze((1,1,1...), None) case (#498)

parent e44b38e3
......@@ -114,6 +114,8 @@ def squeeze(a, axis=None):
for i, a_dim in enumerate(a_shape):
if i not in search_axis:
out_shape.append(a_dim)
if not out_shape:
out_shape.append(1)
def _compute(*indices):
real_indices = []
flag = 0
......
......@@ -82,7 +82,11 @@ def verify_squeeze(src_shape, axis):
data_npy = np.random.normal(size=src_shape).astype(A.dtype)
out_npy = np.squeeze(data_npy, axis=axis)
data_nd = tvm.nd.array(data_npy, ctx)
out_nd = tvm.nd.empty(out_npy.shape, ctx=ctx, dtype=B.dtype)
if out_npy.shape == ():
out_nd_shape = (1,)
else:
out_nd_shape = out_npy.shape
out_nd = tvm.nd.empty(out_nd_shape, ctx=ctx, dtype=B.dtype)
foo(data_nd, out_nd)
np.testing.assert_allclose(out_nd.asnumpy(), out_npy)
......@@ -159,6 +163,7 @@ def test_squeeze():
verify_squeeze((1, 2, 3, 4), 0)
verify_squeeze((1, 2, 1, 4), None)
verify_squeeze((1, 1, 1, 4), (1, 2))
verify_squeeze((1, 1, 1, 1), None)
def test_concatenate():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment