base.py 4.36 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17
"""Base definitions for RPC."""
18 19
# pylint: disable=invalid-name

20 21 22 23 24 25 26 27 28 29
from __future__ import absolute_import

import socket
import time
import json
import errno
import struct
import random
import logging

30 31
from .._ffi.function import _init_api
from .._ffi.base import py_str
32 33 34 35 36 37 38 39 40 41 42 43

# Magic header for RPC data plane
RPC_MAGIC = 0xff271
# magic header for RPC tracker(control plane)
RPC_TRACKER_MAGIC = 0x2f271
# sucess response
RPC_CODE_SUCCESS = RPC_MAGIC + 0
# duplicate key in proxy
RPC_CODE_DUPLICATE = RPC_MAGIC + 1
# cannot found matched key in server
RPC_CODE_MISMATCH = RPC_MAGIC + 2

44
logger = logging.getLogger('RPCServer')
45 46 47 48 49 50 51 52 53

class TrackerCode(object):
    """Enumeration code for the RPC tracker"""
    FAIL = -1
    SUCCESS = 0
    PING = 1
    STOP = 2
    PUT = 3
    REQUEST = 4
54 55
    UPDATE_INFO = 5
    SUMMARY = 6
56
    GET_PENDING_MATCHKEYS = 7
57 58 59 60

RPC_SESS_MASK = 128


61 62 63 64 65
def get_addr_family(addr):
    res = socket.getaddrinfo(addr[0], addr[1], 0, 0, socket.IPPROTO_TCP)
    return res[0][0]


66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99
def recvall(sock, nbytes):
    """Receive all nbytes from socket.

    Parameters
    ----------
    sock: Socket
       The socket

    nbytes : int
       Number of bytes to be received.
    """
    res = []
    nread = 0
    while nread < nbytes:
        chunk = sock.recv(min(nbytes - nread, 1024))
        if not chunk:
            raise IOError("connection reset")
        nread += len(chunk)
        res.append(chunk)
    return b"".join(res)


def sendjson(sock, data):
    """send a python value to remote via json

    Parameters
    ----------
    sock : Socket
        The socket

    data : object
        Python value to be sent.
    """
    data = json.dumps(data)
tqchen committed
100
    sock.sendall(struct.pack("<i", len(data)))
101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116
    sock.sendall(data.encode("utf-8"))


def recvjson(sock):
    """receive python value from remote via json

    Parameters
    ----------
    sock : Socket
        The socket

    Returns
    -------
    value : object
        The value received.
    """
tqchen committed
117
    size = struct.unpack("<i", recvall(sock, 4))[0]
118 119 120 121
    data = json.loads(py_str(recvall(sock, size)))
    return data


122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144
def random_key(prefix, cmap=None):
    """Generate a random key

    Parameters
    ----------
    prefix : str
        The string prefix

    cmap : dict
        Conflict map

    Returns
    -------
    key : str
        The generated random key
    """
    if cmap:
        while True:
            key = prefix + str(random.random())
            if key not in cmap:
                return key
    else:
        return prefix + str(random.random())
145 146


147
def connect_with_retry(addr, timeout=60, retry_period=5):
148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165
    """Connect to a TPC address with retry

    This function is only reliable to short period of server restart.

    Parameters
    ----------
    addr : tuple
        address tuple

    timeout : float
         Timeout during retry

    retry_period : float
         Number of seconds before we retry again.
    """
    tstart = time.time()
    while True:
        try:
166
            sock = socket.socket(get_addr_family(addr), socket.SOCK_STREAM)
167 168 169 170 171 172 173 174 175
            sock.connect(addr)
            return sock
        except socket.error as sock_err:
            if sock_err.args[0] not in (errno.ECONNREFUSED,):
                raise sock_err
            period = time.time() - tstart
            if period > timeout:
                raise RuntimeError(
                    "Failed to connect to server %s" % str(addr))
176 177
            logger.warning("Cannot connect to tracker %s, retry in %g secs...",
                           str(addr), retry_period)
178 179 180
            time.sleep(retry_period)


181 182
# Still use tvm.rpc for the foreign functions
_init_api("tvm.rpc", "tvm.rpc.base")