# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm import te
from tvm import relay
from tvm.relay.analysis import graph_equal, assert_graph_equal
from tvm.relay.analysis import alpha_equal, assert_alpha_equal
import pytest
from numpy import isclose
from typing import Union
from functools import wraps
raises_parse_error = pytest.mark.xfail(raises=tvm._ffi.base.TVMError)

SEMVER = "v0.0.4"

BINARY_OPS = {
    "*": relay.multiply,
    "/": relay.divide,
    "+": relay.add,
    "-": relay.subtract,
    "<": relay.less,
    ">": relay.greater,
    "<=": relay.less_equal,
    ">=": relay.greater_equal,
    "==": relay.equal,
    "!=": relay.not_equal,
}

TYPES = {
    "int8",
    "int16",
    "int32",
    "int64",

    "uint8",
    "uint16",
    "uint32",
    "uint64",

    "float16",
    "float32",
    "float64",

    "bool",

    "int8x4",
    "uint1x4",
    "float16x4",
}

LIST_DEFN = """
type List[A] {
    Cons(A, List[A]),
    Nil,
}
"""

def roundtrip(expr):
    x = relay.fromtext(expr.astext())
    assert_graph_equal(x, expr)


def parse_text(code):
    expr = relay.fromtext(SEMVER + "\n" + code)
    roundtrip(expr)
    return expr


def parses_as(code, expr):
    # type: (str, relay.Expr) -> bool
    parsed = parse_text(code)
    result = graph_equal(parsed, expr)
    return result

def get_scalar(x):
    # type: (relay.Constant) -> (Union[float, int, bool])
    return x.data.asnumpy().item()

int32 = relay.scalar_type("int32")

_ = relay.Var("_")
X = relay.Var("x")
Y = relay.Var("y")
X_ANNO = relay.Var("x", int32)
Y_ANNO = relay.Var("y", int32)

UNIT = relay.Tuple([])


def test_comments():
    assert parses_as(
        """
        // This is a line comment!
        ()
        """,
        UNIT
    )

    assert parses_as(
        """
        /* This is a block comment!
            This is still a block comment!
        */
        ()
        """,
        UNIT
    )

    assert parses_as(
        """
        /* This is a block comment!
           /*Block comment is recursive!*/
        */
        ()
        """,
        UNIT
    )


def test_int_literal():
    assert isinstance(parse_text("1"), relay.Constant)
    assert isinstance(parse_text("1").data, tvm.nd.NDArray)

    assert get_scalar(parse_text("1")) == 1
    assert get_scalar(parse_text("10")) == 10
    assert get_scalar(parse_text("0")) == 0
    assert get_scalar(parse_text("-100")) == -100
    assert get_scalar(parse_text("-05")) == -5


def test_float_literal():
    assert get_scalar(parse_text("1.0f")) == 1.0
    assert isclose(get_scalar(parse_text("1.56667f")), 1.56667)
    assert get_scalar(parse_text("0.0f")) == 0.0
    assert get_scalar(parse_text("-10.0f")) == -10.0

    # scientific notation
    assert isclose(get_scalar(parse_text("1e-1f")), 1e-1)
    assert get_scalar(parse_text("1e+1f")) == 1e+1
    assert isclose(get_scalar(parse_text("1E-1f")), 1E-1)
    assert get_scalar(parse_text("1E+1f")) == 1E+1
    assert isclose(get_scalar(parse_text("1.0e-1f")), 1.0e-1)
    assert get_scalar(parse_text("1.0e+1f")) == 1.0e+1
    assert isclose(get_scalar(parse_text("1.0E-1f")), 1.0E-1)
    assert get_scalar(parse_text("1.0E+1f")) == 1.0E+1


def test_bool_literal():
    assert get_scalar(parse_text("True")) == True
    assert get_scalar(parse_text("False")) == False


def test_negative():
    assert isinstance(parse_text("let %x = 1; -%x").body, relay.Call)
    assert get_scalar(parse_text("--10")) == 10
    assert get_scalar(parse_text("---10")) == -10


def test_bin_op():
    for bin_op in BINARY_OPS.keys():
        assert parses_as(
            "1 {} 1".format(bin_op),
            BINARY_OPS.get(bin_op)(relay.const(1), relay.const(1))
        )


def test_parens():
    assert graph_equal(parse_text("1 * 1 + 1"), parse_text("(1 * 1) + 1"))
    assert not graph_equal(parse_text("1 * 1 + 1"), parse_text("1 * (1 + 1)"))


def test_op_assoc():
    assert graph_equal(parse_text("1 * 1 + 1 < 1 == 1"), parse_text("(((1 * 1) + 1) < 1) == 1"))
    assert graph_equal(parse_text("1 == 1 < 1 + 1 * 1"), parse_text("1 == (1 < (1 + (1 * 1)))"))


@pytest.mark.skip
def test_vars():
    # temp vars won't work b/c they start with a digit
    # # temp var
    # temp_var = parse_text("%1")
    # assert isinstance(temp_var, relay.Var)
    # assert temp_var.name == "1"

    # var
    var = parse_text("let %foo = (); %foo")
    assert isinstance(var.body, relay.Var)
    assert var.body.name_hint == "foo"

    # global var
    global_var = parse_text("@foo")
    assert isinstance(global_var, relay.GlobalVar)
    assert global_var.name_hint == "foo"

    # operator id
    op = parse_text("foo")
    assert isinstance(op, relay.Op)
    assert op.name == "foo"


def test_let():
    assert parses_as(
        "let %x = 1; ()",
        relay.Let(
            X,
            relay.const(1),
            UNIT
        )
    )

    assert parses_as(
        """
        let %x = 1;
        let %y = 2;
        ()
        """,
        relay.Let(
            X,
            relay.const(1),
            relay.Let(
                Y,
                relay.const(2),
                UNIT
            )
        )
    )


def test_seq():
    assert parses_as(
        "();; ()",
        relay.Let(
            _,
            UNIT,
            UNIT)
    )

    assert parses_as(
        "let %_ = 1; ()",
        relay.Let(
            X,
            relay.const(1),
            UNIT
        )
    )


def test_graph():
    code = "%0 = (); %1 = 1; (%0, %0, %1)"
    assert parses_as(
        code,
        relay.Tuple([UNIT, UNIT, relay.const(1)])
    )
    assert not parses_as(
        code,
        relay.Tuple([relay.Tuple([]), relay.Tuple([]), relay.const(1)])
    )


@raises_parse_error
def test_graph_wrong_order():
    parse_text("%1 = (); %1")


@raises_parse_error
def test_let_global_var():
    parse_text("let @x = 1; ()")


@raises_parse_error
def test_let_op():
    parse_text("let x = 1; ()")


def test_tuple():
    assert parses_as("()", relay.Tuple([]))

    assert parses_as("(0,)", relay.Tuple([relay.const(0)]))

    assert parses_as("(0, 1)", relay.Tuple([relay.const(0), relay.const(1)]))

    assert parses_as("(0, 1, 2)", relay.Tuple([relay.const(0), relay.const(1), relay.const(2)]))


def test_func():
    # 0 args
    assert parses_as(
        "fn () { 0 }",
        relay.Function(
            [],
            relay.const(0),
            None,
            []
        )
    )

    # 1 arg
    assert parses_as(
        "fn (%x) { %x }",
        relay.Function(
            [X],
            X,
            None,
            []
        )
    )

    # 2 args
    assert parses_as(
        "fn (%x, %y) { %x + %y }",
        relay.Function(
            [X, Y],
            relay.add(X, Y),
            None,
            []
        )
    )

    # annotations
    assert parses_as(
        "fn (%x: int32) -> int32 { %x }",
        relay.Function(
            [X_ANNO],
            X_ANNO,
            int32,
            []
        )
    )

    # attributes
    assert parses_as(
        "fn (n=5) { () }",
        relay.Function([], UNIT, None, None, tvm.ir.make_node("DictAttrs", n=relay.const(5)))
    )


# TODO(@jmp): Crashes if %x isn't annnotated.
def test_defn():
    id_defn = parse_text(
        """
        def @id(%x: int32) -> int32 {
            %x
        }
        """)
    assert isinstance(id_defn, tvm.IRModule)


def test_recursive_call():
    id_defn = parse_text(
        """
        def @id(%x: int32) -> int32 {
            @id(%x)
        }
        """)
    assert isinstance(id_defn, tvm.IRModule)


def test_ifelse():
    assert parses_as(
        """
        if (True) {
            0
        } else {
            1
        }
        """,
        relay.If(
            relay.const(True),
            relay.const(0),
            relay.const(1)
        )
    )


@raises_parse_error
def test_ifelse_scope():
    parse_text(
        """
        if (True) {
            let %x = ();
            ()
        } else {
            %x
        }
        """
    )


def test_call():
    # select right function to call: simple ident case
    id_func = relay.Var("id")
    assert parses_as(
        """
        let %id = fn (%x) { %x };
        10 * %id(10)
        """,
        relay.Let(
            id_func,
            relay.Function([X], X, None, []),
            relay.multiply(relay.const(10), relay.Call(id_func, [relay.const(10)]))
        )
    )

    # 0 args
    constant = relay.Var("constant")
    assert parses_as(
        """
        let %constant = fn () { 0 };
        %constant()
        """,
        relay.Let(
            constant,
            relay.Function([], relay.const(0), None, []),
            relay.Call(constant, [], None, None)
        )
    )

    # 1 arg
    id_var = relay.Var("id")
    assert parses_as(
        """
        let %id = fn (%x) { %x };
        %id(1)
        """,
        relay.Let(
            id_var,
            relay.Function([X], X, None, []),
            relay.Call(id_var, [relay.const(1)], None, None)
        )
    )

    # 2 args
    multiply = relay.Var("multiply")
    assert parses_as(
        """
        let %multiply = fn (%x, %y) { %x * %y };
        %multiply(0, 0)
        """,
        relay.Let(
            multiply,
            relay.Function(
                [X, Y],
                relay.multiply(X, Y),
                None,
                []
            ),
            relay.Call(multiply, [relay.const(0), relay.const(0)], None, None)
        )
    )

    # anonymous function
    assert parses_as(
        """
        (fn (%x) { %x })(0)
        """,
        relay.Call(
            relay.Function(
                [X],
                X,
                None,
                []
            ),
            [relay.const(0)],
            None,
            None
        )
    )

    # TODO(@jmp): re-enable after sequence parsing improvements
    # curried function
    # curried_mult = relay.Var("curried_mult")
    # assert parses_as(
    #     """
    #     let %curried_mult =
    #         fn (%x) {
    #         fn (%y) {
    #             %x * %y
    #         }
    #         };
    #     %curried_mult(0);
    #     %curried_mult(0)(0)
    #     """,
    #     relay.Let(
    #         curried_mult,
    #         relay.Function(
    #             [X],
    #             relay.Function(
    #                 [Y],
    #                 relay.multiply(X, Y),
    #                 None,
    #                 []
    #             ),
    #             None,
    #             []
    #         ),
    #         relay.Let(
    #             _,
    #             relay.Call(curried_mult, [relay.const(0)], None, None),
    #             relay.Call(relay.Call(curried_mult, [relay.const(0)], None, None), [relay.const(0)], None, None)
    #         )
    #     )
    # )

    # op
    assert parses_as(
        "abs(1)",
        relay.Call(relay.op.get("abs"), [relay.const(1)], None, None)
    )

# Types


def test_incomplete_type():
    assert parses_as(
        "let %_ : _ = (); ()",
        relay.Let(
            _,
            UNIT,
            UNIT
        )
    )


def test_builtin_types():
    for builtin_type in TYPES:
        parse_text("let %_ : {} = (); ()".format(builtin_type))


def test_tensor_type():
    assert parses_as(
        "let %_ : Tensor[(), float32] = (); ()",
        relay.Let(
            relay.Var("_", relay.TensorType((), "float32")),
            UNIT,
            UNIT
        )
    )

    assert parses_as(
        "let %_ : Tensor[(1), float32] = (); ()",
        relay.Let(
            relay.Var("_", relay.TensorType((1,), "float32")),
            UNIT,
            UNIT
        )
    )

    assert parses_as(
        "let %_ : Tensor[(1, 1), float32] = (); ()",
        relay.Let(
            relay.Var("_", relay.TensorType((1, 1), "float32")),
            UNIT,
            UNIT
        )
    )


def test_function_type():
    assert parses_as(
        """
        let %_: fn () -> int32 = fn () -> int32 { 0 }; ()
        """,
        relay.Let(
            relay.Var("_", relay.FuncType([], int32, [], [])),
            relay.Function([], relay.const(0), int32, []),
            UNIT
        )
    )

    assert parses_as(
        """
        let %_: fn (int32) -> int32 = fn (%x: int32) -> int32 { 0 }; ()
        """,
        relay.Let(
            relay.Var("_", relay.FuncType([int32], int32, [], [])),
            relay.Function([relay.Var("x", int32)], relay.const(0), int32, []),
            UNIT
        )
    )

    assert parses_as(
        """
        let %_: fn (int32, int32) -> int32 = fn (%x: int32, %y: int32) -> int32 { 0 }; ()
        """,
        relay.Let(
            relay.Var("_", relay.FuncType([int32, int32], int32, [], [])),
            relay.Function([relay.Var("x", int32), relay.Var("y", int32)], relay.const(0), int32, []),
            UNIT
        )
    )


def test_tuple_type():
    assert parses_as(
        """
        let %_: () = (); ()
        """,
        relay.Let(
            relay.Var("_", relay.TupleType([])),
            UNIT,
            UNIT
        )
    )

    assert parses_as(
        """
        let %_: (int32,) = (0,); ()
        """,
        relay.Let(
            relay.Var("_", relay.TupleType([int32])),
            relay.Tuple([relay.const(0)]),
            UNIT
        )
    )

    assert parses_as(
        """
        let %_: (int32, int32) = (0, 1); ()
        """,
        relay.Let(
            relay.Var("_", relay.TupleType([int32, int32])),
            relay.Tuple([relay.const(0), relay.const(1)]),
            UNIT
        )
    )


def test_adt_defn():
    mod = tvm.IRModule()

    glob_typ_var = relay.GlobalTypeVar("Ayy")
    prog = relay.TypeData(
            glob_typ_var,
            [],
            [relay.Constructor("Nil", [], glob_typ_var)])
    mod[glob_typ_var] = prog
    assert parses_as(
        """
        type Ayy { Nil }
        """,
        mod
    )


def test_empty_adt_defn():
    mod = tvm.IRModule()

    glob_typ_var = relay.GlobalTypeVar("Ayy")
    prog = relay.TypeData(glob_typ_var, [], [])
    mod[glob_typ_var] = prog
    assert parses_as(
        """
        type Ayy { }
        """,
        mod
    )


def test_multiple_cons_defn():
    mod = tvm.IRModule()

    list_var = relay.GlobalTypeVar("List")
    typ_var = relay.TypeVar("A")
    prog = relay.TypeData(
            list_var,
            [typ_var],
            [
                relay.Constructor("Cons", [typ_var, list_var(typ_var)], list_var),
                relay.Constructor("Nil", [], list_var),
            ])
    mod[list_var] = prog
    assert parses_as(LIST_DEFN, mod)


def test_multiple_type_param_defn():
    glob_typ_var = relay.GlobalTypeVar("Either")
    typ_var_a = relay.TypeVar("A")
    typ_var_b = relay.TypeVar("B")
    prog = relay.TypeData(
            glob_typ_var,
            [typ_var_a, typ_var_b],
            [
                relay.Constructor("Left", [typ_var_a], glob_typ_var),
                relay.Constructor("Right", [typ_var_b], glob_typ_var),
            ])
    mod = tvm.IRModule()
    mod[glob_typ_var] = prog
    assert parses_as(
        """
        type Either[A, B] {
          Left(A),
          Right(B),
        }
        """,
        mod
    )


def test_match():
    # pair each match keyword with whether it specifies a complete match or not
    match_keywords = [("match", True), ("match?", False)]
    for (match_keyword, is_complete) in match_keywords:
        mod = tvm.IRModule()

        list_var = relay.GlobalTypeVar("List")
        typ_var = relay.TypeVar("A")
        cons_constructor = relay.Constructor(
            "Cons", [typ_var, list_var(typ_var)], list_var)
        nil_constructor = relay.Constructor("Nil", [], list_var)
        list_def = relay.TypeData(
            list_var,
            [typ_var],
            [cons_constructor, nil_constructor])
        mod[list_var] = list_def

        length_var = relay.GlobalVar("length")
        typ_var = relay.TypeVar("A")
        input_type = list_var(typ_var)
        input_var = relay.Var("xs", input_type)
        rest_var = relay.Var("rest")
        cons_case = relay.Let(
            _,
            UNIT,
            relay.add(relay.const(1), relay.Call(length_var, [rest_var])))
        body = relay.Match(input_var,
            [relay.Clause(
                relay.PatternConstructor(
                    cons_constructor,
                    [relay.PatternWildcard(), relay.PatternVar(rest_var)]),
                cons_case),
            relay.Clause(
                relay.PatternConstructor(nil_constructor, []),
                relay.const(0))],
            complete=is_complete
        )
        length_func = relay.Function(
            [input_var],
            body,
            int32,
            [typ_var]
        )
        mod[length_var] = length_func

        assert parses_as(
            """
            %s

            def @length[A](%%xs: List[A]) -> int32 {
              %s (%%xs) {
                Cons(_, %%rest) => {
                  ();;
                  1 + @length(%%rest)
                },
                Nil => 0,
              }
            }
            """ % (LIST_DEFN, match_keyword),
            mod
        )


def test_adt_cons_expr():
    mod = tvm.IRModule()

    list_var = relay.GlobalTypeVar("List")
    typ_var = relay.TypeVar("A")
    cons_constructor = relay.Constructor(
        "Cons", [typ_var, list_var(typ_var)], list_var)
    nil_constructor = relay.Constructor("Nil", [], list_var)
    list_def = relay.TypeData(
        list_var,
        [typ_var],
        [cons_constructor, nil_constructor])
    mod[list_var] = list_def

    make_singleton_var = relay.GlobalVar("make_singleton")
    input_var = relay.Var("x", int32)
    make_singleton_func = relay.Function(
        [input_var],
        cons_constructor(input_var, nil_constructor()),
        list_var(int32)
    )
    mod[make_singleton_var] = make_singleton_func

    assert parses_as(
        """
        %s

        def @make_singleton(%%x: int32) -> List[int32] {
          Cons(%%x, Nil)
        }
        """ % LIST_DEFN,
        mod
    )


@raises_parse_error
def test_duplicate_adt_defn():
    parse_text(
        """
        %s

        type List[A] {
          Cons(A, List[A]),
          Nil,
        }
        """ % LIST_DEFN
    )


@raises_parse_error
def test_duplicate_adt_cons():
    parse_text(
        """
        type Ayy { Lmao }
        type Haha { Lmao }
        """
    )


@raises_parse_error
def test_duplicate_adt_cons_defn():
    parse_text(
        """
        type Ayy { Lmao }
        type Lmao { Ayy }
        """
    )


@raises_parse_error
def test_duplicate_global_var():
    parse_text(
        """
        def @id[A](%x: A) -> A { x }
        def @id[A](%x: A) -> A { x }
        """
    )


def test_extern_adt_defn():
    # TODO(weberlo): update this test once extern is implemented
    mod = tvm.IRModule()

    extern_var = relay.GlobalTypeVar("T")
    typ_var = relay.TypeVar("A")
    extern_def = relay.TypeData(extern_var, [typ_var], [])
    mod[extern_var] = extern_def

    assert parses_as(
        """
        extern type T[A]
        """,
        mod
    )
def test_import_grad():
    mod = tvm.IRModule()
    mod.import_from_std("gradient.rly")

if __name__ == "__main__":
    test_comments()
    test_int_literal()
    test_float_literal()
    test_bool_literal()
    test_negative()
    test_bin_op()
    test_parens()
    test_op_assoc()
    test_let()
    test_seq()
    test_graph()
    test_tuple()
    test_func()
    test_defn()
    test_recursive_call()
    test_ifelse()
    test_call()
    test_incomplete_type()
    test_builtin_types()
    test_tensor_type()
    test_function_type()
    test_tuple_type()
    test_adt_defn()
    test_empty_adt_defn()
    test_multiple_cons_defn()
    test_multiple_type_param_defn()
    test_match()
    test_adt_cons_expr()
    test_duplicate_adt_defn()
    test_duplicate_adt_cons()
    test_duplicate_adt_cons_defn()
    test_duplicate_global_var()
    test_extern_adt_defn()
    test_import_grad()