tracker.py 13 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
"""RPC Tracker, tracks and distributes the TVM RPC resources.

This folder implemements the tracker server logic.

Note
----
Tracker is a TCP based rest api with the following protocol:
- Initial handshake to the peer
  - RPC_TRACKER_MAGIC
- Normal message: [size(int32), json-data]
- Each message is initiated by the client, and the tracker replies with a json.

List of available APIs:

- PING: check if tracker is alive
  - input: [TrackerCode.PING]
  - return: TrackerCode.SUCCESS
- PUT: report resource to tracker
  - input: [TrackerCode.PUT, [port, match-key]]
  - return: TrackerCode.SUCCESS
  - note: match-key is a randomly generated identify the resource during connection.
- REQUEST: request a new resource from tracker
  - input: [TrackerCode.REQUEST, [key, user, priority]]
  - return: [TrackerCode.SUCCESS, [url, port, match-key]]
"""
26 27
# pylint: disable=invalid-name

28 29 30 31 32 33 34 35 36 37 38 39 40 41
import heapq
import time
import logging
import socket
import multiprocessing
import errno
import struct
import json

try:
    from tornado import ioloop
    from . import tornado_util
except ImportError as error_msg:
    raise ImportError(
42
        "RPCTracker module requires tornado package %s. Try 'pip install tornado'." % error_msg)
43

44
from .._ffi.base import py_str
45 46 47
from . import base
from .base import RPC_TRACKER_MAGIC, TrackerCode

48
logger = logging.getLogger("RPCTracker")
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80

class Scheduler(object):
    """Abstratc interface of scheduler."""
    def put(self, value):
        """Push a resource into the scheduler.

        This function can trigger callbacks in the scheduler.

        Parameters
        ----------
        value : object
            The resource to be put in the scheduler.
        """
        raise NotImplementedError()

    def request(self, user, priority, callback):
        """Request a resource.

        Parameters
        ----------
        user : str
            The user who is requesting the resource.

        priority : int
            The job priority

        callback : function: value->bool
            Callback function to receive an resource when ready
            returns True if the resource is consumed.
        """
        raise NotImplementedError()

81 82 83 84 85 86 87 88 89 90
    def remove(self, value):
        """Remove a resource in the scheduler

        Parameters
        ----------
        value: object
            The resource to remove
        """
        pass

91 92 93 94
    def summary(self):
        """Get summary information of the scheduler."""
        raise NotImplementedError()

95 96 97

class PriorityScheduler(Scheduler):
    """Priority based scheduler, FIFO based on time"""
98 99
    def __init__(self, key):
        self._key = key
100 101 102 103 104 105 106 107
        self._values = []
        self._requests = []

    def _schedule(self):
        while self._requests and self._values:
            value = self._values.pop(0)
            item = heapq.heappop(self._requests)
            callback = item[-1]
108 109 110
            if callback(value[1:]):
                value[0].pending_matchkeys.remove(value[-1])
            else:
111 112 113 114 115 116 117 118 119 120
                self._values.append(value)

    def put(self, value):
        self._values.append(value)
        self._schedule()

    def request(self, user, priority, callback):
        heapq.heappush(self._requests, (-priority, time.time(), callback))
        self._schedule()

121 122 123 124 125
    def remove(self, value):
        if value in self._values:
            self._values.remove(value)
            self._schedule()

126 127 128 129 130
    def summary(self):
        """Get summary information of the scheduler."""
        return {"free": len(self._values),
                "pending": len(self._requests)}

131 132 133 134 135 136 137 138 139 140 141 142 143 144 145

class TCPEventHandler(tornado_util.TCPHandler):
    """Base asynchronize message handler.

    The tracker and client follows a simple message protocol.
    The message is in form [nbytes(int32)] [json-str].
    All the information is packed in json-str
    """
    def __init__(self, tracker, sock, addr):
        super(TCPEventHandler, self).__init__(sock)
        self._data = bytearray()
        self._tracker = tracker
        self._msg_size = 0
        self._addr = addr
        self._init_req_nbytes = 4
146
        self._info = {"addr": addr}
147 148
        # list of pending match keys that has not been used.
        self.pending_matchkeys = set()
149
        self._tracker._connections.add(self)
150
        self.put_values = []
151 152 153 154 155

    def name(self):
        """name of connection"""
        return "TCPSocket: %s" % str(self._addr)

156 157 158 159
    def summary(self):
        """Summary of this connection"""
        return self._info

160 161 162
    def _init_conn(self, message):
        """Initialie the connection"""
        if len(message) != 4:
163
            logger.warning("Invalid connection from %s", self.name())
164
            self.close()
tqchen committed
165
        magic = struct.unpack('<i', message)[0]
166
        if magic != RPC_TRACKER_MAGIC:
167
            logger.warning("Invalid magic from %s", self.name())
168
            self.close()
tqchen committed
169
        self.write_message(struct.pack('<i', RPC_TRACKER_MAGIC), binary=True)
170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189
        self._init_req_nbytes = 0

    def on_message(self, message):
        """Callback when a message is received.

        Parameters
        ----------
        message : bytearray
            The bytes received
        """
        assert isinstance(message, bytes)
        if self._init_req_nbytes:
            self._init_conn(message)
            return

        self._data += message

        while True:
            if self._msg_size == 0:
                if len(self._data) >= 4:
tqchen committed
190
                    self._msg_size = struct.unpack('<i', self._data[:4])[0]
191 192 193 194 195 196 197 198 199 200 201 202 203 204 205
                else:
                    return
            if self._msg_size != 0 and len(self._data) >= self._msg_size + 4:
                msg = py_str(bytes(self._data[4:4 + self._msg_size]))
                del self._data[:4 + self._msg_size]
                self._msg_size = 0
                # pylint: disable=broad-except
                self.call_handler(json.loads(msg))
            else:
                return

    def ret_value(self, data):
        """return value to the output"""
        data = json.dumps(data)
        self.write_message(
tqchen committed
206
            struct.pack('<i', len(data)), binary=True)
207 208 209 210 211 212 213 214
        self.write_message(data.encode("utf-8"), binary=True)

    def call_handler(self, args):
        """Event handler when json request arrives."""
        code = args[0]
        if code == TrackerCode.PUT:
            key = args[1]
            port, matchkey = args[2]
215
            self.pending_matchkeys.add(matchkey)
216 217
            # got custom address (from rpc server)
            if args[3] is not None:
218
                value = (self, args[3], port, matchkey)
219
            else:
220 221 222
                value = (self, self._addr[0], port, matchkey)
            self._tracker.put(key, value)
            self.put_values.append(value)
223 224 225 226 227 228
            self.ret_value(TrackerCode.SUCCESS)
        elif code == TrackerCode.REQUEST:
            key = args[1]
            user = args[2]
            priority = args[3]
            def _cb(value):
229 230 231 232 233
                # if the connection is already closed
                if not self._sock:
                    return False
                try:
                    self.ret_value([TrackerCode.SUCCESS, value])
234
                except (socket.error, IOError):
235
                    return False
236 237 238 239
                return True
            self._tracker.request(key, user, priority, _cb)
        elif code == TrackerCode.PING:
            self.ret_value(TrackerCode.SUCCESS)
240 241
        elif code == TrackerCode.GET_PENDING_MATCHKEYS:
            self.ret_value(list(self.pending_matchkeys))
242 243 244 245 246 247 248
        elif code == TrackerCode.STOP:
            # safe stop tracker
            if self._tracker._stop_key == args[1]:
                self.ret_value(TrackerCode.SUCCESS)
                self._tracker.stop()
            else:
                self.ret_value(TrackerCode.FAIL)
249 250 251 252 253 254
        elif code == TrackerCode.UPDATE_INFO:
            self._info.update(args[1])
            self.ret_value(TrackerCode.SUCCESS)
        elif code == TrackerCode.SUMMARY:
            status = self._tracker.summary()
            self.ret_value([TrackerCode.SUCCESS, status])
255
        else:
256
            logger.warning("Unknown tracker code %d", code)
257 258 259
            self.close()

    def on_close(self):
260
        self._tracker.close(self)
261 262

    def on_error(self, err):
263
        logger.warning("%s: Error in RPC Tracker: %s", self.name(), err)
264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291
        self.close()


class TrackerServerHandler(object):
    """Tracker that tracks the resources."""
    def __init__(self, sock, stop_key):
        self._scheduler_map = {}
        self._sock = sock
        self._sock.setblocking(0)
        self._ioloop = ioloop.IOLoop.current()
        self._stop_key = stop_key
        self._connections = set()
        def _event_handler(_, events):
            self._on_event(events)
        self._ioloop.add_handler(
            self._sock.fileno(), _event_handler, self._ioloop.READ)

    def _on_event(self, _):
        while True:
            try:
                conn, addr = self._sock.accept()
                TCPEventHandler(self, conn, addr)
            except socket.error as err:
                if err.args[0] in (errno.EAGAIN, errno.EWOULDBLOCK):
                    break

    def create_scheduler(self, key):
        """Create a new scheduler."""
292
        return PriorityScheduler(key)
293 294 295 296 297 298 299 300 301 302 303 304 305

    def put(self, key, value):
        """Report a new resource to the tracker."""
        if key not in self._scheduler_map:
            self._scheduler_map[key] = self.create_scheduler(key)
        self._scheduler_map[key].put(value)

    def request(self, key, user, priority, callback):
        """Request a new resource."""
        if key not in self._scheduler_map:
            self._scheduler_map[key] = self.create_scheduler(key)
        self._scheduler_map[key].request(user, priority, callback)

306 307 308 309 310 311 312
    def close(self, conn):
        self._connections.remove(conn)
        if 'key' in conn._info:
            key = conn._info['key'].split(':')[1]  # 'server:rasp3b' -> 'rasp3b'
            for value in conn.put_values:
                self._scheduler_map[key].remove(value)

313 314 315 316 317 318 319
    def stop(self):
        """Safely stop tracker."""
        for conn in list(self._connections):
            conn.close()
        self._sock.close()
        self._ioloop.stop()

320 321 322 323 324 325 326 327 328 329 330 331 332
    def summary(self):
        """Return a dict summarizing current status."""
        qinfo = {}
        for k, v in self._scheduler_map.items():
            qinfo[k] = v.summary()
        cinfo = []
        # ignore client connections without key
        for conn in self._connections:
            res = conn.summary()
            if res.get("key", "").startswith("server"):
                cinfo.append(res)
        return {"queue_info": qinfo, "server_info": cinfo}

333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356
    def run(self):
        """Run the tracker server"""
        self._ioloop.start()

def _tracker_server(listen_sock, stop_key):
    handler = TrackerServerHandler(listen_sock, stop_key)
    handler.run()


class Tracker(object):
    """Start RPC tracker on a seperate process.

    Python implementation based on multi-processing.

    Parameters
    ----------
    host : str
        The host url of the server.

    port : int
        The TCP port to be bind to

    port_end : int, optional
        The end TCP port to search
357 358 359

    silent: bool, optional
        Whether run in silent mode
360 361 362 363
    """
    def __init__(self,
                 host,
                 port=9190,
364 365 366
                 port_end=9199,
                 silent=False):
        if silent:
367
            logger.setLevel(logging.WARN)
368

369 370
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        self.port = None
371
        self.stop_key = base.random_key("tracker")
372 373 374 375 376 377 378 379 380 381 382 383
        for my_port in range(port, port_end):
            try:
                sock.bind((host, my_port))
                self.port = my_port
                break
            except socket.error as sock_err:
                if sock_err.errno in [98, 48]:
                    continue
                else:
                    raise sock_err
        if not self.port:
            raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end))
384
        logger.info("bind to %s:%d", host, self.port)
385 386 387 388 389 390 391 392 393 394 395
        sock.listen(1)
        self.proc = multiprocessing.Process(
            target=_tracker_server, args=(sock, self.stop_key))
        self.proc.start()
        self.host = host
        # close the socket on this process
        sock.close()

    def _stop_tracker(self):
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        sock.connect((self.host, self.port))
tqchen committed
396 397
        sock.sendall(struct.pack("<i", base.RPC_TRACKER_MAGIC))
        magic = struct.unpack("<i", base.recvall(sock, 4))[0]
398 399 400 401 402 403 404 405 406 407 408 409
        assert magic == base.RPC_TRACKER_MAGIC
        base.sendjson(sock, [TrackerCode.STOP, self.stop_key])
        assert base.recvjson(sock) == TrackerCode.SUCCESS
        sock.close()

    def terminate(self):
        """Terminate the server process"""
        if self.proc:
            if self.proc.is_alive():
                self._stop_tracker()
                self.proc.join(1)
            if self.proc.is_alive():
410
                logger.info("Terminating Tracker Server...")
411 412 413 414 415
                self.proc.terminate()
            self.proc = None

    def __del__(self):
        self.terminate()