Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
f63b249d
Unverified
Commit
f63b249d
authored
Mar 05, 2020
by
Zhi
Committed by
GitHub
Mar 05, 2020
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
refactor build module to take IRModule (#4988)
parent
fe74b37a
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
63 additions
and
50 deletions
+63
-50
include/tvm/relay/transform.h
+10
-0
python/tvm/relay/build_module.py
+28
-30
src/relay/backend/build_module.cc
+22
-18
src/relay/backend/vm/compiler.cc
+0
-1
tests/cpp/relay_build_module_test.cc
+3
-1
No files found.
include/tvm/relay/transform.h
View file @
f63b249d
...
...
@@ -332,6 +332,16 @@ TVM_DLL Pass PartitionGraph();
*/
TVM_DLL
Pass
Inline
();
/*!
* \brief Remove the unused functions in the Relay IRModule.
*
* \param entry_functions The entry functions used to search the functions that
* are being used.
*
* \return The pass.
*/
TVM_DLL
Pass
RemoveUnusedFunctions
(
Array
<
tvm
::
PrimExpr
>
entry_functions
);
}
// namespace transform
/*!
...
...
python/tvm/relay/build_module.py
View file @
f63b249d
...
...
@@ -62,7 +62,7 @@ def _convert_param_map(params):
class
BuildModule
(
object
):
"""Build a
Relay function
to run on TVM graph runtime. This class is used
"""Build a
n IR module
to run on TVM graph runtime. This class is used
to expose the `RelayBuildModule` APIs implemented in C++.
"""
def
__init__
(
self
):
...
...
@@ -74,12 +74,12 @@ class BuildModule(object):
self
.
_set_params_func
=
self
.
mod
[
"set_params"
]
self
.
_get_params_func
=
self
.
mod
[
"get_params"
]
def
build
(
self
,
func
,
target
=
None
,
target_host
=
None
,
params
=
None
):
def
build
(
self
,
mod
,
target
=
None
,
target_host
=
None
,
params
=
None
):
"""
Parameters
----------
func: relay.Function
The
function
to build.
mod : :py:class:`~tvm.IRModule`
The
IRModule
to build.
target : str, :any:`tvm.target.Target`, or dict of str(i.e.
device/context name) to str/tvm.target.Target, optional
...
...
@@ -115,8 +115,8 @@ class BuildModule(object):
# Setup the params.
if
params
:
self
.
_set_params
(
params
)
# Build the
function
self
.
_build
(
func
,
target
,
target_host
)
# Build the
IR module
self
.
_build
(
mod
,
target
,
target_host
)
# Get artifacts
graph_json
=
self
.
get_json
()
mod
=
self
.
get_module
()
...
...
@@ -124,12 +124,12 @@ class BuildModule(object):
return
graph_json
,
mod
,
params
def
optimize
(
self
,
func
,
target
=
None
,
params
=
None
):
def
optimize
(
self
,
mod
,
target
=
None
,
params
=
None
):
"""
Parameters
----------
func: relay.Function
The
function
to build.
mod : :py:class:`~tvm.IRModule`
The
IR module
to build.
target : str, :any:`tvm.target.Target`, or dict of str(i.e.
device/context name) to str/tvm.target.Target, optional
...
...
@@ -142,7 +142,7 @@ class BuildModule(object):
Returns
-------
mod :
tvm.IRModule
mod :
:py:class:`~tvm.IRModule`
The optimized relay module.
params : dict
...
...
@@ -153,7 +153,7 @@ class BuildModule(object):
# Setup the params.
if
params
:
self
.
_set_params
(
params
)
mod
=
self
.
_optimize
(
func
,
target
)
mod
=
self
.
_optimize
(
mod
,
target
)
# Get artifacts
params
=
self
.
get_params
()
...
...
@@ -186,8 +186,8 @@ def build(mod, target=None, target_host=None, params=None):
Parameters
----------
mod :
tvm.IRModule
The module to build. Using relay.Function is deprecated.
mod :
:py:class:`~tvm.IRModule`
The
IR
module to build. Using relay.Function is deprecated.
target : str, :any:`tvm.target.Target`, or dict of str(i.e. device/context
name) to str/tvm.target.Target, optional
...
...
@@ -218,16 +218,15 @@ def build(mod, target=None, target_host=None, params=None):
params : dict
The parameters of the final graph.
"""
if
isinstance
(
mod
,
IRModule
):
func
=
mod
[
"main"
]
elif
isinstance
(
mod
,
_expr
.
Function
):
func
=
mod
if
not
isinstance
(
mod
,
(
IRModule
,
_expr
.
Function
)):
raise
ValueError
(
"Type of input parameter mod must be tvm.IRModule"
)
if
isinstance
(
mod
,
_expr
.
Function
):
mod
=
IRModule
.
from_expr
(
mod
)
warnings
.
warn
(
"Please use input parameter mod (tvm.IRModule) "
"instead of deprecated parameter
func
(tvm.relay.expr.Function)"
,
"instead of deprecated parameter
mod
(tvm.relay.expr.Function)"
,
DeprecationWarning
)
else
:
raise
ValueError
(
"Type of input parameter mod must be tvm.IRModule"
)
target
=
_update_target
(
target
)
...
...
@@ -246,7 +245,7 @@ def build(mod, target=None, target_host=None, params=None):
with
tophub_context
:
bld_mod
=
BuildModule
()
graph_json
,
mod
,
params
=
bld_mod
.
build
(
func
,
target
,
target_host
,
params
)
graph_json
,
mod
,
params
=
bld_mod
.
build
(
mod
,
target
,
target_host
,
params
)
return
graph_json
,
mod
,
params
...
...
@@ -255,7 +254,7 @@ def optimize(mod, target=None, params=None):
Parameters
----------
mod :
tvm.IRModule
mod :
:py:class:`~tvm.IRModule`
The module to build. Using relay.Function is deprecated.
target : str, :any:`tvm.target.Target`, or dict of str(i.e. device/context
...
...
@@ -269,22 +268,21 @@ def optimize(mod, target=None, params=None):
Returns
-------
mod :
tvm.IRModule
mod :
:py:class:`~tvm.IRModule`
The optimized relay module.
params : dict
The parameters of the final graph.
"""
if
isinstance
(
mod
,
IRModule
):
func
=
mod
[
"main"
]
elif
isinstance
(
mod
,
_expr
.
Function
):
func
=
mod
if
not
isinstance
(
mod
,
(
IRModule
,
_expr
.
Function
)):
raise
ValueError
(
"Type of input parameter mod must be tvm.IRModule"
)
if
isinstance
(
mod
,
_expr
.
Function
):
mod
=
IRModule
.
from_expr
(
mod
)
warnings
.
warn
(
"Please use input parameter mod (tvm.IRModule) "
"instead of deprecated parameter func (tvm.relay.expr.Function)"
,
DeprecationWarning
)
else
:
raise
ValueError
(
"Type of input parameter mod must be tvm.IRModule"
)
target
=
_update_target
(
target
)
...
...
@@ -297,7 +295,7 @@ def optimize(mod, target=None, params=None):
with
tophub_context
:
bld_mod
=
BuildModule
()
mod
,
params
=
bld_mod
.
optimize
(
func
,
target
,
params
)
mod
,
params
=
bld_mod
.
optimize
(
mod
,
target
,
params
)
return
mod
,
params
...
...
src/relay/backend/build_module.cc
View file @
f63b249d
...
...
@@ -233,42 +233,46 @@ class RelayBuildModule : public runtime::ModuleNode {
}
/*!
* \brief Build relay
function
for graph runtime
* \brief Build relay
IRModule
for graph runtime
*
* \param
func Relay Function
* \param
mod Relay IRModule
* \param target Target device
* \param target_host Host target device
*/
void
Build
(
Function
func
,
void
Build
(
IRModule
mod
,
const
TargetsMap
&
targets
,
const
tvm
::
Target
&
target_host
)
{
targets_
=
targets
;
target_host_
=
target_host
;
BuildRelay
(
func
,
params_
);
BuildRelay
(
mod
,
params_
);
}
protected
:
/*!
* \brief Optimize a Relay
Function
.
* \brief Optimize a Relay
IRModule
.
*
* \param
func The input Function
where optmization will be applied on.
* \param
relay_module The input IRModule
where optmization will be applied on.
* \param targets The device type to `Target` mapping.
* \param params The param name to value mapping.
*
* \return relay::
Module The updated Relay
module after optimization.
* \return relay::
IRModule The updated Relay IR
module after optimization.
*/
IRModule
Optimize
(
Function
func
,
IRModule
relay_module
,
const
TargetsMap
&
targets
,
const
std
::
unordered_map
<
std
::
string
,
runtime
::
NDArray
>&
params
)
{
if
(
params
.
size
())
{
func
=
BindParamsByName
(
func
,
params
);
CHECK
(
relay_module
->
ContainGlobalVar
(
"main"
))
<<
"Missing the main entry function"
;
GlobalVar
main_glb_var
=
relay_module
->
GetGlobalVar
(
"main"
);
Function
main_func
=
Downcast
<
Function
>
(
relay_module
->
Lookup
(
main_glb_var
));
auto
new_main
=
BindParamsByName
(
main_func
,
params
);
relay_module
->
Update
(
main_glb_var
,
new_main
);
}
// Perform Module->Module optimizations.
IRModule
relay_module
=
IRModule
::
FromExpr
(
func
);
Array
<
Pass
>
pass_seqs
;
Array
<
tvm
::
PrimExpr
>
entry_functions
{
tvm
::
PrimExpr
{
"main"
}};
pass_seqs
.
push_back
(
transform
::
RemoveUnusedFunctions
(
entry_functions
));
// Run all dialect legalization passes.
pass_seqs
.
push_back
(
relay
::
qnn
::
transform
::
Legalize
());
...
...
@@ -418,18 +422,18 @@ class RelayBuildModule : public runtime::ModuleNode {
}
/*!
* \brief Compile a Relay
function
to runtime module.
* \brief Compile a Relay
IR module
to runtime module.
*
* \param
func The Relay function
.
* \param
relay_module The Relay IR module
.
* \param params The parameters.
*/
void
BuildRelay
(
Function
func
,
IRModule
relay_module
,
const
std
::
unordered_map
<
std
::
string
,
tvm
::
runtime
::
NDArray
>&
params
)
{
//
Optimize input Relay Function and returns Relay Module
IRModule
relay_module
=
Optimize
(
func
,
targets_
,
params
);
//
Relay IRModule -> IRModule optimizations.
relay_module
=
Optimize
(
relay_module
,
targets_
,
params
);
// Get the updated function.
func
=
Downcast
<
Function
>
(
relay_module
->
Lookup
(
"main"
));
auto
func
=
Downcast
<
Function
>
(
relay_module
->
Lookup
(
"main"
));
// Generate code for the updated function.
graph_codegen_
=
std
::
unique_ptr
<
GraphCodegen
>
(
new
GraphCodegen
());
...
...
src/relay/backend/vm/compiler.cc
View file @
f63b249d
...
...
@@ -51,7 +51,6 @@ namespace transform {
Pass
LambdaLift
();
Pass
InlinePrimitives
();
Pass
RemoveUnusedFunctions
(
Array
<
tvm
::
PrimExpr
>
entry_functions
);
Pass
ManifestAlloc
(
Target
target_host
)
{
auto
f
=
tvm
::
runtime
::
Registry
::
Get
(
"relay.transform.ManifestAlloc"
);
...
...
tests/cpp/relay_build_module_test.cc
View file @
f63b249d
...
...
@@ -29,6 +29,7 @@
#include <topi/broadcast.h>
#include <topi/generic/injective.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/ir/module.h>
#include <tvm/runtime/module.h>
#include <tvm/runtime/registry.h>
...
...
@@ -115,7 +116,8 @@ TEST(Relay, BuildModule) {
Map
<
tvm
::
Integer
,
tvm
::
Target
>
targets
;
Target
llvm_tgt
=
Target
::
Create
(
"llvm"
);
targets
.
Set
(
0
,
llvm_tgt
);
build_f
(
func
,
targets
,
llvm_tgt
);
auto
relay_mod
=
tvm
::
IRModule
::
FromExpr
(
func
);
build_f
(
relay_mod
,
targets
,
llvm_tgt
);
std
::
string
json
=
json_f
();
tvm
::
runtime
::
Module
mod
=
mod_f
();
// run
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment