Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
f347b525
Commit
f347b525
authored
Nov 24, 2018
by
Yong Wu
Committed by
Yizhi Liu
Feb 08, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Get tags of saved model automatically
Remove exception trail in tf parser error message Fix lint Fix comments
parent
916576c0
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
77 additions
and
79 deletions
+77
-79
nnvm/python/nnvm/frontend/tensorflow.py
+50
-29
nnvm/python/nnvm/frontend/util/tensorflow_parser.py
+27
-50
No files found.
nnvm/python/nnvm/frontend/tensorflow.py
View file @
f347b525
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
from
__future__
import
absolute_import
as
_abs
from
__future__
import
absolute_import
as
_abs
from
__future__
import
print_function
from
__future__
import
print_function
import
warnings
# Numpy support
# Numpy support
import
numpy
as
np
import
numpy
as
np
...
@@ -303,7 +304,8 @@ def _conv(opname):
...
@@ -303,7 +304,8 @@ def _conv(opname):
def
_decode_image
():
def
_decode_image
():
def
_impl
(
inputs
,
attr
,
params
):
def
_impl
(
inputs
,
attr
,
params
):
# Image decode wrapper: Expecting user to feed decoded input to next layer drop this layer.
# Image decode wrapper: Expecting user to feed decoded input to next layer drop this layer.
print
(
"DecodeJpeg: It's a pass through, please handle preprocessing before input"
)
warnings
.
warn
(
"DecodeJpeg: It's a pass through, "
"please handle preprocessing before input"
)
return
inputs
[
0
]
return
inputs
[
0
]
return
_impl
return
_impl
...
@@ -938,8 +940,6 @@ _convert_map = {
...
@@ -938,8 +940,6 @@ _convert_map = {
'Split'
:
_split
(
False
),
'Split'
:
_split
(
False
),
'SplitV'
:
_split
(
True
),
'SplitV'
:
_split
(
True
),
'Unpack'
:
_unpack
(),
'Unpack'
:
_unpack
(),
'QueueDequeueManyV2'
:
_undef
(),
'FIFOQueueV2'
:
_undef
(),
}
}
# _convert_map_rnn defines maps of rnn operator name to
# _convert_map_rnn defines maps of rnn operator name to
...
@@ -1184,42 +1184,57 @@ class GraphProto(object):
...
@@ -1184,42 +1184,57 @@ class GraphProto(object):
if
missing_operators
:
if
missing_operators
:
raise
NotImplementedError
(
\
raise
NotImplementedError
(
\
"The following operators are not implemented: {}"
.
format
(
missing_operators
))
"The following operators are not implemented: {}"
.
format
(
missing_operators
))
for
node
in
graph
.
node
:
for
node
in
graph
.
node
:
if
node
.
op
==
'Placeholder'
:
if
node
.
op
==
'Placeholder'
:
self
.
_input_shapes
[
node
.
name
]
=
tensor_util
.
TensorShapeProtoToList
(
node
.
attr
[
'shape'
]
.
shape
)
if
shape
and
node
.
name
in
shape
:
self
.
_input_shapes
[
node
.
name
][
0
]
=
1
self
.
_input_shapes
[
node
.
name
]
=
list
(
shape
[
node
.
name
])
continue
self
.
_input_shapes
[
node
.
name
]
=
\
tensor_util
.
TensorShapeProtoToList
(
node
.
attr
[
'shape'
]
.
shape
)
for
idx
,
dim
in
enumerate
(
self
.
_input_shapes
[
node
.
name
]):
if
dim
<
0
:
self
.
_input_shapes
[
node
.
name
][
idx
]
=
1
warnings
.
warn
(
"Use 1 instead of -1 in shape of operator
%
s."
%
node
.
name
)
# Ignore user's input shape for Non placeholder
elif
node
.
op
==
'Const'
:
elif
node
.
op
==
'Const'
:
tensor_value
=
node
.
attr
[
'value'
]
.
tensor
tensor_value
=
node
.
attr
[
'value'
]
.
tensor
self
.
_input_shapes
[
node
.
name
]
=
tensor_util
.
TensorShapeProtoToList
(
tensor_value
.
tensor_shape
)
self
.
_input_shapes
[
node
.
name
]
=
\
tensor_util
.
TensorShapeProtoToList
(
tensor_value
.
tensor_shape
)
if
shape
and
node
.
name
in
shape
:
warnings
.
warn
(
"Ignore the passed shape. "
"Shape in graphdef will be used for operator
%
s."
%
node
.
name
)
final_op
=
None
final_op
=
None
# Parse the nodes to re-create TF graph using Symbol API of NNVM
# Parse the nodes to re-create TF graph using Symbol API of NNVM
for
node
in
graph
.
node
:
for
node
in
graph
.
node
:
# Tensorflow doesn't have sep
e
rate list for params extraction.
# Tensorflow doesn't have sep
a
rate list for params extraction.
# Operator name 'Const' is treated as a parameter to build NNVM params dict.
# Operator name 'Const' is treated as a parameter to build NNVM params dict.
input_shapes
=
{}
input_shapes
=
{}
input_0d_mismatch
=
set
()
input_0d_mismatch
=
set
()
attr
=
self
.
_parse_attr
(
node
.
attr
)
attr
=
self
.
_parse_attr
(
node
.
attr
)
#Variable converted to Const will not have only value attr
#
Variable converted to Const will not have only value attr
if
'value'
in
attr
and
node
.
op
==
'Const'
:
if
'value'
in
attr
and
node
.
op
==
'Const'
:
self
.
_output_shapes
[
node
.
name
]
=
[
self
.
_input_shapes
[
node
.
name
]]
self
.
_output_shapes
[
node
.
name
]
=
[
self
.
_input_shapes
[
node
.
name
]]
elif
shape
and
node
.
name
in
shape
:
# Give priority to user argument.
self
.
_output_shapes
[
node
.
name
]
=
[
shape
[
node
.
name
]]
elif
node
.
op
==
'Placeholder'
:
elif
node
.
op
==
'Placeholder'
:
self
.
_output_shapes
[
node
.
name
]
=
[
self
.
_input_shapes
[
node
.
name
]]
self
.
_output_shapes
[
node
.
name
]
=
[
self
.
_input_shapes
[
node
.
name
]]
elif
shape
and
node
.
name
in
shape
:
# Give priority to user argument.
self
.
_output_shapes
[
node
.
name
]
=
[
shape
[
node
.
name
]]
elif
'_output_shapes'
in
attr
:
elif
'_output_shapes'
in
attr
:
self
.
_output_shapes
[
node
.
name
]
=
\
self
.
_output_shapes
[
node
.
name
]
=
\
[
tensor_util
.
TensorShapeProtoToList
(
tshape
)
\
[
tensor_util
.
TensorShapeProtoToList
(
tshape
)
\
for
tshape
in
attr
[
'_output_shapes'
]]
for
tshape
in
attr
[
'_output_shapes'
]]
el
if
shap
e
:
el
s
e
:
# Keep the list indexable to avoid key error.
# Keep the list indexable to avoid key error.
# Actual value will be filled after node creation.
# Actual value will be filled after node creation.
# Will infer shapes if the graph is not frozen with add_shapes=True
self
.
_output_shapes
[
node
.
name
]
=
[
None
]
self
.
_output_shapes
[
node
.
name
]
=
[
None
]
else
:
self
.
_output_shapes
[
node
.
name
]
=
None
self
.
_outputs_are_0d
[
node
.
name
]
=
[
\
self
.
_outputs_are_0d
[
node
.
name
]
=
[
\
not
tshape
if
isinstance
(
tshape
,
list
)
else
False
\
not
tshape
if
isinstance
(
tshape
,
list
)
else
False
\
for
tshape
in
self
.
_output_shapes
[
node
.
name
]]
for
tshape
in
self
.
_output_shapes
[
node
.
name
]]
...
@@ -1241,7 +1256,7 @@ class GraphProto(object):
...
@@ -1241,7 +1256,7 @@ class GraphProto(object):
else
:
else
:
# Pass the parsed shapes instead
# Pass the parsed shapes instead
output_shapes
=
self
.
_output_shapes
[
node
.
name
]
attr
[
"_output_shapes"
]
=
output_shapes
=
self
.
_output_shapes
[
node
.
name
]
# Pass the node name too in attr
# Pass the node name too in attr
attr
[
"_node_name"
]
=
node
.
name
attr
[
"_node_name"
]
=
node
.
name
...
@@ -1282,7 +1297,7 @@ class GraphProto(object):
...
@@ -1282,7 +1297,7 @@ class GraphProto(object):
inputs
=
self
.
_fix_extranodes
(
node
.
op
,
attr
,
inputs
)
inputs
=
self
.
_fix_extranodes
(
node
.
op
,
attr
,
inputs
)
op
=
self
.
_convert_operator
(
node
.
op
,
inputs
,
attr
,
graph
)
op
=
self
.
_convert_operator
(
node
.
op
,
inputs
,
attr
,
graph
)
# Check i
s
op is converted to param
# Check i
f
op is converted to param
if
isinstance
(
op
,
np
.
ndarray
):
if
isinstance
(
op
,
np
.
ndarray
):
self
.
_params
[
node
.
name
]
=
tvm
.
nd
.
array
(
op
)
self
.
_params
[
node
.
name
]
=
tvm
.
nd
.
array
(
op
)
op
=
_sym
.
Variable
(
name
=
node
.
name
,
op
=
_sym
.
Variable
(
name
=
node
.
name
,
...
@@ -1291,19 +1306,25 @@ class GraphProto(object):
...
@@ -1291,19 +1306,25 @@ class GraphProto(object):
# Assuming only one output.
# Assuming only one output.
self
.
_nodes
[
node
.
name
]
=
op
self
.
_nodes
[
node
.
name
]
=
op
final_op
=
op
final_op
=
op
# Infer shapes if passed explicitely
node_output
=
self
.
_nodes
[
node
.
name
]
# Infer shapes even without specifying "add_shapes=True"
if
shape
:
if
output_shapes
==
[
None
]:
g
=
_graph
.
create
(
node_output
)
g
=
_graph
.
create
(
final_op
)
shape_dict
=
{
k
:
v
.
shape
for
k
,
v
in
self
.
_params
.
items
()}
self
.
_output_shapes
[
node
.
name
]
=
\
shape_dict
.
update
(
shape
)
list
(
graph_util
.
infer_shape
(
g
,
**
self
.
_input_shapes
))[
-
1
]
_
,
out_shapes
=
graph_util
.
infer_shape
(
g
,
**
shape_dict
)
self
.
_output_shapes
[
node
.
name
]
=
out_shapes
if
self
.
_output_shapes
[
node
.
name
]
and
shape
and
node
.
name
in
shape
:
elif
output_shapes
==
None
:
assert
self
.
_output_shapes
[
node
.
name
]
==
list
(
shape
[
node
.
name
])
g
=
_graph
.
create
(
node_output
)
self
.
_output_shapes
[
node
.
name
]
=
list
(
graph_util
.
infer_shape
(
g
,
**
self
.
_input_shapes
))[
-
1
]
# Infer shapes if passed explicitely
else
:
node_output
=
self
.
_nodes
[
node
.
name
]
self
.
_output_shapes
[
node
.
name
]
=
output_shapes
if
shape
and
(
not
self
.
_output_shapes
[
node
.
name
][
0
]
or
-
1
in
self
.
_output_shapes
[
node
.
name
][
0
]):
g
=
_graph
.
create
(
node_output
)
shape_dict
=
{
k
:
v
.
shape
for
k
,
v
in
self
.
_params
.
items
()}
shape_dict
.
update
(
shape
)
_
,
out_shapes
=
graph_util
.
infer_shape
(
g
,
**
shape_dict
)
self
.
_output_shapes
[
node
.
name
]
=
out_shapes
out
=
[]
out
=
[]
if
outputs
is
None
:
if
outputs
is
None
:
...
...
nnvm/python/nnvm/frontend/util/tensorflow_parser.py
View file @
f347b525
...
@@ -2,32 +2,13 @@
...
@@ -2,32 +2,13 @@
from
__future__
import
absolute_import
as
_abs
from
__future__
import
absolute_import
as
_abs
from
__future__
import
print_function
from
__future__
import
print_function
import
os
import
os
from
tensorflow.core.framework
import
graph_pb2
try
:
from
tvm.contrib
import
util
from
tensorflow.core.framework
import
graph_pb2
except
ImportError
as
e
:
from
nnvm.frontend.protobuf
import
graph_pb2
try
:
from
tempfile
import
TemporaryDirectory
except
ImportError
:
import
tempfile
import
shutil
class
TemporaryDirectory
(
object
):
def
__enter__
(
self
):
self
.
name
=
tempfile
.
mkdtemp
()
return
self
.
name
def
__exit__
(
self
,
exc
,
value
,
tb
):
shutil
.
rmtree
(
self
.
name
)
class
TFParser
(
object
):
class
TFParser
(
object
):
"""A Wrapper to handle tensorflow models parsing
"""A Wrapper to handle tensorflow models parsing
Works w/o installing tensorflow,
TensorFlow is needed
Protocol Buffer is needed
```
```
parser = TfParser(model_dir)
parser = TfParser(model_dir)
graph = parser.parse()
graph = parser.parse()
...
@@ -39,7 +20,7 @@ class TFParser(object):
...
@@ -39,7 +20,7 @@ class TFParser(object):
"""
"""
def
__init__
(
self
,
model_dir
):
def
__init__
(
self
,
model_dir
):
self
.
_tmp_dir
=
TemporaryDirectory
()
self
.
_tmp_dir
=
util
.
tempdir
()
self
.
_model_dir
=
model_dir
self
.
_model_dir
=
model_dir
self
.
_graph
=
graph_pb2
.
GraphDef
()
self
.
_graph
=
graph_pb2
.
GraphDef
()
...
@@ -51,21 +32,6 @@ class TFParser(object):
...
@@ -51,21 +32,6 @@ class TFParser(object):
"""Get Graph"""
"""Get Graph"""
return
self
.
_graph
return
self
.
_graph
def
_output_graph
(
self
):
import
logging
logging
.
basicConfig
(
level
=
logging
.
DEBUG
)
for
node
in
self
.
_get_graph
()
.
node
:
logging
.
info
(
"Name: {}"
.
format
(
node
.
name
))
logging
.
info
(
"
\t
op: {}"
.
format
(
node
.
op
))
for
input
in
node
.
input
:
logging
.
info
(
"
\t\t
input: {}"
.
format
(
input
))
logging
.
info
(
"
\t\t
device: {}"
.
format
(
node
.
device
))
logging
.
info
(
"
\t\t
AttrValue: "
)
for
key
in
node
.
attr
.
keys
():
logging
.
info
(
"
\t\t\t
key: {} => value: {}"
.
format
(
key
,
node
.
attr
[
key
]))
logging
.
info
(
node
.
attr
[
'shape'
]
.
shape
)
def
_load_pb_file
(
self
):
def
_load_pb_file
(
self
):
"""Load single pb file"""
"""Load single pb file"""
graph
=
self
.
_get_graph
()
graph
=
self
.
_get_graph
()
...
@@ -73,19 +39,30 @@ class TFParser(object):
...
@@ -73,19 +39,30 @@ class TFParser(object):
graph
.
ParseFromString
(
f
.
read
())
graph
.
ParseFromString
(
f
.
read
())
return
graph
return
graph
def
_get_output_names
(
self
,
model_path
):
def
_get_tag_set
(
self
):
"""Return the tag set of saved model, multiple metagraphs are not supported"""
try
:
from
tensorflow.contrib.saved_model.python.saved_model
import
reader
except
ImportError
:
raise
ImportError
(
"InputConfiguration: Unable to import saved_model.reader which is "
"required to get tag set from saved model."
)
tag_sets
=
reader
.
get_saved_model_tag_sets
(
self
.
_model_dir
)
return
tag_sets
[
0
]
def
_get_output_names
(
self
):
"""Return the concatenated output names"""
"""Return the concatenated output names"""
try
:
try
:
import
tensorflow
as
tf
import
tensorflow
as
tf
except
ImportError
as
e
:
except
ImportError
:
raise
ImportError
(
raise
ImportError
(
"InputConfiguration: Unable to import tensorflow which is "
"InputConfiguration: Unable to import tensorflow which is "
"required to restore from saved model.
{}"
.
format
(
e
)
)
"required to restore from saved model.
"
)
tags
=
self
.
_get_tag_set
()
with
tf
.
Session
()
as
sess
:
with
tf
.
Session
()
as
sess
:
meta_graph_def
=
tf
.
saved_model
.
loader
.
load
(
sess
,
meta_graph_def
=
tf
.
saved_model
.
loader
.
load
(
sess
,
[
tf
.
saved_model
.
tag_constants
.
SERVING
]
,
tags
,
model_path
)
self
.
_model_dir
)
output_names
=
set
()
output_names
=
set
()
for
k
in
meta_graph_def
.
signature_def
.
keys
():
for
k
in
meta_graph_def
.
signature_def
.
keys
():
outputs_tensor_info
=
meta_graph_def
.
signature_def
[
k
]
.
outputs
outputs_tensor_info
=
meta_graph_def
.
signature_def
[
k
]
.
outputs
...
@@ -97,19 +74,18 @@ class TFParser(object):
...
@@ -97,19 +74,18 @@ class TFParser(object):
def
_load_saved_model
(
self
):
def
_load_saved_model
(
self
):
"""Load the tensorflow saved model."""
"""Load the tensorflow saved model."""
try
:
try
:
import
tensorflow
as
tf
from
tensorflow.python.tools
import
freeze_graph
from
tensorflow.python.tools
import
freeze_graph
from
tensorflow.python.framework
import
ops
from
tensorflow.python.framework
import
ops
from
tensorflow.python.framework
import
graph_util
from
tensorflow.python.framework
import
graph_util
except
ImportError
as
e
:
except
ImportError
:
raise
ImportError
(
raise
ImportError
(
"InputConfiguration: Unable to import tensorflow which is "
"InputConfiguration: Unable to import tensorflow which is "
"required to restore from saved model.
{}"
.
format
(
e
)
)
"required to restore from saved model.
"
)
saved_model_dir
=
self
.
_model_dir
saved_model_dir
=
self
.
_model_dir
output_graph_filename
=
os
.
path
.
join
(
self
.
_tmp_dir
.
name
,
"neo
_frozen_model.pb"
)
output_graph_filename
=
self
.
_tmp_dir
.
relpath
(
"tf
_frozen_model.pb"
)
input_saved_model_dir
=
saved_model_dir
input_saved_model_dir
=
saved_model_dir
output_node_names
=
self
.
_get_output_names
(
self
.
_model_dir
)
output_node_names
=
self
.
_get_output_names
()
input_binary
=
False
input_binary
=
False
input_saver_def_path
=
False
input_saver_def_path
=
False
...
@@ -119,7 +95,7 @@ class TFParser(object):
...
@@ -119,7 +95,7 @@ class TFParser(object):
input_meta_graph
=
False
input_meta_graph
=
False
checkpoint_path
=
None
checkpoint_path
=
None
input_graph_filename
=
None
input_graph_filename
=
None
saved_model_tags
=
tf
.
saved_model
.
tag_constants
.
SERVING
saved_model_tags
=
","
.
join
(
self
.
_get_tag_set
())
freeze_graph
.
freeze_graph
(
input_graph_filename
,
input_saver_def_path
,
freeze_graph
.
freeze_graph
(
input_graph_filename
,
input_saver_def_path
,
input_binary
,
checkpoint_path
,
output_node_names
,
input_binary
,
checkpoint_path
,
output_node_names
,
...
@@ -145,6 +121,7 @@ class TFParser(object):
...
@@ -145,6 +121,7 @@ class TFParser(object):
file.
file.
"""
"""
graph
=
None
graph
=
None
if
os
.
path
.
isdir
(
self
.
_model_dir
):
if
os
.
path
.
isdir
(
self
.
_model_dir
):
ckpt
=
os
.
path
.
join
(
self
.
_model_dir
,
"checkpoint"
)
ckpt
=
os
.
path
.
join
(
self
.
_model_dir
,
"checkpoint"
)
if
not
os
.
path
.
isfile
(
ckpt
):
if
not
os
.
path
.
isfile
(
ckpt
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment