Commit a81ebd90 by Siva Committed by Tianqi Chen

[NNVM][FRONTEND] Tensorflow frontend support (#1188)

parent 7afeab07
......@@ -270,6 +270,10 @@ def build(graph, target=None, shape=None, dtype="float32",
# Apply optimization
with target:
graph = optimize(graph, shape, dtype, layout)
# Clear extra params without nodes.
_remove_noref_params(params, graph)
# Precompute prune
if params and cfg.pass_enabled("PrecomputePrune"):
graph, params = precompute_prune(graph, params)
......@@ -296,6 +300,24 @@ def build(graph, target=None, shape=None, dtype="float32",
params.update(init_var)
return graph, libmod, params
def _remove_noref_params(params, graph):
""" Helper to clear non referenced params
Parameters
----------
graph : Graph
The input graph
params: dict of str to ndarray
The parameter dictionary
"""
arg_list = set(graph.symbol.list_input_names())
if params:
param_keys = list(params.keys())
for key in param_keys:
if key not in arg_list:
params.pop(key)
def _run_graph(graph, params):
"""Helper utility to build and run and get outputs, only use cpu mode.
......
......@@ -5,3 +5,4 @@ from .onnx import from_onnx
from .coreml import from_coreml
from .keras import from_keras
from .darknet import from_darknet
from .tensorflow import from_tensorflow
# pylint: disable=invalid-name, unused-variable, unused-argument, no-init
"""
Tensorflow Model Helpers
========================
Some helper definitions for tensorflow models.
"""
import re
import os.path
import numpy as np
# Tensorflow imports
import tensorflow as tf
from tensorflow.core.framework import graph_pb2
######################################################################
# Some helper functions
# ---------------------
def ProcessGraphDefParam(graph_def):
"""Type-checks and possibly canonicalizes `graph_def`.
Parameters
----------
graph_def : Obj
tensorflow graph definition.
Returns
-------
graph_def : Obj
tensorflow graph devinition
"""
if not isinstance(graph_def, graph_pb2.GraphDef):
# `graph_def` could be a dynamically-created message, so try a duck-typed
# approach
try:
old_graph_def = graph_def
graph_def = graph_pb2.GraphDef()
graph_def.MergeFrom(old_graph_def)
except TypeError:
raise TypeError('graph_def must be a GraphDef proto.')
return graph_def
class NodeLookup(object):
"""Converts integer node ID's to human readable labels."""
def __init__(self,
label_lookup_path=None,
uid_lookup_path=None):
self.node_lookup = self.load(label_lookup_path, uid_lookup_path)
def load(self, label_lookup_path, uid_lookup_path):
"""Loads a human readable English name for each softmax node.
Parameters
----------
label_lookup_path: String
File containing String UID to integer node ID mapping .
uid_lookup_path: String
File containing String UID to human-readable string mapping.
Returns
-------
node_id_to_name: dict
dict from integer node ID to human-readable string.
"""
if not tf.gfile.Exists(uid_lookup_path):
tf.logging.fatal('File does not exist %s', uid_lookup_path)
if not tf.gfile.Exists(label_lookup_path):
tf.logging.fatal('File does not exist %s', label_lookup_path)
# Loads mapping from string UID to human-readable string
proto_as_ascii_lines = tf.gfile.GFile(uid_lookup_path).readlines()
uid_to_human = {}
p = re.compile(r'[n\d]*[ \S,]*')
for line in proto_as_ascii_lines:
parsed_items = p.findall(line)
uid = parsed_items[0]
human_string = parsed_items[2]
uid_to_human[uid] = human_string
# Loads mapping from string UID to integer node ID.
node_id_to_uid = {}
proto_as_ascii = tf.gfile.GFile(label_lookup_path).readlines()
for line in proto_as_ascii:
if line.startswith(' target_class:'):
target_class = int(line.split(': ')[1])
if line.startswith(' target_class_string:'):
target_class_string = line.split(': ')[1]
node_id_to_uid[target_class] = target_class_string[1:-2]
# Loads the final mapping of integer node ID to human-readable string
node_id_to_name = {}
for key, val in node_id_to_uid.items():
if val not in uid_to_human:
tf.logging.fatal('Failed to locate: %s', val)
name = uid_to_human[val]
node_id_to_name[key] = name
return node_id_to_name
def id_to_string(self, node_id):
if node_id not in self.node_lookup:
return ''
return self.node_lookup[node_id]
def read_normalized_tensor_from_image_file(file_name,
input_height=299,
input_width=299,
input_mean=0,
input_std=255):
""" Preprocessing of image
Parameters
----------
file_name: String
Image filename.
input_height: int
model input height.
input_width: int
model input width
input_mean: int
Mean to be substracted in normalization.
input_std: int
Standard deviation used in normalization.
Returns
-------
np_array: Numpy array
Normalized image data as a numpy array.
"""
input_name = "file_reader"
output_name = "normalized"
file_reader = tf.read_file(file_name, input_name)
image_reader = tf.image.decode_jpeg(file_reader, channels=3,
name='jpeg_reader')
float_caster = tf.cast(image_reader, tf.float32)
dims_expander = tf.expand_dims(float_caster, 0)
resized = tf.image.resize_bilinear(dims_expander, [input_height, input_width])
normalized = tf.divide(tf.subtract(resized, [input_mean]), [input_std])
tf.InteractiveSession()
np_array = normalized.eval()
return np_array
def get_workload_inception_v3():
""" Import Inception V3 workload from frozen protobuf
Parameters
----------
Nothing.
Returns
-------
(normalized, graph_def) : Tuple
normalized is normalized input for graph testing.
graph_def is the tensorflow workload for Inception V3.
"""
repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/InceptionV3/'
model_name = 'inception_v3_2016_08_28_frozen-with_shapes.pb'
model_url = os.path.join(repo_base, model_name)
image_name = 'elephant-299.jpg'
image_url = os.path.join(repo_base, image_name)
from mxnet.gluon.utils import download
download(model_url, model_name)
download(image_url, image_name)
normalized = read_normalized_tensor_from_image_file(os.path.join("./", image_name))
# Creates graph from saved graph_def.pb.
with tf.gfile.FastGFile(os.path.join("./", model_name), 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
graph = tf.import_graph_def(graph_def, name='')
return (normalized, graph_def)
def get_workload_inception_v1():
""" Import Inception V1 workload from frozen protobuf
Parameters
----------
Nothing.
Returns
-------
(image_data, tvm_data, graph_def) : Tuple
image_data is raw encoded image data for TF input.
tvm_data is the decoded image data for TVM input.
graph_def is the tensorflow workload for Inception V1.
"""
repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/InceptionV1/'
model_name = 'classify_image_graph_def-with_shapes.pb'
model_url = os.path.join(repo_base, model_name)
image_name = 'elephant-299.jpg'
image_url = os.path.join(repo_base, image_name)
from mxnet.gluon.utils import download
download(model_url, model_name)
download(image_url, image_name)
if not tf.gfile.Exists(os.path.join("./", image_name)):
tf.logging.fatal('File does not exist %s', image)
image_data = tf.gfile.FastGFile(os.path.join("./", image_name), 'rb').read()
# TVM doesn't handle decode, hence decode it.
from PIL import Image
tvm_data = Image.open(os.path.join("./", image_name)).resize((299, 299))
tvm_data = np.array(tvm_data)
# Creates graph from saved graph_def.pb.
with tf.gfile.FastGFile(os.path.join("./", model_name), 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
graph = tf.import_graph_def(graph_def, name='')
return (image_data, tvm_data, graph_def)
......@@ -52,6 +52,15 @@ reg.register_schedule("_assign", _fschedule_broadcast)
reg.register_pattern("copy", OpPattern.ELEMWISE)
reg.register_schedule("copy", _fschedule_broadcast)
# cast
@reg.register_compute("cast")
def compute_cast(attrs, inputs, _):
"""Compute definition of cast"""
dtype = attrs.get_string("dtype")
return topi.cast(inputs[0], dtype)
reg.register_pattern("cast", OpPattern.ELEMWISE)
reg.register_schedule("cast", _fschedule_broadcast)
# exp
reg.register_pattern("exp", OpPattern.ELEMWISE)
reg.register_schedule("exp", _fschedule_broadcast)
......
......@@ -18,3 +18,6 @@ python3 -m nose -v nnvm/tests/python/frontend/mxnet || exit -1
echo "Running Keras frontend test..."
python3 -m nose -v nnvm/tests/python/frontend/keras || exit -1
echo "Running Tensorflow frontend test..."
python3 -m nose -v nnvm/tests/python/frontend/tensorflow || exit -1
"""
Compile Tensorflow Models
=========================
This article is an introductory tutorial to deploy tensorflow models with NNVM.
For us to begin with, tensorflow module is required to be installed.
A quick solution is to install tensorlfow from
https://www.tensorflow.org/install/install_sources
"""
import nnvm
import tvm
import numpy as np
import os.path
# Tensorflow imports
import tensorflow as tf
from tensorflow.core.framework import graph_pb2
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_util
import nnvm.testing.tf
repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/InceptionV1/'
img_name = 'elephant-299.jpg'
image_url = os.path.join(repo_base, img_name)
model_name = 'classify_image_graph_def-with_shapes.pb'
model_url = os.path.join(repo_base, model_name)
map_proto = 'imagenet_2012_challenge_label_map_proto.pbtxt'
map_proto_url = os.path.join(repo_base, map_proto)
lable_map = 'imagenet_synset_to_human_label_map.txt'
lable_map_url = os.path.join(repo_base, lable_map)
######################################################################
# Download processed tensorflow model
# -----------------------------------
# In this section, we download a pretrained Tensorflow model and classify an image.
from mxnet.gluon.utils import download
download(image_url, img_name)
download(model_url, model_name)
download(map_proto_url, map_proto)
download(lable_map_url, lable_map)
######################################################################
# Creates graph from saved graph_def.pb.
# --------------------------------------
with tf.gfile.FastGFile(os.path.join(
"./", model_name), 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
graph = tf.import_graph_def(graph_def, name='')
# Call the utility to import the graph definition into default graph.
graph_def = nnvm.testing.tf.ProcessGraphDefParam(graph_def)
######################################################################
# Decode image
# ------------
from PIL import Image
image = Image.open(img_name).resize((299, 299))
def transform_image(image):
image = np.array(image)
return image
x = transform_image(image)
######################################################################
# Import the graph to NNVM
# ------------------------
sym, params = nnvm.frontend.from_tensorflow(graph_def)
######################################################################
# Now compile the graph through NNVM
import nnvm.compiler
target = 'llvm'
shape_dict = {'DecodeJpeg/contents': x.shape}
dtype_dict = {'DecodeJpeg/contents': 'uint8'}
graph, lib, params = nnvm.compiler.build(sym, target, shape_dict, dtype=dtype_dict, params=params)
######################################################################
# Execute the portable graph on TVM
# ---------------------------------
# Now, we would like to reproduce the same forward computation using TVM.
from tvm.contrib import graph_runtime
ctx = tvm.cpu(0)
dtype = 'uint8'
m = graph_runtime.create(graph, lib, ctx)
# set inputs
m.set_input('DecodeJpeg/contents', tvm.nd.array(x.astype(dtype)))
m.set_input(**params)
# execute
m.run()
# get outputs
tvm_output = m.get_output(0, tvm.nd.empty(((1, 1008)), 'float32'))
######################################################################
# Process the output to human readable
# ------------------------------------
predictions = tvm_output.asnumpy()
predictions = np.squeeze(predictions)
# Creates node ID --> English string lookup.
node_lookup = nnvm.testing.tf.NodeLookup(label_lookup_path=os.path.join("./", map_proto),
uid_lookup_path=os.path.join("./", lable_map))
top_k = predictions.argsort()[-5:][::-1]
for node_id in top_k:
human_string = node_lookup.id_to_string(node_id)
score = predictions[node_id]
print('%s (score = %.5f)' % (human_string, score))
######################################################################
# Run the same graph with tensorflow and dump output.
# ---------------------------------------------------
def create_graph():
"""Creates a graph from saved GraphDef file and returns a saver."""
# Creates graph from saved graph_def.pb.
with tf.gfile.FastGFile(model_name, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
graph = tf.import_graph_def(graph_def, name='')
# Call the utility to import the graph definition into default graph.
graph_def = nnvm.testing.tf.ProcessGraphDefParam(graph_def)
def run_inference_on_image(image):
"""Runs inference on an image.
Parameters
----------
image: String
Image file name.
Returns
-------
Nothing
"""
if not tf.gfile.Exists(image):
tf.logging.fatal('File does not exist %s', image)
image_data = tf.gfile.FastGFile(image, 'rb').read()
# Creates graph from saved GraphDef.
create_graph()
with tf.Session() as sess:
softmax_tensor = sess.graph.get_tensor_by_name('softmax:0')
predictions = sess.run(softmax_tensor,
{'DecodeJpeg/contents:0': image_data})
predictions = np.squeeze(predictions)
# Creates node ID --> English string lookup.
node_lookup = nnvm.testing.tf.NodeLookup(label_lookup_path=os.path.join("./", map_proto),
uid_lookup_path=os.path.join("./", lable_map))
top_k = predictions.argsort()[-5:][::-1]
print ("===== TENSORFLOW RESULTS =======")
for node_id in top_k:
human_string = node_lookup.id_to_string(node_id)
score = predictions[node_id]
print('%s (score = %.5f)' % (human_string, score))
run_inference_on_image (img_name)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment