Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
54dbcc28
Commit
54dbcc28
authored
Sep 10, 2019
by
雾雨魔理沙
Committed by
Wuwei Lin
Sep 10, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Relay] fix exponential blowup in interpreter (#3559)
parent
5bff6cce
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
73 additions
and
30 deletions
+73
-30
include/tvm/relay/feature.h
+2
-2
python/tvm/relay/backend/interpreter.py
+1
-0
src/relay/backend/interpreter.cc
+3
-0
src/relay/ir/alpha_equal.cc
+1
-1
src/relay/pass/feature.cc
+14
-6
src/relay/pass/type_infer.cc
+1
-12
src/relay/pass/type_solver.cc
+22
-8
tests/python/relay/test_feature.py
+2
-1
tests/python/relay/test_pass_to_cps.py
+14
-0
tests/python/relay/test_type_infer.py
+13
-0
No files found.
include/tvm/relay/feature.h
View file @
54dbcc28
...
@@ -81,13 +81,13 @@ class FeatureSet {
...
@@ -81,13 +81,13 @@ class FeatureSet {
return
ret
;
return
ret
;
}
}
/*! \brief A set that contain all the Feature. */
/*! \brief A set that contain all the Feature. */
static
FeatureSet
All
Feature
()
{
static
FeatureSet
All
()
{
FeatureSet
fs
;
FeatureSet
fs
;
fs
.
bs_
.
flip
();
fs
.
bs_
.
flip
();
return
fs
;
return
fs
;
}
}
/*! \brief The empty set. Contain no Feature. */
/*! \brief The empty set. Contain no Feature. */
static
FeatureSet
No
Feature
()
{
static
FeatureSet
No
()
{
FeatureSet
fs
;
FeatureSet
fs
;
return
fs
;
return
fs
;
}
}
...
...
python/tvm/relay/backend/interpreter.py
View file @
54dbcc28
...
@@ -280,6 +280,7 @@ class Interpreter(Executor):
...
@@ -280,6 +280,7 @@ class Interpreter(Executor):
"""
"""
seq
=
transform
.
Sequential
([
transform
.
SimplifyInference
(),
seq
=
transform
.
Sequential
([
transform
.
SimplifyInference
(),
transform
.
FuseOps
(
0
),
transform
.
FuseOps
(
0
),
transform
.
ToANormalForm
(),
transform
.
InferType
()])
transform
.
InferType
()])
return
seq
(
self
.
mod
)
return
seq
(
self
.
mod
)
...
...
src/relay/backend/interpreter.cc
View file @
54dbcc28
...
@@ -29,6 +29,7 @@
...
@@ -29,6 +29,7 @@
#include <tvm/relay/interpreter.h>
#include <tvm/relay/interpreter.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/attrs/debug.h>
#include <tvm/relay/attrs/debug.h>
#include <tvm/relay/feature.h>
#include "compile_engine.h"
#include "compile_engine.h"
namespace
tvm
{
namespace
tvm
{
...
@@ -761,6 +762,8 @@ CreateInterpreter(
...
@@ -761,6 +762,8 @@ CreateInterpreter(
Target
target
)
{
Target
target
)
{
auto
intrp
=
std
::
make_shared
<
Interpreter
>
(
mod
,
context
,
target
);
auto
intrp
=
std
::
make_shared
<
Interpreter
>
(
mod
,
context
,
target
);
auto
packed
=
[
intrp
](
Expr
expr
)
{
auto
packed
=
[
intrp
](
Expr
expr
)
{
auto
f
=
DetectFeature
(
expr
);
CHECK
(
f
.
is_subset_of
(
FeatureSet
::
All
()
-
fGraph
));
return
intrp
->
Eval
(
expr
);
return
intrp
->
Eval
(
expr
);
};
};
return
TypedPackedFunc
<
Value
(
Expr
)
>
(
packed
);
return
TypedPackedFunc
<
Value
(
Expr
)
>
(
packed
);
...
...
src/relay/ir/alpha_equal.cc
View file @
54dbcc28
...
@@ -120,7 +120,7 @@ class AlphaEqualHandler:
...
@@ -120,7 +120,7 @@ class AlphaEqualHandler:
* \return the comparison result.
* \return the comparison result.
*/
*/
bool
TypeEqual
(
const
Type
&
lhs
,
const
Type
&
rhs
)
{
bool
TypeEqual
(
const
Type
&
lhs
,
const
Type
&
rhs
)
{
auto
compute
=
[
&
](){
auto
compute
=
[
&
]()
{
if
(
lhs
.
same_as
(
rhs
))
return
true
;
if
(
lhs
.
same_as
(
rhs
))
return
true
;
if
(
!
lhs
.
defined
()
||
!
rhs
.
defined
())
return
false
;
if
(
!
lhs
.
defined
()
||
!
rhs
.
defined
())
return
false
;
return
this
->
VisitType
(
lhs
,
rhs
);
return
this
->
VisitType
(
lhs
,
rhs
);
...
...
src/relay/pass/feature.cc
View file @
54dbcc28
...
@@ -34,13 +34,15 @@ namespace relay {
...
@@ -34,13 +34,15 @@ namespace relay {
FeatureSet
DetectFeature
(
const
Expr
&
expr
)
{
FeatureSet
DetectFeature
(
const
Expr
&
expr
)
{
if
(
!
expr
.
defined
())
{
if
(
!
expr
.
defined
())
{
return
FeatureSet
::
No
Feature
();
return
FeatureSet
::
No
();
}
}
struct
FeatureDetector
:
ExprVisitor
{
struct
FeatureDetector
:
ExprVisitor
{
std
::
unordered_set
<
Expr
,
NodeHash
,
NodeEqual
>
visited_
;
std
::
unordered_set
<
Expr
,
NodeHash
,
NodeEqual
>
visited_
;
FeatureSet
fs
=
FeatureSet
::
NoFeature
();
FeatureSet
fs
=
FeatureSet
::
No
();
void
VisitExpr
(
const
Expr
&
expr
)
final
{
void
VisitExpr
(
const
Expr
&
expr
)
final
{
if
(
visited_
.
count
(
expr
)
==
0
)
{
if
(
visited_
.
count
(
expr
)
==
0
)
{
visited_
.
insert
(
expr
);
ExprVisitor
::
VisitExpr
(
expr
);
ExprVisitor
::
VisitExpr
(
expr
);
}
else
{
}
else
{
if
(
!
IsAtomic
(
expr
))
{
if
(
!
IsAtomic
(
expr
))
{
...
@@ -52,15 +54,20 @@ FeatureSet DetectFeature(const Expr& expr) {
...
@@ -52,15 +54,20 @@ FeatureSet DetectFeature(const Expr& expr) {
void VisitExpr_(const CONSTRUCT_NAME##Node* op) final { \
void VisitExpr_(const CONSTRUCT_NAME##Node* op) final { \
STMT \
STMT \
fs += f##CONSTRUCT_NAME; \
fs += f##CONSTRUCT_NAME; \
ExprVisitor::VisitExpr_(op); \
}
}
#define DETECT_DEFAULT_CONSTRUCT(CONSTRUCT_NAME) DETECT_CONSTRUCT(CONSTRUCT_NAME, {})
#define DETECT_DEFAULT_CONSTRUCT(CONSTRUCT_NAME) DETECT_CONSTRUCT(CONSTRUCT_NAME, { \
ExprVisitor::VisitExpr_(op); \
})
DETECT_DEFAULT_CONSTRUCT
(
Var
)
DETECT_DEFAULT_CONSTRUCT
(
Var
)
DETECT_DEFAULT_CONSTRUCT
(
GlobalVar
)
DETECT_DEFAULT_CONSTRUCT
(
GlobalVar
)
DETECT_DEFAULT_CONSTRUCT
(
Constant
)
DETECT_DEFAULT_CONSTRUCT
(
Constant
)
DETECT_DEFAULT_CONSTRUCT
(
Tuple
)
DETECT_DEFAULT_CONSTRUCT
(
Tuple
)
DETECT_DEFAULT_CONSTRUCT
(
TupleGetItem
)
DETECT_DEFAULT_CONSTRUCT
(
TupleGetItem
)
DETECT_DEFAULT_CONSTRUCT
(
Function
)
DETECT_CONSTRUCT
(
Function
,
{
if
(
!
op
->
IsPrimitive
())
{
ExprVisitor
::
VisitExpr_
(
op
);
}
})
DETECT_DEFAULT_CONSTRUCT
(
Op
)
DETECT_DEFAULT_CONSTRUCT
(
Op
)
DETECT_DEFAULT_CONSTRUCT
(
Call
)
DETECT_DEFAULT_CONSTRUCT
(
Call
)
DETECT_CONSTRUCT
(
Let
,
{
DETECT_CONSTRUCT
(
Let
,
{
...
@@ -69,6 +76,7 @@ FeatureSet DetectFeature(const Expr& expr) {
...
@@ -69,6 +76,7 @@ FeatureSet DetectFeature(const Expr& expr) {
fs
+=
fLetRec
;
fs
+=
fLetRec
;
}
}
}
}
ExprVisitor
::
VisitExpr_
(
op
);
})
})
DETECT_DEFAULT_CONSTRUCT
(
If
)
DETECT_DEFAULT_CONSTRUCT
(
If
)
DETECT_DEFAULT_CONSTRUCT
(
RefCreate
)
DETECT_DEFAULT_CONSTRUCT
(
RefCreate
)
...
@@ -83,7 +91,7 @@ FeatureSet DetectFeature(const Expr& expr) {
...
@@ -83,7 +91,7 @@ FeatureSet DetectFeature(const Expr& expr) {
}
}
FeatureSet
DetectFeature
(
const
Module
&
mod
)
{
FeatureSet
DetectFeature
(
const
Module
&
mod
)
{
FeatureSet
fs
=
FeatureSet
::
No
Feature
();
FeatureSet
fs
=
FeatureSet
::
No
();
if
(
mod
.
defined
())
{
if
(
mod
.
defined
())
{
for
(
const
auto
&
f
:
mod
->
functions
)
{
for
(
const
auto
&
f
:
mod
->
functions
)
{
fs
+=
DetectFeature
(
f
.
second
);
fs
+=
DetectFeature
(
f
.
second
);
...
...
src/relay/pass/type_infer.cc
View file @
54dbcc28
...
@@ -139,19 +139,8 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
...
@@ -139,19 +139,8 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
// Perform unification on two types and report the error at the expression
// Perform unification on two types and report the error at the expression
// or the span of the expression.
// or the span of the expression.
Type
Unify
(
const
Type
&
t1
,
const
Type
&
t2
,
const
NodeRef
&
expr
)
{
Type
Unify
(
const
Type
&
t1
,
const
Type
&
t2
,
const
NodeRef
&
expr
)
{
// TODO(tqchen, jroesch): propagate span to solver
try
{
try
{
// instantiate higher-order func types when unifying because
return
solver_
.
Unify
(
t1
,
t2
,
expr
);
// we only allow polymorphism at the top level
Type
first
=
t1
;
Type
second
=
t2
;
if
(
auto
*
ft1
=
t1
.
as
<
FuncTypeNode
>
())
{
first
=
InstantiateFuncType
(
ft1
);
}
if
(
auto
*
ft2
=
t2
.
as
<
FuncTypeNode
>
())
{
second
=
InstantiateFuncType
(
ft2
);
}
return
solver_
.
Unify
(
first
,
second
,
expr
);
}
catch
(
const
dmlc
::
Error
&
e
)
{
}
catch
(
const
dmlc
::
Error
&
e
)
{
this
->
ReportFatalError
(
this
->
ReportFatalError
(
expr
,
expr
,
...
...
src/relay/pass/type_solver.cc
View file @
54dbcc28
...
@@ -289,30 +289,44 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> {
...
@@ -289,30 +289,44 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> {
const
auto
*
ftn
=
tn
.
as
<
FuncTypeNode
>
();
const
auto
*
ftn
=
tn
.
as
<
FuncTypeNode
>
();
if
(
!
ftn
if
(
!
ftn
||
op
->
arg_types
.
size
()
!=
ftn
->
arg_types
.
size
()
||
op
->
arg_types
.
size
()
!=
ftn
->
arg_types
.
size
()
||
op
->
type_params
.
size
()
!=
ftn
->
type_params
.
size
()
||
op
->
type_constraints
.
size
()
!=
ftn
->
type_constraints
.
size
())
{
||
op
->
type_constraints
.
size
()
!=
ftn
->
type_constraints
.
size
())
{
return
Type
(
nullptr
);
return
Type
(
nullptr
);
}
}
// without loss of generality, suppose op->type_params.size() >= ftn->type_params.size().
if
(
op
->
type_params
.
size
()
<
ftn
->
type_params
.
size
())
{
return
VisitType_
(
ftn
,
GetRef
<
FuncType
>
(
op
));
}
// remap type vars so they match
// remap type vars so they match
Map
<
TypeVar
,
Type
>
subst_map
;
Map
<
TypeVar
,
Type
>
subst_map
;
for
(
size_t
i
=
0
;
i
<
op
->
type_params
.
size
();
i
++
)
{
tvm
::
Array
<
TypeVar
>
ft_type_params
;
subst_map
.
Set
(
ftn
->
type_params
[
i
],
op
->
type_params
[
i
]);
for
(
size_t
i
=
0
;
i
<
ftn
->
type_params
.
size
();
++
i
)
{
subst_map
.
Set
(
op
->
type_params
[
i
],
ftn
->
type_params
[
i
]);
ft_type_params
.
push_back
(
op
->
type_params
[
i
]);
}
for
(
size_t
i
=
ftn
->
type_params
.
size
();
i
<
op
->
type_params
.
size
();
++
i
)
{
subst_map
.
Set
(
op
->
type_params
[
i
],
IncompleteTypeNode
::
make
(
kType
));
}
}
auto
ft1
=
GetRef
<
FuncType
>
(
op
);
FuncType
ft
=
FuncTypeNode
::
make
(
op
->
arg_types
,
auto
ft2
=
Downcast
<
FuncType
>
(
Bind
(
GetRef
<
FuncType
>
(
ftn
),
subst_map
));
op
->
ret_type
,
ft_type_params
,
op
->
type_constraints
);
auto
ft1
=
Downcast
<
FuncType
>
(
Bind
(
ft
,
subst_map
));
auto
ft2
=
GetRef
<
FuncType
>
(
ftn
);
Type
ret_type
=
Unify
(
ft1
->
ret_type
,
ft2
->
ret_type
);
Type
ret_type
=
Unify
(
ft1
->
ret_type
,
ft2
->
ret_type
);
std
::
vector
<
Type
>
arg_types
;
std
::
vector
<
Type
>
arg_types
;
for
(
size_t
i
=
0
;
i
<
ft
1
->
arg_types
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
ft
2
->
arg_types
.
size
();
++
i
)
{
Type
arg_type
=
Unify
(
ft1
->
arg_types
[
i
],
ft2
->
arg_types
[
i
]);
Type
arg_type
=
Unify
(
ft1
->
arg_types
[
i
],
ft2
->
arg_types
[
i
]);
arg_types
.
push_back
(
arg_type
);
arg_types
.
push_back
(
arg_type
);
}
}
std
::
vector
<
TypeConstraint
>
type_constraints
;
std
::
vector
<
TypeConstraint
>
type_constraints
;
for
(
size_t
i
=
0
;
i
<
ft1
->
type_constraints
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
ft1
->
type_constraints
.
size
();
++
i
)
{
Type
unified_constraint
=
Unify
(
ft1
->
type_constraints
[
i
],
Type
unified_constraint
=
Unify
(
ft1
->
type_constraints
[
i
],
ft2
->
type_constraints
[
i
]);
ft2
->
type_constraints
[
i
]);
const
auto
*
tcn
=
unified_constraint
.
as
<
TypeConstraintNode
>
();
const
auto
*
tcn
=
unified_constraint
.
as
<
TypeConstraintNode
>
();
...
@@ -321,7 +335,7 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> {
...
@@ -321,7 +335,7 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> {
type_constraints
.
push_back
(
GetRef
<
TypeConstraint
>
(
tcn
));
type_constraints
.
push_back
(
GetRef
<
TypeConstraint
>
(
tcn
));
}
}
return
FuncTypeNode
::
make
(
arg_types
,
ret_type
,
ft
1
->
type_params
,
type_constraints
);
return
FuncTypeNode
::
make
(
arg_types
,
ret_type
,
ft
2
->
type_params
,
type_constraints
);
}
}
Type
VisitType_
(
const
RefTypeNode
*
op
,
const
Type
&
tn
)
final
{
Type
VisitType_
(
const
RefTypeNode
*
op
,
const
Type
&
tn
)
final
{
...
...
tests/python/relay/test_feature.py
View file @
54dbcc28
...
@@ -63,7 +63,8 @@ def test_ad():
...
@@ -63,7 +63,8 @@ def test_ad():
Feature
.
fLet
,
Feature
.
fLet
,
Feature
.
fRefCreate
,
Feature
.
fRefCreate
,
Feature
.
fRefRead
,
Feature
.
fRefRead
,
Feature
.
fRefWrite
Feature
.
fRefWrite
,
Feature
.
fGraph
])
])
...
...
tests/python/relay/test_pass_to_cps.py
View file @
54dbcc28
...
@@ -30,6 +30,20 @@ def rand(dtype='float32', *shape):
...
@@ -30,6 +30,20 @@ def rand(dtype='float32', *shape):
return
tvm
.
nd
.
array
(
np
.
random
.
rand
(
*
shape
)
.
astype
(
dtype
))
return
tvm
.
nd
.
array
(
np
.
random
.
rand
(
*
shape
)
.
astype
(
dtype
))
def
test_id
():
x
=
relay
.
var
(
"x"
,
shape
=
[])
id
=
run_infer_type
(
relay
.
Function
([
x
],
x
))
id_cps
=
run_infer_type
(
to_cps
(
id
))
def
test_double
():
t
=
relay
.
TypeVar
(
"t"
)
x
=
relay
.
var
(
"x"
,
t
)
f
=
relay
.
var
(
"f"
,
relay
.
FuncType
([
t
],
t
))
double
=
run_infer_type
(
relay
.
Function
([
f
,
x
],
f
(
f
(
x
)),
t
,
[
t
]))
double_cps
=
run_infer_type
(
to_cps
(
double
))
# make sure cps work for recursion.
# make sure cps work for recursion.
def
test_recursion
():
def
test_recursion
():
mod
=
relay
.
Module
()
mod
=
relay
.
Module
()
...
...
tests/python/relay/test_type_infer.py
View file @
54dbcc28
...
@@ -19,6 +19,7 @@
...
@@ -19,6 +19,7 @@
"""
"""
from
tvm
import
relay
from
tvm
import
relay
from
tvm.relay
import
op
,
transform
,
analysis
from
tvm.relay
import
op
,
transform
,
analysis
from
tvm.relay.analysis
import
assert_alpha_equal
def
run_infer_type
(
expr
,
mod
=
None
):
def
run_infer_type
(
expr
,
mod
=
None
):
...
@@ -349,6 +350,17 @@ def test_adt_match_type_annotations():
...
@@ -349,6 +350,17 @@ def test_adt_match_type_annotations():
assert
ft
.
checked_type
==
relay
.
FuncType
([
tt
],
relay
.
TupleType
([]))
assert
ft
.
checked_type
==
relay
.
FuncType
([
tt
],
relay
.
TupleType
([]))
def
test_let_polymorphism
():
id
=
relay
.
Var
(
"id"
)
xt
=
relay
.
TypeVar
(
"xt"
)
x
=
relay
.
Var
(
"x"
,
xt
)
body
=
relay
.
Tuple
([
id
(
relay
.
const
(
1
)),
id
(
relay
.
Tuple
([]))])
body
=
relay
.
Let
(
id
,
relay
.
Function
([
x
],
x
,
xt
,
[
xt
]),
body
)
body
=
run_infer_type
(
body
)
int32
=
relay
.
TensorType
((),
"int32"
)
assert_alpha_equal
(
body
.
checked_type
,
relay
.
TupleType
([
int32
,
relay
.
TupleType
([])]))
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_free_expr
()
test_free_expr
()
test_dual_op
()
test_dual_op
()
...
@@ -366,3 +378,4 @@ if __name__ == "__main__":
...
@@ -366,3 +378,4 @@ if __name__ == "__main__":
test_constructor_type
()
test_constructor_type
()
test_constructor_call
()
test_constructor_call
()
test_adt_match
()
test_adt_match
()
test_let_polymorphism
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment