Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
f1782f3e
Commit
f1782f3e
authored
Oct 16, 2018
by
Yong Wu
Committed by
Yizhi Liu
Feb 08, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add tf parser wrapper, infer shape automatically
parent
2da23bd8
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
51 additions
and
9 deletions
+51
-9
nnvm/python/nnvm/frontend/tensorflow.py
+19
-9
nnvm/python/nnvm/frontend/util/__init__.py
+0
-0
nnvm/python/nnvm/frontend/util/tensorflow_parser.py
+32
-0
No files found.
nnvm/python/nnvm/frontend/tensorflow.py
View file @
f1782f3e
...
...
@@ -1129,6 +1129,7 @@ class GraphProto(object):
self
.
_num_param
=
0
self
.
_num_rnn_layer
=
False
self
.
_outputs_are_0d
=
{}
self
.
_input_shapes
=
{}
def
from_tensorflow
(
self
,
graph
,
layout
=
"NHWC"
,
shape
=
None
,
outputs
=
None
):
"""Construct nnvm nodes from tensorflow graph definition - GraphDef.
...
...
@@ -1176,6 +1177,13 @@ class GraphProto(object):
if
missing_operators
:
raise
NotImplementedError
(
\
"The following operators are not implemented: {}"
.
format
(
missing_operators
))
for
node
in
graph
.
node
:
if
node
.
op
==
'Placeholder'
:
self
.
_input_shapes
[
node
.
name
]
=
tensor_util
.
TensorShapeProtoToList
(
node
.
attr
[
'shape'
]
.
shape
)
self
.
_input_shapes
[
node
.
name
][
0
]
=
1
elif
node
.
op
==
'Const'
:
tensor_value
=
node
.
attr
[
'value'
]
.
tensor
self
.
_input_shapes
[
node
.
name
]
=
tensor_util
.
TensorShapeProtoToList
(
tensor_value
.
tensor_shape
)
final_op
=
None
# Parse the nodes to re-create TF graph using Symbol API of NNVM
...
...
@@ -1189,10 +1197,9 @@ class GraphProto(object):
#Variable converted to Const will not have only value attr
if
'value'
in
attr
and
node
.
op
==
'Const'
:
tensor_value
=
attr
[
'value'
]
self
.
_output_shapes
[
node
.
name
]
=
\
[
tensor_util
.
TensorShapeProtoToList
(
\
tensor_value
.
tensor_shape
)]
self
.
_output_shapes
[
node
.
name
]
=
[
self
.
_input_shapes
[
node
.
name
]]
elif
node
.
op
==
'Placeholder'
:
self
.
_output_shapes
[
node
.
name
]
=
[
self
.
_input_shapes
[
node
.
name
]]
elif
shape
and
node
.
name
in
shape
:
# Give priority to user argument.
self
.
_output_shapes
[
node
.
name
]
=
[
shape
[
node
.
name
]]
...
...
@@ -1205,15 +1212,14 @@ class GraphProto(object):
# Actual value will be filled after node creation.
self
.
_output_shapes
[
node
.
name
]
=
[
None
]
else
:
raise
NotImplementedError
(
\
"Please freeze the graph with add_shapes=True"
)
self
.
_output_shapes
[
node
.
name
]
=
None
self
.
_outputs_are_0d
[
node
.
name
]
=
[
\
not
tshape
if
isinstance
(
tshape
,
list
)
else
False
\
for
tshape
in
self
.
_output_shapes
[
node
.
name
]]
if
node
.
op
==
"Placeholder"
:
self
.
_nodes
[
node
.
name
]
=
_sym
.
Variable
(
name
=
node
.
name
,
shape
=
self
.
_
output_shapes
[
node
.
name
][
0
])
shape
=
self
.
_
input_shapes
[
node
.
name
])
elif
node
.
op
==
"Const"
:
# All Const nodes are Param nodes, lets parse
...
...
@@ -1228,7 +1234,7 @@ class GraphProto(object):
else
:
# Pass the parsed shapes instead
attr
[
"_output_shapes"
]
=
self
.
_output_shapes
[
node
.
name
]
output_shapes
=
self
.
_output_shapes
[
node
.
name
]
# Pass the node name too in attr
attr
[
"_node_name"
]
=
node
.
name
...
...
@@ -1278,7 +1284,6 @@ class GraphProto(object):
# Assuming only one output.
self
.
_nodes
[
node
.
name
]
=
op
final_op
=
op
# Infer shapes if passed explicitely
node_output
=
self
.
_nodes
[
node
.
name
]
if
shape
:
...
...
@@ -1287,6 +1292,11 @@ class GraphProto(object):
shape_dict
.
update
(
shape
)
_
,
out_shapes
=
graph_util
.
infer_shape
(
g
,
**
shape_dict
)
self
.
_output_shapes
[
node
.
name
]
=
out_shapes
elif
output_shapes
==
None
:
g
=
_graph
.
create
(
node_output
)
self
.
_output_shapes
[
node
.
name
]
=
list
(
graph_util
.
infer_shape
(
g
,
**
self
.
_input_shapes
))[
-
1
]
else
:
self
.
_output_shapes
[
node
.
name
]
=
output_shapes
out
=
[]
if
outputs
is
None
:
...
...
nnvm/python/nnvm/frontend/util/__init__.py
0 → 100644
View file @
f1782f3e
nnvm/python/nnvm/frontend/util/tensorflow_parser.py
0 → 100644
View file @
f1782f3e
"""TF: Tensorflow parser"""
from
__future__
import
absolute_import
as
_abs
from
__future__
import
print_function
from
nnvm.frontend.protobuf
import
graph_pb2
class
TFParser
(
object
):
"""A Wrapper to handle tensorflow frozen model parsing
Works w/o installing tensorflow,
Protocol Buffer is needed
```
parser = TfParser(pb_file)
graph = parser.parse()
```
Parameters
----------
pb_file : tensorflow frozen pb file
The pb file should include both operations and tensors
"""
def
__init__
(
self
,
pb_file
):
self
.
_pb
=
pb_file
self
.
_graph
=
graph_pb2
.
GraphDef
()
def
_load_model
(
self
):
"""load frozen tensorflow model, return GraphDef """
with
open
(
self
.
_pb
,
"rb"
)
as
f
:
self
.
_graph
.
ParseFromString
(
f
.
read
())
def
parse
(
self
):
self
.
_load_model
()
return
self
.
_graph
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment