# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
"""Unit tests for relay pass manager."""
import numpy as np
import pytest

import tvm
from tvm import relay
from tvm.relay import ExprFunctor
from tvm.relay import Function, Call
from tvm.relay import analysis
from tvm.relay import transform as _transform
from tvm.relay.testing import ctx_list


def run_infer_type(expr):
    mod = relay.Module.from_expr(expr)
    mod = _transform.InferType()(mod)
    entry = mod["main"]
    return entry if isinstance(expr, relay.Function) else entry.body


def get_var_func():
    shape = (5, 10)
    tp = relay.TensorType(shape, "float32")
    x = relay.var("x", tp)
    gv = relay.GlobalVar("myAbs")
    func = relay.Function([x], relay.abs(x))
    return gv, func


def extract_var_func(mod, name):
    var = mod.get_global_var(name)
    func = mod[var]
    return var, func


def update_func(func):
    # Double the value of Constants and vars.
    class DoubleValues(ExprFunctor):
        def __init__(self):
            ExprFunctor.__init__(self)

        def visit_constant(self, const):
            return relay.add(const, const)

        def visit_var(self, var):
            return relay.add(var, var)

        def visit_call(self, call):
            new_op = self.visit(call.op)
            new_args = [self.visit(arg) for arg in call.args]
            return Call(new_op, new_args, call.attrs)

        def visit_global_var(self, gvar):
            return gvar

        def visit_op(self, op):
            return op

        def visit_function(self, fn):
            new_body = self.visit(fn.body)
            return Function(
                list(fn.params), new_body, fn.ret_type, fn.type_params,
                fn.attrs)

    double_value = DoubleValues()
    return double_value.visit(func)


class OptTester():
    """A helper class for testing the pass manager."""

    def __init__(self, mod):
        if not isinstance(mod, relay.Module):
            raise TypeError("mod is expected to be the type of "
                            "relay.Module")
        self.mod = mod

    def analysis(self):
        """Perform analysis for the current module."""
        pass

    @staticmethod
    def transform(node, ctx=None):
        """Perform optimization on node."""
        if isinstance(node, relay.Module):
            # Add a function to the module and return an updated module.
            gv, func = get_var_func()
            mod = relay.Module({gv: func})
            mod.update(node)
            return mod
        if isinstance(node, relay.Function):
            return update_func(node)

        raise TypeError("Found not supported node type.")


def get_rand(shape, dtype='float32'):
    return tvm.nd.array(np.random.rand(*shape).astype(dtype))


def check_func(func, ref_func):
    func = run_infer_type(func)
    ref_func = run_infer_type(ref_func)
    assert analysis.graph_equal(func, ref_func)


def test_module_pass():
    shape = (5, 10)
    dtype = 'float32'
    tp = relay.TensorType(shape, dtype)
    x = relay.var("x", tp)
    y = relay.var("y", tp)
    v_add = relay.GlobalVar("myAdd")
    func = relay.Function([x, y], x + y)
    mod = relay.Module({v_add: func})

    pass_name = "module_pass_test"
    opt_level = 0
    opt_tester = OptTester(mod)
    pass_ctx = None

    @_transform.module_pass(opt_level=opt_level, name=pass_name)
    def transform(expr, ctx):
        return opt_tester.transform(expr, ctx)

    def test_pass_registration():
        mod_pass = transform
        assert isinstance(mod_pass, _transform.ModulePass)
        pass_info = mod_pass.info
        assert pass_info.name == pass_name
        assert pass_info.opt_level == opt_level

    def test_pass_registration_no_decorator():
        def direct_transform(expr, ctx):
            return opt_tester.transform(expr, ctx)
        mod_pass = _transform.module_pass(direct_transform, opt_level=3)
        assert isinstance(mod_pass, _transform.ModulePass)
        pass_info = mod_pass.info
        assert pass_info.name == "direct_transform"
        assert pass_info.opt_level == 3

    def test_pass_run():
        module_pass = transform
        assert pass_name in module_pass.astext()

        updated_mod = module_pass(mod)
        assert isinstance(updated_mod, relay.Module)

        # Check the abs function in the updated module.
        v_abs, myabs = get_var_func()
        new_v_add = updated_mod.get_global_var(v_abs.name_hint)
        new_abs = updated_mod[new_v_add]
        check_func(new_abs, myabs)

        # Check the add function in the updated module.
        v_abs, myabs = get_var_func()
        new_v_add = updated_mod.get_global_var(v_add.name_hint)
        new_add = updated_mod[new_v_add]
        check_func(new_add, func)

        # Check the add function in the python transformed module.
        ret = opt_tester.transform(mod, pass_ctx)
        transformed_v_add = ret.get_global_var(v_add.name_hint)
        transformed_add = mod[transformed_v_add]
        check_func(new_add, transformed_add)

        # Execute the add function.
        x_nd = get_rand(shape, dtype)
        y_nd = get_rand(shape, dtype)
        ref_res = x_nd.asnumpy() + y_nd.asnumpy()
        for target, ctx in ctx_list():
            exe1 = relay.create_executor("graph", ctx=ctx, target=target)
            exe2 = relay.create_executor("debug", ctx=ctx, target=target)
            res1 = exe1.evaluate(new_add)(x_nd, y_nd)
            tvm.testing.assert_allclose(res1.asnumpy(), ref_res, rtol=1e-5)
            res2 = exe2.evaluate(new_add)(x_nd, y_nd)
            tvm.testing.assert_allclose(res2.asnumpy(), ref_res, rtol=1e-5)

    test_pass_registration()
    test_pass_registration_no_decorator
    test_pass_run()


def test_function_class_pass():
    @relay.transform.function_pass(opt_level=1)
    class TestReplaceFunc:
        """Simple test function to replace one argument to another."""
        def __init__(self, new_func):
            self.new_func = new_func

        def transform_function(self, func, mod, ctx):
            return self.new_func

    x = relay.var("x", shape=(10, 20))
    f1 = relay.Function([x], x)
    f2 = relay.Function([x], relay.log(x))
    fpass = TestReplaceFunc(f1)
    assert fpass.info.opt_level == 1
    assert fpass.info.name == "TestReplaceFunc"
    mod = relay.Module.from_expr(f2)
    mod = fpass(mod)
    # wrap in expr
    mod2 = relay.Module.from_expr(f1)
    assert relay.alpha_equal(mod["main"], mod2["main"])


def test_function_pass():
    shape = (10, )
    dtype = 'float32'
    tp = relay.TensorType(shape, dtype)
    x = relay.var("x", tp)
    v_log = relay.GlobalVar("myLog")
    log = relay.Function([x], relay.log(x))
    mod = relay.Module({v_log: log})

    pass_name = "function_pass_test"
    opt_level = 1
    opt_tester = OptTester(mod)
    pass_ctx = None

    @_transform.function_pass(opt_level=opt_level, name=pass_name)
    def transform(expr, mod, ctx):
        return opt_tester.transform(expr, ctx)

    def get_ref_log():
        ref_log = relay.Function([x], relay.log(relay.add(x, x)))
        return ref_log

    def test_pass_registration():
        function_pass = transform
        assert isinstance(function_pass, _transform.FunctionPass)
        pass_info = function_pass.info
        assert pass_info.name == pass_name
        assert pass_info.opt_level == opt_level

    def test_pass_registration_no_decorator():
        def direct_transform(expr, ctx):
            return opt_tester.transform(expr, ctx)
        mod_pass = _transform.function_pass(direct_transform, opt_level=0)
        assert isinstance(mod_pass, _transform.FunctionPass)
        pass_info = mod_pass.info
        assert pass_info.name == "direct_transform"
        assert pass_info.opt_level == 0

    def test_pass_run():
        function_pass = transform
        assert pass_name in function_pass.astext()

        updated_mod = function_pass(mod)
        assert isinstance(updated_mod, relay.Module)

        # Check the log function in the updated module.
        new_v_log = updated_mod.get_global_var(v_log.name_hint)
        new_log = updated_mod[new_v_log]
        check_func(new_log, get_ref_log())

        # Check the log function in the python transformed function.
        ret = opt_tester.transform(log, pass_ctx)
        check_func(new_log, ret)

        # Execute the add function.
        x_nd = get_rand(shape, dtype)
        ref_res = np.log(x_nd.asnumpy() * 2)
        for target, ctx in ctx_list():
            exe1 = relay.create_executor("graph", ctx=ctx, target=target)
            exe2 = relay.create_executor("debug", ctx=ctx, target=target)
            res1 = exe1.evaluate(new_log)(x_nd)
            tvm.testing.assert_allclose(res1.asnumpy(), ref_res, rtol=1e-5)
            res2 = exe2.evaluate(new_log)(x_nd)
            tvm.testing.assert_allclose(res2.asnumpy(), ref_res, rtol=1e-5)

    test_pass_registration()
    test_pass_registration_no_decorator()
    test_pass_run()


def test_module_class_pass():
    @relay.transform.module_pass(opt_level=1)
    class TestPipeline:
        """Simple test function to replace one argument to another."""
        def __init__(self, new_mod, replace):
            self.new_mod = new_mod
            self.replace = replace

        def transform_module(self, mod, ctx):
            if self.replace:
                return self.new_mod
            return mod

    x = relay.var("x", shape=(10, 20))
    m1 = relay.Module.from_expr(relay.Function([x], x))
    m2 = relay.Module.from_expr(relay.Function([x], relay.log(x)))
    fpass = TestPipeline(m2, replace=True)
    assert fpass.info.name == "TestPipeline"
    mod3 = fpass(m1)
    assert mod3.same_as(m2)
    mod4 = TestPipeline(m2, replace=False)(m1)
    assert mod4.same_as(m1)


def test_pass_info():
    info = relay.transform.PassInfo(opt_level=1, name="xyz")
    assert info.opt_level == 1
    assert info.name == "xyz"


def test_sequential_pass():
    shape = (10, )
    dtype = 'float32'
    tp = relay.TensorType(shape, dtype)
    x = relay.var("x", tp)
    y = relay.var("y", tp)
    v_sub = relay.GlobalVar("mySub")
    sub = relay.Function([x, y], relay.subtract(x, y))

    z = relay.var("z", tp)
    v_log = relay.GlobalVar("myLog")
    log = relay.Function([z], relay.log(z))

    mod = relay.Module({v_sub: sub, v_log: log})

    def get_ref_log():
        ref_log = relay.Function([x], relay.log(relay.add(x, x)))
        return ref_log

    def get_ref_sub():
        ref_sub = relay.Function([x, y],
                                 relay.subtract(
                                     relay.add(x, x), relay.add(y, y)))
        return ref_sub

    def get_ref_abs():
        shape = (5, 10)
        tp = relay.TensorType(shape, "float32")
        a = relay.var("a", tp)
        ref_abs = relay.Function([a], relay.abs(relay.add(a, a)))
        return ref_abs

    # Register a module pass.
    opt_tester = OptTester(mod)
    pass_ctx = None

    @_transform.module_pass(opt_level=1)
    def mod_transform(expr, ctx):
        return opt_tester.transform(expr, ctx)

    module_pass = mod_transform

    # Register a function pass.
    @_transform.function_pass(opt_level=1)
    def func_transform(expr, mod, ctx):
        return opt_tester.transform(expr, ctx)

    function_pass = func_transform

    def test_pass_registration():
        passes = [module_pass, function_pass]
        opt_level = 2
        pass_name = "sequential"
        sequential = _transform.Sequential(passes=passes, opt_level=opt_level)
        pass_info = sequential.info
        assert pass_info.name == pass_name
        assert pass_info.opt_level == opt_level

    def test_no_pass():
        passes = []
        sequential = _transform.Sequential(opt_level=1, passes=passes)
        ret_mod = sequential(mod)
        mod_func = ret_mod[v_sub]
        check_func(sub, mod_func)

    def test_only_module_pass():
        passes = [module_pass]
        sequential = _transform.Sequential(opt_level=1, passes=passes)
        with relay.build_config(required_pass=["mod_transform"]):
            ret_mod = sequential(mod)
        # Check the subtract function.
        sub_var, new_sub = extract_var_func(ret_mod, v_sub.name_hint)
        check_func(new_sub, sub)

        # Check the abs function is added.
        abs_var, abs_func = get_var_func()
        abs_var, new_abs = extract_var_func(ret_mod, abs_var.name_hint)
        check_func(new_abs, abs_func)

    def test_only_function_pass():
        # Check the subtract function.
        passes = [function_pass]
        sequential = _transform.Sequential(opt_level=1, passes=passes)
        with relay.build_config(required_pass=["func_transform"]):
            ret_mod = sequential(mod)
        _, new_sub = extract_var_func(ret_mod, v_sub.name_hint)
        check_func(new_sub, get_ref_sub())

        # Check the log function.
        log_var, new_log = extract_var_func(ret_mod, v_log.name_hint)
        check_func(new_log, get_ref_log())

    def test_multiple_passes():
        # Reset the current module since mod has been polluted by the previous
        # function pass.
        mod = relay.Module({v_sub: sub, v_log: log})
        passes = [module_pass, function_pass]
        sequential = _transform.Sequential(opt_level=1, passes=passes)
        required = ["mod_transform", "func_transform"]
        with relay.build_config(required_pass=required):
            ret_mod = sequential(mod)

        # Check the abs function is added.
        abs_var, abs_func = get_var_func()
        abs_var, new_abs = extract_var_func(ret_mod, abs_var.name_hint)
        check_func(new_abs, get_ref_abs())

        # Check the subtract function is modified correctly.
        _, new_sub = extract_var_func(ret_mod, v_sub.name_hint)
        check_func(new_sub, get_ref_sub())

        # Check the log function is modified correctly.
        _, new_log = extract_var_func(ret_mod, v_log.name_hint)
        check_func(new_log, get_ref_log())

        # Execute the updated subtract function.
        x_nd = get_rand(shape, dtype)
        y_nd = get_rand(shape, dtype)
        ref_res = np.subtract(x_nd.asnumpy() * 2, y_nd.asnumpy() * 2)
        for target, ctx in ctx_list():
            exe1 = relay.create_executor("graph", ctx=ctx, target=target)
            exe2 = relay.create_executor("debug", ctx=ctx, target=target)
            res1 = exe1.evaluate(new_sub)(x_nd, y_nd)
            tvm.testing.assert_allclose(res1.asnumpy(), ref_res, rtol=1e-5)
            res2 = exe2.evaluate(new_sub)(x_nd, y_nd)
            tvm.testing.assert_allclose(res2.asnumpy(), ref_res, rtol=1e-5)

        # Execute the updated abs function.
        x_nd = get_rand((5, 10), dtype)
        ref_res = np.abs(x_nd.asnumpy() * 2)
        for target, ctx in ctx_list():
            exe1 = relay.create_executor("graph", ctx=ctx, target=target)
            exe2 = relay.create_executor("debug", ctx=ctx, target=target)
            res1 = exe1.evaluate(new_abs)(x_nd)
            tvm.testing.assert_allclose(res1.asnumpy(), ref_res, rtol=1e-5)
            res2 = exe2.evaluate(new_abs)(x_nd)
            tvm.testing.assert_allclose(res2.asnumpy(), ref_res, rtol=1e-5)

    test_pass_registration()
    test_no_pass()
    test_only_module_pass()
    test_only_function_pass()
    test_multiple_passes()


def test_sequential_with_scoping():
    shape = (1, 2, 3)
    c_data = np.array(shape).astype("float32")
    tp = relay.TensorType(shape, "float32")
    def before():
        c = relay.const(c_data)
        x = relay.var("x", tp)
        y = relay.add(c, c)
        y = relay.multiply(y, relay.const(2, "float32"))
        y = relay.add(x, y)
        z = relay.add(y, c)
        z1 = relay.add(y, c)
        z2 = relay.add(z, z1)
        return relay.Function([x], z2)

    def expected():
        x = relay.var("x", tp)
        c_folded = (c_data + c_data) * 2
        y = relay.add(x, relay.const(c_folded))
        z = relay.add(y, relay.const(c_data))
        z1 = relay.add(z, z)
        return relay.Function([x], z1)

    seq = _transform.Sequential([
        relay.transform.InferType(),
        relay.transform.FoldConstant(),
        relay.transform.EliminateCommonSubexpr(),
        relay.transform.AlterOpLayout()
    ])

    mod = relay.Module({"main": before()})
    with relay.build_config(opt_level=3):
        with tvm.target.create("llvm"):
            mod = seq(mod)

    zz = mod["main"]
    zexpected = run_infer_type(expected())
    assert analysis.alpha_equal(zz, zexpected)


def test_print_ir(capfd):
    shape = (1, 2, 3)
    tp = relay.TensorType(shape, "float32")
    x = relay.var("x", tp)
    y = relay.add(x, x)
    y = relay.multiply(y, relay.const(2, "float32"))
    func = relay.Function([x], y)

    seq = _transform.Sequential([
        relay.transform.InferType(),
        relay.transform.FoldConstant(),
        relay.transform.PrintIR(),
        relay.transform.DeadCodeElimination()
    ])

    mod = relay.Module({"main": func})
    with relay.build_config(opt_level=3):
        mod = seq(mod)

    out = capfd.readouterr().err

    assert "Dumping the module IR" in out
    assert "multiply" in out


if __name__ == "__main__":
    pytest.main()