"""RPC interface for easy testing.

RPC enables connect to a remote server, upload and launch functions.
This is useful to for cross-compile and remote testing,
The compiler stack runs on local server, while we use RPC server
to run on remote runtime which don't have a compiler available.

The test program compiles the program on local server,
upload and run remote RPC server, get the result back to verify correctness.
"""
from __future__ import absolute_import

import os
import socket
import struct
import logging
import multiprocessing
import subprocess
import time
from . import util, cc, tar
from ..module import load as _load_module
from .._ffi.function import _init_api, register_func
from .._ffi.ndarray import context as _context
from .._ffi.base import py_str

RPC_MAGIC = 0xff271
RPC_SESS_MASK = 128

def _server_env():
    """Server environment function return temp dir"""
    temp = util.tempdir()
    # pylint: disable=unused-variable
    @register_func("tvm.contrib.rpc.server.workpath")
    def get_workpath(path):
        return temp.relpath(path)

    @register_func("tvm.contrib.rpc.server.load_module", override=True)
    def load_module(file_name):
        """Load module from remote side."""
        path = temp.relpath(file_name)
        # Try create a shared library in remote
        if path.endswith(".o"):
            logging.info("Create shared library based on %s", path)
            cc.create_shared(path + ".so", path)
            path += ".so"
        elif path.endswith(".tar"):
            tar_temp = util.tempdir()
            tar.untar(path, tar_temp.temp_dir)
            files = [tar_temp.relpath(x) for x in tar_temp.listdir()]
            cc.create_shared(path + ".so", files)
            path += ".so"
        m = _load_module(path)
        logging.info("load_module %s", path)
        return m
    return temp


def _serve_loop(sock, addr):
    """Server loop"""
    sockfd = sock.fileno()
    temp = _server_env()
    _ServerLoop(sockfd)
    temp.remove()
    logging.info("Finish serving %s", addr)


def _recvall(sock, nbytes):
    res = []
    nread = 0
    while nread < nbytes:
        chunk = sock.recv(min(nbytes - nread, 1024))
        nread += len(chunk)
        res.append(chunk)
    return b"".join(res)


def _listen_loop(sock, exclusive):
    """Lisenting loop"""
    last_proc = None
    while True:
        conn, addr = sock.accept()

        if last_proc and last_proc.is_alive() and exclusive:
            logging.info("Kill last call")
            last_proc.terminate()

        logging.info("RPCServer: connection from %s", addr)
        magic = struct.unpack("@i", _recvall(conn, 4))[0]
        if magic != RPC_MAGIC:
            conn.close()
            continue
        keylen = struct.unpack("@i", _recvall(conn, 4))[0]
        key = py_str(_recvall(conn, keylen))
        if not key.startswith("client:"):
            conn.sendall(struct.pack("@i", RPC_MAGIC + 2))
        else:
            conn.sendall(struct.pack("@i", RPC_MAGIC))
        logging.info("Connection from %s", addr)

        process = multiprocessing.Process(target=_serve_loop, args=(conn, addr))
        process.deamon = True
        process.start()
        last_proc = process
        # close from our side.
        conn.close()


def _connect_proxy_loop(addr, key):
    key = "server:" + key
    while True:
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        sock.connect(addr)
        sock.sendall(struct.pack("@i", RPC_MAGIC))
        sock.sendall(struct.pack("@i", len(key)))
        sock.sendall(key.encode("utf-8"))
        magic = struct.unpack("@i", _recvall(sock, 4))[0]
        if magic == RPC_MAGIC + 1:
            raise RuntimeError("key: %s has already been used in proxy" % key)
        elif magic == RPC_MAGIC + 2:
            logging.info("RPCProxy do not have matching client key %s", key)
        elif magic != RPC_MAGIC:
            raise RuntimeError("%s is not RPC Proxy" % str(addr))
        logging.info("RPCProxy connected to %s", str(addr))
        process = multiprocessing.Process(target=_serve_loop, args=(sock, addr))
        process.deamon = True
        process.start()
        process.join()


def _popen(cmd):
    proc = subprocess.Popen(cmd,
                            stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
                            env=os.environ)
    (out, _) = proc.communicate()
    if proc.returncode != 0:
        msg = "Server invoke error:\n"
        msg += out
        raise RuntimeError(msg)


class Server(object):
    """Start RPC server on a seperate process.

    This is a simple python implementation based on multi-processing.
    It is also possible to implement a similar C based sever with
    TVM runtime which does not depend on the python.

    Parameters
    ----------
    host : str
        The host url of the server.

    port : int
        The port to be bind to

    port_end : int, optional
        The end port to search

    is_proxy : bool, optional
        Whether the address specified is a proxy.
        If this is true, the host and port actually corresponds to the
        address of the proxy server.

    use_popen : bool, optional
        Whether to use Popen to start a fresh new process instead of fork.
        This is recommended to switch on if we want to do local RPC demonstration
        for GPU devices to avoid fork safety issues.

    exclusive : bool, optional
        If this is enabled, the server will kill old connection
        when new connection comes. This can make sure the current call
        monopolize the hardware resource.

    key : str, optional
        The key used to identify the server in Proxy connection.
    """
    def __init__(self,
                 host,
                 port=9091,
                 port_end=9199,
                 is_proxy=False,
                 use_popen=False,
                 exclusive=False,
                 key=""):
        try:
            if _ServerLoop is None:
                raise RuntimeError("Please compile with USE_RPC=1")
        except NameError:
            raise RuntimeError("Please compile with USE_RPC=1")
        self.host = host
        self.port = port
        self.libs = []

        if use_popen:
            cmd = ["python",
                   "-m", "tvm.exec.rpc_server",
                   "--host=%s" % host,
                   "--port=%s" % port]
            self.proc = multiprocessing.Process(
                target=subprocess.check_call, args=(cmd,))
            self.proc.deamon = True
            self.proc.start()
            time.sleep(1)
        elif not is_proxy:
            sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            self.port = None
            for my_port in range(port, port_end):
                try:
                    sock.bind((host, my_port))
                    self.port = my_port
                    break
                except socket.error as sock_err:
                    if sock_err.errno in [98, 48]:
                        continue
                    else:
                        raise sock_err
                if not self.port:
                    raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end))
            logging.info("RPCServer: bind to %s:%d", host, self.port)
            sock.listen(1)
            self.sock = sock
            self.proc = multiprocessing.Process(
                target=_listen_loop, args=(self.sock, exclusive))
            self.proc.deamon = True
            self.proc.start()
        else:
            self.proc = multiprocessing.Process(
                target=_connect_proxy_loop, args=((host, port), key))
            self.proc.deamon = True
            self.proc.start()

    def terminate(self):
        """Terminate the server process"""
        if self.proc:
            self.proc.terminate()
            self.proc = None

    def __del__(self):
        self.terminate()


class RPCSession(object):
    """RPC Client session module

    Do not directly create the obhect, call connect
    """
    # pylint: disable=invalid-name
    def __init__(self, sess):
        self._sess = sess
        self._tbl_index = _SessTableIndex(sess)
        self._remote_funcs = {}

    def get_function(self, name):
        """Get function from the session.

        Parameters
        ----------
        name : str
            The name of the function

        Returns
        -------
        f : Function
            The result function.
        """
        return self._sess.get_function(name)

    def context(self, dev_type, dev_id=0):
        """Construct a remote context.

        Parameters
        ----------
        dev_type: int or str

        dev_id: int, optional

        Returns
        -------
        ctx: TVMContext
            The corresponding encoded remote context.
        """
        ctx = _context(dev_type, dev_id)
        encode = (self._tbl_index + 1) * RPC_SESS_MASK
        ctx.device_type += encode
        ctx._rpc_sess = self
        return ctx

    def cpu(self, dev_id=0):
        """Construct remote CPU device."""
        return self.context(1, dev_id)

    def gpu(self, dev_id=0):
        """Construct remote GPU device."""
        return self.context(2, dev_id)

    def cl(self, dev_id=0):
        """Construct remote OpenCL device."""
        return self.context(4, dev_id)

    def metal(self, dev_id=0):
        """Construct remote Metal device."""
        return self.context(8, dev_id)

    def opengl(self, dev_id=0):
        """Construct remote OpenGL device."""
        return self.context(11, dev_id)

    def ext_dev(self, dev_id=0):
        """Construct remote extension device."""
        return self.context(12, dev_id)

    def upload(self, data, target=None):
        """Upload file to remote runtime temp folder

        Parameters
        ----------
        data : str or bytearray
            The file name or binary in local to upload.

        target : str, optional
            The path in remote
        """
        if isinstance(data, bytearray):
            if not target:
                raise ValueError("target must present when file is a bytearray")
            blob = data
        else:
            blob = bytearray(open(data, "rb").read())
            if not target:
                target = os.path.basename(data)

        if "upload" not in self._remote_funcs:
            self._remote_funcs["upload"] = self.get_function(
                "tvm.contrib.rpc.server.upload")
        self._remote_funcs["upload"](target, blob)

    def download(self, path):
        """Download file from remote temp folder.

        Parameters
        ----------
        path : str
            The relative location to remote temp folder.

        Returns
        -------
        blob : bytearray
            The result blob from the file.
        """
        if "download" not in self._remote_funcs:
            self._remote_funcs["download"] = self.get_function(
                "tvm.contrib.rpc.server.download")
        return self._remote_funcs["download"](path)

    def load_module(self, path):
        """Load a remote module, the file need to be uploaded first.

        Parameters
        ----------
        path : str
            The relative location to remote temp folder.

        Returns
        -------
        m : Module
            The remote module containing remote function.
        """
        return _LoadRemoteModule(self._sess, path)


def connect(url, port, key=""):
    """Connect to RPC Server

    Parameters
    ----------
    url : str
        The url of the host

    port : int
        The port to connect to

    key : str, optional
        Additional key to match server

    Returns
    -------
    sess : RPCSession
        The connected session.
    """
    try:
        sess = _Connect(url, port, key)
    except NameError:
        raise RuntimeError("Please compile with USE_RPC=1")
    return RPCSession(sess)

_init_api("tvm.contrib.rpc")