test_module_load.py 7.35 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17
import tvm
18
from tvm.contrib import cc, util
19
import ctypes
20
import os
Hu Shiwen committed
21
import sys
22
import numpy as np
23 24 25 26 27
import subprocess

runtime_py = """
import os
import sys
28

29
os.environ["TVM_USE_RUNTIME_LIB"] = "1"
30
os.environ["TVM_FFI"] = "ctypes"
31 32 33 34 35 36 37 38 39 40
import tvm
import numpy as np
path_dso = sys.argv[1]
dtype = sys.argv[2]
ff = tvm.module.load(path_dso)
a = tvm.nd.array(np.zeros(10, dtype=dtype))
ff(a)
np.testing.assert_equal(a.asnumpy(), np.arange(a.shape[0]))
print("Finish runtime checking...")
"""
41 42

def test_dso_module_load():
43
    if not tvm.module.enabled("llvm"):
44 45
        return
    dtype = 'int64'
46
    temp = util.tempdir()
47 48

    def save_object(names):
49 50 51
        n = tvm.var('n')
        Ab = tvm.decl_buffer((n, ), dtype)
        i = tvm.var('i')
52 53 54 55 56 57
        # for i in 0 to n-1:
        stmt = tvm.make.For(
            i, 0, n - 1, 0, 0,
            tvm.make.Store(Ab.data,
                           tvm.make.Load(dtype, Ab.data, i) + 1,
                           i + 1))
58
        fapi = tvm.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0, True)
59
        fapi = tvm.ir_pass.LowerTVMBuiltin(fapi)
60
        m = tvm.codegen.build_module(fapi, "llvm")
61 62 63
        for name in names:
            m.save(name)

64 65 66 67
    path_obj = temp.relpath("test.o")
    path_ll = temp.relpath("test.ll")
    path_bc = temp.relpath("test.bc")
    path_dso = temp.relpath("test.so")
68 69 70 71
    save_object([path_obj, path_ll, path_bc])
    cc.create_shared(path_dso, [path_obj])

    f1 = tvm.module.load(path_dso)
72
    f2 = tvm.module.load(path_ll)
73 74 75 76 77 78 79
    a = tvm.nd.array(np.zeros(10, dtype=dtype))
    f1(a)
    np.testing.assert_equal(a.asnumpy(), np.arange(a.shape[0]))
    a = tvm.nd.array(np.zeros(10, dtype=dtype))
    f2(a)
    np.testing.assert_equal(a.asnumpy(), np.arange(a.shape[0]))

80 81 82
    path_runtime_py = temp.relpath("runtime.py")
    with open(path_runtime_py, "w") as fo:
        fo.write(runtime_py)
83

84
    subprocess.check_call(
85
        "python3 %s %s %s" % (path_runtime_py, path_dso, dtype),
86
        shell=True)
87

88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106

def test_device_module_dump():
    # graph
    n = tvm.convert(1024)
    A = tvm.placeholder((n,), name='A')
    B = tvm.compute(A.shape, lambda *i: A(*i) + 1.0, name='B')
    s = tvm.create_schedule(B.op)
    # create iter var and assign them tags.
    num_thread = 8
    bx, tx = s[B].split(B.op.axis[0], factor=num_thread)
    s[B].bind(bx, tvm.thread_axis("blockIdx.x"))
    s[B].bind(tx, tvm.thread_axis("threadIdx.x"))

    def check_device(device):
        ctx = tvm.context(device, 0)
        if not ctx.exist:
            print("Skip because %s is not enabled" % device)
            return
        temp = util.tempdir()
107
        name = "myadd_%s" % device
Hu Shiwen committed
108 109 110 111 112 113
        if sys.platform == "darwin" or sys.platform.startswith('linux'):
            f = tvm.build(s, [A, B], device, "llvm -system-lib", name=name)
        elif sys.platform == "win32":
            f = tvm.build(s, [A, B], device, "llvm", name=name)
        else:
            raise ValueError("Unsupported platform")
114

115 116 117 118 119 120 121 122
        path_dso = temp.relpath("dev_lib.so")
        f.export_library(path_dso)

        f1 = tvm.module.load(path_dso)
        a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx)
        b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx)
        f1(a, b)
        np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
Hu Shiwen committed
123 124 125 126
        if sys.platform != "win32":
            f2 = tvm.module.system_lib()
            f2[name](a, b)
            np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
127

128 129 130 131 132 133 134 135 136 137 138 139 140 141 142
    def check_stackvm(device):
        ctx = tvm.context(device, 0)
        if not ctx.exist:
            print("Skip because %s is not enabled" % device)
            return
        temp = util.tempdir()
        name = "myadd_%s" % device
        f = tvm.build(s, [A, B], device, "stackvm", name=name)
        path_dso = temp.relpath("dev_lib.stackvm")
        #f.export_library(path_dso)
        #f1 = tvm.module.load(path_dso)
        a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx)
        b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx)
        f(a, b)
        np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
143

144 145 146
    for device in ["cuda", "vulkan", "opencl", "metal"]:
        check_device(device)
        check_stackvm(device)
147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180

def test_combine_module_llvm():
    """Test combine multiple module into one shared lib."""
    # graph
    nn = 12
    n = tvm.convert(nn)
    A = tvm.placeholder((n,), name='A')
    B = tvm.compute(A.shape, lambda *i: A(*i) + 1.0, name='B')
    s = tvm.create_schedule(B.op)

    def check_llvm():
        ctx = tvm.cpu(0)
        if not tvm.module.enabled("llvm"):
            print("Skip because llvm is not enabled" )
            return
        temp = util.tempdir()
        fadd1 = tvm.build(s, [A, B], "llvm", name="myadd1")
        fadd2 = tvm.build(s, [A, B], "llvm", name="myadd2")
        path1 = temp.relpath("myadd1.o")
        path2 = temp.relpath("myadd2.o")
        path_dso = temp.relpath("mylib.so")
        fadd1.save(path1)
        fadd2.save(path2)
        # create shared library with multiple functions
        cc.create_shared(path_dso, [path1, path2])
        m = tvm.module.load(path_dso)
        fadd1 = m['myadd1']
        fadd2 = m['myadd2']
        a = tvm.nd.array(np.random.uniform(size=nn).astype(A.dtype), ctx)
        b = tvm.nd.array(np.zeros(nn, dtype=A.dtype), ctx)
        fadd1(a, b)
        np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
        fadd2(a, b)
        np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205

    def check_system_lib():
        ctx = tvm.cpu(0)
        if not tvm.module.enabled("llvm"):
            print("Skip because llvm is not enabled" )
            return
        temp = util.tempdir()
        fadd1 = tvm.build(s, [A, B], "llvm -system-lib", name="myadd1")
        fadd2 = tvm.build(s, [A, B], "llvm -system-lib", name="myadd2")
        path1 = temp.relpath("myadd1.o")
        path2 = temp.relpath("myadd2.o")
        path_dso = temp.relpath("mylib.so")
        fadd1.save(path1)
        fadd2.save(path2)
        cc.create_shared(path_dso, [path1, path2])
        # Load dll, will trigger system library registration
        dll = ctypes.CDLL(path_dso)
        # Load the system wide library
        mm = tvm.module.system_lib()
        a = tvm.nd.array(np.random.uniform(size=nn).astype(A.dtype), ctx)
        b = tvm.nd.array(np.zeros(nn, dtype=A.dtype), ctx)
        mm['myadd1'](a, b)
        np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
        mm['myadd2'](a, b)
        np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
206

Hu Shiwen committed
207 208
    if sys.platform != "win32":
        check_system_lib()
209 210 211
    check_llvm()


212

213
if __name__ == "__main__":
214
    test_combine_module_llvm()
215
    test_device_module_dump()
216
    test_dso_module_load()