Commit 5408d3a3 by Yizhi Liu Committed by Tianqi Chen

[rpc] use callback func to do send & recv (#4147)

* [rpc] use callback func to do send & recv. don't get fd from sock as it is deprecated in java

* fix java build

* fix min/max macro define in windows

* keep the old rpc setup for py

* add doc for CallbackChannel
parent a7404230
...@@ -30,7 +30,6 @@ public class ConnectProxyServerProcessor implements ServerProcessor { ...@@ -30,7 +30,6 @@ public class ConnectProxyServerProcessor implements ServerProcessor {
private final String host; private final String host;
private final int port; private final int port;
private final String key; private final String key;
private final SocketFileDescriptorGetter socketFileDescriptorGetter;
private volatile Socket currSocket = new Socket(); private volatile Socket currSocket = new Socket();
private Runnable callback; private Runnable callback;
...@@ -40,14 +39,11 @@ public class ConnectProxyServerProcessor implements ServerProcessor { ...@@ -40,14 +39,11 @@ public class ConnectProxyServerProcessor implements ServerProcessor {
* @param host Proxy server host. * @param host Proxy server host.
* @param port Proxy server port. * @param port Proxy server port.
* @param key Proxy server key. * @param key Proxy server key.
* @param sockFdGetter Method to get file descriptor from Java socket.
*/ */
public ConnectProxyServerProcessor(String host, int port, String key, public ConnectProxyServerProcessor(String host, int port, String key) {
SocketFileDescriptorGetter sockFdGetter) {
this.host = host; this.host = host;
this.port = port; this.port = port;
this.key = "server:" + key; this.key = "server:" + key;
socketFileDescriptorGetter = sockFdGetter;
} }
/** /**
...@@ -70,8 +66,8 @@ public class ConnectProxyServerProcessor implements ServerProcessor { ...@@ -70,8 +66,8 @@ public class ConnectProxyServerProcessor implements ServerProcessor {
try { try {
SocketAddress address = new InetSocketAddress(host, port); SocketAddress address = new InetSocketAddress(host, port);
currSocket.connect(address, 6000); currSocket.connect(address, 6000);
InputStream in = currSocket.getInputStream(); final InputStream in = currSocket.getInputStream();
OutputStream out = currSocket.getOutputStream(); final OutputStream out = currSocket.getOutputStream();
out.write(Utils.toBytes(RPC.RPC_MAGIC)); out.write(Utils.toBytes(RPC.RPC_MAGIC));
out.write(Utils.toBytes(key.length())); out.write(Utils.toBytes(key.length()));
out.write(Utils.toBytes(key)); out.write(Utils.toBytes(key));
...@@ -91,11 +87,10 @@ public class ConnectProxyServerProcessor implements ServerProcessor { ...@@ -91,11 +87,10 @@ public class ConnectProxyServerProcessor implements ServerProcessor {
if (callback != null) { if (callback != null) {
callback.run(); callback.run();
} }
final int sockFd = socketFileDescriptorGetter.get(currSocket);
if (sockFd != -1) { SocketChannel sockChannel = new SocketChannel(currSocket);
new NativeServerLoop(sockFd).run(); new NativeServerLoop(sockChannel.getFsend(), sockChannel.getFrecv()).run();
System.err.println("Finish serving " + address); System.err.println("Finish serving " + address);
}
} catch (Throwable e) { } catch (Throwable e) {
e.printStackTrace(); e.printStackTrace();
throw new RuntimeException(e); throw new RuntimeException(e);
......
...@@ -37,7 +37,6 @@ import java.net.SocketTimeoutException; ...@@ -37,7 +37,6 @@ import java.net.SocketTimeoutException;
*/ */
public class ConnectTrackerServerProcessor implements ServerProcessor { public class ConnectTrackerServerProcessor implements ServerProcessor {
private ServerSocket server; private ServerSocket server;
private final SocketFileDescriptorGetter socketFileDescriptorGetter;
private final String trackerHost; private final String trackerHost;
private final int trackerPort; private final int trackerPort;
// device key // device key
...@@ -62,10 +61,11 @@ public class ConnectTrackerServerProcessor implements ServerProcessor { ...@@ -62,10 +61,11 @@ public class ConnectTrackerServerProcessor implements ServerProcessor {
* @param trackerHost Tracker host. * @param trackerHost Tracker host.
* @param trackerPort Tracker port. * @param trackerPort Tracker port.
* @param key Device key. * @param key Device key.
* @param sockFdGetter Method to get file descriptor from Java socket. * @param watchdog watch for timeout, etc.
* @throws java.io.IOException when socket fails to open.
*/ */
public ConnectTrackerServerProcessor(String trackerHost, int trackerPort, String key, public ConnectTrackerServerProcessor(String trackerHost, int trackerPort, String key,
SocketFileDescriptorGetter sockFdGetter, RPCWatchdog watchdog) throws IOException { RPCWatchdog watchdog) throws IOException {
while (true) { while (true) {
try { try {
this.server = new ServerSocket(serverPort); this.server = new ServerSocket(serverPort);
...@@ -81,7 +81,6 @@ public class ConnectTrackerServerProcessor implements ServerProcessor { ...@@ -81,7 +81,6 @@ public class ConnectTrackerServerProcessor implements ServerProcessor {
} }
} }
System.err.println("using port: " + serverPort); System.err.println("using port: " + serverPort);
this.socketFileDescriptorGetter = sockFdGetter;
this.trackerHost = trackerHost; this.trackerHost = trackerHost;
this.trackerPort = trackerPort; this.trackerPort = trackerPort;
this.key = key; this.key = key;
...@@ -163,11 +162,9 @@ public class ConnectTrackerServerProcessor implements ServerProcessor { ...@@ -163,11 +162,9 @@ public class ConnectTrackerServerProcessor implements ServerProcessor {
System.err.println("Connection from " + socket.getRemoteSocketAddress().toString()); System.err.println("Connection from " + socket.getRemoteSocketAddress().toString());
// received timeout in seconds // received timeout in seconds
watchdog.startTimeout(timeout * 1000); watchdog.startTimeout(timeout * 1000);
final int sockFd = socketFileDescriptorGetter.get(socket); SocketChannel sockChannel = new SocketChannel(socket);
if (sockFd != -1) { new NativeServerLoop(sockChannel.getFsend(), sockChannel.getFrecv()).run();
new NativeServerLoop(sockFd).run(); System.err.println("Finish serving " + socket.getRemoteSocketAddress().toString());
System.err.println("Finish serving " + socket.getRemoteSocketAddress().toString());
}
Utils.closeQuietly(socket); Utils.closeQuietly(socket);
} catch (ConnectException e) { } catch (ConnectException e) {
// if the tracker connection failed, wait a bit before retrying // if the tracker connection failed, wait a bit before retrying
......
...@@ -28,14 +28,17 @@ import java.io.IOException; ...@@ -28,14 +28,17 @@ import java.io.IOException;
* Call native ServerLoop on socket file descriptor. * Call native ServerLoop on socket file descriptor.
*/ */
public class NativeServerLoop implements Runnable { public class NativeServerLoop implements Runnable {
private final int sockFd; private final Function fsend;
private final Function frecv;
/** /**
* Constructor for NativeServerLoop. * Constructor for NativeServerLoop.
* @param nativeSockFd native socket file descriptor. * @param fsend socket.send function.
* @param frecv socket.recv function.
*/ */
public NativeServerLoop(final int nativeSockFd) { public NativeServerLoop(final Function fsend, final Function frecv) {
sockFd = nativeSockFd; this.fsend = fsend;
this.frecv = frecv;
} }
@Override public void run() { @Override public void run() {
...@@ -43,7 +46,7 @@ public class NativeServerLoop implements Runnable { ...@@ -43,7 +46,7 @@ public class NativeServerLoop implements Runnable {
try { try {
tempDir = serverEnv(); tempDir = serverEnv();
System.err.println("starting server loop..."); System.err.println("starting server loop...");
RPC.getApi("_ServerLoop").pushArg(sockFd).invoke(); RPC.getApi("_ServerLoop").pushArg(fsend).pushArg(frecv).invoke();
System.err.println("done server loop..."); System.err.println("done server loop...");
} catch (IOException e) { } catch (IOException e) {
e.printStackTrace(); e.printStackTrace();
......
...@@ -200,6 +200,7 @@ public class RPCSession { ...@@ -200,6 +200,7 @@ public class RPCSession {
* Upload file to remote runtime temp folder. * Upload file to remote runtime temp folder.
* @param data The file in local to upload. * @param data The file in local to upload.
* @param target The path in remote. * @param target The path in remote.
* @throws java.io.IOException for network failure.
*/ */
public void upload(File data, String target) throws IOException { public void upload(File data, String target) throws IOException {
byte[] blob = getBytesFromFile(data); byte[] blob = getBytesFromFile(data);
...@@ -209,6 +210,7 @@ public class RPCSession { ...@@ -209,6 +210,7 @@ public class RPCSession {
/** /**
* Upload file to remote runtime temp folder. * Upload file to remote runtime temp folder.
* @param data The file in local to upload. * @param data The file in local to upload.
* @throws java.io.IOException for network failure.
*/ */
public void upload(File data) throws IOException { public void upload(File data) throws IOException {
upload(data, data.getName()); upload(data, data.getName());
......
...@@ -17,31 +17,12 @@ ...@@ -17,31 +17,12 @@
package ml.dmlc.tvm.rpc; package ml.dmlc.tvm.rpc;
import sun.misc.SharedSecrets;
import java.io.FileDescriptor;
import java.io.FileInputStream;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream;
import java.net.Socket;
/** /**
* RPC Server. * RPC Server.
*/ */
public class Server { public class Server {
private static SocketFileDescriptorGetter defaultSocketFdGetter
= new SocketFileDescriptorGetter() {
@Override public int get(Socket socket) {
try {
InputStream is = socket.getInputStream();
FileDescriptor fd = ((FileInputStream) is).getFD();
return SharedSecrets.getJavaIOFileDescriptorAccess().get(fd);
} catch (IOException e) {
e.printStackTrace();
return -1;
}
}
};
private final WorkerThread worker; private final WorkerThread worker;
private static class WorkerThread extends Thread { private static class WorkerThread extends Thread {
...@@ -72,35 +53,10 @@ public class Server { ...@@ -72,35 +53,10 @@ public class Server {
/** /**
* Start a standalone server. * Start a standalone server.
* @param serverPort Port. * @param serverPort Port.
* @param socketFdGetter Method to get system file descriptor of the server socket.
* @throws IOException if failed to bind localhost:port.
*/
public Server(int serverPort, SocketFileDescriptorGetter socketFdGetter) throws IOException {
worker = new WorkerThread(new StandaloneServerProcessor(serverPort, socketFdGetter));
}
/**
* Start a standalone server.
* Use sun.misc.SharedSecrets.getJavaIOFileDescriptorAccess
* to get file descriptor for the socket.
* @param serverPort Port.
* @throws IOException if failed to bind localhost:port. * @throws IOException if failed to bind localhost:port.
*/ */
public Server(int serverPort) throws IOException { public Server(int serverPort) throws IOException {
this(serverPort, defaultSocketFdGetter); worker = new WorkerThread(new StandaloneServerProcessor(serverPort));
}
/**
* Start a server connected to proxy.
* @param proxyHost The proxy server host.
* @param proxyPort The proxy server port.
* @param key The key to identify the server.
* @param socketFdGetter Method to get system file descriptor of the server socket.
*/
public Server(String proxyHost, int proxyPort, String key,
SocketFileDescriptorGetter socketFdGetter) {
worker = new WorkerThread(
new ConnectProxyServerProcessor(proxyHost, proxyPort, key, socketFdGetter));
} }
/** /**
...@@ -112,7 +68,8 @@ public class Server { ...@@ -112,7 +68,8 @@ public class Server {
* @param key The key to identify the server. * @param key The key to identify the server.
*/ */
public Server(String proxyHost, int proxyPort, String key) { public Server(String proxyHost, int proxyPort, String key) {
this(proxyHost, proxyPort, key, defaultSocketFdGetter); worker = new WorkerThread(
new ConnectProxyServerProcessor(proxyHost, proxyPort, key));
} }
/** /**
......
package ml.dmlc.tvm.rpc;
import ml.dmlc.tvm.Function;
import ml.dmlc.tvm.TVMValue;
import ml.dmlc.tvm.TVMValueBytes;
import java.io.IOException;
import java.net.Socket;
public class SocketChannel {
private final Socket socket;
SocketChannel(Socket sock) {
socket = sock;
}
private Function fsend = Function.convertFunc(new Function.Callback() {
@Override public Object invoke(TVMValue... args) {
byte[] data = args[0].asBytes();
try {
socket.getOutputStream().write(data);
} catch (IOException e) {
e.printStackTrace();
return -1;
}
return data.length;
}
});
private Function frecv = Function.convertFunc(new Function.Callback() {
@Override public Object invoke(TVMValue... args) {
long size = args[0].asLong();
try {
return new TVMValueBytes(Utils.recvAll(socket.getInputStream(), (int) size));
} catch (IOException e) {
e.printStackTrace();
return -1;
}
}
});
public Function getFsend() {
return fsend;
}
public Function getFrecv() {
return frecv;
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.tvm.rpc;
import java.net.Socket;
/**
* Interface for defining different socket fd getter.
*/
public interface SocketFileDescriptorGetter {
/**
* Get native socket file descriptor.
* @param socket Java socket.
* @return native socket fd.
*/
public int get(Socket socket);
}
...@@ -28,12 +28,9 @@ import java.net.Socket; ...@@ -28,12 +28,9 @@ import java.net.Socket;
*/ */
public class StandaloneServerProcessor implements ServerProcessor { public class StandaloneServerProcessor implements ServerProcessor {
private final ServerSocket server; private final ServerSocket server;
private final SocketFileDescriptorGetter socketFileDescriptorGetter;
public StandaloneServerProcessor(int serverPort, public StandaloneServerProcessor(int serverPort) throws IOException {
SocketFileDescriptorGetter sockFdGetter) throws IOException {
this.server = new ServerSocket(serverPort); this.server = new ServerSocket(serverPort);
this.socketFileDescriptorGetter = sockFdGetter;
} }
@Override public void terminate() { @Override public void terminate() {
...@@ -46,9 +43,9 @@ public class StandaloneServerProcessor implements ServerProcessor { ...@@ -46,9 +43,9 @@ public class StandaloneServerProcessor implements ServerProcessor {
@Override public void run() { @Override public void run() {
try { try {
Socket socket = server.accept(); final Socket socket = server.accept();
InputStream in = socket.getInputStream(); final InputStream in = socket.getInputStream();
OutputStream out = socket.getOutputStream(); final OutputStream out = socket.getOutputStream();
int magic = Utils.wrapBytes(Utils.recvAll(in, 4)).getInt(); int magic = Utils.wrapBytes(Utils.recvAll(in, 4)).getInt();
if (magic != RPC.RPC_MAGIC) { if (magic != RPC.RPC_MAGIC) {
Utils.closeQuietly(socket); Utils.closeQuietly(socket);
...@@ -66,12 +63,10 @@ public class StandaloneServerProcessor implements ServerProcessor { ...@@ -66,12 +63,10 @@ public class StandaloneServerProcessor implements ServerProcessor {
out.write(Utils.toBytes(serverKey)); out.write(Utils.toBytes(serverKey));
} }
SocketChannel sockChannel = new SocketChannel(socket);
System.err.println("Connection from " + socket.getRemoteSocketAddress().toString()); System.err.println("Connection from " + socket.getRemoteSocketAddress().toString());
final int sockFd = socketFileDescriptorGetter.get(socket); new NativeServerLoop(sockChannel.getFsend(), sockChannel.getFrecv()).run();
if (sockFd != -1) { System.err.println("Finish serving " + socket.getRemoteSocketAddress().toString());
new NativeServerLoop(sockFd).run();
System.err.println("Finish serving " + socket.getRemoteSocketAddress().toString());
}
Utils.closeQuietly(socket); Utils.closeQuietly(socket);
} catch (Throwable e) { } catch (Throwable e) {
e.printStackTrace(); e.printStackTrace();
......
...@@ -17,7 +17,10 @@ ...@@ -17,7 +17,10 @@
package ml.dmlc.tvm.contrib; package ml.dmlc.tvm.contrib;
import ml.dmlc.tvm.*; import ml.dmlc.tvm.Module;
import ml.dmlc.tvm.NDArray;
import ml.dmlc.tvm.TVMContext;
import ml.dmlc.tvm.TestUtils;
import ml.dmlc.tvm.rpc.Client; import ml.dmlc.tvm.rpc.Client;
import ml.dmlc.tvm.rpc.RPCSession; import ml.dmlc.tvm.rpc.RPCSession;
import ml.dmlc.tvm.rpc.Server; import ml.dmlc.tvm.rpc.Server;
......
...@@ -164,8 +164,8 @@ ...@@ -164,8 +164,8 @@
<artifactId>maven-compiler-plugin</artifactId> <artifactId>maven-compiler-plugin</artifactId>
<version>3.3</version> <version>3.3</version>
<configuration> <configuration>
<source>1.6</source> <source>1.7</source>
<target>1.6</target> <target>1.7</target>
<encoding>UTF-8</encoding> <encoding>UTF-8</encoding>
</configuration> </configuration>
</plugin> </plugin>
......
...@@ -230,7 +230,7 @@ class TCPEventHandler(tornado_util.TCPHandler): ...@@ -230,7 +230,7 @@ class TCPEventHandler(tornado_util.TCPHandler):
port, matchkey = args[2] port, matchkey = args[2]
self.pending_matchkeys.add(matchkey) self.pending_matchkeys.add(matchkey)
# got custom address (from rpc server) # got custom address (from rpc server)
if args[3] is not None: if len(args) >= 4 and args[3] is not None:
value = (self, args[3], port, matchkey) value = (self, args[3], port, matchkey)
else: else:
value = (self, self._addr[0], port, matchkey) value = (self, self._addr[0], port, matchkey)
......
...@@ -27,8 +27,10 @@ ...@@ -27,8 +27,10 @@
#define TVM_COMMON_SOCKET_H_ #define TVM_COMMON_SOCKET_H_
#if defined(_WIN32) #if defined(_WIN32)
#define NOMINMAX
#include <winsock2.h> #include <winsock2.h>
#include <ws2tcpip.h> #include <ws2tcpip.h>
#undef NOMINMAX
using ssize_t = int; using ssize_t = int;
#ifdef _MSC_VER #ifdef _MSC_VER
#pragma comment(lib, "Ws2_32.lib") #pragma comment(lib, "Ws2_32.lib")
......
...@@ -29,32 +29,14 @@ ...@@ -29,32 +29,14 @@
namespace tvm { namespace tvm {
namespace runtime { namespace runtime {
class CallbackChannel final : public RPCChannel {
public:
explicit CallbackChannel(PackedFunc fsend)
: fsend_(fsend) {}
size_t Send(const void* data, size_t size) final {
TVMByteArray bytes;
bytes.data = static_cast<const char*>(data);
bytes.size = size;
uint64_t ret = fsend_(bytes);
return static_cast<size_t>(ret);
}
size_t Recv(void* data, size_t size) final {
LOG(FATAL) << "Do not allow explicit receive for";
return 0;
}
private:
PackedFunc fsend_;
};
PackedFunc CreateEventDrivenServer(PackedFunc fsend, PackedFunc CreateEventDrivenServer(PackedFunc fsend,
std::string name, std::string name,
std::string remote_key) { std::string remote_key) {
std::unique_ptr<CallbackChannel> ch(new CallbackChannel(fsend)); static PackedFunc frecv([](TVMArgs args, TVMRetValue* rv) {
LOG(FATAL) << "Do not allow explicit receive";
return 0;
});
std::unique_ptr<CallbackChannel> ch(new CallbackChannel(fsend, frecv));
std::shared_ptr<RPCSession> sess = std::shared_ptr<RPCSession> sess =
RPCSession::Create(std::move(ch), name, remote_key); RPCSession::Create(std::move(ch), name, remote_key);
return PackedFunc([sess](TVMArgs args, TVMRetValue* rv) { return PackedFunc([sess](TVMArgs args, TVMRetValue* rv) {
......
...@@ -36,6 +36,7 @@ ...@@ -36,6 +36,7 @@
#include <algorithm> #include <algorithm>
#include "rpc_session.h" #include "rpc_session.h"
#include "../../common/ring_buffer.h" #include "../../common/ring_buffer.h"
#include "../../common/socket.h"
namespace tvm { namespace tvm {
namespace runtime { namespace runtime {
...@@ -1260,5 +1261,26 @@ PackedFunc WrapTimeEvaluator(PackedFunc pf, ...@@ -1260,5 +1261,26 @@ PackedFunc WrapTimeEvaluator(PackedFunc pf,
return PackedFunc(ftimer); return PackedFunc(ftimer);
} }
size_t CallbackChannel::Send(const void* data, size_t size) {
TVMByteArray bytes;
bytes.data = static_cast<const char*>(data);
bytes.size = size;
int64_t n = fsend_(bytes);
if (n == -1) {
common::Socket::Error("CallbackChannel::Send");
}
return static_cast<size_t>(n);
}
size_t CallbackChannel::Recv(void* data, size_t size) {
TVMRetValue ret = frecv_(size);
if (ret.type_code() != kBytes) {
common::Socket::Error("CallbackChannel::Recv");
}
std::string* bytes = ret.ptr<std::string>();
memcpy(static_cast<char*>(data), bytes->c_str(), bytes->length());
return bytes->length();
}
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
...@@ -87,7 +87,7 @@ class RPCChannel { ...@@ -87,7 +87,7 @@ class RPCChannel {
*/ */
virtual size_t Send(const void* data, size_t size) = 0; virtual size_t Send(const void* data, size_t size) = 0;
/*! /*!
e * \brief Recv data from channel. * \brief Recv data from channel.
* *
* \param data The data pointer. * \param data The data pointer.
* \param size The size fo the data. * \param size The size fo the data.
...@@ -254,6 +254,37 @@ class RPCSession { ...@@ -254,6 +254,37 @@ class RPCSession {
}; };
/*! /*!
* \brief RPC channel which callback
* frontend (Python/Java/etc.)'s send & recv function
*/
class CallbackChannel final : public RPCChannel {
public:
explicit CallbackChannel(PackedFunc fsend, PackedFunc frecv)
: fsend_(std::move(fsend)), frecv_(std::move(frecv)) {}
~CallbackChannel() {}
/*!
* \brief Send data over to the channel.
* \param data The data pointer.
* \param size The size fo the data.
* \return The actual bytes sent.
*/
size_t Send(const void* data, size_t size) final;
/*!
* \brief Recv data from channel.
*
* \param data The data pointer.
* \param size The size fo the data.
* \return The actual bytes received.
*/
size_t Recv(void* data, size_t size) final;
private:
PackedFunc fsend_;
PackedFunc frecv_;
};
/*!
* \brief Wrap a timer function to measure the time cost of a given packed function. * \brief Wrap a timer function to measure the time cost of a given packed function.
* \param f The function argument. * \param f The function argument.
* \param ctx The context. * \param ctx The context.
......
...@@ -36,7 +36,7 @@ class SockChannel final : public RPCChannel { ...@@ -36,7 +36,7 @@ class SockChannel final : public RPCChannel {
: sock_(sock) {} : sock_(sock) {}
~SockChannel() { ~SockChannel() {
if (!sock_.BadSocket()) { if (!sock_.BadSocket()) {
sock_.Close(); sock_.Close();
} }
} }
size_t Send(const void* data, size_t size) final { size_t Send(const void* data, size_t size) final {
...@@ -109,12 +109,25 @@ void RPCServerLoop(int sockfd) { ...@@ -109,12 +109,25 @@ void RPCServerLoop(int sockfd) {
"SockServerLoop", "")->ServerLoop(); "SockServerLoop", "")->ServerLoop();
} }
void RPCServerLoop(PackedFunc fsend, PackedFunc frecv) {
RPCSession::Create(std::unique_ptr<CallbackChannel>(
new CallbackChannel(fsend, frecv)),
"SockServerLoop", "")->ServerLoop();
}
TVM_REGISTER_GLOBAL("rpc._Connect") TVM_REGISTER_GLOBAL("rpc._Connect")
.set_body_typed(RPCClientConnect); .set_body_typed(RPCClientConnect);
TVM_REGISTER_GLOBAL("rpc._ServerLoop") TVM_REGISTER_GLOBAL("rpc._ServerLoop")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
RPCServerLoop(args[0]); if (args.size() == 1) {
RPCServerLoop(args[0]);
} else {
CHECK_EQ(args.size(), 2);
RPCServerLoop(
args[0].operator tvm::runtime::PackedFunc(),
args[1].operator tvm::runtime::PackedFunc());
}
}); });
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment