build_module.py 14.9 KB
Newer Older
1 2 3 4
"""
Construct the necessary state for the TVM graph runtime
from a Relay expression.
"""
5 6 7
import warnings

from tvm._ffi.runtime_ctypes import TVMContext
8 9 10
from ..build_module import build as _tvm_build_module
from .. import nd as _nd, target as _target, autotvm
from ..contrib import graph_runtime as _graph_rt
11
from . import ir_pass
12
from . import expr
13 14
from .backend import interpreter as _interpreter
from .backend import graph_runtime_codegen as _graph_gen
15

16 17 18 19
# List of optimization pass and level when switch on
OPT_PASS_LEVEL = {
    "SimplifyInference": 0,
    "OpFusion": 1,
20
    "FoldConstant": 2,
21
    "CombineParallelConv2D": 3,
22
    "FoldScaleAxis": 3,
23
    "AlterOpLayout": 3,
24
    "CanonicalizeOps": 3,
25 26
}

27

28 29 30 31 32 33 34 35 36 37 38 39
class BuildConfig(object):
    """Configuration scope to set a build config option.

    Parameters
    ----------
    kwargs
        Keyword arguments of configurations to set.
    """
    current = None
    defaults = {
        "opt_level": 2,
        "add_pass": None,
40
        "fallback_device": None,
41
    }
42

43 44 45 46
    def __init__(self, **kwargs):
        self._old_scope = None
        for k, _ in kwargs.items():
            if k not in BuildConfig.defaults:
47 48
                raise ValueError("invalid argument %s, candidates are %s" %
                                 (k, BuildConfig.defaults.keys()))
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
        self._attr = kwargs

    def __getattr__(self, name):
        if name not in self._attr:
            return BuildConfig.defaults[name]
        return self._attr[name]

    def __enter__(self):
        # pylint: disable=protected-access
        self._old_scope = BuildConfig.current
        attr = BuildConfig.current._attr.copy()
        attr.update(self._attr)
        self._attr = attr
        BuildConfig.current = self
        return self

    def __exit__(self, ptype, value, trace):
        assert self._old_scope
        BuildConfig.current = self._old_scope

    def pass_enabled(self, pass_name):
        """Get whether pass is enabled.

        Parameters
        ----------
        pass_name : str
            The optimization pass name

        Returns
        -------
        enabled : bool
            Whether pass is enabled.
        """
        if self.add_pass and pass_name in self.add_pass:
            return True
        return self.opt_level >= OPT_PASS_LEVEL[pass_name]


BuildConfig.current = BuildConfig()


def build_config(**kwargs):
    """Configure the build behavior by setting config variables.

    Parameters
    ----------
    opt_level: int, default=2
        Optimization level. See OPT_PASS_LEVEL for level of each pass.

    add_pass: set of str
        Optimization pass to be added regardless of optimization level.

101 102 103 104
    fallback_device : str or tvm.TVMContext
        The fallback device. It is also used as the default device for
        operators without specified device during heterogeneous execution.

105 106 107 108 109 110 111 112
    Returns
    -------
    config: BuildConfig
        The build configuration
    """
    return BuildConfig(**kwargs)


113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132
def _bind_params_by_name(func, params):
    """Bind parameters of function by its name."""
    name_dict = {}
    for arg in func.params:
        name = arg.name_hint
        if name in name_dict:
            name_dict[name] = None
        else:
            name_dict[name] = arg
    bind_dict = {}
    for k, v in params.items():
        if k not in name_dict:
            continue
        arg = name_dict[k]
        if arg is None:
            raise ValueError("Multiple args in the function have name %s" % k)
        bind_dict[arg] = expr.const(v)
    return expr.bind(func, bind_dict)


133
def optimize(func, target=None, params=None):
134 135 136 137 138 139 140
    """Perform target invariant optimizations.

    Parameters
    ----------
    func : tvm.relay.Function
        The input to optimization.

141 142 143 144
    target : Optional[:any:`tvm.target.Target`, Dict[int, tvm.target.Target]]
        The optimization target. For heterogeneous compilation, it is a
        dictionary mapping device type to compilation target. For homogeneous
        compilation, it is a build target.
145

146 147 148 149
    params : Optional[Dict[str, tvm.nd.NDArray]]
        Input parameters to the graph that do not change
        during inference time. used for constant folding.

150 151 152 153
    Returns
    -------
    opt_func : tvm.relay.Function
        The optimized version of the function.
154
    """
155 156
    cfg = BuildConfig.current

157 158 159 160 161
    # bind expressions
    if params:
        func = _bind_params_by_name(func, params)

    if cfg.pass_enabled("SimplifyInference"):
162 163 164
        func = ir_pass.infer_type(func)
        func = ir_pass.simplify_inference(func)

165 166 167 168
    if cfg.pass_enabled("CombineParallelConv2D"):
        func = ir_pass.infer_type(func)
        func = ir_pass.combine_parallel_conv2d(func)

169 170
    # The constant folding pass is necessary because FoldScaleAxis pass needs
    # to check the constantness and positiveness of scales.
171 172 173
    if cfg.pass_enabled("FoldConstant"):
        func = ir_pass.fold_constant(func)

174 175 176 177 178
    if cfg.pass_enabled("FoldScaleAxis"):
        func = ir_pass.infer_type(func)
        func = ir_pass.backward_fold_scale_axis(func)
        func = ir_pass.infer_type(func)
        func = ir_pass.forward_fold_scale_axis(func)
179
        func = ir_pass.fold_constant(func)
180

181 182 183 184
    if cfg.pass_enabled("CanonicalizeOps"):
        func = ir_pass.infer_type(func)
        func = ir_pass.canonicalize_ops(func)

185 186 187
    # FIXME(zhiics) Skip AlterOpLayout pass for heterogeneous compilation for
    # now. We probably need to pass target to this pass as well. Fix it in
    # a followup PR.
188
    if cfg.pass_enabled("AlterOpLayout"):
189 190 191 192 193 194 195
        if isinstance(target, _target.Target):
            func = ir_pass.infer_type(func)
            with target:
                func = ir_pass.alter_op_layout(func)
        elif isinstance(target, dict):
            warnings.warn("AlterOpLayout pass is not enabled for heterogeneous"
                          " execution yet.")
196 197 198

    if cfg.pass_enabled("FoldConstant"):
        func = ir_pass.fold_constant(func)
199

200 201 202
    return func


203
def build(func, target=None, target_host=None, params=None):
204
    """Build a function to run on TVM graph runtime.
205 206 207

    Parameters
    ----------
208
    func: relay.Function
209 210
        The function to build.

211 212 213 214
    target : str, :any:`tvm.target.Target`, or dict of str(i.e. device/context
    name) to str/tvm.target.Target, optional
        For heterogeneous compilation, it is a dictionary indicating context to
        target mapping. For homogeneous compilation, it is a build target.
215

216
    target_host : str or :any:`tvm.target.Target`, optional
217 218 219 220 221 222 223 224 225 226
        Host compilation target, if target is device.
        When TVM compiles device specific program such as CUDA,
        we also need host(CPU) side code to interact with the driver
        setup the dimensions and parameters correctly.
        target_host is used to specify the host side codegen target.
        By default, llvm is used if it is enabled,
        otherwise a stackvm intepreter is used.

    params : dict of str to NDArray
        Input parameters to the graph that do not change
227
        during inference time. Used for constant folding.
228 229 230

    Returns
    -------
231 232
    graph_json : str
        The json string that can be accepted by graph runtime.
233

234 235 236 237 238
    mod : tvm.Module
        The module containing necessary libraries.

    params : dict
        The parameters of the final graph.
239
    """
240
    target = target if target else _target.current_target()
241
    if target is None:
242
        raise ValueError("Target is not set in env or passed as argument.")
243 244

    if isinstance(target, dict):
245
        target, fallback_device = _update_heterogeneous_inputs(target)
246 247 248 249 250
    elif isinstance(target, (str, _target.Target)):
        target = _target.create(target)
    else:
        raise ValueError("target must be the type of str, tvm.target.Target," +
                         "or dict of device name to target")
251 252 253 254

    # If current dispatch context is fallback context (the default root context),
    # then load pre-tuned parameters from TopHub
    if isinstance(autotvm.DispatchContext.current, autotvm.FallbackContext):
255 256 257 258
        if isinstance(target, dict):
            tophub_context = autotvm.tophub.context(list(target.values()))
        else:
            tophub_context = autotvm.tophub.context(target)
259 260 261
    else:
        tophub_context = autotvm.util.EmptyContext()

262 263
    cfg = BuildConfig.current

264
    with tophub_context:
265
        func = optimize(func, target, params)
266 267 268 269
        # Annotate the ops for heterogeneous execution.
        if isinstance(target, dict):
            func, target = _run_device_annotation_passes(func, target,
                                                         fallback_device)
270 271
        # Fuse ops before running code gen
        func = ir_pass.infer_type(func)
272
        func = ir_pass.fuse_ops(func, cfg.opt_level)
273 274 275
        # Graph code generation
        func = ir_pass.infer_type(func)
        graph_gen = _graph_gen.GraphRuntimeCodegen(mod=None, target=target)
276
        graph_json, lowered_funcs, params = graph_gen.codegen(func)
277 278
        mod = _tvm_build_module(
            lowered_funcs, target=target, target_host=target_host)
279
    return graph_json, mod, params
280 281


282
def _update_heterogeneous_inputs(target):
283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305
    """Update the target and fallback device required for heterogeneous
    compilation. CPU is used as the fallback device if it wasn't provided.
    Meanwhile, a CPU device type and "llvm" pair will be added to the target
    dictionary in this case.

    Parameters
    ----------
    target : dict of str(i.e. device/context name) to str/tvm.target.Target.
        A dict contains context to target pairs.

    Returns
    -------
    device_target : dict of int to tvm.target.Target.
        The updated device type to target dict.

    fallback_device : int
        The updated fallback device type.
    """
    if not isinstance(target, dict):
        raise ValueError("target must be dict of device name to target for " +
                         "heterogeneous execution, but received %s."
                         % type(target))

306
    fallback_device = BuildConfig.current.fallback_device
307 308 309 310 311 312 313 314 315 316
    if fallback_device is None:
        # cpu is used as the default fallback device when heterogeneous
        # execution is needed, but no fallback device is provided.
        fallback_device = _nd.cpu(0).device_type
        target[fallback_device] = str(_target.create("llvm"))
    elif isinstance(fallback_device, str):
        fallback_device = _nd.context(fallback_device).device_type
    elif isinstance(fallback_device, TVMContext):
        fallback_device = fallback_device.device_type
    else:
317
        raise ValueError("fallback_device expects the type of str or " +
318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379
                         "TVMContext, but received %s." % type(fallback_device))

    device_target = {}
    for dev, tgt in target.items():
        device_target[_nd.context(dev).device_type] = _target.create(tgt)

    if fallback_device not in device_target:
        raise ValueError("%s is used as the default device, but the target" +
                         "is not provided."
                         % _nd.context(fallback_device).device_name)
    return device_target, fallback_device


def _run_device_annotation_passes(func, target, fallback_device):
    """Execute the device annotation passes to update the input program and
    target information.

    Parameters
    ----------
    func: tvm.relay.Function
        The function where annotation passes will be execute at.

    target : Dict[int, tvm.target.Target]
        A dict contains device type to target pairs.

    fallback_device : int
        The fallback device type.

    Returns
    -------
    target : Dict[int, tvm.target.Target]
        The updated device type to target dict.

    func : tvm.relay.Function
        The updated func.
    """
    func = ir_pass.infer_type(func)
    func = ir_pass.rewrite_annotated_ops(func, fallback_device)
    device_map = ir_pass.collect_device_info(func)
    # The expression to device type map will be empty if all or none of
    # the expressions in the `func` are annotated because this map is
    # obtained by propagating the device information in the device copy
    # operator. None of the above cases needs device copy operator.
    if not device_map:
        annotation_map = ir_pass.collect_device_annotation_ops(func)
        # No annotation.
        if not annotation_map:
            target = {0: target[fallback_device]}
        else:
            dev_type = next(iter(annotation_map.values()))
            # All annotated with the same device type.
            if all(val == dev_type for val in annotation_map.values()):
                target = {0: target[dev_type]}
            else:
                raise RuntimeError("Expressions in the function are "
                                   "annotated with various device types,"
                                   "but not device copy operators "
                                   "found. Please check the "
                                   "RewriteAnnotation pass.")
    return func, target


380 381 382 383 384 385 386
class GraphExecutor(_interpreter.Executor):
    """Wrapper around Executor interface.

    This executor is used for debug and testing purpoes.

    Parameters
    ----------
387
    mod : :py:class:`~tvm.relay.module.Module`
388 389
        The module to support the execution.

390
    ctx : :py:class:`TVMContext`
391 392
        The runtime context to run the code on.

393
    target : :py:class:`Target`
394 395
        The target option to build the function.
    """
396

397 398 399 400 401 402
    def __init__(self, mod, ctx, target):
        self.mod = mod
        self.ctx = ctx
        self.target = target

    def _make_executor(self, func):
403 404 405
        graph_json, mod, params = build(func, target=self.target)
        gmodule = _graph_rt.create(graph_json, mod, self.ctx)
        if params:
406
            gmodule.set_input(**params)
407 408 409

        def _graph_wrapper(*args, **kwargs):
            args = self._convert_args(func, args, kwargs)
410 411 412 413 414
            # Create map of inputs.
            for i, arg in enumerate(args):
                gmodule.set_input(i, arg)
            # Run the module, and fetch the output.
            gmodule.run()
415 416
            # make a copy so multiple invocation won't hurt perf.
            return gmodule.get_output(0).copyto(_nd.cpu(0))
417 418 419 420 421 422 423 424 425 426 427 428 429 430 431

        return _graph_wrapper


def create_executor(kind="debug",
                    mod=None,
                    ctx=None,
                    target="llvm"):
    """Factory function to create an executor.

    Parameters
    ----------
    kind : str
        The type of executor

432
    mod : :py:class:`~tvm.relay.module.Module`
433
        The Relay module containing collection of functions
434

435
    ctx : :py:class:`tvm.TVMContext`
436 437
        The context to execute the code.

438
    target : :py:class:`tvm.Target`
439 440 441 442 443 444 445 446 447 448 449
        The corresponding context
    """
    if ctx is not None:
        assert ctx.device_type == _nd.context(str(target), 0).device_type
    else:
        ctx = _nd.context(str(target), 0)

    if isinstance(target, str):
        target = _target.create(target)
    if kind == "debug":
        return _interpreter.Interpreter(mod, ctx, target)
450
    if kind == "graph":
451
        return GraphExecutor(mod, ctx, target)
452
    raise RuntimeError("unknown mode {0}".format(mode))