server.py 12.6 KB
Newer Older
1 2 3 4 5 6 7 8
"""RPC server implementation.

Note
----
Server is TCP based with the following protocol:
- Initial handshake to the peer
  - [RPC_MAGIC, keysize(int32), key-bytes]
- The key is in format
9
   - {server|client}:device-type[:random-key] [-timeout=timeout]
10 11 12 13
"""
from __future__ import absolute_import

import os
14
import ctypes
15 16 17 18 19 20 21 22 23 24
import socket
import select
import struct
import logging
import multiprocessing
import subprocess
import time

from ..._ffi.function import register_func
from ..._ffi.base import py_str
25
from ..._ffi.libinfo import find_lib_path
26
from ...module import load as _load_module
27
from .. import util
28 29 30
from . import base
from . base import TrackerCode

31
def _server_env(load_library):
32 33 34 35 36 37 38 39 40 41 42 43 44 45
    """Server environment function return temp dir"""
    temp = util.tempdir()
    # pylint: disable=unused-variable
    @register_func("tvm.contrib.rpc.server.workpath")
    def get_workpath(path):
        return temp.relpath(path)

    @register_func("tvm.contrib.rpc.server.load_module", override=True)
    def load_module(file_name):
        """Load module from remote side."""
        path = temp.relpath(file_name)
        m = _load_module(path)
        logging.info("load_module %s", path)
        return m
46 47 48 49 50 51 52 53

    libs = []
    load_library = load_library.split(":") if load_library else []
    for file_name in load_library:
        file_name = find_lib_path(file_name)[0]
        libs.append(ctypes.CDLL(file_name, ctypes.RTLD_GLOBAL))
        logging.info("Load additional library %s", file_name)
    temp.libs = libs
54 55 56
    return temp


57
def _serve_loop(sock, addr, load_library):
58 59
    """Server loop"""
    sockfd = sock.fileno()
60
    temp = _server_env(load_library)
61 62 63 64 65 66 67 68 69 70 71 72 73 74
    base._ServerLoop(sockfd)
    temp.remove()
    logging.info("Finish serving %s", addr)


def _parse_server_opt(opts):
    # parse client options
    ret = {}
    for kv in opts:
        if kv.startswith("-timeout="):
            ret["timeout"] = float(kv[9:])
    return ret


75
def _listen_loop(sock, port, rpc_key, tracker_addr, load_library):
76
    """Lisenting loop of the server master."""
77
    def _accept_conn(listen_sock, tracker_conn, ping_period=2):
78 79 80 81 82 83 84 85 86 87 88 89 90
        """Accept connection from the other places.

        Parameters
        ----------
        listen_sock: Socket
            The socket used by listening process.

        tracker_conn : connnection to tracker
            Tracker connection

        ping_period : float, optional
            ping tracker every k seconds if no connection is accepted.
        """
91
        old_keyset = set()
92 93
        # Report resource to tracker
        if tracker_conn:
94
            matchkey = base.random_key(rpc_key + ":")
95 96 97 98
            base.sendjson(tracker_conn,
                          [TrackerCode.PUT, rpc_key, (port, matchkey)])
            assert base.recvjson(tracker_conn) == TrackerCode.SUCCESS
        else:
99
            matchkey = rpc_key
100

101 102
        unmatch_period_count = 0
        unmatch_timeout = 4
103 104 105 106 107
        # Wait until we get a valid connection
        while True:
            if tracker_conn:
                trigger = select.select([listen_sock], [], [], ping_period)
                if not listen_sock in trigger[0]:
108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
                    base.sendjson(tracker_conn, [TrackerCode.GET_PENDING_MATCHKEYS])
                    pending_keys = base.recvjson(tracker_conn)
                    old_keyset.add(matchkey)
                    # if match key not in pending key set
                    # it means the key is aqquired by a client but not used.
                    if matchkey not in pending_keys:
                        unmatch_period_count += 1
                    else:
                        unmatch_period_count = 0
                    # regenerate match key if key is aqquired but not used for a while
                    if unmatch_period_count * ping_period > unmatch_timeout + ping_period:
                        logging.info("RPCServer: no incoming connections, regenerate key ...")
                        matchkey = base.random_key(rpc_key + ":", old_keyset)
                        base.sendjson(tracker_conn,
                                      [TrackerCode.PUT, rpc_key, (port, matchkey)])
                        assert base.recvjson(tracker_conn) == TrackerCode.SUCCESS
                        unmatch_period_count = 0
125 126
                    continue
            conn, addr = listen_sock.accept()
tqchen committed
127
            magic = struct.unpack("<i", base.recvall(conn, 4))[0]
128 129 130
            if magic != base.RPC_MAGIC:
                conn.close()
                continue
tqchen committed
131
            keylen = struct.unpack("<i", base.recvall(conn, 4))[0]
132 133
            key = py_str(base.recvall(conn, keylen))
            arr = key.split()
134
            expect_header = "client:" + matchkey
135 136
            server_key = "server:" + rpc_key
            if arr[0] != expect_header:
tqchen committed
137
                conn.sendall(struct.pack("<i", base.RPC_CODE_MISMATCH))
138 139 140 141
                conn.close()
                logging.info("RPCServer: mismatch key from %s", addr)
                continue
            else:
tqchen committed
142 143
                conn.sendall(struct.pack("<i", base.RPC_CODE_SUCCESS))
                conn.sendall(struct.pack("<i", len(server_key)))
144 145 146 147 148 149 150
                conn.sendall(server_key.encode("utf-8"))
                return conn, addr, _parse_server_opt(arr[1:])

    # Server logic
    tracker_conn = None
    while True:
        try:
151 152 153
            # step 1: setup tracker and report to tracker
            if tracker_addr and tracker_conn is None:
                tracker_conn = base.connect_with_retry(tracker_addr)
tqchen committed
154 155
                tracker_conn.sendall(struct.pack("<i", base.RPC_TRACKER_MAGIC))
                magic = struct.unpack("<i", base.recvall(tracker_conn, 4))[0]
156 157 158 159 160 161 162 163
                if magic != base.RPC_TRACKER_MAGIC:
                    raise RuntimeError("%s is not RPC Tracker" % str(tracker_addr))
                # report status of current queue
                cinfo = {"key" : "server:" + rpc_key}
                base.sendjson(tracker_conn,
                              [TrackerCode.UPDATE_INFO, cinfo])
                assert base.recvjson(tracker_conn) == TrackerCode.SUCCESS

164 165 166 167
            # step 2: wait for in-coming connections
            conn, addr, opts = _accept_conn(sock, tracker_conn)
        except (socket.error, IOError):
            # retry when tracker is dropped
168 169 170
            if tracker_conn:
                tracker_conn.close()
                tracker_conn = None
171 172 173 174
            continue

        # step 3: serving
        logging.info("RPCServer: connection from %s", addr)
175
        server_proc = multiprocessing.Process(target=_serve_loop, args=(conn, addr, load_library))
176 177 178 179 180 181 182
        server_proc.deamon = True
        server_proc.start()
        # close from our side.
        conn.close()
        # wait until server process finish or timeout
        server_proc.join(opts.get("timeout", None))
        if server_proc.is_alive():
183
            logging.info("RPCServer: Timeout in RPC session, kill..")
184 185 186
            server_proc.terminate()


187
def _connect_proxy_loop(addr, key, load_library):
188
    key = "server:" + key
189 190 191
    retry_count = 0
    max_retry = 5
    retry_period = 5
192
    while True:
193 194 195
        try:
            sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            sock.connect(addr)
tqchen committed
196 197
            sock.sendall(struct.pack("<i", base.RPC_MAGIC))
            sock.sendall(struct.pack("<i", len(key)))
198
            sock.sendall(key.encode("utf-8"))
tqchen committed
199
            magic = struct.unpack("<i", base.recvall(sock, 4))[0]
200 201 202 203 204 205
            if magic == base.RPC_CODE_DUPLICATE:
                raise RuntimeError("key: %s has already been used in proxy" % key)
            elif magic == base.RPC_CODE_MISMATCH:
                logging.info("RPCProxy do not have matching client key %s", key)
            elif magic != base.RPC_CODE_SUCCESS:
                raise RuntimeError("%s is not RPC Proxy" % str(addr))
tqchen committed
206
            keylen = struct.unpack("<i", base.recvall(sock, 4))[0]
207 208 209 210
            remote_key = py_str(base.recvall(sock, keylen))
            opts = _parse_server_opt(remote_key.split()[1:])
            logging.info("RPCProxy connected to %s", str(addr))
            process = multiprocessing.Process(
211
                target=_serve_loop, args=(sock, addr, load_library))
212 213 214 215 216 217 218 219 220 221 222 223 224 225
            process.deamon = True
            process.start()
            sock.close()
            process.join(opts.get("timeout", None))
            if process.is_alive():
                logging.info("RPCProxyServer: Timeout in RPC session, kill..")
                process.terminate()
            retry_count = 0
        except (socket.error, IOError) as err:
            retry_count += 1
            logging.info("Error encountered %s, retry in %g sec", str(err), retry_period)
            if retry_count > max_retry:
                raise RuntimeError("Maximum retry error: last error: %s" % str(err))
            time.sleep(retry_period)
226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268

def _popen(cmd):
    proc = subprocess.Popen(cmd,
                            stdout=subprocess.PIPE,
                            stderr=subprocess.STDOUT,
                            env=os.environ)
    (out, _) = proc.communicate()
    if proc.returncode != 0:
        msg = "Server invoke error:\n"
        msg += out
        raise RuntimeError(msg)


class Server(object):
    """Start RPC server on a seperate process.

    This is a simple python implementation based on multi-processing.
    It is also possible to implement a similar C based sever with
    TVM runtime which does not depend on the python.

    Parameters
    ----------
    host : str
        The host url of the server.

    port : int
        The port to be bind to

    port_end : int, optional
        The end port to search

    is_proxy : bool, optional
        Whether the address specified is a proxy.
        If this is true, the host and port actually corresponds to the
        address of the proxy server.

    use_popen : bool, optional
        Whether to use Popen to start a fresh new process instead of fork.
        This is recommended to switch on if we want to do local RPC demonstration
        for GPU devices to avoid fork safety issues.

    key : str, optional
        The key used to identify the server in Proxy connection.
269 270 271

    load_library : str, optional
        List of additional libraries to be loaded during execution.
272 273 274 275 276 277 278 279
    """
    def __init__(self,
                 host,
                 port=9091,
                 port_end=9199,
                 is_proxy=False,
                 use_popen=False,
                 tracker_addr=None,
280 281
                 key="",
                 load_library=None):
282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299
        try:
            if base._ServerLoop is None:
                raise RuntimeError("Please compile with USE_RPC=1")
        except NameError:
            raise RuntimeError("Please compile with USE_RPC=1")
        self.host = host
        self.port = port
        self.libs = []

        if use_popen:
            cmd = ["python",
                   "-m", "tvm.exec.rpc_server",
                   "--host=%s" % host,
                   "--port=%s" % port]
            if tracker_addr:
                assert key
                cmd += ["--tracker=%s:%d" % tracker_addr,
                        "--key=%s" % key]
300 301
            if load_library:
                cmd += ["--load-libary", load_library]
302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326
            self.proc = multiprocessing.Process(
                target=subprocess.check_call, args=(cmd,))
            self.proc.deamon = True
            self.proc.start()
            time.sleep(1)
        elif not is_proxy:
            sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            self.port = None
            for my_port in range(port, port_end):
                try:
                    sock.bind((host, my_port))
                    self.port = my_port
                    break
                except socket.error as sock_err:
                    if sock_err.errno in [98, 48]:
                        continue
                    else:
                        raise sock_err
                if not self.port:
                    raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end))
            logging.info("RPCServer: bind to %s:%d", host, self.port)
            sock.listen(1)
            self.sock = sock
            self.proc = multiprocessing.Process(
                target=_listen_loop, args=(
327
                    self.sock, self.port, key, tracker_addr, load_library))
328 329 330 331
            self.proc.deamon = True
            self.proc.start()
        else:
            self.proc = multiprocessing.Process(
332
                target=_connect_proxy_loop, args=((host, port), key, load_library))
333 334 335 336 337 338 339 340 341 342 343
            self.proc.deamon = True
            self.proc.start()

    def terminate(self):
        """Terminate the server process"""
        if self.proc:
            self.proc.terminate()
            self.proc = None

    def __del__(self):
        self.terminate()