Unverified Commit e68450da by Tianqi Chen Committed by GitHub

[PY][FFI] Introduce PyNativeObject, enable runtime.String to subclass str (#5426)

To make runtime.String to work as naturally as possible in the python side,
we make it sub-class the python's str object. Note that however, we cannot
sub-class Object at the same time due to python's type layout constraint.

We introduce a PyNativeObject class to handle this kind of object sub-classing
and updated the FFI to handle PyNativeObject classes.
parent 6c77195e
......@@ -108,7 +108,7 @@ if(MSVC)
endif()
else(MSVC)
if ("${CMAKE_BUILD_TYPE}" STREQUAL "Debug")
message("Build in Debug mode")
message(STATUS "Build in Debug mode")
set(CMAKE_C_FLAGS "-O0 -g -Wall -fPIC ${CMAKE_C_FLAGS}")
set(CMAKE_CXX_FLAGS "-O0 -g -Wall -fPIC ${CMAKE_CXX_FLAGS}")
set(CMAKE_CUDA_FLAGS "-O0 -g -Xcompiler=-Wall -Xcompiler=-fPIC ${CMAKE_CUDA_FLAGS}")
......
......@@ -50,6 +50,10 @@ def _return_object(x):
tindex = ctypes.c_uint()
check_call(_LIB.TVMObjectGetTypeIndex(handle, ctypes.byref(tindex)))
cls = OBJECT_TYPE.get(tindex.value, _CLASS_OBJECT)
if issubclass(cls, PyNativeObject):
obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT)
obj.handle = handle
return cls.__from_tvm_object__(cls, obj)
# Avoid calling __init__ of cls, instead directly call __new__
# This allows child class to implement their own __init__
obj = cls.__new__(cls)
......@@ -64,6 +68,33 @@ C_TO_PY_ARG_SWITCH[TypeCode.OBJECT_RVALUE_REF_ARG] = _wrap_arg_func(
_return_object, TypeCode.OBJECT_RVALUE_REF_ARG)
class PyNativeObject:
"""Base class of all TVM objects that also subclass python's builtin types."""
__slots__ = []
def __init_tvm_object_by_constructor__(self, fconstructor, *args):
"""Initialize the internal tvm_object by calling constructor function.
Parameters
----------
fconstructor : Function
Constructor function.
args: list of objects
The arguments to the constructor
Note
----
We have a special calling convention to call constructor functions.
So the return object is directly set into the object
"""
# pylint: disable=assigning-non-slot
obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT)
obj.__init_handle_by_constructor__(fconstructor, *args)
self.__tvm_object__ = obj
class ObjectBase(object):
"""Base object for all object types"""
__slots__ = ["handle"]
......
......@@ -29,7 +29,7 @@ from .ndarray import NDArrayBase, _make_array
from .types import TVMValue, TypeCode
from .types import TVMPackedCFunc, TVMCFuncFinalizer
from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func, _ctx_to_int64
from .object import ObjectBase, _set_class_object
from .object import ObjectBase, PyNativeObject, _set_class_object
from . import object as _object
PackedFuncHandle = ctypes.c_void_p
......@@ -123,6 +123,9 @@ def _make_tvm_args(args, temp_args):
values[i].v_handle = ctypes.cast(arg.handle, ctypes.c_void_p)
type_codes[i] = (TypeCode.NDARRAY_HANDLE
if not arg.is_view else TypeCode.DLTENSOR_HANDLE)
elif isinstance(arg, PyNativeObject):
values[i].v_handle = arg.__tvm_object__.handle
type_codes[i] = TypeCode.OBJECT_HANDLE
elif isinstance(arg, _nd._TVM_COMPATS):
values[i].v_handle = ctypes.c_void_p(arg._tvm_handle)
type_codes[i] = arg.__class__._tvm_tcode
......
......@@ -39,18 +39,49 @@ cdef inline object make_ret_object(void* chandle):
object_type = OBJECT_TYPE
handle = ctypes_handle(chandle)
CALL(TVMObjectGetTypeIndex(chandle, &tindex))
if tindex < len(OBJECT_TYPE):
cls = OBJECT_TYPE[tindex]
if cls is not None:
if issubclass(cls, PyNativeObject):
obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT)
(<ObjectBase>obj).chandle = chandle
return cls.__from_tvm_object__(cls, obj)
obj = cls.__new__(cls)
else:
obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT)
else:
obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT)
(<ObjectBase>obj).chandle = chandle
return obj
class PyNativeObject:
"""Base class of all TVM objects that also subclass python's builtin types."""
__slots__ = []
def __init_tvm_object_by_constructor__(self, fconstructor, *args):
"""Initialize the internal tvm_object by calling constructor function.
Parameters
----------
fconstructor : Function
Constructor function.
args: list of objects
The arguments to the constructor
Note
----
We have a special calling convention to call constructor functions.
So the return object is directly set into the object
"""
obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT)
obj.__init_handle_by_constructor__(fconstructor, *args)
self.__tvm_object__ = obj
cdef class ObjectBase:
cdef void* chandle
......
......@@ -109,6 +109,9 @@ cdef inline int make_arg(object arg,
value[0].v_handle = (<NDArrayBase>arg).chandle
tcode[0] = (kTVMNDArrayHandle if
not (<NDArrayBase>arg).c_is_view else kTVMDLTensorHandle)
elif isinstance(arg, PyNativeObject):
value[0].v_handle = (<ObjectBase>(arg.__tvm_object__)).chandle
tcode[0] = kTVMObjectHandle
elif isinstance(arg, _TVM_COMPATS):
ptr = arg._tvm_handle
value[0].v_handle = (<void*>ptr)
......
......@@ -16,9 +16,10 @@
# under the License.
"""Runtime container structures."""
import tvm._ffi
from tvm._ffi.base import string_types
from tvm.runtime import Object, ObjectTypes
from tvm.runtime import _ffi_api
from .object import Object, PyNativeObject
from .object_generic import ObjectTypes
from . import _ffi_api
def getitem_helper(obj, elem_getter, length, idx):
"""Helper function to implement a pythonic getitem function.
......@@ -112,64 +113,26 @@ def tuple_object(fields=None):
@tvm._ffi.register_object("runtime.String")
class String(Object):
"""The string object.
class String(str, PyNativeObject):
"""TVM runtime.String object, represented as a python str.
Parameters
----------
string : str
The string used to construct a runtime String object
Returns
-------
ret : String
The created object.
content : str
The content string used to construct the object.
"""
def __init__(self, string):
self.__init_handle_by_constructor__(_ffi_api.String, string)
def __str__(self):
return _ffi_api.GetStdString(self)
def __len__(self):
return _ffi_api.GetStringSize(self)
def __hash__(self):
return _ffi_api.StringHash(self)
def __eq__(self, other):
if isinstance(other, string_types):
return self.__str__() == other
if not isinstance(other, String):
return False
return _ffi_api.CompareString(self, other) == 0
def __ne__(self, other):
return not self.__eq__(other)
def __gt__(self, other):
return _ffi_api.CompareString(self, other) > 0
def __lt__(self, other):
return _ffi_api.CompareString(self, other) < 0
def __getitem__(self, key):
return self.__str__()[key]
def startswith(self, string):
"""Check if the runtime string starts with a given string
Parameters
----------
string : str
The provided string
Returns
-------
ret : boolean
Return true if the runtime string starts with the given string,
otherwise, false.
"""
return self.__str__().startswith(string)
__slots__ = ["__tvm_object__"]
def __new__(cls, content):
"""Construct from string content."""
val = str.__new__(cls, content)
val.__init_tvm_object_by_constructor__(_ffi_api.String, content)
return val
# pylint: disable=no-self-argument
def __from_tvm_object__(cls, obj):
"""Construct from a given tvm object."""
content = _ffi_api.GetFFIString(obj)
val = str.__new__(cls, content)
val.__tvm_object__ = obj
return val
......@@ -27,11 +27,11 @@ try:
if _FFI_MODE == "ctypes":
raise ImportError()
from tvm._ffi._cy3.core import _set_class_object, _set_class_object_generic
from tvm._ffi._cy3.core import ObjectBase
from tvm._ffi._cy3.core import ObjectBase, PyNativeObject
except (RuntimeError, ImportError):
# pylint: disable=wrong-import-position,unused-import
from tvm._ffi._ctypes.packed_func import _set_class_object, _set_class_object_generic
from tvm._ffi._ctypes.object import ObjectBase
from tvm._ffi._ctypes.object import ObjectBase, PyNativeObject
def _new_object(cls):
......@@ -41,6 +41,7 @@ def _new_object(cls):
class Object(ObjectBase):
"""Base class for all tvm's runtime objects."""
__slots__ = []
def __repr__(self):
return _ffi_node_api.AsRepr(self)
......@@ -78,13 +79,10 @@ class Object(ObjectBase):
def __setstate__(self, state):
# pylint: disable=assigning-non-slot, assignment-from-no-return
handle = state['handle']
self.handle = None
if handle is not None:
json_str = handle
other = _ffi_node_api.LoadJSON(json_str)
self.handle = other.handle
other.handle = None
else:
self.handle = None
self.__init_handle_by_constructor__(
_ffi_node_api.LoadJSON, handle)
def _move(self):
"""Create an RValue reference to the object and mark the object as moved.
......
......@@ -21,7 +21,7 @@ from tvm._ffi.base import string_types
from tvm._ffi.runtime_ctypes import ObjectRValueRef
from . import _ffi_node_api, _ffi_api
from .object import ObjectBase, _set_class_object_generic
from .object import ObjectBase, PyNativeObject, _set_class_object_generic
from .ndarray import NDArrayBase
from .packed_func import PackedFuncBase, convert_to_tvm_func
from .module import Module
......@@ -34,7 +34,7 @@ class ObjectGeneric(object):
raise NotImplementedError()
ObjectTypes = (ObjectBase, NDArrayBase, Module, ObjectRValueRef)
ObjectTypes = (ObjectBase, NDArrayBase, Module, ObjectRValueRef, PyNativeObject)
def convert_to_object(value):
......
......@@ -19,7 +19,7 @@
/*!
* \file src/runtime/container.cc
* \brief Implementations of common plain old data (POD) containers.
* \brief Implementations of common containers.
*/
#include <tvm/runtime/container.h>
#include <tvm/runtime/memory.h>
......@@ -81,26 +81,11 @@ TVM_REGISTER_GLOBAL("runtime.String")
return String(std::move(str));
});
TVM_REGISTER_GLOBAL("runtime.GetStringSize")
.set_body_typed([](String str) {
return static_cast<int64_t>(str.size());
});
TVM_REGISTER_GLOBAL("runtime.GetStdString")
TVM_REGISTER_GLOBAL("runtime.GetFFIString")
.set_body_typed([](String str) {
return std::string(str);
});
TVM_REGISTER_GLOBAL("runtime.CompareString")
.set_body_typed([](String lhs, String rhs) {
return lhs.compare(rhs);
});
TVM_REGISTER_GLOBAL("runtime.StringHash")
.set_body_typed([](String str) {
return static_cast<int64_t>(std::hash<String>()(str));
});
TVM_REGISTER_OBJECT_TYPE(ADTObj);
TVM_REGISTER_OBJECT_TYPE(StringObj);
TVM_REGISTER_OBJECT_TYPE(ClosureObj);
......
......@@ -58,6 +58,11 @@ TVM_REGISTER_GLOBAL("testing.nop")
.set_body([](TVMArgs args, TVMRetValue *ret) {
});
TVM_REGISTER_GLOBAL("testing.echo")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = args[0];
});
TVM_REGISTER_GLOBAL("testing.test_wrap_callback")
.set_body([](TVMArgs args, TVMRetValue *ret) {
PackedFunc pf = args[0];
......
......@@ -17,6 +17,7 @@
import numpy as np
import tvm
import pickle
from tvm import te
from tvm import nd, relay
from tvm.runtime import container as _container
......@@ -56,6 +57,29 @@ def test_tuple_object():
tvm.testing.assert_allclose(out.asnumpy(), np.array(11))
def test_string():
s = tvm.runtime.String("xyz")
assert isinstance(s, tvm.runtime.String)
assert isinstance(s, str)
assert s.startswith("xy")
assert s + "1" == "xyz1"
y = tvm.testing.echo(s)
assert isinstance(y, tvm.runtime.String)
assert s.__tvm_object__.same_as(y.__tvm_object__)
assert s == y
x = tvm.ir.load_json(tvm.ir.save_json(y))
assert isinstance(x, tvm.runtime.String)
assert x == y
# test pickle
z = pickle.loads(pickle.dumps(s))
assert isinstance(z, tvm.runtime.String)
assert s == z
if __name__ == "__main__":
test_string()
test_adt_constructor()
test_tuple_object()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment