Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
4cee98ba
Unverified
Commit
4cee98ba
authored
Jun 07, 2019
by
Tianqi Chen
Committed by
GitHub
Jun 07, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[PASS][RELAY] polish pass infra (#3319)
parent
ca017a38
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
45 additions
and
125 deletions
+45
-125
3rdparty/dmlc-core
+1
-1
src/relay/pass/pass_manager.cc
+44
-124
No files found.
dmlc-core
@
fbe142b2
Subproject commit
3943914eed66470bd010df581e29e4dca4f7df6f
Subproject commit
fbe142b267a8edd1f1188fa2140d88f7ae308661
src/relay/pass/pass_manager.cc
View file @
4cee98ba
...
...
@@ -37,47 +37,6 @@ namespace transform {
using
tvm
::
IRPrinter
;
namespace
{
// TODO(zhiics) Maybe we can use PackedFunc here so that parameters can be
// handled because we need to register the pass for Python invocation anyway.
Pass
GetPass
(
const
std
::
string
&
pass_name
)
{
if
(
pass_name
==
"InferType"
)
{
return
InferType
();
}
else
if
(
pass_name
==
"AlterOpLayout"
)
{
return
AlterOpLayout
();
}
else
if
(
pass_name
==
"CanonicalizeOps"
)
{
return
CanonicalizeOps
();
}
else
if
(
pass_name
==
"CombineParallelConv2d"
)
{
return
CombineParallelConv2D
();
}
else
if
(
pass_name
==
"DeadCodeElimination"
)
{
return
DeadCodeElimination
();
}
else
if
(
pass_name
==
"EliminateCommonSubexpr"
)
{
return
DeadCodeElimination
();
}
else
if
(
pass_name
==
"FoldConstant"
)
{
return
FoldConstant
();
}
else
if
(
pass_name
==
"BackwardFoldScaleAxis"
)
{
return
FoldScaleAxis
();
}
else
if
(
pass_name
==
"ForwardFoldScaleAxis"
)
{
return
FoldScaleAxis
();
}
else
if
(
pass_name
==
"FoldScaleAxis"
)
{
return
FoldScaleAxis
();
}
else
if
(
pass_name
==
"PartialEvaluate"
)
{
return
SimplifyInference
();
}
else
if
(
pass_name
==
"SimplifyInference"
)
{
return
SimplifyInference
();
}
else
if
(
pass_name
==
"ToANormalForm"
)
{
return
ToANormalForm
();
}
else
if
(
pass_name
==
"ToGraphNormalForm"
)
{
return
ToGraphNormalForm
();
}
else
{
LOG
(
FATAL
)
<<
pass_name
<<
" has not been registered yet."
<<
"
\n
"
;
return
Pass
(
nullptr
);
}
}
}
// namespace
struct
RelayPassContextThreadLocalEntry
{
/*! \brief The default pass context. */
PassContext
default_context
;
...
...
@@ -252,6 +211,7 @@ class SequentialNode : public PassNode {
/*! \brief A list of passes that used to compose a sequential pass. */
tvm
::
Array
<
Pass
>
passes
;
void
VisitAttrs
(
tvm
::
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"pass_info"
,
&
pass_info
);
v
->
Visit
(
"passes"
,
&
passes
);
...
...
@@ -263,22 +223,13 @@ class SequentialNode : public PassNode {
PassInfo
Info
()
const
{
return
pass_info
;
}
/*!
* \brief Add a pass to the pass list.
*
* \param pass The candidate pass to be added.
*/
void
AddPass
(
const
Pass
&
pass
)
{
passes
.
push_back
(
pass
);
}
/*!
* \brief Check if a pass is enabled.
*
* \param
pass_name The name of an optimization/analysis pass
.
* \param
info The pass information
.
*
* \return true if the pass is enabled. Otherwise, false.
*/
bool
PassEnabled
(
const
std
::
string
&
pass_name
)
const
;
bool
PassEnabled
(
const
PassInfo
&
info
)
const
;
/*!
* \brief Resolve the pass dependency. It globs all required passes by
...
...
@@ -294,12 +245,6 @@ class SequentialNode : public PassNode {
*/
void
ResolveDependency
(
const
Module
&
mod
);
std
::
unordered_set
<
std
::
string
>
DisabledPasses
(
const
Array
<
tvm
::
Expr
>&
disabled
)
const
;
std
::
unordered_set
<
std
::
string
>
RequiredPasses
(
const
Array
<
tvm
::
Expr
>&
required
)
const
;
/*!
* \brief Perform optimizations on a series of passes. The aforementioned
* typical pass manager jobs could be done by it. This function could
...
...
@@ -317,7 +262,8 @@ class SequentialNode : public PassNode {
TVM_DECLARE_NODE_TYPE_INFO
(
SequentialNode
,
PassNode
);
};
PassInfo
PassInfoNode
::
make
(
int
opt_level
,
std
::
string
name
,
PassInfo
PassInfoNode
::
make
(
int
opt_level
,
std
::
string
name
,
tvm
::
Array
<
tvm
::
Expr
>
required
)
{
auto
pass_info
=
make_node
<
PassInfoNode
>
();
pass_info
->
opt_level
=
opt_level
;
...
...
@@ -338,23 +284,13 @@ ModulePass ModulePassNode::make(
// Module -> Module optimizations.
Module
ModulePassNode
::
operator
()(
const
Module
&
mod
,
const
PassContext
&
pass_ctx
)
const
{
PassInfo
pass_info
=
Info
();
DLOG
(
INFO
)
<<
"Executing module pass : "
<<
pass_info
->
name
<<
" with opt level: "
<<
pass_info
->
opt_level
<<
"
\n
"
;
const
PassInfo
&
pass_info
=
Info
();
DLOG
(
INFO
)
<<
"Executing module pass : "
<<
pass_info
->
name
<<
" with opt level: "
<<
pass_info
->
opt_level
;
CHECK
(
mod
.
defined
());
Module
updated_mod
=
mod
;
// Execute the required passes in a DFS way.
// TODO(zhiics) We may need to pass validation to detect the cyclic
// dependency.
for
(
const
auto
&
it
:
pass_info
->
required
)
{
const
auto
*
name
=
it
.
as
<
tvm
::
ir
::
StringImm
>
();
CHECK
(
name
);
auto
pass
=
GetPass
(
name
->
value
);
updated_mod
=
pass
(
updated_mod
,
pass_ctx
);
}
updated_mod
=
pass_func
(
updated_mod
,
pass_ctx
);
Module
updated_mod
=
pass_func
(
mod
,
pass_ctx
);
CHECK
(
updated_mod
.
defined
());
return
updated_mod
;
}
...
...
@@ -369,25 +305,15 @@ FunctionPass FunctionPassNode::make(
}
// Perform Module -> Module optimizations at the Function level.
// TODO(zhiics) Check and handle the required passes.
Module
FunctionPassNode
::
operator
()(
const
Module
&
mod
,
const
PassContext
&
pass_ctx
)
const
{
PassInfo
pass_info
=
Info
();
const
PassInfo
&
pass_info
=
Info
();
CHECK
(
mod
.
defined
());
DLOG
(
INFO
)
<<
"Executing module pass : "
<<
pass_info
->
name
<<
" with opt level: "
<<
pass_info
->
opt_level
<<
"
\n
"
;
DLOG
(
INFO
)
<<
"Executing module pass : "
<<
pass_info
->
name
<<
" with opt level: "
<<
pass_info
->
opt_level
;
Module
updated_mod
=
mod
;
// Execute the required passes in a DFS way.
// TODO(zhiics) We may need to pass validation to detect the cyclic
// dependency.
for
(
const
auto
&
it
:
pass_info
->
required
)
{
const
auto
*
name
=
it
.
as
<
tvm
::
ir
::
StringImm
>
();
CHECK
(
name
);
auto
pass
=
GetPass
(
name
->
value
);
updated_mod
=
pass
(
updated_mod
,
pass_ctx
);
}
Module
new_mod
=
ModuleNode
::
make
({},
mod
->
type_definitions
);
// Execute the pass function and return a new module.
for
(
const
auto
&
it
:
mod
->
functions
)
{
...
...
@@ -396,7 +322,6 @@ Module FunctionPassNode::operator()(const Module& mod,
:
pass_func
(
it
.
second
,
updated_mod
,
pass_ctx
);
new_mod
->
Add
(
it
.
first
,
updated_func
);
}
return
new_mod
;
}
...
...
@@ -436,47 +361,40 @@ void SequentialNode::ResolveDependency(const Module& mod) {
<<
"
\n
"
;
}
std
::
unordered_set
<
std
::
string
>
SequentialNode
::
DisabledPasses
(
const
Array
<
tvm
::
Expr
>&
disabled
)
const
{
std
::
unordered_set
<
std
::
string
>
ret
;
for
(
const
auto
&
it
:
disabled
)
{
const
auto
*
str
=
it
.
as
<
tvm
::
ir
::
StringImm
>
();
CHECK
(
str
)
<<
"Disabled pass name must be string."
;
ret
.
emplace
(
str
->
value
);
}
return
ret
;
}
std
::
unordered_set
<
std
::
string
>
SequentialNode
::
RequiredPasses
(
const
Array
<
tvm
::
Expr
>&
required
)
const
{
std
::
unordered_set
<
std
::
string
>
ret
;
for
(
const
auto
&
it
:
required
)
{
const
auto
*
str
=
it
.
as
<
tvm
::
ir
::
StringImm
>
();
CHECK
(
str
)
<<
"Required pass name must be string."
;
ret
.
emplace
(
str
->
value
);
// linearly scan the pass array to match pass_name
inline
bool
PassArrayContains
(
const
Array
<
tvm
::
Expr
>&
pass_array
,
const
std
::
string
&
pass_name
)
{
for
(
auto
x
:
pass_array
)
{
auto
*
str_name
=
x
.
as
<
ir
::
StringImm
>
();
CHECK
(
str_name
)
<<
"pass name must be str"
;
if
(
str_name
->
value
==
pass_name
)
return
true
;
}
return
ret
;
return
false
;
}
bool
SequentialNode
::
PassEnabled
(
const
std
::
string
&
pass_name
)
const
{
bool
SequentialNode
::
PassEnabled
(
const
PassInfo
&
info
)
const
{
PassContext
ctx
=
PassContext
::
Current
();
auto
required
=
RequiredPasses
(
ctx
->
required_pass
);
auto
disabled
=
DisabledPasses
(
ctx
->
disabled_pass
);
if
(
disabled
.
count
(
pass_name
))
{
if
(
PassArrayContains
(
ctx
->
disabled_pass
,
info
->
name
))
{
return
false
;
}
if
(
required
.
count
(
pass_
name
))
{
if
(
PassArrayContains
(
ctx
->
required_pass
,
info
->
name
))
{
return
true
;
}
const
Pass
pass
=
GetPass
(
pass_name
);
PassInfo
info
=
pass
->
Info
();
return
ctx
->
opt_level
>=
info
->
opt_level
;
}
Pass
GetPass
(
const
std
::
string
&
pass_name
)
{
using
tvm
::
runtime
::
Registry
;
std
::
string
fpass_name
=
"relay._transform."
+
pass_name
;
const
auto
*
f
=
Registry
::
Get
(
fpass_name
);
CHECK
(
f
!=
nullptr
)
<<
"Cannot find "
<<
fpass_name
<<
"to create the pass "
<<
pass_name
;
return
(
*
f
)();
}
// TODO(zhiics): we currenlty only sequentially execute each pass in
// a Sequential without the consideration of their orders. The phase
// ordering problem needs to be handled in the future.
...
...
@@ -485,13 +403,15 @@ Module SequentialNode::operator()(const Module& module,
Module
mod
=
module
;
for
(
const
Pass
&
pass
:
passes
)
{
CHECK
(
pass
.
defined
())
<<
"Found undefined pass for optimization."
;
PassInfo
info
=
pass
->
Info
();
const
auto
&
pass_name
=
info
->
name
;
// Execute the pass if it is enabled.
if
(
PassEnabled
(
pass_name
))
{
mod
=
pass
(
mod
,
pass_ctx
);
const
PassInfo
&
pass_info
=
pass
->
Info
();
if
(
!
PassEnabled
(
pass_info
))
continue
;
// resolve dependencies
for
(
const
auto
&
it
:
pass_info
->
required
)
{
const
auto
*
name
=
it
.
as
<
tvm
::
ir
::
StringImm
>
();
CHECK
(
name
);
mod
=
GetPass
(
name
->
value
)(
mod
,
pass_ctx
);
}
mod
=
pass
(
mod
,
pass_ctx
);
}
return
mod
;
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment