schedule.py 19.6 KB
Newer Older
1
"""The computation schedule api of TVM."""
2
from __future__ import absolute_import as _abs
3
from ._ffi.base import string_types
4
from ._ffi.node import NodeBase, register_node
5 6 7
from ._ffi.node import convert_to_node as _convert_to_node
from ._ffi.function import _init_api, Function
from ._ffi.function import convert_to_tvm_func as _convert_tvm_func
8
from . import _api_internal
9
from . import tensor as _tensor
10
from . import expr as _expr
11
from . import container as _container
12

13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
def convert(value):
    """Convert value to TVM node or function.

    Parameters
    ----------
    value : python value

    Returns
    -------
    tvm_val : Node or Function
        Converted value in TVM
    """
    if isinstance(value, (Function, NodeBase)):
        return value

    if callable(value):
        return _convert_tvm_func(value)

    return _convert_to_node(value)

33
@register_node
34
class Buffer(NodeBase):
35 36 37 38 39 40 41 42 43 44 45 46
    """Symbolic data buffer in TVM.

    Buffer provide a way to represent data layout
    specialization of data structure in TVM.

    Do not construct directly, use :any:`decl_buffer` instead.
    See the documentation of :any:`decl_buffer` for more details.

    See Also
    --------
    decl_buffer : Declare a buffer
    """
47 48 49
    READ = 1
    WRITE = 2

50
    def access_ptr(self, access_mask, ptr_type="handle", content_lanes=1, offset=0):
51
        """Get an access pointer to the head of buffer.
52 53 54 55 56 57 58 59 60 61 62 63 64 65

        This is the recommended method to get buffer data
        ptress when interacting with external functions.

        Parameters
        ----------
        access_mask : int
            The access pattern MASK. Indicate whether the
            access will read or write to the data content.

        ptr_type : str, optional
            The data type of the result pointer. Do not specify
            unless we want to cast pointer to specific type.

66 67 68 69
        content_lanes: int, optional
            The number of lanes for the data type. This value
            is greater than one for vector types.

70
        offset: Expr, optional
71 72 73
            The offset of pointer. We can use it to offset by
            the number of elements from the address of ptr.

74 75 76 77
        Examples
        --------
        .. code-block:: python

78
          import tvm.schedule.Buffer
79 80 81 82 83 84
          # Get access ptr for read
          buffer.access_ptr("r")
          # Get access ptr for read/write with bitmask
          buffer.access_ptr(Buffer.READ | Buffer.WRITE)
          # Get access ptr for read/write with str flag
          buffer.access_ptr("rw")
85 86
          # Get access ptr for read with offset
          buffer.access_ptr("r", offset = 100)
87 88 89 90 91 92 93 94 95 96 97
        """
        if isinstance(access_mask, string_types):
            mask = 0
            for value in access_mask:
                if value == "r":
                    mask = mask | Buffer.READ
                elif value == "w":
                    mask = mask | Buffer.WRITE
                else:
                    raise ValueError("Unknown access_mask %s" % access_mask)
            access_mask = mask
98
        offset = convert(offset)
99
        return _api_internal._BufferAccessPtr(self, access_mask, ptr_type,
100
                                              content_lanes, offset)
101

102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141
    def vload(self, begin, dtype=None):
        """Generate an Expr that loads dtype from begin index.

        Parameters
        ----------
        begin : Array of Expr
            The beginning index in unit of Buffer.dtype

        dtype : str
            The data type to be loaded,
            can be vector type which have lanes that is multiple of Buffer.dtype

        Returns
        -------
        load : Expr
            The corresponding load expression.
        """
        begin = (begin,) if isinstance(begin, (int, _expr.Expr)) else begin
        dtype = dtype if dtype else self.dtype
        return _api_internal._BufferVLoad(self, begin, dtype)

    def vstore(self, begin, value):
        """Generate a Stmt that store value into begin index.

        Parameters
        ----------
        begin : Array of Expr
            The beginning index in unit of Buffer.dtype

        value : Expr
            The value to be stored.

        Returns
        -------
        store : Stmt
            The corresponding store stmt.
        """
        begin = (begin,) if isinstance(begin, (int, _expr.Expr)) else begin
        return _api_internal._BufferVStore(self, begin, value)

142

143
@register_node
tqchen committed
144
class Split(NodeBase):
145
    """Split operation on axis."""
146 147
    pass

148

149
@register_node
tqchen committed
150
class Fuse(NodeBase):
151
    """Fuse operation on axis."""
152 153
    pass

154

155
@register_node
156 157 158 159 160 161
class Singleton(NodeBase):
    """Singleton axis."""
    pass


@register_node
162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181
class IterVar(NodeBase, _expr.ExprOp):
    """Represent iteration variable.

    IterVar is normally created by Operation, to represent
    axis iterations in the computation.
    It can also created by schedule primitives like :any:`tvm.schedule.Stage.split`.

    See Also
    --------
    tvm.thread_axis: Create thread axis IterVar.
    tvm.reduce_axis: Create reduce axis IterVar.
    """
    DataPar = 0
    ThreadIndex = 1
    CommReduce = 2
    Ordered = 3
    DimInfo = 4
    Unrolled = 5
    Vectorized = 6
    Parallelized = 7
182
    Tensorized = 8
183 184

_tensor.iter_var_cls = IterVar
185

186 187 188 189 190 191 192 193 194 195 196 197 198
def create_schedule(ops):
    """Create a schedule for list of ops

    Parameters
    ----------
    ops : list of Operations
        The source expression.

    Returns
    -------
    sch : schedule.Schedule
        The created schedule.
    """
199
    if not isinstance(ops, (list, _container.Array)):
200
        ops = [ops]
201
    return _api_internal._CreateSchedule(ops)
202 203


204 205
@register_node
class Schedule(NodeBase):
206
    """Schedule for all the stages."""
207 208 209 210 211
    def __getitem__(self, k):
        if isinstance(k, _tensor.Tensor):
            k = k.op
        if not isinstance(k, _tensor.Operation):
            raise ValueError("Expect schedule key to be Tensor or Operation")
212
        if k not in self.stage_map:
213 214 215
            raise ValueError("Cannot find the operation %s in schedule" % (str(k)))
        return self.stage_map[k]

216
    def normalize(self):
217
        """Build a normalized schedule from the current schedule.
218 219 220

        Insert necessary rebase to make certain iter var to start from 0.
        This is needed before bound inference and followup step.
221 222 223 224 225

        Returns
        -------
        sch : Schedule
            The normalized schedule.
226
        """
227
        return _api_internal._ScheduleNormalize(self)
228

229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244
    def create_group(self, outputs, inputs, include_inputs=False):
        """Create stage group by giving output and input boundary.

        The operators between outputs and inputs are placed as member of group.
        outputs are include in the group, while inputs are not included.

        Parameters
        ----------
        outputs : list of Tensors
            The outputs of the group.

        inputs : list of Tensors
            The inputs of the group.

        include_inputs : boolean, optional
            Whether include input operations in the group if they are used by outputs.
245 246 247 248 249 250

        Returns
        -------
        group : Stage
            A virtual stage represents the group, user can use compute_at to move
            the attachment point of the group.
251 252 253 254 255 256 257 258
        """
        if isinstance(outputs, _tensor.Tensor):
            outputs = [outputs]
        if isinstance(inputs, _tensor.Tensor):
            inputs = [inputs]
        return _api_internal._ScheduleCreateGroup(
            self, outputs, inputs, include_inputs)

259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290
    def cache_read(self, tensor, scope, readers):
        """Create a cache read of original tensor for readers.

        This will mutate the body of the readers.
        A new cache stage will be created for the tensor.
        Call this before doing any split/fuse schedule.

        Parameters
        ----------
        tensor : Tensor
            The tensor to be cached.
        scope : str
            The scope of cached
        readers : list of Tensor or Operation
            The readers to read the cache.

        Returns
        -------
        cache : Tensor
            The created cache tensor.
        """
        if isinstance(readers, (_tensor.Tensor, _tensor.Operation)):
            readers = [readers]
        readers = [t.op if isinstance(t, _tensor.Tensor) else t for t in readers]
        return _api_internal._ScheduleCacheRead(self, tensor, scope, readers)

    def cache_write(self, tensor, scope):
        """Create a cache write of original tensor, before storing into tensor.

        This will mutate the body of the tensor.
        A new cache stage will created before feed into the tensor.

291 292 293 294 295 296 297 298
        This function can be used to support data layout transformation.
        If there is a split/fuse/reorder on the data parallel axis of tensor
        before cache_write is called. The intermediate cache stores
        the data in the layout as the iteration order of leave axis.
        The data will be transformed back to the original layout in the original tensor.
        User can further call compute_inline to inline the original layout and keep
        the data stored in the transformed layout.

299 300
        Parameters
        ----------
301 302
        tensor : Tensor, list or tuple
            The tensors to be feed to. All the tensors must be produced by one computeOp
303 304 305 306 307 308 309 310 311 312
        scope : str
            The scope of cached

        Returns
        -------
        cache : Tensor
            The created cache tensor.
        """
        return _api_internal._ScheduleCacheWrite(self, tensor, scope)

313
    def rfactor(self, tensor, axis, factor_axis=0):
314 315 316
        """ Factor a reduction axis in tensor's schedule to be an explicit axis.

        This will create a new stage that generated the new tensor with axis
317
        as the first dimension. The tensor's body will be rewritten as a reduction
318 319 320 321 322 323 324 325
        over the factored tensor.

        Parameters
        ----------
        tensor : Tensor
            The tensor to be factored.
        axis : IterVar
            The reduction axis in the schedule to be factored.
326 327
        factor_axis : int
            The position where the new axis is placed.
328 329 330

        Returns
        -------
331
        tfactor : Tensor or Array of Tensor
332 333
            The created factored tensor.
        """
334
        factored = _api_internal._ScheduleRFactor(self, tensor, axis, factor_axis)
335
        return factored[0] if len(factored) == 1 else factored
336

337

338 339
@register_node
class Stage(NodeBase):
340
    """A Stage represents schedule for one operation."""
341
    def split(self, parent, factor=None, nparts=None):
342
        """Split the stage either by factor providing outer scope, or both
tqchen committed
343 344 345 346 347 348 349 350 351

        Parameters
        ----------
        parent : IterVar
             The parent iter var.

        factor : Expr, optional
             The splitting factor

352 353
        nparts : Expr, optional
             The number of outer parts.
tqchen committed
354 355 356 357 358 359 360 361 362

        Returns
        -------
        outer : IterVar
            The outer variable of iteration.

        inner : IterVar
            The inner variable of iteration.
        """
363 364 365 366
        if nparts is not None:
            if factor is not None:
                raise ValueError("Donot need to provide both outer and nparts")
            outer, inner = _api_internal._StageSplitByNParts(self, parent, nparts)
tqchen committed
367 368
        else:
            if factor is None:
369
                raise ValueError("Either nparts or factor need to be provided")
370
            outer, inner = _api_internal._StageSplitByFactor(self, parent, factor)
tqchen committed
371 372
        return outer, inner

373 374 375 376 377
    def fuse(self, *args):
        """Fuse multiple consecutive iteration variables into a single iteration variable.

        fused = fuse(...fuse(fuse(args[0], args[1]), args[2]),..., args[-1])
        The order is from outer to inner.
tqchen committed
378 379 380

        Parameters
        ----------
381 382
        args : list of IterVars
            Itervars that proceeds each other
tqchen committed
383 384 385

        Returns
        -------
386
        fused : IterVar
tqchen committed
387 388
            The fused variable of iteration.
        """
389
        fused = _api_internal._StageFuse(self, args)
390
        return fused
391 392 393 394 395 396 397 398 399

    def set_scope(self, scope):
        """Set the thread scope of this stage

        Parameters
        ----------
        scope : str
            The thread scope of this stage
        """
400
        return _api_internal._StageSetScope(self, scope)
tqchen committed
401

402 403 404 405 406 407 408 409 410 411 412 413 414 415 416
    def bind(self, ivar, thread_ivar):
        """Bind ivar to thread index thread_ivar

        Parameters
        ----------
        ivar : IterVar
            The iteration to be binded to thread.

        thread_ivar : IterVar
            The thread to be binded.
        """
        _api_internal._StageBind(self, ivar, thread_ivar)

    def env_threads(self, threads):
        """Mark threads to be launched at the outer scope of composed op.
417 418 419 420 421 422

        Parameters
        ----------
        threads : list of threads
            The threads to be launched.
        """
423
        if isinstance(threads, IterVar):
424
            threads = [threads]
425
        _api_internal._StageEnvThreads(self, threads)
426

427 428 429 430 431 432 433 434 435 436 437 438 439
    def set_store_predicate(self, predicate):
        """Set predicate under which store to the array can be performed.

        Use this when there are duplicated threads doing the same store and we only
        need one of them to do the store.

        Parameters
        ----------
        predicate : Expr
            The guard condition fo store.
        """
        _api_internal._StageSetStorePredicate(self, predicate)

tqchen committed
440
    def compute_at(self, parent, scope):
441
        """Attach the stage at parent's scope
tqchen committed
442 443 444

        Parameters
        ----------
445 446
        parent : Stage
            The parent stage
tqchen committed
447 448 449 450

        scope : IterVar
            The loop scope t be attached to.
        """
451
        _api_internal._StageComputeAt(self, parent, scope)
tqchen committed
452

453 454
    def compute_inline(self):
        """Mark stage as inline
tqchen committed
455 456 457

        Parameters
        ----------
458 459
        parent : Stage
            The parent stage
tqchen committed
460
        """
461
        _api_internal._StageComputeInline(self)
tqchen committed
462

463 464
    def compute_root(self):
        """Attach the stage at parent, and mark it as root
tqchen committed
465 466 467

        Parameters
        ----------
468 469
        parent : Stage
            The parent stage
tqchen committed
470
        """
471
        _api_internal._StageComputeRoot(self)
tqchen committed
472 473 474 475 476 477 478 479 480

    def reorder(self, *args):
        """reorder the arguments in the specified order.

        Parameters
        ----------
        args : list of IterVar
            The order to be ordered
        """
481
        _api_internal._StageReorder(self, args)
ZihengJiang committed
482 483

    def tile(self, x_parent, y_parent, x_factor, y_factor):
484 485 486 487 488 489 490 491 492 493 494 495 496
        """ Perform tiling on two dimensions

        The final loop order from outmost to inner most are
        [x_outer, y_outer, x_inner, y_inner]

        Parameters
        ----------
        x_parent : IterVar
            The original x dimension
        y_parent : IterVar
            The original y dimension
        x_factor : Expr
            The stride factor on x axis
497 498
        y_factor : Expr
            The stride factor on y axis
499 500 501 502 503 504 505 506 507 508 509 510

        Returns
        -------
        x_outer : IterVar
            Outer axis of x dimension
        y_outer : IterVar
            Outer axis of y dimension
        x_inner : IterVar
            Inner axis of x dimension
        p_y_inner : IterVar
            Inner axis of y dimension
        """
511
        x_outer, y_outer, x_inner, y_inner = _api_internal._StageTile(
ZihengJiang committed
512 513
            self, x_parent, y_parent, x_factor, y_factor)
        return x_outer, y_outer, x_inner, y_inner
514 515 516 517 518 519 520 521 522 523 524

    def vectorize(self, var):
        """Vectorize the iteration.

        Parameters
        ----------
        var : IterVar
            The iteration to be vectorize
        """
        _api_internal._StageVectorize(self, var)

525 526 527 528 529 530 531 532 533 534 535 536 537
    def tensorize(self, var, tensor_intrin):
        """Tensorize the computation enclosed by var with tensor_intrin

        Parameters
        ----------
        var : IterVar
            The iteration boundary of tensorization.

        tensor_intrin : TensorIntrin
            The tensor intrinsic used for computation.
        """
        _api_internal._StageTensorize(self, var, tensor_intrin)

538 539 540 541 542 543 544 545 546
    def unroll(self, var):
        """Unroll the iteration.

        Parameters
        ----------
        var : IterVar
            The iteration to be unrolled.
        """
        _api_internal._StageUnroll(self, var)
547 548 549 550 551 552 553 554 555 556

    def parallel(self, var):
        """Parallelize the iteration.

        Parameters
        ----------
        var : IterVar
            The iteration to be parallelized.
        """
        _api_internal._StageParallel(self, var)
557

558
    def pragma(self, var, pragma_type, pragma_value=None):
559 560 561 562 563 564 565 566 567 568 569 570 571 572
        """Annotate the iteration with pragma

        This will translate to a pragma_scope surrounding
        the corresponding loop generated.
        Useful to support experimental features and extensions.

        Parameters
        ----------
        var : IterVar
            The iteration to be anotated

        pragma_type : str
             The pragma string to be annotated

573 574 575
        pragma_value : Expr, optional
             The pragma value to pass along the pragma

576 577 578 579 580
        Note
        ----
        Most pragmas are advanced/experimental features
        and may subject to change. List of supported pragmas:

581 582 583 584 585
        - **debug_skip_region**

          Force skip the region marked by the axis and turn it into no-op.
          This is useful for debug purposes.

586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605
        - **parallel_launch_point**

          Specify to launch parallel threads outside the
          specified iteration loop. By default the threads
          launch at the point of parallel construct.
          This pragma moves the launching point to even outer scope.
          The threads are launched once and reused across multiple
          parallel constructs as BSP style program.

        - **parallel_barrier_when_finish**

          Insert a synchronization barrier between working threads
          after the specified loop iteration finishes.

        - **parallel_stride_pattern**

          Hint parallel loop to execute in strided pattern.
          :code:`for (int i = task_id; i < end; i += num_task)`

        """
606 607
        if isinstance(pragma_value, string_types):
            pragma_value = convert(pragma_value)
608
        _api_internal._StagePragma(self, var, pragma_type, pragma_value)
609

610 611 612 613 614 615 616 617 618 619 620 621 622 623
    def prefetch(self, tensor, var, offset):
        """Prefetch the specified variable

        Parameters
        ----------
        tensor : Tensor
            The tensor to be prefetched
        var : IterVar
            The loop point at which the prefetching is applied
        offset : Expr
            The number of iterations to be prefetched before actual execution
        """
        _api_internal._StagePrefetch(self, tensor, var, offset)

624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643
    def storage_align(self, axis, factor, offset):
        """Set alignment requirement for specific axis

        This ensures that stride[axis] == k * factor + offset for some k.
        This is useful to set memory layout to for more friendly memory
        access pattern. For example, we can set alignment to be
        factor=2, offset=1 to avoid bank conflict for thread access on
        higher dimension in GPU shared memory.

        Parameters
        ----------
        axis : IterVar
            The axis dimension to be aligned.
        factor : int
            The factor in alignment specification.
        offset : int
            The offset in the alignment specification.
        """
        _api_internal._StageStorageAlign(self, axis, factor, offset)

644 645 646 647 648 649 650 651 652
    def double_buffer(self):
        """Compute the current stage via double buffering.

        This can only be applied to intermediate stage.
        This will double the storage cost of the current stage.
        Can be useful to hide load latency.
        """
        _api_internal._StageDoubleBuffer(self)

653 654 655 656 657 658 659
    def opengl(self):
        """The special OpenGL schedule

        Maps each output element to a pixel.
        """
        _api_internal._StageOpenGL(self)

660
_init_api("tvm.schedule")