server.py 15.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17 18 19 20 21 22 23 24
"""RPC server implementation.

Note
----
Server is TCP based with the following protocol:
- Initial handshake to the peer
  - [RPC_MAGIC, keysize(int32), key-bytes]
- The key is in format
25
   - {server|client}:device-type[:random-key] [-timeout=timeout]
26
"""
27
# pylint: disable=invalid-name
28
import os
29
import ctypes
30 31 32 33 34 35 36
import socket
import select
import struct
import logging
import multiprocessing
import subprocess
import time
37
import sys
38
import signal
39
import platform
40
import tvm._ffi
41

42 43
from tvm._ffi.base import py_str
from tvm._ffi.libinfo import find_lib_path
44
from tvm.runtime.module import load_module as _load_module
45
from tvm.contrib import util
46 47 48
from . import base
from . base import TrackerCode

49 50
logger = logging.getLogger('RPCServer')

51
def _server_env(load_library, work_path=None):
52
    """Server environment function return temp dir"""
53 54 55 56
    if work_path:
        temp = work_path
    else:
        temp = util.tempdir()
57

58
    # pylint: disable=unused-variable
59
    @tvm._ffi.register_func("tvm.rpc.server.workpath")
60 61 62
    def get_workpath(path):
        return temp.relpath(path)

63
    @tvm._ffi.register_func("tvm.rpc.server.load_module", override=True)
64 65 66 67
    def load_module(file_name):
        """Load module from remote side."""
        path = temp.relpath(file_name)
        m = _load_module(path)
68
        logger.info("load_module %s", path)
69
        return m
70 71 72 73 74 75

    libs = []
    load_library = load_library.split(":") if load_library else []
    for file_name in load_library:
        file_name = find_lib_path(file_name)[0]
        libs.append(ctypes.CDLL(file_name, ctypes.RTLD_GLOBAL))
76
        logger.info("Load additional library %s", file_name)
77
    temp.libs = libs
78 79
    return temp

80
def _serve_loop(sock, addr, load_library, work_path=None):
81 82
    """Server loop"""
    sockfd = sock.fileno()
83
    temp = _server_env(load_library, work_path)
84
    base._ServerLoop(sockfd)
85 86
    if not work_path:
        temp.remove()
87
    logger.info("Finish serving %s", addr)
88 89 90 91 92 93 94 95 96

def _parse_server_opt(opts):
    # parse client options
    ret = {}
    for kv in opts:
        if kv.startswith("-timeout="):
            ret["timeout"] = float(kv[9:])
    return ret

97
def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr):
98
    """Listening loop of the server master."""
99
    def _accept_conn(listen_sock, tracker_conn, ping_period=2):
100 101 102 103 104 105 106 107 108 109 110 111 112
        """Accept connection from the other places.

        Parameters
        ----------
        listen_sock: Socket
            The socket used by listening process.

        tracker_conn : connnection to tracker
            Tracker connection

        ping_period : float, optional
            ping tracker every k seconds if no connection is accepted.
        """
113
        old_keyset = set()
114 115
        # Report resource to tracker
        if tracker_conn:
116
            matchkey = base.random_key(rpc_key + ":")
117
            base.sendjson(tracker_conn,
118
                          [TrackerCode.PUT, rpc_key, (port, matchkey), custom_addr])
119 120
            assert base.recvjson(tracker_conn) == TrackerCode.SUCCESS
        else:
121
            matchkey = rpc_key
122

123 124
        unmatch_period_count = 0
        unmatch_timeout = 4
125 126 127 128 129
        # Wait until we get a valid connection
        while True:
            if tracker_conn:
                trigger = select.select([listen_sock], [], [], ping_period)
                if not listen_sock in trigger[0]:
130 131 132 133
                    base.sendjson(tracker_conn, [TrackerCode.GET_PENDING_MATCHKEYS])
                    pending_keys = base.recvjson(tracker_conn)
                    old_keyset.add(matchkey)
                    # if match key not in pending key set
134
                    # it means the key is acquired by a client but not used.
135 136 137 138
                    if matchkey not in pending_keys:
                        unmatch_period_count += 1
                    else:
                        unmatch_period_count = 0
139
                    # regenerate match key if key is acquired but not used for a while
140
                    if unmatch_period_count * ping_period > unmatch_timeout + ping_period:
141
                        logger.info("no incoming connections, regenerate key ...")
142 143
                        matchkey = base.random_key(rpc_key + ":", old_keyset)
                        base.sendjson(tracker_conn,
144 145
                                      [TrackerCode.PUT, rpc_key, (port, matchkey),
                                       custom_addr])
146 147
                        assert base.recvjson(tracker_conn) == TrackerCode.SUCCESS
                        unmatch_period_count = 0
148 149
                    continue
            conn, addr = listen_sock.accept()
tqchen committed
150
            magic = struct.unpack("<i", base.recvall(conn, 4))[0]
151 152 153
            if magic != base.RPC_MAGIC:
                conn.close()
                continue
tqchen committed
154
            keylen = struct.unpack("<i", base.recvall(conn, 4))[0]
155 156
            key = py_str(base.recvall(conn, keylen))
            arr = key.split()
157
            expect_header = "client:" + matchkey
158 159
            server_key = "server:" + rpc_key
            if arr[0] != expect_header:
tqchen committed
160
                conn.sendall(struct.pack("<i", base.RPC_CODE_MISMATCH))
161
                conn.close()
162
                logger.warning("mismatch key from %s", addr)
163
                continue
164 165 166 167
            conn.sendall(struct.pack("<i", base.RPC_CODE_SUCCESS))
            conn.sendall(struct.pack("<i", len(server_key)))
            conn.sendall(server_key.encode("utf-8"))
            return conn, addr, _parse_server_opt(arr[1:])
168 169 170 171 172

    # Server logic
    tracker_conn = None
    while True:
        try:
173 174
            # step 1: setup tracker and report to tracker
            if tracker_addr and tracker_conn is None:
175
                tracker_conn = base.connect_with_retry(tracker_addr)
tqchen committed
176 177
                tracker_conn.sendall(struct.pack("<i", base.RPC_TRACKER_MAGIC))
                magic = struct.unpack("<i", base.recvall(tracker_conn, 4))[0]
178 179 180 181 182 183 184 185
                if magic != base.RPC_TRACKER_MAGIC:
                    raise RuntimeError("%s is not RPC Tracker" % str(tracker_addr))
                # report status of current queue
                cinfo = {"key" : "server:" + rpc_key}
                base.sendjson(tracker_conn,
                              [TrackerCode.UPDATE_INFO, cinfo])
                assert base.recvjson(tracker_conn) == TrackerCode.SUCCESS

186 187 188 189
            # step 2: wait for in-coming connections
            conn, addr, opts = _accept_conn(sock, tracker_conn)
        except (socket.error, IOError):
            # retry when tracker is dropped
190 191 192
            if tracker_conn:
                tracker_conn.close()
                tracker_conn = None
193
            continue
194
        except RuntimeError as exc:
195
            raise exc
196 197

        # step 3: serving
198
        work_path = util.tempdir()
199 200
        logger.info("connection from %s", addr)
        server_proc = multiprocessing.Process(target=_serve_loop,
201
                                              args=(conn, addr, load_library, work_path))
202 203 204 205 206 207 208
        server_proc.deamon = True
        server_proc.start()
        # close from our side.
        conn.close()
        # wait until server process finish or timeout
        server_proc.join(opts.get("timeout", None))
        if server_proc.is_alive():
209
            logger.info("Timeout in RPC session, kill..")
210
            # pylint: disable=import-outside-toplevel
Haozheng Fan committed
211
            import psutil
212 213 214 215 216
            parent = psutil.Process(server_proc.pid)
            # terminate worker childs
            for child in parent.children(recursive=True):
                child.terminate()
            # terminate the worker
217
            server_proc.terminate()
218
        work_path.remove()
219 220


221
def _connect_proxy_loop(addr, key, load_library):
222
    key = "server:" + key
223 224 225
    retry_count = 0
    max_retry = 5
    retry_period = 5
226
    while True:
227
        try:
228
            sock = socket.socket(base.get_addr_family(addr), socket.SOCK_STREAM)
229
            sock.connect(addr)
tqchen committed
230 231
            sock.sendall(struct.pack("<i", base.RPC_MAGIC))
            sock.sendall(struct.pack("<i", len(key)))
232
            sock.sendall(key.encode("utf-8"))
tqchen committed
233
            magic = struct.unpack("<i", base.recvall(sock, 4))[0]
234 235
            if magic == base.RPC_CODE_DUPLICATE:
                raise RuntimeError("key: %s has already been used in proxy" % key)
236 237

            if magic == base.RPC_CODE_MISMATCH:
238
                logger.warning("RPCProxy do not have matching client key %s", key)
239 240
            elif magic != base.RPC_CODE_SUCCESS:
                raise RuntimeError("%s is not RPC Proxy" % str(addr))
tqchen committed
241
            keylen = struct.unpack("<i", base.recvall(sock, 4))[0]
242 243
            remote_key = py_str(base.recvall(sock, keylen))
            opts = _parse_server_opt(remote_key.split()[1:])
244
            logger.info("connected to %s", str(addr))
245
            process = multiprocessing.Process(
246
                target=_serve_loop, args=(sock, addr, load_library))
247 248 249 250 251
            process.deamon = True
            process.start()
            sock.close()
            process.join(opts.get("timeout", None))
            if process.is_alive():
252
                logger.info("Timeout in RPC session, kill..")
253 254 255 256
                process.terminate()
            retry_count = 0
        except (socket.error, IOError) as err:
            retry_count += 1
257
            logger.warning("Error encountered %s, retry in %g sec", str(err), retry_period)
258 259 260
            if retry_count > max_retry:
                raise RuntimeError("Maximum retry error: last error: %s" % str(err))
            time.sleep(retry_period)
261 262 263 264 265 266 267 268 269 270 271 272 273 274

def _popen(cmd):
    proc = subprocess.Popen(cmd,
                            stdout=subprocess.PIPE,
                            stderr=subprocess.STDOUT,
                            env=os.environ)
    (out, _) = proc.communicate()
    if proc.returncode != 0:
        msg = "Server invoke error:\n"
        msg += out
        raise RuntimeError(msg)


class Server(object):
275
    """Start RPC server on a separate process.
276 277

    This is a simple python implementation based on multi-processing.
Siju committed
278
    It is also possible to implement a similar C based server with
279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301
    TVM runtime which does not depend on the python.

    Parameters
    ----------
    host : str
        The host url of the server.

    port : int
        The port to be bind to

    port_end : int, optional
        The end port to search

    is_proxy : bool, optional
        Whether the address specified is a proxy.
        If this is true, the host and port actually corresponds to the
        address of the proxy server.

    use_popen : bool, optional
        Whether to use Popen to start a fresh new process instead of fork.
        This is recommended to switch on if we want to do local RPC demonstration
        for GPU devices to avoid fork safety issues.

302 303 304
    tracker_addr: Tuple (str, int) , optional
        The address of RPC Tracker in tuple(host, ip) format.
        If is not None, the server will register itself to the tracker.
305

306
    key : str, optional
307
        The key used to identify the device type in tracker.
308 309 310

    load_library : str, optional
        List of additional libraries to be loaded during execution.
311 312 313 314 315 316

    custom_addr: str, optional
        Custom IP Address to Report to RPC Tracker

    silent: bool, optional
        Whether run this server in silent mode.
317 318 319 320 321 322 323 324
    """
    def __init__(self,
                 host,
                 port=9091,
                 port_end=9199,
                 is_proxy=False,
                 use_popen=False,
                 tracker_addr=None,
325
                 key="",
326
                 load_library=None,
327 328
                 custom_addr=None,
                 silent=False):
329 330 331 332 333 334 335 336
        try:
            if base._ServerLoop is None:
                raise RuntimeError("Please compile with USE_RPC=1")
        except NameError:
            raise RuntimeError("Please compile with USE_RPC=1")
        self.host = host
        self.port = port
        self.libs = []
337
        self.custom_addr = custom_addr
338
        self.use_popen = use_popen
339

340
        if silent:
341
            logger.setLevel(logging.ERROR)
342

343
        if use_popen:
344
            cmd = [sys.executable,
345 346 347 348 349 350 351
                   "-m", "tvm.exec.rpc_server",
                   "--host=%s" % host,
                   "--port=%s" % port]
            if tracker_addr:
                assert key
                cmd += ["--tracker=%s:%d" % tracker_addr,
                        "--key=%s" % key]
352
            if load_library:
353 354 355
                cmd += ["--load-library", load_library]
            if custom_addr:
                cmd += ["--custom-addr", custom_addr]
356 357 358
            if silent:
                cmd += ["--silent"]

359 360 361 362 363 364 365 366 367
            # prexec_fn is not thread safe and may result in deadlock.
            # python 3.2 introduced the start_new_session parameter as
            # an alternative to the common use case of
            # prexec_fn=os.setsid.  Once the minimum version of python
            # supported by TVM reaches python 3.2 this code can be
            # rewritten in favour of start_new_session.  In the
            # interim, stop the pylint diagnostic.
            #
            # pylint: disable=subprocess-popen-preexec-fn
368 369 370 371
            if platform.system() == "Windows":
                self.proc = subprocess.Popen(cmd, creationflags=subprocess.CREATE_NEW_PROCESS_GROUP)
            else:
                self.proc = subprocess.Popen(cmd, preexec_fn=os.setsid)
372
            time.sleep(0.5)
373
        elif not is_proxy:
374
            sock = socket.socket(base.get_addr_family((host, port)), socket.SOCK_STREAM)
375 376 377 378 379 380 381 382 383
            self.port = None
            for my_port in range(port, port_end):
                try:
                    sock.bind((host, my_port))
                    self.port = my_port
                    break
                except socket.error as sock_err:
                    if sock_err.errno in [98, 48]:
                        continue
384
                    raise sock_err
385 386
            if not self.port:
                raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end))
387
            logger.info("bind to %s:%d", host, self.port)
388 389 390 391
            sock.listen(1)
            self.sock = sock
            self.proc = multiprocessing.Process(
                target=_listen_loop, args=(
392
                    self.sock, self.port, key, tracker_addr, load_library,
393
                    self.custom_addr))
394 395 396 397
            self.proc.deamon = True
            self.proc.start()
        else:
            self.proc = multiprocessing.Process(
398
                target=_connect_proxy_loop, args=((host, port), key, load_library))
399 400 401 402 403
            self.proc.deamon = True
            self.proc.start()

    def terminate(self):
        """Terminate the server process"""
404 405
        if self.use_popen:
            if self.proc:
406 407 408 409
                if platform.system() == "Windows":
                    os.kill(self.proc.pid, signal.CTRL_C_EVENT)
                else:
                    os.killpg(self.proc.pid, signal.SIGTERM)
410 411 412 413 414
                self.proc = None
        else:
            if self.proc:
                self.proc.terminate()
                self.proc = None
415 416 417

    def __del__(self):
        self.terminate()