Commit d0eb2d3d by Lianmin Zheng Committed by Tianqi Chen

Add silent mode to rpc server and rpc tracker (#1268)

parent 558cf098
...@@ -120,7 +120,7 @@ def random_key(prefix, cmap=None): ...@@ -120,7 +120,7 @@ def random_key(prefix, cmap=None):
return prefix + str(random.random()) return prefix + str(random.random())
def connect_with_retry(addr, timeout=60, retry_period=5): def connect_with_retry(addr, timeout=60, retry_period=5, silent=False):
"""Connect to a TPC address with retry """Connect to a TPC address with retry
This function is only reliable to short period of server restart. This function is only reliable to short period of server restart.
...@@ -135,6 +135,9 @@ def connect_with_retry(addr, timeout=60, retry_period=5): ...@@ -135,6 +135,9 @@ def connect_with_retry(addr, timeout=60, retry_period=5):
retry_period : float retry_period : float
Number of seconds before we retry again. Number of seconds before we retry again.
silent: bool
whether run in silent mode
""" """
tstart = time.time() tstart = time.time()
while True: while True:
...@@ -149,6 +152,7 @@ def connect_with_retry(addr, timeout=60, retry_period=5): ...@@ -149,6 +152,7 @@ def connect_with_retry(addr, timeout=60, retry_period=5):
if period > timeout: if period > timeout:
raise RuntimeError( raise RuntimeError(
"Failed to connect to server %s" % str(addr)) "Failed to connect to server %s" % str(addr))
if not silent:
logging.info("Cannot connect to tracker%s, retry in %g secs...", logging.info("Cannot connect to tracker%s, retry in %g secs...",
str(addr), retry_period) str(addr), retry_period)
time.sleep(retry_period) time.sleep(retry_period)
......
...@@ -536,7 +536,7 @@ def websocket_proxy_server(url, key=""): ...@@ -536,7 +536,7 @@ def websocket_proxy_server(url, key=""):
def _connect(key): def _connect(key):
conn = yield websocket.websocket_connect(url) conn = yield websocket.websocket_connect(url)
on_message = create_on_message(conn) on_message = create_on_message(conn)
temp = _server_env(None) temp = _server_env(None, None)
# Start connecton # Start connecton
conn.write_message(struct.pack('<i', base.RPC_MAGIC), binary=True) conn.write_message(struct.pack('<i', base.RPC_MAGIC), binary=True)
key = "server:" + key key = "server:" + key
......
...@@ -309,7 +309,6 @@ class TrackerServerHandler(object): ...@@ -309,7 +309,6 @@ class TrackerServerHandler(object):
def _tracker_server(listen_sock, stop_key): def _tracker_server(listen_sock, stop_key):
handler = TrackerServerHandler(listen_sock, stop_key) handler = TrackerServerHandler(listen_sock, stop_key)
handler.run() handler.run()
logging.info("Tracker Stop signal received, terminating...")
class Tracker(object): class Tracker(object):
...@@ -327,11 +326,19 @@ class Tracker(object): ...@@ -327,11 +326,19 @@ class Tracker(object):
port_end : int, optional port_end : int, optional
The end TCP port to search The end TCP port to search
silent: bool, optional
Whether run in silent mode
""" """
def __init__(self, def __init__(self,
host, host,
port=9190, port=9190,
port_end=9199): port_end=9199,
silent=False):
self.logger = logging.getLogger("RPCTracker")
if silent:
self.logger.disabled = True
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.port = None self.port = None
self.stop_key = base.random_key("tracker") self.stop_key = base.random_key("tracker")
...@@ -347,7 +354,7 @@ class Tracker(object): ...@@ -347,7 +354,7 @@ class Tracker(object):
raise sock_err raise sock_err
if not self.port: if not self.port:
raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end)) raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end))
logging.info("RPCTracker: bind to %s:%d", host, self.port) self.logger.info("bind to %s:%d", host, self.port)
sock.listen(1) sock.listen(1)
self.proc = multiprocessing.Process( self.proc = multiprocessing.Process(
target=_tracker_server, args=(sock, self.stop_key)) target=_tracker_server, args=(sock, self.stop_key))
...@@ -373,7 +380,7 @@ class Tracker(object): ...@@ -373,7 +380,7 @@ class Tracker(object):
self._stop_tracker() self._stop_tracker()
self.proc.join(1) self.proc.join(1)
if self.proc.is_alive(): if self.proc.is_alive():
logging.info("Terminating Tracker Server...") self.logger.info("Terminating Tracker Server...")
self.proc.terminate() self.proc.terminate()
self.proc = None self.proc = None
......
...@@ -27,7 +27,8 @@ def main(args): ...@@ -27,7 +27,8 @@ def main(args):
key=args.key, key=args.key,
tracker_addr=tracker_addr, tracker_addr=tracker_addr,
load_library=args.load_library, load_library=args.load_library,
custom_addr=args.custom_addr) custom_addr=args.custom_addr,
silent=args.silent)
server.proc.join() server.proc.join()
...@@ -51,6 +52,8 @@ if __name__ == "__main__": ...@@ -51,6 +52,8 @@ if __name__ == "__main__":
and ROCM compilers.") and ROCM compilers.")
parser.add_argument('--custom-addr', type=str, parser.add_argument('--custom-addr', type=str,
help="Custom IP Address to Report to RPC Tracker") help="Custom IP Address to Report to RPC Tracker")
parser.add_argument('--silent', action='store_true',
help="Whether run in silent mode.")
parser.set_defaults(fork=True) parser.set_defaults(fork=True)
args = parser.parse_args() args = parser.parse_args()
...@@ -62,6 +65,7 @@ if __name__ == "__main__": ...@@ -62,6 +65,7 @@ if __name__ == "__main__":
) )
multiprocessing.set_start_method('spawn') multiprocessing.set_start_method('spawn')
else: else:
logging.info("If you are running ROCM/Metal, \ if not args.silent:
fork with cause compiler internal error. Try to launch with arg ```--no-fork```") logging.info("If you are running ROCM/Metal, fork will cause "
"compiler internal error. Try to launch with arg ```--no-fork```")
main(args) main(args)
...@@ -11,7 +11,8 @@ from ..contrib.rpc.tracker import Tracker ...@@ -11,7 +11,8 @@ from ..contrib.rpc.tracker import Tracker
def main(args): def main(args):
"""Main funciton""" """Main funciton"""
tracker = Tracker(args.host, port=args.port) tracker = Tracker(args.host, port=args.port, port_end=args.port_end,
silent=args.silent)
tracker.proc.join() tracker.proc.join()
...@@ -21,10 +22,15 @@ if __name__ == "__main__": ...@@ -21,10 +22,15 @@ if __name__ == "__main__":
help='the hostname of the tracker') help='the hostname of the tracker')
parser.add_argument('--port', type=int, default=9190, parser.add_argument('--port', type=int, default=9190,
help='The port of the PRC') help='The port of the PRC')
parser.add_argument('--port-end', type=int, default=9199,
help='The end search port of the PRC')
parser.add_argument('--no-fork', dest='fork', action='store_false', parser.add_argument('--no-fork', dest='fork', action='store_false',
help="Use spawn mode to avoid fork. This option \ help="Use spawn mode to avoid fork. This option \
is able to avoid potential fork problems with Metal, OpenCL \ is able to avoid potential fork problems with Metal, OpenCL \
and ROCM compilers.") and ROCM compilers.")
parser.add_argument('--silent', action='store_true',
help="Whether run in silent mode.")
parser.set_defaults(fork=True) parser.set_defaults(fork=True)
args = parser.parse_args() args = parser.parse_args()
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
...@@ -35,6 +41,7 @@ if __name__ == "__main__": ...@@ -35,6 +41,7 @@ if __name__ == "__main__":
) )
multiprocessing.set_start_method('spawn') multiprocessing.set_start_method('spawn')
else: else:
logging.info("If you are running ROCM/Metal, \ if not args.silent:
fork with cause compiler internal error. Try to launch with arg ```--no-fork```") logging.info("If you are running ROCM/Metal, fork will cause "
"compiler internal error. Try to launch with arg ```--no-fork```")
main(args) main(args)
...@@ -20,7 +20,6 @@ TVM_REGISTER_GLOBAL("tvm.contrib.rpc.server.upload"). ...@@ -20,7 +20,6 @@ TVM_REGISTER_GLOBAL("tvm.contrib.rpc.server.upload").
set_body([](TVMArgs args, TVMRetValue *rv) { set_body([](TVMArgs args, TVMRetValue *rv) {
std::string file_name = RPCGetPath(args[0]); std::string file_name = RPCGetPath(args[0]);
std::string data = args[1]; std::string data = args[1];
LOG(INFO) << "Upload " << file_name << "... nbytes=" << data.length();
SaveBinaryToFile(file_name, data); SaveBinaryToFile(file_name, data);
}); });
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment