Commit 5bd4afee by Yizhi Liu Committed by Tianqi Chen

[tvm4j] add GraphRuntime (#1472)

parent 05afac09
...@@ -109,8 +109,7 @@ public class Function extends TVMValue { ...@@ -109,8 +109,7 @@ public class Function extends TVMValue {
/** /**
* Release the Function. * Release the Function.
* <p> * <p>
* We highly recommend you to do this manually since the GC strategy is lazy * We highly recommend you to do this manually since the GC strategy is lazy.
* and `finalize()` is not guaranteed to be called when GC happens.
* </p> * </p>
*/ */
@Override public void release() { @Override public void release() {
...@@ -269,6 +268,7 @@ public class Function extends TVMValue { ...@@ -269,6 +268,7 @@ public class Function extends TVMValue {
case BYTES: case BYTES:
Base._LIB.tvmFuncPushArgBytes(tvmArg.asBytes()); Base._LIB.tvmFuncPushArgBytes(tvmArg.asBytes());
break; break;
case HANDLE:
case ARRAY_HANDLE: case ARRAY_HANDLE:
case MODULE_HANDLE: case MODULE_HANDLE:
case FUNC_HANDLE: case FUNC_HANDLE:
......
...@@ -72,8 +72,7 @@ public class Module extends TVMValue { ...@@ -72,8 +72,7 @@ public class Module extends TVMValue {
/** /**
* Release the Module. * Release the Module.
* <p> * <p>
* We highly recommend you to do this manually since the GC strategy is lazy * We highly recommend you to do this manually since the GC strategy is lazy.
* and `finalize()` is not guaranteed to be called when GC happens.
* </p> * </p>
*/ */
@Override public void release() { @Override public void release() {
...@@ -123,6 +122,13 @@ public class Module extends TVMValue { ...@@ -123,6 +122,13 @@ public class Module extends TVMValue {
} }
/** /**
* @return type key of the module.
*/
public String typeKey() {
return getApi("_GetTypeKey").pushArg(this).invoke().asString();
}
/**
* Load module from file. * Load module from file.
* @param path The path to the module file. * @param path The path to the module file.
* @param fmt The format of the file, * @param fmt The format of the file,
......
...@@ -27,10 +27,12 @@ import java.util.List; ...@@ -27,10 +27,12 @@ import java.util.List;
*/ */
public class NDArray extends NDArrayBase { public class NDArray extends NDArrayBase {
private final TVMType dtype; private final TVMType dtype;
private final TVMContext context;
NDArray(long handle, boolean isView, TVMType dtype) { NDArray(long handle, boolean isView, TVMType dtype, TVMContext ctx) {
super(handle, isView); super(handle, isView);
this.dtype = dtype; this.dtype = dtype;
this.context = ctx;
} }
@Override protected void finalize() throws Throwable { @Override protected void finalize() throws Throwable {
...@@ -362,6 +364,14 @@ public class NDArray extends NDArrayBase { ...@@ -362,6 +364,14 @@ public class NDArray extends NDArrayBase {
} }
/** /**
* Get the context of current array.
* @return the context.
*/
public TVMContext ctx() {
return context;
}
/**
* Create an empty array given shape, type and device. * Create an empty array given shape, type and device.
* @param shape The shape of the array. * @param shape The shape of the array.
* @param dtype The data type of the array. * @param dtype The data type of the array.
...@@ -373,7 +383,7 @@ public class NDArray extends NDArrayBase { ...@@ -373,7 +383,7 @@ public class NDArray extends NDArrayBase {
Base.checkCall(Base._LIB.tvmArrayAlloc( Base.checkCall(Base._LIB.tvmArrayAlloc(
shape, dtype.typeCode, dtype.bits, dtype.lanes, shape, dtype.typeCode, dtype.bits, dtype.lanes,
ctx.deviceType, ctx.deviceId, refHandle)); ctx.deviceType, ctx.deviceId, refHandle));
return new NDArray(refHandle.value, false, dtype); return new NDArray(refHandle.value, false, dtype, ctx);
} }
/** /**
......
...@@ -57,8 +57,7 @@ public class NDArrayBase extends TVMValue { ...@@ -57,8 +57,7 @@ public class NDArrayBase extends TVMValue {
/** /**
* Release the NDArray memory. * Release the NDArray memory.
* <p> * <p>
* We highly recommend you to do this manually since the GC strategy is lazy * We highly recommend you to do this manually since the GC strategy is lazy.
* and `finalize()` is not guaranteed to be called when GC happens.
* </p> * </p>
*/ */
public void release() { public void release() {
......
...@@ -37,16 +37,16 @@ public class TVMType { ...@@ -37,16 +37,16 @@ public class TVMType {
this.lanes = lanes; this.lanes = lanes;
int bitsTemp = 0; int bitsTemp = 0;
if (typeStr.startsWith("int")) { if (typeStr.startsWith("int")) {
typeCode = 0; typeCode = INT;
bitsTemp = Integer.parseInt(typeStr.substring(3)); bitsTemp = Integer.parseInt(typeStr.substring(3));
} else if (typeStr.startsWith("uint")) { } else if (typeStr.startsWith("uint")) {
typeCode = 1; typeCode = UINT;
bitsTemp = Integer.parseInt(typeStr.substring(4)); bitsTemp = Integer.parseInt(typeStr.substring(4));
} else if (typeStr.startsWith("float")) { } else if (typeStr.startsWith("float")) {
typeCode = 2; typeCode = FLOAT;
bitsTemp = Integer.parseInt(typeStr.substring(5)); bitsTemp = Integer.parseInt(typeStr.substring(5));
} else if (typeStr.startsWith("handle")) { } else if (typeStr.startsWith("handle")) {
typeCode = 4; typeCode = HANDLE;
bitsTemp = 64; bitsTemp = 64;
} else { } else {
throw new IllegalArgumentException("Do not know how to handle type " + typeStr); throw new IllegalArgumentException("Do not know how to handle type " + typeStr);
...@@ -78,16 +78,16 @@ public class TVMType { ...@@ -78,16 +78,16 @@ public class TVMType {
@Override public String toString() { @Override public String toString() {
String typeCodeStr; String typeCodeStr;
switch (typeCode) { switch (typeCode) {
case 0: case INT:
typeCodeStr = "int"; typeCodeStr = "int";
break; break;
case 1: case UINT:
typeCodeStr = "uint"; typeCodeStr = "uint";
break; break;
case 2: case FLOAT:
typeCodeStr = "float"; typeCodeStr = "float";
break; break;
case 4: case HANDLE:
typeCodeStr = "handle"; typeCodeStr = "handle";
break; break;
default: default:
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.tvm;
/**
* Java class related to TVM handles (TypeCode.HANDLE)
*/
public class TVMValueHandle extends TVMValue {
public final long value;
public TVMValueHandle(long value) {
super(TypeCode.HANDLE);
this.value = value;
}
@Override public long asHandle() {
return value;
}
}
package ml.dmlc.tvm.contrib;
import ml.dmlc.tvm.Function;
import ml.dmlc.tvm.Module;
import ml.dmlc.tvm.NDArray;
import ml.dmlc.tvm.TVMContext;
/**
* Wrapper runtime module.
* This is a thin wrapper of the underlying TVM module.
* you can also directly call set_input, run, and get_output
* of underlying module functions.
*/
public class GraphModule {
private Module module;
private TVMContext ctx;
private Function fsetInput;
private Function frun;
private Function fgetOutput;
private Function fgetInput;
private Function fdebugGetOutput;
private Function floadParams;
GraphModule(Module module, TVMContext ctx) {
this.module = module;
this.ctx = ctx;
fsetInput = module.getFunction("set_input");
frun = module.getFunction("run");
fgetInput = module.getFunction("get_input");
fgetOutput = module.getFunction("get_output");
try {
fdebugGetOutput = module.getFunction("debug_get_output");
} catch (IllegalArgumentException ignored) {
// ignore
}
floadParams = module.getFunction("load_params");
}
/**
* Release the GraphModule.
* <p>
* We highly recommend you to do this manually since the GC strategy is lazy.
* </p>
*/
public void release() {
fsetInput.release();
frun.release();
fgetInput.release();
fgetOutput.release();
if (fdebugGetOutput != null) {
fdebugGetOutput.release();
}
floadParams.release();
module.release();
}
/**
* Set inputs to the module.
* @param key The input key.
* @param value The input value
* @return self.
*/
public GraphModule setInput(String key, NDArray value) {
NDArray input = value;
if (!value.ctx().equals(ctx)) {
input = NDArray.empty(value.shape(), ctx);
value.copyTo(input);
}
fsetInput.pushArg(key).pushArg(input).invoke();
return this;
}
/**
* Set inputs to the module
* @param key The input key.
* @param value The input value.
* @return self.
*/
public GraphModule setInput(int key, NDArray value) {
NDArray input = value;
if (!value.ctx().equals(ctx)) {
input = NDArray.empty(value.shape(), ctx);
value.copyTo(input);
}
fsetInput.pushArg(key).pushArg(input).invoke();
return this;
}
/**
* Run forward execution of the graph.
* @return self.
*/
public GraphModule run() {
frun.invoke();
return this;
}
/**
* Get index-th input to out.
* @param index The input index.
* @param out The output array container.
* @return out.
*/
public NDArray getInput(int index, NDArray out) {
fgetInput.pushArg(index).pushArg(out).invoke();
return out;
}
/**
* Get index-th output to out.
* @param index The output index.
* @param out The output array container.
* @return out.
*/
public NDArray getOutput(int index, NDArray out) {
fgetOutput.pushArg(index).pushArg(out).invoke();
return out;
}
/**
* Run graph up to node and get the output to out.
* @param node The node name.
* @param out The output array container.
* @return out.
*/
public NDArray debugGetOutput(String node, NDArray out) {
if (fdebugGetOutput != null) {
fdebugGetOutput.pushArg(node).pushArg(out).invoke();
} else {
throw new RuntimeException("Please compile runtime with USE_GRAPH_RUNTIME_DEBUG = 0");
}
return out;
}
/**
* Run graph up to node and get the output to out.
* @param node The node index.
* @param out The output array container.
* @return out.
*/
public NDArray debugGetOutput(int node, NDArray out) {
if (fdebugGetOutput != null) {
fdebugGetOutput.pushArg(node).pushArg(out).invoke();
} else {
throw new RuntimeException("Please compile runtime with USE_GRAPH_RUNTIME_DEBUG = 0");
}
return out;
}
/**
* Load parameters from serialized byte array of parameter dict.
* @param params The serialized parameter.
* @return self.
*/
public GraphModule loadParams(byte[] params) {
floadParams.pushArg(params).invoke();
return this;
}
/**
* Get internal module function.
* @param key The key to the module.
* @return The function.
* @throws IllegalArgumentException if function does not exist.
*/
public Function getFunction(String key) {
return module.getFunction(key);
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.tvm.contrib;
import ml.dmlc.tvm.Function;
import ml.dmlc.tvm.Module;
import ml.dmlc.tvm.TVMContext;
import ml.dmlc.tvm.TVMValue;
import ml.dmlc.tvm.rpc.RPC;
import ml.dmlc.tvm.rpc.RPCSession;
import ml.dmlc.tvm.rpc.TVMRemoteContext;
import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
public class GraphRuntime {
/**
* Create a runtime executor module given a graph and module.
* @param graphJson The graph deployed in json format output by nnvm graph.
* @param libmod The module of the corresponding function.
* @param ctx The local or remote context to deploy the module.
* @return Runtime graph module that can be used to execute the graph.
*/
public static GraphModule create(String graphJson, Module libmod, TVMContext ctx) {
Module graphModule = null;
if (ctx.deviceType >= RPC.RPC_SESS_MASK) {
if (!(ctx instanceof TVMRemoteContext)) {
throw new IllegalArgumentException(
"Looks like you are using remote context with no RPCSession bind."
+ "Use session.context instead.");
}
RPCSession rpcSession = ((TVMRemoteContext) ctx).rpcSession;
// check arguments
if (!"rpc".equals(libmod.typeKey())) {
throw new IllegalArgumentException("libmod.typeKey != rpc");
}
final int sessIndex = (int) ((Function) reflectionStaticCall(
RPC.class, "getApi", "_SessTableIndex"))
.pushArg(libmod).invoke().asLong();
if (sessIndex != (Integer) reflectionGetField(rpcSession, "tblIndex")) {
throw new IllegalArgumentException(String.format(
"libmod SessTableIndex=%d mismatch rpcSession.tblIndex=%d",
sessIndex, reflectionGetField(rpcSession, "tblIndex")));
}
Function rpcModuleHandle = (Function) reflectionStaticCall(
RPC.class, "getApi","_ModuleHandle");
if (rpcModuleHandle == null) {
throw new RuntimeException("Cannot find global function tvm.rpc._ModuleHandle."
+ "Did you compile tvm_runtime with the correct version?");
}
Function fcreate = Function.getFunction("tvm.graph_runtime.remote_create");
if (fcreate == null) {
throw new RuntimeException("Cannot find global function tvm.graph_runtime.remote_create."
+ "Did you compile tvm_runtime with correct version?");
}
TVMValue hmod = rpcModuleHandle.pushArg(libmod).invoke();
graphModule = fcreate.call(graphJson, hmod,
ctx.deviceType % RPC.RPC_SESS_MASK, ctx.deviceId).asModule();
} else {
Function fcreate = Function.getFunction("tvm.graph_runtime.create");
if (fcreate == null) {
throw new RuntimeException("Cannot find global function tvm.graph_runtime.create."
+ "Did you compile tvm_runtime with correct version?");
}
graphModule = fcreate.pushArg(graphJson)
.pushArg(libmod).pushArg(ctx.deviceType).pushArg(ctx.deviceId)
.invoke().asModule();
}
return new GraphModule(graphModule, ctx);
}
private static Object reflectionGetField(Object obj, String fieldName) {
try {
Field field = obj.getClass().getDeclaredField(fieldName);
field.setAccessible(true);
return field.get(obj);
} catch (NoSuchFieldException e) {
throw new RuntimeException(e);
} catch (IllegalAccessException e) {
throw new RuntimeException(e);
}
}
private static Object reflectionStaticCall(Class<?> clazz, String methodName, Object ... args) {
Class<?>[] types = new Class<?>[args.length];
for (int i = 0; i < args.length; ++i) {
types[i] = args[i].getClass();
}
try {
Method method = clazz.getDeclaredMethod(methodName, types);
method.setAccessible(true);
return method.invoke(null, args);
} catch (NoSuchMethodException e) {
throw new RuntimeException(e);
} catch (IllegalAccessException e) {
throw new RuntimeException(e);
} catch (InvocationTargetException e) {
throw new RuntimeException(e);
}
}
}
...@@ -44,6 +44,11 @@ public class RPC { ...@@ -44,6 +44,11 @@ public class RPC {
} }
}; };
/**
* Get internal function starts with namespace tvm.rpc.
* @param name function name.
* @return the function, null if not exists.
*/
static Function getApi(String name) { static Function getApi(String name) {
Function func = apiFuncs.get().get(name); Function func = apiFuncs.get().get(name);
if (func == null) { if (func == null) {
......
...@@ -60,7 +60,7 @@ public class RPCSession { ...@@ -60,7 +60,7 @@ public class RPCSession {
public TVMContext context(String devType, int devId) { public TVMContext context(String devType, int devId) {
TVMContext ctx = new TVMContext(devType, devId); TVMContext ctx = new TVMContext(devType, devId);
int encode = (tblIndex + 1) * RPC.RPC_SESS_MASK; int encode = (tblIndex + 1) * RPC.RPC_SESS_MASK;
return new TVMContext(ctx.deviceType + encode, devId); return new TVMRemoteContext(ctx.deviceType + encode, devId, this);
} }
/** /**
...@@ -80,7 +80,7 @@ public class RPCSession { ...@@ -80,7 +80,7 @@ public class RPCSession {
*/ */
public TVMContext context(int devType, int devId) { public TVMContext context(int devType, int devId) {
int encode = (tblIndex + 1) * RPC.RPC_SESS_MASK; int encode = (tblIndex + 1) * RPC.RPC_SESS_MASK;
return new TVMContext(devType + encode, devId); return new TVMRemoteContext(devType + encode, devId, this);
} }
/** /**
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.tvm.rpc;
import ml.dmlc.tvm.TVMContext;
// always related to RPCSession. Cannot construct by users.
public class TVMRemoteContext extends TVMContext {
public final RPCSession rpcSession;
TVMRemoteContext(int deviceType, int deviceId, RPCSession rpcSession) {
super(deviceType, deviceId);
this.rpcSession = rpcSession;
}
}
package ml.dmlc.tvm;
import ml.dmlc.tvm.rpc.Server;
import java.io.IOException;
public class TestUtils {
public static class RefInt {
public int value;
}
public static Server startServer(RefInt portRef) {
Server server = null;
int port = 9981;
for (int i = 0; i < 10; ++i) {
try {
server = new Server(port + i);
server.start();
portRef.value = port + i;
return server;
} catch (IOException e) {
}
}
throw new RuntimeException("Cannot find an available port.");
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.tvm.contrib;
import ml.dmlc.tvm.*;
import ml.dmlc.tvm.rpc.Client;
import ml.dmlc.tvm.rpc.RPCSession;
import ml.dmlc.tvm.rpc.Server;
import org.junit.BeforeClass;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.File;
import java.io.IOException;
import java.util.Scanner;
import static org.junit.Assert.assertArrayEquals;
public class GraphRuntimeTest {
private final Logger logger = LoggerFactory.getLogger(GraphRuntime.class);
private static String loadingDir;
@BeforeClass
public static void beforeClass() {
loadingDir = System.getProperty("test.tempdir");
}
@Test
public void test_add_one_local() throws IOException {
Module libmod = Module.load(loadingDir + File.separator + "graph_addone_lib.so");
String graphJson = new Scanner(new File(
loadingDir + File.separator + "graph_addone.json"))
.useDelimiter("\\Z").next();
TVMContext ctx = TVMContext.cpu();
GraphModule graph = GraphRuntime.create(graphJson, libmod, ctx);
long[] shape = new long[]{4};
NDArray arr = NDArray.empty(shape, ctx);
arr.copyFrom(new float[]{1f, 2f, 3f, 4f});
NDArray out = NDArray.empty(shape, ctx);
graph.setInput("x", arr).run();
graph.getOutput(0, out);
assertArrayEquals(new float[]{2f, 3f, 4f, 5f}, out.asFloatArray(), 1e-3f);
arr.release();
out.release();
graph.release();
}
@Test
public void test_add_one_remote() throws IOException {
if (!Module.enabled("rpc")) {
logger.warn("RPC is not enabled. Skip.");
return;
}
String libPath = loadingDir + File.separator + "graph_addone_lib.so";
String graphJson = new Scanner(new File(
loadingDir + File.separator + "graph_addone.json"))
.useDelimiter("\\Z").next();
TestUtils.RefInt port = new TestUtils.RefInt();
Server server = null;
try {
server = TestUtils.startServer(port);
RPCSession remote = Client.connect("localhost", port.value);
TVMContext ctx = remote.cpu();
remote.upload(new File(libPath));
Module mlib = remote.loadModule("graph_addone_lib.so");
GraphModule graph = GraphRuntime.create(graphJson, mlib, ctx);
long[] shape = new long[]{4};
NDArray arr = NDArray.empty(shape, ctx);
arr.copyFrom(new float[]{1f, 2f, 3f, 4f});
NDArray out = NDArray.empty(shape, ctx);
graph.setInput("x", arr).run();
graph.getOutput(0, out);
assertArrayEquals(new float[]{2f, 3f, 4f, 5f}, out.asFloatArray(), 1e-3f);
arr.release();
out.release();
graph.release();
} finally {
if (server != null) {
server.terminate();
}
}
}
}
...@@ -20,36 +20,21 @@ package ml.dmlc.tvm.rpc; ...@@ -20,36 +20,21 @@ package ml.dmlc.tvm.rpc;
import ml.dmlc.tvm.Function; import ml.dmlc.tvm.Function;
import ml.dmlc.tvm.Module; import ml.dmlc.tvm.Module;
import ml.dmlc.tvm.TVMValue; import ml.dmlc.tvm.TVMValue;
import ml.dmlc.tvm.TestUtils;
import org.junit.Ignore; import org.junit.Ignore;
import org.junit.Test; import org.junit.Test;
import org.slf4j.Logger;
import java.io.IOException; import org.slf4j.LoggerFactory;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
public class RPCTest { public class RPCTest {
static class RefInt { private final Logger logger = LoggerFactory.getLogger(RPCTest.class);
public int value;
}
private static Server startServer(RefInt portRef) {
Server server = null;
int port = 9981;
for (int i = 0; i < 10; ++i) {
try {
server = new Server(port + i);
server.start();
portRef.value = port + i;
return server;
} catch (IOException e) {
}
}
throw new RuntimeException("Cannot find an available port.");
}
@Test @Test
public void test_addone() { public void test_addone() {
if (!Module.enabled("rpc")) { if (!Module.enabled("rpc")) {
logger.warn("RPC is not enabled. Skip.");
return; return;
} }
Function.register("test.rpc.addone", new Function.Callback() { Function.register("test.rpc.addone", new Function.Callback() {
...@@ -58,10 +43,10 @@ public class RPCTest { ...@@ -58,10 +43,10 @@ public class RPCTest {
} }
}); });
RefInt port = new RefInt(); TestUtils.RefInt port = new TestUtils.RefInt();
Server server = null; Server server = null;
try { try {
server = startServer(port); server = TestUtils.startServer(port);
RPCSession client = Client.connect("localhost", port.value); RPCSession client = Client.connect("localhost", port.value);
Function func = client.getFunction("test.rpc.addone"); Function func = client.getFunction("test.rpc.addone");
assertEquals(11L, func.call(10).asLong()); assertEquals(11L, func.call(10).asLong());
...@@ -75,6 +60,7 @@ public class RPCTest { ...@@ -75,6 +60,7 @@ public class RPCTest {
@Test @Test
public void test_strcat() { public void test_strcat() {
if (!Module.enabled("rpc")) { if (!Module.enabled("rpc")) {
logger.warn("RPC is not enabled. Skip.");
return; return;
} }
Function.register("test.rpc.strcat", new Function.Callback() { Function.register("test.rpc.strcat", new Function.Callback() {
...@@ -83,10 +69,10 @@ public class RPCTest { ...@@ -83,10 +69,10 @@ public class RPCTest {
} }
}); });
RefInt port = new RefInt(); TestUtils.RefInt port = new TestUtils.RefInt();
Server server = null; Server server = null;
try { try {
server = startServer(port); server = TestUtils.startServer(port);
RPCSession client = Client.connect("localhost", port.value); RPCSession client = Client.connect("localhost", port.value);
Function func = client.getFunction("test.rpc.strcat"); Function func = client.getFunction("test.rpc.strcat");
assertEquals("abc:11", func.call("abc", 11L).asString()); assertEquals("abc:11", func.call("abc", 11L).asString());
......
import os
import tvm
import json
from tvm.contrib import graph_runtime
def dump_graph_lib(target_dir):
dim = 4
A = tvm.placeholder((dim,), name='A')
B = tvm.compute(A.shape, lambda *i: A(*i) + 1.0, name='B')
sched = tvm.create_schedule(B.op)
node0 = {"op": "null", "name": "x", "inputs": []}
node1 = {"op": "tvm_op", "name": "add",
"inputs": [[0, 0, 0]],
"attrs": {"func_name": "myadd",
"flatten_data": "1",
"num_inputs" : "1",
"num_outputs" : "1"}}
nodes = [node0, node1]
arg_nodes = [0]
node_row_ptr = [0, 1, 2]
outputs = [[1, 0, 0]]
shape = (4,)
attrs = {
"shape" : ["list_shape", [shape, shape]],
"dltype" : ["list_str", ["float32", "float32"]],
"storage_id" : ["list_int", [0, 1]],
}
graph = {"nodes": nodes,
"arg_nodes": arg_nodes,
"node_row_ptr": node_row_ptr,
"heads": outputs,
"attrs": attrs}
graph = json.dumps(graph)
mlib = tvm.build(sched, [A, B], "llvm", name="myadd")
mlib.export_library(os.path.join(target_dir, "graph_addone_lib.so"))
with open(os.path.join(target_dir, "graph_addone.json"), "w") as fo:
fo.write(graph)
if __name__ == "__main__":
import sys
if len(sys.argv) != 2:
sys.exit(-1)
dump_graph_lib(sys.argv[1])
...@@ -72,6 +72,14 @@ jstring getTVMValueStringField(JNIEnv *env, jobject obj) { ...@@ -72,6 +72,14 @@ jstring getTVMValueStringField(JNIEnv *env, jobject obj) {
return ret; return ret;
} }
jobject newTVMValueHandle(JNIEnv *env, jlong value) {
jclass cls = env->FindClass("ml/dmlc/tvm/TVMValueHandle");
jmethodID constructor = env->GetMethodID(cls, "<init>", "(J)V");
jobject object = env->NewObject(cls, constructor, value);
env->DeleteLocalRef(cls);
return object;
}
jobject newTVMValueLong(JNIEnv *env, jlong value) { jobject newTVMValueLong(JNIEnv *env, jlong value) {
jclass cls = env->FindClass("ml/dmlc/tvm/TVMValueLong"); jclass cls = env->FindClass("ml/dmlc/tvm/TVMValueLong");
jmethodID constructor = env->GetMethodID(cls, "<init>", "(J)V"); jmethodID constructor = env->GetMethodID(cls, "<init>", "(J)V");
...@@ -166,6 +174,8 @@ jobject tvmRetValueToJava(JNIEnv *env, TVMValue value, int tcode) { ...@@ -166,6 +174,8 @@ jobject tvmRetValueToJava(JNIEnv *env, TVMValue value, int tcode) {
return newTVMValueLong(env, static_cast<jlong>(value.v_int64)); return newTVMValueLong(env, static_cast<jlong>(value.v_int64));
case kDLFloat: case kDLFloat:
return newTVMValueDouble(env, static_cast<jdouble>(value.v_float64)); return newTVMValueDouble(env, static_cast<jdouble>(value.v_float64));
case kHandle:
return newTVMValueHandle(env, reinterpret_cast<jlong>(value.v_handle));
case kModuleHandle: case kModuleHandle:
return newModule(env, reinterpret_cast<jlong>(value.v_handle)); return newModule(env, reinterpret_cast<jlong>(value.v_handle));
case kFuncHandle: case kFuncHandle:
......
...@@ -8,6 +8,7 @@ TEMP_DIR=$(mktemp -d) ...@@ -8,6 +8,7 @@ TEMP_DIR=$(mktemp -d)
python $SCRIPT_DIR/test_add_cpu.py $TEMP_DIR || exit -1 python $SCRIPT_DIR/test_add_cpu.py $TEMP_DIR || exit -1
python $SCRIPT_DIR/test_add_gpu.py $TEMP_DIR || exit -1 python $SCRIPT_DIR/test_add_gpu.py $TEMP_DIR || exit -1
python $SCRIPT_DIR/test_graph_runtime.py $TEMP_DIR || exit -1
# start rpc proxy server # start rpc proxy server
PORT=$(( ( RANDOM % 1000 ) + 9000 )) PORT=$(( ( RANDOM % 1000 ) + 9000 ))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment