Commit 42608dda by tqchen Committed by Tianqi Chen

[IO] Support cross-endian

parent 0f9dab98
Subproject commit d3f7fbb53e5b037c0f5bf6bd21871ccc720690cc Subproject commit 9b3f9753ae81d657743c555e0cacc4e43f0bed2d
/*!
* Copyright (c) 2017 by Contributors
* \file tvm/runtime/serializer.h
* \brief Serializer extension to support TVM data types
* Include this file to enable serialization of DLDataType, DLContext
*/
#ifndef TVM_RUNTIME_SERIALIZER_H_
#define TVM_RUNTIME_SERIALIZER_H_
#include <dmlc/io.h>
#include <dmlc/serializer.h>
#include "./c_runtime_api.h"
namespace dmlc {
namespace serializer {
template<>
struct Handler<DLDataType> {
inline static void Write(Stream *strm, const DLDataType& dtype) {
Handler<uint8_t>::Write(strm, dtype.code);
Handler<uint8_t>::Write(strm, dtype.bits);
Handler<uint16_t>::Write(strm, dtype.lanes);
}
inline static bool Read(Stream *strm, DLDataType* dtype) {
if (!Handler<uint8_t>::Read(strm, &(dtype->code))) return false;
if (!Handler<uint8_t>::Read(strm, &(dtype->bits))) return false;
if (!Handler<uint16_t>::Read(strm, &(dtype->lanes))) return false;
return true;
}
};
template<>
struct Handler<DLContext> {
inline static void Write(Stream *strm, const DLContext& ctx) {
int32_t device_type = static_cast<int32_t>(ctx.device_type);
Handler<int32_t>::Write(strm, device_type);
Handler<int32_t>::Write(strm, ctx.device_id);
}
inline static bool Read(Stream *strm, DLContext* ctx) {
int32_t device_type = 0;
if (!Handler<int32_t>::Read(strm, &(device_type))) return false;
ctx->device_type = static_cast<DLDeviceType>(device_type);
if (!Handler<int32_t>::Read(strm, &(ctx->device_id))) return false;
return true;
}
};
} // namespace serializer
} // namespace dmlc
#endif // TVM_RUNTIME_SERIALIZER_H_
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include <tvm/runtime/packed_func.h> #include <tvm/runtime/packed_func.h>
#include <tvm/runtime/c_runtime_api.h> #include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/runtime/serializer.h>
#include "./graph_runtime.h" #include "./graph_runtime.h"
namespace nnvm { namespace nnvm {
...@@ -38,46 +39,53 @@ NNVM_REGISTER_OP(tvm_op) ...@@ -38,46 +39,53 @@ NNVM_REGISTER_OP(tvm_op)
bool SaveDLTensor(dmlc::Stream* strm, DLTensor* tensor) { bool SaveDLTensor(dmlc::Stream* strm, DLTensor* tensor) {
uint64_t header = kTVMNDArrayMagic, reserved = 0; uint64_t header = kTVMNDArrayMagic, reserved = 0;
strm->Write(&header, sizeof(header)); strm->Write(header);
strm->Write(&reserved, sizeof(reserved)); strm->Write(reserved);
strm->Write(tensor->ctx);
strm->Write(&tensor->ctx, sizeof(tensor->ctx)); strm->Write(tensor->ndim);
strm->Write(&tensor->ndim, sizeof(tensor->ndim)); strm->Write(tensor->dtype);
strm->Write(&tensor->dtype, sizeof(tensor->dtype));
int ndim = tensor->ndim; int ndim = tensor->ndim;
strm->Write(tensor->shape, sizeof(int64_t) * ndim); strm->WriteArray(tensor->shape, ndim);
int type_size = tensor->dtype.bits / 8; int type_bytes = tensor->dtype.bits / 8;
int64_t size = 1; int64_t num_elems = 1;
for (int i = 0; i < ndim; ++i) { for (int i = 0; i < ndim; ++i) {
size *= tensor->shape[i]; num_elems *= tensor->shape[i];
}
int64_t data_byte_size = type_bytes * num_elems;
strm->Write(data_byte_size);
// handle endianness of data correctly.
if (DMLC_IO_NO_ENDIAN_SWAP) {
strm->Write(tensor->data, data_byte_size);
} else {
uint8_t* dptr = reinterpret_cast<uint8_t*>(tensor->data);
std::vector<uint8_t> bytes(dptr, dptr + data_byte_size);
dmlc::ByteSwap(dmlc::BeginPtr(bytes), type_bytes, num_elems);
strm->Write(dmlc::BeginPtr(bytes), data_byte_size);
} }
int64_t data_byte_size = type_size * size;
strm->Write(&data_byte_size, sizeof(data_byte_size));
strm->Write(tensor->data, data_byte_size);
return true; return true;
} }
DLTensor* LoadDLTensor(dmlc::Stream* strm) { DLTensor* LoadDLTensor(dmlc::Stream* strm) {
uint64_t header, reserved; uint64_t header, reserved;
CHECK(strm->Read(&header, sizeof(header))) CHECK(strm->Read(&header))
<< "Invalid DLTensor file format"; << "Invalid DLTensor file format";
CHECK(strm->Read(&reserved, sizeof(reserved))) CHECK(strm->Read(&reserved))
<< "Invalid DLTensor file format"; << "Invalid DLTensor file format";
CHECK(header == kTVMNDArrayMagic) CHECK(header == kTVMNDArrayMagic)
<< "Invalid DLTensor file format"; << "Invalid DLTensor file format";
DLTensor tensor; DLTensor tensor;
CHECK(strm->Read(&tensor.ctx, sizeof(tensor.ctx))) CHECK(strm->Read(&(tensor.ctx)))
<< "Invalid DLTensor file format"; << "Invalid DLTensor file format";
CHECK(strm->Read(&tensor.ndim, sizeof(tensor.ndim))) CHECK(strm->Read(&(tensor.ndim)))
<< "Invalid DLTensor file format"; << "Invalid DLTensor file format";
CHECK(strm->Read(&tensor.dtype, sizeof(tensor.dtype))) CHECK(strm->Read(&(tensor.dtype)))
<< "Invalid DLTensor file format"; << "Invalid DLTensor file format";
std::vector<int64_t> shape(tensor.ndim); std::vector<int64_t> shape(tensor.ndim);
CHECK(strm->Read(&shape[0], sizeof(int64_t) * tensor.ndim)) if (tensor.ndim != 0) {
<< "Invalid DLTensor file format"; CHECK(strm->ReadArray(&shape[0], tensor.ndim))
<< "Invalid DLTensor file format";
}
DLTensor* ret; DLTensor* ret;
CHECK_EQ(TVMArrayAlloc(shape.data(), CHECK_EQ(TVMArrayAlloc(shape.data(),
tensor.ndim, tensor.ndim,
...@@ -87,18 +95,21 @@ DLTensor* LoadDLTensor(dmlc::Stream* strm) { ...@@ -87,18 +95,21 @@ DLTensor* LoadDLTensor(dmlc::Stream* strm) {
static_cast<int>(tensor.ctx.device_type), static_cast<int>(tensor.ctx.device_type),
tensor.ctx.device_id, tensor.ctx.device_id,
&ret), 0) << TVMGetLastError(); &ret), 0) << TVMGetLastError();
int64_t size = 1; int64_t num_elems = 1;
int type_size = ret->dtype.bits / 8; int elem_bytes = (ret->dtype.bits + 7) / 8;
for (int i = 0; i < ret->ndim; ++i) { for (int i = 0; i < ret->ndim; ++i) {
size *= ret->shape[i]; num_elems *= ret->shape[i];
} }
int64_t data_byte_size; int64_t data_byte_size;
CHECK(strm->Read(&data_byte_size, sizeof(data_byte_size))) CHECK(strm->Read(&data_byte_size))
<< "Invalid DLTensor file format"; << "Invalid DLTensor file format";
CHECK(data_byte_size == type_size * size) CHECK(data_byte_size == num_elems * elem_bytes)
<< "Invalid DLTensor file format"; << "Invalid DLTensor file format";
CHECK(strm->Read(ret->data, type_size * size)) CHECK(strm->Read(ret->data, data_byte_size))
<< "Invalid DLTensor file format"; << "Invalid DLTensor file format";
if (!DMLC_IO_NO_ENDIAN_SWAP) {
dmlc::ByteSwap(ret->data, elem_bytes, num_elems);
}
return ret; return ret;
} }
...@@ -118,12 +129,12 @@ TVM_REGISTER_GLOBAL("nnvm.compiler._save_param_dict") ...@@ -118,12 +129,12 @@ TVM_REGISTER_GLOBAL("nnvm.compiler._save_param_dict")
dmlc::MemoryStringStream strm(&bytes); dmlc::MemoryStringStream strm(&bytes);
dmlc::Stream* fo = &strm; dmlc::Stream* fo = &strm;
uint64_t header = kTVMNDArrayListMagic, reserved = 0; uint64_t header = kTVMNDArrayListMagic, reserved = 0;
fo->Write(&header, sizeof(header)); fo->Write(header);
fo->Write(&reserved, sizeof(reserved)); fo->Write(reserved);
fo->Write(names); fo->Write(names);
{ {
uint64_t sz = static_cast<uint64_t>(arrays.size()); uint64_t sz = static_cast<uint64_t>(arrays.size());
fo->Write(&sz, sizeof(sz)); fo->Write(sz);
for (size_t i = 0; i < sz; ++i) { for (size_t i = 0; i < sz; ++i) {
SaveDLTensor(fo, arrays[i]); SaveDLTensor(fo, arrays[i]);
} }
...@@ -150,7 +161,6 @@ TVM_REGISTER_GLOBAL("nnvm.compiler._load_param_dict") ...@@ -150,7 +161,6 @@ TVM_REGISTER_GLOBAL("nnvm.compiler._load_param_dict")
<< "Invalid parameters file format"; << "Invalid parameters file format";
CHECK(strm->Read(&reserved)) CHECK(strm->Read(&reserved))
<< "Invalid parameters file format"; << "Invalid parameters file format";
CHECK(strm->Read(&names)) CHECK(strm->Read(&names))
<< "Invalid parameters file format"; << "Invalid parameters file format";
uint64_t sz; uint64_t sz;
......
...@@ -73,7 +73,7 @@ def sendjson(sock, data): ...@@ -73,7 +73,7 @@ def sendjson(sock, data):
Python value to be sent. Python value to be sent.
""" """
data = json.dumps(data) data = json.dumps(data)
sock.sendall(struct.pack("@i", len(data))) sock.sendall(struct.pack("<i", len(data)))
sock.sendall(data.encode("utf-8")) sock.sendall(data.encode("utf-8"))
...@@ -90,7 +90,7 @@ def recvjson(sock): ...@@ -90,7 +90,7 @@ def recvjson(sock):
value : object value : object
The value received. The value received.
""" """
size = struct.unpack("@i", recvall(sock, 4))[0] size = struct.unpack("<i", recvall(sock, 4))[0]
data = json.loads(py_str(recvall(sock, size))) data = json.loads(py_str(recvall(sock, size)))
return data return data
......
...@@ -192,8 +192,8 @@ class TrackerSession(object): ...@@ -192,8 +192,8 @@ class TrackerSession(object):
def _connect(self): def _connect(self):
self._sock = base.connect_with_retry(self._addr) self._sock = base.connect_with_retry(self._addr)
self._sock.sendall(struct.pack("@i", base.RPC_TRACKER_MAGIC)) self._sock.sendall(struct.pack("<i", base.RPC_TRACKER_MAGIC))
magic = struct.unpack("@i", base.recvall(self._sock, 4))[0] magic = struct.unpack("<i", base.recvall(self._sock, 4))[0]
if magic != base.RPC_TRACKER_MAGIC: if magic != base.RPC_TRACKER_MAGIC:
raise RuntimeError("%s is not RPC Tracker" % str(self._addr)) raise RuntimeError("%s is not RPC Tracker" % str(self._addr))
......
...@@ -58,14 +58,14 @@ class ForwardHandler(object): ...@@ -58,14 +58,14 @@ class ForwardHandler(object):
def _init_step(self, message): def _init_step(self, message):
if self._magic is None: if self._magic is None:
assert len(message) == 4 assert len(message) == 4
self._magic = struct.unpack('@i', message)[0] self._magic = struct.unpack('<i', message)[0]
if self._magic != base.RPC_MAGIC: if self._magic != base.RPC_MAGIC:
logging.info("Invalid RPC magic from %s", self.name()) logging.info("Invalid RPC magic from %s", self.name())
self.close() self.close()
self._init_req_nbytes = 4 self._init_req_nbytes = 4
elif self._rpc_key_length is None: elif self._rpc_key_length is None:
assert len(message) == 4 assert len(message) == 4
self._rpc_key_length = struct.unpack('@i', message)[0] self._rpc_key_length = struct.unpack('<i', message)[0]
self._init_req_nbytes = self._rpc_key_length self._init_req_nbytes = self._rpc_key_length
elif self.rpc_key is None: elif self.rpc_key is None:
assert len(message) == self._rpc_key_length assert len(message) == self._rpc_key_length
...@@ -269,12 +269,12 @@ class ProxyServerHandler(object): ...@@ -269,12 +269,12 @@ class ProxyServerHandler(object):
lhs.forward_proxy = rhs lhs.forward_proxy = rhs
rhs.forward_proxy = lhs rhs.forward_proxy = lhs
lhs.send_data(struct.pack('@i', base.RPC_CODE_SUCCESS)) lhs.send_data(struct.pack('<i', base.RPC_CODE_SUCCESS))
lhs.send_data(struct.pack('@i', len(rhs.rpc_key))) lhs.send_data(struct.pack('<i', len(rhs.rpc_key)))
lhs.send_data(rhs.rpc_key.encode("utf-8")) lhs.send_data(rhs.rpc_key.encode("utf-8"))
rhs.send_data(struct.pack('@i', base.RPC_CODE_SUCCESS)) rhs.send_data(struct.pack('<i', base.RPC_CODE_SUCCESS))
rhs.send_data(struct.pack('@i', len(lhs.rpc_key))) rhs.send_data(struct.pack('<i', len(lhs.rpc_key)))
rhs.send_data(lhs.rpc_key.encode("utf-8")) rhs.send_data(lhs.rpc_key.encode("utf-8"))
logging.info("Pairup connect %s and %s", lhs.name(), rhs.name()) logging.info("Pairup connect %s and %s", lhs.name(), rhs.name())
...@@ -299,8 +299,8 @@ class ProxyServerHandler(object): ...@@ -299,8 +299,8 @@ class ProxyServerHandler(object):
if self._tracker_conn is None: if self._tracker_conn is None:
self._tracker_conn = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self._tracker_conn = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._tracker_conn.connect(self._tracker_addr) self._tracker_conn.connect(self._tracker_addr)
self._tracker_conn.sendall(struct.pack("@i", base.RPC_TRACKER_MAGIC)) self._tracker_conn.sendall(struct.pack("<i", base.RPC_TRACKER_MAGIC))
magic = struct.unpack("@i", base.recvall(self._tracker_conn, 4))[0] magic = struct.unpack("<i", base.recvall(self._tracker_conn, 4))[0]
if magic != base.RPC_TRACKER_MAGIC: if magic != base.RPC_TRACKER_MAGIC:
self.loop.stop() self.loop.stop()
raise RuntimeError("%s is not RPC Tracker" % str(self._tracker_addr)) raise RuntimeError("%s is not RPC Tracker" % str(self._tracker_addr))
...@@ -371,7 +371,7 @@ class ProxyServerHandler(object): ...@@ -371,7 +371,7 @@ class ProxyServerHandler(object):
if handler.match_key in self._server_pool: if handler.match_key in self._server_pool:
self._pair_up(self._server_pool.pop(handler.match_key), handler) self._pair_up(self._server_pool.pop(handler.match_key), handler)
else: else:
handler.send_data(struct.pack('@i', base.RPC_CODE_MISMATCH)) handler.send_data(struct.pack('<i', base.RPC_CODE_MISMATCH))
handler.signal_close() handler.signal_close()
def _handler_ready_proxy_mode(self, handler): def _handler_ready_proxy_mode(self, handler):
...@@ -395,12 +395,12 @@ class ProxyServerHandler(object): ...@@ -395,12 +395,12 @@ class ProxyServerHandler(object):
logging.info("Timeout client connection %s, cannot find match key=%s", logging.info("Timeout client connection %s, cannot find match key=%s",
handler.name(), key) handler.name(), key)
pool_dst.pop(key) pool_dst.pop(key)
handler.send_data(struct.pack('@i', base.RPC_CODE_MISMATCH)) handler.send_data(struct.pack('<i', base.RPC_CODE_MISMATCH))
handler.signal_close() handler.signal_close()
self.loop.call_later(timeout, cleanup) self.loop.call_later(timeout, cleanup)
else: else:
logging.info("Duplicate connection with same key=%s", key) logging.info("Duplicate connection with same key=%s", key)
handler.send_data(struct.pack('@i', base.RPC_CODE_DUPLICATE)) handler.send_data(struct.pack('<i', base.RPC_CODE_DUPLICATE))
handler.signal_close() handler.signal_close()
def handler_ready(self, handler): def handler_ready(self, handler):
...@@ -538,13 +538,13 @@ def websocket_proxy_server(url, key=""): ...@@ -538,13 +538,13 @@ def websocket_proxy_server(url, key=""):
on_message = create_on_message(conn) on_message = create_on_message(conn)
temp = _server_env(None) temp = _server_env(None)
# Start connecton # Start connecton
conn.write_message(struct.pack('@i', base.RPC_MAGIC), binary=True) conn.write_message(struct.pack('<i', base.RPC_MAGIC), binary=True)
key = "server:" + key key = "server:" + key
conn.write_message(struct.pack('@i', len(key)), binary=True) conn.write_message(struct.pack('<i', len(key)), binary=True)
conn.write_message(key.encode("utf-8"), binary=True) conn.write_message(key.encode("utf-8"), binary=True)
msg = yield conn.read_message() msg = yield conn.read_message()
assert len(msg) >= 4 assert len(msg) >= 4
magic = struct.unpack('@i', msg[:4])[0] magic = struct.unpack('<i', msg[:4])[0]
if magic == base.RPC_CODE_DUPLICATE: if magic == base.RPC_CODE_DUPLICATE:
raise RuntimeError("key: %s has already been used in proxy" % key) raise RuntimeError("key: %s has already been used in proxy" % key)
elif magic == base.RPC_CODE_MISMATCH: elif magic == base.RPC_CODE_MISMATCH:
......
...@@ -124,23 +124,23 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library): ...@@ -124,23 +124,23 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library):
unmatch_period_count = 0 unmatch_period_count = 0
continue continue
conn, addr = listen_sock.accept() conn, addr = listen_sock.accept()
magic = struct.unpack("@i", base.recvall(conn, 4))[0] magic = struct.unpack("<i", base.recvall(conn, 4))[0]
if magic != base.RPC_MAGIC: if magic != base.RPC_MAGIC:
conn.close() conn.close()
continue continue
keylen = struct.unpack("@i", base.recvall(conn, 4))[0] keylen = struct.unpack("<i", base.recvall(conn, 4))[0]
key = py_str(base.recvall(conn, keylen)) key = py_str(base.recvall(conn, keylen))
arr = key.split() arr = key.split()
expect_header = "client:" + matchkey expect_header = "client:" + matchkey
server_key = "server:" + rpc_key server_key = "server:" + rpc_key
if arr[0] != expect_header: if arr[0] != expect_header:
conn.sendall(struct.pack("@i", base.RPC_CODE_MISMATCH)) conn.sendall(struct.pack("<i", base.RPC_CODE_MISMATCH))
conn.close() conn.close()
logging.info("RPCServer: mismatch key from %s", addr) logging.info("RPCServer: mismatch key from %s", addr)
continue continue
else: else:
conn.sendall(struct.pack("@i", base.RPC_CODE_SUCCESS)) conn.sendall(struct.pack("<i", base.RPC_CODE_SUCCESS))
conn.sendall(struct.pack("@i", len(server_key))) conn.sendall(struct.pack("<i", len(server_key)))
conn.sendall(server_key.encode("utf-8")) conn.sendall(server_key.encode("utf-8"))
return conn, addr, _parse_server_opt(arr[1:]) return conn, addr, _parse_server_opt(arr[1:])
...@@ -151,8 +151,8 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library): ...@@ -151,8 +151,8 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library):
# step 1: setup tracker and report to tracker # step 1: setup tracker and report to tracker
if tracker_addr and tracker_conn is None: if tracker_addr and tracker_conn is None:
tracker_conn = base.connect_with_retry(tracker_addr) tracker_conn = base.connect_with_retry(tracker_addr)
tracker_conn.sendall(struct.pack("@i", base.RPC_TRACKER_MAGIC)) tracker_conn.sendall(struct.pack("<i", base.RPC_TRACKER_MAGIC))
magic = struct.unpack("@i", base.recvall(tracker_conn, 4))[0] magic = struct.unpack("<i", base.recvall(tracker_conn, 4))[0]
if magic != base.RPC_TRACKER_MAGIC: if magic != base.RPC_TRACKER_MAGIC:
raise RuntimeError("%s is not RPC Tracker" % str(tracker_addr)) raise RuntimeError("%s is not RPC Tracker" % str(tracker_addr))
# report status of current queue # report status of current queue
...@@ -193,17 +193,17 @@ def _connect_proxy_loop(addr, key, load_library): ...@@ -193,17 +193,17 @@ def _connect_proxy_loop(addr, key, load_library):
try: try:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect(addr) sock.connect(addr)
sock.sendall(struct.pack("@i", base.RPC_MAGIC)) sock.sendall(struct.pack("<i", base.RPC_MAGIC))
sock.sendall(struct.pack("@i", len(key))) sock.sendall(struct.pack("<i", len(key)))
sock.sendall(key.encode("utf-8")) sock.sendall(key.encode("utf-8"))
magic = struct.unpack("@i", base.recvall(sock, 4))[0] magic = struct.unpack("<i", base.recvall(sock, 4))[0]
if magic == base.RPC_CODE_DUPLICATE: if magic == base.RPC_CODE_DUPLICATE:
raise RuntimeError("key: %s has already been used in proxy" % key) raise RuntimeError("key: %s has already been used in proxy" % key)
elif magic == base.RPC_CODE_MISMATCH: elif magic == base.RPC_CODE_MISMATCH:
logging.info("RPCProxy do not have matching client key %s", key) logging.info("RPCProxy do not have matching client key %s", key)
elif magic != base.RPC_CODE_SUCCESS: elif magic != base.RPC_CODE_SUCCESS:
raise RuntimeError("%s is not RPC Proxy" % str(addr)) raise RuntimeError("%s is not RPC Proxy" % str(addr))
keylen = struct.unpack("@i", base.recvall(sock, 4))[0] keylen = struct.unpack("<i", base.recvall(sock, 4))[0]
remote_key = py_str(base.recvall(sock, keylen)) remote_key = py_str(base.recvall(sock, keylen))
opts = _parse_server_opt(remote_key.split()[1:]) opts = _parse_server_opt(remote_key.split()[1:])
logging.info("RPCProxy connected to %s", str(addr)) logging.info("RPCProxy connected to %s", str(addr))
......
...@@ -143,11 +143,11 @@ class TCPEventHandler(tornado_util.TCPHandler): ...@@ -143,11 +143,11 @@ class TCPEventHandler(tornado_util.TCPHandler):
if len(message) != 4: if len(message) != 4:
logging.info("Invalid connection from %s", self.name()) logging.info("Invalid connection from %s", self.name())
self.close() self.close()
magic = struct.unpack('@i', message)[0] magic = struct.unpack('<i', message)[0]
if magic != RPC_TRACKER_MAGIC: if magic != RPC_TRACKER_MAGIC:
logging.info("Invalid magic from %s", self.name()) logging.info("Invalid magic from %s", self.name())
self.close() self.close()
self.write_message(struct.pack('@i', RPC_TRACKER_MAGIC), binary=True) self.write_message(struct.pack('<i', RPC_TRACKER_MAGIC), binary=True)
self._init_req_nbytes = 0 self._init_req_nbytes = 0
def on_message(self, message): def on_message(self, message):
...@@ -168,7 +168,7 @@ class TCPEventHandler(tornado_util.TCPHandler): ...@@ -168,7 +168,7 @@ class TCPEventHandler(tornado_util.TCPHandler):
while True: while True:
if self._msg_size == 0: if self._msg_size == 0:
if len(self._data) >= 4: if len(self._data) >= 4:
self._msg_size = struct.unpack('@i', self._data[:4])[0] self._msg_size = struct.unpack('<i', self._data[:4])[0]
else: else:
return return
if self._msg_size != 0 and len(self._data) >= self._msg_size + 4: if self._msg_size != 0 and len(self._data) >= self._msg_size + 4:
...@@ -184,7 +184,7 @@ class TCPEventHandler(tornado_util.TCPHandler): ...@@ -184,7 +184,7 @@ class TCPEventHandler(tornado_util.TCPHandler):
"""return value to the output""" """return value to the output"""
data = json.dumps(data) data = json.dumps(data)
self.write_message( self.write_message(
struct.pack('@i', len(data)), binary=True) struct.pack('<i', len(data)), binary=True)
self.write_message(data.encode("utf-8"), binary=True) self.write_message(data.encode("utf-8"), binary=True)
def call_handler(self, args): def call_handler(self, args):
...@@ -355,8 +355,8 @@ class Tracker(object): ...@@ -355,8 +355,8 @@ class Tracker(object):
def _stop_tracker(self): def _stop_tracker(self):
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect((self.host, self.port)) sock.connect((self.host, self.port))
sock.sendall(struct.pack("@i", base.RPC_TRACKER_MAGIC)) sock.sendall(struct.pack("<i", base.RPC_TRACKER_MAGIC))
magic = struct.unpack("@i", base.recvall(sock, 4))[0] magic = struct.unpack("<i", base.recvall(sock, 4))[0]
assert magic == base.RPC_TRACKER_MAGIC assert magic == base.RPC_TRACKER_MAGIC
base.sendjson(sock, [TrackerCode.STOP, self.stop_key]) base.sendjson(sock, [TrackerCode.STOP, self.stop_key])
assert base.recvjson(sock) == TrackerCode.SUCCESS assert base.recvjson(sock) == TrackerCode.SUCCESS
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
*/ */
#include <dmlc/json.h> #include <dmlc/json.h>
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <tvm/runtime/serializer.h>
#include <fstream> #include <fstream>
#include "./file_util.h" #include "./file_util.h"
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
*/ */
#include <tvm/runtime/packed_func.h> #include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/runtime/serializer.h>
#include <dmlc/memory_io.h> #include <dmlc/memory_io.h>
#include <dmlc/json.h> #include <dmlc/json.h>
#include <numeric> #include <numeric>
...@@ -397,24 +398,25 @@ class GraphRuntime : public ModuleNode { ...@@ -397,24 +398,25 @@ class GraphRuntime : public ModuleNode {
void GraphRuntime::LoadDLTensor(dmlc::Stream* strm, DLTensor* dst) { void GraphRuntime::LoadDLTensor(dmlc::Stream* strm, DLTensor* dst) {
// always use strm->Read to maintain endianness conversion
uint64_t header, reserved; uint64_t header, reserved;
CHECK(strm->Read(&header, sizeof(header))) CHECK(strm->Read(&header))
<< "Invalid DLTensor file format"; << "Invalid DLTensor file format";
CHECK(strm->Read(&reserved, sizeof(reserved))) CHECK(strm->Read(&reserved))
<< "Invalid DLTensor file format"; << "Invalid DLTensor file format";
CHECK(header == kTVMNDArrayMagic) CHECK(header == kTVMNDArrayMagic)
<< "Invalid DLTensor file format"; << "Invalid DLTensor file format";
DLTensor tensor; DLTensor tensor;
CHECK(strm->Read(&tensor.ctx, sizeof(tensor.ctx))) CHECK(strm->Read(&(tensor.ctx)))
<< "Invalid DLTensor file format"; << "Invalid DLTensor file format";
CHECK(strm->Read(&tensor.ndim, sizeof(tensor.ndim))) CHECK(strm->Read(&(tensor.ndim)))
<< "Invalid DLTensor file format"; << "Invalid DLTensor file format";
CHECK(strm->Read(&tensor.dtype, sizeof(tensor.dtype))) CHECK(strm->Read(&(tensor.dtype)))
<< "Invalid DLTensor file format"; << "Invalid DLTensor file format";
std::vector<int64_t> shape(tensor.ndim); std::vector<int64_t> shape(tensor.ndim);
if (tensor.ndim != 0) { if (tensor.ndim != 0) {
CHECK(strm->Read(&shape[0], sizeof(int64_t) * tensor.ndim)) CHECK(strm->ReadArray(&shape[0], tensor.ndim))
<< "Invalid DLTensor file format"; << "Invalid DLTensor file format";
} }
CHECK_EQ(tensor.ndim, dst->ndim) << "param dimension mismatch"; CHECK_EQ(tensor.ndim, dst->ndim) << "param dimension mismatch";
...@@ -425,18 +427,23 @@ void GraphRuntime::LoadDLTensor(dmlc::Stream* strm, DLTensor* dst) { ...@@ -425,18 +427,23 @@ void GraphRuntime::LoadDLTensor(dmlc::Stream* strm, DLTensor* dst) {
CHECK_EQ(shape[i], dst->shape[i]) << "param shape mismatch"; CHECK_EQ(shape[i], dst->shape[i]) << "param shape mismatch";
} }
size_t bits = dst->dtype.bits * dst->dtype.lanes; size_t bits = dst->dtype.bits * dst->dtype.lanes;
size_t size = (bits + 7) / 8; size_t elem_bytes = (bits + 7) / 8;
size_t num_elems = 1;
for (int i = 0; i < dst->ndim; ++i) { for (int i = 0; i < dst->ndim; ++i) {
size *= dst->shape[i]; num_elems *= dst->shape[i];
} }
uint64_t data_byte_size; uint64_t data_byte_size;
CHECK(strm->Read(&data_byte_size, sizeof(data_byte_size))) CHECK(strm->Read(&data_byte_size))
<< "Invalid DLTensor file format"; << "Invalid DLTensor file format";
CHECK(data_byte_size == size) CHECK_EQ(data_byte_size, elem_bytes * num_elems)
<< "Invalid DLTensor file format"; << "Invalid DLTensor file format";
std::vector<uint8_t> bytes(data_byte_size + 1); std::vector<uint8_t> bytes(data_byte_size + 1);
CHECK(strm->Read(&bytes[0], data_byte_size)) CHECK(strm->Read(&bytes[0], data_byte_size))
<< "Invalid DLTensor file format"; << "Invalid DLTensor file format";
// explicitly swap endian when necessary.
if (!DMLC_IO_NO_ENDIAN_SWAP) {
dmlc::ByteSwap(&bytes[0], elem_bytes, num_elems);
}
TVM_CCALL(TVMArrayCopyFromBytes(dst, &bytes[0], data_byte_size)); TVM_CCALL(TVMArrayCopyFromBytes(dst, &bytes[0], data_byte_size));
} }
...@@ -453,9 +460,8 @@ void GraphRuntime::LoadParams(dmlc::Stream* strm) { ...@@ -453,9 +460,8 @@ void GraphRuntime::LoadParams(dmlc::Stream* strm) {
CHECK(strm->Read(&names)) CHECK(strm->Read(&names))
<< "Invalid parameters file format"; << "Invalid parameters file format";
uint64_t sz; uint64_t sz;
strm->Read(&sz, sizeof(sz)); strm->Read(&sz);
size_t size = static_cast<size_t>(sz); size_t size = static_cast<size_t>(sz);
CHECK(size == names.size()) CHECK(size == names.size())
<< "Invalid parameters file format"; << "Invalid parameters file format";
for (size_t i = 0; i < size; ++i) { for (size_t i = 0; i < size; ++i) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment