Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
cf9db7ea
Commit
cf9db7ea
authored
Jul 25, 2018
by
Sergey Mironov
Committed by
Tianqi Chen
Jul 25, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[NNVM] Add argmax and argmin operations from topi (#1462)
parent
0fddc352
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
112 additions
and
9 deletions
+112
-9
nnvm/python/nnvm/top/reduction.py
+8
-0
nnvm/src/top/tensor/reduce.cc
+57
-0
nnvm/tests/python/compiler/test_top_level4.py
+47
-9
No files found.
nnvm/python/nnvm/top/reduction.py
View file @
cf9db7ea
...
...
@@ -41,3 +41,11 @@ reg.register_schedule("min", _fschedule_reduce)
# collapse sum
reg
.
register_pattern
(
"collapse_sum"
,
OpPattern
.
COMM_REDUCE
)
reg
.
register_schedule
(
"collapse_sum"
,
_fschedule_reduce
)
# argmax
reg
.
register_pattern
(
"argmax"
,
OpPattern
.
COMM_REDUCE
)
reg
.
register_schedule
(
"argmax"
,
_fschedule_reduce
)
# argmin
reg
.
register_pattern
(
"argmin"
,
OpPattern
.
COMM_REDUCE
)
reg
.
register_schedule
(
"argmin"
,
_fschedule_reduce
)
nnvm/src/top/tensor/reduce.cc
View file @
cf9db7ea
...
...
@@ -262,5 +262,62 @@ NNVM_REGISTER_BASE_REDUCE_OP(collapse_sum)
return
Array
<
Tensor
>
{
topi
::
collapse_sum
(
inputs
[
0
],
inputs
[
1
]
->
shape
)
};
});
template
<
int
Type
>
inline
bool
InferFixedType
(
const
NodeAttrs
&
attrs
,
std
::
vector
<
int
>*
in_attrs
,
std
::
vector
<
int
>*
out_attrs
)
{
// Static type inference for argmax operation. Argmax return indices which
// should have Int32 type as shapes do.
CHECK_EQ
(
in_attrs
->
size
(),
1U
);
CHECK_EQ
(
out_attrs
->
size
(),
1U
);
NNVM_ASSIGN_OUTPUT_TYPE
(
attrs
,
*
out_attrs
,
0
,
static_cast
<
int
>
(
Type
));
return
true
;
}
NNVM_REGISTER_BASE_REDUCE_OP
(
argmax
)
.
describe
(
R"code(Creates an operation that finds the indices of the maximum
values over a given axis.
)code"
NNVM_ADD_FILELINE
)
.
add_argument
(
"data"
,
"Tensor"
,
"The input"
)
.
set_attr
<
FInferShape
>
(
"FInferShape"
,
ReduceShape
)
.
set_attr
<
FInferType
>
(
"FInferType"
,
InferFixedType
<
kInt32
>
)
.
set_attr
<
FCorrectLayout
>
(
"FCorrectLayout"
,
ElemwiseFixedLayoutUnknownOut
<
1
,
1
>
)
.
set_num_inputs
(
1
)
.
set_attr
<
FTVMCompute
>
(
"FTVMCompute"
,
[](
const
NodeAttrs
&
attrs
,
const
Array
<
Tensor
>&
inputs
,
const
Array
<
Tensor
>&
out_info
)
{
const
ReduceParam
&
param
=
nnvm
::
get
<
ReduceParam
>
(
attrs
.
parsed
);
TShape
r_axes
=
GetReduceAxes
(
inputs
[
0
]
->
shape
.
size
(),
param
.
axis
,
param
.
exclude
);
auto
axis
=
ShapeToArray
(
r_axes
);
return
Array
<
Tensor
>
{
topi
::
argmax
(
inputs
[
0
],
axis
,
param
.
keepdims
)
};
});
NNVM_REGISTER_BASE_REDUCE_OP
(
argmin
)
.
describe
(
R"code(Creates an operation that finds the indices of the minimum
values over a given axis.
)code"
NNVM_ADD_FILELINE
)
.
add_argument
(
"data"
,
"Tensor"
,
"The input"
)
.
set_attr
<
FInferShape
>
(
"FInferShape"
,
ReduceShape
)
.
set_attr
<
FInferType
>
(
"FInferType"
,
InferFixedType
<
kInt32
>
)
.
set_attr
<
FCorrectLayout
>
(
"FCorrectLayout"
,
ElemwiseFixedLayoutUnknownOut
<
1
,
1
>
)
.
set_num_inputs
(
1
)
.
set_attr
<
FTVMCompute
>
(
"FTVMCompute"
,
[](
const
NodeAttrs
&
attrs
,
const
Array
<
Tensor
>&
inputs
,
const
Array
<
Tensor
>&
out_info
)
{
const
ReduceParam
&
param
=
nnvm
::
get
<
ReduceParam
>
(
attrs
.
parsed
);
TShape
r_axes
=
GetReduceAxes
(
inputs
[
0
]
->
shape
.
size
(),
param
.
axis
,
param
.
exclude
);
auto
axis
=
ShapeToArray
(
r_axes
);
return
Array
<
Tensor
>
{
topi
::
argmin
(
inputs
[
0
],
axis
,
param
.
keepdims
)
};
});
}
// namespace top
}
// namespace nnvm
nnvm/tests/python/compiler/test_top_level4.py
View file @
cf9db7ea
...
...
@@ -71,21 +71,27 @@ def verify_transpose(dshape, axes):
out
=
m
.
get_output
(
0
,
tvm
.
nd
.
empty
(
out_np
.
shape
))
np
.
testing
.
assert_allclose
(
out
.
asnumpy
(),
out_np
,
atol
=
1e-5
,
rtol
=
1e-5
)
def
verify_reduce
(
dshape
,
fnp
,
fsym
,
**
kwargs
):
def
verify_reduce_explicit
(
dshape
,
data
,
result
,
fsym
,
oshape
=
None
,
otype
=
'float32'
,
**
kwargs
):
""" Verify reduce operations by comparign its result with `result` """
x
=
sym
.
Variable
(
"x"
)
y
=
fsym
(
x
+
1
,
**
kwargs
)
dtype
=
"float32"
y
=
fsym
(
x
+
0
,
**
kwargs
)
for
target
,
ctx
in
ctx_list
():
graph
,
lib
,
_
=
nnvm
.
compiler
.
build
(
y
,
target
,
{
"x"
:
dshape
})
m
=
graph_runtime
.
create
(
graph
,
lib
,
ctx
)
# set input
data
=
np
.
random
.
uniform
(
size
=
dshape
)
.
astype
(
dtype
)
out_np
=
fnp
(
data
+
1
,
**
kwargs
)
m
.
run
(
x
=
data
)
out
=
m
.
get_output
(
0
,
tvm
.
nd
.
empty
(
out_np
.
shape
))
np
.
testing
.
assert_allclose
(
out
.
asnumpy
(),
out_np
,
atol
=
1e-5
,
rtol
=
1e-5
)
# oshape set to None means do not test the shape-correctness
oshape
=
result
.
shape
if
oshape
is
None
else
oshape
out
=
m
.
get_output
(
0
,
tvm
.
nd
.
empty
(
oshape
,
dtype
=
otype
))
np
.
testing
.
assert_equal
(
out
.
asnumpy
()
.
shape
,
result
.
shape
)
np
.
testing
.
assert_allclose
(
out
.
asnumpy
(),
result
,
atol
=
1e-5
,
rtol
=
1e-5
)
def
verify_reduce
(
dshape
,
fnp
,
fsym
,
oshape
=
None
,
otype
=
'float32'
,
**
kwargs
):
""" Verify reduce operations by generating data at random and calling numpy
version as reference """
data
=
np
.
random
.
uniform
(
size
=
dshape
)
.
astype
(
otype
)
result
=
fnp
(
data
+
0
,
**
kwargs
)
verify_reduce_explicit
(
dshape
,
data
,
result
,
fsym
,
oshape
=
oshape
,
otype
=
otype
,
**
kwargs
)
def
verify_collapse
(
dshape
,
target_shape
,
fnp
):
x
=
sym
.
Variable
(
"x"
,
shape
=
dshape
)
...
...
@@ -109,11 +115,43 @@ def test_transpose():
def
test_reduce
():
def
_with_keepdims
(
func
):
""" Wrapper around numpy's argmax/argmin with `keepdims` argument supported """
def
wrapper
(
data
,
axis
=
None
,
keepdims
=
False
):
if
not
keepdims
:
return
func
(
data
,
axis
=
axis
)
else
:
if
axis
is
not
None
:
out_shape
=
list
(
data
.
shape
)
out_shape
[
axis
]
=
1
else
:
out_shape
=
[
1
for
_
in
range
(
len
(
data
.
shape
))]
return
func
(
data
,
axis
=
axis
)
.
reshape
(
out_shape
)
return
wrapper
verify_reduce
((
2
,
3
,
4
),
np
.
max
,
sym
.
max
,
axis
=
1
,
keepdims
=
True
)
verify_reduce
((
4
,
4
,
3
),
np
.
min
,
sym
.
min
,
keepdims
=
True
)
verify_reduce
((
4
,
4
,
3
),
np
.
sum
,
sym
.
sum
,
axis
=
(
0
,
2
))
verify_reduce
((
4
,
4
,
3
),
np
.
sum
,
sym
.
sum
)
data
=
np
.
array
([[[
1
,
2
],[
3
,
4
]],[[
3
,
44
],[
5
,
6
]]],
dtype
=
np
.
float32
)
verify_reduce_explicit
([
2
,
2
,
2
],
data
,
np
.
array
([[
1
,
1
],[
1
,
0
]]),
sym
.
argmax
,
otype
=
'int32'
,
axis
=
[
0
,
2
],
exclude
=
True
)
verify_reduce_explicit
([
2
,
2
,
2
],
data
,
np
.
array
([[
0
,
0
],[
0
,
1
]]),
sym
.
argmin
,
otype
=
'int32'
,
axis
=
[
0
,
2
],
exclude
=
True
)
shape
=
[
4
,
4
,
3
]
for
axis
in
[
None
,
0
,
1
,
2
]:
for
keepdims
in
[
True
,
False
]:
kwargs
=
{
'keepdims'
:
keepdims
}
if
axis
is
None
:
# FIXME: NNVM doesn't support setting `axis=None` explicitly.
kwargs
.
update
({
'oshape'
:
[
1
,
1
,
1
]
if
keepdims
else
[]
})
else
:
kwargs
.
update
({
'axis'
:
axis
})
kwargs
.
update
({
'oshape'
:
shape
[:
axis
]
+
[
1
]
+
shape
[
axis
+
1
:]
if
keepdims
else
shape
[:
axis
]
+
shape
[
axis
+
1
:]})
verify_reduce
(
shape
,
_with_keepdims
(
np
.
argmax
),
sym
.
argmax
,
otype
=
'int32'
,
**
kwargs
)
verify_reduce
(
shape
,
_with_keepdims
(
np
.
argmin
),
sym
.
argmin
,
otype
=
'int32'
,
**
kwargs
)
def
test_collapse
():
verify_collapse
((
2
,
3
,
4
),
(
1
,),
lambda
x
:
x
.
sum
())
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment