Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
f347b525
Commit
f347b525
authored
Nov 24, 2018
by
Yong Wu
Committed by
Yizhi Liu
Feb 08, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Get tags of saved model automatically
Remove exception trail in tf parser error message Fix lint Fix comments
parent
916576c0
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
69 additions
and
71 deletions
+69
-71
nnvm/python/nnvm/frontend/tensorflow.py
+42
-21
nnvm/python/nnvm/frontend/util/tensorflow_parser.py
+27
-50
No files found.
nnvm/python/nnvm/frontend/tensorflow.py
View file @
f347b525
...
...
@@ -3,6 +3,7 @@
from
__future__
import
absolute_import
as
_abs
from
__future__
import
print_function
import
warnings
# Numpy support
import
numpy
as
np
...
...
@@ -303,7 +304,8 @@ def _conv(opname):
def
_decode_image
():
def
_impl
(
inputs
,
attr
,
params
):
# Image decode wrapper: Expecting user to feed decoded input to next layer drop this layer.
print
(
"DecodeJpeg: It's a pass through, please handle preprocessing before input"
)
warnings
.
warn
(
"DecodeJpeg: It's a pass through, "
"please handle preprocessing before input"
)
return
inputs
[
0
]
return
_impl
...
...
@@ -938,8 +940,6 @@ _convert_map = {
'Split'
:
_split
(
False
),
'SplitV'
:
_split
(
True
),
'Unpack'
:
_unpack
(),
'QueueDequeueManyV2'
:
_undef
(),
'FIFOQueueV2'
:
_undef
(),
}
# _convert_map_rnn defines maps of rnn operator name to
...
...
@@ -1184,42 +1184,57 @@ class GraphProto(object):
if
missing_operators
:
raise
NotImplementedError
(
\
"The following operators are not implemented: {}"
.
format
(
missing_operators
))
for
node
in
graph
.
node
:
if
node
.
op
==
'Placeholder'
:
self
.
_input_shapes
[
node
.
name
]
=
tensor_util
.
TensorShapeProtoToList
(
node
.
attr
[
'shape'
]
.
shape
)
self
.
_input_shapes
[
node
.
name
][
0
]
=
1
if
shape
and
node
.
name
in
shape
:
self
.
_input_shapes
[
node
.
name
]
=
list
(
shape
[
node
.
name
])
continue
self
.
_input_shapes
[
node
.
name
]
=
\
tensor_util
.
TensorShapeProtoToList
(
node
.
attr
[
'shape'
]
.
shape
)
for
idx
,
dim
in
enumerate
(
self
.
_input_shapes
[
node
.
name
]):
if
dim
<
0
:
self
.
_input_shapes
[
node
.
name
][
idx
]
=
1
warnings
.
warn
(
"Use 1 instead of -1 in shape of operator
%
s."
%
node
.
name
)
# Ignore user's input shape for Non placeholder
elif
node
.
op
==
'Const'
:
tensor_value
=
node
.
attr
[
'value'
]
.
tensor
self
.
_input_shapes
[
node
.
name
]
=
tensor_util
.
TensorShapeProtoToList
(
tensor_value
.
tensor_shape
)
self
.
_input_shapes
[
node
.
name
]
=
\
tensor_util
.
TensorShapeProtoToList
(
tensor_value
.
tensor_shape
)
if
shape
and
node
.
name
in
shape
:
warnings
.
warn
(
"Ignore the passed shape. "
"Shape in graphdef will be used for operator
%
s."
%
node
.
name
)
final_op
=
None
# Parse the nodes to re-create TF graph using Symbol API of NNVM
for
node
in
graph
.
node
:
# Tensorflow doesn't have sep
e
rate list for params extraction.
# Tensorflow doesn't have sep
a
rate list for params extraction.
# Operator name 'Const' is treated as a parameter to build NNVM params dict.
input_shapes
=
{}
input_0d_mismatch
=
set
()
attr
=
self
.
_parse_attr
(
node
.
attr
)
#Variable converted to Const will not have only value attr
#
Variable converted to Const will not have only value attr
if
'value'
in
attr
and
node
.
op
==
'Const'
:
self
.
_output_shapes
[
node
.
name
]
=
[
self
.
_input_shapes
[
node
.
name
]]
elif
node
.
op
==
'Placeholder'
:
self
.
_output_shapes
[
node
.
name
]
=
[
self
.
_input_shapes
[
node
.
name
]]
elif
shape
and
node
.
name
in
shape
:
# Give priority to user argument.
self
.
_output_shapes
[
node
.
name
]
=
[
shape
[
node
.
name
]]
elif
node
.
op
==
'Placeholder'
:
self
.
_output_shapes
[
node
.
name
]
=
[
self
.
_input_shapes
[
node
.
name
]]
elif
'_output_shapes'
in
attr
:
self
.
_output_shapes
[
node
.
name
]
=
\
[
tensor_util
.
TensorShapeProtoToList
(
tshape
)
\
for
tshape
in
attr
[
'_output_shapes'
]]
el
if
shap
e
:
el
s
e
:
# Keep the list indexable to avoid key error.
# Actual value will be filled after node creation.
# Will infer shapes if the graph is not frozen with add_shapes=True
self
.
_output_shapes
[
node
.
name
]
=
[
None
]
else
:
self
.
_output_shapes
[
node
.
name
]
=
None
self
.
_outputs_are_0d
[
node
.
name
]
=
[
\
not
tshape
if
isinstance
(
tshape
,
list
)
else
False
\
for
tshape
in
self
.
_output_shapes
[
node
.
name
]]
...
...
@@ -1241,7 +1256,7 @@ class GraphProto(object):
else
:
# Pass the parsed shapes instead
output_shapes
=
self
.
_output_shapes
[
node
.
name
]
attr
[
"_output_shapes"
]
=
output_shapes
=
self
.
_output_shapes
[
node
.
name
]
# Pass the node name too in attr
attr
[
"_node_name"
]
=
node
.
name
...
...
@@ -1282,7 +1297,7 @@ class GraphProto(object):
inputs
=
self
.
_fix_extranodes
(
node
.
op
,
attr
,
inputs
)
op
=
self
.
_convert_operator
(
node
.
op
,
inputs
,
attr
,
graph
)
# Check i
s
op is converted to param
# Check i
f
op is converted to param
if
isinstance
(
op
,
np
.
ndarray
):
self
.
_params
[
node
.
name
]
=
tvm
.
nd
.
array
(
op
)
op
=
_sym
.
Variable
(
name
=
node
.
name
,
...
...
@@ -1291,19 +1306,25 @@ class GraphProto(object):
# Assuming only one output.
self
.
_nodes
[
node
.
name
]
=
op
final_op
=
op
# Infer shapes even without specifying "add_shapes=True"
if
output_shapes
==
[
None
]:
g
=
_graph
.
create
(
final_op
)
self
.
_output_shapes
[
node
.
name
]
=
\
list
(
graph_util
.
infer_shape
(
g
,
**
self
.
_input_shapes
))[
-
1
]
if
self
.
_output_shapes
[
node
.
name
]
and
shape
and
node
.
name
in
shape
:
assert
self
.
_output_shapes
[
node
.
name
]
==
list
(
shape
[
node
.
name
])
# Infer shapes if passed explicitely
node_output
=
self
.
_nodes
[
node
.
name
]
if
shape
:
if
shape
and
(
not
self
.
_output_shapes
[
node
.
name
][
0
]
or
-
1
in
self
.
_output_shapes
[
node
.
name
][
0
]):
g
=
_graph
.
create
(
node_output
)
shape_dict
=
{
k
:
v
.
shape
for
k
,
v
in
self
.
_params
.
items
()}
shape_dict
.
update
(
shape
)
_
,
out_shapes
=
graph_util
.
infer_shape
(
g
,
**
shape_dict
)
self
.
_output_shapes
[
node
.
name
]
=
out_shapes
elif
output_shapes
==
None
:
g
=
_graph
.
create
(
node_output
)
self
.
_output_shapes
[
node
.
name
]
=
list
(
graph_util
.
infer_shape
(
g
,
**
self
.
_input_shapes
))[
-
1
]
else
:
self
.
_output_shapes
[
node
.
name
]
=
output_shapes
out
=
[]
if
outputs
is
None
:
...
...
nnvm/python/nnvm/frontend/util/tensorflow_parser.py
View file @
f347b525
...
...
@@ -2,32 +2,13 @@
from
__future__
import
absolute_import
as
_abs
from
__future__
import
print_function
import
os
try
:
from
tensorflow.core.framework
import
graph_pb2
except
ImportError
as
e
:
from
nnvm.frontend.protobuf
import
graph_pb2
try
:
from
tempfile
import
TemporaryDirectory
except
ImportError
:
import
tempfile
import
shutil
class
TemporaryDirectory
(
object
):
def
__enter__
(
self
):
self
.
name
=
tempfile
.
mkdtemp
()
return
self
.
name
def
__exit__
(
self
,
exc
,
value
,
tb
):
shutil
.
rmtree
(
self
.
name
)
from
tensorflow.core.framework
import
graph_pb2
from
tvm.contrib
import
util
class
TFParser
(
object
):
"""A Wrapper to handle tensorflow models parsing
Works w/o installing tensorflow,
Protocol Buffer is needed
TensorFlow is needed
```
parser = TfParser(model_dir)
graph = parser.parse()
...
...
@@ -39,7 +20,7 @@ class TFParser(object):
"""
def
__init__
(
self
,
model_dir
):
self
.
_tmp_dir
=
TemporaryDirectory
()
self
.
_tmp_dir
=
util
.
tempdir
()
self
.
_model_dir
=
model_dir
self
.
_graph
=
graph_pb2
.
GraphDef
()
...
...
@@ -51,21 +32,6 @@ class TFParser(object):
"""Get Graph"""
return
self
.
_graph
def
_output_graph
(
self
):
import
logging
logging
.
basicConfig
(
level
=
logging
.
DEBUG
)
for
node
in
self
.
_get_graph
()
.
node
:
logging
.
info
(
"Name: {}"
.
format
(
node
.
name
))
logging
.
info
(
"
\t
op: {}"
.
format
(
node
.
op
))
for
input
in
node
.
input
:
logging
.
info
(
"
\t\t
input: {}"
.
format
(
input
))
logging
.
info
(
"
\t\t
device: {}"
.
format
(
node
.
device
))
logging
.
info
(
"
\t\t
AttrValue: "
)
for
key
in
node
.
attr
.
keys
():
logging
.
info
(
"
\t\t\t
key: {} => value: {}"
.
format
(
key
,
node
.
attr
[
key
]))
logging
.
info
(
node
.
attr
[
'shape'
]
.
shape
)
def
_load_pb_file
(
self
):
"""Load single pb file"""
graph
=
self
.
_get_graph
()
...
...
@@ -73,19 +39,30 @@ class TFParser(object):
graph
.
ParseFromString
(
f
.
read
())
return
graph
def
_get_output_names
(
self
,
model_path
):
def
_get_tag_set
(
self
):
"""Return the tag set of saved model, multiple metagraphs are not supported"""
try
:
from
tensorflow.contrib.saved_model.python.saved_model
import
reader
except
ImportError
:
raise
ImportError
(
"InputConfiguration: Unable to import saved_model.reader which is "
"required to get tag set from saved model."
)
tag_sets
=
reader
.
get_saved_model_tag_sets
(
self
.
_model_dir
)
return
tag_sets
[
0
]
def
_get_output_names
(
self
):
"""Return the concatenated output names"""
try
:
import
tensorflow
as
tf
except
ImportError
as
e
:
except
ImportError
:
raise
ImportError
(
"InputConfiguration: Unable to import tensorflow which is "
"required to restore from saved model.
{}"
.
format
(
e
)
)
"required to restore from saved model.
"
)
tags
=
self
.
_get_tag_set
()
with
tf
.
Session
()
as
sess
:
meta_graph_def
=
tf
.
saved_model
.
loader
.
load
(
sess
,
[
tf
.
saved_model
.
tag_constants
.
SERVING
]
,
model_path
)
tags
,
self
.
_model_dir
)
output_names
=
set
()
for
k
in
meta_graph_def
.
signature_def
.
keys
():
outputs_tensor_info
=
meta_graph_def
.
signature_def
[
k
]
.
outputs
...
...
@@ -97,19 +74,18 @@ class TFParser(object):
def
_load_saved_model
(
self
):
"""Load the tensorflow saved model."""
try
:
import
tensorflow
as
tf
from
tensorflow.python.tools
import
freeze_graph
from
tensorflow.python.framework
import
ops
from
tensorflow.python.framework
import
graph_util
except
ImportError
as
e
:
except
ImportError
:
raise
ImportError
(
"InputConfiguration: Unable to import tensorflow which is "
"required to restore from saved model.
{}"
.
format
(
e
)
)
"required to restore from saved model.
"
)
saved_model_dir
=
self
.
_model_dir
output_graph_filename
=
os
.
path
.
join
(
self
.
_tmp_dir
.
name
,
"neo
_frozen_model.pb"
)
output_graph_filename
=
self
.
_tmp_dir
.
relpath
(
"tf
_frozen_model.pb"
)
input_saved_model_dir
=
saved_model_dir
output_node_names
=
self
.
_get_output_names
(
self
.
_model_dir
)
output_node_names
=
self
.
_get_output_names
()
input_binary
=
False
input_saver_def_path
=
False
...
...
@@ -119,7 +95,7 @@ class TFParser(object):
input_meta_graph
=
False
checkpoint_path
=
None
input_graph_filename
=
None
saved_model_tags
=
tf
.
saved_model
.
tag_constants
.
SERVING
saved_model_tags
=
","
.
join
(
self
.
_get_tag_set
())
freeze_graph
.
freeze_graph
(
input_graph_filename
,
input_saver_def_path
,
input_binary
,
checkpoint_path
,
output_node_names
,
...
...
@@ -145,6 +121,7 @@ class TFParser(object):
file.
"""
graph
=
None
if
os
.
path
.
isdir
(
self
.
_model_dir
):
ckpt
=
os
.
path
.
join
(
self
.
_model_dir
,
"checkpoint"
)
if
not
os
.
path
.
isfile
(
ckpt
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment