reduction.py 6.72 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
"""
Reduction
=========
**Author**: `Tianqi Chen <https://tqchen.github.io>`_

This is an introduction material on how to do reduction in TVM.
Associative reduction operators like sum/max/min are typical
construction blocks of linear algebra operations.

In this tutorial, we will demonstrate how to do reduction in TVM.
"""
from __future__ import absolute_import, print_function

import tvm
import numpy as np

######################################################################
# Describe Sum of Rows
# --------------------
# Assume we want to compute sum of rows as our example.
# In numpy semantics this can be written as :code:`B = numpy.sum(A, axis=1)`
#
23
# The following lines describe the row sum operation.
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
# To create a reduction formula, we declare a reduction axis using
# :any:`tvm.reduce_axis`. :any:`tvm.reduce_axis` takes in the range of reductions.
# :any:`tvm.sum` takes in the expression to be reduced as well as the reduction
# axis and compute the sum of value over all k in the declared range.
#
# The equivalent C code is as follows:
#
# .. code-block:: c
#
#   for (int i = 0; i < n; ++i) {
#     B[i] = 0;
#     for (int k = 0; k < m; ++k) {
#       B[i] = B[i] + A[i][k];
#     }
#   }
#
n = tvm.var("n")
m = tvm.var("m")
A = tvm.placeholder((n, m), name='A')
k = tvm.reduce_axis((0, m), "k")
B = tvm.compute((n,), lambda i: tvm.sum(A[i, k], axis=k), name="B")

######################################################################
# Schedule the Reduction
# ----------------------
# There are several ways to schedule a reduction.
# Before doing anything, let us print out the IR code of default schedule.
#
s = tvm.create_schedule(B.op)
53
print(tvm.lower(s, [A, B], simple_mode=True))
54 55 56 57 58 59 60 61 62 63

######################################################################
# You can find that the IR code is quite like the C code.
# The reduction axis is similar to a normal axis, it can be splitted.
#
# In the following code we split both the row axis of B as well
# axis by different factors. The result is a nested reduction.
#
ko, ki = s[B].split(B.op.reduce_axis[0], factor=16)
xo, xi = s[B].split(B.op.axis[0], factor=32)
64
print(tvm.lower(s, [A, B], simple_mode=True))
65 66 67

######################################################################
# If we are building a GPU kernel, we can bind the rows of B to GPU threads.
68 69
s[B].bind(xo, tvm.thread_axis("blockIdx.x"))
s[B].bind(xi, tvm.thread_axis("threadIdx.x"))
70
print(tvm.lower(s, [A, B], simple_mode=True))
71 72 73 74 75

######################################################################
# Reduction Factoring and Parallelization
# ---------------------------------------
# One problem of building a reduction is that we cannot simply
76 77 78
# parallelize over the reduction axis. We need to divide the computation
# of the reduction, store the local reduction result in a temporal array
# before doing a reduction over the temp array.
79 80
#
# The rfactor primitive does such rewrite of the computation.
81
# In the following schedule, the result of B is written to a temporary
82 83 84 85 86
# result B.rf. The factored dimension becomes the first dimension of B.rf.
#
s = tvm.create_schedule(B.op)
ko, ki = s[B].split(B.op.reduce_axis[0], factor=16)
BF = s.rfactor(B, ki)
87
print(tvm.lower(s, [A, B], simple_mode=True))
88 89 90 91 92 93 94 95 96 97 98

######################################################################
# The scheduled operator of B also get rewritten to be sum over
# the first axis of reduced result of B.f
#
print(s[B].op.body)

######################################################################
# Cross Thread Reduction
# ----------------------
# We can now parallelize over the factored axis.
99 100
# Here the reduction axis of B is marked to be a thread.
# TVM allows reduction axis to be marked as thread if it is the only
101 102 103 104
# axis in reduction and cross thread reduction is possible in the device.
#
# This is indeed the case after the factoring.
# We can directly compute BF at the reduction axis as well.
105
# The final generated kernel will divide the rows by blockIdx.x and threadIdx.y
106 107 108
# columns by threadIdx.x and finally do a cross thread reduction over threadIdx.x
#
xo, xi = s[B].split(s[B].op.axis[0], factor=32)
109 110
s[B].bind(xo, tvm.thread_axis("blockIdx.x"))
s[B].bind(xi, tvm.thread_axis("threadIdx.y"))
111 112
tx = tvm.thread_axis("threadIdx.x")
s[B].bind(s[B].op.reduce_axis[0], tx)
113
s[BF].compute_at(s[B], s[B].op.reduce_axis[0])
114
s[B].set_store_predicate(tx.var.equal(0))
115 116 117 118 119 120 121 122 123 124 125
fcuda = tvm.build(s, [A, B], "cuda")
print(fcuda.imported_modules[0].get_source())

######################################################################
# Verify the correctness of result kernel by comparing it to numpy.
#
nn = 128
ctx  = tvm.gpu(0)
a = tvm.nd.array(np.random.uniform(size=(nn, nn)).astype(A.dtype), ctx)
b = tvm.nd.array(np.zeros(nn, dtype=B.dtype), ctx)
fcuda(a, b)
126
tvm.testing.assert_allclose(
127 128 129
    b.asnumpy(),  np.sum(a.asnumpy(), axis=1), rtol=1e-4)

######################################################################
130
# Describe Convolution via 2D Reduction
131 132
# -------------------------------------
# In TVM, we can describe convolution via 2D reduction in a simple way.
133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
# Here is an example for 2D convolution with filter size = [3, 3] and strides = [1, 1].
#
n = tvm.var('n')
Input = tvm.placeholder((n, n), name='Input')
Filter = tvm.placeholder((3, 3), name='Filter')
di = tvm.reduce_axis((0, 3), name='di')
dj = tvm.reduce_axis((0, 3), name='dj')
Output = tvm.compute(
    (n - 2, n - 2),
    lambda i, j: tvm.sum(Input[i + di, j + dj] * Filter[di, dj], axis=[di, dj]),
    name='Output')
s = tvm.create_schedule(Output.op)
print(tvm.lower(s, [Input, Filter, Output], simple_mode=True))

######################################################################
148 149
# .. _general-reduction:
#
150 151 152 153 154 155 156 157 158 159 160 161 162 163 164
# Define General Commutative Reduction Operation
# ----------------------------------------------
# Besides the built-in reduction operations like :any:`tvm.sum`,
# :any:`tvm.min` and :any:`tvm.max`, you can also define your
# commutative reduction operation by :any:`tvm.comm_reducer`.
#

n = tvm.var('n')
m = tvm.var('m')
product = tvm.comm_reducer(lambda x, y: x*y,
    lambda t: tvm.const(1, dtype=t), name="product")
A = tvm.placeholder((n, m), name='A')
k = tvm.reduce_axis((0, m), name='k')
B = tvm.compute((n,), lambda i: product(A[i, k], axis=k), name='B')

165 166 167 168 169 170
######################################################################
# .. note::
#
#   Sometimes we would like to perform reduction that involves multiple
#   values like :code:`argmax`, which can be done by tuple inputs.
#   See :ref:`reduction-with-tuple-inputs` for more detail.
171 172

######################################################################
173 174 175 176 177 178
# Summary
# -------
# This tutorial provides a walk through of reduction schedule.
#
# - Describe reduction with reduce_axis.
# - Use rfactor to factor out axis if we need parallelism.
179
# - Define new reduction operation by :any:`tvm.comm_reducer`