test_target_codegen_device.py 5.23 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17
import tvm
18
from tvm import te
19
from tvm.contrib import util
20 21
import numpy as np

22 23
def test_large_uint_imm():
    value =  (1 << 63) + 123
24
    other = tvm.tir.const(3, "uint64")
25 26 27
    n = 12
    num_thread = 2

28 29
    A = te.compute((n,), lambda *i: tvm.tir.const(value, "uint64") + other, name='A')
    s = te.create_schedule(A.op)
30
    xo, xi = s[A].split(A.op.axis[0], factor=num_thread)
31 32
    s[A].bind(xi, te.thread_axis("threadIdx.x"))
    s[A].bind(xo, te.thread_axis("blockIdx.x"))
33 34 35 36 37 38 39 40 41 42 43 44 45 46 47

    def check_target(device):
        ctx = tvm.context(device, 0)
        if not ctx.exist:
            return
        f = tvm.build(s, [A], device)
        # launch the kernel.
        a = tvm.nd.empty((n, ), dtype=A.dtype, ctx=ctx)
        f(a)
        assert a.asnumpy()[0] == value + 3

    check_target("cuda")
    check_target("vulkan")


48
def test_add_pipeline():
49 50 51 52 53 54
    n = te.size_var('n')
    A = te.placeholder((n,), name='A')
    B = te.placeholder((), name='B')
    C = te.compute(A.shape, lambda *i: A(*i) + B(), name='C')
    D = te.compute(A.shape, lambda *i: C(*i) + 1, name='D')
    s = te.create_schedule(D.op)
55 56 57

    # GPU schedule have to split by gridIdx and threadIdx
    num_thread = 256
58
    xo, xi = s[C].split(C.op.axis[0], factor=num_thread)
59 60
    s[C].bind(xi, te.thread_axis("threadIdx.x"))
    s[C].bind(xo, te.thread_axis("blockIdx.x"))
61

62
    xo, xi = s[D].split(D.op.axis[0], factor=num_thread)
63 64
    s[D].bind(xi, te.thread_axis("threadIdx.x"))
    s[D].bind(xo, te.thread_axis("blockIdx.x"))
65

66
    # compile to IR
67
    s = s.normalize()
68 69 70 71 72 73 74 75 76 77
    bounds = tvm.te.schedule.InferBound(s)
    stmt = tvm.te.schedule.ScheduleOps(s, bounds)
    Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A')
    Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name='B')
    Db = tvm.tir.decl_buffer(D.shape, D.dtype, name='D')
    stmt = tvm.tir.ir_pass.LoopPartition(stmt, False)
    stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, D:Db}, 64)
    stmt = tvm.tir.ir_pass.Simplify(stmt)
    fapi = tvm.tir.ir_pass.MakeAPI(stmt, "myadd", [Ab, Bb, Db], 0, True)
    fsplits = [x for x in tvm.tir.ir_pass.SplitHostDevice(fapi)]
78
    # lower the floordiv(use stackvm rules so it works for all targets)
79 80
    fsplits = [tvm.tir.ir_pass.LowerIntrin(x, "stackvm") for x in fsplits]
    fsplits[0] = tvm.tir.ir_pass.LowerTVMBuiltin(fsplits[0])
81

82
    def check_target(device, host="stackvm"):
83 84
        ctx = tvm.context(device, 0)
        if not ctx.exist:
85
            return
86
        if not tvm.runtime.enabled(host):
87
            return
88 89
        mhost = tvm.target.codegen.build_module(fsplits[0], host)
        mdev = tvm.target.codegen.build_module(fsplits[1:], device)
90 91 92 93 94 95
        mhost.import_module(mdev)
        code = mdev.get_source()
        f = mhost.entry_func
        # launch the kernel.
        n = 1027
        a = tvm.nd.array(np.random.uniform(size=n).astype(Ab.dtype), ctx)
96
        b = tvm.nd.array(np.random.uniform(size=()).astype(Bb.dtype), ctx)
97 98
        d = tvm.nd.array(np.zeros(n, dtype=Db.dtype), ctx)
        f(a, b, d)
99
        tvm.testing.assert_allclose(
100
            d.asnumpy(), a.asnumpy() + b.asnumpy() + 1)
101

102
    def check_module_save(device, host="stackvm"):
103 104
        ctx = tvm.context(device, 0)
        if not ctx.exist:
105
            return
106
        if not tvm.runtime.enabled(host):
107
            return
108 109 110 111 112 113
        if device == "cuda":
            fmt = "ptx"
        elif device == "rocm":
            fmt = "hsaco"
        else:
            fmt = device
114 115
        mhost = tvm.target.codegen.build_module(fsplits[0], host)
        mdev = tvm.target.codegen.build_module(fsplits[1:], device)
116
        temp = util.tempdir()
117 118
        mpath = temp.relpath("test.%s" % fmt)
        mdev.save(mpath)
119
        mdev2 = tvm.runtime.load_module(mpath)
120 121 122 123 124
        mhost.import_module(mdev2)
        f = mhost.entry_func
        # launch the kernel.
        n = 1027
        a = tvm.nd.array(np.random.uniform(size=n).astype(Ab.dtype), ctx)
125
        b = tvm.nd.array(np.random.uniform(size=()).astype(Bb.dtype), ctx)
126 127
        d = tvm.nd.array(np.zeros(n, dtype=Db.dtype), ctx)
        f(a, b, d)
128
        tvm.testing.assert_allclose(
129
            d.asnumpy(), a.asnumpy() + b.asnumpy() + 1)
130

131 132
    check_target("cuda", host="stackvm")
    check_target("cuda", host="llvm")
133
    check_module_save("cuda", host="stackvm")
134
    check_target("nvptx", host="llvm")
135 136
    check_target("vulkan", host="llvm")
    check_module_save("vulkan", host="stackvm")
137 138
    check_target("rocm", host="llvm")
    check_module_save("rocm", host="llvm")
139

140

141
if __name__ == "__main__":
142
    test_large_uint_imm()
143
    test_add_pipeline()