Commit 2f859d71 by Siva Committed by Yizhi Liu

[RELAY][FRONTEND] Tensorflow frontend. (#2216)

* [RELAY][FRONTEND] Tensorflow frontend support.

* 	* LSTM removed for a while.

* 	* basic ops are good.

* 	* nn wip

* 	* wip

* 	* python2.7 corrections.

* * NN ops are good.

* * e2e models working good

* 	* all good except LSTM

* 	* rebase, tutorials and CI trigger.

* 	* CI errors.

* 	* enable opt_level=3

* 	* Docstrings cleanup. testing.tf utils moved to relay from nnvm.

* 	* tutorials update.

* 	* LSTM work good now.

* 	* Rebase

* 	* CI error

* 	* enable PTB.

* 	* rebase.

* 	* tutorials

* Update python/tvm/relay/frontend/tensorflow.py

Co-Authored-By: srkreddy1238 <sivar.b@huawei.com>

* 	* review comments.

* 	CI fix.

* 	* review comments.
parent 40f76825
......@@ -21,7 +21,7 @@ instructions to generate protobuf from checkpoint.
### Add Shapes:
While freezing of protobuf add additional option ```add_shapes=True``` to embed output shapes of each node into graph.
You may use ```nnvm.testing.tf.AddShapesToGraphDef``` from nnvm for the same.
You may use ```tvm.relay.testing.tf.AddShapesToGraphDef``` from nnvm for the same.
Please refer to [tensorflow tutorial](https://github.com/dmlc/tvm/blob/master/tutorials/nnvm/from_tensorflow.py).
### Explicit Shape:
......
......@@ -21,7 +21,7 @@ from tensorflow.python.ops import variables
from tensorflow.python.ops import init_ops
from tensorflow.core.framework import graph_pb2
import nnvm.testing.tf
import tvm.relay.testing.tf as tf_testing
#######################################################################
# Generic run functions for TVM & tensorflow
......@@ -784,9 +784,9 @@ def test_forward_pad():
def test_forward_inception_v3():
'''test inception V3 model'''
with tf.Graph().as_default():
graph_def = nnvm.testing.tf.get_workload('InceptionV3/inception_v3_2016_08_28_frozen-with_shapes.pb')
graph_def = tf_testing.get_workload('InceptionV3/inception_v3_2016_08_28_frozen-with_shapes.pb')
# Call the utility to import the graph definition into default graph.
graph_def = nnvm.testing.tf.ProcessGraphDefParam(graph_def)
graph_def = tf_testing.ProcessGraphDefParam(graph_def)
data = np.random.uniform(size=(1, 299, 299, 3)).astype('float32')
......@@ -801,9 +801,9 @@ def test_forward_inception_v3():
def test_forward_inception_v1():
'''test inception V1 model'''
with tf.Graph().as_default():
graph_def = nnvm.testing.tf.get_workload("InceptionV1/classify_image_graph_def-with_shapes.pb")
graph_def = tf_testing.get_workload("InceptionV1/classify_image_graph_def-with_shapes.pb")
# Call the utility to import the graph definition into default graph.
graph_def = nnvm.testing.tf.ProcessGraphDefParam(graph_def)
graph_def = tf_testing.ProcessGraphDefParam(graph_def)
# Build an image from random data.
from PIL import Image
......@@ -838,18 +838,18 @@ def test_forward_mobilenet():
'''test mobilenet model'''
# MobilenetV2
with tf.Graph().as_default():
graph_def = nnvm.testing.tf.get_workload(
graph_def = tf_testing.get_workload(
"https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.4_224.tgz",
"mobilenet_v2_1.4_224_frozen.pb")
# Call the utility to import the graph definition into default graph.
graph_def = nnvm.testing.tf.ProcessGraphDefParam(graph_def)
graph_def = tf_testing.ProcessGraphDefParam(graph_def)
data = np.random.uniform(size=(1, 224, 224, 3)).astype('float32')
out_node = 'MobilenetV2/Predictions/Reshape_1'
with tf.Session() as sess:
# Add shapes to the graph.
graph_def = nnvm.testing.tf.AddShapesToGraphDef(sess, out_node)
graph_def = tf_testing.AddShapesToGraphDef(sess, out_node)
tf_output = run_tf_graph(sess, data, 'input:0', out_node + ':0')
tvm_output = run_tvm_graph(graph_def, data, 'input')
tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tf_output[0]), rtol=1e-5, atol=1e-5)
......@@ -861,9 +861,9 @@ def test_forward_resnetv2():
'''test resnet model'''
if is_gpu_available():
with tf.Graph().as_default():
graph_def = nnvm.testing.tf.get_workload("ResnetV2/resnet-20180601_resnet_v2_imagenet-shapes.pb")
graph_def = tf_testing.get_workload("ResnetV2/resnet-20180601_resnet_v2_imagenet-shapes.pb")
# Call the utility to import the graph definition into default graph.
graph_def = nnvm.testing.tf.ProcessGraphDefParam(graph_def)
graph_def = tf_testing.ProcessGraphDefParam(graph_def)
data = np.random.uniform(size=(128, 224, 224, 3)).astype('float32')
out_node = 'ArgMax'
......@@ -879,7 +879,7 @@ def test_forward_resnetv2():
dir(tf.contrib)
def test_forward_ptb():
'''test ptb model'''
config = nnvm.testing.tf.get_config()
config = tf_testing.get_config()
num_steps = config.num_steps
num_hidden = config.hidden_size
num_layers = config.num_layers
......@@ -936,7 +936,7 @@ def test_forward_ptb():
"float32")).asnumpy()
state_output = model.get_output(1, tvm.nd.empty(out_state_shape,
"float32")).asnumpy()
sample = nnvm.testing.tf.pick_from_weight(tvm_output[0])
sample = tf_testing.pick_from_weight(tvm_output[0])
return sample, state_output
......@@ -956,10 +956,10 @@ def test_forward_ptb():
return samples, state
with tf.Graph().as_default():
word_to_id, id_to_word, graph_def = nnvm.testing.tf.get_workload_ptb()
word_to_id, id_to_word, graph_def = tf_testing.get_workload_ptb()
vocab_size = len(word_to_id)
# Call the utility to import the graph definition into default graph.
graph_def = nnvm.testing.tf.ProcessGraphDefParam(graph_def)
graph_def = tf_testing.ProcessGraphDefParam(graph_def)
sess = tf.Session()
#TVM graph module creation
......@@ -975,7 +975,7 @@ def test_forward_ptb():
for word in seed_for_sample],
in_state, params, cnt_sample)
tvm_sample_str = _pretty_print(tvm_samples, False, id_to_word)
tf_samples, tf_state = nnvm.testing.tf.do_tf_sample(sess,
tf_samples, tf_state = tf_testing.do_tf_sample(sess,
[word_to_id[word] for word in seed_for_sample],
in_state, cnt_sample)
tf_sample_str = _pretty_print(tf_samples, False, id_to_word)
......
......@@ -13,3 +13,4 @@ from .onnx import from_onnx
from .tflite import from_tflite
from .coreml import from_coreml
from .caffe2 import from_caffe2
from .tensorflow import from_tensorflow
# pylint: disable=import-self, invalid-name, unused-argument, too-many-lines, len-as-condition
"""TF: Tensorflow frontend."""
from __future__ import absolute_import as _abs
from __future__ import print_function
import logging
# Numpy support
import numpy as np
import tvm
from topi.util import get_const_tuple
from .. import ir_pass
from .. import expr as _expr
from .. import op as _op
__all__ = ['from_tensorflow']
def _get_relay_op(op_name):
try:
op = getattr(_op, op_name)
except AttributeError:
try:
op = getattr(_op.nn, op_name)
except AttributeError:
op = getattr(_op.image, op_name)
if not op:
raise RuntimeError("Unable to map op_name {} to relay".format(op_name))
return op
class AttrCvt(object):
"""Common attribute conveter. An AttrConverter instance is a callable:
```
attr_converter = AttrConverter(op_name, transforms={'a':'b', 'c':('d', 1)})
new_op_name, new_attr = attr_converter(attrs)
```
Parameters
----------
op_name : str or callable
If set as str, returned operator name is the str.
If set as callable, returned operator is the str returned by calling:
`op_name = func(attr)`
transforms : dict of `new_name, or (new_name, default_value, transform function)`
If only a new_name is provided, it's like renaming the attribute name.
If default_value if provded, then the attribute is considered as optional.
If transform function is provided, the original attribute value is handled
by transform function.
excludes : list
A list of excluded attributes that should `NOT` appear.
Raise NotImplementedError if occured.
disables : list
A list of attributes that is disabled in relay. Log warnings.
ignores : list
A list of attributes that is ignored in relay. Debug level logging.
extras : dict
A series of additional attributes should be added anyway to the returned
attribute dict.
custom_check : callable
A custom function takes attribute, and return True/False.
Raise RuntimeError if not bool(True) returned.
"""
def __init__(self, op_name, transforms=None,
excludes=None, disables=None, ignores=None,
extras=None, custom_check=None):
self._op_name = op_name
self._transforms = transforms if transforms else {}
self._excludes = excludes if excludes else []
self._disables = disables if disables else []
self._ignores = ignores if ignores else []
self._extras = extras if extras else {}
self._custom_check = custom_check
def __call__(self, inputs, attrs, *args):
self._ignores.append('_output_shapes')
self._ignores.append('_input_shapes')
self._ignores.append('T')
self._ignores.append('use_cudnn_on_gpu')
self._ignores.append('_node_name')
self._ignores.append('is_training')
self._ignores.append('_target_layout')
# apply custom check
if self._custom_check:
func, msg = self._custom_check
if not func(attrs):
raise RuntimeError("Check failed: {}".format(msg))
# get new op_name
if isinstance(self._op_name, str):
op_name = self._op_name
else:
assert callable(self._op_name), "op_name can either be string or callable"
op_name = self._op_name(attrs)
# convert attributes
new_attrs = {}
for k in attrs.keys():
if k in self._excludes:
raise NotImplementedError("Attribute {} not supported yet.".format(k))
elif k in self._disables:
logging.warning("Attribute %s is disabled in relay.%s", k, op_name)
elif k in self._ignores:
logging.debug("Attribute %s is ignored in relay.%s", k, op_name)
elif k in self._transforms:
new_name, defaults, transform = self._parse_default(self._transforms[k])
if defaults is None:
new_attr = self._required_attr(attrs, k)
else:
new_attr = attrs.get(k, None)
if new_attr is None:
new_attrs[new_name] = defaults
else:
new_attrs[new_name] = transform(new_attr)
else:
# copy
new_attrs[k] = attrs[k]
# add extras
new_attrs.update(self._extras)
return _get_relay_op(op_name)(*inputs, **new_attrs)
def _parse_default(self, target):
"""Helper function to parse default values."""
if not isinstance(target, (list, tuple)):
k, v, t = target, None, lambda x: x
elif len(target) == 1:
k, v, t = target[0], None, lambda x: x
elif len(target) == 2:
k, v, t = target[0], target[1], lambda x: x
elif len(target) > 2:
k, v, t = target[0], target[1], target[2]
else:
k = None # should raise
if not isinstance(k, str):
msg = "{} is not a valid target, (name, default) expected.".format(target)
raise ValueError(msg)
return k, v, t
def _parse_bool(self, value):
"""Helper function to parse default boolean values."""
if isinstance(value, str):
return value.strip().lower() in ['true', '1', 't', 'y', 'yes']
return bool(value)
def _required_attr(self, attr, key):
"""Wrapper for getting required attributes."""
assert isinstance(attr, dict)
if key not in attr:
raise AttributeError("Required attribute {} not found.".format(key))
return attr[key]
def _get_pad_pair(input1d, kernel1d, stride1d):
if input1d % stride1d == 0:
pad = max(kernel1d - stride1d, 0)
else:
pad = max(kernel1d - (input1d % stride1d), 0)
pad_before = pad // 2
pad_after = pad - pad_before
return [pad_before, pad_after]
def _get_name_hint(node):
name = ''
if hasattr(node, "name_hint"):
name = node.name_hint
return name
def _math_name_picker(surfix):
def _impl(attr):
return 'broadcast_' + surfix
return _impl
def _dimension_picker(prefix, surfix=''):
def _impl(attr):
kernel = attr['kernel_shape']
if len(kernel) == 2:
return prefix + '2d' + surfix
else:
raise NotImplementedError("Only 2d kernel supported.")
return _impl
def _dimension_constraint():
def _dim_check(attrs):
if len(attrs['kernel_shape']) == 2:
return True
return False
return _dim_check, "Only 2d kernel supported."
def _infer_channels(inputs, params, transpose=False):
"""A hack for getting 'channles' or 'units' since tensorflow don't provide
these attributes. We check the shape of weights provided to get the number.
"""
out_type = ir_pass.infer_type(inputs)
out_shapes = [get_const_tuple(out_type.checked_type.shape)]
channels = out_shapes[0][0] if not transpose else out_shapes[0][1]
return channels
def _rsqrt():
def _impl(inputs, attr, *args):
inputs.append(tvm.relay.const(-0.5, attr['T'].name))
return AttrCvt(op_name="power")(inputs, attr)
return _impl
def _argx(func, func_name):
""" A common wrapper for argmin and argmax operations """
def _impl(inputs, attr, params):
try:
# In Tensorflow, `axis` argument is a Tensor, not attribute. We
# support the case where it inputs from a scalar constant.
axis_input_name = inputs[1].name_hint
axis_input_vlaue = [params[axis_input_name].asnumpy()[0]]
except (IndexError, KeyError):
raise TypeError( \
"Unsupported argument for `{}` : `axis` should be a constant".format(func_name))
return func(inputs[0], axis=axis_input_vlaue, keepdims=False)
return _impl
def _elemwise(name):
def _impl(inputs, attr, *args):
assert len(inputs) == 2, "Math op take 2 inputs, {} given".format(len(inputs))
return _get_relay_op(name)(*inputs)
return _impl
def _pooling(name):
def _impl(inputs, attr, params):
attr['data_format'] = attr['data_format'].decode("utf-8")
flip_layout = False
input_shape = attr['_input_shapes'][inputs[0]][0]
if attr['data_format'] == 'NHWC':
attr['kernel_shape'] = (attr['ksize'][1], attr['ksize'][2])
attr['strides'] = (attr['strides'][1], attr['strides'][2])
elif attr['data_format'] == 'NCHW':
attr['kernel_shape'] = (attr['ksize'][2], attr['ksize'][3])
attr['strides'] = (attr['strides'][2], attr['strides'][3])
else:
raise TypeError("Unsupported data_format type : {}".format(attr['data_format']))
if attr['_target_layout'] == "NCHW" and attr['data_format'] == "NHWC":
tmp_shape = attr['_input_shapes'][inputs[0]][0]
input_shape = [tmp_shape[ii] for ii in (0, 3, 1, 2)]
inputs[0] = _op.transpose(inputs[0], axes=(0, 3, 1, 2))
attr['data_format'] = "NCHW"
flip_layout = True
# Fix padding
attr['padding'] = attr['padding'].decode("utf-8")
if attr['padding'] == 'VALID':
attr['padding'] = [0, 0]
elif attr['padding'] == 'SAME':
stride_h, stride_w = attr['strides']
kernel_h, kernel_w = attr['kernel_shape']
if attr['data_format'] == 'NHWC':
in_h = input_shape[1]
in_w = input_shape[2]
else:
in_h = input_shape[2]
in_w = input_shape[3]
pad_v = _get_pad_pair(in_h, kernel_h, stride_h)
pad_h = _get_pad_pair(in_w, kernel_w, stride_w)
attr['padding'] = [pad_v[0], pad_h[0], pad_v[1], pad_h[1]]
else:
raise TypeError("Unsupported padding type : {}".format(attr['padding']))
if name == "avg_pool":
attr['count_include_pad'] = False
out = AttrCvt(
op_name=_dimension_picker(name),
transforms={
'kernel_shape':'pool_size',
'data_format':'layout'},
ignores=['ksize'],
extras={'ceil_mode': False},
custom_check=_dimension_constraint())(inputs, attr)
if flip_layout:
out = _op.transpose(out, axes=(0, 2, 3, 1))
return out
return _impl
def _conv(opname):
def _impl(inputs, attr, params):
attr['data_format'] = attr['data_format'].decode("utf-8")
flip_layout = False
# NCHW Layout require weights transpose
if attr['data_format'] == 'NCHW':
tmp_shape = attr['_input_shapes'][inputs[1]][0]
tmp_shape = [tmp_shape[ii] for ii in (3, 2, 0, 1)]
inputs[1] = _op.transpose(inputs[1], axes=(3, 2, 0, 1))
attr['_input_shapes'][inputs[1]] = [tmp_shape]
input_shape = attr['_input_shapes'][inputs[0]][0]
weights_shape = attr['_input_shapes'][inputs[1]][0]
if attr['_target_layout'] == "NCHW" and attr['data_format'] == "NHWC":
input_shape = [input_shape[ii] for ii in (0, 3, 1, 2)]
inputs[0] = _op.transpose(inputs[0], axes=(0, 3, 1, 2))
if opname == 'conv':
weights_shape = [weights_shape[ii] for ii in (3, 2, 0, 1)]
inputs[1] = _op.transpose(inputs[1], axes=(3, 2, 0, 1))
else:
weights_shape = [weights_shape[ii] for ii in (2, 3, 0, 1)]
inputs[1] = _op.transpose(inputs[1], axes=(2, 3, 0, 1))
attr['data_format'] = "NCHW"
attr['strides'] = [attr['strides'][ii] for ii in (0, 3, 1, 2)]
flip_layout = True
if attr['data_format'] == 'NHWC':
kernel_h, kernel_w, _, depth_mult = weights_shape
attr['kernel_shape'] = (weights_shape[0], weights_shape[1])
if opname == 'conv':
attr['channels'] = weights_shape[3]
else:
attr['channels'] = input_shape[3] * depth_mult
if 'dilations' in attr:
attr['dilations'] = (attr['dilations'][0], attr['dilations'][1])
attr['strides'] = (attr['strides'][1], attr['strides'][2])
elif attr['data_format'] == 'NCHW':
depth_mult, _, kernel_h, kernel_w = weights_shape
attr['kernel_shape'] = (weights_shape[2], weights_shape[3])
if opname == 'conv':
attr['channels'] = weights_shape[0]
else:
attr['channels'] = input_shape[0] * depth_mult
if attr['channels'] < 0:
attr['channels'] *= -1
if 'dilations' in attr:
attr['dilations'] = (attr['dilations'][2], attr['dilations'][3])
attr['strides'] = (attr['strides'][2], attr['strides'][3])
else:
raise TypeError("Unsupported data format type : {}".format(attr['data_format']))
if opname == 'depthwise':
attr['groups'] = attr['channels']
# Fix padding
attr['padding'] = attr['padding'].decode("utf-8")
if attr['padding'] == 'VALID':
attr['padding'] = [0, 0]
elif attr['padding'] == 'SAME':
stride_h, stride_w = attr['strides']
kernel_h, kernel_w = attr['kernel_shape']
if attr['data_format'] == 'NHWC':
in_h = input_shape[1]
in_w = input_shape[2]
else:
in_h = input_shape[2]
in_w = input_shape[3]
pad_v = _get_pad_pair(in_h, kernel_h, stride_h)
pad_h = _get_pad_pair(in_w, kernel_w, stride_w)
if attr['data_format'] == 'NHWC':
inputs[0] = _op.nn.pad(data=inputs[0],
pad_width=((0, 0),
(pad_v[0], pad_v[1]),
(pad_h[0], pad_h[1]),
(0, 0)))
else:
inputs[0] = _op.nn.pad(data=inputs[0],
pad_width=((0, 0),
(0, 0),
(pad_v[0], pad_v[1]),
(pad_h[0], pad_h[1])))
attr['padding'] = [0, 0]
else:
raise TypeError("Unsupported padding type : {}".format(attr['padding']))
if 'kernel_layout' not in attr:
if opname == 'conv':
attr['kernel_layout'] = 'HWIO' if attr['data_format'] == 'NHWC' else 'OIHW'
else:
attr['kernel_layout'] = 'HWOI' if attr['data_format'] == 'NHWC' else 'OIHW'
use_bias = len(inputs) == 3
channel_axis = 1 if attr['data_format'] == "NCHW" else 3
out = AttrCvt(
op_name=_dimension_picker('conv'),
transforms={
'kernel_shape': 'kernel_size',
'data_format': 'data_layout',
'dilations': ('dilation', (0, 0)),
'group': ('groups', 1)},
custom_check=_dimension_constraint())([inputs[0], inputs[1]], attr)
if use_bias:
out = _op.nn.bias_add(out, inputs[2], axis=channel_axis)
if flip_layout:
out = _op.transpose(out, axes=(0, 2, 3, 1))
return out
return _impl
def _decode_image():
def _impl(inputs, attr, params):
# Image decode wrapper: Expecting user to feed decoded input to next layer drop this layer.
print("DecodeJpeg: It's a pass through, please handle preprocessing before input")
return inputs[0]
return _impl
def _cast():
def _impl(inputs, attr, params):
return inputs[0].astype(attr['DstT'].name)
return _impl
def _expand_dims():
def _impl(inputs, attr, params):
dim_input = inputs.pop(1)
axis = params[dim_input.name_hint]
params.pop(dim_input.name_hint)
return AttrCvt(op_name="expand_dims", ignores=['Tdim'],
extras={'axis': int(axis.asnumpy()[0])})(inputs, attr)
return _impl
def _resize_bilinear():
def _impl(inputs, attr, params):
attr['size'] = attr['_output_shapes'][0][1:3]
inputs.pop(1)
# NHWC
attr['layout'] = 'NHWC'
return AttrCvt(op_name="resize",
ignores=['Tdim'],
extras={'method': "BILINEAR"})(inputs, attr)
return _impl
def _check_numerics():
def _impl(inputs, attr, params):
# Making a copy node assuming no need to verify
return AttrCvt(op_name="copy", ignores=['message'])(inputs, attr)
return _impl
def _matmul():
def _impl(inputs, attr, params):
channels = _infer_channels(inputs[1], params, not attr['transpose_b'])
if attr['transpose_a']:
inputs[0] = _op.transpose(inputs[0], axes=(1, 0))
if not attr['transpose_b']:
inputs[1] = _op.transpose(inputs[1], axes=(1, 0))
return AttrCvt(op_name="dense",
extras={'units': channels},
ignores=['transpose_a', 'transpose_b', 'T'])(inputs, attr)
return _impl
def _identity():
def _impl(inputs, attr, params):
return inputs[0]
return _impl
def _concatV2():
def _impl(inputs, attr, params):
pop_node = inputs.pop(len(inputs)-1)
axis = params[pop_node.name_hint]
params.pop(pop_node.name_hint)
return AttrCvt(
op_name="concatenate", ignores=['T', 'N', 'Tidx'],
extras={'axis': int(axis.asnumpy()[0])})([inputs], attr)
return _impl
def _concat():
def _impl(inputs, attr, params):
pop_node = inputs.pop(0)
axis = params[pop_node.name_hint]
params.pop(pop_node.name_hint)
return AttrCvt(
op_name="concatenate", ignores=['N'],
extras={'axis': int(axis.asnumpy()[0])})([inputs], attr)
return _impl
def _pack():
def _impl(inputs, attr, params):
axis = int(attr["axis"])
inputs_reshaped = [_op.expand_dims(i, axis=axis, num_newaxis=1) for i in inputs]
return _op.concatenate(inputs_reshaped, axis)
return _impl
def _reshape():
def _impl(inputs, attr, params):
try:
pop_node = inputs[1]
shape_arg = params.pop(pop_node.name_hint)
inputs.pop(1)
return AttrCvt(
op_name="reshape",
extras={'newshape':tuple(shape_arg.asnumpy())},
ignores=['Tshape'])(inputs, attr)
except KeyError:
# Shape operator is already pruned, hence
# try to infer shape by precompute prune if possible.
if all(in_node in params for in_node in inputs[1].list_input_names()):
func = _expr.Function(ir_pass.free_vars(inputs[1]), inputs[1])
with tvm.relay.build_config(opt_level=0):
graph, lib, params = tvm.relay.build(func, target="llvm", params=params)
ctx = tvm.context("llvm", 0)
from tvm.contrib import graph_runtime
m = graph_runtime.create(graph, lib, ctx)
m.set_input(**params)
m.run()
params_new = m.get_output(0)
inputs.pop(1)
return AttrCvt(
op_name="reshape",
extras={'newshape':tuple(params_new.asnumpy().flatten())},
ignores=['Tshape'])(inputs, attr)
else:
raise RuntimeError("Reshape with dynamic shape input not supported yet.")
return _impl
def _bias_add():
def _impl(inputs, attr, params):
return _op.add(inputs[0], inputs[1])
return _impl
def _squeeze():
def _impl(inputs, attr, params):
if len(attr['squeeze_dims']) == 0:
attr['squeeze_dims'] = None
return AttrCvt(
op_name="squeeze",
transforms={'squeeze_dims':'axis'},
ignores=['T'])(inputs, attr)
return _impl
def _fused_batch_norm():
def _impl(inputs, attr, params):
# Tensorflow: (data, gamma, beta, moving_mean, moving_variance)
# Relay: (data, gamma, beta, moving_mean, moving_varience)
axis = 3
need_cast = False
if 'data_format' in attr:
attr['data_format'] = attr['data_format'].decode("utf-8")
if attr['data_format'] == 'NCHW':
axis = 1
if 'U' in attr:
need_cast = True
inputs[0] = _op.cast(inputs[0], dtype=attr['U'].name)
out = AttrCvt(op_name='batch_norm',
transforms={'scale_after_normalization':'scale',
'variance_epsilon':'epsilon'},
extras={'axis': axis},
ignores=['data_format', 'U'],
disables=['momentum'])(inputs, attr)
if need_cast:
out = _op.cast(out, dtype=attr['T'].name)
return out
return _impl
def _batch_norm():
def _impl(inputs, attr, params):
# Rearrange inputs from
# (data, moving_mean, moving_variance, beta, gamma)
# to
# (data, gamma, beta, moving_mean, moving_var)
new_inputs = [inputs[0], inputs[4], inputs[3], inputs[1], inputs[2]]
axis = 3
if 'data_format' in attr:
attr['data_format'] = attr['data_format'].decode("utf-8")
if attr['data_format'] == 'NCHW':
axis = 1
return AttrCvt(
op_name='batch_norm',
transforms={'scale_after_normalization':'scale', 'variance_epsilon':'epsilon'},
extras={'axis': axis},
ignores=['data_format'],
disables=['momentum'])(new_inputs, attr)
return _impl
def _relu6():
def _impl(inputs, attr, params):
return _op.clip(inputs[0], a_min=0, a_max=6)
return _impl
def _shape():
def _impl(inputs, attr, params):
return np.array(attr['_input_shapes'][inputs[0]][0], dtype='int32')
return _impl
def _fill():
def _impl(inputs, attr, params):
fill_arg = params.pop(inputs.pop(1).name_hint)
return _op.full(tvm.relay.const(fill_arg.asnumpy()[0], attr['T'].name),
attr['_output_shapes'][0], attr['T'].name)
return _impl
def _lrn():
def _impl(inputs, attr, params):
attr_new = {}
depth_radius = attr.get('depth_radius', 5)
size = (depth_radius * 2) + 1
attr_new['axis'] = 3 # Fix axis, NHWC format
attr_new['size'] = size
attr_new['bias'] = attr.get('bias', 1)
attr_new['alpha'] = attr.get('alpha', 1) * size
attr_new['beta'] = attr.get('beta', 0.5)
return AttrCvt(op_name='lrn')(inputs, attr_new)
return _impl
def _sum():
def _impl(inputs, attr, params):
axis = params.pop(inputs[1].name_hint).asnumpy()
# convert to tuple for preventing invalid parameter format error
axis = tuple(axis)
return AttrCvt(
op_name='sum',
extras={'axis': axis},
transforms={'keep_dims':'keepdims'},
ignores=['name', 'Tidx'])([inputs[0]], attr)
return _impl
def _square():
def _impl(inputs, attr, params):
return _op.multiply(inputs[0], inputs[0])
return _impl
def _gather_v2():
"Tensorflow now support only gatherv2"
def _impl(inputs, attr, params):
axis = params[inputs.pop(2).name_hint].asnumpy()[0]
new_input = []
new_input.append(inputs.pop(0))
new_input.append(inputs.pop(0))
return AttrCvt(op_name="take",
extras={'axis': tvm.const(axis, 'int32')},
ignores=['Tindices', 'Tparams', 'validate_indices', \
'Taxis', '_class'])(new_input, attr)
return _impl
def _infer_out_shapes(inputs, params):
"""A method to get the output shape of an intermediate node in the relay graph."""
out_type = ir_pass.infer_type(inputs)
out_shapes = [get_const_tuple(out_type.checked_type.shape)]
return out_shapes
def _stridedSlice():
def _impl(inputs, attr, params):
"""Strided Slice.
Operator description: https://www.tensorflow.org/api_docs/python/tf/strided_slice
Tensorflow mask validation: https://github.com/tensorflow/tensorflow/blob/master/
tensorflow/core/util/strided_slice_op.cc#L147-L368
"""
begin = params.pop(inputs[1].name_hint).asnumpy().tolist()
end = params.pop(inputs[2].name_hint).asnumpy().tolist()
stride = params.pop(inputs[3].name_hint).asnumpy().tolist()
begin_mask = int(attr.get('begin_mask', 0))
end_mask = int(attr.get('end_mask', 0))
ellipsis_mask = int(attr.get('ellipsis_mask', 0))
new_axis_mask = int(attr.get('new_axis_mask', 0))
shrink_axis_mask = int(attr.get('shrink_axis_mask', 0))
data_shape = attr['_input_shapes'][inputs[0]]
data_dim = len(data_shape[0])
stride_dim = len(stride)
def _transform_mask(stride_dim, ellipsis_mask):
"""Handle mask inputs to create new begin, end, stride and output shape"""
m_begin = [0] * data_dim
m_end = [0] * data_dim
m_stride = [0] * data_dim
fshape_indices = []
#Count new axis after ellipsis_mask, consider while applying ellipsis_mask.
ellipsis_seen = False
new_axes_after_ellipsis = 0
for i in range(stride_dim):
mask = 1 << i
if ellipsis_seen and (mask & new_axis_mask) != 0:
new_axes_after_ellipsis += 1
if (mask & ellipsis_mask) != 0:
ellipsis_seen = True
if not ellipsis_seen:
#Used later for extending the stride attributes in the below loop.
ellipsis_mask |= (1 << stride_dim)
stride_dim += 1
final_index = 0
for index in range(stride_dim):
mask = 1 << index
if mask & ellipsis_mask:
#Identify the end index for applying ellipsis_mask
to_index = min(((data_dim - (stride_dim-index)) + 1 \
+ new_axes_after_ellipsis), data_dim)
for i in range(final_index, to_index):
m_begin[final_index] = 0
m_end[final_index] = data_shape[0][final_index]
m_stride[final_index] = 1
fshape_indices.append(final_index)
final_index += 1
elif mask &new_axis_mask:
fshape_indices.append(-1)
elif not mask & new_axis_mask:
if final_index == len(m_begin):
break
if mask & begin_mask:
m_begin[final_index] = data_shape[0][final_index] \
if stride[index] < 0 else 0
elif begin[index]:
m_begin[final_index] = begin[index]
if mask & end_mask:
m_end[final_index] = 0 if stride[index] < 0 \
else data_shape[0][final_index]
elif end[index]:
m_end[final_index] = end[index]
m_stride[final_index] = stride[index]
if mask & shrink_axis_mask:
#Tensorflow make axis with shrink_axis_mask as dimension 1
m_begin[final_index] = data_shape[0][final_index] + begin[index] \
if begin[index] < 0 else begin[index]
m_end[final_index] = begin[index] + 1
m_stride[final_index] = 1
fshape_indices.append(-2)
else:
fshape_indices.append(final_index)
final_index += 1
return m_begin, m_end, m_stride, fshape_indices
fshape_indices = None
if begin_mask or end_mask or ellipsis_mask or new_axis_mask or shrink_axis_mask:
begin, end, stride, fshape_indices = _transform_mask(stride_dim, ellipsis_mask)
out = _op.strided_slice(inputs[0], begin=begin, end=end, strides=stride)
out_shape = _infer_out_shapes(out, params)[0]
if not fshape_indices:
fshape_indices = range(len(out_shape))
#Create final output shape.
final_output = []
for gather_index in fshape_indices:
if gather_index == -1:
final_output.append(1)
elif gather_index == -2:
pass
else:
final_output.append(out_shape[gather_index])
return _op.reshape(out, newshape=tuple(final_output))
return _impl
def _pad(name):
def _impl(inputs, attr, params):
padlist_key = inputs[1].name_hint
if padlist_key in params:
padlist = params.pop(padlist_key).asnumpy()
else:
raise RuntimeError("Required parameter {} not fount.".format(padlist_key))
paddings = tuple([tuple(l) for l in padlist])
attr['pad_width'] = paddings
attr['pad_value'] = 0
new_inputs = [inputs[0]]
if name == 'PadV2':
constant_values = params.pop(inputs[2].name_hint).asnumpy()
attr['pad_value'] = constant_values[0]
return AttrCvt(
op_name='pad',
ignores=['Tpaddings'],)(new_inputs, attr)
return _impl
def _transpose():
def _impl(inputs, attr, params):
# If perm is not specified, axes is left empty,
# otherwise its value is get from params
param_name = _get_name_hint(inputs[1])
if param_name in params:
axes = tuple(params.get(param_name).asnumpy())
else:
axes = None
return _op.transpose(inputs[0], axes=axes)
return _impl
def _rank():
def _impl(inputs, attr, params):
input_shapes = attr['_input_shapes'][inputs[0]]
assert len(inputs) == 1
name = attr["_node_name"]
params[name] = tvm.nd.array([len(input_shapes[0])])
return [_expr.var(name,
shape=params[name].shape,
dtype='int32')]
return _impl
def _range():
def _impl(inputs, attr, params):
start = params.pop(inputs[0].name_hint).asnumpy()[0]
limit = params.pop(inputs[1].name_hint).asnumpy()[0]
delta = params.pop(inputs[2].name_hint).asnumpy()[0]
name = attr["_node_name"]
params[name] = tvm.nd.array([start, limit, delta])
return [_expr.var(name,
shape=params[name].shape,
dtype='int32')]
return _impl
def _elu():
def _impl(inputs, attr, params):
alpha = tvm.relay.const(-1.0, attr['T'].name)
return alpha * _op.nn.relu(tvm.relay.const(1, attr['T'].name) \
- _op.exp(inputs[0])) + _op.nn.relu(inputs[0])
return _impl
def _selu():
def _impl(inputs, attr, params):
alpha = tvm.relay.const(-1.6732632423543772848170429916717, attr['T'].name)
gamma = tvm.relay.const(1.0507009873554804934193349852946, attr['T'].name)
return gamma * (alpha * _op.nn.relu(tvm.relay.const(1, attr['T'].name) \
- _op.exp(inputs[0])) + _op.nn.relu(inputs[0]))
return _impl
def _mean():
def _impl(inputs, attr, params):
axis = params.pop(inputs[1].name_hint)
return AttrCvt(op_name="mean", ignores=['Tdim', 'Tidx'],
transforms={'keep_dims': 'keepdims'},
extras={'axis': tuple(axis.asnumpy())})([inputs[0]], attr)
return _impl
def _broadcast(name):
def _impl(inputs, attr, params):
return AttrCvt(
op_name=name,
ignores=['name', 'Tidx']
)(inputs, attr)
return _impl
def _softmax():
def _impl(inputs, attr, params):
return AttrCvt(op_name='softmax',
transforms={'axis': ('axis', 1)})([inputs[0]], attr)
return _impl
# compatible operators that do NOT require any conversion.
_identity_list = []
# _convert_map defines maps of name to converter functor(callable)
# for 1 to 1 mapping, use Renamer if nothing but name is different
# use AttrCvt if attributes need to be converted
# for 1 to N mapping(composed), use custom callable functions
# for N to 1 mapping, currently not supported(?)
_convert_map = {
'ArgMax' : _argx(_op.argmax, 'argmax'),
'ArgMin' : _argx(_op.argmin, 'argmin'),
'AvgPool' : _pooling('avg_pool'),
'BatchNormWithGlobalNormalization' : _batch_norm(),
'BiasAdd' : _bias_add(),
'Cast' : _cast(),
'Ceil' : AttrCvt('ceil'),
'CheckNumerics' : _check_numerics(),
'Concat' : _concat(),
'ConcatV2' : _concatV2(),
'Conv2D' : _conv('conv'),
'DecodeJpeg' : _decode_image(),
'Elu' : _elu(),
'ExpandDims' : _expand_dims(),
'Floor' : AttrCvt('floor'),
'Identity' : _identity(),
'MatMul' : _matmul(),
'MaxPool' : _pooling('max_pool'),
'Add' : _elemwise('add'),
'Sub' : _elemwise('subtract'),
'Mul' : _elemwise('multiply'),
'Maximum' : _elemwise('maximum'),
'Minimum' : _elemwise('minimum'),
'Sum' : _sum(),
'Square' : _square(),
'Pack' : _pack(),
'LeakyRelu' : AttrCvt('leaky_relu'),
'Relu' : AttrCvt('relu'),
'Reshape' : _reshape(),
'ResizeBilinear' : _resize_bilinear(),
'Selu' : _selu(),
'Softmax' : _softmax(),
'Rsqrt' : _rsqrt(),
'Squeeze' : _squeeze(),
'FusedBatchNorm' : _fused_batch_norm(),
'FusedBatchNormV2' : _fused_batch_norm(),
'Relu6' : _relu6(),
'DepthwiseConv2dNative' : _conv('depthwise'),
'Shape' : _shape(),
'Sigmoid' : AttrCvt('sigmoid'),
'Fill' : _fill(),
'GatherV2' : _gather_v2(),
'StridedSlice' : _stridedSlice(),
'LRN' : _lrn(),
'Pad' : _pad('Pad'),
'PadV2' : _pad('PadV2'),
'Range' : _range(),
'Rank' : _rank(),
'Transpose' : _transpose(),
'Tanh' : AttrCvt('tanh'),
'Mean' : _mean(),
'Less' : _broadcast('less'),
'Greater' : _broadcast('greater'),
'LessEqual' : _broadcast('less_equal'),
'GreaterEqual' : _broadcast('greater_equal'),
'Equal' : _broadcast('equal'),
'NotEqual' : _broadcast('not_equal'),
}
def _LSTMBlockCell():
def _impl(inputs, in_state_c, in_state_h, attr, params):
"""LSTM Block cell.
Calculations are described in: https://github.com/tensorflow/tensorflow/blob/
r1.8/tensorflow/contrib/rnn/python/ops/lstm_ops.py#L41-L114
Parameters
----------
inputs : relay.Expr
Input data
in_state_c: list of relay.Expr
Cell state input values for all the layers
in_state_h: list of relay.Expr
Hidden state input values for all the layers
attrs : dict
Dict of operator attributes
params : dict
List of pretrained weights and bias
Returns
-------
sym : relay.Expr
Converted relay.Expr
output: relay.Expr
Output state value.
"""
in_data = inputs[0]
in_weight = inputs[3]
in_bias = inputs[7]
forget_bias = attr.pop('forget_bias')
input_shape = attr['_input_shapes'][inputs[0]]
weight_shape = attr['_input_shapes'][inputs[3]]
batch_size, input_size = input_shape[0][0], input_shape[0][1]
num_hidden_layers = weight_shape[0][1]
num_hidden = num_hidden_layers // 4
in_data = _op.reshape(in_data,
newshape=(batch_size, input_size))
ixh = _op.concatenate([in_data, in_state_h], axis=1)
in_weight = _op.transpose(in_weight, axes=None)
gates = _op.nn.dense(ixh, in_weight,
units=num_hidden_layers)
gates_bias = _op.add(gates, in_bias)
gate_list = _op.split(gates_bias, indices_or_sections=4, axis=1)
in_gate = _op.sigmoid(gate_list[0])
in_transform = _op.tanh(gate_list[1])
forget_gate = _op.sigmoid(gate_list[2])
forget_gate = _op.add(forget_gate,
tvm.relay.const(forget_bias, attr['T'].name))
out_gate = _op.sigmoid(gate_list[3])
next_c = _op.add(_op.multiply(forget_gate, in_state_c),
_op.multiply(in_gate, in_transform))
next_h = out_gate * _op.tanh(next_c)
out_state = _op.concatenate([next_c, next_h], axis=1)
out_state = _op.reshape(out_state,
newshape=(2, batch_size, num_hidden))
return next_h, out_state
return _impl
# _convert_map_rnn defines maps of rnn operator name to
# converter functor(callable) for 1 to 1 mapping.
_convert_map_rnn = {
'LSTMBlockCell' : _LSTMBlockCell(),
}
class RecurrentNetworks(object):
"""Recurrent network layer handlers.
Handle Layer operations.
ToDo: Operators like RNN/GRU layer concepts also can be handled here
Parameters
----------
nodes : list
list of graph nodes used for tensorflow parsing.
out_rnn : list
List of RecurrentNetwork outputs. This output will be appended to the
'head' nodes of the graph.
graph : tensorflow graph definition object
The loaded tensorflow GraphDef
convert_map : dict
Dict of name : callable, where name is the op's name that
require conversion to relay, callable are functions which
take attrs and return (new_op_name, new_attrs)
"""
def __init__(self, nodes, out_rnn, graph, convert_map):
self._graph = graph
self._convert_map = convert_map
self._nodes = nodes
self._out_rnn = out_rnn
self._cur_lstm_layer = 0
self._layer_name_list = []
self._recurrent_ops_layer_map = {
'LSTMBlockCell' : self._LSTMBlockCellLayer(),
}
def _LSTMBlockCellLayer(self):
"""LSTMBlockCell layer handler.
Parameters
----------
op_name : str
Operator name, eg:LSTMBlockCell
layer_name : str list
Layer name is used for creating the state input placeholder.
inputs : relay.Expr
Input data
attrs : dict
Dict of operator attributes
params : dict
List of pretrained weights and bias
num_layers : int
Total number of LSTM layer presented in the graph
Returns
-------
sym : relay.Expr
The returned relay Expr
"""
def _impl(op_name, layer_name, inputs, attrs, params, num_layers):
in_state_c_name = layer_name+'_c'
in_state_h_name = layer_name+'_h'
def _init_state(num_layers, batch_size, num_hidden):
"""Create the initial states for the first layer in the graph."""
in_state_c = [_expr.var(in_state_c_name,
shape=(num_layers, batch_size, num_hidden),
dtype='float32')]
in_state_h = [_expr.var(in_state_h_name,
shape=(num_layers, batch_size, num_hidden),
dtype='float32')]
return in_state_c, in_state_h
def _get_cur_input_state(in_state_c, in_state_h, num_layers,
layer, batch_size, num_hidden):
"""Select the appropriate states for the current layer"""
in_state_c_tup = _op.split(in_state_c[0],
indices_or_sections=num_layers, axis=0)
in_state_h_tup = _op.split(in_state_h[0],
indices_or_sections=num_layers, axis=0)
cur_in_state_c = _op.reshape(in_state_c_tup[layer],
newshape=(batch_size, num_hidden))
cur_in_state_h = _op.reshape(in_state_h_tup[layer],
newshape=(batch_size, num_hidden))
return cur_in_state_c, cur_in_state_h
def _LSTMBlockCellWrapper(inputs, attr, params,
num_layers, layer):
"""LSTM cell warapper to prepare the inputs"""
input_shape = attr['_input_shapes'][inputs[0]]
weight_shape = attr['_input_shapes'][inputs[3]]
batch_size = input_shape[0][0]
num_hidden = weight_shape[0][1] // 4
if layer == 0:
#Create initial states placeholder in case of first layer
in_state_c, in_state_h = _init_state(num_layers,
batch_size, num_hidden)
else:
in_state_c = self._nodes[in_state_c_name]
in_state_h = self._nodes[in_state_h_name]
cur_in_state_c, cur_in_state_h = _get_cur_input_state( \
in_state_c, in_state_h,
num_layers, layer,
batch_size, num_hidden)
output, out_state = self._convert_map[op_name](inputs, cur_in_state_c,
cur_in_state_h,
attr, params)
return output, out_state, in_state_c, in_state_h
sym, cur_out_state, in_state_c, in_state_h = \
_LSTMBlockCellWrapper(inputs, attrs, params,
num_layers, self._cur_lstm_layer)
self._nodes[in_state_c_name] = in_state_c
self._nodes[in_state_h_name] = in_state_h
cur_out_state = _op.expand_dims(cur_out_state, axis=0, num_newaxis=1)
self._out_rnn.append(cur_out_state)
self._cur_lstm_layer += 1
return sym
return _impl
def process_op(self, op_name, inputs, attrs, params):
"""Process recurrent layer operators.
List '_recurrent_ops_layer_map' map each Layer based operators with its
layer handlers. Total number of layers are calculated to form the input
data shapes.
Parameters
----------
op_name : str
Operator name, such as LSTMBlockCell
inputs : relay.Expr
Input data
attrs : dict
Dict of operator attributes
params : dict
List of pretrained weights and bias
Returns
-------
sym : relay.Expr
Returns relay.Expr
"""
def _get_abs_layer_name(node):
"""Identify the layer name is already handled. Return the absolute name
"""
if not self._layer_name_list:
self._layer_name_list.append(node.name)
return node.name
for _name in self._layer_name_list:
if _name in node.name:
abs_name = _name
else:
self._layer_name_list.append(node.name)
abs_name = node.name
return abs_name
#Find number of layers of this same operator node in the graph
#and also read the inputs name for the current op.
num_layers = 0
for _, node in enumerate(self._graph.node):
if node.op == op_name:
layer_name = _get_abs_layer_name(node)
num_layers += 1
sym = self._recurrent_ops_layer_map[op_name](op_name, layer_name, inputs, attrs,
params, num_layers)
return sym
class GraphProto(object):
""" A helper class for handling relay graph copying from Tensorflow GraphDef.
Definition:
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/graph.proto
"""
def __init__(self):
self._nodes = {}
self._params = {}
self._output_shapes = {}
self._num_param = 0
self._num_rnn_layer = False
def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
"""Construct relay nodes from tensorflow graph definition - GraphDef.
Follow the tensorflow graph definition to parse and convert it to Relay.
Some of the assumptions listed below.
-> All Placeholders are considered as graph input.
-> All Const nodes are params.
-> Last node is assumed as graph output.
-> _output_shapes : Graph should be frozen with add_shapes=True.
Or user can pass input shape dictionaly optionally.
-> DecodeJpeg, ResizeBilinear: These are dummy operators.
Hence user should handle preprocessing outside.
-> CheckNumerics: No implementation as of now for this.
Just copies input to output.
Parameters
----------
graph : tensorflow graph definition object
The loaded tensorflow GraphDef
layout : target layout to be used (Optional)
NCHW only supported now to enable NHWC models on GPU.
shape : Dictionary of input dimensions (Optional)
Graph level input shape dictionary.
Returns
-------
sym : relay.op
The returned relay operator
params : dict
A dict of name: tvm.nd.array pairs, used as pretrained weights
"""
try:
from tensorflow.python.framework import tensor_util
except ImportError as e:
raise ImportError(
"Unable to import tensorflow which is required {}".format(e))
missing_operators = self._parse_import_prerequisites(graph)
if missing_operators:
raise NotImplementedError( \
"The following operators are not implemented: {}".format(missing_operators))
# Parse the nodes to re-create TF graph using Relay operators.
for node in graph.node:
# Tensorflow doesn't have seperate list for params extraction.
# Operator name 'Const' is treated as a parameter to build params dict.
input_shapes = {}
attr = self._parse_attr(node.attr)
#Variable converted to Const will not have only value attr
if 'value' in attr and node.op == 'Const':
tensor_value = attr['value']
self._output_shapes[node.name] = \
[tensor_util.TensorShapeProtoToList( \
tensor_value.tensor_shape)]
elif '_output_shapes' in attr:
self._output_shapes[node.name] = \
[tensor_util.TensorShapeProtoToList(tshape) \
for tshape in attr['_output_shapes']]
elif shape:
# Keep the list indexable to avoid key error.
# Actual value will be filled after node creation.
self._output_shapes[node.name] = [None]
else:
raise NotImplementedError( \
"Please freeze the graph with add_shapes=True")
if node.op == "Placeholder":
self._output_shapes[node.name] = [shape[node.name]]
self._nodes[node.name] = [_expr.var(node.name,
shape=self._output_shapes[node.name][0],
dtype=attr['dtype'].name)]
elif node.op == "Const":
# All Const nodes are Param nodes, lets parse
self._num_param += 1
for key, value in node.attr.items():
self._parse_param(key, value, node.name, shape)
if node.name not in self._nodes:
raise NotImplementedError( \
"Const {} couldn't be converted to Param.".format(node.name))
attr = self._parse_attr(node.attr)
else:
# Pass the parsed shapes instead
attr["_output_shapes"] = self._output_shapes[node.name]
# Pass the node name too in attr
attr["_node_name"] = node.name
# Pass the target layout
attr["_target_layout"] = layout
#ToDo: Some of the tensorflow operators internaly maintain
#execution layers and its output name will the layer number along with
#graph node name.eg: Node name:- 'Model/RNN/cell_0/RnnCell', but the
#output name will be 'Model/RNN/cell_0/RnnCell:0'. In this case,
#the digit has to be ignored.
if ":" in node.input[0]:
in_name, _ = node.input[0].split(':')
node.input[0] = in_name
# Fill shapes for all inputs in a list
inputs = []
for i in node.input:
if i in self._nodes:
inputs.append(self._nodes[i][0])
input_shapes[self._nodes[i][0]] = self._output_shapes[i]
attr['_input_shapes'] = input_shapes
op = self._convert_operator(node.op, inputs, attr, graph)
# Check is op is converted to param
if isinstance(op, np.ndarray):
self._params[node.name] = tvm.nd.array(op)
op = [_expr.var(node.name,
shape=self._params[node.name].shape,
dtype=self._params[node.name].dtype)]
elif isinstance(op, (_expr.TupleWrapper, tuple, list)):
pass
elif isinstance(op, _expr.Expr):
op = [op]
else:
raise RuntimeError("unexpected type %s" % type(op))
self._nodes[node.name] = op
# Infer shapes if passed explicitely
node_output = self._nodes[node.name]
out_type = ir_pass.infer_type(node_output[0])
self._output_shapes[node.name] = [get_const_tuple(out_type.checked_type.shape)]
out = []
if outputs is None:
out = op
else:
out = [self._nodes[out_name][0] for out_name in outputs]
#Add the RNN outputs also with 'head' nodes of the relay graph
if self._num_rnn_layer:
if len(self._out_rnn) == 1:
out.append(self._out_rnn[0])
else:
out_rnn = _op.concatenate(self._out_rnn, axis=0)
out.append(out_rnn)
out = out[0] if len(out) == 1 else _expr.Tuple(out)
func = _expr.Function(ir_pass.free_vars(out), out)
return func, self._params
def _parse_import_prerequisites(self, graph):
""" Calculate the named preconditions from TensorFlow `graph`.
Return prerequisites for parsing:
a. Set of operator names which don't have their mapping in TVM, i.e.
which are not supported
"""
missing_operators = set()
for node in graph.node:
if node.op == "Placeholder":
pass
elif node.op == "Const":
pass
else:
if any([node.op in t for t in [_identity_list, _convert_map, _convert_map_rnn]]):
pass
else:
missing_operators.add(node.op)
return missing_operators
def _parse_param(self, key, value, name, shape):
try:
from tensorflow.python.framework import tensor_util
except ImportError as e:
raise ImportError(
"Unable to import tensorflow which is required {}".format(e))
if key == 'value':
np_array = tensor_util.MakeNdarray(value.tensor)
if np_array.dtype == np.dtype(object):
# Object types are generally tensorflow DT_STRING (DecodeJpeg op).
# Just leave it as placeholder.
self._nodes[name] = [_expr.var(name, shape=shape[name], dtype='uint8')]
return
array_ndim = len(np_array.shape)
if array_ndim == 0:
new_array = np.empty([1], dtype=np_array.dtype)
new_array[0] = np_array
self._params[name] = tvm.nd.array(new_array)
else:
self._params[name] = tvm.nd.array(np_array)
self._nodes[name] = [_expr.var(name,
shape=self._params[name].shape,
dtype=self._params[name].dtype)]
else:
if key != 'dtype' and key != '_output_shapes' and key != '_class':
raise NotImplementedError \
("Other attributes for a Const(param) Node {} ? .".format(key))
def _get_attr(self, buf):
"""Returns the value of the attr of this buf with the given `name`.
Args:
buf: attrvalue protobuf.
Returns:
The value of the attr, as a Python object.
Raises:
ValueError: If this op does not have an attr with the given `name`.
"""
fields = ["s", "i", "f", "b", "type", "shape", "tensor", "func"]
x = buf
ret = []
try:
from tensorflow.python.framework import dtypes
except ImportError as e:
raise ImportError(
"Unable to import tensorflow which is required {}".format(e))
# Treat an empty oneof value as an empty list.
if not x.WhichOneof("value"):
return ret
if x.HasField("list"):
for f in fields:
if getattr(x.list, f):
if f == "type":
ret += [dtypes.as_dtype(x) for x in list(getattr(x.list, f))]
else:
ret += list(getattr(x.list, f))
else:
for f in fields:
if x.HasField(f):
if f == "type":
ret = dtypes.as_dtype(getattr(x, f))
else:
ret = getattr(x, f)
return ret
def _parse_attr(self, attr_proto):
"""Convert a list of AttributeProto to a dict, with names as keys."""
attrs = {}
for key, value in attr_proto.items():
attrs[key] = self._get_attr(value)
return attrs
def _convert_rnn_operator(self, op_name, inputs,
attrs, params, graph, convert_map):
"""Convert RNN and its variant operators to Relay operators.
This converter read the input states of each layers and
also maintain the output states of each layer in a list.
Parameters
----------
op_name : str
Operator name, such as LSTMBlockCell
inputs : list of relay.Expr
List of input symbols.
attrs : dict
Dict of operator attributes
params : dict
List of pretrained weights and bias
graph : Tensorflow graph object
Graph is to find the number of upcoming same operator to
calculate the number of layers.
convert_map : dict
Dict of name : callable, where name is the op's name that
require conversion to relay, callable are functions which
take attrs and return (new_op_name, new_attrs)
Returns
-------
sym : relay.Expr
Converted relay.Expr
"""
if not self._num_rnn_layer:
self._out_rnn = []
self.rnn = RecurrentNetworks(self._nodes, self._out_rnn, graph, convert_map)
self._num_rnn_layer = True
sym = self.rnn.process_op(op_name, inputs, attrs, params)
return sym
def _convert_operator(self, op_name, inputs, attrs,
graph, identity_list=None, convert_map=None):
"""Convert from Tensorflow operator to relay operator.
The converter must specify conversions explicity for incompatible name, and
apply handlers to operator attributes.
Parameters
----------
op_name : str
Operator name, such as Conv2D, AvgPool
inputs : list of relay.op
List of input symbols.
attrs : dict
Dict of operator attributes
identity_list : list
List of operators that don't require conversion
convert_map : dict
Dict of name : callable, where name is the op's name that
require conversion to relay, callable are functions which
take attrs and return (new_op_name, new_attrs)
Returns
-------
sym : relay.op
Converted relay operator
"""
identity_list = identity_list if identity_list else _identity_list
convert_map = convert_map if convert_map else _convert_map
convert_map_rnn = _convert_map_rnn
if op_name in identity_list:
sym = _get_relay_op(op_name)(*inputs, **attrs)
elif op_name in convert_map:
sym = convert_map[op_name](inputs, attrs, self._params)
elif op_name in convert_map_rnn:
sym = self._convert_rnn_operator(op_name, inputs, attrs,
self._params, graph,
convert_map_rnn)
else:
raise NotImplementedError("Operator {} not implemented.".format(op_name))
return sym
def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None):
""" Load tensorflow graph which is a python tensorflow graph object into relay.
The companion parameters will be handled automatically.
Parameters
----------
graph : GraphDef object
Tensorflow GraphDef
Returns
-------
sym : relay.op
Compatible relay operator
params : dict of str to tvm.ndarray
Dict of converted parameters stored in tvm.ndarray format
"""
g = GraphProto()
sym, params = g.from_tensorflow(graph, layout, shape, outputs)
return sym, params
# pylint: disable=import-self, invalid-name, unused-argument
"""
Tensorflow testcases
====================
This article is a test script to test tensorflow operator with Relay.
"""
from __future__ import print_function
import numpy as np
import tvm
from tvm import relay
import tensorflow as tf
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import graph_util
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.ops import init_ops
from tensorflow.core.framework import graph_pb2
import tvm.relay.testing.tf as tf_testing
#######################################################################
# Generic run functions for TVM & tensorflow
# ------------------------------------------
def convert_to_list(x):
if not isinstance(x, list):
x = [x]
return x
def run_tvm_graph(graph_def, input_data, input_node, num_output=1, target='llvm', out_names=None):
""" Generic function to compile on relay and execute on tvm """
input_data = convert_to_list(input_data)
input_node = convert_to_list(input_node)
layout = None
if target == "cuda":
layout = "NCHW"
target_host = 'llvm'
if isinstance(input_data, list):
shape_dict = {}
dtype_dict = {}
for i, e in enumerate(input_node):
shape_dict[e] = input_data[i].shape
dtype_dict[e] = input_data[i].dtype
else:
shape_dict = {input_node: input_data.shape}
dtype_dict = {input_node: input_data.dtype}
sym, params = relay.frontend.from_tensorflow(graph_def,
layout=layout,
shape=shape_dict,
outputs=out_names)
with relay.build_config(opt_level=3):
graph, lib, params = relay.build(sym, target, params=params)
ctx = tvm.context(target, 0)
from tvm.contrib import graph_runtime
m = graph_runtime.create(graph, lib, ctx)
# set inputs
for i, e in enumerate(input_node):
m.set_input(e, tvm.nd.array(input_data[i].astype(input_data[i].dtype)))
m.set_input(**params)
# execute
m.run()
# get outputs
assert out_names is None or num_output == len(out_names),"out_names: {} num_output: {}".format(
out_names, num_output)
tvm_output_list = []
for i in range(0, num_output):
tvm_output = m.get_output(i)
tvm_output_list.append(tvm_output.asnumpy())
return tvm_output_list
def run_tf_graph(sess, input_data, input_node, output_node):
""" Generic function to execute tensorflow """
input_data = convert_to_list(input_data)
input_node = convert_to_list(input_node)
output_node = convert_to_list(output_node)
tensor = [0] * len(output_node)
for i in range(len(output_node)):
tensor[i] = sess.graph.get_tensor_by_name(output_node[i])
input_dict = {}
for i, e in enumerate(input_node):
input_dict[e] = input_data[i]
output_data = sess.run(tensor, input_dict)
return output_data
def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False, no_gpu=False):
"""Generic function to generate and compare tensorflow and TVM output"""
out_name = convert_to_list(out_name)
out_node = [0]*len(out_name)
for i in range(len(out_name)):
out_node[i] = out_name[i].split(':')[0] if ":" in out_name[i] else out_name[i]
in_data = convert_to_list(in_data)
in_name = convert_to_list(in_name)
in_node = [0]*len(in_name)
for i in range(len(in_name)):
in_node[i] = in_name[i].split(':')[0] if ":" in in_name[i] else in_name[i]
with tf.Session() as sess:
if init_global_variables:
sess.run(variables.global_variables_initializer())
final_graph_def = tf.graph_util.convert_variables_to_constants(
sess,
sess.graph.as_graph_def(add_shapes=True),
out_node,
)
tf_output = run_tf_graph(sess, in_data, in_name, out_name)
for device in ["llvm", "cuda"]:
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
continue
if no_gpu and device == 'cuda':
continue
tvm_output = run_tvm_graph(final_graph_def, in_data, in_node, target=device)
# since the names from tensorflow and relay runs are not exactly same,
# first len(tf_output) will be compared
for i in range(len(tf_output)):
tvm.testing.assert_allclose(tf_output[i], tvm_output[i], atol=1e-5, rtol=1e-5)
sess.close()
def is_gpu_available():
from tensorflow.python.client import device_lib
local_device_protos = device_lib.list_local_devices()
gpu_list = [x.name for x in local_device_protos if x.device_type == 'GPU']
if len(gpu_list) > 0:
print("Tensorflow GPU:", gpu_list)
return True
else:
return False
#######################################################################
# Pooling
# -------
def _test_pooling_iteration(input_shape, **kwargs):
""" One iteration of pool operation with given shapes and attributes """
x = -np.arange(
np.prod(input_shape), dtype=np.float32).reshape(input_shape) - 1
with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=input_shape, dtype='float32')
nn_ops.pool(in_data, **kwargs)
if kwargs['pooling_type'] == 'MAX':
out_name = 'max_pool:0'
else:
out_name = 'avg_pool:0'
compare_tf_with_tvm(x, 'Placeholder:0', out_name)
def _test_pooling(input_shape, **kwargs):
_test_pooling_iteration(input_shape, **kwargs)
if is_gpu_available():
input_shape = [input_shape[ii] for ii in (0, 3, 1, 2)]
kwargs['data_layout'] = 'NCHW'
_test_pooling_iteration(input_shape, **kwargs)
def test_forward_pooling():
""" Pooling """
for pool_type in ['AVG', 'MAX']:
_test_pooling(input_shape=[2, 9, 10, 2],
window_shape=[1, 1],
padding='SAME',
pooling_type=pool_type,
dilation_rate=[1, 1],
strides=[1, 1])
_test_pooling(input_shape=[2, 10, 9, 2],
window_shape=[1, 1],
padding='SAME',
pooling_type=pool_type,
dilation_rate=[1, 1],
strides=[1, 1])
_test_pooling(input_shape=[2, 9, 10, 2],
window_shape=[2, 1],
padding='SAME',
pooling_type=pool_type,
dilation_rate=[1, 1],
strides=[1, 1])
_test_pooling(input_shape=[2, 10, 9, 2],
window_shape=[2, 3],
padding='SAME',
pooling_type=pool_type,
dilation_rate=[1, 1],
strides=[2, 1])
#######################################################################
# Convolution
# -----------
def _test_convolution(tensor_in_sizes, filter_in_sizes,
dilations, strides, padding, data_format):
""" One iteration of convolution with given shapes and attributes """
total_size_1 = 1
total_size_2 = 1
for s in tensor_in_sizes:
total_size_1 *= s
for s in filter_in_sizes:
total_size_2 *= s
# Initializes the input tensor with array containing incrementing
# numbers from 1.
data_array = [f * 1.0 for f in range(1, total_size_1 + 1)]
filter_array = [f * 1.0 for f in range(1, total_size_2 + 1)]
with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=tensor_in_sizes, dtype='float32')
in_filter = constant_op.constant(filter_array, shape=filter_in_sizes, dtype='float32')
strides = [1] + strides + [1]
dilations = [1] + dilations + [1]
nn_ops.conv2d(in_data,
in_filter,
strides=strides,
padding=padding,
data_format=data_format)
compare_tf_with_tvm(np.reshape(data_array, tensor_in_sizes).astype('float32'),
'Placeholder:0', 'Conv2D:0')
def test_forward_convolution():
if is_gpu_available():
_test_convolution([4, 176, 8, 8], [1, 1, 176, 32], [1, 1], [1, 1], 'SAME', 'NCHW')
_test_convolution([4, 19, 17, 17], [3, 3, 19, 19], [1, 1], [2, 2], 'VALID', 'NCHW')
_test_convolution([4, 124, 17, 17], [1, 1, 124, 19], [1, 1], [1, 1], 'SAME', 'NCHW')
_test_convolution([4, 12, 17, 17], [3, 3, 12, 32], [1, 1], [2, 2], 'VALID', 'NCHW')
_test_convolution([4, 8, 8, 176], [1, 1, 176, 32], [1, 1], [1, 1], 'SAME', 'NHWC')
_test_convolution([4, 17, 17, 19], [3, 3, 19, 19], [1, 1], [2, 2], 'VALID', 'NHWC')
_test_convolution([4, 17, 17, 124], [1, 1, 124, 19], [1, 1], [1, 1], 'SAME', 'NHWC')
_test_convolution([4, 17, 17, 12], [3, 3, 12, 32], [1, 1], [2, 2], 'VALID', 'NHWC')
#######################################################################
# Reshape
# -------
def _test_reshape(data, out_shape):
""" One iteration of reshape operation with given data and out shape """
with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
array_ops.reshape(in_data, out_shape)
compare_tf_with_tvm(data, 'Placeholder:0', 'Reshape:0')
def test_forward_reshape():
_test_reshape(np.arange(6.0), [2, 3])
_test_reshape(np.arange(6), [-1, 2])
_test_reshape(np.arange(6), [3, -1])
_test_reshape(np.arange(6), [-1])
#######################################################################
#######################################################################
# Squeeze
# -------
def _test_squeeze(data, squeeze_dims=None):
""" One iteration of squeeze """
if squeeze_dims is None:
squeeze_dims = []
with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
if squeeze_dims:
array_ops.squeeze(in_data, squeeze_dims)
else:
array_ops.squeeze(in_data)
compare_tf_with_tvm(data, 'Placeholder:0', 'Squeeze:0')
def test_forward_squeeze():
""" Squeeze """
# Nothing to squeeze.
_test_squeeze(np.arange(2).reshape((2)))
_test_squeeze(np.arange(6).reshape((2, 3)))
# Squeeze the middle element away.
_test_squeeze(np.arange(4).reshape((2, 1, 2)))
# Squeeze on both ends.
_test_squeeze(np.arange(6).reshape((1, 2, 1, 3, 1)))
# Positive squeeze dim index.
_test_squeeze(np.arange(6).reshape((1, 2, 1, 3, 1)), [0])
_test_squeeze(np.arange(6).reshape((1, 2, 1, 3, 1)), [2, 4])
_test_squeeze(np.arange(6).reshape((1, 2, 1, 3, 1)), [0, 4, 2])
# Negative squeeze dim index.
_test_squeeze(np.arange(6).reshape((1, 2, 1, 3, 1)), [-1])
_test_squeeze(np.arange(6).reshape((1, 2, 1, 3, 1)), [-3, -5])
_test_squeeze(np.arange(6).reshape((1, 2, 1, 3, 1)), [-3, -5, -1])
#######################################################################
# ConcatV2
# --------
def _test_concat_v2(data, dim):
""" One iteration of ConcatV2 """
with tf.Graph().as_default():
gen_array_ops._concat_v2(data, dim)
compare_tf_with_tvm(data, ['ConcatV2/values_0:0', 'ConcatV2/values_1:0'],
'ConcatV2:0')
def _test_forward_concat_v2():
t1 = np.array([])
t2 = np.array([])
_test_concat_v2([t1, t2], 0)
t1 = np.array([[1, 2, 3], [4, 5, 6]])
t2 = np.array([[7, 8, 9], [10, 11, 12]])
_test_concat_v2([t1, t2], 1)
#######################################################################
# Sigmoid
# -------
def _test_sigmoid(data):
""" One iteration of sigmoid """
with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
sigmoid_out = math_ops.sigmoid(in_data)
compare_tf_with_tvm(data, 'Placeholder:0', 'Sigmoid:0')
def test_forward_sigmoid():
""" Sigmoid """
_test_sigmoid(np.random.uniform(size=(3, 4, 4, 3)).astype('float32'))
#######################################################################
# Argmin/Argmax
# -------------
def _test_argx(func, data, **kwargs):
with tf.Graph().as_default():
inp = array_ops.placeholder(shape=data.shape, dtype=data.dtype, name="c0")
func(inp, name="argx0", output_type=tf.int32, **kwargs)
compare_tf_with_tvm(data, 'c0:0', 'argx0:0')
def test_forward_argminmax():
for axis in [None,0,1,2]:
data = np.random.uniform(size=(8,4,9)).astype('float32')
_test_argx(tf.argmax, data=data, axis=axis)
_test_argx(tf.argmin, data=data, axis=axis)
#######################################################################
# Reduce
# ------
def _test_reduce(func, data, **kwargs):
""" One iteration of a reduce operation"""
with tf.Graph().as_default():
inp = array_ops.placeholder(shape=data.shape, dtype=data.dtype, name="c0")
func(inp, name="reducex0", **kwargs)
compare_tf_with_tvm(data, 'c0:0', 'reducex0:0')
def test_forward_reduce():
data = np.random.uniform(size=(8,4,9)).astype('float32')
_test_reduce(tf.reduce_sum, data=data)
_test_reduce(tf.reduce_sum, data=data, axis=0)
_test_reduce(tf.reduce_sum, data=data, axis=(0,1))
#######################################################################
# Variable
# --------
def _test_variable(data):
""" One iteration of a variable """
tf.reset_default_graph()
input_op = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
input_tensor = array_ops.reshape(input_op, data.shape)
size = input_tensor.shape.dims[1]
with variable_scope.variable_scope("linear", reuse=None):
w = variable_scope.get_variable(
"w", shape=[size, size], dtype=input_tensor.dtype)
math_ops.matmul(input_tensor, w)
compare_tf_with_tvm(data, 'Placeholder:0', 'MatMul:0', init_global_variables=True)
def test_forward_variable():
"""Variable type op test"""
_test_variable(np.random.uniform(size=(32, 100)).astype('float32'))
#######################################################################
# StridedSlice
# ------------
def _test_stridedslice(ip_shape, begin, end, stride, dtype,
begin_mask=0, end_mask=0, new_axis_mask=0,
shrink_axis_mask=0, ellipsis_mask=0):
""" One iteration of a Stridedslice """
tf.reset_default_graph()
in_data = tf.placeholder(dtype, ip_shape, name="in_data")
tf.strided_slice(in_data, begin, end, stride, begin_mask=begin_mask,
end_mask=end_mask, new_axis_mask=new_axis_mask,
shrink_axis_mask=shrink_axis_mask,
ellipsis_mask=ellipsis_mask, name="strided_slice")
np_data = np.random.uniform(size=ip_shape).astype(dtype)
compare_tf_with_tvm(np_data, 'in_data:0', 'strided_slice:0')
def test_forward_stridedslice():
'''test StridedSlice'''
_test_stridedslice((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], 'float32')
_test_stridedslice((3, 4, 3), [1, 0], [4, 3], [2, 1], 'float32', ellipsis_mask=8)
_test_stridedslice((3, 4, 3), [1, 0], [4, 2], [2, 1], 'float32', ellipsis_mask=2)
_test_stridedslice((3, 4, 5, 3), [1, 0], [4, 2], [2, 1], 'float32', ellipsis_mask=2)
_test_stridedslice((3, 4, 5, 3), [1, 0, 1], [4, 2, 2], [2, 1, 1], 'float32', ellipsis_mask=2)
_test_stridedslice((3, 4, 3), [1, 1, 0], [4, 4, 2], [2, 1, 1], 'float32', new_axis_mask=5)
_test_stridedslice((3, 4, 3), [1, 1, 1], [4, 4, 1], [2, 1, 1], 'float32', ellipsis_mask=2, new_axis_mask=4)
_test_stridedslice((6, 4, 5), [1, 1, 1], [6, 3, 4], [2, 1, 1], 'float32', ellipsis_mask=2, new_axis_mask=5)
_test_stridedslice((3, 4, 3), [1, 1, 2], [4, 4, 3], [2, 1, 1], 'float32', ellipsis_mask=4, new_axis_mask=2)
_test_stridedslice((3, 4, 3), [1, 1, 2], [4, 4, 3], [2, 1, 1], 'float32', ellipsis_mask=2, new_axis_mask=3)
_test_stridedslice((3, 4, 3), [1, 1, 0], [4, 4, 1], [2, 1, 1], 'float32', ellipsis_mask=2, new_axis_mask=3)
_test_stridedslice((3, 4, 3), [1, 1, 2], [4, 4, 3], [2, 1, 1], 'float32', ellipsis_mask=2, new_axis_mask=2)
_test_stridedslice((3,4), [1, 0], [4, 4], [1, 1], 'float32', shrink_axis_mask=2)
_test_stridedslice((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1], 'float32', shrink_axis_mask=2, new_axis_mask=2)
_test_stridedslice((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1], 'float32', shrink_axis_mask=1, new_axis_mask=2)
_test_stridedslice((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1], 'float32', shrink_axis_mask=2, new_axis_mask=1)
_test_stridedslice((3, 4, 5, 4, 5, 6), [0, 0], [2, 3], [1, 1], 'float32', shrink_axis_mask=5, new_axis_mask=1)
_test_stridedslice((3, 4, 5, 4, 5, 6), [0, 0, 1, 2, 1], [2, 3, 4, 5, 3], [1, 1, 2, 2, 1],
'float32', shrink_axis_mask=5, new_axis_mask=1, ellipsis_mask=2, begin_mask=8, end_mask=8)
_test_stridedslice((3, 4, 5, 4, 5, 6), [0, 0, 1, 2, 1], [2, 3, 4, 5, 3], [1, 1, 2, 2, 1],
'float32', shrink_axis_mask=8, new_axis_mask=1, ellipsis_mask=2, begin_mask=5, end_mask=5)
_test_stridedslice((3, 4, 5, 4, 5, 6), [0, 0, 1, 2, 1], [2, 3, 4, 5, 3], [1, 1, 2, 2, 1],
'float32', shrink_axis_mask=16, new_axis_mask=1, ellipsis_mask=2, begin_mask=5, end_mask=5)
_test_stridedslice((3, 4, 5, 4, 5, 6), [1, 2, 0, -3], [4, 5, 3, 3], [2, 2, 1, 1],
'float32', shrink_axis_mask=8, new_axis_mask=1, ellipsis_mask=2, begin_mask=5,
end_mask=8)
#######################################################################
# Gather
# ------
def _test_gather(ip_shape, indice_shape, indice_value, axis, dtype):
""" One iteration of a Gather """
tf.reset_default_graph()
in_data = tf.placeholder(dtype, ip_shape, name="in_data")
indices = tf.placeholder("int32", indice_shape, name="indices")
tf.gather(in_data, indices, axis=axis)
np_data = np.random.uniform(size=ip_shape).astype(dtype)
def _fill_indices(indice_value):
indices = np.array(ip_shape, dtype=dtype)
if isinstance(indice_value, int):
indices = np.array([indice_value], dtype='int32')
else:
indices = np.asarray(indice_value, dtype='int32')
return indices
np_indices = _fill_indices(indice_value)
compare_tf_with_tvm([np_data, np_indices], ['in_data:0', 'indices:0'], 'GatherV2:0')
def test_forward_gather():
'''test gather layer'''
_test_gather((4,), (1,), 1, 0, 'int32')
_test_gather((4,), (1,), 1, 0, 'float32')
_test_gather((1,4), (1,), [0], 0, 'int32')
_test_gather((4,), (1,2,2), [[[1,0],[0,1]]], 0, 'float32')
_test_gather((2,2), (1,2,2), [[[1,0],[0,1]]], 0, 'int32')
_test_gather((2,2), (1,2,2), [[[1,0],[0,1]]], 1, 'int32')
_test_gather((2,2), (1,2,2), [[[1,0],[0,1]]], 0, 'float32')
_test_gather((3,3,3), (1,1,2), [[[1,0]]], 0, 'int32')
_test_gather((3,3,3), (1,1,2), [[[1,0]]], 2, 'int32')
_test_gather((4,3,5,6), (1,4), [[2,1,0,0]], 0, 'float32')
#######################################################################
# Multi Input to graph
# --------------------
def test_forward_multi_input():
with tf.Graph().as_default():
in1 = tf.placeholder(tf.int32, shape=[3, 3], name='in1')
in2 = tf.placeholder(tf.int32, shape=[3, 3], name='in2')
in3 = tf.placeholder(tf.int32, shape=[3, 3], name='in3')
in4 = tf.placeholder(tf.int32, shape=[3, 3], name='in4')
out1 = tf.add(in1, in2, name='out1')
out2 = tf.subtract(in3, in4, name='out2')
out = tf.multiply(out1, out2, name='out')
in_data = np.arange(9, dtype='int32').reshape([3, 3])
compare_tf_with_tvm([in_data, in_data, in_data, in_data],
['in1:0', 'in2:0', 'in3:0', 'in4:0'], 'out:0')
#######################################################################
# Multi Output to Graph
# ---------------------
def test_forward_multi_output():
with tf.Graph().as_default():
in1 = tf.placeholder(tf.int32, shape=[3, 3], name='in1')
in2 = tf.placeholder(tf.int32, shape=[3, 3], name='in2')
in3 = tf.placeholder(tf.int32, shape=[3, 3], name='in3')
in4 = tf.placeholder(tf.int32, shape=[3, 3], name='in4')
out1 = tf.add(in1, in2, name='out1')
out2 = tf.subtract(in3, in4, name='out2')
in_data = np.arange(9, dtype='int32').reshape([3, 3])
in_data = [in_data] * 4
in_name = ['in1:0', 'in2:0', 'in3:0', 'in4:0']
out_name = ['out1:0', 'out2:0']
out_node = [out.strip(':0') for out in out_name]
in_node = [inp.strip(':0') for inp in in_name]
with tf.Session() as sess:
final_graph_def = tf.graph_util.convert_variables_to_constants(
sess, sess.graph.as_graph_def(add_shapes=True), out_node,)
tf_output = run_tf_graph(sess, in_data, in_name, out_name)
tvm_output = run_tvm_graph(final_graph_def, in_data, in_node, target='llvm',
out_names=out_node, num_output=2)
for i in range(len(tf_output)):
tvm.testing.assert_allclose(tf_output[i], tvm_output[i], atol=1e-5, rtol=1e-5)
#######################################################################
# Resize Bilinear
# ---------------
def _test_resize_bilinear(in_shape, to_shape, align_corners):
""" One iteration of resize bilinear """
data = np.random.uniform(size=in_shape).astype('float32')
shape_data = np.array(to_shape).astype('int32')
with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
shape_data = constant_op.constant(shape_data, shape=shape_data.shape, dtype=shape_data.dtype)
tf.image.resize_bilinear(in_data, shape_data, align_corners=align_corners)
compare_tf_with_tvm(data, 'Placeholder:0', 'ResizeBilinear:0')
def test_forward_resize_bilinear():
""" Resize Bilinear """
_test_resize_bilinear((4, 16, 32, 32), [50, 50], False)
_test_resize_bilinear((6, 32, 64, 64), [20, 20], True)
#######################################################################
# LSTM
# ----
def _test_lstm_cell(batch_size, num_hidden, num_layers, forget_bias, dtype):
""" One iteration of a LSTM cell """
tf.reset_default_graph()
input_size = num_hidden
input_data = np.full((batch_size, input_size), 1., dtype=dtype)
in_state_c = np.full((num_layers, batch_size, num_hidden), 0.1, dtype=dtype)
in_state_h = np.full((num_layers, batch_size, num_hidden), 0.1, dtype=dtype)
def _get_tensorflow_output():
with tf.Session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
m0 = array_ops.zeros([batch_size, num_hidden])
m1 = array_ops.zeros([batch_size, num_hidden])
x=tf.placeholder(shape=(batch_size, input_size), dtype=dtype)
g, ((out_m0, out_m1)) = \
tf.contrib.rnn.LSTMBlockCell(num_hidden,
forget_bias=forget_bias)(x, ((m0, m1)))
sess.run([variables.global_variables_initializer()])
res = sess.run([g, out_m0, out_m1], {
x.name: np.array([[1., 1.]]),
m0.name: 0.1 * np.ones([batch_size, num_hidden]),
m1.name: 0.1 * np.ones([batch_size, num_hidden]),
})
graph_def = sess.graph.as_graph_def(add_shapes=True)
final_graph_def = graph_util.convert_variables_to_constants(
sess,
graph_def,
['root/lstm_cell/LSTMBlockCell'])
return final_graph_def, res
graph_def, tf_out = _get_tensorflow_output()
tvm_output = run_tvm_graph(graph_def, [input_data, in_state_c, in_state_h],
['root/Placeholder', 'root/lstm_cell/LSTMBlockCell_c',
'root/lstm_cell/LSTMBlockCell_h'], num_output=2)
assert isinstance(tvm_output, list)
out = tvm_output[0]
out_state = tvm_output[1]
out_state_tup = np.split(out_state, indices_or_sections=2, axis=1)
out_state_c = np.reshape(out_state_tup[0], (batch_size, num_hidden))
out_state_h = np.reshape(out_state_tup[1], (batch_size, num_hidden))
tvm_out = [out, out_state_c, out_state_h]
tvm.testing.assert_allclose(tf_out[0], tvm_out[0], rtol=1e-3, atol=1e-3)
def test_forward_lstm():
'''test LSTM block cell'''
_test_lstm_cell(1, 2, 1, 0.0, 'float32')
#######################################################################
# Pack
# ---
def _test_pack(axis, shape, **kwargs):
a = np.arange(np.prod(shape), dtype=np.float32).reshape(shape)
b = np.arange(np.prod(shape), dtype=np.float32).reshape(shape)
with tf.Graph().as_default():
tf_a = array_ops.placeholder(shape=shape, dtype='float32', name='pl_a')
tf_b = array_ops.placeholder(shape=shape, dtype='float32', name='pl_b')
tf_c = tf.stack([tf_a,tf_b], axis=axis, **kwargs)
assert tf_c.op.op_def.name == 'Pack', "tf.stack() is expected to produce 'Pack' operation"
compare_tf_with_tvm([a,b], ['pl_a:0','pl_b:0'], 'stack:0')
def test_forward_pack():
for axis in range(-3,3):
_test_pack(axis, [3,2,1])
for axis in range(-1,1):
_test_pack(axis, [3])
_test_pack(0, [])
#######################################################################
# Pad
# ---
def _test_pad(input_shape, paddings, mode, **kwargs):
""" One iteration of pad operation with given shape"""
x = np.arange(np.prod(input_shape), dtype=np.float32).reshape(input_shape)
with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=input_shape, dtype='float32')
pad_values = constant_op.constant(paddings)
pad = tf.pad(in_data, paddings=pad_values, mode=mode, **kwargs)
if mode == 'CONSTANT':
if 'constant_values' in kwargs:
out_name = 'PadV2:0'
else:
out_name = 'Pad:0'
compare_tf_with_tvm(x, 'Placeholder:0', out_name)
def test_forward_pad():
""" Pad """
_test_pad((2, 3), [[1,1], [2,2]], mode="CONSTANT")
_test_pad((2, 3), [[1,1], [2,2]], mode="CONSTANT", constant_values=1.0)
#######################################################################
# Inception V3
# ------------
def test_forward_inception_v3():
'''test inception V3 model'''
with tf.Graph().as_default():
graph_def = tf_testing.get_workload('InceptionV3/inception_v3_2016_08_28_frozen-with_shapes.pb')
# Call the utility to import the graph definition into default graph.
graph_def = tf_testing.ProcessGraphDefParam(graph_def)
data = np.random.uniform(size=(1, 299, 299, 3)).astype('float32')
with tf.Session() as sess:
tf_output = run_tf_graph(sess, data, 'input:0', 'InceptionV3/Predictions/Reshape_1:0')
tvm_output = run_tvm_graph(graph_def, data, 'input')
tvm.testing.assert_allclose(tf_output[0], tvm_output[0], rtol=1e-5, atol=1e-5)
#######################################################################
# Inception V1
# ------------
def test_forward_inception_v1():
'''test inception V1 model'''
with tf.Graph().as_default():
graph_def = tf_testing.get_workload("InceptionV1/classify_image_graph_def-with_shapes.pb")
# Call the utility to import the graph definition into default graph.
graph_def = tf_testing.ProcessGraphDefParam(graph_def)
# Build an image from random data.
from PIL import Image
from tvm.contrib import util
img_array = np.random.uniform(size=(1, 600, 600, 3)).astype("uint8")
img = Image.frombuffer('RGB', (600, 600), img_array.tostring(), 'raw', 'RGB', 0, 1)
temp = util.tempdir()
img_path = temp.relpath("tf-test.jpg")
img.save(img_path);
import os.path
if not tf.gfile.Exists(os.path.join(img_path)):
tf.logging.fatal('File does not exist %s', img_path)
data = tf.gfile.FastGFile(os.path.join(img_path), 'rb').read()
temp.remove()
# Extract tensorflow decoded image frame for tvm input
with tf.Session() as sess:
tvm_data = run_tf_graph(sess, data, 'DecodeJpeg/contents:0', 'DecodeJpeg:0')
with tf.Session() as sess:
tf_output = run_tf_graph(sess, data, 'DecodeJpeg/contents:0', 'softmax:0')
tvm_output = run_tvm_graph(graph_def, tvm_data, 'DecodeJpeg/contents')
tvm.testing.assert_allclose(tf_output[0], tvm_output[0], rtol=1e-5, atol=1e-5)
#######################################################################
# Mobilenet
# ---------
def test_forward_mobilenet():
'''test mobilenet model'''
# MobilenetV2
with tf.Graph().as_default():
graph_def = tf_testing.get_workload(
"https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.4_224.tgz",
"mobilenet_v2_1.4_224_frozen.pb")
# Call the utility to import the graph definition into default graph.
graph_def = tf_testing.ProcessGraphDefParam(graph_def)
data = np.random.uniform(size=(1, 224, 224, 3)).astype('float32')
out_node = 'MobilenetV2/Predictions/Reshape_1'
with tf.Session() as sess:
# Add shapes to the graph.
graph_def = tf_testing.AddShapesToGraphDef(sess, out_node)
tf_output = run_tf_graph(sess, data, 'input:0', out_node + ':0')
tvm_output = run_tvm_graph(graph_def, data, 'input')
tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tf_output[0]), rtol=1e-5, atol=1e-5)
#######################################################################
# ResnetV2
# ---------
def test_forward_resnetv2():
'''test resnet model'''
if is_gpu_available():
with tf.Graph().as_default():
graph_def = tf_testing.get_workload("ResnetV2/resnet-20180601_resnet_v2_imagenet-shapes.pb")
# Call the utility to import the graph definition into default graph.
graph_def = tf_testing.ProcessGraphDefParam(graph_def)
data = np.random.uniform(size=(128, 224, 224, 3)).astype('float32')
out_node = 'ArgMax'
with tf.Session() as sess:
tf_output = run_tf_graph(sess, data, 'input_tensor:0', out_node + ':0')
tvm_output = run_tvm_graph(graph_def, data, 'input_tensor', tf_output.shape, 'float32')
tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tf_output[0]), rtol=1e-5, atol=1e-5)
#######################################################################
# PTB
# ---
dir(tf.contrib)
def test_forward_ptb():
'''test ptb model'''
config = tf_testing.get_config()
num_steps = config.num_steps
num_hidden = config.hidden_size
num_layers = config.num_layers
batch_size = config.batch_size
vocab_size = config.vocab_size
out_sample_shape = (batch_size, vocab_size)
out_state_shape = (num_layers, 2, batch_size, num_hidden)
#Sample input
inpt = "we have no useful information on"
cnt_sample = 20
def _pretty_print(items, is_char_model, id2word):
if not is_char_model:
return ' '.join([id2word[x] for x in items])
else:
return ''.join([id2word[x] for x in items]).replace('_', ' ')
def _get_tvm_graph_module(graph_def):
#Cell inputs 'c and 'h' consist of all layers values
shape_dict = {'Model/Placeholder': (batch_size, num_steps),
'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_c':(num_layers, batch_size, num_hidden),
'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_h':(num_layers, batch_size, num_hidden)}
sym, params = relay.frontend.from_tensorflow(graph_def, shape=shape_dict)
dtype_dict = {'Model/Placeholder': 'int32',
'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_c':'float32',
'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_h':'float32'}
target = 'llvm'
with relay.build_config(opt_level=0):
graph, lib, params = relay.build(sym, target, params=params)
from tvm.contrib import graph_runtime
ctx = tvm.cpu(0)
return params, graph_runtime.create(graph, lib, ctx)
def _do_tvm_sample(model, data, in_states, params, num_samples):
"""Sampled from the model"""
samples = []
state = in_states
sample = None
def _get_sample(data, state):
input_data = np.full((batch_size, num_steps), data, dtype="int32")
in_state_tup = np.split(state, indices_or_sections=2, axis=1)
in_state_c = np.reshape(in_state_tup[0], (num_layers, batch_size, num_hidden))
in_state_h = np.reshape(in_state_tup[1], (num_layers, batch_size, num_hidden))
model.set_input('Model/Placeholder', tvm.nd.array(input_data.astype("int32")))
model.set_input('Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_c',
tvm.nd.array(in_state_c.astype("float32")))
model.set_input('Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_h',
tvm.nd.array(in_state_h.astype("float32")))
model.set_input(**params)
model.run()
tvm_output = model.get_output(0, tvm.nd.empty(out_sample_shape,
"float32")).asnumpy()
state_output = model.get_output(1, tvm.nd.empty(out_state_shape,
"float32")).asnumpy()
sample = tf_testing.pick_from_weight(tvm_output[0])
return sample, state_output
for x in data:
sample, state = _get_sample(x, state)
if sample is not None:
samples.append(sample)
else:
samples.append(0)
k = 1
while k < num_samples:
sample, state = _get_sample(samples[-1], state)
samples.append(sample)
k += 1
return samples, state
with tf.Graph().as_default():
word_to_id, id_to_word, graph_def = tf_testing.get_workload_ptb()
vocab_size = len(word_to_id)
# Call the utility to import the graph definition into default graph.
graph_def = tf_testing.ProcessGraphDefParam(graph_def)
sess = tf.Session()
#TVM graph module creation
params, m = _get_tvm_graph_module(graph_def)
# Create 10 predicted statments of 20 words
cnt_stm = 0
while cnt_stm < 10:
cnt_stm += 1
in_state = np.full((num_layers, 2, batch_size, num_hidden), 0, dtype="float32")
seed_for_sample = inpt.split()
tvm_samples, tvm_state = _do_tvm_sample(m, [word_to_id[word] \
for word in seed_for_sample],
in_state, params, cnt_sample)
tvm_sample_str = _pretty_print(tvm_samples, False, id_to_word)
tf_samples, tf_state = tf_testing.do_tf_sample(sess,
[word_to_id[word] for word in seed_for_sample],
in_state, cnt_sample)
tf_sample_str = _pretty_print(tf_samples, False, id_to_word)
inpt = tvm_sample_str
tvm.testing.assert_allclose(tf_samples, tvm_samples, rtol=1e-5, atol=1e-5)
assert(tvm_sample_str == tf_sample_str)
#######################################################################
# LRN (Local Response Normalization)
# ----------------------------------
def _test_lrn(ishape, size, axis, bias, alpha, beta):
""" testing local response normalization """
lrn_depth_radius = size / 2
inp_array = np.random.uniform(size=ishape).astype(np.float32)
with tf.Graph().as_default():
in1 = tf.placeholder(shape=inp_array.shape, dtype=inp_array.dtype, name="lrn0_data")
nn_ops.local_response_normalization(in1,
name="lrn",
depth_radius=lrn_depth_radius,
bias=bias,
alpha=alpha,
beta=beta)
compare_tf_with_tvm(inp_array, 'lrn0_data:0', 'lrn:0')
def test_forward_lrn():
_test_lrn((1, 3, 20, 20), 3, 1, 1.0, 1.0, 0.5)
#######################################################################
# l2_normalize
# ------------
def _test_l2_normalize(ishape, eps, axis):
""" testing l2 normalize (uses max, sum, square, sqrt frontend operators)"""
inp_array = np.random.uniform(size=ishape).astype(np.float32)
with tf.Graph().as_default():
in1 = tf.placeholder(shape=inp_array.shape, dtype=inp_array.dtype)
nn.l2_normalize(in1,
axis=axis,
epsilon=eps,
name=None,
dim=None)
compare_tf_with_tvm(inp_array, 'Placeholder:0', 'l2_normalize:0')
def test_forward_l2_normalize():
_test_l2_normalize((1, 3, 20, 20), 0.001, (0,))
#######################################################################
# transpose
# ---------
def _test_forward_transpose(ishape, axes=None):
data = np.random.uniform(size=ishape).astype(np.float32)
with tf.Graph().as_default():
in1 = tf.placeholder(shape=data.shape, dtype=data.dtype, name="transpose_data")
if axes is None:
tf.transpose(in1)
else:
tf.transpose(in1, perm=axes)
compare_tf_with_tvm(data, 'transpose_data:0', 'transpose:0')
def test_forward_transpose():
_test_forward_transpose((2, 3, 4), (1, 2, 0))
_test_forward_transpose((2, 3, 4))
_test_forward_transpose((7, 8, 8, 10))
_test_forward_transpose((2, 3, 4), (1, 2, 0))
_test_forward_transpose((2, 3, 4), (0, 1, 2))
_test_forward_transpose((2, 3, 4, 5), (3, 0, 1, 2))
def test_forward_ceil():
ishape = (1, 3, 10, 10)
inp_array = np.random.uniform(size=ishape).astype(np.float32)
with tf.Graph().as_default():
in1 = tf.placeholder(shape=inp_array.shape, dtype=inp_array.dtype)
tf.ceil(in1)
compare_tf_with_tvm(inp_array, 'Placeholder:0', 'Ceil:0')
def test_forward_floor():
ishape = (1, 3, 10, 10)
inp_array = np.random.uniform(size=ishape).astype(np.float32)
with tf.Graph().as_default():
in1 = tf.placeholder(shape=inp_array.shape, dtype=inp_array.dtype)
tf.floor(in1)
compare_tf_with_tvm(inp_array, 'Placeholder:0', 'Floor:0')
def test_forward_relu():
ishape = (1, 3, 10, 10)
inp_array = np.random.uniform(-5, 5, size=ishape).astype(np.float32)
with tf.Graph().as_default():
in1 = tf.placeholder(shape=inp_array.shape, dtype=inp_array.dtype)
tf.nn.relu(in1)
compare_tf_with_tvm(inp_array, 'Placeholder:0', 'Relu:0')
def test_forward_leaky_relu():
ishape = (1, 3, 10, 10)
inp_array = np.random.uniform(-5, 5, size=ishape).astype(np.float32)
with tf.Graph().as_default():
in1 = tf.placeholder(shape=inp_array.shape, dtype=inp_array.dtype)
tf.nn.leaky_relu(in1, alpha=0.4)
compare_tf_with_tvm(inp_array, 'Placeholder:0', 'LeakyRelu/mul:0')
def test_forward_elu():
ishape = (1, 3, 10, 10)
inp_array = np.random.uniform(-5, 5, size=ishape).astype(np.float32)
with tf.Graph().as_default():
in1 = tf.placeholder(shape=inp_array.shape, dtype=inp_array.dtype)
tf.nn.elu(in1)
compare_tf_with_tvm(inp_array, 'Placeholder:0', 'Elu:0')
def test_forward_selu():
ishape = (1, 3, 10, 10)
inp_array = np.random.uniform(-5, 5, size=ishape).astype(np.float32)
with tf.Graph().as_default():
in1 = tf.placeholder(shape=inp_array.shape, dtype=inp_array.dtype)
tf.nn.selu(in1)
compare_tf_with_tvm(inp_array, 'Placeholder:0', 'Selu:0')
def test_forward_tanh():
ishape = (1, 3, 10, 10)
inp_array = np.random.uniform(-5, 5, size=ishape).astype(np.float32)
with tf.Graph().as_default():
in1 = tf.placeholder(shape=inp_array.shape, dtype=inp_array.dtype)
tf.nn.tanh(in1)
compare_tf_with_tvm(inp_array, 'Placeholder:0', 'Tanh:0')
#######################################################################
# Mean
# ----
def test_forward_mean():
def check_mean(ishape, **kwargs):
inp_array = np.random.uniform(size=ishape).astype(np.float32)
with tf.Graph().as_default():
in1 = tf.placeholder(shape=inp_array.shape, dtype=inp_array.dtype)
tf.keras.backend.mean(in1, **kwargs)
compare_tf_with_tvm(inp_array, 'Placeholder:0', 'Mean:0', no_gpu=True)
check_mean((10, 8, 16, 32))
check_mean((10, 8, 16, 32), axis=(2,3))
check_mean((10, 8, 16, 32), axis=(1,2), keepdims=True)
#######################################################################
# Relational operators
# --------------------
def _test_forward_rel_op(data, func):
with tf.Graph().as_default():
in1 = tf.placeholder(shape=data[0].shape, dtype=data[0].dtype, name='in1')
in2 = tf.placeholder(shape=data[1].shape, dtype=data[1].dtype, name='in2')
op = func(in1, in2, name='op')
out = tf.cast(op, tf.int32, name='out1')
compare_tf_with_tvm([data[0], data[1]], ['in1:0', 'in2:0'], 'out1:0')
def test_forward_rel_ops():
t1 = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
t2 = np.array([[9, 8, 7], [6, 5, 4], [3, 2, 1]])
_test_forward_rel_op([t1, t2], math_ops.less)
_test_forward_rel_op([t1, t2], math_ops.greater)
_test_forward_rel_op([t1, t2], math_ops.less_equal)
_test_forward_rel_op([t1, t2], math_ops.greater_equal)
_test_forward_rel_op([t1, t2], math_ops.equal)
_test_forward_rel_op([t1, t2], math_ops.not_equal)
#######################################################################
# Main
# ----
if __name__ == '__main__':
# Transforms
test_forward_transpose()
test_forward_reshape()
test_forward_squeeze()
test_forward_pack()
test_forward_resize_bilinear()
test_forward_pad()
test_forward_gather()
test_forward_stridedslice()
# Activations
test_forward_sigmoid()
test_forward_relu()
test_forward_leaky_relu()
test_forward_elu()
test_forward_selu()
test_forward_tanh()
# Reductions
test_forward_argminmax()
test_forward_reduce()
test_forward_mean()
# General
test_forward_multi_input()
test_forward_multi_output()
test_forward_variable()
# NN
test_forward_convolution()
test_forward_pooling()
if tf.__version__ == '1.4.1':
_test_forward_concat_v2()
test_forward_lrn()
test_forward_l2_normalize()
# End to End
test_forward_inception_v3()
test_forward_inception_v1()
test_forward_mobilenet()
test_forward_resnetv2()
test_forward_ptb()
# RNN
test_forward_lstm()
# Elementwise
test_forward_ceil()
test_forward_floor()
# Relational ops
test_forward_rel_ops()
......@@ -16,7 +16,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import variables
from tensorflow.contrib.lite.python import interpreter as interpreter_wrapper
import nnvm.testing.tf
import tvm.relay.testing.tf as tf_testing
#######################################################################
# Generic run functions for TVM & TFLite
......@@ -344,7 +344,7 @@ def test_forward_mobilenet():
'''test mobilenet v1 tflite model'''
# MobilenetV1
temp = util.tempdir()
tflite_model_file = nnvm.testing.tf.get_workload_official(
tflite_model_file = tf_testing.get_workload_official(
"http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224.tgz",
"mobilenet_v1_1.0_224.tflite", temp)
tflite_model_buf = open(tflite_model_file, "rb").read()
......
......@@ -42,6 +42,9 @@ python3 -m nose -v tests/python/frontend/onnx || exit -1
echo "Running relay CoreML frondend test..."
python3 -m nose -v tests/python/frontend/coreml || exit -1
echo "Running relay Tensorflow frontend test..."
python3 -m nose -v tests/python/frontend/tensorflow || exit -1
echo "Running nnvm to relay frontend test..."
python3 -m nose -v tests/python/frontend/nnvm_to_relay || exit -1
......@@ -50,4 +53,3 @@ python3 -m nose -v tests/python/frontend/tflite || exit -1
echo "Running relay caffe2 frondend test..."
python3 -m nose -v tests/python/frontend/caffe2 || exit -1
"""
Compile Tensorflow Models
=========================
This article is an introductory tutorial to deploy tensorflow models with TVM.
For us to begin with, tensorflow python module is required to be installed.
Please refer to https://www.tensorflow.org/install
"""
# tvm, relay
import tvm
from tvm import relay
# os and numpy
import numpy as np
import os.path
# Tensorflow imports
import tensorflow as tf
# Tensorflow utility functions
import tvm.relay.testing.tf as tf_testing
# Base location for model related files.
repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/InceptionV1/'
# Test image
img_name = 'elephant-299.jpg'
image_url = os.path.join(repo_base, img_name)
######################################################################
# Tutorials
# ---------
# Please refer docs/frontend/tensorflow.md for more details for various models
# from tensorflow.
model_name = 'classify_image_graph_def-with_shapes.pb'
model_url = os.path.join(repo_base, model_name)
# Image label map
map_proto = 'imagenet_2012_challenge_label_map_proto.pbtxt'
map_proto_url = os.path.join(repo_base, map_proto)
# Human readable text for labels
lable_map = 'imagenet_synset_to_human_label_map.txt'
lable_map_url = os.path.join(repo_base, lable_map)
# Target settings
# Use these commented settings to build for cuda.
#target = 'cuda'
#target_host = 'llvm'
#layout = "NCHW"
#ctx = tvm.gpu(0)
target = 'llvm'
target_host = 'llvm'
layout = None
ctx = tvm.cpu(0)
######################################################################
# Download required files
# -----------------------
# Download files listed above.
from mxnet.gluon.utils import download
download(image_url, img_name)
download(model_url, model_name)
download(map_proto_url, map_proto)
download(lable_map_url, lable_map)
######################################################################
# Import model
# ------------
# Creates tensorflow graph definition from protobuf file.
with tf.gfile.FastGFile(os.path.join("./", model_name), 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
graph = tf.import_graph_def(graph_def, name='')
# Call the utility to import the graph definition into default graph.
graph_def = tf_testing.ProcessGraphDefParam(graph_def)
# Add shapes to the graph.
with tf.Session() as sess:
graph_def = tf_testing.AddShapesToGraphDef(sess, 'softmax')
######################################################################
# Decode image
# ------------
# .. note::
#
# tensorflow frontend import doesn't support preprocessing ops like JpegDecode.
# JpegDecode is bypassed (just return source node).
# Hence we supply decoded frame to TVM instead.
#
from PIL import Image
image = Image.open(img_name).resize((299, 299))
x = np.array(image)
######################################################################
# Import the graph to Relay
# -------------------------
# Import tensorflow graph definition to relay frontend.
#
# Results:
# sym: relay expr for given tensorflow protobuf.
# params: params converted from tensorflow params (tensor protobuf).
shape_dict = {'DecodeJpeg/contents': x.shape}
dtype_dict = {'DecodeJpeg/contents': 'uint8'}
sym, params = relay.frontend.from_tensorflow(graph_def, layout=layout, shape=shape_dict)
print ("Tensorflow protobuf imported to relay frontend.")
######################################################################
# Relay Build
# -----------
# Compile the graph to llvm target with given input specification.
#
# Results:
# graph: Final graph after compilation.
# params: final params after compilation.
# lib: target library which can be deployed on target with tvm runtime.
with relay.build_config(opt_level=3):
graph, lib, params = relay.build(sym, target=target, target_host=target_host, params=params)
######################################################################
# Execute the portable graph on TVM
# ---------------------------------
# Now we can try deploying the compiled model on target.
from tvm.contrib import graph_runtime
dtype = 'uint8'
m = graph_runtime.create(graph, lib, ctx)
# set inputs
m.set_input('DecodeJpeg/contents', tvm.nd.array(x.astype(dtype)))
m.set_input(**params)
# execute
m.run()
# get outputs
tvm_output = m.get_output(0, tvm.nd.empty(((1, 1008)), 'float32'))
######################################################################
# Process the output
# ------------------
# Process the model output to human readable text for InceptionV1.
predictions = tvm_output.asnumpy()
predictions = np.squeeze(predictions)
# Creates node ID --> English string lookup.
node_lookup = tf_testing.NodeLookup(label_lookup_path=os.path.join("./", map_proto),
uid_lookup_path=os.path.join("./", lable_map))
# Print top 5 predictions from TVM output.
top_k = predictions.argsort()[-5:][::-1]
for node_id in top_k:
human_string = node_lookup.id_to_string(node_id)
score = predictions[node_id]
print('%s (score = %.5f)' % (human_string, score))
######################################################################
# Inference on tensorflow
# -----------------------
# Run the corresponding model on tensorflow
def create_graph():
"""Creates a graph from saved GraphDef file and returns a saver."""
# Creates graph from saved graph_def.pb.
with tf.gfile.FastGFile(model_name, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
graph = tf.import_graph_def(graph_def, name='')
# Call the utility to import the graph definition into default graph.
graph_def = tf_testing.ProcessGraphDefParam(graph_def)
def run_inference_on_image(image):
"""Runs inference on an image.
Parameters
----------
image: String
Image file name.
Returns
-------
Nothing
"""
if not tf.gfile.Exists(image):
tf.logging.fatal('File does not exist %s', image)
image_data = tf.gfile.FastGFile(image, 'rb').read()
# Creates graph from saved GraphDef.
create_graph()
with tf.Session() as sess:
softmax_tensor = sess.graph.get_tensor_by_name('softmax:0')
predictions = sess.run(softmax_tensor,
{'DecodeJpeg/contents:0': image_data})
predictions = np.squeeze(predictions)
# Creates node ID --> English string lookup.
node_lookup = tf_testing.NodeLookup(label_lookup_path=os.path.join("./", map_proto),
uid_lookup_path=os.path.join("./", lable_map))
# Print top 5 predictions from tensorflow.
top_k = predictions.argsort()[-5:][::-1]
print ("===== TENSORFLOW RESULTS =======")
for node_id in top_k:
human_string = node_lookup.id_to_string(node_id)
score = predictions[node_id]
print('%s (score = %.5f)' % (human_string, score))
run_inference_on_image (img_name)
......@@ -23,7 +23,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_util
# Tensorflow utility functions
import nnvm.testing.tf
import tvm.relay.testing.tf as tf_testing
# Base location for model related files.
repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/InceptionV1/'
......@@ -87,10 +87,10 @@ with tf.gfile.FastGFile(os.path.join("./", model_name), 'rb') as f:
graph_def.ParseFromString(f.read())
graph = tf.import_graph_def(graph_def, name='')
# Call the utility to import the graph definition into default graph.
graph_def = nnvm.testing.tf.ProcessGraphDefParam(graph_def)
graph_def = tf_testing.ProcessGraphDefParam(graph_def)
# Add shapes to the graph.
with tf.Session() as sess:
graph_def = nnvm.testing.tf.AddShapesToGraphDef(sess, 'softmax')
graph_def = tf_testing.AddShapesToGraphDef(sess, 'softmax')
######################################################################
# Decode image
......@@ -157,7 +157,7 @@ predictions = tvm_output.asnumpy()
predictions = np.squeeze(predictions)
# Creates node ID --> English string lookup.
node_lookup = nnvm.testing.tf.NodeLookup(label_lookup_path=os.path.join("./", map_proto),
node_lookup = tf_testing.NodeLookup(label_lookup_path=os.path.join("./", map_proto),
uid_lookup_path=os.path.join("./", lable_map))
# Print top 5 predictions from TVM output.
......@@ -180,7 +180,7 @@ def create_graph():
graph_def.ParseFromString(f.read())
graph = tf.import_graph_def(graph_def, name='')
# Call the utility to import the graph definition into default graph.
graph_def = nnvm.testing.tf.ProcessGraphDefParam(graph_def)
graph_def = tf_testing.ProcessGraphDefParam(graph_def)
def run_inference_on_image(image):
"""Runs inference on an image.
......@@ -209,7 +209,7 @@ def run_inference_on_image(image):
predictions = np.squeeze(predictions)
# Creates node ID --> English string lookup.
node_lookup = nnvm.testing.tf.NodeLookup(label_lookup_path=os.path.join("./", map_proto),
node_lookup = tf_testing.NodeLookup(label_lookup_path=os.path.join("./", map_proto),
uid_lookup_path=os.path.join("./", lable_map))
# Print top 5 predictions from tensorflow.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment