Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
11dd933f
Unverified
Commit
11dd933f
authored
Aug 16, 2018
by
Tianqi Chen
Committed by
GitHub
Aug 16, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[RUNTIME] Enable return NDArray in RPC (#1610)
parent
9bcc3173
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
158 additions
and
16 deletions
+158
-16
include/tvm/runtime/ndarray.h
+1
-0
src/api/api_base.cc
+8
-0
src/runtime/rpc/rpc_module.cc
+53
-8
src/runtime/rpc/rpc_session.cc
+52
-7
src/runtime/rpc/rpc_session.h
+1
-0
tests/python/unittest/test_runtime_rpc.py
+43
-1
No files found.
include/tvm/runtime/ndarray.h
View file @
11dd933f
...
@@ -246,6 +246,7 @@ struct NDArray::Container {
...
@@ -246,6 +246,7 @@ struct NDArray::Container {
private
:
private
:
friend
class
NDArray
;
friend
class
NDArray
;
friend
class
RPCWrappedFunc
;
/*!
/*!
* \brief The shape container,
* \brief The shape container,
* can be used used for shape data.
* can be used used for shape data.
...
...
src/api/api_base.cc
View file @
11dd933f
...
@@ -37,6 +37,14 @@ TVM_REGISTER_API("_nop")
...
@@ -37,6 +37,14 @@ TVM_REGISTER_API("_nop")
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
});
});
// internal fucntion used for debug and testing purposes
TVM_REGISTER_API
(
"_ndarray_use_count"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
runtime
::
NDArray
nd
=
args
[
0
];
// substract the current one
*
ret
=
(
nd
.
use_count
()
-
1
);
});
TVM_REGISTER_API
(
"_TVMSetStream"
)
TVM_REGISTER_API
(
"_TVMSetStream"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
TVMSetStream
(
args
[
0
],
args
[
1
],
args
[
2
]);
TVMSetStream
(
args
[
0
],
args
[
1
],
args
[
2
]);
...
...
src/runtime/rpc/rpc_module.cc
View file @
11dd933f
...
@@ -12,13 +12,13 @@ namespace tvm {
...
@@ -12,13 +12,13 @@ namespace tvm {
namespace
runtime
{
namespace
runtime
{
// Wrapped remote function to packed func.
// Wrapped remote function to packed func.
struct
RPCWrappedFunc
{
class
RPCWrappedFunc
{
public
:
public
:
RPCWrappedFunc
(
void
*
handle
,
RPCWrappedFunc
(
void
*
handle
,
std
::
shared_ptr
<
RPCSession
>
sess
)
std
::
shared_ptr
<
RPCSession
>
sess
)
:
handle_
(
handle
),
sess_
(
sess
)
{
:
handle_
(
handle
),
sess_
(
sess
)
{
fwrap_
=
PackedFunc
([
sess
](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
fwrap_
=
PackedFunc
([
sess
](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
WrapRemote
(
sess
,
args
.
values
[
0
].
v_handle
,
args
.
type_codes
[
0
]
,
rv
);
WrapRemote
(
sess
,
args
,
rv
);
});
});
}
}
...
@@ -34,10 +34,47 @@ struct RPCWrappedFunc {
...
@@ -34,10 +34,47 @@ struct RPCWrappedFunc {
}
}
static
void
WrapRemote
(
std
::
shared_ptr
<
RPCSession
>
sess
,
static
void
WrapRemote
(
std
::
shared_ptr
<
RPCSession
>
sess
,
void
*
handle
,
TVMArgs
args
,
int
tcode
,
TVMRetValue
*
rv
);
TVMRetValue
*
rv
);
// deleter of RPC remote array
static
void
RemoteNDArrayDeleter
(
NDArray
::
Container
*
ptr
)
{
RemoteSpace
*
space
=
static_cast
<
RemoteSpace
*>
(
ptr
->
dl_tensor
.
data
);
space
->
sess
->
CallRemote
(
RPCCode
::
kNDArrayFree
,
ptr
->
manager_ctx
);
delete
space
;
delete
ptr
;
}
// wrap return value as remote NDArray.
static
NDArray
WrapRemoteNDArray
(
std
::
shared_ptr
<
RPCSession
>
sess
,
DLTensor
*
tensor
,
void
*
nd_handle
)
{
NDArray
::
Container
*
data
=
new
NDArray
::
Container
();
data
->
manager_ctx
=
nd_handle
;
data
->
deleter
=
RemoteNDArrayDeleter
;
RemoteSpace
*
space
=
new
RemoteSpace
();
space
->
sess
=
sess
;
space
->
data
=
tensor
->
data
;
data
->
dl_tensor
.
data
=
space
;
NDArray
ret
(
data
);
// RAII now in effect
data
->
shape_
=
std
::
vector
<
int64_t
>
(
tensor
->
shape
,
tensor
->
shape
+
tensor
->
ndim
);
data
->
dl_tensor
.
shape
=
dmlc
::
BeginPtr
(
data
->
shape_
);
data
->
dl_tensor
.
ndim
=
static_cast
<
int
>
(
data
->
shape_
.
size
());
// setup dtype
data
->
dl_tensor
.
dtype
=
tensor
->
dtype
;
// setup ctx, encode as remote session
data
->
dl_tensor
.
ctx
.
device_id
=
tensor
->
ctx
.
device_id
;
data
->
dl_tensor
.
ctx
.
device_type
=
static_cast
<
DLDeviceType
>
(
static_cast
<
int
>
(
tensor
->
ctx
.
device_type
)
+
kRPCSessMask
*
(
sess
->
table_index
()
+
1
));
// check strides.
CHECK
(
tensor
->
strides
==
nullptr
);
// setup byteoffset
data
->
dl_tensor
.
byte_offset
=
tensor
->
byte_offset
;
return
ret
;
}
private
:
private
:
PackedFunc
fwrap_
;
PackedFunc
fwrap_
;
void
*
handle_
{
nullptr
};
void
*
handle_
{
nullptr
};
...
@@ -126,20 +163,28 @@ class RPCModuleNode final : public ModuleNode {
...
@@ -126,20 +163,28 @@ class RPCModuleNode final : public ModuleNode {
};
};
void
RPCWrappedFunc
::
WrapRemote
(
std
::
shared_ptr
<
RPCSession
>
sess
,
void
RPCWrappedFunc
::
WrapRemote
(
std
::
shared_ptr
<
RPCSession
>
sess
,
void
*
handle
,
TVMArgs
args
,
int
tcode
,
TVMRetValue
*
rv
)
{
TVMRetValue
*
rv
)
{
void
*
handle
=
args
.
values
[
0
].
v_handle
;
int
tcode
=
args
.
type_codes
[
0
];
if
(
handle
==
nullptr
)
return
;
if
(
handle
==
nullptr
)
return
;
if
(
tcode
==
kFuncHandle
)
{
if
(
tcode
==
kFuncHandle
)
{
auto
wf
=
std
::
make_shared
<
RPCWrappedFunc
>
(
handle
,
sess
);
auto
wf
=
std
::
make_shared
<
RPCWrappedFunc
>
(
handle
,
sess
);
*
rv
=
PackedFunc
([
wf
](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
*
rv
=
PackedFunc
([
wf
](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
return
wf
->
operator
()(
args
,
rv
);
return
wf
->
operator
()(
args
,
rv
);
});
});
}
else
{
}
else
if
(
tcode
==
kModuleHandle
)
{
CHECK_EQ
(
tcode
,
kModuleHandle
);
std
::
shared_ptr
<
RPCModuleNode
>
n
=
std
::
shared_ptr
<
RPCModuleNode
>
n
=
std
::
make_shared
<
RPCModuleNode
>
(
handle
,
sess
);
std
::
make_shared
<
RPCModuleNode
>
(
handle
,
sess
);
*
rv
=
Module
(
n
);
*
rv
=
Module
(
n
);
}
else
if
(
tcode
==
kArrayHandle
||
tcode
==
kNDArrayContainer
)
{
CHECK_EQ
(
args
.
size
(),
2
);
DLTensor
*
tensor
=
args
[
0
];
void
*
nd_handle
=
args
[
1
];
*
rv
=
WrapRemoteNDArray
(
sess
,
tensor
,
nd_handle
);
}
else
{
LOG
(
FATAL
)
<<
"Cannot wrap tcode="
<<
tcode
;
}
}
}
}
...
...
src/runtime/rpc/rpc_session.cc
View file @
11dd933f
...
@@ -130,13 +130,16 @@ class RPCSession::EventHandler : public dmlc::Stream {
...
@@ -130,13 +130,16 @@ class RPCSession::EventHandler : public dmlc::Stream {
break
;
break
;
}
}
case
kReturnReceived
:
{
case
kReturnReceived
:
{
CHECK_EQ
(
arg_buf_
->
value
.
size
(),
1U
);
CHECK_GE
(
arg_buf_
->
value
.
size
(),
1U
);
TVMArgValue
argv
=
arg_buf_
->
AsTVMArgs
()[
0
];
TVMArgValue
argv
=
arg_buf_
->
AsTVMArgs
()[
0
];
if
(
argv
.
type_code
()
==
kFuncHandle
||
if
(
argv
.
type_code
()
==
kFuncHandle
||
argv
.
type_code
()
==
kModuleHandle
)
{
argv
.
type_code
()
==
kModuleHandle
||
argv
.
type_code
()
==
kArrayHandle
)
{
CHECK
(
fwrap
!=
nullptr
)
<<
"function/module wrapper not available"
;
CHECK
(
fwrap
!=
nullptr
)
<<
"function/module wrapper not available"
;
fwrap
->
CallPacked
(
arg_buf_
->
AsTVMArgs
(),
rv
);
fwrap
->
CallPacked
(
arg_buf_
->
AsTVMArgs
(),
rv
);
}
else
{
}
else
{
CHECK_EQ
(
arg_buf_
->
value
.
size
(),
1U
);
*
rv
=
argv
;
*
rv
=
argv
;
}
}
arg_buf_
.
reset
();
arg_buf_
.
reset
();
...
@@ -172,15 +175,22 @@ class RPCSession::EventHandler : public dmlc::Stream {
...
@@ -172,15 +175,22 @@ class RPCSession::EventHandler : public dmlc::Stream {
ctx
.
device_type
=
static_cast
<
DLDeviceType
>
(
dev_type
%
kRPCSessMask
);
ctx
.
device_type
=
static_cast
<
DLDeviceType
>
(
dev_type
%
kRPCSessMask
);
return
ctx
;
return
ctx
;
}
}
// send Packed sequence to writer.
// Send Packed sequence to writer.
void
SendPackedSeq
(
const
TVMValue
*
arg_values
,
const
int
*
type_codes
,
int
n
)
{
// return_ndarray is a special flag to handle returning of ndarray
// In this case, we return the shape, context and data of the array,
// as well as a customized PackedFunc that handles deletion of
// the array in the remote.
void
SendPackedSeq
(
const
TVMValue
*
arg_values
,
const
int
*
type_codes
,
int
n
,
bool
return_ndarray
=
false
)
{
this
->
Write
(
n
);
this
->
Write
(
n
);
// only handles .
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
int
tcode
=
type_codes
[
i
];
int
tcode
=
type_codes
[
i
];
if
(
tcode
==
kNDArrayContainer
)
tcode
=
kArrayHandle
;
if
(
tcode
==
kNDArrayContainer
)
tcode
=
kArrayHandle
;
this
->
Write
(
tcode
);
this
->
Write
(
tcode
);
}
}
// Argument packing.
// Argument packing.
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
int
tcode
=
type_codes
[
i
];
int
tcode
=
type_codes
[
i
];
...
@@ -215,9 +225,23 @@ class RPCSession::EventHandler : public dmlc::Stream {
...
@@ -215,9 +225,23 @@ class RPCSession::EventHandler : public dmlc::Stream {
case
kNDArrayContainer
:
case
kNDArrayContainer
:
case
kArrayHandle
:
{
case
kArrayHandle
:
{
DLTensor
*
arr
=
static_cast
<
DLTensor
*>
(
value
.
v_handle
);
DLTensor
*
arr
=
static_cast
<
DLTensor
*>
(
value
.
v_handle
);
TVMContext
ctx
=
StripSessMask
(
arr
->
ctx
);
TVMContext
ctx
;
uint64_t
data
=
reinterpret_cast
<
uint64_t
>
(
uint64_t
data
;
if
(
!
return_ndarray
)
{
// in the client mode
// ctx contains the remote table index
// the space is wrapped by an RemoteSpace
// that holds reference to the session.
ctx
=
StripSessMask
(
arr
->
ctx
);
data
=
reinterpret_cast
<
uint64_t
>
(
static_cast
<
RemoteSpace
*>
(
arr
->
data
)
->
data
);
static_cast
<
RemoteSpace
*>
(
arr
->
data
)
->
data
);
}
else
{
// When we return NDArray, we directly return
// the space and the context
// The client will be further wrapping
ctx
=
arr
->
ctx
;
data
=
reinterpret_cast
<
uint64_t
>
(
arr
->
data
);
}
this
->
Write
(
data
);
this
->
Write
(
data
);
this
->
Write
(
ctx
);
this
->
Write
(
ctx
);
this
->
Write
(
arr
->
ndim
);
this
->
Write
(
arr
->
ndim
);
...
@@ -701,6 +725,21 @@ class RPCSession::EventHandler : public dmlc::Stream {
...
@@ -701,6 +725,21 @@ class RPCSession::EventHandler : public dmlc::Stream {
<<
"Only server can send function and module handle back."
;
<<
"Only server can send function and module handle back."
;
rv
.
MoveToCHost
(
&
ret_value
,
&
ret_tcode
);
rv
.
MoveToCHost
(
&
ret_value
,
&
ret_tcode
);
SendPackedSeq
(
&
ret_value
,
&
ret_tcode
,
1
);
SendPackedSeq
(
&
ret_value
,
&
ret_tcode
,
1
);
}
else
if
(
rv
.
type_code
()
==
kNDArrayContainer
)
{
// always send handle in 64 bit.
CHECK
(
!
client_mode_
)
<<
"Only server can send NDArray back"
;
// We follow a special protocol to return NDArray to client side
// The first pack value is the NDArray handle as DLTensor
// The second pack value is a customized deleter that deletes the NDArray.
TVMValue
ret_value_pack
[
2
];
int
ret_tcode_pack
[
2
];
rv
.
MoveToCHost
(
&
ret_value_pack
[
0
],
&
ret_tcode_pack
[
0
]);
NDArray
::
Container
*
nd
=
static_cast
<
NDArray
::
Container
*>
(
ret_value_pack
[
0
].
v_handle
);
ret_value_pack
[
1
].
v_handle
=
nd
;
ret_tcode_pack
[
1
]
=
kHandle
;
SendPackedSeq
(
ret_value_pack
,
ret_tcode_pack
,
2
,
true
);
}
else
{
}
else
{
ret_value
=
rv
.
value
();
ret_value
=
rv
.
value
();
ret_tcode
=
rv
.
type_code
();
ret_tcode
=
rv
.
type_code
();
...
@@ -1090,6 +1129,11 @@ void RPCModuleGetSource(TVMArgs args, TVMRetValue *rv) {
...
@@ -1090,6 +1129,11 @@ void RPCModuleGetSource(TVMArgs args, TVMRetValue *rv) {
*
rv
=
(
*
static_cast
<
Module
*>
(
mhandle
))
->
GetSource
(
fmt
);
*
rv
=
(
*
static_cast
<
Module
*>
(
mhandle
))
->
GetSource
(
fmt
);
}
}
void
RPCNDArrayFree
(
TVMArgs
args
,
TVMRetValue
*
rv
)
{
void
*
handle
=
args
[
0
];
static_cast
<
NDArray
::
Container
*>
(
handle
)
->
DecRef
();
}
void
RPCGetTimeEvaluator
(
TVMArgs
args
,
TVMRetValue
*
rv
)
{
void
RPCGetTimeEvaluator
(
TVMArgs
args
,
TVMRetValue
*
rv
)
{
PackedFunc
*
pf
=
static_cast
<
PackedFunc
*>
(
args
[
0
].
operator
void
*
());
PackedFunc
*
pf
=
static_cast
<
PackedFunc
*>
(
args
[
0
].
operator
void
*
());
void
*
fhandle
=
new
PackedFunc
(
WrapTimeEvaluator
(
*
pf
,
args
[
1
],
args
[
2
],
args
[
3
]));
void
*
fhandle
=
new
PackedFunc
(
WrapTimeEvaluator
(
*
pf
,
args
[
1
],
args
[
2
],
args
[
3
]));
...
@@ -1138,6 +1182,7 @@ void RPCSession::EventHandler::HandlePackedCall() {
...
@@ -1138,6 +1182,7 @@ void RPCSession::EventHandler::HandlePackedCall() {
case
RPCCode
:
:
kModuleFree
:
CallHandler
(
RPCModuleFree
);
break
;
case
RPCCode
:
:
kModuleFree
:
CallHandler
(
RPCModuleFree
);
break
;
case
RPCCode
:
:
kModuleGetFunc
:
CallHandler
(
RPCModuleGetFunc
);
break
;
case
RPCCode
:
:
kModuleGetFunc
:
CallHandler
(
RPCModuleGetFunc
);
break
;
case
RPCCode
:
:
kModuleGetSource
:
CallHandler
(
RPCModuleGetSource
);
break
;
case
RPCCode
:
:
kModuleGetSource
:
CallHandler
(
RPCModuleGetSource
);
break
;
case
RPCCode
:
:
kNDArrayFree
:
CallHandler
(
RPCNDArrayFree
);
break
;
default
:
LOG
(
FATAL
)
<<
"Unknown event "
<<
static_cast
<
int
>
(
code_
);
default
:
LOG
(
FATAL
)
<<
"Unknown event "
<<
static_cast
<
int
>
(
code_
);
}
}
CHECK_EQ
(
state_
,
kRecvCode
);
CHECK_EQ
(
state_
,
kRecvCode
);
...
...
src/runtime/rpc/rpc_session.h
View file @
11dd933f
...
@@ -48,6 +48,7 @@ enum class RPCCode : int {
...
@@ -48,6 +48,7 @@ enum class RPCCode : int {
kModuleFree
,
kModuleFree
,
kModuleGetFunc
,
kModuleGetFunc
,
kModuleGetSource
,
kModuleGetSource
,
kNDArrayFree
};
};
/*!
/*!
...
...
tests/python/unittest/test_runtime_rpc.py
View file @
11dd933f
...
@@ -175,6 +175,7 @@ def test_rpc_return_func():
...
@@ -175,6 +175,7 @@ def test_rpc_return_func():
@tvm.register_func
(
"rpc.test.remote_func"
)
@tvm.register_func
(
"rpc.test.remote_func"
)
def
addone
(
x
):
def
addone
(
x
):
return
lambda
y
:
x
+
y
return
lambda
y
:
x
+
y
server
=
rpc
.
Server
(
"localhost"
,
key
=
"x1"
)
server
=
rpc
.
Server
(
"localhost"
,
key
=
"x1"
)
client
=
rpc
.
connect
(
server
.
host
,
server
.
port
,
key
=
"x1"
)
client
=
rpc
.
connect
(
server
.
host
,
server
.
port
,
key
=
"x1"
)
f1
=
client
.
get_function
(
"rpc.test.remote_func"
)
f1
=
client
.
get_function
(
"rpc.test.remote_func"
)
...
@@ -182,6 +183,46 @@ def test_rpc_return_func():
...
@@ -182,6 +183,46 @@ def test_rpc_return_func():
assert
fadd
(
12
)
==
22
assert
fadd
(
12
)
==
22
def
test_rpc_return_ndarray
():
# Use closure to check the ref counter correctness
nd
=
tvm
.
nd
.
array
(
np
.
zeros
(
10
)
.
astype
(
"float32"
))
@tvm.register_func
(
"rpc.test.remote_return_nd"
)
def
my_module
(
name
):
if
name
==
"get_arr"
:
return
lambda
:
nd
elif
name
==
"ref_count"
:
return
lambda
:
tvm
.
_api_internal
.
_ndarray_use_count
(
nd
)
elif
name
==
"get_elem"
:
return
lambda
idx
:
nd
.
asnumpy
()[
idx
]
elif
name
==
"get_arr_elem"
:
return
lambda
arr
,
idx
:
arr
.
asnumpy
()[
idx
]
# start server
server
=
rpc
.
Server
(
"localhost"
,
key
=
"x1"
)
client
=
rpc
.
connect
(
server
.
host
,
server
.
port
,
key
=
"x1"
)
m
=
client
.
get_function
(
"rpc.test.remote_return_nd"
)
get_arr
=
m
(
"get_arr"
)
ref_count
=
m
(
"ref_count"
)
get_elem
=
m
(
"get_elem"
)
get_arr_elem
=
m
(
"get_arr_elem"
)
# array test
def
run_arr_test
():
arr
=
get_arr
()
assert
ref_count
()
==
2
arr2
=
get_arr
()
assert
ref_count
()
==
3
assert
arr
.
context
==
client
.
cpu
(
0
)
arr
.
copyfrom
(
np
.
ones
(
10
)
.
astype
(
arr
.
dtype
))
assert
arr2
.
asnumpy
()[
0
]
==
1.0
assert
get_elem
(
0
)
==
1.0
assert
get_arr_elem
(
arr2
,
0
)
==
1.0
assert
ref_count
()
==
1
run_arr_test
()
# check recycle correctness
assert
ref_count
()
==
1
def
test_local_func
():
def
test_local_func
():
@tvm.register_func
(
"rpc.test.remote_func2"
)
@tvm.register_func
(
"rpc.test.remote_func2"
)
def
addone
(
x
):
def
addone
(
x
):
...
@@ -199,9 +240,10 @@ def test_local_func():
...
@@ -199,9 +240,10 @@ def test_local_func():
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
logging
.
basicConfig
(
level
=
logging
.
INFO
)
logging
.
basicConfig
(
level
=
logging
.
INFO
)
test_rpc_return_ndarray
()
test_rpc_return_func
()
test_bigendian_rpc
()
test_bigendian_rpc
()
test_rpc_remote_module
()
test_rpc_remote_module
()
test_rpc_return_func
()
test_rpc_file_exchange
()
test_rpc_file_exchange
()
test_rpc_array
()
test_rpc_array
()
test_rpc_simple
()
test_rpc_simple
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment