Commit f2ab736b by Tianqi Chen Committed by GitHub

[RUNTIME] Enable extension type to PackedFunc. (#447)

* [RUNTIME] Enable extension type to PackedFunc.

* More comments
parent 3130f2d5
......@@ -16,4 +16,27 @@ _LIB = load_lib()
# Expose two functions into python
bind_add = tvm.get_global_func("tvm_ext.bind_add")
sym_add = tvm.get_global_func("tvm_ext.sym_add")
ivec_create = tvm.get_global_func("tvm_ext.ivec_create")
ivec_get = tvm.get_global_func("tvm_ext.ivec_get")
class IntVec(object):
"""Example for using extension class in c++ """
_tvm_tcode = 17
def __init__(self, handle):
self.handle = handle
def __del__(self):
# You can also call your own customized
# deleter if you can free it via your own FFI.
tvm.nd.free_extension_handle(self.handle, 17)
@property
def _tvm_handle(self):
return self.handle.value
def __getitem__(self, idx):
return ivec_get(self, idx)
# Register IntVec extension on python side.
tvm.register_extension(IntVec, IntVec)
......@@ -10,9 +10,41 @@
#include <tvm/packed_func_ext.h>
namespace tvm_ext {
using IntVector = std::vector<int>;
} // namespace tvm_ext
namespace tvm {
namespace runtime {
template<>
struct extension_class_info<tvm_ext::IntVector> {
static const int code = 17;
};
} // namespace tvm
} // namespace runtime
namespace tvm_ext {
using namespace tvm;
using namespace tvm::runtime;
TVM_REGISTER_EXT_TYPE(IntVector);
TVM_REGISTER_GLOBAL("tvm_ext.ivec_create")
.set_body([](TVMArgs args, TVMRetValue *rv) {
IntVector vec;
for (int i = 0; i < args.size(); ++i) {
vec.push_back(args[i].operator int());
}
*rv = vec;
});
TVM_REGISTER_GLOBAL("tvm_ext.ivec_get")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = args[0].AsExtension<IntVector>()[args[1].operator int()];
});
TVM_REGISTER_GLOBAL("tvm_ext.bind_add")
.set_body([](TVMArgs args_, TVMRetValue *rv_) {
PackedFunc pf = args_[0];
......
......@@ -13,6 +13,19 @@ def test_sym_add():
c = tvm_ext.sym_add(a, b)
assert c.a == a and c.b == b
def test_ext_vec():
ivec = tvm_ext.ivec_create(1, 2, 3)
assert(isinstance(ivec, tvm_ext.IntVec))
assert ivec[0] == 1
assert ivec[1] == 2
def ivec_cb(v2):
assert(isinstance(v2, tvm_ext.IntVec))
assert v2[2] == 3
tvm.convert(ivec_cb)(ivec)
if __name__ == "__main__":
test_ext_vec()
test_bind_add()
test_sym_add()
......@@ -89,8 +89,8 @@ inline std::string NodeTypeName() {
// extensions for tvm arg value
template<typename TNodeRef, typename>
inline TVMArgValue::operator TNodeRef() const {
template<typename TNodeRef>
inline TNodeRef TVMArgValue::AsNodeRef() const {
static_assert(
std::is_base_of<NodeRef, TNodeRef>::value,
"Conversion only works for NodeRef");
......@@ -156,8 +156,8 @@ inline TVMRetValue& TVMRetValue::operator=(const NodeRef& other) {
return *this;
}
template<typename TNodeRef, typename>
inline TVMRetValue::operator TNodeRef() const {
template<typename TNodeRef>
inline TNodeRef TVMRetValue::AsNodeRef() const {
static_assert(
std::is_base_of<NodeRef, TNodeRef>::value,
"Conversion only works for NodeRef");
......@@ -166,8 +166,8 @@ inline TVMRetValue::operator TNodeRef() const {
return TNodeRef(*ptr<std::shared_ptr<Node> >());
}
inline void TVMArgsSetter::operator()(size_t i, NodeRef& other) const { // NOLINT(*)
values_[i].v_handle = &(other.node_);
inline void TVMArgsSetter::operator()(size_t i, const NodeRef& other) const { // NOLINT(*)
values_[i].v_handle = const_cast<std::shared_ptr<Node>*>(&(other.node_));
type_codes_[i] = kNodeHandle;
}
......
......@@ -75,7 +75,17 @@ typedef enum {
kModuleHandle = 9U,
kFuncHandle = 10U,
kStr = 11U,
kBytes = 12U
kBytes = 12U,
// Extension codes for other frameworks to integrate TVM PackedFunc.
// To make sure each framework's id do not conflict, use first and
// last sections to mark ranges.
// Open an issue at the repo if you need a section of code.
kExtBegin = 15U,
kNNVMFirst = 16U,
kNNVMLast = 20U,
// The following section of code is used for non-reserved types.
kExtReserveEnd = 64U,
kExtEnd = 128U
} TVMTypeCode;
/*!
......@@ -192,6 +202,14 @@ TVM_DLL int TVMModGetFunction(TVMModuleHandle mod,
TVMFunctionHandle *out);
/*!
* \brief Free front-end extension type resource.
* \param handle The extension handle.
* \param type_code The type of of the extension type.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMExtTypeFree(void* handle, int type_code);
/*!
* \brief Free the Module
* \param mod The module to be freed.
*
......@@ -200,6 +218,7 @@ TVM_DLL int TVMModGetFunction(TVMModuleHandle mod,
* Or if this module is imported by another active module.
*
* The all functions remains valid until TVMFuncFree is called.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMModFree(TVMModuleHandle mod);
......
......@@ -167,6 +167,50 @@ inline std::string TVMType2String(TVMType t);
<< TypeCode2Str(T) << " but get " << TypeCode2Str(CODE) \
/*!
* \brief Type traits to mark if a class is tvm extension type.
*
* To enable extension type in C++ must be register () ed via marco.
* TVM_REGISTER_EXT_TYPE(TypeName) after defining this with this traits.
*
* Extension class can be passed and returned via PackedFunc in all tvm runtime.
* Internally extension class is stored as T*.
*
* \tparam T the typename
*/
template<typename T>
struct extension_class_info {
static const int code = 0;
};
/*!
* \brief Runtime function table about extension type.
*/
class ExtTypeVTable {
public:
/*! \brief function to be called to delete a handle */
void (*destroy)(void* handle);
/*! \brief function to be called when clone a handle */
void* (*clone)(void* handle);
/*!
* \brief Register type
* \tparam T The type to be register.
* \return The registered vtable.
*/
template <typename T>
static inline ExtTypeVTable* Register_();
/*!
* \brief Get a vtable based on type code.
* \param type_code The type code
* \return The registered vtable.
*/
static ExtTypeVTable* Get(int type_code);
private:
// Internal registration function.
static ExtTypeVTable* RegisterInternal(int type_code, const ExtTypeVTable& vt);
};
/*!
* \brief Internal base class to
* handle conversion to POD values.
*/
......@@ -209,6 +253,11 @@ class TVMPODValue_ {
TVM_CHECK_TYPE_CODE(type_code_, kTVMContext);
return value_.v_ctx;
}
template<typename TExtension>
const TExtension& AsExtension() const {
CHECK_LT(type_code_, kExtEnd);
return static_cast<TExtension*>(value_.v_handle)[0];
}
int type_code() const {
return type_code_;
}
......@@ -291,11 +340,13 @@ class TVMArgValue : public TVMPODValue_ {
const TVMValue& value() const {
return value_;
}
// NodeRef related extenstions: in tvm/packed_func_ext.h
template<typename TNodeRef,
// Deferred extension handler.
template<typename TNodeRef>
inline TNodeRef AsNodeRef() const;
template<typename T,
typename = typename std::enable_if<
std::is_class<TNodeRef>::value>::type>
inline operator TNodeRef() const;
std::is_class<T>::value>::type>
inline operator T() const;
template<typename TNodeRef,
typename = typename std::enable_if<
std::is_class<TNodeRef>::value>::type>
......@@ -433,10 +484,18 @@ class TVMRetValue : public TVMPODValue_ {
this->Assign(other);
return *this;
}
TVMRetValue& operator=(TVMArgValue other) {
TVMRetValue& operator=(const TVMArgValue& other) {
this->Assign(other);
return *this;
}
template<typename T,
typename = typename std::enable_if<
extension_class_info<T>::code != 0>::type>
TVMRetValue& operator=(const T& other) {
this->SwitchToClass<T>(
extension_class_info<T>::code, other);
return *this;
}
/*!
* \brief Move the value back to front-end via C API.
* This marks the current container as null.
......@@ -463,12 +522,14 @@ class TVMRetValue : public TVMPODValue_ {
return value_;
}
// NodeRef related extenstions: in tvm/packed_func_ext.h
template<typename T,
typename = typename std::enable_if<
std::is_class<T>::value>::type>
inline operator T() const;
template<typename TNodeRef>
inline TNodeRef AsNodeRef() const;
inline TVMRetValue& operator=(const NodeRef& other);
inline TVMRetValue& operator=(const std::shared_ptr<Node>& other);
template<typename TNodeRef,
typename = typename std::enable_if<
std::is_class<TNodeRef>::value>::type>
inline operator TNodeRef() const;
// type related
inline operator Halide::Type() const;
inline TVMRetValue& operator=(const Halide::Type& other);
......@@ -499,13 +560,20 @@ class TVMRetValue : public TVMPODValue_ {
break;
}
default: {
SwitchToPOD(other.type_code());
value_ = other.value_;
if (other.type_code() < kExtBegin) {
SwitchToPOD(other.type_code());
value_ = other.value_;
} else {
this->Clear();
type_code_ = other.type_code();
value_.v_handle =
(*(ExtTypeVTable::Get(other.type_code())->clone))(
other.value().v_handle);
}
break;
}
}
}
// get the internal container.
void SwitchToPOD(int type_code) {
if (type_code_ != type_code) {
......@@ -531,6 +599,9 @@ class TVMRetValue : public TVMPODValue_ {
case kModuleHandle: delete ptr<Module>(); break;
case kNodeHandle: delete ptr<std::shared_ptr<Node> >(); break;
}
if (type_code_ > kExtBegin) {
(*(ExtTypeVTable::Get(type_code_)->destroy))(value_.v_handle);
}
type_code_ = kNull;
}
};
......@@ -619,24 +690,28 @@ inline PackedFunc::FType PackedFunc::body() const {
// internal namespace
namespace detail {
template<bool stop, std::size_t I, typename F, typename ...Args>
template<bool stop, std::size_t I, typename F>
struct for_each_dispatcher {
static void run(std::tuple<Args...>& args, const F& f) { // NOLINT(*)
f(I, std::get<I>(args));
for_each_dispatcher<(I + 1) == sizeof...(Args), (I+1), F, Args...>::run(args, f);
template<typename T, typename ...Args>
static void run(const F& f, T&& value, Args&&... args) { // NOLINT(*)
f(I, std::forward<T>(value));
for_each_dispatcher<sizeof...(Args) == 0, (I+1), F>
::run(f, std::forward<Args>(args)...);
}
};
template<std::size_t I, typename F, typename ...Args>
struct for_each_dispatcher<true, I, F, Args...> {
static void run(std::tuple<Args...>& args, const F& f) {} // NOLINT(*)
template<std::size_t I, typename F>
struct for_each_dispatcher<true, I, F> {
static void run(const F& f) {} // NOLINT(*)
};
} // namespace detail
template<typename F, typename ...Args>
inline void for_each(std::tuple<Args...>& args, const F& f) { // NOLINT(*)
detail::for_each_dispatcher<sizeof...(Args) == 0, 0, F, Args...>::run(args, f);
inline void for_each(const F& f, Args&&... args) { // NOLINT(*)
for_each_dispatcher<sizeof...(Args) == 0, 0, F>
::run(f, std::forward<Args>(args)...);
}
} // namespace detail
/* \brief argument settter to PackedFunc */
class TVMArgsSetter {
......@@ -645,7 +720,8 @@ class TVMArgsSetter {
: values_(values), type_codes_(type_codes) {}
// setters for POD types
template<typename T,
typename = typename std::enable_if<std::is_integral<T>::value>::type>
typename = typename std::enable_if<
std::is_integral<T>::value>::type>
void operator()(size_t i, T value) const {
values_[i].v_int64 = static_cast<int64_t>(value);
type_codes_[i] = kInt;
......@@ -691,23 +767,23 @@ class TVMArgsSetter {
// setters for container type
// They must be reference(instead of const ref)
// to make sure they are alive in the tuple(instead of getting converted)
void operator()(size_t i, std::string& value) const { // NOLINT(*)
void operator()(size_t i, const std::string& value) const { // NOLINT(*)
values_[i].v_str = value.c_str();
type_codes_[i] = kStr;
}
void operator()(size_t i, TVMByteArray& value) const { // NOLINT(*)
values_[i].v_handle = &value;
void operator()(size_t i, const TVMByteArray& value) const { // NOLINT(*)
values_[i].v_handle = const_cast<TVMByteArray*>(&value);
type_codes_[i] = kBytes;
}
void operator()(size_t i, PackedFunc& value) const { // NOLINT(*)
values_[i].v_handle = &value;
void operator()(size_t i, const PackedFunc& value) const { // NOLINT(*)
values_[i].v_handle = const_cast<PackedFunc*>(&value);
type_codes_[i] = kFuncHandle;
}
void operator()(size_t i, Module& value) const { // NOLINT(*)
values_[i].v_handle = &value;
void operator()(size_t i, const Module& value) const { // NOLINT(*)
values_[i].v_handle = const_cast<Module*>(&value);
type_codes_[i] = kModuleHandle;
}
void operator()(size_t i, TVMRetValue& value) const { // NOLINT(*)
void operator()(size_t i, const TVMRetValue& value) const { // NOLINT(*)
if (value.type_code() == kStr) {
values_[i].v_str = value.ptr<std::string>()->c_str();
type_codes_[i] = kStr;
......@@ -717,8 +793,13 @@ class TVMArgsSetter {
type_codes_[i] = value.type_code();
}
}
// extension
template<typename T,
typename = typename std::enable_if<
extension_class_info<T>::code != 0>::type>
inline void operator()(size_t i, const T& value) const;
// NodeRef related extenstions: in tvm/packed_func_ext.h
inline void operator()(size_t i, NodeRef& other) const; // NOLINT(*)
inline void operator()(size_t i, const NodeRef& other) const; // NOLINT(*)
inline void operator()(size_t i, const Halide::Type& t) const;
private:
......@@ -728,32 +809,79 @@ class TVMArgsSetter {
int* type_codes_;
};
class TVMArgsGetter {
public:
explicit TVMArgsGetter(TVMArgs args)
: args_(args) {}
template<typename T>
inline void operator()(size_t i, T& target) const { // NOLINT(*)
target = args_[i].operator T();
}
private:
TVMArgs args_;
};
template<typename... Args>
inline TVMRetValue PackedFunc::operator()(Args&& ...args) const {
auto targs = std::make_tuple(std::forward<Args>(args)...);
const int kNumArgs = sizeof...(Args);
const int kArraySize = kNumArgs > 0 ? kNumArgs : 1;
TVMValue values[kArraySize];
int type_codes[kArraySize];
for_each(targs, TVMArgsSetter(values, type_codes));
detail::for_each(TVMArgsSetter(values, type_codes),
std::forward<Args>(args)...);
TVMRetValue rv;
body_(TVMArgs(values, type_codes, kNumArgs), &rv);
return rv;
}
// extension and node type handling
namespace detail {
template<typename T, typename TSrc, bool is_ext>
struct TVMValueCast {
static T Apply(const TSrc* self) {
return self->template AsNodeRef<T>();
}
};
template<typename T, typename TSrc>
struct TVMValueCast<T, TSrc, true> {
static T Apply(const TSrc* self) {
return self->template AsExtension<T>();
}
};
} // namespace detail
template<typename T, typename>
inline TVMArgValue::operator T() const {
return detail::
TVMValueCast<T, TVMArgValue, extension_class_info<T>::code != 0>
::Apply(this);
}
template<typename T, typename>
inline TVMRetValue::operator T() const {
return detail::
TVMValueCast<T, TVMRetValue, extension_class_info<T>::code != 0>
::Apply(this);
}
template<typename T, typename>
inline void TVMArgsSetter::operator()(size_t i, const T& value) const {
static_assert(extension_class_info<T>::code != 0,
"Need to have extesion code");
type_codes_[i] = extension_class_info<T>::code;
values_[i].v_handle = const_cast<T*>(&value);
}
// extension type handling
template<typename T>
struct ExtTypeInfo {
static void destroy(void* handle) {
delete static_cast<T*>(handle);
}
static void* clone(void* handle) {
return new T(*static_cast<T*>(handle));
}
};
template<typename T>
inline ExtTypeVTable* ExtTypeVTable::Register_() {
const int code = extension_class_info<T>::code;
static_assert(code != 0,
"require extension_class_info traits to be declared with non-zero code");
ExtTypeVTable vt;
vt.clone = ExtTypeInfo<T>::clone;
vt.destroy = ExtTypeInfo<T>::destroy;
return ExtTypeVTable::RegisterInternal(code, vt);
}
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_PACKED_FUNC_H_
......@@ -73,13 +73,14 @@ class Registry {
*/
static std::vector<std::string> ListNames();
// Internal class.
struct Manager;
private:
/*! \brief name of the function */
std::string name_;
/*! \brief internal packed function */
PackedFunc func_;
// Internal class.
struct Manager;
friend struct Manager;
};
......@@ -96,6 +97,9 @@ class Registry {
#define TVM_FUNC_REG_VAR_DEF \
static TVM_ATTRIBUTE_UNUSED ::tvm::runtime::Registry& __mk_ ## TVM
#define TVM_TYPE_REG_VAR_DEF \
static TVM_ATTRIBUTE_UNUSED ::tvm::runtime::ExtTypeVTable* __mk_ ## TVMT
/*!
* \brief Register a function globally.
* \code
......@@ -108,6 +112,15 @@ class Registry {
TVM_STR_CONCAT(TVM_FUNC_REG_VAR_DEF, __COUNTER__) = \
::tvm::runtime::Registry::Register(OpName)
/*!
* \brief Macro to register extension type.
* This must be registered in a cc file
* after the trait extension_class_info is defined.
*/
#define TVM_REGISTER_EXT_TYPE(T) \
TVM_STR_CONCAT(TVM_TYPE_REG_VAR_DEF, __COUNTER__) = \
::tvm::runtime::ExtTypeVTable::Register_<T>()
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_REGISTRY_H_
......@@ -97,7 +97,7 @@ def _make_tvm_args(args, temp_args):
type_codes[i] = TypeCode.ARRAY_HANDLE
elif isinstance(arg, _nd._TVM_COMPATS):
values[i].v_handle = ctypes.c_void_p(arg._tvm_handle)
type_codes[i] = arg._tvm_tcode
type_codes[i] = arg.__class__._tvm_tcode
elif isinstance(arg, Integral):
values[i].v_int64 = arg
type_codes[i] = TypeCode.INT
......
......@@ -4,6 +4,7 @@ from __future__ import absolute_import
import ctypes
from ..base import _LIB, check_call
from ..runtime_ctypes import TVMArrayHandle
from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func, _return_handle
class NDArrayBase(object):
"""A simple Device/CPU Array object in runtime."""
......@@ -35,9 +36,14 @@ def _make_array(handle, is_view):
_TVM_COMPATS = ()
def _reg_extension(cls):
def _reg_extension(cls, fcreate):
global _TVM_COMPATS
_TVM_COMPATS += (cls,)
if fcreate:
fret = lambda x: fcreate(_return_handle(x))
RETURN_SWITCH[cls._tvm_tcode] = fret
C_TO_PY_ARG_SWITCH[cls._tvm_tcode] = _wrap_arg_func(fret, cls._tvm_tcode)
_CLASS_NDARRAY = None
......
......@@ -18,6 +18,7 @@ cdef enum TVMTypeCode:
kFuncHandle = 10
kStr = 11
kBytes = 12
kExtBegin = 15
cdef extern from "tvm/runtime/c_runtime_api.h":
ctypedef struct DLDataType:
......
......@@ -27,8 +27,10 @@ cdef int tvm_callback(TVMValue* args,
tcode = type_codes[i]
if (tcode == kNodeHandle or
tcode == kFuncHandle or
tcode == kModuleHandle):
tcode == kModuleHandle or
tcode > kExtBegin):
CALL(TVMCbArgToReturn(&value, tcode))
if tcode != kArrayHandle:
pyargs.append(make_ret(value, tcode))
else:
......@@ -87,7 +89,7 @@ cdef inline void make_arg(object arg,
elif isinstance(arg, _TVM_COMPATS):
ptr = arg._tvm_handle
value[0].v_handle = (<void*>ptr)
tcode[0] = arg._tvm_tcode
tcode[0] = arg.__class__._tvm_tcode
elif isinstance(arg, (int, long)):
value[0].v_int64 = arg
tcode[0] = kInt
......@@ -185,8 +187,10 @@ cdef inline object make_ret(TVMValue value, int tcode):
fobj = _CLASS_FUNCTION(None, False)
(<FunctionBase>fobj).chandle = value.v_handle
return fobj
else:
raise ValueError("Unhandled type code %d" % tcode)
elif tcode in _TVM_EXT_RET:
return _TVM_EXT_RET[tcode](ctypes_handle(value.v_handle))
raise ValueError("Unhandled type code %d" % tcode)
cdef inline object FuncCall3(void* chandle, tuple args, int nargs):
......
......@@ -43,9 +43,14 @@ cdef c_make_array(void* chandle, is_view):
cdef _TVM_COMPATS = ()
def _reg_extension(cls):
cdef _TVM_EXT_RET = {}
def _reg_extension(cls, fcreate):
global _TVM_COMPATS
_TVM_COMPATS += (cls,)
if fcreate:
_TVM_EXT_RET[cls._tvm_tcode] = fcreate
def _make_array(handle, is_view):
cdef unsigned long long ptr
......
......@@ -6,7 +6,8 @@ import sys
import ctypes
import numpy as np
from .base import _LIB, check_call, c_array, string_types, _FFI_MODE
from .runtime_ctypes import TVMType, TVMContext, TVMArray, TVMArrayHandle, tvm_shape_index_t
from .runtime_ctypes import TVMType, TVMContext, TVMArray, TVMArrayHandle
from .runtime_ctypes import TypeCode, tvm_shape_index_t
IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
......@@ -222,9 +223,21 @@ class NDArrayBase(_NDArrayBase):
raise ValueError("Unsupported target type %s" % str(type(target)))
return target
def free_extension_handle(handle, type_code):
"""Free c++ extension type handle
def register_extension(cls):
"""Register a extensio class to TVM.
Parameters
----------
handle : ctypes.c_void_p
The handle to the extension type.
type_code : int
The tyoe code
"""
check_call(_LIB.TVMExtTypeFree(handle, ctypes.c_int(type_code)))
def register_extension(cls, fcreate=None):
"""Register a extension class to TVM.
After the class is registered, the class will be able
to directly pass as Function argument generated by TVM.
......@@ -236,16 +249,19 @@ def register_extension(cls):
Note
----
The registered class is requires two properties: _tvm_handle and _tvm_tcode
The registered class is requires one property: _tvm_handle and a class attribute _tvm_tcode.
- ```_tvm_handle``` returns integer represents the address of the handle.
- ```_tvm_tcode``` returns integer represents type code of the class.
- ```_tvm_tcode``` gives integer represents type code of the class.
Returns
-------
cls : class
The class being registered.
fcreate : function, optional
The creation function to create a class object given handle value.
Example
-------
The following code registers user defined class
......@@ -255,16 +271,16 @@ def register_extension(cls):
@tvm.register_extension
class MyTensor(object):
_tvm_tcode = tvm.TypeCode.ARRAY_HANDLE
def __init__(self):
self.handle = _LIB.NewDLTensor()
@property
def _tvm_handle(self):
return self.handle.value
@property
def _tvm_tcode(self):
return tvm.TypeCode.ARRAY_HANDLE
"""
_reg_extension(cls)
if fcreate and cls._tvm_tcode < TypeCode.EXT_BEGIN:
raise ValueError("Cannot register create when extension tcode is same as buildin")
_reg_extension(cls, fcreate)
return cls
......@@ -24,6 +24,7 @@ class TypeCode(object):
FUNC_HANDLE = 10
STR = 11
BYTES = 12
EXT_BEGIN = 15
class TVMByteArray(ctypes.Structure):
"""Temp data structure for byte array."""
......
......@@ -9,7 +9,8 @@ import numpy as _np
from ._ffi.ndarray import TVMContext, TVMType, NDArrayBase
from ._ffi.ndarray import context, empty
from ._ffi.ndarray import _set_class_ndarray, register_extension
from ._ffi.ndarray import _set_class_ndarray
from ._ffi.ndarray import register_extension, free_extension_handle
class NDArray(NDArrayBase):
"""Lightweight NDArray class of TVM runtime.
......
......@@ -9,6 +9,7 @@
#include <unordered_map>
#include <mutex>
#include <memory>
#include <array>
#include "./runtime_base.h"
namespace tvm {
......@@ -21,8 +22,17 @@ struct Registry::Manager {
// and the resource can become invalid because of indeterminstic order of destruction.
// The resources will only be recycled during program exit.
std::unordered_map<std::string, Registry*> fmap;
// vtable for extension type
std::array<ExtTypeVTable, kExtEnd> ext_vtable;
// mutex
std::mutex mutex;
Manager() {
for (auto& x : ext_vtable) {
x.destroy = nullptr;
}
}
static Manager* Global() {
static Manager inst;
return &inst;
......@@ -78,6 +88,24 @@ std::vector<std::string> Registry::ListNames() {
return keys;
}
ExtTypeVTable* ExtTypeVTable::Get(int type_code) {
CHECK(type_code > kExtBegin && type_code < kExtEnd);
Registry::Manager* m = Registry::Manager::Global();
ExtTypeVTable* vt = &(m->ext_vtable[type_code]);
CHECK(vt->destroy != nullptr)
<< "Extension type not registered";
return vt;
}
ExtTypeVTable* ExtTypeVTable::RegisterInternal(
int type_code, const ExtTypeVTable& vt) {
CHECK(type_code > kExtBegin && type_code < kExtEnd);
Registry::Manager* m = Registry::Manager::Global();
std::lock_guard<std::mutex>(m->mutex);
ExtTypeVTable* pvt = &(m->ext_vtable[type_code]);
pvt[0] = vt;
return pvt;
}
} // namespace runtime
} // namespace tvm
......@@ -92,6 +120,11 @@ struct TVMFuncThreadLocalEntry {
/*! \brief Thread local store that can be used to hold return values. */
typedef dmlc::ThreadLocalStore<TVMFuncThreadLocalEntry> TVMFuncThreadLocalStore;
int TVMExtTypeFree(void* handle, int type_code) {
API_BEGIN();
tvm::runtime::ExtTypeVTable::Get(type_code)->destroy(handle);
API_END();
}
int TVMFuncRegisterGlobal(
const char* name, TVMFunctionHandle f, int override) {
......
......@@ -110,6 +110,55 @@ TEST(PackedFunc, Type) {
CHECK(get_type2("float32x2").operator Type() == Float(32, 2));
}
// new namespoace
namespace test {
// register int vector as extension type
using IntVector = std::vector<int>;
} // namespace test
namespace tvm {
namespace runtime {
template<>
struct extension_class_info<test::IntVector> {
static const int code = kExtBegin + 1;
};
} // runtime
} // tvm
// do registration, this need to be in cc file
TVM_REGISTER_EXT_TYPE(test::IntVector);
TEST(PackedFunc, ExtensionType) {
using namespace tvm;
using namespace tvm::runtime;
// note: class are copy by value.
test::IntVector vec{1, 2, 4};
auto copy_vec = PackedFunc([&](TVMArgs args, TVMRetValue* rv) {
// copy by value
const test::IntVector& v = args[0].AsExtension<test::IntVector>();
CHECK(&v == &vec);
test::IntVector v2 = args[0];
CHECK_EQ(v2.size(), 3U);
CHECK_EQ(v[2], 4);
// return copy by value
*rv = v2;
});
auto pass_vec = PackedFunc([&](TVMArgs args, TVMRetValue* rv) {
// copy by value
*rv = args[0];
});
test::IntVector vret1 = copy_vec(vec);
test::IntVector vret2 = pass_vec(copy_vec(vec));
CHECK_EQ(vret1.size(), 3U);
CHECK_EQ(vret2.size(), 3U);
CHECK_EQ(vret1[2], 4);
CHECK_EQ(vret2[2], 4);
}
int main(int argc, char ** argv) {
testing::InitGoogleTest(&argc, argv);
......
......@@ -3,6 +3,7 @@ import numpy as np
@tvm.register_extension
class MyTensorView(object):
_tvm_tcode = tvm.TypeCode.ARRAY_HANDLE
def __init__(self, arr):
self.arr = arr
......@@ -10,10 +11,6 @@ class MyTensorView(object):
def _tvm_handle(self):
return self.arr._tvm_handle
@property
def _tvm_tcode(self):
return tvm.TypeCode.ARRAY_HANDLE
def test_dltensor_compatible():
dtype = 'int64'
n = tvm.var('n')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment