Commit d0eb2d3d by Lianmin Zheng Committed by Tianqi Chen

Add silent mode to rpc server and rpc tracker (#1268)

parent 558cf098
...@@ -120,7 +120,7 @@ def random_key(prefix, cmap=None): ...@@ -120,7 +120,7 @@ def random_key(prefix, cmap=None):
return prefix + str(random.random()) return prefix + str(random.random())
def connect_with_retry(addr, timeout=60, retry_period=5): def connect_with_retry(addr, timeout=60, retry_period=5, silent=False):
"""Connect to a TPC address with retry """Connect to a TPC address with retry
This function is only reliable to short period of server restart. This function is only reliable to short period of server restart.
...@@ -135,6 +135,9 @@ def connect_with_retry(addr, timeout=60, retry_period=5): ...@@ -135,6 +135,9 @@ def connect_with_retry(addr, timeout=60, retry_period=5):
retry_period : float retry_period : float
Number of seconds before we retry again. Number of seconds before we retry again.
silent: bool
whether run in silent mode
""" """
tstart = time.time() tstart = time.time()
while True: while True:
...@@ -149,6 +152,7 @@ def connect_with_retry(addr, timeout=60, retry_period=5): ...@@ -149,6 +152,7 @@ def connect_with_retry(addr, timeout=60, retry_period=5):
if period > timeout: if period > timeout:
raise RuntimeError( raise RuntimeError(
"Failed to connect to server %s" % str(addr)) "Failed to connect to server %s" % str(addr))
if not silent:
logging.info("Cannot connect to tracker%s, retry in %g secs...", logging.info("Cannot connect to tracker%s, retry in %g secs...",
str(addr), retry_period) str(addr), retry_period)
time.sleep(retry_period) time.sleep(retry_period)
......
...@@ -536,7 +536,7 @@ def websocket_proxy_server(url, key=""): ...@@ -536,7 +536,7 @@ def websocket_proxy_server(url, key=""):
def _connect(key): def _connect(key):
conn = yield websocket.websocket_connect(url) conn = yield websocket.websocket_connect(url)
on_message = create_on_message(conn) on_message = create_on_message(conn)
temp = _server_env(None) temp = _server_env(None, None)
# Start connecton # Start connecton
conn.write_message(struct.pack('<i', base.RPC_MAGIC), binary=True) conn.write_message(struct.pack('<i', base.RPC_MAGIC), binary=True)
key = "server:" + key key = "server:" + key
......
...@@ -19,6 +19,7 @@ import logging ...@@ -19,6 +19,7 @@ import logging
import multiprocessing import multiprocessing
import subprocess import subprocess
import time import time
import sys
from ..._ffi.function import register_func from ..._ffi.function import register_func
from ..._ffi.base import py_str from ..._ffi.base import py_str
...@@ -28,9 +29,12 @@ from .. import util ...@@ -28,9 +29,12 @@ from .. import util
from . import base from . import base
from . base import TrackerCode from . base import TrackerCode
def _server_env(load_library): def _server_env(load_library, logger):
"""Server environment function return temp dir""" """Server environment function return temp dir"""
temp = util.tempdir() temp = util.tempdir()
if logger is None:
logger = logging.getLogger()
# pylint: disable=unused-variable # pylint: disable=unused-variable
@register_func("tvm.contrib.rpc.server.workpath") @register_func("tvm.contrib.rpc.server.workpath")
def get_workpath(path): def get_workpath(path):
...@@ -41,7 +45,7 @@ def _server_env(load_library): ...@@ -41,7 +45,7 @@ def _server_env(load_library):
"""Load module from remote side.""" """Load module from remote side."""
path = temp.relpath(file_name) path = temp.relpath(file_name)
m = _load_module(path) m = _load_module(path)
logging.info("load_module %s", path) logger.info("load_module %s", path)
return m return m
libs = [] libs = []
...@@ -49,18 +53,21 @@ def _server_env(load_library): ...@@ -49,18 +53,21 @@ def _server_env(load_library):
for file_name in load_library: for file_name in load_library:
file_name = find_lib_path(file_name)[0] file_name = find_lib_path(file_name)[0]
libs.append(ctypes.CDLL(file_name, ctypes.RTLD_GLOBAL)) libs.append(ctypes.CDLL(file_name, ctypes.RTLD_GLOBAL))
logging.info("Load additional library %s", file_name) logger.info("Load additional library %s", file_name)
temp.libs = libs temp.libs = libs
return temp return temp
def _serve_loop(sock, addr, load_library): def _serve_loop(sock, addr, load_library, silent):
"""Server loop""" """Server loop"""
logger = logging.getLogger("RPCServer")
if silent:
logger.disabled = True
sockfd = sock.fileno() sockfd = sock.fileno()
temp = _server_env(load_library) temp = _server_env(load_library, logger)
base._ServerLoop(sockfd) base._ServerLoop(sockfd)
temp.remove() temp.remove()
logging.info("Finish serving %s", addr) logger.info("Finish serving %s", addr)
def _parse_server_opt(opts): def _parse_server_opt(opts):
...@@ -71,8 +78,12 @@ def _parse_server_opt(opts): ...@@ -71,8 +78,12 @@ def _parse_server_opt(opts):
ret["timeout"] = float(kv[9:]) ret["timeout"] = float(kv[9:])
return ret return ret
def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr): def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr, silent):
"""Lisenting loop of the server master.""" """Listening loop of the server master."""
logger = logging.getLogger("RPCServer")
if silent:
logger.disabled = True
def _accept_conn(listen_sock, tracker_conn, ping_period=2): def _accept_conn(listen_sock, tracker_conn, ping_period=2):
"""Accept connection from the other places. """Accept connection from the other places.
...@@ -115,7 +126,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr): ...@@ -115,7 +126,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr):
unmatch_period_count = 0 unmatch_period_count = 0
# regenerate match key if key is acquired but not used for a while # regenerate match key if key is acquired but not used for a while
if unmatch_period_count * ping_period > unmatch_timeout + ping_period: if unmatch_period_count * ping_period > unmatch_timeout + ping_period:
logging.info("RPCServer: no incoming connections, regenerate key ...") logger.info("no incoming connections, regenerate key ...")
matchkey = base.random_key(rpc_key + ":", old_keyset) matchkey = base.random_key(rpc_key + ":", old_keyset)
base.sendjson(tracker_conn, base.sendjson(tracker_conn,
[TrackerCode.PUT, rpc_key, (port, matchkey), [TrackerCode.PUT, rpc_key, (port, matchkey),
...@@ -136,7 +147,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr): ...@@ -136,7 +147,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr):
if arr[0] != expect_header: if arr[0] != expect_header:
conn.sendall(struct.pack("<i", base.RPC_CODE_MISMATCH)) conn.sendall(struct.pack("<i", base.RPC_CODE_MISMATCH))
conn.close() conn.close()
logging.info("RPCServer: mismatch key from %s", addr) logger.info("mismatch key from %s", addr)
continue continue
else: else:
conn.sendall(struct.pack("<i", base.RPC_CODE_SUCCESS)) conn.sendall(struct.pack("<i", base.RPC_CODE_SUCCESS))
...@@ -150,7 +161,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr): ...@@ -150,7 +161,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr):
try: try:
# step 1: setup tracker and report to tracker # step 1: setup tracker and report to tracker
if tracker_addr and tracker_conn is None: if tracker_addr and tracker_conn is None:
tracker_conn = base.connect_with_retry(tracker_addr) tracker_conn = base.connect_with_retry(tracker_addr, silent=silent)
tracker_conn.sendall(struct.pack("<i", base.RPC_TRACKER_MAGIC)) tracker_conn.sendall(struct.pack("<i", base.RPC_TRACKER_MAGIC))
magic = struct.unpack("<i", base.recvall(tracker_conn, 4))[0] magic = struct.unpack("<i", base.recvall(tracker_conn, 4))[0]
if magic != base.RPC_TRACKER_MAGIC: if magic != base.RPC_TRACKER_MAGIC:
...@@ -169,10 +180,16 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr): ...@@ -169,10 +180,16 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr):
tracker_conn.close() tracker_conn.close()
tracker_conn = None tracker_conn = None
continue continue
except RuntimeError as exc:
if silent:
return
else:
raise exc
# step 3: serving # step 3: serving
logging.info("RPCServer: connection from %s", addr) logger.info("connection from %s", addr)
server_proc = multiprocessing.Process(target=_serve_loop, args=(conn, addr, load_library)) server_proc = multiprocessing.Process(target=_serve_loop,
args=(conn, addr, load_library, silent))
server_proc.deamon = True server_proc.deamon = True
server_proc.start() server_proc.start()
# close from our side. # close from our side.
...@@ -180,11 +197,14 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr): ...@@ -180,11 +197,14 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr):
# wait until server process finish or timeout # wait until server process finish or timeout
server_proc.join(opts.get("timeout", None)) server_proc.join(opts.get("timeout", None))
if server_proc.is_alive(): if server_proc.is_alive():
logging.info("RPCServer: Timeout in RPC session, kill..") logger.info("Timeout in RPC session, kill..")
server_proc.terminate() server_proc.terminate()
def _connect_proxy_loop(addr, key, load_library): def _connect_proxy_loop(addr, key, load_library, silent):
logger = logging.getLogger("RPCProxy")
if silent:
logger.disabled = True
key = "server:" + key key = "server:" + key
retry_count = 0 retry_count = 0
max_retry = 5 max_retry = 5
...@@ -200,26 +220,26 @@ def _connect_proxy_loop(addr, key, load_library): ...@@ -200,26 +220,26 @@ def _connect_proxy_loop(addr, key, load_library):
if magic == base.RPC_CODE_DUPLICATE: if magic == base.RPC_CODE_DUPLICATE:
raise RuntimeError("key: %s has already been used in proxy" % key) raise RuntimeError("key: %s has already been used in proxy" % key)
elif magic == base.RPC_CODE_MISMATCH: elif magic == base.RPC_CODE_MISMATCH:
logging.info("RPCProxy do not have matching client key %s", key) logger.info("RPCProxy do not have matching client key %s", key)
elif magic != base.RPC_CODE_SUCCESS: elif magic != base.RPC_CODE_SUCCESS:
raise RuntimeError("%s is not RPC Proxy" % str(addr)) raise RuntimeError("%s is not RPC Proxy" % str(addr))
keylen = struct.unpack("<i", base.recvall(sock, 4))[0] keylen = struct.unpack("<i", base.recvall(sock, 4))[0]
remote_key = py_str(base.recvall(sock, keylen)) remote_key = py_str(base.recvall(sock, keylen))
opts = _parse_server_opt(remote_key.split()[1:]) opts = _parse_server_opt(remote_key.split()[1:])
logging.info("RPCProxy connected to %s", str(addr)) logger.info("connected to %s", str(addr))
process = multiprocessing.Process( process = multiprocessing.Process(
target=_serve_loop, args=(sock, addr, load_library)) target=_serve_loop, args=(sock, addr, load_library, silent))
process.deamon = True process.deamon = True
process.start() process.start()
sock.close() sock.close()
process.join(opts.get("timeout", None)) process.join(opts.get("timeout", None))
if process.is_alive(): if process.is_alive():
logging.info("RPCProxyServer: Timeout in RPC session, kill..") logger.info("Timeout in RPC session, kill..")
process.terminate() process.terminate()
retry_count = 0 retry_count = 0
except (socket.error, IOError) as err: except (socket.error, IOError) as err:
retry_count += 1 retry_count += 1
logging.info("Error encountered %s, retry in %g sec", str(err), retry_period) logger.info("Error encountered %s, retry in %g sec", str(err), retry_period)
if retry_count > max_retry: if retry_count > max_retry:
raise RuntimeError("Maximum retry error: last error: %s" % str(err)) raise RuntimeError("Maximum retry error: last error: %s" % str(err))
time.sleep(retry_period) time.sleep(retry_period)
...@@ -264,6 +284,9 @@ class Server(object): ...@@ -264,6 +284,9 @@ class Server(object):
This is recommended to switch on if we want to do local RPC demonstration This is recommended to switch on if we want to do local RPC demonstration
for GPU devices to avoid fork safety issues. for GPU devices to avoid fork safety issues.
silent: bool, optional
Whether run this server in silent mode.
key : str, optional key : str, optional
The key used to identify the server in Proxy connection. The key used to identify the server in Proxy connection.
...@@ -276,6 +299,7 @@ class Server(object): ...@@ -276,6 +299,7 @@ class Server(object):
port_end=9199, port_end=9199,
is_proxy=False, is_proxy=False,
use_popen=False, use_popen=False,
silent=False,
tracker_addr=None, tracker_addr=None,
key="", key="",
load_library=None, load_library=None,
...@@ -290,8 +314,12 @@ class Server(object): ...@@ -290,8 +314,12 @@ class Server(object):
self.libs = [] self.libs = []
self.custom_addr = custom_addr self.custom_addr = custom_addr
self.logger = logging.getLogger("RPCServer")
if silent:
self.logger.disabled = True
if use_popen: if use_popen:
cmd = ["python", cmd = [sys.executable,
"-m", "tvm.exec.rpc_server", "-m", "tvm.exec.rpc_server",
"--host=%s" % host, "--host=%s" % host,
"--port=%s" % port] "--port=%s" % port]
...@@ -303,11 +331,14 @@ class Server(object): ...@@ -303,11 +331,14 @@ class Server(object):
cmd += ["--load-library", load_library] cmd += ["--load-library", load_library]
if custom_addr: if custom_addr:
cmd += ["--custom-addr", custom_addr] cmd += ["--custom-addr", custom_addr]
if silent:
cmd += ["--silent"]
self.proc = multiprocessing.Process( self.proc = multiprocessing.Process(
target=subprocess.check_call, args=(cmd,)) target=subprocess.check_call, args=(cmd,))
self.proc.deamon = True self.proc.deamon = True
self.proc.start() self.proc.start()
time.sleep(1) time.sleep(0.5)
elif not is_proxy: elif not is_proxy:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.port = None self.port = None
...@@ -323,17 +354,18 @@ class Server(object): ...@@ -323,17 +354,18 @@ class Server(object):
raise sock_err raise sock_err
if not self.port: if not self.port:
raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end)) raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end))
logging.info("RPCServer: bind to %s:%d", host, self.port) self.logger.info("bind to %s:%d", host, self.port)
sock.listen(1) sock.listen(1)
self.sock = sock self.sock = sock
self.proc = multiprocessing.Process( self.proc = multiprocessing.Process(
target=_listen_loop, args=( target=_listen_loop, args=(
self.sock, self.port, key, tracker_addr, load_library, self.custom_addr)) self.sock, self.port, key, tracker_addr, load_library,
self.custom_addr, silent))
self.proc.deamon = True self.proc.deamon = True
self.proc.start() self.proc.start()
else: else:
self.proc = multiprocessing.Process( self.proc = multiprocessing.Process(
target=_connect_proxy_loop, args=((host, port), key, load_library)) target=_connect_proxy_loop, args=((host, port), key, load_library, silent))
self.proc.deamon = True self.proc.deamon = True
self.proc.start() self.proc.start()
......
...@@ -309,7 +309,6 @@ class TrackerServerHandler(object): ...@@ -309,7 +309,6 @@ class TrackerServerHandler(object):
def _tracker_server(listen_sock, stop_key): def _tracker_server(listen_sock, stop_key):
handler = TrackerServerHandler(listen_sock, stop_key) handler = TrackerServerHandler(listen_sock, stop_key)
handler.run() handler.run()
logging.info("Tracker Stop signal received, terminating...")
class Tracker(object): class Tracker(object):
...@@ -327,11 +326,19 @@ class Tracker(object): ...@@ -327,11 +326,19 @@ class Tracker(object):
port_end : int, optional port_end : int, optional
The end TCP port to search The end TCP port to search
silent: bool, optional
Whether run in silent mode
""" """
def __init__(self, def __init__(self,
host, host,
port=9190, port=9190,
port_end=9199): port_end=9199,
silent=False):
self.logger = logging.getLogger("RPCTracker")
if silent:
self.logger.disabled = True
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.port = None self.port = None
self.stop_key = base.random_key("tracker") self.stop_key = base.random_key("tracker")
...@@ -347,7 +354,7 @@ class Tracker(object): ...@@ -347,7 +354,7 @@ class Tracker(object):
raise sock_err raise sock_err
if not self.port: if not self.port:
raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end)) raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end))
logging.info("RPCTracker: bind to %s:%d", host, self.port) self.logger.info("bind to %s:%d", host, self.port)
sock.listen(1) sock.listen(1)
self.proc = multiprocessing.Process( self.proc = multiprocessing.Process(
target=_tracker_server, args=(sock, self.stop_key)) target=_tracker_server, args=(sock, self.stop_key))
...@@ -373,7 +380,7 @@ class Tracker(object): ...@@ -373,7 +380,7 @@ class Tracker(object):
self._stop_tracker() self._stop_tracker()
self.proc.join(1) self.proc.join(1)
if self.proc.is_alive(): if self.proc.is_alive():
logging.info("Terminating Tracker Server...") self.logger.info("Terminating Tracker Server...")
self.proc.terminate() self.proc.terminate()
self.proc = None self.proc = None
......
...@@ -27,7 +27,8 @@ def main(args): ...@@ -27,7 +27,8 @@ def main(args):
key=args.key, key=args.key,
tracker_addr=tracker_addr, tracker_addr=tracker_addr,
load_library=args.load_library, load_library=args.load_library,
custom_addr=args.custom_addr) custom_addr=args.custom_addr,
silent=args.silent)
server.proc.join() server.proc.join()
...@@ -51,6 +52,8 @@ if __name__ == "__main__": ...@@ -51,6 +52,8 @@ if __name__ == "__main__":
and ROCM compilers.") and ROCM compilers.")
parser.add_argument('--custom-addr', type=str, parser.add_argument('--custom-addr', type=str,
help="Custom IP Address to Report to RPC Tracker") help="Custom IP Address to Report to RPC Tracker")
parser.add_argument('--silent', action='store_true',
help="Whether run in silent mode.")
parser.set_defaults(fork=True) parser.set_defaults(fork=True)
args = parser.parse_args() args = parser.parse_args()
...@@ -62,6 +65,7 @@ if __name__ == "__main__": ...@@ -62,6 +65,7 @@ if __name__ == "__main__":
) )
multiprocessing.set_start_method('spawn') multiprocessing.set_start_method('spawn')
else: else:
logging.info("If you are running ROCM/Metal, \ if not args.silent:
fork with cause compiler internal error. Try to launch with arg ```--no-fork```") logging.info("If you are running ROCM/Metal, fork will cause "
"compiler internal error. Try to launch with arg ```--no-fork```")
main(args) main(args)
...@@ -11,7 +11,8 @@ from ..contrib.rpc.tracker import Tracker ...@@ -11,7 +11,8 @@ from ..contrib.rpc.tracker import Tracker
def main(args): def main(args):
"""Main funciton""" """Main funciton"""
tracker = Tracker(args.host, port=args.port) tracker = Tracker(args.host, port=args.port, port_end=args.port_end,
silent=args.silent)
tracker.proc.join() tracker.proc.join()
...@@ -21,10 +22,15 @@ if __name__ == "__main__": ...@@ -21,10 +22,15 @@ if __name__ == "__main__":
help='the hostname of the tracker') help='the hostname of the tracker')
parser.add_argument('--port', type=int, default=9190, parser.add_argument('--port', type=int, default=9190,
help='The port of the PRC') help='The port of the PRC')
parser.add_argument('--port-end', type=int, default=9199,
help='The end search port of the PRC')
parser.add_argument('--no-fork', dest='fork', action='store_false', parser.add_argument('--no-fork', dest='fork', action='store_false',
help="Use spawn mode to avoid fork. This option \ help="Use spawn mode to avoid fork. This option \
is able to avoid potential fork problems with Metal, OpenCL \ is able to avoid potential fork problems with Metal, OpenCL \
and ROCM compilers.") and ROCM compilers.")
parser.add_argument('--silent', action='store_true',
help="Whether run in silent mode.")
parser.set_defaults(fork=True) parser.set_defaults(fork=True)
args = parser.parse_args() args = parser.parse_args()
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
...@@ -35,6 +41,7 @@ if __name__ == "__main__": ...@@ -35,6 +41,7 @@ if __name__ == "__main__":
) )
multiprocessing.set_start_method('spawn') multiprocessing.set_start_method('spawn')
else: else:
logging.info("If you are running ROCM/Metal, \ if not args.silent:
fork with cause compiler internal error. Try to launch with arg ```--no-fork```") logging.info("If you are running ROCM/Metal, fork will cause "
"compiler internal error. Try to launch with arg ```--no-fork```")
main(args) main(args)
...@@ -20,7 +20,6 @@ TVM_REGISTER_GLOBAL("tvm.contrib.rpc.server.upload"). ...@@ -20,7 +20,6 @@ TVM_REGISTER_GLOBAL("tvm.contrib.rpc.server.upload").
set_body([](TVMArgs args, TVMRetValue *rv) { set_body([](TVMArgs args, TVMRetValue *rv) {
std::string file_name = RPCGetPath(args[0]); std::string file_name = RPCGetPath(args[0]);
std::string data = args[1]; std::string data = args[1];
LOG(INFO) << "Upload " << file_name << "... nbytes=" << data.length();
SaveBinaryToFile(file_name, data); SaveBinaryToFile(file_name, data);
}); });
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment