Commit 79fc6672 by Tianqi Chen Committed by GitHub

[RPC] Tracker status query (#1081)

parent 6bd8dbc7
...@@ -32,7 +32,8 @@ class TrackerCode(object): ...@@ -32,7 +32,8 @@ class TrackerCode(object):
STOP = 2 STOP = 2
PUT = 3 PUT = 3
REQUEST = 4 REQUEST = 4
UPDATE_INFO = 5
SUMMARY = 6
RPC_SESS_MASK = 128 RPC_SESS_MASK = 128
......
...@@ -4,6 +4,7 @@ from __future__ import absolute_import ...@@ -4,6 +4,7 @@ from __future__ import absolute_import
import os import os
import socket import socket
import struct import struct
import time
from . import base from . import base
from ..._ffi.base import TVMError from ..._ffi.base import TVMError
...@@ -150,7 +151,6 @@ class TrackerSession(object): ...@@ -150,7 +151,6 @@ class TrackerSession(object):
def __init__(self, addr): def __init__(self, addr):
self._addr = addr self._addr = addr
self._sock = None self._sock = None
self._max_request_retry = 5
self._connect() self._connect()
def __del__(self): def __del__(self):
...@@ -169,7 +169,38 @@ class TrackerSession(object): ...@@ -169,7 +169,38 @@ class TrackerSession(object):
self._sock.close() self._sock.close()
self._sock = None self._sock = None
def request(self, key, priority=1, session_timeout=0): def summary(self):
"""Get the summary dict of the tracker."""
base.sendjson(self._sock, [base.TrackerCode.SUMMARY])
value = base.recvjson(self._sock)
if value[0] != base.TrackerCode.SUCCESS:
raise RuntimeError("Invalid return value %s" % str(value))
return value[1]
def text_summary(self):
"""Get a text summary of the tracker."""
data = self.summary()
res = ""
res += "Server List\n"
res += "----------------------------\n"
res += "server-address\tkey\n"
res += "----------------------------\n"
for item in data["server_info"]:
addr = item["addr"]
res += addr[0] + ":" + str(addr[1])+ "\t"
res += item["key"] + "\n"
res += "----------------------------\n"
res += "\n"
res += "Queue Status\n"
res += "----------------------------\n"
res += "key\tfree\tpending\n"
res += "----------------------------\n"
for k, v in data["queue_info"].items():
res += "%s\t%d\t%g\n" % (k, v["free"], v["pending"])
res += "----------------------------\n"
return res
def request(self, key, priority=1, session_timeout=0, max_retry=5):
"""Request a new connection from the tracker. """Request a new connection from the tracker.
Parameters Parameters
...@@ -184,8 +215,12 @@ class TrackerSession(object): ...@@ -184,8 +215,12 @@ class TrackerSession(object):
The duration of the session, allows server to kill The duration of the session, allows server to kill
the connection when duration is longer than this value. the connection when duration is longer than this value.
When duration is zero, it means the request must always be kept alive. When duration is zero, it means the request must always be kept alive.
max_retry : int, optional
Maximum number of times to retry before give up.
""" """
for _ in range(self._max_request_retry): last_err = None
for _ in range(max_retry):
try: try:
if self._sock is None: if self._sock is None:
self._connect() self._connect()
...@@ -196,10 +231,63 @@ class TrackerSession(object): ...@@ -196,10 +231,63 @@ class TrackerSession(object):
raise RuntimeError("Invalid return value %s" % str(value)) raise RuntimeError("Invalid return value %s" % str(value))
url, port, matchkey = value[1] url, port, matchkey = value[1]
return connect(url, port, key + matchkey, session_timeout) return connect(url, port, key + matchkey, session_timeout)
except socket.error: except socket.error as err:
self.close() self.close()
except TVMError: last_err = err
pass except TVMError as err:
last_err = err
raise RuntimeError(
"Cannot request %s after %d retry, last_error:%s" % (
key, max_retry, str(last_err)))
def request_and_run(self,
key,
func,
priority=1,
session_timeout=0,
max_retry=2):
"""Request a resource from tracker and run the func.
This function safe-guard rare server node dropout during execution.
In such case, a new resource will be requested and func will be ran again.
Parameters
----------
key : str
The type key of the device.
func : function of session -> value
A stateless function
priority : int, optional
The priority of the request.
session_timeout : float, optional
The duration of the session, allows server to kill
the connection when duration is longer than this value.
When duration is zero, it means the request must always be kept alive.
max_retry : int, optional
Maximum number of times to retry the function before give up.
"""
last_err = None
for _ in range(max_retry):
try:
sess = self.request(key,
priority=priority,
session_timeout=session_timeout)
tstart = time.time()
return func(sess)
except TVMError as err:
duration = time.time() - tstart
# roughly estimate if the error is due to timeout termination
if session_timeout and duration >= session_timeout * 0.95:
raise RuntimeError(
"Session timeout when running %s" % func.__name__)
last_err = err
raise RuntimeError(
"Failed to run on %s after %d retry, last_error:%s" % (
key, max_retry, str(last_err)))
def connect(url, port, key="", session_timeout=0): def connect(url, port, key="", session_timeout=0):
......
...@@ -137,6 +137,11 @@ def _listen_loop(sock, port, rpc_key, tracker_addr): ...@@ -137,6 +137,11 @@ def _listen_loop(sock, port, rpc_key, tracker_addr):
magic = struct.unpack("@i", base.recvall(tracker_conn, 4))[0] magic = struct.unpack("@i", base.recvall(tracker_conn, 4))[0]
if magic != base.RPC_TRACKER_MAGIC: if magic != base.RPC_TRACKER_MAGIC:
raise RuntimeError("%s is not RPC Tracker" % str(tracker_addr)) raise RuntimeError("%s is not RPC Tracker" % str(tracker_addr))
# report status of current queue
cinfo = {"key" : "server:" + rpc_key}
base.sendjson(tracker_conn,
[TrackerCode.UPDATE_INFO, cinfo])
assert base.recvjson(tracker_conn) == TrackerCode.SUCCESS
try: try:
# step 2: wait for in-coming connections # step 2: wait for in-coming connections
conn, addr, opts = _accept_conn(sock, tracker_conn) conn, addr, opts = _accept_conn(sock, tracker_conn)
......
...@@ -75,10 +75,15 @@ class Scheduler(object): ...@@ -75,10 +75,15 @@ class Scheduler(object):
""" """
raise NotImplementedError() raise NotImplementedError()
def summary(self):
"""Get summary information of the scheduler."""
raise NotImplementedError()
class PriorityScheduler(Scheduler): class PriorityScheduler(Scheduler):
"""Priority based scheduler, FIFO based on time""" """Priority based scheduler, FIFO based on time"""
def __init__(self): def __init__(self, key):
self._key = key
self._values = [] self._values = []
self._requests = [] self._requests = []
...@@ -98,6 +103,11 @@ class PriorityScheduler(Scheduler): ...@@ -98,6 +103,11 @@ class PriorityScheduler(Scheduler):
heapq.heappush(self._requests, (-priority, time.time(), callback)) heapq.heappush(self._requests, (-priority, time.time(), callback))
self._schedule() self._schedule()
def summary(self):
"""Get summary information of the scheduler."""
return {"free": len(self._values),
"pending": len(self._requests)}
class TCPEventHandler(tornado_util.TCPHandler): class TCPEventHandler(tornado_util.TCPHandler):
"""Base asynchronize message handler. """Base asynchronize message handler.
...@@ -113,12 +123,17 @@ class TCPEventHandler(tornado_util.TCPHandler): ...@@ -113,12 +123,17 @@ class TCPEventHandler(tornado_util.TCPHandler):
self._msg_size = 0 self._msg_size = 0
self._addr = addr self._addr = addr
self._init_req_nbytes = 4 self._init_req_nbytes = 4
self._info = {"addr": addr}
self._tracker._connections.add(self) self._tracker._connections.add(self)
def name(self): def name(self):
"""name of connection""" """name of connection"""
return "TCPSocket: %s" % str(self._addr) return "TCPSocket: %s" % str(self._addr)
def summary(self):
"""Summary of this connection"""
return self._info
def _init_conn(self, message): def _init_conn(self, message):
"""Initialie the connection""" """Initialie the connection"""
if len(message) != 4: if len(message) != 4:
...@@ -193,6 +208,12 @@ class TCPEventHandler(tornado_util.TCPHandler): ...@@ -193,6 +208,12 @@ class TCPEventHandler(tornado_util.TCPHandler):
self._tracker.stop() self._tracker.stop()
else: else:
self.ret_value(TrackerCode.FAIL) self.ret_value(TrackerCode.FAIL)
elif code == TrackerCode.UPDATE_INFO:
self._info.update(args[1])
self.ret_value(TrackerCode.SUCCESS)
elif code == TrackerCode.SUMMARY:
status = self._tracker.summary()
self.ret_value([TrackerCode.SUCCESS, status])
else: else:
logging.info("Unknown tracker code %d", code) logging.info("Unknown tracker code %d", code)
self.close() self.close()
...@@ -230,8 +251,7 @@ class TrackerServerHandler(object): ...@@ -230,8 +251,7 @@ class TrackerServerHandler(object):
def create_scheduler(self, key): def create_scheduler(self, key):
"""Create a new scheduler.""" """Create a new scheduler."""
_ = key return PriorityScheduler(key)
return PriorityScheduler()
def put(self, key, value): def put(self, key, value):
"""Report a new resource to the tracker.""" """Report a new resource to the tracker."""
...@@ -252,6 +272,19 @@ class TrackerServerHandler(object): ...@@ -252,6 +272,19 @@ class TrackerServerHandler(object):
self._sock.close() self._sock.close()
self._ioloop.stop() self._ioloop.stop()
def summary(self):
"""Return a dict summarizing current status."""
qinfo = {}
for k, v in self._scheduler_map.items():
qinfo[k] = v.summary()
cinfo = []
# ignore client connections without key
for conn in self._connections:
res = conn.summary()
if res.get("key", "").startswith("server"):
cinfo.append(res)
return {"queue_info": qinfo, "server_info": cinfo}
def run(self): def run(self):
"""Run the tracker server""" """Run the tracker server"""
self._ioloop.start() self._ioloop.start()
......
"""Tool to query RPC tracker status"""
from __future__ import absolute_import
import logging
import argparse
import os
from ..contrib import rpc
def main():
"""Main funciton"""
parser = argparse.ArgumentParser()
parser.add_argument('--host', type=str, default="",
help='the hostname of the tracker')
parser.add_argument('--port', type=int, default=None,
help='The port of the PRC')
args = parser.parse_args()
logging.basicConfig(level=logging.INFO)
# default to local host or environment variable
if not args.host:
args.host = os.environ.get("TVM_TRACKER_HOST", "localhost")
if not args.port:
args.port = int(os.environ.get("TVM_TRACKER_PORT", "9190"))
conn = rpc.connect_tracker(args.host, args.port)
# pylint: disable=superfluous-parens
print("Tracker address %s:%d\n" % (args.host, args.port))
print("%s" % conn.text_summary())
if __name__ == "__main__":
main()
"""RPC web proxy, allows redirect to websocket based RPC servers(browsers)""" """Tool to start RPC tracker"""
from __future__ import absolute_import from __future__ import absolute_import
import logging import logging
...@@ -9,7 +9,7 @@ def main(): ...@@ -9,7 +9,7 @@ def main():
"""Main funciton""" """Main funciton"""
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--host', type=str, default="0.0.0.0", parser.add_argument('--host', type=str, default="0.0.0.0",
help='the hostname of the server') help='the hostname of the tracker')
parser.add_argument('--port', type=int, default=9190, parser.add_argument('--port', type=int, default=9190,
help='The port of the PRC') help='The port of the PRC')
args = parser.parse_args() args = parser.parse_args()
......
...@@ -38,6 +38,15 @@ def check_server_drop(): ...@@ -38,6 +38,15 @@ def check_server_drop():
# Fault tolerence server timeout # Fault tolerence server timeout
def check_timeout(timeout, sleeptime): def check_timeout(timeout, sleeptime):
def myfunc(remote):
time.sleep(sleeptime)
f1 = remote.get_function("rpc.test2.addone")
assert f1(10) == 11
try:
tclient.request_and_run("xyz", myfunc, session_timeout=timeout)
except RuntimeError:
pass
print(tclient.text_summary())
try: try:
remote = tclient.request("xyz", priority=0, session_timeout=timeout) remote = tclient.request("xyz", priority=0, session_timeout=timeout)
remote2 = tclient.request("xyz", session_timeout=timeout) remote2 = tclient.request("xyz", session_timeout=timeout)
...@@ -48,8 +57,11 @@ def check_server_drop(): ...@@ -48,8 +57,11 @@ def check_server_drop():
assert f1(10) == 11 assert f1(10) == 11
except tvm.TVMError as e: except tvm.TVMError as e:
pass pass
check_timeout(0.01, 0.1) check_timeout(0.01, 0.1)
check_timeout(2, 0) check_timeout(2, 0)
except ImportError: except ImportError:
print("Skip because tornado is not available") print("Skip because tornado is not available")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment