# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
"""Base definitions for RPC."""
# pylint: disable=invalid-name

from __future__ import absolute_import

import socket
import time
import json
import errno
import struct
import random
import logging

from .._ffi.function import _init_api
from .._ffi.base import py_str

# Magic header for RPC data plane
RPC_MAGIC = 0xff271
# magic header for RPC tracker(control plane)
RPC_TRACKER_MAGIC = 0x2f271
# sucess response
RPC_CODE_SUCCESS = RPC_MAGIC + 0
# duplicate key in proxy
RPC_CODE_DUPLICATE = RPC_MAGIC + 1
# cannot found matched key in server
RPC_CODE_MISMATCH = RPC_MAGIC + 2

logger = logging.getLogger('RPCServer')

class TrackerCode(object):
    """Enumeration code for the RPC tracker"""
    FAIL = -1
    SUCCESS = 0
    PING = 1
    STOP = 2
    PUT = 3
    REQUEST = 4
    UPDATE_INFO = 5
    SUMMARY = 6
    GET_PENDING_MATCHKEYS = 7

RPC_SESS_MASK = 128


def get_addr_family(addr):
    res = socket.getaddrinfo(addr[0], addr[1], 0, 0, socket.IPPROTO_TCP)
    return res[0][0]


def recvall(sock, nbytes):
    """Receive all nbytes from socket.

    Parameters
    ----------
    sock: Socket
       The socket

    nbytes : int
       Number of bytes to be received.
    """
    res = []
    nread = 0
    while nread < nbytes:
        chunk = sock.recv(min(nbytes - nread, 1024))
        if not chunk:
            raise IOError("connection reset")
        nread += len(chunk)
        res.append(chunk)
    return b"".join(res)


def sendjson(sock, data):
    """send a python value to remote via json

    Parameters
    ----------
    sock : Socket
        The socket

    data : object
        Python value to be sent.
    """
    data = json.dumps(data)
    sock.sendall(struct.pack("<i", len(data)))
    sock.sendall(data.encode("utf-8"))


def recvjson(sock):
    """receive python value from remote via json

    Parameters
    ----------
    sock : Socket
        The socket

    Returns
    -------
    value : object
        The value received.
    """
    size = struct.unpack("<i", recvall(sock, 4))[0]
    data = json.loads(py_str(recvall(sock, size)))
    return data


def random_key(prefix, cmap=None):
    """Generate a random key

    Parameters
    ----------
    prefix : str
        The string prefix

    cmap : dict
        Conflict map

    Returns
    -------
    key : str
        The generated random key
    """
    if cmap:
        while True:
            key = prefix + str(random.random())
            if key not in cmap:
                return key
    else:
        return prefix + str(random.random())


def connect_with_retry(addr, timeout=60, retry_period=5):
    """Connect to a TPC address with retry

    This function is only reliable to short period of server restart.

    Parameters
    ----------
    addr : tuple
        address tuple

    timeout : float
         Timeout during retry

    retry_period : float
         Number of seconds before we retry again.
    """
    tstart = time.time()
    while True:
        try:
            sock = socket.socket(get_addr_family(addr), socket.SOCK_STREAM)
            sock.connect(addr)
            return sock
        except socket.error as sock_err:
            if sock_err.args[0] not in (errno.ECONNREFUSED,):
                raise sock_err
            period = time.time() - tstart
            if period > timeout:
                raise RuntimeError(
                    "Failed to connect to server %s" % str(addr))
            logger.warning("Cannot connect to tracker %s, retry in %g secs...",
                           str(addr), retry_period)
            time.sleep(retry_period)


# Still use tvm.rpc for the foreign functions
_init_api("tvm.rpc", "tvm.rpc.base")