Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
279a8eba
Unverified
Commit
279a8eba
authored
Dec 03, 2019
by
Tianqi Chen
Committed by
GitHub
Dec 03, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[RUNTIME][RPC] Update RPC runtime to allow remote module as arg (#4462)
parent
77bdd5f7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
91 additions
and
67 deletions
+91
-67
python/tvm/contrib/debugger/debug_runtime.py
+6
-15
python/tvm/contrib/graph_runtime.py
+3
-4
src/runtime/graph/debug/graph_runtime_debug.cc
+0
-15
src/runtime/graph/graph_runtime.cc
+0
-15
src/runtime/rpc/rpc_module.cc
+23
-1
src/runtime/rpc/rpc_session.cc
+47
-17
src/runtime/rpc/rpc_session.h
+12
-0
No files found.
python/tvm/contrib/debugger/debug_runtime.py
View file @
279a8eba
...
...
@@ -23,7 +23,6 @@ from tvm._ffi.base import string_types
from
tvm._ffi.function
import
get_global_func
from
tvm.contrib
import
graph_runtime
from
tvm.ndarray
import
array
from
tvm.rpc
import
base
as
rpc_base
from
.
import
debug_result
_DUMP_ROOT_PREFIX
=
"tvmdbg_"
...
...
@@ -60,25 +59,17 @@ def create(graph_json_str, libmod, ctx, dump_root=None):
except
AttributeError
:
raise
ValueError
(
"Type
%
s is not supported"
%
type
(
graph_json_str
))
try
:
fcreate
=
get_global_func
(
"tvm.graph_runtime_debug.create"
)
ctx
,
num_rpc_ctx
,
device_type_id
=
graph_runtime
.
get_device_ctx
(
libmod
,
ctx
)
if
num_rpc_ctx
==
len
(
ctx
):
fcreate
=
ctx
[
0
]
.
_rpc_sess
.
get_function
(
"tvm.graph_runtime_debug.create"
)
else
:
fcreate
=
get_global_func
(
"tvm.graph_runtime_debug.create"
)
except
ValueError
:
raise
ValueError
(
"Please set '(USE_GRAPH_RUNTIME_DEBUG ON)' in "
"config.cmake and rebuild TVM to enable debug mode"
)
ctx
,
num_rpc_ctx
,
device_type_id
=
graph_runtime
.
get_device_ctx
(
libmod
,
ctx
)
if
num_rpc_ctx
==
len
(
ctx
):
libmod
=
rpc_base
.
_ModuleHandle
(
libmod
)
try
:
fcreate
=
ctx
[
0
]
.
_rpc_sess
.
get_function
(
"tvm.graph_runtime_debug.remote_create"
)
except
ValueError
:
raise
ValueError
(
"Please set '(USE_GRAPH_RUNTIME_DEBUG ON)' in "
"config.cmake and rebuild TVM to enable debug mode"
)
func_obj
=
fcreate
(
graph_json_str
,
libmod
,
*
device_type_id
)
return
GraphModuleDebug
(
func_obj
,
ctx
,
graph_json_str
,
dump_root
)
...
...
python/tvm/contrib/graph_runtime.py
View file @
279a8eba
...
...
@@ -51,11 +51,10 @@ def create(graph_json_str, libmod, ctx):
ctx
,
num_rpc_ctx
,
device_type_id
=
get_device_ctx
(
libmod
,
ctx
)
if
num_rpc_ctx
==
len
(
ctx
):
hmod
=
rpc_base
.
_ModuleHandle
(
libmod
)
fcreate
=
ctx
[
0
]
.
_rpc_sess
.
get_function
(
"tvm.graph_runtime.remote_create"
)
return
GraphModule
(
fcreate
(
graph_json_str
,
hmod
,
*
device_type_id
)
)
fcreate
=
ctx
[
0
]
.
_rpc_sess
.
get_function
(
"tvm.graph_runtime.create"
)
else
:
fcreate
=
get_global_func
(
"tvm.graph_runtime.create"
)
fcreate
=
get_global_func
(
"tvm.graph_runtime.create"
)
return
GraphModule
(
fcreate
(
graph_json_str
,
libmod
,
*
device_type_id
))
def
get_device_ctx
(
libmod
,
ctx
):
...
...
src/runtime/graph/debug/graph_runtime_debug.cc
View file @
279a8eba
...
...
@@ -27,7 +27,6 @@
#include <chrono>
#include <sstream>
#include "../graph_runtime.h"
#include "../../object_internal.h"
namespace
tvm
{
namespace
runtime
{
...
...
@@ -220,19 +219,5 @@ TVM_REGISTER_GLOBAL("tvm.graph_runtime_debug.create")
<<
args
.
num_args
;
*
rv
=
GraphRuntimeDebugCreate
(
args
[
0
],
args
[
1
],
GetAllContext
(
args
));
});
TVM_REGISTER_GLOBAL
(
"tvm.graph_runtime_debug.remote_create"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
CHECK_GE
(
args
.
num_args
,
4
)
<<
"The expected number of arguments for "
"graph_runtime.remote_create is "
"at least 4, but it has "
<<
args
.
num_args
;
void
*
mhandle
=
args
[
1
];
ModuleNode
*
mnode
=
ObjectInternal
::
GetModuleNode
(
mhandle
);
const
auto
&
contexts
=
GetAllContext
(
args
);
*
rv
=
GraphRuntimeDebugCreate
(
args
[
0
],
GetRef
<
Module
>
(
mnode
),
contexts
);
});
}
// namespace runtime
}
// namespace tvm
src/runtime/graph/graph_runtime.cc
View file @
279a8eba
...
...
@@ -36,7 +36,6 @@
#include <vector>
#include "graph_runtime.h"
#include "../object_internal.h"
namespace
tvm
{
namespace
runtime
{
...
...
@@ -511,19 +510,5 @@ TVM_REGISTER_GLOBAL("tvm.graph_runtime.create")
const
auto
&
contexts
=
GetAllContext
(
args
);
*
rv
=
GraphRuntimeCreate
(
args
[
0
],
args
[
1
],
contexts
);
});
TVM_REGISTER_GLOBAL
(
"tvm.graph_runtime.remote_create"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
CHECK_GE
(
args
.
num_args
,
4
)
<<
"The expected number of arguments for "
"graph_runtime.remote_create is "
"at least 4, but it has "
<<
args
.
num_args
;
void
*
mhandle
=
args
[
1
];
ModuleNode
*
mnode
=
ObjectInternal
::
GetModuleNode
(
mhandle
);
const
auto
&
contexts
=
GetAllContext
(
args
);
*
rv
=
GraphRuntimeCreate
(
args
[
0
],
GetRef
<
Module
>
(
mnode
),
contexts
);
});
}
// namespace runtime
}
// namespace tvm
src/runtime/rpc/rpc_module.cc
View file @
279a8eba
...
...
@@ -41,7 +41,7 @@ class RPCWrappedFunc {
}
void
operator
()(
TVMArgs
args
,
TVMRetValue
*
rv
)
const
{
sess_
->
CallFunc
(
handle_
,
args
,
rv
,
&
fwrap_
);
sess_
->
CallFunc
(
handle_
,
args
,
rv
,
UnwrapRemote
,
&
fwrap_
);
}
~
RPCWrappedFunc
()
{
try
{
...
...
@@ -55,6 +55,9 @@ class RPCWrappedFunc {
TVMArgs
args
,
TVMRetValue
*
rv
);
static
void
*
UnwrapRemote
(
int
rpc_sess_table_index
,
const
TVMArgValue
&
arg
);
// deleter of RPC remote array
static
void
RemoteNDArrayDeleter
(
NDArray
::
Container
*
ptr
)
{
RemoteSpace
*
space
=
static_cast
<
RemoteSpace
*>
(
ptr
->
dl_tensor
.
data
);
...
...
@@ -181,6 +184,25 @@ class RPCModuleNode final : public ModuleNode {
PackedFunc
fwrap_
;
};
void
*
RPCWrappedFunc
::
UnwrapRemote
(
int
rpc_sess_table_index
,
const
TVMArgValue
&
arg
)
{
if
(
arg
.
type_code
()
==
kModuleHandle
)
{
Module
mod
=
arg
;
std
::
string
tkey
=
mod
->
type_key
();
CHECK_EQ
(
tkey
,
"rpc"
)
<<
"ValueError: Cannot pass a non-RPC module to remote"
;
auto
*
rmod
=
static_cast
<
RPCModuleNode
*>
(
mod
.
operator
->
());
CHECK_EQ
(
rmod
->
sess
()
->
table_index
(),
rpc_sess_table_index
)
<<
"ValueError: Cannot pass in module into a different remote session"
;
return
rmod
->
module_handle
();
}
else
{
LOG
(
FATAL
)
<<
"ValueError: Cannot pass type "
<<
runtime
::
TypeCode2Str
(
arg
.
type_code
())
<<
" as an argument to the remote"
;
return
nullptr
;
}
}
void
RPCWrappedFunc
::
WrapRemote
(
std
::
shared_ptr
<
RPCSession
>
sess
,
TVMArgs
args
,
TVMRetValue
*
rv
)
{
...
...
src/runtime/rpc/rpc_session.cc
View file @
279a8eba
...
...
@@ -202,23 +202,33 @@ class RPCSession::EventHandler : public dmlc::Stream {
return
ctx
;
}
// Send Packed sequence to writer.
//
// client_mode: whether we are in client mode.
//
// funwrap: auxiliary function to unwrap remote Object
// when it is provided, we need to unwrap objects.
//
// return_ndarray is a special flag to handle returning of ndarray
// In this case, we return the shape, context and data of the array,
// as well as a customized PackedFunc that handles deletion of
// the array in the remote.
void
SendPackedSeq
(
const
TVMValue
*
arg_values
,
const
int
*
type_codes
,
int
n
,
int
num_args
,
bool
client_mode
,
FUnwrapRemoteObject
funwrap
=
nullptr
,
bool
return_ndarray
=
false
)
{
this
->
Write
(
n
);
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
std
::
swap
(
client_mode_
,
client_mode
);
this
->
Write
(
num_args
);
for
(
int
i
=
0
;
i
<
num_args
;
++
i
)
{
int
tcode
=
type_codes
[
i
];
if
(
tcode
==
kNDArrayContainer
)
tcode
=
kArrayHandle
;
this
->
Write
(
tcode
);
}
// Argument packing.
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
for
(
int
i
=
0
;
i
<
n
um_args
;
++
i
)
{
int
tcode
=
type_codes
[
i
];
TVMValue
value
=
arg_values
[
i
];
switch
(
tcode
)
{
...
...
@@ -241,7 +251,23 @@ class RPCSession::EventHandler : public dmlc::Stream {
break
;
}
case
kFuncHandle
:
case
kModuleHandle
:
case
kModuleHandle
:
{
// always send handle in 64 bit.
uint64_t
handle
;
// allow pass module as argument to remote.
if
(
funwrap
!=
nullptr
)
{
void
*
remote_handle
=
(
*
funwrap
)(
rpc_sess_table_index_
,
runtime
::
TVMArgValue
(
value
,
tcode
));
handle
=
reinterpret_cast
<
uint64_t
>
(
remote_handle
);
}
else
{
CHECK
(
!
client_mode_
)
<<
"Cannot directly pass remote object as argument"
;
handle
=
reinterpret_cast
<
uint64_t
>
(
value
.
v_handle
);
}
this
->
Write
(
handle
);
break
;
}
case
kHandle
:
{
// always send handle in 64 bit.
uint64_t
handle
=
reinterpret_cast
<
uint64_t
>
(
value
.
v_handle
);
...
...
@@ -300,6 +326,7 @@ class RPCSession::EventHandler : public dmlc::Stream {
}
}
}
std
::
swap
(
client_mode_
,
client_mode
);
}
// Endian aware IO handling
...
...
@@ -430,11 +457,11 @@ class RPCSession::EventHandler : public dmlc::Stream {
case
kHandle
:
case
kStr
:
case
kBytes
:
case
kModuleHandle
:
case
kTVMContext
:
{
this
->
RequestBytes
(
sizeof
(
TVMValue
));
break
;
}
case
kFuncHandle
:
case
kModuleHandle
:
{
case
kFuncHandle
:
{
CHECK
(
client_mode_
)
<<
"Only client can receive remote functions"
;
this
->
RequestBytes
(
sizeof
(
TVMValue
));
break
;
...
...
@@ -656,7 +683,7 @@ class RPCSession::EventHandler : public dmlc::Stream {
TVMValue
ret_value
;
ret_value
.
v_str
=
e
.
what
();
int
ret_tcode
=
kStr
;
SendPackedSeq
(
&
ret_value
,
&
ret_tcode
,
1
);
SendPackedSeq
(
&
ret_value
,
&
ret_tcode
,
1
,
false
);
}
}
this
->
SwitchToState
(
kRecvCode
);
...
...
@@ -711,7 +738,7 @@ class RPCSession::EventHandler : public dmlc::Stream {
}
}
this
->
Write
(
code
);
SendPackedSeq
(
&
ret_value
,
&
ret_tcode
,
1
);
SendPackedSeq
(
&
ret_value
,
&
ret_tcode
,
1
,
false
);
arg_recv_stage_
=
0
;
this
->
SwitchToState
(
kRecvCode
);
}
...
...
@@ -734,7 +761,7 @@ class RPCSession::EventHandler : public dmlc::Stream {
if
(
rv
.
type_code
()
==
kStr
)
{
ret_value
.
v_str
=
rv
.
ptr
<
std
::
string
>
()
->
c_str
();
ret_tcode
=
kStr
;
SendPackedSeq
(
&
ret_value
,
&
ret_tcode
,
1
);
SendPackedSeq
(
&
ret_value
,
&
ret_tcode
,
1
,
false
);
}
else
if
(
rv
.
type_code
()
==
kBytes
)
{
std
::
string
*
bytes
=
rv
.
ptr
<
std
::
string
>
();
TVMByteArray
arr
;
...
...
@@ -742,14 +769,14 @@ class RPCSession::EventHandler : public dmlc::Stream {
arr
.
size
=
bytes
->
length
();
ret_value
.
v_handle
=
&
arr
;
ret_tcode
=
kBytes
;
SendPackedSeq
(
&
ret_value
,
&
ret_tcode
,
1
);
SendPackedSeq
(
&
ret_value
,
&
ret_tcode
,
1
,
false
);
}
else
if
(
rv
.
type_code
()
==
kFuncHandle
||
rv
.
type_code
()
==
kModuleHandle
)
{
// always send handle in 64 bit.
CHECK
(
!
client_mode_
)
<<
"Only server can send function and module handle back."
;
rv
.
MoveToCHost
(
&
ret_value
,
&
ret_tcode
);
SendPackedSeq
(
&
ret_value
,
&
ret_tcode
,
1
);
SendPackedSeq
(
&
ret_value
,
&
ret_tcode
,
1
,
false
);
}
else
if
(
rv
.
type_code
()
==
kNDArrayContainer
)
{
// always send handle in 64 bit.
CHECK
(
!
client_mode_
)
...
...
@@ -764,18 +791,18 @@ class RPCSession::EventHandler : public dmlc::Stream {
NDArray
::
Container
*
nd
=
static_cast
<
NDArray
::
Container
*>
(
ret_value_pack
[
0
].
v_handle
);
ret_value_pack
[
1
].
v_handle
=
nd
;
ret_tcode_pack
[
1
]
=
kHandle
;
SendPackedSeq
(
ret_value_pack
,
ret_tcode_pack
,
2
,
true
);
SendPackedSeq
(
ret_value_pack
,
ret_tcode_pack
,
2
,
false
,
nullptr
,
true
);
}
else
{
ret_value
=
rv
.
value
();
ret_tcode
=
rv
.
type_code
();
SendPackedSeq
(
&
ret_value
,
&
ret_tcode
,
1
);
SendPackedSeq
(
&
ret_value
,
&
ret_tcode
,
1
,
false
);
}
}
catch
(
const
std
::
runtime_error
&
e
)
{
RPCCode
code
=
RPCCode
::
kException
;
this
->
Write
(
code
);
ret_value
.
v_str
=
e
.
what
();
ret_tcode
=
kStr
;
SendPackedSeq
(
&
ret_value
,
&
ret_tcode
,
1
);
SendPackedSeq
(
&
ret_value
,
&
ret_tcode
,
1
,
false
);
}
}
...
...
@@ -873,7 +900,7 @@ void RPCSession::Init() {
&
reader_
,
&
writer_
,
table_index_
,
name_
,
&
remote_key_
);
// Quick function to call remote.
call_remote_
=
PackedFunc
([
this
](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
handler_
->
SendPackedSeq
(
args
.
values
,
args
.
type_codes
,
args
.
num_args
);
handler_
->
SendPackedSeq
(
args
.
values
,
args
.
type_codes
,
args
.
num_args
,
true
);
RPCCode
code
=
HandleUntilReturnEvent
(
rv
,
true
,
nullptr
);
CHECK
(
code
==
RPCCode
::
kReturn
)
<<
"code="
<<
static_cast
<
int
>
(
code
);
});
...
...
@@ -954,13 +981,16 @@ int RPCSession::ServerEventHandler(const std::string& bytes, int event_flag) {
void
RPCSession
::
CallFunc
(
void
*
h
,
TVMArgs
args
,
TVMRetValue
*
rv
,
FUnwrapRemoteObject
funwrap
,
const
PackedFunc
*
fwrap
)
{
std
::
lock_guard
<
std
::
recursive_mutex
>
lock
(
mutex_
);
RPCCode
code
=
RPCCode
::
kCallFunc
;
handler_
->
Write
(
code
);
uint64_t
handle
=
reinterpret_cast
<
uint64_t
>
(
h
);
handler_
->
Write
(
handle
);
handler_
->
SendPackedSeq
(
args
.
values
,
args
.
type_codes
,
args
.
num_args
);
handler_
->
SendPackedSeq
(
args
.
values
,
args
.
type_codes
,
args
.
num_args
,
true
,
funwrap
);
code
=
HandleUntilReturnEvent
(
rv
,
true
,
fwrap
);
CHECK
(
code
==
RPCCode
::
kReturn
)
<<
"code="
<<
static_cast
<
int
>
(
code
);
}
...
...
src/runtime/rpc/rpc_session.h
View file @
279a8eba
...
...
@@ -91,6 +91,16 @@ enum class RPCCode : int {
};
/*!
* \brief Function that unwraps a remote object to its handle.
* \param rpc_sess_table_index RPC session table index for validation.
* \param obj Handle to the object argument.
* \return The corresponding handle.
*/
typedef
void
*
(
*
FUnwrapRemoteObject
)(
int
rpc_sess_table_index
,
const
TVMArgValue
&
obj
);
/*!
* \brief Abstract channel interface used to create RPCSession.
*/
class
RPCChannel
{
...
...
@@ -144,11 +154,13 @@ class RPCSession {
* \param handle The function handle
* \param args The arguments
* \param rv The return value.
* \param funpwrap Function that takes a remote object and returns the raw handle.
* \param fwrap Wrapper function to turn Function/Module handle into real return.
*/
void
CallFunc
(
RPCFuncHandle
handle
,
TVMArgs
args
,
TVMRetValue
*
rv
,
FUnwrapRemoteObject
funwrap
,
const
PackedFunc
*
fwrap
);
/*!
* \brief Copy bytes into remote array content.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment