build_module.py 20.7 KB
Newer Older
1
"""The build utils in python.
2

3 4
This module provides the functions to transform schedule to
LoweredFunc and compiled Module.
5
"""
6
from __future__ import absolute_import as _abs
7 8
import warnings

xqdan committed
9
from ._ffi.function import Function
10
from ._ffi.node import NodeBase, register_node
11
from . import api
12
from . import _api_internal
13 14 15 16
from . import tensor
from . import schedule
from . import expr
from . import ir_pass
17
from . import stmt as _stmt
18
from . import container
19
from . import module
20
from . import codegen
21
from . import ndarray
22
from . import target as _target
23
from . import make
24

25
class DumpIR(object):
26 27 28
    """
    Dump IR for each pass.
    With it, you can dump ir just like gcc/llvm.
29

30 31 32
    How to use:
    -----------
    .. code-block:: python
33

34 35
        with tvm.build_config(dump_pass_ir=True)
            run()
36 37 38 39 40 41 42
    """
    scope_level = 0
    def __init__(self):
        self._pass_id = 0
        self._recover_list = []

    def decorate(self, func):
43
        """ decorate the pass function"""
44
        def dump(*args, **kwargs):
45
            """dump function"""
46 47 48
            retv = func(*args, **kwargs)
            if not isinstance(retv, (_stmt.Stmt, container.LoweredFunc, container.Array)):
                return retv
49 50
            fname = func.func_name if hasattr(func, 'func_name') else func.__name__
            pname = str(self._pass_id) + "_" + fname + "_ir.cc"
51 52 53 54 55 56 57 58 59 60 61 62
            with open(pname, "a") as f:
                out = retv.body if isinstance(retv, container.LoweredFunc) else retv
                f.write(str(out))
                if isinstance(retv, container.Array):
                    for x in retv:
                        out = x.body if isinstance(x, container.LoweredFunc) else x
                        f.write("---------%s\n%s\n-----------\n"%(x.name, str(out)))
                self._pass_id += 1
            return retv
        return dump

    def decorate_irpass(self):
63
        """decorate ir_pass and ScheduleOps"""
64 65 66 67 68 69 70 71
        self._old_sgpass = schedule.ScheduleOps
        schedule.ScheduleOps = self.decorate(schedule.ScheduleOps)
        vset = vars(ir_pass)
        k = v = 0
        def recover():
            vset[k] = v
        for k, v in vset.items():
            self._recover_list.append(recover)
xqdan committed
72
            vset[k] = self.decorate(v) if isinstance(v, Function) else v
73

74 75 76 77 78 79 80 81
    def decorate_custompass(self, custom_pass):
        """decorate given list of custom passes, and return decorated passes"""
        custom_pass = custom_pass if custom_pass else []
        pass_list = []
        for idx, x in enumerate(custom_pass):
            x[1].__name__ = "custom{}_phase{}".format(idx, x[0])
            pass_list += [(x[0], self.decorate(x[1]))]
        return pass_list
82 83

    def enter(self):
84
        """only decorate outermost nest"""
85 86 87 88 89 90 91
        if DumpIR.scope_level > 0:
            return
        self.decorate_irpass()
        self._pass_id = 0
        DumpIR.scope_level += 1

    def exit(self):
92
        """recover outermost nest"""
93 94 95 96 97 98 99 100
        if DumpIR.scope_level > 1:
            return
        # recover decorated functions
        for f in self._recover_list:
            f()
        schedule.ScheduleOps = self._old_sgpass
        DumpIR.scope_level -= 1

101

102 103
@register_node
class BuildConfig(NodeBase):
104 105
    """Configuration scope to set a build config option.

106 107 108 109 110 111 112 113 114
    Note
    ----
    This object is backed by node system in C++, with arguments that can be
    exchanged between python and C++.

    Do not construct directly, use build_config instead.

    The fields that are backed by the C++ node are immutable once an instance
    is constructed. See _node_defaults for the fields.
115
    """
116 117

    _node_defaults = {
118
        "auto_unroll_max_step": 0,
119 120
        "auto_unroll_max_depth": 8,
        "auto_unroll_max_extent": 0,
121 122
        "unroll_explicit": True,
        "detect_global_barrier": False,
123
        "partition_const_loop": False,
124 125 126
        "offset_factor": 0,
        "data_alignment": -1,
        "restricted_func": True,
127
        "double_buffer_split_loop": 1,
128
        "dump_pass_ir": False,
129 130
        "instrument_bound_checkers": False,
        "disable_select_rewriting": False
131
    }
132
    _dump_ir = DumpIR()
133 134 135 136 137 138 139 140 141 142 143 144

    # pylint: disable=no-member
    def __init__(self, handle):
        """Initialize the function with handle

        Parameters
        ----------
        handle : SymbolHandle
            the handle to the underlying C++ Symbol
        """
        super(BuildConfig, self).__init__(handle)
        self.handle = handle
145 146 147 148 149 150 151 152 153 154

    @property
    def add_lower_pass(self):
        size = _api_internal._BuildConfigGetAddLowerPassInfo(self)
        result = []
        for i in range(size):
            phase = _api_internal._BuildConfigGetAddLowerPassInfo(self, i, True)
            func = _api_internal._BuildConfigGetAddLowerPassInfo(self, i, False)
            result += [(phase, func)]
        return result
155

156 157 158 159 160 161 162
    @add_lower_pass.setter
    def add_lower_pass(self, value):
        add_lower_pass_args = []
        for x in value:
            add_lower_pass_args += [x[0], x[1]]
        _api_internal._BuildConfigSetAddLowerPass(self, *add_lower_pass_args)

163 164
    def __enter__(self):
        # pylint: disable=protected-access
165 166 167
        _api_internal._EnterBuildConfigScope(self)
        if self.dump_pass_ir:
            BuildConfig._dump_ir.enter()
168 169 170
        return self

    def __exit__(self, ptype, value, trace):
171 172 173
        if self.dump_pass_ir:
            BuildConfig._dump_ir.exit()
        _api_internal._ExitBuildConfigScope()
174

175 176 177 178 179
    def __setattr__(self, name, value):
        if name in BuildConfig._node_defaults:
            raise AttributeError(
                "'%s' object cannot set attribute '%s'" % (str(type(self)), name))
        return super(BuildConfig, self).__setattr__(name, value)
180

181

182
def current_build_config():
183
    """Get the current build configuration."""
184 185
    return _api_internal._GetCurrentBuildConfig()

186

187 188 189 190 191 192
def build_config(**kwargs):
    """Configure the build behavior by setting config variables.

    Parameters
    ----------
    auto_unroll_max_step: int, default=0
193 194
        Threshold of number of steps in the loop to be automatically unrolled.
        This takes inner loop count into consideration.
195

Lianmin Zheng committed
196
    auto_unroll_max_depth: int, default=8
197
        The maximum nested level of loops that can be automatically unrolled.
198 199 200 201 202 203 204 205 206 207

    unroll_explicit: bool, default=True
        Whether explicitly unroll the loop, if set false, the unroll hint will
        be passed to the CodeGen phase, which may generate pragma unroll hint.
        Set this to be true if CodeGen support unroll pragma and
        when we want to be more readable.

    detect_global_barrier: bool, default=True
        Whether detect global barrier.

208 209 210
    partition_const_loop: bool, default=False
        Whether partition const loop

211 212
    data_alignment: int, optional
        The alignment of data pointer in bytes.
213
        If -1 is passed, the alignment will be set to TVM's internal default.
214

215 216 217 218
    offset_factor: int, default=0
        The factor used in default buffer declaration.
        If specified as 0, offset field is not used.

219 220 221 222 223 224
    restricted_func: bool, default=True
        Whether build restricted function.
        That is each buffer argument to the function are guaranteed
        not to overlap. This enables more optimization.
        Corresponds to restricted keyword in C99

225 226 227 228
    double_buffer_split_loop: int, default=2
        Whether split the loop with factor. If it is zero, no splitting will happen.
        It it is bigger than one, the logic will do a split with factor equals the integer
        and unroll the inner loop. This allows the buffer fetching won't contain condition.
229

Lianmin Zheng committed
230
    add_lower_pass: list of tuple (phase, function(Stmt->Stmt)), default=None
231
        phase contains an integer on which optimization pass we apply the pass.
232 233
        Additional lowering passes to be applied before make_api.

234 235
    dump_pass_ir: dump ir of each pass into file idx_passname_ir.cc, default=False

236 237 238 239 240
    Returns
    -------
    config: BuildConfig
        The build configuration
    """
241 242 243 244
    node_args = {k: v if k not in kwargs else kwargs[k]
                 for k, v in BuildConfig._node_defaults.items()}
    config = make.node("BuildConfig", **node_args)

245
    if "add_lower_pass" in kwargs:
246
        config.add_lower_pass = kwargs["add_lower_pass"]
247

248
    return config
249

250 251 252 253 254 255 256 257
def get_binds(args, binds=None):
    """Internal function to get binds and arg_list given arguments.

    Parameters
    ----------
    args : list of Buffer or Tensor or Var
        The argument lists to the function.

258 259 260 261
    binds : dict of :any:`Tensor` to :any:`Buffer`, optional
        Dictionary that maps the Tensor to Buffer which specified the data layout
        requirement of the function. By default, a new compact buffer is created
        for each tensor in the argument.
262 263 264 265 266 267 268 269 270 271

    Returns
    -------
    binds: dict
        The bind specification

    arg_list: list
        The list of symbolic buffers of arguments.
    """
    binds = {} if binds is None else binds.copy()
272
    cfg = current_build_config()
273 274 275
    arg_list = []
    for x in args:
        if isinstance(x, tensor.Tensor):
276 277 278 279 280 281 282 283 284 285
            if x not in binds:
                buf = api.decl_buffer(x.shape,
                                      dtype=x.dtype,
                                      name=x.name,
                                      data_alignment=cfg.data_alignment,
                                      offset_factor=cfg.offset_factor)
                binds[x] = buf
                arg_list.append(buf)
            else:
                arg_list.append(binds[x])
286 287 288 289 290 291 292
        elif isinstance(x, schedule.Buffer):
            arg_list.append(x)
        elif isinstance(x, expr.Var):
            arg_list.append(x)
        else:
            raise ValueError("args must be Tensor, Buffer or Var")
    return binds, arg_list
293

294

295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313
def form_body(sch):
    """According to the given schedule, form the raw body
    Parameters
    ----------
    sch : tvm.schedule.Schedule
    The given scheduler to form the raw body

    Returns
    -------
    The body formed according to the given schedule
    """
    # normalize schedule first
    sch = sch.normalize()
    bounds = schedule.InferBound(sch)
    stmt = schedule.ScheduleOps(sch, bounds)
    stmt = ir_pass.InjectPrefetch(stmt)
    return stmt


314
def lower(sch,
315 316 317
          args,
          name="default_function",
          binds=None,
318
          simple_mode=False):
319
    """Lowering step before build into target.
320 321 322

    Parameters
    ----------
323
    sch : tvm.schedule.Schedule
324
        The schedule to be built
325 326 327 328

    args : list of Buffer or Tensor or Var
        The argument lists to the function.

329
    name : str, optional
330 331
        The name of result function.

332 333 334 335
    binds : dict of :any:`Tensor` to :any:`Buffer`, optional
        Dictionary that maps the Tensor to Buffer which specified the data layout
        requirement of the function. By default, a new compact buffer is created
        for each tensor in the argument.
336

337 338 339
    simple_mode : bool, optional
        Whether only output simple and compact statement, this will skip
        LoopPartition, api wrapper generation and Unrolling.
340

341 342
    Returns
    -------
343 344 345
    f : LoweredFunc or Stmt
       The result function, if with_api_wrapper=False
       Then the Stmt before make api is returned.
346
    """
347
    binds, arg_list = get_binds(args, binds)
348
    cfg = current_build_config()
349
    add_lower_pass = cfg.add_lower_pass if cfg.add_lower_pass else []
350 351
    if cfg.dump_pass_ir:
        add_lower_pass = BuildConfig._dump_ir.decorate_custompass(add_lower_pass)
352 353
    lower_phase0 = [x[1] for x in add_lower_pass if x[0] == 0]
    lower_phase1 = [x[1] for x in add_lower_pass if x[0] == 1]
354 355
    lower_phase2 = [x[1] for x in add_lower_pass if x[0] == 2]
    lower_phase3 = [x[1] for x in add_lower_pass if x[0] > 2]
356

357
    # Phase 0
358
    if isinstance(sch, schedule.Schedule):
359
        stmt = form_body(sch)
360

361 362 363
    for f in lower_phase0:
        stmt = f(stmt)
    # Phase 1
364
    stmt = ir_pass.StorageFlatten(stmt, binds, 64, cfg.instrument_bound_checkers)
365
    stmt = ir_pass.CanonicalSimplify(stmt)
366 367 368
    for f in lower_phase1:
        stmt = f(stmt)
    # Phase 2
369
    if not simple_mode:
370
        stmt = ir_pass.LoopPartition(stmt, cfg.partition_const_loop)
371
    stmt = ir_pass.VectorizeLoop(stmt)
372
    stmt = ir_pass.InjectVirtualThread(stmt)
373
    stmt = ir_pass.InjectDoubleBuffer(stmt, cfg.double_buffer_split_loop)
374
    stmt = ir_pass.StorageRewrite(stmt)
375 376 377
    stmt = ir_pass.UnrollLoop(
        stmt,
        cfg.auto_unroll_max_step,
378
        cfg.auto_unroll_max_depth,
379
        cfg.auto_unroll_max_extent,
380
        cfg.unroll_explicit)
381
    for f in lower_phase2:
382
        stmt = f(stmt)
383
    # Phase 3
384
    stmt = ir_pass.Simplify(stmt)
385 386
    stmt = ir_pass.LowerStorageAccessInfo(stmt)
    stmt = ir_pass.RemoveNoOp(stmt)
387 388
    if not cfg.disable_select_rewriting:
        stmt = ir_pass.RewriteUnsafeSelect(stmt)
389
    for f in lower_phase3:
390
        stmt = f(stmt)
391 392 393
    # Instrument BoundCheckers
    if cfg.instrument_bound_checkers:
        stmt = ir_pass.InstrumentBoundCheckers(stmt)
394
    if simple_mode:
395
        return stmt
396
    return ir_pass.MakeAPI(stmt, name, arg_list, 0, cfg.restricted_func)
397

398 399 400 401

def _build_for_device(flist, target, target_host):
    """Build the lowered functions for a device with the given compilation
    target.
402 403 404

    Parameters
    ----------
405 406
    flist : list of LoweredFunc
        The schedule to be built.
407

408
    target : str or :any:`tvm.target.Target`
409 410
        The target and option of the compilation.

411 412
    target_host : str or :any:`tvm.target.Target`
        The host compilation target.
Zhi committed
413

414 415
    Returns
    -------
416 417
    fhost : list of LoweredFunc
        A list of lowered functions for the host.
418

419 420
    mdev : tvm.module
        A module that contains device code.
421
    """
422
    target = _target.create(target)
423
    device_type = ndarray.context(target.target_name, 0).device_type
424 425 426
    fhost = []
    fdevice = []
    for func in flist:
427 428 429 430
        if not ir_pass.VerifyMemory(func, device_type):
            raise ValueError(
                "Direct host side access to device memory is detected in %s. "
                "Did you forget to bind?" % func.name)
431
        if func.func_type == container.LoweredFunc.MixedFunc:
432
            if current_build_config().detect_global_barrier:
433 434
                func = ir_pass.ThreadSync(func, "global")
            func = ir_pass.ThreadSync(func, "shared")
435
            func = ir_pass.ThreadSync(func, "warp")
436
            warp_size = target.thread_warp_size
437 438 439 440 441
            func = ir_pass.LowerThreadAllreduce(func, warp_size)
            fsplits = [s for s in ir_pass.SplitHostDevice(func)]
            fhost.append(fsplits[0])
            for x in fsplits[1:]:
                fdevice.append(x)
442
        elif func.func_type == container.LoweredFunc.HostFunc:
443
            fhost.append(func)
444
        elif func.func_type == container.LoweredFunc.DeviceFunc:
445 446 447 448
            fdevice.append(func)
        else:
            raise ValueError("unknown function type %d" % func.func_type)

449 450 451 452
    for i, func in enumerate(fdevice):
        warp_size = target.thread_warp_size
        fdevice[i] = ir_pass.LowerWarpMemory(func, warp_size)

453
    if "gpu" in target.keys and not fdevice:
454
        warnings.warn(
455 456
            "Specified target %s, but cannot find device code, did you do "
            "bind?" % target)
457 458

    fhost = [ir_pass.BindDeviceType(x, device_type) for x in fhost]
459
    fhost = [ir_pass.LowerTVMBuiltin(x) for x in fhost]
460

461 462 463
    if device_type == ndarray.cpu(0).device_type and target_host == target:
        assert not fdevice

464
    target_host = _target.create(target_host)
465
    fdevice = [ir_pass.LowerIntrin(x, target.target_name) for x in fdevice]
466
    fhost = [ir_pass.LowerIntrin(x, target_host.target_name) for x in fhost]
467
    fhost = [ir_pass.CombineContextCall(x) for x in fhost]
468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500
    mdev = codegen.build_module(fdevice, str(target)) if fdevice else None

    return fhost, mdev


def build(inputs,
          args=None,
          target=None,
          target_host=None,
          name="default_function",
          binds=None):
    """Build a function with arguments as signature. Code will be generated
    for devices coupled with target information.

    Parameters
    ----------
    inputs : tvm.Schedule, LoweredFunc, or dict of target to LoweredFunc list
        The schedule to be built

    args : list of Buffer or Tensor or Var, optional
        The argument lists to the function.

    target : str or :any:`tvm.target.Target`, optional
        The target and option of the compilation.

    target_host : str or :any:`tvm.target.Target` optional
        Host compilation target, if target is device.
        When TVM compiles device specific program such as CUDA,
        we also need host(CPU) side code to interact with the driver
        setup the dimensions and parameters correctly.
        target_host is used to specify the host side codegen target.
        By default, llvm is used if it is enabled,
        otherwise a stackvm intepreter is used.
501

502 503 504 505 506 507
    name : str, optional
        The name of result function.

    binds : dict, optional
        Dictionary that maps the binding of symbolic buffer to Tensor.
        By default, a new buffer is created for each tensor in the argument.
Zhi committed
508

509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612
    Returns
    -------
    ret : tvm.module
        A module that combines both host and device code.

    Examples
    ________
    There are two typical example uses of this function depending on the type
    of the argument `inputs`:
    1. it is a list of lowered functions:

    .. code-block:: python

        n = 2
        A = tvm.placeholder((n,), name='A')
        B = tvm.placeholder((n,), name='B')
        C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
        s = tvm.create_schedule(C.op)
        f = tvm.lower(s, [A, B, C], name="test_add")
        m = tvm.build(f, target="llvm")

    2. it is a dict of compilation target to list of lowered functions:

    .. code-block:: python

        n = 2
        A = tvm.placeholder((n,), name='A')
        B = tvm.placeholder((n,), name='B')
        C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
        s1 = tvm.create_schedule(C.op)
        s2 = topi.cpp.cuda.schedule_injective("cuda", [C])
        f1 = tvm.lower(s1, [A, B, C], name="test_add1")
        f2 = tvm.lower(s2, [A, B, C], name="test_add2")
        m = tvm.build({"llvm": [f1], "cuda": [f2]}, target_host="llvm")

    Note
    ----
    See the note on :any:`tvm.target` on target string format.
    """
    if isinstance(inputs, schedule.Schedule):
        if args is None:
            raise ValueError("args must be given for build from schedule")
        flist = lower(inputs, args,
                      name=name,
                      binds=binds)
        if isinstance(flist, container.LoweredFunc):
            flist = [flist]
    elif isinstance(inputs, container.LoweredFunc):
        if args:
            raise ValueError("args must be done when build from LoweredFunc.")
        flist = [inputs]
    elif isinstance(inputs, (list, tuple, container.Array)):
        flist = inputs
    elif not isinstance(inputs, (dict, container.Map)):
        raise ValueError("inputs must be Schedule, LoweredFunc, list of "
                         "LoweredFunc, or dict of target to list of "
                         "LoweredFunc.")

    if not isinstance(inputs, (dict, container.Map)):
        target = _target.current_target() if target is None else target
        target = target if target else "llvm"
        target_flist = {target: flist}
    else:
        target_flist = inputs

    for tar, flist in target_flist.items():
        if not isinstance(tar, (str, _target.Target)):
            raise ValueError("The key of inputs must be str or "
                             "_target.Target when inputs is dict.")
        fname_set = set()
        for x in flist:
            if not isinstance(x, container.LoweredFunc):
                raise ValueError("inputs must be Schedule, LoweredFunc, list "
                                 "of LoweredFunc, or dict of str to list of "
                                 "LoweredFunc.")
            if x.name in fname_set:
                raise ValueError("Duplicate function name %s" % x.name)
            fname_set.add(x.name)

    if not target_host:
        for tar, _ in target_flist.items():
            tar = _target.create(tar)
            device_type = ndarray.context(tar.target_name, 0).device_type
            if device_type == ndarray.cpu(0).device_type:
                target_host = tar
                break
    if not target_host:
        target_host = "llvm" if module.enabled("llvm") else "stackvm"

    fhost_all = []
    device_modules = []
    for tar, flist in target_flist.items():
        fhost, mdev = _build_for_device(flist, tar, target_host)
        # Save the current lowered functions of the host and the device module.
        fhost_all += fhost
        device_modules.append(mdev)

    # Generate a unified host module.
    mhost = codegen.build_module(fhost_all, str(target_host))

    # Import all modules.
    for mdev in device_modules:
        if mdev:
            mhost.import_module(mdev)
613
    return mhost