Commit feabd406 by Yizhi Liu Committed by Tianqi Chen

[tvm4j] support kNDArrayContainer (#1510)

parent f4b2c293
...@@ -187,7 +187,8 @@ public class Function extends TVMValue { ...@@ -187,7 +187,8 @@ public class Function extends TVMValue {
* @return this * @return this
*/ */
public Function pushArg(NDArrayBase arg) { public Function pushArg(NDArrayBase arg) {
Base._LIB.tvmFuncPushArgHandle(arg.handle, TypeCode.ARRAY_HANDLE.id); int id = arg.isView ? TypeCode.ARRAY_HANDLE.id : TypeCode.NDARRAY_CONTAINER.id;
Base._LIB.tvmFuncPushArgHandle(arg.handle, id);
return this; return this;
} }
...@@ -247,7 +248,9 @@ public class Function extends TVMValue { ...@@ -247,7 +248,9 @@ public class Function extends TVMValue {
} else if (arg instanceof byte[]) { } else if (arg instanceof byte[]) {
Base._LIB.tvmFuncPushArgBytes((byte[]) arg); Base._LIB.tvmFuncPushArgBytes((byte[]) arg);
} else if (arg instanceof NDArrayBase) { } else if (arg instanceof NDArrayBase) {
Base._LIB.tvmFuncPushArgHandle(((NDArrayBase) arg).handle, TypeCode.ARRAY_HANDLE.id); NDArrayBase nd = (NDArrayBase) arg;
int id = nd.isView ? TypeCode.ARRAY_HANDLE.id : TypeCode.NDARRAY_CONTAINER.id;
Base._LIB.tvmFuncPushArgHandle(nd.handle, id);
} else if (arg instanceof Module) { } else if (arg instanceof Module) {
Base._LIB.tvmFuncPushArgHandle(((Module) arg).handle, TypeCode.MODULE_HANDLE.id); Base._LIB.tvmFuncPushArgHandle(((Module) arg).handle, TypeCode.MODULE_HANDLE.id);
} else if (arg instanceof Function) { } else if (arg instanceof Function) {
......
...@@ -21,7 +21,7 @@ package ml.dmlc.tvm; ...@@ -21,7 +21,7 @@ package ml.dmlc.tvm;
public enum TypeCode { public enum TypeCode {
INT(0), UINT(1), FLOAT(2), HANDLE(3), NULL(4), TVM_TYPE(5), INT(0), UINT(1), FLOAT(2), HANDLE(3), NULL(4), TVM_TYPE(5),
TVM_CONTEXT(6), ARRAY_HANDLE(7), NODE_HANDLE(8), MODULE_HANDLE(9), TVM_CONTEXT(6), ARRAY_HANDLE(7), NODE_HANDLE(8), MODULE_HANDLE(9),
FUNC_HANDLE(10), STR(11), BYTES(12); FUNC_HANDLE(10), STR(11), BYTES(12), NDARRAY_CONTAINER(13);
public final int id; public final int id;
......
...@@ -134,10 +134,10 @@ jobject newFunction(JNIEnv *env, jlong value) { ...@@ -134,10 +134,10 @@ jobject newFunction(JNIEnv *env, jlong value) {
return object; return object;
} }
jobject newNDArray(JNIEnv *env, jlong value) { jobject newNDArray(JNIEnv *env, jlong handle, jboolean isview) {
jclass cls = env->FindClass("ml/dmlc/tvm/NDArrayBase"); jclass cls = env->FindClass("ml/dmlc/tvm/NDArrayBase");
jmethodID constructor = env->GetMethodID(cls, "<init>", "(J)V"); jmethodID constructor = env->GetMethodID(cls, "<init>", "(JZ)V");
jobject object = env->NewObject(cls, constructor, value); jobject object = env->NewObject(cls, constructor, handle, isview);
env->DeleteLocalRef(cls); env->DeleteLocalRef(cls);
return object; return object;
} }
...@@ -181,7 +181,9 @@ jobject tvmRetValueToJava(JNIEnv *env, TVMValue value, int tcode) { ...@@ -181,7 +181,9 @@ jobject tvmRetValueToJava(JNIEnv *env, TVMValue value, int tcode) {
case kFuncHandle: case kFuncHandle:
return newFunction(env, reinterpret_cast<jlong>(value.v_handle)); return newFunction(env, reinterpret_cast<jlong>(value.v_handle));
case kArrayHandle: case kArrayHandle:
return newNDArray(env, reinterpret_cast<jlong>(value.v_handle)); return newNDArray(env, reinterpret_cast<jlong>(value.v_handle), true);
case kNDArrayContainer:
return newNDArray(env, reinterpret_cast<jlong>(value.v_handle), false);
case kStr: case kStr:
return newTVMValueString(env, value.v_str); return newTVMValueString(env, value.v_str);
case kBytes: case kBytes:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment