Commit f7eff095 by Yong Wu Committed by Siva

[relay][frontend] TensorFlow saved model support (#2586)

* [relay][frontend] TensorFlow saved model support

* Add Examples section

* keep one copy of tensorflow_parser in relay
parent 19194e97
...@@ -4,6 +4,7 @@ from __future__ import absolute_import as _abs ...@@ -4,6 +4,7 @@ from __future__ import absolute_import as _abs
from __future__ import print_function from __future__ import print_function
import logging import logging
import warnings
# Numpy support # Numpy support
import numpy as np import numpy as np
...@@ -410,7 +411,7 @@ def _conv(opname): ...@@ -410,7 +411,7 @@ def _conv(opname):
def _decode_image(): def _decode_image():
def _impl(inputs, attr, params): def _impl(inputs, attr, params):
# Image decode wrapper: Expecting user to feed decoded input to next layer drop this layer. # Image decode wrapper: Expecting user to feed decoded input to next layer drop this layer.
print("DecodeJpeg: It's a pass through, please handle preprocessing before input") warnings.warn("DecodeJpeg: It's a pass through, please handle preprocessing before input")
return inputs[0] return inputs[0]
return _impl return _impl
...@@ -1178,6 +1179,7 @@ class GraphProto(object): ...@@ -1178,6 +1179,7 @@ class GraphProto(object):
def __init__(self): def __init__(self):
self._nodes = {} self._nodes = {}
self._params = {} self._params = {}
self._input_shapes = {}
self._output_shapes = {} self._output_shapes = {}
self._num_param = 0 self._num_param = 0
self._num_rnn_layer = False self._num_rnn_layer = False
...@@ -1229,36 +1231,55 @@ class GraphProto(object): ...@@ -1229,36 +1231,55 @@ class GraphProto(object):
raise NotImplementedError( \ raise NotImplementedError( \
"The following operators are not implemented: {}".format(missing_operators)) "The following operators are not implemented: {}".format(missing_operators))
for node in graph.node:
if node.op == 'Placeholder':
if shape and node.name in shape:
self._input_shapes[node.name] = list(shape[node.name])
continue
self._input_shapes[node.name] = \
tensor_util.TensorShapeProtoToList(node.attr['shape'].shape)
for idx, dim in enumerate(self._input_shapes[node.name]):
if dim < 0:
self._input_shapes[node.name][idx] = 1
warnings.warn("Use 1 instead of -1 in shape of operator %s."
% node.name)
# Ignore user's input shape for Non placeholder
elif node.op == 'Const':
tensor_value = node.attr['value'].tensor
self._input_shapes[node.name] = \
tensor_util.TensorShapeProtoToList(tensor_value.tensor_shape)
if shape and node.name in shape:
warnings.warn("Ignore the passed shape. Shape in graphdef "
"will be used for operator %s." % node.name)
# Parse the nodes to re-create TF graph using Relay operators. # Parse the nodes to re-create TF graph using Relay operators.
for node in graph.node: for node in graph.node:
# Tensorflow doesn't have seperate list for params extraction. # Tensorflow doesn't have separate list for params extraction.
# Operator name 'Const' is treated as a parameter to build params dict. # Operator name 'Const' is treated as a parameter to build params dict.
input_shapes = {} input_shapes = {}
attr = self._parse_attr(node.attr) attr = self._parse_attr(node.attr)
#Variable converted to Const will not have only value attr # Variable converted to Const will not have only value attr
if 'value' in attr and node.op == 'Const': if 'value' in attr and node.op == 'Const':
tensor_value = attr['value'] self._output_shapes[node.name] = [self._input_shapes[node.name]]
self._output_shapes[node.name] = \ elif shape and node.name in shape:
[tensor_util.TensorShapeProtoToList( \ # Give priority to user argument.
tensor_value.tensor_shape)] self._output_shapes[node.name] = [shape[node.name]]
elif '_output_shapes' in attr: elif '_output_shapes' in attr:
self._output_shapes[node.name] = \ self._output_shapes[node.name] = \
[tensor_util.TensorShapeProtoToList(tshape) \ [tensor_util.TensorShapeProtoToList(tshape) \
for tshape in attr['_output_shapes']] for tshape in attr['_output_shapes']]
elif shape: else:
# Keep the list indexable to avoid key error. # Keep the list indexable to avoid key error.
# Actual value will be filled after node creation. # Actual value will be filled after node creation.
self._output_shapes[node.name] = [None] self._output_shapes[node.name] = [None]
else:
raise NotImplementedError( \
"Please freeze the graph with add_shapes=True")
if node.op == "Placeholder": if node.op == "Placeholder":
self._output_shapes[node.name] = [shape[node.name]] self._output_shapes[node.name] = [self._input_shapes[node.name]]
self._nodes[node.name] = [_expr.var(node.name, self._nodes[node.name] = [_expr.var(node.name,
shape=self._output_shapes[node.name][0], shape=self._input_shapes[node.name],
dtype=attr['dtype'].name)] dtype=attr['dtype'].name)]
elif node.op == "Const": elif node.op == "Const":
...@@ -1274,7 +1295,7 @@ class GraphProto(object): ...@@ -1274,7 +1295,7 @@ class GraphProto(object):
else: else:
# Pass the parsed shapes instead # Pass the parsed shapes instead
attr["_output_shapes"] = self._output_shapes[node.name] attr["_output_shapes"] = output_shapes = self._output_shapes[node.name]
# Pass the node name too in attr # Pass the node name too in attr
attr["_node_name"] = node.name attr["_node_name"] = node.name
...@@ -1301,7 +1322,7 @@ class GraphProto(object): ...@@ -1301,7 +1322,7 @@ class GraphProto(object):
op = self._convert_operator(node.op, inputs, attr, graph) op = self._convert_operator(node.op, inputs, attr, graph)
# Check is op is converted to param # Check if op is converted to param
if isinstance(op, np.ndarray): if isinstance(op, np.ndarray):
self._params[node.name] = tvm.nd.array(op) self._params[node.name] = tvm.nd.array(op)
op = [_expr.var(node.name, op = [_expr.var(node.name,
...@@ -1317,6 +1338,14 @@ class GraphProto(object): ...@@ -1317,6 +1338,14 @@ class GraphProto(object):
self._nodes[node.name] = op self._nodes[node.name] = op
# Infer shapes even without specifying "add_shapes=True"
if output_shapes == [None]:
out_type = ir_pass.infer_type(self._nodes[node.name][0])
self._output_shapes[node.name] = [get_const_tuple(out_type.checked_type.shape)]
if self._output_shapes[node.name] and shape and node.name in shape:
assert self._output_shapes[node.name] == list(shape[node.name])
# Infer shapes if passed explicitely # Infer shapes if passed explicitely
node_output = self._nodes[node.name] node_output = self._nodes[node.name]
out_type = ir_pass.infer_type(node_output[0]) out_type = ir_pass.infer_type(node_output[0])
......
...@@ -7,16 +7,21 @@ from tvm.contrib import util ...@@ -7,16 +7,21 @@ from tvm.contrib import util
class TFParser(object): class TFParser(object):
"""A Wrapper to handle tensorflow models parsing """
TensorFlow is needed A Wrapper to handle tensorflow models parsing, TensorFlow is needed
```
parser = TfParser(model_dir)
graph = parser.parse()
```
Parameters Parameters
---------- ----------
model_dir : tensorflow frozen pb file or a directory that contains saved model_dir : tensorflow frozen pb file or a directory that contains saved
model or checkpoints. model or checkpoints.
Examples
--------
.. code-block:: python
parser = TfParser(model_dir)
graph = parser.parse()
# graph is related graphdef of the model
""" """
def __init__(self, model_dir): def __init__(self, model_dir):
...@@ -115,13 +120,16 @@ class TFParser(object): ...@@ -115,13 +120,16 @@ class TFParser(object):
"""TODO: Load checkpoint model.""" """TODO: Load checkpoint model."""
raise RuntimeError("InputConfiguration: Loading tf checkpoint model is " raise RuntimeError("InputConfiguration: Loading tf checkpoint model is "
"not supported yet.") "not supported yet.")
# pylint: disable=unreachable
return 0
def parse(self): def parse(self):
"""Parse tensorflow models: checkpoints, saved models, and single pb
file.
""" """
Parse tensorflow models: checkpoints, saved models, and single frozen pb file.
Returns
-------
GraphDef of the passed model
"""
graph = None graph = None
if os.path.isdir(self._model_dir): if os.path.isdir(self._model_dir):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment