Commit 77869913 by Ashutosh Parkhi Committed by Tianqi Chen

tensorflow frontend supports user given outputs (#1913)

parent 2c231b5a
......@@ -1039,7 +1039,7 @@ class GraphProto(object):
self._num_param = 0
self._num_rnn_layer = False
def from_tensorflow(self, graph, layout="NHWC", shape=None):
def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
"""Construct nnvm nodes from tensorflow graph definition - GraphDef.
Follow the tensorflow graph definition to parse and convert it to NNVM.
......@@ -1086,6 +1086,7 @@ class GraphProto(object):
raise NotImplementedError( \
"The following operators are not implemented: {}".format(missing_operators))
final_op = None
# Parse the nodes to re-create TF graph using Symbol API of NNVM
for node in graph.node:
# Tensorflow doesn't have seperate list for params extraction.
......@@ -1165,6 +1166,7 @@ class GraphProto(object):
# Assuming only one output.
self._nodes[node.name] = op
final_op = op
# Infer shapes if passed explicitely
node_output = self._nodes[node.name]
......@@ -1175,13 +1177,16 @@ class GraphProto(object):
_, out_shapes = graph_util.infer_shape(g, **shape_dict)
self._output_shapes[node.name] = out_shapes
# Assume the final node is the output node
out = node_output
out = []
if outputs is None:
out.append(final_op)
else:
out = [self._nodes[out_name] for out_name in outputs]
#Add the RNN outputs also with 'head' nodes of the nnvm graph
if self._num_rnn_layer:
out_rnn = _sym.concatenate(*self._out_rnn, axis=0)
out = [out, out_rnn]
out.append(out_rnn)
if isinstance(out, list):
out = _sym.Group(out)
......@@ -1378,7 +1383,7 @@ class GraphProto(object):
return inputs
def from_tensorflow(graph, layout="NHWC", shape=None):
def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None):
""" Load tensorflow graph which is a python tensorflow graph object into nnvm graph.
The companion parameters will be handled automatically.
......@@ -1396,5 +1401,5 @@ def from_tensorflow(graph, layout="NHWC", shape=None):
Dict of converted parameters stored in tvm.ndarray format
"""
g = GraphProto()
sym, params = g.from_tensorflow(graph, layout, shape)
sym, params = g.from_tensorflow(graph, layout, shape, outputs)
return sym, params
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment