# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

# pylint: disable=invalid-name, unused-argument
"""A parser for Relay's text format."""
from __future__ import absolute_import

import sys
from ast import literal_eval
from typing import Any, Deque, Dict, List, Optional, TypeVar, Tuple, Union
from collections import deque

import tvm

from . import module
from .base import Span, SourceName
from . import adt
from . import expr
from . import ty
from . import op

PYTHON_VERSION = sys.version_info.major
try:
    from .grammar.py3.RelayVisitor import RelayVisitor
    from .grammar.py3.RelayParser import RelayParser
    from .grammar.py3.RelayLexer import RelayLexer
except ImportError:
    raise Exception("Couldn't find ANTLR parser. Try building with USE_ANTLR=ON.")

try:
    from antlr4 import InputStream, CommonTokenStream
    from antlr4.error.ErrorListener import ErrorListener
except ImportError:
    raise Exception("Couldn't find ANTLR runtime." +
                    "Try running `pip{version} install antlr4-python{version}-runtime`."
                    .format(version=PYTHON_VERSION))

sys.setrecursionlimit(10000)

class ParseError(Exception):
    """Exception type for parse errors."""

    def __init__(self, message: str) -> None:
        super(ParseError, self).__init__()
        self.message = message

    def __repr__(self):
        return "ParseError({})".format(self.message)

    def __str__(self):
        return repr(self)

class OpWrapper:
    """Overload the __call__ for op."""
    pass

class ExprOp(OpWrapper):
    """Call an expr. The default, but does not handle attrs well."""
    def __init__(self, operator):
        self.operator = operator

    def __call__(self, args, attrs, type_args):
        try:
            return expr.Call(self.operator, args, attrs, type_args)
        except Exception:
            raise Exception("Operator {} is not registered. It's attributes are {}"
                            .format(self.operator, attrs))

class FuncOp(OpWrapper):
    """Convert the attrs, call the python function with the attrs passed in as keyword arguments.
    Tvm should provide this in the future, as this is pretty similar to what op.get is providing.
    """
    def __init__(self, operator):
        self.operator = operator

    def convert(self, v):
        if isinstance(v, tuple):
            return tuple([self.convert(x) for x in v])
        if isinstance(v, expr.Constant):
            return v.data.asnumpy().item()
        if isinstance(v, str):
            return v
        raise Exception(v)

    def __call__(self, args, attrs, type_args):
        if attrs is None:
            attrs = {}
        x = self.operator(*args, **{k: self.convert(v) for k, v in attrs.items()})
        if isinstance(x, expr.TupleWrapper):
            x = x.astuple()
        return x

BINARY_OPS = {
    RelayParser.MUL: op.multiply,
    RelayParser.DIV: op.divide,
    RelayParser.ADD: op.add,
    RelayParser.SUB: op.subtract,
    RelayParser.LT:  op.less,
    RelayParser.GT:  op.greater,
    RelayParser.LE:  op.less_equal,
    RelayParser.GE:  op.greater_equal,
    RelayParser.EQ:  op.equal,
    RelayParser.NE:  op.not_equal,
}

FUNC_OPS = {
    "nn.conv2d": op.nn.conv2d,
    "nn.batch_norm": op.nn.batch_norm,
    "nn.dense": op.nn.dense,
    "nn.bias_add": op.nn.bias_add,
    "nn.max_pool2d": op.nn.max_pool2d,
    "nn.global_max_pool2d": op.nn.global_max_pool2d,
    "nn.avg_pool2d": op.nn.avg_pool2d,
    "nn.global_avg_pool2d": op.nn.global_avg_pool2d,
    "nn.softmax": op.nn.softmax,
    "reshape": op.reshape,
    "nn.conv2d_transpose": op.nn.conv2d_transpose,
    "concatenate": op.concatenate,
    "nn.dropout": op.nn.dropout_raw,
    "zeros": op.zeros,
    "split": op.split,
    "cast": op.cast
}

TYPE_PREFIXES = [
    "int",
    "uint",
    "float",
    "bool",
]

T = TypeVar("T")
Scope = Deque[Tuple[str, T]]
Scopes = Deque[Scope[T]]

def lookup(scopes: Scopes[T], name: str) -> Optional[T]:
    """Look up `name` in `scopes`."""

    for scope in scopes:
        for key, val in scope:
            if key == name:
                return val
    return None

def spanify(f):
    """A decorator which attaches span information
       to the value returned by calling `f`.

       Intended for use with the below AST visiting
       methods. The idea is that after we do the work
       of constructing the AST we attach Span information.
    """

    def _wrapper(*args, **kwargs):
        # Assumes 0th arg is self and gets source_name from object.
        sn = args[0].source_name
        # Assumes 1st arg is an ANTLR parser context.
        ctx = args[1]
        ast = f(*args, **kwargs)
        line, col = ctx.getSourceInterval()
        sp = Span(sn, line, col)
        if isinstance(ast, tvm.relay.expr.TupleWrapper):
            ast = ast.astuple()
        ast.set_span(sp)
        return ast
    return _wrapper

# TODO(@jmp): Use https://stackoverflow.com/q/13889941
# to figure out how to get ANTLR4 to be more unhappy about syntax errors
class ParseTreeToRelayIR(RelayVisitor):
    """Parse Relay text format into Relay IR."""

    def __init__(self, source_name: str) -> None:
        self.source_name = source_name
        self.module = module.Module({})  # type: module.Module

        # Adding an empty scope allows naked lets without pain.
        self.var_scopes = deque([deque()])       # type: Scopes[expr.Var]
        self.global_vars = {}                    # type: Scope[expr.GlobalVar]
        self.type_var_scopes = deque([deque()])  # type: Scopes[ty.TypeVar]
        self.global_type_vars = {}               # type: Scope[expr.GlobalVar]
        self.graph_expr = []                     # type: List[expr.Expr]

        super(ParseTreeToRelayIR, self).__init__()


    def enter_var_scope(self) -> None:
        """Enter a new Var scope so it can be popped off later."""
        self.var_scopes.appendleft(deque())

    def exit_var_scope(self) -> Scope[expr.Var]:
        """Pop off the current Var scope and return it."""
        return self.var_scopes.popleft()

    def mk_var(self, name: str, typ: ty.Type = None):
        """Create a new Var and add it to the Var scope."""
        var = expr.Var(name, typ)
        self.var_scopes[0].appendleft((name, var))
        return var

    def mk_global_var(self, name: str) -> expr.GlobalVar:
        """Create a new GlobalVar and add it to the GlobalVar scope."""
        if name in self.global_vars:
            raise ParseError(f"duplicate global var \"{name}\"")
        var = expr.GlobalVar(name)
        self.global_vars[name] = var
        return var

    def enter_type_param_scope(self) -> None:
        """Enter a new TypeVar scope so it can be popped off later."""
        self.type_var_scopes.appendleft(deque())

    def exit_type_param_scope(self) -> Scope[ty.TypeVar]:
        """Pop off the current TypeVar scope and return it."""
        return self.type_var_scopes.popleft()

    def mk_typ(self, name: str, kind: ty.Kind) -> ty.TypeVar:
        """Create a new TypeVar and add it to the TypeVar scope."""
        typ = ty.TypeVar(name, kind)
        self.type_var_scopes[0].append((name, typ))
        return typ

    def mk_global_typ_var(self, name, kind):
        # (str, ty.Kind) -> ty.GlobalTypeVar
        """Create a new TypeVar and add it to the TypeVar scope."""
        typ = ty.GlobalTypeVar(name, kind)
        self._check_existing_typ_expr(name, typ)
        self.global_type_vars[name] = typ
        return typ

    # TODO(weberlo): rethink whether we should have type constructors mixed with type vars.
    def mk_global_typ_cons(self, name, cons):
        self._check_existing_typ_expr(name, cons)
        self.global_type_vars[name] = cons

    def _check_existing_typ_expr(self, name, new_expr):
        if name in self.global_type_vars:
            new_typ_name = self._type_expr_name(new_expr)
            existing_typ_name = self._type_expr_name(self.global_type_vars[name])
            raise ParseError(
                f"{new_typ_name} `{name}` conflicts with existing {existing_typ_name}")

    def _type_expr_name(self, e):
        if isinstance(e, adt.Constructor):
            return f"`{e.belong_to.var.name}` ADT constructor"
        elif isinstance(e, ty.GlobalTypeVar):
            if e.kind == ty.Kind.AdtHandle:
                return f"ADT definition"
        return "function definition"

    def visitProjection(self, ctx):
        return expr.TupleGetItem(self.visit(ctx.expr()), self.visit(ctx.NAT()))

    def visitTerminal(self, node) -> Union[expr.Expr, int, float]:
        """Visit lexer tokens that aren't ignored or visited by other functions."""
        node_type = node.getSymbol().type
        node_text = node.getText()

        if node_type == RelayLexer.NAT:
            return int(node_text)
        if node_type == RelayLexer.FLOAT:
            return float(node_text[:-1])
        if node_type == RelayLexer.BOOL_LIT:
            if node_text == "True":
                return True
            if node_text == "False":
                return False
            raise ParseError("unrecognized BOOL_LIT: `{}`".format(node_text))
        if node_type == RelayLexer.QUOTED_STRING:
            return literal_eval(node_text)
        raise ParseError(f"unhandled terminal \"{node_text}\" of type `{node_type}`")

    def visitGeneralIdent(self, ctx):
        name = ctx.getText()
        # Look through all type prefixes for a match.
        for type_prefix in TYPE_PREFIXES:
            if name.startswith(type_prefix):
                return ty.scalar_type(name)
        # Next, look it up in the local then global type params.
        type_expr = lookup(self.type_var_scopes, name)
        if type_expr is None:
            type_expr = self.global_type_vars.get(name, None)
        if type_expr is not None:
            # Zero-arity constructor calls fall into the general ident case, so in that case,
            # we construct a constructor call with no args.
            if isinstance(type_expr, adt.Constructor) and not type_expr.inputs:
                type_expr = expr.Call(type_expr, [])
            return type_expr
        # Check if it's an operator.
        op_name = ".".join([name.getText() for name in ctx.CNAME()])
        if op_name in FUNC_OPS:
            return FuncOp(FUNC_OPS[op_name])
        return ExprOp(op.get(op_name))

    def visitGlobalVar(self, ctx):
        var_name = ctx.CNAME().getText()
        global_var = self.global_vars.get(var_name, None)
        if global_var is None:
            raise ParseError(f"unbound global var `{var_name}`")
        return global_var

    def visitLocalVar(self, ctx):
        var_name = ctx.CNAME().getText()
        local_var = lookup(self.var_scopes, var_name)
        if local_var is None:
            raise ParseError(f"unbound local var `{var_name}`")
        return local_var

    def visitGraphVar(self, ctx):
        return self.graph_expr[int(ctx.NAT().getText())]

    def visit_list(self, ctx_list) -> List[Any]:
        """"Visit a list of contexts."""
        assert isinstance(ctx_list, list)

        return [self.visit(ctx) for ctx in ctx_list]

    def getTypeExpr(self, ctx: Optional[RelayParser.TypeExprContext]) -> Optional[ty.Type]:
        """Return a (possibly None) Relay type."""
        if ctx is None:
            return None

        return self.visit(ctx)

    def visitProg(self, ctx: RelayParser.ProgContext) -> Union[expr.Expr, module.Module]:
        self.meta = None
        if ctx.METADATA():
            header, data = str(ctx.METADATA()).split("\n", 1)
            assert header == "METADATA:"
            self.meta = tvm.load_json(data)
        if ctx.defn():
            self.visit_list(ctx.defn())
            return self.module

        if ctx.expr():
            return self.visit(ctx.expr())

        return self.module

    # Exprs
    def visitOpIdent(self, ctx) -> op.Op:
        op_name = ".".join([name.getText() for name in ctx.CNAME()])
        if op_name in FUNC_OPS:
            return FuncOp(FUNC_OPS[op_name])
        return ExprOp(op.get(op_name))

    # pass through
    def visitParen(self, ctx: RelayParser.ParenContext) -> expr.Expr:
        return self.visit(ctx.expr())

    # pass through
    def visitTypeParen(self, ctx: RelayParser.TypeParenContext) -> expr.Expr:
        return self.visit(ctx.typeExpr())

    # pass through
    def visitBody(self, ctx: RelayParser.BodyContext) -> expr.Expr:
        return self.visit(ctx.expr())

    def visitScalarFloat(self, ctx: RelayParser.ScalarFloatContext) -> expr.Constant:
        return expr.const(self.visit(ctx.FLOAT()))

    def visitScalarInt(self, ctx: RelayParser.ScalarIntContext) -> expr.Constant:
        return expr.const(self.visit(ctx.NAT()))

    def visitScalarBool(self, ctx: RelayParser.ScalarBoolContext) -> expr.Constant:
        return expr.const(self.visit(ctx.BOOL_LIT()))

    def visitNeg(self, ctx: RelayParser.NegContext) -> Union[expr.Constant, expr.Call]:
        val = self.visit(ctx.expr())
        if isinstance(val, expr.Constant) and val.data.asnumpy().ndim == 0:
            # fold Neg in for scalars
            return expr.const(-val.data.asnumpy().item())

        return op.negative(val)

    def visitTuple(self, ctx: RelayParser.TupleContext) -> expr.Tuple:
        tup = self.visit_list(ctx.expr())
        return expr.Tuple(tup)

    def visitLet(self, ctx: RelayParser.LetContext) -> expr.Let:
        """Desugar various sequence constructs to Relay Let nodes."""

        if ctx.var() is None:
            # anonymous identity
            ident = "_"
            typ = None
            var = self.mk_var(ident, typ)
        else:
            var = self.visitVar(ctx.var())

        self.enter_var_scope()
        value = self.visit(ctx.expr(0))
        self.exit_var_scope()

        body = self.visit(ctx.expr(1))

        return expr.Let(var, value, body)

    def visitBinOp(self, ctx: RelayParser.BinOpContext) -> expr.Call:
        """Desugar binary operators."""
        arg0, arg1 = self.visit_list(ctx.expr())
        relay_op = BINARY_OPS.get(ctx.op.type)

        if relay_op is None:
            raise ParseError("unimplemented binary op.")

        return relay_op(arg0, arg1)

    @spanify
    def visitVar(self, ctx: RelayParser.VarContext) -> expr.Var:
        """Visit a single variable."""
        ident = ctx.localVar()

        if ident is None:
            raise ParseError("only local ids may be used in vars.")

        typeExpr = self.getTypeExpr(ctx.typeExpr())

        return self.mk_var(ident.getText()[1:], typeExpr)

    def visitVarList(self, ctx: RelayParser.VarListContext) -> List[expr.Var]:
        return self.visit_list(ctx.var())

    # TODO: support a larger class of values than just Relay exprs
    def visitAttr(self, ctx: RelayParser.AttrContext) -> Tuple[str, expr.Expr]:
        return (ctx.CNAME().getText(), self.visit(ctx.expr()))

    def visitArgNoAttr(self, ctx: RelayParser.ArgNoAttrContext):
        return (self.visit_list(ctx.varList().var()), None)

    def visitAttrSeq(self, ctx: RelayParser.AttrSeqContext) -> Dict[str, expr.Expr]:
        return dict(self.visit_list(ctx.attr()))

    def visitArgWithAttr(self, ctx: RelayParser.AttrSeqContext) \
        -> Tuple[List[expr.Var], Dict[str, expr.Expr]]:
        return (self.visit_list(ctx.var()), self.visitAttrSeq(ctx.attrSeq()))

    def visitArgList(self, ctx: RelayParser.ArgListContext) \
            -> Tuple[Optional[List[expr.Var]], Optional[Dict[str, expr.Expr]]]:
        var_list = self.visit(ctx.varList()) if ctx.varList() else None
        attr_list = self.visit(ctx.attrList()) if ctx.attrList() else None
        return (var_list, attr_list)

    def visitMeta(self, ctx: RelayParser.MetaContext):
        type_key = str(ctx.CNAME())
        index = int(self.visit(ctx.NAT()))
        return self.meta[type_key][index]

    def mk_func(
            self,
            ctx: Union[RelayParser.FuncContext, RelayParser.DefnContext]) \
            -> expr.Function:
        """Construct a function from either a Func or Defn."""
        # Enter var scope early to put params in scope.
        self.enter_var_scope()
        # Capture type params in params.
        self.enter_type_param_scope()
        type_params = ctx.typeParamList()

        if type_params is not None:
            type_params = type_params.typeExpr()
            assert type_params
            for ty_param in type_params:
                name = ty_param.getText()
                self.mk_typ(name, ty.Kind.Type)

        var_list, attr_list = self.visit(ctx.argList())
        if var_list is None:
            var_list = []
        ret_type = self.getTypeExpr(ctx.typeExpr())

        body = self.visit(ctx.body())
        # NB(@jroesch): you must stay in the type parameter scope until
        # after you exit the body, you can reference the type parameters
        # of your parent scopes.
        type_params = list(self.exit_type_param_scope())
        if type_params:
            _, type_params = zip(*type_params)
        self.exit_var_scope()

        attrs = tvm.make.node("DictAttrs", **attr_list) if attr_list is not None else None
        return expr.Function(var_list, body, ret_type, type_params, attrs)

    @spanify
    def visitFunc(self, ctx: RelayParser.FuncContext) -> expr.Function:
        return self.mk_func(ctx)

    # TODO: how to set spans for definitions?
    # @spanify
    def visitFuncDefn(self, ctx: RelayParser.DefnContext) -> None:
        ident_name = ctx.globalVar().getText()[1:]
        ident = self.mk_global_var(ident_name)
        func = self.mk_func(ctx)
        self.module[ident] = func

    def handle_adt_header(
            self,
            ctx: Union[RelayParser.ExternAdtDefnContext, RelayParser.AdtDefnContext]):
        """Handles parsing of the name and type params of an ADT definition."""
        adt_name = ctx.generalIdent().getText()
        adt_var = self.mk_global_typ_var(adt_name, ty.Kind.AdtHandle)
        # parse type params
        type_params = ctx.typeParamList()
        if type_params is None:
            type_params = []
        else:
            type_params = [self.mk_typ(type_ident.getText(), ty.Kind.Type)
                           for type_ident in type_params.typeExpr()]
        return adt_var, type_params

    def visitExternAdtDefn(self, ctx: RelayParser.ExternAdtDefnContext):
        # TODO(weberlo): update this handler once extern is implemented
        self.enter_type_param_scope()
        adt_var, type_params = self.handle_adt_header(ctx)
        # update module being built
        self.module[adt_var] = adt.TypeData(adt_var, type_params, [])
        self.exit_type_param_scope()

    def visitAdtDefn(self, ctx: RelayParser.AdtDefnContext):
        self.enter_type_param_scope()
        adt_var, type_params = self.handle_adt_header(ctx)
        # parse constructors
        adt_cons_defns = ctx.adtConsDefnList()
        if adt_cons_defns is None:
            adt_cons_defns = []
        else:
            adt_cons_defns = adt_cons_defns.adtConsDefn()
        parsed_constructors = []
        for cons_defn in adt_cons_defns:
            inputs = [self.visit(inp) for inp in cons_defn.typeExpr()]
            cons_defn_name = cons_defn.constructorName().getText()
            cons_defn = adt.Constructor(cons_defn_name, inputs, adt_var)
            self.mk_global_typ_cons(cons_defn_name, cons_defn)
            parsed_constructors.append(cons_defn)
        # update module being built
        self.module[adt_var] = adt.TypeData(adt_var, type_params, parsed_constructors)
        self.exit_type_param_scope()

    def visitMatch(self, ctx: RelayParser.MatchContext):
        match_type = ctx.matchType().getText()
        if match_type == "match":
            complete_match = True
        elif match_type == "match?":
            complete_match = False
        else:
            raise RuntimeError(f"unknown match type {match_type}")

        match_data = self.visit(ctx.expr())
        match_clauses = ctx.matchClauseList()
        if match_clauses is None:
            match_clauses = []
        else:
            match_clauses = match_clauses.matchClause()
        parsed_clauses = []
        for clause in match_clauses:
            self.enter_var_scope()
            pattern = self.visit(clause.pattern())
            clause_body = self.visit(clause.expr())
            self.exit_var_scope()
            parsed_clauses.append(adt.Clause(pattern, clause_body))
        return adt.Match(match_data, parsed_clauses, complete=complete_match)

    def visitWildcardPattern(self, ctx: RelayParser.WildcardPatternContext):
        return adt.PatternWildcard()

    def visitVarPattern(self, ctx: RelayParser.VarPatternContext):
        text = ctx.localVar().getText()
        typ = ctx.typeExpr()
        if typ is not None:
            typ = self.visit(typ)
        var = self.mk_var(text[1:], typ=typ)
        return adt.PatternVar(var)

    def visitConstructorPattern(self, ctx: RelayParser.ConstructorPatternContext):
        constructor_name = ctx.constructorName().getText()
        constructor = self.global_type_vars[constructor_name]
        pattern_list = ctx.patternList()
        if pattern_list is None:
            patterns = []
        else:
            patterns = [self.visit(pattern) for pattern in pattern_list.pattern()]
        return adt.PatternConstructor(constructor, patterns)

    def visitTuplePattern(self, ctx: RelayParser.TuplePatternContext):
        return adt.PatternTuple([self.visit(pattern) for pattern in ctx.patternList().pattern()])

    def visitCallNoAttr(self, ctx: RelayParser.CallNoAttrContext):
        return (self.visit_list(ctx.exprList().expr()), None)

    def visitCallWithAttr(self, ctx: RelayParser.CallWithAttrContext):
        return (self.visit_list(ctx.expr()), self.visit(ctx.attrSeq()))

    def call(self, func, args, attrs, type_args):
        if isinstance(func, OpWrapper):
            return func(args, attrs, type_args)
        elif isinstance(func, adt.Constructor):
            return func(*args)
        return expr.Call(func, args, attrs, type_args)

    @spanify
    def visitCall(self, ctx: RelayParser.CallContext) -> expr.Call:
        func = self.visit(ctx.expr())
        args, attrs = self.visit(ctx.callList())
        res = self.call(func, args, attrs, [])
        return res

    @spanify
    def visitIfElse(self, ctx: RelayParser.IfElseContext) -> expr.If:
        """Construct a Relay If node. Creates a new scope for each branch."""
        cond = self.visit(ctx.expr())

        self.enter_var_scope()
        true_branch = self.visit(ctx.body(0))
        self.exit_var_scope()

        self.enter_var_scope()
        false_branch = self.visit(ctx.body(1))
        self.exit_var_scope()

        return expr.If(cond, true_branch, false_branch)

    @spanify
    def visitGraph(self, ctx: RelayParser.GraphContext) -> expr.Expr:
        """Visit a graph variable assignment."""
        graph_nid = int(ctx.graphVar().getText()[1:])

        self.enter_var_scope()
        value = self.visit(ctx.expr(0))
        self.exit_var_scope()

        if graph_nid != len(self.graph_expr):
            raise ParseError(
                "expected new graph variable to be `%{}`,".format(len(self.graph_expr)) + \
                "but got `%{}`".format(graph_nid))
        self.graph_expr.append(value)

        kont = self.visit(ctx.expr(1))
        return kont

    # Types

    # pylint: disable=unused-argument
    def visitIncompleteType(self, ctx: RelayParser.IncompleteTypeContext) -> None:
        return None

    def visitTypeCallType(self, ctx: RelayParser.TypeCallTypeContext):
        func = self.visit(ctx.generalIdent())
        args = [self.visit(arg) for arg in ctx.typeParamList().typeExpr()]
        return ty.TypeCall(func, args)

    def visitParensShape(self, ctx: RelayParser.ParensShapeContext) -> int:
        return self.visit(ctx.shape())

    def visitShapeList(self, ctx: RelayParser.ShapeListContext) -> List[int]:
        return self.visit_list(ctx.shape())

    def visitTensor(self, ctx: RelayParser.TensorContext):
        return tuple(self.visit_list(ctx.expr()))

    def visitTensorType(self, ctx: RelayParser.TensorTypeContext) -> ty.TensorType:
        """Create a simple tensor type. No generics."""

        shape = self.visit(ctx.shapeList())
        dtype = self.visit(ctx.typeExpr())

        if not isinstance(dtype, ty.TensorType):
            raise ParseError("expected dtype to be a Relay base type.")

        dtype = dtype.dtype

        return ty.TensorType(shape, dtype)

    def visitTupleType(self, ctx: RelayParser.TupleTypeContext) -> ty.TupleType:
        return ty.TupleType(self.visit_list(ctx.typeExpr()))

    def visitFuncType(self, ctx: RelayParser.FuncTypeContext) -> ty.FuncType:
        types = self.visit_list(ctx.typeExpr())

        arg_types = types[:-1]
        ret_type = types[-1]

        return ty.FuncType(arg_types, ret_type, [], None)

def make_parser(data: str) -> RelayParser:
    """Construct a RelayParser a given data stream."""
    input_stream = InputStream(data)
    lexer = RelayLexer(input_stream)
    lexer.addErrorListener(StrictErrorListener(data))
    token_stream = CommonTokenStream(lexer)
    p = RelayParser(token_stream)
    p.addErrorListener(StrictErrorListener(data))
    return p

__source_name_counter__ = 0

class StrictErrorListener(ErrorListener):
    """This ErrorListener fail eagerly on all error, and report the program."""
    def __init__(self, text):
        self.text = text

    def syntaxError(self, recognizer, offendingSymbol, line, column, msg, e):
        raise Exception("Syntax Error in:\n" + self.text)

    def reportAmbiguity(self, recognizer, dfa, startIndex, stopIndex, exact, ambigAlts, configs):
        raise Exception("Ambiguity Error in:\n" + self.text)

    def reportAttemptingFullContext(self,
                                    recognizer,
                                    dfa,
                                    startIndex,
                                    stopIndex,
                                    conflictingAlts,
                                    configs):
        raise Exception("Attempting Full Context in:\n" + self.text)

    def reportContextSensitivity(self, recognizer, dfa, startIndex, stopIndex, prediction, configs):
        raise Exception("Context Sensitivity in:\n" + self.text)

def fromtext(data: str, source_name: str = None) -> Union[expr.Expr, module.Module]:
    """Parse a Relay program."""
    if data == "":
        raise ParseError("cannot parse the empty string.")

    global __source_name_counter__

    if source_name is None:
        source_name = "source_file{0}".format(__source_name_counter__)

    if isinstance(source_name, str):
        source_name = SourceName(source_name)

    tree = make_parser(data).prog()
    return ParseTreeToRelayIR(source_name).visit(tree)