Commit 6bd8dbc7 by Tianqi Chen Committed by GitHub

[RPC] Refactor, introduce tracker (#1080)

* [RPC] Refactor, introduce tracker

* [RPC] Change RPC hand shake convention, always get remote key.

* fix lint
parent c24e82e3
......@@ -280,7 +280,7 @@ lib/libtvm_runtime.${SHARED_LIBRARY_SUFFIX}: $(RUNTIME_DEP)
lib/libtvm_web_runtime.bc: web/web_runtime.cc
@mkdir -p build/web
@mkdir -p $(@D)
$(CXX) $(CFLAGS) -MM -MT lib/libtvm_web_runtime.bc $< >build/web/web_runtime.d
emcc $(EMCC_FLAGS) -MM -MT lib/libtvm_web_runtime.bc $< >build/web/web_runtime.d
emcc $(EMCC_FLAGS) -o $@ web/web_runtime.cc
lib/libtvm_web_runtime.js: lib/libtvm_web_runtime.bc
......
......@@ -31,9 +31,12 @@ using FEventHandler = std::function<int(const std::string& in_bytes, int event_f
*
* \param outputStream The output stream used to send outputs.
* \param name The name of the server.
* \param remote_key The remote key
* \return The event handler.
*/
FEventHandler CreateServerEventHandler(NSOutputStream *outputStream, std::string name);
FEventHandler CreateServerEventHandler(NSOutputStream *outputStream,
std::string name,
std::string remote_key);
} // namespace runtime
} // namespace tvm
......
......@@ -8,6 +8,7 @@
#include "../../src/runtime/cpu_device_api.cc"
#include "../../src/runtime/workspace_pool.cc"
#include "../../src/runtime/thread_pool.cc"
#include "../../src/runtime/threading_backend.cc"
#include "../../src/runtime/module_util.cc"
#include "../../src/runtime/system_lib_module.cc"
#include "../../src/runtime/module.cc"
......@@ -59,9 +60,10 @@ class NSStreamChannel final : public RPCChannel {
NSOutputStream* stream_;
};
FEventHandler CreateServerEventHandler(NSOutputStream *outputStream, std::string name) {
FEventHandler CreateServerEventHandler(
NSOutputStream *outputStream, std::string name, std::string remote_key) {
std::unique_ptr<NSStreamChannel> ch(new NSStreamChannel(outputStream));
std::shared_ptr<RPCSession> sess = RPCSession::Create(std::move(ch), name);
std::shared_ptr<RPCSession> sess = RPCSession::Create(std::move(ch), name, remote_key);
return [sess](const std::string& in_bytes, int flag) {
return sess->ServerEventHandler(in_bytes, flag);
};
......
......@@ -143,7 +143,7 @@
[outputStream_ scheduleInRunLoop:[NSRunLoop currentRunLoop] forMode:NSDefaultRunLoopMode];
[outputStream_ open];
[inputStream_ open];
handler_ = tvm::runtime::CreateServerEventHandler(outputStream_, key_);
handler_ = tvm::runtime::CreateServerEventHandler(outputStream_, key_, "%toinit");
CHECK(handler_ != nullptr);
self.infoText.text = @"";
self.statusLabel.text = @"Connecting...";
......@@ -169,7 +169,6 @@
}
- (IBAction)disconnect:(id)sender {
[self close];
}
......
......@@ -74,6 +74,9 @@ public class ConnectProxyServerProcessor implements ServerProcessor {
} else if (magic != RPC.RPC_MAGIC) {
throw new RuntimeException(address + " is not RPC Proxy");
}
// Get key from remote
int keylen = Utils.wrapBytes(Utils.recvAll(in, 4)).getInt();
String remoteKey = Utils.decodeToStr(Utils.recvAll(in, keylen));
System.err.println("RPCProxy connected to " + address);
final int sockFd = socketFileDescriptorGetter.get(currSocket);
......
......@@ -60,7 +60,12 @@ public class StandaloneServerProcessor implements ServerProcessor {
out.write(Utils.toBytes(RPC.RPC_MAGIC + 2));
} else {
out.write(Utils.toBytes(RPC.RPC_MAGIC));
// send server key to the client
String serverKey = "server:java";
out.write(Utils.toBytes(serverKey.length()));
out.write(Utils.toBytes(serverKey));
}
System.err.println("Connection from " + socket.getRemoteSocketAddress().toString());
final int sockFd = socketFileDescriptorGetter.get(socket);
if (sockFd != -1) {
......
import time
from tvm.contrib import rpc_proxy
from tvm.contrib.rpc import proxy
def start_proxy_server(port, timeout):
prox = rpc_proxy.Proxy("localhost", port=port, port_end=port+1)
prox = proxy.Proxy("localhost", port=port, port_end=port+1)
if timeout > 0:
import time
time.sleep(timeout)
......@@ -17,4 +17,3 @@ if __name__ == "__main__":
port = int(sys.argv[1])
timeout = 0 if len(sys.argv) == 2 else float(sys.argv[2])
start_proxy_server(port, timeout)
"""Minimum graph runtime that executes graph containing TVM PackedFunc."""
from . import rpc
from .._ffi.base import string_types
from .._ffi.function import get_global_func
from .rpc import base as rpc_base
from .. import ndarray as nd
......@@ -33,12 +33,12 @@ def create(graph_json_str, libmod, ctx):
raise ValueError("Type %s is not supported" % type(graph_json_str))
device_type = ctx.device_type
device_id = ctx.device_id
if device_type >= rpc.RPC_SESS_MASK:
if device_type >= rpc_base.RPC_SESS_MASK:
assert libmod.type_key == "rpc"
assert rpc._SessTableIndex(libmod) == ctx._rpc_sess._tbl_index
hmod = rpc._ModuleHandle(libmod)
assert rpc_base._SessTableIndex(libmod) == ctx._rpc_sess._tbl_index
hmod = rpc_base._ModuleHandle(libmod)
fcreate = ctx._rpc_sess.get_function("tvm.graph_runtime.remote_create")
device_type = device_type % rpc.RPC_SESS_MASK
device_type = device_type % rpc_base.RPC_SESS_MASK
return GraphModule(fcreate(graph_json_str, hmod, device_type, device_id), ctx)
fcreate = get_global_func("tvm.graph_runtime.create")
return GraphModule(fcreate(graph_json_str, libmod, device_type, device_id), ctx)
......
"""Lightweight TVM RPC module.
RPC enables connect to a remote server, upload and launch functions.
This is useful to for cross-compile and remote testing,
The compiler stack runs on local server, while we use RPC server
to run on remote runtime which don't have a compiler available.
The test program compiles the program on local server,
upload and run remote RPC server, get the result back to verify correctness.
"""
from .server import Server
from .client import RPCSession, connect, connect_tracker
"""Base definitions for RPC."""
from __future__ import absolute_import
import socket
import time
import json
import errno
import struct
import random
import logging
from ..._ffi.function import _init_api
from ..._ffi.base import py_str
# Magic header for RPC data plane
RPC_MAGIC = 0xff271
# magic header for RPC tracker(control plane)
RPC_TRACKER_MAGIC = 0x2f271
# sucess response
RPC_CODE_SUCCESS = RPC_MAGIC + 0
# duplicate key in proxy
RPC_CODE_DUPLICATE = RPC_MAGIC + 1
# cannot found matched key in server
RPC_CODE_MISMATCH = RPC_MAGIC + 2
class TrackerCode(object):
"""Enumeration code for the RPC tracker"""
FAIL = -1
SUCCESS = 0
PING = 1
STOP = 2
PUT = 3
REQUEST = 4
RPC_SESS_MASK = 128
def recvall(sock, nbytes):
"""Receive all nbytes from socket.
Parameters
----------
sock: Socket
The socket
nbytes : int
Number of bytes to be received.
"""
res = []
nread = 0
while nread < nbytes:
chunk = sock.recv(min(nbytes - nread, 1024))
if not chunk:
raise IOError("connection reset")
nread += len(chunk)
res.append(chunk)
return b"".join(res)
def sendjson(sock, data):
"""send a python value to remote via json
Parameters
----------
sock : Socket
The socket
data : object
Python value to be sent.
"""
data = json.dumps(data)
sock.sendall(struct.pack("@i", len(data)))
sock.sendall(data.encode("utf-8"))
def recvjson(sock):
"""receive python value from remote via json
Parameters
----------
sock : Socket
The socket
Returns
-------
value : object
The value received.
"""
size = struct.unpack("@i", recvall(sock, 4))[0]
data = json.loads(py_str(recvall(sock, size)))
return data
def random_key():
"""Generate a random key n"""
return str(random.random())
def connect_with_retry(addr, timeout=60, retry_period=5):
"""Connect to a TPC address with retry
This function is only reliable to short period of server restart.
Parameters
----------
addr : tuple
address tuple
timeout : float
Timeout during retry
retry_period : float
Number of seconds before we retry again.
"""
tstart = time.time()
while True:
try:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect(addr)
return sock
except socket.error as sock_err:
if sock_err.args[0] not in (errno.ECONNREFUSED,):
raise sock_err
period = time.time() - tstart
if period > timeout:
raise RuntimeError(
"Failed to connect to server %s" % str(addr))
logging.info("Cannot connect to tracker%s, retry in %g secs...",
str(addr), retry_period)
time.sleep(retry_period)
# Still use tvm.contrib.rpc for the foreign functions
_init_api("tvm.contrib.rpc", "tvm.contrib.rpc.base")
"""RPC client tools"""
from __future__ import absolute_import
import os
import socket
import struct
from . import base
from ..._ffi.base import TVMError
from ..._ffi.ndarray import context as _context
class RPCSession(object):
"""RPC Client session module
Do not directly create the obhect, call connect
"""
# pylint: disable=invalid-name
def __init__(self, sess):
self._sess = sess
self._tbl_index = base._SessTableIndex(sess)
self._remote_funcs = {}
def get_function(self, name):
"""Get function from the session.
Parameters
----------
name : str
The name of the function
Returns
-------
f : Function
The result function.
"""
return self._sess.get_function(name)
def context(self, dev_type, dev_id=0):
"""Construct a remote context.
Parameters
----------
dev_type: int or str
dev_id: int, optional
Returns
-------
ctx: TVMContext
The corresponding encoded remote context.
"""
ctx = _context(dev_type, dev_id)
encode = (self._tbl_index + 1) * base.RPC_SESS_MASK
ctx.device_type += encode
ctx._rpc_sess = self
return ctx
def cpu(self, dev_id=0):
"""Construct remote CPU device."""
return self.context(1, dev_id)
def gpu(self, dev_id=0):
"""Construct remote GPU device."""
return self.context(2, dev_id)
def cl(self, dev_id=0):
"""Construct remote OpenCL device."""
return self.context(4, dev_id)
def metal(self, dev_id=0):
"""Construct remote Metal device."""
return self.context(8, dev_id)
def opengl(self, dev_id=0):
"""Construct remote OpenGL device."""
return self.context(11, dev_id)
def ext_dev(self, dev_id=0):
"""Construct remote extension device."""
return self.context(12, dev_id)
def upload(self, data, target=None):
"""Upload file to remote runtime temp folder
Parameters
----------
data : str or bytearray
The file name or binary in local to upload.
target : str, optional
The path in remote
"""
if isinstance(data, bytearray):
if not target:
raise ValueError("target must present when file is a bytearray")
blob = data
else:
blob = bytearray(open(data, "rb").read())
if not target:
target = os.path.basename(data)
if "upload" not in self._remote_funcs:
self._remote_funcs["upload"] = self.get_function(
"tvm.contrib.rpc.server.upload")
self._remote_funcs["upload"](target, blob)
def download(self, path):
"""Download file from remote temp folder.
Parameters
----------
path : str
The relative location to remote temp folder.
Returns
-------
blob : bytearray
The result blob from the file.
"""
if "download" not in self._remote_funcs:
self._remote_funcs["download"] = self.get_function(
"tvm.contrib.rpc.server.download")
return self._remote_funcs["download"](path)
def load_module(self, path):
"""Load a remote module, the file need to be uploaded first.
Parameters
----------
path : str
The relative location to remote temp folder.
Returns
-------
m : Module
The remote module containing remote function.
"""
return base._LoadRemoteModule(self._sess, path)
class TrackerSession(object):
"""Tracker client session.
Parameters
----------
addr : tuple
The address tuple
"""
def __init__(self, addr):
self._addr = addr
self._sock = None
self._max_request_retry = 5
self._connect()
def __del__(self):
self.close()
def _connect(self):
self._sock = base.connect_with_retry(self._addr)
self._sock.sendall(struct.pack("@i", base.RPC_TRACKER_MAGIC))
magic = struct.unpack("@i", base.recvall(self._sock, 4))[0]
if magic != base.RPC_TRACKER_MAGIC:
raise RuntimeError("%s is not RPC Tracker" % str(self._addr))
def close(self):
"""Close the tracker connection."""
if self._sock:
self._sock.close()
self._sock = None
def request(self, key, priority=1, session_timeout=0):
"""Request a new connection from the tracker.
Parameters
----------
key : str
The type key of the device.
priority : int, optional
The priority of the request.
session_timeout : float, optional
The duration of the session, allows server to kill
the connection when duration is longer than this value.
When duration is zero, it means the request must always be kept alive.
"""
for _ in range(self._max_request_retry):
try:
if self._sock is None:
self._connect()
base.sendjson(self._sock,
[base.TrackerCode.REQUEST, key, "", priority])
value = base.recvjson(self._sock)
if value[0] != base.TrackerCode.SUCCESS:
raise RuntimeError("Invalid return value %s" % str(value))
url, port, matchkey = value[1]
return connect(url, port, key + matchkey, session_timeout)
except socket.error:
self.close()
except TVMError:
pass
def connect(url, port, key="", session_timeout=0):
"""Connect to RPC Server
Parameters
----------
url : str
The url of the host
port : int
The port to connect to
key : str, optional
Additional key to match server
session_timeout : float, optional
The duration of the session, allows server to kill
the connection when duration is longer than this value.
When duration is zero, it means the request must always be kept alive.
Returns
-------
sess : RPCSession
The connected session.
"""
try:
if session_timeout:
key += " -timeout=%s" % str(session_timeout)
sess = base._Connect(url, port, key)
except NameError:
raise RuntimeError("Please compile with USE_RPC=1")
return RPCSession(sess)
def connect_tracker(url, port):
"""Connect to a RPC tracker
Parameters
----------
url : str
The url of the host
port : int
The port to connect to
Returns
-------
sess : TrackerSession
The connected tracker session.
"""
return TrackerSession((url, port))
......@@ -14,17 +14,20 @@ import socket
import multiprocessing
import errno
import struct
try:
import tornado
from tornado import gen
from tornado import websocket
from tornado import ioloop
from tornado import websocket
from . import tornado_util
except ImportError as error_msg:
raise ImportError("RPCProxy module requires tornado package %s" % error_msg)
from . import rpc
from .rpc import RPC_MAGIC, _server_env
from .._ffi.base import py_str
from . import base
from .base import RPC_MAGIC, RPC_CODE_DUPLICATE, RPC_CODE_SUCCESS, RPC_CODE_MISMATCH
from .server import _server_env
from ..._ffi.base import py_str
class ForwardHandler(object):
"""Forward handler to forward the message."""
......@@ -102,102 +105,38 @@ class ForwardHandler(object):
"""on close event"""
assert not self._done
logging.info("RPCProxy:on_close %s ...", self.name())
self._done = True
self.forward_proxy = None
if self.rpc_key:
key = self.rpc_key[6:]
key = self.rpc_key[7:]
if ProxyServerHandler.current._client_pool.get(key, None) == self:
ProxyServerHandler.current._client_pool.pop(key)
if ProxyServerHandler.current._server_pool.get(key, None) == self:
ProxyServerHandler.current._server_pool.pop(key)
self._done = True
self.forward_proxy = None
class TCPHandler(ForwardHandler):
class TCPHandler(tornado_util.TCPHandler, ForwardHandler):
"""Event driven TCP handler."""
def __init__(self, sock, addr):
super(TCPHandler, self).__init__(sock)
self._init_handler()
self.sock = sock
assert self.sock
self.addr = addr
self.loop = ioloop.IOLoop.current()
self.sock.setblocking(0)
self.pending_write = []
self._signal_close = False
def event_handler(_, events):
self._on_event(events)
ioloop.IOLoop.current().add_handler(
self.sock.fileno(), event_handler, self.loop.READ | self.loop.ERROR)
def name(self):
return "TCPSocket: %s:%s" % (str(self.addr), self.rpc_key)
def send_data(self, message, binary=True):
assert binary
self.pending_write.append(message)
self._on_write()
def _on_write(self):
while self.pending_write:
try:
msg = self.pending_write[0]
nsend = self.sock.send(msg)
if nsend != len(msg):
self.pending_write[0] = msg[nsend:]
else:
del self.pending_write[0]
except socket.error as err:
if err.args[0] in (errno.EAGAIN, errno.EWOULDBLOCK):
break
else:
self.on_error(err)
if self.pending_write:
self.loop.update_handler(
self.sock.fileno(), self.loop.READ | self.loop.ERROR | self.loop.WRITE)
else:
if self._signal_close:
self.close()
else:
self.loop.update_handler(
self.sock.fileno(), self.loop.READ | self.loop.ERROR)
self.write_message(message, True)
def _on_read(self):
try:
msg = bytes(self.sock.recv(4096))
if msg:
self.on_data(msg)
return True
else:
self.close_pair()
except socket.error as err:
if err.args[0] in (errno.EAGAIN, errno.EWOULDBLOCK):
pass
else:
self.on_error(e)
return False
def _on_event(self, events):
if (events & self.loop.ERROR) or (events & self.loop.READ):
if self._on_read() and (events & self.loop.WRITE):
self._on_write()
elif events & self.loop.WRITE:
self._on_write()
def signal_close(self):
if not self.pending_write:
self.close()
else:
self._signal_close = True
def on_message(self, message):
self.on_data(message)
def close(self):
if self.sock is not None:
logging.info("%s Close socket..", self.name())
try:
ioloop.IOLoop.current().remove_handler(self.sock.fileno())
self.sock.close()
except socket.error:
pass
self.sock = None
self.on_close_event()
def on_close(self):
if self.forward_proxy:
self.forward_proxy.signal_close()
self.forward_proxy = None
logging.info("%s Close socket..", self.name())
self.on_close_event()
class WebSocketHandler(websocket.WebSocketHandler, ForwardHandler):
......@@ -207,7 +146,7 @@ class WebSocketHandler(websocket.WebSocketHandler, ForwardHandler):
self._init_handler()
def name(self):
return "WebSocketProxy"
return "WebSocketProxy: %s" % (self.rpc_key)
def on_message(self, message):
self.on_data(message)
......@@ -234,12 +173,16 @@ class WebSocketHandler(websocket.WebSocketHandler, ForwardHandler):
class RequestHandler(tornado.web.RequestHandler):
"""Handles html request."""
def __init__(self, *args, **kwargs):
self.page = open(kwargs.pop("file_path")).read()
web_port = kwargs.pop("rpc_web_port", None)
if web_port:
self.page = self.page.replace(
"ws://localhost:9190/ws",
"ws://localhost:%d/ws" % web_port)
file_path = kwargs.pop("file_path")
if file_path.endswith("html"):
self.page = open(file_path).read()
web_port = kwargs.pop("rpc_web_port", None)
if web_port:
self.page = self.page.replace(
"ws://localhost:9190/ws",
"ws://localhost:%d/ws" % web_port)
else:
self.page = open(file_path, "rb").read()
super(RequestHandler, self).__init__(*args, **kwargs)
def data_received(self, _):
......@@ -303,14 +246,22 @@ class ProxyServerHandler(object):
def _pair_up(self, lhs, rhs):
lhs.forward_proxy = rhs
rhs.forward_proxy = lhs
lhs.send_data(struct.pack('@i', RPC_MAGIC))
rhs.send_data(struct.pack('@i', RPC_MAGIC))
lhs.send_data(struct.pack('@i', RPC_CODE_SUCCESS))
lhs.send_data(struct.pack('@i', len(rhs.rpc_key)))
lhs.send_data(rhs.rpc_key.encode("utf-8"))
rhs.send_data(struct.pack('@i', RPC_CODE_SUCCESS))
rhs.send_data(struct.pack('@i', len(lhs.rpc_key)))
rhs.send_data(lhs.rpc_key.encode("utf-8"))
logging.info("Pairup connect %s and %s", lhs.name(), rhs.name())
def handler_ready(self, handler):
"""Report handler to be ready."""
logging.info("Handler ready %s", handler.name())
key = handler.rpc_key[6:]
key = handler.rpc_key[7:]
if handler.rpc_key.startswith("server:"):
pool_src, pool_dst = self._client_pool, self._server_pool
timeout = self.timeout_server
......@@ -329,18 +280,19 @@ class ProxyServerHandler(object):
logging.info("Timeout client connection %s, cannot find match key=%s",
handler.name(), key)
pool_dst.pop(key)
handler.send_data(struct.pack('@i', RPC_MAGIC + 2))
handler.send_data(struct.pack('@i', RPC_CODE_MISMATCH))
handler.signal_close()
self.loop.call_later(timeout, cleanup)
else:
logging.info("Duplicate connection with same key=%s", key)
handler.send_data(struct.pack('@i', RPC_MAGIC + 1))
handler.send_data(struct.pack('@i', RPC_CODE_DUPLICATE))
handler.signal_close()
def run(self):
"""Run the proxy server"""
ioloop.IOLoop.current().start()
def _proxy_server(listen_sock,
web_port,
timeout_client,
......@@ -446,7 +398,8 @@ def websocket_proxy_server(url, key=""):
data = bytes(data)
conn.write_message(data, binary=True)
return len(data)
on_message = rpc._CreateEventDrivenServer(_fsend, "WebSocketProxyServer")
on_message = base._CreateEventDrivenServer(
_fsend, "WebSocketProxyServer", "%toinit")
return on_message
@gen.coroutine
......@@ -462,14 +415,16 @@ def websocket_proxy_server(url, key=""):
msg = yield conn.read_message()
assert len(msg) >= 4
magic = struct.unpack('@i', msg[:4])[0]
if magic == RPC_MAGIC + 1:
if magic == RPC_CODE_DUPLICATE:
raise RuntimeError("key: %s has already been used in proxy" % key)
elif magic == RPC_MAGIC + 2:
elif magic == RPC_CODE_MISMATCH:
logging.info("RPCProxy do not have matching client key %s", key)
elif magic != RPC_MAGIC:
elif magic != RPC_CODE_SUCCESS:
raise RuntimeError("%s is not RPC Proxy" % url)
logging.info("Connection established")
msg = msg[4:]
logging.info("Connection established with remote")
if msg:
on_message(bytearray(msg), 3)
......
"""Utilities used in tornado."""
import socket
import errno
from tornado import ioloop
class TCPHandler(object):
"""TCP socket handler backed tornado event loop.
Parameters
----------
sock : Socket
The TCP socket, will set it to non-blocking mode.
"""
def __init__(self, sock):
self._sock = sock
self._ioloop = ioloop.IOLoop.current()
self._sock.setblocking(0)
self._pending_write = []
self._signal_close = False
def _event_handler(_, events):
self._event_handler(events)
self._ioloop.add_handler(
self._sock.fileno(), _event_handler,
self._ioloop.READ | self._ioloop.ERROR)
def signal_close(self):
"""Signal the handler to close.
The handler will be closed after the existing
pending message are sent to the peer.
"""
if not self._pending_write:
self.close()
else:
self._signal_close = True
def close(self):
"""Close the socket"""
if self._sock is not None:
try:
self._ioloop.remove_handler(self._sock.fileno())
self._sock.close()
except socket.error:
pass
self._sock = None
self.on_close()
def write_message(self, message, binary=True):
assert binary
self._pending_write.append(message)
self._update_write()
def _event_handler(self, events):
"""centeral event handler"""
if (events & self._ioloop.ERROR) or (events & self._ioloop.READ):
if self._update_read() and (events & self._ioloop.WRITE):
self._update_write()
elif events & self._ioloop.WRITE:
self._update_write()
def _update_write(self):
"""Update the state on write"""
while self._pending_write:
try:
msg = self._pending_write[0]
nsend = self._sock.send(msg)
if nsend != len(msg):
self._pending_write[0] = msg[nsend:]
else:
self._pending_write.pop(0)
except socket.error as err:
if err.args[0] in (errno.EAGAIN, errno.EWOULDBLOCK):
break
else:
self.on_error(err)
if self._pending_write:
self._ioloop.update_handler(
self._sock.fileno(), self._ioloop.READ | self._ioloop.ERROR | self._ioloop.WRITE)
else:
if self._signal_close:
self.close()
else:
self._ioloop.update_handler(
self._sock.fileno(), self._ioloop.READ | self._ioloop.ERROR)
def _update_read(self):
"""Update state when there is read event"""
try:
msg = bytes(self._sock.recv(4096))
if msg:
self.on_message(msg)
return True
else:
# normal close, remote is closed
self.close()
except socket.error as err:
if err.args[0] in (errno.EAGAIN, errno.EWOULDBLOCK):
pass
else:
self.on_error(err)
return False
......@@ -4,7 +4,7 @@ from __future__ import absolute_import
import logging
import argparse
import os
from ..contrib.rpc_proxy import Proxy
from ..contrib.rpc.proxy import Proxy
def find_example_resource():
"""Find resource examples."""
......
......@@ -17,13 +17,14 @@ def main():
help='The port of the PRC')
parser.add_argument('--port-end', type=int, default=9199,
help='The end search port of the PRC')
parser.add_argument('--key', type=str, default="",
help="RPC key used to identify the connection type.")
parser.add_argument('--with-executor', type=bool, default=False,
help="Whether to load executor runtime")
parser.add_argument('--load-library', type=str, default="",
help="Additional library to load")
parser.add_argument('--exclusive', action='store_true',
help="If this is enabled, the server will kill old connection"
"when new connection comes")
parser.add_argument('--tracker', type=str, default="",
help="Report to RPC tracker")
args = parser.parse_args()
logging.basicConfig(level=logging.INFO)
......@@ -38,7 +39,21 @@ def main():
libs.append(ctypes.CDLL(file_name, ctypes.RTLD_GLOBAL))
logging.info("Load additional library %s", file_name)
server = rpc.Server(args.host, args.port, args.port_end, exclusive=args.exclusive)
if args.tracker:
url, port = args.tracker.split(":")
port = int(port)
tracker_addr = (url, port)
if not args.key:
raise RuntimeError(
"Need key to present type of resource when tracker is available")
else:
tracker_addr = None
server = rpc.Server(args.host,
args.port,
args.port_end,
key=args.key,
tracker_addr=tracker_addr)
server.libs += libs
server.proc.join()
......
"""RPC web proxy, allows redirect to websocket based RPC servers(browsers)"""
from __future__ import absolute_import
import logging
import argparse
from ..contrib.rpc.tracker import Tracker
def main():
"""Main funciton"""
parser = argparse.ArgumentParser()
parser.add_argument('--host', type=str, default="0.0.0.0",
help='the hostname of the server')
parser.add_argument('--port', type=int, default=9190,
help='The port of the PRC')
args = parser.parse_args()
logging.basicConfig(level=logging.INFO)
tracker = Tracker(args.host, port=args.port)
tracker.proc.join()
if __name__ == "__main__":
main()
......@@ -32,9 +32,12 @@ class CallbackChannel final : public RPCChannel {
PackedFunc fsend_;
};
PackedFunc CreateEventDrivenServer(PackedFunc fsend, std::string name) {
PackedFunc CreateEventDrivenServer(PackedFunc fsend,
std::string name,
std::string remote_key) {
std::unique_ptr<CallbackChannel> ch(new CallbackChannel(fsend));
std::shared_ptr<RPCSession> sess = RPCSession::Create(std::move(ch), name);
std::shared_ptr<RPCSession> sess =
RPCSession::Create(std::move(ch), name, remote_key);
return PackedFunc([sess](TVMArgs args, TVMRetValue* rv) {
int ret = sess->ServerEventHandler(args[0], args[1]);
*rv = ret;
......@@ -43,7 +46,7 @@ PackedFunc CreateEventDrivenServer(PackedFunc fsend, std::string name) {
TVM_REGISTER_GLOBAL("contrib.rpc._CreateEventDrivenServer")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = CreateEventDrivenServer(args[0], args[1]);
*rv = CreateEventDrivenServer(args[0], args[1], args[2]);
});
} // namespace runtime
} // namespace tvm
......@@ -49,6 +49,7 @@ class RPCModuleNode final : public ModuleNode {
~RPCModuleNode() {
if (module_handle_ != nullptr) {
sess_->CallRemote(RPCCode::kModuleFree, module_handle_);
module_handle_ = nullptr;
}
}
......@@ -198,5 +199,6 @@ TVM_REGISTER_GLOBAL("contrib.rpc._SessTableIndex")
CHECK_EQ(tkey, "rpc");
*rv = static_cast<RPCModuleNode*>(m.operator->())->sess()->table_index();
});
} // namespace runtime
} // namespace tvm
......@@ -49,11 +49,19 @@ class RPCSession::EventHandler {
EventHandler(common::RingBuffer* reader,
common::RingBuffer* writer,
int rpc_sess_table_index,
std::string name)
: reader_(reader), writer_(writer),
std::string name,
std::string* remote_key)
: reader_(reader),
writer_(writer),
rpc_sess_table_index_(rpc_sess_table_index),
name_(name) {
name_(name),
remote_key_(remote_key) {
this->Clear();
if (*remote_key == "%toinit") {
state_ = kInitHeader;
remote_key_->resize(0);
pending_request_bytes_ = sizeof(int32_t);
}
}
// Bytes needed to fulfill current request
size_t BytesNeeded() {
......@@ -75,6 +83,7 @@ class RPCSession::EventHandler {
std::swap(client_mode_, client_mode);
while (this->Ready()) {
switch (state_) {
case kInitHeader: HandleInitHeader(); break;
case kRecvCode: HandleRecvCode(); break;
case kRecvCallHandle: {
this->Read(&call_handle_, sizeof(call_handle_));
......@@ -223,6 +232,7 @@ class RPCSession::EventHandler {
protected:
enum State {
kInitHeader,
kRecvCode,
kRecvCallHandle,
kRecvPackedSeqNumArgs,
......@@ -240,6 +250,8 @@ class RPCSession::EventHandler {
RPCCode code_;
// Handle for the remote function call.
uint64_t call_handle_;
// Initialize remote header
bool init_header_step_{0};
// Number of packed arguments.
int num_packed_args_;
// Current argument index.
......@@ -266,6 +278,10 @@ class RPCSession::EventHandler {
<< "state=" << state;
state_ = state;
switch (state) {
case kInitHeader: {
LOG(FATAL) << "cannot switch to init header";
break;
}
case kRecvCode: {
this->RequestBytes(sizeof(RPCCode));
break;
......@@ -438,6 +454,21 @@ class RPCSession::EventHandler {
this->SwitchToState(kRecvPackedSeqArg);
}
}
// handler for initial header read
void HandleInitHeader() {
if (init_header_step_ == 0) {
int32_t len;
this->Read(&len, sizeof(len));
remote_key_->resize(len);
init_header_step_ = 1;
this->RequestBytes(len);
return;
} else {
CHECK_EQ(init_header_step_, 1);
this->Read(dmlc::BeginPtr(*remote_key_), remote_key_->length());
this->SwitchToState(kRecvCode);
}
}
// Handler for read code.
void HandleRecvCode() {
this->Read(&code_, sizeof(code_));
......@@ -633,6 +664,8 @@ class RPCSession::EventHandler {
int rpc_sess_table_index_;
// Name of session.
std::string name_;
// remote key
std::string* remote_key_;
};
struct RPCSessTable {
......@@ -699,7 +732,8 @@ RPCCode RPCSession::HandleUntilReturnEvent(
void RPCSession::Init() {
// Event handler
handler_ = std::make_shared<EventHandler>(&reader_, &writer_, table_index_, name_);
handler_ = std::make_shared<EventHandler>(
&reader_, &writer_, table_index_, name_, &remote_key_);
// Quick function to call remote.
call_remote_ = PackedFunc([this](TVMArgs args, TVMRetValue* rv) {
handler_->SendPackedSeq(args.values, args.type_codes, args.num_args);
......@@ -709,10 +743,13 @@ void RPCSession::Init() {
}
std::shared_ptr<RPCSession> RPCSession::Create(
std::unique_ptr<RPCChannel> channel, std::string name) {
std::unique_ptr<RPCChannel> channel,
std::string name,
std::string remote_key) {
std::shared_ptr<RPCSession> sess = std::make_shared<RPCSession>();
sess->channel_ = std::move(channel);
sess->name_ = std::move(name);
sess->remote_key_ = std::move(remote_key);
sess->table_index_ = RPCSessTable::Global()->Insert(sess);
sess->Init();
return sess;
......
......@@ -171,12 +171,15 @@ class RPCSession {
/*!
* \brief Create a RPC session with given channel.
* \param channel The communication channel.
* \param name The name of the session, used for debug
* \return The session.
* \param name The local name of the session, used for debug
* \param remote_key The remote key of the session
* if remote_key equals "%toinit", we need to re-intialize
* it by event handler.
*/
static std::shared_ptr<RPCSession> Create(
std::unique_ptr<RPCChannel> channel,
std::string name);
std::string name,
std::string remote_key);
/*!
* \brief Try get session from the global session table by table index.
* \param table_index The table index of the session.
......@@ -208,6 +211,8 @@ class RPCSession {
int table_index_{0};
// The name of the session.
std::string name_;
// The remote key
std::string remote_key_;
};
/*!
......
......@@ -68,7 +68,14 @@ RPCConnect(std::string url, int port, std::string key) {
sock.Close();
LOG(FATAL) << "URL " << url << ":" << port << " is not TVM RPC server";
}
return RPCSession::Create(std::unique_ptr<SockChannel>(new SockChannel(sock)), key);
CHECK_EQ(sock.RecvAll(&keylen, sizeof(keylen)), sizeof(keylen));
std::string remote_key;
if (keylen != 0) {
remote_key.resize(keylen);
CHECK_EQ(sock.RecvAll(&remote_key[0], keylen), keylen);
}
return RPCSession::Create(
std::unique_ptr<SockChannel>(new SockChannel(sock)), key, remote_key);
}
Module RPCClientConnect(std::string url, int port, std::string key) {
......@@ -80,7 +87,7 @@ void RPCServerLoop(int sockfd) {
static_cast<common::TCPSocket::SockType>(sockfd));
RPCSession::Create(
std::unique_ptr<SockChannel>(new SockChannel(sock)),
"SockServerLoop")->ServerLoop();
"SockServerLoop", "")->ServerLoop();
}
TVM_REGISTER_GLOBAL("contrib.rpc._Connect")
......
......@@ -17,9 +17,9 @@ def rpc_proxy_check():
"""
try:
from tvm.contrib import rpc_proxy
from tvm.contrib.rpc import proxy
web_port = 8888
prox = rpc_proxy.Proxy("localhost", web_port=web_port)
prox = proxy.Proxy("localhost", web_port=web_port)
def check():
if not tvm.module.enabled("rpc"):
return
......@@ -30,7 +30,7 @@ def rpc_proxy_check():
def addone(name, x):
return "%s:%d" % (name, x)
server = multiprocessing.Process(
target=rpc_proxy.websocket_proxy_server,
target=proxy.websocket_proxy_server,
args=("ws://localhost:%d/ws" % web_port,"x1"))
# Need to make sure that the connection start after proxy comes up
time.sleep(0.1)
......
import tvm
import logging
import numpy as np
import time
import multiprocessing
from tvm.contrib import rpc
def check_server_drop():
"""test when server drops"""
try:
from tvm.contrib.rpc import tracker, base
from tvm.contrib.rpc.base import TrackerCode
@tvm.register_func("rpc.test2.addone")
def addone(x):
return x + 1
def _put(tclient, value):
base.sendjson(tclient._sock, value)
base.recvjson(tclient._sock)
tserver = tracker.Tracker("localhost", 8888)
tclient = rpc.connect_tracker("localhost", tserver.port)
server1 = rpc.Server(
"localhost", port=9099,
tracker_addr=("localhost", tserver.port),
key="xyz")
server2 = rpc.Server(
"localhost", port=9099,
tracker_addr=("localhost", tserver.port),
key="xyz")
# Fault tolerence to stale worker value
_put(tclient, [TrackerCode.PUT, "xyz", (server1.port, "abc")])
_put(tclient, [TrackerCode.PUT, "xyz", (server1.port, "abcxxx")])
_put(tclient, [TrackerCode.PUT, "xyz", (server2.port, "abcxxx11")])
# Fault tolerence server timeout
def check_timeout(timeout, sleeptime):
try:
remote = tclient.request("xyz", priority=0, session_timeout=timeout)
remote2 = tclient.request("xyz", session_timeout=timeout)
time.sleep(sleeptime)
f1 = remote.get_function("rpc.test2.addone")
assert f1(10) == 11
f1 = remote2.get_function("rpc.test2.addone")
assert f1(10) == 11
except tvm.TVMError as e:
pass
check_timeout(0.01, 0.1)
check_timeout(2, 0)
except ImportError:
print("Skip because tornado is not available")
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
check_server_drop()
......@@ -18,7 +18,7 @@ def test_rpc_simple():
def remotethrow(name):
raise ValueError("%s" % name)
server = rpc.Server("localhost")
server = rpc.Server("localhost", key="x1")
client = rpc.connect(server.host, server.port, key="x1")
f1 = client.get_function("rpc.test.addone")
assert f1(10) == 11
......@@ -41,7 +41,6 @@ def test_rpc_array():
np.testing.assert_equal(y.asnumpy(), x)
server = rpc.Server("localhost")
remote = rpc.connect(server.host, server.port)
print("second connect")
r_cpu = tvm.nd.array(x, remote.cpu(0))
assert str(r_cpu.context).startswith("remote")
np.testing.assert_equal(r_cpu.asnumpy(), x)
......@@ -141,7 +140,7 @@ def test_rpc_return_func():
@tvm.register_func("rpc.test.remote_func")
def addone(x):
return lambda y: x+y
server = rpc.Server("localhost")
server = rpc.Server("localhost", key="x1")
client = rpc.connect(server.host, server.port, key="x1")
f1 = client.get_function("rpc.test.remote_func")
fadd = f1(10)
......
......@@ -851,6 +851,7 @@ var tvm_runtime = tvm_runtime || {};
}
// Node js, import websocket
var bkey = StringToUint8Array("server:" + key);
bkey = bkey.slice(0, bkey.length - 1);
var server_name = "WebSocketRPCServer[" + key + "]";
var RPC_MAGIC = 0xff271;
function checkEndian() {
......@@ -895,7 +896,7 @@ var tvm_runtime = tvm_runtime || {};
} else {
return new TVMConstant(0, "int32");
}
} , server_name);
} , server_name, "%toinit");
function on_open(event) {
var intbuf = new Int32Array(1);
......@@ -912,6 +913,7 @@ var tvm_runtime = tvm_runtime || {};
var msg = new Uint8Array(event.data);
CHECK(msg.length >= 4, "Need message header to be bigger than 4");
var magic = new Int32Array(event.data)[0];
if (magic == RPC_MAGIC + 1) {
throwError("key: " + key + " has already been used in proxy");
} else if (magic == RPC_MAGIC + 2) {
......@@ -1014,7 +1016,7 @@ var tvm_runtime = tvm_runtime || {};
/**
* Load parameters from serialized byte array of parameter dict.
*
*
* @param {Uint8Array} params The serialized parameter dict.
*/
"load_params" : function(params) {
......@@ -1024,7 +1026,7 @@ var tvm_runtime = tvm_runtime || {};
/**
* Load parameters from serialized base64 string of parameter dict.
*
*
* @param {string} base64_params The serialized parameter dict.
*/
"load_base64_params" : function(base64_params) {
......@@ -1046,7 +1048,7 @@ var tvm_runtime = tvm_runtime || {};
/**
* Get index-th output to out.
*
*
* @param {NDArray} out The output array container.
* @return {NDArray} The output array container.
*/
......@@ -1076,7 +1078,7 @@ var tvm_runtime = tvm_runtime || {};
var tvm_graph_module = fcreate(graph_json_str, libmod,
new TVMConstant(ctx.device_type, "int32"),
new TVMConstant(ctx.device_id, "int32"));
return new GraphModule(tvm_graph_module, ctx);
};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment