test_module_load.py 5.58 KB
Newer Older
1
import tvm
2
from tvm.contrib import cc, util
3
import ctypes
4 5
import os
import numpy as np
6 7 8 9 10
import subprocess

runtime_py = """
import os
import sys
11

12
os.environ["TVM_USE_RUNTIME_LIB"] = "1"
13
os.environ["TVM_FFI"] = "ctypes"
14 15 16 17 18 19 20 21 22 23
import tvm
import numpy as np
path_dso = sys.argv[1]
dtype = sys.argv[2]
ff = tvm.module.load(path_dso)
a = tvm.nd.array(np.zeros(10, dtype=dtype))
ff(a)
np.testing.assert_equal(a.asnumpy(), np.arange(a.shape[0]))
print("Finish runtime checking...")
"""
24 25

def test_dso_module_load():
26
    if not tvm.module.enabled("llvm"):
27 28
        return
    dtype = 'int64'
29
    temp = util.tempdir()
30 31

    def save_object(names):
32 33 34
        n = tvm.var('n')
        Ab = tvm.decl_buffer((n, ), dtype)
        i = tvm.var('i')
35 36 37 38 39 40
        # for i in 0 to n-1:
        stmt = tvm.make.For(
            i, 0, n - 1, 0, 0,
            tvm.make.Store(Ab.data,
                           tvm.make.Load(dtype, Ab.data, i) + 1,
                           i + 1))
41
        fapi = tvm.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0, True)
42
        fapi = tvm.ir_pass.LowerTVMBuiltin(fapi)
43
        m = tvm.codegen.build_module(fapi, "llvm")
44 45 46
        for name in names:
            m.save(name)

47 48 49 50
    path_obj = temp.relpath("test.o")
    path_ll = temp.relpath("test.ll")
    path_bc = temp.relpath("test.bc")
    path_dso = temp.relpath("test.so")
51 52 53 54
    save_object([path_obj, path_ll, path_bc])
    cc.create_shared(path_dso, [path_obj])

    f1 = tvm.module.load(path_dso)
55
    f2 = tvm.module.load(path_ll)
56 57 58 59 60 61 62
    a = tvm.nd.array(np.zeros(10, dtype=dtype))
    f1(a)
    np.testing.assert_equal(a.asnumpy(), np.arange(a.shape[0]))
    a = tvm.nd.array(np.zeros(10, dtype=dtype))
    f2(a)
    np.testing.assert_equal(a.asnumpy(), np.arange(a.shape[0]))

63 64 65
    path_runtime_py = temp.relpath("runtime.py")
    with open(path_runtime_py, "w") as fo:
        fo.write(runtime_py)
66

67 68 69
    subprocess.check_call(
        "python %s %s %s" % (path_runtime_py, path_dso, dtype),
        shell=True)
70

71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89

def test_device_module_dump():
    # graph
    n = tvm.convert(1024)
    A = tvm.placeholder((n,), name='A')
    B = tvm.compute(A.shape, lambda *i: A(*i) + 1.0, name='B')
    s = tvm.create_schedule(B.op)
    # create iter var and assign them tags.
    num_thread = 8
    bx, tx = s[B].split(B.op.axis[0], factor=num_thread)
    s[B].bind(bx, tvm.thread_axis("blockIdx.x"))
    s[B].bind(tx, tvm.thread_axis("threadIdx.x"))

    def check_device(device):
        ctx = tvm.context(device, 0)
        if not ctx.exist:
            print("Skip because %s is not enabled" % device)
            return
        temp = util.tempdir()
90 91
        name = "myadd_%s" % device
        f = tvm.build(s, [A, B], device, "llvm -system-lib", name=name)
92 93 94 95 96 97 98
        path_dso = temp.relpath("dev_lib.so")
        f.export_library(path_dso)

        f1 = tvm.module.load(path_dso)
        a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx)
        b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx)
        f1(a, b)
99 100 101
        f2 = tvm.module.system_lib()
        np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
        f2[name](a, b)
102 103 104 105 106 107
        np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)

    check_device("cuda")
    check_device("opencl")
    check_device("metal")

108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141

def test_combine_module_llvm():
    """Test combine multiple module into one shared lib."""
    # graph
    nn = 12
    n = tvm.convert(nn)
    A = tvm.placeholder((n,), name='A')
    B = tvm.compute(A.shape, lambda *i: A(*i) + 1.0, name='B')
    s = tvm.create_schedule(B.op)

    def check_llvm():
        ctx = tvm.cpu(0)
        if not tvm.module.enabled("llvm"):
            print("Skip because llvm is not enabled" )
            return
        temp = util.tempdir()
        fadd1 = tvm.build(s, [A, B], "llvm", name="myadd1")
        fadd2 = tvm.build(s, [A, B], "llvm", name="myadd2")
        path1 = temp.relpath("myadd1.o")
        path2 = temp.relpath("myadd2.o")
        path_dso = temp.relpath("mylib.so")
        fadd1.save(path1)
        fadd2.save(path2)
        # create shared library with multiple functions
        cc.create_shared(path_dso, [path1, path2])
        m = tvm.module.load(path_dso)
        fadd1 = m['myadd1']
        fadd2 = m['myadd2']
        a = tvm.nd.array(np.random.uniform(size=nn).astype(A.dtype), ctx)
        b = tvm.nd.array(np.zeros(nn, dtype=A.dtype), ctx)
        fadd1(a, b)
        np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
        fadd2(a, b)
        np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168

    def check_system_lib():
        ctx = tvm.cpu(0)
        if not tvm.module.enabled("llvm"):
            print("Skip because llvm is not enabled" )
            return
        temp = util.tempdir()
        fadd1 = tvm.build(s, [A, B], "llvm -system-lib", name="myadd1")
        fadd2 = tvm.build(s, [A, B], "llvm -system-lib", name="myadd2")
        path1 = temp.relpath("myadd1.o")
        path2 = temp.relpath("myadd2.o")
        path_dso = temp.relpath("mylib.so")
        fadd1.save(path1)
        fadd2.save(path2)
        cc.create_shared(path_dso, [path1, path2])
        # Load dll, will trigger system library registration
        dll = ctypes.CDLL(path_dso)
        # Load the system wide library
        mm = tvm.module.system_lib()
        a = tvm.nd.array(np.random.uniform(size=nn).astype(A.dtype), ctx)
        b = tvm.nd.array(np.zeros(nn, dtype=A.dtype), ctx)
        mm['myadd1'](a, b)
        np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
        mm['myadd2'](a, b)
        np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)

    check_system_lib()
169 170 171
    check_llvm()


172

173
if __name__ == "__main__":
174
    test_combine_module_llvm()
175
    test_device_module_dump()
176
    test_dso_module_load()