Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
3d62cf7c
Commit
3d62cf7c
authored
Oct 05, 2018
by
Steven S. Lyubomirsky
Committed by
Tianqi Chen
Oct 05, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Relay] More type alpha equality test coverage (#1823)
parent
5bf1cbda
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
204 additions
and
11 deletions
+204
-11
python/tvm/relay/__init__.py
+1
-0
src/relay/pass/alpha_eq.cc
+41
-2
tests/python/relay/test_ir_nodes.py
+2
-2
tests/python/relay/test_pass_alpha_eq.py
+160
-7
No files found.
python/tvm/relay/__init__.py
View file @
3d62cf7c
...
...
@@ -25,6 +25,7 @@ TypeParam = ty.TypeParam
TypeConstraint
=
ty
.
TypeConstraint
FuncType
=
ty
.
FuncType
TypeRelation
=
ty
.
TypeRelation
IncompleteType
=
ty
.
IncompleteType
# Expr
Constant
=
expr
.
Constant
...
...
src/relay/pass/alpha_eq.cc
View file @
3d62cf7c
...
...
@@ -88,11 +88,23 @@ struct TypeAlphaEq : TypeVisitor<const Type&> {
void
VisitType_
(
const
FuncTypeNode
*
op
,
const
Type
&
t2
)
final
{
if
(
const
FuncTypeNode
*
ta2
=
t2
.
as
<
FuncTypeNode
>
())
{
if
(
op
->
arg_types
.
size
()
!=
ta2
->
arg_types
.
size
())
{
if
(
op
->
arg_types
.
size
()
!=
ta2
->
arg_types
.
size
()
||
op
->
type_params
.
size
()
!=
ta2
->
type_params
.
size
()
||
op
->
type_constraints
.
size
()
!=
ta2
->
type_constraints
.
size
())
{
equal
=
false
;
return
;
}
// must visit params first so they are appropriate entered
// into equality map
for
(
size_t
i
=
0
;
i
<
op
->
type_params
.
size
();
i
++
)
{
eq_map
.
Set
(
op
->
type_params
[
i
],
ta2
->
type_params
[
i
]);
this
->
VisitType
(
op
->
type_params
[
i
],
ta2
->
type_params
[
i
]);
if
(
!
equal
)
{
return
;
}
}
for
(
size_t
i
=
0
;
i
<
op
->
arg_types
.
size
();
i
++
)
{
this
->
VisitType
(
op
->
arg_types
[
i
],
ta2
->
arg_types
[
i
]);
if
(
!
equal
)
{
...
...
@@ -101,6 +113,16 @@ struct TypeAlphaEq : TypeVisitor<const Type&> {
}
this
->
VisitType
(
op
->
ret_type
,
ta2
->
ret_type
);
if
(
!
equal
)
{
return
;
}
for
(
size_t
i
=
0
;
i
<
op
->
type_constraints
.
size
();
i
++
)
{
this
->
VisitType
(
op
->
type_constraints
[
i
],
ta2
->
type_constraints
[
i
]);
if
(
!
equal
)
{
return
;
}
}
}
else
{
equal
=
false
;
}
...
...
@@ -108,7 +130,24 @@ struct TypeAlphaEq : TypeVisitor<const Type&> {
void
VisitType_
(
const
TypeRelationNode
*
tr1
,
const
Type
&
t2
)
final
{
if
(
const
TypeRelationNode
*
tr2
=
t2
.
as
<
TypeRelationNode
>
())
{
equal
=
tr1
==
tr2
;
if
(
tr1
->
func
!=
tr2
->
func
||
tr1
->
num_inputs
!=
tr2
->
num_inputs
||
tr1
->
attrs
!=
tr2
->
attrs
)
{
equal
=
false
;
return
;
}
if
(
tr1
->
args
.
size
()
!=
tr2
->
args
.
size
())
{
equal
=
false
;
return
;
}
for
(
size_t
i
=
0
;
i
<
tr1
->
args
.
size
();
i
++
)
{
this
->
VisitType
(
tr1
->
args
[
i
],
tr2
->
args
[
i
]);
if
(
!
equal
)
{
return
;
}
}
}
else
{
equal
=
false
;
}
...
...
tests/python/relay/test_ir_nodes.py
View file @
3d62cf7c
...
...
@@ -65,8 +65,8 @@ def test_type_relation():
args
=
tvm
.
convert
([
tf
,
tt
,
tp
])
num_inputs
=
2
func
=
None
attrs
=
None
func
=
tvm
.
get_env_func
(
"tvm.relay.type_relation.Broadcast"
)
attrs
=
tvm
.
make
.
node
(
"attrs.TestAttrs"
,
name
=
"attr"
,
padding
=
(
3
,
4
))
tr
=
relay
.
TypeRelation
(
func
,
args
,
num_inputs
,
attrs
)
assert
tr
.
args
==
args
...
...
tests/python/relay/test_pass_alpha_eq.py
View file @
3d62cf7c
import
tvm
from
tvm
import
relay
def
test_type_alpha_eq
():
t1
=
relay
.
ty
.
TensorType
((
3
,
4
),
"float32"
)
t2
=
relay
.
ty
.
TensorType
((
3
,
4
),
"float32"
)
t3
=
relay
.
ty
.
TensorType
((
3
,
4
,
5
),
"float32"
)
def
test_tensor_type_alpha_eq
():
t1
=
relay
.
TensorType
((
3
,
4
),
"float32"
)
t2
=
relay
.
TensorType
((
3
,
4
),
"float32"
)
t3
=
relay
.
TensorType
((
3
,
4
,
5
),
"float32"
)
assert
t1
==
t2
assert
t1
!=
t3
t1
=
relay
.
ty
.
TensorType
((),
"float32"
)
t2
=
relay
.
ty
.
TensorType
((),
"float32"
)
t1
=
relay
.
TensorType
((),
"float32"
)
t2
=
relay
.
TensorType
((),
"float32"
)
assert
t1
==
t2
def
test_incomplete_type_alpha_eq
():
t1
=
relay
.
IncompleteType
(
relay
.
Kind
.
Shape
)
t2
=
relay
.
IncompleteType
(
relay
.
Kind
.
Type
)
t3
=
relay
.
IncompleteType
(
relay
.
Kind
.
Type
)
# only equal when there is pointer equality
assert
t2
==
t2
assert
t1
==
t1
assert
t1
!=
t2
assert
t2
!=
t3
def
test_type_param_alpha_eq
():
t1
=
relay
.
TypeParam
(
"v1"
,
relay
.
Kind
.
Type
)
t2
=
relay
.
TypeParam
(
"v2"
,
relay
.
Kind
.
Shape
)
t3
=
relay
.
TypeParam
(
"v3"
,
relay
.
Kind
.
Type
)
# only pointer equality and eq_map allow equal params
assert
t1
==
t1
assert
t2
==
t2
assert
t1
!=
t2
# different kind
assert
t1
!=
t3
# not in eq_map
# function types are the only way to put type params
# in eq map
ft1
=
relay
.
FuncType
(
tvm
.
convert
([]),
t1
,
tvm
.
convert
([
t1
]),
tvm
.
convert
([]))
ft2
=
relay
.
FuncType
(
tvm
.
convert
([]),
t3
,
tvm
.
convert
([
t3
]),
tvm
.
convert
([]))
# actually an invalid type because t2 is wrong kind
ft3
=
relay
.
FuncType
(
tvm
.
convert
([]),
t2
,
tvm
.
convert
([
t2
]),
tvm
.
convert
([]))
assert
ft1
==
ft2
assert
ft1
!=
ft3
# kinds still do not match
def
test_func_type_alpha_eq
():
t1
=
relay
.
TensorType
((
1
,
2
),
"float32"
)
t2
=
relay
.
TensorType
((
1
,
2
,
3
),
"float32"
)
tp1
=
relay
.
TypeParam
(
"v1"
,
relay
.
Kind
.
Type
)
tp2
=
relay
.
TypeParam
(
"v2"
,
relay
.
Kind
.
Type
)
tp3
=
relay
.
TypeParam
(
"v3"
,
relay
.
Kind
.
Shape
)
tp4
=
relay
.
TypeParam
(
"v3"
,
relay
.
Kind
.
Shape
)
broadcast
=
tvm
.
get_env_func
(
"tvm.relay.type_relation.Broadcast"
)
identity
=
tvm
.
get_env_func
(
"tvm.relay.type_relation.Identity"
)
tr1
=
relay
.
TypeRelation
(
broadcast
,
tvm
.
convert
([
tp1
,
tp3
]),
1
,
None
)
tr2
=
relay
.
TypeRelation
(
broadcast
,
tvm
.
convert
([
tp2
,
tp4
]),
1
,
None
)
tr3
=
relay
.
TypeRelation
(
identity
,
tvm
.
convert
([
tp1
,
tp3
]),
1
,
None
)
ft
=
relay
.
FuncType
(
tvm
.
convert
([
t1
,
t2
]),
tp1
,
tvm
.
convert
([
tp1
,
tp3
]),
tvm
.
convert
([
tr1
]))
translate_vars
=
relay
.
FuncType
(
tvm
.
convert
([
t1
,
t2
]),
tp1
,
tvm
.
convert
([
tp2
,
tp4
]),
tvm
.
convert
([
tr2
]))
assert
ft
==
translate_vars
different_args
=
relay
.
FuncType
(
tvm
.
convert
([
t1
]),
tp1
,
tvm
.
convert
([
tp1
,
tp3
]),
tvm
.
convert
([
tr1
]))
assert
ft
!=
different_args
different_order
=
relay
.
FuncType
(
tvm
.
convert
([
t2
,
t1
]),
tp1
,
tvm
.
convert
([
tp1
,
tp3
]),
tvm
.
convert
([
tr1
]))
assert
ft
!=
different_order
no_rel
=
relay
.
FuncType
(
tvm
.
convert
([
t1
,
t2
]),
tp1
,
tvm
.
convert
([
tp1
,
tp3
]),
tvm
.
convert
([]))
assert
ft
!=
no_rel
more_vars
=
relay
.
FuncType
(
tvm
.
convert
([
t1
,
t2
]),
tp2
,
tvm
.
convert
([
tp1
,
tp2
,
tp3
]),
tvm
.
convert
([
tr1
]))
assert
ft
!=
more_vars
all_the_vars
=
relay
.
FuncType
(
tvm
.
convert
([
t1
,
t2
]),
tp1
,
tvm
.
convert
([
tp1
,
tp2
,
tp3
,
tp4
]),
tvm
.
convert
([
tr1
,
tr2
]))
assert
ft
!=
all_the_vars
different_rel
=
relay
.
FuncType
(
tvm
.
convert
([
t1
,
t2
]),
tp1
,
tvm
.
convert
([
tp1
,
tp3
]),
tvm
.
convert
([
tr3
]))
assert
ft
!=
different_rel
more_rels
=
relay
.
FuncType
(
tvm
.
convert
([
t1
,
t2
]),
tp1
,
tvm
.
convert
([
tp1
,
tp3
]),
tvm
.
convert
([
tr1
,
tr3
]))
assert
ft
!=
more_rels
def
test_tuple_type_alpha_eq
():
t1
=
relay
.
TensorType
((
1
,
2
,
3
),
"float32"
)
t2
=
relay
.
TensorType
((
1
,
2
,
3
,
4
),
"float32"
)
tp1
=
relay
.
TypeParam
(
"v1"
,
relay
.
Kind
.
Type
)
tp2
=
relay
.
TypeParam
(
"v2"
,
relay
.
Kind
.
Type
)
tup1
=
relay
.
TupleType
(
tvm
.
convert
([
t1
,
t2
,
tp1
]))
tup2
=
relay
.
TupleType
(
tvm
.
convert
([
t1
,
t2
,
tp1
]))
tup3
=
relay
.
TupleType
(
tvm
.
convert
([
t2
,
t1
,
tp1
]))
tup4
=
relay
.
TupleType
(
tvm
.
convert
([
t1
,
t2
,
tp2
]))
# as long as types are alpha-equal and in same order,
# tuples should be alpha-equal
assert
tup1
==
tup2
assert
tup1
!=
tup3
assert
tup1
!=
tup4
def
test_type_relation_alpha_eq
():
t1
=
relay
.
TensorType
((
1
,
2
),
"float32"
)
t2
=
relay
.
TensorType
((
1
,
2
,
3
),
"float32"
)
t3
=
relay
.
TensorType
((
1
,
2
,
3
,
4
),
"float32"
)
# functions are compared only by pointer equality so
# we need to be sure to use the same pointers
broadcast
=
tvm
.
get_env_func
(
"tvm.relay.type_relation.Broadcast"
)
identity
=
tvm
.
get_env_func
(
"tvm.relay.type_relation.Identity"
)
# attrs are also compared only by pointer equality
attr1
=
tvm
.
make
.
node
(
"attrs.TestAttrs"
,
name
=
"attr"
,
padding
=
(
3
,
4
))
attr2
=
tvm
.
make
.
node
(
"attrs.TestAttrs"
,
name
=
"attr"
,
padding
=
(
3
,
4
))
tr
=
relay
.
TypeRelation
(
broadcast
,
tvm
.
convert
([
t1
,
t2
]),
1
,
attr1
)
same
=
relay
.
TypeRelation
(
broadcast
,
tvm
.
convert
([
t1
,
t2
]),
1
,
attr1
)
diff_func
=
relay
.
TypeRelation
(
identity
,
tvm
.
convert
([
t1
,
t2
]),
1
,
attr1
)
diff_order
=
relay
.
TypeRelation
(
broadcast
,
tvm
.
convert
([
t2
,
t1
]),
1
,
attr1
)
diff_args
=
relay
.
TypeRelation
(
broadcast
,
tvm
.
convert
([
t2
,
t3
]),
1
,
attr1
)
diff_attr
=
relay
.
TypeRelation
(
broadcast
,
tvm
.
convert
([
t1
,
t2
]),
1
,
attr2
)
bigger
=
relay
.
TypeRelation
(
identity
,
tvm
.
convert
([
t1
,
t3
,
t2
]),
2
,
attr1
)
diff_num_inputs
=
relay
.
TypeRelation
(
identity
,
tvm
.
convert
([
t1
,
t3
,
t2
]),
1
,
attr2
)
# func, number of args, input count, and order should be the same
assert
tr
==
same
assert
tr
!=
diff_func
assert
tr
!=
diff_order
assert
tr
!=
diff_args
assert
tr
!=
diff_attr
assert
tr
!=
bigger
assert
bigger
!=
diff_num_inputs
if
__name__
==
"__main__"
:
test_type_alpha_eq
()
test_tensor_type_alpha_eq
()
test_incomplete_type_alpha_eq
()
test_type_param_alpha_eq
()
test_func_type_alpha_eq
()
test_tuple_type_alpha_eq
()
test_type_relation_alpha_eq
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment