dispatcher.py 10.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12
"""
Template dispatcher module.

A dispatcher is a function that can contains multiple behaviors.
Its specific behavior is can be controlled by DispatchContext.

DispatchContext is used in two ways, usually via different implementation
of the DispatchContext base class.

- During search, we can use it to pass the current proposal from tuner.
- During evaluation, we can use it to set pick the best policy.
"""
13 14
# pylint: disable=invalid-name

15 16
from __future__ import absolute_import as _abs

17 18 19
import logging

import numpy as np
20
from decorator import decorate
21 22 23

from tvm import target as _target

24
from .space import FallbackConfigEntity
25

26 27
logger = logging.getLogger('autotvm')

28 29 30 31 32 33 34 35 36
class DispatchContext(object):
    """
    Base class of dispatch context.

    DispatchContext enables the target and workload
    specific dispatch mechanism for templates.
    """
    current = None

37 38 39
    def __init__(self):
        self._old_ctx = DispatchContext.current

40 41
    def query(self, target, workload):
        """
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
        Query the context to get the specific config for a template.
        If cannot find the result inside this context, this function will query it
        from the upper contexts.

        Parameters
        ----------
        target: Target
            The current target
        workload : Workload
            The current workload.

        Returns
        -------
        cfg : ConfigSpace
            The specific configuration.
        """
        ret = self._query_inside(target, workload)
        if ret is None:
            ret = self._old_ctx.query(target, workload)
        return ret

    def _query_inside(self, target, workload):
        """
        Query the context to get the specific config for a template.
        This function only query config inside this context.
67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147

        Parameters
        ----------
        target: Target
            The current target
        workload : Workload
            The current workload.

        Returns
        -------
        cfg : ConfigSpace
            The specific configuration.
        """
        raise NotImplementedError()

    def __enter__(self):
        self._old_ctx = DispatchContext.current
        DispatchContext.current = self
        return self

    def __exit__(self, ptype, value, trace):
        DispatchContext.current = self._old_ctx


def dispatcher(fworkload):
    """Wrap a workload dispatcher function.

    Parameters
    ----------
    fworkload : function
        The workload extraction function from arguments.

    Returns
    -------
    fdispatcher : function
        A wrapped dispatcher function, which will
        dispatch based on DispatchContext and
        the current workload.
    """
    dispatch_dict = {}
    func_name = fworkload.__name__

    def register(key, func=None, override=False):
        """Register template function.

        Parameters
        ----------
        key : str or List of str
            The template key to identify the template
            under this dispatcher.
        func : function
            The function to be registered.
            The first argument of the function is always
            cfg returned by DispatchContext,
            the rest arguments are the same as the fworkload.
        override : bool
            Whether override existing registration.

        Returns
        -------
        The register function if necessary.
        """
        if isinstance(key, str):
            key = [key]

        def _do_reg(myf):
            for x in key:
                if x in dispatch_dict and not override:
                    raise ValueError(
                        "Key %s is already registered for %s" % (x, func_name))
                dispatch_dict[x] = myf
            return myf

        if func:
            return _do_reg(func)
        return _do_reg

    def dispatch_func(func, *args, **kwargs):
        """The wrapped dispatch function"""
        tgt = _target.current_target()
        workload = func(*args, **kwargs)
148 149 150 151 152 153
        cfg = DispatchContext.current.query(tgt, workload)
        if cfg.is_fallback and not cfg.template_key:
            # first try 'direct' template
            if 'direct' in dispatch_dict:
                return dispatch_dict['direct'](cfg, *args, **kwargs)
            # otherwise pick a random template
154 155
            for v in dispatch_dict.values():
                return v(cfg, *args, **kwargs)
156 157
        else:
            return dispatch_dict[cfg.template_key](cfg, *args, **kwargs)
158 159 160 161

    fdecorate = decorate(fworkload, dispatch_func)
    fdecorate.register = register
    return fdecorate
162 163 164


class ApplyConfig(DispatchContext):
165
    """Apply a deterministic config entity for all queries.
166 167 168 169 170 171 172 173 174 175 176

    Parameters
    ----------
    config : ConfigSpace or ConfigEntity
        The specific configuration we care about.
    """
    def __init__(self, config):
        super(ApplyConfig, self).__init__()
        self._config = config
        self.workload = None

177
    def _query_inside(self, target, workload):
178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194
        """Override query"""
        self.workload = workload
        return self._config


class ApplyHistoryBest(DispatchContext):
    """
    Apply the history best config

    Parameters
    ----------
    records : str or iterator of (MeasureInput, MeasureResult)
        Collection of tuning records.
        If is str, then it should be the filename of a records log file.
                   Each row of this file is an encoded record pair.
        Otherwise, it is an iterator.
    """
195
    def __init__(self, records):
196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253
        super(ApplyHistoryBest, self).__init__()

        self.best_by_targetkey = {}
        self.best_by_model = {}

        if records:
            self.load(records)

    def load(self, records):
        """Load records to this dispatch context

        Parameters
        ----------
        records : str or iterator of (MeasureInput, MeasureResult)
            Collection of tuning records.
            If is str, then it should be the filename of a records log file.
                       Each row of this file is an encoded record pair.
            Otherwise, it is an iterator.
        """
        from ..record import load_from_file

        if isinstance(records, str):
            records = load_from_file(records)
        if not records:
            return

        best_by_targetkey = self.best_by_targetkey
        best_by_model = self.best_by_model

        counter = 0
        for inp, res in records:
            counter += 1
            if res.error_no != 0:
                continue

            # use target keys in tvm target system as key to build best map
            for k in inp.target.keys:
                key = (k, inp.task.workload)
                if key not in best_by_targetkey:
                    best_by_targetkey[key] = (inp, res)
                else:
                    _, other_res = best_by_targetkey[key]
                    if np.mean(other_res.costs) > np.mean(res.costs):
                        best_by_targetkey[key] = (inp, res)

            # use model as key to build best map
            for opt in inp.target.options:
                if opt.startswith("-model"):
                    model = opt[7:]
                    key = (model, inp.task.workload)
                    if key not in best_by_model:
                        best_by_model[key] = (inp, res)
                    else:
                        _, other_res = best_by_model[key]
                        if np.mean(other_res.costs) > np.mean(res.costs):
                            best_by_model[key] = (inp, res)
                    break

254
        logger.debug("Finish loading %d records", counter)
255

256
    def _query_inside(self, target, workload):
257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275
        if target is None:
            raise RuntimeError("Need a target context to find the history best. "
                               "Hint: If your target is llvm, use `with tvm.target.create('llvm'):`"
                               " above the dispatcher call. So does other target. ")

        # first try matching by model
        for opt in target.options:
            if opt.startswith("-model"):
                model = opt[7:]
                key = (model, workload)
                if key in self.best_by_model:
                    return self.best_by_model[key][0].config

        # then try matching by target key
        for k in target.keys:
            key = (k, workload)
            if key in self.best_by_targetkey:
                return self.best_by_targetkey[key][0].config

276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291
        return None


class FallbackContext(DispatchContext):
    """
    A fallback dispatch context.

    Any tunable template can be called under this context.
    This is the root context.
    """

    def __init__(self):
        super(FallbackContext, self).__init__()
        self.memory = {}
        self.silent = False

292 293 294
        # a set to prevent print duplicated message
        self.messages = set()

295 296 297 298
    def _query_inside(self, target, workload):
        key = (str(target), workload)
        if key in self.memory:
            return self.memory[key]
299

300
        if not self.silent:
301 302 303 304 305
            msg = "Cannot find config for target=%s, workload=%s. A fallback configuration "\
                  "is used, which may bring great performance regression." % (target, workload)
            if msg not in self.messages:
                self.messages.add(msg)
                logger.warning(msg)
306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325
        cfg = FallbackConfigEntity()

        # cache this config
        self.memory[key] = cfg
        return cfg

    def clear_cache(self, target, workload):
        """Clear fallback cache. Pass the same argument as _query_inside to this function
        to clean the cache.

        Parameters
        ----------
        target: Target
            The current target
        workload : Workload
            The current workload.
        """
        key = (str(target), workload)
        if key in self.memory:
            del self.memory[key]
326

327
DispatchContext.current = FallbackContext()
328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347

def clear_fallback_cache(target, workload):
    """Clear fallback cache. Pass the same argument as _query_inside to this function
    to clean the cache.

    Parameters
    ----------
    target: Target
        The current target
    workload : Workload
        The current workload.

    Note
    ----
    This is used in alter_op_layout to clear the bad cache created before call topi compute function
    """
    context = DispatchContext.current
    while not isinstance(context, FallbackContext):
        context = context._old_ctx
    context.clear_cache(target, workload)