Commit 94acff30 by Siju Committed by Tianqi Chen

added int type axis for relay reduce ops (#2199)

parent d15477cd
...@@ -30,6 +30,7 @@ def argmax(data, axis=None, keepdims=False, exclude=False): ...@@ -30,6 +30,7 @@ def argmax(data, axis=None, keepdims=False, exclude=False):
result : relay.Expr result : relay.Expr
The computed result. The computed result.
""" """
axis = [axis] if isinstance(axis, int) else axis
return _make.argmax(data, axis, keepdims, exclude) return _make.argmax(data, axis, keepdims, exclude)
def argmin(data, axis=None, keepdims=False, exclude=False): def argmin(data, axis=None, keepdims=False, exclude=False):
...@@ -59,6 +60,7 @@ def argmin(data, axis=None, keepdims=False, exclude=False): ...@@ -59,6 +60,7 @@ def argmin(data, axis=None, keepdims=False, exclude=False):
result : relay.Expr result : relay.Expr
The computed result. The computed result.
""" """
axis = [axis] if isinstance(axis, int) else axis
return _make.argmin(data, axis, keepdims, exclude) return _make.argmin(data, axis, keepdims, exclude)
...@@ -89,6 +91,7 @@ def sum(data, axis=None, keepdims=False, exclude=False): ...@@ -89,6 +91,7 @@ def sum(data, axis=None, keepdims=False, exclude=False):
result : relay.Expr result : relay.Expr
The computed result. The computed result.
""" """
axis = [axis] if isinstance(axis, int) else axis
return _make.sum(data, axis, keepdims, exclude) return _make.sum(data, axis, keepdims, exclude)
...@@ -119,6 +122,7 @@ def max(data, axis=None, keepdims=False, exclude=False): ...@@ -119,6 +122,7 @@ def max(data, axis=None, keepdims=False, exclude=False):
result : relay.Expr result : relay.Expr
The computed result. The computed result.
""" """
axis = [axis] if isinstance(axis, int) else axis
return _make.max(data, axis, keepdims, exclude) return _make.max(data, axis, keepdims, exclude)
...@@ -149,6 +153,7 @@ def min(data, axis=None, keepdims=False, exclude=False): ...@@ -149,6 +153,7 @@ def min(data, axis=None, keepdims=False, exclude=False):
result : relay.Expr result : relay.Expr
The computed result. The computed result.
""" """
axis = [axis] if isinstance(axis, int) else axis
return _make.min(data, axis, keepdims, exclude) return _make.min(data, axis, keepdims, exclude)
...@@ -179,6 +184,7 @@ def mean(data, axis=None, keepdims=False, exclude=False): ...@@ -179,6 +184,7 @@ def mean(data, axis=None, keepdims=False, exclude=False):
result : relay.Expr result : relay.Expr
The computed result. The computed result.
""" """
axis = [axis] if isinstance(axis, int) else axis
return _make.mean(data, axis, keepdims, exclude) return _make.mean(data, axis, keepdims, exclude)
...@@ -209,4 +215,5 @@ def prod(data, axis=None, keepdims=False, exclude=False): ...@@ -209,4 +215,5 @@ def prod(data, axis=None, keepdims=False, exclude=False):
result : relay.Expr result : relay.Expr
The computed result. The computed result.
""" """
axis = [axis] if isinstance(axis, int) else axis
return _make.prod(data, axis, keepdims, exclude) return _make.prod(data, axis, keepdims, exclude)
...@@ -145,7 +145,7 @@ def verify_reduce(funcs, data, axis, keepdims, exclude, output, dtype="float32") ...@@ -145,7 +145,7 @@ def verify_reduce(funcs, data, axis, keepdims, exclude, output, dtype="float32")
elif ref_func in [np.max, np.min, np.mean, np.prod]: elif ref_func in [np.max, np.min, np.mean, np.prod]:
ref_res = ref_func(x_data + 0, axis=axis, keepdims=keepdims) ref_res = ref_func(x_data + 0, axis=axis, keepdims=keepdims)
else: #argmin/argmax else: #argmin/argmax
if axis and len(axis) > 1: if axis and not isinstance(axis, int) and len(axis) > 1 :
return return
ref_res = ref_func(x_data + 0, axis=axis, keepdims=keepdims) ref_res = ref_func(x_data + 0, axis=axis, keepdims=keepdims)
...@@ -164,7 +164,7 @@ def test_reduce_functions(): ...@@ -164,7 +164,7 @@ def test_reduce_functions():
return func(data, axis=axis) return func(data, axis=axis)
else: else:
if axis is not None: if axis is not None:
axis = axis[0] axis = axis if isinstance(axis, int) else axis[0]
out_shape = list(data.shape) out_shape = list(data.shape)
out_shape[axis] = 1 out_shape[axis] = 1
else: else:
...@@ -180,10 +180,11 @@ def test_reduce_functions(): ...@@ -180,10 +180,11 @@ def test_reduce_functions():
[relay.prod, np.prod], [relay.prod, np.prod],
[relay.argmin, _with_keepdims(np.argmin)], [relay.argmin, _with_keepdims(np.argmin)],
[relay.argmax, _with_keepdims(np.argmax)]]: [relay.argmax, _with_keepdims(np.argmax)]]:
verify_reduce(func, (d1, d2, d3, d4), (2,), True, False, (d1, d2, 1, d4)) verify_reduce(func, (d1, d2, d3, d4), 2, True, False, (d1, d2, 1, d4))
verify_reduce(func, (d1, d2, d3), (1,), True, False, (d1, 1, d3)) verify_reduce(func, (d1, d2, d3), 1, True, False, (d1, 1, d3))
verify_reduce(func, (d1, d2, d3), None, True, False, (1, 1, 1)) verify_reduce(func, (d1, d2, d3), None, True, False, (1, 1, 1))
verify_reduce(func, (d1, d2, d3), (0, 1), True, False, (1, 1, d3)) verify_reduce(func, (d1, d2, d3), (0, 1), True, False, (1, 1, d3))
verify_reduce(func, (2, 3, 4), 1, True, False, (2, 1, 4))
verify_reduce(func, (2, 3, 4), (1,), True, False, (2, 1, 4)) verify_reduce(func, (2, 3, 4), (1,), True, False, (2, 1, 4))
verify_reduce(func, (2, 3, 4), (0, 1, 2), False, False, ()) verify_reduce(func, (2, 3, 4), (0, 1, 2), False, False, ())
verify_reduce(func, (4, 4, 3), None, False, True, ()) verify_reduce(func, (4, 4, 3), None, False, True, ())
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment