"""RPC proxy, allows both client/server to connect and match connection. In normal RPC, client directly connect to server's IP address. Sometimes this cannot be done when server do not have a static address. RPCProxy allows both client and server connect to the proxy server, the proxy server will forward the message between the client and server. """ # pylint: disable=unused-variable, unused-argument from __future__ import absolute_import import os import logging import socket import multiprocessing import errno import struct import time try: import tornado from tornado import gen from tornado import websocket from tornado import ioloop from . import tornado_util except ImportError as error_msg: raise ImportError("RPCProxy module requires tornado package %s" % error_msg) from . import base from .base import TrackerCode from .server import _server_env from .._ffi.base import py_str class ForwardHandler(object): """Forward handler to forward the message.""" def _init_handler(self): """Initialize handler.""" self._init_message = bytes() self._init_req_nbytes = 4 self._magic = None self.timeout = None self._rpc_key_length = None self._done = False self._proxy = ProxyServerHandler.current assert self._proxy self.rpc_key = None self.match_key = None self.forward_proxy = None self.alloc_time = None def __del__(self): logging.info("Delete %s...", self.name()) def name(self): """Name of this connection.""" return "RPCConnection" def _init_step(self, message): if self._magic is None: assert len(message) == 4 self._magic = struct.unpack('<i', message)[0] if self._magic != base.RPC_MAGIC: logging.info("Invalid RPC magic from %s", self.name()) self.close() self._init_req_nbytes = 4 elif self._rpc_key_length is None: assert len(message) == 4 self._rpc_key_length = struct.unpack('<i', message)[0] self._init_req_nbytes = self._rpc_key_length elif self.rpc_key is None: assert len(message) == self._rpc_key_length self.rpc_key = py_str(message) # match key is used to do the matching self.match_key = self.rpc_key[7:].split()[0] self.on_start() else: assert False def on_start(self): """Event when the initialization is completed""" self._proxy.handler_ready(self) def on_data(self, message): """on data""" assert isinstance(message, bytes) if self.forward_proxy: self.forward_proxy.send_data(message) else: while message and self._init_req_nbytes > len(self._init_message): nbytes = self._init_req_nbytes - len(self._init_message) self._init_message += message[:nbytes] message = message[nbytes:] if self._init_req_nbytes == len(self._init_message): temp = self._init_message self._init_req_nbytes = 0 self._init_message = bytes() self._init_step(temp) if message: logging.info("Invalid RPC protocol, too many bytes %s", self.name()) self.close() def on_error(self, err): logging.info("%s: Error in RPC %s", self.name(), err) self.close_pair() def close_pair(self): if self.forward_proxy: self.forward_proxy.signal_close() self.forward_proxy = None self.close() def on_close_event(self): """on close event""" assert not self._done logging.info("RPCProxy:on_close %s ...", self.name()) if self.match_key: key = self.match_key if self._proxy._client_pool.get(key, None) == self: self._proxy._client_pool.pop(key) if self._proxy._server_pool.get(key, None) == self: self._proxy._server_pool.pop(key) self._done = True self.forward_proxy = None class TCPHandler(tornado_util.TCPHandler, ForwardHandler): """Event driven TCP handler.""" def __init__(self, sock, addr): super(TCPHandler, self).__init__(sock) self._init_handler() self.addr = addr def name(self): return "TCPSocketProxy:%s:%s" % (str(self.addr[0]), self.rpc_key) def send_data(self, message, binary=True): self.write_message(message, True) def on_message(self, message): self.on_data(message) def on_close(self): if self.forward_proxy: self.forward_proxy.signal_close() self.forward_proxy = None logging.info("%s Close socket..", self.name()) self.on_close_event() class WebSocketHandler(websocket.WebSocketHandler, ForwardHandler): """Handler for websockets.""" def __init__(self, *args, **kwargs): super(WebSocketHandler, self).__init__(*args, **kwargs) self._init_handler() def name(self): return "WebSocketProxy:%s" % (self.rpc_key) def on_message(self, message): self.on_data(message) def data_received(self, _): raise NotImplementedError() def send_data(self, message): try: self.write_message(message, True) except websocket.WebSocketClosedError as err: self.on_error(err) def on_close(self): if self.forward_proxy: self.forward_proxy.signal_close() self.forward_proxy = None self.on_close_event() def signal_close(self): self.close() class RequestHandler(tornado.web.RequestHandler): """Handles html request.""" def __init__(self, *args, **kwargs): file_path = kwargs.pop("file_path") if file_path.endswith("html"): self.page = open(file_path).read() web_port = kwargs.pop("rpc_web_port", None) if web_port: self.page = self.page.replace( "ws://localhost:9190/ws", "ws://localhost:%d/ws" % web_port) else: self.page = open(file_path, "rb").read() super(RequestHandler, self).__init__(*args, **kwargs) def data_received(self, _): pass def get(self, *args, **kwargs): self.write(self.page) class ProxyServerHandler(object): """Internal proxy server handler class.""" current = None def __init__(self, sock, listen_port, web_port, timeout_client, timeout_server, tracker_addr, index_page=None, resource_files=None): assert ProxyServerHandler.current is None ProxyServerHandler.current = self if web_port: handlers = [ (r"/ws", WebSocketHandler), ] if index_page: handlers.append( (r"/", RequestHandler, {"file_path": index_page, "rpc_web_port": web_port})) logging.info("Serving RPC index html page at http://localhost:%d", web_port) resource_files = resource_files if resource_files else [] for fname in resource_files: basename = os.path.basename(fname) pair = (r"/%s" % basename, RequestHandler, {"file_path": fname}) handlers.append(pair) logging.info(pair) self.app = tornado.web.Application(handlers) self.app.listen(web_port) self.sock = sock self.sock.setblocking(0) self.loop = ioloop.IOLoop.current() def event_handler(_, events): self._on_event(events) self.loop.add_handler( self.sock.fileno(), event_handler, self.loop.READ) self._client_pool = {} self._server_pool = {} self.timeout_alloc = 5 self.timeout_client = timeout_client self.timeout_server = timeout_server # tracker information self._listen_port = listen_port self._tracker_addr = tracker_addr self._tracker_conn = None self._tracker_pending_puts = [] self._key_set = set() self.update_tracker_period = 2 if tracker_addr: logging.info("Tracker address:%s", str(tracker_addr)) def _callback(): self._update_tracker(True) self.loop.call_later(self.update_tracker_period, _callback) logging.info("RPCProxy: Websock port bind to %d", web_port) def _on_event(self, _): while True: try: conn, addr = self.sock.accept() TCPHandler(conn, addr) except socket.error as err: if err.args[0] in (errno.EAGAIN, errno.EWOULDBLOCK): break def _pair_up(self, lhs, rhs): lhs.forward_proxy = rhs rhs.forward_proxy = lhs lhs.send_data(struct.pack('<i', base.RPC_CODE_SUCCESS)) lhs.send_data(struct.pack('<i', len(rhs.rpc_key))) lhs.send_data(rhs.rpc_key.encode("utf-8")) rhs.send_data(struct.pack('<i', base.RPC_CODE_SUCCESS)) rhs.send_data(struct.pack('<i', len(lhs.rpc_key))) rhs.send_data(lhs.rpc_key.encode("utf-8")) logging.info("Pairup connect %s and %s", lhs.name(), rhs.name()) def _regenerate_server_keys(self, keys): """Regenerate keys for server pool""" keyset = set(self._server_pool.keys()) new_keys = [] # re-generate the server match key, so old information is invalidated. for key in keys: rpc_key, _ = key.split(":") handle = self._server_pool[key] del self._server_pool[key] new_key = base.random_key(rpc_key + ":", keyset) self._server_pool[new_key] = handle keyset.add(new_key) new_keys.append(new_key) return new_keys def _update_tracker(self, period_update=False): """Update information on tracker.""" try: if self._tracker_conn is None: self._tracker_conn = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self._tracker_conn.connect(self._tracker_addr) self._tracker_conn.sendall(struct.pack("<i", base.RPC_TRACKER_MAGIC)) magic = struct.unpack("<i", base.recvall(self._tracker_conn, 4))[0] if magic != base.RPC_TRACKER_MAGIC: self.loop.stop() raise RuntimeError("%s is not RPC Tracker" % str(self._tracker_addr)) # just connect to tracker, need to update all keys self._tracker_pending_puts = self._server_pool.keys() if self._tracker_conn and period_update: # periodically update tracker information # regenerate key if the key is not in tracker anymore # and there is no in-coming connection after timeout_alloc base.sendjson(self._tracker_conn, [TrackerCode.GET_PENDING_MATCHKEYS]) pending_keys = set(base.recvjson(self._tracker_conn)) update_keys = [] for k, v in self._server_pool.items(): if k not in pending_keys: if v.alloc_time is None: v.alloc_time = time.time() elif time.time() - v.alloc_time > self.timeout_alloc: update_keys.append(k) v.alloc_time = None if update_keys: logging.info("RPCProxy: No incoming conn on %s, regenerate keys...", str(update_keys)) new_keys = self._regenerate_server_keys(update_keys) self._tracker_pending_puts += new_keys need_update_info = False # report new connections for key in self._tracker_pending_puts: rpc_key = key.split(":")[0] base.sendjson(self._tracker_conn, [TrackerCode.PUT, rpc_key, (self._listen_port, key), None]) assert base.recvjson(self._tracker_conn) == TrackerCode.SUCCESS if rpc_key not in self._key_set: self._key_set.add(rpc_key) need_update_info = True if need_update_info: keylist = "[" + ",".join(self._key_set) + "]" cinfo = {"key": "server:proxy" + keylist} base.sendjson(self._tracker_conn, [TrackerCode.UPDATE_INFO, cinfo]) assert base.recvjson(self._tracker_conn) == TrackerCode.SUCCESS self._tracker_pending_puts = [] except (socket.error, IOError) as err: logging.info( "Lost tracker connection: %s, try reconnect in %g sec", str(err), self.update_tracker_period) self._tracker_conn.close() self._tracker_conn = None self._regenerate_server_keys(self._server_pool.keys()) if period_update: def _callback(): self._update_tracker(True) self.loop.call_later(self.update_tracker_period, _callback) def _handler_ready_tracker_mode(self, handler): """tracker mode to handle handler ready.""" if handler.rpc_key.startswith("server:"): key = base.random_key(handler.match_key + ":", self._server_pool) handler.match_key = key self._server_pool[key] = handler self._tracker_pending_puts.append(key) self._update_tracker() else: if handler.match_key in self._server_pool: self._pair_up(self._server_pool.pop(handler.match_key), handler) else: handler.send_data(struct.pack('<i', base.RPC_CODE_MISMATCH)) handler.signal_close() def _handler_ready_proxy_mode(self, handler): """Normal proxy mode when handler is ready.""" if handler.rpc_key.startswith("server:"): pool_src, pool_dst = self._client_pool, self._server_pool timeout = self.timeout_server else: pool_src, pool_dst = self._server_pool, self._client_pool timeout = self.timeout_client key = handler.match_key if key in pool_src: self._pair_up(pool_src.pop(key), handler) return elif key not in pool_dst: pool_dst[key] = handler def cleanup(): """Cleanup client connection if timeout""" if pool_dst.get(key, None) == handler: logging.info("Timeout client connection %s, cannot find match key=%s", handler.name(), key) pool_dst.pop(key) handler.send_data(struct.pack('<i', base.RPC_CODE_MISMATCH)) handler.signal_close() self.loop.call_later(timeout, cleanup) else: logging.info("Duplicate connection with same key=%s", key) handler.send_data(struct.pack('<i', base.RPC_CODE_DUPLICATE)) handler.signal_close() def handler_ready(self, handler): """Report handler to be ready.""" logging.info("Handler ready %s", handler.name()) if self._tracker_addr: self._handler_ready_tracker_mode(handler) else: self._handler_ready_proxy_mode(handler) def run(self): """Run the proxy server""" ioloop.IOLoop.current().start() def _proxy_server(listen_sock, listen_port, web_port, timeout_client, timeout_server, tracker_addr, index_page, resource_files): handler = ProxyServerHandler(listen_sock, listen_port, web_port, timeout_client, timeout_server, tracker_addr, index_page, resource_files) handler.run() class Proxy(object): """Start RPC proxy server on a seperate process. Python implementation based on multi-processing. Parameters ---------- host : str The host url of the server. port : int The TCP port to be bind to port_end : int, optional The end TCP port to search web_port : int, optional The http/websocket port of the server. timeout_client : float, optional Timeout of client until it sees a matching connection. timeout_server : float, optional Timeout of server until it sees a matching connection. index_page : str, optional Path to an index page that can be used to display at proxy index. resource_files : str, optional Path to local resources that can be included in the http request """ def __init__(self, host, port=9091, port_end=9199, web_port=0, timeout_client=600, timeout_server=600, tracker_addr=None, index_page=None, resource_files=None): sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.port = None for my_port in range(port, port_end): try: sock.bind((host, my_port)) self.port = my_port break except socket.error as sock_err: if sock_err.errno in [98, 48]: continue else: raise sock_err if not self.port: raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end)) logging.info("RPCProxy: client port bind to %s:%d", host, self.port) sock.listen(1) self.proc = multiprocessing.Process( target=_proxy_server, args=(sock, self.port, web_port, timeout_client, timeout_server, tracker_addr, index_page, resource_files)) self.proc.start() sock.close() self.host = host def terminate(self): """Terminate the server process""" if self.proc: logging.info("Terminating Proxy Server...") self.proc.terminate() self.proc = None def __del__(self): self.terminate() def websocket_proxy_server(url, key=""): """Create a RPC server that uses an websocket that connects to a proxy. Parameters ---------- url : str The url to be connected. key : str The key to identify the server. """ def create_on_message(conn): def _fsend(data): data = bytes(data) conn.write_message(data, binary=True) return len(data) on_message = base._CreateEventDrivenServer( _fsend, "WebSocketProxyServer", "%toinit") return on_message @gen.coroutine def _connect(key): conn = yield websocket.websocket_connect(url) on_message = create_on_message(conn) temp = _server_env(None, None) # Start connecton conn.write_message(struct.pack('<i', base.RPC_MAGIC), binary=True) key = "server:" + key conn.write_message(struct.pack('<i', len(key)), binary=True) conn.write_message(key.encode("utf-8"), binary=True) msg = yield conn.read_message() assert len(msg) >= 4 magic = struct.unpack('<i', msg[:4])[0] if magic == base.RPC_CODE_DUPLICATE: raise RuntimeError("key: %s has already been used in proxy" % key) elif magic == base.RPC_CODE_MISMATCH: logging.info("RPCProxy do not have matching client key %s", key) elif magic != base.RPC_CODE_SUCCESS: raise RuntimeError("%s is not RPC Proxy" % url) msg = msg[4:] logging.info("Connection established with remote") if msg: on_message(bytearray(msg), 3) while True: try: msg = yield conn.read_message() if msg is None: break on_message(bytearray(msg), 3) except websocket.WebSocketClosedError as err: break logging.info("WebSocketProxyServer closed...") temp.remove() ioloop.IOLoop.current().stop() ioloop.IOLoop.current().spawn_callback(_connect, key) ioloop.IOLoop.current().start()