Commit 45ef90c0 by Alexander Pivovarov Committed by Tianqi Chen

Add all parameters to from_tensorflow docs (#3321)

parent ce90f0d0
...@@ -1168,7 +1168,7 @@ class GraphProto(object): ...@@ -1168,7 +1168,7 @@ class GraphProto(object):
self._input_shapes = {} self._input_shapes = {}
def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
"""Construct nnvm nodes from tensorflow graph definition - GraphDef. """Construct nnvm nodes from tensorflow graph definition - GraphDef.
Follow the tensorflow graph definition to parse and convert it to NNVM. Follow the tensorflow graph definition to parse and convert it to NNVM.
Some of the assumptions listed below. Some of the assumptions listed below.
...@@ -1194,6 +1194,9 @@ class GraphProto(object): ...@@ -1194,6 +1194,9 @@ class GraphProto(object):
shape : Dictionary of input dimensions (Optional) shape : Dictionary of input dimensions (Optional)
Graph level input shape dictionary. Graph level input shape dictionary.
outputs : List of output tensor names (Optional)
if not specified then the last node is assumed as graph output.
Returns Returns
------- -------
sym : nnvm.sym.Symbol sym : nnvm.sym.Symbol
...@@ -1569,7 +1572,7 @@ class GraphProto(object): ...@@ -1569,7 +1572,7 @@ class GraphProto(object):
return inputs return inputs
def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None): def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None):
""" Load tensorflow graph which is a python tensorflow graph object into nnvm graph. """Load tensorflow graph which is a python tensorflow graph object into nnvm graph.
The companion parameters will be handled automatically. The companion parameters will be handled automatically.
Parameters Parameters
...@@ -1577,6 +1580,15 @@ def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None): ...@@ -1577,6 +1580,15 @@ def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None):
graph : GraphDef object graph : GraphDef object
Tensorflow GraphDef Tensorflow GraphDef
layout : target layout to be used (Optional)
NCHW only supported now to enable NHWC models on GPU.
shape : Dictionary of input dimensions (Optional)
Graph level input shape dictionary.
outputs : List of output tensor names (Optional)
if not specified then the last node is assumed as graph output.
Returns Returns
------- -------
sym : nnvm.Symbol sym : nnvm.Symbol
......
...@@ -1787,7 +1787,7 @@ class GraphProto(object): ...@@ -1787,7 +1787,7 @@ class GraphProto(object):
self._branches = {} self._branches = {}
def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
"""Construct relay nodes from tensorflow graph definition - GraphDef. """Construct relay nodes from tensorflow graph definition - GraphDef.
Follow the tensorflow graph definition to parse and convert it to Relay. Follow the tensorflow graph definition to parse and convert it to Relay.
Some of the assumptions listed below. Some of the assumptions listed below.
...@@ -1813,6 +1813,9 @@ class GraphProto(object): ...@@ -1813,6 +1813,9 @@ class GraphProto(object):
shape : Dictionary of input dimensions (Optional) shape : Dictionary of input dimensions (Optional)
Graph level input shape dictionary. Graph level input shape dictionary.
outputs : List of output tensor names (Optional)
if not specified then the last node is assumed as graph output.
Returns Returns
------- -------
sym : relay.op sym : relay.op
...@@ -2276,7 +2279,7 @@ class GraphProto(object): ...@@ -2276,7 +2279,7 @@ class GraphProto(object):
def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None): def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None):
""" Load tensorflow graph which is a python tensorflow graph object into relay. """Load tensorflow graph which is a python tensorflow graph object into relay.
The companion parameters will be handled automatically. The companion parameters will be handled automatically.
Parameters Parameters
...@@ -2284,6 +2287,15 @@ def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None): ...@@ -2284,6 +2287,15 @@ def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None):
graph : GraphDef object graph : GraphDef object
Tensorflow GraphDef Tensorflow GraphDef
layout : target layout to be used (Optional)
NCHW only supported now to enable NHWC models on GPU.
shape : Dictionary of input dimensions (Optional)
Graph level input shape dictionary.
outputs : List of output tensor names (Optional)
if not specified then the last node is assumed as graph output.
Returns Returns
------- -------
sym : relay.op sym : relay.op
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment