Commit e3206aa8 by Haichen Shen Committed by Tianqi Chen

[TEST] Cache test data (#2921)

parent 4ac64fc4
from six.moves import urllib
import os
from PIL import Image
import numpy as np
def download(url, path, overwrite=False):
if os.path.exists(path) and not overwrite:
return
print('Downloading {} to {}.'.format(url, path))
urllib.request.urlretrieve(url, path)
from tvm.contrib.download import download_testdata
def get_mobilenet():
url = 'https://docs-assets.developer.apple.com/coreml/models/MobileNet.mlmodel'
dst = 'mobilenet.mlmodel'
real_dst = os.path.abspath(os.path.join(os.path.dirname(__file__), dst))
download(url, real_dst)
return os.path.abspath(real_dst)
real_dst = download_testdata(url, dst, module='coreml')
return real_dst
def get_resnet50():
url = 'https://docs-assets.developer.apple.com/coreml/models/Resnet50.mlmodel'
dst = 'resnet50.mlmodel'
real_dst = os.path.abspath(os.path.join(os.path.dirname(__file__), dst))
download(url, real_dst)
return os.path.abspath(real_dst)
real_dst = download_testdata(url, dst, module='coreml')
return real_dst
def get_cat_image():
url = 'https://gist.githubusercontent.com/zhreshold/bcda4716699ac97ea44f791c24310193/raw/fa7ef0e9c9a5daea686d6473a62aacd1a5885849/cat.png'
dst = 'cat.png'
real_dst = os.path.abspath(os.path.join(os.path.dirname(__file__), dst))
download(url, real_dst)
real_dst = download_testdata(url, dst, module='coreml')
img = Image.open(real_dst).resize((224, 224))
img = np.transpose(img, (2, 0, 1))[np.newaxis, :]
return np.asarray(img)
......@@ -12,44 +12,16 @@ import urllib
import numpy as np
import tvm
from tvm.contrib import graph_runtime
from tvm.contrib.download import download_testdata
from nnvm import frontend
from nnvm.testing.darknet import LAYERTYPE
from nnvm.testing.darknet import __darknetffi__
import nnvm.compiler
if sys.version_info >= (3,):
import urllib.request as urllib2
else:
import urllib2
def _download(url, path, overwrite=False, sizecompare=False):
''' Download from internet'''
if os.path.isfile(path) and not overwrite:
if sizecompare:
file_size = os.path.getsize(path)
res_head = requests.head(url)
res_get = requests.get(url, stream=True)
if 'Content-Length' not in res_head.headers:
res_get = urllib2.urlopen(url)
urlfile_size = int(res_get.headers['Content-Length'])
if urlfile_size != file_size:
print("exist file got corrupted, downloading", path, " file freshly")
_download(url, path, True, False)
return
print('File {} exists, skip.'.format(path))
return
print('Downloading from url {} to {}'.format(url, path))
try:
urllib.request.urlretrieve(url, path)
print('')
except:
urllib.urlretrieve(url, path)
DARKNET_LIB = 'libdarknet2.0.so'
DARKNETLIB_URL = 'https://github.com/siju-samuel/darknet/blob/master/lib/' \
+ DARKNET_LIB + '?raw=true'
_download(DARKNETLIB_URL, DARKNET_LIB)
LIB = __darknetffi__.dlopen('./' + DARKNET_LIB)
LIB = __darknetffi__.dlopen(download_testdata(DARKNETLIB_URL, DARKNET_LIB, module='darknet'))
def _read_memory_buffer(shape, data, dtype='float32'):
length = 1
......@@ -82,6 +54,12 @@ def _get_tvm_output(net, data, build_dtype='float32'):
tvm_out.append(m.get_output(i).asnumpy())
return tvm_out
def _load_net(cfg_url, cfg_name, weights_url, weights_name):
cfg_path = download_testdata(cfg_url, cfg_name, module='darknet')
weights_path = download_testdata(weights_url, weights_name, module='darknet')
net = LIB.load_network(cfg_path.encode('utf-8'), weights_path.encode('utf-8'), 0)
return net
def test_forward(net, build_dtype='float32'):
'''Test network with given input image on both darknet and tvm'''
def get_darknet_output(net, img):
......@@ -125,8 +103,8 @@ def test_forward(net, build_dtype='float32'):
test_image = 'dog.jpg'
img_url = 'https://github.com/siju-samuel/darknet/blob/master/data/' + test_image +'?raw=true'
_download(img_url, test_image)
img = LIB.letterbox_image(LIB.load_image_color(test_image.encode('utf-8'), 0, 0), net.w, net.h)
img_path = download_testdata(img_url, test_image, module='darknet')
img = LIB.letterbox_image(LIB.load_image_color(img_path.encode('utf-8'), 0, 0), net.w, net.h)
darknet_output = get_darknet_output(net, img)
batch_size = 1
data = np.empty([batch_size, img.c, img.h, img.w], dtype)
......@@ -167,9 +145,7 @@ def test_forward_extraction():
weights_name = model_name + '.weights'
cfg_url = 'https://github.com/pjreddie/darknet/blob/master/cfg/' + cfg_name + '?raw=true'
weights_url = 'http://pjreddie.com/media/files/' + weights_name + '?raw=true'
_download(cfg_url, cfg_name)
_download(weights_url, weights_name)
net = LIB.load_network(cfg_name.encode('utf-8'), weights_name.encode('utf-8'), 0)
net = _load_net(cfg_url, cfg_name, weights_url, weights_name)
test_forward(net)
LIB.free_network(net)
......@@ -180,9 +156,7 @@ def test_forward_alexnet():
weights_name = model_name + '.weights'
cfg_url = 'https://github.com/pjreddie/darknet/blob/master/cfg/' + cfg_name + '?raw=true'
weights_url = 'http://pjreddie.com/media/files/' + weights_name + '?raw=true'
_download(cfg_url, cfg_name)
_download(weights_url, weights_name)
net = LIB.load_network(cfg_name.encode('utf-8'), weights_name.encode('utf-8'), 0)
net = _load_net(cfg_url, cfg_name, weights_url, weights_name)
test_forward(net)
LIB.free_network(net)
......@@ -193,9 +167,7 @@ def test_forward_resnet50():
weights_name = model_name + '.weights'
cfg_url = 'https://github.com/pjreddie/darknet/blob/master/cfg/' + cfg_name + '?raw=true'
weights_url = 'http://pjreddie.com/media/files/' + weights_name + '?raw=true'
_download(cfg_url, cfg_name)
_download(weights_url, weights_name)
net = LIB.load_network(cfg_name.encode('utf-8'), weights_name.encode('utf-8'), 0)
net = _load_net(cfg_url, cfg_name, weights_url, weights_name)
test_forward(net)
LIB.free_network(net)
......@@ -206,9 +178,7 @@ def test_forward_yolov2():
weights_name = model_name + '.weights'
cfg_url = 'https://github.com/pjreddie/darknet/blob/master/cfg/' + cfg_name + '?raw=true'
weights_url = 'http://pjreddie.com/media/files/' + weights_name + '?raw=true'
_download(cfg_url, cfg_name)
_download(weights_url, weights_name)
net = LIB.load_network(cfg_name.encode('utf-8'), weights_name.encode('utf-8'), 0)
net = _load_net(cfg_url, cfg_name, weights_url, weights_name)
build_dtype = {}
test_forward(net, build_dtype)
LIB.free_network(net)
......@@ -220,9 +190,7 @@ def test_forward_yolov3():
weights_name = model_name + '.weights'
cfg_url = 'https://github.com/pjreddie/darknet/blob/master/cfg/' + cfg_name + '?raw=true'
weights_url = 'http://pjreddie.com/media/files/' + weights_name + '?raw=true'
_download(cfg_url, cfg_name)
_download(weights_url, weights_name)
net = LIB.load_network(cfg_name.encode('utf-8'), weights_name.encode('utf-8'), 0)
net = _load_net(cfg_url, cfg_name, weights_url, weights_name)
build_dtype = {}
test_forward(net, build_dtype)
LIB.free_network(net)
......
......@@ -3,22 +3,7 @@ from __future__ import absolute_import as _abs
import os
import logging
from .super_resolution import get_super_resolution
def _download(url, filename, overwrite=False):
if os.path.isfile(filename) and not overwrite:
logging.debug('File %s existed, skip.', filename)
return
logging.debug('Downloading from url %s to %s', url, filename)
try:
import urllib.request
urllib.request.urlretrieve(url, filename)
except:
import urllib
urllib.urlretrieve(url, filename)
def _as_abs_path(fname):
cur_dir = os.path.abspath(os.path.dirname(__file__))
return os.path.join(cur_dir, fname)
from tvm.contrib.download import download_testdata
URLS = {
......@@ -30,9 +15,9 @@ URLS = {
# download and add paths
for k, v in URLS.items():
name = k.split('.')[0]
path = _as_abs_path(k)
_download(v, path, False)
locals()[name] = path
relpath = os.path.join('onnx', k)
abspath = download_testdata(v, relpath, module='onnx')
locals()[name] = abspath
# symbol for graph comparison
super_resolution_sym = get_super_resolution()
......@@ -5,8 +5,10 @@ from __future__ import absolute_import as _abs
import os
import sys
import time
import uuid
import shutil
def download(url, path, overwrite=False, size_compare=False, verbose=1):
def download(url, path, overwrite=False, size_compare=False, verbose=1, retries=3):
"""Downloads the file from the internet.
Set the input options correctly to overwrite or do the size comparison
......@@ -53,6 +55,11 @@ def download(url, path, overwrite=False, size_compare=False, verbose=1):
# Stateful start time
start_time = time.time()
dirpath = os.path.dirname(path)
if not os.path.isdir(dirpath):
os.makedirs(dirpath)
random_uuid = str(uuid.uuid4())
tempfile = os.path.join(dirpath, random_uuid)
def _download_progress(count, block_size, total_size):
#pylint: disable=unused-argument
......@@ -68,11 +75,62 @@ def download(url, path, overwrite=False, size_compare=False, verbose=1):
(percent, progress_size / (1024.0 * 1024), speed, duration))
sys.stdout.flush()
while retries >= 0:
# Disable pyling too broad Exception
# pylint: disable=W0703
try:
if sys.version_info >= (3,):
urllib2.urlretrieve(url, path, reporthook=_download_progress)
urllib2.urlretrieve(url, tempfile, reporthook=_download_progress)
print("")
else:
f = urllib2.urlopen(url)
data = f.read()
with open(path, "wb") as code:
with open(tempfile, "wb") as code:
code.write(data)
shutil.move(tempfile, path)
break
except Exception as err:
retries -= 1
if retries == 0:
os.remove(tempfile)
raise err
else:
print("download failed due to {}, retrying, {} attempt{} left"
.format(repr(err), retries, 's' if retries > 1 else ''))
TEST_DATA_ROOT_PATH = os.path.join(os.path.expanduser('~'), '.tvm_test_data')
if not os.path.exists(TEST_DATA_ROOT_PATH):
os.mkdir(TEST_DATA_ROOT_PATH)
def download_testdata(url, relpath, module=None):
"""Downloads the test data from the internet.
Parameters
----------
url : str
Download url.
relpath : str
Relative file path.
module : Union[str, list, tuple], optional
Subdirectory paths under test data folder.
Returns
-------
abspath : str
Absolute file path of downloaded file
"""
global TEST_DATA_ROOT_PATH
if module is None:
module_path = ''
elif isinstance(module, str):
module_path = module
elif isinstance(module, (list, tuple)):
module_path = os.path.join(*module)
else:
raise ValueError("Unsupported module: " + module)
abspath = os.path.join(TEST_DATA_ROOT_PATH, module_path, relpath)
download(url, abspath, overwrite=False, size_compare=True)
return abspath
......@@ -13,7 +13,7 @@ import numpy as np
import tensorflow as tf
from tensorflow.core.framework import graph_pb2
from tvm.contrib import util
from tvm.contrib.download import download_testdata
######################################################################
# Some helper functions
......@@ -136,7 +136,7 @@ class NodeLookup(object):
return ''
return self.node_lookup[node_id]
def get_workload_official(model_url, model_sub_path, temp_dir):
def get_workload_official(model_url, model_sub_path):
""" Import workload from tensorflow official
Parameters
......@@ -158,21 +158,17 @@ def get_workload_official(model_url, model_sub_path, temp_dir):
"""
model_tar_name = os.path.basename(model_url)
from mxnet.gluon.utils import download
temp_path = temp_dir.relpath("./")
path_model = temp_path + model_tar_name
download(model_url, path_model)
model_path = download_testdata(model_url, model_tar_name, module=['tf', 'official'])
dir_path = os.path.dirname(model_path)
import tarfile
if path_model.endswith("tgz") or path_model.endswith("gz"):
tar = tarfile.open(path_model)
tar.extractall(path=temp_path)
if model_path.endswith("tgz") or model_path.endswith("gz"):
tar = tarfile.open(model_path)
tar.extractall(path=dir_path)
tar.close()
else:
raise RuntimeError('Could not decompress the file: ' + path_model)
return temp_path + model_sub_path
raise RuntimeError('Could not decompress the file: ' + model_path)
return os.path.join(dir_path, model_sub_path)
def get_workload(model_path, model_sub_path=None):
""" Import workload from frozen protobuf
......@@ -192,24 +188,18 @@ def get_workload(model_path, model_sub_path=None):
"""
temp = util.tempdir()
if model_sub_path:
path_model = get_workload_official(model_path, model_sub_path, temp)
path_model = get_workload_official(model_path, model_sub_path)
else:
repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/'
model_name = os.path.basename(model_path)
model_url = os.path.join(repo_base, model_path)
from mxnet.gluon.utils import download
path_model = temp.relpath(model_name)
download(model_url, path_model)
path_model = download_testdata(model_url, model_path, module='tf')
# Creates graph from saved graph_def.pb.
with tf.gfile.FastGFile(path_model, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
graph = tf.import_graph_def(graph_def, name='')
temp.remove()
return graph_def
#######################################################################
......@@ -292,7 +282,7 @@ def do_tf_sample(session, data, in_states, num_samples):
def _create_ptb_vocabulary(data_dir):
"""Read the PTB sample data input to create vocabulary"""
data_path = data_dir+'simple-examples/data/'
data_path = os.path.join(data_dir, 'simple-examples/data/')
file_name = 'ptb.train.txt'
def _read_words(filename):
"""Read the data for creating vocabulary"""
......@@ -341,13 +331,10 @@ def get_workload_ptb():
ptb_model_file = 'RNN/ptb/ptb_model_with_lstmblockcell.pb'
import tarfile
from tvm.contrib.download import download
DATA_DIR = './ptb_data/'
if not os.path.exists(DATA_DIR):
os.mkdir(DATA_DIR)
download(sample_url, DATA_DIR+sample_data_file)
t = tarfile.open(DATA_DIR+sample_data_file, 'r')
t.extractall(DATA_DIR)
word_to_id, id_to_word = _create_ptb_vocabulary(DATA_DIR)
file_path = download_testdata(sample_url, sample_data_file, module=['tf', 'ptb_data'])
dir_path = os.path.dirname(file_path)
t = tarfile.open(file_path, 'r')
t.extractall(dir_path)
word_to_id, id_to_word = _create_ptb_vocabulary(dir_path)
return word_to_id, id_to_word, get_workload(ptb_model_file)
from six.moves import urllib
import os
from PIL import Image
import numpy as np
def download(url, path, overwrite=False):
if os.path.exists(path) and not overwrite:
return
print('Downloading {} to {}.'.format(url, path))
urllib.request.urlretrieve(url, path)
from tvm.contrib.download import download_testdata
def get_mobilenet():
url = 'https://docs-assets.developer.apple.com/coreml/models/MobileNet.mlmodel'
dst = 'mobilenet.mlmodel'
real_dst = os.path.abspath(os.path.join(os.path.dirname(__file__), dst))
download(url, real_dst)
real_dst = download_testdata(url, dst, module='coreml')
return os.path.abspath(real_dst)
def get_resnet50():
url = 'https://docs-assets.developer.apple.com/coreml/models/Resnet50.mlmodel'
dst = 'resnet50.mlmodel'
real_dst = os.path.abspath(os.path.join(os.path.dirname(__file__), dst))
download(url, real_dst)
real_dst = download_testdata(url, dst, module='coreml')
return os.path.abspath(real_dst)
def get_cat_image():
url = 'https://gist.githubusercontent.com/zhreshold/bcda4716699ac97ea44f791c24310193/raw/fa7ef0e9c9a5daea686d6473a62aacd1a5885849/cat.png'
dst = 'cat.png'
real_dst = os.path.abspath(os.path.join(os.path.dirname(__file__), dst))
download(url, real_dst)
real_dst = download_testdata(url, dst, module='coreml')
img = Image.open(real_dst).resize((224, 224))
img = np.transpose(img, (2, 0, 1))[np.newaxis, :]
return np.asarray(img)
\ No newline at end of file
......@@ -391,10 +391,9 @@ def test_forward_softmax():
def test_forward_mobilenet():
'''test mobilenet v1 tflite model'''
# MobilenetV1
temp = util.tempdir()
tflite_model_file = tf_testing.get_workload_official(
"http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224.tgz",
"mobilenet_v1_1.0_224.tflite", temp)
"mobilenet_v1_1.0_224.tflite")
with open(tflite_model_file, "rb") as f:
tflite_model_buf = f.read()
data = np.random.uniform(size=(1, 224, 224, 3)).astype('float32')
......@@ -403,7 +402,6 @@ def test_forward_mobilenet():
tvm_output = run_tvm_graph(tflite_model_buf, tvm_data, 'input')
tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]),
rtol=1e-5, atol=1e-5)
temp.remove()
#######################################################################
# Inception V3
......@@ -412,10 +410,9 @@ def test_forward_mobilenet():
def test_forward_inception_v3_net():
'''test inception v3 tflite model'''
# InceptionV3
temp = util.tempdir()
tflite_model_file = tf_testing.get_workload_official(
"https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v3_2018_04_27.tgz",
"inception_v3.tflite", temp)
"inception_v3.tflite")
with open(tflite_model_file, "rb") as f:
tflite_model_buf = f.read()
data = np.random.uniform(size=(1, 299, 299, 3)).astype('float32')
......@@ -424,7 +421,6 @@ def test_forward_inception_v3_net():
tvm_output = run_tvm_graph(tflite_model_buf, tvm_data, 'input')
tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]),
rtol=1e-5, atol=1e-5)
temp.remove()
#######################################################################
# Main
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment