Commit 1e9d014b by Altan Haan Committed by Tianqi Chen

[Relay] Fix reduce axis bug (#3422)

* fix relay reduce axis bug

* add tests for reduce bug
parent 7db5779f
......@@ -107,7 +107,7 @@ def sum(data, axis=None, keepdims=False, exclude=False):
result : relay.Expr
The computed result.
"""
axis = [axis] if axis and isinstance(axis, int) else axis
axis = [axis] if isinstance(axis, int) else axis
return _make.sum(data, axis, keepdims, exclude)
......@@ -159,7 +159,7 @@ def all(data, axis=None, keepdims=False, exclude=False):
# [False, True, False]]
"""
axis = [axis] if axis and isinstance(axis, int) else axis
axis = [axis] if isinstance(axis, int) else axis
return _make.all(data, axis, keepdims, exclude)
......
......@@ -202,7 +202,9 @@ def test_reduce_functions():
[relay.argmax, _with_keepdims(np.argmax)]]:
verify_reduce(func, (d1, d2, d3, d4), None, False, False, ())
verify_reduce(func, (d1, d2, d3, d4), 2, True, False, (d1, d2, 1, d4))
verify_reduce(func, (d1, d2, d3, d4), 0, True, False, (1, d2, d3, d4))
verify_reduce(func, (d1, d2, d3), 1, True, False, (d1, 1, d3))
verify_reduce(func, (d1, d2, d3), 0, True, False, (1, d2, d3))
verify_reduce(func, (d1, d2, d3), None, True, False, (1, 1, 1))
verify_reduce(func, (d1, d2, d3), (0, 1), True, False, (1, 1, d3))
verify_reduce(func, (2, 3, 4), 1, True, False, (2, 1, 4))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment