# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
import numpy as np

def test_get_global():
    targs = (10, 10.0, "hello")
    # register into global function table
    @tvm.register_func
    def my_packed_func(*args):
        assert(tuple(args) == targs)
        return 10
    # get it out from global function table
    f = tvm.get_global_func("my_packed_func")
    assert isinstance(f, tvm.Function)
    y = f(*targs)
    assert y == 10

def test_get_callback_with_node():
    x = tvm.convert(10)
    def test(y):
        assert y.handle != x.handle
        return y

    f2 = tvm.convert(test)
    # register into global function table
    @tvm.register_func
    def my_callback_with_node(y, f):
        assert y == x
        return f(y)

    # get it out from global function table
    f = tvm.get_global_func("my_callback_with_node")
    assert isinstance(f, tvm.Function)
    y = f(x, f2)
    assert(y.value == 10)


def test_return_func():
    def addy(y):
        def add(x):
            return tvm.convert(x + y)
        return add
    myf = tvm.convert(addy)
    f = myf(10)
    assert f(11).value == 21


def test_convert():
    # convert a function to tvm function
    targs = (10, 10.0, "hello", 10)
    def myfunc(*args):
        assert(tuple(args) == targs)

    f = tvm.convert(myfunc)
    assert isinstance(f, tvm.Function)

def test_byte_array():
    s = "hello"
    a = bytearray(s, encoding="ascii")

    def myfunc(ss):
        assert ss == a
    f = tvm.convert(myfunc)
    f(a)


def test_empty_array():
    def myfunc(ss):
        assert tuple(ss) == ()
    x = tvm.convert(())
    tvm.convert(myfunc)(x)


def test_ctx():
    def test_ctx_func(ctx):
        assert tvm.gpu(7) == ctx
        return tvm.cpu(0)
    x = test_ctx_func(tvm.gpu(7))
    assert x == tvm.cpu(0)
    x = tvm.opencl(10)
    x = tvm._api_internal._context_test(x, x.device_type, x.device_id)
    assert x == tvm.opencl(10)

def test_trace_default_action():
    n = 2
    x = tvm.placeholder((n,n,n), name="X", dtype="float32")
    y = tvm.compute(x.shape, lambda i, j, k: tvm.trace([i, j, k, x[i][j][k]]))
    s = tvm.create_schedule(y.op)
    f = tvm.build(s, [x, y], target="llvm")
    xnd = tvm.nd.array(np.ones((n,n,n), dtype=x.dtype))
    ynd = tvm.nd.array(np.zeros((n,n,n), dtype=y.dtype))
    f(xnd, ynd)

def test_trace_expr_assign():
    @tvm.register_func("tvm.trace_callback2")
    def trace_buffer(x):
        return

    def check_assign(dtype):
        n = 4
        x = tvm.placeholder((n,n,n), name="X", dtype=dtype)
        y = tvm.compute(x.shape, lambda i, j, k: tvm.trace([x[i][j][k]], "tvm.trace_callback2"))
        z = tvm.compute(x.shape, lambda i, j, k: tvm.trace([y[i][j][k]], "tvm.trace_callback2"))
        s = tvm.create_schedule(z.op)
        f = tvm.build(s, [x, y, z], "llvm")

        xnd = tvm.nd.array(np.ones((n,n,n), dtype=x.dtype))
        ynd = tvm.nd.array(np.zeros((n,n,n), dtype=y.dtype))
        znd = tvm.nd.array(np.zeros((n,n,n), dtype=z.dtype))
        f(xnd, ynd, znd)

        assert(np.array_equal(xnd.asnumpy(), np.ones((n,n,n))))
        assert(np.array_equal(ynd.asnumpy(), np.ones((n,n,n))))
        assert(np.array_equal(znd.asnumpy(), np.ones((n,n,n))))

    for t in ["float64", "float32", "int64", "int32"]:
        check_assign(t)

def test_trace_expr_sum_generated():
    @tvm.register_func("tvm.trace_callback3")
    def trace_buffer(x):
        return

    def check_expr_sum(dtype):
        n = 4
        a = tvm.placeholder((n,n,n), name="a", dtype=dtype)
        b = tvm.placeholder((n,n,n), name="b", dtype=dtype)
        c = tvm.compute(a.shape, lambda i, j, k: tvm.trace([a[i][j][k]],"tvm.trace_callback3")
                                         + tvm.trace([b[i][j][k]],"tvm.trace_callback3"))
        s = tvm.create_schedule(c.op)
        f = tvm.build(s, [a, b, c])
        xnd = tvm.nd.array(np.array(np.ones((n,n,n), dtype=a.dtype)))
        ynd = tvm.nd.array(np.array(np.ones((n,n,n), dtype=b.dtype)))
        znd = tvm.nd.array(np.zeros((n,n,n), dtype=c.dtype))
        f(xnd, ynd, znd)
        assert(np.array_equal(znd.asnumpy(), xnd.asnumpy() + ynd.asnumpy()))

    for t in ["float64", "float32", "int64", "int32"]:
        check_expr_sum(t)

def test_trace_expr_sum_args():
    @tvm.register_func("tvm.trace_silent")
    def silent(*args):
      return

    def check_expr_sum(dtype):
        n = 4
        a = tvm.placeholder((n,n,n), name="a", dtype=dtype)
        b = tvm.placeholder((n,n,n), name="b", dtype=dtype)
        e = tvm.placeholder((n,n,n), name="e", dtype=dtype)
        d = tvm.placeholder((n,n,n), name="d", dtype=dtype)

        c = tvm.compute(a.shape, lambda i, j, k: tvm.trace([i, j, k, a[i][j][k]], "tvm.trace_silent")
                                               + tvm.trace([i, j, k, b[i][j][k]], "tvm.trace_silent")
                                               + tvm.trace([i, j, k, d[i][j][k]], "tvm.trace_silent")
                                               + tvm.trace([i, j, k, e[i][j][k]], "tvm.trace_silent"))
        s = tvm.create_schedule(c.op)
        f = tvm.build(s, [a, b, d, e, c])
        a_nd = tvm.nd.array(np.array(np.ones((n,n,n), dtype=a.dtype)))
        b_nd = tvm.nd.array(np.array(np.ones((n,n,n), dtype=b.dtype)))
        d_nd = tvm.nd.array(np.array(np.ones((n,n,n), dtype=d.dtype)))
        e_nd = tvm.nd.array(np.array(np.ones((n,n,n), dtype=e.dtype)))
        c_nd = tvm.nd.array(np.zeros((n,n,n), dtype=c.dtype))
        f(a_nd, b_nd, d_nd, e_nd, c_nd)
        assert(np.array_equal(c_nd.asnumpy(), a_nd.asnumpy()
                                            + b_nd.asnumpy()
                                            + d_nd.asnumpy()
                                            + e_nd.asnumpy()))

    for t in ["float64", "float32", "int64", "int32"]:
        check_expr_sum(t)

def test_trace_expr_sum_custom():
    @tvm.register_func("tvm.trace_callback4")
    def trace_buffer(x):
        return

    def check_expr_sum_custom(dtype):
        n = 4
        a = tvm.placeholder((n,n), name="a", dtype=dtype)
        b = tvm.placeholder((n,n), name="b", dtype=dtype)
        c = tvm.compute(a.shape, lambda i,j: tvm.trace([a[i][j]], "tvm.trace_callback4")
                                         + tvm.trace([b[i][j]], "tvm.trace_callback4"))
        s = tvm.create_schedule(c.op)
        f = tvm.build(s, [a, b, c])
        npa = np.array([[1,0,0,0], [0,1,0,0],[0,0,1,0],[0,0,0,1]], dtype=a.dtype)
        npb = np.array([[1,0,0,0], [0,1,0,0],[0,0,1,0],[0,0,0,1]], dtype=a.dtype)
        xnd = tvm.nd.array(npa)
        ynd = tvm.nd.array(npb)
        znd = tvm.nd.array(np.zeros((n,n), dtype=c.dtype))
        f(xnd, ynd, znd)
        assert(np.array_equal(znd.asnumpy(), npa + npb))

    for t in ["float64", "float32", "int64", "int32"]:
        check_expr_sum_custom(t)

def test_trace_can_change_traced_value_int():
    @tvm.register_func("tvm.trace_change_int_first")
    def trace_buffer(x):
        return 13

    @tvm.register_func("tvm.trace_change_int_second")
    def trace_buffer(x):
        return 14

    def check_assign(dtype):
        n = 4
        x = tvm.placeholder((n,), name="X", dtype=dtype)
        y = tvm.compute(x.shape, lambda i: tvm.trace([x[i]], "tvm.trace_change_int_first"))
        z = tvm.compute(x.shape, lambda i: tvm.trace([y[i]], "tvm.trace_change_int_second"))
        s = tvm.create_schedule(z.op)
        f = tvm.build(s, [x, y, z], "llvm")

        xnd = tvm.nd.array(np.ones((n,), dtype=x.dtype))
        ynd = tvm.nd.array(np.zeros((n,), dtype=y.dtype))
        znd = tvm.nd.array(np.zeros((n,), dtype=z.dtype))
        f(xnd, ynd, znd)
        check_array_first = np.array([13, 13, 13, 13])
        check_array_second = np.array([14, 14, 14, 14])
        assert(np.array_equal(ynd.asnumpy(), check_array_first))
        assert(np.array_equal(znd.asnumpy(), check_array_second))

    for t in ["int64", "int32"]:
        check_assign(t)

def test_trace_can_change_traced_value_float():
    @tvm.register_func("tvm.trace_change_float_first")
    def trace_buffer(x):
        return 13.0

    @tvm.register_func("tvm.trace_change_float_second")
    def trace_buffer(x):
        return 14.0

    def check_assign(dtype):
        n = 4
        x = tvm.placeholder((n,), name="X", dtype=dtype)
        y = tvm.compute(x.shape, lambda i: tvm.trace([x[i]], "tvm.trace_change_float_first"))
        z = tvm.compute(x.shape, lambda i: tvm.trace([y[i]], "tvm.trace_change_float_second"))
        s = tvm.create_schedule(z.op)
        f = tvm.build(s, [x, y, z], "llvm")

        xnd = tvm.nd.array(np.ones((n,), dtype=x.dtype))
        ynd = tvm.nd.array(np.zeros((n,), dtype=y.dtype))
        znd = tvm.nd.array(np.zeros((n,), dtype=z.dtype))
        f(xnd, ynd, znd)
        check_array_first = np.array([13.0, 13.0, 13.0, 13.0])
        check_array_second = np.array([14.0, 14.0, 14.0, 14.0])
        assert(np.array_equal(ynd.asnumpy(), check_array_first))
        assert(np.array_equal(znd.asnumpy(), check_array_second))

    for t in ["float64", "float32"]:
        check_assign(t)

if __name__ == "__main__":
    test_empty_array()
    test_get_global()
    test_get_callback_with_node()
    test_convert()
    test_return_func()
    test_byte_array()
    test_ctx()
    test_trace_expr_assign()
    test_trace_expr_sum_generated()
    test_trace_expr_sum_custom()
    test_trace_expr_sum_args()
    test_trace_default_action()
    test_trace_can_change_traced_value_int()
    test_trace_can_change_traced_value_float()