Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
52e55baa
Commit
52e55baa
authored
Dec 02, 2018
by
Josh Pollock
Committed by
Tianqi Chen
Dec 02, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Relay] Parser Tests (#2209)
parent
d3bc59d2
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
570 additions
and
8 deletions
+570
-8
src/relay/ir/alpha_equal.cc
+6
-6
src/relay/ir/text_printer.cc
+2
-2
tests/python/relay/test_ir_parser.py
+562
-0
No files found.
src/relay/ir/alpha_equal.cc
View file @
52e55baa
...
...
@@ -26,7 +26,7 @@ class AlphaEqualHandler:
* Check equality of two nodes.
* \param lhs The left hand operand.
* \param rhs The right hand operand.
* \return
t
he compare result.
* \return
T
he compare result.
*/
bool
Equal
(
const
NodeRef
&
lhs
,
const
NodeRef
&
rhs
)
{
if
(
lhs
.
same_as
(
rhs
))
return
true
;
...
...
@@ -46,7 +46,7 @@ class AlphaEqualHandler:
* Check equality of two attributes.
* \param lhs The left hand operand.
* \param rhs The right hand operand.
* \return
t
he compare result.
* \return
T
he compare result.
*/
bool
AttrEqual
(
const
NodeRef
&
lhs
,
const
NodeRef
&
rhs
)
{
return
AttrsEqualHandler
::
Equal
(
lhs
,
rhs
);
...
...
@@ -55,7 +55,7 @@ class AlphaEqualHandler:
* Check equality of two types.
* \param lhs The left hand operand.
* \param rhs The right hand operand.
* \return
t
he compare result.
* \return
T
he compare result.
*/
bool
TypeEqual
(
const
Type
&
lhs
,
const
Type
&
rhs
)
{
if
(
lhs
.
same_as
(
rhs
))
return
true
;
...
...
@@ -72,7 +72,7 @@ class AlphaEqualHandler:
*
* \param lhs The left hand operand.
* \param rhs The right hand operand.
* \return
t
he compare result.
* \return
T
he compare result.
*/
bool
ExprEqual
(
const
Expr
&
lhs
,
const
Expr
&
rhs
)
{
if
(
lhs
.
same_as
(
rhs
))
return
true
;
...
...
@@ -94,7 +94,7 @@ class AlphaEqualHandler:
* \brief Check if data type equals each other.
* \param lhs The left hand operand.
* \param rhs The right hand operand.
* \return
t
he compare result.
* \return
T
he compare result.
*/
bool
DataTypeEqual
(
const
DataType
&
lhs
,
const
DataType
&
rhs
)
{
return
lhs
==
rhs
;
...
...
@@ -104,7 +104,7 @@ class AlphaEqualHandler:
* if map_free_var_ is set to true, try to map via equal node.
* \param lhs The left hand operand.
* \param rhs The right hand operand.
* \return
t
he compare result.
* \return
T
he compare result.
*/
bool
LeafNodeEqual
(
const
NodeRef
&
lhs
,
const
NodeRef
&
rhs
)
{
if
(
lhs
.
same_as
(
rhs
))
return
true
;
...
...
src/relay/ir/text_printer.cc
View file @
52e55baa
...
...
@@ -38,7 +38,7 @@ inline std::ostream& operator<<(std::ostream& os, const TextValue& val) { // NO
* It can be hard to design a text format for all the possible nodes
* as the set of nodes can grow when we do more extensions.
*
* Instead of trying to design readable text format for every node
s
,
* Instead of trying to design readable text format for every node,
* we support a meta-data section in the text format.
* We allow the text format to refer to a node in the meta-data section.
*
...
...
@@ -73,7 +73,7 @@ inline std::ostream& operator<<(std::ostream& os, const TextValue& val) { // NO
* \endcode
*
* Note that we store tvm.var("n") in the meta data section.
* Since it is stored in the index-0 in the meta-data sec
it
on,
* Since it is stored in the index-0 in the meta-data sec
ti
on,
* we print it as meta.Variable(0).
*
* The text parser can recover this object by loading from the corresponding
...
...
tests/python/relay/test_ir_parser.py
0 → 100644
View file @
52e55baa
import
tvm
from
tvm
import
relay
from
tvm.relay.parser
import
enabled
from
tvm.relay.ir_pass
import
alpha_equal
from
nose.tools
import
nottest
,
raises
from
numpy
import
isclose
from
typing
import
Union
from
functools
import
wraps
if
enabled
():
from
tvm.relay._parser
import
ParseError
raises_parse_error
=
raises
(
ParseError
)
else
:
raises_parse_error
=
lambda
x
:
x
BINARY_OPS
=
{
"*"
:
relay
.
multiply
,
"/"
:
relay
.
divide
,
"+"
:
relay
.
add
,
"-"
:
relay
.
subtract
,
"<"
:
relay
.
less
,
">"
:
relay
.
greater
,
"<="
:
relay
.
less_equal
,
">="
:
relay
.
greater_equal
,
"=="
:
relay
.
equal
,
"!="
:
relay
.
not_equal
,
}
TYPES
=
{
"int8"
,
"int16"
,
"int32"
,
"int64"
,
"uint8"
,
"uint16"
,
"uint32"
,
"uint64"
,
"float16"
,
"float32"
,
"float64"
,
"bool"
,
"int8x4"
,
"uint1x4"
,
"float16x4"
,
}
def
get_scalar
(
x
):
# type: (relay.Constant) -> (Union[float, int, bool])
return
x
.
data
.
asnumpy
()
.
item
()
int32
=
relay
.
scalar_type
(
"int32"
)
_
=
relay
.
Var
(
"_"
)
X
=
relay
.
Var
(
"x"
)
Y
=
relay
.
Var
(
"y"
)
X_ANNO
=
relay
.
Var
(
"x"
,
int32
)
Y_ANNO
=
relay
.
Var
(
"y"
,
int32
)
UNIT
=
relay
.
Tuple
([])
# decorator to determine if parser is enabled
def
if_parser_enabled
(
func
):
# https://stackoverflow.com/q/7727678
@wraps
(
func
)
def
wrapper
():
if
not
enabled
():
return
func
()
return
wrapper
@if_parser_enabled
def
test_comments
():
assert
alpha_equal
(
relay
.
fromtext
(
"""
// This is a line comment!
()
"""
),
UNIT
)
assert
alpha_equal
(
relay
.
fromtext
(
"""
/* This is a block comment!
This is still a block comment!
*/
()
"""
),
UNIT
)
@if_parser_enabled
def
test_int_literal
():
assert
isinstance
(
relay
.
fromtext
(
"1"
),
relay
.
Constant
)
assert
isinstance
(
relay
.
fromtext
(
"1"
)
.
data
,
tvm
.
ndarray
.
NDArray
)
assert
get_scalar
(
relay
.
fromtext
(
"1"
))
==
1
assert
get_scalar
(
relay
.
fromtext
(
"10"
))
==
10
assert
get_scalar
(
relay
.
fromtext
(
"0"
))
==
0
assert
get_scalar
(
relay
.
fromtext
(
"-100"
))
==
-
100
assert
get_scalar
(
relay
.
fromtext
(
"-05"
))
==
-
5
@if_parser_enabled
def
test_float_literal
():
assert
get_scalar
(
relay
.
fromtext
(
"1.0"
))
==
1.0
assert
isclose
(
get_scalar
(
relay
.
fromtext
(
"1.56667"
)),
1.56667
)
assert
get_scalar
(
relay
.
fromtext
(
"0.0"
))
==
0.0
assert
get_scalar
(
relay
.
fromtext
(
"-10.0"
))
==
-
10.0
# scientific notation
assert
isclose
(
get_scalar
(
relay
.
fromtext
(
"1e-1"
)),
1e-1
)
assert
get_scalar
(
relay
.
fromtext
(
"1e+1"
))
==
1e+1
assert
isclose
(
get_scalar
(
relay
.
fromtext
(
"1E-1"
)),
1E-1
)
assert
get_scalar
(
relay
.
fromtext
(
"1E+1"
))
==
1E+1
assert
isclose
(
get_scalar
(
relay
.
fromtext
(
"1.0e-1"
)),
1.0e-1
)
assert
get_scalar
(
relay
.
fromtext
(
"1.0e+1"
))
==
1.0e+1
assert
isclose
(
get_scalar
(
relay
.
fromtext
(
"1.0E-1"
)),
1.0E-1
)
assert
get_scalar
(
relay
.
fromtext
(
"1.0E+1"
))
==
1.0E+1
@if_parser_enabled
def
test_bool_literal
():
assert
get_scalar
(
relay
.
fromtext
(
"True"
))
==
True
assert
get_scalar
(
relay
.
fromtext
(
"False"
))
==
False
@if_parser_enabled
def
test_negative
():
assert
isinstance
(
relay
.
fromtext
(
"let
%
x = 1; -
%
x"
)
.
body
,
relay
.
Call
)
assert
get_scalar
(
relay
.
fromtext
(
"--10"
))
==
10
assert
get_scalar
(
relay
.
fromtext
(
"---10"
))
==
-
10
@if_parser_enabled
def
test_bin_op
():
for
bin_op
in
BINARY_OPS
.
keys
():
assert
alpha_equal
(
relay
.
fromtext
(
"1 {} 1"
.
format
(
bin_op
)),
BINARY_OPS
.
get
(
bin_op
)(
relay
.
const
(
1
),
relay
.
const
(
1
))
)
@if_parser_enabled
def
test_parens
():
assert
alpha_equal
(
relay
.
fromtext
(
"1 * 1 + 1"
),
relay
.
fromtext
(
"(1 * 1) + 1"
))
assert
not
alpha_equal
(
relay
.
fromtext
(
"1 * 1 + 1"
),
relay
.
fromtext
(
"1 * (1 + 1)"
))
@if_parser_enabled
def
test_op_assoc
():
assert
alpha_equal
(
relay
.
fromtext
(
"1 * 1 + 1 < 1 == 1"
),
relay
.
fromtext
(
"(((1 * 1) + 1) < 1) == 1"
))
assert
alpha_equal
(
relay
.
fromtext
(
"1 == 1 < 1 + 1 * 1"
),
relay
.
fromtext
(
"1 == (1 < (1 + (1 * 1)))"
))
@nottest
@if_parser_enabled
def
test_vars
():
# temp vars won't work b/c they start with a digit
# # temp var
# temp_var = relay.fromtext("%1")
# assert isinstance(temp_var, relay.Var)
# assert temp_var.name == "1"
# var
var
=
relay
.
fromtext
(
"let
%
foo = ();
%
foo"
)
assert
isinstance
(
var
.
body
,
relay
.
Var
)
assert
var
.
body
.
name_hint
==
"foo"
# global var
global_var
=
relay
.
fromtext
(
"@foo"
)
assert
isinstance
(
global_var
,
relay
.
GlobalVar
)
assert
global_var
.
name_hint
==
"foo"
# operator id
op
=
relay
.
fromtext
(
"foo"
)
assert
isinstance
(
op
,
relay
.
Op
)
assert
op
.
name
==
"foo"
@if_parser_enabled
def
test_let
():
assert
alpha_equal
(
relay
.
fromtext
(
"let
%
x = 1; ()"
),
relay
.
Let
(
X
,
relay
.
const
(
1
),
UNIT
)
)
@if_parser_enabled
def
test_seq
():
assert
alpha_equal
(
relay
.
fromtext
(
"(); ()"
),
relay
.
Let
(
_
,
UNIT
,
UNIT
)
)
assert
alpha_equal
(
relay
.
fromtext
(
"let
%
_ = { 1 }; ()"
),
relay
.
Let
(
X
,
relay
.
const
(
1
),
UNIT
)
)
@raises_parse_error
@if_parser_enabled
def
test_let_global_var
():
relay
.
fromtext
(
"let @x = 1; ()"
)
@raises_parse_error
@if_parser_enabled
def
test_let_op
():
relay
.
fromtext
(
"let x = 1; ()"
)
@if_parser_enabled
def
test_tuple
():
assert
alpha_equal
(
relay
.
fromtext
(
"()"
),
relay
.
Tuple
([]))
assert
alpha_equal
(
relay
.
fromtext
(
"(0,)"
),
relay
.
Tuple
([
relay
.
const
(
0
)]))
assert
alpha_equal
(
relay
.
fromtext
(
"(0, 1)"
),
relay
.
Tuple
([
relay
.
const
(
0
),
relay
.
const
(
1
)]))
assert
alpha_equal
(
relay
.
fromtext
(
"(0, 1, 2)"
),
relay
.
Tuple
([
relay
.
const
(
0
),
relay
.
const
(
1
),
relay
.
const
(
2
)]))
@if_parser_enabled
def
test_func
():
# 0 args
assert
alpha_equal
(
relay
.
fromtext
(
"fn () { 0 }"
),
relay
.
Function
(
[],
relay
.
const
(
0
),
None
,
[]
)
)
# 1 arg
assert
alpha_equal
(
relay
.
fromtext
(
"fn (
%
x) {
%
x }"
),
relay
.
Function
(
[
X
],
X
,
None
,
[]
)
)
# 2 args
assert
alpha_equal
(
relay
.
fromtext
(
"fn (
%
x,
%
y) {
%
x +
%
y }"
),
relay
.
Function
(
[
X
,
Y
],
relay
.
add
(
X
,
Y
),
None
,
[]
)
)
# annotations
assert
alpha_equal
(
relay
.
fromtext
(
"fn (
%
x: int32) -> int32 {
%
x }"
),
relay
.
Function
(
[
X_ANNO
],
X_ANNO
,
int32
,
[]
)
)
# TODO(@jmp): Crashes if %x isn't annnotated.
# @nottest
@if_parser_enabled
def
test_defn
():
id_defn
=
relay
.
fromtext
(
"""
def @id(
%
x: int32) -> int32 {
%
x
}
"""
)
assert
isinstance
(
id_defn
,
relay
.
Module
)
@if_parser_enabled
def
test_ifelse
():
assert
alpha_equal
(
relay
.
fromtext
(
"""
if (True) {
0
} else {
1
}
"""
),
relay
.
If
(
relay
.
const
(
True
),
relay
.
const
(
0
),
relay
.
const
(
1
)
)
)
@raises_parse_error
@if_parser_enabled
def
test_ifelse_scope
():
relay
.
fromtext
(
"""
if (True) {
let
%
x = ();
()
} else {
%
x
}
"""
)
@if_parser_enabled
def
test_call
():
# 0 args
constant
=
relay
.
Var
(
"constant"
)
assert
alpha_equal
(
relay
.
fromtext
(
"""
let
%
constant = fn () { 0 };
%
constant()
"""
),
relay
.
Let
(
constant
,
relay
.
Function
([],
relay
.
const
(
0
),
None
,
[]),
relay
.
Call
(
constant
,
[],
None
,
None
)
)
)
# 1 arg
id_var
=
relay
.
Var
(
"id"
)
assert
alpha_equal
(
relay
.
fromtext
(
"""
let
%
id = fn (
%
x) {
%
x };
%
id(1)
"""
),
relay
.
Let
(
id_var
,
relay
.
Function
([
X
],
X
,
None
,
[]),
relay
.
Call
(
id_var
,
[
relay
.
const
(
1
)],
None
,
None
)
)
)
# 2 args
multiply
=
relay
.
Var
(
"multiply"
)
assert
alpha_equal
(
relay
.
fromtext
(
"""
let
%
multiply = fn (
%
x,
%
y) {
%
x *
%
y };
%
multiply(0, 0)
"""
),
relay
.
Let
(
multiply
,
relay
.
Function
(
[
X
,
Y
],
relay
.
multiply
(
X
,
Y
),
None
,
[]
),
relay
.
Call
(
multiply
,
[
relay
.
const
(
0
),
relay
.
const
(
0
)],
None
,
None
)
)
)
# anonymous function
assert
alpha_equal
(
relay
.
fromtext
(
"""
(fn (
%
x) {
%
x })(0)
"""
),
relay
.
Call
(
relay
.
Function
(
[
X
],
X
,
None
,
[]
),
[
relay
.
const
(
0
)],
None
,
None
)
)
# curried function
curried_mult
=
relay
.
Var
(
"curried_mult"
)
alpha_equal
(
relay
.
fromtext
(
"""
let
%
curried_mult =
fn (
%
x) {
fn (
%
y) {
%
x *
%
y
}
};
%
curried_mult(0);
%
curried_mult(0)(0)
"""
),
relay
.
Let
(
curried_mult
,
relay
.
Function
(
[
X
],
relay
.
Function
(
[
Y
],
relay
.
multiply
(
X
,
Y
),
None
,
[]
),
None
,
[]
),
relay
.
Let
(
_
,
relay
.
Call
(
curried_mult
,
[
relay
.
const
(
0
)],
None
,
None
),
relay
.
Call
(
relay
.
Call
(
curried_mult
,
[
relay
.
const
(
0
)],
None
,
None
),
[
relay
.
const
(
0
)],
None
,
None
)
)
)
)
# op
alpha_equal
(
relay
.
fromtext
(
"abs(1)"
),
relay
.
Call
(
relay
.
op
.
get
(
"abs"
),
[
relay
.
const
(
1
)],
None
,
None
)
)
# Types
@if_parser_enabled
def
test_incomplete_type
():
assert
alpha_equal
(
relay
.
fromtext
(
"let
%
_ : _ = (); ()"
),
relay
.
Let
(
_
,
UNIT
,
UNIT
)
)
@if_parser_enabled
def
test_builtin_types
():
for
builtin_type
in
TYPES
:
relay
.
fromtext
(
"let
%
_ : {} = (); ()"
.
format
(
builtin_type
))
@nottest
@if_parser_enabled
def
test_call_type
():
assert
False
@if_parser_enabled
def
test_tensor_type
():
assert
alpha_equal
(
relay
.
fromtext
(
"let
%
_ : Tensor[(), float32] = (); ()"
),
relay
.
Let
(
relay
.
Var
(
"_"
,
relay
.
TensorType
((),
"float32"
)),
UNIT
,
UNIT
)
)
assert
alpha_equal
(
relay
.
fromtext
(
"let
%
_ : Tensor[(1,), float32] = (); ()"
),
relay
.
Let
(
relay
.
Var
(
"_"
,
relay
.
TensorType
((
1
,),
"float32"
)),
UNIT
,
UNIT
)
)
assert
alpha_equal
(
relay
.
fromtext
(
"let
%
_ : Tensor[(1, 1), float32] = (); ()"
),
relay
.
Let
(
relay
.
Var
(
"_"
,
relay
.
TensorType
((
1
,
1
),
"float32"
)),
UNIT
,
UNIT
)
)
@if_parser_enabled
def
test_function_type
():
assert
alpha_equal
(
relay
.
fromtext
(
"""
let
%
_: fn () -> int32 = fn () -> int32 { 0 }; ()
"""
),
relay
.
Let
(
relay
.
Var
(
"_"
,
relay
.
FuncType
([],
int32
,
[],
[])),
relay
.
Function
([],
relay
.
const
(
0
),
int32
,
[]),
UNIT
)
)
assert
alpha_equal
(
relay
.
fromtext
(
"""
let
%
_: fn (int32) -> int32 = fn (
%
x: int32) -> int32 { 0 }; ()
"""
),
relay
.
Let
(
relay
.
Var
(
"_"
,
relay
.
FuncType
([
int32
],
int32
,
[],
[])),
relay
.
Function
([
relay
.
Var
(
"x"
,
int32
)],
relay
.
const
(
0
),
int32
,
[]),
UNIT
)
)
assert
alpha_equal
(
relay
.
fromtext
(
"""
let
%
_: fn (int32, int32) -> int32 = fn (
%
x: int32,
%
y: int32) -> int32 { 0 }; ()
"""
),
relay
.
Let
(
relay
.
Var
(
"_"
,
relay
.
FuncType
([
int32
,
int32
],
int32
,
[],
[])),
relay
.
Function
([
relay
.
Var
(
"x"
,
int32
),
relay
.
Var
(
"y"
,
int32
)],
relay
.
const
(
0
),
int32
,
[]),
UNIT
)
)
@if_parser_enabled
def
test_tuple_type
():
assert
alpha_equal
(
relay
.
fromtext
(
"""
let
%
_: () = (); ()
"""
),
relay
.
Let
(
relay
.
Var
(
"_"
,
relay
.
TupleType
([])),
UNIT
,
UNIT
)
)
assert
alpha_equal
(
relay
.
fromtext
(
"""
let
%
_: (int32,) = (0,); ()
"""
),
relay
.
Let
(
relay
.
Var
(
"_"
,
relay
.
TupleType
([
int32
])),
relay
.
Tuple
([
relay
.
const
(
0
)]),
UNIT
)
)
assert
alpha_equal
(
relay
.
fromtext
(
"""
let
%
_: (int32, int32) = (0, 1); ()
"""
),
relay
.
Let
(
relay
.
Var
(
"_"
,
relay
.
TupleType
([
int32
,
int32
])),
relay
.
Tuple
([
relay
.
const
(
0
),
relay
.
const
(
1
)]),
UNIT
)
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment