Commit c324494f by Tianqi Chen Committed by GitHub

[RUNTIME][RPC] Change RPCServer to Event Driven Code (#243)

* [RUNTIME][RPC] Change RPCServer to Event Driven Code

* fix
parent f4d1dddb
...@@ -58,7 +58,7 @@ CFLAGS = -std=c++11 -Wall -O2 $(INCLUDE_FLAGS) -fPIC ...@@ -58,7 +58,7 @@ CFLAGS = -std=c++11 -Wall -O2 $(INCLUDE_FLAGS) -fPIC
LLVM_CFLAGS= -fno-rtti -DDMLC_ENABLE_RTTI=0 LLVM_CFLAGS= -fno-rtti -DDMLC_ENABLE_RTTI=0
FRAMEWORKS = FRAMEWORKS =
OBJCFLAGS = -fno-objc-arc OBJCFLAGS = -fno-objc-arc
EMCC_FLAGS= -s RESERVED_FUNCTION_POINTERS=2 -s NO_EXIT_RUNTIME=1 -DDMLC_LOG_STACK_TRACE=0\ EMCC_FLAGS= -s RESERVED_FUNCTION_POINTERS=2 -s NO_EXIT_RUNTIME=1 -s MAIN_MODULE=1 -DDMLC_LOG_STACK_TRACE=0\
-std=c++11 -Oz $(INCLUDE_FLAGS) -std=c++11 -Oz $(INCLUDE_FLAGS)
# Dependency specific rules # Dependency specific rules
......
...@@ -212,13 +212,6 @@ class TVMPODValue_ { ...@@ -212,13 +212,6 @@ class TVMPODValue_ {
int type_code() const { int type_code() const {
return type_code_; return type_code_;
} }
protected:
friend class TVMArgsSetter;
friend class TVMRetValue;
TVMPODValue_() : type_code_(kNull) {}
TVMPODValue_(TVMValue value, int type_code)
: value_(value), type_code_(type_code) {}
/*! /*!
* \brief return handle as specific pointer type. * \brief return handle as specific pointer type.
* \tparam T the data type. * \tparam T the data type.
...@@ -228,6 +221,14 @@ class TVMPODValue_ { ...@@ -228,6 +221,14 @@ class TVMPODValue_ {
T* ptr() const { T* ptr() const {
return static_cast<T*>(value_.v_handle); return static_cast<T*>(value_.v_handle);
} }
protected:
friend class TVMArgsSetter;
friend class TVMRetValue;
TVMPODValue_() : type_code_(kNull) {}
TVMPODValue_(TVMValue value, int type_code)
: value_(value), type_code_(type_code) {}
/*! \brief The value */ /*! \brief The value */
TVMValue value_; TVMValue value_;
/*! \brief the type code */ /*! \brief the type code */
...@@ -347,6 +348,8 @@ class TVMRetValue : public TVMPODValue_ { ...@@ -347,6 +348,8 @@ class TVMRetValue : public TVMPODValue_ {
operator std::string() const { operator std::string() const {
if (type_code_ == kTVMType) { if (type_code_ == kTVMType) {
return TVMType2String(operator TVMType()); return TVMType2String(operator TVMType());
} else if (type_code_ == kBytes) {
return *ptr<std::string>();
} }
TVM_CHECK_TYPE_CODE(type_code_, kStr); TVM_CHECK_TYPE_CODE(type_code_, kStr);
return *ptr<std::string>(); return *ptr<std::string>();
...@@ -414,6 +417,10 @@ class TVMRetValue : public TVMPODValue_ { ...@@ -414,6 +417,10 @@ class TVMRetValue : public TVMPODValue_ {
this->SwitchToClass(kStr, value); this->SwitchToClass(kStr, value);
return *this; return *this;
} }
TVMRetValue& operator=(TVMByteArray value) {
this->SwitchToClass(kBytes, std::string(value.data, value.size));
return *this;
}
TVMRetValue& operator=(PackedFunc f) { TVMRetValue& operator=(PackedFunc f) {
this->SwitchToClass(kFuncHandle, f); this->SwitchToClass(kFuncHandle, f);
return *this; return *this;
...@@ -442,7 +449,7 @@ class TVMRetValue : public TVMPODValue_ { ...@@ -442,7 +449,7 @@ class TVMRetValue : public TVMPODValue_ {
void MoveToCHost(TVMValue* ret_value, void MoveToCHost(TVMValue* ret_value,
int* ret_type_code) { int* ret_type_code) {
// cannot move str; need specially handle. // cannot move str; need specially handle.
CHECK(type_code_ != kStr); CHECK(type_code_ != kStr && type_code_ != kBytes);
*ret_value = value_; *ret_value = value_;
*ret_type_code = type_code_; *ret_type_code = type_code_;
type_code_ = kNull; type_code_ = kNull;
...@@ -470,11 +477,14 @@ class TVMRetValue : public TVMPODValue_ { ...@@ -470,11 +477,14 @@ class TVMRetValue : public TVMPODValue_ {
template<typename T> template<typename T>
void Assign(const T& other) { void Assign(const T& other) {
switch (other.type_code()) { switch (other.type_code()) {
case kStr: case kStr: {
case kBytes: {
SwitchToClass<std::string>(kStr, other); SwitchToClass<std::string>(kStr, other);
break; break;
} }
case kBytes: {
SwitchToClass<std::string>(kBytes, other);
break;
}
case kFuncHandle: { case kFuncHandle: {
SwitchToClass<PackedFunc>(kFuncHandle, other); SwitchToClass<PackedFunc>(kFuncHandle, other);
break; break;
...@@ -702,6 +712,7 @@ class TVMArgsSetter { ...@@ -702,6 +712,7 @@ class TVMArgsSetter {
values_[i].v_str = value.ptr<std::string>()->c_str(); values_[i].v_str = value.ptr<std::string>()->c_str();
type_codes_[i] = kStr; type_codes_[i] = kStr;
} else { } else {
CHECK_NE(value.type_code(), kBytes) << "not handled.";
values_[i] = value.value_; values_[i] = value.value_;
type_codes_[i] = value.type_code(); type_codes_[i] = value.type_code();
} }
......
...@@ -7,6 +7,7 @@ from .._ffi.libinfo import find_lib_path ...@@ -7,6 +7,7 @@ from .._ffi.libinfo import find_lib_path
def create_js(output, def create_js(output,
objects, objects,
options=None, options=None,
side_module=False,
cc="emcc"): cc="emcc"):
"""Create emscripten javascript library. """Create emscripten javascript library.
...@@ -29,6 +30,8 @@ def create_js(output, ...@@ -29,6 +30,8 @@ def create_js(output,
cmd += ["-s", "NO_EXIT_RUNTIME=1"] cmd += ["-s", "NO_EXIT_RUNTIME=1"]
cmd += ["-Oz"] cmd += ["-Oz"]
cmd += ["-o", output] cmd += ["-o", output]
if side_module:
cmd += ["-s", "SIDE_MODULE=1"]
objects = [objects] if isinstance(objects, str) else objects objects = [objects] if isinstance(objects, str) else objects
with_runtime = False with_runtime = False
...@@ -36,7 +39,7 @@ def create_js(output, ...@@ -36,7 +39,7 @@ def create_js(output,
if obj.find("libtvm_web_runtime.bc") != -1: if obj.find("libtvm_web_runtime.bc") != -1:
with_runtime = True with_runtime = True
if not with_runtime: if not with_runtime and not side_module:
objects += [find_lib_path("libtvm_web_runtime.bc")[0]] objects += [find_lib_path("libtvm_web_runtime.bc")[0]]
cmd += objects cmd += objects
......
...@@ -19,32 +19,20 @@ from . import util, cc_compiler ...@@ -19,32 +19,20 @@ from . import util, cc_compiler
from ..module import load as _load_module from ..module import load as _load_module
from .._ffi.function import _init_api, register_func from .._ffi.function import _init_api, register_func
from .._ffi.ndarray import context as _context from .._ffi.ndarray import context as _context
from .._ffi.base import py_str
RPC_MAGIC = 0xff271 RPC_MAGIC = 0xff271
RPC_SESS_MASK = 128 RPC_SESS_MASK = 128
def _serve_loop(sock, addr): def _server_env():
"""Server loop""" """Server environment function return temp dir"""
sockfd = sock.fileno()
temp = util.tempdir() temp = util.tempdir()
# pylint: disable=unused-variable # pylint: disable=unused-variable
@register_func("tvm.contrib.rpc.server.upload") @register_func("tvm.contrib.rpc.server.workpath")
def upload(file_name, blob): def get_workpath(path):
"""Upload the blob to remote temp file""" return temp.relpath(path)
path = temp.relpath(file_name)
with open(path, "wb") as out_file:
out_file.write(blob)
logging.info("upload %s", path)
@register_func("tvm.contrib.rpc.server.download")
def download(file_name):
"""Download file from remote"""
path = temp.relpath(file_name)
dat = bytearray(open(path, "rb").read())
logging.info("download %s", path)
return dat
@register_func("tvm.contrib.rpc.server.load_module") @register_func("tvm.contrib.rpc.server.load_module", override=True)
def load_module(file_name): def load_module(file_name):
"""Load module from remote side.""" """Load module from remote side."""
path = temp.relpath(file_name) path = temp.relpath(file_name)
...@@ -53,11 +41,16 @@ def _serve_loop(sock, addr): ...@@ -53,11 +41,16 @@ def _serve_loop(sock, addr):
logging.info('Create shared library based on %s', path) logging.info('Create shared library based on %s', path)
cc_compiler.create_shared(path + '.so', path) cc_compiler.create_shared(path + '.so', path)
path += '.so' path += '.so'
m = _load_module(path) m = _load_module(path)
logging.info("load_module %s", path) logging.info("load_module %s", path)
return m return m
return temp
def _serve_loop(sock, addr):
"""Server loop"""
sockfd = sock.fileno()
temp = _server_env()
_ServerLoop(sockfd) _ServerLoop(sockfd)
temp.remove() temp.remove()
logging.info("Finish serving %s", addr) logging.info("Finish serving %s", addr)
...@@ -78,11 +71,16 @@ def _listen_loop(sock): ...@@ -78,11 +71,16 @@ def _listen_loop(sock):
while True: while True:
conn, addr = sock.accept() conn, addr = sock.accept()
logging.info("RPCServer: connection from %s", addr) logging.info("RPCServer: connection from %s", addr)
conn.sendall(struct.pack('@i', RPC_MAGIC))
magic = struct.unpack('@i', _recvall(conn, 4))[0] magic = struct.unpack('@i', _recvall(conn, 4))[0]
if magic != RPC_MAGIC: if magic != RPC_MAGIC:
conn.close() conn.close()
continue continue
keylen = struct.unpack('@i', _recvall(conn, 4))[0]
key = py_str(_recvall(conn, keylen))
if not key.startswith("client:"):
conn.sendall(struct.pack('@i', RPC_MAGIC + 2))
else:
conn.sendall(struct.pack('@i', RPC_MAGIC))
logging.info("Connection from %s", addr) logging.info("Connection from %s", addr)
process = multiprocessing.Process(target=_serve_loop, args=(conn, addr)) process = multiprocessing.Process(target=_serve_loop, args=(conn, addr))
process.deamon = True process.deamon = True
...@@ -91,6 +89,27 @@ def _listen_loop(sock): ...@@ -91,6 +89,27 @@ def _listen_loop(sock):
conn.close() conn.close()
def _connect_proxy_loop(addr, key):
while True:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect(addr)
sock.sendall(struct.pack('@i', RPC_MAGIC))
sock.sendall(struct.pack('@i', len(key)))
sock.sendall(key)
magic = struct.unpack('@i', _recvall(sock, 4))[0]
if magic == RPC_MAGIC + 1:
raise RuntimeError("key: %s has already been used in proxy" % key)
elif magic == RPC_MAGIC + 2:
logging.info("RPCProxy do not have matching client key %s", key)
elif magic != RPC_MAGIC:
raise RuntimeError("%s is not RPC Proxy" % str(addr))
logging.info("RPCProxy connected to %s", str(addr))
process = multiprocessing.Process(target=_serve_loop, args=(sock, addr))
process.deamon = True
process.start()
process.join()
class Server(object): class Server(object):
"""Start RPC server on a seperate process. """Start RPC server on a seperate process.
...@@ -108,27 +127,43 @@ class Server(object): ...@@ -108,27 +127,43 @@ class Server(object):
port_end : int, optional port_end : int, optional
The end port to search The end port to search
is_proxy : bool, optional
Whether the address specified is a proxy.
If this is true, the host and port actually corresponds to the
address of the proxy server.
key : str, optional
The key used to identify the server in Proxy connection.
""" """
def __init__(self, host, port=9091, port_end=9199): def __init__(self, host, port=9091, port_end=9199, is_proxy=False, key=""):
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.port = None
for my_port in range(port, port_end):
try:
sock.bind((host, my_port))
self.port = my_port
break
except socket.error as sock_err:
if sock_err.errno in [98, 48]:
continue
else:
raise sock_err
if not self.port:
raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end))
logging.info("RPCServer: bind to %s:%d", host, self.port)
sock.listen(1)
self.sock = sock
self.host = host self.host = host
self.proc = multiprocessing.Process(target=_listen_loop, args=(self.sock,)) self.port = port
if not is_proxy:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.port = None
for my_port in range(port, port_end):
try:
sock.bind((host, my_port))
self.port = my_port
break
except socket.error as sock_err:
if sock_err.errno in [98, 48]:
continue
else:
raise sock_err
if not self.port:
raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end))
logging.info("RPCServer: bind to %s:%d", host, self.port)
sock.listen(1)
self.sock = sock
self.proc = multiprocessing.Process(
target=_listen_loop, args=(self.sock,))
else:
self.proc = multiprocessing.Process(
target=_connect_proxy_loop, args=((host, port), key))
self.proc.deamon = True
self.proc.start() self.proc.start()
def terminate(self): def terminate(self):
...@@ -141,6 +176,7 @@ class Server(object): ...@@ -141,6 +176,7 @@ class Server(object):
self.terminate() self.terminate()
class RPCSession(object): class RPCSession(object):
"""RPC Client session module """RPC Client session module
...@@ -262,7 +298,7 @@ class RPCSession(object): ...@@ -262,7 +298,7 @@ class RPCSession(object):
return _LoadRemoteModule(self._sess, path) return _LoadRemoteModule(self._sess, path)
def connect(url, port): def connect(url, port, key=""):
"""Connect to RPC Server """Connect to RPC Server
Parameters Parameters
...@@ -273,13 +309,16 @@ def connect(url, port): ...@@ -273,13 +309,16 @@ def connect(url, port):
port : int port : int
The port to connect to The port to connect to
key : str, optional
Additional key to match server
Returns Returns
------- -------
sess : RPCSession sess : RPCSession
The connected session. The connected session.
""" """
try: try:
sess = _Connect(url, port) sess = _Connect(url, port, key)
except NameError: except NameError:
raise RuntimeError('Please compile with USE_RPC=1') raise RuntimeError('Please compile with USE_RPC=1')
return RPCSession(sess) return RPCSession(sess)
......
"""RPC web proxy, allows redirect to websocket based RPC servers(browsers)"""
from __future__ import absolute_import
import logging
import argparse
import os
from ..contrib.rpc_proxy import Proxy
def find_example_resource():
"""Find resource examples."""
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
base_path = os.path.join(curr_path, "../../../")
index_page = os.path.join(base_path, "web/example_rpc.html")
js_files = [
os.path.join(base_path, "web/tvm_runtime.js"),
os.path.join(base_path, "lib/libtvm_web_runtime.js"),
os.path.join(base_path, "lib/libtvm_web_runtime.js.mem")
]
for fname in [index_page] + js_files:
if not os.path.exists(fname):
raise RuntimeError("Cannot find %s" % fname)
return index_page, js_files
def main():
"""Main funciton"""
parser = argparse.ArgumentParser()
parser.add_argument('--host', type=str, default="0.0.0.0",
help='the hostname of the server')
parser.add_argument('--port', type=int, default=9090,
help='The port of the PRC')
parser.add_argument('--web-port', type=int, default=9888,
help='The port of the http/websocket server')
parser.add_argument('--example-rpc', type=bool, default=False,
help='Whether to switch on example rpc mode')
args = parser.parse_args()
logging.basicConfig(level=logging.INFO)
if args.example_rpc:
index, js_files = find_example_resource()
prox = Proxy(args.host, port=args.port,
web_port=args.web_port, index_page=index,
resource_files=js_files)
else:
prox = Proxy(args.host, port=args.port, web_port=args.web_port)
prox.proc.join()
if __name__ == "__main__":
main()
/*!
* Copyright (c) 2017 by Contributors
* \file ring_buffer.h
* \brief this file aims to provide a wrapper of sockets
*/
#ifndef TVM_COMMON_RING_BUFFER_H_
#define TVM_COMMON_RING_BUFFER_H_
#include <vector>
#include <cstring>
#include <algorithm>
namespace tvm {
namespace common {
/*!
* \brief Ring buffer class for data buffering in IO.
* Enables easy usage for sync and async mode.
*/
class RingBuffer {
public:
/*! \brief Initial capacity of ring buffer. */
static const int kInitCapacity = 4 << 10;
/*! \brief constructor */
RingBuffer() : ring_(kInitCapacity) {}
/*! \return number of bytes available in buffer. */
size_t bytes_available() const {
return bytes_available_;
}
/*! \return Current capacity of buffer. */
size_t capacity() const {
return ring_.size();
}
/*!
* Reserve capacity to be at least n.
* Will only increase capacity if n is bigger than current capacity.
* \param n The size of capacity.
*/
void Reserve(size_t n) {
if (ring_.size() >= n) return;
size_t old_size = ring_.size();
size_t new_size = ring_.size();
while (new_size < n) {
new_size *= 2;
}
ring_.resize(new_size);
if (head_ptr_ + bytes_available_ > old_size) {
// copy the ring overflow part into the tail.
size_t ncopy = head_ptr_ + bytes_available_ - old_size;
memcpy(&ring_[0] + old_size, &ring_[0], ncopy);
}
}
/*!
* \brief Peform a non-blocking read from buffer
* size must be smaller than this->bytes_available()
* \param data the data pointer.
* \param size The number of bytes to read.
*/
void Read(void* data, size_t size) {
CHECK_GE(bytes_available_, size);
size_t ncopy = std::min(size, ring_.size() - head_ptr_);
memcpy(data, &ring_[0] + head_ptr_, ncopy);
if (ncopy < size) {
memcpy(reinterpret_cast<char*>(data) + ncopy,
&ring_[0], size - ncopy);
}
head_ptr_ = (head_ptr_ + size) % ring_.size();
bytes_available_ -= size;
}
/*!
* \brief Read data from buffer with and put them to non-blocking send function.
*
* \param frecv A send function handle to put the data to.
* \param max_nbytes Maximum number of bytes can to read.
* \tparam FSend A non-blocking function with signature size_t (const void* data, size_t size);
*/
template<typename FSend>
size_t ReadWithCallback(FSend fsend, size_t max_nbytes) {
size_t size = std::min(max_nbytes, bytes_available_);
CHECK_NE(size, 0U);
size_t ncopy = std::min(size, ring_.size() - head_ptr_);
size_t nsend = fsend(&ring_[0] + head_ptr_, ncopy);
bytes_available_ -= nsend;
if (ncopy == nsend && ncopy < size) {
size_t nsend2 = fsend(&ring_[0], size - ncopy);
bytes_available_ -= nsend2;
nsend += nsend2;
}
return nsend;
}
/*!
* \brief Write data into buffer, always ensures all data is written.
* \param data The data pointer
* \param size The size of data to be written.
*/
void Write(const void* data, size_t size) {
this->Reserve(bytes_available_ + size);
size_t tail = head_ptr_ + bytes_available_;
if (tail >= ring_.size()) {
memcpy(&ring_[0] + (tail - ring_.size()), data, size);
} else {
size_t ncopy = std::min(ring_.size() - tail, size);
memcpy(&ring_[0] + tail, data, ncopy);
if (ncopy < size) {
memcpy(&ring_[0], reinterpret_cast<const char*>(data) + ncopy, size - ncopy);
}
}
bytes_available_ += size;
}
/*!
* \brief Writen data into the buffer by give it a non-blocking callback function.
*
* \param frecv A receive function handle
* \param max_nbytes Maximum number of bytes can write.
* \tparam FRecv A non-blocking function with signature size_t (void* data, size_t size);
*/
template<typename FRecv>
size_t WriteWithCallback(FRecv frecv, size_t max_nbytes) {
this->Reserve(bytes_available_ + max_nbytes);
size_t nbytes = max_nbytes;
size_t tail = head_ptr_ + bytes_available_;
if (tail >= ring_.size()) {
size_t nrecv = frecv(&ring_[0] + (tail - ring_.size()), nbytes);
bytes_available_ += nrecv;
return nrecv;
} else {
size_t ncopy = std::min(ring_.size() - tail, nbytes);
size_t nrecv = frecv(&ring_[0] + tail, ncopy);
bytes_available_ += nrecv;
if (nrecv == ncopy && ncopy < nbytes) {
size_t nrecv2 = frecv(&ring_[0], nbytes - ncopy);
bytes_available_ += nrecv2;
nrecv += nrecv2;
}
return nrecv;
}
}
private:
// buffer head
size_t head_ptr_{0};
// number of bytes in the buffer.
size_t bytes_available_{0};
// The internald ata ring.
std::vector<char> ring_;
};
} // namespace common
} // namespace tvm
#endif // TVM_COMMON_RING_BUFFER_H_
...@@ -157,6 +157,7 @@ using namespace tvm::runtime; ...@@ -157,6 +157,7 @@ using namespace tvm::runtime;
struct TVMRuntimeEntry { struct TVMRuntimeEntry {
std::string ret_str; std::string ret_str;
std::string last_error; std::string last_error;
TVMByteArray ret_bytes;
// threads used in parallel for // threads used in parallel for
std::vector<std::thread> par_threads; std::vector<std::thread> par_threads;
// errors created in parallel for. // errors created in parallel for.
...@@ -311,11 +312,23 @@ int TVMFuncCall(TVMFunctionHandle func, ...@@ -311,11 +312,23 @@ int TVMFuncCall(TVMFunctionHandle func,
TVMArgs(args, arg_type_codes, num_args), &rv); TVMArgs(args, arg_type_codes, num_args), &rv);
// handle return string. // handle return string.
if (rv.type_code() == kStr || if (rv.type_code() == kStr ||
rv.type_code() == kTVMType) { rv.type_code() == kTVMType ||
rv.type_code() == kBytes) {
TVMRuntimeEntry* e = TVMAPIRuntimeStore::Get(); TVMRuntimeEntry* e = TVMAPIRuntimeStore::Get();
e->ret_str = rv.operator std::string(); if (rv.type_code() != kTVMType) {
*ret_type_code = kStr; e->ret_str = *rv.ptr<std::string>();
ret_val->v_str = e->ret_str.c_str(); } else {
e->ret_str = rv.operator std::string();
}
if (rv.type_code() == kBytes) {
e->ret_bytes.data = e->ret_str.c_str();
e->ret_bytes.size = e->ret_str.length();
*ret_type_code = kBytes;
ret_val->v_handle = &(e->ret_bytes);
} else {
*ret_type_code = kStr;
ret_val->v_str = e->ret_str.c_str();
}
} else { } else {
rv.MoveToCHost(ret_val, ret_type_code); rv.MoveToCHost(ret_val, ret_type_code);
} }
......
/*!
* Copyright (c) 2017 by Contributors
* \file rpc_event_impl.cc
* \brief Event based RPC server implementation.
*/
#include <tvm/runtime/registry.h>
#include <memory>
#include "./rpc_session.h"
namespace tvm {
namespace runtime {
class CallbackChannel final : public RPCChannel {
public:
explicit CallbackChannel(PackedFunc fsend)
: fsend_(fsend) {}
size_t Send(const void* data, size_t size) final {
TVMByteArray bytes;
bytes.data = static_cast<const char*>(data);
bytes.size = size;
uint64_t ret = fsend_(bytes);
return static_cast<size_t>(ret);
}
size_t Recv(void* data, size_t size) final {
LOG(FATAL) << "Do not allow explicit receive for";
return 0;
}
private:
PackedFunc fsend_;
};
PackedFunc CreateEvenDrivenServer(PackedFunc fsend, std::string name) {
std::unique_ptr<CallbackChannel> ch(new CallbackChannel(fsend));
std::shared_ptr<RPCSession> sess = RPCSession::Create(std::move(ch), name);
return PackedFunc([sess](TVMArgs args, TVMRetValue* rv) {
bool ret = sess->ServerOnMessageHandler(args[0]);
*rv = ret;
});
}
TVM_REGISTER_GLOBAL("contrib.rpc._CreateEventDrivenServer")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = CreateEvenDrivenServer(args[0], args[1]);
});
} // namespace runtime
} // namespace tvm
...@@ -10,8 +10,6 @@ ...@@ -10,8 +10,6 @@
namespace tvm { namespace tvm {
namespace runtime { namespace runtime {
const int kRPCMagic = 0xff271;
// Wrapped remote function to packed func. // Wrapped remote function to packed func.
struct RPCWrappedFunc { struct RPCWrappedFunc {
public: public:
...@@ -98,36 +96,12 @@ class RPCModuleNode final : public ModuleNode { ...@@ -98,36 +96,12 @@ class RPCModuleNode final : public ModuleNode {
std::shared_ptr<RPCSession> sess_; std::shared_ptr<RPCSession> sess_;
}; };
Module RPCConnect(std::string url, int port) { Module CreateRPCModule(std::shared_ptr<RPCSession> sess) {
common::TCPSocket sock;
common::SockAddr addr(url.c_str(), port);
sock.Create();
CHECK(sock.Connect(addr))
<< "Connect to " << addr.AsString() << " failed";
// hand shake
int code = kRPCMagic;
CHECK_EQ(sock.SendAll(&code, sizeof(code)), sizeof(code));
CHECK_EQ(sock.RecvAll(&code, sizeof(code)), sizeof(code));
if (code != kRPCMagic) {
sock.Close();
LOG(FATAL) << "URL " << url << ":" << port << " is not TVM RPC server";
}
std::shared_ptr<RPCModuleNode> n = std::shared_ptr<RPCModuleNode> n =
std::make_shared<RPCModuleNode>(nullptr, RPCSession::Create(sock)); std::make_shared<RPCModuleNode>(nullptr, sess);
return Module(n); return Module(n);
} }
void RPCServerLoop(int sockfd) {
common::TCPSocket sock(
static_cast<common::TCPSocket::SockType>(sockfd));
RPCSession::Create(sock)->ServerLoop();
}
TVM_REGISTER_GLOBAL("contrib.rpc._Connect")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = RPCConnect(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("module._RPCTimeEvaluator") TVM_REGISTER_GLOBAL("module._RPCTimeEvaluator")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
Module m = args[0]; Module m = args[0];
...@@ -163,10 +137,5 @@ TVM_REGISTER_GLOBAL("contrib.rpc._SessTableIndex") ...@@ -163,10 +137,5 @@ TVM_REGISTER_GLOBAL("contrib.rpc._SessTableIndex")
CHECK_EQ(tkey, "rpc"); CHECK_EQ(tkey, "rpc");
*rv = static_cast<RPCModuleNode*>(m.operator->())->sess()->table_index(); *rv = static_cast<RPCModuleNode*>(m.operator->())->sess()->table_index();
}); });
TVM_REGISTER_GLOBAL("contrib.rpc._ServerLoop")
.set_body([](TVMArgs args, TVMRetValue* rv) {
RPCServerLoop(args[0]);
});
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
/*!
* Copyright (c) 2017 by Contributors
* \file rpc_server_env
* \brief Server environment of the RPC.
*/
#include <tvm/runtime/registry.h>
#include "../file_util.h"
namespace tvm {
namespace runtime {
std::string RPCGetPath(const std::string& name) {
static const PackedFunc* f =
runtime::Registry::Get("tvm.contrib.rpc.server.workpath");
CHECK(f != nullptr) << "require tvm.contrib.rpc.server.workpath";
return (*f)(name);
}
TVM_REGISTER_GLOBAL("tvm.contrib.rpc.server.upload").
set_body([](TVMArgs args, TVMRetValue *rv) {
std::string file_name = RPCGetPath(args[0]);
std::string data = args[1];
LOG(INFO) << "Upload " << file_name << "... nbytes=" << data.length();
SaveBinaryToFile(file_name, data);
});
TVM_REGISTER_GLOBAL("tvm.contrib.rpc.server.download")
.set_body([](TVMArgs args, TVMRetValue *rv) {
std::string file_name = RPCGetPath(args[0]);
std::string data;
LoadBinaryFromFile(file_name, &data);
TVMByteArray arr;
arr.data = data.c_str();
arr.size = data.length();
LOG(INFO) << "Download " << file_name << "... nbytes=" << arr.size;
*rv = arr;
});
TVM_REGISTER_GLOBAL("tvm.contrib.rpc.server.load_module")
.set_body([](TVMArgs args, TVMRetValue *rv) {
std::string file_name = RPCGetPath(args[0]);
*rv = Module::LoadFromFile(file_name, "");
LOG(INFO) << "Load module from " << file_name << " ...";
});
} // namespace runtime
} // namespace tvm
...@@ -10,11 +10,13 @@ ...@@ -10,11 +10,13 @@
#include <tvm/runtime/device_api.h> #include <tvm/runtime/device_api.h>
#include <mutex> #include <mutex>
#include <string> #include <string>
#include "../../common/socket.h" #include "../../common/ring_buffer.h"
namespace tvm { namespace tvm {
namespace runtime { namespace runtime {
const int kRPCMagic = 0xff271;
/*! \brief The remote functio handle */ /*! \brief The remote functio handle */
using RPCFuncHandle = void*; using RPCFuncHandle = void*;
...@@ -22,6 +24,7 @@ struct RPCArgBuffer; ...@@ -22,6 +24,7 @@ struct RPCArgBuffer;
/*! \brief The RPC code */ /*! \brief The RPC code */
enum class RPCCode : int { enum class RPCCode : int {
kNone,
kCallFunc, kCallFunc,
kReturn, kReturn,
kException, kException,
...@@ -30,6 +33,7 @@ enum class RPCCode : int { ...@@ -30,6 +33,7 @@ enum class RPCCode : int {
kCopyToRemote, kCopyToRemote,
kCopyAck, kCopyAck,
// The following are code that can send over CallRemote // The following are code that can send over CallRemote
kSystemFuncStart,
kGetGlobalFunc, kGetGlobalFunc,
kGetTimeEvaluator, kGetTimeEvaluator,
kFreeFunc, kFreeFunc,
...@@ -45,6 +49,30 @@ enum class RPCCode : int { ...@@ -45,6 +49,30 @@ enum class RPCCode : int {
kModuleGetSource kModuleGetSource
}; };
/*!
* \brief Abstract channel interface used to create RPCSession.
*/
class RPCChannel {
public:
/*! \brief virtual destructor */
virtual ~RPCChannel() {}
/*!
* \brief Send data over to the channel.
* \param data The data pointer.
* \param size The size fo the data.
* \return The actual bytes sent.
*/
virtual size_t Send(const void* data, size_t size) = 0;
/*!
e * \brief Recv data from channel.
*
* \param data The data pointer.
* \param size The size fo the data.
* \return The actual bytes received.
*/
virtual size_t Recv(void* data, size_t size) = 0;
};
// Bidirectional Communication Session of PackedRPC // Bidirectional Communication Session of PackedRPC
class RPCSession { class RPCSession {
public: public:
...@@ -55,6 +83,17 @@ class RPCSession { ...@@ -55,6 +83,17 @@ class RPCSession {
*/ */
void ServerLoop(); void ServerLoop();
/*! /*!
* \brief Message handling function for event driven server.
* Called when the server receives a message.
* Event driven handler will never call recv on the channel
* and always relies on the ServerOnMessageHandler
* to receive the data.
*
* \param bytes The incoming bytes.
* \return Whether need continue running, return false when receive a shutdown message.
*/
bool ServerOnMessageHandler(const std::string& bytes);
/*!
* \brief Call into remote function * \brief Call into remote function
* \param handle The function handle * \param handle The function handle
* \param args The arguments * \param args The arguments
...@@ -121,10 +160,13 @@ class RPCSession { ...@@ -121,10 +160,13 @@ class RPCSession {
} }
/*! /*!
* \brief Create a RPC session with given socket * \brief Create a RPC session with given socket
* \param sock The socket. * \param channel The communication channel.
* \param name The name of the session, used for debug
* \return The session. * \return The session.
*/ */
static std::shared_ptr<RPCSession> Create(common::TCPSocket sock); static std::shared_ptr<RPCSession> Create(
std::unique_ptr<RPCChannel> channel,
std::string name);
/*! /*!
* \brief Try get session from the global session table by table index. * \brief Try get session from the global session table by table index.
* \param table_index The table index of the session. * \param table_index The table index of the session.
...@@ -133,36 +175,28 @@ class RPCSession { ...@@ -133,36 +175,28 @@ class RPCSession {
static std::shared_ptr<RPCSession> Get(int table_index); static std::shared_ptr<RPCSession> Get(int table_index);
private: private:
/*! class EventHandler;
* \brief Handle the remote call with f // Handle events until receives a return
* \param f The handle function // Also flushes channels so that the function advances.
* \tparam F the handler function. RPCCode HandleUntilReturnEvent(TVMRetValue* rv);
*/ // Initalization
template<typename F>
void CallHandler(F f);
void Init(); void Init();
// Shutdown
void Shutdown(); void Shutdown();
void SendReturnValue(int succ, TVMValue value, int tcode); // Internal channel.
void SendPackedSeq(const TVMValue* arg_values, const int* type_codes, int n); std::unique_ptr<RPCChannel> channel_;
void RecvPackedSeq(RPCArgBuffer *buf);
RPCCode HandleNextEvent(TVMRetValue *rv);
TVMContext StripSessMask(TVMContext ctx);
// special handler.
void HandleCallFunc();
void HandleException();
void HandleCopyFromRemote();
void HandleCopyToRemote();
void HandleReturn(TVMRetValue* rv);
// Internal mutex // Internal mutex
std::recursive_mutex mutex_; std::recursive_mutex mutex_;
// Internal socket // Internal ring buffer.
common::TCPSocket sock_; common::RingBuffer reader_, writer_;
// Internal temporal data space. // Event handler.
std::string temp_data_; std::shared_ptr<EventHandler> handler_;
// call remote with the specified function coede. // call remote with the specified function coede.
PackedFunc call_remote_; PackedFunc call_remote_;
// The index of this session in RPC session table. // The index of this session in RPC session table.
int table_index_{0}; int table_index_{0};
// The name of the session.
std::string name_;
}; };
/*! /*!
...@@ -173,6 +207,13 @@ class RPCSession { ...@@ -173,6 +207,13 @@ class RPCSession {
*/ */
PackedFunc WrapTimeEvaluator(PackedFunc f, TVMContext ctx, int nstep); PackedFunc WrapTimeEvaluator(PackedFunc f, TVMContext ctx, int nstep);
/*!
* \brief Create a Global RPC module that refers to the session.
* \param sess The RPC session of the global module.
* \return The created module.
*/
Module CreateRPCModule(std::shared_ptr<RPCSession> sess);
// Remote space pointer. // Remote space pointer.
struct RemoteSpace { struct RemoteSpace {
void* data; void* data;
...@@ -183,7 +224,7 @@ struct RemoteSpace { ...@@ -183,7 +224,7 @@ struct RemoteSpace {
template<typename... Args> template<typename... Args>
inline TVMRetValue RPCSession::CallRemote(RPCCode code, Args&& ...args) { inline TVMRetValue RPCSession::CallRemote(RPCCode code, Args&& ...args) {
std::lock_guard<std::recursive_mutex> lock(mutex_); std::lock_guard<std::recursive_mutex> lock(mutex_);
CHECK_EQ(sock_.SendAll(&code, sizeof(code)), sizeof(code)); writer_.Write(&code, sizeof(code));
return call_remote_(std::forward<Args>(args)...); return call_remote_(std::forward<Args>(args)...);
} }
} // namespace runtime } // namespace runtime
......
/*!
* Copyright (c) 2017 by Contributors
* \file rpc_socket_impl.cc
* \brief Socket based RPC implementation.
*/
#include <tvm/runtime/registry.h>
#include <memory>
#include "./rpc_session.h"
#include "../../common/socket.h"
namespace tvm {
namespace runtime {
class SockChannel final : public RPCChannel {
public:
explicit SockChannel(common::TCPSocket sock)
: sock_(sock) {}
~SockChannel() {
if (!sock_.BadSocket()) {
sock_.Close();
}
}
size_t Send(const void* data, size_t size) final {
ssize_t n = sock_.Send(data, size);
if (n == -1) {
common::Socket::Error("SockChannel::Send");
}
return static_cast<size_t>(n);
}
size_t Recv(void* data, size_t size) final {
ssize_t n = sock_.Recv(data, size);
if (n == -1) {
common::Socket::Error("SockChannel::Recv");
}
return static_cast<size_t>(n);
}
private:
common::TCPSocket sock_;
};
Module RPCConnect(std::string url, int port, std::string key) {
common::TCPSocket sock;
common::SockAddr addr(url.c_str(), port);
sock.Create();
CHECK(sock.Connect(addr))
<< "Connect to " << addr.AsString() << " failed";
// hand shake
std::ostringstream os;
os << "client:" << key;
key = os.str();
int code = kRPCMagic;
int keylen = static_cast<int>(key.length());
CHECK_EQ(sock.SendAll(&code, sizeof(code)), sizeof(code));
CHECK_EQ(sock.SendAll(&keylen, sizeof(keylen)), sizeof(keylen));
if (keylen != 0) {
CHECK_EQ(sock.SendAll(key.c_str(), keylen), keylen);
}
CHECK_EQ(sock.RecvAll(&code, sizeof(code)), sizeof(code));
if (code == kRPCMagic + 2) {
sock.Close();
LOG(FATAL) << "URL " << url << ":" << port
<< " cannot find server that matches key=" << key;
} else if (code == kRPCMagic + 1) {
sock.Close();
LOG(FATAL) << "URL " << url << ":" << port
<< " server already have client key=" << key;
} else if (code != kRPCMagic) {
sock.Close();
LOG(FATAL) << "URL " << url << ":" << port << " is not TVM RPC server";
}
return CreateRPCModule(
RPCSession::Create(
std::unique_ptr<SockChannel>(new SockChannel(sock)),
"SockClient"));
}
void RPCServerLoop(int sockfd) {
common::TCPSocket sock(
static_cast<common::TCPSocket::SockType>(sockfd));
RPCSession::Create(
std::unique_ptr<SockChannel>(new SockChannel(sock)),
"SockServerLoop")->ServerLoop();
}
TVM_REGISTER_GLOBAL("contrib.rpc._Connect")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = RPCConnect(args[0], args[1], args[2]);
});
TVM_REGISTER_GLOBAL("contrib.rpc._ServerLoop")
.set_body([](TVMArgs args, TVMRetValue* rv) {
RPCServerLoop(args[0]);
});
} // namespace runtime
} // namespace tvm
...@@ -12,5 +12,8 @@ RUN bash /install/ubuntu_install_python.sh ...@@ -12,5 +12,8 @@ RUN bash /install/ubuntu_install_python.sh
COPY install/ubuntu_install_iverilog.sh /install/ubuntu_install_iverilog.sh COPY install/ubuntu_install_iverilog.sh /install/ubuntu_install_iverilog.sh
RUN bash /install/ubuntu_install_iverilog.sh RUN bash /install/ubuntu_install_iverilog.sh
COPY install/ubuntu_install_python_package.sh /install/ubuntu_install_python_package.sh
RUN bash /install/ubuntu_install_python_package.sh
COPY install/ubuntu_install_java.sh /install/ubuntu_install_java.sh COPY install/ubuntu_install_java.sh /install/ubuntu_install_java.sh
RUN bash /install/ubuntu_install_java.sh RUN bash /install/ubuntu_install_java.sh
...@@ -12,4 +12,7 @@ RUN bash /install/ubuntu_install_python.sh ...@@ -12,4 +12,7 @@ RUN bash /install/ubuntu_install_python.sh
COPY install/ubuntu_install_emscripten.sh /install/ubuntu_install_emscripten.sh COPY install/ubuntu_install_emscripten.sh /install/ubuntu_install_emscripten.sh
RUN bash /install/ubuntu_install_emscripten.sh RUN bash /install/ubuntu_install_emscripten.sh
COPY install/ubuntu_install_python_package.sh /install/ubuntu_install_python_package.sh
RUN bash /install/ubuntu_install_python_package.sh
RUN cp /root/.emscripten /emsdk-portable/ RUN cp /root/.emscripten /emsdk-portable/
\ No newline at end of file
...@@ -18,6 +18,9 @@ RUN bash /install/ubuntu_install_opencl.sh ...@@ -18,6 +18,9 @@ RUN bash /install/ubuntu_install_opencl.sh
COPY install/ubuntu_install_iverilog.sh /install/ubuntu_install_iverilog.sh COPY install/ubuntu_install_iverilog.sh /install/ubuntu_install_iverilog.sh
RUN bash /install/ubuntu_install_iverilog.sh RUN bash /install/ubuntu_install_iverilog.sh
COPY install/ubuntu_install_python_package.sh /install/ubuntu_install_python_package.sh
RUN bash /install/ubuntu_install_python_package.sh
COPY install/ubuntu_install_sphinx.sh /install/ubuntu_install_sphinx.sh COPY install/ubuntu_install_sphinx.sh /install/ubuntu_install_sphinx.sh
RUN bash /install/ubuntu_install_sphinx.sh RUN bash /install/ubuntu_install_sphinx.sh
......
...@@ -5,8 +5,11 @@ RUN apt-get update --fix-missing ...@@ -5,8 +5,11 @@ RUN apt-get update --fix-missing
COPY install/ubuntu_install_core.sh /install/ubuntu_install_core.sh COPY install/ubuntu_install_core.sh /install/ubuntu_install_core.sh
RUN bash /install/ubuntu_install_core.sh RUN bash /install/ubuntu_install_core.sh
COPY install/ubuntu_install_llvm.sh /install/ubuntu_install_llvm.sh
RUN bash /install/ubuntu_install_llvm.sh
COPY install/ubuntu_install_python.sh /install/ubuntu_install_python.sh COPY install/ubuntu_install_python.sh /install/ubuntu_install_python.sh
RUN bash /install/ubuntu_install_python.sh RUN bash /install/ubuntu_install_python.sh
COPY install/ubuntu_install_llvm.sh /install/ubuntu_install_llvm.sh COPY install/ubuntu_install_python_package.sh /install/ubuntu_install_python_package.sh
RUN bash /install/ubuntu_install_llvm.sh RUN bash /install/ubuntu_install_python_package.sh
apt-get update && apt-get install -y curl apt-get update && apt-get install -y curl
curl -sL https://deb.nodesource.com/setup_6.x | bash - curl -sL https://deb.nodesource.com/setup_6.x | bash -
apt-get update && apt-get install -y nodejs apt-get update && apt-get install -y nodejs
npm install eslint jsdoc npm install eslint jsdoc ws
# install libraries for python package on ubuntu # install python and pip, don't modify this, modify install_python_package.sh
apt-get update && apt-get install -y python-pip python-dev python3-dev apt-get update && apt-get install -y python-pip python-dev python3-dev
# the version of the pip shipped with ubuntu may be too lower, install a recent version here # the version of the pip shipped with ubuntu may be too lower, install a recent version here
cd /tmp && wget https://bootstrap.pypa.io/get-pip.py && python3 get-pip.py && python2 get-pip.py cd /tmp && wget https://bootstrap.pypa.io/get-pip.py && python3 get-pip.py && python2 get-pip.py
pip2 install nose pylint numpy nose-timer cython decorator scipy
pip3 install nose pylint numpy nose-timer cython decorator scipy
# install libraries for python package on ubuntu
pip2 install nose pylint numpy nose-timer cython decorator scipy tornado
pip3 install nose pylint numpy nose-timer cython decorator scipy tornado
import tvm
import logging
import numpy as np
import time
import multiprocessing
from tvm.contrib import rpc
def rpc_proxy_test():
"""This is a simple test function for RPC Proxy
It is not included as nosetests, because:
- It depends on tornado
- It relies on the fact that Proxy starts before client and server connects,
which is often the case but not always
User can directly run this script to verify correctness.
"""
try:
from tvm.contrib import rpc_proxy
web_port = 8888
prox = rpc_proxy.Proxy("localhost", web_port=web_port)
def check():
if not tvm.module.enabled("rpc"):
return
@tvm.register_func("rpc.test2.addone")
def addone(x):
return x + 1
@tvm.register_func("rpc.test2.strcat")
def addone(name, x):
return "%s:%d" % (name, x)
server = multiprocessing.Process(
target=rpc_proxy.websocket_proxy_server,
args=("ws://localhost:%d/ws" % web_port,"x1"))
# Need to make sure that the connection start after proxy comes up
time.sleep(0.1)
server.deamon = True
server.start()
client = rpc.connect(prox.host, prox.port, key="x1")
f1 = client.get_function("rpc.test2.addone")
assert f1(10) == 11
f2 = client.get_function("rpc.test2.strcat")
assert f2("abc", 11) == "abc:11"
check()
except ImportError:
print("Skipping because tornado is not avaliable...")
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
rpc_proxy_test()
...@@ -14,10 +14,21 @@ def test_rpc_simple(): ...@@ -14,10 +14,21 @@ def test_rpc_simple():
def addone(name, x): def addone(name, x):
return "%s:%d" % (name, x) return "%s:%d" % (name, x)
@tvm.register_func("rpc.test.except")
def remotethrow(name):
raise ValueError("%s" % name)
server = rpc.Server("localhost") server = rpc.Server("localhost")
client = rpc.connect(server.host, server.port) client = rpc.connect(server.host, server.port, key="x1")
f1 = client.get_function("rpc.test.addone") f1 = client.get_function("rpc.test.addone")
assert f1(10) == 11 assert f1(10) == 11
f3 = client.get_function("rpc.test.except")
try:
f3("abc")
assert False
except tvm.TVMError as e:
assert "abc" in str(e)
f2 = client.get_function("rpc.test.strcat") f2 = client.get_function("rpc.test.strcat")
assert f2("abc", 11) == "abc:11" assert f2("abc", 11) == "abc:11"
...@@ -42,9 +53,10 @@ def test_rpc_file_exchange(): ...@@ -42,9 +53,10 @@ def test_rpc_file_exchange():
return return
server = rpc.Server("localhost") server = rpc.Server("localhost")
remote = rpc.connect(server.host, server.port) remote = rpc.connect(server.host, server.port)
blob = bytearray(np.random.randint(0, 10, size=(127))) blob = bytearray(np.random.randint(0, 10, size=(10)))
remote.upload(blob, "dat.bin") remote.upload(blob, "dat.bin")
rev = remote.download("dat.bin") rev = remote.download("dat.bin")
assert(rev == blob)
def test_rpc_remote_module(): def test_rpc_remote_module():
if not tvm.module.enabled("rpc"): if not tvm.module.enabled("rpc"):
...@@ -79,7 +91,8 @@ def test_rpc_remote_module(): ...@@ -79,7 +91,8 @@ def test_rpc_remote_module():
if __name__ == "__main__": if __name__ == "__main__":
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
test_rpc_file_exchange()
exit(0)
test_rpc_array() test_rpc_array()
test_rpc_remote_module() test_rpc_remote_module()
test_rpc_file_exchange()
test_rpc_simple() test_rpc_simple()
"""Simple testcode to test Javascript RPC
To use it, start a rpc proxy with "python -m tvm.exec.rpc_proxy".
Connect javascript end to the websocket port and connect to the RPC.
"""
import tvm
import os
from tvm.contrib import rpc, util, emscripten
import numpy as np
proxy_host = "localhost"
proxy_port = 9090
def test_rpc_array():
if not tvm.module.enabled("rpc"):
return
# graph
n = tvm.convert(1024)
A = tvm.placeholder((n,), name='A')
B = tvm.compute(A.shape, lambda *i: A(*i) + 1.0, name='B')
s = tvm.create_schedule(B.op)
remote = rpc.connect(proxy_host, proxy_port, key="js")
target = "llvm -target=asmjs-unknown-emscripten -system-lib"
def check_remote():
if not tvm.module.enabled(target):
print("Skip because %s is not enabled" % target)
return
temp = util.tempdir()
ctx = remote.cpu(0)
f = tvm.build(s, [A, B], target, name="myadd")
path_obj = temp.relpath("dev_lib.bc")
path_dso = temp.relpath("dev_lib.js")
f.save(path_obj)
emscripten.create_js(path_dso, path_obj, side_module=True)
# Upload to suffix as dso so it can be loaded remotely
remote.upload(path_dso, "dev_lib.dso")
data = remote.download("dev_lib.dso")
f1 = remote.load_module("dev_lib.dso")
a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx)
b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx)
time_f = f1.time_evaluator(f1.entry_name, remote.cpu(0), number=10)
cost = time_f(a, b)
print('%g secs/op' % cost)
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
check_remote()
test_rpc_array()
...@@ -139,9 +139,18 @@ sysLib.release(); ...@@ -139,9 +139,18 @@ sysLib.release();
``` ```
## Notes Current example supports static linking, which is the preferred way to get more efficiency
- Current example supports static linking, which is the preferred way to get more efficiency
in javascript backend. in javascript backend.
- It should also be possible to use Emscripten's dynamic linking to dynamically load modules.
- Take a look at tvm_runtime.js which contains quite a few helper functions ## Proxy based RPC
to interact with TVM from javascript.
\ No newline at end of file We can now use javascript end to start an RPC server and connect to it from python side,
making the testing flow easier.
The following is an example to reproduce this. This requires everything to be in the git source and setup PYTHONPATH(instead of use setup.py install)
- run "python -m tvm.exec.rpc_proxy --example-rpc=1" to start proxy.
- Open broswer, goto the server webpage click Connect to proxy.
- Alternatively run "node web/example_rpc_node.js"
- run "python tests/web/websock_rpc_test.py" to run the rpc client.
The general idea is to use Emscripten's dynamic linking to dynamically load modules.
<html>
<head><title> TVM RPC Test Page </title></head>
<script src="libtvm_web_runtime.js"></script>
<script src="tvm_runtime.js"></script>
<script>
tvm = tvm_runtime.create(Module);
tvm.logger = function(message) {
console.log(message);
var d = document.createElement("div");
d.innerHTML = message;
document.getElementById("log").appendChild(d);
};
function clear_log() {
var node = document.getElementById("log");
while (node.hasChildNodes()) {
node.removeChild(node.lastChild);
}
}
function connect_rpc() {
var proxyurl = document.getElementById("proxyURL").value;
var key = document.getElementById("proxyKey").value;
tvm.startRPCServer(proxyurl, key, 100);
}
</script>
<body>
<h1>TVM Test Page</h1>
To use this page, the easiest way is to do
<ul>
<li> run "python -m tvm.exec.rpc_proxy --example-rpc=1" to start proxy.
<li> Click Connect to proxy.
<li> run "python tests/web/websock_rpc_test.py" to run the rpc client.
</ul>
<h2>Options</h2>
Proxy URL<input name="proxyurl" id="proxyURL" type="text" value="ws://localhost:9888/ws"><br>
RPC Server Key<input name="serverkey" id="proxyKey" type="text" value="js"><br>
<button onclick="connect_rpc()">Connect To Proxy</button>
<button onclick="clear_log()">Clear Log</button>
<div id="log"></div>
</body>
</html>
// Javascript RPC server example
// Start and connect to websocket proxy.
// Load Emscripten Module, need to change path to root/lib
const path = require("path");
process.chdir(path.join(__dirname, "../lib"));
var Module = require("../lib/libtvm_web_runtime.js");
// Bootstrap TVMruntime with emscripten module.
const tvm_runtime = require("../web/tvm_runtime.js");
const tvm = tvm_runtime.create(Module);
var websock_proxy = "ws://localhost:9888/ws";
var num_sess = 100;
tvm.startRPCServer(websock_proxy, "js", num_sess)
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
/* eslint no-unused-vars: "off" */ /* eslint no-unused-vars: "off" */
/* eslint no-unexpected-multiline: "off" */ /* eslint no-unexpected-multiline: "off" */
/* eslint indent: "off" */ /* eslint indent: "off" */
/* eslint no-console: "off" */
/** /**
* TVM Runtime namespace. * TVM Runtime namespace.
* Provide tvm_runtime.create to create a {@link tvm.TVMRuntime}. * Provide tvm_runtime.create to create a {@link tvm.TVMRuntime}.
...@@ -31,8 +32,13 @@ var tvm_runtime = tvm_runtime || {}; ...@@ -31,8 +32,13 @@ var tvm_runtime = tvm_runtime || {};
* @memberof tvm * @memberof tvm
*/ */
function TVMRuntime() { function TVMRuntime() {
"use strict";
var runtime_ref = this;
// Utility function to throw error // Utility function to throw error
function throwError(message) { function throwError(message) {
if (typeof runtime_ref.logger !== "undefined") {
runtime_ref.logger(message);
}
if (typeof Error !== "undefined") { if (typeof Error !== "undefined") {
throw new Error(message); throw new Error(message);
} }
...@@ -232,6 +238,21 @@ var tvm_runtime = tvm_runtime || {}; ...@@ -232,6 +238,21 @@ var tvm_runtime = tvm_runtime || {};
throwError(message); throwError(message);
} }
}; };
/**
* Logging function.
* Override this to change logger behavior.
*
* @param {string} message
*/
this.logger = function(message) {
console.log(message);
};
function logging(message) {
runtime_ref.logger(message);
}
// Override print error to logging
Module.printErr = logging;
var CHECK = this.assert; var CHECK = this.assert;
function TVM_CALL(ret) { function TVM_CALL(ret) {
...@@ -746,11 +767,12 @@ var tvm_runtime = tvm_runtime || {}; ...@@ -746,11 +767,12 @@ var tvm_runtime = tvm_runtime || {};
out_array.release(); out_array.release();
return names; return names;
}; };
var listGlobalFuncNames = this.listGlobalFuncNames;
/** /**
* Get a global function from TVM runtime. * Get a global function from TVM runtime.
* *
* @param {string} The name of the function. * @param {string} The name of the function.
* @return {Function} The corresponding function. * @return {Function} The corresponding function, null if function do not exist
*/ */
this.getGlobalFunc = function (name) { this.getGlobalFunc = function (name) {
// alloc // alloc
...@@ -759,7 +781,11 @@ var tvm_runtime = tvm_runtime || {}; ...@@ -759,7 +781,11 @@ var tvm_runtime = tvm_runtime || {};
var out_handle = out.asHandle(); var out_handle = out.asHandle();
// release // release
out.release(); out.release();
return makeTVMFunction(out_handle); if (out_handle != 0) {
return makeTVMFunction(out_handle);
} else {
return null;
}
}; };
var getGlobalFunc = this.getGlobalFunc; var getGlobalFunc = this.getGlobalFunc;
/** /**
...@@ -788,14 +814,123 @@ var tvm_runtime = tvm_runtime || {}; ...@@ -788,14 +814,123 @@ var tvm_runtime = tvm_runtime || {};
//----------------------------------------- //-----------------------------------------
// Wrap of TVM Functions. // Wrap of TVM Functions.
// ---------------------------------------- // ----------------------------------------
var fGetSystemLib = getGlobalFunc("module._GetSystemLib"); var systemFunc = {};
/** /**
* Get system-wide library module singleton. * Get system-wide library module singleton.5A
* System lib is a global module that contains self register functions in startup. * System lib is a global module that contains self register functions in startup.
* @return {tvm.TVMModule} The system module singleton. * @return {tvm.TVMModule} The system module singleton.
*/ */
this.systemLib = function() { this.systemLib = function() {
return fGetSystemLib(); if (typeof systemFunc.fGetSystemLib === "undefined") {
systemFunc.fGetSystemLib = getGlobalFunc("module._GetSystemLib");
}
return systemFunc.fGetSystemLib();
};
this.startRPCServer = function(url, key, counter) {
if (typeof key === "undefined") {
key = "";
}
if (typeof counter === "undefined") {
counter = 1;
}
// Node js, import websocket
var bkey = StringToUint8Array("server:" + key);
var server_name = "WebSocketRPCServer[" + key + "]";
var RPC_MAGIC = 0xff271;
function checkEndian() {
var a = new ArrayBuffer(4);
var b = new Uint8Array(a);
var c = new Uint32Array(a);
b[0] = 0x11;
b[1] = 0x22;
b[2] = 0x33;
b[3] = 0x44;
CHECK(c[0] === 0x44332211, "Need little endian to work");
}
checkEndian();
// start rpc
function RPCServer(counter) {
var socket;
if (typeof module !== "undefined" && module.exports) {
// WebSocket for nodejs
const WebSocket = require("ws");
socket = new WebSocket(url);
} else {
socket = new WebSocket(url);
}
var self = this;
socket.binaryType = "arraybuffer";
this.init = true;
this.counter = counter;
if (typeof systemFunc.fcreateServer === "undefined") {
systemFunc.fcreateServer =
getGlobalFunc("contrib.rpc._CreateEventDrivenServer");
}
if (systemFunc.fcreateServer == null) {
throwError("RPCServer is not included in runtime");
}
var message_handler = systemFunc.fcreateServer(
function(cbytes) {
if (socket.readyState == 1) {
socket.send(cbytes);
return new TVMConstant(cbytes.length, "int32");
} else {
return new TVMConstant(0, "int32");
}
} , server_name);
function on_open(event) {
var intbuf = new Int32Array(1);
intbuf[0] = RPC_MAGIC;
socket.send(intbuf);
intbuf[0] = bkey.length;
socket.send(intbuf);
socket.send(bkey);
logging(server_name + " connected...");
}
function on_message(event) {
if (self.init) {
var msg = new Uint8Array(event.data);
CHECK(msg.length >= 4, "Need message header to be bigger than 4");
var magic = new Int32Array(event.data)[0];
if (magic == RPC_MAGIC + 1) {
throwError("key: " + key + " has already been used in proxy");
} else if (magic == RPC_MAGIC + 2) {
logging(server_name + ": RPCProxy do not have matching client key " + key);
} else {
CHECK(magic == RPC_MAGIC, url + "is not RPC Proxy");
self.init = false;
}
logging(server_name + "init end...");
if (msg.length > 4) {
if (!message_handler(new Uint8Array(event.data, 4, msg.length -4))) {
socket.close();
}
}
} else {
if (!message_handler(new Uint8Array(event.data))) {
socket.close();
}
}
}
function on_close(event) {
message_handler.release();
logging(server_name + ": closed finish...");
if (!self.init && self.counter != 0) {
logging(server_name + ":reconnect to serve another request, session left=" + counter);
// start a new server.
new RPCServer(counter - 1);
}
}
socket.addEventListener("open", on_open);
socket.addEventListener("message", on_message);
socket.addEventListener("close", on_close);
}
return new RPCServer(counter);
}; };
//----------------------------------------- //-----------------------------------------
// Class defintions // Class defintions
...@@ -943,16 +1078,12 @@ var tvm_runtime = tvm_runtime || {}; ...@@ -943,16 +1078,12 @@ var tvm_runtime = tvm_runtime || {};
* @property {string} create * @property {string} create
* @memberof tvm_runtime * @memberof tvm_runtime
* @param Module The emscripten module. * @param Module The emscripten module.
* @param Runtime The emscripten runtime, optional
* @return {tvm.TVMRuntime} The created TVM runtime. * @return {tvm.TVMRuntime} The created TVM runtime.
*/ */
this.create = function(Module, Runtime) { this.create = function(Module) {
var tvm = {}; var tvm = {};
tvm.Module = Module; tvm.Module = Module;
if (typeof Runtime == "undefined") { tvm.Runtime = Module.Runtime;
Runtime = Module.Runtime;
}
tvm.Runtime = Runtime;
TVMRuntime.apply(tvm); TVMRuntime.apply(tvm);
return tvm; return tvm;
}; };
......
...@@ -2,6 +2,9 @@ ...@@ -2,6 +2,9 @@
* Copyright (c) 2017 by Contributors * Copyright (c) 2017 by Contributors
* \file web_runtime.cc * \file web_runtime.cc
*/ */
#include <sys/stat.h>
#include <fstream>
#include "../src/runtime/c_runtime_api.cc" #include "../src/runtime/c_runtime_api.cc"
#include "../src/runtime/cpu_device_api.cc" #include "../src/runtime/cpu_device_api.cc"
#include "../src/runtime/workspace_pool.cc" #include "../src/runtime/workspace_pool.cc"
...@@ -10,3 +13,33 @@ ...@@ -10,3 +13,33 @@
#include "../src/runtime/module.cc" #include "../src/runtime/module.cc"
#include "../src/runtime/registry.cc" #include "../src/runtime/registry.cc"
#include "../src/runtime/file_util.cc" #include "../src/runtime/file_util.cc"
#include "../src/runtime/dso_module.cc"
#include "../src/runtime/rpc/rpc_session.cc"
#include "../src/runtime/rpc/rpc_event_impl.cc"
#include "../src/runtime/rpc/rpc_server_env.cc"
namespace tvm {
namespace contrib {
struct RPCEnv {
public:
RPCEnv() {
base_ = "/rpc";
mkdir(&base_[0], 0777);
}
// Get Path.
std::string GetPath(const std::string& file_name) {
return base_ + "/" + file_name;
}
private:
std::string base_;
};
TVM_REGISTER_GLOBAL("tvm.contrib.rpc.server.workpath")
.set_body([](TVMArgs args, TVMRetValue* rv) {
static RPCEnv env;
*rv = env.GetPath(args[0]);
});
} // namespace contrib
} // namespace tvm
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment