Commit 7d67e473 by Yizhi Liu Committed by Tianqi Chen

[tvm4j] RPC Server (#268)

* [tvm4j] RPC Server

* [tvm4j] fix recursively function calling; connect to proxy server; osx rename .so to .dylib

* [tvm4j] test case for proxy connection; thread pool for serving
parent 1146495f
......@@ -6,7 +6,7 @@
<includeBaseDirectory>false</includeBaseDirectory>
<files>
<file>
<source>../../../lib/libtvm_runtime.so</source>
<source>../../../lib/libtvm_runtime.dylib</source>
<outputDirectory>lib/native</outputDirectory>
<fileMode>0644</fileMode>
</file>
......
......@@ -20,18 +20,21 @@
<id>osx-x86_64-cpu</id>
<properties>
<platform>osx-x86_64-cpu</platform>
<libtvm.so.filename>libtvm_runtime.dylib</libtvm.so.filename>
</properties>
</profile>
<profile>
<id>linux-x86_64-cpu</id>
<properties>
<platform>linux-x86_64-cpu</platform>
<libtvm.so.filename>libtvm_runtime.so</libtvm.so.filename>
</properties>
</profile>
<profile>
<id>linux-x86_64-gpu</id>
<properties>
<platform>linux-x86_64-gpu</platform>
<libtvm.so.filename>libtvm_runtime.so</libtvm.so.filename>
</properties>
</profile>
</profiles>
......@@ -88,7 +91,7 @@
<threadCount>1</threadCount>
<argLine>
-Djava.library.path=${project.parent.basedir}/native/${platform}/target
-Dlibtvm.so.path=${project.parent.basedir}/../lib/libtvm_runtime.so
-Dlibtvm.so.path=${project.parent.basedir}/../lib/${libtvm.so.filename}
</argLine>
</configuration>
<executions>
......
......@@ -80,7 +80,18 @@ final class Base {
if (tvmLibFilename == null || !new File(tvmLibFilename).isFile()
|| _LIB.nativeLibInit(tvmLibFilename) != 0) {
try {
NativeLibraryLoader.extractResourceFileToTempDir("libtvm_runtime.so", new Action() {
String runtimeLibname;
String os = System.getProperty("os.name");
// ref: http://lopica.sourceforge.net/os.html
if (os.startsWith("Linux")) {
runtimeLibname = "libtvm_runtime.so";
} else if (os.startsWith("Mac")) {
runtimeLibname = "libtvm_runtime.dylib";
} else {
// TODO(yizhi) support windows later
throw new UnsatisfiedLinkError("Windows not supported currently");
}
NativeLibraryLoader.extractResourceFileToTempDir(runtimeLibname, new Action() {
@Override public void invoke(File target) {
System.err.println("Loading tvm runtime from " + target.getPath());
checkCall(_LIB.nativeLibInit(target.getPath()));
......
......@@ -34,7 +34,7 @@ public class Function extends TVMValue {
* @param name full function name.
* @return TVM function.
*/
static Function getFunction(final String name) {
public static Function getFunction(final String name) {
for (String fullName : listGlobalFuncNames()) {
if (fullName.equals(name)) {
return getGlobalFunc(fullName, true, false);
......
......@@ -30,7 +30,7 @@ class NativeLibraryLoader {
static {
try {
tempDir = File.createTempFile("tvm", "");
tempDir = File.createTempFile("tvm4j", "");
if (!tempDir.delete() || !tempDir.mkdir()) {
throw new IOException("Couldn't create directory " + tempDir.getAbsolutePath());
}
......
......@@ -17,12 +17,12 @@
package ml.dmlc.tvm;
import ml.dmlc.tvm.rpc.RPC;
import java.util.HashMap;
import java.util.Map;
public class TVMContext {
private static final int RPC_SESS_MASK = 128;
private static final Map<Integer, String> MASK2STR = new HashMap<Integer, String>();
private static final Map<String, Integer> STR2MASK = new HashMap<String, Integer>();
......@@ -169,9 +169,9 @@ public class TVMContext {
}
@Override public String toString() {
if (deviceType >= RPC_SESS_MASK) {
int tblId = deviceType / RPC_SESS_MASK - 1;
int devType = deviceType % RPC_SESS_MASK;
if (deviceType >= RPC.RPC_SESS_MASK) {
int tblId = deviceType / RPC.RPC_SESS_MASK - 1;
int devType = deviceType % RPC.RPC_SESS_MASK;
return String.format("remote[%d]:%s(%d)", tblId, MASK2STR.get(devType), deviceId);
}
return String.format("%s(%d)", MASK2STR.get(deviceType), deviceId);
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.tvm.rpc;
import ml.dmlc.tvm.Function;
import ml.dmlc.tvm.TVMValue;
public class Client {
/**
* Connect to RPC Server.
* @param url The url of the host.
* @param port The port to connect to.
* @param key Additional key to match server.
* @return The connected session.
*/
public static RPCSession connect(String url, int port, String key) {
Function doConnect = RPC.getApi("_Connect");
if (doConnect == null) {
throw new RuntimeException("Please compile with USE_RPC=1");
}
TVMValue sess = doConnect.pushArg(url).pushArg(port).pushArg(key).invoke();
return new RPCSession(sess.asModule());
}
/**
* Connect to RPC Server.
* @param url The url of the host.
* @param port The port to connect to.
* @return The connected session.
*/
public static RPCSession connect(String url, int port) {
return connect(url, port, "");
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.tvm.rpc;
import ml.dmlc.tvm.Function;
import java.util.HashMap;
import java.util.Map;
public class RPC {
public static final int RPC_MAGIC = 0xff271;
public static final int RPC_SESS_MASK = 128;
private static ThreadLocal<Map<String, Function>> apiFuncs
= new ThreadLocal<Map<String, Function>>() {
@Override
protected Map<String, Function> initialValue() {
return new HashMap<String, Function>();
}
};
static Function getApi(String name) {
Function func = apiFuncs.get().get(name);
if (func == null) {
func = Function.getFunction("contrib.rpc." + name);
if (func == null) {
return null;
}
apiFuncs.get().put(name, func);
}
return func;
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.tvm.rpc;
import ml.dmlc.tvm.Function;
import ml.dmlc.tvm.Module;
import ml.dmlc.tvm.TVMContext;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.HashMap;
import java.util.Map;
/**
* RPC Client session module.
* Do not directly create the object, use Client.connect.
*/
public class RPCSession {
private final Module session;
private final int tblIndex;
private final Map<String, Function> remoteFuncs = new HashMap<String, Function>();
RPCSession(Module sess) {
session = sess;
tblIndex = (int) RPC.getApi("_SessTableIndex").pushArg(session).invoke().asLong();
}
/**
* Get function from the session.
* @param name The name of the function.
* @return The result function.
*/
public Function getFunction(String name) {
return session.getFunction(name);
}
/**
* Construct a remote context.
* @param devType device type.
* @param devId device id.
* @return The corresponding encoded remote context.
*/
public TVMContext context(String devType, int devId) {
TVMContext ctx = new TVMContext(devType, devId);
int encode = (tblIndex + 1) * RPC.RPC_SESS_MASK;
return new TVMContext(ctx.deviceType + encode, devId);
}
/**
* Construct a remote context.
* @param devType device type.
* @return The corresponding encoded remote context.
*/
public TVMContext context(String devType) {
return context(devType, 0);
}
/**
* Construct a remote context.
* @param devType device type.
* @param devId device id.
* @return The corresponding encoded remote context.
*/
public TVMContext context(int devType, int devId) {
int encode = (tblIndex + 1) * RPC.RPC_SESS_MASK;
return new TVMContext(devType + encode, devId);
}
/**
* Construct a remote context.
* @param devType device type.
* @return The corresponding encoded remote context.
*/
public TVMContext context(int devType) {
return context(devType, 0);
}
/**
* Construct remote CPU device.
* @param devId device id.
* @return Remote CPU context.
*/
public TVMContext cpu(int devId) {
return context(1, devId);
}
/**
* Construct remote CPU device.
* @return Remote CPU context.
*/
public TVMContext cpu() {
return cpu(0);
}
/**
* Construct remote GPU device.
* @param devId device id.
* @return Remote GPU context.
*/
public TVMContext gpu(int devId) {
return context(2, devId);
}
/**
* Construct remote GPU device.
* @return Remote GPU context.
*/
public TVMContext gpu() {
return gpu(0);
}
/**
* Construct remote OpenCL device.
* @param devId device id.
* @return Remote OpenCL context.
*/
public TVMContext cl(int devId) {
return context(4, devId);
}
/**
* Construct remote OpenCL device.
* @return Remote OpenCL context.
*/
public TVMContext cl() {
return cl(0);
}
/**
* Construct remote Metal device.
* @param devId device id.
* @return Remote metal context.
*/
public TVMContext metal(int devId) {
return context(8, devId);
}
/**
* Construct remote Metal device.
* @return Remote metal context.
*/
public TVMContext metal() {
return metal(0);
}
/**
* Upload binary to remote runtime temp folder.
* @param data The binary in local to upload.
* @param target The path in remote, cannot be null.
*/
public void upload(byte[] data, String target) {
if (target == null) {
throw new IllegalArgumentException("Please specify the upload target");
}
final String funcName = "upload";
Function remoteFunc = remoteFuncs.get(funcName);
if (remoteFunc == null) {
remoteFunc = getFunction("tvm.contrib.rpc.server.upload");
remoteFuncs.put(funcName, remoteFunc);
}
remoteFunc.pushArg(target).pushArg(data).invoke();
}
/**
* Upload file to remote runtime temp folder.
* @param data The file in local to upload.
* @param target The path in remote.
*/
public void upload(File data, String target) throws IOException {
byte[] blob = getBytesFromFile(data);
upload(blob, target);
}
/**
* Upload file to remote runtime temp folder.
* @param data The file in local to upload.
*/
public void upload(File data) throws IOException {
upload(data, data.getName());
}
/**
* Download file from remote temp folder.
* @param path The relative location to remote temp folder.
* @return The result blob from the file.
*/
public byte[] download(String path) {
final String name = "download";
Function func = remoteFuncs.get(name);
if (func == null) {
func = getFunction("tvm.contrib.rpc.server.download");
remoteFuncs.put(name, func);
}
return func.pushArg(path).invoke().asBytes();
}
/**
* Load a remote module, the file need to be uploaded first.
* @param path The relative location to remote temp folder.
* @return The remote module containing remote function.
*/
public Module loadModule(String path) {
return RPC.getApi("_LoadRemoteModule").pushArg(session).pushArg(path).invoke().asModule();
}
private static byte[] getBytesFromFile(File file) throws IOException {
// Get the size of the file
long length = file.length();
if (length > Integer.MAX_VALUE) {
throw new IOException("File " + file.getName() + " is too large!");
}
// cannot create an array using a long type.
byte[] bytes = new byte[(int)length];
// Read in the bytes
int offset = 0;
int numRead = 0;
InputStream is = new FileInputStream(file);
try {
while (offset < bytes.length
&& (numRead = is.read(bytes, offset, bytes.length - offset)) >= 0) {
offset += numRead;
}
} finally {
is.close();
}
// Ensure all the bytes have been read in
if (offset < bytes.length) {
throw new IOException("Could not completely read file " + file.getName());
}
return bytes;
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.tvm.rpc;
import ml.dmlc.tvm.Function;
import ml.dmlc.tvm.Module;
import ml.dmlc.tvm.TVMValue;
import org.junit.Test;
import java.io.IOException;
import static org.junit.Assert.assertEquals;
public class RPCTest {
static class RefInt {
public int value;
}
private static Server startServer(RefInt portRef) {
Server server = null;
int port = 9981;
for (int i = 0; i < 10; ++i) {
try {
server = new Server(port + i);
server.start();
portRef.value = port + i;
return server;
} catch (IOException e) {
}
}
throw new RuntimeException("Cannot find an available port.");
}
@Test
public void test_addone() {
if (!Module.enabled("rpc")) {
return;
}
Function.register("test.rpc.addone", new Function.Callback() {
@Override public Object invoke(TVMValue... args) {
return args[0].asLong() + 1L;
}
});
RefInt port = new RefInt();
Server server = null;
try {
server = startServer(port);
RPCSession client = Client.connect("localhost", port.value);
Function func = client.getFunction("test.rpc.addone");
assertEquals(11L, func.call(10).asLong());
} finally {
if (server != null) {
server.terminate();
}
}
}
@Test
public void test_strcat() {
if (!Module.enabled("rpc")) {
return;
}
Function.register("test.rpc.strcat", new Function.Callback() {
@Override public Object invoke(TVMValue... args) {
return args[0].asString() + ":" + args[1].asLong();
}
});
RefInt port = new RefInt();
Server server = null;
try {
server = startServer(port);
RPCSession client = Client.connect("localhost", port.value);
Function func = client.getFunction("test.rpc.strcat");
assertEquals("abc:11", func.call("abc", 11L).asString());
} finally {
if (server != null) {
server.terminate();
}
}
}
@Test
public void test_connect_proxy_server() {
String proxyHost = System.getProperty("test.rpc.proxy.host");
int proxyPort = Integer.parseInt(System.getProperty("test.rpc.proxy.port"));
Function.register("test.rpc.proxy.addone", new Function.Callback() {
@Override public Object invoke(TVMValue... tvmValues) {
return tvmValues[0].asLong() + 1L;
}
});
Server server = null;
try {
server = new Server(proxyHost, proxyPort, "x1");
server.start();
RPCSession client = Client.connect(proxyHost, proxyPort, "x1");
Function f1 = client.getFunction("test.rpc.proxy.addone");
assertEquals(11L, f1.call(10L).asLong());
} finally {
if (server != null) {
server.terminate();
}
}
}
}
import time
from tvm.contrib import rpc_proxy
def start_proxy_server(port, timeout):
prox = rpc_proxy.Proxy("localhost", port=port, port_end=port+1)
if timeout > 0:
import time
time.sleep(timeout)
prox.terminate()
else:
prox.proc.join()
if __name__ == "__main__":
import sys
if len(sys.argv) < 2:
sys.exit(-1)
port = int(sys.argv[1])
timeout = 0 if len(sys.argv) == 2 else float(sys.argv[2])
start_proxy_server(port, timeout)
......@@ -158,27 +158,33 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_tvm_LibInfo_tvmFuncCall(
TVMValue retVal;
int retTypeCode;
// function can be invoked recursively,
// thus we copy the pushed arguments here.
auto argValues = e->tvmFuncArgValues;
auto argTypes = e->tvmFuncArgTypes;
auto pushedStrs = e->tvmFuncArgPushedStrs;
auto pushedBytes = e->tvmFuncArgPushedBytes;
e->tvmFuncArgPushedStrs.clear();
e->tvmFuncArgPushedBytes.clear();
e->tvmFuncArgTypes.clear();
e->tvmFuncArgValues.clear();
int ret = TVMFuncCall(reinterpret_cast<TVMFunctionHandle>(jhandle),
&e->tvmFuncArgValues[0], &e->tvmFuncArgTypes[0], numArgs, &retVal, &retTypeCode);
&argValues[0], &argTypes[0], numArgs, &retVal, &retTypeCode);
for (auto iter = e->tvmFuncArgPushedStrs.cbegin();
iter != e->tvmFuncArgPushedStrs.cend(); iter++) {
for (auto iter = pushedStrs.cbegin(); iter != pushedStrs.cend(); iter++) {
env->ReleaseStringUTFChars(iter->first, iter->second);
env->DeleteGlobalRef(iter->first);
}
for (auto iter = e->tvmFuncArgPushedBytes.cbegin();
iter != e->tvmFuncArgPushedBytes.cend(); iter++) {
for (auto iter = pushedBytes.cbegin(); iter != pushedBytes.cend(); iter++) {
env->ReleaseByteArrayElements(iter->first,
reinterpret_cast<jbyte *>(const_cast<char *>(iter->second->data)), 0);
env->DeleteGlobalRef(iter->first);
delete iter->second;
}
e->tvmFuncArgPushedStrs.clear();
e->tvmFuncArgPushedBytes.clear();
e->tvmFuncArgTypes.clear();
e->tvmFuncArgValues.clear();
// return TVMValue object to Java
jclass refTVMValueCls = env->FindClass("ml/dmlc/tvm/Base$RefTVMValue");
jfieldID refTVMValueFid
......
......@@ -10,7 +10,14 @@ TEMP_DIR=$(mktemp -d)
python $SCRIPT_DIR/test_add_cpu.py $TEMP_DIR || exit -1
python $SCRIPT_DIR/test_add_gpu.py $TEMP_DIR || exit -1
# start rpc proxy server
PORT=$(( ( RANDOM % 1000 ) + 9000 ))
python $SCRIPT_DIR/test_rpc_proxy_server.py $PORT 30 &
make jvmpkg || exit -1
make jvmpkg JVM_TEST_ARGS="-DskipTests=false -Dtest.tempdir=$TEMP_DIR" || exit -1
make jvmpkg JVM_TEST_ARGS="-DskipTests=false \
-Dtest.tempdir=$TEMP_DIR \
-Dtest.rpc.proxy.host=localhost \
-Dtest.rpc.proxy.port=$PORT" || exit -1
rm -rf $TEMP_DIR
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment