tracker.py 12.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
"""RPC Tracker, tracks and distributes the TVM RPC resources.

This folder implemements the tracker server logic.

Note
----
Tracker is a TCP based rest api with the following protocol:
- Initial handshake to the peer
  - RPC_TRACKER_MAGIC
- Normal message: [size(int32), json-data]
- Each message is initiated by the client, and the tracker replies with a json.

List of available APIs:

- PING: check if tracker is alive
  - input: [TrackerCode.PING]
  - return: TrackerCode.SUCCESS
- PUT: report resource to tracker
  - input: [TrackerCode.PUT, [port, match-key]]
  - return: TrackerCode.SUCCESS
  - note: match-key is a randomly generated identify the resource during connection.
- REQUEST: request a new resource from tracker
  - input: [TrackerCode.REQUEST, [key, user, priority]]
  - return: [TrackerCode.SUCCESS, [url, port, match-key]]
"""
26 27
# pylint: disable=invalid-name

28 29 30 31 32 33 34 35 36 37 38 39 40 41
import heapq
import time
import logging
import socket
import multiprocessing
import errno
import struct
import json

try:
    from tornado import ioloop
    from . import tornado_util
except ImportError as error_msg:
    raise ImportError(
42
        "RPCTracker module requires tornado package %s. Try 'pip install tornado'." % error_msg)
43

44
from .._ffi.base import py_str
45 46 47
from . import base
from .base import RPC_TRACKER_MAGIC, TrackerCode

48
logger = logging.getLogger("RPCTracker")
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80

class Scheduler(object):
    """Abstratc interface of scheduler."""
    def put(self, value):
        """Push a resource into the scheduler.

        This function can trigger callbacks in the scheduler.

        Parameters
        ----------
        value : object
            The resource to be put in the scheduler.
        """
        raise NotImplementedError()

    def request(self, user, priority, callback):
        """Request a resource.

        Parameters
        ----------
        user : str
            The user who is requesting the resource.

        priority : int
            The job priority

        callback : function: value->bool
            Callback function to receive an resource when ready
            returns True if the resource is consumed.
        """
        raise NotImplementedError()

81 82 83 84
    def summary(self):
        """Get summary information of the scheduler."""
        raise NotImplementedError()

85 86 87

class PriorityScheduler(Scheduler):
    """Priority based scheduler, FIFO based on time"""
88 89
    def __init__(self, key):
        self._key = key
90 91 92 93 94 95 96 97
        self._values = []
        self._requests = []

    def _schedule(self):
        while self._requests and self._values:
            value = self._values.pop(0)
            item = heapq.heappop(self._requests)
            callback = item[-1]
98 99 100
            if callback(value[1:]):
                value[0].pending_matchkeys.remove(value[-1])
            else:
101 102 103 104 105 106 107 108 109 110
                self._values.append(value)

    def put(self, value):
        self._values.append(value)
        self._schedule()

    def request(self, user, priority, callback):
        heapq.heappush(self._requests, (-priority, time.time(), callback))
        self._schedule()

111 112 113 114 115
    def summary(self):
        """Get summary information of the scheduler."""
        return {"free": len(self._values),
                "pending": len(self._requests)}

116 117 118 119 120 121 122 123 124 125 126 127 128 129 130

class TCPEventHandler(tornado_util.TCPHandler):
    """Base asynchronize message handler.

    The tracker and client follows a simple message protocol.
    The message is in form [nbytes(int32)] [json-str].
    All the information is packed in json-str
    """
    def __init__(self, tracker, sock, addr):
        super(TCPEventHandler, self).__init__(sock)
        self._data = bytearray()
        self._tracker = tracker
        self._msg_size = 0
        self._addr = addr
        self._init_req_nbytes = 4
131
        self._info = {"addr": addr}
132 133
        # list of pending match keys that has not been used.
        self.pending_matchkeys = set()
134 135 136 137 138 139
        self._tracker._connections.add(self)

    def name(self):
        """name of connection"""
        return "TCPSocket: %s" % str(self._addr)

140 141 142 143
    def summary(self):
        """Summary of this connection"""
        return self._info

144 145 146
    def _init_conn(self, message):
        """Initialie the connection"""
        if len(message) != 4:
147
            logger.warning("Invalid connection from %s", self.name())
148
            self.close()
tqchen committed
149
        magic = struct.unpack('<i', message)[0]
150
        if magic != RPC_TRACKER_MAGIC:
151
            logger.warning("Invalid magic from %s", self.name())
152
            self.close()
tqchen committed
153
        self.write_message(struct.pack('<i', RPC_TRACKER_MAGIC), binary=True)
154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173
        self._init_req_nbytes = 0

    def on_message(self, message):
        """Callback when a message is received.

        Parameters
        ----------
        message : bytearray
            The bytes received
        """
        assert isinstance(message, bytes)
        if self._init_req_nbytes:
            self._init_conn(message)
            return

        self._data += message

        while True:
            if self._msg_size == 0:
                if len(self._data) >= 4:
tqchen committed
174
                    self._msg_size = struct.unpack('<i', self._data[:4])[0]
175 176 177 178 179 180 181 182 183 184 185 186 187 188 189
                else:
                    return
            if self._msg_size != 0 and len(self._data) >= self._msg_size + 4:
                msg = py_str(bytes(self._data[4:4 + self._msg_size]))
                del self._data[:4 + self._msg_size]
                self._msg_size = 0
                # pylint: disable=broad-except
                self.call_handler(json.loads(msg))
            else:
                return

    def ret_value(self, data):
        """return value to the output"""
        data = json.dumps(data)
        self.write_message(
tqchen committed
190
            struct.pack('<i', len(data)), binary=True)
191 192 193 194 195 196 197 198
        self.write_message(data.encode("utf-8"), binary=True)

    def call_handler(self, args):
        """Event handler when json request arrives."""
        code = args[0]
        if code == TrackerCode.PUT:
            key = args[1]
            port, matchkey = args[2]
199
            self.pending_matchkeys.add(matchkey)
200 201 202 203 204
            # got custom address (from rpc server)
            if args[3] is not None:
                self._tracker.put(key, (self, args[3], port, matchkey))
            else:
                self._tracker.put(key, (self, self._addr[0], port, matchkey))
205 206 207 208 209 210
            self.ret_value(TrackerCode.SUCCESS)
        elif code == TrackerCode.REQUEST:
            key = args[1]
            user = args[2]
            priority = args[3]
            def _cb(value):
211 212 213 214 215
                # if the connection is already closed
                if not self._sock:
                    return False
                try:
                    self.ret_value([TrackerCode.SUCCESS, value])
216
                except (socket.error, IOError):
217
                    return False
218 219 220 221
                return True
            self._tracker.request(key, user, priority, _cb)
        elif code == TrackerCode.PING:
            self.ret_value(TrackerCode.SUCCESS)
222 223
        elif code == TrackerCode.GET_PENDING_MATCHKEYS:
            self.ret_value(list(self.pending_matchkeys))
224 225 226 227 228 229 230
        elif code == TrackerCode.STOP:
            # safe stop tracker
            if self._tracker._stop_key == args[1]:
                self.ret_value(TrackerCode.SUCCESS)
                self._tracker.stop()
            else:
                self.ret_value(TrackerCode.FAIL)
231 232 233 234 235 236
        elif code == TrackerCode.UPDATE_INFO:
            self._info.update(args[1])
            self.ret_value(TrackerCode.SUCCESS)
        elif code == TrackerCode.SUMMARY:
            status = self._tracker.summary()
            self.ret_value([TrackerCode.SUCCESS, status])
237
        else:
238
            logger.warning("Unknown tracker code %d", code)
239 240 241 242 243 244
            self.close()

    def on_close(self):
        self._tracker._connections.remove(self)

    def on_error(self, err):
245
        logger.warning("%s: Error in RPC Tracker: %s", self.name(), err)
246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273
        self.close()


class TrackerServerHandler(object):
    """Tracker that tracks the resources."""
    def __init__(self, sock, stop_key):
        self._scheduler_map = {}
        self._sock = sock
        self._sock.setblocking(0)
        self._ioloop = ioloop.IOLoop.current()
        self._stop_key = stop_key
        self._connections = set()
        def _event_handler(_, events):
            self._on_event(events)
        self._ioloop.add_handler(
            self._sock.fileno(), _event_handler, self._ioloop.READ)

    def _on_event(self, _):
        while True:
            try:
                conn, addr = self._sock.accept()
                TCPEventHandler(self, conn, addr)
            except socket.error as err:
                if err.args[0] in (errno.EAGAIN, errno.EWOULDBLOCK):
                    break

    def create_scheduler(self, key):
        """Create a new scheduler."""
274
        return PriorityScheduler(key)
275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294

    def put(self, key, value):
        """Report a new resource to the tracker."""
        if key not in self._scheduler_map:
            self._scheduler_map[key] = self.create_scheduler(key)
        self._scheduler_map[key].put(value)

    def request(self, key, user, priority, callback):
        """Request a new resource."""
        if key not in self._scheduler_map:
            self._scheduler_map[key] = self.create_scheduler(key)
        self._scheduler_map[key].request(user, priority, callback)

    def stop(self):
        """Safely stop tracker."""
        for conn in list(self._connections):
            conn.close()
        self._sock.close()
        self._ioloop.stop()

295 296 297 298 299 300 301 302 303 304 305 306 307
    def summary(self):
        """Return a dict summarizing current status."""
        qinfo = {}
        for k, v in self._scheduler_map.items():
            qinfo[k] = v.summary()
        cinfo = []
        # ignore client connections without key
        for conn in self._connections:
            res = conn.summary()
            if res.get("key", "").startswith("server"):
                cinfo.append(res)
        return {"queue_info": qinfo, "server_info": cinfo}

308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331
    def run(self):
        """Run the tracker server"""
        self._ioloop.start()

def _tracker_server(listen_sock, stop_key):
    handler = TrackerServerHandler(listen_sock, stop_key)
    handler.run()


class Tracker(object):
    """Start RPC tracker on a seperate process.

    Python implementation based on multi-processing.

    Parameters
    ----------
    host : str
        The host url of the server.

    port : int
        The TCP port to be bind to

    port_end : int, optional
        The end TCP port to search
332 333 334

    silent: bool, optional
        Whether run in silent mode
335 336 337 338
    """
    def __init__(self,
                 host,
                 port=9190,
339 340 341
                 port_end=9199,
                 silent=False):
        if silent:
342
            logger.setLevel(logging.WARN)
343

344 345
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        self.port = None
346
        self.stop_key = base.random_key("tracker")
347 348 349 350 351 352 353 354 355 356 357 358
        for my_port in range(port, port_end):
            try:
                sock.bind((host, my_port))
                self.port = my_port
                break
            except socket.error as sock_err:
                if sock_err.errno in [98, 48]:
                    continue
                else:
                    raise sock_err
        if not self.port:
            raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end))
359
        logger.info("bind to %s:%d", host, self.port)
360 361 362 363 364 365 366 367 368 369 370
        sock.listen(1)
        self.proc = multiprocessing.Process(
            target=_tracker_server, args=(sock, self.stop_key))
        self.proc.start()
        self.host = host
        # close the socket on this process
        sock.close()

    def _stop_tracker(self):
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        sock.connect((self.host, self.port))
tqchen committed
371 372
        sock.sendall(struct.pack("<i", base.RPC_TRACKER_MAGIC))
        magic = struct.unpack("<i", base.recvall(sock, 4))[0]
373 374 375 376 377 378 379 380 381 382 383 384
        assert magic == base.RPC_TRACKER_MAGIC
        base.sendjson(sock, [TrackerCode.STOP, self.stop_key])
        assert base.recvjson(sock) == TrackerCode.SUCCESS
        sock.close()

    def terminate(self):
        """Terminate the server process"""
        if self.proc:
            if self.proc.is_alive():
                self._stop_tracker()
                self.proc.join(1)
            if self.proc.is_alive():
385
                logger.info("Terminating Tracker Server...")
386 387 388 389 390
                self.proc.terminate()
            self.proc = None

    def __del__(self):
        self.terminate()