Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
4cee98ba
Unverified
Commit
4cee98ba
authored
Jun 07, 2019
by
Tianqi Chen
Committed by
GitHub
Jun 07, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[PASS][RELAY] polish pass infra (#3319)
parent
ca017a38
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
45 additions
and
125 deletions
+45
-125
3rdparty/dmlc-core
+1
-1
src/relay/pass/pass_manager.cc
+44
-124
No files found.
dmlc-core
@
fbe142b2
Subproject commit
3943914eed66470bd010df581e29e4dca4f7df6f
Subproject commit
fbe142b267a8edd1f1188fa2140d88f7ae308661
src/relay/pass/pass_manager.cc
View file @
4cee98ba
...
@@ -37,47 +37,6 @@ namespace transform {
...
@@ -37,47 +37,6 @@ namespace transform {
using
tvm
::
IRPrinter
;
using
tvm
::
IRPrinter
;
namespace
{
// TODO(zhiics) Maybe we can use PackedFunc here so that parameters can be
// handled because we need to register the pass for Python invocation anyway.
Pass
GetPass
(
const
std
::
string
&
pass_name
)
{
if
(
pass_name
==
"InferType"
)
{
return
InferType
();
}
else
if
(
pass_name
==
"AlterOpLayout"
)
{
return
AlterOpLayout
();
}
else
if
(
pass_name
==
"CanonicalizeOps"
)
{
return
CanonicalizeOps
();
}
else
if
(
pass_name
==
"CombineParallelConv2d"
)
{
return
CombineParallelConv2D
();
}
else
if
(
pass_name
==
"DeadCodeElimination"
)
{
return
DeadCodeElimination
();
}
else
if
(
pass_name
==
"EliminateCommonSubexpr"
)
{
return
DeadCodeElimination
();
}
else
if
(
pass_name
==
"FoldConstant"
)
{
return
FoldConstant
();
}
else
if
(
pass_name
==
"BackwardFoldScaleAxis"
)
{
return
FoldScaleAxis
();
}
else
if
(
pass_name
==
"ForwardFoldScaleAxis"
)
{
return
FoldScaleAxis
();
}
else
if
(
pass_name
==
"FoldScaleAxis"
)
{
return
FoldScaleAxis
();
}
else
if
(
pass_name
==
"PartialEvaluate"
)
{
return
SimplifyInference
();
}
else
if
(
pass_name
==
"SimplifyInference"
)
{
return
SimplifyInference
();
}
else
if
(
pass_name
==
"ToANormalForm"
)
{
return
ToANormalForm
();
}
else
if
(
pass_name
==
"ToGraphNormalForm"
)
{
return
ToGraphNormalForm
();
}
else
{
LOG
(
FATAL
)
<<
pass_name
<<
" has not been registered yet."
<<
"
\n
"
;
return
Pass
(
nullptr
);
}
}
}
// namespace
struct
RelayPassContextThreadLocalEntry
{
struct
RelayPassContextThreadLocalEntry
{
/*! \brief The default pass context. */
/*! \brief The default pass context. */
PassContext
default_context
;
PassContext
default_context
;
...
@@ -252,6 +211,7 @@ class SequentialNode : public PassNode {
...
@@ -252,6 +211,7 @@ class SequentialNode : public PassNode {
/*! \brief A list of passes that used to compose a sequential pass. */
/*! \brief A list of passes that used to compose a sequential pass. */
tvm
::
Array
<
Pass
>
passes
;
tvm
::
Array
<
Pass
>
passes
;
void
VisitAttrs
(
tvm
::
AttrVisitor
*
v
)
final
{
void
VisitAttrs
(
tvm
::
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"pass_info"
,
&
pass_info
);
v
->
Visit
(
"pass_info"
,
&
pass_info
);
v
->
Visit
(
"passes"
,
&
passes
);
v
->
Visit
(
"passes"
,
&
passes
);
...
@@ -263,22 +223,13 @@ class SequentialNode : public PassNode {
...
@@ -263,22 +223,13 @@ class SequentialNode : public PassNode {
PassInfo
Info
()
const
{
return
pass_info
;
}
PassInfo
Info
()
const
{
return
pass_info
;
}
/*!
/*!
* \brief Add a pass to the pass list.
*
* \param pass The candidate pass to be added.
*/
void
AddPass
(
const
Pass
&
pass
)
{
passes
.
push_back
(
pass
);
}
/*!
* \brief Check if a pass is enabled.
* \brief Check if a pass is enabled.
*
*
* \param
pass_name The name of an optimization/analysis pass
.
* \param
info The pass information
.
*
*
* \return true if the pass is enabled. Otherwise, false.
* \return true if the pass is enabled. Otherwise, false.
*/
*/
bool
PassEnabled
(
const
std
::
string
&
pass_name
)
const
;
bool
PassEnabled
(
const
PassInfo
&
info
)
const
;
/*!
/*!
* \brief Resolve the pass dependency. It globs all required passes by
* \brief Resolve the pass dependency. It globs all required passes by
...
@@ -294,12 +245,6 @@ class SequentialNode : public PassNode {
...
@@ -294,12 +245,6 @@ class SequentialNode : public PassNode {
*/
*/
void
ResolveDependency
(
const
Module
&
mod
);
void
ResolveDependency
(
const
Module
&
mod
);
std
::
unordered_set
<
std
::
string
>
DisabledPasses
(
const
Array
<
tvm
::
Expr
>&
disabled
)
const
;
std
::
unordered_set
<
std
::
string
>
RequiredPasses
(
const
Array
<
tvm
::
Expr
>&
required
)
const
;
/*!
/*!
* \brief Perform optimizations on a series of passes. The aforementioned
* \brief Perform optimizations on a series of passes. The aforementioned
* typical pass manager jobs could be done by it. This function could
* typical pass manager jobs could be done by it. This function could
...
@@ -317,7 +262,8 @@ class SequentialNode : public PassNode {
...
@@ -317,7 +262,8 @@ class SequentialNode : public PassNode {
TVM_DECLARE_NODE_TYPE_INFO
(
SequentialNode
,
PassNode
);
TVM_DECLARE_NODE_TYPE_INFO
(
SequentialNode
,
PassNode
);
};
};
PassInfo
PassInfoNode
::
make
(
int
opt_level
,
std
::
string
name
,
PassInfo
PassInfoNode
::
make
(
int
opt_level
,
std
::
string
name
,
tvm
::
Array
<
tvm
::
Expr
>
required
)
{
tvm
::
Array
<
tvm
::
Expr
>
required
)
{
auto
pass_info
=
make_node
<
PassInfoNode
>
();
auto
pass_info
=
make_node
<
PassInfoNode
>
();
pass_info
->
opt_level
=
opt_level
;
pass_info
->
opt_level
=
opt_level
;
...
@@ -338,23 +284,13 @@ ModulePass ModulePassNode::make(
...
@@ -338,23 +284,13 @@ ModulePass ModulePassNode::make(
// Module -> Module optimizations.
// Module -> Module optimizations.
Module
ModulePassNode
::
operator
()(
const
Module
&
mod
,
Module
ModulePassNode
::
operator
()(
const
Module
&
mod
,
const
PassContext
&
pass_ctx
)
const
{
const
PassContext
&
pass_ctx
)
const
{
PassInfo
pass_info
=
Info
();
const
PassInfo
&
pass_info
=
Info
();
DLOG
(
INFO
)
<<
"Executing module pass : "
<<
pass_info
->
name
DLOG
(
INFO
)
<<
"Executing module pass : "
<<
" with opt level: "
<<
pass_info
->
opt_level
<<
"
\n
"
;
<<
pass_info
->
name
<<
" with opt level: "
<<
pass_info
->
opt_level
;
CHECK
(
mod
.
defined
());
CHECK
(
mod
.
defined
());
Module
updated_mod
=
mod
;
Module
updated_mod
=
pass_func
(
mod
,
pass_ctx
);
// Execute the required passes in a DFS way.
// TODO(zhiics) We may need to pass validation to detect the cyclic
// dependency.
for
(
const
auto
&
it
:
pass_info
->
required
)
{
const
auto
*
name
=
it
.
as
<
tvm
::
ir
::
StringImm
>
();
CHECK
(
name
);
auto
pass
=
GetPass
(
name
->
value
);
updated_mod
=
pass
(
updated_mod
,
pass_ctx
);
}
updated_mod
=
pass_func
(
updated_mod
,
pass_ctx
);
CHECK
(
updated_mod
.
defined
());
CHECK
(
updated_mod
.
defined
());
return
updated_mod
;
return
updated_mod
;
}
}
...
@@ -369,25 +305,15 @@ FunctionPass FunctionPassNode::make(
...
@@ -369,25 +305,15 @@ FunctionPass FunctionPassNode::make(
}
}
// Perform Module -> Module optimizations at the Function level.
// Perform Module -> Module optimizations at the Function level.
// TODO(zhiics) Check and handle the required passes.
Module
FunctionPassNode
::
operator
()(
const
Module
&
mod
,
Module
FunctionPassNode
::
operator
()(
const
Module
&
mod
,
const
PassContext
&
pass_ctx
)
const
{
const
PassContext
&
pass_ctx
)
const
{
PassInfo
pass_info
=
Info
();
const
PassInfo
&
pass_info
=
Info
();
CHECK
(
mod
.
defined
());
CHECK
(
mod
.
defined
());
DLOG
(
INFO
)
<<
"Executing module pass : "
<<
pass_info
->
name
DLOG
(
INFO
)
<<
"Executing module pass : "
<<
" with opt level: "
<<
pass_info
->
opt_level
<<
"
\n
"
;
<<
pass_info
->
name
<<
" with opt level: "
<<
pass_info
->
opt_level
;
Module
updated_mod
=
mod
;
Module
updated_mod
=
mod
;
// Execute the required passes in a DFS way.
// TODO(zhiics) We may need to pass validation to detect the cyclic
// dependency.
for
(
const
auto
&
it
:
pass_info
->
required
)
{
const
auto
*
name
=
it
.
as
<
tvm
::
ir
::
StringImm
>
();
CHECK
(
name
);
auto
pass
=
GetPass
(
name
->
value
);
updated_mod
=
pass
(
updated_mod
,
pass_ctx
);
}
Module
new_mod
=
ModuleNode
::
make
({},
mod
->
type_definitions
);
Module
new_mod
=
ModuleNode
::
make
({},
mod
->
type_definitions
);
// Execute the pass function and return a new module.
// Execute the pass function and return a new module.
for
(
const
auto
&
it
:
mod
->
functions
)
{
for
(
const
auto
&
it
:
mod
->
functions
)
{
...
@@ -396,7 +322,6 @@ Module FunctionPassNode::operator()(const Module& mod,
...
@@ -396,7 +322,6 @@ Module FunctionPassNode::operator()(const Module& mod,
:
pass_func
(
it
.
second
,
updated_mod
,
pass_ctx
);
:
pass_func
(
it
.
second
,
updated_mod
,
pass_ctx
);
new_mod
->
Add
(
it
.
first
,
updated_func
);
new_mod
->
Add
(
it
.
first
,
updated_func
);
}
}
return
new_mod
;
return
new_mod
;
}
}
...
@@ -436,47 +361,40 @@ void SequentialNode::ResolveDependency(const Module& mod) {
...
@@ -436,47 +361,40 @@ void SequentialNode::ResolveDependency(const Module& mod) {
<<
"
\n
"
;
<<
"
\n
"
;
}
}
std
::
unordered_set
<
std
::
string
>
SequentialNode
::
DisabledPasses
(
// linearly scan the pass array to match pass_name
const
Array
<
tvm
::
Expr
>&
disabled
)
const
{
inline
bool
PassArrayContains
(
const
Array
<
tvm
::
Expr
>&
pass_array
,
std
::
unordered_set
<
std
::
string
>
ret
;
const
std
::
string
&
pass_name
)
{
for
(
const
auto
&
it
:
disabled
)
{
for
(
auto
x
:
pass_array
)
{
const
auto
*
str
=
it
.
as
<
tvm
::
ir
::
StringImm
>
();
auto
*
str_name
=
x
.
as
<
ir
::
StringImm
>
();
CHECK
(
str
)
<<
"Disabled pass name must be string."
;
CHECK
(
str_name
)
<<
"pass name must be str"
;
ret
.
emplace
(
str
->
value
);
if
(
str_name
->
value
==
pass_name
)
return
true
;
}
return
ret
;
}
std
::
unordered_set
<
std
::
string
>
SequentialNode
::
RequiredPasses
(
const
Array
<
tvm
::
Expr
>&
required
)
const
{
std
::
unordered_set
<
std
::
string
>
ret
;
for
(
const
auto
&
it
:
required
)
{
const
auto
*
str
=
it
.
as
<
tvm
::
ir
::
StringImm
>
();
CHECK
(
str
)
<<
"Required pass name must be string."
;
ret
.
emplace
(
str
->
value
);
}
}
return
ret
;
return
false
;
}
}
bool
SequentialNode
::
PassEnabled
(
const
std
::
string
&
pass_name
)
const
{
bool
SequentialNode
::
PassEnabled
(
const
PassInfo
&
info
)
const
{
PassContext
ctx
=
PassContext
::
Current
();
PassContext
ctx
=
PassContext
::
Current
();
auto
required
=
RequiredPasses
(
ctx
->
required_pass
);
if
(
PassArrayContains
(
ctx
->
disabled_pass
,
info
->
name
))
{
auto
disabled
=
DisabledPasses
(
ctx
->
disabled_pass
);
if
(
disabled
.
count
(
pass_name
))
{
return
false
;
return
false
;
}
}
if
(
required
.
count
(
pass_
name
))
{
if
(
PassArrayContains
(
ctx
->
required_pass
,
info
->
name
))
{
return
true
;
return
true
;
}
}
const
Pass
pass
=
GetPass
(
pass_name
);
PassInfo
info
=
pass
->
Info
();
return
ctx
->
opt_level
>=
info
->
opt_level
;
return
ctx
->
opt_level
>=
info
->
opt_level
;
}
}
Pass
GetPass
(
const
std
::
string
&
pass_name
)
{
using
tvm
::
runtime
::
Registry
;
std
::
string
fpass_name
=
"relay._transform."
+
pass_name
;
const
auto
*
f
=
Registry
::
Get
(
fpass_name
);
CHECK
(
f
!=
nullptr
)
<<
"Cannot find "
<<
fpass_name
<<
"to create the pass "
<<
pass_name
;
return
(
*
f
)();
}
// TODO(zhiics): we currenlty only sequentially execute each pass in
// TODO(zhiics): we currenlty only sequentially execute each pass in
// a Sequential without the consideration of their orders. The phase
// a Sequential without the consideration of their orders. The phase
// ordering problem needs to be handled in the future.
// ordering problem needs to be handled in the future.
...
@@ -485,13 +403,15 @@ Module SequentialNode::operator()(const Module& module,
...
@@ -485,13 +403,15 @@ Module SequentialNode::operator()(const Module& module,
Module
mod
=
module
;
Module
mod
=
module
;
for
(
const
Pass
&
pass
:
passes
)
{
for
(
const
Pass
&
pass
:
passes
)
{
CHECK
(
pass
.
defined
())
<<
"Found undefined pass for optimization."
;
CHECK
(
pass
.
defined
())
<<
"Found undefined pass for optimization."
;
const
PassInfo
&
pass_info
=
pass
->
Info
();
PassInfo
info
=
pass
->
Info
();
if
(
!
PassEnabled
(
pass_info
))
continue
;
const
auto
&
pass_name
=
info
->
name
;
// resolve dependencies
// Execute the pass if it is enabled.
for
(
const
auto
&
it
:
pass_info
->
required
)
{
if
(
PassEnabled
(
pass_name
))
{
const
auto
*
name
=
it
.
as
<
tvm
::
ir
::
StringImm
>
();
mod
=
pass
(
mod
,
pass_ctx
);
CHECK
(
name
);
mod
=
GetPass
(
name
->
value
)(
mod
,
pass_ctx
);
}
}
mod
=
pass
(
mod
,
pass_ctx
);
}
}
return
mod
;
return
mod
;
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment