Commit baa04599 by Thierry Moreau Committed by GitHub

[CONTRIB] TVM download utility based on urllib2/urlib.request (#1313)

moving nnvm/testing/download.py to python/tvm/contrib/download.py to be used as a general TVM download utility
parent 61370e4b
# pylint: disable=invalid-name, no-member, import-error, no-name-in-module, global-variable-undefined, bare-except
"""Helper utility for downloading"""
from __future__ import print_function
from __future__ import absolute_import as _abs
......@@ -6,7 +5,6 @@ from __future__ import absolute_import as _abs
import os
import sys
import time
import urllib
import requests
if sys.version_info >= (3,):
......@@ -14,21 +12,6 @@ if sys.version_info >= (3,):
else:
import urllib2
def _download_progress(count, block_size, total_size):
"""Show the download progress.
"""
global start_time
if count == 0:
start_time = time.time()
return
duration = time.time() - start_time
progress_size = int(count * block_size)
speed = int(progress_size / (1024 * duration))
percent = int(count * block_size * 100 / total_size)
sys.stdout.write("\r...%d%%, %d MB, %d KB/s, %d seconds passed" %
(percent, progress_size / (1024 * 1024), speed, duration))
sys.stdout.flush()
def download(url, path, overwrite=False, size_compare=False):
"""Downloads the file from the internet.
Set the input options correctly to overwrite or do the size comparison
......@@ -62,8 +45,29 @@ def download(url, path, overwrite=False, size_compare=False):
print('File {} exists, skip.'.format(path))
return
print('Downloading from url {} to {}'.format(url, path))
try:
urllib.request.urlretrieve(url, path, reporthook=_download_progress)
print('')
except:
urllib.urlretrieve(url, path, reporthook=_download_progress)
# Stateful start time
start_time = time.time()
def _download_progress(count, block_size, total_size):
#pylint: disable=unused-argument
"""Show the download progress.
"""
if count == 0:
return
duration = time.time() - start_time
progress_size = int(count * block_size)
speed = int(progress_size / (1024 * duration))
percent = min(int(count * block_size * 100 / total_size), 100)
sys.stdout.write("\r...%d%%, %d MB, %d KB/s, %d seconds passed" %
(percent, progress_size / (1024 * 1024), speed, duration))
sys.stdout.flush()
if sys.version_info >= (3,):
urllib2.urlretrieve(url, path, reporthook=_download_progress)
print("")
else:
f = urllib2.urlopen(url)
data = f.read()
with open(path, "wb") as code:
code.write(data)
......@@ -16,7 +16,7 @@ import numpy as np
from nnvm import compiler
from nnvm.frontend import from_mxnet
from nnvm.testing.download import download
from tvm.contrib.download import download
from tvm.contrib import graph_runtime
from mxnet.model import load_checkpoint
......
......@@ -24,7 +24,7 @@ import tvm
import os
from ctypes import *
from nnvm.testing.download import download
from tvm.contrib.download import download
from nnvm.testing.darknet import __darknetffi__
######################################################################
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment