tophub.py 7.68 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17 18 19 20
"""
TopHub: Tensor Operator Hub
To get the best performance, we typically need auto-tuning for the specific devices.
TVM releases pre-tuned parameters in TopHub for some common networks and hardware targets.
21
TVM will download these parameters for you when you call relay.build.
22
"""
23
# pylint: disable=invalid-name
24 25 26

import logging
import os
27
import sys
28 29 30 31

from .task import ApplyHistoryBest
from .. import target as _target
from ..contrib.download import download
32
from .record import load_from_file
33 34 35 36 37 38 39 40 41 42
from .util import EmptyContext

# environment variable to read TopHub location
AUTOTVM_TOPHUB_LOC_VAR = "TOPHUB_LOCATION"

# default location of TopHub
AUTOTVM_TOPHUB_DEFAULT_LOC = "https://raw.githubusercontent.com/uwsampl/tvm-distro/master/tophub"

# value of AUTOTVM_TOPHUB_LOC_VAR to specify to not read from TopHub
AUTOTVM_TOPHUB_NONE_LOC = "NONE"
43

44
# root path to store TopHub files
45 46
AUTOTVM_TOPHUB_ROOT_PATH = os.path.join(os.path.expanduser('~'), ".tvm", "tophub")

47 48
# the version of each package
PACKAGE_VERSION = {
49 50
    'arm_cpu':          "v0.06",
    'llvm':             "v0.04",
51

52 53 54 55 56
    'cuda':             "v0.08",
    'rocm':             "v0.04",
    'opencl':           "v0.04",
    'mali':             "v0.06",
    'intel_graphics':   "v0.02",
57

58
    'vta':              "v0.08",
59 60
}

61
logger = logging.getLogger('autotvm')
62 63 64 65 66

def _alias(name):
    """convert alias for some packages"""
    table = {
        'vtacpu': 'vta',
67 68

        'metal': 'opencl',
69
        'vulkan': 'opencl',
70
        'nvptx': 'cuda',
71 72 73
    }
    return table.get(name, name)

74 75 76
def _get_tophub_location():
    location = os.getenv(AUTOTVM_TOPHUB_LOC_VAR, None)
    return AUTOTVM_TOPHUB_DEFAULT_LOC if location is None else location
77

78
def context(target, extra_files=None):
79
    """Return the dispatch context with pre-tuned parameters.
80 81
    This function will load the corresponding *.log files in AUTOTVM_TOPHUB_ROOT_PATH.
    If cannot find them, it will download them from TopHub github repo.
82 83 84 85
    Users can also add their own files in argument `extra_files`.

    Parameters
    ----------
86
    target: Target or List of Target
87 88 89 90
        The compilation target
    extra_files: list of str, optional
        Extra log files to load
    """
91 92 93 94
    tophub_location = _get_tophub_location()
    if tophub_location == AUTOTVM_TOPHUB_NONE_LOC:
        return EmptyContext()

95
    best_context = ApplyHistoryBest([])
96

97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113
    targets = target if isinstance(target, (list, tuple)) else [target]

    for tgt in targets:
        if isinstance(tgt, str):
            tgt = _target.create(tgt)

        possible_names = []
        for opt in tgt.options:
            if opt.startswith("-device"):
                device = _alias(opt[8:])
                possible_names.append(device)
        possible_names.append(tgt.target_name)

        all_packages = list(PACKAGE_VERSION.keys())
        for name in possible_names:
            name = _alias(name)
            if name in all_packages:
114
                if not check_backend(tophub_location, name):
115
                    continue
116 117 118 119

                filename = "%s_%s.log" % (name, PACKAGE_VERSION[name])
                best_context.load(os.path.join(AUTOTVM_TOPHUB_ROOT_PATH, filename))
                break   # only load one file to avoid some fallback template mismatch problem
120 121 122 123 124 125 126 127

    if extra_files:
        for filename in extra_files:
            best_context.load(filename)

    return best_context


128
def check_backend(tophub_location, backend):
129 130 131 132 133 134
    """Check whether have pre-tuned parameters of the certain target.
    If not, will download it.

    Parameters
    ----------
    backend: str
135
        The name of backend.
136 137 138 139 140

    Returns
    ----------
    success: bool
        Whether the check is successful.
141 142
    """
    backend = _alias(backend)
143
    assert backend in PACKAGE_VERSION, 'Cannot find backend "%s" in TopHub' % backend
144

145 146 147
    version = PACKAGE_VERSION[backend]
    package_name = "%s_%s.log" % (backend, version)
    if os.path.isfile(os.path.join(AUTOTVM_TOPHUB_ROOT_PATH, package_name)):
148
        return True
149

150
    # pylint: disable=import-outside-toplevel
151 152 153 154 155
    if sys.version_info >= (3,):
        import urllib.request as urllib2
    else:
        import urllib2
    try:
156
        download_package(tophub_location, package_name)
157
        return True
158 159
    except urllib2.URLError as e:
        logging.warning("Failed to download tophub package for %s: %s", backend, e)
160
        return False
161 162


163
def download_package(tophub_location, package_name):
164
    """Download pre-tuned parameters of operators for a backend
165

166 167
    Parameters
    ----------
168 169 170
    tophub_location: str
        The location to download TopHub parameters from

171 172
    package_name: str
        The name of package
173
    """
174 175 176 177 178 179 180 181 182 183
    rootpath = AUTOTVM_TOPHUB_ROOT_PATH

    if not os.path.isdir(rootpath):
        # make directory
        splits = os.path.split(rootpath)
        for j in range(1, len(splits)+1):
            path = os.path.join(*splits[:j])
            if not os.path.isdir(path):
                os.mkdir(path)

184 185 186
    download_url = "{0}/{1}".format(tophub_location, package_name)
    logger.info("Download pre-tuned parameters package from %s", download_url)
    download(download_url, os.path.join(rootpath, package_name), True, verbose=0)
187 188 189 190 191


# global cache for load_reference_log
REFERENCE_LOG_CACHE = {}

192
def load_reference_log(backend, model, workload_name):
193 194 195 196 197 198 199 200
    """ Load reference log from TopHub to support fallback in template.
    Template will use these reference logs to choose fallback config.

    Parameters
    ----------
    backend: str
        The backend name
    model: str
201
        The name of the device model
202 203 204 205 206 207 208 209 210 211
    workload_name: str
        The name of the workload. (The first item in the workload tuple)
    """

    backend = _alias(backend)
    version = PACKAGE_VERSION[backend]
    package_name = "%s_%s.log" % (backend, version)
    filename = os.path.join(AUTOTVM_TOPHUB_ROOT_PATH, package_name)

    global REFERENCE_LOG_CACHE
212
    key = (backend, model, workload_name)
213 214 215

    if key not in REFERENCE_LOG_CACHE:
        tmp = []
216 217 218 219 220
        # Download the config file from tophub if not exists.
        if not os.path.exists(filename):
            tophub_location = _get_tophub_location()
            download_package(tophub_location, package_name)
        if os.path.isfile(filename): # in case download failed
221 222
            find = False
            inp = None
223
            counts = {}
224
            for inp, res in load_from_file(filename):
225
                counts[inp.target.model] = counts.get(inp.target.model, 0) + 1
226 227 228
                if model == inp.target.model:
                    find = True
                    break
229
            # if device model is not find, use the device model with the most tuned workloads
230 231
            if not find and counts:
                model = max(counts.items(), key=lambda k: k[1])[0]
232 233

            for inp, res in load_from_file(filename):
234
                if model == inp.target.model and inp.task.workload[0] == workload_name:
235 236 237 238
                    tmp.append((inp, res))
        REFERENCE_LOG_CACHE[key] = tmp

    return REFERENCE_LOG_CACHE[key]