Commit 94acff30 by Siju Committed by Tianqi Chen

added int type axis for relay reduce ops (#2199)

parent d15477cd
......@@ -30,6 +30,7 @@ def argmax(data, axis=None, keepdims=False, exclude=False):
result : relay.Expr
The computed result.
"""
axis = [axis] if isinstance(axis, int) else axis
return _make.argmax(data, axis, keepdims, exclude)
def argmin(data, axis=None, keepdims=False, exclude=False):
......@@ -59,6 +60,7 @@ def argmin(data, axis=None, keepdims=False, exclude=False):
result : relay.Expr
The computed result.
"""
axis = [axis] if isinstance(axis, int) else axis
return _make.argmin(data, axis, keepdims, exclude)
......@@ -89,6 +91,7 @@ def sum(data, axis=None, keepdims=False, exclude=False):
result : relay.Expr
The computed result.
"""
axis = [axis] if isinstance(axis, int) else axis
return _make.sum(data, axis, keepdims, exclude)
......@@ -119,6 +122,7 @@ def max(data, axis=None, keepdims=False, exclude=False):
result : relay.Expr
The computed result.
"""
axis = [axis] if isinstance(axis, int) else axis
return _make.max(data, axis, keepdims, exclude)
......@@ -149,6 +153,7 @@ def min(data, axis=None, keepdims=False, exclude=False):
result : relay.Expr
The computed result.
"""
axis = [axis] if isinstance(axis, int) else axis
return _make.min(data, axis, keepdims, exclude)
......@@ -179,6 +184,7 @@ def mean(data, axis=None, keepdims=False, exclude=False):
result : relay.Expr
The computed result.
"""
axis = [axis] if isinstance(axis, int) else axis
return _make.mean(data, axis, keepdims, exclude)
......@@ -209,4 +215,5 @@ def prod(data, axis=None, keepdims=False, exclude=False):
result : relay.Expr
The computed result.
"""
axis = [axis] if isinstance(axis, int) else axis
return _make.prod(data, axis, keepdims, exclude)
......@@ -145,7 +145,7 @@ def verify_reduce(funcs, data, axis, keepdims, exclude, output, dtype="float32")
elif ref_func in [np.max, np.min, np.mean, np.prod]:
ref_res = ref_func(x_data + 0, axis=axis, keepdims=keepdims)
else: #argmin/argmax
if axis and len(axis) > 1:
if axis and not isinstance(axis, int) and len(axis) > 1 :
return
ref_res = ref_func(x_data + 0, axis=axis, keepdims=keepdims)
......@@ -164,7 +164,7 @@ def test_reduce_functions():
return func(data, axis=axis)
else:
if axis is not None:
axis = axis[0]
axis = axis if isinstance(axis, int) else axis[0]
out_shape = list(data.shape)
out_shape[axis] = 1
else:
......@@ -180,10 +180,11 @@ def test_reduce_functions():
[relay.prod, np.prod],
[relay.argmin, _with_keepdims(np.argmin)],
[relay.argmax, _with_keepdims(np.argmax)]]:
verify_reduce(func, (d1, d2, d3, d4), (2,), True, False, (d1, d2, 1, d4))
verify_reduce(func, (d1, d2, d3), (1,), True, False, (d1, 1, d3))
verify_reduce(func, (d1, d2, d3, d4), 2, True, False, (d1, d2, 1, d4))
verify_reduce(func, (d1, d2, d3), 1, True, False, (d1, 1, d3))
verify_reduce(func, (d1, d2, d3), None, True, False, (1, 1, 1))
verify_reduce(func, (d1, d2, d3), (0, 1), True, False, (1, 1, d3))
verify_reduce(func, (2, 3, 4), 1, True, False, (2, 1, 4))
verify_reduce(func, (2, 3, 4), (1,), True, False, (2, 1, 4))
verify_reduce(func, (2, 3, 4), (0, 1, 2), False, False, ())
verify_reduce(func, (4, 4, 3), None, False, True, ())
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment