Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
7c3ec7df
Commit
7c3ec7df
authored
Sep 21, 2018
by
Zhi
Committed by
Tianqi Chen
Sep 21, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Heterogeneous Runtime (#1695)
parent
7beafddd
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
70 additions
and
32 deletions
+70
-32
python/tvm/build_module.py
+26
-6
python/tvm/contrib/graph_runtime.py
+44
-26
src/runtime/graph/graph_runtime.cc
+0
-0
tests/python/unittest/test_runtime_heterogeneous.py
+0
-0
No files found.
python/tvm/build_module.py
View file @
7c3ec7df
...
@@ -384,8 +384,14 @@ def build(sch,
...
@@ -384,8 +384,14 @@ def build(sch,
target
=
None
,
target
=
None
,
target_host
=
None
,
target_host
=
None
,
name
=
"default_function"
,
name
=
"default_function"
,
binds
=
None
):
binds
=
None
,
"""Build a function with arguments as signiture.
postpone_host_codegen
=
False
):
"""Build a function with arguments as signature. Code will be generated
for a device specified by the target. For homogeneous execution, a module
that contains both host and device code is returned. For heterogeneous
execution, a list of lowered functions for the host and a module containing
device code are returned, but actual code generation for the host module is
postponed after code generation is finished for all devices.
Parameters
Parameters
----------
----------
...
@@ -414,10 +420,18 @@ def build(sch,
...
@@ -414,10 +420,18 @@ def build(sch,
Dictionary that maps the binding of symbolic buffer to Tensor.
Dictionary that maps the binding of symbolic buffer to Tensor.
By default, a new buffer is created for each tensor in the argument.
By default, a new buffer is created for each tensor in the argument.
postpone_host_codegen : bool, optional
A bool value that indicates if code generation for the host module
should be postponed. This variable is set to be true for heterogeneous
execution. Otherwise, it is defaulted to false.
Returns
Returns
-------
-------
f : Function, or pair of functions
ret : tvm.module, or (list of LoweredFunc, tvm.module) tuple
The result function.
A module that combines both host and device code is returned when
postpone_host_codegen is not set. Otherwise, a list of lowered
functions for the host and a module contains only device code are
returned.
Note
Note
----
----
...
@@ -498,9 +512,15 @@ def build(sch,
...
@@ -498,9 +512,15 @@ def build(sch,
fdevice
=
[
ir_pass
.
LowerIntrin
(
x
,
target_device
.
target_name
)
for
x
in
fdevice
]
fdevice
=
[
ir_pass
.
LowerIntrin
(
x
,
target_device
.
target_name
)
for
x
in
fdevice
]
fhost
=
[
ir_pass
.
LowerIntrin
(
x
,
target_host
.
target_name
)
for
x
in
fhost
]
fhost
=
[
ir_pass
.
LowerIntrin
(
x
,
target_host
.
target_name
)
for
x
in
fhost
]
fhost
=
[
ir_pass
.
CombineContextCall
(
x
)
for
x
in
fhost
]
fhost
=
[
ir_pass
.
CombineContextCall
(
x
)
for
x
in
fhost
]
mhost
=
codegen
.
build_module
(
fhost
,
str
(
target_host
))
# Append fhost to the device module and return the updated module. All
# device modules will be imported to the host module after all of them are
# collected.
mdev
=
codegen
.
build_module
(
fdevice
,
str
(
target_device
))
if
fdevice
else
None
if
postpone_host_codegen
:
return
fhost
,
mdev
mhost
=
codegen
.
build_module
(
fhost
,
str
(
target_host
))
if
fdevice
:
if
fdevice
:
mdev
=
codegen
.
build_module
(
fdevice
,
str
(
target_device
))
mhost
.
import_module
(
mdev
)
mhost
.
import_module
(
mdev
)
return
mhost
return
mhost
python/tvm/contrib/graph_runtime.py
View file @
7c3ec7df
...
@@ -3,26 +3,24 @@ import numpy as np
...
@@ -3,26 +3,24 @@ import numpy as np
from
.._ffi.base
import
string_types
from
.._ffi.base
import
string_types
from
.._ffi.function
import
get_global_func
from
.._ffi.function
import
get_global_func
from
.._ffi.runtime_ctypes
import
TVMContext
from
..rpc
import
base
as
rpc_base
from
..rpc
import
base
as
rpc_base
from
..
import
ndarray
as
nd
def
create
(
graph_json_str
,
libmod
,
ctx
):
def
create
(
graph_json_str
,
libmod
,
ctx
):
"""Create a runtime executor module given a graph and module.
"""Create a runtime executor module given a graph and module.
Parameters
Parameters
----------
----------
graph_json_str : str or graph class
graph_json_str : str or graph class
The graph to be deployed in json format output by nnvm graph.
The graph to be deployed in json format output by nnvm graph.
The graph can only contain one operator(tvm_op) that
The graph can only contain one operator(tvm_op) that
points to the name of PackedFunc in the libmod.
points to the name of PackedFunc in the libmod.
libmod : tvm.Module
libmod : tvm.Module
The module of the corresponding function
The module of the corresponding function
ctx : TVMContext or list of TVMContext
ctx : TVMContext
The context to deploy the module. It can be local or remote when there
The context to deploy the module, can be local or remote.
is only one TVMContext. Otherwise, the first context in the list will
be used as this purpose. All context should be given for heterogeneous
execution.
Returns
Returns
-------
-------
graph_module : GraphModule
graph_module : GraphModule
...
@@ -33,17 +31,42 @@ def create(graph_json_str, libmod, ctx):
...
@@ -33,17 +31,42 @@ def create(graph_json_str, libmod, ctx):
graph_json_str
=
graph_json_str
.
_tvm_graph_json
()
graph_json_str
=
graph_json_str
.
_tvm_graph_json
()
except
AttributeError
:
except
AttributeError
:
raise
ValueError
(
"Type
%
s is not supported"
%
type
(
graph_json_str
))
raise
ValueError
(
"Type
%
s is not supported"
%
type
(
graph_json_str
))
device_type
=
ctx
.
device_type
if
isinstance
(
ctx
,
TVMContext
):
device_id
=
ctx
.
device_id
ctx
=
[
ctx
]
elif
not
isinstance
(
ctx
,
(
list
,
tuple
)):
raise
ValueError
(
"ctx has to be the type of TVMContext or a list of "
"TVMCTVMContext"
)
for
cur_ctx
in
ctx
:
if
not
isinstance
(
cur_ctx
,
TVMContext
):
raise
ValueError
(
"ctx has to be the type of TVMContext or a list "
"of TVMContext"
)
# device_type_id[0], device_type_id[1] are used as the primary/fallback
# context type and id. All other ones are used as device context for
# heterogeneous execution.
num_rpc_ctx
=
0
device_type_id
=
[]
for
cur_ctx
in
ctx
:
device_type
=
cur_ctx
.
device_type
if
device_type
>=
rpc_base
.
RPC_SESS_MASK
:
if
device_type
>=
rpc_base
.
RPC_SESS_MASK
:
assert
libmod
.
type_key
==
"rpc"
assert
libmod
.
type_key
==
"rpc"
assert
rpc_base
.
_SessTableIndex
(
libmod
)
==
ctx
.
_rpc_sess
.
_tbl_index
assert
rpc_base
.
_SessTableIndex
(
libmod
)
==
cur_ctx
.
_rpc_sess
.
_tbl_index
num_rpc_ctx
+=
1
device_type
=
cur_ctx
.
device_type
%
rpc_base
.
RPC_SESS_MASK
device_type_id
.
append
(
device_type
)
device_type_id
.
append
(
cur_ctx
.
device_id
)
if
0
<
num_rpc_ctx
<
len
(
ctx
):
raise
ValueError
(
"Either all or none of the contexts should be rpc."
)
if
num_rpc_ctx
==
len
(
ctx
):
hmod
=
rpc_base
.
_ModuleHandle
(
libmod
)
hmod
=
rpc_base
.
_ModuleHandle
(
libmod
)
fcreate
=
ctx
.
_rpc_sess
.
get_function
(
"tvm.graph_runtime.remote_create"
)
fcreate
=
ctx
[
0
]
.
_rpc_sess
.
get_function
(
"tvm.graph_runtime.remote_create"
)
device_type
=
device_type
%
rpc_base
.
RPC_SESS_MASK
return
GraphModule
(
fcreate
(
graph_json_str
,
hmod
,
*
device_type_id
))
return
GraphModule
(
fcreate
(
graph_json_str
,
hmod
,
device_type
,
device_id
),
ctx
)
fcreate
=
get_global_func
(
"tvm.graph_runtime.create"
)
fcreate
=
get_global_func
(
"tvm.graph_runtime.create"
)
return
GraphModule
(
fcreate
(
graph_json_str
,
libmod
,
device_type
,
device_id
),
ctx
)
return
GraphModule
(
fcreate
(
graph_json_str
,
libmod
,
*
device_type_id
)
)
class
GraphModule
(
object
):
class
GraphModule
(
object
):
...
@@ -58,18 +81,13 @@ class GraphModule(object):
...
@@ -58,18 +81,13 @@ class GraphModule(object):
module : Module
module : Module
The interal tvm module that holds the actual graph functions.
The interal tvm module that holds the actual graph functions.
ctx : TVMContext
The context this module is under
Attributes
Attributes
----------
----------
module : Module
module : Module
The interal tvm module that holds the actual graph functions.
The interal tvm module that holds the actual graph functions.
ctx : TVMContext
The context this module is under
"""
"""
def
__init__
(
self
,
module
,
ctx
):
def
__init__
(
self
,
module
):
self
.
module
=
module
self
.
module
=
module
self
.
_set_input
=
module
[
"set_input"
]
self
.
_set_input
=
module
[
"set_input"
]
self
.
_run
=
module
[
"run"
]
self
.
_run
=
module
[
"run"
]
...
@@ -81,7 +99,6 @@ class GraphModule(object):
...
@@ -81,7 +99,6 @@ class GraphModule(object):
except
AttributeError
:
except
AttributeError
:
pass
pass
self
.
_load_params
=
module
[
"load_params"
]
self
.
_load_params
=
module
[
"load_params"
]
self
.
ctx
=
ctx
def
set_input
(
self
,
key
=
None
,
value
=
None
,
**
params
):
def
set_input
(
self
,
key
=
None
,
value
=
None
,
**
params
):
"""Set inputs to the module via kwargs
"""Set inputs to the module via kwargs
...
@@ -98,14 +115,14 @@ class GraphModule(object):
...
@@ -98,14 +115,14 @@ class GraphModule(object):
Additonal arguments
Additonal arguments
"""
"""
if
key
:
if
key
:
self
.
_
set_input
(
key
,
nd
.
array
(
value
,
ctx
=
self
.
ctx
)
)
self
.
_
get_input
(
key
)
.
copyfrom
(
value
)
if
params
:
if
params
:
# upload big arrays first to avoid memory issue in rpc mode
# upload big arrays first to avoid memory issue in rpc mode
keys
=
list
(
params
.
keys
())
keys
=
list
(
params
.
keys
())
keys
.
sort
(
key
=
lambda
x
:
-
np
.
prod
(
params
[
x
]
.
shape
))
keys
.
sort
(
key
=
lambda
x
:
-
np
.
prod
(
params
[
x
]
.
shape
))
for
k
in
keys
:
for
k
in
keys
:
self
.
_
set_input
(
k
,
nd
.
array
(
params
[
k
],
ctx
=
self
.
ctx
)
)
self
.
_
get_input
(
k
)
.
copyfrom
(
params
[
k
]
)
def
run
(
self
,
**
input_dict
):
def
run
(
self
,
**
input_dict
):
"""Run forward execution of the graph
"""Run forward execution of the graph
...
@@ -177,7 +194,8 @@ class GraphModule(object):
...
@@ -177,7 +194,8 @@ class GraphModule(object):
if
hasattr
(
self
,
'_debug_get_output'
):
if
hasattr
(
self
,
'_debug_get_output'
):
self
.
_debug_get_output
(
node
,
out
)
self
.
_debug_get_output
(
node
,
out
)
else
:
else
:
raise
RuntimeError
(
"Please compile runtime with USE_GRAPH_RUNTIME_DEBUG = 0"
)
raise
RuntimeError
(
"Please compile runtime with USE_GRAPH_RUNTIME_DEBUG = 0"
)
return
out
return
out
def
load_params
(
self
,
params_bytes
):
def
load_params
(
self
,
params_bytes
):
...
...
src/runtime/graph/graph_runtime.cc
View file @
7c3ec7df
This diff is collapsed.
Click to expand it.
tests/python/unittest/test_runtime_heterogeneous.py
0 → 100644
View file @
7c3ec7df
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment