Commit 77869913 by Ashutosh Parkhi Committed by Tianqi Chen

tensorflow frontend supports user given outputs (#1913)

parent 2c231b5a
...@@ -1039,7 +1039,7 @@ class GraphProto(object): ...@@ -1039,7 +1039,7 @@ class GraphProto(object):
self._num_param = 0 self._num_param = 0
self._num_rnn_layer = False self._num_rnn_layer = False
def from_tensorflow(self, graph, layout="NHWC", shape=None): def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
"""Construct nnvm nodes from tensorflow graph definition - GraphDef. """Construct nnvm nodes from tensorflow graph definition - GraphDef.
Follow the tensorflow graph definition to parse and convert it to NNVM. Follow the tensorflow graph definition to parse and convert it to NNVM.
...@@ -1086,6 +1086,7 @@ class GraphProto(object): ...@@ -1086,6 +1086,7 @@ class GraphProto(object):
raise NotImplementedError( \ raise NotImplementedError( \
"The following operators are not implemented: {}".format(missing_operators)) "The following operators are not implemented: {}".format(missing_operators))
final_op = None
# Parse the nodes to re-create TF graph using Symbol API of NNVM # Parse the nodes to re-create TF graph using Symbol API of NNVM
for node in graph.node: for node in graph.node:
# Tensorflow doesn't have seperate list for params extraction. # Tensorflow doesn't have seperate list for params extraction.
...@@ -1165,6 +1166,7 @@ class GraphProto(object): ...@@ -1165,6 +1166,7 @@ class GraphProto(object):
# Assuming only one output. # Assuming only one output.
self._nodes[node.name] = op self._nodes[node.name] = op
final_op = op
# Infer shapes if passed explicitely # Infer shapes if passed explicitely
node_output = self._nodes[node.name] node_output = self._nodes[node.name]
...@@ -1175,13 +1177,16 @@ class GraphProto(object): ...@@ -1175,13 +1177,16 @@ class GraphProto(object):
_, out_shapes = graph_util.infer_shape(g, **shape_dict) _, out_shapes = graph_util.infer_shape(g, **shape_dict)
self._output_shapes[node.name] = out_shapes self._output_shapes[node.name] = out_shapes
# Assume the final node is the output node out = []
out = node_output if outputs is None:
out.append(final_op)
else:
out = [self._nodes[out_name] for out_name in outputs]
#Add the RNN outputs also with 'head' nodes of the nnvm graph #Add the RNN outputs also with 'head' nodes of the nnvm graph
if self._num_rnn_layer: if self._num_rnn_layer:
out_rnn = _sym.concatenate(*self._out_rnn, axis=0) out_rnn = _sym.concatenate(*self._out_rnn, axis=0)
out = [out, out_rnn] out.append(out_rnn)
if isinstance(out, list): if isinstance(out, list):
out = _sym.Group(out) out = _sym.Group(out)
...@@ -1378,7 +1383,7 @@ class GraphProto(object): ...@@ -1378,7 +1383,7 @@ class GraphProto(object):
return inputs return inputs
def from_tensorflow(graph, layout="NHWC", shape=None): def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None):
""" Load tensorflow graph which is a python tensorflow graph object into nnvm graph. """ Load tensorflow graph which is a python tensorflow graph object into nnvm graph.
The companion parameters will be handled automatically. The companion parameters will be handled automatically.
...@@ -1396,5 +1401,5 @@ def from_tensorflow(graph, layout="NHWC", shape=None): ...@@ -1396,5 +1401,5 @@ def from_tensorflow(graph, layout="NHWC", shape=None):
Dict of converted parameters stored in tvm.ndarray format Dict of converted parameters stored in tvm.ndarray format
""" """
g = GraphProto() g = GraphProto()
sym, params = g.from_tensorflow(graph, layout, shape) sym, params = g.from_tensorflow(graph, layout, shape, outputs)
return sym, params return sym, params
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment