Unverified Commit e91cc5ab by Tianqi Chen Committed by GitHub

[RUNTIME] Remove Extension VTable in favor of Unified Object system. (#4578)

Before the unified object protocol, we support pass
additional extension objects around by declaring a type as an extension type.
The old extension mechanism requires the types to register their
constructor and deleter to a VTable and does not enjoy the benefit of the
self-contained deletion property of the new Object system.

This PR upgrades the extension example to make use of the new object system
and removed the old Extension VTable.

Note that the register_extension funtion in the python side continues to work
when the passed argument does not require explicit container copy/deletion,
which covers the current usecases of the extension mechanism.
parent ff65698f
...@@ -20,8 +20,7 @@ TVM_ROOT=$(shell cd ../..; pwd) ...@@ -20,8 +20,7 @@ TVM_ROOT=$(shell cd ../..; pwd)
PKG_CFLAGS = -std=c++11 -O2 -fPIC\ PKG_CFLAGS = -std=c++11 -O2 -fPIC\
-I${TVM_ROOT}/include\ -I${TVM_ROOT}/include\
-I${TVM_ROOT}/3rdparty/dmlc-core/include\ -I${TVM_ROOT}/3rdparty/dmlc-core/include\
-I${TVM_ROOT}/3rdparty/dlpack/include\ -I${TVM_ROOT}/3rdparty/dlpack/include
-I${TVM_ROOT}/3rdparty/HalideIR/src
PKG_LDFLAGS =-L${TVM_ROOT}/build PKG_LDFLAGS =-L${TVM_ROOT}/build
UNAME_S := $(shell uname -s) UNAME_S := $(shell uname -s)
......
...@@ -38,18 +38,9 @@ sym_add = tvm.get_global_func("tvm_ext.sym_add") ...@@ -38,18 +38,9 @@ sym_add = tvm.get_global_func("tvm_ext.sym_add")
ivec_create = tvm.get_global_func("tvm_ext.ivec_create") ivec_create = tvm.get_global_func("tvm_ext.ivec_create")
ivec_get = tvm.get_global_func("tvm_ext.ivec_get") ivec_get = tvm.get_global_func("tvm_ext.ivec_get")
class IntVec(object): @tvm.register_object("tvm_ext.IntVector")
class IntVec(tvm.Object):
"""Example for using extension class in c++ """ """Example for using extension class in c++ """
_tvm_tcode = 17
def __init__(self, handle):
self.handle = handle
def __del__(self):
# You can also call your own customized
# deleter if you can free it via your own FFI.
tvm.nd.free_extension_handle(self.handle, self.__class__._tvm_tcode)
@property @property
def _tvm_handle(self): def _tvm_handle(self):
return self.handle.value return self.handle.value
...@@ -57,9 +48,6 @@ class IntVec(object): ...@@ -57,9 +48,6 @@ class IntVec(object):
def __getitem__(self, idx): def __getitem__(self, idx):
return ivec_get(self, idx) return ivec_get(self, idx)
# Register IntVec extension on python side.
tvm.register_extension(IntVec, IntVec)
nd_create = tvm.get_global_func("tvm_ext.nd_create") nd_create = tvm.get_global_func("tvm_ext.nd_create")
nd_add_two = tvm.get_global_func("tvm_ext.nd_add_two") nd_add_two = tvm.get_global_func("tvm_ext.nd_add_two")
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -30,17 +30,12 @@ ...@@ -30,17 +30,12 @@
#include <tvm/runtime/device_api.h> #include <tvm/runtime/device_api.h>
namespace tvm_ext { namespace tvm_ext {
using IntVector = std::vector<int>;
class NDSubClass; class NDSubClass;
} // namespace tvm_ext } // namespace tvm_ext
namespace tvm { namespace tvm {
namespace runtime { namespace runtime {
template<> template<>
struct extension_type_info<tvm_ext::IntVector> {
static const int code = 17;
};
template<>
struct array_type_info<tvm_ext::NDSubClass> { struct array_type_info<tvm_ext::NDSubClass> {
static const int code = 1; static const int code = 1;
}; };
...@@ -104,24 +99,47 @@ class NDSubClass : public tvm::runtime::NDArray { ...@@ -104,24 +99,47 @@ class NDSubClass : public tvm::runtime::NDArray {
return self->addtional_info_; return self->addtional_info_;
} }
}; };
/*!
* \brief Introduce additional extension data structures
* by sub-classing TVM's object system.
*/
class IntVectorObj : public Object {
public:
std::vector<int> vec;
static constexpr const char* _type_key = "tvm_ext.IntVector";
TVM_DECLARE_FINAL_OBJECT_INFO(IntVectorObj, Object);
};
/*!
* \brief Int vector reference class.
*/
class IntVector : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(IntVector, ObjectRef, IntVectorObj);
};
TVM_REGISTER_OBJECT_TYPE(IntVectorObj);
} // namespace tvm_ext } // namespace tvm_ext
namespace tvm_ext { namespace tvm_ext {
TVM_REGISTER_EXT_TYPE(IntVector);
TVM_REGISTER_GLOBAL("tvm_ext.ivec_create") TVM_REGISTER_GLOBAL("tvm_ext.ivec_create")
.set_body([](TVMArgs args, TVMRetValue *rv) { .set_body([](TVMArgs args, TVMRetValue *rv) {
IntVector vec; auto n = tvm::runtime::make_object<IntVectorObj>();
for (int i = 0; i < args.size(); ++i) { for (int i = 0; i < args.size(); ++i) {
vec.push_back(args[i].operator int()); n->vec.push_back(args[i].operator int());
} }
*rv = vec; *rv = IntVector(n);
}); });
TVM_REGISTER_GLOBAL("tvm_ext.ivec_get") TVM_REGISTER_GLOBAL("tvm_ext.ivec_get")
.set_body([](TVMArgs args, TVMRetValue *rv) { .set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = args[0].AsExtension<IntVector>()[args[1].operator int()]; IntVector p = args[0];
*rv = p->vec[args[1].operator int()];
}); });
......
...@@ -235,14 +235,6 @@ TVM_DLL int TVMModGetFunction(TVMModuleHandle mod, ...@@ -235,14 +235,6 @@ TVM_DLL int TVMModGetFunction(TVMModuleHandle mod,
TVMFunctionHandle *out); TVMFunctionHandle *out);
/*! /*!
* \brief Free front-end extension type resource.
* \param handle The extension handle.
* \param type_code The type of of the extension type.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMExtTypeFree(void* handle, int type_code);
/*!
* \brief Free the Module * \brief Free the Module
* \param mod The module to be freed. * \param mod The module to be freed.
* *
......
...@@ -387,7 +387,6 @@ inline std::string TVMType2String(TVMType t); ...@@ -387,7 +387,6 @@ inline std::string TVMType2String(TVMType t);
#define TVM_CHECK_TYPE_CODE(CODE, T) \ #define TVM_CHECK_TYPE_CODE(CODE, T) \
CHECK_EQ(CODE, T) << " expected " \ CHECK_EQ(CODE, T) << " expected " \
<< TypeCode2Str(T) << " but get " << TypeCode2Str(CODE) \ << TypeCode2Str(T) << " but get " << TypeCode2Str(CODE) \
/*! /*!
* \brief Type traits to mark if a class is tvm extension type. * \brief Type traits to mark if a class is tvm extension type.
* *
...@@ -405,34 +404,6 @@ struct extension_type_info { ...@@ -405,34 +404,6 @@ struct extension_type_info {
}; };
/*! /*!
* \brief Runtime function table about extension type.
*/
class ExtTypeVTable {
public:
/*! \brief function to be called to delete a handle */
void (*destroy)(void* handle);
/*! \brief function to be called when clone a handle */
void* (*clone)(void* handle);
/*!
* \brief Register type
* \tparam T The type to be register.
* \return The registered vtable.
*/
template <typename T>
static inline ExtTypeVTable* Register_();
/*!
* \brief Get a vtable based on type code.
* \param type_code The type code
* \return The registered vtable.
*/
TVM_DLL static ExtTypeVTable* Get(int type_code);
private:
// Internal registration function.
TVM_DLL static ExtTypeVTable* RegisterInternal(int type_code, const ExtTypeVTable& vt);
};
/*!
* \brief Internal base class to * \brief Internal base class to
* handle conversion to POD values. * handle conversion to POD values.
*/ */
...@@ -518,11 +489,6 @@ class TVMPODValue_ { ...@@ -518,11 +489,6 @@ class TVMPODValue_ {
CHECK_EQ(container->array_type_code_, array_type_info<TNDArray>::code); CHECK_EQ(container->array_type_code_, array_type_info<TNDArray>::code);
return TNDArray(container); return TNDArray(container);
} }
template<typename TExtension>
const TExtension& AsExtension() const {
CHECK_LT(type_code_, kExtEnd);
return static_cast<TExtension*>(value_.v_handle)[0];
}
template<typename TObjectRef, template<typename TObjectRef,
typename = typename std::enable_if< typename = typename std::enable_if<
std::is_class<TObjectRef>::value>::type> std::is_class<TObjectRef>::value>::type>
...@@ -867,20 +833,8 @@ class TVMRetValue : public TVMPODValue_ { ...@@ -867,20 +833,8 @@ class TVMRetValue : public TVMPODValue_ {
break; break;
} }
default: { default: {
if (other.type_code() < kExtBegin) { SwitchToPOD(other.type_code());
SwitchToPOD(other.type_code()); value_ = other.value_;
value_ = other.value_;
} else {
#if TVM_RUNTIME_HEADER_ONLY
LOG(FATAL) << "Header only mode do not support ext type";
#else
this->Clear();
type_code_ = other.type_code();
value_.v_handle =
(*(ExtTypeVTable::Get(other.type_code())->clone))(
other.value().v_handle);
#endif
}
break; break;
} }
} }
...@@ -931,13 +885,6 @@ class TVMRetValue : public TVMPODValue_ { ...@@ -931,13 +885,6 @@ class TVMRetValue : public TVMPODValue_ {
break; break;
} }
} }
if (type_code_ > kExtBegin) {
#if TVM_RUNTIME_HEADER_ONLY
LOG(FATAL) << "Header only mode do not support ext type";
#else
(*(ExtTypeVTable::Get(type_code_)->destroy))(value_.v_handle);
#endif
}
type_code_ = kNull; type_code_ = kNull;
} }
}; };
...@@ -1317,23 +1264,16 @@ inline R TypedPackedFunc<R(Args...)>::operator()(Args... args) const { ...@@ -1317,23 +1264,16 @@ inline R TypedPackedFunc<R(Args...)>::operator()(Args... args) const {
// extension and node type handling // extension and node type handling
namespace detail { namespace detail {
template<typename T, typename TSrc, bool is_ext, bool is_nd> template<typename T, typename TSrc, bool is_nd>
struct TVMValueCast { struct TVMValueCast {
static T Apply(const TSrc* self) { static T Apply(const TSrc* self) {
static_assert(!is_ext && !is_nd, "The default case accepts only non-extensions"); static_assert(!is_nd, "The default case accepts only non-extensions");
return self->template AsObjectRef<T>(); return self->template AsObjectRef<T>();
} }
}; };
template<typename T, typename TSrc> template<typename T, typename TSrc>
struct TVMValueCast<T, TSrc, true, false> { struct TVMValueCast<T, TSrc, true> {
static T Apply(const TSrc* self) {
return self->template AsExtension<T>();
}
};
template<typename T, typename TSrc>
struct TVMValueCast<T, TSrc, false, true> {
static T Apply(const TSrc* self) { static T Apply(const TSrc* self) {
return self->template AsNDArray<T>(); return self->template AsNDArray<T>();
} }
...@@ -1345,7 +1285,6 @@ template<typename T, typename> ...@@ -1345,7 +1285,6 @@ template<typename T, typename>
inline TVMArgValue::operator T() const { inline TVMArgValue::operator T() const {
return detail:: return detail::
TVMValueCast<T, TVMArgValue, TVMValueCast<T, TVMArgValue,
(extension_type_info<T>::code != 0),
(array_type_info<T>::code > 0)> (array_type_info<T>::code > 0)>
::Apply(this); ::Apply(this);
} }
...@@ -1354,19 +1293,10 @@ template<typename T, typename> ...@@ -1354,19 +1293,10 @@ template<typename T, typename>
inline TVMRetValue::operator T() const { inline TVMRetValue::operator T() const {
return detail:: return detail::
TVMValueCast<T, TVMRetValue, TVMValueCast<T, TVMRetValue,
(extension_type_info<T>::code != 0),
(array_type_info<T>::code > 0)> (array_type_info<T>::code > 0)>
::Apply(this); ::Apply(this);
} }
template<typename T, typename>
inline void TVMArgsSetter::operator()(size_t i, const T& value) const {
static_assert(extension_type_info<T>::code != 0,
"Need to have extesion code");
type_codes_[i] = extension_type_info<T>::code;
values_[i].v_handle = const_cast<T*>(&value);
}
// PackedFunc support // PackedFunc support
inline TVMRetValue& TVMRetValue::operator=(const DataType& t) { inline TVMRetValue& TVMRetValue::operator=(const DataType& t) {
return this->operator=(t.operator DLDataType()); return this->operator=(t.operator DLDataType());
...@@ -1385,28 +1315,6 @@ inline void TVMArgsSetter::operator()( ...@@ -1385,28 +1315,6 @@ inline void TVMArgsSetter::operator()(
this->operator()(i, t.operator DLDataType()); this->operator()(i, t.operator DLDataType());
} }
// extension type handling
template<typename T>
struct ExtTypeInfo {
static void destroy(void* handle) {
delete static_cast<T*>(handle);
}
static void* clone(void* handle) {
return new T(*static_cast<T*>(handle));
}
};
template<typename T>
inline ExtTypeVTable* ExtTypeVTable::Register_() {
const int code = extension_type_info<T>::code;
static_assert(code != 0,
"require extension_type_info traits to be declared with non-zero code");
ExtTypeVTable vt;
vt.clone = ExtTypeInfo<T>::clone;
vt.destroy = ExtTypeInfo<T>::destroy;
return ExtTypeVTable::RegisterInternal(code, vt);
}
inline PackedFunc Module::GetFunction(const std::string& name, bool query_imports) { inline PackedFunc Module::GetFunction(const std::string& name, bool query_imports) {
return (*this)->GetFunction(name, query_imports); return (*this)->GetFunction(name, query_imports);
} }
......
...@@ -311,15 +311,6 @@ class Registry { ...@@ -311,15 +311,6 @@ class Registry {
TVM_STR_CONCAT(TVM_FUNC_REG_VAR_DEF, __COUNTER__) = \ TVM_STR_CONCAT(TVM_FUNC_REG_VAR_DEF, __COUNTER__) = \
::tvm::runtime::Registry::Register(OpName) ::tvm::runtime::Registry::Register(OpName)
/*!
* \brief Macro to register extension type.
* This must be registered in a cc file
* after the trait extension_type_info is defined.
*/
#define TVM_REGISTER_EXT_TYPE(T) \
TVM_STR_CONCAT(TVM_TYPE_REG_VAR_DEF, __COUNTER__) = \
::tvm::runtime::ExtTypeVTable::Register_<T>()
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
#endif // TVM_RUNTIME_REGISTRY_H_ #endif // TVM_RUNTIME_REGISTRY_H_
...@@ -299,20 +299,6 @@ class NDArrayBase(_NDArrayBase): ...@@ -299,20 +299,6 @@ class NDArrayBase(_NDArrayBase):
raise ValueError("Unsupported target type %s" % str(type(target))) raise ValueError("Unsupported target type %s" % str(type(target)))
def free_extension_handle(handle, type_code):
"""Free c++ extension type handle
Parameters
----------
handle : ctypes.c_void_p
The handle to the extension type.
type_code : int
The tyoe code
"""
check_call(_LIB.TVMExtTypeFree(handle, ctypes.c_int(type_code)))
def register_extension(cls, fcreate=None): def register_extension(cls, fcreate=None):
"""Register a extension class to TVM. """Register a extension class to TVM.
......
...@@ -26,7 +26,7 @@ import numpy as _np ...@@ -26,7 +26,7 @@ import numpy as _np
from ._ffi.ndarray import TVMContext, TVMType, NDArrayBase from ._ffi.ndarray import TVMContext, TVMType, NDArrayBase
from ._ffi.ndarray import context, empty, from_dlpack from ._ffi.ndarray import context, empty, from_dlpack
from ._ffi.ndarray import _set_class_ndarray from ._ffi.ndarray import _set_class_ndarray
from ._ffi.ndarray import register_extension, free_extension_handle from ._ffi.ndarray import register_extension
class NDArray(NDArrayBase): class NDArray(NDArrayBase):
"""Lightweight NDArray class of TVM runtime. """Lightweight NDArray class of TVM runtime.
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -40,15 +40,10 @@ struct Registry::Manager { ...@@ -40,15 +40,10 @@ struct Registry::Manager {
// and the resource can become invalid because of indeterminstic order of destruction. // and the resource can become invalid because of indeterminstic order of destruction.
// The resources will only be recycled during program exit. // The resources will only be recycled during program exit.
std::unordered_map<std::string, Registry*> fmap; std::unordered_map<std::string, Registry*> fmap;
// vtable for extension type
std::array<ExtTypeVTable, kExtEnd> ext_vtable;
// mutex // mutex
std::mutex mutex; std::mutex mutex;
Manager() { Manager() {
for (auto& x : ext_vtable) {
x.destroy = nullptr;
}
} }
static Manager* Global() { static Manager* Global() {
...@@ -109,24 +104,6 @@ std::vector<std::string> Registry::ListNames() { ...@@ -109,24 +104,6 @@ std::vector<std::string> Registry::ListNames() {
return keys; return keys;
} }
ExtTypeVTable* ExtTypeVTable::Get(int type_code) {
CHECK(type_code > kExtBegin && type_code < kExtEnd);
Registry::Manager* m = Registry::Manager::Global();
ExtTypeVTable* vt = &(m->ext_vtable[type_code]);
CHECK(vt->destroy != nullptr)
<< "Extension type not registered";
return vt;
}
ExtTypeVTable* ExtTypeVTable::RegisterInternal(
int type_code, const ExtTypeVTable& vt) {
CHECK(type_code > kExtBegin && type_code < kExtEnd);
Registry::Manager* m = Registry::Manager::Global();
std::lock_guard<std::mutex> lock(m->mutex);
ExtTypeVTable* pvt = &(m->ext_vtable[type_code]);
pvt[0] = vt;
return pvt;
}
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
...@@ -141,12 +118,6 @@ struct TVMFuncThreadLocalEntry { ...@@ -141,12 +118,6 @@ struct TVMFuncThreadLocalEntry {
/*! \brief Thread local store that can be used to hold return values. */ /*! \brief Thread local store that can be used to hold return values. */
typedef dmlc::ThreadLocalStore<TVMFuncThreadLocalEntry> TVMFuncThreadLocalStore; typedef dmlc::ThreadLocalStore<TVMFuncThreadLocalEntry> TVMFuncThreadLocalStore;
int TVMExtTypeFree(void* handle, int type_code) {
API_BEGIN();
tvm::runtime::ExtTypeVTable::Get(type_code)->destroy(handle);
API_END();
}
int TVMFuncRegisterGlobal( int TVMFuncRegisterGlobal(
const char* name, TVMFunctionHandle f, int override) { const char* name, TVMFunctionHandle f, int override) {
API_BEGIN(); API_BEGIN();
......
...@@ -178,56 +178,6 @@ TEST(TypedPackedFunc, HighOrder) { ...@@ -178,56 +178,6 @@ TEST(TypedPackedFunc, HighOrder) {
CHECK_EQ(f1(3), 4); CHECK_EQ(f1(3), 4);
} }
// new namespoace
namespace test {
// register int vector as extension type
using IntVector = std::vector<int>;
} // namespace test
namespace tvm {
namespace runtime {
template<>
struct extension_type_info<test::IntVector> {
static const int code = kExtBegin + 1;
};
} // runtime
} // tvm
// do registration, this need to be in cc file
TVM_REGISTER_EXT_TYPE(test::IntVector);
TEST(PackedFunc, ExtensionType) {
using namespace tvm;
using namespace tvm::runtime;
// note: class are copy by value.
test::IntVector vec{1, 2, 4};
auto copy_vec = PackedFunc([&](TVMArgs args, TVMRetValue* rv) {
// copy by value
const test::IntVector& v = args[0].AsExtension<test::IntVector>();
CHECK(&v == &vec);
test::IntVector v2 = args[0];
CHECK_EQ(v2.size(), 3U);
CHECK_EQ(v[2], 4);
// return copy by value
*rv = v2;
});
auto pass_vec = PackedFunc([&](TVMArgs args, TVMRetValue* rv) {
// copy by value
*rv = args[0];
});
test::IntVector vret1 = copy_vec(vec);
test::IntVector vret2 = pass_vec(copy_vec(vec));
CHECK_EQ(vret1.size(), 3U);
CHECK_EQ(vret2.size(), 3U);
CHECK_EQ(vret1[2], 4);
CHECK_EQ(vret2[2], 4);
}
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
testing::InitGoogleTest(&argc, argv); testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe"; testing::FLAGS_gtest_death_test_style = "threadsafe";
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment