Commit 968c539b by Tianqi Chen Committed by GitHub

[DOC/PERF] Reduction Tutorial and GEMM (#96)

* [PERF] Add gemm

* [DOC] Reduction tutorial
parent bfa7c4a5
......@@ -36,7 +36,7 @@ class CodeGenCUDA : public CodeGenC {
private:
// magic number to add pragma unroll to it.
// used to generate code that is compact but still unrolls.
int max_auto_unroll_{1025};
int max_auto_unroll_{64};
// Whether global barrier is needed.
bool need_global_barrier_{false};
// Global barrier state
......
......@@ -112,12 +112,12 @@ class ThreadAllreduceBuilder : public IRMutator {
IterVar iv(attr->node.node_);
e.scope = runtime::ThreadScope::make(iv->thread_tag);
e.iv = iv;
CHECK(arith::GetConstInt(attr->value, &(e.extent)))
<< "Need constant extent for thread group";
CHECK_LE(e.scope.rank, 1);
CHECK_GE(e.scope.dim_index, 0)
<< "vthread do not work with cross thread reduction";
if (e.scope.rank == 1) {
CHECK(arith::GetConstInt(attr->value, &(e.extent)))
<< "Need constant extent for reduce set " << iv;
if (reduce_set.count(iv->var.get())) {
vred.push_back(e);
++nmatch;
......
......@@ -172,6 +172,9 @@ void RebaseNonZeroMinLoop(const Schedule& sch) {
IterVar rebased = IterVarNode::make(
Range(), iv->var.copy_with_suffix(""), iv->iter_type);
s->relations.push_back(RebaseNode::make(iv, rebased));
if (s->iter_var_attrs.count(iv)) {
s->iter_var_attrs.Set(rebased, s->iter_var_attrs.at(iv));
}
leaf_vars->data[idx] = rebased.node_;
rebase_map[iv] = rebased;
}
......
import tvm
import os
from tvm.addon import nvcc_compiler
import numpy as np
TASK="gemm"
USE_MANUAL_CODE = False
@tvm.register_func
def tvm_callback_cuda_compile(code):
ptx = nvcc_compiler.compile_source(code, target="ptx", options=["-arch=sm_52"])
return ptx
def write_code(code, fname):
with open(fname, "w") as f:
f.write(code)
@tvm.register_func
def tvm_callback_cuda_postproc(code):
if not os.path.exists("perf"):
os.mkdir("perf")
write_code(code, "perf/%s_generated.cu" % TASK)
if USE_MANUAL_CODE:
code = open("perf/%s_manual.cu" % TASK).read()
return code
def test_gemm():
# graph
nn = 2048
n = tvm.var('n')
n = tvm.convert(nn)
m, l = n, n
A = tvm.placeholder((l, n), name='A')
B = tvm.placeholder((l, m), name='B')
k = tvm.reduce_axis((0, l), name='k')
C = tvm.compute(
(m, n),
lambda ii, jj: tvm.sum(A[k, jj] * B[k, ii], axis=k),
name='C')
# schedule
s = tvm.create_schedule(C.op)
AA = s.cache_read(A, "shared", [C])
BB = s.cache_read(B, "shared", [C])
AL = s.cache_read(AA, "local", [C])
BL = s.cache_read(BB, "local", [C])
CC = s.cache_write(C, "local")
scale = 8
num_thread = 8
block_factor = scale * num_thread
block_x = tvm.thread_axis("blockIdx.x")
thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
block_y = tvm.thread_axis("blockIdx.y")
thread_y = tvm.thread_axis((0, num_thread), "threadIdx.y")
thread_xz = tvm.thread_axis((0, 2), "vthread", name="vx")
thread_yz = tvm.thread_axis((0, 2), "vthread", name="vy")
by, yi = s[C].split(C.op.axis[0], factor=block_factor)
bx, xi = s[C].split(C.op.axis[1], factor=block_factor)
s[C].bind(by, block_y)
s[C].bind(bx, block_x)
s[C].reorder(by, bx, yi, xi)
tyz, yi = s[C].split(yi, nparts=2)
ty, yi = s[C].split(yi, nparts=num_thread)
txz, xi = s[C].split(xi, nparts=2)
tx, xi = s[C].split(xi, nparts=num_thread)
s[C].bind(tyz, thread_yz)
s[C].bind(txz, thread_xz)
s[C].bind(ty, thread_y)
s[C].bind(tx, thread_x)
s[C].reorder(tyz, txz, ty, tx, yi, xi)
s[CC].compute_at(s[C], tx)
yo, xo = CC.op.axis
ko, ki = s[CC].split(k, factor=8)
kt, ki = s[CC].split(ki, factor=1)
s[CC].reorder(ko, kt, ki, yo, xo)
s[AA].compute_at(s[CC], ko)
s[BB].compute_at(s[CC], ko)
s[AL].compute_at(s[CC], kt)
s[BL].compute_at(s[CC], kt)
# Schedule for A's shared memory load
ty, xi = s[AA].split(s[AA].op.axis[0], nparts=num_thread)
_, xi = s[AA].split(s[AA].op.axis[1], factor=num_thread * 4)
tx, xi = s[AA].split(xi, nparts=num_thread)
s[AA].bind(ty, thread_y)
s[AA].bind(tx, thread_x)
s[AA].vectorize(xi)
# Schedule for B' shared memory load
ty, xi = s[BB].split(s[BB].op.axis[0], nparts=num_thread)
_, xi = s[BB].split(s[BB].op.axis[1], factor=num_thread * 4)
tx, xi = s[BB].split(xi, nparts=num_thread)
s[BB].bind(ty, thread_y)
s[BB].bind(tx, thread_x)
s[BB].vectorize(xi)
max_auto_unroll_step = 8
# correctness
def check_device(device, host="stackvm"):
if not tvm.codegen.enabled(host):
return
if not tvm.codegen.enabled(device):
return
f = tvm.build(s, [A, B, C], device, host,
max_auto_unroll_step=max_auto_unroll_step)
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0)
# launch the kernel.
n, m, l = nn, nn, nn
a_np = np.random.uniform(size=(n, l)).astype(A.dtype)
b_np = np.random.uniform(size=(m, l)).astype(B.dtype)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(b_np, ctx)
c = tvm.nd.array(np.zeros((n, m), dtype=C.dtype), ctx)
for i in range(2):
f(a, b, c)
np.testing.assert_allclose(
c.asnumpy(), np.dot(b_np.T, a_np), rtol=1e-5)
check_device("cuda")
if __name__ == "__main__":
test_gemm()
......@@ -17,8 +17,8 @@ import numpy as np
# Quick knobs
TASK="rnn_matexp"
USE_MANUAL_CODE = False
PERSIST_KERNEL = False
DETECT_GLOBAL_BARRIER = True
PERSIST_KERNEL = True
DETECT_GLOBAL_BARRIER = PERSIST_KERNEL
SKIP_CHECK = False
@tvm.register_func
......@@ -93,6 +93,7 @@ def rnn_matexp():
if PERSIST_KERNEL:
s[WhhL].compute_at(s[s_scan], thread_x)
s[WhhL].unroll(WhhL.op.axis[0])
else:
s[WhhL].compute_at(s[CLF], CLF.op.axis[3])
......
"""
Reduction
=========
**Author**: `Tianqi Chen <https://tqchen.github.io>`_
This is an introduction material on how to do reduction in TVM.
Associative reduction operators like sum/max/min are typical
construction blocks of linear algebra operations.
In this tutorial, we will demonstrate how to do reduction in TVM.
"""
from __future__ import absolute_import, print_function
import tvm
import numpy as np
######################################################################
# Describe Sum of Rows
# --------------------
# Assume we want to compute sum of rows as our example.
# In numpy semantics this can be written as :code:`B = numpy.sum(A, axis=1)`
#
# The following lines describes the row sum operation.
# To create a reduction formula, we declare a reduction axis using
# :any:`tvm.reduce_axis`. :any:`tvm.reduce_axis` takes in the range of reductions.
# :any:`tvm.sum` takes in the expression to be reduced as well as the reduction
# axis and compute the sum of value over all k in the declared range.
#
# The equivalent C code is as follows:
#
# .. code-block:: c
#
# for (int i = 0; i < n; ++i) {
# B[i] = 0;
# for (int k = 0; k < m; ++k) {
# B[i] = B[i] + A[i][k];
# }
# }
#
n = tvm.var("n")
m = tvm.var("m")
A = tvm.placeholder((n, m), name='A')
k = tvm.reduce_axis((0, m), "k")
B = tvm.compute((n,), lambda i: tvm.sum(A[i, k], axis=k), name="B")
######################################################################
# Schedule the Reduction
# ----------------------
# There are several ways to schedule a reduction.
# Before doing anything, let us print out the IR code of default schedule.
#
s = tvm.create_schedule(B.op)
print(tvm.lower(s, [A, B], with_api_wrapper=False))
######################################################################
# You can find that the IR code is quite like the C code.
# The reduction axis is similar to a normal axis, it can be splitted.
#
# In the following code we split both the row axis of B as well
# axis by different factors. The result is a nested reduction.
#
ko, ki = s[B].split(B.op.reduce_axis[0], factor=16)
xo, xi = s[B].split(B.op.axis[0], factor=32)
print(tvm.lower(s, [A, B], with_api_wrapper=False))
######################################################################
# If we are building a GPU kernel, we can bind the rows of B to GPU threads.
s[B.op].bind(xo, tvm.thread_axis("blockIdx.x"))
s[B.op].bind(xi, tvm.thread_axis("threadIdx.x"))
print(tvm.lower(s, [A, B], with_api_wrapper=False))
######################################################################
# Reduction Factoring and Parallelization
# ---------------------------------------
# One problem of building a reduction is that we cannot simply
# parallelize over the reduction axis. We need to devide the computation
# of the reduction, store the local reduction result in a temporal array.
# Before doing a reduction over the temp array.
#
# The rfactor primitive does such rewrite of the computation.
# In the following schedule, the result of B is write written to a temporary
# result B.rf. The factored dimension becomes the first dimension of B.rf.
#
s = tvm.create_schedule(B.op)
ko, ki = s[B].split(B.op.reduce_axis[0], factor=16)
BF = s.rfactor(B, ki)
print(tvm.lower(s, [A, B], with_api_wrapper=False))
######################################################################
# The scheduled operator of B also get rewritten to be sum over
# the first axis of reduced result of B.f
#
print(s[B].op.body)
######################################################################
# Cross Thread Reduction
# ----------------------
# We can now parallelize over the factored axis.
# Here mark the reduction axis of B is marked to be a thread.
# tvm allow reduction axis to be marked as thread if it is the only
# axis in reduction and cross thread reduction is possible in the device.
#
# This is indeed the case after the factoring.
# We can directly compute BF at the reduction axis as well.
# The final generated kernel will divides the rows by blockIdx.x and threadIdx.y
# columns by threadIdx.x and finally do a cross thread reduction over threadIdx.x
#
xo, xi = s[B].split(s[B].op.axis[0], factor=32)
s[B.op].bind(xo, tvm.thread_axis("blockIdx.x"))
s[B.op].bind(xi, tvm.thread_axis("threadIdx.y"))
s[B].bind(s[B].op.reduce_axis[0], tvm.thread_axis("threadIdx.x"))
s[BF].compute_at(s[B], s[B].op.reduce_axis[0])
fcuda = tvm.build(s, [A, B], "cuda")
print(fcuda.imported_modules[0].get_source())
######################################################################
# Verify the correctness of result kernel by comparing it to numpy.
#
nn = 128
ctx = tvm.gpu(0)
a = tvm.nd.array(np.random.uniform(size=(nn, nn)).astype(A.dtype), ctx)
b = tvm.nd.array(np.zeros(nn, dtype=B.dtype), ctx)
fcuda(a, b)
np.testing.assert_allclose(
b.asnumpy(), np.sum(a.asnumpy(), axis=1), rtol=1e-4)
######################################################################
# Summary
# -------
# This tutorial provides a walk through of reduction schedule.
#
# - Describe reduction with reduce_axis.
# - Use rfactor to factor out axis if we need parallelism.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment