# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. """Helper utility for downloading""" import os import sys import time import uuid import shutil def download(url, path, overwrite=False, size_compare=False, verbose=1, retries=3): """Downloads the file from the internet. Set the input options correctly to overwrite or do the size comparison Parameters ---------- url : str Download url. path : str Local file path to save downloaded file overwrite : bool, optional Whether to overwrite existing file size_compare : bool, optional Whether to do size compare to check downloaded file. verbose: int, optional Verbose level retries: int, optional Number of time to retry download, default at 3. """ # pylint: disable=import-outside-toplevel import urllib.request as urllib2 if os.path.isfile(path) and not overwrite: if size_compare: import requests file_size = os.path.getsize(path) res_head = requests.head(url) res_get = requests.get(url, stream=True) if 'Content-Length' not in res_head.headers: res_get = urllib2.urlopen(url) url_file_size = int(res_get.headers['Content-Length']) if url_file_size != file_size: print("exist file got corrupted, downloading %s file freshly..." % path) download(url, path, True, False) return print('File {} exists, skip.'.format(path)) return if verbose >= 1: print('Downloading from url {} to {}'.format(url, path)) # Stateful start time start_time = time.time() dirpath = os.path.dirname(path) if dirpath and not os.path.isdir(dirpath): os.makedirs(dirpath) random_uuid = str(uuid.uuid4()) tempfile = os.path.join(dirpath, random_uuid) def _download_progress(count, block_size, total_size): #pylint: disable=unused-argument """Show the download progress. """ if count == 0: return duration = time.time() - start_time progress_size = int(count * block_size) speed = int(progress_size / (1024 * duration)) percent = min(int(count * block_size * 100 / total_size), 100) sys.stdout.write("\r...%d%%, %.2f MB, %d KB/s, %d seconds passed" % (percent, progress_size / (1024.0 * 1024), speed, duration)) sys.stdout.flush() while retries >= 0: # Disable pyling too broad Exception # pylint: disable=W0703 try: if sys.version_info >= (3,): urllib2.urlretrieve(url, tempfile, reporthook=_download_progress) print("") else: f = urllib2.urlopen(url) data = f.read() with open(tempfile, "wb") as code: code.write(data) shutil.move(tempfile, path) break except Exception as err: retries -= 1 if retries == 0: if os.path.exists(tempfile): os.remove(tempfile) raise err print("download failed due to {}, retrying, {} attempt{} left" .format(repr(err), retries, 's' if retries > 1 else '')) if "TEST_DATA_ROOT_PATH" in os.environ: TEST_DATA_ROOT_PATH = os.environ.get("TEST_DATA_ROOT_PATH") else: TEST_DATA_ROOT_PATH = os.path.join(os.path.expanduser('~'), '.tvm_test_data') os.makedirs(TEST_DATA_ROOT_PATH, exist_ok=True) def download_testdata(url, relpath, module=None): """Downloads the test data from the internet. Parameters ---------- url : str Download url. relpath : str Relative file path. module : Union[str, list, tuple], optional Subdirectory paths under test data folder. Returns ------- abspath : str Absolute file path of downloaded file """ global TEST_DATA_ROOT_PATH if module is None: module_path = '' elif isinstance(module, str): module_path = module elif isinstance(module, (list, tuple)): module_path = os.path.join(*module) else: raise ValueError("Unsupported module: " + module) abspath = os.path.join(TEST_DATA_ROOT_PATH, module_path, relpath) download(url, abspath, overwrite=False, size_compare=False) return abspath