test_codegen_cross_llvm.py 2.54 KB
Newer Older
1 2 3 4
"""Test cross compilation"""
import tvm
import os
import struct
5
from tvm.contrib import util, cc, rpc
6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27
import numpy as np

def test_llvm_add_pipeline():
    nn = 1024
    n = tvm.convert(nn)
    A = tvm.placeholder((n,), name='A')
    B = tvm.placeholder((n,), name='B')
    C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
    s = tvm.create_schedule(C.op)
    xo, xi = s[C].split(C.op.axis[0], factor=4)
    s[C].parallel(xo)
    s[C].vectorize(xi)

    def verify_elf(path, e_machine):
        with open(path, "rb") as fi:
            arr = fi.read(20)
            assert struct.unpack('ccc', arr[1:4]) == (b'E',b'L',b'F')
            endian = struct.unpack('b', arr[0x5:0x6])[0]
            endian = '<' if endian == 1 else '>'
            assert struct.unpack(endian + 'h', arr[0x12:0x14])[0] == e_machine

    def build_i386():
28 29 30
        if not tvm.module.enabled("llvm"):
            print("Skip because llvm is not enabled..")
            return
31 32 33 34 35 36 37 38
        temp = util.tempdir()
        target = "llvm -target=i386-pc-linux-gnu"
        f = tvm.build(s, [A, B, C], target)
        path = temp.relpath("myadd.o")
        f.save(path)
        verify_elf(path, 0x03)

    def build_arm():
39 40 41
        target = "llvm -target=armv7-none-linux-gnueabihf"
        if not tvm.module.enabled(target):
            print("Skip because %s is not enabled.." % target)
42
            return
43 44 45 46 47
        temp = util.tempdir()
        f = tvm.build(s, [A, B, C], target)
        path = temp.relpath("myadd.o")
        f.save(path)
        verify_elf(path, 0x28)
48 49
        asm_path = temp.relpath("myadd.asm")
        f.save(asm_path)
50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
        # Do a RPC verification, launch kernel on Arm Board if available.
        host = os.environ.get('TVM_RPC_ARM_HOST', None)
        remote = None
        if host:
            port = int(os.environ['TVM_RPC_ARM_PORT'])
            try:
                remote = rpc.connect(host, port)
            except tvm.TVMError as e:
                pass

        if remote:
            remote.upload(path)
            farm = remote.load_module("myadd.o")
            ctx = remote.cpu(0)
            n = nn
            a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
            b = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
            c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
            farm(a, b, c)
            np.testing.assert_allclose(
                c.asnumpy(), a.asnumpy() + b.asnumpy())
            print("Verification finish on remote..")

    build_i386()
    build_arm()

if __name__ == "__main__":
    test_llvm_add_pipeline()