Unverified Commit 608e9458 by Samuel Committed by GitHub

[TFLITE]Hard Swish & MobilnetV3 model testing (#5239)

* [TFLITE]Hard Swish & MobilnetV3 model testing

* CI Failure addressed
parent 00a84813
...@@ -84,6 +84,7 @@ class OperatorConverter(object): ...@@ -84,6 +84,7 @@ class OperatorConverter(object):
'FULLY_CONNECTED': self.convert_fully_connected, 'FULLY_CONNECTED': self.convert_fully_connected,
'GREATER_EQUAL': self.convert_greater_equal, 'GREATER_EQUAL': self.convert_greater_equal,
'GREATER': self.convert_greater, 'GREATER': self.convert_greater,
'HARD_SWISH': self.convert_hard_swish,
'L2_NORMALIZATION': self.convert_l2_normalization, 'L2_NORMALIZATION': self.convert_l2_normalization,
'LESS_EQUAL': self.convert_less_equal, 'LESS_EQUAL': self.convert_less_equal,
'LESS': self.convert_less, 'LESS': self.convert_less,
...@@ -595,6 +596,42 @@ class OperatorConverter(object): ...@@ -595,6 +596,42 @@ class OperatorConverter(object):
return out return out
def convert_hard_swish(self, op):
"""Convert TFLite Hard swish"""
try:
from tflite.Operator import Operator
except ImportError:
raise ImportError("The tflite package must be installed")
assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 1, "input tensors length should be 1"
input_tensor = input_tensors[0]
in_expr = self.get_expr(input_tensor.tensor_idx)
output_tensors = self.get_output_tensors(op)
assert len(output_tensors) == 1, "output tensors length should be 1"
output_tensor = output_tensors[0]
def _relu6(data):
return _op.tensor.clip(data, 0.0, 6.0)
def _hard_swish(data):
return data * _relu6(data + relay.const(3.0)) / relay.const(6.0)
# Dequantize if the input is quantized.
if input_tensor.qnn_params:
in_expr = self.dequantize(in_expr, input_tensor)
# Perform hardswish
out = _hard_swish(in_expr)
# Go back to integer dataype if the original operator was quantized.
if output_tensor.qnn_params:
out = self.quantize(out, output_tensor)
return out
def convert_concatenation(self, op): def convert_concatenation(self, op):
"""Convert TFLite concatenation""" """Convert TFLite concatenation"""
try: try:
......
...@@ -1626,6 +1626,26 @@ def test_forward_mobilenet_v2(): ...@@ -1626,6 +1626,26 @@ def test_forward_mobilenet_v2():
rtol=1e-5, atol=1e-5) rtol=1e-5, atol=1e-5)
####################################################################### #######################################################################
# Mobilenet V3
# ------------
def test_forward_mobilenet_v3():
"""Test the Mobilenet V3 TF Lite model."""
# In MobilenetV3, some ops are not supported before tf 1.15 fbs schema
if package_version.parse(tf.VERSION) < package_version.parse('1.15.0'):
return
tflite_model_file = tf_testing.get_workload_official(
"https://storage.googleapis.com/mobilenet_v3/checkpoints/v3-large_224_1.0_float.tgz",
"v3-large_224_1.0_float/v3-large_224_1.0_float.tflite")
with open(tflite_model_file, "rb") as f:
tflite_model_buf = f.read()
data = np.random.uniform(size=(1, 224, 224, 3)).astype('float32')
tflite_output = run_tflite_graph(tflite_model_buf, data)
tvm_output = run_tvm_graph(tflite_model_buf, data, 'input')
tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]),
rtol=1e-5, atol=1e-5)
#######################################################################
# Inception # Inception
# --------- # ---------
...@@ -1724,6 +1744,35 @@ def test_forward_qnn_mobilenet_v2_net(): ...@@ -1724,6 +1744,35 @@ def test_forward_qnn_mobilenet_v2_net():
tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels) tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels)
####################################################################### #######################################################################
# Mobilenet V3 Quantized
# ----------------------
def test_forward_qnn_mobilenet_v3_net():
"""Test the Quantized TFLite Mobilenet V3 model."""
# In MobilenetV3, some ops are not supported before tf 1.15 fbs schema
if package_version.parse(tf.VERSION) < package_version.parse('1.15.0'):
return
tflite_model_file = tf_testing.get_workload_official(
"https://storage.googleapis.com/mobilenet_v3/checkpoints/v3-large_224_1.0_uint8.tgz",
"v3-large_224_1.0_uint8/v3-large_224_1.0_uint8.tflite")
with open(tflite_model_file, "rb") as f:
tflite_model_buf = f.read()
# Test image. Checking the labels because the requantize implementation is different between
# TFLite and Relay. This cause final output numbers to mismatch. So, testing accuracy via
# labels. Also, giving a real image, instead of random inputs.
data = get_real_image(224, 224)
tflite_output = run_tflite_graph(tflite_model_buf, data)
tflite_predictions = np.squeeze(tflite_output)
tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1]
tvm_output = run_tvm_graph(tflite_model_buf, data, 'input')
tvm_predictions = np.squeeze(tvm_output)
tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1]
tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels)
#######################################################################
# SSD Mobilenet # SSD Mobilenet
# ------------- # -------------
...@@ -1831,6 +1880,7 @@ if __name__ == '__main__': ...@@ -1831,6 +1880,7 @@ if __name__ == '__main__':
# End to End # End to End
test_forward_mobilenet_v1() test_forward_mobilenet_v1()
test_forward_mobilenet_v2() test_forward_mobilenet_v2()
test_forward_mobilenet_v3()
test_forward_inception_v3_net() test_forward_inception_v3_net()
test_forward_inception_v4_net() test_forward_inception_v4_net()
test_forward_ssd_mobilenet_v1() test_forward_ssd_mobilenet_v1()
...@@ -1840,3 +1890,4 @@ if __name__ == '__main__': ...@@ -1840,3 +1890,4 @@ if __name__ == '__main__':
test_forward_qnn_inception_v1_net() test_forward_qnn_inception_v1_net()
test_forward_qnn_mobilenet_v1_net() test_forward_qnn_mobilenet_v1_net()
test_forward_qnn_mobilenet_v2_net() test_forward_qnn_mobilenet_v2_net()
test_forward_qnn_mobilenet_v3_net()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment