Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
42608dda
Commit
42608dda
authored
May 30, 2018
by
tqchen
Committed by
Tianqi Chen
May 30, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[IO] Support cross-endian
parent
0f9dab98
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
147 additions
and
80 deletions
+147
-80
dmlc-core
+1
-1
include/tvm/runtime/serializer.h
+50
-0
nnvm/src/compiler/graph_runtime.cc
+42
-32
python/tvm/contrib/rpc/base.py
+2
-2
python/tvm/contrib/rpc/client.py
+2
-2
python/tvm/contrib/rpc/proxy.py
+14
-14
python/tvm/contrib/rpc/server.py
+11
-11
python/tvm/contrib/rpc/tracker.py
+6
-6
src/runtime/file_util.cc
+1
-0
src/runtime/graph/graph_runtime.cc
+18
-12
src/runtime/rpc/rpc_session.cc
+0
-0
No files found.
dmlc-core
@
9b3f9753
Subproject commit
d3f7fbb53e5b037c0f5bf6bd21871ccc720690cc
Subproject commit
9b3f9753ae81d657743c555e0cacc4e43f0bed2d
include/tvm/runtime/serializer.h
0 → 100644
View file @
42608dda
/*!
* Copyright (c) 2017 by Contributors
* \file tvm/runtime/serializer.h
* \brief Serializer extension to support TVM data types
* Include this file to enable serialization of DLDataType, DLContext
*/
#ifndef TVM_RUNTIME_SERIALIZER_H_
#define TVM_RUNTIME_SERIALIZER_H_
#include <dmlc/io.h>
#include <dmlc/serializer.h>
#include "./c_runtime_api.h"
namespace
dmlc
{
namespace
serializer
{
template
<>
struct
Handler
<
DLDataType
>
{
inline
static
void
Write
(
Stream
*
strm
,
const
DLDataType
&
dtype
)
{
Handler
<
uint8_t
>::
Write
(
strm
,
dtype
.
code
);
Handler
<
uint8_t
>::
Write
(
strm
,
dtype
.
bits
);
Handler
<
uint16_t
>::
Write
(
strm
,
dtype
.
lanes
);
}
inline
static
bool
Read
(
Stream
*
strm
,
DLDataType
*
dtype
)
{
if
(
!
Handler
<
uint8_t
>::
Read
(
strm
,
&
(
dtype
->
code
)))
return
false
;
if
(
!
Handler
<
uint8_t
>::
Read
(
strm
,
&
(
dtype
->
bits
)))
return
false
;
if
(
!
Handler
<
uint16_t
>::
Read
(
strm
,
&
(
dtype
->
lanes
)))
return
false
;
return
true
;
}
};
template
<>
struct
Handler
<
DLContext
>
{
inline
static
void
Write
(
Stream
*
strm
,
const
DLContext
&
ctx
)
{
int32_t
device_type
=
static_cast
<
int32_t
>
(
ctx
.
device_type
);
Handler
<
int32_t
>::
Write
(
strm
,
device_type
);
Handler
<
int32_t
>::
Write
(
strm
,
ctx
.
device_id
);
}
inline
static
bool
Read
(
Stream
*
strm
,
DLContext
*
ctx
)
{
int32_t
device_type
=
0
;
if
(
!
Handler
<
int32_t
>::
Read
(
strm
,
&
(
device_type
)))
return
false
;
ctx
->
device_type
=
static_cast
<
DLDeviceType
>
(
device_type
);
if
(
!
Handler
<
int32_t
>::
Read
(
strm
,
&
(
ctx
->
device_id
)))
return
false
;
return
true
;
}
};
}
// namespace serializer
}
// namespace dmlc
#endif // TVM_RUNTIME_SERIALIZER_H_
nnvm/src/compiler/graph_runtime.cc
View file @
42608dda
...
...
@@ -7,6 +7,7 @@
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/serializer.h>
#include "./graph_runtime.h"
namespace
nnvm
{
...
...
@@ -38,46 +39,53 @@ NNVM_REGISTER_OP(tvm_op)
bool
SaveDLTensor
(
dmlc
::
Stream
*
strm
,
DLTensor
*
tensor
)
{
uint64_t
header
=
kTVMNDArrayMagic
,
reserved
=
0
;
strm
->
Write
(
&
header
,
sizeof
(
header
));
strm
->
Write
(
&
reserved
,
sizeof
(
reserved
));
strm
->
Write
(
&
tensor
->
ctx
,
sizeof
(
tensor
->
ctx
));
strm
->
Write
(
&
tensor
->
ndim
,
sizeof
(
tensor
->
ndim
));
strm
->
Write
(
&
tensor
->
dtype
,
sizeof
(
tensor
->
dtype
));
strm
->
Write
(
header
);
strm
->
Write
(
reserved
);
strm
->
Write
(
tensor
->
ctx
);
strm
->
Write
(
tensor
->
ndim
);
strm
->
Write
(
tensor
->
dtype
);
int
ndim
=
tensor
->
ndim
;
strm
->
Write
(
tensor
->
shape
,
sizeof
(
int64_t
)
*
ndim
);
strm
->
Write
Array
(
tensor
->
shape
,
ndim
);
int
type_
size
=
tensor
->
dtype
.
bits
/
8
;
int64_t
size
=
1
;
int
type_
bytes
=
tensor
->
dtype
.
bits
/
8
;
int64_t
num_elems
=
1
;
for
(
int
i
=
0
;
i
<
ndim
;
++
i
)
{
size
*=
tensor
->
shape
[
i
];
num_elems
*=
tensor
->
shape
[
i
];
}
int64_t
data_byte_size
=
type_bytes
*
num_elems
;
strm
->
Write
(
data_byte_size
);
// handle endianness of data correctly.
if
(
DMLC_IO_NO_ENDIAN_SWAP
)
{
strm
->
Write
(
tensor
->
data
,
data_byte_size
);
}
else
{
uint8_t
*
dptr
=
reinterpret_cast
<
uint8_t
*>
(
tensor
->
data
);
std
::
vector
<
uint8_t
>
bytes
(
dptr
,
dptr
+
data_byte_size
);
dmlc
::
ByteSwap
(
dmlc
::
BeginPtr
(
bytes
),
type_bytes
,
num_elems
);
strm
->
Write
(
dmlc
::
BeginPtr
(
bytes
),
data_byte_size
);
}
int64_t
data_byte_size
=
type_size
*
size
;
strm
->
Write
(
&
data_byte_size
,
sizeof
(
data_byte_size
));
strm
->
Write
(
tensor
->
data
,
data_byte_size
);
return
true
;
}
DLTensor
*
LoadDLTensor
(
dmlc
::
Stream
*
strm
)
{
uint64_t
header
,
reserved
;
CHECK
(
strm
->
Read
(
&
header
,
sizeof
(
header
)
))
CHECK
(
strm
->
Read
(
&
header
))
<<
"Invalid DLTensor file format"
;
CHECK
(
strm
->
Read
(
&
reserved
,
sizeof
(
reserved
)
))
CHECK
(
strm
->
Read
(
&
reserved
))
<<
"Invalid DLTensor file format"
;
CHECK
(
header
==
kTVMNDArrayMagic
)
<<
"Invalid DLTensor file format"
;
DLTensor
tensor
;
CHECK
(
strm
->
Read
(
&
tensor
.
ctx
,
sizeof
(
tensor
.
ctx
)))
CHECK
(
strm
->
Read
(
&
(
tensor
.
ctx
)))
<<
"Invalid DLTensor file format"
;
CHECK
(
strm
->
Read
(
&
tensor
.
ndim
,
sizeof
(
tensor
.
ndim
)))
CHECK
(
strm
->
Read
(
&
(
tensor
.
ndim
)))
<<
"Invalid DLTensor file format"
;
CHECK
(
strm
->
Read
(
&
tensor
.
dtype
,
sizeof
(
tensor
.
dtype
)))
CHECK
(
strm
->
Read
(
&
(
tensor
.
dtype
)))
<<
"Invalid DLTensor file format"
;
std
::
vector
<
int64_t
>
shape
(
tensor
.
ndim
);
CHECK
(
strm
->
Read
(
&
shape
[
0
],
sizeof
(
int64_t
)
*
tensor
.
ndim
))
<<
"Invalid DLTensor file format"
;
if
(
tensor
.
ndim
!=
0
)
{
CHECK
(
strm
->
ReadArray
(
&
shape
[
0
],
tensor
.
ndim
))
<<
"Invalid DLTensor file format"
;
}
DLTensor
*
ret
;
CHECK_EQ
(
TVMArrayAlloc
(
shape
.
data
(),
tensor
.
ndim
,
...
...
@@ -87,18 +95,21 @@ DLTensor* LoadDLTensor(dmlc::Stream* strm) {
static_cast
<
int
>
(
tensor
.
ctx
.
device_type
),
tensor
.
ctx
.
device_id
,
&
ret
),
0
)
<<
TVMGetLastError
();
int64_t
size
=
1
;
int
type_size
=
ret
->
dtype
.
bits
/
8
;
int64_t
num_elems
=
1
;
int
elem_bytes
=
(
ret
->
dtype
.
bits
+
7
)
/
8
;
for
(
int
i
=
0
;
i
<
ret
->
ndim
;
++
i
)
{
size
*=
ret
->
shape
[
i
];
num_elems
*=
ret
->
shape
[
i
];
}
int64_t
data_byte_size
;
CHECK
(
strm
->
Read
(
&
data_byte_size
,
sizeof
(
data_byte_size
)
))
CHECK
(
strm
->
Read
(
&
data_byte_size
))
<<
"Invalid DLTensor file format"
;
CHECK
(
data_byte_size
==
type_size
*
size
)
CHECK
(
data_byte_size
==
num_elems
*
elem_bytes
)
<<
"Invalid DLTensor file format"
;
CHECK
(
strm
->
Read
(
ret
->
data
,
type_size
*
size
))
CHECK
(
strm
->
Read
(
ret
->
data
,
data_byte_
size
))
<<
"Invalid DLTensor file format"
;
if
(
!
DMLC_IO_NO_ENDIAN_SWAP
)
{
dmlc
::
ByteSwap
(
ret
->
data
,
elem_bytes
,
num_elems
);
}
return
ret
;
}
...
...
@@ -118,12 +129,12 @@ TVM_REGISTER_GLOBAL("nnvm.compiler._save_param_dict")
dmlc
::
MemoryStringStream
strm
(
&
bytes
);
dmlc
::
Stream
*
fo
=
&
strm
;
uint64_t
header
=
kTVMNDArrayListMagic
,
reserved
=
0
;
fo
->
Write
(
&
header
,
sizeof
(
header
)
);
fo
->
Write
(
&
reserved
,
sizeof
(
reserved
)
);
fo
->
Write
(
header
);
fo
->
Write
(
reserved
);
fo
->
Write
(
names
);
{
uint64_t
sz
=
static_cast
<
uint64_t
>
(
arrays
.
size
());
fo
->
Write
(
&
sz
,
sizeof
(
sz
)
);
fo
->
Write
(
sz
);
for
(
size_t
i
=
0
;
i
<
sz
;
++
i
)
{
SaveDLTensor
(
fo
,
arrays
[
i
]);
}
...
...
@@ -150,7 +161,6 @@ TVM_REGISTER_GLOBAL("nnvm.compiler._load_param_dict")
<<
"Invalid parameters file format"
;
CHECK
(
strm
->
Read
(
&
reserved
))
<<
"Invalid parameters file format"
;
CHECK
(
strm
->
Read
(
&
names
))
<<
"Invalid parameters file format"
;
uint64_t
sz
;
...
...
python/tvm/contrib/rpc/base.py
View file @
42608dda
...
...
@@ -73,7 +73,7 @@ def sendjson(sock, data):
Python value to be sent.
"""
data
=
json
.
dumps
(
data
)
sock
.
sendall
(
struct
.
pack
(
"
@
i"
,
len
(
data
)))
sock
.
sendall
(
struct
.
pack
(
"
<
i"
,
len
(
data
)))
sock
.
sendall
(
data
.
encode
(
"utf-8"
))
...
...
@@ -90,7 +90,7 @@ def recvjson(sock):
value : object
The value received.
"""
size
=
struct
.
unpack
(
"
@
i"
,
recvall
(
sock
,
4
))[
0
]
size
=
struct
.
unpack
(
"
<
i"
,
recvall
(
sock
,
4
))[
0
]
data
=
json
.
loads
(
py_str
(
recvall
(
sock
,
size
)))
return
data
...
...
python/tvm/contrib/rpc/client.py
View file @
42608dda
...
...
@@ -192,8 +192,8 @@ class TrackerSession(object):
def
_connect
(
self
):
self
.
_sock
=
base
.
connect_with_retry
(
self
.
_addr
)
self
.
_sock
.
sendall
(
struct
.
pack
(
"
@
i"
,
base
.
RPC_TRACKER_MAGIC
))
magic
=
struct
.
unpack
(
"
@
i"
,
base
.
recvall
(
self
.
_sock
,
4
))[
0
]
self
.
_sock
.
sendall
(
struct
.
pack
(
"
<
i"
,
base
.
RPC_TRACKER_MAGIC
))
magic
=
struct
.
unpack
(
"
<
i"
,
base
.
recvall
(
self
.
_sock
,
4
))[
0
]
if
magic
!=
base
.
RPC_TRACKER_MAGIC
:
raise
RuntimeError
(
"
%
s is not RPC Tracker"
%
str
(
self
.
_addr
))
...
...
python/tvm/contrib/rpc/proxy.py
View file @
42608dda
...
...
@@ -58,14 +58,14 @@ class ForwardHandler(object):
def
_init_step
(
self
,
message
):
if
self
.
_magic
is
None
:
assert
len
(
message
)
==
4
self
.
_magic
=
struct
.
unpack
(
'
@
i'
,
message
)[
0
]
self
.
_magic
=
struct
.
unpack
(
'
<
i'
,
message
)[
0
]
if
self
.
_magic
!=
base
.
RPC_MAGIC
:
logging
.
info
(
"Invalid RPC magic from
%
s"
,
self
.
name
())
self
.
close
()
self
.
_init_req_nbytes
=
4
elif
self
.
_rpc_key_length
is
None
:
assert
len
(
message
)
==
4
self
.
_rpc_key_length
=
struct
.
unpack
(
'
@
i'
,
message
)[
0
]
self
.
_rpc_key_length
=
struct
.
unpack
(
'
<
i'
,
message
)[
0
]
self
.
_init_req_nbytes
=
self
.
_rpc_key_length
elif
self
.
rpc_key
is
None
:
assert
len
(
message
)
==
self
.
_rpc_key_length
...
...
@@ -269,12 +269,12 @@ class ProxyServerHandler(object):
lhs
.
forward_proxy
=
rhs
rhs
.
forward_proxy
=
lhs
lhs
.
send_data
(
struct
.
pack
(
'
@
i'
,
base
.
RPC_CODE_SUCCESS
))
lhs
.
send_data
(
struct
.
pack
(
'
@
i'
,
len
(
rhs
.
rpc_key
)))
lhs
.
send_data
(
struct
.
pack
(
'
<
i'
,
base
.
RPC_CODE_SUCCESS
))
lhs
.
send_data
(
struct
.
pack
(
'
<
i'
,
len
(
rhs
.
rpc_key
)))
lhs
.
send_data
(
rhs
.
rpc_key
.
encode
(
"utf-8"
))
rhs
.
send_data
(
struct
.
pack
(
'
@
i'
,
base
.
RPC_CODE_SUCCESS
))
rhs
.
send_data
(
struct
.
pack
(
'
@
i'
,
len
(
lhs
.
rpc_key
)))
rhs
.
send_data
(
struct
.
pack
(
'
<
i'
,
base
.
RPC_CODE_SUCCESS
))
rhs
.
send_data
(
struct
.
pack
(
'
<
i'
,
len
(
lhs
.
rpc_key
)))
rhs
.
send_data
(
lhs
.
rpc_key
.
encode
(
"utf-8"
))
logging
.
info
(
"Pairup connect
%
s and
%
s"
,
lhs
.
name
(),
rhs
.
name
())
...
...
@@ -299,8 +299,8 @@ class ProxyServerHandler(object):
if
self
.
_tracker_conn
is
None
:
self
.
_tracker_conn
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
self
.
_tracker_conn
.
connect
(
self
.
_tracker_addr
)
self
.
_tracker_conn
.
sendall
(
struct
.
pack
(
"
@
i"
,
base
.
RPC_TRACKER_MAGIC
))
magic
=
struct
.
unpack
(
"
@
i"
,
base
.
recvall
(
self
.
_tracker_conn
,
4
))[
0
]
self
.
_tracker_conn
.
sendall
(
struct
.
pack
(
"
<
i"
,
base
.
RPC_TRACKER_MAGIC
))
magic
=
struct
.
unpack
(
"
<
i"
,
base
.
recvall
(
self
.
_tracker_conn
,
4
))[
0
]
if
magic
!=
base
.
RPC_TRACKER_MAGIC
:
self
.
loop
.
stop
()
raise
RuntimeError
(
"
%
s is not RPC Tracker"
%
str
(
self
.
_tracker_addr
))
...
...
@@ -371,7 +371,7 @@ class ProxyServerHandler(object):
if
handler
.
match_key
in
self
.
_server_pool
:
self
.
_pair_up
(
self
.
_server_pool
.
pop
(
handler
.
match_key
),
handler
)
else
:
handler
.
send_data
(
struct
.
pack
(
'
@
i'
,
base
.
RPC_CODE_MISMATCH
))
handler
.
send_data
(
struct
.
pack
(
'
<
i'
,
base
.
RPC_CODE_MISMATCH
))
handler
.
signal_close
()
def
_handler_ready_proxy_mode
(
self
,
handler
):
...
...
@@ -395,12 +395,12 @@ class ProxyServerHandler(object):
logging
.
info
(
"Timeout client connection
%
s, cannot find match key=
%
s"
,
handler
.
name
(),
key
)
pool_dst
.
pop
(
key
)
handler
.
send_data
(
struct
.
pack
(
'
@
i'
,
base
.
RPC_CODE_MISMATCH
))
handler
.
send_data
(
struct
.
pack
(
'
<
i'
,
base
.
RPC_CODE_MISMATCH
))
handler
.
signal_close
()
self
.
loop
.
call_later
(
timeout
,
cleanup
)
else
:
logging
.
info
(
"Duplicate connection with same key=
%
s"
,
key
)
handler
.
send_data
(
struct
.
pack
(
'
@
i'
,
base
.
RPC_CODE_DUPLICATE
))
handler
.
send_data
(
struct
.
pack
(
'
<
i'
,
base
.
RPC_CODE_DUPLICATE
))
handler
.
signal_close
()
def
handler_ready
(
self
,
handler
):
...
...
@@ -538,13 +538,13 @@ def websocket_proxy_server(url, key=""):
on_message
=
create_on_message
(
conn
)
temp
=
_server_env
(
None
)
# Start connecton
conn
.
write_message
(
struct
.
pack
(
'
@
i'
,
base
.
RPC_MAGIC
),
binary
=
True
)
conn
.
write_message
(
struct
.
pack
(
'
<
i'
,
base
.
RPC_MAGIC
),
binary
=
True
)
key
=
"server:"
+
key
conn
.
write_message
(
struct
.
pack
(
'
@
i'
,
len
(
key
)),
binary
=
True
)
conn
.
write_message
(
struct
.
pack
(
'
<
i'
,
len
(
key
)),
binary
=
True
)
conn
.
write_message
(
key
.
encode
(
"utf-8"
),
binary
=
True
)
msg
=
yield
conn
.
read_message
()
assert
len
(
msg
)
>=
4
magic
=
struct
.
unpack
(
'
@
i'
,
msg
[:
4
])[
0
]
magic
=
struct
.
unpack
(
'
<
i'
,
msg
[:
4
])[
0
]
if
magic
==
base
.
RPC_CODE_DUPLICATE
:
raise
RuntimeError
(
"key:
%
s has already been used in proxy"
%
key
)
elif
magic
==
base
.
RPC_CODE_MISMATCH
:
...
...
python/tvm/contrib/rpc/server.py
View file @
42608dda
...
...
@@ -124,23 +124,23 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library):
unmatch_period_count
=
0
continue
conn
,
addr
=
listen_sock
.
accept
()
magic
=
struct
.
unpack
(
"
@
i"
,
base
.
recvall
(
conn
,
4
))[
0
]
magic
=
struct
.
unpack
(
"
<
i"
,
base
.
recvall
(
conn
,
4
))[
0
]
if
magic
!=
base
.
RPC_MAGIC
:
conn
.
close
()
continue
keylen
=
struct
.
unpack
(
"
@
i"
,
base
.
recvall
(
conn
,
4
))[
0
]
keylen
=
struct
.
unpack
(
"
<
i"
,
base
.
recvall
(
conn
,
4
))[
0
]
key
=
py_str
(
base
.
recvall
(
conn
,
keylen
))
arr
=
key
.
split
()
expect_header
=
"client:"
+
matchkey
server_key
=
"server:"
+
rpc_key
if
arr
[
0
]
!=
expect_header
:
conn
.
sendall
(
struct
.
pack
(
"
@
i"
,
base
.
RPC_CODE_MISMATCH
))
conn
.
sendall
(
struct
.
pack
(
"
<
i"
,
base
.
RPC_CODE_MISMATCH
))
conn
.
close
()
logging
.
info
(
"RPCServer: mismatch key from
%
s"
,
addr
)
continue
else
:
conn
.
sendall
(
struct
.
pack
(
"
@
i"
,
base
.
RPC_CODE_SUCCESS
))
conn
.
sendall
(
struct
.
pack
(
"
@
i"
,
len
(
server_key
)))
conn
.
sendall
(
struct
.
pack
(
"
<
i"
,
base
.
RPC_CODE_SUCCESS
))
conn
.
sendall
(
struct
.
pack
(
"
<
i"
,
len
(
server_key
)))
conn
.
sendall
(
server_key
.
encode
(
"utf-8"
))
return
conn
,
addr
,
_parse_server_opt
(
arr
[
1
:])
...
...
@@ -151,8 +151,8 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library):
# step 1: setup tracker and report to tracker
if
tracker_addr
and
tracker_conn
is
None
:
tracker_conn
=
base
.
connect_with_retry
(
tracker_addr
)
tracker_conn
.
sendall
(
struct
.
pack
(
"
@
i"
,
base
.
RPC_TRACKER_MAGIC
))
magic
=
struct
.
unpack
(
"
@
i"
,
base
.
recvall
(
tracker_conn
,
4
))[
0
]
tracker_conn
.
sendall
(
struct
.
pack
(
"
<
i"
,
base
.
RPC_TRACKER_MAGIC
))
magic
=
struct
.
unpack
(
"
<
i"
,
base
.
recvall
(
tracker_conn
,
4
))[
0
]
if
magic
!=
base
.
RPC_TRACKER_MAGIC
:
raise
RuntimeError
(
"
%
s is not RPC Tracker"
%
str
(
tracker_addr
))
# report status of current queue
...
...
@@ -193,17 +193,17 @@ def _connect_proxy_loop(addr, key, load_library):
try
:
sock
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
sock
.
connect
(
addr
)
sock
.
sendall
(
struct
.
pack
(
"
@
i"
,
base
.
RPC_MAGIC
))
sock
.
sendall
(
struct
.
pack
(
"
@
i"
,
len
(
key
)))
sock
.
sendall
(
struct
.
pack
(
"
<
i"
,
base
.
RPC_MAGIC
))
sock
.
sendall
(
struct
.
pack
(
"
<
i"
,
len
(
key
)))
sock
.
sendall
(
key
.
encode
(
"utf-8"
))
magic
=
struct
.
unpack
(
"
@
i"
,
base
.
recvall
(
sock
,
4
))[
0
]
magic
=
struct
.
unpack
(
"
<
i"
,
base
.
recvall
(
sock
,
4
))[
0
]
if
magic
==
base
.
RPC_CODE_DUPLICATE
:
raise
RuntimeError
(
"key:
%
s has already been used in proxy"
%
key
)
elif
magic
==
base
.
RPC_CODE_MISMATCH
:
logging
.
info
(
"RPCProxy do not have matching client key
%
s"
,
key
)
elif
magic
!=
base
.
RPC_CODE_SUCCESS
:
raise
RuntimeError
(
"
%
s is not RPC Proxy"
%
str
(
addr
))
keylen
=
struct
.
unpack
(
"
@
i"
,
base
.
recvall
(
sock
,
4
))[
0
]
keylen
=
struct
.
unpack
(
"
<
i"
,
base
.
recvall
(
sock
,
4
))[
0
]
remote_key
=
py_str
(
base
.
recvall
(
sock
,
keylen
))
opts
=
_parse_server_opt
(
remote_key
.
split
()[
1
:])
logging
.
info
(
"RPCProxy connected to
%
s"
,
str
(
addr
))
...
...
python/tvm/contrib/rpc/tracker.py
View file @
42608dda
...
...
@@ -143,11 +143,11 @@ class TCPEventHandler(tornado_util.TCPHandler):
if
len
(
message
)
!=
4
:
logging
.
info
(
"Invalid connection from
%
s"
,
self
.
name
())
self
.
close
()
magic
=
struct
.
unpack
(
'
@
i'
,
message
)[
0
]
magic
=
struct
.
unpack
(
'
<
i'
,
message
)[
0
]
if
magic
!=
RPC_TRACKER_MAGIC
:
logging
.
info
(
"Invalid magic from
%
s"
,
self
.
name
())
self
.
close
()
self
.
write_message
(
struct
.
pack
(
'
@
i'
,
RPC_TRACKER_MAGIC
),
binary
=
True
)
self
.
write_message
(
struct
.
pack
(
'
<
i'
,
RPC_TRACKER_MAGIC
),
binary
=
True
)
self
.
_init_req_nbytes
=
0
def
on_message
(
self
,
message
):
...
...
@@ -168,7 +168,7 @@ class TCPEventHandler(tornado_util.TCPHandler):
while
True
:
if
self
.
_msg_size
==
0
:
if
len
(
self
.
_data
)
>=
4
:
self
.
_msg_size
=
struct
.
unpack
(
'
@
i'
,
self
.
_data
[:
4
])[
0
]
self
.
_msg_size
=
struct
.
unpack
(
'
<
i'
,
self
.
_data
[:
4
])[
0
]
else
:
return
if
self
.
_msg_size
!=
0
and
len
(
self
.
_data
)
>=
self
.
_msg_size
+
4
:
...
...
@@ -184,7 +184,7 @@ class TCPEventHandler(tornado_util.TCPHandler):
"""return value to the output"""
data
=
json
.
dumps
(
data
)
self
.
write_message
(
struct
.
pack
(
'
@
i'
,
len
(
data
)),
binary
=
True
)
struct
.
pack
(
'
<
i'
,
len
(
data
)),
binary
=
True
)
self
.
write_message
(
data
.
encode
(
"utf-8"
),
binary
=
True
)
def
call_handler
(
self
,
args
):
...
...
@@ -355,8 +355,8 @@ class Tracker(object):
def
_stop_tracker
(
self
):
sock
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
sock
.
connect
((
self
.
host
,
self
.
port
))
sock
.
sendall
(
struct
.
pack
(
"
@
i"
,
base
.
RPC_TRACKER_MAGIC
))
magic
=
struct
.
unpack
(
"
@
i"
,
base
.
recvall
(
sock
,
4
))[
0
]
sock
.
sendall
(
struct
.
pack
(
"
<
i"
,
base
.
RPC_TRACKER_MAGIC
))
magic
=
struct
.
unpack
(
"
<
i"
,
base
.
recvall
(
sock
,
4
))[
0
]
assert
magic
==
base
.
RPC_TRACKER_MAGIC
base
.
sendjson
(
sock
,
[
TrackerCode
.
STOP
,
self
.
stop_key
])
assert
base
.
recvjson
(
sock
)
==
TrackerCode
.
SUCCESS
...
...
src/runtime/file_util.cc
View file @
42608dda
...
...
@@ -4,6 +4,7 @@
*/
#include <dmlc/json.h>
#include <dmlc/logging.h>
#include <tvm/runtime/serializer.h>
#include <fstream>
#include "./file_util.h"
...
...
src/runtime/graph/graph_runtime.cc
View file @
42608dda
...
...
@@ -4,6 +4,7 @@
*/
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/serializer.h>
#include <dmlc/memory_io.h>
#include <dmlc/json.h>
#include <numeric>
...
...
@@ -397,24 +398,25 @@ class GraphRuntime : public ModuleNode {
void
GraphRuntime
::
LoadDLTensor
(
dmlc
::
Stream
*
strm
,
DLTensor
*
dst
)
{
// always use strm->Read to maintain endianness conversion
uint64_t
header
,
reserved
;
CHECK
(
strm
->
Read
(
&
header
,
sizeof
(
header
)
))
CHECK
(
strm
->
Read
(
&
header
))
<<
"Invalid DLTensor file format"
;
CHECK
(
strm
->
Read
(
&
reserved
,
sizeof
(
reserved
)
))
CHECK
(
strm
->
Read
(
&
reserved
))
<<
"Invalid DLTensor file format"
;
CHECK
(
header
==
kTVMNDArrayMagic
)
<<
"Invalid DLTensor file format"
;
DLTensor
tensor
;
CHECK
(
strm
->
Read
(
&
tensor
.
ctx
,
sizeof
(
tensor
.
ctx
)))
CHECK
(
strm
->
Read
(
&
(
tensor
.
ctx
)))
<<
"Invalid DLTensor file format"
;
CHECK
(
strm
->
Read
(
&
tensor
.
ndim
,
sizeof
(
tensor
.
ndim
)))
CHECK
(
strm
->
Read
(
&
(
tensor
.
ndim
)))
<<
"Invalid DLTensor file format"
;
CHECK
(
strm
->
Read
(
&
tensor
.
dtype
,
sizeof
(
tensor
.
dtype
)))
CHECK
(
strm
->
Read
(
&
(
tensor
.
dtype
)))
<<
"Invalid DLTensor file format"
;
std
::
vector
<
int64_t
>
shape
(
tensor
.
ndim
);
if
(
tensor
.
ndim
!=
0
)
{
CHECK
(
strm
->
Read
(
&
shape
[
0
],
sizeof
(
int64_t
)
*
tensor
.
ndim
))
CHECK
(
strm
->
Read
Array
(
&
shape
[
0
],
tensor
.
ndim
))
<<
"Invalid DLTensor file format"
;
}
CHECK_EQ
(
tensor
.
ndim
,
dst
->
ndim
)
<<
"param dimension mismatch"
;
...
...
@@ -425,18 +427,23 @@ void GraphRuntime::LoadDLTensor(dmlc::Stream* strm, DLTensor* dst) {
CHECK_EQ
(
shape
[
i
],
dst
->
shape
[
i
])
<<
"param shape mismatch"
;
}
size_t
bits
=
dst
->
dtype
.
bits
*
dst
->
dtype
.
lanes
;
size_t
size
=
(
bits
+
7
)
/
8
;
size_t
elem_bytes
=
(
bits
+
7
)
/
8
;
size_t
num_elems
=
1
;
for
(
int
i
=
0
;
i
<
dst
->
ndim
;
++
i
)
{
size
*=
dst
->
shape
[
i
];
num_elems
*=
dst
->
shape
[
i
];
}
uint64_t
data_byte_size
;
CHECK
(
strm
->
Read
(
&
data_byte_size
,
sizeof
(
data_byte_size
)
))
CHECK
(
strm
->
Read
(
&
data_byte_size
))
<<
"Invalid DLTensor file format"
;
CHECK
(
data_byte_size
==
size
)
CHECK
_EQ
(
data_byte_size
,
elem_bytes
*
num_elems
)
<<
"Invalid DLTensor file format"
;
std
::
vector
<
uint8_t
>
bytes
(
data_byte_size
+
1
);
CHECK
(
strm
->
Read
(
&
bytes
[
0
],
data_byte_size
))
<<
"Invalid DLTensor file format"
;
// explicitly swap endian when necessary.
if
(
!
DMLC_IO_NO_ENDIAN_SWAP
)
{
dmlc
::
ByteSwap
(
&
bytes
[
0
],
elem_bytes
,
num_elems
);
}
TVM_CCALL
(
TVMArrayCopyFromBytes
(
dst
,
&
bytes
[
0
],
data_byte_size
));
}
...
...
@@ -453,9 +460,8 @@ void GraphRuntime::LoadParams(dmlc::Stream* strm) {
CHECK
(
strm
->
Read
(
&
names
))
<<
"Invalid parameters file format"
;
uint64_t
sz
;
strm
->
Read
(
&
sz
,
sizeof
(
sz
)
);
strm
->
Read
(
&
sz
);
size_t
size
=
static_cast
<
size_t
>
(
sz
);
CHECK
(
size
==
names
.
size
())
<<
"Invalid parameters file format"
;
for
(
size_t
i
=
0
;
i
<
size
;
++
i
)
{
...
...
src/runtime/rpc/rpc_session.cc
View file @
42608dda
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment