Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
89a88c57
Commit
89a88c57
authored
May 24, 2019
by
雾雨魔理沙
Committed by
Tianqi Chen
May 24, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Relay] Start porting pass to the pass manager (#3191)
parent
7e648417
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
328 additions
and
69 deletions
+328
-69
include/tvm/relay/pass.h
+87
-52
include/tvm/relay/transform.h
+114
-6
src/relay/pass/dead_code.cc
+12
-0
src/relay/pass/device_annotation.cc
+12
-0
src/relay/pass/fold_constant.cc
+12
-0
src/relay/pass/forward_rewrite.cc
+31
-0
src/relay/pass/fuse_ops.cc
+14
-0
src/relay/pass/partial_eval.cc
+12
-0
src/relay/pass/pass_manager.cc
+8
-9
src/relay/pass/to_a_normal_form.cc
+12
-0
src/relay/pass/to_graph_normal_form.cc
+12
-0
tests/python/relay/test_pass_manager.py
+2
-2
No files found.
include/tvm/relay/pass.h
View file @
89a88c57
This diff is collapsed.
Click to expand it.
include/tvm/relay/transform.h
View file @
89a88c57
...
...
@@ -61,6 +61,7 @@
#include <tvm/relay/error.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/module.h>
#include <tvm/relay/op_attr_types.h>
#include <string>
#include <unordered_map>
#include <vector>
...
...
@@ -198,7 +199,7 @@ class Pass;
*/
class
PassNode
:
public
RelayNode
{
public
:
/*
/*
!
* \brief Get the pass information/meta data. */
virtual
PassInfo
Info
()
const
=
0
;
...
...
@@ -300,11 +301,118 @@ Pass CreateModulePass(
*
* \return The created function pass.
*/
Pass
CreateFunctionPass
(
const
runtime
::
TypedPackedFunc
<
Function
(
Function
,
PassContext
)
>&
pass_func
,
int
opt_level
,
const
std
::
string
&
name
,
const
tvm
::
Array
<
tvm
::
Expr
>&
required
);
TVM_DLL
Pass
CreateFunctionPass
(
const
runtime
::
TypedPackedFunc
<
Function
(
Function
,
Module
,
PassContext
)
>&
pass_func
,
int
opt_level
,
const
std
::
string
&
name
,
const
tvm
::
Array
<
tvm
::
Expr
>&
required
);
/*! \brief Remove expressions which does not effect the program result.
*
* It will remove let bindings which are not referenced,
* and inline let bindings that are only used once.
*
* For example, this pass should turn `let a = 1 in 2` into `2`,
* as the value of the expression does not depend on a.
*
* As another example, `let a = 1 in a` will be optimized into 1.
*
* \return the pass.
*/
TVM_DLL
Pass
DeadCodeElimination
();
/*!
* \brief Fold constant expressions.
*
* \return The pass.
*/
TVM_DLL
Pass
FoldConstant
();
/*!
* \brief Fuse operations into expr into seperate functions.
*
* \param fuse_opt_level Optimization level. If it is -1 it will be inferred from pass context.
*
* \return The pass.
*/
TVM_DLL
Pass
FuseOps
(
int
fuse_opt_level
=
-
1
);
/*!
* \brief Apply rewrite rules to rewrite the expr in post DFS order.
*
* \param rewrite_map_attr_name The Op's attr name which corresponds to the rewrite
* rule function.
* \param fcontext Additional callback to provide context argument for each call node.
* \param fmulti_ref_trigger Transformation function to be called when
* an Expr consumed by multiple callers.
*
* \return The pass.
*/
TVM_DLL
Pass
ForwardRewrite
(
const
std
::
string
&
rewrite_map_attr_name
,
std
::
function
<
NodeRef
(
const
Call
&
)
>
fcontext
=
nullptr
,
std
::
function
<
Expr
(
const
Expr
&
)
>
fmulti_ref_trigger
=
nullptr
);
/*!
* \brief Apply rewrite rules to rewrite the expr in post DFS order.
*
* \param rewrite_func The rewrite func that will apply to all operators.
* \param fcontext Additional callback to provide context argument for each call node.
* \param fmulti_ref_trigger Transformation function to be called when
* an Expr consumed by multiple callers.
*
* \return The pass.
*/
TVM_DLL
Pass
ForwardRewrite
(
const
FForwardRewrite
&
rewrite_func
,
std
::
function
<
NodeRef
(
const
Call
&
)
>
fcontext
=
nullptr
,
std
::
function
<
Expr
(
const
Expr
&
)
>
fmulti_ref_trigger
=
nullptr
);
/*!
* \brief Rewrite the annotated program.
*
* \param fallback_device The fallback device which is the default device for
* operators without annotation.
*
* \return The pass.
*/
TVM_DLL
Pass
RewriteAnnotatedOps
(
int
fallback_device
);
/*!
* \brief turn a dataflow graph into Administrative Normal Form, or A-Normal Form (ANF).
*
* It will turn an expression that is in a graph form (with sharing implicit),
* to an expression with explicit sharing (A-Normal Form).
*
* The scope of the root expression is the global scope.
*
* The scope of any non root expression is the least common ancestor of all it's scope.
*
* Values are ordered by post-DFS order in each scope.
*
* \return The pass.
*/
TVM_DLL
Pass
ToANormalForm
();
/*!
* \brief Remove let binding and directly share via pointer instead.
*
* It will remove all let binding,
* and turn all of the variable bound by let into direct pointer reference.
*
* \return the expression in graph normal form.
*/
TVM_DLL
Pass
ToGraphNormalForm
();
/*!
* \brief Aggressive constant propagation/constant folding/inlining.
*
* It will do as much computation in compile time as possible.
* It has two benefit: remove runtime overhead, and allow more optimization (typically fusion).
* As a side effect, code size will explode.
*
* \return the optimized expression.
*/
TVM_DLL
Pass
PartialEval
();
}
// namespace transform
}
// namespace relay
...
...
src/relay/pass/dead_code.cc
View file @
89a88c57
...
...
@@ -151,5 +151,17 @@ Expr DeadCodeElimination(const Expr& e) {
TVM_REGISTER_API
(
"relay._ir_pass.dead_code_elimination"
)
.
set_body_typed
(
DeadCodeElimination
);
namespace
transform
{
Pass
DeadCodeElimination
()
{
runtime
::
TypedPackedFunc
<
Function
(
Function
,
Module
,
PassContext
)
>
pass_func
=
[
=
](
Function
f
,
Module
m
,
PassContext
pc
)
{
return
Downcast
<
Function
>
(
DeadCodeElimination
(
f
));
};
return
CreateFunctionPass
(
pass_func
,
1
,
"dead_code_elimination"
,
{});
}
}
// namespace transform
}
// namespace relay
}
// namespace tvm
src/relay/pass/device_annotation.cc
View file @
89a88c57
...
...
@@ -550,6 +550,18 @@ TVM_REGISTER_API("relay._ir_pass.RewriteDeviceAnnotation")
TVM_REGISTER_API
(
"relay._ir_pass.CollectDeviceAnnotationOps"
)
.
set_body_typed
(
CollectDeviceAnnotationOps
);
namespace
transform
{
Pass
RewriteAnnotatedOps
(
int
fallback_device
)
{
runtime
::
TypedPackedFunc
<
Function
(
Function
,
Module
,
PassContext
)
>
pass_func
=
[
=
](
Function
f
,
Module
m
,
PassContext
pc
)
{
return
Downcast
<
Function
>
(
RewriteAnnotatedOps
(
f
,
fallback_device
));
};
return
CreateFunctionPass
(
pass_func
,
1
,
"rewrite_annotated_ops"
,
{});
}
}
// namespace transform
}
// namespace relay
}
// namespace tvm
src/relay/pass/fold_constant.cc
View file @
89a88c57
...
...
@@ -215,5 +215,17 @@ Expr FoldConstant(const Expr& expr) {
TVM_REGISTER_API
(
"relay._ir_pass.FoldConstant"
)
.
set_body_typed
(
FoldConstant
);
namespace
transform
{
Pass
FoldConstant
()
{
runtime
::
TypedPackedFunc
<
Function
(
Function
,
Module
,
PassContext
)
>
pass_func
=
[
=
](
Function
f
,
Module
m
,
PassContext
pc
)
{
return
Downcast
<
Function
>
(
FoldConstant
(
f
));
};
return
CreateFunctionPass
(
pass_func
,
1
,
"fold_constant"
,
{});
}
}
// namespace transform
}
// namespace relay
}
// namespace tvm
src/relay/pass/forward_rewrite.cc
View file @
89a88c57
...
...
@@ -206,6 +206,37 @@ Expr ForwardRewrite(const Expr& expr,
return
ForwardRewriter
(
&
rewrite_func
,
fcontext
,
fmulti_ref_trigger
).
Rewrite
(
expr
);
}
namespace
transform
{
using
std
::
function
;
Pass
ForwardRewrite
(
const
std
::
string
&
rewrite_map_attr_name
,
function
<
NodeRef
(
const
Call
&
)
>
fcontext
,
function
<
Expr
(
const
Expr
&
)
>
fmulti_ref_trigger
)
{
runtime
::
TypedPackedFunc
<
Function
(
Function
,
Module
,
PassContext
)
>
pass_func
=
[
=
](
Function
f
,
Module
m
,
PassContext
pc
)
{
return
Downcast
<
Function
>
(
ForwardRewrite
(
f
,
rewrite_map_attr_name
,
fcontext
,
fmulti_ref_trigger
));
};
return
CreateFunctionPass
(
pass_func
,
1
,
"forward_rewrite"
,
{});
}
Pass
ForwardRewrite
(
const
FForwardRewrite
&
rewrite_func
,
function
<
NodeRef
(
const
Call
&
)
>
fcontext
,
function
<
Expr
(
const
Expr
&
)
>
fmulti_ref_trigger
)
{
runtime
::
TypedPackedFunc
<
Function
(
Function
,
Module
,
PassContext
)
>
pass_func
=
[
=
](
Function
f
,
Module
m
,
PassContext
pc
)
{
return
Downcast
<
Function
>
(
ForwardRewrite
(
f
,
rewrite_func
,
fcontext
,
fmulti_ref_trigger
));
};
return
CreateFunctionPass
(
pass_func
,
1
,
"forward_rewrite"
,
{});
}
}
// namespace transform
}
// namespace relay
}
// namespace tvm
src/relay/pass/fuse_ops.cc
View file @
89a88c57
...
...
@@ -964,5 +964,19 @@ Expr FuseOps(const Expr& expr, int fuse_opt_level, const Module& module) {
TVM_REGISTER_API
(
"relay._ir_pass.FuseOps"
)
.
set_body_typed
(
FuseOps
);
namespace
transform
{
Pass
FuseOps
(
int
fuse_opt_level
)
{
runtime
::
TypedPackedFunc
<
Function
(
Function
,
Module
,
PassContext
)
>
pass_func
=
[
=
](
Function
f
,
Module
m
,
PassContext
pc
)
{
int
opt_level
=
fuse_opt_level
==
-
1
?
pc
->
opt_level
:
fuse_opt_level
;
return
Downcast
<
Function
>
(
FuseOps
(
f
,
opt_level
,
m
));
};
return
CreateFunctionPass
(
pass_func
,
1
,
"fuse_ops"
,
{});
}
}
// namespace transform
}
// namespace relay
}
// namespace tvm
src/relay/pass/partial_eval.cc
View file @
89a88c57
...
...
@@ -801,5 +801,17 @@ TVM_REGISTER_API("relay._ir_pass.partial_evaluate")
*
ret
=
PartialEval
(
args
[
0
]);
});
namespace
transform
{
Pass
PartialEval
()
{
runtime
::
TypedPackedFunc
<
Function
(
Function
,
Module
,
PassContext
)
>
pass_func
=
[
=
](
Function
f
,
Module
m
,
PassContext
pc
)
{
return
Downcast
<
Function
>
(
PartialEval
(
f
));
};
return
CreateFunctionPass
(
pass_func
,
1
,
"partial_eval"
,
{});
}
}
// namespace transform
}
// namespace relay
}
// namespace tvm
src/relay/pass/pass_manager.cc
View file @
89a88c57
...
...
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...
...
@@ -201,7 +201,7 @@ class FunctionPassNode : public PassNode {
* `pass_func` and let it run on a given module. The same `pass_func` will
* then be applied on each function in the module.
*/
runtime
::
TypedPackedFunc
<
Function
(
Function
,
PassContext
)
>
pass_func
;
runtime
::
TypedPackedFunc
<
Function
(
Function
,
Module
,
PassContext
)
>
pass_func
;
FunctionPassNode
()
=
default
;
...
...
@@ -225,7 +225,7 @@ class FunctionPassNode : public PassNode {
PassInfo
Info
()
const
{
return
pass_info
;
}
TVM_DLL
static
FunctionPass
make
(
runtime
::
TypedPackedFunc
<
Function
(
Function
,
PassContext
)
>
pass_func
,
runtime
::
TypedPackedFunc
<
Function
(
Function
,
Module
,
PassContext
)
>
pass_func
,
PassInfo
pass_info
);
static
constexpr
const
char
*
_type_key
=
"relay.FunctionPass"
;
...
...
@@ -363,7 +363,7 @@ Module ModulePassNode::operator()(const Module& mod,
}
FunctionPass
FunctionPassNode
::
make
(
runtime
::
TypedPackedFunc
<
Function
(
Function
,
PassContext
)
>
pass_func
,
runtime
::
TypedPackedFunc
<
Function
(
Function
,
Module
,
PassContext
)
>
pass_func
,
PassInfo
pass_info
)
{
auto
n
=
make_node
<
FunctionPassNode
>
();
n
->
pass_func
=
std
::
move
(
pass_func
);
...
...
@@ -383,8 +383,7 @@ Module FunctionPassNode::operator()(const Module& mod,
// Execute the pass function and return a new module.
for
(
const
auto
&
it
:
mod
->
functions
)
{
auto
updated_func
=
SkipFunction
(
it
.
second
)
?
it
.
second
:
pass_func
(
it
.
second
,
pass_ctx
);
auto
updated_func
=
SkipFunction
(
it
.
second
)
?
it
.
second
:
pass_func
(
it
.
second
,
mod
,
pass_ctx
);
new_mod
->
Add
(
it
.
first
,
updated_func
);
}
...
...
@@ -501,7 +500,7 @@ Pass CreateModulePass(
}
Pass
CreateFunctionPass
(
const
runtime
::
TypedPackedFunc
<
Function
(
Function
,
PassContext
)
>&
pass_func
,
const
runtime
::
TypedPackedFunc
<
Function
(
Function
,
Module
,
PassContext
)
>&
pass_func
,
int
opt_level
,
const
std
::
string
&
name
,
const
tvm
::
Array
<
tvm
::
Expr
>&
required
)
{
...
...
@@ -589,7 +588,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
tvm
::
IRPrinter
*
p
)
{
const
PassInfoNode
*
seq_pn
=
node
->
Info
().
operator
->
();
p
->
stream
<<
"Run Sequential pass: "
<<
seq_pn
->
name
<<
" at the optimization level
. "
<<
seq_pn
->
opt_level
;
<<
" at the optimization level
"
<<
seq_pn
->
opt_level
<<
". "
;
p
->
stream
<<
"The passes will be executed are: ["
;
for
(
const
auto
&
it
:
node
->
passes
)
{
const
PassNode
*
pn
=
it
.
operator
->
();
...
...
src/relay/pass/to_a_normal_form.cc
View file @
89a88c57
...
...
@@ -333,5 +333,17 @@ Expr ToANormalForm(const Expr& e, const Module& m) {
TVM_REGISTER_API
(
"relay._ir_pass.to_a_normal_form"
)
.
set_body_typed
(
static_cast
<
Expr
(
*
)(
const
Expr
&
,
const
Module
&
)
>
(
ToANormalForm
));
namespace
transform
{
Pass
ToANormalForm
()
{
runtime
::
TypedPackedFunc
<
Function
(
Function
,
Module
,
PassContext
)
>
pass_func
=
[
=
](
Function
f
,
Module
m
,
PassContext
pc
)
{
return
Downcast
<
Function
>
(
ToANormalForm
(
f
,
m
));
};
return
CreateFunctionPass
(
pass_func
,
1
,
"to_a_normal_form"
,
{});
}
}
// namespace transform
}
// namespace relay
}
// namespace tvm
src/relay/pass/to_graph_normal_form.cc
View file @
89a88c57
...
...
@@ -79,5 +79,17 @@ Expr ToGraphNormalForm(const Expr& e) {
TVM_REGISTER_API
(
"relay._ir_pass.to_graph_normal_form"
)
.
set_body_typed
(
ToGraphNormalForm
);
namespace
transform
{
Pass
ToGraphNormalForm
()
{
runtime
::
TypedPackedFunc
<
Function
(
Function
,
Module
,
PassContext
)
>
pass_func
=
[
=
](
Function
f
,
Module
m
,
PassContext
pc
)
{
return
Downcast
<
Function
>
(
ToGraphNormalForm
(
f
));
};
return
CreateFunctionPass
(
pass_func
,
1
,
"to_graph_normal_form"
,
{});
}
}
// namespace transform
}
// namespace relay
}
// namespace tvm
tests/python/relay/test_pass_manager.py
View file @
89a88c57
...
...
@@ -204,7 +204,7 @@ def test_function_pass():
pass_ctx
=
None
@_transform.function_pass
(
opt_level
=
opt_level
,
name
=
pass_name
)
def
transform
(
expr
,
ctx
):
def
transform
(
expr
,
mod
,
ctx
):
return
opt_tester
.
transform
(
expr
,
ctx
)
def
get_ref_log
():
...
...
@@ -303,7 +303,7 @@ def test_sequential_pass():
# Register a function pass.
@_transform.function_pass
(
opt_level
=
1
)
def
func_transform
(
expr
,
ctx
):
def
func_transform
(
expr
,
mod
,
ctx
):
return
opt_tester
.
transform
(
expr
,
ctx
)
function_pass
=
func_transform
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment