tracker.py 13.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41
"""RPC Tracker, tracks and distributes the TVM RPC resources.

This folder implemements the tracker server logic.

Note
----
Tracker is a TCP based rest api with the following protocol:
- Initial handshake to the peer
  - RPC_TRACKER_MAGIC
- Normal message: [size(int32), json-data]
- Each message is initiated by the client, and the tracker replies with a json.

List of available APIs:

- PING: check if tracker is alive
  - input: [TrackerCode.PING]
  - return: TrackerCode.SUCCESS
- PUT: report resource to tracker
  - input: [TrackerCode.PUT, [port, match-key]]
  - return: TrackerCode.SUCCESS
  - note: match-key is a randomly generated identify the resource during connection.
- REQUEST: request a new resource from tracker
  - input: [TrackerCode.REQUEST, [key, user, priority]]
  - return: [TrackerCode.SUCCESS, [url, port, match-key]]
"""
42 43
# pylint: disable=invalid-name

44 45 46 47 48 49 50 51 52 53 54 55 56 57
import heapq
import time
import logging
import socket
import multiprocessing
import errno
import struct
import json

try:
    from tornado import ioloop
    from . import tornado_util
except ImportError as error_msg:
    raise ImportError(
58
        "RPCTracker module requires tornado package %s. Try 'pip install tornado'." % error_msg)
59

60
from .._ffi.base import py_str
61 62 63
from . import base
from .base import RPC_TRACKER_MAGIC, TrackerCode

64
logger = logging.getLogger("RPCTracker")
65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96

class Scheduler(object):
    """Abstratc interface of scheduler."""
    def put(self, value):
        """Push a resource into the scheduler.

        This function can trigger callbacks in the scheduler.

        Parameters
        ----------
        value : object
            The resource to be put in the scheduler.
        """
        raise NotImplementedError()

    def request(self, user, priority, callback):
        """Request a resource.

        Parameters
        ----------
        user : str
            The user who is requesting the resource.

        priority : int
            The job priority

        callback : function: value->bool
            Callback function to receive an resource when ready
            returns True if the resource is consumed.
        """
        raise NotImplementedError()

97 98 99 100 101 102 103 104
    def remove(self, value):
        """Remove a resource in the scheduler

        Parameters
        ----------
        value: object
            The resource to remove
        """
105

106

107 108 109 110
    def summary(self):
        """Get summary information of the scheduler."""
        raise NotImplementedError()

111 112 113

class PriorityScheduler(Scheduler):
    """Priority based scheduler, FIFO based on time"""
114 115
    def __init__(self, key):
        self._key = key
116 117 118 119 120 121 122 123
        self._values = []
        self._requests = []

    def _schedule(self):
        while self._requests and self._values:
            value = self._values.pop(0)
            item = heapq.heappop(self._requests)
            callback = item[-1]
124 125 126
            if callback(value[1:]):
                value[0].pending_matchkeys.remove(value[-1])
            else:
127 128 129 130 131 132 133 134 135 136
                self._values.append(value)

    def put(self, value):
        self._values.append(value)
        self._schedule()

    def request(self, user, priority, callback):
        heapq.heappush(self._requests, (-priority, time.time(), callback))
        self._schedule()

137 138 139 140 141
    def remove(self, value):
        if value in self._values:
            self._values.remove(value)
            self._schedule()

142 143 144 145 146
    def summary(self):
        """Get summary information of the scheduler."""
        return {"free": len(self._values),
                "pending": len(self._requests)}

147 148 149 150 151 152 153 154 155 156 157 158 159 160 161

class TCPEventHandler(tornado_util.TCPHandler):
    """Base asynchronize message handler.

    The tracker and client follows a simple message protocol.
    The message is in form [nbytes(int32)] [json-str].
    All the information is packed in json-str
    """
    def __init__(self, tracker, sock, addr):
        super(TCPEventHandler, self).__init__(sock)
        self._data = bytearray()
        self._tracker = tracker
        self._msg_size = 0
        self._addr = addr
        self._init_req_nbytes = 4
162
        self._info = {"addr": addr}
163 164
        # list of pending match keys that has not been used.
        self.pending_matchkeys = set()
165
        self._tracker._connections.add(self)
166
        self.put_values = []
167 168 169 170 171

    def name(self):
        """name of connection"""
        return "TCPSocket: %s" % str(self._addr)

172 173 174 175
    def summary(self):
        """Summary of this connection"""
        return self._info

176 177 178
    def _init_conn(self, message):
        """Initialie the connection"""
        if len(message) != 4:
179
            logger.warning("Invalid connection from %s", self.name())
180
            self.close()
tqchen committed
181
        magic = struct.unpack('<i', message)[0]
182
        if magic != RPC_TRACKER_MAGIC:
183
            logger.warning("Invalid magic from %s", self.name())
184
            self.close()
tqchen committed
185
        self.write_message(struct.pack('<i', RPC_TRACKER_MAGIC), binary=True)
186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205
        self._init_req_nbytes = 0

    def on_message(self, message):
        """Callback when a message is received.

        Parameters
        ----------
        message : bytearray
            The bytes received
        """
        assert isinstance(message, bytes)
        if self._init_req_nbytes:
            self._init_conn(message)
            return

        self._data += message

        while True:
            if self._msg_size == 0:
                if len(self._data) >= 4:
tqchen committed
206
                    self._msg_size = struct.unpack('<i', self._data[:4])[0]
207 208 209 210 211 212 213 214 215 216 217 218 219 220 221
                else:
                    return
            if self._msg_size != 0 and len(self._data) >= self._msg_size + 4:
                msg = py_str(bytes(self._data[4:4 + self._msg_size]))
                del self._data[:4 + self._msg_size]
                self._msg_size = 0
                # pylint: disable=broad-except
                self.call_handler(json.loads(msg))
            else:
                return

    def ret_value(self, data):
        """return value to the output"""
        data = json.dumps(data)
        self.write_message(
tqchen committed
222
            struct.pack('<i', len(data)), binary=True)
223 224 225 226 227 228 229 230
        self.write_message(data.encode("utf-8"), binary=True)

    def call_handler(self, args):
        """Event handler when json request arrives."""
        code = args[0]
        if code == TrackerCode.PUT:
            key = args[1]
            port, matchkey = args[2]
231
            self.pending_matchkeys.add(matchkey)
232
            # got custom address (from rpc server)
233
            if len(args) >= 4 and args[3] is not None:
234
                value = (self, args[3], port, matchkey)
235
            else:
236 237 238
                value = (self, self._addr[0], port, matchkey)
            self._tracker.put(key, value)
            self.put_values.append(value)
239 240 241 242 243 244
            self.ret_value(TrackerCode.SUCCESS)
        elif code == TrackerCode.REQUEST:
            key = args[1]
            user = args[2]
            priority = args[3]
            def _cb(value):
245 246 247 248 249
                # if the connection is already closed
                if not self._sock:
                    return False
                try:
                    self.ret_value([TrackerCode.SUCCESS, value])
250
                except (socket.error, IOError):
251
                    return False
252 253 254 255
                return True
            self._tracker.request(key, user, priority, _cb)
        elif code == TrackerCode.PING:
            self.ret_value(TrackerCode.SUCCESS)
256 257
        elif code == TrackerCode.GET_PENDING_MATCHKEYS:
            self.ret_value(list(self.pending_matchkeys))
258 259 260 261 262 263 264
        elif code == TrackerCode.STOP:
            # safe stop tracker
            if self._tracker._stop_key == args[1]:
                self.ret_value(TrackerCode.SUCCESS)
                self._tracker.stop()
            else:
                self.ret_value(TrackerCode.FAIL)
265 266 267 268 269 270
        elif code == TrackerCode.UPDATE_INFO:
            self._info.update(args[1])
            self.ret_value(TrackerCode.SUCCESS)
        elif code == TrackerCode.SUMMARY:
            status = self._tracker.summary()
            self.ret_value([TrackerCode.SUCCESS, status])
271
        else:
272
            logger.warning("Unknown tracker code %d", code)
273 274 275
            self.close()

    def on_close(self):
276
        self._tracker.close(self)
277 278

    def on_error(self, err):
279
        logger.warning("%s: Error in RPC Tracker: %s", self.name(), err)
280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307
        self.close()


class TrackerServerHandler(object):
    """Tracker that tracks the resources."""
    def __init__(self, sock, stop_key):
        self._scheduler_map = {}
        self._sock = sock
        self._sock.setblocking(0)
        self._ioloop = ioloop.IOLoop.current()
        self._stop_key = stop_key
        self._connections = set()
        def _event_handler(_, events):
            self._on_event(events)
        self._ioloop.add_handler(
            self._sock.fileno(), _event_handler, self._ioloop.READ)

    def _on_event(self, _):
        while True:
            try:
                conn, addr = self._sock.accept()
                TCPEventHandler(self, conn, addr)
            except socket.error as err:
                if err.args[0] in (errno.EAGAIN, errno.EWOULDBLOCK):
                    break

    def create_scheduler(self, key):
        """Create a new scheduler."""
308
        return PriorityScheduler(key)
309 310 311 312 313 314 315 316 317 318 319 320 321

    def put(self, key, value):
        """Report a new resource to the tracker."""
        if key not in self._scheduler_map:
            self._scheduler_map[key] = self.create_scheduler(key)
        self._scheduler_map[key].put(value)

    def request(self, key, user, priority, callback):
        """Request a new resource."""
        if key not in self._scheduler_map:
            self._scheduler_map[key] = self.create_scheduler(key)
        self._scheduler_map[key].request(user, priority, callback)

322 323 324 325 326 327 328
    def close(self, conn):
        self._connections.remove(conn)
        if 'key' in conn._info:
            key = conn._info['key'].split(':')[1]  # 'server:rasp3b' -> 'rasp3b'
            for value in conn.put_values:
                self._scheduler_map[key].remove(value)

329 330 331 332 333 334 335
    def stop(self):
        """Safely stop tracker."""
        for conn in list(self._connections):
            conn.close()
        self._sock.close()
        self._ioloop.stop()

336 337 338 339 340 341 342 343 344 345 346 347 348
    def summary(self):
        """Return a dict summarizing current status."""
        qinfo = {}
        for k, v in self._scheduler_map.items():
            qinfo[k] = v.summary()
        cinfo = []
        # ignore client connections without key
        for conn in self._connections:
            res = conn.summary()
            if res.get("key", "").startswith("server"):
                cinfo.append(res)
        return {"queue_info": qinfo, "server_info": cinfo}

349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372
    def run(self):
        """Run the tracker server"""
        self._ioloop.start()

def _tracker_server(listen_sock, stop_key):
    handler = TrackerServerHandler(listen_sock, stop_key)
    handler.run()


class Tracker(object):
    """Start RPC tracker on a seperate process.

    Python implementation based on multi-processing.

    Parameters
    ----------
    host : str
        The host url of the server.

    port : int
        The TCP port to be bind to

    port_end : int, optional
        The end TCP port to search
373 374 375

    silent: bool, optional
        Whether run in silent mode
376 377 378 379
    """
    def __init__(self,
                 host,
                 port=9190,
380 381 382
                 port_end=9199,
                 silent=False):
        if silent:
383
            logger.setLevel(logging.WARN)
384

385
        sock = socket.socket(base.get_addr_family((host, port)), socket.SOCK_STREAM)
386
        self.port = None
387
        self.stop_key = base.random_key("tracker")
388 389 390 391 392 393 394 395
        for my_port in range(port, port_end):
            try:
                sock.bind((host, my_port))
                self.port = my_port
                break
            except socket.error as sock_err:
                if sock_err.errno in [98, 48]:
                    continue
396
                raise sock_err
397 398
        if not self.port:
            raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end))
399
        logger.info("bind to %s:%d", host, self.port)
400 401 402 403 404 405 406 407 408
        sock.listen(1)
        self.proc = multiprocessing.Process(
            target=_tracker_server, args=(sock, self.stop_key))
        self.proc.start()
        self.host = host
        # close the socket on this process
        sock.close()

    def _stop_tracker(self):
409
        sock = socket.socket(base.get_addr_family((self.host, self.port)), socket.SOCK_STREAM)
410
        sock.connect((self.host, self.port))
tqchen committed
411 412
        sock.sendall(struct.pack("<i", base.RPC_TRACKER_MAGIC))
        magic = struct.unpack("<i", base.recvall(sock, 4))[0]
413 414 415 416 417 418 419 420 421 422 423 424
        assert magic == base.RPC_TRACKER_MAGIC
        base.sendjson(sock, [TrackerCode.STOP, self.stop_key])
        assert base.recvjson(sock) == TrackerCode.SUCCESS
        sock.close()

    def terminate(self):
        """Terminate the server process"""
        if self.proc:
            if self.proc.is_alive():
                self._stop_tracker()
                self.proc.join(1)
            if self.proc.is_alive():
425
                logger.info("Terminating Tracker Server...")
426 427 428 429 430
                self.proc.terminate()
            self.proc = None

    def __del__(self):
        self.terminate()