get_started.py 10.5 KB
Newer Older
1 2 3 4 5 6
"""
Get Started with TVM
====================
**Author**: `Tianqi Chen <https://tqchen.github.io>`_

This is an introduction tutorial to TVM.
7
TVM is a domain specific language for efficient kernel construction.
8 9 10 11 12 13 14 15

In this tutorial, we will demonstrate the basic workflow in TVM.
"""
from __future__ import absolute_import, print_function

import tvm
import numpy as np

16 17 18 19 20 21
# Global declarations of environment.

tgt_host="llvm"
# Change it to respective GPU if gpu is enabled Ex: cuda, opencl
tgt="cuda"

22 23 24 25
######################################################################
# Vector Add Example
# ------------------
# In this tutorial, we will use a vector addition example to demonstrate
26
# the workflow.
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
#

######################################################################
# Describe the Computation
# ------------------------
# As a first step, we need to describe our computation.
# TVM adopts tensor semantics, with each intermediate result
# represented as multi-dimensional array. The user need to describe
# the computation rule that generate the tensors.
#
# We first define a symbolic variable n to represent the shape.
# We then define two placeholder Tensors, A and B, with given shape (n,)
#
# We then describe the result tensor C, with a compute operation.
# The compute function takes the shape of the tensor, as well as a lambda function
# that describes the computation rule for each position of the tensor.
#
# No computation happens during this phase, as we are only declaring how
# the computation should be done.
#
n = tvm.var("n")
A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B')
C = tvm.compute(A.shape, lambda i: A[i] + B[i], name="C")
print(type(C))

######################################################################
# Schedule the Computation
# ------------------------
# While the above lines describes the computation rule, we can compute
# C in many ways since the axis of C can be computed in data parallel manner.
# TVM asks user to provide a description of computation called schedule.
#
# A schedule is a set of transformation of computation that transforms
# the loop of computations in the program.
#
# After we construct the schedule, by default the schedule computes
# C in a serial manner in a row-major order.
#
# .. code-block:: c
#
#   for (int i = 0; i < n; ++i) {
#     C[i] = A[i] + B[i];
#   }
#
s = tvm.create_schedule(C.op)

######################################################################
# We used the split construct to split the first axis of C,
# this will split the original iteration axis into product of
# two iterations. This is equivalent to the following code.
#
# .. code-block:: c
#
#   for (int bx = 0; bx < ceil(n / 64); ++bx) {
#     for (int tx = 0; tx < 64; ++tx) {
#       int i = bx * 64 + tx;
#       if (i < n) {
#         C[i] = A[i] + B[i];
#       }
#     }
#   }
#
bx, tx = s[C].split(C.op.axis[0], factor=64)

######################################################################
93
# Finally we bind the iteration axis bx and tx to threads in the GPU
94 95 96
# compute grid. These are GPU specific constructs that allows us
# to generate code that runs on GPU.
#
97 98 99
if tgt == "cuda":
  s[C].bind(bx, tvm.thread_axis("blockIdx.x"))
  s[C].bind(tx, tvm.thread_axis("threadIdx.x"))
100 101 102 103 104 105 106 107 108 109 110 111 112

######################################################################
# Compilation
# -----------
# After we have finished specifying the schedule, we can compile it
# into a TVM function. By default TVM compiles into a type-erased
# function that can be directly called from python side.
#
# In the following line, we use tvm.build to create a function.
# The build function takes the schedule, the desired signature of the
# function(including the inputs and outputs) as well as target language
# we want to compile to.
#
113 114
# The result of compilation fadd is a GPU device function(if GPU is involved) 
# that can as well as a host wrapper that calls into the GPU function.
115 116 117
# fadd is the generated host wrapper function, it contains reference
# to the generated device function internally.
#
118
fadd = tvm.build(s, [A, B, C], tgt, target_host=tgt_host, name="myadd")
119 120 121 122 123 124 125 126 127 128 129

######################################################################
# Run the Function
# ----------------
# The compiled function TVM function is designed to be a concise C API
# that can be invoked from any languages.
#
# We provide an minimum array API in python to aid quick testing and prototyping.
# The array API is based on `DLPack <https://github.com/dmlc/dlpack>`_ standard.
#
# - We first create a gpu context.
130
# - Then tvm.nd.array copies the data to gpu.
131 132 133
# - fadd runs the actual computation.
# - asnumpy() copies the gpu array back to cpu and we can use this to verify correctness
#
134 135
ctx = tvm.context(tgt, 0)

136 137 138 139
n = 1024
a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), ctx)
c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
140
fadd(a, b, c)
141 142 143 144 145 146 147
np.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())

######################################################################
# Inspect the Generated Code
# --------------------------
# You can inspect the generated code in TVM. The result of tvm.build
# is a tvm Module. fadd is the host module that contains the host wrapper,
148
# it also contains a device module for the CUDA (GPU) function.
149
#
150
# The following code fetches the device module and prints the content code.
151
#
152 153 154 155 156 157
if tgt == "cuda":
    dev_module = fadd.imported_modules[0]
    print("-----GPU code-----")
    print(dev_module.get_source())
else:
    print(fadd.get_source())
158 159 160 161 162 163 164 165 166

######################################################################
# .. note:: Code Specialization
#
#   As you may noticed, during the declaration, A, B and C both
#   takes the same shape argument n. TVM will take advantage of this
#   to pass only single shape argument to the kernel, as you will find in
#   the printed device code. This is one form of specialization.
#
167
#   On the host side, TVM will automatically generate check code
168
#   that checks the constraints in the parameters. So if you pass
169
#   arrays with different shapes into the fadd, an error will be raised.
170 171 172 173 174 175 176 177 178 179
#
#   We can do more specializations. For example, we can write
#   :code:`n = tvm.convert(1024)` instead of :code:`n = tvm.var("n")`,
#   in the computation declaration. The generated function will
#   only take vectors with length 1024.
#

######################################################################
# Save Compiled Module
# --------------------
180
# Besides runtime compilation, we can save the compiled modules into
181 182 183 184 185 186 187 188
# file and load them back later. This is called ahead of time compilation.
#
# The following code first does the following step:
#
# - It saves the compiled host module into an object file.
# - Then it saves the device module into a ptx file.
# - cc.create_shared calls a env compiler(gcc) to create a shared library
#
189
from tvm.contrib import cc
190
from tvm.contrib import util
191

192
temp = util.tempdir()
193 194 195
fadd.save(temp.relpath("myadd.o"))
if tgt == "cuda":
    fadd.imported_modules[0].save(temp.relpath("myadd.ptx"))
196 197 198 199 200 201 202 203 204
cc.create_shared(temp.relpath("myadd.so"), [temp.relpath("myadd.o")])
print(temp.listdir())

######################################################################
# .. note:: Module Storage Format
#
#   The CPU(host) module is directly saved as a shared library(so).
#   There can be multiple customed format on the device code.
#   In our example, device code is stored in ptx, as well as a meta
205
#   data json file. They can be loaded and linked seperatedly via import.
206 207 208 209 210 211 212 213 214 215
#

######################################################################
# Load Compiled Module
# --------------------
# We can load the compiled module from the file system and run the code.
# The following code load the host and device module seperatedly and
# re-link them together. We can verify that the newly loaded function works.
#
fadd1 = tvm.module.load(temp.relpath("myadd.so"))
216 217 218
if tgt == "cuda":
    fadd1_dev = tvm.module.load(temp.relpath("myadd.ptx"))
    fadd1.import_module(fadd1_dev)
219 220 221 222
fadd1(a, b, c)
np.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())

######################################################################
223 224 225
# Pack Everything into One Library
# --------------------------------
# In the above example, we store the device and host code seperatedly.
226
# TVM also supports export everything as one shared library.
227 228 229 230
# Under the hood, we pack the device modules into binary blobs and link
# them together with the host code.
# Currently we support packing of Metal, OpenCL and CUDA modules.
#
231
fadd.export_library(temp.relpath("myadd_pack.so"))
232 233 234 235 236
fadd2 = tvm.module.load(temp.relpath("myadd_pack.so"))
fadd2(a, b, c)
np.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())

######################################################################
237 238 239 240
# .. note:: Runtime API and Thread-Safety
#
#   The compiled modules of TVM do not depend on the TVM compiler.
#   Instead, it only depends on a minimum runtime library.
241
#   TVM runtime library wraps the device drivers and provides
242 243 244 245 246 247 248 249 250 251 252 253 254 255 256
#   thread-safe and device agnostic call into the compiled functions.
#
#   This means you can call the compiled TVM function from any thread,
#   on any GPUs.
#

######################################################################
# Generate OpenCL Code
# --------------------
# TVM provides code generation features into multiple backends,
# we can also generate OpenCL code or LLVM code that runs on CPU backends.
#
# The following codeblocks generate opencl code, creates array on opencl
# device, and verifies the correctness of the code.
#
257 258 259 260 261 262 263 264 265 266 267
if tgt == "opencl":
    fadd_cl = tvm.build(s, [A, B, C], "opencl", name="myadd")
    print("------opencl code------")
    print(fadd_cl.imported_modules[0].get_source())
    ctx = tvm.cl(0)
    n = 1024
    a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
    b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), ctx)
    c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
    fadd_cl(a, b, c)
    np.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283

######################################################################
# Summary
# -------
# This tutorial provides a walk through of TVM workflow using
# a vector add example. The general workflow is
#
# - Describe your computation via series of operations.
# - Describe how we want to compute use schedule primitives.
# - Compile to the target function we want.
# - Optionally, save the function to be loaded later.
#
# You are more than welcomed to checkout other examples and
# tutorials to learn more about the supported operations, schedule primitives
# and other features in TVM.
#