Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
70041c48
Commit
70041c48
authored
Jun 10, 2019
by
Zhi
Committed by
Jared Roesch
Jun 10, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[relay][vm] move vm opt passes to pass manager (#3323)
parent
8f219b95
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
150 additions
and
113 deletions
+150
-113
python/tvm/relay/backend/vm.py
+33
-19
src/relay/backend/vm/compiler.cc
+17
-7
src/relay/backend/vm/inline_primitives.cc
+49
-43
src/relay/backend/vm/lambda_lift.cc
+41
-39
src/relay/pass/pass_manager.cc
+10
-5
No files found.
python/tvm/relay/backend/vm.py
View file @
70041c48
...
@@ -20,24 +20,45 @@ The Relay Virtual Vachine.
...
@@ -20,24 +20,45 @@ The Relay Virtual Vachine.
Implements a Python interface to compiling and executing on the Relay VM.
Implements a Python interface to compiling and executing on the Relay VM.
"""
"""
import
numpy
as
np
import
tvm
import
tvm
from
tvm._ffi.function
import
Object
from
tvm._ffi.function
import
Object
import
numpy
as
np
from
..
import
transform
from
..
import
ir_pass
from
..backend.interpreter
import
Executor
from
..backend.interpreter
import
Executor
from
..expr
import
GlobalVar
,
Function
,
Expr
from
..expr
import
GlobalVar
,
Expr
from
.
import
_vm
from
.
import
_vm
Object
=
Object
Object
=
Object
def
optimize
(
expr
,
mod
=
None
):
def
optimize
(
mod
):
# TODO: We need to move this optimization code into the optimizer/pass manager
"""Perform several optimizations on a module before executing it in the
ck_expr
=
ir_pass
.
infer_type
(
expr
,
mod
=
mod
)
Relay virtual machine.
simplified_expr
=
ir_pass
.
simplify_inference
(
ck_expr
)
simplified_expr
=
ir_pass
.
infer_type
(
simplified_expr
,
mod
=
mod
)
Parameters
fused_expr
=
ir_pass
.
fuse_ops
(
simplified_expr
,
mod
=
mod
)
----------
ck_fused
=
ir_pass
.
infer_type
(
fused_expr
,
mod
=
mod
)
mod : tvm.relay.Module
return
ck_fused
The module to optimize.
Returns
-------
ret : tvm.relay.Module
The optimized module.
"""
main_func
=
mod
[
mod
.
entry_func
]
opt_passes
=
[]
if
not
main_func
.
params
and
isinstance
(
main_func
.
body
,
GlobalVar
):
opt_passes
.
append
(
transform
.
EtaExpand
())
opt_passes
=
opt_passes
+
[
transform
.
SimplifyInference
(),
transform
.
FuseOps
(),
transform
.
InferType
()
]
seq
=
transform
.
Sequential
(
opt_passes
)
return
seq
(
mod
)
def
_convert
(
arg
,
cargs
):
def
_convert
(
arg
,
cargs
):
if
isinstance
(
arg
,
np
.
ndarray
):
if
isinstance
(
arg
,
np
.
ndarray
):
...
@@ -76,15 +97,8 @@ def _eval_vm(mod, ctx, *args):
...
@@ -76,15 +97,8 @@ def _eval_vm(mod, ctx, *args):
args: List[tvm.NDArray, np.ndarray]
args: List[tvm.NDArray, np.ndarray]
The arguments to evaluate.
The arguments to evaluate.
"""
"""
main_func
=
mod
[
mod
.
entry_func
]
if
not
main_func
.
params
and
isinstance
(
main_func
.
body
,
GlobalVar
):
main_func
=
ir_pass
.
eta_expand
(
main_func
.
body
,
mod
)
assert
isinstance
(
main_func
,
Function
)
main_func
=
optimize
(
mod
[
mod
.
entry_func
],
mod
)
mod
[
mod
.
entry_func
]
=
main_func
mod
=
optimize
(
mod
)
args
=
list
(
args
)
args
=
list
(
args
)
assert
isinstance
(
args
,
list
)
assert
isinstance
(
args
,
list
)
cargs
=
convert
(
args
)
cargs
=
convert
(
args
)
...
...
src/relay/backend/vm/compiler.cc
View file @
70041c48
...
@@ -27,7 +27,7 @@
...
@@ -27,7 +27,7 @@
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/interpreter.h>
#include <tvm/relay/interpreter.h>
#include <tvm/logging.h>
#include <tvm/logging.h>
#include <tvm/relay/
pass
.h>
#include <tvm/relay/
transform
.h>
#include <tvm/runtime/vm.h>
#include <tvm/runtime/vm.h>
#include <iostream>
#include <iostream>
#include <unordered_map>
#include <unordered_map>
...
@@ -38,15 +38,22 @@
...
@@ -38,15 +38,22 @@
namespace
tvm
{
namespace
tvm
{
namespace
relay
{
namespace
relay
{
namespace
transform
{
Pass
LambdaLift
();
Pass
InlinePrimitives
();
}
// namespace transform
namespace
vm
{
namespace
vm
{
using
namespace
tvm
::
runtime
;
using
namespace
tvm
::
runtime
;
using
namespace
tvm
::
runtime
::
vm
;
using
namespace
tvm
::
runtime
::
vm
;
using
namespace
relay
::
transform
;
// (@jroesch): VM passes, eventually declare as passes.
// (@jroesch): VM passes, eventually declare as passes.
bool
IsClosure
(
const
Function
&
func
);
bool
IsClosure
(
const
Function
&
func
);
Module
LambdaLift
(
const
Module
&
module
);
Module
InlinePrimitives
(
const
Module
&
module
);
template
<
typename
T
,
typename
U
>
template
<
typename
T
,
typename
U
>
using
NodeMap
=
std
::
unordered_map
<
T
,
U
,
NodeHash
,
NodeEqual
>
;
using
NodeMap
=
std
::
unordered_map
<
T
,
U
,
NodeHash
,
NodeEqual
>
;
...
@@ -560,10 +567,13 @@ VMFunction CompileFunc(VMCompilerContext* context, const GlobalVar& var, const F
...
@@ -560,10 +567,13 @@ VMFunction CompileFunc(VMCompilerContext* context, const GlobalVar& var, const F
}
}
Module
OptimizeModule
(
const
Module
&
mod
)
{
Module
OptimizeModule
(
const
Module
&
mod
)
{
ToANormalForm
(
mod
->
entry_func
,
mod
);
transform
::
Sequential
seq
({
transform
::
ToANormalForm
(),
InlinePrimitives
(
mod
);
transform
::
InlinePrimitives
(),
LambdaLift
(
mod
);
transform
::
LambdaLift
(),
return
InlinePrimitives
(
mod
);
transform
::
InlinePrimitives
()});
auto
pass_ctx
=
transform
::
PassContext
::
Create
();
tvm
::
With
<
relay
::
transform
::
PassContext
>
ctx
(
pass_ctx
);
return
seq
(
mod
);
}
}
void
PopulateGlobalMap
(
GlobalMap
*
global_map
,
const
Module
&
mod
)
{
void
PopulateGlobalMap
(
GlobalMap
*
global_map
,
const
Module
&
mod
)
{
...
...
src/relay/backend/vm/inline_primitives.cc
View file @
70041c48
...
@@ -26,7 +26,7 @@
...
@@ -26,7 +26,7 @@
#include <tvm/relay/expr.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/logging.h>
#include <tvm/logging.h>
#include <tvm/relay/
pass
.h>
#include <tvm/relay/
transform
.h>
#include <tvm/runtime/vm.h>
#include <tvm/runtime/vm.h>
#include <iostream>
#include <iostream>
#include <vector>
#include <vector>
...
@@ -37,6 +37,21 @@ namespace tvm {
...
@@ -37,6 +37,21 @@ namespace tvm {
namespace
relay
{
namespace
relay
{
namespace
vm
{
namespace
vm
{
// TODO(@jroesch): write verifier
/* This pass will eliminate primitives which have been lifted by the ANF
* transform inlining them directly into call sites.
*
* This makes VM related code generation easier as the call target is always
* a primitive function.
*
* let prim = fn(...) { ... };
* prim(...)
*
* will become:
*
* (fn(...) { ... })(...)
*/
struct
PrimitiveInliner
:
ExprMutator
{
struct
PrimitiveInliner
:
ExprMutator
{
Module
module_
;
Module
module_
;
std
::
unordered_map
<
Var
,
Expr
,
NodeHash
,
NodeEqual
>
var_map
;
std
::
unordered_map
<
Var
,
Expr
,
NodeHash
,
NodeEqual
>
var_map
;
...
@@ -92,55 +107,46 @@ struct PrimitiveInliner : ExprMutator {
...
@@ -92,55 +107,46 @@ struct PrimitiveInliner : ExprMutator {
}
}
}
}
Function
Inline
(
const
Function
&
func
)
{
Module
Inline
()
{
DLOG
(
INFO
)
<<
"Before inlining primitives: "
<<
std
::
endl
auto
gvar_funcs
=
module_
->
functions
;
<<
"func= "
<<
AsText
(
func
,
false
)
<<
std
::
endl
;
for
(
auto
pair
:
gvar_funcs
)
{
auto
global
=
pair
.
first
;
auto
inlined
=
FunctionNode
::
make
(
func
->
params
,
VisitExpr
(
func
->
body
),
func
->
ret_type
,
auto
func
=
pair
.
second
;
func
->
type_params
,
func
->
attrs
);
DLOG
(
INFO
)
<<
"Before inlining primitives: "
<<
global
<<
std
::
endl
<<
AsText
(
func
,
false
);
inlined
=
Downcast
<
Function
>
(
DeadCodeElimination
(
inlined
));
func
=
FunctionNode
::
make
(
func
->
params
,
DLOG
(
INFO
)
<<
"After inlining primitives"
<<
std
::
endl
VisitExpr
(
func
->
body
),
<<
"after_func= "
<<
AsText
(
inlined
,
false
)
<<
std
::
endl
;
func
->
ret_type
,
return
inlined
;
func
->
type_params
,
func
->
attrs
);
module_
->
Add
(
global
,
func
,
true
);
DLOG
(
INFO
)
<<
"After inlining primitives: "
<<
global
<<
std
::
endl
<<
AsText
(
func
,
false
);
}
return
module_
;
}
}
};
};
// TODO(@jroesch): write verifier
}
// namespace vm
/* This pass will eliminate primitives which have been lifted by the ANF
* transform inlining them directly into call sites.
*
* This makes VM related code generation easier as the call target is always
* a primitive function.
*
* let prim = fn(...) { ... };
* prim(...)
*
* will become:
*
* (fn(...) { ... })(...)
*/
Module
InlinePrimitives
(
const
Module
&
module
)
{
PrimitiveInliner
inliner
(
module
);
tvm
::
Map
<
GlobalVar
,
Function
>
updates
;
namespace
transform
{
// There is an ordering bug here.
Pass
InlinePrimitives
()
{
for
(
auto
pair
:
module
->
functions
)
{
runtime
::
TypedPackedFunc
<
Module
(
Module
,
PassContext
)
>
pass_func
=
auto
global
=
pair
.
first
;
[
=
](
Module
m
,
PassContext
pc
)
{
auto
func
=
pair
.
second
;
return
relay
::
vm
::
PrimitiveInliner
(
m
).
Inline
();
updates
.
Set
(
global
,
inliner
.
Inline
(
func
));
};
}
auto
inline_pass
=
CreateModulePass
(
pass_func
,
1
,
"Inline"
,
{});
// Eliminate dead code for each function after inlining.
return
Sequential
({
inline_pass
,
DeadCodeElimination
()},
"InlinePrimitives"
);
}
for
(
auto
pair
:
updates
)
{
TVM_REGISTER_API
(
"relay._transform.InlinePrimitives"
)
module
->
Add
(
pair
.
first
,
pair
.
second
,
true
);
.
set_body_typed
(
InlinePrimitives
);
}
return
module
;
}
// namespace transform
}
}
// namespace vm
}
// namespace relay
}
// namespace relay
}
// namespace tvm
}
// namespace tvm
src/relay/backend/vm/lambda_lift.cc
View file @
70041c48
...
@@ -27,6 +27,7 @@
...
@@ -27,6 +27,7 @@
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/logging.h>
#include <tvm/logging.h>
#include <tvm/relay/pass.h>
#include <tvm/relay/pass.h>
#include <tvm/relay/transform.h>
#include <tvm/runtime/vm.h>
#include <tvm/runtime/vm.h>
#include <iostream>
#include <iostream>
#include <vector>
#include <vector>
...
@@ -54,9 +55,14 @@ Function MarkClosure(const Function& func) {
...
@@ -54,9 +55,14 @@ Function MarkClosure(const Function& func) {
return
FunctionSetAttr
(
func
,
kIsClosure
,
tvm
::
Integer
(
1
));
return
FunctionSetAttr
(
func
,
kIsClosure
,
tvm
::
Integer
(
1
));
}
}
/* The goal of this class is to lift out any nested functions into top-level
* functions.
*
* We will lift a function out into a global which takes the set of the free
* vars and then return the new created function.
*/
struct
LambdaLifter
:
ExprMutator
{
struct
LambdaLifter
:
ExprMutator
{
Module
module_
;
Module
module_
;
std
::
vector
<
std
::
pair
<
GlobalVar
,
Function
>>
lifted_
;
explicit
LambdaLifter
(
const
Module
&
module
)
:
module_
(
module
)
{}
explicit
LambdaLifter
(
const
Module
&
module
)
:
module_
(
module
)
{}
Expr
VisitExpr_
(
const
FunctionNode
*
func_node
)
final
{
Expr
VisitExpr_
(
const
FunctionNode
*
func_node
)
final
{
...
@@ -71,8 +77,7 @@ struct LambdaLifter : ExprMutator {
...
@@ -71,8 +77,7 @@ struct LambdaLifter : ExprMutator {
auto
free_type_vars
=
FreeTypeVars
(
func
,
module_
);
auto
free_type_vars
=
FreeTypeVars
(
func
,
module_
);
auto
body
=
Downcast
<
Function
>
(
ExprMutator
::
VisitExpr_
(
func_node
));
auto
body
=
Downcast
<
Function
>
(
ExprMutator
::
VisitExpr_
(
func_node
));
// When performing this optimization there are two
// When performing this optimization there are two cases.
// cases.
//
//
// The first case in which we have no free variables
// The first case in which we have no free variables
// we can just lift the function into the global
// we can just lift the function into the global
...
@@ -80,7 +85,7 @@ struct LambdaLifter : ExprMutator {
...
@@ -80,7 +85,7 @@ struct LambdaLifter : ExprMutator {
//
//
//
//
// The second case requires that we generate a special
// The second case requires that we generate a special
// function w
it
h makes a distinction between allocating
// function w
hic
h makes a distinction between allocating
// a closure, and then the code for the closure.
// a closure, and then the code for the closure.
//
//
// We represent a closure allocation by lifting the
// We represent a closure allocation by lifting the
...
@@ -92,7 +97,7 @@ struct LambdaLifter : ExprMutator {
...
@@ -92,7 +97,7 @@ struct LambdaLifter : ExprMutator {
// function marked as a closure is used to emit allocation
// function marked as a closure is used to emit allocation
// code for the closure's environment.
// code for the closure's environment.
//
//
// The "inner" function
is
should be used to generate the
// The "inner" function should be used to generate the
// code for the closure.
// code for the closure.
Function
lifted_func
;
Function
lifted_func
;
if
(
free_vars
.
size
()
==
0
)
{
if
(
free_vars
.
size
()
==
0
)
{
...
@@ -107,16 +112,16 @@ struct LambdaLifter : ExprMutator {
...
@@ -107,16 +112,16 @@ struct LambdaLifter : ExprMutator {
CHECK
(
lifted_func
.
defined
());
CHECK
(
lifted_func
.
defined
());
auto
name
=
GenerateName
(
lifted_func
);
auto
name
=
GenerateName
(
lifted_func
);
auto
global
=
this
->
module_
->
GetGlobalVar
(
name
);
auto
global
=
module_
->
GetGlobalVar
(
name
);
lifted_
.
push_back
({
global
,
lifted_func
});
// Add the lifted function to the module.
module_
->
Add
(
global
,
lifted_func
);
if
(
free_vars
.
size
()
==
0
)
{
if
(
free_vars
.
size
()
==
0
)
{
return
std
::
move
(
global
);
return
std
::
move
(
global
);
}
else
{
}
else
{
// If we need to allocate a closure
// If we need to allocate a closure,
// we pass the variables in its environment
// we pass the variables in its environment here.
// here.
Array
<
Expr
>
fvs
;
Array
<
Expr
>
fvs
;
for
(
auto
fv
:
free_vars
)
{
for
(
auto
fv
:
free_vars
)
{
fvs
.
push_back
(
fv
);
fvs
.
push_back
(
fv
);
...
@@ -125,42 +130,39 @@ struct LambdaLifter : ExprMutator {
...
@@ -125,42 +130,39 @@ struct LambdaLifter : ExprMutator {
}
}
}
}
Function
Lift
(
const
Function
&
func
)
{
Module
Lift
()
{
DLOG
(
INFO
)
<<
"Lifting: "
<<
AsText
(
func
,
false
)
<<
std
::
endl
;
return
FunctionNode
::
make
(
func
->
params
,
VisitExpr
(
func
->
body
),
func
->
ret_type
,
func
->
type_params
,
func
->
attrs
);
}
};
/* The goal of this pass is to lift out any nested functions into top-level
* functions.
*
* We will lift the functions out into globals which take the set of the free vars
* and then return a function whcih has b
*/
Module
LambdaLift
(
const
Module
&
module
)
{
LambdaLifter
lifter
(
module
);
tvm
::
Map
<
GlobalVar
,
Function
>
updates
;
// There is an ordering bug here.
// There is an ordering bug here.
for
(
auto
pair
:
module
->
functions
)
{
auto
glob_funcs
=
module_
->
functions
;
auto
global
=
pair
.
first
;
for
(
auto
pair
:
glob_funcs
)
{
auto
func
=
pair
.
second
;
auto
func
=
pair
.
second
;
updates
.
Set
(
global
,
lifter
.
Lift
(
func
));
DLOG
(
INFO
)
<<
"Lifting "
<<
AsText
(
func
,
false
);
func
=
FunctionNode
::
make
(
func
->
params
,
VisitExpr
(
func
->
body
),
func
->
ret_type
,
func
->
type_params
,
func
->
attrs
);
module_
->
Add
(
pair
.
first
,
func
,
true
);
}
}
return
module_
;
for
(
auto
i
=
lifter
.
lifted_
.
begin
();
i
!=
lifter
.
lifted_
.
end
();
i
++
)
{
module
->
Add
(
i
->
first
,
i
->
second
);
}
}
};
for
(
auto
pair
:
updates
)
{
}
// namespace vm
module
->
Add
(
pair
.
first
,
pair
.
second
,
true
);
}
return
module
;
namespace
transform
{
Pass
LambdaLift
()
{
runtime
::
TypedPackedFunc
<
Module
(
Module
,
PassContext
)
>
pass_func
=
[
=
](
Module
m
,
PassContext
pc
)
{
return
relay
::
vm
::
LambdaLifter
(
m
).
Lift
();
};
return
CreateModulePass
(
pass_func
,
1
,
"LambdaLift"
,
{});
}
}
}
// namespace vm
TVM_REGISTER_API
(
"relay._transform.LambdaLift"
)
.
set_body_typed
(
LambdaLift
);
}
// namespace transform
}
// namespace relay
}
// namespace relay
}
// namespace tvm
}
// namespace tvm
src/relay/pass/pass_manager.cc
View file @
70041c48
...
@@ -309,20 +309,24 @@ Module FunctionPassNode::operator()(const Module& mod,
...
@@ -309,20 +309,24 @@ Module FunctionPassNode::operator()(const Module& mod,
const
PassContext
&
pass_ctx
)
const
{
const
PassContext
&
pass_ctx
)
const
{
const
PassInfo
&
pass_info
=
Info
();
const
PassInfo
&
pass_info
=
Info
();
CHECK
(
mod
.
defined
());
CHECK
(
mod
.
defined
());
DLOG
(
INFO
)
<<
"Executing
module
pass : "
DLOG
(
INFO
)
<<
"Executing
function
pass : "
<<
pass_info
->
name
<<
pass_info
->
name
<<
" with opt level: "
<<
" with opt level: "
<<
pass_info
->
opt_level
;
<<
pass_info
->
opt_level
;
Module
updated_mod
=
mod
;
Module
updated_mod
=
mod
;
Module
new_mod
=
ModuleNode
::
make
({},
mod
->
type_definitions
);
// Execute the pass function and return a new module.
// Execute the pass function and return a new module.
std
::
vector
<
std
::
pair
<
GlobalVar
,
Function
>
>
updates
;
for
(
const
auto
&
it
:
mod
->
functions
)
{
for
(
const
auto
&
it
:
mod
->
functions
)
{
auto
updated_func
=
SkipFunction
(
it
.
second
)
auto
updated_func
=
SkipFunction
(
it
.
second
)
?
it
.
second
?
it
.
second
:
pass_func
(
it
.
second
,
updated_mod
,
pass_ctx
);
:
pass_func
(
it
.
second
,
updated_mod
,
pass_ctx
);
new_mod
->
Add
(
it
.
first
,
updated_func
);
updates
.
push_back
({
it
.
first
,
updated_func
});
}
for
(
const
auto
&
pair
:
updates
)
{
updated_mod
->
Add
(
pair
.
first
,
pair
.
second
,
true
);
}
}
return
new
_mod
;
return
updated
_mod
;
}
}
// TODO(zhiics) Create an enum attribute for FunctionNode
// TODO(zhiics) Create an enum attribute for FunctionNode
...
@@ -539,7 +543,8 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
...
@@ -539,7 +543,8 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
tvm
::
IRPrinter
*
p
)
{
tvm
::
IRPrinter
*
p
)
{
p
->
stream
<<
"Pass context information: "
<<
"
\n
"
;
p
->
stream
<<
"Pass context information: "
<<
"
\n
"
;
p
->
stream
<<
"
\t
opt_level: "
<<
node
->
opt_level
<<
"
\n
"
;
p
->
stream
<<
"
\t
opt_level: "
<<
node
->
opt_level
<<
"
\n
"
;
p
->
stream
<<
"
\t
fallback device: "
<<
runtime
::
DeviceName
(
node
->
opt_level
)
p
->
stream
<<
"
\t
fallback device: "
<<
runtime
::
DeviceName
(
node
->
fallback_device
)
<<
"
\n
"
;
<<
"
\n
"
;
p
->
stream
<<
"
\t
required passes: ["
<<
node
->
opt_level
;
p
->
stream
<<
"
\t
required passes: ["
<<
node
->
opt_level
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment