Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
f63b249d
Unverified
Commit
f63b249d
authored
Mar 05, 2020
by
Zhi
Committed by
GitHub
Mar 05, 2020
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
refactor build module to take IRModule (#4988)
parent
fe74b37a
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
63 additions
and
50 deletions
+63
-50
include/tvm/relay/transform.h
+10
-0
python/tvm/relay/build_module.py
+28
-30
src/relay/backend/build_module.cc
+22
-18
src/relay/backend/vm/compiler.cc
+0
-1
tests/cpp/relay_build_module_test.cc
+3
-1
No files found.
include/tvm/relay/transform.h
View file @
f63b249d
...
@@ -332,6 +332,16 @@ TVM_DLL Pass PartitionGraph();
...
@@ -332,6 +332,16 @@ TVM_DLL Pass PartitionGraph();
*/
*/
TVM_DLL
Pass
Inline
();
TVM_DLL
Pass
Inline
();
/*!
* \brief Remove the unused functions in the Relay IRModule.
*
* \param entry_functions The entry functions used to search the functions that
* are being used.
*
* \return The pass.
*/
TVM_DLL
Pass
RemoveUnusedFunctions
(
Array
<
tvm
::
PrimExpr
>
entry_functions
);
}
// namespace transform
}
// namespace transform
/*!
/*!
...
...
python/tvm/relay/build_module.py
View file @
f63b249d
...
@@ -62,7 +62,7 @@ def _convert_param_map(params):
...
@@ -62,7 +62,7 @@ def _convert_param_map(params):
class
BuildModule
(
object
):
class
BuildModule
(
object
):
"""Build a
Relay function
to run on TVM graph runtime. This class is used
"""Build a
n IR module
to run on TVM graph runtime. This class is used
to expose the `RelayBuildModule` APIs implemented in C++.
to expose the `RelayBuildModule` APIs implemented in C++.
"""
"""
def
__init__
(
self
):
def
__init__
(
self
):
...
@@ -74,12 +74,12 @@ class BuildModule(object):
...
@@ -74,12 +74,12 @@ class BuildModule(object):
self
.
_set_params_func
=
self
.
mod
[
"set_params"
]
self
.
_set_params_func
=
self
.
mod
[
"set_params"
]
self
.
_get_params_func
=
self
.
mod
[
"get_params"
]
self
.
_get_params_func
=
self
.
mod
[
"get_params"
]
def
build
(
self
,
func
,
target
=
None
,
target_host
=
None
,
params
=
None
):
def
build
(
self
,
mod
,
target
=
None
,
target_host
=
None
,
params
=
None
):
"""
"""
Parameters
Parameters
----------
----------
func: relay.Function
mod : :py:class:`~tvm.IRModule`
The
function
to build.
The
IRModule
to build.
target : str, :any:`tvm.target.Target`, or dict of str(i.e.
target : str, :any:`tvm.target.Target`, or dict of str(i.e.
device/context name) to str/tvm.target.Target, optional
device/context name) to str/tvm.target.Target, optional
...
@@ -115,8 +115,8 @@ class BuildModule(object):
...
@@ -115,8 +115,8 @@ class BuildModule(object):
# Setup the params.
# Setup the params.
if
params
:
if
params
:
self
.
_set_params
(
params
)
self
.
_set_params
(
params
)
# Build the
function
# Build the
IR module
self
.
_build
(
func
,
target
,
target_host
)
self
.
_build
(
mod
,
target
,
target_host
)
# Get artifacts
# Get artifacts
graph_json
=
self
.
get_json
()
graph_json
=
self
.
get_json
()
mod
=
self
.
get_module
()
mod
=
self
.
get_module
()
...
@@ -124,12 +124,12 @@ class BuildModule(object):
...
@@ -124,12 +124,12 @@ class BuildModule(object):
return
graph_json
,
mod
,
params
return
graph_json
,
mod
,
params
def
optimize
(
self
,
func
,
target
=
None
,
params
=
None
):
def
optimize
(
self
,
mod
,
target
=
None
,
params
=
None
):
"""
"""
Parameters
Parameters
----------
----------
func: relay.Function
mod : :py:class:`~tvm.IRModule`
The
function
to build.
The
IR module
to build.
target : str, :any:`tvm.target.Target`, or dict of str(i.e.
target : str, :any:`tvm.target.Target`, or dict of str(i.e.
device/context name) to str/tvm.target.Target, optional
device/context name) to str/tvm.target.Target, optional
...
@@ -142,7 +142,7 @@ class BuildModule(object):
...
@@ -142,7 +142,7 @@ class BuildModule(object):
Returns
Returns
-------
-------
mod :
tvm.IRModule
mod :
:py:class:`~tvm.IRModule`
The optimized relay module.
The optimized relay module.
params : dict
params : dict
...
@@ -153,7 +153,7 @@ class BuildModule(object):
...
@@ -153,7 +153,7 @@ class BuildModule(object):
# Setup the params.
# Setup the params.
if
params
:
if
params
:
self
.
_set_params
(
params
)
self
.
_set_params
(
params
)
mod
=
self
.
_optimize
(
func
,
target
)
mod
=
self
.
_optimize
(
mod
,
target
)
# Get artifacts
# Get artifacts
params
=
self
.
get_params
()
params
=
self
.
get_params
()
...
@@ -186,8 +186,8 @@ def build(mod, target=None, target_host=None, params=None):
...
@@ -186,8 +186,8 @@ def build(mod, target=None, target_host=None, params=None):
Parameters
Parameters
----------
----------
mod :
tvm.IRModule
mod :
:py:class:`~tvm.IRModule`
The module to build. Using relay.Function is deprecated.
The
IR
module to build. Using relay.Function is deprecated.
target : str, :any:`tvm.target.Target`, or dict of str(i.e. device/context
target : str, :any:`tvm.target.Target`, or dict of str(i.e. device/context
name) to str/tvm.target.Target, optional
name) to str/tvm.target.Target, optional
...
@@ -218,16 +218,15 @@ def build(mod, target=None, target_host=None, params=None):
...
@@ -218,16 +218,15 @@ def build(mod, target=None, target_host=None, params=None):
params : dict
params : dict
The parameters of the final graph.
The parameters of the final graph.
"""
"""
if
isinstance
(
mod
,
IRModule
):
if
not
isinstance
(
mod
,
(
IRModule
,
_expr
.
Function
)):
func
=
mod
[
"main"
]
raise
ValueError
(
"Type of input parameter mod must be tvm.IRModule"
)
elif
isinstance
(
mod
,
_expr
.
Function
):
func
=
mod
if
isinstance
(
mod
,
_expr
.
Function
):
mod
=
IRModule
.
from_expr
(
mod
)
warnings
.
warn
(
warnings
.
warn
(
"Please use input parameter mod (tvm.IRModule) "
"Please use input parameter mod (tvm.IRModule) "
"instead of deprecated parameter
func
(tvm.relay.expr.Function)"
,
"instead of deprecated parameter
mod
(tvm.relay.expr.Function)"
,
DeprecationWarning
)
DeprecationWarning
)
else
:
raise
ValueError
(
"Type of input parameter mod must be tvm.IRModule"
)
target
=
_update_target
(
target
)
target
=
_update_target
(
target
)
...
@@ -246,7 +245,7 @@ def build(mod, target=None, target_host=None, params=None):
...
@@ -246,7 +245,7 @@ def build(mod, target=None, target_host=None, params=None):
with
tophub_context
:
with
tophub_context
:
bld_mod
=
BuildModule
()
bld_mod
=
BuildModule
()
graph_json
,
mod
,
params
=
bld_mod
.
build
(
func
,
target
,
target_host
,
params
)
graph_json
,
mod
,
params
=
bld_mod
.
build
(
mod
,
target
,
target_host
,
params
)
return
graph_json
,
mod
,
params
return
graph_json
,
mod
,
params
...
@@ -255,7 +254,7 @@ def optimize(mod, target=None, params=None):
...
@@ -255,7 +254,7 @@ def optimize(mod, target=None, params=None):
Parameters
Parameters
----------
----------
mod :
tvm.IRModule
mod :
:py:class:`~tvm.IRModule`
The module to build. Using relay.Function is deprecated.
The module to build. Using relay.Function is deprecated.
target : str, :any:`tvm.target.Target`, or dict of str(i.e. device/context
target : str, :any:`tvm.target.Target`, or dict of str(i.e. device/context
...
@@ -269,22 +268,21 @@ def optimize(mod, target=None, params=None):
...
@@ -269,22 +268,21 @@ def optimize(mod, target=None, params=None):
Returns
Returns
-------
-------
mod :
tvm.IRModule
mod :
:py:class:`~tvm.IRModule`
The optimized relay module.
The optimized relay module.
params : dict
params : dict
The parameters of the final graph.
The parameters of the final graph.
"""
"""
if
isinstance
(
mod
,
IRModule
):
if
not
isinstance
(
mod
,
(
IRModule
,
_expr
.
Function
)):
func
=
mod
[
"main"
]
raise
ValueError
(
"Type of input parameter mod must be tvm.IRModule"
)
elif
isinstance
(
mod
,
_expr
.
Function
):
func
=
mod
if
isinstance
(
mod
,
_expr
.
Function
):
mod
=
IRModule
.
from_expr
(
mod
)
warnings
.
warn
(
warnings
.
warn
(
"Please use input parameter mod (tvm.IRModule) "
"Please use input parameter mod (tvm.IRModule) "
"instead of deprecated parameter func (tvm.relay.expr.Function)"
,
"instead of deprecated parameter func (tvm.relay.expr.Function)"
,
DeprecationWarning
)
DeprecationWarning
)
else
:
raise
ValueError
(
"Type of input parameter mod must be tvm.IRModule"
)
target
=
_update_target
(
target
)
target
=
_update_target
(
target
)
...
@@ -297,7 +295,7 @@ def optimize(mod, target=None, params=None):
...
@@ -297,7 +295,7 @@ def optimize(mod, target=None, params=None):
with
tophub_context
:
with
tophub_context
:
bld_mod
=
BuildModule
()
bld_mod
=
BuildModule
()
mod
,
params
=
bld_mod
.
optimize
(
func
,
target
,
params
)
mod
,
params
=
bld_mod
.
optimize
(
mod
,
target
,
params
)
return
mod
,
params
return
mod
,
params
...
...
src/relay/backend/build_module.cc
View file @
f63b249d
...
@@ -233,42 +233,46 @@ class RelayBuildModule : public runtime::ModuleNode {
...
@@ -233,42 +233,46 @@ class RelayBuildModule : public runtime::ModuleNode {
}
}
/*!
/*!
* \brief Build relay
function
for graph runtime
* \brief Build relay
IRModule
for graph runtime
*
*
* \param
func Relay Function
* \param
mod Relay IRModule
* \param target Target device
* \param target Target device
* \param target_host Host target device
* \param target_host Host target device
*/
*/
void
Build
(
Function
func
,
void
Build
(
IRModule
mod
,
const
TargetsMap
&
targets
,
const
TargetsMap
&
targets
,
const
tvm
::
Target
&
target_host
)
{
const
tvm
::
Target
&
target_host
)
{
targets_
=
targets
;
targets_
=
targets
;
target_host_
=
target_host
;
target_host_
=
target_host
;
BuildRelay
(
func
,
params_
);
BuildRelay
(
mod
,
params_
);
}
}
protected
:
protected
:
/*!
/*!
* \brief Optimize a Relay
Function
.
* \brief Optimize a Relay
IRModule
.
*
*
* \param
func The input Function
where optmization will be applied on.
* \param
relay_module The input IRModule
where optmization will be applied on.
* \param targets The device type to `Target` mapping.
* \param targets The device type to `Target` mapping.
* \param params The param name to value mapping.
* \param params The param name to value mapping.
*
*
* \return relay::
Module The updated Relay
module after optimization.
* \return relay::
IRModule The updated Relay IR
module after optimization.
*/
*/
IRModule
Optimize
(
IRModule
Optimize
(
Function
func
,
IRModule
relay_module
,
const
TargetsMap
&
targets
,
const
TargetsMap
&
targets
,
const
std
::
unordered_map
<
std
::
string
,
runtime
::
NDArray
>&
params
)
{
const
std
::
unordered_map
<
std
::
string
,
runtime
::
NDArray
>&
params
)
{
if
(
params
.
size
())
{
if
(
params
.
size
())
{
func
=
BindParamsByName
(
func
,
params
);
CHECK
(
relay_module
->
ContainGlobalVar
(
"main"
))
<<
"Missing the main entry function"
;
GlobalVar
main_glb_var
=
relay_module
->
GetGlobalVar
(
"main"
);
Function
main_func
=
Downcast
<
Function
>
(
relay_module
->
Lookup
(
main_glb_var
));
auto
new_main
=
BindParamsByName
(
main_func
,
params
);
relay_module
->
Update
(
main_glb_var
,
new_main
);
}
}
// Perform Module->Module optimizations.
IRModule
relay_module
=
IRModule
::
FromExpr
(
func
);
Array
<
Pass
>
pass_seqs
;
Array
<
Pass
>
pass_seqs
;
Array
<
tvm
::
PrimExpr
>
entry_functions
{
tvm
::
PrimExpr
{
"main"
}};
pass_seqs
.
push_back
(
transform
::
RemoveUnusedFunctions
(
entry_functions
));
// Run all dialect legalization passes.
// Run all dialect legalization passes.
pass_seqs
.
push_back
(
relay
::
qnn
::
transform
::
Legalize
());
pass_seqs
.
push_back
(
relay
::
qnn
::
transform
::
Legalize
());
...
@@ -418,18 +422,18 @@ class RelayBuildModule : public runtime::ModuleNode {
...
@@ -418,18 +422,18 @@ class RelayBuildModule : public runtime::ModuleNode {
}
}
/*!
/*!
* \brief Compile a Relay
function
to runtime module.
* \brief Compile a Relay
IR module
to runtime module.
*
*
* \param
func The Relay function
.
* \param
relay_module The Relay IR module
.
* \param params The parameters.
* \param params The parameters.
*/
*/
void
BuildRelay
(
void
BuildRelay
(
Function
func
,
IRModule
relay_module
,
const
std
::
unordered_map
<
std
::
string
,
tvm
::
runtime
::
NDArray
>&
params
)
{
const
std
::
unordered_map
<
std
::
string
,
tvm
::
runtime
::
NDArray
>&
params
)
{
//
Optimize input Relay Function and returns Relay Module
//
Relay IRModule -> IRModule optimizations.
IRModule
relay_module
=
Optimize
(
func
,
targets_
,
params
);
relay_module
=
Optimize
(
relay_module
,
targets_
,
params
);
// Get the updated function.
// Get the updated function.
func
=
Downcast
<
Function
>
(
relay_module
->
Lookup
(
"main"
));
auto
func
=
Downcast
<
Function
>
(
relay_module
->
Lookup
(
"main"
));
// Generate code for the updated function.
// Generate code for the updated function.
graph_codegen_
=
std
::
unique_ptr
<
GraphCodegen
>
(
new
GraphCodegen
());
graph_codegen_
=
std
::
unique_ptr
<
GraphCodegen
>
(
new
GraphCodegen
());
...
...
src/relay/backend/vm/compiler.cc
View file @
f63b249d
...
@@ -51,7 +51,6 @@ namespace transform {
...
@@ -51,7 +51,6 @@ namespace transform {
Pass
LambdaLift
();
Pass
LambdaLift
();
Pass
InlinePrimitives
();
Pass
InlinePrimitives
();
Pass
RemoveUnusedFunctions
(
Array
<
tvm
::
PrimExpr
>
entry_functions
);
Pass
ManifestAlloc
(
Target
target_host
)
{
Pass
ManifestAlloc
(
Target
target_host
)
{
auto
f
=
tvm
::
runtime
::
Registry
::
Get
(
"relay.transform.ManifestAlloc"
);
auto
f
=
tvm
::
runtime
::
Registry
::
Get
(
"relay.transform.ManifestAlloc"
);
...
...
tests/cpp/relay_build_module_test.cc
View file @
f63b249d
...
@@ -29,6 +29,7 @@
...
@@ -29,6 +29,7 @@
#include <topi/broadcast.h>
#include <topi/broadcast.h>
#include <topi/generic/injective.h>
#include <topi/generic/injective.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/ir/module.h>
#include <tvm/runtime/module.h>
#include <tvm/runtime/module.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/registry.h>
...
@@ -115,7 +116,8 @@ TEST(Relay, BuildModule) {
...
@@ -115,7 +116,8 @@ TEST(Relay, BuildModule) {
Map
<
tvm
::
Integer
,
tvm
::
Target
>
targets
;
Map
<
tvm
::
Integer
,
tvm
::
Target
>
targets
;
Target
llvm_tgt
=
Target
::
Create
(
"llvm"
);
Target
llvm_tgt
=
Target
::
Create
(
"llvm"
);
targets
.
Set
(
0
,
llvm_tgt
);
targets
.
Set
(
0
,
llvm_tgt
);
build_f
(
func
,
targets
,
llvm_tgt
);
auto
relay_mod
=
tvm
::
IRModule
::
FromExpr
(
func
);
build_f
(
relay_mod
,
targets
,
llvm_tgt
);
std
::
string
json
=
json_f
();
std
::
string
json
=
json_f
();
tvm
::
runtime
::
Module
mod
=
mod_f
();
tvm
::
runtime
::
Module
mod
=
mod_f
();
// run
// run
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment