test_codegen_device.py 3.34 KB
Newer Older
1
import tvm
2
from tvm.contrib import util
3 4 5
import numpy as np

def test_add_pipeline():
6
    n = tvm.var('n')
7
    A = tvm.placeholder((n,), name='A')
8 9
    B = tvm.placeholder((), name='B')
    C = tvm.compute(A.shape, lambda *i: A(*i) + B(), name='C')
10
    D = tvm.compute(A.shape, lambda *i: C(*i) + 1, name='D')
11
    s = tvm.create_schedule(D.op)
12 13 14

    # GPU schedule have to split by gridIdx and threadIdx
    num_thread = 256
15
    xo, xi = s[C].split(C.op.axis[0], factor=num_thread)
16 17
    s[C].bind(xi, tvm.thread_axis("threadIdx.x"))
    s[C].bind(xo, tvm.thread_axis("blockIdx.x"))
18

19
    xo, xi = s[D].split(D.op.axis[0], factor=num_thread)
20 21
    s[D].bind(xi, tvm.thread_axis("threadIdx.x"))
    s[D].bind(xo, tvm.thread_axis("blockIdx.x"))
22

23
    # compile to IR
24
    s = s.normalize()
25
    bounds = tvm.schedule.InferBound(s)
26
    stmt = tvm.schedule.ScheduleOps(s, bounds)
27 28
    Ab = tvm.decl_buffer(A.shape, A.dtype, name='A')
    Bb = tvm.decl_buffer(B.shape, B.dtype, name='B')
29
    Db = tvm.decl_buffer(D.shape, D.dtype, name='D')
30
    stmt = tvm.ir_pass.LoopPartition(stmt, False)
31
    stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, D:Db}, 64)
32
    stmt = tvm.ir_pass.Simplify(stmt)
33
    fapi = tvm.ir_pass.MakeAPI(stmt, "myadd", [Ab, Bb, Db], 0, True)
34
    fsplits = [x for x in tvm.ir_pass.SplitHostDevice(fapi)]
35
    fsplits[0] = tvm.ir_pass.LowerTVMBuiltin(fsplits[0])
36

37
    def check_target(device, host="stackvm"):
38 39
        ctx = tvm.context(device, 0)
        if not ctx.exist:
40
            return
41
        if not tvm.module.enabled(host):
42
            return
43 44
        mhost = tvm.codegen.build_module(fsplits[0], host)
        mdev = tvm.codegen.build_module(fsplits[1:], device)
45 46 47 48 49 50
        mhost.import_module(mdev)
        code = mdev.get_source()
        f = mhost.entry_func
        # launch the kernel.
        n = 1027
        a = tvm.nd.array(np.random.uniform(size=n).astype(Ab.dtype), ctx)
51
        b = tvm.nd.array(np.random.uniform(size=()).astype(Bb.dtype), ctx)
52 53
        d = tvm.nd.array(np.zeros(n, dtype=Db.dtype), ctx)
        f(a, b, d)
54
        np.testing.assert_allclose(
55
            d.asnumpy(), a.asnumpy() + b.asnumpy() + 1)
56

57
    def check_module_save(device, host="stackvm"):
58 59
        ctx = tvm.context(device, 0)
        if not ctx.exist:
60
            return
61
        if not tvm.module.enabled(host):
62
            return
63
        fmt = "ptx" if device == "cuda" else device
64 65
        mhost = tvm.codegen.build_module(fsplits[0], host)
        mdev = tvm.codegen.build_module(fsplits[1:], device)
66
        temp = util.tempdir()
67 68 69 70 71 72 73 74
        mpath = temp.relpath("test.%s" % fmt)
        mdev.save(mpath)
        mdev2 = tvm.module.load(mpath)
        mhost.import_module(mdev2)
        f = mhost.entry_func
        # launch the kernel.
        n = 1027
        a = tvm.nd.array(np.random.uniform(size=n).astype(Ab.dtype), ctx)
75
        b = tvm.nd.array(np.random.uniform(size=()).astype(Bb.dtype), ctx)
76 77
        d = tvm.nd.array(np.zeros(n, dtype=Db.dtype), ctx)
        f(a, b, d)
78
        np.testing.assert_allclose(
79
            d.asnumpy(), a.asnumpy() + b.asnumpy() + 1)
80

81 82
    check_target("cuda", host="stackvm")
    check_target("cuda", host="llvm")
83
    check_module_save("cuda", host="stackvm")
84
    check_target("nvptx", host="llvm")
85
    check_target("vulkan", host="llvm")
86
    check_target("rocm", host="llvm")
87
    check_module_save("vulkan", host="stackvm")
88

89

90 91
if __name__ == "__main__":
    test_add_pipeline()