test_codegen_llvm.py 21.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17
import tvm
18
import topi
19
from tvm.contrib import util, clang
20
import numpy as np
21
import ctypes
22
import math
23 24 25 26 27 28 29 30 31 32 33

def test_llvm_intrin():
    ib = tvm.ir_builder.create()
    n = tvm.convert(4)
    A = ib.pointer("float32", name="A")
    args = [
        tvm.call_pure_intrin("handle", "tvm_address_of", A[0]),
        0, 3, 1
    ]
    ib.emit(tvm.make.Evaluate(
        tvm.make.Call(
34
            "int32", "prefetch", args, tvm.expr.Call.Intrinsic, None, 0)))
35 36 37
    body = ib.get()
    func = tvm.ir_pass.MakeAPI(body, "prefetch", [A], 0, True)
    fcode = tvm.build(func, None, "llvm")
38

39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72

def test_llvm_import():
    # extern "C" is necessary to get the correct signature
    cc_code = """
    extern "C" float my_add(float x, float y) {
      return x + y;
    }
    """
    n = 10
    A = tvm.placeholder((n,), name='A')
    B = tvm.compute((n,), lambda *i:
                    tvm.call_pure_extern("float32", "my_add", A(*i), 1.0),
                    name='B')
    def check_llvm(use_file):
        if not tvm.module.enabled("llvm"):
            return
        if not clang.find_clang(required=False):
            print("skip because clang is not available")
            return
        temp = util.tempdir()
        ll_path = temp.relpath("temp.ll")
        ll_code = clang.create_llvm(cc_code, output=ll_path)
        s = tvm.create_schedule(B.op)
        if use_file:
            s[B].pragma(s[B].op.axis[0], "import_llvm", ll_path)
        else:
            s[B].pragma(s[B].op.axis[0], "import_llvm", ll_code)
        # BUILD and invoke the kernel.
        f = tvm.build(s, [A, B], "llvm")
        ctx = tvm.cpu(0)
        # launch the kernel.
        a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
        b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), ctx)
        f(a, b)
73
        tvm.testing.assert_allclose(
74 75 76 77 78 79
            b.asnumpy(), a.asnumpy() + 1.0)
    check_llvm(use_file=True)
    check_llvm(use_file=False)



80 81 82 83 84 85 86 87 88 89
def test_llvm_lookup_intrin():
    ib = tvm.ir_builder.create()
    m = tvm.var("m")
    A = ib.pointer("uint8x8", name="A")
    x = tvm.call_llvm_intrin("uint8x8", "llvm.ctpop.i8", tvm.const(1, 'uint32'), A)
    ib.emit(x)
    body = ib.get()
    func = tvm.ir_pass.MakeAPI(body, "ctpop", [A], 1, True)
    fcode = tvm.build(func, None, "llvm")

90

91
def test_llvm_add_pipeline():
92 93
    nn = 1024
    n = tvm.convert(nn)
94 95
    A = tvm.placeholder((n,), name='A')
    B = tvm.placeholder((n,), name='B')
96 97 98
    AA = tvm.compute((n,), lambda *i: A(*i), name='A')
    BB = tvm.compute((n,), lambda *i: B(*i), name='B')
    T = tvm.compute(A.shape, lambda *i: AA(*i) + BB(*i), name='T')
99
    C = tvm.compute(A.shape, lambda *i: T(*i), name='C')
100
    s = tvm.create_schedule(C.op)
101
    xo, xi = s[C].split(C.op.axis[0], factor=4)
102 103 104 105 106
    xo1, xo2 = s[C].split(xo, factor=13)
    s[C].parallel(xo2)
    s[C].pragma(xo1, "parallel_launch_point")
    s[C].pragma(xo2, "parallel_stride_pattern")
    s[C].pragma(xo2, "parallel_barrier_when_finish")
107
    s[C].vectorize(xi)
108

109
    def check_llvm():
110
        if not tvm.module.enabled("llvm"):
111
            return
112 113
        # Specifically allow offset to test codepath when offset is available
        Ab = tvm.decl_buffer(
114 115 116
            A.shape, A.dtype,
            elem_offset=tvm.var('Aoffset'),
            offset_factor=8,
117 118
            name='A')
        binds = {A : Ab}
119
        # BUILD and invoke the kernel.
120
        f = tvm.build(s, [A, B, C], "llvm", binds=binds)
121 122
        ctx = tvm.cpu(0)
        # launch the kernel.
123
        n = nn
124 125 126
        a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
        b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), ctx)
        c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
127
        f(a, b, c)
128
        tvm.testing.assert_allclose(
129
            c.asnumpy(), a.asnumpy() + b.asnumpy())
130

131 132
    with tvm.build_config(offset_factor=4):
        check_llvm()
133 134


135 136 137 138
def test_llvm_persist_parallel():
    n = 128
    A = tvm.placeholder((n,), name='A')
    B = tvm.compute(A.shape, lambda *i: A(*i) + 1, name='B')
139
    C = tvm.compute(A.shape, lambda *i: tvm.sqrt(B(*i)) * 2 + 2, name='C')
140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159
    s = tvm.create_schedule(C.op)
    xo, xi = s[C].split(C.op.axis[0], factor=8)
    xo1, xo2 = s[C].split(xo, nparts=1)
    s[B].compute_at(s[C], xo1)
    s[B].parallel(s[B].op.axis[0])
    s[B].pragma(s[B].op.axis[0], "parallel_barrier_when_finish")
    s[C].parallel(xi)
    s[C].pragma(xo1, "parallel_launch_point")
    s[C].pragma(xi, "parallel_stride_pattern")

    def check_llvm():
        if not tvm.module.enabled("llvm"):
            return
        # BUILD and invoke the kernel.
        f = tvm.build(s, [A, C], "llvm")
        ctx = tvm.cpu(0)
        # launch the kernel.
        a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
        c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
        f(a, c)
160
        tvm.testing.assert_allclose(c.asnumpy(),
161 162
                                   np.sqrt(a.asnumpy() + 1) * 2 + 2,
                                   rtol=1e-5)
163 164 165 166

    check_llvm()


167 168
def test_llvm_flip_pipeline():
    def check_llvm(nn, base):
169
        if not tvm.module.enabled("llvm"):
170 171 172 173
            return
        n = tvm.convert(nn)
        A = tvm.placeholder((n + base), name='A')
        C = tvm.compute((n,), lambda i: A(nn + base- i - 1), name='C')
174
        s = tvm.create_schedule(C.op)
175 176 177 178 179 180 181 182 183 184 185
        xo, xi = s[C].split(C.op.axis[0], factor=4)
        s[C].parallel(xo)
        s[C].vectorize(xi)
        # build and invoke the kernel.
        f = tvm.build(s, [A, C], "llvm")
        ctx = tvm.cpu(0)
        # launch the kernel.
        n = nn
        a = tvm.nd.array(np.random.uniform(size=(n + base)).astype(A.dtype), ctx)
        c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
        f(a, c)
186
        tvm.testing.assert_allclose(
187 188 189 190 191 192 193
            c.asnumpy(), a.asnumpy()[::-1][:n])
    check_llvm(4, 0)
    check_llvm(128, 8)
    check_llvm(3, 0)
    check_llvm(128, 1)


194 195 196 197 198 199 200 201
def test_llvm_vadd_pipeline():
    def check_llvm(n, lanes):
        if not tvm.module.enabled("llvm"):
            return
        A = tvm.placeholder((n,), name='A', dtype="float32x%d" % lanes)
        B = tvm.compute((n,), lambda i: A[i], name='B')
        C = tvm.compute((n,), lambda i: B[i] + tvm.const(1, A.dtype), name='C')
        s = tvm.create_schedule(C.op)
202 203
        xo, xi = s[C].split(C.op.axis[0], nparts=2)
        _, xi = s[C].split(xi, factor=2)
204 205
        s[C].parallel(xo)
        s[C].vectorize(xi)
206
        s[B].compute_at(s[C], xo)
207 208 209 210 211 212 213 214 215 216
        xo, xi = s[B].split(B.op.axis[0], factor=2)
        s[B].vectorize(xi)
        # build and invoke the kernel.
        f = tvm.build(s, [A, C], "llvm")
        ctx = tvm.cpu(0)
        # launch the kernel.
        a = tvm.nd.empty((n,), A.dtype).copyfrom(
            np.random.uniform(size=(n, lanes)))
        c = tvm.nd.empty((n,), C.dtype, ctx)
        f(a, c)
217
        tvm.testing.assert_allclose(
218 219
            c.asnumpy(), a.asnumpy() + 1)
    check_llvm(64, 2)
220
    check_llvm(512, 2)
221 222


223 224
def test_llvm_madd_pipeline():
    def check_llvm(nn, base, stride):
225
        if not tvm.module.enabled("llvm"):
226 227 228 229
            return
        n = tvm.convert(nn)
        A = tvm.placeholder((n + base, stride), name='A')
        C = tvm.compute((n, stride), lambda i, j: A(base + i, j) + 1, name='C')
230
        s = tvm.create_schedule(C.op)
231 232 233 234 235 236 237 238 239 240 241
        xo, xi = s[C].split(C.op.axis[0], factor=4)
        s[C].parallel(xo)
        s[C].vectorize(xi)
        # build and invoke the kernel.
        f = tvm.build(s, [A, C], "llvm")
        ctx = tvm.cpu(0)
        # launch the kernel.
        n = nn
        a = tvm.nd.array(np.random.uniform(size=(n + base, stride)).astype(A.dtype), ctx)
        c = tvm.nd.array(np.zeros((n, stride), dtype=C.dtype), ctx)
        f(a, c)
242
        tvm.testing.assert_allclose(
243 244 245
            c.asnumpy(), a.asnumpy()[base:] + 1)
    check_llvm(64, 0, 2)
    check_llvm(4, 0, 1)
246 247
    with tvm.build_config(restricted_func=False):
        check_llvm(4, 0, 3)
248

249

250 251 252 253 254 255
def test_llvm_temp_space():
    nn = 1024
    n = tvm.convert(nn)
    A = tvm.placeholder((n,), name='A')
    B = tvm.compute(A.shape, lambda i: A(i) + 1, name='B')
    C = tvm.compute(A.shape, lambda i: B(i) + 1, name='C')
256
    s = tvm.create_schedule(C.op)
257

258
    def check_llvm():
259
        if not tvm.module.enabled("llvm"):
260 261 262 263 264 265 266 267 268
            return
        # build and invoke the kernel.
        f = tvm.build(s, [A, C], "llvm")
        ctx = tvm.cpu(0)
        # launch the kernel.
        n = nn
        a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
        c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
        f(a, c)
269
        tvm.testing.assert_allclose(
270 271
            c.asnumpy(), a.asnumpy() + 1 + 1)
    check_llvm()
272

273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298
def test_multiple_func():
    nn = 1024
    n = tvm.convert(nn)
    A = tvm.placeholder((n,), name='A')
    B = tvm.placeholder((n,), name='B')
    C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
    s = tvm.create_schedule(C.op)
    xo, xi = s[C].split(C.op.axis[0], factor=4)
    s[C].parallel(xo)
    s[C].vectorize(xi)
    def check_llvm():
        if not tvm.module.enabled("llvm"):
            return
        # build two functions
        f2 = tvm.lower(s, [A, B, C], name="fadd1")
        f1 = tvm.lower(s, [A, B, C], name="fadd2")
        m = tvm.build([f1, f2], "llvm")
        fadd1 = m['fadd1']
        fadd2 = m['fadd2']
        ctx = tvm.cpu(0)
        # launch the kernel.
        n = nn
        a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
        b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), ctx)
        c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
        fadd1(a, b, c)
299
        tvm.testing.assert_allclose(
300 301
            c.asnumpy(), a.asnumpy() + b.asnumpy())
        fadd2(a, b, c)
302
        tvm.testing.assert_allclose(
303 304 305 306
            c.asnumpy(), a.asnumpy() + b.asnumpy())
    check_llvm()


307

308
def test_llvm_condition():
309 310 311 312
    def check_llvm(n, offset):
        if not tvm.module.enabled("llvm"):
            return
        A = tvm.placeholder((n, ), name='A')
313
        C = tvm.compute((n,), lambda i: tvm.if_then_else(i >= offset, A[i], 0.0), name='C')
314 315 316 317 318 319 320 321 322 323
        s = tvm.create_schedule(C.op)
        # build and invoke the kernel.
        f = tvm.build(s, [A, C], "llvm")
        ctx = tvm.cpu(0)
        # launch the kernel.
        a = tvm.nd.array(np.random.uniform(size=(n,)).astype(A.dtype), ctx)
        c = tvm.nd.empty((n,), A.dtype, ctx)
        f(a, c)
        c_np = a.asnumpy()
        c_np[:offset] = 0
324
        tvm.testing.assert_allclose(c.asnumpy(), c_np)
325 326 327
    check_llvm(64, 8)


328 329 330 331 332 333 334 335 336 337 338 339 340 341 342
def test_llvm_bool():
    def check_llvm(n):
        if not tvm.module.enabled("llvm"):
            return
        A = tvm.placeholder((n, ), name='A', dtype="int32")
        C = tvm.compute((n,), lambda i: A[i].equal(1).astype("float"), name='C')
        s = tvm.create_schedule(C.op)
        # build and invoke the kernel.
        f = tvm.build(s, [A, C], "llvm")
        ctx = tvm.cpu(0)
        # launch the kernel.
        a = tvm.nd.array(np.random.randint(0, 2, size=(n,)).astype(A.dtype), ctx)
        c = tvm.nd.empty((n,), C.dtype, ctx)
        f(a, c)
        c_np = a.asnumpy() == 1
343
        tvm.testing.assert_allclose(c.asnumpy(), c_np)
344 345 346
    check_llvm(64)


347 348 349 350 351 352 353
def test_rank_zero():
    def check_llvm(n):
        if not tvm.module.enabled("llvm"):
            return
        A = tvm.placeholder((n, ), name='A')
        scale = tvm.placeholder((), name='scale')
        k = tvm.reduce_axis((0, n), name="k")
354 355
        C = tvm.compute((), lambda : tvm.sum(A[k] * scale(), axis=k), name="C")
        D = tvm.compute((), lambda : C() + 1)
356 357 358 359 360 361 362 363 364 365 366
        s = tvm.create_schedule(D.op)
        # build and invoke the kernel.
        f = tvm.build(s, [A, scale, D], "llvm")
        ctx = tvm.cpu(0)
        # launch the kernel.
        a = tvm.nd.array(np.random.randint(0, 2, size=(n,)).astype(A.dtype), ctx)
        sc = tvm.nd.array(
            np.random.randint(0, 2, size=()).astype(scale.dtype), ctx)
        d = tvm.nd.empty((), D.dtype, ctx)
        f(a, sc, d)
        d_np = np.sum(a.asnumpy()) * sc.asnumpy() + 1
367
        tvm.testing.assert_allclose(d.asnumpy(), d_np)
368 369
    check_llvm(64)

370 371 372 373 374 375 376 377
def test_rank_zero_bound_checkers():
    def check_llvm(n):
        if not tvm.module.enabled("llvm"):
            return
        with tvm.build_config(instrument_bound_checkers=True):
            A = tvm.placeholder((n, ), name='A')
            scale = tvm.placeholder((), name='scale')
            k = tvm.reduce_axis((0, n), name="k")
378 379
            C = tvm.compute((), lambda : tvm.sum(A[k] * scale(), axis=k), name="C")
            D = tvm.compute((), lambda : C() + 1)
380 381 382 383 384 385 386 387 388 389 390 391 392 393
            s = tvm.create_schedule(D.op)
            # build and invoke the kernel.
            f = tvm.build(s, [A, scale, D], "llvm")
            ctx = tvm.cpu(0)
            # launch the kernel.
            a = tvm.nd.array(np.random.randint(0, 2, size=(n,)).astype(A.dtype), ctx)
            sc = tvm.nd.array(
                np.random.randint(0, 2, size=()).astype(scale.dtype), ctx)
            d = tvm.nd.empty((), D.dtype, ctx)
            f(a, sc, d)
            d_np = np.sum(a.asnumpy()) * sc.asnumpy() + 1
            tvm.testing.assert_allclose(d.asnumpy(), d_np)
    check_llvm(64)

394

395 396 397 398 399 400 401 402 403 404 405 406 407
def test_alignment():
    n = tvm.convert(1024)
    A = tvm.placeholder((n,), name='A')
    B = tvm.compute(A.shape, lambda i: A[i] * 3, name='B')
    s = tvm.create_schedule(B.op)
    bx, tx = s[B].split(B.op.axis[0], factor=8)
    s[B].vectorize(tx)
    f = tvm.build(s, [A, B], "llvm")

    for l in f.get_source().split("\n"):
        if "align" in l and "4 x float" in l:
            assert "align 32" in l

408 409 410 411
def test_llvm_div():
    """Check that the semantics of div and mod is the same as in C/C++"""
    def check_div(start, end, divisor, dtype):
        T = tvm.compute((end - start,),
412
                        lambda i: tvm.div(tvm.expr.Cast(dtype, (start + i)), tvm.const(divisor, dtype)))
413 414 415 416 417 418 419 420
        s = tvm.create_schedule([T.op])
        f = tvm.build(s, [T], "llvm")
        a = tvm.nd.empty((end - start,), dtype)
        f(a)
        ref = [int(float(i)/divisor) for i in range(start, end)]
        tvm.testing.assert_allclose(a.asnumpy(), ref)

    def check_mod(start, end, divisor, dtype):
421
        tmod = tvm.truncmod
422
        T = tvm.compute((end - start,),
423
                        lambda i: tmod(tvm.expr.Cast(dtype, (start + i)), tvm.const(divisor, dtype)))
424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442
        s = tvm.create_schedule([T.op])
        f = tvm.build(s, [T], "llvm")
        a = tvm.nd.empty((end - start,), dtype)
        f(a)
        ref = [int(math.fmod(i, divisor)) for i in range(start, end)]
        tvm.testing.assert_allclose(a.asnumpy(), ref)

    def check_llvm(start, end, divisor, dtype):
        check_div(start, end, divisor, dtype)
        check_mod(start, end, divisor, dtype)

    for d in range(-5, 6):
        if d != 0:
            # Note that 11 (and not e.g. 10) is used to avoid issues with the simplifier
            check_llvm(-11, 11, d, 'int32')
            check_llvm(-11, 11, d, 'int8')
            if d > 0:
                check_llvm(123, 133, d, 'uint8')
                check_llvm(0, 256, d, 'uint8')
443

444 445 446
def test_llvm_fp_math():
    def check_llvm_reciprocal(n):
        A = tvm.placeholder((n,), name='A')
447
        B = tvm.compute((n,), lambda i: tvm.div(1.0,(1e+37*A[i])), name='B')
448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476

        s = tvm.create_schedule(B.op)
        f = tvm.build(s, [A, B], "llvm")

        a = tvm.nd.array(np.full((n,), 100, 'float32'))
        b = tvm.nd.empty((n,), 'float32')
        f(a, b)
        tvm.testing.assert_allclose(b.asnumpy(), np.zeros((n,), 'float32'))

    check_llvm_reciprocal(4)
    check_llvm_reciprocal(8)
    check_llvm_reciprocal(16)

    def check_llvm_sigmoid(n):
        A = tvm.placeholder((n,), name='A')
        B = tvm.compute((n,), lambda i: tvm.sigmoid(A[i]), name='B')

        s = tvm.create_schedule(B.op)
        f = tvm.build(s, [A, B], "llvm")

        a = tvm.nd.array(np.full((n,), -1000, 'float32'))
        b = tvm.nd.empty((n,), 'float32')
        f(a, b)
        tvm.testing.assert_allclose(b.asnumpy(), np.zeros((n,), 'float32'))

    check_llvm_sigmoid(4)
    check_llvm_sigmoid(8)
    check_llvm_sigmoid(16)

477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492

def test_dwarf_debug_information():
    nn = 1024
    n = tvm.convert(nn)
    A = tvm.placeholder((n,), name='A')
    B = tvm.placeholder((n,), name='B')
    C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
    s = tvm.create_schedule(C.op)
    xo, xi = s[C].split(C.op.axis[0], factor=4)
    s[C].parallel(xo)
    s[C].vectorize(xi)
    def check_llvm_object():
        if not tvm.module.enabled("llvm"):
            return
        if tvm.codegen.llvm_version_major() < 5:
            return
493 494
        if tvm.codegen.llvm_version_major() > 6:
            return
495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529
        # build two functions
        f2 = tvm.lower(s, [A, B, C], name="fadd1")
        f1 = tvm.lower(s, [A, B, C], name="fadd2")
        m = tvm.build([f1, f2], "llvm")
        temp = util.tempdir()
        o_path = temp.relpath("temp.o")
        m.save(o_path)
        import re
        import shutil
        import subprocess
        import sys

        # Try the dwarfdump utility (OS X)
        if shutil.which("dwarfdump"):
            output = subprocess.check_output(["dwarfdump", o_path])
            assert re.search(r"""DW_AT_name\\t\("fadd1"\)""", str(output))
            assert re.search(r"""DW_AT_name\\t\("fadd2"\)""", str(output))

        # Try gobjdump (OS X)
        if shutil.which("gobjdump"):
            output = subprocess.check_output(["gobjdump", "--dwarf", o_path])
            assert re.search(r"""DW_AT_name.*fadd1""", str(output))
            assert re.search(r"""DW_AT_name.*fadd2""", str(output))

        # Try objdump (Linux) - Darwin objdump has different DWARF syntax.
        if shutil.which("objdump") and sys.platform != 'darwin':
            output = subprocess.check_output(["objdump", "--dwarf", o_path])
            assert re.search(r"""DW_AT_name.*fadd1""", str(output))
            assert re.search(r"""DW_AT_name.*fadd2""", str(output))

    def check_llvm_ir():
        if not tvm.module.enabled("llvm"):
            return
        if tvm.codegen.llvm_version_major() < 5:
            return
530 531
        if tvm.codegen.llvm_version_major() > 6:
            return
532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552
        # build two functions
        f2 = tvm.lower(s, [A, B, C], name="fadd1")
        f1 = tvm.lower(s, [A, B, C], name="fadd2")
        m = tvm.build([f1, f2], target="llvm -target=aarch64-linux-gnu")
        ll = m.get_source("ll")

        # On non-Darwin OS, don't explicitly specify DWARF version.
        import re
        assert not re.search(r""""Dwarf Version""""", ll)
        assert re.search(r"""llvm.dbg.value""", ll)

        # Try Darwin, require DWARF-2
        m = tvm.build([f1, f2],
                      target="llvm -target=x86_64-apple-darwin-macho")
        ll = m.get_source("ll")
        assert re.search(r"""i32 4, !"Dwarf Version", i32 2""", ll)
        assert re.search(r"""llvm.dbg.value""", ll)

    check_llvm_object()
    check_llvm_ir()

553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583

def test_llvm_shuffle():
    a = tvm.placeholder((8, ), 'int32')
    b = tvm.placeholder((8, ), 'int32')
    c = tvm.compute((8, ), lambda x: a[x] + b[7-x])
    sch = tvm.create_schedule(c.op)

    def my_vectorize(stmt):

        def vectorizer(op):
            store = op.body
            idx = tvm.make.Ramp(tvm.const(0, 'int32'), tvm.const(1, 'int32'), 8)
            all_ones = tvm.const(1, 'int32x8')
            value = store.value
            b_idx = tvm.make.Shuffle([idx], [tvm.const(i, 'int32') for i in range(7, -1, -1)])
            new_a = tvm.make.Load('int32x8', value.a.buffer_var, idx, all_ones)
            new_b = tvm.make.Load('int32x8', value.b.buffer_var, b_idx, all_ones)
            value = new_a + new_b
            return tvm.make.Store(store.buffer_var, new_a + new_b, idx, all_ones)

        return tvm.ir_pass.IRTransform(stmt, None, vectorizer, ['For'])

    with tvm.build_config(add_lower_pass=[(1, my_vectorize)]):
        ir = tvm.lower(sch, [a, b, c], simple_mode=True)
        module = tvm.build(sch, [a, b, c])
        a_ = tvm.ndarray.array(np.arange(1, 9, dtype='int32'))
        b_ = tvm.ndarray.array(np.arange(8, 0, -1, dtype='int32'))
        c_ = tvm.ndarray.array(np.zeros((8, ), dtype='int32'))
        module(a_, b_, c_)
        tvm.testing.assert_allclose(c_.asnumpy(), (a_.asnumpy() * 2).astype('int32'))

584
if __name__ == "__main__":
585
    test_llvm_import()
586
    test_alignment()
587
    test_rank_zero()
588
    test_rank_zero_bound_checkers()
589
    test_llvm_bool()
590
    test_llvm_persist_parallel()
591
    test_llvm_condition()
592
    test_llvm_vadd_pipeline()
593
    test_llvm_add_pipeline()
594
    test_llvm_intrin()
595
    test_multiple_func()
596 597
    test_llvm_flip_pipeline()
    test_llvm_madd_pipeline()
598
    test_llvm_temp_space()
599
    test_llvm_lookup_intrin()
600
    test_llvm_div()
601
    test_llvm_fp_math()
602
    test_dwarf_debug_information()
603
    test_llvm_shuffle()