_parser.py 27.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17

18
# pylint: disable=invalid-name, unused-argument
19 20 21 22
"""A parser for Relay's text format."""
from __future__ import absolute_import

import sys
23
from ast import literal_eval
24
from collections import deque
25

26 27 28 29 30 31 32 33 34 35 36 37 38
try:
    # no typing.Deque in Python 3.5
    # https://bugs.python.org/issue29011
    from typing import Any, Dict, List, Optional, TypeVar, Tuple, Union, MutableSequence, T, Deque
except ImportError:
    class Deque(deque, MutableSequence[T], extra=deque):

        def __new__(cls, *args, **kwds):
            if _geqv(cls, Deque):
                raise TypeError("Type Deque cannot be instantiated; "
                                "use deque() instead")
            return deque.__new__(cls, *args, **kwds)

39
import tvm
40 41

from . import module
42
from .base import Span, SourceName
43
from . import adt
44 45 46 47
from . import expr
from . import ty
from . import op

48 49 50 51 52 53
PYTHON_VERSION = sys.version_info.major
try:
    from .grammar.py3.RelayVisitor import RelayVisitor
    from .grammar.py3.RelayParser import RelayParser
    from .grammar.py3.RelayLexer import RelayLexer
except ImportError:
54
    raise Exception("Couldn't find ANTLR parser. Try building with USE_ANTLR=ON.")
55 56 57 58 59 60 61 62 63 64

try:
    from antlr4 import InputStream, CommonTokenStream
    from antlr4.error.ErrorListener import ErrorListener
except ImportError:
    raise Exception("Couldn't find ANTLR runtime." +
                    "Try running `pip{version} install antlr4-python{version}-runtime`."
                    .format(version=PYTHON_VERSION))

sys.setrecursionlimit(10000)
65

66 67 68
class ParseError(Exception):
    """Exception type for parse errors."""

69
    def __init__(self, message: str) -> None:
70 71 72
        super(ParseError, self).__init__()
        self.message = message

73 74
    def __repr__(self):
        return "ParseError({})".format(self.message)
75

76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
    def __str__(self):
        return repr(self)

class OpWrapper:
    """Overload the __call__ for op."""
    pass

class ExprOp(OpWrapper):
    """Call an expr. The default, but does not handle attrs well."""
    def __init__(self, operator):
        self.operator = operator

    def __call__(self, args, attrs, type_args):
        try:
            return expr.Call(self.operator, args, attrs, type_args)
        except Exception:
92 93
            raise Exception("Operator {} is not registered. It's attributes are {}"
                            .format(self.operator, attrs))
94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117

class FuncOp(OpWrapper):
    """Convert the attrs, call the python function with the attrs passed in as keyword arguments.
    Tvm should provide this in the future, as this is pretty similar to what op.get is providing.
    """
    def __init__(self, operator):
        self.operator = operator

    def convert(self, v):
        if isinstance(v, tuple):
            return tuple([self.convert(x) for x in v])
        if isinstance(v, expr.Constant):
            return v.data.asnumpy().item()
        if isinstance(v, str):
            return v
        raise Exception(v)

    def __call__(self, args, attrs, type_args):
        if attrs is None:
            attrs = {}
        x = self.operator(*args, **{k: self.convert(v) for k, v in attrs.items()})
        if isinstance(x, expr.TupleWrapper):
            x = x.astuple()
        return x
118 119 120 121 122 123 124 125 126 127 128 129 130 131

BINARY_OPS = {
    RelayParser.MUL: op.multiply,
    RelayParser.DIV: op.divide,
    RelayParser.ADD: op.add,
    RelayParser.SUB: op.subtract,
    RelayParser.LT:  op.less,
    RelayParser.GT:  op.greater,
    RelayParser.LE:  op.less_equal,
    RelayParser.GE:  op.greater_equal,
    RelayParser.EQ:  op.equal,
    RelayParser.NE:  op.not_equal,
}

132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
FUNC_OPS = {
    "nn.conv2d": op.nn.conv2d,
    "nn.batch_norm": op.nn.batch_norm,
    "nn.dense": op.nn.dense,
    "nn.bias_add": op.nn.bias_add,
    "nn.max_pool2d": op.nn.max_pool2d,
    "nn.global_max_pool2d": op.nn.global_max_pool2d,
    "nn.avg_pool2d": op.nn.avg_pool2d,
    "nn.global_avg_pool2d": op.nn.global_avg_pool2d,
    "nn.softmax": op.nn.softmax,
    "reshape": op.reshape,
    "nn.conv2d_transpose": op.nn.conv2d_transpose,
    "concatenate": op.concatenate,
    "nn.dropout": op.nn.dropout_raw,
    "zeros": op.zeros,
    "split": op.split,
148
    "cast": op.cast
149 150
}

151 152 153 154 155 156 157
TYPE_PREFIXES = [
    "int",
    "uint",
    "float",
    "bool",
]

158 159 160
T = TypeVar("T")
Scope = Deque[Tuple[str, T]]
Scopes = Deque[Scope[T]]
161

162
def lookup(scopes: Scopes[T], name: str) -> Optional[T]:
163 164 165 166 167 168 169 170
    """Look up `name` in `scopes`."""

    for scope in scopes:
        for key, val in scope:
            if key == name:
                return val
    return None

171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187
def spanify(f):
    """A decorator which attaches span information
       to the value returned by calling `f`.

       Intended for use with the below AST visiting
       methods. The idea is that after we do the work
       of constructing the AST we attach Span information.
    """

    def _wrapper(*args, **kwargs):
        # Assumes 0th arg is self and gets source_name from object.
        sn = args[0].source_name
        # Assumes 1st arg is an ANTLR parser context.
        ctx = args[1]
        ast = f(*args, **kwargs)
        line, col = ctx.getSourceInterval()
        sp = Span(sn, line, col)
188 189
        if isinstance(ast, tvm.relay.expr.TupleWrapper):
            ast = ast.astuple()
190 191 192 193
        ast.set_span(sp)
        return ast
    return _wrapper

194 195 196 197 198
# TODO(@jmp): Use https://stackoverflow.com/q/13889941
# to figure out how to get ANTLR4 to be more unhappy about syntax errors
class ParseTreeToRelayIR(RelayVisitor):
    """Parse Relay text format into Relay IR."""

199
    def __init__(self, source_name: str) -> None:
200
        self.source_name = source_name
201
        self.module = module.Module({})  # type: module.Module
202 203

        # Adding an empty scope allows naked lets without pain.
204 205 206 207 208
        self.var_scopes = deque([deque()])       # type: Scopes[expr.Var]
        self.global_vars = {}                    # type: Scope[expr.GlobalVar]
        self.type_var_scopes = deque([deque()])  # type: Scopes[ty.TypeVar]
        self.global_type_vars = {}               # type: Scope[expr.GlobalVar]
        self.graph_expr = []                     # type: List[expr.Expr]
209 210 211

        super(ParseTreeToRelayIR, self).__init__()

212

213
    def enter_var_scope(self) -> None:
214 215 216
        """Enter a new Var scope so it can be popped off later."""
        self.var_scopes.appendleft(deque())

217
    def exit_var_scope(self) -> Scope[expr.Var]:
218 219 220
        """Pop off the current Var scope and return it."""
        return self.var_scopes.popleft()

221
    def mk_var(self, name: str, typ: ty.Type = None):
222
        """Create a new Var and add it to the Var scope."""
223
        var = expr.Var(name, typ)
224 225 226
        self.var_scopes[0].appendleft((name, var))
        return var

227
    def mk_global_var(self, name: str) -> expr.GlobalVar:
228
        """Create a new GlobalVar and add it to the GlobalVar scope."""
229
        if name in self.global_vars:
230
            raise ParseError("duplicate global var \"{0}\"".format(name))
231
        var = expr.GlobalVar(name)
232
        self.global_vars[name] = var
233 234
        return var

235
    def enter_type_param_scope(self) -> None:
236
        """Enter a new TypeVar scope so it can be popped off later."""
237
        self.type_var_scopes.appendleft(deque())
238

239
    def exit_type_param_scope(self) -> Scope[ty.TypeVar]:
240
        """Pop off the current TypeVar scope and return it."""
241
        return self.type_var_scopes.popleft()
242

243
    def mk_typ(self, name: str, kind: ty.Kind) -> ty.TypeVar:
244 245
        """Create a new TypeVar and add it to the TypeVar scope."""
        typ = ty.TypeVar(name, kind)
246
        self.type_var_scopes[0].append((name, typ))
247 248 249 250 251 252 253 254
        return typ

    def mk_global_typ_var(self, name, kind):
        # (str, ty.Kind) -> ty.GlobalTypeVar
        """Create a new TypeVar and add it to the TypeVar scope."""
        typ = ty.GlobalTypeVar(name, kind)
        self._check_existing_typ_expr(name, typ)
        self.global_type_vars[name] = typ
255 256
        return typ

257
    # TODO(weberlo): rethink whether we should have type constructors mixed with type vars.
258 259 260 261 262 263 264 265 266
    def mk_global_typ_cons(self, name, cons):
        self._check_existing_typ_expr(name, cons)
        self.global_type_vars[name] = cons

    def _check_existing_typ_expr(self, name, new_expr):
        if name in self.global_type_vars:
            new_typ_name = self._type_expr_name(new_expr)
            existing_typ_name = self._type_expr_name(self.global_type_vars[name])
            raise ParseError(
267 268
                "{0} `{1}` conflicts with existing {2}".format(new_typ_name,\
                                                                name, existing_typ_name))
269 270 271

    def _type_expr_name(self, e):
        if isinstance(e, adt.Constructor):
272
            return "`{0}` ADT constructor".format(e.belong_to.var.name)
273 274
        elif isinstance(e, ty.GlobalTypeVar):
            if e.kind == ty.Kind.AdtHandle:
275
                return "ADT definition"
276 277
        return "function definition"

278 279 280
    def visitProjection(self, ctx):
        return expr.TupleGetItem(self.visit(ctx.expr()), self.visit(ctx.NAT()))

281
    def visitTerminal(self, node) -> Union[expr.Expr, int, float]:
282 283 284
        """Visit lexer tokens that aren't ignored or visited by other functions."""
        node_type = node.getSymbol().type
        node_text = node.getText()
285

286
        if node_type == RelayLexer.NAT:
287
            return int(node_text)
288
        if node_type == RelayLexer.FLOAT:
289
            return float(node_text[:-1])
290
        if node_type == RelayLexer.BOOL_LIT:
291 292
            if node_text == "True":
                return True
293
            if node_text == "False":
294
                return False
295
            raise ParseError("unrecognized BOOL_LIT: `{}`".format(node_text))
296 297
        if node_type == RelayLexer.QUOTED_STRING:
            return literal_eval(node_text)
298
        raise ParseError("unhandled terminal \"{0}\" of type `{1}`".format(node_text, node_type))
299 300 301 302 303 304 305 306

    def visitGeneralIdent(self, ctx):
        name = ctx.getText()
        # Look through all type prefixes for a match.
        for type_prefix in TYPE_PREFIXES:
            if name.startswith(type_prefix):
                return ty.scalar_type(name)
        # Next, look it up in the local then global type params.
307 308 309 310 311 312 313 314 315
        type_expr = lookup(self.type_var_scopes, name)
        if type_expr is None:
            type_expr = self.global_type_vars.get(name, None)
        if type_expr is not None:
            # Zero-arity constructor calls fall into the general ident case, so in that case,
            # we construct a constructor call with no args.
            if isinstance(type_expr, adt.Constructor) and not type_expr.inputs:
                type_expr = expr.Call(type_expr, [])
            return type_expr
316 317 318 319 320 321 322 323 324 325
        # Check if it's an operator.
        op_name = ".".join([name.getText() for name in ctx.CNAME()])
        if op_name in FUNC_OPS:
            return FuncOp(FUNC_OPS[op_name])
        return ExprOp(op.get(op_name))

    def visitGlobalVar(self, ctx):
        var_name = ctx.CNAME().getText()
        global_var = self.global_vars.get(var_name, None)
        if global_var is None:
326
            raise ParseError("unbound global var `{0}`".format(var_name))
327 328 329 330 331 332
        return global_var

    def visitLocalVar(self, ctx):
        var_name = ctx.CNAME().getText()
        local_var = lookup(self.var_scopes, var_name)
        if local_var is None:
333
            raise ParseError("unbound local var `{0}`".format(var_name))
334
        return local_var
335

336 337
    def visitGraphVar(self, ctx):
        return self.graph_expr[int(ctx.NAT().getText())]
338

339
    def visit_list(self, ctx_list) -> List[Any]:
340
        """"Visit a list of contexts."""
341
        assert isinstance(ctx_list, list)
342 343 344

        return [self.visit(ctx) for ctx in ctx_list]

345
    def getTypeExpr(self, ctx: Optional[RelayParser.TypeExprContext]) -> Optional[ty.Type]:
346 347 348 349 350 351
        """Return a (possibly None) Relay type."""
        if ctx is None:
            return None

        return self.visit(ctx)

352
    def visitProg(self, ctx: RelayParser.ProgContext) -> Union[expr.Expr, module.Module]:
353 354
        self.meta = None
        if ctx.METADATA():
355
            header, data = str(ctx.METADATA()).split("\n", 1)
356 357
            assert header == "METADATA:"
            self.meta = tvm.load_json(data)
358 359 360 361
        if ctx.defn():
            self.visit_list(ctx.defn())
            return self.module

362 363
        if ctx.expr():
            return self.visit(ctx.expr())
364

365
        return self.module
366

367
    # Exprs
368 369
    def visitOpIdent(self, ctx) -> op.Op:
        op_name = ".".join([name.getText() for name in ctx.CNAME()])
370 371 372
        if op_name in FUNC_OPS:
            return FuncOp(FUNC_OPS[op_name])
        return ExprOp(op.get(op_name))
373 374

    # pass through
375
    def visitParen(self, ctx: RelayParser.ParenContext) -> expr.Expr:
376 377 378
        return self.visit(ctx.expr())

    # pass through
379 380 381 382
    def visitTypeParen(self, ctx: RelayParser.TypeParenContext) -> expr.Expr:
        return self.visit(ctx.typeExpr())

    # pass through
383
    def visitBody(self, ctx: RelayParser.BodyContext) -> expr.Expr:
384 385
        return self.visit(ctx.expr())

386
    def visitScalarFloat(self, ctx: RelayParser.ScalarFloatContext) -> expr.Constant:
387 388
        return expr.const(self.visit(ctx.FLOAT()))

389
    def visitScalarInt(self, ctx: RelayParser.ScalarIntContext) -> expr.Constant:
390
        return expr.const(self.visit(ctx.NAT()))
391

392
    def visitScalarBool(self, ctx: RelayParser.ScalarBoolContext) -> expr.Constant:
393 394
        return expr.const(self.visit(ctx.BOOL_LIT()))

395
    def visitNeg(self, ctx: RelayParser.NegContext) -> Union[expr.Constant, expr.Call]:
396 397 398 399 400 401 402
        val = self.visit(ctx.expr())
        if isinstance(val, expr.Constant) and val.data.asnumpy().ndim == 0:
            # fold Neg in for scalars
            return expr.const(-val.data.asnumpy().item())

        return op.negative(val)

403
    def visitTuple(self, ctx: RelayParser.TupleContext) -> expr.Tuple:
404 405 406
        tup = self.visit_list(ctx.expr())
        return expr.Tuple(tup)

407
    def visitLet(self, ctx: RelayParser.LetContext) -> expr.Let:
408 409
        """Desugar various sequence constructs to Relay Let nodes."""

410
        if ctx.var() is None:
411 412
            # anonymous identity
            ident = "_"
413 414
            typ = None
            var = self.mk_var(ident, typ)
415
        else:
416
            var = self.visitVar(ctx.var())
417 418 419 420 421 422 423 424 425

        self.enter_var_scope()
        value = self.visit(ctx.expr(0))
        self.exit_var_scope()

        body = self.visit(ctx.expr(1))

        return expr.Let(var, value, body)

426
    def visitBinOp(self, ctx: RelayParser.BinOpContext) -> expr.Call:
427 428 429 430 431
        """Desugar binary operators."""
        arg0, arg1 = self.visit_list(ctx.expr())
        relay_op = BINARY_OPS.get(ctx.op.type)

        if relay_op is None:
432
            raise ParseError("unimplemented binary op.")
433 434 435

        return relay_op(arg0, arg1)

436
    @spanify
437
    def visitVar(self, ctx: RelayParser.VarContext) -> expr.Var:
438
        """Visit a single variable."""
439
        ident = ctx.localVar()
440 441

        if ident is None:
442
            raise ParseError("only local ids may be used in vars.")
443

444
        typeExpr = self.getTypeExpr(ctx.typeExpr())
445

446
        return self.mk_var(ident.getText()[1:], typeExpr)
447

448
    def visitVarList(self, ctx: RelayParser.VarListContext) -> List[expr.Var]:
449 450
        return self.visit_list(ctx.var())

451
    # TODO: support a larger class of values than just Relay exprs
452
    def visitAttr(self, ctx: RelayParser.AttrContext) -> Tuple[str, expr.Expr]:
453 454
        return (ctx.CNAME().getText(), self.visit(ctx.expr()))

455
    def visitArgNoAttr(self, ctx: RelayParser.ArgNoAttrContext):
456 457
        return (self.visit_list(ctx.varList().var()), None)

458
    def visitAttrSeq(self, ctx: RelayParser.AttrSeqContext) -> Dict[str, expr.Expr]:
459 460
        return dict(self.visit_list(ctx.attr()))

461 462
    def visitArgWithAttr(self, ctx: RelayParser.AttrSeqContext) \
        -> Tuple[List[expr.Var], Dict[str, expr.Expr]]:
463 464
        return (self.visit_list(ctx.var()), self.visitAttrSeq(ctx.attrSeq()))

465 466
    def visitArgList(self, ctx: RelayParser.ArgListContext) \
            -> Tuple[Optional[List[expr.Var]], Optional[Dict[str, expr.Expr]]]:
467 468 469 470
        var_list = self.visit(ctx.varList()) if ctx.varList() else None
        attr_list = self.visit(ctx.attrList()) if ctx.attrList() else None
        return (var_list, attr_list)

471
    def visitMeta(self, ctx: RelayParser.MetaContext):
472 473 474 475
        type_key = str(ctx.CNAME())
        index = int(self.visit(ctx.NAT()))
        return self.meta[type_key][index]

476 477 478 479
    def mk_func(
            self,
            ctx: Union[RelayParser.FuncContext, RelayParser.DefnContext]) \
            -> expr.Function:
480 481 482 483 484
        """Construct a function from either a Func or Defn."""
        # Enter var scope early to put params in scope.
        self.enter_var_scope()
        # Capture type params in params.
        self.enter_type_param_scope()
485
        type_params = ctx.typeParamList()
486 487

        if type_params is not None:
488
            type_params = type_params.typeExpr()
489 490 491 492 493
            assert type_params
            for ty_param in type_params:
                name = ty_param.getText()
                self.mk_typ(name, ty.Kind.Type)

494
        var_list, attr_list = self.visit(ctx.argList())
495 496
        if var_list is None:
            var_list = []
497
        ret_type = self.getTypeExpr(ctx.typeExpr())
498

499 500 501 502
        body = self.visit(ctx.body())
        # NB(@jroesch): you must stay in the type parameter scope until
        # after you exit the body, you can reference the type parameters
        # of your parent scopes.
503 504 505 506 507
        type_params = list(self.exit_type_param_scope())
        if type_params:
            _, type_params = zip(*type_params)
        self.exit_var_scope()

508 509
        attrs = tvm.make.node("DictAttrs", **attr_list) if attr_list is not None else None
        return expr.Function(var_list, body, ret_type, type_params, attrs)
510

511
    @spanify
512
    def visitFunc(self, ctx: RelayParser.FuncContext) -> expr.Function:
513 514
        return self.mk_func(ctx)

515 516
    # TODO: how to set spans for definitions?
    # @spanify
517 518
    def visitFuncDefn(self, ctx: RelayParser.DefnContext) -> None:
        ident_name = ctx.globalVar().getText()[1:]
519
        ident = self.mk_global_var(ident_name)
520 521
        func = self.mk_func(ctx)
        self.module[ident] = func
522

523 524 525 526 527 528 529 530 531 532 533 534
    def handle_adt_header(
            self,
            ctx: Union[RelayParser.ExternAdtDefnContext, RelayParser.AdtDefnContext]):
        """Handles parsing of the name and type params of an ADT definition."""
        adt_name = ctx.generalIdent().getText()
        adt_var = self.mk_global_typ_var(adt_name, ty.Kind.AdtHandle)
        # parse type params
        type_params = ctx.typeParamList()
        if type_params is None:
            type_params = []
        else:
            type_params = [self.mk_typ(type_ident.getText(), ty.Kind.Type)
535
                           for type_ident in type_params.typeExpr()]
536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572
        return adt_var, type_params

    def visitExternAdtDefn(self, ctx: RelayParser.ExternAdtDefnContext):
        # TODO(weberlo): update this handler once extern is implemented
        self.enter_type_param_scope()
        adt_var, type_params = self.handle_adt_header(ctx)
        # update module being built
        self.module[adt_var] = adt.TypeData(adt_var, type_params, [])
        self.exit_type_param_scope()

    def visitAdtDefn(self, ctx: RelayParser.AdtDefnContext):
        self.enter_type_param_scope()
        adt_var, type_params = self.handle_adt_header(ctx)
        # parse constructors
        adt_cons_defns = ctx.adtConsDefnList()
        if adt_cons_defns is None:
            adt_cons_defns = []
        else:
            adt_cons_defns = adt_cons_defns.adtConsDefn()
        parsed_constructors = []
        for cons_defn in adt_cons_defns:
            inputs = [self.visit(inp) for inp in cons_defn.typeExpr()]
            cons_defn_name = cons_defn.constructorName().getText()
            cons_defn = adt.Constructor(cons_defn_name, inputs, adt_var)
            self.mk_global_typ_cons(cons_defn_name, cons_defn)
            parsed_constructors.append(cons_defn)
        # update module being built
        self.module[adt_var] = adt.TypeData(adt_var, type_params, parsed_constructors)
        self.exit_type_param_scope()

    def visitMatch(self, ctx: RelayParser.MatchContext):
        match_type = ctx.matchType().getText()
        if match_type == "match":
            complete_match = True
        elif match_type == "match?":
            complete_match = False
        else:
573
            raise RuntimeError("unknown match type {0}".format(match_type))
574 575 576 577 578 579 580 581 582 583

        match_data = self.visit(ctx.expr())
        match_clauses = ctx.matchClauseList()
        if match_clauses is None:
            match_clauses = []
        else:
            match_clauses = match_clauses.matchClause()
        parsed_clauses = []
        for clause in match_clauses:
            self.enter_var_scope()
584
            pattern = self.visit(clause.pattern())
585 586
            clause_body = self.visit(clause.expr())
            self.exit_var_scope()
587
            parsed_clauses.append(adt.Clause(pattern, clause_body))
588 589
        return adt.Match(match_data, parsed_clauses, complete=complete_match)

590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606
    def visitWildcardPattern(self, ctx: RelayParser.WildcardPatternContext):
        return adt.PatternWildcard()

    def visitVarPattern(self, ctx: RelayParser.VarPatternContext):
        text = ctx.localVar().getText()
        typ = ctx.typeExpr()
        if typ is not None:
            typ = self.visit(typ)
        var = self.mk_var(text[1:], typ=typ)
        return adt.PatternVar(var)

    def visitConstructorPattern(self, ctx: RelayParser.ConstructorPatternContext):
        constructor_name = ctx.constructorName().getText()
        constructor = self.global_type_vars[constructor_name]
        pattern_list = ctx.patternList()
        if pattern_list is None:
            patterns = []
607
        else:
608 609 610 611 612
            patterns = [self.visit(pattern) for pattern in pattern_list.pattern()]
        return adt.PatternConstructor(constructor, patterns)

    def visitTuplePattern(self, ctx: RelayParser.TuplePatternContext):
        return adt.PatternTuple([self.visit(pattern) for pattern in ctx.patternList().pattern()])
613 614

    def visitCallNoAttr(self, ctx: RelayParser.CallNoAttrContext):
615 616
        return (self.visit_list(ctx.exprList().expr()), None)

617
    def visitCallWithAttr(self, ctx: RelayParser.CallWithAttrContext):
618 619 620 621 622
        return (self.visit_list(ctx.expr()), self.visit(ctx.attrSeq()))

    def call(self, func, args, attrs, type_args):
        if isinstance(func, OpWrapper):
            return func(args, attrs, type_args)
623 624
        elif isinstance(func, adt.Constructor):
            return func(*args)
625 626
        return expr.Call(func, args, attrs, type_args)

627
    @spanify
628
    def visitCall(self, ctx: RelayParser.CallContext) -> expr.Call:
629 630
        func = self.visit(ctx.expr())
        args, attrs = self.visit(ctx.callList())
631 632
        res = self.call(func, args, attrs, [])
        return res
633

634
    @spanify
635
    def visitIfElse(self, ctx: RelayParser.IfElseContext) -> expr.If:
636 637 638 639 640 641 642 643 644 645 646 647 648
        """Construct a Relay If node. Creates a new scope for each branch."""
        cond = self.visit(ctx.expr())

        self.enter_var_scope()
        true_branch = self.visit(ctx.body(0))
        self.exit_var_scope()

        self.enter_var_scope()
        false_branch = self.visit(ctx.body(1))
        self.exit_var_scope()

        return expr.If(cond, true_branch, false_branch)

649
    @spanify
650
    def visitGraph(self, ctx: RelayParser.GraphContext) -> expr.Expr:
651
        """Visit a graph variable assignment."""
652
        graph_nid = int(ctx.graphVar().getText()[1:])
653 654 655 656 657 658 659

        self.enter_var_scope()
        value = self.visit(ctx.expr(0))
        self.exit_var_scope()

        if graph_nid != len(self.graph_expr):
            raise ParseError(
660
                "expected new graph variable to be `%{}`,".format(len(self.graph_expr)) + \
661 662 663 664 665 666
                "but got `%{}`".format(graph_nid))
        self.graph_expr.append(value)

        kont = self.visit(ctx.expr(1))
        return kont

667 668 669
    # Types

    # pylint: disable=unused-argument
670
    def visitIncompleteType(self, ctx: RelayParser.IncompleteTypeContext) -> None:
671 672
        return None

673 674
    def visitTypeCallType(self, ctx: RelayParser.TypeCallTypeContext):
        func = self.visit(ctx.generalIdent())
675
        args = [self.visit(arg) for arg in ctx.typeParamList().typeExpr()]
676
        return ty.TypeCall(func, args)
677

678
    def visitParensShape(self, ctx: RelayParser.ParensShapeContext) -> int:
679 680
        return self.visit(ctx.shape())

681
    def visitShapeList(self, ctx: RelayParser.ShapeListContext) -> List[int]:
682 683
        return self.visit_list(ctx.shape())

684
    def visitTensor(self, ctx: RelayParser.TensorContext):
685 686
        return tuple(self.visit_list(ctx.expr()))

687
    def visitTensorType(self, ctx: RelayParser.TensorTypeContext) -> ty.TensorType:
688 689
        """Create a simple tensor type. No generics."""

690
        shape = self.visit(ctx.shapeList())
691
        dtype = self.visit(ctx.typeExpr())
692 693

        if not isinstance(dtype, ty.TensorType):
694
            raise ParseError("expected dtype to be a Relay base type.")
695 696 697 698 699

        dtype = dtype.dtype

        return ty.TensorType(shape, dtype)

700
    def visitTupleType(self, ctx: RelayParser.TupleTypeContext) -> ty.TupleType:
701
        return ty.TupleType(self.visit_list(ctx.typeExpr()))
702

703
    def visitFuncType(self, ctx: RelayParser.FuncTypeContext) -> ty.FuncType:
704
        types = self.visit_list(ctx.typeExpr())
705 706 707 708 709 710

        arg_types = types[:-1]
        ret_type = types[-1]

        return ty.FuncType(arg_types, ret_type, [], None)

711
def make_parser(data: str) -> RelayParser:
712 713 714
    """Construct a RelayParser a given data stream."""
    input_stream = InputStream(data)
    lexer = RelayLexer(input_stream)
715
    lexer.addErrorListener(StrictErrorListener(data))
716
    token_stream = CommonTokenStream(lexer)
717 718 719
    p = RelayParser(token_stream)
    p.addErrorListener(StrictErrorListener(data))
    return p
720

721 722
__source_name_counter__ = 0

723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745
class StrictErrorListener(ErrorListener):
    """This ErrorListener fail eagerly on all error, and report the program."""
    def __init__(self, text):
        self.text = text

    def syntaxError(self, recognizer, offendingSymbol, line, column, msg, e):
        raise Exception("Syntax Error in:\n" + self.text)

    def reportAmbiguity(self, recognizer, dfa, startIndex, stopIndex, exact, ambigAlts, configs):
        raise Exception("Ambiguity Error in:\n" + self.text)

    def reportAttemptingFullContext(self,
                                    recognizer,
                                    dfa,
                                    startIndex,
                                    stopIndex,
                                    conflictingAlts,
                                    configs):
        raise Exception("Attempting Full Context in:\n" + self.text)

    def reportContextSensitivity(self, recognizer, dfa, startIndex, stopIndex, prediction, configs):
        raise Exception("Context Sensitivity in:\n" + self.text)

746
def fromtext(data: str, source_name: str = None) -> Union[expr.Expr, module.Module]:
747
    """Parse a Relay program."""
748
    if data == "":
749
        raise ParseError("cannot parse the empty string.")
750

751 752 753 754 755 756 757 758
    global __source_name_counter__

    if source_name is None:
        source_name = "source_file{0}".format(__source_name_counter__)

    if isinstance(source_name, str):
        source_name = SourceName(source_name)

759
    tree = make_parser(data).prog()
760
    return ParseTreeToRelayIR(source_name).visit(tree)