# pylint: disable=invalid-name, unused-import """A parser for Relay's text format.""" from __future__ import absolute_import import sys from collections import deque from typing import TypeVar, Deque, Tuple, Optional, Union, NamedTuple, List, Callable, Any, Dict import tvm from . import module from .base import Span, SourceName from . import expr from . import ty from . import op class ParseError(Exception): """Exception type for parse errors.""" def __init__(self, message): # type: (str) -> None super(ParseError, self).__init__() self.message = message PYTHON_VERSION = sys.version_info.major try: if PYTHON_VERSION == 2: from .grammar.py2.RelayVisitor import RelayVisitor from .grammar.py2.RelayParser import RelayParser from .grammar.py2.RelayLexer import RelayLexer else: from .grammar.py3.RelayVisitor import RelayVisitor from .grammar.py3.RelayParser import RelayParser from .grammar.py3.RelayLexer import RelayLexer except ImportError: raise ParseError("Couldn't find ANTLR parser. Try building with USE_ANTLR=ON.") try: from antlr4 import ParserRuleContext, InputStream, CommonTokenStream from antlr4.tree.Tree import TerminalNode except ImportError: raise ParseError("Couldn't find ANTLR runtime." + "Try running `pip{version} install antlr4-python{version}-runtime`." .format(version=PYTHON_VERSION)) BINARY_OPS = { RelayParser.MUL: op.multiply, RelayParser.DIV: op.divide, RelayParser.ADD: op.add, RelayParser.SUB: op.subtract, RelayParser.LT: op.less, RelayParser.GT: op.greater, RelayParser.LE: op.less_equal, RelayParser.GE: op.greater_equal, RelayParser.EQ: op.equal, RelayParser.NE: op.not_equal, } TYPE_PREFIXES = [ "int", "uint", "float", "bool", ] T = TypeVar("T") Scope = Deque[Tuple[str, T]] Scopes = Deque[Scope[T]] def lookup(scopes, name): # type: (Scopes[T], str) -> Optional[T] """Look up `name` in `scopes`.""" for scope in scopes: for key, val in scope: if key == name: return val return None def spanify(f): """A decorator which attaches span information to the value returned by calling `f`. Intended for use with the below AST visiting methods. The idea is that after we do the work of constructing the AST we attach Span information. """ def _wrapper(*args, **kwargs): # Assumes 0th arg is self and gets source_name from object. sn = args[0].source_name # Assumes 1st arg is an ANTLR parser context. ctx = args[1] ast = f(*args, **kwargs) line, col = ctx.getSourceInterval() sp = Span(sn, line, col) ast.set_span(sp) return ast return _wrapper # TODO(@jmp): Use https://stackoverflow.com/q/13889941 # to figure out how to get ANTLR4 to be more unhappy about syntax errors class ParseTreeToRelayIR(RelayVisitor): """Parse Relay text format into Relay IR.""" def __init__(self, source_name): # type: (str) -> None self.source_name = source_name self.module = module.Module({}) # type: module.Module # Adding an empty scope allows naked lets without pain. self.var_scopes = deque([deque()]) # type: Scopes[expr.Var] self.global_var_scope = deque() # type: Scope[expr.GlobalVar] self.type_param_scopes = deque([deque()]) # type: Scopes[ty.TypeVar] self.graph_expr = [] # type: List[expr.Expr] super(ParseTreeToRelayIR, self).__init__() def enter_var_scope(self): # type: () -> None """Enter a new Var scope so it can be popped off later.""" self.var_scopes.appendleft(deque()) def exit_var_scope(self): # type: () -> Scope[expr.Var] """Pop off the current Var scope and return it.""" return self.var_scopes.popleft() def mk_var(self, name, type_): # type: (str, ty.Type) -> expr.Var """Create a new Var and add it to the Var scope.""" var = expr.Var(name, type_) self.var_scopes[0].appendleft((name, var)) return var def mk_global_var(self, name): # type: (str) -> expr.GlobalVar """Create a new GlobalVar and add it to the GlobalVar scope.""" var = expr.GlobalVar(name) self.global_var_scope.append((name, var)) return var def enter_type_param_scope(self): # type: () -> None """Enter a new TypeVar scope so it can be popped off later.""" self.type_param_scopes.appendleft(deque()) def exit_type_param_scope(self): # type: () -> Scope[ty.TypeVar] """Pop off the current TypeVar scope and return it.""" return self.type_param_scopes.popleft() def mk_typ(self, name, kind): # (str, ty.Kind) -> ty.TypeVar """Create a new TypeVar and add it to the TypeVar scope.""" typ = ty.TypeVar(name, kind) self.type_param_scopes[0].appendleft((name, typ)) return typ def visitTerminal(self, node): # type: (TerminalNode) -> Union[expr.Expr, int, float] """Visit lexer tokens that aren't ignored or visited by other functions.""" node_type = node.getSymbol().type node_text = node.getText() name = node_text[1:] # variables if node_type == RelayLexer.GLOBAL_VAR: return lookup(deque([self.global_var_scope]), node_text[1:]) if node_type == RelayLexer.LOCAL_VAR: # Remove the leading '%' and lookup the name. var = lookup(self.var_scopes, name) if var is None: raise ParseError("Couldn't resolve `{}`.".format(name)) return var if node_type == RelayLexer.GRAPH_VAR: try: return self.graph_expr[int(name)] except IndexError: raise ParseError("Couldn't resolve `{}`".format(name)) # data types if node_type == RelayLexer.NAT: return int(node_text) if node_type == RelayLexer.FLOAT: return float(node_text) if node_type == RelayLexer.BOOL_LIT: if node_text == "True": return True if node_text == "False": return False raise ParseError("Unrecognized BOOL_LIT: `{}`".format(node_text)) raise ParseError("todo: {}".format(node_text)) def visit_list(self, ctx_list): # type: (List[ParserRuleContext]) -> List[Any] """"Visit a list of contexts.""" return [self.visit(ctx) for ctx in ctx_list] def getType_(self, ctx): # type: (Optional[RelayParser.Type_Context]) -> Optional[ty.Type] """Return a (possibly None) Relay type.""" if ctx is None: return None return self.visit(ctx) def visitProg(self, ctx): # type: (RelayParser.ProgContext) -> Union[expr.Expr, module.Module] if ctx.defn(): self.visit_list(ctx.defn()) return self.module return self.visit(ctx.expr()) # Exprs def visitOpIdent(self, ctx): # type: (RelayParser.OpIdentContext) -> op.Op return op.get(ctx.CNAME().getText()) # pass through def visitParens(self, ctx): # type: (RelayParser.ParensContext) -> expr.Expr return self.visit(ctx.expr()) # pass through def visitBody(self, ctx): # type: (RelayParser.BodyContext) -> expr.Expr return self.visit(ctx.expr()) def visitScalarFloat(self, ctx): # type: (RelayParser.ScalarFloatContext) -> expr.Constant return expr.const(self.visit(ctx.FLOAT())) def visitScalarInt(self, ctx): # type: (RelayParser.ScalarIntContext) -> expr.Constant return expr.const(self.visit(ctx.NAT())) def visitScalarBool(self, ctx): # type: (RelayParser.ScalarBoolContext) -> expr.Constant return expr.const(self.visit(ctx.BOOL_LIT())) def visitNeg(self, ctx): # type: (RelayParser.NegContext) -> Union[expr.Constant, expr.Call] val = self.visit(ctx.expr()) if isinstance(val, expr.Constant) and val.data.asnumpy().ndim == 0: # fold Neg in for scalars return expr.const(-val.data.asnumpy().item()) return op.negative(val) def visitTuple(self, ctx): # type: (RelayParser.TupleContext) -> expr.Tuple tup = self.visit_list(ctx.expr()) return expr.Tuple(tup) # Currently doesn't support mutable sequencing. def visitLet(self, ctx): # type: (RelayParser.SeqContext) -> expr.Let """Desugar various sequence constructs to Relay Let nodes.""" if ctx.MUT() is not None: raise ParseError("Mutation is currently unsupported.") if ctx.var() is None or ctx.var().ident() is None: # anonymous identity ident = "_" type_ = None else: local_var = ctx.var().ident().LOCAL_VAR() if local_var is None: raise ParseError("Only local ids may be used in `let`s.") ident = local_var.getText()[1:] type_ = self.getType_(ctx.var().type_()) var = self.mk_var(ident, type_) self.enter_var_scope() value = self.visit(ctx.expr(0)) self.exit_var_scope() body = self.visit(ctx.expr(1)) return expr.Let(var, value, body) def visitBinOp(self, ctx): # type: (RelayParser.BinOpContext) -> expr.Call """Desugar binary operators.""" arg0, arg1 = self.visit_list(ctx.expr()) relay_op = BINARY_OPS.get(ctx.op.type) if relay_op is None: raise ParseError("Unimplemented binary op.") return relay_op(arg0, arg1) @spanify def visitVar(self, ctx): # type: (RelayParser.VarContext) -> expr.Var """Visit a single variable.""" ident = ctx.ident().LOCAL_VAR() if ident is None: raise ParseError("Only local ids may be used in vars.") type_ = self.getType_(ctx.type_()) return self.mk_var(ident.getText()[1:], type_) def visitVarList(self, ctx): # type: (RelayParser.VarListContext) -> List[expr.Var] return self.visit_list(ctx.var()) # TODO: support a larger class of values than just Relay exprs def visitAttr(self, ctx): # type: (RelayParser.AttrContext) -> Tuple[str, expr.Expr] return (ctx.CNAME().getText(), self.visit(ctx.expr())) def visitAttrList(self, ctx): # type: (RelayParser.AttrListContext) -> Dict[str, expr.Expr] return dict(self.visit_list(ctx.attr())) def visitArgList(self, ctx # type: RelayParser.ArgListContext ): # type: (...) -> Tuple[Optional[List[expr.Var]], Optional[Dict[str, expr.Expr]]] var_list = self.visit(ctx.varList()) if ctx.varList() else None attr_list = self.visit(ctx.attrList()) if ctx.attrList() else None return (var_list, attr_list) def mk_func(self, ctx): # type: (Union[RelayParser.FuncContext, RelayParser.DefnContext]) -> expr.Function """Construct a function from either a Func or Defn.""" # Enter var scope early to put params in scope. self.enter_var_scope() # Capture type params in params. self.enter_type_param_scope() var_list, attr_list = self.visit(ctx.argList()) ret_type = self.getType_(ctx.type_()) type_params = list(self.exit_type_param_scope()) if type_params: _, type_params = zip(*type_params) body = self.visit(ctx.body()) self.exit_var_scope() attrs = tvm.make.node("DictAttrs", **attr_list) if attr_list is not None else None return expr.Function(var_list, body, ret_type, type_params, attrs) @spanify def visitFunc(self, ctx): # type: (RelayParser.FuncContext) -> expr.Function return self.mk_func(ctx) # TODO: how to set spans for definitions? # @spanify def visitDefn(self, ctx): # type: (RelayParser.DefnContext) -> None ident = ctx.ident().GLOBAL_VAR() if ident is None: raise ParseError("Only global ids may be used in `def`s.") ident_name = ident.getText()[1:] ident = self.mk_global_var(ident_name) self.module[ident] = self.mk_func(ctx) @spanify def visitCall(self, ctx): # type: (RelayParser.CallContext) -> expr.Call visited_exprs = self.visit_list(ctx.expr()) func = visited_exprs[0] args = visited_exprs[1:] return expr.Call(func, args, None, None) @spanify def visitIfElse(self, ctx): # type: (RelayParser.IfElseContext) -> expr.If """Construct a Relay If node. Creates a new scope for each branch.""" cond = self.visit(ctx.expr()) self.enter_var_scope() true_branch = self.visit(ctx.body(0)) self.exit_var_scope() self.enter_var_scope() false_branch = self.visit(ctx.body(1)) self.exit_var_scope() return expr.If(cond, true_branch, false_branch) @spanify def visitGraph(self, ctx): # type: (RelayParser.GraphContext) -> expr.Expr """Visit a graph variable assignment.""" if ctx.ident().GRAPH_VAR() is None: raise ParseError("Expected a graph var, but got `{}`".format(ctx.ident().getText())) graph_nid = int(ctx.ident().GRAPH_VAR().getText()[1:]) self.enter_var_scope() value = self.visit(ctx.expr(0)) self.exit_var_scope() if graph_nid != len(self.graph_expr): raise ParseError( "Expected new graph variable to be `%{}`,".format(len(self.graph_expr)) + \ "but got `%{}`".format(graph_nid)) self.graph_expr.append(value) kont = self.visit(ctx.expr(1)) return kont # Types # pylint: disable=unused-argument def visitIncompleteType(self, ctx): # type (RelayParser.IncompleteTypeContext) -> None: return None def visitIdentType(self, ctx): # type: (RelayParser.IdentTypeContext) -> Union[ty.TensorType, str] ident_type = ctx.CNAME().getText() # look through all type prefixes for a match for type_prefix in TYPE_PREFIXES: if ident_type.startswith(type_prefix): return ty.scalar_type(ident_type) raise ParseError("Unknown builtin type: {}".format(ident_type)) # def visitCallType(self, ctx): # # type: (RelayParser.CallTypeContext) -> Union[expr.Expr, ty.TensorType] # ident_type = ctx.identType().CNAME().getText() # args = self.visit_list(ctx.type_()) # if not args: # raise ParseError("Type-level functions must have arguments!") # func_type = TYPE_FUNCS.get(ident_type)(args) # if func_type is None: # raise ParseError("Unknown type-level function: `{}`".format(ident_type)) # else: # return func_type def visitParensShape(self, ctx): # type: (RelayParser.ParensShapeContext) -> int return self.visit(ctx.shape()) def visitShapeSeq(self, ctx): # type: (RelayParser.ShapeSeqContext) -> List[int] return self.visit_list(ctx.shape()) def visitTensorType(self, ctx): # type: (RelayParser.TensorTypeContext) -> ty.TensorType """Create a simple tensor type. No generics.""" shape = self.visit(ctx.shapeSeq()) dtype = self.visit(ctx.type_()) if not isinstance(dtype, ty.TensorType): raise ParseError("Expected dtype to be a Relay base type.") dtype = dtype.dtype return ty.TensorType(shape, dtype) def visitTupleType(self, ctx): # type: (RelayParser.TupleTypeContext) -> ty.TupleType return ty.TupleType(self.visit_list(ctx.type_())) def visitFuncType(self, ctx): # type: (RelayParser.FuncTypeContext) -> ty.FuncType types = self.visit_list(ctx.type_()) arg_types = types[:-1] ret_type = types[-1] return ty.FuncType(arg_types, ret_type, [], None) def make_parser(data): # type: (str) -> RelayParser """Construct a RelayParser a given data stream.""" input_stream = InputStream(data) lexer = RelayLexer(input_stream) token_stream = CommonTokenStream(lexer) return RelayParser(token_stream) __source_name_counter__ = 0 def fromtext(data, source_name=None): # type: (str, str) -> Union[expr.Expr, module.Module] """Parse a Relay program.""" global __source_name_counter__ if source_name is None: source_name = "source_file{0}".format(__source_name_counter__) if isinstance(source_name, str): source_name = SourceName(source_name) tree = make_parser(data).prog() return ParseTreeToRelayIR(source_name).visit(tree)