Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
d274e4b3
Commit
d274e4b3
authored
Jan 17, 2019
by
Jared Roesch
Committed by
Tianqi Chen
Jan 17, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Relay][Parser] Improve Relay parser and pretty printing, including CMAKE (#2377)
parent
d0f83664
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
400 additions
and
219 deletions
+400
-219
cmake/modules/ANTLR.cmake
+19
-5
include/tvm/relay/base.h
+3
-1
python/tvm/relay/_base.py
+5
-0
python/tvm/relay/_parser.py
+113
-22
python/tvm/relay/base.py
+10
-0
python/tvm/relay/grammar/Relay.g4
+35
-21
python/tvm/relay/parser.py
+9
-3
src/relay/ir/base.cc
+14
-0
tests/python/relay/test_ir_parser.py
+192
-167
No files found.
cmake/modules/ANTLR.cmake
View file @
d274e4b3
if
(
USE_ANTLR
)
if
(
USE_ANTLR
)
if
(
EXISTS /usr/local/lib/antlr-4.7.1-complete.jar
)
file
(
GLOB_RECURSE ANTLR4
set
(
ANTLR4
"/usr/local/lib/antlr-4.7.1-complete.jar"
)
/usr/local/lib/antlr-*-complete.jar
/usr/local/Cellar/*antlr-*-complete.jar
)
if
(
DEFINED ANTLR4
)
# Get the first element of the list of antlr jars.
# Sort and reverse the list so the item selected is the highest
# version in lib or else in Cellar if no lib installation exists.
list
(
SORT ANTLR4
)
list
(
REVERSE ANTLR4
)
list
(
GET ANTLR4 0 ANTLR4
)
set
(
RELAY_PARSER_DIR
set
(
RELAY_PARSER_DIR
${
CMAKE_CURRENT_SOURCE_DIR
}
/python/tvm/relay/grammar
)
${
CMAKE_CURRENT_SOURCE_DIR
}
/python/tvm/relay/grammar
)
...
@@ -14,15 +22,21 @@ if(USE_ANTLR)
...
@@ -14,15 +22,21 @@ if(USE_ANTLR)
${
RELAY_PARSER_DIR
}
/py3/RelayParser.py
${
RELAY_PARSER_DIR
}
/py3/RelayParser.py
${
RELAY_PARSER_DIR
}
/py3/RelayLexer.py
)
${
RELAY_PARSER_DIR
}
/py3/RelayLexer.py
)
set
(
JAVA_HOME $ENV{JAVA_HOME}
)
if
(
NOT DEFINED JAVA_HOME
)
# Hack to get system to search for Java itself.
set
(
JAVA_HOME
"/usr"
)
endif
()
# Generate ANTLR grammar for parsing.
# Generate ANTLR grammar for parsing.
add_custom_command
(
OUTPUT
${
RELAY_PARSER
}
add_custom_command
(
OUTPUT
${
RELAY_PARSER
}
COMMAND $
ENV
{JAVA_HOME}/bin/java -jar
${
ANTLR4
}
-visitor -no-listener -Dlanguage=Python2
${
RELAY_PARSER_DIR
}
/Relay.g4 -o
${
RELAY_PARSER_DIR
}
/py2
COMMAND
${
JAVA_HOME
}
/bin/java -jar
${
ANTLR4
}
-visitor -no-listener -Dlanguage=Python2
${
RELAY_PARSER_DIR
}
/Relay.g4 -o
${
RELAY_PARSER_DIR
}
/py2
COMMAND $
ENV
{JAVA_HOME}/bin/java -jar
${
ANTLR4
}
-visitor -no-listener -Dlanguage=Python3
${
RELAY_PARSER_DIR
}
/Relay.g4 -o
${
RELAY_PARSER_DIR
}
/py3
COMMAND
${
JAVA_HOME
}
/bin/java -jar
${
ANTLR4
}
-visitor -no-listener -Dlanguage=Python3
${
RELAY_PARSER_DIR
}
/Relay.g4 -o
${
RELAY_PARSER_DIR
}
/py3
DEPENDS
${
RELAY_PARSER_DIR
}
/Relay.g4
DEPENDS
${
RELAY_PARSER_DIR
}
/Relay.g4
WORKING_DIRECTORY
${
RELAY_PARSER_DIR
}
)
WORKING_DIRECTORY
${
RELAY_PARSER_DIR
}
)
add_custom_target
(
relay_parser ALL DEPENDS
${
RELAY_PARSER
}
)
add_custom_target
(
relay_parser ALL DEPENDS
${
RELAY_PARSER
}
)
else
()
else
()
message
(
FATAL_ERROR
"Can't find ANTLR4
!"
)
message
(
FATAL_ERROR
"Can't find ANTLR4
: ANTLR4="
${
ANTLR4
}
)
endif
()
endif
()
endif
(
USE_ANTLR
)
endif
(
USE_ANTLR
)
include/tvm/relay/base.h
View file @
d274e4b3
...
@@ -108,7 +108,9 @@ class SourceName : public NodeRef {
...
@@ -108,7 +108,9 @@ class SourceName : public NodeRef {
* \brief access the internal node container
* \brief access the internal node container
* \return the pointer to the internal node container
* \return the pointer to the internal node container
*/
*/
inline
const
SourceNameNode
*
operator
->
()
const
;
inline
const
SourceNameNode
*
operator
->
()
const
{
return
static_cast
<
SourceNameNode
*>
(
this
->
node_
.
get
());
}
/*!
/*!
* \brief Get an SourceName for a given operator name.
* \brief Get an SourceName for a given operator name.
...
...
python/tvm/relay/_base.py
0 → 100644
View file @
d274e4b3
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable
"""The interface of expr function exposed from C++."""
from
tvm._ffi.function
import
_init_api
_init_api
(
"relay._base"
,
__name__
)
python/tvm/relay/_parser.py
View file @
d274e4b3
...
@@ -6,13 +6,17 @@ from __future__ import absolute_import
...
@@ -6,13 +6,17 @@ from __future__ import absolute_import
import
sys
import
sys
from
collections
import
deque
from
collections
import
deque
from
typing
import
TypeVar
,
Deque
,
Tuple
,
Optional
,
Union
,
NamedTuple
,
List
,
Callable
,
Any
from
typing
import
TypeVar
,
Deque
,
Tuple
,
Optional
,
Union
,
NamedTuple
,
List
,
Callable
,
Any
,
Dict
import
tvm
from
.
import
module
from
.
import
module
from
.base
import
Span
,
SourceName
from
.
import
expr
from
.
import
expr
from
.
import
ty
from
.
import
ty
from
.
import
op
from
.
import
op
class
ParseError
(
Exception
):
class
ParseError
(
Exception
):
"""Exception type for parse errors."""
"""Exception type for parse errors."""
...
@@ -76,22 +80,46 @@ def lookup(scopes, name):
...
@@ -76,22 +80,46 @@ def lookup(scopes, name):
return
val
return
val
return
None
return
None
def
spanify
(
f
):
"""A decorator which attaches span information
to the value returned by calling `f`.
Intended for use with the below AST visiting
methods. The idea is that after we do the work
of constructing the AST we attach Span information.
"""
def
_wrapper
(
*
args
,
**
kwargs
):
# Assumes 0th arg is self and gets source_name from object.
sn
=
args
[
0
]
.
source_name
# Assumes 1st arg is an ANTLR parser context.
ctx
=
args
[
1
]
ast
=
f
(
*
args
,
**
kwargs
)
line
,
col
=
ctx
.
getSourceInterval
()
sp
=
Span
(
sn
,
line
,
col
)
ast
.
set_span
(
sp
)
return
ast
return
_wrapper
# TODO(@jmp): Use https://stackoverflow.com/q/13889941
# TODO(@jmp): Use https://stackoverflow.com/q/13889941
# to figure out how to get ANTLR4 to be more unhappy about syntax errors
# to figure out how to get ANTLR4 to be more unhappy about syntax errors
class
ParseTreeToRelayIR
(
RelayVisitor
):
class
ParseTreeToRelayIR
(
RelayVisitor
):
"""Parse Relay text format into Relay IR."""
"""Parse Relay text format into Relay IR."""
def
__init__
(
self
):
def
__init__
(
self
,
source_name
):
# type: () -> None
# type: (str) -> None
self
.
source_name
=
source_name
self
.
module
=
module
.
Module
({})
# type: module.Module
self
.
module
=
module
.
Module
({})
# type: module.Module
# Adding an empty scope allows naked lets without pain.
# Adding an empty scope allows naked lets without pain.
self
.
var_scopes
=
deque
([
deque
()])
# type: Scopes[expr.Var]
self
.
var_scopes
=
deque
([
deque
()])
# type: Scopes[expr.Var]
self
.
global_var_scope
=
deque
()
# type: Scope[expr.GlobalVar]
self
.
global_var_scope
=
deque
()
# type: Scope[expr.GlobalVar]
self
.
type_param_scopes
=
deque
([
deque
()])
# type: Scopes[ty.TypeVar]
self
.
type_param_scopes
=
deque
([
deque
()])
# type: Scopes[ty.TypeVar]
self
.
graph_expr
=
[]
# type: List[expr.Expr]
super
(
ParseTreeToRelayIR
,
self
)
.
__init__
()
super
(
ParseTreeToRelayIR
,
self
)
.
__init__
()
def
enter_var_scope
(
self
):
def
enter_var_scope
(
self
):
# type: () -> None
# type: () -> None
"""Enter a new Var scope so it can be popped off later."""
"""Enter a new Var scope so it can be popped off later."""
...
@@ -146,20 +174,25 @@ class ParseTreeToRelayIR(RelayVisitor):
...
@@ -146,20 +174,25 @@ class ParseTreeToRelayIR(RelayVisitor):
node_type
=
node
.
getSymbol
()
.
type
node_type
=
node
.
getSymbol
()
.
type
node_text
=
node
.
getText
()
node_text
=
node
.
getText
()
name
=
node_text
[
1
:]
# variables
# variables
if
node_type
==
RelayLexer
.
GLOBAL_VAR
:
if
node_type
==
RelayLexer
.
GLOBAL_VAR
:
return
lookup
(
[
self
.
global_var_scope
]
,
node_text
[
1
:])
return
lookup
(
deque
([
self
.
global_var_scope
])
,
node_text
[
1
:])
elif
node_type
==
RelayLexer
.
LOCAL_VAR
:
elif
node_type
==
RelayLexer
.
LOCAL_VAR
:
name
=
node_text
[
1
:]
# Remove the leading '%' and lookup the name.
var
=
lookup
(
self
.
var_scopes
,
name
)
var
=
lookup
(
self
.
var_scopes
,
name
)
if
var
is
None
:
if
var
is
None
:
raise
ParseError
(
"Couldn't resolve `{}`."
.
format
(
name
))
raise
ParseError
(
"Couldn't resolve `{}`."
.
format
(
name
))
return
var
return
var
elif
node_type
==
RelayLexer
.
GRAPH_VAR
:
try
:
return
self
.
graph_expr
[
int
(
name
)]
except
IndexError
:
raise
ParseError
(
"Couldn't resolve `{}`"
.
format
(
name
))
# data types
# data types
elif
node_type
==
RelayLexer
.
IN
T
:
elif
node_type
==
RelayLexer
.
NA
T
:
return
int
(
node_text
)
return
int
(
node_text
)
elif
node_type
==
RelayLexer
.
FLOAT
:
elif
node_type
==
RelayLexer
.
FLOAT
:
return
float
(
node_text
)
return
float
(
node_text
)
...
@@ -190,7 +223,7 @@ class ParseTreeToRelayIR(RelayVisitor):
...
@@ -190,7 +223,7 @@ class ParseTreeToRelayIR(RelayVisitor):
return
self
.
visit
(
ctx
)
return
self
.
visit
(
ctx
)
def
visitProg
(
self
,
ctx
):
def
visitProg
(
self
,
ctx
):
# type: (RelayParser.ProgContext) -> Union[expr.Expr,
env.Environment
]
# type: (RelayParser.ProgContext) -> Union[expr.Expr,
module.Module
]
if
ctx
.
defn
():
if
ctx
.
defn
():
self
.
visit_list
(
ctx
.
defn
())
self
.
visit_list
(
ctx
.
defn
())
return
self
.
module
return
self
.
module
...
@@ -219,7 +252,7 @@ class ParseTreeToRelayIR(RelayVisitor):
...
@@ -219,7 +252,7 @@ class ParseTreeToRelayIR(RelayVisitor):
def
visitScalarInt
(
self
,
ctx
):
def
visitScalarInt
(
self
,
ctx
):
# type: (RelayParser.ScalarIntContext) -> expr.Constant
# type: (RelayParser.ScalarIntContext) -> expr.Constant
return
expr
.
const
(
self
.
visit
(
ctx
.
IN
T
()))
return
expr
.
const
(
self
.
visit
(
ctx
.
NA
T
()))
def
visitScalarBool
(
self
,
ctx
):
def
visitScalarBool
(
self
,
ctx
):
# type: (RelayParser.ScalarBoolContext) -> expr.Constant
# type: (RelayParser.ScalarBoolContext) -> expr.Constant
...
@@ -240,7 +273,7 @@ class ParseTreeToRelayIR(RelayVisitor):
...
@@ -240,7 +273,7 @@ class ParseTreeToRelayIR(RelayVisitor):
return
expr
.
Tuple
(
tup
)
return
expr
.
Tuple
(
tup
)
# Currently doesn't support mutable sequencing.
# Currently doesn't support mutable sequencing.
def
visit
Seq
(
self
,
ctx
):
def
visit
Let
(
self
,
ctx
):
# type: (RelayParser.SeqContext) -> expr.Let
# type: (RelayParser.SeqContext) -> expr.Let
"""Desugar various sequence constructs to Relay Let nodes."""
"""Desugar various sequence constructs to Relay Let nodes."""
if
ctx
.
MUT
()
is
not
None
:
if
ctx
.
MUT
()
is
not
None
:
...
@@ -253,7 +286,7 @@ class ParseTreeToRelayIR(RelayVisitor):
...
@@ -253,7 +286,7 @@ class ParseTreeToRelayIR(RelayVisitor):
else
:
else
:
local_var
=
ctx
.
var
()
.
ident
()
.
LOCAL_VAR
()
local_var
=
ctx
.
var
()
.
ident
()
.
LOCAL_VAR
()
if
local_var
is
None
:
if
local_var
is
None
:
raise
ParseError
(
'Only local ids may be used in `let`s.'
)
raise
ParseError
(
"Only local ids may be used in `let`s."
)
ident
=
local_var
.
getText
()[
1
:]
ident
=
local_var
.
getText
()[
1
:]
type_
=
self
.
getType_
(
ctx
.
var
()
.
type_
())
type_
=
self
.
getType_
(
ctx
.
var
()
.
type_
())
...
@@ -278,12 +311,14 @@ class ParseTreeToRelayIR(RelayVisitor):
...
@@ -278,12 +311,14 @@ class ParseTreeToRelayIR(RelayVisitor):
return
relay_op
(
arg0
,
arg1
)
return
relay_op
(
arg0
,
arg1
)
@spanify
def
visitVar
(
self
,
ctx
):
def
visitVar
(
self
,
ctx
):
# type: (RelayParser.VarContext) -> expr.Var
# type: (RelayParser.VarContext) -> expr.Var
"""Visit a single variable."""
ident
=
ctx
.
ident
()
.
LOCAL_VAR
()
ident
=
ctx
.
ident
()
.
LOCAL_VAR
()
if
ident
is
None
:
if
ident
is
None
:
raise
ParseError
(
'Only local ids may be used in params.'
)
raise
ParseError
(
"Only local ids may be used in vars."
)
type_
=
self
.
getType_
(
ctx
.
type_
())
type_
=
self
.
getType_
(
ctx
.
type_
())
...
@@ -293,15 +328,33 @@ class ParseTreeToRelayIR(RelayVisitor):
...
@@ -293,15 +328,33 @@ class ParseTreeToRelayIR(RelayVisitor):
# type: (RelayParser.VarListContext) -> List[expr.Var]
# type: (RelayParser.VarListContext) -> List[expr.Var]
return
self
.
visit_list
(
ctx
.
var
())
return
self
.
visit_list
(
ctx
.
var
())
# TODO: support a larger class of values than just Relay exprs
def
visitAttr
(
self
,
ctx
):
# type: (RelayParser.AttrContext) -> Tuple[str, expr.Expr]
return
(
ctx
.
CNAME
()
.
getText
(),
self
.
visit
(
ctx
.
expr
()))
def
visitAttrList
(
self
,
ctx
):
# type: (RelayParser.AttrListContext) -> Dict[str, expr.Expr]
return
dict
(
self
.
visit_list
(
ctx
.
attr
()))
def
visitArgList
(
self
,
ctx
# type: RelayParser.ArgListContext
):
# type: (...) -> Tuple[Optional[List[expr.Var]], Optional[Dict[str, expr.Expr]]]
var_list
=
self
.
visit
(
ctx
.
varList
())
if
ctx
.
varList
()
else
None
attr_list
=
self
.
visit
(
ctx
.
attrList
())
if
ctx
.
attrList
()
else
None
return
(
var_list
,
attr_list
)
def
mk_func
(
self
,
ctx
):
def
mk_func
(
self
,
ctx
):
# type: (Union[RelayParser.FuncContext, RelayParser.DefnContext]) -> Function
# type: (Union[RelayParser.FuncContext, RelayParser.DefnContext]) ->
expr.
Function
"""Construct a function from either a Func or Defn."""
"""Construct a function from either a Func or Defn."""
# Enter var scope early to put params in scope.
# Enter var scope early to put params in scope.
self
.
enter_var_scope
()
self
.
enter_var_scope
()
# Capture type params in params.
# Capture type params in params.
self
.
enter_type_param_scope
()
self
.
enter_type_param_scope
()
var_list
=
self
.
visit
(
ctx
.
var
List
())
var_list
,
attr_list
=
self
.
visit
(
ctx
.
arg
List
())
ret_type
=
self
.
getType_
(
ctx
.
type_
())
ret_type
=
self
.
getType_
(
ctx
.
type_
())
type_params
=
list
(
self
.
exit_type_param_scope
())
type_params
=
list
(
self
.
exit_type_param_scope
())
...
@@ -311,22 +364,28 @@ class ParseTreeToRelayIR(RelayVisitor):
...
@@ -311,22 +364,28 @@ class ParseTreeToRelayIR(RelayVisitor):
body
=
self
.
visit
(
ctx
.
body
())
body
=
self
.
visit
(
ctx
.
body
())
self
.
exit_var_scope
()
self
.
exit_var_scope
()
return
expr
.
Function
(
var_list
,
body
,
ret_type
,
type_params
)
# type: ignore
attrs
=
tvm
.
make
.
node
(
"DictAttrs"
,
**
attr_list
)
if
attr_list
is
not
None
else
None
return
expr
.
Function
(
var_list
,
body
,
ret_type
,
type_params
,
attrs
)
@spanify
def
visitFunc
(
self
,
ctx
):
def
visitFunc
(
self
,
ctx
):
# type: (RelayParser.FuncContext) -> expr.Function
# type: (RelayParser.FuncContext) -> expr.Function
return
self
.
mk_func
(
ctx
)
return
self
.
mk_func
(
ctx
)
# TODO: how to set spans for definitions?
# @spanify
def
visitDefn
(
self
,
ctx
):
def
visitDefn
(
self
,
ctx
):
# type: (RelayParser.DefnContext) -> None
# type: (RelayParser.DefnContext) -> None
ident
=
ctx
.
ident
()
.
GLOBAL_VAR
()
ident
=
ctx
.
ident
()
.
GLOBAL_VAR
()
if
ident
is
None
:
if
ident
is
None
:
raise
ParseError
(
'Only global ids may be used in `def`s.'
)
raise
ParseError
(
"Only global ids may be used in `def`s."
)
ident_name
=
ident
.
getText
()[
1
:]
ident_name
=
ident
.
getText
()[
1
:]
ident
=
self
.
mk_global_var
(
ident_name
)
ident
=
self
.
mk_global_var
(
ident_name
)
self
.
module
[
ident
]
=
self
.
mk_func
(
ctx
)
self
.
module
[
ident
]
=
self
.
mk_func
(
ctx
)
@spanify
def
visitCall
(
self
,
ctx
):
def
visitCall
(
self
,
ctx
):
# type: (RelayParser.CallContext) -> expr.Call
# type: (RelayParser.CallContext) -> expr.Call
visited_exprs
=
self
.
visit_list
(
ctx
.
expr
())
visited_exprs
=
self
.
visit_list
(
ctx
.
expr
())
...
@@ -336,6 +395,7 @@ class ParseTreeToRelayIR(RelayVisitor):
...
@@ -336,6 +395,7 @@ class ParseTreeToRelayIR(RelayVisitor):
return
expr
.
Call
(
func
,
args
,
None
,
None
)
return
expr
.
Call
(
func
,
args
,
None
,
None
)
@spanify
def
visitIfElse
(
self
,
ctx
):
def
visitIfElse
(
self
,
ctx
):
# type: (RelayParser.IfElseContext) -> expr.If
# type: (RelayParser.IfElseContext) -> expr.If
"""Construct a Relay If node. Creates a new scope for each branch."""
"""Construct a Relay If node. Creates a new scope for each branch."""
...
@@ -351,6 +411,27 @@ class ParseTreeToRelayIR(RelayVisitor):
...
@@ -351,6 +411,27 @@ class ParseTreeToRelayIR(RelayVisitor):
return
expr
.
If
(
cond
,
true_branch
,
false_branch
)
return
expr
.
If
(
cond
,
true_branch
,
false_branch
)
@spanify
def
visitGraph
(
self
,
ctx
):
# type: (RelayParser.GraphContext) -> expr.Expr
"""Visit a graph variable assignment."""
if
ctx
.
ident
()
.
GRAPH_VAR
()
is
None
:
raise
ParseError
(
"Expected a graph var, but got `{}`"
.
format
(
ctx
.
ident
()
.
getText
()))
graph_nid
=
int
(
ctx
.
ident
()
.
GRAPH_VAR
()
.
getText
()[
1
:])
self
.
enter_var_scope
()
value
=
self
.
visit
(
ctx
.
expr
(
0
))
self
.
exit_var_scope
()
if
graph_nid
!=
len
(
self
.
graph_expr
):
raise
ParseError
(
"Expected new graph variable to be `
%
{}`,"
.
format
(
len
(
self
.
graph_expr
))
+
\
"but got `
%
{}`"
.
format
(
graph_nid
))
self
.
graph_expr
.
append
(
value
)
kont
=
self
.
visit
(
ctx
.
expr
(
1
))
return
kont
# Types
# Types
# pylint: disable=unused-argument
# pylint: disable=unused-argument
...
@@ -428,8 +509,18 @@ def make_parser(data):
...
@@ -428,8 +509,18 @@ def make_parser(data):
token_stream
=
CommonTokenStream
(
lexer
)
token_stream
=
CommonTokenStream
(
lexer
)
return
RelayParser
(
token_stream
)
return
RelayParser
(
token_stream
)
def
fromtext
(
data
):
__source_name_counter__
=
0
# type: (str) -> Union[expr.Expr, env.Environment]
def
fromtext
(
data
,
source_name
=
None
):
# type: (str, str) -> Union[expr.Expr, module.Module]
"""Parse a Relay program."""
"""Parse a Relay program."""
global
__source_name_counter__
if
source_name
is
None
:
source_name
=
"source_file{0}"
.
format
(
__source_name_counter__
)
if
isinstance
(
source_name
,
str
):
source_name
=
SourceName
(
source_name
)
tree
=
make_parser
(
data
)
.
prog
()
tree
=
make_parser
(
data
)
.
prog
()
return
ParseTreeToRelayIR
()
.
visit
(
tree
)
return
ParseTreeToRelayIR
(
source_name
)
.
visit
(
tree
)
python/tvm/relay/base.py
View file @
d274e4b3
...
@@ -4,6 +4,7 @@ from __future__ import absolute_import as _abs
...
@@ -4,6 +4,7 @@ from __future__ import absolute_import as _abs
from
.._ffi.node
import
NodeBase
,
register_node
as
_register_tvm_node
from
.._ffi.node
import
NodeBase
,
register_node
as
_register_tvm_node
from
.
import
_make
from
.
import
_make
from
.
import
_expr
from
.
import
_expr
from
.
import
_base
NodeBase
=
NodeBase
NodeBase
=
NodeBase
...
@@ -63,6 +64,9 @@ class RelayNode(NodeBase):
...
@@ -63,6 +64,9 @@ class RelayNode(NodeBase):
"""
"""
return
_expr
.
RelayPrint
(
self
,
show_meta_data
,
annotate
)
return
_expr
.
RelayPrint
(
self
,
show_meta_data
,
annotate
)
def
set_span
(
self
,
span
):
_base
.
set_span
(
self
,
span
)
@register_relay_node
@register_relay_node
class
Span
(
RelayNode
):
class
Span
(
RelayNode
):
...
@@ -71,6 +75,12 @@ class Span(RelayNode):
...
@@ -71,6 +75,12 @@ class Span(RelayNode):
def
__init__
(
self
,
source
,
lineno
,
col_offset
):
def
__init__
(
self
,
source
,
lineno
,
col_offset
):
self
.
__init_handle_by_constructor__
(
_make
.
Span
,
source
,
lineno
,
col_offset
)
self
.
__init_handle_by_constructor__
(
_make
.
Span
,
source
,
lineno
,
col_offset
)
@register_relay_node
class
SourceName
(
RelayNode
):
"""A identifier for a source location"""
def
__init__
(
self
,
name
):
self
.
__init_handle_by_constructor__
(
_make
.
SourceName
,
name
)
@register_relay_node
@register_relay_node
class
Id
(
NodeBase
):
class
Id
(
NodeBase
):
...
...
python/tvm/relay/grammar/Relay.g4
View file @
d274e4b3
grammar Relay;
grammar Relay;
SEMVER: 'v0.0.1' ;
// Lexing
// Lexing
// comments
// comments
WS : [ \t\n\r]+ -> skip ;
WS : [ \t\n\r]+ -> skip ;
...
@@ -20,7 +22,8 @@ NE: '!=' ;
...
@@ -20,7 +22,8 @@ NE: '!=' ;
opIdent: CNAME ;
opIdent: CNAME ;
GLOBAL_VAR: '@' CNAME ;
GLOBAL_VAR: '@' CNAME ;
LOCAL_VAR: '%' CNAME ;
LOCAL_VAR: '%' CNAME;
GRAPH_VAR: '%' NAT;
MUT: 'mut' ;
MUT: 'mut' ;
...
@@ -31,13 +34,13 @@ BOOL_LIT
...
@@ -31,13 +34,13 @@ BOOL_LIT
// non-negative floats
// non-negative floats
FLOAT
FLOAT
:
INT '.' IN
T EXP? // 1.35, 1.35E-9, 0.3, 4.5
:
NAT '.' NA
T EXP? // 1.35, 1.35E-9, 0.3, 4.5
|
IN
T EXP // 1e10 3e4
|
NA
T EXP // 1e10 3e4
;
;
// non-negative ints
// non-negative ints
IN
T: DIGIT+ ;
NA
T: DIGIT+ ;
fragment EXP: [eE] [+\-]?
IN
T ; // \- since - means "range" inside [...]
fragment EXP: [eE] [+\-]?
NA
T ; // \- since - means "range" inside [...]
CNAME: ('_'|LETTER) ('_'|LETTER|DIGIT)* ;
CNAME: ('_'|LETTER) ('_'|LETTER|DIGIT)* ;
fragment LETTER: [a-zA-Z] ;
fragment LETTER: [a-zA-Z] ;
...
@@ -46,7 +49,7 @@ fragment DIGIT: [0-9] ;
...
@@ -46,7 +49,7 @@ fragment DIGIT: [0-9] ;
// Parsing
// Parsing
// A Relay program is a list of global definitions or an expression.
// A Relay program is a list of global definitions or an expression.
prog: (defn* | expr) EOF ;
prog:
SEMVER
(defn* | expr) EOF ;
// option: 'set' ident BOOL_LIT ;
// option: 'set' ident BOOL_LIT ;
...
@@ -73,10 +76,11 @@ expr
...
@@ -73,10 +76,11 @@ expr
| 'if' '(' expr ')' body 'else' body # ifElse
| 'if' '(' expr ')' body 'else' body # ifElse
// sequencing
// sequencing
| 'let' MUT? var '=' expr ';' expr #
seq
| 'let' MUT? var '=' expr ';' expr #
let
| 'let' MUT? var '=' '{' expr '}' ';' expr #
seq
| 'let' MUT? var '=' '{' expr '}' ';' expr #
let
// sugar for let %_ = expr; expr
// sugar for let %_ = expr; expr
| expr ';' expr # seq
| expr ';' expr # let
| ident '=' expr ';' expr # graph
// mutable update
// mutable update
// | ident '=' expr # writeRef
// | ident '=' expr # writeRef
...
@@ -84,16 +88,25 @@ expr
...
@@ -84,16 +88,25 @@ expr
| ident # identExpr
| ident # identExpr
| scalar # scalarExpr
| scalar # scalarExpr
// | expr '.'
INT
# project
// | expr '.'
NAT
# project
// | 'debug'
# debug
// | 'debug' # debug
;
;
func: 'fn' varList ('->' type_)? body ;
func: 'fn' '(' argList ')' ('->' type_)? body ;
defn: 'def' ident varList ('->' type_)? body ;
defn: 'def' ident '(' argList ')' ('->' type_)? body ;
argList
: varList
| attrList
| varList ',' attrList
;
varList:
'(' (var (',' var)*)? ')'
;
varList:
(var (',' var)*)?
;
var: ident (':' type_)? ;
var: ident (':' type_)? ;
attrList: (attr (',' attr)*)? ;
attr: CNAME '=' expr ;
// TODO(@jmp): for improved type annotations
// TODO(@jmp): for improved type annotations
// returnAnno: (ident ':')? type_ ;
// returnAnno: (ident ':')? type_ ;
...
@@ -110,7 +123,7 @@ type_
...
@@ -110,7 +123,7 @@ type_
// | identType '[' (type_ (',' type_)*)? ']' # callType
// | identType '[' (type_ (',' type_)*)? ']' # callType
| 'fn' '(' (type_ (',' type_)*)? ')' '->' type_ # funcType
| 'fn' '(' (type_ (',' type_)*)? ')' '->' type_ # funcType
| '_' # incompleteType
| '_' # incompleteType
|
IN
T # intType
|
NA
T # intType
;
;
shapeSeq
shapeSeq
...
@@ -123,20 +136,20 @@ shape
...
@@ -123,20 +136,20 @@ shape
: '(' shape ')' # parensShape
: '(' shape ')' # parensShape
// | type_ op=('*'|'/') type_ # binOpType
// | type_ op=('*'|'/') type_ # binOpType
// | type_ op=('+'|'-') type_ # binOpType
// | type_ op=('+'|'-') type_ # binOpType
|
IN
T # intShape
|
NA
T # intShape
;
;
identType: CNAME ;
identType: CNAME ;
//
Int8, Int16, Int32, I
nt64
//
int8, int16, int32, i
nt64
//
UInt8, UInt16, UInt32, UI
nt64
//
uint8, uint16, uint32, ui
nt64
//
Float16, Float32, F
loat64
//
float16, float32, f
loat64
//
B
ool
//
b
ool
body: '{' expr '}' ;
body: '{' expr '}' ;
scalar
scalar
: FLOAT # scalarFloat
: FLOAT # scalarFloat
|
IN
T # scalarInt
|
NA
T # scalarInt
| BOOL_LIT # scalarBool
| BOOL_LIT # scalarBool
;
;
...
@@ -144,4 +157,5 @@ ident
...
@@ -144,4 +157,5 @@ ident
: opIdent
: opIdent
| GLOBAL_VAR
| GLOBAL_VAR
| LOCAL_VAR
| LOCAL_VAR
| GRAPH_VAR
;
;
python/tvm/relay/parser.py
View file @
d274e4b3
"""A parser for Relay's text format."""
"""A parser for Relay's text format."""
from
__future__
import
absolute_import
from
__future__
import
absolute_import
from
..
import
register_func
def
enabled
():
def
enabled
():
"""Is the parser enabled/Can we import the parser?"""
"""Checks whether the parser is enabled, this allows users to
optionally support building the parser.
We use this check before importing the parser.
"""
try
:
try
:
# pylint: disable=unused-variable
# pylint: disable=unused-variable
from
tvm.relay
import
_parser
from
tvm.relay
import
_parser
...
@@ -11,7 +16,8 @@ def enabled():
...
@@ -11,7 +16,8 @@ def enabled():
except
Exception
:
except
Exception
:
return
False
return
False
def
fromtext
(
data
):
@register_func
(
"relay.fromtext"
)
def
fromtext
(
data
,
source_name
=
None
):
"""Parse a Relay program."""
"""Parse a Relay program."""
from
tvm.relay
import
_parser
from
tvm.relay
import
_parser
return
_parser
.
fromtext
(
data
)
return
_parser
.
fromtext
(
data
,
source_name
)
src/relay/ir/base.cc
View file @
d274e4b3
...
@@ -32,6 +32,11 @@ SourceName SourceName::Get(const std::string& name) {
...
@@ -32,6 +32,11 @@ SourceName SourceName::Get(const std::string& name) {
return
SourceName
(
GetSourceNameNode
(
name
));
return
SourceName
(
GetSourceNameNode
(
name
));
}
}
TVM_REGISTER_API
(
"relay._make.SourceName"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
SourceName
::
Get
(
args
[
0
]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER
(
IRPrinter
,
vtable
)
TVM_STATIC_IR_FUNCTOR_REGISTER
(
IRPrinter
,
vtable
)
.
set_dispatch
<
SourceNameNode
>
([](
const
SourceNameNode
*
node
,
tvm
::
IRPrinter
*
p
)
{
.
set_dispatch
<
SourceNameNode
>
([](
const
SourceNameNode
*
node
,
tvm
::
IRPrinter
*
p
)
{
p
->
stream
<<
"SourceName("
<<
node
->
name
<<
", "
<<
node
<<
")"
;
p
->
stream
<<
"SourceName("
<<
node
->
name
<<
", "
<<
node
<<
")"
;
...
@@ -66,5 +71,14 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
...
@@ -66,5 +71,14 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
TVM_REGISTER_NODE_TYPE
(
IdNode
);
TVM_REGISTER_NODE_TYPE
(
IdNode
);
TVM_REGISTER_API
(
"relay._base.set_span"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
NodeRef
node_ref
=
args
[
0
];
auto
rn
=
node_ref
.
as_derived
<
RelayNode
>
();
CHECK
(
rn
);
Span
sp
=
args
[
1
];
rn
->
span
=
sp
;
});
}
// namespace relay
}
// namespace relay
}
// namespace tvm
}
// namespace tvm
tests/python/relay/test_ir_parser.py
View file @
d274e4b3
...
@@ -8,11 +8,12 @@ from numpy import isclose
...
@@ -8,11 +8,12 @@ from numpy import isclose
from
typing
import
Union
from
typing
import
Union
from
functools
import
wraps
from
functools
import
wraps
if
enabled
():
if
enabled
():
from
tvm.relay._parser
import
ParseError
raises_parse_error
=
raises
(
tvm
.
_ffi
.
base
.
TVMError
)
raises_parse_error
=
raises
(
ParseError
)
else
:
else
:
raises_parse_error
=
lambda
x
:
x
raises_parse_error
=
lambda
x
:
x
SEMVER
=
"v0.0.1"
BINARY_OPS
=
{
BINARY_OPS
=
{
"*"
:
relay
.
multiply
,
"*"
:
relay
.
multiply
,
"/"
:
relay
.
divide
,
"/"
:
relay
.
divide
,
...
@@ -48,6 +49,10 @@ TYPES = {
...
@@ -48,6 +49,10 @@ TYPES = {
"float16x4"
,
"float16x4"
,
}
}
def
parses_as
(
code
,
expr
):
# type: (str, relay.Expr) -> bool
return
alpha_equal
(
relay
.
fromtext
(
SEMVER
+
"
\n
"
+
code
),
expr
)
def
get_scalar
(
x
):
def
get_scalar
(
x
):
# type: (relay.Constant) -> (Union[float, int, bool])
# type: (relay.Constant) -> (Union[float, int, bool])
return
x
.
data
.
asnumpy
()
.
item
()
return
x
.
data
.
asnumpy
()
.
item
()
...
@@ -74,80 +79,80 @@ def if_parser_enabled(func):
...
@@ -74,80 +79,80 @@ def if_parser_enabled(func):
@if_parser_enabled
@if_parser_enabled
def
test_comments
():
def
test_comments
():
assert
alpha_equal
(
assert
parses_as
(
relay
.
fromtext
(
"""
"""
// This is a line comment!
// This is a line comment!
()
()
"""
)
,
"""
,
UNIT
UNIT
)
)
assert
alpha_equal
(
assert
parses_as
(
relay
.
fromtext
(
"""
"""
/* This is a block comment!
/* This is a block comment!
This is still a block comment!
This is still a block comment!
*/
*/
()
()
"""
)
,
"""
,
UNIT
UNIT
)
)
@if_parser_enabled
@if_parser_enabled
def
test_int_literal
():
def
test_int_literal
():
assert
isinstance
(
relay
.
fromtext
(
"1"
),
relay
.
Constant
)
assert
isinstance
(
relay
.
fromtext
(
SEMVER
+
"1"
),
relay
.
Constant
)
assert
isinstance
(
relay
.
fromtext
(
"1"
)
.
data
,
tvm
.
ndarray
.
NDArray
)
assert
isinstance
(
relay
.
fromtext
(
SEMVER
+
"1"
)
.
data
,
tvm
.
ndarray
.
NDArray
)
assert
get_scalar
(
relay
.
fromtext
(
"1"
))
==
1
assert
get_scalar
(
relay
.
fromtext
(
SEMVER
+
"1"
))
==
1
assert
get_scalar
(
relay
.
fromtext
(
"10"
))
==
10
assert
get_scalar
(
relay
.
fromtext
(
SEMVER
+
"10"
))
==
10
assert
get_scalar
(
relay
.
fromtext
(
"0"
))
==
0
assert
get_scalar
(
relay
.
fromtext
(
SEMVER
+
"0"
))
==
0
assert
get_scalar
(
relay
.
fromtext
(
"-100"
))
==
-
100
assert
get_scalar
(
relay
.
fromtext
(
SEMVER
+
"-100"
))
==
-
100
assert
get_scalar
(
relay
.
fromtext
(
"-05"
))
==
-
5
assert
get_scalar
(
relay
.
fromtext
(
SEMVER
+
"-05"
))
==
-
5
@if_parser_enabled
@if_parser_enabled
def
test_float_literal
():
def
test_float_literal
():
assert
get_scalar
(
relay
.
fromtext
(
"1.0"
))
==
1.0
assert
get_scalar
(
relay
.
fromtext
(
SEMVER
+
"1.0"
))
==
1.0
assert
isclose
(
get_scalar
(
relay
.
fromtext
(
"1.56667"
)),
1.56667
)
assert
isclose
(
get_scalar
(
relay
.
fromtext
(
SEMVER
+
"1.56667"
)),
1.56667
)
assert
get_scalar
(
relay
.
fromtext
(
"0.0"
))
==
0.0
assert
get_scalar
(
relay
.
fromtext
(
SEMVER
+
"0.0"
))
==
0.0
assert
get_scalar
(
relay
.
fromtext
(
"-10.0"
))
==
-
10.0
assert
get_scalar
(
relay
.
fromtext
(
SEMVER
+
"-10.0"
))
==
-
10.0
# scientific notation
# scientific notation
assert
isclose
(
get_scalar
(
relay
.
fromtext
(
"1e-1"
)),
1e-1
)
assert
isclose
(
get_scalar
(
relay
.
fromtext
(
SEMVER
+
"1e-1"
)),
1e-1
)
assert
get_scalar
(
relay
.
fromtext
(
"1e+1"
))
==
1e+1
assert
get_scalar
(
relay
.
fromtext
(
SEMVER
+
"1e+1"
))
==
1e+1
assert
isclose
(
get_scalar
(
relay
.
fromtext
(
"1E-1"
)),
1E-1
)
assert
isclose
(
get_scalar
(
relay
.
fromtext
(
SEMVER
+
"1E-1"
)),
1E-1
)
assert
get_scalar
(
relay
.
fromtext
(
"1E+1"
))
==
1E+1
assert
get_scalar
(
relay
.
fromtext
(
SEMVER
+
"1E+1"
))
==
1E+1
assert
isclose
(
get_scalar
(
relay
.
fromtext
(
"1.0e-1"
)),
1.0e-1
)
assert
isclose
(
get_scalar
(
relay
.
fromtext
(
SEMVER
+
"1.0e-1"
)),
1.0e-1
)
assert
get_scalar
(
relay
.
fromtext
(
"1.0e+1"
))
==
1.0e+1
assert
get_scalar
(
relay
.
fromtext
(
SEMVER
+
"1.0e+1"
))
==
1.0e+1
assert
isclose
(
get_scalar
(
relay
.
fromtext
(
"1.0E-1"
)),
1.0E-1
)
assert
isclose
(
get_scalar
(
relay
.
fromtext
(
SEMVER
+
"1.0E-1"
)),
1.0E-1
)
assert
get_scalar
(
relay
.
fromtext
(
"1.0E+1"
))
==
1.0E+1
assert
get_scalar
(
relay
.
fromtext
(
SEMVER
+
"1.0E+1"
))
==
1.0E+1
@if_parser_enabled
@if_parser_enabled
def
test_bool_literal
():
def
test_bool_literal
():
assert
get_scalar
(
relay
.
fromtext
(
"True"
))
==
True
assert
get_scalar
(
relay
.
fromtext
(
SEMVER
+
"True"
))
==
True
assert
get_scalar
(
relay
.
fromtext
(
"False"
))
==
False
assert
get_scalar
(
relay
.
fromtext
(
SEMVER
+
"False"
))
==
False
@if_parser_enabled
@if_parser_enabled
def
test_negative
():
def
test_negative
():
assert
isinstance
(
relay
.
fromtext
(
"let
%
x = 1; -
%
x"
)
.
body
,
relay
.
Call
)
assert
isinstance
(
relay
.
fromtext
(
SEMVER
+
"let
%
x = 1; -
%
x"
)
.
body
,
relay
.
Call
)
assert
get_scalar
(
relay
.
fromtext
(
"--10"
))
==
10
assert
get_scalar
(
relay
.
fromtext
(
SEMVER
+
"--10"
))
==
10
assert
get_scalar
(
relay
.
fromtext
(
"---10"
))
==
-
10
assert
get_scalar
(
relay
.
fromtext
(
SEMVER
+
"---10"
))
==
-
10
@if_parser_enabled
@if_parser_enabled
def
test_bin_op
():
def
test_bin_op
():
for
bin_op
in
BINARY_OPS
.
keys
():
for
bin_op
in
BINARY_OPS
.
keys
():
assert
alpha_equal
(
assert
parses_as
(
relay
.
fromtext
(
"1 {} 1"
.
format
(
bin_op
)
),
"1 {} 1"
.
format
(
bin_op
),
BINARY_OPS
.
get
(
bin_op
)(
relay
.
const
(
1
),
relay
.
const
(
1
))
BINARY_OPS
.
get
(
bin_op
)(
relay
.
const
(
1
),
relay
.
const
(
1
))
)
)
@if_parser_enabled
@if_parser_enabled
def
test_parens
():
def
test_parens
():
assert
alpha_equal
(
relay
.
fromtext
(
"1 * 1 + 1"
),
relay
.
fromtext
(
"(1 * 1) + 1"
))
assert
alpha_equal
(
relay
.
fromtext
(
SEMVER
+
"1 * 1 + 1"
),
relay
.
fromtext
(
SEMVER
+
"(1 * 1) + 1"
))
assert
not
alpha_equal
(
relay
.
fromtext
(
"1 * 1 + 1"
),
relay
.
fromtext
(
"1 * (1 + 1)"
))
assert
not
alpha_equal
(
relay
.
fromtext
(
SEMVER
+
"1 * 1 + 1"
),
relay
.
fromtext
(
SEMVER
+
"1 * (1 + 1)"
))
@if_parser_enabled
@if_parser_enabled
def
test_op_assoc
():
def
test_op_assoc
():
assert
alpha_equal
(
relay
.
fromtext
(
"1 * 1 + 1 < 1 == 1"
),
relay
.
fromtext
(
"(((1 * 1) + 1) < 1) == 1"
))
assert
alpha_equal
(
relay
.
fromtext
(
SEMVER
+
"1 * 1 + 1 < 1 == 1"
),
relay
.
fromtext
(
SEMVER
+
"(((1 * 1) + 1) < 1) == 1"
))
assert
alpha_equal
(
relay
.
fromtext
(
"1 == 1 < 1 + 1 * 1"
),
relay
.
fromtext
(
"1 == (1 < (1 + (1 * 1)))"
))
assert
alpha_equal
(
relay
.
fromtext
(
SEMVER
+
"1 == 1 < 1 + 1 * 1"
),
relay
.
fromtext
(
SEMVER
+
"1 == (1 < (1 + (1 * 1)))"
))
@nottest
@nottest
@if_parser_enabled
@if_parser_enabled
...
@@ -159,24 +164,24 @@ def test_vars():
...
@@ -159,24 +164,24 @@ def test_vars():
# assert temp_var.name == "1"
# assert temp_var.name == "1"
# var
# var
var
=
relay
.
fromtext
(
"let
%
foo = ();
%
foo"
)
var
=
relay
.
fromtext
(
SEMVER
+
"let
%
foo = ();
%
foo"
)
assert
isinstance
(
var
.
body
,
relay
.
Var
)
assert
isinstance
(
var
.
body
,
relay
.
Var
)
assert
var
.
body
.
name_hint
==
"foo"
assert
var
.
body
.
name_hint
==
"foo"
# global var
# global var
global_var
=
relay
.
fromtext
(
"@foo"
)
global_var
=
relay
.
fromtext
(
SEMVER
+
"@foo"
)
assert
isinstance
(
global_var
,
relay
.
GlobalVar
)
assert
isinstance
(
global_var
,
relay
.
GlobalVar
)
assert
global_var
.
name_hint
==
"foo"
assert
global_var
.
name_hint
==
"foo"
# operator id
# operator id
op
=
relay
.
fromtext
(
"foo"
)
op
=
relay
.
fromtext
(
SEMVER
+
"foo"
)
assert
isinstance
(
op
,
relay
.
Op
)
assert
isinstance
(
op
,
relay
.
Op
)
assert
op
.
name
==
"foo"
assert
op
.
name
==
"foo"
@if_parser_enabled
@if_parser_enabled
def
test_let
():
def
test_let
():
assert
alpha_equal
(
assert
parses_as
(
relay
.
fromtext
(
"let
%
x = 1; ()"
)
,
"let
%
x = 1; ()"
,
relay
.
Let
(
relay
.
Let
(
X
,
X
,
relay
.
const
(
1
),
relay
.
const
(
1
),
...
@@ -184,18 +189,35 @@ def test_let():
...
@@ -184,18 +189,35 @@ def test_let():
)
)
)
)
assert
parses_as
(
"""
let
%
x = 1;
let
%
y = 2;
()
"""
,
relay
.
Let
(
X
,
relay
.
const
(
1
),
relay
.
Let
(
Y
,
relay
.
const
(
2
),
UNIT
)
)
)
@if_parser_enabled
@if_parser_enabled
def
test_seq
():
def
test_seq
():
assert
alpha_equal
(
assert
parses_as
(
relay
.
fromtext
(
"(); ()"
)
,
"(); ()"
,
relay
.
Let
(
relay
.
Let
(
_
,
_
,
UNIT
,
UNIT
,
UNIT
)
UNIT
)
)
)
assert
alpha_equal
(
assert
parses_as
(
relay
.
fromtext
(
"let
%
_ = { 1 }; ()"
)
,
"let
%
_ = { 1 }; ()"
,
relay
.
Let
(
relay
.
Let
(
X
,
X
,
relay
.
const
(
1
),
relay
.
const
(
1
),
...
@@ -203,31 +225,48 @@ def test_seq():
...
@@ -203,31 +225,48 @@ def test_seq():
)
)
)
)
@if_parser_enabled
def
test_graph
():
assert
parses_as
(
"
%0
= ();
%1
= 1; (
%0
,
%0
,
%1
)"
,
relay
.
Tuple
([
UNIT
,
UNIT
,
relay
.
const
(
1
)])
)
assert
not
parses_as
(
"
%0
= ();
%1
= 1; (
%0
,
%0
,
%1
)"
,
relay
.
Tuple
([
relay
.
Tuple
([]),
relay
.
Tuple
([]),
relay
.
const
(
1
)])
)
@raises_parse_error
@if_parser_enabled
def
test_graph_wrong_order
():
relay
.
fromtext
(
SEMVER
+
"
%1
= ();
%1
"
)
@raises_parse_error
@raises_parse_error
@if_parser_enabled
@if_parser_enabled
def
test_let_global_var
():
def
test_let_global_var
():
relay
.
fromtext
(
"let @x = 1; ()"
)
relay
.
fromtext
(
SEMVER
+
"let @x = 1; ()"
)
@raises_parse_error
@raises_parse_error
@if_parser_enabled
@if_parser_enabled
def
test_let_op
():
def
test_let_op
():
relay
.
fromtext
(
"let x = 1; ()"
)
relay
.
fromtext
(
SEMVER
+
"let x = 1; ()"
)
@if_parser_enabled
@if_parser_enabled
def
test_tuple
():
def
test_tuple
():
assert
alpha_equal
(
relay
.
fromtext
(
"()"
)
,
relay
.
Tuple
([]))
assert
parses_as
(
"()"
,
relay
.
Tuple
([]))
assert
alpha_equal
(
relay
.
fromtext
(
"(0,)"
)
,
relay
.
Tuple
([
relay
.
const
(
0
)]))
assert
parses_as
(
"(0,)"
,
relay
.
Tuple
([
relay
.
const
(
0
)]))
assert
alpha_equal
(
relay
.
fromtext
(
"(0, 1)"
)
,
relay
.
Tuple
([
relay
.
const
(
0
),
relay
.
const
(
1
)]))
assert
parses_as
(
"(0, 1)"
,
relay
.
Tuple
([
relay
.
const
(
0
),
relay
.
const
(
1
)]))
assert
alpha_equal
(
relay
.
fromtext
(
"(0, 1, 2)"
)
,
relay
.
Tuple
([
relay
.
const
(
0
),
relay
.
const
(
1
),
relay
.
const
(
2
)]))
assert
parses_as
(
"(0, 1, 2)"
,
relay
.
Tuple
([
relay
.
const
(
0
),
relay
.
const
(
1
),
relay
.
const
(
2
)]))
@if_parser_enabled
@if_parser_enabled
def
test_func
():
def
test_func
():
# 0 args
# 0 args
assert
alpha_equal
(
assert
parses_as
(
relay
.
fromtext
(
"fn () { 0 }"
)
,
"fn () { 0 }"
,
relay
.
Function
(
relay
.
Function
(
[],
[],
relay
.
const
(
0
),
relay
.
const
(
0
),
...
@@ -237,8 +276,8 @@ def test_func():
...
@@ -237,8 +276,8 @@ def test_func():
)
)
# 1 arg
# 1 arg
assert
alpha_equal
(
assert
parses_as
(
relay
.
fromtext
(
"fn (
%
x) {
%
x }"
)
,
"fn (
%
x) {
%
x }"
,
relay
.
Function
(
relay
.
Function
(
[
X
],
[
X
],
X
,
X
,
...
@@ -248,8 +287,8 @@ def test_func():
...
@@ -248,8 +287,8 @@ def test_func():
)
)
# 2 args
# 2 args
assert
alpha_equal
(
assert
parses_as
(
relay
.
fromtext
(
"fn (
%
x,
%
y) {
%
x +
%
y }"
)
,
"fn (
%
x,
%
y) {
%
x +
%
y }"
,
relay
.
Function
(
relay
.
Function
(
[
X
,
Y
],
[
X
,
Y
],
relay
.
add
(
X
,
Y
),
relay
.
add
(
X
,
Y
),
...
@@ -259,8 +298,8 @@ def test_func():
...
@@ -259,8 +298,8 @@ def test_func():
)
)
# annotations
# annotations
assert
alpha_equal
(
assert
parses_as
(
relay
.
fromtext
(
"fn (
%
x: int32) -> int32 {
%
x }"
)
,
"fn (
%
x: int32) -> int32 {
%
x }"
,
relay
.
Function
(
relay
.
Function
(
[
X_ANNO
],
[
X_ANNO
],
X_ANNO
,
X_ANNO
,
...
@@ -269,11 +308,17 @@ def test_func():
...
@@ -269,11 +308,17 @@ def test_func():
)
)
)
)
# attributes
assert
parses_as
(
"fn (n=5) { () }"
,
relay
.
Function
([],
UNIT
,
None
,
None
,
tvm
.
make
.
node
(
"DictAttrs"
,
n
=
relay
.
const
(
5
)))
)
# TODO(@jmp): Crashes if %x isn't annnotated.
# TODO(@jmp): Crashes if %x isn't annnotated.
# @nottest
@if_parser_enabled
@if_parser_enabled
def
test_defn
():
def
test_defn
():
id_defn
=
relay
.
fromtext
(
id_defn
=
relay
.
fromtext
(
SEMVER
+
"""
"""
def @id(
%
x: int32) -> int32 {
def @id(
%
x: int32) -> int32 {
%
x
%
x
...
@@ -284,6 +329,7 @@ def test_defn():
...
@@ -284,6 +329,7 @@ def test_defn():
@if_parser_enabled
@if_parser_enabled
def
test_recursive_call
():
def
test_recursive_call
():
id_defn
=
relay
.
fromtext
(
id_defn
=
relay
.
fromtext
(
SEMVER
+
"""
"""
def @id(
%
x: int32) -> int32 {
def @id(
%
x: int32) -> int32 {
@id(
%
x)
@id(
%
x)
...
@@ -293,16 +339,14 @@ def test_recursive_call():
...
@@ -293,16 +339,14 @@ def test_recursive_call():
@if_parser_enabled
@if_parser_enabled
def
test_ifelse
():
def
test_ifelse
():
assert
alpha_equal
(
assert
parses_as
(
relay
.
fromtext
(
"""
"""
if (True) {
if (True) {
0
0
} else {
} else {
1
1
}
}
"""
"""
,
),
relay
.
If
(
relay
.
If
(
relay
.
const
(
True
),
relay
.
const
(
True
),
relay
.
const
(
0
),
relay
.
const
(
0
),
...
@@ -314,6 +358,7 @@ def test_ifelse():
...
@@ -314,6 +358,7 @@ def test_ifelse():
@if_parser_enabled
@if_parser_enabled
def
test_ifelse_scope
():
def
test_ifelse_scope
():
relay
.
fromtext
(
relay
.
fromtext
(
SEMVER
+
"""
"""
if (True) {
if (True) {
let
%
x = ();
let
%
x = ();
...
@@ -328,13 +373,11 @@ def test_ifelse_scope():
...
@@ -328,13 +373,11 @@ def test_ifelse_scope():
def
test_call
():
def
test_call
():
# select right function to call: simple ident case
# select right function to call: simple ident case
id_func
=
relay
.
Var
(
"id"
)
id_func
=
relay
.
Var
(
"id"
)
assert
alpha_equal
(
assert
parses_as
(
relay
.
fromtext
(
"""
"""
let
%
id = fn (
%
x) {
%
x };
let
%
id = fn (
%
x) {
%
x };
10 *
%
id(10)
10 *
%
id(10)
"""
"""
,
),
relay
.
Let
(
relay
.
Let
(
id_func
,
id_func
,
relay
.
Function
([
X
],
X
,
None
,
[]),
relay
.
Function
([
X
],
X
,
None
,
[]),
...
@@ -344,13 +387,11 @@ def test_call():
...
@@ -344,13 +387,11 @@ def test_call():
# 0 args
# 0 args
constant
=
relay
.
Var
(
"constant"
)
constant
=
relay
.
Var
(
"constant"
)
assert
alpha_equal
(
assert
parses_as
(
relay
.
fromtext
(
"""
"""
let
%
constant = fn () { 0 };
let
%
constant = fn () { 0 };
%
constant()
%
constant()
"""
"""
,
),
relay
.
Let
(
relay
.
Let
(
constant
,
constant
,
relay
.
Function
([],
relay
.
const
(
0
),
None
,
[]),
relay
.
Function
([],
relay
.
const
(
0
),
None
,
[]),
...
@@ -360,13 +401,11 @@ def test_call():
...
@@ -360,13 +401,11 @@ def test_call():
# 1 arg
# 1 arg
id_var
=
relay
.
Var
(
"id"
)
id_var
=
relay
.
Var
(
"id"
)
assert
alpha_equal
(
assert
parses_as
(
relay
.
fromtext
(
"""
"""
let
%
id = fn (
%
x) {
%
x };
let
%
id = fn (
%
x) {
%
x };
%
id(1)
%
id(1)
"""
,
"""
),
relay
.
Let
(
relay
.
Let
(
id_var
,
id_var
,
relay
.
Function
([
X
],
X
,
None
,
[]),
relay
.
Function
([
X
],
X
,
None
,
[]),
...
@@ -376,13 +415,11 @@ def test_call():
...
@@ -376,13 +415,11 @@ def test_call():
# 2 args
# 2 args
multiply
=
relay
.
Var
(
"multiply"
)
multiply
=
relay
.
Var
(
"multiply"
)
assert
alpha_equal
(
assert
parses_as
(
relay
.
fromtext
(
"""
"""
let
%
multiply = fn (
%
x,
%
y) {
%
x *
%
y };
let
%
multiply = fn (
%
x,
%
y) {
%
x *
%
y };
%
multiply(0, 0)
%
multiply(0, 0)
"""
"""
,
),
relay
.
Let
(
relay
.
Let
(
multiply
,
multiply
,
relay
.
Function
(
relay
.
Function
(
...
@@ -396,12 +433,10 @@ def test_call():
...
@@ -396,12 +433,10 @@ def test_call():
)
)
# anonymous function
# anonymous function
assert
alpha_equal
(
assert
parses_as
(
relay
.
fromtext
(
"""
"""
(fn (
%
x) {
%
x })(0)
(fn (
%
x) {
%
x })(0)
"""
"""
,
),
relay
.
Call
(
relay
.
Call
(
relay
.
Function
(
relay
.
Function
(
[
X
],
[
X
],
...
@@ -415,45 +450,44 @@ def test_call():
...
@@ -415,45 +450,44 @@ def test_call():
)
)
)
)
# TODO(@jmp): re-enable after sequence parsing improvements
# curried function
# curried function
curried_mult
=
relay
.
Var
(
"curried_mult"
)
# curried_mult = relay.Var("curried_mult")
alpha_equal
(
# assert parses_as(
relay
.
fromtext
(
# """
"""
# let %curried_mult =
let
%
curried_mult =
# fn (%x) {
fn (
%
x) {
# fn (%y) {
fn (
%
y) {
# %x * %y
%
x *
%
y
# }
}
# };
};
# %curried_mult(0);
%
curried_mult(0);
# %curried_mult(0)(0)
%
curried_mult(0)(0)
# """,
"""
# relay.Let(
),
# curried_mult,
relay
.
Let
(
# relay.Function(
curried_mult
,
# [X],
relay
.
Function
(
# relay.Function(
[
X
],
# [Y],
relay
.
Function
(
# relay.multiply(X, Y),
[
Y
],
# None,
relay
.
multiply
(
X
,
Y
),
# []
None
,
# ),
[]
# None,
),
# []
None
,
# ),
[]
# relay.Let(
),
# _,
relay
.
Let
(
# relay.Call(curried_mult, [relay.const(0)], None, None),
_
,
# relay.Call(relay.Call(curried_mult, [relay.const(0)], None, None), [relay.const(0)], None, None)
relay
.
Call
(
curried_mult
,
[
relay
.
const
(
0
)],
None
,
None
),
# )
relay
.
Call
(
relay
.
Call
(
curried_mult
,
[
relay
.
const
(
0
)],
None
,
None
),
[
relay
.
const
(
0
)],
None
,
None
)
# )
)
# )
)
)
# op
# op
a
lpha_equal
(
a
ssert
parses_as
(
relay
.
fromtext
(
"abs(1)"
)
,
"abs(1)"
,
relay
.
Call
(
relay
.
op
.
get
(
"abs"
),
[
relay
.
const
(
1
)],
None
,
None
)
relay
.
Call
(
relay
.
op
.
get
(
"abs"
),
[
relay
.
const
(
1
)],
None
,
None
)
)
)
...
@@ -461,8 +495,8 @@ def test_call():
...
@@ -461,8 +495,8 @@ def test_call():
@if_parser_enabled
@if_parser_enabled
def
test_incomplete_type
():
def
test_incomplete_type
():
assert
alpha_equal
(
assert
parses_as
(
relay
.
fromtext
(
"let
%
_ : _ = (); ()"
)
,
"let
%
_ : _ = (); ()"
,
relay
.
Let
(
relay
.
Let
(
_
,
_
,
UNIT
,
UNIT
,
...
@@ -473,7 +507,7 @@ def test_incomplete_type():
...
@@ -473,7 +507,7 @@ def test_incomplete_type():
@if_parser_enabled
@if_parser_enabled
def
test_builtin_types
():
def
test_builtin_types
():
for
builtin_type
in
TYPES
:
for
builtin_type
in
TYPES
:
relay
.
fromtext
(
"let
%
_ : {} = (); ()"
.
format
(
builtin_type
))
relay
.
fromtext
(
SEMVER
+
"let
%
_ : {} = (); ()"
.
format
(
builtin_type
))
@nottest
@nottest
@if_parser_enabled
@if_parser_enabled
...
@@ -482,8 +516,8 @@ def test_call_type():
...
@@ -482,8 +516,8 @@ def test_call_type():
@if_parser_enabled
@if_parser_enabled
def
test_tensor_type
():
def
test_tensor_type
():
assert
alpha_equal
(
assert
parses_as
(
relay
.
fromtext
(
"let
%
_ : Tensor[(), float32] = (); ()"
)
,
"let
%
_ : Tensor[(), float32] = (); ()"
,
relay
.
Let
(
relay
.
Let
(
relay
.
Var
(
"_"
,
relay
.
TensorType
((),
"float32"
)),
relay
.
Var
(
"_"
,
relay
.
TensorType
((),
"float32"
)),
UNIT
,
UNIT
,
...
@@ -491,8 +525,8 @@ def test_tensor_type():
...
@@ -491,8 +525,8 @@ def test_tensor_type():
)
)
)
)
assert
alpha_equal
(
assert
parses_as
(
relay
.
fromtext
(
"let
%
_ : Tensor[(1,), float32] = (); ()"
)
,
"let
%
_ : Tensor[(1,), float32] = (); ()"
,
relay
.
Let
(
relay
.
Let
(
relay
.
Var
(
"_"
,
relay
.
TensorType
((
1
,),
"float32"
)),
relay
.
Var
(
"_"
,
relay
.
TensorType
((
1
,),
"float32"
)),
UNIT
,
UNIT
,
...
@@ -500,8 +534,8 @@ def test_tensor_type():
...
@@ -500,8 +534,8 @@ def test_tensor_type():
)
)
)
)
assert
alpha_equal
(
assert
parses_as
(
relay
.
fromtext
(
"let
%
_ : Tensor[(1, 1), float32] = (); ()"
)
,
"let
%
_ : Tensor[(1, 1), float32] = (); ()"
,
relay
.
Let
(
relay
.
Let
(
relay
.
Var
(
"_"
,
relay
.
TensorType
((
1
,
1
),
"float32"
)),
relay
.
Var
(
"_"
,
relay
.
TensorType
((
1
,
1
),
"float32"
)),
UNIT
,
UNIT
,
...
@@ -511,12 +545,10 @@ def test_tensor_type():
...
@@ -511,12 +545,10 @@ def test_tensor_type():
@if_parser_enabled
@if_parser_enabled
def
test_function_type
():
def
test_function_type
():
assert
alpha_equal
(
assert
parses_as
(
relay
.
fromtext
(
"""
"""
let
%
_: fn () -> int32 = fn () -> int32 { 0 }; ()
let
%
_: fn () -> int32 = fn () -> int32 { 0 }; ()
"""
,
"""
),
relay
.
Let
(
relay
.
Let
(
relay
.
Var
(
"_"
,
relay
.
FuncType
([],
int32
,
[],
[])),
relay
.
Var
(
"_"
,
relay
.
FuncType
([],
int32
,
[],
[])),
relay
.
Function
([],
relay
.
const
(
0
),
int32
,
[]),
relay
.
Function
([],
relay
.
const
(
0
),
int32
,
[]),
...
@@ -524,12 +556,10 @@ def test_function_type():
...
@@ -524,12 +556,10 @@ def test_function_type():
)
)
)
)
assert
alpha_equal
(
assert
parses_as
(
relay
.
fromtext
(
"""
"""
let
%
_: fn (int32) -> int32 = fn (
%
x: int32) -> int32 { 0 }; ()
let
%
_: fn (int32) -> int32 = fn (
%
x: int32) -> int32 { 0 }; ()
"""
,
"""
),
relay
.
Let
(
relay
.
Let
(
relay
.
Var
(
"_"
,
relay
.
FuncType
([
int32
],
int32
,
[],
[])),
relay
.
Var
(
"_"
,
relay
.
FuncType
([
int32
],
int32
,
[],
[])),
relay
.
Function
([
relay
.
Var
(
"x"
,
int32
)],
relay
.
const
(
0
),
int32
,
[]),
relay
.
Function
([
relay
.
Var
(
"x"
,
int32
)],
relay
.
const
(
0
),
int32
,
[]),
...
@@ -537,12 +567,10 @@ def test_function_type():
...
@@ -537,12 +567,10 @@ def test_function_type():
)
)
)
)
assert
alpha_equal
(
assert
parses_as
(
relay
.
fromtext
(
"""
"""
let
%
_: fn (int32, int32) -> int32 = fn (
%
x: int32,
%
y: int32) -> int32 { 0 }; ()
let
%
_: fn (int32, int32) -> int32 = fn (
%
x: int32,
%
y: int32) -> int32 { 0 }; ()
"""
,
"""
),
relay
.
Let
(
relay
.
Let
(
relay
.
Var
(
"_"
,
relay
.
FuncType
([
int32
,
int32
],
int32
,
[],
[])),
relay
.
Var
(
"_"
,
relay
.
FuncType
([
int32
,
int32
],
int32
,
[],
[])),
relay
.
Function
([
relay
.
Var
(
"x"
,
int32
),
relay
.
Var
(
"y"
,
int32
)],
relay
.
const
(
0
),
int32
,
[]),
relay
.
Function
([
relay
.
Var
(
"x"
,
int32
),
relay
.
Var
(
"y"
,
int32
)],
relay
.
const
(
0
),
int32
,
[]),
...
@@ -552,11 +580,10 @@ def test_function_type():
...
@@ -552,11 +580,10 @@ def test_function_type():
@if_parser_enabled
@if_parser_enabled
def
test_tuple_type
():
def
test_tuple_type
():
assert
alpha_equal
(
assert
parses_as
(
relay
.
fromtext
(
"""
"""
let
%
_: () = (); ()
let
%
_: () = (); ()
"""
)
,
"""
,
relay
.
Let
(
relay
.
Let
(
relay
.
Var
(
"_"
,
relay
.
TupleType
([])),
relay
.
Var
(
"_"
,
relay
.
TupleType
([])),
UNIT
,
UNIT
,
...
@@ -564,11 +591,10 @@ def test_tuple_type():
...
@@ -564,11 +591,10 @@ def test_tuple_type():
)
)
)
)
assert
alpha_equal
(
assert
parses_as
(
relay
.
fromtext
(
"""
"""
let
%
_: (int32,) = (0,); ()
let
%
_: (int32,) = (0,); ()
"""
)
,
"""
,
relay
.
Let
(
relay
.
Let
(
relay
.
Var
(
"_"
,
relay
.
TupleType
([
int32
])),
relay
.
Var
(
"_"
,
relay
.
TupleType
([
int32
])),
relay
.
Tuple
([
relay
.
const
(
0
)]),
relay
.
Tuple
([
relay
.
const
(
0
)]),
...
@@ -576,11 +602,10 @@ def test_tuple_type():
...
@@ -576,11 +602,10 @@ def test_tuple_type():
)
)
)
)
assert
alpha_equal
(
assert
parses_as
(
relay
.
fromtext
(
"""
"""
let
%
_: (int32, int32) = (0, 1); ()
let
%
_: (int32, int32) = (0, 1); ()
"""
)
,
"""
,
relay
.
Let
(
relay
.
Let
(
relay
.
Var
(
"_"
,
relay
.
TupleType
([
int32
,
int32
])),
relay
.
Var
(
"_"
,
relay
.
TupleType
([
int32
,
int32
])),
relay
.
Tuple
([
relay
.
const
(
0
),
relay
.
const
(
1
)]),
relay
.
Tuple
([
relay
.
const
(
0
),
relay
.
const
(
1
)]),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment