Commit 2f859d71 by Siva Committed by Yizhi Liu

[RELAY][FRONTEND] Tensorflow frontend. (#2216)

* [RELAY][FRONTEND] Tensorflow frontend support.

* 	* LSTM removed for a while.

* 	* basic ops are good.

* 	* nn wip

* 	* wip

* 	* python2.7 corrections.

* * NN ops are good.

* * e2e models working good

* 	* all good except LSTM

* 	* rebase, tutorials and CI trigger.

* 	* CI errors.

* 	* enable opt_level=3

* 	* Docstrings cleanup. testing.tf utils moved to relay from nnvm.

* 	* tutorials update.

* 	* LSTM work good now.

* 	* Rebase

* 	* CI error

* 	* enable PTB.

* 	* rebase.

* 	* tutorials

* Update python/tvm/relay/frontend/tensorflow.py

Co-Authored-By: srkreddy1238 <sivar.b@huawei.com>

* 	* review comments.

* 	CI fix.

* 	* review comments.
parent 40f76825
...@@ -21,7 +21,7 @@ instructions to generate protobuf from checkpoint. ...@@ -21,7 +21,7 @@ instructions to generate protobuf from checkpoint.
### Add Shapes: ### Add Shapes:
While freezing of protobuf add additional option ```add_shapes=True``` to embed output shapes of each node into graph. While freezing of protobuf add additional option ```add_shapes=True``` to embed output shapes of each node into graph.
You may use ```nnvm.testing.tf.AddShapesToGraphDef``` from nnvm for the same. You may use ```tvm.relay.testing.tf.AddShapesToGraphDef``` from nnvm for the same.
Please refer to [tensorflow tutorial](https://github.com/dmlc/tvm/blob/master/tutorials/nnvm/from_tensorflow.py). Please refer to [tensorflow tutorial](https://github.com/dmlc/tvm/blob/master/tutorials/nnvm/from_tensorflow.py).
### Explicit Shape: ### Explicit Shape:
......
...@@ -21,7 +21,7 @@ from tensorflow.python.ops import variables ...@@ -21,7 +21,7 @@ from tensorflow.python.ops import variables
from tensorflow.python.ops import init_ops from tensorflow.python.ops import init_ops
from tensorflow.core.framework import graph_pb2 from tensorflow.core.framework import graph_pb2
import nnvm.testing.tf import tvm.relay.testing.tf as tf_testing
####################################################################### #######################################################################
# Generic run functions for TVM & tensorflow # Generic run functions for TVM & tensorflow
...@@ -784,9 +784,9 @@ def test_forward_pad(): ...@@ -784,9 +784,9 @@ def test_forward_pad():
def test_forward_inception_v3(): def test_forward_inception_v3():
'''test inception V3 model''' '''test inception V3 model'''
with tf.Graph().as_default(): with tf.Graph().as_default():
graph_def = nnvm.testing.tf.get_workload('InceptionV3/inception_v3_2016_08_28_frozen-with_shapes.pb') graph_def = tf_testing.get_workload('InceptionV3/inception_v3_2016_08_28_frozen-with_shapes.pb')
# Call the utility to import the graph definition into default graph. # Call the utility to import the graph definition into default graph.
graph_def = nnvm.testing.tf.ProcessGraphDefParam(graph_def) graph_def = tf_testing.ProcessGraphDefParam(graph_def)
data = np.random.uniform(size=(1, 299, 299, 3)).astype('float32') data = np.random.uniform(size=(1, 299, 299, 3)).astype('float32')
...@@ -801,9 +801,9 @@ def test_forward_inception_v3(): ...@@ -801,9 +801,9 @@ def test_forward_inception_v3():
def test_forward_inception_v1(): def test_forward_inception_v1():
'''test inception V1 model''' '''test inception V1 model'''
with tf.Graph().as_default(): with tf.Graph().as_default():
graph_def = nnvm.testing.tf.get_workload("InceptionV1/classify_image_graph_def-with_shapes.pb") graph_def = tf_testing.get_workload("InceptionV1/classify_image_graph_def-with_shapes.pb")
# Call the utility to import the graph definition into default graph. # Call the utility to import the graph definition into default graph.
graph_def = nnvm.testing.tf.ProcessGraphDefParam(graph_def) graph_def = tf_testing.ProcessGraphDefParam(graph_def)
# Build an image from random data. # Build an image from random data.
from PIL import Image from PIL import Image
...@@ -838,18 +838,18 @@ def test_forward_mobilenet(): ...@@ -838,18 +838,18 @@ def test_forward_mobilenet():
'''test mobilenet model''' '''test mobilenet model'''
# MobilenetV2 # MobilenetV2
with tf.Graph().as_default(): with tf.Graph().as_default():
graph_def = nnvm.testing.tf.get_workload( graph_def = tf_testing.get_workload(
"https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.4_224.tgz", "https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.4_224.tgz",
"mobilenet_v2_1.4_224_frozen.pb") "mobilenet_v2_1.4_224_frozen.pb")
# Call the utility to import the graph definition into default graph. # Call the utility to import the graph definition into default graph.
graph_def = nnvm.testing.tf.ProcessGraphDefParam(graph_def) graph_def = tf_testing.ProcessGraphDefParam(graph_def)
data = np.random.uniform(size=(1, 224, 224, 3)).astype('float32') data = np.random.uniform(size=(1, 224, 224, 3)).astype('float32')
out_node = 'MobilenetV2/Predictions/Reshape_1' out_node = 'MobilenetV2/Predictions/Reshape_1'
with tf.Session() as sess: with tf.Session() as sess:
# Add shapes to the graph. # Add shapes to the graph.
graph_def = nnvm.testing.tf.AddShapesToGraphDef(sess, out_node) graph_def = tf_testing.AddShapesToGraphDef(sess, out_node)
tf_output = run_tf_graph(sess, data, 'input:0', out_node + ':0') tf_output = run_tf_graph(sess, data, 'input:0', out_node + ':0')
tvm_output = run_tvm_graph(graph_def, data, 'input') tvm_output = run_tvm_graph(graph_def, data, 'input')
tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tf_output[0]), rtol=1e-5, atol=1e-5) tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tf_output[0]), rtol=1e-5, atol=1e-5)
...@@ -861,9 +861,9 @@ def test_forward_resnetv2(): ...@@ -861,9 +861,9 @@ def test_forward_resnetv2():
'''test resnet model''' '''test resnet model'''
if is_gpu_available(): if is_gpu_available():
with tf.Graph().as_default(): with tf.Graph().as_default():
graph_def = nnvm.testing.tf.get_workload("ResnetV2/resnet-20180601_resnet_v2_imagenet-shapes.pb") graph_def = tf_testing.get_workload("ResnetV2/resnet-20180601_resnet_v2_imagenet-shapes.pb")
# Call the utility to import the graph definition into default graph. # Call the utility to import the graph definition into default graph.
graph_def = nnvm.testing.tf.ProcessGraphDefParam(graph_def) graph_def = tf_testing.ProcessGraphDefParam(graph_def)
data = np.random.uniform(size=(128, 224, 224, 3)).astype('float32') data = np.random.uniform(size=(128, 224, 224, 3)).astype('float32')
out_node = 'ArgMax' out_node = 'ArgMax'
...@@ -879,7 +879,7 @@ def test_forward_resnetv2(): ...@@ -879,7 +879,7 @@ def test_forward_resnetv2():
dir(tf.contrib) dir(tf.contrib)
def test_forward_ptb(): def test_forward_ptb():
'''test ptb model''' '''test ptb model'''
config = nnvm.testing.tf.get_config() config = tf_testing.get_config()
num_steps = config.num_steps num_steps = config.num_steps
num_hidden = config.hidden_size num_hidden = config.hidden_size
num_layers = config.num_layers num_layers = config.num_layers
...@@ -936,7 +936,7 @@ def test_forward_ptb(): ...@@ -936,7 +936,7 @@ def test_forward_ptb():
"float32")).asnumpy() "float32")).asnumpy()
state_output = model.get_output(1, tvm.nd.empty(out_state_shape, state_output = model.get_output(1, tvm.nd.empty(out_state_shape,
"float32")).asnumpy() "float32")).asnumpy()
sample = nnvm.testing.tf.pick_from_weight(tvm_output[0]) sample = tf_testing.pick_from_weight(tvm_output[0])
return sample, state_output return sample, state_output
...@@ -956,10 +956,10 @@ def test_forward_ptb(): ...@@ -956,10 +956,10 @@ def test_forward_ptb():
return samples, state return samples, state
with tf.Graph().as_default(): with tf.Graph().as_default():
word_to_id, id_to_word, graph_def = nnvm.testing.tf.get_workload_ptb() word_to_id, id_to_word, graph_def = tf_testing.get_workload_ptb()
vocab_size = len(word_to_id) vocab_size = len(word_to_id)
# Call the utility to import the graph definition into default graph. # Call the utility to import the graph definition into default graph.
graph_def = nnvm.testing.tf.ProcessGraphDefParam(graph_def) graph_def = tf_testing.ProcessGraphDefParam(graph_def)
sess = tf.Session() sess = tf.Session()
#TVM graph module creation #TVM graph module creation
...@@ -975,7 +975,7 @@ def test_forward_ptb(): ...@@ -975,7 +975,7 @@ def test_forward_ptb():
for word in seed_for_sample], for word in seed_for_sample],
in_state, params, cnt_sample) in_state, params, cnt_sample)
tvm_sample_str = _pretty_print(tvm_samples, False, id_to_word) tvm_sample_str = _pretty_print(tvm_samples, False, id_to_word)
tf_samples, tf_state = nnvm.testing.tf.do_tf_sample(sess, tf_samples, tf_state = tf_testing.do_tf_sample(sess,
[word_to_id[word] for word in seed_for_sample], [word_to_id[word] for word in seed_for_sample],
in_state, cnt_sample) in_state, cnt_sample)
tf_sample_str = _pretty_print(tf_samples, False, id_to_word) tf_sample_str = _pretty_print(tf_samples, False, id_to_word)
......
...@@ -13,3 +13,4 @@ from .onnx import from_onnx ...@@ -13,3 +13,4 @@ from .onnx import from_onnx
from .tflite import from_tflite from .tflite import from_tflite
from .coreml import from_coreml from .coreml import from_coreml
from .caffe2 import from_caffe2 from .caffe2 import from_caffe2
from .tensorflow import from_tensorflow
...@@ -16,7 +16,7 @@ from tensorflow.python.ops import array_ops ...@@ -16,7 +16,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import variables from tensorflow.python.ops import variables
from tensorflow.contrib.lite.python import interpreter as interpreter_wrapper from tensorflow.contrib.lite.python import interpreter as interpreter_wrapper
import nnvm.testing.tf import tvm.relay.testing.tf as tf_testing
####################################################################### #######################################################################
# Generic run functions for TVM & TFLite # Generic run functions for TVM & TFLite
...@@ -344,7 +344,7 @@ def test_forward_mobilenet(): ...@@ -344,7 +344,7 @@ def test_forward_mobilenet():
'''test mobilenet v1 tflite model''' '''test mobilenet v1 tflite model'''
# MobilenetV1 # MobilenetV1
temp = util.tempdir() temp = util.tempdir()
tflite_model_file = nnvm.testing.tf.get_workload_official( tflite_model_file = tf_testing.get_workload_official(
"http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224.tgz", "http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224.tgz",
"mobilenet_v1_1.0_224.tflite", temp) "mobilenet_v1_1.0_224.tflite", temp)
tflite_model_buf = open(tflite_model_file, "rb").read() tflite_model_buf = open(tflite_model_file, "rb").read()
......
...@@ -42,6 +42,9 @@ python3 -m nose -v tests/python/frontend/onnx || exit -1 ...@@ -42,6 +42,9 @@ python3 -m nose -v tests/python/frontend/onnx || exit -1
echo "Running relay CoreML frondend test..." echo "Running relay CoreML frondend test..."
python3 -m nose -v tests/python/frontend/coreml || exit -1 python3 -m nose -v tests/python/frontend/coreml || exit -1
echo "Running relay Tensorflow frontend test..."
python3 -m nose -v tests/python/frontend/tensorflow || exit -1
echo "Running nnvm to relay frontend test..." echo "Running nnvm to relay frontend test..."
python3 -m nose -v tests/python/frontend/nnvm_to_relay || exit -1 python3 -m nose -v tests/python/frontend/nnvm_to_relay || exit -1
...@@ -50,4 +53,3 @@ python3 -m nose -v tests/python/frontend/tflite || exit -1 ...@@ -50,4 +53,3 @@ python3 -m nose -v tests/python/frontend/tflite || exit -1
echo "Running relay caffe2 frondend test..." echo "Running relay caffe2 frondend test..."
python3 -m nose -v tests/python/frontend/caffe2 || exit -1 python3 -m nose -v tests/python/frontend/caffe2 || exit -1
"""
Compile Tensorflow Models
=========================
This article is an introductory tutorial to deploy tensorflow models with TVM.
For us to begin with, tensorflow python module is required to be installed.
Please refer to https://www.tensorflow.org/install
"""
# tvm, relay
import tvm
from tvm import relay
# os and numpy
import numpy as np
import os.path
# Tensorflow imports
import tensorflow as tf
# Tensorflow utility functions
import tvm.relay.testing.tf as tf_testing
# Base location for model related files.
repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/InceptionV1/'
# Test image
img_name = 'elephant-299.jpg'
image_url = os.path.join(repo_base, img_name)
######################################################################
# Tutorials
# ---------
# Please refer docs/frontend/tensorflow.md for more details for various models
# from tensorflow.
model_name = 'classify_image_graph_def-with_shapes.pb'
model_url = os.path.join(repo_base, model_name)
# Image label map
map_proto = 'imagenet_2012_challenge_label_map_proto.pbtxt'
map_proto_url = os.path.join(repo_base, map_proto)
# Human readable text for labels
lable_map = 'imagenet_synset_to_human_label_map.txt'
lable_map_url = os.path.join(repo_base, lable_map)
# Target settings
# Use these commented settings to build for cuda.
#target = 'cuda'
#target_host = 'llvm'
#layout = "NCHW"
#ctx = tvm.gpu(0)
target = 'llvm'
target_host = 'llvm'
layout = None
ctx = tvm.cpu(0)
######################################################################
# Download required files
# -----------------------
# Download files listed above.
from mxnet.gluon.utils import download
download(image_url, img_name)
download(model_url, model_name)
download(map_proto_url, map_proto)
download(lable_map_url, lable_map)
######################################################################
# Import model
# ------------
# Creates tensorflow graph definition from protobuf file.
with tf.gfile.FastGFile(os.path.join("./", model_name), 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
graph = tf.import_graph_def(graph_def, name='')
# Call the utility to import the graph definition into default graph.
graph_def = tf_testing.ProcessGraphDefParam(graph_def)
# Add shapes to the graph.
with tf.Session() as sess:
graph_def = tf_testing.AddShapesToGraphDef(sess, 'softmax')
######################################################################
# Decode image
# ------------
# .. note::
#
# tensorflow frontend import doesn't support preprocessing ops like JpegDecode.
# JpegDecode is bypassed (just return source node).
# Hence we supply decoded frame to TVM instead.
#
from PIL import Image
image = Image.open(img_name).resize((299, 299))
x = np.array(image)
######################################################################
# Import the graph to Relay
# -------------------------
# Import tensorflow graph definition to relay frontend.
#
# Results:
# sym: relay expr for given tensorflow protobuf.
# params: params converted from tensorflow params (tensor protobuf).
shape_dict = {'DecodeJpeg/contents': x.shape}
dtype_dict = {'DecodeJpeg/contents': 'uint8'}
sym, params = relay.frontend.from_tensorflow(graph_def, layout=layout, shape=shape_dict)
print ("Tensorflow protobuf imported to relay frontend.")
######################################################################
# Relay Build
# -----------
# Compile the graph to llvm target with given input specification.
#
# Results:
# graph: Final graph after compilation.
# params: final params after compilation.
# lib: target library which can be deployed on target with tvm runtime.
with relay.build_config(opt_level=3):
graph, lib, params = relay.build(sym, target=target, target_host=target_host, params=params)
######################################################################
# Execute the portable graph on TVM
# ---------------------------------
# Now we can try deploying the compiled model on target.
from tvm.contrib import graph_runtime
dtype = 'uint8'
m = graph_runtime.create(graph, lib, ctx)
# set inputs
m.set_input('DecodeJpeg/contents', tvm.nd.array(x.astype(dtype)))
m.set_input(**params)
# execute
m.run()
# get outputs
tvm_output = m.get_output(0, tvm.nd.empty(((1, 1008)), 'float32'))
######################################################################
# Process the output
# ------------------
# Process the model output to human readable text for InceptionV1.
predictions = tvm_output.asnumpy()
predictions = np.squeeze(predictions)
# Creates node ID --> English string lookup.
node_lookup = tf_testing.NodeLookup(label_lookup_path=os.path.join("./", map_proto),
uid_lookup_path=os.path.join("./", lable_map))
# Print top 5 predictions from TVM output.
top_k = predictions.argsort()[-5:][::-1]
for node_id in top_k:
human_string = node_lookup.id_to_string(node_id)
score = predictions[node_id]
print('%s (score = %.5f)' % (human_string, score))
######################################################################
# Inference on tensorflow
# -----------------------
# Run the corresponding model on tensorflow
def create_graph():
"""Creates a graph from saved GraphDef file and returns a saver."""
# Creates graph from saved graph_def.pb.
with tf.gfile.FastGFile(model_name, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
graph = tf.import_graph_def(graph_def, name='')
# Call the utility to import the graph definition into default graph.
graph_def = tf_testing.ProcessGraphDefParam(graph_def)
def run_inference_on_image(image):
"""Runs inference on an image.
Parameters
----------
image: String
Image file name.
Returns
-------
Nothing
"""
if not tf.gfile.Exists(image):
tf.logging.fatal('File does not exist %s', image)
image_data = tf.gfile.FastGFile(image, 'rb').read()
# Creates graph from saved GraphDef.
create_graph()
with tf.Session() as sess:
softmax_tensor = sess.graph.get_tensor_by_name('softmax:0')
predictions = sess.run(softmax_tensor,
{'DecodeJpeg/contents:0': image_data})
predictions = np.squeeze(predictions)
# Creates node ID --> English string lookup.
node_lookup = tf_testing.NodeLookup(label_lookup_path=os.path.join("./", map_proto),
uid_lookup_path=os.path.join("./", lable_map))
# Print top 5 predictions from tensorflow.
top_k = predictions.argsort()[-5:][::-1]
print ("===== TENSORFLOW RESULTS =======")
for node_id in top_k:
human_string = node_lookup.id_to_string(node_id)
score = predictions[node_id]
print('%s (score = %.5f)' % (human_string, score))
run_inference_on_image (img_name)
...@@ -23,7 +23,7 @@ from tensorflow.python.framework import dtypes ...@@ -23,7 +23,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_util from tensorflow.python.framework import tensor_util
# Tensorflow utility functions # Tensorflow utility functions
import nnvm.testing.tf import tvm.relay.testing.tf as tf_testing
# Base location for model related files. # Base location for model related files.
repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/InceptionV1/' repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/InceptionV1/'
...@@ -87,10 +87,10 @@ with tf.gfile.FastGFile(os.path.join("./", model_name), 'rb') as f: ...@@ -87,10 +87,10 @@ with tf.gfile.FastGFile(os.path.join("./", model_name), 'rb') as f:
graph_def.ParseFromString(f.read()) graph_def.ParseFromString(f.read())
graph = tf.import_graph_def(graph_def, name='') graph = tf.import_graph_def(graph_def, name='')
# Call the utility to import the graph definition into default graph. # Call the utility to import the graph definition into default graph.
graph_def = nnvm.testing.tf.ProcessGraphDefParam(graph_def) graph_def = tf_testing.ProcessGraphDefParam(graph_def)
# Add shapes to the graph. # Add shapes to the graph.
with tf.Session() as sess: with tf.Session() as sess:
graph_def = nnvm.testing.tf.AddShapesToGraphDef(sess, 'softmax') graph_def = tf_testing.AddShapesToGraphDef(sess, 'softmax')
###################################################################### ######################################################################
# Decode image # Decode image
...@@ -157,7 +157,7 @@ predictions = tvm_output.asnumpy() ...@@ -157,7 +157,7 @@ predictions = tvm_output.asnumpy()
predictions = np.squeeze(predictions) predictions = np.squeeze(predictions)
# Creates node ID --> English string lookup. # Creates node ID --> English string lookup.
node_lookup = nnvm.testing.tf.NodeLookup(label_lookup_path=os.path.join("./", map_proto), node_lookup = tf_testing.NodeLookup(label_lookup_path=os.path.join("./", map_proto),
uid_lookup_path=os.path.join("./", lable_map)) uid_lookup_path=os.path.join("./", lable_map))
# Print top 5 predictions from TVM output. # Print top 5 predictions from TVM output.
...@@ -180,7 +180,7 @@ def create_graph(): ...@@ -180,7 +180,7 @@ def create_graph():
graph_def.ParseFromString(f.read()) graph_def.ParseFromString(f.read())
graph = tf.import_graph_def(graph_def, name='') graph = tf.import_graph_def(graph_def, name='')
# Call the utility to import the graph definition into default graph. # Call the utility to import the graph definition into default graph.
graph_def = nnvm.testing.tf.ProcessGraphDefParam(graph_def) graph_def = tf_testing.ProcessGraphDefParam(graph_def)
def run_inference_on_image(image): def run_inference_on_image(image):
"""Runs inference on an image. """Runs inference on an image.
...@@ -209,7 +209,7 @@ def run_inference_on_image(image): ...@@ -209,7 +209,7 @@ def run_inference_on_image(image):
predictions = np.squeeze(predictions) predictions = np.squeeze(predictions)
# Creates node ID --> English string lookup. # Creates node ID --> English string lookup.
node_lookup = nnvm.testing.tf.NodeLookup(label_lookup_path=os.path.join("./", map_proto), node_lookup = tf_testing.NodeLookup(label_lookup_path=os.path.join("./", map_proto),
uid_lookup_path=os.path.join("./", lable_map)) uid_lookup_path=os.path.join("./", lable_map))
# Print top 5 predictions from tensorflow. # Print top 5 predictions from tensorflow.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment