Commit c324494f by Tianqi Chen Committed by GitHub

[RUNTIME][RPC] Change RPCServer to Event Driven Code (#243)

* [RUNTIME][RPC] Change RPCServer to Event Driven Code

* fix
parent f4d1dddb
......@@ -58,7 +58,7 @@ CFLAGS = -std=c++11 -Wall -O2 $(INCLUDE_FLAGS) -fPIC
LLVM_CFLAGS= -fno-rtti -DDMLC_ENABLE_RTTI=0
FRAMEWORKS =
OBJCFLAGS = -fno-objc-arc
EMCC_FLAGS= -s RESERVED_FUNCTION_POINTERS=2 -s NO_EXIT_RUNTIME=1 -DDMLC_LOG_STACK_TRACE=0\
EMCC_FLAGS= -s RESERVED_FUNCTION_POINTERS=2 -s NO_EXIT_RUNTIME=1 -s MAIN_MODULE=1 -DDMLC_LOG_STACK_TRACE=0\
-std=c++11 -Oz $(INCLUDE_FLAGS)
# Dependency specific rules
......
......@@ -212,13 +212,6 @@ class TVMPODValue_ {
int type_code() const {
return type_code_;
}
protected:
friend class TVMArgsSetter;
friend class TVMRetValue;
TVMPODValue_() : type_code_(kNull) {}
TVMPODValue_(TVMValue value, int type_code)
: value_(value), type_code_(type_code) {}
/*!
* \brief return handle as specific pointer type.
* \tparam T the data type.
......@@ -228,6 +221,14 @@ class TVMPODValue_ {
T* ptr() const {
return static_cast<T*>(value_.v_handle);
}
protected:
friend class TVMArgsSetter;
friend class TVMRetValue;
TVMPODValue_() : type_code_(kNull) {}
TVMPODValue_(TVMValue value, int type_code)
: value_(value), type_code_(type_code) {}
/*! \brief The value */
TVMValue value_;
/*! \brief the type code */
......@@ -347,6 +348,8 @@ class TVMRetValue : public TVMPODValue_ {
operator std::string() const {
if (type_code_ == kTVMType) {
return TVMType2String(operator TVMType());
} else if (type_code_ == kBytes) {
return *ptr<std::string>();
}
TVM_CHECK_TYPE_CODE(type_code_, kStr);
return *ptr<std::string>();
......@@ -414,6 +417,10 @@ class TVMRetValue : public TVMPODValue_ {
this->SwitchToClass(kStr, value);
return *this;
}
TVMRetValue& operator=(TVMByteArray value) {
this->SwitchToClass(kBytes, std::string(value.data, value.size));
return *this;
}
TVMRetValue& operator=(PackedFunc f) {
this->SwitchToClass(kFuncHandle, f);
return *this;
......@@ -442,7 +449,7 @@ class TVMRetValue : public TVMPODValue_ {
void MoveToCHost(TVMValue* ret_value,
int* ret_type_code) {
// cannot move str; need specially handle.
CHECK(type_code_ != kStr);
CHECK(type_code_ != kStr && type_code_ != kBytes);
*ret_value = value_;
*ret_type_code = type_code_;
type_code_ = kNull;
......@@ -470,11 +477,14 @@ class TVMRetValue : public TVMPODValue_ {
template<typename T>
void Assign(const T& other) {
switch (other.type_code()) {
case kStr:
case kBytes: {
case kStr: {
SwitchToClass<std::string>(kStr, other);
break;
}
case kBytes: {
SwitchToClass<std::string>(kBytes, other);
break;
}
case kFuncHandle: {
SwitchToClass<PackedFunc>(kFuncHandle, other);
break;
......@@ -702,6 +712,7 @@ class TVMArgsSetter {
values_[i].v_str = value.ptr<std::string>()->c_str();
type_codes_[i] = kStr;
} else {
CHECK_NE(value.type_code(), kBytes) << "not handled.";
values_[i] = value.value_;
type_codes_[i] = value.type_code();
}
......
......@@ -7,6 +7,7 @@ from .._ffi.libinfo import find_lib_path
def create_js(output,
objects,
options=None,
side_module=False,
cc="emcc"):
"""Create emscripten javascript library.
......@@ -29,6 +30,8 @@ def create_js(output,
cmd += ["-s", "NO_EXIT_RUNTIME=1"]
cmd += ["-Oz"]
cmd += ["-o", output]
if side_module:
cmd += ["-s", "SIDE_MODULE=1"]
objects = [objects] if isinstance(objects, str) else objects
with_runtime = False
......@@ -36,7 +39,7 @@ def create_js(output,
if obj.find("libtvm_web_runtime.bc") != -1:
with_runtime = True
if not with_runtime:
if not with_runtime and not side_module:
objects += [find_lib_path("libtvm_web_runtime.bc")[0]]
cmd += objects
......
......@@ -19,32 +19,20 @@ from . import util, cc_compiler
from ..module import load as _load_module
from .._ffi.function import _init_api, register_func
from .._ffi.ndarray import context as _context
from .._ffi.base import py_str
RPC_MAGIC = 0xff271
RPC_SESS_MASK = 128
def _serve_loop(sock, addr):
"""Server loop"""
sockfd = sock.fileno()
def _server_env():
"""Server environment function return temp dir"""
temp = util.tempdir()
# pylint: disable=unused-variable
@register_func("tvm.contrib.rpc.server.upload")
def upload(file_name, blob):
"""Upload the blob to remote temp file"""
path = temp.relpath(file_name)
with open(path, "wb") as out_file:
out_file.write(blob)
logging.info("upload %s", path)
@register_func("tvm.contrib.rpc.server.download")
def download(file_name):
"""Download file from remote"""
path = temp.relpath(file_name)
dat = bytearray(open(path, "rb").read())
logging.info("download %s", path)
return dat
@register_func("tvm.contrib.rpc.server.workpath")
def get_workpath(path):
return temp.relpath(path)
@register_func("tvm.contrib.rpc.server.load_module")
@register_func("tvm.contrib.rpc.server.load_module", override=True)
def load_module(file_name):
"""Load module from remote side."""
path = temp.relpath(file_name)
......@@ -53,11 +41,16 @@ def _serve_loop(sock, addr):
logging.info('Create shared library based on %s', path)
cc_compiler.create_shared(path + '.so', path)
path += '.so'
m = _load_module(path)
logging.info("load_module %s", path)
return m
return temp
def _serve_loop(sock, addr):
"""Server loop"""
sockfd = sock.fileno()
temp = _server_env()
_ServerLoop(sockfd)
temp.remove()
logging.info("Finish serving %s", addr)
......@@ -78,11 +71,16 @@ def _listen_loop(sock):
while True:
conn, addr = sock.accept()
logging.info("RPCServer: connection from %s", addr)
conn.sendall(struct.pack('@i', RPC_MAGIC))
magic = struct.unpack('@i', _recvall(conn, 4))[0]
if magic != RPC_MAGIC:
conn.close()
continue
keylen = struct.unpack('@i', _recvall(conn, 4))[0]
key = py_str(_recvall(conn, keylen))
if not key.startswith("client:"):
conn.sendall(struct.pack('@i', RPC_MAGIC + 2))
else:
conn.sendall(struct.pack('@i', RPC_MAGIC))
logging.info("Connection from %s", addr)
process = multiprocessing.Process(target=_serve_loop, args=(conn, addr))
process.deamon = True
......@@ -91,6 +89,27 @@ def _listen_loop(sock):
conn.close()
def _connect_proxy_loop(addr, key):
while True:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect(addr)
sock.sendall(struct.pack('@i', RPC_MAGIC))
sock.sendall(struct.pack('@i', len(key)))
sock.sendall(key)
magic = struct.unpack('@i', _recvall(sock, 4))[0]
if magic == RPC_MAGIC + 1:
raise RuntimeError("key: %s has already been used in proxy" % key)
elif magic == RPC_MAGIC + 2:
logging.info("RPCProxy do not have matching client key %s", key)
elif magic != RPC_MAGIC:
raise RuntimeError("%s is not RPC Proxy" % str(addr))
logging.info("RPCProxy connected to %s", str(addr))
process = multiprocessing.Process(target=_serve_loop, args=(sock, addr))
process.deamon = True
process.start()
process.join()
class Server(object):
"""Start RPC server on a seperate process.
......@@ -108,27 +127,43 @@ class Server(object):
port_end : int, optional
The end port to search
is_proxy : bool, optional
Whether the address specified is a proxy.
If this is true, the host and port actually corresponds to the
address of the proxy server.
key : str, optional
The key used to identify the server in Proxy connection.
"""
def __init__(self, host, port=9091, port_end=9199):
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.port = None
for my_port in range(port, port_end):
try:
sock.bind((host, my_port))
self.port = my_port
break
except socket.error as sock_err:
if sock_err.errno in [98, 48]:
continue
else:
raise sock_err
if not self.port:
raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end))
logging.info("RPCServer: bind to %s:%d", host, self.port)
sock.listen(1)
self.sock = sock
def __init__(self, host, port=9091, port_end=9199, is_proxy=False, key=""):
self.host = host
self.proc = multiprocessing.Process(target=_listen_loop, args=(self.sock,))
self.port = port
if not is_proxy:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.port = None
for my_port in range(port, port_end):
try:
sock.bind((host, my_port))
self.port = my_port
break
except socket.error as sock_err:
if sock_err.errno in [98, 48]:
continue
else:
raise sock_err
if not self.port:
raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end))
logging.info("RPCServer: bind to %s:%d", host, self.port)
sock.listen(1)
self.sock = sock
self.proc = multiprocessing.Process(
target=_listen_loop, args=(self.sock,))
else:
self.proc = multiprocessing.Process(
target=_connect_proxy_loop, args=((host, port), key))
self.proc.deamon = True
self.proc.start()
def terminate(self):
......@@ -141,6 +176,7 @@ class Server(object):
self.terminate()
class RPCSession(object):
"""RPC Client session module
......@@ -262,7 +298,7 @@ class RPCSession(object):
return _LoadRemoteModule(self._sess, path)
def connect(url, port):
def connect(url, port, key=""):
"""Connect to RPC Server
Parameters
......@@ -273,13 +309,16 @@ def connect(url, port):
port : int
The port to connect to
key : str, optional
Additional key to match server
Returns
-------
sess : RPCSession
The connected session.
"""
try:
sess = _Connect(url, port)
sess = _Connect(url, port, key)
except NameError:
raise RuntimeError('Please compile with USE_RPC=1')
return RPCSession(sess)
......
"""RPC proxy, allows both client/server to connect and match connection.
In normal RPC, client directly connect to server's IP address.
Sometimes this cannot be done when server do not have a static address.
RPCProxy allows both client and server connect to the proxy server,
the proxy server will forward the message between the client and server.
"""
# pylint: disable=unused-variable, unused-argument
from __future__ import absolute_import
import os
import logging
import socket
import multiprocessing
import errno
import struct
try:
import tornado
from tornado import gen
from tornado import websocket
from tornado import ioloop
from tornado import websocket
except ImportError as error_msg:
raise ImportError("RPCProxy module requires tornado package %s" % error_msg)
from . import rpc
from .rpc import RPC_MAGIC, _server_env
from .._ffi.base import py_str
class ForwardHandler(object):
"""Forward handler to forward the message."""
def _init_handler(self):
"""Initialize handler."""
self._init_message = bytes()
self._init_req_nbytes = 4
self.forward_proxy = None
self._magic = None
self.timeout = None
self._rpc_key_length = None
self.rpc_key = None
self._done = False
def __del__(self):
logging.info("Delete %s...", self.name())
def name(self):
"""Name of this connection."""
return "RPCConnection"
def _init_step(self, message):
if self._magic is None:
assert len(message) == 4
self._magic = struct.unpack('@i', message)[0]
if self._magic != RPC_MAGIC:
logging.info("Invalid RPC magic from %s", self.name())
self.close()
self._init_req_nbytes = 4
elif self._rpc_key_length is None:
assert len(message) == 4
self._rpc_key_length = struct.unpack('@i', message)[0]
self._init_req_nbytes = self._rpc_key_length
elif self.rpc_key is None:
assert len(message) == self._rpc_key_length
self.rpc_key = py_str(message)
self.on_start()
else:
assert False
def on_start(self):
"""Event when the initialization is completed"""
ProxyServerHandler.current.handler_ready(self)
def on_data(self, message):
assert isinstance(message, bytes)
if self.forward_proxy:
self.forward_proxy.send_data(message)
else:
while message and self._init_req_nbytes > len(self._init_message):
nbytes = self._init_req_nbytes - len(self._init_message)
self._init_message += message[:nbytes]
message = message[nbytes:]
if self._init_req_nbytes == len(self._init_message):
temp = self._init_message
self._init_req_nbytes = 0
self._init_message = bytes()
self._init_step(temp)
if message:
logging.info("Invalid RPC protocol, too many bytes %s", self.name())
self.close()
def on_error(self, err):
logging.info("%s: Error in RPC %s", self.name(), err)
self.close_pair()
def close_pair(self):
if self.forward_proxy:
self.forward_proxy.signal_close()
self.forward_proxy = None
self.close()
def on_close_event(self):
assert not self._done
logging.info("RPCProxy:on_close %s ...", self.name())
self._done = True
self.forward_proxy = None
if self.rpc_key:
key = self.rpc_key[6:]
if ProxyServerHandler.current._client_pool.get(key, None) == self:
ProxyServerHandler.current._client_pool.pop(key)
if ProxyServerHandler.current._server_pool.get(key, None) == self:
ProxyServerHandler.current._server_pool.pop(key)
class TCPHandler(ForwardHandler):
"""Event driven TCP handler."""
def __init__(self, sock, addr):
self._init_handler()
self.sock = sock
assert self.sock
self.addr = addr
self.loop = ioloop.IOLoop.current()
self.sock.setblocking(0)
self.pending_write = []
self._signal_close = False
def event_handler(_, events):
self._on_event(events)
ioloop.IOLoop.current().add_handler(
self.sock.fileno(), event_handler, self.loop.READ | self.loop.ERROR)
def name(self):
return "TCPSocket: %s:%s" % (str(self.addr), self.rpc_key)
def send_data(self, message, binary=True):
assert binary
self.pending_write.append(message)
self._on_write()
def _on_write(self):
while self.pending_write:
try:
msg = self.pending_write[0]
nsend = self.sock.send(msg)
if nsend != len(msg):
self.pending_write[0] = msg[nsend:]
else:
del self.pending_write[0]
except socket.error as err:
if err.args[0] in (errno.EAGAIN, errno.EWOULDBLOCK):
break
else:
self.on_error(err)
if self.pending_write:
self.loop.update_handler(
self.sock.fileno(), self.loop.READ | self.loop.ERROR | self.loop.WRITE)
else:
if self._signal_close:
self.close()
else:
self.loop.update_handler(
self.sock.fileno(), self.loop.READ | self.loop.ERROR)
def _on_read(self):
try:
msg = bytes(self.sock.recv(4096))
if msg:
self.on_data(msg)
return True
else:
self.close_pair()
except socket.error as err:
if err.args[0] in (errno.EAGAIN, errno.EWOULDBLOCK):
pass
else:
self.on_error(e)
return False
def _on_event(self, events):
if (events & self.loop.ERROR) or (events & self.loop.READ):
if self._on_read() and (events & self.loop.WRITE):
self._on_write()
elif events & self.loop.WRITE:
self._on_write()
def signal_close(self):
if not self.pending_write:
self.close()
else:
self._signal_close = True
def close(self):
if self.sock is not None:
logging.info("%s Close socket..", self.name())
try:
ioloop.IOLoop.current().remove_handler(self.sock.fileno())
self.sock.close()
except socket.error:
pass
self.sock = None
self.on_close_event()
class WebSocketHandler(websocket.WebSocketHandler, ForwardHandler):
"""Handler for websockets."""
def __init__(self, *args, **kwargs):
super(WebSocketHandler, self).__init__(*args, **kwargs)
self._init_handler()
def name(self):
return "WebSocketProxy"
def on_message(self, message):
self.on_data(message)
def data_received(self, _):
raise NotImplementedError()
def send_data(self, message):
try:
self.write_message(message, True)
except websocket.WebSocketClosedError as err:
self.on_error(err)
def on_close(self):
if self.forward_proxy:
self.forward_proxy.signal_close()
self.forward_proxy = None
self.on_close_event()
def signal_close(self):
self.close()
class RequestHandler(tornado.web.RequestHandler):
"""Handles html request."""
def __init__(self, *args, **kwargs):
self.page = open(kwargs.pop("file_path")).read()
web_port = kwargs.pop("rpc_web_port", None)
if web_port:
self.page.replace(r"ws://localhost:9888/ws",
r"ws://localhost:%d/ws" % web_port)
super(RequestHandler, self).__init__(*args, **kwargs)
def data_received(self, _):
pass
def get(self, *args, **kwargs):
self.write(self.page)
class ProxyServerHandler(object):
"""Internal proxy server handler class."""
current = None
def __init__(self,
sock,
web_port,
timeout_client,
timeout_server,
index_page=None,
resource_files=None):
assert ProxyServerHandler.current is None
ProxyServerHandler.current = self
if web_port:
handlers = [
(r"/ws", WebSocketHandler),
]
if index_page:
handlers.append(
(r"/", RequestHandler, {"file_path": index_page, "rpc_web_port": web_port}))
logging.info("Serving RPC index html page at http://localhost:%d", web_port)
resource_files = resource_files if resource_files else []
for fname in resource_files:
basename = os.path.basename(fname)
pair = (r"/%s" % basename, RequestHandler, {"file_path": fname})
handlers.append(pair)
logging.info(pair)
self.app = tornado.web.Application(handlers)
self.app.listen(web_port)
self.sock = sock
self.sock.setblocking(0)
self.loop = ioloop.IOLoop.current()
def event_handler(_, events):
self._on_event(events)
self.loop.add_handler(
self.sock.fileno(), event_handler, self.loop.READ)
self._client_pool = {}
self._server_pool = {}
self.timeout_client = timeout_client
self.timeout_server = timeout_server
logging.info("RPCProxy: Websock port bind to %d", web_port)
def _on_event(self, _):
while True:
try:
conn, addr = self.sock.accept()
TCPHandler(conn, addr)
except socket.error as err:
if err.args[0] in (errno.EAGAIN, errno.EWOULDBLOCK):
break
def _pair_up(self, lhs, rhs):
lhs.forward_proxy = rhs
rhs.forward_proxy = lhs
lhs.send_data(struct.pack('@i', RPC_MAGIC))
rhs.send_data(struct.pack('@i', RPC_MAGIC))
logging.info("Pairup connect %s and %s", lhs.name(), rhs.name())
def handler_ready(self, handler):
"""Report handler to be ready."""
logging.info("Handler ready %s", handler.name())
key = handler.rpc_key[6:]
if handler.rpc_key.startswith("server:"):
pool_src, pool_dst = self._client_pool, self._server_pool
timeout = self.timeout_server
else:
pool_src, pool_dst = self._server_pool, self._client_pool
timeout = self.timeout_client
if key in pool_src:
self._pair_up(pool_src.pop(key), handler)
return
elif key not in pool_dst:
pool_dst[key] = handler
def cleanup():
"""Cleanup client connection if timeout"""
if pool_dst.get(key, None) == handler:
logging.info("Timeout client connection %s, cannot find match key=%s",
handler.name(), key)
pool_dst.pop(key)
handler.send_data(struct.pack('@i', RPC_MAGIC + 2))
handler.signal_close()
self.loop.call_later(timeout, cleanup)
else:
logging.info("Duplicate connection with same key=%s", key)
handler.send_data(struct.pack('@i', RPC_MAGIC + 1))
handler.signal_close()
def run(self):
"""Run the proxy server"""
ioloop.IOLoop.current().start()
def _proxy_server(listen_sock,
web_port,
timeout_client,
timeout_server,
index_page,
resource_files):
handler = ProxyServerHandler(listen_sock,
web_port,
timeout_client,
timeout_server,
index_page,
resource_files)
handler.run()
class Proxy(object):
"""Start RPC proxy server on a seperate process.
Python implementation based on multi-processing.
Parameters
----------
host : str
The host url of the server.
port : int
The TCP port to be bind to
port_end : int, optional
The end TCP port to search
web_port : int, optional
The http/websocket port of the server.
timeout_client : float, optional
Timeout of client until it sees a matching connection.
timeout_server : float, optional
Timeout of server until it sees a matching connection.
index_page : str, optional
Path to an index page that can be used to display at proxy index.
resource_files : str, optional
Path to local resources that can be included in the http request
"""
def __init__(self,
host,
port=9091,
port_end=9199,
web_port=0,
timeout_client=10,
timeout_server=600,
index_page=None,
resource_files=None):
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.port = None
for my_port in range(port, port_end):
try:
sock.bind((host, my_port))
self.port = my_port
break
except socket.error as sock_err:
if sock_err.errno in [98, 48]:
continue
else:
raise sock_err
if not self.port:
raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end))
logging.info("RPCProxy: client port bind to %s:%d", host, self.port)
sock.listen(1)
self.proc = multiprocessing.Process(
target=_proxy_server, args=(sock, web_port,
timeout_client, timeout_server,
index_page, resource_files))
self.proc.start()
self.host = host
def terminate(self):
"""Terminate the server process"""
if self.proc:
logging.info("Terminating Proxy Server...")
self.proc.terminate()
self.proc = None
def __del__(self):
self.terminate()
def websocket_proxy_server(url, key=""):
"""Create a RPC server that uses an websocket that connects to a proxy.
Parameters
----------
url : str
The url to be connected.
key : str
The key to identify the server.
"""
def create_on_message(conn):
def _fsend(data):
data = bytes(data)
conn.write_message(data, binary=True)
return len(data)
on_message = rpc._CreateEventDrivenServer(_fsend, "WebSocketProxyServer")
return on_message
@gen.coroutine
def _connect(key):
conn = yield websocket.websocket_connect(url)
on_message = create_on_message(conn)
temp = _server_env()
# Start connecton
conn.write_message(struct.pack('@i', RPC_MAGIC), binary=True)
key = "server:" + key
conn.write_message(struct.pack('@i', len(key)), binary=True)
conn.write_message(key.encode("utf-8"), binary=True)
msg = yield conn.read_message()
assert len(msg) >= 4
magic = struct.unpack('@i', msg[:4])[0]
if magic == RPC_MAGIC + 1:
raise RuntimeError("key: %s has already been used in proxy" % key)
elif magic == RPC_MAGIC + 2:
logging.info("RPCProxy do not have matching client key %s", key)
elif magic != RPC_MAGIC:
raise RuntimeError("%s is not RPC Proxy" % url)
logging.info("Connection established")
msg = msg[4:]
if msg:
on_message(bytearray(msg))
while True:
try:
msg = yield conn.read_message()
if msg is None:
break
on_message(bytearray(msg))
except websocket.WebSocketClosedError as err:
break
logging.info("WebSocketProxyServer closed...")
temp.remove()
ioloop.IOLoop.current().stop()
ioloop.IOLoop.current().spawn_callback(_connect, key)
ioloop.IOLoop.current().start()
"""RPC web proxy, allows redirect to websocket based RPC servers(browsers)"""
from __future__ import absolute_import
import logging
import argparse
import os
from ..contrib.rpc_proxy import Proxy
def find_example_resource():
"""Find resource examples."""
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
base_path = os.path.join(curr_path, "../../../")
index_page = os.path.join(base_path, "web/example_rpc.html")
js_files = [
os.path.join(base_path, "web/tvm_runtime.js"),
os.path.join(base_path, "lib/libtvm_web_runtime.js"),
os.path.join(base_path, "lib/libtvm_web_runtime.js.mem")
]
for fname in [index_page] + js_files:
if not os.path.exists(fname):
raise RuntimeError("Cannot find %s" % fname)
return index_page, js_files
def main():
"""Main funciton"""
parser = argparse.ArgumentParser()
parser.add_argument('--host', type=str, default="0.0.0.0",
help='the hostname of the server')
parser.add_argument('--port', type=int, default=9090,
help='The port of the PRC')
parser.add_argument('--web-port', type=int, default=9888,
help='The port of the http/websocket server')
parser.add_argument('--example-rpc', type=bool, default=False,
help='Whether to switch on example rpc mode')
args = parser.parse_args()
logging.basicConfig(level=logging.INFO)
if args.example_rpc:
index, js_files = find_example_resource()
prox = Proxy(args.host, port=args.port,
web_port=args.web_port, index_page=index,
resource_files=js_files)
else:
prox = Proxy(args.host, port=args.port, web_port=args.web_port)
prox.proc.join()
if __name__ == "__main__":
main()
/*!
* Copyright (c) 2017 by Contributors
* \file ring_buffer.h
* \brief this file aims to provide a wrapper of sockets
*/
#ifndef TVM_COMMON_RING_BUFFER_H_
#define TVM_COMMON_RING_BUFFER_H_
#include <vector>
#include <cstring>
#include <algorithm>
namespace tvm {
namespace common {
/*!
* \brief Ring buffer class for data buffering in IO.
* Enables easy usage for sync and async mode.
*/
class RingBuffer {
public:
/*! \brief Initial capacity of ring buffer. */
static const int kInitCapacity = 4 << 10;
/*! \brief constructor */
RingBuffer() : ring_(kInitCapacity) {}
/*! \return number of bytes available in buffer. */
size_t bytes_available() const {
return bytes_available_;
}
/*! \return Current capacity of buffer. */
size_t capacity() const {
return ring_.size();
}
/*!
* Reserve capacity to be at least n.
* Will only increase capacity if n is bigger than current capacity.
* \param n The size of capacity.
*/
void Reserve(size_t n) {
if (ring_.size() >= n) return;
size_t old_size = ring_.size();
size_t new_size = ring_.size();
while (new_size < n) {
new_size *= 2;
}
ring_.resize(new_size);
if (head_ptr_ + bytes_available_ > old_size) {
// copy the ring overflow part into the tail.
size_t ncopy = head_ptr_ + bytes_available_ - old_size;
memcpy(&ring_[0] + old_size, &ring_[0], ncopy);
}
}
/*!
* \brief Peform a non-blocking read from buffer
* size must be smaller than this->bytes_available()
* \param data the data pointer.
* \param size The number of bytes to read.
*/
void Read(void* data, size_t size) {
CHECK_GE(bytes_available_, size);
size_t ncopy = std::min(size, ring_.size() - head_ptr_);
memcpy(data, &ring_[0] + head_ptr_, ncopy);
if (ncopy < size) {
memcpy(reinterpret_cast<char*>(data) + ncopy,
&ring_[0], size - ncopy);
}
head_ptr_ = (head_ptr_ + size) % ring_.size();
bytes_available_ -= size;
}
/*!
* \brief Read data from buffer with and put them to non-blocking send function.
*
* \param frecv A send function handle to put the data to.
* \param max_nbytes Maximum number of bytes can to read.
* \tparam FSend A non-blocking function with signature size_t (const void* data, size_t size);
*/
template<typename FSend>
size_t ReadWithCallback(FSend fsend, size_t max_nbytes) {
size_t size = std::min(max_nbytes, bytes_available_);
CHECK_NE(size, 0U);
size_t ncopy = std::min(size, ring_.size() - head_ptr_);
size_t nsend = fsend(&ring_[0] + head_ptr_, ncopy);
bytes_available_ -= nsend;
if (ncopy == nsend && ncopy < size) {
size_t nsend2 = fsend(&ring_[0], size - ncopy);
bytes_available_ -= nsend2;
nsend += nsend2;
}
return nsend;
}
/*!
* \brief Write data into buffer, always ensures all data is written.
* \param data The data pointer
* \param size The size of data to be written.
*/
void Write(const void* data, size_t size) {
this->Reserve(bytes_available_ + size);
size_t tail = head_ptr_ + bytes_available_;
if (tail >= ring_.size()) {
memcpy(&ring_[0] + (tail - ring_.size()), data, size);
} else {
size_t ncopy = std::min(ring_.size() - tail, size);
memcpy(&ring_[0] + tail, data, ncopy);
if (ncopy < size) {
memcpy(&ring_[0], reinterpret_cast<const char*>(data) + ncopy, size - ncopy);
}
}
bytes_available_ += size;
}
/*!
* \brief Writen data into the buffer by give it a non-blocking callback function.
*
* \param frecv A receive function handle
* \param max_nbytes Maximum number of bytes can write.
* \tparam FRecv A non-blocking function with signature size_t (void* data, size_t size);
*/
template<typename FRecv>
size_t WriteWithCallback(FRecv frecv, size_t max_nbytes) {
this->Reserve(bytes_available_ + max_nbytes);
size_t nbytes = max_nbytes;
size_t tail = head_ptr_ + bytes_available_;
if (tail >= ring_.size()) {
size_t nrecv = frecv(&ring_[0] + (tail - ring_.size()), nbytes);
bytes_available_ += nrecv;
return nrecv;
} else {
size_t ncopy = std::min(ring_.size() - tail, nbytes);
size_t nrecv = frecv(&ring_[0] + tail, ncopy);
bytes_available_ += nrecv;
if (nrecv == ncopy && ncopy < nbytes) {
size_t nrecv2 = frecv(&ring_[0], nbytes - ncopy);
bytes_available_ += nrecv2;
nrecv += nrecv2;
}
return nrecv;
}
}
private:
// buffer head
size_t head_ptr_{0};
// number of bytes in the buffer.
size_t bytes_available_{0};
// The internald ata ring.
std::vector<char> ring_;
};
} // namespace common
} // namespace tvm
#endif // TVM_COMMON_RING_BUFFER_H_
......@@ -157,6 +157,7 @@ using namespace tvm::runtime;
struct TVMRuntimeEntry {
std::string ret_str;
std::string last_error;
TVMByteArray ret_bytes;
// threads used in parallel for
std::vector<std::thread> par_threads;
// errors created in parallel for.
......@@ -311,11 +312,23 @@ int TVMFuncCall(TVMFunctionHandle func,
TVMArgs(args, arg_type_codes, num_args), &rv);
// handle return string.
if (rv.type_code() == kStr ||
rv.type_code() == kTVMType) {
rv.type_code() == kTVMType ||
rv.type_code() == kBytes) {
TVMRuntimeEntry* e = TVMAPIRuntimeStore::Get();
e->ret_str = rv.operator std::string();
*ret_type_code = kStr;
ret_val->v_str = e->ret_str.c_str();
if (rv.type_code() != kTVMType) {
e->ret_str = *rv.ptr<std::string>();
} else {
e->ret_str = rv.operator std::string();
}
if (rv.type_code() == kBytes) {
e->ret_bytes.data = e->ret_str.c_str();
e->ret_bytes.size = e->ret_str.length();
*ret_type_code = kBytes;
ret_val->v_handle = &(e->ret_bytes);
} else {
*ret_type_code = kStr;
ret_val->v_str = e->ret_str.c_str();
}
} else {
rv.MoveToCHost(ret_val, ret_type_code);
}
......
/*!
* Copyright (c) 2017 by Contributors
* \file rpc_event_impl.cc
* \brief Event based RPC server implementation.
*/
#include <tvm/runtime/registry.h>
#include <memory>
#include "./rpc_session.h"
namespace tvm {
namespace runtime {
class CallbackChannel final : public RPCChannel {
public:
explicit CallbackChannel(PackedFunc fsend)
: fsend_(fsend) {}
size_t Send(const void* data, size_t size) final {
TVMByteArray bytes;
bytes.data = static_cast<const char*>(data);
bytes.size = size;
uint64_t ret = fsend_(bytes);
return static_cast<size_t>(ret);
}
size_t Recv(void* data, size_t size) final {
LOG(FATAL) << "Do not allow explicit receive for";
return 0;
}
private:
PackedFunc fsend_;
};
PackedFunc CreateEvenDrivenServer(PackedFunc fsend, std::string name) {
std::unique_ptr<CallbackChannel> ch(new CallbackChannel(fsend));
std::shared_ptr<RPCSession> sess = RPCSession::Create(std::move(ch), name);
return PackedFunc([sess](TVMArgs args, TVMRetValue* rv) {
bool ret = sess->ServerOnMessageHandler(args[0]);
*rv = ret;
});
}
TVM_REGISTER_GLOBAL("contrib.rpc._CreateEventDrivenServer")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = CreateEvenDrivenServer(args[0], args[1]);
});
} // namespace runtime
} // namespace tvm
......@@ -10,8 +10,6 @@
namespace tvm {
namespace runtime {
const int kRPCMagic = 0xff271;
// Wrapped remote function to packed func.
struct RPCWrappedFunc {
public:
......@@ -98,36 +96,12 @@ class RPCModuleNode final : public ModuleNode {
std::shared_ptr<RPCSession> sess_;
};
Module RPCConnect(std::string url, int port) {
common::TCPSocket sock;
common::SockAddr addr(url.c_str(), port);
sock.Create();
CHECK(sock.Connect(addr))
<< "Connect to " << addr.AsString() << " failed";
// hand shake
int code = kRPCMagic;
CHECK_EQ(sock.SendAll(&code, sizeof(code)), sizeof(code));
CHECK_EQ(sock.RecvAll(&code, sizeof(code)), sizeof(code));
if (code != kRPCMagic) {
sock.Close();
LOG(FATAL) << "URL " << url << ":" << port << " is not TVM RPC server";
}
Module CreateRPCModule(std::shared_ptr<RPCSession> sess) {
std::shared_ptr<RPCModuleNode> n =
std::make_shared<RPCModuleNode>(nullptr, RPCSession::Create(sock));
std::make_shared<RPCModuleNode>(nullptr, sess);
return Module(n);
}
void RPCServerLoop(int sockfd) {
common::TCPSocket sock(
static_cast<common::TCPSocket::SockType>(sockfd));
RPCSession::Create(sock)->ServerLoop();
}
TVM_REGISTER_GLOBAL("contrib.rpc._Connect")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = RPCConnect(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("module._RPCTimeEvaluator")
.set_body([](TVMArgs args, TVMRetValue* rv) {
Module m = args[0];
......@@ -163,10 +137,5 @@ TVM_REGISTER_GLOBAL("contrib.rpc._SessTableIndex")
CHECK_EQ(tkey, "rpc");
*rv = static_cast<RPCModuleNode*>(m.operator->())->sess()->table_index();
});
TVM_REGISTER_GLOBAL("contrib.rpc._ServerLoop")
.set_body([](TVMArgs args, TVMRetValue* rv) {
RPCServerLoop(args[0]);
});
} // namespace runtime
} // namespace tvm
/*!
* Copyright (c) 2017 by Contributors
* \file rpc_server_env
* \brief Server environment of the RPC.
*/
#include <tvm/runtime/registry.h>
#include "../file_util.h"
namespace tvm {
namespace runtime {
std::string RPCGetPath(const std::string& name) {
static const PackedFunc* f =
runtime::Registry::Get("tvm.contrib.rpc.server.workpath");
CHECK(f != nullptr) << "require tvm.contrib.rpc.server.workpath";
return (*f)(name);
}
TVM_REGISTER_GLOBAL("tvm.contrib.rpc.server.upload").
set_body([](TVMArgs args, TVMRetValue *rv) {
std::string file_name = RPCGetPath(args[0]);
std::string data = args[1];
LOG(INFO) << "Upload " << file_name << "... nbytes=" << data.length();
SaveBinaryToFile(file_name, data);
});
TVM_REGISTER_GLOBAL("tvm.contrib.rpc.server.download")
.set_body([](TVMArgs args, TVMRetValue *rv) {
std::string file_name = RPCGetPath(args[0]);
std::string data;
LoadBinaryFromFile(file_name, &data);
TVMByteArray arr;
arr.data = data.c_str();
arr.size = data.length();
LOG(INFO) << "Download " << file_name << "... nbytes=" << arr.size;
*rv = arr;
});
TVM_REGISTER_GLOBAL("tvm.contrib.rpc.server.load_module")
.set_body([](TVMArgs args, TVMRetValue *rv) {
std::string file_name = RPCGetPath(args[0]);
*rv = Module::LoadFromFile(file_name, "");
LOG(INFO) << "Load module from " << file_name << " ...";
});
} // namespace runtime
} // namespace tvm
......@@ -8,8 +8,10 @@
#include <tvm/runtime/registry.h>
#include <memory>
#include <array>
#include <string>
#include <chrono>
#include "./rpc_session.h"
#include "../../common/ring_buffer.h"
namespace tvm {
namespace runtime {
......@@ -41,6 +43,564 @@ struct RPCArgBuffer {
}
};
// Event handler for RPC events.
class RPCSession::EventHandler {
public:
EventHandler(common::RingBuffer* reader,
common::RingBuffer* writer,
int rpc_sess_table_index,
std::string name)
: reader_(reader), writer_(writer),
rpc_sess_table_index_(rpc_sess_table_index),
name_(name) {
this->Clear();
}
// Bytes needed to fulfill current request
size_t BytesNeeded() {
if (reader_->bytes_available() < pending_request_bytes_) {
return pending_request_bytes_ - reader_->bytes_available();
} else {
return 0;
}
}
bool CanCleanShutdown() const {
return state_ == kRecvCode;
}
void FinishCopyAck() {
this->SwitchToState(kRecvCode);
}
RPCCode HandleNextEvent(TVMRetValue* rv) {
while (this->Ready()) {
switch (state_) {
case kRecvCode: HandleRecvCode(); break;
case kRecvCallHandle: {
this->Read(&call_handle_, sizeof(call_handle_));
this->SwitchToState(kRecvPackedSeqNumArgs);
break;
}
case kRecvPackedSeqNumArgs: {
this->Read(&num_packed_args_, sizeof(num_packed_args_));
arg_buf_.reset(new RPCArgBuffer());
arg_buf_->value.resize(num_packed_args_);
arg_buf_->tcode.resize(num_packed_args_);
this->SwitchToState(kRecvPackedSeqTypeCode);
break;
}
case kRecvPackedSeqTypeCode: {
if (num_packed_args_ != 0) {
this->Read(arg_buf_->tcode.data(), sizeof(int) * num_packed_args_);
}
arg_index_ = 0;
arg_recv_stage_ = 0;
this->SwitchToState(kRecvPackedSeqArg);
break;
}
case kRecvPackedSeqArg: {
this->HandleRecvPackedSeqArg();
break;
}
case kDoCopyFromRemote: {
this->HandleCopyFromRemote();
break;
}
case kDoCopyToRemote: {
this->HandleCopyToRemote();
break;
}
case kReturnReceived: {
CHECK_EQ(arg_buf_->value.size(), 1U);
TVMArgValue argv = arg_buf_->AsTVMArgs()[0];
*rv = argv;
arg_buf_.reset();
this->SwitchToState(kRecvCode);
return RPCCode::kReturn;
}
case kCopyAckReceived: {
return RPCCode::kCopyAck;
}
case kShutdownReceived: {
return RPCCode::kShutdown;
}
}
}
return RPCCode::kNone;
}
// Reset and clear all states.
void Clear() {
state_ = kRecvCode;
pending_request_bytes_ = sizeof(RPCCode);
arg_recv_stage_ = 0;
arg_buf_.reset();
}
// strip sessionon mask
TVMContext StripSessMask(TVMContext ctx) {
int dev_type = ctx.device_type;
CHECK_EQ(dev_type / kRPCSessMask, rpc_sess_table_index_ + 1)
<< "Can only TVMContext related to the same remote sesstion";
ctx.device_type = static_cast<DLDeviceType>(dev_type % kRPCSessMask);
return ctx;
}
// send Packed sequence to writer.
void SendPackedSeq(const TVMValue* arg_values, const int* type_codes, int n) {
writer_->Write(&n, sizeof(n));
writer_->Write(type_codes, sizeof(int) * n);
// Argument packing.
for (int i = 0; i < n; ++i) {
int tcode = type_codes[i];
TVMValue value = arg_values[i];
switch (tcode) {
case kInt:
case kUInt:
case kFloat:
case kTVMType: {
writer_->Write(&value, sizeof(TVMValue));
break;
}
case kTVMContext: {
value.v_ctx = StripSessMask(value.v_ctx);
writer_->Write(&value, sizeof(TVMValue));
break;
}
case kHandle: {
// always send handle in 64 bit.
uint64_t handle = reinterpret_cast<uint64_t>(value.v_handle);
writer_->Write(&handle, sizeof(uint64_t));
break;
}
case kArrayHandle: {
DLTensor* arr = static_cast<DLTensor*>(value.v_handle);
TVMContext ctx = StripSessMask(arr->ctx);
uint64_t data = reinterpret_cast<uint64_t>(
static_cast<RemoteSpace*>(arr->data)->data);
writer_->Write(&data, sizeof(uint64_t));
writer_->Write(&ctx, sizeof(ctx));
writer_->Write(&(arr->ndim), sizeof(int));
writer_->Write(&(arr->dtype), sizeof(DLDataType));
writer_->Write(arr->shape, sizeof(int64_t) * arr->ndim);
CHECK(arr->strides == nullptr)
<< "Donot support strided remote array";
CHECK_EQ(arr->byte_offset, 0)
<< "Donot support send byte offset";
break;
}
case kNull: break;
case kStr: {
const char* s = value.v_str;
uint64_t len = strlen(s);
writer_->Write(&len, sizeof(len));
writer_->Write(s, sizeof(char) * len);
break;
}
case kBytes: {
TVMByteArray* bytes = static_cast<TVMByteArray*>(arg_values[i].v_handle);
uint64_t len = bytes->size;
writer_->Write(&len, sizeof(len));
writer_->Write(bytes->data, sizeof(char) * len);
break;
}
default: {
LOG(FATAL) << "RPC cannot handle type " << TypeCode2Str(tcode);
break;
}
}
}
}
protected:
enum State {
kRecvCode,
kRecvCallHandle,
kRecvPackedSeqNumArgs,
kRecvPackedSeqTypeCode,
kRecvPackedSeqArg,
kDoCopyFromRemote,
kDoCopyToRemote,
kReturnReceived,
kCopyAckReceived,
kShutdownReceived
};
// Current state;
State state_;
// The RPCCode to be read.
RPCCode code_;
// Handle for the remote function call.
uint64_t call_handle_;
// Number of packed arguments.
int num_packed_args_;
// Current argument index.
int arg_index_;
// The stage of each argument receiver.
int arg_recv_stage_;
// Argument buffer
std::unique_ptr<RPCArgBuffer> arg_buf_;
// Temp byte buffer.
std::unique_ptr<RPCByteArrayBuffer> temp_bytes_;
// Temp array buffer.
std::unique_ptr<RPCDataArrayBuffer> temp_array_;
// Internal temporal data space.
std::string temp_data_;
// Temp variables for copy request state.
TVMContext copy_ctx_;
uint64_t copy_handle_, copy_offset_, copy_size_;
// State switcher
void SwitchToState(State state) {
// invariant
CHECK_EQ(pending_request_bytes_, 0U)
<< "state=" << state;
state_ = state;
switch (state) {
case kRecvCode: {
this->RequestBytes(sizeof(RPCCode));
break;
}
case kRecvCallHandle: {
this->RequestBytes(sizeof(call_handle_));
break;
}
case kRecvPackedSeqNumArgs: {
this->RequestBytes(sizeof(num_packed_args_));
break;
}
case kRecvPackedSeqTypeCode: {
this->RequestBytes(sizeof(int) * num_packed_args_);
break;
}
case kRecvPackedSeqArg: {
CHECK_LE(arg_index_, num_packed_args_);
if (arg_index_ == num_packed_args_) {
// The function can change state_ again.
HandlePackedCall();
} else {
RequestRecvPackedSeqArg();
}
break;
}
case kDoCopyFromRemote: {
this->RequestBytes(sizeof(uint64_t) * 3);
this->RequestBytes(sizeof(TVMContext));
break;
}
case kDoCopyToRemote: {
this->RequestBytes(sizeof(uint64_t) * 3);
this->RequestBytes(sizeof(TVMContext));
break;
}
case kCopyAckReceived:
case kReturnReceived:
case kShutdownReceived: {
break;
}
}
}
// Requets bytes needed for next computation.
void RequestRecvPackedSeqArg() {
CHECK_EQ(arg_recv_stage_, 0);
int tcode = arg_buf_->tcode[arg_index_];
static_assert(sizeof(TVMValue) == sizeof(uint64_t), "invariant");
switch (tcode) {
case kInt:
case kUInt:
case kFloat:
case kTVMType:
case kHandle:
case kStr:
case kBytes:
case kTVMContext: this->RequestBytes(sizeof(TVMValue)); break;
case kNull: break;
case kArrayHandle: {
this->RequestBytes(sizeof(uint64_t));
this->RequestBytes(sizeof(TVMContext));
this->RequestBytes(sizeof(int));
this->RequestBytes(sizeof(DLDataType));
break;
}
default: {
LOG(FATAL) << "RPC cannot handle type " << TypeCode2Str(tcode);
break;
}
}
}
// Handler for packed sequence argument receive.
void HandleRecvPackedSeqArg() {
CHECK_LT(arg_index_, num_packed_args_);
int tcode = arg_buf_->tcode[arg_index_];
TVMValue& value = arg_buf_->value[arg_index_];
if (arg_recv_stage_ == 0) {
switch (tcode) {
case kInt:
case kUInt:
case kFloat:
case kTVMType:
case kTVMContext: {
this->Read(&value, sizeof(TVMValue));
++arg_index_;
this->SwitchToState(kRecvPackedSeqArg);
break;
}
case kHandle: {
// always send handle in 64 bit.
uint64_t handle;
this->Read(&handle, sizeof(handle));
value.v_handle = reinterpret_cast<void*>(handle);
++arg_index_;
this->SwitchToState(kRecvPackedSeqArg);
break;
}
case kNull: {
value.v_handle = nullptr;
++arg_index_;
this->SwitchToState(kRecvPackedSeqArg);
break;
}
case kStr:
case kBytes: {
uint64_t len;
this->Read(&len, sizeof(len));
temp_bytes_.reset( new RPCByteArrayBuffer());
temp_bytes_->data.resize(len);
arg_recv_stage_ = 1;
this->RequestBytes(len);
break;
break;
}
case kArrayHandle: {
temp_array_.reset(new RPCDataArrayBuffer());
uint64_t handle;
this->Read(&handle, sizeof(handle));
DLTensor& tensor = temp_array_->tensor;
tensor.data = reinterpret_cast<void*>(handle);
this->Read(&(tensor.ctx), sizeof(TVMContext));
this->Read(&(tensor.ndim), sizeof(int));
this->Read(&(tensor.dtype), sizeof(DLDataType));
temp_array_->shape.resize(tensor.ndim);
tensor.shape = temp_array_->shape.data();
arg_recv_stage_ = 1;
tensor.strides = nullptr;
tensor.byte_offset = 0;
this->RequestBytes(sizeof(int64_t) * tensor.ndim);
break;
}
default: {
LOG(FATAL) << "RPC cannot handle type " << TypeCode2Str(tcode);
break;
}
}
} else {
CHECK_EQ(arg_recv_stage_, 1);
if (tcode == kStr || tcode == kBytes) {
if (temp_bytes_->data.size() != 0) {
this->Read(&(temp_bytes_->data[0]), temp_bytes_->data.size());
}
if (tcode == kStr) {
value.v_str = temp_bytes_->data.c_str();
} else {
temp_bytes_->arr.size = static_cast<size_t>(temp_bytes_->data.size());
temp_bytes_->arr.data = dmlc::BeginPtr(temp_bytes_->data);
value.v_handle = &(temp_bytes_->arr);
}
arg_buf_->temp_bytes.emplace_back(std::move(temp_bytes_));
} else {
CHECK_EQ(tcode, kArrayHandle);
DLTensor& tensor = temp_array_->tensor;
this->Read(tensor.shape, tensor.ndim * sizeof(int64_t));
value.v_handle = &tensor;
arg_buf_->temp_array.emplace_back(std::move(temp_array_));
}
++arg_index_;
arg_recv_stage_ = 0;
this->SwitchToState(kRecvPackedSeqArg);
}
}
// Handler for read code.
void HandleRecvCode() {
this->Read(&code_, sizeof(code_));
if (code_ > RPCCode::kSystemFuncStart) {
SwitchToState(kRecvPackedSeqNumArgs);
return;
}
// invariant.
CHECK_EQ(arg_recv_stage_, 0);
switch (code_) {
case RPCCode::kCallFunc: {
SwitchToState(kRecvCallHandle);
break;
}
case RPCCode::kException:
case RPCCode::kReturn: {
SwitchToState(kRecvPackedSeqNumArgs);
break;
}
case RPCCode::kCopyFromRemote: {
SwitchToState(kDoCopyFromRemote);
break;
}
case RPCCode::kCopyToRemote: {
SwitchToState(kDoCopyToRemote);
break;
}
case RPCCode::kShutdown: {
SwitchToState(kShutdownReceived);
break;
}
case RPCCode::kCopyAck: {
SwitchToState(kCopyAckReceived);
break;
}
default: LOG(FATAL) << "Unknown event " << static_cast<int>(code_);
}
}
void HandleCopyFromRemote() {
uint64_t handle, offset, size;
TVMContext ctx;
this->Read(&handle, sizeof(handle));
this->Read(&offset, sizeof(offset));
this->Read(&size, sizeof(size));
this->Read(&ctx, sizeof(ctx));
if (ctx.device_type == kCPU) {
RPCCode code = RPCCode::kCopyAck;
writer_->Write(&code, sizeof(code));
writer_->Write(reinterpret_cast<char*>(handle) + offset, size);
} else {
temp_data_.resize(size + 1);
try {
TVMContext cpu_ctx;
cpu_ctx.device_type = kCPU;
cpu_ctx.device_id = 0;
DeviceAPI::Get(ctx)->CopyDataFromTo(
reinterpret_cast<void*>(handle), offset,
dmlc::BeginPtr(temp_data_), 0,
size, ctx, cpu_ctx, nullptr);
RPCCode code = RPCCode::kCopyAck;
writer_->Write(&code, sizeof(code));
writer_->Write(&temp_data_[0], size);
} catch (const std::runtime_error &e) {
RPCCode code = RPCCode::kException;
writer_->Write(&code, sizeof(code));
TVMValue ret_value;
ret_value.v_str = e.what();
int ret_tcode = kStr;
SendPackedSeq(&ret_value, &ret_tcode, 1);
}
}
this->SwitchToState(kRecvCode);
}
void HandleCopyToRemote() {
// use static variable to persist state.
// This only works if next stage is immediately after this.
if (arg_recv_stage_ == 0) {
this->Read(&copy_handle_, sizeof(uint64_t));
this->Read(&copy_offset_, sizeof(uint64_t));
this->Read(&copy_size_, sizeof(uint64_t));
this->Read(&copy_ctx_, sizeof(TVMContext));
arg_recv_stage_ = 1;
CHECK_EQ(pending_request_bytes_, 0U);
this->RequestBytes(copy_size_);
} else {
CHECK_EQ(arg_recv_stage_, 1);
TVMValue ret_value;
ret_value.v_handle = nullptr;
int ret_tcode = kNull;
RPCCode code = RPCCode::kReturn;
std::string errmsg;
if (copy_ctx_.device_type == kCPU) {
this->Read(
reinterpret_cast<char*>(copy_handle_) + copy_offset_, copy_size_);
} else {
temp_data_.resize(copy_size_ + 1);
this->Read(&temp_data_[0], copy_size_);
try {
TVMContext cpu_ctx;
cpu_ctx.device_type = kCPU;
cpu_ctx.device_id = 0;
DeviceAPI::Get(copy_ctx_)->CopyDataFromTo(
temp_data_.data(), 0,
reinterpret_cast<void*>(copy_handle_), copy_offset_,
copy_size_, cpu_ctx, copy_ctx_, nullptr);
} catch (const std::runtime_error &e) {
code = RPCCode::kException;
errmsg = e.what();
ret_value.v_str = errmsg.c_str();
ret_tcode = kStr;
}
}
writer_->Write(&code, sizeof(code));
SendPackedSeq(&ret_value, &ret_tcode, 1);
arg_recv_stage_ = 0;
this->SwitchToState(kRecvCode);
}
}
// Handle for packed call.
void HandlePackedCall();
template<typename F>
void CallHandler(F f) {
TVMRetValue rv;
TVMValue ret_value;
int ret_tcode;
try {
// Need to move out, in case f itself need to call RecvPackedSeq
// Which will override argbuf again.
std::unique_ptr<RPCArgBuffer> args = std::move(arg_buf_);
f(args->AsTVMArgs(), &rv);
RPCCode code = RPCCode::kReturn;
writer_->Write(&code, sizeof(code));
if (rv.type_code() == kStr) {
ret_value.v_str = rv.ptr<std::string>()->c_str();
ret_tcode = kStr;
SendPackedSeq(&ret_value, &ret_tcode, 1);
} else if (rv.type_code() == kBytes) {
std::string* bytes = rv.ptr<std::string>();
TVMByteArray arr;
arr.data = bytes->c_str();
arr.size = bytes->length();
ret_value.v_handle = &arr;
ret_tcode = kBytes;
SendPackedSeq(&ret_value, &ret_tcode, 1);
} else {
ret_value = rv.value();
ret_tcode = rv.type_code();
SendPackedSeq(&ret_value, &ret_tcode, 1);
}
} catch (const std::runtime_error& e) {
RPCCode code = RPCCode::kException;
writer_->Write(&code, sizeof(code));
ret_value.v_str = e.what();
ret_tcode = kStr;
SendPackedSeq(&ret_value, &ret_tcode, 1);
}
}
private:
// Utility functions
// Internal read function, update pending_request_bytes_
void Read(void* data, size_t size) {
CHECK_LE(size, pending_request_bytes_);
reader_->Read(data, size);
pending_request_bytes_ -= size;
}
// Request number of bytes from reader.
void RequestBytes(size_t nbytes) {
pending_request_bytes_ += nbytes;
reader_->Reserve(pending_request_bytes_);
}
// Whether we are ready to handle next request.
bool Ready() {
return reader_->bytes_available() >= pending_request_bytes_;
}
// Number of pending bytes requests
size_t pending_request_bytes_;
// The ring buffer to read data from.
common::RingBuffer* reader_;
// The ringr buffer to write reply to.
common::RingBuffer* writer_;
// Session table index.
int rpc_sess_table_index_;
// Name of session.
std::string name_;
};
struct RPCSessTable {
public:
static constexpr int kMaxRPCSession = 32;
......@@ -74,22 +634,52 @@ struct RPCSessTable {
std::array<std::weak_ptr<RPCSession>, kMaxRPCSession> tbl_;
};
RPCCode RPCSession::HandleUntilReturnEvent(TVMRetValue* rv) {
RPCCode code = RPCCode::kCallFunc;
while (code != RPCCode::kReturn &&
code != RPCCode::kShutdown &&
code != RPCCode::kCopyAck) {
while (writer_.bytes_available() != 0) {
writer_.ReadWithCallback([this](const void *data, size_t size) {
return channel_->Send(data, size);
}, writer_.bytes_available());
}
size_t bytes_needed = handler_->BytesNeeded();
if (bytes_needed != 0) {
size_t n = reader_.WriteWithCallback([this](void* data, size_t size) {
return channel_->Recv(data, size);
}, bytes_needed);
if (n == 0) {
if (handler_->CanCleanShutdown()) {
return RPCCode::kShutdown;
} else {
LOG(FATAL) << "Channel closes before we get neded bytes";
}
}
}
code = handler_->HandleNextEvent(rv);
}
return code;
}
void RPCSession::Init() {
// Event handler
handler_ = std::make_shared<EventHandler>(&reader_, &writer_, table_index_, name_);
// Quick function to call remote.
call_remote_ = PackedFunc([this](TVMArgs args, TVMRetValue* rv) {
this->SendPackedSeq(args.values, args.type_codes, args.num_args);
RPCCode code = RPCCode::kCallFunc;
while (code != RPCCode::kReturn) {
code = HandleNextEvent(rv);
}
handler_->SendPackedSeq(args.values, args.type_codes, args.num_args);
RPCCode code = HandleUntilReturnEvent(rv);
CHECK(code == RPCCode::kReturn) << "code=" << static_cast<int>(code);
});
}
std::shared_ptr<RPCSession> RPCSession::Create(common::TCPSocket sock) {
std::shared_ptr<RPCSession> RPCSession::Create(
std::unique_ptr<RPCChannel> channel, std::string name) {
std::shared_ptr<RPCSession> sess = std::make_shared<RPCSession>();
sess->sock_ = sock;
sess->channel_ = std::move(channel);
sess->name_ = std::move(name);
sess->table_index_ = RPCSessTable::Global()->Insert(sess);
sess->Init();
sess->table_index_ = RPCSessTable::Global()->Insert(sess);
return sess;
}
......@@ -102,33 +692,53 @@ RPCSession::~RPCSession() {
}
void RPCSession::Shutdown() {
if (!sock_.BadSocket()) {
if (channel_ != nullptr) {
RPCCode code = RPCCode::kShutdown;
CHECK_EQ(sock_.SendAll(&code, sizeof(code)), sizeof(code));
sock_.Close();
writer_.Write(&code, sizeof(code));
// flush all writing buffer to output channel.
try {
while (writer_.bytes_available() != 0) {
size_t n = writer_.ReadWithCallback([this](const void *data, size_t size) {
return channel_->Send(data, size);
}, writer_.bytes_available());
if (n == 0) break;
}
} catch (const dmlc::Error& e) {
}
channel_.reset(nullptr);
}
}
void RPCSession::ServerLoop() {
std::lock_guard<std::recursive_mutex> lock(mutex_);
RPCCode code = RPCCode::kCallFunc;
TVMRetValue rv;
while (code != RPCCode::kShutdown) {
code = HandleNextEvent(&rv);
CHECK(code != RPCCode::kReturn);
}
if (!sock_.BadSocket()) {
sock_.Close();
CHECK(HandleUntilReturnEvent(&rv) == RPCCode::kShutdown);
LOG(INFO) << "Shutdown...";
channel_.reset(nullptr);
}
bool RPCSession::ServerOnMessageHandler(const std::string& bytes) {
std::lock_guard<std::recursive_mutex> lock(mutex_);
reader_.Write(bytes.c_str(), bytes.length());
TVMRetValue rv;
RPCCode code = handler_->HandleNextEvent(&rv);
while (writer_.bytes_available() != 0) {
writer_.ReadWithCallback([this](const void *data, size_t size) {
return channel_->Send(data, size);
}, writer_.bytes_available());
}
CHECK(code != RPCCode::kReturn && code != RPCCode::kCopyAck);
return code != RPCCode::kShutdown;
}
// Get remote function with name
void RPCSession::CallFunc(void* h, TVMArgs args, TVMRetValue* rv) {
std::lock_guard<std::recursive_mutex> lock(mutex_);
RPCCode code = RPCCode::kCallFunc;
CHECK_EQ(sock_.SendAll(&code, sizeof(code)), sizeof(code));
writer_.Write(&code, sizeof(code));
uint64_t handle = reinterpret_cast<uint64_t>(h);
CHECK_EQ(sock_.SendAll(&handle, sizeof(handle)), sizeof(handle));
writer_.Write(&handle, sizeof(handle));
call_remote_.CallPacked(args, rv);
}
......@@ -139,22 +749,19 @@ void RPCSession::CopyToRemote(void* from,
size_t data_size,
TVMContext ctx_to) {
std::lock_guard<std::recursive_mutex> lock(mutex_);
ctx_to = StripSessMask(ctx_to);
ctx_to = handler_->StripSessMask(ctx_to);
RPCCode code = RPCCode::kCopyToRemote;
CHECK_EQ(sock_.SendAll(&code, sizeof(code)), sizeof(code));
writer_.Write(&code, sizeof(code));
uint64_t handle = reinterpret_cast<uint64_t>(to);
CHECK_EQ(sock_.SendAll(&handle, sizeof(handle)), sizeof(handle));
writer_.Write(&handle, sizeof(handle));
uint64_t offset = static_cast<uint64_t>(to_offset);
CHECK_EQ(sock_.SendAll(&offset, sizeof(offset)), sizeof(offset));
writer_.Write(&offset, sizeof(offset));
uint64_t size = static_cast<uint64_t>(data_size);
CHECK_EQ(sock_.SendAll(&size, sizeof(size)), sizeof(size));
CHECK_EQ(sock_.SendAll(&ctx_to, sizeof(ctx_to)), sizeof(ctx_to));
CHECK_EQ(sock_.SendAll(reinterpret_cast<char*>(from) + from_offset, data_size),
data_size);
writer_.Write(&size, sizeof(size));
writer_.Write(&ctx_to, sizeof(ctx_to));
writer_.Write(reinterpret_cast<char*>(from) + from_offset, data_size);
TVMRetValue rv;
while (code != RPCCode::kReturn) {
code = HandleNextEvent(&rv);
}
CHECK(HandleUntilReturnEvent(&rv) == RPCCode::kReturn);
}
void RPCSession::CopyFromRemote(void* from,
......@@ -164,23 +771,29 @@ void RPCSession::CopyFromRemote(void* from,
size_t data_size,
TVMContext ctx_from) {
std::lock_guard<std::recursive_mutex> lock(mutex_);
ctx_from = StripSessMask(ctx_from);
ctx_from = handler_->StripSessMask(ctx_from);
RPCCode code = RPCCode::kCopyFromRemote;
CHECK_EQ(sock_.SendAll(&code, sizeof(code)), sizeof(code));
writer_.Write(&code, sizeof(code));
uint64_t handle = reinterpret_cast<uint64_t>(from);
CHECK_EQ(sock_.SendAll(&handle, sizeof(handle)), sizeof(handle));
writer_.Write(&handle, sizeof(handle));
uint64_t offset = static_cast<uint64_t>(from_offset);
CHECK_EQ(sock_.SendAll(&offset, sizeof(offset)), sizeof(offset));
writer_.Write(&offset, sizeof(offset));
uint64_t size = static_cast<uint64_t>(data_size);
CHECK_EQ(sock_.SendAll(&size, sizeof(size)), sizeof(size));
CHECK_EQ(sock_.SendAll(&ctx_from, sizeof(ctx_from)), sizeof(ctx_from));
CHECK_EQ(sock_.RecvAll(&code, sizeof(code)), sizeof(code));
if (code == RPCCode::kCopyAck) {
CHECK_EQ(sock_.RecvAll(reinterpret_cast<char*>(to) + to_offset, data_size),
data_size);
} else {
HandleException();
writer_.Write(&size, sizeof(size));
writer_.Write(&ctx_from, sizeof(ctx_from));
TVMRetValue rv;
CHECK(HandleUntilReturnEvent(&rv) == RPCCode::kCopyAck);
reader_.Reserve(data_size);
while (reader_.bytes_available() < data_size) {
size_t bytes_needed = data_size - reader_.bytes_available();
reader_.WriteWithCallback([this](void* data, size_t size) {
size_t n = channel_->Recv(data, size);
CHECK_NE(n, 0U) << "Channel closes before we get neded bytes";
return n;
}, bytes_needed);
}
reader_.Read(reinterpret_cast<char*>(to) + to_offset, data_size);
handler_->FinishCopyAck();
}
RPCFuncHandle RPCSession::GetTimeEvaluator(
......@@ -188,308 +801,6 @@ RPCFuncHandle RPCSession::GetTimeEvaluator(
return this->CallRemote(RPCCode::kGetTimeEvaluator, fhandle, ctx, nstep);
}
void RPCSession::SendReturnValue(
int succ, TVMValue ret_value, int ret_tcode) {
if (succ == 0) {
RPCCode code = RPCCode::kReturn;
CHECK_EQ(sock_.SendAll(&code, sizeof(code)), sizeof(code));
} else {
RPCCode code = RPCCode::kException;
CHECK_EQ(sock_.SendAll(&code, sizeof(code)), sizeof(code));
ret_value.v_str = TVMGetLastError();
ret_tcode = kStr;
}
SendPackedSeq(&ret_value, &ret_tcode, 1);
}
template<typename F>
void RPCSession::CallHandler(F f) {
RPCArgBuffer args;
this->RecvPackedSeq(&args);
TVMRetValue rv;
TVMValue ret_value;
int ret_tcode;
try {
f(TVMArgs(args.value.data(), args.tcode.data(),
static_cast<int>(args.value.size())), &rv);
RPCCode code = RPCCode::kReturn;
CHECK_EQ(sock_.SendAll(&code, sizeof(code)), sizeof(code));
if (rv.type_code() == kStr) {
std::string str = rv;
ret_value.v_str = str.c_str();
ret_tcode = kStr;
SendPackedSeq(&ret_value, &ret_tcode, 1);
} else {
ret_value = rv.value();
ret_tcode = rv.type_code();
SendPackedSeq(&ret_value, &ret_tcode, 1);
}
} catch (const std::runtime_error& e) {
RPCCode code = RPCCode::kException;
CHECK_EQ(sock_.SendAll(&code, sizeof(code)), sizeof(code));
ret_value.v_str = e.what();
ret_tcode = kStr;
SendPackedSeq(&ret_value, &ret_tcode, 1);
}
}
void RPCSession::HandleCallFunc() {
uint64_t handle;
CHECK_EQ(sock_.RecvAll(&handle, sizeof(handle)), sizeof(handle));
PackedFunc* pf = reinterpret_cast<PackedFunc*>(handle);
CallHandler([pf](TVMArgs args, TVMRetValue *rv) {
pf->CallPacked(args, rv);
});
}
void RPCSession::HandleException() {
RPCArgBuffer ret;
this->RecvPackedSeq(&ret);
CHECK_EQ(ret.value.size(), 1U);
CHECK_EQ(ret.tcode[0], kStr);
std::ostringstream os;
os << "Except caught from RPC call: " << ret.value[0].v_str;
throw dmlc::Error(os.str());
}
void RPCSession::HandleCopyToRemote() {
uint64_t handle, offset, size;
TVMContext ctx;
CHECK_EQ(sock_.RecvAll(&handle, sizeof(handle)), sizeof(handle));
CHECK_EQ(sock_.RecvAll(&offset, sizeof(offset)), sizeof(offset));
CHECK_EQ(sock_.RecvAll(&size, sizeof(size)), sizeof(size));
CHECK_EQ(sock_.RecvAll(&ctx, sizeof(ctx)), sizeof(ctx));
int succ = 0;
if (ctx.device_type == kCPU) {
CHECK_EQ(sock_.RecvAll(reinterpret_cast<char*>(handle) + offset, size),
static_cast<size_t>(size));
} else {
temp_data_.resize(size+1);
CHECK_EQ(sock_.RecvAll(&temp_data_[0], size),
static_cast<size_t>(size));
try {
TVMContext cpu_ctx;
cpu_ctx.device_type = kCPU;
cpu_ctx.device_id = 0;
DeviceAPI::Get(ctx)->CopyDataFromTo(
temp_data_.data(), 0,
reinterpret_cast<void*>(handle), offset,
size, cpu_ctx, ctx, nullptr);
} catch (const std::runtime_error &e) {
TVMAPISetLastError(e.what());
succ = -1;
}
}
TVMValue ret_value;
ret_value.v_handle = nullptr;
int ret_tcode = kNull;
SendReturnValue(succ, ret_value, ret_tcode);
}
void RPCSession::HandleCopyFromRemote() {
uint64_t handle, offset, size;
TVMContext ctx;
CHECK_EQ(sock_.RecvAll(&handle, sizeof(handle)), sizeof(handle));
CHECK_EQ(sock_.RecvAll(&offset, sizeof(offset)), sizeof(offset));
CHECK_EQ(sock_.RecvAll(&size, sizeof(size)), sizeof(size));
CHECK_EQ(sock_.RecvAll(&ctx, sizeof(ctx)), sizeof(ctx));
if (ctx.device_type == kCPU) {
RPCCode code = RPCCode::kCopyAck;
CHECK_EQ(sock_.SendAll(&code, sizeof(code)), sizeof(code));
CHECK_EQ(sock_.SendAll(reinterpret_cast<char*>(handle) + offset, size),
static_cast<size_t>(size));
} else {
temp_data_.resize(size + 1);
try {
TVMContext cpu_ctx;
cpu_ctx.device_type = kCPU;
cpu_ctx.device_id = 0;
DeviceAPI::Get(ctx)->CopyDataFromTo(
reinterpret_cast<void*>(handle), offset,
dmlc::BeginPtr(temp_data_), 0,
size, ctx, cpu_ctx, nullptr);
RPCCode code = RPCCode::kCopyAck;
CHECK_EQ(sock_.SendAll(&code, sizeof(code)), sizeof(code));
CHECK_EQ(sock_.SendAll(&temp_data_[0], size),
static_cast<size_t>(size));
} catch (const std::runtime_error &e) {
RPCCode code = RPCCode::kException;
CHECK_EQ(sock_.SendAll(&code, sizeof(code)), sizeof(code));
TVMValue ret_value;
ret_value.v_str = e.what();
int ret_tcode = kStr;
SendPackedSeq(&ret_value, &ret_tcode, 1);
}
}
}
void RPCSession::HandleReturn(TVMRetValue* rv) {
RPCArgBuffer ret;
this->RecvPackedSeq(&ret);
CHECK_EQ(ret.value.size(), 1U);
TVMArgValue argv = ret.AsTVMArgs()[0];
*rv = argv;
}
TVMContext RPCSession::StripSessMask(TVMContext ctx) {
int dev_type = ctx.device_type;
CHECK_EQ(dev_type / kRPCSessMask, table_index_ + 1)
<< "Can only TVMContext related to the same remote sesstion";
ctx.device_type = static_cast<DLDeviceType>(dev_type % kRPCSessMask);
return ctx;
}
// packed Send sequence to the channel
void RPCSession::SendPackedSeq(
const TVMValue* arg_values, const int* type_codes, int n) {
CHECK_EQ(sock_.SendAll(&n, sizeof(n)), sizeof(n));
CHECK_EQ(sock_.SendAll(type_codes, sizeof(int) * n), sizeof(int) * n);
// Argument packing.
for (int i = 0; i < n; ++i) {
int tcode = type_codes[i];
TVMValue value = arg_values[i];
switch (tcode) {
case kInt:
case kUInt:
case kFloat:
case kTVMType: {
CHECK_EQ(sock_.SendAll(&value, sizeof(TVMValue)), sizeof(TVMValue));
break;
}
case kTVMContext: {
value.v_ctx = StripSessMask(value.v_ctx);
CHECK_EQ(sock_.SendAll(&value, sizeof(TVMValue)), sizeof(TVMValue));
break;
}
case kHandle: {
// always send handle in 64 bit.
uint64_t handle = reinterpret_cast<uint64_t>(value.v_handle);
CHECK_EQ(sock_.SendAll(&handle, sizeof(uint64_t)), sizeof(uint64_t));
break;
}
case kArrayHandle: {
DLTensor* arr = static_cast<DLTensor*>(value.v_handle);
TVMContext ctx = StripSessMask(arr->ctx);
uint64_t data = reinterpret_cast<uint64_t>(
static_cast<RemoteSpace*>(arr->data)->data);
CHECK_EQ(sock_.SendAll(&data, sizeof(uint64_t)), sizeof(uint64_t));
CHECK_EQ(sock_.SendAll(&ctx, sizeof(ctx)), sizeof(ctx));
CHECK_EQ(sock_.SendAll(&(arr->ndim), sizeof(int)), sizeof(int));
CHECK_EQ(sock_.SendAll(&(arr->dtype), sizeof(DLDataType)), sizeof(DLDataType));
CHECK_EQ(sock_.SendAll(arr->shape, sizeof(int64_t) * arr->ndim),
sizeof(int64_t) * arr->ndim);
CHECK(arr->strides == nullptr)
<< "Donot support strided remote array";
CHECK_EQ(arr->byte_offset, 0)
<< "Donot support send byte offset";
break;
}
case kNull: break;
case kStr: {
const char* s = value.v_str;
uint64_t len = strlen(s);
CHECK_EQ(sock_.SendAll(&len, sizeof(len)), sizeof(len));
CHECK_EQ(sock_.SendAll(s, sizeof(char) * len), sizeof(char) * len);
break;
}
case kBytes: {
TVMByteArray* bytes = static_cast<TVMByteArray*>(arg_values[i].v_handle);
uint64_t len = bytes->size;
CHECK_EQ(sock_.SendAll(&len, sizeof(len)), sizeof(len));
CHECK_EQ(sock_.SendAll(bytes->data, sizeof(char) * len), sizeof(char) * len);
break;
}
default: {
LOG(FATAL) << "RPC cannot handle type " << TypeCode2Str(tcode);
break;
}
}
}
}
// Receive packed sequence from the channel
void RPCSession::RecvPackedSeq(RPCArgBuffer *buf) {
int n;
CHECK_EQ(sock_.RecvAll(&n, sizeof(n)), sizeof(n));
buf->value.resize(n);
buf->tcode.resize(n);
buf->temp_bytes.clear();
if (n != 0) {
buf->tcode.resize(n);
CHECK_EQ(sock_.RecvAll(buf->tcode.data(), sizeof(int) * n),
sizeof(int) * n);
}
buf->value.resize(n);
for (int i = 0; i < n; ++i) {
int tcode = buf->tcode[i];
TVMValue& value = buf->value[i];
switch (tcode) {
case kInt:
case kUInt:
case kFloat:
case kTVMType:
case kTVMContext: {
CHECK_EQ(sock_.RecvAll(&value, sizeof(TVMValue)), sizeof(TVMValue));
break;
}
case kHandle: {
// always send handle in 64 bit.
uint64_t handle;
CHECK_EQ(sock_.RecvAll(&handle, sizeof(uint64_t)), sizeof(uint64_t));
value.v_handle = reinterpret_cast<void*>(handle);
break;
}
case kNull: {
value.v_handle = nullptr;
break;
}
case kStr:
case kBytes: {
uint64_t len;
CHECK_EQ(sock_.RecvAll(&len, sizeof(len)), sizeof(len));
std::unique_ptr<RPCByteArrayBuffer> temp(new RPCByteArrayBuffer());
temp->data.resize(len);
if (len != 0) {
CHECK_EQ(sock_.RecvAll(&(temp->data[0]), sizeof(char) * len),
sizeof(char) * len);
}
if (tcode == kStr) {
value.v_str = temp->data.c_str();
} else {
temp->arr.size = static_cast<size_t>(len);
temp->arr.data = dmlc::BeginPtr(temp->data);
value.v_handle = &(temp->arr);
}
buf->temp_bytes.emplace_back(std::move(temp));
break;
}
case kArrayHandle: {
std::unique_ptr<RPCDataArrayBuffer> temp(new RPCDataArrayBuffer());
uint64_t handle;
CHECK_EQ(sock_.RecvAll(&handle, sizeof(handle)), sizeof(handle));
DLTensor& tensor = temp->tensor;
tensor.data = reinterpret_cast<void*>(handle);
CHECK_EQ(sock_.RecvAll(&(tensor.ctx), sizeof(TVMContext)), sizeof(TVMContext));
CHECK_EQ(sock_.RecvAll(&(tensor.ndim), sizeof(int)), sizeof(int));
CHECK_EQ(sock_.RecvAll(&(tensor.dtype), sizeof(DLDataType)), sizeof(DLDataType));
temp->shape.resize(tensor.ndim);
tensor.shape = temp->shape.data();
CHECK_EQ(sock_.RecvAll(tensor.shape, tensor.ndim * sizeof(int64_t)),
tensor.ndim * sizeof(int64_t));
tensor.strides = nullptr;
tensor.byte_offset = 0;
value.v_handle = &tensor;
buf->temp_array.emplace_back(std::move(temp));
break;
}
default: {
LOG(FATAL) << "RPC cannot handle type " << TypeCode2Str(tcode);
break;
}
}
}
}
// Event handler functions
void RPCGetGlobalFunc(TVMArgs args, TVMRetValue* rv) {
std::string name = args[0];
......@@ -607,16 +918,32 @@ void RPCGetTimeEvaluator(TVMArgs args, TVMRetValue *rv) {
*rv = fhandle;
}
RPCCode RPCSession::HandleNextEvent(TVMRetValue *rv) {
RPCCode code;
CHECK_EQ(sock_.RecvAll(&code, sizeof(int)), sizeof(int));
switch (code) {
case RPCCode::kCallFunc: HandleCallFunc(); break;
case RPCCode::kReturn: HandleReturn(rv); break;
case RPCCode::kException: HandleException(); break;
case RPCCode::kCopyFromRemote: HandleCopyFromRemote(); break;
case RPCCode::kCopyToRemote: HandleCopyToRemote(); break;
case RPCCode::kShutdown: break;
void RPCSession::EventHandler::HandlePackedCall() {
CHECK_EQ(pending_request_bytes_, 0U);
if (code_ == RPCCode::kReturn) {
state_ = kReturnReceived; return;
}
// reset state to clean init state
state_ = kRecvCode;
this->RequestBytes(sizeof(RPCCode));
// Event handler sit at clean state at this point.
switch (code_) {
case RPCCode::kCallFunc: {
PackedFunc* pf = reinterpret_cast<PackedFunc*>(call_handle_);
CallHandler([pf](TVMArgs args, TVMRetValue* rv) {
pf->CallPacked(args, rv);
});
break;
}
case RPCCode::kException: {
CHECK_EQ(arg_buf_->value.size(), 1U);
CHECK_EQ(arg_buf_->tcode[0], kStr);
std::ostringstream os;
os << "Except caught from RPC call: " << arg_buf_->value[0].v_str;
arg_buf_.reset();
throw dmlc::Error(os.str());
break;
}
// system functions
case RPCCode::kGetTimeEvaluator: CallHandler(RPCGetTimeEvaluator); break;
case RPCCode::kFreeFunc: CallHandler(RPCFreeFunc); break;
......@@ -631,9 +958,9 @@ RPCCode RPCSession::HandleNextEvent(TVMRetValue *rv) {
case RPCCode::kModuleFree: CallHandler(RPCModuleFree); break;
case RPCCode::kModuleGetFunc: CallHandler(RPCModuleGetFunc); break;
case RPCCode::kModuleGetSource: CallHandler(RPCModuleGetSource); break;
default: LOG(FATAL) << "Unknown event " << static_cast<int>(code);
default: LOG(FATAL) << "Unknown event " << static_cast<int>(code_);
}
return code;
CHECK_EQ(state_, kRecvCode);
}
PackedFunc WrapTimeEvaluator(PackedFunc pf, TVMContext ctx, int nstep) {
......@@ -656,5 +983,6 @@ PackedFunc WrapTimeEvaluator(PackedFunc pf, TVMContext ctx, int nstep) {
};
return PackedFunc(ftimer);
}
} // namespace runtime
} // namespace tvm
......@@ -10,11 +10,13 @@
#include <tvm/runtime/device_api.h>
#include <mutex>
#include <string>
#include "../../common/socket.h"
#include "../../common/ring_buffer.h"
namespace tvm {
namespace runtime {
const int kRPCMagic = 0xff271;
/*! \brief The remote functio handle */
using RPCFuncHandle = void*;
......@@ -22,6 +24,7 @@ struct RPCArgBuffer;
/*! \brief The RPC code */
enum class RPCCode : int {
kNone,
kCallFunc,
kReturn,
kException,
......@@ -30,6 +33,7 @@ enum class RPCCode : int {
kCopyToRemote,
kCopyAck,
// The following are code that can send over CallRemote
kSystemFuncStart,
kGetGlobalFunc,
kGetTimeEvaluator,
kFreeFunc,
......@@ -45,6 +49,30 @@ enum class RPCCode : int {
kModuleGetSource
};
/*!
* \brief Abstract channel interface used to create RPCSession.
*/
class RPCChannel {
public:
/*! \brief virtual destructor */
virtual ~RPCChannel() {}
/*!
* \brief Send data over to the channel.
* \param data The data pointer.
* \param size The size fo the data.
* \return The actual bytes sent.
*/
virtual size_t Send(const void* data, size_t size) = 0;
/*!
e * \brief Recv data from channel.
*
* \param data The data pointer.
* \param size The size fo the data.
* \return The actual bytes received.
*/
virtual size_t Recv(void* data, size_t size) = 0;
};
// Bidirectional Communication Session of PackedRPC
class RPCSession {
public:
......@@ -55,6 +83,17 @@ class RPCSession {
*/
void ServerLoop();
/*!
* \brief Message handling function for event driven server.
* Called when the server receives a message.
* Event driven handler will never call recv on the channel
* and always relies on the ServerOnMessageHandler
* to receive the data.
*
* \param bytes The incoming bytes.
* \return Whether need continue running, return false when receive a shutdown message.
*/
bool ServerOnMessageHandler(const std::string& bytes);
/*!
* \brief Call into remote function
* \param handle The function handle
* \param args The arguments
......@@ -121,10 +160,13 @@ class RPCSession {
}
/*!
* \brief Create a RPC session with given socket
* \param sock The socket.
* \param channel The communication channel.
* \param name The name of the session, used for debug
* \return The session.
*/
static std::shared_ptr<RPCSession> Create(common::TCPSocket sock);
static std::shared_ptr<RPCSession> Create(
std::unique_ptr<RPCChannel> channel,
std::string name);
/*!
* \brief Try get session from the global session table by table index.
* \param table_index The table index of the session.
......@@ -133,36 +175,28 @@ class RPCSession {
static std::shared_ptr<RPCSession> Get(int table_index);
private:
/*!
* \brief Handle the remote call with f
* \param f The handle function
* \tparam F the handler function.
*/
template<typename F>
void CallHandler(F f);
class EventHandler;
// Handle events until receives a return
// Also flushes channels so that the function advances.
RPCCode HandleUntilReturnEvent(TVMRetValue* rv);
// Initalization
void Init();
// Shutdown
void Shutdown();
void SendReturnValue(int succ, TVMValue value, int tcode);
void SendPackedSeq(const TVMValue* arg_values, const int* type_codes, int n);
void RecvPackedSeq(RPCArgBuffer *buf);
RPCCode HandleNextEvent(TVMRetValue *rv);
TVMContext StripSessMask(TVMContext ctx);
// special handler.
void HandleCallFunc();
void HandleException();
void HandleCopyFromRemote();
void HandleCopyToRemote();
void HandleReturn(TVMRetValue* rv);
// Internal channel.
std::unique_ptr<RPCChannel> channel_;
// Internal mutex
std::recursive_mutex mutex_;
// Internal socket
common::TCPSocket sock_;
// Internal temporal data space.
std::string temp_data_;
// Internal ring buffer.
common::RingBuffer reader_, writer_;
// Event handler.
std::shared_ptr<EventHandler> handler_;
// call remote with the specified function coede.
PackedFunc call_remote_;
// The index of this session in RPC session table.
int table_index_{0};
// The name of the session.
std::string name_;
};
/*!
......@@ -173,6 +207,13 @@ class RPCSession {
*/
PackedFunc WrapTimeEvaluator(PackedFunc f, TVMContext ctx, int nstep);
/*!
* \brief Create a Global RPC module that refers to the session.
* \param sess The RPC session of the global module.
* \return The created module.
*/
Module CreateRPCModule(std::shared_ptr<RPCSession> sess);
// Remote space pointer.
struct RemoteSpace {
void* data;
......@@ -183,7 +224,7 @@ struct RemoteSpace {
template<typename... Args>
inline TVMRetValue RPCSession::CallRemote(RPCCode code, Args&& ...args) {
std::lock_guard<std::recursive_mutex> lock(mutex_);
CHECK_EQ(sock_.SendAll(&code, sizeof(code)), sizeof(code));
writer_.Write(&code, sizeof(code));
return call_remote_(std::forward<Args>(args)...);
}
} // namespace runtime
......
/*!
* Copyright (c) 2017 by Contributors
* \file rpc_socket_impl.cc
* \brief Socket based RPC implementation.
*/
#include <tvm/runtime/registry.h>
#include <memory>
#include "./rpc_session.h"
#include "../../common/socket.h"
namespace tvm {
namespace runtime {
class SockChannel final : public RPCChannel {
public:
explicit SockChannel(common::TCPSocket sock)
: sock_(sock) {}
~SockChannel() {
if (!sock_.BadSocket()) {
sock_.Close();
}
}
size_t Send(const void* data, size_t size) final {
ssize_t n = sock_.Send(data, size);
if (n == -1) {
common::Socket::Error("SockChannel::Send");
}
return static_cast<size_t>(n);
}
size_t Recv(void* data, size_t size) final {
ssize_t n = sock_.Recv(data, size);
if (n == -1) {
common::Socket::Error("SockChannel::Recv");
}
return static_cast<size_t>(n);
}
private:
common::TCPSocket sock_;
};
Module RPCConnect(std::string url, int port, std::string key) {
common::TCPSocket sock;
common::SockAddr addr(url.c_str(), port);
sock.Create();
CHECK(sock.Connect(addr))
<< "Connect to " << addr.AsString() << " failed";
// hand shake
std::ostringstream os;
os << "client:" << key;
key = os.str();
int code = kRPCMagic;
int keylen = static_cast<int>(key.length());
CHECK_EQ(sock.SendAll(&code, sizeof(code)), sizeof(code));
CHECK_EQ(sock.SendAll(&keylen, sizeof(keylen)), sizeof(keylen));
if (keylen != 0) {
CHECK_EQ(sock.SendAll(key.c_str(), keylen), keylen);
}
CHECK_EQ(sock.RecvAll(&code, sizeof(code)), sizeof(code));
if (code == kRPCMagic + 2) {
sock.Close();
LOG(FATAL) << "URL " << url << ":" << port
<< " cannot find server that matches key=" << key;
} else if (code == kRPCMagic + 1) {
sock.Close();
LOG(FATAL) << "URL " << url << ":" << port
<< " server already have client key=" << key;
} else if (code != kRPCMagic) {
sock.Close();
LOG(FATAL) << "URL " << url << ":" << port << " is not TVM RPC server";
}
return CreateRPCModule(
RPCSession::Create(
std::unique_ptr<SockChannel>(new SockChannel(sock)),
"SockClient"));
}
void RPCServerLoop(int sockfd) {
common::TCPSocket sock(
static_cast<common::TCPSocket::SockType>(sockfd));
RPCSession::Create(
std::unique_ptr<SockChannel>(new SockChannel(sock)),
"SockServerLoop")->ServerLoop();
}
TVM_REGISTER_GLOBAL("contrib.rpc._Connect")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = RPCConnect(args[0], args[1], args[2]);
});
TVM_REGISTER_GLOBAL("contrib.rpc._ServerLoop")
.set_body([](TVMArgs args, TVMRetValue* rv) {
RPCServerLoop(args[0]);
});
} // namespace runtime
} // namespace tvm
......@@ -12,5 +12,8 @@ RUN bash /install/ubuntu_install_python.sh
COPY install/ubuntu_install_iverilog.sh /install/ubuntu_install_iverilog.sh
RUN bash /install/ubuntu_install_iverilog.sh
COPY install/ubuntu_install_python_package.sh /install/ubuntu_install_python_package.sh
RUN bash /install/ubuntu_install_python_package.sh
COPY install/ubuntu_install_java.sh /install/ubuntu_install_java.sh
RUN bash /install/ubuntu_install_java.sh
......@@ -12,4 +12,7 @@ RUN bash /install/ubuntu_install_python.sh
COPY install/ubuntu_install_emscripten.sh /install/ubuntu_install_emscripten.sh
RUN bash /install/ubuntu_install_emscripten.sh
COPY install/ubuntu_install_python_package.sh /install/ubuntu_install_python_package.sh
RUN bash /install/ubuntu_install_python_package.sh
RUN cp /root/.emscripten /emsdk-portable/
\ No newline at end of file
......@@ -18,6 +18,9 @@ RUN bash /install/ubuntu_install_opencl.sh
COPY install/ubuntu_install_iverilog.sh /install/ubuntu_install_iverilog.sh
RUN bash /install/ubuntu_install_iverilog.sh
COPY install/ubuntu_install_python_package.sh /install/ubuntu_install_python_package.sh
RUN bash /install/ubuntu_install_python_package.sh
COPY install/ubuntu_install_sphinx.sh /install/ubuntu_install_sphinx.sh
RUN bash /install/ubuntu_install_sphinx.sh
......
......@@ -5,8 +5,11 @@ RUN apt-get update --fix-missing
COPY install/ubuntu_install_core.sh /install/ubuntu_install_core.sh
RUN bash /install/ubuntu_install_core.sh
COPY install/ubuntu_install_llvm.sh /install/ubuntu_install_llvm.sh
RUN bash /install/ubuntu_install_llvm.sh
COPY install/ubuntu_install_python.sh /install/ubuntu_install_python.sh
RUN bash /install/ubuntu_install_python.sh
COPY install/ubuntu_install_llvm.sh /install/ubuntu_install_llvm.sh
RUN bash /install/ubuntu_install_llvm.sh
COPY install/ubuntu_install_python_package.sh /install/ubuntu_install_python_package.sh
RUN bash /install/ubuntu_install_python_package.sh
apt-get update && apt-get install -y curl
curl -sL https://deb.nodesource.com/setup_6.x | bash -
apt-get update && apt-get install -y nodejs
npm install eslint jsdoc
npm install eslint jsdoc ws
# install libraries for python package on ubuntu
# install python and pip, don't modify this, modify install_python_package.sh
apt-get update && apt-get install -y python-pip python-dev python3-dev
# the version of the pip shipped with ubuntu may be too lower, install a recent version here
cd /tmp && wget https://bootstrap.pypa.io/get-pip.py && python3 get-pip.py && python2 get-pip.py
pip2 install nose pylint numpy nose-timer cython decorator scipy
pip3 install nose pylint numpy nose-timer cython decorator scipy
# install libraries for python package on ubuntu
pip2 install nose pylint numpy nose-timer cython decorator scipy tornado
pip3 install nose pylint numpy nose-timer cython decorator scipy tornado
import tvm
import logging
import numpy as np
import time
import multiprocessing
from tvm.contrib import rpc
def rpc_proxy_test():
"""This is a simple test function for RPC Proxy
It is not included as nosetests, because:
- It depends on tornado
- It relies on the fact that Proxy starts before client and server connects,
which is often the case but not always
User can directly run this script to verify correctness.
"""
try:
from tvm.contrib import rpc_proxy
web_port = 8888
prox = rpc_proxy.Proxy("localhost", web_port=web_port)
def check():
if not tvm.module.enabled("rpc"):
return
@tvm.register_func("rpc.test2.addone")
def addone(x):
return x + 1
@tvm.register_func("rpc.test2.strcat")
def addone(name, x):
return "%s:%d" % (name, x)
server = multiprocessing.Process(
target=rpc_proxy.websocket_proxy_server,
args=("ws://localhost:%d/ws" % web_port,"x1"))
# Need to make sure that the connection start after proxy comes up
time.sleep(0.1)
server.deamon = True
server.start()
client = rpc.connect(prox.host, prox.port, key="x1")
f1 = client.get_function("rpc.test2.addone")
assert f1(10) == 11
f2 = client.get_function("rpc.test2.strcat")
assert f2("abc", 11) == "abc:11"
check()
except ImportError:
print("Skipping because tornado is not avaliable...")
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
rpc_proxy_test()
......@@ -14,10 +14,21 @@ def test_rpc_simple():
def addone(name, x):
return "%s:%d" % (name, x)
@tvm.register_func("rpc.test.except")
def remotethrow(name):
raise ValueError("%s" % name)
server = rpc.Server("localhost")
client = rpc.connect(server.host, server.port)
client = rpc.connect(server.host, server.port, key="x1")
f1 = client.get_function("rpc.test.addone")
assert f1(10) == 11
f3 = client.get_function("rpc.test.except")
try:
f3("abc")
assert False
except tvm.TVMError as e:
assert "abc" in str(e)
f2 = client.get_function("rpc.test.strcat")
assert f2("abc", 11) == "abc:11"
......@@ -42,9 +53,10 @@ def test_rpc_file_exchange():
return
server = rpc.Server("localhost")
remote = rpc.connect(server.host, server.port)
blob = bytearray(np.random.randint(0, 10, size=(127)))
blob = bytearray(np.random.randint(0, 10, size=(10)))
remote.upload(blob, "dat.bin")
rev = remote.download("dat.bin")
assert(rev == blob)
def test_rpc_remote_module():
if not tvm.module.enabled("rpc"):
......@@ -79,7 +91,8 @@ def test_rpc_remote_module():
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
test_rpc_file_exchange()
exit(0)
test_rpc_array()
test_rpc_remote_module()
test_rpc_file_exchange()
test_rpc_simple()
"""Simple testcode to test Javascript RPC
To use it, start a rpc proxy with "python -m tvm.exec.rpc_proxy".
Connect javascript end to the websocket port and connect to the RPC.
"""
import tvm
import os
from tvm.contrib import rpc, util, emscripten
import numpy as np
proxy_host = "localhost"
proxy_port = 9090
def test_rpc_array():
if not tvm.module.enabled("rpc"):
return
# graph
n = tvm.convert(1024)
A = tvm.placeholder((n,), name='A')
B = tvm.compute(A.shape, lambda *i: A(*i) + 1.0, name='B')
s = tvm.create_schedule(B.op)
remote = rpc.connect(proxy_host, proxy_port, key="js")
target = "llvm -target=asmjs-unknown-emscripten -system-lib"
def check_remote():
if not tvm.module.enabled(target):
print("Skip because %s is not enabled" % target)
return
temp = util.tempdir()
ctx = remote.cpu(0)
f = tvm.build(s, [A, B], target, name="myadd")
path_obj = temp.relpath("dev_lib.bc")
path_dso = temp.relpath("dev_lib.js")
f.save(path_obj)
emscripten.create_js(path_dso, path_obj, side_module=True)
# Upload to suffix as dso so it can be loaded remotely
remote.upload(path_dso, "dev_lib.dso")
data = remote.download("dev_lib.dso")
f1 = remote.load_module("dev_lib.dso")
a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx)
b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx)
time_f = f1.time_evaluator(f1.entry_name, remote.cpu(0), number=10)
cost = time_f(a, b)
print('%g secs/op' % cost)
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
check_remote()
test_rpc_array()
......@@ -139,9 +139,18 @@ sysLib.release();
```
## Notes
- Current example supports static linking, which is the preferred way to get more efficiency
Current example supports static linking, which is the preferred way to get more efficiency
in javascript backend.
- It should also be possible to use Emscripten's dynamic linking to dynamically load modules.
- Take a look at tvm_runtime.js which contains quite a few helper functions
to interact with TVM from javascript.
\ No newline at end of file
## Proxy based RPC
We can now use javascript end to start an RPC server and connect to it from python side,
making the testing flow easier.
The following is an example to reproduce this. This requires everything to be in the git source and setup PYTHONPATH(instead of use setup.py install)
- run "python -m tvm.exec.rpc_proxy --example-rpc=1" to start proxy.
- Open broswer, goto the server webpage click Connect to proxy.
- Alternatively run "node web/example_rpc_node.js"
- run "python tests/web/websock_rpc_test.py" to run the rpc client.
The general idea is to use Emscripten's dynamic linking to dynamically load modules.
<html>
<head><title> TVM RPC Test Page </title></head>
<script src="libtvm_web_runtime.js"></script>
<script src="tvm_runtime.js"></script>
<script>
tvm = tvm_runtime.create(Module);
tvm.logger = function(message) {
console.log(message);
var d = document.createElement("div");
d.innerHTML = message;
document.getElementById("log").appendChild(d);
};
function clear_log() {
var node = document.getElementById("log");
while (node.hasChildNodes()) {
node.removeChild(node.lastChild);
}
}
function connect_rpc() {
var proxyurl = document.getElementById("proxyURL").value;
var key = document.getElementById("proxyKey").value;
tvm.startRPCServer(proxyurl, key, 100);
}
</script>
<body>
<h1>TVM Test Page</h1>
To use this page, the easiest way is to do
<ul>
<li> run "python -m tvm.exec.rpc_proxy --example-rpc=1" to start proxy.
<li> Click Connect to proxy.
<li> run "python tests/web/websock_rpc_test.py" to run the rpc client.
</ul>
<h2>Options</h2>
Proxy URL<input name="proxyurl" id="proxyURL" type="text" value="ws://localhost:9888/ws"><br>
RPC Server Key<input name="serverkey" id="proxyKey" type="text" value="js"><br>
<button onclick="connect_rpc()">Connect To Proxy</button>
<button onclick="clear_log()">Clear Log</button>
<div id="log"></div>
</body>
</html>
// Javascript RPC server example
// Start and connect to websocket proxy.
// Load Emscripten Module, need to change path to root/lib
const path = require("path");
process.chdir(path.join(__dirname, "../lib"));
var Module = require("../lib/libtvm_web_runtime.js");
// Bootstrap TVMruntime with emscripten module.
const tvm_runtime = require("../web/tvm_runtime.js");
const tvm = tvm_runtime.create(Module);
var websock_proxy = "ws://localhost:9888/ws";
var num_sess = 100;
tvm.startRPCServer(websock_proxy, "js", num_sess)
......@@ -7,6 +7,7 @@
/* eslint no-unused-vars: "off" */
/* eslint no-unexpected-multiline: "off" */
/* eslint indent: "off" */
/* eslint no-console: "off" */
/**
* TVM Runtime namespace.
* Provide tvm_runtime.create to create a {@link tvm.TVMRuntime}.
......@@ -31,8 +32,13 @@ var tvm_runtime = tvm_runtime || {};
* @memberof tvm
*/
function TVMRuntime() {
"use strict";
var runtime_ref = this;
// Utility function to throw error
function throwError(message) {
if (typeof runtime_ref.logger !== "undefined") {
runtime_ref.logger(message);
}
if (typeof Error !== "undefined") {
throw new Error(message);
}
......@@ -232,6 +238,21 @@ var tvm_runtime = tvm_runtime || {};
throwError(message);
}
};
/**
* Logging function.
* Override this to change logger behavior.
*
* @param {string} message
*/
this.logger = function(message) {
console.log(message);
};
function logging(message) {
runtime_ref.logger(message);
}
// Override print error to logging
Module.printErr = logging;
var CHECK = this.assert;
function TVM_CALL(ret) {
......@@ -746,11 +767,12 @@ var tvm_runtime = tvm_runtime || {};
out_array.release();
return names;
};
var listGlobalFuncNames = this.listGlobalFuncNames;
/**
* Get a global function from TVM runtime.
*
* @param {string} The name of the function.
* @return {Function} The corresponding function.
* @return {Function} The corresponding function, null if function do not exist
*/
this.getGlobalFunc = function (name) {
// alloc
......@@ -759,7 +781,11 @@ var tvm_runtime = tvm_runtime || {};
var out_handle = out.asHandle();
// release
out.release();
return makeTVMFunction(out_handle);
if (out_handle != 0) {
return makeTVMFunction(out_handle);
} else {
return null;
}
};
var getGlobalFunc = this.getGlobalFunc;
/**
......@@ -788,14 +814,123 @@ var tvm_runtime = tvm_runtime || {};
//-----------------------------------------
// Wrap of TVM Functions.
// ----------------------------------------
var fGetSystemLib = getGlobalFunc("module._GetSystemLib");
var systemFunc = {};
/**
* Get system-wide library module singleton.
* Get system-wide library module singleton.5A
* System lib is a global module that contains self register functions in startup.
* @return {tvm.TVMModule} The system module singleton.
*/
this.systemLib = function() {
return fGetSystemLib();
if (typeof systemFunc.fGetSystemLib === "undefined") {
systemFunc.fGetSystemLib = getGlobalFunc("module._GetSystemLib");
}
return systemFunc.fGetSystemLib();
};
this.startRPCServer = function(url, key, counter) {
if (typeof key === "undefined") {
key = "";
}
if (typeof counter === "undefined") {
counter = 1;
}
// Node js, import websocket
var bkey = StringToUint8Array("server:" + key);
var server_name = "WebSocketRPCServer[" + key + "]";
var RPC_MAGIC = 0xff271;
function checkEndian() {
var a = new ArrayBuffer(4);
var b = new Uint8Array(a);
var c = new Uint32Array(a);
b[0] = 0x11;
b[1] = 0x22;
b[2] = 0x33;
b[3] = 0x44;
CHECK(c[0] === 0x44332211, "Need little endian to work");
}
checkEndian();
// start rpc
function RPCServer(counter) {
var socket;
if (typeof module !== "undefined" && module.exports) {
// WebSocket for nodejs
const WebSocket = require("ws");
socket = new WebSocket(url);
} else {
socket = new WebSocket(url);
}
var self = this;
socket.binaryType = "arraybuffer";
this.init = true;
this.counter = counter;
if (typeof systemFunc.fcreateServer === "undefined") {
systemFunc.fcreateServer =
getGlobalFunc("contrib.rpc._CreateEventDrivenServer");
}
if (systemFunc.fcreateServer == null) {
throwError("RPCServer is not included in runtime");
}
var message_handler = systemFunc.fcreateServer(
function(cbytes) {
if (socket.readyState == 1) {
socket.send(cbytes);
return new TVMConstant(cbytes.length, "int32");
} else {
return new TVMConstant(0, "int32");
}
} , server_name);
function on_open(event) {
var intbuf = new Int32Array(1);
intbuf[0] = RPC_MAGIC;
socket.send(intbuf);
intbuf[0] = bkey.length;
socket.send(intbuf);
socket.send(bkey);
logging(server_name + " connected...");
}
function on_message(event) {
if (self.init) {
var msg = new Uint8Array(event.data);
CHECK(msg.length >= 4, "Need message header to be bigger than 4");
var magic = new Int32Array(event.data)[0];
if (magic == RPC_MAGIC + 1) {
throwError("key: " + key + " has already been used in proxy");
} else if (magic == RPC_MAGIC + 2) {
logging(server_name + ": RPCProxy do not have matching client key " + key);
} else {
CHECK(magic == RPC_MAGIC, url + "is not RPC Proxy");
self.init = false;
}
logging(server_name + "init end...");
if (msg.length > 4) {
if (!message_handler(new Uint8Array(event.data, 4, msg.length -4))) {
socket.close();
}
}
} else {
if (!message_handler(new Uint8Array(event.data))) {
socket.close();
}
}
}
function on_close(event) {
message_handler.release();
logging(server_name + ": closed finish...");
if (!self.init && self.counter != 0) {
logging(server_name + ":reconnect to serve another request, session left=" + counter);
// start a new server.
new RPCServer(counter - 1);
}
}
socket.addEventListener("open", on_open);
socket.addEventListener("message", on_message);
socket.addEventListener("close", on_close);
}
return new RPCServer(counter);
};
//-----------------------------------------
// Class defintions
......@@ -943,16 +1078,12 @@ var tvm_runtime = tvm_runtime || {};
* @property {string} create
* @memberof tvm_runtime
* @param Module The emscripten module.
* @param Runtime The emscripten runtime, optional
* @return {tvm.TVMRuntime} The created TVM runtime.
*/
this.create = function(Module, Runtime) {
this.create = function(Module) {
var tvm = {};
tvm.Module = Module;
if (typeof Runtime == "undefined") {
Runtime = Module.Runtime;
}
tvm.Runtime = Runtime;
tvm.Runtime = Module.Runtime;
TVMRuntime.apply(tvm);
return tvm;
};
......
......@@ -2,6 +2,9 @@
* Copyright (c) 2017 by Contributors
* \file web_runtime.cc
*/
#include <sys/stat.h>
#include <fstream>
#include "../src/runtime/c_runtime_api.cc"
#include "../src/runtime/cpu_device_api.cc"
#include "../src/runtime/workspace_pool.cc"
......@@ -10,3 +13,33 @@
#include "../src/runtime/module.cc"
#include "../src/runtime/registry.cc"
#include "../src/runtime/file_util.cc"
#include "../src/runtime/dso_module.cc"
#include "../src/runtime/rpc/rpc_session.cc"
#include "../src/runtime/rpc/rpc_event_impl.cc"
#include "../src/runtime/rpc/rpc_server_env.cc"
namespace tvm {
namespace contrib {
struct RPCEnv {
public:
RPCEnv() {
base_ = "/rpc";
mkdir(&base_[0], 0777);
}
// Get Path.
std::string GetPath(const std::string& file_name) {
return base_ + "/" + file_name;
}
private:
std::string base_;
};
TVM_REGISTER_GLOBAL("tvm.contrib.rpc.server.workpath")
.set_body([](TVMArgs args, TVMRetValue* rv) {
static RPCEnv env;
*rv = env.GetPath(args[0]);
});
} // namespace contrib
} // namespace tvm
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment