Commit f2ab736b by Tianqi Chen Committed by GitHub

[RUNTIME] Enable extension type to PackedFunc. (#447)

* [RUNTIME] Enable extension type to PackedFunc.

* More comments
parent 3130f2d5
...@@ -16,4 +16,27 @@ _LIB = load_lib() ...@@ -16,4 +16,27 @@ _LIB = load_lib()
# Expose two functions into python # Expose two functions into python
bind_add = tvm.get_global_func("tvm_ext.bind_add") bind_add = tvm.get_global_func("tvm_ext.bind_add")
sym_add = tvm.get_global_func("tvm_ext.sym_add") sym_add = tvm.get_global_func("tvm_ext.sym_add")
ivec_create = tvm.get_global_func("tvm_ext.ivec_create")
ivec_get = tvm.get_global_func("tvm_ext.ivec_get")
class IntVec(object):
"""Example for using extension class in c++ """
_tvm_tcode = 17
def __init__(self, handle):
self.handle = handle
def __del__(self):
# You can also call your own customized
# deleter if you can free it via your own FFI.
tvm.nd.free_extension_handle(self.handle, 17)
@property
def _tvm_handle(self):
return self.handle.value
def __getitem__(self, idx):
return ivec_get(self, idx)
# Register IntVec extension on python side.
tvm.register_extension(IntVec, IntVec)
...@@ -10,9 +10,41 @@ ...@@ -10,9 +10,41 @@
#include <tvm/packed_func_ext.h> #include <tvm/packed_func_ext.h>
namespace tvm_ext { namespace tvm_ext {
using IntVector = std::vector<int>;
} // namespace tvm_ext
namespace tvm {
namespace runtime {
template<>
struct extension_class_info<tvm_ext::IntVector> {
static const int code = 17;
};
} // namespace tvm
} // namespace runtime
namespace tvm_ext {
using namespace tvm; using namespace tvm;
using namespace tvm::runtime; using namespace tvm::runtime;
TVM_REGISTER_EXT_TYPE(IntVector);
TVM_REGISTER_GLOBAL("tvm_ext.ivec_create")
.set_body([](TVMArgs args, TVMRetValue *rv) {
IntVector vec;
for (int i = 0; i < args.size(); ++i) {
vec.push_back(args[i].operator int());
}
*rv = vec;
});
TVM_REGISTER_GLOBAL("tvm_ext.ivec_get")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = args[0].AsExtension<IntVector>()[args[1].operator int()];
});
TVM_REGISTER_GLOBAL("tvm_ext.bind_add") TVM_REGISTER_GLOBAL("tvm_ext.bind_add")
.set_body([](TVMArgs args_, TVMRetValue *rv_) { .set_body([](TVMArgs args_, TVMRetValue *rv_) {
PackedFunc pf = args_[0]; PackedFunc pf = args_[0];
......
...@@ -13,6 +13,19 @@ def test_sym_add(): ...@@ -13,6 +13,19 @@ def test_sym_add():
c = tvm_ext.sym_add(a, b) c = tvm_ext.sym_add(a, b)
assert c.a == a and c.b == b assert c.a == a and c.b == b
def test_ext_vec():
ivec = tvm_ext.ivec_create(1, 2, 3)
assert(isinstance(ivec, tvm_ext.IntVec))
assert ivec[0] == 1
assert ivec[1] == 2
def ivec_cb(v2):
assert(isinstance(v2, tvm_ext.IntVec))
assert v2[2] == 3
tvm.convert(ivec_cb)(ivec)
if __name__ == "__main__": if __name__ == "__main__":
test_ext_vec()
test_bind_add() test_bind_add()
test_sym_add() test_sym_add()
...@@ -89,8 +89,8 @@ inline std::string NodeTypeName() { ...@@ -89,8 +89,8 @@ inline std::string NodeTypeName() {
// extensions for tvm arg value // extensions for tvm arg value
template<typename TNodeRef, typename> template<typename TNodeRef>
inline TVMArgValue::operator TNodeRef() const { inline TNodeRef TVMArgValue::AsNodeRef() const {
static_assert( static_assert(
std::is_base_of<NodeRef, TNodeRef>::value, std::is_base_of<NodeRef, TNodeRef>::value,
"Conversion only works for NodeRef"); "Conversion only works for NodeRef");
...@@ -156,8 +156,8 @@ inline TVMRetValue& TVMRetValue::operator=(const NodeRef& other) { ...@@ -156,8 +156,8 @@ inline TVMRetValue& TVMRetValue::operator=(const NodeRef& other) {
return *this; return *this;
} }
template<typename TNodeRef, typename> template<typename TNodeRef>
inline TVMRetValue::operator TNodeRef() const { inline TNodeRef TVMRetValue::AsNodeRef() const {
static_assert( static_assert(
std::is_base_of<NodeRef, TNodeRef>::value, std::is_base_of<NodeRef, TNodeRef>::value,
"Conversion only works for NodeRef"); "Conversion only works for NodeRef");
...@@ -166,8 +166,8 @@ inline TVMRetValue::operator TNodeRef() const { ...@@ -166,8 +166,8 @@ inline TVMRetValue::operator TNodeRef() const {
return TNodeRef(*ptr<std::shared_ptr<Node> >()); return TNodeRef(*ptr<std::shared_ptr<Node> >());
} }
inline void TVMArgsSetter::operator()(size_t i, NodeRef& other) const { // NOLINT(*) inline void TVMArgsSetter::operator()(size_t i, const NodeRef& other) const { // NOLINT(*)
values_[i].v_handle = &(other.node_); values_[i].v_handle = const_cast<std::shared_ptr<Node>*>(&(other.node_));
type_codes_[i] = kNodeHandle; type_codes_[i] = kNodeHandle;
} }
......
...@@ -75,7 +75,17 @@ typedef enum { ...@@ -75,7 +75,17 @@ typedef enum {
kModuleHandle = 9U, kModuleHandle = 9U,
kFuncHandle = 10U, kFuncHandle = 10U,
kStr = 11U, kStr = 11U,
kBytes = 12U kBytes = 12U,
// Extension codes for other frameworks to integrate TVM PackedFunc.
// To make sure each framework's id do not conflict, use first and
// last sections to mark ranges.
// Open an issue at the repo if you need a section of code.
kExtBegin = 15U,
kNNVMFirst = 16U,
kNNVMLast = 20U,
// The following section of code is used for non-reserved types.
kExtReserveEnd = 64U,
kExtEnd = 128U
} TVMTypeCode; } TVMTypeCode;
/*! /*!
...@@ -192,6 +202,14 @@ TVM_DLL int TVMModGetFunction(TVMModuleHandle mod, ...@@ -192,6 +202,14 @@ TVM_DLL int TVMModGetFunction(TVMModuleHandle mod,
TVMFunctionHandle *out); TVMFunctionHandle *out);
/*! /*!
* \brief Free front-end extension type resource.
* \param handle The extension handle.
* \param type_code The type of of the extension type.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMExtTypeFree(void* handle, int type_code);
/*!
* \brief Free the Module * \brief Free the Module
* \param mod The module to be freed. * \param mod The module to be freed.
* *
...@@ -200,6 +218,7 @@ TVM_DLL int TVMModGetFunction(TVMModuleHandle mod, ...@@ -200,6 +218,7 @@ TVM_DLL int TVMModGetFunction(TVMModuleHandle mod,
* Or if this module is imported by another active module. * Or if this module is imported by another active module.
* *
* The all functions remains valid until TVMFuncFree is called. * The all functions remains valid until TVMFuncFree is called.
* \return 0 when success, -1 when failure happens
*/ */
TVM_DLL int TVMModFree(TVMModuleHandle mod); TVM_DLL int TVMModFree(TVMModuleHandle mod);
......
...@@ -167,6 +167,50 @@ inline std::string TVMType2String(TVMType t); ...@@ -167,6 +167,50 @@ inline std::string TVMType2String(TVMType t);
<< TypeCode2Str(T) << " but get " << TypeCode2Str(CODE) \ << TypeCode2Str(T) << " but get " << TypeCode2Str(CODE) \
/*! /*!
* \brief Type traits to mark if a class is tvm extension type.
*
* To enable extension type in C++ must be register () ed via marco.
* TVM_REGISTER_EXT_TYPE(TypeName) after defining this with this traits.
*
* Extension class can be passed and returned via PackedFunc in all tvm runtime.
* Internally extension class is stored as T*.
*
* \tparam T the typename
*/
template<typename T>
struct extension_class_info {
static const int code = 0;
};
/*!
* \brief Runtime function table about extension type.
*/
class ExtTypeVTable {
public:
/*! \brief function to be called to delete a handle */
void (*destroy)(void* handle);
/*! \brief function to be called when clone a handle */
void* (*clone)(void* handle);
/*!
* \brief Register type
* \tparam T The type to be register.
* \return The registered vtable.
*/
template <typename T>
static inline ExtTypeVTable* Register_();
/*!
* \brief Get a vtable based on type code.
* \param type_code The type code
* \return The registered vtable.
*/
static ExtTypeVTable* Get(int type_code);
private:
// Internal registration function.
static ExtTypeVTable* RegisterInternal(int type_code, const ExtTypeVTable& vt);
};
/*!
* \brief Internal base class to * \brief Internal base class to
* handle conversion to POD values. * handle conversion to POD values.
*/ */
...@@ -209,6 +253,11 @@ class TVMPODValue_ { ...@@ -209,6 +253,11 @@ class TVMPODValue_ {
TVM_CHECK_TYPE_CODE(type_code_, kTVMContext); TVM_CHECK_TYPE_CODE(type_code_, kTVMContext);
return value_.v_ctx; return value_.v_ctx;
} }
template<typename TExtension>
const TExtension& AsExtension() const {
CHECK_LT(type_code_, kExtEnd);
return static_cast<TExtension*>(value_.v_handle)[0];
}
int type_code() const { int type_code() const {
return type_code_; return type_code_;
} }
...@@ -291,11 +340,13 @@ class TVMArgValue : public TVMPODValue_ { ...@@ -291,11 +340,13 @@ class TVMArgValue : public TVMPODValue_ {
const TVMValue& value() const { const TVMValue& value() const {
return value_; return value_;
} }
// NodeRef related extenstions: in tvm/packed_func_ext.h // Deferred extension handler.
template<typename TNodeRef, template<typename TNodeRef>
inline TNodeRef AsNodeRef() const;
template<typename T,
typename = typename std::enable_if< typename = typename std::enable_if<
std::is_class<TNodeRef>::value>::type> std::is_class<T>::value>::type>
inline operator TNodeRef() const; inline operator T() const;
template<typename TNodeRef, template<typename TNodeRef,
typename = typename std::enable_if< typename = typename std::enable_if<
std::is_class<TNodeRef>::value>::type> std::is_class<TNodeRef>::value>::type>
...@@ -433,10 +484,18 @@ class TVMRetValue : public TVMPODValue_ { ...@@ -433,10 +484,18 @@ class TVMRetValue : public TVMPODValue_ {
this->Assign(other); this->Assign(other);
return *this; return *this;
} }
TVMRetValue& operator=(TVMArgValue other) { TVMRetValue& operator=(const TVMArgValue& other) {
this->Assign(other); this->Assign(other);
return *this; return *this;
} }
template<typename T,
typename = typename std::enable_if<
extension_class_info<T>::code != 0>::type>
TVMRetValue& operator=(const T& other) {
this->SwitchToClass<T>(
extension_class_info<T>::code, other);
return *this;
}
/*! /*!
* \brief Move the value back to front-end via C API. * \brief Move the value back to front-end via C API.
* This marks the current container as null. * This marks the current container as null.
...@@ -463,12 +522,14 @@ class TVMRetValue : public TVMPODValue_ { ...@@ -463,12 +522,14 @@ class TVMRetValue : public TVMPODValue_ {
return value_; return value_;
} }
// NodeRef related extenstions: in tvm/packed_func_ext.h // NodeRef related extenstions: in tvm/packed_func_ext.h
template<typename T,
typename = typename std::enable_if<
std::is_class<T>::value>::type>
inline operator T() const;
template<typename TNodeRef>
inline TNodeRef AsNodeRef() const;
inline TVMRetValue& operator=(const NodeRef& other); inline TVMRetValue& operator=(const NodeRef& other);
inline TVMRetValue& operator=(const std::shared_ptr<Node>& other); inline TVMRetValue& operator=(const std::shared_ptr<Node>& other);
template<typename TNodeRef,
typename = typename std::enable_if<
std::is_class<TNodeRef>::value>::type>
inline operator TNodeRef() const;
// type related // type related
inline operator Halide::Type() const; inline operator Halide::Type() const;
inline TVMRetValue& operator=(const Halide::Type& other); inline TVMRetValue& operator=(const Halide::Type& other);
...@@ -499,13 +560,20 @@ class TVMRetValue : public TVMPODValue_ { ...@@ -499,13 +560,20 @@ class TVMRetValue : public TVMPODValue_ {
break; break;
} }
default: { default: {
SwitchToPOD(other.type_code()); if (other.type_code() < kExtBegin) {
value_ = other.value_; SwitchToPOD(other.type_code());
value_ = other.value_;
} else {
this->Clear();
type_code_ = other.type_code();
value_.v_handle =
(*(ExtTypeVTable::Get(other.type_code())->clone))(
other.value().v_handle);
}
break; break;
} }
} }
} }
// get the internal container. // get the internal container.
void SwitchToPOD(int type_code) { void SwitchToPOD(int type_code) {
if (type_code_ != type_code) { if (type_code_ != type_code) {
...@@ -531,6 +599,9 @@ class TVMRetValue : public TVMPODValue_ { ...@@ -531,6 +599,9 @@ class TVMRetValue : public TVMPODValue_ {
case kModuleHandle: delete ptr<Module>(); break; case kModuleHandle: delete ptr<Module>(); break;
case kNodeHandle: delete ptr<std::shared_ptr<Node> >(); break; case kNodeHandle: delete ptr<std::shared_ptr<Node> >(); break;
} }
if (type_code_ > kExtBegin) {
(*(ExtTypeVTable::Get(type_code_)->destroy))(value_.v_handle);
}
type_code_ = kNull; type_code_ = kNull;
} }
}; };
...@@ -619,24 +690,28 @@ inline PackedFunc::FType PackedFunc::body() const { ...@@ -619,24 +690,28 @@ inline PackedFunc::FType PackedFunc::body() const {
// internal namespace // internal namespace
namespace detail { namespace detail {
template<bool stop, std::size_t I, typename F, typename ...Args>
template<bool stop, std::size_t I, typename F>
struct for_each_dispatcher { struct for_each_dispatcher {
static void run(std::tuple<Args...>& args, const F& f) { // NOLINT(*) template<typename T, typename ...Args>
f(I, std::get<I>(args)); static void run(const F& f, T&& value, Args&&... args) { // NOLINT(*)
for_each_dispatcher<(I + 1) == sizeof...(Args), (I+1), F, Args...>::run(args, f); f(I, std::forward<T>(value));
for_each_dispatcher<sizeof...(Args) == 0, (I+1), F>
::run(f, std::forward<Args>(args)...);
} }
}; };
template<std::size_t I, typename F, typename ...Args> template<std::size_t I, typename F>
struct for_each_dispatcher<true, I, F, Args...> { struct for_each_dispatcher<true, I, F> {
static void run(std::tuple<Args...>& args, const F& f) {} // NOLINT(*) static void run(const F& f) {} // NOLINT(*)
}; };
} // namespace detail
template<typename F, typename ...Args> template<typename F, typename ...Args>
inline void for_each(std::tuple<Args...>& args, const F& f) { // NOLINT(*) inline void for_each(const F& f, Args&&... args) { // NOLINT(*)
detail::for_each_dispatcher<sizeof...(Args) == 0, 0, F, Args...>::run(args, f); for_each_dispatcher<sizeof...(Args) == 0, 0, F>
::run(f, std::forward<Args>(args)...);
} }
} // namespace detail
/* \brief argument settter to PackedFunc */ /* \brief argument settter to PackedFunc */
class TVMArgsSetter { class TVMArgsSetter {
...@@ -645,7 +720,8 @@ class TVMArgsSetter { ...@@ -645,7 +720,8 @@ class TVMArgsSetter {
: values_(values), type_codes_(type_codes) {} : values_(values), type_codes_(type_codes) {}
// setters for POD types // setters for POD types
template<typename T, template<typename T,
typename = typename std::enable_if<std::is_integral<T>::value>::type> typename = typename std::enable_if<
std::is_integral<T>::value>::type>
void operator()(size_t i, T value) const { void operator()(size_t i, T value) const {
values_[i].v_int64 = static_cast<int64_t>(value); values_[i].v_int64 = static_cast<int64_t>(value);
type_codes_[i] = kInt; type_codes_[i] = kInt;
...@@ -691,23 +767,23 @@ class TVMArgsSetter { ...@@ -691,23 +767,23 @@ class TVMArgsSetter {
// setters for container type // setters for container type
// They must be reference(instead of const ref) // They must be reference(instead of const ref)
// to make sure they are alive in the tuple(instead of getting converted) // to make sure they are alive in the tuple(instead of getting converted)
void operator()(size_t i, std::string& value) const { // NOLINT(*) void operator()(size_t i, const std::string& value) const { // NOLINT(*)
values_[i].v_str = value.c_str(); values_[i].v_str = value.c_str();
type_codes_[i] = kStr; type_codes_[i] = kStr;
} }
void operator()(size_t i, TVMByteArray& value) const { // NOLINT(*) void operator()(size_t i, const TVMByteArray& value) const { // NOLINT(*)
values_[i].v_handle = &value; values_[i].v_handle = const_cast<TVMByteArray*>(&value);
type_codes_[i] = kBytes; type_codes_[i] = kBytes;
} }
void operator()(size_t i, PackedFunc& value) const { // NOLINT(*) void operator()(size_t i, const PackedFunc& value) const { // NOLINT(*)
values_[i].v_handle = &value; values_[i].v_handle = const_cast<PackedFunc*>(&value);
type_codes_[i] = kFuncHandle; type_codes_[i] = kFuncHandle;
} }
void operator()(size_t i, Module& value) const { // NOLINT(*) void operator()(size_t i, const Module& value) const { // NOLINT(*)
values_[i].v_handle = &value; values_[i].v_handle = const_cast<Module*>(&value);
type_codes_[i] = kModuleHandle; type_codes_[i] = kModuleHandle;
} }
void operator()(size_t i, TVMRetValue& value) const { // NOLINT(*) void operator()(size_t i, const TVMRetValue& value) const { // NOLINT(*)
if (value.type_code() == kStr) { if (value.type_code() == kStr) {
values_[i].v_str = value.ptr<std::string>()->c_str(); values_[i].v_str = value.ptr<std::string>()->c_str();
type_codes_[i] = kStr; type_codes_[i] = kStr;
...@@ -717,8 +793,13 @@ class TVMArgsSetter { ...@@ -717,8 +793,13 @@ class TVMArgsSetter {
type_codes_[i] = value.type_code(); type_codes_[i] = value.type_code();
} }
} }
// extension
template<typename T,
typename = typename std::enable_if<
extension_class_info<T>::code != 0>::type>
inline void operator()(size_t i, const T& value) const;
// NodeRef related extenstions: in tvm/packed_func_ext.h // NodeRef related extenstions: in tvm/packed_func_ext.h
inline void operator()(size_t i, NodeRef& other) const; // NOLINT(*) inline void operator()(size_t i, const NodeRef& other) const; // NOLINT(*)
inline void operator()(size_t i, const Halide::Type& t) const; inline void operator()(size_t i, const Halide::Type& t) const;
private: private:
...@@ -728,32 +809,79 @@ class TVMArgsSetter { ...@@ -728,32 +809,79 @@ class TVMArgsSetter {
int* type_codes_; int* type_codes_;
}; };
class TVMArgsGetter {
public:
explicit TVMArgsGetter(TVMArgs args)
: args_(args) {}
template<typename T>
inline void operator()(size_t i, T& target) const { // NOLINT(*)
target = args_[i].operator T();
}
private:
TVMArgs args_;
};
template<typename... Args> template<typename... Args>
inline TVMRetValue PackedFunc::operator()(Args&& ...args) const { inline TVMRetValue PackedFunc::operator()(Args&& ...args) const {
auto targs = std::make_tuple(std::forward<Args>(args)...);
const int kNumArgs = sizeof...(Args); const int kNumArgs = sizeof...(Args);
const int kArraySize = kNumArgs > 0 ? kNumArgs : 1; const int kArraySize = kNumArgs > 0 ? kNumArgs : 1;
TVMValue values[kArraySize]; TVMValue values[kArraySize];
int type_codes[kArraySize]; int type_codes[kArraySize];
for_each(targs, TVMArgsSetter(values, type_codes)); detail::for_each(TVMArgsSetter(values, type_codes),
std::forward<Args>(args)...);
TVMRetValue rv; TVMRetValue rv;
body_(TVMArgs(values, type_codes, kNumArgs), &rv); body_(TVMArgs(values, type_codes, kNumArgs), &rv);
return rv; return rv;
} }
// extension and node type handling
namespace detail {
template<typename T, typename TSrc, bool is_ext>
struct TVMValueCast {
static T Apply(const TSrc* self) {
return self->template AsNodeRef<T>();
}
};
template<typename T, typename TSrc>
struct TVMValueCast<T, TSrc, true> {
static T Apply(const TSrc* self) {
return self->template AsExtension<T>();
}
};
} // namespace detail
template<typename T, typename>
inline TVMArgValue::operator T() const {
return detail::
TVMValueCast<T, TVMArgValue, extension_class_info<T>::code != 0>
::Apply(this);
}
template<typename T, typename>
inline TVMRetValue::operator T() const {
return detail::
TVMValueCast<T, TVMRetValue, extension_class_info<T>::code != 0>
::Apply(this);
}
template<typename T, typename>
inline void TVMArgsSetter::operator()(size_t i, const T& value) const {
static_assert(extension_class_info<T>::code != 0,
"Need to have extesion code");
type_codes_[i] = extension_class_info<T>::code;
values_[i].v_handle = const_cast<T*>(&value);
}
// extension type handling
template<typename T>
struct ExtTypeInfo {
static void destroy(void* handle) {
delete static_cast<T*>(handle);
}
static void* clone(void* handle) {
return new T(*static_cast<T*>(handle));
}
};
template<typename T>
inline ExtTypeVTable* ExtTypeVTable::Register_() {
const int code = extension_class_info<T>::code;
static_assert(code != 0,
"require extension_class_info traits to be declared with non-zero code");
ExtTypeVTable vt;
vt.clone = ExtTypeInfo<T>::clone;
vt.destroy = ExtTypeInfo<T>::destroy;
return ExtTypeVTable::RegisterInternal(code, vt);
}
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
#endif // TVM_RUNTIME_PACKED_FUNC_H_ #endif // TVM_RUNTIME_PACKED_FUNC_H_
...@@ -73,13 +73,14 @@ class Registry { ...@@ -73,13 +73,14 @@ class Registry {
*/ */
static std::vector<std::string> ListNames(); static std::vector<std::string> ListNames();
// Internal class.
struct Manager;
private: private:
/*! \brief name of the function */ /*! \brief name of the function */
std::string name_; std::string name_;
/*! \brief internal packed function */ /*! \brief internal packed function */
PackedFunc func_; PackedFunc func_;
// Internal class.
struct Manager;
friend struct Manager; friend struct Manager;
}; };
...@@ -96,6 +97,9 @@ class Registry { ...@@ -96,6 +97,9 @@ class Registry {
#define TVM_FUNC_REG_VAR_DEF \ #define TVM_FUNC_REG_VAR_DEF \
static TVM_ATTRIBUTE_UNUSED ::tvm::runtime::Registry& __mk_ ## TVM static TVM_ATTRIBUTE_UNUSED ::tvm::runtime::Registry& __mk_ ## TVM
#define TVM_TYPE_REG_VAR_DEF \
static TVM_ATTRIBUTE_UNUSED ::tvm::runtime::ExtTypeVTable* __mk_ ## TVMT
/*! /*!
* \brief Register a function globally. * \brief Register a function globally.
* \code * \code
...@@ -108,6 +112,15 @@ class Registry { ...@@ -108,6 +112,15 @@ class Registry {
TVM_STR_CONCAT(TVM_FUNC_REG_VAR_DEF, __COUNTER__) = \ TVM_STR_CONCAT(TVM_FUNC_REG_VAR_DEF, __COUNTER__) = \
::tvm::runtime::Registry::Register(OpName) ::tvm::runtime::Registry::Register(OpName)
/*!
* \brief Macro to register extension type.
* This must be registered in a cc file
* after the trait extension_class_info is defined.
*/
#define TVM_REGISTER_EXT_TYPE(T) \
TVM_STR_CONCAT(TVM_TYPE_REG_VAR_DEF, __COUNTER__) = \
::tvm::runtime::ExtTypeVTable::Register_<T>()
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
#endif // TVM_RUNTIME_REGISTRY_H_ #endif // TVM_RUNTIME_REGISTRY_H_
...@@ -97,7 +97,7 @@ def _make_tvm_args(args, temp_args): ...@@ -97,7 +97,7 @@ def _make_tvm_args(args, temp_args):
type_codes[i] = TypeCode.ARRAY_HANDLE type_codes[i] = TypeCode.ARRAY_HANDLE
elif isinstance(arg, _nd._TVM_COMPATS): elif isinstance(arg, _nd._TVM_COMPATS):
values[i].v_handle = ctypes.c_void_p(arg._tvm_handle) values[i].v_handle = ctypes.c_void_p(arg._tvm_handle)
type_codes[i] = arg._tvm_tcode type_codes[i] = arg.__class__._tvm_tcode
elif isinstance(arg, Integral): elif isinstance(arg, Integral):
values[i].v_int64 = arg values[i].v_int64 = arg
type_codes[i] = TypeCode.INT type_codes[i] = TypeCode.INT
......
...@@ -4,6 +4,7 @@ from __future__ import absolute_import ...@@ -4,6 +4,7 @@ from __future__ import absolute_import
import ctypes import ctypes
from ..base import _LIB, check_call from ..base import _LIB, check_call
from ..runtime_ctypes import TVMArrayHandle from ..runtime_ctypes import TVMArrayHandle
from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func, _return_handle
class NDArrayBase(object): class NDArrayBase(object):
"""A simple Device/CPU Array object in runtime.""" """A simple Device/CPU Array object in runtime."""
...@@ -35,9 +36,14 @@ def _make_array(handle, is_view): ...@@ -35,9 +36,14 @@ def _make_array(handle, is_view):
_TVM_COMPATS = () _TVM_COMPATS = ()
def _reg_extension(cls): def _reg_extension(cls, fcreate):
global _TVM_COMPATS global _TVM_COMPATS
_TVM_COMPATS += (cls,) _TVM_COMPATS += (cls,)
if fcreate:
fret = lambda x: fcreate(_return_handle(x))
RETURN_SWITCH[cls._tvm_tcode] = fret
C_TO_PY_ARG_SWITCH[cls._tvm_tcode] = _wrap_arg_func(fret, cls._tvm_tcode)
_CLASS_NDARRAY = None _CLASS_NDARRAY = None
......
...@@ -18,6 +18,7 @@ cdef enum TVMTypeCode: ...@@ -18,6 +18,7 @@ cdef enum TVMTypeCode:
kFuncHandle = 10 kFuncHandle = 10
kStr = 11 kStr = 11
kBytes = 12 kBytes = 12
kExtBegin = 15
cdef extern from "tvm/runtime/c_runtime_api.h": cdef extern from "tvm/runtime/c_runtime_api.h":
ctypedef struct DLDataType: ctypedef struct DLDataType:
......
...@@ -27,8 +27,10 @@ cdef int tvm_callback(TVMValue* args, ...@@ -27,8 +27,10 @@ cdef int tvm_callback(TVMValue* args,
tcode = type_codes[i] tcode = type_codes[i]
if (tcode == kNodeHandle or if (tcode == kNodeHandle or
tcode == kFuncHandle or tcode == kFuncHandle or
tcode == kModuleHandle): tcode == kModuleHandle or
tcode > kExtBegin):
CALL(TVMCbArgToReturn(&value, tcode)) CALL(TVMCbArgToReturn(&value, tcode))
if tcode != kArrayHandle: if tcode != kArrayHandle:
pyargs.append(make_ret(value, tcode)) pyargs.append(make_ret(value, tcode))
else: else:
...@@ -87,7 +89,7 @@ cdef inline void make_arg(object arg, ...@@ -87,7 +89,7 @@ cdef inline void make_arg(object arg,
elif isinstance(arg, _TVM_COMPATS): elif isinstance(arg, _TVM_COMPATS):
ptr = arg._tvm_handle ptr = arg._tvm_handle
value[0].v_handle = (<void*>ptr) value[0].v_handle = (<void*>ptr)
tcode[0] = arg._tvm_tcode tcode[0] = arg.__class__._tvm_tcode
elif isinstance(arg, (int, long)): elif isinstance(arg, (int, long)):
value[0].v_int64 = arg value[0].v_int64 = arg
tcode[0] = kInt tcode[0] = kInt
...@@ -185,8 +187,10 @@ cdef inline object make_ret(TVMValue value, int tcode): ...@@ -185,8 +187,10 @@ cdef inline object make_ret(TVMValue value, int tcode):
fobj = _CLASS_FUNCTION(None, False) fobj = _CLASS_FUNCTION(None, False)
(<FunctionBase>fobj).chandle = value.v_handle (<FunctionBase>fobj).chandle = value.v_handle
return fobj return fobj
else: elif tcode in _TVM_EXT_RET:
raise ValueError("Unhandled type code %d" % tcode) return _TVM_EXT_RET[tcode](ctypes_handle(value.v_handle))
raise ValueError("Unhandled type code %d" % tcode)
cdef inline object FuncCall3(void* chandle, tuple args, int nargs): cdef inline object FuncCall3(void* chandle, tuple args, int nargs):
......
...@@ -43,9 +43,14 @@ cdef c_make_array(void* chandle, is_view): ...@@ -43,9 +43,14 @@ cdef c_make_array(void* chandle, is_view):
cdef _TVM_COMPATS = () cdef _TVM_COMPATS = ()
def _reg_extension(cls): cdef _TVM_EXT_RET = {}
def _reg_extension(cls, fcreate):
global _TVM_COMPATS global _TVM_COMPATS
_TVM_COMPATS += (cls,) _TVM_COMPATS += (cls,)
if fcreate:
_TVM_EXT_RET[cls._tvm_tcode] = fcreate
def _make_array(handle, is_view): def _make_array(handle, is_view):
cdef unsigned long long ptr cdef unsigned long long ptr
......
...@@ -6,7 +6,8 @@ import sys ...@@ -6,7 +6,8 @@ import sys
import ctypes import ctypes
import numpy as np import numpy as np
from .base import _LIB, check_call, c_array, string_types, _FFI_MODE from .base import _LIB, check_call, c_array, string_types, _FFI_MODE
from .runtime_ctypes import TVMType, TVMContext, TVMArray, TVMArrayHandle, tvm_shape_index_t from .runtime_ctypes import TVMType, TVMContext, TVMArray, TVMArrayHandle
from .runtime_ctypes import TypeCode, tvm_shape_index_t
IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
...@@ -222,9 +223,21 @@ class NDArrayBase(_NDArrayBase): ...@@ -222,9 +223,21 @@ class NDArrayBase(_NDArrayBase):
raise ValueError("Unsupported target type %s" % str(type(target))) raise ValueError("Unsupported target type %s" % str(type(target)))
return target return target
def free_extension_handle(handle, type_code):
"""Free c++ extension type handle
def register_extension(cls): Parameters
"""Register a extensio class to TVM. ----------
handle : ctypes.c_void_p
The handle to the extension type.
type_code : int
The tyoe code
"""
check_call(_LIB.TVMExtTypeFree(handle, ctypes.c_int(type_code)))
def register_extension(cls, fcreate=None):
"""Register a extension class to TVM.
After the class is registered, the class will be able After the class is registered, the class will be able
to directly pass as Function argument generated by TVM. to directly pass as Function argument generated by TVM.
...@@ -236,16 +249,19 @@ def register_extension(cls): ...@@ -236,16 +249,19 @@ def register_extension(cls):
Note Note
---- ----
The registered class is requires two properties: _tvm_handle and _tvm_tcode The registered class is requires one property: _tvm_handle and a class attribute _tvm_tcode.
- ```_tvm_handle``` returns integer represents the address of the handle. - ```_tvm_handle``` returns integer represents the address of the handle.
- ```_tvm_tcode``` returns integer represents type code of the class. - ```_tvm_tcode``` gives integer represents type code of the class.
Returns Returns
------- -------
cls : class cls : class
The class being registered. The class being registered.
fcreate : function, optional
The creation function to create a class object given handle value.
Example Example
------- -------
The following code registers user defined class The following code registers user defined class
...@@ -255,16 +271,16 @@ def register_extension(cls): ...@@ -255,16 +271,16 @@ def register_extension(cls):
@tvm.register_extension @tvm.register_extension
class MyTensor(object): class MyTensor(object):
_tvm_tcode = tvm.TypeCode.ARRAY_HANDLE
def __init__(self): def __init__(self):
self.handle = _LIB.NewDLTensor() self.handle = _LIB.NewDLTensor()
@property @property
def _tvm_handle(self): def _tvm_handle(self):
return self.handle.value return self.handle.value
@property
def _tvm_tcode(self):
return tvm.TypeCode.ARRAY_HANDLE
""" """
_reg_extension(cls) if fcreate and cls._tvm_tcode < TypeCode.EXT_BEGIN:
raise ValueError("Cannot register create when extension tcode is same as buildin")
_reg_extension(cls, fcreate)
return cls return cls
...@@ -24,6 +24,7 @@ class TypeCode(object): ...@@ -24,6 +24,7 @@ class TypeCode(object):
FUNC_HANDLE = 10 FUNC_HANDLE = 10
STR = 11 STR = 11
BYTES = 12 BYTES = 12
EXT_BEGIN = 15
class TVMByteArray(ctypes.Structure): class TVMByteArray(ctypes.Structure):
"""Temp data structure for byte array.""" """Temp data structure for byte array."""
......
...@@ -9,7 +9,8 @@ import numpy as _np ...@@ -9,7 +9,8 @@ import numpy as _np
from ._ffi.ndarray import TVMContext, TVMType, NDArrayBase from ._ffi.ndarray import TVMContext, TVMType, NDArrayBase
from ._ffi.ndarray import context, empty from ._ffi.ndarray import context, empty
from ._ffi.ndarray import _set_class_ndarray, register_extension from ._ffi.ndarray import _set_class_ndarray
from ._ffi.ndarray import register_extension, free_extension_handle
class NDArray(NDArrayBase): class NDArray(NDArrayBase):
"""Lightweight NDArray class of TVM runtime. """Lightweight NDArray class of TVM runtime.
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <unordered_map> #include <unordered_map>
#include <mutex> #include <mutex>
#include <memory> #include <memory>
#include <array>
#include "./runtime_base.h" #include "./runtime_base.h"
namespace tvm { namespace tvm {
...@@ -21,8 +22,17 @@ struct Registry::Manager { ...@@ -21,8 +22,17 @@ struct Registry::Manager {
// and the resource can become invalid because of indeterminstic order of destruction. // and the resource can become invalid because of indeterminstic order of destruction.
// The resources will only be recycled during program exit. // The resources will only be recycled during program exit.
std::unordered_map<std::string, Registry*> fmap; std::unordered_map<std::string, Registry*> fmap;
// vtable for extension type
std::array<ExtTypeVTable, kExtEnd> ext_vtable;
// mutex
std::mutex mutex; std::mutex mutex;
Manager() {
for (auto& x : ext_vtable) {
x.destroy = nullptr;
}
}
static Manager* Global() { static Manager* Global() {
static Manager inst; static Manager inst;
return &inst; return &inst;
...@@ -78,6 +88,24 @@ std::vector<std::string> Registry::ListNames() { ...@@ -78,6 +88,24 @@ std::vector<std::string> Registry::ListNames() {
return keys; return keys;
} }
ExtTypeVTable* ExtTypeVTable::Get(int type_code) {
CHECK(type_code > kExtBegin && type_code < kExtEnd);
Registry::Manager* m = Registry::Manager::Global();
ExtTypeVTable* vt = &(m->ext_vtable[type_code]);
CHECK(vt->destroy != nullptr)
<< "Extension type not registered";
return vt;
}
ExtTypeVTable* ExtTypeVTable::RegisterInternal(
int type_code, const ExtTypeVTable& vt) {
CHECK(type_code > kExtBegin && type_code < kExtEnd);
Registry::Manager* m = Registry::Manager::Global();
std::lock_guard<std::mutex>(m->mutex);
ExtTypeVTable* pvt = &(m->ext_vtable[type_code]);
pvt[0] = vt;
return pvt;
}
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
...@@ -92,6 +120,11 @@ struct TVMFuncThreadLocalEntry { ...@@ -92,6 +120,11 @@ struct TVMFuncThreadLocalEntry {
/*! \brief Thread local store that can be used to hold return values. */ /*! \brief Thread local store that can be used to hold return values. */
typedef dmlc::ThreadLocalStore<TVMFuncThreadLocalEntry> TVMFuncThreadLocalStore; typedef dmlc::ThreadLocalStore<TVMFuncThreadLocalEntry> TVMFuncThreadLocalStore;
int TVMExtTypeFree(void* handle, int type_code) {
API_BEGIN();
tvm::runtime::ExtTypeVTable::Get(type_code)->destroy(handle);
API_END();
}
int TVMFuncRegisterGlobal( int TVMFuncRegisterGlobal(
const char* name, TVMFunctionHandle f, int override) { const char* name, TVMFunctionHandle f, int override) {
......
...@@ -110,6 +110,55 @@ TEST(PackedFunc, Type) { ...@@ -110,6 +110,55 @@ TEST(PackedFunc, Type) {
CHECK(get_type2("float32x2").operator Type() == Float(32, 2)); CHECK(get_type2("float32x2").operator Type() == Float(32, 2));
} }
// new namespoace
namespace test {
// register int vector as extension type
using IntVector = std::vector<int>;
} // namespace test
namespace tvm {
namespace runtime {
template<>
struct extension_class_info<test::IntVector> {
static const int code = kExtBegin + 1;
};
} // runtime
} // tvm
// do registration, this need to be in cc file
TVM_REGISTER_EXT_TYPE(test::IntVector);
TEST(PackedFunc, ExtensionType) {
using namespace tvm;
using namespace tvm::runtime;
// note: class are copy by value.
test::IntVector vec{1, 2, 4};
auto copy_vec = PackedFunc([&](TVMArgs args, TVMRetValue* rv) {
// copy by value
const test::IntVector& v = args[0].AsExtension<test::IntVector>();
CHECK(&v == &vec);
test::IntVector v2 = args[0];
CHECK_EQ(v2.size(), 3U);
CHECK_EQ(v[2], 4);
// return copy by value
*rv = v2;
});
auto pass_vec = PackedFunc([&](TVMArgs args, TVMRetValue* rv) {
// copy by value
*rv = args[0];
});
test::IntVector vret1 = copy_vec(vec);
test::IntVector vret2 = pass_vec(copy_vec(vec));
CHECK_EQ(vret1.size(), 3U);
CHECK_EQ(vret2.size(), 3U);
CHECK_EQ(vret1[2], 4);
CHECK_EQ(vret2[2], 4);
}
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
testing::InitGoogleTest(&argc, argv); testing::InitGoogleTest(&argc, argv);
......
...@@ -3,6 +3,7 @@ import numpy as np ...@@ -3,6 +3,7 @@ import numpy as np
@tvm.register_extension @tvm.register_extension
class MyTensorView(object): class MyTensorView(object):
_tvm_tcode = tvm.TypeCode.ARRAY_HANDLE
def __init__(self, arr): def __init__(self, arr):
self.arr = arr self.arr = arr
...@@ -10,10 +11,6 @@ class MyTensorView(object): ...@@ -10,10 +11,6 @@ class MyTensorView(object):
def _tvm_handle(self): def _tvm_handle(self):
return self.arr._tvm_handle return self.arr._tvm_handle
@property
def _tvm_tcode(self):
return tvm.TypeCode.ARRAY_HANDLE
def test_dltensor_compatible(): def test_dltensor_compatible():
dtype = 'int64' dtype = 'int64'
n = tvm.var('n') n = tvm.var('n')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment