Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
6783d373
Commit
6783d373
authored
Jan 16, 2019
by
Steven S. Lyubomirsky
Committed by
Tianqi Chen
Jan 16, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Relay] Unifier hotfix (#2437)
parent
76188a43
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
1072 additions
and
254 deletions
+1072
-254
include/tvm/relay/pass.h
+68
-0
python/tvm/relay/ir_pass.py
+66
-2
src/relay/pass/gradient.cc
+15
-4
src/relay/pass/type_infer.cc
+77
-60
src/relay/pass/type_solver.cc
+308
-35
src/relay/pass/type_solver.h
+11
-20
src/relay/pass/util.cc
+175
-39
tests/cpp/relay_pass_type_infer_test.cc
+10
-6
tests/python/relay/test_pass_free_vars.py
+0
-41
tests/python/relay/test_pass_vars.py
+144
-0
tests/python/relay/test_type_infer.py
+34
-47
tests/python/relay/test_type_solver.py
+164
-0
No files found.
include/tvm/relay/pass.h
View file @
6783d373
...
@@ -108,6 +108,17 @@ bool AlphaEqual(const Type& t1, const Type& t2);
...
@@ -108,6 +108,17 @@ bool AlphaEqual(const Type& t1, const Type& t2);
*/
*/
bool
WellFormed
(
const
Expr
&
expr
);
bool
WellFormed
(
const
Expr
&
expr
);
/*! \brief Get all bound variables from expression expr.
*
* Bound variables are all variables that are declared in the expr.
* They only have meaning inside that expr, and can only be used in it.
*
* \param expr the expression.
*
* \return List of bound vars, in the PostDFS order in the expression.
*/
tvm
::
Array
<
Var
>
BoundVars
(
const
Expr
&
expr
);
/*! \brief Get free type parameters from expression expr.
/*! \brief Get free type parameters from expression expr.
*
*
* Free variables are variables that are not bound by a
* Free variables are variables that are not bound by a
...
@@ -119,6 +130,14 @@ bool WellFormed(const Expr& expr);
...
@@ -119,6 +130,14 @@ bool WellFormed(const Expr& expr);
*/
*/
tvm
::
Array
<
Var
>
FreeVars
(
const
Expr
&
expr
);
tvm
::
Array
<
Var
>
FreeVars
(
const
Expr
&
expr
);
/*! \brief Get all variables from expression expr.
*
* \param expr the expression.
*
* \return List of all vars, in the PostDFS order in the expression.
*/
tvm
::
Array
<
Var
>
AllVars
(
const
Expr
&
expr
);
/*! \brief Get free TypeVars from expression expr.
/*! \brief Get free TypeVars from expression expr.
*
*
* Free type parameters are type parameters that are not bound by a function
* Free type parameters are type parameters that are not bound by a function
...
@@ -130,6 +149,55 @@ tvm::Array<Var> FreeVars(const Expr& expr);
...
@@ -130,6 +149,55 @@ tvm::Array<Var> FreeVars(const Expr& expr);
*/
*/
tvm
::
Array
<
TypeVar
>
FreeTypeVars
(
const
Expr
&
expr
);
tvm
::
Array
<
TypeVar
>
FreeTypeVars
(
const
Expr
&
expr
);
/*! \brief Get free TypeVars from type t.
*
* Free type parameters are type parameters that are not bound by a function
* type in the context.
*
* \param t the type.
*
* \return List of free type vars, in the PostDFS order visited by type.
*/
tvm
::
Array
<
TypeVar
>
FreeTypeVars
(
const
Type
&
t
);
/*! \brief Get all bound type variables from expression expr.
*
* Bound variables are all type variables that are declared in the expr.
* They only have meaning inside that expr, and can only be used in it.
*
* \param expr the expression.
*
* \return List of bound type vars, in the PostDFS order in the expression.
*/
tvm
::
Array
<
TypeVar
>
BoundTypeVars
(
const
Expr
&
expr
);
/*! \brief Get all bound type variables from type t.
*
* Bound variables are all type variables that are declared in the type.
* They only have meaning inside that type, and can only be used in it.
*
* \param t the type
*
* \return List of bound type vars, in the PostDFS order visited by type.
*/
tvm
::
Array
<
TypeVar
>
BoundTypeVars
(
const
Type
&
t
);
/*! \brief Get all type variables in expression expr.
*
* \param expr the expression.
*
* \return List of type vars, in the PostDFS order in the expression.
*/
tvm
::
Array
<
TypeVar
>
AllTypeVars
(
const
Expr
&
expr
);
/*! \brief Get all type variables in type t.
*
* \param t the type.
*
* \return List of type vars, in the PostDFS order visited by type.
*/
tvm
::
Array
<
TypeVar
>
AllTypeVars
(
const
Type
&
t
);
/*! \brief Remove expressions which does not effect the program result.
/*! \brief Remove expressions which does not effect the program result.
*
*
* It will remove let bindings which are not referenced, and branches that will
* It will remove let bindings which are not referenced, and branches that will
...
...
python/tvm/relay/ir_pass.py
View file @
6783d373
...
@@ -158,6 +158,38 @@ def free_vars(expr):
...
@@ -158,6 +158,38 @@ def free_vars(expr):
return
_ir_pass
.
free_vars
(
expr
)
return
_ir_pass
.
free_vars
(
expr
)
def
bound_vars
(
expr
):
"""Get bound vars from expression expr in post-DFS order.
Parameters
----------
expr: tvm.relay.Expr
The input expression
Returns
-------
free : List[tvm.relay.Var]
The list of bound variables in post-DFS order.
"""
return
_ir_pass
.
bound_vars
(
expr
)
def
all_vars
(
expr
):
"""Get all vars from expression expr in post-DFS order.
Parameters
----------
expr: tvm.relay.Expr
The input expression
Returns
-------
free : List[tvm.relay.Var]
The list of all variables in post-DFS order.
"""
return
_ir_pass
.
all_vars
(
expr
)
def
free_type_vars
(
expr
):
def
free_type_vars
(
expr
):
"""Get free type variables from expression/type e
"""Get free type variables from expression/type e
...
@@ -168,12 +200,44 @@ def free_type_vars(expr):
...
@@ -168,12 +200,44 @@ def free_type_vars(expr):
Returns
Returns
-------
-------
free : List[tvm.relay.Type
Param
]
free : List[tvm.relay.Type
Var
]
The list of free type variables
The list of free type variables
in post-DFS order
"""
"""
return
_ir_pass
.
free_type_vars
(
expr
)
return
_ir_pass
.
free_type_vars
(
expr
)
def
bound_type_vars
(
expr
):
"""Get bound type variables from expression/type e
Parameters
----------
expr: Union[tvm.relay.Expr,tvm.relay.Type]
The input expression/type
Returns
-------
free : List[tvm.relay.TypeVar]
The list of bound type variables in post-DFS order
"""
return
_ir_pass
.
bound_type_vars
(
expr
)
def
all_type_vars
(
expr
):
"""Get all type variables from expression/type e
Parameters
----------
expr: Union[tvm.relay.Expr,tvm.relay.Type]
The input expression/type
Returns
-------
free : List[tvm.relay.TypeVar]
The list of all type variables in post-DFS order
"""
return
_ir_pass
.
all_type_vars
(
expr
)
def
simplify_inference
(
expr
):
def
simplify_inference
(
expr
):
""" Simplify the data-flow graph for inference phase.
""" Simplify the data-flow graph for inference phase.
...
...
src/relay/pass/gradient.cc
View file @
6783d373
...
@@ -205,14 +205,25 @@ Expr FirstOrderGradient(const Expr& re, const Module& mod) {
...
@@ -205,14 +205,25 @@ Expr FirstOrderGradient(const Expr& re, const Module& mod) {
});
});
return
Pair
(
res
.
foward
,
grad
);
return
Pair
(
res
.
foward
,
grad
);
});
});
// if type annotations are provided, we will construct a ret type;
// otherwise, leave it to be inferred
Type
ret_type
=
Type
();
std
::
vector
<
Type
>
vt
;
std
::
vector
<
Type
>
vt
;
bool
missing
=
!
f
->
ret_type
.
defined
();
for
(
const
auto
&
p
:
f
->
params
)
{
for
(
const
auto
&
p
:
f
->
params
)
{
if
(
missing
||
!
p
->
type_annotation
.
defined
())
{
missing
=
true
;
break
;
}
vt
.
push_back
(
p
->
type_annotation
);
vt
.
push_back
(
p
->
type_annotation
);
}
}
return
FunctionNode
::
make
(
f
->
params
,
body
,
if
(
!
missing
)
{
TupleTypeNode
::
make
({
f
->
ret_type
,
TupleTypeNode
::
make
({})}),
ret_type
=
TupleTypeNode
::
make
({
f
->
ret_type
,
TupleTypeNode
::
make
(
vt
)});
{});
}
return
FunctionNode
::
make
(
f
->
params
,
body
,
ret_type
,
{});
}
}
TVM_REGISTER_API
(
"relay._ir_pass.first_order_gradient"
)
TVM_REGISTER_API
(
"relay._ir_pass.first_order_gradient"
)
...
...
src/relay/pass/type_infer.cc
View file @
6783d373
...
@@ -56,31 +56,11 @@ bool TupleGetItemRel(const Array<Type>& types,
...
@@ -56,31 +56,11 @@ bool TupleGetItemRel(const Array<Type>& types,
return
true
;
return
true
;
}
}
bool
MakeTupleRel
(
const
Array
<
Type
>&
types
,
int
num_inputs
,
const
Attrs
&
attrs
,
const
TypeReporter
&
reporter
)
{
CHECK_EQ
(
static_cast
<
size_t
>
(
num_inputs
+
1
),
types
.
size
());
for
(
int
i
=
0
;
i
<
num_inputs
;
++
i
)
{
if
(
types
[
i
].
as
<
IncompleteTypeNode
>
())
return
false
;
}
Array
<
Type
>
fields
;
for
(
int
i
=
0
;
i
<
num_inputs
;
++
i
)
{
fields
.
push_back
(
types
[
i
]);
}
reporter
->
Assign
(
types
[
num_inputs
],
TupleTypeNode
::
make
(
fields
));
return
true
;
}
TVM_REGISTER_NODE_TYPE
(
TupleGetItemAttrs
);
TVM_REGISTER_NODE_TYPE
(
TupleGetItemAttrs
);
TVM_REGISTER_API
(
"tvm.relay.type_relation.TupleGetItem"
)
TVM_REGISTER_API
(
"tvm.relay.type_relation.TupleGetItem"
)
.
set_body_typed
<
bool
(
const
Array
<
Type
>&
,
int
,
const
Attrs
&
,
const
TypeReporter
&
)
>
(
.
set_body_typed
<
bool
(
const
Array
<
Type
>&
,
int
,
const
Attrs
&
,
const
TypeReporter
&
)
>
(
TupleGetItemRel
);
TupleGetItemRel
);
TVM_REGISTER_API
(
"tvm.relay.type_relation.MakeTuple"
)
.
set_body_typed
<
bool
(
const
Array
<
Type
>&
,
int
,
const
Attrs
&
,
const
TypeReporter
&
)
>
(
MakeTupleRel
);
struct
ResolvedTypeInfo
{
struct
ResolvedTypeInfo
{
explicit
ResolvedTypeInfo
(
Type
checked_type
,
Array
<
Type
>
type_args
)
explicit
ResolvedTypeInfo
(
Type
checked_type
,
Array
<
Type
>
type_args
)
:
checked_type
(
checked_type
),
type_args
(
type_args
)
{}
:
checked_type
(
checked_type
),
type_args
(
type_args
)
{}
...
@@ -120,6 +100,10 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
...
@@ -120,6 +100,10 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
// type inferencer will populate it up
// type inferencer will populate it up
std
::
unordered_map
<
Expr
,
ResolvedTypeInfo
,
NodeHash
,
NodeEqual
>
type_map_
;
std
::
unordered_map
<
Expr
,
ResolvedTypeInfo
,
NodeHash
,
NodeEqual
>
type_map_
;
// used to ensure we don't have free type vars hanging around
// (a temporary measure until we have proper generalization implemented)
Map
<
TypeVar
,
Type
>
instantiation_map_
;
// The solver used by the inferencer.
// The solver used by the inferencer.
TypeSolver
solver_
;
TypeSolver
solver_
;
// relation function
// relation function
...
@@ -140,6 +124,32 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
...
@@ -140,6 +124,32 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
return
Type
();
return
Type
();
}
}
}
}
// Substitutes every type var in t with a corresponding incomplete type.
// This is a temporary measure to ensure type vars behave until
// generalization is properly implemented.
Type
Instantiate
(
const
Type
&
t
)
{
if
(
!
t
.
defined
())
{
return
t
;
}
auto
*
ft
=
t
.
as
<
FuncTypeNode
>
();
if
(
ft
==
nullptr
)
{
return
Bind
(
t
,
instantiation_map_
);
}
for
(
auto
type_param
:
ft
->
type_params
)
{
instantiation_map_
.
Set
(
type_param
,
IncompleteTypeNode
::
make
(
TypeVarNode
::
Kind
::
kType
));
}
Type
ret_type
=
ft
->
ret_type
;
if
(
!
ret_type
.
defined
())
{
ret_type
=
IncompleteTypeNode
::
make
(
TypeVarNode
::
Kind
::
kType
);
}
auto
strip_tvs
=
FuncTypeNode
::
make
(
ft
->
arg_types
,
ret_type
,
{},
ft
->
type_constraints
);
return
Bind
(
strip_tvs
,
instantiation_map_
);
}
// Lazily get type for expr
// Lazily get type for expr
// will call visit to deduce it if it is not in the type_map_
// will call visit to deduce it if it is not in the type_map_
Type
GetType
(
const
Expr
&
expr
)
{
Type
GetType
(
const
Expr
&
expr
)
{
...
@@ -147,7 +157,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
...
@@ -147,7 +157,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
if
(
it
!=
type_map_
.
end
()
&&
it
->
second
.
checked_type
.
defined
())
{
if
(
it
!=
type_map_
.
end
()
&&
it
->
second
.
checked_type
.
defined
())
{
return
it
->
second
.
checked_type
;
return
it
->
second
.
checked_type
;
}
}
Type
ret
=
this
->
VisitExpr
(
expr
);
Type
ret
=
Instantiate
(
this
->
VisitExpr
(
expr
)
);
ResolvedTypeInfo
&
rti
=
type_map_
[
expr
];
ResolvedTypeInfo
&
rti
=
type_map_
[
expr
];
rti
.
checked_type
=
ret
;
rti
.
checked_type
=
ret
;
return
ret
;
return
ret
;
...
@@ -175,19 +185,11 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
...
@@ -175,19 +185,11 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
}
}
Type
VisitExpr_
(
const
TupleNode
*
op
)
final
{
Type
VisitExpr_
(
const
TupleNode
*
op
)
final
{
if
(
!
make_tuple_rel_
.
defined
())
{
make_tuple_rel_
=
TypeRelationFn
(
EnvFunc
::
Get
(
"tvm.relay.type_relation.MakeTuple"
).
node_
);
}
Array
<
Type
>
types
;
Array
<
Type
>
types
;
for
(
Expr
field
:
op
->
fields
)
{
for
(
Expr
field
:
op
->
fields
)
{
types
.
push_back
(
GetType
(
field
));
types
.
push_back
(
GetType
(
field
));
}
}
Type
rtype
=
IncompleteTypeNode
::
make
(
TypeVarNode
::
Kind
::
kType
);
return
TupleTypeNode
::
make
(
types
);
types
.
push_back
(
rtype
);
solver_
.
AddConstraint
(
TypeRelationNode
::
make
(
make_tuple_rel_
,
types
,
op
->
fields
.
size
(),
Attrs
()));
return
rtype
;
}
}
Type
VisitExpr_
(
const
TupleGetItemNode
*
op
)
final
{
Type
VisitExpr_
(
const
TupleGetItemNode
*
op
)
final
{
...
@@ -209,11 +211,17 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
...
@@ -209,11 +211,17 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
}
}
Type
VisitExpr_
(
const
LetNode
*
op
)
final
{
Type
VisitExpr_
(
const
LetNode
*
op
)
final
{
// if the definition is a function literal, permit recursion
bool
is_functional_literal
=
op
->
value
.
as
<
FunctionNode
>
()
!=
nullptr
;
if
(
is_functional_literal
)
{
type_map_
[
op
->
var
].
checked_type
=
IncompleteTypeNode
::
make
(
TypeVarNode
::
Kind
::
kType
);
}
Type
vtype
=
GetType
(
op
->
value
);
Type
vtype
=
GetType
(
op
->
value
);
if
(
op
->
var
->
type_annotation
.
defined
())
{
if
(
op
->
var
->
type_annotation
.
defined
())
{
vtype
=
Unify
(
vtype
,
op
->
var
->
type_annotation
,
op
->
span
);
vtype
=
Unify
(
vtype
,
op
->
var
->
type_annotation
,
op
->
span
);
}
}
CHECK
(
!
type_map_
.
count
(
op
->
var
));
CHECK
(
is_functional_literal
||
!
type_map_
.
count
(
op
->
var
));
// NOTE: no scoping is necessary because var are unique in program
// NOTE: no scoping is necessary because var are unique in program
type_map_
[
op
->
var
].
checked_type
=
vtype
;
type_map_
[
op
->
var
].
checked_type
=
vtype
;
return
GetType
(
op
->
body
);
return
GetType
(
op
->
body
);
...
@@ -252,16 +260,14 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
...
@@ -252,16 +260,14 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
return
rtype
;
return
rtype
;
}
}
//
instantiate the function type with fresh
//
substitute the type args in the function type
FuncType
Instantiate
(
const
FuncTypeNode
*
fn_ty
,
Array
<
Type
>*
ty_args
)
{
FuncType
Instantiate
FuncType
(
const
FuncTypeNode
*
fn_ty
,
const
Array
<
Type
>&
ty_args
)
{
tvm
::
Map
<
TypeVar
,
Type
>
subst_map
;
tvm
::
Map
<
TypeVar
,
Type
>
subst_map
;
// Build a subsitituion map up from the function type and type arguments.
// Build a subsitituion map up from the function type and type arguments.
// Eventually allow the type vars to be passed in.
// Eventually allow the type vars to be passed in.
for
(
auto
ty_param
:
fn_ty
->
type_params
)
{
for
(
size_t
i
=
0
;
i
<
fn_ty
->
type_params
.
size
();
i
++
)
{
IncompleteType
fresh
=
IncompleteTypeNode
::
make
(
ty_param
->
kind
);
subst_map
.
Set
(
fn_ty
->
type_params
[
i
],
ty_args
[
i
]);
subst_map
.
Set
(
ty_param
,
fresh
);
ty_args
->
push_back
(
fresh
);
}
}
Type
ret_type
=
fn_ty
->
ret_type
;
Type
ret_type
=
fn_ty
->
ret_type
;
...
@@ -296,13 +302,32 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
...
@@ -296,13 +302,32 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
Type
GeneralCall
(
const
CallNode
*
call
,
Array
<
Type
>
arg_types
)
{
Type
GeneralCall
(
const
CallNode
*
call
,
Array
<
Type
>
arg_types
)
{
Type
ftype
=
GetType
(
call
->
op
);
Type
ftype
=
GetType
(
call
->
op
);
auto
*
fn_ty_node
=
ftype
.
as
<
FuncTypeNode
>
();
auto
*
fn_ty_node
=
ftype
.
as
<
FuncTypeNode
>
();
auto
*
inc_ty_node
=
ftype
.
as
<
IncompleteTypeNode
>
();
CHECK
(
fn_ty_node
!=
nullptr
||
inc_ty_node
!=
nullptr
)
<<
"only expressions with function types can be called, found "
<<
ftype
<<
" at "
<<
call
->
span
;
// incomplete type => it must be a function taking the arg types
// with an unknown return type
if
(
inc_ty_node
!=
nullptr
)
{
Type
ret_type
=
IncompleteTypeNode
::
make
(
TypeVarNode
::
Kind
::
kType
);
Type
func_type
=
FuncTypeNode
::
make
(
arg_types
,
ret_type
,
{},
{});
Type
unified
=
this
->
Unify
(
ftype
,
func_type
,
call
->
span
);
fn_ty_node
=
unified
.
as
<
FuncTypeNode
>
();
}
CHECK
(
fn_ty_node
!=
nullptr
)
Array
<
Type
>
type_args
=
call
->
type_args
;
<<
"only expressions with function types can be called, found "
if
(
type_args
.
size
()
==
0
)
{
<<
ftype
<<
" at "
<<
call
->
span
;
for
(
size_t
i
=
0
;
i
<
fn_ty_node
->
type_params
.
size
();
i
++
)
{
type_args
.
push_back
(
IncompleteTypeNode
::
make
(
TypeVarNode
::
Kind
::
kType
));
Array
<
Type
>
type_args
;
}
FuncType
fn_ty
=
Instantiate
(
fn_ty_node
,
&
type_args
);
}
CHECK
(
type_args
.
size
()
==
fn_ty_node
->
type_params
.
size
())
<<
"Incorrect number of type args in "
<<
call
->
span
<<
": "
<<
"Expected "
<<
fn_ty_node
->
type_params
.
size
()
<<
"but got "
<<
type_args
.
size
();
FuncType
fn_ty
=
InstantiateFuncType
(
fn_ty_node
,
type_args
);
AddTypeArgs
(
GetRef
<
Call
>
(
call
),
type_args
);
AddTypeArgs
(
GetRef
<
Call
>
(
call
),
type_args
);
...
@@ -353,26 +378,17 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
...
@@ -353,26 +378,17 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
}
}
Type
VisitExpr_
(
const
FunctionNode
*
f
)
final
{
Type
VisitExpr_
(
const
FunctionNode
*
f
)
final
{
solver_
.
Solve
();
Array
<
Type
>
arg_types
;
for
(
auto
param
:
f
->
params
)
{
for
(
auto
param
:
f
->
params
)
{
GetType
(
param
);
arg_types
.
push_back
(
GetType
(
param
)
);
}
}
Type
rtype
=
GetType
(
f
->
body
);
Type
rtype
=
GetType
(
f
->
body
);
// Run solver using the currently known information
if
(
f
->
ret_type
.
defined
())
{
solver_
.
Solve
();
rtype
=
this
->
Unify
(
f
->
ret_type
,
rtype
,
f
->
span
);
// Trying to resolve
Array
<
Type
>
arg_types
;
for
(
size_t
i
=
0
;
i
<
f
->
params
.
size
();
++
i
)
{
Type
atype
=
solver_
.
Resolve
(
GetType
(
f
->
params
[
i
]));
CHECK
(
atype
.
as
<
IncompleteTypeNode
>
()
==
nullptr
)
<<
"Cannot resolve type of "
<<
i
<<
"-th parameter of function at"
<<
f
->
span
;
arg_types
.
push_back
(
atype
);
}
}
rtype
=
solver_
.
Resolve
(
rtype
);
auto
ret
=
FuncTypeNode
::
make
(
arg_types
,
rtype
,
f
->
type_params
,
{});
CHECK
(
rtype
.
as
<
IncompleteTypeNode
>
()
==
nullptr
)
return
solver_
.
Resolve
(
ret
);
<<
"Cannot resolve return type of function at"
<<
f
->
span
;
// do not support constraint lifting for now.
return
FuncTypeNode
::
make
(
arg_types
,
rtype
,
f
->
type_params
,
{});
}
}
};
};
...
@@ -380,7 +396,7 @@ class TypeInferencer::Resolver : public ExprMutator {
...
@@ -380,7 +396,7 @@ class TypeInferencer::Resolver : public ExprMutator {
public
:
public
:
Resolver
(
const
std
::
unordered_map
<
Expr
,
ResolvedTypeInfo
,
NodeHash
,
NodeEqual
>&
tmap
,
Resolver
(
const
std
::
unordered_map
<
Expr
,
ResolvedTypeInfo
,
NodeHash
,
NodeEqual
>&
tmap
,
TypeSolver
*
solver
)
TypeSolver
*
solver
)
:
tmap_
(
tmap
),
solver_
(
solver
)
{
:
tmap_
(
tmap
),
solver_
(
solver
)
{
}
}
Expr
VisitExpr_
(
const
VarNode
*
op
)
final
{
Expr
VisitExpr_
(
const
VarNode
*
op
)
final
{
...
@@ -525,6 +541,7 @@ Expr TypeInferencer::Infer(Expr expr) {
...
@@ -525,6 +541,7 @@ Expr TypeInferencer::Infer(Expr expr) {
GetType
(
expr
);
GetType
(
expr
);
// Step 1: Solve the constraints.
// Step 1: Solve the constraints.
solver_
.
Solve
();
solver_
.
Solve
();
// Step 2: Attach resolved types to checked_type field.
// Step 2: Attach resolved types to checked_type field.
auto
resolved_expr
=
Resolver
(
type_map_
,
&
solver_
).
VisitExpr
(
expr
);
auto
resolved_expr
=
Resolver
(
type_map_
,
&
solver_
).
VisitExpr
(
expr
);
CHECK
(
WellFormed
(
resolved_expr
));
CHECK
(
WellFormed
(
resolved_expr
));
...
...
src/relay/pass/type_solver.cc
View file @
6783d373
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
*/
*/
#include <string>
#include <string>
#include "type_solver.h"
#include "type_solver.h"
#include "../ir/type_functor.h"
namespace
tvm
{
namespace
tvm
{
namespace
relay
{
namespace
relay
{
...
@@ -38,9 +39,298 @@ class TypeSolver::Reporter : public TypeReporterNode {
...
@@ -38,9 +39,298 @@ class TypeSolver::Reporter : public TypeReporterNode {
TypeSolver
*
solver_
;
TypeSolver
*
solver_
;
};
};
class
TypeSolver
::
OccursChecker
:
public
TypeVisitor
{
public
:
explicit
OccursChecker
(
TypeSolver
*
solver
,
TypeNode
*
var
)
:
solver_
(
solver
),
var_
(
var
),
found_
(
false
)
{}
bool
Check
(
const
Type
&
t
)
{
VisitType
(
t
);
return
found_
;
}
void
VisitType_
(
const
IncompleteTypeNode
*
op
)
override
{
IncompleteType
t
=
GetRef
<
IncompleteType
>
(
op
);
TypeNode
*
node
=
solver_
->
GetTypeNode
(
t
);
found_
=
found_
||
(
var_
->
FindRoot
()
==
node
->
FindRoot
());
}
private
:
TypeSolver
*
solver_
;
TypeNode
*
var_
;
bool
found_
;
};
class
TypeSolver
::
Unifier
:
public
TypeFunctor
<
Type
(
const
Type
&
,
const
Type
&
)
>
{
public
:
explicit
Unifier
(
TypeSolver
*
solver
)
:
solver_
(
solver
)
{}
Type
Unify
(
const
Type
&
src
,
const
Type
&
dst
)
{
// Known limitation
// - handle shape pattern matching
TypeNode
*
lhs
=
solver_
->
GetTypeNode
(
dst
);
TypeNode
*
rhs
=
solver_
->
GetTypeNode
(
src
);
// do occur check so we don't create self-referencing structure
if
(
lhs
->
FindRoot
()
==
rhs
->
FindRoot
())
{
return
lhs
->
resolved_type
;
}
if
(
lhs
->
resolved_type
.
as
<
IncompleteTypeNode
>
())
{
CHECK
(
!
CheckOccurs
(
lhs
,
rhs
->
resolved_type
))
<<
"Incomplete type "
<<
lhs
->
resolved_type
<<
" occurs in "
<<
rhs
->
resolved_type
<<
", cannot unify"
;
solver_
->
MergeFromTo
(
lhs
,
rhs
);
return
rhs
->
resolved_type
;
}
else
if
(
rhs
->
resolved_type
.
as
<
IncompleteTypeNode
>
())
{
CHECK
(
!
CheckOccurs
(
rhs
,
lhs
->
resolved_type
))
<<
"Incomplete type "
<<
rhs
->
resolved_type
<<
" occurs in "
<<
lhs
->
resolved_type
<<
", cannot unify"
;
solver_
->
MergeFromTo
(
rhs
,
lhs
);
return
lhs
->
resolved_type
;
}
else
{
Type
resolved
=
this
->
VisitType
(
lhs
->
resolved_type
,
rhs
->
resolved_type
);
CHECK
(
resolved
.
defined
())
<<
"Unable to unify parent types: "
<<
lhs
->
resolved_type
<<
" and "
<<
rhs
->
resolved_type
;
TypeNode
*
top
=
solver_
->
GetTypeNode
(
resolved
);
solver_
->
MergeFromTo
(
lhs
,
top
);
solver_
->
MergeFromTo
(
rhs
,
top
);
return
resolved
;
}
}
// Checks whether lhs (taken to be a type var) occurs in t, meaning
// there is a recursive equality constraint, which should be rejected.
// N.b.: A tautology like ?a = ?a is okay and should be checked for
// *before* calling this method
bool
CheckOccurs
(
TypeNode
*
lhs
,
const
Type
&
t
)
{
OccursChecker
rc
(
solver_
,
lhs
);
return
rc
.
Check
(
t
);
}
// default: unify only if alpha-equal
Type
VisitTypeDefault_
(
const
Node
*
op
,
const
Type
&
tn
)
override
{
NodeRef
nr
=
GetRef
<
NodeRef
>
(
op
);
Type
t1
=
GetRef
<
Type
>
(
nr
.
as_derived
<
tvm
::
relay
::
TypeNode
>
());
if
(
!
AlphaEqual
(
t1
,
tn
))
{
return
Type
(
nullptr
);
}
return
t1
;
}
Type
VisitType_
(
const
TupleTypeNode
*
op
,
const
Type
&
tn
)
override
{
const
auto
*
ttn
=
tn
.
as
<
TupleTypeNode
>
();
if
(
!
ttn
||
op
->
fields
.
size
()
!=
ttn
->
fields
.
size
())
{
return
Type
(
nullptr
);
}
TupleType
tt1
=
GetRef
<
TupleType
>
(
op
);
TupleType
tt2
=
GetRef
<
TupleType
>
(
ttn
);
std
::
vector
<
Type
>
new_fields
;
for
(
size_t
i
=
0
;
i
<
tt1
->
fields
.
size
();
i
++
)
{
Type
field
=
Unify
(
tt1
->
fields
[
i
],
tt2
->
fields
[
i
]);
new_fields
.
push_back
(
field
);
}
return
TupleTypeNode
::
make
(
new_fields
);
}
Type
VisitType_
(
const
FuncTypeNode
*
op
,
const
Type
&
tn
)
override
{
const
auto
*
ftn
=
tn
.
as
<
FuncTypeNode
>
();
if
(
!
ftn
||
op
->
arg_types
.
size
()
!=
ftn
->
arg_types
.
size
()
||
op
->
type_params
.
size
()
!=
ftn
->
type_params
.
size
()
||
op
->
type_constraints
.
size
()
!=
ftn
->
type_constraints
.
size
())
{
return
Type
(
nullptr
);
}
// remap type vars so they match
Map
<
TypeVar
,
Type
>
subst_map
;
for
(
size_t
i
=
0
;
i
<
op
->
type_params
.
size
();
i
++
)
{
subst_map
.
Set
(
ftn
->
type_params
[
i
],
op
->
type_params
[
i
]);
}
auto
ft1
=
GetRef
<
FuncType
>
(
op
);
auto
ft2
=
Downcast
<
FuncType
>
(
Bind
(
GetRef
<
FuncType
>
(
ftn
),
subst_map
));
Type
ret_type
=
Unify
(
ft1
->
ret_type
,
ft2
->
ret_type
);
std
::
vector
<
Type
>
arg_types
;
for
(
size_t
i
=
0
;
i
<
ft1
->
arg_types
.
size
();
i
++
)
{
Type
arg_type
=
Unify
(
ft1
->
arg_types
[
i
],
ft2
->
arg_types
[
i
]);
arg_types
.
push_back
(
arg_type
);
}
std
::
vector
<
TypeConstraint
>
type_constraints
;
for
(
size_t
i
=
0
;
i
<
ft1
->
type_constraints
.
size
();
i
++
)
{
Type
unified_constraint
=
Unify
(
ft1
->
type_constraints
[
i
],
ft2
->
type_constraints
[
i
]);
const
auto
*
tcn
=
unified_constraint
.
as
<
TypeConstraintNode
>
();
CHECK
(
tcn
)
<<
"Two type constraints unified into a non-constraint?"
<<
ft1
->
type_constraints
[
i
]
<<
" and "
<<
ft2
->
type_constraints
[
i
];
type_constraints
.
push_back
(
GetRef
<
TypeConstraint
>
(
tcn
));
}
return
FuncTypeNode
::
make
(
arg_types
,
ret_type
,
ft1
->
type_params
,
type_constraints
);
}
private
:
TypeSolver
*
solver_
;
};
class
TypeSolver
::
Resolver
:
public
TypeMutator
{
public
:
explicit
Resolver
(
TypeSolver
*
solver
)
:
solver_
(
solver
)
{}
Type
Resolve
(
const
Type
&
t
)
{
if
(
!
t
.
defined
())
{
return
t
;
}
return
VisitType
(
t
);
}
Type
VisitType_
(
const
IncompleteTypeNode
*
op
)
override
{
auto
*
node
=
solver_
->
GetTypeNode
(
GetRef
<
IncompleteType
>
(
op
));
return
node
->
resolved_type
;
}
private
:
TypeSolver
*
solver_
;
};
// It ends up being more compact to simply have TypeFunctor<void(const Type&) than
// a TypeVisitor because we can use the default case to dispense with
// most of the overrides.
class
TypeSolver
::
Propagator
:
public
TypeFunctor
<
void
(
const
Type
&
)
>
{
public
:
explicit
Propagator
(
TypeSolver
*
solver
,
const
std
::
unordered_set
<
RelationNode
*>*
rels
)
:
solver_
(
solver
),
rels_
(
rels
)
{}
// adds the relation node to t and all child types of t
void
Propagate
(
const
Type
&
t
)
{
VisitType
(
t
);
}
void
UpdateRelSet
(
const
Type
&
t
)
{
TypeNode
*
tnode
=
solver_
->
GetTypeNode
(
t
);
for
(
auto
*
rel
:
*
rels_
)
{
tnode
->
rel_set
.
insert
(
rel
);
}
}
void
VisitTypeDefault_
(
const
Node
*
op
)
override
{
NodeRef
nr
=
GetRef
<
NodeRef
>
(
op
);
Type
t
=
GetRef
<
Type
>
(
nr
.
as_derived
<
tvm
::
relay
::
TypeNode
>
());
UpdateRelSet
(
t
);
}
void
VisitType_
(
const
TupleTypeNode
*
op
)
override
{
TupleType
tt
=
GetRef
<
TupleType
>
(
op
);
UpdateRelSet
(
tt
);
for
(
const
Type
&
t
:
tt
->
fields
)
{
Propagate
(
t
);
}
}
void
VisitType_
(
const
FuncTypeNode
*
op
)
override
{
FuncType
ft
=
GetRef
<
FuncType
>
(
op
);
UpdateRelSet
(
ft
);
Propagate
(
ft
->
ret_type
);
for
(
auto
arg_type
:
ft
->
arg_types
)
{
Propagate
(
arg_type
);
}
for
(
auto
type_param
:
ft
->
type_params
)
{
Propagate
(
type_param
);
}
for
(
auto
type_cs
:
ft
->
type_constraints
)
{
Propagate
(
type_cs
);
}
}
private
:
TypeSolver
*
solver_
;
const
std
::
unordered_set
<
RelationNode
*>*
rels_
;
};
// similarly, we use TypeFunctor<void(const Type&)> so we can use
// the default visitor case to avoid more overrides
class
TypeSolver
::
Merger
:
public
TypeFunctor
<
void
(
const
Type
&
)
>
{
public
:
explicit
Merger
(
TypeSolver
*
solver
)
:
solver_
(
solver
)
{}
// Merges src node to dst, ensures *all* type relations of all
// child nodes of src are transferred to dst.
void
Merge
(
TypeNode
*
src
,
TypeNode
*
dst
)
{
if
(
src
==
dst
)
return
;
dst_
=
dst
;
VisitType
(
src
->
resolved_type
);
// set parent at the end so later calls to GetTypeNode go back to src
src
->
parent
=
dst
;
// now propagate relations to child nodes, since change to
// a child node should update parent too
Propagator
prop
(
solver_
,
&
dst
->
rel_set
);
prop
.
Propagate
(
dst
->
resolved_type
);
}
// Transfers any relations linked to t to the stored dst.
// Any unresolved relations are added back to the queue, since
// there is now new information
void
TransferLinks
(
const
Type
&
t
)
{
TypeNode
*
src
=
solver_
->
GetTypeNode
(
t
);
if
(
src
==
dst_
)
return
;
for
(
auto
*
rel
:
src
->
rel_set
)
{
// if the relation is not yet resolved, add to queue
if
(
!
rel
->
resolved
)
{
solver_
->
AddToQueue
(
rel
);
dst_
->
rel_set
.
insert
(
rel
);
}
}
}
void
VisitTypeDefault_
(
const
Node
*
op
)
override
{
NodeRef
nr
=
GetRef
<
NodeRef
>
(
op
);
Type
t
=
GetRef
<
Type
>
(
nr
.
as_derived
<
tvm
::
relay
::
TypeNode
>
());
TransferLinks
(
t
);
}
void
VisitType_
(
const
TupleTypeNode
*
ttn
)
override
{
auto
tup
=
GetRef
<
TupleType
>
(
ttn
);
TransferLinks
(
tup
);
for
(
auto
field
:
tup
->
fields
)
{
VisitType
(
field
);
}
}
void
VisitType_
(
const
FuncTypeNode
*
ftn
)
override
{
auto
func
=
GetRef
<
FuncType
>
(
ftn
);
TransferLinks
(
func
);
VisitType
(
func
->
ret_type
);
for
(
auto
arg
:
func
->
arg_types
)
{
VisitType
(
arg
);
}
for
(
auto
param
:
func
->
type_params
)
{
VisitType
(
param
);
}
for
(
auto
constraint
:
func
->
type_constraints
)
{
VisitType
(
constraint
);
}
}
private
:
TypeSolver
*
solver_
;
TypeNode
*
dst_
;
};
// constructor
// constructor
TypeSolver
::
TypeSolver
()
TypeSolver
::
TypeSolver
()
:
reporter_
(
make_node
<
Reporter
>
(
this
))
{
:
reporter_
(
make_node
<
Reporter
>
(
this
))
{
}
}
// destructor
// destructor
...
@@ -54,31 +344,16 @@ TypeSolver::~TypeSolver() {
...
@@ -54,31 +344,16 @@ TypeSolver::~TypeSolver() {
}
}
}
}
// merge src type node to dst
void
TypeSolver
::
MergeFromTo
(
TypeNode
*
src
,
TypeNode
*
dst
)
{
Merger
merger
(
this
);
merger
.
Merge
(
src
,
dst
);
}
// Add equality constraint
// Add equality constraint
Type
TypeSolver
::
Unify
(
const
Type
&
dst
,
const
Type
&
src
)
{
Type
TypeSolver
::
Unify
(
const
Type
&
dst
,
const
Type
&
src
)
{
// Known limitation
Unifier
unifier
(
this
);
// - handle composite types whose component can be unknown.
return
unifier
.
Unify
(
dst
,
src
);
// - handle shape pattern matching
TypeNode
*
lhs
=
GetTypeNode
(
dst
);
TypeNode
*
rhs
=
GetTypeNode
(
src
);
// do occur check so we don't create self-referencing structure
if
(
lhs
->
FindRoot
()
==
rhs
->
FindRoot
())
{
return
lhs
->
resolved_type
;
}
if
(
lhs
->
resolved_type
.
as
<
IncompleteTypeNode
>
())
{
MergeFromTo
(
lhs
,
rhs
);
return
rhs
->
resolved_type
;
}
else
if
(
rhs
->
resolved_type
.
as
<
IncompleteTypeNode
>
())
{
MergeFromTo
(
rhs
,
lhs
);
return
lhs
->
resolved_type
;
}
else
{
lhs
->
parent
=
rhs
;
CHECK
(
AlphaEqual
(
lhs
->
resolved_type
,
rhs
->
resolved_type
))
<<
"Incompatible parent types in UF:"
<<
lhs
->
resolved_type
<<
" and "
<<
rhs
->
resolved_type
;
return
rhs
->
resolved_type
;
}
}
}
// Add type constraint to the solver.
// Add type constraint to the solver.
...
@@ -96,9 +371,9 @@ void TypeSolver::AddConstraint(const TypeConstraint& constraint) {
...
@@ -96,9 +371,9 @@ void TypeSolver::AddConstraint(const TypeConstraint& constraint) {
tlink
->
value
=
tnode
;
tlink
->
value
=
tnode
;
rnode
->
type_list
.
Push
(
tlink
);
rnode
->
type_list
.
Push
(
tlink
);
// insert type->relation node
// insert type->relation node
LinkNode
<
RelationNode
*>*
rlink
=
arena_
.
make
<
LinkNode
<
RelationNode
*>
>
()
;
std
::
unordered_set
<
RelationNode
*>
singleton
{
rnode
}
;
rlink
->
value
=
rnode
;
Propagator
prop
(
this
,
&
singleton
)
;
tnode
->
rel_list
.
Push
(
rlink
);
prop
.
Propagate
(
tnode
->
resolved_type
);
}
}
// add the relation to the working queue.
// add the relation to the working queue.
this
->
AddToQueue
(
rnode
);
this
->
AddToQueue
(
rnode
);
...
@@ -110,12 +385,10 @@ void TypeSolver::AddConstraint(const TypeConstraint& constraint) {
...
@@ -110,12 +385,10 @@ void TypeSolver::AddConstraint(const TypeConstraint& constraint) {
// Resolve a type in the solver context.
// Resolve a type in the solver context.
Type
TypeSolver
::
Resolve
(
const
Type
&
type
)
{
Type
TypeSolver
::
Resolve
(
const
Type
&
type
)
{
Resolver
resolver
(
this
);
auto
it
=
tmap_
.
find
(
type
);
auto
it
=
tmap_
.
find
(
type
);
if
(
it
!=
tmap_
.
end
())
{
Type
t
=
(
it
!=
tmap_
.
end
())
?
it
->
second
->
FindRoot
()
->
resolved_type
:
type
;
return
it
->
second
->
FindRoot
()
->
resolved_type
;
return
resolver
.
Resolve
(
t
);
}
else
{
return
type
;
}
}
}
bool
TypeSolver
::
Solve
()
{
bool
TypeSolver
::
Solve
()
{
...
@@ -128,7 +401,7 @@ bool TypeSolver::Solve() {
...
@@ -128,7 +401,7 @@ bool TypeSolver::Solve() {
// update the relation with given evidence.
// update the relation with given evidence.
Array
<
Type
>
args
;
Array
<
Type
>
args
;
for
(
auto
*
tlink
=
rnode
->
type_list
.
head
;
tlink
!=
nullptr
;
tlink
=
tlink
->
next
)
{
for
(
auto
*
tlink
=
rnode
->
type_list
.
head
;
tlink
!=
nullptr
;
tlink
=
tlink
->
next
)
{
args
.
push_back
(
tlink
->
value
->
FindRoot
()
->
resolved_type
);
args
.
push_back
(
Resolve
(
tlink
->
value
->
FindRoot
()
->
resolved_type
)
);
CHECK_LE
(
args
.
size
(),
rel
->
args
.
size
());
CHECK_LE
(
args
.
size
(),
rel
->
args
.
size
());
}
}
// call the function
// call the function
...
@@ -161,8 +434,8 @@ TVM_REGISTER_API("relay._ir_pass._test_type_solver")
...
@@ -161,8 +434,8 @@ TVM_REGISTER_API("relay._ir_pass._test_type_solver")
return
solver
->
Solve
();
return
solver
->
Solve
();
});
});
}
else
if
(
name
==
"Unify"
)
{
}
else
if
(
name
==
"Unify"
)
{
return
TypedPackedFunc
<
void
(
Type
,
Type
)
>
([
solver
](
Type
lhs
,
Type
rhs
)
{
return
TypedPackedFunc
<
Type
(
Type
,
Type
)
>
([
solver
](
Type
lhs
,
Type
rhs
)
{
solver
->
Unify
(
lhs
,
rhs
);
return
solver
->
Unify
(
lhs
,
rhs
);
});
});
}
else
if
(
name
==
"Resolve"
)
{
}
else
if
(
name
==
"Resolve"
)
{
return
TypedPackedFunc
<
Type
(
Type
)
>
([
solver
](
Type
t
)
{
return
TypedPackedFunc
<
Type
(
Type
)
>
([
solver
](
Type
t
)
{
...
...
src/relay/pass/type_solver.h
View file @
6783d373
...
@@ -18,6 +18,7 @@ namespace relay {
...
@@ -18,6 +18,7 @@ namespace relay {
using
common
::
LinkNode
;
using
common
::
LinkNode
;
using
common
::
LinkedList
;
using
common
::
LinkedList
;
/*!
/*!
* \brief Interface of type solver used in type inference.
* \brief Interface of type solver used in type inference.
*
*
...
@@ -65,6 +66,11 @@ class TypeSolver {
...
@@ -65,6 +66,11 @@ class TypeSolver {
Type
Unify
(
const
Type
&
lhs
,
const
Type
&
rhs
);
Type
Unify
(
const
Type
&
lhs
,
const
Type
&
rhs
);
private
:
private
:
class
OccursChecker
;
class
Unifier
;
class
Resolver
;
class
Propagator
;
class
Merger
;
class
Reporter
;
class
Reporter
;
struct
TypeNode
;
struct
TypeNode
;
struct
RelationNode
;
struct
RelationNode
;
...
@@ -77,15 +83,15 @@ class TypeSolver {
...
@@ -77,15 +83,15 @@ class TypeSolver {
* that can unifies the same types to the name resolved_type.
* that can unifies the same types to the name resolved_type.
*
*
* It also contains collection of links to related Relations,
* It also contains collection of links to related Relations,
* which is stored in rel_
lis
t.
* which is stored in rel_
se
t.
*/
*/
struct
TypeNode
{
struct
TypeNode
{
/*! \brief The final resolved type */
/*! \brief The final resolved type */
Type
resolved_type
;
Type
resolved_type
;
/*! \brief type node in the union find algorithm */
/*! \brief type node in the union find algorithm */
TypeNode
*
parent
{
nullptr
};
TypeNode
*
parent
{
nullptr
};
/*! \brief
lis
t of relations that is related to this type node */
/*! \brief
se
t of relations that is related to this type node */
LinkedList
<
RelationNode
*>
rel_lis
t
;
std
::
unordered_set
<
RelationNode
*>
rel_se
t
;
/*!
/*!
* \brief Find the root type node, perform path compression
* \brief Find the root type node, perform path compression
* \return The root type node.
* \return The root type node.
...
@@ -125,7 +131,7 @@ class TypeSolver {
...
@@ -125,7 +131,7 @@ class TypeSolver {
size_t
num_resolved_rels_
{
0
};
size_t
num_resolved_rels_
{
0
};
/*! \brief map from type node to types. */
/*! \brief map from type node to types. */
std
::
unordered_map
<
Type
,
TypeNode
*
,
NodeHash
,
NodeEqual
>
tmap_
;
std
::
unordered_map
<
Type
,
TypeNode
*
,
NodeHash
,
NodeEqual
>
tmap_
;
/*! \br
ei
f Internal queue to update the relation */
/*! \br
ie
f Internal queue to update the relation */
std
::
queue
<
RelationNode
*>
update_queue_
;
std
::
queue
<
RelationNode
*>
update_queue_
;
/*! \brief allocator of all the internal node obhect*/
/*! \brief allocator of all the internal node obhect*/
common
::
Arena
arena_
;
common
::
Arena
arena_
;
...
@@ -163,22 +169,7 @@ class TypeSolver {
...
@@ -163,22 +169,7 @@ class TypeSolver {
* \param src The source operand
* \param src The source operand
* \param dst The dst operand.
* \param dst The dst operand.
*/
*/
void
MergeFromTo
(
TypeNode
*
src
,
TypeNode
*
dst
)
{
void
MergeFromTo
(
TypeNode
*
src
,
TypeNode
*
dst
);
if
(
src
==
dst
)
return
;
src
->
parent
=
dst
;
// move the link to the to dst
for
(
auto
*
rlink
=
src
->
rel_list
.
head
;
rlink
!=
nullptr
;)
{
// store next pointer first before rlink get moved
auto
*
next
=
rlink
->
next
;
// if the relation is not yet resolved
// send the relation to the new
if
(
!
rlink
->
value
->
resolved
)
{
this
->
AddToQueue
(
rlink
->
value
);
dst
->
rel_list
.
Push
(
rlink
);
}
rlink
=
next
;
}
}
};
};
}
// namespace relay
}
// namespace relay
...
...
src/relay/pass/util.cc
View file @
6783d373
...
@@ -12,105 +12,211 @@
...
@@ -12,105 +12,211 @@
namespace
tvm
{
namespace
tvm
{
namespace
relay
{
namespace
relay
{
// FreeTypeVar
template
<
typename
T
>
class
FreeTypeVarTVisitor
:
public
TypeVisitor
{
struct
InsertionSet
{
std
::
unordered_set
<
T
,
NodeHash
,
NodeEqual
>
set
;
std
::
vector
<
T
>
data
;
void
Insert
(
const
T
&
t
)
{
if
(
set
.
count
(
t
)
==
0
)
{
set
.
insert
(
t
);
data
.
push_back
(
t
);
}
}
};
class
TypeVarTVisitor
:
public
TypeVisitor
{
public
:
public
:
Free
TypeVarTVisitor
(
TypeVarTVisitor
(
Array
<
TypeVar
>*
fre
e_vars
,
InsertionSet
<
TypeVar
>*
typ
e_vars
,
std
::
unordered_set
<
TypeVar
,
NodeHash
,
NodeEqual
>*
bound
_vars
)
InsertionSet
<
TypeVar
>*
bound_type
_vars
)
:
free_vars_
(
free_vars
),
bound_vars_
(
bound
_vars
)
{
}
:
type_vars_
(
type_vars
),
bound_type_vars_
(
bound_type
_vars
)
{
}
void
VisitType_
(
const
TypeVarNode
*
tp
)
final
{
void
VisitType_
(
const
TypeVarNode
*
tp
)
final
{
TypeVar
var
=
GetRef
<
TypeVar
>
(
tp
);
TypeVar
var
=
GetRef
<
TypeVar
>
(
tp
);
if
(
bound_vars_
->
count
(
var
)
==
0
)
{
type_vars_
->
Insert
(
var
);
free_vars_
->
push_back
(
var
);
}
}
}
void
VisitType_
(
const
FuncTypeNode
*
f
)
final
{
void
VisitType_
(
const
FuncTypeNode
*
f
)
final
{
for
(
auto
type_param
:
f
->
type_params
)
{
for
(
auto
type_param
:
f
->
type_params
)
{
bound_vars_
->
insert
(
type_param
);
type_vars_
->
Insert
(
type_param
);
bound_type_vars_
->
Insert
(
type_param
);
}
}
TypeVisitor
::
VisitType_
(
f
);
TypeVisitor
::
VisitType_
(
f
);
}
}
private
:
private
:
Array
<
TypeVar
>*
fre
e_vars_
;
InsertionSet
<
TypeVar
>*
typ
e_vars_
;
std
::
unordered_set
<
TypeVar
,
NodeHash
,
NodeEqual
>*
bound
_vars_
;
InsertionSet
<
TypeVar
>*
bound_type
_vars_
;
};
};
class
Free
TypeVarEVisitor
:
private
ExprVisitor
{
class
TypeVarEVisitor
:
private
ExprVisitor
{
public
:
public
:
Array
<
TypeVar
>
Find
(
const
Expr
&
expr
)
{
Array
<
TypeVar
>
CollectFree
()
{
this
->
VisitExpr
(
expr
);
Array
<
TypeVar
>
ret
;
return
free_vars_
;
for
(
const
auto
&
v
:
type_vars_
.
data
)
{
if
(
bound_type_vars_
.
set
.
count
(
v
)
==
0
)
{
ret
.
push_back
(
v
);
}
}
return
ret
;
}
Array
<
TypeVar
>
CollectBound
()
{
Array
<
TypeVar
>
ret
;
for
(
const
auto
&
v
:
bound_type_vars_
.
data
)
{
ret
.
push_back
(
v
);
}
return
ret
;
}
Array
<
TypeVar
>
CollectAll
()
{
Array
<
TypeVar
>
ret
;
for
(
const
auto
&
v
:
type_vars_
.
data
)
{
ret
.
push_back
(
v
);
}
return
ret
;
}
}
Array
<
TypeVar
>
Find
(
const
Type
&
type
)
{
Array
<
TypeVar
>
Free
(
const
Expr
&
expr
)
{
this
->
VisitType
(
type
);
VisitExpr
(
expr
);
return
free_vars_
;
return
CollectFree
();
}
Array
<
TypeVar
>
Free
(
const
Type
&
type
)
{
VisitType
(
type
);
return
CollectFree
();
}
Array
<
TypeVar
>
Bound
(
const
Expr
&
expr
)
{
VisitExpr
(
expr
);
return
CollectBound
();
}
Array
<
TypeVar
>
Bound
(
const
Type
&
type
)
{
VisitType
(
type
);
return
CollectBound
();
}
Array
<
TypeVar
>
All
(
const
Expr
&
expr
)
{
VisitExpr
(
expr
);
return
CollectAll
();
}
Array
<
TypeVar
>
All
(
const
Type
&
type
)
{
VisitType
(
type
);
return
CollectAll
();
}
}
void
VisitExpr_
(
const
FunctionNode
*
f
)
final
{
void
VisitExpr_
(
const
FunctionNode
*
f
)
final
{
for
(
const
auto
&
tp
:
f
->
type_params
)
{
for
(
const
auto
&
tp
:
f
->
type_params
)
{
bound_vars_
.
insert
(
tp
);
type_vars_
.
Insert
(
tp
);
bound_type_vars_
.
Insert
(
tp
);
}
}
ExprVisitor
::
VisitExpr_
(
f
);
ExprVisitor
::
VisitExpr_
(
f
);
}
}
void
VisitType
(
const
Type
&
t
)
final
{
void
VisitType
(
const
Type
&
t
)
final
{
FreeTypeVarTVisitor
(
&
free_vars_
,
&
bound
_vars_
)
TypeVarTVisitor
(
&
type_vars_
,
&
bound_type
_vars_
)
.
VisitType
(
t
);
.
VisitType
(
t
);
}
}
private
:
private
:
// The result list
InsertionSet
<
TypeVar
>
type_vars_
;
Array
<
TypeVar
>
free_vars_
;
InsertionSet
<
TypeVar
>
bound_type_vars_
;
std
::
unordered_set
<
TypeVar
,
NodeHash
,
NodeEqual
>
bound_vars_
;
};
};
class
Free
VarVisitor
:
protected
ExprVisitor
{
class
VarVisitor
:
protected
ExprVisitor
{
public
:
public
:
Array
<
Var
>
F
ind
(
const
Expr
&
expr
)
{
Array
<
Var
>
F
ree
(
const
Expr
&
expr
)
{
this
->
VisitExpr
(
expr
);
this
->
VisitExpr
(
expr
);
return
free_vars_
;
Array
<
Var
>
ret
;
for
(
const
auto
&
v
:
vars_
.
data
)
{
if
(
bound_vars_
.
set
.
count
(
v
)
==
0
)
{
ret
.
push_back
(
v
);
}
}
return
ret
;
}
}
void
VisitExpr_
(
const
VarNode
*
var
)
final
{
Array
<
Var
>
Bound
(
const
Expr
&
expr
)
{
if
(
bound_vars_
.
count
(
var
)
==
0
)
{
this
->
VisitExpr
(
expr
);
free_vars_
.
push_back
(
GetRef
<
Var
>
(
var
));
Array
<
Var
>
ret
;
for
(
const
auto
&
v
:
bound_vars_
.
data
)
{
ret
.
push_back
(
v
);
}
}
return
ret
;
}
Array
<
Var
>
All
(
const
Expr
&
expr
)
{
this
->
VisitExpr
(
expr
);
Array
<
Var
>
ret
;
for
(
const
auto
&
v
:
vars_
.
data
)
{
ret
.
push_back
(
v
);
}
return
ret
;
}
void
MarkBounded
(
const
Var
&
v
)
{
bound_vars_
.
Insert
(
v
);
vars_
.
Insert
(
v
);
}
void
VisitExpr_
(
const
VarNode
*
var
)
final
{
vars_
.
Insert
(
GetRef
<
Var
>
(
var
));
}
}
void
VisitExpr_
(
const
FunctionNode
*
op
)
final
{
void
VisitExpr_
(
const
FunctionNode
*
op
)
final
{
for
(
const
auto
&
param
:
op
->
params
)
{
for
(
const
auto
&
param
:
op
->
params
)
{
bound_vars_
.
insert
(
param
.
operator
->
()
);
MarkBounded
(
param
);
}
}
VisitExpr
(
op
->
body
);
VisitExpr
(
op
->
body
);
}
}
void
VisitExpr_
(
const
LetNode
*
op
)
final
{
void
VisitExpr_
(
const
LetNode
*
op
)
final
{
bound_vars_
.
insert
(
op
->
var
.
operator
->
()
);
MarkBounded
(
op
->
var
);
VisitExpr
(
op
->
value
);
VisitExpr
(
op
->
value
);
VisitExpr
(
op
->
body
);
VisitExpr
(
op
->
body
);
}
}
private
:
private
:
// The result list
InsertionSet
<
Var
>
vars_
;
Array
<
Var
>
free_vars_
;
InsertionSet
<
Var
>
bound_vars_
;
std
::
unordered_set
<
const
VarNode
*>
bound_vars_
;
};
};
tvm
::
Array
<
TypeVar
>
FreeTypeVars
(
const
Expr
&
expr
)
{
tvm
::
Array
<
TypeVar
>
FreeTypeVars
(
const
Expr
&
expr
)
{
return
FreeTypeVarEVisitor
().
Find
(
expr
);
return
TypeVarEVisitor
().
Free
(
expr
);
}
}
tvm
::
Array
<
TypeVar
>
FreeTypeVars
(
const
Type
&
type
)
{
tvm
::
Array
<
TypeVar
>
FreeTypeVars
(
const
Type
&
type
)
{
return
FreeTypeVarEVisitor
().
Find
(
type
);
return
TypeVarEVisitor
().
Free
(
type
);
}
tvm
::
Array
<
TypeVar
>
BoundTypeVars
(
const
Expr
&
expr
)
{
return
TypeVarEVisitor
().
Bound
(
expr
);
}
tvm
::
Array
<
TypeVar
>
BoundTypeVars
(
const
Type
&
type
)
{
return
TypeVarEVisitor
().
Bound
(
type
);
}
tvm
::
Array
<
TypeVar
>
AllTypeVars
(
const
Expr
&
expr
)
{
return
TypeVarEVisitor
().
All
(
expr
);
}
tvm
::
Array
<
TypeVar
>
AllTypeVars
(
const
Type
&
type
)
{
return
TypeVarEVisitor
().
All
(
type
);
}
}
tvm
::
Array
<
Var
>
FreeVars
(
const
Expr
&
expr
)
{
tvm
::
Array
<
Var
>
FreeVars
(
const
Expr
&
expr
)
{
return
FreeVarVisitor
().
Find
(
expr
);
return
VarVisitor
().
Free
(
expr
);
}
tvm
::
Array
<
Var
>
BoundVars
(
const
Expr
&
expr
)
{
return
VarVisitor
().
Bound
(
expr
);
}
tvm
::
Array
<
Var
>
AllVars
(
const
Expr
&
expr
)
{
return
VarVisitor
().
All
(
expr
);
}
}
TVM_REGISTER_API
(
"relay._ir_pass.free_vars"
)
TVM_REGISTER_API
(
"relay._ir_pass.free_vars"
)
...
@@ -118,16 +224,46 @@ TVM_REGISTER_API("relay._ir_pass.free_vars")
...
@@ -118,16 +224,46 @@ TVM_REGISTER_API("relay._ir_pass.free_vars")
*
ret
=
FreeVars
(
args
[
0
]);
*
ret
=
FreeVars
(
args
[
0
]);
});
});
TVM_REGISTER_API
(
"relay._ir_pass.bound_vars"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
BoundVars
(
args
[
0
]);
});
TVM_REGISTER_API
(
"relay._ir_pass.all_vars"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
AllVars
(
args
[
0
]);
});
TVM_REGISTER_API
(
"relay._ir_pass.free_type_vars"
)
TVM_REGISTER_API
(
"relay._ir_pass.free_type_vars"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
NodeRef
x
=
args
[
0
];
NodeRef
x
=
args
[
0
];
if
(
x
.
as
<
TypeNode
>
())
{
if
(
x
.
as
_derived
<
TypeNode
>
())
{
*
ret
=
FreeTypeVars
(
Downcast
<
Type
>
(
x
));
*
ret
=
FreeTypeVars
(
Downcast
<
Type
>
(
x
));
}
else
{
}
else
{
*
ret
=
FreeTypeVars
(
Downcast
<
Expr
>
(
x
));
*
ret
=
FreeTypeVars
(
Downcast
<
Expr
>
(
x
));
}
}
});
});
TVM_REGISTER_API
(
"relay._ir_pass.bound_type_vars"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
NodeRef
x
=
args
[
0
];
if
(
x
.
as_derived
<
TypeNode
>
())
{
*
ret
=
BoundTypeVars
(
Downcast
<
Type
>
(
x
));
}
else
{
*
ret
=
BoundTypeVars
(
Downcast
<
Expr
>
(
x
));
}
});
TVM_REGISTER_API
(
"relay._ir_pass.all_type_vars"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
NodeRef
x
=
args
[
0
];
if
(
x
.
as_derived
<
TypeNode
>
())
{
*
ret
=
AllTypeVars
(
Downcast
<
Type
>
(
x
));
}
else
{
*
ret
=
AllTypeVars
(
Downcast
<
Expr
>
(
x
));
}
});
/*!
/*!
* \brief Get reference counter of each internal ExprNode in body.
* \brief Get reference counter of each internal ExprNode in body.
* \param body The body expression.
* \param body The body expression.
...
...
tests/cpp/relay_pass_type_infer_test.cc
View file @
6783d373
...
@@ -6,13 +6,17 @@
...
@@ -6,13 +6,17 @@
TEST
(
Relay
,
SelfReference
)
{
TEST
(
Relay
,
SelfReference
)
{
using
namespace
tvm
;
using
namespace
tvm
;
auto
type_a
=
relay
::
TypeVarNode
::
make
(
"a"
,
relay
::
TypeVarNode
::
kType
);
auto
tensor_type
=
relay
::
TensorTypeNode
::
make
({},
::
tvm
::
Bool
());
auto
type_b
=
relay
::
TypeVarNode
::
make
(
"b"
,
relay
::
TypeVarNode
::
kType
);
auto
x
=
relay
::
VarNode
::
make
(
"x"
,
relay
::
Type
());
auto
x
=
relay
::
VarNode
::
make
(
"x"
,
type_a
);
auto
f
=
relay
::
FunctionNode
::
make
(
tvm
::
Array
<
relay
::
Var
>
{
x
},
x
,
relay
::
Type
(),
{});
auto
f
=
relay
::
FunctionNode
::
make
(
tvm
::
Array
<
relay
::
Var
>
{
x
},
x
,
type_b
,
Array
<
relay
::
TypeVar
>
{});
auto
fx
=
relay
::
CallNode
::
make
(
f
,
Array
<
relay
::
Expr
>
{
x
});
auto
y
=
relay
::
VarNode
::
make
(
"y"
,
tensor_type
);
auto
call
=
relay
::
CallNode
::
make
(
f
,
Array
<
relay
::
Expr
>
{
y
});
auto
fx
=
relay
::
FunctionNode
::
make
(
tvm
::
Array
<
relay
::
Var
>
{
y
},
call
,
relay
::
Type
(),
{});
auto
type_fx
=
relay
::
InferType
(
fx
,
relay
::
ModuleNode
::
make
(
Map
<
relay
::
GlobalVar
,
relay
::
Function
>
{}));
auto
type_fx
=
relay
::
InferType
(
fx
,
relay
::
ModuleNode
::
make
(
Map
<
relay
::
GlobalVar
,
relay
::
Function
>
{}));
CHECK_EQ
(
type_fx
->
checked_type
(),
type_a
);
auto
expected
=
relay
::
FuncTypeNode
::
make
(
tvm
::
Array
<
relay
::
Type
>
{
tensor_type
},
tensor_type
,
{},
{});
CHECK
(
AlphaEqual
(
type_fx
->
checked_type
(),
expected
));
}
}
int
main
(
int
argc
,
char
**
argv
)
{
int
main
(
int
argc
,
char
**
argv
)
{
...
...
tests/python/relay/test_pass_free_vars.py
deleted
100644 → 0
View file @
76188a43
import
tvm
from
tvm
import
relay
from
tvm.relay.ir_pass
import
free_vars
,
free_type_vars
def
test_free_vars
():
ty
=
relay
.
TensorType
([],
"int32"
)
x
=
relay
.
Var
(
"x"
,
ty
)
fvx
=
free_vars
(
x
)
assert
len
(
fvx
)
==
1
assert
fvx
[
0
]
==
x
v
=
relay
.
Constant
(
tvm
.
nd
.
array
(
10
))
let
=
relay
.
Let
(
x
,
v
,
x
)
fvx
=
free_vars
(
let
)
assert
len
(
free_vars
(
let
))
==
0
f
=
relay
.
Function
([
x
],
x
,
ty
)
assert
len
(
free_vars
(
f
))
==
0
def
test_tuple
():
t
=
relay
.
Var
(
't'
)
fv
=
free_vars
(
relay
.
Tuple
([
t
,
t
]))
assert
len
(
fv
)
==
1
assert
fv
[
0
]
==
t
fv
=
free_vars
(
relay
.
TupleGetItem
(
t
,
123
))
assert
len
(
fv
)
==
1
assert
fv
[
0
]
==
t
def
test_free_type_vars
():
tp
=
relay
.
TypeVar
(
""
)
ty
=
relay
.
TupleType
([
tp
,
relay
.
TensorType
([],
"int32"
)])
x
=
relay
.
Var
(
"x"
,
ty
)
y
=
relay
.
Var
(
"y"
)
let
=
relay
.
Let
(
x
,
y
,
x
)
fvl
=
free_vars
(
let
)
assert
len
(
fvl
)
==
1
assert
fvl
[
0
]
==
y
ftvl
=
free_type_vars
(
let
)
assert
len
(
ftvl
)
==
1
assert
ftvl
[
0
]
==
tp
tests/python/relay/test_pass_vars.py
0 → 100644
View file @
6783d373
import
tvm
from
tvm
import
relay
from
tvm.relay.ir_pass
import
(
free_vars
,
free_type_vars
,
bound_vars
,
bound_type_vars
,
all_vars
,
all_type_vars
)
def
assert_vars_match
(
actual
,
expected
):
assert
len
(
actual
)
==
len
(
expected
)
for
i
in
range
(
len
(
actual
)):
assert
actual
[
i
]
==
expected
[
i
]
def
test_free_vars
():
ty
=
relay
.
TensorType
([],
"int32"
)
x
=
relay
.
Var
(
"x"
,
ty
)
fvx
=
free_vars
(
x
)
assert
len
(
fvx
)
==
1
assert
fvx
[
0
]
==
x
v
=
relay
.
Constant
(
tvm
.
nd
.
array
(
10
))
let
=
relay
.
Let
(
x
,
v
,
x
)
fvx
=
free_vars
(
let
)
assert
len
(
free_vars
(
let
))
==
0
f
=
relay
.
Function
([
x
],
x
,
ty
)
assert
len
(
free_vars
(
f
))
==
0
def
test_free_vars_tuple
():
t
=
relay
.
Var
(
't'
)
fv
=
free_vars
(
relay
.
Tuple
([
t
,
t
]))
assert
len
(
fv
)
==
1
assert
fv
[
0
]
==
t
fv
=
free_vars
(
relay
.
TupleGetItem
(
t
,
123
))
assert
len
(
fv
)
==
1
assert
fv
[
0
]
==
t
def
test_free_type_vars
():
tp
=
relay
.
TypeVar
(
""
)
ty
=
relay
.
TupleType
([
tp
,
relay
.
TensorType
([],
"int32"
)])
x
=
relay
.
Var
(
"x"
,
ty
)
y
=
relay
.
Var
(
"y"
)
let
=
relay
.
Let
(
x
,
y
,
x
)
fvl
=
free_vars
(
let
)
assert
len
(
fvl
)
==
1
assert
fvl
[
0
]
==
y
ftvl
=
free_type_vars
(
let
)
assert
len
(
ftvl
)
==
1
assert
ftvl
[
0
]
==
tp
def
test_bound_vars
():
x
=
relay
.
Var
(
"x"
)
y
=
relay
.
Var
(
"y"
)
z
=
relay
.
Var
(
"z"
)
a
=
relay
.
Var
(
"a"
)
f1
=
relay
.
Function
([
x
,
y
,
z
],
relay
.
Let
(
a
,
x
,
relay
.
Tuple
([])))
assert_vars_match
(
bound_vars
(
f1
),
[
x
,
y
,
z
,
a
])
tup
=
relay
.
Tuple
([
x
,
y
,
z
,
a
])
assert
len
(
bound_vars
(
tup
))
==
0
f2
=
relay
.
Function
([
x
,
y
],
relay
.
Tuple
([
x
,
y
,
z
,
a
]))
assert_vars_match
(
bound_vars
(
f2
),
[
x
,
y
])
def
test_bound_type_vars
():
a
=
relay
.
TypeVar
(
"a"
)
b
=
relay
.
TypeVar
(
"b"
)
c
=
relay
.
TypeVar
(
"c"
)
ft1
=
relay
.
FuncType
([
a
],
b
,
[
a
,
b
])
bound_ft1
=
bound_type_vars
(
ft1
)
assert_vars_match
(
bound_type_vars
(
ft1
),
[
a
,
b
])
ft2
=
relay
.
FuncType
([],
c
,
[
a
])
assert_vars_match
(
bound_type_vars
(
ft2
),
[
a
])
tup_ty
=
relay
.
TupleType
([
a
,
b
,
c
])
assert
len
(
bound_type_vars
(
tup_ty
))
==
0
f1
=
relay
.
Function
([],
relay
.
Tuple
([]),
type_params
=
[
a
,
b
])
assert_vars_match
(
bound_type_vars
(
f1
),
[
a
,
b
])
f2
=
relay
.
Function
([],
relay
.
Tuple
([]),
c
)
assert
len
(
bound_type_vars
(
f2
))
==
0
x
=
relay
.
Var
(
"x"
,
a
)
let1
=
relay
.
Let
(
x
,
relay
.
Tuple
([]),
x
)
assert
len
(
bound_type_vars
(
let1
))
==
0
let2
=
relay
.
Let
(
x
,
relay
.
Function
([],
relay
.
Tuple
([]),
type_params
=
[
b
,
c
]),
x
)
assert_vars_match
(
bound_type_vars
(
let2
),
[
b
,
c
])
def
test_all_vars
():
x
=
relay
.
Var
(
"x"
)
y
=
relay
.
Var
(
"y"
)
z
=
relay
.
Var
(
"z"
)
f1
=
relay
.
Function
([
x
,
y
],
z
)
assert_vars_match
(
all_vars
(
f1
),
[
x
,
y
,
z
])
f2
=
relay
.
Function
([
x
],
relay
.
Let
(
y
,
relay
.
Tuple
([]),
z
))
assert_vars_match
(
all_vars
(
f2
),
[
x
,
y
,
z
])
f3
=
relay
.
Function
([
x
],
relay
.
Tuple
([
y
,
z
]))
assert_vars_match
(
all_vars
(
f3
),
[
x
,
y
,
z
])
tup
=
relay
.
Tuple
([
x
,
y
,
z
])
assert_vars_match
(
all_vars
(
tup
),
[
x
,
y
,
z
])
def
test_all_type_vars
():
a
=
relay
.
TypeVar
(
"a"
)
b
=
relay
.
TypeVar
(
"b"
)
c
=
relay
.
TypeVar
(
"c"
)
ft1
=
relay
.
FuncType
([
b
],
c
,
[
a
])
assert_vars_match
(
all_type_vars
(
ft1
),
[
a
,
b
,
c
])
ft2
=
relay
.
FuncType
([],
relay
.
TupleType
([
a
,
b
,
c
]),
[])
assert_vars_match
(
all_type_vars
(
ft2
),
[
a
,
b
,
c
])
w
=
relay
.
Var
(
"w"
)
x
=
relay
.
Var
(
"x"
,
a
)
y
=
relay
.
Var
(
"y"
,
b
)
z
=
relay
.
Var
(
"z"
,
c
)
f1
=
relay
.
Function
([
x
],
y
,
b
,
[
a
])
assert_vars_match
(
all_type_vars
(
f1
),
[
a
,
b
])
f2
=
relay
.
Function
([
x
],
relay
.
Let
(
y
,
x
,
z
))
assert_vars_match
(
all_type_vars
(
f2
),
[
a
,
b
,
c
])
f3
=
relay
.
Function
([],
relay
.
Tuple
([
x
,
y
,
z
]),
ret_type
=
relay
.
TupleType
([
a
,
b
,
c
]))
assert_vars_match
(
all_type_vars
(
f3
),
[
a
,
b
,
c
])
f4
=
relay
.
Function
([
w
],
relay
.
Tuple
([]),
type_params
=
[
a
,
b
,
c
])
assert_vars_match
(
all_type_vars
(
f4
),
[
a
,
b
,
c
])
f5
=
relay
.
Function
([
w
],
w
)
assert
len
(
all_type_vars
(
f5
))
==
0
tests/python/relay/test_type_infer.py
View file @
6783d373
...
@@ -23,7 +23,7 @@ def test_monomorphic_let():
...
@@ -23,7 +23,7 @@ def test_monomorphic_let():
x
=
sb
.
let
(
'x'
,
relay
.
const
(
1.0
,
"float64"
))
x
=
sb
.
let
(
'x'
,
relay
.
const
(
1.0
,
"float64"
))
sb
.
ret
(
x
)
sb
.
ret
(
x
)
xchecked
=
relay
.
ir_pass
.
infer_type
(
sb
.
get
())
xchecked
=
relay
.
ir_pass
.
infer_type
(
sb
.
get
())
assert
xchecked
.
checked_type
==
relay
.
scalar_type
(
"float64"
)
assert
xchecked
.
checked_type
==
relay
.
scalar_type
(
"float64"
)
def
test_single_op
():
def
test_single_op
():
...
@@ -41,14 +41,15 @@ def test_add_broadcast_op():
...
@@ -41,14 +41,15 @@ def test_add_broadcast_op():
return x + y;
return x + y;
}
}
"""
"""
pass
x
=
relay
.
var
(
'x'
,
shape
=
(
10
,
4
))
# x = relay.var('x', shape=(10, 4))
y
=
relay
.
var
(
'y'
,
shape
=
(
5
,
10
,
1
))
# y = relay.var('y', shape=(5, 10, 1))
z
=
x
+
y
# z = x + y
func
=
relay
.
Function
([
x
,
y
],
z
)
# func = relay.Function([x, y], z)
t1
=
relay
.
TensorType
((
10
,
4
),
'float32'
)
# ttype = relay.TensorType((5, 5, 5), 'float32')
t2
=
relay
.
TensorType
((
5
,
10
,
1
),
'float32'
)
# expected_ty = relay.FuncType([ttype, ttype], ttype)
t3
=
relay
.
TensorType
((
5
,
10
,
4
),
'float32'
)
# assert_has_type(func.to_func(), expected_ty)
expected_ty
=
relay
.
FuncType
([
t1
,
t2
],
t3
)
assert_has_type
(
func
,
expected_ty
)
def
test_dual_op
():
def
test_dual_op
():
...
@@ -110,24 +111,17 @@ def test_recursion():
...
@@ -110,24 +111,17 @@ def test_recursion():
assert
"
%3
= @f(
%1
,
%2
)"
in
mod
.
astext
()
assert
"
%3
= @f(
%1
,
%2
)"
in
mod
.
astext
()
assert
mod
[
f
]
.
checked_type
==
relay
.
FuncType
([
ti32
,
tf32
],
tf32
)
assert
mod
[
f
]
.
checked_type
==
relay
.
FuncType
([
ti32
,
tf32
],
tf32
)
# This currently fails and should pass under the type system.
#
# This test is to illustrate problem with our weak form of
# unification.
#
def
test_incomplete_call
():
def
test_incomplete_call
():
sb
=
ScopeBuilder
(
)
tt
=
relay
.
scalar_type
(
'int32'
)
x
=
relay
.
var
(
'x'
,
dtype
=
'int32'
)
x
=
relay
.
var
(
'x'
,
tt
)
f
=
relay
.
var
(
'f'
)
f
=
relay
.
var
(
'f'
)
func
=
relay
.
Function
([
x
,
f
],
relay
.
Call
(
f
,
[
x
]))
func
=
relay
.
Function
([
x
,
f
],
relay
.
Call
(
f
,
[
x
]),
tt
)
ft
=
relay
.
ir_pass
.
infer_type
(
func
)
f_type
=
relay
.
FuncType
([
tt
],
tt
)
assert
ft
.
checked_type
==
relay
.
FuncType
([
tt
,
f_type
],
tt
)
try
:
relay
.
ir_pass
.
infer_type
(
func
)
assert
False
except
tvm
.
TVMError
as
e
:
assert
True
def
test_tuple
():
def
test_tuple
():
tp
=
relay
.
TensorType
((
10
,))
tp
=
relay
.
TensorType
((
10
,))
...
@@ -136,6 +130,7 @@ def test_tuple():
...
@@ -136,6 +130,7 @@ def test_tuple():
assert
(
relay
.
ir_pass
.
infer_type
(
res
)
.
checked_type
==
assert
(
relay
.
ir_pass
.
infer_type
(
res
)
.
checked_type
==
relay
.
TupleType
([
tp
,
tp
]))
relay
.
TupleType
([
tp
,
tp
]))
def
test_free_expr
():
def
test_free_expr
():
x
=
relay
.
var
(
"x"
,
"float32"
)
x
=
relay
.
var
(
"x"
,
"float32"
)
y
=
relay
.
add
(
x
,
x
)
y
=
relay
.
add
(
x
,
x
)
...
@@ -161,38 +156,26 @@ def test_type_args():
...
@@ -161,38 +156,26 @@ def test_type_args():
assert
sh2
[
1
]
.
value
==
10
assert
sh2
[
1
]
.
value
==
10
def
test_self_reference
():
def
test_global_var_recursion
():
"""
Program:
def f(x) {
return x;
}
"""
a
=
relay
.
TypeVar
(
"a"
)
x
=
relay
.
var
(
"x"
,
a
)
sb
=
relay
.
ScopeBuilder
()
f
=
relay
.
Function
([
x
],
x
)
fx
=
relay
.
Call
(
f
,
[
x
])
assert
relay
.
ir_pass
.
infer_type
(
x
)
.
checked_type
==
a
assert
relay
.
ir_pass
.
infer_type
(
f
)
.
checked_type
==
relay
.
FuncType
([
a
],
a
)
assert
relay
.
ir_pass
.
infer_type
(
fx
)
.
checked_type
==
a
def
test_global_var_cow_issue
():
mod
=
relay
.
Module
({})
mod
=
relay
.
Module
({})
gv
=
relay
.
GlobalVar
(
"foo"
)
gv
=
relay
.
GlobalVar
(
"foo"
)
x
=
relay
.
var
(
'x'
,
shape
=
[])
x
=
relay
.
var
(
'x'
,
shape
=
[])
func
=
relay
.
Function
([
x
],
relay
.
Call
(
gv
,
[
x
]),
tt
=
relay
.
scalar_type
(
'float32'
)
relay
.
TensorType
([],
'float32'
))
func
=
relay
.
Function
([
x
],
relay
.
Call
(
gv
,
[
x
]),
tt
)
mod
[
gv
]
=
func
mod
[
gv
]
=
func
ft
=
relay
.
ir_pass
.
infer_type
(
gv
,
mod
)
assert
mod
[
ft
]
.
checked_type
==
relay
.
FuncType
([
tt
],
tt
)
def
test_equal
():
def
test_equal
():
i
=
relay
.
var
(
'i'
,
shape
=
[],
dtype
=
'int32'
)
i
=
relay
.
var
(
'i'
,
shape
=
[],
dtype
=
'int32'
)
eq
=
op
.
equal
(
i
,
relay
.
const
(
0
,
dtype
=
'int32'
))
eq
=
op
.
equal
(
i
,
relay
.
const
(
0
,
dtype
=
'int32'
))
# This should fail ....
func
=
relay
.
Function
([
i
],
eq
)
func
=
relay
.
Function
([
i
],
eq
,
ret_type
=
relay
.
TensorType
([],
'int32'
))
ft
=
relay
.
ir_pass
.
infer_type
(
func
)
assert
ft
.
checked_type
==
relay
.
FuncType
([
relay
.
scalar_type
(
'int32'
)],
relay
.
scalar_type
(
'bool'
))
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
@@ -204,8 +187,12 @@ if __name__ == "__main__":
...
@@ -204,8 +187,12 @@ if __name__ == "__main__":
test_decl
()
test_decl
()
test_recursion
()
test_recursion
()
test_tuple
()
test_tuple
()
test_generalized_tuple
()
test_incomplete_call
()
test_incomplete_call
()
test_generalized_call
()
test_call_with_type_args
()
test_free_expr
()
test_free_expr
()
test_type_args
()
test_type_args
()
test_self_reference
()
test_self_reference
()
test_global_var_cow_issue
()
test_global_var_recursion
()
test_equal
()
tests/python/relay/test_type_solver.py
View file @
6783d373
import
tvm
import
tvm
from
tvm
import
relay
from
tvm
import
relay
from
nose.tools
import
raises
def
make_rel
(
name
,
args
,
num_inputs
=
None
,
attrs
=
None
):
def
make_rel
(
name
,
args
,
num_inputs
=
None
,
attrs
=
None
):
...
@@ -48,7 +49,170 @@ def test_backward_solving():
...
@@ -48,7 +49,170 @@ def test_backward_solving():
assert
solver
.
Resolve
(
t3
)
==
relay
.
ty
.
TensorType
((
10
,
10
,
20
),
"float32"
)
assert
solver
.
Resolve
(
t3
)
==
relay
.
ty
.
TensorType
((
10
,
10
,
20
),
"float32"
)
def
test_unify_tuple
():
solver
=
make_solver
()
t1
=
relay
.
ty
.
IncompleteType
()
t2
=
relay
.
ty
.
IncompleteType
()
t3
=
relay
.
ty
.
TensorType
((
10
,
20
),
"float32"
)
tup1
=
relay
.
ty
.
TupleType
([
t1
,
t2
])
tup2
=
relay
.
ty
.
TupleType
([
t3
,
t3
])
unified
=
solver
.
Unify
(
tup1
,
tup2
)
assert
unified
==
tup2
def
test_unify_functype
():
solver
=
make_solver
()
t1
=
relay
.
ty
.
IncompleteType
()
t2
=
relay
.
ty
.
IncompleteType
()
t3
=
relay
.
ty
.
IncompleteType
()
unit
=
relay
.
ty
.
TupleType
([])
tensor1
=
relay
.
ty
.
TensorType
((
10
,
20
),
"float32"
)
tensor2
=
relay
.
ty
.
TensorType
((
10
,),
"float32"
)
ft1
=
relay
.
ty
.
FuncType
([
t1
,
t2
],
t3
)
ft2
=
relay
.
ty
.
FuncType
([
tensor1
,
tensor2
],
unit
)
unified
=
solver
.
Unify
(
ft1
,
ft2
)
assert
unified
==
ft2
def
test_recursive_unify
():
solver
=
make_solver
()
t1
=
relay
.
ty
.
IncompleteType
()
t2
=
relay
.
ty
.
IncompleteType
()
t3
=
relay
.
ty
.
IncompleteType
()
tensor1
=
relay
.
ty
.
TensorType
((
10
,
10
,
20
),
"float32"
)
tensor2
=
relay
.
ty
.
TensorType
((
10
,
20
),
"float32"
)
tensor3
=
relay
.
ty
.
TensorType
((
10
,),
"float32"
)
tup1
=
relay
.
ty
.
TupleType
([
relay
.
ty
.
TupleType
([
t1
,
t2
]),
t2
])
tup2
=
relay
.
ty
.
TupleType
([
relay
.
ty
.
TupleType
([
tensor1
,
tensor2
]),
tensor2
])
ft1
=
relay
.
ty
.
FuncType
([
tup1
,
t3
],
t3
)
ft2
=
relay
.
ty
.
FuncType
([
tup2
,
tensor3
],
tensor3
)
unified
=
solver
.
Unify
(
ft1
,
ft2
)
assert
unified
==
ft2
def
test_unify_vars_under_tuples
():
solver
=
make_solver
()
t1
=
relay
.
ty
.
IncompleteType
()
tup1
=
relay
.
ty
.
TupleType
([
t1
,
t1
])
unified
=
solver
.
Unify
(
tup1
,
tup1
)
assert
unified
==
tup1
t2
=
relay
.
ty
.
IncompleteType
()
tup2
=
relay
.
ty
.
TupleType
([
t2
,
t2
])
tup3
=
relay
.
ty
.
TupleType
([
t1
,
t2
])
tup4
=
relay
.
ty
.
TupleType
([
t2
,
t1
])
unified
=
solver
.
Unify
(
tup3
,
tup4
)
assert
(
unified
==
tup1
or
unified
==
tup2
)
def
test_binding_over_typevars
():
solver
=
make_solver
()
t1
=
relay
.
ty
.
IncompleteType
()
t2
=
relay
.
ty
.
IncompleteType
()
a
=
relay
.
ty
.
TypeVar
(
'a'
)
b
=
relay
.
ty
.
TypeVar
(
'b'
)
c
=
relay
.
ty
.
TypeVar
(
'c'
)
d
=
relay
.
ty
.
TypeVar
(
'd'
)
ft1
=
relay
.
ty
.
FuncType
([
t1
],
t2
,
[
c
,
d
])
ft2
=
relay
.
ty
.
FuncType
([
a
],
b
,
[
a
,
b
])
unified
=
solver
.
Unify
(
ft1
,
ft2
)
assert
(
unified
==
solver
.
Resolve
(
ft1
))
def
test_recursive_backward_solving
():
solver
=
make_solver
()
tensor1
=
relay
.
ty
.
TensorType
((
10
,
20
),
"float32"
)
tensor2
=
relay
.
ty
.
TensorType
((
10
,
1
,
1
),
"float32"
)
tensor3
=
relay
.
ty
.
TensorType
((
10
,),
"float32"
)
t1
=
relay
.
ty
.
IncompleteType
()
t2
=
relay
.
ty
.
IncompleteType
()
t3
=
relay
.
ty
.
IncompleteType
()
tup1
=
relay
.
ty
.
TupleType
([
relay
.
ty
.
TupleType
([
tensor1
,
tensor2
]),
tensor3
])
tup2
=
relay
.
ty
.
TupleType
([
relay
.
ty
.
TupleType
([
t1
,
t2
]),
t3
])
solver
.
gen_type
(
"Identity"
,
[
tup1
],
out
=
tup2
)
assert
solver
.
Solve
()
assert
solver
.
Resolve
(
tup2
)
==
tup1
def
test_backward_solving_after_child_update
():
solver
=
make_solver
()
tensor1
=
relay
.
ty
.
TensorType
((
10
,
20
),
"float32"
)
tensor2
=
relay
.
ty
.
TensorType
((
10
,
1
,
1
),
"float32"
)
t1
=
relay
.
ty
.
IncompleteType
()
t2
=
relay
.
ty
.
IncompleteType
()
t3
=
relay
.
ty
.
IncompleteType
()
tup1
=
relay
.
ty
.
TupleType
([
t1
,
t2
])
tup2
=
relay
.
ty
.
TupleType
([
t1
,
t3
])
tup_concrete
=
relay
.
ty
.
TupleType
([
tensor1
,
tensor2
])
t4
=
solver
.
gen_type
(
"Identity"
,
[
tup1
])
t5
=
solver
.
gen_type
(
"Identity"
,
[
tup2
])
solver
.
gen_type
(
"Identity"
,
[
t4
],
out
=
t5
)
assert
solver
.
Solve
()
assert
solver
.
Resolve
(
t3
)
==
t3
or
solver
.
Resolve
(
t3
)
==
t2
assert
solver
.
Resolve
(
t4
)
==
tup1
or
solver
.
Resolve
(
t4
)
==
tup2
assert
solver
.
Resolve
(
t5
)
==
tup1
or
solver
.
Resolve
(
t5
)
==
tup2
# updating the variables *inside* tup1 and tup2 should update t4 and t5
solver
.
gen_type
(
"Identity"
,
[
t1
],
out
=
tensor1
)
solver
.
gen_type
(
"Identity"
,
[
t2
],
out
=
tensor2
)
assert
solver
.
Solve
()
assert
solver
.
Resolve
(
t4
)
==
tup_concrete
assert
solver
.
Resolve
(
t5
)
==
tup_concrete
@raises
(
tvm
.
_ffi
.
base
.
TVMError
)
def
test_incompatible_tuple_unification
():
solver
=
make_solver
()
t1
=
relay
.
ty
.
IncompleteType
()
t2
=
relay
.
ty
.
IncompleteType
()
tensor1
=
relay
.
ty
.
TensorType
((
1
,
2
,
3
),
"float32"
)
tensor2
=
relay
.
ty
.
TensorType
((
2
,
3
),
"float32"
)
tensor3
=
relay
.
ty
.
TensorType
((
3
,),
"float32"
)
tup1
=
relay
.
ty
.
TupleType
([
relay
.
ty
.
TupleType
([
t1
,
t1
]),
t2
])
tup2
=
relay
.
ty
.
TupleType
([
relay
.
ty
.
TupleType
([
tensor1
,
tensor2
]),
tensor3
])
solver
.
Unify
(
tup1
,
tup2
)
@raises
(
tvm
.
_ffi
.
base
.
TVMError
)
def
test_bad_recursive_unification
():
solver
=
make_solver
()
t1
=
relay
.
ty
.
IncompleteType
()
solver
.
Unify
(
t1
,
relay
.
ty
.
TupleType
([
t1
,
t1
]))
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_bcast
()
test_bcast
()
test_backward_solving
()
test_backward_solving
()
test_unify_tuple
()
test_unify_functype
()
test_recursive_unify
()
test_unify_vars_under_tuples
()
test_recursive_backward_solving
()
test_backward_solving_after_child_update
()
test_incompatible_tuple_unification
()
test_bad_recursive_unification
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment