schedule_primitives.py 6.77 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
"""
Schedule Primitives in TVM
==========================
**Author**: `Ziheng Jiang <https://github.com/ZihengJiang>`_

TVM is a domain specific language for efficient kernel construction.

In this tutorial, we will show you how to schedule the computation by
various primitives provided by TVM.
"""
from __future__ import absolute_import, print_function

import tvm
import numpy as np

######################################################################
#
# There often exist several methods to compute the same result,
# however, different methods will result in different locality and
# performance. So TVM asks user to provide how to execute the
# computation called **Schedule**.
#
# A **Schedule** is a set of transformation of computation that
# transforms the loop of computations in the program.
25
#
26 27 28 29 30 31 32

# declare some variables for use later
n = tvm.var('n')
m = tvm.var('m')

######################################################################
# A schedule can be created from a list of ops, by default the
33
# schedule computes tensor in a serial manner in a row-major order.
34 35 36 37 38 39 40 41

# declare a matrix element-wise multiply
A = tvm.placeholder((m, n), name='A')
B = tvm.placeholder((m, n), name='B')
C = tvm.compute((m, n), lambda i, j: A[i, j] * B[i, j], name='C')

s = tvm.create_schedule([C.op])
# lower will transform the computation from definition to the real
42
# callable function. With argument `simple_mode=True`, it will
43 44
# return you a readable C like statement, we use it here to print the
# schedule result.
45
print(tvm.lower(s, [A, B, C], simple_mode=True))
46 47 48 49 50 51 52 53

######################################################################
# One schedule is composed by multiple stages, and one
# **Stage** represents schedule for one operation. We provide various
# methods to schedule every stage.

######################################################################
# split
54
# -----
55 56 57 58 59 60 61
# :code:`split` can split a specified axis into two axises by
# :code:`factor`.
A = tvm.placeholder((m,), name='A')
B = tvm.compute((m,), lambda i: A[i]*2, name='B')

s = tvm.create_schedule(B.op)
xo, xi = s[B].split(B.op.axis[0], factor=32)
62
print(tvm.lower(s, [A, B], simple_mode=True))
63 64 65 66 67 68 69 70 71

######################################################################
# You can also split a axis by :code:`nparts`, which splits the axis
# contrary with :code:`factor`.
A = tvm.placeholder((m,), name='A')
B = tvm.compute((m,), lambda i: A[i], name='B')

s = tvm.create_schedule(B.op)
bx, tx = s[B].split(B.op.axis[0], nparts=32)
72
print(tvm.lower(s, [A, B], simple_mode=True))
73 74 75

######################################################################
# tile
76
# ----
77 78 79 80 81 82 83
# :code:`tile` help you execute the computation tile by tile over two
# axises.
A = tvm.placeholder((m, n), name='A')
B = tvm.compute((m, n), lambda i, j: A[i, j], name='B')

s = tvm.create_schedule(B.op)
xo, yo, xi, yi = s[B].tile(B.op.axis[0], B.op.axis[1], x_factor=10, y_factor=5)
84
print(tvm.lower(s, [A, B], simple_mode=True))
85 86 87

######################################################################
# fuse
88
# ----
89 90 91 92 93 94 95 96
# :code:`fuse` can fuse two consecutive axises of one computation.
A = tvm.placeholder((m, n), name='A')
B = tvm.compute((m, n), lambda i, j: A[i, j], name='B')

s = tvm.create_schedule(B.op)
# tile to four axises first: (i.outer, j.outer, i.inner, j.inner)
xo, yo, xi, yi = s[B].tile(B.op.axis[0], B.op.axis[1], x_factor=10, y_factor=5)
# then fuse (i.inner, j.inner) into one axis: (i.inner.j.inner.fused)
97
fused = s[B].fuse(xi, yi)
98
print(tvm.lower(s, [A, B], simple_mode=True))
99 100 101

######################################################################
# reorder
102
# -------
103 104 105 106 107 108 109 110 111
# :code:`reorder` can reorder the axises in the specified order.
A = tvm.placeholder((m, n), name='A')
B = tvm.compute((m, n), lambda i, j: A[i, j], name='B')

s = tvm.create_schedule(B.op)
# tile to four axises first: (i.outer, j.outer, i.inner, j.inner)
xo, yo, xi, yi = s[B].tile(B.op.axis[0], B.op.axis[1], x_factor=10, y_factor=5)
# then reorder the axises: (i.inner, j.outer, i.outer, j.inner)
s[B].reorder(xi, yo, xo, yi)
112
print(tvm.lower(s, [A, B], simple_mode=True))
113 114 115

######################################################################
# bind
116
# ----
117 118 119 120 121 122 123 124 125
# :code:`bind` can bind a specified axis with a thread axis, often used
# in gpu programming.
A = tvm.placeholder((n,), name='A')
B = tvm.compute(A.shape, lambda i: A[i] * 2, name='B')

s = tvm.create_schedule(B.op)
bx, tx = s[B].split(B.op.axis[0], factor=64)
s[B].bind(bx, tvm.thread_axis("blockIdx.x"))
s[B].bind(tx, tvm.thread_axis("threadIdx.x"))
126
print(tvm.lower(s, [A, B], simple_mode=True))
127 128 129

######################################################################
# compute_at
130
# ----------
131 132 133 134 135 136 137
# For a schedule consists of multiple operators, tvm will compute
# tensors at the root separately by default.
A = tvm.placeholder((m,), name='A')
B = tvm.compute((m,), lambda i: A[i]+1, name='B')
C = tvm.compute((m,), lambda i: B[i]*2, name='C')

s = tvm.create_schedule(C.op)
138
print(tvm.lower(s, [A, B, C], simple_mode=True))
139 140 141 142 143 144 145 146 147 148

######################################################################
# :code:`compute_at` can move computation of `B` into the first axis
# of computation of `C`.
A = tvm.placeholder((m,), name='A')
B = tvm.compute((m,), lambda i: A[i]+1, name='B')
C = tvm.compute((m,), lambda i: B[i]*2, name='C')

s = tvm.create_schedule(C.op)
s[B].compute_at(s[C], C.op.axis[0])
149
print(tvm.lower(s, [A, B, C], simple_mode=True))
150 151 152

######################################################################
# compute_inline
153
# --------------
154 155 156 157 158 159 160 161 162
# :code:`compute_inline` can mark one stage as inline, then the body of
# computation will be expanded and inserted at the address where the
# tensor is required.
A = tvm.placeholder((m,), name='A')
B = tvm.compute((m,), lambda i: A[i]+1, name='B')
C = tvm.compute((m,), lambda i: B[i]*2, name='C')

s = tvm.create_schedule(C.op)
s[B].compute_inline()
163
print(tvm.lower(s, [A, B, C], simple_mode=True))
164 165 166

######################################################################
# compute_root
167
# ------------
168 169 170 171 172 173 174 175
# :code:`compute_root` can move computation of one stage to the root.
A = tvm.placeholder((m,), name='A')
B = tvm.compute((m,), lambda i: A[i]+1, name='B')
C = tvm.compute((m,), lambda i: B[i]*2, name='C')

s = tvm.create_schedule(C.op)
s[B].compute_at(s[C], C.op.axis[0])
s[B].compute_root()
176
print(tvm.lower(s, [A, B, C], simple_mode=True))
177 178 179 180 181 182 183 184

######################################################################
# Summary
# -------
# This tutorial provides an introduction to schedule primitives in
# tvm, which permits users schedule the computation easily and
# flexibly.
#
185
# In order to get a good performance kernel implementation, the
186 187 188 189 190 191
# general workflow often is:
#
# - Describe your computation via series of operations.
# - Try to schedule the computation with primitives.
# - Compile and run to see the performance difference.
# - Adjust your schedule according the running result.