download.py 5.09 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
Yao Wang committed
17 18 19 20 21 22 23
"""Helper utility for downloading"""
from __future__ import print_function
from __future__ import absolute_import as _abs

import os
import sys
import time
24 25
import uuid
import shutil
Yao Wang committed
26

27
def download(url, path, overwrite=False, size_compare=False, verbose=1, retries=3):
Yao Wang committed
28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43
    """Downloads the file from the internet.
    Set the input options correctly to overwrite or do the size comparison

    Parameters
    ----------
    url : str
        Download url.

    path : str
        Local file path to save downloaded file

    overwrite : bool, optional
        Whether to overwrite existing file

    size_compare : bool, optional
        Whether to do size compare to check downloaded file.
44

45 46
    verbose: int, optional
        Verbose level
47 48 49

    retries: int, optional
        Number of time to retry download, default at 3.
50
    """
51 52 53 54 55
    if sys.version_info >= (3,):
        import urllib.request as urllib2
    else:
        import urllib2

Yao Wang committed
56 57
    if os.path.isfile(path) and not overwrite:
        if size_compare:
58
            import requests
Yao Wang committed
59 60 61 62 63 64 65 66 67 68 69 70
            file_size = os.path.getsize(path)
            res_head = requests.head(url)
            res_get = requests.get(url, stream=True)
            if 'Content-Length' not in res_head.headers:
                res_get = urllib2.urlopen(url)
            url_file_size = int(res_get.headers['Content-Length'])
            if url_file_size != file_size:
                print("exist file got corrupted, downloading %s file freshly..." % path)
                download(url, path, True, False)
                return
        print('File {} exists, skip.'.format(path))
        return
71 72 73

    if verbose >= 1:
        print('Downloading from url {} to {}'.format(url, path))
74 75 76

    # Stateful start time
    start_time = time.time()
77
    dirpath = os.path.dirname(path)
78
    if dirpath and not os.path.isdir(dirpath):
79 80 81
        os.makedirs(dirpath)
    random_uuid = str(uuid.uuid4())
    tempfile = os.path.join(dirpath, random_uuid)
82 83 84 85 86 87 88 89 90 91 92

    def _download_progress(count, block_size, total_size):
        #pylint: disable=unused-argument
        """Show the download progress.
        """
        if count == 0:
            return
        duration = time.time() - start_time
        progress_size = int(count * block_size)
        speed = int(progress_size / (1024 * duration))
        percent = min(int(count * block_size * 100 / total_size), 100)
93 94
        sys.stdout.write("\r...%d%%, %.2f MB, %d KB/s, %d seconds passed" %
                         (percent, progress_size / (1024.0 * 1024), speed, duration))
95 96
        sys.stdout.flush()

97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113
    while retries >= 0:
        # Disable pyling too broad Exception
        # pylint: disable=W0703
        try:
            if sys.version_info >= (3,):
                urllib2.urlretrieve(url, tempfile, reporthook=_download_progress)
                print("")
            else:
                f = urllib2.urlopen(url)
                data = f.read()
                with open(tempfile, "wb") as code:
                    code.write(data)
            shutil.move(tempfile, path)
            break
        except Exception as err:
            retries -= 1
            if retries == 0:
114 115
                if os.path.exists(tempfile):
                    os.remove(tempfile)
116 117 118 119 120 121 122
                raise err
            else:
                print("download failed due to {}, retrying, {} attempt{} left"
                      .format(repr(err), retries, 's' if retries > 1 else ''))


TEST_DATA_ROOT_PATH = os.path.join(os.path.expanduser('~'), '.tvm_test_data')
123 124
os.makedirs(TEST_DATA_ROOT_PATH, exist_ok=True)

125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151

def download_testdata(url, relpath, module=None):
    """Downloads the test data from the internet.

    Parameters
    ----------
    url : str
        Download url.

    relpath : str
        Relative file path.

    module : Union[str, list, tuple], optional
        Subdirectory paths under test data folder.

    Returns
    -------
    abspath : str
        Absolute file path of downloaded file
    """
    global TEST_DATA_ROOT_PATH
    if module is None:
        module_path = ''
    elif isinstance(module, str):
        module_path = module
    elif isinstance(module, (list, tuple)):
        module_path = os.path.join(*module)
152
    else:
153 154 155 156
        raise ValueError("Unsupported module: " + module)
    abspath = os.path.join(TEST_DATA_ROOT_PATH, module_path, relpath)
    download(url, abspath, overwrite=False, size_compare=True)
    return abspath