Commit 8d241b9d by Tianqi Chen Committed by GitHub

[RPC] Allow back pressure from writer (#250)

* [RPC] Allow backpressure from writer

* fix

* fix
parent c6d4f5af
......@@ -235,8 +235,9 @@ class RequestHandler(tornado.web.RequestHandler):
self.page = open(kwargs.pop("file_path")).read()
web_port = kwargs.pop("rpc_web_port", None)
if web_port:
self.page.replace(r"ws://localhost:9888/ws",
r"ws://localhost:%d/ws" % web_port)
self.page = self.page.replace(
"ws://localhost:9190/ws",
"ws://localhost:%d/ws" % web_port)
super(RequestHandler, self).__init__(*args, **kwargs)
def data_received(self, _):
......@@ -468,14 +469,14 @@ def websocket_proxy_server(url, key=""):
logging.info("Connection established")
msg = msg[4:]
if msg:
on_message(bytearray(msg))
on_message(bytearray(msg), 3)
while True:
try:
msg = yield conn.read_message()
if msg is None:
break
on_message(bytearray(msg))
on_message(bytearray(msg), 3)
except websocket.WebSocketClosedError as err:
break
logging.info("WebSocketProxyServer closed...")
......
......@@ -29,7 +29,7 @@ def main():
help='the hostname of the server')
parser.add_argument('--port', type=int, default=9090,
help='The port of the PRC')
parser.add_argument('--web-port', type=int, default=9888,
parser.add_argument('--web-port', type=int, default=9190,
help='The port of the http/websocket server')
parser.add_argument('--example-rpc', type=bool, default=False,
help='Whether to switch on example rpc mode')
......
......@@ -32,18 +32,18 @@ class CallbackChannel final : public RPCChannel {
PackedFunc fsend_;
};
PackedFunc CreateEvenDrivenServer(PackedFunc fsend, std::string name) {
PackedFunc CreateEventDrivenServer(PackedFunc fsend, std::string name) {
std::unique_ptr<CallbackChannel> ch(new CallbackChannel(fsend));
std::shared_ptr<RPCSession> sess = RPCSession::Create(std::move(ch), name);
return PackedFunc([sess](TVMArgs args, TVMRetValue* rv) {
bool ret = sess->ServerOnMessageHandler(args[0]);
int ret = sess->ServerEventHandler(args[0], args[1]);
*rv = ret;
});
}
TVM_REGISTER_GLOBAL("contrib.rpc._CreateEventDrivenServer")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = CreateEvenDrivenServer(args[0], args[1]);
*rv = CreateEventDrivenServer(args[0], args[1]);
});
} // namespace runtime
} // namespace tvm
......@@ -752,19 +752,23 @@ void RPCSession::ServerLoop() {
channel_.reset(nullptr);
}
bool RPCSession::ServerOnMessageHandler(const std::string& bytes) {
int RPCSession::ServerEventHandler(const std::string& bytes, int event_flag) {
std::lock_guard<std::recursive_mutex> lock(mutex_);
RPCCode code = RPCCode::kNone;
if (bytes.length() != 0) {
reader_.Write(bytes.c_str(), bytes.length());
TVMRetValue rv;
RPCCode code = handler_->HandleNextEvent(&rv, false, nullptr);
while (writer_.bytes_available() != 0) {
code = handler_->HandleNextEvent(&rv, false, nullptr);
}
if ((event_flag & 2) != 0 && writer_.bytes_available() != 0) {
writer_.ReadWithCallback([this](const void *data, size_t size) {
return channel_->Send(data, size);
}, writer_.bytes_available());
}
CHECK(code != RPCCode::kReturn && code != RPCCode::kCopyAck);
return code != RPCCode::kShutdown;
if (code == RPCCode::kShutdown) return 0;
if (writer_.bytes_available() != 0) return 2;
return 1;
}
// Get remote function with name
......
......@@ -86,13 +86,18 @@ class RPCSession {
* \brief Message handling function for event driven server.
* Called when the server receives a message.
* Event driven handler will never call recv on the channel
* and always relies on the ServerOnMessageHandler
* and always relies on the ServerEventHandler.
* to receive the data.
*
* \param bytes The incoming bytes.
* \return Whether need continue running, return false when receive a shutdown message.
*/
bool ServerOnMessageHandler(const std::string& bytes);
* \param in_bytes The incoming bytes.
* \param event_flag 1: read_available, 2: write_avaiable.
* \return State flag.
* 1: continue running, no need to write,
* 2: need to write
* 0: shutdown
*/
int ServerEventHandler(const std::string& in_bytes,
int event_flag);
/*!
* \brief Call into remote function
* \param handle The function handle
......@@ -161,7 +166,7 @@ class RPCSession {
return table_index_;
}
/*!
* \brief Create a RPC session with given socket
* \brief Create a RPC session with given channel.
* \param channel The communication channel.
* \param name The name of the session, used for debug
* \return The session.
......
......@@ -5,7 +5,7 @@ import time
import multiprocessing
from tvm.contrib import rpc
def rpc_proxy_test():
def rpc_proxy_check():
"""This is a simple test function for RPC Proxy
It is not included as nosetests, because:
......@@ -47,4 +47,4 @@ def rpc_proxy_test():
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
rpc_proxy_test()
rpc_proxy_check()
......@@ -31,7 +31,7 @@
<li> run "python tests/web/websock_rpc_test.py" to run the rpc client.
</ul>
<h2>Options</h2>
Proxy URL<input name="proxyurl" id="proxyURL" type="text" value="ws://localhost:9888/ws"><br>
Proxy URL<input name="proxyurl" id="proxyURL" type="text" value="ws://localhost:9190/ws"><br>
RPC Server Key<input name="serverkey" id="proxyKey" type="text" value="js"><br>
<button onclick="connect_rpc()">Connect To Proxy</button>
<button onclick="clear_log()">Clear Log</button>
......
......@@ -9,6 +9,6 @@ var Module = require("../lib/libtvm_web_runtime.js");
const tvm_runtime = require("../web/tvm_runtime.js");
const tvm = tvm_runtime.create(Module);
var websock_proxy = "ws://localhost:9888/ws";
var websock_proxy = "ws://localhost:9190/ws";
var num_sess = 100;
tvm.startRPCServer(websock_proxy, "js", num_sess)
......@@ -503,7 +503,7 @@ var tvm_runtime = tvm_runtime || {};
* @return {boolean} Whether f is PackedFunc
*/
this.isPackedFunc = function(f) {
return (typeof f._tvm_function !== "undefined");
return (typeof f == "function") && f.hasOwnProperty("_tvm_function");
};
var isPackedFunc = this.isPackedFunc;
/**
......@@ -633,7 +633,7 @@ var tvm_runtime = tvm_runtime || {};
}
} else if (tp == "number") {
this.setDouble(i, v);
} else if (typeof v._tvm_function !== "undefined") {
} else if (tp == "function" && v.hasOwnProperty("_tvm_function")) {
this.setString(i, v._tvm_function.handle, kFuncHandle);
} else if (v === null) {
this.setHandle(i, 0, kNull);
......@@ -907,12 +907,15 @@ var tvm_runtime = tvm_runtime || {};
}
logging(server_name + "init end...");
if (msg.length > 4) {
if (!message_handler(new Uint8Array(event.data, 4, msg.length -4))) {
if (message_handler(
new Uint8Array(event.data, 4, msg.length -4),
new TVMConstant(3, "int32")) == 0) {
socket.close();
}
}
} else {
if (!message_handler(new Uint8Array(event.data))) {
if (message_handler(new Uint8Array(event.data),
new TVMConstant(3, "int32")) == 0) {
socket.close();
}
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment