Commit 7d67e473 by Yizhi Liu Committed by Tianqi Chen

[tvm4j] RPC Server (#268)

* [tvm4j] RPC Server

* [tvm4j] fix recursively function calling; connect to proxy server; osx rename .so to .dylib

* [tvm4j] test case for proxy connection; thread pool for serving
parent 1146495f
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
<includeBaseDirectory>false</includeBaseDirectory> <includeBaseDirectory>false</includeBaseDirectory>
<files> <files>
<file> <file>
<source>../../../lib/libtvm_runtime.so</source> <source>../../../lib/libtvm_runtime.dylib</source>
<outputDirectory>lib/native</outputDirectory> <outputDirectory>lib/native</outputDirectory>
<fileMode>0644</fileMode> <fileMode>0644</fileMode>
</file> </file>
......
...@@ -20,18 +20,21 @@ ...@@ -20,18 +20,21 @@
<id>osx-x86_64-cpu</id> <id>osx-x86_64-cpu</id>
<properties> <properties>
<platform>osx-x86_64-cpu</platform> <platform>osx-x86_64-cpu</platform>
<libtvm.so.filename>libtvm_runtime.dylib</libtvm.so.filename>
</properties> </properties>
</profile> </profile>
<profile> <profile>
<id>linux-x86_64-cpu</id> <id>linux-x86_64-cpu</id>
<properties> <properties>
<platform>linux-x86_64-cpu</platform> <platform>linux-x86_64-cpu</platform>
<libtvm.so.filename>libtvm_runtime.so</libtvm.so.filename>
</properties> </properties>
</profile> </profile>
<profile> <profile>
<id>linux-x86_64-gpu</id> <id>linux-x86_64-gpu</id>
<properties> <properties>
<platform>linux-x86_64-gpu</platform> <platform>linux-x86_64-gpu</platform>
<libtvm.so.filename>libtvm_runtime.so</libtvm.so.filename>
</properties> </properties>
</profile> </profile>
</profiles> </profiles>
...@@ -88,7 +91,7 @@ ...@@ -88,7 +91,7 @@
<threadCount>1</threadCount> <threadCount>1</threadCount>
<argLine> <argLine>
-Djava.library.path=${project.parent.basedir}/native/${platform}/target -Djava.library.path=${project.parent.basedir}/native/${platform}/target
-Dlibtvm.so.path=${project.parent.basedir}/../lib/libtvm_runtime.so -Dlibtvm.so.path=${project.parent.basedir}/../lib/${libtvm.so.filename}
</argLine> </argLine>
</configuration> </configuration>
<executions> <executions>
......
...@@ -80,7 +80,18 @@ final class Base { ...@@ -80,7 +80,18 @@ final class Base {
if (tvmLibFilename == null || !new File(tvmLibFilename).isFile() if (tvmLibFilename == null || !new File(tvmLibFilename).isFile()
|| _LIB.nativeLibInit(tvmLibFilename) != 0) { || _LIB.nativeLibInit(tvmLibFilename) != 0) {
try { try {
NativeLibraryLoader.extractResourceFileToTempDir("libtvm_runtime.so", new Action() { String runtimeLibname;
String os = System.getProperty("os.name");
// ref: http://lopica.sourceforge.net/os.html
if (os.startsWith("Linux")) {
runtimeLibname = "libtvm_runtime.so";
} else if (os.startsWith("Mac")) {
runtimeLibname = "libtvm_runtime.dylib";
} else {
// TODO(yizhi) support windows later
throw new UnsatisfiedLinkError("Windows not supported currently");
}
NativeLibraryLoader.extractResourceFileToTempDir(runtimeLibname, new Action() {
@Override public void invoke(File target) { @Override public void invoke(File target) {
System.err.println("Loading tvm runtime from " + target.getPath()); System.err.println("Loading tvm runtime from " + target.getPath());
checkCall(_LIB.nativeLibInit(target.getPath())); checkCall(_LIB.nativeLibInit(target.getPath()));
......
...@@ -34,7 +34,7 @@ public class Function extends TVMValue { ...@@ -34,7 +34,7 @@ public class Function extends TVMValue {
* @param name full function name. * @param name full function name.
* @return TVM function. * @return TVM function.
*/ */
static Function getFunction(final String name) { public static Function getFunction(final String name) {
for (String fullName : listGlobalFuncNames()) { for (String fullName : listGlobalFuncNames()) {
if (fullName.equals(name)) { if (fullName.equals(name)) {
return getGlobalFunc(fullName, true, false); return getGlobalFunc(fullName, true, false);
......
...@@ -30,7 +30,7 @@ class NativeLibraryLoader { ...@@ -30,7 +30,7 @@ class NativeLibraryLoader {
static { static {
try { try {
tempDir = File.createTempFile("tvm", ""); tempDir = File.createTempFile("tvm4j", "");
if (!tempDir.delete() || !tempDir.mkdir()) { if (!tempDir.delete() || !tempDir.mkdir()) {
throw new IOException("Couldn't create directory " + tempDir.getAbsolutePath()); throw new IOException("Couldn't create directory " + tempDir.getAbsolutePath());
} }
......
...@@ -17,12 +17,12 @@ ...@@ -17,12 +17,12 @@
package ml.dmlc.tvm; package ml.dmlc.tvm;
import ml.dmlc.tvm.rpc.RPC;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
public class TVMContext { public class TVMContext {
private static final int RPC_SESS_MASK = 128;
private static final Map<Integer, String> MASK2STR = new HashMap<Integer, String>(); private static final Map<Integer, String> MASK2STR = new HashMap<Integer, String>();
private static final Map<String, Integer> STR2MASK = new HashMap<String, Integer>(); private static final Map<String, Integer> STR2MASK = new HashMap<String, Integer>();
...@@ -169,9 +169,9 @@ public class TVMContext { ...@@ -169,9 +169,9 @@ public class TVMContext {
} }
@Override public String toString() { @Override public String toString() {
if (deviceType >= RPC_SESS_MASK) { if (deviceType >= RPC.RPC_SESS_MASK) {
int tblId = deviceType / RPC_SESS_MASK - 1; int tblId = deviceType / RPC.RPC_SESS_MASK - 1;
int devType = deviceType % RPC_SESS_MASK; int devType = deviceType % RPC.RPC_SESS_MASK;
return String.format("remote[%d]:%s(%d)", tblId, MASK2STR.get(devType), deviceId); return String.format("remote[%d]:%s(%d)", tblId, MASK2STR.get(devType), deviceId);
} }
return String.format("%s(%d)", MASK2STR.get(deviceType), deviceId); return String.format("%s(%d)", MASK2STR.get(deviceType), deviceId);
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.tvm.rpc;
import ml.dmlc.tvm.Function;
import ml.dmlc.tvm.TVMValue;
public class Client {
/**
* Connect to RPC Server.
* @param url The url of the host.
* @param port The port to connect to.
* @param key Additional key to match server.
* @return The connected session.
*/
public static RPCSession connect(String url, int port, String key) {
Function doConnect = RPC.getApi("_Connect");
if (doConnect == null) {
throw new RuntimeException("Please compile with USE_RPC=1");
}
TVMValue sess = doConnect.pushArg(url).pushArg(port).pushArg(key).invoke();
return new RPCSession(sess.asModule());
}
/**
* Connect to RPC Server.
* @param url The url of the host.
* @param port The port to connect to.
* @return The connected session.
*/
public static RPCSession connect(String url, int port) {
return connect(url, port, "");
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.tvm.rpc;
import ml.dmlc.tvm.Function;
import java.util.HashMap;
import java.util.Map;
public class RPC {
public static final int RPC_MAGIC = 0xff271;
public static final int RPC_SESS_MASK = 128;
private static ThreadLocal<Map<String, Function>> apiFuncs
= new ThreadLocal<Map<String, Function>>() {
@Override
protected Map<String, Function> initialValue() {
return new HashMap<String, Function>();
}
};
static Function getApi(String name) {
Function func = apiFuncs.get().get(name);
if (func == null) {
func = Function.getFunction("contrib.rpc." + name);
if (func == null) {
return null;
}
apiFuncs.get().put(name, func);
}
return func;
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.tvm.rpc;
import ml.dmlc.tvm.Function;
import ml.dmlc.tvm.Module;
import ml.dmlc.tvm.TVMContext;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.HashMap;
import java.util.Map;
/**
* RPC Client session module.
* Do not directly create the object, use Client.connect.
*/
public class RPCSession {
private final Module session;
private final int tblIndex;
private final Map<String, Function> remoteFuncs = new HashMap<String, Function>();
RPCSession(Module sess) {
session = sess;
tblIndex = (int) RPC.getApi("_SessTableIndex").pushArg(session).invoke().asLong();
}
/**
* Get function from the session.
* @param name The name of the function.
* @return The result function.
*/
public Function getFunction(String name) {
return session.getFunction(name);
}
/**
* Construct a remote context.
* @param devType device type.
* @param devId device id.
* @return The corresponding encoded remote context.
*/
public TVMContext context(String devType, int devId) {
TVMContext ctx = new TVMContext(devType, devId);
int encode = (tblIndex + 1) * RPC.RPC_SESS_MASK;
return new TVMContext(ctx.deviceType + encode, devId);
}
/**
* Construct a remote context.
* @param devType device type.
* @return The corresponding encoded remote context.
*/
public TVMContext context(String devType) {
return context(devType, 0);
}
/**
* Construct a remote context.
* @param devType device type.
* @param devId device id.
* @return The corresponding encoded remote context.
*/
public TVMContext context(int devType, int devId) {
int encode = (tblIndex + 1) * RPC.RPC_SESS_MASK;
return new TVMContext(devType + encode, devId);
}
/**
* Construct a remote context.
* @param devType device type.
* @return The corresponding encoded remote context.
*/
public TVMContext context(int devType) {
return context(devType, 0);
}
/**
* Construct remote CPU device.
* @param devId device id.
* @return Remote CPU context.
*/
public TVMContext cpu(int devId) {
return context(1, devId);
}
/**
* Construct remote CPU device.
* @return Remote CPU context.
*/
public TVMContext cpu() {
return cpu(0);
}
/**
* Construct remote GPU device.
* @param devId device id.
* @return Remote GPU context.
*/
public TVMContext gpu(int devId) {
return context(2, devId);
}
/**
* Construct remote GPU device.
* @return Remote GPU context.
*/
public TVMContext gpu() {
return gpu(0);
}
/**
* Construct remote OpenCL device.
* @param devId device id.
* @return Remote OpenCL context.
*/
public TVMContext cl(int devId) {
return context(4, devId);
}
/**
* Construct remote OpenCL device.
* @return Remote OpenCL context.
*/
public TVMContext cl() {
return cl(0);
}
/**
* Construct remote Metal device.
* @param devId device id.
* @return Remote metal context.
*/
public TVMContext metal(int devId) {
return context(8, devId);
}
/**
* Construct remote Metal device.
* @return Remote metal context.
*/
public TVMContext metal() {
return metal(0);
}
/**
* Upload binary to remote runtime temp folder.
* @param data The binary in local to upload.
* @param target The path in remote, cannot be null.
*/
public void upload(byte[] data, String target) {
if (target == null) {
throw new IllegalArgumentException("Please specify the upload target");
}
final String funcName = "upload";
Function remoteFunc = remoteFuncs.get(funcName);
if (remoteFunc == null) {
remoteFunc = getFunction("tvm.contrib.rpc.server.upload");
remoteFuncs.put(funcName, remoteFunc);
}
remoteFunc.pushArg(target).pushArg(data).invoke();
}
/**
* Upload file to remote runtime temp folder.
* @param data The file in local to upload.
* @param target The path in remote.
*/
public void upload(File data, String target) throws IOException {
byte[] blob = getBytesFromFile(data);
upload(blob, target);
}
/**
* Upload file to remote runtime temp folder.
* @param data The file in local to upload.
*/
public void upload(File data) throws IOException {
upload(data, data.getName());
}
/**
* Download file from remote temp folder.
* @param path The relative location to remote temp folder.
* @return The result blob from the file.
*/
public byte[] download(String path) {
final String name = "download";
Function func = remoteFuncs.get(name);
if (func == null) {
func = getFunction("tvm.contrib.rpc.server.download");
remoteFuncs.put(name, func);
}
return func.pushArg(path).invoke().asBytes();
}
/**
* Load a remote module, the file need to be uploaded first.
* @param path The relative location to remote temp folder.
* @return The remote module containing remote function.
*/
public Module loadModule(String path) {
return RPC.getApi("_LoadRemoteModule").pushArg(session).pushArg(path).invoke().asModule();
}
private static byte[] getBytesFromFile(File file) throws IOException {
// Get the size of the file
long length = file.length();
if (length > Integer.MAX_VALUE) {
throw new IOException("File " + file.getName() + " is too large!");
}
// cannot create an array using a long type.
byte[] bytes = new byte[(int)length];
// Read in the bytes
int offset = 0;
int numRead = 0;
InputStream is = new FileInputStream(file);
try {
while (offset < bytes.length
&& (numRead = is.read(bytes, offset, bytes.length - offset)) >= 0) {
offset += numRead;
}
} finally {
is.close();
}
// Ensure all the bytes have been read in
if (offset < bytes.length) {
throw new IOException("Could not completely read file " + file.getName());
}
return bytes;
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.tvm.rpc;
import ml.dmlc.tvm.Function;
import ml.dmlc.tvm.Module;
import ml.dmlc.tvm.TVMValue;
import sun.misc.SharedSecrets;
import java.io.File;
import java.io.FileDescriptor;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.ServerSocket;
import java.net.Socket;
import java.net.SocketException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
/**
* RPC Server.
*/
public class Server {
private static SocketFileDescriptorGetter defaultSocketFdGetter
= new SocketFileDescriptorGetter() {
@Override public int get(Socket socket) {
try {
InputStream is = socket.getInputStream();
FileDescriptor fd = ((FileInputStream) is).getFD();
return SharedSecrets.getJavaIOFileDescriptorAccess().get(fd);
} catch (IOException e) {
e.printStackTrace();
return -1;
}
}
};
private static final int DEFAULT_THREAD_NUMBER_IN_A_POOL = 20;
private final Loop serverLoop;
private final ExecutorService threadPool;
/**
* Start a standalone server.
* @param serverPort Port.
* @param socketFdGetter Method to get system file descriptor of the server socket.
* @throws IOException if failed to bind localhost:port.
*/
public Server(int serverPort, SocketFileDescriptorGetter socketFdGetter) throws IOException {
threadPool = setupThreadPool();
serverLoop = new ListenLoop(serverPort, threadPool, socketFdGetter);
}
/**
* Start a standalone server.
* Use sun.misc.SharedSecrets.getJavaIOFileDescriptorAccess
* to get file descriptor for the socket.
* @param serverPort Port.
* @throws IOException if failed to bind localhost:port.
*/
public Server(int serverPort) throws IOException {
this(serverPort, defaultSocketFdGetter);
}
/**
* Start a server connected to proxy.
* @param proxyHost The proxy server host.
* @param proxyPort The proxy server port.
* @param key The key to identify the server.
* @param socketFdGetter Method to get system file descriptor of the server socket.
*/
public Server(String proxyHost, int proxyPort, String key,
SocketFileDescriptorGetter socketFdGetter) {
threadPool = setupThreadPool();
serverLoop = new ConnectProxyLoop(proxyHost, proxyPort, key, threadPool, socketFdGetter);
}
/**
* Start a server connected to proxy.
* Use sun.misc.SharedSecrets.getJavaIOFileDescriptorAccess
* to get file descriptor for the socket.
* @param proxyHost The proxy server host.
* @param proxyPort The proxy server port.
* @param key The key to identify the server.
*/
public Server(String proxyHost, int proxyPort, String key) {
this(proxyHost, proxyPort, key, defaultSocketFdGetter);
}
private ExecutorService setupThreadPool() {
final String workerThreadNumber = System.getProperty("rpc.server.thread.number");
final int numThread = (workerThreadNumber == null)
? DEFAULT_THREAD_NUMBER_IN_A_POOL : Integer.parseInt(workerThreadNumber);
return Executors.newFixedThreadPool(numThread);
}
/**
* Start the server.
*/
public void start() {
serverLoop.start();
}
/**
* Stop the server.
*/
public void terminate() {
serverLoop.interrupt();
serverLoop.terminate();
threadPool.shutdown();
}
public static interface SocketFileDescriptorGetter {
public int get(Socket socket);
}
static class ServerLoop implements Runnable {
private final Socket socket;
private final SocketFileDescriptorGetter socketFdGetter;
ServerLoop(Socket socket, SocketFileDescriptorGetter fdGetter) {
this.socket = socket;
socketFdGetter = fdGetter;
}
@Override public void run() {
int sockFd = socketFdGetter.get(socket);
if (sockFd != -1) {
File tempDir = null;
try {
tempDir = serverEnv();
RPC.getApi("_ServerLoop").pushArg(sockFd).invoke();
System.err.println("Finish serving " + socket.getRemoteSocketAddress().toString());
} catch (IOException e) {
e.printStackTrace();
} finally {
if (tempDir != null) {
if (!tempDir.delete()) {
System.err.println(
"[WARN] Couldn't delete temporary directory " + tempDir.getAbsolutePath());
}
}
closeQuietly(socket);
}
}
}
private File serverEnv() throws IOException {
// Server environment function return temp dir.
final File tempDir = File.createTempFile("tvm4j_rpc_", "");
if (!tempDir.delete() || !tempDir.mkdir()) {
throw new IOException("Couldn't create directory " + tempDir.getAbsolutePath());
}
Function.register("tvm.contrib.rpc.server.workpath", new Function.Callback() {
@Override public Object invoke(TVMValue... args) {
return tempDir + File.separator + args[0].asString();
}
}, true);
Function.register("tvm.contrib.rpc.server.load_module", new Function.Callback() {
@Override public Object invoke(TVMValue... args) {
String filename = args[0].asString();
String path = tempDir + File.separator + filename;
System.err.println("Load module from " + path);
return Module.load(path);
}
}, true);
return tempDir;
}
}
abstract static class Loop extends Thread {
public abstract void terminate();
}
static class ConnectProxyLoop extends Loop {
private volatile boolean running = true;
private final String host;
private final int port;
private final String key;
private final ExecutorService workerPool;
private final SocketFileDescriptorGetter socketFileDescriptorGetter;
private Socket waitingSocket = null;
public ConnectProxyLoop(String host, int port, String key,
ExecutorService workerPool,
SocketFileDescriptorGetter sockFdGetter) {
this.host = host;
this.port = port;
this.key = "server:" + key;
this.workerPool = workerPool;
socketFileDescriptorGetter = sockFdGetter;
}
@Override public void terminate() {
running = false;
if (waitingSocket != null) {
try {
waitingSocket.close();
} catch (IOException e) {
e.printStackTrace();
}
}
}
@Override public void run() {
while (running) {
try {
Socket socket = new Socket(host, port);
waitingSocket = socket;
InputStream in = socket.getInputStream();
OutputStream out = socket.getOutputStream();
out.write(toBytes(RPC.RPC_MAGIC));
out.write(toBytes(key.length()));
out.write(toBytes(key));
int magic = wrapBytes(recvAll(in, 4)).getInt();
final String address = host + ":" + port;
if (magic == RPC.RPC_MAGIC + 1) {
throw new RuntimeException(
String.format("key: %s has already been used in proxy", key));
} else if (magic == RPC.RPC_MAGIC + 2) {
System.err.println("RPCProxy do not have matching client key " + key);
} else if (magic != RPC.RPC_MAGIC) {
throw new RuntimeException(address + " is not RPC Proxy");
}
System.err.println("RPCProxy connected to " + address);
waitingSocket = null;
workerPool.execute(new ServerLoop(socket, socketFileDescriptorGetter));
} catch (SocketException e) {
// when terminates, this is what we expect, do nothing.
} catch (IOException e) {
e.printStackTrace();
terminate();
}
}
}
}
static class ListenLoop extends Loop {
private final ServerSocket server;
private final ExecutorService workerPool;
private final SocketFileDescriptorGetter socketFileDescriptorGetter;
private volatile boolean running = true;
public ListenLoop(int serverPort, ExecutorService workerPool,
SocketFileDescriptorGetter sockFdGetter) throws IOException {
this.server = new ServerSocket(serverPort);
this.workerPool = workerPool;
this.socketFileDescriptorGetter = sockFdGetter;
}
@Override public void terminate() {
this.running = false;
try {
server.close();
} catch (IOException e) {
e.printStackTrace();
}
}
@Override public void run() {
while (running) {
try {
Socket socket = server.accept();
InputStream in = socket.getInputStream();
OutputStream out = socket.getOutputStream();
int magic = wrapBytes(recvAll(in, 4)).getInt();
if (magic != RPC.RPC_MAGIC) {
closeQuietly(socket);
continue;
}
int keyLen = wrapBytes(recvAll(in, 4)).getInt();
String key = decodeToStr(recvAll(in, keyLen));
if (!key.startsWith("client:")) {
out.write(toBytes(RPC.RPC_MAGIC + 2));
} else {
out.write(toBytes(RPC.RPC_MAGIC));
}
System.err.println("Connection from " + socket.getRemoteSocketAddress().toString());
workerPool.execute(new ServerLoop(socket, socketFileDescriptorGetter));
} catch (SocketException e) {
// when terminates, this is what we expect, do nothing.
} catch (IOException e) {
e.printStackTrace();
terminate();
}
}
}
}
private static byte[] recvAll(final InputStream in, final int numBytes) throws IOException {
byte[] res = new byte[numBytes];
int numRead = 0;
while (numRead < numBytes) {
int chunk = in.read(res, numRead, Math.min(numBytes - numRead, 1024));
numRead += chunk;
}
return res;
}
private static void closeQuietly(Socket socket) {
if (socket != null) {
try {
socket.shutdownInput();
socket.shutdownOutput();
socket.close();
} catch (IOException ioe) {
// close quietly, do nothing.
}
}
}
private static ByteBuffer wrapBytes(byte[] bytes) {
ByteBuffer bb = ByteBuffer.wrap(bytes);
bb.order(ByteOrder.LITTLE_ENDIAN);
return bb;
}
private static byte[] toBytes(int number) {
ByteBuffer bb = ByteBuffer.allocate(4);
bb.order(ByteOrder.LITTLE_ENDIAN);
return bb.putInt(number).array();
}
private static byte[] toBytes(String str) {
byte[] bytes = new byte[str.length()];
for (int i = 0; i < str.length(); ++i) {
bytes[i] = (byte) str.charAt(i);
}
return bytes;
}
private static String decodeToStr(byte[] bytes) {
StringBuilder builder = new StringBuilder();
for (byte bt : bytes) {
builder.append((char) bt);
}
return builder.toString();
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.tvm.rpc;
import ml.dmlc.tvm.Function;
import ml.dmlc.tvm.Module;
import ml.dmlc.tvm.TVMValue;
import org.junit.Test;
import java.io.IOException;
import static org.junit.Assert.assertEquals;
public class RPCTest {
static class RefInt {
public int value;
}
private static Server startServer(RefInt portRef) {
Server server = null;
int port = 9981;
for (int i = 0; i < 10; ++i) {
try {
server = new Server(port + i);
server.start();
portRef.value = port + i;
return server;
} catch (IOException e) {
}
}
throw new RuntimeException("Cannot find an available port.");
}
@Test
public void test_addone() {
if (!Module.enabled("rpc")) {
return;
}
Function.register("test.rpc.addone", new Function.Callback() {
@Override public Object invoke(TVMValue... args) {
return args[0].asLong() + 1L;
}
});
RefInt port = new RefInt();
Server server = null;
try {
server = startServer(port);
RPCSession client = Client.connect("localhost", port.value);
Function func = client.getFunction("test.rpc.addone");
assertEquals(11L, func.call(10).asLong());
} finally {
if (server != null) {
server.terminate();
}
}
}
@Test
public void test_strcat() {
if (!Module.enabled("rpc")) {
return;
}
Function.register("test.rpc.strcat", new Function.Callback() {
@Override public Object invoke(TVMValue... args) {
return args[0].asString() + ":" + args[1].asLong();
}
});
RefInt port = new RefInt();
Server server = null;
try {
server = startServer(port);
RPCSession client = Client.connect("localhost", port.value);
Function func = client.getFunction("test.rpc.strcat");
assertEquals("abc:11", func.call("abc", 11L).asString());
} finally {
if (server != null) {
server.terminate();
}
}
}
@Test
public void test_connect_proxy_server() {
String proxyHost = System.getProperty("test.rpc.proxy.host");
int proxyPort = Integer.parseInt(System.getProperty("test.rpc.proxy.port"));
Function.register("test.rpc.proxy.addone", new Function.Callback() {
@Override public Object invoke(TVMValue... tvmValues) {
return tvmValues[0].asLong() + 1L;
}
});
Server server = null;
try {
server = new Server(proxyHost, proxyPort, "x1");
server.start();
RPCSession client = Client.connect(proxyHost, proxyPort, "x1");
Function f1 = client.getFunction("test.rpc.proxy.addone");
assertEquals(11L, f1.call(10L).asLong());
} finally {
if (server != null) {
server.terminate();
}
}
}
}
import time
from tvm.contrib import rpc_proxy
def start_proxy_server(port, timeout):
prox = rpc_proxy.Proxy("localhost", port=port, port_end=port+1)
if timeout > 0:
import time
time.sleep(timeout)
prox.terminate()
else:
prox.proc.join()
if __name__ == "__main__":
import sys
if len(sys.argv) < 2:
sys.exit(-1)
port = int(sys.argv[1])
timeout = 0 if len(sys.argv) == 2 else float(sys.argv[2])
start_proxy_server(port, timeout)
...@@ -158,27 +158,33 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_tvm_LibInfo_tvmFuncCall( ...@@ -158,27 +158,33 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_tvm_LibInfo_tvmFuncCall(
TVMValue retVal; TVMValue retVal;
int retTypeCode; int retTypeCode;
// function can be invoked recursively,
// thus we copy the pushed arguments here.
auto argValues = e->tvmFuncArgValues;
auto argTypes = e->tvmFuncArgTypes;
auto pushedStrs = e->tvmFuncArgPushedStrs;
auto pushedBytes = e->tvmFuncArgPushedBytes;
e->tvmFuncArgPushedStrs.clear();
e->tvmFuncArgPushedBytes.clear();
e->tvmFuncArgTypes.clear();
e->tvmFuncArgValues.clear();
int ret = TVMFuncCall(reinterpret_cast<TVMFunctionHandle>(jhandle), int ret = TVMFuncCall(reinterpret_cast<TVMFunctionHandle>(jhandle),
&e->tvmFuncArgValues[0], &e->tvmFuncArgTypes[0], numArgs, &retVal, &retTypeCode); &argValues[0], &argTypes[0], numArgs, &retVal, &retTypeCode);
for (auto iter = e->tvmFuncArgPushedStrs.cbegin(); for (auto iter = pushedStrs.cbegin(); iter != pushedStrs.cend(); iter++) {
iter != e->tvmFuncArgPushedStrs.cend(); iter++) {
env->ReleaseStringUTFChars(iter->first, iter->second); env->ReleaseStringUTFChars(iter->first, iter->second);
env->DeleteGlobalRef(iter->first); env->DeleteGlobalRef(iter->first);
} }
for (auto iter = e->tvmFuncArgPushedBytes.cbegin(); for (auto iter = pushedBytes.cbegin(); iter != pushedBytes.cend(); iter++) {
iter != e->tvmFuncArgPushedBytes.cend(); iter++) {
env->ReleaseByteArrayElements(iter->first, env->ReleaseByteArrayElements(iter->first,
reinterpret_cast<jbyte *>(const_cast<char *>(iter->second->data)), 0); reinterpret_cast<jbyte *>(const_cast<char *>(iter->second->data)), 0);
env->DeleteGlobalRef(iter->first); env->DeleteGlobalRef(iter->first);
delete iter->second; delete iter->second;
} }
e->tvmFuncArgPushedStrs.clear();
e->tvmFuncArgPushedBytes.clear();
e->tvmFuncArgTypes.clear();
e->tvmFuncArgValues.clear();
// return TVMValue object to Java // return TVMValue object to Java
jclass refTVMValueCls = env->FindClass("ml/dmlc/tvm/Base$RefTVMValue"); jclass refTVMValueCls = env->FindClass("ml/dmlc/tvm/Base$RefTVMValue");
jfieldID refTVMValueFid jfieldID refTVMValueFid
......
...@@ -10,7 +10,14 @@ TEMP_DIR=$(mktemp -d) ...@@ -10,7 +10,14 @@ TEMP_DIR=$(mktemp -d)
python $SCRIPT_DIR/test_add_cpu.py $TEMP_DIR || exit -1 python $SCRIPT_DIR/test_add_cpu.py $TEMP_DIR || exit -1
python $SCRIPT_DIR/test_add_gpu.py $TEMP_DIR || exit -1 python $SCRIPT_DIR/test_add_gpu.py $TEMP_DIR || exit -1
# start rpc proxy server
PORT=$(( ( RANDOM % 1000 ) + 9000 ))
python $SCRIPT_DIR/test_rpc_proxy_server.py $PORT 30 &
make jvmpkg || exit -1 make jvmpkg || exit -1
make jvmpkg JVM_TEST_ARGS="-DskipTests=false -Dtest.tempdir=$TEMP_DIR" || exit -1 make jvmpkg JVM_TEST_ARGS="-DskipTests=false \
-Dtest.tempdir=$TEMP_DIR \
-Dtest.rpc.proxy.host=localhost \
-Dtest.rpc.proxy.port=$PORT" || exit -1
rm -rf $TEMP_DIR rm -rf $TEMP_DIR
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment