Commit a527b58e by Jared Roesch Committed by Tianqi Chen

[Relay] Fixes to sum (#2439)

parent 967bcb3b
...@@ -12,8 +12,8 @@ def argmax(data, axis=None, keepdims=False, exclude=False): ...@@ -12,8 +12,8 @@ def argmax(data, axis=None, keepdims=False, exclude=False):
The input data The input data
axis : None or int or tuple of int axis : None or int or tuple of int
Axis or axes along which a argmin operation is performed. Axis or axes along which a argmax operation is performed.
The default, axis=None, will find the indices of maximum element all of the elements of The default, axis=None, will find the indices of the maximum element of the elements of
the input array. If axis is negative it counts from the last to the first axis. the input array. If axis is negative it counts from the last to the first axis.
keepdims : bool keepdims : bool
...@@ -73,14 +73,14 @@ def sum(data, axis=None, keepdims=False, exclude=False): ...@@ -73,14 +73,14 @@ def sum(data, axis=None, keepdims=False, exclude=False):
The input data The input data
axis : None or int or tuple of int axis : None or int or tuple of int
Axis or axes along which a argmin operation is performed. Axis or axes along which a sum is performed. The default, axis=None,
The default, axis=None, will find the indices of minimum element all of the elements of will sum all of the elements of the input array. If axis is
the input array. If axis is negative it counts from the last to the first axis. negative it counts from the last to the first axis.
keepdims : bool keepdims : bool
If this is set to True, the axes which are reduced are left in the result as dimensions If this is set to True, the axes which are reduced are left in the result as
with size one. dimensions with size one. With this option, the result will broadcast
With this option, the result will broadcast correctly against the input array. correctly against the input array.
exclude : bool exclude : bool
If `exclude` is true, reduction will be performed on the axes that are If `exclude` is true, reduction will be performed on the axes that are
...@@ -91,7 +91,7 @@ def sum(data, axis=None, keepdims=False, exclude=False): ...@@ -91,7 +91,7 @@ def sum(data, axis=None, keepdims=False, exclude=False):
result : relay.Expr result : relay.Expr
The computed result. The computed result.
""" """
axis = [axis] if isinstance(axis, int) else axis axis = [axis] if axis and isinstance(axis, int) else axis
return _make.sum(data, axis, keepdims, exclude) return _make.sum(data, axis, keepdims, exclude)
...@@ -104,9 +104,9 @@ def max(data, axis=None, keepdims=False, exclude=False): ...@@ -104,9 +104,9 @@ def max(data, axis=None, keepdims=False, exclude=False):
The input data The input data
axis : None or int or tuple of int axis : None or int or tuple of int
Axis or axes along which a argmin operation is performed. Axis or axes along which the max operation is performed.
The default, axis=None, will find the indices of minimum element all of the elements of The default, axis=None, will find the max element from all of the elements of the input
the input array. If axis is negative it counts from the last to the first axis. array. If axis is negative it counts from the last to the first axis.
keepdims : bool keepdims : bool
If this is set to True, the axes which are reduced are left in the result as dimensions If this is set to True, the axes which are reduced are left in the result as dimensions
...@@ -135,9 +135,10 @@ def min(data, axis=None, keepdims=False, exclude=False): ...@@ -135,9 +135,10 @@ def min(data, axis=None, keepdims=False, exclude=False):
The input data The input data
axis : None or int or tuple of int axis : None or int or tuple of int
Axis or axes along which a argmin operation is performed. Axis or axes along which a minimum operation is performed.
The default, axis=None, will find the indices of minimum element all of the elements of The default, axis=None, will find the minimum element from all
the input array. If axis is negative it counts from the last to the first axis. of the elements of the input array. If axis is negative it counts from
the last to the first axis.
keepdims : bool keepdims : bool
If this is set to True, the axes which are reduced are left in the result as dimensions If this is set to True, the axes which are reduced are left in the result as dimensions
...@@ -166,7 +167,7 @@ def mean(data, axis=None, keepdims=False, exclude=False): ...@@ -166,7 +167,7 @@ def mean(data, axis=None, keepdims=False, exclude=False):
The input data The input data
axis : None or int or tuple of int axis : None or int or tuple of int
Axis or axes along which a argmin operation is performed. Axis or axes along which a mean operation is performed.
The default, axis=None, will find the indices of minimum element all of the elements of The default, axis=None, will find the indices of minimum element all of the elements of
the input array. If axis is negative it counts from the last to the first axis. the input array. If axis is negative it counts from the last to the first axis.
...@@ -197,7 +198,7 @@ def prod(data, axis=None, keepdims=False, exclude=False): ...@@ -197,7 +198,7 @@ def prod(data, axis=None, keepdims=False, exclude=False):
The input data The input data
axis : None or int or tuple of int axis : None or int or tuple of int
Axis or axes along which a argmin operation is performed. Axis or axes along which a product is performed.
The default, axis=None, will find the indices of minimum element all of the elements of The default, axis=None, will find the indices of minimum element all of the elements of
the input array. If axis is negative it counts from the last to the first axis. the input array. If axis is negative it counts from the last to the first axis.
......
...@@ -180,6 +180,7 @@ def test_reduce_functions(): ...@@ -180,6 +180,7 @@ def test_reduce_functions():
[relay.prod, np.prod], [relay.prod, np.prod],
[relay.argmin, _with_keepdims(np.argmin)], [relay.argmin, _with_keepdims(np.argmin)],
[relay.argmax, _with_keepdims(np.argmax)]]: [relay.argmax, _with_keepdims(np.argmax)]]:
verify_reduce(func, (d1, d2, d3, d4), None, False, False, ())
verify_reduce(func, (d1, d2, d3, d4), 2, True, False, (d1, d2, 1, d4)) verify_reduce(func, (d1, d2, d3, d4), 2, True, False, (d1, d2, 1, d4))
verify_reduce(func, (d1, d2, d3), 1, True, False, (d1, 1, d3)) verify_reduce(func, (d1, d2, d3), 1, True, False, (d1, 1, d3))
verify_reduce(func, (d1, d2, d3), None, True, False, (1, 1, 1)) verify_reduce(func, (d1, d2, d3), None, True, False, (1, 1, 1))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment