Commit 07cc21f1 by Lianmin Zheng Committed by Tianqi Chen

add exclusive mode for rpc server (#941)

parent 589a2651
...@@ -74,10 +74,16 @@ def _recvall(sock, nbytes): ...@@ -74,10 +74,16 @@ def _recvall(sock, nbytes):
return b"".join(res) return b"".join(res)
def _listen_loop(sock): def _listen_loop(sock, exclusive):
"""Lisenting loop""" """Lisenting loop"""
last_proc = None
while True: while True:
conn, addr = sock.accept() conn, addr = sock.accept()
if last_proc and last_proc.is_alive() and exclusive:
logging.info("Kill last call")
last_proc.terminate()
logging.info("RPCServer: connection from %s", addr) logging.info("RPCServer: connection from %s", addr)
magic = struct.unpack("@i", _recvall(conn, 4))[0] magic = struct.unpack("@i", _recvall(conn, 4))[0]
if magic != RPC_MAGIC: if magic != RPC_MAGIC:
...@@ -90,9 +96,11 @@ def _listen_loop(sock): ...@@ -90,9 +96,11 @@ def _listen_loop(sock):
else: else:
conn.sendall(struct.pack("@i", RPC_MAGIC)) conn.sendall(struct.pack("@i", RPC_MAGIC))
logging.info("Connection from %s", addr) logging.info("Connection from %s", addr)
process = multiprocessing.Process(target=_serve_loop, args=(conn, addr)) process = multiprocessing.Process(target=_serve_loop, args=(conn, addr))
process.deamon = True process.deamon = True
process.start() process.start()
last_proc = process
# close from our side. # close from our side.
conn.close() conn.close()
...@@ -158,6 +166,11 @@ class Server(object): ...@@ -158,6 +166,11 @@ class Server(object):
This is recommended to switch on if we want to do local RPC demonstration This is recommended to switch on if we want to do local RPC demonstration
for GPU devices to avoid fork safety issues. for GPU devices to avoid fork safety issues.
exclusive : bool, optional
If this is enabled, the server will kill old connection
when new connection comes. This can make sure the current call
monopolize the hardware resource.
key : str, optional key : str, optional
The key used to identify the server in Proxy connection. The key used to identify the server in Proxy connection.
""" """
...@@ -167,6 +180,7 @@ class Server(object): ...@@ -167,6 +180,7 @@ class Server(object):
port_end=9199, port_end=9199,
is_proxy=False, is_proxy=False,
use_popen=False, use_popen=False,
exclusive=False,
key=""): key=""):
self.host = host self.host = host
self.port = port self.port = port
...@@ -201,7 +215,7 @@ class Server(object): ...@@ -201,7 +215,7 @@ class Server(object):
sock.listen(1) sock.listen(1)
self.sock = sock self.sock = sock
self.proc = multiprocessing.Process( self.proc = multiprocessing.Process(
target=_listen_loop, args=(self.sock,)) target=_listen_loop, args=(self.sock, exclusive))
self.proc.deamon = True self.proc.deamon = True
self.proc.start() self.proc.start()
else: else:
...@@ -210,8 +224,6 @@ class Server(object): ...@@ -210,8 +224,6 @@ class Server(object):
self.proc.deamon = True self.proc.deamon = True
self.proc.start() self.proc.start()
def terminate(self): def terminate(self):
"""Terminate the server process""" """Terminate the server process"""
if self.proc: if self.proc:
...@@ -222,7 +234,6 @@ class Server(object): ...@@ -222,7 +234,6 @@ class Server(object):
self.terminate() self.terminate()
class RPCSession(object): class RPCSession(object):
"""RPC Client session module """RPC Client session module
......
...@@ -21,6 +21,9 @@ def main(): ...@@ -21,6 +21,9 @@ def main():
help="Whether to load executor runtime") help="Whether to load executor runtime")
parser.add_argument('--load-library', type=str, default="", parser.add_argument('--load-library', type=str, default="",
help="Additional library to load") help="Additional library to load")
parser.add_argument('--exclusive', action='store_true',
help="If this is enabled, the server will kill old connection"
"when new connection comes")
args = parser.parse_args() args = parser.parse_args()
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
...@@ -35,7 +38,7 @@ def main(): ...@@ -35,7 +38,7 @@ def main():
libs.append(ctypes.CDLL(file_name, ctypes.RTLD_GLOBAL)) libs.append(ctypes.CDLL(file_name, ctypes.RTLD_GLOBAL))
logging.info("Load additional library %s", file_name) logging.info("Load additional library %s", file_name)
server = rpc.Server(args.host, args.port, args.port_end) server = rpc.Server(args.host, args.port, args.port_end, exclusive=args.exclusive)
server.libs += libs server.libs += libs
server.proc.join() server.proc.join()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment