server.py 14.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17 18 19 20 21 22 23 24
"""RPC server implementation.

Note
----
Server is TCP based with the following protocol:
- Initial handshake to the peer
  - [RPC_MAGIC, keysize(int32), key-bytes]
- The key is in format
25
   - {server|client}:device-type[:random-key] [-timeout=timeout]
26
"""
27 28
# pylint: disable=invalid-name

29 30 31
from __future__ import absolute_import

import os
32
import ctypes
33 34 35 36 37 38 39
import socket
import select
import struct
import logging
import multiprocessing
import subprocess
import time
40
import sys
41
import signal
42

43 44 45 46 47
from .._ffi.function import register_func
from .._ffi.base import py_str
from .._ffi.libinfo import find_lib_path
from ..module import load as _load_module
from ..contrib import util
48 49 50
from . import base
from . base import TrackerCode

51 52 53
logger = logging.getLogger('RPCServer')

def _server_env(load_library):
54 55
    """Server environment function return temp dir"""
    temp = util.tempdir()
56

57
    # pylint: disable=unused-variable
58
    @register_func("tvm.rpc.server.workpath")
59 60 61
    def get_workpath(path):
        return temp.relpath(path)

62
    @register_func("tvm.rpc.server.load_module", override=True)
63 64 65 66
    def load_module(file_name):
        """Load module from remote side."""
        path = temp.relpath(file_name)
        m = _load_module(path)
67
        logger.info("load_module %s", path)
68
        return m
69 70 71 72 73 74

    libs = []
    load_library = load_library.split(":") if load_library else []
    for file_name in load_library:
        file_name = find_lib_path(file_name)[0]
        libs.append(ctypes.CDLL(file_name, ctypes.RTLD_GLOBAL))
75
        logger.info("Load additional library %s", file_name)
76
    temp.libs = libs
77 78 79
    return temp


80
def _serve_loop(sock, addr, load_library):
81 82
    """Server loop"""
    sockfd = sock.fileno()
83
    temp = _server_env(load_library)
84 85
    base._ServerLoop(sockfd)
    temp.remove()
86
    logger.info("Finish serving %s", addr)
87 88 89 90 91 92 93 94 95 96


def _parse_server_opt(opts):
    # parse client options
    ret = {}
    for kv in opts:
        if kv.startswith("-timeout="):
            ret["timeout"] = float(kv[9:])
    return ret

97
def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr):
98
    """Listening loop of the server master."""
99
    def _accept_conn(listen_sock, tracker_conn, ping_period=2):
100 101 102 103 104 105 106 107 108 109 110 111 112
        """Accept connection from the other places.

        Parameters
        ----------
        listen_sock: Socket
            The socket used by listening process.

        tracker_conn : connnection to tracker
            Tracker connection

        ping_period : float, optional
            ping tracker every k seconds if no connection is accepted.
        """
113
        old_keyset = set()
114 115
        # Report resource to tracker
        if tracker_conn:
116
            matchkey = base.random_key(rpc_key + ":")
117
            base.sendjson(tracker_conn,
118
                          [TrackerCode.PUT, rpc_key, (port, matchkey), custom_addr])
119 120
            assert base.recvjson(tracker_conn) == TrackerCode.SUCCESS
        else:
121
            matchkey = rpc_key
122

123 124
        unmatch_period_count = 0
        unmatch_timeout = 4
125 126 127 128 129
        # Wait until we get a valid connection
        while True:
            if tracker_conn:
                trigger = select.select([listen_sock], [], [], ping_period)
                if not listen_sock in trigger[0]:
130 131 132 133
                    base.sendjson(tracker_conn, [TrackerCode.GET_PENDING_MATCHKEYS])
                    pending_keys = base.recvjson(tracker_conn)
                    old_keyset.add(matchkey)
                    # if match key not in pending key set
134
                    # it means the key is acquired by a client but not used.
135 136 137 138
                    if matchkey not in pending_keys:
                        unmatch_period_count += 1
                    else:
                        unmatch_period_count = 0
139
                    # regenerate match key if key is acquired but not used for a while
140
                    if unmatch_period_count * ping_period > unmatch_timeout + ping_period:
141
                        logger.info("no incoming connections, regenerate key ...")
142 143
                        matchkey = base.random_key(rpc_key + ":", old_keyset)
                        base.sendjson(tracker_conn,
144 145
                                      [TrackerCode.PUT, rpc_key, (port, matchkey),
                                       custom_addr])
146 147
                        assert base.recvjson(tracker_conn) == TrackerCode.SUCCESS
                        unmatch_period_count = 0
148 149
                    continue
            conn, addr = listen_sock.accept()
tqchen committed
150
            magic = struct.unpack("<i", base.recvall(conn, 4))[0]
151 152 153
            if magic != base.RPC_MAGIC:
                conn.close()
                continue
tqchen committed
154
            keylen = struct.unpack("<i", base.recvall(conn, 4))[0]
155 156
            key = py_str(base.recvall(conn, keylen))
            arr = key.split()
157
            expect_header = "client:" + matchkey
158 159
            server_key = "server:" + rpc_key
            if arr[0] != expect_header:
tqchen committed
160
                conn.sendall(struct.pack("<i", base.RPC_CODE_MISMATCH))
161
                conn.close()
162
                logger.warning("mismatch key from %s", addr)
163 164
                continue
            else:
tqchen committed
165 166
                conn.sendall(struct.pack("<i", base.RPC_CODE_SUCCESS))
                conn.sendall(struct.pack("<i", len(server_key)))
167 168 169 170 171 172 173
                conn.sendall(server_key.encode("utf-8"))
                return conn, addr, _parse_server_opt(arr[1:])

    # Server logic
    tracker_conn = None
    while True:
        try:
174 175
            # step 1: setup tracker and report to tracker
            if tracker_addr and tracker_conn is None:
176
                tracker_conn = base.connect_with_retry(tracker_addr)
tqchen committed
177 178
                tracker_conn.sendall(struct.pack("<i", base.RPC_TRACKER_MAGIC))
                magic = struct.unpack("<i", base.recvall(tracker_conn, 4))[0]
179 180 181 182 183 184 185 186
                if magic != base.RPC_TRACKER_MAGIC:
                    raise RuntimeError("%s is not RPC Tracker" % str(tracker_addr))
                # report status of current queue
                cinfo = {"key" : "server:" + rpc_key}
                base.sendjson(tracker_conn,
                              [TrackerCode.UPDATE_INFO, cinfo])
                assert base.recvjson(tracker_conn) == TrackerCode.SUCCESS

187 188 189 190
            # step 2: wait for in-coming connections
            conn, addr, opts = _accept_conn(sock, tracker_conn)
        except (socket.error, IOError):
            # retry when tracker is dropped
191 192 193
            if tracker_conn:
                tracker_conn.close()
                tracker_conn = None
194
            continue
195
        except RuntimeError as exc:
196
            raise exc
197 198

        # step 3: serving
199 200
        logger.info("connection from %s", addr)
        server_proc = multiprocessing.Process(target=_serve_loop,
201
                                              args=(conn, addr, load_library))
202 203 204 205 206 207 208
        server_proc.deamon = True
        server_proc.start()
        # close from our side.
        conn.close()
        # wait until server process finish or timeout
        server_proc.join(opts.get("timeout", None))
        if server_proc.is_alive():
209
            logger.info("Timeout in RPC session, kill..")
210 211 212
            server_proc.terminate()


213
def _connect_proxy_loop(addr, key, load_library):
214
    key = "server:" + key
215 216 217
    retry_count = 0
    max_retry = 5
    retry_period = 5
218
    while True:
219
        try:
220
            sock = socket.socket(base.get_addr_family(addr), socket.SOCK_STREAM)
221
            sock.connect(addr)
tqchen committed
222 223
            sock.sendall(struct.pack("<i", base.RPC_MAGIC))
            sock.sendall(struct.pack("<i", len(key)))
224
            sock.sendall(key.encode("utf-8"))
tqchen committed
225
            magic = struct.unpack("<i", base.recvall(sock, 4))[0]
226 227 228
            if magic == base.RPC_CODE_DUPLICATE:
                raise RuntimeError("key: %s has already been used in proxy" % key)
            elif magic == base.RPC_CODE_MISMATCH:
229
                logger.warning("RPCProxy do not have matching client key %s", key)
230 231
            elif magic != base.RPC_CODE_SUCCESS:
                raise RuntimeError("%s is not RPC Proxy" % str(addr))
tqchen committed
232
            keylen = struct.unpack("<i", base.recvall(sock, 4))[0]
233 234
            remote_key = py_str(base.recvall(sock, keylen))
            opts = _parse_server_opt(remote_key.split()[1:])
235
            logger.info("connected to %s", str(addr))
236
            process = multiprocessing.Process(
237
                target=_serve_loop, args=(sock, addr, load_library))
238 239 240 241 242
            process.deamon = True
            process.start()
            sock.close()
            process.join(opts.get("timeout", None))
            if process.is_alive():
243
                logger.info("Timeout in RPC session, kill..")
244 245 246 247
                process.terminate()
            retry_count = 0
        except (socket.error, IOError) as err:
            retry_count += 1
248
            logger.warning("Error encountered %s, retry in %g sec", str(err), retry_period)
249 250 251
            if retry_count > max_retry:
                raise RuntimeError("Maximum retry error: last error: %s" % str(err))
            time.sleep(retry_period)
252 253 254 255 256 257 258 259 260 261 262 263 264 265

def _popen(cmd):
    proc = subprocess.Popen(cmd,
                            stdout=subprocess.PIPE,
                            stderr=subprocess.STDOUT,
                            env=os.environ)
    (out, _) = proc.communicate()
    if proc.returncode != 0:
        msg = "Server invoke error:\n"
        msg += out
        raise RuntimeError(msg)


class Server(object):
266
    """Start RPC server on a separate process.
267 268

    This is a simple python implementation based on multi-processing.
Siju committed
269
    It is also possible to implement a similar C based server with
270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292
    TVM runtime which does not depend on the python.

    Parameters
    ----------
    host : str
        The host url of the server.

    port : int
        The port to be bind to

    port_end : int, optional
        The end port to search

    is_proxy : bool, optional
        Whether the address specified is a proxy.
        If this is true, the host and port actually corresponds to the
        address of the proxy server.

    use_popen : bool, optional
        Whether to use Popen to start a fresh new process instead of fork.
        This is recommended to switch on if we want to do local RPC demonstration
        for GPU devices to avoid fork safety issues.

293 294 295
    tracker_addr: Tuple (str, int) , optional
        The address of RPC Tracker in tuple(host, ip) format.
        If is not None, the server will register itself to the tracker.
296

297
    key : str, optional
298
        The key used to identify the device type in tracker.
299 300 301

    load_library : str, optional
        List of additional libraries to be loaded during execution.
302 303 304 305 306 307

    custom_addr: str, optional
        Custom IP Address to Report to RPC Tracker

    silent: bool, optional
        Whether run this server in silent mode.
308 309 310 311 312 313 314 315
    """
    def __init__(self,
                 host,
                 port=9091,
                 port_end=9199,
                 is_proxy=False,
                 use_popen=False,
                 tracker_addr=None,
316
                 key="",
317
                 load_library=None,
318 319
                 custom_addr=None,
                 silent=False):
320 321 322 323 324 325 326 327
        try:
            if base._ServerLoop is None:
                raise RuntimeError("Please compile with USE_RPC=1")
        except NameError:
            raise RuntimeError("Please compile with USE_RPC=1")
        self.host = host
        self.port = port
        self.libs = []
328
        self.custom_addr = custom_addr
329
        self.use_popen = use_popen
330

331
        if silent:
332
            logger.setLevel(logging.ERROR)
333

334
        if use_popen:
335
            cmd = [sys.executable,
336 337 338 339 340 341 342
                   "-m", "tvm.exec.rpc_server",
                   "--host=%s" % host,
                   "--port=%s" % port]
            if tracker_addr:
                assert key
                cmd += ["--tracker=%s:%d" % tracker_addr,
                        "--key=%s" % key]
343
            if load_library:
344 345 346
                cmd += ["--load-library", load_library]
            if custom_addr:
                cmd += ["--custom-addr", custom_addr]
347 348 349
            if silent:
                cmd += ["--silent"]

350 351 352 353 354 355 356 357 358
            # prexec_fn is not thread safe and may result in deadlock.
            # python 3.2 introduced the start_new_session parameter as
            # an alternative to the common use case of
            # prexec_fn=os.setsid.  Once the minimum version of python
            # supported by TVM reaches python 3.2 this code can be
            # rewritten in favour of start_new_session.  In the
            # interim, stop the pylint diagnostic.
            #
            # pylint: disable=subprocess-popen-preexec-fn
359
            self.proc = subprocess.Popen(cmd, preexec_fn=os.setsid)
360
            time.sleep(0.5)
361
        elif not is_proxy:
362
            sock = socket.socket(base.get_addr_family((host, port)), socket.SOCK_STREAM)
363 364 365 366 367 368 369 370 371 372 373
            self.port = None
            for my_port in range(port, port_end):
                try:
                    sock.bind((host, my_port))
                    self.port = my_port
                    break
                except socket.error as sock_err:
                    if sock_err.errno in [98, 48]:
                        continue
                    else:
                        raise sock_err
374 375
            if not self.port:
                raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end))
376
            logger.info("bind to %s:%d", host, self.port)
377 378 379 380
            sock.listen(1)
            self.sock = sock
            self.proc = multiprocessing.Process(
                target=_listen_loop, args=(
381
                    self.sock, self.port, key, tracker_addr, load_library,
382
                    self.custom_addr))
383 384 385 386
            self.proc.deamon = True
            self.proc.start()
        else:
            self.proc = multiprocessing.Process(
387
                target=_connect_proxy_loop, args=((host, port), key, load_library))
388 389 390 391 392
            self.proc.deamon = True
            self.proc.start()

    def terminate(self):
        """Terminate the server process"""
393 394 395 396 397 398 399 400
        if self.use_popen:
            if self.proc:
                os.killpg(self.proc.pid, signal.SIGTERM)
                self.proc = None
        else:
            if self.proc:
                self.proc.terminate()
                self.proc = None
401 402 403

    def __del__(self):
        self.terminate()