Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
e3206aa8
Commit
e3206aa8
authored
Mar 28, 2019
by
Haichen Shen
Committed by
Tianqi Chen
Mar 28, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[TEST] Cache test data (#2921)
parent
4ac64fc4
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
111 additions
and
134 deletions
+111
-134
nnvm/tests/python/frontend/coreml/model_zoo/__init__.py
+6
-15
nnvm/tests/python/frontend/darknet/test_forward.py
+15
-47
nnvm/tests/python/frontend/onnx/model_zoo/__init__.py
+4
-19
python/tvm/contrib/download.py
+61
-3
python/tvm/relay/testing/tf.py
+18
-31
tests/python/frontend/coreml/model_zoo/__init__.py
+5
-13
tests/python/frontend/tflite/test_forward.py
+2
-6
No files found.
nnvm/tests/python/frontend/coreml/model_zoo/__init__.py
View file @
e3206aa8
from
six.moves
import
urllib
import
os
from
PIL
import
Image
import
numpy
as
np
def
download
(
url
,
path
,
overwrite
=
False
):
if
os
.
path
.
exists
(
path
)
and
not
overwrite
:
return
print
(
'Downloading {} to {}.'
.
format
(
url
,
path
))
urllib
.
request
.
urlretrieve
(
url
,
path
)
from
tvm.contrib.download
import
download_testdata
def
get_mobilenet
():
url
=
'https://docs-assets.developer.apple.com/coreml/models/MobileNet.mlmodel'
dst
=
'mobilenet.mlmodel'
real_dst
=
os
.
path
.
abspath
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
dst
))
download
(
url
,
real_dst
)
return
os
.
path
.
abspath
(
real_dst
)
real_dst
=
download_testdata
(
url
,
dst
,
module
=
'coreml'
)
return
real_dst
def
get_resnet50
():
url
=
'https://docs-assets.developer.apple.com/coreml/models/Resnet50.mlmodel'
dst
=
'resnet50.mlmodel'
real_dst
=
os
.
path
.
abspath
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
dst
))
download
(
url
,
real_dst
)
return
os
.
path
.
abspath
(
real_dst
)
real_dst
=
download_testdata
(
url
,
dst
,
module
=
'coreml'
)
return
real_dst
def
get_cat_image
():
url
=
'https://gist.githubusercontent.com/zhreshold/bcda4716699ac97ea44f791c24310193/raw/fa7ef0e9c9a5daea686d6473a62aacd1a5885849/cat.png'
dst
=
'cat.png'
real_dst
=
os
.
path
.
abspath
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
dst
))
download
(
url
,
real_dst
)
real_dst
=
download_testdata
(
url
,
dst
,
module
=
'coreml'
)
img
=
Image
.
open
(
real_dst
)
.
resize
((
224
,
224
))
img
=
np
.
transpose
(
img
,
(
2
,
0
,
1
))[
np
.
newaxis
,
:]
return
np
.
asarray
(
img
)
nnvm/tests/python/frontend/darknet/test_forward.py
View file @
e3206aa8
...
...
@@ -12,44 +12,16 @@ import urllib
import
numpy
as
np
import
tvm
from
tvm.contrib
import
graph_runtime
from
tvm.contrib.download
import
download_testdata
from
nnvm
import
frontend
from
nnvm.testing.darknet
import
LAYERTYPE
from
nnvm.testing.darknet
import
__darknetffi__
import
nnvm.compiler
if
sys
.
version_info
>=
(
3
,):
import
urllib.request
as
urllib2
else
:
import
urllib2
def
_download
(
url
,
path
,
overwrite
=
False
,
sizecompare
=
False
):
''' Download from internet'''
if
os
.
path
.
isfile
(
path
)
and
not
overwrite
:
if
sizecompare
:
file_size
=
os
.
path
.
getsize
(
path
)
res_head
=
requests
.
head
(
url
)
res_get
=
requests
.
get
(
url
,
stream
=
True
)
if
'Content-Length'
not
in
res_head
.
headers
:
res_get
=
urllib2
.
urlopen
(
url
)
urlfile_size
=
int
(
res_get
.
headers
[
'Content-Length'
])
if
urlfile_size
!=
file_size
:
print
(
"exist file got corrupted, downloading"
,
path
,
" file freshly"
)
_download
(
url
,
path
,
True
,
False
)
return
print
(
'File {} exists, skip.'
.
format
(
path
))
return
print
(
'Downloading from url {} to {}'
.
format
(
url
,
path
))
try
:
urllib
.
request
.
urlretrieve
(
url
,
path
)
print
(
''
)
except
:
urllib
.
urlretrieve
(
url
,
path
)
DARKNET_LIB
=
'libdarknet2.0.so'
DARKNETLIB_URL
=
'https://github.com/siju-samuel/darknet/blob/master/lib/'
\
+
DARKNET_LIB
+
'?raw=true'
_download
(
DARKNETLIB_URL
,
DARKNET_LIB
)
LIB
=
__darknetffi__
.
dlopen
(
'./'
+
DARKNET_LIB
)
LIB
=
__darknetffi__
.
dlopen
(
download_testdata
(
DARKNETLIB_URL
,
DARKNET_LIB
,
module
=
'darknet'
))
def
_read_memory_buffer
(
shape
,
data
,
dtype
=
'float32'
):
length
=
1
...
...
@@ -82,6 +54,12 @@ def _get_tvm_output(net, data, build_dtype='float32'):
tvm_out
.
append
(
m
.
get_output
(
i
)
.
asnumpy
())
return
tvm_out
def
_load_net
(
cfg_url
,
cfg_name
,
weights_url
,
weights_name
):
cfg_path
=
download_testdata
(
cfg_url
,
cfg_name
,
module
=
'darknet'
)
weights_path
=
download_testdata
(
weights_url
,
weights_name
,
module
=
'darknet'
)
net
=
LIB
.
load_network
(
cfg_path
.
encode
(
'utf-8'
),
weights_path
.
encode
(
'utf-8'
),
0
)
return
net
def
test_forward
(
net
,
build_dtype
=
'float32'
):
'''Test network with given input image on both darknet and tvm'''
def
get_darknet_output
(
net
,
img
):
...
...
@@ -125,8 +103,8 @@ def test_forward(net, build_dtype='float32'):
test_image
=
'dog.jpg'
img_url
=
'https://github.com/siju-samuel/darknet/blob/master/data/'
+
test_image
+
'?raw=true'
_download
(
img_url
,
test_image
)
img
=
LIB
.
letterbox_image
(
LIB
.
load_image_color
(
test_image
.
encode
(
'utf-8'
),
0
,
0
),
net
.
w
,
net
.
h
)
img_path
=
download_testdata
(
img_url
,
test_image
,
module
=
'darknet'
)
img
=
LIB
.
letterbox_image
(
LIB
.
load_image_color
(
img_path
.
encode
(
'utf-8'
),
0
,
0
),
net
.
w
,
net
.
h
)
darknet_output
=
get_darknet_output
(
net
,
img
)
batch_size
=
1
data
=
np
.
empty
([
batch_size
,
img
.
c
,
img
.
h
,
img
.
w
],
dtype
)
...
...
@@ -167,9 +145,7 @@ def test_forward_extraction():
weights_name
=
model_name
+
'.weights'
cfg_url
=
'https://github.com/pjreddie/darknet/blob/master/cfg/'
+
cfg_name
+
'?raw=true'
weights_url
=
'http://pjreddie.com/media/files/'
+
weights_name
+
'?raw=true'
_download
(
cfg_url
,
cfg_name
)
_download
(
weights_url
,
weights_name
)
net
=
LIB
.
load_network
(
cfg_name
.
encode
(
'utf-8'
),
weights_name
.
encode
(
'utf-8'
),
0
)
net
=
_load_net
(
cfg_url
,
cfg_name
,
weights_url
,
weights_name
)
test_forward
(
net
)
LIB
.
free_network
(
net
)
...
...
@@ -180,9 +156,7 @@ def test_forward_alexnet():
weights_name
=
model_name
+
'.weights'
cfg_url
=
'https://github.com/pjreddie/darknet/blob/master/cfg/'
+
cfg_name
+
'?raw=true'
weights_url
=
'http://pjreddie.com/media/files/'
+
weights_name
+
'?raw=true'
_download
(
cfg_url
,
cfg_name
)
_download
(
weights_url
,
weights_name
)
net
=
LIB
.
load_network
(
cfg_name
.
encode
(
'utf-8'
),
weights_name
.
encode
(
'utf-8'
),
0
)
net
=
_load_net
(
cfg_url
,
cfg_name
,
weights_url
,
weights_name
)
test_forward
(
net
)
LIB
.
free_network
(
net
)
...
...
@@ -193,9 +167,7 @@ def test_forward_resnet50():
weights_name
=
model_name
+
'.weights'
cfg_url
=
'https://github.com/pjreddie/darknet/blob/master/cfg/'
+
cfg_name
+
'?raw=true'
weights_url
=
'http://pjreddie.com/media/files/'
+
weights_name
+
'?raw=true'
_download
(
cfg_url
,
cfg_name
)
_download
(
weights_url
,
weights_name
)
net
=
LIB
.
load_network
(
cfg_name
.
encode
(
'utf-8'
),
weights_name
.
encode
(
'utf-8'
),
0
)
net
=
_load_net
(
cfg_url
,
cfg_name
,
weights_url
,
weights_name
)
test_forward
(
net
)
LIB
.
free_network
(
net
)
...
...
@@ -206,9 +178,7 @@ def test_forward_yolov2():
weights_name
=
model_name
+
'.weights'
cfg_url
=
'https://github.com/pjreddie/darknet/blob/master/cfg/'
+
cfg_name
+
'?raw=true'
weights_url
=
'http://pjreddie.com/media/files/'
+
weights_name
+
'?raw=true'
_download
(
cfg_url
,
cfg_name
)
_download
(
weights_url
,
weights_name
)
net
=
LIB
.
load_network
(
cfg_name
.
encode
(
'utf-8'
),
weights_name
.
encode
(
'utf-8'
),
0
)
net
=
_load_net
(
cfg_url
,
cfg_name
,
weights_url
,
weights_name
)
build_dtype
=
{}
test_forward
(
net
,
build_dtype
)
LIB
.
free_network
(
net
)
...
...
@@ -220,9 +190,7 @@ def test_forward_yolov3():
weights_name
=
model_name
+
'.weights'
cfg_url
=
'https://github.com/pjreddie/darknet/blob/master/cfg/'
+
cfg_name
+
'?raw=true'
weights_url
=
'http://pjreddie.com/media/files/'
+
weights_name
+
'?raw=true'
_download
(
cfg_url
,
cfg_name
)
_download
(
weights_url
,
weights_name
)
net
=
LIB
.
load_network
(
cfg_name
.
encode
(
'utf-8'
),
weights_name
.
encode
(
'utf-8'
),
0
)
net
=
_load_net
(
cfg_url
,
cfg_name
,
weights_url
,
weights_name
)
build_dtype
=
{}
test_forward
(
net
,
build_dtype
)
LIB
.
free_network
(
net
)
...
...
nnvm/tests/python/frontend/onnx/model_zoo/__init__.py
View file @
e3206aa8
...
...
@@ -3,22 +3,7 @@ from __future__ import absolute_import as _abs
import
os
import
logging
from
.super_resolution
import
get_super_resolution
def
_download
(
url
,
filename
,
overwrite
=
False
):
if
os
.
path
.
isfile
(
filename
)
and
not
overwrite
:
logging
.
debug
(
'File
%
s existed, skip.'
,
filename
)
return
logging
.
debug
(
'Downloading from url
%
s to
%
s'
,
url
,
filename
)
try
:
import
urllib.request
urllib
.
request
.
urlretrieve
(
url
,
filename
)
except
:
import
urllib
urllib
.
urlretrieve
(
url
,
filename
)
def
_as_abs_path
(
fname
):
cur_dir
=
os
.
path
.
abspath
(
os
.
path
.
dirname
(
__file__
))
return
os
.
path
.
join
(
cur_dir
,
fname
)
from
tvm.contrib.download
import
download_testdata
URLS
=
{
...
...
@@ -30,9 +15,9 @@ URLS = {
# download and add paths
for
k
,
v
in
URLS
.
items
():
name
=
k
.
split
(
'.'
)[
0
]
path
=
_as_abs_path
(
k
)
_download
(
v
,
path
,
False
)
locals
()[
name
]
=
path
relpath
=
os
.
path
.
join
(
'onnx'
,
k
)
abspath
=
download_testdata
(
v
,
relpath
,
module
=
'onnx'
)
locals
()[
name
]
=
abs
path
# symbol for graph comparison
super_resolution_sym
=
get_super_resolution
()
python/tvm/contrib/download.py
View file @
e3206aa8
...
...
@@ -5,8 +5,10 @@ from __future__ import absolute_import as _abs
import
os
import
sys
import
time
import
uuid
import
shutil
def
download
(
url
,
path
,
overwrite
=
False
,
size_compare
=
False
,
verbose
=
1
):
def
download
(
url
,
path
,
overwrite
=
False
,
size_compare
=
False
,
verbose
=
1
,
retries
=
3
):
"""Downloads the file from the internet.
Set the input options correctly to overwrite or do the size comparison
...
...
@@ -53,6 +55,11 @@ def download(url, path, overwrite=False, size_compare=False, verbose=1):
# Stateful start time
start_time
=
time
.
time
()
dirpath
=
os
.
path
.
dirname
(
path
)
if
not
os
.
path
.
isdir
(
dirpath
):
os
.
makedirs
(
dirpath
)
random_uuid
=
str
(
uuid
.
uuid4
())
tempfile
=
os
.
path
.
join
(
dirpath
,
random_uuid
)
def
_download_progress
(
count
,
block_size
,
total_size
):
#pylint: disable=unused-argument
...
...
@@ -68,11 +75,62 @@ def download(url, path, overwrite=False, size_compare=False, verbose=1):
(
percent
,
progress_size
/
(
1024.0
*
1024
),
speed
,
duration
))
sys
.
stdout
.
flush
()
while
retries
>=
0
:
# Disable pyling too broad Exception
# pylint: disable=W0703
try
:
if
sys
.
version_info
>=
(
3
,):
urllib2
.
urlretrieve
(
url
,
path
,
reporthook
=
_download_progress
)
urllib2
.
urlretrieve
(
url
,
tempfile
,
reporthook
=
_download_progress
)
print
(
""
)
else
:
f
=
urllib2
.
urlopen
(
url
)
data
=
f
.
read
()
with
open
(
path
,
"wb"
)
as
code
:
with
open
(
tempfile
,
"wb"
)
as
code
:
code
.
write
(
data
)
shutil
.
move
(
tempfile
,
path
)
break
except
Exception
as
err
:
retries
-=
1
if
retries
==
0
:
os
.
remove
(
tempfile
)
raise
err
else
:
print
(
"download failed due to {}, retrying, {} attempt{} left"
.
format
(
repr
(
err
),
retries
,
's'
if
retries
>
1
else
''
))
TEST_DATA_ROOT_PATH
=
os
.
path
.
join
(
os
.
path
.
expanduser
(
'~'
),
'.tvm_test_data'
)
if
not
os
.
path
.
exists
(
TEST_DATA_ROOT_PATH
):
os
.
mkdir
(
TEST_DATA_ROOT_PATH
)
def
download_testdata
(
url
,
relpath
,
module
=
None
):
"""Downloads the test data from the internet.
Parameters
----------
url : str
Download url.
relpath : str
Relative file path.
module : Union[str, list, tuple], optional
Subdirectory paths under test data folder.
Returns
-------
abspath : str
Absolute file path of downloaded file
"""
global
TEST_DATA_ROOT_PATH
if
module
is
None
:
module_path
=
''
elif
isinstance
(
module
,
str
):
module_path
=
module
elif
isinstance
(
module
,
(
list
,
tuple
)):
module_path
=
os
.
path
.
join
(
*
module
)
else
:
raise
ValueError
(
"Unsupported module: "
+
module
)
abspath
=
os
.
path
.
join
(
TEST_DATA_ROOT_PATH
,
module_path
,
relpath
)
download
(
url
,
abspath
,
overwrite
=
False
,
size_compare
=
True
)
return
abspath
python/tvm/relay/testing/tf.py
View file @
e3206aa8
...
...
@@ -13,7 +13,7 @@ import numpy as np
import
tensorflow
as
tf
from
tensorflow.core.framework
import
graph_pb2
from
tvm.contrib
import
util
from
tvm.contrib
.download
import
download_testdata
######################################################################
# Some helper functions
...
...
@@ -136,7 +136,7 @@ class NodeLookup(object):
return
''
return
self
.
node_lookup
[
node_id
]
def
get_workload_official
(
model_url
,
model_sub_path
,
temp_dir
):
def
get_workload_official
(
model_url
,
model_sub_path
):
""" Import workload from tensorflow official
Parameters
...
...
@@ -158,21 +158,17 @@ def get_workload_official(model_url, model_sub_path, temp_dir):
"""
model_tar_name
=
os
.
path
.
basename
(
model_url
)
from
mxnet.gluon.utils
import
download
temp_path
=
temp_dir
.
relpath
(
"./"
)
path_model
=
temp_path
+
model_tar_name
download
(
model_url
,
path_model
)
model_path
=
download_testdata
(
model_url
,
model_tar_name
,
module
=
[
'tf'
,
'official'
])
dir_path
=
os
.
path
.
dirname
(
model_path
)
import
tarfile
if
path_model
.
endswith
(
"tgz"
)
or
path_model
.
endswith
(
"gz"
):
tar
=
tarfile
.
open
(
path_model
)
tar
.
extractall
(
path
=
temp
_path
)
if
model_path
.
endswith
(
"tgz"
)
or
model_path
.
endswith
(
"gz"
):
tar
=
tarfile
.
open
(
model_path
)
tar
.
extractall
(
path
=
dir
_path
)
tar
.
close
()
else
:
raise
RuntimeError
(
'Could not decompress the file: '
+
path_model
)
return
temp_path
+
model_sub_path
raise
RuntimeError
(
'Could not decompress the file: '
+
model_path
)
return
os
.
path
.
join
(
dir_path
,
model_sub_path
)
def
get_workload
(
model_path
,
model_sub_path
=
None
):
""" Import workload from frozen protobuf
...
...
@@ -192,24 +188,18 @@ def get_workload(model_path, model_sub_path=None):
"""
temp
=
util
.
tempdir
()
if
model_sub_path
:
path_model
=
get_workload_official
(
model_path
,
model_sub_path
,
temp
)
path_model
=
get_workload_official
(
model_path
,
model_sub_path
)
else
:
repo_base
=
'https://github.com/dmlc/web-data/raw/master/tensorflow/models/'
model_name
=
os
.
path
.
basename
(
model_path
)
model_url
=
os
.
path
.
join
(
repo_base
,
model_path
)
from
mxnet.gluon.utils
import
download
path_model
=
temp
.
relpath
(
model_name
)
download
(
model_url
,
path_model
)
path_model
=
download_testdata
(
model_url
,
model_path
,
module
=
'tf'
)
# Creates graph from saved graph_def.pb.
with
tf
.
gfile
.
FastGFile
(
path_model
,
'rb'
)
as
f
:
graph_def
=
tf
.
GraphDef
()
graph_def
.
ParseFromString
(
f
.
read
())
graph
=
tf
.
import_graph_def
(
graph_def
,
name
=
''
)
temp
.
remove
()
return
graph_def
#######################################################################
...
...
@@ -292,7 +282,7 @@ def do_tf_sample(session, data, in_states, num_samples):
def
_create_ptb_vocabulary
(
data_dir
):
"""Read the PTB sample data input to create vocabulary"""
data_path
=
data_dir
+
'simple-examples/data/'
data_path
=
os
.
path
.
join
(
data_dir
,
'simple-examples/data/'
)
file_name
=
'ptb.train.txt'
def
_read_words
(
filename
):
"""Read the data for creating vocabulary"""
...
...
@@ -341,13 +331,10 @@ def get_workload_ptb():
ptb_model_file
=
'RNN/ptb/ptb_model_with_lstmblockcell.pb'
import
tarfile
from
tvm.contrib.download
import
download
DATA_DIR
=
'./ptb_data/'
if
not
os
.
path
.
exists
(
DATA_DIR
):
os
.
mkdir
(
DATA_DIR
)
download
(
sample_url
,
DATA_DIR
+
sample_data_file
)
t
=
tarfile
.
open
(
DATA_DIR
+
sample_data_file
,
'r'
)
t
.
extractall
(
DATA_DIR
)
word_to_id
,
id_to_word
=
_create_ptb_vocabulary
(
DATA_DIR
)
file_path
=
download_testdata
(
sample_url
,
sample_data_file
,
module
=
[
'tf'
,
'ptb_data'
])
dir_path
=
os
.
path
.
dirname
(
file_path
)
t
=
tarfile
.
open
(
file_path
,
'r'
)
t
.
extractall
(
dir_path
)
word_to_id
,
id_to_word
=
_create_ptb_vocabulary
(
dir_path
)
return
word_to_id
,
id_to_word
,
get_workload
(
ptb_model_file
)
tests/python/frontend/coreml/model_zoo/__init__.py
View file @
e3206aa8
from
six.moves
import
urllib
import
os
from
PIL
import
Image
import
numpy
as
np
def
download
(
url
,
path
,
overwrite
=
False
):
if
os
.
path
.
exists
(
path
)
and
not
overwrite
:
return
print
(
'Downloading {} to {}.'
.
format
(
url
,
path
))
urllib
.
request
.
urlretrieve
(
url
,
path
)
from
tvm.contrib.download
import
download_testdata
def
get_mobilenet
():
url
=
'https://docs-assets.developer.apple.com/coreml/models/MobileNet.mlmodel'
dst
=
'mobilenet.mlmodel'
real_dst
=
os
.
path
.
abspath
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
dst
))
download
(
url
,
real_dst
)
real_dst
=
download_testdata
(
url
,
dst
,
module
=
'coreml'
)
return
os
.
path
.
abspath
(
real_dst
)
def
get_resnet50
():
url
=
'https://docs-assets.developer.apple.com/coreml/models/Resnet50.mlmodel'
dst
=
'resnet50.mlmodel'
real_dst
=
os
.
path
.
abspath
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
dst
))
download
(
url
,
real_dst
)
real_dst
=
download_testdata
(
url
,
dst
,
module
=
'coreml'
)
return
os
.
path
.
abspath
(
real_dst
)
def
get_cat_image
():
url
=
'https://gist.githubusercontent.com/zhreshold/bcda4716699ac97ea44f791c24310193/raw/fa7ef0e9c9a5daea686d6473a62aacd1a5885849/cat.png'
dst
=
'cat.png'
real_dst
=
os
.
path
.
abspath
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
dst
))
download
(
url
,
real_dst
)
real_dst
=
download_testdata
(
url
,
dst
,
module
=
'coreml'
)
img
=
Image
.
open
(
real_dst
)
.
resize
((
224
,
224
))
img
=
np
.
transpose
(
img
,
(
2
,
0
,
1
))[
np
.
newaxis
,
:]
return
np
.
asarray
(
img
)
\ No newline at end of file
tests/python/frontend/tflite/test_forward.py
View file @
e3206aa8
...
...
@@ -391,10 +391,9 @@ def test_forward_softmax():
def
test_forward_mobilenet
():
'''test mobilenet v1 tflite model'''
# MobilenetV1
temp
=
util
.
tempdir
()
tflite_model_file
=
tf_testing
.
get_workload_official
(
"http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224.tgz"
,
"mobilenet_v1_1.0_224.tflite"
,
temp
)
"mobilenet_v1_1.0_224.tflite"
)
with
open
(
tflite_model_file
,
"rb"
)
as
f
:
tflite_model_buf
=
f
.
read
()
data
=
np
.
random
.
uniform
(
size
=
(
1
,
224
,
224
,
3
))
.
astype
(
'float32'
)
...
...
@@ -403,7 +402,6 @@ def test_forward_mobilenet():
tvm_output
=
run_tvm_graph
(
tflite_model_buf
,
tvm_data
,
'input'
)
tvm
.
testing
.
assert_allclose
(
np
.
squeeze
(
tvm_output
[
0
]),
np
.
squeeze
(
tflite_output
[
0
]),
rtol
=
1e-5
,
atol
=
1e-5
)
temp
.
remove
()
#######################################################################
# Inception V3
...
...
@@ -412,10 +410,9 @@ def test_forward_mobilenet():
def
test_forward_inception_v3_net
():
'''test inception v3 tflite model'''
# InceptionV3
temp
=
util
.
tempdir
()
tflite_model_file
=
tf_testing
.
get_workload_official
(
"https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v3_2018_04_27.tgz"
,
"inception_v3.tflite"
,
temp
)
"inception_v3.tflite"
)
with
open
(
tflite_model_file
,
"rb"
)
as
f
:
tflite_model_buf
=
f
.
read
()
data
=
np
.
random
.
uniform
(
size
=
(
1
,
299
,
299
,
3
))
.
astype
(
'float32'
)
...
...
@@ -424,7 +421,6 @@ def test_forward_inception_v3_net():
tvm_output
=
run_tvm_graph
(
tflite_model_buf
,
tvm_data
,
'input'
)
tvm
.
testing
.
assert_allclose
(
np
.
squeeze
(
tvm_output
[
0
]),
np
.
squeeze
(
tflite_output
[
0
]),
rtol
=
1e-5
,
atol
=
1e-5
)
temp
.
remove
()
#######################################################################
# Main
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment