Commit e68874d6 by Sunwoong Joo Committed by Tianqi Chen

[Relay][Frontend] Adding ADD operator to tflite frontend for compiling the MobileNetV2 (#2919)

parent eb82e7b7
......@@ -258,6 +258,9 @@ class ExprTable(object):
if name not in self.exprs:
self.exprs[name] = expr
def has_expr(self, name):
return True if name in self.exprs else False
def set_padding(self, paddings):
self.paddings = paddings
self.in_padding = True
......
......@@ -46,7 +46,8 @@ class OperatorConverter(object):
'SOFTMAX': self.convert_softmax,
'SQUEEZE': self.convert_squeeze,
'MAX_POOL_2D': self.convert_max_pool2d,
"CONCATENATION": self.convert_concatenation
'CONCATENATION': self.convert_concatenation,
'ADD': self.convert_add
}
def check_unsupported_ops(self):
......@@ -292,6 +293,49 @@ class OperatorConverter(object):
out = self.convert_fused_activation_function(out, fused_activation_fn)
return out
def convert_add(self, op):
"""Convert TFLite add"""
try:
from tflite.Operator import Operator
except ImportError:
raise ImportError("The tflite package must be installed")
assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 2, "input tensors length should be 2"
lhs_tensor = input_tensors[0]
lhs_expr = self.get_expr(lhs_tensor.tensor_idx)
rhs_tensor = input_tensors[1]
if self.has_expr(rhs_tensor.tensor_idx):
# In most cases, we can assume that TOCO fuses ADD operators
# with constants - it means both will be tensors.
rhs_expr = self.get_expr(rhs_tensor.tensor_idx)
else:
# However, in some corner cases, the ADD operator is not fused,
# we can receive as constant.
rhs_type_str = self.get_tensor_type_str(rhs_tensor.tensor.Type())
rhs_expr = self.exp_tab.new_const(self.get_tensor_value(rhs_tensor),
dtype=rhs_type_str)
# In this case, we have to be careful about formatting.
input_shape_length = len(rhs_tensor.tensor.ShapeAsNumpy())
if input_shape_length in (1, 2):
pass
elif input_shape_length == 3:
# N H*W C to N C H*W
rhs_expr = _op.transpose(rhs_expr, axes=(0, 2, 1))
elif input_shape_length == 4:
# N H W C to N C H W
rhs_expr = _op.transpose(rhs_expr, axes=(0, 3, 1, 2))
else:
msg = 'Input shape length {} for operator ADD is not valid.'
raise tvm.error.OpAttributeInvalid(msg.format(input_shape_length))
out = _op.add(lhs_expr, rhs_expr)
return out
def convert_squeeze(self, op):
"""Convert TFLite squeeze"""
try:
......@@ -554,6 +598,9 @@ class OperatorConverter(object):
def get_expr(self, input_tensor_idx):
return self.exp_tab.get_expr(get_tensor_name(self.subgraph, input_tensor_idx))
def has_expr(self, input_tensor_idx):
return self.exp_tab.has_expr(get_tensor_name(self.subgraph, input_tensor_idx))
def build_str_map(obj):
"""Build string map of TFLite enum int value
......
......@@ -11,6 +11,8 @@ from tvm import relay
from tvm.contrib import util
import tensorflow as tf
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import variables
......@@ -99,7 +101,7 @@ def run_tflite_graph(tflite_model_buf, input_data):
def compare_tflite_with_tvm(tflite_in_data, tvm_in_data, in_name, input_tensors,
output_tensors, output_need_transpose_nchw=False,
output_tensors, output_need_transpose=False,
init_global_variables=False):
"""Generic function to generate and compare TFLite and TVM output"""
tflite_in_data = convert_to_list(tflite_in_data)
......@@ -126,9 +128,19 @@ def compare_tflite_with_tvm(tflite_in_data, tvm_in_data, in_name, input_tensors,
tvm_output = run_tvm_graph(tflite_model_buffer, tvm_in_data, in_node, target=device)
for i in range(len(tflite_output)):
if output_need_transpose_nchw:
if output_need_transpose:
dim = len(tvm_output[i].shape)
if dim == 3:
# N C H*W to N H*W C
axes = (0, 2, 1)
elif dim == 4:
# N C H W to N H W C
axes = (0, 2, 3, 1)
else:
raise NotImplementedError("Not support input shape {} of transpose : ".
format(str(dim)))
tvm.testing.assert_allclose(tflite_output[i],
np.transpose(tvm_output[i], axes=(0, 2, 3, 1)),
np.transpose(tvm_output[i], axes=axes),
atol=1e-5, rtol=1e-5)
else:
tvm.testing.assert_allclose(tflite_output[i], tvm_output[i],
......@@ -152,7 +164,7 @@ def _test_pooling_iteration(input_shape, **kwargs):
out = nn_ops.pool(in_data, **kwargs)
compare_tflite_with_tvm(x, tvm_data, 'Placeholder:0', [in_data], [out],
output_need_transpose_nchw=True)
output_need_transpose=True)
def _test_pooling(input_shape, **kwargs):
......@@ -236,7 +248,7 @@ def _test_convolution(tensor_in_sizes, filter_in_sizes,
# TFLite output is NHWC, TVM is NCHW, we need transpose
compare_tflite_with_tvm(tflite_data_array, tvm_data_array,
'Placeholder:0', [in_data], [out],
output_need_transpose_nchw=True)
output_need_transpose=True)
def test_forward_convolution():
......@@ -331,6 +343,53 @@ def test_forward_concatenation():
#######################################################################
# Add
# ---
def _test_add(data):
""" One iteration of add """
assert len(data) == 2
need_transpose = False
if len(data[0].shape) == 1 or len(data[0].shape) == 2:
tvm_data = data
elif len(data[0].shape) == 3:
need_transpose = True
tvm_data = [np.transpose(d, axes=(0, 2, 1)) for d in data]
elif len(data[0].shape) == 4:
need_transpose = True
tvm_data = [np.transpose(d, axes=(0, 3, 1, 2)) for d in data]
else:
raise NotImplementedError("Not support input shape {} of add : ".
format(str(len(data.shape))))
# Test with two tensors
with tf.Graph().as_default():
in_data = [array_ops.placeholder(shape=data[0].shape, dtype=data[0].dtype, name='in_0'),
array_ops.placeholder(shape=data[1].shape, dtype=data[1].dtype, name='in_1')]
out = math_ops.add(in_data[0], in_data[1])
compare_tflite_with_tvm(data, tvm_data, ['in_0:0','in_1:0'],
in_data, [out], need_transpose)
# Test with tensor and constant
with tf.Graph().as_default():
in_data = [array_ops.placeholder(shape=data[0].shape, dtype=data[0].dtype, name='in')]
out = math_ops.add(in_data[0], ops.convert_to_tensor(data[1], dtype=data[1].dtype))
compare_tflite_with_tvm([data[0]], [tvm_data[0]], ['in:0'],
in_data, [out], need_transpose)
def test_forward_add():
""" Add """
_test_add([np.arange(6.0, dtype=np.float32).reshape((2, 1, 1, 3)),
np.arange(6.0, dtype=np.float32).reshape((2, 1, 1, 3))])
_test_add([np.arange(6.0, dtype=np.float32).reshape((2, 1, 3)),
np.arange(6.0, dtype=np.float32).reshape((2, 1, 3))])
_test_add([np.arange(3.0, dtype=np.float32).reshape((1, 3)),
np.arange(3.0, dtype=np.float32).reshape((1, 3))])
#######################################################################
# Squeeze
# -------
......@@ -388,7 +447,7 @@ def test_forward_softmax():
# Mobilenet
# ---------
def test_forward_mobilenet():
def test_forward_mobilenet_v1():
'''test mobilenet v1 tflite model'''
# MobilenetV1
tflite_model_file = tf_testing.get_workload_official(
......@@ -403,6 +462,21 @@ def test_forward_mobilenet():
tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]),
rtol=1e-5, atol=1e-5)
def test_forward_mobilenet_v2():
'''test mobilenet v2 tflite model'''
# MobilenetV2
tflite_model_file = tf_testing.get_workload_official(
"http://download.tensorflow.org/models/tflite_11_05_08/mobilenet_v2_1.0_224.tgz",
"mobilenet_v2_1.0_224.tflite")
with open(tflite_model_file, "rb") as f:
tflite_model_buf = f.read()
data = np.random.uniform(size=(1, 224, 224, 3)).astype('float32')
tvm_data = np.transpose(data, axes=(0, 3, 1, 2))
tflite_output = run_tflite_graph(tflite_model_buf, data)
tvm_output = run_tvm_graph(tflite_model_buf, tvm_data, 'input')
tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]),
rtol=1e-5, atol=1e-5)
#######################################################################
# Inception V3
# ------------
......@@ -436,6 +510,10 @@ if __name__ == '__main__':
test_forward_pooling()
test_forward_softmax()
# Math
test_forward_add()
# End to End
test_forward_mobilenet()
test_forward_mobilenet_v1()
test_forward_mobilenet_v2()
test_forward_inception_v3_net()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment