Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
2973f8a6
Commit
2973f8a6
authored
Jul 18, 2019
by
雾雨魔理沙
Committed by
Tianqi Chen
Jul 18, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Relay] parser/pretty printer roundtripping (#3536)
parent
e5efc632
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
759 additions
and
429 deletions
+759
-429
python/tvm/relay/_parser.py
+168
-52
python/tvm/relay/analysis.py
+31
-0
python/tvm/relay/grammar/Relay.g4
+60
-65
python/tvm/relay/grammar/py3/RelayLexer.py
+177
-138
python/tvm/relay/grammar/py3/RelayParser.py
+0
-0
python/tvm/relay/grammar/py3/RelayVisitor.py
+67
-22
python/tvm/relay/op/nn/nn.py
+34
-12
python/tvm/relay/parser.py
+4
-1
python/tvm/relay/testing/densenet.py
+1
-1
python/tvm/relay/ty.py
+1
-1
src/relay/ir/alpha_equal.cc
+69
-36
src/relay/ir/doc.cc
+1
-1
src/relay/ir/doc.h
+11
-6
src/relay/ir/pretty_printer.cc
+70
-56
tests/python/relay/test_ir_parser.py
+15
-8
tests/python/relay/test_ir_text_printer.py
+50
-30
No files found.
python/tvm/relay/_parser.py
View file @
2973f8a6
...
...
@@ -15,14 +15,14 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-
impor
t
# pylint: disable=invalid-name, unused-
argumen
t
"""A parser for Relay's text format."""
from
__future__
import
absolute_import
import
sys
from
ast
import
literal_eval
from
collections
import
deque
from
typing
import
TypeVar
,
Deque
,
Tuple
,
Optional
,
Union
,
NamedTuple
,
List
,
Callable
,
Any
,
Dict
import
tvm
...
...
@@ -32,6 +32,23 @@ from . import expr
from
.
import
ty
from
.
import
op
PYTHON_VERSION
=
sys
.
version_info
.
major
try
:
from
.grammar.py3.RelayVisitor
import
RelayVisitor
from
.grammar.py3.RelayParser
import
RelayParser
from
.grammar.py3.RelayLexer
import
RelayLexer
except
ImportError
:
raise
Exeption
(
"Couldn't find ANTLR parser. Try building with USE_ANTLR=ON."
)
try
:
from
antlr4
import
InputStream
,
CommonTokenStream
from
antlr4.error.ErrorListener
import
ErrorListener
except
ImportError
:
raise
Exception
(
"Couldn't find ANTLR runtime."
+
"Try running `pip{version} install antlr4-python{version}-runtime`."
.
format
(
version
=
PYTHON_VERSION
))
sys
.
setrecursionlimit
(
10000
)
class
ParseError
(
Exception
):
"""Exception type for parse errors."""
...
...
@@ -41,21 +58,50 @@ class ParseError(Exception):
super
(
ParseError
,
self
)
.
__init__
()
self
.
message
=
message
PYTHON_VERSION
=
sys
.
version_info
.
major
try
:
from
.grammar.py3.RelayVisitor
import
RelayVisitor
from
.grammar.py3.RelayParser
import
RelayParser
from
.grammar.py3.RelayLexer
import
RelayLexer
except
ImportError
:
raise
ParseError
(
"Couldn't find ANTLR parser. Try building with USE_ANTLR=ON."
)
def
__repr__
(
self
):
return
"ParseError({})"
.
format
(
self
.
message
)
try
:
from
antlr4
import
ParserRuleContext
,
InputStream
,
CommonTokenStream
from
antlr4.tree.Tree
import
TerminalNode
except
ImportError
:
raise
ParseError
(
"Couldn't find ANTLR runtime."
+
"Try running `pip{version} install antlr4-python{version}-runtime`."
.
format
(
version
=
PYTHON_VERSION
))
def
__str__
(
self
):
return
repr
(
self
)
class
OpWrapper
:
"""Overload the __call__ for op."""
pass
class
ExprOp
(
OpWrapper
):
"""Call an expr. The default, but does not handle attrs well."""
def
__init__
(
self
,
operator
):
self
.
operator
=
operator
def
__call__
(
self
,
args
,
attrs
,
type_args
):
try
:
return
expr
.
Call
(
self
.
operator
,
args
,
attrs
,
type_args
)
except
Exception
:
raise
Exception
(
str
(
self
.
operator
)
+
" "
+
str
(
attrs
))
class
FuncOp
(
OpWrapper
):
"""Convert the attrs, call the python function with the attrs passed in as keyword arguments.
Tvm should provide this in the future, as this is pretty similar to what op.get is providing.
"""
def
__init__
(
self
,
operator
):
self
.
operator
=
operator
def
convert
(
self
,
v
):
if
isinstance
(
v
,
tuple
):
return
tuple
([
self
.
convert
(
x
)
for
x
in
v
])
if
isinstance
(
v
,
expr
.
Constant
):
return
v
.
data
.
asnumpy
()
.
item
()
if
isinstance
(
v
,
str
):
return
v
raise
Exception
(
v
)
def
__call__
(
self
,
args
,
attrs
,
type_args
):
if
attrs
is
None
:
attrs
=
{}
x
=
self
.
operator
(
*
args
,
**
{
k
:
self
.
convert
(
v
)
for
k
,
v
in
attrs
.
items
()})
if
isinstance
(
x
,
expr
.
TupleWrapper
):
x
=
x
.
astuple
()
return
x
BINARY_OPS
=
{
RelayParser
.
MUL
:
op
.
multiply
,
...
...
@@ -70,6 +116,24 @@ BINARY_OPS = {
RelayParser
.
NE
:
op
.
not_equal
,
}
FUNC_OPS
=
{
"nn.conv2d"
:
op
.
nn
.
conv2d
,
"nn.batch_norm"
:
op
.
nn
.
batch_norm
,
"nn.dense"
:
op
.
nn
.
dense
,
"nn.bias_add"
:
op
.
nn
.
bias_add
,
"nn.max_pool2d"
:
op
.
nn
.
max_pool2d
,
"nn.global_max_pool2d"
:
op
.
nn
.
global_max_pool2d
,
"nn.avg_pool2d"
:
op
.
nn
.
avg_pool2d
,
"nn.global_avg_pool2d"
:
op
.
nn
.
global_avg_pool2d
,
"nn.softmax"
:
op
.
nn
.
softmax
,
"reshape"
:
op
.
reshape
,
"nn.conv2d_transpose"
:
op
.
nn
.
conv2d_transpose
,
"concatenate"
:
op
.
concatenate
,
"nn.dropout"
:
op
.
nn
.
dropout_raw
,
"zeros"
:
op
.
zeros
,
"split"
:
op
.
split
,
}
TYPE_PREFIXES
=
[
"int"
,
"uint"
,
...
...
@@ -77,9 +141,9 @@ TYPE_PREFIXES = [
"bool"
,
]
T
=
TypeVar
(
"T"
)
Scope
=
Deque
[
Tuple
[
str
,
T
]]
Scopes
=
Deque
[
Scope
[
T
]]
T
=
ty
.
TypeVar
(
"T"
)
#
Scope = Deque[Tuple[str, T]]
#
Scopes = Deque[Scope[T]]
def
lookup
(
scopes
,
name
):
# type: (Scopes[T], str) -> Optional[T]
...
...
@@ -108,6 +172,8 @@ def spanify(f):
ast
=
f
(
*
args
,
**
kwargs
)
line
,
col
=
ctx
.
getSourceInterval
()
sp
=
Span
(
sn
,
line
,
col
)
if
isinstance
(
ast
,
tvm
.
relay
.
expr
.
TupleWrapper
):
ast
=
ast
.
astuple
()
ast
.
set_span
(
sp
)
return
ast
return
_wrapper
...
...
@@ -179,6 +245,9 @@ class ParseTreeToRelayIR(RelayVisitor):
self
.
type_param_scopes
[
0
]
.
appendleft
((
name
,
typ
))
return
typ
def
visitProjection
(
self
,
ctx
):
return
expr
.
TupleGetItem
(
self
.
visit
(
ctx
.
expr
()),
self
.
visit
(
ctx
.
NAT
()))
def
visitTerminal
(
self
,
node
):
# type: (TerminalNode) -> Union[expr.Expr, int, float]
"""Visit lexer tokens that aren't ignored or visited by other functions."""
...
...
@@ -213,12 +282,15 @@ class ParseTreeToRelayIR(RelayVisitor):
if
node_text
==
"False"
:
return
False
raise
ParseError
(
"Unrecognized BOOL_LIT: `{}`"
.
format
(
node_text
))
if
node_type
==
RelayLexer
.
QUOTED_STRING
:
return
literal_eval
(
node_text
)
raise
ParseError
(
"todo:
{}
"
.
format
(
node_text
))
raise
ParseError
(
"todo:
`{}`
"
.
format
(
node_text
))
def
visit_list
(
self
,
ctx_list
):
# type: (List[ParserRuleContext]) -> List[Any]
""""Visit a list of contexts."""
assert
isinstance
(
ctx_list
,
list
)
return
[
self
.
visit
(
ctx
)
for
ctx
in
ctx_list
]
...
...
@@ -232,6 +304,11 @@ class ParseTreeToRelayIR(RelayVisitor):
return
self
.
visit
(
ctx
)
def
visitProg
(
self
,
ctx
):
self
.
meta
=
None
if
ctx
.
METADATA
():
header
,
data
=
str
(
ctx
.
METADATA
())
.
split
(
'
\n
'
,
1
)
assert
header
==
"METADATA:"
self
.
meta
=
tvm
.
load_json
(
data
)
# type: (RelayParser.ProgContext) -> Union[expr.Expr, module.Module]
if
ctx
.
defn
():
self
.
visit_list
(
ctx
.
defn
())
...
...
@@ -245,11 +322,14 @@ class ParseTreeToRelayIR(RelayVisitor):
# Exprs
def
visitOpIdent
(
self
,
ctx
):
# type: (RelayParser.OpIdentContext) -> op.Op
return
op
.
get
(
ctx
.
CNAME
()
.
getText
())
op_name
=
ctx
.
CNAME
()
.
getText
()
if
op_name
in
FUNC_OPS
:
return
FuncOp
(
FUNC_OPS
[
op_name
])
return
ExprOp
(
op
.
get
(
op_name
))
# pass through
def
visitParen
s
(
self
,
ctx
):
# type: (RelayParser.Paren
s
Context) -> expr.Expr
def
visitParen
(
self
,
ctx
):
# type: (RelayParser.ParenContext) -> expr.Expr
return
self
.
visit
(
ctx
.
expr
())
# pass through
...
...
@@ -283,25 +363,17 @@ class ParseTreeToRelayIR(RelayVisitor):
tup
=
self
.
visit_list
(
ctx
.
expr
())
return
expr
.
Tuple
(
tup
)
# Currently doesn't support mutable sequencing.
def
visitLet
(
self
,
ctx
):
# type: (RelayParser.SeqContext) -> expr.Let
"""Desugar various sequence constructs to Relay Let nodes."""
if
ctx
.
MUT
()
is
not
None
:
raise
ParseError
(
"Mutation is currently unsupported."
)
if
ctx
.
var
()
is
None
or
ctx
.
var
()
.
ident
()
is
None
:
if
ctx
.
var
()
is
None
:
# anonymous identity
ident
=
"_"
type_
=
None
var
=
self
.
mk_var
(
ident
,
type_
)
else
:
local_var
=
ctx
.
var
()
.
ident
()
.
LOCAL_VAR
()
if
local_var
is
None
:
raise
ParseError
(
"Only local ids may be used in `let`s."
)
ident
=
local_var
.
getText
()[
1
:]
type_
=
self
.
getType_
(
ctx
.
var
()
.
type_
())
var
=
self
.
mk_var
(
ident
,
type_
)
var
=
self
.
visitVar
(
ctx
.
var
())
self
.
enter_var_scope
()
value
=
self
.
visit
(
ctx
.
expr
(
0
))
...
...
@@ -326,7 +398,7 @@ class ParseTreeToRelayIR(RelayVisitor):
def
visitVar
(
self
,
ctx
):
# type: (RelayParser.VarContext) -> expr.Var
"""Visit a single variable."""
ident
=
ctx
.
ident
()
.
LOCAL_VAR
()
ident
=
ctx
.
LOCAL_VAR
()
if
ident
is
None
:
raise
ParseError
(
"Only local ids may be used in vars."
)
...
...
@@ -344,19 +416,29 @@ class ParseTreeToRelayIR(RelayVisitor):
# type: (RelayParser.AttrContext) -> Tuple[str, expr.Expr]
return
(
ctx
.
CNAME
()
.
getText
(),
self
.
visit
(
ctx
.
expr
()))
def
visitAttrList
(
self
,
ctx
):
def
visitArgNoAttr
(
self
,
ctx
):
return
(
self
.
visit_list
(
ctx
.
varList
()
.
var
()),
None
)
def
visitAttrSeq
(
self
,
ctx
):
# type: (RelayParser.AttrListContext) -> Dict[str, expr.Expr]
return
dict
(
self
.
visit_list
(
ctx
.
attr
()))
def
visitArgWithAttr
(
self
,
ctx
):
return
(
self
.
visit_list
(
ctx
.
var
()),
self
.
visitAttrSeq
(
ctx
.
attrSeq
()))
def
visitArgList
(
self
,
ctx
# type: RelayParser.ArgListContext
):
# type: (...) -> Tuple[Optional[List[expr.Var]], Optional[Dict[str, expr.Expr]]]
var_list
=
self
.
visit
(
ctx
.
varList
())
if
ctx
.
varList
()
else
None
attr_list
=
self
.
visit
(
ctx
.
attrList
())
if
ctx
.
attrList
()
else
None
return
(
var_list
,
attr_list
)
def
visitMeta
(
self
,
ctx
):
type_key
=
str
(
ctx
.
CNAME
())
index
=
int
(
self
.
visit
(
ctx
.
NAT
()))
return
self
.
meta
[
type_key
][
index
]
def
mk_func
(
self
,
ctx
):
# type: (Union[RelayParser.FuncContext, RelayParser.DefnContext]) -> expr.Function
"""Construct a function from either a Func or Defn."""
...
...
@@ -365,7 +447,7 @@ class ParseTreeToRelayIR(RelayVisitor):
self
.
enter_var_scope
()
# Capture type params in params.
self
.
enter_type_param_scope
()
type_params
=
ctx
.
typeParam
Seq
()
type_params
=
ctx
.
typeParam
List
()
if
type_params
is
not
None
:
type_params
=
type_params
.
ident
()
...
...
@@ -405,18 +487,25 @@ class ParseTreeToRelayIR(RelayVisitor):
raise
ParseError
(
"Only global ids may be used in `def`s."
)
ident_name
=
ident
.
getText
()[
1
:]
ident
=
self
.
mk_global_var
(
ident_name
)
self
.
module
[
ident
]
=
self
.
mk_func
(
ctx
)
def
visitCallNoAttr
(
self
,
ctx
):
return
(
self
.
visit_list
(
ctx
.
exprList
()
.
expr
()),
None
)
def
visitCallWithAttr
(
self
,
ctx
):
return
(
self
.
visit_list
(
ctx
.
expr
()),
self
.
visit
(
ctx
.
attrSeq
()))
def
call
(
self
,
func
,
args
,
attrs
,
type_args
):
if
isinstance
(
func
,
OpWrapper
):
return
func
(
args
,
attrs
,
type_args
)
return
expr
.
Call
(
func
,
args
,
attrs
,
type_args
)
@spanify
def
visitCall
(
self
,
ctx
):
# type: (RelayParser.CallContext) -> expr.Call
visited_exprs
=
self
.
visit_list
(
ctx
.
expr
())
func
=
visited_exprs
[
0
]
args
=
visited_exprs
[
1
:]
return
expr
.
Call
(
func
,
args
,
None
,
None
)
func
=
self
.
visit
(
ctx
.
expr
())
args
,
attrs
=
self
.
visit
(
ctx
.
callList
())
return
self
.
call
(
func
,
args
,
attrs
,
[])
@spanify
def
visitIfElse
(
self
,
ctx
):
...
...
@@ -438,9 +527,7 @@ class ParseTreeToRelayIR(RelayVisitor):
def
visitGraph
(
self
,
ctx
):
# type: (RelayParser.GraphContext) -> expr.Expr
"""Visit a graph variable assignment."""
if
ctx
.
ident
()
.
GRAPH_VAR
()
is
None
:
raise
ParseError
(
"Expected a graph var, but got `{}`"
.
format
(
ctx
.
ident
()
.
getText
()))
graph_nid
=
int
(
ctx
.
ident
()
.
GRAPH_VAR
()
.
getText
()[
1
:])
graph_nid
=
int
(
ctx
.
GRAPH_VAR
()
.
getText
()[
1
:])
self
.
enter_var_scope
()
value
=
self
.
visit
(
ctx
.
expr
(
0
))
...
...
@@ -500,15 +587,18 @@ class ParseTreeToRelayIR(RelayVisitor):
# type: (RelayParser.ParensShapeContext) -> int
return
self
.
visit
(
ctx
.
shape
())
def
visitShape
Seq
(
self
,
ctx
):
# type: (RelayParser.Shape
Seq
Context) -> List[int]
def
visitShape
List
(
self
,
ctx
):
# type: (RelayParser.Shape
List
Context) -> List[int]
return
self
.
visit_list
(
ctx
.
shape
())
def
visitTensor
(
self
,
ctx
):
return
tuple
(
self
.
visit_list
(
ctx
.
expr
()))
def
visitTensorType
(
self
,
ctx
):
# type: (RelayParser.TensorTypeContext) -> ty.TensorType
"""Create a simple tensor type. No generics."""
shape
=
self
.
visit
(
ctx
.
shape
Seq
())
shape
=
self
.
visit
(
ctx
.
shape
List
())
dtype
=
self
.
visit
(
ctx
.
type_
())
if
not
isinstance
(
dtype
,
ty
.
TensorType
):
...
...
@@ -536,11 +626,37 @@ def make_parser(data):
"""Construct a RelayParser a given data stream."""
input_stream
=
InputStream
(
data
)
lexer
=
RelayLexer
(
input_stream
)
lexer
.
addErrorListener
(
StrictErrorListener
(
data
))
token_stream
=
CommonTokenStream
(
lexer
)
return
RelayParser
(
token_stream
)
p
=
RelayParser
(
token_stream
)
p
.
addErrorListener
(
StrictErrorListener
(
data
))
return
p
__source_name_counter__
=
0
class
StrictErrorListener
(
ErrorListener
):
"""This ErrorListener fail eagerly on all error, and report the program."""
def
__init__
(
self
,
text
):
self
.
text
=
text
def
syntaxError
(
self
,
recognizer
,
offendingSymbol
,
line
,
column
,
msg
,
e
):
raise
Exception
(
"Syntax Error in:
\n
"
+
self
.
text
)
def
reportAmbiguity
(
self
,
recognizer
,
dfa
,
startIndex
,
stopIndex
,
exact
,
ambigAlts
,
configs
):
raise
Exception
(
"Ambiguity Error in:
\n
"
+
self
.
text
)
def
reportAttemptingFullContext
(
self
,
recognizer
,
dfa
,
startIndex
,
stopIndex
,
conflictingAlts
,
configs
):
raise
Exception
(
"Attempting Full Context in:
\n
"
+
self
.
text
)
def
reportContextSensitivity
(
self
,
recognizer
,
dfa
,
startIndex
,
stopIndex
,
prediction
,
configs
):
raise
Exception
(
"Context Sensitivity in:
\n
"
+
self
.
text
)
def
fromtext
(
data
,
source_name
=
None
):
# type: (str, str) -> Union[expr.Expr, module.Module]
"""Parse a Relay program."""
...
...
python/tvm/relay/analysis.py
View file @
2973f8a6
...
...
@@ -224,6 +224,20 @@ def alpha_equal(lhs, rhs):
return
bool
(
_make
.
_alpha_equal
(
lhs
,
rhs
))
def
assert_alpha_equal
(
lhs
,
rhs
):
"""Assert that two Relay expr is structurally equivalent. (alpha equivalence).
Parameters
----------
lhs : tvm.relay.Expr
One of the input Expression.
rhs : tvm.relay.Expr
One of the input Expression.
"""
_make
.
_assert_alpha_equal
(
lhs
,
rhs
)
def
graph_equal
(
lhs
,
rhs
):
"""Compare two Relay expr for data-flow equivalence.
The difference between this and alpha-equality is that
...
...
@@ -246,6 +260,23 @@ def graph_equal(lhs, rhs):
return
bool
(
_make
.
_graph_equal
(
lhs
,
rhs
))
def
assert_graph_equal
(
lhs
,
rhs
):
"""Compare two Relay expr for data-flow equivalence.
The difference between this and alpha-equality is that
variables are not expected to match between lhs and rhs;
they are treated as sources and are mapped between each other.
Parameters
----------
lhs : tvm.relay.Expr
One of the input Expression.
rhs : tvm.relay.Expr
One of the input Expression.
"""
_make
.
_assert_graph_equal
(
lhs
,
rhs
)
def
collect_device_info
(
expr
):
"""Collect the device allocation map for the given expression. The device
ids are propagated from the `device_copy` operators.
...
...
python/tvm/relay/grammar/Relay.g4
View file @
2973f8a6
...
...
@@ -17,15 +17,20 @@
* under the License.
*/
// list = *, seq = ?
grammar Relay;
SEMVER: 'v0.0.3' ;
// Lexing
// comments
WS : [ \t\n\r]+ -> skip ;
LINE_COMMENT : '//' .*? '\n' -> skip ;
COMMENT : '/*' .*? '*/' -> skip ;
COMMENT : '/*' (COMMENT|.)*? '*/' -> skip;
WS : [ \t\n\r]+ -> skip;
LINE_COMMENT : '//' .*? '\n' -> skip;
fragment ESCAPED_QUOTE : '\\"';
QUOTED_STRING : '"' ( ESCAPED_QUOTE | ~('\n'|'\r') )*? '"';
// operators
MUL: '*' ;
...
...
@@ -39,18 +44,18 @@ GE: '>=' ;
EQ: '==' ;
NE: '!=' ;
opIdent: CNAME ;
GLOBAL_VAR: '@' CNAME ;
LOCAL_VAR: '%' CNAME;
GRAPH_VAR: '%' NAT;
MUT: 'mut' ;
BOOL_LIT
: 'True'
| 'False'
;
CNAME: ('_'|LETTER) ('_'|LETTER|DIGIT)* ('.' CNAME)*;
opIdent: CNAME ;
GLOBAL_VAR: '@' CNAME ;
LOCAL_VAR: '%' CNAME;
GRAPH_VAR: '%' NAT;
DATATYPE : 'int64';
// non-negative floats
fragment PREFLOAT : NAT ('.' NAT)? EXP?; // 1.35, 1.35E-9, 0.3, 4.5, 1, 1e10 3e4
...
...
@@ -60,109 +65,99 @@ FLOAT : PREFLOAT 'f';
NAT: DIGIT+ ;
fragment EXP: [eE] [+\-]? NAT ; // \- since - means "range" inside [...]
CNAME: ('_'|LETTER) ('_'|LETTER|DIGIT)* ;
fragment LETTER: [a-zA-Z] ;
fragment DIGIT: [0-9] ;
fragment LETTER: [a-zA-Z];
fragment DIGIT: [0-9];
METADATA: 'METADATA:' .*;
// Parsing
// A Relay program is a list of global definitions or an expression.
prog: SEMVER (defn* | expr) EOF ;
prog: SEMVER (defn* | expr)
METADATA?
EOF ;
// option: 'set' ident BOOL_LIT ;
exprList: (expr (',' expr)*)?;
callList
: exprList # callNoAttr
| (expr ',')* attrSeq # callWithAttr
;
expr
// operators
: '(' expr ')' # parens
: '(' expr ')' # paren
| '{' expr '}' # paren
// function application
| expr '('
(expr (',' expr)*)? ')'
# call
| expr '('
callList ')'
# call
| '-' expr # neg
| expr op=('*'|'/') expr # binOp
| expr op=('+'|'-') expr # binOp
| expr op=('<'|'>'|'<='|'>=') expr # binOp
| expr op=('=='|'!=') expr # binOp
// function definition
| func # funcExpr
// tuples and tensors
| '(' ')' # tuple
| '(' expr ',' ')' # tuple
| '(' expr (',' expr)+ ')' # tuple
| expr '.' NAT # projection
| '[' (expr (',' expr)*)? ']' # tensor
| 'if' '(' expr ')' body 'else' body # ifElse
// sequencing
| 'let' MUT? var '=' expr ';' expr # let
| 'let' MUT? var '=' '{' expr '}' ';' expr # let
| 'let' var '=' expr ';' expr # let
// sugar for let %_ = expr; expr
| expr ';' expr # let
| ident '=' expr ';' expr # graph
// mutable update
// | ident '=' expr # writeRef
// | expr '^' # readRef
| expr ';;' expr # let
| GRAPH_VAR '=' expr ';' expr # graph
| ident # identExpr
| scalar # scalarExpr
// | expr '.' NAT # project
// | 'debug' # debug
| meta # metaExpr
| QUOTED_STRING # stringExpr
;
func: 'fn' typeParam
Seq
? '(' argList ')' ('->' type_)? body ;
defn: 'def' ident typeParam
Seq
? '(' argList ')' ('->' type_)? body ;
func: 'fn' typeParam
List
? '(' argList ')' ('->' type_)? body ;
defn: 'def' ident typeParam
List
? '(' argList ')' ('->' type_)? body ;
argList
: varList
| attrList
| varList ',' attrList
: varList # argNoAttr
| (var ',')* attrSeq # argWithAttr
;
varList: (var (',' var)*)?
;
var:
ident (':' type_)?
;
varList: (var (',' var)*)?;
var:
LOCAL_VAR (':' type_)?
;
attr
List: (attr (',' attr)*)?
;
attr
Seq: attr (',' attr)*
;
attr: CNAME '=' expr ;
// TODO(@jmp): for improved type annotations
// returnAnno: (ident ':')? type_ ;
// relations: 'where' relation (',' relation)* ;
// relation: ident '(' (type_ (',' type_)*)? ')' ;
typeParamSeq
typeParamList
: '[' ']'
| '[' ident (',' ident)* ']'
;
type_
: '(' ')' # tupleType
| '(' type_ ',' ')' # tupleType
| '(' type_ (',' type_)+ ')' # tupleType
| typeIdent # typeIdentType
| 'Tensor' '[' shapeSeq ',' type_ ']' # tensorType
// currently unused
// | typeIdent '[' (type_ (',' type_)*)? ']' # callType
| 'fn' typeParamSeq? '(' (type_ (',' type_)*)? ')' '->' type_ # funcType
| '_' # incompleteType
| NAT # intType
: '(' ')' # tupleType
| '(' type_ ',' ')' # tupleType
| '(' type_ (',' type_)+ ')' # tupleType
| typeIdent # typeIdentType
| 'Tensor' '[' shapeList ',' type_ ']' # tensorType
| 'fn' typeParamList? '(' (type_ (',' type_)*)? ')' '->' type_ # funcType
| '_' # incompleteType
| NAT # intType
;
shape
Seq
: '(' ')'
| '('
shape ','
')'
|
'(' shape (',' shape)+ ')'
shape
List
: '('
shape (',' shape)+
')'
| '(' ')'
|
shape
;
meta : 'meta' '[' CNAME ']' '[' NAT ']';
shape
: '(' shape ')' # parensShape
// | type_ op=('*'|'/') type_ # binOpType
// | type_ op=('+'|'-') type_ # binOpType
| NAT # intShape
: meta # metaShape
| '(' shape ')' # parensShape
| NAT # intShape
;
typeIdent : CNAME
;
typeIdent : CNAME;
// int8, int16, int32, int64
// uint8, uint16, uint32, uint64
// float16, float32, float64
...
...
python/tvm/relay/grammar/py3/RelayLexer.py
View file @
2973f8a6
...
...
@@ -7,116 +7,147 @@ import sys
def
serializedATN
():
with
StringIO
()
as
buf
:
buf
.
write
(
"
\3\u608b\ua72a\u8133\ub9ed\u417c\u3be7\u7786\u5964\2
*
"
)
buf
.
write
(
"
\u01
0d
\b\1\4\2\t\2\4\3\t\3\4\4\t\4\4\5\t\5\4\6\t\6\4\7
"
)
buf
.
write
(
"
\3\u608b\ua72a\u8133\ub9ed\u417c\u3be7\u7786\u5964\2
/
"
)
buf
.
write
(
"
\u01
4a
\b\1\4\2\t\2\4\3\t\3\4\4\t\4\4\5\t\5\4\6\t\6\4\7
"
)
buf
.
write
(
"
\t\7\4\b\t\b\4\t\t\t\4\n\t\n\4\13\t\13\4\f\t\f\4\r\t\r
"
)
buf
.
write
(
"
\4\16\t\16\4\17\t\17\4\20\t\20\4\21\t\21\4\22\t\22\4\23
"
)
buf
.
write
(
"
\t\23\4\24\t\24\4\25\t\25\4\26\t\26\4\27\t\27\4\30\t\30
"
)
buf
.
write
(
"
\4\31\t\31\4\32\t\32\4\33\t\33\4\34\t\34\4\35\t\35\4\36
"
)
buf
.
write
(
"
\t\36\4\37\t\37\4
\t
\4
!
\t
!
\4\"\t\"\4
#
\t
#
\4
$
\t
$
\4
%
\t
%
"
)
buf
.
write
(
"
\4
&
\t
&
\4\'\t\'\4
(
\t
(
\4
)
\t
)
\4
*
\t
*
\4
+
\t
+
\4
,
\t
,
\4
-
\t
-
\3\2
"
)
buf
.
write
(
"
\3\2\3\3\3\3\3\4\3\4\3\5\3\5\3\6\3\6\3\7\3\7\3\7\3\b\3
"
)
buf
.
write
(
"
\b\3\b\3\b\3\b\3\t\3\t\3\t\3\t\3\n\3\n\3\13\3\13\3\f\3
"
)
buf
.
write
(
"
\f\3\r\3\r\3\16\3\16\3\16\3\17\3\17\3\17\3\20\3\20\3\20
"
)
buf
.
write
(
"
\3\20\3\21\3\21\3\22\3\22\3\22\3\22\3\22\3\22\3\22\3\23
"
)
buf
.
write
(
"
\3\23\3\24\3\24\3\24\3\24\3\24\3\24\3\24\3\25\6\25\u0097
"
)
buf
.
write
(
"
\n\25\r\25\16\25\u0098\3\25\3\25\3\26\3\26\3\26\3\26\7
"
)
buf
.
write
(
"
\26\u00a1\n\26\f\26\16\26\u00a4\13\26\3\26\3\26\3\26\3
"
)
buf
.
write
(
"
\26\3\27\3\27\3\27\3\27\7\27\u00ae\n\27\f\27\16\27\u00b1
"
)
buf
.
write
(
"
\13\27\3\27\3\27\3\27\3\27\3\27\3\30\3\30\3\31\3\31\3
"
)
buf
.
write
(
"
\32\3\32\3\33\3\33\3\34\3\34\3\35\3\35\3\36\3\36\3\36
"
)
buf
.
write
(
"
\3\37\3\37\3\37\3
\3
\3
\3
!
\3
!
\3
!
\3\"\3\"\3\"\3
#
\3
#
\3
"
)
buf
.
write
(
"#
\3
$
\3
$
\3
$
\3
%
\3
%
\3
%
\3
%
\3
&
\3
&
\3
&
\3
&
\3
&
\3
&
\3
&
\3
&
\3
&
\5
&
\u00e6
"
)
buf
.
write
(
"
\n
&
\3\'\3\'\3\'\5\'\u00eb\n\'\3\'\5\'\u00ee\n\'\3
(
\3
("
)
buf
.
write
(
"
\3
(
\3
)
\6
)
\u00f4\n
)
\r
)
\16
)
\u00f5\3
*
\3
*
\5
*
\u00fa\n
*
\3
*
\3
"
)
buf
.
write
(
"*
\3
+
\3
+
\5
+
\u0100\n
+
\3
+
\3
+
\3
+
\7
+
\u0105\n
+
\f
+
\16
+
\u0108
"
)
buf
.
write
(
"
\13
+
\3
,
\3
,
\3
-
\3
-
\4\u00a2\u00af\2
.
\3\3\5\4\7\5\t\6\13\7
"
)
buf
.
write
(
"
\r\b\17\t\21\n\23\13\25\f\27\r\31\16\33\17\35\20\37\21
"
)
buf
.
write
(
"!
\22
#
\23
%
\24\'\25
)
\26
+
\27
-
\30
/
\31\61\32\63\33\65\34\67
"
)
buf
.
write
(
"
\35
9
\36
;
\37
= ?!A
\"
C#E$G
%
I&K
\'
M
\2
O(Q)S
\2
U*W
\2
Y
\2\3\2\7
"
)
buf
.
write
(
"
\5\2\13\f\17\17\"\"\4\2
GGgg
\4\2
--//
\4\2
C
\\
c|
\3\2\62
;
\2
"
)
buf
.
write
(
"
\u0114\2\3\3\2\2\2\2\5\3\2\2\2\2\7\3\2\2\2\2\t\3\2\2\2
"
)
buf
.
write
(
"
\2\13\3\2\2\2\2\r\3\2\2\2\2\17\3\2\2\2\2\21\3\2\2\2\2
"
)
buf
.
write
(
"
\23\3\2\2\2\2\25\3\2\2\2\2\27\3\2\2\2\2\31\3\2\2\2\2\33
"
)
buf
.
write
(
"
\3\2\2\2\2\35\3\2\2\2\2\37\3\2\2\2\2
!
\3\2\2\2\2
#
\3\2\2
"
)
buf
.
write
(
"
\2\2
%
\3\2\2\2\2\'\3\2\2\2\2
)
\3\2\2\2\2
+
\3\2\2\2\2
-
\3\2
"
)
buf
.
write
(
"
\2\2\2
/
\3\2\2\2\2\61\3\2\2\2\2\63\3\2\2\2\2\65\3\2\2\2
"
)
buf
.
write
(
"
\2\67\3\2\2\2\2
9
\3\2\2\2\2
;
\3\2\2\2\2
=
\3\2\2\2\2
?
\3\2
"
)
buf
.
write
(
"
\2\2\2
A
\3\2\2\2\2
C
\3\2\2\2\2
E
\3\2\2\2\2
G
\3\2\2\2\2
I
\3
"
)
buf
.
write
(
"
\2\2\2\2
K
\3\2\2\2\2
O
\3\2\2\2\2
Q
\3\2\2\2\2
U
\3\2\2\2\3
["
)
buf
.
write
(
"
\3\2\2\2\5
]
\3\2\2\2\7
_
\3\2\2\2\t
a
\3\2\2\2\13
c
\3\2\2\2
"
)
buf
.
write
(
"
\r
e
\3\2\2\2\17
h
\3\2\2\2\21
m
\3\2\2\2\23
q
\3\2\2\2\25
s
\3
"
)
buf
.
write
(
"
\2\2\2\27
u
\3\2\2\2\31
w
\3\2\2\2\33
y
\3\2\2\2\35
|
\3\2\2\2
"
)
buf
.
write
(
"
\37\177\3\2\2\2
!
\u0083\3\2\2\2
#
\u0085\3\2\2\2
%
\u008c\3
"
)
buf
.
write
(
"
\2\2\2\'\u008e\3\2\2\2
)
\u0096\3\2\2\2
+
\u009c\3\2\2\2
-"
)
buf
.
write
(
"
\u00a9\3\2\2\2
/
\u00b7\3\2\2\2\61\u00b9\3\2\2\2\63\u00bb
"
)
buf
.
write
(
"
\3\2\2\2\65\u00bd\3\2\2\2\67\u00bf\3\2\2\2
9
\u00c1\3\2
"
)
buf
.
write
(
"
\2\2
;
\u00c3\3\2\2\2
=
\u00c6\3\2\2\2
?
\u00c9\3\2\2\2
A
\u00cc
"
)
buf
.
write
(
"
\3\2\2\2
C
\u00cf\3\2\2\2
E
\u00d2\3\2\2\2
G
\u00d5\3\2\2\2
"
)
buf
.
write
(
"I
\u00d8\3\2\2\2
K
\u00e5\3\2\2\2
M
\u00e7\3\2\2\2
O
\u00ef\3
"
)
buf
.
write
(
"
\2\2\2
Q
\u00f3\3\2\2\2
S
\u00f7\3\2\2\2
U
\u00ff\3\2\2\2
W
\u0109
"
)
buf
.
write
(
"
\3\2\2\2
Y
\u010b\3\2\2\2
[
\\\7
*
\2\2\\\4\3\2\2\2
]^
\7
+
\2\2
"
)
buf
.
write
(
"^
\6\3\2\2\2
_`
\7
.
\2\2
`
\b\3\2\2\2
ab
\7
]
\2\2
b
\n\3\2\2\2
cd"
)
buf
.
write
(
"
\7
_
\2\2
d
\f\3\2\2\2
ef
\7
k
\2\2
fg
\7
h
\2\2
g
\16\3\2\2\2
hi
\7
g"
)
buf
.
write
(
"
\2\2
ij
\7
n
\2\2
jk
\7
u
\2\2
kl
\7
g
\2\2
l
\20\3\2\2\2
mn
\7
n
\2\2
n"
)
buf
.
write
(
"o
\7
g
\2\2
op
\7
v
\2\2
p
\22\3\2\2\2
qr
\7
?
\2\2
r
\24\3\2\2\2
st
\7
"
)
buf
.
write
(
"=
\2\2
t
\26\3\2\2\2
uv
\7
}
\2\2
v
\30\3\2\2\2
wx
\7\177\2\2
x
\32
"
)
buf
.
write
(
"
\3\2\2\2
yz
\7
h
\2\2
z{
\7
p
\2\2
{
\34\3\2\2\2
|}
\7
/
\2\2
}~
\7
@
\2
"
)
buf
.
write
(
"
\2
~
\36\3\2\2\2\177\u0080\7
f
\2\2\u0080\u0081\7
g
\2\2\u0081
"
)
buf
.
write
(
"
\u0082\7
h
\2\2\u0082
\3\2\2\2\u0083\u0084\7
<
\2\2\u0084
"
)
buf
.
write
(
"
\"\3\2\2\2\u0085\u0086\7
V
\2\2\u0086\u0087\7
g
\2\2\u0087
"
)
buf
.
write
(
"
\u0088\7
p
\2\2\u0088\u0089\7
u
\2\2\u0089\u008a\7
q
\2\2\u008a
"
)
buf
.
write
(
"
\u008b\7
t
\2\2\u008b
$
\3\2\2\2\u008c\u008d\7
a
\2\2\u008d
"
)
buf
.
write
(
"&
\3\2\2\2\u008e\u008f\7
x
\2\2\u008f\u0090\7\62\2\2\u0090
"
)
buf
.
write
(
"
\u0091\7\60\2\2\u0091\u0092\7\62\2\2\u0092\u0093\7\60
"
)
buf
.
write
(
"
\2\2\u0093\u0094\7\65\2\2\u0094
(
\3\2\2\2\u0095\u0097\t
"
)
buf
.
write
(
"
\2\2\2\u0096\u0095\3\2\2\2\u0097\u0098\3\2\2\2\u0098\u0096
"
)
buf
.
write
(
"
\3\2\2\2\u0098\u0099\3\2\2\2\u0099\u009a\3\2\2\2\u009a
"
)
buf
.
write
(
"
\u009b\b\25\2\2\u009b
*
\3\2\2\2\u009c\u009d\7\61\2\2\u009d
"
)
buf
.
write
(
"
\u009e\7\61\2\2\u009e\u00a2\3\2\2\2\u009f\u00a1\13\2\2
"
)
buf
.
write
(
"
\2\u00a0\u009f\3\2\2\2\u00a1\u00a4\3\2\2\2\u00a2\u00a3
"
)
buf
.
write
(
"
\3\2\2\2\u00a2\u00a0\3\2\2\2\u00a3\u00a5\3\2\2\2\u00a4
"
)
buf
.
write
(
"
\u00a2\3\2\2\2\u00a5\u00a6\7\f\2\2\u00a6\u00a7\3\2\2\2
"
)
buf
.
write
(
"
\u00a7\u00a8\b\26\2\2\u00a8
,
\3\2\2\2\u00a9\u00aa\7\61
"
)
buf
.
write
(
"
\2\2\u00aa\u00ab\7
,
\2\2\u00ab\u00af\3\2\2\2\u00ac\u00ae
"
)
buf
.
write
(
"
\13\2\2\2\u00ad\u00ac\3\2\2\2\u00ae\u00b1\3\2\2\2\u00af
"
)
buf
.
write
(
"
\u00b0\3\2\2\2\u00af\u00ad\3\2\2\2\u00b0\u00b2\3\2\2\2
"
)
buf
.
write
(
"
\u00b1\u00af\3\2\2\2\u00b2\u00b3\7
,
\2\2\u00b3\u00b4\7
"
)
buf
.
write
(
"
\61\2\2\u00b4\u00b5\3\2\2\2\u00b5\u00b6\b\27\2\2\u00b6
"
)
buf
.
write
(
".
\3\2\2\2\u00b7\u00b8\7
,
\2\2\u00b8\60\3\2\2\2\u00b9\u00ba
"
)
buf
.
write
(
"
\7\61\2\2\u00ba\62\3\2\2\2\u00bb\u00bc\7
-
\2\2\u00bc\64
"
)
buf
.
write
(
"
\3\2\2\2\u00bd\u00be\7
/
\2\2\u00be\66\3\2\2\2\u00bf\u00c0
"
)
buf
.
write
(
"
\7
>
\2\2\u00c0
8
\3\2\2\2\u00c1\u00c2\7
@
\2\2\u00c2
:
\3\2\2
"
)
buf
.
write
(
"
\2\u00c3\u00c4\7
>
\2\2\u00c4\u00c5\7
?
\2\2\u00c5
<
\3\2\2
"
)
buf
.
write
(
"
\2\u00c6\u00c7\7
@
\2\2\u00c7\u00c8\7
?
\2\2\u00c8
>
\3\2\2
"
)
buf
.
write
(
"
\2\u00c9\u00ca\7
?
\2\2\u00ca\u00cb\7
?
\2\2\u00cb
@
\3\2\2
"
)
buf
.
write
(
"
\2\u00cc\u00cd\7
#
\2\2\u00cd\u00ce\7
?
\2\2\u00ce
B
\3\2\2
"
)
buf
.
write
(
"
\2\u00cf\u00d0\7
B
\2\2\u00d0\u00d1\5
U+
\2\u00d1
D
\3\2\2\2
"
)
buf
.
write
(
"
\u00d2\u00d3\7\'\2\2\u00d3\u00d4\5
U+
\2\u00d4
F
\3\2\2\2
"
)
buf
.
write
(
"
\u00d5\u00d6\7\'\2\2\u00d6\u00d7\5
Q)
\2\u00d7
H
\3\2\2\2
"
)
buf
.
write
(
"
\u00d8\u00d9\7
o
\2\2\u00d9\u00da\7
w
\2\2\u00da\u00db\7
v"
)
buf
.
write
(
"
\2\2\u00db
J
\3\2\2\2\u00dc\u00dd\7
V
\2\2\u00dd\u00de\7
t"
)
buf
.
write
(
"
\2\2\u00de\u00df\7
w
\2\2\u00df\u00e6\7
g
\2\2\u00e0\u00e1
"
)
buf
.
write
(
"
\7
H
\2\2\u00e1\u00e2\7
c
\2\2\u00e2\u00e3\7
n
\2\2\u00e3\u00e4
"
)
buf
.
write
(
"
\7
u
\2\2\u00e4\u00e6\7
g
\2\2\u00e5\u00dc\3\2\2\2\u00e5\u00e0
"
)
buf
.
write
(
"
\3\2\2\2\u00e6
L
\3\2\2\2\u00e7\u00ea\5
Q)
\2\u00e8\u00e9
"
)
buf
.
write
(
"
\7\60\2\2\u00e9\u00eb\5
Q)
\2\u00ea\u00e8\3\2\2\2\u00ea
"
)
buf
.
write
(
"
\u00eb\3\2\2\2\u00eb\u00ed\3\2\2\2\u00ec\u00ee\5
S*
\2\u00ed
"
)
buf
.
write
(
"
\u00ec\3\2\2\2\u00ed\u00ee\3\2\2\2\u00ee
N
\3\2\2\2\u00ef
"
)
buf
.
write
(
"
\u00f0\5
M
\'\2\u00f0\u00f1\7
h
\2\2\u00f1
P
\3\2\2\2\u00f2
"
)
buf
.
write
(
"
\u00f4\5
Y-
\2\u00f3\u00f2\3\2\2\2\u00f4\u00f5\3\2\2\2\u00f5
"
)
buf
.
write
(
"
\u00f3\3\2\2\2\u00f5\u00f6\3\2\2\2\u00f6
R
\3\2\2\2\u00f7
"
)
buf
.
write
(
"
\u00f9\t\3\2\2\u00f8\u00fa\t\4\2\2\u00f9\u00f8\3\2\2\2
"
)
buf
.
write
(
"
\u00f9\u00fa\3\2\2\2\u00fa\u00fb\3\2\2\2\u00fb\u00fc\5
"
)
buf
.
write
(
"Q)
\2\u00fc
T
\3\2\2\2\u00fd\u0100\7
a
\2\2\u00fe\u0100\5
W"
)
buf
.
write
(
",
\2\u00ff\u00fd\3\2\2\2\u00ff\u00fe\3\2\2\2\u0100\u0106
"
)
buf
.
write
(
"
\3\2\2\2\u0101\u0105\7
a
\2\2\u0102\u0105\5
W,
\2\u0103\u0105
"
)
buf
.
write
(
"
\5
Y-
\2\u0104\u0101\3\2\2\2\u0104\u0102\3\2\2\2\u0104\u0103
"
)
buf
.
write
(
"
\3\2\2\2\u0105\u0108\3\2\2\2\u0106\u0104\3\2\2\2\u0106
"
)
buf
.
write
(
"
\u0107\3\2\2\2\u0107
V
\3\2\2\2\u0108\u0106\3\2\2\2\u0109
"
)
buf
.
write
(
"
\u010a\t\5\2\2\u010a
X
\3\2\2\2\u010b\u010c\t\6\2\2\u010c
"
)
buf
.
write
(
"Z
\3\2\2\2\16\2\u0098\u00a2\u00af\u00e5\u00ea\u00ed\u00f5
"
)
buf
.
write
(
"
\u00f9\u00ff\u0104\u0106\3\b\2\2
"
)
buf
.
write
(
"
\4
&
\t
&
\4\'\t\'\4
(
\t
(
\4
)
\t
)
\4
*
\t
*
\4
+
\t
+
\4
,
\t
,
\4
-
\t
-
\4
."
)
buf
.
write
(
"
\t
.
\4
/
\t
/
\4\60\t\60\4\61\t\61\4\62\t\62\4\63\t\63\3\2
"
)
buf
.
write
(
"
\3\2\3\3\3\3\3\4\3\4\3\5\3\5\3\6\3\6\3\7\3\7\3\b\3\b\3
"
)
buf
.
write
(
"
\t\3\t\3\n\3\n\3\n\3\13\3\13\3\13\3\13\3\13\3\f\3\f\3
"
)
buf
.
write
(
"
\f\3\f\3\r\3\r\3\16\3\16\3\17\3\17\3\17\3\20\3\20\3\20
"
)
buf
.
write
(
"
\3\21\3\21\3\21\3\22\3\22\3\22\3\22\3\23\3\23\3\24\3\24
"
)
buf
.
write
(
"
\3\24\3\24\3\24\3\24\3\24\3\25\3\25\3\26\3\26\3\26\3\26
"
)
buf
.
write
(
"
\3\26\3\27\3\27\3\27\3\27\3\27\3\27\3\27\3\30\3\30\3\30
"
)
buf
.
write
(
"
\3\30\3\30\7\30\u00b1\n\30\f\30\16\30\u00b4\13\30\3\30
"
)
buf
.
write
(
"
\3\30\3\30\3\30\3\30\3\31\6\31\u00bc\n\31\r\31\16\31\u00bd
"
)
buf
.
write
(
"
\3\31\3\31\3\32\3\32\3\32\3\32\7\32\u00c6\n\32\f\32\16
"
)
buf
.
write
(
"
\32\u00c9\13\32\3\32\3\32\3\32\3\32\3\33\3\33\3\33\3\34
"
)
buf
.
write
(
"
\3\34\3\34\7\34\u00d5\n\34\f\34\16\34\u00d8\13\34\3\34
"
)
buf
.
write
(
"
\3\34\3\35\3\35\3\36\3\36\3\37\3\37\3
\3
\3
!
\3
!
\3\"\3
"
)
buf
.
write
(
"
\"\3
#
\3
#
\3
#
\3
$
\3
$
\3
$
\3
%
\3
%
\3
%
\3
&
\3
&
\3
&
\3\'\3\'\3\'\3\'
"
)
buf
.
write
(
"
\3\'\3\'\3\'\3\'\3\'\5\'\u00fd\n\'\3
(
\3
(
\5
(
\u0101\n
(
\3
"
)
buf
.
write
(
"(
\3
(
\3
(
\7
(
\u0106\n
(
\f
(
\16
(
\u0109\13
(
\3
(
\3
(
\7
(
\u010d\n
"
)
buf
.
write
(
"(
\f
(
\16
(
\u0110\13
(
\3
)
\3
)
\3
)
\3
*
\3
*
\3
*
\3
+
\3
+
\3
+
\3
,
\3
,
\3
"
)
buf
.
write
(
",
\3
,
\3
,
\3
,
\3
-
\3
-
\3
-
\5
-
\u0124\n
-
\3
-
\5
-
\u0127\n
-
\3
.
\3
.
\3
"
)
buf
.
write
(
".
\3
/
\6
/
\u012d\n
/
\r
/
\16
/
\u012e\3\60\3\60\5\60\u0133\n\60
"
)
buf
.
write
(
"
\3\60\3\60\3\61\3\61\3\62\3\62\3\63\3\63\3\63\3\63\3\63
"
)
buf
.
write
(
"
\3\63\3\63\3\63\3\63\3\63\3\63\7\63\u0146\n\63\f\63\16
"
)
buf
.
write
(
"
\63\u0149\13\63\5\u00b2\u00c7\u00d6\2\64\3\3\5\4\7\5\t
"
)
buf
.
write
(
"
\6\13\7\r\b\17\t\21\n\23\13\25\f\27\r\31\16\33\17\35\20
"
)
buf
.
write
(
"
\37\21
!
\22
#
\23
%
\24\'\25
)
\26
+
\27
-
\30
/
\31\61\32\63\33\65
"
)
buf
.
write
(
"
\2\67\34
9
\35
;
\36
=
\37
? A!C
\"
E#G$I
%
K&M
\'
O(Q)S*U+W,Y
\2
[-"
)
buf
.
write
(
"]._
\2
a
\2
c
\2
e/
\3\2\b\5\2\13\f\17\17\"\"\4\2\f\f\17\17\4
"
)
buf
.
write
(
"
\2
GGgg
\4\2
--//
\4\2
C
\\
c|
\3\2\62
;
\2\u0155\2\3\3\2\2\2\2
"
)
buf
.
write
(
"
\5\3\2\2\2\2\7\3\2\2\2\2\t\3\2\2\2\2\13\3\2\2\2\2\r\3
"
)
buf
.
write
(
"
\2\2\2\2\17\3\2\2\2\2\21\3\2\2\2\2\23\3\2\2\2\2\25\3\2
"
)
buf
.
write
(
"
\2\2\2\27\3\2\2\2\2\31\3\2\2\2\2\33\3\2\2\2\2\35\3\2\2
"
)
buf
.
write
(
"
\2\2\37\3\2\2\2\2
!
\3\2\2\2\2
#
\3\2\2\2\2
%
\3\2\2\2\2\'\3
"
)
buf
.
write
(
"
\2\2\2\2
)
\3\2\2\2\2
+
\3\2\2\2\2
-
\3\2\2\2\2
/
\3\2\2\2\2\61
"
)
buf
.
write
(
"
\3\2\2\2\2\63\3\2\2\2\2\67\3\2\2\2\2
9
\3\2\2\2\2
;
\3\2\2
"
)
buf
.
write
(
"
\2\2
=
\3\2\2\2\2
?
\3\2\2\2\2
A
\3\2\2\2\2
C
\3\2\2\2\2
E
\3\2
"
)
buf
.
write
(
"
\2\2\2
G
\3\2\2\2\2
I
\3\2\2\2\2
K
\3\2\2\2\2
M
\3\2\2\2\2
O
\3
"
)
buf
.
write
(
"
\2\2\2\2
Q
\3\2\2\2\2
S
\3\2\2\2\2
U
\3\2\2\2\2
W
\3\2\2\2\2
["
)
buf
.
write
(
"
\3\2\2\2\2
]
\3\2\2\2\2
e
\3\2\2\2\3
g
\3\2\2\2\5
i
\3\2\2\2\7
"
)
buf
.
write
(
"k
\3\2\2\2\t
m
\3\2\2\2\13
o
\3\2\2\2\r
q
\3\2\2\2\17
s
\3\2\2
"
)
buf
.
write
(
"
\2\21
u
\3\2\2\2\23
w
\3\2\2\2\25
z
\3\2\2\2\27\177\3\2\2\2
"
)
buf
.
write
(
"
\31\u0083\3\2\2\2\33\u0085\3\2\2\2\35\u0087\3\2\2\2\37
"
)
buf
.
write
(
"
\u008a\3\2\2\2
!
\u008d\3\2\2\2
#
\u0090\3\2\2\2
%
\u0094\3
"
)
buf
.
write
(
"
\2\2\2\'\u0096\3\2\2\2
)
\u009d\3\2\2\2
+
\u009f\3\2\2\2
-"
)
buf
.
write
(
"
\u00a4\3\2\2\2
/
\u00ab\3\2\2\2\61\u00bb\3\2\2\2\63\u00c1
"
)
buf
.
write
(
"
\3\2\2\2\65\u00ce\3\2\2\2\67\u00d1\3\2\2\2
9
\u00db\3\2
"
)
buf
.
write
(
"
\2\2
;
\u00dd\3\2\2\2
=
\u00df\3\2\2\2
?
\u00e1\3\2\2\2
A
\u00e3
"
)
buf
.
write
(
"
\3\2\2\2
C
\u00e5\3\2\2\2
E
\u00e7\3\2\2\2
G
\u00ea\3\2\2\2
"
)
buf
.
write
(
"I
\u00ed\3\2\2\2
K
\u00f0\3\2\2\2
M
\u00fc\3\2\2\2
O
\u0100\3
"
)
buf
.
write
(
"
\2\2\2
Q
\u0111\3\2\2\2
S
\u0114\3\2\2\2
U
\u0117\3\2\2\2
W
\u011a
"
)
buf
.
write
(
"
\3\2\2\2
Y
\u0120\3\2\2\2
[
\u0128\3\2\2\2
]
\u012c\3\2\2\2
"
)
buf
.
write
(
"_
\u0130\3\2\2\2
a
\u0136\3\2\2\2
c
\u0138\3\2\2\2
e
\u013a\3
"
)
buf
.
write
(
"
\2\2\2
gh
\7
.
\2\2
h
\4\3\2\2\2
ij
\7
*
\2\2
j
\6\3\2\2\2
kl
\7
+
\2
"
)
buf
.
write
(
"
\2
l
\b\3\2\2\2
mn
\7
}
\2\2
n
\n\3\2\2\2
op
\7\177\2\2
p
\f\3\2\2
"
)
buf
.
write
(
"
\2
qr
\7\60\2\2
r
\16\3\2\2\2
st
\7
]
\2\2
t
\20\3\2\2\2
uv
\7
_
\2
"
)
buf
.
write
(
"
\2
v
\22\3\2\2\2
wx
\7
k
\2\2
xy
\7
h
\2\2
y
\24\3\2\2\2
z{
\7
g
\2\2
"
)
buf
.
write
(
"{|
\7
n
\2\2
|}
\7
u
\2\2
}~
\7
g
\2\2
~
\26\3\2\2\2\177\u0080\7
n
\2
"
)
buf
.
write
(
"
\2\u0080\u0081\7
g
\2\2\u0081\u0082\7
v
\2\2\u0082\30\3\2
"
)
buf
.
write
(
"
\2\2\u0083\u0084\7
?
\2\2\u0084\32\3\2\2\2\u0085\u0086\7
"
)
buf
.
write
(
"=
\2\2\u0086\34\3\2\2\2\u0087\u0088\7
=
\2\2\u0088\u0089
"
)
buf
.
write
(
"
\7
=
\2\2\u0089\36\3\2\2\2\u008a\u008b\7
h
\2\2\u008b\u008c
"
)
buf
.
write
(
"
\7
p
\2\2\u008c
\3\2\2\2\u008d\u008e\7
/
\2\2\u008e\u008f
"
)
buf
.
write
(
"
\7
@
\2\2\u008f\"\3\2\2\2\u0090\u0091\7
f
\2\2\u0091\u0092
"
)
buf
.
write
(
"
\7
g
\2\2\u0092\u0093\7
h
\2\2\u0093
$
\3\2\2\2\u0094\u0095
"
)
buf
.
write
(
"
\7
<
\2\2\u0095
&
\3\2\2\2\u0096\u0097\7
V
\2\2\u0097\u0098
"
)
buf
.
write
(
"
\7
g
\2\2\u0098\u0099\7
p
\2\2\u0099\u009a\7
u
\2\2\u009a\u009b
"
)
buf
.
write
(
"
\7
q
\2\2\u009b\u009c\7
t
\2\2\u009c
(
\3\2\2\2\u009d\u009e
"
)
buf
.
write
(
"
\7
a
\2\2\u009e
*
\3\2\2\2\u009f\u00a0\7
o
\2\2\u00a0\u00a1
"
)
buf
.
write
(
"
\7
g
\2\2\u00a1\u00a2\7
v
\2\2\u00a2\u00a3\7
c
\2\2\u00a3
,
\3
"
)
buf
.
write
(
"
\2\2\2\u00a4\u00a5\7
x
\2\2\u00a5\u00a6\7\62\2\2\u00a6\u00a7
"
)
buf
.
write
(
"
\7\60\2\2\u00a7\u00a8\7\62\2\2\u00a8\u00a9\7\60\2\2\u00a9
"
)
buf
.
write
(
"
\u00aa\7\65\2\2\u00aa
.
\3\2\2\2\u00ab\u00ac\7\61\2\2\u00ac
"
)
buf
.
write
(
"
\u00ad\7
,
\2\2\u00ad\u00b2\3\2\2\2\u00ae\u00b1\5
/
\30\2
"
)
buf
.
write
(
"
\u00af\u00b1\13\2\2\2\u00b0\u00ae\3\2\2\2\u00b0\u00af
"
)
buf
.
write
(
"
\3\2\2\2\u00b1\u00b4\3\2\2\2\u00b2\u00b3\3\2\2\2\u00b2
"
)
buf
.
write
(
"
\u00b0\3\2\2\2\u00b3\u00b5\3\2\2\2\u00b4\u00b2\3\2\2\2
"
)
buf
.
write
(
"
\u00b5\u00b6\7
,
\2\2\u00b6\u00b7\7\61\2\2\u00b7\u00b8\3
"
)
buf
.
write
(
"
\2\2\2\u00b8\u00b9\b\30\2\2\u00b9\60\3\2\2\2\u00ba\u00bc
"
)
buf
.
write
(
"
\t\2\2\2\u00bb\u00ba\3\2\2\2\u00bc\u00bd\3\2\2\2\u00bd
"
)
buf
.
write
(
"
\u00bb\3\2\2\2\u00bd\u00be\3\2\2\2\u00be\u00bf\3\2\2\2
"
)
buf
.
write
(
"
\u00bf\u00c0\b\31\2\2\u00c0\62\3\2\2\2\u00c1\u00c2\7\61
"
)
buf
.
write
(
"
\2\2\u00c2\u00c3\7\61\2\2\u00c3\u00c7\3\2\2\2\u00c4\u00c6
"
)
buf
.
write
(
"
\13\2\2\2\u00c5\u00c4\3\2\2\2\u00c6\u00c9\3\2\2\2\u00c7
"
)
buf
.
write
(
"
\u00c8\3\2\2\2\u00c7\u00c5\3\2\2\2\u00c8\u00ca\3\2\2\2
"
)
buf
.
write
(
"
\u00c9\u00c7\3\2\2\2\u00ca\u00cb\7\f\2\2\u00cb\u00cc\3
"
)
buf
.
write
(
"
\2\2\2\u00cc\u00cd\b\32\2\2\u00cd\64\3\2\2\2\u00ce\u00cf
"
)
buf
.
write
(
"
\7
^
\2\2\u00cf\u00d0\7
$
\2\2\u00d0\66\3\2\2\2\u00d1\u00d6
"
)
buf
.
write
(
"
\7
$
\2\2\u00d2\u00d5\5\65\33\2\u00d3\u00d5\n\3\2\2\u00d4
"
)
buf
.
write
(
"
\u00d2\3\2\2\2\u00d4\u00d3\3\2\2\2\u00d5\u00d8\3\2\2\2
"
)
buf
.
write
(
"
\u00d6\u00d7\3\2\2\2\u00d6\u00d4\3\2\2\2\u00d7\u00d9\3
"
)
buf
.
write
(
"
\2\2\2\u00d8\u00d6\3\2\2\2\u00d9\u00da\7
$
\2\2\u00da
8
\3
"
)
buf
.
write
(
"
\2\2\2\u00db\u00dc\7
,
\2\2\u00dc
:
\3\2\2\2\u00dd\u00de\7
"
)
buf
.
write
(
"
\61\2\2\u00de
<
\3\2\2\2\u00df\u00e0\7
-
\2\2\u00e0
>
\3\2\2
"
)
buf
.
write
(
"
\2\u00e1\u00e2\7
/
\2\2\u00e2
@
\3\2\2\2\u00e3\u00e4\7
>
\2
"
)
buf
.
write
(
"
\2\u00e4
B
\3\2\2\2\u00e5\u00e6\7
@
\2\2\u00e6
D
\3\2\2\2\u00e7
"
)
buf
.
write
(
"
\u00e8\7
>
\2\2\u00e8\u00e9\7
?
\2\2\u00e9
F
\3\2\2\2\u00ea
"
)
buf
.
write
(
"
\u00eb\7
@
\2\2\u00eb\u00ec\7
?
\2\2\u00ec
H
\3\2\2\2\u00ed
"
)
buf
.
write
(
"
\u00ee\7
?
\2\2\u00ee\u00ef\7
?
\2\2\u00ef
J
\3\2\2\2\u00f0
"
)
buf
.
write
(
"
\u00f1\7
#
\2\2\u00f1\u00f2\7
?
\2\2\u00f2
L
\3\2\2\2\u00f3
"
)
buf
.
write
(
"
\u00f4\7
V
\2\2\u00f4\u00f5\7
t
\2\2\u00f5\u00f6\7
w
\2\2\u00f6
"
)
buf
.
write
(
"
\u00fd\7
g
\2\2\u00f7\u00f8\7
H
\2\2\u00f8\u00f9\7
c
\2\2\u00f9
"
)
buf
.
write
(
"
\u00fa\7
n
\2\2\u00fa\u00fb\7
u
\2\2\u00fb\u00fd\7
g
\2\2\u00fc
"
)
buf
.
write
(
"
\u00f3\3\2\2\2\u00fc\u00f7\3\2\2\2\u00fd
N
\3\2\2\2\u00fe
"
)
buf
.
write
(
"
\u0101\7
a
\2\2\u00ff\u0101\5
a
\61\2\u0100\u00fe\3\2\2\2
"
)
buf
.
write
(
"
\u0100\u00ff\3\2\2\2\u0101\u0107\3\2\2\2\u0102\u0106\7
"
)
buf
.
write
(
"a
\2\2\u0103\u0106\5
a
\61\2\u0104\u0106\5
c
\62\2\u0105\u0102
"
)
buf
.
write
(
"
\3\2\2\2\u0105\u0103\3\2\2\2\u0105\u0104\3\2\2\2\u0106
"
)
buf
.
write
(
"
\u0109\3\2\2\2\u0107\u0105\3\2\2\2\u0107\u0108\3\2\2\2
"
)
buf
.
write
(
"
\u0108\u010e\3\2\2\2\u0109\u0107\3\2\2\2\u010a\u010b\7
"
)
buf
.
write
(
"
\60\2\2\u010b\u010d\5
O(
\2\u010c\u010a\3\2\2\2\u010d\u0110
"
)
buf
.
write
(
"
\3\2\2\2\u010e\u010c\3\2\2\2\u010e\u010f\3\2\2\2\u010f
"
)
buf
.
write
(
"P
\3\2\2\2\u0110\u010e\3\2\2\2\u0111\u0112\7
B
\2\2\u0112
"
)
buf
.
write
(
"
\u0113\5
O(
\2\u0113
R
\3\2\2\2\u0114\u0115\7\'\2\2\u0115
"
)
buf
.
write
(
"
\u0116\5
O(
\2\u0116
T
\3\2\2\2\u0117\u0118\7\'\2\2\u0118
"
)
buf
.
write
(
"
\u0119\5
]/
\2\u0119
V
\3\2\2\2\u011a\u011b\7
k
\2\2\u011b\u011c
"
)
buf
.
write
(
"
\7
p
\2\2\u011c\u011d\7
v
\2\2\u011d\u011e\7
8
\2\2\u011e\u011f
"
)
buf
.
write
(
"
\7\66\2\2\u011f
X
\3\2\2\2\u0120\u0123\5
]/
\2\u0121\u0122
"
)
buf
.
write
(
"
\7\60\2\2\u0122\u0124\5
]/
\2\u0123\u0121\3\2\2\2\u0123
"
)
buf
.
write
(
"
\u0124\3\2\2\2\u0124\u0126\3\2\2\2\u0125\u0127\5
_
\60\2
"
)
buf
.
write
(
"
\u0126\u0125\3\2\2\2\u0126\u0127\3\2\2\2\u0127
Z
\3\2\2
"
)
buf
.
write
(
"
\2\u0128\u0129\5
Y-
\2\u0129\u012a\7
h
\2\2\u012a\\\3\2\2
"
)
buf
.
write
(
"
\2\u012b\u012d\5
c
\62\2\u012c\u012b\3\2\2\2\u012d\u012e
"
)
buf
.
write
(
"
\3\2\2\2\u012e\u012c\3\2\2\2\u012e\u012f\3\2\2\2\u012f
"
)
buf
.
write
(
"^
\3\2\2\2\u0130\u0132\t\4\2\2\u0131\u0133\t\5\2\2\u0132
"
)
buf
.
write
(
"
\u0131\3\2\2\2\u0132\u0133\3\2\2\2\u0133\u0134\3\2\2\2
"
)
buf
.
write
(
"
\u0134\u0135\5
]/
\2\u0135
`
\3\2\2\2\u0136\u0137\t\6\2\2
"
)
buf
.
write
(
"
\u0137
b
\3\2\2\2\u0138\u0139\t\7\2\2\u0139
d
\3\2\2\2\u013a
"
)
buf
.
write
(
"
\u013b\7
O
\2\2\u013b\u013c\7
G
\2\2\u013c\u013d\7
V
\2\2\u013d
"
)
buf
.
write
(
"
\u013e\7
C
\2\2\u013e\u013f\7
F
\2\2\u013f\u0140\7
C
\2\2\u0140
"
)
buf
.
write
(
"
\u0141\7
V
\2\2\u0141\u0142\7
C
\2\2\u0142\u0143\7
<
\2\2\u0143
"
)
buf
.
write
(
"
\u0147\3\2\2\2\u0144\u0146\13\2\2\2\u0145\u0144\3\2\2
"
)
buf
.
write
(
"
\2\u0146\u0149\3\2\2\2\u0147\u0145\3\2\2\2\u0147\u0148
"
)
buf
.
write
(
"
\3\2\2\2\u0148
f
\3\2\2\2\u0149\u0147\3\2\2\2\23\2\u00b0
"
)
buf
.
write
(
"
\u00b2\u00bd\u00c7\u00d4\u00d6\u00fc\u0100\u0105\u0107
"
)
buf
.
write
(
"
\u010e\u0123\u0126\u012e\u0132\u0147\3\b\2\2
"
)
return
buf
.
getvalue
()
...
...
@@ -144,51 +175,59 @@ class RelayLexer(Lexer):
T__15
=
16
T__16
=
17
T__17
=
18
SEMVER
=
19
WS
=
20
LINE_COMMENT
=
21
COMMENT
=
22
MUL
=
23
DIV
=
24
ADD
=
25
SUB
=
26
LT
=
27
GT
=
28
LE
=
29
GE
=
30
EQ
=
31
NE
=
32
GLOBAL_VAR
=
33
LOCAL_VAR
=
34
GRAPH_VAR
=
35
MUT
=
36
T__18
=
19
T__19
=
20
T__20
=
21
SEMVER
=
22
COMMENT
=
23
WS
=
24
LINE_COMMENT
=
25
QUOTED_STRING
=
26
MUL
=
27
DIV
=
28
ADD
=
29
SUB
=
30
LT
=
31
GT
=
32
LE
=
33
GE
=
34
EQ
=
35
NE
=
36
BOOL_LIT
=
37
FLOAT
=
38
NAT
=
39
CNAME
=
40
CNAME
=
38
GLOBAL_VAR
=
39
LOCAL_VAR
=
40
GRAPH_VAR
=
41
DATATYPE
=
42
FLOAT
=
43
NAT
=
44
METADATA
=
45
channelNames
=
[
u"DEFAULT_TOKEN_CHANNEL"
,
u"HIDDEN"
]
modeNames
=
[
"DEFAULT_MODE"
]
literalNames
=
[
"<INVALID>"
,
"'('"
,
"')'"
,
"','"
,
"'['"
,
"']'"
,
"'if'"
,
"'else'"
,
"'let'"
,
"'='"
,
"';'"
,
"'{'"
,
"'}'"
,
"'fn'"
,
"'->'"
,
"'def'"
,
"':'"
,
"'Tensor'"
,
"'_'"
,
"'v0.0.3'"
,
"'*'"
,
"'/'"
,
"'+'"
,
"'-'"
,
"'<'"
,
"'>'"
,
"'<='"
,
"'>='"
,
"'=='"
,
"'!='"
,
"'mut'"
]
"','"
,
"'('"
,
"')'"
,
"'{'"
,
"'}'"
,
"'.'"
,
"'['"
,
"']'"
,
"'if'"
,
"'else'"
,
"'let'"
,
"'='"
,
"';'"
,
"';;'"
,
"'fn'"
,
"'->'"
,
"'def'"
,
"':'"
,
"'Tensor'"
,
"'_'"
,
"'meta'"
,
"'v0.0.3'"
,
"'*'"
,
"'/'"
,
"'+'"
,
"'-'"
,
"'<'"
,
"'>'"
,
"'<='"
,
"'>='"
,
"'=='"
,
"'!='"
,
"'int64'"
]
symbolicNames
=
[
"<INVALID>"
,
"SEMVER"
,
"WS"
,
"LINE_COMMENT"
,
"COMMENT"
,
"MUL"
,
"DIV"
,
"ADD"
,
"SUB"
,
"LT"
,
"GT"
,
"LE"
,
"GE"
,
"EQ"
,
"NE"
,
"GLOBAL_VAR"
,
"LOCAL_VAR"
,
"GRAPH_VAR"
,
"MUT"
,
"BOOL_LIT"
,
"FLOAT"
,
"NAT"
,
"CNAME"
]
"SEMVER"
,
"COMMENT"
,
"WS"
,
"LINE_COMMENT"
,
"QUOTED_STRING"
,
"MUL"
,
"DIV"
,
"ADD"
,
"SUB"
,
"LT"
,
"GT"
,
"LE"
,
"GE"
,
"EQ"
,
"NE"
,
"BOOL_LIT"
,
"CNAME"
,
"GLOBAL_VAR"
,
"LOCAL_VAR"
,
"GRAPH_VAR"
,
"DATATYPE"
,
"FLOAT"
,
"NAT"
,
"METADATA"
]
ruleNames
=
[
"T__0"
,
"T__1"
,
"T__2"
,
"T__3"
,
"T__4"
,
"T__5"
,
"T__6"
,
"T__7"
,
"T__8"
,
"T__9"
,
"T__10"
,
"T__11"
,
"T__12"
,
"T__13"
,
"T__14"
,
"T__15"
,
"T__16"
,
"T__17"
,
"SEMVER"
,
"WS"
,
"LINE_COMMENT"
,
"COMMENT"
,
"MUL"
,
"DIV"
,
"ADD"
,
"SUB"
,
"LT"
,
"GT"
,
"LE"
,
"GE"
,
"EQ"
,
"NE"
,
"GLOBAL_VAR"
,
"LOCAL_VAR"
,
"GRAPH_VAR"
,
"MUT"
,
"BOOL_LIT"
,
"PREFLOAT"
,
"FLOAT"
,
"NAT"
,
"EXP"
,
"CNAME"
,
"LETTER"
,
"DIGIT"
]
"T__14"
,
"T__15"
,
"T__16"
,
"T__17"
,
"T__18"
,
"T__19"
,
"T__20"
,
"SEMVER"
,
"COMMENT"
,
"WS"
,
"LINE_COMMENT"
,
"ESCAPED_QUOTE"
,
"QUOTED_STRING"
,
"MUL"
,
"DIV"
,
"ADD"
,
"SUB"
,
"LT"
,
"GT"
,
"LE"
,
"GE"
,
"EQ"
,
"NE"
,
"BOOL_LIT"
,
"CNAME"
,
"GLOBAL_VAR"
,
"LOCAL_VAR"
,
"GRAPH_VAR"
,
"DATATYPE"
,
"PREFLOAT"
,
"FLOAT"
,
"NAT"
,
"EXP"
,
"LETTER"
,
"DIGIT"
,
"METADATA"
]
grammarFileName
=
"Relay.g4"
...
...
python/tvm/relay/grammar/py3/RelayParser.py
View file @
2973f8a6
This source diff could not be displayed because it is too large. You can
view the blob
instead.
python/tvm/relay/grammar/py3/RelayVisitor.py
View file @
2973f8a6
...
...
@@ -19,11 +19,51 @@ class RelayVisitor(ParseTreeVisitor):
return
self
.
visitChildren
(
ctx
)
# Visit a parse tree produced by RelayParser#exprList.
def
visitExprList
(
self
,
ctx
:
RelayParser
.
ExprListContext
):
return
self
.
visitChildren
(
ctx
)
# Visit a parse tree produced by RelayParser#callNoAttr.
def
visitCallNoAttr
(
self
,
ctx
:
RelayParser
.
CallNoAttrContext
):
return
self
.
visitChildren
(
ctx
)
# Visit a parse tree produced by RelayParser#callWithAttr.
def
visitCallWithAttr
(
self
,
ctx
:
RelayParser
.
CallWithAttrContext
):
return
self
.
visitChildren
(
ctx
)
# Visit a parse tree produced by RelayParser#funcExpr.
def
visitFuncExpr
(
self
,
ctx
:
RelayParser
.
FuncExprContext
):
return
self
.
visitChildren
(
ctx
)
# Visit a parse tree produced by RelayParser#metaExpr.
def
visitMetaExpr
(
self
,
ctx
:
RelayParser
.
MetaExprContext
):
return
self
.
visitChildren
(
ctx
)
# Visit a parse tree produced by RelayParser#tensor.
def
visitTensor
(
self
,
ctx
:
RelayParser
.
TensorContext
):
return
self
.
visitChildren
(
ctx
)
# Visit a parse tree produced by RelayParser#graph.
def
visitGraph
(
self
,
ctx
:
RelayParser
.
GraphContext
):
return
self
.
visitChildren
(
ctx
)
# Visit a parse tree produced by RelayParser#identExpr.
def
visitIdentExpr
(
self
,
ctx
:
RelayParser
.
IdentExprContext
):
return
self
.
visitChildren
(
ctx
)
# Visit a parse tree produced by RelayParser#stringExpr.
def
visitStringExpr
(
self
,
ctx
:
RelayParser
.
StringExprContext
):
return
self
.
visitChildren
(
ctx
)
# Visit a parse tree produced by RelayParser#call.
def
visitCall
(
self
,
ctx
:
RelayParser
.
CallContext
):
return
self
.
visitChildren
(
ctx
)
...
...
@@ -39,13 +79,8 @@ class RelayVisitor(ParseTreeVisitor):
return
self
.
visitChildren
(
ctx
)
# Visit a parse tree produced by RelayParser#parens.
def
visitParens
(
self
,
ctx
:
RelayParser
.
ParensContext
):
return
self
.
visitChildren
(
ctx
)
# Visit a parse tree produced by RelayParser#funcExpr.
def
visitFuncExpr
(
self
,
ctx
:
RelayParser
.
FuncExprContext
):
# Visit a parse tree produced by RelayParser#paren.
def
visitParen
(
self
,
ctx
:
RelayParser
.
ParenContext
):
return
self
.
visitChildren
(
ctx
)
...
...
@@ -59,8 +94,8 @@ class RelayVisitor(ParseTreeVisitor):
return
self
.
visitChildren
(
ctx
)
# Visit a parse tree produced by RelayParser#
tensor
.
def
visit
Tensor
(
self
,
ctx
:
RelayParser
.
Tensor
Context
):
# Visit a parse tree produced by RelayParser#
projection
.
def
visit
Projection
(
self
,
ctx
:
RelayParser
.
Projection
Context
):
return
self
.
visitChildren
(
ctx
)
...
...
@@ -69,11 +104,6 @@ class RelayVisitor(ParseTreeVisitor):
return
self
.
visitChildren
(
ctx
)
# Visit a parse tree produced by RelayParser#graph.
def
visitGraph
(
self
,
ctx
:
RelayParser
.
GraphContext
):
return
self
.
visitChildren
(
ctx
)
# Visit a parse tree produced by RelayParser#binOp.
def
visitBinOp
(
self
,
ctx
:
RelayParser
.
BinOpContext
):
return
self
.
visitChildren
(
ctx
)
...
...
@@ -89,8 +119,13 @@ class RelayVisitor(ParseTreeVisitor):
return
self
.
visitChildren
(
ctx
)
# Visit a parse tree produced by RelayParser#argList.
def
visitArgList
(
self
,
ctx
:
RelayParser
.
ArgListContext
):
# Visit a parse tree produced by RelayParser#argNoAttr.
def
visitArgNoAttr
(
self
,
ctx
:
RelayParser
.
ArgNoAttrContext
):
return
self
.
visitChildren
(
ctx
)
# Visit a parse tree produced by RelayParser#argWithAttr.
def
visitArgWithAttr
(
self
,
ctx
:
RelayParser
.
ArgWithAttrContext
):
return
self
.
visitChildren
(
ctx
)
...
...
@@ -104,8 +139,8 @@ class RelayVisitor(ParseTreeVisitor):
return
self
.
visitChildren
(
ctx
)
# Visit a parse tree produced by RelayParser#attr
List
.
def
visitAttr
List
(
self
,
ctx
:
RelayParser
.
AttrList
Context
):
# Visit a parse tree produced by RelayParser#attr
Seq
.
def
visitAttr
Seq
(
self
,
ctx
:
RelayParser
.
AttrSeq
Context
):
return
self
.
visitChildren
(
ctx
)
...
...
@@ -114,8 +149,8 @@ class RelayVisitor(ParseTreeVisitor):
return
self
.
visitChildren
(
ctx
)
# Visit a parse tree produced by RelayParser#typeParam
Seq
.
def
visitTypeParam
Seq
(
self
,
ctx
:
RelayParser
.
TypeParamSeq
Context
):
# Visit a parse tree produced by RelayParser#typeParam
List
.
def
visitTypeParam
List
(
self
,
ctx
:
RelayParser
.
TypeParamList
Context
):
return
self
.
visitChildren
(
ctx
)
...
...
@@ -149,8 +184,18 @@ class RelayVisitor(ParseTreeVisitor):
return
self
.
visitChildren
(
ctx
)
# Visit a parse tree produced by RelayParser#shapeSeq.
def
visitShapeSeq
(
self
,
ctx
:
RelayParser
.
ShapeSeqContext
):
# Visit a parse tree produced by RelayParser#shapeList.
def
visitShapeList
(
self
,
ctx
:
RelayParser
.
ShapeListContext
):
return
self
.
visitChildren
(
ctx
)
# Visit a parse tree produced by RelayParser#meta.
def
visitMeta
(
self
,
ctx
:
RelayParser
.
MetaContext
):
return
self
.
visitChildren
(
ctx
)
# Visit a parse tree produced by RelayParser#metaShape.
def
visitMetaShape
(
self
,
ctx
:
RelayParser
.
MetaShapeContext
):
return
self
.
visitChildren
(
ctx
)
...
...
python/tvm/relay/op/nn/nn.py
View file @
2973f8a6
...
...
@@ -66,34 +66,34 @@ def conv2d(data,
weight : tvm.relay.Expr
The weight expressions.
strides :
tuple of int, optional
strides :
Optional[Tuple[int]]
The strides of convolution.
padding :
tuple of int, optional
padding :
Optional[Tuple[int]]
The padding of convolution on both sides of inputs before convolution.
dilation :
tuple of int, optional
dilation :
Optional[Tuple[int]]
Specifies the dilation rate to be used for dilated convolution.
groups :
int, optional
groups :
Optional[int]
Number of groups for grouped convolution.
channels :
int, optional
channels :
Optional[int]
Number of output channels of this convolution.
kernel_size :
tuple of int, optional
kernel_size :
Optional[Tuple[int]]
The spatial of the convolution kernel.
data_layout :
str, optional
data_layout :
Optional[str]
Layout of the input.
kernel_layout :
str, optional
kernel_layout :
Optional[str]
Layout of the weight.
out_layout :
str, optional
out_layout :
Optional[str]
Layout of the output, by default, out_layout is the same as data_layout
out_dtype :
str, optional
out_dtype :
Optional[str]
Specifies the output data type for mixed precision conv2d.
Returns
...
...
@@ -691,8 +691,30 @@ def dropout(data, rate=0.5):
result : tvm.relay.Expr
The result of dropout
"""
result
=
_make
.
dropout
(
data
,
rate
)
return
TupleWrapper
(
result
,
2
)[
0
]
return
TupleWrapper
(
dropout_raw
(
data
,
rate
),
2
)[
0
]
def
dropout_raw
(
data
,
rate
=
0.5
):
"""Applies the dropout operation to the input array.
During training, each element of the input is set to zero with
probability ``p``. The whole array is rescaled by ``1/(1-p)``
to keep the expected sum of the input unchanged.
Parameters
----------
data : tvm.relay.Expr
The input data to the operator.
rate : float, optional (default=0.5)
The probability for an element to be reset to 0.
Returns
-------
result : tvm.relay.Expr
The result of dropout
"""
return
_make
.
dropout
(
data
,
rate
)
def
batch_norm
(
data
,
...
...
python/tvm/relay/parser.py
View file @
2973f8a6
...
...
@@ -23,4 +23,7 @@ from .. import register_func
def
fromtext
(
data
,
source_name
=
None
):
"""Parse a Relay program."""
from
tvm.relay
import
_parser
return
_parser
.
fromtext
(
data
,
source_name
)
x
=
_parser
.
fromtext
(
data
+
"
\n
"
,
source_name
)
if
x
is
None
:
raise
Exception
(
"cannot parse: "
,
data
)
return
x
python/tvm/relay/testing/densenet.py
View file @
2973f8a6
...
...
@@ -42,7 +42,7 @@ def _make_dense_block(data, num_layers, bn_size, growth_rate, index):
layer_out
=
data
for
i
in
range
(
num_layers
):
layer_out
=
_make_dense_layer
(
layer_out
,
growth_rate
,
bn_size
,
"
(
%
s,
%
s)
"
%
(
index
,
i
))
"
%
s_
%
s
"
%
(
index
,
i
))
return
layer_out
def
_make_transition
(
data
,
num_output_features
,
index
):
...
...
python/tvm/relay/ty.py
View file @
2973f8a6
...
...
@@ -29,7 +29,7 @@ class Type(RelayNode):
"""Compare two Relay types for structural equivalence using
alpha equivalence.
"""
return
bool
(
_make
.
_
type_
alpha_equal
(
self
,
other
))
return
bool
(
_make
.
_alpha_equal
(
self
,
other
))
def
__ne__
(
self
,
other
):
return
not
self
.
__eq__
(
other
)
...
...
src/relay/ir/alpha_equal.cc
View file @
2973f8a6
...
...
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...
...
@@ -18,7 +18,7 @@
*/
/*!
* Copyright (c) 201
8
by Contributors
* Copyright (c) 201
9
by Contributors
* \file src/tvm/relay/ir/alpha_equal.cc
* \brief Alpha equality check by deep comparing two nodes.
*/
...
...
@@ -27,9 +27,10 @@
#include <tvm/relay/pattern_functor.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/attrs/nn.h>
#include "type_functor.h"
#include "../../lang/attr_functor.h"
namespace
tvm
{
namespace
relay
{
...
...
@@ -40,8 +41,8 @@ class AlphaEqualHandler:
public
ExprFunctor
<
bool
(
const
Expr
&
,
const
Expr
&
)
>
,
public
PatternFunctor
<
bool
(
const
Pattern
&
,
const
Pattern
&
)
>
{
public
:
explicit
AlphaEqualHandler
(
bool
map_free_var
)
:
map_free_var_
(
map_free_var
)
{
}
explicit
AlphaEqualHandler
(
bool
map_free_var
,
bool
assert_mode
)
:
map_free_var_
(
map_free_var
)
,
assert_mode_
(
assert_mode
)
{
}
/*!
* Check equality of two nodes.
...
...
@@ -76,6 +77,9 @@ class AlphaEqualHandler:
return
AttrEqual
(
lhs
,
rhs
);
}
bool
DoubleEqual
(
double
l
,
double
r
)
{
return
true
;
}
/*!
* Check equality of two attributes.
* \param lhs The left hand operand.
...
...
@@ -83,18 +87,28 @@ class AlphaEqualHandler:
* \return The comparison result.
*/
bool
AttrEqual
(
const
NodeRef
&
lhs
,
const
NodeRef
&
rhs
)
{
if
(
&
lhs
==
&
rhs
)
return
true
;
auto
lhsd
=
lhs
.
as
<
DictAttrsNode
>
();
if
(
lhsd
)
{
auto
rhsd
=
lhs
.
as
<
DictAttrsNode
>
();
if
(
!
rhsd
)
return
false
;
if
(
lhsd
->
dict
.
size
()
!=
rhsd
->
dict
.
size
())
return
false
;
for
(
const
auto
&
k
:
lhsd
->
dict
)
{
if
(
!
Equal
(
k
.
second
,
rhsd
->
dict
[
k
.
first
]))
return
false
;
auto
compute
=
[
&
]()
{
if
(
&
lhs
==
&
rhs
)
return
true
;
if
(
auto
lhsd
=
lhs
.
as
<
DictAttrsNode
>
())
{
auto
rhsd
=
lhs
.
as
<
DictAttrsNode
>
();
if
(
!
rhsd
)
return
false
;
if
(
lhsd
->
dict
.
size
()
!=
rhsd
->
dict
.
size
())
return
false
;
for
(
const
auto
&
k
:
lhsd
->
dict
)
{
if
(
!
Equal
(
k
.
second
,
rhsd
->
dict
[
k
.
first
]))
return
false
;
}
return
true
;
}
return
true
;
}
return
AttrsEqualHandler
::
Equal
(
lhs
,
rhs
);
if
(
auto
lhsbn
=
lhs
.
as
<
BatchNormAttrs
>
())
{
auto
rhsbn
=
rhs
.
as
<
BatchNormAttrs
>
();
if
(
!
rhsbn
)
return
false
;
return
(
lhsbn
->
axis
==
rhsbn
->
axis
)
&&
DoubleEqual
(
lhsbn
->
epsilon
,
rhsbn
->
epsilon
)
&&
(
lhsbn
->
center
==
rhsbn
->
center
)
&&
(
lhsbn
->
scale
==
rhsbn
->
scale
);
}
return
AttrsEqualHandler
::
Equal
(
lhs
,
rhs
);
};
return
Compare
(
compute
(),
lhs
,
rhs
);
}
/*!
* Check equality of two types.
...
...
@@ -107,6 +121,13 @@ class AlphaEqualHandler:
if
(
!
lhs
.
defined
()
||
!
rhs
.
defined
())
return
false
;
return
this
->
VisitType
(
lhs
,
rhs
);
}
bool
Compare
(
bool
result
,
const
NodeRef
&
lhs
,
const
NodeRef
&
rhs
)
{
if
(
assert_mode_
)
{
CHECK
(
result
)
<<
"
\n
"
<<
AsText
(
lhs
,
true
)
<<
"
\n
is not equal to:
\n
"
<<
AsText
(
rhs
,
true
);
}
return
result
;
}
/*!
* Check equality of two expressions.
*
...
...
@@ -120,18 +141,21 @@ class AlphaEqualHandler:
* \return The comparison result.
*/
bool
ExprEqual
(
const
Expr
&
lhs
,
const
Expr
&
rhs
)
{
if
(
lhs
.
same_as
(
rhs
))
return
true
;
if
(
!
lhs
.
defined
()
||
!
rhs
.
defined
())
return
false
;
auto
it
=
equal_map_
.
find
(
lhs
);
if
(
it
!=
equal_map_
.
end
())
{
return
it
->
second
.
same_as
(
rhs
);
}
if
(
this
->
VisitExpr
(
lhs
,
rhs
))
{
equal_map_
[
lhs
]
=
rhs
;
return
true
;
}
else
{
return
false
;
}
auto
compute
=
[
&
]()
{
if
(
lhs
.
same_as
(
rhs
))
return
true
;
if
(
!
lhs
.
defined
()
||
!
rhs
.
defined
())
return
false
;
auto
it
=
equal_map_
.
find
(
lhs
);
if
(
it
!=
equal_map_
.
end
())
{
return
it
->
second
.
same_as
(
rhs
);
}
if
(
this
->
VisitExpr
(
lhs
,
rhs
))
{
equal_map_
[
lhs
]
=
rhs
;
return
true
;
}
else
{
return
false
;
}
};
return
Compare
(
compute
(),
lhs
,
rhs
);
}
protected
:
...
...
@@ -516,32 +540,41 @@ class AlphaEqualHandler:
private
:
// whether to map open terms.
bool
map_free_var_
;
// if in assert mode, must return true, and will throw error otherwise.
bool
assert_mode_
;
// renaming of NodeRef to indicate two nodes equals to each other
std
::
unordered_map
<
NodeRef
,
NodeRef
,
NodeHash
,
NodeEqual
>
equal_map_
;
};
bool
AlphaEqual
(
const
Type
&
lhs
,
const
Type
&
rhs
)
{
return
AlphaEqualHandler
(
false
).
TypeEqual
(
lhs
,
rhs
);
return
AlphaEqualHandler
(
false
,
false
).
TypeEqual
(
lhs
,
rhs
);
}
bool
AlphaEqual
(
const
Expr
&
lhs
,
const
Expr
&
rhs
)
{
return
AlphaEqualHandler
(
false
).
ExprEqual
(
lhs
,
rhs
);
return
AlphaEqualHandler
(
false
,
false
).
ExprEqual
(
lhs
,
rhs
);
}
// TODO(@jroesch): move to correct namespace?
TVM_REGISTER_API
(
"relay._make._alpha_equal"
)
.
set_body_typed
<
bool
(
NodeRef
,
NodeRef
)
>
([](
NodeRef
a
,
NodeRef
b
)
{
return
AlphaEqualHandler
(
false
).
Equal
(
a
,
b
);
return
AlphaEqualHandler
(
false
,
false
).
Equal
(
a
,
b
);
});
TVM_REGISTER_API
(
"relay._make._type_alpha_equal"
)
.
set_body_typed
<
bool
(
Type
,
Type
)
>
([](
Type
a
,
Type
b
)
{
return
AlphaEqualHandler
(
false
).
TypeEqual
(
a
,
b
);
TVM_REGISTER_API
(
"relay._make._assert_alpha_equal"
)
.
set_body_typed
<
void
(
NodeRef
,
NodeRef
)
>
([](
NodeRef
a
,
NodeRef
b
)
{
bool
alpha_equal
=
AlphaEqualHandler
(
false
,
true
).
Equal
(
a
,
b
);
CHECK
(
alpha_equal
)
<<
AsText
(
a
,
true
)
<<
" and "
<<
AsText
(
b
,
true
)
<<
" is not alpha equal"
;
});
TVM_REGISTER_API
(
"relay._make._graph_equal"
)
.
set_body_typed
<
bool
(
NodeRef
,
NodeRef
)
>
([](
NodeRef
a
,
NodeRef
b
)
{
return
AlphaEqualHandler
(
true
).
Equal
(
a
,
b
);
return
AlphaEqualHandler
(
true
,
false
).
Equal
(
a
,
b
);
});
TVM_REGISTER_API
(
"relay._make._assert_graph_equal"
)
.
set_body_typed
<
void
(
NodeRef
,
NodeRef
)
>
([](
NodeRef
a
,
NodeRef
b
)
{
bool
graph_equal
=
AlphaEqualHandler
(
true
,
true
).
Equal
(
a
,
b
);
CHECK
(
graph_equal
)
<<
AsText
(
a
,
true
)
<<
" and "
<<
AsText
(
b
,
true
)
<<
" is not graph equal"
;
});
}
// namespace relay
...
...
src/relay/ir/doc.cc
View file @
2973f8a6
...
...
@@ -89,7 +89,7 @@ std::string Doc::str() {
return
os
.
str
();
}
Doc
Print
Vec
(
const
std
::
vector
<
Doc
>&
vec
,
const
Doc
&
sep
)
{
Doc
Print
Sep
(
const
std
::
vector
<
Doc
>&
vec
,
const
Doc
&
sep
)
{
Doc
seq
;
if
(
vec
.
size
()
!=
0
)
{
seq
=
vec
[
0
];
...
...
src/relay/ir/doc.h
View file @
2973f8a6
...
...
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...
...
@@ -46,7 +46,11 @@ using DocAtom = std::shared_ptr<DocAtomNode>;
struct
TextNode
:
DocAtomNode
{
std
::
string
str
;
explicit
TextNode
(
const
std
::
string
&
str
)
:
str
(
str
)
{}
explicit
TextNode
(
const
std
::
string
&
str
)
:
str
(
str
)
{
if
(
str
.
find_first_of
(
"
\t\n
"
)
!=
str
.
npos
)
{
LOG
(
WARNING
)
<<
"text node: '"
<<
str
<<
"' should not has tab or newline."
;
}
}
};
struct
LineNode
:
DocAtomNode
{
...
...
@@ -91,8 +95,8 @@ class Doc {
// DSL functions
// Render vectors of docs with a separator. e.g. Print
Vec
([1, 2, 3], f) -> 1f2f3
Doc
Print
Vec
(
const
std
::
vector
<
Doc
>&
vec
,
const
Doc
&
sep
=
Doc
(
", "
));
// Render vectors of docs with a separator. e.g. Print
Sep
([1, 2, 3], f) -> 1f2f3
Doc
Print
Sep
(
const
std
::
vector
<
Doc
>&
vec
,
const
Doc
&
sep
=
Doc
(
", "
));
// Print a constant bool value.
Doc
PrintBool
(
bool
value
);
// Print a data type.
...
...
@@ -116,7 +120,8 @@ Doc PrintConstScalar(DataType dtype, const T* data) {
}
else
if
(
dtype
==
Bool
())
{
return
PrintBool
(
data
[
0
]
!=
0
);
}
else
{
os
<<
dtype
<<
"("
<<
data
[
0
]
<<
")"
;
// todo(@M.K.) this is unsafe. fix.
os
<<
data
[
0
];
}
return
Doc
(
os
.
str
());
}
...
...
src/relay/ir/pretty_printer.cc
View file @
2973f8a6
...
...
@@ -32,6 +32,7 @@
* - Otherwise, inline if the node is at the end of a scope and is used at most once.
*/
#include <dmlc/json.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/module.h>
#include <tvm/relay/pattern_functor.h>
...
...
@@ -43,6 +44,17 @@
namespace
tvm
{
namespace
relay
{
Doc
Brace
(
const
Doc
&
d
,
const
std
::
string
&
open
=
"{"
,
const
std
::
string
&
close
=
"}"
,
int
indent
=
2
)
{
Doc
doc
;
doc
<<
open
;
doc
<<
Indent
(
indent
,
PrintNewLine
()
<<
d
)
<<
PrintNewLine
();
doc
<<
close
;
return
doc
;
}
/*!
* \brief Meta data context for PrettyPrinter.
*
...
...
@@ -108,8 +120,10 @@ class TextMetaDataContext {
if
(
it
!=
meta_repr_
.
end
())
{
return
it
->
second
;
}
std
::
string
type_key
=
node
->
type_key
();
CHECK
(
!
type_key
.
empty
());
Array
<
NodeRef
>&
mvector
=
meta_data_
[
node
->
type_key
()
];
meta_data_
[
type_key
];
int64_t
index
=
static_cast
<
int64_t
>
(
mvector
.
size
());
mvector
.
push_back
(
node
);
Doc
doc
;
...
...
@@ -117,14 +131,18 @@ class TextMetaDataContext {
meta_repr_
[
node
]
=
doc
;
return
meta_repr_
[
node
];
}
Doc
PrintKeyValue
(
const
std
::
string
&
str
,
const
Doc
&
v
)
const
{
return
Doc
(
"
\"
"
)
<<
str
<<
"
\"
: "
<<
v
;
}
/*!
* \brief Get the metadata section in json format.
* \return the meta data string.
*/
std
::
string
GetMetaSection
()
const
{
if
(
meta_data_
.
size
()
==
0
)
return
std
::
string
();
return
SaveJSON
(
Map
<
std
::
string
,
NodeRef
>
(
meta_data_
.
begin
(),
meta_data_
.
end
()));
Doc
GetMetaSection
()
const
{
if
(
meta_data_
.
size
()
==
0
)
return
Doc
();
return
Doc
(
SaveJSON
(
Map
<
std
::
string
,
NodeRef
>
(
meta_data_
.
begin
(),
meta_data_
.
end
())));
}
/*! \return whether the meta data context is empty. */
...
...
@@ -172,12 +190,11 @@ class PrettyPrinter :
}
// indent a new body
// TODO(jmp): indent should be an instance variable of the printer
Doc
PrintBody
(
const
NodeRef
&
node
,
int
indent
=
2
)
{
Doc
doc
;
Doc
body
;
doc
<<
"{"
;
doc
<<
Indent
(
indent
,
body
<<
"
\n
"
<<
PrintScope
(
node
))
<<
"
\n
"
;
doc
<<
Indent
(
indent
,
body
<<
PrintNewLine
()
<<
PrintScope
(
node
))
<<
PrintNewLine
()
;
doc
<<
"}"
;
return
doc
;
}
...
...
@@ -203,13 +220,12 @@ class PrettyPrinter :
Doc
doc
;
doc
<<
PrintScope
(
node
);
if
(
!
meta_
.
empty
())
{
doc
<<
PrintNewLine
();
if
(
show_meta_data_
)
{
std
::
string
meta_json
=
meta_
.
GetMetaSection
();
// append meta data in the end.
doc
<<
"
\n
"
<<
"/* meta data */"
<<
"
\n
"
<<
meta_json
;
doc
<<
"
METADATA:"
<<
PrintNewLine
()
<<
meta_
.
GetMetaSection
()
;
}
else
{
doc
<<
"
\n
"
<<
"// meta data omitted. you can use show_meta_data=True to include meta data"
;
doc
<<
"// meta data omitted. you can use show_meta_data=True to include meta data"
;
}
}
return
doc
;
...
...
@@ -361,7 +377,7 @@ class PrettyPrinter :
// wrap GNFed let in brackets
Doc
body
;
printed_expr
<<
"{"
;
printed_expr
<<
Indent
(
2
,
body
<<
"
\n
"
<<
VisitExpr
(
expr
))
<<
"
\n
"
;
printed_expr
<<
Indent
(
2
,
body
<<
PrintNewLine
()
<<
VisitExpr
(
expr
))
<<
PrintNewLine
()
;
printed_expr
<<
"}"
;
}
else
{
printed_expr
=
VisitExpr
(
expr
);
...
...
@@ -373,7 +389,7 @@ class PrettyPrinter :
if
(
expr
.
as
<
VarNode
>
())
{
// This is our first time visiting the var and we hit the VarNode case
// in the visitor. Thus the variable is free.
doc_stack_
.
back
()
<<
"free_var "
<<
printed_expr
<<
"
\n
"
;
doc_stack_
.
back
()
<<
"free_var "
<<
printed_expr
<<
PrintNewLine
()
;
// Memoization is done in AllocVar.
return
memo_
[
expr
];
}
else
if
(
inline_expr
)
{
...
...
@@ -422,7 +438,7 @@ class PrettyPrinter :
fields
.
push_back
(
Print
(
field
));
}
Doc
doc
;
doc
<<
"("
<<
Print
Vec
(
fields
);
doc
<<
"("
<<
Print
Sep
(
fields
);
// conform to python tuple format (1,)
if
(
op
->
fields
.
size
()
==
1
)
{
doc
<<
","
;
...
...
@@ -460,31 +476,31 @@ class PrettyPrinter :
}
Doc
PrintFunc
(
const
Doc
&
prefix
,
const
Function
&
fn
)
{
Doc
doc
;
doc
<<
prefix
;
if
(
fn
->
type_params
.
size
()
>
0
)
{
doc
<<
"<"
;
std
::
vector
<
Doc
>
type_params
;
for
(
const
TypeVar
&
tv
:
fn
->
type_params
)
{
type_params
.
push_back
(
AllocTypeVar
(
tv
));
}
doc
<<
PrintVec
(
type_params
);
doc
<<
">"
;
}
doc
<<
"("
;
std
::
vector
<
Doc
>
params
;
for
(
Var
param
:
fn
->
params
)
{
params
.
push_back
(
AllocVar
(
param
));
}
for
(
const
Doc
&
d
:
PrintFuncAttrs
(
fn
->
attrs
))
{
params
.
push_back
(
d
);
}
doc
<<
PrintVec
(
params
)
<<
") "
;
if
(
fn
->
ret_type
.
defined
())
{
doc
<<
"-> "
<<
Print
(
fn
->
ret_type
)
<<
" "
;
Doc
doc
;
doc
<<
prefix
;
if
(
fn
->
type_params
.
size
()
>
0
)
{
doc
<<
"<"
;
std
::
vector
<
Doc
>
type_params
;
for
(
const
TypeVar
&
tv
:
fn
->
type_params
)
{
type_params
.
push_back
(
AllocTypeVar
(
tv
));
}
doc
<<
PrintBody
(
fn
->
body
);
return
doc
;
doc
<<
PrintSep
(
type_params
);
doc
<<
">"
;
}
doc
<<
"("
;
std
::
vector
<
Doc
>
params
;
for
(
Var
param
:
fn
->
params
)
{
params
.
push_back
(
AllocVar
(
param
));
}
for
(
const
Doc
&
d
:
PrintFuncAttrs
(
fn
->
attrs
))
{
params
.
push_back
(
d
);
}
doc
<<
PrintSep
(
params
)
<<
") "
;
if
(
fn
->
ret_type
.
defined
())
{
doc
<<
"-> "
<<
Print
(
fn
->
ret_type
)
<<
" "
;
}
doc
<<
PrintBody
(
fn
->
body
);
return
doc
;
}
Doc
PrintMod
(
const
Module
&
mod
)
{
...
...
@@ -493,13 +509,13 @@ class PrettyPrinter :
for
(
const
auto
&
kv
:
mod
->
functions
)
{
dg_
=
DependencyGraph
::
Create
(
&
arena_
,
kv
.
second
);
std
::
ostringstream
os
;
if
(
counter
++
!=
0
)
{
doc
<<
"
\n
"
;
doc
<<
PrintNewLine
()
;
}
std
::
ostringstream
os
;
os
<<
"def @"
<<
kv
.
first
->
name_hint
;
doc
<<
PrintFunc
(
Doc
(
os
.
str
()),
kv
.
second
);
doc
<<
"
\n
"
;
doc
<<
PrintNewLine
()
;
}
return
doc
;
}
...
...
@@ -528,7 +544,7 @@ class PrettyPrinter :
args
.
push_back
(
d
);
}
doc
<<
Print
(
op
->
op
);
return
doc
<<
"("
<<
Print
Vec
(
args
)
<<
")"
;
return
doc
<<
"("
<<
Print
Sep
(
args
)
<<
")"
;
}
Doc
VisitExpr_
(
const
RefCreateNode
*
op
)
final
{
...
...
@@ -558,7 +574,7 @@ class PrettyPrinter :
clauses
.
push_back
(
clause_doc
<<
Print
(
clause
->
lhs
)
<<
" -> "
<<
Print
(
clause
->
rhs
));
}
doc
<<
Indent
(
2
,
body
<<
"
\n
"
<<
PrintVec
(
clauses
,
Doc
(
"
\n
"
)))
<<
"
\n
"
;
doc
<<
Indent
(
2
,
body
<<
PrintNewLine
()
<<
PrintSep
(
clauses
,
PrintNewLine
()))
<<
PrintNewLine
()
;
doc
<<
"}"
;
return
doc
;
}
...
...
@@ -570,7 +586,7 @@ class PrettyPrinter :
for
(
const
auto
&
pat
:
p
->
patterns
)
{
pats
.
push_back
(
Print
(
pat
));
}
return
doc
<<
Print
Vec
(
pats
)
<<
")"
;
return
doc
<<
Print
Sep
(
pats
)
<<
")"
;
}
Doc
VisitPattern_
(
const
PatternVarNode
*
pv
)
final
{
...
...
@@ -617,7 +633,7 @@ class PrettyPrinter :
args
.
push_back
(
PrintType
(
t
,
false
));
}
doc
<<
"["
;
doc
<<
Print
Vec
(
args
);
doc
<<
Print
Sep
(
args
);
doc
<<
"]"
;
return
doc
;
}
...
...
@@ -633,11 +649,7 @@ class PrettyPrinter :
for
(
NodeRef
shape
:
node
->
shape
)
{
shapes
.
push_back
(
PrintAttr
(
shape
));
}
doc
<<
PrintVec
(
shapes
);
// conform to python tuple format (1,)
if
(
node
->
shape
.
size
()
==
1
)
{
doc
<<
","
;
}
doc
<<
PrintSep
(
shapes
);
return
doc
<<
"), "
<<
PrintDType
(
node
->
dtype
)
<<
"]"
;
}
...
...
@@ -647,7 +659,7 @@ class PrettyPrinter :
fields
.
push_back
(
Print
(
field
));
}
Doc
doc
;
doc
<<
"("
<<
Print
Vec
(
fields
);
doc
<<
"("
<<
Print
Sep
(
fields
);
// conform to python tuple format (1,)
if
(
node
->
fields
.
size
()
==
1
)
{
doc
<<
","
;
...
...
@@ -664,14 +676,14 @@ class PrettyPrinter :
for
(
Type
type_param
:
node
->
type_params
)
{
type_params
.
push_back
(
Print
(
type_param
));
}
doc
<<
Print
Vec
(
type_params
);
doc
<<
Print
Sep
(
type_params
);
doc
<<
">"
;
}
std
::
vector
<
Doc
>
arg_types
;
for
(
Type
arg_type
:
node
->
arg_types
)
{
arg_types
.
push_back
(
Print
(
arg_type
));
}
return
doc
<<
"("
<<
Print
Vec
(
arg_types
)
<<
") -> "
<<
Print
(
node
->
ret_type
);
return
doc
<<
"("
<<
Print
Sep
(
arg_types
)
<<
") -> "
<<
Print
(
node
->
ret_type
);
}
Doc
VisitType_
(
const
RefTypeNode
*
node
)
final
{
...
...
@@ -710,7 +722,7 @@ class PrettyPrinter :
for
(
NodePtr
<
Node
>
val
:
op
->
data
)
{
arr_vals
.
push_back
(
PrintAttr
(
NodeRef
(
val
)));
}
doc
<<
Print
Vec
(
arr_vals
);
doc
<<
Print
Sep
(
arr_vals
);
doc
<<
"]"
;
return
doc
;
}
...
...
@@ -771,7 +783,9 @@ class PrettyPrinter::AttrPrinter : public AttrVisitor {
}
void
Visit
(
const
char
*
key
,
double
*
value
)
final
{
PrintKV
(
key
,
*
value
);
Doc
doc
;
doc
<<
key
<<
"="
<<
*
value
<<
"f"
;
docs
->
push_back
(
doc
);
}
void
Visit
(
const
char
*
key
,
int64_t
*
value
)
final
{
PrintKV
(
key
,
*
value
);
...
...
@@ -843,7 +857,7 @@ std::string PrettyPrint_(const NodeRef& node,
bool
show_meta_data
,
runtime
::
TypedPackedFunc
<
std
::
string
(
Expr
)
>
annotate
)
{
Doc
doc
;
doc
<<
"v0.0.3"
<<
"
\n
"
doc
<<
"v0.0.3"
<<
PrintNewLine
()
<<
PrettyPrinter
(
show_meta_data
,
annotate
).
PrintFinal
(
node
);
return
doc
.
str
();
}
...
...
tests/python/relay/test_ir_parser.py
View file @
2973f8a6
...
...
@@ -16,7 +16,7 @@
# under the License.
import
tvm
from
tvm
import
relay
from
tvm.relay.analysis
import
alpha_equal
from
tvm.relay.analysis
import
alpha_equal
,
assert_alpha_equal
from
nose.tools
import
nottest
,
raises
from
numpy
import
isclose
from
typing
import
Union
...
...
@@ -60,12 +60,9 @@ TYPES = {
"float16x4"
,
}
def
assert_alpha_equal
(
a
,
b
):
if
not
alpha_equal
(
a
,
b
):
raise
Exception
(
"lhs is: "
,
str
(
a
),
"rhs is: "
,
str
(
b
))
def
roundtrip
(
expr
):
assert_alpha_equal
(
relay
.
fromtext
(
str
(
expr
)),
expr
)
x
=
relay
.
fromtext
(
str
(
expr
))
assert_alpha_equal
(
x
,
expr
)
def
parse_text
(
code
):
...
...
@@ -112,6 +109,16 @@ def test_comments():
UNIT
)
assert
parses_as
(
"""
/* This is a block comment!
/*Block comment is recursive!*/
*/
()
"""
,
UNIT
)
def
test_int_literal
():
assert
isinstance
(
parse_text
(
"1"
),
relay
.
Constant
)
...
...
@@ -224,7 +231,7 @@ def test_let():
def
test_seq
():
assert
parses_as
(
"(); ()"
,
"();
;
()"
,
relay
.
Let
(
_
,
UNIT
,
...
...
@@ -538,7 +545,7 @@ def test_tensor_type():
)
assert
parses_as
(
"let
%
_ : Tensor[(1
,
), float32] = (); ()"
,
"let
%
_ : Tensor[(1), float32] = (); ()"
,
relay
.
Let
(
relay
.
Var
(
"_"
,
relay
.
TensorType
((
1
,),
"float32"
)),
UNIT
,
...
...
tests/python/relay/test_ir_text_printer.py
View file @
2973f8a6
...
...
@@ -15,14 +15,27 @@
# specific language governing permissions and limitations
# under the License.
import
tvm
from
tvm
import
relay
import
tvm.relay.testing
import
numpy
as
np
from
tvm
import
relay
from
tvm.relay
import
Expr
from
tvm.relay.analysis
import
alpha_equal
,
assert_alpha_equal
,
assert_graph_equal
,
free_vars
do_print
=
[
False
]
SEMVER
=
"v0.0.3
\n
"
def
astext
(
p
,
graph_equal
=
False
):
txt
=
p
.
astext
()
if
isinstance
(
p
,
Expr
)
and
free_vars
(
p
):
return
txt
x
=
relay
.
fromtext
(
txt
)
if
graph_equal
:
assert_graph_equal
(
x
,
p
)
else
:
assert_alpha_equal
(
x
,
p
)
return
txt
def
show
(
text
):
if
do_print
[
0
]:
print
(
"---------------------------"
)
...
...
@@ -35,8 +48,8 @@ def test_func():
z
=
relay
.
add
(
x
,
one
)
z
=
relay
.
add
(
z
,
z
)
f
=
relay
.
Function
([
x
,
y
],
z
)
show
(
z
.
astext
(
))
show
(
f
.
astext
(
))
show
(
astext
(
z
))
show
(
astext
(
f
))
def
test_env
():
...
...
@@ -47,7 +60,7 @@ def test_env():
f
=
relay
.
Function
([
x
,
y
],
z
)
env
=
relay
.
Module
()
env
[
"myf"
]
=
f
text
=
env
.
astext
(
)
text
=
astext
(
env
)
assert
"def @myf"
in
text
assert
"def @myf"
in
str
(
env
)
assert
"add(
%0
,
%0
) /* ty=float32 */"
in
text
...
...
@@ -65,7 +78,7 @@ def test_meta_data():
padding
=
(
1
,
1
),
channels
=
2
)
f
=
relay
.
Function
([
x
,
w
],
z
)
text
=
f
.
astext
(
)
text
=
astext
(
f
,
graph_equal
=
True
)
text_no_meta
=
str
(
f
)
assert
"channels=2"
in
text
assert
"channels=2"
in
text_no_meta
...
...
@@ -73,25 +86,22 @@ def test_meta_data():
assert
"meta[Variable][0]"
in
text_no_meta
assert
"type_key"
in
text
assert
"type_key"
not
in
text_no_meta
show
(
text
)
show
(
f
)
text
=
relay
.
const
([
1
,
2
,
3
])
.
astext
(
)
text
=
astext
(
relay
.
const
([
1
,
2
,
3
])
)
assert
"meta[relay.Constant][0]"
in
text
show
(
text
)
def
test_call_attrs
():
x
=
relay
.
var
(
"x"
)
# non default args
z
=
relay
.
nn
.
softmax
(
x
,
axis
=
2
)
assert
"axis=2"
in
z
.
astext
(
)
assert
"axis=2"
in
astext
(
z
)
# default args
z
=
relay
.
nn
.
softmax
(
x
)
assert
"softmax(
%
x)"
in
z
.
astext
(
)
assert
"softmax(
%
x)"
in
astext
(
z
)
# non default args
z
=
relay
.
expand_dims
(
x
,
axis
=
2
,
num_newaxis
=
2
)
assert
"num_newaxis=2"
in
z
.
astext
(
)
assert
"num_newaxis=2"
in
astext
(
z
)
def
test_let_if_scope
():
...
...
@@ -111,68 +121,72 @@ def test_let_if_scope():
result
=
sb
.
get
()
f
=
relay
.
Function
([
x
,
y
,
cond
],
result
)
text
=
f
.
astext
(
)
text
=
astext
(
f
)
assert
text
.
count
(
"{"
)
==
4
assert
"
%
cond: bool"
in
text
show
(
f
.
astext
(
))
show
(
astext
(
f
))
def
test_variable_name
():
# avoid pure number even if the namehint is pure number
v1
=
relay
.
var
(
"1"
)
assert
"
%
v1"
in
v1
.
astext
(
)
assert
"
%
v1"
in
astext
(
v1
)
def
test_mlp
():
net
,
params
=
tvm
.
relay
.
testing
.
mlp
.
get_workload
(
batch_size
=
1
)
net
.
astext
(
)
astext
(
net
)
def
test_resnet
():
net
,
params
=
tvm
.
relay
.
testing
.
resnet
.
get_workload
(
batch_size
=
1
)
net
.
astext
(
)
astext
(
net
)
def
test_mobilenet
():
net
,
params
=
tvm
.
relay
.
testing
.
mobilenet
.
get_workload
(
batch_size
=
1
)
net
.
astext
(
)
astext
(
net
)
def
test_dqn
():
net
,
params
=
tvm
.
relay
.
testing
.
dqn
.
get_workload
(
batch_size
=
1
)
net
.
astext
(
)
astext
(
net
)
def
test_dcgan
():
net
,
params
=
tvm
.
relay
.
testing
.
dcgan
.
get_workload
(
batch_size
=
1
)
net
.
astext
(
)
astext
(
net
)
def
test_lstm
():
net
,
params
=
tvm
.
relay
.
testing
.
lstm
.
get_workload
(
1
,
1
)
astext
(
net
)
net
,
params
=
tvm
.
relay
.
testing
.
lstm
.
get_workload
(
4
,
4
)
net
.
astext
(
)
astext
(
net
)
def
test_inception_v3
():
net
,
params
=
tvm
.
relay
.
testing
.
inception_v3
.
get_workload
(
batch_size
=
1
)
net
.
astext
(
)
astext
(
net
)
def
test_squeezenet
():
for
version
in
[
'1.0'
,
'1.1'
]:
net
,
params
=
tvm
.
relay
.
testing
.
squeezenet
.
get_workload
(
batch_size
=
1
,
version
=
version
)
net
.
astext
(
)
astext
(
net
)
def
test_vgg
():
net
,
params
=
tvm
.
relay
.
testing
.
vgg
.
get_workload
(
batch_size
=
1
)
net
.
astext
(
)
astext
(
net
)
def
test_densenet
():
net
,
params
=
tvm
.
relay
.
testing
.
densenet
.
get_workload
(
batch_size
=
1
)
net
.
astext
(
)
astext
(
net
)
def
test_call_node_order
():
x
=
relay
.
var
(
"x"
)
y
=
relay
.
var
(
"y"
)
assert
relay
.
Call
(
relay
.
Function
([
x
],
x
),
[
relay
.
Call
(
relay
.
Function
([
y
],
y
),
[
relay
.
const
(
1
)])])
.
astext
()
==
SEMVER
+
\
prog
=
relay
.
Call
(
relay
.
Function
([
x
],
x
),
[
relay
.
Call
(
relay
.
Function
([
y
],
y
),
[
relay
.
const
(
1
)])])
assert
astext
(
prog
)
==
SEMVER
+
\
(
"
%0
= fn (
%
y) {
\n
"
"
%
y
\n
"
"};
\n
"
...
...
@@ -185,17 +199,25 @@ def test_call_node_order():
def
test_let_inlining
():
tup
=
relay
.
Tuple
([
relay
.
const
(
0
),
relay
.
const
(
0
)])
x
=
relay
.
var
(
"x"
)
assert
relay
.
Let
(
x
,
tup
,
tup
)
.
astext
(
)
==
SEMVER
+
\
assert
astext
(
relay
.
Let
(
x
,
tup
,
tup
)
)
==
SEMVER
+
\
(
"
%0
= (0, 0);
\n
"
"let
%
x =
%0
;
\n
"
"
%0
"
)
assert
relay
.
Let
(
x
,
tup
,
x
)
.
astext
(
)
==
SEMVER
+
\
assert
astext
(
relay
.
Let
(
x
,
tup
,
x
)
)
==
SEMVER
+
\
(
"let
%
x = (0, 0);
\n
"
"
%
x"
)
def
test_zeros
():
x
=
relay
.
op
.
zeros
([],
"float32"
)
astext
(
x
)
if
__name__
==
"__main__"
:
do_print
[
0
]
=
True
test_lstm
()
test_zeros
()
test_meta_data
()
test_let_inlining
()
test_resnet
()
test_mobilenet
()
test_mlp
()
...
...
@@ -207,9 +229,7 @@ if __name__ == "__main__":
test_densenet
()
test_func
()
test_env
()
test_meta_data
()
test_call_attrs
()
test_let_if_scope
()
test_variable_name
()
test_call_node_order
()
test_let_inlining
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment