schedule_primitives.py 7.55 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40
"""
Schedule Primitives in TVM
==========================
**Author**: `Ziheng Jiang <https://github.com/ZihengJiang>`_

TVM is a domain specific language for efficient kernel construction.

In this tutorial, we will show you how to schedule the computation by
various primitives provided by TVM.
"""
from __future__ import absolute_import, print_function

import tvm
import numpy as np

######################################################################
#
# There often exist several methods to compute the same result,
# however, different methods will result in different locality and
# performance. So TVM asks user to provide how to execute the
# computation called **Schedule**.
#
# A **Schedule** is a set of transformation of computation that
# transforms the loop of computations in the program.
41
#
42 43 44 45 46 47 48

# declare some variables for use later
n = tvm.var('n')
m = tvm.var('m')

######################################################################
# A schedule can be created from a list of ops, by default the
49
# schedule computes tensor in a serial manner in a row-major order.
50 51 52 53 54 55 56 57

# declare a matrix element-wise multiply
A = tvm.placeholder((m, n), name='A')
B = tvm.placeholder((m, n), name='B')
C = tvm.compute((m, n), lambda i, j: A[i, j] * B[i, j], name='C')

s = tvm.create_schedule([C.op])
# lower will transform the computation from definition to the real
58
# callable function. With argument `simple_mode=True`, it will
59 60
# return you a readable C like statement, we use it here to print the
# schedule result.
61
print(tvm.lower(s, [A, B, C], simple_mode=True))
62 63 64 65 66 67 68 69

######################################################################
# One schedule is composed by multiple stages, and one
# **Stage** represents schedule for one operation. We provide various
# methods to schedule every stage.

######################################################################
# split
70
# -----
71 72 73 74 75 76 77
# :code:`split` can split a specified axis into two axises by
# :code:`factor`.
A = tvm.placeholder((m,), name='A')
B = tvm.compute((m,), lambda i: A[i]*2, name='B')

s = tvm.create_schedule(B.op)
xo, xi = s[B].split(B.op.axis[0], factor=32)
78
print(tvm.lower(s, [A, B], simple_mode=True))
79 80 81 82 83 84 85 86 87

######################################################################
# You can also split a axis by :code:`nparts`, which splits the axis
# contrary with :code:`factor`.
A = tvm.placeholder((m,), name='A')
B = tvm.compute((m,), lambda i: A[i], name='B')

s = tvm.create_schedule(B.op)
bx, tx = s[B].split(B.op.axis[0], nparts=32)
88
print(tvm.lower(s, [A, B], simple_mode=True))
89 90 91

######################################################################
# tile
92
# ----
93 94 95 96 97 98 99
# :code:`tile` help you execute the computation tile by tile over two
# axises.
A = tvm.placeholder((m, n), name='A')
B = tvm.compute((m, n), lambda i, j: A[i, j], name='B')

s = tvm.create_schedule(B.op)
xo, yo, xi, yi = s[B].tile(B.op.axis[0], B.op.axis[1], x_factor=10, y_factor=5)
100
print(tvm.lower(s, [A, B], simple_mode=True))
101 102 103

######################################################################
# fuse
104
# ----
105 106 107 108 109 110 111 112
# :code:`fuse` can fuse two consecutive axises of one computation.
A = tvm.placeholder((m, n), name='A')
B = tvm.compute((m, n), lambda i, j: A[i, j], name='B')

s = tvm.create_schedule(B.op)
# tile to four axises first: (i.outer, j.outer, i.inner, j.inner)
xo, yo, xi, yi = s[B].tile(B.op.axis[0], B.op.axis[1], x_factor=10, y_factor=5)
# then fuse (i.inner, j.inner) into one axis: (i.inner.j.inner.fused)
113
fused = s[B].fuse(xi, yi)
114
print(tvm.lower(s, [A, B], simple_mode=True))
115 116 117

######################################################################
# reorder
118
# -------
119 120 121 122 123 124 125 126 127
# :code:`reorder` can reorder the axises in the specified order.
A = tvm.placeholder((m, n), name='A')
B = tvm.compute((m, n), lambda i, j: A[i, j], name='B')

s = tvm.create_schedule(B.op)
# tile to four axises first: (i.outer, j.outer, i.inner, j.inner)
xo, yo, xi, yi = s[B].tile(B.op.axis[0], B.op.axis[1], x_factor=10, y_factor=5)
# then reorder the axises: (i.inner, j.outer, i.outer, j.inner)
s[B].reorder(xi, yo, xo, yi)
128
print(tvm.lower(s, [A, B], simple_mode=True))
129 130 131

######################################################################
# bind
132
# ----
133 134 135 136 137 138 139 140 141
# :code:`bind` can bind a specified axis with a thread axis, often used
# in gpu programming.
A = tvm.placeholder((n,), name='A')
B = tvm.compute(A.shape, lambda i: A[i] * 2, name='B')

s = tvm.create_schedule(B.op)
bx, tx = s[B].split(B.op.axis[0], factor=64)
s[B].bind(bx, tvm.thread_axis("blockIdx.x"))
s[B].bind(tx, tvm.thread_axis("threadIdx.x"))
142
print(tvm.lower(s, [A, B], simple_mode=True))
143 144 145

######################################################################
# compute_at
146
# ----------
147
# For a schedule that consists of multiple operators, TVM will compute
148 149 150 151 152 153
# tensors at the root separately by default.
A = tvm.placeholder((m,), name='A')
B = tvm.compute((m,), lambda i: A[i]+1, name='B')
C = tvm.compute((m,), lambda i: B[i]*2, name='C')

s = tvm.create_schedule(C.op)
154
print(tvm.lower(s, [A, B, C], simple_mode=True))
155 156 157 158 159 160 161 162 163 164

######################################################################
# :code:`compute_at` can move computation of `B` into the first axis
# of computation of `C`.
A = tvm.placeholder((m,), name='A')
B = tvm.compute((m,), lambda i: A[i]+1, name='B')
C = tvm.compute((m,), lambda i: B[i]*2, name='C')

s = tvm.create_schedule(C.op)
s[B].compute_at(s[C], C.op.axis[0])
165
print(tvm.lower(s, [A, B, C], simple_mode=True))
166 167 168

######################################################################
# compute_inline
169
# --------------
170 171 172 173 174 175 176 177 178
# :code:`compute_inline` can mark one stage as inline, then the body of
# computation will be expanded and inserted at the address where the
# tensor is required.
A = tvm.placeholder((m,), name='A')
B = tvm.compute((m,), lambda i: A[i]+1, name='B')
C = tvm.compute((m,), lambda i: B[i]*2, name='C')

s = tvm.create_schedule(C.op)
s[B].compute_inline()
179
print(tvm.lower(s, [A, B, C], simple_mode=True))
180 181 182

######################################################################
# compute_root
183
# ------------
184 185 186 187 188 189 190 191
# :code:`compute_root` can move computation of one stage to the root.
A = tvm.placeholder((m,), name='A')
B = tvm.compute((m,), lambda i: A[i]+1, name='B')
C = tvm.compute((m,), lambda i: B[i]*2, name='C')

s = tvm.create_schedule(C.op)
s[B].compute_at(s[C], C.op.axis[0])
s[B].compute_root()
192
print(tvm.lower(s, [A, B, C], simple_mode=True))
193 194 195 196 197 198 199 200

######################################################################
# Summary
# -------
# This tutorial provides an introduction to schedule primitives in
# tvm, which permits users schedule the computation easily and
# flexibly.
#
201
# In order to get a good performance kernel implementation, the
202 203 204 205 206 207
# general workflow often is:
#
# - Describe your computation via series of operations.
# - Try to schedule the computation with primitives.
# - Compile and run to see the performance difference.
# - Adjust your schedule according the running result.