matexp.py 4.99 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
"""Matrix exponential example.

This is an example for matrix exponential,
which calculates the following recursion formula

```math
X[t] = dot(X[t-1], W)
```
"""
import tvm
import time
import os
import argparse
14
from tvm.contrib import nvcc
15 16 17
import numpy as np

# Quick knobs
18
TASK="matexp"
19
USE_MANUAL_CODE = False
20 21
PERSIST_KERNEL = True
DETECT_GLOBAL_BARRIER = PERSIST_KERNEL
22 23 24 25 26
SKIP_CHECK = False

@tvm.register_func
def tvm_callback_cuda_compile(code):
    """Use nvcc compiler for better perf."""
27
    ptx =  nvcc.compile_cuda(code, target="ptx")
28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48
    return ptx

def write_code(code, fname):
    with open(fname, "w") as f:
        f.write(code)

@tvm.register_func
def tvm_callback_cuda_postproc(code):
    if not os.path.exists("perf"):
        os.mkdir("perf")
    write_code(code, "perf/%s_generated.cu" % TASK)
    if USE_MANUAL_CODE:
        code = open("perf/%s_manual.cu" % TASK).read()
    return code

def rnn_matexp():
    n_num_step = 128
    n_num_hidden = 1152
    n_batch_size = 4
    detect_global_barrier = DETECT_GLOBAL_BARRIER

49
    num_step = tvm.var("num_step")
50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
    num_hidden = tvm.convert(n_num_hidden)
    batch_size = tvm.convert(n_batch_size)
    num_thread_y = 8
    num_thread_x = 16 * 3
    num_sm = 24

    Whh = tvm.placeholder((num_hidden, num_hidden), name="Whh")
    s_init = tvm.compute((1, batch_size, num_hidden),
                         lambda _, i, j: 1.0, name="init")
    s_state = tvm.placeholder((num_step, batch_size, num_hidden))
    kh = tvm.reduce_axis((0, num_hidden), name="kh")
    s_update = tvm.compute(
        (num_step, batch_size, num_hidden),
        lambda t, i, j: tvm.sum(s_state[t-1, i, kh] * Whh[kh, j], axis=kh),
        name="update")
    s_scan = tvm.scan(s_init, s_update, s_state)
    # schedule
67
    s = tvm.create_schedule(s_scan.op)
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
    CL = s_update
    SS = s.cache_read(s_state, "shared", [CL])
    SL = s.cache_read(SS, "local", [CL])
    WhhL = s.cache_read(Whh, "local", [CL])
    ko, ki = s[CL].split(s[CL].op.reduce_axis[0], nparts=num_thread_y)
    CLF = s.rfactor(CL, ko)

    block_x = tvm.thread_axis((0, num_sm), "blockIdx.x")
    thread_x = tvm.thread_axis((0, num_thread_x), "threadIdx.x")
    thread_y = tvm.thread_axis((0, num_thread_y), "threadIdx.y")
    if PERSIST_KERNEL:
        s[s_scan.op].env_threads([block_x, thread_y, thread_x])

    bx, xi = s[s_init].split(s_init.op.axis[2], nparts=num_sm)
    tx, xi = s[s_init].split(xi, nparts=num_thread_x)
    s[s_init].bind(bx, block_x)
    s[s_init].bind(tx, thread_x)

    bx, xi = s[s_update].split(s[CL].op.axis[2], nparts=num_sm)
    tx, xi = s[s_update].split(xi, nparts=num_thread_x)
    s[s_update].bind(bx, block_x)
    s[s_update].bind(tx, thread_x)
    s[CL].bind(s[CL].op.reduce_axis[0], thread_y)
    s[CLF].compute_at(s[CL], s[CL].op.reduce_axis[0])
92 93
    # Duplicate store predicate.
    s[CL].set_store_predicate(thread_y.equal(0))
94 95 96

    if PERSIST_KERNEL:
        s[WhhL].compute_at(s[s_scan], thread_x)
97
        s[WhhL].unroll(WhhL.op.axis[0])
98 99 100 101 102 103 104 105 106 107 108 109 110 111 112
    else:
        s[WhhL].compute_at(s[CLF], CLF.op.axis[3])

    kr, ki = s[CLF].split(CLF.op.reduce_axis[0], nparts=1)
    ko, ki = s[CLF].split(ki, factor=4)
    s[SS].compute_at(s[CLF], kr)
    s[SL].compute_at(s[CLF], ko)

    xo, xi = s[SS].split(SS.op.axis[2], factor=num_thread_x * num_thread_y * 3)
    ty, xi = s[SS].split(xi, nparts=num_thread_y)
    tx, xi = s[SS].split(xi, nparts=num_thread_x)
    s[SS].bind(ty, thread_y)
    s[SS].bind(tx, thread_x)

    def check_device(target):
113 114 115 116 117
        with tvm.build_config(
                detect_global_barrier=detect_global_barrier,
                auto_unroll_max_step=128,
                unroll_explicit=False):
            f = tvm.build(s, [s_scan, Whh], target)
118 119 120 121 122 123 124 125 126 127 128 129
        ctx = tvm.gpu(0) if target == "cuda" else tvm.cl(0)
        # launch the kernel.
        res_np = np.zeros(
            (n_num_step, n_batch_size, n_num_hidden)).astype("float32")
        Whh_np = np.zeros((n_num_hidden, n_num_hidden)).astype("float32")
        Whh_np[:] = 2.0 / n_num_hidden
        Whh_np[:, n_num_hidden//2:] = 0

        res_a = tvm.nd.array(res_np, ctx)
        Whh_a = tvm.nd.array(Whh_np, ctx)
        # Skip first pass as it is compilation
        f(res_a, Whh_a)
130
        ctx.sync()
131 132 133
        # measure time cost of second step.
        tstart = time.time()
        f(res_a, Whh_a)
134
        ctx.sync()
135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152
        tgap = time.time() - tstart
        print("Time cost=%g" % tgap)
        # correctness
        if not SKIP_CHECK:
            res_gpu = res_a.asnumpy()
            res_cmp = np.ones_like(res_np).astype("float64")
            Whh_np = Whh_np.astype("float64")
            for t in range(1, n_num_step):
                res_cmp[t][:] = np.dot(res_cmp[t - 1], Whh_np)
            for i  in range(n_num_step):
                for j in range(n_num_hidden):
                    if abs(res_cmp[i,0,j] - res_gpu[i,0,j]) > 1e-5:
                        print("%d, %d: %g vs %g" % (i,j, res_cmp[i,0,j], res_gpu[i,0,j]))
            np.testing.assert_allclose(res_gpu, res_cmp, rtol=1e-3)
    check_device("cuda")

if __name__ == "__main__":
    rnn_matexp()