Commit 2f859d71 by Siva Committed by Yizhi Liu

[RELAY][FRONTEND] Tensorflow frontend. (#2216)

* [RELAY][FRONTEND] Tensorflow frontend support.

* 	* LSTM removed for a while.

* 	* basic ops are good.

* 	* nn wip

* 	* wip

* 	* python2.7 corrections.

* * NN ops are good.

* * e2e models working good

* 	* all good except LSTM

* 	* rebase, tutorials and CI trigger.

* 	* CI errors.

* 	* enable opt_level=3

* 	* Docstrings cleanup. testing.tf utils moved to relay from nnvm.

* 	* tutorials update.

* 	* LSTM work good now.

* 	* Rebase

* 	* CI error

* 	* enable PTB.

* 	* rebase.

* 	* tutorials

* Update python/tvm/relay/frontend/tensorflow.py

Co-Authored-By: srkreddy1238 <sivar.b@huawei.com>

* 	* review comments.

* 	CI fix.

* 	* review comments.
parent 40f76825
......@@ -21,7 +21,7 @@ instructions to generate protobuf from checkpoint.
### Add Shapes:
While freezing of protobuf add additional option ```add_shapes=True``` to embed output shapes of each node into graph.
You may use ```nnvm.testing.tf.AddShapesToGraphDef``` from nnvm for the same.
You may use ```tvm.relay.testing.tf.AddShapesToGraphDef``` from nnvm for the same.
Please refer to [tensorflow tutorial](https://github.com/dmlc/tvm/blob/master/tutorials/nnvm/from_tensorflow.py).
### Explicit Shape:
......
......@@ -21,7 +21,7 @@ from tensorflow.python.ops import variables
from tensorflow.python.ops import init_ops
from tensorflow.core.framework import graph_pb2
import nnvm.testing.tf
import tvm.relay.testing.tf as tf_testing
#######################################################################
# Generic run functions for TVM & tensorflow
......@@ -784,9 +784,9 @@ def test_forward_pad():
def test_forward_inception_v3():
'''test inception V3 model'''
with tf.Graph().as_default():
graph_def = nnvm.testing.tf.get_workload('InceptionV3/inception_v3_2016_08_28_frozen-with_shapes.pb')
graph_def = tf_testing.get_workload('InceptionV3/inception_v3_2016_08_28_frozen-with_shapes.pb')
# Call the utility to import the graph definition into default graph.
graph_def = nnvm.testing.tf.ProcessGraphDefParam(graph_def)
graph_def = tf_testing.ProcessGraphDefParam(graph_def)
data = np.random.uniform(size=(1, 299, 299, 3)).astype('float32')
......@@ -801,9 +801,9 @@ def test_forward_inception_v3():
def test_forward_inception_v1():
'''test inception V1 model'''
with tf.Graph().as_default():
graph_def = nnvm.testing.tf.get_workload("InceptionV1/classify_image_graph_def-with_shapes.pb")
graph_def = tf_testing.get_workload("InceptionV1/classify_image_graph_def-with_shapes.pb")
# Call the utility to import the graph definition into default graph.
graph_def = nnvm.testing.tf.ProcessGraphDefParam(graph_def)
graph_def = tf_testing.ProcessGraphDefParam(graph_def)
# Build an image from random data.
from PIL import Image
......@@ -838,18 +838,18 @@ def test_forward_mobilenet():
'''test mobilenet model'''
# MobilenetV2
with tf.Graph().as_default():
graph_def = nnvm.testing.tf.get_workload(
graph_def = tf_testing.get_workload(
"https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.4_224.tgz",
"mobilenet_v2_1.4_224_frozen.pb")
# Call the utility to import the graph definition into default graph.
graph_def = nnvm.testing.tf.ProcessGraphDefParam(graph_def)
graph_def = tf_testing.ProcessGraphDefParam(graph_def)
data = np.random.uniform(size=(1, 224, 224, 3)).astype('float32')
out_node = 'MobilenetV2/Predictions/Reshape_1'
with tf.Session() as sess:
# Add shapes to the graph.
graph_def = nnvm.testing.tf.AddShapesToGraphDef(sess, out_node)
graph_def = tf_testing.AddShapesToGraphDef(sess, out_node)
tf_output = run_tf_graph(sess, data, 'input:0', out_node + ':0')
tvm_output = run_tvm_graph(graph_def, data, 'input')
tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tf_output[0]), rtol=1e-5, atol=1e-5)
......@@ -861,9 +861,9 @@ def test_forward_resnetv2():
'''test resnet model'''
if is_gpu_available():
with tf.Graph().as_default():
graph_def = nnvm.testing.tf.get_workload("ResnetV2/resnet-20180601_resnet_v2_imagenet-shapes.pb")
graph_def = tf_testing.get_workload("ResnetV2/resnet-20180601_resnet_v2_imagenet-shapes.pb")
# Call the utility to import the graph definition into default graph.
graph_def = nnvm.testing.tf.ProcessGraphDefParam(graph_def)
graph_def = tf_testing.ProcessGraphDefParam(graph_def)
data = np.random.uniform(size=(128, 224, 224, 3)).astype('float32')
out_node = 'ArgMax'
......@@ -879,7 +879,7 @@ def test_forward_resnetv2():
dir(tf.contrib)
def test_forward_ptb():
'''test ptb model'''
config = nnvm.testing.tf.get_config()
config = tf_testing.get_config()
num_steps = config.num_steps
num_hidden = config.hidden_size
num_layers = config.num_layers
......@@ -936,7 +936,7 @@ def test_forward_ptb():
"float32")).asnumpy()
state_output = model.get_output(1, tvm.nd.empty(out_state_shape,
"float32")).asnumpy()
sample = nnvm.testing.tf.pick_from_weight(tvm_output[0])
sample = tf_testing.pick_from_weight(tvm_output[0])
return sample, state_output
......@@ -956,10 +956,10 @@ def test_forward_ptb():
return samples, state
with tf.Graph().as_default():
word_to_id, id_to_word, graph_def = nnvm.testing.tf.get_workload_ptb()
word_to_id, id_to_word, graph_def = tf_testing.get_workload_ptb()
vocab_size = len(word_to_id)
# Call the utility to import the graph definition into default graph.
graph_def = nnvm.testing.tf.ProcessGraphDefParam(graph_def)
graph_def = tf_testing.ProcessGraphDefParam(graph_def)
sess = tf.Session()
#TVM graph module creation
......@@ -975,7 +975,7 @@ def test_forward_ptb():
for word in seed_for_sample],
in_state, params, cnt_sample)
tvm_sample_str = _pretty_print(tvm_samples, False, id_to_word)
tf_samples, tf_state = nnvm.testing.tf.do_tf_sample(sess,
tf_samples, tf_state = tf_testing.do_tf_sample(sess,
[word_to_id[word] for word in seed_for_sample],
in_state, cnt_sample)
tf_sample_str = _pretty_print(tf_samples, False, id_to_word)
......
......@@ -13,3 +13,4 @@ from .onnx import from_onnx
from .tflite import from_tflite
from .coreml import from_coreml
from .caffe2 import from_caffe2
from .tensorflow import from_tensorflow
......@@ -16,7 +16,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import variables
from tensorflow.contrib.lite.python import interpreter as interpreter_wrapper
import nnvm.testing.tf
import tvm.relay.testing.tf as tf_testing
#######################################################################
# Generic run functions for TVM & TFLite
......@@ -344,7 +344,7 @@ def test_forward_mobilenet():
'''test mobilenet v1 tflite model'''
# MobilenetV1
temp = util.tempdir()
tflite_model_file = nnvm.testing.tf.get_workload_official(
tflite_model_file = tf_testing.get_workload_official(
"http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224.tgz",
"mobilenet_v1_1.0_224.tflite", temp)
tflite_model_buf = open(tflite_model_file, "rb").read()
......
......@@ -42,6 +42,9 @@ python3 -m nose -v tests/python/frontend/onnx || exit -1
echo "Running relay CoreML frondend test..."
python3 -m nose -v tests/python/frontend/coreml || exit -1
echo "Running relay Tensorflow frontend test..."
python3 -m nose -v tests/python/frontend/tensorflow || exit -1
echo "Running nnvm to relay frontend test..."
python3 -m nose -v tests/python/frontend/nnvm_to_relay || exit -1
......@@ -50,4 +53,3 @@ python3 -m nose -v tests/python/frontend/tflite || exit -1
echo "Running relay caffe2 frondend test..."
python3 -m nose -v tests/python/frontend/caffe2 || exit -1
"""
Compile Tensorflow Models
=========================
This article is an introductory tutorial to deploy tensorflow models with TVM.
For us to begin with, tensorflow python module is required to be installed.
Please refer to https://www.tensorflow.org/install
"""
# tvm, relay
import tvm
from tvm import relay
# os and numpy
import numpy as np
import os.path
# Tensorflow imports
import tensorflow as tf
# Tensorflow utility functions
import tvm.relay.testing.tf as tf_testing
# Base location for model related files.
repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/InceptionV1/'
# Test image
img_name = 'elephant-299.jpg'
image_url = os.path.join(repo_base, img_name)
######################################################################
# Tutorials
# ---------
# Please refer docs/frontend/tensorflow.md for more details for various models
# from tensorflow.
model_name = 'classify_image_graph_def-with_shapes.pb'
model_url = os.path.join(repo_base, model_name)
# Image label map
map_proto = 'imagenet_2012_challenge_label_map_proto.pbtxt'
map_proto_url = os.path.join(repo_base, map_proto)
# Human readable text for labels
lable_map = 'imagenet_synset_to_human_label_map.txt'
lable_map_url = os.path.join(repo_base, lable_map)
# Target settings
# Use these commented settings to build for cuda.
#target = 'cuda'
#target_host = 'llvm'
#layout = "NCHW"
#ctx = tvm.gpu(0)
target = 'llvm'
target_host = 'llvm'
layout = None
ctx = tvm.cpu(0)
######################################################################
# Download required files
# -----------------------
# Download files listed above.
from mxnet.gluon.utils import download
download(image_url, img_name)
download(model_url, model_name)
download(map_proto_url, map_proto)
download(lable_map_url, lable_map)
######################################################################
# Import model
# ------------
# Creates tensorflow graph definition from protobuf file.
with tf.gfile.FastGFile(os.path.join("./", model_name), 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
graph = tf.import_graph_def(graph_def, name='')
# Call the utility to import the graph definition into default graph.
graph_def = tf_testing.ProcessGraphDefParam(graph_def)
# Add shapes to the graph.
with tf.Session() as sess:
graph_def = tf_testing.AddShapesToGraphDef(sess, 'softmax')
######################################################################
# Decode image
# ------------
# .. note::
#
# tensorflow frontend import doesn't support preprocessing ops like JpegDecode.
# JpegDecode is bypassed (just return source node).
# Hence we supply decoded frame to TVM instead.
#
from PIL import Image
image = Image.open(img_name).resize((299, 299))
x = np.array(image)
######################################################################
# Import the graph to Relay
# -------------------------
# Import tensorflow graph definition to relay frontend.
#
# Results:
# sym: relay expr for given tensorflow protobuf.
# params: params converted from tensorflow params (tensor protobuf).
shape_dict = {'DecodeJpeg/contents': x.shape}
dtype_dict = {'DecodeJpeg/contents': 'uint8'}
sym, params = relay.frontend.from_tensorflow(graph_def, layout=layout, shape=shape_dict)
print ("Tensorflow protobuf imported to relay frontend.")
######################################################################
# Relay Build
# -----------
# Compile the graph to llvm target with given input specification.
#
# Results:
# graph: Final graph after compilation.
# params: final params after compilation.
# lib: target library which can be deployed on target with tvm runtime.
with relay.build_config(opt_level=3):
graph, lib, params = relay.build(sym, target=target, target_host=target_host, params=params)
######################################################################
# Execute the portable graph on TVM
# ---------------------------------
# Now we can try deploying the compiled model on target.
from tvm.contrib import graph_runtime
dtype = 'uint8'
m = graph_runtime.create(graph, lib, ctx)
# set inputs
m.set_input('DecodeJpeg/contents', tvm.nd.array(x.astype(dtype)))
m.set_input(**params)
# execute
m.run()
# get outputs
tvm_output = m.get_output(0, tvm.nd.empty(((1, 1008)), 'float32'))
######################################################################
# Process the output
# ------------------
# Process the model output to human readable text for InceptionV1.
predictions = tvm_output.asnumpy()
predictions = np.squeeze(predictions)
# Creates node ID --> English string lookup.
node_lookup = tf_testing.NodeLookup(label_lookup_path=os.path.join("./", map_proto),
uid_lookup_path=os.path.join("./", lable_map))
# Print top 5 predictions from TVM output.
top_k = predictions.argsort()[-5:][::-1]
for node_id in top_k:
human_string = node_lookup.id_to_string(node_id)
score = predictions[node_id]
print('%s (score = %.5f)' % (human_string, score))
######################################################################
# Inference on tensorflow
# -----------------------
# Run the corresponding model on tensorflow
def create_graph():
"""Creates a graph from saved GraphDef file and returns a saver."""
# Creates graph from saved graph_def.pb.
with tf.gfile.FastGFile(model_name, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
graph = tf.import_graph_def(graph_def, name='')
# Call the utility to import the graph definition into default graph.
graph_def = tf_testing.ProcessGraphDefParam(graph_def)
def run_inference_on_image(image):
"""Runs inference on an image.
Parameters
----------
image: String
Image file name.
Returns
-------
Nothing
"""
if not tf.gfile.Exists(image):
tf.logging.fatal('File does not exist %s', image)
image_data = tf.gfile.FastGFile(image, 'rb').read()
# Creates graph from saved GraphDef.
create_graph()
with tf.Session() as sess:
softmax_tensor = sess.graph.get_tensor_by_name('softmax:0')
predictions = sess.run(softmax_tensor,
{'DecodeJpeg/contents:0': image_data})
predictions = np.squeeze(predictions)
# Creates node ID --> English string lookup.
node_lookup = tf_testing.NodeLookup(label_lookup_path=os.path.join("./", map_proto),
uid_lookup_path=os.path.join("./", lable_map))
# Print top 5 predictions from tensorflow.
top_k = predictions.argsort()[-5:][::-1]
print ("===== TENSORFLOW RESULTS =======")
for node_id in top_k:
human_string = node_lookup.id_to_string(node_id)
score = predictions[node_id]
print('%s (score = %.5f)' % (human_string, score))
run_inference_on_image (img_name)
......@@ -23,7 +23,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_util
# Tensorflow utility functions
import nnvm.testing.tf
import tvm.relay.testing.tf as tf_testing
# Base location for model related files.
repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/InceptionV1/'
......@@ -87,10 +87,10 @@ with tf.gfile.FastGFile(os.path.join("./", model_name), 'rb') as f:
graph_def.ParseFromString(f.read())
graph = tf.import_graph_def(graph_def, name='')
# Call the utility to import the graph definition into default graph.
graph_def = nnvm.testing.tf.ProcessGraphDefParam(graph_def)
graph_def = tf_testing.ProcessGraphDefParam(graph_def)
# Add shapes to the graph.
with tf.Session() as sess:
graph_def = nnvm.testing.tf.AddShapesToGraphDef(sess, 'softmax')
graph_def = tf_testing.AddShapesToGraphDef(sess, 'softmax')
######################################################################
# Decode image
......@@ -157,7 +157,7 @@ predictions = tvm_output.asnumpy()
predictions = np.squeeze(predictions)
# Creates node ID --> English string lookup.
node_lookup = nnvm.testing.tf.NodeLookup(label_lookup_path=os.path.join("./", map_proto),
node_lookup = tf_testing.NodeLookup(label_lookup_path=os.path.join("./", map_proto),
uid_lookup_path=os.path.join("./", lable_map))
# Print top 5 predictions from TVM output.
......@@ -180,7 +180,7 @@ def create_graph():
graph_def.ParseFromString(f.read())
graph = tf.import_graph_def(graph_def, name='')
# Call the utility to import the graph definition into default graph.
graph_def = nnvm.testing.tf.ProcessGraphDefParam(graph_def)
graph_def = tf_testing.ProcessGraphDefParam(graph_def)
def run_inference_on_image(image):
"""Runs inference on an image.
......@@ -209,7 +209,7 @@ def run_inference_on_image(image):
predictions = np.squeeze(predictions)
# Creates node ID --> English string lookup.
node_lookup = nnvm.testing.tf.NodeLookup(label_lookup_path=os.path.join("./", map_proto),
node_lookup = tf_testing.NodeLookup(label_lookup_path=os.path.join("./", map_proto),
uid_lookup_path=os.path.join("./", lable_map))
# Print top 5 predictions from tensorflow.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment