test_codegen_llvm.py 12.6 KB
Newer Older
1
import tvm
2
from tvm.contrib import util, clang
3
import numpy as np
4 5 6 7 8 9 10 11 12 13 14 15
import ctypes

def test_llvm_intrin():
    ib = tvm.ir_builder.create()
    n = tvm.convert(4)
    A = ib.pointer("float32", name="A")
    args = [
        tvm.call_pure_intrin("handle", "tvm_address_of", A[0]),
        0, 3, 1
    ]
    ib.emit(tvm.make.Evaluate(
        tvm.make.Call(
16
            "int32", "prefetch", args, tvm.expr.Call.Intrinsic, None, 0)))
17 18 19
    body = ib.get()
    func = tvm.ir_pass.MakeAPI(body, "prefetch", [A], 0, True)
    fcode = tvm.build(func, None, "llvm")
20

21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54

def test_llvm_import():
    # extern "C" is necessary to get the correct signature
    cc_code = """
    extern "C" float my_add(float x, float y) {
      return x + y;
    }
    """
    n = 10
    A = tvm.placeholder((n,), name='A')
    B = tvm.compute((n,), lambda *i:
                    tvm.call_pure_extern("float32", "my_add", A(*i), 1.0),
                    name='B')
    def check_llvm(use_file):
        if not tvm.module.enabled("llvm"):
            return
        if not clang.find_clang(required=False):
            print("skip because clang is not available")
            return
        temp = util.tempdir()
        ll_path = temp.relpath("temp.ll")
        ll_code = clang.create_llvm(cc_code, output=ll_path)
        s = tvm.create_schedule(B.op)
        if use_file:
            s[B].pragma(s[B].op.axis[0], "import_llvm", ll_path)
        else:
            s[B].pragma(s[B].op.axis[0], "import_llvm", ll_code)
        # BUILD and invoke the kernel.
        f = tvm.build(s, [A, B], "llvm")
        ctx = tvm.cpu(0)
        # launch the kernel.
        a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
        b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), ctx)
        f(a, b)
55
        tvm.testing.assert_allclose(
56 57 58 59 60 61
            b.asnumpy(), a.asnumpy() + 1.0)
    check_llvm(use_file=True)
    check_llvm(use_file=False)



62 63 64 65 66 67 68 69 70 71
def test_llvm_lookup_intrin():
    ib = tvm.ir_builder.create()
    m = tvm.var("m")
    A = ib.pointer("uint8x8", name="A")
    x = tvm.call_llvm_intrin("uint8x8", "llvm.ctpop.i8", tvm.const(1, 'uint32'), A)
    ib.emit(x)
    body = ib.get()
    func = tvm.ir_pass.MakeAPI(body, "ctpop", [A], 1, True)
    fcode = tvm.build(func, None, "llvm")

72
def test_llvm_add_pipeline():
73 74
    nn = 1024
    n = tvm.convert(nn)
75 76
    A = tvm.placeholder((n,), name='A')
    B = tvm.placeholder((n,), name='B')
77 78 79
    AA = tvm.compute((n,), lambda *i: A(*i), name='A')
    BB = tvm.compute((n,), lambda *i: B(*i), name='B')
    T = tvm.compute(A.shape, lambda *i: AA(*i) + BB(*i), name='T')
80
    C = tvm.compute(A.shape, lambda *i: T(*i), name='C')
81
    s = tvm.create_schedule(C.op)
82
    xo, xi = s[C].split(C.op.axis[0], factor=4)
83 84 85 86 87
    xo1, xo2 = s[C].split(xo, factor=13)
    s[C].parallel(xo2)
    s[C].pragma(xo1, "parallel_launch_point")
    s[C].pragma(xo2, "parallel_stride_pattern")
    s[C].pragma(xo2, "parallel_barrier_when_finish")
88
    s[C].vectorize(xi)
89

90
    def check_llvm():
91
        if not tvm.module.enabled("llvm"):
92
            return
93 94
        # Specifically allow offset to test codepath when offset is available
        Ab = tvm.decl_buffer(
95 96 97
            A.shape, A.dtype,
            elem_offset=tvm.var('Aoffset'),
            offset_factor=8,
98 99
            name='A')
        binds = {A : Ab}
100
        # BUILD and invoke the kernel.
101
        f = tvm.build(s, [A, B, C], "llvm", binds=binds)
102 103
        ctx = tvm.cpu(0)
        # launch the kernel.
104
        n = nn
105 106 107
        a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
        b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), ctx)
        c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
108
        f(a, b, c)
109
        tvm.testing.assert_allclose(
110
            c.asnumpy(), a.asnumpy() + b.asnumpy())
111

112 113
    with tvm.build_config(offset_factor=4):
        check_llvm()
114 115


116 117 118 119
def test_llvm_persist_parallel():
    n = 128
    A = tvm.placeholder((n,), name='A')
    B = tvm.compute(A.shape, lambda *i: A(*i) + 1, name='B')
120
    C = tvm.compute(A.shape, lambda *i: tvm.sqrt(B(*i)) * 2 + 2, name='C')
121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
    s = tvm.create_schedule(C.op)
    xo, xi = s[C].split(C.op.axis[0], factor=8)
    xo1, xo2 = s[C].split(xo, nparts=1)
    s[B].compute_at(s[C], xo1)
    s[B].parallel(s[B].op.axis[0])
    s[B].pragma(s[B].op.axis[0], "parallel_barrier_when_finish")
    s[C].parallel(xi)
    s[C].pragma(xo1, "parallel_launch_point")
    s[C].pragma(xi, "parallel_stride_pattern")

    def check_llvm():
        if not tvm.module.enabled("llvm"):
            return
        # BUILD and invoke the kernel.
        f = tvm.build(s, [A, C], "llvm")
        ctx = tvm.cpu(0)
        # launch the kernel.
        a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
        c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
        f(a, c)
141
        tvm.testing.assert_allclose(c.asnumpy(),
142 143
                                   np.sqrt(a.asnumpy() + 1) * 2 + 2,
                                   rtol=1e-5)
144 145 146 147

    check_llvm()


148 149
def test_llvm_flip_pipeline():
    def check_llvm(nn, base):
150
        if not tvm.module.enabled("llvm"):
151 152 153 154
            return
        n = tvm.convert(nn)
        A = tvm.placeholder((n + base), name='A')
        C = tvm.compute((n,), lambda i: A(nn + base- i - 1), name='C')
155
        s = tvm.create_schedule(C.op)
156 157 158 159 160 161 162 163 164 165 166
        xo, xi = s[C].split(C.op.axis[0], factor=4)
        s[C].parallel(xo)
        s[C].vectorize(xi)
        # build and invoke the kernel.
        f = tvm.build(s, [A, C], "llvm")
        ctx = tvm.cpu(0)
        # launch the kernel.
        n = nn
        a = tvm.nd.array(np.random.uniform(size=(n + base)).astype(A.dtype), ctx)
        c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
        f(a, c)
167
        tvm.testing.assert_allclose(
168 169 170 171 172 173 174
            c.asnumpy(), a.asnumpy()[::-1][:n])
    check_llvm(4, 0)
    check_llvm(128, 8)
    check_llvm(3, 0)
    check_llvm(128, 1)


175 176 177 178 179 180 181 182
def test_llvm_vadd_pipeline():
    def check_llvm(n, lanes):
        if not tvm.module.enabled("llvm"):
            return
        A = tvm.placeholder((n,), name='A', dtype="float32x%d" % lanes)
        B = tvm.compute((n,), lambda i: A[i], name='B')
        C = tvm.compute((n,), lambda i: B[i] + tvm.const(1, A.dtype), name='C')
        s = tvm.create_schedule(C.op)
183 184
        xo, xi = s[C].split(C.op.axis[0], nparts=2)
        _, xi = s[C].split(xi, factor=2)
185 186
        s[C].parallel(xo)
        s[C].vectorize(xi)
187
        s[B].compute_at(s[C], xo)
188 189 190 191 192 193 194 195 196 197
        xo, xi = s[B].split(B.op.axis[0], factor=2)
        s[B].vectorize(xi)
        # build and invoke the kernel.
        f = tvm.build(s, [A, C], "llvm")
        ctx = tvm.cpu(0)
        # launch the kernel.
        a = tvm.nd.empty((n,), A.dtype).copyfrom(
            np.random.uniform(size=(n, lanes)))
        c = tvm.nd.empty((n,), C.dtype, ctx)
        f(a, c)
198
        tvm.testing.assert_allclose(
199 200
            c.asnumpy(), a.asnumpy() + 1)
    check_llvm(64, 2)
201
    check_llvm(512, 2)
202 203


204 205
def test_llvm_madd_pipeline():
    def check_llvm(nn, base, stride):
206
        if not tvm.module.enabled("llvm"):
207 208 209 210
            return
        n = tvm.convert(nn)
        A = tvm.placeholder((n + base, stride), name='A')
        C = tvm.compute((n, stride), lambda i, j: A(base + i, j) + 1, name='C')
211
        s = tvm.create_schedule(C.op)
212 213 214 215 216 217 218 219 220 221 222
        xo, xi = s[C].split(C.op.axis[0], factor=4)
        s[C].parallel(xo)
        s[C].vectorize(xi)
        # build and invoke the kernel.
        f = tvm.build(s, [A, C], "llvm")
        ctx = tvm.cpu(0)
        # launch the kernel.
        n = nn
        a = tvm.nd.array(np.random.uniform(size=(n + base, stride)).astype(A.dtype), ctx)
        c = tvm.nd.array(np.zeros((n, stride), dtype=C.dtype), ctx)
        f(a, c)
223
        tvm.testing.assert_allclose(
224 225 226
            c.asnumpy(), a.asnumpy()[base:] + 1)
    check_llvm(64, 0, 2)
    check_llvm(4, 0, 1)
227 228
    with tvm.build_config(restricted_func=False):
        check_llvm(4, 0, 3)
229

230

231 232 233 234 235 236
def test_llvm_temp_space():
    nn = 1024
    n = tvm.convert(nn)
    A = tvm.placeholder((n,), name='A')
    B = tvm.compute(A.shape, lambda i: A(i) + 1, name='B')
    C = tvm.compute(A.shape, lambda i: B(i) + 1, name='C')
237
    s = tvm.create_schedule(C.op)
238

239
    def check_llvm():
240
        if not tvm.module.enabled("llvm"):
241 242 243 244 245 246 247 248 249
            return
        # build and invoke the kernel.
        f = tvm.build(s, [A, C], "llvm")
        ctx = tvm.cpu(0)
        # launch the kernel.
        n = nn
        a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
        c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
        f(a, c)
250
        tvm.testing.assert_allclose(
251 252
            c.asnumpy(), a.asnumpy() + 1 + 1)
    check_llvm()
253

254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279
def test_multiple_func():
    nn = 1024
    n = tvm.convert(nn)
    A = tvm.placeholder((n,), name='A')
    B = tvm.placeholder((n,), name='B')
    C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
    s = tvm.create_schedule(C.op)
    xo, xi = s[C].split(C.op.axis[0], factor=4)
    s[C].parallel(xo)
    s[C].vectorize(xi)
    def check_llvm():
        if not tvm.module.enabled("llvm"):
            return
        # build two functions
        f2 = tvm.lower(s, [A, B, C], name="fadd1")
        f1 = tvm.lower(s, [A, B, C], name="fadd2")
        m = tvm.build([f1, f2], "llvm")
        fadd1 = m['fadd1']
        fadd2 = m['fadd2']
        ctx = tvm.cpu(0)
        # launch the kernel.
        n = nn
        a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
        b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), ctx)
        c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
        fadd1(a, b, c)
280
        tvm.testing.assert_allclose(
281 282
            c.asnumpy(), a.asnumpy() + b.asnumpy())
        fadd2(a, b, c)
283
        tvm.testing.assert_allclose(
284 285 286 287
            c.asnumpy(), a.asnumpy() + b.asnumpy())
    check_llvm()


288

289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304
def test_llvm_select():
    def check_llvm(n, offset):
        if not tvm.module.enabled("llvm"):
            return
        A = tvm.placeholder((n, ), name='A')
        C = tvm.compute((n,), lambda i: tvm.select(i >= offset, A[i], 0.0), name='C')
        s = tvm.create_schedule(C.op)
        # build and invoke the kernel.
        f = tvm.build(s, [A, C], "llvm")
        ctx = tvm.cpu(0)
        # launch the kernel.
        a = tvm.nd.array(np.random.uniform(size=(n,)).astype(A.dtype), ctx)
        c = tvm.nd.empty((n,), A.dtype, ctx)
        f(a, c)
        c_np = a.asnumpy()
        c_np[:offset] = 0
305
        tvm.testing.assert_allclose(c.asnumpy(), c_np)
306 307 308
    check_llvm(64, 8)


309 310 311 312 313 314 315 316 317 318 319 320 321 322 323
def test_llvm_bool():
    def check_llvm(n):
        if not tvm.module.enabled("llvm"):
            return
        A = tvm.placeholder((n, ), name='A', dtype="int32")
        C = tvm.compute((n,), lambda i: A[i].equal(1).astype("float"), name='C')
        s = tvm.create_schedule(C.op)
        # build and invoke the kernel.
        f = tvm.build(s, [A, C], "llvm")
        ctx = tvm.cpu(0)
        # launch the kernel.
        a = tvm.nd.array(np.random.randint(0, 2, size=(n,)).astype(A.dtype), ctx)
        c = tvm.nd.empty((n,), C.dtype, ctx)
        f(a, c)
        c_np = a.asnumpy() == 1
324
        tvm.testing.assert_allclose(c.asnumpy(), c_np)
325 326 327
    check_llvm(64)


328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347
def test_rank_zero():
    def check_llvm(n):
        if not tvm.module.enabled("llvm"):
            return
        A = tvm.placeholder((n, ), name='A')
        scale = tvm.placeholder((), name='scale')
        k = tvm.reduce_axis((0, n), name="k")
        C = tvm.compute((), lambda : tvm.sum(A[k] * scale, axis=k), name="C")
        D = tvm.compute((), lambda : C + 1)
        s = tvm.create_schedule(D.op)
        # build and invoke the kernel.
        f = tvm.build(s, [A, scale, D], "llvm")
        ctx = tvm.cpu(0)
        # launch the kernel.
        a = tvm.nd.array(np.random.randint(0, 2, size=(n,)).astype(A.dtype), ctx)
        sc = tvm.nd.array(
            np.random.randint(0, 2, size=()).astype(scale.dtype), ctx)
        d = tvm.nd.empty((), D.dtype, ctx)
        f(a, sc, d)
        d_np = np.sum(a.asnumpy()) * sc.asnumpy() + 1
348
        tvm.testing.assert_allclose(d.asnumpy(), d_np)
349 350 351
    check_llvm(64)


352 353 354 355 356 357 358 359 360 361 362 363 364 365
def test_alignment():
    n = tvm.convert(1024)
    A = tvm.placeholder((n,), name='A')
    B = tvm.compute(A.shape, lambda i: A[i] * 3, name='B')
    s = tvm.create_schedule(B.op)
    bx, tx = s[B].split(B.op.axis[0], factor=8)
    s[B].vectorize(tx)
    f = tvm.build(s, [A, B], "llvm")

    for l in f.get_source().split("\n"):
        if "align" in l and "4 x float" in l:
            assert "align 32" in l


366
if __name__ == "__main__":
367
    test_llvm_import()
368
    test_alignment()
369
    test_rank_zero()
370
    test_llvm_bool()
371
    test_llvm_persist_parallel()
372
    test_llvm_select()
373
    test_llvm_vadd_pipeline()
374
    test_llvm_add_pipeline()
375
    test_llvm_intrin()
376
    test_multiple_func()
377 378
    test_llvm_flip_pipeline()
    test_llvm_madd_pipeline()
379
    test_llvm_temp_space()
380
    test_llvm_lookup_intrin()