Commit c324494f by Tianqi Chen Committed by GitHub

[RUNTIME][RPC] Change RPCServer to Event Driven Code (#243)

* [RUNTIME][RPC] Change RPCServer to Event Driven Code

* fix
parent f4d1dddb
......@@ -58,7 +58,7 @@ CFLAGS = -std=c++11 -Wall -O2 $(INCLUDE_FLAGS) -fPIC
LLVM_CFLAGS= -fno-rtti -DDMLC_ENABLE_RTTI=0
FRAMEWORKS =
OBJCFLAGS = -fno-objc-arc
EMCC_FLAGS= -s RESERVED_FUNCTION_POINTERS=2 -s NO_EXIT_RUNTIME=1 -DDMLC_LOG_STACK_TRACE=0\
EMCC_FLAGS= -s RESERVED_FUNCTION_POINTERS=2 -s NO_EXIT_RUNTIME=1 -s MAIN_MODULE=1 -DDMLC_LOG_STACK_TRACE=0\
-std=c++11 -Oz $(INCLUDE_FLAGS)
# Dependency specific rules
......
......@@ -212,13 +212,6 @@ class TVMPODValue_ {
int type_code() const {
return type_code_;
}
protected:
friend class TVMArgsSetter;
friend class TVMRetValue;
TVMPODValue_() : type_code_(kNull) {}
TVMPODValue_(TVMValue value, int type_code)
: value_(value), type_code_(type_code) {}
/*!
* \brief return handle as specific pointer type.
* \tparam T the data type.
......@@ -228,6 +221,14 @@ class TVMPODValue_ {
T* ptr() const {
return static_cast<T*>(value_.v_handle);
}
protected:
friend class TVMArgsSetter;
friend class TVMRetValue;
TVMPODValue_() : type_code_(kNull) {}
TVMPODValue_(TVMValue value, int type_code)
: value_(value), type_code_(type_code) {}
/*! \brief The value */
TVMValue value_;
/*! \brief the type code */
......@@ -347,6 +348,8 @@ class TVMRetValue : public TVMPODValue_ {
operator std::string() const {
if (type_code_ == kTVMType) {
return TVMType2String(operator TVMType());
} else if (type_code_ == kBytes) {
return *ptr<std::string>();
}
TVM_CHECK_TYPE_CODE(type_code_, kStr);
return *ptr<std::string>();
......@@ -414,6 +417,10 @@ class TVMRetValue : public TVMPODValue_ {
this->SwitchToClass(kStr, value);
return *this;
}
TVMRetValue& operator=(TVMByteArray value) {
this->SwitchToClass(kBytes, std::string(value.data, value.size));
return *this;
}
TVMRetValue& operator=(PackedFunc f) {
this->SwitchToClass(kFuncHandle, f);
return *this;
......@@ -442,7 +449,7 @@ class TVMRetValue : public TVMPODValue_ {
void MoveToCHost(TVMValue* ret_value,
int* ret_type_code) {
// cannot move str; need specially handle.
CHECK(type_code_ != kStr);
CHECK(type_code_ != kStr && type_code_ != kBytes);
*ret_value = value_;
*ret_type_code = type_code_;
type_code_ = kNull;
......@@ -470,11 +477,14 @@ class TVMRetValue : public TVMPODValue_ {
template<typename T>
void Assign(const T& other) {
switch (other.type_code()) {
case kStr:
case kBytes: {
case kStr: {
SwitchToClass<std::string>(kStr, other);
break;
}
case kBytes: {
SwitchToClass<std::string>(kBytes, other);
break;
}
case kFuncHandle: {
SwitchToClass<PackedFunc>(kFuncHandle, other);
break;
......@@ -702,6 +712,7 @@ class TVMArgsSetter {
values_[i].v_str = value.ptr<std::string>()->c_str();
type_codes_[i] = kStr;
} else {
CHECK_NE(value.type_code(), kBytes) << "not handled.";
values_[i] = value.value_;
type_codes_[i] = value.type_code();
}
......
......@@ -7,6 +7,7 @@ from .._ffi.libinfo import find_lib_path
def create_js(output,
objects,
options=None,
side_module=False,
cc="emcc"):
"""Create emscripten javascript library.
......@@ -29,6 +30,8 @@ def create_js(output,
cmd += ["-s", "NO_EXIT_RUNTIME=1"]
cmd += ["-Oz"]
cmd += ["-o", output]
if side_module:
cmd += ["-s", "SIDE_MODULE=1"]
objects = [objects] if isinstance(objects, str) else objects
with_runtime = False
......@@ -36,7 +39,7 @@ def create_js(output,
if obj.find("libtvm_web_runtime.bc") != -1:
with_runtime = True
if not with_runtime:
if not with_runtime and not side_module:
objects += [find_lib_path("libtvm_web_runtime.bc")[0]]
cmd += objects
......
......@@ -19,32 +19,20 @@ from . import util, cc_compiler
from ..module import load as _load_module
from .._ffi.function import _init_api, register_func
from .._ffi.ndarray import context as _context
from .._ffi.base import py_str
RPC_MAGIC = 0xff271
RPC_SESS_MASK = 128
def _serve_loop(sock, addr):
"""Server loop"""
sockfd = sock.fileno()
def _server_env():
"""Server environment function return temp dir"""
temp = util.tempdir()
# pylint: disable=unused-variable
@register_func("tvm.contrib.rpc.server.upload")
def upload(file_name, blob):
"""Upload the blob to remote temp file"""
path = temp.relpath(file_name)
with open(path, "wb") as out_file:
out_file.write(blob)
logging.info("upload %s", path)
@register_func("tvm.contrib.rpc.server.workpath")
def get_workpath(path):
return temp.relpath(path)
@register_func("tvm.contrib.rpc.server.download")
def download(file_name):
"""Download file from remote"""
path = temp.relpath(file_name)
dat = bytearray(open(path, "rb").read())
logging.info("download %s", path)
return dat
@register_func("tvm.contrib.rpc.server.load_module")
@register_func("tvm.contrib.rpc.server.load_module", override=True)
def load_module(file_name):
"""Load module from remote side."""
path = temp.relpath(file_name)
......@@ -53,11 +41,16 @@ def _serve_loop(sock, addr):
logging.info('Create shared library based on %s', path)
cc_compiler.create_shared(path + '.so', path)
path += '.so'
m = _load_module(path)
logging.info("load_module %s", path)
return m
return temp
def _serve_loop(sock, addr):
"""Server loop"""
sockfd = sock.fileno()
temp = _server_env()
_ServerLoop(sockfd)
temp.remove()
logging.info("Finish serving %s", addr)
......@@ -78,11 +71,16 @@ def _listen_loop(sock):
while True:
conn, addr = sock.accept()
logging.info("RPCServer: connection from %s", addr)
conn.sendall(struct.pack('@i', RPC_MAGIC))
magic = struct.unpack('@i', _recvall(conn, 4))[0]
if magic != RPC_MAGIC:
conn.close()
continue
keylen = struct.unpack('@i', _recvall(conn, 4))[0]
key = py_str(_recvall(conn, keylen))
if not key.startswith("client:"):
conn.sendall(struct.pack('@i', RPC_MAGIC + 2))
else:
conn.sendall(struct.pack('@i', RPC_MAGIC))
logging.info("Connection from %s", addr)
process = multiprocessing.Process(target=_serve_loop, args=(conn, addr))
process.deamon = True
......@@ -91,6 +89,27 @@ def _listen_loop(sock):
conn.close()
def _connect_proxy_loop(addr, key):
while True:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect(addr)
sock.sendall(struct.pack('@i', RPC_MAGIC))
sock.sendall(struct.pack('@i', len(key)))
sock.sendall(key)
magic = struct.unpack('@i', _recvall(sock, 4))[0]
if magic == RPC_MAGIC + 1:
raise RuntimeError("key: %s has already been used in proxy" % key)
elif magic == RPC_MAGIC + 2:
logging.info("RPCProxy do not have matching client key %s", key)
elif magic != RPC_MAGIC:
raise RuntimeError("%s is not RPC Proxy" % str(addr))
logging.info("RPCProxy connected to %s", str(addr))
process = multiprocessing.Process(target=_serve_loop, args=(sock, addr))
process.deamon = True
process.start()
process.join()
class Server(object):
"""Start RPC server on a seperate process.
......@@ -108,8 +127,20 @@ class Server(object):
port_end : int, optional
The end port to search
is_proxy : bool, optional
Whether the address specified is a proxy.
If this is true, the host and port actually corresponds to the
address of the proxy server.
key : str, optional
The key used to identify the server in Proxy connection.
"""
def __init__(self, host, port=9091, port_end=9199):
def __init__(self, host, port=9091, port_end=9199, is_proxy=False, key=""):
self.host = host
self.port = port
if not is_proxy:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.port = None
for my_port in range(port, port_end):
......@@ -127,8 +158,12 @@ class Server(object):
logging.info("RPCServer: bind to %s:%d", host, self.port)
sock.listen(1)
self.sock = sock
self.host = host
self.proc = multiprocessing.Process(target=_listen_loop, args=(self.sock,))
self.proc = multiprocessing.Process(
target=_listen_loop, args=(self.sock,))
else:
self.proc = multiprocessing.Process(
target=_connect_proxy_loop, args=((host, port), key))
self.proc.deamon = True
self.proc.start()
def terminate(self):
......@@ -141,6 +176,7 @@ class Server(object):
self.terminate()
class RPCSession(object):
"""RPC Client session module
......@@ -262,7 +298,7 @@ class RPCSession(object):
return _LoadRemoteModule(self._sess, path)
def connect(url, port):
def connect(url, port, key=""):
"""Connect to RPC Server
Parameters
......@@ -273,13 +309,16 @@ def connect(url, port):
port : int
The port to connect to
key : str, optional
Additional key to match server
Returns
-------
sess : RPCSession
The connected session.
"""
try:
sess = _Connect(url, port)
sess = _Connect(url, port, key)
except NameError:
raise RuntimeError('Please compile with USE_RPC=1')
return RPCSession(sess)
......
"""RPC web proxy, allows redirect to websocket based RPC servers(browsers)"""
from __future__ import absolute_import
import logging
import argparse
import os
from ..contrib.rpc_proxy import Proxy
def find_example_resource():
"""Find resource examples."""
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
base_path = os.path.join(curr_path, "../../../")
index_page = os.path.join(base_path, "web/example_rpc.html")
js_files = [
os.path.join(base_path, "web/tvm_runtime.js"),
os.path.join(base_path, "lib/libtvm_web_runtime.js"),
os.path.join(base_path, "lib/libtvm_web_runtime.js.mem")
]
for fname in [index_page] + js_files:
if not os.path.exists(fname):
raise RuntimeError("Cannot find %s" % fname)
return index_page, js_files
def main():
"""Main funciton"""
parser = argparse.ArgumentParser()
parser.add_argument('--host', type=str, default="0.0.0.0",
help='the hostname of the server')
parser.add_argument('--port', type=int, default=9090,
help='The port of the PRC')
parser.add_argument('--web-port', type=int, default=9888,
help='The port of the http/websocket server')
parser.add_argument('--example-rpc', type=bool, default=False,
help='Whether to switch on example rpc mode')
args = parser.parse_args()
logging.basicConfig(level=logging.INFO)
if args.example_rpc:
index, js_files = find_example_resource()
prox = Proxy(args.host, port=args.port,
web_port=args.web_port, index_page=index,
resource_files=js_files)
else:
prox = Proxy(args.host, port=args.port, web_port=args.web_port)
prox.proc.join()
if __name__ == "__main__":
main()
/*!
* Copyright (c) 2017 by Contributors
* \file ring_buffer.h
* \brief this file aims to provide a wrapper of sockets
*/
#ifndef TVM_COMMON_RING_BUFFER_H_
#define TVM_COMMON_RING_BUFFER_H_
#include <vector>
#include <cstring>
#include <algorithm>
namespace tvm {
namespace common {
/*!
* \brief Ring buffer class for data buffering in IO.
* Enables easy usage for sync and async mode.
*/
class RingBuffer {
public:
/*! \brief Initial capacity of ring buffer. */
static const int kInitCapacity = 4 << 10;
/*! \brief constructor */
RingBuffer() : ring_(kInitCapacity) {}
/*! \return number of bytes available in buffer. */
size_t bytes_available() const {
return bytes_available_;
}
/*! \return Current capacity of buffer. */
size_t capacity() const {
return ring_.size();
}
/*!
* Reserve capacity to be at least n.
* Will only increase capacity if n is bigger than current capacity.
* \param n The size of capacity.
*/
void Reserve(size_t n) {
if (ring_.size() >= n) return;
size_t old_size = ring_.size();
size_t new_size = ring_.size();
while (new_size < n) {
new_size *= 2;
}
ring_.resize(new_size);
if (head_ptr_ + bytes_available_ > old_size) {
// copy the ring overflow part into the tail.
size_t ncopy = head_ptr_ + bytes_available_ - old_size;
memcpy(&ring_[0] + old_size, &ring_[0], ncopy);
}
}
/*!
* \brief Peform a non-blocking read from buffer
* size must be smaller than this->bytes_available()
* \param data the data pointer.
* \param size The number of bytes to read.
*/
void Read(void* data, size_t size) {
CHECK_GE(bytes_available_, size);
size_t ncopy = std::min(size, ring_.size() - head_ptr_);
memcpy(data, &ring_[0] + head_ptr_, ncopy);
if (ncopy < size) {
memcpy(reinterpret_cast<char*>(data) + ncopy,
&ring_[0], size - ncopy);
}
head_ptr_ = (head_ptr_ + size) % ring_.size();
bytes_available_ -= size;
}
/*!
* \brief Read data from buffer with and put them to non-blocking send function.
*
* \param frecv A send function handle to put the data to.
* \param max_nbytes Maximum number of bytes can to read.
* \tparam FSend A non-blocking function with signature size_t (const void* data, size_t size);
*/
template<typename FSend>
size_t ReadWithCallback(FSend fsend, size_t max_nbytes) {
size_t size = std::min(max_nbytes, bytes_available_);
CHECK_NE(size, 0U);
size_t ncopy = std::min(size, ring_.size() - head_ptr_);
size_t nsend = fsend(&ring_[0] + head_ptr_, ncopy);
bytes_available_ -= nsend;
if (ncopy == nsend && ncopy < size) {
size_t nsend2 = fsend(&ring_[0], size - ncopy);
bytes_available_ -= nsend2;
nsend += nsend2;
}
return nsend;
}
/*!
* \brief Write data into buffer, always ensures all data is written.
* \param data The data pointer
* \param size The size of data to be written.
*/
void Write(const void* data, size_t size) {
this->Reserve(bytes_available_ + size);
size_t tail = head_ptr_ + bytes_available_;
if (tail >= ring_.size()) {
memcpy(&ring_[0] + (tail - ring_.size()), data, size);
} else {
size_t ncopy = std::min(ring_.size() - tail, size);
memcpy(&ring_[0] + tail, data, ncopy);
if (ncopy < size) {
memcpy(&ring_[0], reinterpret_cast<const char*>(data) + ncopy, size - ncopy);
}
}
bytes_available_ += size;
}
/*!
* \brief Writen data into the buffer by give it a non-blocking callback function.
*
* \param frecv A receive function handle
* \param max_nbytes Maximum number of bytes can write.
* \tparam FRecv A non-blocking function with signature size_t (void* data, size_t size);
*/
template<typename FRecv>
size_t WriteWithCallback(FRecv frecv, size_t max_nbytes) {
this->Reserve(bytes_available_ + max_nbytes);
size_t nbytes = max_nbytes;
size_t tail = head_ptr_ + bytes_available_;
if (tail >= ring_.size()) {
size_t nrecv = frecv(&ring_[0] + (tail - ring_.size()), nbytes);
bytes_available_ += nrecv;
return nrecv;
} else {
size_t ncopy = std::min(ring_.size() - tail, nbytes);
size_t nrecv = frecv(&ring_[0] + tail, ncopy);
bytes_available_ += nrecv;
if (nrecv == ncopy && ncopy < nbytes) {
size_t nrecv2 = frecv(&ring_[0], nbytes - ncopy);
bytes_available_ += nrecv2;
nrecv += nrecv2;
}
return nrecv;
}
}
private:
// buffer head
size_t head_ptr_{0};
// number of bytes in the buffer.
size_t bytes_available_{0};
// The internald ata ring.
std::vector<char> ring_;
};
} // namespace common
} // namespace tvm
#endif // TVM_COMMON_RING_BUFFER_H_
......@@ -157,6 +157,7 @@ using namespace tvm::runtime;
struct TVMRuntimeEntry {
std::string ret_str;
std::string last_error;
TVMByteArray ret_bytes;
// threads used in parallel for
std::vector<std::thread> par_threads;
// errors created in parallel for.
......@@ -311,11 +312,23 @@ int TVMFuncCall(TVMFunctionHandle func,
TVMArgs(args, arg_type_codes, num_args), &rv);
// handle return string.
if (rv.type_code() == kStr ||
rv.type_code() == kTVMType) {
rv.type_code() == kTVMType ||
rv.type_code() == kBytes) {
TVMRuntimeEntry* e = TVMAPIRuntimeStore::Get();
if (rv.type_code() != kTVMType) {
e->ret_str = *rv.ptr<std::string>();
} else {
e->ret_str = rv.operator std::string();
}
if (rv.type_code() == kBytes) {
e->ret_bytes.data = e->ret_str.c_str();
e->ret_bytes.size = e->ret_str.length();
*ret_type_code = kBytes;
ret_val->v_handle = &(e->ret_bytes);
} else {
*ret_type_code = kStr;
ret_val->v_str = e->ret_str.c_str();
}
} else {
rv.MoveToCHost(ret_val, ret_type_code);
}
......
/*!
* Copyright (c) 2017 by Contributors
* \file rpc_event_impl.cc
* \brief Event based RPC server implementation.
*/
#include <tvm/runtime/registry.h>
#include <memory>
#include "./rpc_session.h"
namespace tvm {
namespace runtime {
class CallbackChannel final : public RPCChannel {
public:
explicit CallbackChannel(PackedFunc fsend)
: fsend_(fsend) {}
size_t Send(const void* data, size_t size) final {
TVMByteArray bytes;
bytes.data = static_cast<const char*>(data);
bytes.size = size;
uint64_t ret = fsend_(bytes);
return static_cast<size_t>(ret);
}
size_t Recv(void* data, size_t size) final {
LOG(FATAL) << "Do not allow explicit receive for";
return 0;
}
private:
PackedFunc fsend_;
};
PackedFunc CreateEvenDrivenServer(PackedFunc fsend, std::string name) {
std::unique_ptr<CallbackChannel> ch(new CallbackChannel(fsend));
std::shared_ptr<RPCSession> sess = RPCSession::Create(std::move(ch), name);
return PackedFunc([sess](TVMArgs args, TVMRetValue* rv) {
bool ret = sess->ServerOnMessageHandler(args[0]);
*rv = ret;
});
}
TVM_REGISTER_GLOBAL("contrib.rpc._CreateEventDrivenServer")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = CreateEvenDrivenServer(args[0], args[1]);
});
} // namespace runtime
} // namespace tvm
......@@ -10,8 +10,6 @@
namespace tvm {
namespace runtime {
const int kRPCMagic = 0xff271;
// Wrapped remote function to packed func.
struct RPCWrappedFunc {
public:
......@@ -98,36 +96,12 @@ class RPCModuleNode final : public ModuleNode {
std::shared_ptr<RPCSession> sess_;
};
Module RPCConnect(std::string url, int port) {
common::TCPSocket sock;
common::SockAddr addr(url.c_str(), port);
sock.Create();
CHECK(sock.Connect(addr))
<< "Connect to " << addr.AsString() << " failed";
// hand shake
int code = kRPCMagic;
CHECK_EQ(sock.SendAll(&code, sizeof(code)), sizeof(code));
CHECK_EQ(sock.RecvAll(&code, sizeof(code)), sizeof(code));
if (code != kRPCMagic) {
sock.Close();
LOG(FATAL) << "URL " << url << ":" << port << " is not TVM RPC server";
}
Module CreateRPCModule(std::shared_ptr<RPCSession> sess) {
std::shared_ptr<RPCModuleNode> n =
std::make_shared<RPCModuleNode>(nullptr, RPCSession::Create(sock));
std::make_shared<RPCModuleNode>(nullptr, sess);
return Module(n);
}
void RPCServerLoop(int sockfd) {
common::TCPSocket sock(
static_cast<common::TCPSocket::SockType>(sockfd));
RPCSession::Create(sock)->ServerLoop();
}
TVM_REGISTER_GLOBAL("contrib.rpc._Connect")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = RPCConnect(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("module._RPCTimeEvaluator")
.set_body([](TVMArgs args, TVMRetValue* rv) {
Module m = args[0];
......@@ -163,10 +137,5 @@ TVM_REGISTER_GLOBAL("contrib.rpc._SessTableIndex")
CHECK_EQ(tkey, "rpc");
*rv = static_cast<RPCModuleNode*>(m.operator->())->sess()->table_index();
});
TVM_REGISTER_GLOBAL("contrib.rpc._ServerLoop")
.set_body([](TVMArgs args, TVMRetValue* rv) {
RPCServerLoop(args[0]);
});
} // namespace runtime
} // namespace tvm
/*!
* Copyright (c) 2017 by Contributors
* \file rpc_server_env
* \brief Server environment of the RPC.
*/
#include <tvm/runtime/registry.h>
#include "../file_util.h"
namespace tvm {
namespace runtime {
std::string RPCGetPath(const std::string& name) {
static const PackedFunc* f =
runtime::Registry::Get("tvm.contrib.rpc.server.workpath");
CHECK(f != nullptr) << "require tvm.contrib.rpc.server.workpath";
return (*f)(name);
}
TVM_REGISTER_GLOBAL("tvm.contrib.rpc.server.upload").
set_body([](TVMArgs args, TVMRetValue *rv) {
std::string file_name = RPCGetPath(args[0]);
std::string data = args[1];
LOG(INFO) << "Upload " << file_name << "... nbytes=" << data.length();
SaveBinaryToFile(file_name, data);
});
TVM_REGISTER_GLOBAL("tvm.contrib.rpc.server.download")
.set_body([](TVMArgs args, TVMRetValue *rv) {
std::string file_name = RPCGetPath(args[0]);
std::string data;
LoadBinaryFromFile(file_name, &data);
TVMByteArray arr;
arr.data = data.c_str();
arr.size = data.length();
LOG(INFO) << "Download " << file_name << "... nbytes=" << arr.size;
*rv = arr;
});
TVM_REGISTER_GLOBAL("tvm.contrib.rpc.server.load_module")
.set_body([](TVMArgs args, TVMRetValue *rv) {
std::string file_name = RPCGetPath(args[0]);
*rv = Module::LoadFromFile(file_name, "");
LOG(INFO) << "Load module from " << file_name << " ...";
});
} // namespace runtime
} // namespace tvm
......@@ -10,11 +10,13 @@
#include <tvm/runtime/device_api.h>
#include <mutex>
#include <string>
#include "../../common/socket.h"
#include "../../common/ring_buffer.h"
namespace tvm {
namespace runtime {
const int kRPCMagic = 0xff271;
/*! \brief The remote functio handle */
using RPCFuncHandle = void*;
......@@ -22,6 +24,7 @@ struct RPCArgBuffer;
/*! \brief The RPC code */
enum class RPCCode : int {
kNone,
kCallFunc,
kReturn,
kException,
......@@ -30,6 +33,7 @@ enum class RPCCode : int {
kCopyToRemote,
kCopyAck,
// The following are code that can send over CallRemote
kSystemFuncStart,
kGetGlobalFunc,
kGetTimeEvaluator,
kFreeFunc,
......@@ -45,6 +49,30 @@ enum class RPCCode : int {
kModuleGetSource
};
/*!
* \brief Abstract channel interface used to create RPCSession.
*/
class RPCChannel {
public:
/*! \brief virtual destructor */
virtual ~RPCChannel() {}
/*!
* \brief Send data over to the channel.
* \param data The data pointer.
* \param size The size fo the data.
* \return The actual bytes sent.
*/
virtual size_t Send(const void* data, size_t size) = 0;
/*!
e * \brief Recv data from channel.
*
* \param data The data pointer.
* \param size The size fo the data.
* \return The actual bytes received.
*/
virtual size_t Recv(void* data, size_t size) = 0;
};
// Bidirectional Communication Session of PackedRPC
class RPCSession {
public:
......@@ -55,6 +83,17 @@ class RPCSession {
*/
void ServerLoop();
/*!
* \brief Message handling function for event driven server.
* Called when the server receives a message.
* Event driven handler will never call recv on the channel
* and always relies on the ServerOnMessageHandler
* to receive the data.
*
* \param bytes The incoming bytes.
* \return Whether need continue running, return false when receive a shutdown message.
*/
bool ServerOnMessageHandler(const std::string& bytes);
/*!
* \brief Call into remote function
* \param handle The function handle
* \param args The arguments
......@@ -121,10 +160,13 @@ class RPCSession {
}
/*!
* \brief Create a RPC session with given socket
* \param sock The socket.
* \param channel The communication channel.
* \param name The name of the session, used for debug
* \return The session.
*/
static std::shared_ptr<RPCSession> Create(common::TCPSocket sock);
static std::shared_ptr<RPCSession> Create(
std::unique_ptr<RPCChannel> channel,
std::string name);
/*!
* \brief Try get session from the global session table by table index.
* \param table_index The table index of the session.
......@@ -133,36 +175,28 @@ class RPCSession {
static std::shared_ptr<RPCSession> Get(int table_index);
private:
/*!
* \brief Handle the remote call with f
* \param f The handle function
* \tparam F the handler function.
*/
template<typename F>
void CallHandler(F f);
class EventHandler;
// Handle events until receives a return
// Also flushes channels so that the function advances.
RPCCode HandleUntilReturnEvent(TVMRetValue* rv);
// Initalization
void Init();
// Shutdown
void Shutdown();
void SendReturnValue(int succ, TVMValue value, int tcode);
void SendPackedSeq(const TVMValue* arg_values, const int* type_codes, int n);
void RecvPackedSeq(RPCArgBuffer *buf);
RPCCode HandleNextEvent(TVMRetValue *rv);
TVMContext StripSessMask(TVMContext ctx);
// special handler.
void HandleCallFunc();
void HandleException();
void HandleCopyFromRemote();
void HandleCopyToRemote();
void HandleReturn(TVMRetValue* rv);
// Internal channel.
std::unique_ptr<RPCChannel> channel_;
// Internal mutex
std::recursive_mutex mutex_;
// Internal socket
common::TCPSocket sock_;
// Internal temporal data space.
std::string temp_data_;
// Internal ring buffer.
common::RingBuffer reader_, writer_;
// Event handler.
std::shared_ptr<EventHandler> handler_;
// call remote with the specified function coede.
PackedFunc call_remote_;
// The index of this session in RPC session table.
int table_index_{0};
// The name of the session.
std::string name_;
};
/*!
......@@ -173,6 +207,13 @@ class RPCSession {
*/
PackedFunc WrapTimeEvaluator(PackedFunc f, TVMContext ctx, int nstep);
/*!
* \brief Create a Global RPC module that refers to the session.
* \param sess The RPC session of the global module.
* \return The created module.
*/
Module CreateRPCModule(std::shared_ptr<RPCSession> sess);
// Remote space pointer.
struct RemoteSpace {
void* data;
......@@ -183,7 +224,7 @@ struct RemoteSpace {
template<typename... Args>
inline TVMRetValue RPCSession::CallRemote(RPCCode code, Args&& ...args) {
std::lock_guard<std::recursive_mutex> lock(mutex_);
CHECK_EQ(sock_.SendAll(&code, sizeof(code)), sizeof(code));
writer_.Write(&code, sizeof(code));
return call_remote_(std::forward<Args>(args)...);
}
} // namespace runtime
......
/*!
* Copyright (c) 2017 by Contributors
* \file rpc_socket_impl.cc
* \brief Socket based RPC implementation.
*/
#include <tvm/runtime/registry.h>
#include <memory>
#include "./rpc_session.h"
#include "../../common/socket.h"
namespace tvm {
namespace runtime {
class SockChannel final : public RPCChannel {
public:
explicit SockChannel(common::TCPSocket sock)
: sock_(sock) {}
~SockChannel() {
if (!sock_.BadSocket()) {
sock_.Close();
}
}
size_t Send(const void* data, size_t size) final {
ssize_t n = sock_.Send(data, size);
if (n == -1) {
common::Socket::Error("SockChannel::Send");
}
return static_cast<size_t>(n);
}
size_t Recv(void* data, size_t size) final {
ssize_t n = sock_.Recv(data, size);
if (n == -1) {
common::Socket::Error("SockChannel::Recv");
}
return static_cast<size_t>(n);
}
private:
common::TCPSocket sock_;
};
Module RPCConnect(std::string url, int port, std::string key) {
common::TCPSocket sock;
common::SockAddr addr(url.c_str(), port);
sock.Create();
CHECK(sock.Connect(addr))
<< "Connect to " << addr.AsString() << " failed";
// hand shake
std::ostringstream os;
os << "client:" << key;
key = os.str();
int code = kRPCMagic;
int keylen = static_cast<int>(key.length());
CHECK_EQ(sock.SendAll(&code, sizeof(code)), sizeof(code));
CHECK_EQ(sock.SendAll(&keylen, sizeof(keylen)), sizeof(keylen));
if (keylen != 0) {
CHECK_EQ(sock.SendAll(key.c_str(), keylen), keylen);
}
CHECK_EQ(sock.RecvAll(&code, sizeof(code)), sizeof(code));
if (code == kRPCMagic + 2) {
sock.Close();
LOG(FATAL) << "URL " << url << ":" << port
<< " cannot find server that matches key=" << key;
} else if (code == kRPCMagic + 1) {
sock.Close();
LOG(FATAL) << "URL " << url << ":" << port
<< " server already have client key=" << key;
} else if (code != kRPCMagic) {
sock.Close();
LOG(FATAL) << "URL " << url << ":" << port << " is not TVM RPC server";
}
return CreateRPCModule(
RPCSession::Create(
std::unique_ptr<SockChannel>(new SockChannel(sock)),
"SockClient"));
}
void RPCServerLoop(int sockfd) {
common::TCPSocket sock(
static_cast<common::TCPSocket::SockType>(sockfd));
RPCSession::Create(
std::unique_ptr<SockChannel>(new SockChannel(sock)),
"SockServerLoop")->ServerLoop();
}
TVM_REGISTER_GLOBAL("contrib.rpc._Connect")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = RPCConnect(args[0], args[1], args[2]);
});
TVM_REGISTER_GLOBAL("contrib.rpc._ServerLoop")
.set_body([](TVMArgs args, TVMRetValue* rv) {
RPCServerLoop(args[0]);
});
} // namespace runtime
} // namespace tvm
......@@ -12,5 +12,8 @@ RUN bash /install/ubuntu_install_python.sh
COPY install/ubuntu_install_iverilog.sh /install/ubuntu_install_iverilog.sh
RUN bash /install/ubuntu_install_iverilog.sh
COPY install/ubuntu_install_python_package.sh /install/ubuntu_install_python_package.sh
RUN bash /install/ubuntu_install_python_package.sh
COPY install/ubuntu_install_java.sh /install/ubuntu_install_java.sh
RUN bash /install/ubuntu_install_java.sh
......@@ -12,4 +12,7 @@ RUN bash /install/ubuntu_install_python.sh
COPY install/ubuntu_install_emscripten.sh /install/ubuntu_install_emscripten.sh
RUN bash /install/ubuntu_install_emscripten.sh
COPY install/ubuntu_install_python_package.sh /install/ubuntu_install_python_package.sh
RUN bash /install/ubuntu_install_python_package.sh
RUN cp /root/.emscripten /emsdk-portable/
\ No newline at end of file
......@@ -18,6 +18,9 @@ RUN bash /install/ubuntu_install_opencl.sh
COPY install/ubuntu_install_iverilog.sh /install/ubuntu_install_iverilog.sh
RUN bash /install/ubuntu_install_iverilog.sh
COPY install/ubuntu_install_python_package.sh /install/ubuntu_install_python_package.sh
RUN bash /install/ubuntu_install_python_package.sh
COPY install/ubuntu_install_sphinx.sh /install/ubuntu_install_sphinx.sh
RUN bash /install/ubuntu_install_sphinx.sh
......
......@@ -5,8 +5,11 @@ RUN apt-get update --fix-missing
COPY install/ubuntu_install_core.sh /install/ubuntu_install_core.sh
RUN bash /install/ubuntu_install_core.sh
COPY install/ubuntu_install_llvm.sh /install/ubuntu_install_llvm.sh
RUN bash /install/ubuntu_install_llvm.sh
COPY install/ubuntu_install_python.sh /install/ubuntu_install_python.sh
RUN bash /install/ubuntu_install_python.sh
COPY install/ubuntu_install_llvm.sh /install/ubuntu_install_llvm.sh
RUN bash /install/ubuntu_install_llvm.sh
COPY install/ubuntu_install_python_package.sh /install/ubuntu_install_python_package.sh
RUN bash /install/ubuntu_install_python_package.sh
apt-get update && apt-get install -y curl
curl -sL https://deb.nodesource.com/setup_6.x | bash -
apt-get update && apt-get install -y nodejs
npm install eslint jsdoc
npm install eslint jsdoc ws
# install libraries for python package on ubuntu
# install python and pip, don't modify this, modify install_python_package.sh
apt-get update && apt-get install -y python-pip python-dev python3-dev
# the version of the pip shipped with ubuntu may be too lower, install a recent version here
cd /tmp && wget https://bootstrap.pypa.io/get-pip.py && python3 get-pip.py && python2 get-pip.py
pip2 install nose pylint numpy nose-timer cython decorator scipy
pip3 install nose pylint numpy nose-timer cython decorator scipy
# install libraries for python package on ubuntu
pip2 install nose pylint numpy nose-timer cython decorator scipy tornado
pip3 install nose pylint numpy nose-timer cython decorator scipy tornado
import tvm
import logging
import numpy as np
import time
import multiprocessing
from tvm.contrib import rpc
def rpc_proxy_test():
"""This is a simple test function for RPC Proxy
It is not included as nosetests, because:
- It depends on tornado
- It relies on the fact that Proxy starts before client and server connects,
which is often the case but not always
User can directly run this script to verify correctness.
"""
try:
from tvm.contrib import rpc_proxy
web_port = 8888
prox = rpc_proxy.Proxy("localhost", web_port=web_port)
def check():
if not tvm.module.enabled("rpc"):
return
@tvm.register_func("rpc.test2.addone")
def addone(x):
return x + 1
@tvm.register_func("rpc.test2.strcat")
def addone(name, x):
return "%s:%d" % (name, x)
server = multiprocessing.Process(
target=rpc_proxy.websocket_proxy_server,
args=("ws://localhost:%d/ws" % web_port,"x1"))
# Need to make sure that the connection start after proxy comes up
time.sleep(0.1)
server.deamon = True
server.start()
client = rpc.connect(prox.host, prox.port, key="x1")
f1 = client.get_function("rpc.test2.addone")
assert f1(10) == 11
f2 = client.get_function("rpc.test2.strcat")
assert f2("abc", 11) == "abc:11"
check()
except ImportError:
print("Skipping because tornado is not avaliable...")
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
rpc_proxy_test()
......@@ -14,10 +14,21 @@ def test_rpc_simple():
def addone(name, x):
return "%s:%d" % (name, x)
@tvm.register_func("rpc.test.except")
def remotethrow(name):
raise ValueError("%s" % name)
server = rpc.Server("localhost")
client = rpc.connect(server.host, server.port)
client = rpc.connect(server.host, server.port, key="x1")
f1 = client.get_function("rpc.test.addone")
assert f1(10) == 11
f3 = client.get_function("rpc.test.except")
try:
f3("abc")
assert False
except tvm.TVMError as e:
assert "abc" in str(e)
f2 = client.get_function("rpc.test.strcat")
assert f2("abc", 11) == "abc:11"
......@@ -42,9 +53,10 @@ def test_rpc_file_exchange():
return
server = rpc.Server("localhost")
remote = rpc.connect(server.host, server.port)
blob = bytearray(np.random.randint(0, 10, size=(127)))
blob = bytearray(np.random.randint(0, 10, size=(10)))
remote.upload(blob, "dat.bin")
rev = remote.download("dat.bin")
assert(rev == blob)
def test_rpc_remote_module():
if not tvm.module.enabled("rpc"):
......@@ -79,7 +91,8 @@ def test_rpc_remote_module():
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
test_rpc_file_exchange()
exit(0)
test_rpc_array()
test_rpc_remote_module()
test_rpc_file_exchange()
test_rpc_simple()
"""Simple testcode to test Javascript RPC
To use it, start a rpc proxy with "python -m tvm.exec.rpc_proxy".
Connect javascript end to the websocket port and connect to the RPC.
"""
import tvm
import os
from tvm.contrib import rpc, util, emscripten
import numpy as np
proxy_host = "localhost"
proxy_port = 9090
def test_rpc_array():
if not tvm.module.enabled("rpc"):
return
# graph
n = tvm.convert(1024)
A = tvm.placeholder((n,), name='A')
B = tvm.compute(A.shape, lambda *i: A(*i) + 1.0, name='B')
s = tvm.create_schedule(B.op)
remote = rpc.connect(proxy_host, proxy_port, key="js")
target = "llvm -target=asmjs-unknown-emscripten -system-lib"
def check_remote():
if not tvm.module.enabled(target):
print("Skip because %s is not enabled" % target)
return
temp = util.tempdir()
ctx = remote.cpu(0)
f = tvm.build(s, [A, B], target, name="myadd")
path_obj = temp.relpath("dev_lib.bc")
path_dso = temp.relpath("dev_lib.js")
f.save(path_obj)
emscripten.create_js(path_dso, path_obj, side_module=True)
# Upload to suffix as dso so it can be loaded remotely
remote.upload(path_dso, "dev_lib.dso")
data = remote.download("dev_lib.dso")
f1 = remote.load_module("dev_lib.dso")
a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx)
b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx)
time_f = f1.time_evaluator(f1.entry_name, remote.cpu(0), number=10)
cost = time_f(a, b)
print('%g secs/op' % cost)
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
check_remote()
test_rpc_array()
......@@ -139,9 +139,18 @@ sysLib.release();
```
## Notes
- Current example supports static linking, which is the preferred way to get more efficiency
Current example supports static linking, which is the preferred way to get more efficiency
in javascript backend.
- It should also be possible to use Emscripten's dynamic linking to dynamically load modules.
- Take a look at tvm_runtime.js which contains quite a few helper functions
to interact with TVM from javascript.
\ No newline at end of file
## Proxy based RPC
We can now use javascript end to start an RPC server and connect to it from python side,
making the testing flow easier.
The following is an example to reproduce this. This requires everything to be in the git source and setup PYTHONPATH(instead of use setup.py install)
- run "python -m tvm.exec.rpc_proxy --example-rpc=1" to start proxy.
- Open broswer, goto the server webpage click Connect to proxy.
- Alternatively run "node web/example_rpc_node.js"
- run "python tests/web/websock_rpc_test.py" to run the rpc client.
The general idea is to use Emscripten's dynamic linking to dynamically load modules.
<html>
<head><title> TVM RPC Test Page </title></head>
<script src="libtvm_web_runtime.js"></script>
<script src="tvm_runtime.js"></script>
<script>
tvm = tvm_runtime.create(Module);
tvm.logger = function(message) {
console.log(message);
var d = document.createElement("div");
d.innerHTML = message;
document.getElementById("log").appendChild(d);
};
function clear_log() {
var node = document.getElementById("log");
while (node.hasChildNodes()) {
node.removeChild(node.lastChild);
}
}
function connect_rpc() {
var proxyurl = document.getElementById("proxyURL").value;
var key = document.getElementById("proxyKey").value;
tvm.startRPCServer(proxyurl, key, 100);
}
</script>
<body>
<h1>TVM Test Page</h1>
To use this page, the easiest way is to do
<ul>
<li> run "python -m tvm.exec.rpc_proxy --example-rpc=1" to start proxy.
<li> Click Connect to proxy.
<li> run "python tests/web/websock_rpc_test.py" to run the rpc client.
</ul>
<h2>Options</h2>
Proxy URL<input name="proxyurl" id="proxyURL" type="text" value="ws://localhost:9888/ws"><br>
RPC Server Key<input name="serverkey" id="proxyKey" type="text" value="js"><br>
<button onclick="connect_rpc()">Connect To Proxy</button>
<button onclick="clear_log()">Clear Log</button>
<div id="log"></div>
</body>
</html>
// Javascript RPC server example
// Start and connect to websocket proxy.
// Load Emscripten Module, need to change path to root/lib
const path = require("path");
process.chdir(path.join(__dirname, "../lib"));
var Module = require("../lib/libtvm_web_runtime.js");
// Bootstrap TVMruntime with emscripten module.
const tvm_runtime = require("../web/tvm_runtime.js");
const tvm = tvm_runtime.create(Module);
var websock_proxy = "ws://localhost:9888/ws";
var num_sess = 100;
tvm.startRPCServer(websock_proxy, "js", num_sess)
......@@ -7,6 +7,7 @@
/* eslint no-unused-vars: "off" */
/* eslint no-unexpected-multiline: "off" */
/* eslint indent: "off" */
/* eslint no-console: "off" */
/**
* TVM Runtime namespace.
* Provide tvm_runtime.create to create a {@link tvm.TVMRuntime}.
......@@ -31,8 +32,13 @@ var tvm_runtime = tvm_runtime || {};
* @memberof tvm
*/
function TVMRuntime() {
"use strict";
var runtime_ref = this;
// Utility function to throw error
function throwError(message) {
if (typeof runtime_ref.logger !== "undefined") {
runtime_ref.logger(message);
}
if (typeof Error !== "undefined") {
throw new Error(message);
}
......@@ -232,6 +238,21 @@ var tvm_runtime = tvm_runtime || {};
throwError(message);
}
};
/**
* Logging function.
* Override this to change logger behavior.
*
* @param {string} message
*/
this.logger = function(message) {
console.log(message);
};
function logging(message) {
runtime_ref.logger(message);
}
// Override print error to logging
Module.printErr = logging;
var CHECK = this.assert;
function TVM_CALL(ret) {
......@@ -746,11 +767,12 @@ var tvm_runtime = tvm_runtime || {};
out_array.release();
return names;
};
var listGlobalFuncNames = this.listGlobalFuncNames;
/**
* Get a global function from TVM runtime.
*
* @param {string} The name of the function.
* @return {Function} The corresponding function.
* @return {Function} The corresponding function, null if function do not exist
*/
this.getGlobalFunc = function (name) {
// alloc
......@@ -759,7 +781,11 @@ var tvm_runtime = tvm_runtime || {};
var out_handle = out.asHandle();
// release
out.release();
if (out_handle != 0) {
return makeTVMFunction(out_handle);
} else {
return null;
}
};
var getGlobalFunc = this.getGlobalFunc;
/**
......@@ -788,14 +814,123 @@ var tvm_runtime = tvm_runtime || {};
//-----------------------------------------
// Wrap of TVM Functions.
// ----------------------------------------
var fGetSystemLib = getGlobalFunc("module._GetSystemLib");
var systemFunc = {};
/**
* Get system-wide library module singleton.
* Get system-wide library module singleton.5A
* System lib is a global module that contains self register functions in startup.
* @return {tvm.TVMModule} The system module singleton.
*/
this.systemLib = function() {
return fGetSystemLib();
if (typeof systemFunc.fGetSystemLib === "undefined") {
systemFunc.fGetSystemLib = getGlobalFunc("module._GetSystemLib");
}
return systemFunc.fGetSystemLib();
};
this.startRPCServer = function(url, key, counter) {
if (typeof key === "undefined") {
key = "";
}
if (typeof counter === "undefined") {
counter = 1;
}
// Node js, import websocket
var bkey = StringToUint8Array("server:" + key);
var server_name = "WebSocketRPCServer[" + key + "]";
var RPC_MAGIC = 0xff271;
function checkEndian() {
var a = new ArrayBuffer(4);
var b = new Uint8Array(a);
var c = new Uint32Array(a);
b[0] = 0x11;
b[1] = 0x22;
b[2] = 0x33;
b[3] = 0x44;
CHECK(c[0] === 0x44332211, "Need little endian to work");
}
checkEndian();
// start rpc
function RPCServer(counter) {
var socket;
if (typeof module !== "undefined" && module.exports) {
// WebSocket for nodejs
const WebSocket = require("ws");
socket = new WebSocket(url);
} else {
socket = new WebSocket(url);
}
var self = this;
socket.binaryType = "arraybuffer";
this.init = true;
this.counter = counter;
if (typeof systemFunc.fcreateServer === "undefined") {
systemFunc.fcreateServer =
getGlobalFunc("contrib.rpc._CreateEventDrivenServer");
}
if (systemFunc.fcreateServer == null) {
throwError("RPCServer is not included in runtime");
}
var message_handler = systemFunc.fcreateServer(
function(cbytes) {
if (socket.readyState == 1) {
socket.send(cbytes);
return new TVMConstant(cbytes.length, "int32");
} else {
return new TVMConstant(0, "int32");
}
} , server_name);
function on_open(event) {
var intbuf = new Int32Array(1);
intbuf[0] = RPC_MAGIC;
socket.send(intbuf);
intbuf[0] = bkey.length;
socket.send(intbuf);
socket.send(bkey);
logging(server_name + " connected...");
}
function on_message(event) {
if (self.init) {
var msg = new Uint8Array(event.data);
CHECK(msg.length >= 4, "Need message header to be bigger than 4");
var magic = new Int32Array(event.data)[0];
if (magic == RPC_MAGIC + 1) {
throwError("key: " + key + " has already been used in proxy");
} else if (magic == RPC_MAGIC + 2) {
logging(server_name + ": RPCProxy do not have matching client key " + key);
} else {
CHECK(magic == RPC_MAGIC, url + "is not RPC Proxy");
self.init = false;
}
logging(server_name + "init end...");
if (msg.length > 4) {
if (!message_handler(new Uint8Array(event.data, 4, msg.length -4))) {
socket.close();
}
}
} else {
if (!message_handler(new Uint8Array(event.data))) {
socket.close();
}
}
}
function on_close(event) {
message_handler.release();
logging(server_name + ": closed finish...");
if (!self.init && self.counter != 0) {
logging(server_name + ":reconnect to serve another request, session left=" + counter);
// start a new server.
new RPCServer(counter - 1);
}
}
socket.addEventListener("open", on_open);
socket.addEventListener("message", on_message);
socket.addEventListener("close", on_close);
}
return new RPCServer(counter);
};
//-----------------------------------------
// Class defintions
......@@ -943,16 +1078,12 @@ var tvm_runtime = tvm_runtime || {};
* @property {string} create
* @memberof tvm_runtime
* @param Module The emscripten module.
* @param Runtime The emscripten runtime, optional
* @return {tvm.TVMRuntime} The created TVM runtime.
*/
this.create = function(Module, Runtime) {
this.create = function(Module) {
var tvm = {};
tvm.Module = Module;
if (typeof Runtime == "undefined") {
Runtime = Module.Runtime;
}
tvm.Runtime = Runtime;
tvm.Runtime = Module.Runtime;
TVMRuntime.apply(tvm);
return tvm;
};
......
......@@ -2,6 +2,9 @@
* Copyright (c) 2017 by Contributors
* \file web_runtime.cc
*/
#include <sys/stat.h>
#include <fstream>
#include "../src/runtime/c_runtime_api.cc"
#include "../src/runtime/cpu_device_api.cc"
#include "../src/runtime/workspace_pool.cc"
......@@ -10,3 +13,33 @@
#include "../src/runtime/module.cc"
#include "../src/runtime/registry.cc"
#include "../src/runtime/file_util.cc"
#include "../src/runtime/dso_module.cc"
#include "../src/runtime/rpc/rpc_session.cc"
#include "../src/runtime/rpc/rpc_event_impl.cc"
#include "../src/runtime/rpc/rpc_server_env.cc"
namespace tvm {
namespace contrib {
struct RPCEnv {
public:
RPCEnv() {
base_ = "/rpc";
mkdir(&base_[0], 0777);
}
// Get Path.
std::string GetPath(const std::string& file_name) {
return base_ + "/" + file_name;
}
private:
std::string base_;
};
TVM_REGISTER_GLOBAL("tvm.contrib.rpc.server.workpath")
.set_body([](TVMArgs args, TVMRetValue* rv) {
static RPCEnv env;
*rv = env.GetPath(args[0]);
});
} // namespace contrib
} // namespace tvm
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment