Unverified Commit e68450da by Tianqi Chen Committed by GitHub

[PY][FFI] Introduce PyNativeObject, enable runtime.String to subclass str (#5426)

To make runtime.String to work as naturally as possible in the python side,
we make it sub-class the python's str object. Note that however, we cannot
sub-class Object at the same time due to python's type layout constraint.

We introduce a PyNativeObject class to handle this kind of object sub-classing
and updated the FFI to handle PyNativeObject classes.
parent 6c77195e
...@@ -108,7 +108,7 @@ if(MSVC) ...@@ -108,7 +108,7 @@ if(MSVC)
endif() endif()
else(MSVC) else(MSVC)
if ("${CMAKE_BUILD_TYPE}" STREQUAL "Debug") if ("${CMAKE_BUILD_TYPE}" STREQUAL "Debug")
message("Build in Debug mode") message(STATUS "Build in Debug mode")
set(CMAKE_C_FLAGS "-O0 -g -Wall -fPIC ${CMAKE_C_FLAGS}") set(CMAKE_C_FLAGS "-O0 -g -Wall -fPIC ${CMAKE_C_FLAGS}")
set(CMAKE_CXX_FLAGS "-O0 -g -Wall -fPIC ${CMAKE_CXX_FLAGS}") set(CMAKE_CXX_FLAGS "-O0 -g -Wall -fPIC ${CMAKE_CXX_FLAGS}")
set(CMAKE_CUDA_FLAGS "-O0 -g -Xcompiler=-Wall -Xcompiler=-fPIC ${CMAKE_CUDA_FLAGS}") set(CMAKE_CUDA_FLAGS "-O0 -g -Xcompiler=-Wall -Xcompiler=-fPIC ${CMAKE_CUDA_FLAGS}")
......
...@@ -50,6 +50,10 @@ def _return_object(x): ...@@ -50,6 +50,10 @@ def _return_object(x):
tindex = ctypes.c_uint() tindex = ctypes.c_uint()
check_call(_LIB.TVMObjectGetTypeIndex(handle, ctypes.byref(tindex))) check_call(_LIB.TVMObjectGetTypeIndex(handle, ctypes.byref(tindex)))
cls = OBJECT_TYPE.get(tindex.value, _CLASS_OBJECT) cls = OBJECT_TYPE.get(tindex.value, _CLASS_OBJECT)
if issubclass(cls, PyNativeObject):
obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT)
obj.handle = handle
return cls.__from_tvm_object__(cls, obj)
# Avoid calling __init__ of cls, instead directly call __new__ # Avoid calling __init__ of cls, instead directly call __new__
# This allows child class to implement their own __init__ # This allows child class to implement their own __init__
obj = cls.__new__(cls) obj = cls.__new__(cls)
...@@ -64,6 +68,33 @@ C_TO_PY_ARG_SWITCH[TypeCode.OBJECT_RVALUE_REF_ARG] = _wrap_arg_func( ...@@ -64,6 +68,33 @@ C_TO_PY_ARG_SWITCH[TypeCode.OBJECT_RVALUE_REF_ARG] = _wrap_arg_func(
_return_object, TypeCode.OBJECT_RVALUE_REF_ARG) _return_object, TypeCode.OBJECT_RVALUE_REF_ARG)
class PyNativeObject:
"""Base class of all TVM objects that also subclass python's builtin types."""
__slots__ = []
def __init_tvm_object_by_constructor__(self, fconstructor, *args):
"""Initialize the internal tvm_object by calling constructor function.
Parameters
----------
fconstructor : Function
Constructor function.
args: list of objects
The arguments to the constructor
Note
----
We have a special calling convention to call constructor functions.
So the return object is directly set into the object
"""
# pylint: disable=assigning-non-slot
obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT)
obj.__init_handle_by_constructor__(fconstructor, *args)
self.__tvm_object__ = obj
class ObjectBase(object): class ObjectBase(object):
"""Base object for all object types""" """Base object for all object types"""
__slots__ = ["handle"] __slots__ = ["handle"]
......
...@@ -29,7 +29,7 @@ from .ndarray import NDArrayBase, _make_array ...@@ -29,7 +29,7 @@ from .ndarray import NDArrayBase, _make_array
from .types import TVMValue, TypeCode from .types import TVMValue, TypeCode
from .types import TVMPackedCFunc, TVMCFuncFinalizer from .types import TVMPackedCFunc, TVMCFuncFinalizer
from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func, _ctx_to_int64 from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func, _ctx_to_int64
from .object import ObjectBase, _set_class_object from .object import ObjectBase, PyNativeObject, _set_class_object
from . import object as _object from . import object as _object
PackedFuncHandle = ctypes.c_void_p PackedFuncHandle = ctypes.c_void_p
...@@ -123,6 +123,9 @@ def _make_tvm_args(args, temp_args): ...@@ -123,6 +123,9 @@ def _make_tvm_args(args, temp_args):
values[i].v_handle = ctypes.cast(arg.handle, ctypes.c_void_p) values[i].v_handle = ctypes.cast(arg.handle, ctypes.c_void_p)
type_codes[i] = (TypeCode.NDARRAY_HANDLE type_codes[i] = (TypeCode.NDARRAY_HANDLE
if not arg.is_view else TypeCode.DLTENSOR_HANDLE) if not arg.is_view else TypeCode.DLTENSOR_HANDLE)
elif isinstance(arg, PyNativeObject):
values[i].v_handle = arg.__tvm_object__.handle
type_codes[i] = TypeCode.OBJECT_HANDLE
elif isinstance(arg, _nd._TVM_COMPATS): elif isinstance(arg, _nd._TVM_COMPATS):
values[i].v_handle = ctypes.c_void_p(arg._tvm_handle) values[i].v_handle = ctypes.c_void_p(arg._tvm_handle)
type_codes[i] = arg.__class__._tvm_tcode type_codes[i] = arg.__class__._tvm_tcode
......
...@@ -39,18 +39,49 @@ cdef inline object make_ret_object(void* chandle): ...@@ -39,18 +39,49 @@ cdef inline object make_ret_object(void* chandle):
object_type = OBJECT_TYPE object_type = OBJECT_TYPE
handle = ctypes_handle(chandle) handle = ctypes_handle(chandle)
CALL(TVMObjectGetTypeIndex(chandle, &tindex)) CALL(TVMObjectGetTypeIndex(chandle, &tindex))
if tindex < len(OBJECT_TYPE): if tindex < len(OBJECT_TYPE):
cls = OBJECT_TYPE[tindex] cls = OBJECT_TYPE[tindex]
if cls is not None: if cls is not None:
if issubclass(cls, PyNativeObject):
obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT)
(<ObjectBase>obj).chandle = chandle
return cls.__from_tvm_object__(cls, obj)
obj = cls.__new__(cls) obj = cls.__new__(cls)
else: else:
obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT) obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT)
else: else:
obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT) obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT)
(<ObjectBase>obj).chandle = chandle (<ObjectBase>obj).chandle = chandle
return obj return obj
class PyNativeObject:
"""Base class of all TVM objects that also subclass python's builtin types."""
__slots__ = []
def __init_tvm_object_by_constructor__(self, fconstructor, *args):
"""Initialize the internal tvm_object by calling constructor function.
Parameters
----------
fconstructor : Function
Constructor function.
args: list of objects
The arguments to the constructor
Note
----
We have a special calling convention to call constructor functions.
So the return object is directly set into the object
"""
obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT)
obj.__init_handle_by_constructor__(fconstructor, *args)
self.__tvm_object__ = obj
cdef class ObjectBase: cdef class ObjectBase:
cdef void* chandle cdef void* chandle
......
...@@ -109,6 +109,9 @@ cdef inline int make_arg(object arg, ...@@ -109,6 +109,9 @@ cdef inline int make_arg(object arg,
value[0].v_handle = (<NDArrayBase>arg).chandle value[0].v_handle = (<NDArrayBase>arg).chandle
tcode[0] = (kTVMNDArrayHandle if tcode[0] = (kTVMNDArrayHandle if
not (<NDArrayBase>arg).c_is_view else kTVMDLTensorHandle) not (<NDArrayBase>arg).c_is_view else kTVMDLTensorHandle)
elif isinstance(arg, PyNativeObject):
value[0].v_handle = (<ObjectBase>(arg.__tvm_object__)).chandle
tcode[0] = kTVMObjectHandle
elif isinstance(arg, _TVM_COMPATS): elif isinstance(arg, _TVM_COMPATS):
ptr = arg._tvm_handle ptr = arg._tvm_handle
value[0].v_handle = (<void*>ptr) value[0].v_handle = (<void*>ptr)
......
...@@ -16,9 +16,10 @@ ...@@ -16,9 +16,10 @@
# under the License. # under the License.
"""Runtime container structures.""" """Runtime container structures."""
import tvm._ffi import tvm._ffi
from tvm._ffi.base import string_types from .object import Object, PyNativeObject
from tvm.runtime import Object, ObjectTypes from .object_generic import ObjectTypes
from tvm.runtime import _ffi_api from . import _ffi_api
def getitem_helper(obj, elem_getter, length, idx): def getitem_helper(obj, elem_getter, length, idx):
"""Helper function to implement a pythonic getitem function. """Helper function to implement a pythonic getitem function.
...@@ -112,64 +113,26 @@ def tuple_object(fields=None): ...@@ -112,64 +113,26 @@ def tuple_object(fields=None):
@tvm._ffi.register_object("runtime.String") @tvm._ffi.register_object("runtime.String")
class String(Object): class String(str, PyNativeObject):
"""The string object. """TVM runtime.String object, represented as a python str.
Parameters
----------
string : str
The string used to construct a runtime String object
Returns
-------
ret : String
The created object.
"""
def __init__(self, string):
self.__init_handle_by_constructor__(_ffi_api.String, string)
def __str__(self):
return _ffi_api.GetStdString(self)
def __len__(self):
return _ffi_api.GetStringSize(self)
def __hash__(self):
return _ffi_api.StringHash(self)
def __eq__(self, other):
if isinstance(other, string_types):
return self.__str__() == other
if not isinstance(other, String):
return False
return _ffi_api.CompareString(self, other) == 0
def __ne__(self, other):
return not self.__eq__(other)
def __gt__(self, other):
return _ffi_api.CompareString(self, other) > 0
def __lt__(self, other):
return _ffi_api.CompareString(self, other) < 0
def __getitem__(self, key):
return self.__str__()[key]
def startswith(self, string):
"""Check if the runtime string starts with a given string
Parameters Parameters
---------- ----------
string : str content : str
The provided string The content string used to construct the object.
Returns
-------
ret : boolean
Return true if the runtime string starts with the given string,
otherwise, false.
""" """
return self.__str__().startswith(string) __slots__ = ["__tvm_object__"]
def __new__(cls, content):
"""Construct from string content."""
val = str.__new__(cls, content)
val.__init_tvm_object_by_constructor__(_ffi_api.String, content)
return val
# pylint: disable=no-self-argument
def __from_tvm_object__(cls, obj):
"""Construct from a given tvm object."""
content = _ffi_api.GetFFIString(obj)
val = str.__new__(cls, content)
val.__tvm_object__ = obj
return val
...@@ -27,11 +27,11 @@ try: ...@@ -27,11 +27,11 @@ try:
if _FFI_MODE == "ctypes": if _FFI_MODE == "ctypes":
raise ImportError() raise ImportError()
from tvm._ffi._cy3.core import _set_class_object, _set_class_object_generic from tvm._ffi._cy3.core import _set_class_object, _set_class_object_generic
from tvm._ffi._cy3.core import ObjectBase from tvm._ffi._cy3.core import ObjectBase, PyNativeObject
except (RuntimeError, ImportError): except (RuntimeError, ImportError):
# pylint: disable=wrong-import-position,unused-import # pylint: disable=wrong-import-position,unused-import
from tvm._ffi._ctypes.packed_func import _set_class_object, _set_class_object_generic from tvm._ffi._ctypes.packed_func import _set_class_object, _set_class_object_generic
from tvm._ffi._ctypes.object import ObjectBase from tvm._ffi._ctypes.object import ObjectBase, PyNativeObject
def _new_object(cls): def _new_object(cls):
...@@ -41,6 +41,7 @@ def _new_object(cls): ...@@ -41,6 +41,7 @@ def _new_object(cls):
class Object(ObjectBase): class Object(ObjectBase):
"""Base class for all tvm's runtime objects.""" """Base class for all tvm's runtime objects."""
__slots__ = []
def __repr__(self): def __repr__(self):
return _ffi_node_api.AsRepr(self) return _ffi_node_api.AsRepr(self)
...@@ -78,13 +79,10 @@ class Object(ObjectBase): ...@@ -78,13 +79,10 @@ class Object(ObjectBase):
def __setstate__(self, state): def __setstate__(self, state):
# pylint: disable=assigning-non-slot, assignment-from-no-return # pylint: disable=assigning-non-slot, assignment-from-no-return
handle = state['handle'] handle = state['handle']
if handle is not None:
json_str = handle
other = _ffi_node_api.LoadJSON(json_str)
self.handle = other.handle
other.handle = None
else:
self.handle = None self.handle = None
if handle is not None:
self.__init_handle_by_constructor__(
_ffi_node_api.LoadJSON, handle)
def _move(self): def _move(self):
"""Create an RValue reference to the object and mark the object as moved. """Create an RValue reference to the object and mark the object as moved.
......
...@@ -21,7 +21,7 @@ from tvm._ffi.base import string_types ...@@ -21,7 +21,7 @@ from tvm._ffi.base import string_types
from tvm._ffi.runtime_ctypes import ObjectRValueRef from tvm._ffi.runtime_ctypes import ObjectRValueRef
from . import _ffi_node_api, _ffi_api from . import _ffi_node_api, _ffi_api
from .object import ObjectBase, _set_class_object_generic from .object import ObjectBase, PyNativeObject, _set_class_object_generic
from .ndarray import NDArrayBase from .ndarray import NDArrayBase
from .packed_func import PackedFuncBase, convert_to_tvm_func from .packed_func import PackedFuncBase, convert_to_tvm_func
from .module import Module from .module import Module
...@@ -34,7 +34,7 @@ class ObjectGeneric(object): ...@@ -34,7 +34,7 @@ class ObjectGeneric(object):
raise NotImplementedError() raise NotImplementedError()
ObjectTypes = (ObjectBase, NDArrayBase, Module, ObjectRValueRef) ObjectTypes = (ObjectBase, NDArrayBase, Module, ObjectRValueRef, PyNativeObject)
def convert_to_object(value): def convert_to_object(value):
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
/*! /*!
* \file src/runtime/container.cc * \file src/runtime/container.cc
* \brief Implementations of common plain old data (POD) containers. * \brief Implementations of common containers.
*/ */
#include <tvm/runtime/container.h> #include <tvm/runtime/container.h>
#include <tvm/runtime/memory.h> #include <tvm/runtime/memory.h>
...@@ -81,26 +81,11 @@ TVM_REGISTER_GLOBAL("runtime.String") ...@@ -81,26 +81,11 @@ TVM_REGISTER_GLOBAL("runtime.String")
return String(std::move(str)); return String(std::move(str));
}); });
TVM_REGISTER_GLOBAL("runtime.GetStringSize") TVM_REGISTER_GLOBAL("runtime.GetFFIString")
.set_body_typed([](String str) {
return static_cast<int64_t>(str.size());
});
TVM_REGISTER_GLOBAL("runtime.GetStdString")
.set_body_typed([](String str) { .set_body_typed([](String str) {
return std::string(str); return std::string(str);
}); });
TVM_REGISTER_GLOBAL("runtime.CompareString")
.set_body_typed([](String lhs, String rhs) {
return lhs.compare(rhs);
});
TVM_REGISTER_GLOBAL("runtime.StringHash")
.set_body_typed([](String str) {
return static_cast<int64_t>(std::hash<String>()(str));
});
TVM_REGISTER_OBJECT_TYPE(ADTObj); TVM_REGISTER_OBJECT_TYPE(ADTObj);
TVM_REGISTER_OBJECT_TYPE(StringObj); TVM_REGISTER_OBJECT_TYPE(StringObj);
TVM_REGISTER_OBJECT_TYPE(ClosureObj); TVM_REGISTER_OBJECT_TYPE(ClosureObj);
......
...@@ -58,6 +58,11 @@ TVM_REGISTER_GLOBAL("testing.nop") ...@@ -58,6 +58,11 @@ TVM_REGISTER_GLOBAL("testing.nop")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
}); });
TVM_REGISTER_GLOBAL("testing.echo")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = args[0];
});
TVM_REGISTER_GLOBAL("testing.test_wrap_callback") TVM_REGISTER_GLOBAL("testing.test_wrap_callback")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
PackedFunc pf = args[0]; PackedFunc pf = args[0];
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
import numpy as np import numpy as np
import tvm import tvm
import pickle
from tvm import te from tvm import te
from tvm import nd, relay from tvm import nd, relay
from tvm.runtime import container as _container from tvm.runtime import container as _container
...@@ -56,6 +57,29 @@ def test_tuple_object(): ...@@ -56,6 +57,29 @@ def test_tuple_object():
tvm.testing.assert_allclose(out.asnumpy(), np.array(11)) tvm.testing.assert_allclose(out.asnumpy(), np.array(11))
def test_string():
s = tvm.runtime.String("xyz")
assert isinstance(s, tvm.runtime.String)
assert isinstance(s, str)
assert s.startswith("xy")
assert s + "1" == "xyz1"
y = tvm.testing.echo(s)
assert isinstance(y, tvm.runtime.String)
assert s.__tvm_object__.same_as(y.__tvm_object__)
assert s == y
x = tvm.ir.load_json(tvm.ir.save_json(y))
assert isinstance(x, tvm.runtime.String)
assert x == y
# test pickle
z = pickle.loads(pickle.dumps(s))
assert isinstance(z, tvm.runtime.String)
assert s == z
if __name__ == "__main__": if __name__ == "__main__":
test_string()
test_adt_constructor() test_adt_constructor()
test_tuple_object() test_tuple_object()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment