Commit f1782f3e by Yong Wu Committed by Yizhi Liu

Add tf parser wrapper, infer shape automatically

parent 2da23bd8
......@@ -1129,6 +1129,7 @@ class GraphProto(object):
self._num_param = 0
self._num_rnn_layer = False
self._outputs_are_0d = {}
self._input_shapes = {}
def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
"""Construct nnvm nodes from tensorflow graph definition - GraphDef.
......@@ -1176,6 +1177,13 @@ class GraphProto(object):
if missing_operators:
raise NotImplementedError( \
"The following operators are not implemented: {}".format(missing_operators))
for node in graph.node:
if node.op == 'Placeholder':
self._input_shapes[node.name] = tensor_util.TensorShapeProtoToList(node.attr['shape'].shape)
self._input_shapes[node.name][0] = 1
elif node.op == 'Const':
tensor_value = node.attr['value'].tensor
self._input_shapes[node.name] = tensor_util.TensorShapeProtoToList(tensor_value.tensor_shape)
final_op = None
# Parse the nodes to re-create TF graph using Symbol API of NNVM
......@@ -1189,10 +1197,9 @@ class GraphProto(object):
#Variable converted to Const will not have only value attr
if 'value' in attr and node.op == 'Const':
tensor_value = attr['value']
self._output_shapes[node.name] = \
[tensor_util.TensorShapeProtoToList( \
tensor_value.tensor_shape)]
self._output_shapes[node.name] = [self._input_shapes[node.name]]
elif node.op == 'Placeholder':
self._output_shapes[node.name] = [self._input_shapes[node.name]]
elif shape and node.name in shape:
# Give priority to user argument.
self._output_shapes[node.name] = [shape[node.name]]
......@@ -1205,15 +1212,14 @@ class GraphProto(object):
# Actual value will be filled after node creation.
self._output_shapes[node.name] = [None]
else:
raise NotImplementedError( \
"Please freeze the graph with add_shapes=True")
self._output_shapes[node.name] = None
self._outputs_are_0d[node.name] = [ \
not tshape if isinstance(tshape, list) else False \
for tshape in self._output_shapes[node.name]]
if node.op == "Placeholder":
self._nodes[node.name] = _sym.Variable(name=node.name,
shape=self._output_shapes[node.name][0])
shape=self._input_shapes[node.name])
elif node.op == "Const":
# All Const nodes are Param nodes, lets parse
......@@ -1228,7 +1234,7 @@ class GraphProto(object):
else:
# Pass the parsed shapes instead
attr["_output_shapes"] = self._output_shapes[node.name]
output_shapes = self._output_shapes[node.name]
# Pass the node name too in attr
attr["_node_name"] = node.name
......@@ -1278,7 +1284,6 @@ class GraphProto(object):
# Assuming only one output.
self._nodes[node.name] = op
final_op = op
# Infer shapes if passed explicitely
node_output = self._nodes[node.name]
if shape:
......@@ -1287,6 +1292,11 @@ class GraphProto(object):
shape_dict.update(shape)
_, out_shapes = graph_util.infer_shape(g, **shape_dict)
self._output_shapes[node.name] = out_shapes
elif output_shapes == None:
g = _graph.create(node_output)
self._output_shapes[node.name] = list(graph_util.infer_shape(g, **self._input_shapes))[-1]
else:
self._output_shapes[node.name] = output_shapes
out = []
if outputs is None:
......
"""TF: Tensorflow parser"""
from __future__ import absolute_import as _abs
from __future__ import print_function
from nnvm.frontend.protobuf import graph_pb2
class TFParser(object):
"""A Wrapper to handle tensorflow frozen model parsing
Works w/o installing tensorflow,
Protocol Buffer is needed
```
parser = TfParser(pb_file)
graph = parser.parse()
```
Parameters
----------
pb_file : tensorflow frozen pb file
The pb file should include both operations and tensors
"""
def __init__(self, pb_file):
self._pb = pb_file
self._graph = graph_pb2.GraphDef()
def _load_model(self):
"""load frozen tensorflow model, return GraphDef """
with open(self._pb, "rb") as f:
self._graph.ParseFromString(f.read())
def parse(self):
self._load_model()
return self._graph
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment