Commit d0eb2d3d by Lianmin Zheng Committed by Tianqi Chen

Add silent mode to rpc server and rpc tracker (#1268)

parent 558cf098
......@@ -120,7 +120,7 @@ def random_key(prefix, cmap=None):
return prefix + str(random.random())
def connect_with_retry(addr, timeout=60, retry_period=5):
def connect_with_retry(addr, timeout=60, retry_period=5, silent=False):
"""Connect to a TPC address with retry
This function is only reliable to short period of server restart.
......@@ -135,6 +135,9 @@ def connect_with_retry(addr, timeout=60, retry_period=5):
retry_period : float
Number of seconds before we retry again.
silent: bool
whether run in silent mode
"""
tstart = time.time()
while True:
......@@ -149,8 +152,9 @@ def connect_with_retry(addr, timeout=60, retry_period=5):
if period > timeout:
raise RuntimeError(
"Failed to connect to server %s" % str(addr))
logging.info("Cannot connect to tracker%s, retry in %g secs...",
str(addr), retry_period)
if not silent:
logging.info("Cannot connect to tracker%s, retry in %g secs...",
str(addr), retry_period)
time.sleep(retry_period)
......
......@@ -536,7 +536,7 @@ def websocket_proxy_server(url, key=""):
def _connect(key):
conn = yield websocket.websocket_connect(url)
on_message = create_on_message(conn)
temp = _server_env(None)
temp = _server_env(None, None)
# Start connecton
conn.write_message(struct.pack('<i', base.RPC_MAGIC), binary=True)
key = "server:" + key
......
......@@ -19,6 +19,7 @@ import logging
import multiprocessing
import subprocess
import time
import sys
from ..._ffi.function import register_func
from ..._ffi.base import py_str
......@@ -28,9 +29,12 @@ from .. import util
from . import base
from . base import TrackerCode
def _server_env(load_library):
def _server_env(load_library, logger):
"""Server environment function return temp dir"""
temp = util.tempdir()
if logger is None:
logger = logging.getLogger()
# pylint: disable=unused-variable
@register_func("tvm.contrib.rpc.server.workpath")
def get_workpath(path):
......@@ -41,7 +45,7 @@ def _server_env(load_library):
"""Load module from remote side."""
path = temp.relpath(file_name)
m = _load_module(path)
logging.info("load_module %s", path)
logger.info("load_module %s", path)
return m
libs = []
......@@ -49,18 +53,21 @@ def _server_env(load_library):
for file_name in load_library:
file_name = find_lib_path(file_name)[0]
libs.append(ctypes.CDLL(file_name, ctypes.RTLD_GLOBAL))
logging.info("Load additional library %s", file_name)
logger.info("Load additional library %s", file_name)
temp.libs = libs
return temp
def _serve_loop(sock, addr, load_library):
def _serve_loop(sock, addr, load_library, silent):
"""Server loop"""
logger = logging.getLogger("RPCServer")
if silent:
logger.disabled = True
sockfd = sock.fileno()
temp = _server_env(load_library)
temp = _server_env(load_library, logger)
base._ServerLoop(sockfd)
temp.remove()
logging.info("Finish serving %s", addr)
logger.info("Finish serving %s", addr)
def _parse_server_opt(opts):
......@@ -71,8 +78,12 @@ def _parse_server_opt(opts):
ret["timeout"] = float(kv[9:])
return ret
def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr):
"""Lisenting loop of the server master."""
def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr, silent):
"""Listening loop of the server master."""
logger = logging.getLogger("RPCServer")
if silent:
logger.disabled = True
def _accept_conn(listen_sock, tracker_conn, ping_period=2):
"""Accept connection from the other places.
......@@ -115,7 +126,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr):
unmatch_period_count = 0
# regenerate match key if key is acquired but not used for a while
if unmatch_period_count * ping_period > unmatch_timeout + ping_period:
logging.info("RPCServer: no incoming connections, regenerate key ...")
logger.info("no incoming connections, regenerate key ...")
matchkey = base.random_key(rpc_key + ":", old_keyset)
base.sendjson(tracker_conn,
[TrackerCode.PUT, rpc_key, (port, matchkey),
......@@ -136,7 +147,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr):
if arr[0] != expect_header:
conn.sendall(struct.pack("<i", base.RPC_CODE_MISMATCH))
conn.close()
logging.info("RPCServer: mismatch key from %s", addr)
logger.info("mismatch key from %s", addr)
continue
else:
conn.sendall(struct.pack("<i", base.RPC_CODE_SUCCESS))
......@@ -150,7 +161,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr):
try:
# step 1: setup tracker and report to tracker
if tracker_addr and tracker_conn is None:
tracker_conn = base.connect_with_retry(tracker_addr)
tracker_conn = base.connect_with_retry(tracker_addr, silent=silent)
tracker_conn.sendall(struct.pack("<i", base.RPC_TRACKER_MAGIC))
magic = struct.unpack("<i", base.recvall(tracker_conn, 4))[0]
if magic != base.RPC_TRACKER_MAGIC:
......@@ -169,10 +180,16 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr):
tracker_conn.close()
tracker_conn = None
continue
except RuntimeError as exc:
if silent:
return
else:
raise exc
# step 3: serving
logging.info("RPCServer: connection from %s", addr)
server_proc = multiprocessing.Process(target=_serve_loop, args=(conn, addr, load_library))
logger.info("connection from %s", addr)
server_proc = multiprocessing.Process(target=_serve_loop,
args=(conn, addr, load_library, silent))
server_proc.deamon = True
server_proc.start()
# close from our side.
......@@ -180,11 +197,14 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr):
# wait until server process finish or timeout
server_proc.join(opts.get("timeout", None))
if server_proc.is_alive():
logging.info("RPCServer: Timeout in RPC session, kill..")
logger.info("Timeout in RPC session, kill..")
server_proc.terminate()
def _connect_proxy_loop(addr, key, load_library):
def _connect_proxy_loop(addr, key, load_library, silent):
logger = logging.getLogger("RPCProxy")
if silent:
logger.disabled = True
key = "server:" + key
retry_count = 0
max_retry = 5
......@@ -200,26 +220,26 @@ def _connect_proxy_loop(addr, key, load_library):
if magic == base.RPC_CODE_DUPLICATE:
raise RuntimeError("key: %s has already been used in proxy" % key)
elif magic == base.RPC_CODE_MISMATCH:
logging.info("RPCProxy do not have matching client key %s", key)
logger.info("RPCProxy do not have matching client key %s", key)
elif magic != base.RPC_CODE_SUCCESS:
raise RuntimeError("%s is not RPC Proxy" % str(addr))
keylen = struct.unpack("<i", base.recvall(sock, 4))[0]
remote_key = py_str(base.recvall(sock, keylen))
opts = _parse_server_opt(remote_key.split()[1:])
logging.info("RPCProxy connected to %s", str(addr))
logger.info("connected to %s", str(addr))
process = multiprocessing.Process(
target=_serve_loop, args=(sock, addr, load_library))
target=_serve_loop, args=(sock, addr, load_library, silent))
process.deamon = True
process.start()
sock.close()
process.join(opts.get("timeout", None))
if process.is_alive():
logging.info("RPCProxyServer: Timeout in RPC session, kill..")
logger.info("Timeout in RPC session, kill..")
process.terminate()
retry_count = 0
except (socket.error, IOError) as err:
retry_count += 1
logging.info("Error encountered %s, retry in %g sec", str(err), retry_period)
logger.info("Error encountered %s, retry in %g sec", str(err), retry_period)
if retry_count > max_retry:
raise RuntimeError("Maximum retry error: last error: %s" % str(err))
time.sleep(retry_period)
......@@ -264,6 +284,9 @@ class Server(object):
This is recommended to switch on if we want to do local RPC demonstration
for GPU devices to avoid fork safety issues.
silent: bool, optional
Whether run this server in silent mode.
key : str, optional
The key used to identify the server in Proxy connection.
......@@ -276,6 +299,7 @@ class Server(object):
port_end=9199,
is_proxy=False,
use_popen=False,
silent=False,
tracker_addr=None,
key="",
load_library=None,
......@@ -290,8 +314,12 @@ class Server(object):
self.libs = []
self.custom_addr = custom_addr
self.logger = logging.getLogger("RPCServer")
if silent:
self.logger.disabled = True
if use_popen:
cmd = ["python",
cmd = [sys.executable,
"-m", "tvm.exec.rpc_server",
"--host=%s" % host,
"--port=%s" % port]
......@@ -303,11 +331,14 @@ class Server(object):
cmd += ["--load-library", load_library]
if custom_addr:
cmd += ["--custom-addr", custom_addr]
if silent:
cmd += ["--silent"]
self.proc = multiprocessing.Process(
target=subprocess.check_call, args=(cmd,))
self.proc.deamon = True
self.proc.start()
time.sleep(1)
time.sleep(0.5)
elif not is_proxy:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.port = None
......@@ -321,19 +352,20 @@ class Server(object):
continue
else:
raise sock_err
if not self.port:
raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end))
logging.info("RPCServer: bind to %s:%d", host, self.port)
if not self.port:
raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end))
self.logger.info("bind to %s:%d", host, self.port)
sock.listen(1)
self.sock = sock
self.proc = multiprocessing.Process(
target=_listen_loop, args=(
self.sock, self.port, key, tracker_addr, load_library, self.custom_addr))
self.sock, self.port, key, tracker_addr, load_library,
self.custom_addr, silent))
self.proc.deamon = True
self.proc.start()
else:
self.proc = multiprocessing.Process(
target=_connect_proxy_loop, args=((host, port), key, load_library))
target=_connect_proxy_loop, args=((host, port), key, load_library, silent))
self.proc.deamon = True
self.proc.start()
......
......@@ -309,7 +309,6 @@ class TrackerServerHandler(object):
def _tracker_server(listen_sock, stop_key):
handler = TrackerServerHandler(listen_sock, stop_key)
handler.run()
logging.info("Tracker Stop signal received, terminating...")
class Tracker(object):
......@@ -327,11 +326,19 @@ class Tracker(object):
port_end : int, optional
The end TCP port to search
silent: bool, optional
Whether run in silent mode
"""
def __init__(self,
host,
port=9190,
port_end=9199):
port_end=9199,
silent=False):
self.logger = logging.getLogger("RPCTracker")
if silent:
self.logger.disabled = True
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.port = None
self.stop_key = base.random_key("tracker")
......@@ -347,7 +354,7 @@ class Tracker(object):
raise sock_err
if not self.port:
raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end))
logging.info("RPCTracker: bind to %s:%d", host, self.port)
self.logger.info("bind to %s:%d", host, self.port)
sock.listen(1)
self.proc = multiprocessing.Process(
target=_tracker_server, args=(sock, self.stop_key))
......@@ -373,7 +380,7 @@ class Tracker(object):
self._stop_tracker()
self.proc.join(1)
if self.proc.is_alive():
logging.info("Terminating Tracker Server...")
self.logger.info("Terminating Tracker Server...")
self.proc.terminate()
self.proc = None
......
......@@ -27,7 +27,8 @@ def main(args):
key=args.key,
tracker_addr=tracker_addr,
load_library=args.load_library,
custom_addr=args.custom_addr)
custom_addr=args.custom_addr,
silent=args.silent)
server.proc.join()
......@@ -51,6 +52,8 @@ if __name__ == "__main__":
and ROCM compilers.")
parser.add_argument('--custom-addr', type=str,
help="Custom IP Address to Report to RPC Tracker")
parser.add_argument('--silent', action='store_true',
help="Whether run in silent mode.")
parser.set_defaults(fork=True)
args = parser.parse_args()
......@@ -62,6 +65,7 @@ if __name__ == "__main__":
)
multiprocessing.set_start_method('spawn')
else:
logging.info("If you are running ROCM/Metal, \
fork with cause compiler internal error. Try to launch with arg ```--no-fork```")
if not args.silent:
logging.info("If you are running ROCM/Metal, fork will cause "
"compiler internal error. Try to launch with arg ```--no-fork```")
main(args)
......@@ -11,7 +11,8 @@ from ..contrib.rpc.tracker import Tracker
def main(args):
"""Main funciton"""
tracker = Tracker(args.host, port=args.port)
tracker = Tracker(args.host, port=args.port, port_end=args.port_end,
silent=args.silent)
tracker.proc.join()
......@@ -21,10 +22,15 @@ if __name__ == "__main__":
help='the hostname of the tracker')
parser.add_argument('--port', type=int, default=9190,
help='The port of the PRC')
parser.add_argument('--port-end', type=int, default=9199,
help='The end search port of the PRC')
parser.add_argument('--no-fork', dest='fork', action='store_false',
help="Use spawn mode to avoid fork. This option \
is able to avoid potential fork problems with Metal, OpenCL \
and ROCM compilers.")
parser.add_argument('--silent', action='store_true',
help="Whether run in silent mode.")
parser.set_defaults(fork=True)
args = parser.parse_args()
logging.basicConfig(level=logging.INFO)
......@@ -35,6 +41,7 @@ if __name__ == "__main__":
)
multiprocessing.set_start_method('spawn')
else:
logging.info("If you are running ROCM/Metal, \
fork with cause compiler internal error. Try to launch with arg ```--no-fork```")
if not args.silent:
logging.info("If you are running ROCM/Metal, fork will cause "
"compiler internal error. Try to launch with arg ```--no-fork```")
main(args)
......@@ -20,7 +20,6 @@ TVM_REGISTER_GLOBAL("tvm.contrib.rpc.server.upload").
set_body([](TVMArgs args, TVMRetValue *rv) {
std::string file_name = RPCGetPath(args[0]);
std::string data = args[1];
LOG(INFO) << "Upload " << file_name << "... nbytes=" << data.length();
SaveBinaryToFile(file_name, data);
});
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment