Commit 68c4400e by Yizhi Liu Committed by Tianqi Chen

[tvm4j] register user-defined function (#251)

* [tvm4j] register user-defined function

* [tvm4j] define java function (pushArgToStack) to convert arguments to C TVMValue

* [tvm4j] make Module & Function extends TVMValue

* [tvm4j] make registered cb function return Object

* [tvm4j] add cb finalizer; add TVMValueBytes

* [tvm4j] support NDArrayBase cb arg

* [tvm4j] register cb function unit tests

* [tvm4j] pass Function.Callback to resource_handle

* [tvm4j] fix type cast
parent 532ee9b7
......@@ -21,7 +21,10 @@ import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
public class Function {
/**
* TVM Packed Function.
*/
public class Function extends TVMValue {
final long handle;
public final boolean isResident;
private boolean isReleased = false;
......@@ -76,24 +79,41 @@ public class Function {
* @param handle the handle to the underlying function.
* @param isResident Whether this is a resident function in jvm
*/
public Function(long handle, boolean isResident) {
Function(long handle, boolean isResident) {
super(TypeCode.FUNC_HANDLE);
this.handle = handle;
this.isResident = isResident;
}
Function(long handle) {
this(handle, false);
}
@Override protected void finalize() throws Throwable {
release();
super.finalize();
}
/**
* Easy for user to get the instance from returned TVMValue.
* @return this
*/
@Override public Function asFunction() {
return this;
}
@Override long asHandle() {
return handle;
}
/**
* Release the Function.
* <p>
* We highly recommend you to do this manually since the GC strategy is lazy
* and `finalize()` is not guaranteed to be called when GC happens.
* </p>
*/
public void release() {
@Override public void release() {
if (!isReleased) {
if (!isResident) {
Base.checkCall(Base._LIB.tvmFuncFree(handle));
......@@ -167,34 +187,143 @@ public class Function {
* @param arg NDArray.
* @return this
*/
public Function pushArg(NDArray arg) {
public Function pushArg(NDArrayBase arg) {
Base._LIB.tvmFuncPushArgHandle(arg.handle, TypeCode.ARRAY_HANDLE.id);
return this;
}
/**
* Push argument to the function.
* @param arg Module.
* @return this
*/
public Function pushArg(Module arg) {
Base._LIB.tvmFuncPushArgHandle(arg.handle, TypeCode.MODULE_HANDLE.id);
return this;
}
/**
* Push argument to the function.
* @param arg Function.
* @return this
*/
public Function pushArg(Function arg) {
Base._LIB.tvmFuncPushArgHandle(arg.handle, TypeCode.FUNC_HANDLE.id);
return this;
}
/**
* Push argument to the function.
* @param arg bytes.
* @return this
*/
public Function pushArg(byte[] arg) {
Base._LIB.tvmFuncPushArgBytes(arg);
return this;
}
/**
* Invoke function with arguments.
* @param args Can be Integer, Long, Float, Double, String, NDArray.
* @return the result.
*/
public TVMValue call(Object... args) {
for (Object arg : args) {
if (arg instanceof Integer) {
pushArg((Integer) arg);
} else if (arg instanceof Long) {
pushArg((Long) arg);
} else if (arg instanceof Float) {
pushArg((Float) arg);
} else if (arg instanceof Double) {
pushArg((Double) arg);
} else if (arg instanceof String) {
pushArg((String) arg);
} else if (arg instanceof NDArray) {
pushArg((NDArray) arg);
} else {
throw new IllegalArgumentException("Invalid argument: " + arg);
}
pushArgToStack(arg);
}
return invoke();
}
private static void pushArgToStack(Object arg) {
if (arg instanceof Integer) {
Base._LIB.tvmFuncPushArgLong((Integer) arg);
} else if (arg instanceof Long) {
Base._LIB.tvmFuncPushArgLong((Long) arg);
} else if (arg instanceof Float) {
Base._LIB.tvmFuncPushArgDouble((Float) arg);
} else if (arg instanceof Double) {
Base._LIB.tvmFuncPushArgDouble((Double) arg);
} else if (arg instanceof String) {
Base._LIB.tvmFuncPushArgString((String) arg);
} else if (arg instanceof byte[]) {
Base._LIB.tvmFuncPushArgBytes((byte[]) arg);
} else if (arg instanceof NDArrayBase) {
Base._LIB.tvmFuncPushArgHandle(((NDArrayBase) arg).handle, TypeCode.ARRAY_HANDLE.id);
} else if (arg instanceof Module) {
Base._LIB.tvmFuncPushArgHandle(((Module) arg).handle, TypeCode.MODULE_HANDLE.id);
} else if (arg instanceof Function) {
Base._LIB.tvmFuncPushArgHandle(((Function) arg).handle, TypeCode.FUNC_HANDLE.id);
} else if (arg instanceof TVMValue) {
TVMValue tvmArg = (TVMValue) arg;
switch (tvmArg.typeCode) {
case UINT:
case INT:
Base._LIB.tvmFuncPushArgLong(tvmArg.asLong());
break;
case FLOAT:
Base._LIB.tvmFuncPushArgDouble(tvmArg.asDouble());
break;
case STR:
Base._LIB.tvmFuncPushArgString(tvmArg.asString());
break;
case BYTES:
Base._LIB.tvmFuncPushArgBytes(tvmArg.asBytes());
break;
case ARRAY_HANDLE:
case MODULE_HANDLE:
case FUNC_HANDLE:
Base._LIB.tvmFuncPushArgHandle(tvmArg.asHandle(), tvmArg.typeCode.id);
break;
default:
throw new IllegalArgumentException("Invalid argument: " + arg);
}
} else {
throw new IllegalArgumentException("Invalid argument: " + arg);
}
}
public static interface Callback {
public Object invoke(TVMValue... args);
}
/**
* Register user-defined global function.
* @param name The function name.
* @param function The function to be registered.
* @param override Whether override existing entry.
*/
public static void register(String name, Callback function, boolean override) {
Base.RefLong createdFuncHandleRef = new Base.RefLong();
Base.checkCall(Base._LIB.tvmFuncCreateFromCFunc(function, createdFuncHandleRef));
int ioverride = override ? 1 : 0;
Base.checkCall(Base._LIB.tvmFuncRegisterGlobal(name, createdFuncHandleRef.value, ioverride));
}
/**
* Register user-defined global function, do not override existing entry.
* @param name The function name.
* @param function The function to be registered.
*/
public static void register(String name, Callback function) {
register(name, function, false);
}
/**
* Convert a Java function to TVM function.
* @param function Java function.
* @return TVM function.
*/
public static Function convertFunc(Callback function) {
Base.RefLong createdFuncHandleRef = new Base.RefLong();
Base.checkCall(Base._LIB.tvmFuncCreateFromCFunc(function, createdFuncHandleRef));
return new Function(createdFuncHandleRef.value);
}
private static Object invokeRegisteredCbFunc(Callback cb, TVMValue[] args) {
if (cb == null) {
System.err.println("[ERROR] Failed to get registered function");
return null;
}
return cb.invoke(args);
}
}
......@@ -20,56 +20,57 @@ package ml.dmlc.tvm;
import java.util.List;
class LibInfo {
public native int nativeLibInit(String tvmLibFile);
native int nativeLibInit(String tvmLibFile);
public native int shutdown();
native int shutdown();
public native String tvmGetLastError();
native String tvmGetLastError();
// Function
public native void tvmFuncPushArgLong(long arg);
native void tvmFuncPushArgLong(long arg);
public native void tvmFuncPushArgDouble(double arg);
native void tvmFuncPushArgDouble(double arg);
public native void tvmFuncPushArgString(String arg);
native void tvmFuncPushArgString(String arg);
public native void tvmFuncPushArgHandle(long arg, int argType);
native void tvmFuncPushArgBytes(byte[] arg);
public native int tvmFuncListGlobalNames(List<String> funcNames);
native void tvmFuncPushArgHandle(long arg, int argType);
public native int tvmFuncFree(long handle);
native int tvmFuncListGlobalNames(List<String> funcNames);
public native int tvmFuncGetGlobal(String name, Base.RefLong handle);
native int tvmFuncFree(long handle);
public native int tvmFuncCall(long handle, Base.RefTVMValue retVal);
native int tvmFuncGetGlobal(String name, Base.RefLong handle);
native int tvmFuncCall(long handle, Base.RefTVMValue retVal);
native int tvmFuncCreateFromCFunc(Function.Callback function, Base.RefLong handle);
native int tvmFuncRegisterGlobal(String name, long handle, int override);
// Module
public native int tvmModFree(long handle);
native int tvmModFree(long handle);
public native int tvmModGetFunction(long handle, String name,
native int tvmModGetFunction(long handle, String name,
int queryImports, Base.RefLong retHandle);
public native int tvmModImport(long mod, long dep);
native int tvmModImport(long mod, long dep);
// NDArray
public native int tvmArrayFree(long handle);
native int tvmArrayFree(long handle);
public native int tvmArrayAlloc(long[] shape,
int dtypeCode,
int dtypeBits,
int dtypeLanes,
int deviceType,
int deviceId,
Base.RefLong refHandle);
native int tvmArrayAlloc(long[] shape, int dtypeCode, int dtypeBits, int dtypeLanes,
int deviceType, int deviceId, Base.RefLong refHandle);
public native int tvmArrayGetShape(long handle, List<Long> shape);
native int tvmArrayGetShape(long handle, List<Long> shape);
public native int tvmArrayCopyFromTo(long from, long to);
native int tvmArrayCopyFromTo(long from, long to);
public native int tvmArrayCopyFromJArray(byte[] fromRaw, long from, long to);
native int tvmArrayCopyFromJArray(byte[] fromRaw, long from, long to);
public native int tvmArrayCopyToJArray(long from, byte[] to);
native int tvmArrayCopyToJArray(long from, byte[] to);
// TVMContext
public native int tvmSynchronize(int deviceType, int deviceId);
native int tvmSynchronize(int deviceType, int deviceId);
}
......@@ -23,7 +23,7 @@ import java.util.Map;
/**
* Container of compiled functions of TVM.
*/
public class Module {
public class Module extends TVMValue {
public final long handle;
private boolean isReleased = false;
......@@ -44,7 +44,8 @@ public class Module {
return func;
}
public Module(long handle) {
Module(long handle) {
super(TypeCode.MODULE_HANDLE);
this.handle = handle;
}
......@@ -57,13 +58,25 @@ public class Module {
}
/**
* Easy for user to get the instance from returned TVMValue.
* @return this
*/
@Override public Module asModule() {
return this;
}
@Override long asHandle() {
return handle;
}
/**
* Release the Module.
* <p>
* We highly recommend you to do this manually since the GC strategy is lazy
* and `finalize()` is not guaranteed to be called when GC happens.
* </p>
*/
public void release() {
@Override public void release() {
if (!isReleased) {
Base.checkCall(Base._LIB.tvmModFree(handle));
isReleased = true;
......
......@@ -25,49 +25,19 @@ import java.util.List;
/**
* Lightweight NDArray class of TVM runtime.
*/
public class NDArray {
public final long handle;
private final boolean isView;
public class NDArray extends NDArrayBase {
private final TVMType dtype;
private boolean isReleased = false;
NDArray(long handle, boolean isView, TVMType dtype) {
this.handle = handle;
this.isView = isView;
super(handle, isView);
this.dtype = dtype;
}
NDArray(long handle) {
this(handle, false, new TVMType("float32", 1));
}
NDArray(long handle, boolean isView) {
this(handle, isView, new TVMType("float32", 1));
}
@Override protected void finalize() throws Throwable {
release();
super.finalize();
}
/**
* Release the NDArray memory.
* <p>
* We highly recommend you to do this manually since the GC strategy is lazy
* and `finalize()` is not guaranteed to be called when GC happens.
* </p>
*/
public void release() {
if (!isReleased) {
if (!isView) {
Base.checkCall(Base._LIB.tvmArrayFree(handle));
isReleased = true;
}
}
}
/**
* Copy from a native array.
* The NDArray type must by float64
* @param sourceArray the source data
......@@ -366,7 +336,7 @@ public class NDArray {
*/
public byte[] internal() {
NDArray tmp = NDArray.empty(shape(), dtype);
Base.checkCall(Base._LIB.tvmArrayCopyFromTo(handle, tmp.handle));
copyTo(tmp);
int arrLength = dtype.numOfBytes * (int) size();
byte[] arr = new byte[arrLength];
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.tvm;
/**
* Base class of NDArray. To handle callback array.
* Only deep-copy supported.
*/
public class NDArrayBase extends TVMValue {
protected final long handle;
protected final boolean isView;
private boolean isReleased = false;
NDArrayBase(long handle, boolean isView) {
super(TypeCode.ARRAY_HANDLE);
this.handle = handle;
this.isView = isView;
}
NDArrayBase(long handle) {
this(handle, true);
}
@Override public NDArrayBase asNDArray() {
return this;
}
@Override long asHandle() {
return handle;
}
/**
* Copy array to target
* @param target The target array to be copied, must have same shape as this array.
* @return target
*/
public NDArrayBase copyTo(NDArrayBase target) {
Base.checkCall(Base._LIB.tvmArrayCopyFromTo(handle, target.handle));
return target;
}
/**
* Release the NDArray memory.
* <p>
* We highly recommend you to do this manually since the GC strategy is lazy
* and `finalize()` is not guaranteed to be called when GC happens.
* </p>
*/
public void release() {
if (!isReleased) {
if (!isView) {
Base.checkCall(Base._LIB.tvmArrayFree(handle));
isReleased = true;
}
}
}
@Override protected void finalize() throws Throwable {
release();
super.finalize();
}
}
......@@ -24,6 +24,9 @@ public class TVMValue {
typeCode = tc;
}
public void release() {
}
public long asLong() {
throw new UnsupportedOperationException();
}
......@@ -32,15 +35,28 @@ public class TVMValue {
throw new UnsupportedOperationException();
}
public byte[] asBytes() {
throw new UnsupportedOperationException();
}
public Module asModule() {
throw new UnsupportedOperationException();
}
public NDArray asNDArray() {
public Function asFunction() {
throw new UnsupportedOperationException();
}
public NDArrayBase asNDArray() {
throw new UnsupportedOperationException();
}
public String asString() {
throw new UnsupportedOperationException();
}
// easy for JNI to use.
long asHandle() {
throw new UnsupportedOperationException();
}
}
......@@ -17,15 +17,15 @@
package ml.dmlc.tvm;
public class TVMValueModuleHandle extends TVMValue {
public final long value;
public class TVMValueBytes extends TVMValue {
public final byte[] value;
public TVMValueModuleHandle(long value) {
super(TypeCode.MODULE_HANDLE);
public TVMValueBytes(byte[] value) {
super(TypeCode.BYTES);
this.value = value;
}
@Override public Module asModule() {
return new Module(value);
@Override public byte[] asBytes() {
return value;
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.tvm;
public class TVMValueNDArrayHandle extends TVMValue {
public final long value;
public TVMValueNDArrayHandle(long value) {
super(TypeCode.ARRAY_HANDLE);
this.value = value;
}
@Override public NDArray asNDArray() {
return new NDArray(value);
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.tvm;
import org.junit.Test;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
public class FunctionTest {
@Test
public void test_reg_sum_number() {
Function.register("sum_number", new Function.Callback() {
@Override public Object invoke(TVMValue... args) {
long res = 0L;
for (TVMValue arg : args) {
res += arg.asLong();
}
return res;
}
});
Function func = Function.getFunction("sum_number");
TVMValue res = func.pushArg(10).pushArg(20).invoke();
assertEquals(30, res.asLong());
res.release();
func.release();
}
@Test
public void test_add_string() {
Function func = Function.convertFunc(new Function.Callback() {
@Override public Object invoke(TVMValue... args) {
String res = "";
for (TVMValue arg : args) {
res += arg.asString();
}
return res;
}
});
TVMValue res = func.pushArg("Hello").pushArg(" ").pushArg("World!").invoke();
assertEquals("Hello World!", res.asString());
res.release();
func.release();
}
@Test
public void test_sum_first_byte() {
Function func = Function.convertFunc(new Function.Callback() {
@Override public Object invoke(TVMValue... args) {
byte[] bt = new byte[1];
for (TVMValue arg : args) {
bt[0] += arg.asBytes()[0];
}
return bt;
}
});
TVMValue res = func.pushArg(new byte[]{1}).pushArg(new byte[]{2, 3}).invoke();
assertArrayEquals(new byte[]{3}, res.asBytes());
res.release();
func.release();
}
@Test
public void test_sum_ndarray() {
final long[] shape = new long[]{2, 1};
Function func = Function.convertFunc(new Function.Callback() {
@Override public Object invoke(TVMValue... args) {
double sum = 0.0;
for (TVMValue arg : args) {
NDArray arr = NDArray.empty(shape, new TVMType("float32"));
arg.asNDArray().copyTo(arr);
float[] nativeArr = arr.asFloatArray();
for (int i = 0; i < nativeArr.length; ++i) {
sum += nativeArr[i];
}
arr.release();
}
return sum;
}
});
NDArray arr = NDArray.empty(shape, new TVMType("float32"));
arr.copyFrom(new float[]{2f, 3f});
TVMValue res = func.pushArg(arr).pushArg(arr).invoke();
assertEquals(10.0, res.asDouble(), 1e-3);
res.release();
func.release();
}
@Test
public void test_return_function() {
Function myFunc = Function.convertFunc(new Function.Callback() {
@Override public Object invoke(TVMValue... args) {
final long y = args[0].asLong();
return Function.convertFunc(new Function.Callback() {
@Override public Object invoke(TVMValue... args) {
final long x = args[0].asLong();
return x + y;
}
});
}
});
Function func = myFunc.pushArg(10).invoke().asFunction();
TVMValue res = func.pushArg(20).invoke();
assertEquals(30, res.asLong());
func.release();
myFunc.release();
}
}
......@@ -88,8 +88,46 @@ jobject newTVMValueDouble(JNIEnv *env, jdouble value) {
return object;
}
jobject newTVMValueModuleHandle(JNIEnv *env, jlong value) {
jclass cls = env->FindClass("ml/dmlc/tvm/TVMValueModuleHandle");
jobject newTVMValueString(JNIEnv *env, const char *value) {
jstring jvalue = env->NewStringUTF(value);
jclass cls = env->FindClass("ml/dmlc/tvm/TVMValueString");
jmethodID constructor = env->GetMethodID(cls, "<init>", "(Ljava/lang/String;)V");
jobject object = env->NewObject(cls, constructor, jvalue);
env->DeleteLocalRef(cls);
env->DeleteLocalRef(jvalue);
return object;
}
jobject newTVMValueBytes(JNIEnv *env, const TVMByteArray *arr) {
jbyteArray jarr = env->NewByteArray(arr->size);
env->SetByteArrayRegion(jarr, 0, arr->size,
reinterpret_cast<jbyte *>(const_cast<char *>(arr->data)));
jclass cls = env->FindClass("ml/dmlc/tvm/TVMValueBytes");
jmethodID constructor = env->GetMethodID(cls, "<init>", "([B)V");
jobject object = env->NewObject(cls, constructor, jarr);
env->DeleteLocalRef(cls);
env->DeleteLocalRef(jarr);
return object;
}
jobject newModule(JNIEnv *env, jlong value) {
jclass cls = env->FindClass("ml/dmlc/tvm/Module");
jmethodID constructor = env->GetMethodID(cls, "<init>", "(J)V");
jobject object = env->NewObject(cls, constructor, value);
env->DeleteLocalRef(cls);
return object;
}
jobject newFunction(JNIEnv *env, jlong value) {
jclass cls = env->FindClass("ml/dmlc/tvm/Function");
jmethodID constructor = env->GetMethodID(cls, "<init>", "(J)V");
jobject object = env->NewObject(cls, constructor, value);
env->DeleteLocalRef(cls);
return object;
}
jobject newNDArray(JNIEnv *env, jlong value) {
jclass cls = env->FindClass("ml/dmlc/tvm/NDArrayBase");
jmethodID constructor = env->GetMethodID(cls, "<init>", "(J)V");
jobject object = env->NewObject(cls, constructor, value);
env->DeleteLocalRef(cls);
......@@ -121,4 +159,28 @@ void fromJavaContext(JNIEnv *env, jobject jctx, TVMContext *ctx) {
env->DeleteLocalRef(tvmContextClass);
}
jobject tvmRetValueToJava(JNIEnv *env, TVMValue value, int tcode) {
switch (tcode) {
case kUInt:
case kInt:
return newTVMValueLong(env, static_cast<jlong>(value.v_int64));
case kFloat:
return newTVMValueDouble(env, static_cast<jdouble>(value.v_float64));
case kModuleHandle:
return newModule(env, reinterpret_cast<jlong>(value.v_handle));
case kFuncHandle:
return newFunction(env, reinterpret_cast<jlong>(value.v_handle));
case kArrayHandle:
return newNDArray(env, reinterpret_cast<jlong>(value.v_handle));
case kStr:
return newTVMValueString(env, value.v_str);
case kBytes:
return newTVMValueBytes(env, reinterpret_cast<TVMByteArray *>(value.v_handle));
case kNull:
return newObject(env, "ml/dmlc/tvm/TVMValueNull");
default:
LOG(FATAL) << "Do NOT know how to handle return type code " << tcode;
}
}
#endif // TVM4J_JNI_MAIN_NATIVE_JNI_HELPER_FUNC_H_
......@@ -22,6 +22,7 @@ struct TVMFuncArgsThreadLocalEntry {
std::vector<int> tvmFuncArgTypes;
// for later release
std::vector<std::pair<jstring, const char *> > tvmFuncArgPushedStrs;
std::vector<std::pair<jbyteArray, TVMByteArray *> > tvmFuncArgPushedBytes;
};
typedef dmlc::ThreadLocalStore<TVMFuncArgsThreadLocalEntry> TVMFuncArgsThreadLocalStore;
......@@ -90,6 +91,26 @@ JNIEXPORT void JNICALL Java_ml_dmlc_tvm_LibInfo_tvmFuncPushArgHandle(
e->tvmFuncArgTypes.push_back(static_cast<int>(argType));
}
JNIEXPORT void JNICALL Java_ml_dmlc_tvm_LibInfo_tvmFuncPushArgBytes(
JNIEnv *env, jobject obj, jbyteArray arg) {
jbyteArray garg = reinterpret_cast<jbyteArray>(env->NewGlobalRef(arg));
jbyte *data = env->GetByteArrayElements(garg, 0);
TVMByteArray *byteArray = new TVMByteArray();
byteArray->size = static_cast<size_t>(env->GetArrayLength(garg));
byteArray->data = reinterpret_cast<const char *>(data);
TVMValue value;
value.v_handle = reinterpret_cast<void *>(byteArray);
TVMFuncArgsThreadLocalEntry *e = TVMFuncArgsThreadLocalStore::Get();
e->tvmFuncArgValues.push_back(value);
e->tvmFuncArgTypes.push_back(kBytes);
e->tvmFuncArgPushedBytes.push_back(std::make_pair(garg, byteArray));
// release (garg, data), byteArray later
}
JNIEXPORT jint JNICALL Java_ml_dmlc_tvm_LibInfo_tvmFuncListGlobalNames(
JNIEnv *env, jobject obj, jobject jfuncNames) {
int outSize;
......@@ -145,7 +166,16 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_tvm_LibInfo_tvmFuncCall(
env->ReleaseStringUTFChars(iter->first, iter->second);
env->DeleteGlobalRef(iter->first);
}
for (auto iter = e->tvmFuncArgPushedBytes.cbegin();
iter != e->tvmFuncArgPushedBytes.cend(); iter++) {
env->ReleaseByteArrayElements(iter->first,
reinterpret_cast<jbyte *>(const_cast<char *>(iter->second->data)), 0);
env->DeleteGlobalRef(iter->first);
delete iter->second;
}
e->tvmFuncArgPushedStrs.clear();
e->tvmFuncArgPushedBytes.clear();
e->tvmFuncArgTypes.clear();
e->tvmFuncArgValues.clear();
......@@ -154,32 +184,116 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_tvm_LibInfo_tvmFuncCall(
jfieldID refTVMValueFid
= env->GetFieldID(refTVMValueCls, "value", "Lml/dmlc/tvm/TVMValue;");
switch (retTypeCode) {
case kInt:
env->SetObjectField(jretVal, refTVMValueFid,
newTVMValueLong(env, static_cast<jlong>(retVal.v_int64)));
break;
case kFloat:
env->SetObjectField(jretVal, refTVMValueFid,
newTVMValueDouble(env, static_cast<jdouble>(retVal.v_float64)));
break;
case kModuleHandle:
env->SetObjectField(jretVal, refTVMValueFid,
newTVMValueModuleHandle(env, reinterpret_cast<jlong>(retVal.v_handle)));
break;
case kNull:
env->SetObjectField(jretVal, refTVMValueFid,
newObject(env, "ml/dmlc/tvm/TVMValueNull"));
break;
default:
LOG(FATAL) << "Do NOT know how to handle return type code " << retTypeCode;
}
env->SetObjectField(jretVal, refTVMValueFid, tvmRetValueToJava(env, retVal, retTypeCode));
env->DeleteLocalRef(refTVMValueCls);
return ret;
}
// Callback function
extern "C" int funcInvokeCallback(TVMValue *args,
int *typeCodes, int numArgs, TVMRetValueHandle ret, void *resourceHandle) {
JNIEnv *env;
int jniStatus = _jvm->GetEnv(reinterpret_cast<void **>(&env), JNI_VERSION_1_6);
if (jniStatus == JNI_EDETACHED) {
_jvm->AttachCurrentThread(reinterpret_cast<void **>(&env), nullptr);
} else {
CHECK(jniStatus == JNI_OK);
}
jclass tvmValueCls = env->FindClass("ml/dmlc/tvm/TVMValue");
jobjectArray jargs = env->NewObjectArray(numArgs, tvmValueCls, 0);
for (int i = 0; i < numArgs; ++i) {
TVMValue arg = args[i];
int tcode = typeCodes[i];
if (tcode == kNodeHandle || tcode == kFuncHandle || tcode == kModuleHandle) {
TVMCbArgToReturn(&arg, tcode);
}
jobject jarg = tvmRetValueToJava(env, arg, tcode);
env->SetObjectArrayElement(jargs, i, jarg);
}
jclass clsFunc = env->FindClass("ml/dmlc/tvm/Function");
jmethodID invokeRegisteredCbFunc = env->GetStaticMethodID(clsFunc, "invokeRegisteredCbFunc",
"(Lml/dmlc/tvm/Function$Callback;[Lml/dmlc/tvm/TVMValue;)Ljava/lang/Object;");
jmethodID pushArgToStack = env->GetStaticMethodID(clsFunc, "pushArgToStack",
"(Ljava/lang/Object;)V");
jobject jretValue = env->CallStaticObjectMethod(clsFunc, invokeRegisteredCbFunc,
reinterpret_cast<jobject>(resourceHandle), jargs);
TVMFuncArgsThreadLocalEntry *e = TVMFuncArgsThreadLocalStore::Get();
const int prevNumStrArg = e->tvmFuncArgPushedStrs.size();
const int prevNumBytesArg = e->tvmFuncArgPushedBytes.size();
// convert returned (java) TVMValue to (C) TVMValue
env->CallStaticVoidMethod(clsFunc, pushArgToStack, jretValue);
TVMValue retValue = e->tvmFuncArgValues.back();
e->tvmFuncArgValues.pop_back();
int retCode = e->tvmFuncArgTypes.back();
e->tvmFuncArgTypes.pop_back();
// set back the return value
TVMCFuncSetReturn(ret, &retValue, &retCode, 1);
// release allocated strings.
if (e->tvmFuncArgPushedStrs.size() > prevNumStrArg) {
const auto &pairArg = e->tvmFuncArgPushedStrs.back();
env->ReleaseStringUTFChars(pairArg.first, pairArg.second);
env->DeleteGlobalRef(pairArg.first);
e->tvmFuncArgPushedStrs.pop_back();
}
// release allocated bytes.
if (e->tvmFuncArgPushedBytes.size() > prevNumBytesArg) {
const auto &pairArg = e->tvmFuncArgPushedBytes.back();
env->ReleaseByteArrayElements(pairArg.first,
reinterpret_cast<jbyte *>(const_cast<char *>(pairArg.second->data)), 0);
env->DeleteGlobalRef(pairArg.first);
delete pairArg.second;
e->tvmFuncArgPushedBytes.pop_back();
}
env->DeleteLocalRef(clsFunc);
env->DeleteLocalRef(tvmValueCls);
return 0;
}
// Free callback function
extern "C" void funcFreeCallback(void *resourceHandle) {
JNIEnv *env;
int jniStatus = _jvm->GetEnv(reinterpret_cast<void **>(&env), JNI_VERSION_1_6);
if (jniStatus == JNI_EDETACHED) {
_jvm->AttachCurrentThread(reinterpret_cast<void **>(&env), nullptr);
} else {
CHECK(jniStatus == JNI_OK);
}
env->DeleteGlobalRef(reinterpret_cast<jobject>(resourceHandle));
}
JNIEXPORT jint JNICALL Java_ml_dmlc_tvm_LibInfo_tvmFuncCreateFromCFunc(
JNIEnv *env, jobject obj, jobject jfunction, jobject jretHandle) {
TVMFunctionHandle out;
int ret = TVMFuncCreateFromCFunc(reinterpret_cast<TVMPackedCFunc>(&funcInvokeCallback),
reinterpret_cast<void *>(env->NewGlobalRef(jfunction)),
reinterpret_cast<TVMPackedCFuncFinalizer>(&funcFreeCallback),
&out);
setLongField(env, jretHandle, reinterpret_cast<jlong>(out));
return ret;
}
JNIEXPORT jint JNICALL Java_ml_dmlc_tvm_LibInfo_tvmFuncRegisterGlobal(
JNIEnv *env, jobject obj, jstring jname, jlong jhandle, jint joverride) {
const char *name = env->GetStringUTFChars(jname, 0);
int ret = TVMFuncRegisterGlobal(
name, reinterpret_cast<TVMFunctionHandle>(jhandle), reinterpret_cast<int>(joverride));
env->ReleaseStringUTFChars(jname, name);
return ret;
}
// Module
JNIEXPORT jint JNICALL Java_ml_dmlc_tvm_LibInfo_tvmModFree(
JNIEnv *env, jobject obj, jlong jhandle) {
......@@ -223,14 +337,14 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_tvm_LibInfo_tvmArrayAlloc(
jlong *shapeArray = env->GetLongArrayElements(jshape, NULL);
int ret = TVMArrayAlloc(
reinterpret_cast<const tvm_index_t*>(shapeArray),
ndim,
static_cast<int>(jdtypeCode),
static_cast<int>(jdtypeBits),
static_cast<int>(jdtypeLanes),
static_cast<int>(jdeviceType),
static_cast<int>(jdeviceId),
&out);
reinterpret_cast<const tvm_index_t*>(shapeArray),
ndim,
static_cast<int>(jdtypeCode),
static_cast<int>(jdtypeBits),
static_cast<int>(jdtypeLanes),
static_cast<int>(jdeviceType),
static_cast<int>(jdeviceId),
&out);
env->ReleaseLongArrayElements(jshape, shapeArray, 0);
setLongField(env, jret, reinterpret_cast<jlong>(out));
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment