# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

# pylint: disable=invalid-name, unused-argument
"""A parser for Relay's text format."""
from __future__ import absolute_import

import sys
from ast import literal_eval

from collections import deque

import tvm

from . import module
from .base import Span, SourceName
from . import expr
from . import ty
from . import op

PYTHON_VERSION = sys.version_info.major
try:
    from .grammar.py3.RelayVisitor import RelayVisitor
    from .grammar.py3.RelayParser import RelayParser
    from .grammar.py3.RelayLexer import RelayLexer
except ImportError:
    raise Exeption("Couldn't find ANTLR parser. Try building with USE_ANTLR=ON.")

try:
    from antlr4 import InputStream, CommonTokenStream
    from antlr4.error.ErrorListener import ErrorListener
except ImportError:
    raise Exception("Couldn't find ANTLR runtime." +
                    "Try running `pip{version} install antlr4-python{version}-runtime`."
                    .format(version=PYTHON_VERSION))

sys.setrecursionlimit(10000)

class ParseError(Exception):
    """Exception type for parse errors."""

    def __init__(self, message):
        # type: (str) -> None
        super(ParseError, self).__init__()
        self.message = message

    def __repr__(self):
        return "ParseError({})".format(self.message)

    def __str__(self):
        return repr(self)

class OpWrapper:
    """Overload the __call__ for op."""
    pass

class ExprOp(OpWrapper):
    """Call an expr. The default, but does not handle attrs well."""
    def __init__(self, operator):
        self.operator = operator

    def __call__(self, args, attrs, type_args):
        try:
            return expr.Call(self.operator, args, attrs, type_args)
        except Exception:
            raise Exception(str(self.operator) + " " + str(attrs))

class FuncOp(OpWrapper):
    """Convert the attrs, call the python function with the attrs passed in as keyword arguments.
    Tvm should provide this in the future, as this is pretty similar to what op.get is providing.
    """
    def __init__(self, operator):
        self.operator = operator

    def convert(self, v):
        if isinstance(v, tuple):
            return tuple([self.convert(x) for x in v])
        if isinstance(v, expr.Constant):
            return v.data.asnumpy().item()
        if isinstance(v, str):
            return v
        raise Exception(v)

    def __call__(self, args, attrs, type_args):
        if attrs is None:
            attrs = {}
        x = self.operator(*args, **{k: self.convert(v) for k, v in attrs.items()})
        if isinstance(x, expr.TupleWrapper):
            x = x.astuple()
        return x

BINARY_OPS = {
    RelayParser.MUL: op.multiply,
    RelayParser.DIV: op.divide,
    RelayParser.ADD: op.add,
    RelayParser.SUB: op.subtract,
    RelayParser.LT:  op.less,
    RelayParser.GT:  op.greater,
    RelayParser.LE:  op.less_equal,
    RelayParser.GE:  op.greater_equal,
    RelayParser.EQ:  op.equal,
    RelayParser.NE:  op.not_equal,
}

FUNC_OPS = {
    "nn.conv2d": op.nn.conv2d,
    "nn.batch_norm": op.nn.batch_norm,
    "nn.dense": op.nn.dense,
    "nn.bias_add": op.nn.bias_add,
    "nn.max_pool2d": op.nn.max_pool2d,
    "nn.global_max_pool2d": op.nn.global_max_pool2d,
    "nn.avg_pool2d": op.nn.avg_pool2d,
    "nn.global_avg_pool2d": op.nn.global_avg_pool2d,
    "nn.softmax": op.nn.softmax,
    "reshape": op.reshape,
    "nn.conv2d_transpose": op.nn.conv2d_transpose,
    "concatenate": op.concatenate,
    "nn.dropout": op.nn.dropout_raw,
    "zeros": op.zeros,
    "split": op.split,
}

TYPE_PREFIXES = [
    "int",
    "uint",
    "float",
    "bool",
]

T = ty.TypeVar("T")
# Scope = Deque[Tuple[str, T]]
# Scopes = Deque[Scope[T]]

def lookup(scopes, name):
    # type: (Scopes[T], str) -> Optional[T]
    """Look up `name` in `scopes`."""

    for scope in scopes:
        for key, val in scope:
            if key == name:
                return val
    return None

def spanify(f):
    """A decorator which attaches span information
       to the value returned by calling `f`.

       Intended for use with the below AST visiting
       methods. The idea is that after we do the work
       of constructing the AST we attach Span information.
    """

    def _wrapper(*args, **kwargs):
        # Assumes 0th arg is self and gets source_name from object.
        sn = args[0].source_name
        # Assumes 1st arg is an ANTLR parser context.
        ctx = args[1]
        ast = f(*args, **kwargs)
        line, col = ctx.getSourceInterval()
        sp = Span(sn, line, col)
        if isinstance(ast, tvm.relay.expr.TupleWrapper):
            ast = ast.astuple()
        ast.set_span(sp)
        return ast
    return _wrapper

# TODO(@jmp): Use https://stackoverflow.com/q/13889941
# to figure out how to get ANTLR4 to be more unhappy about syntax errors
class ParseTreeToRelayIR(RelayVisitor):
    """Parse Relay text format into Relay IR."""

    def __init__(self, source_name):
        # type: (str) -> None
        self.source_name = source_name
        self.module = module.Module({})   # type: module.Module

        # Adding an empty scope allows naked lets without pain.
        self.var_scopes = deque([deque()])          # type: Scopes[expr.Var]
        self.global_var_scope = deque()             # type: Scope[expr.GlobalVar]
        self.type_param_scopes = deque([deque()])   # type: Scopes[ty.TypeVar]
        self.graph_expr = []                        # type: List[expr.Expr]

        super(ParseTreeToRelayIR, self).__init__()


    def enter_var_scope(self):
        # type: () -> None
        """Enter a new Var scope so it can be popped off later."""

        self.var_scopes.appendleft(deque())

    def exit_var_scope(self):
        # type: () -> Scope[expr.Var]
        """Pop off the current Var scope and return it."""

        return self.var_scopes.popleft()

    def mk_var(self, name, type_):
        # type: (str, ty.Type) -> expr.Var
        """Create a new Var and add it to the Var scope."""

        var = expr.Var(name, type_)
        self.var_scopes[0].appendleft((name, var))
        return var

    def mk_global_var(self, name):
        # type: (str) -> expr.GlobalVar
        """Create a new GlobalVar and add it to the GlobalVar scope."""

        var = expr.GlobalVar(name)
        self.global_var_scope.append((name, var))
        return var

    def enter_type_param_scope(self):
        # type: () -> None
        """Enter a new TypeVar scope so it can be popped off later."""

        self.type_param_scopes.appendleft(deque())

    def exit_type_param_scope(self):
        # type: () -> Scope[ty.TypeVar]
        """Pop off the current TypeVar scope and return it."""

        return self.type_param_scopes.popleft()

    def mk_typ(self, name, kind):
        # (str, ty.Kind) -> ty.TypeVar
        """Create a new TypeVar and add it to the TypeVar scope."""

        typ = ty.TypeVar(name, kind)
        self.type_param_scopes[0].appendleft((name, typ))
        return typ

    def visitProjection(self, ctx):
        return expr.TupleGetItem(self.visit(ctx.expr()), self.visit(ctx.NAT()))

    def visitTerminal(self, node):
        # type: (TerminalNode) -> Union[expr.Expr, int, float]
        """Visit lexer tokens that aren't ignored or visited by other functions."""

        node_type = node.getSymbol().type
        node_text = node.getText()
        name = node_text[1:]

        # variables
        if node_type == RelayLexer.GLOBAL_VAR:
            return lookup(deque([self.global_var_scope]), node_text[1:])
        if node_type == RelayLexer.LOCAL_VAR:
            # Remove the leading '%' and lookup the name.
            var = lookup(self.var_scopes, name)
            if var is None:
                raise ParseError("Couldn't resolve `{}`.".format(name))
            return var
        if node_type == RelayLexer.GRAPH_VAR:
            try:
                return self.graph_expr[int(name)]
            except IndexError:
                raise ParseError("Couldn't resolve `{}`".format(name))

        # data types
        if node_type == RelayLexer.NAT:
            return int(node_text)
        if node_type == RelayLexer.FLOAT:
            return float(node_text[:-1])
        if node_type == RelayLexer.BOOL_LIT:
            if node_text == "True":
                return True
            if node_text == "False":
                return False
            raise ParseError("Unrecognized BOOL_LIT: `{}`".format(node_text))
        if node_type == RelayLexer.QUOTED_STRING:
            return literal_eval(node_text)

        raise ParseError("todo: `{}`".format(node_text))

    def visit_list(self, ctx_list):
        # type: (List[ParserRuleContext]) -> List[Any]
        """"Visit a list of contexts."""
        assert isinstance(ctx_list, list)

        return [self.visit(ctx) for ctx in ctx_list]

    def getType_(self, ctx):
        # type: (Optional[RelayParser.Type_Context]) -> Optional[ty.Type]
        """Return a (possibly None) Relay type."""

        if ctx is None:
            return None

        return self.visit(ctx)

    def visitProg(self, ctx):
        self.meta = None
        if ctx.METADATA():
            header, data = str(ctx.METADATA()).split('\n', 1)
            assert header == "METADATA:"
            self.meta = tvm.load_json(data)
        # type: (RelayParser.ProgContext) -> Union[expr.Expr, module.Module]
        if ctx.defn():
            self.visit_list(ctx.defn())
            return self.module

        if ctx.expr():
            return self.visit(ctx.expr())

        return self.module

    # Exprs
    def visitOpIdent(self, ctx):
        # type: (RelayParser.OpIdentContext) -> op.Op
        op_name = ctx.CNAME().getText()
        if op_name in FUNC_OPS:
            return FuncOp(FUNC_OPS[op_name])
        return ExprOp(op.get(op_name))

    # pass through
    def visitParen(self, ctx):
        # type: (RelayParser.ParenContext) -> expr.Expr
        return self.visit(ctx.expr())

    # pass through
    def visitBody(self, ctx):
        # type: (RelayParser.BodyContext) -> expr.Expr
        return self.visit(ctx.expr())

    def visitScalarFloat(self, ctx):
        # type: (RelayParser.ScalarFloatContext) -> expr.Constant
        return expr.const(self.visit(ctx.FLOAT()))

    def visitScalarInt(self, ctx):
        # type: (RelayParser.ScalarIntContext) -> expr.Constant
        return expr.const(self.visit(ctx.NAT()))

    def visitScalarBool(self, ctx):
        # type: (RelayParser.ScalarBoolContext) -> expr.Constant
        return expr.const(self.visit(ctx.BOOL_LIT()))

    def visitNeg(self, ctx):
        # type: (RelayParser.NegContext) -> Union[expr.Constant, expr.Call]
        val = self.visit(ctx.expr())
        if isinstance(val, expr.Constant) and val.data.asnumpy().ndim == 0:
            # fold Neg in for scalars
            return expr.const(-val.data.asnumpy().item())

        return op.negative(val)

    def visitTuple(self, ctx):
        # type: (RelayParser.TupleContext) -> expr.Tuple
        tup = self.visit_list(ctx.expr())
        return expr.Tuple(tup)

    def visitLet(self, ctx):
        # type: (RelayParser.SeqContext) -> expr.Let
        """Desugar various sequence constructs to Relay Let nodes."""

        if ctx.var() is None:
            # anonymous identity
            ident = "_"
            type_ = None
            var = self.mk_var(ident, type_)
        else:
            var = self.visitVar(ctx.var())

        self.enter_var_scope()
        value = self.visit(ctx.expr(0))
        self.exit_var_scope()

        body = self.visit(ctx.expr(1))

        return expr.Let(var, value, body)

    def visitBinOp(self, ctx):
        # type: (RelayParser.BinOpContext) -> expr.Call
        """Desugar binary operators."""
        arg0, arg1 = self.visit_list(ctx.expr())
        relay_op = BINARY_OPS.get(ctx.op.type)

        if relay_op is None:
            raise ParseError("Unimplemented binary op.")

        return relay_op(arg0, arg1)

    @spanify
    def visitVar(self, ctx):
        # type: (RelayParser.VarContext) -> expr.Var
        """Visit a single variable."""
        ident = ctx.LOCAL_VAR()

        if ident is None:
            raise ParseError("Only local ids may be used in vars.")

        type_ = self.getType_(ctx.type_())

        return self.mk_var(ident.getText()[1:], type_)

    def visitVarList(self, ctx):
        # type: (RelayParser.VarListContext) -> List[expr.Var]
        return self.visit_list(ctx.var())

    # TODO: support a larger class of values than just Relay exprs
    def visitAttr(self, ctx):
        # type: (RelayParser.AttrContext) -> Tuple[str, expr.Expr]
        return (ctx.CNAME().getText(), self.visit(ctx.expr()))

    def visitArgNoAttr(self, ctx):
        return (self.visit_list(ctx.varList().var()), None)

    def visitAttrSeq(self, ctx):
        # type: (RelayParser.AttrListContext) -> Dict[str, expr.Expr]
        return dict(self.visit_list(ctx.attr()))

    def visitArgWithAttr(self, ctx):
        return (self.visit_list(ctx.var()), self.visitAttrSeq(ctx.attrSeq()))

    def visitArgList(self,
                     ctx    # type: RelayParser.ArgListContext
                    ):
        # type: (...) -> Tuple[Optional[List[expr.Var]], Optional[Dict[str, expr.Expr]]]
        var_list = self.visit(ctx.varList()) if ctx.varList() else None
        attr_list = self.visit(ctx.attrList()) if ctx.attrList() else None
        return (var_list, attr_list)

    def visitMeta(self, ctx):
        type_key = str(ctx.CNAME())
        index = int(self.visit(ctx.NAT()))
        return self.meta[type_key][index]

    def mk_func(self, ctx):
        # type: (Union[RelayParser.FuncContext, RelayParser.DefnContext]) -> expr.Function
        """Construct a function from either a Func or Defn."""

        # Enter var scope early to put params in scope.
        self.enter_var_scope()
        # Capture type params in params.
        self.enter_type_param_scope()
        type_params = ctx.typeParamList()

        if type_params is not None:
            type_params = type_params.ident()
            assert type_params
            for ty_param in type_params:
                name = ty_param.getText()
                self.mk_typ(name, ty.Kind.Type)

        var_list, attr_list = self.visit(ctx.argList())
        if var_list is None:
            var_list = []
        ret_type = self.getType_(ctx.type_())

        body = self.visit(ctx.body())
        # NB(@jroesch): you must stay in the type parameter scope until
        # after you exit the body, you can reference the type parameters
        # of your parent scopes.
        type_params = list(self.exit_type_param_scope())
        if type_params:
            _, type_params = zip(*type_params)
        self.exit_var_scope()

        attrs = tvm.make.node("DictAttrs", **attr_list) if attr_list is not None else None
        return expr.Function(var_list, body, ret_type, type_params, attrs)

    @spanify
    def visitFunc(self, ctx):
        # type: (RelayParser.FuncContext) -> expr.Function
        return self.mk_func(ctx)

    # TODO: how to set spans for definitions?
    # @spanify
    def visitDefn(self, ctx):
        # type: (RelayParser.DefnContext) -> None
        ident = ctx.ident().GLOBAL_VAR()
        if ident is None:
            raise ParseError("Only global ids may be used in `def`s.")
        ident_name = ident.getText()[1:]
        ident = self.mk_global_var(ident_name)
        self.module[ident] = self.mk_func(ctx)

    def visitCallNoAttr(self, ctx):
        return (self.visit_list(ctx.exprList().expr()), None)

    def visitCallWithAttr(self, ctx):
        return (self.visit_list(ctx.expr()), self.visit(ctx.attrSeq()))

    def call(self, func, args, attrs, type_args):
        if isinstance(func, OpWrapper):
            return func(args, attrs, type_args)
        return expr.Call(func, args, attrs, type_args)

    @spanify
    def visitCall(self, ctx):
        # type: (RelayParser.CallContext) -> expr.Call
        func = self.visit(ctx.expr())
        args, attrs = self.visit(ctx.callList())
        return self.call(func, args, attrs, [])

    @spanify
    def visitIfElse(self, ctx):
        # type: (RelayParser.IfElseContext) -> expr.If
        """Construct a Relay If node. Creates a new scope for each branch."""
        cond = self.visit(ctx.expr())

        self.enter_var_scope()
        true_branch = self.visit(ctx.body(0))
        self.exit_var_scope()

        self.enter_var_scope()
        false_branch = self.visit(ctx.body(1))
        self.exit_var_scope()

        return expr.If(cond, true_branch, false_branch)

    @spanify
    def visitGraph(self, ctx):
        # type: (RelayParser.GraphContext) -> expr.Expr
        """Visit a graph variable assignment."""
        graph_nid = int(ctx.GRAPH_VAR().getText()[1:])

        self.enter_var_scope()
        value = self.visit(ctx.expr(0))
        self.exit_var_scope()

        if graph_nid != len(self.graph_expr):
            raise ParseError(
                "Expected new graph variable to be `%{}`,".format(len(self.graph_expr)) + \
                "but got `%{}`".format(graph_nid))
        self.graph_expr.append(value)

        kont = self.visit(ctx.expr(1))
        return kont

    # Types

    # pylint: disable=unused-argument
    def visitIncompleteType(self, ctx):
        # type (RelayParser.IncompleteTypeContext) -> None:
        return None

    def visitTypeIdent(self, ctx):
        # type: (RelayParser.TypeIdentContext) -> Union[ty.TensorType, str]
        '''
        Handle type identifier.
        '''
        type_ident = ctx.CNAME().getText()

        # Look through all type prefixes for a match
        for type_prefix in TYPE_PREFIXES:
            if type_ident.startswith(type_prefix):
                return ty.scalar_type(type_ident)

        type_param = lookup(self.type_param_scopes, type_ident)
        if type_param is not None:
            return type_param

        raise ParseError("Unknown builtin type: {}".format(type_ident))

    # def visitCallType(self, ctx):
    #     # type: (RelayParser.CallTypeContext) -> Union[expr.Expr, ty.TensorType]
    #     ident_type = ctx.identType().CNAME().getText()

    #     args = self.visit_list(ctx.type_())

    #     if not args:
    #         raise ParseError("Type-level functions must have arguments!")

    #     func_type = TYPE_FUNCS.get(ident_type)(args)

    #     if func_type is None:
    #         raise ParseError("Unknown type-level function: `{}`".format(ident_type))
    #     else:
    #         return func_type

    def visitParensShape(self, ctx):
        # type: (RelayParser.ParensShapeContext) -> int
        return self.visit(ctx.shape())

    def visitShapeList(self, ctx):
        # type: (RelayParser.ShapeListContext) -> List[int]
        return self.visit_list(ctx.shape())

    def visitTensor(self, ctx):
        return tuple(self.visit_list(ctx.expr()))

    def visitTensorType(self, ctx):
        # type: (RelayParser.TensorTypeContext) -> ty.TensorType
        """Create a simple tensor type. No generics."""

        shape = self.visit(ctx.shapeList())
        dtype = self.visit(ctx.type_())

        if not isinstance(dtype, ty.TensorType):
            raise ParseError("Expected dtype to be a Relay base type.")

        dtype = dtype.dtype

        return ty.TensorType(shape, dtype)

    def visitTupleType(self, ctx):
        # type: (RelayParser.TupleTypeContext) -> ty.TupleType
        return ty.TupleType(self.visit_list(ctx.type_()))

    def visitFuncType(self, ctx):
        # type: (RelayParser.FuncTypeContext) -> ty.FuncType
        types = self.visit_list(ctx.type_())

        arg_types = types[:-1]
        ret_type = types[-1]

        return ty.FuncType(arg_types, ret_type, [], None)

def make_parser(data):
    # type: (str) -> RelayParser
    """Construct a RelayParser a given data stream."""
    input_stream = InputStream(data)
    lexer = RelayLexer(input_stream)
    lexer.addErrorListener(StrictErrorListener(data))
    token_stream = CommonTokenStream(lexer)
    p = RelayParser(token_stream)
    p.addErrorListener(StrictErrorListener(data))
    return p

__source_name_counter__ = 0

class StrictErrorListener(ErrorListener):
    """This ErrorListener fail eagerly on all error, and report the program."""
    def __init__(self, text):
        self.text = text

    def syntaxError(self, recognizer, offendingSymbol, line, column, msg, e):
        raise Exception("Syntax Error in:\n" + self.text)

    def reportAmbiguity(self, recognizer, dfa, startIndex, stopIndex, exact, ambigAlts, configs):
        raise Exception("Ambiguity Error in:\n" + self.text)

    def reportAttemptingFullContext(self,
                                    recognizer,
                                    dfa,
                                    startIndex,
                                    stopIndex,
                                    conflictingAlts,
                                    configs):
        raise Exception("Attempting Full Context in:\n" + self.text)

    def reportContextSensitivity(self, recognizer, dfa, startIndex, stopIndex, prediction, configs):
        raise Exception("Context Sensitivity in:\n" + self.text)

def fromtext(data, source_name=None):
    # type: (str, str) -> Union[expr.Expr, module.Module]
    """Parse a Relay program."""
    if data == "":
        raise ParseError("Cannot parse the empty string.")

    global __source_name_counter__

    if source_name is None:
        source_name = "source_file{0}".format(__source_name_counter__)

    if isinstance(source_name, str):
        source_name = SourceName(source_name)

    tree = make_parser(data).prog()
    return ParseTreeToRelayIR(source_name).visit(tree)