Commit 59d8ba8f by Alexander Pivovarov Committed by Yao Wang

Add test_forward_ssd_mobilenet_v1 to tflite/test_forward (#3350)

parent 8a89177b
...@@ -163,13 +163,10 @@ def get_workload_official(model_url, model_sub_path): ...@@ -163,13 +163,10 @@ def get_workload_official(model_url, model_sub_path):
model_sub_path: model_sub_path:
Sub path in extracted tar for the ftozen protobuf file. Sub path in extracted tar for the ftozen protobuf file.
temp_dir: TempDirectory
The temporary directory object to download the content.
Returns Returns
------- -------
graph_def: graphdef model_path: str
graph_def is the tensorflow workload for mobilenet. Full path to saved model file
""" """
...@@ -200,7 +197,7 @@ def get_workload(model_path, model_sub_path=None): ...@@ -200,7 +197,7 @@ def get_workload(model_path, model_sub_path=None):
Returns Returns
------- -------
graph_def: graphdef graph_def: graphdef
graph_def is the tensorflow workload for mobilenet. graph_def is the tensorflow workload.
""" """
......
...@@ -599,6 +599,24 @@ def test_forward_inception_v4_net(): ...@@ -599,6 +599,24 @@ def test_forward_inception_v4_net():
rtol=1e-5, atol=1e-5) rtol=1e-5, atol=1e-5)
####################################################################### #######################################################################
# SSD Mobilenet
# -------------
def test_forward_ssd_mobilenet_v1():
"""Test the SSD Mobilenet V1 TF Lite model."""
# SSD MobilenetV1
tflite_model_file = tf_testing.get_workload_official(
"https://raw.githubusercontent.com/dmlc/web-data/master/tensorflow/models/object_detection/ssd_mobilenet_v1_coco_2018_01_28_nopp.tgz",
"ssd_mobilenet_v1_coco_2018_01_28_nopp.tflite")
with open(tflite_model_file, "rb") as f:
tflite_model_buf = f.read()
data = np.random.uniform(size=(1, 300, 300, 3)).astype('float32')
tflite_output = run_tflite_graph(tflite_model_buf, data)
tvm_output = run_tvm_graph(tflite_model_buf, data, 'normalized_input_image_tensor')
tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]),
rtol=1e-5, atol=1e-5)
#######################################################################
# Main # Main
# ---- # ----
if __name__ == '__main__': if __name__ == '__main__':
...@@ -623,3 +641,4 @@ if __name__ == '__main__': ...@@ -623,3 +641,4 @@ if __name__ == '__main__':
test_forward_mobilenet_v2() test_forward_mobilenet_v2()
test_forward_inception_v3_net() test_forward_inception_v3_net()
test_forward_inception_v4_net() test_forward_inception_v4_net()
test_forward_ssd_mobilenet_v1()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment