Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
279a8eba
Unverified
Commit
279a8eba
authored
Dec 03, 2019
by
Tianqi Chen
Committed by
GitHub
Dec 03, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[RUNTIME][RPC] Update RPC runtime to allow remote module as arg (#4462)
parent
77bdd5f7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
91 additions
and
67 deletions
+91
-67
python/tvm/contrib/debugger/debug_runtime.py
+6
-15
python/tvm/contrib/graph_runtime.py
+3
-4
src/runtime/graph/debug/graph_runtime_debug.cc
+0
-15
src/runtime/graph/graph_runtime.cc
+0
-15
src/runtime/rpc/rpc_module.cc
+23
-1
src/runtime/rpc/rpc_session.cc
+47
-17
src/runtime/rpc/rpc_session.h
+12
-0
No files found.
python/tvm/contrib/debugger/debug_runtime.py
View file @
279a8eba
...
@@ -23,7 +23,6 @@ from tvm._ffi.base import string_types
...
@@ -23,7 +23,6 @@ from tvm._ffi.base import string_types
from
tvm._ffi.function
import
get_global_func
from
tvm._ffi.function
import
get_global_func
from
tvm.contrib
import
graph_runtime
from
tvm.contrib
import
graph_runtime
from
tvm.ndarray
import
array
from
tvm.ndarray
import
array
from
tvm.rpc
import
base
as
rpc_base
from
.
import
debug_result
from
.
import
debug_result
_DUMP_ROOT_PREFIX
=
"tvmdbg_"
_DUMP_ROOT_PREFIX
=
"tvmdbg_"
...
@@ -60,25 +59,17 @@ def create(graph_json_str, libmod, ctx, dump_root=None):
...
@@ -60,25 +59,17 @@ def create(graph_json_str, libmod, ctx, dump_root=None):
except
AttributeError
:
except
AttributeError
:
raise
ValueError
(
"Type
%
s is not supported"
%
type
(
graph_json_str
))
raise
ValueError
(
"Type
%
s is not supported"
%
type
(
graph_json_str
))
try
:
try
:
fcreate
=
get_global_func
(
"tvm.graph_runtime_debug.create"
)
ctx
,
num_rpc_ctx
,
device_type_id
=
graph_runtime
.
get_device_ctx
(
libmod
,
ctx
)
if
num_rpc_ctx
==
len
(
ctx
):
fcreate
=
ctx
[
0
]
.
_rpc_sess
.
get_function
(
"tvm.graph_runtime_debug.create"
)
else
:
fcreate
=
get_global_func
(
"tvm.graph_runtime_debug.create"
)
except
ValueError
:
except
ValueError
:
raise
ValueError
(
raise
ValueError
(
"Please set '(USE_GRAPH_RUNTIME_DEBUG ON)' in "
"Please set '(USE_GRAPH_RUNTIME_DEBUG ON)' in "
"config.cmake and rebuild TVM to enable debug mode"
"config.cmake and rebuild TVM to enable debug mode"
)
)
ctx
,
num_rpc_ctx
,
device_type_id
=
graph_runtime
.
get_device_ctx
(
libmod
,
ctx
)
if
num_rpc_ctx
==
len
(
ctx
):
libmod
=
rpc_base
.
_ModuleHandle
(
libmod
)
try
:
fcreate
=
ctx
[
0
]
.
_rpc_sess
.
get_function
(
"tvm.graph_runtime_debug.remote_create"
)
except
ValueError
:
raise
ValueError
(
"Please set '(USE_GRAPH_RUNTIME_DEBUG ON)' in "
"config.cmake and rebuild TVM to enable debug mode"
)
func_obj
=
fcreate
(
graph_json_str
,
libmod
,
*
device_type_id
)
func_obj
=
fcreate
(
graph_json_str
,
libmod
,
*
device_type_id
)
return
GraphModuleDebug
(
func_obj
,
ctx
,
graph_json_str
,
dump_root
)
return
GraphModuleDebug
(
func_obj
,
ctx
,
graph_json_str
,
dump_root
)
...
...
python/tvm/contrib/graph_runtime.py
View file @
279a8eba
...
@@ -51,11 +51,10 @@ def create(graph_json_str, libmod, ctx):
...
@@ -51,11 +51,10 @@ def create(graph_json_str, libmod, ctx):
ctx
,
num_rpc_ctx
,
device_type_id
=
get_device_ctx
(
libmod
,
ctx
)
ctx
,
num_rpc_ctx
,
device_type_id
=
get_device_ctx
(
libmod
,
ctx
)
if
num_rpc_ctx
==
len
(
ctx
):
if
num_rpc_ctx
==
len
(
ctx
):
hmod
=
rpc_base
.
_ModuleHandle
(
libmod
)
fcreate
=
ctx
[
0
]
.
_rpc_sess
.
get_function
(
"tvm.graph_runtime.create"
)
fcreate
=
ctx
[
0
]
.
_rpc_sess
.
get_function
(
"tvm.graph_runtime.remote_create"
)
else
:
return
GraphModule
(
fcreate
(
graph_json_str
,
hmod
,
*
device_type_id
)
)
fcreate
=
get_global_func
(
"tvm.graph_runtime.create"
)
fcreate
=
get_global_func
(
"tvm.graph_runtime.create"
)
return
GraphModule
(
fcreate
(
graph_json_str
,
libmod
,
*
device_type_id
))
return
GraphModule
(
fcreate
(
graph_json_str
,
libmod
,
*
device_type_id
))
def
get_device_ctx
(
libmod
,
ctx
):
def
get_device_ctx
(
libmod
,
ctx
):
...
...
src/runtime/graph/debug/graph_runtime_debug.cc
View file @
279a8eba
...
@@ -27,7 +27,6 @@
...
@@ -27,7 +27,6 @@
#include <chrono>
#include <chrono>
#include <sstream>
#include <sstream>
#include "../graph_runtime.h"
#include "../graph_runtime.h"
#include "../../object_internal.h"
namespace
tvm
{
namespace
tvm
{
namespace
runtime
{
namespace
runtime
{
...
@@ -220,19 +219,5 @@ TVM_REGISTER_GLOBAL("tvm.graph_runtime_debug.create")
...
@@ -220,19 +219,5 @@ TVM_REGISTER_GLOBAL("tvm.graph_runtime_debug.create")
<<
args
.
num_args
;
<<
args
.
num_args
;
*
rv
=
GraphRuntimeDebugCreate
(
args
[
0
],
args
[
1
],
GetAllContext
(
args
));
*
rv
=
GraphRuntimeDebugCreate
(
args
[
0
],
args
[
1
],
GetAllContext
(
args
));
});
});
TVM_REGISTER_GLOBAL
(
"tvm.graph_runtime_debug.remote_create"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
CHECK_GE
(
args
.
num_args
,
4
)
<<
"The expected number of arguments for "
"graph_runtime.remote_create is "
"at least 4, but it has "
<<
args
.
num_args
;
void
*
mhandle
=
args
[
1
];
ModuleNode
*
mnode
=
ObjectInternal
::
GetModuleNode
(
mhandle
);
const
auto
&
contexts
=
GetAllContext
(
args
);
*
rv
=
GraphRuntimeDebugCreate
(
args
[
0
],
GetRef
<
Module
>
(
mnode
),
contexts
);
});
}
// namespace runtime
}
// namespace runtime
}
// namespace tvm
}
// namespace tvm
src/runtime/graph/graph_runtime.cc
View file @
279a8eba
...
@@ -36,7 +36,6 @@
...
@@ -36,7 +36,6 @@
#include <vector>
#include <vector>
#include "graph_runtime.h"
#include "graph_runtime.h"
#include "../object_internal.h"
namespace
tvm
{
namespace
tvm
{
namespace
runtime
{
namespace
runtime
{
...
@@ -511,19 +510,5 @@ TVM_REGISTER_GLOBAL("tvm.graph_runtime.create")
...
@@ -511,19 +510,5 @@ TVM_REGISTER_GLOBAL("tvm.graph_runtime.create")
const
auto
&
contexts
=
GetAllContext
(
args
);
const
auto
&
contexts
=
GetAllContext
(
args
);
*
rv
=
GraphRuntimeCreate
(
args
[
0
],
args
[
1
],
contexts
);
*
rv
=
GraphRuntimeCreate
(
args
[
0
],
args
[
1
],
contexts
);
});
});
TVM_REGISTER_GLOBAL
(
"tvm.graph_runtime.remote_create"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
CHECK_GE
(
args
.
num_args
,
4
)
<<
"The expected number of arguments for "
"graph_runtime.remote_create is "
"at least 4, but it has "
<<
args
.
num_args
;
void
*
mhandle
=
args
[
1
];
ModuleNode
*
mnode
=
ObjectInternal
::
GetModuleNode
(
mhandle
);
const
auto
&
contexts
=
GetAllContext
(
args
);
*
rv
=
GraphRuntimeCreate
(
args
[
0
],
GetRef
<
Module
>
(
mnode
),
contexts
);
});
}
// namespace runtime
}
// namespace runtime
}
// namespace tvm
}
// namespace tvm
src/runtime/rpc/rpc_module.cc
View file @
279a8eba
...
@@ -41,7 +41,7 @@ class RPCWrappedFunc {
...
@@ -41,7 +41,7 @@ class RPCWrappedFunc {
}
}
void
operator
()(
TVMArgs
args
,
TVMRetValue
*
rv
)
const
{
void
operator
()(
TVMArgs
args
,
TVMRetValue
*
rv
)
const
{
sess_
->
CallFunc
(
handle_
,
args
,
rv
,
&
fwrap_
);
sess_
->
CallFunc
(
handle_
,
args
,
rv
,
UnwrapRemote
,
&
fwrap_
);
}
}
~
RPCWrappedFunc
()
{
~
RPCWrappedFunc
()
{
try
{
try
{
...
@@ -55,6 +55,9 @@ class RPCWrappedFunc {
...
@@ -55,6 +55,9 @@ class RPCWrappedFunc {
TVMArgs
args
,
TVMArgs
args
,
TVMRetValue
*
rv
);
TVMRetValue
*
rv
);
static
void
*
UnwrapRemote
(
int
rpc_sess_table_index
,
const
TVMArgValue
&
arg
);
// deleter of RPC remote array
// deleter of RPC remote array
static
void
RemoteNDArrayDeleter
(
NDArray
::
Container
*
ptr
)
{
static
void
RemoteNDArrayDeleter
(
NDArray
::
Container
*
ptr
)
{
RemoteSpace
*
space
=
static_cast
<
RemoteSpace
*>
(
ptr
->
dl_tensor
.
data
);
RemoteSpace
*
space
=
static_cast
<
RemoteSpace
*>
(
ptr
->
dl_tensor
.
data
);
...
@@ -181,6 +184,25 @@ class RPCModuleNode final : public ModuleNode {
...
@@ -181,6 +184,25 @@ class RPCModuleNode final : public ModuleNode {
PackedFunc
fwrap_
;
PackedFunc
fwrap_
;
};
};
void
*
RPCWrappedFunc
::
UnwrapRemote
(
int
rpc_sess_table_index
,
const
TVMArgValue
&
arg
)
{
if
(
arg
.
type_code
()
==
kModuleHandle
)
{
Module
mod
=
arg
;
std
::
string
tkey
=
mod
->
type_key
();
CHECK_EQ
(
tkey
,
"rpc"
)
<<
"ValueError: Cannot pass a non-RPC module to remote"
;
auto
*
rmod
=
static_cast
<
RPCModuleNode
*>
(
mod
.
operator
->
());
CHECK_EQ
(
rmod
->
sess
()
->
table_index
(),
rpc_sess_table_index
)
<<
"ValueError: Cannot pass in module into a different remote session"
;
return
rmod
->
module_handle
();
}
else
{
LOG
(
FATAL
)
<<
"ValueError: Cannot pass type "
<<
runtime
::
TypeCode2Str
(
arg
.
type_code
())
<<
" as an argument to the remote"
;
return
nullptr
;
}
}
void
RPCWrappedFunc
::
WrapRemote
(
std
::
shared_ptr
<
RPCSession
>
sess
,
void
RPCWrappedFunc
::
WrapRemote
(
std
::
shared_ptr
<
RPCSession
>
sess
,
TVMArgs
args
,
TVMArgs
args
,
TVMRetValue
*
rv
)
{
TVMRetValue
*
rv
)
{
...
...
src/runtime/rpc/rpc_session.cc
View file @
279a8eba
...
@@ -202,23 +202,33 @@ class RPCSession::EventHandler : public dmlc::Stream {
...
@@ -202,23 +202,33 @@ class RPCSession::EventHandler : public dmlc::Stream {
return
ctx
;
return
ctx
;
}
}
// Send Packed sequence to writer.
// Send Packed sequence to writer.
//
// client_mode: whether we are in client mode.
//
// funwrap: auxiliary function to unwrap remote Object
// when it is provided, we need to unwrap objects.
//
// return_ndarray is a special flag to handle returning of ndarray
// return_ndarray is a special flag to handle returning of ndarray
// In this case, we return the shape, context and data of the array,
// In this case, we return the shape, context and data of the array,
// as well as a customized PackedFunc that handles deletion of
// as well as a customized PackedFunc that handles deletion of
// the array in the remote.
// the array in the remote.
void
SendPackedSeq
(
const
TVMValue
*
arg_values
,
void
SendPackedSeq
(
const
TVMValue
*
arg_values
,
const
int
*
type_codes
,
const
int
*
type_codes
,
int
n
,
int
num_args
,
bool
client_mode
,
FUnwrapRemoteObject
funwrap
=
nullptr
,
bool
return_ndarray
=
false
)
{
bool
return_ndarray
=
false
)
{
this
->
Write
(
n
);
std
::
swap
(
client_mode_
,
client_mode
);
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
this
->
Write
(
num_args
);
for
(
int
i
=
0
;
i
<
num_args
;
++
i
)
{
int
tcode
=
type_codes
[
i
];
int
tcode
=
type_codes
[
i
];
if
(
tcode
==
kNDArrayContainer
)
tcode
=
kArrayHandle
;
if
(
tcode
==
kNDArrayContainer
)
tcode
=
kArrayHandle
;
this
->
Write
(
tcode
);
this
->
Write
(
tcode
);
}
}
// Argument packing.
// Argument packing.
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
for
(
int
i
=
0
;
i
<
n
um_args
;
++
i
)
{
int
tcode
=
type_codes
[
i
];
int
tcode
=
type_codes
[
i
];
TVMValue
value
=
arg_values
[
i
];
TVMValue
value
=
arg_values
[
i
];
switch
(
tcode
)
{
switch
(
tcode
)
{
...
@@ -241,7 +251,23 @@ class RPCSession::EventHandler : public dmlc::Stream {
...
@@ -241,7 +251,23 @@ class RPCSession::EventHandler : public dmlc::Stream {
break
;
break
;
}
}
case
kFuncHandle
:
case
kFuncHandle
:
case
kModuleHandle
:
case
kModuleHandle
:
{
// always send handle in 64 bit.
uint64_t
handle
;
// allow pass module as argument to remote.
if
(
funwrap
!=
nullptr
)
{
void
*
remote_handle
=
(
*
funwrap
)(
rpc_sess_table_index_
,
runtime
::
TVMArgValue
(
value
,
tcode
));
handle
=
reinterpret_cast
<
uint64_t
>
(
remote_handle
);
}
else
{
CHECK
(
!
client_mode_
)
<<
"Cannot directly pass remote object as argument"
;
handle
=
reinterpret_cast
<
uint64_t
>
(
value
.
v_handle
);
}
this
->
Write
(
handle
);
break
;
}
case
kHandle
:
{
case
kHandle
:
{
// always send handle in 64 bit.
// always send handle in 64 bit.
uint64_t
handle
=
reinterpret_cast
<
uint64_t
>
(
value
.
v_handle
);
uint64_t
handle
=
reinterpret_cast
<
uint64_t
>
(
value
.
v_handle
);
...
@@ -300,6 +326,7 @@ class RPCSession::EventHandler : public dmlc::Stream {
...
@@ -300,6 +326,7 @@ class RPCSession::EventHandler : public dmlc::Stream {
}
}
}
}
}
}
std
::
swap
(
client_mode_
,
client_mode
);
}
}
// Endian aware IO handling
// Endian aware IO handling
...
@@ -430,11 +457,11 @@ class RPCSession::EventHandler : public dmlc::Stream {
...
@@ -430,11 +457,11 @@ class RPCSession::EventHandler : public dmlc::Stream {
case
kHandle
:
case
kHandle
:
case
kStr
:
case
kStr
:
case
kBytes
:
case
kBytes
:
case
kModuleHandle
:
case
kTVMContext
:
{
case
kTVMContext
:
{
this
->
RequestBytes
(
sizeof
(
TVMValue
));
break
;
this
->
RequestBytes
(
sizeof
(
TVMValue
));
break
;
}
}
case
kFuncHandle
:
case
kFuncHandle
:
{
case
kModuleHandle
:
{
CHECK
(
client_mode_
)
CHECK
(
client_mode_
)
<<
"Only client can receive remote functions"
;
<<
"Only client can receive remote functions"
;
this
->
RequestBytes
(
sizeof
(
TVMValue
));
break
;
this
->
RequestBytes
(
sizeof
(
TVMValue
));
break
;
...
@@ -656,7 +683,7 @@ class RPCSession::EventHandler : public dmlc::Stream {
...
@@ -656,7 +683,7 @@ class RPCSession::EventHandler : public dmlc::Stream {
TVMValue
ret_value
;
TVMValue
ret_value
;
ret_value
.
v_str
=
e
.
what
();
ret_value
.
v_str
=
e
.
what
();
int
ret_tcode
=
kStr
;
int
ret_tcode
=
kStr
;
SendPackedSeq
(
&
ret_value
,
&
ret_tcode
,
1
);
SendPackedSeq
(
&
ret_value
,
&
ret_tcode
,
1
,
false
);
}
}
}
}
this
->
SwitchToState
(
kRecvCode
);
this
->
SwitchToState
(
kRecvCode
);
...
@@ -711,7 +738,7 @@ class RPCSession::EventHandler : public dmlc::Stream {
...
@@ -711,7 +738,7 @@ class RPCSession::EventHandler : public dmlc::Stream {
}
}
}
}
this
->
Write
(
code
);
this
->
Write
(
code
);
SendPackedSeq
(
&
ret_value
,
&
ret_tcode
,
1
);
SendPackedSeq
(
&
ret_value
,
&
ret_tcode
,
1
,
false
);
arg_recv_stage_
=
0
;
arg_recv_stage_
=
0
;
this
->
SwitchToState
(
kRecvCode
);
this
->
SwitchToState
(
kRecvCode
);
}
}
...
@@ -734,7 +761,7 @@ class RPCSession::EventHandler : public dmlc::Stream {
...
@@ -734,7 +761,7 @@ class RPCSession::EventHandler : public dmlc::Stream {
if
(
rv
.
type_code
()
==
kStr
)
{
if
(
rv
.
type_code
()
==
kStr
)
{
ret_value
.
v_str
=
rv
.
ptr
<
std
::
string
>
()
->
c_str
();
ret_value
.
v_str
=
rv
.
ptr
<
std
::
string
>
()
->
c_str
();
ret_tcode
=
kStr
;
ret_tcode
=
kStr
;
SendPackedSeq
(
&
ret_value
,
&
ret_tcode
,
1
);
SendPackedSeq
(
&
ret_value
,
&
ret_tcode
,
1
,
false
);
}
else
if
(
rv
.
type_code
()
==
kBytes
)
{
}
else
if
(
rv
.
type_code
()
==
kBytes
)
{
std
::
string
*
bytes
=
rv
.
ptr
<
std
::
string
>
();
std
::
string
*
bytes
=
rv
.
ptr
<
std
::
string
>
();
TVMByteArray
arr
;
TVMByteArray
arr
;
...
@@ -742,14 +769,14 @@ class RPCSession::EventHandler : public dmlc::Stream {
...
@@ -742,14 +769,14 @@ class RPCSession::EventHandler : public dmlc::Stream {
arr
.
size
=
bytes
->
length
();
arr
.
size
=
bytes
->
length
();
ret_value
.
v_handle
=
&
arr
;
ret_value
.
v_handle
=
&
arr
;
ret_tcode
=
kBytes
;
ret_tcode
=
kBytes
;
SendPackedSeq
(
&
ret_value
,
&
ret_tcode
,
1
);
SendPackedSeq
(
&
ret_value
,
&
ret_tcode
,
1
,
false
);
}
else
if
(
rv
.
type_code
()
==
kFuncHandle
||
}
else
if
(
rv
.
type_code
()
==
kFuncHandle
||
rv
.
type_code
()
==
kModuleHandle
)
{
rv
.
type_code
()
==
kModuleHandle
)
{
// always send handle in 64 bit.
// always send handle in 64 bit.
CHECK
(
!
client_mode_
)
CHECK
(
!
client_mode_
)
<<
"Only server can send function and module handle back."
;
<<
"Only server can send function and module handle back."
;
rv
.
MoveToCHost
(
&
ret_value
,
&
ret_tcode
);
rv
.
MoveToCHost
(
&
ret_value
,
&
ret_tcode
);
SendPackedSeq
(
&
ret_value
,
&
ret_tcode
,
1
);
SendPackedSeq
(
&
ret_value
,
&
ret_tcode
,
1
,
false
);
}
else
if
(
rv
.
type_code
()
==
kNDArrayContainer
)
{
}
else
if
(
rv
.
type_code
()
==
kNDArrayContainer
)
{
// always send handle in 64 bit.
// always send handle in 64 bit.
CHECK
(
!
client_mode_
)
CHECK
(
!
client_mode_
)
...
@@ -764,18 +791,18 @@ class RPCSession::EventHandler : public dmlc::Stream {
...
@@ -764,18 +791,18 @@ class RPCSession::EventHandler : public dmlc::Stream {
NDArray
::
Container
*
nd
=
static_cast
<
NDArray
::
Container
*>
(
ret_value_pack
[
0
].
v_handle
);
NDArray
::
Container
*
nd
=
static_cast
<
NDArray
::
Container
*>
(
ret_value_pack
[
0
].
v_handle
);
ret_value_pack
[
1
].
v_handle
=
nd
;
ret_value_pack
[
1
].
v_handle
=
nd
;
ret_tcode_pack
[
1
]
=
kHandle
;
ret_tcode_pack
[
1
]
=
kHandle
;
SendPackedSeq
(
ret_value_pack
,
ret_tcode_pack
,
2
,
true
);
SendPackedSeq
(
ret_value_pack
,
ret_tcode_pack
,
2
,
false
,
nullptr
,
true
);
}
else
{
}
else
{
ret_value
=
rv
.
value
();
ret_value
=
rv
.
value
();
ret_tcode
=
rv
.
type_code
();
ret_tcode
=
rv
.
type_code
();
SendPackedSeq
(
&
ret_value
,
&
ret_tcode
,
1
);
SendPackedSeq
(
&
ret_value
,
&
ret_tcode
,
1
,
false
);
}
}
}
catch
(
const
std
::
runtime_error
&
e
)
{
}
catch
(
const
std
::
runtime_error
&
e
)
{
RPCCode
code
=
RPCCode
::
kException
;
RPCCode
code
=
RPCCode
::
kException
;
this
->
Write
(
code
);
this
->
Write
(
code
);
ret_value
.
v_str
=
e
.
what
();
ret_value
.
v_str
=
e
.
what
();
ret_tcode
=
kStr
;
ret_tcode
=
kStr
;
SendPackedSeq
(
&
ret_value
,
&
ret_tcode
,
1
);
SendPackedSeq
(
&
ret_value
,
&
ret_tcode
,
1
,
false
);
}
}
}
}
...
@@ -873,7 +900,7 @@ void RPCSession::Init() {
...
@@ -873,7 +900,7 @@ void RPCSession::Init() {
&
reader_
,
&
writer_
,
table_index_
,
name_
,
&
remote_key_
);
&
reader_
,
&
writer_
,
table_index_
,
name_
,
&
remote_key_
);
// Quick function to call remote.
// Quick function to call remote.
call_remote_
=
PackedFunc
([
this
](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
call_remote_
=
PackedFunc
([
this
](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
handler_
->
SendPackedSeq
(
args
.
values
,
args
.
type_codes
,
args
.
num_args
);
handler_
->
SendPackedSeq
(
args
.
values
,
args
.
type_codes
,
args
.
num_args
,
true
);
RPCCode
code
=
HandleUntilReturnEvent
(
rv
,
true
,
nullptr
);
RPCCode
code
=
HandleUntilReturnEvent
(
rv
,
true
,
nullptr
);
CHECK
(
code
==
RPCCode
::
kReturn
)
<<
"code="
<<
static_cast
<
int
>
(
code
);
CHECK
(
code
==
RPCCode
::
kReturn
)
<<
"code="
<<
static_cast
<
int
>
(
code
);
});
});
...
@@ -954,13 +981,16 @@ int RPCSession::ServerEventHandler(const std::string& bytes, int event_flag) {
...
@@ -954,13 +981,16 @@ int RPCSession::ServerEventHandler(const std::string& bytes, int event_flag) {
void
RPCSession
::
CallFunc
(
void
*
h
,
void
RPCSession
::
CallFunc
(
void
*
h
,
TVMArgs
args
,
TVMArgs
args
,
TVMRetValue
*
rv
,
TVMRetValue
*
rv
,
FUnwrapRemoteObject
funwrap
,
const
PackedFunc
*
fwrap
)
{
const
PackedFunc
*
fwrap
)
{
std
::
lock_guard
<
std
::
recursive_mutex
>
lock
(
mutex_
);
std
::
lock_guard
<
std
::
recursive_mutex
>
lock
(
mutex_
);
RPCCode
code
=
RPCCode
::
kCallFunc
;
RPCCode
code
=
RPCCode
::
kCallFunc
;
handler_
->
Write
(
code
);
handler_
->
Write
(
code
);
uint64_t
handle
=
reinterpret_cast
<
uint64_t
>
(
h
);
uint64_t
handle
=
reinterpret_cast
<
uint64_t
>
(
h
);
handler_
->
Write
(
handle
);
handler_
->
Write
(
handle
);
handler_
->
SendPackedSeq
(
args
.
values
,
args
.
type_codes
,
args
.
num_args
);
handler_
->
SendPackedSeq
(
args
.
values
,
args
.
type_codes
,
args
.
num_args
,
true
,
funwrap
);
code
=
HandleUntilReturnEvent
(
rv
,
true
,
fwrap
);
code
=
HandleUntilReturnEvent
(
rv
,
true
,
fwrap
);
CHECK
(
code
==
RPCCode
::
kReturn
)
<<
"code="
<<
static_cast
<
int
>
(
code
);
CHECK
(
code
==
RPCCode
::
kReturn
)
<<
"code="
<<
static_cast
<
int
>
(
code
);
}
}
...
...
src/runtime/rpc/rpc_session.h
View file @
279a8eba
...
@@ -91,6 +91,16 @@ enum class RPCCode : int {
...
@@ -91,6 +91,16 @@ enum class RPCCode : int {
};
};
/*!
/*!
* \brief Function that unwraps a remote object to its handle.
* \param rpc_sess_table_index RPC session table index for validation.
* \param obj Handle to the object argument.
* \return The corresponding handle.
*/
typedef
void
*
(
*
FUnwrapRemoteObject
)(
int
rpc_sess_table_index
,
const
TVMArgValue
&
obj
);
/*!
* \brief Abstract channel interface used to create RPCSession.
* \brief Abstract channel interface used to create RPCSession.
*/
*/
class
RPCChannel
{
class
RPCChannel
{
...
@@ -144,11 +154,13 @@ class RPCSession {
...
@@ -144,11 +154,13 @@ class RPCSession {
* \param handle The function handle
* \param handle The function handle
* \param args The arguments
* \param args The arguments
* \param rv The return value.
* \param rv The return value.
* \param funpwrap Function that takes a remote object and returns the raw handle.
* \param fwrap Wrapper function to turn Function/Module handle into real return.
* \param fwrap Wrapper function to turn Function/Module handle into real return.
*/
*/
void
CallFunc
(
RPCFuncHandle
handle
,
void
CallFunc
(
RPCFuncHandle
handle
,
TVMArgs
args
,
TVMArgs
args
,
TVMRetValue
*
rv
,
TVMRetValue
*
rv
,
FUnwrapRemoteObject
funwrap
,
const
PackedFunc
*
fwrap
);
const
PackedFunc
*
fwrap
);
/*!
/*!
* \brief Copy bytes into remote array content.
* \brief Copy bytes into remote array content.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment