schedule.py 18.3 KB
Newer Older
1
"""The computation schedule api of TVM."""
2
from __future__ import absolute_import as _abs
3
from ._ffi.base import string_types
4
from ._ffi.node import NodeBase, register_node
5
from ._ffi.function import _init_api
6
from . import _api_internal
7
from . import tensor as _tensor
8
from . import expr as _expr
9
from . import container as _container
10

11
@register_node
12
class Buffer(NodeBase):
13 14 15 16 17 18 19 20 21 22 23 24
    """Symbolic data buffer in TVM.

    Buffer provide a way to represent data layout
    specialization of data structure in TVM.

    Do not construct directly, use :any:`decl_buffer` instead.
    See the documentation of :any:`decl_buffer` for more details.

    See Also
    --------
    decl_buffer : Declare a buffer
    """
25 26 27
    READ = 1
    WRITE = 2

28
    def access_ptr(self, access_mask, ptr_type="handle", content_lanes=1):
29
        """Get an access pointer to the head of buffer.
30 31 32 33 34 35 36 37 38 39 40 41 42 43

        This is the recommended method to get buffer data
        ptress when interacting with external functions.

        Parameters
        ----------
        access_mask : int
            The access pattern MASK. Indicate whether the
            access will read or write to the data content.

        ptr_type : str, optional
            The data type of the result pointer. Do not specify
            unless we want to cast pointer to specific type.

44 45 46 47
        content_lanes: int, optional
            The number of lanes for the data type. This value
            is greater than one for vector types.

48 49 50 51
        Examples
        --------
        .. code-block:: python

52
          import tvm.schedule.Buffer
53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69
          # Get access ptr for read
          buffer.access_ptr("r")
          # Get access ptr for read/write with bitmask
          buffer.access_ptr(Buffer.READ | Buffer.WRITE)
          # Get access ptr for read/write with str flag
          buffer.access_ptr("rw")
        """
        if isinstance(access_mask, string_types):
            mask = 0
            for value in access_mask:
                if value == "r":
                    mask = mask | Buffer.READ
                elif value == "w":
                    mask = mask | Buffer.WRITE
                else:
                    raise ValueError("Unknown access_mask %s" % access_mask)
            access_mask = mask
70 71
        return _api_internal._BufferAccessPtr(self, access_mask, ptr_type,
                                              content_lanes)
72

73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112
    def vload(self, begin, dtype=None):
        """Generate an Expr that loads dtype from begin index.

        Parameters
        ----------
        begin : Array of Expr
            The beginning index in unit of Buffer.dtype

        dtype : str
            The data type to be loaded,
            can be vector type which have lanes that is multiple of Buffer.dtype

        Returns
        -------
        load : Expr
            The corresponding load expression.
        """
        begin = (begin,) if isinstance(begin, (int, _expr.Expr)) else begin
        dtype = dtype if dtype else self.dtype
        return _api_internal._BufferVLoad(self, begin, dtype)

    def vstore(self, begin, value):
        """Generate a Stmt that store value into begin index.

        Parameters
        ----------
        begin : Array of Expr
            The beginning index in unit of Buffer.dtype

        value : Expr
            The value to be stored.

        Returns
        -------
        store : Stmt
            The corresponding store stmt.
        """
        begin = (begin,) if isinstance(begin, (int, _expr.Expr)) else begin
        return _api_internal._BufferVStore(self, begin, value)

113

114
@register_node
tqchen committed
115
class Split(NodeBase):
116
    """Split operation on axis."""
117 118
    pass

119

120
@register_node
tqchen committed
121
class Fuse(NodeBase):
122
    """Fuse operation on axis."""
123 124
    pass

125

126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146
@register_node
class IterVar(NodeBase, _expr.ExprOp):
    """Represent iteration variable.

    IterVar is normally created by Operation, to represent
    axis iterations in the computation.
    It can also created by schedule primitives like :any:`tvm.schedule.Stage.split`.

    See Also
    --------
    tvm.thread_axis: Create thread axis IterVar.
    tvm.reduce_axis: Create reduce axis IterVar.
    """
    DataPar = 0
    ThreadIndex = 1
    CommReduce = 2
    Ordered = 3
    DimInfo = 4
    Unrolled = 5
    Vectorized = 6
    Parallelized = 7
147
    Tensorized = 8
148 149

_tensor.iter_var_cls = IterVar
150

151 152 153 154 155 156 157 158 159 160 161 162 163
def create_schedule(ops):
    """Create a schedule for list of ops

    Parameters
    ----------
    ops : list of Operations
        The source expression.

    Returns
    -------
    sch : schedule.Schedule
        The created schedule.
    """
164
    if not isinstance(ops, (list, _container.Array)):
165
        ops = [ops]
166
    return _api_internal._CreateSchedule(ops)
167 168


169 170
@register_node
class Schedule(NodeBase):
171
    """Schedule for all the stages."""
172 173 174 175 176
    def __getitem__(self, k):
        if isinstance(k, _tensor.Tensor):
            k = k.op
        if not isinstance(k, _tensor.Operation):
            raise ValueError("Expect schedule key to be Tensor or Operation")
177
        if k not in self.stage_map:
178 179 180
            raise ValueError("Cannot find the operation %s in schedule" % (str(k)))
        return self.stage_map[k]

181
    def normalize(self):
182
        """Build a normalized schedule from the current schedule.
183 184 185

        Insert necessary rebase to make certain iter var to start from 0.
        This is needed before bound inference and followup step.
186 187 188 189 190

        Returns
        -------
        sch : Schedule
            The normalized schedule.
191
        """
192
        return _api_internal._ScheduleNormalize(self)
193

194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209
    def create_group(self, outputs, inputs, include_inputs=False):
        """Create stage group by giving output and input boundary.

        The operators between outputs and inputs are placed as member of group.
        outputs are include in the group, while inputs are not included.

        Parameters
        ----------
        outputs : list of Tensors
            The outputs of the group.

        inputs : list of Tensors
            The inputs of the group.

        include_inputs : boolean, optional
            Whether include input operations in the group if they are used by outputs.
210 211 212 213 214 215

        Returns
        -------
        group : Stage
            A virtual stage represents the group, user can use compute_at to move
            the attachment point of the group.
216 217 218 219 220 221 222 223
        """
        if isinstance(outputs, _tensor.Tensor):
            outputs = [outputs]
        if isinstance(inputs, _tensor.Tensor):
            inputs = [inputs]
        return _api_internal._ScheduleCreateGroup(
            self, outputs, inputs, include_inputs)

224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255
    def cache_read(self, tensor, scope, readers):
        """Create a cache read of original tensor for readers.

        This will mutate the body of the readers.
        A new cache stage will be created for the tensor.
        Call this before doing any split/fuse schedule.

        Parameters
        ----------
        tensor : Tensor
            The tensor to be cached.
        scope : str
            The scope of cached
        readers : list of Tensor or Operation
            The readers to read the cache.

        Returns
        -------
        cache : Tensor
            The created cache tensor.
        """
        if isinstance(readers, (_tensor.Tensor, _tensor.Operation)):
            readers = [readers]
        readers = [t.op if isinstance(t, _tensor.Tensor) else t for t in readers]
        return _api_internal._ScheduleCacheRead(self, tensor, scope, readers)

    def cache_write(self, tensor, scope):
        """Create a cache write of original tensor, before storing into tensor.

        This will mutate the body of the tensor.
        A new cache stage will created before feed into the tensor.

256 257 258 259 260 261 262 263
        This function can be used to support data layout transformation.
        If there is a split/fuse/reorder on the data parallel axis of tensor
        before cache_write is called. The intermediate cache stores
        the data in the layout as the iteration order of leave axis.
        The data will be transformed back to the original layout in the original tensor.
        User can further call compute_inline to inline the original layout and keep
        the data stored in the transformed layout.

264 265 266 267 268 269 270 271 272 273 274 275 276 277
        Parameters
        ----------
        tensor : Tensor
            The tensor to be feed to.
        scope : str
            The scope of cached

        Returns
        -------
        cache : Tensor
            The created cache tensor.
        """
        return _api_internal._ScheduleCacheWrite(self, tensor, scope)

278 279 280 281
    def rfactor(self, tensor, axis):
        """ Factor a reduction axis in tensor's schedule to be an explicit axis.

        This will create a new stage that generated the new tensor with axis
282
        as the first dimension. The tensor's body will be rewritten as a reduction
283 284 285 286 287 288 289 290 291 292 293
        over the factored tensor.

        Parameters
        ----------
        tensor : Tensor
            The tensor to be factored.
        axis : IterVar
            The reduction axis in the schedule to be factored.

        Returns
        -------
294
        tfactor : Tensor or Array of Tensor
295 296
            The created factored tensor.
        """
297 298
        factored = _api_internal._ScheduleRFactor(self, tensor, axis)
        return factored[0] if len(factored) == 1 else factored
299

300

301 302
@register_node
class Stage(NodeBase):
303
    """A Stage represents schedule for one operation."""
304
    def split(self, parent, factor=None, nparts=None):
305
        """Split the stage either by factor providing outer scope, or both
tqchen committed
306 307 308 309 310 311 312 313 314

        Parameters
        ----------
        parent : IterVar
             The parent iter var.

        factor : Expr, optional
             The splitting factor

315 316
        nparts : Expr, optional
             The number of outer parts.
tqchen committed
317 318 319 320 321 322 323 324 325

        Returns
        -------
        outer : IterVar
            The outer variable of iteration.

        inner : IterVar
            The inner variable of iteration.
        """
326 327 328 329
        if nparts is not None:
            if factor is not None:
                raise ValueError("Donot need to provide both outer and nparts")
            outer, inner = _api_internal._StageSplitByNParts(self, parent, nparts)
tqchen committed
330 331
        else:
            if factor is None:
332
                raise ValueError("Either nparts or factor need to be provided")
333
            outer, inner = _api_internal._StageSplitByFactor(self, parent, factor)
tqchen committed
334 335
        return outer, inner

336 337 338 339 340
    def fuse(self, *args):
        """Fuse multiple consecutive iteration variables into a single iteration variable.

        fused = fuse(...fuse(fuse(args[0], args[1]), args[2]),..., args[-1])
        The order is from outer to inner.
tqchen committed
341 342 343

        Parameters
        ----------
344 345
        args : list of IterVars
            Itervars that proceeds each other
tqchen committed
346 347 348

        Returns
        -------
349
        fused : IterVar
tqchen committed
350 351
            The fused variable of iteration.
        """
352 353 354 355 356
        assert len(args) >= 1, "Length of the arguments must be >=1 for fuse."
        fused = args[0]
        for i in range(1, len(args)):
            fused = _api_internal._StageFuse(self, fused, args[i])
        return fused
357 358 359 360 361 362 363 364 365

    def set_scope(self, scope):
        """Set the thread scope of this stage

        Parameters
        ----------
        scope : str
            The thread scope of this stage
        """
366
        return _api_internal._StageSetScope(self, scope)
tqchen committed
367

368 369 370 371 372 373 374 375 376 377 378 379 380 381 382
    def bind(self, ivar, thread_ivar):
        """Bind ivar to thread index thread_ivar

        Parameters
        ----------
        ivar : IterVar
            The iteration to be binded to thread.

        thread_ivar : IterVar
            The thread to be binded.
        """
        _api_internal._StageBind(self, ivar, thread_ivar)

    def env_threads(self, threads):
        """Mark threads to be launched at the outer scope of composed op.
383 384 385 386 387 388

        Parameters
        ----------
        threads : list of threads
            The threads to be launched.
        """
389
        if isinstance(threads, IterVar):
390
            threads = [threads]
391
        _api_internal._StageEnvThreads(self, threads)
392

393 394 395 396 397 398 399 400 401 402 403 404 405
    def set_store_predicate(self, predicate):
        """Set predicate under which store to the array can be performed.

        Use this when there are duplicated threads doing the same store and we only
        need one of them to do the store.

        Parameters
        ----------
        predicate : Expr
            The guard condition fo store.
        """
        _api_internal._StageSetStorePredicate(self, predicate)

tqchen committed
406
    def compute_at(self, parent, scope):
407
        """Attach the stage at parent's scope
tqchen committed
408 409 410

        Parameters
        ----------
411 412
        parent : Stage
            The parent stage
tqchen committed
413 414 415 416

        scope : IterVar
            The loop scope t be attached to.
        """
417
        _api_internal._StageComputeAt(self, parent, scope)
tqchen committed
418

419 420
    def compute_inline(self):
        """Mark stage as inline
tqchen committed
421 422 423

        Parameters
        ----------
424 425
        parent : Stage
            The parent stage
tqchen committed
426
        """
427
        _api_internal._StageComputeInline(self)
tqchen committed
428

429 430
    def compute_root(self):
        """Attach the stage at parent, and mark it as root
tqchen committed
431 432 433

        Parameters
        ----------
434 435
        parent : Stage
            The parent stage
tqchen committed
436
        """
437
        _api_internal._StageComputeRoot(self)
tqchen committed
438 439 440 441 442 443 444 445 446

    def reorder(self, *args):
        """reorder the arguments in the specified order.

        Parameters
        ----------
        args : list of IterVar
            The order to be ordered
        """
447
        _api_internal._StageReorder(self, args)
ZihengJiang committed
448 449

    def tile(self, x_parent, y_parent, x_factor, y_factor):
450 451 452 453 454 455 456 457 458 459 460 461 462
        """ Perform tiling on two dimensions

        The final loop order from outmost to inner most are
        [x_outer, y_outer, x_inner, y_inner]

        Parameters
        ----------
        x_parent : IterVar
            The original x dimension
        y_parent : IterVar
            The original y dimension
        x_factor : Expr
            The stride factor on x axis
463 464
        y_factor : Expr
            The stride factor on y axis
465 466 467 468 469 470 471 472 473 474 475 476

        Returns
        -------
        x_outer : IterVar
            Outer axis of x dimension
        y_outer : IterVar
            Outer axis of y dimension
        x_inner : IterVar
            Inner axis of x dimension
        p_y_inner : IterVar
            Inner axis of y dimension
        """
477
        x_outer, y_outer, x_inner, y_inner = _api_internal._StageTile(
ZihengJiang committed
478 479
            self, x_parent, y_parent, x_factor, y_factor)
        return x_outer, y_outer, x_inner, y_inner
480 481 482 483 484 485 486 487 488 489 490

    def vectorize(self, var):
        """Vectorize the iteration.

        Parameters
        ----------
        var : IterVar
            The iteration to be vectorize
        """
        _api_internal._StageVectorize(self, var)

491 492 493 494 495 496 497 498 499 500 501 502 503
    def tensorize(self, var, tensor_intrin):
        """Tensorize the computation enclosed by var with tensor_intrin

        Parameters
        ----------
        var : IterVar
            The iteration boundary of tensorization.

        tensor_intrin : TensorIntrin
            The tensor intrinsic used for computation.
        """
        _api_internal._StageTensorize(self, var, tensor_intrin)

504 505 506 507 508 509 510 511 512
    def unroll(self, var):
        """Unroll the iteration.

        Parameters
        ----------
        var : IterVar
            The iteration to be unrolled.
        """
        _api_internal._StageUnroll(self, var)
513 514 515 516 517 518 519 520 521 522

    def parallel(self, var):
        """Parallelize the iteration.

        Parameters
        ----------
        var : IterVar
            The iteration to be parallelized.
        """
        _api_internal._StageParallel(self, var)
523

524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543
    def pragma(self, var, pragma_type):
        """Annotate the iteration with pragma

        This will translate to a pragma_scope surrounding
        the corresponding loop generated.
        Useful to support experimental features and extensions.

        Parameters
        ----------
        var : IterVar
            The iteration to be anotated

        pragma_type : str
             The pragma string to be annotated

        Note
        ----
        Most pragmas are advanced/experimental features
        and may subject to change. List of supported pragmas:

544 545 546 547 548
        - **debug_skip_region**

          Force skip the region marked by the axis and turn it into no-op.
          This is useful for debug purposes.

549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570
        - **parallel_launch_point**

          Specify to launch parallel threads outside the
          specified iteration loop. By default the threads
          launch at the point of parallel construct.
          This pragma moves the launching point to even outer scope.
          The threads are launched once and reused across multiple
          parallel constructs as BSP style program.

        - **parallel_barrier_when_finish**

          Insert a synchronization barrier between working threads
          after the specified loop iteration finishes.

        - **parallel_stride_pattern**

          Hint parallel loop to execute in strided pattern.
          :code:`for (int i = task_id; i < end; i += num_task)`

        """
        _api_internal._StagePragma(self, var, pragma_type)

571 572 573 574 575 576 577 578 579 580 581 582 583 584
    def prefetch(self, tensor, var, offset):
        """Prefetch the specified variable

        Parameters
        ----------
        tensor : Tensor
            The tensor to be prefetched
        var : IterVar
            The loop point at which the prefetching is applied
        offset : Expr
            The number of iterations to be prefetched before actual execution
        """
        _api_internal._StagePrefetch(self, tensor, var, offset)

585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604
    def storage_align(self, axis, factor, offset):
        """Set alignment requirement for specific axis

        This ensures that stride[axis] == k * factor + offset for some k.
        This is useful to set memory layout to for more friendly memory
        access pattern. For example, we can set alignment to be
        factor=2, offset=1 to avoid bank conflict for thread access on
        higher dimension in GPU shared memory.

        Parameters
        ----------
        axis : IterVar
            The axis dimension to be aligned.
        factor : int
            The factor in alignment specification.
        offset : int
            The offset in the alignment specification.
        """
        _api_internal._StageStorageAlign(self, axis, factor, offset)

605 606 607 608 609 610 611 612 613
    def double_buffer(self):
        """Compute the current stage via double buffering.

        This can only be applied to intermediate stage.
        This will double the storage cost of the current stage.
        Can be useful to hide load latency.
        """
        _api_internal._StageDoubleBuffer(self)

614
_init_api("tvm.schedule")