Commit 2973f8a6 by 雾雨魔理沙 Committed by Tianqi Chen

[Relay] parser/pretty printer roundtripping (#3536)

parent e5efc632
......@@ -15,14 +15,14 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-import
# pylint: disable=invalid-name, unused-argument
"""A parser for Relay's text format."""
from __future__ import absolute_import
import sys
from ast import literal_eval
from collections import deque
from typing import TypeVar, Deque, Tuple, Optional, Union, NamedTuple, List, Callable, Any, Dict
import tvm
......@@ -32,31 +32,77 @@ from . import expr
from . import ty
from . import op
class ParseError(Exception):
"""Exception type for parse errors."""
def __init__(self, message):
# type: (str) -> None
super(ParseError, self).__init__()
self.message = message
PYTHON_VERSION = sys.version_info.major
try:
from .grammar.py3.RelayVisitor import RelayVisitor
from .grammar.py3.RelayParser import RelayParser
from .grammar.py3.RelayLexer import RelayLexer
except ImportError:
raise ParseError("Couldn't find ANTLR parser. Try building with USE_ANTLR=ON.")
raise Exeption("Couldn't find ANTLR parser. Try building with USE_ANTLR=ON.")
try:
from antlr4 import ParserRuleContext, InputStream, CommonTokenStream
from antlr4.tree.Tree import TerminalNode
from antlr4 import InputStream, CommonTokenStream
from antlr4.error.ErrorListener import ErrorListener
except ImportError:
raise ParseError("Couldn't find ANTLR runtime." +
raise Exception("Couldn't find ANTLR runtime." +
"Try running `pip{version} install antlr4-python{version}-runtime`."
.format(version=PYTHON_VERSION))
sys.setrecursionlimit(10000)
class ParseError(Exception):
"""Exception type for parse errors."""
def __init__(self, message):
# type: (str) -> None
super(ParseError, self).__init__()
self.message = message
def __repr__(self):
return "ParseError({})".format(self.message)
def __str__(self):
return repr(self)
class OpWrapper:
"""Overload the __call__ for op."""
pass
class ExprOp(OpWrapper):
"""Call an expr. The default, but does not handle attrs well."""
def __init__(self, operator):
self.operator = operator
def __call__(self, args, attrs, type_args):
try:
return expr.Call(self.operator, args, attrs, type_args)
except Exception:
raise Exception(str(self.operator) + " " + str(attrs))
class FuncOp(OpWrapper):
"""Convert the attrs, call the python function with the attrs passed in as keyword arguments.
Tvm should provide this in the future, as this is pretty similar to what op.get is providing.
"""
def __init__(self, operator):
self.operator = operator
def convert(self, v):
if isinstance(v, tuple):
return tuple([self.convert(x) for x in v])
if isinstance(v, expr.Constant):
return v.data.asnumpy().item()
if isinstance(v, str):
return v
raise Exception(v)
def __call__(self, args, attrs, type_args):
if attrs is None:
attrs = {}
x = self.operator(*args, **{k: self.convert(v) for k, v in attrs.items()})
if isinstance(x, expr.TupleWrapper):
x = x.astuple()
return x
BINARY_OPS = {
RelayParser.MUL: op.multiply,
RelayParser.DIV: op.divide,
......@@ -70,6 +116,24 @@ BINARY_OPS = {
RelayParser.NE: op.not_equal,
}
FUNC_OPS = {
"nn.conv2d": op.nn.conv2d,
"nn.batch_norm": op.nn.batch_norm,
"nn.dense": op.nn.dense,
"nn.bias_add": op.nn.bias_add,
"nn.max_pool2d": op.nn.max_pool2d,
"nn.global_max_pool2d": op.nn.global_max_pool2d,
"nn.avg_pool2d": op.nn.avg_pool2d,
"nn.global_avg_pool2d": op.nn.global_avg_pool2d,
"nn.softmax": op.nn.softmax,
"reshape": op.reshape,
"nn.conv2d_transpose": op.nn.conv2d_transpose,
"concatenate": op.concatenate,
"nn.dropout": op.nn.dropout_raw,
"zeros": op.zeros,
"split": op.split,
}
TYPE_PREFIXES = [
"int",
"uint",
......@@ -77,9 +141,9 @@ TYPE_PREFIXES = [
"bool",
]
T = TypeVar("T")
Scope = Deque[Tuple[str, T]]
Scopes = Deque[Scope[T]]
T = ty.TypeVar("T")
# Scope = Deque[Tuple[str, T]]
# Scopes = Deque[Scope[T]]
def lookup(scopes, name):
# type: (Scopes[T], str) -> Optional[T]
......@@ -108,6 +172,8 @@ def spanify(f):
ast = f(*args, **kwargs)
line, col = ctx.getSourceInterval()
sp = Span(sn, line, col)
if isinstance(ast, tvm.relay.expr.TupleWrapper):
ast = ast.astuple()
ast.set_span(sp)
return ast
return _wrapper
......@@ -179,6 +245,9 @@ class ParseTreeToRelayIR(RelayVisitor):
self.type_param_scopes[0].appendleft((name, typ))
return typ
def visitProjection(self, ctx):
return expr.TupleGetItem(self.visit(ctx.expr()), self.visit(ctx.NAT()))
def visitTerminal(self, node):
# type: (TerminalNode) -> Union[expr.Expr, int, float]
"""Visit lexer tokens that aren't ignored or visited by other functions."""
......@@ -213,12 +282,15 @@ class ParseTreeToRelayIR(RelayVisitor):
if node_text == "False":
return False
raise ParseError("Unrecognized BOOL_LIT: `{}`".format(node_text))
if node_type == RelayLexer.QUOTED_STRING:
return literal_eval(node_text)
raise ParseError("todo: {}".format(node_text))
raise ParseError("todo: `{}`".format(node_text))
def visit_list(self, ctx_list):
# type: (List[ParserRuleContext]) -> List[Any]
""""Visit a list of contexts."""
assert isinstance(ctx_list, list)
return [self.visit(ctx) for ctx in ctx_list]
......@@ -232,6 +304,11 @@ class ParseTreeToRelayIR(RelayVisitor):
return self.visit(ctx)
def visitProg(self, ctx):
self.meta = None
if ctx.METADATA():
header, data = str(ctx.METADATA()).split('\n', 1)
assert header == "METADATA:"
self.meta = tvm.load_json(data)
# type: (RelayParser.ProgContext) -> Union[expr.Expr, module.Module]
if ctx.defn():
self.visit_list(ctx.defn())
......@@ -245,11 +322,14 @@ class ParseTreeToRelayIR(RelayVisitor):
# Exprs
def visitOpIdent(self, ctx):
# type: (RelayParser.OpIdentContext) -> op.Op
return op.get(ctx.CNAME().getText())
op_name = ctx.CNAME().getText()
if op_name in FUNC_OPS:
return FuncOp(FUNC_OPS[op_name])
return ExprOp(op.get(op_name))
# pass through
def visitParens(self, ctx):
# type: (RelayParser.ParensContext) -> expr.Expr
def visitParen(self, ctx):
# type: (RelayParser.ParenContext) -> expr.Expr
return self.visit(ctx.expr())
# pass through
......@@ -283,25 +363,17 @@ class ParseTreeToRelayIR(RelayVisitor):
tup = self.visit_list(ctx.expr())
return expr.Tuple(tup)
# Currently doesn't support mutable sequencing.
def visitLet(self, ctx):
# type: (RelayParser.SeqContext) -> expr.Let
"""Desugar various sequence constructs to Relay Let nodes."""
if ctx.MUT() is not None:
raise ParseError("Mutation is currently unsupported.")
if ctx.var() is None or ctx.var().ident() is None:
if ctx.var() is None:
# anonymous identity
ident = "_"
type_ = None
else:
local_var = ctx.var().ident().LOCAL_VAR()
if local_var is None:
raise ParseError("Only local ids may be used in `let`s.")
ident = local_var.getText()[1:]
type_ = self.getType_(ctx.var().type_())
var = self.mk_var(ident, type_)
else:
var = self.visitVar(ctx.var())
self.enter_var_scope()
value = self.visit(ctx.expr(0))
......@@ -326,7 +398,7 @@ class ParseTreeToRelayIR(RelayVisitor):
def visitVar(self, ctx):
# type: (RelayParser.VarContext) -> expr.Var
"""Visit a single variable."""
ident = ctx.ident().LOCAL_VAR()
ident = ctx.LOCAL_VAR()
if ident is None:
raise ParseError("Only local ids may be used in vars.")
......@@ -344,19 +416,29 @@ class ParseTreeToRelayIR(RelayVisitor):
# type: (RelayParser.AttrContext) -> Tuple[str, expr.Expr]
return (ctx.CNAME().getText(), self.visit(ctx.expr()))
def visitAttrList(self, ctx):
def visitArgNoAttr(self, ctx):
return (self.visit_list(ctx.varList().var()), None)
def visitAttrSeq(self, ctx):
# type: (RelayParser.AttrListContext) -> Dict[str, expr.Expr]
return dict(self.visit_list(ctx.attr()))
def visitArgWithAttr(self, ctx):
return (self.visit_list(ctx.var()), self.visitAttrSeq(ctx.attrSeq()))
def visitArgList(self,
ctx # type: RelayParser.ArgListContext
):
# type: (...) -> Tuple[Optional[List[expr.Var]], Optional[Dict[str, expr.Expr]]]
var_list = self.visit(ctx.varList()) if ctx.varList() else None
attr_list = self.visit(ctx.attrList()) if ctx.attrList() else None
return (var_list, attr_list)
def visitMeta(self, ctx):
type_key = str(ctx.CNAME())
index = int(self.visit(ctx.NAT()))
return self.meta[type_key][index]
def mk_func(self, ctx):
# type: (Union[RelayParser.FuncContext, RelayParser.DefnContext]) -> expr.Function
"""Construct a function from either a Func or Defn."""
......@@ -365,7 +447,7 @@ class ParseTreeToRelayIR(RelayVisitor):
self.enter_var_scope()
# Capture type params in params.
self.enter_type_param_scope()
type_params = ctx.typeParamSeq()
type_params = ctx.typeParamList()
if type_params is not None:
type_params = type_params.ident()
......@@ -405,18 +487,25 @@ class ParseTreeToRelayIR(RelayVisitor):
raise ParseError("Only global ids may be used in `def`s.")
ident_name = ident.getText()[1:]
ident = self.mk_global_var(ident_name)
self.module[ident] = self.mk_func(ctx)
def visitCallNoAttr(self, ctx):
return (self.visit_list(ctx.exprList().expr()), None)
def visitCallWithAttr(self, ctx):
return (self.visit_list(ctx.expr()), self.visit(ctx.attrSeq()))
def call(self, func, args, attrs, type_args):
if isinstance(func, OpWrapper):
return func(args, attrs, type_args)
return expr.Call(func, args, attrs, type_args)
@spanify
def visitCall(self, ctx):
# type: (RelayParser.CallContext) -> expr.Call
visited_exprs = self.visit_list(ctx.expr())
func = visited_exprs[0]
args = visited_exprs[1:]
return expr.Call(func, args, None, None)
func = self.visit(ctx.expr())
args, attrs = self.visit(ctx.callList())
return self.call(func, args, attrs, [])
@spanify
def visitIfElse(self, ctx):
......@@ -438,9 +527,7 @@ class ParseTreeToRelayIR(RelayVisitor):
def visitGraph(self, ctx):
# type: (RelayParser.GraphContext) -> expr.Expr
"""Visit a graph variable assignment."""
if ctx.ident().GRAPH_VAR() is None:
raise ParseError("Expected a graph var, but got `{}`".format(ctx.ident().getText()))
graph_nid = int(ctx.ident().GRAPH_VAR().getText()[1:])
graph_nid = int(ctx.GRAPH_VAR().getText()[1:])
self.enter_var_scope()
value = self.visit(ctx.expr(0))
......@@ -500,15 +587,18 @@ class ParseTreeToRelayIR(RelayVisitor):
# type: (RelayParser.ParensShapeContext) -> int
return self.visit(ctx.shape())
def visitShapeSeq(self, ctx):
# type: (RelayParser.ShapeSeqContext) -> List[int]
def visitShapeList(self, ctx):
# type: (RelayParser.ShapeListContext) -> List[int]
return self.visit_list(ctx.shape())
def visitTensor(self, ctx):
return tuple(self.visit_list(ctx.expr()))
def visitTensorType(self, ctx):
# type: (RelayParser.TensorTypeContext) -> ty.TensorType
"""Create a simple tensor type. No generics."""
shape = self.visit(ctx.shapeSeq())
shape = self.visit(ctx.shapeList())
dtype = self.visit(ctx.type_())
if not isinstance(dtype, ty.TensorType):
......@@ -536,11 +626,37 @@ def make_parser(data):
"""Construct a RelayParser a given data stream."""
input_stream = InputStream(data)
lexer = RelayLexer(input_stream)
lexer.addErrorListener(StrictErrorListener(data))
token_stream = CommonTokenStream(lexer)
return RelayParser(token_stream)
p = RelayParser(token_stream)
p.addErrorListener(StrictErrorListener(data))
return p
__source_name_counter__ = 0
class StrictErrorListener(ErrorListener):
"""This ErrorListener fail eagerly on all error, and report the program."""
def __init__(self, text):
self.text = text
def syntaxError(self, recognizer, offendingSymbol, line, column, msg, e):
raise Exception("Syntax Error in:\n" + self.text)
def reportAmbiguity(self, recognizer, dfa, startIndex, stopIndex, exact, ambigAlts, configs):
raise Exception("Ambiguity Error in:\n" + self.text)
def reportAttemptingFullContext(self,
recognizer,
dfa,
startIndex,
stopIndex,
conflictingAlts,
configs):
raise Exception("Attempting Full Context in:\n" + self.text)
def reportContextSensitivity(self, recognizer, dfa, startIndex, stopIndex, prediction, configs):
raise Exception("Context Sensitivity in:\n" + self.text)
def fromtext(data, source_name=None):
# type: (str, str) -> Union[expr.Expr, module.Module]
"""Parse a Relay program."""
......
......@@ -224,6 +224,20 @@ def alpha_equal(lhs, rhs):
return bool(_make._alpha_equal(lhs, rhs))
def assert_alpha_equal(lhs, rhs):
"""Assert that two Relay expr is structurally equivalent. (alpha equivalence).
Parameters
----------
lhs : tvm.relay.Expr
One of the input Expression.
rhs : tvm.relay.Expr
One of the input Expression.
"""
_make._assert_alpha_equal(lhs, rhs)
def graph_equal(lhs, rhs):
"""Compare two Relay expr for data-flow equivalence.
The difference between this and alpha-equality is that
......@@ -246,6 +260,23 @@ def graph_equal(lhs, rhs):
return bool(_make._graph_equal(lhs, rhs))
def assert_graph_equal(lhs, rhs):
"""Compare two Relay expr for data-flow equivalence.
The difference between this and alpha-equality is that
variables are not expected to match between lhs and rhs;
they are treated as sources and are mapped between each other.
Parameters
----------
lhs : tvm.relay.Expr
One of the input Expression.
rhs : tvm.relay.Expr
One of the input Expression.
"""
_make._assert_graph_equal(lhs, rhs)
def collect_device_info(expr):
"""Collect the device allocation map for the given expression. The device
ids are propagated from the `device_copy` operators.
......
......@@ -17,15 +17,20 @@
* under the License.
*/
// list = *, seq = ?
grammar Relay;
SEMVER: 'v0.0.3' ;
// Lexing
// comments
WS : [ \t\n\r]+ -> skip ;
LINE_COMMENT : '//' .*? '\n' -> skip ;
COMMENT : '/*' .*? '*/' -> skip ;
COMMENT : '/*' (COMMENT|.)*? '*/' -> skip;
WS : [ \t\n\r]+ -> skip;
LINE_COMMENT : '//' .*? '\n' -> skip;
fragment ESCAPED_QUOTE : '\\"';
QUOTED_STRING : '"' ( ESCAPED_QUOTE | ~('\n'|'\r') )*? '"';
// operators
MUL: '*' ;
......@@ -39,18 +44,18 @@ GE: '>=' ;
EQ: '==' ;
NE: '!=' ;
opIdent: CNAME ;
GLOBAL_VAR: '@' CNAME ;
LOCAL_VAR: '%' CNAME;
GRAPH_VAR: '%' NAT;
MUT: 'mut' ;
BOOL_LIT
: 'True'
| 'False'
;
CNAME: ('_'|LETTER) ('_'|LETTER|DIGIT)* ('.' CNAME)*;
opIdent: CNAME ;
GLOBAL_VAR: '@' CNAME ;
LOCAL_VAR: '%' CNAME;
GRAPH_VAR: '%' NAT;
DATATYPE : 'int64';
// non-negative floats
fragment PREFLOAT : NAT ('.' NAT)? EXP?; // 1.35, 1.35E-9, 0.3, 4.5, 1, 1e10 3e4
......@@ -60,78 +65,69 @@ FLOAT : PREFLOAT 'f';
NAT: DIGIT+ ;
fragment EXP: [eE] [+\-]? NAT ; // \- since - means "range" inside [...]
CNAME: ('_'|LETTER) ('_'|LETTER|DIGIT)* ;
fragment LETTER: [a-zA-Z] ;
fragment DIGIT: [0-9] ;
fragment LETTER: [a-zA-Z];
fragment DIGIT: [0-9];
METADATA: 'METADATA:' .*;
// Parsing
// A Relay program is a list of global definitions or an expression.
prog: SEMVER (defn* | expr) EOF ;
prog: SEMVER (defn* | expr) METADATA? EOF ;
// option: 'set' ident BOOL_LIT ;
exprList: (expr (',' expr)*)?;
callList
: exprList # callNoAttr
| (expr ',')* attrSeq # callWithAttr
;
expr
// operators
: '(' expr ')' # parens
: '(' expr ')' # paren
| '{' expr '}' # paren
// function application
| expr '(' (expr (',' expr)*)? ')' # call
| expr '(' callList ')' # call
| '-' expr # neg
| expr op=('*'|'/') expr # binOp
| expr op=('+'|'-') expr # binOp
| expr op=('<'|'>'|'<='|'>=') expr # binOp
| expr op=('=='|'!=') expr # binOp
// function definition
| func # funcExpr
// tuples and tensors
| '(' ')' # tuple
| '(' expr ',' ')' # tuple
| '(' expr (',' expr)+ ')' # tuple
| expr '.' NAT # projection
| '[' (expr (',' expr)*)? ']' # tensor
| 'if' '(' expr ')' body 'else' body # ifElse
// sequencing
| 'let' MUT? var '=' expr ';' expr # let
| 'let' MUT? var '=' '{' expr '}' ';' expr # let
| 'let' var '=' expr ';' expr # let
// sugar for let %_ = expr; expr
| expr ';' expr # let
| ident '=' expr ';' expr # graph
// mutable update
// | ident '=' expr # writeRef
// | expr '^' # readRef
| expr ';;' expr # let
| GRAPH_VAR '=' expr ';' expr # graph
| ident # identExpr
| scalar # scalarExpr
// | expr '.' NAT # project
// | 'debug' # debug
| meta # metaExpr
| QUOTED_STRING # stringExpr
;
func: 'fn' typeParamSeq? '(' argList ')' ('->' type_)? body ;
defn: 'def' ident typeParamSeq? '(' argList ')' ('->' type_)? body ;
func: 'fn' typeParamList? '(' argList ')' ('->' type_)? body ;
defn: 'def' ident typeParamList? '(' argList ')' ('->' type_)? body ;
argList
: varList
| attrList
| varList ',' attrList
: varList # argNoAttr
| (var ',')* attrSeq # argWithAttr
;
varList: (var (',' var)*)? ;
var: ident (':' type_)? ;
varList: (var (',' var)*)?;
var: LOCAL_VAR (':' type_)?;
attrList: (attr (',' attr)*)? ;
attrSeq: attr (',' attr)*;
attr: CNAME '=' expr ;
// TODO(@jmp): for improved type annotations
// returnAnno: (ident ':')? type_ ;
// relations: 'where' relation (',' relation)* ;
// relation: ident '(' (type_ (',' type_)*)? ')' ;
typeParamSeq
typeParamList
: '[' ']'
| '[' ident (',' ident)* ']'
;
......@@ -141,28 +137,27 @@ type_
| '(' type_ ',' ')' # tupleType
| '(' type_ (',' type_)+ ')' # tupleType
| typeIdent # typeIdentType
| 'Tensor' '[' shapeSeq ',' type_ ']' # tensorType
// currently unused
// | typeIdent '[' (type_ (',' type_)*)? ']' # callType
| 'fn' typeParamSeq? '(' (type_ (',' type_)*)? ')' '->' type_ # funcType
| 'Tensor' '[' shapeList ',' type_ ']' # tensorType
| 'fn' typeParamList? '(' (type_ (',' type_)*)? ')' '->' type_ # funcType
| '_' # incompleteType
| NAT # intType
;
shapeSeq
: '(' ')'
| '(' shape ',' ')'
| '(' shape (',' shape)+ ')'
shapeList
: '(' shape (',' shape)+ ')'
| '(' ')'
| shape
;
meta : 'meta' '[' CNAME ']' '[' NAT ']';
shape
: '(' shape ')' # parensShape
// | type_ op=('*'|'/') type_ # binOpType
// | type_ op=('+'|'-') type_ # binOpType
: meta # metaShape
| '(' shape ')' # parensShape
| NAT # intShape
;
typeIdent : CNAME ;
typeIdent : CNAME;
// int8, int16, int32, int64
// uint8, uint16, uint32, uint64
// float16, float32, float64
......
......@@ -7,116 +7,147 @@ import sys
def serializedATN():
with StringIO() as buf:
buf.write("\3\u608b\ua72a\u8133\ub9ed\u417c\u3be7\u7786\u5964\2*")
buf.write("\u010d\b\1\4\2\t\2\4\3\t\3\4\4\t\4\4\5\t\5\4\6\t\6\4\7")
buf.write("\3\u608b\ua72a\u8133\ub9ed\u417c\u3be7\u7786\u5964\2/")
buf.write("\u014a\b\1\4\2\t\2\4\3\t\3\4\4\t\4\4\5\t\5\4\6\t\6\4\7")
buf.write("\t\7\4\b\t\b\4\t\t\t\4\n\t\n\4\13\t\13\4\f\t\f\4\r\t\r")
buf.write("\4\16\t\16\4\17\t\17\4\20\t\20\4\21\t\21\4\22\t\22\4\23")
buf.write("\t\23\4\24\t\24\4\25\t\25\4\26\t\26\4\27\t\27\4\30\t\30")
buf.write("\4\31\t\31\4\32\t\32\4\33\t\33\4\34\t\34\4\35\t\35\4\36")
buf.write("\t\36\4\37\t\37\4 \t \4!\t!\4\"\t\"\4#\t#\4$\t$\4%\t%")
buf.write("\4&\t&\4\'\t\'\4(\t(\4)\t)\4*\t*\4+\t+\4,\t,\4-\t-\3\2")
buf.write("\3\2\3\3\3\3\3\4\3\4\3\5\3\5\3\6\3\6\3\7\3\7\3\7\3\b\3")
buf.write("\b\3\b\3\b\3\b\3\t\3\t\3\t\3\t\3\n\3\n\3\13\3\13\3\f\3")
buf.write("\f\3\r\3\r\3\16\3\16\3\16\3\17\3\17\3\17\3\20\3\20\3\20")
buf.write("\3\20\3\21\3\21\3\22\3\22\3\22\3\22\3\22\3\22\3\22\3\23")
buf.write("\3\23\3\24\3\24\3\24\3\24\3\24\3\24\3\24\3\25\6\25\u0097")
buf.write("\n\25\r\25\16\25\u0098\3\25\3\25\3\26\3\26\3\26\3\26\7")
buf.write("\26\u00a1\n\26\f\26\16\26\u00a4\13\26\3\26\3\26\3\26\3")
buf.write("\26\3\27\3\27\3\27\3\27\7\27\u00ae\n\27\f\27\16\27\u00b1")
buf.write("\13\27\3\27\3\27\3\27\3\27\3\27\3\30\3\30\3\31\3\31\3")
buf.write("\32\3\32\3\33\3\33\3\34\3\34\3\35\3\35\3\36\3\36\3\36")
buf.write("\3\37\3\37\3\37\3 \3 \3 \3!\3!\3!\3\"\3\"\3\"\3#\3#\3")
buf.write("#\3$\3$\3$\3%\3%\3%\3%\3&\3&\3&\3&\3&\3&\3&\3&\3&\5&\u00e6")
buf.write("\n&\3\'\3\'\3\'\5\'\u00eb\n\'\3\'\5\'\u00ee\n\'\3(\3(")
buf.write("\3(\3)\6)\u00f4\n)\r)\16)\u00f5\3*\3*\5*\u00fa\n*\3*\3")
buf.write("*\3+\3+\5+\u0100\n+\3+\3+\3+\7+\u0105\n+\f+\16+\u0108")
buf.write("\13+\3,\3,\3-\3-\4\u00a2\u00af\2.\3\3\5\4\7\5\t\6\13\7")
buf.write("\r\b\17\t\21\n\23\13\25\f\27\r\31\16\33\17\35\20\37\21")
buf.write("!\22#\23%\24\'\25)\26+\27-\30/\31\61\32\63\33\65\34\67")
buf.write("\359\36;\37= ?!A\"C#E$G%I&K\'M\2O(Q)S\2U*W\2Y\2\3\2\7")
buf.write("\5\2\13\f\17\17\"\"\4\2GGgg\4\2--//\4\2C\\c|\3\2\62;\2")
buf.write("\u0114\2\3\3\2\2\2\2\5\3\2\2\2\2\7\3\2\2\2\2\t\3\2\2\2")
buf.write("\2\13\3\2\2\2\2\r\3\2\2\2\2\17\3\2\2\2\2\21\3\2\2\2\2")
buf.write("\23\3\2\2\2\2\25\3\2\2\2\2\27\3\2\2\2\2\31\3\2\2\2\2\33")
buf.write("\3\2\2\2\2\35\3\2\2\2\2\37\3\2\2\2\2!\3\2\2\2\2#\3\2\2")
buf.write("\2\2%\3\2\2\2\2\'\3\2\2\2\2)\3\2\2\2\2+\3\2\2\2\2-\3\2")
buf.write("\2\2\2/\3\2\2\2\2\61\3\2\2\2\2\63\3\2\2\2\2\65\3\2\2\2")
buf.write("\2\67\3\2\2\2\29\3\2\2\2\2;\3\2\2\2\2=\3\2\2\2\2?\3\2")
buf.write("\2\2\2A\3\2\2\2\2C\3\2\2\2\2E\3\2\2\2\2G\3\2\2\2\2I\3")
buf.write("\2\2\2\2K\3\2\2\2\2O\3\2\2\2\2Q\3\2\2\2\2U\3\2\2\2\3[")
buf.write("\3\2\2\2\5]\3\2\2\2\7_\3\2\2\2\ta\3\2\2\2\13c\3\2\2\2")
buf.write("\re\3\2\2\2\17h\3\2\2\2\21m\3\2\2\2\23q\3\2\2\2\25s\3")
buf.write("\2\2\2\27u\3\2\2\2\31w\3\2\2\2\33y\3\2\2\2\35|\3\2\2\2")
buf.write("\37\177\3\2\2\2!\u0083\3\2\2\2#\u0085\3\2\2\2%\u008c\3")
buf.write("\2\2\2\'\u008e\3\2\2\2)\u0096\3\2\2\2+\u009c\3\2\2\2-")
buf.write("\u00a9\3\2\2\2/\u00b7\3\2\2\2\61\u00b9\3\2\2\2\63\u00bb")
buf.write("\3\2\2\2\65\u00bd\3\2\2\2\67\u00bf\3\2\2\29\u00c1\3\2")
buf.write("\2\2;\u00c3\3\2\2\2=\u00c6\3\2\2\2?\u00c9\3\2\2\2A\u00cc")
buf.write("\3\2\2\2C\u00cf\3\2\2\2E\u00d2\3\2\2\2G\u00d5\3\2\2\2")
buf.write("I\u00d8\3\2\2\2K\u00e5\3\2\2\2M\u00e7\3\2\2\2O\u00ef\3")
buf.write("\2\2\2Q\u00f3\3\2\2\2S\u00f7\3\2\2\2U\u00ff\3\2\2\2W\u0109")
buf.write("\3\2\2\2Y\u010b\3\2\2\2[\\\7*\2\2\\\4\3\2\2\2]^\7+\2\2")
buf.write("^\6\3\2\2\2_`\7.\2\2`\b\3\2\2\2ab\7]\2\2b\n\3\2\2\2cd")
buf.write("\7_\2\2d\f\3\2\2\2ef\7k\2\2fg\7h\2\2g\16\3\2\2\2hi\7g")
buf.write("\2\2ij\7n\2\2jk\7u\2\2kl\7g\2\2l\20\3\2\2\2mn\7n\2\2n")
buf.write("o\7g\2\2op\7v\2\2p\22\3\2\2\2qr\7?\2\2r\24\3\2\2\2st\7")
buf.write("=\2\2t\26\3\2\2\2uv\7}\2\2v\30\3\2\2\2wx\7\177\2\2x\32")
buf.write("\3\2\2\2yz\7h\2\2z{\7p\2\2{\34\3\2\2\2|}\7/\2\2}~\7@\2")
buf.write("\2~\36\3\2\2\2\177\u0080\7f\2\2\u0080\u0081\7g\2\2\u0081")
buf.write("\u0082\7h\2\2\u0082 \3\2\2\2\u0083\u0084\7<\2\2\u0084")
buf.write("\"\3\2\2\2\u0085\u0086\7V\2\2\u0086\u0087\7g\2\2\u0087")
buf.write("\u0088\7p\2\2\u0088\u0089\7u\2\2\u0089\u008a\7q\2\2\u008a")
buf.write("\u008b\7t\2\2\u008b$\3\2\2\2\u008c\u008d\7a\2\2\u008d")
buf.write("&\3\2\2\2\u008e\u008f\7x\2\2\u008f\u0090\7\62\2\2\u0090")
buf.write("\u0091\7\60\2\2\u0091\u0092\7\62\2\2\u0092\u0093\7\60")
buf.write("\2\2\u0093\u0094\7\65\2\2\u0094(\3\2\2\2\u0095\u0097\t")
buf.write("\2\2\2\u0096\u0095\3\2\2\2\u0097\u0098\3\2\2\2\u0098\u0096")
buf.write("\3\2\2\2\u0098\u0099\3\2\2\2\u0099\u009a\3\2\2\2\u009a")
buf.write("\u009b\b\25\2\2\u009b*\3\2\2\2\u009c\u009d\7\61\2\2\u009d")
buf.write("\u009e\7\61\2\2\u009e\u00a2\3\2\2\2\u009f\u00a1\13\2\2")
buf.write("\2\u00a0\u009f\3\2\2\2\u00a1\u00a4\3\2\2\2\u00a2\u00a3")
buf.write("\3\2\2\2\u00a2\u00a0\3\2\2\2\u00a3\u00a5\3\2\2\2\u00a4")
buf.write("\u00a2\3\2\2\2\u00a5\u00a6\7\f\2\2\u00a6\u00a7\3\2\2\2")
buf.write("\u00a7\u00a8\b\26\2\2\u00a8,\3\2\2\2\u00a9\u00aa\7\61")
buf.write("\2\2\u00aa\u00ab\7,\2\2\u00ab\u00af\3\2\2\2\u00ac\u00ae")
buf.write("\13\2\2\2\u00ad\u00ac\3\2\2\2\u00ae\u00b1\3\2\2\2\u00af")
buf.write("\u00b0\3\2\2\2\u00af\u00ad\3\2\2\2\u00b0\u00b2\3\2\2\2")
buf.write("\u00b1\u00af\3\2\2\2\u00b2\u00b3\7,\2\2\u00b3\u00b4\7")
buf.write("\61\2\2\u00b4\u00b5\3\2\2\2\u00b5\u00b6\b\27\2\2\u00b6")
buf.write(".\3\2\2\2\u00b7\u00b8\7,\2\2\u00b8\60\3\2\2\2\u00b9\u00ba")
buf.write("\7\61\2\2\u00ba\62\3\2\2\2\u00bb\u00bc\7-\2\2\u00bc\64")
buf.write("\3\2\2\2\u00bd\u00be\7/\2\2\u00be\66\3\2\2\2\u00bf\u00c0")
buf.write("\7>\2\2\u00c08\3\2\2\2\u00c1\u00c2\7@\2\2\u00c2:\3\2\2")
buf.write("\2\u00c3\u00c4\7>\2\2\u00c4\u00c5\7?\2\2\u00c5<\3\2\2")
buf.write("\2\u00c6\u00c7\7@\2\2\u00c7\u00c8\7?\2\2\u00c8>\3\2\2")
buf.write("\2\u00c9\u00ca\7?\2\2\u00ca\u00cb\7?\2\2\u00cb@\3\2\2")
buf.write("\2\u00cc\u00cd\7#\2\2\u00cd\u00ce\7?\2\2\u00ceB\3\2\2")
buf.write("\2\u00cf\u00d0\7B\2\2\u00d0\u00d1\5U+\2\u00d1D\3\2\2\2")
buf.write("\u00d2\u00d3\7\'\2\2\u00d3\u00d4\5U+\2\u00d4F\3\2\2\2")
buf.write("\u00d5\u00d6\7\'\2\2\u00d6\u00d7\5Q)\2\u00d7H\3\2\2\2")
buf.write("\u00d8\u00d9\7o\2\2\u00d9\u00da\7w\2\2\u00da\u00db\7v")
buf.write("\2\2\u00dbJ\3\2\2\2\u00dc\u00dd\7V\2\2\u00dd\u00de\7t")
buf.write("\2\2\u00de\u00df\7w\2\2\u00df\u00e6\7g\2\2\u00e0\u00e1")
buf.write("\7H\2\2\u00e1\u00e2\7c\2\2\u00e2\u00e3\7n\2\2\u00e3\u00e4")
buf.write("\7u\2\2\u00e4\u00e6\7g\2\2\u00e5\u00dc\3\2\2\2\u00e5\u00e0")
buf.write("\3\2\2\2\u00e6L\3\2\2\2\u00e7\u00ea\5Q)\2\u00e8\u00e9")
buf.write("\7\60\2\2\u00e9\u00eb\5Q)\2\u00ea\u00e8\3\2\2\2\u00ea")
buf.write("\u00eb\3\2\2\2\u00eb\u00ed\3\2\2\2\u00ec\u00ee\5S*\2\u00ed")
buf.write("\u00ec\3\2\2\2\u00ed\u00ee\3\2\2\2\u00eeN\3\2\2\2\u00ef")
buf.write("\u00f0\5M\'\2\u00f0\u00f1\7h\2\2\u00f1P\3\2\2\2\u00f2")
buf.write("\u00f4\5Y-\2\u00f3\u00f2\3\2\2\2\u00f4\u00f5\3\2\2\2\u00f5")
buf.write("\u00f3\3\2\2\2\u00f5\u00f6\3\2\2\2\u00f6R\3\2\2\2\u00f7")
buf.write("\u00f9\t\3\2\2\u00f8\u00fa\t\4\2\2\u00f9\u00f8\3\2\2\2")
buf.write("\u00f9\u00fa\3\2\2\2\u00fa\u00fb\3\2\2\2\u00fb\u00fc\5")
buf.write("Q)\2\u00fcT\3\2\2\2\u00fd\u0100\7a\2\2\u00fe\u0100\5W")
buf.write(",\2\u00ff\u00fd\3\2\2\2\u00ff\u00fe\3\2\2\2\u0100\u0106")
buf.write("\3\2\2\2\u0101\u0105\7a\2\2\u0102\u0105\5W,\2\u0103\u0105")
buf.write("\5Y-\2\u0104\u0101\3\2\2\2\u0104\u0102\3\2\2\2\u0104\u0103")
buf.write("\3\2\2\2\u0105\u0108\3\2\2\2\u0106\u0104\3\2\2\2\u0106")
buf.write("\u0107\3\2\2\2\u0107V\3\2\2\2\u0108\u0106\3\2\2\2\u0109")
buf.write("\u010a\t\5\2\2\u010aX\3\2\2\2\u010b\u010c\t\6\2\2\u010c")
buf.write("Z\3\2\2\2\16\2\u0098\u00a2\u00af\u00e5\u00ea\u00ed\u00f5")
buf.write("\u00f9\u00ff\u0104\u0106\3\b\2\2")
buf.write("\4&\t&\4\'\t\'\4(\t(\4)\t)\4*\t*\4+\t+\4,\t,\4-\t-\4.")
buf.write("\t.\4/\t/\4\60\t\60\4\61\t\61\4\62\t\62\4\63\t\63\3\2")
buf.write("\3\2\3\3\3\3\3\4\3\4\3\5\3\5\3\6\3\6\3\7\3\7\3\b\3\b\3")
buf.write("\t\3\t\3\n\3\n\3\n\3\13\3\13\3\13\3\13\3\13\3\f\3\f\3")
buf.write("\f\3\f\3\r\3\r\3\16\3\16\3\17\3\17\3\17\3\20\3\20\3\20")
buf.write("\3\21\3\21\3\21\3\22\3\22\3\22\3\22\3\23\3\23\3\24\3\24")
buf.write("\3\24\3\24\3\24\3\24\3\24\3\25\3\25\3\26\3\26\3\26\3\26")
buf.write("\3\26\3\27\3\27\3\27\3\27\3\27\3\27\3\27\3\30\3\30\3\30")
buf.write("\3\30\3\30\7\30\u00b1\n\30\f\30\16\30\u00b4\13\30\3\30")
buf.write("\3\30\3\30\3\30\3\30\3\31\6\31\u00bc\n\31\r\31\16\31\u00bd")
buf.write("\3\31\3\31\3\32\3\32\3\32\3\32\7\32\u00c6\n\32\f\32\16")
buf.write("\32\u00c9\13\32\3\32\3\32\3\32\3\32\3\33\3\33\3\33\3\34")
buf.write("\3\34\3\34\7\34\u00d5\n\34\f\34\16\34\u00d8\13\34\3\34")
buf.write("\3\34\3\35\3\35\3\36\3\36\3\37\3\37\3 \3 \3!\3!\3\"\3")
buf.write("\"\3#\3#\3#\3$\3$\3$\3%\3%\3%\3&\3&\3&\3\'\3\'\3\'\3\'")
buf.write("\3\'\3\'\3\'\3\'\3\'\5\'\u00fd\n\'\3(\3(\5(\u0101\n(\3")
buf.write("(\3(\3(\7(\u0106\n(\f(\16(\u0109\13(\3(\3(\7(\u010d\n")
buf.write("(\f(\16(\u0110\13(\3)\3)\3)\3*\3*\3*\3+\3+\3+\3,\3,\3")
buf.write(",\3,\3,\3,\3-\3-\3-\5-\u0124\n-\3-\5-\u0127\n-\3.\3.\3")
buf.write(".\3/\6/\u012d\n/\r/\16/\u012e\3\60\3\60\5\60\u0133\n\60")
buf.write("\3\60\3\60\3\61\3\61\3\62\3\62\3\63\3\63\3\63\3\63\3\63")
buf.write("\3\63\3\63\3\63\3\63\3\63\3\63\7\63\u0146\n\63\f\63\16")
buf.write("\63\u0149\13\63\5\u00b2\u00c7\u00d6\2\64\3\3\5\4\7\5\t")
buf.write("\6\13\7\r\b\17\t\21\n\23\13\25\f\27\r\31\16\33\17\35\20")
buf.write("\37\21!\22#\23%\24\'\25)\26+\27-\30/\31\61\32\63\33\65")
buf.write("\2\67\349\35;\36=\37? A!C\"E#G$I%K&M\'O(Q)S*U+W,Y\2[-")
buf.write("]._\2a\2c\2e/\3\2\b\5\2\13\f\17\17\"\"\4\2\f\f\17\17\4")
buf.write("\2GGgg\4\2--//\4\2C\\c|\3\2\62;\2\u0155\2\3\3\2\2\2\2")
buf.write("\5\3\2\2\2\2\7\3\2\2\2\2\t\3\2\2\2\2\13\3\2\2\2\2\r\3")
buf.write("\2\2\2\2\17\3\2\2\2\2\21\3\2\2\2\2\23\3\2\2\2\2\25\3\2")
buf.write("\2\2\2\27\3\2\2\2\2\31\3\2\2\2\2\33\3\2\2\2\2\35\3\2\2")
buf.write("\2\2\37\3\2\2\2\2!\3\2\2\2\2#\3\2\2\2\2%\3\2\2\2\2\'\3")
buf.write("\2\2\2\2)\3\2\2\2\2+\3\2\2\2\2-\3\2\2\2\2/\3\2\2\2\2\61")
buf.write("\3\2\2\2\2\63\3\2\2\2\2\67\3\2\2\2\29\3\2\2\2\2;\3\2\2")
buf.write("\2\2=\3\2\2\2\2?\3\2\2\2\2A\3\2\2\2\2C\3\2\2\2\2E\3\2")
buf.write("\2\2\2G\3\2\2\2\2I\3\2\2\2\2K\3\2\2\2\2M\3\2\2\2\2O\3")
buf.write("\2\2\2\2Q\3\2\2\2\2S\3\2\2\2\2U\3\2\2\2\2W\3\2\2\2\2[")
buf.write("\3\2\2\2\2]\3\2\2\2\2e\3\2\2\2\3g\3\2\2\2\5i\3\2\2\2\7")
buf.write("k\3\2\2\2\tm\3\2\2\2\13o\3\2\2\2\rq\3\2\2\2\17s\3\2\2")
buf.write("\2\21u\3\2\2\2\23w\3\2\2\2\25z\3\2\2\2\27\177\3\2\2\2")
buf.write("\31\u0083\3\2\2\2\33\u0085\3\2\2\2\35\u0087\3\2\2\2\37")
buf.write("\u008a\3\2\2\2!\u008d\3\2\2\2#\u0090\3\2\2\2%\u0094\3")
buf.write("\2\2\2\'\u0096\3\2\2\2)\u009d\3\2\2\2+\u009f\3\2\2\2-")
buf.write("\u00a4\3\2\2\2/\u00ab\3\2\2\2\61\u00bb\3\2\2\2\63\u00c1")
buf.write("\3\2\2\2\65\u00ce\3\2\2\2\67\u00d1\3\2\2\29\u00db\3\2")
buf.write("\2\2;\u00dd\3\2\2\2=\u00df\3\2\2\2?\u00e1\3\2\2\2A\u00e3")
buf.write("\3\2\2\2C\u00e5\3\2\2\2E\u00e7\3\2\2\2G\u00ea\3\2\2\2")
buf.write("I\u00ed\3\2\2\2K\u00f0\3\2\2\2M\u00fc\3\2\2\2O\u0100\3")
buf.write("\2\2\2Q\u0111\3\2\2\2S\u0114\3\2\2\2U\u0117\3\2\2\2W\u011a")
buf.write("\3\2\2\2Y\u0120\3\2\2\2[\u0128\3\2\2\2]\u012c\3\2\2\2")
buf.write("_\u0130\3\2\2\2a\u0136\3\2\2\2c\u0138\3\2\2\2e\u013a\3")
buf.write("\2\2\2gh\7.\2\2h\4\3\2\2\2ij\7*\2\2j\6\3\2\2\2kl\7+\2")
buf.write("\2l\b\3\2\2\2mn\7}\2\2n\n\3\2\2\2op\7\177\2\2p\f\3\2\2")
buf.write("\2qr\7\60\2\2r\16\3\2\2\2st\7]\2\2t\20\3\2\2\2uv\7_\2")
buf.write("\2v\22\3\2\2\2wx\7k\2\2xy\7h\2\2y\24\3\2\2\2z{\7g\2\2")
buf.write("{|\7n\2\2|}\7u\2\2}~\7g\2\2~\26\3\2\2\2\177\u0080\7n\2")
buf.write("\2\u0080\u0081\7g\2\2\u0081\u0082\7v\2\2\u0082\30\3\2")
buf.write("\2\2\u0083\u0084\7?\2\2\u0084\32\3\2\2\2\u0085\u0086\7")
buf.write("=\2\2\u0086\34\3\2\2\2\u0087\u0088\7=\2\2\u0088\u0089")
buf.write("\7=\2\2\u0089\36\3\2\2\2\u008a\u008b\7h\2\2\u008b\u008c")
buf.write("\7p\2\2\u008c \3\2\2\2\u008d\u008e\7/\2\2\u008e\u008f")
buf.write("\7@\2\2\u008f\"\3\2\2\2\u0090\u0091\7f\2\2\u0091\u0092")
buf.write("\7g\2\2\u0092\u0093\7h\2\2\u0093$\3\2\2\2\u0094\u0095")
buf.write("\7<\2\2\u0095&\3\2\2\2\u0096\u0097\7V\2\2\u0097\u0098")
buf.write("\7g\2\2\u0098\u0099\7p\2\2\u0099\u009a\7u\2\2\u009a\u009b")
buf.write("\7q\2\2\u009b\u009c\7t\2\2\u009c(\3\2\2\2\u009d\u009e")
buf.write("\7a\2\2\u009e*\3\2\2\2\u009f\u00a0\7o\2\2\u00a0\u00a1")
buf.write("\7g\2\2\u00a1\u00a2\7v\2\2\u00a2\u00a3\7c\2\2\u00a3,\3")
buf.write("\2\2\2\u00a4\u00a5\7x\2\2\u00a5\u00a6\7\62\2\2\u00a6\u00a7")
buf.write("\7\60\2\2\u00a7\u00a8\7\62\2\2\u00a8\u00a9\7\60\2\2\u00a9")
buf.write("\u00aa\7\65\2\2\u00aa.\3\2\2\2\u00ab\u00ac\7\61\2\2\u00ac")
buf.write("\u00ad\7,\2\2\u00ad\u00b2\3\2\2\2\u00ae\u00b1\5/\30\2")
buf.write("\u00af\u00b1\13\2\2\2\u00b0\u00ae\3\2\2\2\u00b0\u00af")
buf.write("\3\2\2\2\u00b1\u00b4\3\2\2\2\u00b2\u00b3\3\2\2\2\u00b2")
buf.write("\u00b0\3\2\2\2\u00b3\u00b5\3\2\2\2\u00b4\u00b2\3\2\2\2")
buf.write("\u00b5\u00b6\7,\2\2\u00b6\u00b7\7\61\2\2\u00b7\u00b8\3")
buf.write("\2\2\2\u00b8\u00b9\b\30\2\2\u00b9\60\3\2\2\2\u00ba\u00bc")
buf.write("\t\2\2\2\u00bb\u00ba\3\2\2\2\u00bc\u00bd\3\2\2\2\u00bd")
buf.write("\u00bb\3\2\2\2\u00bd\u00be\3\2\2\2\u00be\u00bf\3\2\2\2")
buf.write("\u00bf\u00c0\b\31\2\2\u00c0\62\3\2\2\2\u00c1\u00c2\7\61")
buf.write("\2\2\u00c2\u00c3\7\61\2\2\u00c3\u00c7\3\2\2\2\u00c4\u00c6")
buf.write("\13\2\2\2\u00c5\u00c4\3\2\2\2\u00c6\u00c9\3\2\2\2\u00c7")
buf.write("\u00c8\3\2\2\2\u00c7\u00c5\3\2\2\2\u00c8\u00ca\3\2\2\2")
buf.write("\u00c9\u00c7\3\2\2\2\u00ca\u00cb\7\f\2\2\u00cb\u00cc\3")
buf.write("\2\2\2\u00cc\u00cd\b\32\2\2\u00cd\64\3\2\2\2\u00ce\u00cf")
buf.write("\7^\2\2\u00cf\u00d0\7$\2\2\u00d0\66\3\2\2\2\u00d1\u00d6")
buf.write("\7$\2\2\u00d2\u00d5\5\65\33\2\u00d3\u00d5\n\3\2\2\u00d4")
buf.write("\u00d2\3\2\2\2\u00d4\u00d3\3\2\2\2\u00d5\u00d8\3\2\2\2")
buf.write("\u00d6\u00d7\3\2\2\2\u00d6\u00d4\3\2\2\2\u00d7\u00d9\3")
buf.write("\2\2\2\u00d8\u00d6\3\2\2\2\u00d9\u00da\7$\2\2\u00da8\3")
buf.write("\2\2\2\u00db\u00dc\7,\2\2\u00dc:\3\2\2\2\u00dd\u00de\7")
buf.write("\61\2\2\u00de<\3\2\2\2\u00df\u00e0\7-\2\2\u00e0>\3\2\2")
buf.write("\2\u00e1\u00e2\7/\2\2\u00e2@\3\2\2\2\u00e3\u00e4\7>\2")
buf.write("\2\u00e4B\3\2\2\2\u00e5\u00e6\7@\2\2\u00e6D\3\2\2\2\u00e7")
buf.write("\u00e8\7>\2\2\u00e8\u00e9\7?\2\2\u00e9F\3\2\2\2\u00ea")
buf.write("\u00eb\7@\2\2\u00eb\u00ec\7?\2\2\u00ecH\3\2\2\2\u00ed")
buf.write("\u00ee\7?\2\2\u00ee\u00ef\7?\2\2\u00efJ\3\2\2\2\u00f0")
buf.write("\u00f1\7#\2\2\u00f1\u00f2\7?\2\2\u00f2L\3\2\2\2\u00f3")
buf.write("\u00f4\7V\2\2\u00f4\u00f5\7t\2\2\u00f5\u00f6\7w\2\2\u00f6")
buf.write("\u00fd\7g\2\2\u00f7\u00f8\7H\2\2\u00f8\u00f9\7c\2\2\u00f9")
buf.write("\u00fa\7n\2\2\u00fa\u00fb\7u\2\2\u00fb\u00fd\7g\2\2\u00fc")
buf.write("\u00f3\3\2\2\2\u00fc\u00f7\3\2\2\2\u00fdN\3\2\2\2\u00fe")
buf.write("\u0101\7a\2\2\u00ff\u0101\5a\61\2\u0100\u00fe\3\2\2\2")
buf.write("\u0100\u00ff\3\2\2\2\u0101\u0107\3\2\2\2\u0102\u0106\7")
buf.write("a\2\2\u0103\u0106\5a\61\2\u0104\u0106\5c\62\2\u0105\u0102")
buf.write("\3\2\2\2\u0105\u0103\3\2\2\2\u0105\u0104\3\2\2\2\u0106")
buf.write("\u0109\3\2\2\2\u0107\u0105\3\2\2\2\u0107\u0108\3\2\2\2")
buf.write("\u0108\u010e\3\2\2\2\u0109\u0107\3\2\2\2\u010a\u010b\7")
buf.write("\60\2\2\u010b\u010d\5O(\2\u010c\u010a\3\2\2\2\u010d\u0110")
buf.write("\3\2\2\2\u010e\u010c\3\2\2\2\u010e\u010f\3\2\2\2\u010f")
buf.write("P\3\2\2\2\u0110\u010e\3\2\2\2\u0111\u0112\7B\2\2\u0112")
buf.write("\u0113\5O(\2\u0113R\3\2\2\2\u0114\u0115\7\'\2\2\u0115")
buf.write("\u0116\5O(\2\u0116T\3\2\2\2\u0117\u0118\7\'\2\2\u0118")
buf.write("\u0119\5]/\2\u0119V\3\2\2\2\u011a\u011b\7k\2\2\u011b\u011c")
buf.write("\7p\2\2\u011c\u011d\7v\2\2\u011d\u011e\78\2\2\u011e\u011f")
buf.write("\7\66\2\2\u011fX\3\2\2\2\u0120\u0123\5]/\2\u0121\u0122")
buf.write("\7\60\2\2\u0122\u0124\5]/\2\u0123\u0121\3\2\2\2\u0123")
buf.write("\u0124\3\2\2\2\u0124\u0126\3\2\2\2\u0125\u0127\5_\60\2")
buf.write("\u0126\u0125\3\2\2\2\u0126\u0127\3\2\2\2\u0127Z\3\2\2")
buf.write("\2\u0128\u0129\5Y-\2\u0129\u012a\7h\2\2\u012a\\\3\2\2")
buf.write("\2\u012b\u012d\5c\62\2\u012c\u012b\3\2\2\2\u012d\u012e")
buf.write("\3\2\2\2\u012e\u012c\3\2\2\2\u012e\u012f\3\2\2\2\u012f")
buf.write("^\3\2\2\2\u0130\u0132\t\4\2\2\u0131\u0133\t\5\2\2\u0132")
buf.write("\u0131\3\2\2\2\u0132\u0133\3\2\2\2\u0133\u0134\3\2\2\2")
buf.write("\u0134\u0135\5]/\2\u0135`\3\2\2\2\u0136\u0137\t\6\2\2")
buf.write("\u0137b\3\2\2\2\u0138\u0139\t\7\2\2\u0139d\3\2\2\2\u013a")
buf.write("\u013b\7O\2\2\u013b\u013c\7G\2\2\u013c\u013d\7V\2\2\u013d")
buf.write("\u013e\7C\2\2\u013e\u013f\7F\2\2\u013f\u0140\7C\2\2\u0140")
buf.write("\u0141\7V\2\2\u0141\u0142\7C\2\2\u0142\u0143\7<\2\2\u0143")
buf.write("\u0147\3\2\2\2\u0144\u0146\13\2\2\2\u0145\u0144\3\2\2")
buf.write("\2\u0146\u0149\3\2\2\2\u0147\u0145\3\2\2\2\u0147\u0148")
buf.write("\3\2\2\2\u0148f\3\2\2\2\u0149\u0147\3\2\2\2\23\2\u00b0")
buf.write("\u00b2\u00bd\u00c7\u00d4\u00d6\u00fc\u0100\u0105\u0107")
buf.write("\u010e\u0123\u0126\u012e\u0132\u0147\3\b\2\2")
return buf.getvalue()
......@@ -144,51 +175,59 @@ class RelayLexer(Lexer):
T__15 = 16
T__16 = 17
T__17 = 18
SEMVER = 19
WS = 20
LINE_COMMENT = 21
COMMENT = 22
MUL = 23
DIV = 24
ADD = 25
SUB = 26
LT = 27
GT = 28
LE = 29
GE = 30
EQ = 31
NE = 32
GLOBAL_VAR = 33
LOCAL_VAR = 34
GRAPH_VAR = 35
MUT = 36
T__18 = 19
T__19 = 20
T__20 = 21
SEMVER = 22
COMMENT = 23
WS = 24
LINE_COMMENT = 25
QUOTED_STRING = 26
MUL = 27
DIV = 28
ADD = 29
SUB = 30
LT = 31
GT = 32
LE = 33
GE = 34
EQ = 35
NE = 36
BOOL_LIT = 37
FLOAT = 38
NAT = 39
CNAME = 40
CNAME = 38
GLOBAL_VAR = 39
LOCAL_VAR = 40
GRAPH_VAR = 41
DATATYPE = 42
FLOAT = 43
NAT = 44
METADATA = 45
channelNames = [ u"DEFAULT_TOKEN_CHANNEL", u"HIDDEN" ]
modeNames = [ "DEFAULT_MODE" ]
literalNames = [ "<INVALID>",
"'('", "')'", "','", "'['", "']'", "'if'", "'else'", "'let'",
"'='", "';'", "'{'", "'}'", "'fn'", "'->'", "'def'", "':'",
"'Tensor'", "'_'", "'v0.0.3'", "'*'", "'/'", "'+'", "'-'", "'<'",
"'>'", "'<='", "'>='", "'=='", "'!='", "'mut'" ]
"','", "'('", "')'", "'{'", "'}'", "'.'", "'['", "']'", "'if'",
"'else'", "'let'", "'='", "';'", "';;'", "'fn'", "'->'", "'def'",
"':'", "'Tensor'", "'_'", "'meta'", "'v0.0.3'", "'*'", "'/'",
"'+'", "'-'", "'<'", "'>'", "'<='", "'>='", "'=='", "'!='",
"'int64'" ]
symbolicNames = [ "<INVALID>",
"SEMVER", "WS", "LINE_COMMENT", "COMMENT", "MUL", "DIV", "ADD",
"SUB", "LT", "GT", "LE", "GE", "EQ", "NE", "GLOBAL_VAR", "LOCAL_VAR",
"GRAPH_VAR", "MUT", "BOOL_LIT", "FLOAT", "NAT", "CNAME" ]
"SEMVER", "COMMENT", "WS", "LINE_COMMENT", "QUOTED_STRING",
"MUL", "DIV", "ADD", "SUB", "LT", "GT", "LE", "GE", "EQ", "NE",
"BOOL_LIT", "CNAME", "GLOBAL_VAR", "LOCAL_VAR", "GRAPH_VAR",
"DATATYPE", "FLOAT", "NAT", "METADATA" ]
ruleNames = [ "T__0", "T__1", "T__2", "T__3", "T__4", "T__5", "T__6",
"T__7", "T__8", "T__9", "T__10", "T__11", "T__12", "T__13",
"T__14", "T__15", "T__16", "T__17", "SEMVER", "WS", "LINE_COMMENT",
"COMMENT", "MUL", "DIV", "ADD", "SUB", "LT", "GT", "LE",
"GE", "EQ", "NE", "GLOBAL_VAR", "LOCAL_VAR", "GRAPH_VAR",
"MUT", "BOOL_LIT", "PREFLOAT", "FLOAT", "NAT", "EXP",
"CNAME", "LETTER", "DIGIT" ]
"T__14", "T__15", "T__16", "T__17", "T__18", "T__19",
"T__20", "SEMVER", "COMMENT", "WS", "LINE_COMMENT", "ESCAPED_QUOTE",
"QUOTED_STRING", "MUL", "DIV", "ADD", "SUB", "LT", "GT",
"LE", "GE", "EQ", "NE", "BOOL_LIT", "CNAME", "GLOBAL_VAR",
"LOCAL_VAR", "GRAPH_VAR", "DATATYPE", "PREFLOAT", "FLOAT",
"NAT", "EXP", "LETTER", "DIGIT", "METADATA" ]
grammarFileName = "Relay.g4"
......
This source diff could not be displayed because it is too large. You can view the blob instead.
......@@ -19,11 +19,51 @@ class RelayVisitor(ParseTreeVisitor):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#exprList.
def visitExprList(self, ctx:RelayParser.ExprListContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#callNoAttr.
def visitCallNoAttr(self, ctx:RelayParser.CallNoAttrContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#callWithAttr.
def visitCallWithAttr(self, ctx:RelayParser.CallWithAttrContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#funcExpr.
def visitFuncExpr(self, ctx:RelayParser.FuncExprContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#metaExpr.
def visitMetaExpr(self, ctx:RelayParser.MetaExprContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#tensor.
def visitTensor(self, ctx:RelayParser.TensorContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#graph.
def visitGraph(self, ctx:RelayParser.GraphContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#identExpr.
def visitIdentExpr(self, ctx:RelayParser.IdentExprContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#stringExpr.
def visitStringExpr(self, ctx:RelayParser.StringExprContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#call.
def visitCall(self, ctx:RelayParser.CallContext):
return self.visitChildren(ctx)
......@@ -39,13 +79,8 @@ class RelayVisitor(ParseTreeVisitor):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#parens.
def visitParens(self, ctx:RelayParser.ParensContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#funcExpr.
def visitFuncExpr(self, ctx:RelayParser.FuncExprContext):
# Visit a parse tree produced by RelayParser#paren.
def visitParen(self, ctx:RelayParser.ParenContext):
return self.visitChildren(ctx)
......@@ -59,8 +94,8 @@ class RelayVisitor(ParseTreeVisitor):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#tensor.
def visitTensor(self, ctx:RelayParser.TensorContext):
# Visit a parse tree produced by RelayParser#projection.
def visitProjection(self, ctx:RelayParser.ProjectionContext):
return self.visitChildren(ctx)
......@@ -69,11 +104,6 @@ class RelayVisitor(ParseTreeVisitor):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#graph.
def visitGraph(self, ctx:RelayParser.GraphContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#binOp.
def visitBinOp(self, ctx:RelayParser.BinOpContext):
return self.visitChildren(ctx)
......@@ -89,8 +119,13 @@ class RelayVisitor(ParseTreeVisitor):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#argList.
def visitArgList(self, ctx:RelayParser.ArgListContext):
# Visit a parse tree produced by RelayParser#argNoAttr.
def visitArgNoAttr(self, ctx:RelayParser.ArgNoAttrContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#argWithAttr.
def visitArgWithAttr(self, ctx:RelayParser.ArgWithAttrContext):
return self.visitChildren(ctx)
......@@ -104,8 +139,8 @@ class RelayVisitor(ParseTreeVisitor):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#attrList.
def visitAttrList(self, ctx:RelayParser.AttrListContext):
# Visit a parse tree produced by RelayParser#attrSeq.
def visitAttrSeq(self, ctx:RelayParser.AttrSeqContext):
return self.visitChildren(ctx)
......@@ -114,8 +149,8 @@ class RelayVisitor(ParseTreeVisitor):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#typeParamSeq.
def visitTypeParamSeq(self, ctx:RelayParser.TypeParamSeqContext):
# Visit a parse tree produced by RelayParser#typeParamList.
def visitTypeParamList(self, ctx:RelayParser.TypeParamListContext):
return self.visitChildren(ctx)
......@@ -149,8 +184,18 @@ class RelayVisitor(ParseTreeVisitor):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#shapeSeq.
def visitShapeSeq(self, ctx:RelayParser.ShapeSeqContext):
# Visit a parse tree produced by RelayParser#shapeList.
def visitShapeList(self, ctx:RelayParser.ShapeListContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#meta.
def visitMeta(self, ctx:RelayParser.MetaContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#metaShape.
def visitMetaShape(self, ctx:RelayParser.MetaShapeContext):
return self.visitChildren(ctx)
......
......@@ -66,34 +66,34 @@ def conv2d(data,
weight : tvm.relay.Expr
The weight expressions.
strides : tuple of int, optional
strides : Optional[Tuple[int]]
The strides of convolution.
padding : tuple of int, optional
padding : Optional[Tuple[int]]
The padding of convolution on both sides of inputs before convolution.
dilation : tuple of int, optional
dilation : Optional[Tuple[int]]
Specifies the dilation rate to be used for dilated convolution.
groups : int, optional
groups : Optional[int]
Number of groups for grouped convolution.
channels : int, optional
channels : Optional[int]
Number of output channels of this convolution.
kernel_size : tuple of int, optional
kernel_size : Optional[Tuple[int]]
The spatial of the convolution kernel.
data_layout : str, optional
data_layout : Optional[str]
Layout of the input.
kernel_layout : str, optional
kernel_layout : Optional[str]
Layout of the weight.
out_layout : str, optional
out_layout : Optional[str]
Layout of the output, by default, out_layout is the same as data_layout
out_dtype : str, optional
out_dtype : Optional[str]
Specifies the output data type for mixed precision conv2d.
Returns
......@@ -691,8 +691,30 @@ def dropout(data, rate=0.5):
result : tvm.relay.Expr
The result of dropout
"""
result = _make.dropout(data, rate)
return TupleWrapper(result, 2)[0]
return TupleWrapper(dropout_raw(data, rate), 2)[0]
def dropout_raw(data, rate=0.5):
"""Applies the dropout operation to the input array.
During training, each element of the input is set to zero with
probability ``p``. The whole array is rescaled by ``1/(1-p)``
to keep the expected sum of the input unchanged.
Parameters
----------
data : tvm.relay.Expr
The input data to the operator.
rate : float, optional (default=0.5)
The probability for an element to be reset to 0.
Returns
-------
result : tvm.relay.Expr
The result of dropout
"""
return _make.dropout(data, rate)
def batch_norm(data,
......
......@@ -23,4 +23,7 @@ from .. import register_func
def fromtext(data, source_name=None):
"""Parse a Relay program."""
from tvm.relay import _parser
return _parser.fromtext(data, source_name)
x = _parser.fromtext(data + "\n", source_name)
if x is None:
raise Exception("cannot parse: ", data)
return x
......@@ -42,7 +42,7 @@ def _make_dense_block(data, num_layers, bn_size, growth_rate, index):
layer_out = data
for i in range(num_layers):
layer_out = _make_dense_layer(layer_out, growth_rate, bn_size,
"(%s, %s)" % (index, i))
"%s_%s" % (index, i))
return layer_out
def _make_transition(data, num_output_features, index):
......
......@@ -29,7 +29,7 @@ class Type(RelayNode):
"""Compare two Relay types for structural equivalence using
alpha equivalence.
"""
return bool(_make._type_alpha_equal(self, other))
return bool(_make._alpha_equal(self, other))
def __ne__(self, other):
return not self.__eq__(other)
......
......@@ -18,7 +18,7 @@
*/
/*!
* Copyright (c) 2018 by Contributors
* Copyright (c) 2019 by Contributors
* \file src/tvm/relay/ir/alpha_equal.cc
* \brief Alpha equality check by deep comparing two nodes.
*/
......@@ -27,9 +27,10 @@
#include <tvm/relay/pattern_functor.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/attrs/nn.h>
#include "type_functor.h"
#include "../../lang/attr_functor.h"
namespace tvm {
namespace relay {
......@@ -40,8 +41,8 @@ class AlphaEqualHandler:
public ExprFunctor<bool(const Expr&, const Expr&)>,
public PatternFunctor<bool(const Pattern&, const Pattern&)> {
public:
explicit AlphaEqualHandler(bool map_free_var)
: map_free_var_(map_free_var) { }
explicit AlphaEqualHandler(bool map_free_var, bool assert_mode)
: map_free_var_(map_free_var), assert_mode_(assert_mode) { }
/*!
* Check equality of two nodes.
......@@ -76,6 +77,9 @@ class AlphaEqualHandler:
return AttrEqual(lhs, rhs);
}
bool DoubleEqual(double l, double r) {
return true;
}
/*!
* Check equality of two attributes.
* \param lhs The left hand operand.
......@@ -83,9 +87,9 @@ class AlphaEqualHandler:
* \return The comparison result.
*/
bool AttrEqual(const NodeRef& lhs, const NodeRef& rhs) {
auto compute = [&]() {
if (&lhs == &rhs) return true;
auto lhsd = lhs.as<DictAttrsNode>();
if (lhsd) {
if (auto lhsd = lhs.as<DictAttrsNode>()) {
auto rhsd = lhs.as<DictAttrsNode>();
if (!rhsd) return false;
if (lhsd->dict.size() != rhsd->dict.size()) return false;
......@@ -94,7 +98,17 @@ class AlphaEqualHandler:
}
return true;
}
if (auto lhsbn = lhs.as<BatchNormAttrs>()) {
auto rhsbn = rhs.as<BatchNormAttrs>();
if (!rhsbn) return false;
return (lhsbn->axis == rhsbn->axis)
&& DoubleEqual(lhsbn->epsilon, rhsbn->epsilon)
&& (lhsbn->center == rhsbn->center)
&& (lhsbn->scale == rhsbn->scale);
}
return AttrsEqualHandler::Equal(lhs, rhs);
};
return Compare(compute(), lhs, rhs);
}
/*!
* Check equality of two types.
......@@ -107,6 +121,13 @@ class AlphaEqualHandler:
if (!lhs.defined() || !rhs.defined()) return false;
return this->VisitType(lhs, rhs);
}
bool Compare(bool result, const NodeRef& lhs, const NodeRef& rhs) {
if (assert_mode_) {
CHECK(result) << "\n" << AsText(lhs, true) << "\nis not equal to:\n" << AsText(rhs, true);
}
return result;
}
/*!
* Check equality of two expressions.
*
......@@ -120,6 +141,7 @@ class AlphaEqualHandler:
* \return The comparison result.
*/
bool ExprEqual(const Expr& lhs, const Expr& rhs) {
auto compute = [&]() {
if (lhs.same_as(rhs)) return true;
if (!lhs.defined() || !rhs.defined()) return false;
auto it = equal_map_.find(lhs);
......@@ -132,6 +154,8 @@ class AlphaEqualHandler:
} else {
return false;
}
};
return Compare(compute(), lhs, rhs);
}
protected:
......@@ -516,32 +540,41 @@ class AlphaEqualHandler:
private:
// whether to map open terms.
bool map_free_var_;
// if in assert mode, must return true, and will throw error otherwise.
bool assert_mode_;
// renaming of NodeRef to indicate two nodes equals to each other
std::unordered_map<NodeRef, NodeRef, NodeHash, NodeEqual> equal_map_;
};
bool AlphaEqual(const Type& lhs, const Type& rhs) {
return AlphaEqualHandler(false).TypeEqual(lhs, rhs);
return AlphaEqualHandler(false, false).TypeEqual(lhs, rhs);
}
bool AlphaEqual(const Expr& lhs, const Expr& rhs) {
return AlphaEqualHandler(false).ExprEqual(lhs, rhs);
return AlphaEqualHandler(false, false).ExprEqual(lhs, rhs);
}
// TODO(@jroesch): move to correct namespace?
TVM_REGISTER_API("relay._make._alpha_equal")
.set_body_typed<bool(NodeRef, NodeRef)>([](NodeRef a, NodeRef b) {
return AlphaEqualHandler(false).Equal(a, b);
return AlphaEqualHandler(false, false).Equal(a, b);
});
TVM_REGISTER_API("relay._make._type_alpha_equal")
.set_body_typed<bool(Type, Type)>([](Type a, Type b) {
return AlphaEqualHandler(false).TypeEqual(a, b);
TVM_REGISTER_API("relay._make._assert_alpha_equal")
.set_body_typed<void(NodeRef, NodeRef)>([](NodeRef a, NodeRef b) {
bool alpha_equal = AlphaEqualHandler(false, true).Equal(a, b);
CHECK(alpha_equal) << AsText(a, true) << " and " << AsText(b, true) << " is not alpha equal";
});
TVM_REGISTER_API("relay._make._graph_equal")
.set_body_typed<bool(NodeRef, NodeRef)>([](NodeRef a, NodeRef b) {
return AlphaEqualHandler(true).Equal(a, b);
return AlphaEqualHandler(true, false).Equal(a, b);
});
TVM_REGISTER_API("relay._make._assert_graph_equal")
.set_body_typed<void(NodeRef, NodeRef)>([](NodeRef a, NodeRef b) {
bool graph_equal = AlphaEqualHandler(true, true).Equal(a, b);
CHECK(graph_equal) << AsText(a, true) << " and " << AsText(b, true) << " is not graph equal";
});
} // namespace relay
......
......@@ -89,7 +89,7 @@ std::string Doc::str() {
return os.str();
}
Doc PrintVec(const std::vector<Doc>& vec, const Doc& sep) {
Doc PrintSep(const std::vector<Doc>& vec, const Doc& sep) {
Doc seq;
if (vec.size() != 0) {
seq = vec[0];
......
......@@ -46,7 +46,11 @@ using DocAtom = std::shared_ptr<DocAtomNode>;
struct TextNode : DocAtomNode {
std::string str;
explicit TextNode(const std::string& str) : str(str) {}
explicit TextNode(const std::string& str) : str(str) {
if (str.find_first_of("\t\n") != str.npos) {
LOG(WARNING) << "text node: '" << str << "' should not has tab or newline.";
}
}
};
struct LineNode : DocAtomNode {
......@@ -91,8 +95,8 @@ class Doc {
// DSL functions
// Render vectors of docs with a separator. e.g. PrintVec([1, 2, 3], f) -> 1f2f3
Doc PrintVec(const std::vector<Doc>& vec, const Doc& sep = Doc(", "));
// Render vectors of docs with a separator. e.g. PrintSep([1, 2, 3], f) -> 1f2f3
Doc PrintSep(const std::vector<Doc>& vec, const Doc& sep = Doc(", "));
// Print a constant bool value.
Doc PrintBool(bool value);
// Print a data type.
......@@ -116,7 +120,8 @@ Doc PrintConstScalar(DataType dtype, const T* data) {
} else if (dtype == Bool()) {
return PrintBool(data[0] != 0);
} else {
os << dtype << "(" << data[0] << ")";
// todo(@M.K.) this is unsafe. fix.
os << data[0];
}
return Doc(os.str());
}
......
......@@ -32,6 +32,7 @@
* - Otherwise, inline if the node is at the end of a scope and is used at most once.
*/
#include <dmlc/json.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/module.h>
#include <tvm/relay/pattern_functor.h>
......@@ -43,6 +44,17 @@
namespace tvm {
namespace relay {
Doc Brace(const Doc& d,
const std::string& open = "{",
const std::string& close = "}",
int indent = 2) {
Doc doc;
doc << open;
doc << Indent(indent, PrintNewLine() << d) << PrintNewLine();
doc << close;
return doc;
}
/*!
* \brief Meta data context for PrettyPrinter.
*
......@@ -108,8 +120,10 @@ class TextMetaDataContext {
if (it != meta_repr_.end()) {
return it->second;
}
std::string type_key = node->type_key();
CHECK(!type_key.empty());
Array<NodeRef>& mvector =
meta_data_[node->type_key()];
meta_data_[type_key];
int64_t index = static_cast<int64_t>(mvector.size());
mvector.push_back(node);
Doc doc;
......@@ -117,14 +131,18 @@ class TextMetaDataContext {
meta_repr_[node] = doc;
return meta_repr_[node];
}
Doc PrintKeyValue(const std::string& str, const Doc& v) const {
return Doc("\"") << str << "\": " << v;
}
/*!
* \brief Get the metadata section in json format.
* \return the meta data string.
*/
std::string GetMetaSection() const {
if (meta_data_.size() == 0) return std::string();
return SaveJSON(Map<std::string, NodeRef>(
meta_data_.begin(), meta_data_.end()));
Doc GetMetaSection() const {
if (meta_data_.size() == 0) return Doc();
return Doc(SaveJSON(Map<std::string, NodeRef>(meta_data_.begin(), meta_data_.end())));
}
/*! \return whether the meta data context is empty. */
......@@ -172,12 +190,11 @@ class PrettyPrinter :
}
// indent a new body
// TODO(jmp): indent should be an instance variable of the printer
Doc PrintBody(const NodeRef& node, int indent = 2) {
Doc doc;
Doc body;
doc << "{";
doc << Indent(indent, body << "\n" << PrintScope(node)) << "\n";
doc << Indent(indent, body << PrintNewLine() << PrintScope(node)) << PrintNewLine();
doc << "}";
return doc;
}
......@@ -203,13 +220,12 @@ class PrettyPrinter :
Doc doc;
doc << PrintScope(node);
if (!meta_.empty()) {
doc << PrintNewLine();
if (show_meta_data_) {
std::string meta_json = meta_.GetMetaSection();
// append meta data in the end.
doc << "\n" << "/* meta data */" << "\n" << meta_json;
doc << "METADATA:" << PrintNewLine() << meta_.GetMetaSection();
} else {
doc << "\n"
<< "// meta data omitted. you can use show_meta_data=True to include meta data";
doc << "// meta data omitted. you can use show_meta_data=True to include meta data";
}
}
return doc;
......@@ -361,7 +377,7 @@ class PrettyPrinter :
// wrap GNFed let in brackets
Doc body;
printed_expr << "{";
printed_expr << Indent(2, body << "\n" << VisitExpr(expr)) << "\n";
printed_expr << Indent(2, body << PrintNewLine() << VisitExpr(expr)) << PrintNewLine();
printed_expr << "}";
} else {
printed_expr = VisitExpr(expr);
......@@ -373,7 +389,7 @@ class PrettyPrinter :
if (expr.as<VarNode>()) {
// This is our first time visiting the var and we hit the VarNode case
// in the visitor. Thus the variable is free.
doc_stack_.back() << "free_var " << printed_expr << "\n";
doc_stack_.back() << "free_var " << printed_expr << PrintNewLine();
// Memoization is done in AllocVar.
return memo_[expr];
} else if (inline_expr) {
......@@ -422,7 +438,7 @@ class PrettyPrinter :
fields.push_back(Print(field));
}
Doc doc;
doc << "(" << PrintVec(fields);
doc << "(" << PrintSep(fields);
// conform to python tuple format (1,)
if (op->fields.size() == 1) {
doc << ",";
......@@ -468,7 +484,7 @@ class PrettyPrinter :
for (const TypeVar& tv : fn->type_params) {
type_params.push_back(AllocTypeVar(tv));
}
doc << PrintVec(type_params);
doc << PrintSep(type_params);
doc << ">";
}
doc << "(";
......@@ -479,7 +495,7 @@ class PrettyPrinter :
for (const Doc& d : PrintFuncAttrs(fn->attrs)) {
params.push_back(d);
}
doc << PrintVec(params) << ") ";
doc << PrintSep(params) << ") ";
if (fn->ret_type.defined()) {
doc << "-> " << Print(fn->ret_type) << " ";
}
......@@ -493,13 +509,13 @@ class PrettyPrinter :
for (const auto& kv : mod->functions) {
dg_ = DependencyGraph::Create(&arena_, kv.second);
std::ostringstream os;
if (counter++ != 0) {
doc << "\n";
doc << PrintNewLine();
}
std::ostringstream os;
os << "def @" << kv.first->name_hint;
doc << PrintFunc(Doc(os.str()), kv.second);
doc << "\n";
doc << PrintNewLine();
}
return doc;
}
......@@ -528,7 +544,7 @@ class PrettyPrinter :
args.push_back(d);
}
doc << Print(op->op);
return doc << "(" << PrintVec(args) << ")";
return doc << "(" << PrintSep(args) << ")";
}
Doc VisitExpr_(const RefCreateNode* op) final {
......@@ -558,7 +574,7 @@ class PrettyPrinter :
clauses.push_back(clause_doc << Print(clause->lhs) << " -> "
<< Print(clause->rhs));
}
doc << Indent(2, body << "\n" << PrintVec(clauses, Doc("\n"))) << "\n";
doc << Indent(2, body << PrintNewLine() << PrintSep(clauses, PrintNewLine())) << PrintNewLine();
doc << "}";
return doc;
}
......@@ -570,7 +586,7 @@ class PrettyPrinter :
for (const auto& pat : p->patterns) {
pats.push_back(Print(pat));
}
return doc << PrintVec(pats) << ")";
return doc << PrintSep(pats) << ")";
}
Doc VisitPattern_(const PatternVarNode* pv) final {
......@@ -617,7 +633,7 @@ class PrettyPrinter :
args.push_back(PrintType(t, false));
}
doc << "[";
doc << PrintVec(args);
doc << PrintSep(args);
doc << "]";
return doc;
}
......@@ -633,11 +649,7 @@ class PrettyPrinter :
for (NodeRef shape : node->shape) {
shapes.push_back(PrintAttr(shape));
}
doc << PrintVec(shapes);
// conform to python tuple format (1,)
if (node->shape.size() == 1) {
doc << ",";
}
doc << PrintSep(shapes);
return doc << "), " << PrintDType(node->dtype) << "]";
}
......@@ -647,7 +659,7 @@ class PrettyPrinter :
fields.push_back(Print(field));
}
Doc doc;
doc << "(" << PrintVec(fields);
doc << "(" << PrintSep(fields);
// conform to python tuple format (1,)
if (node->fields.size() == 1) {
doc << ",";
......@@ -664,14 +676,14 @@ class PrettyPrinter :
for (Type type_param : node->type_params) {
type_params.push_back(Print(type_param));
}
doc << PrintVec(type_params);
doc << PrintSep(type_params);
doc << ">";
}
std::vector<Doc> arg_types;
for (Type arg_type : node->arg_types) {
arg_types.push_back(Print(arg_type));
}
return doc << "(" << PrintVec(arg_types) << ") -> " << Print(node->ret_type);
return doc << "(" << PrintSep(arg_types) << ") -> " << Print(node->ret_type);
}
Doc VisitType_(const RefTypeNode* node) final {
......@@ -710,7 +722,7 @@ class PrettyPrinter :
for (NodePtr<Node> val : op->data) {
arr_vals.push_back(PrintAttr(NodeRef(val)));
}
doc << PrintVec(arr_vals);
doc << PrintSep(arr_vals);
doc << "]";
return doc;
}
......@@ -771,7 +783,9 @@ class PrettyPrinter::AttrPrinter : public AttrVisitor {
}
void Visit(const char* key, double* value) final {
PrintKV(key, *value);
Doc doc;
doc << key << "=" << *value << "f";
docs->push_back(doc);
}
void Visit(const char* key, int64_t* value) final {
PrintKV(key, *value);
......@@ -843,7 +857,7 @@ std::string PrettyPrint_(const NodeRef& node,
bool show_meta_data,
runtime::TypedPackedFunc<std::string(Expr)> annotate) {
Doc doc;
doc << "v0.0.3" << "\n"
doc << "v0.0.3" << PrintNewLine()
<< PrettyPrinter(show_meta_data, annotate).PrintFinal(node);
return doc.str();
}
......
......@@ -16,7 +16,7 @@
# under the License.
import tvm
from tvm import relay
from tvm.relay.analysis import alpha_equal
from tvm.relay.analysis import alpha_equal, assert_alpha_equal
from nose.tools import nottest, raises
from numpy import isclose
from typing import Union
......@@ -60,12 +60,9 @@ TYPES = {
"float16x4",
}
def assert_alpha_equal(a, b):
if not alpha_equal(a, b):
raise Exception("lhs is: ", str(a), "rhs is: ", str(b))
def roundtrip(expr):
assert_alpha_equal(relay.fromtext(str(expr)), expr)
x = relay.fromtext(str(expr))
assert_alpha_equal(x, expr)
def parse_text(code):
......@@ -112,6 +109,16 @@ def test_comments():
UNIT
)
assert parses_as(
"""
/* This is a block comment!
/*Block comment is recursive!*/
*/
()
""",
UNIT
)
def test_int_literal():
assert isinstance(parse_text("1"), relay.Constant)
......@@ -224,7 +231,7 @@ def test_let():
def test_seq():
assert parses_as(
"(); ()",
"();; ()",
relay.Let(
_,
UNIT,
......@@ -538,7 +545,7 @@ def test_tensor_type():
)
assert parses_as(
"let %_ : Tensor[(1,), float32] = (); ()",
"let %_ : Tensor[(1), float32] = (); ()",
relay.Let(
relay.Var("_", relay.TensorType((1,), "float32")),
UNIT,
......
......@@ -15,14 +15,27 @@
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm import relay
import tvm.relay.testing
import numpy as np
from tvm import relay
from tvm.relay import Expr
from tvm.relay.analysis import alpha_equal, assert_alpha_equal, assert_graph_equal, free_vars
do_print = [False]
SEMVER = "v0.0.3\n"
def astext(p, graph_equal=False):
txt = p.astext()
if isinstance(p, Expr) and free_vars(p):
return txt
x = relay.fromtext(txt)
if graph_equal:
assert_graph_equal(x, p)
else:
assert_alpha_equal(x, p)
return txt
def show(text):
if do_print[0]:
print("---------------------------")
......@@ -35,8 +48,8 @@ def test_func():
z = relay.add(x, one)
z = relay.add(z, z)
f = relay.Function([x, y], z)
show(z.astext())
show(f.astext())
show(astext(z))
show(astext(f))
def test_env():
......@@ -47,7 +60,7 @@ def test_env():
f = relay.Function([x, y], z)
env = relay.Module()
env["myf"] = f
text = env.astext()
text = astext(env)
assert "def @myf" in text
assert "def @myf" in str(env)
assert "add(%0, %0) /* ty=float32 */" in text
......@@ -65,7 +78,7 @@ def test_meta_data():
padding=(1, 1),
channels=2)
f = relay.Function([x, w], z)
text = f.astext()
text = astext(f, graph_equal=True)
text_no_meta = str(f)
assert "channels=2" in text
assert "channels=2" in text_no_meta
......@@ -73,25 +86,22 @@ def test_meta_data():
assert "meta[Variable][0]" in text_no_meta
assert "type_key" in text
assert "type_key" not in text_no_meta
show(text)
show(f)
text = relay.const([1,2,3]).astext()
text = astext(relay.const([1,2,3]))
assert "meta[relay.Constant][0]" in text
show(text)
def test_call_attrs():
x = relay.var("x")
# non default args
z = relay.nn.softmax(x, axis=2)
assert "axis=2" in z.astext()
assert "axis=2" in astext(z)
# default args
z = relay.nn.softmax(x)
assert "softmax(%x)" in z.astext()
assert "softmax(%x)" in astext(z)
# non default args
z = relay.expand_dims(x, axis=2, num_newaxis=2)
assert "num_newaxis=2" in z.astext()
assert "num_newaxis=2" in astext(z)
def test_let_if_scope():
......@@ -111,68 +121,72 @@ def test_let_if_scope():
result = sb.get()
f = relay.Function([x, y, cond], result)
text = f.astext()
text = astext(f)
assert text.count("{") == 4
assert "%cond: bool" in text
show(f.astext())
show(astext(f))
def test_variable_name():
# avoid pure number even if the namehint is pure number
v1 = relay.var("1")
assert "%v1" in v1.astext()
assert "%v1" in astext(v1)
def test_mlp():
net, params = tvm.relay.testing.mlp.get_workload(batch_size=1)
net.astext()
astext(net)
def test_resnet():
net, params = tvm.relay.testing.resnet.get_workload(batch_size=1)
net.astext()
astext(net)
def test_mobilenet():
net, params = tvm.relay.testing.mobilenet.get_workload(batch_size=1)
net.astext()
astext(net)
def test_dqn():
net, params = tvm.relay.testing.dqn.get_workload(batch_size=1)
net.astext()
astext(net)
def test_dcgan():
net, params = tvm.relay.testing.dcgan.get_workload(batch_size=1)
net.astext()
astext(net)
def test_lstm():
net, params = tvm.relay.testing.lstm.get_workload(1, 1)
astext(net)
net, params = tvm.relay.testing.lstm.get_workload(4, 4)
net.astext()
astext(net)
def test_inception_v3():
net, params = tvm.relay.testing.inception_v3.get_workload(batch_size=1)
net.astext()
astext(net)
def test_squeezenet():
for version in ['1.0', '1.1']:
net, params = tvm.relay.testing.squeezenet.get_workload(batch_size=1, version=version)
net.astext()
astext(net)
def test_vgg():
net, params = tvm.relay.testing.vgg.get_workload(batch_size=1)
net.astext()
astext(net)
def test_densenet():
net, params = tvm.relay.testing.densenet.get_workload(batch_size=1)
net.astext()
astext(net)
def test_call_node_order():
x = relay.var("x")
y = relay.var("y")
assert relay.Call(relay.Function([x], x), [relay.Call(relay.Function([y], y), [relay.const(1)])]).astext() == SEMVER + \
prog = relay.Call(relay.Function([x], x), [relay.Call(relay.Function([y], y), [relay.const(1)])])
assert astext(prog) == SEMVER + \
("%0 = fn (%y) {\n"
" %y\n"
"};\n"
......@@ -185,17 +199,25 @@ def test_call_node_order():
def test_let_inlining():
tup = relay.Tuple([relay.const(0), relay.const(0)])
x = relay.var("x")
assert relay.Let(x, tup, tup).astext() == SEMVER + \
assert astext(relay.Let(x, tup, tup)) == SEMVER + \
("%0 = (0, 0);\n"
"let %x = %0;\n"
"%0")
assert relay.Let(x, tup, x).astext() == SEMVER + \
assert astext(relay.Let(x, tup, x)) == SEMVER + \
("let %x = (0, 0);\n"
"%x")
def test_zeros():
x = relay.op.zeros([], "float32")
astext(x)
if __name__ == "__main__":
do_print[0] = True
test_lstm()
test_zeros()
test_meta_data()
test_let_inlining()
test_resnet()
test_mobilenet()
test_mlp()
......@@ -207,9 +229,7 @@ if __name__ == "__main__":
test_densenet()
test_func()
test_env()
test_meta_data()
test_call_attrs()
test_let_if_scope()
test_variable_name()
test_call_node_order()
test_let_inlining()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment