Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
4001569e
Commit
4001569e
authored
Oct 18, 2018
by
Siju
Committed by
Tianqi Chen
Oct 18, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[RELAY][OP]Reduction operator framework, argmax, argmin (#1865)
parent
fd392677
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
391 additions
and
0 deletions
+391
-0
docs/langref/relay_op.rst
+4
-0
include/tvm/relay/type.h
+7
-0
python/tvm/relay/__init__.py
+1
-0
python/tvm/relay/op/__init__.py
+1
-0
python/tvm/relay/op/reduce.py
+64
-0
src/relay/op/tensor/reduce.cc
+217
-0
src/relay/pass/type_solver.cc
+7
-0
tests/python/relay/test_op_level4.py
+90
-0
No files found.
docs/langref/relay_op.rst
View file @
4001569e
...
...
@@ -106,6 +106,8 @@ This level enables additional math and transform operators.
tvm.relay.minimum
tvm.relay.pow
tvm.relay.where
tvm.relay.argmax
tvm.relay.argmin
**Level 5: Vision/Image Operators**
...
...
@@ -183,6 +185,8 @@ Level 4 Definitions
.. autofunction:: tvm.relay.minimum
.. autofunction:: tvm.relay.pow
.. autofunction:: tvm.relay.where
.. autofunction:: tvm.relay.argmax
.. autofunction:: tvm.relay.argmin
Level 5 Definitions
...
...
include/tvm/relay/type.h
View file @
4001569e
...
...
@@ -270,6 +270,13 @@ class TypeReporterNode : public Node {
*/
TVM_DLL
virtual
void
Assign
(
const
Type
&
dst
,
const
Type
&
src
)
=
0
;
/*!
* \brief assert shape expression comparison.
* \param cond The condition of operation.
* \return false if assertation can be proven to have failed
* true if solver can still proceed.
*/
TVM_DLL
virtual
bool
Assert
(
const
IndexExpr
&
cond
)
=
0
;
/*!
* \brief assert shape expression equals each other.
* \param lhs The left operand.
* \param rhs The right operand.
...
...
python/tvm/relay/__init__.py
View file @
4001569e
...
...
@@ -9,6 +9,7 @@ from . import ir_builder
# Root operators
from
.op
import
Op
from
.op.reduce
import
*
from
.op.tensor
import
*
from
.op.transform
import
*
from
.
import
nn
...
...
python/tvm/relay/op/__init__.py
View file @
4001569e
...
...
@@ -4,6 +4,7 @@
from
.op
import
get
,
register
,
Op
# Operators
from
.reduce
import
*
from
.tensor
import
*
from
.transform
import
*
from
.
import
nn
...
...
python/tvm/relay/op/reduce.py
0 → 100644
View file @
4001569e
"""Reduce operators."""
# pylint: disable=redefined-builtin
from
.
import
_make
def
argmax
(
data
,
axis
=
None
,
keepdims
=
False
,
exclude
=
False
):
"""Returns the indices of the maximum values along an axis.
Parameters
----------
data : relay.Expr
The input data
axis : None or int or tuple of int
Axis or axes along which a argmin operation is performed.
The default, axis=None, will find the indices of maximum element all of the elements of
the input array. If axis is negative it counts from the last to the first axis.
keepdims : bool
If this is set to True, the axes which are reduced are left in the result as dimensions
with size one.
With this option, the result will broadcast correctly against the input array.
exclude : bool
If `exclude` is true, reduction will be performed on the axes that are
NOT in axis instead.
Returns
-------
result : relay.Expr
The computed result.
"""
return
_make
.
argmax
(
data
,
axis
,
keepdims
,
exclude
)
def
argmin
(
data
,
axis
=
None
,
keepdims
=
False
,
exclude
=
False
):
"""Returns the indices of the minimum values along an axis.
Parameters
----------
data : relay.Expr
The input data
axis : None or int or tuple of int
Axis or axes along which a argmin operation is performed.
The default, axis=None, will find the indices of minimum element all of the elements of
the input array. If axis is negative it counts from the last to the first axis.
keepdims : bool
If this is set to True, the axes which are reduced are left in the result as dimensions
with size one.
With this option, the result will broadcast correctly against the input array.
exclude : bool
If `exclude` is true, reduction will be performed on the axes that are
NOT in axis instead.
Returns
-------
result : relay.Expr
The computed result.
"""
return
_make
.
argmin
(
data
,
axis
,
keepdims
,
exclude
)
src/relay/op/tensor/reduce.cc
0 → 100644
View file @
4001569e
/*!
* Copyright (c) 2018 by Contributors
* \file reduce.cc
* \brief Reduction operators.
*/
#include <tvm/relay/expr.h>
#include <tvm/relay/op.h>
#include <numeric>
#include <limits>
#include "../type_relations.h"
namespace
tvm
{
namespace
relay
{
/*! \brief Attributes for Reduce operators */
struct
ReduceAttrs
:
public
tvm
::
AttrsNode
<
ReduceAttrs
>
{
Array
<
IndexExpr
>
axis
;
bool
keepdims
;
bool
exclude
;
TVM_DECLARE_ATTRS
(
ReduceAttrs
,
"relay.attrs.ReduceAttrs"
)
{
TVM_ATTR_FIELD
(
axis
).
set_default
(
Array
<
IndexExpr
>
({}))
.
describe
(
R"code(The axis or axes along which to perform the reduction.
The default, `axis=()`, will compute over all elements into a
scalar array with shape `(1,)`.
If `axis` is int, a reduction is performed on a particular axis.
If `axis` is a tuple of ints, a reduction is performed on all the axes
specified in the tuple.
If `exclude` is true, reduction will be performed on the axes that are
NOT in axis instead.)code"
);
TVM_ATTR_FIELD
(
keepdims
).
set_default
(
false
)
.
describe
(
"If this is set to `True`, the reduced axes are left "
"in the result as dimension with size one."
);
TVM_ATTR_FIELD
(
exclude
).
set_default
(
false
)
.
describe
(
"Whether to perform reduction on axis that are NOT in axis instead."
);
}
};
/*!
* \brief GetReduceAxes, get the new axis from indim and other arguments
* \param indim Number of dimensions of input data.
* \param axis The input axis vector.
* \param exclude Whether 'axis' input given is the excluded axis.
* \return r_axes The new reduced axes of the output.
*/
inline
std
::
vector
<
int64_t
>
GetReduceAxes
(
const
uint32_t
indim
,
const
Array
<
IndexExpr
>&
inaxis
,
bool
exclude
)
{
if
(
!
inaxis
.
defined
())
{
std
::
vector
<
int64_t
>
r_axes
(
indim
);
std
::
iota
(
r_axes
.
begin
(),
r_axes
.
end
(),
0
);
return
r_axes
;
}
std
::
vector
<
int64_t
>
in_axes
;
for
(
auto
i
:
inaxis
)
{
const
int64_t
*
k
=
as_const_int
(
i
);
CHECK
(
k
!=
nullptr
)
<<
"Reduce axis need to be constant, cannot be symbolic"
;
int64_t
axis
=
k
[
0
];
if
(
axis
<
0
)
{
axis
=
axis
+
indim
;
}
// Check out of bounds error
CHECK
(
axis
>=
0
)
<<
"Axis out of bounds in reduce operator."
;
CHECK
(
axis
<
indim
)
<<
"Axis out of bounds in reduce operator."
;
in_axes
.
push_back
(
axis
);
}
CHECK
(
in_axes
[
in_axes
.
size
()
-
1
]
<
indim
)
<<
"Reduction axis "
<<
in_axes
[
in_axes
.
size
()
-
1
]
<<
" exceeds input dimensions "
<<
indim
;
std
::
sort
(
in_axes
.
begin
(),
in_axes
.
end
());
if
(
!
exclude
)
{
return
in_axes
;
}
auto
r_size
=
indim
-
in_axes
.
size
();
std
::
vector
<
int64_t
>
r_axes
(
r_size
);
for
(
uint32_t
i
=
0
,
j
=
0
,
k
=
0
;
i
<
indim
;
++
i
)
{
if
(
j
<
in_axes
.
size
()
&&
in_axes
[
j
]
==
i
)
{
++
j
;
continue
;
}
r_axes
[
k
++
]
=
i
;
}
return
r_axes
;
}
/*!
* \brief ReduceShapeImpl get the outshape for the reduction operator
* \param in_shape Shape of input data.
* \param param ReduceAttrs details.
* \param reporter The reporter to report solution to.
* \return oshape Output shape inferred.
*/
inline
std
::
vector
<
IndexExpr
>
ReduceShapeImpl
(
const
std
::
vector
<
IndexExpr
>
&
in_shape
,
const
ReduceAttrs
*
param
,
const
TypeReporter
&
reporter
)
{
uint32_t
indim
=
in_shape
.
size
();
auto
r_axes
=
GetReduceAxes
(
indim
,
param
->
axis
,
param
->
exclude
);
if
(
!
r_axes
.
size
())
{
return
in_shape
;
}
auto
max_shape
=
make_const
(
Int
(
64
),
1
);
for
(
int64_t
axis
:
r_axes
)
{
max_shape
*=
in_shape
[
axis
];
}
CHECK
(
reporter
->
Assert
(
max_shape
<
make_const
(
Int
(
64
),
std
::
numeric_limits
<
int32_t
>::
max
())))
<<
"The maximum possible index of reduced shape cannot be more than int32 max."
;
if
(
param
->
keepdims
)
{
std
::
vector
<
IndexExpr
>
oshape
(
in_shape
);
for
(
unsigned
i
=
0
,
j
=
0
;
i
<
indim
;
++
i
)
{
if
(
j
>=
r_axes
.
size
()
||
!
(
r_axes
[
j
]
==
i
))
{
continue
;
}
oshape
[
i
]
=
1
;
++
j
;
}
return
oshape
;
}
else
{
auto
osize
=
indim
-
r_axes
.
size
();
std
::
vector
<
IndexExpr
>
oshape
(
osize
);
for
(
unsigned
i
=
0
,
j
=
0
,
k
=
0
;
i
<
indim
;
++
i
)
{
if
(
j
<
r_axes
.
size
()
&&
(
r_axes
[
j
]
==
i
))
{
++
j
;
continue
;
}
oshape
[
k
++
]
=
in_shape
[
i
];
}
return
oshape
;
}
}
/*!
* \brief ArgReduceRel Output type and shape relation evaluation function.
* \param num_inputs Number of input types in the args.
* \param attrs The additional attributes of the operator.
* \param reporter The reporter to report solution to.
* \return false if This relation cannot be resolved. true if this relation has been resolved.
*/
bool
ArgReduceRel
(
const
Array
<
Type
>&
types
,
int
num_inputs
,
const
Attrs
&
attrs
,
const
TypeReporter
&
reporter
)
{
CHECK_EQ
(
types
.
size
(),
2
);
const
auto
*
data
=
types
[
0
].
as
<
TensorTypeNode
>
();
if
(
data
==
nullptr
)
return
false
;
CHECK
(
static_cast
<
int
>
(
data
->
shape
.
size
())
!=
0
);
std
::
vector
<
IndexExpr
>
in_shape
;
for
(
auto
i
:
data
->
shape
)
{
in_shape
.
push_back
(
i
);
}
const
ReduceAttrs
*
param
=
attrs
.
as
<
ReduceAttrs
>
();
CHECK
(
param
!=
nullptr
);
// assign output type and shape
auto
oshape
=
ReduceShapeImpl
(
in_shape
,
param
,
reporter
);
reporter
->
Assign
(
types
[
1
],
TensorTypeNode
::
make
(
oshape
,
Int
(
32
)));
return
true
;
}
#define RELAY_REGISTER_REDUCE_OP(OpName) \
TVM_REGISTER_API("relay.op._make." OpName) \
.set_body([](const TVMArgs& args, TVMRetValue* rv) { \
auto make_func = [](Expr data, \
Array<IndexExpr> axis, \
bool keepdims, \
bool exclude) { \
auto attrs = make_node<ReduceAttrs>(); \
attrs->axis = std::move(axis); \
attrs->keepdims = keepdims; \
attrs->exclude = exclude; \
static const Op& op = Op::Get(OpName); \
return CallNode::make(op, {data}, Attrs(attrs), {}); \
}; \
runtime::detail::unpack_call<Expr, 4>(make_func, args, rv); \
}); \
RELAY_REGISTER_OP(OpName) \
.set_num_inputs(1) \
.add_argument("data", "Tensor", "The input tensor.")
RELAY_REGISTER_REDUCE_OP
(
"argmax"
)
.
describe
(
R"code(Creates an operation that finds the indices of the maximum
values over a given axis.
)code"
TVM_ADD_FILELINE
)
.
set_num_inputs
(
1
)
.
set_support_level
(
4
)
.
add_type_rel
(
"ArgReduce"
,
ArgReduceRel
);
RELAY_REGISTER_REDUCE_OP
(
"argmin"
)
.
describe
(
R"code(Creates an operation that finds the indices of the minimum
values over a given axis.
)code"
TVM_ADD_FILELINE
)
.
set_num_inputs
(
1
)
.
set_support_level
(
4
)
.
add_type_rel
(
"ArgReduce"
,
ArgReduceRel
);
}
// namespace relay
}
// namespace tvm
src/relay/pass/type_solver.cc
View file @
4001569e
...
...
@@ -18,6 +18,13 @@ class TypeSolver::Reporter : public TypeReporterNode {
solver_
->
Unify
(
dst
,
src
);
}
bool
Assert
(
const
IndexExpr
&
cond
)
final
{
if
(
const
uint64_t
*
pdiff
=
as_const_uint
(
cond
))
{
return
pdiff
[
0
];
}
return
true
;
}
bool
AssertEQ
(
const
IndexExpr
&
lhs
,
const
IndexExpr
&
rhs
)
final
{
// early warning constant case.
IndexExpr
diff
=
lhs
-
rhs
;
...
...
tests/python/relay/test_op_level4.py
View file @
4001569e
...
...
@@ -93,6 +93,94 @@ def test_binary_broadcast():
ftype
=
func
.
checked_type
assert
ftype
.
ret_type
==
relay
.
TensorType
((
5
,
10
,
4
),
"int32"
)
def
test_argmax
():
ib
=
relay
.
ir_builder
.
IRBuilder
()
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
tvm
.
var
(
"c"
),
tvm
.
var
(
"h"
),
tvm
.
var
(
"w"
)
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
))
with
ib
.
function
(
x
)
as
func
:
ib
.
ret
(
relay
.
argmax
(
x
,
axis
=
(
1
,)))
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
assert
ftype
.
ret_type
==
relay
.
ty
.
TensorType
((
n
,
h
,
w
),
"int32"
)
ib
=
relay
.
ir_builder
.
IRBuilder
()
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
tvm
.
var
(
"c"
),
tvm
.
var
(
"h"
),
tvm
.
var
(
"w"
)
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
))
with
ib
.
function
(
x
)
as
func
:
ib
.
ret
(
relay
.
argmax
(
x
,
axis
=
(
2
,),
keepdims
=
True
))
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
assert
ftype
.
ret_type
==
relay
.
ty
.
TensorType
((
n
,
c
,
1
,
w
),
"int32"
)
ib
=
relay
.
ir_builder
.
IRBuilder
()
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
tvm
.
var
(
"c"
),
tvm
.
var
(
"h"
),
tvm
.
var
(
"w"
)
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
))
with
ib
.
function
(
x
)
as
func
:
ib
.
ret
(
relay
.
argmax
(
x
,
axis
=
(
2
,),
keepdims
=
True
,
exclude
=
True
))
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
assert
ftype
.
ret_type
==
relay
.
ty
.
TensorType
((
1
,
1
,
h
,
1
),
"int32"
)
def
test_argmin
():
ib
=
relay
.
ir_builder
.
IRBuilder
()
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
tvm
.
var
(
"c"
),
tvm
.
var
(
"h"
),
tvm
.
var
(
"w"
)
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
))
with
ib
.
function
(
x
)
as
func
:
ib
.
ret
(
relay
.
argmax
(
x
,
axis
=
(
1
,)))
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
assert
ftype
.
ret_type
==
relay
.
ty
.
TensorType
((
n
,
h
,
w
),
"int32"
)
ib
=
relay
.
ir_builder
.
IRBuilder
()
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
tvm
.
var
(
"c"
),
tvm
.
var
(
"h"
),
tvm
.
var
(
"w"
)
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
))
with
ib
.
function
(
x
)
as
func
:
ib
.
ret
(
relay
.
argmin
(
x
,
axis
=
(
2
,),
keepdims
=
True
))
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
assert
ftype
.
ret_type
==
relay
.
ty
.
TensorType
((
n
,
c
,
1
,
w
),
"int32"
)
ib
=
relay
.
ir_builder
.
IRBuilder
()
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
tvm
.
var
(
"c"
),
tvm
.
var
(
"h"
),
tvm
.
var
(
"w"
)
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
))
with
ib
.
function
(
x
)
as
func
:
ib
.
ret
(
relay
.
argmin
(
x
,
axis
=
(
2
,),
keepdims
=
True
,
exclude
=
True
))
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
assert
ftype
.
ret_type
==
relay
.
ty
.
TensorType
((
1
,
1
,
h
,
1
),
"int32"
)
ib
=
relay
.
ir_builder
.
IRBuilder
()
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
tvm
.
var
(
"c"
),
tvm
.
var
(
"h"
),
tvm
.
var
(
"w"
)
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
))
with
ib
.
function
(
x
)
as
func
:
ib
.
ret
(
relay
.
argmin
(
x
,
axis
=
(
2
,
1
),
keepdims
=
True
,
exclude
=
True
))
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
assert
ftype
.
ret_type
==
relay
.
ty
.
TensorType
((
1
,
c
,
h
,
1
),
"int32"
)
ib
=
relay
.
ir_builder
.
IRBuilder
()
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
tvm
.
var
(
"c"
),
tvm
.
var
(
"h"
),
tvm
.
var
(
"w"
)
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
))
with
ib
.
function
(
x
)
as
func
:
ib
.
ret
(
relay
.
argmin
(
x
,
axis
=
None
,
keepdims
=
True
,
exclude
=
True
))
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
assert
ftype
.
ret_type
==
relay
.
ty
.
TensorType
((
1
,
1
,
1
,
1
),
"int32"
)
def
test_where
():
ib
=
relay
.
ir_builder
.
IRBuilder
()
cond
=
ib
.
param
(
"cond"
,
relay
.
TensorType
((
3
,
4
),
"float32"
))
...
...
@@ -113,3 +201,5 @@ if __name__ == "__main__":
test_binary_broadcast
()
test_where
()
test_multibox_prior
()
test_argmax
()
test_argmin
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment