Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
70041c48
Commit
70041c48
authored
Jun 10, 2019
by
Zhi
Committed by
Jared Roesch
Jun 10, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[relay][vm] move vm opt passes to pass manager (#3323)
parent
8f219b95
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
150 additions
and
113 deletions
+150
-113
python/tvm/relay/backend/vm.py
+33
-19
src/relay/backend/vm/compiler.cc
+17
-7
src/relay/backend/vm/inline_primitives.cc
+49
-43
src/relay/backend/vm/lambda_lift.cc
+41
-39
src/relay/pass/pass_manager.cc
+10
-5
No files found.
python/tvm/relay/backend/vm.py
View file @
70041c48
...
...
@@ -20,24 +20,45 @@ The Relay Virtual Vachine.
Implements a Python interface to compiling and executing on the Relay VM.
"""
import
numpy
as
np
import
tvm
from
tvm._ffi.function
import
Object
import
numpy
as
np
from
..
import
ir_pass
from
..
import
transform
from
..backend.interpreter
import
Executor
from
..expr
import
GlobalVar
,
Function
,
Expr
from
..expr
import
GlobalVar
,
Expr
from
.
import
_vm
Object
=
Object
def
optimize
(
expr
,
mod
=
None
):
# TODO: We need to move this optimization code into the optimizer/pass manager
ck_expr
=
ir_pass
.
infer_type
(
expr
,
mod
=
mod
)
simplified_expr
=
ir_pass
.
simplify_inference
(
ck_expr
)
simplified_expr
=
ir_pass
.
infer_type
(
simplified_expr
,
mod
=
mod
)
fused_expr
=
ir_pass
.
fuse_ops
(
simplified_expr
,
mod
=
mod
)
ck_fused
=
ir_pass
.
infer_type
(
fused_expr
,
mod
=
mod
)
return
ck_fused
def
optimize
(
mod
):
"""Perform several optimizations on a module before executing it in the
Relay virtual machine.
Parameters
----------
mod : tvm.relay.Module
The module to optimize.
Returns
-------
ret : tvm.relay.Module
The optimized module.
"""
main_func
=
mod
[
mod
.
entry_func
]
opt_passes
=
[]
if
not
main_func
.
params
and
isinstance
(
main_func
.
body
,
GlobalVar
):
opt_passes
.
append
(
transform
.
EtaExpand
())
opt_passes
=
opt_passes
+
[
transform
.
SimplifyInference
(),
transform
.
FuseOps
(),
transform
.
InferType
()
]
seq
=
transform
.
Sequential
(
opt_passes
)
return
seq
(
mod
)
def
_convert
(
arg
,
cargs
):
if
isinstance
(
arg
,
np
.
ndarray
):
...
...
@@ -76,15 +97,8 @@ def _eval_vm(mod, ctx, *args):
args: List[tvm.NDArray, np.ndarray]
The arguments to evaluate.
"""
main_func
=
mod
[
mod
.
entry_func
]
if
not
main_func
.
params
and
isinstance
(
main_func
.
body
,
GlobalVar
):
main_func
=
ir_pass
.
eta_expand
(
main_func
.
body
,
mod
)
assert
isinstance
(
main_func
,
Function
)
main_func
=
optimize
(
mod
[
mod
.
entry_func
],
mod
)
mod
[
mod
.
entry_func
]
=
main_func
mod
=
optimize
(
mod
)
args
=
list
(
args
)
assert
isinstance
(
args
,
list
)
cargs
=
convert
(
args
)
...
...
src/relay/backend/vm/compiler.cc
View file @
70041c48
...
...
@@ -27,7 +27,7 @@
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/interpreter.h>
#include <tvm/logging.h>
#include <tvm/relay/
pass
.h>
#include <tvm/relay/
transform
.h>
#include <tvm/runtime/vm.h>
#include <iostream>
#include <unordered_map>
...
...
@@ -38,15 +38,22 @@
namespace
tvm
{
namespace
relay
{
namespace
transform
{
Pass
LambdaLift
();
Pass
InlinePrimitives
();
}
// namespace transform
namespace
vm
{
using
namespace
tvm
::
runtime
;
using
namespace
tvm
::
runtime
::
vm
;
using
namespace
relay
::
transform
;
// (@jroesch): VM passes, eventually declare as passes.
bool
IsClosure
(
const
Function
&
func
);
Module
LambdaLift
(
const
Module
&
module
);
Module
InlinePrimitives
(
const
Module
&
module
);
template
<
typename
T
,
typename
U
>
using
NodeMap
=
std
::
unordered_map
<
T
,
U
,
NodeHash
,
NodeEqual
>
;
...
...
@@ -560,10 +567,13 @@ VMFunction CompileFunc(VMCompilerContext* context, const GlobalVar& var, const F
}
Module
OptimizeModule
(
const
Module
&
mod
)
{
ToANormalForm
(
mod
->
entry_func
,
mod
);
InlinePrimitives
(
mod
);
LambdaLift
(
mod
);
return
InlinePrimitives
(
mod
);
transform
::
Sequential
seq
({
transform
::
ToANormalForm
(),
transform
::
InlinePrimitives
(),
transform
::
LambdaLift
(),
transform
::
InlinePrimitives
()});
auto
pass_ctx
=
transform
::
PassContext
::
Create
();
tvm
::
With
<
relay
::
transform
::
PassContext
>
ctx
(
pass_ctx
);
return
seq
(
mod
);
}
void
PopulateGlobalMap
(
GlobalMap
*
global_map
,
const
Module
&
mod
)
{
...
...
src/relay/backend/vm/inline_primitives.cc
View file @
70041c48
...
...
@@ -26,7 +26,7 @@
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/logging.h>
#include <tvm/relay/
pass
.h>
#include <tvm/relay/
transform
.h>
#include <tvm/runtime/vm.h>
#include <iostream>
#include <vector>
...
...
@@ -37,6 +37,21 @@ namespace tvm {
namespace
relay
{
namespace
vm
{
// TODO(@jroesch): write verifier
/* This pass will eliminate primitives which have been lifted by the ANF
* transform inlining them directly into call sites.
*
* This makes VM related code generation easier as the call target is always
* a primitive function.
*
* let prim = fn(...) { ... };
* prim(...)
*
* will become:
*
* (fn(...) { ... })(...)
*/
struct
PrimitiveInliner
:
ExprMutator
{
Module
module_
;
std
::
unordered_map
<
Var
,
Expr
,
NodeHash
,
NodeEqual
>
var_map
;
...
...
@@ -92,55 +107,46 @@ struct PrimitiveInliner : ExprMutator {
}
}
Function
Inline
(
const
Function
&
func
)
{
DLOG
(
INFO
)
<<
"Before inlining primitives: "
<<
std
::
endl
<<
"func= "
<<
AsText
(
func
,
false
)
<<
std
::
endl
;
auto
inlined
=
FunctionNode
::
make
(
func
->
params
,
VisitExpr
(
func
->
body
),
func
->
ret_type
,
func
->
type_params
,
func
->
attrs
);
inlined
=
Downcast
<
Function
>
(
DeadCodeElimination
(
inlined
));
DLOG
(
INFO
)
<<
"After inlining primitives"
<<
std
::
endl
<<
"after_func= "
<<
AsText
(
inlined
,
false
)
<<
std
::
endl
;
return
inlined
;
Module
Inline
()
{
auto
gvar_funcs
=
module_
->
functions
;
for
(
auto
pair
:
gvar_funcs
)
{
auto
global
=
pair
.
first
;
auto
func
=
pair
.
second
;
DLOG
(
INFO
)
<<
"Before inlining primitives: "
<<
global
<<
std
::
endl
<<
AsText
(
func
,
false
);
func
=
FunctionNode
::
make
(
func
->
params
,
VisitExpr
(
func
->
body
),
func
->
ret_type
,
func
->
type_params
,
func
->
attrs
);
module_
->
Add
(
global
,
func
,
true
);
DLOG
(
INFO
)
<<
"After inlining primitives: "
<<
global
<<
std
::
endl
<<
AsText
(
func
,
false
);
}
return
module_
;
}
};
// TODO(@jroesch): write verifier
/* This pass will eliminate primitives which have been lifted by the ANF
* transform inlining them directly into call sites.
*
* This makes VM related code generation easier as the call target is always
* a primitive function.
*
* let prim = fn(...) { ... };
* prim(...)
*
* will become:
*
* (fn(...) { ... })(...)
*/
Module
InlinePrimitives
(
const
Module
&
module
)
{
PrimitiveInliner
inliner
(
module
);
}
// namespace vm
tvm
::
Map
<
GlobalVar
,
Function
>
updates
;
namespace
transform
{
// There is an ordering bug here.
for
(
auto
pair
:
module
->
functions
)
{
auto
global
=
pair
.
first
;
auto
func
=
pair
.
second
;
updates
.
Set
(
global
,
inliner
.
Inline
(
func
));
}
Pass
InlinePrimitives
()
{
runtime
::
TypedPackedFunc
<
Module
(
Module
,
PassContext
)
>
pass_func
=
[
=
](
Module
m
,
PassContext
pc
)
{
return
relay
::
vm
::
PrimitiveInliner
(
m
).
Inline
();
};
auto
inline_pass
=
CreateModulePass
(
pass_func
,
1
,
"Inline"
,
{});
// Eliminate dead code for each function after inlining.
return
Sequential
({
inline_pass
,
DeadCodeElimination
()},
"InlinePrimitives"
);
}
for
(
auto
pair
:
updates
)
{
module
->
Add
(
pair
.
first
,
pair
.
second
,
true
);
}
TVM_REGISTER_API
(
"relay._transform.InlinePrimitives"
)
.
set_body_typed
(
InlinePrimitives
);
return
module
;
}
}
// namespace transform
}
// namespace vm
}
// namespace relay
}
// namespace tvm
src/relay/backend/vm/lambda_lift.cc
View file @
70041c48
...
...
@@ -27,6 +27,7 @@
#include <tvm/relay/expr_functor.h>
#include <tvm/logging.h>
#include <tvm/relay/pass.h>
#include <tvm/relay/transform.h>
#include <tvm/runtime/vm.h>
#include <iostream>
#include <vector>
...
...
@@ -54,9 +55,14 @@ Function MarkClosure(const Function& func) {
return
FunctionSetAttr
(
func
,
kIsClosure
,
tvm
::
Integer
(
1
));
}
/* The goal of this class is to lift out any nested functions into top-level
* functions.
*
* We will lift a function out into a global which takes the set of the free
* vars and then return the new created function.
*/
struct
LambdaLifter
:
ExprMutator
{
Module
module_
;
std
::
vector
<
std
::
pair
<
GlobalVar
,
Function
>>
lifted_
;
explicit
LambdaLifter
(
const
Module
&
module
)
:
module_
(
module
)
{}
Expr
VisitExpr_
(
const
FunctionNode
*
func_node
)
final
{
...
...
@@ -71,8 +77,7 @@ struct LambdaLifter : ExprMutator {
auto
free_type_vars
=
FreeTypeVars
(
func
,
module_
);
auto
body
=
Downcast
<
Function
>
(
ExprMutator
::
VisitExpr_
(
func_node
));
// When performing this optimization there are two
// cases.
// When performing this optimization there are two cases.
//
// The first case in which we have no free variables
// we can just lift the function into the global
...
...
@@ -80,7 +85,7 @@ struct LambdaLifter : ExprMutator {
//
//
// The second case requires that we generate a special
// function w
it
h makes a distinction between allocating
// function w
hic
h makes a distinction between allocating
// a closure, and then the code for the closure.
//
// We represent a closure allocation by lifting the
...
...
@@ -92,7 +97,7 @@ struct LambdaLifter : ExprMutator {
// function marked as a closure is used to emit allocation
// code for the closure's environment.
//
// The "inner" function
is
should be used to generate the
// The "inner" function should be used to generate the
// code for the closure.
Function
lifted_func
;
if
(
free_vars
.
size
()
==
0
)
{
...
...
@@ -107,16 +112,16 @@ struct LambdaLifter : ExprMutator {
CHECK
(
lifted_func
.
defined
());
auto
name
=
GenerateName
(
lifted_func
);
auto
global
=
this
->
module_
->
GetGlobalVar
(
name
);
auto
global
=
module_
->
GetGlobalVar
(
name
);
lifted_
.
push_back
({
global
,
lifted_func
});
// Add the lifted function to the module.
module_
->
Add
(
global
,
lifted_func
);
if
(
free_vars
.
size
()
==
0
)
{
return
std
::
move
(
global
);
}
else
{
// If we need to allocate a closure
// we pass the variables in its environment
// here.
// If we need to allocate a closure,
// we pass the variables in its environment here.
Array
<
Expr
>
fvs
;
for
(
auto
fv
:
free_vars
)
{
fvs
.
push_back
(
fv
);
...
...
@@ -125,42 +130,39 @@ struct LambdaLifter : ExprMutator {
}
}
Function
Lift
(
const
Function
&
func
)
{
DLOG
(
INFO
)
<<
"Lifting: "
<<
AsText
(
func
,
false
)
<<
std
::
endl
;
return
FunctionNode
::
make
(
func
->
params
,
VisitExpr
(
func
->
body
),
func
->
ret_type
,
func
->
type_params
,
func
->
attrs
);
Module
Lift
()
{
// There is an ordering bug here.
auto
glob_funcs
=
module_
->
functions
;
for
(
auto
pair
:
glob_funcs
)
{
auto
func
=
pair
.
second
;
DLOG
(
INFO
)
<<
"Lifting "
<<
AsText
(
func
,
false
);
func
=
FunctionNode
::
make
(
func
->
params
,
VisitExpr
(
func
->
body
),
func
->
ret_type
,
func
->
type_params
,
func
->
attrs
);
module_
->
Add
(
pair
.
first
,
func
,
true
);
}
return
module_
;
}
};
/* The goal of this pass is to lift out any nested functions into top-level
* functions.
*
* We will lift the functions out into globals which take the set of the free vars
* and then return a function whcih has b
*/
Module
LambdaLift
(
const
Module
&
module
)
{
LambdaLifter
lifter
(
module
);
tvm
::
Map
<
GlobalVar
,
Function
>
updates
;
}
// namespace vm
// There is an ordering bug here.
for
(
auto
pair
:
module
->
functions
)
{
auto
global
=
pair
.
first
;
auto
func
=
pair
.
second
;
updates
.
Set
(
global
,
lifter
.
Lift
(
func
));
}
namespace
transform
{
for
(
auto
i
=
lifter
.
lifted_
.
begin
();
i
!=
lifter
.
lifted_
.
end
();
i
++
)
{
module
->
Add
(
i
->
first
,
i
->
second
);
}
Pass
LambdaLift
()
{
runtime
::
TypedPackedFunc
<
Module
(
Module
,
PassContext
)
>
pass_func
=
[
=
](
Module
m
,
PassContext
pc
)
{
return
relay
::
vm
::
LambdaLifter
(
m
).
Lift
();
};
return
CreateModulePass
(
pass_func
,
1
,
"LambdaLift"
,
{});
}
for
(
auto
pair
:
updates
)
{
module
->
Add
(
pair
.
first
,
pair
.
second
,
true
);
}
TVM_REGISTER_API
(
"relay._transform.LambdaLift"
)
.
set_body_typed
(
LambdaLift
);
return
module
;
}
}
// namespace transform
}
// namespace vm
}
// namespace relay
}
// namespace tvm
src/relay/pass/pass_manager.cc
View file @
70041c48
...
...
@@ -309,20 +309,24 @@ Module FunctionPassNode::operator()(const Module& mod,
const
PassContext
&
pass_ctx
)
const
{
const
PassInfo
&
pass_info
=
Info
();
CHECK
(
mod
.
defined
());
DLOG
(
INFO
)
<<
"Executing
module
pass : "
DLOG
(
INFO
)
<<
"Executing
function
pass : "
<<
pass_info
->
name
<<
" with opt level: "
<<
pass_info
->
opt_level
;
Module
updated_mod
=
mod
;
Module
new_mod
=
ModuleNode
::
make
({},
mod
->
type_definitions
);
// Execute the pass function and return a new module.
std
::
vector
<
std
::
pair
<
GlobalVar
,
Function
>
>
updates
;
for
(
const
auto
&
it
:
mod
->
functions
)
{
auto
updated_func
=
SkipFunction
(
it
.
second
)
?
it
.
second
:
pass_func
(
it
.
second
,
updated_mod
,
pass_ctx
);
new_mod
->
Add
(
it
.
first
,
updated_func
);
updates
.
push_back
({
it
.
first
,
updated_func
});
}
for
(
const
auto
&
pair
:
updates
)
{
updated_mod
->
Add
(
pair
.
first
,
pair
.
second
,
true
);
}
return
new
_mod
;
return
updated
_mod
;
}
// TODO(zhiics) Create an enum attribute for FunctionNode
...
...
@@ -539,7 +543,8 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
tvm
::
IRPrinter
*
p
)
{
p
->
stream
<<
"Pass context information: "
<<
"
\n
"
;
p
->
stream
<<
"
\t
opt_level: "
<<
node
->
opt_level
<<
"
\n
"
;
p
->
stream
<<
"
\t
fallback device: "
<<
runtime
::
DeviceName
(
node
->
opt_level
)
p
->
stream
<<
"
\t
fallback device: "
<<
runtime
::
DeviceName
(
node
->
fallback_device
)
<<
"
\n
"
;
p
->
stream
<<
"
\t
required passes: ["
<<
node
->
opt_level
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment