Commit 42608dda by tqchen Committed by Tianqi Chen

[IO] Support cross-endian

parent 0f9dab98
Subproject commit d3f7fbb53e5b037c0f5bf6bd21871ccc720690cc
Subproject commit 9b3f9753ae81d657743c555e0cacc4e43f0bed2d
/*!
* Copyright (c) 2017 by Contributors
* \file tvm/runtime/serializer.h
* \brief Serializer extension to support TVM data types
* Include this file to enable serialization of DLDataType, DLContext
*/
#ifndef TVM_RUNTIME_SERIALIZER_H_
#define TVM_RUNTIME_SERIALIZER_H_
#include <dmlc/io.h>
#include <dmlc/serializer.h>
#include "./c_runtime_api.h"
namespace dmlc {
namespace serializer {
template<>
struct Handler<DLDataType> {
inline static void Write(Stream *strm, const DLDataType& dtype) {
Handler<uint8_t>::Write(strm, dtype.code);
Handler<uint8_t>::Write(strm, dtype.bits);
Handler<uint16_t>::Write(strm, dtype.lanes);
}
inline static bool Read(Stream *strm, DLDataType* dtype) {
if (!Handler<uint8_t>::Read(strm, &(dtype->code))) return false;
if (!Handler<uint8_t>::Read(strm, &(dtype->bits))) return false;
if (!Handler<uint16_t>::Read(strm, &(dtype->lanes))) return false;
return true;
}
};
template<>
struct Handler<DLContext> {
inline static void Write(Stream *strm, const DLContext& ctx) {
int32_t device_type = static_cast<int32_t>(ctx.device_type);
Handler<int32_t>::Write(strm, device_type);
Handler<int32_t>::Write(strm, ctx.device_id);
}
inline static bool Read(Stream *strm, DLContext* ctx) {
int32_t device_type = 0;
if (!Handler<int32_t>::Read(strm, &(device_type))) return false;
ctx->device_type = static_cast<DLDeviceType>(device_type);
if (!Handler<int32_t>::Read(strm, &(ctx->device_id))) return false;
return true;
}
};
} // namespace serializer
} // namespace dmlc
#endif // TVM_RUNTIME_SERIALIZER_H_
......@@ -7,6 +7,7 @@
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/serializer.h>
#include "./graph_runtime.h"
namespace nnvm {
......@@ -38,46 +39,53 @@ NNVM_REGISTER_OP(tvm_op)
bool SaveDLTensor(dmlc::Stream* strm, DLTensor* tensor) {
uint64_t header = kTVMNDArrayMagic, reserved = 0;
strm->Write(&header, sizeof(header));
strm->Write(&reserved, sizeof(reserved));
strm->Write(&tensor->ctx, sizeof(tensor->ctx));
strm->Write(&tensor->ndim, sizeof(tensor->ndim));
strm->Write(&tensor->dtype, sizeof(tensor->dtype));
strm->Write(header);
strm->Write(reserved);
strm->Write(tensor->ctx);
strm->Write(tensor->ndim);
strm->Write(tensor->dtype);
int ndim = tensor->ndim;
strm->Write(tensor->shape, sizeof(int64_t) * ndim);
strm->WriteArray(tensor->shape, ndim);
int type_size = tensor->dtype.bits / 8;
int64_t size = 1;
int type_bytes = tensor->dtype.bits / 8;
int64_t num_elems = 1;
for (int i = 0; i < ndim; ++i) {
size *= tensor->shape[i];
num_elems *= tensor->shape[i];
}
int64_t data_byte_size = type_size * size;
strm->Write(&data_byte_size, sizeof(data_byte_size));
int64_t data_byte_size = type_bytes * num_elems;
strm->Write(data_byte_size);
// handle endianness of data correctly.
if (DMLC_IO_NO_ENDIAN_SWAP) {
strm->Write(tensor->data, data_byte_size);
} else {
uint8_t* dptr = reinterpret_cast<uint8_t*>(tensor->data);
std::vector<uint8_t> bytes(dptr, dptr + data_byte_size);
dmlc::ByteSwap(dmlc::BeginPtr(bytes), type_bytes, num_elems);
strm->Write(dmlc::BeginPtr(bytes), data_byte_size);
}
return true;
}
DLTensor* LoadDLTensor(dmlc::Stream* strm) {
uint64_t header, reserved;
CHECK(strm->Read(&header, sizeof(header)))
CHECK(strm->Read(&header))
<< "Invalid DLTensor file format";
CHECK(strm->Read(&reserved, sizeof(reserved)))
CHECK(strm->Read(&reserved))
<< "Invalid DLTensor file format";
CHECK(header == kTVMNDArrayMagic)
<< "Invalid DLTensor file format";
DLTensor tensor;
CHECK(strm->Read(&tensor.ctx, sizeof(tensor.ctx)))
CHECK(strm->Read(&(tensor.ctx)))
<< "Invalid DLTensor file format";
CHECK(strm->Read(&tensor.ndim, sizeof(tensor.ndim)))
CHECK(strm->Read(&(tensor.ndim)))
<< "Invalid DLTensor file format";
CHECK(strm->Read(&tensor.dtype, sizeof(tensor.dtype)))
CHECK(strm->Read(&(tensor.dtype)))
<< "Invalid DLTensor file format";
std::vector<int64_t> shape(tensor.ndim);
CHECK(strm->Read(&shape[0], sizeof(int64_t) * tensor.ndim))
if (tensor.ndim != 0) {
CHECK(strm->ReadArray(&shape[0], tensor.ndim))
<< "Invalid DLTensor file format";
}
DLTensor* ret;
CHECK_EQ(TVMArrayAlloc(shape.data(),
tensor.ndim,
......@@ -87,18 +95,21 @@ DLTensor* LoadDLTensor(dmlc::Stream* strm) {
static_cast<int>(tensor.ctx.device_type),
tensor.ctx.device_id,
&ret), 0) << TVMGetLastError();
int64_t size = 1;
int type_size = ret->dtype.bits / 8;
int64_t num_elems = 1;
int elem_bytes = (ret->dtype.bits + 7) / 8;
for (int i = 0; i < ret->ndim; ++i) {
size *= ret->shape[i];
num_elems *= ret->shape[i];
}
int64_t data_byte_size;
CHECK(strm->Read(&data_byte_size, sizeof(data_byte_size)))
CHECK(strm->Read(&data_byte_size))
<< "Invalid DLTensor file format";
CHECK(data_byte_size == type_size * size)
CHECK(data_byte_size == num_elems * elem_bytes)
<< "Invalid DLTensor file format";
CHECK(strm->Read(ret->data, type_size * size))
CHECK(strm->Read(ret->data, data_byte_size))
<< "Invalid DLTensor file format";
if (!DMLC_IO_NO_ENDIAN_SWAP) {
dmlc::ByteSwap(ret->data, elem_bytes, num_elems);
}
return ret;
}
......@@ -118,12 +129,12 @@ TVM_REGISTER_GLOBAL("nnvm.compiler._save_param_dict")
dmlc::MemoryStringStream strm(&bytes);
dmlc::Stream* fo = &strm;
uint64_t header = kTVMNDArrayListMagic, reserved = 0;
fo->Write(&header, sizeof(header));
fo->Write(&reserved, sizeof(reserved));
fo->Write(header);
fo->Write(reserved);
fo->Write(names);
{
uint64_t sz = static_cast<uint64_t>(arrays.size());
fo->Write(&sz, sizeof(sz));
fo->Write(sz);
for (size_t i = 0; i < sz; ++i) {
SaveDLTensor(fo, arrays[i]);
}
......@@ -150,7 +161,6 @@ TVM_REGISTER_GLOBAL("nnvm.compiler._load_param_dict")
<< "Invalid parameters file format";
CHECK(strm->Read(&reserved))
<< "Invalid parameters file format";
CHECK(strm->Read(&names))
<< "Invalid parameters file format";
uint64_t sz;
......
......@@ -73,7 +73,7 @@ def sendjson(sock, data):
Python value to be sent.
"""
data = json.dumps(data)
sock.sendall(struct.pack("@i", len(data)))
sock.sendall(struct.pack("<i", len(data)))
sock.sendall(data.encode("utf-8"))
......@@ -90,7 +90,7 @@ def recvjson(sock):
value : object
The value received.
"""
size = struct.unpack("@i", recvall(sock, 4))[0]
size = struct.unpack("<i", recvall(sock, 4))[0]
data = json.loads(py_str(recvall(sock, size)))
return data
......
......@@ -192,8 +192,8 @@ class TrackerSession(object):
def _connect(self):
self._sock = base.connect_with_retry(self._addr)
self._sock.sendall(struct.pack("@i", base.RPC_TRACKER_MAGIC))
magic = struct.unpack("@i", base.recvall(self._sock, 4))[0]
self._sock.sendall(struct.pack("<i", base.RPC_TRACKER_MAGIC))
magic = struct.unpack("<i", base.recvall(self._sock, 4))[0]
if magic != base.RPC_TRACKER_MAGIC:
raise RuntimeError("%s is not RPC Tracker" % str(self._addr))
......
......@@ -58,14 +58,14 @@ class ForwardHandler(object):
def _init_step(self, message):
if self._magic is None:
assert len(message) == 4
self._magic = struct.unpack('@i', message)[0]
self._magic = struct.unpack('<i', message)[0]
if self._magic != base.RPC_MAGIC:
logging.info("Invalid RPC magic from %s", self.name())
self.close()
self._init_req_nbytes = 4
elif self._rpc_key_length is None:
assert len(message) == 4
self._rpc_key_length = struct.unpack('@i', message)[0]
self._rpc_key_length = struct.unpack('<i', message)[0]
self._init_req_nbytes = self._rpc_key_length
elif self.rpc_key is None:
assert len(message) == self._rpc_key_length
......@@ -269,12 +269,12 @@ class ProxyServerHandler(object):
lhs.forward_proxy = rhs
rhs.forward_proxy = lhs
lhs.send_data(struct.pack('@i', base.RPC_CODE_SUCCESS))
lhs.send_data(struct.pack('@i', len(rhs.rpc_key)))
lhs.send_data(struct.pack('<i', base.RPC_CODE_SUCCESS))
lhs.send_data(struct.pack('<i', len(rhs.rpc_key)))
lhs.send_data(rhs.rpc_key.encode("utf-8"))
rhs.send_data(struct.pack('@i', base.RPC_CODE_SUCCESS))
rhs.send_data(struct.pack('@i', len(lhs.rpc_key)))
rhs.send_data(struct.pack('<i', base.RPC_CODE_SUCCESS))
rhs.send_data(struct.pack('<i', len(lhs.rpc_key)))
rhs.send_data(lhs.rpc_key.encode("utf-8"))
logging.info("Pairup connect %s and %s", lhs.name(), rhs.name())
......@@ -299,8 +299,8 @@ class ProxyServerHandler(object):
if self._tracker_conn is None:
self._tracker_conn = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._tracker_conn.connect(self._tracker_addr)
self._tracker_conn.sendall(struct.pack("@i", base.RPC_TRACKER_MAGIC))
magic = struct.unpack("@i", base.recvall(self._tracker_conn, 4))[0]
self._tracker_conn.sendall(struct.pack("<i", base.RPC_TRACKER_MAGIC))
magic = struct.unpack("<i", base.recvall(self._tracker_conn, 4))[0]
if magic != base.RPC_TRACKER_MAGIC:
self.loop.stop()
raise RuntimeError("%s is not RPC Tracker" % str(self._tracker_addr))
......@@ -371,7 +371,7 @@ class ProxyServerHandler(object):
if handler.match_key in self._server_pool:
self._pair_up(self._server_pool.pop(handler.match_key), handler)
else:
handler.send_data(struct.pack('@i', base.RPC_CODE_MISMATCH))
handler.send_data(struct.pack('<i', base.RPC_CODE_MISMATCH))
handler.signal_close()
def _handler_ready_proxy_mode(self, handler):
......@@ -395,12 +395,12 @@ class ProxyServerHandler(object):
logging.info("Timeout client connection %s, cannot find match key=%s",
handler.name(), key)
pool_dst.pop(key)
handler.send_data(struct.pack('@i', base.RPC_CODE_MISMATCH))
handler.send_data(struct.pack('<i', base.RPC_CODE_MISMATCH))
handler.signal_close()
self.loop.call_later(timeout, cleanup)
else:
logging.info("Duplicate connection with same key=%s", key)
handler.send_data(struct.pack('@i', base.RPC_CODE_DUPLICATE))
handler.send_data(struct.pack('<i', base.RPC_CODE_DUPLICATE))
handler.signal_close()
def handler_ready(self, handler):
......@@ -538,13 +538,13 @@ def websocket_proxy_server(url, key=""):
on_message = create_on_message(conn)
temp = _server_env(None)
# Start connecton
conn.write_message(struct.pack('@i', base.RPC_MAGIC), binary=True)
conn.write_message(struct.pack('<i', base.RPC_MAGIC), binary=True)
key = "server:" + key
conn.write_message(struct.pack('@i', len(key)), binary=True)
conn.write_message(struct.pack('<i', len(key)), binary=True)
conn.write_message(key.encode("utf-8"), binary=True)
msg = yield conn.read_message()
assert len(msg) >= 4
magic = struct.unpack('@i', msg[:4])[0]
magic = struct.unpack('<i', msg[:4])[0]
if magic == base.RPC_CODE_DUPLICATE:
raise RuntimeError("key: %s has already been used in proxy" % key)
elif magic == base.RPC_CODE_MISMATCH:
......
......@@ -124,23 +124,23 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library):
unmatch_period_count = 0
continue
conn, addr = listen_sock.accept()
magic = struct.unpack("@i", base.recvall(conn, 4))[0]
magic = struct.unpack("<i", base.recvall(conn, 4))[0]
if magic != base.RPC_MAGIC:
conn.close()
continue
keylen = struct.unpack("@i", base.recvall(conn, 4))[0]
keylen = struct.unpack("<i", base.recvall(conn, 4))[0]
key = py_str(base.recvall(conn, keylen))
arr = key.split()
expect_header = "client:" + matchkey
server_key = "server:" + rpc_key
if arr[0] != expect_header:
conn.sendall(struct.pack("@i", base.RPC_CODE_MISMATCH))
conn.sendall(struct.pack("<i", base.RPC_CODE_MISMATCH))
conn.close()
logging.info("RPCServer: mismatch key from %s", addr)
continue
else:
conn.sendall(struct.pack("@i", base.RPC_CODE_SUCCESS))
conn.sendall(struct.pack("@i", len(server_key)))
conn.sendall(struct.pack("<i", base.RPC_CODE_SUCCESS))
conn.sendall(struct.pack("<i", len(server_key)))
conn.sendall(server_key.encode("utf-8"))
return conn, addr, _parse_server_opt(arr[1:])
......@@ -151,8 +151,8 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library):
# step 1: setup tracker and report to tracker
if tracker_addr and tracker_conn is None:
tracker_conn = base.connect_with_retry(tracker_addr)
tracker_conn.sendall(struct.pack("@i", base.RPC_TRACKER_MAGIC))
magic = struct.unpack("@i", base.recvall(tracker_conn, 4))[0]
tracker_conn.sendall(struct.pack("<i", base.RPC_TRACKER_MAGIC))
magic = struct.unpack("<i", base.recvall(tracker_conn, 4))[0]
if magic != base.RPC_TRACKER_MAGIC:
raise RuntimeError("%s is not RPC Tracker" % str(tracker_addr))
# report status of current queue
......@@ -193,17 +193,17 @@ def _connect_proxy_loop(addr, key, load_library):
try:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect(addr)
sock.sendall(struct.pack("@i", base.RPC_MAGIC))
sock.sendall(struct.pack("@i", len(key)))
sock.sendall(struct.pack("<i", base.RPC_MAGIC))
sock.sendall(struct.pack("<i", len(key)))
sock.sendall(key.encode("utf-8"))
magic = struct.unpack("@i", base.recvall(sock, 4))[0]
magic = struct.unpack("<i", base.recvall(sock, 4))[0]
if magic == base.RPC_CODE_DUPLICATE:
raise RuntimeError("key: %s has already been used in proxy" % key)
elif magic == base.RPC_CODE_MISMATCH:
logging.info("RPCProxy do not have matching client key %s", key)
elif magic != base.RPC_CODE_SUCCESS:
raise RuntimeError("%s is not RPC Proxy" % str(addr))
keylen = struct.unpack("@i", base.recvall(sock, 4))[0]
keylen = struct.unpack("<i", base.recvall(sock, 4))[0]
remote_key = py_str(base.recvall(sock, keylen))
opts = _parse_server_opt(remote_key.split()[1:])
logging.info("RPCProxy connected to %s", str(addr))
......
......@@ -143,11 +143,11 @@ class TCPEventHandler(tornado_util.TCPHandler):
if len(message) != 4:
logging.info("Invalid connection from %s", self.name())
self.close()
magic = struct.unpack('@i', message)[0]
magic = struct.unpack('<i', message)[0]
if magic != RPC_TRACKER_MAGIC:
logging.info("Invalid magic from %s", self.name())
self.close()
self.write_message(struct.pack('@i', RPC_TRACKER_MAGIC), binary=True)
self.write_message(struct.pack('<i', RPC_TRACKER_MAGIC), binary=True)
self._init_req_nbytes = 0
def on_message(self, message):
......@@ -168,7 +168,7 @@ class TCPEventHandler(tornado_util.TCPHandler):
while True:
if self._msg_size == 0:
if len(self._data) >= 4:
self._msg_size = struct.unpack('@i', self._data[:4])[0]
self._msg_size = struct.unpack('<i', self._data[:4])[0]
else:
return
if self._msg_size != 0 and len(self._data) >= self._msg_size + 4:
......@@ -184,7 +184,7 @@ class TCPEventHandler(tornado_util.TCPHandler):
"""return value to the output"""
data = json.dumps(data)
self.write_message(
struct.pack('@i', len(data)), binary=True)
struct.pack('<i', len(data)), binary=True)
self.write_message(data.encode("utf-8"), binary=True)
def call_handler(self, args):
......@@ -355,8 +355,8 @@ class Tracker(object):
def _stop_tracker(self):
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect((self.host, self.port))
sock.sendall(struct.pack("@i", base.RPC_TRACKER_MAGIC))
magic = struct.unpack("@i", base.recvall(sock, 4))[0]
sock.sendall(struct.pack("<i", base.RPC_TRACKER_MAGIC))
magic = struct.unpack("<i", base.recvall(sock, 4))[0]
assert magic == base.RPC_TRACKER_MAGIC
base.sendjson(sock, [TrackerCode.STOP, self.stop_key])
assert base.recvjson(sock) == TrackerCode.SUCCESS
......
......@@ -4,6 +4,7 @@
*/
#include <dmlc/json.h>
#include <dmlc/logging.h>
#include <tvm/runtime/serializer.h>
#include <fstream>
#include "./file_util.h"
......
......@@ -4,6 +4,7 @@
*/
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/serializer.h>
#include <dmlc/memory_io.h>
#include <dmlc/json.h>
#include <numeric>
......@@ -397,24 +398,25 @@ class GraphRuntime : public ModuleNode {
void GraphRuntime::LoadDLTensor(dmlc::Stream* strm, DLTensor* dst) {
// always use strm->Read to maintain endianness conversion
uint64_t header, reserved;
CHECK(strm->Read(&header, sizeof(header)))
CHECK(strm->Read(&header))
<< "Invalid DLTensor file format";
CHECK(strm->Read(&reserved, sizeof(reserved)))
CHECK(strm->Read(&reserved))
<< "Invalid DLTensor file format";
CHECK(header == kTVMNDArrayMagic)
<< "Invalid DLTensor file format";
DLTensor tensor;
CHECK(strm->Read(&tensor.ctx, sizeof(tensor.ctx)))
CHECK(strm->Read(&(tensor.ctx)))
<< "Invalid DLTensor file format";
CHECK(strm->Read(&tensor.ndim, sizeof(tensor.ndim)))
CHECK(strm->Read(&(tensor.ndim)))
<< "Invalid DLTensor file format";
CHECK(strm->Read(&tensor.dtype, sizeof(tensor.dtype)))
CHECK(strm->Read(&(tensor.dtype)))
<< "Invalid DLTensor file format";
std::vector<int64_t> shape(tensor.ndim);
if (tensor.ndim != 0) {
CHECK(strm->Read(&shape[0], sizeof(int64_t) * tensor.ndim))
CHECK(strm->ReadArray(&shape[0], tensor.ndim))
<< "Invalid DLTensor file format";
}
CHECK_EQ(tensor.ndim, dst->ndim) << "param dimension mismatch";
......@@ -425,18 +427,23 @@ void GraphRuntime::LoadDLTensor(dmlc::Stream* strm, DLTensor* dst) {
CHECK_EQ(shape[i], dst->shape[i]) << "param shape mismatch";
}
size_t bits = dst->dtype.bits * dst->dtype.lanes;
size_t size = (bits + 7) / 8;
size_t elem_bytes = (bits + 7) / 8;
size_t num_elems = 1;
for (int i = 0; i < dst->ndim; ++i) {
size *= dst->shape[i];
num_elems *= dst->shape[i];
}
uint64_t data_byte_size;
CHECK(strm->Read(&data_byte_size, sizeof(data_byte_size)))
CHECK(strm->Read(&data_byte_size))
<< "Invalid DLTensor file format";
CHECK(data_byte_size == size)
CHECK_EQ(data_byte_size, elem_bytes * num_elems)
<< "Invalid DLTensor file format";
std::vector<uint8_t> bytes(data_byte_size + 1);
CHECK(strm->Read(&bytes[0], data_byte_size))
<< "Invalid DLTensor file format";
// explicitly swap endian when necessary.
if (!DMLC_IO_NO_ENDIAN_SWAP) {
dmlc::ByteSwap(&bytes[0], elem_bytes, num_elems);
}
TVM_CCALL(TVMArrayCopyFromBytes(dst, &bytes[0], data_byte_size));
}
......@@ -453,9 +460,8 @@ void GraphRuntime::LoadParams(dmlc::Stream* strm) {
CHECK(strm->Read(&names))
<< "Invalid parameters file format";
uint64_t sz;
strm->Read(&sz, sizeof(sz));
strm->Read(&sz);
size_t size = static_cast<size_t>(sz);
CHECK(size == names.size())
<< "Invalid parameters file format";
for (size_t i = 0; i < size; ++i) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment