Commit 331f6fd0 by Alexander Pivovarov Committed by Yizhi Liu

Fix TFLite RESHAPE assert (#4320)

parent 26eb4053
...@@ -265,7 +265,7 @@ class OperatorConverter(object): ...@@ -265,7 +265,7 @@ class OperatorConverter(object):
assert isinstance(op, Operator) assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op) input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 2, "input tensors length should be 2" assert input_tensors, "input tensors should not be empty"
input_tensor = input_tensors[0] input_tensor = input_tensors[0]
input_tensor_idx = input_tensor.tensor_idx input_tensor_idx = input_tensor.tensor_idx
......
...@@ -38,6 +38,7 @@ try: ...@@ -38,6 +38,7 @@ try:
except ImportError: except ImportError:
from tensorflow.contrib import lite as interpreter_wrapper from tensorflow.contrib import lite as interpreter_wrapper
from tvm.contrib.download import download_testdata
import tvm.relay.testing.tf as tf_testing import tvm.relay.testing.tf as tf_testing
from packaging import version as package_version from packaging import version as package_version
...@@ -1138,6 +1139,25 @@ def test_forward_ssd_mobilenet_v1(): ...@@ -1138,6 +1139,25 @@ def test_forward_ssd_mobilenet_v1():
rtol=1e-5, atol=1e-5) rtol=1e-5, atol=1e-5)
####################################################################### #######################################################################
# MediaPipe
# -------------
def test_forward_mediapipe_hand_landmark():
"""Test MediaPipe 2D hand landmark TF Lite model."""
# MediaPipe 2D hand landmark TF
tflite_model_file = download_testdata(
"https://github.com/google/mediapipe/raw/master/mediapipe/models/hand_landmark.tflite",
"hand_landmark.tflite")
with open(tflite_model_file, "rb") as f:
tflite_model_buf = f.read()
data = np.random.uniform(size=(1, 256, 256, 3)).astype('float32')
tflite_output = run_tflite_graph(tflite_model_buf, data)
tvm_output = run_tvm_graph(tflite_model_buf, data, 'input_1', num_output=2)
for i in range(2):
tvm.testing.assert_allclose(np.squeeze(tvm_output[i]), np.squeeze(tflite_output[i]),
rtol=1e-5, atol=1e-5)
#######################################################################
# Main # Main
# ---- # ----
if __name__ == '__main__': if __name__ == '__main__':
...@@ -1192,6 +1212,7 @@ if __name__ == '__main__': ...@@ -1192,6 +1212,7 @@ if __name__ == '__main__':
test_forward_inception_v3_net() test_forward_inception_v3_net()
test_forward_inception_v4_net() test_forward_inception_v4_net()
test_forward_ssd_mobilenet_v1() test_forward_ssd_mobilenet_v1()
test_forward_mediapipe_hand_landmark()
# End to End quantized # End to End quantized
test_forward_qnn_inception_v1_net() test_forward_qnn_inception_v1_net()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment