Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
d274e4b3
Commit
d274e4b3
authored
Jan 17, 2019
by
Jared Roesch
Committed by
Tianqi Chen
Jan 17, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Relay][Parser] Improve Relay parser and pretty printing, including CMAKE (#2377)
parent
d0f83664
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
381 additions
and
200 deletions
+381
-200
cmake/modules/ANTLR.cmake
+19
-5
include/tvm/relay/base.h
+3
-1
python/tvm/relay/_base.py
+5
-0
python/tvm/relay/_parser.py
+110
-19
python/tvm/relay/base.py
+10
-0
python/tvm/relay/grammar/Relay.g4
+34
-20
python/tvm/relay/parser.py
+9
-3
src/relay/ir/base.cc
+14
-0
tests/python/relay/test_ir_parser.py
+177
-152
No files found.
cmake/modules/ANTLR.cmake
View file @
d274e4b3
if
(
USE_ANTLR
)
if
(
EXISTS /usr/local/lib/antlr-4.7.1-complete.jar
)
set
(
ANTLR4
"/usr/local/lib/antlr-4.7.1-complete.jar"
)
file
(
GLOB_RECURSE ANTLR4
/usr/local/lib/antlr-*-complete.jar
/usr/local/Cellar/*antlr-*-complete.jar
)
if
(
DEFINED ANTLR4
)
# Get the first element of the list of antlr jars.
# Sort and reverse the list so the item selected is the highest
# version in lib or else in Cellar if no lib installation exists.
list
(
SORT ANTLR4
)
list
(
REVERSE ANTLR4
)
list
(
GET ANTLR4 0 ANTLR4
)
set
(
RELAY_PARSER_DIR
${
CMAKE_CURRENT_SOURCE_DIR
}
/python/tvm/relay/grammar
)
...
...
@@ -14,15 +22,21 @@ if(USE_ANTLR)
${
RELAY_PARSER_DIR
}
/py3/RelayParser.py
${
RELAY_PARSER_DIR
}
/py3/RelayLexer.py
)
set
(
JAVA_HOME $ENV{JAVA_HOME}
)
if
(
NOT DEFINED JAVA_HOME
)
# Hack to get system to search for Java itself.
set
(
JAVA_HOME
"/usr"
)
endif
()
# Generate ANTLR grammar for parsing.
add_custom_command
(
OUTPUT
${
RELAY_PARSER
}
COMMAND $
ENV
{JAVA_HOME}/bin/java -jar
${
ANTLR4
}
-visitor -no-listener -Dlanguage=Python2
${
RELAY_PARSER_DIR
}
/Relay.g4 -o
${
RELAY_PARSER_DIR
}
/py2
COMMAND $
ENV
{JAVA_HOME}/bin/java -jar
${
ANTLR4
}
-visitor -no-listener -Dlanguage=Python3
${
RELAY_PARSER_DIR
}
/Relay.g4 -o
${
RELAY_PARSER_DIR
}
/py3
COMMAND
${
JAVA_HOME
}
/bin/java -jar
${
ANTLR4
}
-visitor -no-listener -Dlanguage=Python2
${
RELAY_PARSER_DIR
}
/Relay.g4 -o
${
RELAY_PARSER_DIR
}
/py2
COMMAND
${
JAVA_HOME
}
/bin/java -jar
${
ANTLR4
}
-visitor -no-listener -Dlanguage=Python3
${
RELAY_PARSER_DIR
}
/Relay.g4 -o
${
RELAY_PARSER_DIR
}
/py3
DEPENDS
${
RELAY_PARSER_DIR
}
/Relay.g4
WORKING_DIRECTORY
${
RELAY_PARSER_DIR
}
)
add_custom_target
(
relay_parser ALL DEPENDS
${
RELAY_PARSER
}
)
else
()
message
(
FATAL_ERROR
"Can't find ANTLR4
!"
)
message
(
FATAL_ERROR
"Can't find ANTLR4
: ANTLR4="
${
ANTLR4
}
)
endif
()
endif
(
USE_ANTLR
)
include/tvm/relay/base.h
View file @
d274e4b3
...
...
@@ -108,7 +108,9 @@ class SourceName : public NodeRef {
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline
const
SourceNameNode
*
operator
->
()
const
;
inline
const
SourceNameNode
*
operator
->
()
const
{
return
static_cast
<
SourceNameNode
*>
(
this
->
node_
.
get
());
}
/*!
* \brief Get an SourceName for a given operator name.
...
...
python/tvm/relay/_base.py
0 → 100644
View file @
d274e4b3
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable
"""The interface of expr function exposed from C++."""
from
tvm._ffi.function
import
_init_api
_init_api
(
"relay._base"
,
__name__
)
python/tvm/relay/_parser.py
View file @
d274e4b3
...
...
@@ -6,13 +6,17 @@ from __future__ import absolute_import
import
sys
from
collections
import
deque
from
typing
import
TypeVar
,
Deque
,
Tuple
,
Optional
,
Union
,
NamedTuple
,
List
,
Callable
,
Any
from
typing
import
TypeVar
,
Deque
,
Tuple
,
Optional
,
Union
,
NamedTuple
,
List
,
Callable
,
Any
,
Dict
import
tvm
from
.
import
module
from
.base
import
Span
,
SourceName
from
.
import
expr
from
.
import
ty
from
.
import
op
class
ParseError
(
Exception
):
"""Exception type for parse errors."""
...
...
@@ -76,22 +80,46 @@ def lookup(scopes, name):
return
val
return
None
def
spanify
(
f
):
"""A decorator which attaches span information
to the value returned by calling `f`.
Intended for use with the below AST visiting
methods. The idea is that after we do the work
of constructing the AST we attach Span information.
"""
def
_wrapper
(
*
args
,
**
kwargs
):
# Assumes 0th arg is self and gets source_name from object.
sn
=
args
[
0
]
.
source_name
# Assumes 1st arg is an ANTLR parser context.
ctx
=
args
[
1
]
ast
=
f
(
*
args
,
**
kwargs
)
line
,
col
=
ctx
.
getSourceInterval
()
sp
=
Span
(
sn
,
line
,
col
)
ast
.
set_span
(
sp
)
return
ast
return
_wrapper
# TODO(@jmp): Use https://stackoverflow.com/q/13889941
# to figure out how to get ANTLR4 to be more unhappy about syntax errors
class
ParseTreeToRelayIR
(
RelayVisitor
):
"""Parse Relay text format into Relay IR."""
def
__init__
(
self
):
# type: () -> None
def
__init__
(
self
,
source_name
):
# type: (str) -> None
self
.
source_name
=
source_name
self
.
module
=
module
.
Module
({})
# type: module.Module
# Adding an empty scope allows naked lets without pain.
self
.
var_scopes
=
deque
([
deque
()])
# type: Scopes[expr.Var]
self
.
global_var_scope
=
deque
()
# type: Scope[expr.GlobalVar]
self
.
type_param_scopes
=
deque
([
deque
()])
# type: Scopes[ty.TypeVar]
self
.
graph_expr
=
[]
# type: List[expr.Expr]
super
(
ParseTreeToRelayIR
,
self
)
.
__init__
()
def
enter_var_scope
(
self
):
# type: () -> None
"""Enter a new Var scope so it can be popped off later."""
...
...
@@ -146,20 +174,25 @@ class ParseTreeToRelayIR(RelayVisitor):
node_type
=
node
.
getSymbol
()
.
type
node_text
=
node
.
getText
()
name
=
node_text
[
1
:]
# variables
if
node_type
==
RelayLexer
.
GLOBAL_VAR
:
return
lookup
(
[
self
.
global_var_scope
]
,
node_text
[
1
:])
return
lookup
(
deque
([
self
.
global_var_scope
])
,
node_text
[
1
:])
elif
node_type
==
RelayLexer
.
LOCAL_VAR
:
name
=
node_text
[
1
:]
# Remove the leading '%' and lookup the name.
var
=
lookup
(
self
.
var_scopes
,
name
)
if
var
is
None
:
raise
ParseError
(
"Couldn't resolve `{}`."
.
format
(
name
))
return
var
elif
node_type
==
RelayLexer
.
GRAPH_VAR
:
try
:
return
self
.
graph_expr
[
int
(
name
)]
except
IndexError
:
raise
ParseError
(
"Couldn't resolve `{}`"
.
format
(
name
))
# data types
elif
node_type
==
RelayLexer
.
IN
T
:
elif
node_type
==
RelayLexer
.
NA
T
:
return
int
(
node_text
)
elif
node_type
==
RelayLexer
.
FLOAT
:
return
float
(
node_text
)
...
...
@@ -190,7 +223,7 @@ class ParseTreeToRelayIR(RelayVisitor):
return
self
.
visit
(
ctx
)
def
visitProg
(
self
,
ctx
):
# type: (RelayParser.ProgContext) -> Union[expr.Expr,
env.Environment
]
# type: (RelayParser.ProgContext) -> Union[expr.Expr,
module.Module
]
if
ctx
.
defn
():
self
.
visit_list
(
ctx
.
defn
())
return
self
.
module
...
...
@@ -219,7 +252,7 @@ class ParseTreeToRelayIR(RelayVisitor):
def
visitScalarInt
(
self
,
ctx
):
# type: (RelayParser.ScalarIntContext) -> expr.Constant
return
expr
.
const
(
self
.
visit
(
ctx
.
IN
T
()))
return
expr
.
const
(
self
.
visit
(
ctx
.
NA
T
()))
def
visitScalarBool
(
self
,
ctx
):
# type: (RelayParser.ScalarBoolContext) -> expr.Constant
...
...
@@ -240,7 +273,7 @@ class ParseTreeToRelayIR(RelayVisitor):
return
expr
.
Tuple
(
tup
)
# Currently doesn't support mutable sequencing.
def
visit
Seq
(
self
,
ctx
):
def
visit
Let
(
self
,
ctx
):
# type: (RelayParser.SeqContext) -> expr.Let
"""Desugar various sequence constructs to Relay Let nodes."""
if
ctx
.
MUT
()
is
not
None
:
...
...
@@ -253,7 +286,7 @@ class ParseTreeToRelayIR(RelayVisitor):
else
:
local_var
=
ctx
.
var
()
.
ident
()
.
LOCAL_VAR
()
if
local_var
is
None
:
raise
ParseError
(
'Only local ids may be used in `let`s.'
)
raise
ParseError
(
"Only local ids may be used in `let`s."
)
ident
=
local_var
.
getText
()[
1
:]
type_
=
self
.
getType_
(
ctx
.
var
()
.
type_
())
...
...
@@ -278,12 +311,14 @@ class ParseTreeToRelayIR(RelayVisitor):
return
relay_op
(
arg0
,
arg1
)
@spanify
def
visitVar
(
self
,
ctx
):
# type: (RelayParser.VarContext) -> expr.Var
"""Visit a single variable."""
ident
=
ctx
.
ident
()
.
LOCAL_VAR
()
if
ident
is
None
:
raise
ParseError
(
'Only local ids may be used in params.'
)
raise
ParseError
(
"Only local ids may be used in vars."
)
type_
=
self
.
getType_
(
ctx
.
type_
())
...
...
@@ -293,15 +328,33 @@ class ParseTreeToRelayIR(RelayVisitor):
# type: (RelayParser.VarListContext) -> List[expr.Var]
return
self
.
visit_list
(
ctx
.
var
())
# TODO: support a larger class of values than just Relay exprs
def
visitAttr
(
self
,
ctx
):
# type: (RelayParser.AttrContext) -> Tuple[str, expr.Expr]
return
(
ctx
.
CNAME
()
.
getText
(),
self
.
visit
(
ctx
.
expr
()))
def
visitAttrList
(
self
,
ctx
):
# type: (RelayParser.AttrListContext) -> Dict[str, expr.Expr]
return
dict
(
self
.
visit_list
(
ctx
.
attr
()))
def
visitArgList
(
self
,
ctx
# type: RelayParser.ArgListContext
):
# type: (...) -> Tuple[Optional[List[expr.Var]], Optional[Dict[str, expr.Expr]]]
var_list
=
self
.
visit
(
ctx
.
varList
())
if
ctx
.
varList
()
else
None
attr_list
=
self
.
visit
(
ctx
.
attrList
())
if
ctx
.
attrList
()
else
None
return
(
var_list
,
attr_list
)
def
mk_func
(
self
,
ctx
):
# type: (Union[RelayParser.FuncContext, RelayParser.DefnContext]) -> Function
# type: (Union[RelayParser.FuncContext, RelayParser.DefnContext]) ->
expr.
Function
"""Construct a function from either a Func or Defn."""
# Enter var scope early to put params in scope.
self
.
enter_var_scope
()
# Capture type params in params.
self
.
enter_type_param_scope
()
var_list
=
self
.
visit
(
ctx
.
var
List
())
var_list
,
attr_list
=
self
.
visit
(
ctx
.
arg
List
())
ret_type
=
self
.
getType_
(
ctx
.
type_
())
type_params
=
list
(
self
.
exit_type_param_scope
())
...
...
@@ -311,22 +364,28 @@ class ParseTreeToRelayIR(RelayVisitor):
body
=
self
.
visit
(
ctx
.
body
())
self
.
exit_var_scope
()
return
expr
.
Function
(
var_list
,
body
,
ret_type
,
type_params
)
# type: ignor
e
attrs
=
tvm
.
make
.
node
(
"DictAttrs"
,
**
attr_list
)
if
attr_list
is
not
None
else
Non
e
return
expr
.
Function
(
var_list
,
body
,
ret_type
,
type_params
,
attrs
)
@spanify
def
visitFunc
(
self
,
ctx
):
# type: (RelayParser.FuncContext) -> expr.Function
return
self
.
mk_func
(
ctx
)
# TODO: how to set spans for definitions?
# @spanify
def
visitDefn
(
self
,
ctx
):
# type: (RelayParser.DefnContext) -> None
ident
=
ctx
.
ident
()
.
GLOBAL_VAR
()
if
ident
is
None
:
raise
ParseError
(
'Only global ids may be used in `def`s.'
)
raise
ParseError
(
"Only global ids may be used in `def`s."
)
ident_name
=
ident
.
getText
()[
1
:]
ident
=
self
.
mk_global_var
(
ident_name
)
self
.
module
[
ident
]
=
self
.
mk_func
(
ctx
)
@spanify
def
visitCall
(
self
,
ctx
):
# type: (RelayParser.CallContext) -> expr.Call
visited_exprs
=
self
.
visit_list
(
ctx
.
expr
())
...
...
@@ -336,6 +395,7 @@ class ParseTreeToRelayIR(RelayVisitor):
return
expr
.
Call
(
func
,
args
,
None
,
None
)
@spanify
def
visitIfElse
(
self
,
ctx
):
# type: (RelayParser.IfElseContext) -> expr.If
"""Construct a Relay If node. Creates a new scope for each branch."""
...
...
@@ -351,6 +411,27 @@ class ParseTreeToRelayIR(RelayVisitor):
return
expr
.
If
(
cond
,
true_branch
,
false_branch
)
@spanify
def
visitGraph
(
self
,
ctx
):
# type: (RelayParser.GraphContext) -> expr.Expr
"""Visit a graph variable assignment."""
if
ctx
.
ident
()
.
GRAPH_VAR
()
is
None
:
raise
ParseError
(
"Expected a graph var, but got `{}`"
.
format
(
ctx
.
ident
()
.
getText
()))
graph_nid
=
int
(
ctx
.
ident
()
.
GRAPH_VAR
()
.
getText
()[
1
:])
self
.
enter_var_scope
()
value
=
self
.
visit
(
ctx
.
expr
(
0
))
self
.
exit_var_scope
()
if
graph_nid
!=
len
(
self
.
graph_expr
):
raise
ParseError
(
"Expected new graph variable to be `
%
{}`,"
.
format
(
len
(
self
.
graph_expr
))
+
\
"but got `
%
{}`"
.
format
(
graph_nid
))
self
.
graph_expr
.
append
(
value
)
kont
=
self
.
visit
(
ctx
.
expr
(
1
))
return
kont
# Types
# pylint: disable=unused-argument
...
...
@@ -428,8 +509,18 @@ def make_parser(data):
token_stream
=
CommonTokenStream
(
lexer
)
return
RelayParser
(
token_stream
)
def
fromtext
(
data
):
# type: (str) -> Union[expr.Expr, env.Environment]
__source_name_counter__
=
0
def
fromtext
(
data
,
source_name
=
None
):
# type: (str, str) -> Union[expr.Expr, module.Module]
"""Parse a Relay program."""
global
__source_name_counter__
if
source_name
is
None
:
source_name
=
"source_file{0}"
.
format
(
__source_name_counter__
)
if
isinstance
(
source_name
,
str
):
source_name
=
SourceName
(
source_name
)
tree
=
make_parser
(
data
)
.
prog
()
return
ParseTreeToRelayIR
()
.
visit
(
tree
)
return
ParseTreeToRelayIR
(
source_name
)
.
visit
(
tree
)
python/tvm/relay/base.py
View file @
d274e4b3
...
...
@@ -4,6 +4,7 @@ from __future__ import absolute_import as _abs
from
.._ffi.node
import
NodeBase
,
register_node
as
_register_tvm_node
from
.
import
_make
from
.
import
_expr
from
.
import
_base
NodeBase
=
NodeBase
...
...
@@ -63,6 +64,9 @@ class RelayNode(NodeBase):
"""
return
_expr
.
RelayPrint
(
self
,
show_meta_data
,
annotate
)
def
set_span
(
self
,
span
):
_base
.
set_span
(
self
,
span
)
@register_relay_node
class
Span
(
RelayNode
):
...
...
@@ -71,6 +75,12 @@ class Span(RelayNode):
def
__init__
(
self
,
source
,
lineno
,
col_offset
):
self
.
__init_handle_by_constructor__
(
_make
.
Span
,
source
,
lineno
,
col_offset
)
@register_relay_node
class
SourceName
(
RelayNode
):
"""A identifier for a source location"""
def
__init__
(
self
,
name
):
self
.
__init_handle_by_constructor__
(
_make
.
SourceName
,
name
)
@register_relay_node
class
Id
(
NodeBase
):
...
...
python/tvm/relay/grammar/Relay.g4
View file @
d274e4b3
grammar Relay;
SEMVER: 'v0.0.1' ;
// Lexing
// comments
WS : [ \t\n\r]+ -> skip ;
...
...
@@ -20,7 +22,8 @@ NE: '!=' ;
opIdent: CNAME ;
GLOBAL_VAR: '@' CNAME ;
LOCAL_VAR: '%' CNAME ;
LOCAL_VAR: '%' CNAME;
GRAPH_VAR: '%' NAT;
MUT: 'mut' ;
...
...
@@ -31,13 +34,13 @@ BOOL_LIT
// non-negative floats
FLOAT
:
INT '.' IN
T EXP? // 1.35, 1.35E-9, 0.3, 4.5
|
IN
T EXP // 1e10 3e4
:
NAT '.' NA
T EXP? // 1.35, 1.35E-9, 0.3, 4.5
|
NA
T EXP // 1e10 3e4
;
// non-negative ints
IN
T: DIGIT+ ;
fragment EXP: [eE] [+\-]?
IN
T ; // \- since - means "range" inside [...]
NA
T: DIGIT+ ;
fragment EXP: [eE] [+\-]?
NA
T ; // \- since - means "range" inside [...]
CNAME: ('_'|LETTER) ('_'|LETTER|DIGIT)* ;
fragment LETTER: [a-zA-Z] ;
...
...
@@ -46,7 +49,7 @@ fragment DIGIT: [0-9] ;
// Parsing
// A Relay program is a list of global definitions or an expression.
prog: (defn* | expr) EOF ;
prog:
SEMVER
(defn* | expr) EOF ;
// option: 'set' ident BOOL_LIT ;
...
...
@@ -73,10 +76,11 @@ expr
| 'if' '(' expr ')' body 'else' body # ifElse
// sequencing
| 'let' MUT? var '=' expr ';' expr #
seq
| 'let' MUT? var '=' '{' expr '}' ';' expr #
seq
| 'let' MUT? var '=' expr ';' expr #
let
| 'let' MUT? var '=' '{' expr '}' ';' expr #
let
// sugar for let %_ = expr; expr
| expr ';' expr # seq
| expr ';' expr # let
| ident '=' expr ';' expr # graph
// mutable update
// | ident '=' expr # writeRef
...
...
@@ -84,16 +88,25 @@ expr
| ident # identExpr
| scalar # scalarExpr
// | expr '.'
INT
# project
// | expr '.'
NAT
# project
// | 'debug' # debug
;
func: 'fn' varList ('->' type_)? body ;
defn: 'def' ident varList ('->' type_)? body ;
func: 'fn' '(' argList ')' ('->' type_)? body ;
defn: 'def' ident '(' argList ')' ('->' type_)? body ;
argList
: varList
| attrList
| varList ',' attrList
;
varList:
'(' (var (',' var)*)? ')'
;
varList:
(var (',' var)*)?
;
var: ident (':' type_)? ;
attrList: (attr (',' attr)*)? ;
attr: CNAME '=' expr ;
// TODO(@jmp): for improved type annotations
// returnAnno: (ident ':')? type_ ;
...
...
@@ -110,7 +123,7 @@ type_
// | identType '[' (type_ (',' type_)*)? ']' # callType
| 'fn' '(' (type_ (',' type_)*)? ')' '->' type_ # funcType
| '_' # incompleteType
|
IN
T # intType
|
NA
T # intType
;
shapeSeq
...
...
@@ -123,20 +136,20 @@ shape
: '(' shape ')' # parensShape
// | type_ op=('*'|'/') type_ # binOpType
// | type_ op=('+'|'-') type_ # binOpType
|
IN
T # intShape
|
NA
T # intShape
;
identType: CNAME ;
//
Int8, Int16, Int32, I
nt64
//
UInt8, UInt16, UInt32, UI
nt64
//
Float16, Float32, F
loat64
//
B
ool
//
int8, int16, int32, i
nt64
//
uint8, uint16, uint32, ui
nt64
//
float16, float32, f
loat64
//
b
ool
body: '{' expr '}' ;
scalar
: FLOAT # scalarFloat
|
IN
T # scalarInt
|
NA
T # scalarInt
| BOOL_LIT # scalarBool
;
...
...
@@ -144,4 +157,5 @@ ident
: opIdent
| GLOBAL_VAR
| LOCAL_VAR
| GRAPH_VAR
;
python/tvm/relay/parser.py
View file @
d274e4b3
"""A parser for Relay's text format."""
from
__future__
import
absolute_import
from
..
import
register_func
def
enabled
():
"""Is the parser enabled/Can we import the parser?"""
"""Checks whether the parser is enabled, this allows users to
optionally support building the parser.
We use this check before importing the parser.
"""
try
:
# pylint: disable=unused-variable
from
tvm.relay
import
_parser
...
...
@@ -11,7 +16,8 @@ def enabled():
except
Exception
:
return
False
def
fromtext
(
data
):
@register_func
(
"relay.fromtext"
)
def
fromtext
(
data
,
source_name
=
None
):
"""Parse a Relay program."""
from
tvm.relay
import
_parser
return
_parser
.
fromtext
(
data
)
return
_parser
.
fromtext
(
data
,
source_name
)
src/relay/ir/base.cc
View file @
d274e4b3
...
...
@@ -32,6 +32,11 @@ SourceName SourceName::Get(const std::string& name) {
return
SourceName
(
GetSourceNameNode
(
name
));
}
TVM_REGISTER_API
(
"relay._make.SourceName"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
SourceName
::
Get
(
args
[
0
]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER
(
IRPrinter
,
vtable
)
.
set_dispatch
<
SourceNameNode
>
([](
const
SourceNameNode
*
node
,
tvm
::
IRPrinter
*
p
)
{
p
->
stream
<<
"SourceName("
<<
node
->
name
<<
", "
<<
node
<<
")"
;
...
...
@@ -66,5 +71,14 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
TVM_REGISTER_NODE_TYPE
(
IdNode
);
TVM_REGISTER_API
(
"relay._base.set_span"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
NodeRef
node_ref
=
args
[
0
];
auto
rn
=
node_ref
.
as_derived
<
RelayNode
>
();
CHECK
(
rn
);
Span
sp
=
args
[
1
];
rn
->
span
=
sp
;
});
}
// namespace relay
}
// namespace tvm
tests/python/relay/test_ir_parser.py
View file @
d274e4b3
...
...
@@ -8,11 +8,12 @@ from numpy import isclose
from
typing
import
Union
from
functools
import
wraps
if
enabled
():
from
tvm.relay._parser
import
ParseError
raises_parse_error
=
raises
(
ParseError
)
raises_parse_error
=
raises
(
tvm
.
_ffi
.
base
.
TVMError
)
else
:
raises_parse_error
=
lambda
x
:
x
SEMVER
=
"v0.0.1"
BINARY_OPS
=
{
"*"
:
relay
.
multiply
,
"/"
:
relay
.
divide
,
...
...
@@ -48,6 +49,10 @@ TYPES = {
"float16x4"
,
}
def
parses_as
(
code
,
expr
):
# type: (str, relay.Expr) -> bool
return
alpha_equal
(
relay
.
fromtext
(
SEMVER
+
"
\n
"
+
code
),
expr
)
def
get_scalar
(
x
):
# type: (relay.Constant) -> (Union[float, int, bool])
return
x
.
data
.
asnumpy
()
.
item
()
...
...
@@ -74,80 +79,80 @@ def if_parser_enabled(func):
@if_parser_enabled
def
test_comments
():
assert
alpha_equal
(
relay
.
fromtext
(
"""
assert
parses_as
(
"""
// This is a line comment!
()
"""
)
,
"""
,
UNIT
)
assert
alpha_equal
(
relay
.
fromtext
(
"""
assert
parses_as
(
"""
/* This is a block comment!
This is still a block comment!
*/
()
"""
)
,
"""
,
UNIT
)
@if_parser_enabled
def
test_int_literal
():
assert
isinstance
(
relay
.
fromtext
(
"1"
),
relay
.
Constant
)
assert
isinstance
(
relay
.
fromtext
(
"1"
)
.
data
,
tvm
.
ndarray
.
NDArray
)
assert
isinstance
(
relay
.
fromtext
(
SEMVER
+
"1"
),
relay
.
Constant
)
assert
isinstance
(
relay
.
fromtext
(
SEMVER
+
"1"
)
.
data
,
tvm
.
ndarray
.
NDArray
)
assert
get_scalar
(
relay
.
fromtext
(
"1"
))
==
1
assert
get_scalar
(
relay
.
fromtext
(
"10"
))
==
10
assert
get_scalar
(
relay
.
fromtext
(
"0"
))
==
0
assert
get_scalar
(
relay
.
fromtext
(
"-100"
))
==
-
100
assert
get_scalar
(
relay
.
fromtext
(
"-05"
))
==
-
5
assert
get_scalar
(
relay
.
fromtext
(
SEMVER
+
"1"
))
==
1
assert
get_scalar
(
relay
.
fromtext
(
SEMVER
+
"10"
))
==
10
assert
get_scalar
(
relay
.
fromtext
(
SEMVER
+
"0"
))
==
0
assert
get_scalar
(
relay
.
fromtext
(
SEMVER
+
"-100"
))
==
-
100
assert
get_scalar
(
relay
.
fromtext
(
SEMVER
+
"-05"
))
==
-
5
@if_parser_enabled
def
test_float_literal
():
assert
get_scalar
(
relay
.
fromtext
(
"1.0"
))
==
1.0
assert
isclose
(
get_scalar
(
relay
.
fromtext
(
"1.56667"
)),
1.56667
)
assert
get_scalar
(
relay
.
fromtext
(
"0.0"
))
==
0.0
assert
get_scalar
(
relay
.
fromtext
(
"-10.0"
))
==
-
10.0
assert
get_scalar
(
relay
.
fromtext
(
SEMVER
+
"1.0"
))
==
1.0
assert
isclose
(
get_scalar
(
relay
.
fromtext
(
SEMVER
+
"1.56667"
)),
1.56667
)
assert
get_scalar
(
relay
.
fromtext
(
SEMVER
+
"0.0"
))
==
0.0
assert
get_scalar
(
relay
.
fromtext
(
SEMVER
+
"-10.0"
))
==
-
10.0
# scientific notation
assert
isclose
(
get_scalar
(
relay
.
fromtext
(
"1e-1"
)),
1e-1
)
assert
get_scalar
(
relay
.
fromtext
(
"1e+1"
))
==
1e+1
assert
isclose
(
get_scalar
(
relay
.
fromtext
(
"1E-1"
)),
1E-1
)
assert
get_scalar
(
relay
.
fromtext
(
"1E+1"
))
==
1E+1
assert
isclose
(
get_scalar
(
relay
.
fromtext
(
"1.0e-1"
)),
1.0e-1
)
assert
get_scalar
(
relay
.
fromtext
(
"1.0e+1"
))
==
1.0e+1
assert
isclose
(
get_scalar
(
relay
.
fromtext
(
"1.0E-1"
)),
1.0E-1
)
assert
get_scalar
(
relay
.
fromtext
(
"1.0E+1"
))
==
1.0E+1
assert
isclose
(
get_scalar
(
relay
.
fromtext
(
SEMVER
+
"1e-1"
)),
1e-1
)
assert
get_scalar
(
relay
.
fromtext
(
SEMVER
+
"1e+1"
))
==
1e+1
assert
isclose
(
get_scalar
(
relay
.
fromtext
(
SEMVER
+
"1E-1"
)),
1E-1
)
assert
get_scalar
(
relay
.
fromtext
(
SEMVER
+
"1E+1"
))
==
1E+1
assert
isclose
(
get_scalar
(
relay
.
fromtext
(
SEMVER
+
"1.0e-1"
)),
1.0e-1
)
assert
get_scalar
(
relay
.
fromtext
(
SEMVER
+
"1.0e+1"
))
==
1.0e+1
assert
isclose
(
get_scalar
(
relay
.
fromtext
(
SEMVER
+
"1.0E-1"
)),
1.0E-1
)
assert
get_scalar
(
relay
.
fromtext
(
SEMVER
+
"1.0E+1"
))
==
1.0E+1
@if_parser_enabled
def
test_bool_literal
():
assert
get_scalar
(
relay
.
fromtext
(
"True"
))
==
True
assert
get_scalar
(
relay
.
fromtext
(
"False"
))
==
False
assert
get_scalar
(
relay
.
fromtext
(
SEMVER
+
"True"
))
==
True
assert
get_scalar
(
relay
.
fromtext
(
SEMVER
+
"False"
))
==
False
@if_parser_enabled
def
test_negative
():
assert
isinstance
(
relay
.
fromtext
(
"let
%
x = 1; -
%
x"
)
.
body
,
relay
.
Call
)
assert
get_scalar
(
relay
.
fromtext
(
"--10"
))
==
10
assert
get_scalar
(
relay
.
fromtext
(
"---10"
))
==
-
10
assert
isinstance
(
relay
.
fromtext
(
SEMVER
+
"let
%
x = 1; -
%
x"
)
.
body
,
relay
.
Call
)
assert
get_scalar
(
relay
.
fromtext
(
SEMVER
+
"--10"
))
==
10
assert
get_scalar
(
relay
.
fromtext
(
SEMVER
+
"---10"
))
==
-
10
@if_parser_enabled
def
test_bin_op
():
for
bin_op
in
BINARY_OPS
.
keys
():
assert
alpha_equal
(
relay
.
fromtext
(
"1 {} 1"
.
format
(
bin_op
)
),
assert
parses_as
(
"1 {} 1"
.
format
(
bin_op
),
BINARY_OPS
.
get
(
bin_op
)(
relay
.
const
(
1
),
relay
.
const
(
1
))
)
@if_parser_enabled
def
test_parens
():
assert
alpha_equal
(
relay
.
fromtext
(
"1 * 1 + 1"
),
relay
.
fromtext
(
"(1 * 1) + 1"
))
assert
not
alpha_equal
(
relay
.
fromtext
(
"1 * 1 + 1"
),
relay
.
fromtext
(
"1 * (1 + 1)"
))
assert
alpha_equal
(
relay
.
fromtext
(
SEMVER
+
"1 * 1 + 1"
),
relay
.
fromtext
(
SEMVER
+
"(1 * 1) + 1"
))
assert
not
alpha_equal
(
relay
.
fromtext
(
SEMVER
+
"1 * 1 + 1"
),
relay
.
fromtext
(
SEMVER
+
"1 * (1 + 1)"
))
@if_parser_enabled
def
test_op_assoc
():
assert
alpha_equal
(
relay
.
fromtext
(
"1 * 1 + 1 < 1 == 1"
),
relay
.
fromtext
(
"(((1 * 1) + 1) < 1) == 1"
))
assert
alpha_equal
(
relay
.
fromtext
(
"1 == 1 < 1 + 1 * 1"
),
relay
.
fromtext
(
"1 == (1 < (1 + (1 * 1)))"
))
assert
alpha_equal
(
relay
.
fromtext
(
SEMVER
+
"1 * 1 + 1 < 1 == 1"
),
relay
.
fromtext
(
SEMVER
+
"(((1 * 1) + 1) < 1) == 1"
))
assert
alpha_equal
(
relay
.
fromtext
(
SEMVER
+
"1 == 1 < 1 + 1 * 1"
),
relay
.
fromtext
(
SEMVER
+
"1 == (1 < (1 + (1 * 1)))"
))
@nottest
@if_parser_enabled
...
...
@@ -159,24 +164,24 @@ def test_vars():
# assert temp_var.name == "1"
# var
var
=
relay
.
fromtext
(
"let
%
foo = ();
%
foo"
)
var
=
relay
.
fromtext
(
SEMVER
+
"let
%
foo = ();
%
foo"
)
assert
isinstance
(
var
.
body
,
relay
.
Var
)
assert
var
.
body
.
name_hint
==
"foo"
# global var
global_var
=
relay
.
fromtext
(
"@foo"
)
global_var
=
relay
.
fromtext
(
SEMVER
+
"@foo"
)
assert
isinstance
(
global_var
,
relay
.
GlobalVar
)
assert
global_var
.
name_hint
==
"foo"
# operator id
op
=
relay
.
fromtext
(
"foo"
)
op
=
relay
.
fromtext
(
SEMVER
+
"foo"
)
assert
isinstance
(
op
,
relay
.
Op
)
assert
op
.
name
==
"foo"
@if_parser_enabled
def
test_let
():
assert
alpha_equal
(
relay
.
fromtext
(
"let
%
x = 1; ()"
)
,
assert
parses_as
(
"let
%
x = 1; ()"
,
relay
.
Let
(
X
,
relay
.
const
(
1
),
...
...
@@ -184,18 +189,35 @@ def test_let():
)
)
assert
parses_as
(
"""
let
%
x = 1;
let
%
y = 2;
()
"""
,
relay
.
Let
(
X
,
relay
.
const
(
1
),
relay
.
Let
(
Y
,
relay
.
const
(
2
),
UNIT
)
)
)
@if_parser_enabled
def
test_seq
():
assert
alpha_equal
(
relay
.
fromtext
(
"(); ()"
)
,
assert
parses_as
(
"(); ()"
,
relay
.
Let
(
_
,
UNIT
,
UNIT
)
)
assert
alpha_equal
(
relay
.
fromtext
(
"let
%
_ = { 1 }; ()"
)
,
assert
parses_as
(
"let
%
_ = { 1 }; ()"
,
relay
.
Let
(
X
,
relay
.
const
(
1
),
...
...
@@ -203,31 +225,48 @@ def test_seq():
)
)
@if_parser_enabled
def
test_graph
():
assert
parses_as
(
"
%0
= ();
%1
= 1; (
%0
,
%0
,
%1
)"
,
relay
.
Tuple
([
UNIT
,
UNIT
,
relay
.
const
(
1
)])
)
assert
not
parses_as
(
"
%0
= ();
%1
= 1; (
%0
,
%0
,
%1
)"
,
relay
.
Tuple
([
relay
.
Tuple
([]),
relay
.
Tuple
([]),
relay
.
const
(
1
)])
)
@raises_parse_error
@if_parser_enabled
def
test_graph_wrong_order
():
relay
.
fromtext
(
SEMVER
+
"
%1
= ();
%1
"
)
@raises_parse_error
@if_parser_enabled
def
test_let_global_var
():
relay
.
fromtext
(
"let @x = 1; ()"
)
relay
.
fromtext
(
SEMVER
+
"let @x = 1; ()"
)
@raises_parse_error
@if_parser_enabled
def
test_let_op
():
relay
.
fromtext
(
"let x = 1; ()"
)
relay
.
fromtext
(
SEMVER
+
"let x = 1; ()"
)
@if_parser_enabled
def
test_tuple
():
assert
alpha_equal
(
relay
.
fromtext
(
"()"
)
,
relay
.
Tuple
([]))
assert
parses_as
(
"()"
,
relay
.
Tuple
([]))
assert
alpha_equal
(
relay
.
fromtext
(
"(0,)"
)
,
relay
.
Tuple
([
relay
.
const
(
0
)]))
assert
parses_as
(
"(0,)"
,
relay
.
Tuple
([
relay
.
const
(
0
)]))
assert
alpha_equal
(
relay
.
fromtext
(
"(0, 1)"
)
,
relay
.
Tuple
([
relay
.
const
(
0
),
relay
.
const
(
1
)]))
assert
parses_as
(
"(0, 1)"
,
relay
.
Tuple
([
relay
.
const
(
0
),
relay
.
const
(
1
)]))
assert
alpha_equal
(
relay
.
fromtext
(
"(0, 1, 2)"
)
,
relay
.
Tuple
([
relay
.
const
(
0
),
relay
.
const
(
1
),
relay
.
const
(
2
)]))
assert
parses_as
(
"(0, 1, 2)"
,
relay
.
Tuple
([
relay
.
const
(
0
),
relay
.
const
(
1
),
relay
.
const
(
2
)]))
@if_parser_enabled
def
test_func
():
# 0 args
assert
alpha_equal
(
relay
.
fromtext
(
"fn () { 0 }"
)
,
assert
parses_as
(
"fn () { 0 }"
,
relay
.
Function
(
[],
relay
.
const
(
0
),
...
...
@@ -237,8 +276,8 @@ def test_func():
)
# 1 arg
assert
alpha_equal
(
relay
.
fromtext
(
"fn (
%
x) {
%
x }"
)
,
assert
parses_as
(
"fn (
%
x) {
%
x }"
,
relay
.
Function
(
[
X
],
X
,
...
...
@@ -248,8 +287,8 @@ def test_func():
)
# 2 args
assert
alpha_equal
(
relay
.
fromtext
(
"fn (
%
x,
%
y) {
%
x +
%
y }"
)
,
assert
parses_as
(
"fn (
%
x,
%
y) {
%
x +
%
y }"
,
relay
.
Function
(
[
X
,
Y
],
relay
.
add
(
X
,
Y
),
...
...
@@ -259,8 +298,8 @@ def test_func():
)
# annotations
assert
alpha_equal
(
relay
.
fromtext
(
"fn (
%
x: int32) -> int32 {
%
x }"
)
,
assert
parses_as
(
"fn (
%
x: int32) -> int32 {
%
x }"
,
relay
.
Function
(
[
X_ANNO
],
X_ANNO
,
...
...
@@ -269,11 +308,17 @@ def test_func():
)
)
# attributes
assert
parses_as
(
"fn (n=5) { () }"
,
relay
.
Function
([],
UNIT
,
None
,
None
,
tvm
.
make
.
node
(
"DictAttrs"
,
n
=
relay
.
const
(
5
)))
)
# TODO(@jmp): Crashes if %x isn't annnotated.
# @nottest
@if_parser_enabled
def
test_defn
():
id_defn
=
relay
.
fromtext
(
SEMVER
+
"""
def @id(
%
x: int32) -> int32 {
%
x
...
...
@@ -284,6 +329,7 @@ def test_defn():
@if_parser_enabled
def
test_recursive_call
():
id_defn
=
relay
.
fromtext
(
SEMVER
+
"""
def @id(
%
x: int32) -> int32 {
@id(
%
x)
...
...
@@ -293,16 +339,14 @@ def test_recursive_call():
@if_parser_enabled
def
test_ifelse
():
assert
alpha_equal
(
relay
.
fromtext
(
assert
parses_as
(
"""
if (True) {
0
} else {
1
}
"""
),
"""
,
relay
.
If
(
relay
.
const
(
True
),
relay
.
const
(
0
),
...
...
@@ -314,6 +358,7 @@ def test_ifelse():
@if_parser_enabled
def
test_ifelse_scope
():
relay
.
fromtext
(
SEMVER
+
"""
if (True) {
let
%
x = ();
...
...
@@ -328,13 +373,11 @@ def test_ifelse_scope():
def
test_call
():
# select right function to call: simple ident case
id_func
=
relay
.
Var
(
"id"
)
assert
alpha_equal
(
relay
.
fromtext
(
assert
parses_as
(
"""
let
%
id = fn (
%
x) {
%
x };
10 *
%
id(10)
"""
),
"""
,
relay
.
Let
(
id_func
,
relay
.
Function
([
X
],
X
,
None
,
[]),
...
...
@@ -344,13 +387,11 @@ def test_call():
# 0 args
constant
=
relay
.
Var
(
"constant"
)
assert
alpha_equal
(
relay
.
fromtext
(
assert
parses_as
(
"""
let
%
constant = fn () { 0 };
%
constant()
"""
),
"""
,
relay
.
Let
(
constant
,
relay
.
Function
([],
relay
.
const
(
0
),
None
,
[]),
...
...
@@ -360,13 +401,11 @@ def test_call():
# 1 arg
id_var
=
relay
.
Var
(
"id"
)
assert
alpha_equal
(
relay
.
fromtext
(
assert
parses_as
(
"""
let
%
id = fn (
%
x) {
%
x };
%
id(1)
"""
),
"""
,
relay
.
Let
(
id_var
,
relay
.
Function
([
X
],
X
,
None
,
[]),
...
...
@@ -376,13 +415,11 @@ def test_call():
# 2 args
multiply
=
relay
.
Var
(
"multiply"
)
assert
alpha_equal
(
relay
.
fromtext
(
assert
parses_as
(
"""
let
%
multiply = fn (
%
x,
%
y) {
%
x *
%
y };
%
multiply(0, 0)
"""
),
"""
,
relay
.
Let
(
multiply
,
relay
.
Function
(
...
...
@@ -396,12 +433,10 @@ def test_call():
)
# anonymous function
assert
alpha_equal
(
relay
.
fromtext
(
assert
parses_as
(
"""
(fn (
%
x) {
%
x })(0)
"""
),
"""
,
relay
.
Call
(
relay
.
Function
(
[
X
],
...
...
@@ -415,45 +450,44 @@ def test_call():
)
)
# TODO(@jmp): re-enable after sequence parsing improvements
# curried function
curried_mult
=
relay
.
Var
(
"curried_mult"
)
alpha_equal
(
relay
.
fromtext
(
"""
let
%
curried_mult =
fn (
%
x) {
fn (
%
y) {
%
x *
%
y
}
};
%
curried_mult(0);
%
curried_mult(0)(0)
"""
),
relay
.
Let
(
curried_mult
,
relay
.
Function
(
[
X
],
relay
.
Function
(
[
Y
],
relay
.
multiply
(
X
,
Y
),
None
,
[]
),
None
,
[]
),
relay
.
Let
(
_
,
relay
.
Call
(
curried_mult
,
[
relay
.
const
(
0
)],
None
,
None
),
relay
.
Call
(
relay
.
Call
(
curried_mult
,
[
relay
.
const
(
0
)],
None
,
None
),
[
relay
.
const
(
0
)],
None
,
None
)
)
)
)
# curried_mult = relay.Var("curried_mult")
# assert parses_as(
# """
# let %curried_mult =
# fn (%x) {
# fn (%y) {
# %x * %y
# }
# };
# %curried_mult(0);
# %curried_mult(0)(0)
# """,
# relay.Let(
# curried_mult,
# relay.Function(
# [X],
# relay.Function(
# [Y],
# relay.multiply(X, Y),
# None,
# []
# ),
# None,
# []
# ),
# relay.Let(
# _,
# relay.Call(curried_mult, [relay.const(0)], None, None),
# relay.Call(relay.Call(curried_mult, [relay.const(0)], None, None), [relay.const(0)], None, None)
# )
# )
# )
# op
a
lpha_equal
(
relay
.
fromtext
(
"abs(1)"
)
,
a
ssert
parses_as
(
"abs(1)"
,
relay
.
Call
(
relay
.
op
.
get
(
"abs"
),
[
relay
.
const
(
1
)],
None
,
None
)
)
...
...
@@ -461,8 +495,8 @@ def test_call():
@if_parser_enabled
def
test_incomplete_type
():
assert
alpha_equal
(
relay
.
fromtext
(
"let
%
_ : _ = (); ()"
)
,
assert
parses_as
(
"let
%
_ : _ = (); ()"
,
relay
.
Let
(
_
,
UNIT
,
...
...
@@ -473,7 +507,7 @@ def test_incomplete_type():
@if_parser_enabled
def
test_builtin_types
():
for
builtin_type
in
TYPES
:
relay
.
fromtext
(
"let
%
_ : {} = (); ()"
.
format
(
builtin_type
))
relay
.
fromtext
(
SEMVER
+
"let
%
_ : {} = (); ()"
.
format
(
builtin_type
))
@nottest
@if_parser_enabled
...
...
@@ -482,8 +516,8 @@ def test_call_type():
@if_parser_enabled
def
test_tensor_type
():
assert
alpha_equal
(
relay
.
fromtext
(
"let
%
_ : Tensor[(), float32] = (); ()"
)
,
assert
parses_as
(
"let
%
_ : Tensor[(), float32] = (); ()"
,
relay
.
Let
(
relay
.
Var
(
"_"
,
relay
.
TensorType
((),
"float32"
)),
UNIT
,
...
...
@@ -491,8 +525,8 @@ def test_tensor_type():
)
)
assert
alpha_equal
(
relay
.
fromtext
(
"let
%
_ : Tensor[(1,), float32] = (); ()"
)
,
assert
parses_as
(
"let
%
_ : Tensor[(1,), float32] = (); ()"
,
relay
.
Let
(
relay
.
Var
(
"_"
,
relay
.
TensorType
((
1
,),
"float32"
)),
UNIT
,
...
...
@@ -500,8 +534,8 @@ def test_tensor_type():
)
)
assert
alpha_equal
(
relay
.
fromtext
(
"let
%
_ : Tensor[(1, 1), float32] = (); ()"
)
,
assert
parses_as
(
"let
%
_ : Tensor[(1, 1), float32] = (); ()"
,
relay
.
Let
(
relay
.
Var
(
"_"
,
relay
.
TensorType
((
1
,
1
),
"float32"
)),
UNIT
,
...
...
@@ -511,12 +545,10 @@ def test_tensor_type():
@if_parser_enabled
def
test_function_type
():
assert
alpha_equal
(
relay
.
fromtext
(
assert
parses_as
(
"""
let
%
_: fn () -> int32 = fn () -> int32 { 0 }; ()
"""
),
"""
,
relay
.
Let
(
relay
.
Var
(
"_"
,
relay
.
FuncType
([],
int32
,
[],
[])),
relay
.
Function
([],
relay
.
const
(
0
),
int32
,
[]),
...
...
@@ -524,12 +556,10 @@ def test_function_type():
)
)
assert
alpha_equal
(
relay
.
fromtext
(
assert
parses_as
(
"""
let
%
_: fn (int32) -> int32 = fn (
%
x: int32) -> int32 { 0 }; ()
"""
),
"""
,
relay
.
Let
(
relay
.
Var
(
"_"
,
relay
.
FuncType
([
int32
],
int32
,
[],
[])),
relay
.
Function
([
relay
.
Var
(
"x"
,
int32
)],
relay
.
const
(
0
),
int32
,
[]),
...
...
@@ -537,12 +567,10 @@ def test_function_type():
)
)
assert
alpha_equal
(
relay
.
fromtext
(
assert
parses_as
(
"""
let
%
_: fn (int32, int32) -> int32 = fn (
%
x: int32,
%
y: int32) -> int32 { 0 }; ()
"""
),
"""
,
relay
.
Let
(
relay
.
Var
(
"_"
,
relay
.
FuncType
([
int32
,
int32
],
int32
,
[],
[])),
relay
.
Function
([
relay
.
Var
(
"x"
,
int32
),
relay
.
Var
(
"y"
,
int32
)],
relay
.
const
(
0
),
int32
,
[]),
...
...
@@ -552,11 +580,10 @@ def test_function_type():
@if_parser_enabled
def
test_tuple_type
():
assert
alpha_equal
(
relay
.
fromtext
(
assert
parses_as
(
"""
let
%
_: () = (); ()
"""
)
,
"""
,
relay
.
Let
(
relay
.
Var
(
"_"
,
relay
.
TupleType
([])),
UNIT
,
...
...
@@ -564,11 +591,10 @@ def test_tuple_type():
)
)
assert
alpha_equal
(
relay
.
fromtext
(
assert
parses_as
(
"""
let
%
_: (int32,) = (0,); ()
"""
)
,
"""
,
relay
.
Let
(
relay
.
Var
(
"_"
,
relay
.
TupleType
([
int32
])),
relay
.
Tuple
([
relay
.
const
(
0
)]),
...
...
@@ -576,11 +602,10 @@ def test_tuple_type():
)
)
assert
alpha_equal
(
relay
.
fromtext
(
assert
parses_as
(
"""
let
%
_: (int32, int32) = (0, 1); ()
"""
)
,
"""
,
relay
.
Let
(
relay
.
Var
(
"_"
,
relay
.
TupleType
([
int32
,
int32
])),
relay
.
Tuple
([
relay
.
const
(
0
),
relay
.
const
(
1
)]),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment