Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
6d1dc4ae
Commit
6d1dc4ae
authored
Aug 02, 2018
by
Sergey Mironov
Committed by
Tianqi Chen
Aug 02, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[NNVM] Support argmax/argmin in tensorflow frontend (#1514)
parent
71cff3e8
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
79 additions
and
8 deletions
+79
-8
nnvm/python/nnvm/frontend/tensorflow.py
+48
-8
nnvm/tests/python/frontend/tensorflow/test_forward.py
+31
-0
No files found.
nnvm/python/nnvm/frontend/tensorflow.py
View file @
6d1dc4ae
...
...
@@ -91,6 +91,20 @@ def _rsqrt():
return
AttrCvt
(
op_name
=
"__pow_scalar__"
,
extras
=
{
'scalar'
:
-
0.5
})(
inputs
,
attr
)
return
_impl
def
_argx
(
func
,
func_name
):
""" A common wrapper for argmin and argmax operations """
def
_impl
(
inputs
,
attr
,
params
):
try
:
# In Tensorflow, `axis` argument is a Tensor, not attribute. We
# support the case where it inputs from a scalar constant.
axis_input_name
=
inputs
[
1
]
.
list_output_names
()[
0
]
axis_input_vlaue
=
params
[
axis_input_name
]
.
asnumpy
()[
0
]
except
(
IndexError
,
KeyError
):
raise
TypeError
(
\
"Unsupported argument for `{}` : `axis` should be a constant"
.
format
(
func_name
))
return
func
(
inputs
[
0
],
axis
=
axis_input_vlaue
,
keepdims
=
False
)
return
_impl
def
_elemwise
(
name
):
def
_impl
(
inputs
,
attr
,
*
args
):
assert
len
(
inputs
)
==
2
,
"Math op take 2 inputs, {} given"
.
format
(
len
(
inputs
))
...
...
@@ -664,6 +678,8 @@ _identity_list = []
# for 1 to N mapping(composed), use custom callable functions
# for N to 1 mapping, currently not supported(?)
_convert_map
=
{
'ArgMax'
:
_argx
(
_sym
.
argmax
,
'argmax'
),
'ArgMin'
:
_argx
(
_sym
.
argmin
,
'argmin'
),
'AvgPool'
:
_pooling
(
'avg_pool'
),
'BatchNormWithGlobalNormalization'
:
_batch_norm
(),
'BiasAdd'
:
_bias_add
(),
...
...
@@ -879,6 +895,28 @@ class RecurrentNetworks(object):
params
,
num_layers
)
return
sym
def
_parse_import_prerequisites
(
graph
):
""" Calculate the named preconditions from TensorFlow `graph`.
Return prerequisites for parsing:
a. Set of operator names which don't have their mapping in TVM, i.e.
which are not supported
"""
missing_operators
=
set
()
for
node
in
graph
.
node
:
if
node
.
op
==
"Placeholder"
:
pass
elif
node
.
op
==
"Const"
:
pass
else
:
if
any
([
node
.
op
in
t
for
t
in
[
_identity_list
,
_convert_map
,
_convert_map_rnn
]]):
pass
else
:
missing_operators
.
add
(
node
.
op
)
return
missing_operators
class
GraphProto
(
object
):
""" A helper class for handling nnvm graph copying from Tensorflow GraphDef.
Definition:
...
...
@@ -901,7 +939,7 @@ class GraphProto(object):
Follow the tensorflow graph definition to parse and convert it to NNVM.
Some of the assumptions listed below.
-> First
Const or Placeholder
node will be considered as graph input.
-> First
Placeholder or Const
node will be considered as graph input.
-> Rest all Const nodes are params.
-> Last node is assumed as graph output.
-> _output_shapes : Attribute should present in the tenserflow forzen graph.
...
...
@@ -910,6 +948,7 @@ class GraphProto(object):
-> CheckNumerics: No implementation as of now for this.
Just copies input to output.
TODO: Change algorithm to stop treating first 'Const' in a special way.
Parameters
----------
...
...
@@ -923,10 +962,6 @@ class GraphProto(object):
params : dict
A dict of name: tvm.nd.array pairs, used as pretrained weights
"""
# Parse throught all nodes and start extracting
# params aka Const nodes
# input nodes : First const node
# normal nodes : other normal nodes
try
:
from
tensorflow.python.framework
import
tensor_util
...
...
@@ -934,12 +969,18 @@ class GraphProto(object):
raise
ImportError
(
"Unable to import tensorflow which is required {}"
.
format
(
e
))
missing_operators
=
_parse_import_prerequisites
(
graph
)
if
missing_operators
:
raise
NotImplementedError
(
\
"The following operators are not implemented: {}"
.
format
(
missing_operators
))
# Parse the nodes to re-create TF graph using Symbol API of NNVM
for
node
in
graph
.
node
:
# Tensorflow doesn't have seperate list for params extraction.
# Operator name 'Const' is treated as a parameter to build NNVM params dict.
input_shapes
=
{}
if
node
.
op
==
"Placeholder"
:
# Assuming only one input graph with type 'Placeholder'
self
.
_input_node
=
node
.
name
self
.
_num_input
+=
1
...
...
@@ -954,7 +995,6 @@ class GraphProto(object):
raise
NotImplementedError
(
\
"Please freeze the graph with add_shapes=True"
)
elif
node
.
op
==
"Const"
:
# Assuming first Const node as Graph Input node
if
self
.
_input_node
==
''
:
self
.
_input_node
=
node
.
name
self
.
_num_input
+=
1
...
...
@@ -997,7 +1037,7 @@ class GraphProto(object):
# Pass the node name too in attr
attr
[
"_node_name"
]
=
node
.
name
#ToDo: Some of the tensorflow operators
maintain
internaly maintain
#ToDo: Some of the tensorflow operators internaly maintain
#execution layers and its output name will the layer number along with
#graph node name.eg: Node name:- 'Model/RNN/cell_0/RnnCell', but the
#output name will be 'Model/RNN/cell_0/RnnCell:0'. In this case,
...
...
nnvm/tests/python/frontend/tensorflow/test_forward.py
View file @
6d1dc4ae
...
...
@@ -404,6 +404,37 @@ def test_forward_sigmoid():
_test_sigmoid
(
np
.
random
.
uniform
(
size
=
(
3
,
4
,
4
,
3
))
.
astype
(
'float32'
))
#######################################################################
# Argmin/Argmax
# -------------
def
_test_argx
(
func
,
data
,
**
kwargs
):
with
tf
.
Graph
()
.
as_default
():
inp
=
constant_op
.
constant
(
data
,
shape
=
data
.
shape
,
dtype
=
data
.
dtype
,
name
=
"c0"
)
# pylint: disable=unused-variable
out
=
func
(
inp
,
name
=
"argx0"
,
**
kwargs
)
# pylint: enable=unused-variable
with
tf
.
Session
()
as
sess
:
graph_def
=
tf
.
graph_util
.
convert_variables_to_constants
(
sess
=
sess
,
input_graph_def
=
sess
.
graph
.
as_graph_def
(
add_shapes
=
True
),
output_node_names
=
[
"argx0"
])
tf_output
=
run_tf_graph
(
sess
,
data
,
input_node
=
"c0:0"
,
output_node
=
"argx0:0"
)
tvm_output
=
run_tvm_graph
(
graph_def
,
data
,
"c0"
,
tf_output
.
shape
,
output_dtype
=
'int32'
)
np
.
testing
.
assert_allclose
(
tf_output
,
tvm_output
,
atol
=
1e-5
,
rtol
=
1e-5
)
sess
.
close
()
def
test_argmin_argmax
():
for
axis
in
[
None
,
0
,
1
,
2
]:
data
=
np
.
random
.
uniform
(
size
=
(
8
,
4
,
9
))
.
astype
(
'float32'
)
_test_argx
(
tf
.
argmax
,
data
=
data
,
axis
=
axis
)
_test_argx
(
tf
.
argmin
,
data
=
data
,
axis
=
axis
)
#######################################################################
# Variable
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment