Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
953ca1f6
Unverified
Commit
953ca1f6
authored
May 28, 2019
by
Tianqi Chen
Committed by
GitHub
May 28, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[C++] Cleanup transform API nits (#3253)
parent
a8275bdb
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
96 additions
and
76 deletions
+96
-76
include/tvm/relay/transform.h
+72
-35
src/relay/pass/pass_manager.cc
+24
-41
No files found.
include/tvm/relay/transform.h
View file @
953ca1f6
...
@@ -76,8 +76,8 @@ namespace transform {
...
@@ -76,8 +76,8 @@ namespace transform {
class
PassContext
;
class
PassContext
;
/*!
/*!
* \brief PassContextNode contains the information that a pass can rely on,
such as
* \brief PassContextNode contains the information that a pass can rely on,
* analysis results.
*
such as
analysis results.
*/
*/
class
PassContextNode
:
public
RelayNode
{
class
PassContextNode
:
public
RelayNode
{
public
:
public
:
...
@@ -110,32 +110,51 @@ class PassContextNode : public RelayNode {
...
@@ -110,32 +110,51 @@ class PassContextNode : public RelayNode {
TVM_DECLARE_NODE_TYPE_INFO
(
PassContextNode
,
RelayNode
);
TVM_DECLARE_NODE_TYPE_INFO
(
PassContextNode
,
RelayNode
);
};
};
/*!
* \brief PassContext that is used to configure the pass behavior.
*
* \code
*
* auto new_ctx = PassContext::Create();
* ctx->opt_level = 2;
* ctx->fallback_device = kDLCPU;
* With<PassContext> scope(ctx);
* // pass context in effect.
*
* \endcode
*/
class
PassContext
:
public
NodeRef
{
class
PassContext
:
public
NodeRef
{
public
:
public
:
PassContext
()
{}
PassContext
()
{}
explicit
PassContext
(
tvm
::
NodePtr
<
Node
>
n
)
:
NodeRef
(
n
)
{}
explicit
PassContext
(
NodePtr
<::
tvm
::
Node
>
n
)
:
NodeRef
(
n
)
{}
/*!
/*
* \brief const accessor.
* \brief Constructor of a `PassContext` object.
* \return const access pointer.
*
*/
* \param opt_level The optimization level that will be applied.
const
PassContextNode
*
operator
->
()
const
{
* \param fallback_device The fallback device used for heterogeneous
CHECK
(
node_
.
get
()
!=
nullptr
);
* execution.
return
static_cast
<
const
PassContextNode
*>
(
node_
.
get
());
* \param required_pass The passes that are required for a context to execute
}
* other passes.
/*!
* \param required_pass The passes that will be disabled during the
* \brief mutable accessor.
* optimization under a context.
* \return mutable access pointer.
*/
PassContextNode
*
operator
->
()
{
CHECK
(
node_
.
get
()
!=
nullptr
);
return
static_cast
<
PassContextNode
*>
(
node_
.
get
());
}
/*!
* \brief Construct a PassContext containing the default configurations.
* \return The new PassContext.
*/
TVM_DLL
static
PassContext
Create
();
/*!
* \brief Get the default pass context in the current scope.
* \return The pass context.
*/
*/
TVM_DLL
PassContext
(
int
opt_level
,
int
fallback_device
,
tvm
::
Array
<
tvm
::
Expr
>
required_pass
,
tvm
::
Array
<
tvm
::
Expr
>
disabled_pass
);
// Get the currently used pass context.
TVM_DLL
static
PassContext
Current
();
TVM_DLL
static
PassContext
Current
();
const
PassContextNode
*
operator
->
()
const
;
// accessor.
using
ContainerType
=
PassContextNode
;
using
ContainerType
=
PassContextNode
;
class
Internal
;
class
Internal
;
...
@@ -204,25 +223,23 @@ class PassNode : public RelayNode {
...
@@ -204,25 +223,23 @@ class PassNode : public RelayNode {
virtual
PassInfo
Info
()
const
=
0
;
virtual
PassInfo
Info
()
const
=
0
;
/*!
/*!
* \brief Execute the optimization pass using a functor. This functor
* \brief Transform mod using the default PassContext in the current scope.
* internally uses a current pass context.
*
*
* \param mod The module that an optimization pass runs on.
* \param mod The module that an optimization pass runs on.
*
*
* \return The
updat
ed module.
* \return The
transform
ed module.
*/
*/
Module
operator
()(
const
Module
&
mod
)
const
{
Module
operator
()(
const
Module
&
mod
)
const
{
return
this
->
operator
()(
mod
,
PassContext
::
Current
());
return
this
->
operator
()(
mod
,
PassContext
::
Current
());
}
}
/*!
/*!
* \brief
Execute the optimization pass
using a functor under a given pass context.
* \brief
Transform mod
using a functor under a given pass context.
*
*
* \param mod The module that an optimization pass runs on.
* \param mod The module that an optimization pass runs on.
* \param pass_ctx The pass context that will be used to help the execution of
* \param pass_ctx The pass context that can provide information for the optimization.
* optimizations.
*
*
* \return The
updat
ed module.
* \return The
transform
ed module.
*/
*/
virtual
Module
operator
()(
const
Module
&
mod
,
virtual
Module
operator
()(
const
Module
&
mod
,
const
PassContext
&
pass_ctx
)
const
=
0
;
const
PassContext
&
pass_ctx
)
const
=
0
;
...
@@ -235,14 +252,34 @@ class PassNode : public RelayNode {
...
@@ -235,14 +252,34 @@ class PassNode : public RelayNode {
class
Pass
:
public
NodeRef
{
class
Pass
:
public
NodeRef
{
public
:
public
:
Pass
()
=
default
;
/*!
explicit
Pass
(
NodePtr
<
tvm
::
Node
>
p
)
:
NodeRef
(
p
)
{}
* \brief Transform mod using the default PassContext in the current scope.
*
PassNode
*
operator
->
()
const
{
* \param mod The module that an optimization pass runs on.
return
static_cast
<
PassNode
*>
(
this
->
node_
.
get
());
*
* \return The transformed module.
*/
Module
operator
()(
const
Module
&
mod
)
const
{
const
PassNode
*
node
=
operator
->
();
CHECK
(
node
!=
nullptr
);
return
node
->
operator
()(
mod
);
}
/*!
* \brief Transform mod using a functor under a given pass context.
*
* \param mod The module that an optimization pass runs on.
* \param pass_ctx The pass context that can provide information for the optimization.
*
* \return The transformed module.
*/
Module
operator
()(
const
Module
&
mod
,
const
PassContext
&
pass_ctx
)
const
{
const
PassNode
*
node
=
operator
->
();
CHECK
(
node
!=
nullptr
);
return
node
->
operator
()(
mod
,
pass_ctx
);
}
}
using
ContainerType
=
PassNode
;
TVM_DEFINE_NODE_REF_METHODS
(
Pass
,
NodeRef
,
PassNode
)
;
};
};
class
SequentialNode
;
class
SequentialNode
;
...
...
src/relay/pass/pass_manager.cc
View file @
953ca1f6
...
@@ -74,21 +74,6 @@ class OptPassLevel {
...
@@ -74,21 +74,6 @@ class OptPassLevel {
}
}
};
};
PassContext
::
PassContext
(
int
opt_level
,
int
fallback_device
,
tvm
::
Array
<
tvm
::
Expr
>
required_pass
,
tvm
::
Array
<
tvm
::
Expr
>
disabled_pass
)
{
auto
ctx
=
make_node
<
PassContextNode
>
();
ctx
->
opt_level
=
opt_level
;
ctx
->
fallback_device
=
fallback_device
;
ctx
->
required_pass
=
std
::
move
(
required_pass
);
ctx
->
disabled_pass
=
std
::
move
(
disabled_pass
);
node_
=
std
::
move
(
ctx
);
}
const
PassContextNode
*
PassContext
::
operator
->
()
const
{
return
static_cast
<
const
PassContextNode
*>
(
node_
.
get
());
}
struct
RelayPassContextThreadLocalEntry
{
struct
RelayPassContextThreadLocalEntry
{
/*! \brief The default pass context. */
/*! \brief The default pass context. */
PassContext
default_context
;
PassContext
default_context
;
...
@@ -129,6 +114,10 @@ PassContext PassContext::Current() {
...
@@ -129,6 +114,10 @@ PassContext PassContext::Current() {
}
}
}
}
PassContext
PassContext
::
Create
()
{
return
PassContext
(
make_node
<
PassContextNode
>
());
}
class
ModulePass
;
class
ModulePass
;
/*!
/*!
...
@@ -291,7 +280,7 @@ class SequentialNode : public PassNode {
...
@@ -291,7 +280,7 @@ class SequentialNode : public PassNode {
*
*
* \return true if the pass is enabled. Otherwise, false.
* \return true if the pass is enabled. Otherwise, false.
*/
*/
bool
pass_e
nabled
(
const
std
::
string
&
pass_name
)
const
;
bool
PassE
nabled
(
const
std
::
string
&
pass_name
)
const
;
/*!
/*!
* \brief Resolve the pass dependency. It globs all required passes by
* \brief Resolve the pass dependency. It globs all required passes by
...
@@ -353,9 +342,8 @@ ModulePass ModulePassNode::make(
...
@@ -353,9 +342,8 @@ ModulePass ModulePassNode::make(
Module
ModulePassNode
::
operator
()(
const
Module
&
mod
,
Module
ModulePassNode
::
operator
()(
const
Module
&
mod
,
const
PassContext
&
pass_ctx
)
const
{
const
PassContext
&
pass_ctx
)
const
{
PassInfo
pass_info
=
Info
();
PassInfo
pass_info
=
Info
();
LOG
(
INFO
)
<<
"Executing module pass : "
<<
pass_info
.
operator
->
()
->
name
DLOG
(
INFO
)
<<
"Executing module pass : "
<<
pass_info
->
name
<<
" with opt level: "
<<
pass_info
.
operator
->
()
->
opt_level
<<
"
\n
"
;
<<
" with opt level: "
<<
pass_info
->
opt_level
<<
"
\n
"
;
CHECK
(
mod
.
defined
());
CHECK
(
mod
.
defined
());
auto
updated_mod
=
pass_func
(
mod
,
pass_ctx
);
auto
updated_mod
=
pass_func
(
mod
,
pass_ctx
);
CHECK
(
updated_mod
.
defined
());
CHECK
(
updated_mod
.
defined
());
...
@@ -376,11 +364,10 @@ FunctionPass FunctionPassNode::make(
...
@@ -376,11 +364,10 @@ FunctionPass FunctionPassNode::make(
Module
FunctionPassNode
::
operator
()(
const
Module
&
mod
,
Module
FunctionPassNode
::
operator
()(
const
Module
&
mod
,
const
PassContext
&
pass_ctx
)
const
{
const
PassContext
&
pass_ctx
)
const
{
PassInfo
pass_info
=
Info
();
PassInfo
pass_info
=
Info
();
LOG
(
INFO
)
<<
"Executing function pass : "
<<
pass_info
.
operator
->
()
->
name
<<
" with opt level: "
<<
pass_info
.
operator
->
()
->
opt_level
<<
"
\n
"
;
CHECK
(
mod
.
defined
());
CHECK
(
mod
.
defined
());
Module
new_mod
=
ModuleNode
::
make
({},
mod
->
type_definitions
);
Module
new_mod
=
ModuleNode
::
make
({},
mod
->
type_definitions
);
DLOG
(
INFO
)
<<
"Executing module pass : "
<<
pass_info
->
name
<<
" with opt level: "
<<
pass_info
->
opt_level
<<
"
\n
"
;
// Execute the pass function and return a new module.
// Execute the pass function and return a new module.
for
(
const
auto
&
it
:
mod
->
functions
)
{
for
(
const
auto
&
it
:
mod
->
functions
)
{
auto
updated_func
=
SkipFunction
(
it
.
second
)
?
it
.
second
:
pass_func
(
it
.
second
,
mod
,
pass_ctx
);
auto
updated_func
=
SkipFunction
(
it
.
second
)
?
it
.
second
:
pass_func
(
it
.
second
,
mod
,
pass_ctx
);
...
@@ -448,12 +435,11 @@ std::unordered_set<std::string> SequentialNode::RequiredPasses(
...
@@ -448,12 +435,11 @@ std::unordered_set<std::string> SequentialNode::RequiredPasses(
return
ret
;
return
ret
;
}
}
bool
SequentialNode
::
pass_e
nabled
(
const
std
::
string
&
pass_name
)
const
{
bool
SequentialNode
::
PassE
nabled
(
const
std
::
string
&
pass_name
)
const
{
PassContext
ctx
=
PassContext
::
Current
();
PassContext
ctx
=
PassContext
::
Current
();
const
PassContextNode
*
ctx_node
=
ctx
.
operator
->
();
auto
required
=
RequiredPasses
(
ctx
->
required_pass
);
auto
required
=
RequiredPasses
(
ctx_node
->
required_pass
);
auto
disabled
=
DisabledPasses
(
ctx
->
required_pass
);
auto
disabled
=
DisabledPasses
(
ctx_node
->
required_pass
);
if
(
disabled
.
count
(
pass_name
))
{
if
(
disabled
.
count
(
pass_name
))
{
return
false
;
return
false
;
...
@@ -462,7 +448,7 @@ bool SequentialNode::pass_enabled(const std::string& pass_name) const {
...
@@ -462,7 +448,7 @@ bool SequentialNode::pass_enabled(const std::string& pass_name) const {
if
(
required
.
count
(
pass_name
))
{
if
(
required
.
count
(
pass_name
))
{
return
true
;
return
true
;
}
}
return
ctx
_node
->
opt_level
>=
opt_pass_level
[
pass_name
];
return
ctx
->
opt_level
>=
opt_pass_level
[
pass_name
];
}
}
// TODO(zhiics): we currenlty only sequentially execute each pass in
// TODO(zhiics): we currenlty only sequentially execute each pass in
...
@@ -470,15 +456,14 @@ bool SequentialNode::pass_enabled(const std::string& pass_name) const {
...
@@ -470,15 +456,14 @@ bool SequentialNode::pass_enabled(const std::string& pass_name) const {
// ordering problem needed to be handled in the future.
// ordering problem needed to be handled in the future.
Module
SequentialNode
::
operator
()(
const
Module
&
module
,
Module
SequentialNode
::
operator
()(
const
Module
&
module
,
const
PassContext
&
pass_ctx
)
const
{
const
PassContext
&
pass_ctx
)
const
{
const
auto
*
ctx_node
=
pass_ctx
.
operator
->
();
int
opt_level
=
pass_ctx
->
opt_level
;
int
opt_level
=
ctx_node
->
opt_level
;
auto
disabled
=
DisabledPasses
(
pass_ctx
->
disabled_pass
);
auto
disabled
=
DisabledPasses
(
ctx_node
->
disabled_pass
);
Module
mod
=
module
;
Module
mod
=
module
;
for
(
const
Pass
&
pass
:
passes
)
{
for
(
const
Pass
&
pass
:
passes
)
{
CHECK
(
pass
.
defined
())
<<
"Found undefined pass for optimization."
;
CHECK
(
pass
.
defined
())
<<
"Found undefined pass for optimization."
;
PassInfo
info
=
pass
->
Info
();
PassInfo
info
=
pass
->
Info
();
const
auto
&
pass_name
=
info
.
operator
->
()
->
name
;
const
auto
&
pass_name
=
info
->
name
;
const
auto
&
pass_opt_level
=
info
.
operator
->
()
->
opt_level
;
const
auto
&
pass_opt_level
=
info
->
opt_level
;
// Skip the pass if its optimization level is higher that the one of in the
// Skip the pass if its optimization level is higher that the one of in the
// pass context or if this pass is disabled.
// pass context or if this pass is disabled.
if
(
pass_opt_level
>
opt_level
||
disabled
.
count
(
pass_name
))
{
if
(
pass_opt_level
>
opt_level
||
disabled
.
count
(
pass_name
))
{
...
@@ -540,14 +525,7 @@ TVM_REGISTER_API("relay._transform.CreateModulePass")
...
@@ -540,14 +525,7 @@ TVM_REGISTER_API("relay._transform.CreateModulePass")
TVM_REGISTER_API
(
"relay._transform.RunPass"
)
TVM_REGISTER_API
(
"relay._transform.RunPass"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
Pass
pass
=
args
[
0
];
*
ret
=
args
[
0
].
operator
Pass
()(
args
[
1
]);
Module
mod
=
args
[
1
];
CHECK
(
pass
.
defined
())
<<
"Running an undefined pass is not allowed."
<<
"
\n
"
;
const
auto
*
pn
=
pass
.
operator
->
();
*
ret
=
(
*
pn
)(
mod
);
});
});
TVM_STATIC_IR_FUNCTOR_REGISTER
(
IRPrinter
,
vtable
)
TVM_STATIC_IR_FUNCTOR_REGISTER
(
IRPrinter
,
vtable
)
...
@@ -602,11 +580,16 @@ TVM_REGISTER_NODE_TYPE(PassContextNode);
...
@@ -602,11 +580,16 @@ TVM_REGISTER_NODE_TYPE(PassContextNode);
TVM_REGISTER_API
(
"relay._transform.PassContext"
)
TVM_REGISTER_API
(
"relay._transform.PassContext"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
auto
pctx
=
PassContext
::
Create
();
int
opt_level
=
args
[
0
];
int
opt_level
=
args
[
0
];
int
fallback_device
=
args
[
1
];
int
fallback_device
=
args
[
1
];
tvm
::
Array
<
tvm
::
Expr
>
required
=
args
[
2
];
tvm
::
Array
<
tvm
::
Expr
>
required
=
args
[
2
];
tvm
::
Array
<
tvm
::
Expr
>
disabled
=
args
[
3
];
tvm
::
Array
<
tvm
::
Expr
>
disabled
=
args
[
3
];
*
ret
=
PassContext
(
opt_level
,
fallback_device
,
required
,
disabled
);
pctx
->
opt_level
=
opt_level
;
pctx
->
fallback_device
=
fallback_device
;
pctx
->
required_pass
=
std
::
move
(
required
);
pctx
->
disabled_pass
=
std
::
move
(
disabled
);
*
ret
=
pctx
;
});
});
TVM_STATIC_IR_FUNCTOR_REGISTER
(
IRPrinter
,
vtable
)
TVM_STATIC_IR_FUNCTOR_REGISTER
(
IRPrinter
,
vtable
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment