Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
3a1bb8c7
Commit
3a1bb8c7
authored
Oct 23, 2018
by
Steven S. Lyubomirsky
Committed by
Tianqi Chen
Oct 23, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Relay] Serialization round-trip tests (#1968)
parent
975d0d44
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
61 additions
and
7 deletions
+61
-7
python/tvm/relay/ir_pass.py
+22
-0
src/relay/ir/alpha_equal.cc
+1
-1
tests/python/relay/test_ir_nodes.py
+38
-6
No files found.
python/tvm/relay/ir_pass.py
View file @
3a1bb8c7
...
...
@@ -141,3 +141,25 @@ def alpha_equal(lhs, rhs):
True iff lhs is alpha equal to rhs.
"""
return
bool
(
_make
.
_alpha_equal
(
lhs
,
rhs
))
def
graph_equal
(
lhs
,
rhs
):
"""Compare two Relay expr for data-flow equivalence.
The difference between this and alpha-equality is that
variables are not expected to match between lhs and rhs;
they are treated as sources and are mapped between each other.
Parameters
----------
lhs: tvm.relay.Expr
One of the input Expression.
rhs: tvm.relay.Expr
One of the input Expression.
Returns
-------
result: bool
True iff lhs is data-flow equivalent to rhs.
"""
return
bool
(
_make
.
_graph_equal
(
lhs
,
rhs
))
src/relay/ir/alpha_equal.cc
View file @
3a1bb8c7
...
...
@@ -183,7 +183,7 @@ class AlphaEqualHandler:
bool
VisitType_
(
const
TypeRelationNode
*
lhs
,
const
Type
&
other
)
final
{
if
(
const
TypeRelationNode
*
rhs
=
other
.
as
<
TypeRelationNode
>
())
{
if
(
!
lhs
->
func
.
same_as
(
rhs
->
func
)
)
return
false
;
if
(
lhs
->
func
->
name
!=
rhs
->
func
->
name
)
return
false
;
if
(
lhs
->
num_inputs
!=
rhs
->
num_inputs
)
return
false
;
if
(
!
this
->
AttrEqual
(
lhs
->
attrs
,
rhs
->
attrs
))
return
false
;
if
(
lhs
->
args
.
size
()
!=
rhs
->
args
.
size
())
return
false
;
...
...
tests/python/relay/test_ir_nodes.py
View file @
3a1bb8c7
...
...
@@ -2,6 +2,14 @@
import
tvm
from
tvm
import
relay
from
tvm.expr
import
*
from
tvm.relay.ir_pass
import
graph_equal
def
check_json_roundtrip
(
node
):
json_str
=
tvm
.
save_json
(
node
)
back
=
tvm
.
load_json
(
json_str
)
assert
graph_equal
(
back
,
node
)
def
test_bad_constructor
():
try
:
...
...
@@ -21,6 +29,13 @@ def test_span():
assert
isinstance
(
span
,
relay
.
base
.
Span
)
str
(
span
)
# span is not a node so we can't use graph_equal
# to test the round trip
back
=
tvm
.
load_json
(
tvm
.
save_json
(
span
))
assert
back
.
source
==
span
.
source
assert
back
.
lineno
==
span
.
lineno
assert
back
.
col_offset
==
span
.
col_offset
# Types
def
test_tensor_type
():
...
...
@@ -31,6 +46,7 @@ def test_tensor_type():
assert
tt
.
shape
==
shape
assert
tt
.
span
==
None
str
(
tt
)
check_json_roundtrip
(
tt
)
def
test_type_param
():
...
...
@@ -38,21 +54,23 @@ def test_type_param():
assert
tp
.
kind
==
relay
.
Kind
.
Type
# assert tp.span # TODO allow us to set span
str
(
tp
)
check_json_roundtrip
(
tp
)
def
test_func_type
():
type_params
=
tvm
.
convert
([])
type_constraints
=
tvm
.
convert
([])
# TODO: fill me in
arg_types
=
tvm
.
convert
([])
ret_type
=
None
ret_type
=
relay
.
TensorType
((
1
,
2
,
3
),
'float32'
)
tf
=
relay
.
FuncType
(
arg_types
,
ret_type
,
type_params
,
type_constraints
)
assert
tf
.
type_params
==
type_params
assert
tf
.
type_constraints
==
type_constraints
assert
tf
.
arg_types
==
arg_types
assert
tf
.
ret_type
==
ret_type
assert
tf
.
span
==
None
# TODO make sure we can set
# TODO make sure we can set
span
str
(
tf
)
check_json_roundtrip
(
tf
)
def
test_tuple_type
():
...
...
@@ -63,13 +81,15 @@ def test_tuple_type():
tup_ty
=
relay
.
TupleType
(
fields
)
assert
tup_ty
.
fields
==
fields
str
(
tup_ty
)
check_json_roundtrip
(
tup_ty
)
def
test_type_relation
():
tp
=
relay
.
TypeVar
(
'tp'
,
relay
.
Kind
.
Type
)
tf
=
relay
.
FuncType
(
tvm
.
convert
([]),
None
,
tvm
.
convert
([]),
tvm
.
convert
([]))
tt
=
relay
.
TensorType
(
tvm
.
convert
([
1
,
2
,
3
]),
'float32'
)
args
=
tvm
.
convert
([
t
f
,
tt
,
tp
])
args
=
tvm
.
convert
([
t
p
,
tf
,
tt
])
num_inputs
=
2
func
=
tvm
.
get_env_func
(
"tvm.relay.type_relation.Broadcast"
)
...
...
@@ -78,6 +98,8 @@ def test_type_relation():
tr
=
relay
.
TypeRelation
(
func
,
args
,
num_inputs
,
attrs
)
assert
tr
.
args
==
args
assert
tr
.
num_inputs
==
num_inputs
str
(
tr
)
check_json_roundtrip
(
tr
)
def
test_constant
():
...
...
@@ -86,6 +108,7 @@ def test_constant():
assert
const
.
data
==
arr
assert
const
.
span
==
None
str
(
const
)
check_json_roundtrip
(
const
)
def
test_tuple
():
...
...
@@ -94,6 +117,7 @@ def test_tuple():
assert
tup
.
fields
==
fields
assert
tup
.
span
==
None
str
(
tup
)
check_json_roundtrip
(
tup
)
def
test_local_var
():
...
...
@@ -103,6 +127,7 @@ def test_local_var():
assert
lv
.
type_annotation
is
None
# assert lv.span == None todo(@jroesch): what do we do about spans
str
(
lv
)
check_json_roundtrip
(
lv
)
t1
=
relay
.
ty
.
TensorType
((),
"float"
)
lv
=
relay
.
Var
(
name_hint
,
t1
)
...
...
@@ -116,20 +141,22 @@ def test_global_var():
gv
.
name_hint
==
name_hint
# assert lv.span == None todo(@jroesch): what do we do about spans
str
(
gv
)
check_json_roundtrip
(
gv
)
def
test_function
():
param_names
=
[
'a'
,
'b'
,
'c'
,
'd'
]
params
=
tvm
.
convert
([
relay
.
Var
(
n
)
for
n
in
param_names
])
ret_type
=
None
body
=
None
ret_type
=
relay
.
TupleType
(
tvm
.
convert
([]))
body
=
relay
.
Tuple
(
tvm
.
convert
([]))
type_params
=
tvm
.
convert
([])
fn
=
relay
.
Function
(
params
,
ret_type
,
body
,
type_params
)
fn
=
relay
.
Function
(
params
,
body
,
ret_type
,
type_params
)
assert
fn
.
params
==
params
assert
fn
.
body
==
body
assert
fn
.
type_params
==
type_params
assert
fn
.
span
==
None
str
(
fn
)
check_json_roundtrip
(
fn
)
def
test_call
():
...
...
@@ -141,6 +168,7 @@ def test_call():
assert
call
.
args
==
args
assert
call
.
span
==
None
str
(
call
)
check_json_roundtrip
(
call
)
def
test_let
():
...
...
@@ -156,6 +184,7 @@ def test_let():
assert
let
.
body
==
lv
assert
let
.
span
==
None
str
(
let
)
check_json_roundtrip
(
let
)
def
test_if
():
...
...
@@ -168,6 +197,7 @@ def test_if():
assert
ife
.
false_branch
==
right
assert
ife
.
span
==
None
str
(
ife
)
check_json_roundtrip
(
ife
)
def
test_tuple_get_item
():
...
...
@@ -176,6 +206,8 @@ def test_tuple_get_item():
assert
get
.
tuple_value
==
tup
assert
get
.
index
==
1
str
(
get
)
check_json_roundtrip
(
get
)
if
__name__
==
"__main__"
:
test_bad_constructor
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment