# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

import numpy as np
import tvm
from tvm import te
from tvm import relay
from tvm.relay.analysis import alpha_equal, assert_alpha_equal
from tvm.relay.prelude import Prelude
from tvm.relay import op, create_executor, transform
from tvm.relay import Var, TypeVar, TupleGetItem, Let, Function, const, RefRead, RefWrite, RefCreate
from tvm.relay import TensorType, Tuple, If, Clause, PatternConstructor, PatternVar, Match
from tvm.relay import GlobalVar, Call
from tvm.relay.transform import gradient
from tvm.relay.testing import add_nat_definitions, make_nat_expr, run_infer_type

def check_eval(expr, expected_result, mod=None, rtol=1e-07):
    ctx = tvm.context("llvm", 0)
    intrp = create_executor(mod=mod, ctx=ctx, target="llvm")

    result = intrp.evaluate(expr)
    np.testing.assert_allclose(result.asnumpy(), expected_result, rtol=rtol)


def run_opt_pass(expr, passes):
    passes = passes if isinstance(passes, list) else [passes]
    mod = tvm.IRModule.from_expr(expr)
    seq = transform.Sequential(passes)
    with transform.PassContext(opt_level=3):
       mod = seq(mod)
    entry = mod["main"]
    return entry if isinstance(expr, relay.Function) else entry.body


def tipe(expr):
    return run_opt_pass(expr, [transform.PartialEvaluate(),
                               transform.InferType()])


def dcpe(expr, mod=None, grad=False):
    passes = [transform.PartialEvaluate(),
              transform.DeadCodeElimination(inline_once=True)]
    if grad:
        expr = gradient(run_infer_type(expr))
    if mod:
        assert isinstance(expr, Function)
        mod["main"] = expr
        seq = transform.Sequential(passes)
        mod = seq(mod)
        return mod["main"]
    return run_opt_pass(expr, passes)


def test_tuple():
    t = TypeVar("t")
    x = Var("x", t)
    body = TupleGetItem(relay.Tuple([relay.const(4.0), x]), 1)
    f = Function([x], body, None, [t])
    expected = relay.Function([x], x, None, [t])
    expected = run_opt_pass(expected, transform.InferType())
    assert alpha_equal(dcpe(f), expected)


def test_const_inline():
    t = relay.TensorType([], "float32")
    d = Var("d", t)
    double = Function([d], d + d)
    orig = double(const(4.0))
    assert alpha_equal(dcpe(orig), const(8.0))


def test_ref():
    t = relay.TensorType([], "float32")
    d = relay.Var("d", t)
    r = relay.Var("r", relay.RefType(t))
    x = relay.Var("x")
    body = relay.RefRead(r)
    body = Let(x, RefWrite(r, RefRead(r) * RefRead(r)), body)
    body = Let(r, RefCreate(d), body)
    square = Function([d], body)
    expected = run_opt_pass(Function([d], d * d), transform.InferType())
    assert alpha_equal(dcpe(square), expected)


def test_empty_ad():
    shape = (10, 10)
    dtype = "float32"
    t = TensorType(shape, dtype)
    d = Var("d", t)
    f = Function([d], d)
    g = dcpe(f, grad=True)
    expected = Function([d], Tuple([d, Tuple([op.ones_like(d)])]))
    expected = run_opt_pass(expected, transform.InferType())
    assert alpha_equal(g, expected)


def test_ad():
    shape = (10, 10)
    dtype = "float32"
    t = TensorType(shape, dtype)
    d = Var("d", t)
    f = Function([d], d * d)
    g = dcpe(f, grad=True)
    m = d * d
    x = relay.Var("x")
    o = op.ones_like(x)
    x1 = relay.Var("x1")
    grad = op.zeros_like(d) + op.collapse_sum_like(x1 * d, d) + op.collapse_sum_like(x1 * d, d)
    body = Tuple([x, Tuple([grad])])
    body = relay.Let(x1, o, body)
    expected = Function([d], relay.Let(x, m, body))
    expected = run_opt_pass(expected, transform.InferType())
    assert_alpha_equal(g, expected)


def test_if_ref():
    shape = ()
    dtype = "bool"
    t = TensorType(shape, dtype)
    d = Var("d", t)
    r = Var("r")
    update = Function([], RefWrite(r, RefRead(r) + RefRead(r)))
    u = Var("u")
    body = If(d, u(), u())
    eff = Var("eff")
    body = Let(eff, body, RefRead(r))
    f = Function([d], Let(r, RefCreate(const(1)), Let(u, update, body)))
    pe_f = tipe(f)
    ex = create_executor()
    f_res = ex.evaluate(f)(const(True))
    pe_f_res = ex.evaluate(pe_f)(const(True))
    np.testing.assert_allclose(f_res.asnumpy(), 2 * np.ones_like(f_res.asnumpy()))
    np.testing.assert_allclose(pe_f_res.asnumpy(), 2 * np.ones_like(pe_f_res.asnumpy()))


def test_function_invalidate():
    shape = ()
    dtype = "bool"
    t = TensorType(shape, dtype)
    d = Var("d", t)
    r = Var("r")
    fetch = Function([], RefRead(r))
    fet = Var("fetch")
    fet_obscured = Var("fetch_obscured")
    u = Var("u")
    body = If(d, fet_obscured(), fet_obscured())
    body = Let(u, RefWrite(r, const(1)), body)
    body = Let(fet_obscured, If(d, fet, fet), body)
    body = Let(fet, fetch, body)
    body = Let(r, RefCreate(const(0)), body)
    f = Function([d], body)
    pe_f = tipe(f)
    ex = create_executor()
    f_res = ex.evaluate(f)(const(True))
    pe_f_res = ex.evaluate(pe_f)(const(True))
    np.testing.assert_allclose(f_res.asnumpy(), np.ones_like(f_res.asnumpy()))
    np.testing.assert_allclose(pe_f_res.asnumpy(), np.ones_like(pe_f_res.asnumpy()))


def test_head_cons():
    mod = tvm.IRModule()
    p = Prelude(mod)
    hd = p.hd
    t = TypeVar("t")
    x = Var("x", t)
    body = hd(p.cons(x, p.nil()))
    f = Function([x], body, None, [t])
    res = dcpe(f, mod)
    assert alpha_equal(res, Function([x], x, t, [t]))


def test_map():
    mod = tvm.IRModule()
    p = Prelude(mod)
    f = GlobalVar("f")
    t = TypeVar("t")
    a = Var("a", t)
    mod[f] = Function([a], a, t, [t])
    orig = p.map(f, p.cons(const(1), p.cons(const(2), p.cons(const(3), p.nil()))))
    expected = p.cons((const(1)), p.cons((const(2)), p.cons((const(3)), p.nil())))
    expected = Function([], expected)
    mod["main"] = expected
    expected = mod["main"]
    orig = Function([], orig)
    res = dcpe(orig, mod=mod)
    assert alpha_equal(res.body, expected.body)


def test_loop():
    mod = tvm.IRModule()
    t = TypeVar("t")
    x = Var("x", t)
    loop = GlobalVar("loop")
    mod[loop] = Function([x], loop(x), t, [t])
    expected = Call(loop, [const(1)])
    mod["main"] = Function([], expected)
    expected = mod["main"].body
    call = Function([], loop(const(1)))
    res = dcpe(call, mod=mod)
    assert alpha_equal(res.body, expected)


def test_swap_loop():
    mod = tvm.IRModule()
    p = Prelude(mod)
    add_nat_definitions(p)
    nat = p.nat()
    x = Var("x", nat)
    y = Var("y", nat)
    loop = GlobalVar("loop")
    mod[loop] = Function([x, y], loop(y, x), nat)
    prog = loop(make_nat_expr(p, 1), make_nat_expr(p, 2))
    res = Function([], prog)
    res = dcpe(res, mod=mod)
    assert alpha_equal(prog, res.body)


def test_abs_diff():
    # TODO(@M.K.): refactor using tuple pattern (not yet implemented)
    mod = tvm.IRModule()
    p = Prelude(mod)
    add_nat_definitions(p)
    nat = p.nat()
    x = Var("x", nat)
    y = Var("y", nat)
    xp = Var("x'", nat)
    yp = Var("y'", nat)
    diff = GlobalVar("diff")
    y_z_case = Clause(PatternConstructor(p.z, []), x)
    y_s_case = Clause(PatternConstructor(p.s, [PatternVar(yp)]), diff(yp, xp))
    x_z_case = Clause(PatternConstructor(p.z, []), y)
    x_s_case = Clause(PatternConstructor(p.s, [PatternVar(xp)]), Match(y, [y_z_case, y_s_case]))
    mod[diff] = Function([x, y], Match(x, [x_z_case, x_s_case]))
    orig = diff(make_nat_expr(p, 7), make_nat_expr(p, 3))
    orig = Function([], orig)
    res = dcpe(orig, mod=mod)
    assert alpha_equal(res.body, make_nat_expr(p, 4))


def test_match_nat_id():
    mod = tvm.IRModule()
    p = Prelude(mod)
    add_nat_definitions(p)
    nat = p.nat()
    x = Var("x", nat)
    y = Var("y", nat)
    nat_id = GlobalVar("nat_id")
    z_case = Clause(PatternConstructor(p.z, []), p.z())
    s_case = Clause(PatternConstructor(p.s, [PatternVar(y)]), p.s(y))
    mod[nat_id] = Function([x], Match(x, [z_case, s_case]))
    orig = nat_id(make_nat_expr(p, 3))
    orig = Function([], orig)
    res = dcpe(orig, mod=mod)
    assert alpha_equal(res.body, make_nat_expr(p, 3))


def test_nat_id():
    mod = tvm.IRModule()
    p = Prelude(mod)
    add_nat_definitions(p)
    nat = p.nat()
    x = Var("x", nat)
    y = Var("y", nat)
    nat_id = GlobalVar("nat_id")
    mod[nat_id] = Function([x], x)
    orig = nat_id(make_nat_expr(p, 3))
    orig = Function([], orig)
    res = dcpe(orig, mod=mod)
    assert alpha_equal(res.body, make_nat_expr(p, 3))


def test_global_match_nat_id():
    mod = tvm.IRModule()
    p = Prelude(mod)
    add_nat_definitions(p)
    nat = p.nat()
    x = Var("x", nat)
    z_case = Clause(PatternConstructor(p.z, []), p.z())
    s_case = Clause(PatternConstructor(p.s, [PatternVar(x)]), p.s(x))
    orig = Match(make_nat_expr(p, 3), [z_case, s_case])
    orig = Function([], orig)
    res = dcpe(orig, mod=mod)
    assert alpha_equal(res.body, make_nat_expr(p, 3))


def test_double():
    mod = tvm.IRModule()
    p = Prelude(mod)
    add_nat_definitions(p)
    orig = p.double(make_nat_expr(p, 3))
    orig = Function([], orig)
    res = dcpe(orig, mod=mod)
    assert alpha_equal(res.body, make_nat_expr(p, 6))


def test_concat():
    t = relay.TensorType([10], "float32")
    x = Var("x", t)
    y = Var("x", t)
    orig = run_infer_type(Function([x, y], op.concatenate([x, y], axis=0)))
    assert_alpha_equal(dcpe(orig), orig)


def test_triangle_number():
    t = relay.TensorType([], "int32")
    x = Var("x", t)
    f_var = Var("f")
    f = Function([x], If(op.equal(x, const(0)), const(0), x + f_var(x - const(1))))
    orig = run_infer_type(Let(f_var, f, f_var(const(10))))
    assert_alpha_equal(dcpe(orig), const(55))


def test_nat_update():
    m = tvm.IRModule()
    p = Prelude(m)
    add_nat_definitions(p)
    m = transform.ToANormalForm()(m)
    transform.PartialEvaluate()(m)


def test_tuple_match():
    a = relay.Var("a")
    b = relay.Var("b")
    clause = relay.Clause(relay.PatternTuple([relay.PatternVar(a), relay.PatternVar(b)]), a + b)
    x = relay.Match(relay.Tuple([relay.const(1), relay.const(1)]), [clause])
    assert_alpha_equal(dcpe(x), const(2))


if __name__ == '__main__':
    test_nat_update()
    test_ref()
    test_tuple()
    test_empty_ad()
    test_const_inline()
    test_ad()
    test_if_ref()
    test_function_invalidate()
    test_head_cons()
    test_map()
    test_loop()
    test_swap_loop()
    test_abs_diff()
    test_double()
    test_nat_id()
    test_global_match_nat_id()
    test_match_nat_id()
    test_concat()
    test_triangle_number()
    test_tuple_match()