Commit 42608dda by tqchen Committed by Tianqi Chen

[IO] Support cross-endian

parent 0f9dab98
Subproject commit d3f7fbb53e5b037c0f5bf6bd21871ccc720690cc Subproject commit 9b3f9753ae81d657743c555e0cacc4e43f0bed2d
/*!
* Copyright (c) 2017 by Contributors
* \file tvm/runtime/serializer.h
* \brief Serializer extension to support TVM data types
* Include this file to enable serialization of DLDataType, DLContext
*/
#ifndef TVM_RUNTIME_SERIALIZER_H_
#define TVM_RUNTIME_SERIALIZER_H_
#include <dmlc/io.h>
#include <dmlc/serializer.h>
#include "./c_runtime_api.h"
namespace dmlc {
namespace serializer {
template<>
struct Handler<DLDataType> {
inline static void Write(Stream *strm, const DLDataType& dtype) {
Handler<uint8_t>::Write(strm, dtype.code);
Handler<uint8_t>::Write(strm, dtype.bits);
Handler<uint16_t>::Write(strm, dtype.lanes);
}
inline static bool Read(Stream *strm, DLDataType* dtype) {
if (!Handler<uint8_t>::Read(strm, &(dtype->code))) return false;
if (!Handler<uint8_t>::Read(strm, &(dtype->bits))) return false;
if (!Handler<uint16_t>::Read(strm, &(dtype->lanes))) return false;
return true;
}
};
template<>
struct Handler<DLContext> {
inline static void Write(Stream *strm, const DLContext& ctx) {
int32_t device_type = static_cast<int32_t>(ctx.device_type);
Handler<int32_t>::Write(strm, device_type);
Handler<int32_t>::Write(strm, ctx.device_id);
}
inline static bool Read(Stream *strm, DLContext* ctx) {
int32_t device_type = 0;
if (!Handler<int32_t>::Read(strm, &(device_type))) return false;
ctx->device_type = static_cast<DLDeviceType>(device_type);
if (!Handler<int32_t>::Read(strm, &(ctx->device_id))) return false;
return true;
}
};
} // namespace serializer
} // namespace dmlc
#endif // TVM_RUNTIME_SERIALIZER_H_
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include <tvm/runtime/packed_func.h> #include <tvm/runtime/packed_func.h>
#include <tvm/runtime/c_runtime_api.h> #include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/runtime/serializer.h>
#include "./graph_runtime.h" #include "./graph_runtime.h"
namespace nnvm { namespace nnvm {
...@@ -38,46 +39,53 @@ NNVM_REGISTER_OP(tvm_op) ...@@ -38,46 +39,53 @@ NNVM_REGISTER_OP(tvm_op)
bool SaveDLTensor(dmlc::Stream* strm, DLTensor* tensor) { bool SaveDLTensor(dmlc::Stream* strm, DLTensor* tensor) {
uint64_t header = kTVMNDArrayMagic, reserved = 0; uint64_t header = kTVMNDArrayMagic, reserved = 0;
strm->Write(&header, sizeof(header)); strm->Write(header);
strm->Write(&reserved, sizeof(reserved)); strm->Write(reserved);
strm->Write(tensor->ctx);
strm->Write(&tensor->ctx, sizeof(tensor->ctx)); strm->Write(tensor->ndim);
strm->Write(&tensor->ndim, sizeof(tensor->ndim)); strm->Write(tensor->dtype);
strm->Write(&tensor->dtype, sizeof(tensor->dtype));
int ndim = tensor->ndim; int ndim = tensor->ndim;
strm->Write(tensor->shape, sizeof(int64_t) * ndim); strm->WriteArray(tensor->shape, ndim);
int type_size = tensor->dtype.bits / 8; int type_bytes = tensor->dtype.bits / 8;
int64_t size = 1; int64_t num_elems = 1;
for (int i = 0; i < ndim; ++i) { for (int i = 0; i < ndim; ++i) {
size *= tensor->shape[i]; num_elems *= tensor->shape[i];
} }
int64_t data_byte_size = type_size * size; int64_t data_byte_size = type_bytes * num_elems;
strm->Write(&data_byte_size, sizeof(data_byte_size)); strm->Write(data_byte_size);
// handle endianness of data correctly.
if (DMLC_IO_NO_ENDIAN_SWAP) {
strm->Write(tensor->data, data_byte_size); strm->Write(tensor->data, data_byte_size);
} else {
uint8_t* dptr = reinterpret_cast<uint8_t*>(tensor->data);
std::vector<uint8_t> bytes(dptr, dptr + data_byte_size);
dmlc::ByteSwap(dmlc::BeginPtr(bytes), type_bytes, num_elems);
strm->Write(dmlc::BeginPtr(bytes), data_byte_size);
}
return true; return true;
} }
DLTensor* LoadDLTensor(dmlc::Stream* strm) { DLTensor* LoadDLTensor(dmlc::Stream* strm) {
uint64_t header, reserved; uint64_t header, reserved;
CHECK(strm->Read(&header, sizeof(header))) CHECK(strm->Read(&header))
<< "Invalid DLTensor file format"; << "Invalid DLTensor file format";
CHECK(strm->Read(&reserved, sizeof(reserved))) CHECK(strm->Read(&reserved))
<< "Invalid DLTensor file format"; << "Invalid DLTensor file format";
CHECK(header == kTVMNDArrayMagic) CHECK(header == kTVMNDArrayMagic)
<< "Invalid DLTensor file format"; << "Invalid DLTensor file format";
DLTensor tensor; DLTensor tensor;
CHECK(strm->Read(&tensor.ctx, sizeof(tensor.ctx))) CHECK(strm->Read(&(tensor.ctx)))
<< "Invalid DLTensor file format"; << "Invalid DLTensor file format";
CHECK(strm->Read(&tensor.ndim, sizeof(tensor.ndim))) CHECK(strm->Read(&(tensor.ndim)))
<< "Invalid DLTensor file format"; << "Invalid DLTensor file format";
CHECK(strm->Read(&tensor.dtype, sizeof(tensor.dtype))) CHECK(strm->Read(&(tensor.dtype)))
<< "Invalid DLTensor file format"; << "Invalid DLTensor file format";
std::vector<int64_t> shape(tensor.ndim); std::vector<int64_t> shape(tensor.ndim);
CHECK(strm->Read(&shape[0], sizeof(int64_t) * tensor.ndim)) if (tensor.ndim != 0) {
CHECK(strm->ReadArray(&shape[0], tensor.ndim))
<< "Invalid DLTensor file format"; << "Invalid DLTensor file format";
}
DLTensor* ret; DLTensor* ret;
CHECK_EQ(TVMArrayAlloc(shape.data(), CHECK_EQ(TVMArrayAlloc(shape.data(),
tensor.ndim, tensor.ndim,
...@@ -87,18 +95,21 @@ DLTensor* LoadDLTensor(dmlc::Stream* strm) { ...@@ -87,18 +95,21 @@ DLTensor* LoadDLTensor(dmlc::Stream* strm) {
static_cast<int>(tensor.ctx.device_type), static_cast<int>(tensor.ctx.device_type),
tensor.ctx.device_id, tensor.ctx.device_id,
&ret), 0) << TVMGetLastError(); &ret), 0) << TVMGetLastError();
int64_t size = 1; int64_t num_elems = 1;
int type_size = ret->dtype.bits / 8; int elem_bytes = (ret->dtype.bits + 7) / 8;
for (int i = 0; i < ret->ndim; ++i) { for (int i = 0; i < ret->ndim; ++i) {
size *= ret->shape[i]; num_elems *= ret->shape[i];
} }
int64_t data_byte_size; int64_t data_byte_size;
CHECK(strm->Read(&data_byte_size, sizeof(data_byte_size))) CHECK(strm->Read(&data_byte_size))
<< "Invalid DLTensor file format"; << "Invalid DLTensor file format";
CHECK(data_byte_size == type_size * size) CHECK(data_byte_size == num_elems * elem_bytes)
<< "Invalid DLTensor file format"; << "Invalid DLTensor file format";
CHECK(strm->Read(ret->data, type_size * size)) CHECK(strm->Read(ret->data, data_byte_size))
<< "Invalid DLTensor file format"; << "Invalid DLTensor file format";
if (!DMLC_IO_NO_ENDIAN_SWAP) {
dmlc::ByteSwap(ret->data, elem_bytes, num_elems);
}
return ret; return ret;
} }
...@@ -118,12 +129,12 @@ TVM_REGISTER_GLOBAL("nnvm.compiler._save_param_dict") ...@@ -118,12 +129,12 @@ TVM_REGISTER_GLOBAL("nnvm.compiler._save_param_dict")
dmlc::MemoryStringStream strm(&bytes); dmlc::MemoryStringStream strm(&bytes);
dmlc::Stream* fo = &strm; dmlc::Stream* fo = &strm;
uint64_t header = kTVMNDArrayListMagic, reserved = 0; uint64_t header = kTVMNDArrayListMagic, reserved = 0;
fo->Write(&header, sizeof(header)); fo->Write(header);
fo->Write(&reserved, sizeof(reserved)); fo->Write(reserved);
fo->Write(names); fo->Write(names);
{ {
uint64_t sz = static_cast<uint64_t>(arrays.size()); uint64_t sz = static_cast<uint64_t>(arrays.size());
fo->Write(&sz, sizeof(sz)); fo->Write(sz);
for (size_t i = 0; i < sz; ++i) { for (size_t i = 0; i < sz; ++i) {
SaveDLTensor(fo, arrays[i]); SaveDLTensor(fo, arrays[i]);
} }
...@@ -150,7 +161,6 @@ TVM_REGISTER_GLOBAL("nnvm.compiler._load_param_dict") ...@@ -150,7 +161,6 @@ TVM_REGISTER_GLOBAL("nnvm.compiler._load_param_dict")
<< "Invalid parameters file format"; << "Invalid parameters file format";
CHECK(strm->Read(&reserved)) CHECK(strm->Read(&reserved))
<< "Invalid parameters file format"; << "Invalid parameters file format";
CHECK(strm->Read(&names)) CHECK(strm->Read(&names))
<< "Invalid parameters file format"; << "Invalid parameters file format";
uint64_t sz; uint64_t sz;
......
...@@ -73,7 +73,7 @@ def sendjson(sock, data): ...@@ -73,7 +73,7 @@ def sendjson(sock, data):
Python value to be sent. Python value to be sent.
""" """
data = json.dumps(data) data = json.dumps(data)
sock.sendall(struct.pack("@i", len(data))) sock.sendall(struct.pack("<i", len(data)))
sock.sendall(data.encode("utf-8")) sock.sendall(data.encode("utf-8"))
...@@ -90,7 +90,7 @@ def recvjson(sock): ...@@ -90,7 +90,7 @@ def recvjson(sock):
value : object value : object
The value received. The value received.
""" """
size = struct.unpack("@i", recvall(sock, 4))[0] size = struct.unpack("<i", recvall(sock, 4))[0]
data = json.loads(py_str(recvall(sock, size))) data = json.loads(py_str(recvall(sock, size)))
return data return data
......
...@@ -192,8 +192,8 @@ class TrackerSession(object): ...@@ -192,8 +192,8 @@ class TrackerSession(object):
def _connect(self): def _connect(self):
self._sock = base.connect_with_retry(self._addr) self._sock = base.connect_with_retry(self._addr)
self._sock.sendall(struct.pack("@i", base.RPC_TRACKER_MAGIC)) self._sock.sendall(struct.pack("<i", base.RPC_TRACKER_MAGIC))
magic = struct.unpack("@i", base.recvall(self._sock, 4))[0] magic = struct.unpack("<i", base.recvall(self._sock, 4))[0]
if magic != base.RPC_TRACKER_MAGIC: if magic != base.RPC_TRACKER_MAGIC:
raise RuntimeError("%s is not RPC Tracker" % str(self._addr)) raise RuntimeError("%s is not RPC Tracker" % str(self._addr))
......
...@@ -58,14 +58,14 @@ class ForwardHandler(object): ...@@ -58,14 +58,14 @@ class ForwardHandler(object):
def _init_step(self, message): def _init_step(self, message):
if self._magic is None: if self._magic is None:
assert len(message) == 4 assert len(message) == 4
self._magic = struct.unpack('@i', message)[0] self._magic = struct.unpack('<i', message)[0]
if self._magic != base.RPC_MAGIC: if self._magic != base.RPC_MAGIC:
logging.info("Invalid RPC magic from %s", self.name()) logging.info("Invalid RPC magic from %s", self.name())
self.close() self.close()
self._init_req_nbytes = 4 self._init_req_nbytes = 4
elif self._rpc_key_length is None: elif self._rpc_key_length is None:
assert len(message) == 4 assert len(message) == 4
self._rpc_key_length = struct.unpack('@i', message)[0] self._rpc_key_length = struct.unpack('<i', message)[0]
self._init_req_nbytes = self._rpc_key_length self._init_req_nbytes = self._rpc_key_length
elif self.rpc_key is None: elif self.rpc_key is None:
assert len(message) == self._rpc_key_length assert len(message) == self._rpc_key_length
...@@ -269,12 +269,12 @@ class ProxyServerHandler(object): ...@@ -269,12 +269,12 @@ class ProxyServerHandler(object):
lhs.forward_proxy = rhs lhs.forward_proxy = rhs
rhs.forward_proxy = lhs rhs.forward_proxy = lhs
lhs.send_data(struct.pack('@i', base.RPC_CODE_SUCCESS)) lhs.send_data(struct.pack('<i', base.RPC_CODE_SUCCESS))
lhs.send_data(struct.pack('@i', len(rhs.rpc_key))) lhs.send_data(struct.pack('<i', len(rhs.rpc_key)))
lhs.send_data(rhs.rpc_key.encode("utf-8")) lhs.send_data(rhs.rpc_key.encode("utf-8"))
rhs.send_data(struct.pack('@i', base.RPC_CODE_SUCCESS)) rhs.send_data(struct.pack('<i', base.RPC_CODE_SUCCESS))
rhs.send_data(struct.pack('@i', len(lhs.rpc_key))) rhs.send_data(struct.pack('<i', len(lhs.rpc_key)))
rhs.send_data(lhs.rpc_key.encode("utf-8")) rhs.send_data(lhs.rpc_key.encode("utf-8"))
logging.info("Pairup connect %s and %s", lhs.name(), rhs.name()) logging.info("Pairup connect %s and %s", lhs.name(), rhs.name())
...@@ -299,8 +299,8 @@ class ProxyServerHandler(object): ...@@ -299,8 +299,8 @@ class ProxyServerHandler(object):
if self._tracker_conn is None: if self._tracker_conn is None:
self._tracker_conn = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self._tracker_conn = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._tracker_conn.connect(self._tracker_addr) self._tracker_conn.connect(self._tracker_addr)
self._tracker_conn.sendall(struct.pack("@i", base.RPC_TRACKER_MAGIC)) self._tracker_conn.sendall(struct.pack("<i", base.RPC_TRACKER_MAGIC))
magic = struct.unpack("@i", base.recvall(self._tracker_conn, 4))[0] magic = struct.unpack("<i", base.recvall(self._tracker_conn, 4))[0]
if magic != base.RPC_TRACKER_MAGIC: if magic != base.RPC_TRACKER_MAGIC:
self.loop.stop() self.loop.stop()
raise RuntimeError("%s is not RPC Tracker" % str(self._tracker_addr)) raise RuntimeError("%s is not RPC Tracker" % str(self._tracker_addr))
...@@ -371,7 +371,7 @@ class ProxyServerHandler(object): ...@@ -371,7 +371,7 @@ class ProxyServerHandler(object):
if handler.match_key in self._server_pool: if handler.match_key in self._server_pool:
self._pair_up(self._server_pool.pop(handler.match_key), handler) self._pair_up(self._server_pool.pop(handler.match_key), handler)
else: else:
handler.send_data(struct.pack('@i', base.RPC_CODE_MISMATCH)) handler.send_data(struct.pack('<i', base.RPC_CODE_MISMATCH))
handler.signal_close() handler.signal_close()
def _handler_ready_proxy_mode(self, handler): def _handler_ready_proxy_mode(self, handler):
...@@ -395,12 +395,12 @@ class ProxyServerHandler(object): ...@@ -395,12 +395,12 @@ class ProxyServerHandler(object):
logging.info("Timeout client connection %s, cannot find match key=%s", logging.info("Timeout client connection %s, cannot find match key=%s",
handler.name(), key) handler.name(), key)
pool_dst.pop(key) pool_dst.pop(key)
handler.send_data(struct.pack('@i', base.RPC_CODE_MISMATCH)) handler.send_data(struct.pack('<i', base.RPC_CODE_MISMATCH))
handler.signal_close() handler.signal_close()
self.loop.call_later(timeout, cleanup) self.loop.call_later(timeout, cleanup)
else: else:
logging.info("Duplicate connection with same key=%s", key) logging.info("Duplicate connection with same key=%s", key)
handler.send_data(struct.pack('@i', base.RPC_CODE_DUPLICATE)) handler.send_data(struct.pack('<i', base.RPC_CODE_DUPLICATE))
handler.signal_close() handler.signal_close()
def handler_ready(self, handler): def handler_ready(self, handler):
...@@ -538,13 +538,13 @@ def websocket_proxy_server(url, key=""): ...@@ -538,13 +538,13 @@ def websocket_proxy_server(url, key=""):
on_message = create_on_message(conn) on_message = create_on_message(conn)
temp = _server_env(None) temp = _server_env(None)
# Start connecton # Start connecton
conn.write_message(struct.pack('@i', base.RPC_MAGIC), binary=True) conn.write_message(struct.pack('<i', base.RPC_MAGIC), binary=True)
key = "server:" + key key = "server:" + key
conn.write_message(struct.pack('@i', len(key)), binary=True) conn.write_message(struct.pack('<i', len(key)), binary=True)
conn.write_message(key.encode("utf-8"), binary=True) conn.write_message(key.encode("utf-8"), binary=True)
msg = yield conn.read_message() msg = yield conn.read_message()
assert len(msg) >= 4 assert len(msg) >= 4
magic = struct.unpack('@i', msg[:4])[0] magic = struct.unpack('<i', msg[:4])[0]
if magic == base.RPC_CODE_DUPLICATE: if magic == base.RPC_CODE_DUPLICATE:
raise RuntimeError("key: %s has already been used in proxy" % key) raise RuntimeError("key: %s has already been used in proxy" % key)
elif magic == base.RPC_CODE_MISMATCH: elif magic == base.RPC_CODE_MISMATCH:
......
...@@ -124,23 +124,23 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library): ...@@ -124,23 +124,23 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library):
unmatch_period_count = 0 unmatch_period_count = 0
continue continue
conn, addr = listen_sock.accept() conn, addr = listen_sock.accept()
magic = struct.unpack("@i", base.recvall(conn, 4))[0] magic = struct.unpack("<i", base.recvall(conn, 4))[0]
if magic != base.RPC_MAGIC: if magic != base.RPC_MAGIC:
conn.close() conn.close()
continue continue
keylen = struct.unpack("@i", base.recvall(conn, 4))[0] keylen = struct.unpack("<i", base.recvall(conn, 4))[0]
key = py_str(base.recvall(conn, keylen)) key = py_str(base.recvall(conn, keylen))
arr = key.split() arr = key.split()
expect_header = "client:" + matchkey expect_header = "client:" + matchkey
server_key = "server:" + rpc_key server_key = "server:" + rpc_key
if arr[0] != expect_header: if arr[0] != expect_header:
conn.sendall(struct.pack("@i", base.RPC_CODE_MISMATCH)) conn.sendall(struct.pack("<i", base.RPC_CODE_MISMATCH))
conn.close() conn.close()
logging.info("RPCServer: mismatch key from %s", addr) logging.info("RPCServer: mismatch key from %s", addr)
continue continue
else: else:
conn.sendall(struct.pack("@i", base.RPC_CODE_SUCCESS)) conn.sendall(struct.pack("<i", base.RPC_CODE_SUCCESS))
conn.sendall(struct.pack("@i", len(server_key))) conn.sendall(struct.pack("<i", len(server_key)))
conn.sendall(server_key.encode("utf-8")) conn.sendall(server_key.encode("utf-8"))
return conn, addr, _parse_server_opt(arr[1:]) return conn, addr, _parse_server_opt(arr[1:])
...@@ -151,8 +151,8 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library): ...@@ -151,8 +151,8 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library):
# step 1: setup tracker and report to tracker # step 1: setup tracker and report to tracker
if tracker_addr and tracker_conn is None: if tracker_addr and tracker_conn is None:
tracker_conn = base.connect_with_retry(tracker_addr) tracker_conn = base.connect_with_retry(tracker_addr)
tracker_conn.sendall(struct.pack("@i", base.RPC_TRACKER_MAGIC)) tracker_conn.sendall(struct.pack("<i", base.RPC_TRACKER_MAGIC))
magic = struct.unpack("@i", base.recvall(tracker_conn, 4))[0] magic = struct.unpack("<i", base.recvall(tracker_conn, 4))[0]
if magic != base.RPC_TRACKER_MAGIC: if magic != base.RPC_TRACKER_MAGIC:
raise RuntimeError("%s is not RPC Tracker" % str(tracker_addr)) raise RuntimeError("%s is not RPC Tracker" % str(tracker_addr))
# report status of current queue # report status of current queue
...@@ -193,17 +193,17 @@ def _connect_proxy_loop(addr, key, load_library): ...@@ -193,17 +193,17 @@ def _connect_proxy_loop(addr, key, load_library):
try: try:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect(addr) sock.connect(addr)
sock.sendall(struct.pack("@i", base.RPC_MAGIC)) sock.sendall(struct.pack("<i", base.RPC_MAGIC))
sock.sendall(struct.pack("@i", len(key))) sock.sendall(struct.pack("<i", len(key)))
sock.sendall(key.encode("utf-8")) sock.sendall(key.encode("utf-8"))
magic = struct.unpack("@i", base.recvall(sock, 4))[0] magic = struct.unpack("<i", base.recvall(sock, 4))[0]
if magic == base.RPC_CODE_DUPLICATE: if magic == base.RPC_CODE_DUPLICATE:
raise RuntimeError("key: %s has already been used in proxy" % key) raise RuntimeError("key: %s has already been used in proxy" % key)
elif magic == base.RPC_CODE_MISMATCH: elif magic == base.RPC_CODE_MISMATCH:
logging.info("RPCProxy do not have matching client key %s", key) logging.info("RPCProxy do not have matching client key %s", key)
elif magic != base.RPC_CODE_SUCCESS: elif magic != base.RPC_CODE_SUCCESS:
raise RuntimeError("%s is not RPC Proxy" % str(addr)) raise RuntimeError("%s is not RPC Proxy" % str(addr))
keylen = struct.unpack("@i", base.recvall(sock, 4))[0] keylen = struct.unpack("<i", base.recvall(sock, 4))[0]
remote_key = py_str(base.recvall(sock, keylen)) remote_key = py_str(base.recvall(sock, keylen))
opts = _parse_server_opt(remote_key.split()[1:]) opts = _parse_server_opt(remote_key.split()[1:])
logging.info("RPCProxy connected to %s", str(addr)) logging.info("RPCProxy connected to %s", str(addr))
......
...@@ -143,11 +143,11 @@ class TCPEventHandler(tornado_util.TCPHandler): ...@@ -143,11 +143,11 @@ class TCPEventHandler(tornado_util.TCPHandler):
if len(message) != 4: if len(message) != 4:
logging.info("Invalid connection from %s", self.name()) logging.info("Invalid connection from %s", self.name())
self.close() self.close()
magic = struct.unpack('@i', message)[0] magic = struct.unpack('<i', message)[0]
if magic != RPC_TRACKER_MAGIC: if magic != RPC_TRACKER_MAGIC:
logging.info("Invalid magic from %s", self.name()) logging.info("Invalid magic from %s", self.name())
self.close() self.close()
self.write_message(struct.pack('@i', RPC_TRACKER_MAGIC), binary=True) self.write_message(struct.pack('<i', RPC_TRACKER_MAGIC), binary=True)
self._init_req_nbytes = 0 self._init_req_nbytes = 0
def on_message(self, message): def on_message(self, message):
...@@ -168,7 +168,7 @@ class TCPEventHandler(tornado_util.TCPHandler): ...@@ -168,7 +168,7 @@ class TCPEventHandler(tornado_util.TCPHandler):
while True: while True:
if self._msg_size == 0: if self._msg_size == 0:
if len(self._data) >= 4: if len(self._data) >= 4:
self._msg_size = struct.unpack('@i', self._data[:4])[0] self._msg_size = struct.unpack('<i', self._data[:4])[0]
else: else:
return return
if self._msg_size != 0 and len(self._data) >= self._msg_size + 4: if self._msg_size != 0 and len(self._data) >= self._msg_size + 4:
...@@ -184,7 +184,7 @@ class TCPEventHandler(tornado_util.TCPHandler): ...@@ -184,7 +184,7 @@ class TCPEventHandler(tornado_util.TCPHandler):
"""return value to the output""" """return value to the output"""
data = json.dumps(data) data = json.dumps(data)
self.write_message( self.write_message(
struct.pack('@i', len(data)), binary=True) struct.pack('<i', len(data)), binary=True)
self.write_message(data.encode("utf-8"), binary=True) self.write_message(data.encode("utf-8"), binary=True)
def call_handler(self, args): def call_handler(self, args):
...@@ -355,8 +355,8 @@ class Tracker(object): ...@@ -355,8 +355,8 @@ class Tracker(object):
def _stop_tracker(self): def _stop_tracker(self):
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect((self.host, self.port)) sock.connect((self.host, self.port))
sock.sendall(struct.pack("@i", base.RPC_TRACKER_MAGIC)) sock.sendall(struct.pack("<i", base.RPC_TRACKER_MAGIC))
magic = struct.unpack("@i", base.recvall(sock, 4))[0] magic = struct.unpack("<i", base.recvall(sock, 4))[0]
assert magic == base.RPC_TRACKER_MAGIC assert magic == base.RPC_TRACKER_MAGIC
base.sendjson(sock, [TrackerCode.STOP, self.stop_key]) base.sendjson(sock, [TrackerCode.STOP, self.stop_key])
assert base.recvjson(sock) == TrackerCode.SUCCESS assert base.recvjson(sock) == TrackerCode.SUCCESS
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
*/ */
#include <dmlc/json.h> #include <dmlc/json.h>
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <tvm/runtime/serializer.h>
#include <fstream> #include <fstream>
#include "./file_util.h" #include "./file_util.h"
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
*/ */
#include <tvm/runtime/packed_func.h> #include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/runtime/serializer.h>
#include <dmlc/memory_io.h> #include <dmlc/memory_io.h>
#include <dmlc/json.h> #include <dmlc/json.h>
#include <numeric> #include <numeric>
...@@ -397,24 +398,25 @@ class GraphRuntime : public ModuleNode { ...@@ -397,24 +398,25 @@ class GraphRuntime : public ModuleNode {
void GraphRuntime::LoadDLTensor(dmlc::Stream* strm, DLTensor* dst) { void GraphRuntime::LoadDLTensor(dmlc::Stream* strm, DLTensor* dst) {
// always use strm->Read to maintain endianness conversion
uint64_t header, reserved; uint64_t header, reserved;
CHECK(strm->Read(&header, sizeof(header))) CHECK(strm->Read(&header))
<< "Invalid DLTensor file format"; << "Invalid DLTensor file format";
CHECK(strm->Read(&reserved, sizeof(reserved))) CHECK(strm->Read(&reserved))
<< "Invalid DLTensor file format"; << "Invalid DLTensor file format";
CHECK(header == kTVMNDArrayMagic) CHECK(header == kTVMNDArrayMagic)
<< "Invalid DLTensor file format"; << "Invalid DLTensor file format";
DLTensor tensor; DLTensor tensor;
CHECK(strm->Read(&tensor.ctx, sizeof(tensor.ctx))) CHECK(strm->Read(&(tensor.ctx)))
<< "Invalid DLTensor file format"; << "Invalid DLTensor file format";
CHECK(strm->Read(&tensor.ndim, sizeof(tensor.ndim))) CHECK(strm->Read(&(tensor.ndim)))
<< "Invalid DLTensor file format"; << "Invalid DLTensor file format";
CHECK(strm->Read(&tensor.dtype, sizeof(tensor.dtype))) CHECK(strm->Read(&(tensor.dtype)))
<< "Invalid DLTensor file format"; << "Invalid DLTensor file format";
std::vector<int64_t> shape(tensor.ndim); std::vector<int64_t> shape(tensor.ndim);
if (tensor.ndim != 0) { if (tensor.ndim != 0) {
CHECK(strm->Read(&shape[0], sizeof(int64_t) * tensor.ndim)) CHECK(strm->ReadArray(&shape[0], tensor.ndim))
<< "Invalid DLTensor file format"; << "Invalid DLTensor file format";
} }
CHECK_EQ(tensor.ndim, dst->ndim) << "param dimension mismatch"; CHECK_EQ(tensor.ndim, dst->ndim) << "param dimension mismatch";
...@@ -425,18 +427,23 @@ void GraphRuntime::LoadDLTensor(dmlc::Stream* strm, DLTensor* dst) { ...@@ -425,18 +427,23 @@ void GraphRuntime::LoadDLTensor(dmlc::Stream* strm, DLTensor* dst) {
CHECK_EQ(shape[i], dst->shape[i]) << "param shape mismatch"; CHECK_EQ(shape[i], dst->shape[i]) << "param shape mismatch";
} }
size_t bits = dst->dtype.bits * dst->dtype.lanes; size_t bits = dst->dtype.bits * dst->dtype.lanes;
size_t size = (bits + 7) / 8; size_t elem_bytes = (bits + 7) / 8;
size_t num_elems = 1;
for (int i = 0; i < dst->ndim; ++i) { for (int i = 0; i < dst->ndim; ++i) {
size *= dst->shape[i]; num_elems *= dst->shape[i];
} }
uint64_t data_byte_size; uint64_t data_byte_size;
CHECK(strm->Read(&data_byte_size, sizeof(data_byte_size))) CHECK(strm->Read(&data_byte_size))
<< "Invalid DLTensor file format"; << "Invalid DLTensor file format";
CHECK(data_byte_size == size) CHECK_EQ(data_byte_size, elem_bytes * num_elems)
<< "Invalid DLTensor file format"; << "Invalid DLTensor file format";
std::vector<uint8_t> bytes(data_byte_size + 1); std::vector<uint8_t> bytes(data_byte_size + 1);
CHECK(strm->Read(&bytes[0], data_byte_size)) CHECK(strm->Read(&bytes[0], data_byte_size))
<< "Invalid DLTensor file format"; << "Invalid DLTensor file format";
// explicitly swap endian when necessary.
if (!DMLC_IO_NO_ENDIAN_SWAP) {
dmlc::ByteSwap(&bytes[0], elem_bytes, num_elems);
}
TVM_CCALL(TVMArrayCopyFromBytes(dst, &bytes[0], data_byte_size)); TVM_CCALL(TVMArrayCopyFromBytes(dst, &bytes[0], data_byte_size));
} }
...@@ -453,9 +460,8 @@ void GraphRuntime::LoadParams(dmlc::Stream* strm) { ...@@ -453,9 +460,8 @@ void GraphRuntime::LoadParams(dmlc::Stream* strm) {
CHECK(strm->Read(&names)) CHECK(strm->Read(&names))
<< "Invalid parameters file format"; << "Invalid parameters file format";
uint64_t sz; uint64_t sz;
strm->Read(&sz, sizeof(sz)); strm->Read(&sz);
size_t size = static_cast<size_t>(sz); size_t size = static_cast<size_t>(sz);
CHECK(size == names.size()) CHECK(size == names.size())
<< "Invalid parameters file format"; << "Invalid parameters file format";
for (size_t i = 0; i < size; ++i) { for (size_t i = 0; i < size; ++i) {
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <tvm/runtime/packed_func.h> #include <tvm/runtime/packed_func.h>
#include <tvm/runtime/device_api.h> #include <tvm/runtime/device_api.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/runtime/serializer.h>
#include <memory> #include <memory>
#include <array> #include <array>
#include <string> #include <string>
...@@ -44,7 +45,7 @@ struct RPCArgBuffer { ...@@ -44,7 +45,7 @@ struct RPCArgBuffer {
}; };
// Event handler for RPC events. // Event handler for RPC events.
class RPCSession::EventHandler { class RPCSession::EventHandler : public dmlc::Stream {
public: public:
EventHandler(common::RingBuffer* reader, EventHandler(common::RingBuffer* reader,
common::RingBuffer* writer, common::RingBuffer* writer,
...@@ -71,6 +72,15 @@ class RPCSession::EventHandler { ...@@ -71,6 +72,15 @@ class RPCSession::EventHandler {
return 0; return 0;
} }
} }
// Request number of bytes from reader.
void RequestBytes(size_t nbytes) {
pending_request_bytes_ += nbytes;
reader_->Reserve(pending_request_bytes_);
}
// Whether we are ready to handle next request.
bool Ready() {
return reader_->bytes_available() >= pending_request_bytes_;
}
bool CanCleanShutdown() const { bool CanCleanShutdown() const {
return state_ == kRecvCode; return state_ == kRecvCode;
} }
...@@ -86,12 +96,12 @@ class RPCSession::EventHandler { ...@@ -86,12 +96,12 @@ class RPCSession::EventHandler {
case kInitHeader: HandleInitHeader(); break; case kInitHeader: HandleInitHeader(); break;
case kRecvCode: HandleRecvCode(); break; case kRecvCode: HandleRecvCode(); break;
case kRecvCallHandle: { case kRecvCallHandle: {
this->Read(&call_handle_, sizeof(call_handle_)); CHECK(this->Read(&call_handle_));
this->SwitchToState(kRecvPackedSeqNumArgs); this->SwitchToState(kRecvPackedSeqNumArgs);
break; break;
} }
case kRecvPackedSeqNumArgs: { case kRecvPackedSeqNumArgs: {
this->Read(&num_packed_args_, sizeof(num_packed_args_)); CHECK(this->Read(&num_packed_args_));
arg_buf_.reset(new RPCArgBuffer()); arg_buf_.reset(new RPCArgBuffer());
arg_buf_->value.resize(num_packed_args_); arg_buf_->value.resize(num_packed_args_);
arg_buf_->tcode.resize(num_packed_args_); arg_buf_->tcode.resize(num_packed_args_);
...@@ -100,7 +110,7 @@ class RPCSession::EventHandler { ...@@ -100,7 +110,7 @@ class RPCSession::EventHandler {
} }
case kRecvPackedSeqTypeCode: { case kRecvPackedSeqTypeCode: {
if (num_packed_args_ != 0) { if (num_packed_args_ != 0) {
this->Read(arg_buf_->tcode.data(), sizeof(int) * num_packed_args_); this->ReadArray(arg_buf_->tcode.data(), num_packed_args_);
} }
arg_index_ = 0; arg_index_ = 0;
arg_recv_stage_ = 0; arg_recv_stage_ = 0;
...@@ -164,8 +174,8 @@ class RPCSession::EventHandler { ...@@ -164,8 +174,8 @@ class RPCSession::EventHandler {
} }
// send Packed sequence to writer. // send Packed sequence to writer.
void SendPackedSeq(const TVMValue* arg_values, const int* type_codes, int n) { void SendPackedSeq(const TVMValue* arg_values, const int* type_codes, int n) {
writer_->Write(&n, sizeof(n)); this->Write(n);
writer_->Write(type_codes, sizeof(int) * n); this->WriteArray(type_codes, n);
// Argument packing. // Argument packing.
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
int tcode = type_codes[i]; int tcode = type_codes[i];
...@@ -173,14 +183,20 @@ class RPCSession::EventHandler { ...@@ -173,14 +183,20 @@ class RPCSession::EventHandler {
switch (tcode) { switch (tcode) {
case kDLInt: case kDLInt:
case kDLUInt: case kDLUInt:
case kDLFloat: case kDLFloat: {
this->Write<int64_t>(value.v_int64);
break;
}
case kTVMType: { case kTVMType: {
writer_->Write(&value, sizeof(TVMValue)); this->Write(value.v_type);
// padding
int32_t padding = 0;
this->Write<int32_t>(padding);
break; break;
} }
case kTVMContext: { case kTVMContext: {
value.v_ctx = StripSessMask(value.v_ctx); value.v_ctx = StripSessMask(value.v_ctx);
writer_->Write(&value, sizeof(TVMValue)); this->Write(value.v_ctx);
break; break;
} }
case kFuncHandle: case kFuncHandle:
...@@ -188,7 +204,7 @@ class RPCSession::EventHandler { ...@@ -188,7 +204,7 @@ class RPCSession::EventHandler {
case kHandle: { case kHandle: {
// always send handle in 64 bit. // always send handle in 64 bit.
uint64_t handle = reinterpret_cast<uint64_t>(value.v_handle); uint64_t handle = reinterpret_cast<uint64_t>(value.v_handle);
writer_->Write(&handle, sizeof(uint64_t)); this->Write(handle);
break; break;
} }
case kArrayHandle: { case kArrayHandle: {
...@@ -196,11 +212,11 @@ class RPCSession::EventHandler { ...@@ -196,11 +212,11 @@ class RPCSession::EventHandler {
TVMContext ctx = StripSessMask(arr->ctx); TVMContext ctx = StripSessMask(arr->ctx);
uint64_t data = reinterpret_cast<uint64_t>( uint64_t data = reinterpret_cast<uint64_t>(
static_cast<RemoteSpace*>(arr->data)->data); static_cast<RemoteSpace*>(arr->data)->data);
writer_->Write(&data, sizeof(uint64_t)); this->Write(data);
writer_->Write(&ctx, sizeof(ctx)); this->Write(ctx);
writer_->Write(&(arr->ndim), sizeof(int)); this->Write(arr->ndim);
writer_->Write(&(arr->dtype), sizeof(DLDataType)); this->Write(arr->dtype);
writer_->Write(arr->shape, sizeof(int64_t) * arr->ndim); this->WriteArray(arr->shape, arr->ndim);
CHECK(arr->strides == nullptr) CHECK(arr->strides == nullptr)
<< "Donot support strided remote array"; << "Donot support strided remote array";
CHECK_EQ(arr->byte_offset, 0) CHECK_EQ(arr->byte_offset, 0)
...@@ -211,15 +227,15 @@ class RPCSession::EventHandler { ...@@ -211,15 +227,15 @@ class RPCSession::EventHandler {
case kStr: { case kStr: {
const char* s = value.v_str; const char* s = value.v_str;
uint64_t len = strlen(s); uint64_t len = strlen(s);
writer_->Write(&len, sizeof(len)); this->Write(len);
writer_->Write(s, sizeof(char) * len); this->WriteArray(s, len);
break; break;
} }
case kBytes: { case kBytes: {
TVMByteArray* bytes = static_cast<TVMByteArray*>(arg_values[i].v_handle); TVMByteArray* bytes = static_cast<TVMByteArray*>(arg_values[i].v_handle);
uint64_t len = bytes->size; uint64_t len = bytes->size;
writer_->Write(&len, sizeof(len)); this->Write(len);
writer_->Write(bytes->data, sizeof(char) * len); this->WriteArray(bytes->data, len);
break; break;
} }
default: { default: {
...@@ -230,6 +246,23 @@ class RPCSession::EventHandler { ...@@ -230,6 +246,23 @@ class RPCSession::EventHandler {
} }
} }
// Endian aware IO handling
using Stream::Read;
using Stream::Write;
using Stream::ReadArray;
using Stream::WriteArray;
inline bool Read(RPCCode* code) {
int cdata;
if (!this->Read(&cdata)) return false;
*code = static_cast<RPCCode>(cdata);
return true;
}
inline void Write(RPCCode code) {
int cdata = static_cast<int>(code);
this->Write(cdata);
}
protected: protected:
enum State { enum State {
kInitHeader, kInitHeader,
...@@ -370,10 +403,22 @@ class RPCSession::EventHandler { ...@@ -370,10 +403,22 @@ class RPCSession::EventHandler {
switch (tcode) { switch (tcode) {
case kDLInt: case kDLInt:
case kDLUInt: case kDLUInt:
case kDLFloat: case kDLFloat: {
case kTVMType: this->Read<int64_t>(&(value.v_int64));
++arg_index_;
this->SwitchToState(kRecvPackedSeqArg);
break;
}
case kTVMType: {
this->Read(&(value.v_type));
int32_t padding = 0;
this->Read<int32_t>(&padding);
++arg_index_;
this->SwitchToState(kRecvPackedSeqArg);
break;
}
case kTVMContext: { case kTVMContext: {
this->Read(&value, sizeof(TVMValue)); this->Read(&(value.v_ctx));
++arg_index_; ++arg_index_;
this->SwitchToState(kRecvPackedSeqArg); this->SwitchToState(kRecvPackedSeqArg);
break; break;
...@@ -383,7 +428,7 @@ class RPCSession::EventHandler { ...@@ -383,7 +428,7 @@ class RPCSession::EventHandler {
case kHandle: { case kHandle: {
// always send handle in 64 bit. // always send handle in 64 bit.
uint64_t handle; uint64_t handle;
this->Read(&handle, sizeof(handle)); this->Read(&handle);
value.v_handle = reinterpret_cast<void*>(handle); value.v_handle = reinterpret_cast<void*>(handle);
++arg_index_; ++arg_index_;
this->SwitchToState(kRecvPackedSeqArg); this->SwitchToState(kRecvPackedSeqArg);
...@@ -398,7 +443,7 @@ class RPCSession::EventHandler { ...@@ -398,7 +443,7 @@ class RPCSession::EventHandler {
case kStr: case kStr:
case kBytes: { case kBytes: {
uint64_t len; uint64_t len;
this->Read(&len, sizeof(len)); this->Read(&len);
temp_bytes_.reset( new RPCByteArrayBuffer()); temp_bytes_.reset( new RPCByteArrayBuffer());
temp_bytes_->data.resize(len); temp_bytes_->data.resize(len);
arg_recv_stage_ = 1; arg_recv_stage_ = 1;
...@@ -409,12 +454,12 @@ class RPCSession::EventHandler { ...@@ -409,12 +454,12 @@ class RPCSession::EventHandler {
case kArrayHandle: { case kArrayHandle: {
temp_array_.reset(new RPCDataArrayBuffer()); temp_array_.reset(new RPCDataArrayBuffer());
uint64_t handle; uint64_t handle;
this->Read(&handle, sizeof(handle)); this->Read(&handle);
DLTensor& tensor = temp_array_->tensor; DLTensor& tensor = temp_array_->tensor;
tensor.data = reinterpret_cast<void*>(handle); tensor.data = reinterpret_cast<void*>(handle);
this->Read(&(tensor.ctx), sizeof(TVMContext)); this->Read(&(tensor.ctx));
this->Read(&(tensor.ndim), sizeof(int)); this->Read(&(tensor.ndim));
this->Read(&(tensor.dtype), sizeof(DLDataType)); this->Read(&(tensor.dtype));
temp_array_->shape.resize(tensor.ndim); temp_array_->shape.resize(tensor.ndim);
tensor.shape = temp_array_->shape.data(); tensor.shape = temp_array_->shape.data();
arg_recv_stage_ = 1; arg_recv_stage_ = 1;
...@@ -432,7 +477,7 @@ class RPCSession::EventHandler { ...@@ -432,7 +477,7 @@ class RPCSession::EventHandler {
CHECK_EQ(arg_recv_stage_, 1); CHECK_EQ(arg_recv_stage_, 1);
if (tcode == kStr || tcode == kBytes) { if (tcode == kStr || tcode == kBytes) {
if (temp_bytes_->data.size() != 0) { if (temp_bytes_->data.size() != 0) {
this->Read(&(temp_bytes_->data[0]), temp_bytes_->data.size()); this->ReadArray(&(temp_bytes_->data[0]), temp_bytes_->data.size());
} }
if (tcode == kStr) { if (tcode == kStr) {
value.v_str = temp_bytes_->data.c_str(); value.v_str = temp_bytes_->data.c_str();
...@@ -445,7 +490,7 @@ class RPCSession::EventHandler { ...@@ -445,7 +490,7 @@ class RPCSession::EventHandler {
} else { } else {
CHECK_EQ(tcode, kArrayHandle); CHECK_EQ(tcode, kArrayHandle);
DLTensor& tensor = temp_array_->tensor; DLTensor& tensor = temp_array_->tensor;
this->Read(tensor.shape, tensor.ndim * sizeof(int64_t)); this->ReadArray(tensor.shape, tensor.ndim);
value.v_handle = &tensor; value.v_handle = &tensor;
arg_buf_->temp_array.emplace_back(std::move(temp_array_)); arg_buf_->temp_array.emplace_back(std::move(temp_array_));
} }
...@@ -458,20 +503,20 @@ class RPCSession::EventHandler { ...@@ -458,20 +503,20 @@ class RPCSession::EventHandler {
void HandleInitHeader() { void HandleInitHeader() {
if (init_header_step_ == 0) { if (init_header_step_ == 0) {
int32_t len; int32_t len;
this->Read(&len, sizeof(len)); this->Read(&len);
remote_key_->resize(len); remote_key_->resize(len);
init_header_step_ = 1; init_header_step_ = 1;
this->RequestBytes(len); this->RequestBytes(len);
return; return;
} else { } else {
CHECK_EQ(init_header_step_, 1); CHECK_EQ(init_header_step_, 1);
this->Read(dmlc::BeginPtr(*remote_key_), remote_key_->length()); this->ReadArray(dmlc::BeginPtr(*remote_key_), remote_key_->length());
this->SwitchToState(kRecvCode); this->SwitchToState(kRecvCode);
} }
} }
// Handler for read code. // Handler for read code.
void HandleRecvCode() { void HandleRecvCode() {
this->Read(&code_, sizeof(code_)); this->Read(&code_);
if (code_ > RPCCode::kSystemFuncStart) { if (code_ > RPCCode::kSystemFuncStart) {
SwitchToState(kRecvPackedSeqNumArgs); SwitchToState(kRecvPackedSeqNumArgs);
return; return;
...@@ -511,14 +556,14 @@ class RPCSession::EventHandler { ...@@ -511,14 +556,14 @@ class RPCSession::EventHandler {
void HandleCopyFromRemote() { void HandleCopyFromRemote() {
uint64_t handle, offset, size; uint64_t handle, offset, size;
TVMContext ctx; TVMContext ctx;
this->Read(&handle, sizeof(handle)); this->Read(&handle);
this->Read(&offset, sizeof(offset)); this->Read(&offset);
this->Read(&size, sizeof(size)); this->Read(&size);
this->Read(&ctx, sizeof(ctx)); this->Read(&ctx);
if (ctx.device_type == kDLCPU) { if (ctx.device_type == kDLCPU) {
RPCCode code = RPCCode::kCopyAck; RPCCode code = RPCCode::kCopyAck;
writer_->Write(&code, sizeof(code)); this->Write(code);
writer_->Write(reinterpret_cast<char*>(handle) + offset, size); this->WriteArray(reinterpret_cast<char*>(handle) + offset, size);
} else { } else {
temp_data_.resize(size + 1); temp_data_.resize(size + 1);
try { try {
...@@ -530,11 +575,11 @@ class RPCSession::EventHandler { ...@@ -530,11 +575,11 @@ class RPCSession::EventHandler {
dmlc::BeginPtr(temp_data_), 0, dmlc::BeginPtr(temp_data_), 0,
size, ctx, cpu_ctx, nullptr); size, ctx, cpu_ctx, nullptr);
RPCCode code = RPCCode::kCopyAck; RPCCode code = RPCCode::kCopyAck;
writer_->Write(&code, sizeof(code)); this->Write(code);
writer_->Write(&temp_data_[0], size); this->WriteArray(&temp_data_[0], size);
} catch (const std::runtime_error &e) { } catch (const std::runtime_error &e) {
RPCCode code = RPCCode::kException; RPCCode code = RPCCode::kException;
writer_->Write(&code, sizeof(code)); this->Write(code);
TVMValue ret_value; TVMValue ret_value;
ret_value.v_str = e.what(); ret_value.v_str = e.what();
int ret_tcode = kStr; int ret_tcode = kStr;
...@@ -548,10 +593,10 @@ class RPCSession::EventHandler { ...@@ -548,10 +593,10 @@ class RPCSession::EventHandler {
// use static variable to persist state. // use static variable to persist state.
// This only works if next stage is immediately after this. // This only works if next stage is immediately after this.
if (arg_recv_stage_ == 0) { if (arg_recv_stage_ == 0) {
this->Read(&copy_handle_, sizeof(uint64_t)); CHECK(this->Read(&copy_handle_));
this->Read(&copy_offset_, sizeof(uint64_t)); CHECK(this->Read(&copy_offset_));
this->Read(&copy_size_, sizeof(uint64_t)); CHECK(this->Read(&copy_size_));
this->Read(&copy_ctx_, sizeof(TVMContext)); CHECK(this->Read(&copy_ctx_));
arg_recv_stage_ = 1; arg_recv_stage_ = 1;
CHECK_EQ(pending_request_bytes_, 0U); CHECK_EQ(pending_request_bytes_, 0U);
this->RequestBytes(copy_size_); this->RequestBytes(copy_size_);
...@@ -563,11 +608,11 @@ class RPCSession::EventHandler { ...@@ -563,11 +608,11 @@ class RPCSession::EventHandler {
RPCCode code = RPCCode::kReturn; RPCCode code = RPCCode::kReturn;
std::string errmsg; std::string errmsg;
if (copy_ctx_.device_type == kDLCPU) { if (copy_ctx_.device_type == kDLCPU) {
this->Read( this->ReadArray(
reinterpret_cast<char*>(copy_handle_) + copy_offset_, copy_size_); reinterpret_cast<char*>(copy_handle_) + copy_offset_, copy_size_);
} else { } else {
temp_data_.resize(copy_size_ + 1); temp_data_.resize(copy_size_ + 1);
this->Read(&temp_data_[0], copy_size_); this->ReadArray(&temp_data_[0], copy_size_);
try { try {
TVMContext cpu_ctx; TVMContext cpu_ctx;
cpu_ctx.device_type = kDLCPU; cpu_ctx.device_type = kDLCPU;
...@@ -583,7 +628,7 @@ class RPCSession::EventHandler { ...@@ -583,7 +628,7 @@ class RPCSession::EventHandler {
ret_tcode = kStr; ret_tcode = kStr;
} }
} }
writer_->Write(&code, sizeof(code)); this->Write(code);
SendPackedSeq(&ret_value, &ret_tcode, 1); SendPackedSeq(&ret_value, &ret_tcode, 1);
arg_recv_stage_ = 0; arg_recv_stage_ = 0;
this->SwitchToState(kRecvCode); this->SwitchToState(kRecvCode);
...@@ -603,7 +648,7 @@ class RPCSession::EventHandler { ...@@ -603,7 +648,7 @@ class RPCSession::EventHandler {
std::unique_ptr<RPCArgBuffer> args = std::move(arg_buf_); std::unique_ptr<RPCArgBuffer> args = std::move(arg_buf_);
f(args->AsTVMArgs(), &rv); f(args->AsTVMArgs(), &rv);
RPCCode code = RPCCode::kReturn; RPCCode code = RPCCode::kReturn;
writer_->Write(&code, sizeof(code)); this->Write(code);
if (rv.type_code() == kStr) { if (rv.type_code() == kStr) {
ret_value.v_str = rv.ptr<std::string>()->c_str(); ret_value.v_str = rv.ptr<std::string>()->c_str();
ret_tcode = kStr; ret_tcode = kStr;
...@@ -630,7 +675,7 @@ class RPCSession::EventHandler { ...@@ -630,7 +675,7 @@ class RPCSession::EventHandler {
} }
} catch (const std::runtime_error& e) { } catch (const std::runtime_error& e) {
RPCCode code = RPCCode::kException; RPCCode code = RPCCode::kException;
writer_->Write(&code, sizeof(code)); this->Write(code);
ret_value.v_str = e.what(); ret_value.v_str = e.what();
ret_tcode = kStr; ret_tcode = kStr;
SendPackedSeq(&ret_value, &ret_tcode, 1); SendPackedSeq(&ret_value, &ret_tcode, 1);
...@@ -640,19 +685,14 @@ class RPCSession::EventHandler { ...@@ -640,19 +685,14 @@ class RPCSession::EventHandler {
private: private:
// Utility functions // Utility functions
// Internal read function, update pending_request_bytes_ // Internal read function, update pending_request_bytes_
void Read(void* data, size_t size) { size_t Read(void* data, size_t size) final {
CHECK_LE(size, pending_request_bytes_); CHECK_LE(size, pending_request_bytes_);
reader_->Read(data, size); reader_->Read(data, size);
pending_request_bytes_ -= size; pending_request_bytes_ -= size;
return size;
} }
// Request number of bytes from reader. void Write(const void* data, size_t size) final {
void RequestBytes(size_t nbytes) { writer_->Write(data, size);
pending_request_bytes_ += nbytes;
reader_->Reserve(pending_request_bytes_);
}
// Whether we are ready to handle next request.
bool Ready() {
return reader_->bytes_available() >= pending_request_bytes_;
} }
// Number of pending bytes requests // Number of pending bytes requests
size_t pending_request_bytes_; size_t pending_request_bytes_;
...@@ -766,7 +806,7 @@ RPCSession::~RPCSession() { ...@@ -766,7 +806,7 @@ RPCSession::~RPCSession() {
void RPCSession::Shutdown() { void RPCSession::Shutdown() {
if (channel_ != nullptr) { if (channel_ != nullptr) {
RPCCode code = RPCCode::kShutdown; RPCCode code = RPCCode::kShutdown;
writer_.Write(&code, sizeof(code)); handler_->Write(code);
// flush all writing buffer to output channel. // flush all writing buffer to output channel.
try { try {
while (writer_.bytes_available() != 0) { while (writer_.bytes_available() != 0) {
...@@ -788,7 +828,6 @@ void RPCSession::ServerLoop() { ...@@ -788,7 +828,6 @@ void RPCSession::ServerLoop() {
} }
TVMRetValue rv; TVMRetValue rv;
CHECK(HandleUntilReturnEvent(&rv, false, nullptr) == RPCCode::kShutdown); CHECK(HandleUntilReturnEvent(&rv, false, nullptr) == RPCCode::kShutdown);
LOG(INFO) << "Shutdown...";
if (const auto* f = Registry::Get("tvm.contrib.rpc.server.shutdown")) { if (const auto* f = Registry::Get("tvm.contrib.rpc.server.shutdown")) {
(*f)(); (*f)();
} }
...@@ -821,9 +860,9 @@ void RPCSession::CallFunc(void* h, ...@@ -821,9 +860,9 @@ void RPCSession::CallFunc(void* h,
const PackedFunc* fwrap) { const PackedFunc* fwrap) {
std::lock_guard<std::recursive_mutex> lock(mutex_); std::lock_guard<std::recursive_mutex> lock(mutex_);
RPCCode code = RPCCode::kCallFunc; RPCCode code = RPCCode::kCallFunc;
writer_.Write(&code, sizeof(code)); handler_->Write(code);
uint64_t handle = reinterpret_cast<uint64_t>(h); uint64_t handle = reinterpret_cast<uint64_t>(h);
writer_.Write(&handle, sizeof(handle)); handler_->Write(handle);
handler_->SendPackedSeq(args.values, args.type_codes, args.num_args); handler_->SendPackedSeq(args.values, args.type_codes, args.num_args);
code = HandleUntilReturnEvent(rv, true, fwrap); code = HandleUntilReturnEvent(rv, true, fwrap);
CHECK(code == RPCCode::kReturn) << "code=" << static_cast<int>(code); CHECK(code == RPCCode::kReturn) << "code=" << static_cast<int>(code);
...@@ -838,15 +877,15 @@ void RPCSession::CopyToRemote(void* from, ...@@ -838,15 +877,15 @@ void RPCSession::CopyToRemote(void* from,
std::lock_guard<std::recursive_mutex> lock(mutex_); std::lock_guard<std::recursive_mutex> lock(mutex_);
ctx_to = handler_->StripSessMask(ctx_to); ctx_to = handler_->StripSessMask(ctx_to);
RPCCode code = RPCCode::kCopyToRemote; RPCCode code = RPCCode::kCopyToRemote;
writer_.Write(&code, sizeof(code)); handler_->Write(code);
uint64_t handle = reinterpret_cast<uint64_t>(to); uint64_t handle = reinterpret_cast<uint64_t>(to);
writer_.Write(&handle, sizeof(handle)); handler_->Write(handle);
uint64_t offset = static_cast<uint64_t>(to_offset); uint64_t offset = static_cast<uint64_t>(to_offset);
writer_.Write(&offset, sizeof(offset)); handler_->Write(offset);
uint64_t size = static_cast<uint64_t>(data_size); uint64_t size = static_cast<uint64_t>(data_size);
writer_.Write(&size, sizeof(size)); handler_->Write(size);
writer_.Write(&ctx_to, sizeof(ctx_to)); handler_->Write(ctx_to);
writer_.Write(reinterpret_cast<char*>(from) + from_offset, data_size); handler_->WriteArray(reinterpret_cast<char*>(from) + from_offset, data_size);
TVMRetValue rv; TVMRetValue rv;
CHECK(HandleUntilReturnEvent(&rv, true, nullptr) == RPCCode::kReturn); CHECK(HandleUntilReturnEvent(&rv, true, nullptr) == RPCCode::kReturn);
} }
...@@ -860,26 +899,27 @@ void RPCSession::CopyFromRemote(void* from, ...@@ -860,26 +899,27 @@ void RPCSession::CopyFromRemote(void* from,
std::lock_guard<std::recursive_mutex> lock(mutex_); std::lock_guard<std::recursive_mutex> lock(mutex_);
ctx_from = handler_->StripSessMask(ctx_from); ctx_from = handler_->StripSessMask(ctx_from);
RPCCode code = RPCCode::kCopyFromRemote; RPCCode code = RPCCode::kCopyFromRemote;
writer_.Write(&code, sizeof(code)); handler_->Write(code);
uint64_t handle = reinterpret_cast<uint64_t>(from); uint64_t handle = reinterpret_cast<uint64_t>(from);
writer_.Write(&handle, sizeof(handle)); handler_->Write(handle);
uint64_t offset = static_cast<uint64_t>(from_offset); uint64_t offset = static_cast<uint64_t>(from_offset);
writer_.Write(&offset, sizeof(offset)); handler_->Write(offset);
uint64_t size = static_cast<uint64_t>(data_size); uint64_t size = static_cast<uint64_t>(data_size);
writer_.Write(&size, sizeof(size)); handler_->Write(size);
writer_.Write(&ctx_from, sizeof(ctx_from)); handler_->Write(ctx_from);
TVMRetValue rv; TVMRetValue rv;
CHECK(HandleUntilReturnEvent(&rv, true, nullptr) == RPCCode::kCopyAck); CHECK(HandleUntilReturnEvent(&rv, true, nullptr) == RPCCode::kCopyAck);
reader_.Reserve(data_size); reader_.Reserve(data_size);
while (reader_.bytes_available() < data_size) { handler_->RequestBytes(data_size);
size_t bytes_needed = data_size - reader_.bytes_available(); while (!handler_->Ready()) {
size_t bytes_needed = handler_->BytesNeeded();
reader_.WriteWithCallback([this](void* data, size_t size) { reader_.WriteWithCallback([this](void* data, size_t size) {
size_t n = channel_->Recv(data, size); size_t n = channel_->Recv(data, size);
CHECK_NE(n, 0U) << "Channel closes before we get neded bytes"; CHECK_NE(n, 0U) << "Channel closes before we get neded bytes";
return n; return n;
}, bytes_needed); }, bytes_needed);
} }
reader_.Read(reinterpret_cast<char*>(to) + to_offset, data_size); handler_->ReadArray(reinterpret_cast<char*>(to) + to_offset, data_size);
handler_->FinishCopyAck(); handler_->FinishCopyAck();
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment