Commit 3d67ea17 by Tianqi Chen Committed by GitHub

[RPC] support tracker in proxy (#1082)

parent 79fc6672
...@@ -94,9 +94,29 @@ def recvjson(sock): ...@@ -94,9 +94,29 @@ def recvjson(sock):
return data return data
def random_key(): def random_key(prefix, cmap=None):
"""Generate a random key n""" """Generate a random key
return str(random.random())
Parameters
----------
prefix : str
The string prefix
cmap : dict
Conflict map
Returns
-------
key : str
The generated random key
"""
if cmap:
while True:
key = prefix + str(random.random())
if key not in cmap:
return key
else:
return prefix + str(random.random())
def connect_with_retry(addr, timeout=60, retry_period=5): def connect_with_retry(addr, timeout=60, retry_period=5):
......
...@@ -91,7 +91,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr): ...@@ -91,7 +91,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr):
""" """
# Report resource to tracker # Report resource to tracker
if tracker_conn: if tracker_conn:
matchkey = ":" + base.random_key() matchkey = base.random_key(":")
base.sendjson(tracker_conn, base.sendjson(tracker_conn,
[TrackerCode.PUT, rpc_key, (port, matchkey)]) [TrackerCode.PUT, rpc_key, (port, matchkey)])
assert base.recvjson(tracker_conn) == TrackerCode.SUCCESS assert base.recvjson(tracker_conn) == TrackerCode.SUCCESS
...@@ -130,6 +130,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr): ...@@ -130,6 +130,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr):
# Server logic # Server logic
tracker_conn = None tracker_conn = None
while True: while True:
try:
# step 1: setup tracker and report to tracker # step 1: setup tracker and report to tracker
if tracker_addr and tracker_conn is None: if tracker_addr and tracker_conn is None:
tracker_conn = base.connect_with_retry(tracker_addr) tracker_conn = base.connect_with_retry(tracker_addr)
...@@ -142,7 +143,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr): ...@@ -142,7 +143,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr):
base.sendjson(tracker_conn, base.sendjson(tracker_conn,
[TrackerCode.UPDATE_INFO, cinfo]) [TrackerCode.UPDATE_INFO, cinfo])
assert base.recvjson(tracker_conn) == TrackerCode.SUCCESS assert base.recvjson(tracker_conn) == TrackerCode.SUCCESS
try:
# step 2: wait for in-coming connections # step 2: wait for in-coming connections
conn, addr, opts = _accept_conn(sock, tracker_conn) conn, addr, opts = _accept_conn(sock, tracker_conn)
except (socket.error, IOError): except (socket.error, IOError):
...@@ -161,13 +162,17 @@ def _listen_loop(sock, port, rpc_key, tracker_addr): ...@@ -161,13 +162,17 @@ def _listen_loop(sock, port, rpc_key, tracker_addr):
# wait until server process finish or timeout # wait until server process finish or timeout
server_proc.join(opts.get("timeout", None)) server_proc.join(opts.get("timeout", None))
if server_proc.is_alive(): if server_proc.is_alive():
logging.info("Timeout in RPC session, kill..") logging.info("RPCServer: Timeout in RPC session, kill..")
server_proc.terminate() server_proc.terminate()
def _connect_proxy_loop(addr, key): def _connect_proxy_loop(addr, key):
key = "server:" + key key = "server:" + key
retry_count = 0
max_retry = 5
retry_period = 5
while True: while True:
try:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect(addr) sock.connect(addr)
sock.sendall(struct.pack("@i", base.RPC_MAGIC)) sock.sendall(struct.pack("@i", base.RPC_MAGIC))
...@@ -183,7 +188,6 @@ def _connect_proxy_loop(addr, key): ...@@ -183,7 +188,6 @@ def _connect_proxy_loop(addr, key):
keylen = struct.unpack("@i", base.recvall(sock, 4))[0] keylen = struct.unpack("@i", base.recvall(sock, 4))[0]
remote_key = py_str(base.recvall(sock, keylen)) remote_key = py_str(base.recvall(sock, keylen))
opts = _parse_server_opt(remote_key.split()[1:]) opts = _parse_server_opt(remote_key.split()[1:])
logging.info("RPCProxy connected to %s", str(addr)) logging.info("RPCProxy connected to %s", str(addr))
process = multiprocessing.Process( process = multiprocessing.Process(
target=_serve_loop, args=(sock, addr)) target=_serve_loop, args=(sock, addr))
...@@ -192,9 +196,15 @@ def _connect_proxy_loop(addr, key): ...@@ -192,9 +196,15 @@ def _connect_proxy_loop(addr, key):
sock.close() sock.close()
process.join(opts.get("timeout", None)) process.join(opts.get("timeout", None))
if process.is_alive(): if process.is_alive():
logging.info("Timeout in RPC session, kill..") logging.info("RPCProxyServer: Timeout in RPC session, kill..")
process.terminate() process.terminate()
retry_count = 0
except (socket.error, IOError) as err:
retry_count += 1
logging.info("Error encountered %s, retry in %g sec", str(err), retry_period)
if retry_count > max_retry:
raise RuntimeError("Maximum retry error: last error: %s" % str(err))
time.sleep(retry_period)
def _popen(cmd): def _popen(cmd):
proc = subprocess.Popen(cmd, proc = subprocess.Popen(cmd,
......
...@@ -317,7 +317,7 @@ class Tracker(object): ...@@ -317,7 +317,7 @@ class Tracker(object):
port_end=9199): port_end=9199):
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.port = None self.port = None
self.stop_key = base.random_key() self.stop_key = base.random_key("tracker")
for my_port in range(port, port_end): for my_port in range(port, port_end):
try: try:
sock.bind((host, my_port)) sock.bind((host, my_port))
......
...@@ -29,19 +29,35 @@ def main(): ...@@ -29,19 +29,35 @@ def main():
help='the hostname of the server') help='the hostname of the server')
parser.add_argument('--port', type=int, default=9090, parser.add_argument('--port', type=int, default=9090,
help='The port of the PRC') help='The port of the PRC')
parser.add_argument('--web-port', type=int, default=9190, parser.add_argument('--web-port', type=int, default=8888,
help='The port of the http/websocket server') help='The port of the http/websocket server')
parser.add_argument('--example-rpc', type=bool, default=False, parser.add_argument('--example-rpc', type=bool, default=False,
help='Whether to switch on example rpc mode') help='Whether to switch on example rpc mode')
parser.add_argument('--tracker', type=str, default="",
help="Report to RPC tracker")
args = parser.parse_args() args = parser.parse_args()
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
if args.tracker:
url, port = args.tracker.split(":")
port = int(port)
tracker_addr = (url, port)
else:
tracker_addr = None
if args.example_rpc: if args.example_rpc:
index, js_files = find_example_resource() index, js_files = find_example_resource()
prox = Proxy(args.host, port=args.port, prox = Proxy(args.host,
web_port=args.web_port, index_page=index, port=args.port,
resource_files=js_files) web_port=args.web_port,
index_page=index,
resource_files=js_files,
tracker_addr=tracker_addr)
else: else:
prox = Proxy(args.host, port=args.port, web_port=args.web_port) prox = Proxy(args.host,
port=args.port,
web_port=args.web_port,
tracker_addr=tracker_addr)
prox.proc.join() prox.proc.join()
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -8,7 +8,7 @@ from tvm.contrib import rpc ...@@ -8,7 +8,7 @@ from tvm.contrib import rpc
def check_server_drop(): def check_server_drop():
"""test when server drops""" """test when server drops"""
try: try:
from tvm.contrib.rpc import tracker, base from tvm.contrib.rpc import tracker, proxy, base
from tvm.contrib.rpc.base import TrackerCode from tvm.contrib.rpc.base import TrackerCode
@tvm.register_func("rpc.test2.addone") @tvm.register_func("rpc.test2.addone")
...@@ -20,21 +20,24 @@ def check_server_drop(): ...@@ -20,21 +20,24 @@ def check_server_drop():
base.recvjson(tclient._sock) base.recvjson(tclient._sock)
tserver = tracker.Tracker("localhost", 8888) tserver = tracker.Tracker("localhost", 8888)
tproxy = proxy.Proxy("localhost", 8881,
tracker_addr=("localhost", tserver.port))
tclient = rpc.connect_tracker("localhost", tserver.port) tclient = rpc.connect_tracker("localhost", tserver.port)
server1 = rpc.Server( server1 = rpc.Server(
"localhost", port=9099, "localhost", port=9099,
tracker_addr=("localhost", tserver.port), tracker_addr=("localhost", tserver.port),
key="xyz") key="xyz")
server2 = rpc.Server( server2 = rpc.Server(
"localhost", port=9099, "localhost", tproxy.port, is_proxy=True,
tracker_addr=("localhost", tserver.port),
key="xyz") key="xyz")
server3 = rpc.Server(
"localhost", tproxy.port, is_proxy=True,
key="xyz1")
# Fault tolerence to stale worker value # Fault tolerence to stale worker value
_put(tclient, [TrackerCode.PUT, "xyz", (server1.port, "abc")]) _put(tclient, [TrackerCode.PUT, "xyz", (server1.port, "abc")])
_put(tclient, [TrackerCode.PUT, "xyz", (server1.port, "abcxxx")]) _put(tclient, [TrackerCode.PUT, "xyz", (server1.port, "abcxxx")])
_put(tclient, [TrackerCode.PUT, "xyz", (server2.port, "abcxxx11")]) _put(tclient, [TrackerCode.PUT, "xyz", (tproxy.port, "abcxxx11")])
# Fault tolerence server timeout # Fault tolerence server timeout
def check_timeout(timeout, sleeptime): def check_timeout(timeout, sleeptime):
...@@ -60,8 +63,11 @@ def check_server_drop(): ...@@ -60,8 +63,11 @@ def check_server_drop():
check_timeout(0.01, 0.1) check_timeout(0.01, 0.1)
check_timeout(2, 0) check_timeout(2, 0)
tserver.terminate()
server2.terminate()
server1.terminate()
server3.terminate()
tproxy.terminate()
except ImportError: except ImportError:
print("Skip because tornado is not available") print("Skip because tornado is not available")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment