Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
42608dda
Commit
42608dda
authored
May 30, 2018
by
tqchen
Committed by
Tianqi Chen
May 30, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[IO] Support cross-endian
parent
0f9dab98
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
263 additions
and
156 deletions
+263
-156
dmlc-core
+1
-1
include/tvm/runtime/serializer.h
+50
-0
nnvm/src/compiler/graph_runtime.cc
+40
-30
python/tvm/contrib/rpc/base.py
+2
-2
python/tvm/contrib/rpc/client.py
+2
-2
python/tvm/contrib/rpc/proxy.py
+14
-14
python/tvm/contrib/rpc/server.py
+11
-11
python/tvm/contrib/rpc/tracker.py
+6
-6
src/runtime/file_util.cc
+1
-0
src/runtime/graph/graph_runtime.cc
+18
-12
src/runtime/rpc/rpc_session.cc
+118
-78
No files found.
dmlc-core
@
9b3f9753
Subproject commit
d3f7fbb53e5b037c0f5bf6bd21871ccc720690cc
Subproject commit
9b3f9753ae81d657743c555e0cacc4e43f0bed2d
include/tvm/runtime/serializer.h
0 → 100644
View file @
42608dda
/*!
* Copyright (c) 2017 by Contributors
* \file tvm/runtime/serializer.h
* \brief Serializer extension to support TVM data types
* Include this file to enable serialization of DLDataType, DLContext
*/
#ifndef TVM_RUNTIME_SERIALIZER_H_
#define TVM_RUNTIME_SERIALIZER_H_
#include <dmlc/io.h>
#include <dmlc/serializer.h>
#include "./c_runtime_api.h"
namespace
dmlc
{
namespace
serializer
{
template
<>
struct
Handler
<
DLDataType
>
{
inline
static
void
Write
(
Stream
*
strm
,
const
DLDataType
&
dtype
)
{
Handler
<
uint8_t
>::
Write
(
strm
,
dtype
.
code
);
Handler
<
uint8_t
>::
Write
(
strm
,
dtype
.
bits
);
Handler
<
uint16_t
>::
Write
(
strm
,
dtype
.
lanes
);
}
inline
static
bool
Read
(
Stream
*
strm
,
DLDataType
*
dtype
)
{
if
(
!
Handler
<
uint8_t
>::
Read
(
strm
,
&
(
dtype
->
code
)))
return
false
;
if
(
!
Handler
<
uint8_t
>::
Read
(
strm
,
&
(
dtype
->
bits
)))
return
false
;
if
(
!
Handler
<
uint16_t
>::
Read
(
strm
,
&
(
dtype
->
lanes
)))
return
false
;
return
true
;
}
};
template
<>
struct
Handler
<
DLContext
>
{
inline
static
void
Write
(
Stream
*
strm
,
const
DLContext
&
ctx
)
{
int32_t
device_type
=
static_cast
<
int32_t
>
(
ctx
.
device_type
);
Handler
<
int32_t
>::
Write
(
strm
,
device_type
);
Handler
<
int32_t
>::
Write
(
strm
,
ctx
.
device_id
);
}
inline
static
bool
Read
(
Stream
*
strm
,
DLContext
*
ctx
)
{
int32_t
device_type
=
0
;
if
(
!
Handler
<
int32_t
>::
Read
(
strm
,
&
(
device_type
)))
return
false
;
ctx
->
device_type
=
static_cast
<
DLDeviceType
>
(
device_type
);
if
(
!
Handler
<
int32_t
>::
Read
(
strm
,
&
(
ctx
->
device_id
)))
return
false
;
return
true
;
}
};
}
// namespace serializer
}
// namespace dmlc
#endif // TVM_RUNTIME_SERIALIZER_H_
nnvm/src/compiler/graph_runtime.cc
View file @
42608dda
...
...
@@ -7,6 +7,7 @@
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/serializer.h>
#include "./graph_runtime.h"
namespace
nnvm
{
...
...
@@ -38,46 +39,53 @@ NNVM_REGISTER_OP(tvm_op)
bool
SaveDLTensor
(
dmlc
::
Stream
*
strm
,
DLTensor
*
tensor
)
{
uint64_t
header
=
kTVMNDArrayMagic
,
reserved
=
0
;
strm
->
Write
(
&
header
,
sizeof
(
header
));
strm
->
Write
(
&
reserved
,
sizeof
(
reserved
));
strm
->
Write
(
&
tensor
->
ctx
,
sizeof
(
tensor
->
ctx
));
strm
->
Write
(
&
tensor
->
ndim
,
sizeof
(
tensor
->
ndim
));
strm
->
Write
(
&
tensor
->
dtype
,
sizeof
(
tensor
->
dtype
));
strm
->
Write
(
header
);
strm
->
Write
(
reserved
);
strm
->
Write
(
tensor
->
ctx
);
strm
->
Write
(
tensor
->
ndim
);
strm
->
Write
(
tensor
->
dtype
);
int
ndim
=
tensor
->
ndim
;
strm
->
Write
(
tensor
->
shape
,
sizeof
(
int64_t
)
*
ndim
);
strm
->
Write
Array
(
tensor
->
shape
,
ndim
);
int
type_
size
=
tensor
->
dtype
.
bits
/
8
;
int64_t
size
=
1
;
int
type_
bytes
=
tensor
->
dtype
.
bits
/
8
;
int64_t
num_elems
=
1
;
for
(
int
i
=
0
;
i
<
ndim
;
++
i
)
{
size
*=
tensor
->
shape
[
i
];
num_elems
*=
tensor
->
shape
[
i
];
}
int64_t
data_byte_size
=
type_size
*
size
;
strm
->
Write
(
&
data_byte_size
,
sizeof
(
data_byte_size
));
int64_t
data_byte_size
=
type_bytes
*
num_elems
;
strm
->
Write
(
data_byte_size
);
// handle endianness of data correctly.
if
(
DMLC_IO_NO_ENDIAN_SWAP
)
{
strm
->
Write
(
tensor
->
data
,
data_byte_size
);
}
else
{
uint8_t
*
dptr
=
reinterpret_cast
<
uint8_t
*>
(
tensor
->
data
);
std
::
vector
<
uint8_t
>
bytes
(
dptr
,
dptr
+
data_byte_size
);
dmlc
::
ByteSwap
(
dmlc
::
BeginPtr
(
bytes
),
type_bytes
,
num_elems
);
strm
->
Write
(
dmlc
::
BeginPtr
(
bytes
),
data_byte_size
);
}
return
true
;
}
DLTensor
*
LoadDLTensor
(
dmlc
::
Stream
*
strm
)
{
uint64_t
header
,
reserved
;
CHECK
(
strm
->
Read
(
&
header
,
sizeof
(
header
)
))
CHECK
(
strm
->
Read
(
&
header
))
<<
"Invalid DLTensor file format"
;
CHECK
(
strm
->
Read
(
&
reserved
,
sizeof
(
reserved
)
))
CHECK
(
strm
->
Read
(
&
reserved
))
<<
"Invalid DLTensor file format"
;
CHECK
(
header
==
kTVMNDArrayMagic
)
<<
"Invalid DLTensor file format"
;
DLTensor
tensor
;
CHECK
(
strm
->
Read
(
&
tensor
.
ctx
,
sizeof
(
tensor
.
ctx
)))
CHECK
(
strm
->
Read
(
&
(
tensor
.
ctx
)))
<<
"Invalid DLTensor file format"
;
CHECK
(
strm
->
Read
(
&
tensor
.
ndim
,
sizeof
(
tensor
.
ndim
)))
CHECK
(
strm
->
Read
(
&
(
tensor
.
ndim
)))
<<
"Invalid DLTensor file format"
;
CHECK
(
strm
->
Read
(
&
tensor
.
dtype
,
sizeof
(
tensor
.
dtype
)))
CHECK
(
strm
->
Read
(
&
(
tensor
.
dtype
)))
<<
"Invalid DLTensor file format"
;
std
::
vector
<
int64_t
>
shape
(
tensor
.
ndim
);
CHECK
(
strm
->
Read
(
&
shape
[
0
],
sizeof
(
int64_t
)
*
tensor
.
ndim
))
if
(
tensor
.
ndim
!=
0
)
{
CHECK
(
strm
->
ReadArray
(
&
shape
[
0
],
tensor
.
ndim
))
<<
"Invalid DLTensor file format"
;
}
DLTensor
*
ret
;
CHECK_EQ
(
TVMArrayAlloc
(
shape
.
data
(),
tensor
.
ndim
,
...
...
@@ -87,18 +95,21 @@ DLTensor* LoadDLTensor(dmlc::Stream* strm) {
static_cast
<
int
>
(
tensor
.
ctx
.
device_type
),
tensor
.
ctx
.
device_id
,
&
ret
),
0
)
<<
TVMGetLastError
();
int64_t
size
=
1
;
int
type_size
=
ret
->
dtype
.
bits
/
8
;
int64_t
num_elems
=
1
;
int
elem_bytes
=
(
ret
->
dtype
.
bits
+
7
)
/
8
;
for
(
int
i
=
0
;
i
<
ret
->
ndim
;
++
i
)
{
size
*=
ret
->
shape
[
i
];
num_elems
*=
ret
->
shape
[
i
];
}
int64_t
data_byte_size
;
CHECK
(
strm
->
Read
(
&
data_byte_size
,
sizeof
(
data_byte_size
)
))
CHECK
(
strm
->
Read
(
&
data_byte_size
))
<<
"Invalid DLTensor file format"
;
CHECK
(
data_byte_size
==
type_size
*
size
)
CHECK
(
data_byte_size
==
num_elems
*
elem_bytes
)
<<
"Invalid DLTensor file format"
;
CHECK
(
strm
->
Read
(
ret
->
data
,
type_size
*
size
))
CHECK
(
strm
->
Read
(
ret
->
data
,
data_byte_
size
))
<<
"Invalid DLTensor file format"
;
if
(
!
DMLC_IO_NO_ENDIAN_SWAP
)
{
dmlc
::
ByteSwap
(
ret
->
data
,
elem_bytes
,
num_elems
);
}
return
ret
;
}
...
...
@@ -118,12 +129,12 @@ TVM_REGISTER_GLOBAL("nnvm.compiler._save_param_dict")
dmlc
::
MemoryStringStream
strm
(
&
bytes
);
dmlc
::
Stream
*
fo
=
&
strm
;
uint64_t
header
=
kTVMNDArrayListMagic
,
reserved
=
0
;
fo
->
Write
(
&
header
,
sizeof
(
header
)
);
fo
->
Write
(
&
reserved
,
sizeof
(
reserved
)
);
fo
->
Write
(
header
);
fo
->
Write
(
reserved
);
fo
->
Write
(
names
);
{
uint64_t
sz
=
static_cast
<
uint64_t
>
(
arrays
.
size
());
fo
->
Write
(
&
sz
,
sizeof
(
sz
)
);
fo
->
Write
(
sz
);
for
(
size_t
i
=
0
;
i
<
sz
;
++
i
)
{
SaveDLTensor
(
fo
,
arrays
[
i
]);
}
...
...
@@ -150,7 +161,6 @@ TVM_REGISTER_GLOBAL("nnvm.compiler._load_param_dict")
<<
"Invalid parameters file format"
;
CHECK
(
strm
->
Read
(
&
reserved
))
<<
"Invalid parameters file format"
;
CHECK
(
strm
->
Read
(
&
names
))
<<
"Invalid parameters file format"
;
uint64_t
sz
;
...
...
python/tvm/contrib/rpc/base.py
View file @
42608dda
...
...
@@ -73,7 +73,7 @@ def sendjson(sock, data):
Python value to be sent.
"""
data
=
json
.
dumps
(
data
)
sock
.
sendall
(
struct
.
pack
(
"
@
i"
,
len
(
data
)))
sock
.
sendall
(
struct
.
pack
(
"
<
i"
,
len
(
data
)))
sock
.
sendall
(
data
.
encode
(
"utf-8"
))
...
...
@@ -90,7 +90,7 @@ def recvjson(sock):
value : object
The value received.
"""
size
=
struct
.
unpack
(
"
@
i"
,
recvall
(
sock
,
4
))[
0
]
size
=
struct
.
unpack
(
"
<
i"
,
recvall
(
sock
,
4
))[
0
]
data
=
json
.
loads
(
py_str
(
recvall
(
sock
,
size
)))
return
data
...
...
python/tvm/contrib/rpc/client.py
View file @
42608dda
...
...
@@ -192,8 +192,8 @@ class TrackerSession(object):
def
_connect
(
self
):
self
.
_sock
=
base
.
connect_with_retry
(
self
.
_addr
)
self
.
_sock
.
sendall
(
struct
.
pack
(
"
@
i"
,
base
.
RPC_TRACKER_MAGIC
))
magic
=
struct
.
unpack
(
"
@
i"
,
base
.
recvall
(
self
.
_sock
,
4
))[
0
]
self
.
_sock
.
sendall
(
struct
.
pack
(
"
<
i"
,
base
.
RPC_TRACKER_MAGIC
))
magic
=
struct
.
unpack
(
"
<
i"
,
base
.
recvall
(
self
.
_sock
,
4
))[
0
]
if
magic
!=
base
.
RPC_TRACKER_MAGIC
:
raise
RuntimeError
(
"
%
s is not RPC Tracker"
%
str
(
self
.
_addr
))
...
...
python/tvm/contrib/rpc/proxy.py
View file @
42608dda
...
...
@@ -58,14 +58,14 @@ class ForwardHandler(object):
def
_init_step
(
self
,
message
):
if
self
.
_magic
is
None
:
assert
len
(
message
)
==
4
self
.
_magic
=
struct
.
unpack
(
'
@
i'
,
message
)[
0
]
self
.
_magic
=
struct
.
unpack
(
'
<
i'
,
message
)[
0
]
if
self
.
_magic
!=
base
.
RPC_MAGIC
:
logging
.
info
(
"Invalid RPC magic from
%
s"
,
self
.
name
())
self
.
close
()
self
.
_init_req_nbytes
=
4
elif
self
.
_rpc_key_length
is
None
:
assert
len
(
message
)
==
4
self
.
_rpc_key_length
=
struct
.
unpack
(
'
@
i'
,
message
)[
0
]
self
.
_rpc_key_length
=
struct
.
unpack
(
'
<
i'
,
message
)[
0
]
self
.
_init_req_nbytes
=
self
.
_rpc_key_length
elif
self
.
rpc_key
is
None
:
assert
len
(
message
)
==
self
.
_rpc_key_length
...
...
@@ -269,12 +269,12 @@ class ProxyServerHandler(object):
lhs
.
forward_proxy
=
rhs
rhs
.
forward_proxy
=
lhs
lhs
.
send_data
(
struct
.
pack
(
'
@
i'
,
base
.
RPC_CODE_SUCCESS
))
lhs
.
send_data
(
struct
.
pack
(
'
@
i'
,
len
(
rhs
.
rpc_key
)))
lhs
.
send_data
(
struct
.
pack
(
'
<
i'
,
base
.
RPC_CODE_SUCCESS
))
lhs
.
send_data
(
struct
.
pack
(
'
<
i'
,
len
(
rhs
.
rpc_key
)))
lhs
.
send_data
(
rhs
.
rpc_key
.
encode
(
"utf-8"
))
rhs
.
send_data
(
struct
.
pack
(
'
@
i'
,
base
.
RPC_CODE_SUCCESS
))
rhs
.
send_data
(
struct
.
pack
(
'
@
i'
,
len
(
lhs
.
rpc_key
)))
rhs
.
send_data
(
struct
.
pack
(
'
<
i'
,
base
.
RPC_CODE_SUCCESS
))
rhs
.
send_data
(
struct
.
pack
(
'
<
i'
,
len
(
lhs
.
rpc_key
)))
rhs
.
send_data
(
lhs
.
rpc_key
.
encode
(
"utf-8"
))
logging
.
info
(
"Pairup connect
%
s and
%
s"
,
lhs
.
name
(),
rhs
.
name
())
...
...
@@ -299,8 +299,8 @@ class ProxyServerHandler(object):
if
self
.
_tracker_conn
is
None
:
self
.
_tracker_conn
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
self
.
_tracker_conn
.
connect
(
self
.
_tracker_addr
)
self
.
_tracker_conn
.
sendall
(
struct
.
pack
(
"
@
i"
,
base
.
RPC_TRACKER_MAGIC
))
magic
=
struct
.
unpack
(
"
@
i"
,
base
.
recvall
(
self
.
_tracker_conn
,
4
))[
0
]
self
.
_tracker_conn
.
sendall
(
struct
.
pack
(
"
<
i"
,
base
.
RPC_TRACKER_MAGIC
))
magic
=
struct
.
unpack
(
"
<
i"
,
base
.
recvall
(
self
.
_tracker_conn
,
4
))[
0
]
if
magic
!=
base
.
RPC_TRACKER_MAGIC
:
self
.
loop
.
stop
()
raise
RuntimeError
(
"
%
s is not RPC Tracker"
%
str
(
self
.
_tracker_addr
))
...
...
@@ -371,7 +371,7 @@ class ProxyServerHandler(object):
if
handler
.
match_key
in
self
.
_server_pool
:
self
.
_pair_up
(
self
.
_server_pool
.
pop
(
handler
.
match_key
),
handler
)
else
:
handler
.
send_data
(
struct
.
pack
(
'
@
i'
,
base
.
RPC_CODE_MISMATCH
))
handler
.
send_data
(
struct
.
pack
(
'
<
i'
,
base
.
RPC_CODE_MISMATCH
))
handler
.
signal_close
()
def
_handler_ready_proxy_mode
(
self
,
handler
):
...
...
@@ -395,12 +395,12 @@ class ProxyServerHandler(object):
logging
.
info
(
"Timeout client connection
%
s, cannot find match key=
%
s"
,
handler
.
name
(),
key
)
pool_dst
.
pop
(
key
)
handler
.
send_data
(
struct
.
pack
(
'
@
i'
,
base
.
RPC_CODE_MISMATCH
))
handler
.
send_data
(
struct
.
pack
(
'
<
i'
,
base
.
RPC_CODE_MISMATCH
))
handler
.
signal_close
()
self
.
loop
.
call_later
(
timeout
,
cleanup
)
else
:
logging
.
info
(
"Duplicate connection with same key=
%
s"
,
key
)
handler
.
send_data
(
struct
.
pack
(
'
@
i'
,
base
.
RPC_CODE_DUPLICATE
))
handler
.
send_data
(
struct
.
pack
(
'
<
i'
,
base
.
RPC_CODE_DUPLICATE
))
handler
.
signal_close
()
def
handler_ready
(
self
,
handler
):
...
...
@@ -538,13 +538,13 @@ def websocket_proxy_server(url, key=""):
on_message
=
create_on_message
(
conn
)
temp
=
_server_env
(
None
)
# Start connecton
conn
.
write_message
(
struct
.
pack
(
'
@
i'
,
base
.
RPC_MAGIC
),
binary
=
True
)
conn
.
write_message
(
struct
.
pack
(
'
<
i'
,
base
.
RPC_MAGIC
),
binary
=
True
)
key
=
"server:"
+
key
conn
.
write_message
(
struct
.
pack
(
'
@
i'
,
len
(
key
)),
binary
=
True
)
conn
.
write_message
(
struct
.
pack
(
'
<
i'
,
len
(
key
)),
binary
=
True
)
conn
.
write_message
(
key
.
encode
(
"utf-8"
),
binary
=
True
)
msg
=
yield
conn
.
read_message
()
assert
len
(
msg
)
>=
4
magic
=
struct
.
unpack
(
'
@
i'
,
msg
[:
4
])[
0
]
magic
=
struct
.
unpack
(
'
<
i'
,
msg
[:
4
])[
0
]
if
magic
==
base
.
RPC_CODE_DUPLICATE
:
raise
RuntimeError
(
"key:
%
s has already been used in proxy"
%
key
)
elif
magic
==
base
.
RPC_CODE_MISMATCH
:
...
...
python/tvm/contrib/rpc/server.py
View file @
42608dda
...
...
@@ -124,23 +124,23 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library):
unmatch_period_count
=
0
continue
conn
,
addr
=
listen_sock
.
accept
()
magic
=
struct
.
unpack
(
"
@
i"
,
base
.
recvall
(
conn
,
4
))[
0
]
magic
=
struct
.
unpack
(
"
<
i"
,
base
.
recvall
(
conn
,
4
))[
0
]
if
magic
!=
base
.
RPC_MAGIC
:
conn
.
close
()
continue
keylen
=
struct
.
unpack
(
"
@
i"
,
base
.
recvall
(
conn
,
4
))[
0
]
keylen
=
struct
.
unpack
(
"
<
i"
,
base
.
recvall
(
conn
,
4
))[
0
]
key
=
py_str
(
base
.
recvall
(
conn
,
keylen
))
arr
=
key
.
split
()
expect_header
=
"client:"
+
matchkey
server_key
=
"server:"
+
rpc_key
if
arr
[
0
]
!=
expect_header
:
conn
.
sendall
(
struct
.
pack
(
"
@
i"
,
base
.
RPC_CODE_MISMATCH
))
conn
.
sendall
(
struct
.
pack
(
"
<
i"
,
base
.
RPC_CODE_MISMATCH
))
conn
.
close
()
logging
.
info
(
"RPCServer: mismatch key from
%
s"
,
addr
)
continue
else
:
conn
.
sendall
(
struct
.
pack
(
"
@
i"
,
base
.
RPC_CODE_SUCCESS
))
conn
.
sendall
(
struct
.
pack
(
"
@
i"
,
len
(
server_key
)))
conn
.
sendall
(
struct
.
pack
(
"
<
i"
,
base
.
RPC_CODE_SUCCESS
))
conn
.
sendall
(
struct
.
pack
(
"
<
i"
,
len
(
server_key
)))
conn
.
sendall
(
server_key
.
encode
(
"utf-8"
))
return
conn
,
addr
,
_parse_server_opt
(
arr
[
1
:])
...
...
@@ -151,8 +151,8 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library):
# step 1: setup tracker and report to tracker
if
tracker_addr
and
tracker_conn
is
None
:
tracker_conn
=
base
.
connect_with_retry
(
tracker_addr
)
tracker_conn
.
sendall
(
struct
.
pack
(
"
@
i"
,
base
.
RPC_TRACKER_MAGIC
))
magic
=
struct
.
unpack
(
"
@
i"
,
base
.
recvall
(
tracker_conn
,
4
))[
0
]
tracker_conn
.
sendall
(
struct
.
pack
(
"
<
i"
,
base
.
RPC_TRACKER_MAGIC
))
magic
=
struct
.
unpack
(
"
<
i"
,
base
.
recvall
(
tracker_conn
,
4
))[
0
]
if
magic
!=
base
.
RPC_TRACKER_MAGIC
:
raise
RuntimeError
(
"
%
s is not RPC Tracker"
%
str
(
tracker_addr
))
# report status of current queue
...
...
@@ -193,17 +193,17 @@ def _connect_proxy_loop(addr, key, load_library):
try
:
sock
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
sock
.
connect
(
addr
)
sock
.
sendall
(
struct
.
pack
(
"
@
i"
,
base
.
RPC_MAGIC
))
sock
.
sendall
(
struct
.
pack
(
"
@
i"
,
len
(
key
)))
sock
.
sendall
(
struct
.
pack
(
"
<
i"
,
base
.
RPC_MAGIC
))
sock
.
sendall
(
struct
.
pack
(
"
<
i"
,
len
(
key
)))
sock
.
sendall
(
key
.
encode
(
"utf-8"
))
magic
=
struct
.
unpack
(
"
@
i"
,
base
.
recvall
(
sock
,
4
))[
0
]
magic
=
struct
.
unpack
(
"
<
i"
,
base
.
recvall
(
sock
,
4
))[
0
]
if
magic
==
base
.
RPC_CODE_DUPLICATE
:
raise
RuntimeError
(
"key:
%
s has already been used in proxy"
%
key
)
elif
magic
==
base
.
RPC_CODE_MISMATCH
:
logging
.
info
(
"RPCProxy do not have matching client key
%
s"
,
key
)
elif
magic
!=
base
.
RPC_CODE_SUCCESS
:
raise
RuntimeError
(
"
%
s is not RPC Proxy"
%
str
(
addr
))
keylen
=
struct
.
unpack
(
"
@
i"
,
base
.
recvall
(
sock
,
4
))[
0
]
keylen
=
struct
.
unpack
(
"
<
i"
,
base
.
recvall
(
sock
,
4
))[
0
]
remote_key
=
py_str
(
base
.
recvall
(
sock
,
keylen
))
opts
=
_parse_server_opt
(
remote_key
.
split
()[
1
:])
logging
.
info
(
"RPCProxy connected to
%
s"
,
str
(
addr
))
...
...
python/tvm/contrib/rpc/tracker.py
View file @
42608dda
...
...
@@ -143,11 +143,11 @@ class TCPEventHandler(tornado_util.TCPHandler):
if
len
(
message
)
!=
4
:
logging
.
info
(
"Invalid connection from
%
s"
,
self
.
name
())
self
.
close
()
magic
=
struct
.
unpack
(
'
@
i'
,
message
)[
0
]
magic
=
struct
.
unpack
(
'
<
i'
,
message
)[
0
]
if
magic
!=
RPC_TRACKER_MAGIC
:
logging
.
info
(
"Invalid magic from
%
s"
,
self
.
name
())
self
.
close
()
self
.
write_message
(
struct
.
pack
(
'
@
i'
,
RPC_TRACKER_MAGIC
),
binary
=
True
)
self
.
write_message
(
struct
.
pack
(
'
<
i'
,
RPC_TRACKER_MAGIC
),
binary
=
True
)
self
.
_init_req_nbytes
=
0
def
on_message
(
self
,
message
):
...
...
@@ -168,7 +168,7 @@ class TCPEventHandler(tornado_util.TCPHandler):
while
True
:
if
self
.
_msg_size
==
0
:
if
len
(
self
.
_data
)
>=
4
:
self
.
_msg_size
=
struct
.
unpack
(
'
@
i'
,
self
.
_data
[:
4
])[
0
]
self
.
_msg_size
=
struct
.
unpack
(
'
<
i'
,
self
.
_data
[:
4
])[
0
]
else
:
return
if
self
.
_msg_size
!=
0
and
len
(
self
.
_data
)
>=
self
.
_msg_size
+
4
:
...
...
@@ -184,7 +184,7 @@ class TCPEventHandler(tornado_util.TCPHandler):
"""return value to the output"""
data
=
json
.
dumps
(
data
)
self
.
write_message
(
struct
.
pack
(
'
@
i'
,
len
(
data
)),
binary
=
True
)
struct
.
pack
(
'
<
i'
,
len
(
data
)),
binary
=
True
)
self
.
write_message
(
data
.
encode
(
"utf-8"
),
binary
=
True
)
def
call_handler
(
self
,
args
):
...
...
@@ -355,8 +355,8 @@ class Tracker(object):
def
_stop_tracker
(
self
):
sock
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
sock
.
connect
((
self
.
host
,
self
.
port
))
sock
.
sendall
(
struct
.
pack
(
"
@
i"
,
base
.
RPC_TRACKER_MAGIC
))
magic
=
struct
.
unpack
(
"
@
i"
,
base
.
recvall
(
sock
,
4
))[
0
]
sock
.
sendall
(
struct
.
pack
(
"
<
i"
,
base
.
RPC_TRACKER_MAGIC
))
magic
=
struct
.
unpack
(
"
<
i"
,
base
.
recvall
(
sock
,
4
))[
0
]
assert
magic
==
base
.
RPC_TRACKER_MAGIC
base
.
sendjson
(
sock
,
[
TrackerCode
.
STOP
,
self
.
stop_key
])
assert
base
.
recvjson
(
sock
)
==
TrackerCode
.
SUCCESS
...
...
src/runtime/file_util.cc
View file @
42608dda
...
...
@@ -4,6 +4,7 @@
*/
#include <dmlc/json.h>
#include <dmlc/logging.h>
#include <tvm/runtime/serializer.h>
#include <fstream>
#include "./file_util.h"
...
...
src/runtime/graph/graph_runtime.cc
View file @
42608dda
...
...
@@ -4,6 +4,7 @@
*/
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/serializer.h>
#include <dmlc/memory_io.h>
#include <dmlc/json.h>
#include <numeric>
...
...
@@ -397,24 +398,25 @@ class GraphRuntime : public ModuleNode {
void
GraphRuntime
::
LoadDLTensor
(
dmlc
::
Stream
*
strm
,
DLTensor
*
dst
)
{
// always use strm->Read to maintain endianness conversion
uint64_t
header
,
reserved
;
CHECK
(
strm
->
Read
(
&
header
,
sizeof
(
header
)
))
CHECK
(
strm
->
Read
(
&
header
))
<<
"Invalid DLTensor file format"
;
CHECK
(
strm
->
Read
(
&
reserved
,
sizeof
(
reserved
)
))
CHECK
(
strm
->
Read
(
&
reserved
))
<<
"Invalid DLTensor file format"
;
CHECK
(
header
==
kTVMNDArrayMagic
)
<<
"Invalid DLTensor file format"
;
DLTensor
tensor
;
CHECK
(
strm
->
Read
(
&
tensor
.
ctx
,
sizeof
(
tensor
.
ctx
)))
CHECK
(
strm
->
Read
(
&
(
tensor
.
ctx
)))
<<
"Invalid DLTensor file format"
;
CHECK
(
strm
->
Read
(
&
tensor
.
ndim
,
sizeof
(
tensor
.
ndim
)))
CHECK
(
strm
->
Read
(
&
(
tensor
.
ndim
)))
<<
"Invalid DLTensor file format"
;
CHECK
(
strm
->
Read
(
&
tensor
.
dtype
,
sizeof
(
tensor
.
dtype
)))
CHECK
(
strm
->
Read
(
&
(
tensor
.
dtype
)))
<<
"Invalid DLTensor file format"
;
std
::
vector
<
int64_t
>
shape
(
tensor
.
ndim
);
if
(
tensor
.
ndim
!=
0
)
{
CHECK
(
strm
->
Read
(
&
shape
[
0
],
sizeof
(
int64_t
)
*
tensor
.
ndim
))
CHECK
(
strm
->
Read
Array
(
&
shape
[
0
],
tensor
.
ndim
))
<<
"Invalid DLTensor file format"
;
}
CHECK_EQ
(
tensor
.
ndim
,
dst
->
ndim
)
<<
"param dimension mismatch"
;
...
...
@@ -425,18 +427,23 @@ void GraphRuntime::LoadDLTensor(dmlc::Stream* strm, DLTensor* dst) {
CHECK_EQ
(
shape
[
i
],
dst
->
shape
[
i
])
<<
"param shape mismatch"
;
}
size_t
bits
=
dst
->
dtype
.
bits
*
dst
->
dtype
.
lanes
;
size_t
size
=
(
bits
+
7
)
/
8
;
size_t
elem_bytes
=
(
bits
+
7
)
/
8
;
size_t
num_elems
=
1
;
for
(
int
i
=
0
;
i
<
dst
->
ndim
;
++
i
)
{
size
*=
dst
->
shape
[
i
];
num_elems
*=
dst
->
shape
[
i
];
}
uint64_t
data_byte_size
;
CHECK
(
strm
->
Read
(
&
data_byte_size
,
sizeof
(
data_byte_size
)
))
CHECK
(
strm
->
Read
(
&
data_byte_size
))
<<
"Invalid DLTensor file format"
;
CHECK
(
data_byte_size
==
size
)
CHECK
_EQ
(
data_byte_size
,
elem_bytes
*
num_elems
)
<<
"Invalid DLTensor file format"
;
std
::
vector
<
uint8_t
>
bytes
(
data_byte_size
+
1
);
CHECK
(
strm
->
Read
(
&
bytes
[
0
],
data_byte_size
))
<<
"Invalid DLTensor file format"
;
// explicitly swap endian when necessary.
if
(
!
DMLC_IO_NO_ENDIAN_SWAP
)
{
dmlc
::
ByteSwap
(
&
bytes
[
0
],
elem_bytes
,
num_elems
);
}
TVM_CCALL
(
TVMArrayCopyFromBytes
(
dst
,
&
bytes
[
0
],
data_byte_size
));
}
...
...
@@ -453,9 +460,8 @@ void GraphRuntime::LoadParams(dmlc::Stream* strm) {
CHECK
(
strm
->
Read
(
&
names
))
<<
"Invalid parameters file format"
;
uint64_t
sz
;
strm
->
Read
(
&
sz
,
sizeof
(
sz
)
);
strm
->
Read
(
&
sz
);
size_t
size
=
static_cast
<
size_t
>
(
sz
);
CHECK
(
size
==
names
.
size
())
<<
"Invalid parameters file format"
;
for
(
size_t
i
=
0
;
i
<
size
;
++
i
)
{
...
...
src/runtime/rpc/rpc_session.cc
View file @
42608dda
...
...
@@ -6,6 +6,7 @@
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/serializer.h>
#include <memory>
#include <array>
#include <string>
...
...
@@ -44,7 +45,7 @@ struct RPCArgBuffer {
};
// Event handler for RPC events.
class
RPCSession
::
EventHandler
{
class
RPCSession
::
EventHandler
:
public
dmlc
::
Stream
{
public
:
EventHandler
(
common
::
RingBuffer
*
reader
,
common
::
RingBuffer
*
writer
,
...
...
@@ -71,6 +72,15 @@ class RPCSession::EventHandler {
return
0
;
}
}
// Request number of bytes from reader.
void
RequestBytes
(
size_t
nbytes
)
{
pending_request_bytes_
+=
nbytes
;
reader_
->
Reserve
(
pending_request_bytes_
);
}
// Whether we are ready to handle next request.
bool
Ready
()
{
return
reader_
->
bytes_available
()
>=
pending_request_bytes_
;
}
bool
CanCleanShutdown
()
const
{
return
state_
==
kRecvCode
;
}
...
...
@@ -86,12 +96,12 @@ class RPCSession::EventHandler {
case
kInitHeader
:
HandleInitHeader
();
break
;
case
kRecvCode
:
HandleRecvCode
();
break
;
case
kRecvCallHandle
:
{
this
->
Read
(
&
call_handle_
,
sizeof
(
call_handle_
));
CHECK
(
this
->
Read
(
&
call_handle_
));
this
->
SwitchToState
(
kRecvPackedSeqNumArgs
);
break
;
}
case
kRecvPackedSeqNumArgs
:
{
this
->
Read
(
&
num_packed_args_
,
sizeof
(
num_packed_args_
));
CHECK
(
this
->
Read
(
&
num_packed_args_
));
arg_buf_
.
reset
(
new
RPCArgBuffer
());
arg_buf_
->
value
.
resize
(
num_packed_args_
);
arg_buf_
->
tcode
.
resize
(
num_packed_args_
);
...
...
@@ -100,7 +110,7 @@ class RPCSession::EventHandler {
}
case
kRecvPackedSeqTypeCode
:
{
if
(
num_packed_args_
!=
0
)
{
this
->
Read
(
arg_buf_
->
tcode
.
data
(),
sizeof
(
int
)
*
num_packed_args_
);
this
->
Read
Array
(
arg_buf_
->
tcode
.
data
(),
num_packed_args_
);
}
arg_index_
=
0
;
arg_recv_stage_
=
0
;
...
...
@@ -164,8 +174,8 @@ class RPCSession::EventHandler {
}
// send Packed sequence to writer.
void
SendPackedSeq
(
const
TVMValue
*
arg_values
,
const
int
*
type_codes
,
int
n
)
{
writer_
->
Write
(
&
n
,
sizeof
(
n
)
);
writer_
->
Write
(
type_codes
,
sizeof
(
int
)
*
n
);
this
->
Write
(
n
);
this
->
WriteArray
(
type_codes
,
n
);
// Argument packing.
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
int
tcode
=
type_codes
[
i
];
...
...
@@ -173,14 +183,20 @@ class RPCSession::EventHandler {
switch
(
tcode
)
{
case
kDLInt
:
case
kDLUInt
:
case
kDLFloat
:
case
kDLFloat
:
{
this
->
Write
<
int64_t
>
(
value
.
v_int64
);
break
;
}
case
kTVMType
:
{
writer_
->
Write
(
&
value
,
sizeof
(
TVMValue
));
this
->
Write
(
value
.
v_type
);
// padding
int32_t
padding
=
0
;
this
->
Write
<
int32_t
>
(
padding
);
break
;
}
case
kTVMContext
:
{
value
.
v_ctx
=
StripSessMask
(
value
.
v_ctx
);
writer_
->
Write
(
&
value
,
sizeof
(
TVMValue
)
);
this
->
Write
(
value
.
v_ctx
);
break
;
}
case
kFuncHandle
:
...
...
@@ -188,7 +204,7 @@ class RPCSession::EventHandler {
case
kHandle
:
{
// always send handle in 64 bit.
uint64_t
handle
=
reinterpret_cast
<
uint64_t
>
(
value
.
v_handle
);
writer_
->
Write
(
&
handle
,
sizeof
(
uint64_t
)
);
this
->
Write
(
handle
);
break
;
}
case
kArrayHandle
:
{
...
...
@@ -196,11 +212,11 @@ class RPCSession::EventHandler {
TVMContext
ctx
=
StripSessMask
(
arr
->
ctx
);
uint64_t
data
=
reinterpret_cast
<
uint64_t
>
(
static_cast
<
RemoteSpace
*>
(
arr
->
data
)
->
data
);
writer_
->
Write
(
&
data
,
sizeof
(
uint64_t
)
);
writer_
->
Write
(
&
ctx
,
sizeof
(
ctx
)
);
writer_
->
Write
(
&
(
arr
->
ndim
),
sizeof
(
int
)
);
writer_
->
Write
(
&
(
arr
->
dtype
),
sizeof
(
DLDataType
)
);
writer_
->
Write
(
arr
->
shape
,
sizeof
(
int64_t
)
*
arr
->
ndim
);
this
->
Write
(
data
);
this
->
Write
(
ctx
);
this
->
Write
(
arr
->
ndim
);
this
->
Write
(
arr
->
dtype
);
this
->
WriteArray
(
arr
->
shape
,
arr
->
ndim
);
CHECK
(
arr
->
strides
==
nullptr
)
<<
"Donot support strided remote array"
;
CHECK_EQ
(
arr
->
byte_offset
,
0
)
...
...
@@ -211,15 +227,15 @@ class RPCSession::EventHandler {
case
kStr
:
{
const
char
*
s
=
value
.
v_str
;
uint64_t
len
=
strlen
(
s
);
writer_
->
Write
(
&
len
,
sizeof
(
len
)
);
writer_
->
Write
(
s
,
sizeof
(
char
)
*
len
);
this
->
Write
(
len
);
this
->
WriteArray
(
s
,
len
);
break
;
}
case
kBytes
:
{
TVMByteArray
*
bytes
=
static_cast
<
TVMByteArray
*>
(
arg_values
[
i
].
v_handle
);
uint64_t
len
=
bytes
->
size
;
writer_
->
Write
(
&
len
,
sizeof
(
len
)
);
writer_
->
Write
(
bytes
->
data
,
sizeof
(
char
)
*
len
);
this
->
Write
(
len
);
this
->
WriteArray
(
bytes
->
data
,
len
);
break
;
}
default
:
{
...
...
@@ -230,6 +246,23 @@ class RPCSession::EventHandler {
}
}
// Endian aware IO handling
using
Stream
::
Read
;
using
Stream
::
Write
;
using
Stream
::
ReadArray
;
using
Stream
::
WriteArray
;
inline
bool
Read
(
RPCCode
*
code
)
{
int
cdata
;
if
(
!
this
->
Read
(
&
cdata
))
return
false
;
*
code
=
static_cast
<
RPCCode
>
(
cdata
);
return
true
;
}
inline
void
Write
(
RPCCode
code
)
{
int
cdata
=
static_cast
<
int
>
(
code
);
this
->
Write
(
cdata
);
}
protected
:
enum
State
{
kInitHeader
,
...
...
@@ -370,10 +403,22 @@ class RPCSession::EventHandler {
switch
(
tcode
)
{
case
kDLInt
:
case
kDLUInt
:
case
kDLFloat
:
case
kTVMType
:
case
kDLFloat
:
{
this
->
Read
<
int64_t
>
(
&
(
value
.
v_int64
));
++
arg_index_
;
this
->
SwitchToState
(
kRecvPackedSeqArg
);
break
;
}
case
kTVMType
:
{
this
->
Read
(
&
(
value
.
v_type
));
int32_t
padding
=
0
;
this
->
Read
<
int32_t
>
(
&
padding
);
++
arg_index_
;
this
->
SwitchToState
(
kRecvPackedSeqArg
);
break
;
}
case
kTVMContext
:
{
this
->
Read
(
&
value
,
sizeof
(
TVMValue
));
this
->
Read
(
&
(
value
.
v_ctx
));
++
arg_index_
;
this
->
SwitchToState
(
kRecvPackedSeqArg
);
break
;
...
...
@@ -383,7 +428,7 @@ class RPCSession::EventHandler {
case
kHandle
:
{
// always send handle in 64 bit.
uint64_t
handle
;
this
->
Read
(
&
handle
,
sizeof
(
handle
)
);
this
->
Read
(
&
handle
);
value
.
v_handle
=
reinterpret_cast
<
void
*>
(
handle
);
++
arg_index_
;
this
->
SwitchToState
(
kRecvPackedSeqArg
);
...
...
@@ -398,7 +443,7 @@ class RPCSession::EventHandler {
case
kStr
:
case
kBytes
:
{
uint64_t
len
;
this
->
Read
(
&
len
,
sizeof
(
len
)
);
this
->
Read
(
&
len
);
temp_bytes_
.
reset
(
new
RPCByteArrayBuffer
());
temp_bytes_
->
data
.
resize
(
len
);
arg_recv_stage_
=
1
;
...
...
@@ -409,12 +454,12 @@ class RPCSession::EventHandler {
case
kArrayHandle
:
{
temp_array_
.
reset
(
new
RPCDataArrayBuffer
());
uint64_t
handle
;
this
->
Read
(
&
handle
,
sizeof
(
handle
)
);
this
->
Read
(
&
handle
);
DLTensor
&
tensor
=
temp_array_
->
tensor
;
tensor
.
data
=
reinterpret_cast
<
void
*>
(
handle
);
this
->
Read
(
&
(
tensor
.
ctx
)
,
sizeof
(
TVMContext
)
);
this
->
Read
(
&
(
tensor
.
ndim
)
,
sizeof
(
int
)
);
this
->
Read
(
&
(
tensor
.
dtype
)
,
sizeof
(
DLDataType
)
);
this
->
Read
(
&
(
tensor
.
ctx
));
this
->
Read
(
&
(
tensor
.
ndim
));
this
->
Read
(
&
(
tensor
.
dtype
));
temp_array_
->
shape
.
resize
(
tensor
.
ndim
);
tensor
.
shape
=
temp_array_
->
shape
.
data
();
arg_recv_stage_
=
1
;
...
...
@@ -432,7 +477,7 @@ class RPCSession::EventHandler {
CHECK_EQ
(
arg_recv_stage_
,
1
);
if
(
tcode
==
kStr
||
tcode
==
kBytes
)
{
if
(
temp_bytes_
->
data
.
size
()
!=
0
)
{
this
->
Read
(
&
(
temp_bytes_
->
data
[
0
]),
temp_bytes_
->
data
.
size
());
this
->
Read
Array
(
&
(
temp_bytes_
->
data
[
0
]),
temp_bytes_
->
data
.
size
());
}
if
(
tcode
==
kStr
)
{
value
.
v_str
=
temp_bytes_
->
data
.
c_str
();
...
...
@@ -445,7 +490,7 @@ class RPCSession::EventHandler {
}
else
{
CHECK_EQ
(
tcode
,
kArrayHandle
);
DLTensor
&
tensor
=
temp_array_
->
tensor
;
this
->
Read
(
tensor
.
shape
,
tensor
.
ndim
*
sizeof
(
int64_t
)
);
this
->
Read
Array
(
tensor
.
shape
,
tensor
.
ndim
);
value
.
v_handle
=
&
tensor
;
arg_buf_
->
temp_array
.
emplace_back
(
std
::
move
(
temp_array_
));
}
...
...
@@ -458,20 +503,20 @@ class RPCSession::EventHandler {
void
HandleInitHeader
()
{
if
(
init_header_step_
==
0
)
{
int32_t
len
;
this
->
Read
(
&
len
,
sizeof
(
len
)
);
this
->
Read
(
&
len
);
remote_key_
->
resize
(
len
);
init_header_step_
=
1
;
this
->
RequestBytes
(
len
);
return
;
}
else
{
CHECK_EQ
(
init_header_step_
,
1
);
this
->
Read
(
dmlc
::
BeginPtr
(
*
remote_key_
),
remote_key_
->
length
());
this
->
Read
Array
(
dmlc
::
BeginPtr
(
*
remote_key_
),
remote_key_
->
length
());
this
->
SwitchToState
(
kRecvCode
);
}
}
// Handler for read code.
void
HandleRecvCode
()
{
this
->
Read
(
&
code_
,
sizeof
(
code_
)
);
this
->
Read
(
&
code_
);
if
(
code_
>
RPCCode
::
kSystemFuncStart
)
{
SwitchToState
(
kRecvPackedSeqNumArgs
);
return
;
...
...
@@ -511,14 +556,14 @@ class RPCSession::EventHandler {
void
HandleCopyFromRemote
()
{
uint64_t
handle
,
offset
,
size
;
TVMContext
ctx
;
this
->
Read
(
&
handle
,
sizeof
(
handle
)
);
this
->
Read
(
&
offset
,
sizeof
(
offset
)
);
this
->
Read
(
&
size
,
sizeof
(
size
)
);
this
->
Read
(
&
ctx
,
sizeof
(
ctx
)
);
this
->
Read
(
&
handle
);
this
->
Read
(
&
offset
);
this
->
Read
(
&
size
);
this
->
Read
(
&
ctx
);
if
(
ctx
.
device_type
==
kDLCPU
)
{
RPCCode
code
=
RPCCode
::
kCopyAck
;
writer_
->
Write
(
&
code
,
sizeof
(
code
)
);
writer_
->
Write
(
reinterpret_cast
<
char
*>
(
handle
)
+
offset
,
size
);
this
->
Write
(
code
);
this
->
WriteArray
(
reinterpret_cast
<
char
*>
(
handle
)
+
offset
,
size
);
}
else
{
temp_data_
.
resize
(
size
+
1
);
try
{
...
...
@@ -530,11 +575,11 @@ class RPCSession::EventHandler {
dmlc
::
BeginPtr
(
temp_data_
),
0
,
size
,
ctx
,
cpu_ctx
,
nullptr
);
RPCCode
code
=
RPCCode
::
kCopyAck
;
writer_
->
Write
(
&
code
,
sizeof
(
code
)
);
writer_
->
Write
(
&
temp_data_
[
0
],
size
);
this
->
Write
(
code
);
this
->
WriteArray
(
&
temp_data_
[
0
],
size
);
}
catch
(
const
std
::
runtime_error
&
e
)
{
RPCCode
code
=
RPCCode
::
kException
;
writer_
->
Write
(
&
code
,
sizeof
(
code
)
);
this
->
Write
(
code
);
TVMValue
ret_value
;
ret_value
.
v_str
=
e
.
what
();
int
ret_tcode
=
kStr
;
...
...
@@ -548,10 +593,10 @@ class RPCSession::EventHandler {
// use static variable to persist state.
// This only works if next stage is immediately after this.
if
(
arg_recv_stage_
==
0
)
{
this
->
Read
(
&
copy_handle_
,
sizeof
(
uint64_t
));
this
->
Read
(
&
copy_offset_
,
sizeof
(
uint64_t
));
this
->
Read
(
&
copy_size_
,
sizeof
(
uint64_t
));
this
->
Read
(
&
copy_ctx_
,
sizeof
(
TVMContext
));
CHECK
(
this
->
Read
(
&
copy_handle_
));
CHECK
(
this
->
Read
(
&
copy_offset_
));
CHECK
(
this
->
Read
(
&
copy_size_
));
CHECK
(
this
->
Read
(
&
copy_ctx_
));
arg_recv_stage_
=
1
;
CHECK_EQ
(
pending_request_bytes_
,
0U
);
this
->
RequestBytes
(
copy_size_
);
...
...
@@ -563,11 +608,11 @@ class RPCSession::EventHandler {
RPCCode
code
=
RPCCode
::
kReturn
;
std
::
string
errmsg
;
if
(
copy_ctx_
.
device_type
==
kDLCPU
)
{
this
->
Read
(
this
->
Read
Array
(
reinterpret_cast
<
char
*>
(
copy_handle_
)
+
copy_offset_
,
copy_size_
);
}
else
{
temp_data_
.
resize
(
copy_size_
+
1
);
this
->
Read
(
&
temp_data_
[
0
],
copy_size_
);
this
->
Read
Array
(
&
temp_data_
[
0
],
copy_size_
);
try
{
TVMContext
cpu_ctx
;
cpu_ctx
.
device_type
=
kDLCPU
;
...
...
@@ -583,7 +628,7 @@ class RPCSession::EventHandler {
ret_tcode
=
kStr
;
}
}
writer_
->
Write
(
&
code
,
sizeof
(
code
)
);
this
->
Write
(
code
);
SendPackedSeq
(
&
ret_value
,
&
ret_tcode
,
1
);
arg_recv_stage_
=
0
;
this
->
SwitchToState
(
kRecvCode
);
...
...
@@ -603,7 +648,7 @@ class RPCSession::EventHandler {
std
::
unique_ptr
<
RPCArgBuffer
>
args
=
std
::
move
(
arg_buf_
);
f
(
args
->
AsTVMArgs
(),
&
rv
);
RPCCode
code
=
RPCCode
::
kReturn
;
writer_
->
Write
(
&
code
,
sizeof
(
code
)
);
this
->
Write
(
code
);
if
(
rv
.
type_code
()
==
kStr
)
{
ret_value
.
v_str
=
rv
.
ptr
<
std
::
string
>
()
->
c_str
();
ret_tcode
=
kStr
;
...
...
@@ -630,7 +675,7 @@ class RPCSession::EventHandler {
}
}
catch
(
const
std
::
runtime_error
&
e
)
{
RPCCode
code
=
RPCCode
::
kException
;
writer_
->
Write
(
&
code
,
sizeof
(
code
)
);
this
->
Write
(
code
);
ret_value
.
v_str
=
e
.
what
();
ret_tcode
=
kStr
;
SendPackedSeq
(
&
ret_value
,
&
ret_tcode
,
1
);
...
...
@@ -640,19 +685,14 @@ class RPCSession::EventHandler {
private
:
// Utility functions
// Internal read function, update pending_request_bytes_
void
Read
(
void
*
data
,
size_t
size
)
{
size_t
Read
(
void
*
data
,
size_t
size
)
final
{
CHECK_LE
(
size
,
pending_request_bytes_
);
reader_
->
Read
(
data
,
size
);
pending_request_bytes_
-=
size
;
return
size
;
}
// Request number of bytes from reader.
void
RequestBytes
(
size_t
nbytes
)
{
pending_request_bytes_
+=
nbytes
;
reader_
->
Reserve
(
pending_request_bytes_
);
}
// Whether we are ready to handle next request.
bool
Ready
()
{
return
reader_
->
bytes_available
()
>=
pending_request_bytes_
;
void
Write
(
const
void
*
data
,
size_t
size
)
final
{
writer_
->
Write
(
data
,
size
);
}
// Number of pending bytes requests
size_t
pending_request_bytes_
;
...
...
@@ -766,7 +806,7 @@ RPCSession::~RPCSession() {
void
RPCSession
::
Shutdown
()
{
if
(
channel_
!=
nullptr
)
{
RPCCode
code
=
RPCCode
::
kShutdown
;
writer_
.
Write
(
&
code
,
sizeof
(
code
)
);
handler_
->
Write
(
code
);
// flush all writing buffer to output channel.
try
{
while
(
writer_
.
bytes_available
()
!=
0
)
{
...
...
@@ -788,7 +828,6 @@ void RPCSession::ServerLoop() {
}
TVMRetValue
rv
;
CHECK
(
HandleUntilReturnEvent
(
&
rv
,
false
,
nullptr
)
==
RPCCode
::
kShutdown
);
LOG
(
INFO
)
<<
"Shutdown..."
;
if
(
const
auto
*
f
=
Registry
::
Get
(
"tvm.contrib.rpc.server.shutdown"
))
{
(
*
f
)();
}
...
...
@@ -821,9 +860,9 @@ void RPCSession::CallFunc(void* h,
const
PackedFunc
*
fwrap
)
{
std
::
lock_guard
<
std
::
recursive_mutex
>
lock
(
mutex_
);
RPCCode
code
=
RPCCode
::
kCallFunc
;
writer_
.
Write
(
&
code
,
sizeof
(
code
)
);
handler_
->
Write
(
code
);
uint64_t
handle
=
reinterpret_cast
<
uint64_t
>
(
h
);
writer_
.
Write
(
&
handle
,
sizeof
(
handle
)
);
handler_
->
Write
(
handle
);
handler_
->
SendPackedSeq
(
args
.
values
,
args
.
type_codes
,
args
.
num_args
);
code
=
HandleUntilReturnEvent
(
rv
,
true
,
fwrap
);
CHECK
(
code
==
RPCCode
::
kReturn
)
<<
"code="
<<
static_cast
<
int
>
(
code
);
...
...
@@ -838,15 +877,15 @@ void RPCSession::CopyToRemote(void* from,
std
::
lock_guard
<
std
::
recursive_mutex
>
lock
(
mutex_
);
ctx_to
=
handler_
->
StripSessMask
(
ctx_to
);
RPCCode
code
=
RPCCode
::
kCopyToRemote
;
writer_
.
Write
(
&
code
,
sizeof
(
code
)
);
handler_
->
Write
(
code
);
uint64_t
handle
=
reinterpret_cast
<
uint64_t
>
(
to
);
writer_
.
Write
(
&
handle
,
sizeof
(
handle
)
);
handler_
->
Write
(
handle
);
uint64_t
offset
=
static_cast
<
uint64_t
>
(
to_offset
);
writer_
.
Write
(
&
offset
,
sizeof
(
offset
)
);
handler_
->
Write
(
offset
);
uint64_t
size
=
static_cast
<
uint64_t
>
(
data_size
);
writer_
.
Write
(
&
size
,
sizeof
(
size
)
);
writer_
.
Write
(
&
ctx_to
,
sizeof
(
ctx_to
)
);
writer_
.
Write
(
reinterpret_cast
<
char
*>
(
from
)
+
from_offset
,
data_size
);
handler_
->
Write
(
size
);
handler_
->
Write
(
ctx_to
);
handler_
->
WriteArray
(
reinterpret_cast
<
char
*>
(
from
)
+
from_offset
,
data_size
);
TVMRetValue
rv
;
CHECK
(
HandleUntilReturnEvent
(
&
rv
,
true
,
nullptr
)
==
RPCCode
::
kReturn
);
}
...
...
@@ -860,26 +899,27 @@ void RPCSession::CopyFromRemote(void* from,
std
::
lock_guard
<
std
::
recursive_mutex
>
lock
(
mutex_
);
ctx_from
=
handler_
->
StripSessMask
(
ctx_from
);
RPCCode
code
=
RPCCode
::
kCopyFromRemote
;
writer_
.
Write
(
&
code
,
sizeof
(
code
)
);
handler_
->
Write
(
code
);
uint64_t
handle
=
reinterpret_cast
<
uint64_t
>
(
from
);
writer_
.
Write
(
&
handle
,
sizeof
(
handle
)
);
handler_
->
Write
(
handle
);
uint64_t
offset
=
static_cast
<
uint64_t
>
(
from_offset
);
writer_
.
Write
(
&
offset
,
sizeof
(
offset
)
);
handler_
->
Write
(
offset
);
uint64_t
size
=
static_cast
<
uint64_t
>
(
data_size
);
writer_
.
Write
(
&
size
,
sizeof
(
size
)
);
writer_
.
Write
(
&
ctx_from
,
sizeof
(
ctx_from
)
);
handler_
->
Write
(
size
);
handler_
->
Write
(
ctx_from
);
TVMRetValue
rv
;
CHECK
(
HandleUntilReturnEvent
(
&
rv
,
true
,
nullptr
)
==
RPCCode
::
kCopyAck
);
reader_
.
Reserve
(
data_size
);
while
(
reader_
.
bytes_available
()
<
data_size
)
{
size_t
bytes_needed
=
data_size
-
reader_
.
bytes_available
();
handler_
->
RequestBytes
(
data_size
);
while
(
!
handler_
->
Ready
())
{
size_t
bytes_needed
=
handler_
->
BytesNeeded
();
reader_
.
WriteWithCallback
([
this
](
void
*
data
,
size_t
size
)
{
size_t
n
=
channel_
->
Recv
(
data
,
size
);
CHECK_NE
(
n
,
0U
)
<<
"Channel closes before we get neded bytes"
;
return
n
;
},
bytes_needed
);
}
reader_
.
Read
(
reinterpret_cast
<
char
*>
(
to
)
+
to_offset
,
data_size
);
handler_
->
ReadArray
(
reinterpret_cast
<
char
*>
(
to
)
+
to_offset
,
data_size
);
handler_
->
FinishCopyAck
();
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment