Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
42608dda
Commit
42608dda
authored
May 30, 2018
by
tqchen
Committed by
Tianqi Chen
May 30, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[IO] Support cross-endian
parent
0f9dab98
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
145 additions
and
78 deletions
+145
-78
dmlc-core
+1
-1
include/tvm/runtime/serializer.h
+50
-0
nnvm/src/compiler/graph_runtime.cc
+40
-30
python/tvm/contrib/rpc/base.py
+2
-2
python/tvm/contrib/rpc/client.py
+2
-2
python/tvm/contrib/rpc/proxy.py
+14
-14
python/tvm/contrib/rpc/server.py
+11
-11
python/tvm/contrib/rpc/tracker.py
+6
-6
src/runtime/file_util.cc
+1
-0
src/runtime/graph/graph_runtime.cc
+18
-12
src/runtime/rpc/rpc_session.cc
+0
-0
No files found.
dmlc-core
@
9b3f9753
Subproject commit
d3f7fbb53e5b037c0f5bf6bd21871ccc720690cc
Subproject commit
9b3f9753ae81d657743c555e0cacc4e43f0bed2d
include/tvm/runtime/serializer.h
0 → 100644
View file @
42608dda
/*!
* Copyright (c) 2017 by Contributors
* \file tvm/runtime/serializer.h
* \brief Serializer extension to support TVM data types
* Include this file to enable serialization of DLDataType, DLContext
*/
#ifndef TVM_RUNTIME_SERIALIZER_H_
#define TVM_RUNTIME_SERIALIZER_H_
#include <dmlc/io.h>
#include <dmlc/serializer.h>
#include "./c_runtime_api.h"
namespace
dmlc
{
namespace
serializer
{
template
<>
struct
Handler
<
DLDataType
>
{
inline
static
void
Write
(
Stream
*
strm
,
const
DLDataType
&
dtype
)
{
Handler
<
uint8_t
>::
Write
(
strm
,
dtype
.
code
);
Handler
<
uint8_t
>::
Write
(
strm
,
dtype
.
bits
);
Handler
<
uint16_t
>::
Write
(
strm
,
dtype
.
lanes
);
}
inline
static
bool
Read
(
Stream
*
strm
,
DLDataType
*
dtype
)
{
if
(
!
Handler
<
uint8_t
>::
Read
(
strm
,
&
(
dtype
->
code
)))
return
false
;
if
(
!
Handler
<
uint8_t
>::
Read
(
strm
,
&
(
dtype
->
bits
)))
return
false
;
if
(
!
Handler
<
uint16_t
>::
Read
(
strm
,
&
(
dtype
->
lanes
)))
return
false
;
return
true
;
}
};
template
<>
struct
Handler
<
DLContext
>
{
inline
static
void
Write
(
Stream
*
strm
,
const
DLContext
&
ctx
)
{
int32_t
device_type
=
static_cast
<
int32_t
>
(
ctx
.
device_type
);
Handler
<
int32_t
>::
Write
(
strm
,
device_type
);
Handler
<
int32_t
>::
Write
(
strm
,
ctx
.
device_id
);
}
inline
static
bool
Read
(
Stream
*
strm
,
DLContext
*
ctx
)
{
int32_t
device_type
=
0
;
if
(
!
Handler
<
int32_t
>::
Read
(
strm
,
&
(
device_type
)))
return
false
;
ctx
->
device_type
=
static_cast
<
DLDeviceType
>
(
device_type
);
if
(
!
Handler
<
int32_t
>::
Read
(
strm
,
&
(
ctx
->
device_id
)))
return
false
;
return
true
;
}
};
}
// namespace serializer
}
// namespace dmlc
#endif // TVM_RUNTIME_SERIALIZER_H_
nnvm/src/compiler/graph_runtime.cc
View file @
42608dda
...
@@ -7,6 +7,7 @@
...
@@ -7,6 +7,7 @@
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/serializer.h>
#include "./graph_runtime.h"
#include "./graph_runtime.h"
namespace
nnvm
{
namespace
nnvm
{
...
@@ -38,46 +39,53 @@ NNVM_REGISTER_OP(tvm_op)
...
@@ -38,46 +39,53 @@ NNVM_REGISTER_OP(tvm_op)
bool
SaveDLTensor
(
dmlc
::
Stream
*
strm
,
DLTensor
*
tensor
)
{
bool
SaveDLTensor
(
dmlc
::
Stream
*
strm
,
DLTensor
*
tensor
)
{
uint64_t
header
=
kTVMNDArrayMagic
,
reserved
=
0
;
uint64_t
header
=
kTVMNDArrayMagic
,
reserved
=
0
;
strm
->
Write
(
&
header
,
sizeof
(
header
));
strm
->
Write
(
header
);
strm
->
Write
(
&
reserved
,
sizeof
(
reserved
));
strm
->
Write
(
reserved
);
strm
->
Write
(
tensor
->
ctx
);
strm
->
Write
(
&
tensor
->
ctx
,
sizeof
(
tensor
->
ctx
));
strm
->
Write
(
tensor
->
ndim
);
strm
->
Write
(
&
tensor
->
ndim
,
sizeof
(
tensor
->
ndim
));
strm
->
Write
(
tensor
->
dtype
);
strm
->
Write
(
&
tensor
->
dtype
,
sizeof
(
tensor
->
dtype
));
int
ndim
=
tensor
->
ndim
;
int
ndim
=
tensor
->
ndim
;
strm
->
Write
(
tensor
->
shape
,
sizeof
(
int64_t
)
*
ndim
);
strm
->
Write
Array
(
tensor
->
shape
,
ndim
);
int
type_
size
=
tensor
->
dtype
.
bits
/
8
;
int
type_
bytes
=
tensor
->
dtype
.
bits
/
8
;
int64_t
size
=
1
;
int64_t
num_elems
=
1
;
for
(
int
i
=
0
;
i
<
ndim
;
++
i
)
{
for
(
int
i
=
0
;
i
<
ndim
;
++
i
)
{
size
*=
tensor
->
shape
[
i
];
num_elems
*=
tensor
->
shape
[
i
];
}
}
int64_t
data_byte_size
=
type_size
*
size
;
int64_t
data_byte_size
=
type_bytes
*
num_elems
;
strm
->
Write
(
&
data_byte_size
,
sizeof
(
data_byte_size
));
strm
->
Write
(
data_byte_size
);
// handle endianness of data correctly.
if
(
DMLC_IO_NO_ENDIAN_SWAP
)
{
strm
->
Write
(
tensor
->
data
,
data_byte_size
);
strm
->
Write
(
tensor
->
data
,
data_byte_size
);
}
else
{
uint8_t
*
dptr
=
reinterpret_cast
<
uint8_t
*>
(
tensor
->
data
);
std
::
vector
<
uint8_t
>
bytes
(
dptr
,
dptr
+
data_byte_size
);
dmlc
::
ByteSwap
(
dmlc
::
BeginPtr
(
bytes
),
type_bytes
,
num_elems
);
strm
->
Write
(
dmlc
::
BeginPtr
(
bytes
),
data_byte_size
);
}
return
true
;
return
true
;
}
}
DLTensor
*
LoadDLTensor
(
dmlc
::
Stream
*
strm
)
{
DLTensor
*
LoadDLTensor
(
dmlc
::
Stream
*
strm
)
{
uint64_t
header
,
reserved
;
uint64_t
header
,
reserved
;
CHECK
(
strm
->
Read
(
&
header
,
sizeof
(
header
)
))
CHECK
(
strm
->
Read
(
&
header
))
<<
"Invalid DLTensor file format"
;
<<
"Invalid DLTensor file format"
;
CHECK
(
strm
->
Read
(
&
reserved
,
sizeof
(
reserved
)
))
CHECK
(
strm
->
Read
(
&
reserved
))
<<
"Invalid DLTensor file format"
;
<<
"Invalid DLTensor file format"
;
CHECK
(
header
==
kTVMNDArrayMagic
)
CHECK
(
header
==
kTVMNDArrayMagic
)
<<
"Invalid DLTensor file format"
;
<<
"Invalid DLTensor file format"
;
DLTensor
tensor
;
DLTensor
tensor
;
CHECK
(
strm
->
Read
(
&
tensor
.
ctx
,
sizeof
(
tensor
.
ctx
)))
CHECK
(
strm
->
Read
(
&
(
tensor
.
ctx
)))
<<
"Invalid DLTensor file format"
;
<<
"Invalid DLTensor file format"
;
CHECK
(
strm
->
Read
(
&
tensor
.
ndim
,
sizeof
(
tensor
.
ndim
)))
CHECK
(
strm
->
Read
(
&
(
tensor
.
ndim
)))
<<
"Invalid DLTensor file format"
;
<<
"Invalid DLTensor file format"
;
CHECK
(
strm
->
Read
(
&
tensor
.
dtype
,
sizeof
(
tensor
.
dtype
)))
CHECK
(
strm
->
Read
(
&
(
tensor
.
dtype
)))
<<
"Invalid DLTensor file format"
;
<<
"Invalid DLTensor file format"
;
std
::
vector
<
int64_t
>
shape
(
tensor
.
ndim
);
std
::
vector
<
int64_t
>
shape
(
tensor
.
ndim
);
CHECK
(
strm
->
Read
(
&
shape
[
0
],
sizeof
(
int64_t
)
*
tensor
.
ndim
))
if
(
tensor
.
ndim
!=
0
)
{
CHECK
(
strm
->
ReadArray
(
&
shape
[
0
],
tensor
.
ndim
))
<<
"Invalid DLTensor file format"
;
<<
"Invalid DLTensor file format"
;
}
DLTensor
*
ret
;
DLTensor
*
ret
;
CHECK_EQ
(
TVMArrayAlloc
(
shape
.
data
(),
CHECK_EQ
(
TVMArrayAlloc
(
shape
.
data
(),
tensor
.
ndim
,
tensor
.
ndim
,
...
@@ -87,18 +95,21 @@ DLTensor* LoadDLTensor(dmlc::Stream* strm) {
...
@@ -87,18 +95,21 @@ DLTensor* LoadDLTensor(dmlc::Stream* strm) {
static_cast
<
int
>
(
tensor
.
ctx
.
device_type
),
static_cast
<
int
>
(
tensor
.
ctx
.
device_type
),
tensor
.
ctx
.
device_id
,
tensor
.
ctx
.
device_id
,
&
ret
),
0
)
<<
TVMGetLastError
();
&
ret
),
0
)
<<
TVMGetLastError
();
int64_t
size
=
1
;
int64_t
num_elems
=
1
;
int
type_size
=
ret
->
dtype
.
bits
/
8
;
int
elem_bytes
=
(
ret
->
dtype
.
bits
+
7
)
/
8
;
for
(
int
i
=
0
;
i
<
ret
->
ndim
;
++
i
)
{
for
(
int
i
=
0
;
i
<
ret
->
ndim
;
++
i
)
{
size
*=
ret
->
shape
[
i
];
num_elems
*=
ret
->
shape
[
i
];
}
}
int64_t
data_byte_size
;
int64_t
data_byte_size
;
CHECK
(
strm
->
Read
(
&
data_byte_size
,
sizeof
(
data_byte_size
)
))
CHECK
(
strm
->
Read
(
&
data_byte_size
))
<<
"Invalid DLTensor file format"
;
<<
"Invalid DLTensor file format"
;
CHECK
(
data_byte_size
==
type_size
*
size
)
CHECK
(
data_byte_size
==
num_elems
*
elem_bytes
)
<<
"Invalid DLTensor file format"
;
<<
"Invalid DLTensor file format"
;
CHECK
(
strm
->
Read
(
ret
->
data
,
type_size
*
size
))
CHECK
(
strm
->
Read
(
ret
->
data
,
data_byte_
size
))
<<
"Invalid DLTensor file format"
;
<<
"Invalid DLTensor file format"
;
if
(
!
DMLC_IO_NO_ENDIAN_SWAP
)
{
dmlc
::
ByteSwap
(
ret
->
data
,
elem_bytes
,
num_elems
);
}
return
ret
;
return
ret
;
}
}
...
@@ -118,12 +129,12 @@ TVM_REGISTER_GLOBAL("nnvm.compiler._save_param_dict")
...
@@ -118,12 +129,12 @@ TVM_REGISTER_GLOBAL("nnvm.compiler._save_param_dict")
dmlc
::
MemoryStringStream
strm
(
&
bytes
);
dmlc
::
MemoryStringStream
strm
(
&
bytes
);
dmlc
::
Stream
*
fo
=
&
strm
;
dmlc
::
Stream
*
fo
=
&
strm
;
uint64_t
header
=
kTVMNDArrayListMagic
,
reserved
=
0
;
uint64_t
header
=
kTVMNDArrayListMagic
,
reserved
=
0
;
fo
->
Write
(
&
header
,
sizeof
(
header
)
);
fo
->
Write
(
header
);
fo
->
Write
(
&
reserved
,
sizeof
(
reserved
)
);
fo
->
Write
(
reserved
);
fo
->
Write
(
names
);
fo
->
Write
(
names
);
{
{
uint64_t
sz
=
static_cast
<
uint64_t
>
(
arrays
.
size
());
uint64_t
sz
=
static_cast
<
uint64_t
>
(
arrays
.
size
());
fo
->
Write
(
&
sz
,
sizeof
(
sz
)
);
fo
->
Write
(
sz
);
for
(
size_t
i
=
0
;
i
<
sz
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
sz
;
++
i
)
{
SaveDLTensor
(
fo
,
arrays
[
i
]);
SaveDLTensor
(
fo
,
arrays
[
i
]);
}
}
...
@@ -150,7 +161,6 @@ TVM_REGISTER_GLOBAL("nnvm.compiler._load_param_dict")
...
@@ -150,7 +161,6 @@ TVM_REGISTER_GLOBAL("nnvm.compiler._load_param_dict")
<<
"Invalid parameters file format"
;
<<
"Invalid parameters file format"
;
CHECK
(
strm
->
Read
(
&
reserved
))
CHECK
(
strm
->
Read
(
&
reserved
))
<<
"Invalid parameters file format"
;
<<
"Invalid parameters file format"
;
CHECK
(
strm
->
Read
(
&
names
))
CHECK
(
strm
->
Read
(
&
names
))
<<
"Invalid parameters file format"
;
<<
"Invalid parameters file format"
;
uint64_t
sz
;
uint64_t
sz
;
...
...
python/tvm/contrib/rpc/base.py
View file @
42608dda
...
@@ -73,7 +73,7 @@ def sendjson(sock, data):
...
@@ -73,7 +73,7 @@ def sendjson(sock, data):
Python value to be sent.
Python value to be sent.
"""
"""
data
=
json
.
dumps
(
data
)
data
=
json
.
dumps
(
data
)
sock
.
sendall
(
struct
.
pack
(
"
@
i"
,
len
(
data
)))
sock
.
sendall
(
struct
.
pack
(
"
<
i"
,
len
(
data
)))
sock
.
sendall
(
data
.
encode
(
"utf-8"
))
sock
.
sendall
(
data
.
encode
(
"utf-8"
))
...
@@ -90,7 +90,7 @@ def recvjson(sock):
...
@@ -90,7 +90,7 @@ def recvjson(sock):
value : object
value : object
The value received.
The value received.
"""
"""
size
=
struct
.
unpack
(
"
@
i"
,
recvall
(
sock
,
4
))[
0
]
size
=
struct
.
unpack
(
"
<
i"
,
recvall
(
sock
,
4
))[
0
]
data
=
json
.
loads
(
py_str
(
recvall
(
sock
,
size
)))
data
=
json
.
loads
(
py_str
(
recvall
(
sock
,
size
)))
return
data
return
data
...
...
python/tvm/contrib/rpc/client.py
View file @
42608dda
...
@@ -192,8 +192,8 @@ class TrackerSession(object):
...
@@ -192,8 +192,8 @@ class TrackerSession(object):
def
_connect
(
self
):
def
_connect
(
self
):
self
.
_sock
=
base
.
connect_with_retry
(
self
.
_addr
)
self
.
_sock
=
base
.
connect_with_retry
(
self
.
_addr
)
self
.
_sock
.
sendall
(
struct
.
pack
(
"
@
i"
,
base
.
RPC_TRACKER_MAGIC
))
self
.
_sock
.
sendall
(
struct
.
pack
(
"
<
i"
,
base
.
RPC_TRACKER_MAGIC
))
magic
=
struct
.
unpack
(
"
@
i"
,
base
.
recvall
(
self
.
_sock
,
4
))[
0
]
magic
=
struct
.
unpack
(
"
<
i"
,
base
.
recvall
(
self
.
_sock
,
4
))[
0
]
if
magic
!=
base
.
RPC_TRACKER_MAGIC
:
if
magic
!=
base
.
RPC_TRACKER_MAGIC
:
raise
RuntimeError
(
"
%
s is not RPC Tracker"
%
str
(
self
.
_addr
))
raise
RuntimeError
(
"
%
s is not RPC Tracker"
%
str
(
self
.
_addr
))
...
...
python/tvm/contrib/rpc/proxy.py
View file @
42608dda
...
@@ -58,14 +58,14 @@ class ForwardHandler(object):
...
@@ -58,14 +58,14 @@ class ForwardHandler(object):
def
_init_step
(
self
,
message
):
def
_init_step
(
self
,
message
):
if
self
.
_magic
is
None
:
if
self
.
_magic
is
None
:
assert
len
(
message
)
==
4
assert
len
(
message
)
==
4
self
.
_magic
=
struct
.
unpack
(
'
@
i'
,
message
)[
0
]
self
.
_magic
=
struct
.
unpack
(
'
<
i'
,
message
)[
0
]
if
self
.
_magic
!=
base
.
RPC_MAGIC
:
if
self
.
_magic
!=
base
.
RPC_MAGIC
:
logging
.
info
(
"Invalid RPC magic from
%
s"
,
self
.
name
())
logging
.
info
(
"Invalid RPC magic from
%
s"
,
self
.
name
())
self
.
close
()
self
.
close
()
self
.
_init_req_nbytes
=
4
self
.
_init_req_nbytes
=
4
elif
self
.
_rpc_key_length
is
None
:
elif
self
.
_rpc_key_length
is
None
:
assert
len
(
message
)
==
4
assert
len
(
message
)
==
4
self
.
_rpc_key_length
=
struct
.
unpack
(
'
@
i'
,
message
)[
0
]
self
.
_rpc_key_length
=
struct
.
unpack
(
'
<
i'
,
message
)[
0
]
self
.
_init_req_nbytes
=
self
.
_rpc_key_length
self
.
_init_req_nbytes
=
self
.
_rpc_key_length
elif
self
.
rpc_key
is
None
:
elif
self
.
rpc_key
is
None
:
assert
len
(
message
)
==
self
.
_rpc_key_length
assert
len
(
message
)
==
self
.
_rpc_key_length
...
@@ -269,12 +269,12 @@ class ProxyServerHandler(object):
...
@@ -269,12 +269,12 @@ class ProxyServerHandler(object):
lhs
.
forward_proxy
=
rhs
lhs
.
forward_proxy
=
rhs
rhs
.
forward_proxy
=
lhs
rhs
.
forward_proxy
=
lhs
lhs
.
send_data
(
struct
.
pack
(
'
@
i'
,
base
.
RPC_CODE_SUCCESS
))
lhs
.
send_data
(
struct
.
pack
(
'
<
i'
,
base
.
RPC_CODE_SUCCESS
))
lhs
.
send_data
(
struct
.
pack
(
'
@
i'
,
len
(
rhs
.
rpc_key
)))
lhs
.
send_data
(
struct
.
pack
(
'
<
i'
,
len
(
rhs
.
rpc_key
)))
lhs
.
send_data
(
rhs
.
rpc_key
.
encode
(
"utf-8"
))
lhs
.
send_data
(
rhs
.
rpc_key
.
encode
(
"utf-8"
))
rhs
.
send_data
(
struct
.
pack
(
'
@
i'
,
base
.
RPC_CODE_SUCCESS
))
rhs
.
send_data
(
struct
.
pack
(
'
<
i'
,
base
.
RPC_CODE_SUCCESS
))
rhs
.
send_data
(
struct
.
pack
(
'
@
i'
,
len
(
lhs
.
rpc_key
)))
rhs
.
send_data
(
struct
.
pack
(
'
<
i'
,
len
(
lhs
.
rpc_key
)))
rhs
.
send_data
(
lhs
.
rpc_key
.
encode
(
"utf-8"
))
rhs
.
send_data
(
lhs
.
rpc_key
.
encode
(
"utf-8"
))
logging
.
info
(
"Pairup connect
%
s and
%
s"
,
lhs
.
name
(),
rhs
.
name
())
logging
.
info
(
"Pairup connect
%
s and
%
s"
,
lhs
.
name
(),
rhs
.
name
())
...
@@ -299,8 +299,8 @@ class ProxyServerHandler(object):
...
@@ -299,8 +299,8 @@ class ProxyServerHandler(object):
if
self
.
_tracker_conn
is
None
:
if
self
.
_tracker_conn
is
None
:
self
.
_tracker_conn
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
self
.
_tracker_conn
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
self
.
_tracker_conn
.
connect
(
self
.
_tracker_addr
)
self
.
_tracker_conn
.
connect
(
self
.
_tracker_addr
)
self
.
_tracker_conn
.
sendall
(
struct
.
pack
(
"
@
i"
,
base
.
RPC_TRACKER_MAGIC
))
self
.
_tracker_conn
.
sendall
(
struct
.
pack
(
"
<
i"
,
base
.
RPC_TRACKER_MAGIC
))
magic
=
struct
.
unpack
(
"
@
i"
,
base
.
recvall
(
self
.
_tracker_conn
,
4
))[
0
]
magic
=
struct
.
unpack
(
"
<
i"
,
base
.
recvall
(
self
.
_tracker_conn
,
4
))[
0
]
if
magic
!=
base
.
RPC_TRACKER_MAGIC
:
if
magic
!=
base
.
RPC_TRACKER_MAGIC
:
self
.
loop
.
stop
()
self
.
loop
.
stop
()
raise
RuntimeError
(
"
%
s is not RPC Tracker"
%
str
(
self
.
_tracker_addr
))
raise
RuntimeError
(
"
%
s is not RPC Tracker"
%
str
(
self
.
_tracker_addr
))
...
@@ -371,7 +371,7 @@ class ProxyServerHandler(object):
...
@@ -371,7 +371,7 @@ class ProxyServerHandler(object):
if
handler
.
match_key
in
self
.
_server_pool
:
if
handler
.
match_key
in
self
.
_server_pool
:
self
.
_pair_up
(
self
.
_server_pool
.
pop
(
handler
.
match_key
),
handler
)
self
.
_pair_up
(
self
.
_server_pool
.
pop
(
handler
.
match_key
),
handler
)
else
:
else
:
handler
.
send_data
(
struct
.
pack
(
'
@
i'
,
base
.
RPC_CODE_MISMATCH
))
handler
.
send_data
(
struct
.
pack
(
'
<
i'
,
base
.
RPC_CODE_MISMATCH
))
handler
.
signal_close
()
handler
.
signal_close
()
def
_handler_ready_proxy_mode
(
self
,
handler
):
def
_handler_ready_proxy_mode
(
self
,
handler
):
...
@@ -395,12 +395,12 @@ class ProxyServerHandler(object):
...
@@ -395,12 +395,12 @@ class ProxyServerHandler(object):
logging
.
info
(
"Timeout client connection
%
s, cannot find match key=
%
s"
,
logging
.
info
(
"Timeout client connection
%
s, cannot find match key=
%
s"
,
handler
.
name
(),
key
)
handler
.
name
(),
key
)
pool_dst
.
pop
(
key
)
pool_dst
.
pop
(
key
)
handler
.
send_data
(
struct
.
pack
(
'
@
i'
,
base
.
RPC_CODE_MISMATCH
))
handler
.
send_data
(
struct
.
pack
(
'
<
i'
,
base
.
RPC_CODE_MISMATCH
))
handler
.
signal_close
()
handler
.
signal_close
()
self
.
loop
.
call_later
(
timeout
,
cleanup
)
self
.
loop
.
call_later
(
timeout
,
cleanup
)
else
:
else
:
logging
.
info
(
"Duplicate connection with same key=
%
s"
,
key
)
logging
.
info
(
"Duplicate connection with same key=
%
s"
,
key
)
handler
.
send_data
(
struct
.
pack
(
'
@
i'
,
base
.
RPC_CODE_DUPLICATE
))
handler
.
send_data
(
struct
.
pack
(
'
<
i'
,
base
.
RPC_CODE_DUPLICATE
))
handler
.
signal_close
()
handler
.
signal_close
()
def
handler_ready
(
self
,
handler
):
def
handler_ready
(
self
,
handler
):
...
@@ -538,13 +538,13 @@ def websocket_proxy_server(url, key=""):
...
@@ -538,13 +538,13 @@ def websocket_proxy_server(url, key=""):
on_message
=
create_on_message
(
conn
)
on_message
=
create_on_message
(
conn
)
temp
=
_server_env
(
None
)
temp
=
_server_env
(
None
)
# Start connecton
# Start connecton
conn
.
write_message
(
struct
.
pack
(
'
@
i'
,
base
.
RPC_MAGIC
),
binary
=
True
)
conn
.
write_message
(
struct
.
pack
(
'
<
i'
,
base
.
RPC_MAGIC
),
binary
=
True
)
key
=
"server:"
+
key
key
=
"server:"
+
key
conn
.
write_message
(
struct
.
pack
(
'
@
i'
,
len
(
key
)),
binary
=
True
)
conn
.
write_message
(
struct
.
pack
(
'
<
i'
,
len
(
key
)),
binary
=
True
)
conn
.
write_message
(
key
.
encode
(
"utf-8"
),
binary
=
True
)
conn
.
write_message
(
key
.
encode
(
"utf-8"
),
binary
=
True
)
msg
=
yield
conn
.
read_message
()
msg
=
yield
conn
.
read_message
()
assert
len
(
msg
)
>=
4
assert
len
(
msg
)
>=
4
magic
=
struct
.
unpack
(
'
@
i'
,
msg
[:
4
])[
0
]
magic
=
struct
.
unpack
(
'
<
i'
,
msg
[:
4
])[
0
]
if
magic
==
base
.
RPC_CODE_DUPLICATE
:
if
magic
==
base
.
RPC_CODE_DUPLICATE
:
raise
RuntimeError
(
"key:
%
s has already been used in proxy"
%
key
)
raise
RuntimeError
(
"key:
%
s has already been used in proxy"
%
key
)
elif
magic
==
base
.
RPC_CODE_MISMATCH
:
elif
magic
==
base
.
RPC_CODE_MISMATCH
:
...
...
python/tvm/contrib/rpc/server.py
View file @
42608dda
...
@@ -124,23 +124,23 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library):
...
@@ -124,23 +124,23 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library):
unmatch_period_count
=
0
unmatch_period_count
=
0
continue
continue
conn
,
addr
=
listen_sock
.
accept
()
conn
,
addr
=
listen_sock
.
accept
()
magic
=
struct
.
unpack
(
"
@
i"
,
base
.
recvall
(
conn
,
4
))[
0
]
magic
=
struct
.
unpack
(
"
<
i"
,
base
.
recvall
(
conn
,
4
))[
0
]
if
magic
!=
base
.
RPC_MAGIC
:
if
magic
!=
base
.
RPC_MAGIC
:
conn
.
close
()
conn
.
close
()
continue
continue
keylen
=
struct
.
unpack
(
"
@
i"
,
base
.
recvall
(
conn
,
4
))[
0
]
keylen
=
struct
.
unpack
(
"
<
i"
,
base
.
recvall
(
conn
,
4
))[
0
]
key
=
py_str
(
base
.
recvall
(
conn
,
keylen
))
key
=
py_str
(
base
.
recvall
(
conn
,
keylen
))
arr
=
key
.
split
()
arr
=
key
.
split
()
expect_header
=
"client:"
+
matchkey
expect_header
=
"client:"
+
matchkey
server_key
=
"server:"
+
rpc_key
server_key
=
"server:"
+
rpc_key
if
arr
[
0
]
!=
expect_header
:
if
arr
[
0
]
!=
expect_header
:
conn
.
sendall
(
struct
.
pack
(
"
@
i"
,
base
.
RPC_CODE_MISMATCH
))
conn
.
sendall
(
struct
.
pack
(
"
<
i"
,
base
.
RPC_CODE_MISMATCH
))
conn
.
close
()
conn
.
close
()
logging
.
info
(
"RPCServer: mismatch key from
%
s"
,
addr
)
logging
.
info
(
"RPCServer: mismatch key from
%
s"
,
addr
)
continue
continue
else
:
else
:
conn
.
sendall
(
struct
.
pack
(
"
@
i"
,
base
.
RPC_CODE_SUCCESS
))
conn
.
sendall
(
struct
.
pack
(
"
<
i"
,
base
.
RPC_CODE_SUCCESS
))
conn
.
sendall
(
struct
.
pack
(
"
@
i"
,
len
(
server_key
)))
conn
.
sendall
(
struct
.
pack
(
"
<
i"
,
len
(
server_key
)))
conn
.
sendall
(
server_key
.
encode
(
"utf-8"
))
conn
.
sendall
(
server_key
.
encode
(
"utf-8"
))
return
conn
,
addr
,
_parse_server_opt
(
arr
[
1
:])
return
conn
,
addr
,
_parse_server_opt
(
arr
[
1
:])
...
@@ -151,8 +151,8 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library):
...
@@ -151,8 +151,8 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library):
# step 1: setup tracker and report to tracker
# step 1: setup tracker and report to tracker
if
tracker_addr
and
tracker_conn
is
None
:
if
tracker_addr
and
tracker_conn
is
None
:
tracker_conn
=
base
.
connect_with_retry
(
tracker_addr
)
tracker_conn
=
base
.
connect_with_retry
(
tracker_addr
)
tracker_conn
.
sendall
(
struct
.
pack
(
"
@
i"
,
base
.
RPC_TRACKER_MAGIC
))
tracker_conn
.
sendall
(
struct
.
pack
(
"
<
i"
,
base
.
RPC_TRACKER_MAGIC
))
magic
=
struct
.
unpack
(
"
@
i"
,
base
.
recvall
(
tracker_conn
,
4
))[
0
]
magic
=
struct
.
unpack
(
"
<
i"
,
base
.
recvall
(
tracker_conn
,
4
))[
0
]
if
magic
!=
base
.
RPC_TRACKER_MAGIC
:
if
magic
!=
base
.
RPC_TRACKER_MAGIC
:
raise
RuntimeError
(
"
%
s is not RPC Tracker"
%
str
(
tracker_addr
))
raise
RuntimeError
(
"
%
s is not RPC Tracker"
%
str
(
tracker_addr
))
# report status of current queue
# report status of current queue
...
@@ -193,17 +193,17 @@ def _connect_proxy_loop(addr, key, load_library):
...
@@ -193,17 +193,17 @@ def _connect_proxy_loop(addr, key, load_library):
try
:
try
:
sock
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
sock
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
sock
.
connect
(
addr
)
sock
.
connect
(
addr
)
sock
.
sendall
(
struct
.
pack
(
"
@
i"
,
base
.
RPC_MAGIC
))
sock
.
sendall
(
struct
.
pack
(
"
<
i"
,
base
.
RPC_MAGIC
))
sock
.
sendall
(
struct
.
pack
(
"
@
i"
,
len
(
key
)))
sock
.
sendall
(
struct
.
pack
(
"
<
i"
,
len
(
key
)))
sock
.
sendall
(
key
.
encode
(
"utf-8"
))
sock
.
sendall
(
key
.
encode
(
"utf-8"
))
magic
=
struct
.
unpack
(
"
@
i"
,
base
.
recvall
(
sock
,
4
))[
0
]
magic
=
struct
.
unpack
(
"
<
i"
,
base
.
recvall
(
sock
,
4
))[
0
]
if
magic
==
base
.
RPC_CODE_DUPLICATE
:
if
magic
==
base
.
RPC_CODE_DUPLICATE
:
raise
RuntimeError
(
"key:
%
s has already been used in proxy"
%
key
)
raise
RuntimeError
(
"key:
%
s has already been used in proxy"
%
key
)
elif
magic
==
base
.
RPC_CODE_MISMATCH
:
elif
magic
==
base
.
RPC_CODE_MISMATCH
:
logging
.
info
(
"RPCProxy do not have matching client key
%
s"
,
key
)
logging
.
info
(
"RPCProxy do not have matching client key
%
s"
,
key
)
elif
magic
!=
base
.
RPC_CODE_SUCCESS
:
elif
magic
!=
base
.
RPC_CODE_SUCCESS
:
raise
RuntimeError
(
"
%
s is not RPC Proxy"
%
str
(
addr
))
raise
RuntimeError
(
"
%
s is not RPC Proxy"
%
str
(
addr
))
keylen
=
struct
.
unpack
(
"
@
i"
,
base
.
recvall
(
sock
,
4
))[
0
]
keylen
=
struct
.
unpack
(
"
<
i"
,
base
.
recvall
(
sock
,
4
))[
0
]
remote_key
=
py_str
(
base
.
recvall
(
sock
,
keylen
))
remote_key
=
py_str
(
base
.
recvall
(
sock
,
keylen
))
opts
=
_parse_server_opt
(
remote_key
.
split
()[
1
:])
opts
=
_parse_server_opt
(
remote_key
.
split
()[
1
:])
logging
.
info
(
"RPCProxy connected to
%
s"
,
str
(
addr
))
logging
.
info
(
"RPCProxy connected to
%
s"
,
str
(
addr
))
...
...
python/tvm/contrib/rpc/tracker.py
View file @
42608dda
...
@@ -143,11 +143,11 @@ class TCPEventHandler(tornado_util.TCPHandler):
...
@@ -143,11 +143,11 @@ class TCPEventHandler(tornado_util.TCPHandler):
if
len
(
message
)
!=
4
:
if
len
(
message
)
!=
4
:
logging
.
info
(
"Invalid connection from
%
s"
,
self
.
name
())
logging
.
info
(
"Invalid connection from
%
s"
,
self
.
name
())
self
.
close
()
self
.
close
()
magic
=
struct
.
unpack
(
'
@
i'
,
message
)[
0
]
magic
=
struct
.
unpack
(
'
<
i'
,
message
)[
0
]
if
magic
!=
RPC_TRACKER_MAGIC
:
if
magic
!=
RPC_TRACKER_MAGIC
:
logging
.
info
(
"Invalid magic from
%
s"
,
self
.
name
())
logging
.
info
(
"Invalid magic from
%
s"
,
self
.
name
())
self
.
close
()
self
.
close
()
self
.
write_message
(
struct
.
pack
(
'
@
i'
,
RPC_TRACKER_MAGIC
),
binary
=
True
)
self
.
write_message
(
struct
.
pack
(
'
<
i'
,
RPC_TRACKER_MAGIC
),
binary
=
True
)
self
.
_init_req_nbytes
=
0
self
.
_init_req_nbytes
=
0
def
on_message
(
self
,
message
):
def
on_message
(
self
,
message
):
...
@@ -168,7 +168,7 @@ class TCPEventHandler(tornado_util.TCPHandler):
...
@@ -168,7 +168,7 @@ class TCPEventHandler(tornado_util.TCPHandler):
while
True
:
while
True
:
if
self
.
_msg_size
==
0
:
if
self
.
_msg_size
==
0
:
if
len
(
self
.
_data
)
>=
4
:
if
len
(
self
.
_data
)
>=
4
:
self
.
_msg_size
=
struct
.
unpack
(
'
@
i'
,
self
.
_data
[:
4
])[
0
]
self
.
_msg_size
=
struct
.
unpack
(
'
<
i'
,
self
.
_data
[:
4
])[
0
]
else
:
else
:
return
return
if
self
.
_msg_size
!=
0
and
len
(
self
.
_data
)
>=
self
.
_msg_size
+
4
:
if
self
.
_msg_size
!=
0
and
len
(
self
.
_data
)
>=
self
.
_msg_size
+
4
:
...
@@ -184,7 +184,7 @@ class TCPEventHandler(tornado_util.TCPHandler):
...
@@ -184,7 +184,7 @@ class TCPEventHandler(tornado_util.TCPHandler):
"""return value to the output"""
"""return value to the output"""
data
=
json
.
dumps
(
data
)
data
=
json
.
dumps
(
data
)
self
.
write_message
(
self
.
write_message
(
struct
.
pack
(
'
@
i'
,
len
(
data
)),
binary
=
True
)
struct
.
pack
(
'
<
i'
,
len
(
data
)),
binary
=
True
)
self
.
write_message
(
data
.
encode
(
"utf-8"
),
binary
=
True
)
self
.
write_message
(
data
.
encode
(
"utf-8"
),
binary
=
True
)
def
call_handler
(
self
,
args
):
def
call_handler
(
self
,
args
):
...
@@ -355,8 +355,8 @@ class Tracker(object):
...
@@ -355,8 +355,8 @@ class Tracker(object):
def
_stop_tracker
(
self
):
def
_stop_tracker
(
self
):
sock
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
sock
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
sock
.
connect
((
self
.
host
,
self
.
port
))
sock
.
connect
((
self
.
host
,
self
.
port
))
sock
.
sendall
(
struct
.
pack
(
"
@
i"
,
base
.
RPC_TRACKER_MAGIC
))
sock
.
sendall
(
struct
.
pack
(
"
<
i"
,
base
.
RPC_TRACKER_MAGIC
))
magic
=
struct
.
unpack
(
"
@
i"
,
base
.
recvall
(
sock
,
4
))[
0
]
magic
=
struct
.
unpack
(
"
<
i"
,
base
.
recvall
(
sock
,
4
))[
0
]
assert
magic
==
base
.
RPC_TRACKER_MAGIC
assert
magic
==
base
.
RPC_TRACKER_MAGIC
base
.
sendjson
(
sock
,
[
TrackerCode
.
STOP
,
self
.
stop_key
])
base
.
sendjson
(
sock
,
[
TrackerCode
.
STOP
,
self
.
stop_key
])
assert
base
.
recvjson
(
sock
)
==
TrackerCode
.
SUCCESS
assert
base
.
recvjson
(
sock
)
==
TrackerCode
.
SUCCESS
...
...
src/runtime/file_util.cc
View file @
42608dda
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
*/
*/
#include <dmlc/json.h>
#include <dmlc/json.h>
#include <dmlc/logging.h>
#include <dmlc/logging.h>
#include <tvm/runtime/serializer.h>
#include <fstream>
#include <fstream>
#include "./file_util.h"
#include "./file_util.h"
...
...
src/runtime/graph/graph_runtime.cc
View file @
42608dda
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
*/
*/
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/serializer.h>
#include <dmlc/memory_io.h>
#include <dmlc/memory_io.h>
#include <dmlc/json.h>
#include <dmlc/json.h>
#include <numeric>
#include <numeric>
...
@@ -397,24 +398,25 @@ class GraphRuntime : public ModuleNode {
...
@@ -397,24 +398,25 @@ class GraphRuntime : public ModuleNode {
void
GraphRuntime
::
LoadDLTensor
(
dmlc
::
Stream
*
strm
,
DLTensor
*
dst
)
{
void
GraphRuntime
::
LoadDLTensor
(
dmlc
::
Stream
*
strm
,
DLTensor
*
dst
)
{
// always use strm->Read to maintain endianness conversion
uint64_t
header
,
reserved
;
uint64_t
header
,
reserved
;
CHECK
(
strm
->
Read
(
&
header
,
sizeof
(
header
)
))
CHECK
(
strm
->
Read
(
&
header
))
<<
"Invalid DLTensor file format"
;
<<
"Invalid DLTensor file format"
;
CHECK
(
strm
->
Read
(
&
reserved
,
sizeof
(
reserved
)
))
CHECK
(
strm
->
Read
(
&
reserved
))
<<
"Invalid DLTensor file format"
;
<<
"Invalid DLTensor file format"
;
CHECK
(
header
==
kTVMNDArrayMagic
)
CHECK
(
header
==
kTVMNDArrayMagic
)
<<
"Invalid DLTensor file format"
;
<<
"Invalid DLTensor file format"
;
DLTensor
tensor
;
DLTensor
tensor
;
CHECK
(
strm
->
Read
(
&
tensor
.
ctx
,
sizeof
(
tensor
.
ctx
)))
CHECK
(
strm
->
Read
(
&
(
tensor
.
ctx
)))
<<
"Invalid DLTensor file format"
;
<<
"Invalid DLTensor file format"
;
CHECK
(
strm
->
Read
(
&
tensor
.
ndim
,
sizeof
(
tensor
.
ndim
)))
CHECK
(
strm
->
Read
(
&
(
tensor
.
ndim
)))
<<
"Invalid DLTensor file format"
;
<<
"Invalid DLTensor file format"
;
CHECK
(
strm
->
Read
(
&
tensor
.
dtype
,
sizeof
(
tensor
.
dtype
)))
CHECK
(
strm
->
Read
(
&
(
tensor
.
dtype
)))
<<
"Invalid DLTensor file format"
;
<<
"Invalid DLTensor file format"
;
std
::
vector
<
int64_t
>
shape
(
tensor
.
ndim
);
std
::
vector
<
int64_t
>
shape
(
tensor
.
ndim
);
if
(
tensor
.
ndim
!=
0
)
{
if
(
tensor
.
ndim
!=
0
)
{
CHECK
(
strm
->
Read
(
&
shape
[
0
],
sizeof
(
int64_t
)
*
tensor
.
ndim
))
CHECK
(
strm
->
Read
Array
(
&
shape
[
0
],
tensor
.
ndim
))
<<
"Invalid DLTensor file format"
;
<<
"Invalid DLTensor file format"
;
}
}
CHECK_EQ
(
tensor
.
ndim
,
dst
->
ndim
)
<<
"param dimension mismatch"
;
CHECK_EQ
(
tensor
.
ndim
,
dst
->
ndim
)
<<
"param dimension mismatch"
;
...
@@ -425,18 +427,23 @@ void GraphRuntime::LoadDLTensor(dmlc::Stream* strm, DLTensor* dst) {
...
@@ -425,18 +427,23 @@ void GraphRuntime::LoadDLTensor(dmlc::Stream* strm, DLTensor* dst) {
CHECK_EQ
(
shape
[
i
],
dst
->
shape
[
i
])
<<
"param shape mismatch"
;
CHECK_EQ
(
shape
[
i
],
dst
->
shape
[
i
])
<<
"param shape mismatch"
;
}
}
size_t
bits
=
dst
->
dtype
.
bits
*
dst
->
dtype
.
lanes
;
size_t
bits
=
dst
->
dtype
.
bits
*
dst
->
dtype
.
lanes
;
size_t
size
=
(
bits
+
7
)
/
8
;
size_t
elem_bytes
=
(
bits
+
7
)
/
8
;
size_t
num_elems
=
1
;
for
(
int
i
=
0
;
i
<
dst
->
ndim
;
++
i
)
{
for
(
int
i
=
0
;
i
<
dst
->
ndim
;
++
i
)
{
size
*=
dst
->
shape
[
i
];
num_elems
*=
dst
->
shape
[
i
];
}
}
uint64_t
data_byte_size
;
uint64_t
data_byte_size
;
CHECK
(
strm
->
Read
(
&
data_byte_size
,
sizeof
(
data_byte_size
)
))
CHECK
(
strm
->
Read
(
&
data_byte_size
))
<<
"Invalid DLTensor file format"
;
<<
"Invalid DLTensor file format"
;
CHECK
(
data_byte_size
==
size
)
CHECK
_EQ
(
data_byte_size
,
elem_bytes
*
num_elems
)
<<
"Invalid DLTensor file format"
;
<<
"Invalid DLTensor file format"
;
std
::
vector
<
uint8_t
>
bytes
(
data_byte_size
+
1
);
std
::
vector
<
uint8_t
>
bytes
(
data_byte_size
+
1
);
CHECK
(
strm
->
Read
(
&
bytes
[
0
],
data_byte_size
))
CHECK
(
strm
->
Read
(
&
bytes
[
0
],
data_byte_size
))
<<
"Invalid DLTensor file format"
;
<<
"Invalid DLTensor file format"
;
// explicitly swap endian when necessary.
if
(
!
DMLC_IO_NO_ENDIAN_SWAP
)
{
dmlc
::
ByteSwap
(
&
bytes
[
0
],
elem_bytes
,
num_elems
);
}
TVM_CCALL
(
TVMArrayCopyFromBytes
(
dst
,
&
bytes
[
0
],
data_byte_size
));
TVM_CCALL
(
TVMArrayCopyFromBytes
(
dst
,
&
bytes
[
0
],
data_byte_size
));
}
}
...
@@ -453,9 +460,8 @@ void GraphRuntime::LoadParams(dmlc::Stream* strm) {
...
@@ -453,9 +460,8 @@ void GraphRuntime::LoadParams(dmlc::Stream* strm) {
CHECK
(
strm
->
Read
(
&
names
))
CHECK
(
strm
->
Read
(
&
names
))
<<
"Invalid parameters file format"
;
<<
"Invalid parameters file format"
;
uint64_t
sz
;
uint64_t
sz
;
strm
->
Read
(
&
sz
,
sizeof
(
sz
)
);
strm
->
Read
(
&
sz
);
size_t
size
=
static_cast
<
size_t
>
(
sz
);
size_t
size
=
static_cast
<
size_t
>
(
sz
);
CHECK
(
size
==
names
.
size
())
CHECK
(
size
==
names
.
size
())
<<
"Invalid parameters file format"
;
<<
"Invalid parameters file format"
;
for
(
size_t
i
=
0
;
i
<
size
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
size
;
++
i
)
{
...
...
src/runtime/rpc/rpc_session.cc
View file @
42608dda
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment