Commit 02a8be10 by Siju Committed by Tianqi Chen

[RELAY]Reduce ops sum/max/min/mean/prod (#1927)

parent 7cd7dbff
......@@ -108,6 +108,11 @@ This level enables additional math and transform operators.
tvm.relay.where
tvm.relay.argmax
tvm.relay.argmin
tvm.relay.sum
tvm.relay.max
tvm.relay.min
tvm.relay.mean
tvm.relay.prod
**Level 5: Vision/Image Operators**
......@@ -187,6 +192,11 @@ Level 4 Definitions
.. autofunction:: tvm.relay.where
.. autofunction:: tvm.relay.argmax
.. autofunction:: tvm.relay.argmin
.. autofunction:: tvm.relay.sum
.. autofunction:: tvm.relay.max
.. autofunction:: tvm.relay.min
.. autofunction:: tvm.relay.mean
.. autofunction:: tvm.relay.prod
Level 5 Definitions
......
......@@ -30,7 +30,6 @@ def argmax(data, axis=None, keepdims=False, exclude=False):
result : relay.Expr
The computed result.
"""
return _make.argmax(data, axis, keepdims, exclude)
def argmin(data, axis=None, keepdims=False, exclude=False):
......@@ -60,5 +59,154 @@ def argmin(data, axis=None, keepdims=False, exclude=False):
result : relay.Expr
The computed result.
"""
return _make.argmin(data, axis, keepdims, exclude)
def sum(data, axis=None, keepdims=False, exclude=False):
"""Computes the sum of array elements over given axes.
Parameters
----------
data : relay.Expr
The input data
axis : None or int or tuple of int
Axis or axes along which a argmin operation is performed.
The default, axis=None, will find the indices of minimum element all of the elements of
the input array. If axis is negative it counts from the last to the first axis.
keepdims : bool
If this is set to True, the axes which are reduced are left in the result as dimensions
with size one.
With this option, the result will broadcast correctly against the input array.
exclude : bool
If `exclude` is true, reduction will be performed on the axes that are
NOT in axis instead.
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.sum(data, axis, keepdims, exclude)
def max(data, axis=None, keepdims=False, exclude=False):
""" Computes the max of array elements over given axes.
Parameters
----------
data : relay.Expr
The input data
axis : None or int or tuple of int
Axis or axes along which a argmin operation is performed.
The default, axis=None, will find the indices of minimum element all of the elements of
the input array. If axis is negative it counts from the last to the first axis.
keepdims : bool
If this is set to True, the axes which are reduced are left in the result as dimensions
with size one.
With this option, the result will broadcast correctly against the input array.
exclude : bool
If `exclude` is true, reduction will be performed on the axes that are
NOT in axis instead.
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.max(data, axis, keepdims, exclude)
def min(data, axis=None, keepdims=False, exclude=False):
"""Computes the min of array elements over given axes.
Parameters
----------
data : relay.Expr
The input data
axis : None or int or tuple of int
Axis or axes along which a argmin operation is performed.
The default, axis=None, will find the indices of minimum element all of the elements of
the input array. If axis is negative it counts from the last to the first axis.
keepdims : bool
If this is set to True, the axes which are reduced are left in the result as dimensions
with size one.
With this option, the result will broadcast correctly against the input array.
exclude : bool
If `exclude` is true, reduction will be performed on the axes that are
NOT in axis instead.
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.min(data, axis, keepdims, exclude)
def mean(data, axis=None, keepdims=False, exclude=False):
"""Computes the mean of array elements over given axes.
Parameters
----------
data : relay.Expr
The input data
axis : None or int or tuple of int
Axis or axes along which a argmin operation is performed.
The default, axis=None, will find the indices of minimum element all of the elements of
the input array. If axis is negative it counts from the last to the first axis.
keepdims : bool
If this is set to True, the axes which are reduced are left in the result as dimensions
with size one.
With this option, the result will broadcast correctly against the input array.
exclude : bool
If `exclude` is true, reduction will be performed on the axes that are
NOT in axis instead.
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.mean(data, axis, keepdims, exclude)
def prod(data, axis=None, keepdims=False, exclude=False):
"""Computes the products of array elements over given axes.
Parameters
----------
data : relay.Expr
The input data
axis : None or int or tuple of int
Axis or axes along which a argmin operation is performed.
The default, axis=None, will find the indices of minimum element all of the elements of
the input array. If axis is negative it counts from the last to the first axis.
keepdims : bool
If this is set to True, the axes which are reduced are left in the result as dimensions
with size one.
With this option, the result will broadcast correctly against the input array.
exclude : bool
If `exclude` is true, reduction will be performed on the axes that are
NOT in axis instead.
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.prod(data, axis, keepdims, exclude)
......@@ -7,6 +7,7 @@
#include <tvm/relay/op.h>
#include <numeric>
#include <limits>
#include "../op_common.h"
#include "../type_relations.h"
namespace tvm {
......@@ -19,7 +20,7 @@ struct ReduceAttrs : public tvm::AttrsNode<ReduceAttrs> {
bool exclude;
TVM_DECLARE_ATTRS(ReduceAttrs, "relay.attrs.ReduceAttrs") {
TVM_ATTR_FIELD(axis).set_default(Array<IndexExpr>({}))
TVM_ATTR_FIELD(axis).set_default(NullValue<Array<IndexExpr>>())
.describe(R"code(The axis or axes along which to perform the reduction.
The default, `axis=()`, will compute over all elements into a
......@@ -158,10 +159,7 @@ bool ArgReduceRel(const Array<Type>& types,
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) return false;
CHECK(static_cast<int>(data->shape.size()) != 0);
std::vector<IndexExpr> in_shape;
for (auto i : data->shape) {
in_shape.push_back(i);
}
std::vector<IndexExpr>&& in_shape = AsVector(data->shape);
const ReduceAttrs* param = attrs.as<ReduceAttrs>();
CHECK(param != nullptr);
......@@ -172,6 +170,31 @@ bool ArgReduceRel(const Array<Type>& types,
return true;
}
/*!
* \brief ReduceRel Output type and shape relation evaluation function.
* \param num_inputs Number of input types in the args.
* \param attrs The additional attributes of the operator.
* \param reporter The reporter to report solution to.
* \return false if This relation cannot be resolved. true if this relation has been resolved.
*/
bool ReduceRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) return false;
CHECK(static_cast<int>(data->shape.size()) != 0);
std::vector<IndexExpr>&& in_shape = AsVector(data->shape);
const ReduceAttrs* param = attrs.as<ReduceAttrs>();
CHECK(param != nullptr);
// assign output type and shape
auto oshape = ReduceShapeImpl(in_shape, param, reporter);
reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));
return true;
}
#define RELAY_REGISTER_REDUCE_OP(OpName) \
TVM_REGISTER_API("relay.op._make." OpName) \
......@@ -213,5 +236,88 @@ values over a given axis.
.set_support_level(4)
.add_type_rel("ArgReduce", ArgReduceRel);
RELAY_REGISTER_REDUCE_OP("sum")
.describe(R"code(Computes the sum of array elements over given axes.
Example::
data = [[[1,2],[2,3],[1,3]],
[[1,4],[4,3],[5,2]],
[[7,1],[7,2],[7,3]]]
sum(data, axis=1)
[[ 4. 8.]
[ 10. 9.]
[ 21. 6.]]
sum(data, axis=[1,2])
[ 12. 19. 27.]
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.ReduceAttrs")
.set_support_level(4)
.add_type_rel("Reduce", ReduceRel);
RELAY_REGISTER_REDUCE_OP("max")
.describe(R"code(Computes the max of array elements over given axes.
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.ReduceAttrs")
.set_support_level(4)
.add_type_rel("Reduce", ReduceRel);
RELAY_REGISTER_REDUCE_OP("min")
.describe(R"code(Computes the min of array elements over given axes.
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.ReduceAttrs")
.set_support_level(4)
.add_type_rel("Reduce", ReduceRel);
RELAY_REGISTER_REDUCE_OP("mean")
.describe(R"code(Computes the mean of array elements over given axes.
Example::
data = [[[1,2],[2,3],[1,3]],
[[1,4],[4,3],[5,2]],
[[7,1],[7,2],[7,3]]]
mean(data)
[3.22]
mean(data, axis=[1,2])
[ 2. 3.16666667 4.5]
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.ReduceAttrs")
.set_support_level(4)
.add_type_rel("Reduce", ReduceRel);
RELAY_REGISTER_REDUCE_OP("prod")
.describe(R"code(Computes the products of array elements over given axes.
Example::
data = [[[1,2],[2,3],[1,3]],
[[1,4],[4,3],[5,2]],
[[7,1],[7,2],[7,3]]]
mean(data, axis=1)
[35562240]
mean(data, axis=[1,2])
[ 36 480 2058]
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.ReduceAttrs")
.set_support_level(4)
.add_type_rel("Reduce", ReduceRel);
} // namespace relay
} // namespace tvm
......@@ -46,27 +46,6 @@ def test_binary_int_broadcast():
assert zz.checked_type == relay.TensorType((5, 10, 4), "int32")
def test_arg_reduce():
for op in [relay.argmax, relay.argmin]:
n, c , h, w = 10, 20, 3, 4
x = relay.var("x", relay.ty.TensorType((n, c , h, w), "float32"))
z = relay.argmax(x, axis=(1,))
"axis=" in z.astext()
zz = relay.ir_pass.infer_type(z)
assert zz.checked_type == relay.ty.TensorType((n, h, w), "int32")
n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
x = relay.var("x", relay.ty.TensorType((n, c , h, w), "float32"))
z = relay.argmax(x, axis=(2,), keepdims=True)
zz = relay.ir_pass.infer_type(z)
assert zz.checked_type == relay.ty.TensorType((n, c , 1, w), "int32")
n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
x = relay.var("x", relay.ty.TensorType((n, c , h, w), "float32"))
z = relay.argmax(x, axis=(2,), keepdims=True, exclude=True)
zz = relay.ir_pass.infer_type(z)
assert zz.checked_type == relay.ty.TensorType((1, 1 , h, 1), "int32")
def test_where():
cond = relay.var("cond", relay.TensorType((3, 4), "float32"))
x = relay.var("x", relay.TensorType((3, 4), "float32"))
......@@ -76,9 +55,45 @@ def test_where():
assert zz.checked_type == relay.TensorType((3, 4), "float32")
def verify_reduce(test_func, data, axis, keepdims, exclude, output):
x = relay.var("x", relay.TensorType(data, "float32"))
z = test_func(x, axis, keepdims, exclude)
zz = relay.ir_pass.infer_type(z)
if axis:
assert "axis=" in z.astext()
if keepdims:
assert "keepdims=" in z.astext()
if exclude:
assert "exclude=" in z.astext()
out_type = "int32" if test_func in [relay.argmin, relay.argmax] else "float32"
assert zz.checked_type == relay.ty.TensorType(output, out_type)
def test_reduce_functions():
d1, d2, d3, d4 = tvm.var("d1"), tvm.var("d2"), tvm.var("d3"), tvm.var("d4")
for func in [relay.sum,
relay.max,
relay.min,
relay.mean,
relay.prod,
relay.argmin,
relay.argmax]:
verify_reduce(func, (d1, d2, d3, d4), (2,), True, False, (d1, d2, 1, d4))
verify_reduce(func, (d1, d2, d3), (1,), True, False, (d1, 1, d3))
verify_reduce(func, (d1, d2, d3), None, True, False, (1, 1, 1))
verify_reduce(func, (d1, d2, d3), (0, 1), True, False, (1, 1, d3))
verify_reduce(func, (2, 3, 4), (1,), True, False, (2, 1, 4))
verify_reduce(func, (2, 3, 4), (0, 1, 2), False, False, ())
verify_reduce(func, (4, 4, 3), None, True, False, (1, 1, 1))
verify_reduce(func, (4, 4, 3), None, False, True, ())
verify_reduce(func, (4, 4, 3), (0, 2), False, False, (4,))
verify_reduce(func, (128, 24, 128), (0, 1), False, False, (128,))
verify_reduce(func, (128, 24, 128), (0, 2), False, False, (24,))
verify_reduce(func, (128, 24, 128), (0, 1), True, False, (1, 1, 128))
verify_reduce(func, (128, 24, 128), (0, 2), True, False, (1, 24, 1))
if __name__ == "__main__":
test_binary_op()
test_cmp_type()
test_binary_int_broadcast()
test_where()
test_arg_reduce()
test_reduce_functions()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment