Commit 5408d3a3 by Yizhi Liu Committed by Tianqi Chen

[rpc] use callback func to do send & recv (#4147)

* [rpc] use callback func to do send & recv. don't get fd from sock as it is deprecated in java

* fix java build

* fix min/max macro define in windows

* keep the old rpc setup for py

* add doc for CallbackChannel
parent a7404230
......@@ -30,7 +30,6 @@ public class ConnectProxyServerProcessor implements ServerProcessor {
private final String host;
private final int port;
private final String key;
private final SocketFileDescriptorGetter socketFileDescriptorGetter;
private volatile Socket currSocket = new Socket();
private Runnable callback;
......@@ -40,14 +39,11 @@ public class ConnectProxyServerProcessor implements ServerProcessor {
* @param host Proxy server host.
* @param port Proxy server port.
* @param key Proxy server key.
* @param sockFdGetter Method to get file descriptor from Java socket.
*/
public ConnectProxyServerProcessor(String host, int port, String key,
SocketFileDescriptorGetter sockFdGetter) {
public ConnectProxyServerProcessor(String host, int port, String key) {
this.host = host;
this.port = port;
this.key = "server:" + key;
socketFileDescriptorGetter = sockFdGetter;
}
/**
......@@ -70,8 +66,8 @@ public class ConnectProxyServerProcessor implements ServerProcessor {
try {
SocketAddress address = new InetSocketAddress(host, port);
currSocket.connect(address, 6000);
InputStream in = currSocket.getInputStream();
OutputStream out = currSocket.getOutputStream();
final InputStream in = currSocket.getInputStream();
final OutputStream out = currSocket.getOutputStream();
out.write(Utils.toBytes(RPC.RPC_MAGIC));
out.write(Utils.toBytes(key.length()));
out.write(Utils.toBytes(key));
......@@ -91,11 +87,10 @@ public class ConnectProxyServerProcessor implements ServerProcessor {
if (callback != null) {
callback.run();
}
final int sockFd = socketFileDescriptorGetter.get(currSocket);
if (sockFd != -1) {
new NativeServerLoop(sockFd).run();
System.err.println("Finish serving " + address);
}
SocketChannel sockChannel = new SocketChannel(currSocket);
new NativeServerLoop(sockChannel.getFsend(), sockChannel.getFrecv()).run();
System.err.println("Finish serving " + address);
} catch (Throwable e) {
e.printStackTrace();
throw new RuntimeException(e);
......
......@@ -37,7 +37,6 @@ import java.net.SocketTimeoutException;
*/
public class ConnectTrackerServerProcessor implements ServerProcessor {
private ServerSocket server;
private final SocketFileDescriptorGetter socketFileDescriptorGetter;
private final String trackerHost;
private final int trackerPort;
// device key
......@@ -62,10 +61,11 @@ public class ConnectTrackerServerProcessor implements ServerProcessor {
* @param trackerHost Tracker host.
* @param trackerPort Tracker port.
* @param key Device key.
* @param sockFdGetter Method to get file descriptor from Java socket.
* @param watchdog watch for timeout, etc.
* @throws java.io.IOException when socket fails to open.
*/
public ConnectTrackerServerProcessor(String trackerHost, int trackerPort, String key,
SocketFileDescriptorGetter sockFdGetter, RPCWatchdog watchdog) throws IOException {
RPCWatchdog watchdog) throws IOException {
while (true) {
try {
this.server = new ServerSocket(serverPort);
......@@ -81,7 +81,6 @@ public class ConnectTrackerServerProcessor implements ServerProcessor {
}
}
System.err.println("using port: " + serverPort);
this.socketFileDescriptorGetter = sockFdGetter;
this.trackerHost = trackerHost;
this.trackerPort = trackerPort;
this.key = key;
......@@ -163,11 +162,9 @@ public class ConnectTrackerServerProcessor implements ServerProcessor {
System.err.println("Connection from " + socket.getRemoteSocketAddress().toString());
// received timeout in seconds
watchdog.startTimeout(timeout * 1000);
final int sockFd = socketFileDescriptorGetter.get(socket);
if (sockFd != -1) {
new NativeServerLoop(sockFd).run();
System.err.println("Finish serving " + socket.getRemoteSocketAddress().toString());
}
SocketChannel sockChannel = new SocketChannel(socket);
new NativeServerLoop(sockChannel.getFsend(), sockChannel.getFrecv()).run();
System.err.println("Finish serving " + socket.getRemoteSocketAddress().toString());
Utils.closeQuietly(socket);
} catch (ConnectException e) {
// if the tracker connection failed, wait a bit before retrying
......
......@@ -28,14 +28,17 @@ import java.io.IOException;
* Call native ServerLoop on socket file descriptor.
*/
public class NativeServerLoop implements Runnable {
private final int sockFd;
private final Function fsend;
private final Function frecv;
/**
* Constructor for NativeServerLoop.
* @param nativeSockFd native socket file descriptor.
* @param fsend socket.send function.
* @param frecv socket.recv function.
*/
public NativeServerLoop(final int nativeSockFd) {
sockFd = nativeSockFd;
public NativeServerLoop(final Function fsend, final Function frecv) {
this.fsend = fsend;
this.frecv = frecv;
}
@Override public void run() {
......@@ -43,7 +46,7 @@ public class NativeServerLoop implements Runnable {
try {
tempDir = serverEnv();
System.err.println("starting server loop...");
RPC.getApi("_ServerLoop").pushArg(sockFd).invoke();
RPC.getApi("_ServerLoop").pushArg(fsend).pushArg(frecv).invoke();
System.err.println("done server loop...");
} catch (IOException e) {
e.printStackTrace();
......
......@@ -200,6 +200,7 @@ public class RPCSession {
* Upload file to remote runtime temp folder.
* @param data The file in local to upload.
* @param target The path in remote.
* @throws java.io.IOException for network failure.
*/
public void upload(File data, String target) throws IOException {
byte[] blob = getBytesFromFile(data);
......@@ -209,6 +210,7 @@ public class RPCSession {
/**
* Upload file to remote runtime temp folder.
* @param data The file in local to upload.
* @throws java.io.IOException for network failure.
*/
public void upload(File data) throws IOException {
upload(data, data.getName());
......
......@@ -17,31 +17,12 @@
package ml.dmlc.tvm.rpc;
import sun.misc.SharedSecrets;
import java.io.FileDescriptor;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.net.Socket;
/**
* RPC Server.
*/
public class Server {
private static SocketFileDescriptorGetter defaultSocketFdGetter
= new SocketFileDescriptorGetter() {
@Override public int get(Socket socket) {
try {
InputStream is = socket.getInputStream();
FileDescriptor fd = ((FileInputStream) is).getFD();
return SharedSecrets.getJavaIOFileDescriptorAccess().get(fd);
} catch (IOException e) {
e.printStackTrace();
return -1;
}
}
};
private final WorkerThread worker;
private static class WorkerThread extends Thread {
......@@ -72,35 +53,10 @@ public class Server {
/**
* Start a standalone server.
* @param serverPort Port.
* @param socketFdGetter Method to get system file descriptor of the server socket.
* @throws IOException if failed to bind localhost:port.
*/
public Server(int serverPort, SocketFileDescriptorGetter socketFdGetter) throws IOException {
worker = new WorkerThread(new StandaloneServerProcessor(serverPort, socketFdGetter));
}
/**
* Start a standalone server.
* Use sun.misc.SharedSecrets.getJavaIOFileDescriptorAccess
* to get file descriptor for the socket.
* @param serverPort Port.
* @throws IOException if failed to bind localhost:port.
*/
public Server(int serverPort) throws IOException {
this(serverPort, defaultSocketFdGetter);
}
/**
* Start a server connected to proxy.
* @param proxyHost The proxy server host.
* @param proxyPort The proxy server port.
* @param key The key to identify the server.
* @param socketFdGetter Method to get system file descriptor of the server socket.
*/
public Server(String proxyHost, int proxyPort, String key,
SocketFileDescriptorGetter socketFdGetter) {
worker = new WorkerThread(
new ConnectProxyServerProcessor(proxyHost, proxyPort, key, socketFdGetter));
worker = new WorkerThread(new StandaloneServerProcessor(serverPort));
}
/**
......@@ -112,7 +68,8 @@ public class Server {
* @param key The key to identify the server.
*/
public Server(String proxyHost, int proxyPort, String key) {
this(proxyHost, proxyPort, key, defaultSocketFdGetter);
worker = new WorkerThread(
new ConnectProxyServerProcessor(proxyHost, proxyPort, key));
}
/**
......
package ml.dmlc.tvm.rpc;
import ml.dmlc.tvm.Function;
import ml.dmlc.tvm.TVMValue;
import ml.dmlc.tvm.TVMValueBytes;
import java.io.IOException;
import java.net.Socket;
public class SocketChannel {
private final Socket socket;
SocketChannel(Socket sock) {
socket = sock;
}
private Function fsend = Function.convertFunc(new Function.Callback() {
@Override public Object invoke(TVMValue... args) {
byte[] data = args[0].asBytes();
try {
socket.getOutputStream().write(data);
} catch (IOException e) {
e.printStackTrace();
return -1;
}
return data.length;
}
});
private Function frecv = Function.convertFunc(new Function.Callback() {
@Override public Object invoke(TVMValue... args) {
long size = args[0].asLong();
try {
return new TVMValueBytes(Utils.recvAll(socket.getInputStream(), (int) size));
} catch (IOException e) {
e.printStackTrace();
return -1;
}
}
});
public Function getFsend() {
return fsend;
}
public Function getFrecv() {
return frecv;
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.tvm.rpc;
import java.net.Socket;
/**
* Interface for defining different socket fd getter.
*/
public interface SocketFileDescriptorGetter {
/**
* Get native socket file descriptor.
* @param socket Java socket.
* @return native socket fd.
*/
public int get(Socket socket);
}
......@@ -28,12 +28,9 @@ import java.net.Socket;
*/
public class StandaloneServerProcessor implements ServerProcessor {
private final ServerSocket server;
private final SocketFileDescriptorGetter socketFileDescriptorGetter;
public StandaloneServerProcessor(int serverPort,
SocketFileDescriptorGetter sockFdGetter) throws IOException {
public StandaloneServerProcessor(int serverPort) throws IOException {
this.server = new ServerSocket(serverPort);
this.socketFileDescriptorGetter = sockFdGetter;
}
@Override public void terminate() {
......@@ -46,9 +43,9 @@ public class StandaloneServerProcessor implements ServerProcessor {
@Override public void run() {
try {
Socket socket = server.accept();
InputStream in = socket.getInputStream();
OutputStream out = socket.getOutputStream();
final Socket socket = server.accept();
final InputStream in = socket.getInputStream();
final OutputStream out = socket.getOutputStream();
int magic = Utils.wrapBytes(Utils.recvAll(in, 4)).getInt();
if (magic != RPC.RPC_MAGIC) {
Utils.closeQuietly(socket);
......@@ -66,12 +63,10 @@ public class StandaloneServerProcessor implements ServerProcessor {
out.write(Utils.toBytes(serverKey));
}
SocketChannel sockChannel = new SocketChannel(socket);
System.err.println("Connection from " + socket.getRemoteSocketAddress().toString());
final int sockFd = socketFileDescriptorGetter.get(socket);
if (sockFd != -1) {
new NativeServerLoop(sockFd).run();
System.err.println("Finish serving " + socket.getRemoteSocketAddress().toString());
}
new NativeServerLoop(sockChannel.getFsend(), sockChannel.getFrecv()).run();
System.err.println("Finish serving " + socket.getRemoteSocketAddress().toString());
Utils.closeQuietly(socket);
} catch (Throwable e) {
e.printStackTrace();
......
......@@ -17,7 +17,10 @@
package ml.dmlc.tvm.contrib;
import ml.dmlc.tvm.*;
import ml.dmlc.tvm.Module;
import ml.dmlc.tvm.NDArray;
import ml.dmlc.tvm.TVMContext;
import ml.dmlc.tvm.TestUtils;
import ml.dmlc.tvm.rpc.Client;
import ml.dmlc.tvm.rpc.RPCSession;
import ml.dmlc.tvm.rpc.Server;
......
......@@ -164,8 +164,8 @@
<artifactId>maven-compiler-plugin</artifactId>
<version>3.3</version>
<configuration>
<source>1.6</source>
<target>1.6</target>
<source>1.7</source>
<target>1.7</target>
<encoding>UTF-8</encoding>
</configuration>
</plugin>
......
......@@ -230,7 +230,7 @@ class TCPEventHandler(tornado_util.TCPHandler):
port, matchkey = args[2]
self.pending_matchkeys.add(matchkey)
# got custom address (from rpc server)
if args[3] is not None:
if len(args) >= 4 and args[3] is not None:
value = (self, args[3], port, matchkey)
else:
value = (self, self._addr[0], port, matchkey)
......
......@@ -27,8 +27,10 @@
#define TVM_COMMON_SOCKET_H_
#if defined(_WIN32)
#define NOMINMAX
#include <winsock2.h>
#include <ws2tcpip.h>
#undef NOMINMAX
using ssize_t = int;
#ifdef _MSC_VER
#pragma comment(lib, "Ws2_32.lib")
......
......@@ -29,32 +29,14 @@
namespace tvm {
namespace runtime {
class CallbackChannel final : public RPCChannel {
public:
explicit CallbackChannel(PackedFunc fsend)
: fsend_(fsend) {}
size_t Send(const void* data, size_t size) final {
TVMByteArray bytes;
bytes.data = static_cast<const char*>(data);
bytes.size = size;
uint64_t ret = fsend_(bytes);
return static_cast<size_t>(ret);
}
size_t Recv(void* data, size_t size) final {
LOG(FATAL) << "Do not allow explicit receive for";
return 0;
}
private:
PackedFunc fsend_;
};
PackedFunc CreateEventDrivenServer(PackedFunc fsend,
std::string name,
std::string remote_key) {
std::unique_ptr<CallbackChannel> ch(new CallbackChannel(fsend));
static PackedFunc frecv([](TVMArgs args, TVMRetValue* rv) {
LOG(FATAL) << "Do not allow explicit receive";
return 0;
});
std::unique_ptr<CallbackChannel> ch(new CallbackChannel(fsend, frecv));
std::shared_ptr<RPCSession> sess =
RPCSession::Create(std::move(ch), name, remote_key);
return PackedFunc([sess](TVMArgs args, TVMRetValue* rv) {
......
......@@ -36,6 +36,7 @@
#include <algorithm>
#include "rpc_session.h"
#include "../../common/ring_buffer.h"
#include "../../common/socket.h"
namespace tvm {
namespace runtime {
......@@ -1260,5 +1261,26 @@ PackedFunc WrapTimeEvaluator(PackedFunc pf,
return PackedFunc(ftimer);
}
size_t CallbackChannel::Send(const void* data, size_t size) {
TVMByteArray bytes;
bytes.data = static_cast<const char*>(data);
bytes.size = size;
int64_t n = fsend_(bytes);
if (n == -1) {
common::Socket::Error("CallbackChannel::Send");
}
return static_cast<size_t>(n);
}
size_t CallbackChannel::Recv(void* data, size_t size) {
TVMRetValue ret = frecv_(size);
if (ret.type_code() != kBytes) {
common::Socket::Error("CallbackChannel::Recv");
}
std::string* bytes = ret.ptr<std::string>();
memcpy(static_cast<char*>(data), bytes->c_str(), bytes->length());
return bytes->length();
}
} // namespace runtime
} // namespace tvm
......@@ -87,7 +87,7 @@ class RPCChannel {
*/
virtual size_t Send(const void* data, size_t size) = 0;
/*!
e * \brief Recv data from channel.
* \brief Recv data from channel.
*
* \param data The data pointer.
* \param size The size fo the data.
......@@ -254,6 +254,37 @@ class RPCSession {
};
/*!
* \brief RPC channel which callback
* frontend (Python/Java/etc.)'s send & recv function
*/
class CallbackChannel final : public RPCChannel {
public:
explicit CallbackChannel(PackedFunc fsend, PackedFunc frecv)
: fsend_(std::move(fsend)), frecv_(std::move(frecv)) {}
~CallbackChannel() {}
/*!
* \brief Send data over to the channel.
* \param data The data pointer.
* \param size The size fo the data.
* \return The actual bytes sent.
*/
size_t Send(const void* data, size_t size) final;
/*!
* \brief Recv data from channel.
*
* \param data The data pointer.
* \param size The size fo the data.
* \return The actual bytes received.
*/
size_t Recv(void* data, size_t size) final;
private:
PackedFunc fsend_;
PackedFunc frecv_;
};
/*!
* \brief Wrap a timer function to measure the time cost of a given packed function.
* \param f The function argument.
* \param ctx The context.
......
......@@ -36,7 +36,7 @@ class SockChannel final : public RPCChannel {
: sock_(sock) {}
~SockChannel() {
if (!sock_.BadSocket()) {
sock_.Close();
sock_.Close();
}
}
size_t Send(const void* data, size_t size) final {
......@@ -109,12 +109,25 @@ void RPCServerLoop(int sockfd) {
"SockServerLoop", "")->ServerLoop();
}
void RPCServerLoop(PackedFunc fsend, PackedFunc frecv) {
RPCSession::Create(std::unique_ptr<CallbackChannel>(
new CallbackChannel(fsend, frecv)),
"SockServerLoop", "")->ServerLoop();
}
TVM_REGISTER_GLOBAL("rpc._Connect")
.set_body_typed(RPCClientConnect);
TVM_REGISTER_GLOBAL("rpc._ServerLoop")
.set_body([](TVMArgs args, TVMRetValue* rv) {
RPCServerLoop(args[0]);
if (args.size() == 1) {
RPCServerLoop(args[0]);
} else {
CHECK_EQ(args.size(), 2);
RPCServerLoop(
args[0].operator tvm::runtime::PackedFunc(),
args[1].operator tvm::runtime::PackedFunc());
}
});
} // namespace runtime
} // namespace tvm
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment