Commit 6ad7ce8b by 在原佐为 Committed by Siva

Add CONCATENATION to tflite frontend, support Inception V3 (#2643)

* Add CONCATENATION to tflite frontend

* fix typo

* Fix codestyle

* Fix code style

* simplify convert map

* Update
parent a1b86100
......@@ -35,6 +35,8 @@ class OperatorConverter(object):
self.builtin_op_code = build_str_map(BuiltinOperator())
self.activation_fn_type = build_str_map(ActivationFunctionType())
self.builtin_options = build_str_map(BuiltinOptions())
# Add more operators
self.convert_map = {
'CONV_2D': self.convert_conv2d,
'DEPTHWISE_CONV_2D': self.convert_depthwise_conv2d,
......@@ -43,7 +45,7 @@ class OperatorConverter(object):
'SOFTMAX': self.convert_softmax,
'SQUEEZE': self.convert_squeeze,
'MAX_POOL_2D': self.convert_max_pool2d,
# Add more operators
"CONCATENATION": self.convert_concatenation
}
def check_unsupported_ops(self):
......@@ -245,6 +247,48 @@ class OperatorConverter(object):
return out
def convert_concatenation(self, op):
""" convert TFLite concatenation"""
try:
from tflite.Operator import Operator
from tflite.ConcatenationOptions import ConcatenationOptions
from tflite.BuiltinOptions import BuiltinOptions
from tflite.ActivationFunctionType import ActivationFunctionType
except ImportError:
raise ImportError("The tflite package must be installed")
assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) >= 1, "input tensors should greater than 1"
in_exprs = [self.get_expr(input_tensor.tensor_idx) for input_tensor in input_tensors]
output_tensors = self.get_output_tensors(op)
assert len(output_tensors) == 1, "output tensors should be 1"
assert op.BuiltinOptionsType() == BuiltinOptions.ConcatenationOptions
op_options = op.BuiltinOptions()
concatenation_options = ConcatenationOptions()
concatenation_options.Init(op_options.Bytes, op_options.Pos)
concatenation_axis = concatenation_options.Axis()
fused_activation_fn = concatenation_options.FusedActivationFunction()
input_shape_length = len(input_tensors[0].tensor.ShapeAsNumpy())
# TFLite is N H W C, our layout is N C H W
if input_shape_length <= 4:
axis_convert_map = [0] + list(range(2, input_shape_length)) + [1]
concatenation_axis = axis_convert_map[concatenation_axis]
else:
raise NotImplementedError("Not support input shape length {} of concatenatio : "
.format(str(input_shape_length)))
# with axis in N H W C
out = _op.concatenate(in_exprs, axis=concatenation_axis)
# if we have activation fn
if fused_activation_fn != ActivationFunctionType.NONE:
out = self.convert_fused_activation_function(out, fused_activation_fn)
return out
def convert_squeeze(self, op):
"""Convert TFLite squeeze"""
try:
......
......@@ -284,6 +284,53 @@ def test_forward_reshape():
#######################################################################
# Concatenation
# -------------
def _test_concatenation(data, axis):
""" One iteration of concatenation """
assert len(data) >= 1
need_transpose = False
if len(data[0].shape) == 1 or len(data[0].shape) == 2:
tvm_data = data
elif len(data[0].shape) == 3:
#need_transpose = True
tvm_data = [np.transpose(d, axes=(0, 2, 1)) for d in data]
elif len(data[0].shape) == 4:
need_transpose = True
tvm_data = [np.transpose(d, axes=(0, 3, 1, 2)) for d in data]
else:
raise NotImplementedError("Not support input shape {} of reshape : ".
format(str(len(data))))
with tf.Graph().as_default():
in_data = [
array_ops.placeholder(shape=tensor.shape, dtype=tensor.dtype, name="in_{}".format(idx))
for idx, tensor in enumerate(data)]
out = array_ops.concat(in_data, axis=axis)
name = ["in_{}:0".format(idx) for idx in range(len(data))]
compare_tflite_with_tvm(data, tvm_data, name, in_data, [out], need_transpose)
def test_forward_concatenation():
_test_concatenation(
[np.arange(6).reshape((1, 2, 1, 3)),
np.arange(6).reshape((1, 2, 1, 3))], 1)
_test_concatenation(
[np.arange(6).reshape((3, 2)),
np.arange(6).reshape((3, 2))], 1)
_test_concatenation(
[np.arange(6).reshape((2, 1, 1, 3)),
np.arange(6).reshape((2, 1, 1, 3)),
np.arange(6).reshape((2, 1, 1, 3))], 1)
#######################################################################
# Squeeze
# -------
......@@ -340,6 +387,7 @@ def test_forward_softmax():
#######################################################################
# Mobilenet
# ---------
def test_forward_mobilenet():
'''test mobilenet v1 tflite model'''
# MobilenetV1
......@@ -347,19 +395,43 @@ def test_forward_mobilenet():
tflite_model_file = tf_testing.get_workload_official(
"http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224.tgz",
"mobilenet_v1_1.0_224.tflite", temp)
tflite_model_buf = open(tflite_model_file, "rb").read()
with open(tflite_model_file, "rb") as f:
tflite_model_buf = f.read()
data = np.random.uniform(size=(1, 224, 224, 3)).astype('float32')
tvm_data = np.transpose(data, axes=(0, 3, 1, 2))
tflite_output = run_tflite_graph(tflite_model_buf, data)
tvm_output = run_tvm_graph(tflite_model_buf, tvm_data, 'input')
tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]),
rtol=1e-5, atol=1e-5)
temp.remove()
#######################################################################
# Inception V3
# ------------
def test_forward_inception_v3_net():
'''test inception v3 tflite model'''
# InceptionV3
temp = util.tempdir()
tflite_model_file = tf_testing.get_workload_official(
"https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v3_2018_04_27.tgz",
"inception_v3.tflite", temp)
with open(tflite_model_file, "rb") as f:
tflite_model_buf = f.read()
data = np.random.uniform(size=(1, 299, 299, 3)).astype('float32')
tvm_data = np.transpose(data, axes=(0, 3, 1, 2))
tflite_output = run_tflite_graph(tflite_model_buf, data)
tvm_output = run_tvm_graph(tflite_model_buf, tvm_data, 'input')
tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]),
rtol=1e-5, atol=1e-5)
temp.remove()
#######################################################################
# Main
# ----
if __name__ == '__main__':
# Transforms
test_forward_concatenation()
test_forward_reshape()
test_forward_squeeze()
......@@ -370,3 +442,4 @@ if __name__ == '__main__':
# End to End
test_forward_mobilenet()
test_forward_inception_v3_net()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment