Commit 5ba24773 by eqy Committed by Tianqi Chen

support custom IP address from rpc server to tracker (PUT) (#1243)

parent a1974012
......@@ -71,8 +71,7 @@ def _parse_server_opt(opts):
ret["timeout"] = float(kv[9:])
return ret
def _listen_loop(sock, port, rpc_key, tracker_addr, load_library):
def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr):
"""Lisenting loop of the server master."""
def _accept_conn(listen_sock, tracker_conn, ping_period=2):
"""Accept connection from the other places.
......@@ -93,7 +92,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library):
if tracker_conn:
matchkey = base.random_key(rpc_key + ":")
base.sendjson(tracker_conn,
[TrackerCode.PUT, rpc_key, (port, matchkey)])
[TrackerCode.PUT, rpc_key, (port, matchkey), custom_addr])
assert base.recvjson(tracker_conn) == TrackerCode.SUCCESS
else:
matchkey = rpc_key
......@@ -109,17 +108,18 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library):
pending_keys = base.recvjson(tracker_conn)
old_keyset.add(matchkey)
# if match key not in pending key set
# it means the key is aqquired by a client but not used.
# it means the key is acquired by a client but not used.
if matchkey not in pending_keys:
unmatch_period_count += 1
else:
unmatch_period_count = 0
# regenerate match key if key is aqquired but not used for a while
# regenerate match key if key is acquired but not used for a while
if unmatch_period_count * ping_period > unmatch_timeout + ping_period:
logging.info("RPCServer: no incoming connections, regenerate key ...")
matchkey = base.random_key(rpc_key + ":", old_keyset)
base.sendjson(tracker_conn,
[TrackerCode.PUT, rpc_key, (port, matchkey)])
[TrackerCode.PUT, rpc_key, (port, matchkey),
custom_addr])
assert base.recvjson(tracker_conn) == TrackerCode.SUCCESS
unmatch_period_count = 0
continue
......@@ -278,7 +278,8 @@ class Server(object):
use_popen=False,
tracker_addr=None,
key="",
load_library=None):
load_library=None,
custom_addr=None):
try:
if base._ServerLoop is None:
raise RuntimeError("Please compile with USE_RPC=1")
......@@ -287,6 +288,7 @@ class Server(object):
self.host = host
self.port = port
self.libs = []
self.custom_addr = custom_addr
if use_popen:
cmd = ["python",
......@@ -298,7 +300,9 @@ class Server(object):
cmd += ["--tracker=%s:%d" % tracker_addr,
"--key=%s" % key]
if load_library:
cmd += ["--load-libary", load_library]
cmd += ["--load-library", load_library]
if custom_addr:
cmd += ["--custom-addr", custom_addr]
self.proc = multiprocessing.Process(
target=subprocess.check_call, args=(cmd,))
self.proc.deamon = True
......@@ -324,7 +328,7 @@ class Server(object):
self.sock = sock
self.proc = multiprocessing.Process(
target=_listen_loop, args=(
self.sock, self.port, key, tracker_addr, load_library))
self.sock, self.port, key, tracker_addr, load_library, self.custom_addr))
self.proc.deamon = True
self.proc.start()
else:
......
......@@ -194,7 +194,11 @@ class TCPEventHandler(tornado_util.TCPHandler):
key = args[1]
port, matchkey = args[2]
self.pending_matchkeys.add(matchkey)
self._tracker.put(key, (self, self._addr[0], port, matchkey))
# got custom address (from rpc server)
if args[3] is not None:
self._tracker.put(key, (self, args[3], port, matchkey))
else:
self._tracker.put(key, (self, self._addr[0], port, matchkey))
self.ret_value(TrackerCode.SUCCESS)
elif code == TrackerCode.REQUEST:
key = args[1]
......
......@@ -8,9 +8,8 @@ import sys
import logging
from ..contrib import rpc
def main(args):
"""Main funciton"""
"""Main function"""
if args.tracker:
url, port = args.tracker.split(":")
......@@ -27,7 +26,8 @@ def main(args):
args.port_end,
key=args.key,
tracker_addr=tracker_addr,
load_library=args.load_library)
load_library=args.load_library,
custom_addr=args.custom_addr)
server.proc.join()
......@@ -49,6 +49,9 @@ if __name__ == "__main__":
help="Use spawn mode to avoid fork. This option \
is able to avoid potential fork problems with Metal, OpenCL \
and ROCM compilers.")
parser.add_argument('--custom-addr', type=str,
help="Custom IP Address to Report to RPC Tracker")
parser.set_defaults(fork=True)
args = parser.parse_args()
logging.basicConfig(level=logging.INFO)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment