Commit e3206aa8 by Haichen Shen Committed by Tianqi Chen

[TEST] Cache test data (#2921)

parent 4ac64fc4
from six.moves import urllib
import os import os
from PIL import Image from PIL import Image
import numpy as np import numpy as np
from tvm.contrib.download import download_testdata
def download(url, path, overwrite=False):
if os.path.exists(path) and not overwrite:
return
print('Downloading {} to {}.'.format(url, path))
urllib.request.urlretrieve(url, path)
def get_mobilenet(): def get_mobilenet():
url = 'https://docs-assets.developer.apple.com/coreml/models/MobileNet.mlmodel' url = 'https://docs-assets.developer.apple.com/coreml/models/MobileNet.mlmodel'
dst = 'mobilenet.mlmodel' dst = 'mobilenet.mlmodel'
real_dst = os.path.abspath(os.path.join(os.path.dirname(__file__), dst)) real_dst = download_testdata(url, dst, module='coreml')
download(url, real_dst) return real_dst
return os.path.abspath(real_dst)
def get_resnet50(): def get_resnet50():
url = 'https://docs-assets.developer.apple.com/coreml/models/Resnet50.mlmodel' url = 'https://docs-assets.developer.apple.com/coreml/models/Resnet50.mlmodel'
dst = 'resnet50.mlmodel' dst = 'resnet50.mlmodel'
real_dst = os.path.abspath(os.path.join(os.path.dirname(__file__), dst)) real_dst = download_testdata(url, dst, module='coreml')
download(url, real_dst) return real_dst
return os.path.abspath(real_dst)
def get_cat_image(): def get_cat_image():
url = 'https://gist.githubusercontent.com/zhreshold/bcda4716699ac97ea44f791c24310193/raw/fa7ef0e9c9a5daea686d6473a62aacd1a5885849/cat.png' url = 'https://gist.githubusercontent.com/zhreshold/bcda4716699ac97ea44f791c24310193/raw/fa7ef0e9c9a5daea686d6473a62aacd1a5885849/cat.png'
dst = 'cat.png' dst = 'cat.png'
real_dst = os.path.abspath(os.path.join(os.path.dirname(__file__), dst)) real_dst = download_testdata(url, dst, module='coreml')
download(url, real_dst)
img = Image.open(real_dst).resize((224, 224)) img = Image.open(real_dst).resize((224, 224))
img = np.transpose(img, (2, 0, 1))[np.newaxis, :] img = np.transpose(img, (2, 0, 1))[np.newaxis, :]
return np.asarray(img) return np.asarray(img)
...@@ -12,44 +12,16 @@ import urllib ...@@ -12,44 +12,16 @@ import urllib
import numpy as np import numpy as np
import tvm import tvm
from tvm.contrib import graph_runtime from tvm.contrib import graph_runtime
from tvm.contrib.download import download_testdata
from nnvm import frontend from nnvm import frontend
from nnvm.testing.darknet import LAYERTYPE from nnvm.testing.darknet import LAYERTYPE
from nnvm.testing.darknet import __darknetffi__ from nnvm.testing.darknet import __darknetffi__
import nnvm.compiler import nnvm.compiler
if sys.version_info >= (3,):
import urllib.request as urllib2
else:
import urllib2
def _download(url, path, overwrite=False, sizecompare=False):
''' Download from internet'''
if os.path.isfile(path) and not overwrite:
if sizecompare:
file_size = os.path.getsize(path)
res_head = requests.head(url)
res_get = requests.get(url, stream=True)
if 'Content-Length' not in res_head.headers:
res_get = urllib2.urlopen(url)
urlfile_size = int(res_get.headers['Content-Length'])
if urlfile_size != file_size:
print("exist file got corrupted, downloading", path, " file freshly")
_download(url, path, True, False)
return
print('File {} exists, skip.'.format(path))
return
print('Downloading from url {} to {}'.format(url, path))
try:
urllib.request.urlretrieve(url, path)
print('')
except:
urllib.urlretrieve(url, path)
DARKNET_LIB = 'libdarknet2.0.so' DARKNET_LIB = 'libdarknet2.0.so'
DARKNETLIB_URL = 'https://github.com/siju-samuel/darknet/blob/master/lib/' \ DARKNETLIB_URL = 'https://github.com/siju-samuel/darknet/blob/master/lib/' \
+ DARKNET_LIB + '?raw=true' + DARKNET_LIB + '?raw=true'
_download(DARKNETLIB_URL, DARKNET_LIB) LIB = __darknetffi__.dlopen(download_testdata(DARKNETLIB_URL, DARKNET_LIB, module='darknet'))
LIB = __darknetffi__.dlopen('./' + DARKNET_LIB)
def _read_memory_buffer(shape, data, dtype='float32'): def _read_memory_buffer(shape, data, dtype='float32'):
length = 1 length = 1
...@@ -82,6 +54,12 @@ def _get_tvm_output(net, data, build_dtype='float32'): ...@@ -82,6 +54,12 @@ def _get_tvm_output(net, data, build_dtype='float32'):
tvm_out.append(m.get_output(i).asnumpy()) tvm_out.append(m.get_output(i).asnumpy())
return tvm_out return tvm_out
def _load_net(cfg_url, cfg_name, weights_url, weights_name):
cfg_path = download_testdata(cfg_url, cfg_name, module='darknet')
weights_path = download_testdata(weights_url, weights_name, module='darknet')
net = LIB.load_network(cfg_path.encode('utf-8'), weights_path.encode('utf-8'), 0)
return net
def test_forward(net, build_dtype='float32'): def test_forward(net, build_dtype='float32'):
'''Test network with given input image on both darknet and tvm''' '''Test network with given input image on both darknet and tvm'''
def get_darknet_output(net, img): def get_darknet_output(net, img):
...@@ -125,8 +103,8 @@ def test_forward(net, build_dtype='float32'): ...@@ -125,8 +103,8 @@ def test_forward(net, build_dtype='float32'):
test_image = 'dog.jpg' test_image = 'dog.jpg'
img_url = 'https://github.com/siju-samuel/darknet/blob/master/data/' + test_image +'?raw=true' img_url = 'https://github.com/siju-samuel/darknet/blob/master/data/' + test_image +'?raw=true'
_download(img_url, test_image) img_path = download_testdata(img_url, test_image, module='darknet')
img = LIB.letterbox_image(LIB.load_image_color(test_image.encode('utf-8'), 0, 0), net.w, net.h) img = LIB.letterbox_image(LIB.load_image_color(img_path.encode('utf-8'), 0, 0), net.w, net.h)
darknet_output = get_darknet_output(net, img) darknet_output = get_darknet_output(net, img)
batch_size = 1 batch_size = 1
data = np.empty([batch_size, img.c, img.h, img.w], dtype) data = np.empty([batch_size, img.c, img.h, img.w], dtype)
...@@ -167,9 +145,7 @@ def test_forward_extraction(): ...@@ -167,9 +145,7 @@ def test_forward_extraction():
weights_name = model_name + '.weights' weights_name = model_name + '.weights'
cfg_url = 'https://github.com/pjreddie/darknet/blob/master/cfg/' + cfg_name + '?raw=true' cfg_url = 'https://github.com/pjreddie/darknet/blob/master/cfg/' + cfg_name + '?raw=true'
weights_url = 'http://pjreddie.com/media/files/' + weights_name + '?raw=true' weights_url = 'http://pjreddie.com/media/files/' + weights_name + '?raw=true'
_download(cfg_url, cfg_name) net = _load_net(cfg_url, cfg_name, weights_url, weights_name)
_download(weights_url, weights_name)
net = LIB.load_network(cfg_name.encode('utf-8'), weights_name.encode('utf-8'), 0)
test_forward(net) test_forward(net)
LIB.free_network(net) LIB.free_network(net)
...@@ -180,9 +156,7 @@ def test_forward_alexnet(): ...@@ -180,9 +156,7 @@ def test_forward_alexnet():
weights_name = model_name + '.weights' weights_name = model_name + '.weights'
cfg_url = 'https://github.com/pjreddie/darknet/blob/master/cfg/' + cfg_name + '?raw=true' cfg_url = 'https://github.com/pjreddie/darknet/blob/master/cfg/' + cfg_name + '?raw=true'
weights_url = 'http://pjreddie.com/media/files/' + weights_name + '?raw=true' weights_url = 'http://pjreddie.com/media/files/' + weights_name + '?raw=true'
_download(cfg_url, cfg_name) net = _load_net(cfg_url, cfg_name, weights_url, weights_name)
_download(weights_url, weights_name)
net = LIB.load_network(cfg_name.encode('utf-8'), weights_name.encode('utf-8'), 0)
test_forward(net) test_forward(net)
LIB.free_network(net) LIB.free_network(net)
...@@ -193,9 +167,7 @@ def test_forward_resnet50(): ...@@ -193,9 +167,7 @@ def test_forward_resnet50():
weights_name = model_name + '.weights' weights_name = model_name + '.weights'
cfg_url = 'https://github.com/pjreddie/darknet/blob/master/cfg/' + cfg_name + '?raw=true' cfg_url = 'https://github.com/pjreddie/darknet/blob/master/cfg/' + cfg_name + '?raw=true'
weights_url = 'http://pjreddie.com/media/files/' + weights_name + '?raw=true' weights_url = 'http://pjreddie.com/media/files/' + weights_name + '?raw=true'
_download(cfg_url, cfg_name) net = _load_net(cfg_url, cfg_name, weights_url, weights_name)
_download(weights_url, weights_name)
net = LIB.load_network(cfg_name.encode('utf-8'), weights_name.encode('utf-8'), 0)
test_forward(net) test_forward(net)
LIB.free_network(net) LIB.free_network(net)
...@@ -206,9 +178,7 @@ def test_forward_yolov2(): ...@@ -206,9 +178,7 @@ def test_forward_yolov2():
weights_name = model_name + '.weights' weights_name = model_name + '.weights'
cfg_url = 'https://github.com/pjreddie/darknet/blob/master/cfg/' + cfg_name + '?raw=true' cfg_url = 'https://github.com/pjreddie/darknet/blob/master/cfg/' + cfg_name + '?raw=true'
weights_url = 'http://pjreddie.com/media/files/' + weights_name + '?raw=true' weights_url = 'http://pjreddie.com/media/files/' + weights_name + '?raw=true'
_download(cfg_url, cfg_name) net = _load_net(cfg_url, cfg_name, weights_url, weights_name)
_download(weights_url, weights_name)
net = LIB.load_network(cfg_name.encode('utf-8'), weights_name.encode('utf-8'), 0)
build_dtype = {} build_dtype = {}
test_forward(net, build_dtype) test_forward(net, build_dtype)
LIB.free_network(net) LIB.free_network(net)
...@@ -220,9 +190,7 @@ def test_forward_yolov3(): ...@@ -220,9 +190,7 @@ def test_forward_yolov3():
weights_name = model_name + '.weights' weights_name = model_name + '.weights'
cfg_url = 'https://github.com/pjreddie/darknet/blob/master/cfg/' + cfg_name + '?raw=true' cfg_url = 'https://github.com/pjreddie/darknet/blob/master/cfg/' + cfg_name + '?raw=true'
weights_url = 'http://pjreddie.com/media/files/' + weights_name + '?raw=true' weights_url = 'http://pjreddie.com/media/files/' + weights_name + '?raw=true'
_download(cfg_url, cfg_name) net = _load_net(cfg_url, cfg_name, weights_url, weights_name)
_download(weights_url, weights_name)
net = LIB.load_network(cfg_name.encode('utf-8'), weights_name.encode('utf-8'), 0)
build_dtype = {} build_dtype = {}
test_forward(net, build_dtype) test_forward(net, build_dtype)
LIB.free_network(net) LIB.free_network(net)
......
...@@ -3,22 +3,7 @@ from __future__ import absolute_import as _abs ...@@ -3,22 +3,7 @@ from __future__ import absolute_import as _abs
import os import os
import logging import logging
from .super_resolution import get_super_resolution from .super_resolution import get_super_resolution
from tvm.contrib.download import download_testdata
def _download(url, filename, overwrite=False):
if os.path.isfile(filename) and not overwrite:
logging.debug('File %s existed, skip.', filename)
return
logging.debug('Downloading from url %s to %s', url, filename)
try:
import urllib.request
urllib.request.urlretrieve(url, filename)
except:
import urllib
urllib.urlretrieve(url, filename)
def _as_abs_path(fname):
cur_dir = os.path.abspath(os.path.dirname(__file__))
return os.path.join(cur_dir, fname)
URLS = { URLS = {
...@@ -30,9 +15,9 @@ URLS = { ...@@ -30,9 +15,9 @@ URLS = {
# download and add paths # download and add paths
for k, v in URLS.items(): for k, v in URLS.items():
name = k.split('.')[0] name = k.split('.')[0]
path = _as_abs_path(k) relpath = os.path.join('onnx', k)
_download(v, path, False) abspath = download_testdata(v, relpath, module='onnx')
locals()[name] = path locals()[name] = abspath
# symbol for graph comparison # symbol for graph comparison
super_resolution_sym = get_super_resolution() super_resolution_sym = get_super_resolution()
...@@ -5,8 +5,10 @@ from __future__ import absolute_import as _abs ...@@ -5,8 +5,10 @@ from __future__ import absolute_import as _abs
import os import os
import sys import sys
import time import time
import uuid
import shutil
def download(url, path, overwrite=False, size_compare=False, verbose=1): def download(url, path, overwrite=False, size_compare=False, verbose=1, retries=3):
"""Downloads the file from the internet. """Downloads the file from the internet.
Set the input options correctly to overwrite or do the size comparison Set the input options correctly to overwrite or do the size comparison
...@@ -53,6 +55,11 @@ def download(url, path, overwrite=False, size_compare=False, verbose=1): ...@@ -53,6 +55,11 @@ def download(url, path, overwrite=False, size_compare=False, verbose=1):
# Stateful start time # Stateful start time
start_time = time.time() start_time = time.time()
dirpath = os.path.dirname(path)
if not os.path.isdir(dirpath):
os.makedirs(dirpath)
random_uuid = str(uuid.uuid4())
tempfile = os.path.join(dirpath, random_uuid)
def _download_progress(count, block_size, total_size): def _download_progress(count, block_size, total_size):
#pylint: disable=unused-argument #pylint: disable=unused-argument
...@@ -68,11 +75,62 @@ def download(url, path, overwrite=False, size_compare=False, verbose=1): ...@@ -68,11 +75,62 @@ def download(url, path, overwrite=False, size_compare=False, verbose=1):
(percent, progress_size / (1024.0 * 1024), speed, duration)) (percent, progress_size / (1024.0 * 1024), speed, duration))
sys.stdout.flush() sys.stdout.flush()
while retries >= 0:
# Disable pyling too broad Exception
# pylint: disable=W0703
try:
if sys.version_info >= (3,): if sys.version_info >= (3,):
urllib2.urlretrieve(url, path, reporthook=_download_progress) urllib2.urlretrieve(url, tempfile, reporthook=_download_progress)
print("") print("")
else: else:
f = urllib2.urlopen(url) f = urllib2.urlopen(url)
data = f.read() data = f.read()
with open(path, "wb") as code: with open(tempfile, "wb") as code:
code.write(data) code.write(data)
shutil.move(tempfile, path)
break
except Exception as err:
retries -= 1
if retries == 0:
os.remove(tempfile)
raise err
else:
print("download failed due to {}, retrying, {} attempt{} left"
.format(repr(err), retries, 's' if retries > 1 else ''))
TEST_DATA_ROOT_PATH = os.path.join(os.path.expanduser('~'), '.tvm_test_data')
if not os.path.exists(TEST_DATA_ROOT_PATH):
os.mkdir(TEST_DATA_ROOT_PATH)
def download_testdata(url, relpath, module=None):
"""Downloads the test data from the internet.
Parameters
----------
url : str
Download url.
relpath : str
Relative file path.
module : Union[str, list, tuple], optional
Subdirectory paths under test data folder.
Returns
-------
abspath : str
Absolute file path of downloaded file
"""
global TEST_DATA_ROOT_PATH
if module is None:
module_path = ''
elif isinstance(module, str):
module_path = module
elif isinstance(module, (list, tuple)):
module_path = os.path.join(*module)
else:
raise ValueError("Unsupported module: " + module)
abspath = os.path.join(TEST_DATA_ROOT_PATH, module_path, relpath)
download(url, abspath, overwrite=False, size_compare=True)
return abspath
...@@ -13,7 +13,7 @@ import numpy as np ...@@ -13,7 +13,7 @@ import numpy as np
import tensorflow as tf import tensorflow as tf
from tensorflow.core.framework import graph_pb2 from tensorflow.core.framework import graph_pb2
from tvm.contrib import util from tvm.contrib.download import download_testdata
###################################################################### ######################################################################
# Some helper functions # Some helper functions
...@@ -136,7 +136,7 @@ class NodeLookup(object): ...@@ -136,7 +136,7 @@ class NodeLookup(object):
return '' return ''
return self.node_lookup[node_id] return self.node_lookup[node_id]
def get_workload_official(model_url, model_sub_path, temp_dir): def get_workload_official(model_url, model_sub_path):
""" Import workload from tensorflow official """ Import workload from tensorflow official
Parameters Parameters
...@@ -158,21 +158,17 @@ def get_workload_official(model_url, model_sub_path, temp_dir): ...@@ -158,21 +158,17 @@ def get_workload_official(model_url, model_sub_path, temp_dir):
""" """
model_tar_name = os.path.basename(model_url) model_tar_name = os.path.basename(model_url)
model_path = download_testdata(model_url, model_tar_name, module=['tf', 'official'])
from mxnet.gluon.utils import download dir_path = os.path.dirname(model_path)
temp_path = temp_dir.relpath("./")
path_model = temp_path + model_tar_name
download(model_url, path_model)
import tarfile import tarfile
if path_model.endswith("tgz") or path_model.endswith("gz"): if model_path.endswith("tgz") or model_path.endswith("gz"):
tar = tarfile.open(path_model) tar = tarfile.open(model_path)
tar.extractall(path=temp_path) tar.extractall(path=dir_path)
tar.close() tar.close()
else: else:
raise RuntimeError('Could not decompress the file: ' + path_model) raise RuntimeError('Could not decompress the file: ' + model_path)
return temp_path + model_sub_path return os.path.join(dir_path, model_sub_path)
def get_workload(model_path, model_sub_path=None): def get_workload(model_path, model_sub_path=None):
""" Import workload from frozen protobuf """ Import workload from frozen protobuf
...@@ -192,24 +188,18 @@ def get_workload(model_path, model_sub_path=None): ...@@ -192,24 +188,18 @@ def get_workload(model_path, model_sub_path=None):
""" """
temp = util.tempdir()
if model_sub_path: if model_sub_path:
path_model = get_workload_official(model_path, model_sub_path, temp) path_model = get_workload_official(model_path, model_sub_path)
else: else:
repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/' repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/'
model_name = os.path.basename(model_path)
model_url = os.path.join(repo_base, model_path) model_url = os.path.join(repo_base, model_path)
path_model = download_testdata(model_url, model_path, module='tf')
from mxnet.gluon.utils import download
path_model = temp.relpath(model_name)
download(model_url, path_model)
# Creates graph from saved graph_def.pb. # Creates graph from saved graph_def.pb.
with tf.gfile.FastGFile(path_model, 'rb') as f: with tf.gfile.FastGFile(path_model, 'rb') as f:
graph_def = tf.GraphDef() graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read()) graph_def.ParseFromString(f.read())
graph = tf.import_graph_def(graph_def, name='') graph = tf.import_graph_def(graph_def, name='')
temp.remove()
return graph_def return graph_def
####################################################################### #######################################################################
...@@ -292,7 +282,7 @@ def do_tf_sample(session, data, in_states, num_samples): ...@@ -292,7 +282,7 @@ def do_tf_sample(session, data, in_states, num_samples):
def _create_ptb_vocabulary(data_dir): def _create_ptb_vocabulary(data_dir):
"""Read the PTB sample data input to create vocabulary""" """Read the PTB sample data input to create vocabulary"""
data_path = data_dir+'simple-examples/data/' data_path = os.path.join(data_dir, 'simple-examples/data/')
file_name = 'ptb.train.txt' file_name = 'ptb.train.txt'
def _read_words(filename): def _read_words(filename):
"""Read the data for creating vocabulary""" """Read the data for creating vocabulary"""
...@@ -341,13 +331,10 @@ def get_workload_ptb(): ...@@ -341,13 +331,10 @@ def get_workload_ptb():
ptb_model_file = 'RNN/ptb/ptb_model_with_lstmblockcell.pb' ptb_model_file = 'RNN/ptb/ptb_model_with_lstmblockcell.pb'
import tarfile import tarfile
from tvm.contrib.download import download file_path = download_testdata(sample_url, sample_data_file, module=['tf', 'ptb_data'])
DATA_DIR = './ptb_data/' dir_path = os.path.dirname(file_path)
if not os.path.exists(DATA_DIR): t = tarfile.open(file_path, 'r')
os.mkdir(DATA_DIR) t.extractall(dir_path)
download(sample_url, DATA_DIR+sample_data_file)
t = tarfile.open(DATA_DIR+sample_data_file, 'r') word_to_id, id_to_word = _create_ptb_vocabulary(dir_path)
t.extractall(DATA_DIR)
word_to_id, id_to_word = _create_ptb_vocabulary(DATA_DIR)
return word_to_id, id_to_word, get_workload(ptb_model_file) return word_to_id, id_to_word, get_workload(ptb_model_file)
from six.moves import urllib
import os import os
from PIL import Image from PIL import Image
import numpy as np import numpy as np
from tvm.contrib.download import download_testdata
def download(url, path, overwrite=False):
if os.path.exists(path) and not overwrite:
return
print('Downloading {} to {}.'.format(url, path))
urllib.request.urlretrieve(url, path)
def get_mobilenet(): def get_mobilenet():
url = 'https://docs-assets.developer.apple.com/coreml/models/MobileNet.mlmodel' url = 'https://docs-assets.developer.apple.com/coreml/models/MobileNet.mlmodel'
dst = 'mobilenet.mlmodel' dst = 'mobilenet.mlmodel'
real_dst = os.path.abspath(os.path.join(os.path.dirname(__file__), dst)) real_dst = download_testdata(url, dst, module='coreml')
download(url, real_dst)
return os.path.abspath(real_dst) return os.path.abspath(real_dst)
def get_resnet50(): def get_resnet50():
url = 'https://docs-assets.developer.apple.com/coreml/models/Resnet50.mlmodel' url = 'https://docs-assets.developer.apple.com/coreml/models/Resnet50.mlmodel'
dst = 'resnet50.mlmodel' dst = 'resnet50.mlmodel'
real_dst = os.path.abspath(os.path.join(os.path.dirname(__file__), dst)) real_dst = download_testdata(url, dst, module='coreml')
download(url, real_dst)
return os.path.abspath(real_dst) return os.path.abspath(real_dst)
def get_cat_image(): def get_cat_image():
url = 'https://gist.githubusercontent.com/zhreshold/bcda4716699ac97ea44f791c24310193/raw/fa7ef0e9c9a5daea686d6473a62aacd1a5885849/cat.png' url = 'https://gist.githubusercontent.com/zhreshold/bcda4716699ac97ea44f791c24310193/raw/fa7ef0e9c9a5daea686d6473a62aacd1a5885849/cat.png'
dst = 'cat.png' dst = 'cat.png'
real_dst = os.path.abspath(os.path.join(os.path.dirname(__file__), dst)) real_dst = download_testdata(url, dst, module='coreml')
download(url, real_dst)
img = Image.open(real_dst).resize((224, 224)) img = Image.open(real_dst).resize((224, 224))
img = np.transpose(img, (2, 0, 1))[np.newaxis, :] img = np.transpose(img, (2, 0, 1))[np.newaxis, :]
return np.asarray(img) return np.asarray(img)
\ No newline at end of file
...@@ -391,10 +391,9 @@ def test_forward_softmax(): ...@@ -391,10 +391,9 @@ def test_forward_softmax():
def test_forward_mobilenet(): def test_forward_mobilenet():
'''test mobilenet v1 tflite model''' '''test mobilenet v1 tflite model'''
# MobilenetV1 # MobilenetV1
temp = util.tempdir()
tflite_model_file = tf_testing.get_workload_official( tflite_model_file = tf_testing.get_workload_official(
"http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224.tgz", "http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224.tgz",
"mobilenet_v1_1.0_224.tflite", temp) "mobilenet_v1_1.0_224.tflite")
with open(tflite_model_file, "rb") as f: with open(tflite_model_file, "rb") as f:
tflite_model_buf = f.read() tflite_model_buf = f.read()
data = np.random.uniform(size=(1, 224, 224, 3)).astype('float32') data = np.random.uniform(size=(1, 224, 224, 3)).astype('float32')
...@@ -403,7 +402,6 @@ def test_forward_mobilenet(): ...@@ -403,7 +402,6 @@ def test_forward_mobilenet():
tvm_output = run_tvm_graph(tflite_model_buf, tvm_data, 'input') tvm_output = run_tvm_graph(tflite_model_buf, tvm_data, 'input')
tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]),
rtol=1e-5, atol=1e-5) rtol=1e-5, atol=1e-5)
temp.remove()
####################################################################### #######################################################################
# Inception V3 # Inception V3
...@@ -412,10 +410,9 @@ def test_forward_mobilenet(): ...@@ -412,10 +410,9 @@ def test_forward_mobilenet():
def test_forward_inception_v3_net(): def test_forward_inception_v3_net():
'''test inception v3 tflite model''' '''test inception v3 tflite model'''
# InceptionV3 # InceptionV3
temp = util.tempdir()
tflite_model_file = tf_testing.get_workload_official( tflite_model_file = tf_testing.get_workload_official(
"https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v3_2018_04_27.tgz", "https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v3_2018_04_27.tgz",
"inception_v3.tflite", temp) "inception_v3.tflite")
with open(tflite_model_file, "rb") as f: with open(tflite_model_file, "rb") as f:
tflite_model_buf = f.read() tflite_model_buf = f.read()
data = np.random.uniform(size=(1, 299, 299, 3)).astype('float32') data = np.random.uniform(size=(1, 299, 299, 3)).astype('float32')
...@@ -424,7 +421,6 @@ def test_forward_inception_v3_net(): ...@@ -424,7 +421,6 @@ def test_forward_inception_v3_net():
tvm_output = run_tvm_graph(tflite_model_buf, tvm_data, 'input') tvm_output = run_tvm_graph(tflite_model_buf, tvm_data, 'input')
tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]),
rtol=1e-5, atol=1e-5) rtol=1e-5, atol=1e-5)
temp.remove()
####################################################################### #######################################################################
# Main # Main
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment