Commit 4001569e by Siju Committed by Tianqi Chen

[RELAY][OP]Reduction operator framework, argmax, argmin (#1865)

parent fd392677
...@@ -106,6 +106,8 @@ This level enables additional math and transform operators. ...@@ -106,6 +106,8 @@ This level enables additional math and transform operators.
tvm.relay.minimum tvm.relay.minimum
tvm.relay.pow tvm.relay.pow
tvm.relay.where tvm.relay.where
tvm.relay.argmax
tvm.relay.argmin
**Level 5: Vision/Image Operators** **Level 5: Vision/Image Operators**
...@@ -183,6 +185,8 @@ Level 4 Definitions ...@@ -183,6 +185,8 @@ Level 4 Definitions
.. autofunction:: tvm.relay.minimum .. autofunction:: tvm.relay.minimum
.. autofunction:: tvm.relay.pow .. autofunction:: tvm.relay.pow
.. autofunction:: tvm.relay.where .. autofunction:: tvm.relay.where
.. autofunction:: tvm.relay.argmax
.. autofunction:: tvm.relay.argmin
Level 5 Definitions Level 5 Definitions
......
...@@ -270,6 +270,13 @@ class TypeReporterNode : public Node { ...@@ -270,6 +270,13 @@ class TypeReporterNode : public Node {
*/ */
TVM_DLL virtual void Assign(const Type& dst, const Type& src) = 0; TVM_DLL virtual void Assign(const Type& dst, const Type& src) = 0;
/*! /*!
* \brief assert shape expression comparison.
* \param cond The condition of operation.
* \return false if assertation can be proven to have failed
* true if solver can still proceed.
*/
TVM_DLL virtual bool Assert(const IndexExpr& cond)= 0;
/*!
* \brief assert shape expression equals each other. * \brief assert shape expression equals each other.
* \param lhs The left operand. * \param lhs The left operand.
* \param rhs The right operand. * \param rhs The right operand.
......
...@@ -9,6 +9,7 @@ from . import ir_builder ...@@ -9,6 +9,7 @@ from . import ir_builder
# Root operators # Root operators
from .op import Op from .op import Op
from .op.reduce import *
from .op.tensor import * from .op.tensor import *
from .op.transform import * from .op.transform import *
from . import nn from . import nn
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
from .op import get, register, Op from .op import get, register, Op
# Operators # Operators
from .reduce import *
from .tensor import * from .tensor import *
from .transform import * from .transform import *
from . import nn from . import nn
......
"""Reduce operators."""
# pylint: disable=redefined-builtin
from . import _make
def argmax(data, axis=None, keepdims=False, exclude=False):
"""Returns the indices of the maximum values along an axis.
Parameters
----------
data : relay.Expr
The input data
axis : None or int or tuple of int
Axis or axes along which a argmin operation is performed.
The default, axis=None, will find the indices of maximum element all of the elements of
the input array. If axis is negative it counts from the last to the first axis.
keepdims : bool
If this is set to True, the axes which are reduced are left in the result as dimensions
with size one.
With this option, the result will broadcast correctly against the input array.
exclude : bool
If `exclude` is true, reduction will be performed on the axes that are
NOT in axis instead.
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.argmax(data, axis, keepdims, exclude)
def argmin(data, axis=None, keepdims=False, exclude=False):
"""Returns the indices of the minimum values along an axis.
Parameters
----------
data : relay.Expr
The input data
axis : None or int or tuple of int
Axis or axes along which a argmin operation is performed.
The default, axis=None, will find the indices of minimum element all of the elements of
the input array. If axis is negative it counts from the last to the first axis.
keepdims : bool
If this is set to True, the axes which are reduced are left in the result as dimensions
with size one.
With this option, the result will broadcast correctly against the input array.
exclude : bool
If `exclude` is true, reduction will be performed on the axes that are
NOT in axis instead.
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.argmin(data, axis, keepdims, exclude)
/*!
* Copyright (c) 2018 by Contributors
* \file reduce.cc
* \brief Reduction operators.
*/
#include <tvm/relay/expr.h>
#include <tvm/relay/op.h>
#include <numeric>
#include <limits>
#include "../type_relations.h"
namespace tvm {
namespace relay {
/*! \brief Attributes for Reduce operators */
struct ReduceAttrs : public tvm::AttrsNode<ReduceAttrs> {
Array<IndexExpr> axis;
bool keepdims;
bool exclude;
TVM_DECLARE_ATTRS(ReduceAttrs, "relay.attrs.ReduceAttrs") {
TVM_ATTR_FIELD(axis).set_default(Array<IndexExpr>({}))
.describe(R"code(The axis or axes along which to perform the reduction.
The default, `axis=()`, will compute over all elements into a
scalar array with shape `(1,)`.
If `axis` is int, a reduction is performed on a particular axis.
If `axis` is a tuple of ints, a reduction is performed on all the axes
specified in the tuple.
If `exclude` is true, reduction will be performed on the axes that are
NOT in axis instead.)code");
TVM_ATTR_FIELD(keepdims).set_default(false)
.describe("If this is set to `True`, the reduced axes are left "
"in the result as dimension with size one.");
TVM_ATTR_FIELD(exclude).set_default(false)
.describe("Whether to perform reduction on axis that are NOT in axis instead.");
}
};
/*!
* \brief GetReduceAxes, get the new axis from indim and other arguments
* \param indim Number of dimensions of input data.
* \param axis The input axis vector.
* \param exclude Whether 'axis' input given is the excluded axis.
* \return r_axes The new reduced axes of the output.
*/
inline std::vector<int64_t> GetReduceAxes(const uint32_t indim,
const Array<IndexExpr>& inaxis,
bool exclude) {
if (!inaxis.defined()) {
std::vector<int64_t> r_axes(indim);
std::iota(r_axes.begin(), r_axes.end(), 0);
return r_axes;
}
std::vector<int64_t> in_axes;
for (auto i : inaxis) {
const int64_t* k = as_const_int(i);
CHECK(k != nullptr) << "Reduce axis need to be constant, cannot be symbolic";
int64_t axis = k[0];
if (axis < 0) {
axis = axis + indim;
}
// Check out of bounds error
CHECK(axis >= 0)
<< "Axis out of bounds in reduce operator.";
CHECK(axis < indim)
<< "Axis out of bounds in reduce operator.";
in_axes.push_back(axis);
}
CHECK(in_axes[in_axes.size() - 1] < indim)
<< "Reduction axis " << in_axes[in_axes.size() - 1]
<< " exceeds input dimensions " << indim;
std::sort(in_axes.begin(), in_axes.end());
if (!exclude) {
return in_axes;
}
auto r_size = indim - in_axes.size();
std::vector<int64_t> r_axes(r_size);
for (uint32_t i = 0, j = 0, k = 0; i < indim; ++i) {
if (j < in_axes.size() && in_axes[j] == i) {
++j;
continue;
}
r_axes[k++] = i;
}
return r_axes;
}
/*!
* \brief ReduceShapeImpl get the outshape for the reduction operator
* \param in_shape Shape of input data.
* \param param ReduceAttrs details.
* \param reporter The reporter to report solution to.
* \return oshape Output shape inferred.
*/
inline std::vector<IndexExpr> ReduceShapeImpl(const std::vector<IndexExpr> &in_shape,
const ReduceAttrs* param,
const TypeReporter& reporter) {
uint32_t indim = in_shape.size();
auto r_axes = GetReduceAxes(indim, param->axis, param->exclude);
if (!r_axes.size()) {
return in_shape;
}
auto max_shape = make_const(Int(64), 1);
for (int64_t axis : r_axes) {
max_shape *= in_shape[axis];
}
CHECK(reporter->Assert(max_shape < make_const(Int(64), std::numeric_limits<int32_t>::max())))
<< "The maximum possible index of reduced shape cannot be more than int32 max.";
if (param->keepdims) {
std::vector<IndexExpr> oshape(in_shape);
for (unsigned i = 0, j = 0; i < indim; ++i) {
if (j >= r_axes.size() || !(r_axes[j] == i)) {
continue;
}
oshape[i] = 1;
++j;
}
return oshape;
} else {
auto osize = indim - r_axes.size();
std::vector<IndexExpr> oshape(osize);
for (unsigned i = 0, j = 0, k = 0; i < indim; ++i) {
if (j < r_axes.size() && (r_axes[j] == i)) {
++j;
continue;
}
oshape[k++] = in_shape[i];
}
return oshape;
}
}
/*!
* \brief ArgReduceRel Output type and shape relation evaluation function.
* \param num_inputs Number of input types in the args.
* \param attrs The additional attributes of the operator.
* \param reporter The reporter to report solution to.
* \return false if This relation cannot be resolved. true if this relation has been resolved.
*/
bool ArgReduceRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) return false;
CHECK(static_cast<int>(data->shape.size()) != 0);
std::vector<IndexExpr> in_shape;
for (auto i : data->shape) {
in_shape.push_back(i);
}
const ReduceAttrs* param = attrs.as<ReduceAttrs>();
CHECK(param != nullptr);
// assign output type and shape
auto oshape = ReduceShapeImpl(in_shape, param, reporter);
reporter->Assign(types[1], TensorTypeNode::make(oshape, Int(32)));
return true;
}
#define RELAY_REGISTER_REDUCE_OP(OpName) \
TVM_REGISTER_API("relay.op._make." OpName) \
.set_body([](const TVMArgs& args, TVMRetValue* rv) { \
auto make_func = [](Expr data, \
Array<IndexExpr> axis, \
bool keepdims, \
bool exclude) { \
auto attrs = make_node<ReduceAttrs>(); \
attrs->axis = std::move(axis); \
attrs->keepdims = keepdims; \
attrs->exclude = exclude; \
static const Op& op = Op::Get(OpName); \
return CallNode::make(op, {data}, Attrs(attrs), {}); \
}; \
runtime::detail::unpack_call<Expr, 4>(make_func, args, rv); \
}); \
RELAY_REGISTER_OP(OpName) \
.set_num_inputs(1) \
.add_argument("data", "Tensor", "The input tensor.")
RELAY_REGISTER_REDUCE_OP("argmax")
.describe(R"code(Creates an operation that finds the indices of the maximum
values over a given axis.
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.set_support_level(4)
.add_type_rel("ArgReduce", ArgReduceRel);
RELAY_REGISTER_REDUCE_OP("argmin")
.describe(R"code(Creates an operation that finds the indices of the minimum
values over a given axis.
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.set_support_level(4)
.add_type_rel("ArgReduce", ArgReduceRel);
} // namespace relay
} // namespace tvm
...@@ -18,6 +18,13 @@ class TypeSolver::Reporter : public TypeReporterNode { ...@@ -18,6 +18,13 @@ class TypeSolver::Reporter : public TypeReporterNode {
solver_->Unify(dst, src); solver_->Unify(dst, src);
} }
bool Assert(const IndexExpr& cond) final {
if (const uint64_t* pdiff = as_const_uint(cond)) {
return pdiff[0];
}
return true;
}
bool AssertEQ(const IndexExpr& lhs, const IndexExpr& rhs) final { bool AssertEQ(const IndexExpr& lhs, const IndexExpr& rhs) final {
// early warning constant case. // early warning constant case.
IndexExpr diff = lhs - rhs; IndexExpr diff = lhs - rhs;
......
...@@ -93,6 +93,94 @@ def test_binary_broadcast(): ...@@ -93,6 +93,94 @@ def test_binary_broadcast():
ftype = func.checked_type ftype = func.checked_type
assert ftype.ret_type == relay.TensorType((5, 10, 4), "int32") assert ftype.ret_type == relay.TensorType((5, 10, 4), "int32")
def test_argmax():
ib = relay.ir_builder.IRBuilder()
n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
x = ib.param("x", relay.ty.TensorType((n, c , h, w), "float32"))
with ib.function(x) as func:
ib.ret(relay.argmax(x, axis=(1,)))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
assert ftype.ret_type == relay.ty.TensorType((n, h, w), "int32")
ib = relay.ir_builder.IRBuilder()
n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
x = ib.param("x", relay.ty.TensorType((n, c , h, w), "float32"))
with ib.function(x) as func:
ib.ret(relay.argmax(x, axis=(2,), keepdims=True))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
assert ftype.ret_type == relay.ty.TensorType((n, c , 1, w), "int32")
ib = relay.ir_builder.IRBuilder()
n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
x = ib.param("x", relay.ty.TensorType((n, c , h, w), "float32"))
with ib.function(x) as func:
ib.ret(relay.argmax(x, axis=(2,), keepdims=True, exclude=True))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
assert ftype.ret_type == relay.ty.TensorType((1, 1 , h, 1), "int32")
def test_argmin():
ib = relay.ir_builder.IRBuilder()
n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
x = ib.param("x", relay.ty.TensorType((n, c , h, w), "float32"))
with ib.function(x) as func:
ib.ret(relay.argmax(x, axis=(1,)))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
assert ftype.ret_type == relay.ty.TensorType((n, h, w), "int32")
ib = relay.ir_builder.IRBuilder()
n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
x = ib.param("x", relay.ty.TensorType((n, c , h, w), "float32"))
with ib.function(x) as func:
ib.ret(relay.argmin(x, axis=(2,), keepdims=True))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
assert ftype.ret_type == relay.ty.TensorType((n, c , 1, w), "int32")
ib = relay.ir_builder.IRBuilder()
n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
x = ib.param("x", relay.ty.TensorType((n, c , h, w), "float32"))
with ib.function(x) as func:
ib.ret(relay.argmin(x, axis=(2,), keepdims=True, exclude=True))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
assert ftype.ret_type == relay.ty.TensorType((1, 1 , h, 1), "int32")
ib = relay.ir_builder.IRBuilder()
n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
x = ib.param("x", relay.ty.TensorType((n, c , h, w), "float32"))
with ib.function(x) as func:
ib.ret(relay.argmin(x, axis=(2,1), keepdims=True, exclude=True))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
assert ftype.ret_type == relay.ty.TensorType((1, c , h, 1), "int32")
ib = relay.ir_builder.IRBuilder()
n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
x = ib.param("x", relay.ty.TensorType((n, c , h, w), "float32"))
with ib.function(x) as func:
ib.ret(relay.argmin(x, axis=None, keepdims=True, exclude=True))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
assert ftype.ret_type == relay.ty.TensorType((1, 1 , 1, 1), "int32")
def test_where(): def test_where():
ib = relay.ir_builder.IRBuilder() ib = relay.ir_builder.IRBuilder()
cond = ib.param("cond", relay.TensorType((3, 4), "float32")) cond = ib.param("cond", relay.TensorType((3, 4), "float32"))
...@@ -113,3 +201,5 @@ if __name__ == "__main__": ...@@ -113,3 +201,5 @@ if __name__ == "__main__":
test_binary_broadcast() test_binary_broadcast()
test_where() test_where()
test_multibox_prior() test_multibox_prior()
test_argmax()
test_argmin()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment