Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
b3f3ab55
Commit
b3f3ab55
authored
Jul 03, 2019
by
雾雨魔理沙
Committed by
Tianqi Chen
Jul 03, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Relay] Fix PE (#3482)
parent
287078c3
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
101 additions
and
67 deletions
+101
-67
include/tvm/relay/module.h
+1
-1
src/relay/ir/expr_functor.cc
+20
-6
src/relay/ir/type_functor.cc
+5
-1
src/relay/ir/type_functor.h
+1
-0
src/relay/pass/let_list.h
+1
-1
src/relay/pass/partial_eval.cc
+60
-56
src/relay/pass/type_infer.cc
+1
-0
src/relay/pass/util.cc
+10
-0
tests/python/relay/test_pass_partial_eval.py
+2
-2
No files found.
include/tvm/relay/module.h
View file @
b3f3ab55
...
...
@@ -55,7 +55,7 @@ struct Module;
* The functional style allows users to construct custom
* environments easily, for example each thread can store
* a Module while auto-tuning.
*
*
/
*/
class
ModuleNode
:
public
RelayNode
{
public
:
...
...
src/relay/ir/expr_functor.cc
View file @
b3f3ab55
...
...
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...
...
@@ -18,7 +18,7 @@
*/
/*!
* Copyright (c) 201
8
by Contributors
* Copyright (c) 201
9
by Contributors
* \file src/tvm/relay/expr_mutator.cc
* \brief A wrapper around ExprFunctor which functionally updates the AST.
*
...
...
@@ -26,6 +26,7 @@
* the cost of using functional updates.
*/
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h>
#include "type_functor.h"
namespace
tvm
{
...
...
@@ -353,7 +354,7 @@ TVM_REGISTER_API("relay._analysis.post_order_visit")
});
// Implement bind.
class
ExprBinder
:
public
ExprMutator
{
class
ExprBinder
:
public
ExprMutator
,
PatternMutator
{
public
:
explicit
ExprBinder
(
const
tvm
::
Map
<
Var
,
Expr
>&
args_map
)
:
args_map_
(
args_map
)
{
...
...
@@ -383,13 +384,26 @@ class ExprBinder : public ExprMutator {
}
}
Pattern
VisitPattern
(
const
Pattern
&
p
)
final
{
return
PatternMutator
::
VisitPattern
(
p
);
}
Clause
VisitClause
(
const
Clause
&
c
)
final
{
Pattern
pat
=
VisitPattern
(
c
->
lhs
);
return
ClauseNode
::
make
(
pat
,
VisitExpr
(
c
->
rhs
));
}
Var
VisitVar
(
const
Var
&
v
)
final
{
return
Downcast
<
Var
>
(
VisitExpr
(
v
));
}
private
:
const
tvm
::
Map
<
Var
,
Expr
>&
args_map_
;
};
Expr
Bind
(
const
Expr
&
expr
,
const
tvm
::
Map
<
Var
,
Expr
>&
args_map
)
{
if
(
const
FunctionNode
*
func
=
expr
.
as
<
FunctionNode
>
())
{
Expr
new_body
=
ExprBinder
(
args_map
).
Mutate
(
func
->
body
);
Expr
new_body
=
ExprBinder
(
args_map
).
VisitExpr
(
func
->
body
);
Array
<
Var
>
new_params
;
for
(
Var
param
:
func
->
params
)
{
if
(
!
args_map
.
count
(
param
))
{
...
...
@@ -406,7 +420,7 @@ Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& args_map) {
func
->
type_params
,
func
->
attrs
);
}
else
{
return
ExprBinder
(
args_map
).
Mutate
(
expr
);
return
ExprBinder
(
args_map
).
VisitExpr
(
expr
);
}
}
...
...
src/relay/ir/type_functor.cc
View file @
b3f3ab55
...
...
@@ -92,6 +92,10 @@ void TypeVisitor::VisitType_(const TypeDataNode* op) {
}
}
Type
TypeMutator
::
VisitType
(
const
Type
&
t
)
{
return
t
.
defined
()
?
TypeFunctor
<
Type
(
const
Type
&
)
>::
VisitType
(
t
)
:
t
;
}
// Type Mutator.
Array
<
Type
>
TypeMutator
::
MutateArray
(
Array
<
Type
>
arr
)
{
// The array will do copy on write
...
...
@@ -221,7 +225,7 @@ class TypeBinder : public TypeMutator {
};
Type
Bind
(
const
Type
&
type
,
const
tvm
::
Map
<
TypeVar
,
Type
>&
args_map
)
{
return
type
.
defined
()
?
TypeBinder
(
args_map
).
VisitType
(
type
)
:
type
;
return
TypeBinder
(
args_map
).
VisitType
(
type
)
;
}
}
// namespace relay
...
...
src/relay/ir/type_functor.h
View file @
b3f3ab55
...
...
@@ -139,6 +139,7 @@ class TypeVisitor : public TypeFunctor<void(const Type& n)> {
// Mutator that transform a type to another one.
class
TypeMutator
:
public
TypeFunctor
<
Type
(
const
Type
&
n
)
>
{
public
:
Type
VisitType
(
const
Type
&
t
)
override
;
Type
VisitType_
(
const
TypeVarNode
*
op
)
override
;
Type
VisitType_
(
const
TensorTypeNode
*
op
)
override
;
Type
VisitType_
(
const
IncompleteTypeNode
*
op
)
override
;
...
...
src/relay/pass/let_list.h
View file @
b3f3ab55
...
...
@@ -48,7 +48,7 @@ class LetList {
public
:
~
LetList
()
{
if
(
lets_
.
size
()
>
0
&&
!
used_
)
{
std
::
cout
<<
"Warning: letlist not used"
<<
std
::
endl
;
LOG
(
WARNING
)
<<
"letlist not used"
;
}
}
/*!
...
...
src/relay/pass/partial_eval.cc
View file @
b3f3ab55
...
...
@@ -64,7 +64,7 @@
* 3: The generated code reuses bindings (although they are not shadowed),
* so we have to deduplicate them.
*
* 4: In the generated code, multiple VarNode might have same Id.
* 4: In the generated code,
as it call TypeSubst,
multiple VarNode might have same Id.
* While it is permitted, most pass use NodeHash for Var,
* and having multiple VarNode for same Id break them.
* Thus we remap them to a single Id for now.
...
...
@@ -216,9 +216,9 @@ Static MkSRef() {
}
using
Func
=
std
::
function
<
PStatic
(
const
std
::
vector
<
PStatic
>&
,
const
Attrs
&
,
const
Array
<
Type
>&
,
LetList
*
)
>
;
const
Attrs
&
,
const
Array
<
Type
>&
,
LetList
*
)
>
;
struct
SFuncNode
:
StaticNode
{
Func
func
;
...
...
@@ -256,6 +256,7 @@ class Environment {
void
Insert
(
const
Var
&
v
,
const
PStatic
&
ps
)
{
CHECK
(
ps
.
defined
());
CHECK_EQ
(
env_
.
back
().
locals
.
count
(
v
),
0
);
env_
.
back
().
locals
[
v
]
=
ps
;
}
...
...
@@ -287,12 +288,17 @@ class Environment {
/*!
* \brief As our store require rollback, we implement it as a frame.
* every time we need to copy the store, a new frame is insert.
* every time we roll back, a frame is popped.
*
* Every time we need to copy the store, a new frame is insert.
* Every time we roll back, a frame is popped.
*/
struct
StoreFrame
{
std
::
unordered_map
<
const
SRefNode
*
,
PStatic
>
store
;
/*! \brief on unknown effect, history_valid is set to true to signal above frame is outdated */
/*!
* \brief On unknown effect, history_valid is set to true to signal above frame is outdated.
*
* It only outdate the frame above it, but not the current frame.
*/
bool
history_valid
=
true
;
explicit
StoreFrame
(
const
std
::
unordered_map
<
const
SRefNode
*
,
PStatic
>&
store
)
:
store
(
store
)
{
}
StoreFrame
()
=
default
;
...
...
@@ -310,6 +316,7 @@ class Store {
}
void
Insert
(
const
SRefNode
*
r
,
const
PStatic
&
ps
)
{
CHECK
(
r
);
store_
.
back
().
store
[
r
]
=
ps
;
}
...
...
@@ -317,19 +324,21 @@ class Store {
PStatic
Lookup
(
const
SRefNode
*
r
)
{
auto
rit
=
store_
.
rbegin
();
while
(
rit
!=
store_
.
rend
())
{
if
(
!
rit
->
history_valid
)
{
return
PStatic
();
}
if
(
rit
->
store
.
find
(
r
)
!=
rit
->
store
.
end
())
{
return
rit
->
store
.
find
(
r
)
->
second
;
}
if
(
!
rit
->
history_valid
)
{
return
PStatic
();
}
++
rit
;
}
return
PStatic
();
}
void
Invalidate
()
{
store_
.
back
().
history_valid
=
false
;
StoreFrame
sf
;
sf
.
history_valid
=
false
;
store_
.
push_back
(
sf
);
}
private
:
...
...
@@ -341,6 +350,10 @@ class Store {
store_
->
store_
.
push_back
(
StoreFrame
());
}
~
StoreFrameContext
()
{
// push one history valid frame off.
while
(
!
store_
->
store_
.
back
().
history_valid
)
{
store_
->
store_
.
pop_back
();
}
store_
->
store_
.
pop_back
();
}
};
...
...
@@ -442,13 +455,7 @@ Function AsFunc(const Expr& e) {
class
PartialEvaluator
:
public
ExprFunctor
<
PStatic
(
const
Expr
&
e
,
LetList
*
ll
)
>
,
public
PatternFunctor
<
MatchStatus
(
const
Pattern
&
,
const
PStatic
&
)
>
{
public
:
PartialEvaluator
(
const
tvm
::
Array
<
Var
>&
free_vars
,
const
Module
&
mod
)
:
mod_
(
mod
)
{
for
(
const
Var
&
v
:
free_vars
)
{
env_
.
Insert
(
v
,
NoStatic
(
v
));
}
}
PartialEvaluator
(
const
Module
&
mod
)
:
mod_
(
mod
)
{
}
PStatic
VisitExpr
(
const
Expr
&
e
,
LetList
*
ll
)
final
{
PStatic
ret
=
ExprFunctor
<
PStatic
(
const
Expr
&
,
LetList
*
)
>::
VisitExpr
(
e
,
ll
);
...
...
@@ -484,23 +491,23 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
return
env_
.
Lookup
(
GetRef
<
Var
>
(
op
));
}
PStatic
Visit
Expr_
(
const
GlobalVarNode
*
op
,
LetList
*
ll
)
final
{
GlobalVar
gv
=
GetRef
<
GlobalVar
>
(
op
);
PStatic
Visit
GlobalVar
(
const
GlobalVar
&
gv
)
{
CHECK
(
mod_
.
defined
()
);
if
(
gv_map_
.
count
(
gv
)
==
0
)
{
if
(
mod_
.
defined
())
{
Function
func
=
mod_
->
Lookup
(
gv
);
InitializeFuncId
(
func
);
Func
f
=
VisitFuncStatic
(
func
,
gv
);
gv_map_
.
insert
({
gv
,
HasStatic
(
MkSFunc
(
f
),
gv
)});
func
=
AsFunc
(
PostProcess
(
VisitFuncDynamic
(
func
,
f
)));
mod_
->
Update
(
gv
,
func
);
}
else
{
gv_map_
.
insert
({
gv
,
NoStatic
(
gv
)});
}
Function
func
=
mod_
->
Lookup
(
gv
);
InitializeFuncId
(
func
);
Func
f
=
VisitFuncStatic
(
func
,
gv
);
gv_map_
.
insert
({
gv
,
HasStatic
(
MkSFunc
(
f
),
gv
)});
func
=
AsFunc
(
PostProcess
(
VisitFuncDynamic
(
func
,
f
)));
mod_
->
Update
(
gv
,
func
);
}
return
gv_map_
.
at
(
gv
);
}
PStatic
VisitExpr_
(
const
GlobalVarNode
*
op
,
LetList
*
ll
)
final
{
return
VisitGlobalVar
(
GetRef
<
GlobalVar
>
(
op
));
}
PStatic
VisitExpr_
(
const
LetNode
*
op
,
LetList
*
ll
)
final
{
env_
.
Insert
(
op
->
var
,
VisitExpr
(
op
->
value
,
ll
));
return
VisitExpr
(
op
->
body
,
ll
);
...
...
@@ -629,7 +636,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
subst
.
Set
(
func
->
type_params
[
i
],
type_args
[
i
]);
}
for
(
size_t
i
=
type_args
.
size
();
i
<
func
->
type_params
.
size
();
++
i
)
{
subst
.
Set
(
func
->
type_params
[
i
],
Type
(
));
subst
.
Set
(
func
->
type_params
[
i
],
IncompleteTypeNode
::
make
(
kType
));
}
std
::
vector
<
Time
>
args_time
;
for
(
const
auto
&
v
:
pv
)
{
...
...
@@ -672,22 +679,22 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
};
}
Expr
VisitFuncDynamic
(
const
Function
&
func
,
const
Func
&
f
)
{
return
store_
.
Extend
<
Expr
>
([
&
]()
{
store_
.
Invalidate
();
return
FunctionNode
::
make
(
func
->
params
,
LetList
::
With
([
&
](
LetList
*
ll
)
{
std
::
vector
<
PStatic
>
pv
;
for
(
const
auto
&
v
:
func
->
params
)
{
pv
.
push_back
(
NoStatic
(
v
));
}
tvm
::
Array
<
Type
>
type_args
;
for
(
const
auto
&
tp
:
func
->
type_params
)
{
type_args
.
push_back
(
tp
);
}
return
f
(
pv
,
Attrs
(),
type_args
,
ll
)
->
dynamic
;
}),
func
->
ret_type
,
func
->
type_params
,
func
->
attrs
);
});
store_
.
Invalidate
();
return
FunctionNode
::
make
(
func
->
params
,
LetList
::
With
([
&
](
LetList
*
ll
)
{
std
::
vector
<
PStatic
>
pv
;
for
(
const
auto
&
v
:
func
->
params
)
{
pv
.
push_back
(
NoStatic
(
v
));
}
tvm
::
Array
<
Type
>
type_args
;
for
(
const
auto
&
tp
:
func
->
type_params
)
{
type_args
.
push_back
(
tp
);
}
return
f
(
pv
,
Attrs
(),
type_args
,
ll
)
->
dynamic
;
}),
func
->
ret_type
,
func
->
type_params
,
func
->
attrs
);
});
}
PStatic
VisitFunc
(
const
Function
&
func
,
LetList
*
ll
)
{
...
...
@@ -1012,17 +1019,14 @@ Expr PostProcess(const Expr& e) {
Module
PartialEval
(
const
Module
&
m
)
{
CHECK
(
m
->
entry_func
.
defined
());
auto
func
=
m
->
Lookup
(
m
->
entry_func
);
Expr
ret
=
TransformF
([
&
](
const
Expr
&
e
)
{
return
LetList
::
With
([
&
](
LetList
*
ll
)
{
relay
::
partial_eval
::
PartialEvaluator
pe
(
FreeVars
(
e
),
m
);
pe
.
InitializeFuncId
(
e
);
return
relay
::
partial_eval
::
PostProcess
(
pe
.
VisitExpr
(
e
,
ll
)
->
dynamic
);
});
},
func
);
CHECK
(
ret
->
is_type
<
FunctionNode
>
());
m
->
Update
(
m
->
entry_func
,
Downcast
<
Function
>
(
ret
));
relay
::
partial_eval
::
PartialEvaluator
pe
(
m
);
std
::
vector
<
GlobalVar
>
gvs
;
for
(
const
auto
&
p
:
m
->
functions
)
{
gvs
.
push_back
(
p
.
first
);
}
for
(
const
auto
&
gv
:
gvs
)
{
pe
.
VisitGlobalVar
(
gv
);
}
return
m
;
}
...
...
src/relay/pass/type_infer.cc
View file @
b3f3ab55
...
...
@@ -172,6 +172,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
return
it
->
second
.
checked_type
;
}
Type
ret
=
this
->
VisitExpr
(
expr
);
CHECK
(
ret
.
defined
());
KindCheck
(
ret
,
mod_
);
ResolvedTypeInfo
&
rti
=
type_map_
[
expr
];
rti
.
checked_type
=
ret
;
...
...
src/relay/pass/util.cc
View file @
b3f3ab55
...
...
@@ -425,6 +425,16 @@ Expr TypeSubst(const Expr& expr, const tvm::Map<TypeVar, Type>& subst_map) {
Var
VisitVar
(
const
Var
&
v
)
final
{
return
Downcast
<
Var
>
(
VisitExpr
(
v
));
}
Pattern
VisitPattern
(
const
Pattern
&
p
)
final
{
return
PatternMutator
::
VisitPattern
(
p
);
}
Clause
VisitClause
(
const
Clause
&
c
)
final
{
Pattern
pat
=
VisitPattern
(
c
->
lhs
);
return
ClauseNode
::
make
(
pat
,
VisitExpr
(
c
->
rhs
));
}
private
:
const
tvm
::
Map
<
TypeVar
,
Type
>&
subst_map_
;
};
...
...
tests/python/relay/test_pass_partial_eval.py
View file @
b3f3ab55
...
...
@@ -307,10 +307,10 @@ def test_double():
if
__name__
==
'__main__'
:
test_
empty_ad
()
test_
ref
()
test_tuple
()
test_empty_ad
()
test_const_inline
()
test_ref
()
test_ad
()
test_if_ref
()
test_function_invalidate
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment