Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
3455c8a5
Commit
3455c8a5
authored
Oct 01, 2018
by
Steven S. Lyubomirsky
Committed by
Tianqi Chen
Oct 01, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Relay] Incorporate TypeRelations into more tests (#1792)
parent
106991d2
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
110 additions
and
7 deletions
+110
-7
python/tvm/relay/__init__.py
+1
-0
python/tvm/relay/ty.pyi
+23
-0
src/relay/pass/kind_check.cc
+11
-4
tests/python/relay/test_check_kind.py
+58
-3
tests/python/relay/test_ir_nodes.py
+17
-0
No files found.
python/tvm/relay/__init__.py
View file @
3455c8a5
...
...
@@ -21,6 +21,7 @@ Kind = ty.Kind
TypeParam
=
ty
.
TypeParam
TypeConstraint
=
ty
.
TypeConstraint
FuncType
=
ty
.
FuncType
TypeRelation
=
ty
.
TypeRelation
# Expr
Constant
=
expr
.
Constant
...
...
python/tvm/relay/ty.pyi
View file @
3455c8a5
...
...
@@ -158,3 +158,26 @@ class IncompleteType(Type):
def __init__(self, kind):
self.__init_handle_by_constructor__(_make.IncompleteType, kind)
@register_relay_node
class TypeRelation(TypeConstraint):
"""Type relation in relay.
Parameters
----------
func : EnvFunc
User defined relation function.
args : list of types
List of types to the func.
num_inputs: int
Number of input arguments in args,
this act as a hint for type inference.
attrs : Attrs
The attribute attached to the relation information
"""
def __init__(self, func, args, num_inputs, attrs):
self.__init_handle_by_constructor__(_make.TypeRelation,
func, args, num_inputs, attrs)
src/relay/pass/kind_check.cc
View file @
3455c8a5
...
...
@@ -45,8 +45,7 @@ struct KindChecker : TypeVisitor<> {
return
true
;
}
return
t
.
as
<
TensorTypeNode
>
()
||
t
.
as
<
BaseTensorTypeNode
>
()
||
t
.
as
<
TupleTypeNode
>
()
||
t
.
as
<
FuncTypeNode
>
();
return
t
.
as_derived
<
BaseTensorTypeNode
>
()
||
t
.
as
<
TupleTypeNode
>
()
||
t
.
as
<
FuncTypeNode
>
();
}
void
VisitType_
(
const
TupleTypeNode
*
op
)
override
{
...
...
@@ -61,8 +60,9 @@ struct KindChecker : TypeVisitor<> {
}
void
VisitType_
(
const
FuncTypeNode
*
op
)
override
{
// func types should only take normal types for arguments
// and only return a normal type
// Func types should only take normal types for arguments
// and only return a normal type. They should also have
// well-formed constraints
for
(
const
Type
&
t
:
op
->
arg_types
)
{
this
->
VisitType
(
t
);
valid
=
valid
&&
IsTypeKind
(
t
);
...
...
@@ -71,6 +71,13 @@ struct KindChecker : TypeVisitor<> {
}
}
for
(
const
TypeConstraint
&
tc
:
op
->
type_constraints
)
{
this
->
VisitType
(
tc
);
if
(
!
valid
)
{
return
;
}
}
this
->
VisitType
(
op
->
ret_type
);
valid
=
valid
&&
IsTypeKind
(
op
->
ret_type
);
}
...
...
tests/python/relay/test_check_kind.py
View file @
3455c8a5
...
...
@@ -2,7 +2,7 @@ import tvm
from
tvm
import
relay
from
tvm.relay.ir_pass
import
check_kind
def
test_tuple_kind
s
():
def
test_tuple_kind
():
# only contain type kinds
tp
=
relay
.
TypeParam
(
'tp'
,
relay
.
Kind
.
Type
)
tt
=
relay
.
TensorType
(
tvm
.
convert
([
1
,
2
,
3
]),
'float32'
)
...
...
@@ -12,6 +12,7 @@ def test_tuple_kinds():
tup_ty
=
relay
.
TupleType
(
fields
)
assert
check_kind
(
tup_ty
)
def
test_func_kind
():
# only contain type kinds
tp1
=
relay
.
TypeParam
(
'tp1'
,
relay
.
Kind
.
Type
)
...
...
@@ -21,15 +22,29 @@ def test_func_kind():
dtype
=
'float32'
tensor_type
=
relay
.
TensorType
(
shape
,
dtype
)
tr
=
relay
.
TypeRelation
(
None
,
tvm
.
convert
([
tensor_type
,
tp1
])
,
1
,
None
)
type_params
=
tvm
.
convert
([
tp1
,
tp2
])
type_constraints
=
tvm
.
convert
([])
type_constraints
=
tvm
.
convert
([
tr
])
arg_types
=
tvm
.
convert
([
tp1
,
tensor_type
])
ret_type
=
relay
.
TupleType
(
tvm
.
convert
([
tp2
,
tensor_type
]))
tf
=
relay
.
FuncType
(
arg_types
,
ret_type
,
type_params
,
type_constraints
)
assert
check_kind
(
tf
)
def
test_invalid_tuple_kinds
():
def
test_relation_kind
():
# only have type kinds for arguments
tp
=
relay
.
TypeParam
(
'tp'
,
relay
.
Kind
.
Type
)
tt
=
relay
.
TensorType
(
tvm
.
convert
([
1
,
2
,
3
]),
'float32'
)
tf
=
relay
.
FuncType
(
tvm
.
convert
([]),
tt
,
tvm
.
convert
([]),
tvm
.
convert
([]))
args
=
tvm
.
convert
([
tf
,
tt
,
tp
])
tr
=
relay
.
TypeRelation
(
None
,
args
,
2
,
None
)
assert
check_kind
(
tr
)
def
test_invalid_tuple_kind
():
tp1
=
relay
.
TypeParam
(
'tp1'
,
relay
.
Kind
.
Shape
)
tp2
=
relay
.
TypeParam
(
'tp2'
,
relay
.
Kind
.
BaseType
)
tp3
=
relay
.
TypeParam
(
'tp3'
,
relay
.
Kind
.
ShapeVar
)
...
...
@@ -38,6 +53,7 @@ def test_invalid_tuple_kinds():
tup_ty
=
relay
.
TupleType
(
fields
)
assert
not
check_kind
(
tup_ty
)
def
test_invalid_func_kind
():
tp1
=
relay
.
TypeParam
(
'tp1'
,
relay
.
Kind
.
Shape
)
tp2
=
relay
.
TypeParam
(
'tp2'
,
relay
.
Kind
.
BaseType
)
...
...
@@ -51,16 +67,29 @@ def test_invalid_func_kind():
tf
=
relay
.
FuncType
(
arg_types
,
ret_type
,
type_params
,
type_constraints
)
assert
not
check_kind
(
tf
)
def
test_invalid_relation_kind
():
tp1
=
relay
.
TypeParam
(
'tp1'
,
relay
.
Kind
.
Shape
)
tp2
=
relay
.
TypeParam
(
'tp2'
,
relay
.
Kind
.
BaseType
)
tp3
=
relay
.
TypeParam
(
'tp3'
,
relay
.
Kind
.
ShapeVar
)
args
=
tvm
.
convert
([
tp1
,
tp2
,
tp3
])
tr
=
relay
.
TypeRelation
(
None
,
args
,
2
,
None
)
assert
not
check_kind
(
tr
)
def
test_func_with_invalid_ret_type
():
tp1
=
relay
.
TypeParam
(
'tp1'
,
relay
.
Kind
.
Type
)
tp2
=
relay
.
TypeParam
(
'tp2'
,
relay
.
Kind
.
Shape
)
tf
=
relay
.
FuncType
(
tvm
.
convert
([
tp1
]),
tp2
,
tvm
.
convert
([
tp1
,
tp2
]),
tvm
.
convert
([]))
def
test_func_with_invalid_arg_types
():
tp1
=
relay
.
TypeParam
(
'tp1'
,
relay
.
Kind
.
Shape
)
tp2
=
relay
.
TypeParam
(
'tp2'
,
relay
.
Kind
.
Type
)
tf
=
relay
.
FuncType
(
tvm
.
convert
([
tp1
]),
tp2
,
tvm
.
convert
([
tp1
,
tp2
]),
tvm
.
convert
([]))
def
test_func_with_invalid_tuple
():
tp1
=
relay
.
TypeParam
(
'tp1'
,
relay
.
Kind
.
Shape
)
...
...
@@ -69,6 +98,18 @@ def test_func_with_invalid_tuple():
tf
=
relay
.
FuncType
(
tvm
.
convert
([]),
ret_type
,
tvm
.
convert
([
tp1
]),
tvm
.
convert
([]))
assert
not
check_kind
(
tf
)
def
test_func_with_invalid_relation
():
tp1
=
relay
.
TypeParam
(
'tp1'
,
relay
.
Kind
.
Type
)
tp2
=
relay
.
TypeParam
(
'tp2'
,
relay
.
Kind
.
Shape
)
tp3
=
relay
.
TypeParam
(
'tp3'
,
relay
.
Kind
.
ShapeVar
)
tr
=
relay
.
TypeRelation
(
None
,
tvm
.
convert
([
tp2
,
tp3
]),
1
,
None
)
tf
=
relay
.
FuncType
(
tvm
.
convert
([
tp1
]),
tp1
,
tvm
.
convert
([
tp1
,
tp2
,
tp3
]),
tvm
.
convert
([
tr
]))
assert
not
check_kind
(
tf
)
def
test_tuple_with_invalid_func
():
tensor_type
=
relay
.
TensorType
(
tvm
.
convert
([
1
,
2
,
3
]),
'float32'
)
...
...
@@ -77,3 +118,17 @@ def test_tuple_with_invalid_func():
tup_ty
=
relay
.
TupleType
(
tvm
.
convert
([
tensor_type
,
tf
]))
assert
not
check_kind
(
tup_ty
)
if
__name__
==
"__main__"
:
test_tuple_kind
()
test_func_kind
()
test_relation_kind
()
test_invalid_tuple_kind
()
test_invalid_func_kind
()
test_invalid_relation_kind
()
test_func_with_invalid_ret_type
()
test_func_with_invalid_arg_types
()
test_func_with_invalid_tuple
()
test_func_with_invalid_relation
()
test_tuple_with_invalid_func
()
tests/python/relay/test_ir_nodes.py
View file @
3455c8a5
...
...
@@ -58,6 +58,21 @@ def test_tuple_type():
assert
tup_ty
.
fields
==
fields
def
test_type_relation
():
tp
=
relay
.
TypeParam
(
'tp'
,
relay
.
Kind
.
Type
)
tf
=
relay
.
FuncType
(
tvm
.
convert
([]),
None
,
tvm
.
convert
([]),
tvm
.
convert
([]))
tt
=
relay
.
TensorType
(
tvm
.
convert
([
1
,
2
,
3
]),
'float32'
)
args
=
tvm
.
convert
([
tf
,
tt
,
tp
])
num_inputs
=
2
func
=
None
attrs
=
None
tr
=
relay
.
TypeRelation
(
func
,
args
,
num_inputs
,
attrs
)
assert
tr
.
args
==
args
assert
tr
.
num_inputs
==
num_inputs
def
test_constant
():
arr
=
tvm
.
nd
.
array
(
10
)
const
=
relay
.
Constant
(
arr
)
...
...
@@ -158,6 +173,8 @@ if __name__ == "__main__":
test_tensor_type
()
test_type_param
()
test_func_type
()
test_tuple_type
()
test_type_relation
()
test_constant
()
test_tuple
()
test_local_var
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment