Unverified Commit 10ae8ee1 by Tianqi Chen Committed by GitHub

[RUNTIME] Support TVMContext (#1720)

parent dd9589ec
...@@ -646,6 +646,11 @@ class TVMRetValue : public TVMPODValue_ { ...@@ -646,6 +646,11 @@ class TVMRetValue : public TVMPODValue_ {
value_.v_int64 = value; value_.v_int64 = value;
return *this; return *this;
} }
TVMRetValue& operator=(TVMContext value) {
this->SwitchToPOD(kTVMContext);
value_.v_ctx = value;
return *this;
}
TVMRetValue& operator=(TVMType t) { TVMRetValue& operator=(TVMType t) {
this->SwitchToPOD(kTVMType); this->SwitchToPOD(kTVMType);
value_.v_type = t; value_.v_type = t;
......
...@@ -15,7 +15,7 @@ from . import ndarray as _nd ...@@ -15,7 +15,7 @@ from . import ndarray as _nd
from .ndarray import NDArrayBase, _make_array from .ndarray import NDArrayBase, _make_array
from .types import TVMValue, TypeCode from .types import TVMValue, TypeCode
from .types import TVMPackedCFunc, TVMCFuncFinalizer from .types import TVMPackedCFunc, TVMCFuncFinalizer
from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func, _ctx_to_int64
from .node import NodeBase from .node import NodeBase
from . import node as _node from . import node as _node
...@@ -110,7 +110,7 @@ def _make_tvm_args(args, temp_args): ...@@ -110,7 +110,7 @@ def _make_tvm_args(args, temp_args):
values[i].v_str = c_str(str(arg)) values[i].v_str = c_str(str(arg))
type_codes[i] = TypeCode.STR type_codes[i] = TypeCode.STR
elif isinstance(arg, TVMContext): elif isinstance(arg, TVMContext):
values[i].v_ctx = arg values[i].v_int64 = _ctx_to_int64(arg)
type_codes[i] = TypeCode.TVM_CONTEXT type_codes[i] = TypeCode.TVM_CONTEXT
elif isinstance(arg, bytearray): elif isinstance(arg, bytearray):
arr = TVMByteArray() arr = TVMByteArray()
......
...@@ -3,8 +3,9 @@ ...@@ -3,8 +3,9 @@
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import ctypes import ctypes
import struct
from ..base import py_str, check_call, _LIB from ..base import py_str, check_call, _LIB
from ..runtime_ctypes import TVMByteArray, TypeCode from ..runtime_ctypes import TVMByteArray, TypeCode, TVMContext
class TVMValue(ctypes.Union): class TVMValue(ctypes.Union):
"""TVMValue in C API""" """TVMValue in C API"""
...@@ -36,7 +37,7 @@ def _return_handle(x): ...@@ -36,7 +37,7 @@ def _return_handle(x):
return handle return handle
def _return_bytes(x): def _return_bytes(x):
"""return handle""" """return bytes"""
handle = x.v_handle handle = x.v_handle
if not isinstance(handle, ctypes.c_void_p): if not isinstance(handle, ctypes.c_void_p):
handle = ctypes.c_void_p(handle) handle = ctypes.c_void_p(handle)
...@@ -48,6 +49,15 @@ def _return_bytes(x): ...@@ -48,6 +49,15 @@ def _return_bytes(x):
raise RuntimeError('memmove failed') raise RuntimeError('memmove failed')
return res return res
def _return_context(value):
"""return TVMContext"""
# use bit unpacking from int64 view
# We use this to get around ctypes issue on Union of Structure
data = struct.pack("=q", value.v_int64)
arr = struct.unpack("=ii", data)
return TVMContext(arr[0], arr[1])
def _wrap_arg_func(return_f, type_code): def _wrap_arg_func(return_f, type_code):
tcode = ctypes.c_int(type_code) tcode = ctypes.c_int(type_code)
def _wrap_func(x): def _wrap_func(x):
...@@ -55,13 +65,20 @@ def _wrap_arg_func(return_f, type_code): ...@@ -55,13 +65,20 @@ def _wrap_arg_func(return_f, type_code):
return return_f(x) return return_f(x)
return _wrap_func return _wrap_func
def _ctx_to_int64(ctx):
"""Pack context into int64 in native endian"""
data = struct.pack("=ii", ctx.device_type, ctx.device_id)
return struct.unpack("=q", data)[0]
RETURN_SWITCH = { RETURN_SWITCH = {
TypeCode.INT: lambda x: x.v_int64, TypeCode.INT: lambda x: x.v_int64,
TypeCode.FLOAT: lambda x: x.v_float64, TypeCode.FLOAT: lambda x: x.v_float64,
TypeCode.HANDLE: _return_handle, TypeCode.HANDLE: _return_handle,
TypeCode.NULL: lambda x: None, TypeCode.NULL: lambda x: None,
TypeCode.STR: lambda x: py_str(x.v_str), TypeCode.STR: lambda x: py_str(x.v_str),
TypeCode.BYTES: _return_bytes TypeCode.BYTES: _return_bytes,
TypeCode.TVM_CONTEXT: _return_context
} }
C_TO_PY_ARG_SWITCH = { C_TO_PY_ARG_SWITCH = {
...@@ -70,5 +87,6 @@ C_TO_PY_ARG_SWITCH = { ...@@ -70,5 +87,6 @@ C_TO_PY_ARG_SWITCH = {
TypeCode.HANDLE: _return_handle, TypeCode.HANDLE: _return_handle,
TypeCode.NULL: lambda x: None, TypeCode.NULL: lambda x: None,
TypeCode.STR: lambda x: py_str(x.v_str), TypeCode.STR: lambda x: py_str(x.v_str),
TypeCode.BYTES: _return_bytes TypeCode.BYTES: _return_bytes,
TypeCode.TVM_CONTEXT: _return_context
} }
...@@ -35,6 +35,16 @@ TVM_REGISTER_API("_nop") ...@@ -35,6 +35,16 @@ TVM_REGISTER_API("_nop")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
}); });
TVM_REGISTER_API("_context_test")
.set_body([](TVMArgs args, TVMRetValue *ret) {
DLContext ctx = args[0];
int dtype = args[1];
int did = args[2];
CHECK_EQ(static_cast<int>(ctx.device_type), dtype);
CHECK_EQ(static_cast<int>(ctx.device_id), did);
*ret = ctx;
});
// internal fucntion used for debug and testing purposes // internal fucntion used for debug and testing purposes
TVM_REGISTER_API("_ndarray_use_count") TVM_REGISTER_API("_ndarray_use_count")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
......
...@@ -70,6 +70,16 @@ def test_empty_array(): ...@@ -70,6 +70,16 @@ def test_empty_array():
tvm.convert(myfunc)(x) tvm.convert(myfunc)(x)
def test_ctx():
def test_ctx_func(ctx):
assert tvm.gpu(7) == ctx
return tvm.cpu(0)
x = test_ctx_func(tvm.gpu(7))
assert x == tvm.cpu(0)
x = tvm.opencl(10)
x = tvm._api_internal._context_test(x, x.device_type, x.device_id)
assert x == tvm.opencl(10)
if __name__ == "__main__": if __name__ == "__main__":
test_empty_array() test_empty_array()
test_get_global() test_get_global()
...@@ -77,3 +87,4 @@ if __name__ == "__main__": ...@@ -77,3 +87,4 @@ if __name__ == "__main__":
test_convert() test_convert()
test_return_func() test_return_func()
test_byte_array() test_byte_array()
test_ctx()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment