schedule.py 20.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17
"""The computation schedule api of TVM."""
18
from __future__ import absolute_import as _abs
19
from ._ffi.base import string_types
20
from ._ffi.node import NodeBase, register_node
21 22 23
from ._ffi.node import convert_to_node as _convert_to_node
from ._ffi.function import _init_api, Function
from ._ffi.function import convert_to_tvm_func as _convert_tvm_func
24
from . import _api_internal
25
from . import tensor as _tensor
26
from . import expr as _expr
27
from . import container as _container
28

29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48
def convert(value):
    """Convert value to TVM node or function.

    Parameters
    ----------
    value : python value

    Returns
    -------
    tvm_val : Node or Function
        Converted value in TVM
    """
    if isinstance(value, (Function, NodeBase)):
        return value

    if callable(value):
        return _convert_tvm_func(value)

    return _convert_to_node(value)

49
@register_node
50
class Buffer(NodeBase):
51 52 53 54 55 56 57 58 59 60 61 62
    """Symbolic data buffer in TVM.

    Buffer provide a way to represent data layout
    specialization of data structure in TVM.

    Do not construct directly, use :any:`decl_buffer` instead.
    See the documentation of :any:`decl_buffer` for more details.

    See Also
    --------
    decl_buffer : Declare a buffer
    """
63 64 65
    READ = 1
    WRITE = 2

66
    def access_ptr(self, access_mask, ptr_type="handle", content_lanes=1, offset=0):
67
        """Get an access pointer to the head of buffer.
68 69 70 71 72 73 74 75 76 77 78 79 80 81

        This is the recommended method to get buffer data
        ptress when interacting with external functions.

        Parameters
        ----------
        access_mask : int
            The access pattern MASK. Indicate whether the
            access will read or write to the data content.

        ptr_type : str, optional
            The data type of the result pointer. Do not specify
            unless we want to cast pointer to specific type.

82 83 84 85
        content_lanes: int, optional
            The number of lanes for the data type. This value
            is greater than one for vector types.

86
        offset: Expr, optional
87 88 89
            The offset of pointer. We can use it to offset by
            the number of elements from the address of ptr.

90 91 92 93
        Examples
        --------
        .. code-block:: python

94
          import tvm.schedule.Buffer
95 96 97 98 99 100
          # Get access ptr for read
          buffer.access_ptr("r")
          # Get access ptr for read/write with bitmask
          buffer.access_ptr(Buffer.READ | Buffer.WRITE)
          # Get access ptr for read/write with str flag
          buffer.access_ptr("rw")
101 102
          # Get access ptr for read with offset
          buffer.access_ptr("r", offset = 100)
103 104 105 106 107 108 109 110 111 112 113
        """
        if isinstance(access_mask, string_types):
            mask = 0
            for value in access_mask:
                if value == "r":
                    mask = mask | Buffer.READ
                elif value == "w":
                    mask = mask | Buffer.WRITE
                else:
                    raise ValueError("Unknown access_mask %s" % access_mask)
            access_mask = mask
114
        offset = convert(offset)
115
        return _api_internal._BufferAccessPtr(self, access_mask, ptr_type,
116
                                              content_lanes, offset)
117

118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157
    def vload(self, begin, dtype=None):
        """Generate an Expr that loads dtype from begin index.

        Parameters
        ----------
        begin : Array of Expr
            The beginning index in unit of Buffer.dtype

        dtype : str
            The data type to be loaded,
            can be vector type which have lanes that is multiple of Buffer.dtype

        Returns
        -------
        load : Expr
            The corresponding load expression.
        """
        begin = (begin,) if isinstance(begin, (int, _expr.Expr)) else begin
        dtype = dtype if dtype else self.dtype
        return _api_internal._BufferVLoad(self, begin, dtype)

    def vstore(self, begin, value):
        """Generate a Stmt that store value into begin index.

        Parameters
        ----------
        begin : Array of Expr
            The beginning index in unit of Buffer.dtype

        value : Expr
            The value to be stored.

        Returns
        -------
        store : Stmt
            The corresponding store stmt.
        """
        begin = (begin,) if isinstance(begin, (int, _expr.Expr)) else begin
        return _api_internal._BufferVStore(self, begin, value)

158

159
@register_node
tqchen committed
160
class Split(NodeBase):
161
    """Split operation on axis."""
162

163

164
@register_node
tqchen committed
165
class Fuse(NodeBase):
166
    """Fuse operation on axis."""
167

168

169
@register_node
170 171 172 173 174
class Singleton(NodeBase):
    """Singleton axis."""


@register_node
175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194
class IterVar(NodeBase, _expr.ExprOp):
    """Represent iteration variable.

    IterVar is normally created by Operation, to represent
    axis iterations in the computation.
    It can also created by schedule primitives like :any:`tvm.schedule.Stage.split`.

    See Also
    --------
    tvm.thread_axis: Create thread axis IterVar.
    tvm.reduce_axis: Create reduce axis IterVar.
    """
    DataPar = 0
    ThreadIndex = 1
    CommReduce = 2
    Ordered = 3
    DimInfo = 4
    Unrolled = 5
    Vectorized = 6
    Parallelized = 7
195
    Tensorized = 8
196 197

_tensor.iter_var_cls = IterVar
198

199 200 201 202 203 204 205 206 207 208 209 210 211
def create_schedule(ops):
    """Create a schedule for list of ops

    Parameters
    ----------
    ops : list of Operations
        The source expression.

    Returns
    -------
    sch : schedule.Schedule
        The created schedule.
    """
212
    if not isinstance(ops, (list, _container.Array)):
213
        ops = [ops]
214
    return _api_internal._CreateSchedule(ops)
215 216


217 218
@register_node
class Schedule(NodeBase):
219
    """Schedule for all the stages."""
220 221 222 223 224
    def __getitem__(self, k):
        if isinstance(k, _tensor.Tensor):
            k = k.op
        if not isinstance(k, _tensor.Operation):
            raise ValueError("Expect schedule key to be Tensor or Operation")
225
        if k not in self.stage_map:
226 227 228
            raise ValueError("Cannot find the operation %s in schedule" % (str(k)))
        return self.stage_map[k]

229
    def normalize(self):
230
        """Build a normalized schedule from the current schedule.
231 232 233

        Insert necessary rebase to make certain iter var to start from 0.
        This is needed before bound inference and followup step.
234 235 236 237 238

        Returns
        -------
        sch : Schedule
            The normalized schedule.
239
        """
240
        return _api_internal._ScheduleNormalize(self)
241

242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257
    def create_group(self, outputs, inputs, include_inputs=False):
        """Create stage group by giving output and input boundary.

        The operators between outputs and inputs are placed as member of group.
        outputs are include in the group, while inputs are not included.

        Parameters
        ----------
        outputs : list of Tensors
            The outputs of the group.

        inputs : list of Tensors
            The inputs of the group.

        include_inputs : boolean, optional
            Whether include input operations in the group if they are used by outputs.
258 259 260 261 262 263

        Returns
        -------
        group : Stage
            A virtual stage represents the group, user can use compute_at to move
            the attachment point of the group.
264 265 266 267 268 269 270 271
        """
        if isinstance(outputs, _tensor.Tensor):
            outputs = [outputs]
        if isinstance(inputs, _tensor.Tensor):
            inputs = [inputs]
        return _api_internal._ScheduleCreateGroup(
            self, outputs, inputs, include_inputs)

272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303
    def cache_read(self, tensor, scope, readers):
        """Create a cache read of original tensor for readers.

        This will mutate the body of the readers.
        A new cache stage will be created for the tensor.
        Call this before doing any split/fuse schedule.

        Parameters
        ----------
        tensor : Tensor
            The tensor to be cached.
        scope : str
            The scope of cached
        readers : list of Tensor or Operation
            The readers to read the cache.

        Returns
        -------
        cache : Tensor
            The created cache tensor.
        """
        if isinstance(readers, (_tensor.Tensor, _tensor.Operation)):
            readers = [readers]
        readers = [t.op if isinstance(t, _tensor.Tensor) else t for t in readers]
        return _api_internal._ScheduleCacheRead(self, tensor, scope, readers)

    def cache_write(self, tensor, scope):
        """Create a cache write of original tensor, before storing into tensor.

        This will mutate the body of the tensor.
        A new cache stage will created before feed into the tensor.

304 305 306 307 308 309 310 311
        This function can be used to support data layout transformation.
        If there is a split/fuse/reorder on the data parallel axis of tensor
        before cache_write is called. The intermediate cache stores
        the data in the layout as the iteration order of leave axis.
        The data will be transformed back to the original layout in the original tensor.
        User can further call compute_inline to inline the original layout and keep
        the data stored in the transformed layout.

312 313
        Parameters
        ----------
314 315
        tensor : Tensor, list or tuple
            The tensors to be feed to. All the tensors must be produced by one computeOp
316 317 318 319 320 321 322 323 324 325
        scope : str
            The scope of cached

        Returns
        -------
        cache : Tensor
            The created cache tensor.
        """
        return _api_internal._ScheduleCacheWrite(self, tensor, scope)

326
    def rfactor(self, tensor, axis, factor_axis=0):
327 328 329
        """ Factor a reduction axis in tensor's schedule to be an explicit axis.

        This will create a new stage that generated the new tensor with axis
330
        as the first dimension. The tensor's body will be rewritten as a reduction
331 332 333 334 335 336 337 338
        over the factored tensor.

        Parameters
        ----------
        tensor : Tensor
            The tensor to be factored.
        axis : IterVar
            The reduction axis in the schedule to be factored.
339 340
        factor_axis : int
            The position where the new axis is placed.
341 342 343

        Returns
        -------
344
        tfactor : Tensor or Array of Tensor
345 346
            The created factored tensor.
        """
347
        factored = _api_internal._ScheduleRFactor(self, tensor, axis, factor_axis)
348
        return factored[0] if len(factored) == 1 else factored
349

350

351 352
@register_node
class Stage(NodeBase):
353
    """A Stage represents schedule for one operation."""
354
    def split(self, parent, factor=None, nparts=None):
355
        """Split the stage either by factor providing outer scope, or both
tqchen committed
356 357 358 359 360 361 362 363 364

        Parameters
        ----------
        parent : IterVar
             The parent iter var.

        factor : Expr, optional
             The splitting factor

365 366
        nparts : Expr, optional
             The number of outer parts.
tqchen committed
367 368 369 370 371 372 373 374 375

        Returns
        -------
        outer : IterVar
            The outer variable of iteration.

        inner : IterVar
            The inner variable of iteration.
        """
376 377
        if nparts is not None:
            if factor is not None:
Siju committed
378
                raise ValueError("Do not need to provide both outer and nparts")
379
            outer, inner = _api_internal._StageSplitByNParts(self, parent, nparts)
tqchen committed
380 381
        else:
            if factor is None:
382
                raise ValueError("Either nparts or factor need to be provided")
383
            outer, inner = _api_internal._StageSplitByFactor(self, parent, factor)
tqchen committed
384 385
        return outer, inner

386 387 388 389 390
    def fuse(self, *args):
        """Fuse multiple consecutive iteration variables into a single iteration variable.

        fused = fuse(...fuse(fuse(args[0], args[1]), args[2]),..., args[-1])
        The order is from outer to inner.
tqchen committed
391 392 393

        Parameters
        ----------
394 395
        args : list of IterVars
            Itervars that proceeds each other
tqchen committed
396 397 398

        Returns
        -------
399
        fused : IterVar
tqchen committed
400 401
            The fused variable of iteration.
        """
402
        fused = _api_internal._StageFuse(self, args)
403
        return fused
404 405 406 407 408 409 410 411 412

    def set_scope(self, scope):
        """Set the thread scope of this stage

        Parameters
        ----------
        scope : str
            The thread scope of this stage
        """
413
        return _api_internal._StageSetScope(self, scope)
tqchen committed
414

415 416 417 418 419 420 421 422 423 424 425 426 427 428 429
    def bind(self, ivar, thread_ivar):
        """Bind ivar to thread index thread_ivar

        Parameters
        ----------
        ivar : IterVar
            The iteration to be binded to thread.

        thread_ivar : IterVar
            The thread to be binded.
        """
        _api_internal._StageBind(self, ivar, thread_ivar)

    def env_threads(self, threads):
        """Mark threads to be launched at the outer scope of composed op.
430 431 432 433 434 435

        Parameters
        ----------
        threads : list of threads
            The threads to be launched.
        """
436
        if isinstance(threads, IterVar):
437
            threads = [threads]
438
        _api_internal._StageEnvThreads(self, threads)
439

440 441 442 443 444 445 446 447 448 449 450 451 452
    def set_store_predicate(self, predicate):
        """Set predicate under which store to the array can be performed.

        Use this when there are duplicated threads doing the same store and we only
        need one of them to do the store.

        Parameters
        ----------
        predicate : Expr
            The guard condition fo store.
        """
        _api_internal._StageSetStorePredicate(self, predicate)

tqchen committed
453
    def compute_at(self, parent, scope):
454
        """Attach the stage at parent's scope
tqchen committed
455 456 457

        Parameters
        ----------
458 459
        parent : Stage
            The parent stage
tqchen committed
460 461 462 463

        scope : IterVar
            The loop scope t be attached to.
        """
464
        _api_internal._StageComputeAt(self, parent, scope)
tqchen committed
465

466 467
    def compute_inline(self):
        """Mark stage as inline
tqchen committed
468 469 470

        Parameters
        ----------
471 472
        parent : Stage
            The parent stage
tqchen committed
473
        """
474
        _api_internal._StageComputeInline(self)
tqchen committed
475

476 477
    def compute_root(self):
        """Attach the stage at parent, and mark it as root
tqchen committed
478 479 480

        Parameters
        ----------
481 482
        parent : Stage
            The parent stage
tqchen committed
483
        """
484
        _api_internal._StageComputeRoot(self)
tqchen committed
485 486 487 488 489 490 491 492 493

    def reorder(self, *args):
        """reorder the arguments in the specified order.

        Parameters
        ----------
        args : list of IterVar
            The order to be ordered
        """
494
        _api_internal._StageReorder(self, args)
ZihengJiang committed
495 496

    def tile(self, x_parent, y_parent, x_factor, y_factor):
497 498 499 500 501 502 503 504 505 506 507 508 509
        """ Perform tiling on two dimensions

        The final loop order from outmost to inner most are
        [x_outer, y_outer, x_inner, y_inner]

        Parameters
        ----------
        x_parent : IterVar
            The original x dimension
        y_parent : IterVar
            The original y dimension
        x_factor : Expr
            The stride factor on x axis
510 511
        y_factor : Expr
            The stride factor on y axis
512 513 514 515 516 517 518 519 520 521 522 523

        Returns
        -------
        x_outer : IterVar
            Outer axis of x dimension
        y_outer : IterVar
            Outer axis of y dimension
        x_inner : IterVar
            Inner axis of x dimension
        p_y_inner : IterVar
            Inner axis of y dimension
        """
524
        x_outer, y_outer, x_inner, y_inner = _api_internal._StageTile(
ZihengJiang committed
525 526
            self, x_parent, y_parent, x_factor, y_factor)
        return x_outer, y_outer, x_inner, y_inner
527 528 529 530 531 532 533 534 535 536 537

    def vectorize(self, var):
        """Vectorize the iteration.

        Parameters
        ----------
        var : IterVar
            The iteration to be vectorize
        """
        _api_internal._StageVectorize(self, var)

538 539 540 541 542 543 544 545 546 547 548 549 550
    def tensorize(self, var, tensor_intrin):
        """Tensorize the computation enclosed by var with tensor_intrin

        Parameters
        ----------
        var : IterVar
            The iteration boundary of tensorization.

        tensor_intrin : TensorIntrin
            The tensor intrinsic used for computation.
        """
        _api_internal._StageTensorize(self, var, tensor_intrin)

551 552 553 554 555 556 557 558 559
    def unroll(self, var):
        """Unroll the iteration.

        Parameters
        ----------
        var : IterVar
            The iteration to be unrolled.
        """
        _api_internal._StageUnroll(self, var)
560 561 562 563 564 565 566 567 568 569

    def parallel(self, var):
        """Parallelize the iteration.

        Parameters
        ----------
        var : IterVar
            The iteration to be parallelized.
        """
        _api_internal._StageParallel(self, var)
570

571
    def pragma(self, var, pragma_type, pragma_value=None):
572 573 574 575 576 577 578 579 580 581 582 583 584 585
        """Annotate the iteration with pragma

        This will translate to a pragma_scope surrounding
        the corresponding loop generated.
        Useful to support experimental features and extensions.

        Parameters
        ----------
        var : IterVar
            The iteration to be anotated

        pragma_type : str
             The pragma string to be annotated

586 587 588
        pragma_value : Expr, optional
             The pragma value to pass along the pragma

589 590 591 592 593
        Note
        ----
        Most pragmas are advanced/experimental features
        and may subject to change. List of supported pragmas:

594 595 596 597 598
        - **debug_skip_region**

          Force skip the region marked by the axis and turn it into no-op.
          This is useful for debug purposes.

599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618
        - **parallel_launch_point**

          Specify to launch parallel threads outside the
          specified iteration loop. By default the threads
          launch at the point of parallel construct.
          This pragma moves the launching point to even outer scope.
          The threads are launched once and reused across multiple
          parallel constructs as BSP style program.

        - **parallel_barrier_when_finish**

          Insert a synchronization barrier between working threads
          after the specified loop iteration finishes.

        - **parallel_stride_pattern**

          Hint parallel loop to execute in strided pattern.
          :code:`for (int i = task_id; i < end; i += num_task)`

        """
619 620
        if isinstance(pragma_value, string_types):
            pragma_value = convert(pragma_value)
621
        _api_internal._StagePragma(self, var, pragma_type, pragma_value)
622

623 624 625 626 627 628 629 630 631 632 633 634 635 636
    def prefetch(self, tensor, var, offset):
        """Prefetch the specified variable

        Parameters
        ----------
        tensor : Tensor
            The tensor to be prefetched
        var : IterVar
            The loop point at which the prefetching is applied
        offset : Expr
            The number of iterations to be prefetched before actual execution
        """
        _api_internal._StagePrefetch(self, tensor, var, offset)

637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656
    def storage_align(self, axis, factor, offset):
        """Set alignment requirement for specific axis

        This ensures that stride[axis] == k * factor + offset for some k.
        This is useful to set memory layout to for more friendly memory
        access pattern. For example, we can set alignment to be
        factor=2, offset=1 to avoid bank conflict for thread access on
        higher dimension in GPU shared memory.

        Parameters
        ----------
        axis : IterVar
            The axis dimension to be aligned.
        factor : int
            The factor in alignment specification.
        offset : int
            The offset in the alignment specification.
        """
        _api_internal._StageStorageAlign(self, axis, factor, offset)

657 658 659 660 661 662 663 664 665
    def double_buffer(self):
        """Compute the current stage via double buffering.

        This can only be applied to intermediate stage.
        This will double the storage cost of the current stage.
        Can be useful to hide load latency.
        """
        _api_internal._StageDoubleBuffer(self)

666 667 668 669 670 671 672
    def opengl(self):
        """The special OpenGL schedule

        Maps each output element to a pixel.
        """
        _api_internal._StageOpenGL(self)

673
_init_api("tvm.schedule")