Commit 14181340 by Tianqi Chen Committed by GitHub

[RPC] More robust tracker protocol (#1085)

* [RPC] More robust tracker protocol

* fix normal rpc
parent 2e17e850
...@@ -34,6 +34,7 @@ class TrackerCode(object): ...@@ -34,6 +34,7 @@ class TrackerCode(object):
REQUEST = 4 REQUEST = 4
UPDATE_INFO = 5 UPDATE_INFO = 5
SUMMARY = 6 SUMMARY = 6
GET_PENDING_MATCHKEYS = 7
RPC_SESS_MASK = 128 RPC_SESS_MASK = 128
......
...@@ -230,7 +230,7 @@ class TrackerSession(object): ...@@ -230,7 +230,7 @@ class TrackerSession(object):
if value[0] != base.TrackerCode.SUCCESS: if value[0] != base.TrackerCode.SUCCESS:
raise RuntimeError("Invalid return value %s" % str(value)) raise RuntimeError("Invalid return value %s" % str(value))
url, port, matchkey = value[1] url, port, matchkey = value[1]
return connect(url, port, key + matchkey, session_timeout) return connect(url, port, matchkey, session_timeout)
except socket.error as err: except socket.error as err:
self.close() self.close()
last_err = err last_err = err
......
...@@ -14,6 +14,7 @@ import socket ...@@ -14,6 +14,7 @@ import socket
import multiprocessing import multiprocessing
import errno import errno
import struct import struct
import time
try: try:
import tornado import tornado
...@@ -45,6 +46,7 @@ class ForwardHandler(object): ...@@ -45,6 +46,7 @@ class ForwardHandler(object):
self.rpc_key = None self.rpc_key = None
self.match_key = None self.match_key = None
self.forward_proxy = None self.forward_proxy = None
self.alloc_time = None
def __del__(self): def __del__(self):
logging.info("Delete %s...", self.name()) logging.info("Delete %s...", self.name())
...@@ -237,6 +239,7 @@ class ProxyServerHandler(object): ...@@ -237,6 +239,7 @@ class ProxyServerHandler(object):
self.sock.fileno(), event_handler, self.loop.READ) self.sock.fileno(), event_handler, self.loop.READ)
self._client_pool = {} self._client_pool = {}
self._server_pool = {} self._server_pool = {}
self.timeout_alloc = 5
self.timeout_client = timeout_client self.timeout_client = timeout_client
self.timeout_server = timeout_server self.timeout_server = timeout_server
# tracker information # tracker information
...@@ -245,8 +248,12 @@ class ProxyServerHandler(object): ...@@ -245,8 +248,12 @@ class ProxyServerHandler(object):
self._tracker_conn = None self._tracker_conn = None
self._tracker_pending_puts = [] self._tracker_pending_puts = []
self._key_set = set() self._key_set = set()
self.update_tracker_period = 2
if tracker_addr: if tracker_addr:
logging.info("Tracker address:%s", str(tracker_addr)) logging.info("Tracker address:%s", str(tracker_addr))
def _callback():
self._update_tracker(True)
self.loop.call_later(self.update_tracker_period, _callback)
logging.info("RPCProxy: Websock port bind to %d", web_port) logging.info("RPCProxy: Websock port bind to %d", web_port)
def _on_event(self, _): def _on_event(self, _):
...@@ -271,7 +278,22 @@ class ProxyServerHandler(object): ...@@ -271,7 +278,22 @@ class ProxyServerHandler(object):
rhs.send_data(lhs.rpc_key.encode("utf-8")) rhs.send_data(lhs.rpc_key.encode("utf-8"))
logging.info("Pairup connect %s and %s", lhs.name(), rhs.name()) logging.info("Pairup connect %s and %s", lhs.name(), rhs.name())
def _update_tracker(self): def _regenerate_server_keys(self, keys):
"""Regenerate keys for server pool"""
keyset = set(self._server_pool.keys())
new_keys = []
# re-generate the server match key, so old information is invalidated.
for key in keys:
rpc_key, _ = key.split(":")
handle = self._server_pool[key]
del self._server_pool[key]
new_key = base.random_key(rpc_key + ":", keyset)
self._server_pool[new_key] = handle
keyset.add(new_key)
new_keys.append(new_key)
return new_keys
def _update_tracker(self, period_update=False):
"""Update information on tracker.""" """Update information on tracker."""
try: try:
if self._tracker_conn is None: if self._tracker_conn is None:
...@@ -285,13 +307,33 @@ class ProxyServerHandler(object): ...@@ -285,13 +307,33 @@ class ProxyServerHandler(object):
# just connect to tracker, need to update all keys # just connect to tracker, need to update all keys
self._tracker_pending_puts = self._server_pool.keys() self._tracker_pending_puts = self._server_pool.keys()
if self._tracker_conn and period_update:
# periodically update tracker information
# regenerate key if the key is not in tracker anymore
# and there is no in-coming connection after timeout_alloc
base.sendjson(self._tracker_conn, [TrackerCode.GET_PENDING_MATCHKEYS])
pending_keys = set(base.recvjson(self._tracker_conn))
update_keys = []
for k, v in self._server_pool.items():
if k not in pending_keys:
if v.alloc_time is None:
v.alloc_time = time.time()
elif time.time() - v.alloc_time > self.timeout_alloc:
update_keys.append(k)
v.alloc_time = None
if update_keys:
logging.info("RPCProxy: No incoming conn on %s, regenerate keys...",
str(update_keys))
new_keys = self._regenerate_server_keys(update_keys)
self._tracker_pending_puts += new_keys
need_update_info = False need_update_info = False
# report new connections # report new connections
for key in self._tracker_pending_puts: for key in self._tracker_pending_puts:
rpc_key, match_key = key.split(":") rpc_key = key.split(":")[0]
base.sendjson(self._tracker_conn, base.sendjson(self._tracker_conn,
[TrackerCode.PUT, rpc_key, [TrackerCode.PUT, rpc_key,
(self._listen_port, ":" + match_key)]) (self._listen_port, key)])
assert base.recvjson(self._tracker_conn) == TrackerCode.SUCCESS assert base.recvjson(self._tracker_conn) == TrackerCode.SUCCESS
if rpc_key not in self._key_set: if rpc_key not in self._key_set:
self._key_set.add(rpc_key) self._key_set.add(rpc_key)
...@@ -305,24 +347,17 @@ class ProxyServerHandler(object): ...@@ -305,24 +347,17 @@ class ProxyServerHandler(object):
assert base.recvjson(self._tracker_conn) == TrackerCode.SUCCESS assert base.recvjson(self._tracker_conn) == TrackerCode.SUCCESS
self._tracker_pending_puts = [] self._tracker_pending_puts = []
except (socket.error, IOError) as err: except (socket.error, IOError) as err:
retry_period = 5
logging.info( logging.info(
"Lost tracker connection: %s, try reconnect in %g sec", "Lost tracker connection: %s, try reconnect in %g sec",
str(err), retry_period) str(err), self.update_tracker_period)
self._tracker_conn.close() self._tracker_conn.close()
self._tracker_conn = None self._tracker_conn = None
new_pool = {} self._regenerate_server_keys(self._server_pool.keys())
keyset = set(self._server_pool.keys())
# re-generate the server match key, so old information is invalidated. if period_update:
for key, handle in self._server_pool.items():
rpc_key, _ = key.split(":")
key = base.random_key(rpc_key + ":", keyset)
new_pool[key] = handle
keyset.add(key)
self._server_pool = new_pool
def _callback(): def _callback():
self._update_tracker() self._update_tracker(True)
self.loop.call_later(retry_period, _callback) self.loop.call_later(self.update_tracker_period, _callback)
def _handler_ready_tracker_mode(self, handler): def _handler_ready_tracker_mode(self, handler):
"""tracker mode to handle handler ready.""" """tracker mode to handle handler ready."""
......
...@@ -6,7 +6,7 @@ Server is TCP based with the following protocol: ...@@ -6,7 +6,7 @@ Server is TCP based with the following protocol:
- Initial handshake to the peer - Initial handshake to the peer
- [RPC_MAGIC, keysize(int32), key-bytes] - [RPC_MAGIC, keysize(int32), key-bytes]
- The key is in format - The key is in format
- {server|client}:device-type[:matchkey] [-timeout=timeout] - {server|client}:device-type[:random-key] [-timeout=timeout]
""" """
from __future__ import absolute_import from __future__ import absolute_import
...@@ -75,7 +75,7 @@ def _parse_server_opt(opts): ...@@ -75,7 +75,7 @@ def _parse_server_opt(opts):
def _listen_loop(sock, port, rpc_key, tracker_addr): def _listen_loop(sock, port, rpc_key, tracker_addr):
"""Lisenting loop of the server master.""" """Lisenting loop of the server master."""
def _accept_conn(listen_sock, tracker_conn, ping_period=0.1): def _accept_conn(listen_sock, tracker_conn, ping_period=2):
"""Accept connection from the other places. """Accept connection from the other places.
Parameters Parameters
...@@ -89,22 +89,40 @@ def _listen_loop(sock, port, rpc_key, tracker_addr): ...@@ -89,22 +89,40 @@ def _listen_loop(sock, port, rpc_key, tracker_addr):
ping_period : float, optional ping_period : float, optional
ping tracker every k seconds if no connection is accepted. ping tracker every k seconds if no connection is accepted.
""" """
old_keyset = set()
# Report resource to tracker # Report resource to tracker
if tracker_conn: if tracker_conn:
matchkey = base.random_key(":") matchkey = base.random_key(rpc_key + ":")
base.sendjson(tracker_conn, base.sendjson(tracker_conn,
[TrackerCode.PUT, rpc_key, (port, matchkey)]) [TrackerCode.PUT, rpc_key, (port, matchkey)])
assert base.recvjson(tracker_conn) == TrackerCode.SUCCESS assert base.recvjson(tracker_conn) == TrackerCode.SUCCESS
else: else:
matchkey = "" matchkey = rpc_key
unmatch_period_count = 0
unmatch_timeout = 4
# Wait until we get a valid connection # Wait until we get a valid connection
while True: while True:
if tracker_conn: if tracker_conn:
trigger = select.select([listen_sock], [], [], ping_period) trigger = select.select([listen_sock], [], [], ping_period)
if not listen_sock in trigger[0]: if not listen_sock in trigger[0]:
base.sendjson(tracker_conn, [TrackerCode.PING]) base.sendjson(tracker_conn, [TrackerCode.GET_PENDING_MATCHKEYS])
assert base.recvjson(tracker_conn) == TrackerCode.SUCCESS pending_keys = base.recvjson(tracker_conn)
old_keyset.add(matchkey)
# if match key not in pending key set
# it means the key is aqquired by a client but not used.
if matchkey not in pending_keys:
unmatch_period_count += 1
else:
unmatch_period_count = 0
# regenerate match key if key is aqquired but not used for a while
if unmatch_period_count * ping_period > unmatch_timeout + ping_period:
logging.info("RPCServer: no incoming connections, regenerate key ...")
matchkey = base.random_key(rpc_key + ":", old_keyset)
base.sendjson(tracker_conn,
[TrackerCode.PUT, rpc_key, (port, matchkey)])
assert base.recvjson(tracker_conn) == TrackerCode.SUCCESS
unmatch_period_count = 0
continue continue
conn, addr = listen_sock.accept() conn, addr = listen_sock.accept()
magic = struct.unpack("@i", base.recvall(conn, 4))[0] magic = struct.unpack("@i", base.recvall(conn, 4))[0]
...@@ -114,7 +132,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr): ...@@ -114,7 +132,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr):
keylen = struct.unpack("@i", base.recvall(conn, 4))[0] keylen = struct.unpack("@i", base.recvall(conn, 4))[0]
key = py_str(base.recvall(conn, keylen)) key = py_str(base.recvall(conn, keylen))
arr = key.split() arr = key.split()
expect_header = "client:" + rpc_key + matchkey expect_header = "client:" + matchkey
server_key = "server:" + rpc_key server_key = "server:" + rpc_key
if arr[0] != expect_header: if arr[0] != expect_header:
conn.sendall(struct.pack("@i", base.RPC_CODE_MISMATCH)) conn.sendall(struct.pack("@i", base.RPC_CODE_MISMATCH))
......
...@@ -48,6 +48,8 @@ class TCPHandler(object): ...@@ -48,6 +48,8 @@ class TCPHandler(object):
def write_message(self, message, binary=True): def write_message(self, message, binary=True):
assert binary assert binary
if self._sock is None:
raise IOError("socket is already closed")
self._pending_write.append(message) self._pending_write.append(message)
self._update_write() self._update_write()
......
...@@ -92,7 +92,9 @@ class PriorityScheduler(Scheduler): ...@@ -92,7 +92,9 @@ class PriorityScheduler(Scheduler):
value = self._values.pop(0) value = self._values.pop(0)
item = heapq.heappop(self._requests) item = heapq.heappop(self._requests)
callback = item[-1] callback = item[-1]
if not callback(value): if callback(value[1:]):
value[0].pending_matchkeys.remove(value[-1])
else:
self._values.append(value) self._values.append(value)
def put(self, value): def put(self, value):
...@@ -124,6 +126,8 @@ class TCPEventHandler(tornado_util.TCPHandler): ...@@ -124,6 +126,8 @@ class TCPEventHandler(tornado_util.TCPHandler):
self._addr = addr self._addr = addr
self._init_req_nbytes = 4 self._init_req_nbytes = 4
self._info = {"addr": addr} self._info = {"addr": addr}
# list of pending match keys that has not been used.
self.pending_matchkeys = set()
self._tracker._connections.add(self) self._tracker._connections.add(self)
def name(self): def name(self):
...@@ -189,18 +193,27 @@ class TCPEventHandler(tornado_util.TCPHandler): ...@@ -189,18 +193,27 @@ class TCPEventHandler(tornado_util.TCPHandler):
if code == TrackerCode.PUT: if code == TrackerCode.PUT:
key = args[1] key = args[1]
port, matchkey = args[2] port, matchkey = args[2]
self._tracker.put(key, (self._addr[0], port, matchkey)) self.pending_matchkeys.add(matchkey)
self._tracker.put(key, (self, self._addr[0], port, matchkey))
self.ret_value(TrackerCode.SUCCESS) self.ret_value(TrackerCode.SUCCESS)
elif code == TrackerCode.REQUEST: elif code == TrackerCode.REQUEST:
key = args[1] key = args[1]
user = args[2] user = args[2]
priority = args[3] priority = args[3]
def _cb(value): def _cb(value):
self.ret_value([TrackerCode.SUCCESS, value]) # if the connection is already closed
if not self._sock:
return False
try:
self.ret_value([TrackerCode.SUCCESS, value])
except (socket.sock_error, IOError):
return False
return True return True
self._tracker.request(key, user, priority, _cb) self._tracker.request(key, user, priority, _cb)
elif code == TrackerCode.PING: elif code == TrackerCode.PING:
self.ret_value(TrackerCode.SUCCESS) self.ret_value(TrackerCode.SUCCESS)
elif code == TrackerCode.GET_PENDING_MATCHKEYS:
self.ret_value(list(self.pending_matchkeys))
elif code == TrackerCode.STOP: elif code == TrackerCode.STOP:
# safe stop tracker # safe stop tracker
if self._tracker._stop_key == args[1]: if self._tracker._stop_key == args[1]:
......
...@@ -23,6 +23,11 @@ def check_server_drop(): ...@@ -23,6 +23,11 @@ def check_server_drop():
tproxy = proxy.Proxy("localhost", 8881, tproxy = proxy.Proxy("localhost", 8881,
tracker_addr=("localhost", tserver.port)) tracker_addr=("localhost", tserver.port))
tclient = rpc.connect_tracker("localhost", tserver.port) tclient = rpc.connect_tracker("localhost", tserver.port)
server0 = rpc.Server(
"localhost", port=9099,
tracker_addr=("localhost", tserver.port),
key="abc")
server1 = rpc.Server( server1 = rpc.Server(
"localhost", port=9099, "localhost", port=9099,
tracker_addr=("localhost", tserver.port), tracker_addr=("localhost", tserver.port),
...@@ -34,6 +39,10 @@ def check_server_drop(): ...@@ -34,6 +39,10 @@ def check_server_drop():
"localhost", tproxy.port, is_proxy=True, "localhost", tproxy.port, is_proxy=True,
key="xyz1") key="xyz1")
# Fault tolerence to un-handled requested value
_put(tclient, [TrackerCode.REQUEST, "abc", "", 1])
_put(tclient, [TrackerCode.REQUEST, "xyz1", "", 1])
# Fault tolerence to stale worker value # Fault tolerence to stale worker value
_put(tclient, [TrackerCode.PUT, "xyz", (server1.port, "abc")]) _put(tclient, [TrackerCode.PUT, "xyz", (server1.port, "abc")])
_put(tclient, [TrackerCode.PUT, "xyz", (server1.port, "abcxxx")]) _put(tclient, [TrackerCode.PUT, "xyz", (server1.port, "abcxxx")])
...@@ -58,14 +67,21 @@ def check_server_drop(): ...@@ -58,14 +67,21 @@ def check_server_drop():
assert f1(10) == 11 assert f1(10) == 11
f1 = remote2.get_function("rpc.test2.addone") f1 = remote2.get_function("rpc.test2.addone")
assert f1(10) == 11 assert f1(10) == 11
except tvm.TVMError as e: except tvm.TVMError as e:
pass pass
remote3 = tclient.request("abc")
f1 = remote3.get_function("rpc.test2.addone")
remote3 = tclient.request("xyz1")
f1 = remote3.get_function("rpc.test2.addone")
assert f1(10) == 11
check_timeout(0.01, 0.1) check_timeout(0.01, 0.1)
check_timeout(2, 0) check_timeout(2, 0)
tserver.terminate() tserver.terminate()
server2.terminate() server0.terminate()
server1.terminate() server1.terminate()
server2.terminate()
server3.terminate() server3.terminate()
tproxy.terminate() tproxy.terminate()
except ImportError: except ImportError:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment