Commit ca0292d8 by Logan Weber Committed by Jared Roesch

[Relay] Add ADTs to text format (#3863)

* Getting closer to having ADT defs

* ADT defs working probly

* Match parsing basipally done

* came to earth in a silver chrome UFO

* match finished?

* All tests but newest are passing

* ADT constructors work

now cleanup?

* Cleanup round 1

* Cleanup round 2

* Cleanup round 3

* Cleanup round 4

* Cleanup round 6

* Cleanup round 7

* Lil grammar fix

* Remove ANTLR Java files

* Lint roller

* Lint roller

* Address feedback

* Test completeness in match test

* Remove unused imports

* Lint roller

* Switch to Rust-style ADT syntax

* Lil fix

* Add dummy `extern type` handler

* Add type arg to test

* Update prelude semantic version

* Repair test

* Fix graph var handling in match

* Revert 's/graph_equal/is_unifiable' change
parent a103c4ee
...@@ -165,6 +165,13 @@ class ModuleNode : public RelayNode { ...@@ -165,6 +165,13 @@ class ModuleNode : public RelayNode {
TVM_DLL TypeData LookupDef(const std::string& var) const; TVM_DLL TypeData LookupDef(const std::string& var) const;
/*! /*!
* \brief Check if a global type definition exists
* \param var The name of the global type definition.
* \return Whether the definition exists.
*/
TVM_DLL bool HasDef(const std::string& var) const;
/*!
* \brief Look up a constructor by its tag. * \brief Look up a constructor by its tag.
* \param tag The tag for the constructor. * \param tag The tag for the constructor.
* \return The constructor object. * \return The constructor object.
......
...@@ -21,13 +21,14 @@ from __future__ import absolute_import ...@@ -21,13 +21,14 @@ from __future__ import absolute_import
import sys import sys
from ast import literal_eval from ast import literal_eval
from typing import Any, Deque, Dict, List, Optional, TypeVar, Tuple, Union
from collections import deque from collections import deque
import tvm import tvm
from . import module from . import module
from .base import Span, SourceName from .base import Span, SourceName
from . import adt
from . import expr from . import expr
from . import ty from . import ty
from . import op from . import op
...@@ -53,8 +54,7 @@ sys.setrecursionlimit(10000) ...@@ -53,8 +54,7 @@ sys.setrecursionlimit(10000)
class ParseError(Exception): class ParseError(Exception):
"""Exception type for parse errors.""" """Exception type for parse errors."""
def __init__(self, message): def __init__(self, message: str) -> None:
# type: (str) -> None
super(ParseError, self).__init__() super(ParseError, self).__init__()
self.message = message self.message = message
...@@ -143,12 +143,11 @@ TYPE_PREFIXES = [ ...@@ -143,12 +143,11 @@ TYPE_PREFIXES = [
"bool", "bool",
] ]
T = ty.TypeVar("T") T = TypeVar("T")
# Scope = Deque[Tuple[str, T]] Scope = Deque[Tuple[str, T]]
# Scopes = Deque[Scope[T]] Scopes = Deque[Scope[T]]
def lookup(scopes, name): def lookup(scopes: Scopes[T], name: str) -> Optional[T]:
# type: (Scopes[T], str) -> Optional[T]
"""Look up `name` in `scopes`.""" """Look up `name` in `scopes`."""
for scope in scopes: for scope in scopes:
...@@ -185,95 +184,92 @@ def spanify(f): ...@@ -185,95 +184,92 @@ def spanify(f):
class ParseTreeToRelayIR(RelayVisitor): class ParseTreeToRelayIR(RelayVisitor):
"""Parse Relay text format into Relay IR.""" """Parse Relay text format into Relay IR."""
def __init__(self, source_name): def __init__(self, source_name: str) -> None:
# type: (str) -> None
self.source_name = source_name self.source_name = source_name
self.module = module.Module({}) # type: module.Module self.module = module.Module({}) # type: module.Module
# Adding an empty scope allows naked lets without pain. # Adding an empty scope allows naked lets without pain.
self.var_scopes = deque([deque()]) # type: Scopes[expr.Var] self.var_scopes = deque([deque()]) # type: Scopes[expr.Var]
self.global_var_scope = deque() # type: Scope[expr.GlobalVar] self.global_vars = {} # type: Scope[expr.GlobalVar]
self.type_param_scopes = deque([deque()]) # type: Scopes[ty.TypeVar] self.type_var_scopes = deque([deque()]) # type: Scopes[ty.TypeVar]
self.global_type_vars = {} # type: Scope[expr.GlobalVar]
self.graph_expr = [] # type: List[expr.Expr] self.graph_expr = [] # type: List[expr.Expr]
super(ParseTreeToRelayIR, self).__init__() super(ParseTreeToRelayIR, self).__init__()
def enter_var_scope(self): def enter_var_scope(self) -> None:
# type: () -> None
"""Enter a new Var scope so it can be popped off later.""" """Enter a new Var scope so it can be popped off later."""
self.var_scopes.appendleft(deque()) self.var_scopes.appendleft(deque())
def exit_var_scope(self): def exit_var_scope(self) -> Scope[expr.Var]:
# type: () -> Scope[expr.Var]
"""Pop off the current Var scope and return it.""" """Pop off the current Var scope and return it."""
return self.var_scopes.popleft() return self.var_scopes.popleft()
def mk_var(self, name, type_): def mk_var(self, name: str, typ: ty.Type = None):
# type: (str, ty.Type) -> expr.Var
"""Create a new Var and add it to the Var scope.""" """Create a new Var and add it to the Var scope."""
var = expr.Var(name, typ)
var = expr.Var(name, type_)
self.var_scopes[0].appendleft((name, var)) self.var_scopes[0].appendleft((name, var))
return var return var
def mk_global_var(self, name): def mk_global_var(self, name: str) -> expr.GlobalVar:
# type: (str) -> expr.GlobalVar
"""Create a new GlobalVar and add it to the GlobalVar scope.""" """Create a new GlobalVar and add it to the GlobalVar scope."""
if name in self.global_vars:
raise ParseError(f"duplicate global var \"{name}\"")
var = expr.GlobalVar(name) var = expr.GlobalVar(name)
self.global_var_scope.append((name, var)) self.global_vars[name] = var
return var return var
def enter_type_param_scope(self): def enter_type_param_scope(self) -> None:
# type: () -> None
"""Enter a new TypeVar scope so it can be popped off later.""" """Enter a new TypeVar scope so it can be popped off later."""
self.type_var_scopes.appendleft(deque())
self.type_param_scopes.appendleft(deque()) def exit_type_param_scope(self) -> Scope[ty.TypeVar]:
def exit_type_param_scope(self):
# type: () -> Scope[ty.TypeVar]
"""Pop off the current TypeVar scope and return it.""" """Pop off the current TypeVar scope and return it."""
return self.type_var_scopes.popleft()
return self.type_param_scopes.popleft() def mk_typ(self, name: str, kind: ty.Kind) -> ty.TypeVar:
def mk_typ(self, name, kind):
# (str, ty.Kind) -> ty.TypeVar
"""Create a new TypeVar and add it to the TypeVar scope.""" """Create a new TypeVar and add it to the TypeVar scope."""
typ = ty.TypeVar(name, kind) typ = ty.TypeVar(name, kind)
self.type_param_scopes[0].appendleft((name, typ)) self.type_var_scopes[0].appendleft((name, typ))
return typ
def mk_global_typ_var(self, name, kind):
# (str, ty.Kind) -> ty.GlobalTypeVar
"""Create a new TypeVar and add it to the TypeVar scope."""
typ = ty.GlobalTypeVar(name, kind)
self._check_existing_typ_expr(name, typ)
self.global_type_vars[name] = typ
return typ return typ
# TODO: rethink whether we should have type constructors mixed with type vars.
def mk_global_typ_cons(self, name, cons):
self._check_existing_typ_expr(name, cons)
self.global_type_vars[name] = cons
def _check_existing_typ_expr(self, name, new_expr):
if name in self.global_type_vars:
new_typ_name = self._type_expr_name(new_expr)
existing_typ_name = self._type_expr_name(self.global_type_vars[name])
raise ParseError(
f"{new_typ_name} `{name}` conflicts with existing {existing_typ_name}")
def _type_expr_name(self, e):
if isinstance(e, adt.Constructor):
return f"`{e.belong_to.var.name}` ADT constructor"
elif isinstance(e, ty.GlobalTypeVar):
if e.kind == ty.Kind.AdtHandle:
return f"ADT definition"
return "function definition"
def visitProjection(self, ctx): def visitProjection(self, ctx):
return expr.TupleGetItem(self.visit(ctx.expr()), self.visit(ctx.NAT())) return expr.TupleGetItem(self.visit(ctx.expr()), self.visit(ctx.NAT()))
def visitTerminal(self, node): def visitTerminal(self, node) -> Union[expr.Expr, int, float]:
# type: (TerminalNode) -> Union[expr.Expr, int, float]
"""Visit lexer tokens that aren't ignored or visited by other functions.""" """Visit lexer tokens that aren't ignored or visited by other functions."""
node_type = node.getSymbol().type node_type = node.getSymbol().type
node_text = node.getText() node_text = node.getText()
name = node_text[1:]
# variables
if node_type == RelayLexer.GLOBAL_VAR:
return lookup(deque([self.global_var_scope]), node_text[1:])
if node_type == RelayLexer.LOCAL_VAR:
# Remove the leading '%' and lookup the name.
var = lookup(self.var_scopes, name)
if var is None:
raise ParseError("Couldn't resolve `{}`.".format(name))
return var
if node_type == RelayLexer.GRAPH_VAR:
try:
return self.graph_expr[int(name)]
except IndexError:
raise ParseError("Couldn't resolve `{}`".format(name))
# data types
if node_type == RelayLexer.NAT: if node_type == RelayLexer.NAT:
return int(node_text) return int(node_text)
if node_type == RelayLexer.FLOAT: if node_type == RelayLexer.FLOAT:
...@@ -283,35 +279,67 @@ class ParseTreeToRelayIR(RelayVisitor): ...@@ -283,35 +279,67 @@ class ParseTreeToRelayIR(RelayVisitor):
return True return True
if node_text == "False": if node_text == "False":
return False return False
raise ParseError("Unrecognized BOOL_LIT: `{}`".format(node_text)) raise ParseError("unrecognized BOOL_LIT: `{}`".format(node_text))
if node_type == RelayLexer.QUOTED_STRING: if node_type == RelayLexer.QUOTED_STRING:
return literal_eval(node_text) return literal_eval(node_text)
raise ParseError(f"unhandled terminal \"{node_text}\" of type `{node_type}`")
def visitGeneralIdent(self, ctx):
name = ctx.getText()
# Look through all type prefixes for a match.
for type_prefix in TYPE_PREFIXES:
if name.startswith(type_prefix):
return ty.scalar_type(name)
# Next, look it up in the local then global type params.
type_param = lookup(self.type_var_scopes, name)
if type_param is None:
type_param = self.global_type_vars.get(name, None)
if type_param is not None:
return type_param
# Check if it's an operator.
op_name = ".".join([name.getText() for name in ctx.CNAME()])
if op_name in FUNC_OPS:
return FuncOp(FUNC_OPS[op_name])
return ExprOp(op.get(op_name))
raise ParseError("todo: `{}`".format(node_text)) def visitGlobalVar(self, ctx):
var_name = ctx.CNAME().getText()
global_var = self.global_vars.get(var_name, None)
if global_var is None:
raise ParseError(f"unbound global var `{var_name}`")
return global_var
def visit_list(self, ctx_list): def visitLocalVar(self, ctx):
# type: (List[ParserRuleContext]) -> List[Any] var_name = ctx.CNAME().getText()
local_var = lookup(self.var_scopes, var_name)
if local_var is None:
raise ParseError(f"unbound local var `{var_name}`")
return local_var
def visitGraphVar(self, ctx):
return self.graph_expr[int(ctx.NAT().getText())]
def visit_list(self, ctx_list) -> List[Any]:
""""Visit a list of contexts.""" """"Visit a list of contexts."""
# type: RelayParser.ContextParserRuleContext
assert isinstance(ctx_list, list) assert isinstance(ctx_list, list)
return [self.visit(ctx) for ctx in ctx_list] return [self.visit(ctx) for ctx in ctx_list]
def getType_(self, ctx): def getTypeExpr(self, ctx) -> Optional[ty.Type]:
# type: (Optional[RelayParser.Type_Context]) -> Optional[ty.Type]
"""Return a (possibly None) Relay type.""" """Return a (possibly None) Relay type."""
# type: : Optional[RelayParser.Type_Context]
if ctx is None: if ctx is None:
return None return None
return self.visit(ctx) return self.visit(ctx)
def visitProg(self, ctx): def visitProg(self, ctx: RelayParser.ProgContext) -> Union[expr.Expr, module.Module]:
self.meta = None self.meta = None
if ctx.METADATA(): if ctx.METADATA():
header, data = str(ctx.METADATA()).split('\n', 1) header, data = str(ctx.METADATA()).split("\n", 1)
assert header == "METADATA:" assert header == "METADATA:"
self.meta = tvm.load_json(data) self.meta = tvm.load_json(data)
# type: (RelayParser.ProgContext) -> Union[expr.Expr, module.Module]
if ctx.defn(): if ctx.defn():
self.visit_list(ctx.defn()) self.visit_list(ctx.defn())
return self.module return self.module
...@@ -322,37 +350,30 @@ class ParseTreeToRelayIR(RelayVisitor): ...@@ -322,37 +350,30 @@ class ParseTreeToRelayIR(RelayVisitor):
return self.module return self.module
# Exprs # Exprs
def visitOpIdent(self, ctx): def visitOpIdent(self, ctx) -> op.Op:
# type: (RelayParser.OpIdentContext) -> op.Op op_name = ".".join([name.getText() for name in ctx.CNAME()])
op_name = ctx.CNAME().getText()
if op_name in FUNC_OPS: if op_name in FUNC_OPS:
return FuncOp(FUNC_OPS[op_name]) return FuncOp(FUNC_OPS[op_name])
return ExprOp(op.get(op_name)) return ExprOp(op.get(op_name))
# pass through # pass through
def visitParen(self, ctx): def visitParen(self, ctx: RelayParser.ParenContext) -> expr.Expr:
# type: (RelayParser.ParenContext) -> expr.Expr
return self.visit(ctx.expr()) return self.visit(ctx.expr())
# pass through # pass through
def visitBody(self, ctx): def visitBody(self, ctx: RelayParser.BodyContext) -> expr.Expr:
# type: (RelayParser.BodyContext) -> expr.Expr
return self.visit(ctx.expr()) return self.visit(ctx.expr())
def visitScalarFloat(self, ctx): def visitScalarFloat(self, ctx: RelayParser.ScalarFloatContext) -> expr.Constant:
# type: (RelayParser.ScalarFloatContext) -> expr.Constant
return expr.const(self.visit(ctx.FLOAT())) return expr.const(self.visit(ctx.FLOAT()))
def visitScalarInt(self, ctx): def visitScalarInt(self, ctx: RelayParser.ScalarIntContext) -> expr.Constant:
# type: (RelayParser.ScalarIntContext) -> expr.Constant
return expr.const(self.visit(ctx.NAT())) return expr.const(self.visit(ctx.NAT()))
def visitScalarBool(self, ctx): def visitScalarBool(self, ctx: RelayParser.ScalarBoolContext) -> expr.Constant:
# type: (RelayParser.ScalarBoolContext) -> expr.Constant
return expr.const(self.visit(ctx.BOOL_LIT())) return expr.const(self.visit(ctx.BOOL_LIT()))
def visitNeg(self, ctx): def visitNeg(self, ctx: RelayParser.NegContext) -> Union[expr.Constant, expr.Call]:
# type: (RelayParser.NegContext) -> Union[expr.Constant, expr.Call]
val = self.visit(ctx.expr()) val = self.visit(ctx.expr())
if isinstance(val, expr.Constant) and val.data.asnumpy().ndim == 0: if isinstance(val, expr.Constant) and val.data.asnumpy().ndim == 0:
# fold Neg in for scalars # fold Neg in for scalars
...@@ -360,20 +381,18 @@ class ParseTreeToRelayIR(RelayVisitor): ...@@ -360,20 +381,18 @@ class ParseTreeToRelayIR(RelayVisitor):
return op.negative(val) return op.negative(val)
def visitTuple(self, ctx): def visitTuple(self, ctx: RelayParser.TupleContext) -> expr.Tuple:
# type: (RelayParser.TupleContext) -> expr.Tuple
tup = self.visit_list(ctx.expr()) tup = self.visit_list(ctx.expr())
return expr.Tuple(tup) return expr.Tuple(tup)
def visitLet(self, ctx): def visitLet(self, ctx: RelayParser.LetContext) -> expr.Let:
# type: (RelayParser.SeqContext) -> expr.Let
"""Desugar various sequence constructs to Relay Let nodes.""" """Desugar various sequence constructs to Relay Let nodes."""
if ctx.var() is None: if ctx.var() is None:
# anonymous identity # anonymous identity
ident = "_" ident = "_"
type_ = None typ = None
var = self.mk_var(ident, type_) var = self.mk_var(ident, typ)
else: else:
var = self.visitVar(ctx.var()) var = self.visitVar(ctx.var())
...@@ -385,66 +404,61 @@ class ParseTreeToRelayIR(RelayVisitor): ...@@ -385,66 +404,61 @@ class ParseTreeToRelayIR(RelayVisitor):
return expr.Let(var, value, body) return expr.Let(var, value, body)
def visitBinOp(self, ctx): def visitBinOp(self, ctx: RelayParser.BinOpContext) -> expr.Call:
# type: (RelayParser.BinOpContext) -> expr.Call
"""Desugar binary operators.""" """Desugar binary operators."""
arg0, arg1 = self.visit_list(ctx.expr()) arg0, arg1 = self.visit_list(ctx.expr())
relay_op = BINARY_OPS.get(ctx.op.type) relay_op = BINARY_OPS.get(ctx.op.type)
if relay_op is None: if relay_op is None:
raise ParseError("Unimplemented binary op.") raise ParseError("unimplemented binary op.")
return relay_op(arg0, arg1) return relay_op(arg0, arg1)
@spanify @spanify
def visitVar(self, ctx): def visitVar(self, ctx: RelayParser.VarContext) -> expr.Var:
# type: (RelayParser.VarContext) -> expr.Var
"""Visit a single variable.""" """Visit a single variable."""
ident = ctx.LOCAL_VAR() ident = ctx.localVar()
if ident is None: if ident is None:
raise ParseError("Only local ids may be used in vars.") raise ParseError("only local ids may be used in vars.")
type_ = self.getType_(ctx.type_()) typeExpr = self.getTypeExpr(ctx.typeExpr())
return self.mk_var(ident.getText()[1:], type_) return self.mk_var(ident.getText()[1:], typeExpr)
def visitVarList(self, ctx): def visitVarList(self, ctx: RelayParser.VarListContext) -> List[expr.Var]:
# type: (RelayParser.VarListContext) -> List[expr.Var]
return self.visit_list(ctx.var()) return self.visit_list(ctx.var())
# TODO: support a larger class of values than just Relay exprs # TODO: support a larger class of values than just Relay exprs
def visitAttr(self, ctx): def visitAttr(self, ctx: RelayParser.AttrContext) -> Tuple[str, expr.Expr]:
# type: (RelayParser.AttrContext) -> Tuple[str, expr.Expr]
return (ctx.CNAME().getText(), self.visit(ctx.expr())) return (ctx.CNAME().getText(), self.visit(ctx.expr()))
def visitArgNoAttr(self, ctx): def visitArgNoAttr(self, ctx: RelayParser.ArgNoAttrContext):
return (self.visit_list(ctx.varList().var()), None) return (self.visit_list(ctx.varList().var()), None)
def visitAttrSeq(self, ctx): def visitAttrSeq(self, ctx: RelayParser.AttrSeqContext) -> Dict[str, expr.Expr]:
# type: (RelayParser.AttrListContext) -> Dict[str, expr.Expr]
return dict(self.visit_list(ctx.attr())) return dict(self.visit_list(ctx.attr()))
def visitArgWithAttr(self, ctx): def visitArgWithAttr(self, ctx: RelayParser.AttrSeqContext) \
-> Tuple[List[expr.Var], Dict[str, expr.Expr]]:
return (self.visit_list(ctx.var()), self.visitAttrSeq(ctx.attrSeq())) return (self.visit_list(ctx.var()), self.visitAttrSeq(ctx.attrSeq()))
def visitArgList(self, def visitArgList(self, ctx: RelayParser.ArgListContext) \
ctx # type: RelayParser.ArgListContext -> Tuple[Optional[List[expr.Var]], Optional[Dict[str, expr.Expr]]]:
):
# type: (...) -> Tuple[Optional[List[expr.Var]], Optional[Dict[str, expr.Expr]]]
var_list = self.visit(ctx.varList()) if ctx.varList() else None var_list = self.visit(ctx.varList()) if ctx.varList() else None
attr_list = self.visit(ctx.attrList()) if ctx.attrList() else None attr_list = self.visit(ctx.attrList()) if ctx.attrList() else None
return (var_list, attr_list) return (var_list, attr_list)
def visitMeta(self, ctx): def visitMeta(self, ctx: RelayParser.MetaContext):
type_key = str(ctx.CNAME()) type_key = str(ctx.CNAME())
index = int(self.visit(ctx.NAT())) index = int(self.visit(ctx.NAT()))
return self.meta[type_key][index] return self.meta[type_key][index]
def mk_func(self, ctx): def mk_func(
# type: (Union[RelayParser.FuncContext, RelayParser.DefnContext]) -> expr.Function self,
ctx: Union[RelayParser.FuncContext, RelayParser.DefnContext]) \
-> expr.Function:
"""Construct a function from either a Func or Defn.""" """Construct a function from either a Func or Defn."""
# Enter var scope early to put params in scope. # Enter var scope early to put params in scope.
self.enter_var_scope() self.enter_var_scope()
# Capture type params in params. # Capture type params in params.
...@@ -452,7 +466,7 @@ class ParseTreeToRelayIR(RelayVisitor): ...@@ -452,7 +466,7 @@ class ParseTreeToRelayIR(RelayVisitor):
type_params = ctx.typeParamList() type_params = ctx.typeParamList()
if type_params is not None: if type_params is not None:
type_params = type_params.ident() type_params = type_params.generalIdent()
assert type_params assert type_params
for ty_param in type_params: for ty_param in type_params:
name = ty_param.getText() name = ty_param.getText()
...@@ -461,7 +475,7 @@ class ParseTreeToRelayIR(RelayVisitor): ...@@ -461,7 +475,7 @@ class ParseTreeToRelayIR(RelayVisitor):
var_list, attr_list = self.visit(ctx.argList()) var_list, attr_list = self.visit(ctx.argList())
if var_list is None: if var_list is None:
var_list = [] var_list = []
ret_type = self.getType_(ctx.type_()) ret_type = self.getTypeExpr(ctx.typeExpr())
body = self.visit(ctx.body()) body = self.visit(ctx.body())
# NB(@jroesch): you must stay in the type parameter scope until # NB(@jroesch): you must stay in the type parameter scope until
...@@ -476,41 +490,135 @@ class ParseTreeToRelayIR(RelayVisitor): ...@@ -476,41 +490,135 @@ class ParseTreeToRelayIR(RelayVisitor):
return expr.Function(var_list, body, ret_type, type_params, attrs) return expr.Function(var_list, body, ret_type, type_params, attrs)
@spanify @spanify
def visitFunc(self, ctx): def visitFunc(self, ctx: RelayParser.FuncContext) -> expr.Function:
# type: (RelayParser.FuncContext) -> expr.Function
return self.mk_func(ctx) return self.mk_func(ctx)
# TODO: how to set spans for definitions? # TODO: how to set spans for definitions?
# @spanify # @spanify
def visitDefn(self, ctx): def visitFuncDefn(self, ctx: RelayParser.DefnContext) -> None:
# type: (RelayParser.DefnContext) -> None ident_name = ctx.globalVar().getText()[1:]
ident = ctx.ident().GLOBAL_VAR()
if ident is None:
raise ParseError("Only global ids may be used in `def`s.")
ident_name = ident.getText()[1:]
ident = self.mk_global_var(ident_name) ident = self.mk_global_var(ident_name)
self.module[ident] = self.mk_func(ctx) self.module[ident] = self.mk_func(ctx)
def visitCallNoAttr(self, ctx): def handle_adt_header(
self,
ctx: Union[RelayParser.ExternAdtDefnContext, RelayParser.AdtDefnContext]):
"""Handles parsing of the name and type params of an ADT definition."""
adt_name = ctx.generalIdent().getText()
adt_var = self.mk_global_typ_var(adt_name, ty.Kind.AdtHandle)
# parse type params
type_params = ctx.typeParamList()
if type_params is None:
type_params = []
else:
type_params = [self.mk_typ(type_ident.getText(), ty.Kind.Type)
for type_ident in type_params.generalIdent()]
return adt_var, type_params
def visitExternAdtDefn(self, ctx: RelayParser.ExternAdtDefnContext):
# TODO(weberlo): update this handler once extern is implemented
self.enter_type_param_scope()
adt_var, type_params = self.handle_adt_header(ctx)
# update module being built
self.module[adt_var] = adt.TypeData(adt_var, type_params, [])
self.exit_type_param_scope()
def visitAdtDefn(self, ctx: RelayParser.AdtDefnContext):
self.enter_type_param_scope()
adt_var, type_params = self.handle_adt_header(ctx)
# parse constructors
adt_cons_defns = ctx.adtConsDefnList()
if adt_cons_defns is None:
adt_cons_defns = []
else:
adt_cons_defns = adt_cons_defns.adtConsDefn()
parsed_constructors = []
for cons_defn in adt_cons_defns:
inputs = [self.visit(inp) for inp in cons_defn.typeExpr()]
cons_defn_name = cons_defn.constructorName().getText()
cons_defn = adt.Constructor(cons_defn_name, inputs, adt_var)
self.mk_global_typ_cons(cons_defn_name, cons_defn)
parsed_constructors.append(cons_defn)
# update module being built
self.module[adt_var] = adt.TypeData(adt_var, type_params, parsed_constructors)
self.exit_type_param_scope()
def visitMatch(self, ctx: RelayParser.MatchContext):
match_type = ctx.matchType().getText()
if match_type == "match":
complete_match = True
elif match_type == "match?":
complete_match = False
else:
raise RuntimeError(f"unknown match type {match_type}")
# TODO: Will need some kind of type checking to know which ADT is being
# matched on.
match_data = self.visit(ctx.expr())
match_clauses = ctx.matchClauseList()
if match_clauses is None:
match_clauses = []
else:
match_clauses = match_clauses.matchClause()
parsed_clauses = []
for clause in match_clauses:
constructor_name = clause.constructorName().getText()
constructor = self.global_type_vars[constructor_name]
self.enter_var_scope()
patternList = clause.patternList()
if patternList is None:
patterns = []
else:
patterns = [self.visit(pattern) for pattern in patternList.pattern()]
clause_body = self.visit(clause.expr())
self.exit_var_scope()
# TODO: Do we need to pass `None` if it's a 0-arity cons, or is an empty list fine?
parsed_clauses.append(adt.Clause(
adt.PatternConstructor(
constructor,
patterns
),
clause_body
))
return adt.Match(match_data, parsed_clauses, complete=complete_match)
def visitPattern(self, ctx: RelayParser.PatternContext):
text = ctx.getText()
if text == "_":
return adt.PatternWildcard()
elif text.startswith("%"):
text = ctx.localVar().getText()
typ = ctx.typeExpr()
if typ is not None:
typ = self.visit(typ)
var = self.mk_var(text[1:], typ=typ)
return adt.PatternVar(var)
else:
raise ParseError(f"invalid pattern syntax \"{text}\"")
def visitCallNoAttr(self, ctx: RelayParser.CallNoAttrContext):
return (self.visit_list(ctx.exprList().expr()), None) return (self.visit_list(ctx.exprList().expr()), None)
def visitCallWithAttr(self, ctx): def visitCallWithAttr(self, ctx: RelayParser.CallWithAttrContext):
return (self.visit_list(ctx.expr()), self.visit(ctx.attrSeq())) return (self.visit_list(ctx.expr()), self.visit(ctx.attrSeq()))
def call(self, func, args, attrs, type_args): def call(self, func, args, attrs, type_args):
if isinstance(func, OpWrapper): if isinstance(func, OpWrapper):
return func(args, attrs, type_args) return func(args, attrs, type_args)
elif isinstance(func, adt.Constructor):
return func(*args)
return expr.Call(func, args, attrs, type_args) return expr.Call(func, args, attrs, type_args)
@spanify @spanify
def visitCall(self, ctx): def visitCall(self, ctx: RelayParser.CallContext):
# type: (RelayParser.CallContext) -> expr.Call # type: (RelayParser.CallContext) -> expr.Call
func = self.visit(ctx.expr()) func = self.visit(ctx.expr())
args, attrs = self.visit(ctx.callList()) args, attrs = self.visit(ctx.callList())
return self.call(func, args, attrs, []) res = self.call(func, args, attrs, [])
return res
@spanify @spanify
def visitIfElse(self, ctx): def visitIfElse(self, ctx: RelayParser.IfElseContext):
# type: (RelayParser.IfElseContext) -> expr.If # type: (RelayParser.IfElseContext) -> expr.If
"""Construct a Relay If node. Creates a new scope for each branch.""" """Construct a Relay If node. Creates a new scope for each branch."""
cond = self.visit(ctx.expr()) cond = self.visit(ctx.expr())
...@@ -526,10 +634,10 @@ class ParseTreeToRelayIR(RelayVisitor): ...@@ -526,10 +634,10 @@ class ParseTreeToRelayIR(RelayVisitor):
return expr.If(cond, true_branch, false_branch) return expr.If(cond, true_branch, false_branch)
@spanify @spanify
def visitGraph(self, ctx): def visitGraph(self, ctx: RelayParser.GraphContext):
# type: (RelayParser.GraphContext) -> expr.Expr # type: (RelayParser.GraphContext) -> expr.Expr
"""Visit a graph variable assignment.""" """Visit a graph variable assignment."""
graph_nid = int(ctx.GRAPH_VAR().getText()[1:]) graph_nid = int(ctx.graphVar().getText()[1:])
self.enter_var_scope() self.enter_var_scope()
value = self.visit(ctx.expr(0)) value = self.visit(ctx.expr(0))
...@@ -537,7 +645,7 @@ class ParseTreeToRelayIR(RelayVisitor): ...@@ -537,7 +645,7 @@ class ParseTreeToRelayIR(RelayVisitor):
if graph_nid != len(self.graph_expr): if graph_nid != len(self.graph_expr):
raise ParseError( raise ParseError(
"Expected new graph variable to be `%{}`,".format(len(self.graph_expr)) + \ "expected new graph variable to be `%{}`,".format(len(self.graph_expr)) + \
"but got `%{}`".format(graph_nid)) "but got `%{}`".format(graph_nid))
self.graph_expr.append(value) self.graph_expr.append(value)
...@@ -547,76 +655,47 @@ class ParseTreeToRelayIR(RelayVisitor): ...@@ -547,76 +655,47 @@ class ParseTreeToRelayIR(RelayVisitor):
# Types # Types
# pylint: disable=unused-argument # pylint: disable=unused-argument
def visitIncompleteType(self, ctx): def visitIncompleteType(self, ctx: RelayParser.IncompleteTypeContext):
# type (RelayParser.IncompleteTypeContext) -> None: # type (RelayParser.IncompleteTypeContext) -> None:
return None return None
def visitTypeIdent(self, ctx): def visitTypeCallType(self, ctx: RelayParser.TypeCallTypeContext):
# type: (RelayParser.TypeIdentContext) -> Union[ty.TensorType, str] func = self.visit(ctx.generalIdent())
''' args = [self.visit(arg) for arg in ctx.typeParamList().generalIdent()]
Handle type identifier. return ty.TypeCall(func, args)
'''
type_ident = ctx.CNAME().getText()
# Look through all type prefixes for a match
for type_prefix in TYPE_PREFIXES:
if type_ident.startswith(type_prefix):
return ty.scalar_type(type_ident)
type_param = lookup(self.type_param_scopes, type_ident)
if type_param is not None:
return type_param
raise ParseError("Unknown builtin type: {}".format(type_ident))
# def visitCallType(self, ctx):
# # type: (RelayParser.CallTypeContext) -> Union[expr.Expr, ty.TensorType]
# ident_type = ctx.identType().CNAME().getText()
# args = self.visit_list(ctx.type_())
# if not args:
# raise ParseError("Type-level functions must have arguments!")
# func_type = TYPE_FUNCS.get(ident_type)(args)
# if func_type is None:
# raise ParseError("Unknown type-level function: `{}`".format(ident_type))
# else:
# return func_type
def visitParensShape(self, ctx): def visitParensShape(self, ctx: RelayParser.ParensShapeContext):
# type: (RelayParser.ParensShapeContext) -> int # type: (RelayParser.ParensShapeContext) -> int
return self.visit(ctx.shape()) return self.visit(ctx.shape())
def visitShapeList(self, ctx): def visitShapeList(self, ctx: RelayParser.ShapeListContext):
# type: (RelayParser.ShapeListContext) -> List[int] # type: (RelayParser.ShapeListContext) -> List[int]
return self.visit_list(ctx.shape()) return self.visit_list(ctx.shape())
def visitTensor(self, ctx): def visitTensor(self, ctx: RelayParser.TensorContext):
return tuple(self.visit_list(ctx.expr())) return tuple(self.visit_list(ctx.expr()))
def visitTensorType(self, ctx): def visitTensorType(self, ctx: RelayParser.TensorTypeContext):
# type: (RelayParser.TensorTypeContext) -> ty.TensorType # type: (RelayParser.TensorTypeContext) -> ty.TensorType
"""Create a simple tensor type. No generics.""" """Create a simple tensor type. No generics."""
shape = self.visit(ctx.shapeList()) shape = self.visit(ctx.shapeList())
dtype = self.visit(ctx.type_()) dtype = self.visit(ctx.typeExpr())
if not isinstance(dtype, ty.TensorType): if not isinstance(dtype, ty.TensorType):
raise ParseError("Expected dtype to be a Relay base type.") raise ParseError("expected dtype to be a Relay base type.")
dtype = dtype.dtype dtype = dtype.dtype
return ty.TensorType(shape, dtype) return ty.TensorType(shape, dtype)
def visitTupleType(self, ctx): def visitTupleType(self, ctx: RelayParser.TupleTypeContext):
# type: (RelayParser.TupleTypeContext) -> ty.TupleType # type: (RelayParser.TupleTypeContext) -> ty.TupleType
return ty.TupleType(self.visit_list(ctx.type_())) return ty.TupleType(self.visit_list(ctx.typeExpr()))
def visitFuncType(self, ctx): def visitFuncType(self, ctx: RelayParser.FuncTypeContext):
# type: (RelayParser.FuncTypeContext) -> ty.FuncType # type: (RelayParser.FuncTypeContext) -> ty.FuncType
types = self.visit_list(ctx.type_()) types = self.visit_list(ctx.typeExpr())
arg_types = types[:-1] arg_types = types[:-1]
ret_type = types[-1] ret_type = types[-1]
...@@ -663,7 +742,7 @@ def fromtext(data, source_name=None): ...@@ -663,7 +742,7 @@ def fromtext(data, source_name=None):
# type: (str, str) -> Union[expr.Expr, module.Module] # type: (str, str) -> Union[expr.Expr, module.Module]
"""Parse a Relay program.""" """Parse a Relay program."""
if data == "": if data == "":
raise ParseError("Cannot parse the empty string.") raise ParseError("cannot parse the empty string.")
global __source_name_counter__ global __source_name_counter__
......
...@@ -17,11 +17,15 @@ ...@@ -17,11 +17,15 @@
* under the License. * under the License.
*/ */
// list = *, seq = ? /*
* NOTE: The `USE_ANTLR` option in `config.cmake` must be enabled in order for
* changes in this file to be reflected by the parser.
* NOTE: All upper-case rules are *lexer* rules and all camel-case rules are *parser* rules.
*/
grammar Relay; grammar Relay;
SEMVER: 'v0.0.3' ; SEMVER: 'v0.0.4' ;
// Lexing // Lexing
// comments // comments
...@@ -49,13 +53,8 @@ BOOL_LIT ...@@ -49,13 +53,8 @@ BOOL_LIT
| 'False' | 'False'
; ;
CNAME: ('_'|LETTER) ('_'|LETTER|DIGIT)* ('.' CNAME)*; CNAME: ('_'|LETTER) ('_'|LETTER|DIGIT)* ('.' CNAME)* ;
opIdent: CNAME ;
GLOBAL_VAR: '@' CNAME ;
LOCAL_VAR: '%' CNAME;
GRAPH_VAR: '%' NAT;
DATATYPE : 'int64';
// non-negative floats // non-negative floats
fragment PREFLOAT : NAT ('.' NAT)? EXP?; // 1.35, 1.35E-9, 0.3, 4.5, 1, 1e10 3e4 fragment PREFLOAT : NAT ('.' NAT)? EXP?; // 1.35, 1.35E-9, 0.3, 4.5, 1, 1e10 3e4
...@@ -74,7 +73,11 @@ METADATA: 'METADATA:' .*; ...@@ -74,7 +73,11 @@ METADATA: 'METADATA:' .*;
// A Relay program is a list of global definitions or an expression. // A Relay program is a list of global definitions or an expression.
prog: SEMVER (defn* | expr) METADATA? EOF ; prog: SEMVER (defn* | expr) METADATA? EOF ;
// option: 'set' ident BOOL_LIT ; // Covers both operator and type idents
generalIdent: CNAME ('.' CNAME)*;
globalVar: '@' CNAME ;
localVar: '%' ('_' | CNAME) ;
graphVar: '%' NAT ;
exprList: (expr (',' expr)*)?; exprList: (expr (',' expr)*)?;
callList callList
...@@ -85,7 +88,6 @@ callList ...@@ -85,7 +88,6 @@ callList
expr expr
// operators // operators
: '(' expr ')' # paren : '(' expr ')' # paren
| '{' expr '}' # paren
// function application // function application
| expr '(' callList ')' # call | expr '(' callList ')' # call
| '-' expr # neg | '-' expr # neg
...@@ -99,53 +101,74 @@ expr ...@@ -99,53 +101,74 @@ expr
| '(' ')' # tuple | '(' ')' # tuple
| '(' expr ',' ')' # tuple | '(' expr ',' ')' # tuple
| '(' expr (',' expr)+ ')' # tuple | '(' expr (',' expr)+ ')' # tuple
| expr '.' NAT # projection
| '[' (expr (',' expr)*)? ']' # tensor | '[' (expr (',' expr)*)? ']' # tensor
| 'if' '(' expr ')' body 'else' body # ifElse | 'if' '(' expr ')' body 'else' body # ifElse
| matchType '(' expr ')' '{' matchClauseList? '}' # match
| expr '.' NAT # projection
// sequencing // sequencing
| 'let' var '=' expr ';' expr # let | 'let' var '=' expr ';' expr # let
// sugar for let %_ = expr; expr // sugar for let %_ = expr; expr
| expr ';;' expr # let | expr ';;' expr # let
| GRAPH_VAR '=' expr ';' expr # graph | graphVar '=' expr ';' expr # graph
| ident # identExpr | ident # identExpr
| scalar # scalarExpr | scalar # scalarExpr
| meta # metaExpr | meta # metaExpr
| QUOTED_STRING # stringExpr | QUOTED_STRING # stringExpr
; ;
func: 'fn' typeParamList? '(' argList ')' ('->' type_)? body ; func: 'fn' typeParamList? '(' argList ')' ('->' typeExpr)? body ;
defn: 'def' ident typeParamList? '(' argList ')' ('->' type_)? body ; defn
: 'def' globalVar typeParamList? '(' argList ')' ('->' typeExpr)? body # funcDefn
| 'extern' 'type' generalIdent typeParamList? # externAdtDefn
| 'type' generalIdent typeParamList? '{' adtConsDefnList? '}' # adtDefn
;
constructorName: CNAME ;
adtConsDefnList: adtConsDefn (',' adtConsDefn)* ','? ;
adtConsDefn: constructorName ('(' typeExpr (',' typeExpr)* ')')? ;
matchClauseList: matchClause (',' matchClause)* ','? ;
matchClause: constructorName patternList? '=>' ('{' expr '}' | expr) ;
// complete or incomplete match, respectively
matchType : 'match' | 'match?' ;
patternList: '(' pattern (',' pattern)* ')';
pattern
: '_'
| localVar (':' typeExpr)?
;
adtCons: constructorName adtConsParamList? ;
adtConsParamList: '(' adtConsParam (',' adtConsParam)* ')' ;
adtConsParam: localVar | constructorName ;
argList argList
: varList # argNoAttr : varList # argNoAttr
| (var ',')* attrSeq # argWithAttr | (var ',')* attrSeq # argWithAttr
; ;
varList: (var (',' var)*)?; varList: (var (',' var)*)? ;
var: LOCAL_VAR (':' type_)?; var: localVar (':' typeExpr)? ;
attrSeq: attr (',' attr)*; attrSeq: attr (',' attr)* ;
attr: CNAME '=' expr ; attr: CNAME '=' expr ;
typeParamList typeExpr
: '[' ']'
| '[' ident (',' ident)* ']'
;
type_
: '(' ')' # tupleType : '(' ')' # tupleType
| '(' type_ ',' ')' # tupleType | '(' typeExpr ',' ')' # tupleType
| '(' type_ (',' type_)+ ')' # tupleType | '(' typeExpr (',' typeExpr)+ ')' # tupleType
| typeIdent # typeIdentType | generalIdent typeParamList # typeCallType
| 'Tensor' '[' shapeList ',' type_ ']' # tensorType | generalIdent # typeIdentType
| 'fn' typeParamList? '(' (type_ (',' type_)*)? ')' '->' type_ # funcType | 'Tensor' '[' shapeList ',' typeExpr ']' # tensorType
| 'fn' typeParamList? '(' (typeExpr (',' typeExpr)*)? ')' '->' typeExpr # funcType
| '_' # incompleteType | '_' # incompleteType
| NAT # intType
; ;
typeParamList: '[' generalIdent (',' generalIdent)* ']' ;
shapeList shapeList
: '(' shape (',' shape)+ ')' : '(' ')'
| '(' ')' | '(' shape (',' shape)+ ')'
| shape | shape
; ;
...@@ -157,12 +180,6 @@ shape ...@@ -157,12 +180,6 @@ shape
| NAT # intShape | NAT # intShape
; ;
typeIdent : CNAME;
// int8, int16, int32, int64
// uint8, uint16, uint32, uint64
// float16, float32, float64
// bool
body: '{' expr '}' ; body: '{' expr '}' ;
scalar scalar
...@@ -172,8 +189,8 @@ scalar ...@@ -172,8 +189,8 @@ scalar
; ;
ident ident
: opIdent : generalIdent
| GLOBAL_VAR | globalVar
| LOCAL_VAR | localVar
| GRAPH_VAR | graphVar
; ;
# Generated from /workspace/python/tvm/relay/grammar/Relay.g4 by ANTLR 4.7.1 # Generated from /Users/doobs/Code/repo/sampl/tvm/python/tvm/relay/grammar/Relay.g4 by ANTLR 4.7.2
from antlr4 import * from antlr4 import *
from io import StringIO from io import StringIO
from typing.io import TextIO from typing.io import TextIO
import sys import sys
def serializedATN(): def serializedATN():
with StringIO() as buf: with StringIO() as buf:
buf.write("\3\u608b\ua72a\u8133\ub9ed\u417c\u3be7\u7786\u5964\2/") buf.write("\3\u608b\ua72a\u8133\ub9ed\u417c\u3be7\u7786\u5964\2\62")
buf.write("\u014a\b\1\4\2\t\2\4\3\t\3\4\4\t\4\4\5\t\5\4\6\t\6\4\7") buf.write("\u0161\b\1\4\2\t\2\4\3\t\3\4\4\t\4\4\5\t\5\4\6\t\6\4\7")
buf.write("\t\7\4\b\t\b\4\t\t\t\4\n\t\n\4\13\t\13\4\f\t\f\4\r\t\r") buf.write("\t\7\4\b\t\b\4\t\t\t\4\n\t\n\4\13\t\13\4\f\t\f\4\r\t\r")
buf.write("\4\16\t\16\4\17\t\17\4\20\t\20\4\21\t\21\4\22\t\22\4\23") buf.write("\4\16\t\16\4\17\t\17\4\20\t\20\4\21\t\21\4\22\t\22\4\23")
buf.write("\t\23\4\24\t\24\4\25\t\25\4\26\t\26\4\27\t\27\4\30\t\30") buf.write("\t\23\4\24\t\24\4\25\t\25\4\26\t\26\4\27\t\27\4\30\t\30")
buf.write("\4\31\t\31\4\32\t\32\4\33\t\33\4\34\t\34\4\35\t\35\4\36") buf.write("\4\31\t\31\4\32\t\32\4\33\t\33\4\34\t\34\4\35\t\35\4\36")
buf.write("\t\36\4\37\t\37\4 \t \4!\t!\4\"\t\"\4#\t#\4$\t$\4%\t%") buf.write("\t\36\4\37\t\37\4 \t \4!\t!\4\"\t\"\4#\t#\4$\t$\4%\t%")
buf.write("\4&\t&\4\'\t\'\4(\t(\4)\t)\4*\t*\4+\t+\4,\t,\4-\t-\4.") buf.write("\4&\t&\4\'\t\'\4(\t(\4)\t)\4*\t*\4+\t+\4,\t,\4-\t-\4.")
buf.write("\t.\4/\t/\4\60\t\60\4\61\t\61\4\62\t\62\4\63\t\63\3\2") buf.write("\t.\4/\t/\4\60\t\60\4\61\t\61\4\62\t\62\4\63\t\63\4\64")
buf.write("\3\2\3\3\3\3\3\4\3\4\3\5\3\5\3\6\3\6\3\7\3\7\3\b\3\b\3") buf.write("\t\64\4\65\t\65\4\66\t\66\3\2\3\2\3\3\3\3\3\4\3\4\3\5")
buf.write("\t\3\t\3\n\3\n\3\n\3\13\3\13\3\13\3\13\3\13\3\f\3\f\3") buf.write("\3\5\3\6\3\6\3\7\3\7\3\b\3\b\3\t\3\t\3\n\3\n\3\13\3\13")
buf.write("\f\3\f\3\r\3\r\3\16\3\16\3\17\3\17\3\17\3\20\3\20\3\20") buf.write("\3\13\3\f\3\f\3\f\3\f\3\f\3\r\3\r\3\16\3\16\3\17\3\17")
buf.write("\3\21\3\21\3\21\3\22\3\22\3\22\3\22\3\23\3\23\3\24\3\24") buf.write("\3\17\3\17\3\20\3\20\3\21\3\21\3\22\3\22\3\22\3\23\3\23")
buf.write("\3\24\3\24\3\24\3\24\3\24\3\25\3\25\3\26\3\26\3\26\3\26") buf.write("\3\23\3\24\3\24\3\24\3\25\3\25\3\25\3\25\3\26\3\26\3\26")
buf.write("\3\26\3\27\3\27\3\27\3\27\3\27\3\27\3\27\3\30\3\30\3\30") buf.write("\3\26\3\26\3\26\3\26\3\27\3\27\3\27\3\27\3\27\3\30\3\30")
buf.write("\3\30\3\30\7\30\u00b1\n\30\f\30\16\30\u00b4\13\30\3\30") buf.write("\3\30\3\31\3\31\3\31\3\31\3\31\3\31\3\32\3\32\3\32\3\32")
buf.write("\3\30\3\30\3\30\3\30\3\31\6\31\u00bc\n\31\r\31\16\31\u00bd") buf.write("\3\32\3\32\3\32\3\33\3\33\3\34\3\34\3\34\3\34\3\34\3\34")
buf.write("\3\31\3\31\3\32\3\32\3\32\3\32\7\32\u00c6\n\32\f\32\16") buf.write("\3\34\3\35\3\35\3\35\3\35\3\35\3\36\3\36\3\36\3\36\3\36")
buf.write("\32\u00c9\13\32\3\32\3\32\3\32\3\32\3\33\3\33\3\33\3\34") buf.write("\3\36\3\36\3\37\3\37\3\37\3\37\3\37\7\37\u00d7\n\37\f")
buf.write("\3\34\3\34\7\34\u00d5\n\34\f\34\16\34\u00d8\13\34\3\34") buf.write("\37\16\37\u00da\13\37\3\37\3\37\3\37\3\37\3\37\3 \6 \u00e2")
buf.write("\3\34\3\35\3\35\3\36\3\36\3\37\3\37\3 \3 \3!\3!\3\"\3") buf.write("\n \r \16 \u00e3\3 \3 \3!\3!\3!\3!\7!\u00ec\n!\f!\16!")
buf.write("\"\3#\3#\3#\3$\3$\3$\3%\3%\3%\3&\3&\3&\3\'\3\'\3\'\3\'") buf.write("\u00ef\13!\3!\3!\3!\3!\3\"\3\"\3\"\3#\3#\3#\7#\u00fb\n")
buf.write("\3\'\3\'\3\'\3\'\3\'\5\'\u00fd\n\'\3(\3(\5(\u0101\n(\3") buf.write("#\f#\16#\u00fe\13#\3#\3#\3$\3$\3%\3%\3&\3&\3\'\3\'\3(")
buf.write("(\3(\3(\7(\u0106\n(\f(\16(\u0109\13(\3(\3(\7(\u010d\n") buf.write("\3(\3)\3)\3*\3*\3*\3+\3+\3+\3,\3,\3,\3-\3-\3-\3.\3.\3")
buf.write("(\f(\16(\u0110\13(\3)\3)\3)\3*\3*\3*\3+\3+\3+\3,\3,\3") buf.write(".\3.\3.\3.\3.\3.\3.\5.\u0123\n.\3/\3/\5/\u0127\n/\3/\3")
buf.write(",\3,\3,\3,\3-\3-\3-\5-\u0124\n-\3-\5-\u0127\n-\3.\3.\3") buf.write("/\3/\7/\u012c\n/\f/\16/\u012f\13/\3/\3/\7/\u0133\n/\f")
buf.write(".\3/\6/\u012d\n/\r/\16/\u012e\3\60\3\60\5\60\u0133\n\60") buf.write("/\16/\u0136\13/\3\60\3\60\3\60\5\60\u013b\n\60\3\60\5")
buf.write("\3\60\3\60\3\61\3\61\3\62\3\62\3\63\3\63\3\63\3\63\3\63") buf.write("\60\u013e\n\60\3\61\3\61\3\61\3\62\6\62\u0144\n\62\r\62")
buf.write("\3\63\3\63\3\63\3\63\3\63\3\63\7\63\u0146\n\63\f\63\16") buf.write("\16\62\u0145\3\63\3\63\5\63\u014a\n\63\3\63\3\63\3\64")
buf.write("\63\u0149\13\63\5\u00b2\u00c7\u00d6\2\64\3\3\5\4\7\5\t") buf.write("\3\64\3\65\3\65\3\66\3\66\3\66\3\66\3\66\3\66\3\66\3\66")
buf.write("\6\13\7\r\b\17\t\21\n\23\13\25\f\27\r\31\16\33\17\35\20") buf.write("\3\66\3\66\3\66\7\66\u015d\n\66\f\66\16\66\u0160\13\66")
buf.write("\37\21!\22#\23%\24\'\25)\26+\27-\30/\31\61\32\63\33\65") buf.write("\5\u00d8\u00ed\u00fc\2\67\3\3\5\4\7\5\t\6\13\7\r\b\17")
buf.write("\2\67\349\35;\36=\37? A!C\"E#G$I%K&M\'O(Q)S*U+W,Y\2[-") buf.write("\t\21\n\23\13\25\f\27\r\31\16\33\17\35\20\37\21!\22#\23")
buf.write("]._\2a\2c\2e/\3\2\b\5\2\13\f\17\17\"\"\4\2\f\f\17\17\4") buf.write("%\24\'\25)\26+\27-\30/\31\61\32\63\33\65\34\67\359\36")
buf.write("\2GGgg\4\2--//\4\2C\\c|\3\2\62;\2\u0155\2\3\3\2\2\2\2") buf.write(";\37= ?!A\"C\2E#G$I%K&M\'O(Q)S*U+W,Y-[.]/_\2a\60c\61e")
buf.write("\2g\2i\2k\62\3\2\b\5\2\13\f\17\17\"\"\4\2\f\f\17\17\4")
buf.write("\2GGgg\4\2--//\4\2C\\c|\3\2\62;\2\u016c\2\3\3\2\2\2\2")
buf.write("\5\3\2\2\2\2\7\3\2\2\2\2\t\3\2\2\2\2\13\3\2\2\2\2\r\3") buf.write("\5\3\2\2\2\2\7\3\2\2\2\2\t\3\2\2\2\2\13\3\2\2\2\2\r\3")
buf.write("\2\2\2\2\17\3\2\2\2\2\21\3\2\2\2\2\23\3\2\2\2\2\25\3\2") buf.write("\2\2\2\2\17\3\2\2\2\2\21\3\2\2\2\2\23\3\2\2\2\2\25\3\2")
buf.write("\2\2\2\27\3\2\2\2\2\31\3\2\2\2\2\33\3\2\2\2\2\35\3\2\2") buf.write("\2\2\2\27\3\2\2\2\2\31\3\2\2\2\2\33\3\2\2\2\2\35\3\2\2")
buf.write("\2\2\37\3\2\2\2\2!\3\2\2\2\2#\3\2\2\2\2%\3\2\2\2\2\'\3") buf.write("\2\2\37\3\2\2\2\2!\3\2\2\2\2#\3\2\2\2\2%\3\2\2\2\2\'\3")
buf.write("\2\2\2\2)\3\2\2\2\2+\3\2\2\2\2-\3\2\2\2\2/\3\2\2\2\2\61") buf.write("\2\2\2\2)\3\2\2\2\2+\3\2\2\2\2-\3\2\2\2\2/\3\2\2\2\2\61")
buf.write("\3\2\2\2\2\63\3\2\2\2\2\67\3\2\2\2\29\3\2\2\2\2;\3\2\2") buf.write("\3\2\2\2\2\63\3\2\2\2\2\65\3\2\2\2\2\67\3\2\2\2\29\3\2")
buf.write("\2\2=\3\2\2\2\2?\3\2\2\2\2A\3\2\2\2\2C\3\2\2\2\2E\3\2") buf.write("\2\2\2;\3\2\2\2\2=\3\2\2\2\2?\3\2\2\2\2A\3\2\2\2\2E\3")
buf.write("\2\2\2G\3\2\2\2\2I\3\2\2\2\2K\3\2\2\2\2M\3\2\2\2\2O\3") buf.write("\2\2\2\2G\3\2\2\2\2I\3\2\2\2\2K\3\2\2\2\2M\3\2\2\2\2O")
buf.write("\2\2\2\2Q\3\2\2\2\2S\3\2\2\2\2U\3\2\2\2\2W\3\2\2\2\2[") buf.write("\3\2\2\2\2Q\3\2\2\2\2S\3\2\2\2\2U\3\2\2\2\2W\3\2\2\2\2")
buf.write("\3\2\2\2\2]\3\2\2\2\2e\3\2\2\2\3g\3\2\2\2\5i\3\2\2\2\7") buf.write("Y\3\2\2\2\2[\3\2\2\2\2]\3\2\2\2\2a\3\2\2\2\2c\3\2\2\2")
buf.write("k\3\2\2\2\tm\3\2\2\2\13o\3\2\2\2\rq\3\2\2\2\17s\3\2\2") buf.write("\2k\3\2\2\2\3m\3\2\2\2\5o\3\2\2\2\7q\3\2\2\2\ts\3\2\2")
buf.write("\2\21u\3\2\2\2\23w\3\2\2\2\25z\3\2\2\2\27\177\3\2\2\2") buf.write("\2\13u\3\2\2\2\rw\3\2\2\2\17y\3\2\2\2\21{\3\2\2\2\23}")
buf.write("\31\u0083\3\2\2\2\33\u0085\3\2\2\2\35\u0087\3\2\2\2\37") buf.write("\3\2\2\2\25\177\3\2\2\2\27\u0082\3\2\2\2\31\u0087\3\2")
buf.write("\u008a\3\2\2\2!\u008d\3\2\2\2#\u0090\3\2\2\2%\u0094\3") buf.write("\2\2\33\u0089\3\2\2\2\35\u008b\3\2\2\2\37\u008f\3\2\2")
buf.write("\2\2\2\'\u0096\3\2\2\2)\u009d\3\2\2\2+\u009f\3\2\2\2-") buf.write("\2!\u0091\3\2\2\2#\u0093\3\2\2\2%\u0096\3\2\2\2\'\u0099")
buf.write("\u00a4\3\2\2\2/\u00ab\3\2\2\2\61\u00bb\3\2\2\2\63\u00c1") buf.write("\3\2\2\2)\u009c\3\2\2\2+\u00a0\3\2\2\2-\u00a7\3\2\2\2")
buf.write("\3\2\2\2\65\u00ce\3\2\2\2\67\u00d1\3\2\2\29\u00db\3\2") buf.write("/\u00ac\3\2\2\2\61\u00af\3\2\2\2\63\u00b5\3\2\2\2\65\u00bc")
buf.write("\2\2;\u00dd\3\2\2\2=\u00df\3\2\2\2?\u00e1\3\2\2\2A\u00e3") buf.write("\3\2\2\2\67\u00be\3\2\2\29\u00c5\3\2\2\2;\u00ca\3\2\2")
buf.write("\3\2\2\2C\u00e5\3\2\2\2E\u00e7\3\2\2\2G\u00ea\3\2\2\2") buf.write("\2=\u00d1\3\2\2\2?\u00e1\3\2\2\2A\u00e7\3\2\2\2C\u00f4")
buf.write("I\u00ed\3\2\2\2K\u00f0\3\2\2\2M\u00fc\3\2\2\2O\u0100\3") buf.write("\3\2\2\2E\u00f7\3\2\2\2G\u0101\3\2\2\2I\u0103\3\2\2\2")
buf.write("\2\2\2Q\u0111\3\2\2\2S\u0114\3\2\2\2U\u0117\3\2\2\2W\u011a") buf.write("K\u0105\3\2\2\2M\u0107\3\2\2\2O\u0109\3\2\2\2Q\u010b\3")
buf.write("\3\2\2\2Y\u0120\3\2\2\2[\u0128\3\2\2\2]\u012c\3\2\2\2") buf.write("\2\2\2S\u010d\3\2\2\2U\u0110\3\2\2\2W\u0113\3\2\2\2Y\u0116")
buf.write("_\u0130\3\2\2\2a\u0136\3\2\2\2c\u0138\3\2\2\2e\u013a\3") buf.write("\3\2\2\2[\u0122\3\2\2\2]\u0126\3\2\2\2_\u0137\3\2\2\2")
buf.write("\2\2\2gh\7.\2\2h\4\3\2\2\2ij\7*\2\2j\6\3\2\2\2kl\7+\2") buf.write("a\u013f\3\2\2\2c\u0143\3\2\2\2e\u0147\3\2\2\2g\u014d\3")
buf.write("\2l\b\3\2\2\2mn\7}\2\2n\n\3\2\2\2op\7\177\2\2p\f\3\2\2") buf.write("\2\2\2i\u014f\3\2\2\2k\u0151\3\2\2\2mn\7\60\2\2n\4\3\2")
buf.write("\2qr\7\60\2\2r\16\3\2\2\2st\7]\2\2t\20\3\2\2\2uv\7_\2") buf.write("\2\2op\7B\2\2p\6\3\2\2\2qr\7\'\2\2r\b\3\2\2\2st\7a\2\2")
buf.write("\2v\22\3\2\2\2wx\7k\2\2xy\7h\2\2y\24\3\2\2\2z{\7g\2\2") buf.write("t\n\3\2\2\2uv\7.\2\2v\f\3\2\2\2wx\7*\2\2x\16\3\2\2\2y")
buf.write("{|\7n\2\2|}\7u\2\2}~\7g\2\2~\26\3\2\2\2\177\u0080\7n\2") buf.write("z\7+\2\2z\20\3\2\2\2{|\7]\2\2|\22\3\2\2\2}~\7_\2\2~\24")
buf.write("\2\u0080\u0081\7g\2\2\u0081\u0082\7v\2\2\u0082\30\3\2") buf.write("\3\2\2\2\177\u0080\7k\2\2\u0080\u0081\7h\2\2\u0081\26")
buf.write("\2\2\u0083\u0084\7?\2\2\u0084\32\3\2\2\2\u0085\u0086\7") buf.write("\3\2\2\2\u0082\u0083\7g\2\2\u0083\u0084\7n\2\2\u0084\u0085")
buf.write("=\2\2\u0086\34\3\2\2\2\u0087\u0088\7=\2\2\u0088\u0089") buf.write("\7u\2\2\u0085\u0086\7g\2\2\u0086\30\3\2\2\2\u0087\u0088")
buf.write("\7=\2\2\u0089\36\3\2\2\2\u008a\u008b\7h\2\2\u008b\u008c") buf.write("\7}\2\2\u0088\32\3\2\2\2\u0089\u008a\7\177\2\2\u008a\34")
buf.write("\7p\2\2\u008c \3\2\2\2\u008d\u008e\7/\2\2\u008e\u008f") buf.write("\3\2\2\2\u008b\u008c\7n\2\2\u008c\u008d\7g\2\2\u008d\u008e")
buf.write("\7@\2\2\u008f\"\3\2\2\2\u0090\u0091\7f\2\2\u0091\u0092") buf.write("\7v\2\2\u008e\36\3\2\2\2\u008f\u0090\7?\2\2\u0090 \3\2")
buf.write("\7g\2\2\u0092\u0093\7h\2\2\u0093$\3\2\2\2\u0094\u0095") buf.write("\2\2\u0091\u0092\7=\2\2\u0092\"\3\2\2\2\u0093\u0094\7")
buf.write("\7<\2\2\u0095&\3\2\2\2\u0096\u0097\7V\2\2\u0097\u0098") buf.write("=\2\2\u0094\u0095\7=\2\2\u0095$\3\2\2\2\u0096\u0097\7")
buf.write("\7g\2\2\u0098\u0099\7p\2\2\u0099\u009a\7u\2\2\u009a\u009b") buf.write("h\2\2\u0097\u0098\7p\2\2\u0098&\3\2\2\2\u0099\u009a\7")
buf.write("\7q\2\2\u009b\u009c\7t\2\2\u009c(\3\2\2\2\u009d\u009e") buf.write("/\2\2\u009a\u009b\7@\2\2\u009b(\3\2\2\2\u009c\u009d\7")
buf.write("\7a\2\2\u009e*\3\2\2\2\u009f\u00a0\7o\2\2\u00a0\u00a1") buf.write("f\2\2\u009d\u009e\7g\2\2\u009e\u009f\7h\2\2\u009f*\3\2")
buf.write("\7g\2\2\u00a1\u00a2\7v\2\2\u00a2\u00a3\7c\2\2\u00a3,\3") buf.write("\2\2\u00a0\u00a1\7g\2\2\u00a1\u00a2\7z\2\2\u00a2\u00a3")
buf.write("\2\2\2\u00a4\u00a5\7x\2\2\u00a5\u00a6\7\62\2\2\u00a6\u00a7") buf.write("\7v\2\2\u00a3\u00a4\7g\2\2\u00a4\u00a5\7t\2\2\u00a5\u00a6")
buf.write("\7\60\2\2\u00a7\u00a8\7\62\2\2\u00a8\u00a9\7\60\2\2\u00a9") buf.write("\7p\2\2\u00a6,\3\2\2\2\u00a7\u00a8\7v\2\2\u00a8\u00a9")
buf.write("\u00aa\7\65\2\2\u00aa.\3\2\2\2\u00ab\u00ac\7\61\2\2\u00ac") buf.write("\7{\2\2\u00a9\u00aa\7r\2\2\u00aa\u00ab\7g\2\2\u00ab.\3")
buf.write("\u00ad\7,\2\2\u00ad\u00b2\3\2\2\2\u00ae\u00b1\5/\30\2") buf.write("\2\2\2\u00ac\u00ad\7?\2\2\u00ad\u00ae\7@\2\2\u00ae\60")
buf.write("\u00af\u00b1\13\2\2\2\u00b0\u00ae\3\2\2\2\u00b0\u00af") buf.write("\3\2\2\2\u00af\u00b0\7o\2\2\u00b0\u00b1\7c\2\2\u00b1\u00b2")
buf.write("\3\2\2\2\u00b1\u00b4\3\2\2\2\u00b2\u00b3\3\2\2\2\u00b2") buf.write("\7v\2\2\u00b2\u00b3\7e\2\2\u00b3\u00b4\7j\2\2\u00b4\62")
buf.write("\u00b0\3\2\2\2\u00b3\u00b5\3\2\2\2\u00b4\u00b2\3\2\2\2") buf.write("\3\2\2\2\u00b5\u00b6\7o\2\2\u00b6\u00b7\7c\2\2\u00b7\u00b8")
buf.write("\u00b5\u00b6\7,\2\2\u00b6\u00b7\7\61\2\2\u00b7\u00b8\3") buf.write("\7v\2\2\u00b8\u00b9\7e\2\2\u00b9\u00ba\7j\2\2\u00ba\u00bb")
buf.write("\2\2\2\u00b8\u00b9\b\30\2\2\u00b9\60\3\2\2\2\u00ba\u00bc") buf.write("\7A\2\2\u00bb\64\3\2\2\2\u00bc\u00bd\7<\2\2\u00bd\66\3")
buf.write("\t\2\2\2\u00bb\u00ba\3\2\2\2\u00bc\u00bd\3\2\2\2\u00bd") buf.write("\2\2\2\u00be\u00bf\7V\2\2\u00bf\u00c0\7g\2\2\u00c0\u00c1")
buf.write("\u00bb\3\2\2\2\u00bd\u00be\3\2\2\2\u00be\u00bf\3\2\2\2") buf.write("\7p\2\2\u00c1\u00c2\7u\2\2\u00c2\u00c3\7q\2\2\u00c3\u00c4")
buf.write("\u00bf\u00c0\b\31\2\2\u00c0\62\3\2\2\2\u00c1\u00c2\7\61") buf.write("\7t\2\2\u00c48\3\2\2\2\u00c5\u00c6\7o\2\2\u00c6\u00c7")
buf.write("\2\2\u00c2\u00c3\7\61\2\2\u00c3\u00c7\3\2\2\2\u00c4\u00c6") buf.write("\7g\2\2\u00c7\u00c8\7v\2\2\u00c8\u00c9\7c\2\2\u00c9:\3")
buf.write("\13\2\2\2\u00c5\u00c4\3\2\2\2\u00c6\u00c9\3\2\2\2\u00c7") buf.write("\2\2\2\u00ca\u00cb\7x\2\2\u00cb\u00cc\7\62\2\2\u00cc\u00cd")
buf.write("\u00c8\3\2\2\2\u00c7\u00c5\3\2\2\2\u00c8\u00ca\3\2\2\2") buf.write("\7\60\2\2\u00cd\u00ce\7\62\2\2\u00ce\u00cf\7\60\2\2\u00cf")
buf.write("\u00c9\u00c7\3\2\2\2\u00ca\u00cb\7\f\2\2\u00cb\u00cc\3") buf.write("\u00d0\7\66\2\2\u00d0<\3\2\2\2\u00d1\u00d2\7\61\2\2\u00d2")
buf.write("\2\2\2\u00cc\u00cd\b\32\2\2\u00cd\64\3\2\2\2\u00ce\u00cf") buf.write("\u00d3\7,\2\2\u00d3\u00d8\3\2\2\2\u00d4\u00d7\5=\37\2")
buf.write("\7^\2\2\u00cf\u00d0\7$\2\2\u00d0\66\3\2\2\2\u00d1\u00d6") buf.write("\u00d5\u00d7\13\2\2\2\u00d6\u00d4\3\2\2\2\u00d6\u00d5")
buf.write("\7$\2\2\u00d2\u00d5\5\65\33\2\u00d3\u00d5\n\3\2\2\u00d4") buf.write("\3\2\2\2\u00d7\u00da\3\2\2\2\u00d8\u00d9\3\2\2\2\u00d8")
buf.write("\u00d2\3\2\2\2\u00d4\u00d3\3\2\2\2\u00d5\u00d8\3\2\2\2") buf.write("\u00d6\3\2\2\2\u00d9\u00db\3\2\2\2\u00da\u00d8\3\2\2\2")
buf.write("\u00d6\u00d7\3\2\2\2\u00d6\u00d4\3\2\2\2\u00d7\u00d9\3") buf.write("\u00db\u00dc\7,\2\2\u00dc\u00dd\7\61\2\2\u00dd\u00de\3")
buf.write("\2\2\2\u00d8\u00d6\3\2\2\2\u00d9\u00da\7$\2\2\u00da8\3") buf.write("\2\2\2\u00de\u00df\b\37\2\2\u00df>\3\2\2\2\u00e0\u00e2")
buf.write("\2\2\2\u00db\u00dc\7,\2\2\u00dc:\3\2\2\2\u00dd\u00de\7") buf.write("\t\2\2\2\u00e1\u00e0\3\2\2\2\u00e2\u00e3\3\2\2\2\u00e3")
buf.write("\61\2\2\u00de<\3\2\2\2\u00df\u00e0\7-\2\2\u00e0>\3\2\2") buf.write("\u00e1\3\2\2\2\u00e3\u00e4\3\2\2\2\u00e4\u00e5\3\2\2\2")
buf.write("\2\u00e1\u00e2\7/\2\2\u00e2@\3\2\2\2\u00e3\u00e4\7>\2") buf.write("\u00e5\u00e6\b \2\2\u00e6@\3\2\2\2\u00e7\u00e8\7\61\2")
buf.write("\2\u00e4B\3\2\2\2\u00e5\u00e6\7@\2\2\u00e6D\3\2\2\2\u00e7") buf.write("\2\u00e8\u00e9\7\61\2\2\u00e9\u00ed\3\2\2\2\u00ea\u00ec")
buf.write("\u00e8\7>\2\2\u00e8\u00e9\7?\2\2\u00e9F\3\2\2\2\u00ea") buf.write("\13\2\2\2\u00eb\u00ea\3\2\2\2\u00ec\u00ef\3\2\2\2\u00ed")
buf.write("\u00eb\7@\2\2\u00eb\u00ec\7?\2\2\u00ecH\3\2\2\2\u00ed") buf.write("\u00ee\3\2\2\2\u00ed\u00eb\3\2\2\2\u00ee\u00f0\3\2\2\2")
buf.write("\u00ee\7?\2\2\u00ee\u00ef\7?\2\2\u00efJ\3\2\2\2\u00f0") buf.write("\u00ef\u00ed\3\2\2\2\u00f0\u00f1\7\f\2\2\u00f1\u00f2\3")
buf.write("\u00f1\7#\2\2\u00f1\u00f2\7?\2\2\u00f2L\3\2\2\2\u00f3") buf.write("\2\2\2\u00f2\u00f3\b!\2\2\u00f3B\3\2\2\2\u00f4\u00f5\7")
buf.write("\u00f4\7V\2\2\u00f4\u00f5\7t\2\2\u00f5\u00f6\7w\2\2\u00f6") buf.write("^\2\2\u00f5\u00f6\7$\2\2\u00f6D\3\2\2\2\u00f7\u00fc\7")
buf.write("\u00fd\7g\2\2\u00f7\u00f8\7H\2\2\u00f8\u00f9\7c\2\2\u00f9") buf.write("$\2\2\u00f8\u00fb\5C\"\2\u00f9\u00fb\n\3\2\2\u00fa\u00f8")
buf.write("\u00fa\7n\2\2\u00fa\u00fb\7u\2\2\u00fb\u00fd\7g\2\2\u00fc") buf.write("\3\2\2\2\u00fa\u00f9\3\2\2\2\u00fb\u00fe\3\2\2\2\u00fc")
buf.write("\u00f3\3\2\2\2\u00fc\u00f7\3\2\2\2\u00fdN\3\2\2\2\u00fe") buf.write("\u00fd\3\2\2\2\u00fc\u00fa\3\2\2\2\u00fd\u00ff\3\2\2\2")
buf.write("\u0101\7a\2\2\u00ff\u0101\5a\61\2\u0100\u00fe\3\2\2\2") buf.write("\u00fe\u00fc\3\2\2\2\u00ff\u0100\7$\2\2\u0100F\3\2\2\2")
buf.write("\u0100\u00ff\3\2\2\2\u0101\u0107\3\2\2\2\u0102\u0106\7") buf.write("\u0101\u0102\7,\2\2\u0102H\3\2\2\2\u0103\u0104\7\61\2")
buf.write("a\2\2\u0103\u0106\5a\61\2\u0104\u0106\5c\62\2\u0105\u0102") buf.write("\2\u0104J\3\2\2\2\u0105\u0106\7-\2\2\u0106L\3\2\2\2\u0107")
buf.write("\3\2\2\2\u0105\u0103\3\2\2\2\u0105\u0104\3\2\2\2\u0106") buf.write("\u0108\7/\2\2\u0108N\3\2\2\2\u0109\u010a\7>\2\2\u010a")
buf.write("\u0109\3\2\2\2\u0107\u0105\3\2\2\2\u0107\u0108\3\2\2\2") buf.write("P\3\2\2\2\u010b\u010c\7@\2\2\u010cR\3\2\2\2\u010d\u010e")
buf.write("\u0108\u010e\3\2\2\2\u0109\u0107\3\2\2\2\u010a\u010b\7") buf.write("\7>\2\2\u010e\u010f\7?\2\2\u010fT\3\2\2\2\u0110\u0111")
buf.write("\60\2\2\u010b\u010d\5O(\2\u010c\u010a\3\2\2\2\u010d\u0110") buf.write("\7@\2\2\u0111\u0112\7?\2\2\u0112V\3\2\2\2\u0113\u0114")
buf.write("\3\2\2\2\u010e\u010c\3\2\2\2\u010e\u010f\3\2\2\2\u010f") buf.write("\7?\2\2\u0114\u0115\7?\2\2\u0115X\3\2\2\2\u0116\u0117")
buf.write("P\3\2\2\2\u0110\u010e\3\2\2\2\u0111\u0112\7B\2\2\u0112") buf.write("\7#\2\2\u0117\u0118\7?\2\2\u0118Z\3\2\2\2\u0119\u011a")
buf.write("\u0113\5O(\2\u0113R\3\2\2\2\u0114\u0115\7\'\2\2\u0115") buf.write("\7V\2\2\u011a\u011b\7t\2\2\u011b\u011c\7w\2\2\u011c\u0123")
buf.write("\u0116\5O(\2\u0116T\3\2\2\2\u0117\u0118\7\'\2\2\u0118") buf.write("\7g\2\2\u011d\u011e\7H\2\2\u011e\u011f\7c\2\2\u011f\u0120")
buf.write("\u0119\5]/\2\u0119V\3\2\2\2\u011a\u011b\7k\2\2\u011b\u011c") buf.write("\7n\2\2\u0120\u0121\7u\2\2\u0121\u0123\7g\2\2\u0122\u0119")
buf.write("\7p\2\2\u011c\u011d\7v\2\2\u011d\u011e\78\2\2\u011e\u011f") buf.write("\3\2\2\2\u0122\u011d\3\2\2\2\u0123\\\3\2\2\2\u0124\u0127")
buf.write("\7\66\2\2\u011fX\3\2\2\2\u0120\u0123\5]/\2\u0121\u0122") buf.write("\7a\2\2\u0125\u0127\5g\64\2\u0126\u0124\3\2\2\2\u0126")
buf.write("\7\60\2\2\u0122\u0124\5]/\2\u0123\u0121\3\2\2\2\u0123") buf.write("\u0125\3\2\2\2\u0127\u012d\3\2\2\2\u0128\u012c\7a\2\2")
buf.write("\u0124\3\2\2\2\u0124\u0126\3\2\2\2\u0125\u0127\5_\60\2") buf.write("\u0129\u012c\5g\64\2\u012a\u012c\5i\65\2\u012b\u0128\3")
buf.write("\u0126\u0125\3\2\2\2\u0126\u0127\3\2\2\2\u0127Z\3\2\2") buf.write("\2\2\2\u012b\u0129\3\2\2\2\u012b\u012a\3\2\2\2\u012c\u012f")
buf.write("\2\u0128\u0129\5Y-\2\u0129\u012a\7h\2\2\u012a\\\3\2\2") buf.write("\3\2\2\2\u012d\u012b\3\2\2\2\u012d\u012e\3\2\2\2\u012e")
buf.write("\2\u012b\u012d\5c\62\2\u012c\u012b\3\2\2\2\u012d\u012e") buf.write("\u0134\3\2\2\2\u012f\u012d\3\2\2\2\u0130\u0131\7\60\2")
buf.write("\3\2\2\2\u012e\u012c\3\2\2\2\u012e\u012f\3\2\2\2\u012f") buf.write("\2\u0131\u0133\5]/\2\u0132\u0130\3\2\2\2\u0133\u0136\3")
buf.write("^\3\2\2\2\u0130\u0132\t\4\2\2\u0131\u0133\t\5\2\2\u0132") buf.write("\2\2\2\u0134\u0132\3\2\2\2\u0134\u0135\3\2\2\2\u0135^")
buf.write("\u0131\3\2\2\2\u0132\u0133\3\2\2\2\u0133\u0134\3\2\2\2") buf.write("\3\2\2\2\u0136\u0134\3\2\2\2\u0137\u013a\5c\62\2\u0138")
buf.write("\u0134\u0135\5]/\2\u0135`\3\2\2\2\u0136\u0137\t\6\2\2") buf.write("\u0139\7\60\2\2\u0139\u013b\5c\62\2\u013a\u0138\3\2\2")
buf.write("\u0137b\3\2\2\2\u0138\u0139\t\7\2\2\u0139d\3\2\2\2\u013a") buf.write("\2\u013a\u013b\3\2\2\2\u013b\u013d\3\2\2\2\u013c\u013e")
buf.write("\u013b\7O\2\2\u013b\u013c\7G\2\2\u013c\u013d\7V\2\2\u013d") buf.write("\5e\63\2\u013d\u013c\3\2\2\2\u013d\u013e\3\2\2\2\u013e")
buf.write("\u013e\7C\2\2\u013e\u013f\7F\2\2\u013f\u0140\7C\2\2\u0140") buf.write("`\3\2\2\2\u013f\u0140\5_\60\2\u0140\u0141\7h\2\2\u0141")
buf.write("\u0141\7V\2\2\u0141\u0142\7C\2\2\u0142\u0143\7<\2\2\u0143") buf.write("b\3\2\2\2\u0142\u0144\5i\65\2\u0143\u0142\3\2\2\2\u0144")
buf.write("\u0147\3\2\2\2\u0144\u0146\13\2\2\2\u0145\u0144\3\2\2") buf.write("\u0145\3\2\2\2\u0145\u0143\3\2\2\2\u0145\u0146\3\2\2\2")
buf.write("\2\u0146\u0149\3\2\2\2\u0147\u0145\3\2\2\2\u0147\u0148") buf.write("\u0146d\3\2\2\2\u0147\u0149\t\4\2\2\u0148\u014a\t\5\2")
buf.write("\3\2\2\2\u0148f\3\2\2\2\u0149\u0147\3\2\2\2\23\2\u00b0") buf.write("\2\u0149\u0148\3\2\2\2\u0149\u014a\3\2\2\2\u014a\u014b")
buf.write("\u00b2\u00bd\u00c7\u00d4\u00d6\u00fc\u0100\u0105\u0107") buf.write("\3\2\2\2\u014b\u014c\5c\62\2\u014cf\3\2\2\2\u014d\u014e")
buf.write("\u010e\u0123\u0126\u012e\u0132\u0147\3\b\2\2") buf.write("\t\6\2\2\u014eh\3\2\2\2\u014f\u0150\t\7\2\2\u0150j\3\2")
buf.write("\2\2\u0151\u0152\7O\2\2\u0152\u0153\7G\2\2\u0153\u0154")
buf.write("\7V\2\2\u0154\u0155\7C\2\2\u0155\u0156\7F\2\2\u0156\u0157")
buf.write("\7C\2\2\u0157\u0158\7V\2\2\u0158\u0159\7C\2\2\u0159\u015a")
buf.write("\7<\2\2\u015a\u015e\3\2\2\2\u015b\u015d\13\2\2\2\u015c")
buf.write("\u015b\3\2\2\2\u015d\u0160\3\2\2\2\u015e\u015c\3\2\2\2")
buf.write("\u015e\u015f\3\2\2\2\u015fl\3\2\2\2\u0160\u015e\3\2\2")
buf.write("\2\23\2\u00d6\u00d8\u00e3\u00ed\u00fa\u00fc\u0122\u0126")
buf.write("\u012b\u012d\u0134\u013a\u013d\u0145\u0149\u015e\3\b\2")
buf.write("\2")
return buf.getvalue() return buf.getvalue()
...@@ -178,62 +190,65 @@ class RelayLexer(Lexer): ...@@ -178,62 +190,65 @@ class RelayLexer(Lexer):
T__18 = 19 T__18 = 19
T__19 = 20 T__19 = 20
T__20 = 21 T__20 = 21
SEMVER = 22 T__21 = 22
COMMENT = 23 T__22 = 23
WS = 24 T__23 = 24
LINE_COMMENT = 25 T__24 = 25
QUOTED_STRING = 26 T__25 = 26
MUL = 27 T__26 = 27
DIV = 28 T__27 = 28
ADD = 29 SEMVER = 29
SUB = 30 COMMENT = 30
LT = 31 WS = 31
GT = 32 LINE_COMMENT = 32
LE = 33 QUOTED_STRING = 33
GE = 34 MUL = 34
EQ = 35 DIV = 35
NE = 36 ADD = 36
BOOL_LIT = 37 SUB = 37
CNAME = 38 LT = 38
GLOBAL_VAR = 39 GT = 39
LOCAL_VAR = 40 LE = 40
GRAPH_VAR = 41 GE = 41
DATATYPE = 42 EQ = 42
FLOAT = 43 NE = 43
NAT = 44 BOOL_LIT = 44
METADATA = 45 CNAME = 45
FLOAT = 46
NAT = 47
METADATA = 48
channelNames = [ u"DEFAULT_TOKEN_CHANNEL", u"HIDDEN" ] channelNames = [ u"DEFAULT_TOKEN_CHANNEL", u"HIDDEN" ]
modeNames = [ "DEFAULT_MODE" ] modeNames = [ "DEFAULT_MODE" ]
literalNames = [ "<INVALID>", literalNames = [ "<INVALID>",
"','", "'('", "')'", "'{'", "'}'", "'.'", "'['", "']'", "'if'", "'.'", "'@'", "'%'", "'_'", "','", "'('", "')'", "'['", "']'",
"'else'", "'let'", "'='", "';'", "';;'", "'fn'", "'->'", "'def'", "'if'", "'else'", "'{'", "'}'", "'let'", "'='", "';'", "';;'",
"':'", "'Tensor'", "'_'", "'meta'", "'v0.0.3'", "'*'", "'/'", "'fn'", "'->'", "'def'", "'extern'", "'type'", "'=>'", "'match'",
"'+'", "'-'", "'<'", "'>'", "'<='", "'>='", "'=='", "'!='", "'match?'", "':'", "'Tensor'", "'meta'", "'v0.0.4'", "'*'",
"'int64'" ] "'/'", "'+'", "'-'", "'<'", "'>'", "'<='", "'>='", "'=='", "'!='" ]
symbolicNames = [ "<INVALID>", symbolicNames = [ "<INVALID>",
"SEMVER", "COMMENT", "WS", "LINE_COMMENT", "QUOTED_STRING", "SEMVER", "COMMENT", "WS", "LINE_COMMENT", "QUOTED_STRING",
"MUL", "DIV", "ADD", "SUB", "LT", "GT", "LE", "GE", "EQ", "NE", "MUL", "DIV", "ADD", "SUB", "LT", "GT", "LE", "GE", "EQ", "NE",
"BOOL_LIT", "CNAME", "GLOBAL_VAR", "LOCAL_VAR", "GRAPH_VAR", "BOOL_LIT", "CNAME", "FLOAT", "NAT", "METADATA" ]
"DATATYPE", "FLOAT", "NAT", "METADATA" ]
ruleNames = [ "T__0", "T__1", "T__2", "T__3", "T__4", "T__5", "T__6", ruleNames = [ "T__0", "T__1", "T__2", "T__3", "T__4", "T__5", "T__6",
"T__7", "T__8", "T__9", "T__10", "T__11", "T__12", "T__13", "T__7", "T__8", "T__9", "T__10", "T__11", "T__12", "T__13",
"T__14", "T__15", "T__16", "T__17", "T__18", "T__19", "T__14", "T__15", "T__16", "T__17", "T__18", "T__19",
"T__20", "SEMVER", "COMMENT", "WS", "LINE_COMMENT", "ESCAPED_QUOTE", "T__20", "T__21", "T__22", "T__23", "T__24", "T__25",
"QUOTED_STRING", "MUL", "DIV", "ADD", "SUB", "LT", "GT", "T__26", "T__27", "SEMVER", "COMMENT", "WS", "LINE_COMMENT",
"LE", "GE", "EQ", "NE", "BOOL_LIT", "CNAME", "GLOBAL_VAR", "ESCAPED_QUOTE", "QUOTED_STRING", "MUL", "DIV", "ADD",
"LOCAL_VAR", "GRAPH_VAR", "DATATYPE", "PREFLOAT", "FLOAT", "SUB", "LT", "GT", "LE", "GE", "EQ", "NE", "BOOL_LIT",
"NAT", "EXP", "LETTER", "DIGIT", "METADATA" ] "CNAME", "PREFLOAT", "FLOAT", "NAT", "EXP", "LETTER",
"DIGIT", "METADATA" ]
grammarFileName = "Relay.g4" grammarFileName = "Relay.g4"
def __init__(self, input=None, output:TextIO = sys.stdout): def __init__(self, input=None, output:TextIO = sys.stdout):
super().__init__(input, output) super().__init__(input, output)
self.checkVersion("4.7.1") self.checkVersion("4.7.2")
self._interp = LexerATNSimulator(self, self.atn, self.decisionsToDFA, PredictionContextCache()) self._interp = LexerATNSimulator(self, self.atn, self.decisionsToDFA, PredictionContextCache())
self._actions = None self._actions = None
self._predicates = None self._predicates = None
......
This source diff could not be displayed because it is too large. You can view the blob instead.
# Generated from /workspace/python/tvm/relay/grammar/Relay.g4 by ANTLR 4.7.1 # Generated from /Users/doobs/Code/repo/sampl/tvm/python/tvm/relay/grammar/Relay.g4 by ANTLR 4.7.2
from antlr4 import * from antlr4 import *
if __name__ is not None and "." in __name__: if __name__ is not None and "." in __name__:
from .RelayParser import RelayParser from .RelayParser import RelayParser
...@@ -9,13 +9,28 @@ else: ...@@ -9,13 +9,28 @@ else:
class RelayVisitor(ParseTreeVisitor): class RelayVisitor(ParseTreeVisitor):
# Visit a parse tree produced by RelayParser#opIdent. # Visit a parse tree produced by RelayParser#prog.
def visitOpIdent(self, ctx:RelayParser.OpIdentContext): def visitProg(self, ctx:RelayParser.ProgContext):
return self.visitChildren(ctx) return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#prog. # Visit a parse tree produced by RelayParser#generalIdent.
def visitProg(self, ctx:RelayParser.ProgContext): def visitGeneralIdent(self, ctx:RelayParser.GeneralIdentContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#globalVar.
def visitGlobalVar(self, ctx:RelayParser.GlobalVarContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#localVar.
def visitLocalVar(self, ctx:RelayParser.LocalVarContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#graphVar.
def visitGraphVar(self, ctx:RelayParser.GraphVarContext):
return self.visitChildren(ctx) return self.visitChildren(ctx)
...@@ -44,6 +59,11 @@ class RelayVisitor(ParseTreeVisitor): ...@@ -44,6 +59,11 @@ class RelayVisitor(ParseTreeVisitor):
return self.visitChildren(ctx) return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#match.
def visitMatch(self, ctx:RelayParser.MatchContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#tensor. # Visit a parse tree produced by RelayParser#tensor.
def visitTensor(self, ctx:RelayParser.TensorContext): def visitTensor(self, ctx:RelayParser.TensorContext):
return self.visitChildren(ctx) return self.visitChildren(ctx)
...@@ -114,8 +134,73 @@ class RelayVisitor(ParseTreeVisitor): ...@@ -114,8 +134,73 @@ class RelayVisitor(ParseTreeVisitor):
return self.visitChildren(ctx) return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#defn. # Visit a parse tree produced by RelayParser#funcDefn.
def visitDefn(self, ctx:RelayParser.DefnContext): def visitFuncDefn(self, ctx:RelayParser.FuncDefnContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#externAdtDefn.
def visitExternAdtDefn(self, ctx:RelayParser.ExternAdtDefnContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#adtDefn.
def visitAdtDefn(self, ctx:RelayParser.AdtDefnContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#constructorName.
def visitConstructorName(self, ctx:RelayParser.ConstructorNameContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#adtConsDefnList.
def visitAdtConsDefnList(self, ctx:RelayParser.AdtConsDefnListContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#adtConsDefn.
def visitAdtConsDefn(self, ctx:RelayParser.AdtConsDefnContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#matchClauseList.
def visitMatchClauseList(self, ctx:RelayParser.MatchClauseListContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#matchClause.
def visitMatchClause(self, ctx:RelayParser.MatchClauseContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#matchType.
def visitMatchType(self, ctx:RelayParser.MatchTypeContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#patternList.
def visitPatternList(self, ctx:RelayParser.PatternListContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#pattern.
def visitPattern(self, ctx:RelayParser.PatternContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#adtCons.
def visitAdtCons(self, ctx:RelayParser.AdtConsContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#adtConsParamList.
def visitAdtConsParamList(self, ctx:RelayParser.AdtConsParamListContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#adtConsParam.
def visitAdtConsParam(self, ctx:RelayParser.AdtConsParamContext):
return self.visitChildren(ctx) return self.visitChildren(ctx)
...@@ -149,13 +234,13 @@ class RelayVisitor(ParseTreeVisitor): ...@@ -149,13 +234,13 @@ class RelayVisitor(ParseTreeVisitor):
return self.visitChildren(ctx) return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#typeParamList. # Visit a parse tree produced by RelayParser#tupleType.
def visitTypeParamList(self, ctx:RelayParser.TypeParamListContext): def visitTupleType(self, ctx:RelayParser.TupleTypeContext):
return self.visitChildren(ctx) return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#tupleType. # Visit a parse tree produced by RelayParser#typeCallType.
def visitTupleType(self, ctx:RelayParser.TupleTypeContext): def visitTypeCallType(self, ctx:RelayParser.TypeCallTypeContext):
return self.visitChildren(ctx) return self.visitChildren(ctx)
...@@ -179,8 +264,8 @@ class RelayVisitor(ParseTreeVisitor): ...@@ -179,8 +264,8 @@ class RelayVisitor(ParseTreeVisitor):
return self.visitChildren(ctx) return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#intType. # Visit a parse tree produced by RelayParser#typeParamList.
def visitIntType(self, ctx:RelayParser.IntTypeContext): def visitTypeParamList(self, ctx:RelayParser.TypeParamListContext):
return self.visitChildren(ctx) return self.visitChildren(ctx)
...@@ -209,11 +294,6 @@ class RelayVisitor(ParseTreeVisitor): ...@@ -209,11 +294,6 @@ class RelayVisitor(ParseTreeVisitor):
return self.visitChildren(ctx) return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#typeIdent.
def visitTypeIdent(self, ctx:RelayParser.TypeIdentContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#body. # Visit a parse tree produced by RelayParser#body.
def visitBody(self, ctx:RelayParser.BodyContext): def visitBody(self, ctx:RelayParser.BodyContext):
return self.visitChildren(ctx) return self.visitChildren(ctx)
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
* specific language governing permissions and limitations * specific language governing permissions and limitations
* under the License. * under the License.
*/ */
v0.0.3 v0.0.4
def @id[a](%x: a) -> a { def @id[a](%x: a) -> a {
%x %x
......
...@@ -70,7 +70,10 @@ class AlphaEqualHandler: ...@@ -70,7 +70,10 @@ class AlphaEqualHandler:
} }
if (lhsm->type_definitions.size() != rhsm->type_definitions.size()) return false; if (lhsm->type_definitions.size() != rhsm->type_definitions.size()) return false;
for (const auto& p : lhsm->type_definitions) { for (const auto& p : lhsm->type_definitions) {
if (!Equal(p.second, rhsm->LookupDef(p.first->var->name_hint))) return false; if (!rhsm->HasDef(p.first->var->name_hint) ||
!Equal(p.second, rhsm->LookupDef(p.first->var->name_hint))) {
return false;
}
} }
return true; return true;
} }
...@@ -288,7 +291,7 @@ class AlphaEqualHandler: ...@@ -288,7 +291,7 @@ class AlphaEqualHandler:
} }
bool VisitType_(const GlobalTypeVarNode* lhs, const Type& other) final { bool VisitType_(const GlobalTypeVarNode* lhs, const Type& other) final {
return GetRef<Type>(lhs) == other; return LeafNodeEqual(GetRef<NodeRef>(lhs), other);
} }
bool VisitType_(const TypeCallNode* lhs, const Type& other) final { bool VisitType_(const TypeCallNode* lhs, const Type& other) final {
...@@ -307,6 +310,26 @@ class AlphaEqualHandler: ...@@ -307,6 +310,26 @@ class AlphaEqualHandler:
return true; return true;
} }
bool VisitType_(const TypeDataNode* lhs, const Type& other) final {
const TypeDataNode* rhs = other.as<TypeDataNode>();
if (rhs == nullptr
|| lhs->type_vars.size() != rhs->type_vars.size()
|| !TypeEqual(lhs->header, rhs->header)) {
return false;
}
for (size_t i = 0; i < lhs->type_vars.size(); ++i) {
if (!TypeEqual(lhs->type_vars[i], rhs->type_vars[i])) {
return false;
}
}
for (size_t i = 0; i < lhs->constructors.size(); ++i) {
if (!ExprEqual(lhs->constructors[i], rhs->constructors[i])) {
return false;
}
}
return true;
}
// Expr equal checking. // Expr equal checking.
bool NDArrayEqual(const runtime::NDArray& lhs, bool NDArrayEqual(const runtime::NDArray& lhs,
const runtime::NDArray& rhs) { const runtime::NDArray& rhs) {
...@@ -485,7 +508,10 @@ class AlphaEqualHandler: ...@@ -485,7 +508,10 @@ class AlphaEqualHandler:
} }
bool VisitExpr_(const ConstructorNode* lhs, const Expr& other) final { bool VisitExpr_(const ConstructorNode* lhs, const Expr& other) final {
return GetRef<Expr>(lhs) == other; if (const ConstructorNode* rhs = other.as<ConstructorNode>()) {
return lhs->name_hint == rhs->name_hint;
}
return false;
} }
bool ClauseEqual(const Clause& lhs, const Clause& rhs) { bool ClauseEqual(const Clause& lhs, const Clause& rhs) {
...@@ -582,7 +608,7 @@ TVM_REGISTER_API("relay._make._alpha_equal") ...@@ -582,7 +608,7 @@ TVM_REGISTER_API("relay._make._alpha_equal")
TVM_REGISTER_API("relay._make._assert_alpha_equal") TVM_REGISTER_API("relay._make._assert_alpha_equal")
.set_body_typed<void(NodeRef, NodeRef)>([](NodeRef a, NodeRef b) { .set_body_typed<void(NodeRef, NodeRef)>([](NodeRef a, NodeRef b) {
bool alpha_equal = AlphaEqualHandler(false, true).Equal(a, b); bool alpha_equal = AlphaEqualHandler(false, true).Equal(a, b);
CHECK(alpha_equal) << AsText(a, true) << " and " << AsText(b, true) << " is not alpha equal"; CHECK(alpha_equal) << AsText(a, true) << " and " << AsText(b, true) << " are not alpha equal";
}); });
TVM_REGISTER_API("relay._make._graph_equal") TVM_REGISTER_API("relay._make._graph_equal")
...@@ -593,7 +619,7 @@ TVM_REGISTER_API("relay._make._graph_equal") ...@@ -593,7 +619,7 @@ TVM_REGISTER_API("relay._make._graph_equal")
TVM_REGISTER_API("relay._make._assert_graph_equal") TVM_REGISTER_API("relay._make._assert_graph_equal")
.set_body_typed<void(NodeRef, NodeRef)>([](NodeRef a, NodeRef b) { .set_body_typed<void(NodeRef, NodeRef)>([](NodeRef a, NodeRef b) {
bool graph_equal = AlphaEqualHandler(true, true).Equal(a, b); bool graph_equal = AlphaEqualHandler(true, true).Equal(a, b);
CHECK(graph_equal) << AsText(a, true) << " and " << AsText(b, true) << " is not graph equal"; CHECK(graph_equal) << AsText(a, true) << " and " << AsText(b, true) << " are not graph equal";
}); });
} // namespace relay } // namespace relay
......
...@@ -206,6 +206,11 @@ TypeData ModuleNode::LookupDef(const std::string& name) const { ...@@ -206,6 +206,11 @@ TypeData ModuleNode::LookupDef(const std::string& name) const {
return this->LookupDef(id); return this->LookupDef(id);
} }
bool ModuleNode::HasDef(const std::string& name) const {
auto it = global_type_var_map_.find(name);
return it != global_type_var_map_.end();
}
Constructor ModuleNode::LookupTag(const int32_t tag) { Constructor ModuleNode::LookupTag(const int32_t tag) {
auto it = constructor_tag_map_.find(tag); auto it = constructor_tag_map_.find(tag);
CHECK(it != constructor_tag_map_.end()) CHECK(it != constructor_tag_map_.end())
......
...@@ -44,6 +44,8 @@ ...@@ -44,6 +44,8 @@
namespace tvm { namespace tvm {
namespace relay { namespace relay {
static const char* kSemVer = "v0.0.4";
Doc Brace(const Doc& d, Doc Brace(const Doc& d,
const std::string& open = "{", const std::string& open = "{",
const std::string& close = "}", const std::string& close = "}",
...@@ -239,6 +241,8 @@ class PrettyPrinter : ...@@ -239,6 +241,8 @@ class PrettyPrinter :
return PrintExpr(Downcast<Expr>(node), meta, try_inline); return PrintExpr(Downcast<Expr>(node), meta, try_inline);
} else if (node.as_derived<TypeNode>()) { } else if (node.as_derived<TypeNode>()) {
return PrintType(Downcast<Type>(node), meta); return PrintType(Downcast<Type>(node), meta);
} else if (node.as_derived<PatternNode>()) {
return PrintPattern(Downcast<Pattern>(node), meta);
} else if (node.as_derived<ModuleNode>()) { } else if (node.as_derived<ModuleNode>()) {
return PrintMod(Downcast<Module>(node)); return PrintMod(Downcast<Module>(node));
} else { } else {
...@@ -313,7 +317,7 @@ class PrettyPrinter : ...@@ -313,7 +317,7 @@ class PrettyPrinter :
if (name.length() == 0 || !std::isalpha(name[0])) { if (name.length() == 0 || !std::isalpha(name[0])) {
name = "t" + name; name = "t" + name;
} }
Doc val = GetUniqueName("%" + name); Doc val = GetUniqueName(name);
memo_type_[var] = val; memo_type_[var] = val;
if (var->kind != kType) { if (var->kind != kType) {
val << ": " << Print(var->kind); val << ": " << Print(var->kind);
...@@ -347,13 +351,17 @@ class PrettyPrinter : ...@@ -347,13 +351,17 @@ class PrettyPrinter :
} }
bool IsUnique(const Expr& expr) { bool IsUnique(const Expr& expr) {
return !(dg_.expr_node.at(expr)->parents.head && auto it = dg_.expr_node.find(expr);
dg_.expr_node.at(expr)->parents.head->next); if (it == dg_.expr_node.end()) {
return true;
} else {
return !(it->second->parents.head && it->second->parents.head->next);
}
} }
bool AlwaysInline(const Expr& expr) { bool AlwaysInline(const Expr& expr) {
return expr.as<GlobalVarNode>() || expr.as<ConstantNode>() || return expr.as<GlobalVarNode>() || expr.as<ConstantNode>() || expr.as<OpNode>() ||
expr.as<OpNode>() || expr.as<VarNode>(); expr.as<VarNode>() || expr.as<ConstructorNode>();
} }
//------------------------------------ //------------------------------------
...@@ -380,9 +388,9 @@ class PrettyPrinter : ...@@ -380,9 +388,9 @@ class PrettyPrinter :
} else if (!inline_expr && expr.as<LetNode>()) { } else if (!inline_expr && expr.as<LetNode>()) {
// wrap GNFed let in brackets // wrap GNFed let in brackets
Doc body; Doc body;
printed_expr << "{"; printed_expr << "(";
printed_expr << Indent(2, body << PrintNewLine() << VisitExpr(expr)) << PrintNewLine(); printed_expr << Indent(2, body << PrintNewLine() << VisitExpr(expr)) << PrintNewLine();
printed_expr << "}"; printed_expr << ")";
} else { } else {
printed_expr = VisitExpr(expr); printed_expr = VisitExpr(expr);
} }
...@@ -483,13 +491,13 @@ class PrettyPrinter : ...@@ -483,13 +491,13 @@ class PrettyPrinter :
Doc doc; Doc doc;
doc << prefix; doc << prefix;
if (fn->type_params.size() > 0) { if (fn->type_params.size() > 0) {
doc << "<"; doc << "[";
std::vector<Doc> type_params; std::vector<Doc> type_params;
for (const TypeVar& tv : fn->type_params) { for (const TypeVar& tv : fn->type_params) {
type_params.push_back(AllocTypeVar(tv)); type_params.push_back(Doc(tv->var->name_hint));
} }
doc << PrintSep(type_params); doc << PrintSep(type_params);
doc << ">"; doc << "]";
} }
doc << "("; doc << "(";
std::vector<Doc> params; std::vector<Doc> params;
...@@ -510,6 +518,15 @@ class PrettyPrinter : ...@@ -510,6 +518,15 @@ class PrettyPrinter :
Doc PrintMod(const Module& mod) { Doc PrintMod(const Module& mod) {
Doc doc; Doc doc;
int counter = 0; int counter = 0;
// type definitions
for (const auto& kv : mod->type_definitions) {
if (counter++ != 0) {
doc << PrintNewLine();
}
doc << Print(kv.second);
doc << PrintNewLine();
}
// functions
for (const auto& kv : mod->functions) { for (const auto& kv : mod->functions) {
dg_ = DependencyGraph::Create(&arena_, kv.second); dg_ = DependencyGraph::Create(&arena_, kv.second);
...@@ -547,7 +564,12 @@ class PrettyPrinter : ...@@ -547,7 +564,12 @@ class PrettyPrinter :
for (const Doc& d : PrintCallAttrs(op->attrs, op->op)) { for (const Doc& d : PrintCallAttrs(op->attrs, op->op)) {
args.push_back(d); args.push_back(d);
} }
const auto* cons_node = op->op.as<ConstructorNode>();
if (cons_node) {
doc << cons_node->name_hint;
} else {
doc << Print(op->op); doc << Print(op->op);
}
return doc << "(" << PrintSep(args) << ")"; return doc << "(" << PrintSep(args) << ")";
} }
...@@ -570,27 +592,57 @@ class PrettyPrinter : ...@@ -570,27 +592,57 @@ class PrettyPrinter :
// TODO(jmp): Lots of code duplication here because PrintBody and PrintScope don't accept Docs. // TODO(jmp): Lots of code duplication here because PrintBody and PrintScope don't accept Docs.
Doc doc; Doc doc;
Doc body; Doc body;
doc << "match " << Print(op->data) << " "; doc << "match";
doc << "{"; if (!op->complete) {
std::vector<Doc> clauses; doc << "?";
}
doc << " (" << Print(op->data) << ") {";
std::vector<Doc> clause_docs;
for (const auto& clause : op->clauses) { for (const auto& clause : op->clauses) {
Doc clause_doc; Doc clause_doc;
clauses.push_back(clause_doc << Print(clause->lhs) << " -> " clause_doc << PrintPattern(clause->lhs, false) << " => ";
<< Print(clause->rhs)); Doc rhs_doc = PrintScope(clause->rhs);
if (clause->rhs.as<LetNode>()) {
// only add braces if there are multiple lines on the rhs
rhs_doc = Brace(rhs_doc);
} }
doc << Indent(2, body << PrintNewLine() << PrintSep(clauses, PrintNewLine())) << PrintNewLine(); clause_doc << rhs_doc << ",";
doc << "}"; clause_docs.push_back(clause_doc);
}
doc << Indent(2, body << PrintNewLine() << PrintSep(clause_docs, PrintNewLine()))
<< PrintNewLine() << "}";
return doc; return doc;
} }
Doc PrintPattern(const Pattern& pattern, bool meta) {
auto it = memo_pattern_.find(pattern);
if (it != memo_pattern_.end()) return it->second;
Doc printed_pattern;
if (meta) {
printed_pattern = meta_.GetMetaNode(GetRef<NodeRef>(pattern.get()));
} else {
printed_pattern = VisitPattern(pattern);
}
memo_pattern_[pattern] = printed_pattern;
return printed_pattern;
}
Doc VisitPattern_(const PatternConstructorNode* p) final { Doc VisitPattern_(const PatternConstructorNode* p) final {
Doc doc; Doc doc;
doc << p->constructor->name_hint << "("; doc << p->constructor->name_hint;
if (!p->patterns.empty()) {
doc << "(";
std::vector<Doc> pats; std::vector<Doc> pats;
for (const auto& pat : p->patterns) { for (const auto& pat : p->patterns) {
pats.push_back(Print(pat)); pats.push_back(Print(pat));
} }
return doc << PrintSep(pats) << ")"; doc << PrintSep(pats) << ")";
}
return doc;
}
Doc VisitPattern_(const PatternWildcardNode* pw) final {
return Doc("_");
} }
Doc VisitPattern_(const PatternVarNode* pv) final { Doc VisitPattern_(const PatternVarNode* pv) final {
...@@ -598,7 +650,17 @@ class PrettyPrinter : ...@@ -598,7 +650,17 @@ class PrettyPrinter :
} }
Doc VisitExpr_(const ConstructorNode* n) final { Doc VisitExpr_(const ConstructorNode* n) final {
return Doc(n->name_hint); Doc doc;
doc << n->name_hint;
if (n->inputs.size() != 0) {
doc << "(";
std::vector<Doc> inputs;
for (Type input : n->inputs) {
inputs.push_back(Print(input));
}
doc << PrintSep(inputs) << ")";
}
return doc;
} }
//------------------------------------ //------------------------------------
...@@ -623,7 +685,7 @@ class PrettyPrinter : ...@@ -623,7 +685,7 @@ class PrettyPrinter :
} }
Doc VisitType_(const TypeVarNode* node) final { Doc VisitType_(const TypeVarNode* node) final {
return AllocTypeVar(GetRef<TypeVar>(node)); return Doc(node->var->name_hint);
} }
Doc VisitType_(const GlobalTypeVarNode* node) final { Doc VisitType_(const GlobalTypeVarNode* node) final {
...@@ -675,13 +737,13 @@ class PrettyPrinter : ...@@ -675,13 +737,13 @@ class PrettyPrinter :
Doc doc; Doc doc;
doc << "fn "; doc << "fn ";
if (node->type_params.size() != 0) { if (node->type_params.size() != 0) {
doc << "<"; doc << "[";
std::vector<Doc> type_params; std::vector<Doc> type_params;
for (Type type_param : node->type_params) { for (Type type_param : node->type_params) {
type_params.push_back(Print(type_param)); type_params.push_back(Print(type_param));
} }
doc << PrintSep(type_params); doc << PrintSep(type_params);
doc << ">"; doc << "]";
} }
std::vector<Doc> arg_types; std::vector<Doc> arg_types;
for (Type arg_type : node->arg_types) { for (Type arg_type : node->arg_types) {
...@@ -695,6 +757,37 @@ class PrettyPrinter : ...@@ -695,6 +757,37 @@ class PrettyPrinter :
return doc << "ref(" << Print(node->value) << ")"; return doc << "ref(" << Print(node->value) << ")";
} }
Doc VisitType_(const TypeDataNode* node) final {
Doc doc;
doc << "type " << Print(node->header);
// type vars
if (node->type_vars.size() != 0) {
doc << "[";
std::vector<Doc> type_vars;
for (Type type_var : node->type_vars) {
type_vars.push_back(Print(type_var));
}
doc << PrintSep(type_vars) << "]";
}
doc << " ";
std::vector<Doc> constructor_docs;
for (Constructor constructor : node->constructors) {
constructor_docs.push_back(Print(constructor, /* meta */ false, /* try_inline */ true));
}
Doc separator;
separator << "," << PrintNewLine();
Doc adt_body;
adt_body << PrintSep(constructor_docs, separator);
// add trailing comma if there are any constructors
if (!constructor_docs.empty()) {
adt_body << ",";
}
doc << Brace(adt_body);
return doc;
}
//------------------------------------ //------------------------------------
// Overload of Attr printing functions // Overload of Attr printing functions
//------------------------------------ //------------------------------------
...@@ -758,6 +851,8 @@ class PrettyPrinter : ...@@ -758,6 +851,8 @@ class PrettyPrinter :
std::unordered_map<Expr, Doc, NodeHash, NodeEqual> memo_; std::unordered_map<Expr, Doc, NodeHash, NodeEqual> memo_;
/*! \brief Map from Type to Doc */ /*! \brief Map from Type to Doc */
std::unordered_map<Type, Doc, NodeHash, NodeEqual> memo_type_; std::unordered_map<Type, Doc, NodeHash, NodeEqual> memo_type_;
/*! \brief Map from Type to Doc */
std::unordered_map<Pattern, Doc, NodeHash, NodeEqual> memo_pattern_;
/*! \brief name allocation map */ /*! \brief name allocation map */
std::unordered_map<std::string, int> name_alloc_map_; std::unordered_map<std::string, int> name_alloc_map_;
/*! \brief meta data context */ /*! \brief meta data context */
...@@ -861,7 +956,7 @@ std::string PrettyPrint_(const NodeRef& node, ...@@ -861,7 +956,7 @@ std::string PrettyPrint_(const NodeRef& node,
bool show_meta_data, bool show_meta_data,
runtime::TypedPackedFunc<std::string(Expr)> annotate) { runtime::TypedPackedFunc<std::string(Expr)> annotate) {
Doc doc; Doc doc;
doc << "v0.0.3" << PrintNewLine() doc << kSemVer << PrintNewLine()
<< PrettyPrinter(show_meta_data, annotate).PrintFinal(node); << PrettyPrinter(show_meta_data, annotate).PrintFinal(node);
return doc.str(); return doc.str();
} }
......
...@@ -774,7 +774,6 @@ class TypeInferencer::Resolver : public ExprMutator, PatternMutator { ...@@ -774,7 +774,6 @@ class TypeInferencer::Resolver : public ExprMutator, PatternMutator {
bool update_missing_type_annotation_{true}; bool update_missing_type_annotation_{true};
}; };
Expr TypeInferencer::Infer(Expr expr) { Expr TypeInferencer::Infer(Expr expr) {
// Step 1: Populate the constraints. // Step 1: Populate the constraints.
GetType(expr); GetType(expr);
......
...@@ -16,14 +16,14 @@ ...@@ -16,14 +16,14 @@
# under the License. # under the License.
import tvm import tvm
from tvm import relay from tvm import relay
from tvm.relay.analysis import alpha_equal, assert_alpha_equal from tvm.relay.analysis import graph_equal, assert_graph_equal
from nose.tools import nottest, raises from nose.tools import nottest, raises
from numpy import isclose from numpy import isclose
from typing import Union from typing import Union
from functools import wraps from functools import wraps
raises_parse_error = raises(tvm._ffi.base.TVMError) raises_parse_error = raises(tvm._ffi.base.TVMError)
SEMVER = "v0.0.3" SEMVER = "v0.0.4"
BINARY_OPS = { BINARY_OPS = {
"*": relay.multiply, "*": relay.multiply,
...@@ -60,20 +60,29 @@ TYPES = { ...@@ -60,20 +60,29 @@ TYPES = {
"float16x4", "float16x4",
} }
LIST_DEFN = """
type List[A] {
Cons(A, List[A]),
Nil,
}
"""
def roundtrip(expr): def roundtrip(expr):
x = relay.fromtext(str(expr)) x = relay.fromtext(str(expr))
assert_alpha_equal(x, expr) assert_graph_equal(x, expr)
def parse_text(code): def parse_text(code):
x = relay.fromtext(SEMVER + "\n" + code) expr = relay.fromtext(SEMVER + "\n" + code)
roundtrip(x) roundtrip(expr)
return x return expr
def parses_as(code, expr): def parses_as(code, expr):
# type: (str, relay.Expr) -> bool # type: (str, relay.Expr) -> bool
return alpha_equal(parse_text(code), expr) parsed = parse_text(code)
result = graph_equal(parsed, expr)
return result
def get_scalar(x): def get_scalar(x):
# type: (relay.Constant) -> (Union[float, int, bool]) # type: (relay.Constant) -> (Union[float, int, bool])
...@@ -168,13 +177,13 @@ def test_bin_op(): ...@@ -168,13 +177,13 @@ def test_bin_op():
def test_parens(): def test_parens():
assert alpha_equal(parse_text("1 * 1 + 1"), parse_text("(1 * 1) + 1")) assert graph_equal(parse_text("1 * 1 + 1"), parse_text("(1 * 1) + 1"))
assert not alpha_equal(parse_text("1 * 1 + 1"), parse_text("1 * (1 + 1)")) assert not graph_equal(parse_text("1 * 1 + 1"), parse_text("1 * (1 + 1)"))
def test_op_assoc(): def test_op_assoc():
assert alpha_equal(parse_text("1 * 1 + 1 < 1 == 1"), parse_text("(((1 * 1) + 1) < 1) == 1")) assert graph_equal(parse_text("1 * 1 + 1 < 1 == 1"), parse_text("(((1 * 1) + 1) < 1) == 1"))
assert alpha_equal(parse_text("1 == 1 < 1 + 1 * 1"), parse_text("1 == (1 < (1 + (1 * 1)))")) assert graph_equal(parse_text("1 == 1 < 1 + 1 * 1"), parse_text("1 == (1 < (1 + (1 * 1)))"))
@nottest @nottest
...@@ -239,7 +248,7 @@ def test_seq(): ...@@ -239,7 +248,7 @@ def test_seq():
) )
assert parses_as( assert parses_as(
"let %_ = { 1 }; ()", "let %_ = 1; ()",
relay.Let( relay.Let(
X, X,
relay.const(1), relay.const(1),
...@@ -249,13 +258,13 @@ def test_seq(): ...@@ -249,13 +258,13 @@ def test_seq():
def test_graph(): def test_graph():
code = "%0 = (); %1 = 1; (%0, %0, %1)"
assert parses_as( assert parses_as(
"%0 = (); %1 = 1; (%0, %0, %1)", code,
relay.Tuple([UNIT, UNIT, relay.const(1)]) relay.Tuple([UNIT, UNIT, relay.const(1)])
) )
assert not parses_as( assert not parses_as(
"%0 = (); %1 = 1; (%0, %0, %1)", code,
relay.Tuple([relay.Tuple([]), relay.Tuple([]), relay.const(1)]) relay.Tuple([relay.Tuple([]), relay.Tuple([]), relay.const(1)])
) )
...@@ -632,6 +641,236 @@ def test_tuple_type(): ...@@ -632,6 +641,236 @@ def test_tuple_type():
) )
) )
def test_adt_defn():
mod = relay.Module()
glob_typ_var = relay.GlobalTypeVar("Ayy")
prog = relay.TypeData(
glob_typ_var,
[],
[relay.Constructor("Nil", [], glob_typ_var)])
mod[glob_typ_var] = prog
assert parses_as(
"""
type Ayy { Nil }
""",
mod
)
def test_empty_adt_defn():
mod = relay.Module()
glob_typ_var = relay.GlobalTypeVar("Ayy")
prog = relay.TypeData(glob_typ_var, [], [])
mod[glob_typ_var] = prog
assert parses_as(
"""
type Ayy { }
""",
mod
)
def test_multiple_cons_defn():
mod = relay.Module()
list_var = relay.GlobalTypeVar("List")
typ_var = relay.TypeVar("A")
prog = relay.TypeData(
list_var,
[typ_var],
[
relay.Constructor("Cons", [typ_var, list_var(typ_var)], list_var),
relay.Constructor("Nil", [], list_var),
])
mod[list_var] = prog
assert parses_as(LIST_DEFN, mod)
def test_multiple_type_param_defn():
glob_typ_var = relay.GlobalTypeVar("Either")
typ_var_a = relay.TypeVar("A")
typ_var_b = relay.TypeVar("B")
prog = relay.TypeData(
glob_typ_var,
[typ_var_a, typ_var_b],
[
relay.Constructor("Left", [typ_var_a], glob_typ_var),
relay.Constructor("Right", [typ_var_b], glob_typ_var),
])
mod = relay.Module()
mod[glob_typ_var] = prog
assert parses_as(
"""
type Either[A, B] {
Left(A),
Right(B),
}
""",
mod
)
def test_match():
# pair each match keyword with whether it specifies a complete match or not
match_keywords = [("match", True), ("match?", False)]
for (match_keyword, is_complete) in match_keywords:
mod = relay.Module()
list_var = relay.GlobalTypeVar("List")
typ_var = relay.TypeVar("A")
cons_constructor = relay.Constructor(
"Cons", [typ_var, list_var(typ_var)], list_var)
nil_constructor = relay.Constructor("Nil", [], list_var)
list_def = relay.TypeData(
list_var,
[typ_var],
[cons_constructor, nil_constructor])
mod[list_var] = list_def
length_var = relay.GlobalVar("length")
typ_var = relay.TypeVar("A")
input_type = list_var(typ_var)
input_var = relay.Var("xs", input_type)
rest_var = relay.Var("rest")
cons_case = relay.Let(
_,
UNIT,
relay.add(relay.const(1), relay.Call(length_var, [rest_var])))
body = relay.Match(input_var,
[relay.Clause(
relay.PatternConstructor(
cons_constructor,
[relay.PatternWildcard(), relay.PatternVar(rest_var)]),
cons_case),
relay.Clause(
relay.PatternConstructor(nil_constructor, []),
relay.const(0))],
complete=is_complete
)
length_func = relay.Function(
[input_var],
body,
int32,
[typ_var]
)
mod[length_var] = length_func
assert parses_as(
"""
%s
def @length[A](%%xs: List[A]) -> int32 {
%s (%%xs) {
Cons(_, %%rest) => {
();;
1 + @length(%%rest)
},
Nil => 0,
}
}
""" % (LIST_DEFN, match_keyword),
mod
)
def test_adt_cons_expr():
mod = relay.Module()
list_var = relay.GlobalTypeVar("List")
typ_var = relay.TypeVar("A")
cons_constructor = relay.Constructor(
"Cons", [typ_var, list_var(typ_var)], list_var)
nil_constructor = relay.Constructor("Nil", [], list_var)
list_def = relay.TypeData(
list_var,
[typ_var],
[cons_constructor, nil_constructor])
mod[list_var] = list_def
make_singleton_var = relay.GlobalVar("make_singleton")
input_var = relay.Var("x", int32)
make_singleton_func = relay.Function(
[input_var],
cons_constructor(input_var, nil_constructor()),
list_var(int32)
)
mod[make_singleton_var] = make_singleton_func
assert parses_as(
"""
%s
def @make_singleton(%%x: int32) -> List[int32] {
Cons(%%x, Nil())
}
""" % LIST_DEFN,
mod
)
@raises_parse_error
def test_duplicate_adt_defn():
parse_text(
"""
%s
type List[A] {
Cons(A, List[A]),
Nil,
}
""" % LIST_DEFN
)
@raises_parse_error
def test_duplicate_adt_cons():
parse_text(
"""
type Ayy { Lmao }
type Haha { Lmao }
"""
)
@raises_parse_error
def test_duplicate_adt_cons_defn():
parse_text(
"""
type Ayy { Lmao }
type Lmao { Ayy }
"""
)
@raises_parse_error
def test_duplicate_global_var():
parse_text(
"""
def @id[A](%x: A) -> A { x }
def @id[A](%x: A) -> A { x }
"""
)
def test_extern_adt_defn():
# TODO(weberlo): update this test once extern is implemented
mod = relay.Module()
extern_var = relay.GlobalTypeVar("T")
typ_var = relay.TypeVar("A")
extern_def = relay.TypeData(extern_var, [typ_var], [])
mod[extern_var] = extern_def
assert parses_as(
"""
extern type T[A]
""",
mod
)
if __name__ == "__main__": if __name__ == "__main__":
test_comments() test_comments()
test_int_literal() test_int_literal()
...@@ -655,3 +894,14 @@ if __name__ == "__main__": ...@@ -655,3 +894,14 @@ if __name__ == "__main__":
test_tensor_type() test_tensor_type()
test_function_type() test_function_type()
test_tuple_type() test_tuple_type()
test_adt_defn()
test_empty_adt_defn()
test_multiple_cons_defn()
test_multiple_type_param_defn()
test_match()
test_adt_cons_expr()
test_duplicate_adt_defn()
test_duplicate_adt_cons()
test_duplicate_adt_cons_defn()
test_duplicate_global_var()
test_extern_adt_defn()
...@@ -23,14 +23,14 @@ from tvm.relay.analysis import alpha_equal, assert_alpha_equal, assert_graph_equ ...@@ -23,14 +23,14 @@ from tvm.relay.analysis import alpha_equal, assert_alpha_equal, assert_graph_equ
do_print = [False] do_print = [False]
SEMVER = "v0.0.3\n" SEMVER = "v0.0.4\n"
def astext(p, graph_equal=False): def astext(p, unify_free_vars=False):
txt = p.astext() txt = p.astext()
if isinstance(p, Expr) and free_vars(p): if isinstance(p, Expr) and free_vars(p):
return txt return txt
x = relay.fromtext(txt) x = relay.fromtext(txt)
if graph_equal: if unify_free_vars:
assert_graph_equal(x, p) assert_graph_equal(x, p)
else: else:
assert_alpha_equal(x, p) assert_alpha_equal(x, p)
...@@ -78,7 +78,7 @@ def test_meta_data(): ...@@ -78,7 +78,7 @@ def test_meta_data():
padding=(1, 1), padding=(1, 1),
channels=2) channels=2)
f = relay.Function([x, w], z) f = relay.Function([x, w], z)
text = astext(f, graph_equal=True) text = astext(f, unify_free_vars=True)
text_no_meta = str(f) text_no_meta = str(f)
assert "channels=2" in text assert "channels=2" in text
assert "channels=2" in text_no_meta assert "channels=2" in text_no_meta
...@@ -122,7 +122,7 @@ def test_let_if_scope(): ...@@ -122,7 +122,7 @@ def test_let_if_scope():
f = relay.Function([x, y, cond], result) f = relay.Function([x, y, cond], result)
text = astext(f) text = astext(f)
assert text.count("{") == 4 assert text.count("{") == 3
assert "%cond: bool" in text assert "%cond: bool" in text
show(astext(f)) show(astext(f))
...@@ -218,14 +218,6 @@ def test_zeros(): ...@@ -218,14 +218,6 @@ def test_zeros():
x = relay.op.zeros([], "float32") x = relay.op.zeros([], "float32")
astext(x) astext(x)
def test_cast():
data = relay.var('data', dtype='float32')
fp16_cast = relay.cast(data, dtype='float16')
cast_func = relay.Function(relay.analysis.free_vars(fp16_cast), fp16_cast)
astext(cast_func)
if __name__ == "__main__": if __name__ == "__main__":
do_print[0] = True do_print[0] = True
test_lstm() test_lstm()
...@@ -247,4 +239,3 @@ if __name__ == "__main__": ...@@ -247,4 +239,3 @@ if __name__ == "__main__":
test_let_if_scope() test_let_if_scope()
test_variable_name() test_variable_name()
test_call_node_order() test_call_node_order()
test_cast()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment