Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
953ca1f6
Unverified
Commit
953ca1f6
authored
May 28, 2019
by
Tianqi Chen
Committed by
GitHub
May 28, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[C++] Cleanup transform API nits (#3253)
parent
a8275bdb
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
96 additions
and
76 deletions
+96
-76
include/tvm/relay/transform.h
+72
-35
src/relay/pass/pass_manager.cc
+24
-41
No files found.
include/tvm/relay/transform.h
View file @
953ca1f6
...
...
@@ -76,8 +76,8 @@ namespace transform {
class
PassContext
;
/*!
* \brief PassContextNode contains the information that a pass can rely on,
such as
* analysis results.
* \brief PassContextNode contains the information that a pass can rely on,
*
such as
analysis results.
*/
class
PassContextNode
:
public
RelayNode
{
public
:
...
...
@@ -110,32 +110,51 @@ class PassContextNode : public RelayNode {
TVM_DECLARE_NODE_TYPE_INFO
(
PassContextNode
,
RelayNode
);
};
/*!
* \brief PassContext that is used to configure the pass behavior.
*
* \code
*
* auto new_ctx = PassContext::Create();
* ctx->opt_level = 2;
* ctx->fallback_device = kDLCPU;
* With<PassContext> scope(ctx);
* // pass context in effect.
*
* \endcode
*/
class
PassContext
:
public
NodeRef
{
public
:
PassContext
()
{}
explicit
PassContext
(
tvm
::
NodePtr
<
Node
>
n
)
:
NodeRef
(
n
)
{}
/*
* \brief Constructor of a `PassContext` object.
*
* \param opt_level The optimization level that will be applied.
* \param fallback_device The fallback device used for heterogeneous
* execution.
* \param required_pass The passes that are required for a context to execute
* other passes.
* \param required_pass The passes that will be disabled during the
* optimization under a context.
explicit
PassContext
(
NodePtr
<::
tvm
::
Node
>
n
)
:
NodeRef
(
n
)
{}
/*!
* \brief const accessor.
* \return const access pointer.
*/
const
PassContextNode
*
operator
->
()
const
{
CHECK
(
node_
.
get
()
!=
nullptr
);
return
static_cast
<
const
PassContextNode
*>
(
node_
.
get
());
}
/*!
* \brief mutable accessor.
* \return mutable access pointer.
*/
PassContextNode
*
operator
->
()
{
CHECK
(
node_
.
get
()
!=
nullptr
);
return
static_cast
<
PassContextNode
*>
(
node_
.
get
());
}
/*!
* \brief Construct a PassContext containing the default configurations.
* \return The new PassContext.
*/
TVM_DLL
static
PassContext
Create
();
/*!
* \brief Get the default pass context in the current scope.
* \return The pass context.
*/
TVM_DLL
PassContext
(
int
opt_level
,
int
fallback_device
,
tvm
::
Array
<
tvm
::
Expr
>
required_pass
,
tvm
::
Array
<
tvm
::
Expr
>
disabled_pass
);
// Get the currently used pass context.
TVM_DLL
static
PassContext
Current
();
const
PassContextNode
*
operator
->
()
const
;
// accessor.
using
ContainerType
=
PassContextNode
;
class
Internal
;
...
...
@@ -204,25 +223,23 @@ class PassNode : public RelayNode {
virtual
PassInfo
Info
()
const
=
0
;
/*!
* \brief Execute the optimization pass using a functor. This functor
* internally uses a current pass context.
* \brief Transform mod using the default PassContext in the current scope.
*
* \param mod The module that an optimization pass runs on.
*
* \return The
updat
ed module.
* \return The
transform
ed module.
*/
Module
operator
()(
const
Module
&
mod
)
const
{
return
this
->
operator
()(
mod
,
PassContext
::
Current
());
}
/*!
* \brief
Execute the optimization pass
using a functor under a given pass context.
* \brief
Transform mod
using a functor under a given pass context.
*
* \param mod The module that an optimization pass runs on.
* \param pass_ctx The pass context that will be used to help the execution of
* optimizations.
* \param pass_ctx The pass context that can provide information for the optimization.
*
* \return The
updat
ed module.
* \return The
transform
ed module.
*/
virtual
Module
operator
()(
const
Module
&
mod
,
const
PassContext
&
pass_ctx
)
const
=
0
;
...
...
@@ -235,14 +252,34 @@ class PassNode : public RelayNode {
class
Pass
:
public
NodeRef
{
public
:
Pass
()
=
default
;
explicit
Pass
(
NodePtr
<
tvm
::
Node
>
p
)
:
NodeRef
(
p
)
{}
PassNode
*
operator
->
()
const
{
return
static_cast
<
PassNode
*>
(
this
->
node_
.
get
());
/*!
* \brief Transform mod using the default PassContext in the current scope.
*
* \param mod The module that an optimization pass runs on.
*
* \return The transformed module.
*/
Module
operator
()(
const
Module
&
mod
)
const
{
const
PassNode
*
node
=
operator
->
();
CHECK
(
node
!=
nullptr
);
return
node
->
operator
()(
mod
);
}
/*!
* \brief Transform mod using a functor under a given pass context.
*
* \param mod The module that an optimization pass runs on.
* \param pass_ctx The pass context that can provide information for the optimization.
*
* \return The transformed module.
*/
Module
operator
()(
const
Module
&
mod
,
const
PassContext
&
pass_ctx
)
const
{
const
PassNode
*
node
=
operator
->
();
CHECK
(
node
!=
nullptr
);
return
node
->
operator
()(
mod
,
pass_ctx
);
}
using
ContainerType
=
PassNode
;
TVM_DEFINE_NODE_REF_METHODS
(
Pass
,
NodeRef
,
PassNode
)
;
};
class
SequentialNode
;
...
...
src/relay/pass/pass_manager.cc
View file @
953ca1f6
...
...
@@ -74,21 +74,6 @@ class OptPassLevel {
}
};
PassContext
::
PassContext
(
int
opt_level
,
int
fallback_device
,
tvm
::
Array
<
tvm
::
Expr
>
required_pass
,
tvm
::
Array
<
tvm
::
Expr
>
disabled_pass
)
{
auto
ctx
=
make_node
<
PassContextNode
>
();
ctx
->
opt_level
=
opt_level
;
ctx
->
fallback_device
=
fallback_device
;
ctx
->
required_pass
=
std
::
move
(
required_pass
);
ctx
->
disabled_pass
=
std
::
move
(
disabled_pass
);
node_
=
std
::
move
(
ctx
);
}
const
PassContextNode
*
PassContext
::
operator
->
()
const
{
return
static_cast
<
const
PassContextNode
*>
(
node_
.
get
());
}
struct
RelayPassContextThreadLocalEntry
{
/*! \brief The default pass context. */
PassContext
default_context
;
...
...
@@ -129,6 +114,10 @@ PassContext PassContext::Current() {
}
}
PassContext
PassContext
::
Create
()
{
return
PassContext
(
make_node
<
PassContextNode
>
());
}
class
ModulePass
;
/*!
...
...
@@ -291,7 +280,7 @@ class SequentialNode : public PassNode {
*
* \return true if the pass is enabled. Otherwise, false.
*/
bool
pass_e
nabled
(
const
std
::
string
&
pass_name
)
const
;
bool
PassE
nabled
(
const
std
::
string
&
pass_name
)
const
;
/*!
* \brief Resolve the pass dependency. It globs all required passes by
...
...
@@ -353,9 +342,8 @@ ModulePass ModulePassNode::make(
Module
ModulePassNode
::
operator
()(
const
Module
&
mod
,
const
PassContext
&
pass_ctx
)
const
{
PassInfo
pass_info
=
Info
();
LOG
(
INFO
)
<<
"Executing module pass : "
<<
pass_info
.
operator
->
()
->
name
<<
" with opt level: "
<<
pass_info
.
operator
->
()
->
opt_level
<<
"
\n
"
;
DLOG
(
INFO
)
<<
"Executing module pass : "
<<
pass_info
->
name
<<
" with opt level: "
<<
pass_info
->
opt_level
<<
"
\n
"
;
CHECK
(
mod
.
defined
());
auto
updated_mod
=
pass_func
(
mod
,
pass_ctx
);
CHECK
(
updated_mod
.
defined
());
...
...
@@ -376,11 +364,10 @@ FunctionPass FunctionPassNode::make(
Module
FunctionPassNode
::
operator
()(
const
Module
&
mod
,
const
PassContext
&
pass_ctx
)
const
{
PassInfo
pass_info
=
Info
();
LOG
(
INFO
)
<<
"Executing function pass : "
<<
pass_info
.
operator
->
()
->
name
<<
" with opt level: "
<<
pass_info
.
operator
->
()
->
opt_level
<<
"
\n
"
;
CHECK
(
mod
.
defined
());
Module
new_mod
=
ModuleNode
::
make
({},
mod
->
type_definitions
);
DLOG
(
INFO
)
<<
"Executing module pass : "
<<
pass_info
->
name
<<
" with opt level: "
<<
pass_info
->
opt_level
<<
"
\n
"
;
// Execute the pass function and return a new module.
for
(
const
auto
&
it
:
mod
->
functions
)
{
auto
updated_func
=
SkipFunction
(
it
.
second
)
?
it
.
second
:
pass_func
(
it
.
second
,
mod
,
pass_ctx
);
...
...
@@ -448,12 +435,11 @@ std::unordered_set<std::string> SequentialNode::RequiredPasses(
return
ret
;
}
bool
SequentialNode
::
pass_e
nabled
(
const
std
::
string
&
pass_name
)
const
{
bool
SequentialNode
::
PassE
nabled
(
const
std
::
string
&
pass_name
)
const
{
PassContext
ctx
=
PassContext
::
Current
();
const
PassContextNode
*
ctx_node
=
ctx
.
operator
->
();
auto
required
=
RequiredPasses
(
ctx_node
->
required_pass
);
auto
disabled
=
DisabledPasses
(
ctx_node
->
required_pass
);
auto
required
=
RequiredPasses
(
ctx
->
required_pass
);
auto
disabled
=
DisabledPasses
(
ctx
->
required_pass
);
if
(
disabled
.
count
(
pass_name
))
{
return
false
;
...
...
@@ -462,7 +448,7 @@ bool SequentialNode::pass_enabled(const std::string& pass_name) const {
if
(
required
.
count
(
pass_name
))
{
return
true
;
}
return
ctx
_node
->
opt_level
>=
opt_pass_level
[
pass_name
];
return
ctx
->
opt_level
>=
opt_pass_level
[
pass_name
];
}
// TODO(zhiics): we currenlty only sequentially execute each pass in
...
...
@@ -470,15 +456,14 @@ bool SequentialNode::pass_enabled(const std::string& pass_name) const {
// ordering problem needed to be handled in the future.
Module
SequentialNode
::
operator
()(
const
Module
&
module
,
const
PassContext
&
pass_ctx
)
const
{
const
auto
*
ctx_node
=
pass_ctx
.
operator
->
();
int
opt_level
=
ctx_node
->
opt_level
;
auto
disabled
=
DisabledPasses
(
ctx_node
->
disabled_pass
);
int
opt_level
=
pass_ctx
->
opt_level
;
auto
disabled
=
DisabledPasses
(
pass_ctx
->
disabled_pass
);
Module
mod
=
module
;
for
(
const
Pass
&
pass
:
passes
)
{
CHECK
(
pass
.
defined
())
<<
"Found undefined pass for optimization."
;
PassInfo
info
=
pass
->
Info
();
const
auto
&
pass_name
=
info
.
operator
->
()
->
name
;
const
auto
&
pass_opt_level
=
info
.
operator
->
()
->
opt_level
;
const
auto
&
pass_name
=
info
->
name
;
const
auto
&
pass_opt_level
=
info
->
opt_level
;
// Skip the pass if its optimization level is higher that the one of in the
// pass context or if this pass is disabled.
if
(
pass_opt_level
>
opt_level
||
disabled
.
count
(
pass_name
))
{
...
...
@@ -540,14 +525,7 @@ TVM_REGISTER_API("relay._transform.CreateModulePass")
TVM_REGISTER_API
(
"relay._transform.RunPass"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
Pass
pass
=
args
[
0
];
Module
mod
=
args
[
1
];
CHECK
(
pass
.
defined
())
<<
"Running an undefined pass is not allowed."
<<
"
\n
"
;
const
auto
*
pn
=
pass
.
operator
->
();
*
ret
=
(
*
pn
)(
mod
);
*
ret
=
args
[
0
].
operator
Pass
()(
args
[
1
]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER
(
IRPrinter
,
vtable
)
...
...
@@ -602,11 +580,16 @@ TVM_REGISTER_NODE_TYPE(PassContextNode);
TVM_REGISTER_API
(
"relay._transform.PassContext"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
auto
pctx
=
PassContext
::
Create
();
int
opt_level
=
args
[
0
];
int
fallback_device
=
args
[
1
];
tvm
::
Array
<
tvm
::
Expr
>
required
=
args
[
2
];
tvm
::
Array
<
tvm
::
Expr
>
disabled
=
args
[
3
];
*
ret
=
PassContext
(
opt_level
,
fallback_device
,
required
,
disabled
);
pctx
->
opt_level
=
opt_level
;
pctx
->
fallback_device
=
fallback_device
;
pctx
->
required_pass
=
std
::
move
(
required
);
pctx
->
disabled_pass
=
std
::
move
(
disabled
);
*
ret
=
pctx
;
});
TVM_STATIC_IR_FUNCTOR_REGISTER
(
IRPrinter
,
vtable
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment