Commit ca0292d8 by Logan Weber Committed by Jared Roesch

[Relay] Add ADTs to text format (#3863)

* Getting closer to having ADT defs

* ADT defs working probly

* Match parsing basipally done

* came to earth in a silver chrome UFO

* match finished?

* All tests but newest are passing

* ADT constructors work

now cleanup?

* Cleanup round 1

* Cleanup round 2

* Cleanup round 3

* Cleanup round 4

* Cleanup round 6

* Cleanup round 7

* Lil grammar fix

* Remove ANTLR Java files

* Lint roller

* Lint roller

* Address feedback

* Test completeness in match test

* Remove unused imports

* Lint roller

* Switch to Rust-style ADT syntax

* Lil fix

* Add dummy `extern type` handler

* Add type arg to test

* Update prelude semantic version

* Repair test

* Fix graph var handling in match

* Revert 's/graph_equal/is_unifiable' change
parent a103c4ee
......@@ -264,7 +264,7 @@ class MatchNode : public ExprNode {
/*! \brief The match node clauses. */
tvm::Array<Clause> clauses;
/*! \brief Should this match be complete (cover all cases)?
/*! \brief Should this match be complete (cover all cases)?
* If yes, the type checker will generate an error if there are any missing cases.
*/
bool complete;
......
......@@ -165,6 +165,13 @@ class ModuleNode : public RelayNode {
TVM_DLL TypeData LookupDef(const std::string& var) const;
/*!
* \brief Check if a global type definition exists
* \param var The name of the global type definition.
* \return Whether the definition exists.
*/
TVM_DLL bool HasDef(const std::string& var) const;
/*!
* \brief Look up a constructor by its tag.
* \param tag The tag for the constructor.
* \return The constructor object.
......
......@@ -21,13 +21,14 @@ from __future__ import absolute_import
import sys
from ast import literal_eval
from typing import Any, Deque, Dict, List, Optional, TypeVar, Tuple, Union
from collections import deque
import tvm
from . import module
from .base import Span, SourceName
from . import adt
from . import expr
from . import ty
from . import op
......@@ -53,8 +54,7 @@ sys.setrecursionlimit(10000)
class ParseError(Exception):
"""Exception type for parse errors."""
def __init__(self, message):
# type: (str) -> None
def __init__(self, message: str) -> None:
super(ParseError, self).__init__()
self.message = message
......@@ -143,12 +143,11 @@ TYPE_PREFIXES = [
"bool",
]
T = ty.TypeVar("T")
# Scope = Deque[Tuple[str, T]]
# Scopes = Deque[Scope[T]]
T = TypeVar("T")
Scope = Deque[Tuple[str, T]]
Scopes = Deque[Scope[T]]
def lookup(scopes, name):
# type: (Scopes[T], str) -> Optional[T]
def lookup(scopes: Scopes[T], name: str) -> Optional[T]:
"""Look up `name` in `scopes`."""
for scope in scopes:
......@@ -185,95 +184,92 @@ def spanify(f):
class ParseTreeToRelayIR(RelayVisitor):
"""Parse Relay text format into Relay IR."""
def __init__(self, source_name):
# type: (str) -> None
def __init__(self, source_name: str) -> None:
self.source_name = source_name
self.module = module.Module({}) # type: module.Module
self.module = module.Module({}) # type: module.Module
# Adding an empty scope allows naked lets without pain.
self.var_scopes = deque([deque()]) # type: Scopes[expr.Var]
self.global_var_scope = deque() # type: Scope[expr.GlobalVar]
self.type_param_scopes = deque([deque()]) # type: Scopes[ty.TypeVar]
self.graph_expr = [] # type: List[expr.Expr]
self.var_scopes = deque([deque()]) # type: Scopes[expr.Var]
self.global_vars = {} # type: Scope[expr.GlobalVar]
self.type_var_scopes = deque([deque()]) # type: Scopes[ty.TypeVar]
self.global_type_vars = {} # type: Scope[expr.GlobalVar]
self.graph_expr = [] # type: List[expr.Expr]
super(ParseTreeToRelayIR, self).__init__()
def enter_var_scope(self):
# type: () -> None
def enter_var_scope(self) -> None:
"""Enter a new Var scope so it can be popped off later."""
self.var_scopes.appendleft(deque())
def exit_var_scope(self):
# type: () -> Scope[expr.Var]
def exit_var_scope(self) -> Scope[expr.Var]:
"""Pop off the current Var scope and return it."""
return self.var_scopes.popleft()
def mk_var(self, name, type_):
# type: (str, ty.Type) -> expr.Var
def mk_var(self, name: str, typ: ty.Type = None):
"""Create a new Var and add it to the Var scope."""
var = expr.Var(name, type_)
var = expr.Var(name, typ)
self.var_scopes[0].appendleft((name, var))
return var
def mk_global_var(self, name):
# type: (str) -> expr.GlobalVar
def mk_global_var(self, name: str) -> expr.GlobalVar:
"""Create a new GlobalVar and add it to the GlobalVar scope."""
if name in self.global_vars:
raise ParseError(f"duplicate global var \"{name}\"")
var = expr.GlobalVar(name)
self.global_var_scope.append((name, var))
self.global_vars[name] = var
return var
def enter_type_param_scope(self):
# type: () -> None
def enter_type_param_scope(self) -> None:
"""Enter a new TypeVar scope so it can be popped off later."""
self.type_var_scopes.appendleft(deque())
self.type_param_scopes.appendleft(deque())
def exit_type_param_scope(self):
# type: () -> Scope[ty.TypeVar]
def exit_type_param_scope(self) -> Scope[ty.TypeVar]:
"""Pop off the current TypeVar scope and return it."""
return self.type_var_scopes.popleft()
return self.type_param_scopes.popleft()
def mk_typ(self, name, kind):
# (str, ty.Kind) -> ty.TypeVar
def mk_typ(self, name: str, kind: ty.Kind) -> ty.TypeVar:
"""Create a new TypeVar and add it to the TypeVar scope."""
typ = ty.TypeVar(name, kind)
self.type_param_scopes[0].appendleft((name, typ))
self.type_var_scopes[0].appendleft((name, typ))
return typ
def mk_global_typ_var(self, name, kind):
# (str, ty.Kind) -> ty.GlobalTypeVar
"""Create a new TypeVar and add it to the TypeVar scope."""
typ = ty.GlobalTypeVar(name, kind)
self._check_existing_typ_expr(name, typ)
self.global_type_vars[name] = typ
return typ
# TODO: rethink whether we should have type constructors mixed with type vars.
def mk_global_typ_cons(self, name, cons):
self._check_existing_typ_expr(name, cons)
self.global_type_vars[name] = cons
def _check_existing_typ_expr(self, name, new_expr):
if name in self.global_type_vars:
new_typ_name = self._type_expr_name(new_expr)
existing_typ_name = self._type_expr_name(self.global_type_vars[name])
raise ParseError(
f"{new_typ_name} `{name}` conflicts with existing {existing_typ_name}")
def _type_expr_name(self, e):
if isinstance(e, adt.Constructor):
return f"`{e.belong_to.var.name}` ADT constructor"
elif isinstance(e, ty.GlobalTypeVar):
if e.kind == ty.Kind.AdtHandle:
return f"ADT definition"
return "function definition"
def visitProjection(self, ctx):
return expr.TupleGetItem(self.visit(ctx.expr()), self.visit(ctx.NAT()))
def visitTerminal(self, node):
# type: (TerminalNode) -> Union[expr.Expr, int, float]
def visitTerminal(self, node) -> Union[expr.Expr, int, float]:
"""Visit lexer tokens that aren't ignored or visited by other functions."""
node_type = node.getSymbol().type
node_text = node.getText()
name = node_text[1:]
# variables
if node_type == RelayLexer.GLOBAL_VAR:
return lookup(deque([self.global_var_scope]), node_text[1:])
if node_type == RelayLexer.LOCAL_VAR:
# Remove the leading '%' and lookup the name.
var = lookup(self.var_scopes, name)
if var is None:
raise ParseError("Couldn't resolve `{}`.".format(name))
return var
if node_type == RelayLexer.GRAPH_VAR:
try:
return self.graph_expr[int(name)]
except IndexError:
raise ParseError("Couldn't resolve `{}`".format(name))
# data types
if node_type == RelayLexer.NAT:
return int(node_text)
if node_type == RelayLexer.FLOAT:
......@@ -283,35 +279,67 @@ class ParseTreeToRelayIR(RelayVisitor):
return True
if node_text == "False":
return False
raise ParseError("Unrecognized BOOL_LIT: `{}`".format(node_text))
raise ParseError("unrecognized BOOL_LIT: `{}`".format(node_text))
if node_type == RelayLexer.QUOTED_STRING:
return literal_eval(node_text)
raise ParseError(f"unhandled terminal \"{node_text}\" of type `{node_type}`")
def visitGeneralIdent(self, ctx):
name = ctx.getText()
# Look through all type prefixes for a match.
for type_prefix in TYPE_PREFIXES:
if name.startswith(type_prefix):
return ty.scalar_type(name)
# Next, look it up in the local then global type params.
type_param = lookup(self.type_var_scopes, name)
if type_param is None:
type_param = self.global_type_vars.get(name, None)
if type_param is not None:
return type_param
# Check if it's an operator.
op_name = ".".join([name.getText() for name in ctx.CNAME()])
if op_name in FUNC_OPS:
return FuncOp(FUNC_OPS[op_name])
return ExprOp(op.get(op_name))
def visitGlobalVar(self, ctx):
var_name = ctx.CNAME().getText()
global_var = self.global_vars.get(var_name, None)
if global_var is None:
raise ParseError(f"unbound global var `{var_name}`")
return global_var
def visitLocalVar(self, ctx):
var_name = ctx.CNAME().getText()
local_var = lookup(self.var_scopes, var_name)
if local_var is None:
raise ParseError(f"unbound local var `{var_name}`")
return local_var
raise ParseError("todo: `{}`".format(node_text))
def visitGraphVar(self, ctx):
return self.graph_expr[int(ctx.NAT().getText())]
def visit_list(self, ctx_list):
# type: (List[ParserRuleContext]) -> List[Any]
def visit_list(self, ctx_list) -> List[Any]:
""""Visit a list of contexts."""
# type: RelayParser.ContextParserRuleContext
assert isinstance(ctx_list, list)
return [self.visit(ctx) for ctx in ctx_list]
def getType_(self, ctx):
# type: (Optional[RelayParser.Type_Context]) -> Optional[ty.Type]
def getTypeExpr(self, ctx) -> Optional[ty.Type]:
"""Return a (possibly None) Relay type."""
# type: : Optional[RelayParser.Type_Context]
if ctx is None:
return None
return self.visit(ctx)
def visitProg(self, ctx):
def visitProg(self, ctx: RelayParser.ProgContext) -> Union[expr.Expr, module.Module]:
self.meta = None
if ctx.METADATA():
header, data = str(ctx.METADATA()).split('\n', 1)
header, data = str(ctx.METADATA()).split("\n", 1)
assert header == "METADATA:"
self.meta = tvm.load_json(data)
# type: (RelayParser.ProgContext) -> Union[expr.Expr, module.Module]
if ctx.defn():
self.visit_list(ctx.defn())
return self.module
......@@ -322,37 +350,30 @@ class ParseTreeToRelayIR(RelayVisitor):
return self.module
# Exprs
def visitOpIdent(self, ctx):
# type: (RelayParser.OpIdentContext) -> op.Op
op_name = ctx.CNAME().getText()
def visitOpIdent(self, ctx) -> op.Op:
op_name = ".".join([name.getText() for name in ctx.CNAME()])
if op_name in FUNC_OPS:
return FuncOp(FUNC_OPS[op_name])
return ExprOp(op.get(op_name))
# pass through
def visitParen(self, ctx):
# type: (RelayParser.ParenContext) -> expr.Expr
def visitParen(self, ctx: RelayParser.ParenContext) -> expr.Expr:
return self.visit(ctx.expr())
# pass through
def visitBody(self, ctx):
# type: (RelayParser.BodyContext) -> expr.Expr
def visitBody(self, ctx: RelayParser.BodyContext) -> expr.Expr:
return self.visit(ctx.expr())
def visitScalarFloat(self, ctx):
# type: (RelayParser.ScalarFloatContext) -> expr.Constant
def visitScalarFloat(self, ctx: RelayParser.ScalarFloatContext) -> expr.Constant:
return expr.const(self.visit(ctx.FLOAT()))
def visitScalarInt(self, ctx):
# type: (RelayParser.ScalarIntContext) -> expr.Constant
def visitScalarInt(self, ctx: RelayParser.ScalarIntContext) -> expr.Constant:
return expr.const(self.visit(ctx.NAT()))
def visitScalarBool(self, ctx):
# type: (RelayParser.ScalarBoolContext) -> expr.Constant
def visitScalarBool(self, ctx: RelayParser.ScalarBoolContext) -> expr.Constant:
return expr.const(self.visit(ctx.BOOL_LIT()))
def visitNeg(self, ctx):
# type: (RelayParser.NegContext) -> Union[expr.Constant, expr.Call]
def visitNeg(self, ctx: RelayParser.NegContext) -> Union[expr.Constant, expr.Call]:
val = self.visit(ctx.expr())
if isinstance(val, expr.Constant) and val.data.asnumpy().ndim == 0:
# fold Neg in for scalars
......@@ -360,20 +381,18 @@ class ParseTreeToRelayIR(RelayVisitor):
return op.negative(val)
def visitTuple(self, ctx):
# type: (RelayParser.TupleContext) -> expr.Tuple
def visitTuple(self, ctx: RelayParser.TupleContext) -> expr.Tuple:
tup = self.visit_list(ctx.expr())
return expr.Tuple(tup)
def visitLet(self, ctx):
# type: (RelayParser.SeqContext) -> expr.Let
def visitLet(self, ctx: RelayParser.LetContext) -> expr.Let:
"""Desugar various sequence constructs to Relay Let nodes."""
if ctx.var() is None:
# anonymous identity
ident = "_"
type_ = None
var = self.mk_var(ident, type_)
typ = None
var = self.mk_var(ident, typ)
else:
var = self.visitVar(ctx.var())
......@@ -385,66 +404,61 @@ class ParseTreeToRelayIR(RelayVisitor):
return expr.Let(var, value, body)
def visitBinOp(self, ctx):
# type: (RelayParser.BinOpContext) -> expr.Call
def visitBinOp(self, ctx: RelayParser.BinOpContext) -> expr.Call:
"""Desugar binary operators."""
arg0, arg1 = self.visit_list(ctx.expr())
relay_op = BINARY_OPS.get(ctx.op.type)
if relay_op is None:
raise ParseError("Unimplemented binary op.")
raise ParseError("unimplemented binary op.")
return relay_op(arg0, arg1)
@spanify
def visitVar(self, ctx):
# type: (RelayParser.VarContext) -> expr.Var
def visitVar(self, ctx: RelayParser.VarContext) -> expr.Var:
"""Visit a single variable."""
ident = ctx.LOCAL_VAR()
ident = ctx.localVar()
if ident is None:
raise ParseError("Only local ids may be used in vars.")
raise ParseError("only local ids may be used in vars.")
type_ = self.getType_(ctx.type_())
typeExpr = self.getTypeExpr(ctx.typeExpr())
return self.mk_var(ident.getText()[1:], type_)
return self.mk_var(ident.getText()[1:], typeExpr)
def visitVarList(self, ctx):
# type: (RelayParser.VarListContext) -> List[expr.Var]
def visitVarList(self, ctx: RelayParser.VarListContext) -> List[expr.Var]:
return self.visit_list(ctx.var())
# TODO: support a larger class of values than just Relay exprs
def visitAttr(self, ctx):
# type: (RelayParser.AttrContext) -> Tuple[str, expr.Expr]
def visitAttr(self, ctx: RelayParser.AttrContext) -> Tuple[str, expr.Expr]:
return (ctx.CNAME().getText(), self.visit(ctx.expr()))
def visitArgNoAttr(self, ctx):
def visitArgNoAttr(self, ctx: RelayParser.ArgNoAttrContext):
return (self.visit_list(ctx.varList().var()), None)
def visitAttrSeq(self, ctx):
# type: (RelayParser.AttrListContext) -> Dict[str, expr.Expr]
def visitAttrSeq(self, ctx: RelayParser.AttrSeqContext) -> Dict[str, expr.Expr]:
return dict(self.visit_list(ctx.attr()))
def visitArgWithAttr(self, ctx):
def visitArgWithAttr(self, ctx: RelayParser.AttrSeqContext) \
-> Tuple[List[expr.Var], Dict[str, expr.Expr]]:
return (self.visit_list(ctx.var()), self.visitAttrSeq(ctx.attrSeq()))
def visitArgList(self,
ctx # type: RelayParser.ArgListContext
):
# type: (...) -> Tuple[Optional[List[expr.Var]], Optional[Dict[str, expr.Expr]]]
def visitArgList(self, ctx: RelayParser.ArgListContext) \
-> Tuple[Optional[List[expr.Var]], Optional[Dict[str, expr.Expr]]]:
var_list = self.visit(ctx.varList()) if ctx.varList() else None
attr_list = self.visit(ctx.attrList()) if ctx.attrList() else None
return (var_list, attr_list)
def visitMeta(self, ctx):
def visitMeta(self, ctx: RelayParser.MetaContext):
type_key = str(ctx.CNAME())
index = int(self.visit(ctx.NAT()))
return self.meta[type_key][index]
def mk_func(self, ctx):
# type: (Union[RelayParser.FuncContext, RelayParser.DefnContext]) -> expr.Function
def mk_func(
self,
ctx: Union[RelayParser.FuncContext, RelayParser.DefnContext]) \
-> expr.Function:
"""Construct a function from either a Func or Defn."""
# Enter var scope early to put params in scope.
self.enter_var_scope()
# Capture type params in params.
......@@ -452,7 +466,7 @@ class ParseTreeToRelayIR(RelayVisitor):
type_params = ctx.typeParamList()
if type_params is not None:
type_params = type_params.ident()
type_params = type_params.generalIdent()
assert type_params
for ty_param in type_params:
name = ty_param.getText()
......@@ -461,7 +475,7 @@ class ParseTreeToRelayIR(RelayVisitor):
var_list, attr_list = self.visit(ctx.argList())
if var_list is None:
var_list = []
ret_type = self.getType_(ctx.type_())
ret_type = self.getTypeExpr(ctx.typeExpr())
body = self.visit(ctx.body())
# NB(@jroesch): you must stay in the type parameter scope until
......@@ -476,41 +490,135 @@ class ParseTreeToRelayIR(RelayVisitor):
return expr.Function(var_list, body, ret_type, type_params, attrs)
@spanify
def visitFunc(self, ctx):
# type: (RelayParser.FuncContext) -> expr.Function
def visitFunc(self, ctx: RelayParser.FuncContext) -> expr.Function:
return self.mk_func(ctx)
# TODO: how to set spans for definitions?
# @spanify
def visitDefn(self, ctx):
# type: (RelayParser.DefnContext) -> None
ident = ctx.ident().GLOBAL_VAR()
if ident is None:
raise ParseError("Only global ids may be used in `def`s.")
ident_name = ident.getText()[1:]
def visitFuncDefn(self, ctx: RelayParser.DefnContext) -> None:
ident_name = ctx.globalVar().getText()[1:]
ident = self.mk_global_var(ident_name)
self.module[ident] = self.mk_func(ctx)
def visitCallNoAttr(self, ctx):
def handle_adt_header(
self,
ctx: Union[RelayParser.ExternAdtDefnContext, RelayParser.AdtDefnContext]):
"""Handles parsing of the name and type params of an ADT definition."""
adt_name = ctx.generalIdent().getText()
adt_var = self.mk_global_typ_var(adt_name, ty.Kind.AdtHandle)
# parse type params
type_params = ctx.typeParamList()
if type_params is None:
type_params = []
else:
type_params = [self.mk_typ(type_ident.getText(), ty.Kind.Type)
for type_ident in type_params.generalIdent()]
return adt_var, type_params
def visitExternAdtDefn(self, ctx: RelayParser.ExternAdtDefnContext):
# TODO(weberlo): update this handler once extern is implemented
self.enter_type_param_scope()
adt_var, type_params = self.handle_adt_header(ctx)
# update module being built
self.module[adt_var] = adt.TypeData(adt_var, type_params, [])
self.exit_type_param_scope()
def visitAdtDefn(self, ctx: RelayParser.AdtDefnContext):
self.enter_type_param_scope()
adt_var, type_params = self.handle_adt_header(ctx)
# parse constructors
adt_cons_defns = ctx.adtConsDefnList()
if adt_cons_defns is None:
adt_cons_defns = []
else:
adt_cons_defns = adt_cons_defns.adtConsDefn()
parsed_constructors = []
for cons_defn in adt_cons_defns:
inputs = [self.visit(inp) for inp in cons_defn.typeExpr()]
cons_defn_name = cons_defn.constructorName().getText()
cons_defn = adt.Constructor(cons_defn_name, inputs, adt_var)
self.mk_global_typ_cons(cons_defn_name, cons_defn)
parsed_constructors.append(cons_defn)
# update module being built
self.module[adt_var] = adt.TypeData(adt_var, type_params, parsed_constructors)
self.exit_type_param_scope()
def visitMatch(self, ctx: RelayParser.MatchContext):
match_type = ctx.matchType().getText()
if match_type == "match":
complete_match = True
elif match_type == "match?":
complete_match = False
else:
raise RuntimeError(f"unknown match type {match_type}")
# TODO: Will need some kind of type checking to know which ADT is being
# matched on.
match_data = self.visit(ctx.expr())
match_clauses = ctx.matchClauseList()
if match_clauses is None:
match_clauses = []
else:
match_clauses = match_clauses.matchClause()
parsed_clauses = []
for clause in match_clauses:
constructor_name = clause.constructorName().getText()
constructor = self.global_type_vars[constructor_name]
self.enter_var_scope()
patternList = clause.patternList()
if patternList is None:
patterns = []
else:
patterns = [self.visit(pattern) for pattern in patternList.pattern()]
clause_body = self.visit(clause.expr())
self.exit_var_scope()
# TODO: Do we need to pass `None` if it's a 0-arity cons, or is an empty list fine?
parsed_clauses.append(adt.Clause(
adt.PatternConstructor(
constructor,
patterns
),
clause_body
))
return adt.Match(match_data, parsed_clauses, complete=complete_match)
def visitPattern(self, ctx: RelayParser.PatternContext):
text = ctx.getText()
if text == "_":
return adt.PatternWildcard()
elif text.startswith("%"):
text = ctx.localVar().getText()
typ = ctx.typeExpr()
if typ is not None:
typ = self.visit(typ)
var = self.mk_var(text[1:], typ=typ)
return adt.PatternVar(var)
else:
raise ParseError(f"invalid pattern syntax \"{text}\"")
def visitCallNoAttr(self, ctx: RelayParser.CallNoAttrContext):
return (self.visit_list(ctx.exprList().expr()), None)
def visitCallWithAttr(self, ctx):
def visitCallWithAttr(self, ctx: RelayParser.CallWithAttrContext):
return (self.visit_list(ctx.expr()), self.visit(ctx.attrSeq()))
def call(self, func, args, attrs, type_args):
if isinstance(func, OpWrapper):
return func(args, attrs, type_args)
elif isinstance(func, adt.Constructor):
return func(*args)
return expr.Call(func, args, attrs, type_args)
@spanify
def visitCall(self, ctx):
def visitCall(self, ctx: RelayParser.CallContext):
# type: (RelayParser.CallContext) -> expr.Call
func = self.visit(ctx.expr())
args, attrs = self.visit(ctx.callList())
return self.call(func, args, attrs, [])
res = self.call(func, args, attrs, [])
return res
@spanify
def visitIfElse(self, ctx):
def visitIfElse(self, ctx: RelayParser.IfElseContext):
# type: (RelayParser.IfElseContext) -> expr.If
"""Construct a Relay If node. Creates a new scope for each branch."""
cond = self.visit(ctx.expr())
......@@ -526,10 +634,10 @@ class ParseTreeToRelayIR(RelayVisitor):
return expr.If(cond, true_branch, false_branch)
@spanify
def visitGraph(self, ctx):
def visitGraph(self, ctx: RelayParser.GraphContext):
# type: (RelayParser.GraphContext) -> expr.Expr
"""Visit a graph variable assignment."""
graph_nid = int(ctx.GRAPH_VAR().getText()[1:])
graph_nid = int(ctx.graphVar().getText()[1:])
self.enter_var_scope()
value = self.visit(ctx.expr(0))
......@@ -537,7 +645,7 @@ class ParseTreeToRelayIR(RelayVisitor):
if graph_nid != len(self.graph_expr):
raise ParseError(
"Expected new graph variable to be `%{}`,".format(len(self.graph_expr)) + \
"expected new graph variable to be `%{}`,".format(len(self.graph_expr)) + \
"but got `%{}`".format(graph_nid))
self.graph_expr.append(value)
......@@ -547,76 +655,47 @@ class ParseTreeToRelayIR(RelayVisitor):
# Types
# pylint: disable=unused-argument
def visitIncompleteType(self, ctx):
def visitIncompleteType(self, ctx: RelayParser.IncompleteTypeContext):
# type (RelayParser.IncompleteTypeContext) -> None:
return None
def visitTypeIdent(self, ctx):
# type: (RelayParser.TypeIdentContext) -> Union[ty.TensorType, str]
'''
Handle type identifier.
'''
type_ident = ctx.CNAME().getText()
# Look through all type prefixes for a match
for type_prefix in TYPE_PREFIXES:
if type_ident.startswith(type_prefix):
return ty.scalar_type(type_ident)
type_param = lookup(self.type_param_scopes, type_ident)
if type_param is not None:
return type_param
raise ParseError("Unknown builtin type: {}".format(type_ident))
# def visitCallType(self, ctx):
# # type: (RelayParser.CallTypeContext) -> Union[expr.Expr, ty.TensorType]
# ident_type = ctx.identType().CNAME().getText()
# args = self.visit_list(ctx.type_())
# if not args:
# raise ParseError("Type-level functions must have arguments!")
# func_type = TYPE_FUNCS.get(ident_type)(args)
# if func_type is None:
# raise ParseError("Unknown type-level function: `{}`".format(ident_type))
# else:
# return func_type
def visitTypeCallType(self, ctx: RelayParser.TypeCallTypeContext):
func = self.visit(ctx.generalIdent())
args = [self.visit(arg) for arg in ctx.typeParamList().generalIdent()]
return ty.TypeCall(func, args)
def visitParensShape(self, ctx):
def visitParensShape(self, ctx: RelayParser.ParensShapeContext):
# type: (RelayParser.ParensShapeContext) -> int
return self.visit(ctx.shape())
def visitShapeList(self, ctx):
def visitShapeList(self, ctx: RelayParser.ShapeListContext):
# type: (RelayParser.ShapeListContext) -> List[int]
return self.visit_list(ctx.shape())
def visitTensor(self, ctx):
def visitTensor(self, ctx: RelayParser.TensorContext):
return tuple(self.visit_list(ctx.expr()))
def visitTensorType(self, ctx):
def visitTensorType(self, ctx: RelayParser.TensorTypeContext):
# type: (RelayParser.TensorTypeContext) -> ty.TensorType
"""Create a simple tensor type. No generics."""
shape = self.visit(ctx.shapeList())
dtype = self.visit(ctx.type_())
dtype = self.visit(ctx.typeExpr())
if not isinstance(dtype, ty.TensorType):
raise ParseError("Expected dtype to be a Relay base type.")
raise ParseError("expected dtype to be a Relay base type.")
dtype = dtype.dtype
return ty.TensorType(shape, dtype)
def visitTupleType(self, ctx):
def visitTupleType(self, ctx: RelayParser.TupleTypeContext):
# type: (RelayParser.TupleTypeContext) -> ty.TupleType
return ty.TupleType(self.visit_list(ctx.type_()))
return ty.TupleType(self.visit_list(ctx.typeExpr()))
def visitFuncType(self, ctx):
def visitFuncType(self, ctx: RelayParser.FuncTypeContext):
# type: (RelayParser.FuncTypeContext) -> ty.FuncType
types = self.visit_list(ctx.type_())
types = self.visit_list(ctx.typeExpr())
arg_types = types[:-1]
ret_type = types[-1]
......@@ -663,7 +742,7 @@ def fromtext(data, source_name=None):
# type: (str, str) -> Union[expr.Expr, module.Module]
"""Parse a Relay program."""
if data == "":
raise ParseError("Cannot parse the empty string.")
raise ParseError("cannot parse the empty string.")
global __source_name_counter__
......
......@@ -17,11 +17,15 @@
* under the License.
*/
// list = *, seq = ?
/*
* NOTE: The `USE_ANTLR` option in `config.cmake` must be enabled in order for
* changes in this file to be reflected by the parser.
* NOTE: All upper-case rules are *lexer* rules and all camel-case rules are *parser* rules.
*/
grammar Relay;
SEMVER: 'v0.0.3' ;
SEMVER: 'v0.0.4' ;
// Lexing
// comments
......@@ -49,13 +53,8 @@ BOOL_LIT
| 'False'
;
CNAME: ('_'|LETTER) ('_'|LETTER|DIGIT)* ('.' CNAME)*;
opIdent: CNAME ;
GLOBAL_VAR: '@' CNAME ;
LOCAL_VAR: '%' CNAME;
GRAPH_VAR: '%' NAT;
CNAME: ('_'|LETTER) ('_'|LETTER|DIGIT)* ('.' CNAME)* ;
DATATYPE : 'int64';
// non-negative floats
fragment PREFLOAT : NAT ('.' NAT)? EXP?; // 1.35, 1.35E-9, 0.3, 4.5, 1, 1e10 3e4
......@@ -74,106 +73,124 @@ METADATA: 'METADATA:' .*;
// A Relay program is a list of global definitions or an expression.
prog: SEMVER (defn* | expr) METADATA? EOF ;
// option: 'set' ident BOOL_LIT ;
// Covers both operator and type idents
generalIdent: CNAME ('.' CNAME)*;
globalVar: '@' CNAME ;
localVar: '%' ('_' | CNAME) ;
graphVar: '%' NAT ;
exprList: (expr (',' expr)*)?;
callList
: exprList # callNoAttr
| (expr ',')* attrSeq # callWithAttr
: exprList # callNoAttr
| (expr ',')* attrSeq # callWithAttr
;
expr
// operators
: '(' expr ')' # paren
| '{' expr '}' # paren
: '(' expr ')' # paren
// function application
| expr '(' callList ')' # call
| '-' expr # neg
| expr op=('*'|'/') expr # binOp
| expr op=('+'|'-') expr # binOp
| expr op=('<'|'>'|'<='|'>=') expr # binOp
| expr op=('=='|'!=') expr # binOp
| expr '(' callList ')' # call
| '-' expr # neg
| expr op=('*'|'/') expr # binOp
| expr op=('+'|'-') expr # binOp
| expr op=('<'|'>'|'<='|'>=') expr # binOp
| expr op=('=='|'!=') expr # binOp
// function definition
| func # funcExpr
| func # funcExpr
// tuples and tensors
| '(' ')' # tuple
| '(' expr ',' ')' # tuple
| '(' expr (',' expr)+ ')' # tuple
| expr '.' NAT # projection
| '[' (expr (',' expr)*)? ']' # tensor
| 'if' '(' expr ')' body 'else' body # ifElse
| '(' ')' # tuple
| '(' expr ',' ')' # tuple
| '(' expr (',' expr)+ ')' # tuple
| '[' (expr (',' expr)*)? ']' # tensor
| 'if' '(' expr ')' body 'else' body # ifElse
| matchType '(' expr ')' '{' matchClauseList? '}' # match
| expr '.' NAT # projection
// sequencing
| 'let' var '=' expr ';' expr # let
| 'let' var '=' expr ';' expr # let
// sugar for let %_ = expr; expr
| expr ';;' expr # let
| GRAPH_VAR '=' expr ';' expr # graph
| ident # identExpr
| scalar # scalarExpr
| meta # metaExpr
| QUOTED_STRING # stringExpr
| expr ';;' expr # let
| graphVar '=' expr ';' expr # graph
| ident # identExpr
| scalar # scalarExpr
| meta # metaExpr
| QUOTED_STRING # stringExpr
;
func: 'fn' typeParamList? '(' argList ')' ('->' typeExpr)? body ;
defn
: 'def' globalVar typeParamList? '(' argList ')' ('->' typeExpr)? body # funcDefn
| 'extern' 'type' generalIdent typeParamList? # externAdtDefn
| 'type' generalIdent typeParamList? '{' adtConsDefnList? '}' # adtDefn
;
constructorName: CNAME ;
adtConsDefnList: adtConsDefn (',' adtConsDefn)* ','? ;
adtConsDefn: constructorName ('(' typeExpr (',' typeExpr)* ')')? ;
matchClauseList: matchClause (',' matchClause)* ','? ;
matchClause: constructorName patternList? '=>' ('{' expr '}' | expr) ;
// complete or incomplete match, respectively
matchType : 'match' | 'match?' ;
patternList: '(' pattern (',' pattern)* ')';
pattern
: '_'
| localVar (':' typeExpr)?
;
func: 'fn' typeParamList? '(' argList ')' ('->' type_)? body ;
defn: 'def' ident typeParamList? '(' argList ')' ('->' type_)? body ;
adtCons: constructorName adtConsParamList? ;
adtConsParamList: '(' adtConsParam (',' adtConsParam)* ')' ;
adtConsParam: localVar | constructorName ;
argList
: varList # argNoAttr
| (var ',')* attrSeq # argWithAttr
: varList # argNoAttr
| (var ',')* attrSeq # argWithAttr
;
varList: (var (',' var)*)?;
var: LOCAL_VAR (':' type_)?;
varList: (var (',' var)*)? ;
var: localVar (':' typeExpr)? ;
attrSeq: attr (',' attr)*;
attrSeq: attr (',' attr)* ;
attr: CNAME '=' expr ;
typeParamList
: '[' ']'
| '[' ident (',' ident)* ']'
typeExpr
: '(' ')' # tupleType
| '(' typeExpr ',' ')' # tupleType
| '(' typeExpr (',' typeExpr)+ ')' # tupleType
| generalIdent typeParamList # typeCallType
| generalIdent # typeIdentType
| 'Tensor' '[' shapeList ',' typeExpr ']' # tensorType
| 'fn' typeParamList? '(' (typeExpr (',' typeExpr)*)? ')' '->' typeExpr # funcType
| '_' # incompleteType
;
type_
: '(' ')' # tupleType
| '(' type_ ',' ')' # tupleType
| '(' type_ (',' type_)+ ')' # tupleType
| typeIdent # typeIdentType
| 'Tensor' '[' shapeList ',' type_ ']' # tensorType
| 'fn' typeParamList? '(' (type_ (',' type_)*)? ')' '->' type_ # funcType
| '_' # incompleteType
| NAT # intType
;
typeParamList: '[' generalIdent (',' generalIdent)* ']' ;
shapeList
: '(' shape (',' shape)+ ')'
| '(' ')'
: '(' ')'
| '(' shape (',' shape)+ ')'
| shape
;
meta : 'meta' '[' CNAME ']' '[' NAT ']';
shape
: meta # metaShape
| '(' shape ')' # parensShape
| NAT # intShape
: meta # metaShape
| '(' shape ')' # parensShape
| NAT # intShape
;
typeIdent : CNAME;
// int8, int16, int32, int64
// uint8, uint16, uint32, uint64
// float16, float32, float64
// bool
body: '{' expr '}' ;
scalar
: FLOAT # scalarFloat
| NAT # scalarInt
| BOOL_LIT # scalarBool
: FLOAT # scalarFloat
| NAT # scalarInt
| BOOL_LIT # scalarBool
;
ident
: opIdent
| GLOBAL_VAR
| LOCAL_VAR
| GRAPH_VAR
: generalIdent
| globalVar
| localVar
| graphVar
;
# Generated from /workspace/python/tvm/relay/grammar/Relay.g4 by ANTLR 4.7.1
# Generated from /Users/doobs/Code/repo/sampl/tvm/python/tvm/relay/grammar/Relay.g4 by ANTLR 4.7.2
from antlr4 import *
from io import StringIO
from typing.io import TextIO
import sys
def serializedATN():
with StringIO() as buf:
buf.write("\3\u608b\ua72a\u8133\ub9ed\u417c\u3be7\u7786\u5964\2/")
buf.write("\u014a\b\1\4\2\t\2\4\3\t\3\4\4\t\4\4\5\t\5\4\6\t\6\4\7")
buf.write("\3\u608b\ua72a\u8133\ub9ed\u417c\u3be7\u7786\u5964\2\62")
buf.write("\u0161\b\1\4\2\t\2\4\3\t\3\4\4\t\4\4\5\t\5\4\6\t\6\4\7")
buf.write("\t\7\4\b\t\b\4\t\t\t\4\n\t\n\4\13\t\13\4\f\t\f\4\r\t\r")
buf.write("\4\16\t\16\4\17\t\17\4\20\t\20\4\21\t\21\4\22\t\22\4\23")
buf.write("\t\23\4\24\t\24\4\25\t\25\4\26\t\26\4\27\t\27\4\30\t\30")
buf.write("\4\31\t\31\4\32\t\32\4\33\t\33\4\34\t\34\4\35\t\35\4\36")
buf.write("\t\36\4\37\t\37\4 \t \4!\t!\4\"\t\"\4#\t#\4$\t$\4%\t%")
buf.write("\4&\t&\4\'\t\'\4(\t(\4)\t)\4*\t*\4+\t+\4,\t,\4-\t-\4.")
buf.write("\t.\4/\t/\4\60\t\60\4\61\t\61\4\62\t\62\4\63\t\63\3\2")
buf.write("\3\2\3\3\3\3\3\4\3\4\3\5\3\5\3\6\3\6\3\7\3\7\3\b\3\b\3")
buf.write("\t\3\t\3\n\3\n\3\n\3\13\3\13\3\13\3\13\3\13\3\f\3\f\3")
buf.write("\f\3\f\3\r\3\r\3\16\3\16\3\17\3\17\3\17\3\20\3\20\3\20")
buf.write("\3\21\3\21\3\21\3\22\3\22\3\22\3\22\3\23\3\23\3\24\3\24")
buf.write("\3\24\3\24\3\24\3\24\3\24\3\25\3\25\3\26\3\26\3\26\3\26")
buf.write("\3\26\3\27\3\27\3\27\3\27\3\27\3\27\3\27\3\30\3\30\3\30")
buf.write("\3\30\3\30\7\30\u00b1\n\30\f\30\16\30\u00b4\13\30\3\30")
buf.write("\3\30\3\30\3\30\3\30\3\31\6\31\u00bc\n\31\r\31\16\31\u00bd")
buf.write("\3\31\3\31\3\32\3\32\3\32\3\32\7\32\u00c6\n\32\f\32\16")
buf.write("\32\u00c9\13\32\3\32\3\32\3\32\3\32\3\33\3\33\3\33\3\34")
buf.write("\3\34\3\34\7\34\u00d5\n\34\f\34\16\34\u00d8\13\34\3\34")
buf.write("\3\34\3\35\3\35\3\36\3\36\3\37\3\37\3 \3 \3!\3!\3\"\3")
buf.write("\"\3#\3#\3#\3$\3$\3$\3%\3%\3%\3&\3&\3&\3\'\3\'\3\'\3\'")
buf.write("\3\'\3\'\3\'\3\'\3\'\5\'\u00fd\n\'\3(\3(\5(\u0101\n(\3")
buf.write("(\3(\3(\7(\u0106\n(\f(\16(\u0109\13(\3(\3(\7(\u010d\n")
buf.write("(\f(\16(\u0110\13(\3)\3)\3)\3*\3*\3*\3+\3+\3+\3,\3,\3")
buf.write(",\3,\3,\3,\3-\3-\3-\5-\u0124\n-\3-\5-\u0127\n-\3.\3.\3")
buf.write(".\3/\6/\u012d\n/\r/\16/\u012e\3\60\3\60\5\60\u0133\n\60")
buf.write("\3\60\3\60\3\61\3\61\3\62\3\62\3\63\3\63\3\63\3\63\3\63")
buf.write("\3\63\3\63\3\63\3\63\3\63\3\63\7\63\u0146\n\63\f\63\16")
buf.write("\63\u0149\13\63\5\u00b2\u00c7\u00d6\2\64\3\3\5\4\7\5\t")
buf.write("\6\13\7\r\b\17\t\21\n\23\13\25\f\27\r\31\16\33\17\35\20")
buf.write("\37\21!\22#\23%\24\'\25)\26+\27-\30/\31\61\32\63\33\65")
buf.write("\2\67\349\35;\36=\37? A!C\"E#G$I%K&M\'O(Q)S*U+W,Y\2[-")
buf.write("]._\2a\2c\2e/\3\2\b\5\2\13\f\17\17\"\"\4\2\f\f\17\17\4")
buf.write("\2GGgg\4\2--//\4\2C\\c|\3\2\62;\2\u0155\2\3\3\2\2\2\2")
buf.write("\t.\4/\t/\4\60\t\60\4\61\t\61\4\62\t\62\4\63\t\63\4\64")
buf.write("\t\64\4\65\t\65\4\66\t\66\3\2\3\2\3\3\3\3\3\4\3\4\3\5")
buf.write("\3\5\3\6\3\6\3\7\3\7\3\b\3\b\3\t\3\t\3\n\3\n\3\13\3\13")
buf.write("\3\13\3\f\3\f\3\f\3\f\3\f\3\r\3\r\3\16\3\16\3\17\3\17")
buf.write("\3\17\3\17\3\20\3\20\3\21\3\21\3\22\3\22\3\22\3\23\3\23")
buf.write("\3\23\3\24\3\24\3\24\3\25\3\25\3\25\3\25\3\26\3\26\3\26")
buf.write("\3\26\3\26\3\26\3\26\3\27\3\27\3\27\3\27\3\27\3\30\3\30")
buf.write("\3\30\3\31\3\31\3\31\3\31\3\31\3\31\3\32\3\32\3\32\3\32")
buf.write("\3\32\3\32\3\32\3\33\3\33\3\34\3\34\3\34\3\34\3\34\3\34")
buf.write("\3\34\3\35\3\35\3\35\3\35\3\35\3\36\3\36\3\36\3\36\3\36")
buf.write("\3\36\3\36\3\37\3\37\3\37\3\37\3\37\7\37\u00d7\n\37\f")
buf.write("\37\16\37\u00da\13\37\3\37\3\37\3\37\3\37\3\37\3 \6 \u00e2")
buf.write("\n \r \16 \u00e3\3 \3 \3!\3!\3!\3!\7!\u00ec\n!\f!\16!")
buf.write("\u00ef\13!\3!\3!\3!\3!\3\"\3\"\3\"\3#\3#\3#\7#\u00fb\n")
buf.write("#\f#\16#\u00fe\13#\3#\3#\3$\3$\3%\3%\3&\3&\3\'\3\'\3(")
buf.write("\3(\3)\3)\3*\3*\3*\3+\3+\3+\3,\3,\3,\3-\3-\3-\3.\3.\3")
buf.write(".\3.\3.\3.\3.\3.\3.\5.\u0123\n.\3/\3/\5/\u0127\n/\3/\3")
buf.write("/\3/\7/\u012c\n/\f/\16/\u012f\13/\3/\3/\7/\u0133\n/\f")
buf.write("/\16/\u0136\13/\3\60\3\60\3\60\5\60\u013b\n\60\3\60\5")
buf.write("\60\u013e\n\60\3\61\3\61\3\61\3\62\6\62\u0144\n\62\r\62")
buf.write("\16\62\u0145\3\63\3\63\5\63\u014a\n\63\3\63\3\63\3\64")
buf.write("\3\64\3\65\3\65\3\66\3\66\3\66\3\66\3\66\3\66\3\66\3\66")
buf.write("\3\66\3\66\3\66\7\66\u015d\n\66\f\66\16\66\u0160\13\66")
buf.write("\5\u00d8\u00ed\u00fc\2\67\3\3\5\4\7\5\t\6\13\7\r\b\17")
buf.write("\t\21\n\23\13\25\f\27\r\31\16\33\17\35\20\37\21!\22#\23")
buf.write("%\24\'\25)\26+\27-\30/\31\61\32\63\33\65\34\67\359\36")
buf.write(";\37= ?!A\"C\2E#G$I%K&M\'O(Q)S*U+W,Y-[.]/_\2a\60c\61e")
buf.write("\2g\2i\2k\62\3\2\b\5\2\13\f\17\17\"\"\4\2\f\f\17\17\4")
buf.write("\2GGgg\4\2--//\4\2C\\c|\3\2\62;\2\u016c\2\3\3\2\2\2\2")
buf.write("\5\3\2\2\2\2\7\3\2\2\2\2\t\3\2\2\2\2\13\3\2\2\2\2\r\3")
buf.write("\2\2\2\2\17\3\2\2\2\2\21\3\2\2\2\2\23\3\2\2\2\2\25\3\2")
buf.write("\2\2\2\27\3\2\2\2\2\31\3\2\2\2\2\33\3\2\2\2\2\35\3\2\2")
buf.write("\2\2\37\3\2\2\2\2!\3\2\2\2\2#\3\2\2\2\2%\3\2\2\2\2\'\3")
buf.write("\2\2\2\2)\3\2\2\2\2+\3\2\2\2\2-\3\2\2\2\2/\3\2\2\2\2\61")
buf.write("\3\2\2\2\2\63\3\2\2\2\2\67\3\2\2\2\29\3\2\2\2\2;\3\2\2")
buf.write("\2\2=\3\2\2\2\2?\3\2\2\2\2A\3\2\2\2\2C\3\2\2\2\2E\3\2")
buf.write("\2\2\2G\3\2\2\2\2I\3\2\2\2\2K\3\2\2\2\2M\3\2\2\2\2O\3")
buf.write("\2\2\2\2Q\3\2\2\2\2S\3\2\2\2\2U\3\2\2\2\2W\3\2\2\2\2[")
buf.write("\3\2\2\2\2]\3\2\2\2\2e\3\2\2\2\3g\3\2\2\2\5i\3\2\2\2\7")
buf.write("k\3\2\2\2\tm\3\2\2\2\13o\3\2\2\2\rq\3\2\2\2\17s\3\2\2")
buf.write("\2\21u\3\2\2\2\23w\3\2\2\2\25z\3\2\2\2\27\177\3\2\2\2")
buf.write("\31\u0083\3\2\2\2\33\u0085\3\2\2\2\35\u0087\3\2\2\2\37")
buf.write("\u008a\3\2\2\2!\u008d\3\2\2\2#\u0090\3\2\2\2%\u0094\3")
buf.write("\2\2\2\'\u0096\3\2\2\2)\u009d\3\2\2\2+\u009f\3\2\2\2-")
buf.write("\u00a4\3\2\2\2/\u00ab\3\2\2\2\61\u00bb\3\2\2\2\63\u00c1")
buf.write("\3\2\2\2\65\u00ce\3\2\2\2\67\u00d1\3\2\2\29\u00db\3\2")
buf.write("\2\2;\u00dd\3\2\2\2=\u00df\3\2\2\2?\u00e1\3\2\2\2A\u00e3")
buf.write("\3\2\2\2C\u00e5\3\2\2\2E\u00e7\3\2\2\2G\u00ea\3\2\2\2")
buf.write("I\u00ed\3\2\2\2K\u00f0\3\2\2\2M\u00fc\3\2\2\2O\u0100\3")
buf.write("\2\2\2Q\u0111\3\2\2\2S\u0114\3\2\2\2U\u0117\3\2\2\2W\u011a")
buf.write("\3\2\2\2Y\u0120\3\2\2\2[\u0128\3\2\2\2]\u012c\3\2\2\2")
buf.write("_\u0130\3\2\2\2a\u0136\3\2\2\2c\u0138\3\2\2\2e\u013a\3")
buf.write("\2\2\2gh\7.\2\2h\4\3\2\2\2ij\7*\2\2j\6\3\2\2\2kl\7+\2")
buf.write("\2l\b\3\2\2\2mn\7}\2\2n\n\3\2\2\2op\7\177\2\2p\f\3\2\2")
buf.write("\2qr\7\60\2\2r\16\3\2\2\2st\7]\2\2t\20\3\2\2\2uv\7_\2")
buf.write("\2v\22\3\2\2\2wx\7k\2\2xy\7h\2\2y\24\3\2\2\2z{\7g\2\2")
buf.write("{|\7n\2\2|}\7u\2\2}~\7g\2\2~\26\3\2\2\2\177\u0080\7n\2")
buf.write("\2\u0080\u0081\7g\2\2\u0081\u0082\7v\2\2\u0082\30\3\2")
buf.write("\2\2\u0083\u0084\7?\2\2\u0084\32\3\2\2\2\u0085\u0086\7")
buf.write("=\2\2\u0086\34\3\2\2\2\u0087\u0088\7=\2\2\u0088\u0089")
buf.write("\7=\2\2\u0089\36\3\2\2\2\u008a\u008b\7h\2\2\u008b\u008c")
buf.write("\7p\2\2\u008c \3\2\2\2\u008d\u008e\7/\2\2\u008e\u008f")
buf.write("\7@\2\2\u008f\"\3\2\2\2\u0090\u0091\7f\2\2\u0091\u0092")
buf.write("\7g\2\2\u0092\u0093\7h\2\2\u0093$\3\2\2\2\u0094\u0095")
buf.write("\7<\2\2\u0095&\3\2\2\2\u0096\u0097\7V\2\2\u0097\u0098")
buf.write("\7g\2\2\u0098\u0099\7p\2\2\u0099\u009a\7u\2\2\u009a\u009b")
buf.write("\7q\2\2\u009b\u009c\7t\2\2\u009c(\3\2\2\2\u009d\u009e")
buf.write("\7a\2\2\u009e*\3\2\2\2\u009f\u00a0\7o\2\2\u00a0\u00a1")
buf.write("\7g\2\2\u00a1\u00a2\7v\2\2\u00a2\u00a3\7c\2\2\u00a3,\3")
buf.write("\2\2\2\u00a4\u00a5\7x\2\2\u00a5\u00a6\7\62\2\2\u00a6\u00a7")
buf.write("\7\60\2\2\u00a7\u00a8\7\62\2\2\u00a8\u00a9\7\60\2\2\u00a9")
buf.write("\u00aa\7\65\2\2\u00aa.\3\2\2\2\u00ab\u00ac\7\61\2\2\u00ac")
buf.write("\u00ad\7,\2\2\u00ad\u00b2\3\2\2\2\u00ae\u00b1\5/\30\2")
buf.write("\u00af\u00b1\13\2\2\2\u00b0\u00ae\3\2\2\2\u00b0\u00af")
buf.write("\3\2\2\2\u00b1\u00b4\3\2\2\2\u00b2\u00b3\3\2\2\2\u00b2")
buf.write("\u00b0\3\2\2\2\u00b3\u00b5\3\2\2\2\u00b4\u00b2\3\2\2\2")
buf.write("\u00b5\u00b6\7,\2\2\u00b6\u00b7\7\61\2\2\u00b7\u00b8\3")
buf.write("\2\2\2\u00b8\u00b9\b\30\2\2\u00b9\60\3\2\2\2\u00ba\u00bc")
buf.write("\t\2\2\2\u00bb\u00ba\3\2\2\2\u00bc\u00bd\3\2\2\2\u00bd")
buf.write("\u00bb\3\2\2\2\u00bd\u00be\3\2\2\2\u00be\u00bf\3\2\2\2")
buf.write("\u00bf\u00c0\b\31\2\2\u00c0\62\3\2\2\2\u00c1\u00c2\7\61")
buf.write("\2\2\u00c2\u00c3\7\61\2\2\u00c3\u00c7\3\2\2\2\u00c4\u00c6")
buf.write("\13\2\2\2\u00c5\u00c4\3\2\2\2\u00c6\u00c9\3\2\2\2\u00c7")
buf.write("\u00c8\3\2\2\2\u00c7\u00c5\3\2\2\2\u00c8\u00ca\3\2\2\2")
buf.write("\u00c9\u00c7\3\2\2\2\u00ca\u00cb\7\f\2\2\u00cb\u00cc\3")
buf.write("\2\2\2\u00cc\u00cd\b\32\2\2\u00cd\64\3\2\2\2\u00ce\u00cf")
buf.write("\7^\2\2\u00cf\u00d0\7$\2\2\u00d0\66\3\2\2\2\u00d1\u00d6")
buf.write("\7$\2\2\u00d2\u00d5\5\65\33\2\u00d3\u00d5\n\3\2\2\u00d4")
buf.write("\u00d2\3\2\2\2\u00d4\u00d3\3\2\2\2\u00d5\u00d8\3\2\2\2")
buf.write("\u00d6\u00d7\3\2\2\2\u00d6\u00d4\3\2\2\2\u00d7\u00d9\3")
buf.write("\2\2\2\u00d8\u00d6\3\2\2\2\u00d9\u00da\7$\2\2\u00da8\3")
buf.write("\2\2\2\u00db\u00dc\7,\2\2\u00dc:\3\2\2\2\u00dd\u00de\7")
buf.write("\61\2\2\u00de<\3\2\2\2\u00df\u00e0\7-\2\2\u00e0>\3\2\2")
buf.write("\2\u00e1\u00e2\7/\2\2\u00e2@\3\2\2\2\u00e3\u00e4\7>\2")
buf.write("\2\u00e4B\3\2\2\2\u00e5\u00e6\7@\2\2\u00e6D\3\2\2\2\u00e7")
buf.write("\u00e8\7>\2\2\u00e8\u00e9\7?\2\2\u00e9F\3\2\2\2\u00ea")
buf.write("\u00eb\7@\2\2\u00eb\u00ec\7?\2\2\u00ecH\3\2\2\2\u00ed")
buf.write("\u00ee\7?\2\2\u00ee\u00ef\7?\2\2\u00efJ\3\2\2\2\u00f0")
buf.write("\u00f1\7#\2\2\u00f1\u00f2\7?\2\2\u00f2L\3\2\2\2\u00f3")
buf.write("\u00f4\7V\2\2\u00f4\u00f5\7t\2\2\u00f5\u00f6\7w\2\2\u00f6")
buf.write("\u00fd\7g\2\2\u00f7\u00f8\7H\2\2\u00f8\u00f9\7c\2\2\u00f9")
buf.write("\u00fa\7n\2\2\u00fa\u00fb\7u\2\2\u00fb\u00fd\7g\2\2\u00fc")
buf.write("\u00f3\3\2\2\2\u00fc\u00f7\3\2\2\2\u00fdN\3\2\2\2\u00fe")
buf.write("\u0101\7a\2\2\u00ff\u0101\5a\61\2\u0100\u00fe\3\2\2\2")
buf.write("\u0100\u00ff\3\2\2\2\u0101\u0107\3\2\2\2\u0102\u0106\7")
buf.write("a\2\2\u0103\u0106\5a\61\2\u0104\u0106\5c\62\2\u0105\u0102")
buf.write("\3\2\2\2\u0105\u0103\3\2\2\2\u0105\u0104\3\2\2\2\u0106")
buf.write("\u0109\3\2\2\2\u0107\u0105\3\2\2\2\u0107\u0108\3\2\2\2")
buf.write("\u0108\u010e\3\2\2\2\u0109\u0107\3\2\2\2\u010a\u010b\7")
buf.write("\60\2\2\u010b\u010d\5O(\2\u010c\u010a\3\2\2\2\u010d\u0110")
buf.write("\3\2\2\2\u010e\u010c\3\2\2\2\u010e\u010f\3\2\2\2\u010f")
buf.write("P\3\2\2\2\u0110\u010e\3\2\2\2\u0111\u0112\7B\2\2\u0112")
buf.write("\u0113\5O(\2\u0113R\3\2\2\2\u0114\u0115\7\'\2\2\u0115")
buf.write("\u0116\5O(\2\u0116T\3\2\2\2\u0117\u0118\7\'\2\2\u0118")
buf.write("\u0119\5]/\2\u0119V\3\2\2\2\u011a\u011b\7k\2\2\u011b\u011c")
buf.write("\7p\2\2\u011c\u011d\7v\2\2\u011d\u011e\78\2\2\u011e\u011f")
buf.write("\7\66\2\2\u011fX\3\2\2\2\u0120\u0123\5]/\2\u0121\u0122")
buf.write("\7\60\2\2\u0122\u0124\5]/\2\u0123\u0121\3\2\2\2\u0123")
buf.write("\u0124\3\2\2\2\u0124\u0126\3\2\2\2\u0125\u0127\5_\60\2")
buf.write("\u0126\u0125\3\2\2\2\u0126\u0127\3\2\2\2\u0127Z\3\2\2")
buf.write("\2\u0128\u0129\5Y-\2\u0129\u012a\7h\2\2\u012a\\\3\2\2")
buf.write("\2\u012b\u012d\5c\62\2\u012c\u012b\3\2\2\2\u012d\u012e")
buf.write("\3\2\2\2\u012e\u012c\3\2\2\2\u012e\u012f\3\2\2\2\u012f")
buf.write("^\3\2\2\2\u0130\u0132\t\4\2\2\u0131\u0133\t\5\2\2\u0132")
buf.write("\u0131\3\2\2\2\u0132\u0133\3\2\2\2\u0133\u0134\3\2\2\2")
buf.write("\u0134\u0135\5]/\2\u0135`\3\2\2\2\u0136\u0137\t\6\2\2")
buf.write("\u0137b\3\2\2\2\u0138\u0139\t\7\2\2\u0139d\3\2\2\2\u013a")
buf.write("\u013b\7O\2\2\u013b\u013c\7G\2\2\u013c\u013d\7V\2\2\u013d")
buf.write("\u013e\7C\2\2\u013e\u013f\7F\2\2\u013f\u0140\7C\2\2\u0140")
buf.write("\u0141\7V\2\2\u0141\u0142\7C\2\2\u0142\u0143\7<\2\2\u0143")
buf.write("\u0147\3\2\2\2\u0144\u0146\13\2\2\2\u0145\u0144\3\2\2")
buf.write("\2\u0146\u0149\3\2\2\2\u0147\u0145\3\2\2\2\u0147\u0148")
buf.write("\3\2\2\2\u0148f\3\2\2\2\u0149\u0147\3\2\2\2\23\2\u00b0")
buf.write("\u00b2\u00bd\u00c7\u00d4\u00d6\u00fc\u0100\u0105\u0107")
buf.write("\u010e\u0123\u0126\u012e\u0132\u0147\3\b\2\2")
buf.write("\3\2\2\2\2\63\3\2\2\2\2\65\3\2\2\2\2\67\3\2\2\2\29\3\2")
buf.write("\2\2\2;\3\2\2\2\2=\3\2\2\2\2?\3\2\2\2\2A\3\2\2\2\2E\3")
buf.write("\2\2\2\2G\3\2\2\2\2I\3\2\2\2\2K\3\2\2\2\2M\3\2\2\2\2O")
buf.write("\3\2\2\2\2Q\3\2\2\2\2S\3\2\2\2\2U\3\2\2\2\2W\3\2\2\2\2")
buf.write("Y\3\2\2\2\2[\3\2\2\2\2]\3\2\2\2\2a\3\2\2\2\2c\3\2\2\2")
buf.write("\2k\3\2\2\2\3m\3\2\2\2\5o\3\2\2\2\7q\3\2\2\2\ts\3\2\2")
buf.write("\2\13u\3\2\2\2\rw\3\2\2\2\17y\3\2\2\2\21{\3\2\2\2\23}")
buf.write("\3\2\2\2\25\177\3\2\2\2\27\u0082\3\2\2\2\31\u0087\3\2")
buf.write("\2\2\33\u0089\3\2\2\2\35\u008b\3\2\2\2\37\u008f\3\2\2")
buf.write("\2!\u0091\3\2\2\2#\u0093\3\2\2\2%\u0096\3\2\2\2\'\u0099")
buf.write("\3\2\2\2)\u009c\3\2\2\2+\u00a0\3\2\2\2-\u00a7\3\2\2\2")
buf.write("/\u00ac\3\2\2\2\61\u00af\3\2\2\2\63\u00b5\3\2\2\2\65\u00bc")
buf.write("\3\2\2\2\67\u00be\3\2\2\29\u00c5\3\2\2\2;\u00ca\3\2\2")
buf.write("\2=\u00d1\3\2\2\2?\u00e1\3\2\2\2A\u00e7\3\2\2\2C\u00f4")
buf.write("\3\2\2\2E\u00f7\3\2\2\2G\u0101\3\2\2\2I\u0103\3\2\2\2")
buf.write("K\u0105\3\2\2\2M\u0107\3\2\2\2O\u0109\3\2\2\2Q\u010b\3")
buf.write("\2\2\2S\u010d\3\2\2\2U\u0110\3\2\2\2W\u0113\3\2\2\2Y\u0116")
buf.write("\3\2\2\2[\u0122\3\2\2\2]\u0126\3\2\2\2_\u0137\3\2\2\2")
buf.write("a\u013f\3\2\2\2c\u0143\3\2\2\2e\u0147\3\2\2\2g\u014d\3")
buf.write("\2\2\2i\u014f\3\2\2\2k\u0151\3\2\2\2mn\7\60\2\2n\4\3\2")
buf.write("\2\2op\7B\2\2p\6\3\2\2\2qr\7\'\2\2r\b\3\2\2\2st\7a\2\2")
buf.write("t\n\3\2\2\2uv\7.\2\2v\f\3\2\2\2wx\7*\2\2x\16\3\2\2\2y")
buf.write("z\7+\2\2z\20\3\2\2\2{|\7]\2\2|\22\3\2\2\2}~\7_\2\2~\24")
buf.write("\3\2\2\2\177\u0080\7k\2\2\u0080\u0081\7h\2\2\u0081\26")
buf.write("\3\2\2\2\u0082\u0083\7g\2\2\u0083\u0084\7n\2\2\u0084\u0085")
buf.write("\7u\2\2\u0085\u0086\7g\2\2\u0086\30\3\2\2\2\u0087\u0088")
buf.write("\7}\2\2\u0088\32\3\2\2\2\u0089\u008a\7\177\2\2\u008a\34")
buf.write("\3\2\2\2\u008b\u008c\7n\2\2\u008c\u008d\7g\2\2\u008d\u008e")
buf.write("\7v\2\2\u008e\36\3\2\2\2\u008f\u0090\7?\2\2\u0090 \3\2")
buf.write("\2\2\u0091\u0092\7=\2\2\u0092\"\3\2\2\2\u0093\u0094\7")
buf.write("=\2\2\u0094\u0095\7=\2\2\u0095$\3\2\2\2\u0096\u0097\7")
buf.write("h\2\2\u0097\u0098\7p\2\2\u0098&\3\2\2\2\u0099\u009a\7")
buf.write("/\2\2\u009a\u009b\7@\2\2\u009b(\3\2\2\2\u009c\u009d\7")
buf.write("f\2\2\u009d\u009e\7g\2\2\u009e\u009f\7h\2\2\u009f*\3\2")
buf.write("\2\2\u00a0\u00a1\7g\2\2\u00a1\u00a2\7z\2\2\u00a2\u00a3")
buf.write("\7v\2\2\u00a3\u00a4\7g\2\2\u00a4\u00a5\7t\2\2\u00a5\u00a6")
buf.write("\7p\2\2\u00a6,\3\2\2\2\u00a7\u00a8\7v\2\2\u00a8\u00a9")
buf.write("\7{\2\2\u00a9\u00aa\7r\2\2\u00aa\u00ab\7g\2\2\u00ab.\3")
buf.write("\2\2\2\u00ac\u00ad\7?\2\2\u00ad\u00ae\7@\2\2\u00ae\60")
buf.write("\3\2\2\2\u00af\u00b0\7o\2\2\u00b0\u00b1\7c\2\2\u00b1\u00b2")
buf.write("\7v\2\2\u00b2\u00b3\7e\2\2\u00b3\u00b4\7j\2\2\u00b4\62")
buf.write("\3\2\2\2\u00b5\u00b6\7o\2\2\u00b6\u00b7\7c\2\2\u00b7\u00b8")
buf.write("\7v\2\2\u00b8\u00b9\7e\2\2\u00b9\u00ba\7j\2\2\u00ba\u00bb")
buf.write("\7A\2\2\u00bb\64\3\2\2\2\u00bc\u00bd\7<\2\2\u00bd\66\3")
buf.write("\2\2\2\u00be\u00bf\7V\2\2\u00bf\u00c0\7g\2\2\u00c0\u00c1")
buf.write("\7p\2\2\u00c1\u00c2\7u\2\2\u00c2\u00c3\7q\2\2\u00c3\u00c4")
buf.write("\7t\2\2\u00c48\3\2\2\2\u00c5\u00c6\7o\2\2\u00c6\u00c7")
buf.write("\7g\2\2\u00c7\u00c8\7v\2\2\u00c8\u00c9\7c\2\2\u00c9:\3")
buf.write("\2\2\2\u00ca\u00cb\7x\2\2\u00cb\u00cc\7\62\2\2\u00cc\u00cd")
buf.write("\7\60\2\2\u00cd\u00ce\7\62\2\2\u00ce\u00cf\7\60\2\2\u00cf")
buf.write("\u00d0\7\66\2\2\u00d0<\3\2\2\2\u00d1\u00d2\7\61\2\2\u00d2")
buf.write("\u00d3\7,\2\2\u00d3\u00d8\3\2\2\2\u00d4\u00d7\5=\37\2")
buf.write("\u00d5\u00d7\13\2\2\2\u00d6\u00d4\3\2\2\2\u00d6\u00d5")
buf.write("\3\2\2\2\u00d7\u00da\3\2\2\2\u00d8\u00d9\3\2\2\2\u00d8")
buf.write("\u00d6\3\2\2\2\u00d9\u00db\3\2\2\2\u00da\u00d8\3\2\2\2")
buf.write("\u00db\u00dc\7,\2\2\u00dc\u00dd\7\61\2\2\u00dd\u00de\3")
buf.write("\2\2\2\u00de\u00df\b\37\2\2\u00df>\3\2\2\2\u00e0\u00e2")
buf.write("\t\2\2\2\u00e1\u00e0\3\2\2\2\u00e2\u00e3\3\2\2\2\u00e3")
buf.write("\u00e1\3\2\2\2\u00e3\u00e4\3\2\2\2\u00e4\u00e5\3\2\2\2")
buf.write("\u00e5\u00e6\b \2\2\u00e6@\3\2\2\2\u00e7\u00e8\7\61\2")
buf.write("\2\u00e8\u00e9\7\61\2\2\u00e9\u00ed\3\2\2\2\u00ea\u00ec")
buf.write("\13\2\2\2\u00eb\u00ea\3\2\2\2\u00ec\u00ef\3\2\2\2\u00ed")
buf.write("\u00ee\3\2\2\2\u00ed\u00eb\3\2\2\2\u00ee\u00f0\3\2\2\2")
buf.write("\u00ef\u00ed\3\2\2\2\u00f0\u00f1\7\f\2\2\u00f1\u00f2\3")
buf.write("\2\2\2\u00f2\u00f3\b!\2\2\u00f3B\3\2\2\2\u00f4\u00f5\7")
buf.write("^\2\2\u00f5\u00f6\7$\2\2\u00f6D\3\2\2\2\u00f7\u00fc\7")
buf.write("$\2\2\u00f8\u00fb\5C\"\2\u00f9\u00fb\n\3\2\2\u00fa\u00f8")
buf.write("\3\2\2\2\u00fa\u00f9\3\2\2\2\u00fb\u00fe\3\2\2\2\u00fc")
buf.write("\u00fd\3\2\2\2\u00fc\u00fa\3\2\2\2\u00fd\u00ff\3\2\2\2")
buf.write("\u00fe\u00fc\3\2\2\2\u00ff\u0100\7$\2\2\u0100F\3\2\2\2")
buf.write("\u0101\u0102\7,\2\2\u0102H\3\2\2\2\u0103\u0104\7\61\2")
buf.write("\2\u0104J\3\2\2\2\u0105\u0106\7-\2\2\u0106L\3\2\2\2\u0107")
buf.write("\u0108\7/\2\2\u0108N\3\2\2\2\u0109\u010a\7>\2\2\u010a")
buf.write("P\3\2\2\2\u010b\u010c\7@\2\2\u010cR\3\2\2\2\u010d\u010e")
buf.write("\7>\2\2\u010e\u010f\7?\2\2\u010fT\3\2\2\2\u0110\u0111")
buf.write("\7@\2\2\u0111\u0112\7?\2\2\u0112V\3\2\2\2\u0113\u0114")
buf.write("\7?\2\2\u0114\u0115\7?\2\2\u0115X\3\2\2\2\u0116\u0117")
buf.write("\7#\2\2\u0117\u0118\7?\2\2\u0118Z\3\2\2\2\u0119\u011a")
buf.write("\7V\2\2\u011a\u011b\7t\2\2\u011b\u011c\7w\2\2\u011c\u0123")
buf.write("\7g\2\2\u011d\u011e\7H\2\2\u011e\u011f\7c\2\2\u011f\u0120")
buf.write("\7n\2\2\u0120\u0121\7u\2\2\u0121\u0123\7g\2\2\u0122\u0119")
buf.write("\3\2\2\2\u0122\u011d\3\2\2\2\u0123\\\3\2\2\2\u0124\u0127")
buf.write("\7a\2\2\u0125\u0127\5g\64\2\u0126\u0124\3\2\2\2\u0126")
buf.write("\u0125\3\2\2\2\u0127\u012d\3\2\2\2\u0128\u012c\7a\2\2")
buf.write("\u0129\u012c\5g\64\2\u012a\u012c\5i\65\2\u012b\u0128\3")
buf.write("\2\2\2\u012b\u0129\3\2\2\2\u012b\u012a\3\2\2\2\u012c\u012f")
buf.write("\3\2\2\2\u012d\u012b\3\2\2\2\u012d\u012e\3\2\2\2\u012e")
buf.write("\u0134\3\2\2\2\u012f\u012d\3\2\2\2\u0130\u0131\7\60\2")
buf.write("\2\u0131\u0133\5]/\2\u0132\u0130\3\2\2\2\u0133\u0136\3")
buf.write("\2\2\2\u0134\u0132\3\2\2\2\u0134\u0135\3\2\2\2\u0135^")
buf.write("\3\2\2\2\u0136\u0134\3\2\2\2\u0137\u013a\5c\62\2\u0138")
buf.write("\u0139\7\60\2\2\u0139\u013b\5c\62\2\u013a\u0138\3\2\2")
buf.write("\2\u013a\u013b\3\2\2\2\u013b\u013d\3\2\2\2\u013c\u013e")
buf.write("\5e\63\2\u013d\u013c\3\2\2\2\u013d\u013e\3\2\2\2\u013e")
buf.write("`\3\2\2\2\u013f\u0140\5_\60\2\u0140\u0141\7h\2\2\u0141")
buf.write("b\3\2\2\2\u0142\u0144\5i\65\2\u0143\u0142\3\2\2\2\u0144")
buf.write("\u0145\3\2\2\2\u0145\u0143\3\2\2\2\u0145\u0146\3\2\2\2")
buf.write("\u0146d\3\2\2\2\u0147\u0149\t\4\2\2\u0148\u014a\t\5\2")
buf.write("\2\u0149\u0148\3\2\2\2\u0149\u014a\3\2\2\2\u014a\u014b")
buf.write("\3\2\2\2\u014b\u014c\5c\62\2\u014cf\3\2\2\2\u014d\u014e")
buf.write("\t\6\2\2\u014eh\3\2\2\2\u014f\u0150\t\7\2\2\u0150j\3\2")
buf.write("\2\2\u0151\u0152\7O\2\2\u0152\u0153\7G\2\2\u0153\u0154")
buf.write("\7V\2\2\u0154\u0155\7C\2\2\u0155\u0156\7F\2\2\u0156\u0157")
buf.write("\7C\2\2\u0157\u0158\7V\2\2\u0158\u0159\7C\2\2\u0159\u015a")
buf.write("\7<\2\2\u015a\u015e\3\2\2\2\u015b\u015d\13\2\2\2\u015c")
buf.write("\u015b\3\2\2\2\u015d\u0160\3\2\2\2\u015e\u015c\3\2\2\2")
buf.write("\u015e\u015f\3\2\2\2\u015fl\3\2\2\2\u0160\u015e\3\2\2")
buf.write("\2\23\2\u00d6\u00d8\u00e3\u00ed\u00fa\u00fc\u0122\u0126")
buf.write("\u012b\u012d\u0134\u013a\u013d\u0145\u0149\u015e\3\b\2")
buf.write("\2")
return buf.getvalue()
......@@ -178,62 +190,65 @@ class RelayLexer(Lexer):
T__18 = 19
T__19 = 20
T__20 = 21
SEMVER = 22
COMMENT = 23
WS = 24
LINE_COMMENT = 25
QUOTED_STRING = 26
MUL = 27
DIV = 28
ADD = 29
SUB = 30
LT = 31
GT = 32
LE = 33
GE = 34
EQ = 35
NE = 36
BOOL_LIT = 37
CNAME = 38
GLOBAL_VAR = 39
LOCAL_VAR = 40
GRAPH_VAR = 41
DATATYPE = 42
FLOAT = 43
NAT = 44
METADATA = 45
T__21 = 22
T__22 = 23
T__23 = 24
T__24 = 25
T__25 = 26
T__26 = 27
T__27 = 28
SEMVER = 29
COMMENT = 30
WS = 31
LINE_COMMENT = 32
QUOTED_STRING = 33
MUL = 34
DIV = 35
ADD = 36
SUB = 37
LT = 38
GT = 39
LE = 40
GE = 41
EQ = 42
NE = 43
BOOL_LIT = 44
CNAME = 45
FLOAT = 46
NAT = 47
METADATA = 48
channelNames = [ u"DEFAULT_TOKEN_CHANNEL", u"HIDDEN" ]
modeNames = [ "DEFAULT_MODE" ]
literalNames = [ "<INVALID>",
"','", "'('", "')'", "'{'", "'}'", "'.'", "'['", "']'", "'if'",
"'else'", "'let'", "'='", "';'", "';;'", "'fn'", "'->'", "'def'",
"':'", "'Tensor'", "'_'", "'meta'", "'v0.0.3'", "'*'", "'/'",
"'+'", "'-'", "'<'", "'>'", "'<='", "'>='", "'=='", "'!='",
"'int64'" ]
"'.'", "'@'", "'%'", "'_'", "','", "'('", "')'", "'['", "']'",
"'if'", "'else'", "'{'", "'}'", "'let'", "'='", "';'", "';;'",
"'fn'", "'->'", "'def'", "'extern'", "'type'", "'=>'", "'match'",
"'match?'", "':'", "'Tensor'", "'meta'", "'v0.0.4'", "'*'",
"'/'", "'+'", "'-'", "'<'", "'>'", "'<='", "'>='", "'=='", "'!='" ]
symbolicNames = [ "<INVALID>",
"SEMVER", "COMMENT", "WS", "LINE_COMMENT", "QUOTED_STRING",
"MUL", "DIV", "ADD", "SUB", "LT", "GT", "LE", "GE", "EQ", "NE",
"BOOL_LIT", "CNAME", "GLOBAL_VAR", "LOCAL_VAR", "GRAPH_VAR",
"DATATYPE", "FLOAT", "NAT", "METADATA" ]
"BOOL_LIT", "CNAME", "FLOAT", "NAT", "METADATA" ]
ruleNames = [ "T__0", "T__1", "T__2", "T__3", "T__4", "T__5", "T__6",
"T__7", "T__8", "T__9", "T__10", "T__11", "T__12", "T__13",
"T__14", "T__15", "T__16", "T__17", "T__18", "T__19",
"T__20", "SEMVER", "COMMENT", "WS", "LINE_COMMENT", "ESCAPED_QUOTE",
"QUOTED_STRING", "MUL", "DIV", "ADD", "SUB", "LT", "GT",
"LE", "GE", "EQ", "NE", "BOOL_LIT", "CNAME", "GLOBAL_VAR",
"LOCAL_VAR", "GRAPH_VAR", "DATATYPE", "PREFLOAT", "FLOAT",
"NAT", "EXP", "LETTER", "DIGIT", "METADATA" ]
"T__20", "T__21", "T__22", "T__23", "T__24", "T__25",
"T__26", "T__27", "SEMVER", "COMMENT", "WS", "LINE_COMMENT",
"ESCAPED_QUOTE", "QUOTED_STRING", "MUL", "DIV", "ADD",
"SUB", "LT", "GT", "LE", "GE", "EQ", "NE", "BOOL_LIT",
"CNAME", "PREFLOAT", "FLOAT", "NAT", "EXP", "LETTER",
"DIGIT", "METADATA" ]
grammarFileName = "Relay.g4"
def __init__(self, input=None, output:TextIO = sys.stdout):
super().__init__(input, output)
self.checkVersion("4.7.1")
self.checkVersion("4.7.2")
self._interp = LexerATNSimulator(self, self.atn, self.decisionsToDFA, PredictionContextCache())
self._actions = None
self._predicates = None
......
This source diff could not be displayed because it is too large. You can view the blob instead.
# Generated from /workspace/python/tvm/relay/grammar/Relay.g4 by ANTLR 4.7.1
# Generated from /Users/doobs/Code/repo/sampl/tvm/python/tvm/relay/grammar/Relay.g4 by ANTLR 4.7.2
from antlr4 import *
if __name__ is not None and "." in __name__:
from .RelayParser import RelayParser
......@@ -9,13 +9,28 @@ else:
class RelayVisitor(ParseTreeVisitor):
# Visit a parse tree produced by RelayParser#opIdent.
def visitOpIdent(self, ctx:RelayParser.OpIdentContext):
# Visit a parse tree produced by RelayParser#prog.
def visitProg(self, ctx:RelayParser.ProgContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#prog.
def visitProg(self, ctx:RelayParser.ProgContext):
# Visit a parse tree produced by RelayParser#generalIdent.
def visitGeneralIdent(self, ctx:RelayParser.GeneralIdentContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#globalVar.
def visitGlobalVar(self, ctx:RelayParser.GlobalVarContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#localVar.
def visitLocalVar(self, ctx:RelayParser.LocalVarContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#graphVar.
def visitGraphVar(self, ctx:RelayParser.GraphVarContext):
return self.visitChildren(ctx)
......@@ -44,6 +59,11 @@ class RelayVisitor(ParseTreeVisitor):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#match.
def visitMatch(self, ctx:RelayParser.MatchContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#tensor.
def visitTensor(self, ctx:RelayParser.TensorContext):
return self.visitChildren(ctx)
......@@ -114,8 +134,73 @@ class RelayVisitor(ParseTreeVisitor):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#defn.
def visitDefn(self, ctx:RelayParser.DefnContext):
# Visit a parse tree produced by RelayParser#funcDefn.
def visitFuncDefn(self, ctx:RelayParser.FuncDefnContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#externAdtDefn.
def visitExternAdtDefn(self, ctx:RelayParser.ExternAdtDefnContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#adtDefn.
def visitAdtDefn(self, ctx:RelayParser.AdtDefnContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#constructorName.
def visitConstructorName(self, ctx:RelayParser.ConstructorNameContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#adtConsDefnList.
def visitAdtConsDefnList(self, ctx:RelayParser.AdtConsDefnListContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#adtConsDefn.
def visitAdtConsDefn(self, ctx:RelayParser.AdtConsDefnContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#matchClauseList.
def visitMatchClauseList(self, ctx:RelayParser.MatchClauseListContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#matchClause.
def visitMatchClause(self, ctx:RelayParser.MatchClauseContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#matchType.
def visitMatchType(self, ctx:RelayParser.MatchTypeContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#patternList.
def visitPatternList(self, ctx:RelayParser.PatternListContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#pattern.
def visitPattern(self, ctx:RelayParser.PatternContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#adtCons.
def visitAdtCons(self, ctx:RelayParser.AdtConsContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#adtConsParamList.
def visitAdtConsParamList(self, ctx:RelayParser.AdtConsParamListContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#adtConsParam.
def visitAdtConsParam(self, ctx:RelayParser.AdtConsParamContext):
return self.visitChildren(ctx)
......@@ -149,13 +234,13 @@ class RelayVisitor(ParseTreeVisitor):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#typeParamList.
def visitTypeParamList(self, ctx:RelayParser.TypeParamListContext):
# Visit a parse tree produced by RelayParser#tupleType.
def visitTupleType(self, ctx:RelayParser.TupleTypeContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#tupleType.
def visitTupleType(self, ctx:RelayParser.TupleTypeContext):
# Visit a parse tree produced by RelayParser#typeCallType.
def visitTypeCallType(self, ctx:RelayParser.TypeCallTypeContext):
return self.visitChildren(ctx)
......@@ -179,8 +264,8 @@ class RelayVisitor(ParseTreeVisitor):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#intType.
def visitIntType(self, ctx:RelayParser.IntTypeContext):
# Visit a parse tree produced by RelayParser#typeParamList.
def visitTypeParamList(self, ctx:RelayParser.TypeParamListContext):
return self.visitChildren(ctx)
......@@ -209,11 +294,6 @@ class RelayVisitor(ParseTreeVisitor):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#typeIdent.
def visitTypeIdent(self, ctx:RelayParser.TypeIdentContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#body.
def visitBody(self, ctx:RelayParser.BodyContext):
return self.visitChildren(ctx)
......
......@@ -16,7 +16,7 @@
* specific language governing permissions and limitations
* under the License.
*/
v0.0.3
v0.0.4
def @id[a](%x: a) -> a {
%x
......
......@@ -70,7 +70,10 @@ class AlphaEqualHandler:
}
if (lhsm->type_definitions.size() != rhsm->type_definitions.size()) return false;
for (const auto& p : lhsm->type_definitions) {
if (!Equal(p.second, rhsm->LookupDef(p.first->var->name_hint))) return false;
if (!rhsm->HasDef(p.first->var->name_hint) ||
!Equal(p.second, rhsm->LookupDef(p.first->var->name_hint))) {
return false;
}
}
return true;
}
......@@ -288,7 +291,7 @@ class AlphaEqualHandler:
}
bool VisitType_(const GlobalTypeVarNode* lhs, const Type& other) final {
return GetRef<Type>(lhs) == other;
return LeafNodeEqual(GetRef<NodeRef>(lhs), other);
}
bool VisitType_(const TypeCallNode* lhs, const Type& other) final {
......@@ -307,6 +310,26 @@ class AlphaEqualHandler:
return true;
}
bool VisitType_(const TypeDataNode* lhs, const Type& other) final {
const TypeDataNode* rhs = other.as<TypeDataNode>();
if (rhs == nullptr
|| lhs->type_vars.size() != rhs->type_vars.size()
|| !TypeEqual(lhs->header, rhs->header)) {
return false;
}
for (size_t i = 0; i < lhs->type_vars.size(); ++i) {
if (!TypeEqual(lhs->type_vars[i], rhs->type_vars[i])) {
return false;
}
}
for (size_t i = 0; i < lhs->constructors.size(); ++i) {
if (!ExprEqual(lhs->constructors[i], rhs->constructors[i])) {
return false;
}
}
return true;
}
// Expr equal checking.
bool NDArrayEqual(const runtime::NDArray& lhs,
const runtime::NDArray& rhs) {
......@@ -485,7 +508,10 @@ class AlphaEqualHandler:
}
bool VisitExpr_(const ConstructorNode* lhs, const Expr& other) final {
return GetRef<Expr>(lhs) == other;
if (const ConstructorNode* rhs = other.as<ConstructorNode>()) {
return lhs->name_hint == rhs->name_hint;
}
return false;
}
bool ClauseEqual(const Clause& lhs, const Clause& rhs) {
......@@ -582,7 +608,7 @@ TVM_REGISTER_API("relay._make._alpha_equal")
TVM_REGISTER_API("relay._make._assert_alpha_equal")
.set_body_typed<void(NodeRef, NodeRef)>([](NodeRef a, NodeRef b) {
bool alpha_equal = AlphaEqualHandler(false, true).Equal(a, b);
CHECK(alpha_equal) << AsText(a, true) << " and " << AsText(b, true) << " is not alpha equal";
CHECK(alpha_equal) << AsText(a, true) << " and " << AsText(b, true) << " are not alpha equal";
});
TVM_REGISTER_API("relay._make._graph_equal")
......@@ -593,7 +619,7 @@ TVM_REGISTER_API("relay._make._graph_equal")
TVM_REGISTER_API("relay._make._assert_graph_equal")
.set_body_typed<void(NodeRef, NodeRef)>([](NodeRef a, NodeRef b) {
bool graph_equal = AlphaEqualHandler(true, true).Equal(a, b);
CHECK(graph_equal) << AsText(a, true) << " and " << AsText(b, true) << " is not graph equal";
CHECK(graph_equal) << AsText(a, true) << " and " << AsText(b, true) << " are not graph equal";
});
} // namespace relay
......
......@@ -206,6 +206,11 @@ TypeData ModuleNode::LookupDef(const std::string& name) const {
return this->LookupDef(id);
}
bool ModuleNode::HasDef(const std::string& name) const {
auto it = global_type_var_map_.find(name);
return it != global_type_var_map_.end();
}
Constructor ModuleNode::LookupTag(const int32_t tag) {
auto it = constructor_tag_map_.find(tag);
CHECK(it != constructor_tag_map_.end())
......
......@@ -44,6 +44,8 @@
namespace tvm {
namespace relay {
static const char* kSemVer = "v0.0.4";
Doc Brace(const Doc& d,
const std::string& open = "{",
const std::string& close = "}",
......@@ -239,6 +241,8 @@ class PrettyPrinter :
return PrintExpr(Downcast<Expr>(node), meta, try_inline);
} else if (node.as_derived<TypeNode>()) {
return PrintType(Downcast<Type>(node), meta);
} else if (node.as_derived<PatternNode>()) {
return PrintPattern(Downcast<Pattern>(node), meta);
} else if (node.as_derived<ModuleNode>()) {
return PrintMod(Downcast<Module>(node));
} else {
......@@ -313,7 +317,7 @@ class PrettyPrinter :
if (name.length() == 0 || !std::isalpha(name[0])) {
name = "t" + name;
}
Doc val = GetUniqueName("%" + name);
Doc val = GetUniqueName(name);
memo_type_[var] = val;
if (var->kind != kType) {
val << ": " << Print(var->kind);
......@@ -347,13 +351,17 @@ class PrettyPrinter :
}
bool IsUnique(const Expr& expr) {
return !(dg_.expr_node.at(expr)->parents.head &&
dg_.expr_node.at(expr)->parents.head->next);
auto it = dg_.expr_node.find(expr);
if (it == dg_.expr_node.end()) {
return true;
} else {
return !(it->second->parents.head && it->second->parents.head->next);
}
}
bool AlwaysInline(const Expr& expr) {
return expr.as<GlobalVarNode>() || expr.as<ConstantNode>() ||
expr.as<OpNode>() || expr.as<VarNode>();
return expr.as<GlobalVarNode>() || expr.as<ConstantNode>() || expr.as<OpNode>() ||
expr.as<VarNode>() || expr.as<ConstructorNode>();
}
//------------------------------------
......@@ -380,9 +388,9 @@ class PrettyPrinter :
} else if (!inline_expr && expr.as<LetNode>()) {
// wrap GNFed let in brackets
Doc body;
printed_expr << "{";
printed_expr << "(";
printed_expr << Indent(2, body << PrintNewLine() << VisitExpr(expr)) << PrintNewLine();
printed_expr << "}";
printed_expr << ")";
} else {
printed_expr = VisitExpr(expr);
}
......@@ -483,13 +491,13 @@ class PrettyPrinter :
Doc doc;
doc << prefix;
if (fn->type_params.size() > 0) {
doc << "<";
doc << "[";
std::vector<Doc> type_params;
for (const TypeVar& tv : fn->type_params) {
type_params.push_back(AllocTypeVar(tv));
type_params.push_back(Doc(tv->var->name_hint));
}
doc << PrintSep(type_params);
doc << ">";
doc << "]";
}
doc << "(";
std::vector<Doc> params;
......@@ -510,6 +518,15 @@ class PrettyPrinter :
Doc PrintMod(const Module& mod) {
Doc doc;
int counter = 0;
// type definitions
for (const auto& kv : mod->type_definitions) {
if (counter++ != 0) {
doc << PrintNewLine();
}
doc << Print(kv.second);
doc << PrintNewLine();
}
// functions
for (const auto& kv : mod->functions) {
dg_ = DependencyGraph::Create(&arena_, kv.second);
......@@ -547,7 +564,12 @@ class PrettyPrinter :
for (const Doc& d : PrintCallAttrs(op->attrs, op->op)) {
args.push_back(d);
}
doc << Print(op->op);
const auto* cons_node = op->op.as<ConstructorNode>();
if (cons_node) {
doc << cons_node->name_hint;
} else {
doc << Print(op->op);
}
return doc << "(" << PrintSep(args) << ")";
}
......@@ -570,27 +592,57 @@ class PrettyPrinter :
// TODO(jmp): Lots of code duplication here because PrintBody and PrintScope don't accept Docs.
Doc doc;
Doc body;
doc << "match " << Print(op->data) << " ";
doc << "{";
std::vector<Doc> clauses;
doc << "match";
if (!op->complete) {
doc << "?";
}
doc << " (" << Print(op->data) << ") {";
std::vector<Doc> clause_docs;
for (const auto& clause : op->clauses) {
Doc clause_doc;
clauses.push_back(clause_doc << Print(clause->lhs) << " -> "
<< Print(clause->rhs));
clause_doc << PrintPattern(clause->lhs, false) << " => ";
Doc rhs_doc = PrintScope(clause->rhs);
if (clause->rhs.as<LetNode>()) {
// only add braces if there are multiple lines on the rhs
rhs_doc = Brace(rhs_doc);
}
clause_doc << rhs_doc << ",";
clause_docs.push_back(clause_doc);
}
doc << Indent(2, body << PrintNewLine() << PrintSep(clauses, PrintNewLine())) << PrintNewLine();
doc << "}";
doc << Indent(2, body << PrintNewLine() << PrintSep(clause_docs, PrintNewLine()))
<< PrintNewLine() << "}";
return doc;
}
Doc PrintPattern(const Pattern& pattern, bool meta) {
auto it = memo_pattern_.find(pattern);
if (it != memo_pattern_.end()) return it->second;
Doc printed_pattern;
if (meta) {
printed_pattern = meta_.GetMetaNode(GetRef<NodeRef>(pattern.get()));
} else {
printed_pattern = VisitPattern(pattern);
}
memo_pattern_[pattern] = printed_pattern;
return printed_pattern;
}
Doc VisitPattern_(const PatternConstructorNode* p) final {
Doc doc;
doc << p->constructor->name_hint << "(";
std::vector<Doc> pats;
for (const auto& pat : p->patterns) {
pats.push_back(Print(pat));
doc << p->constructor->name_hint;
if (!p->patterns.empty()) {
doc << "(";
std::vector<Doc> pats;
for (const auto& pat : p->patterns) {
pats.push_back(Print(pat));
}
doc << PrintSep(pats) << ")";
}
return doc << PrintSep(pats) << ")";
return doc;
}
Doc VisitPattern_(const PatternWildcardNode* pw) final {
return Doc("_");
}
Doc VisitPattern_(const PatternVarNode* pv) final {
......@@ -598,7 +650,17 @@ class PrettyPrinter :
}
Doc VisitExpr_(const ConstructorNode* n) final {
return Doc(n->name_hint);
Doc doc;
doc << n->name_hint;
if (n->inputs.size() != 0) {
doc << "(";
std::vector<Doc> inputs;
for (Type input : n->inputs) {
inputs.push_back(Print(input));
}
doc << PrintSep(inputs) << ")";
}
return doc;
}
//------------------------------------
......@@ -623,7 +685,7 @@ class PrettyPrinter :
}
Doc VisitType_(const TypeVarNode* node) final {
return AllocTypeVar(GetRef<TypeVar>(node));
return Doc(node->var->name_hint);
}
Doc VisitType_(const GlobalTypeVarNode* node) final {
......@@ -675,13 +737,13 @@ class PrettyPrinter :
Doc doc;
doc << "fn ";
if (node->type_params.size() != 0) {
doc << "<";
doc << "[";
std::vector<Doc> type_params;
for (Type type_param : node->type_params) {
type_params.push_back(Print(type_param));
}
doc << PrintSep(type_params);
doc << ">";
doc << "]";
}
std::vector<Doc> arg_types;
for (Type arg_type : node->arg_types) {
......@@ -695,6 +757,37 @@ class PrettyPrinter :
return doc << "ref(" << Print(node->value) << ")";
}
Doc VisitType_(const TypeDataNode* node) final {
Doc doc;
doc << "type " << Print(node->header);
// type vars
if (node->type_vars.size() != 0) {
doc << "[";
std::vector<Doc> type_vars;
for (Type type_var : node->type_vars) {
type_vars.push_back(Print(type_var));
}
doc << PrintSep(type_vars) << "]";
}
doc << " ";
std::vector<Doc> constructor_docs;
for (Constructor constructor : node->constructors) {
constructor_docs.push_back(Print(constructor, /* meta */ false, /* try_inline */ true));
}
Doc separator;
separator << "," << PrintNewLine();
Doc adt_body;
adt_body << PrintSep(constructor_docs, separator);
// add trailing comma if there are any constructors
if (!constructor_docs.empty()) {
adt_body << ",";
}
doc << Brace(adt_body);
return doc;
}
//------------------------------------
// Overload of Attr printing functions
//------------------------------------
......@@ -758,6 +851,8 @@ class PrettyPrinter :
std::unordered_map<Expr, Doc, NodeHash, NodeEqual> memo_;
/*! \brief Map from Type to Doc */
std::unordered_map<Type, Doc, NodeHash, NodeEqual> memo_type_;
/*! \brief Map from Type to Doc */
std::unordered_map<Pattern, Doc, NodeHash, NodeEqual> memo_pattern_;
/*! \brief name allocation map */
std::unordered_map<std::string, int> name_alloc_map_;
/*! \brief meta data context */
......@@ -861,7 +956,7 @@ std::string PrettyPrint_(const NodeRef& node,
bool show_meta_data,
runtime::TypedPackedFunc<std::string(Expr)> annotate) {
Doc doc;
doc << "v0.0.3" << PrintNewLine()
doc << kSemVer << PrintNewLine()
<< PrettyPrinter(show_meta_data, annotate).PrintFinal(node);
return doc.str();
}
......
......@@ -774,7 +774,6 @@ class TypeInferencer::Resolver : public ExprMutator, PatternMutator {
bool update_missing_type_annotation_{true};
};
Expr TypeInferencer::Infer(Expr expr) {
// Step 1: Populate the constraints.
GetType(expr);
......
......@@ -16,14 +16,14 @@
# under the License.
import tvm
from tvm import relay
from tvm.relay.analysis import alpha_equal, assert_alpha_equal
from tvm.relay.analysis import graph_equal, assert_graph_equal
from nose.tools import nottest, raises
from numpy import isclose
from typing import Union
from functools import wraps
raises_parse_error = raises(tvm._ffi.base.TVMError)
SEMVER = "v0.0.3"
SEMVER = "v0.0.4"
BINARY_OPS = {
"*": relay.multiply,
......@@ -60,20 +60,29 @@ TYPES = {
"float16x4",
}
LIST_DEFN = """
type List[A] {
Cons(A, List[A]),
Nil,
}
"""
def roundtrip(expr):
x = relay.fromtext(str(expr))
assert_alpha_equal(x, expr)
assert_graph_equal(x, expr)
def parse_text(code):
x = relay.fromtext(SEMVER + "\n" + code)
roundtrip(x)
return x
expr = relay.fromtext(SEMVER + "\n" + code)
roundtrip(expr)
return expr
def parses_as(code, expr):
# type: (str, relay.Expr) -> bool
return alpha_equal(parse_text(code), expr)
parsed = parse_text(code)
result = graph_equal(parsed, expr)
return result
def get_scalar(x):
# type: (relay.Constant) -> (Union[float, int, bool])
......@@ -168,13 +177,13 @@ def test_bin_op():
def test_parens():
assert alpha_equal(parse_text("1 * 1 + 1"), parse_text("(1 * 1) + 1"))
assert not alpha_equal(parse_text("1 * 1 + 1"), parse_text("1 * (1 + 1)"))
assert graph_equal(parse_text("1 * 1 + 1"), parse_text("(1 * 1) + 1"))
assert not graph_equal(parse_text("1 * 1 + 1"), parse_text("1 * (1 + 1)"))
def test_op_assoc():
assert alpha_equal(parse_text("1 * 1 + 1 < 1 == 1"), parse_text("(((1 * 1) + 1) < 1) == 1"))
assert alpha_equal(parse_text("1 == 1 < 1 + 1 * 1"), parse_text("1 == (1 < (1 + (1 * 1)))"))
assert graph_equal(parse_text("1 * 1 + 1 < 1 == 1"), parse_text("(((1 * 1) + 1) < 1) == 1"))
assert graph_equal(parse_text("1 == 1 < 1 + 1 * 1"), parse_text("1 == (1 < (1 + (1 * 1)))"))
@nottest
......@@ -239,7 +248,7 @@ def test_seq():
)
assert parses_as(
"let %_ = { 1 }; ()",
"let %_ = 1; ()",
relay.Let(
X,
relay.const(1),
......@@ -249,13 +258,13 @@ def test_seq():
def test_graph():
code = "%0 = (); %1 = 1; (%0, %0, %1)"
assert parses_as(
"%0 = (); %1 = 1; (%0, %0, %1)",
code,
relay.Tuple([UNIT, UNIT, relay.const(1)])
)
assert not parses_as(
"%0 = (); %1 = 1; (%0, %0, %1)",
code,
relay.Tuple([relay.Tuple([]), relay.Tuple([]), relay.const(1)])
)
......@@ -632,6 +641,236 @@ def test_tuple_type():
)
)
def test_adt_defn():
mod = relay.Module()
glob_typ_var = relay.GlobalTypeVar("Ayy")
prog = relay.TypeData(
glob_typ_var,
[],
[relay.Constructor("Nil", [], glob_typ_var)])
mod[glob_typ_var] = prog
assert parses_as(
"""
type Ayy { Nil }
""",
mod
)
def test_empty_adt_defn():
mod = relay.Module()
glob_typ_var = relay.GlobalTypeVar("Ayy")
prog = relay.TypeData(glob_typ_var, [], [])
mod[glob_typ_var] = prog
assert parses_as(
"""
type Ayy { }
""",
mod
)
def test_multiple_cons_defn():
mod = relay.Module()
list_var = relay.GlobalTypeVar("List")
typ_var = relay.TypeVar("A")
prog = relay.TypeData(
list_var,
[typ_var],
[
relay.Constructor("Cons", [typ_var, list_var(typ_var)], list_var),
relay.Constructor("Nil", [], list_var),
])
mod[list_var] = prog
assert parses_as(LIST_DEFN, mod)
def test_multiple_type_param_defn():
glob_typ_var = relay.GlobalTypeVar("Either")
typ_var_a = relay.TypeVar("A")
typ_var_b = relay.TypeVar("B")
prog = relay.TypeData(
glob_typ_var,
[typ_var_a, typ_var_b],
[
relay.Constructor("Left", [typ_var_a], glob_typ_var),
relay.Constructor("Right", [typ_var_b], glob_typ_var),
])
mod = relay.Module()
mod[glob_typ_var] = prog
assert parses_as(
"""
type Either[A, B] {
Left(A),
Right(B),
}
""",
mod
)
def test_match():
# pair each match keyword with whether it specifies a complete match or not
match_keywords = [("match", True), ("match?", False)]
for (match_keyword, is_complete) in match_keywords:
mod = relay.Module()
list_var = relay.GlobalTypeVar("List")
typ_var = relay.TypeVar("A")
cons_constructor = relay.Constructor(
"Cons", [typ_var, list_var(typ_var)], list_var)
nil_constructor = relay.Constructor("Nil", [], list_var)
list_def = relay.TypeData(
list_var,
[typ_var],
[cons_constructor, nil_constructor])
mod[list_var] = list_def
length_var = relay.GlobalVar("length")
typ_var = relay.TypeVar("A")
input_type = list_var(typ_var)
input_var = relay.Var("xs", input_type)
rest_var = relay.Var("rest")
cons_case = relay.Let(
_,
UNIT,
relay.add(relay.const(1), relay.Call(length_var, [rest_var])))
body = relay.Match(input_var,
[relay.Clause(
relay.PatternConstructor(
cons_constructor,
[relay.PatternWildcard(), relay.PatternVar(rest_var)]),
cons_case),
relay.Clause(
relay.PatternConstructor(nil_constructor, []),
relay.const(0))],
complete=is_complete
)
length_func = relay.Function(
[input_var],
body,
int32,
[typ_var]
)
mod[length_var] = length_func
assert parses_as(
"""
%s
def @length[A](%%xs: List[A]) -> int32 {
%s (%%xs) {
Cons(_, %%rest) => {
();;
1 + @length(%%rest)
},
Nil => 0,
}
}
""" % (LIST_DEFN, match_keyword),
mod
)
def test_adt_cons_expr():
mod = relay.Module()
list_var = relay.GlobalTypeVar("List")
typ_var = relay.TypeVar("A")
cons_constructor = relay.Constructor(
"Cons", [typ_var, list_var(typ_var)], list_var)
nil_constructor = relay.Constructor("Nil", [], list_var)
list_def = relay.TypeData(
list_var,
[typ_var],
[cons_constructor, nil_constructor])
mod[list_var] = list_def
make_singleton_var = relay.GlobalVar("make_singleton")
input_var = relay.Var("x", int32)
make_singleton_func = relay.Function(
[input_var],
cons_constructor(input_var, nil_constructor()),
list_var(int32)
)
mod[make_singleton_var] = make_singleton_func
assert parses_as(
"""
%s
def @make_singleton(%%x: int32) -> List[int32] {
Cons(%%x, Nil())
}
""" % LIST_DEFN,
mod
)
@raises_parse_error
def test_duplicate_adt_defn():
parse_text(
"""
%s
type List[A] {
Cons(A, List[A]),
Nil,
}
""" % LIST_DEFN
)
@raises_parse_error
def test_duplicate_adt_cons():
parse_text(
"""
type Ayy { Lmao }
type Haha { Lmao }
"""
)
@raises_parse_error
def test_duplicate_adt_cons_defn():
parse_text(
"""
type Ayy { Lmao }
type Lmao { Ayy }
"""
)
@raises_parse_error
def test_duplicate_global_var():
parse_text(
"""
def @id[A](%x: A) -> A { x }
def @id[A](%x: A) -> A { x }
"""
)
def test_extern_adt_defn():
# TODO(weberlo): update this test once extern is implemented
mod = relay.Module()
extern_var = relay.GlobalTypeVar("T")
typ_var = relay.TypeVar("A")
extern_def = relay.TypeData(extern_var, [typ_var], [])
mod[extern_var] = extern_def
assert parses_as(
"""
extern type T[A]
""",
mod
)
if __name__ == "__main__":
test_comments()
test_int_literal()
......@@ -655,3 +894,14 @@ if __name__ == "__main__":
test_tensor_type()
test_function_type()
test_tuple_type()
test_adt_defn()
test_empty_adt_defn()
test_multiple_cons_defn()
test_multiple_type_param_defn()
test_match()
test_adt_cons_expr()
test_duplicate_adt_defn()
test_duplicate_adt_cons()
test_duplicate_adt_cons_defn()
test_duplicate_global_var()
test_extern_adt_defn()
......@@ -23,14 +23,14 @@ from tvm.relay.analysis import alpha_equal, assert_alpha_equal, assert_graph_equ
do_print = [False]
SEMVER = "v0.0.3\n"
SEMVER = "v0.0.4\n"
def astext(p, graph_equal=False):
def astext(p, unify_free_vars=False):
txt = p.astext()
if isinstance(p, Expr) and free_vars(p):
return txt
x = relay.fromtext(txt)
if graph_equal:
if unify_free_vars:
assert_graph_equal(x, p)
else:
assert_alpha_equal(x, p)
......@@ -78,7 +78,7 @@ def test_meta_data():
padding=(1, 1),
channels=2)
f = relay.Function([x, w], z)
text = astext(f, graph_equal=True)
text = astext(f, unify_free_vars=True)
text_no_meta = str(f)
assert "channels=2" in text
assert "channels=2" in text_no_meta
......@@ -122,7 +122,7 @@ def test_let_if_scope():
f = relay.Function([x, y, cond], result)
text = astext(f)
assert text.count("{") == 4
assert text.count("{") == 3
assert "%cond: bool" in text
show(astext(f))
......@@ -218,14 +218,6 @@ def test_zeros():
x = relay.op.zeros([], "float32")
astext(x)
def test_cast():
data = relay.var('data', dtype='float32')
fp16_cast = relay.cast(data, dtype='float16')
cast_func = relay.Function(relay.analysis.free_vars(fp16_cast), fp16_cast)
astext(cast_func)
if __name__ == "__main__":
do_print[0] = True
test_lstm()
......@@ -247,4 +239,3 @@ if __name__ == "__main__":
test_let_if_scope()
test_variable_name()
test_call_node_order()
test_cast()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment