Commit 2a8e0746 by Xingjian Shi Committed by Tianqi Chen

[TOPI]Support dim-0 tensor in topi broadcast/reduce (#731)

* support dim-0 tensor in topi ops

revert transform

* revert
parent 85e4058c
...@@ -107,10 +107,8 @@ def comm_reduce(data, axis=None, keepdims=False, func=tvm.sum, is_idx_reduce=Fal ...@@ -107,10 +107,8 @@ def comm_reduce(data, axis=None, keepdims=False, func=tvm.sum, is_idx_reduce=Fal
ret : tvm.Tensor ret : tvm.Tensor
""" """
ndim = len(data.shape) ndim = len(data.shape)
assert ndim != 0, "Reduce a dim-0 input is not supported!"
real_axis = _get_real_axis(ndim, axis) real_axis = _get_real_axis(ndim, axis)
if real_axis == list(range(ndim)) and keepdims is False:
raise ValueError("Currently we do not support all reduce + keepdims = False!"
" axis={}, keepdims={}".format(axis, keepdims))
reduce_axes = [tvm.reduce_axis((0, data.shape[i]), "k%d" %i) for i in real_axis] reduce_axes = [tvm.reduce_axis((0, data.shape[i]), "k%d" %i) for i in real_axis]
if keepdims: if keepdims:
target_shape = [1 if i in real_axis else data.shape[i] for i in range(ndim)] target_shape = [1 if i in real_axis else data.shape[i] for i in range(ndim)]
......
...@@ -89,12 +89,14 @@ def verify_broadcast_binary_ele(lhs_shape, rhs_shape, typ="add"): ...@@ -89,12 +89,14 @@ def verify_broadcast_binary_ele(lhs_shape, rhs_shape, typ="add"):
def test_broadcast_to(): def test_broadcast_to():
verify_broadcast_to_ele((1,), (10,)) verify_broadcast_to_ele((1,), (10,))
verify_broadcast_to_ele((), (10,))
verify_broadcast_to_ele((1, 1, 5, 4), (3, 4, 4, 4, 5, 4)) verify_broadcast_to_ele((1, 1, 5, 4), (3, 4, 4, 4, 5, 4))
verify_broadcast_to_ele((1, 128, 1, 32), (64, 128, 64, 32)) verify_broadcast_to_ele((1, 128, 1, 32), (64, 128, 64, 32))
def test_broadcast_binary(): def test_broadcast_binary():
verify_broadcast_binary_ele((5, 2, 3), (2, 1), typ="add") verify_broadcast_binary_ele((5, 2, 3), (2, 1), typ="add")
verify_broadcast_binary_ele((5, 2, 3), (), typ="add")
verify_broadcast_binary_ele((5, 64, 128), (2, 5, 64, 1), typ="mul") verify_broadcast_binary_ele((5, 64, 128), (2, 5, 64, 1), typ="mul")
verify_broadcast_binary_ele((2, 3, 1, 32), (64, 32), typ="div") verify_broadcast_binary_ele((2, 3, 1, 32), (64, 32), typ="div")
verify_broadcast_binary_ele((1, 32), (64, 32), typ="sub") verify_broadcast_binary_ele((1, 32), (64, 32), typ="sub")
......
...@@ -108,7 +108,10 @@ def test_reduce_map(): ...@@ -108,7 +108,10 @@ def test_reduce_map():
axis=None, axis=None,
keepdims=True, keepdims=True,
type="argmax") type="argmax")
verify_reduce_map_ele(in_shape=(31, 21, 15),
axis=None,
keepdims=False,
type="sum")
if __name__ == "__main__": if __name__ == "__main__":
test_reduce_map() test_reduce_map()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment