base.py 3.46 KB
Newer Older
1
"""Base definitions for RPC."""
2 3
# pylint: disable=invalid-name

4 5 6 7 8 9 10 11 12 13
from __future__ import absolute_import

import socket
import time
import json
import errno
import struct
import random
import logging

14 15
from .._ffi.function import _init_api
from .._ffi.base import py_str
16 17 18 19 20 21 22 23 24 25 26 27

# Magic header for RPC data plane
RPC_MAGIC = 0xff271
# magic header for RPC tracker(control plane)
RPC_TRACKER_MAGIC = 0x2f271
# sucess response
RPC_CODE_SUCCESS = RPC_MAGIC + 0
# duplicate key in proxy
RPC_CODE_DUPLICATE = RPC_MAGIC + 1
# cannot found matched key in server
RPC_CODE_MISMATCH = RPC_MAGIC + 2

28
logger = logging.getLogger('RPCServer')
29 30 31 32 33 34 35 36 37

class TrackerCode(object):
    """Enumeration code for the RPC tracker"""
    FAIL = -1
    SUCCESS = 0
    PING = 1
    STOP = 2
    PUT = 3
    REQUEST = 4
38 39
    UPDATE_INFO = 5
    SUMMARY = 6
40
    GET_PENDING_MATCHKEYS = 7
41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78

RPC_SESS_MASK = 128


def recvall(sock, nbytes):
    """Receive all nbytes from socket.

    Parameters
    ----------
    sock: Socket
       The socket

    nbytes : int
       Number of bytes to be received.
    """
    res = []
    nread = 0
    while nread < nbytes:
        chunk = sock.recv(min(nbytes - nread, 1024))
        if not chunk:
            raise IOError("connection reset")
        nread += len(chunk)
        res.append(chunk)
    return b"".join(res)


def sendjson(sock, data):
    """send a python value to remote via json

    Parameters
    ----------
    sock : Socket
        The socket

    data : object
        Python value to be sent.
    """
    data = json.dumps(data)
tqchen committed
79
    sock.sendall(struct.pack("<i", len(data)))
80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
    sock.sendall(data.encode("utf-8"))


def recvjson(sock):
    """receive python value from remote via json

    Parameters
    ----------
    sock : Socket
        The socket

    Returns
    -------
    value : object
        The value received.
    """
tqchen committed
96
    size = struct.unpack("<i", recvall(sock, 4))[0]
97 98 99 100
    data = json.loads(py_str(recvall(sock, size)))
    return data


101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123
def random_key(prefix, cmap=None):
    """Generate a random key

    Parameters
    ----------
    prefix : str
        The string prefix

    cmap : dict
        Conflict map

    Returns
    -------
    key : str
        The generated random key
    """
    if cmap:
        while True:
            key = prefix + str(random.random())
            if key not in cmap:
                return key
    else:
        return prefix + str(random.random())
124 125


126
def connect_with_retry(addr, timeout=60, retry_period=5):
127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154
    """Connect to a TPC address with retry

    This function is only reliable to short period of server restart.

    Parameters
    ----------
    addr : tuple
        address tuple

    timeout : float
         Timeout during retry

    retry_period : float
         Number of seconds before we retry again.
    """
    tstart = time.time()
    while True:
        try:
            sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            sock.connect(addr)
            return sock
        except socket.error as sock_err:
            if sock_err.args[0] not in (errno.ECONNREFUSED,):
                raise sock_err
            period = time.time() - tstart
            if period > timeout:
                raise RuntimeError(
                    "Failed to connect to server %s" % str(addr))
155 156
            logger.warning("Cannot connect to tracker %s, retry in %g secs...",
                           str(addr), retry_period)
157 158 159
            time.sleep(retry_period)


160 161
# Still use tvm.rpc for the foreign functions
_init_api("tvm.rpc", "tvm.rpc.base")