# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm import te
from tvm import relay
from tvm.relay import Function, transform
from tvm.relay.analysis import alpha_equal, graph_equal, free_vars, assert_alpha_equal
from tvm.relay.op import log, add, equal, subtract
from tvm.relay.testing import inception_v3

import pytest

class env:
    def __init__(self):
        self.shape = tvm.runtime.convert([1, 2, 3])
        self.tt = relay.TensorType(self.shape, "float32")
        self.int32 = relay.TensorType([], "int32")
        self.float32 = relay.TensorType([], "float32")
        self.one = relay.const(1.0)
        self.two = relay.const(2.0)
        self.three = relay.const(3.0)
        self.a = relay.Var("a", self.float32)
        self.b = relay.Var("b", self.float32)
        self.c = relay.Var("c", self.float32)
        self.d = relay.Var("d", self.float32)
        self.e = relay.Var("e", self.float32)
        self.x = relay.Var("x", self.int32)
        self.y = relay.Var("y", self.int32)
        self.z = relay.Var("z", self.int32)


e = env()


def run_opt_pass(expr, opt_pass):
    assert isinstance(opt_pass, transform.Pass)
    mod = tvm.IRModule.from_expr(expr)
    mod = opt_pass(mod)
    entry = mod["main"]
    return entry if isinstance(expr, relay.Function) else entry.body


def test_let():
    orig = relay.Let(e.x, e.y, e.z)
    orig = run_opt_pass(orig, transform.DeadCodeElimination())
    assert alpha_equal(Function(free_vars(orig), orig), Function([e.z], e.z))


def test_used_let():
    orig = relay.Let(e.c, e.one, e.c + e.c)
    orig = run_opt_pass(orig, transform.DeadCodeElimination())
    expected = relay.Let(e.c, e.one, e.c + e.c)
    assert alpha_equal(Function([e.c], orig), Function([e.c], expected))

def test_inline():
    orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.c))
    orig = run_opt_pass(orig, transform.DeadCodeElimination(True))
    assert_alpha_equal(Function(free_vars(orig), orig), Function([e.d], e.d))


def test_chain_unused_let():
    orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.e))
    orig = run_opt_pass(orig, transform.DeadCodeElimination())
    assert alpha_equal(Function(free_vars(orig), orig), Function([e.e], e.e))


def use_f(func):
    f = relay.Var("f")
    n = relay.Var("n", e.int32)
    data = relay.Var("data", e.float32)
    funcbody = relay.If(equal(n, relay.const(0)),
                        data,
                        relay.Call(f, [subtract(n, relay.const(1)),
                                       log(data)]))
    value = relay.Function([n, data], funcbody, e.float32, [])
    return relay.Let(f, value, func(f))

# make sure we dont infinite loop
def test_recursion():
    """
    Program:
       let f(n: i32, data: f32) -> f32 = {
          if (n == 0) {
              return data;
          } else {
              return f(n - 1, log(data));
          }
       }
       f(2, 10000);
    """
    orig = use_f(lambda f: relay.Call(f, [relay.const(2), relay.const(10000.0)]))
    dced = run_opt_pass(orig, transform.DeadCodeElimination())
    orig = run_opt_pass(orig, transform.InferType())
    assert_alpha_equal(dced, orig)

def test_recursion_dead():
    x = relay.Let(e.a, e.one, e.three)
    dced_f = lambda f: x
    dced = run_opt_pass(use_f(dced_f), transform.DeadCodeElimination())
    assert alpha_equal(dced, e.three)


def test_op_let():
    dced = run_opt_pass(add(relay.Let(e.a, e.one, e.three), e.two),
                        transform.DeadCodeElimination())
    assert alpha_equal(dced, add(e.three, e.two))


def test_tuple_get_item():
    tt = relay.TupleType([e.float32, e.float32])
    t = relay.Var('t', tt)
    a = relay.Var('a')
    g = relay.TupleGetItem(t, 0)
    dced = run_opt_pass(g, transform.DeadCodeElimination())
    assert alpha_equal(Function(free_vars(dced), dced), Function(free_vars(g), g))
    orig = relay.TupleGetItem(relay.Let(a, e.one, t), 0)
    dced = run_opt_pass(orig, transform.DeadCodeElimination())
    assert alpha_equal(Function(free_vars(dced), dced), Function(free_vars(g), g))


@pytest.mark.timeout(timeout=10, method="thread")
def test_complexity():
    g = inception_v3.get_net(1, 1000, (3, 299, 299), 'float32')
    run_opt_pass(g, transform.DeadCodeElimination())


if __name__ == "__main__":
    test_let()
    test_used_let()
    test_inline()
    test_chain_unused_let()
    test_recursion()
    test_recursion_dead()
    test_op_let()
    test_tuple_get_item()
    test_complexity()