Commit 8d241b9d by Tianqi Chen Committed by GitHub

[RPC] Allow back pressure from writer (#250)

* [RPC] Allow backpressure from writer

* fix

* fix
parent c6d4f5af
...@@ -235,8 +235,9 @@ class RequestHandler(tornado.web.RequestHandler): ...@@ -235,8 +235,9 @@ class RequestHandler(tornado.web.RequestHandler):
self.page = open(kwargs.pop("file_path")).read() self.page = open(kwargs.pop("file_path")).read()
web_port = kwargs.pop("rpc_web_port", None) web_port = kwargs.pop("rpc_web_port", None)
if web_port: if web_port:
self.page.replace(r"ws://localhost:9888/ws", self.page = self.page.replace(
r"ws://localhost:%d/ws" % web_port) "ws://localhost:9190/ws",
"ws://localhost:%d/ws" % web_port)
super(RequestHandler, self).__init__(*args, **kwargs) super(RequestHandler, self).__init__(*args, **kwargs)
def data_received(self, _): def data_received(self, _):
...@@ -468,14 +469,14 @@ def websocket_proxy_server(url, key=""): ...@@ -468,14 +469,14 @@ def websocket_proxy_server(url, key=""):
logging.info("Connection established") logging.info("Connection established")
msg = msg[4:] msg = msg[4:]
if msg: if msg:
on_message(bytearray(msg)) on_message(bytearray(msg), 3)
while True: while True:
try: try:
msg = yield conn.read_message() msg = yield conn.read_message()
if msg is None: if msg is None:
break break
on_message(bytearray(msg)) on_message(bytearray(msg), 3)
except websocket.WebSocketClosedError as err: except websocket.WebSocketClosedError as err:
break break
logging.info("WebSocketProxyServer closed...") logging.info("WebSocketProxyServer closed...")
......
...@@ -29,7 +29,7 @@ def main(): ...@@ -29,7 +29,7 @@ def main():
help='the hostname of the server') help='the hostname of the server')
parser.add_argument('--port', type=int, default=9090, parser.add_argument('--port', type=int, default=9090,
help='The port of the PRC') help='The port of the PRC')
parser.add_argument('--web-port', type=int, default=9888, parser.add_argument('--web-port', type=int, default=9190,
help='The port of the http/websocket server') help='The port of the http/websocket server')
parser.add_argument('--example-rpc', type=bool, default=False, parser.add_argument('--example-rpc', type=bool, default=False,
help='Whether to switch on example rpc mode') help='Whether to switch on example rpc mode')
......
...@@ -32,18 +32,18 @@ class CallbackChannel final : public RPCChannel { ...@@ -32,18 +32,18 @@ class CallbackChannel final : public RPCChannel {
PackedFunc fsend_; PackedFunc fsend_;
}; };
PackedFunc CreateEvenDrivenServer(PackedFunc fsend, std::string name) { PackedFunc CreateEventDrivenServer(PackedFunc fsend, std::string name) {
std::unique_ptr<CallbackChannel> ch(new CallbackChannel(fsend)); std::unique_ptr<CallbackChannel> ch(new CallbackChannel(fsend));
std::shared_ptr<RPCSession> sess = RPCSession::Create(std::move(ch), name); std::shared_ptr<RPCSession> sess = RPCSession::Create(std::move(ch), name);
return PackedFunc([sess](TVMArgs args, TVMRetValue* rv) { return PackedFunc([sess](TVMArgs args, TVMRetValue* rv) {
bool ret = sess->ServerOnMessageHandler(args[0]); int ret = sess->ServerEventHandler(args[0], args[1]);
*rv = ret; *rv = ret;
}); });
} }
TVM_REGISTER_GLOBAL("contrib.rpc._CreateEventDrivenServer") TVM_REGISTER_GLOBAL("contrib.rpc._CreateEventDrivenServer")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = CreateEvenDrivenServer(args[0], args[1]); *rv = CreateEventDrivenServer(args[0], args[1]);
}); });
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
...@@ -752,19 +752,23 @@ void RPCSession::ServerLoop() { ...@@ -752,19 +752,23 @@ void RPCSession::ServerLoop() {
channel_.reset(nullptr); channel_.reset(nullptr);
} }
int RPCSession::ServerEventHandler(const std::string& bytes, int event_flag) {
bool RPCSession::ServerOnMessageHandler(const std::string& bytes) {
std::lock_guard<std::recursive_mutex> lock(mutex_); std::lock_guard<std::recursive_mutex> lock(mutex_);
reader_.Write(bytes.c_str(), bytes.length()); RPCCode code = RPCCode::kNone;
TVMRetValue rv; if (bytes.length() != 0) {
RPCCode code = handler_->HandleNextEvent(&rv, false, nullptr); reader_.Write(bytes.c_str(), bytes.length());
while (writer_.bytes_available() != 0) { TVMRetValue rv;
code = handler_->HandleNextEvent(&rv, false, nullptr);
}
if ((event_flag & 2) != 0 && writer_.bytes_available() != 0) {
writer_.ReadWithCallback([this](const void *data, size_t size) { writer_.ReadWithCallback([this](const void *data, size_t size) {
return channel_->Send(data, size); return channel_->Send(data, size);
}, writer_.bytes_available()); }, writer_.bytes_available());
} }
CHECK(code != RPCCode::kReturn && code != RPCCode::kCopyAck); CHECK(code != RPCCode::kReturn && code != RPCCode::kCopyAck);
return code != RPCCode::kShutdown; if (code == RPCCode::kShutdown) return 0;
if (writer_.bytes_available() != 0) return 2;
return 1;
} }
// Get remote function with name // Get remote function with name
......
...@@ -86,13 +86,18 @@ class RPCSession { ...@@ -86,13 +86,18 @@ class RPCSession {
* \brief Message handling function for event driven server. * \brief Message handling function for event driven server.
* Called when the server receives a message. * Called when the server receives a message.
* Event driven handler will never call recv on the channel * Event driven handler will never call recv on the channel
* and always relies on the ServerOnMessageHandler * and always relies on the ServerEventHandler.
* to receive the data. * to receive the data.
* *
* \param bytes The incoming bytes. * \param in_bytes The incoming bytes.
* \return Whether need continue running, return false when receive a shutdown message. * \param event_flag 1: read_available, 2: write_avaiable.
* \return State flag.
* 1: continue running, no need to write,
* 2: need to write
* 0: shutdown
*/ */
bool ServerOnMessageHandler(const std::string& bytes); int ServerEventHandler(const std::string& in_bytes,
int event_flag);
/*! /*!
* \brief Call into remote function * \brief Call into remote function
* \param handle The function handle * \param handle The function handle
...@@ -161,7 +166,7 @@ class RPCSession { ...@@ -161,7 +166,7 @@ class RPCSession {
return table_index_; return table_index_;
} }
/*! /*!
* \brief Create a RPC session with given socket * \brief Create a RPC session with given channel.
* \param channel The communication channel. * \param channel The communication channel.
* \param name The name of the session, used for debug * \param name The name of the session, used for debug
* \return The session. * \return The session.
......
...@@ -5,7 +5,7 @@ import time ...@@ -5,7 +5,7 @@ import time
import multiprocessing import multiprocessing
from tvm.contrib import rpc from tvm.contrib import rpc
def rpc_proxy_test(): def rpc_proxy_check():
"""This is a simple test function for RPC Proxy """This is a simple test function for RPC Proxy
It is not included as nosetests, because: It is not included as nosetests, because:
...@@ -47,4 +47,4 @@ def rpc_proxy_test(): ...@@ -47,4 +47,4 @@ def rpc_proxy_test():
if __name__ == "__main__": if __name__ == "__main__":
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
rpc_proxy_test() rpc_proxy_check()
...@@ -31,7 +31,7 @@ ...@@ -31,7 +31,7 @@
<li> run "python tests/web/websock_rpc_test.py" to run the rpc client. <li> run "python tests/web/websock_rpc_test.py" to run the rpc client.
</ul> </ul>
<h2>Options</h2> <h2>Options</h2>
Proxy URL<input name="proxyurl" id="proxyURL" type="text" value="ws://localhost:9888/ws"><br> Proxy URL<input name="proxyurl" id="proxyURL" type="text" value="ws://localhost:9190/ws"><br>
RPC Server Key<input name="serverkey" id="proxyKey" type="text" value="js"><br> RPC Server Key<input name="serverkey" id="proxyKey" type="text" value="js"><br>
<button onclick="connect_rpc()">Connect To Proxy</button> <button onclick="connect_rpc()">Connect To Proxy</button>
<button onclick="clear_log()">Clear Log</button> <button onclick="clear_log()">Clear Log</button>
......
...@@ -9,6 +9,6 @@ var Module = require("../lib/libtvm_web_runtime.js"); ...@@ -9,6 +9,6 @@ var Module = require("../lib/libtvm_web_runtime.js");
const tvm_runtime = require("../web/tvm_runtime.js"); const tvm_runtime = require("../web/tvm_runtime.js");
const tvm = tvm_runtime.create(Module); const tvm = tvm_runtime.create(Module);
var websock_proxy = "ws://localhost:9888/ws"; var websock_proxy = "ws://localhost:9190/ws";
var num_sess = 100; var num_sess = 100;
tvm.startRPCServer(websock_proxy, "js", num_sess) tvm.startRPCServer(websock_proxy, "js", num_sess)
...@@ -503,7 +503,7 @@ var tvm_runtime = tvm_runtime || {}; ...@@ -503,7 +503,7 @@ var tvm_runtime = tvm_runtime || {};
* @return {boolean} Whether f is PackedFunc * @return {boolean} Whether f is PackedFunc
*/ */
this.isPackedFunc = function(f) { this.isPackedFunc = function(f) {
return (typeof f._tvm_function !== "undefined"); return (typeof f == "function") && f.hasOwnProperty("_tvm_function");
}; };
var isPackedFunc = this.isPackedFunc; var isPackedFunc = this.isPackedFunc;
/** /**
...@@ -633,7 +633,7 @@ var tvm_runtime = tvm_runtime || {}; ...@@ -633,7 +633,7 @@ var tvm_runtime = tvm_runtime || {};
} }
} else if (tp == "number") { } else if (tp == "number") {
this.setDouble(i, v); this.setDouble(i, v);
} else if (typeof v._tvm_function !== "undefined") { } else if (tp == "function" && v.hasOwnProperty("_tvm_function")) {
this.setString(i, v._tvm_function.handle, kFuncHandle); this.setString(i, v._tvm_function.handle, kFuncHandle);
} else if (v === null) { } else if (v === null) {
this.setHandle(i, 0, kNull); this.setHandle(i, 0, kNull);
...@@ -907,12 +907,15 @@ var tvm_runtime = tvm_runtime || {}; ...@@ -907,12 +907,15 @@ var tvm_runtime = tvm_runtime || {};
} }
logging(server_name + "init end..."); logging(server_name + "init end...");
if (msg.length > 4) { if (msg.length > 4) {
if (!message_handler(new Uint8Array(event.data, 4, msg.length -4))) { if (message_handler(
new Uint8Array(event.data, 4, msg.length -4),
new TVMConstant(3, "int32")) == 0) {
socket.close(); socket.close();
} }
} }
} else { } else {
if (!message_handler(new Uint8Array(event.data))) { if (message_handler(new Uint8Array(event.data),
new TVMConstant(3, "int32")) == 0) {
socket.close(); socket.close();
} }
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment