# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm.contrib import util
import numpy as np

def test_add_pipeline():
    n = tvm.var('n')
    A = tvm.placeholder((n,), name='A')
    B = tvm.placeholder((), name='B')
    C = tvm.compute(A.shape, lambda *i: A(*i) + B(), name='C')
    D = tvm.compute(A.shape, lambda *i: C(*i) + 1, name='D')
    s = tvm.create_schedule(D.op)

    # GPU schedule have to split by gridIdx and threadIdx
    num_thread = 256
    xo, xi = s[C].split(C.op.axis[0], factor=num_thread)
    s[C].bind(xi, tvm.thread_axis("threadIdx.x"))
    s[C].bind(xo, tvm.thread_axis("blockIdx.x"))

    xo, xi = s[D].split(D.op.axis[0], factor=num_thread)
    s[D].bind(xi, tvm.thread_axis("threadIdx.x"))
    s[D].bind(xo, tvm.thread_axis("blockIdx.x"))

    # compile to IR
    s = s.normalize()
    bounds = tvm.schedule.InferBound(s)
    stmt = tvm.schedule.ScheduleOps(s, bounds)
    Ab = tvm.decl_buffer(A.shape, A.dtype, name='A')
    Bb = tvm.decl_buffer(B.shape, B.dtype, name='B')
    Db = tvm.decl_buffer(D.shape, D.dtype, name='D')
    stmt = tvm.ir_pass.LoopPartition(stmt, False)
    stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, D:Db}, 64)
    stmt = tvm.ir_pass.Simplify(stmt)
    fapi = tvm.ir_pass.MakeAPI(stmt, "myadd", [Ab, Bb, Db], 0, True)
    fsplits = [x for x in tvm.ir_pass.SplitHostDevice(fapi)]
    # lower the floordiv(use stackvm rules so it works for all targets)
    fsplits = [tvm.ir_pass.LowerIntrin(x, "stackvm") for x in fsplits]
    fsplits[0] = tvm.ir_pass.LowerTVMBuiltin(fsplits[0])

    def check_target(device, host="stackvm"):
        ctx = tvm.context(device, 0)
        if not ctx.exist:
            return
        if not tvm.module.enabled(host):
            return
        mhost = tvm.codegen.build_module(fsplits[0], host)
        mdev = tvm.codegen.build_module(fsplits[1:], device)
        mhost.import_module(mdev)
        code = mdev.get_source()
        f = mhost.entry_func
        # launch the kernel.
        n = 1027
        a = tvm.nd.array(np.random.uniform(size=n).astype(Ab.dtype), ctx)
        b = tvm.nd.array(np.random.uniform(size=()).astype(Bb.dtype), ctx)
        d = tvm.nd.array(np.zeros(n, dtype=Db.dtype), ctx)
        f(a, b, d)
        tvm.testing.assert_allclose(
            d.asnumpy(), a.asnumpy() + b.asnumpy() + 1)

    def check_module_save(device, host="stackvm"):
        ctx = tvm.context(device, 0)
        if not ctx.exist:
            return
        if not tvm.module.enabled(host):
            return
        if device == "cuda":
            fmt = "ptx"
        elif device == "rocm":
            fmt = "hsaco"
        else:
            fmt = device
        mhost = tvm.codegen.build_module(fsplits[0], host)
        mdev = tvm.codegen.build_module(fsplits[1:], device)
        temp = util.tempdir()
        mpath = temp.relpath("test.%s" % fmt)
        mdev.save(mpath)
        mdev2 = tvm.module.load(mpath)
        mhost.import_module(mdev2)
        f = mhost.entry_func
        # launch the kernel.
        n = 1027
        a = tvm.nd.array(np.random.uniform(size=n).astype(Ab.dtype), ctx)
        b = tvm.nd.array(np.random.uniform(size=()).astype(Bb.dtype), ctx)
        d = tvm.nd.array(np.zeros(n, dtype=Db.dtype), ctx)
        f(a, b, d)
        tvm.testing.assert_allclose(
            d.asnumpy(), a.asnumpy() + b.asnumpy() + 1)

    check_target("cuda", host="stackvm")
    check_target("cuda", host="llvm")
    check_module_save("cuda", host="stackvm")
    check_target("nvptx", host="llvm")
    check_target("vulkan", host="llvm")
    check_module_save("vulkan", host="stackvm")
    check_target("rocm", host="llvm")
    check_module_save("rocm", host="llvm")


if __name__ == "__main__":
    test_add_pipeline()