# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm.contrib import cc, util
import ctypes
import os
import sys
import numpy as np
import subprocess

runtime_py = """
import os
import sys

os.environ["TVM_USE_RUNTIME_LIB"] = "1"
os.environ["TVM_FFI"] = "ctypes"
import tvm
import numpy as np
path_dso = sys.argv[1]
dtype = sys.argv[2]
ff = tvm.module.load(path_dso)
a = tvm.nd.array(np.zeros(10, dtype=dtype))
ff(a)
np.testing.assert_equal(a.asnumpy(), np.arange(a.shape[0]))
print("Finish runtime checking...")
"""

def test_dso_module_load():
    if not tvm.module.enabled("llvm"):
        return
    dtype = 'int64'
    temp = util.tempdir()

    def save_object(names):
        n = tvm.var('n')
        Ab = tvm.decl_buffer((n, ), dtype)
        i = tvm.var('i')
        # for i in 0 to n-1:
        stmt = tvm.make.For(
            i, 0, n - 1, 0, 0,
            tvm.make.Store(Ab.data,
                           tvm.make.Load(dtype, Ab.data, i) + 1,
                           i + 1))
        fapi = tvm.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0, True)
        fapi = tvm.ir_pass.LowerTVMBuiltin(fapi)
        m = tvm.codegen.build_module(fapi, "llvm")
        for name in names:
            m.save(name)

    path_obj = temp.relpath("test.o")
    path_ll = temp.relpath("test.ll")
    path_bc = temp.relpath("test.bc")
    path_dso = temp.relpath("test.so")
    save_object([path_obj, path_ll, path_bc])
    cc.create_shared(path_dso, [path_obj])

    f1 = tvm.module.load(path_dso)
    f2 = tvm.module.load(path_ll)
    a = tvm.nd.array(np.zeros(10, dtype=dtype))
    f1(a)
    np.testing.assert_equal(a.asnumpy(), np.arange(a.shape[0]))
    a = tvm.nd.array(np.zeros(10, dtype=dtype))
    f2(a)
    np.testing.assert_equal(a.asnumpy(), np.arange(a.shape[0]))

    path_runtime_py = temp.relpath("runtime.py")
    with open(path_runtime_py, "w") as fo:
        fo.write(runtime_py)

    subprocess.check_call(
        "python3 %s %s %s" % (path_runtime_py, path_dso, dtype),
        shell=True)


def test_device_module_dump():
    # graph
    n = tvm.convert(1024)
    A = tvm.placeholder((n,), name='A')
    B = tvm.compute(A.shape, lambda *i: A(*i) + 1.0, name='B')
    s = tvm.create_schedule(B.op)
    # create iter var and assign them tags.
    num_thread = 8
    bx, tx = s[B].split(B.op.axis[0], factor=num_thread)
    s[B].bind(bx, tvm.thread_axis("blockIdx.x"))
    s[B].bind(tx, tvm.thread_axis("threadIdx.x"))

    def check_device(device):
        ctx = tvm.context(device, 0)
        if not ctx.exist:
            print("Skip because %s is not enabled" % device)
            return
        temp = util.tempdir()
        name = "myadd_%s" % device
        if sys.platform == "darwin" or sys.platform.startswith('linux'):
            f = tvm.build(s, [A, B], device, "llvm -system-lib", name=name)
        elif sys.platform == "win32":
            f = tvm.build(s, [A, B], device, "llvm", name=name)
        else:
            raise ValueError("Unsupported platform")

        path_dso = temp.relpath("dev_lib.so")
        f.export_library(path_dso)

        f1 = tvm.module.load(path_dso)
        a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx)
        b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx)
        f1(a, b)
        np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
        if sys.platform != "win32":
            f2 = tvm.module.system_lib()
            f2[name](a, b)
            np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)

    def check_stackvm(device):
        ctx = tvm.context(device, 0)
        if not ctx.exist:
            print("Skip because %s is not enabled" % device)
            return
        temp = util.tempdir()
        name = "myadd_%s" % device
        f = tvm.build(s, [A, B], device, "stackvm", name=name)
        path_dso = temp.relpath("dev_lib.stackvm")
        #f.export_library(path_dso)
        #f1 = tvm.module.load(path_dso)
        a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx)
        b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx)
        f(a, b)
        np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)

    for device in ["cuda", "vulkan", "opencl", "metal"]:
        check_device(device)
        check_stackvm(device)

def test_combine_module_llvm():
    """Test combine multiple module into one shared lib."""
    # graph
    nn = 12
    n = tvm.convert(nn)
    A = tvm.placeholder((n,), name='A')
    B = tvm.compute(A.shape, lambda *i: A(*i) + 1.0, name='B')
    s = tvm.create_schedule(B.op)

    def check_llvm():
        ctx = tvm.cpu(0)
        if not tvm.module.enabled("llvm"):
            print("Skip because llvm is not enabled" )
            return
        temp = util.tempdir()
        fadd1 = tvm.build(s, [A, B], "llvm", name="myadd1")
        fadd2 = tvm.build(s, [A, B], "llvm", name="myadd2")
        path1 = temp.relpath("myadd1.o")
        path2 = temp.relpath("myadd2.o")
        path_dso = temp.relpath("mylib.so")
        fadd1.save(path1)
        fadd2.save(path2)
        # create shared library with multiple functions
        cc.create_shared(path_dso, [path1, path2])
        m = tvm.module.load(path_dso)
        fadd1 = m['myadd1']
        fadd2 = m['myadd2']
        a = tvm.nd.array(np.random.uniform(size=nn).astype(A.dtype), ctx)
        b = tvm.nd.array(np.zeros(nn, dtype=A.dtype), ctx)
        fadd1(a, b)
        np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
        fadd2(a, b)
        np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)

    def check_system_lib():
        ctx = tvm.cpu(0)
        if not tvm.module.enabled("llvm"):
            print("Skip because llvm is not enabled" )
            return
        temp = util.tempdir()
        fadd1 = tvm.build(s, [A, B], "llvm -system-lib", name="myadd1")
        fadd2 = tvm.build(s, [A, B], "llvm -system-lib", name="myadd2")
        path1 = temp.relpath("myadd1.o")
        path2 = temp.relpath("myadd2.o")
        path_dso = temp.relpath("mylib.so")
        fadd1.save(path1)
        fadd2.save(path2)
        cc.create_shared(path_dso, [path1, path2])
        # Load dll, will trigger system library registration
        dll = ctypes.CDLL(path_dso)
        # Load the system wide library
        mm = tvm.module.system_lib()
        a = tvm.nd.array(np.random.uniform(size=nn).astype(A.dtype), ctx)
        b = tvm.nd.array(np.zeros(nn, dtype=A.dtype), ctx)
        mm['myadd1'](a, b)
        np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
        mm['myadd2'](a, b)
        np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)

    if sys.platform != "win32":
        check_system_lib()
    check_llvm()



if __name__ == "__main__":
    test_combine_module_llvm()
    test_device_module_dump()
    test_dso_module_load()