Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
3bfa5fc0
Commit
3bfa5fc0
authored
Oct 23, 2018
by
Jared Roesch
Committed by
Tianqi Chen
Oct 23, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[RELAY][TypeSystem] Add support for populating type args (#1962)
parent
3a1bb8c7
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
142 additions
and
30 deletions
+142
-30
include/tvm/relay/op.h
+30
-0
src/relay/ir/text_printer.cc
+17
-4
src/relay/pass/type_infer.cc
+78
-26
tests/python/relay/test_type_infer.py
+17
-0
No files found.
include/tvm/relay/op.h
View file @
3bfa5fc0
...
...
@@ -485,6 +485,36 @@ inline ValueType OpMap<ValueType>::get(const Op& op,
return
map_
.
get
<
ValueType
>
(
op
,
def_value
);
}
/*!
* \brief Check that an expression is a "primtive operator".
*
* Will return true if the expression is an operator which
* matches the form of primtive operators registered directly
* by the Relay codebase.
*
* That is the arguments are all type variables, and there is a single
* type relation applied to the input and output types.
*/
inline
bool
IsPrimitiveOp
(
const
Expr
&
expr
)
{
const
auto
*
op
=
expr
.
as
<
OpNode
>
();
if
(
!
op
)
{
return
false
;
}
const
auto
&
fn_ty
=
op
->
op_type
;
if
(
fn_ty
->
type_constraints
.
size
()
!=
1
)
return
false
;
const
TypeRelationNode
*
rel
=
fn_ty
->
type_constraints
[
0
].
as
<
TypeRelationNode
>
();
if
(
rel
==
nullptr
)
return
false
;
// validate if the type parameter matches up
for
(
size_t
i
=
0
;
i
<
fn_ty
->
type_params
.
size
();
++
i
)
{
if
(
!
fn_ty
->
type_params
[
i
].
same_as
(
rel
->
args
[
i
]))
return
false
;
}
return
true
;
}
}
// namespace relay
}
// namespace tvm
#endif // TVM_RELAY_OP_H_
src/relay/ir/text_printer.cc
View file @
3bfa5fc0
...
...
@@ -278,10 +278,7 @@ class TextPrinter :
}
TextValue
VisitExpr_
(
const
CallNode
*
op
)
final
{
// TODO(tqchen, M.K.): support generic call
// possibly through meta-data
CHECK_EQ
(
op
->
type_args
.
size
(),
0U
)
<<
"generic call not yet supported"
;
TextValue
call_op
=
GetValue
(
op
->
op
);
std
::
vector
<
TextValue
>
args
;
for
(
Expr
arg
:
op
->
args
)
{
...
...
@@ -289,7 +286,23 @@ class TextPrinter :
}
TextValue
id
=
this
->
AllocTempVar
();
this
->
PrintIndent
();
stream_
<<
id
<<
" = "
<<
call_op
<<
"("
;
stream_
<<
id
<<
" = "
<<
call_op
;
auto
type_args
=
op
->
type_args
;
if
(
!
IsPrimitiveOp
(
op
->
op
)
&&
type_args
.
size
()
>
0U
)
{
stream_
<<
"<"
;
for
(
size_t
i
=
0
;
i
<
op
->
type_args
.
size
();
++
i
)
{
this
->
PrintType
(
type_args
[
i
],
stream_
);
if
(
i
+
1
!=
type_args
.
size
())
{
stream_
<<
", "
;
}
}
stream_
<<
">"
;
}
stream_
<<
"("
;
for
(
size_t
i
=
0
;
i
<
args
.
size
();
++
i
)
{
stream_
<<
args
[
i
];
if
(
i
+
1
!=
args
.
size
())
{
...
...
src/relay/pass/type_infer.cc
View file @
3bfa5fc0
...
...
@@ -61,6 +61,17 @@ TVM_REGISTER_API("tvm.relay.type_relation.TupleGetItem")
.
set_body_typed
<
bool
(
const
Array
<
Type
>&
,
int
,
const
Attrs
&
,
const
TypeReporter
&
)
>
(
TupleGetItemRel
);
struct
ResolvedTypeInfo
{
explicit
ResolvedTypeInfo
(
Type
checked_type
,
Array
<
Type
>
type_args
)
:
checked_type
(
checked_type
),
type_args
(
type_args
)
{}
ResolvedTypeInfo
()
{}
Type
checked_type
;
// Only allocated when the expression is a call.
Array
<
Type
>
type_args
=
Array
<
Type
>
(
NodePtr
<
Node
>
(
nullptr
));
};
//
// The inference algorithm can roughly be devided into three stages:
// - Populate the constraints by visiting the expression (TypeInferencer.GetType)
...
...
@@ -87,7 +98,8 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
Environment
env_
;
// map from expression to checked type
// type inferencer will populate it up
std
::
unordered_map
<
Expr
,
Type
,
NodeHash
,
NodeEqual
>
type_map_
;
std
::
unordered_map
<
Expr
,
ResolvedTypeInfo
,
NodeHash
,
NodeEqual
>
type_map_
;
// The solver used by the inferencer.
TypeSolver
solver_
;
// relation function
...
...
@@ -111,11 +123,12 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
// will call visit to deduce it if it is not in the type_map_
Type
GetType
(
const
Expr
&
expr
)
{
auto
it
=
type_map_
.
find
(
expr
);
if
(
it
!=
type_map_
.
end
())
{
return
it
->
second
;
if
(
it
!=
type_map_
.
end
()
&&
it
->
second
.
checked_type
.
defined
()
)
{
return
it
->
second
.
checked_type
;
}
Type
ret
=
this
->
VisitExpr
(
expr
);
type_map_
[
expr
]
=
ret
;
ResolvedTypeInfo
&
rti
=
type_map_
[
expr
];
rti
.
checked_type
=
ret
;
return
ret
;
}
...
...
@@ -176,7 +189,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
}
CHECK
(
!
type_map_
.
count
(
op
->
var
));
// NOTE: no scoping is necessary because var are unique in program
type_map_
[
op
->
var
]
=
vtype
;
type_map_
[
op
->
var
]
.
checked_type
=
vtype
;
return
GetType
(
op
->
body
);
}
...
...
@@ -224,6 +237,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
subst_map
.
Set
(
ty_param
,
fresh
);
ty_args
->
push_back
(
fresh
);
}
Type
ret_type
=
fn_ty
->
ret_type
;
// If the function type is incomplete, place a new IncompleteType
...
...
@@ -234,6 +248,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
if
(
!
ret_type
.
defined
())
{
ret_type
=
IncompleteTypeNode
::
make
(
TypeVarNode
::
Kind
::
kType
);
}
Type
inst_ty
=
FuncTypeNode
::
make
(
fn_ty
->
arg_types
,
ret_type
,
{},
fn_ty
->
type_constraints
);
...
...
@@ -241,49 +256,74 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
return
Downcast
<
FuncType
>
(
inst_ty
);
}
void
AddTypeArgs
(
const
Expr
&
expr
,
Array
<
Type
>
type_args
)
{
auto
type_info
=
type_map_
.
find
(
expr
);
if
(
type_info
==
type_map_
.
end
())
{
type_map_
.
insert
({
expr
,
ResolvedTypeInfo
(
Type
(),
type_args
)});
}
else
{
CHECK
(
!
type_info
->
second
.
type_args
.
defined
());
type_info
->
second
.
type_args
=
type_args
;
}
}
// Handle general call node.
Type
GeneralCall
(
const
CallNode
*
op
,
Array
<
Type
>
arg_types
)
{
Type
ftype
=
GetType
(
op
->
op
);
Type
GeneralCall
(
const
CallNode
*
call
,
Array
<
Type
>
arg_types
)
{
Type
ftype
=
GetType
(
call
->
op
);
auto
*
fn_ty_node
=
ftype
.
as
<
FuncTypeNode
>
();
CHECK
(
fn_ty_node
!=
nullptr
)
<<
"only expressions with function types can be called, at "
<<
op
->
span
;
<<
call
->
span
;
Array
<
Type
>
type_args
;
FuncType
fn_ty
=
Instantiate
(
fn_ty_node
,
&
type_args
);
AddTypeArgs
(
GetRef
<
Call
>
(
call
),
type_args
);
size_t
type_arity
=
fn_ty
->
arg_types
.
size
();
size_t
number_of_args
=
arg_types
.
size
();
if
(
type_arity
!=
number_of_args
)
{
if
(
type_arity
<
number_of_args
)
{
LOG
(
FATAL
)
<<
"the function is provided too many arguments "
<<
op
->
span
;
LOG
(
FATAL
)
<<
"the function is provided too many arguments "
<<
call
->
span
;
}
else
{
LOG
(
FATAL
)
<<
"the function is provided too few arguments"
<<
op
->
span
;
LOG
(
FATAL
)
<<
"the function is provided too few arguments"
<<
call
->
span
;
}
}
for
(
size_t
i
=
0
;
i
<
fn_ty
->
arg_types
.
size
();
i
++
)
{
this
->
Unify
(
fn_ty
->
arg_types
[
i
],
arg_types
[
i
],
op
->
args
[
i
]
->
span
);
this
->
Unify
(
fn_ty
->
arg_types
[
i
],
arg_types
[
i
],
call
->
args
[
i
]
->
span
);
}
for
(
auto
cs
:
fn_ty
->
type_constraints
)
{
solver_
.
AddConstraint
(
cs
);
if
(
auto
tr
=
cs
.
as
<
TypeRelationNode
>
())
{
solver_
.
AddConstraint
(
TypeRelationNode
::
make
(
tr
->
func
,
tr
->
args
,
tr
->
num_inputs
,
call
->
attrs
));
}
else
{
solver_
.
AddConstraint
(
cs
);
}
}
return
fn_ty
->
ret_type
;
}
Type
VisitExpr_
(
const
CallNode
*
op
)
final
{
// Fast path: well-formed primitive op
Type
VisitExpr_
(
const
CallNode
*
call
)
final
{
Array
<
Type
>
arg_types
;
for
(
Expr
arg
:
op
->
args
)
{
for
(
Expr
arg
:
call
->
args
)
{
arg_types
.
push_back
(
GetType
(
arg
));
}
if
(
const
OpNode
*
opnode
=
op
->
op
.
as
<
OpNode
>
())
{
if
(
const
OpNode
*
opnode
=
call
->
op
.
as
<
OpNode
>
())
{
Type
rtype
=
PrimitiveCall
(
opnode
->
op_type
.
as
<
FuncTypeNode
>
(),
arg_types
,
op
->
attrs
);
if
(
rtype
.
defined
())
return
rtype
;
call
->
attrs
);
if
(
rtype
.
defined
())
{
AddTypeArgs
(
GetRef
<
Call
>
(
call
),
arg_types
);
return
rtype
;
}
}
return
GeneralCall
(
op
,
arg_types
);
return
GeneralCall
(
call
,
arg_types
);
}
Type
VisitExpr_
(
const
FunctionNode
*
f
)
final
{
...
...
@@ -312,7 +352,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
class
TypeInferencer
::
Resolver
:
public
ExprMutator
{
public
:
Resolver
(
const
std
::
unordered_map
<
Expr
,
Type
,
NodeHash
,
NodeEqual
>&
tmap
,
Resolver
(
const
std
::
unordered_map
<
Expr
,
ResolvedTypeInfo
,
NodeHash
,
NodeEqual
>&
tmap
,
TypeSolver
*
solver
)
:
tmap_
(
tmap
),
solver_
(
solver
)
{
}
...
...
@@ -362,7 +402,7 @@ class TypeInferencer::Resolver : public ExprMutator {
Expr
AttachCheckedType
(
const
T
*
op
)
{
auto
it
=
tmap_
.
find
(
GetRef
<
Expr
>
(
op
));
CHECK
(
it
!=
tmap_
.
end
());
Type
checked_type
=
solver_
->
Resolve
(
it
->
second
);
Type
checked_type
=
solver_
->
Resolve
(
it
->
second
.
checked_type
);
CHECK
(
checked_type
.
as
<
IncompleteTypeNode
>
()
==
nullptr
)
<<
"Cannot resolve type of "
<<
GetRef
<
Expr
>
(
op
)
<<
" at "
<<
op
->
span
;
...
...
@@ -376,25 +416,37 @@ class TypeInferencer::Resolver : public ExprMutator {
}
new_e
->
checked_type_
=
checked_type
;
}
if
(
it
->
second
.
type_args
.
defined
())
{
Call
call
=
Downcast
<
Call
>
(
new_e
);
const
CallNode
*
const_call_ref
=
call
.
operator
->
();
CallNode
*
call_ref
=
const_cast
<
CallNode
*>
(
const_call_ref
);
call_ref
->
type_args
=
it
->
second
.
type_args
;
for
(
size_t
i
=
0
;
i
<
call
->
type_args
.
size
();
i
++
)
{
call_ref
->
type_args
.
Set
(
i
,
solver_
->
Resolve
(
call
->
type_args
[
i
]));
}
}
return
new_e
;
}
Type
VisitType
(
const
Type
&
t
)
final
{
Type
VisitType
(
const
Type
&
t
)
final
{
return
solver_
->
Resolve
(
t
);
}
private
:
const
std
::
unordered_map
<
Expr
,
Type
,
NodeHash
,
NodeEqual
>&
tmap_
;
const
std
::
unordered_map
<
Expr
,
ResolvedTypeInfo
,
NodeHash
,
NodeEqual
>&
tmap_
;
TypeSolver
*
solver_
;
};
Expr
TypeInferencer
::
Infer
(
Expr
expr
)
{
//
step 0: populate the constraints
//
Step 0: Populate the constraints.
GetType
(
expr
);
//
step 1: solve the constraints
//
Step 1: Solve the constraints.
solver_
.
Solve
();
//
step 2: attach resolved types to checked_type field
//
Step 2: Attach resolved types to checked_type field.
return
Resolver
(
type_map_
,
&
solver_
).
VisitExpr
(
expr
);
}
...
...
tests/python/relay/test_type_infer.py
View file @
3bfa5fc0
...
...
@@ -91,6 +91,21 @@ def test_free_expr():
yy
=
relay
.
ir_pass
.
infer_type
(
y
)
assert
yy
.
checked_type
==
relay
.
scalar_type
(
"float32"
)
def
test_type_args
():
x
=
relay
.
var
(
"x"
,
shape
=
(
10
,
10
))
y
=
relay
.
var
(
"y"
,
shape
=
(
1
,
10
))
z
=
relay
.
add
(
x
,
y
)
ty_z
=
relay
.
ir_pass
.
infer_type
(
z
)
ty_args
=
ty_z
.
type_args
assert
len
(
ty_args
)
==
2
assert
ty_args
[
0
]
.
dtype
==
"float32"
assert
ty_args
[
1
]
.
dtype
==
"float32"
sh1
=
ty_args
[
0
]
.
shape
sh2
=
ty_args
[
1
]
.
shape
assert
sh1
[
0
]
.
value
==
10
assert
sh1
[
1
]
.
value
==
10
assert
sh2
[
0
]
.
value
==
1
assert
sh2
[
1
]
.
value
==
10
if
__name__
==
"__main__"
:
test_free_expr
()
...
...
@@ -100,3 +115,5 @@ if __name__ == "__main__":
test_decl
()
test_recursion
()
test_tuple
()
test_free_expr
()
test_type_args
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment