Commit 02a8be10 by Siju Committed by Tianqi Chen

[RELAY]Reduce ops sum/max/min/mean/prod (#1927)

parent 7cd7dbff
...@@ -108,6 +108,11 @@ This level enables additional math and transform operators. ...@@ -108,6 +108,11 @@ This level enables additional math and transform operators.
tvm.relay.where tvm.relay.where
tvm.relay.argmax tvm.relay.argmax
tvm.relay.argmin tvm.relay.argmin
tvm.relay.sum
tvm.relay.max
tvm.relay.min
tvm.relay.mean
tvm.relay.prod
**Level 5: Vision/Image Operators** **Level 5: Vision/Image Operators**
...@@ -187,6 +192,11 @@ Level 4 Definitions ...@@ -187,6 +192,11 @@ Level 4 Definitions
.. autofunction:: tvm.relay.where .. autofunction:: tvm.relay.where
.. autofunction:: tvm.relay.argmax .. autofunction:: tvm.relay.argmax
.. autofunction:: tvm.relay.argmin .. autofunction:: tvm.relay.argmin
.. autofunction:: tvm.relay.sum
.. autofunction:: tvm.relay.max
.. autofunction:: tvm.relay.min
.. autofunction:: tvm.relay.mean
.. autofunction:: tvm.relay.prod
Level 5 Definitions Level 5 Definitions
......
...@@ -30,7 +30,6 @@ def argmax(data, axis=None, keepdims=False, exclude=False): ...@@ -30,7 +30,6 @@ def argmax(data, axis=None, keepdims=False, exclude=False):
result : relay.Expr result : relay.Expr
The computed result. The computed result.
""" """
return _make.argmax(data, axis, keepdims, exclude) return _make.argmax(data, axis, keepdims, exclude)
def argmin(data, axis=None, keepdims=False, exclude=False): def argmin(data, axis=None, keepdims=False, exclude=False):
...@@ -60,5 +59,154 @@ def argmin(data, axis=None, keepdims=False, exclude=False): ...@@ -60,5 +59,154 @@ def argmin(data, axis=None, keepdims=False, exclude=False):
result : relay.Expr result : relay.Expr
The computed result. The computed result.
""" """
return _make.argmin(data, axis, keepdims, exclude) return _make.argmin(data, axis, keepdims, exclude)
def sum(data, axis=None, keepdims=False, exclude=False):
"""Computes the sum of array elements over given axes.
Parameters
----------
data : relay.Expr
The input data
axis : None or int or tuple of int
Axis or axes along which a argmin operation is performed.
The default, axis=None, will find the indices of minimum element all of the elements of
the input array. If axis is negative it counts from the last to the first axis.
keepdims : bool
If this is set to True, the axes which are reduced are left in the result as dimensions
with size one.
With this option, the result will broadcast correctly against the input array.
exclude : bool
If `exclude` is true, reduction will be performed on the axes that are
NOT in axis instead.
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.sum(data, axis, keepdims, exclude)
def max(data, axis=None, keepdims=False, exclude=False):
""" Computes the max of array elements over given axes.
Parameters
----------
data : relay.Expr
The input data
axis : None or int or tuple of int
Axis or axes along which a argmin operation is performed.
The default, axis=None, will find the indices of minimum element all of the elements of
the input array. If axis is negative it counts from the last to the first axis.
keepdims : bool
If this is set to True, the axes which are reduced are left in the result as dimensions
with size one.
With this option, the result will broadcast correctly against the input array.
exclude : bool
If `exclude` is true, reduction will be performed on the axes that are
NOT in axis instead.
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.max(data, axis, keepdims, exclude)
def min(data, axis=None, keepdims=False, exclude=False):
"""Computes the min of array elements over given axes.
Parameters
----------
data : relay.Expr
The input data
axis : None or int or tuple of int
Axis or axes along which a argmin operation is performed.
The default, axis=None, will find the indices of minimum element all of the elements of
the input array. If axis is negative it counts from the last to the first axis.
keepdims : bool
If this is set to True, the axes which are reduced are left in the result as dimensions
with size one.
With this option, the result will broadcast correctly against the input array.
exclude : bool
If `exclude` is true, reduction will be performed on the axes that are
NOT in axis instead.
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.min(data, axis, keepdims, exclude)
def mean(data, axis=None, keepdims=False, exclude=False):
"""Computes the mean of array elements over given axes.
Parameters
----------
data : relay.Expr
The input data
axis : None or int or tuple of int
Axis or axes along which a argmin operation is performed.
The default, axis=None, will find the indices of minimum element all of the elements of
the input array. If axis is negative it counts from the last to the first axis.
keepdims : bool
If this is set to True, the axes which are reduced are left in the result as dimensions
with size one.
With this option, the result will broadcast correctly against the input array.
exclude : bool
If `exclude` is true, reduction will be performed on the axes that are
NOT in axis instead.
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.mean(data, axis, keepdims, exclude)
def prod(data, axis=None, keepdims=False, exclude=False):
"""Computes the products of array elements over given axes.
Parameters
----------
data : relay.Expr
The input data
axis : None or int or tuple of int
Axis or axes along which a argmin operation is performed.
The default, axis=None, will find the indices of minimum element all of the elements of
the input array. If axis is negative it counts from the last to the first axis.
keepdims : bool
If this is set to True, the axes which are reduced are left in the result as dimensions
with size one.
With this option, the result will broadcast correctly against the input array.
exclude : bool
If `exclude` is true, reduction will be performed on the axes that are
NOT in axis instead.
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.prod(data, axis, keepdims, exclude)
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include <tvm/relay/op.h> #include <tvm/relay/op.h>
#include <numeric> #include <numeric>
#include <limits> #include <limits>
#include "../op_common.h"
#include "../type_relations.h" #include "../type_relations.h"
namespace tvm { namespace tvm {
...@@ -19,7 +20,7 @@ struct ReduceAttrs : public tvm::AttrsNode<ReduceAttrs> { ...@@ -19,7 +20,7 @@ struct ReduceAttrs : public tvm::AttrsNode<ReduceAttrs> {
bool exclude; bool exclude;
TVM_DECLARE_ATTRS(ReduceAttrs, "relay.attrs.ReduceAttrs") { TVM_DECLARE_ATTRS(ReduceAttrs, "relay.attrs.ReduceAttrs") {
TVM_ATTR_FIELD(axis).set_default(Array<IndexExpr>({})) TVM_ATTR_FIELD(axis).set_default(NullValue<Array<IndexExpr>>())
.describe(R"code(The axis or axes along which to perform the reduction. .describe(R"code(The axis or axes along which to perform the reduction.
The default, `axis=()`, will compute over all elements into a The default, `axis=()`, will compute over all elements into a
...@@ -158,10 +159,7 @@ bool ArgReduceRel(const Array<Type>& types, ...@@ -158,10 +159,7 @@ bool ArgReduceRel(const Array<Type>& types,
const auto* data = types[0].as<TensorTypeNode>(); const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) return false; if (data == nullptr) return false;
CHECK(static_cast<int>(data->shape.size()) != 0); CHECK(static_cast<int>(data->shape.size()) != 0);
std::vector<IndexExpr> in_shape; std::vector<IndexExpr>&& in_shape = AsVector(data->shape);
for (auto i : data->shape) {
in_shape.push_back(i);
}
const ReduceAttrs* param = attrs.as<ReduceAttrs>(); const ReduceAttrs* param = attrs.as<ReduceAttrs>();
CHECK(param != nullptr); CHECK(param != nullptr);
...@@ -172,6 +170,31 @@ bool ArgReduceRel(const Array<Type>& types, ...@@ -172,6 +170,31 @@ bool ArgReduceRel(const Array<Type>& types,
return true; return true;
} }
/*!
* \brief ReduceRel Output type and shape relation evaluation function.
* \param num_inputs Number of input types in the args.
* \param attrs The additional attributes of the operator.
* \param reporter The reporter to report solution to.
* \return false if This relation cannot be resolved. true if this relation has been resolved.
*/
bool ReduceRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) return false;
CHECK(static_cast<int>(data->shape.size()) != 0);
std::vector<IndexExpr>&& in_shape = AsVector(data->shape);
const ReduceAttrs* param = attrs.as<ReduceAttrs>();
CHECK(param != nullptr);
// assign output type and shape
auto oshape = ReduceShapeImpl(in_shape, param, reporter);
reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));
return true;
}
#define RELAY_REGISTER_REDUCE_OP(OpName) \ #define RELAY_REGISTER_REDUCE_OP(OpName) \
TVM_REGISTER_API("relay.op._make." OpName) \ TVM_REGISTER_API("relay.op._make." OpName) \
...@@ -213,5 +236,88 @@ values over a given axis. ...@@ -213,5 +236,88 @@ values over a given axis.
.set_support_level(4) .set_support_level(4)
.add_type_rel("ArgReduce", ArgReduceRel); .add_type_rel("ArgReduce", ArgReduceRel);
RELAY_REGISTER_REDUCE_OP("sum")
.describe(R"code(Computes the sum of array elements over given axes.
Example::
data = [[[1,2],[2,3],[1,3]],
[[1,4],[4,3],[5,2]],
[[7,1],[7,2],[7,3]]]
sum(data, axis=1)
[[ 4. 8.]
[ 10. 9.]
[ 21. 6.]]
sum(data, axis=[1,2])
[ 12. 19. 27.]
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.ReduceAttrs")
.set_support_level(4)
.add_type_rel("Reduce", ReduceRel);
RELAY_REGISTER_REDUCE_OP("max")
.describe(R"code(Computes the max of array elements over given axes.
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.ReduceAttrs")
.set_support_level(4)
.add_type_rel("Reduce", ReduceRel);
RELAY_REGISTER_REDUCE_OP("min")
.describe(R"code(Computes the min of array elements over given axes.
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.ReduceAttrs")
.set_support_level(4)
.add_type_rel("Reduce", ReduceRel);
RELAY_REGISTER_REDUCE_OP("mean")
.describe(R"code(Computes the mean of array elements over given axes.
Example::
data = [[[1,2],[2,3],[1,3]],
[[1,4],[4,3],[5,2]],
[[7,1],[7,2],[7,3]]]
mean(data)
[3.22]
mean(data, axis=[1,2])
[ 2. 3.16666667 4.5]
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.ReduceAttrs")
.set_support_level(4)
.add_type_rel("Reduce", ReduceRel);
RELAY_REGISTER_REDUCE_OP("prod")
.describe(R"code(Computes the products of array elements over given axes.
Example::
data = [[[1,2],[2,3],[1,3]],
[[1,4],[4,3],[5,2]],
[[7,1],[7,2],[7,3]]]
mean(data, axis=1)
[35562240]
mean(data, axis=[1,2])
[ 36 480 2058]
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.ReduceAttrs")
.set_support_level(4)
.add_type_rel("Reduce", ReduceRel);
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
...@@ -46,27 +46,6 @@ def test_binary_int_broadcast(): ...@@ -46,27 +46,6 @@ def test_binary_int_broadcast():
assert zz.checked_type == relay.TensorType((5, 10, 4), "int32") assert zz.checked_type == relay.TensorType((5, 10, 4), "int32")
def test_arg_reduce():
for op in [relay.argmax, relay.argmin]:
n, c , h, w = 10, 20, 3, 4
x = relay.var("x", relay.ty.TensorType((n, c , h, w), "float32"))
z = relay.argmax(x, axis=(1,))
"axis=" in z.astext()
zz = relay.ir_pass.infer_type(z)
assert zz.checked_type == relay.ty.TensorType((n, h, w), "int32")
n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
x = relay.var("x", relay.ty.TensorType((n, c , h, w), "float32"))
z = relay.argmax(x, axis=(2,), keepdims=True)
zz = relay.ir_pass.infer_type(z)
assert zz.checked_type == relay.ty.TensorType((n, c , 1, w), "int32")
n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
x = relay.var("x", relay.ty.TensorType((n, c , h, w), "float32"))
z = relay.argmax(x, axis=(2,), keepdims=True, exclude=True)
zz = relay.ir_pass.infer_type(z)
assert zz.checked_type == relay.ty.TensorType((1, 1 , h, 1), "int32")
def test_where(): def test_where():
cond = relay.var("cond", relay.TensorType((3, 4), "float32")) cond = relay.var("cond", relay.TensorType((3, 4), "float32"))
x = relay.var("x", relay.TensorType((3, 4), "float32")) x = relay.var("x", relay.TensorType((3, 4), "float32"))
...@@ -76,9 +55,45 @@ def test_where(): ...@@ -76,9 +55,45 @@ def test_where():
assert zz.checked_type == relay.TensorType((3, 4), "float32") assert zz.checked_type == relay.TensorType((3, 4), "float32")
def verify_reduce(test_func, data, axis, keepdims, exclude, output):
x = relay.var("x", relay.TensorType(data, "float32"))
z = test_func(x, axis, keepdims, exclude)
zz = relay.ir_pass.infer_type(z)
if axis:
assert "axis=" in z.astext()
if keepdims:
assert "keepdims=" in z.astext()
if exclude:
assert "exclude=" in z.astext()
out_type = "int32" if test_func in [relay.argmin, relay.argmax] else "float32"
assert zz.checked_type == relay.ty.TensorType(output, out_type)
def test_reduce_functions():
d1, d2, d3, d4 = tvm.var("d1"), tvm.var("d2"), tvm.var("d3"), tvm.var("d4")
for func in [relay.sum,
relay.max,
relay.min,
relay.mean,
relay.prod,
relay.argmin,
relay.argmax]:
verify_reduce(func, (d1, d2, d3, d4), (2,), True, False, (d1, d2, 1, d4))
verify_reduce(func, (d1, d2, d3), (1,), True, False, (d1, 1, d3))
verify_reduce(func, (d1, d2, d3), None, True, False, (1, 1, 1))
verify_reduce(func, (d1, d2, d3), (0, 1), True, False, (1, 1, d3))
verify_reduce(func, (2, 3, 4), (1,), True, False, (2, 1, 4))
verify_reduce(func, (2, 3, 4), (0, 1, 2), False, False, ())
verify_reduce(func, (4, 4, 3), None, True, False, (1, 1, 1))
verify_reduce(func, (4, 4, 3), None, False, True, ())
verify_reduce(func, (4, 4, 3), (0, 2), False, False, (4,))
verify_reduce(func, (128, 24, 128), (0, 1), False, False, (128,))
verify_reduce(func, (128, 24, 128), (0, 2), False, False, (24,))
verify_reduce(func, (128, 24, 128), (0, 1), True, False, (1, 1, 128))
verify_reduce(func, (128, 24, 128), (0, 2), True, False, (1, 24, 1))
if __name__ == "__main__": if __name__ == "__main__":
test_binary_op() test_binary_op()
test_cmp_type() test_cmp_type()
test_binary_int_broadcast() test_binary_int_broadcast()
test_where() test_where()
test_arg_reduce() test_reduce_functions()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment