Commit c8245e9a by Siva Committed by Yizhi Liu

[FRONTEND][TENSORFLOW] Enhancements. (#1923)

* [FRONTEND][TENSORFLOW] Enhancements.
	* Generalize the shape with explicite argument.
	* Supported entire range of mobilenet_v2 models.
	* Cast op updated to latest tensorflow.
	* Documentation updates.
	* CheckNumerics op handling without exception.
	* Test data from tensorflow official releases.

* 	* CI error.

* 	* self review

* 	* Enhanced reshape handling.

* 	* docs.

* 	* tutorials

* 	* review comments.

* 	* review.
parent 2ea7969b
# Tensorflow Frontend
Tensorflow frontend helps in importing tensorflow released model into TVM.
This document helps few steps while importing various different models from
[tensorflow research/slim](https://github.com/tensorflow/models/tree/master/research/slim).
Current frontend is tested with all versions of below models
- Inception (V1/V2/V3/V4)
- Resnet (All)
- Mobilenet (V1/V2 All)
- Vgg (16/19)
Tensorflow frontend expects a freezed protobuf format as input.
Not all models are released as freezed protobuf. Some of them are checkpoints (.ckpt).
Please refer to [export](https://github.com/tensorflow/models/tree/master/research/slim#exporting-the-inference-graph)
and [freeze](https://github.com/tensorflow/models/tree/master/research/slim#freezing-the-exported-graph)
instructions to generate protobuf from checkpoint.
## General Instructions
### Add Shapes:
While freezing of protobuf add additional option ```add_shapes=True``` to embed output shapes of each node into graph.
You may use ```nnvm.testing.tf.AddShapesToGraphDef``` from nnvm for the same.
Please refer to [tensorflow tutorial](https://github.com/dmlc/tvm/blob/master/tutorials/nnvm/from_tensorflow.py).
### Explicit Shape:
There might be situations where the add_shapes=True may not provide sufficient information about shape.
You may pass explicit dictionary of input shapes argument for ```from_tensorflow```.
Please refer to [test cases](https://github.com/dmlc/tvm/blob/master/nnvm/tests/python/frontend/tensorflow/test_forward.py#L36).
### GPU:
Most of these tensorflow models are released for CPU with NHWC layout.
To compile for GPU we need to pass extra argument ```layout='NCHW'``` for from_tensorflow.
This option will do a layout conversion before and after for neural network ops.
Remaining nnvm build options for GPU compilation remain as it is.
......@@ -9,7 +9,7 @@ import numpy as np
import tvm
from .. import symbol as _sym
from .. import graph as _graph
from .. compiler import graph_util
from .. compiler import graph_util, build_module
from .common import get_nnvm_op, AttrConverter as AttrConvert
__all__ = ['from_tensorflow']
......@@ -380,7 +380,7 @@ def _pack():
def _impl(inputs, attr, params):
axis = int(attr["axis"])
inputs_reshaped = [_sym.expand_dims(i, axis=axis, num_newaxis=1) for i in inputs]
return _sym.concatenate(*inputs_reshaped, axis=axis)
return _sym.concatenate(*inputs_reshaped, axis=axis, name=attr["_node_name"])
return _impl
......@@ -396,9 +396,19 @@ def _reshape():
extras={'shape':tuple(shape_arg.asnumpy())},
ignores=['Tshape'])(inputs, attr)
except KeyError:
# Shape operator is already pruned, hence
# try to infer shape by precompute prune if possible.
if all(in_node in params for in_node in inputs[1].list_input_names()):
graph = _graph.create(_sym.Group(inputs[1]))
params_pre = {k: params[k] for k in inputs[1].list_input_names()}
params_new = build_module._run_graph(graph, params_pre)
inputs.pop(1)
return AttrCvt(
op_name="reshape_like",
op_name="reshape",
extras={'shape':tuple(params_new[0].asnumpy().flatten())},
ignores=['Tshape'])(inputs, attr)
else:
raise RuntimeError("Reshape with dynamic shape input not supported yet.")
return _impl
def _bias_add():
......@@ -470,9 +480,7 @@ def _relu6():
def _shape():
def _impl(inputs, attr, params):
# Result of this operator is prominently used by reshape operator.
# Just pass the input as it is so that reshape_like can be used there.
return inputs[0]
return np.array(attr['_input_shapes'][inputs[0]][0], dtype='int32')
return _impl
def _fill():
......@@ -1031,28 +1039,33 @@ class GraphProto(object):
self._num_param = 0
self._num_rnn_layer = False
def from_tensorflow(self, graph, layout="NHWC"):
def from_tensorflow(self, graph, layout="NHWC", shape=None):
"""Construct nnvm nodes from tensorflow graph definition - GraphDef.
Follow the tensorflow graph definition to parse and convert it to NNVM.
Some of the assumptions listed below.
-> First Placeholder or Const node will be considered as graph input.
-> Rest all Const nodes are params.
-> All Placeholders are considered as graph input.
-> All Const nodes are params.
-> Last node is assumed as graph output.
-> _output_shapes : Attribute should present in the tenserflow forzen graph.
-> _output_shapes : Graph should be frozen with add_shapes=True.
Or user can pass input shape dictionaly optionally.
-> DecodeJpeg, ResizeBilinear: These are dummy operators.
Hence user should handle preprocessing outside.
-> CheckNumerics: No implementation as of now for this.
Just copies input to output.
TODO: Change algorithm to stop treating first 'Const' in a special way.
Parameters
----------
graph : tensorflow graph definition object
The loaded tensorflow GraphDef
layout : target layout to be used (Optional)
NCHW only supported now to enable NHWC models on GPU.
shape : Dictionary of input dimensions (Optional)
Graph level input shape dictionary.
Returns
-------
sym : nnvm.sym.Symbol
......@@ -1079,7 +1092,6 @@ class GraphProto(object):
# Operator name 'Const' is treated as a parameter to build NNVM params dict.
input_shapes = {}
attr = self._parse_attr(node.attr)
#Variable converted to Const will not have only value attr
......@@ -1092,6 +1104,10 @@ class GraphProto(object):
self._output_shapes[node.name] = \
[tensor_util.TensorShapeProtoToList(shape) \
for shape in attr['_output_shapes']]
elif shape:
# Keep the list indexable to avoid key error.
# Actual value will be filled after node creation.
self._output_shapes[node.name] = [None]
else:
raise NotImplementedError( \
"Please freeze the graph with add_shapes=True")
......@@ -1100,7 +1116,6 @@ class GraphProto(object):
self._nodes[node.name] = _sym.Variable(name=node.name,
shape=self._output_shapes[node.name][0])
#input_shapes[self._nodes[node.name]] = self._output_shapes[node.name]
elif node.op == "Const":
# All Const nodes are Param nodes, lets parse
self._num_param += 1
......@@ -1132,21 +1147,33 @@ class GraphProto(object):
node.input[0] = in_name
# Fill shapes for all inputs in a list
try:
inputs = [self._nodes[i] for i in node.input]
inputs = []
for i in node.input:
if i in self._nodes:
inputs.append(self._nodes[i])
input_shapes[self._nodes[i]] = self._output_shapes[i]
attr['_input_shapes'] = input_shapes
except KeyError:
# TODO: Need to find clean way to handle '^CheckNumerics'
pass
inputs = self._fix_extranodes(node.op, attr, inputs)
op = self._convert_operator(node.op, inputs, attr, graph)
# Check is op is converted to param
if isinstance(op, np.ndarray):
self._params[node.name] = tvm.nd.array(op)
op = _sym.Variable(name=node.name,
shape=self._params[node.name].shape)
# Assuming only one output.
self._nodes[node.name] = op
node_output = op
# Infer shapes if passed explicitely
node_output = self._nodes[node.name]
if shape:
g = _graph.create(node_output)
shape_dict = {k: v.shape for k, v in self._params.items()}
shape_dict.update(shape)
_, out_shapes = graph_util.infer_shape(g, **shape_dict)
self._output_shapes[node.name] = out_shapes
# Assume the final node is the output node
out = node_output
......@@ -1351,7 +1378,7 @@ class GraphProto(object):
return inputs
def from_tensorflow(graph, layout="NHWC"):
def from_tensorflow(graph, layout="NHWC", shape=None):
""" Load tensorflow graph which is a python tensorflow graph object into nnvm graph.
The companion parameters will be handled automatically.
......@@ -1369,5 +1396,5 @@ def from_tensorflow(graph, layout="NHWC"):
Dict of converted parameters stored in tvm.ndarray format
"""
g = GraphProto()
sym, params = g.from_tensorflow(graph, layout)
sym, params = g.from_tensorflow(graph, layout, shape)
return sym, params
......@@ -46,13 +46,15 @@ def ProcessGraphDefParam(graph_def):
return graph_def
def AddShapesToGraphDef(out_node):
def AddShapesToGraphDef(session, out_node):
""" Add shapes attribute to nodes of the graph.
Input graph here is the default graph in context.
Parameters
----------
out_node: String
session : tf.Session
Tensorflow session
out_node : String
Final output node of the graph.
Returns
......@@ -62,10 +64,9 @@ def AddShapesToGraphDef(out_node):
"""
with tf.Session() as sess:
graph_def = tf.graph_util.convert_variables_to_constants(
sess,
sess.graph.as_graph_def(add_shapes=True),
session,
session.graph.as_graph_def(add_shapes=True),
[out_node],
)
return graph_def
......@@ -135,7 +136,45 @@ class NodeLookup(object):
return ''
return self.node_lookup[node_id]
def get_workload(model_path):
def get_workload_official(model_url, model_sub_path, temp_dir):
""" Import workload from tensorflow official
Parameters
----------
model_url: str
URL from where it will be downloaded.
model_sub_path:
Sub path in extracted tar for the ftozen protobuf file.
temp_dir: TempDirectory
The temporary directory object to download the content.
Returns
-------
graph_def: graphdef
graph_def is the tensorflow workload for mobilenet.
"""
model_tar_name = os.path.basename(model_url)
from mxnet.gluon.utils import download
temp_path = temp_dir.relpath("./")
path_model = temp_path + model_tar_name
download(model_url, path_model)
import tarfile
if path_model.endswith("tgz") or path_model.endswith("gz"):
tar = tarfile.open(path_model)
tar.extractall(path=temp_path)
tar.close()
else:
raise RuntimeError('Could not decompress the file: ' + path_model)
return temp_path + model_sub_path
def get_workload(model_path, model_sub_path=None):
""" Import workload from frozen protobuf
Parameters
......@@ -143,6 +182,9 @@ def get_workload(model_path):
model_path: str
model_path on remote repository to download from.
model_sub_path: str
Model path in the compressed archive.
Returns
-------
graph_def: graphdef
......@@ -150,15 +192,16 @@ def get_workload(model_path):
"""
temp = util.tempdir()
if model_sub_path:
path_model = get_workload_official(model_path, model_sub_path, temp)
else:
repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/'
model_name = os.path.basename(model_path)
model_url = os.path.join(repo_base, model_path)
from mxnet.gluon.utils import download
temp = util.tempdir()
path_model = temp.relpath(model_name)
download(model_url, path_model)
# Creates graph from saved graph_def.pb.
......
......@@ -32,9 +32,8 @@ def run_tvm_graph(graph_def, input_data, input_node, num_output=1, target='llvm'
layout = None
if target == "cuda":
layout = "NCHW"
sym, params = nnvm.frontend.from_tensorflow(graph_def, layout=layout)
target_host = 'llvm'
if isinstance(input_data, list):
shape_dict = {}
dtype_dict = {}
......@@ -45,6 +44,7 @@ def run_tvm_graph(graph_def, input_data, input_node, num_output=1, target='llvm'
shape_dict = {input_node: input_data.shape}
dtype_dict = {input_node: input_data.dtype}
sym, params = nnvm.frontend.from_tensorflow(graph_def, layout=layout, shape=shape_dict)
graph, lib, params = nnvm.compiler.build(sym, target=target, target_host=target_host, shape=shape_dict,
dtype=dtype_dict, params=params)
......@@ -696,15 +696,20 @@ def test_forward_inception_v1():
# ---------
def test_forward_mobilenet():
'''test mobilenet model'''
# MobilenetV2
with tf.Graph().as_default():
graph_def = nnvm.testing.tf.get_workload("MobilenetV1/mobilenet_v1_1.0_224_frozen-with-shapes.pb")
graph_def = nnvm.testing.tf.get_workload(
"https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.4_224.tgz",
"mobilenet_v2_1.4_224_frozen.pb")
# Call the utility to import the graph definition into default graph.
graph_def = nnvm.testing.tf.ProcessGraphDefParam(graph_def)
data = np.random.uniform(size=(1, 224, 224, 3)).astype('float32')
out_node = 'MobilenetV1/Predictions/Reshape_1'
out_node = 'MobilenetV2/Predictions/Reshape_1'
with tf.Session() as sess:
# Add shapes to the graph.
graph_def = nnvm.testing.tf.AddShapesToGraphDef(sess, out_node)
tf_output = run_tf_graph(sess, data, 'input:0', out_node + ':0')
tvm_output = run_tvm_graph(graph_def, data, 'input')
tvm.testing.assert_allclose(np.squeeze(tvm_output), np.squeeze(tf_output), rtol=1e-5, atol=1e-5)
......
......@@ -32,13 +32,18 @@ repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/Incep
img_name = 'elephant-299.jpg'
image_url = os.path.join(repo_base, img_name)
# InceptionV1 model protobuf
######################################################################
# Tutorials
# ---------
# .. note::
#
# protobuf should be exported with :any:`add_shapes=True` option.
# Could use https://github.com/dmlc/web-data/tree/master/tensorflow/scripts/tf-to-nnvm.py
# to add shapes for existing models.
#
# Please refer docs/frontend/tensorflow.md for more details for various models
# from tensorflow.
model_name = 'classify_image_graph_def-with_shapes.pb'
model_url = os.path.join(repo_base, model_name)
......@@ -84,14 +89,15 @@ with tf.gfile.FastGFile(os.path.join("./", model_name), 'rb') as f:
# Call the utility to import the graph definition into default graph.
graph_def = nnvm.testing.tf.ProcessGraphDefParam(graph_def)
# Add shapes to the graph.
graph_def = nnvm.testing.tf.AddShapesToGraphDef('softmax')
with tf.Session() as sess:
graph_def = nnvm.testing.tf.AddShapesToGraphDef(sess, 'softmax')
######################################################################
# Decode image
# ------------
# .. note::
#
# tensorflow frontend import doesn't support preprocessing ops like JpegDecode
# tensorflow frontend import doesn't support preprocessing ops like JpegDecode.
# JpegDecode is bypassed (just return source node).
# Hence we supply decoded frame to TVM instead.
#
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment