test_gemm.py 2.83 KB
Newer Older
1 2
import tvm
import numpy as np
3 4
import time

5

6 7
def test_gemm():
    # graph
8
    nn = 1024
9
    n = tvm.var('n')
10
    n = tvm.convert(nn)
11 12 13 14
    m = n
    l = n
    A = tvm.placeholder((n, l), name='A')
    B = tvm.placeholder((m, l), name='B')
15
    k = tvm.reduce_axis((0, l), name='k')
16
    C = tvm.compute(
17
        (n, m),
18
        lambda ii, jj: tvm.sum(A[ii, k] * B[jj, k], axis=k),
19 20
        name='CC')
    # schedule
21
    s = tvm.create_schedule(C.op)
22 23 24 25
    xtile, ytile = 32, 32
    scale = 8
    num_thread = 8
    block_factor = scale * num_thread
26 27 28 29
    block_x = tvm.thread_axis("blockIdx.x")
    thread_x = tvm.thread_axis("threadIdx.x")
    block_y = tvm.thread_axis("blockIdx.y")
    thread_y = tvm.thread_axis("threadIdx.y")
30

31 32 33
    CC = s.cache_write(C, "local")
    AA = s.cache_read(A, "shared", [CC])
    BB = s.cache_read(B, "shared", [CC])
34 35 36 37 38 39 40 41 42 43
    by, yi = s[C].split(C.op.axis[0], factor=block_factor)
    bx, xi = s[C].split(C.op.axis[1], factor=block_factor)
    s[C].reorder(by, bx, yi, xi)
    s[C].bind(by, block_y)
    s[C].bind(bx, block_x)
    ty, yi = s[C].split(yi, nparts=num_thread)
    tx, xi = s[C].split(xi, nparts=num_thread)
    s[C].reorder(ty, tx, yi, xi)
    s[C].bind(ty, thread_y)
    s[C].bind(tx, thread_x)
44 45 46
    yo, xo = CC.op.axis
    s[CC].reorder(k, yo, xo)

47 48

    s[CC].compute_at(s[C], tx)
49 50
    s[AA].compute_at(s[CC], k)
    s[BB].compute_at(s[CC], k)
51 52
    s[AA].double_buffer()
    s[BB].double_buffer()
53 54 55 56 57 58 59 60 61
    ty, xi = s[AA].split(s[AA].op.axis[0], nparts=num_thread)
    tx, xi = s[AA].split(xi, nparts=num_thread)
    s[AA].bind(ty, thread_y)
    s[AA].bind(tx, thread_x)

    ty, xi = s[BB].split(s[BB].op.axis[0], nparts=num_thread)
    tx, xi = s[BB].split(xi, nparts=num_thread)
    s[BB].bind(ty, thread_y)
    s[BB].bind(tx, thread_x)
62 63

    # lowering test
64
    s = s.normalize()
65

66
    # one line to build the function.
67
    def check_device(device):
68 69
        ctx = tvm.context(device, 0)
        if not ctx.exist:
70
            print("skip because %s is not enabled.." % device)
71
            return
72

73 74
        with tvm.target.create(device):
            f = tvm.build(s, [A, B, C])
75

76 77 78 79 80 81 82 83 84
        # launch the kernel.
        n = nn
        m = n
        l = n
        a_np = np.random.uniform(size=(n, l)).astype(A.dtype)
        b_np = np.random.uniform(size=(m, l)).astype(B.dtype)
        a = tvm.nd.array(a_np, ctx)
        b = tvm.nd.array(b_np, ctx)
        c = tvm.nd.array(np.zeros((n, m), dtype=C.dtype), ctx)
85 86
        ftimer = f.time_evaluator(f.entry_name, ctx, number=1)
        tcost = ftimer(a, b, c).mean
87
        print("%s: exec=%g sec/op" % (ctx, tcost))
88
        tvm.testing.assert_allclose(
89 90
            c.asnumpy(), np.dot(a_np, b_np.T), rtol=1e-5)

91
    check_device("vulkan")
92 93
    check_device("nvptx -mcpu=sm_20")
    check_device("rocm")
94
    check_device("metal")
95
    check_device("opencl")
96
    check_device("cuda")
97 98 99

if __name__ == "__main__":
    test_gemm()