Commit 42608dda by tqchen Committed by Tianqi Chen

[IO] Support cross-endian

parent 0f9dab98
Subproject commit d3f7fbb53e5b037c0f5bf6bd21871ccc720690cc
Subproject commit 9b3f9753ae81d657743c555e0cacc4e43f0bed2d
/*!
* Copyright (c) 2017 by Contributors
* \file tvm/runtime/serializer.h
* \brief Serializer extension to support TVM data types
* Include this file to enable serialization of DLDataType, DLContext
*/
#ifndef TVM_RUNTIME_SERIALIZER_H_
#define TVM_RUNTIME_SERIALIZER_H_
#include <dmlc/io.h>
#include <dmlc/serializer.h>
#include "./c_runtime_api.h"
namespace dmlc {
namespace serializer {
template<>
struct Handler<DLDataType> {
inline static void Write(Stream *strm, const DLDataType& dtype) {
Handler<uint8_t>::Write(strm, dtype.code);
Handler<uint8_t>::Write(strm, dtype.bits);
Handler<uint16_t>::Write(strm, dtype.lanes);
}
inline static bool Read(Stream *strm, DLDataType* dtype) {
if (!Handler<uint8_t>::Read(strm, &(dtype->code))) return false;
if (!Handler<uint8_t>::Read(strm, &(dtype->bits))) return false;
if (!Handler<uint16_t>::Read(strm, &(dtype->lanes))) return false;
return true;
}
};
template<>
struct Handler<DLContext> {
inline static void Write(Stream *strm, const DLContext& ctx) {
int32_t device_type = static_cast<int32_t>(ctx.device_type);
Handler<int32_t>::Write(strm, device_type);
Handler<int32_t>::Write(strm, ctx.device_id);
}
inline static bool Read(Stream *strm, DLContext* ctx) {
int32_t device_type = 0;
if (!Handler<int32_t>::Read(strm, &(device_type))) return false;
ctx->device_type = static_cast<DLDeviceType>(device_type);
if (!Handler<int32_t>::Read(strm, &(ctx->device_id))) return false;
return true;
}
};
} // namespace serializer
} // namespace dmlc
#endif // TVM_RUNTIME_SERIALIZER_H_
......@@ -7,6 +7,7 @@
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/serializer.h>
#include "./graph_runtime.h"
namespace nnvm {
......@@ -38,46 +39,53 @@ NNVM_REGISTER_OP(tvm_op)
bool SaveDLTensor(dmlc::Stream* strm, DLTensor* tensor) {
uint64_t header = kTVMNDArrayMagic, reserved = 0;
strm->Write(&header, sizeof(header));
strm->Write(&reserved, sizeof(reserved));
strm->Write(&tensor->ctx, sizeof(tensor->ctx));
strm->Write(&tensor->ndim, sizeof(tensor->ndim));
strm->Write(&tensor->dtype, sizeof(tensor->dtype));
strm->Write(header);
strm->Write(reserved);
strm->Write(tensor->ctx);
strm->Write(tensor->ndim);
strm->Write(tensor->dtype);
int ndim = tensor->ndim;
strm->Write(tensor->shape, sizeof(int64_t) * ndim);
strm->WriteArray(tensor->shape, ndim);
int type_size = tensor->dtype.bits / 8;
int64_t size = 1;
int type_bytes = tensor->dtype.bits / 8;
int64_t num_elems = 1;
for (int i = 0; i < ndim; ++i) {
size *= tensor->shape[i];
num_elems *= tensor->shape[i];
}
int64_t data_byte_size = type_size * size;
strm->Write(&data_byte_size, sizeof(data_byte_size));
int64_t data_byte_size = type_bytes * num_elems;
strm->Write(data_byte_size);
// handle endianness of data correctly.
if (DMLC_IO_NO_ENDIAN_SWAP) {
strm->Write(tensor->data, data_byte_size);
} else {
uint8_t* dptr = reinterpret_cast<uint8_t*>(tensor->data);
std::vector<uint8_t> bytes(dptr, dptr + data_byte_size);
dmlc::ByteSwap(dmlc::BeginPtr(bytes), type_bytes, num_elems);
strm->Write(dmlc::BeginPtr(bytes), data_byte_size);
}
return true;
}
DLTensor* LoadDLTensor(dmlc::Stream* strm) {
uint64_t header, reserved;
CHECK(strm->Read(&header, sizeof(header)))
CHECK(strm->Read(&header))
<< "Invalid DLTensor file format";
CHECK(strm->Read(&reserved, sizeof(reserved)))
CHECK(strm->Read(&reserved))
<< "Invalid DLTensor file format";
CHECK(header == kTVMNDArrayMagic)
<< "Invalid DLTensor file format";
DLTensor tensor;
CHECK(strm->Read(&tensor.ctx, sizeof(tensor.ctx)))
CHECK(strm->Read(&(tensor.ctx)))
<< "Invalid DLTensor file format";
CHECK(strm->Read(&tensor.ndim, sizeof(tensor.ndim)))
CHECK(strm->Read(&(tensor.ndim)))
<< "Invalid DLTensor file format";
CHECK(strm->Read(&tensor.dtype, sizeof(tensor.dtype)))
CHECK(strm->Read(&(tensor.dtype)))
<< "Invalid DLTensor file format";
std::vector<int64_t> shape(tensor.ndim);
CHECK(strm->Read(&shape[0], sizeof(int64_t) * tensor.ndim))
if (tensor.ndim != 0) {
CHECK(strm->ReadArray(&shape[0], tensor.ndim))
<< "Invalid DLTensor file format";
}
DLTensor* ret;
CHECK_EQ(TVMArrayAlloc(shape.data(),
tensor.ndim,
......@@ -87,18 +95,21 @@ DLTensor* LoadDLTensor(dmlc::Stream* strm) {
static_cast<int>(tensor.ctx.device_type),
tensor.ctx.device_id,
&ret), 0) << TVMGetLastError();
int64_t size = 1;
int type_size = ret->dtype.bits / 8;
int64_t num_elems = 1;
int elem_bytes = (ret->dtype.bits + 7) / 8;
for (int i = 0; i < ret->ndim; ++i) {
size *= ret->shape[i];
num_elems *= ret->shape[i];
}
int64_t data_byte_size;
CHECK(strm->Read(&data_byte_size, sizeof(data_byte_size)))
CHECK(strm->Read(&data_byte_size))
<< "Invalid DLTensor file format";
CHECK(data_byte_size == type_size * size)
CHECK(data_byte_size == num_elems * elem_bytes)
<< "Invalid DLTensor file format";
CHECK(strm->Read(ret->data, type_size * size))
CHECK(strm->Read(ret->data, data_byte_size))
<< "Invalid DLTensor file format";
if (!DMLC_IO_NO_ENDIAN_SWAP) {
dmlc::ByteSwap(ret->data, elem_bytes, num_elems);
}
return ret;
}
......@@ -118,12 +129,12 @@ TVM_REGISTER_GLOBAL("nnvm.compiler._save_param_dict")
dmlc::MemoryStringStream strm(&bytes);
dmlc::Stream* fo = &strm;
uint64_t header = kTVMNDArrayListMagic, reserved = 0;
fo->Write(&header, sizeof(header));
fo->Write(&reserved, sizeof(reserved));
fo->Write(header);
fo->Write(reserved);
fo->Write(names);
{
uint64_t sz = static_cast<uint64_t>(arrays.size());
fo->Write(&sz, sizeof(sz));
fo->Write(sz);
for (size_t i = 0; i < sz; ++i) {
SaveDLTensor(fo, arrays[i]);
}
......@@ -150,7 +161,6 @@ TVM_REGISTER_GLOBAL("nnvm.compiler._load_param_dict")
<< "Invalid parameters file format";
CHECK(strm->Read(&reserved))
<< "Invalid parameters file format";
CHECK(strm->Read(&names))
<< "Invalid parameters file format";
uint64_t sz;
......
......@@ -73,7 +73,7 @@ def sendjson(sock, data):
Python value to be sent.
"""
data = json.dumps(data)
sock.sendall(struct.pack("@i", len(data)))
sock.sendall(struct.pack("<i", len(data)))
sock.sendall(data.encode("utf-8"))
......@@ -90,7 +90,7 @@ def recvjson(sock):
value : object
The value received.
"""
size = struct.unpack("@i", recvall(sock, 4))[0]
size = struct.unpack("<i", recvall(sock, 4))[0]
data = json.loads(py_str(recvall(sock, size)))
return data
......
......@@ -192,8 +192,8 @@ class TrackerSession(object):
def _connect(self):
self._sock = base.connect_with_retry(self._addr)
self._sock.sendall(struct.pack("@i", base.RPC_TRACKER_MAGIC))
magic = struct.unpack("@i", base.recvall(self._sock, 4))[0]
self._sock.sendall(struct.pack("<i", base.RPC_TRACKER_MAGIC))
magic = struct.unpack("<i", base.recvall(self._sock, 4))[0]
if magic != base.RPC_TRACKER_MAGIC:
raise RuntimeError("%s is not RPC Tracker" % str(self._addr))
......
......@@ -58,14 +58,14 @@ class ForwardHandler(object):
def _init_step(self, message):
if self._magic is None:
assert len(message) == 4
self._magic = struct.unpack('@i', message)[0]
self._magic = struct.unpack('<i', message)[0]
if self._magic != base.RPC_MAGIC:
logging.info("Invalid RPC magic from %s", self.name())
self.close()
self._init_req_nbytes = 4
elif self._rpc_key_length is None:
assert len(message) == 4
self._rpc_key_length = struct.unpack('@i', message)[0]
self._rpc_key_length = struct.unpack('<i', message)[0]
self._init_req_nbytes = self._rpc_key_length
elif self.rpc_key is None:
assert len(message) == self._rpc_key_length
......@@ -269,12 +269,12 @@ class ProxyServerHandler(object):
lhs.forward_proxy = rhs
rhs.forward_proxy = lhs
lhs.send_data(struct.pack('@i', base.RPC_CODE_SUCCESS))
lhs.send_data(struct.pack('@i', len(rhs.rpc_key)))
lhs.send_data(struct.pack('<i', base.RPC_CODE_SUCCESS))
lhs.send_data(struct.pack('<i', len(rhs.rpc_key)))
lhs.send_data(rhs.rpc_key.encode("utf-8"))
rhs.send_data(struct.pack('@i', base.RPC_CODE_SUCCESS))
rhs.send_data(struct.pack('@i', len(lhs.rpc_key)))
rhs.send_data(struct.pack('<i', base.RPC_CODE_SUCCESS))
rhs.send_data(struct.pack('<i', len(lhs.rpc_key)))
rhs.send_data(lhs.rpc_key.encode("utf-8"))
logging.info("Pairup connect %s and %s", lhs.name(), rhs.name())
......@@ -299,8 +299,8 @@ class ProxyServerHandler(object):
if self._tracker_conn is None:
self._tracker_conn = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._tracker_conn.connect(self._tracker_addr)
self._tracker_conn.sendall(struct.pack("@i", base.RPC_TRACKER_MAGIC))
magic = struct.unpack("@i", base.recvall(self._tracker_conn, 4))[0]
self._tracker_conn.sendall(struct.pack("<i", base.RPC_TRACKER_MAGIC))
magic = struct.unpack("<i", base.recvall(self._tracker_conn, 4))[0]
if magic != base.RPC_TRACKER_MAGIC:
self.loop.stop()
raise RuntimeError("%s is not RPC Tracker" % str(self._tracker_addr))
......@@ -371,7 +371,7 @@ class ProxyServerHandler(object):
if handler.match_key in self._server_pool:
self._pair_up(self._server_pool.pop(handler.match_key), handler)
else:
handler.send_data(struct.pack('@i', base.RPC_CODE_MISMATCH))
handler.send_data(struct.pack('<i', base.RPC_CODE_MISMATCH))
handler.signal_close()
def _handler_ready_proxy_mode(self, handler):
......@@ -395,12 +395,12 @@ class ProxyServerHandler(object):
logging.info("Timeout client connection %s, cannot find match key=%s",
handler.name(), key)
pool_dst.pop(key)
handler.send_data(struct.pack('@i', base.RPC_CODE_MISMATCH))
handler.send_data(struct.pack('<i', base.RPC_CODE_MISMATCH))
handler.signal_close()
self.loop.call_later(timeout, cleanup)
else:
logging.info("Duplicate connection with same key=%s", key)
handler.send_data(struct.pack('@i', base.RPC_CODE_DUPLICATE))
handler.send_data(struct.pack('<i', base.RPC_CODE_DUPLICATE))
handler.signal_close()
def handler_ready(self, handler):
......@@ -538,13 +538,13 @@ def websocket_proxy_server(url, key=""):
on_message = create_on_message(conn)
temp = _server_env(None)
# Start connecton
conn.write_message(struct.pack('@i', base.RPC_MAGIC), binary=True)
conn.write_message(struct.pack('<i', base.RPC_MAGIC), binary=True)
key = "server:" + key
conn.write_message(struct.pack('@i', len(key)), binary=True)
conn.write_message(struct.pack('<i', len(key)), binary=True)
conn.write_message(key.encode("utf-8"), binary=True)
msg = yield conn.read_message()
assert len(msg) >= 4
magic = struct.unpack('@i', msg[:4])[0]
magic = struct.unpack('<i', msg[:4])[0]
if magic == base.RPC_CODE_DUPLICATE:
raise RuntimeError("key: %s has already been used in proxy" % key)
elif magic == base.RPC_CODE_MISMATCH:
......
......@@ -124,23 +124,23 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library):
unmatch_period_count = 0
continue
conn, addr = listen_sock.accept()
magic = struct.unpack("@i", base.recvall(conn, 4))[0]
magic = struct.unpack("<i", base.recvall(conn, 4))[0]
if magic != base.RPC_MAGIC:
conn.close()
continue
keylen = struct.unpack("@i", base.recvall(conn, 4))[0]
keylen = struct.unpack("<i", base.recvall(conn, 4))[0]
key = py_str(base.recvall(conn, keylen))
arr = key.split()
expect_header = "client:" + matchkey
server_key = "server:" + rpc_key
if arr[0] != expect_header:
conn.sendall(struct.pack("@i", base.RPC_CODE_MISMATCH))
conn.sendall(struct.pack("<i", base.RPC_CODE_MISMATCH))
conn.close()
logging.info("RPCServer: mismatch key from %s", addr)
continue
else:
conn.sendall(struct.pack("@i", base.RPC_CODE_SUCCESS))
conn.sendall(struct.pack("@i", len(server_key)))
conn.sendall(struct.pack("<i", base.RPC_CODE_SUCCESS))
conn.sendall(struct.pack("<i", len(server_key)))
conn.sendall(server_key.encode("utf-8"))
return conn, addr, _parse_server_opt(arr[1:])
......@@ -151,8 +151,8 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library):
# step 1: setup tracker and report to tracker
if tracker_addr and tracker_conn is None:
tracker_conn = base.connect_with_retry(tracker_addr)
tracker_conn.sendall(struct.pack("@i", base.RPC_TRACKER_MAGIC))
magic = struct.unpack("@i", base.recvall(tracker_conn, 4))[0]
tracker_conn.sendall(struct.pack("<i", base.RPC_TRACKER_MAGIC))
magic = struct.unpack("<i", base.recvall(tracker_conn, 4))[0]
if magic != base.RPC_TRACKER_MAGIC:
raise RuntimeError("%s is not RPC Tracker" % str(tracker_addr))
# report status of current queue
......@@ -193,17 +193,17 @@ def _connect_proxy_loop(addr, key, load_library):
try:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect(addr)
sock.sendall(struct.pack("@i", base.RPC_MAGIC))
sock.sendall(struct.pack("@i", len(key)))
sock.sendall(struct.pack("<i", base.RPC_MAGIC))
sock.sendall(struct.pack("<i", len(key)))
sock.sendall(key.encode("utf-8"))
magic = struct.unpack("@i", base.recvall(sock, 4))[0]
magic = struct.unpack("<i", base.recvall(sock, 4))[0]
if magic == base.RPC_CODE_DUPLICATE:
raise RuntimeError("key: %s has already been used in proxy" % key)
elif magic == base.RPC_CODE_MISMATCH:
logging.info("RPCProxy do not have matching client key %s", key)
elif magic != base.RPC_CODE_SUCCESS:
raise RuntimeError("%s is not RPC Proxy" % str(addr))
keylen = struct.unpack("@i", base.recvall(sock, 4))[0]
keylen = struct.unpack("<i", base.recvall(sock, 4))[0]
remote_key = py_str(base.recvall(sock, keylen))
opts = _parse_server_opt(remote_key.split()[1:])
logging.info("RPCProxy connected to %s", str(addr))
......
......@@ -143,11 +143,11 @@ class TCPEventHandler(tornado_util.TCPHandler):
if len(message) != 4:
logging.info("Invalid connection from %s", self.name())
self.close()
magic = struct.unpack('@i', message)[0]
magic = struct.unpack('<i', message)[0]
if magic != RPC_TRACKER_MAGIC:
logging.info("Invalid magic from %s", self.name())
self.close()
self.write_message(struct.pack('@i', RPC_TRACKER_MAGIC), binary=True)
self.write_message(struct.pack('<i', RPC_TRACKER_MAGIC), binary=True)
self._init_req_nbytes = 0
def on_message(self, message):
......@@ -168,7 +168,7 @@ class TCPEventHandler(tornado_util.TCPHandler):
while True:
if self._msg_size == 0:
if len(self._data) >= 4:
self._msg_size = struct.unpack('@i', self._data[:4])[0]
self._msg_size = struct.unpack('<i', self._data[:4])[0]
else:
return
if self._msg_size != 0 and len(self._data) >= self._msg_size + 4:
......@@ -184,7 +184,7 @@ class TCPEventHandler(tornado_util.TCPHandler):
"""return value to the output"""
data = json.dumps(data)
self.write_message(
struct.pack('@i', len(data)), binary=True)
struct.pack('<i', len(data)), binary=True)
self.write_message(data.encode("utf-8"), binary=True)
def call_handler(self, args):
......@@ -355,8 +355,8 @@ class Tracker(object):
def _stop_tracker(self):
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect((self.host, self.port))
sock.sendall(struct.pack("@i", base.RPC_TRACKER_MAGIC))
magic = struct.unpack("@i", base.recvall(sock, 4))[0]
sock.sendall(struct.pack("<i", base.RPC_TRACKER_MAGIC))
magic = struct.unpack("<i", base.recvall(sock, 4))[0]
assert magic == base.RPC_TRACKER_MAGIC
base.sendjson(sock, [TrackerCode.STOP, self.stop_key])
assert base.recvjson(sock) == TrackerCode.SUCCESS
......
......@@ -4,6 +4,7 @@
*/
#include <dmlc/json.h>
#include <dmlc/logging.h>
#include <tvm/runtime/serializer.h>
#include <fstream>
#include "./file_util.h"
......
......@@ -4,6 +4,7 @@
*/
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/serializer.h>
#include <dmlc/memory_io.h>
#include <dmlc/json.h>
#include <numeric>
......@@ -397,24 +398,25 @@ class GraphRuntime : public ModuleNode {
void GraphRuntime::LoadDLTensor(dmlc::Stream* strm, DLTensor* dst) {
// always use strm->Read to maintain endianness conversion
uint64_t header, reserved;
CHECK(strm->Read(&header, sizeof(header)))
CHECK(strm->Read(&header))
<< "Invalid DLTensor file format";
CHECK(strm->Read(&reserved, sizeof(reserved)))
CHECK(strm->Read(&reserved))
<< "Invalid DLTensor file format";
CHECK(header == kTVMNDArrayMagic)
<< "Invalid DLTensor file format";
DLTensor tensor;
CHECK(strm->Read(&tensor.ctx, sizeof(tensor.ctx)))
CHECK(strm->Read(&(tensor.ctx)))
<< "Invalid DLTensor file format";
CHECK(strm->Read(&tensor.ndim, sizeof(tensor.ndim)))
CHECK(strm->Read(&(tensor.ndim)))
<< "Invalid DLTensor file format";
CHECK(strm->Read(&tensor.dtype, sizeof(tensor.dtype)))
CHECK(strm->Read(&(tensor.dtype)))
<< "Invalid DLTensor file format";
std::vector<int64_t> shape(tensor.ndim);
if (tensor.ndim != 0) {
CHECK(strm->Read(&shape[0], sizeof(int64_t) * tensor.ndim))
CHECK(strm->ReadArray(&shape[0], tensor.ndim))
<< "Invalid DLTensor file format";
}
CHECK_EQ(tensor.ndim, dst->ndim) << "param dimension mismatch";
......@@ -425,18 +427,23 @@ void GraphRuntime::LoadDLTensor(dmlc::Stream* strm, DLTensor* dst) {
CHECK_EQ(shape[i], dst->shape[i]) << "param shape mismatch";
}
size_t bits = dst->dtype.bits * dst->dtype.lanes;
size_t size = (bits + 7) / 8;
size_t elem_bytes = (bits + 7) / 8;
size_t num_elems = 1;
for (int i = 0; i < dst->ndim; ++i) {
size *= dst->shape[i];
num_elems *= dst->shape[i];
}
uint64_t data_byte_size;
CHECK(strm->Read(&data_byte_size, sizeof(data_byte_size)))
CHECK(strm->Read(&data_byte_size))
<< "Invalid DLTensor file format";
CHECK(data_byte_size == size)
CHECK_EQ(data_byte_size, elem_bytes * num_elems)
<< "Invalid DLTensor file format";
std::vector<uint8_t> bytes(data_byte_size + 1);
CHECK(strm->Read(&bytes[0], data_byte_size))
<< "Invalid DLTensor file format";
// explicitly swap endian when necessary.
if (!DMLC_IO_NO_ENDIAN_SWAP) {
dmlc::ByteSwap(&bytes[0], elem_bytes, num_elems);
}
TVM_CCALL(TVMArrayCopyFromBytes(dst, &bytes[0], data_byte_size));
}
......@@ -453,9 +460,8 @@ void GraphRuntime::LoadParams(dmlc::Stream* strm) {
CHECK(strm->Read(&names))
<< "Invalid parameters file format";
uint64_t sz;
strm->Read(&sz, sizeof(sz));
strm->Read(&sz);
size_t size = static_cast<size_t>(sz);
CHECK(size == names.size())
<< "Invalid parameters file format";
for (size_t i = 0; i < size; ++i) {
......
......@@ -6,6 +6,7 @@
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/serializer.h>
#include <memory>
#include <array>
#include <string>
......@@ -44,7 +45,7 @@ struct RPCArgBuffer {
};
// Event handler for RPC events.
class RPCSession::EventHandler {
class RPCSession::EventHandler : public dmlc::Stream {
public:
EventHandler(common::RingBuffer* reader,
common::RingBuffer* writer,
......@@ -71,6 +72,15 @@ class RPCSession::EventHandler {
return 0;
}
}
// Request number of bytes from reader.
void RequestBytes(size_t nbytes) {
pending_request_bytes_ += nbytes;
reader_->Reserve(pending_request_bytes_);
}
// Whether we are ready to handle next request.
bool Ready() {
return reader_->bytes_available() >= pending_request_bytes_;
}
bool CanCleanShutdown() const {
return state_ == kRecvCode;
}
......@@ -86,12 +96,12 @@ class RPCSession::EventHandler {
case kInitHeader: HandleInitHeader(); break;
case kRecvCode: HandleRecvCode(); break;
case kRecvCallHandle: {
this->Read(&call_handle_, sizeof(call_handle_));
CHECK(this->Read(&call_handle_));
this->SwitchToState(kRecvPackedSeqNumArgs);
break;
}
case kRecvPackedSeqNumArgs: {
this->Read(&num_packed_args_, sizeof(num_packed_args_));
CHECK(this->Read(&num_packed_args_));
arg_buf_.reset(new RPCArgBuffer());
arg_buf_->value.resize(num_packed_args_);
arg_buf_->tcode.resize(num_packed_args_);
......@@ -100,7 +110,7 @@ class RPCSession::EventHandler {
}
case kRecvPackedSeqTypeCode: {
if (num_packed_args_ != 0) {
this->Read(arg_buf_->tcode.data(), sizeof(int) * num_packed_args_);
this->ReadArray(arg_buf_->tcode.data(), num_packed_args_);
}
arg_index_ = 0;
arg_recv_stage_ = 0;
......@@ -164,8 +174,8 @@ class RPCSession::EventHandler {
}
// send Packed sequence to writer.
void SendPackedSeq(const TVMValue* arg_values, const int* type_codes, int n) {
writer_->Write(&n, sizeof(n));
writer_->Write(type_codes, sizeof(int) * n);
this->Write(n);
this->WriteArray(type_codes, n);
// Argument packing.
for (int i = 0; i < n; ++i) {
int tcode = type_codes[i];
......@@ -173,14 +183,20 @@ class RPCSession::EventHandler {
switch (tcode) {
case kDLInt:
case kDLUInt:
case kDLFloat:
case kDLFloat: {
this->Write<int64_t>(value.v_int64);
break;
}
case kTVMType: {
writer_->Write(&value, sizeof(TVMValue));
this->Write(value.v_type);
// padding
int32_t padding = 0;
this->Write<int32_t>(padding);
break;
}
case kTVMContext: {
value.v_ctx = StripSessMask(value.v_ctx);
writer_->Write(&value, sizeof(TVMValue));
this->Write(value.v_ctx);
break;
}
case kFuncHandle:
......@@ -188,7 +204,7 @@ class RPCSession::EventHandler {
case kHandle: {
// always send handle in 64 bit.
uint64_t handle = reinterpret_cast<uint64_t>(value.v_handle);
writer_->Write(&handle, sizeof(uint64_t));
this->Write(handle);
break;
}
case kArrayHandle: {
......@@ -196,11 +212,11 @@ class RPCSession::EventHandler {
TVMContext ctx = StripSessMask(arr->ctx);
uint64_t data = reinterpret_cast<uint64_t>(
static_cast<RemoteSpace*>(arr->data)->data);
writer_->Write(&data, sizeof(uint64_t));
writer_->Write(&ctx, sizeof(ctx));
writer_->Write(&(arr->ndim), sizeof(int));
writer_->Write(&(arr->dtype), sizeof(DLDataType));
writer_->Write(arr->shape, sizeof(int64_t) * arr->ndim);
this->Write(data);
this->Write(ctx);
this->Write(arr->ndim);
this->Write(arr->dtype);
this->WriteArray(arr->shape, arr->ndim);
CHECK(arr->strides == nullptr)
<< "Donot support strided remote array";
CHECK_EQ(arr->byte_offset, 0)
......@@ -211,15 +227,15 @@ class RPCSession::EventHandler {
case kStr: {
const char* s = value.v_str;
uint64_t len = strlen(s);
writer_->Write(&len, sizeof(len));
writer_->Write(s, sizeof(char) * len);
this->Write(len);
this->WriteArray(s, len);
break;
}
case kBytes: {
TVMByteArray* bytes = static_cast<TVMByteArray*>(arg_values[i].v_handle);
uint64_t len = bytes->size;
writer_->Write(&len, sizeof(len));
writer_->Write(bytes->data, sizeof(char) * len);
this->Write(len);
this->WriteArray(bytes->data, len);
break;
}
default: {
......@@ -230,6 +246,23 @@ class RPCSession::EventHandler {
}
}
// Endian aware IO handling
using Stream::Read;
using Stream::Write;
using Stream::ReadArray;
using Stream::WriteArray;
inline bool Read(RPCCode* code) {
int cdata;
if (!this->Read(&cdata)) return false;
*code = static_cast<RPCCode>(cdata);
return true;
}
inline void Write(RPCCode code) {
int cdata = static_cast<int>(code);
this->Write(cdata);
}
protected:
enum State {
kInitHeader,
......@@ -370,10 +403,22 @@ class RPCSession::EventHandler {
switch (tcode) {
case kDLInt:
case kDLUInt:
case kDLFloat:
case kTVMType:
case kDLFloat: {
this->Read<int64_t>(&(value.v_int64));
++arg_index_;
this->SwitchToState(kRecvPackedSeqArg);
break;
}
case kTVMType: {
this->Read(&(value.v_type));
int32_t padding = 0;
this->Read<int32_t>(&padding);
++arg_index_;
this->SwitchToState(kRecvPackedSeqArg);
break;
}
case kTVMContext: {
this->Read(&value, sizeof(TVMValue));
this->Read(&(value.v_ctx));
++arg_index_;
this->SwitchToState(kRecvPackedSeqArg);
break;
......@@ -383,7 +428,7 @@ class RPCSession::EventHandler {
case kHandle: {
// always send handle in 64 bit.
uint64_t handle;
this->Read(&handle, sizeof(handle));
this->Read(&handle);
value.v_handle = reinterpret_cast<void*>(handle);
++arg_index_;
this->SwitchToState(kRecvPackedSeqArg);
......@@ -398,7 +443,7 @@ class RPCSession::EventHandler {
case kStr:
case kBytes: {
uint64_t len;
this->Read(&len, sizeof(len));
this->Read(&len);
temp_bytes_.reset( new RPCByteArrayBuffer());
temp_bytes_->data.resize(len);
arg_recv_stage_ = 1;
......@@ -409,12 +454,12 @@ class RPCSession::EventHandler {
case kArrayHandle: {
temp_array_.reset(new RPCDataArrayBuffer());
uint64_t handle;
this->Read(&handle, sizeof(handle));
this->Read(&handle);
DLTensor& tensor = temp_array_->tensor;
tensor.data = reinterpret_cast<void*>(handle);
this->Read(&(tensor.ctx), sizeof(TVMContext));
this->Read(&(tensor.ndim), sizeof(int));
this->Read(&(tensor.dtype), sizeof(DLDataType));
this->Read(&(tensor.ctx));
this->Read(&(tensor.ndim));
this->Read(&(tensor.dtype));
temp_array_->shape.resize(tensor.ndim);
tensor.shape = temp_array_->shape.data();
arg_recv_stage_ = 1;
......@@ -432,7 +477,7 @@ class RPCSession::EventHandler {
CHECK_EQ(arg_recv_stage_, 1);
if (tcode == kStr || tcode == kBytes) {
if (temp_bytes_->data.size() != 0) {
this->Read(&(temp_bytes_->data[0]), temp_bytes_->data.size());
this->ReadArray(&(temp_bytes_->data[0]), temp_bytes_->data.size());
}
if (tcode == kStr) {
value.v_str = temp_bytes_->data.c_str();
......@@ -445,7 +490,7 @@ class RPCSession::EventHandler {
} else {
CHECK_EQ(tcode, kArrayHandle);
DLTensor& tensor = temp_array_->tensor;
this->Read(tensor.shape, tensor.ndim * sizeof(int64_t));
this->ReadArray(tensor.shape, tensor.ndim);
value.v_handle = &tensor;
arg_buf_->temp_array.emplace_back(std::move(temp_array_));
}
......@@ -458,20 +503,20 @@ class RPCSession::EventHandler {
void HandleInitHeader() {
if (init_header_step_ == 0) {
int32_t len;
this->Read(&len, sizeof(len));
this->Read(&len);
remote_key_->resize(len);
init_header_step_ = 1;
this->RequestBytes(len);
return;
} else {
CHECK_EQ(init_header_step_, 1);
this->Read(dmlc::BeginPtr(*remote_key_), remote_key_->length());
this->ReadArray(dmlc::BeginPtr(*remote_key_), remote_key_->length());
this->SwitchToState(kRecvCode);
}
}
// Handler for read code.
void HandleRecvCode() {
this->Read(&code_, sizeof(code_));
this->Read(&code_);
if (code_ > RPCCode::kSystemFuncStart) {
SwitchToState(kRecvPackedSeqNumArgs);
return;
......@@ -511,14 +556,14 @@ class RPCSession::EventHandler {
void HandleCopyFromRemote() {
uint64_t handle, offset, size;
TVMContext ctx;
this->Read(&handle, sizeof(handle));
this->Read(&offset, sizeof(offset));
this->Read(&size, sizeof(size));
this->Read(&ctx, sizeof(ctx));
this->Read(&handle);
this->Read(&offset);
this->Read(&size);
this->Read(&ctx);
if (ctx.device_type == kDLCPU) {
RPCCode code = RPCCode::kCopyAck;
writer_->Write(&code, sizeof(code));
writer_->Write(reinterpret_cast<char*>(handle) + offset, size);
this->Write(code);
this->WriteArray(reinterpret_cast<char*>(handle) + offset, size);
} else {
temp_data_.resize(size + 1);
try {
......@@ -530,11 +575,11 @@ class RPCSession::EventHandler {
dmlc::BeginPtr(temp_data_), 0,
size, ctx, cpu_ctx, nullptr);
RPCCode code = RPCCode::kCopyAck;
writer_->Write(&code, sizeof(code));
writer_->Write(&temp_data_[0], size);
this->Write(code);
this->WriteArray(&temp_data_[0], size);
} catch (const std::runtime_error &e) {
RPCCode code = RPCCode::kException;
writer_->Write(&code, sizeof(code));
this->Write(code);
TVMValue ret_value;
ret_value.v_str = e.what();
int ret_tcode = kStr;
......@@ -548,10 +593,10 @@ class RPCSession::EventHandler {
// use static variable to persist state.
// This only works if next stage is immediately after this.
if (arg_recv_stage_ == 0) {
this->Read(&copy_handle_, sizeof(uint64_t));
this->Read(&copy_offset_, sizeof(uint64_t));
this->Read(&copy_size_, sizeof(uint64_t));
this->Read(&copy_ctx_, sizeof(TVMContext));
CHECK(this->Read(&copy_handle_));
CHECK(this->Read(&copy_offset_));
CHECK(this->Read(&copy_size_));
CHECK(this->Read(&copy_ctx_));
arg_recv_stage_ = 1;
CHECK_EQ(pending_request_bytes_, 0U);
this->RequestBytes(copy_size_);
......@@ -563,11 +608,11 @@ class RPCSession::EventHandler {
RPCCode code = RPCCode::kReturn;
std::string errmsg;
if (copy_ctx_.device_type == kDLCPU) {
this->Read(
this->ReadArray(
reinterpret_cast<char*>(copy_handle_) + copy_offset_, copy_size_);
} else {
temp_data_.resize(copy_size_ + 1);
this->Read(&temp_data_[0], copy_size_);
this->ReadArray(&temp_data_[0], copy_size_);
try {
TVMContext cpu_ctx;
cpu_ctx.device_type = kDLCPU;
......@@ -583,7 +628,7 @@ class RPCSession::EventHandler {
ret_tcode = kStr;
}
}
writer_->Write(&code, sizeof(code));
this->Write(code);
SendPackedSeq(&ret_value, &ret_tcode, 1);
arg_recv_stage_ = 0;
this->SwitchToState(kRecvCode);
......@@ -603,7 +648,7 @@ class RPCSession::EventHandler {
std::unique_ptr<RPCArgBuffer> args = std::move(arg_buf_);
f(args->AsTVMArgs(), &rv);
RPCCode code = RPCCode::kReturn;
writer_->Write(&code, sizeof(code));
this->Write(code);
if (rv.type_code() == kStr) {
ret_value.v_str = rv.ptr<std::string>()->c_str();
ret_tcode = kStr;
......@@ -630,7 +675,7 @@ class RPCSession::EventHandler {
}
} catch (const std::runtime_error& e) {
RPCCode code = RPCCode::kException;
writer_->Write(&code, sizeof(code));
this->Write(code);
ret_value.v_str = e.what();
ret_tcode = kStr;
SendPackedSeq(&ret_value, &ret_tcode, 1);
......@@ -640,19 +685,14 @@ class RPCSession::EventHandler {
private:
// Utility functions
// Internal read function, update pending_request_bytes_
void Read(void* data, size_t size) {
size_t Read(void* data, size_t size) final {
CHECK_LE(size, pending_request_bytes_);
reader_->Read(data, size);
pending_request_bytes_ -= size;
return size;
}
// Request number of bytes from reader.
void RequestBytes(size_t nbytes) {
pending_request_bytes_ += nbytes;
reader_->Reserve(pending_request_bytes_);
}
// Whether we are ready to handle next request.
bool Ready() {
return reader_->bytes_available() >= pending_request_bytes_;
void Write(const void* data, size_t size) final {
writer_->Write(data, size);
}
// Number of pending bytes requests
size_t pending_request_bytes_;
......@@ -766,7 +806,7 @@ RPCSession::~RPCSession() {
void RPCSession::Shutdown() {
if (channel_ != nullptr) {
RPCCode code = RPCCode::kShutdown;
writer_.Write(&code, sizeof(code));
handler_->Write(code);
// flush all writing buffer to output channel.
try {
while (writer_.bytes_available() != 0) {
......@@ -788,7 +828,6 @@ void RPCSession::ServerLoop() {
}
TVMRetValue rv;
CHECK(HandleUntilReturnEvent(&rv, false, nullptr) == RPCCode::kShutdown);
LOG(INFO) << "Shutdown...";
if (const auto* f = Registry::Get("tvm.contrib.rpc.server.shutdown")) {
(*f)();
}
......@@ -821,9 +860,9 @@ void RPCSession::CallFunc(void* h,
const PackedFunc* fwrap) {
std::lock_guard<std::recursive_mutex> lock(mutex_);
RPCCode code = RPCCode::kCallFunc;
writer_.Write(&code, sizeof(code));
handler_->Write(code);
uint64_t handle = reinterpret_cast<uint64_t>(h);
writer_.Write(&handle, sizeof(handle));
handler_->Write(handle);
handler_->SendPackedSeq(args.values, args.type_codes, args.num_args);
code = HandleUntilReturnEvent(rv, true, fwrap);
CHECK(code == RPCCode::kReturn) << "code=" << static_cast<int>(code);
......@@ -838,15 +877,15 @@ void RPCSession::CopyToRemote(void* from,
std::lock_guard<std::recursive_mutex> lock(mutex_);
ctx_to = handler_->StripSessMask(ctx_to);
RPCCode code = RPCCode::kCopyToRemote;
writer_.Write(&code, sizeof(code));
handler_->Write(code);
uint64_t handle = reinterpret_cast<uint64_t>(to);
writer_.Write(&handle, sizeof(handle));
handler_->Write(handle);
uint64_t offset = static_cast<uint64_t>(to_offset);
writer_.Write(&offset, sizeof(offset));
handler_->Write(offset);
uint64_t size = static_cast<uint64_t>(data_size);
writer_.Write(&size, sizeof(size));
writer_.Write(&ctx_to, sizeof(ctx_to));
writer_.Write(reinterpret_cast<char*>(from) + from_offset, data_size);
handler_->Write(size);
handler_->Write(ctx_to);
handler_->WriteArray(reinterpret_cast<char*>(from) + from_offset, data_size);
TVMRetValue rv;
CHECK(HandleUntilReturnEvent(&rv, true, nullptr) == RPCCode::kReturn);
}
......@@ -860,26 +899,27 @@ void RPCSession::CopyFromRemote(void* from,
std::lock_guard<std::recursive_mutex> lock(mutex_);
ctx_from = handler_->StripSessMask(ctx_from);
RPCCode code = RPCCode::kCopyFromRemote;
writer_.Write(&code, sizeof(code));
handler_->Write(code);
uint64_t handle = reinterpret_cast<uint64_t>(from);
writer_.Write(&handle, sizeof(handle));
handler_->Write(handle);
uint64_t offset = static_cast<uint64_t>(from_offset);
writer_.Write(&offset, sizeof(offset));
handler_->Write(offset);
uint64_t size = static_cast<uint64_t>(data_size);
writer_.Write(&size, sizeof(size));
writer_.Write(&ctx_from, sizeof(ctx_from));
handler_->Write(size);
handler_->Write(ctx_from);
TVMRetValue rv;
CHECK(HandleUntilReturnEvent(&rv, true, nullptr) == RPCCode::kCopyAck);
reader_.Reserve(data_size);
while (reader_.bytes_available() < data_size) {
size_t bytes_needed = data_size - reader_.bytes_available();
handler_->RequestBytes(data_size);
while (!handler_->Ready()) {
size_t bytes_needed = handler_->BytesNeeded();
reader_.WriteWithCallback([this](void* data, size_t size) {
size_t n = channel_->Recv(data, size);
CHECK_NE(n, 0U) << "Channel closes before we get neded bytes";
return n;
}, bytes_needed);
}
reader_.Read(reinterpret_cast<char*>(to) + to_offset, data_size);
handler_->ReadArray(reinterpret_cast<char*>(to) + to_offset, data_size);
handler_->FinishCopyAck();
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment