Commit 78e0871d by songqun Committed by Jared Roesch

[FRONTEND][TFLITE] Add FULLY_CONNECTED op into tflite frontend, support Inception V4 (#3019)

* Add FULLY_CONNECTED op into tflite frontend, support Inception V4

* Fix comment style in TF Lite tests.
parent e6ca91e1
...@@ -63,7 +63,8 @@ class OperatorConverter(object): ...@@ -63,7 +63,8 @@ class OperatorConverter(object):
'SQUEEZE': self.convert_squeeze, 'SQUEEZE': self.convert_squeeze,
'MAX_POOL_2D': self.convert_max_pool2d, 'MAX_POOL_2D': self.convert_max_pool2d,
'CONCATENATION': self.convert_concatenation, 'CONCATENATION': self.convert_concatenation,
'ADD': self.convert_add 'ADD': self.convert_add,
'FULLY_CONNECTED': self.convert_fully_connected,
} }
def check_unsupported_ops(self): def check_unsupported_ops(self):
...@@ -352,6 +353,71 @@ class OperatorConverter(object): ...@@ -352,6 +353,71 @@ class OperatorConverter(object):
out = _op.add(lhs_expr, rhs_expr) out = _op.add(lhs_expr, rhs_expr)
return out return out
def convert_fully_connected(self, op):
"""Convert TFLite fully connected"""
try:
from tflite.Operator import Operator
from tflite.FullyConnectedOptions import FullyConnectedOptions
from tflite.BuiltinOptions import BuiltinOptions
from tflite.TensorType import TensorType
from tflite.ActivationFunctionType import ActivationFunctionType
except ImportError:
raise ImportError("The tflite package must be installed")
assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) >= 2, "input tensors length should be >= 2"
input_tensor = input_tensors[0]
input_tensor_idx = input_tensor.tensor_idx
weight_tensor = input_tensors[1]
input_tensor_shape = input_tensor.tensor.ShapeAsNumpy()
weight_tensor_shape = weight_tensor.tensor.ShapeAsNumpy()
# reshape input tensor from N H W C to N H*W*C
input_size_per_batch = 1
for s in range(1, len(input_tensor_shape)):
input_size_per_batch *= input_tensor_shape[s]
assert input_size_per_batch == weight_tensor_shape[1], \
"input size and weight size are mismatched"
target_shape = tuple((input_tensor_shape[0], input_size_per_batch))
in_expr = self.get_expr(input_tensor_idx)
in_expr = _op.reshape(in_expr, target_shape)
assert op.BuiltinOptionsType() == BuiltinOptions.FullyConnectedOptions
op_options = op.BuiltinOptions()
fully_connected_options = FullyConnectedOptions()
fully_connected_options.Init(op_options.Bytes, op_options.Pos)
fused_activation_fn = fully_connected_options.FusedActivationFunction()
# weight tensor type should be UINT8 (quantization) or FLOAT32
weight_tensor_type = weight_tensor.tensor.Type()
assert weight_tensor_type in (TensorType.UINT8, TensorType.FLOAT32)
weight_tensor_type_str = self.get_tensor_type_str(weight_tensor_type)
weight_value = self.get_tensor_value(weight_tensor)
weight_expr = self.exp_tab.new_const(weight_value, dtype=weight_tensor_type_str)
out = _op.nn.dense(in_expr, weight_expr)
# if we have bias
if len(input_tensors) == 3:
bias_tensor = input_tensors[2]
bias_tensor_type = bias_tensor.tensor.Type()
# bias tensor type should be INT32 (quantization) or FLOAT32
assert bias_tensor_type in (TensorType.INT32, TensorType.FLOAT32)
bias_tensor_type_str = self.get_tensor_type_str(bias_tensor_type)
bias_expr = self.exp_tab.new_const(self.get_tensor_value(bias_tensor),
dtype=bias_tensor_type_str)
out = _op.nn.bias_add(out, bias_expr)
# If we have fused activations
if fused_activation_fn != ActivationFunctionType.NONE:
out = self.convert_fused_activation_function(out, fused_activation_fn)
return out
def convert_squeeze(self, op): def convert_squeeze(self, op):
"""Convert TFLite squeeze""" """Convert TFLite squeeze"""
try: try:
......
...@@ -459,12 +459,63 @@ def test_forward_softmax(): ...@@ -459,12 +459,63 @@ def test_forward_softmax():
""" Softmax """ """ Softmax """
_test_softmax(np.arange(6.0, dtype=np.float32).reshape((1, 6))) _test_softmax(np.arange(6.0, dtype=np.float32).reshape((1, 6)))
#######################################################################
# Fully Connected
# -------
def _test_fully_connected(tensor_in_sizes, filter_in_sizes, bias_in_size=None):
""" One iteration of fully connected """
total_size_1 = 1
total_size_2 = 1
for s in tensor_in_sizes:
total_size_1 *= s
for s in filter_in_sizes:
total_size_2 *= s
# Initializes the input tensor with array containing incrementing
# numbers from 1.
data_array = [f * 1.0 for f in range(1, total_size_1 + 1)]
filter_array = [f * 1.0 for f in range(1, total_size_2 + 1)]
assert int(total_size_1 / tensor_in_sizes[0]) == filter_in_sizes[0], \
"input size and filter size are mismatched"
with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=tensor_in_sizes, dtype='float32')
in_filter = constant_op.constant(filter_array, shape=filter_in_sizes, dtype='float32')
# reshape N H W C into N H*W*C
in_data_reshape = array_ops.reshape(in_data, [tensor_in_sizes[0], -1])
out = math_ops.mat_mul(in_data_reshape, in_filter)
# if we have bias
if bias_in_size:
assert bias_in_size[0] == filter_in_sizes[1], "bias and filter size are mismatched"
bias_array = [f * 1.0 for f in range(1, bias_in_size[0] + 1)]
in_bias = constant_op.constant(bias_array, shape=bias_in_size, dtype='float32')
out = nn_ops.bias_add(out, in_bias)
tflite_data_array = np.reshape(data_array, tensor_in_sizes).astype('float32')
tvm_data_array = np.transpose(tflite_data_array, axes=(0, 3, 1, 2))
compare_tflite_with_tvm(tflite_data_array, tvm_data_array,
'Placeholder:0', [in_data], [out])
def test_forward_fully_connected():
""" Fully Connected """
_test_fully_connected([1, 1, 1, 150], [150, 100])
_test_fully_connected([1, 1, 1, 150], [150, 100], [100])
_test_fully_connected([5, 1, 1, 150], [150, 100])
_test_fully_connected([5, 1, 1, 150], [150, 100], [100])
####################################################################### #######################################################################
# Mobilenet # Mobilenet
# --------- # ---------
def test_forward_mobilenet_v1(): def test_forward_mobilenet_v1():
'''test mobilenet v1 tflite model''' """Test the Mobilenet V1 TF Lite model."""
# MobilenetV1 # MobilenetV1
tflite_model_file = tf_testing.get_workload_official( tflite_model_file = tf_testing.get_workload_official(
"http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224.tgz", "http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224.tgz",
...@@ -479,7 +530,7 @@ def test_forward_mobilenet_v1(): ...@@ -479,7 +530,7 @@ def test_forward_mobilenet_v1():
rtol=1e-5, atol=1e-5) rtol=1e-5, atol=1e-5)
def test_forward_mobilenet_v2(): def test_forward_mobilenet_v2():
'''test mobilenet v2 tflite model''' """Test the Mobilenet V2 TF Lite model."""
# MobilenetV2 # MobilenetV2
tflite_model_file = tf_testing.get_workload_official( tflite_model_file = tf_testing.get_workload_official(
"http://download.tensorflow.org/models/tflite_11_05_08/mobilenet_v2_1.0_224.tgz", "http://download.tensorflow.org/models/tflite_11_05_08/mobilenet_v2_1.0_224.tgz",
...@@ -494,11 +545,11 @@ def test_forward_mobilenet_v2(): ...@@ -494,11 +545,11 @@ def test_forward_mobilenet_v2():
rtol=1e-5, atol=1e-5) rtol=1e-5, atol=1e-5)
####################################################################### #######################################################################
# Inception V3 # Inception
# ------------ # ------------
def test_forward_inception_v3_net(): def test_forward_inception_v3_net():
'''test inception v3 tflite model''' """Test the Inception V3 TF Lite model."""
# InceptionV3 # InceptionV3
tflite_model_file = tf_testing.get_workload_official( tflite_model_file = tf_testing.get_workload_official(
"https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v3_2018_04_27.tgz", "https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v3_2018_04_27.tgz",
...@@ -512,6 +563,21 @@ def test_forward_inception_v3_net(): ...@@ -512,6 +563,21 @@ def test_forward_inception_v3_net():
tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]),
rtol=1e-5, atol=1e-5) rtol=1e-5, atol=1e-5)
def test_forward_inception_v4_net():
"""Test the Inception V4 TF Lite model."""
# InceptionV4
tflite_model_file = tf_testing.get_workload_official(
"https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v4_2018_04_27.tgz",
"inception_v4.tflite")
with open(tflite_model_file, "rb") as f:
tflite_model_buf = f.read()
data = np.random.uniform(size=(1, 299, 299, 3)).astype('float32')
tvm_data = np.transpose(data, axes=(0, 3, 1, 2))
tflite_output = run_tflite_graph(tflite_model_buf, data)
tvm_output = run_tvm_graph(tflite_model_buf, tvm_data, 'input')
tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]),
rtol=1e-5, atol=1e-5)
####################################################################### #######################################################################
# Main # Main
# ---- # ----
...@@ -525,6 +591,7 @@ if __name__ == '__main__': ...@@ -525,6 +591,7 @@ if __name__ == '__main__':
test_forward_convolution() test_forward_convolution()
test_forward_pooling() test_forward_pooling()
test_forward_softmax() test_forward_softmax()
test_forward_fully_connected()
# Math # Math
test_forward_add() test_forward_add()
...@@ -533,3 +600,4 @@ if __name__ == '__main__': ...@@ -533,3 +600,4 @@ if __name__ == '__main__':
test_forward_mobilenet_v1() test_forward_mobilenet_v1()
test_forward_mobilenet_v2() test_forward_mobilenet_v2()
test_forward_inception_v3_net() test_forward_inception_v3_net()
test_forward_inception_v4_net()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment