Commit c8245e9a by Siva Committed by Yizhi Liu

[FRONTEND][TENSORFLOW] Enhancements. (#1923)

* [FRONTEND][TENSORFLOW] Enhancements.
	* Generalize the shape with explicite argument.
	* Supported entire range of mobilenet_v2 models.
	* Cast op updated to latest tensorflow.
	* Documentation updates.
	* CheckNumerics op handling without exception.
	* Test data from tensorflow official releases.

* 	* CI error.

* 	* self review

* 	* Enhanced reshape handling.

* 	* docs.

* 	* tutorials

* 	* review comments.

* 	* review.
parent 2ea7969b
# Tensorflow Frontend
Tensorflow frontend helps in importing tensorflow released model into TVM.
This document helps few steps while importing various different models from
[tensorflow research/slim](https://github.com/tensorflow/models/tree/master/research/slim).
Current frontend is tested with all versions of below models
- Inception (V1/V2/V3/V4)
- Resnet (All)
- Mobilenet (V1/V2 All)
- Vgg (16/19)
Tensorflow frontend expects a freezed protobuf format as input.
Not all models are released as freezed protobuf. Some of them are checkpoints (.ckpt).
Please refer to [export](https://github.com/tensorflow/models/tree/master/research/slim#exporting-the-inference-graph)
and [freeze](https://github.com/tensorflow/models/tree/master/research/slim#freezing-the-exported-graph)
instructions to generate protobuf from checkpoint.
## General Instructions
### Add Shapes:
While freezing of protobuf add additional option ```add_shapes=True``` to embed output shapes of each node into graph.
You may use ```nnvm.testing.tf.AddShapesToGraphDef``` from nnvm for the same.
Please refer to [tensorflow tutorial](https://github.com/dmlc/tvm/blob/master/tutorials/nnvm/from_tensorflow.py).
### Explicit Shape:
There might be situations where the add_shapes=True may not provide sufficient information about shape.
You may pass explicit dictionary of input shapes argument for ```from_tensorflow```.
Please refer to [test cases](https://github.com/dmlc/tvm/blob/master/nnvm/tests/python/frontend/tensorflow/test_forward.py#L36).
### GPU:
Most of these tensorflow models are released for CPU with NHWC layout.
To compile for GPU we need to pass extra argument ```layout='NCHW'``` for from_tensorflow.
This option will do a layout conversion before and after for neural network ops.
Remaining nnvm build options for GPU compilation remain as it is.
...@@ -9,7 +9,7 @@ import numpy as np ...@@ -9,7 +9,7 @@ import numpy as np
import tvm import tvm
from .. import symbol as _sym from .. import symbol as _sym
from .. import graph as _graph from .. import graph as _graph
from .. compiler import graph_util from .. compiler import graph_util, build_module
from .common import get_nnvm_op, AttrConverter as AttrConvert from .common import get_nnvm_op, AttrConverter as AttrConvert
__all__ = ['from_tensorflow'] __all__ = ['from_tensorflow']
...@@ -380,7 +380,7 @@ def _pack(): ...@@ -380,7 +380,7 @@ def _pack():
def _impl(inputs, attr, params): def _impl(inputs, attr, params):
axis = int(attr["axis"]) axis = int(attr["axis"])
inputs_reshaped = [_sym.expand_dims(i, axis=axis, num_newaxis=1) for i in inputs] inputs_reshaped = [_sym.expand_dims(i, axis=axis, num_newaxis=1) for i in inputs]
return _sym.concatenate(*inputs_reshaped, axis=axis) return _sym.concatenate(*inputs_reshaped, axis=axis, name=attr["_node_name"])
return _impl return _impl
...@@ -396,9 +396,19 @@ def _reshape(): ...@@ -396,9 +396,19 @@ def _reshape():
extras={'shape':tuple(shape_arg.asnumpy())}, extras={'shape':tuple(shape_arg.asnumpy())},
ignores=['Tshape'])(inputs, attr) ignores=['Tshape'])(inputs, attr)
except KeyError: except KeyError:
# Shape operator is already pruned, hence
# try to infer shape by precompute prune if possible.
if all(in_node in params for in_node in inputs[1].list_input_names()):
graph = _graph.create(_sym.Group(inputs[1]))
params_pre = {k: params[k] for k in inputs[1].list_input_names()}
params_new = build_module._run_graph(graph, params_pre)
inputs.pop(1)
return AttrCvt( return AttrCvt(
op_name="reshape_like", op_name="reshape",
extras={'shape':tuple(params_new[0].asnumpy().flatten())},
ignores=['Tshape'])(inputs, attr) ignores=['Tshape'])(inputs, attr)
else:
raise RuntimeError("Reshape with dynamic shape input not supported yet.")
return _impl return _impl
def _bias_add(): def _bias_add():
...@@ -470,9 +480,7 @@ def _relu6(): ...@@ -470,9 +480,7 @@ def _relu6():
def _shape(): def _shape():
def _impl(inputs, attr, params): def _impl(inputs, attr, params):
# Result of this operator is prominently used by reshape operator. return np.array(attr['_input_shapes'][inputs[0]][0], dtype='int32')
# Just pass the input as it is so that reshape_like can be used there.
return inputs[0]
return _impl return _impl
def _fill(): def _fill():
...@@ -1031,28 +1039,33 @@ class GraphProto(object): ...@@ -1031,28 +1039,33 @@ class GraphProto(object):
self._num_param = 0 self._num_param = 0
self._num_rnn_layer = False self._num_rnn_layer = False
def from_tensorflow(self, graph, layout="NHWC"): def from_tensorflow(self, graph, layout="NHWC", shape=None):
"""Construct nnvm nodes from tensorflow graph definition - GraphDef. """Construct nnvm nodes from tensorflow graph definition - GraphDef.
Follow the tensorflow graph definition to parse and convert it to NNVM. Follow the tensorflow graph definition to parse and convert it to NNVM.
Some of the assumptions listed below. Some of the assumptions listed below.
-> First Placeholder or Const node will be considered as graph input. -> All Placeholders are considered as graph input.
-> Rest all Const nodes are params. -> All Const nodes are params.
-> Last node is assumed as graph output. -> Last node is assumed as graph output.
-> _output_shapes : Attribute should present in the tenserflow forzen graph. -> _output_shapes : Graph should be frozen with add_shapes=True.
Or user can pass input shape dictionaly optionally.
-> DecodeJpeg, ResizeBilinear: These are dummy operators. -> DecodeJpeg, ResizeBilinear: These are dummy operators.
Hence user should handle preprocessing outside. Hence user should handle preprocessing outside.
-> CheckNumerics: No implementation as of now for this. -> CheckNumerics: No implementation as of now for this.
Just copies input to output. Just copies input to output.
TODO: Change algorithm to stop treating first 'Const' in a special way.
Parameters Parameters
---------- ----------
graph : tensorflow graph definition object graph : tensorflow graph definition object
The loaded tensorflow GraphDef The loaded tensorflow GraphDef
layout : target layout to be used (Optional)
NCHW only supported now to enable NHWC models on GPU.
shape : Dictionary of input dimensions (Optional)
Graph level input shape dictionary.
Returns Returns
------- -------
sym : nnvm.sym.Symbol sym : nnvm.sym.Symbol
...@@ -1079,7 +1092,6 @@ class GraphProto(object): ...@@ -1079,7 +1092,6 @@ class GraphProto(object):
# Operator name 'Const' is treated as a parameter to build NNVM params dict. # Operator name 'Const' is treated as a parameter to build NNVM params dict.
input_shapes = {} input_shapes = {}
attr = self._parse_attr(node.attr) attr = self._parse_attr(node.attr)
#Variable converted to Const will not have only value attr #Variable converted to Const will not have only value attr
...@@ -1092,6 +1104,10 @@ class GraphProto(object): ...@@ -1092,6 +1104,10 @@ class GraphProto(object):
self._output_shapes[node.name] = \ self._output_shapes[node.name] = \
[tensor_util.TensorShapeProtoToList(shape) \ [tensor_util.TensorShapeProtoToList(shape) \
for shape in attr['_output_shapes']] for shape in attr['_output_shapes']]
elif shape:
# Keep the list indexable to avoid key error.
# Actual value will be filled after node creation.
self._output_shapes[node.name] = [None]
else: else:
raise NotImplementedError( \ raise NotImplementedError( \
"Please freeze the graph with add_shapes=True") "Please freeze the graph with add_shapes=True")
...@@ -1100,7 +1116,6 @@ class GraphProto(object): ...@@ -1100,7 +1116,6 @@ class GraphProto(object):
self._nodes[node.name] = _sym.Variable(name=node.name, self._nodes[node.name] = _sym.Variable(name=node.name,
shape=self._output_shapes[node.name][0]) shape=self._output_shapes[node.name][0])
#input_shapes[self._nodes[node.name]] = self._output_shapes[node.name]
elif node.op == "Const": elif node.op == "Const":
# All Const nodes are Param nodes, lets parse # All Const nodes are Param nodes, lets parse
self._num_param += 1 self._num_param += 1
...@@ -1132,21 +1147,33 @@ class GraphProto(object): ...@@ -1132,21 +1147,33 @@ class GraphProto(object):
node.input[0] = in_name node.input[0] = in_name
# Fill shapes for all inputs in a list # Fill shapes for all inputs in a list
try: inputs = []
inputs = [self._nodes[i] for i in node.input]
for i in node.input: for i in node.input:
if i in self._nodes:
inputs.append(self._nodes[i])
input_shapes[self._nodes[i]] = self._output_shapes[i] input_shapes[self._nodes[i]] = self._output_shapes[i]
attr['_input_shapes'] = input_shapes attr['_input_shapes'] = input_shapes
except KeyError:
# TODO: Need to find clean way to handle '^CheckNumerics'
pass
inputs = self._fix_extranodes(node.op, attr, inputs) inputs = self._fix_extranodes(node.op, attr, inputs)
op = self._convert_operator(node.op, inputs, attr, graph) op = self._convert_operator(node.op, inputs, attr, graph)
# Check is op is converted to param
if isinstance(op, np.ndarray):
self._params[node.name] = tvm.nd.array(op)
op = _sym.Variable(name=node.name,
shape=self._params[node.name].shape)
# Assuming only one output. # Assuming only one output.
self._nodes[node.name] = op self._nodes[node.name] = op
node_output = op
# Infer shapes if passed explicitely
node_output = self._nodes[node.name]
if shape:
g = _graph.create(node_output)
shape_dict = {k: v.shape for k, v in self._params.items()}
shape_dict.update(shape)
_, out_shapes = graph_util.infer_shape(g, **shape_dict)
self._output_shapes[node.name] = out_shapes
# Assume the final node is the output node # Assume the final node is the output node
out = node_output out = node_output
...@@ -1351,7 +1378,7 @@ class GraphProto(object): ...@@ -1351,7 +1378,7 @@ class GraphProto(object):
return inputs return inputs
def from_tensorflow(graph, layout="NHWC"): def from_tensorflow(graph, layout="NHWC", shape=None):
""" Load tensorflow graph which is a python tensorflow graph object into nnvm graph. """ Load tensorflow graph which is a python tensorflow graph object into nnvm graph.
The companion parameters will be handled automatically. The companion parameters will be handled automatically.
...@@ -1369,5 +1396,5 @@ def from_tensorflow(graph, layout="NHWC"): ...@@ -1369,5 +1396,5 @@ def from_tensorflow(graph, layout="NHWC"):
Dict of converted parameters stored in tvm.ndarray format Dict of converted parameters stored in tvm.ndarray format
""" """
g = GraphProto() g = GraphProto()
sym, params = g.from_tensorflow(graph, layout) sym, params = g.from_tensorflow(graph, layout, shape)
return sym, params return sym, params
...@@ -46,13 +46,15 @@ def ProcessGraphDefParam(graph_def): ...@@ -46,13 +46,15 @@ def ProcessGraphDefParam(graph_def):
return graph_def return graph_def
def AddShapesToGraphDef(out_node): def AddShapesToGraphDef(session, out_node):
""" Add shapes attribute to nodes of the graph. """ Add shapes attribute to nodes of the graph.
Input graph here is the default graph in context. Input graph here is the default graph in context.
Parameters Parameters
---------- ----------
out_node: String session : tf.Session
Tensorflow session
out_node : String
Final output node of the graph. Final output node of the graph.
Returns Returns
...@@ -62,10 +64,9 @@ def AddShapesToGraphDef(out_node): ...@@ -62,10 +64,9 @@ def AddShapesToGraphDef(out_node):
""" """
with tf.Session() as sess:
graph_def = tf.graph_util.convert_variables_to_constants( graph_def = tf.graph_util.convert_variables_to_constants(
sess, session,
sess.graph.as_graph_def(add_shapes=True), session.graph.as_graph_def(add_shapes=True),
[out_node], [out_node],
) )
return graph_def return graph_def
...@@ -135,7 +136,45 @@ class NodeLookup(object): ...@@ -135,7 +136,45 @@ class NodeLookup(object):
return '' return ''
return self.node_lookup[node_id] return self.node_lookup[node_id]
def get_workload(model_path): def get_workload_official(model_url, model_sub_path, temp_dir):
""" Import workload from tensorflow official
Parameters
----------
model_url: str
URL from where it will be downloaded.
model_sub_path:
Sub path in extracted tar for the ftozen protobuf file.
temp_dir: TempDirectory
The temporary directory object to download the content.
Returns
-------
graph_def: graphdef
graph_def is the tensorflow workload for mobilenet.
"""
model_tar_name = os.path.basename(model_url)
from mxnet.gluon.utils import download
temp_path = temp_dir.relpath("./")
path_model = temp_path + model_tar_name
download(model_url, path_model)
import tarfile
if path_model.endswith("tgz") or path_model.endswith("gz"):
tar = tarfile.open(path_model)
tar.extractall(path=temp_path)
tar.close()
else:
raise RuntimeError('Could not decompress the file: ' + path_model)
return temp_path + model_sub_path
def get_workload(model_path, model_sub_path=None):
""" Import workload from frozen protobuf """ Import workload from frozen protobuf
Parameters Parameters
...@@ -143,6 +182,9 @@ def get_workload(model_path): ...@@ -143,6 +182,9 @@ def get_workload(model_path):
model_path: str model_path: str
model_path on remote repository to download from. model_path on remote repository to download from.
model_sub_path: str
Model path in the compressed archive.
Returns Returns
------- -------
graph_def: graphdef graph_def: graphdef
...@@ -150,15 +192,16 @@ def get_workload(model_path): ...@@ -150,15 +192,16 @@ def get_workload(model_path):
""" """
temp = util.tempdir()
if model_sub_path:
path_model = get_workload_official(model_path, model_sub_path, temp)
else:
repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/' repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/'
model_name = os.path.basename(model_path) model_name = os.path.basename(model_path)
model_url = os.path.join(repo_base, model_path) model_url = os.path.join(repo_base, model_path)
from mxnet.gluon.utils import download from mxnet.gluon.utils import download
temp = util.tempdir()
path_model = temp.relpath(model_name) path_model = temp.relpath(model_name)
download(model_url, path_model) download(model_url, path_model)
# Creates graph from saved graph_def.pb. # Creates graph from saved graph_def.pb.
......
...@@ -32,9 +32,8 @@ def run_tvm_graph(graph_def, input_data, input_node, num_output=1, target='llvm' ...@@ -32,9 +32,8 @@ def run_tvm_graph(graph_def, input_data, input_node, num_output=1, target='llvm'
layout = None layout = None
if target == "cuda": if target == "cuda":
layout = "NCHW" layout = "NCHW"
sym, params = nnvm.frontend.from_tensorflow(graph_def, layout=layout)
target_host = 'llvm' target_host = 'llvm'
if isinstance(input_data, list): if isinstance(input_data, list):
shape_dict = {} shape_dict = {}
dtype_dict = {} dtype_dict = {}
...@@ -45,6 +44,7 @@ def run_tvm_graph(graph_def, input_data, input_node, num_output=1, target='llvm' ...@@ -45,6 +44,7 @@ def run_tvm_graph(graph_def, input_data, input_node, num_output=1, target='llvm'
shape_dict = {input_node: input_data.shape} shape_dict = {input_node: input_data.shape}
dtype_dict = {input_node: input_data.dtype} dtype_dict = {input_node: input_data.dtype}
sym, params = nnvm.frontend.from_tensorflow(graph_def, layout=layout, shape=shape_dict)
graph, lib, params = nnvm.compiler.build(sym, target=target, target_host=target_host, shape=shape_dict, graph, lib, params = nnvm.compiler.build(sym, target=target, target_host=target_host, shape=shape_dict,
dtype=dtype_dict, params=params) dtype=dtype_dict, params=params)
...@@ -696,15 +696,20 @@ def test_forward_inception_v1(): ...@@ -696,15 +696,20 @@ def test_forward_inception_v1():
# --------- # ---------
def test_forward_mobilenet(): def test_forward_mobilenet():
'''test mobilenet model''' '''test mobilenet model'''
# MobilenetV2
with tf.Graph().as_default(): with tf.Graph().as_default():
graph_def = nnvm.testing.tf.get_workload("MobilenetV1/mobilenet_v1_1.0_224_frozen-with-shapes.pb") graph_def = nnvm.testing.tf.get_workload(
"https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.4_224.tgz",
"mobilenet_v2_1.4_224_frozen.pb")
# Call the utility to import the graph definition into default graph. # Call the utility to import the graph definition into default graph.
graph_def = nnvm.testing.tf.ProcessGraphDefParam(graph_def) graph_def = nnvm.testing.tf.ProcessGraphDefParam(graph_def)
data = np.random.uniform(size=(1, 224, 224, 3)).astype('float32') data = np.random.uniform(size=(1, 224, 224, 3)).astype('float32')
out_node = 'MobilenetV1/Predictions/Reshape_1' out_node = 'MobilenetV2/Predictions/Reshape_1'
with tf.Session() as sess: with tf.Session() as sess:
# Add shapes to the graph.
graph_def = nnvm.testing.tf.AddShapesToGraphDef(sess, out_node)
tf_output = run_tf_graph(sess, data, 'input:0', out_node + ':0') tf_output = run_tf_graph(sess, data, 'input:0', out_node + ':0')
tvm_output = run_tvm_graph(graph_def, data, 'input') tvm_output = run_tvm_graph(graph_def, data, 'input')
tvm.testing.assert_allclose(np.squeeze(tvm_output), np.squeeze(tf_output), rtol=1e-5, atol=1e-5) tvm.testing.assert_allclose(np.squeeze(tvm_output), np.squeeze(tf_output), rtol=1e-5, atol=1e-5)
......
...@@ -32,13 +32,18 @@ repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/Incep ...@@ -32,13 +32,18 @@ repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/Incep
img_name = 'elephant-299.jpg' img_name = 'elephant-299.jpg'
image_url = os.path.join(repo_base, img_name) image_url = os.path.join(repo_base, img_name)
# InceptionV1 model protobuf ######################################################################
# Tutorials
# ---------
# .. note:: # .. note::
# #
# protobuf should be exported with :any:`add_shapes=True` option. # protobuf should be exported with :any:`add_shapes=True` option.
# Could use https://github.com/dmlc/web-data/tree/master/tensorflow/scripts/tf-to-nnvm.py # Could use https://github.com/dmlc/web-data/tree/master/tensorflow/scripts/tf-to-nnvm.py
# to add shapes for existing models. # to add shapes for existing models.
# #
# Please refer docs/frontend/tensorflow.md for more details for various models
# from tensorflow.
model_name = 'classify_image_graph_def-with_shapes.pb' model_name = 'classify_image_graph_def-with_shapes.pb'
model_url = os.path.join(repo_base, model_name) model_url = os.path.join(repo_base, model_name)
...@@ -84,14 +89,15 @@ with tf.gfile.FastGFile(os.path.join("./", model_name), 'rb') as f: ...@@ -84,14 +89,15 @@ with tf.gfile.FastGFile(os.path.join("./", model_name), 'rb') as f:
# Call the utility to import the graph definition into default graph. # Call the utility to import the graph definition into default graph.
graph_def = nnvm.testing.tf.ProcessGraphDefParam(graph_def) graph_def = nnvm.testing.tf.ProcessGraphDefParam(graph_def)
# Add shapes to the graph. # Add shapes to the graph.
graph_def = nnvm.testing.tf.AddShapesToGraphDef('softmax') with tf.Session() as sess:
graph_def = nnvm.testing.tf.AddShapesToGraphDef(sess, 'softmax')
###################################################################### ######################################################################
# Decode image # Decode image
# ------------ # ------------
# .. note:: # .. note::
# #
# tensorflow frontend import doesn't support preprocessing ops like JpegDecode # tensorflow frontend import doesn't support preprocessing ops like JpegDecode.
# JpegDecode is bypassed (just return source node). # JpegDecode is bypassed (just return source node).
# Hence we supply decoded frame to TVM instead. # Hence we supply decoded frame to TVM instead.
# #
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment