# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
"""RPC Tracker, tracks and distributes the TVM RPC resources.

This folder implemements the tracker server logic.

Note
----
Tracker is a TCP based rest api with the following protocol:
- Initial handshake to the peer
  - RPC_TRACKER_MAGIC
- Normal message: [size(int32), json-data]
- Each message is initiated by the client, and the tracker replies with a json.

List of available APIs:

- PING: check if tracker is alive
  - input: [TrackerCode.PING]
  - return: TrackerCode.SUCCESS
- PUT: report resource to tracker
  - input: [TrackerCode.PUT, [port, match-key]]
  - return: TrackerCode.SUCCESS
  - note: match-key is a randomly generated identify the resource during connection.
- REQUEST: request a new resource from tracker
  - input: [TrackerCode.REQUEST, [key, user, priority]]
  - return: [TrackerCode.SUCCESS, [url, port, match-key]]
"""
# pylint: disable=invalid-name

import heapq
import time
import logging
import socket
import multiprocessing
import errno
import struct
import json

try:
    from tornado import ioloop
    from . import tornado_util
except ImportError as error_msg:
    raise ImportError(
        "RPCTracker module requires tornado package %s. Try 'pip install tornado'." % error_msg)

from .._ffi.base import py_str
from . import base
from .base import RPC_TRACKER_MAGIC, TrackerCode

logger = logging.getLogger("RPCTracker")

class Scheduler(object):
    """Abstratc interface of scheduler."""
    def put(self, value):
        """Push a resource into the scheduler.

        This function can trigger callbacks in the scheduler.

        Parameters
        ----------
        value : object
            The resource to be put in the scheduler.
        """
        raise NotImplementedError()

    def request(self, user, priority, callback):
        """Request a resource.

        Parameters
        ----------
        user : str
            The user who is requesting the resource.

        priority : int
            The job priority

        callback : function: value->bool
            Callback function to receive an resource when ready
            returns True if the resource is consumed.
        """
        raise NotImplementedError()

    def remove(self, value):
        """Remove a resource in the scheduler

        Parameters
        ----------
        value: object
            The resource to remove
        """


    def summary(self):
        """Get summary information of the scheduler."""
        raise NotImplementedError()


class PriorityScheduler(Scheduler):
    """Priority based scheduler, FIFO based on time"""
    def __init__(self, key):
        self._key = key
        self._values = []
        self._requests = []

    def _schedule(self):
        while self._requests and self._values:
            value = self._values.pop(0)
            item = heapq.heappop(self._requests)
            callback = item[-1]
            if callback(value[1:]):
                value[0].pending_matchkeys.remove(value[-1])
            else:
                self._values.append(value)

    def put(self, value):
        self._values.append(value)
        self._schedule()

    def request(self, user, priority, callback):
        heapq.heappush(self._requests, (-priority, time.time(), callback))
        self._schedule()

    def remove(self, value):
        if value in self._values:
            self._values.remove(value)
            self._schedule()

    def summary(self):
        """Get summary information of the scheduler."""
        return {"free": len(self._values),
                "pending": len(self._requests)}


class TCPEventHandler(tornado_util.TCPHandler):
    """Base asynchronize message handler.

    The tracker and client follows a simple message protocol.
    The message is in form [nbytes(int32)] [json-str].
    All the information is packed in json-str
    """
    def __init__(self, tracker, sock, addr):
        super(TCPEventHandler, self).__init__(sock)
        self._data = bytearray()
        self._tracker = tracker
        self._msg_size = 0
        self._addr = addr
        self._init_req_nbytes = 4
        self._info = {"addr": addr}
        # list of pending match keys that has not been used.
        self.pending_matchkeys = set()
        self._tracker._connections.add(self)
        self.put_values = []

    def name(self):
        """name of connection"""
        return "TCPSocket: %s" % str(self._addr)

    def summary(self):
        """Summary of this connection"""
        return self._info

    def _init_conn(self, message):
        """Initialie the connection"""
        if len(message) != 4:
            logger.warning("Invalid connection from %s", self.name())
            self.close()
        magic = struct.unpack('<i', message)[0]
        if magic != RPC_TRACKER_MAGIC:
            logger.warning("Invalid magic from %s", self.name())
            self.close()
        self.write_message(struct.pack('<i', RPC_TRACKER_MAGIC), binary=True)
        self._init_req_nbytes = 0

    def on_message(self, message):
        """Callback when a message is received.

        Parameters
        ----------
        message : bytearray
            The bytes received
        """
        assert isinstance(message, bytes)
        if self._init_req_nbytes:
            self._init_conn(message)
            return

        self._data += message

        while True:
            if self._msg_size == 0:
                if len(self._data) >= 4:
                    self._msg_size = struct.unpack('<i', self._data[:4])[0]
                else:
                    return
            if self._msg_size != 0 and len(self._data) >= self._msg_size + 4:
                msg = py_str(bytes(self._data[4:4 + self._msg_size]))
                del self._data[:4 + self._msg_size]
                self._msg_size = 0
                # pylint: disable=broad-except
                self.call_handler(json.loads(msg))
            else:
                return

    def ret_value(self, data):
        """return value to the output"""
        data = json.dumps(data)
        self.write_message(
            struct.pack('<i', len(data)), binary=True)
        self.write_message(data.encode("utf-8"), binary=True)

    def call_handler(self, args):
        """Event handler when json request arrives."""
        code = args[0]
        if code == TrackerCode.PUT:
            key = args[1]
            port, matchkey = args[2]
            self.pending_matchkeys.add(matchkey)
            # got custom address (from rpc server)
            if len(args) >= 4 and args[3] is not None:
                value = (self, args[3], port, matchkey)
            else:
                value = (self, self._addr[0], port, matchkey)
            self._tracker.put(key, value)
            self.put_values.append(value)
            self.ret_value(TrackerCode.SUCCESS)
        elif code == TrackerCode.REQUEST:
            key = args[1]
            user = args[2]
            priority = args[3]
            def _cb(value):
                # if the connection is already closed
                if not self._sock:
                    return False
                try:
                    self.ret_value([TrackerCode.SUCCESS, value])
                except (socket.error, IOError):
                    return False
                return True
            self._tracker.request(key, user, priority, _cb)
        elif code == TrackerCode.PING:
            self.ret_value(TrackerCode.SUCCESS)
        elif code == TrackerCode.GET_PENDING_MATCHKEYS:
            self.ret_value(list(self.pending_matchkeys))
        elif code == TrackerCode.STOP:
            # safe stop tracker
            if self._tracker._stop_key == args[1]:
                self.ret_value(TrackerCode.SUCCESS)
                self._tracker.stop()
            else:
                self.ret_value(TrackerCode.FAIL)
        elif code == TrackerCode.UPDATE_INFO:
            self._info.update(args[1])
            self.ret_value(TrackerCode.SUCCESS)
        elif code == TrackerCode.SUMMARY:
            status = self._tracker.summary()
            self.ret_value([TrackerCode.SUCCESS, status])
        else:
            logger.warning("Unknown tracker code %d", code)
            self.close()

    def on_close(self):
        self._tracker.close(self)

    def on_error(self, err):
        logger.warning("%s: Error in RPC Tracker: %s", self.name(), err)
        self.close()


class TrackerServerHandler(object):
    """Tracker that tracks the resources."""
    def __init__(self, sock, stop_key):
        self._scheduler_map = {}
        self._sock = sock
        self._sock.setblocking(0)
        self._ioloop = ioloop.IOLoop.current()
        self._stop_key = stop_key
        self._connections = set()
        def _event_handler(_, events):
            self._on_event(events)
        self._ioloop.add_handler(
            self._sock.fileno(), _event_handler, self._ioloop.READ)

    def _on_event(self, _):
        while True:
            try:
                conn, addr = self._sock.accept()
                TCPEventHandler(self, conn, addr)
            except socket.error as err:
                if err.args[0] in (errno.EAGAIN, errno.EWOULDBLOCK):
                    break

    def create_scheduler(self, key):
        """Create a new scheduler."""
        return PriorityScheduler(key)

    def put(self, key, value):
        """Report a new resource to the tracker."""
        if key not in self._scheduler_map:
            self._scheduler_map[key] = self.create_scheduler(key)
        self._scheduler_map[key].put(value)

    def request(self, key, user, priority, callback):
        """Request a new resource."""
        if key not in self._scheduler_map:
            self._scheduler_map[key] = self.create_scheduler(key)
        self._scheduler_map[key].request(user, priority, callback)

    def close(self, conn):
        self._connections.remove(conn)
        if 'key' in conn._info:
            key = conn._info['key'].split(':')[1]  # 'server:rasp3b' -> 'rasp3b'
            for value in conn.put_values:
                self._scheduler_map[key].remove(value)

    def stop(self):
        """Safely stop tracker."""
        for conn in list(self._connections):
            conn.close()
        self._sock.close()
        self._ioloop.stop()

    def summary(self):
        """Return a dict summarizing current status."""
        qinfo = {}
        for k, v in self._scheduler_map.items():
            qinfo[k] = v.summary()
        cinfo = []
        # ignore client connections without key
        for conn in self._connections:
            res = conn.summary()
            if res.get("key", "").startswith("server"):
                cinfo.append(res)
        return {"queue_info": qinfo, "server_info": cinfo}

    def run(self):
        """Run the tracker server"""
        self._ioloop.start()

def _tracker_server(listen_sock, stop_key):
    handler = TrackerServerHandler(listen_sock, stop_key)
    handler.run()


class Tracker(object):
    """Start RPC tracker on a seperate process.

    Python implementation based on multi-processing.

    Parameters
    ----------
    host : str
        The host url of the server.

    port : int
        The TCP port to be bind to

    port_end : int, optional
        The end TCP port to search

    silent: bool, optional
        Whether run in silent mode
    """
    def __init__(self,
                 host,
                 port=9190,
                 port_end=9199,
                 silent=False):
        if silent:
            logger.setLevel(logging.WARN)

        sock = socket.socket(base.get_addr_family((host, port)), socket.SOCK_STREAM)
        self.port = None
        self.stop_key = base.random_key("tracker")
        for my_port in range(port, port_end):
            try:
                sock.bind((host, my_port))
                self.port = my_port
                break
            except socket.error as sock_err:
                if sock_err.errno in [98, 48]:
                    continue
                else:
                    raise sock_err
        if not self.port:
            raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end))
        logger.info("bind to %s:%d", host, self.port)
        sock.listen(1)
        self.proc = multiprocessing.Process(
            target=_tracker_server, args=(sock, self.stop_key))
        self.proc.start()
        self.host = host
        # close the socket on this process
        sock.close()

    def _stop_tracker(self):
        sock = socket.socket(base.get_addr_family((self.host, self.port)), socket.SOCK_STREAM)
        sock.connect((self.host, self.port))
        sock.sendall(struct.pack("<i", base.RPC_TRACKER_MAGIC))
        magic = struct.unpack("<i", base.recvall(sock, 4))[0]
        assert magic == base.RPC_TRACKER_MAGIC
        base.sendjson(sock, [TrackerCode.STOP, self.stop_key])
        assert base.recvjson(sock) == TrackerCode.SUCCESS
        sock.close()

    def terminate(self):
        """Terminate the server process"""
        if self.proc:
            if self.proc.is_alive():
                self._stop_tracker()
                self.proc.join(1)
            if self.proc.is_alive():
                logger.info("Terminating Tracker Server...")
                self.proc.terminate()
            self.proc = None

    def __del__(self):
        self.terminate()