Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
916576c0
Commit
916576c0
authored
Nov 19, 2018
by
Zhi Chen
Committed by
Yizhi Liu
Feb 08, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Support TensorFlow saved model
TF parser: return the consistent error message to error handler
parent
f1782f3e
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
165 additions
and
14 deletions
+165
-14
nnvm/python/nnvm/frontend/tensorflow.py
+7
-0
nnvm/python/nnvm/frontend/util/tensorflow_parser.py
+158
-14
No files found.
nnvm/python/nnvm/frontend/tensorflow.py
View file @
916576c0
...
@@ -355,6 +355,11 @@ def _matmul():
...
@@ -355,6 +355,11 @@ def _matmul():
return
_impl
return
_impl
def
_undef
():
def
_impl
(
inputs
,
attr
,
params
):
return
_sym
.
__undef__
()
return
_impl
def
_identity
():
def
_identity
():
def
_impl
(
inputs
,
attr
,
params
):
def
_impl
(
inputs
,
attr
,
params
):
return
inputs
[
0
]
return
inputs
[
0
]
...
@@ -933,6 +938,8 @@ _convert_map = {
...
@@ -933,6 +938,8 @@ _convert_map = {
'Split'
:
_split
(
False
),
'Split'
:
_split
(
False
),
'SplitV'
:
_split
(
True
),
'SplitV'
:
_split
(
True
),
'Unpack'
:
_unpack
(),
'Unpack'
:
_unpack
(),
'QueueDequeueManyV2'
:
_undef
(),
'FIFOQueueV2'
:
_undef
(),
}
}
# _convert_map_rnn defines maps of rnn operator name to
# _convert_map_rnn defines maps of rnn operator name to
...
...
nnvm/python/nnvm/frontend/util/tensorflow_parser.py
View file @
916576c0
"""TF: Tensorflow parser"""
"""TF: Tensorflow parser"""
from
__future__
import
absolute_import
as
_abs
from
__future__
import
absolute_import
as
_abs
from
__future__
import
print_function
from
__future__
import
print_function
from
nnvm.frontend.protobuf
import
graph_pb2
import
os
try
:
from
tensorflow.core.framework
import
graph_pb2
except
ImportError
as
e
:
from
nnvm.frontend.protobuf
import
graph_pb2
try
:
from
tempfile
import
TemporaryDirectory
except
ImportError
:
import
tempfile
import
shutil
class
TemporaryDirectory
(
object
):
def
__enter__
(
self
):
self
.
name
=
tempfile
.
mkdtemp
()
return
self
.
name
def
__exit__
(
self
,
exc
,
value
,
tb
):
shutil
.
rmtree
(
self
.
name
)
class
TFParser
(
object
):
class
TFParser
(
object
):
"""A Wrapper to handle tensorflow
frozen model
parsing
"""A Wrapper to handle tensorflow
models
parsing
Works w/o installing tensorflow,
Works w/o installing tensorflow,
Protocol Buffer is needed
Protocol Buffer is needed
```
```
parser = TfParser(
pb_file
)
parser = TfParser(
model_dir
)
graph = parser.parse()
graph = parser.parse()
```
```
Parameters
Parameters
----------
----------
pb_file : tensorflow frozen pb file
model_dir : tensorflow frozen pb file or a directory that contains saved
The pb file should include both operations and tensors
model or checkpoints.
"""
"""
def
__init__
(
self
,
pb_file
):
def
__init__
(
self
,
model_dir
):
self
.
_pb
=
pb_file
self
.
_tmp_dir
=
TemporaryDirectory
()
self
.
_model_dir
=
model_dir
self
.
_graph
=
graph_pb2
.
GraphDef
()
self
.
_graph
=
graph_pb2
.
GraphDef
()
def
_load_model
(
self
):
def
_set_graph
(
self
,
graph
):
"""load frozen tensorflow model, return GraphDef """
"""Set Graph"""
with
open
(
self
.
_pb
,
"rb"
)
as
f
:
self
.
_graph
=
graph
self
.
_graph
.
ParseFromString
(
f
.
read
())
def
_get_graph
(
self
):
"""Get Graph"""
return
self
.
_graph
def
_output_graph
(
self
):
import
logging
logging
.
basicConfig
(
level
=
logging
.
DEBUG
)
for
node
in
self
.
_get_graph
()
.
node
:
logging
.
info
(
"Name: {}"
.
format
(
node
.
name
))
logging
.
info
(
"
\t
op: {}"
.
format
(
node
.
op
))
for
input
in
node
.
input
:
logging
.
info
(
"
\t\t
input: {}"
.
format
(
input
))
logging
.
info
(
"
\t\t
device: {}"
.
format
(
node
.
device
))
logging
.
info
(
"
\t\t
AttrValue: "
)
for
key
in
node
.
attr
.
keys
():
logging
.
info
(
"
\t\t\t
key: {} => value: {}"
.
format
(
key
,
node
.
attr
[
key
]))
logging
.
info
(
node
.
attr
[
'shape'
]
.
shape
)
def
_load_pb_file
(
self
):
"""Load single pb file"""
graph
=
self
.
_get_graph
()
with
open
(
self
.
_model_dir
,
"rb"
)
as
f
:
graph
.
ParseFromString
(
f
.
read
())
return
graph
def
_get_output_names
(
self
,
model_path
):
"""Return the concatenated output names"""
try
:
import
tensorflow
as
tf
except
ImportError
as
e
:
raise
ImportError
(
"InputConfiguration: Unable to import tensorflow which is "
"required to restore from saved model. {}"
.
format
(
e
))
with
tf
.
Session
()
as
sess
:
meta_graph_def
=
tf
.
saved_model
.
loader
.
load
(
sess
,
[
tf
.
saved_model
.
tag_constants
.
SERVING
],
model_path
)
output_names
=
set
()
for
k
in
meta_graph_def
.
signature_def
.
keys
():
outputs_tensor_info
=
meta_graph_def
.
signature_def
[
k
]
.
outputs
for
output_tensor
in
outputs_tensor_info
.
values
():
output_names
.
add
(
output_tensor
.
name
)
output_names
=
[
i
.
replace
(
":0"
,
""
)
for
i
in
output_names
]
return
","
.
join
(
output_names
)
def
_load_saved_model
(
self
):
"""Load the tensorflow saved model."""
try
:
import
tensorflow
as
tf
from
tensorflow.python.tools
import
freeze_graph
from
tensorflow.python.framework
import
ops
from
tensorflow.python.framework
import
graph_util
except
ImportError
as
e
:
raise
ImportError
(
"InputConfiguration: Unable to import tensorflow which is "
"required to restore from saved model. {}"
.
format
(
e
))
saved_model_dir
=
self
.
_model_dir
output_graph_filename
=
os
.
path
.
join
(
self
.
_tmp_dir
.
name
,
"neo_frozen_model.pb"
)
input_saved_model_dir
=
saved_model_dir
output_node_names
=
self
.
_get_output_names
(
self
.
_model_dir
)
input_binary
=
False
input_saver_def_path
=
False
restore_op_name
=
None
filename_tensor_name
=
None
clear_devices
=
True
input_meta_graph
=
False
checkpoint_path
=
None
input_graph_filename
=
None
saved_model_tags
=
tf
.
saved_model
.
tag_constants
.
SERVING
freeze_graph
.
freeze_graph
(
input_graph_filename
,
input_saver_def_path
,
input_binary
,
checkpoint_path
,
output_node_names
,
restore_op_name
,
filename_tensor_name
,
output_graph_filename
,
clear_devices
,
""
,
""
,
""
,
input_meta_graph
,
input_saved_model_dir
,
saved_model_tags
)
with
ops
.
Graph
()
.
as_default
():
output_graph_def
=
graph_pb2
.
GraphDef
()
with
open
(
output_graph_filename
,
"rb"
)
as
f
:
output_graph_def
.
ParseFromString
(
f
.
read
())
output_graph_def
=
graph_util
.
remove_training_nodes
(
output_graph_def
)
return
output_graph_def
def
_load_ckpt
(
self
):
"""TODO: Load checkpoint model."""
raise
RuntimeError
(
"InputConfiguration: Loading tf checkpoint model is "
"not supported yet."
)
def
parse
(
self
):
def
parse
(
self
):
self
.
_load_model
()
"""Parse tensorflow models: checkpoints, saved models, and single pb
return
self
.
_graph
file.
\ No newline at end of file
"""
graph
=
None
if
os
.
path
.
isdir
(
self
.
_model_dir
):
ckpt
=
os
.
path
.
join
(
self
.
_model_dir
,
"checkpoint"
)
if
not
os
.
path
.
isfile
(
ckpt
):
if
not
os
.
path
.
isdir
(
os
.
path
.
join
(
self
.
_model_dir
,
"variables"
)):
raise
RuntimeError
(
"InputConfiguration: Invalid model path."
)
graph
=
self
.
_load_saved_model
()
else
:
graph
=
self
.
_load_ckpt
()
elif
os
.
path
.
isfile
(
self
.
_model_dir
):
# Only .pb or .pbtxt is a valid suffix name.
if
self
.
_model_dir
.
endswith
(
".pb"
)
or
\
self
.
_model_dir
.
endswith
(
".pbtxt"
):
cur_dir
=
os
.
path
.
dirname
(
self
.
_model_dir
)
else
:
raise
RuntimeError
(
"InputConfiguration: Invalid model format."
)
# It is a saved model if `variables` directory is present at the
# same directory with the pb or pbtxt file.
if
os
.
path
.
isdir
(
os
.
path
.
join
(
cur_dir
,
"variables"
)):
self
.
_model_dir
=
cur_dir
graph
=
self
.
_load_saved_model
()
else
:
graph
=
self
.
_load_pb_file
()
else
:
raise
RuntimeError
(
"InputConfiguration: Unrecognized model "
"file or path."
)
self
.
_set_graph
(
graph
)
return
graph
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment