Commit 79fc6672 by Tianqi Chen Committed by GitHub

[RPC] Tracker status query (#1081)

parent 6bd8dbc7
......@@ -32,7 +32,8 @@ class TrackerCode(object):
STOP = 2
PUT = 3
REQUEST = 4
UPDATE_INFO = 5
SUMMARY = 6
RPC_SESS_MASK = 128
......
......@@ -4,6 +4,7 @@ from __future__ import absolute_import
import os
import socket
import struct
import time
from . import base
from ..._ffi.base import TVMError
......@@ -150,7 +151,6 @@ class TrackerSession(object):
def __init__(self, addr):
self._addr = addr
self._sock = None
self._max_request_retry = 5
self._connect()
def __del__(self):
......@@ -169,7 +169,38 @@ class TrackerSession(object):
self._sock.close()
self._sock = None
def request(self, key, priority=1, session_timeout=0):
def summary(self):
"""Get the summary dict of the tracker."""
base.sendjson(self._sock, [base.TrackerCode.SUMMARY])
value = base.recvjson(self._sock)
if value[0] != base.TrackerCode.SUCCESS:
raise RuntimeError("Invalid return value %s" % str(value))
return value[1]
def text_summary(self):
"""Get a text summary of the tracker."""
data = self.summary()
res = ""
res += "Server List\n"
res += "----------------------------\n"
res += "server-address\tkey\n"
res += "----------------------------\n"
for item in data["server_info"]:
addr = item["addr"]
res += addr[0] + ":" + str(addr[1])+ "\t"
res += item["key"] + "\n"
res += "----------------------------\n"
res += "\n"
res += "Queue Status\n"
res += "----------------------------\n"
res += "key\tfree\tpending\n"
res += "----------------------------\n"
for k, v in data["queue_info"].items():
res += "%s\t%d\t%g\n" % (k, v["free"], v["pending"])
res += "----------------------------\n"
return res
def request(self, key, priority=1, session_timeout=0, max_retry=5):
"""Request a new connection from the tracker.
Parameters
......@@ -184,8 +215,12 @@ class TrackerSession(object):
The duration of the session, allows server to kill
the connection when duration is longer than this value.
When duration is zero, it means the request must always be kept alive.
max_retry : int, optional
Maximum number of times to retry before give up.
"""
for _ in range(self._max_request_retry):
last_err = None
for _ in range(max_retry):
try:
if self._sock is None:
self._connect()
......@@ -196,10 +231,63 @@ class TrackerSession(object):
raise RuntimeError("Invalid return value %s" % str(value))
url, port, matchkey = value[1]
return connect(url, port, key + matchkey, session_timeout)
except socket.error:
except socket.error as err:
self.close()
except TVMError:
pass
last_err = err
except TVMError as err:
last_err = err
raise RuntimeError(
"Cannot request %s after %d retry, last_error:%s" % (
key, max_retry, str(last_err)))
def request_and_run(self,
key,
func,
priority=1,
session_timeout=0,
max_retry=2):
"""Request a resource from tracker and run the func.
This function safe-guard rare server node dropout during execution.
In such case, a new resource will be requested and func will be ran again.
Parameters
----------
key : str
The type key of the device.
func : function of session -> value
A stateless function
priority : int, optional
The priority of the request.
session_timeout : float, optional
The duration of the session, allows server to kill
the connection when duration is longer than this value.
When duration is zero, it means the request must always be kept alive.
max_retry : int, optional
Maximum number of times to retry the function before give up.
"""
last_err = None
for _ in range(max_retry):
try:
sess = self.request(key,
priority=priority,
session_timeout=session_timeout)
tstart = time.time()
return func(sess)
except TVMError as err:
duration = time.time() - tstart
# roughly estimate if the error is due to timeout termination
if session_timeout and duration >= session_timeout * 0.95:
raise RuntimeError(
"Session timeout when running %s" % func.__name__)
last_err = err
raise RuntimeError(
"Failed to run on %s after %d retry, last_error:%s" % (
key, max_retry, str(last_err)))
def connect(url, port, key="", session_timeout=0):
......
......@@ -137,6 +137,11 @@ def _listen_loop(sock, port, rpc_key, tracker_addr):
magic = struct.unpack("@i", base.recvall(tracker_conn, 4))[0]
if magic != base.RPC_TRACKER_MAGIC:
raise RuntimeError("%s is not RPC Tracker" % str(tracker_addr))
# report status of current queue
cinfo = {"key" : "server:" + rpc_key}
base.sendjson(tracker_conn,
[TrackerCode.UPDATE_INFO, cinfo])
assert base.recvjson(tracker_conn) == TrackerCode.SUCCESS
try:
# step 2: wait for in-coming connections
conn, addr, opts = _accept_conn(sock, tracker_conn)
......
......@@ -75,10 +75,15 @@ class Scheduler(object):
"""
raise NotImplementedError()
def summary(self):
"""Get summary information of the scheduler."""
raise NotImplementedError()
class PriorityScheduler(Scheduler):
"""Priority based scheduler, FIFO based on time"""
def __init__(self):
def __init__(self, key):
self._key = key
self._values = []
self._requests = []
......@@ -98,6 +103,11 @@ class PriorityScheduler(Scheduler):
heapq.heappush(self._requests, (-priority, time.time(), callback))
self._schedule()
def summary(self):
"""Get summary information of the scheduler."""
return {"free": len(self._values),
"pending": len(self._requests)}
class TCPEventHandler(tornado_util.TCPHandler):
"""Base asynchronize message handler.
......@@ -113,12 +123,17 @@ class TCPEventHandler(tornado_util.TCPHandler):
self._msg_size = 0
self._addr = addr
self._init_req_nbytes = 4
self._info = {"addr": addr}
self._tracker._connections.add(self)
def name(self):
"""name of connection"""
return "TCPSocket: %s" % str(self._addr)
def summary(self):
"""Summary of this connection"""
return self._info
def _init_conn(self, message):
"""Initialie the connection"""
if len(message) != 4:
......@@ -193,6 +208,12 @@ class TCPEventHandler(tornado_util.TCPHandler):
self._tracker.stop()
else:
self.ret_value(TrackerCode.FAIL)
elif code == TrackerCode.UPDATE_INFO:
self._info.update(args[1])
self.ret_value(TrackerCode.SUCCESS)
elif code == TrackerCode.SUMMARY:
status = self._tracker.summary()
self.ret_value([TrackerCode.SUCCESS, status])
else:
logging.info("Unknown tracker code %d", code)
self.close()
......@@ -230,8 +251,7 @@ class TrackerServerHandler(object):
def create_scheduler(self, key):
"""Create a new scheduler."""
_ = key
return PriorityScheduler()
return PriorityScheduler(key)
def put(self, key, value):
"""Report a new resource to the tracker."""
......@@ -252,6 +272,19 @@ class TrackerServerHandler(object):
self._sock.close()
self._ioloop.stop()
def summary(self):
"""Return a dict summarizing current status."""
qinfo = {}
for k, v in self._scheduler_map.items():
qinfo[k] = v.summary()
cinfo = []
# ignore client connections without key
for conn in self._connections:
res = conn.summary()
if res.get("key", "").startswith("server"):
cinfo.append(res)
return {"queue_info": qinfo, "server_info": cinfo}
def run(self):
"""Run the tracker server"""
self._ioloop.start()
......
"""Tool to query RPC tracker status"""
from __future__ import absolute_import
import logging
import argparse
import os
from ..contrib import rpc
def main():
"""Main funciton"""
parser = argparse.ArgumentParser()
parser.add_argument('--host', type=str, default="",
help='the hostname of the tracker')
parser.add_argument('--port', type=int, default=None,
help='The port of the PRC')
args = parser.parse_args()
logging.basicConfig(level=logging.INFO)
# default to local host or environment variable
if not args.host:
args.host = os.environ.get("TVM_TRACKER_HOST", "localhost")
if not args.port:
args.port = int(os.environ.get("TVM_TRACKER_PORT", "9190"))
conn = rpc.connect_tracker(args.host, args.port)
# pylint: disable=superfluous-parens
print("Tracker address %s:%d\n" % (args.host, args.port))
print("%s" % conn.text_summary())
if __name__ == "__main__":
main()
"""RPC web proxy, allows redirect to websocket based RPC servers(browsers)"""
"""Tool to start RPC tracker"""
from __future__ import absolute_import
import logging
......@@ -9,7 +9,7 @@ def main():
"""Main funciton"""
parser = argparse.ArgumentParser()
parser.add_argument('--host', type=str, default="0.0.0.0",
help='the hostname of the server')
help='the hostname of the tracker')
parser.add_argument('--port', type=int, default=9190,
help='The port of the PRC')
args = parser.parse_args()
......
......@@ -38,6 +38,15 @@ def check_server_drop():
# Fault tolerence server timeout
def check_timeout(timeout, sleeptime):
def myfunc(remote):
time.sleep(sleeptime)
f1 = remote.get_function("rpc.test2.addone")
assert f1(10) == 11
try:
tclient.request_and_run("xyz", myfunc, session_timeout=timeout)
except RuntimeError:
pass
print(tclient.text_summary())
try:
remote = tclient.request("xyz", priority=0, session_timeout=timeout)
remote2 = tclient.request("xyz", session_timeout=timeout)
......@@ -48,8 +57,11 @@ def check_server_drop():
assert f1(10) == 11
except tvm.TVMError as e:
pass
check_timeout(0.01, 0.1)
check_timeout(2, 0)
except ImportError:
print("Skip because tornado is not available")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment