schedule.py 16.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17
# pylint: disable=unused-import
18
"""The computation schedule api of TVM."""
19
import tvm._ffi
20
from tvm._ffi.base import string_types
21

22
from tvm.runtime import Object, convert
23
from tvm.ir import container as _container
24
from tvm.tir import IterVar, Buffer
25

26
from . import tensor as _tensor
27
from . import _ffi_api
28

29

30
@tvm._ffi.register_object
31
class Split(Object):
32
    """Split operation on axis."""
33

34

35
@tvm._ffi.register_object
36
class Fuse(Object):
37
    """Fuse operation on axis."""
38

39

40
@tvm._ffi.register_object
41
class Singleton(Object):
42 43 44
    """Singleton axis."""


45 46 47 48 49 50 51 52 53 54 55 56 57
def create_schedule(ops):
    """Create a schedule for list of ops

    Parameters
    ----------
    ops : list of Operations
        The source expression.

    Returns
    -------
    sch : schedule.Schedule
        The created schedule.
    """
58
    if not isinstance(ops, (list, _container.Array)):
59
        ops = [ops]
60
    return _ffi_api.CreateSchedule(ops)
61 62


63
@tvm._ffi.register_object
64
class Schedule(Object):
65
    """Schedule for all the stages."""
66 67 68 69 70
    def __getitem__(self, k):
        if isinstance(k, _tensor.Tensor):
            k = k.op
        if not isinstance(k, _tensor.Operation):
            raise ValueError("Expect schedule key to be Tensor or Operation")
71
        if k not in self.stage_map:
72 73 74
            raise ValueError("Cannot find the operation %s in schedule" % (str(k)))
        return self.stage_map[k]

75
    def normalize(self):
76
        """Build a normalized schedule from the current schedule.
77 78 79

        Insert necessary rebase to make certain iter var to start from 0.
        This is needed before bound inference and followup step.
80 81 82 83 84

        Returns
        -------
        sch : Schedule
            The normalized schedule.
85
        """
86
        return _ffi_api.ScheduleNormalize(self)
87

88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103
    def create_group(self, outputs, inputs, include_inputs=False):
        """Create stage group by giving output and input boundary.

        The operators between outputs and inputs are placed as member of group.
        outputs are include in the group, while inputs are not included.

        Parameters
        ----------
        outputs : list of Tensors
            The outputs of the group.

        inputs : list of Tensors
            The inputs of the group.

        include_inputs : boolean, optional
            Whether include input operations in the group if they are used by outputs.
104 105 106 107 108 109

        Returns
        -------
        group : Stage
            A virtual stage represents the group, user can use compute_at to move
            the attachment point of the group.
110 111 112 113 114
        """
        if isinstance(outputs, _tensor.Tensor):
            outputs = [outputs]
        if isinstance(inputs, _tensor.Tensor):
            inputs = [inputs]
115
        return _ffi_api.ScheduleCreateGroup(
116 117
            self, outputs, inputs, include_inputs)

118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141
    def cache_read(self, tensor, scope, readers):
        """Create a cache read of original tensor for readers.

        This will mutate the body of the readers.
        A new cache stage will be created for the tensor.
        Call this before doing any split/fuse schedule.

        Parameters
        ----------
        tensor : Tensor
            The tensor to be cached.
        scope : str
            The scope of cached
        readers : list of Tensor or Operation
            The readers to read the cache.

        Returns
        -------
        cache : Tensor
            The created cache tensor.
        """
        if isinstance(readers, (_tensor.Tensor, _tensor.Operation)):
            readers = [readers]
        readers = [t.op if isinstance(t, _tensor.Tensor) else t for t in readers]
142
        return _ffi_api.ScheduleCacheRead(self, tensor, scope, readers)
143 144 145 146 147 148 149

    def cache_write(self, tensor, scope):
        """Create a cache write of original tensor, before storing into tensor.

        This will mutate the body of the tensor.
        A new cache stage will created before feed into the tensor.

150 151 152 153 154 155 156 157
        This function can be used to support data layout transformation.
        If there is a split/fuse/reorder on the data parallel axis of tensor
        before cache_write is called. The intermediate cache stores
        the data in the layout as the iteration order of leave axis.
        The data will be transformed back to the original layout in the original tensor.
        User can further call compute_inline to inline the original layout and keep
        the data stored in the transformed layout.

158 159
        Parameters
        ----------
160 161
        tensor : Tensor, list or tuple
            The tensors to be feed to. All the tensors must be produced by one computeOp
162 163 164 165 166 167 168 169
        scope : str
            The scope of cached

        Returns
        -------
        cache : Tensor
            The created cache tensor.
        """
170
        return _ffi_api.ScheduleCacheWrite(self, tensor, scope)
171

172
    def rfactor(self, tensor, axis, factor_axis=0):
173 174 175
        """ Factor a reduction axis in tensor's schedule to be an explicit axis.

        This will create a new stage that generated the new tensor with axis
176
        as the first dimension. The tensor's body will be rewritten as a reduction
177 178 179 180 181 182 183 184
        over the factored tensor.

        Parameters
        ----------
        tensor : Tensor
            The tensor to be factored.
        axis : IterVar
            The reduction axis in the schedule to be factored.
185 186
        factor_axis : int
            The position where the new axis is placed.
187 188 189

        Returns
        -------
190
        tfactor : Tensor or Array of Tensor
191 192
            The created factored tensor.
        """
193
        factored = _ffi_api.ScheduleRFactor(self, tensor, axis, factor_axis)
194
        return factored[0] if len(factored) == 1 else factored
195

196

197
@tvm._ffi.register_object
198
class Stage(Object):
199
    """A Stage represents schedule for one operation."""
200
    def split(self, parent, factor=None, nparts=None):
201
        """Split the stage either by factor providing outer scope, or both
tqchen committed
202 203 204 205 206 207 208 209 210

        Parameters
        ----------
        parent : IterVar
             The parent iter var.

        factor : Expr, optional
             The splitting factor

211 212
        nparts : Expr, optional
             The number of outer parts.
tqchen committed
213 214 215 216 217 218 219 220 221

        Returns
        -------
        outer : IterVar
            The outer variable of iteration.

        inner : IterVar
            The inner variable of iteration.
        """
222 223
        if nparts is not None:
            if factor is not None:
Siju committed
224
                raise ValueError("Do not need to provide both outer and nparts")
225
            outer, inner = _ffi_api.StageSplitByNParts(self, parent, nparts)
tqchen committed
226 227
        else:
            if factor is None:
228
                raise ValueError("Either nparts or factor need to be provided")
229
            outer, inner = _ffi_api.StageSplitByFactor(self, parent, factor)
tqchen committed
230 231
        return outer, inner

232 233 234 235 236
    def fuse(self, *args):
        """Fuse multiple consecutive iteration variables into a single iteration variable.

        fused = fuse(...fuse(fuse(args[0], args[1]), args[2]),..., args[-1])
        The order is from outer to inner.
tqchen committed
237 238 239

        Parameters
        ----------
240 241
        args : list of IterVars
            Itervars that proceeds each other
tqchen committed
242 243 244

        Returns
        -------
245
        fused : IterVar
tqchen committed
246 247
            The fused variable of iteration.
        """
248
        fused = _ffi_api.StageFuse(self, args)
249
        return fused
250 251 252 253 254 255 256 257 258

    def set_scope(self, scope):
        """Set the thread scope of this stage

        Parameters
        ----------
        scope : str
            The thread scope of this stage
        """
259
        return _ffi_api.StageSetScope(self, scope)
tqchen committed
260

261 262 263 264 265 266 267 268 269 270 271
    def bind(self, ivar, thread_ivar):
        """Bind ivar to thread index thread_ivar

        Parameters
        ----------
        ivar : IterVar
            The iteration to be binded to thread.

        thread_ivar : IterVar
            The thread to be binded.
        """
272
        _ffi_api.StageBind(self, ivar, thread_ivar)
273 274 275

    def env_threads(self, threads):
        """Mark threads to be launched at the outer scope of composed op.
276 277 278 279 280 281

        Parameters
        ----------
        threads : list of threads
            The threads to be launched.
        """
282
        if isinstance(threads, IterVar):
283
            threads = [threads]
284
        _ffi_api.StageEnvThreads(self, threads)
285

286 287 288 289 290 291 292 293 294 295 296
    def set_store_predicate(self, predicate):
        """Set predicate under which store to the array can be performed.

        Use this when there are duplicated threads doing the same store and we only
        need one of them to do the store.

        Parameters
        ----------
        predicate : Expr
            The guard condition fo store.
        """
297
        _ffi_api.StageSetStorePredicate(self, predicate)
298

tqchen committed
299
    def compute_at(self, parent, scope):
300
        """Attach the stage at parent's scope
tqchen committed
301 302 303

        Parameters
        ----------
304 305
        parent : Stage
            The parent stage
tqchen committed
306 307 308 309

        scope : IterVar
            The loop scope t be attached to.
        """
310
        _ffi_api.StageComputeAt(self, parent, scope)
tqchen committed
311

312 313
    def compute_inline(self):
        """Mark stage as inline
tqchen committed
314 315 316

        Parameters
        ----------
317 318
        parent : Stage
            The parent stage
tqchen committed
319
        """
320
        _ffi_api.StageComputeInline(self)
tqchen committed
321

322 323
    def compute_root(self):
        """Attach the stage at parent, and mark it as root
tqchen committed
324 325 326

        Parameters
        ----------
327 328
        parent : Stage
            The parent stage
tqchen committed
329
        """
330
        _ffi_api.StageComputeRoot(self)
tqchen committed
331 332 333 334 335 336 337 338 339

    def reorder(self, *args):
        """reorder the arguments in the specified order.

        Parameters
        ----------
        args : list of IterVar
            The order to be ordered
        """
340
        _ffi_api.StageReorder(self, args)
ZihengJiang committed
341 342

    def tile(self, x_parent, y_parent, x_factor, y_factor):
343 344 345 346 347 348 349 350 351 352 353 354 355
        """ Perform tiling on two dimensions

        The final loop order from outmost to inner most are
        [x_outer, y_outer, x_inner, y_inner]

        Parameters
        ----------
        x_parent : IterVar
            The original x dimension
        y_parent : IterVar
            The original y dimension
        x_factor : Expr
            The stride factor on x axis
356 357
        y_factor : Expr
            The stride factor on y axis
358 359 360 361 362 363 364 365 366 367 368 369

        Returns
        -------
        x_outer : IterVar
            Outer axis of x dimension
        y_outer : IterVar
            Outer axis of y dimension
        x_inner : IterVar
            Inner axis of x dimension
        p_y_inner : IterVar
            Inner axis of y dimension
        """
370
        x_outer, y_outer, x_inner, y_inner = _ffi_api.StageTile(
ZihengJiang committed
371 372
            self, x_parent, y_parent, x_factor, y_factor)
        return x_outer, y_outer, x_inner, y_inner
373 374 375 376 377 378 379 380 381

    def vectorize(self, var):
        """Vectorize the iteration.

        Parameters
        ----------
        var : IterVar
            The iteration to be vectorize
        """
382
        _ffi_api.StageVectorize(self, var)
383

384 385 386 387 388 389 390 391 392 393 394
    def tensorize(self, var, tensor_intrin):
        """Tensorize the computation enclosed by var with tensor_intrin

        Parameters
        ----------
        var : IterVar
            The iteration boundary of tensorization.

        tensor_intrin : TensorIntrin
            The tensor intrinsic used for computation.
        """
395
        _ffi_api.StageTensorize(self, var, tensor_intrin)
396

397 398 399 400 401 402 403 404
    def unroll(self, var):
        """Unroll the iteration.

        Parameters
        ----------
        var : IterVar
            The iteration to be unrolled.
        """
405
        _ffi_api.StageUnroll(self, var)
406 407 408 409 410 411 412 413 414

    def parallel(self, var):
        """Parallelize the iteration.

        Parameters
        ----------
        var : IterVar
            The iteration to be parallelized.
        """
415
        _ffi_api.StageParallel(self, var)
416

417
    def pragma(self, var, pragma_type, pragma_value=None):
418 419 420 421 422 423 424 425 426 427 428 429 430 431
        """Annotate the iteration with pragma

        This will translate to a pragma_scope surrounding
        the corresponding loop generated.
        Useful to support experimental features and extensions.

        Parameters
        ----------
        var : IterVar
            The iteration to be anotated

        pragma_type : str
             The pragma string to be annotated

432 433 434
        pragma_value : Expr, optional
             The pragma value to pass along the pragma

435 436 437 438 439
        Note
        ----
        Most pragmas are advanced/experimental features
        and may subject to change. List of supported pragmas:

440 441 442 443 444
        - **debug_skip_region**

          Force skip the region marked by the axis and turn it into no-op.
          This is useful for debug purposes.

445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464
        - **parallel_launch_point**

          Specify to launch parallel threads outside the
          specified iteration loop. By default the threads
          launch at the point of parallel construct.
          This pragma moves the launching point to even outer scope.
          The threads are launched once and reused across multiple
          parallel constructs as BSP style program.

        - **parallel_barrier_when_finish**

          Insert a synchronization barrier between working threads
          after the specified loop iteration finishes.

        - **parallel_stride_pattern**

          Hint parallel loop to execute in strided pattern.
          :code:`for (int i = task_id; i < end; i += num_task)`

        """
465 466
        if isinstance(pragma_value, string_types):
            pragma_value = convert(pragma_value)
467
        _ffi_api.StagePragma(self, var, pragma_type, pragma_value)
468

469 470 471 472 473 474 475 476 477 478 479 480
    def prefetch(self, tensor, var, offset):
        """Prefetch the specified variable

        Parameters
        ----------
        tensor : Tensor
            The tensor to be prefetched
        var : IterVar
            The loop point at which the prefetching is applied
        offset : Expr
            The number of iterations to be prefetched before actual execution
        """
481
        _ffi_api.StagePrefetch(self, tensor, var, offset)
482

483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500
    def storage_align(self, axis, factor, offset):
        """Set alignment requirement for specific axis

        This ensures that stride[axis] == k * factor + offset for some k.
        This is useful to set memory layout to for more friendly memory
        access pattern. For example, we can set alignment to be
        factor=2, offset=1 to avoid bank conflict for thread access on
        higher dimension in GPU shared memory.

        Parameters
        ----------
        axis : IterVar
            The axis dimension to be aligned.
        factor : int
            The factor in alignment specification.
        offset : int
            The offset in the alignment specification.
        """
501
        _ffi_api.StageStorageAlign(self, axis, factor, offset)
502

503 504 505 506 507 508 509
    def double_buffer(self):
        """Compute the current stage via double buffering.

        This can only be applied to intermediate stage.
        This will double the storage cost of the current stage.
        Can be useful to hide load latency.
        """
510
        _ffi_api.StageDoubleBuffer(self)
511

512 513 514 515 516
    def opengl(self):
        """The special OpenGL schedule

        Maps each output element to a pixel.
        """
517 518
        _ffi_api.StageOpenGL(self)

519

520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554
@tvm._ffi.register_object
class SpecializedCondition(Object):
    """Specialized condition to enable op specialization."""
    def __init__(self, conditions):
        """Create a specialized condition.

        .. note::
            Conditions are represented in conjunctive joint form (CNF).
            Each condition should be a simple expression, e.g., n > 16,
            m % 8 == 0, etc., where n, m are tvm.Var that represents a
            dimension in the tensor shape.

        Parameters
        ----------
        conditions : List of tvm.Expr
            List of conditions in conjunctive joint form (CNF).
        """
        if not isinstance(conditions, (list, _container.Array)):
            conditions = [conditions]
        self.__init_handle_by_constructor__(
            _ffi_api.CreateSpecializedCondition, conditions)

    @staticmethod
    def current():
        """Returns the current specialized condition"""
        return _ffi_api.GetCurrentSpecialization()

    def __enter__(self):
        _ffi_api.EnterSpecializationScope(self)
        return self

    def __exit__(self, ptype, value, trace):
        _ffi_api.ExitSpecializationScope(self)


555
tvm._ffi._init_api("schedule", __name__)