Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
42608dda
Commit
42608dda
authored
May 30, 2018
by
tqchen
Committed by
Tianqi Chen
May 30, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[IO] Support cross-endian
parent
0f9dab98
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
263 additions
and
156 deletions
+263
-156
dmlc-core
+1
-1
include/tvm/runtime/serializer.h
+50
-0
nnvm/src/compiler/graph_runtime.cc
+40
-30
python/tvm/contrib/rpc/base.py
+2
-2
python/tvm/contrib/rpc/client.py
+2
-2
python/tvm/contrib/rpc/proxy.py
+14
-14
python/tvm/contrib/rpc/server.py
+11
-11
python/tvm/contrib/rpc/tracker.py
+6
-6
src/runtime/file_util.cc
+1
-0
src/runtime/graph/graph_runtime.cc
+18
-12
src/runtime/rpc/rpc_session.cc
+118
-78
No files found.
dmlc-core
@
9b3f9753
Subproject commit
d3f7fbb53e5b037c0f5bf6bd21871ccc720690cc
Subproject commit
9b3f9753ae81d657743c555e0cacc4e43f0bed2d
include/tvm/runtime/serializer.h
0 → 100644
View file @
42608dda
/*!
* Copyright (c) 2017 by Contributors
* \file tvm/runtime/serializer.h
* \brief Serializer extension to support TVM data types
* Include this file to enable serialization of DLDataType, DLContext
*/
#ifndef TVM_RUNTIME_SERIALIZER_H_
#define TVM_RUNTIME_SERIALIZER_H_
#include <dmlc/io.h>
#include <dmlc/serializer.h>
#include "./c_runtime_api.h"
namespace
dmlc
{
namespace
serializer
{
template
<>
struct
Handler
<
DLDataType
>
{
inline
static
void
Write
(
Stream
*
strm
,
const
DLDataType
&
dtype
)
{
Handler
<
uint8_t
>::
Write
(
strm
,
dtype
.
code
);
Handler
<
uint8_t
>::
Write
(
strm
,
dtype
.
bits
);
Handler
<
uint16_t
>::
Write
(
strm
,
dtype
.
lanes
);
}
inline
static
bool
Read
(
Stream
*
strm
,
DLDataType
*
dtype
)
{
if
(
!
Handler
<
uint8_t
>::
Read
(
strm
,
&
(
dtype
->
code
)))
return
false
;
if
(
!
Handler
<
uint8_t
>::
Read
(
strm
,
&
(
dtype
->
bits
)))
return
false
;
if
(
!
Handler
<
uint16_t
>::
Read
(
strm
,
&
(
dtype
->
lanes
)))
return
false
;
return
true
;
}
};
template
<>
struct
Handler
<
DLContext
>
{
inline
static
void
Write
(
Stream
*
strm
,
const
DLContext
&
ctx
)
{
int32_t
device_type
=
static_cast
<
int32_t
>
(
ctx
.
device_type
);
Handler
<
int32_t
>::
Write
(
strm
,
device_type
);
Handler
<
int32_t
>::
Write
(
strm
,
ctx
.
device_id
);
}
inline
static
bool
Read
(
Stream
*
strm
,
DLContext
*
ctx
)
{
int32_t
device_type
=
0
;
if
(
!
Handler
<
int32_t
>::
Read
(
strm
,
&
(
device_type
)))
return
false
;
ctx
->
device_type
=
static_cast
<
DLDeviceType
>
(
device_type
);
if
(
!
Handler
<
int32_t
>::
Read
(
strm
,
&
(
ctx
->
device_id
)))
return
false
;
return
true
;
}
};
}
// namespace serializer
}
// namespace dmlc
#endif // TVM_RUNTIME_SERIALIZER_H_
nnvm/src/compiler/graph_runtime.cc
View file @
42608dda
...
@@ -7,6 +7,7 @@
...
@@ -7,6 +7,7 @@
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/serializer.h>
#include "./graph_runtime.h"
#include "./graph_runtime.h"
namespace
nnvm
{
namespace
nnvm
{
...
@@ -38,46 +39,53 @@ NNVM_REGISTER_OP(tvm_op)
...
@@ -38,46 +39,53 @@ NNVM_REGISTER_OP(tvm_op)
bool
SaveDLTensor
(
dmlc
::
Stream
*
strm
,
DLTensor
*
tensor
)
{
bool
SaveDLTensor
(
dmlc
::
Stream
*
strm
,
DLTensor
*
tensor
)
{
uint64_t
header
=
kTVMNDArrayMagic
,
reserved
=
0
;
uint64_t
header
=
kTVMNDArrayMagic
,
reserved
=
0
;
strm
->
Write
(
&
header
,
sizeof
(
header
));
strm
->
Write
(
header
);
strm
->
Write
(
&
reserved
,
sizeof
(
reserved
));
strm
->
Write
(
reserved
);
strm
->
Write
(
tensor
->
ctx
);
strm
->
Write
(
&
tensor
->
ctx
,
sizeof
(
tensor
->
ctx
));
strm
->
Write
(
tensor
->
ndim
);
strm
->
Write
(
&
tensor
->
ndim
,
sizeof
(
tensor
->
ndim
));
strm
->
Write
(
tensor
->
dtype
);
strm
->
Write
(
&
tensor
->
dtype
,
sizeof
(
tensor
->
dtype
));
int
ndim
=
tensor
->
ndim
;
int
ndim
=
tensor
->
ndim
;
strm
->
Write
(
tensor
->
shape
,
sizeof
(
int64_t
)
*
ndim
);
strm
->
Write
Array
(
tensor
->
shape
,
ndim
);
int
type_
size
=
tensor
->
dtype
.
bits
/
8
;
int
type_
bytes
=
tensor
->
dtype
.
bits
/
8
;
int64_t
size
=
1
;
int64_t
num_elems
=
1
;
for
(
int
i
=
0
;
i
<
ndim
;
++
i
)
{
for
(
int
i
=
0
;
i
<
ndim
;
++
i
)
{
size
*=
tensor
->
shape
[
i
];
num_elems
*=
tensor
->
shape
[
i
];
}
}
int64_t
data_byte_size
=
type_size
*
size
;
int64_t
data_byte_size
=
type_bytes
*
num_elems
;
strm
->
Write
(
&
data_byte_size
,
sizeof
(
data_byte_size
));
strm
->
Write
(
data_byte_size
);
// handle endianness of data correctly.
if
(
DMLC_IO_NO_ENDIAN_SWAP
)
{
strm
->
Write
(
tensor
->
data
,
data_byte_size
);
strm
->
Write
(
tensor
->
data
,
data_byte_size
);
}
else
{
uint8_t
*
dptr
=
reinterpret_cast
<
uint8_t
*>
(
tensor
->
data
);
std
::
vector
<
uint8_t
>
bytes
(
dptr
,
dptr
+
data_byte_size
);
dmlc
::
ByteSwap
(
dmlc
::
BeginPtr
(
bytes
),
type_bytes
,
num_elems
);
strm
->
Write
(
dmlc
::
BeginPtr
(
bytes
),
data_byte_size
);
}
return
true
;
return
true
;
}
}
DLTensor
*
LoadDLTensor
(
dmlc
::
Stream
*
strm
)
{
DLTensor
*
LoadDLTensor
(
dmlc
::
Stream
*
strm
)
{
uint64_t
header
,
reserved
;
uint64_t
header
,
reserved
;
CHECK
(
strm
->
Read
(
&
header
,
sizeof
(
header
)
))
CHECK
(
strm
->
Read
(
&
header
))
<<
"Invalid DLTensor file format"
;
<<
"Invalid DLTensor file format"
;
CHECK
(
strm
->
Read
(
&
reserved
,
sizeof
(
reserved
)
))
CHECK
(
strm
->
Read
(
&
reserved
))
<<
"Invalid DLTensor file format"
;
<<
"Invalid DLTensor file format"
;
CHECK
(
header
==
kTVMNDArrayMagic
)
CHECK
(
header
==
kTVMNDArrayMagic
)
<<
"Invalid DLTensor file format"
;
<<
"Invalid DLTensor file format"
;
DLTensor
tensor
;
DLTensor
tensor
;
CHECK
(
strm
->
Read
(
&
tensor
.
ctx
,
sizeof
(
tensor
.
ctx
)))
CHECK
(
strm
->
Read
(
&
(
tensor
.
ctx
)))
<<
"Invalid DLTensor file format"
;
<<
"Invalid DLTensor file format"
;
CHECK
(
strm
->
Read
(
&
tensor
.
ndim
,
sizeof
(
tensor
.
ndim
)))
CHECK
(
strm
->
Read
(
&
(
tensor
.
ndim
)))
<<
"Invalid DLTensor file format"
;
<<
"Invalid DLTensor file format"
;
CHECK
(
strm
->
Read
(
&
tensor
.
dtype
,
sizeof
(
tensor
.
dtype
)))
CHECK
(
strm
->
Read
(
&
(
tensor
.
dtype
)))
<<
"Invalid DLTensor file format"
;
<<
"Invalid DLTensor file format"
;
std
::
vector
<
int64_t
>
shape
(
tensor
.
ndim
);
std
::
vector
<
int64_t
>
shape
(
tensor
.
ndim
);
CHECK
(
strm
->
Read
(
&
shape
[
0
],
sizeof
(
int64_t
)
*
tensor
.
ndim
))
if
(
tensor
.
ndim
!=
0
)
{
CHECK
(
strm
->
ReadArray
(
&
shape
[
0
],
tensor
.
ndim
))
<<
"Invalid DLTensor file format"
;
<<
"Invalid DLTensor file format"
;
}
DLTensor
*
ret
;
DLTensor
*
ret
;
CHECK_EQ
(
TVMArrayAlloc
(
shape
.
data
(),
CHECK_EQ
(
TVMArrayAlloc
(
shape
.
data
(),
tensor
.
ndim
,
tensor
.
ndim
,
...
@@ -87,18 +95,21 @@ DLTensor* LoadDLTensor(dmlc::Stream* strm) {
...
@@ -87,18 +95,21 @@ DLTensor* LoadDLTensor(dmlc::Stream* strm) {
static_cast
<
int
>
(
tensor
.
ctx
.
device_type
),
static_cast
<
int
>
(
tensor
.
ctx
.
device_type
),
tensor
.
ctx
.
device_id
,
tensor
.
ctx
.
device_id
,
&
ret
),
0
)
<<
TVMGetLastError
();
&
ret
),
0
)
<<
TVMGetLastError
();
int64_t
size
=
1
;
int64_t
num_elems
=
1
;
int
type_size
=
ret
->
dtype
.
bits
/
8
;
int
elem_bytes
=
(
ret
->
dtype
.
bits
+
7
)
/
8
;
for
(
int
i
=
0
;
i
<
ret
->
ndim
;
++
i
)
{
for
(
int
i
=
0
;
i
<
ret
->
ndim
;
++
i
)
{
size
*=
ret
->
shape
[
i
];
num_elems
*=
ret
->
shape
[
i
];
}
}
int64_t
data_byte_size
;
int64_t
data_byte_size
;
CHECK
(
strm
->
Read
(
&
data_byte_size
,
sizeof
(
data_byte_size
)
))
CHECK
(
strm
->
Read
(
&
data_byte_size
))
<<
"Invalid DLTensor file format"
;
<<
"Invalid DLTensor file format"
;
CHECK
(
data_byte_size
==
type_size
*
size
)
CHECK
(
data_byte_size
==
num_elems
*
elem_bytes
)
<<
"Invalid DLTensor file format"
;
<<
"Invalid DLTensor file format"
;
CHECK
(
strm
->
Read
(
ret
->
data
,
type_size
*
size
))
CHECK
(
strm
->
Read
(
ret
->
data
,
data_byte_
size
))
<<
"Invalid DLTensor file format"
;
<<
"Invalid DLTensor file format"
;
if
(
!
DMLC_IO_NO_ENDIAN_SWAP
)
{
dmlc
::
ByteSwap
(
ret
->
data
,
elem_bytes
,
num_elems
);
}
return
ret
;
return
ret
;
}
}
...
@@ -118,12 +129,12 @@ TVM_REGISTER_GLOBAL("nnvm.compiler._save_param_dict")
...
@@ -118,12 +129,12 @@ TVM_REGISTER_GLOBAL("nnvm.compiler._save_param_dict")
dmlc
::
MemoryStringStream
strm
(
&
bytes
);
dmlc
::
MemoryStringStream
strm
(
&
bytes
);
dmlc
::
Stream
*
fo
=
&
strm
;
dmlc
::
Stream
*
fo
=
&
strm
;
uint64_t
header
=
kTVMNDArrayListMagic
,
reserved
=
0
;
uint64_t
header
=
kTVMNDArrayListMagic
,
reserved
=
0
;
fo
->
Write
(
&
header
,
sizeof
(
header
)
);
fo
->
Write
(
header
);
fo
->
Write
(
&
reserved
,
sizeof
(
reserved
)
);
fo
->
Write
(
reserved
);
fo
->
Write
(
names
);
fo
->
Write
(
names
);
{
{
uint64_t
sz
=
static_cast
<
uint64_t
>
(
arrays
.
size
());
uint64_t
sz
=
static_cast
<
uint64_t
>
(
arrays
.
size
());
fo
->
Write
(
&
sz
,
sizeof
(
sz
)
);
fo
->
Write
(
sz
);
for
(
size_t
i
=
0
;
i
<
sz
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
sz
;
++
i
)
{
SaveDLTensor
(
fo
,
arrays
[
i
]);
SaveDLTensor
(
fo
,
arrays
[
i
]);
}
}
...
@@ -150,7 +161,6 @@ TVM_REGISTER_GLOBAL("nnvm.compiler._load_param_dict")
...
@@ -150,7 +161,6 @@ TVM_REGISTER_GLOBAL("nnvm.compiler._load_param_dict")
<<
"Invalid parameters file format"
;
<<
"Invalid parameters file format"
;
CHECK
(
strm
->
Read
(
&
reserved
))
CHECK
(
strm
->
Read
(
&
reserved
))
<<
"Invalid parameters file format"
;
<<
"Invalid parameters file format"
;
CHECK
(
strm
->
Read
(
&
names
))
CHECK
(
strm
->
Read
(
&
names
))
<<
"Invalid parameters file format"
;
<<
"Invalid parameters file format"
;
uint64_t
sz
;
uint64_t
sz
;
...
...
python/tvm/contrib/rpc/base.py
View file @
42608dda
...
@@ -73,7 +73,7 @@ def sendjson(sock, data):
...
@@ -73,7 +73,7 @@ def sendjson(sock, data):
Python value to be sent.
Python value to be sent.
"""
"""
data
=
json
.
dumps
(
data
)
data
=
json
.
dumps
(
data
)
sock
.
sendall
(
struct
.
pack
(
"
@
i"
,
len
(
data
)))
sock
.
sendall
(
struct
.
pack
(
"
<
i"
,
len
(
data
)))
sock
.
sendall
(
data
.
encode
(
"utf-8"
))
sock
.
sendall
(
data
.
encode
(
"utf-8"
))
...
@@ -90,7 +90,7 @@ def recvjson(sock):
...
@@ -90,7 +90,7 @@ def recvjson(sock):
value : object
value : object
The value received.
The value received.
"""
"""
size
=
struct
.
unpack
(
"
@
i"
,
recvall
(
sock
,
4
))[
0
]
size
=
struct
.
unpack
(
"
<
i"
,
recvall
(
sock
,
4
))[
0
]
data
=
json
.
loads
(
py_str
(
recvall
(
sock
,
size
)))
data
=
json
.
loads
(
py_str
(
recvall
(
sock
,
size
)))
return
data
return
data
...
...
python/tvm/contrib/rpc/client.py
View file @
42608dda
...
@@ -192,8 +192,8 @@ class TrackerSession(object):
...
@@ -192,8 +192,8 @@ class TrackerSession(object):
def
_connect
(
self
):
def
_connect
(
self
):
self
.
_sock
=
base
.
connect_with_retry
(
self
.
_addr
)
self
.
_sock
=
base
.
connect_with_retry
(
self
.
_addr
)
self
.
_sock
.
sendall
(
struct
.
pack
(
"
@
i"
,
base
.
RPC_TRACKER_MAGIC
))
self
.
_sock
.
sendall
(
struct
.
pack
(
"
<
i"
,
base
.
RPC_TRACKER_MAGIC
))
magic
=
struct
.
unpack
(
"
@
i"
,
base
.
recvall
(
self
.
_sock
,
4
))[
0
]
magic
=
struct
.
unpack
(
"
<
i"
,
base
.
recvall
(
self
.
_sock
,
4
))[
0
]
if
magic
!=
base
.
RPC_TRACKER_MAGIC
:
if
magic
!=
base
.
RPC_TRACKER_MAGIC
:
raise
RuntimeError
(
"
%
s is not RPC Tracker"
%
str
(
self
.
_addr
))
raise
RuntimeError
(
"
%
s is not RPC Tracker"
%
str
(
self
.
_addr
))
...
...
python/tvm/contrib/rpc/proxy.py
View file @
42608dda
...
@@ -58,14 +58,14 @@ class ForwardHandler(object):
...
@@ -58,14 +58,14 @@ class ForwardHandler(object):
def
_init_step
(
self
,
message
):
def
_init_step
(
self
,
message
):
if
self
.
_magic
is
None
:
if
self
.
_magic
is
None
:
assert
len
(
message
)
==
4
assert
len
(
message
)
==
4
self
.
_magic
=
struct
.
unpack
(
'
@
i'
,
message
)[
0
]
self
.
_magic
=
struct
.
unpack
(
'
<
i'
,
message
)[
0
]
if
self
.
_magic
!=
base
.
RPC_MAGIC
:
if
self
.
_magic
!=
base
.
RPC_MAGIC
:
logging
.
info
(
"Invalid RPC magic from
%
s"
,
self
.
name
())
logging
.
info
(
"Invalid RPC magic from
%
s"
,
self
.
name
())
self
.
close
()
self
.
close
()
self
.
_init_req_nbytes
=
4
self
.
_init_req_nbytes
=
4
elif
self
.
_rpc_key_length
is
None
:
elif
self
.
_rpc_key_length
is
None
:
assert
len
(
message
)
==
4
assert
len
(
message
)
==
4
self
.
_rpc_key_length
=
struct
.
unpack
(
'
@
i'
,
message
)[
0
]
self
.
_rpc_key_length
=
struct
.
unpack
(
'
<
i'
,
message
)[
0
]
self
.
_init_req_nbytes
=
self
.
_rpc_key_length
self
.
_init_req_nbytes
=
self
.
_rpc_key_length
elif
self
.
rpc_key
is
None
:
elif
self
.
rpc_key
is
None
:
assert
len
(
message
)
==
self
.
_rpc_key_length
assert
len
(
message
)
==
self
.
_rpc_key_length
...
@@ -269,12 +269,12 @@ class ProxyServerHandler(object):
...
@@ -269,12 +269,12 @@ class ProxyServerHandler(object):
lhs
.
forward_proxy
=
rhs
lhs
.
forward_proxy
=
rhs
rhs
.
forward_proxy
=
lhs
rhs
.
forward_proxy
=
lhs
lhs
.
send_data
(
struct
.
pack
(
'
@
i'
,
base
.
RPC_CODE_SUCCESS
))
lhs
.
send_data
(
struct
.
pack
(
'
<
i'
,
base
.
RPC_CODE_SUCCESS
))
lhs
.
send_data
(
struct
.
pack
(
'
@
i'
,
len
(
rhs
.
rpc_key
)))
lhs
.
send_data
(
struct
.
pack
(
'
<
i'
,
len
(
rhs
.
rpc_key
)))
lhs
.
send_data
(
rhs
.
rpc_key
.
encode
(
"utf-8"
))
lhs
.
send_data
(
rhs
.
rpc_key
.
encode
(
"utf-8"
))
rhs
.
send_data
(
struct
.
pack
(
'
@
i'
,
base
.
RPC_CODE_SUCCESS
))
rhs
.
send_data
(
struct
.
pack
(
'
<
i'
,
base
.
RPC_CODE_SUCCESS
))
rhs
.
send_data
(
struct
.
pack
(
'
@
i'
,
len
(
lhs
.
rpc_key
)))
rhs
.
send_data
(
struct
.
pack
(
'
<
i'
,
len
(
lhs
.
rpc_key
)))
rhs
.
send_data
(
lhs
.
rpc_key
.
encode
(
"utf-8"
))
rhs
.
send_data
(
lhs
.
rpc_key
.
encode
(
"utf-8"
))
logging
.
info
(
"Pairup connect
%
s and
%
s"
,
lhs
.
name
(),
rhs
.
name
())
logging
.
info
(
"Pairup connect
%
s and
%
s"
,
lhs
.
name
(),
rhs
.
name
())
...
@@ -299,8 +299,8 @@ class ProxyServerHandler(object):
...
@@ -299,8 +299,8 @@ class ProxyServerHandler(object):
if
self
.
_tracker_conn
is
None
:
if
self
.
_tracker_conn
is
None
:
self
.
_tracker_conn
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
self
.
_tracker_conn
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
self
.
_tracker_conn
.
connect
(
self
.
_tracker_addr
)
self
.
_tracker_conn
.
connect
(
self
.
_tracker_addr
)
self
.
_tracker_conn
.
sendall
(
struct
.
pack
(
"
@
i"
,
base
.
RPC_TRACKER_MAGIC
))
self
.
_tracker_conn
.
sendall
(
struct
.
pack
(
"
<
i"
,
base
.
RPC_TRACKER_MAGIC
))
magic
=
struct
.
unpack
(
"
@
i"
,
base
.
recvall
(
self
.
_tracker_conn
,
4
))[
0
]
magic
=
struct
.
unpack
(
"
<
i"
,
base
.
recvall
(
self
.
_tracker_conn
,
4
))[
0
]
if
magic
!=
base
.
RPC_TRACKER_MAGIC
:
if
magic
!=
base
.
RPC_TRACKER_MAGIC
:
self
.
loop
.
stop
()
self
.
loop
.
stop
()
raise
RuntimeError
(
"
%
s is not RPC Tracker"
%
str
(
self
.
_tracker_addr
))
raise
RuntimeError
(
"
%
s is not RPC Tracker"
%
str
(
self
.
_tracker_addr
))
...
@@ -371,7 +371,7 @@ class ProxyServerHandler(object):
...
@@ -371,7 +371,7 @@ class ProxyServerHandler(object):
if
handler
.
match_key
in
self
.
_server_pool
:
if
handler
.
match_key
in
self
.
_server_pool
:
self
.
_pair_up
(
self
.
_server_pool
.
pop
(
handler
.
match_key
),
handler
)
self
.
_pair_up
(
self
.
_server_pool
.
pop
(
handler
.
match_key
),
handler
)
else
:
else
:
handler
.
send_data
(
struct
.
pack
(
'
@
i'
,
base
.
RPC_CODE_MISMATCH
))
handler
.
send_data
(
struct
.
pack
(
'
<
i'
,
base
.
RPC_CODE_MISMATCH
))
handler
.
signal_close
()
handler
.
signal_close
()
def
_handler_ready_proxy_mode
(
self
,
handler
):
def
_handler_ready_proxy_mode
(
self
,
handler
):
...
@@ -395,12 +395,12 @@ class ProxyServerHandler(object):
...
@@ -395,12 +395,12 @@ class ProxyServerHandler(object):
logging
.
info
(
"Timeout client connection
%
s, cannot find match key=
%
s"
,
logging
.
info
(
"Timeout client connection
%
s, cannot find match key=
%
s"
,
handler
.
name
(),
key
)
handler
.
name
(),
key
)
pool_dst
.
pop
(
key
)
pool_dst
.
pop
(
key
)
handler
.
send_data
(
struct
.
pack
(
'
@
i'
,
base
.
RPC_CODE_MISMATCH
))
handler
.
send_data
(
struct
.
pack
(
'
<
i'
,
base
.
RPC_CODE_MISMATCH
))
handler
.
signal_close
()
handler
.
signal_close
()
self
.
loop
.
call_later
(
timeout
,
cleanup
)
self
.
loop
.
call_later
(
timeout
,
cleanup
)
else
:
else
:
logging
.
info
(
"Duplicate connection with same key=
%
s"
,
key
)
logging
.
info
(
"Duplicate connection with same key=
%
s"
,
key
)
handler
.
send_data
(
struct
.
pack
(
'
@
i'
,
base
.
RPC_CODE_DUPLICATE
))
handler
.
send_data
(
struct
.
pack
(
'
<
i'
,
base
.
RPC_CODE_DUPLICATE
))
handler
.
signal_close
()
handler
.
signal_close
()
def
handler_ready
(
self
,
handler
):
def
handler_ready
(
self
,
handler
):
...
@@ -538,13 +538,13 @@ def websocket_proxy_server(url, key=""):
...
@@ -538,13 +538,13 @@ def websocket_proxy_server(url, key=""):
on_message
=
create_on_message
(
conn
)
on_message
=
create_on_message
(
conn
)
temp
=
_server_env
(
None
)
temp
=
_server_env
(
None
)
# Start connecton
# Start connecton
conn
.
write_message
(
struct
.
pack
(
'
@
i'
,
base
.
RPC_MAGIC
),
binary
=
True
)
conn
.
write_message
(
struct
.
pack
(
'
<
i'
,
base
.
RPC_MAGIC
),
binary
=
True
)
key
=
"server:"
+
key
key
=
"server:"
+
key
conn
.
write_message
(
struct
.
pack
(
'
@
i'
,
len
(
key
)),
binary
=
True
)
conn
.
write_message
(
struct
.
pack
(
'
<
i'
,
len
(
key
)),
binary
=
True
)
conn
.
write_message
(
key
.
encode
(
"utf-8"
),
binary
=
True
)
conn
.
write_message
(
key
.
encode
(
"utf-8"
),
binary
=
True
)
msg
=
yield
conn
.
read_message
()
msg
=
yield
conn
.
read_message
()
assert
len
(
msg
)
>=
4
assert
len
(
msg
)
>=
4
magic
=
struct
.
unpack
(
'
@
i'
,
msg
[:
4
])[
0
]
magic
=
struct
.
unpack
(
'
<
i'
,
msg
[:
4
])[
0
]
if
magic
==
base
.
RPC_CODE_DUPLICATE
:
if
magic
==
base
.
RPC_CODE_DUPLICATE
:
raise
RuntimeError
(
"key:
%
s has already been used in proxy"
%
key
)
raise
RuntimeError
(
"key:
%
s has already been used in proxy"
%
key
)
elif
magic
==
base
.
RPC_CODE_MISMATCH
:
elif
magic
==
base
.
RPC_CODE_MISMATCH
:
...
...
python/tvm/contrib/rpc/server.py
View file @
42608dda
...
@@ -124,23 +124,23 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library):
...
@@ -124,23 +124,23 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library):
unmatch_period_count
=
0
unmatch_period_count
=
0
continue
continue
conn
,
addr
=
listen_sock
.
accept
()
conn
,
addr
=
listen_sock
.
accept
()
magic
=
struct
.
unpack
(
"
@
i"
,
base
.
recvall
(
conn
,
4
))[
0
]
magic
=
struct
.
unpack
(
"
<
i"
,
base
.
recvall
(
conn
,
4
))[
0
]
if
magic
!=
base
.
RPC_MAGIC
:
if
magic
!=
base
.
RPC_MAGIC
:
conn
.
close
()
conn
.
close
()
continue
continue
keylen
=
struct
.
unpack
(
"
@
i"
,
base
.
recvall
(
conn
,
4
))[
0
]
keylen
=
struct
.
unpack
(
"
<
i"
,
base
.
recvall
(
conn
,
4
))[
0
]
key
=
py_str
(
base
.
recvall
(
conn
,
keylen
))
key
=
py_str
(
base
.
recvall
(
conn
,
keylen
))
arr
=
key
.
split
()
arr
=
key
.
split
()
expect_header
=
"client:"
+
matchkey
expect_header
=
"client:"
+
matchkey
server_key
=
"server:"
+
rpc_key
server_key
=
"server:"
+
rpc_key
if
arr
[
0
]
!=
expect_header
:
if
arr
[
0
]
!=
expect_header
:
conn
.
sendall
(
struct
.
pack
(
"
@
i"
,
base
.
RPC_CODE_MISMATCH
))
conn
.
sendall
(
struct
.
pack
(
"
<
i"
,
base
.
RPC_CODE_MISMATCH
))
conn
.
close
()
conn
.
close
()
logging
.
info
(
"RPCServer: mismatch key from
%
s"
,
addr
)
logging
.
info
(
"RPCServer: mismatch key from
%
s"
,
addr
)
continue
continue
else
:
else
:
conn
.
sendall
(
struct
.
pack
(
"
@
i"
,
base
.
RPC_CODE_SUCCESS
))
conn
.
sendall
(
struct
.
pack
(
"
<
i"
,
base
.
RPC_CODE_SUCCESS
))
conn
.
sendall
(
struct
.
pack
(
"
@
i"
,
len
(
server_key
)))
conn
.
sendall
(
struct
.
pack
(
"
<
i"
,
len
(
server_key
)))
conn
.
sendall
(
server_key
.
encode
(
"utf-8"
))
conn
.
sendall
(
server_key
.
encode
(
"utf-8"
))
return
conn
,
addr
,
_parse_server_opt
(
arr
[
1
:])
return
conn
,
addr
,
_parse_server_opt
(
arr
[
1
:])
...
@@ -151,8 +151,8 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library):
...
@@ -151,8 +151,8 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library):
# step 1: setup tracker and report to tracker
# step 1: setup tracker and report to tracker
if
tracker_addr
and
tracker_conn
is
None
:
if
tracker_addr
and
tracker_conn
is
None
:
tracker_conn
=
base
.
connect_with_retry
(
tracker_addr
)
tracker_conn
=
base
.
connect_with_retry
(
tracker_addr
)
tracker_conn
.
sendall
(
struct
.
pack
(
"
@
i"
,
base
.
RPC_TRACKER_MAGIC
))
tracker_conn
.
sendall
(
struct
.
pack
(
"
<
i"
,
base
.
RPC_TRACKER_MAGIC
))
magic
=
struct
.
unpack
(
"
@
i"
,
base
.
recvall
(
tracker_conn
,
4
))[
0
]
magic
=
struct
.
unpack
(
"
<
i"
,
base
.
recvall
(
tracker_conn
,
4
))[
0
]
if
magic
!=
base
.
RPC_TRACKER_MAGIC
:
if
magic
!=
base
.
RPC_TRACKER_MAGIC
:
raise
RuntimeError
(
"
%
s is not RPC Tracker"
%
str
(
tracker_addr
))
raise
RuntimeError
(
"
%
s is not RPC Tracker"
%
str
(
tracker_addr
))
# report status of current queue
# report status of current queue
...
@@ -193,17 +193,17 @@ def _connect_proxy_loop(addr, key, load_library):
...
@@ -193,17 +193,17 @@ def _connect_proxy_loop(addr, key, load_library):
try
:
try
:
sock
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
sock
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
sock
.
connect
(
addr
)
sock
.
connect
(
addr
)
sock
.
sendall
(
struct
.
pack
(
"
@
i"
,
base
.
RPC_MAGIC
))
sock
.
sendall
(
struct
.
pack
(
"
<
i"
,
base
.
RPC_MAGIC
))
sock
.
sendall
(
struct
.
pack
(
"
@
i"
,
len
(
key
)))
sock
.
sendall
(
struct
.
pack
(
"
<
i"
,
len
(
key
)))
sock
.
sendall
(
key
.
encode
(
"utf-8"
))
sock
.
sendall
(
key
.
encode
(
"utf-8"
))
magic
=
struct
.
unpack
(
"
@
i"
,
base
.
recvall
(
sock
,
4
))[
0
]
magic
=
struct
.
unpack
(
"
<
i"
,
base
.
recvall
(
sock
,
4
))[
0
]
if
magic
==
base
.
RPC_CODE_DUPLICATE
:
if
magic
==
base
.
RPC_CODE_DUPLICATE
:
raise
RuntimeError
(
"key:
%
s has already been used in proxy"
%
key
)
raise
RuntimeError
(
"key:
%
s has already been used in proxy"
%
key
)
elif
magic
==
base
.
RPC_CODE_MISMATCH
:
elif
magic
==
base
.
RPC_CODE_MISMATCH
:
logging
.
info
(
"RPCProxy do not have matching client key
%
s"
,
key
)
logging
.
info
(
"RPCProxy do not have matching client key
%
s"
,
key
)
elif
magic
!=
base
.
RPC_CODE_SUCCESS
:
elif
magic
!=
base
.
RPC_CODE_SUCCESS
:
raise
RuntimeError
(
"
%
s is not RPC Proxy"
%
str
(
addr
))
raise
RuntimeError
(
"
%
s is not RPC Proxy"
%
str
(
addr
))
keylen
=
struct
.
unpack
(
"
@
i"
,
base
.
recvall
(
sock
,
4
))[
0
]
keylen
=
struct
.
unpack
(
"
<
i"
,
base
.
recvall
(
sock
,
4
))[
0
]
remote_key
=
py_str
(
base
.
recvall
(
sock
,
keylen
))
remote_key
=
py_str
(
base
.
recvall
(
sock
,
keylen
))
opts
=
_parse_server_opt
(
remote_key
.
split
()[
1
:])
opts
=
_parse_server_opt
(
remote_key
.
split
()[
1
:])
logging
.
info
(
"RPCProxy connected to
%
s"
,
str
(
addr
))
logging
.
info
(
"RPCProxy connected to
%
s"
,
str
(
addr
))
...
...
python/tvm/contrib/rpc/tracker.py
View file @
42608dda
...
@@ -143,11 +143,11 @@ class TCPEventHandler(tornado_util.TCPHandler):
...
@@ -143,11 +143,11 @@ class TCPEventHandler(tornado_util.TCPHandler):
if
len
(
message
)
!=
4
:
if
len
(
message
)
!=
4
:
logging
.
info
(
"Invalid connection from
%
s"
,
self
.
name
())
logging
.
info
(
"Invalid connection from
%
s"
,
self
.
name
())
self
.
close
()
self
.
close
()
magic
=
struct
.
unpack
(
'
@
i'
,
message
)[
0
]
magic
=
struct
.
unpack
(
'
<
i'
,
message
)[
0
]
if
magic
!=
RPC_TRACKER_MAGIC
:
if
magic
!=
RPC_TRACKER_MAGIC
:
logging
.
info
(
"Invalid magic from
%
s"
,
self
.
name
())
logging
.
info
(
"Invalid magic from
%
s"
,
self
.
name
())
self
.
close
()
self
.
close
()
self
.
write_message
(
struct
.
pack
(
'
@
i'
,
RPC_TRACKER_MAGIC
),
binary
=
True
)
self
.
write_message
(
struct
.
pack
(
'
<
i'
,
RPC_TRACKER_MAGIC
),
binary
=
True
)
self
.
_init_req_nbytes
=
0
self
.
_init_req_nbytes
=
0
def
on_message
(
self
,
message
):
def
on_message
(
self
,
message
):
...
@@ -168,7 +168,7 @@ class TCPEventHandler(tornado_util.TCPHandler):
...
@@ -168,7 +168,7 @@ class TCPEventHandler(tornado_util.TCPHandler):
while
True
:
while
True
:
if
self
.
_msg_size
==
0
:
if
self
.
_msg_size
==
0
:
if
len
(
self
.
_data
)
>=
4
:
if
len
(
self
.
_data
)
>=
4
:
self
.
_msg_size
=
struct
.
unpack
(
'
@
i'
,
self
.
_data
[:
4
])[
0
]
self
.
_msg_size
=
struct
.
unpack
(
'
<
i'
,
self
.
_data
[:
4
])[
0
]
else
:
else
:
return
return
if
self
.
_msg_size
!=
0
and
len
(
self
.
_data
)
>=
self
.
_msg_size
+
4
:
if
self
.
_msg_size
!=
0
and
len
(
self
.
_data
)
>=
self
.
_msg_size
+
4
:
...
@@ -184,7 +184,7 @@ class TCPEventHandler(tornado_util.TCPHandler):
...
@@ -184,7 +184,7 @@ class TCPEventHandler(tornado_util.TCPHandler):
"""return value to the output"""
"""return value to the output"""
data
=
json
.
dumps
(
data
)
data
=
json
.
dumps
(
data
)
self
.
write_message
(
self
.
write_message
(
struct
.
pack
(
'
@
i'
,
len
(
data
)),
binary
=
True
)
struct
.
pack
(
'
<
i'
,
len
(
data
)),
binary
=
True
)
self
.
write_message
(
data
.
encode
(
"utf-8"
),
binary
=
True
)
self
.
write_message
(
data
.
encode
(
"utf-8"
),
binary
=
True
)
def
call_handler
(
self
,
args
):
def
call_handler
(
self
,
args
):
...
@@ -355,8 +355,8 @@ class Tracker(object):
...
@@ -355,8 +355,8 @@ class Tracker(object):
def
_stop_tracker
(
self
):
def
_stop_tracker
(
self
):
sock
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
sock
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
sock
.
connect
((
self
.
host
,
self
.
port
))
sock
.
connect
((
self
.
host
,
self
.
port
))
sock
.
sendall
(
struct
.
pack
(
"
@
i"
,
base
.
RPC_TRACKER_MAGIC
))
sock
.
sendall
(
struct
.
pack
(
"
<
i"
,
base
.
RPC_TRACKER_MAGIC
))
magic
=
struct
.
unpack
(
"
@
i"
,
base
.
recvall
(
sock
,
4
))[
0
]
magic
=
struct
.
unpack
(
"
<
i"
,
base
.
recvall
(
sock
,
4
))[
0
]
assert
magic
==
base
.
RPC_TRACKER_MAGIC
assert
magic
==
base
.
RPC_TRACKER_MAGIC
base
.
sendjson
(
sock
,
[
TrackerCode
.
STOP
,
self
.
stop_key
])
base
.
sendjson
(
sock
,
[
TrackerCode
.
STOP
,
self
.
stop_key
])
assert
base
.
recvjson
(
sock
)
==
TrackerCode
.
SUCCESS
assert
base
.
recvjson
(
sock
)
==
TrackerCode
.
SUCCESS
...
...
src/runtime/file_util.cc
View file @
42608dda
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
*/
*/
#include <dmlc/json.h>
#include <dmlc/json.h>
#include <dmlc/logging.h>
#include <dmlc/logging.h>
#include <tvm/runtime/serializer.h>
#include <fstream>
#include <fstream>
#include "./file_util.h"
#include "./file_util.h"
...
...
src/runtime/graph/graph_runtime.cc
View file @
42608dda
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
*/
*/
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/serializer.h>
#include <dmlc/memory_io.h>
#include <dmlc/memory_io.h>
#include <dmlc/json.h>
#include <dmlc/json.h>
#include <numeric>
#include <numeric>
...
@@ -397,24 +398,25 @@ class GraphRuntime : public ModuleNode {
...
@@ -397,24 +398,25 @@ class GraphRuntime : public ModuleNode {
void
GraphRuntime
::
LoadDLTensor
(
dmlc
::
Stream
*
strm
,
DLTensor
*
dst
)
{
void
GraphRuntime
::
LoadDLTensor
(
dmlc
::
Stream
*
strm
,
DLTensor
*
dst
)
{
// always use strm->Read to maintain endianness conversion
uint64_t
header
,
reserved
;
uint64_t
header
,
reserved
;
CHECK
(
strm
->
Read
(
&
header
,
sizeof
(
header
)
))
CHECK
(
strm
->
Read
(
&
header
))
<<
"Invalid DLTensor file format"
;
<<
"Invalid DLTensor file format"
;
CHECK
(
strm
->
Read
(
&
reserved
,
sizeof
(
reserved
)
))
CHECK
(
strm
->
Read
(
&
reserved
))
<<
"Invalid DLTensor file format"
;
<<
"Invalid DLTensor file format"
;
CHECK
(
header
==
kTVMNDArrayMagic
)
CHECK
(
header
==
kTVMNDArrayMagic
)
<<
"Invalid DLTensor file format"
;
<<
"Invalid DLTensor file format"
;
DLTensor
tensor
;
DLTensor
tensor
;
CHECK
(
strm
->
Read
(
&
tensor
.
ctx
,
sizeof
(
tensor
.
ctx
)))
CHECK
(
strm
->
Read
(
&
(
tensor
.
ctx
)))
<<
"Invalid DLTensor file format"
;
<<
"Invalid DLTensor file format"
;
CHECK
(
strm
->
Read
(
&
tensor
.
ndim
,
sizeof
(
tensor
.
ndim
)))
CHECK
(
strm
->
Read
(
&
(
tensor
.
ndim
)))
<<
"Invalid DLTensor file format"
;
<<
"Invalid DLTensor file format"
;
CHECK
(
strm
->
Read
(
&
tensor
.
dtype
,
sizeof
(
tensor
.
dtype
)))
CHECK
(
strm
->
Read
(
&
(
tensor
.
dtype
)))
<<
"Invalid DLTensor file format"
;
<<
"Invalid DLTensor file format"
;
std
::
vector
<
int64_t
>
shape
(
tensor
.
ndim
);
std
::
vector
<
int64_t
>
shape
(
tensor
.
ndim
);
if
(
tensor
.
ndim
!=
0
)
{
if
(
tensor
.
ndim
!=
0
)
{
CHECK
(
strm
->
Read
(
&
shape
[
0
],
sizeof
(
int64_t
)
*
tensor
.
ndim
))
CHECK
(
strm
->
Read
Array
(
&
shape
[
0
],
tensor
.
ndim
))
<<
"Invalid DLTensor file format"
;
<<
"Invalid DLTensor file format"
;
}
}
CHECK_EQ
(
tensor
.
ndim
,
dst
->
ndim
)
<<
"param dimension mismatch"
;
CHECK_EQ
(
tensor
.
ndim
,
dst
->
ndim
)
<<
"param dimension mismatch"
;
...
@@ -425,18 +427,23 @@ void GraphRuntime::LoadDLTensor(dmlc::Stream* strm, DLTensor* dst) {
...
@@ -425,18 +427,23 @@ void GraphRuntime::LoadDLTensor(dmlc::Stream* strm, DLTensor* dst) {
CHECK_EQ
(
shape
[
i
],
dst
->
shape
[
i
])
<<
"param shape mismatch"
;
CHECK_EQ
(
shape
[
i
],
dst
->
shape
[
i
])
<<
"param shape mismatch"
;
}
}
size_t
bits
=
dst
->
dtype
.
bits
*
dst
->
dtype
.
lanes
;
size_t
bits
=
dst
->
dtype
.
bits
*
dst
->
dtype
.
lanes
;
size_t
size
=
(
bits
+
7
)
/
8
;
size_t
elem_bytes
=
(
bits
+
7
)
/
8
;
size_t
num_elems
=
1
;
for
(
int
i
=
0
;
i
<
dst
->
ndim
;
++
i
)
{
for
(
int
i
=
0
;
i
<
dst
->
ndim
;
++
i
)
{
size
*=
dst
->
shape
[
i
];
num_elems
*=
dst
->
shape
[
i
];
}
}
uint64_t
data_byte_size
;
uint64_t
data_byte_size
;
CHECK
(
strm
->
Read
(
&
data_byte_size
,
sizeof
(
data_byte_size
)
))
CHECK
(
strm
->
Read
(
&
data_byte_size
))
<<
"Invalid DLTensor file format"
;
<<
"Invalid DLTensor file format"
;
CHECK
(
data_byte_size
==
size
)
CHECK
_EQ
(
data_byte_size
,
elem_bytes
*
num_elems
)
<<
"Invalid DLTensor file format"
;
<<
"Invalid DLTensor file format"
;
std
::
vector
<
uint8_t
>
bytes
(
data_byte_size
+
1
);
std
::
vector
<
uint8_t
>
bytes
(
data_byte_size
+
1
);
CHECK
(
strm
->
Read
(
&
bytes
[
0
],
data_byte_size
))
CHECK
(
strm
->
Read
(
&
bytes
[
0
],
data_byte_size
))
<<
"Invalid DLTensor file format"
;
<<
"Invalid DLTensor file format"
;
// explicitly swap endian when necessary.
if
(
!
DMLC_IO_NO_ENDIAN_SWAP
)
{
dmlc
::
ByteSwap
(
&
bytes
[
0
],
elem_bytes
,
num_elems
);
}
TVM_CCALL
(
TVMArrayCopyFromBytes
(
dst
,
&
bytes
[
0
],
data_byte_size
));
TVM_CCALL
(
TVMArrayCopyFromBytes
(
dst
,
&
bytes
[
0
],
data_byte_size
));
}
}
...
@@ -453,9 +460,8 @@ void GraphRuntime::LoadParams(dmlc::Stream* strm) {
...
@@ -453,9 +460,8 @@ void GraphRuntime::LoadParams(dmlc::Stream* strm) {
CHECK
(
strm
->
Read
(
&
names
))
CHECK
(
strm
->
Read
(
&
names
))
<<
"Invalid parameters file format"
;
<<
"Invalid parameters file format"
;
uint64_t
sz
;
uint64_t
sz
;
strm
->
Read
(
&
sz
,
sizeof
(
sz
)
);
strm
->
Read
(
&
sz
);
size_t
size
=
static_cast
<
size_t
>
(
sz
);
size_t
size
=
static_cast
<
size_t
>
(
sz
);
CHECK
(
size
==
names
.
size
())
CHECK
(
size
==
names
.
size
())
<<
"Invalid parameters file format"
;
<<
"Invalid parameters file format"
;
for
(
size_t
i
=
0
;
i
<
size
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
size
;
++
i
)
{
...
...
src/runtime/rpc/rpc_session.cc
View file @
42608dda
...
@@ -6,6 +6,7 @@
...
@@ -6,6 +6,7 @@
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/serializer.h>
#include <memory>
#include <memory>
#include <array>
#include <array>
#include <string>
#include <string>
...
@@ -44,7 +45,7 @@ struct RPCArgBuffer {
...
@@ -44,7 +45,7 @@ struct RPCArgBuffer {
};
};
// Event handler for RPC events.
// Event handler for RPC events.
class
RPCSession
::
EventHandler
{
class
RPCSession
::
EventHandler
:
public
dmlc
::
Stream
{
public
:
public
:
EventHandler
(
common
::
RingBuffer
*
reader
,
EventHandler
(
common
::
RingBuffer
*
reader
,
common
::
RingBuffer
*
writer
,
common
::
RingBuffer
*
writer
,
...
@@ -71,6 +72,15 @@ class RPCSession::EventHandler {
...
@@ -71,6 +72,15 @@ class RPCSession::EventHandler {
return
0
;
return
0
;
}
}
}
}
// Request number of bytes from reader.
void
RequestBytes
(
size_t
nbytes
)
{
pending_request_bytes_
+=
nbytes
;
reader_
->
Reserve
(
pending_request_bytes_
);
}
// Whether we are ready to handle next request.
bool
Ready
()
{
return
reader_
->
bytes_available
()
>=
pending_request_bytes_
;
}
bool
CanCleanShutdown
()
const
{
bool
CanCleanShutdown
()
const
{
return
state_
==
kRecvCode
;
return
state_
==
kRecvCode
;
}
}
...
@@ -86,12 +96,12 @@ class RPCSession::EventHandler {
...
@@ -86,12 +96,12 @@ class RPCSession::EventHandler {
case
kInitHeader
:
HandleInitHeader
();
break
;
case
kInitHeader
:
HandleInitHeader
();
break
;
case
kRecvCode
:
HandleRecvCode
();
break
;
case
kRecvCode
:
HandleRecvCode
();
break
;
case
kRecvCallHandle
:
{
case
kRecvCallHandle
:
{
this
->
Read
(
&
call_handle_
,
sizeof
(
call_handle_
));
CHECK
(
this
->
Read
(
&
call_handle_
));
this
->
SwitchToState
(
kRecvPackedSeqNumArgs
);
this
->
SwitchToState
(
kRecvPackedSeqNumArgs
);
break
;
break
;
}
}
case
kRecvPackedSeqNumArgs
:
{
case
kRecvPackedSeqNumArgs
:
{
this
->
Read
(
&
num_packed_args_
,
sizeof
(
num_packed_args_
));
CHECK
(
this
->
Read
(
&
num_packed_args_
));
arg_buf_
.
reset
(
new
RPCArgBuffer
());
arg_buf_
.
reset
(
new
RPCArgBuffer
());
arg_buf_
->
value
.
resize
(
num_packed_args_
);
arg_buf_
->
value
.
resize
(
num_packed_args_
);
arg_buf_
->
tcode
.
resize
(
num_packed_args_
);
arg_buf_
->
tcode
.
resize
(
num_packed_args_
);
...
@@ -100,7 +110,7 @@ class RPCSession::EventHandler {
...
@@ -100,7 +110,7 @@ class RPCSession::EventHandler {
}
}
case
kRecvPackedSeqTypeCode
:
{
case
kRecvPackedSeqTypeCode
:
{
if
(
num_packed_args_
!=
0
)
{
if
(
num_packed_args_
!=
0
)
{
this
->
Read
(
arg_buf_
->
tcode
.
data
(),
sizeof
(
int
)
*
num_packed_args_
);
this
->
Read
Array
(
arg_buf_
->
tcode
.
data
(),
num_packed_args_
);
}
}
arg_index_
=
0
;
arg_index_
=
0
;
arg_recv_stage_
=
0
;
arg_recv_stage_
=
0
;
...
@@ -164,8 +174,8 @@ class RPCSession::EventHandler {
...
@@ -164,8 +174,8 @@ class RPCSession::EventHandler {
}
}
// send Packed sequence to writer.
// send Packed sequence to writer.
void
SendPackedSeq
(
const
TVMValue
*
arg_values
,
const
int
*
type_codes
,
int
n
)
{
void
SendPackedSeq
(
const
TVMValue
*
arg_values
,
const
int
*
type_codes
,
int
n
)
{
writer_
->
Write
(
&
n
,
sizeof
(
n
)
);
this
->
Write
(
n
);
writer_
->
Write
(
type_codes
,
sizeof
(
int
)
*
n
);
this
->
WriteArray
(
type_codes
,
n
);
// Argument packing.
// Argument packing.
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
int
tcode
=
type_codes
[
i
];
int
tcode
=
type_codes
[
i
];
...
@@ -173,14 +183,20 @@ class RPCSession::EventHandler {
...
@@ -173,14 +183,20 @@ class RPCSession::EventHandler {
switch
(
tcode
)
{
switch
(
tcode
)
{
case
kDLInt
:
case
kDLInt
:
case
kDLUInt
:
case
kDLUInt
:
case
kDLFloat
:
case
kDLFloat
:
{
this
->
Write
<
int64_t
>
(
value
.
v_int64
);
break
;
}
case
kTVMType
:
{
case
kTVMType
:
{
writer_
->
Write
(
&
value
,
sizeof
(
TVMValue
));
this
->
Write
(
value
.
v_type
);
// padding
int32_t
padding
=
0
;
this
->
Write
<
int32_t
>
(
padding
);
break
;
break
;
}
}
case
kTVMContext
:
{
case
kTVMContext
:
{
value
.
v_ctx
=
StripSessMask
(
value
.
v_ctx
);
value
.
v_ctx
=
StripSessMask
(
value
.
v_ctx
);
writer_
->
Write
(
&
value
,
sizeof
(
TVMValue
)
);
this
->
Write
(
value
.
v_ctx
);
break
;
break
;
}
}
case
kFuncHandle
:
case
kFuncHandle
:
...
@@ -188,7 +204,7 @@ class RPCSession::EventHandler {
...
@@ -188,7 +204,7 @@ class RPCSession::EventHandler {
case
kHandle
:
{
case
kHandle
:
{
// always send handle in 64 bit.
// always send handle in 64 bit.
uint64_t
handle
=
reinterpret_cast
<
uint64_t
>
(
value
.
v_handle
);
uint64_t
handle
=
reinterpret_cast
<
uint64_t
>
(
value
.
v_handle
);
writer_
->
Write
(
&
handle
,
sizeof
(
uint64_t
)
);
this
->
Write
(
handle
);
break
;
break
;
}
}
case
kArrayHandle
:
{
case
kArrayHandle
:
{
...
@@ -196,11 +212,11 @@ class RPCSession::EventHandler {
...
@@ -196,11 +212,11 @@ class RPCSession::EventHandler {
TVMContext
ctx
=
StripSessMask
(
arr
->
ctx
);
TVMContext
ctx
=
StripSessMask
(
arr
->
ctx
);
uint64_t
data
=
reinterpret_cast
<
uint64_t
>
(
uint64_t
data
=
reinterpret_cast
<
uint64_t
>
(
static_cast
<
RemoteSpace
*>
(
arr
->
data
)
->
data
);
static_cast
<
RemoteSpace
*>
(
arr
->
data
)
->
data
);
writer_
->
Write
(
&
data
,
sizeof
(
uint64_t
)
);
this
->
Write
(
data
);
writer_
->
Write
(
&
ctx
,
sizeof
(
ctx
)
);
this
->
Write
(
ctx
);
writer_
->
Write
(
&
(
arr
->
ndim
),
sizeof
(
int
)
);
this
->
Write
(
arr
->
ndim
);
writer_
->
Write
(
&
(
arr
->
dtype
),
sizeof
(
DLDataType
)
);
this
->
Write
(
arr
->
dtype
);
writer_
->
Write
(
arr
->
shape
,
sizeof
(
int64_t
)
*
arr
->
ndim
);
this
->
WriteArray
(
arr
->
shape
,
arr
->
ndim
);
CHECK
(
arr
->
strides
==
nullptr
)
CHECK
(
arr
->
strides
==
nullptr
)
<<
"Donot support strided remote array"
;
<<
"Donot support strided remote array"
;
CHECK_EQ
(
arr
->
byte_offset
,
0
)
CHECK_EQ
(
arr
->
byte_offset
,
0
)
...
@@ -211,15 +227,15 @@ class RPCSession::EventHandler {
...
@@ -211,15 +227,15 @@ class RPCSession::EventHandler {
case
kStr
:
{
case
kStr
:
{
const
char
*
s
=
value
.
v_str
;
const
char
*
s
=
value
.
v_str
;
uint64_t
len
=
strlen
(
s
);
uint64_t
len
=
strlen
(
s
);
writer_
->
Write
(
&
len
,
sizeof
(
len
)
);
this
->
Write
(
len
);
writer_
->
Write
(
s
,
sizeof
(
char
)
*
len
);
this
->
WriteArray
(
s
,
len
);
break
;
break
;
}
}
case
kBytes
:
{
case
kBytes
:
{
TVMByteArray
*
bytes
=
static_cast
<
TVMByteArray
*>
(
arg_values
[
i
].
v_handle
);
TVMByteArray
*
bytes
=
static_cast
<
TVMByteArray
*>
(
arg_values
[
i
].
v_handle
);
uint64_t
len
=
bytes
->
size
;
uint64_t
len
=
bytes
->
size
;
writer_
->
Write
(
&
len
,
sizeof
(
len
)
);
this
->
Write
(
len
);
writer_
->
Write
(
bytes
->
data
,
sizeof
(
char
)
*
len
);
this
->
WriteArray
(
bytes
->
data
,
len
);
break
;
break
;
}
}
default
:
{
default
:
{
...
@@ -230,6 +246,23 @@ class RPCSession::EventHandler {
...
@@ -230,6 +246,23 @@ class RPCSession::EventHandler {
}
}
}
}
// Endian aware IO handling
using
Stream
::
Read
;
using
Stream
::
Write
;
using
Stream
::
ReadArray
;
using
Stream
::
WriteArray
;
inline
bool
Read
(
RPCCode
*
code
)
{
int
cdata
;
if
(
!
this
->
Read
(
&
cdata
))
return
false
;
*
code
=
static_cast
<
RPCCode
>
(
cdata
);
return
true
;
}
inline
void
Write
(
RPCCode
code
)
{
int
cdata
=
static_cast
<
int
>
(
code
);
this
->
Write
(
cdata
);
}
protected
:
protected
:
enum
State
{
enum
State
{
kInitHeader
,
kInitHeader
,
...
@@ -370,10 +403,22 @@ class RPCSession::EventHandler {
...
@@ -370,10 +403,22 @@ class RPCSession::EventHandler {
switch
(
tcode
)
{
switch
(
tcode
)
{
case
kDLInt
:
case
kDLInt
:
case
kDLUInt
:
case
kDLUInt
:
case
kDLFloat
:
case
kDLFloat
:
{
case
kTVMType
:
this
->
Read
<
int64_t
>
(
&
(
value
.
v_int64
));
++
arg_index_
;
this
->
SwitchToState
(
kRecvPackedSeqArg
);
break
;
}
case
kTVMType
:
{
this
->
Read
(
&
(
value
.
v_type
));
int32_t
padding
=
0
;
this
->
Read
<
int32_t
>
(
&
padding
);
++
arg_index_
;
this
->
SwitchToState
(
kRecvPackedSeqArg
);
break
;
}
case
kTVMContext
:
{
case
kTVMContext
:
{
this
->
Read
(
&
value
,
sizeof
(
TVMValue
));
this
->
Read
(
&
(
value
.
v_ctx
));
++
arg_index_
;
++
arg_index_
;
this
->
SwitchToState
(
kRecvPackedSeqArg
);
this
->
SwitchToState
(
kRecvPackedSeqArg
);
break
;
break
;
...
@@ -383,7 +428,7 @@ class RPCSession::EventHandler {
...
@@ -383,7 +428,7 @@ class RPCSession::EventHandler {
case
kHandle
:
{
case
kHandle
:
{
// always send handle in 64 bit.
// always send handle in 64 bit.
uint64_t
handle
;
uint64_t
handle
;
this
->
Read
(
&
handle
,
sizeof
(
handle
)
);
this
->
Read
(
&
handle
);
value
.
v_handle
=
reinterpret_cast
<
void
*>
(
handle
);
value
.
v_handle
=
reinterpret_cast
<
void
*>
(
handle
);
++
arg_index_
;
++
arg_index_
;
this
->
SwitchToState
(
kRecvPackedSeqArg
);
this
->
SwitchToState
(
kRecvPackedSeqArg
);
...
@@ -398,7 +443,7 @@ class RPCSession::EventHandler {
...
@@ -398,7 +443,7 @@ class RPCSession::EventHandler {
case
kStr
:
case
kStr
:
case
kBytes
:
{
case
kBytes
:
{
uint64_t
len
;
uint64_t
len
;
this
->
Read
(
&
len
,
sizeof
(
len
)
);
this
->
Read
(
&
len
);
temp_bytes_
.
reset
(
new
RPCByteArrayBuffer
());
temp_bytes_
.
reset
(
new
RPCByteArrayBuffer
());
temp_bytes_
->
data
.
resize
(
len
);
temp_bytes_
->
data
.
resize
(
len
);
arg_recv_stage_
=
1
;
arg_recv_stage_
=
1
;
...
@@ -409,12 +454,12 @@ class RPCSession::EventHandler {
...
@@ -409,12 +454,12 @@ class RPCSession::EventHandler {
case
kArrayHandle
:
{
case
kArrayHandle
:
{
temp_array_
.
reset
(
new
RPCDataArrayBuffer
());
temp_array_
.
reset
(
new
RPCDataArrayBuffer
());
uint64_t
handle
;
uint64_t
handle
;
this
->
Read
(
&
handle
,
sizeof
(
handle
)
);
this
->
Read
(
&
handle
);
DLTensor
&
tensor
=
temp_array_
->
tensor
;
DLTensor
&
tensor
=
temp_array_
->
tensor
;
tensor
.
data
=
reinterpret_cast
<
void
*>
(
handle
);
tensor
.
data
=
reinterpret_cast
<
void
*>
(
handle
);
this
->
Read
(
&
(
tensor
.
ctx
)
,
sizeof
(
TVMContext
)
);
this
->
Read
(
&
(
tensor
.
ctx
));
this
->
Read
(
&
(
tensor
.
ndim
)
,
sizeof
(
int
)
);
this
->
Read
(
&
(
tensor
.
ndim
));
this
->
Read
(
&
(
tensor
.
dtype
)
,
sizeof
(
DLDataType
)
);
this
->
Read
(
&
(
tensor
.
dtype
));
temp_array_
->
shape
.
resize
(
tensor
.
ndim
);
temp_array_
->
shape
.
resize
(
tensor
.
ndim
);
tensor
.
shape
=
temp_array_
->
shape
.
data
();
tensor
.
shape
=
temp_array_
->
shape
.
data
();
arg_recv_stage_
=
1
;
arg_recv_stage_
=
1
;
...
@@ -432,7 +477,7 @@ class RPCSession::EventHandler {
...
@@ -432,7 +477,7 @@ class RPCSession::EventHandler {
CHECK_EQ
(
arg_recv_stage_
,
1
);
CHECK_EQ
(
arg_recv_stage_
,
1
);
if
(
tcode
==
kStr
||
tcode
==
kBytes
)
{
if
(
tcode
==
kStr
||
tcode
==
kBytes
)
{
if
(
temp_bytes_
->
data
.
size
()
!=
0
)
{
if
(
temp_bytes_
->
data
.
size
()
!=
0
)
{
this
->
Read
(
&
(
temp_bytes_
->
data
[
0
]),
temp_bytes_
->
data
.
size
());
this
->
Read
Array
(
&
(
temp_bytes_
->
data
[
0
]),
temp_bytes_
->
data
.
size
());
}
}
if
(
tcode
==
kStr
)
{
if
(
tcode
==
kStr
)
{
value
.
v_str
=
temp_bytes_
->
data
.
c_str
();
value
.
v_str
=
temp_bytes_
->
data
.
c_str
();
...
@@ -445,7 +490,7 @@ class RPCSession::EventHandler {
...
@@ -445,7 +490,7 @@ class RPCSession::EventHandler {
}
else
{
}
else
{
CHECK_EQ
(
tcode
,
kArrayHandle
);
CHECK_EQ
(
tcode
,
kArrayHandle
);
DLTensor
&
tensor
=
temp_array_
->
tensor
;
DLTensor
&
tensor
=
temp_array_
->
tensor
;
this
->
Read
(
tensor
.
shape
,
tensor
.
ndim
*
sizeof
(
int64_t
)
);
this
->
Read
Array
(
tensor
.
shape
,
tensor
.
ndim
);
value
.
v_handle
=
&
tensor
;
value
.
v_handle
=
&
tensor
;
arg_buf_
->
temp_array
.
emplace_back
(
std
::
move
(
temp_array_
));
arg_buf_
->
temp_array
.
emplace_back
(
std
::
move
(
temp_array_
));
}
}
...
@@ -458,20 +503,20 @@ class RPCSession::EventHandler {
...
@@ -458,20 +503,20 @@ class RPCSession::EventHandler {
void
HandleInitHeader
()
{
void
HandleInitHeader
()
{
if
(
init_header_step_
==
0
)
{
if
(
init_header_step_
==
0
)
{
int32_t
len
;
int32_t
len
;
this
->
Read
(
&
len
,
sizeof
(
len
)
);
this
->
Read
(
&
len
);
remote_key_
->
resize
(
len
);
remote_key_
->
resize
(
len
);
init_header_step_
=
1
;
init_header_step_
=
1
;
this
->
RequestBytes
(
len
);
this
->
RequestBytes
(
len
);
return
;
return
;
}
else
{
}
else
{
CHECK_EQ
(
init_header_step_
,
1
);
CHECK_EQ
(
init_header_step_
,
1
);
this
->
Read
(
dmlc
::
BeginPtr
(
*
remote_key_
),
remote_key_
->
length
());
this
->
Read
Array
(
dmlc
::
BeginPtr
(
*
remote_key_
),
remote_key_
->
length
());
this
->
SwitchToState
(
kRecvCode
);
this
->
SwitchToState
(
kRecvCode
);
}
}
}
}
// Handler for read code.
// Handler for read code.
void
HandleRecvCode
()
{
void
HandleRecvCode
()
{
this
->
Read
(
&
code_
,
sizeof
(
code_
)
);
this
->
Read
(
&
code_
);
if
(
code_
>
RPCCode
::
kSystemFuncStart
)
{
if
(
code_
>
RPCCode
::
kSystemFuncStart
)
{
SwitchToState
(
kRecvPackedSeqNumArgs
);
SwitchToState
(
kRecvPackedSeqNumArgs
);
return
;
return
;
...
@@ -511,14 +556,14 @@ class RPCSession::EventHandler {
...
@@ -511,14 +556,14 @@ class RPCSession::EventHandler {
void
HandleCopyFromRemote
()
{
void
HandleCopyFromRemote
()
{
uint64_t
handle
,
offset
,
size
;
uint64_t
handle
,
offset
,
size
;
TVMContext
ctx
;
TVMContext
ctx
;
this
->
Read
(
&
handle
,
sizeof
(
handle
)
);
this
->
Read
(
&
handle
);
this
->
Read
(
&
offset
,
sizeof
(
offset
)
);
this
->
Read
(
&
offset
);
this
->
Read
(
&
size
,
sizeof
(
size
)
);
this
->
Read
(
&
size
);
this
->
Read
(
&
ctx
,
sizeof
(
ctx
)
);
this
->
Read
(
&
ctx
);
if
(
ctx
.
device_type
==
kDLCPU
)
{
if
(
ctx
.
device_type
==
kDLCPU
)
{
RPCCode
code
=
RPCCode
::
kCopyAck
;
RPCCode
code
=
RPCCode
::
kCopyAck
;
writer_
->
Write
(
&
code
,
sizeof
(
code
)
);
this
->
Write
(
code
);
writer_
->
Write
(
reinterpret_cast
<
char
*>
(
handle
)
+
offset
,
size
);
this
->
WriteArray
(
reinterpret_cast
<
char
*>
(
handle
)
+
offset
,
size
);
}
else
{
}
else
{
temp_data_
.
resize
(
size
+
1
);
temp_data_
.
resize
(
size
+
1
);
try
{
try
{
...
@@ -530,11 +575,11 @@ class RPCSession::EventHandler {
...
@@ -530,11 +575,11 @@ class RPCSession::EventHandler {
dmlc
::
BeginPtr
(
temp_data_
),
0
,
dmlc
::
BeginPtr
(
temp_data_
),
0
,
size
,
ctx
,
cpu_ctx
,
nullptr
);
size
,
ctx
,
cpu_ctx
,
nullptr
);
RPCCode
code
=
RPCCode
::
kCopyAck
;
RPCCode
code
=
RPCCode
::
kCopyAck
;
writer_
->
Write
(
&
code
,
sizeof
(
code
)
);
this
->
Write
(
code
);
writer_
->
Write
(
&
temp_data_
[
0
],
size
);
this
->
WriteArray
(
&
temp_data_
[
0
],
size
);
}
catch
(
const
std
::
runtime_error
&
e
)
{
}
catch
(
const
std
::
runtime_error
&
e
)
{
RPCCode
code
=
RPCCode
::
kException
;
RPCCode
code
=
RPCCode
::
kException
;
writer_
->
Write
(
&
code
,
sizeof
(
code
)
);
this
->
Write
(
code
);
TVMValue
ret_value
;
TVMValue
ret_value
;
ret_value
.
v_str
=
e
.
what
();
ret_value
.
v_str
=
e
.
what
();
int
ret_tcode
=
kStr
;
int
ret_tcode
=
kStr
;
...
@@ -548,10 +593,10 @@ class RPCSession::EventHandler {
...
@@ -548,10 +593,10 @@ class RPCSession::EventHandler {
// use static variable to persist state.
// use static variable to persist state.
// This only works if next stage is immediately after this.
// This only works if next stage is immediately after this.
if
(
arg_recv_stage_
==
0
)
{
if
(
arg_recv_stage_
==
0
)
{
this
->
Read
(
&
copy_handle_
,
sizeof
(
uint64_t
));
CHECK
(
this
->
Read
(
&
copy_handle_
));
this
->
Read
(
&
copy_offset_
,
sizeof
(
uint64_t
));
CHECK
(
this
->
Read
(
&
copy_offset_
));
this
->
Read
(
&
copy_size_
,
sizeof
(
uint64_t
));
CHECK
(
this
->
Read
(
&
copy_size_
));
this
->
Read
(
&
copy_ctx_
,
sizeof
(
TVMContext
));
CHECK
(
this
->
Read
(
&
copy_ctx_
));
arg_recv_stage_
=
1
;
arg_recv_stage_
=
1
;
CHECK_EQ
(
pending_request_bytes_
,
0U
);
CHECK_EQ
(
pending_request_bytes_
,
0U
);
this
->
RequestBytes
(
copy_size_
);
this
->
RequestBytes
(
copy_size_
);
...
@@ -563,11 +608,11 @@ class RPCSession::EventHandler {
...
@@ -563,11 +608,11 @@ class RPCSession::EventHandler {
RPCCode
code
=
RPCCode
::
kReturn
;
RPCCode
code
=
RPCCode
::
kReturn
;
std
::
string
errmsg
;
std
::
string
errmsg
;
if
(
copy_ctx_
.
device_type
==
kDLCPU
)
{
if
(
copy_ctx_
.
device_type
==
kDLCPU
)
{
this
->
Read
(
this
->
Read
Array
(
reinterpret_cast
<
char
*>
(
copy_handle_
)
+
copy_offset_
,
copy_size_
);
reinterpret_cast
<
char
*>
(
copy_handle_
)
+
copy_offset_
,
copy_size_
);
}
else
{
}
else
{
temp_data_
.
resize
(
copy_size_
+
1
);
temp_data_
.
resize
(
copy_size_
+
1
);
this
->
Read
(
&
temp_data_
[
0
],
copy_size_
);
this
->
Read
Array
(
&
temp_data_
[
0
],
copy_size_
);
try
{
try
{
TVMContext
cpu_ctx
;
TVMContext
cpu_ctx
;
cpu_ctx
.
device_type
=
kDLCPU
;
cpu_ctx
.
device_type
=
kDLCPU
;
...
@@ -583,7 +628,7 @@ class RPCSession::EventHandler {
...
@@ -583,7 +628,7 @@ class RPCSession::EventHandler {
ret_tcode
=
kStr
;
ret_tcode
=
kStr
;
}
}
}
}
writer_
->
Write
(
&
code
,
sizeof
(
code
)
);
this
->
Write
(
code
);
SendPackedSeq
(
&
ret_value
,
&
ret_tcode
,
1
);
SendPackedSeq
(
&
ret_value
,
&
ret_tcode
,
1
);
arg_recv_stage_
=
0
;
arg_recv_stage_
=
0
;
this
->
SwitchToState
(
kRecvCode
);
this
->
SwitchToState
(
kRecvCode
);
...
@@ -603,7 +648,7 @@ class RPCSession::EventHandler {
...
@@ -603,7 +648,7 @@ class RPCSession::EventHandler {
std
::
unique_ptr
<
RPCArgBuffer
>
args
=
std
::
move
(
arg_buf_
);
std
::
unique_ptr
<
RPCArgBuffer
>
args
=
std
::
move
(
arg_buf_
);
f
(
args
->
AsTVMArgs
(),
&
rv
);
f
(
args
->
AsTVMArgs
(),
&
rv
);
RPCCode
code
=
RPCCode
::
kReturn
;
RPCCode
code
=
RPCCode
::
kReturn
;
writer_
->
Write
(
&
code
,
sizeof
(
code
)
);
this
->
Write
(
code
);
if
(
rv
.
type_code
()
==
kStr
)
{
if
(
rv
.
type_code
()
==
kStr
)
{
ret_value
.
v_str
=
rv
.
ptr
<
std
::
string
>
()
->
c_str
();
ret_value
.
v_str
=
rv
.
ptr
<
std
::
string
>
()
->
c_str
();
ret_tcode
=
kStr
;
ret_tcode
=
kStr
;
...
@@ -630,7 +675,7 @@ class RPCSession::EventHandler {
...
@@ -630,7 +675,7 @@ class RPCSession::EventHandler {
}
}
}
catch
(
const
std
::
runtime_error
&
e
)
{
}
catch
(
const
std
::
runtime_error
&
e
)
{
RPCCode
code
=
RPCCode
::
kException
;
RPCCode
code
=
RPCCode
::
kException
;
writer_
->
Write
(
&
code
,
sizeof
(
code
)
);
this
->
Write
(
code
);
ret_value
.
v_str
=
e
.
what
();
ret_value
.
v_str
=
e
.
what
();
ret_tcode
=
kStr
;
ret_tcode
=
kStr
;
SendPackedSeq
(
&
ret_value
,
&
ret_tcode
,
1
);
SendPackedSeq
(
&
ret_value
,
&
ret_tcode
,
1
);
...
@@ -640,19 +685,14 @@ class RPCSession::EventHandler {
...
@@ -640,19 +685,14 @@ class RPCSession::EventHandler {
private
:
private
:
// Utility functions
// Utility functions
// Internal read function, update pending_request_bytes_
// Internal read function, update pending_request_bytes_
void
Read
(
void
*
data
,
size_t
size
)
{
size_t
Read
(
void
*
data
,
size_t
size
)
final
{
CHECK_LE
(
size
,
pending_request_bytes_
);
CHECK_LE
(
size
,
pending_request_bytes_
);
reader_
->
Read
(
data
,
size
);
reader_
->
Read
(
data
,
size
);
pending_request_bytes_
-=
size
;
pending_request_bytes_
-=
size
;
return
size
;
}
}
// Request number of bytes from reader.
void
Write
(
const
void
*
data
,
size_t
size
)
final
{
void
RequestBytes
(
size_t
nbytes
)
{
writer_
->
Write
(
data
,
size
);
pending_request_bytes_
+=
nbytes
;
reader_
->
Reserve
(
pending_request_bytes_
);
}
// Whether we are ready to handle next request.
bool
Ready
()
{
return
reader_
->
bytes_available
()
>=
pending_request_bytes_
;
}
}
// Number of pending bytes requests
// Number of pending bytes requests
size_t
pending_request_bytes_
;
size_t
pending_request_bytes_
;
...
@@ -766,7 +806,7 @@ RPCSession::~RPCSession() {
...
@@ -766,7 +806,7 @@ RPCSession::~RPCSession() {
void
RPCSession
::
Shutdown
()
{
void
RPCSession
::
Shutdown
()
{
if
(
channel_
!=
nullptr
)
{
if
(
channel_
!=
nullptr
)
{
RPCCode
code
=
RPCCode
::
kShutdown
;
RPCCode
code
=
RPCCode
::
kShutdown
;
writer_
.
Write
(
&
code
,
sizeof
(
code
)
);
handler_
->
Write
(
code
);
// flush all writing buffer to output channel.
// flush all writing buffer to output channel.
try
{
try
{
while
(
writer_
.
bytes_available
()
!=
0
)
{
while
(
writer_
.
bytes_available
()
!=
0
)
{
...
@@ -788,7 +828,6 @@ void RPCSession::ServerLoop() {
...
@@ -788,7 +828,6 @@ void RPCSession::ServerLoop() {
}
}
TVMRetValue
rv
;
TVMRetValue
rv
;
CHECK
(
HandleUntilReturnEvent
(
&
rv
,
false
,
nullptr
)
==
RPCCode
::
kShutdown
);
CHECK
(
HandleUntilReturnEvent
(
&
rv
,
false
,
nullptr
)
==
RPCCode
::
kShutdown
);
LOG
(
INFO
)
<<
"Shutdown..."
;
if
(
const
auto
*
f
=
Registry
::
Get
(
"tvm.contrib.rpc.server.shutdown"
))
{
if
(
const
auto
*
f
=
Registry
::
Get
(
"tvm.contrib.rpc.server.shutdown"
))
{
(
*
f
)();
(
*
f
)();
}
}
...
@@ -821,9 +860,9 @@ void RPCSession::CallFunc(void* h,
...
@@ -821,9 +860,9 @@ void RPCSession::CallFunc(void* h,
const
PackedFunc
*
fwrap
)
{
const
PackedFunc
*
fwrap
)
{
std
::
lock_guard
<
std
::
recursive_mutex
>
lock
(
mutex_
);
std
::
lock_guard
<
std
::
recursive_mutex
>
lock
(
mutex_
);
RPCCode
code
=
RPCCode
::
kCallFunc
;
RPCCode
code
=
RPCCode
::
kCallFunc
;
writer_
.
Write
(
&
code
,
sizeof
(
code
)
);
handler_
->
Write
(
code
);
uint64_t
handle
=
reinterpret_cast
<
uint64_t
>
(
h
);
uint64_t
handle
=
reinterpret_cast
<
uint64_t
>
(
h
);
writer_
.
Write
(
&
handle
,
sizeof
(
handle
)
);
handler_
->
Write
(
handle
);
handler_
->
SendPackedSeq
(
args
.
values
,
args
.
type_codes
,
args
.
num_args
);
handler_
->
SendPackedSeq
(
args
.
values
,
args
.
type_codes
,
args
.
num_args
);
code
=
HandleUntilReturnEvent
(
rv
,
true
,
fwrap
);
code
=
HandleUntilReturnEvent
(
rv
,
true
,
fwrap
);
CHECK
(
code
==
RPCCode
::
kReturn
)
<<
"code="
<<
static_cast
<
int
>
(
code
);
CHECK
(
code
==
RPCCode
::
kReturn
)
<<
"code="
<<
static_cast
<
int
>
(
code
);
...
@@ -838,15 +877,15 @@ void RPCSession::CopyToRemote(void* from,
...
@@ -838,15 +877,15 @@ void RPCSession::CopyToRemote(void* from,
std
::
lock_guard
<
std
::
recursive_mutex
>
lock
(
mutex_
);
std
::
lock_guard
<
std
::
recursive_mutex
>
lock
(
mutex_
);
ctx_to
=
handler_
->
StripSessMask
(
ctx_to
);
ctx_to
=
handler_
->
StripSessMask
(
ctx_to
);
RPCCode
code
=
RPCCode
::
kCopyToRemote
;
RPCCode
code
=
RPCCode
::
kCopyToRemote
;
writer_
.
Write
(
&
code
,
sizeof
(
code
)
);
handler_
->
Write
(
code
);
uint64_t
handle
=
reinterpret_cast
<
uint64_t
>
(
to
);
uint64_t
handle
=
reinterpret_cast
<
uint64_t
>
(
to
);
writer_
.
Write
(
&
handle
,
sizeof
(
handle
)
);
handler_
->
Write
(
handle
);
uint64_t
offset
=
static_cast
<
uint64_t
>
(
to_offset
);
uint64_t
offset
=
static_cast
<
uint64_t
>
(
to_offset
);
writer_
.
Write
(
&
offset
,
sizeof
(
offset
)
);
handler_
->
Write
(
offset
);
uint64_t
size
=
static_cast
<
uint64_t
>
(
data_size
);
uint64_t
size
=
static_cast
<
uint64_t
>
(
data_size
);
writer_
.
Write
(
&
size
,
sizeof
(
size
)
);
handler_
->
Write
(
size
);
writer_
.
Write
(
&
ctx_to
,
sizeof
(
ctx_to
)
);
handler_
->
Write
(
ctx_to
);
writer_
.
Write
(
reinterpret_cast
<
char
*>
(
from
)
+
from_offset
,
data_size
);
handler_
->
WriteArray
(
reinterpret_cast
<
char
*>
(
from
)
+
from_offset
,
data_size
);
TVMRetValue
rv
;
TVMRetValue
rv
;
CHECK
(
HandleUntilReturnEvent
(
&
rv
,
true
,
nullptr
)
==
RPCCode
::
kReturn
);
CHECK
(
HandleUntilReturnEvent
(
&
rv
,
true
,
nullptr
)
==
RPCCode
::
kReturn
);
}
}
...
@@ -860,26 +899,27 @@ void RPCSession::CopyFromRemote(void* from,
...
@@ -860,26 +899,27 @@ void RPCSession::CopyFromRemote(void* from,
std
::
lock_guard
<
std
::
recursive_mutex
>
lock
(
mutex_
);
std
::
lock_guard
<
std
::
recursive_mutex
>
lock
(
mutex_
);
ctx_from
=
handler_
->
StripSessMask
(
ctx_from
);
ctx_from
=
handler_
->
StripSessMask
(
ctx_from
);
RPCCode
code
=
RPCCode
::
kCopyFromRemote
;
RPCCode
code
=
RPCCode
::
kCopyFromRemote
;
writer_
.
Write
(
&
code
,
sizeof
(
code
)
);
handler_
->
Write
(
code
);
uint64_t
handle
=
reinterpret_cast
<
uint64_t
>
(
from
);
uint64_t
handle
=
reinterpret_cast
<
uint64_t
>
(
from
);
writer_
.
Write
(
&
handle
,
sizeof
(
handle
)
);
handler_
->
Write
(
handle
);
uint64_t
offset
=
static_cast
<
uint64_t
>
(
from_offset
);
uint64_t
offset
=
static_cast
<
uint64_t
>
(
from_offset
);
writer_
.
Write
(
&
offset
,
sizeof
(
offset
)
);
handler_
->
Write
(
offset
);
uint64_t
size
=
static_cast
<
uint64_t
>
(
data_size
);
uint64_t
size
=
static_cast
<
uint64_t
>
(
data_size
);
writer_
.
Write
(
&
size
,
sizeof
(
size
)
);
handler_
->
Write
(
size
);
writer_
.
Write
(
&
ctx_from
,
sizeof
(
ctx_from
)
);
handler_
->
Write
(
ctx_from
);
TVMRetValue
rv
;
TVMRetValue
rv
;
CHECK
(
HandleUntilReturnEvent
(
&
rv
,
true
,
nullptr
)
==
RPCCode
::
kCopyAck
);
CHECK
(
HandleUntilReturnEvent
(
&
rv
,
true
,
nullptr
)
==
RPCCode
::
kCopyAck
);
reader_
.
Reserve
(
data_size
);
reader_
.
Reserve
(
data_size
);
while
(
reader_
.
bytes_available
()
<
data_size
)
{
handler_
->
RequestBytes
(
data_size
);
size_t
bytes_needed
=
data_size
-
reader_
.
bytes_available
();
while
(
!
handler_
->
Ready
())
{
size_t
bytes_needed
=
handler_
->
BytesNeeded
();
reader_
.
WriteWithCallback
([
this
](
void
*
data
,
size_t
size
)
{
reader_
.
WriteWithCallback
([
this
](
void
*
data
,
size_t
size
)
{
size_t
n
=
channel_
->
Recv
(
data
,
size
);
size_t
n
=
channel_
->
Recv
(
data
,
size
);
CHECK_NE
(
n
,
0U
)
<<
"Channel closes before we get neded bytes"
;
CHECK_NE
(
n
,
0U
)
<<
"Channel closes before we get neded bytes"
;
return
n
;
return
n
;
},
bytes_needed
);
},
bytes_needed
);
}
}
reader_
.
Read
(
reinterpret_cast
<
char
*>
(
to
)
+
to_offset
,
data_size
);
handler_
->
ReadArray
(
reinterpret_cast
<
char
*>
(
to
)
+
to_offset
,
data_size
);
handler_
->
FinishCopyAck
();
handler_
->
FinishCopyAck
();
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment