test_codegen_device.py 4.38 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17
import tvm
18
from tvm.contrib import util
19 20 21
import numpy as np

def test_add_pipeline():
22
    n = tvm.var('n')
23
    A = tvm.placeholder((n,), name='A')
24 25
    B = tvm.placeholder((), name='B')
    C = tvm.compute(A.shape, lambda *i: A(*i) + B(), name='C')
26
    D = tvm.compute(A.shape, lambda *i: C(*i) + 1, name='D')
27
    s = tvm.create_schedule(D.op)
28 29 30

    # GPU schedule have to split by gridIdx and threadIdx
    num_thread = 256
31
    xo, xi = s[C].split(C.op.axis[0], factor=num_thread)
32 33
    s[C].bind(xi, tvm.thread_axis("threadIdx.x"))
    s[C].bind(xo, tvm.thread_axis("blockIdx.x"))
34

35
    xo, xi = s[D].split(D.op.axis[0], factor=num_thread)
36 37
    s[D].bind(xi, tvm.thread_axis("threadIdx.x"))
    s[D].bind(xo, tvm.thread_axis("blockIdx.x"))
38

39
    # compile to IR
40
    s = s.normalize()
41
    bounds = tvm.schedule.InferBound(s)
42
    stmt = tvm.schedule.ScheduleOps(s, bounds)
43 44
    Ab = tvm.decl_buffer(A.shape, A.dtype, name='A')
    Bb = tvm.decl_buffer(B.shape, B.dtype, name='B')
45
    Db = tvm.decl_buffer(D.shape, D.dtype, name='D')
46
    stmt = tvm.ir_pass.LoopPartition(stmt, False)
47
    stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, D:Db}, 64)
48
    stmt = tvm.ir_pass.Simplify(stmt)
49
    fapi = tvm.ir_pass.MakeAPI(stmt, "myadd", [Ab, Bb, Db], 0, True)
50
    fsplits = [x for x in tvm.ir_pass.SplitHostDevice(fapi)]
51 52
    # lower the floordiv(use stackvm rules so it works for all targets)
    fsplits = [tvm.ir_pass.LowerIntrin(x, "stackvm") for x in fsplits]
53
    fsplits[0] = tvm.ir_pass.LowerTVMBuiltin(fsplits[0])
54

55
    def check_target(device, host="stackvm"):
56 57
        ctx = tvm.context(device, 0)
        if not ctx.exist:
58
            return
59
        if not tvm.module.enabled(host):
60
            return
61 62
        mhost = tvm.codegen.build_module(fsplits[0], host)
        mdev = tvm.codegen.build_module(fsplits[1:], device)
63 64 65 66 67 68
        mhost.import_module(mdev)
        code = mdev.get_source()
        f = mhost.entry_func
        # launch the kernel.
        n = 1027
        a = tvm.nd.array(np.random.uniform(size=n).astype(Ab.dtype), ctx)
69
        b = tvm.nd.array(np.random.uniform(size=()).astype(Bb.dtype), ctx)
70 71
        d = tvm.nd.array(np.zeros(n, dtype=Db.dtype), ctx)
        f(a, b, d)
72
        tvm.testing.assert_allclose(
73
            d.asnumpy(), a.asnumpy() + b.asnumpy() + 1)
74

75
    def check_module_save(device, host="stackvm"):
76 77
        ctx = tvm.context(device, 0)
        if not ctx.exist:
78
            return
79
        if not tvm.module.enabled(host):
80
            return
81 82 83 84 85 86
        if device == "cuda":
            fmt = "ptx"
        elif device == "rocm":
            fmt = "hsaco"
        else:
            fmt = device
87 88
        mhost = tvm.codegen.build_module(fsplits[0], host)
        mdev = tvm.codegen.build_module(fsplits[1:], device)
89
        temp = util.tempdir()
90 91 92 93 94 95 96 97
        mpath = temp.relpath("test.%s" % fmt)
        mdev.save(mpath)
        mdev2 = tvm.module.load(mpath)
        mhost.import_module(mdev2)
        f = mhost.entry_func
        # launch the kernel.
        n = 1027
        a = tvm.nd.array(np.random.uniform(size=n).astype(Ab.dtype), ctx)
98
        b = tvm.nd.array(np.random.uniform(size=()).astype(Bb.dtype), ctx)
99 100
        d = tvm.nd.array(np.zeros(n, dtype=Db.dtype), ctx)
        f(a, b, d)
101
        tvm.testing.assert_allclose(
102
            d.asnumpy(), a.asnumpy() + b.asnumpy() + 1)
103

104 105
    check_target("cuda", host="stackvm")
    check_target("cuda", host="llvm")
106
    check_module_save("cuda", host="stackvm")
107
    check_target("nvptx", host="llvm")
108 109
    check_target("vulkan", host="llvm")
    check_module_save("vulkan", host="stackvm")
110 111
    check_target("rocm", host="llvm")
    check_module_save("rocm", host="llvm")
112

113

114 115
if __name__ == "__main__":
    test_add_pipeline()