Commit bfb811c7 by Animesh Jain Committed by Zhi

[QNN][TFLite] Parsing TFLite quantized models. (#3900)

parent 7530e043
...@@ -997,6 +997,46 @@ def test_forward_inception_v4_net(): ...@@ -997,6 +997,46 @@ def test_forward_inception_v4_net():
tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]),
rtol=1e-5, atol=1e-5) rtol=1e-5, atol=1e-5)
def test_forward_qnn_inception_v1_net():
"""Test the Quantized TFLite Inception model."""
# InceptionV1
tflite_model_file = tf_testing.get_workload_official(
"https://storage.googleapis.com/download.tensorflow.org/models/inception_v1_224_quant_20181026.tgz",
"inception_v1_224_quant.tflite")
with open(tflite_model_file, "rb") as f:
tflite_model_buf = f.read()
# Checking the labels because the requantize implementation is different between TFLite and
# Relay. This cause final output numbers to mismatch. So, testing accuracy via labels.
np.random.seed(0)
data = np.random.random_integers(low=0, high=128, size=(1, 224, 224, 3)).astype('uint8')
tflite_output = run_tflite_graph(tflite_model_buf, data)
tflite_predictions = np.squeeze(tflite_output)
tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1]
tvm_output = run_tvm_graph(tflite_model_buf, data, 'input')
tvm_predictions = np.squeeze(tvm_output)
tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1]
tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels)
def test_forward_qnn_mobilenet_v1_net():
"""Test the Quantized TFLite Mobilenet V1 model."""
# MobilenetV1
tflite_model_file = tf_testing.get_workload_official(
"https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz",
"mobilenet_v1_1.0_224_quant.tflite")
with open(tflite_model_file, "rb") as f:
tflite_model_buf = f.read()
# Checking the labels because the requantize implementation is different between TFLite and
# Relay. This cause final output numbers to mismatch. So, testing accuracy via labels.
np.random.seed(0)
data = np.random.random_integers(low=0, high=128, size=(1, 224, 224, 3)).astype('uint8')
tflite_output = run_tflite_graph(tflite_model_buf, data)
tflite_predictions = np.squeeze(tflite_output)
tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1]
tvm_output = run_tvm_graph(tflite_model_buf, data, 'input')
tvm_predictions = np.squeeze(tvm_output)
tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1]
tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels)
####################################################################### #######################################################################
# SSD Mobilenet # SSD Mobilenet
# ------------- # -------------
...@@ -1069,3 +1109,8 @@ if __name__ == '__main__': ...@@ -1069,3 +1109,8 @@ if __name__ == '__main__':
test_forward_inception_v3_net() test_forward_inception_v3_net()
test_forward_inception_v4_net() test_forward_inception_v4_net()
test_forward_ssd_mobilenet_v1() test_forward_ssd_mobilenet_v1()
# End to End quantized
# TODO - MobilenetV2 fails for now. Remove when fixed.
test_forward_qnn_inception_v1_net()
test_forward_qnn_mobilenet_v1_net()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment