Unverified Commit 10ae8ee1 by Tianqi Chen Committed by GitHub

[RUNTIME] Support TVMContext (#1720)

parent dd9589ec
......@@ -646,6 +646,11 @@ class TVMRetValue : public TVMPODValue_ {
value_.v_int64 = value;
return *this;
}
TVMRetValue& operator=(TVMContext value) {
this->SwitchToPOD(kTVMContext);
value_.v_ctx = value;
return *this;
}
TVMRetValue& operator=(TVMType t) {
this->SwitchToPOD(kTVMType);
value_.v_type = t;
......
......@@ -15,7 +15,7 @@ from . import ndarray as _nd
from .ndarray import NDArrayBase, _make_array
from .types import TVMValue, TypeCode
from .types import TVMPackedCFunc, TVMCFuncFinalizer
from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func
from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func, _ctx_to_int64
from .node import NodeBase
from . import node as _node
......@@ -110,7 +110,7 @@ def _make_tvm_args(args, temp_args):
values[i].v_str = c_str(str(arg))
type_codes[i] = TypeCode.STR
elif isinstance(arg, TVMContext):
values[i].v_ctx = arg
values[i].v_int64 = _ctx_to_int64(arg)
type_codes[i] = TypeCode.TVM_CONTEXT
elif isinstance(arg, bytearray):
arr = TVMByteArray()
......
......@@ -3,8 +3,9 @@
from __future__ import absolute_import as _abs
import ctypes
import struct
from ..base import py_str, check_call, _LIB
from ..runtime_ctypes import TVMByteArray, TypeCode
from ..runtime_ctypes import TVMByteArray, TypeCode, TVMContext
class TVMValue(ctypes.Union):
"""TVMValue in C API"""
......@@ -36,7 +37,7 @@ def _return_handle(x):
return handle
def _return_bytes(x):
"""return handle"""
"""return bytes"""
handle = x.v_handle
if not isinstance(handle, ctypes.c_void_p):
handle = ctypes.c_void_p(handle)
......@@ -48,6 +49,15 @@ def _return_bytes(x):
raise RuntimeError('memmove failed')
return res
def _return_context(value):
"""return TVMContext"""
# use bit unpacking from int64 view
# We use this to get around ctypes issue on Union of Structure
data = struct.pack("=q", value.v_int64)
arr = struct.unpack("=ii", data)
return TVMContext(arr[0], arr[1])
def _wrap_arg_func(return_f, type_code):
tcode = ctypes.c_int(type_code)
def _wrap_func(x):
......@@ -55,13 +65,20 @@ def _wrap_arg_func(return_f, type_code):
return return_f(x)
return _wrap_func
def _ctx_to_int64(ctx):
"""Pack context into int64 in native endian"""
data = struct.pack("=ii", ctx.device_type, ctx.device_id)
return struct.unpack("=q", data)[0]
RETURN_SWITCH = {
TypeCode.INT: lambda x: x.v_int64,
TypeCode.FLOAT: lambda x: x.v_float64,
TypeCode.HANDLE: _return_handle,
TypeCode.NULL: lambda x: None,
TypeCode.STR: lambda x: py_str(x.v_str),
TypeCode.BYTES: _return_bytes
TypeCode.BYTES: _return_bytes,
TypeCode.TVM_CONTEXT: _return_context
}
C_TO_PY_ARG_SWITCH = {
......@@ -70,5 +87,6 @@ C_TO_PY_ARG_SWITCH = {
TypeCode.HANDLE: _return_handle,
TypeCode.NULL: lambda x: None,
TypeCode.STR: lambda x: py_str(x.v_str),
TypeCode.BYTES: _return_bytes
TypeCode.BYTES: _return_bytes,
TypeCode.TVM_CONTEXT: _return_context
}
......@@ -35,6 +35,16 @@ TVM_REGISTER_API("_nop")
.set_body([](TVMArgs args, TVMRetValue *ret) {
});
TVM_REGISTER_API("_context_test")
.set_body([](TVMArgs args, TVMRetValue *ret) {
DLContext ctx = args[0];
int dtype = args[1];
int did = args[2];
CHECK_EQ(static_cast<int>(ctx.device_type), dtype);
CHECK_EQ(static_cast<int>(ctx.device_id), did);
*ret = ctx;
});
// internal fucntion used for debug and testing purposes
TVM_REGISTER_API("_ndarray_use_count")
.set_body([](TVMArgs args, TVMRetValue *ret) {
......
......@@ -70,6 +70,16 @@ def test_empty_array():
tvm.convert(myfunc)(x)
def test_ctx():
def test_ctx_func(ctx):
assert tvm.gpu(7) == ctx
return tvm.cpu(0)
x = test_ctx_func(tvm.gpu(7))
assert x == tvm.cpu(0)
x = tvm.opencl(10)
x = tvm._api_internal._context_test(x, x.device_type, x.device_id)
assert x == tvm.opencl(10)
if __name__ == "__main__":
test_empty_array()
test_get_global()
......@@ -77,3 +87,4 @@ if __name__ == "__main__":
test_convert()
test_return_func()
test_byte_array()
test_ctx()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment