Commit 433756b9 by Tianqi Chen Committed by GitHub

Revert "[RUNTIME] Refactor extension type handling, now it is header only (#924)" (#925)

This reverts commit 12d15704d7f5d30cff7540f1fd16be64c6baca68.
parent a6b4a219
...@@ -22,10 +22,13 @@ struct extension_class_info<tvm_ext::IntVector> { ...@@ -22,10 +22,13 @@ struct extension_class_info<tvm_ext::IntVector> {
} // namespace tvm } // namespace tvm
} // namespace runtime } // namespace runtime
namespace tvm_ext {
using namespace tvm; using namespace tvm;
using namespace tvm::runtime; using namespace tvm::runtime;
namespace tvm_ext { TVM_REGISTER_EXT_TYPE(IntVector);
TVM_REGISTER_GLOBAL("tvm_ext.ivec_create") TVM_REGISTER_GLOBAL("tvm_ext.ivec_create")
.set_body([](TVMArgs args, TVMRetValue *rv) { .set_body([](TVMArgs args, TVMRetValue *rv) {
...@@ -63,18 +66,3 @@ TVM_REGISTER_GLOBAL("device_api.ext_dev") ...@@ -63,18 +66,3 @@ TVM_REGISTER_GLOBAL("device_api.ext_dev")
*rv = (*tvm::runtime::Registry::Get("device_api.cpu"))(); *rv = (*tvm::runtime::Registry::Get("device_api.cpu"))();
}); });
} // namespace tvm_ext } // namespace tvm_ext
// This callback approach allows extension allows tvm to extract
// This way can be helpful when we want to use a header only
// minimum version of TVM Runtime.
extern "C" int TVMExtDeclare(TVMFunctionHandle pregister) {
const PackedFunc& fregister =
*static_cast<PackedFunc*>(pregister);
auto mul = [](TVMArgs args, TVMRetValue *rv) {
int x = args[0];
int y = args[1];
*rv = x * y;
};
fregister("mul", PackedFunc(mul));
return 0;
}
...@@ -44,14 +44,8 @@ def test_ext_vec(): ...@@ -44,14 +44,8 @@ def test_ext_vec():
tvm.convert(ivec_cb)(ivec) tvm.convert(ivec_cb)(ivec)
def test_extract_ext():
fdict = tvm.extract_ext_funcs(tvm_ext._LIB.TVMExtDeclare)
assert fdict["mul"](3, 4) == 12
if __name__ == "__main__": if __name__ == "__main__":
test_ext_dev() test_ext_dev()
test_ext_vec() test_ext_vec()
test_bind_add() test_bind_add()
test_sym_add() test_sym_add()
test_extract_ext()
...@@ -24,13 +24,6 @@ ...@@ -24,13 +24,6 @@
#define TVM_EXTERN_C #define TVM_EXTERN_C
#endif #endif
// Macros to do weak linking
#ifdef _MSC_VER
#define TVM_WEAK __declspec(selectany)
#else
#define TVM_WEAK __attribute__((weak))
#endif
#ifdef __EMSCRIPTEN__ #ifdef __EMSCRIPTEN__
#include <emscripten/emscripten.h> #include <emscripten/emscripten.h>
#define TVM_DLL EMSCRIPTEN_KEEPALIVE #define TVM_DLL EMSCRIPTEN_KEEPALIVE
...@@ -321,17 +314,6 @@ typedef int (*TVMPackedCFunc)( ...@@ -321,17 +314,6 @@ typedef int (*TVMPackedCFunc)(
typedef void (*TVMPackedCFuncFinalizer)(void* resource_handle); typedef void (*TVMPackedCFuncFinalizer)(void* resource_handle);
/*! /*!
* \brief Signature for extension function declarer.
*
* TVM call this function to get the extension functions
* The declarer will call register_func to register function and their name.
*
* \param resource_func_handle The register function
* \return 0 if success, -1 if failure happens
*/
typedef int (*TVMExtensionFuncDeclarer)(TVMFunctionHandle register_func_handle);
/*!
* \brief Wrap a TVMPackedCFunc to become a FunctionHandle. * \brief Wrap a TVMPackedCFunc to become a FunctionHandle.
* *
* The resource_handle will be managed by TVM API, until the function is no longer used. * The resource_handle will be managed by TVM API, until the function is no longer used.
......
...@@ -38,14 +38,8 @@ class Module { ...@@ -38,14 +38,8 @@ class Module {
* \param query_imports Whether also query dependency modules. * \param query_imports Whether also query dependency modules.
* \return The result function. * \return The result function.
* This function will return PackedFunc(nullptr) if function do not exist. * This function will return PackedFunc(nullptr) if function do not exist.
* \note Implemented in packed_func.cc
*/ */
inline PackedFunc GetFunction(const std::string& name, bool query_imports = false); TVM_DLL PackedFunc GetFunction(const std::string& name, bool query_imports = false);
/*! \return internal container */
inline ModuleNode* operator->();
/*! \return internal container */
inline const ModuleNode* operator->() const;
// The following functions requires link with runtime.
/*! /*!
* \brief Import another module into this module. * \brief Import another module into this module.
* \param other The module to be imported. * \param other The module to be imported.
...@@ -63,6 +57,10 @@ class Module { ...@@ -63,6 +57,10 @@ class Module {
*/ */
TVM_DLL static Module LoadFromFile(const std::string& file_name, TVM_DLL static Module LoadFromFile(const std::string& file_name,
const std::string& format = ""); const std::string& format = "");
/*! \return internal container */
inline ModuleNode* operator->();
/*! \return internal container */
inline const ModuleNode* operator->() const;
private: private:
std::shared_ptr<ModuleNode> node_; std::shared_ptr<ModuleNode> node_;
......
...@@ -183,17 +183,31 @@ struct extension_class_info { ...@@ -183,17 +183,31 @@ struct extension_class_info {
}; };
/*! /*!
* \brief Capsule structure holding extension types * \brief Runtime function table about extension type.
* Capsule is self-contained and include
* all the information to clone and destroy the type.
*/ */
struct TVMExtTypeCapsule { class ExtTypeVTable {
/*! \brief The pointer to the object */ public:
void* ptr;
/*! \brief function to be called to delete a handle */ /*! \brief function to be called to delete a handle */
void (*destroy)(void* handle); void (*destroy)(void* handle);
/*! \brief function to be called when clone a handle */ /*! \brief function to be called when clone a handle */
void* (*clone)(void* handle); void* (*clone)(void* handle);
/*!
* \brief Register type
* \tparam T The type to be register.
* \return The registered vtable.
*/
template <typename T>
static inline ExtTypeVTable* Register_();
/*!
* \brief Get a vtable based on type code.
* \param type_code The type code
* \return The registered vtable.
*/
TVM_DLL static ExtTypeVTable* Get(int type_code);
private:
// Internal registration function.
TVM_DLL static ExtTypeVTable* RegisterInternal(int type_code, const ExtTypeVTable& vt);
}; };
/*! /*!
...@@ -241,9 +255,8 @@ class TVMPODValue_ { ...@@ -241,9 +255,8 @@ class TVMPODValue_ {
} }
template<typename TExtension> template<typename TExtension>
const TExtension& AsExtension() const { const TExtension& AsExtension() const {
CHECK_EQ(type_code_, extension_class_info<TExtension>::code); CHECK_LT(type_code_, kExtEnd);
return static_cast<TExtension*>( return static_cast<TExtension*>(value_.v_handle)[0];
static_cast<TVMExtTypeCapsule*>(value_.v_handle)->ptr)[0];
} }
int type_code() const { int type_code() const {
return type_code_; return type_code_;
...@@ -475,6 +488,14 @@ class TVMRetValue : public TVMPODValue_ { ...@@ -475,6 +488,14 @@ class TVMRetValue : public TVMPODValue_ {
this->Assign(other); this->Assign(other);
return *this; return *this;
} }
template<typename T,
typename = typename std::enable_if<
extension_class_info<T>::code != 0>::type>
TVMRetValue& operator=(const T& other) {
this->SwitchToClass<T>(
extension_class_info<T>::code, other);
return *this;
}
/*! /*!
* \brief Move the value back to front-end via C API. * \brief Move the value back to front-end via C API.
* This marks the current container as null. * This marks the current container as null.
...@@ -500,11 +521,6 @@ class TVMRetValue : public TVMPODValue_ { ...@@ -500,11 +521,6 @@ class TVMRetValue : public TVMPODValue_ {
type_code_ != kStr) << "TVMRetValue.value can only be used for POD data"; type_code_ != kStr) << "TVMRetValue.value can only be used for POD data";
return value_; return value_;
} }
// assign extension
template<typename T,
typename = typename std::enable_if<
extension_class_info<T>::code != 0>::type>
inline TVMRetValue& operator=(const T& other);
// NodeRef related extenstions: in tvm/packed_func_ext.h // NodeRef related extenstions: in tvm/packed_func_ext.h
template<typename T, template<typename T,
typename = typename std::enable_if< typename = typename std::enable_if<
...@@ -548,9 +564,11 @@ class TVMRetValue : public TVMPODValue_ { ...@@ -548,9 +564,11 @@ class TVMRetValue : public TVMPODValue_ {
SwitchToPOD(other.type_code()); SwitchToPOD(other.type_code());
value_ = other.value_; value_ = other.value_;
} else { } else {
TVMExtTypeCapsule cap = *other.template ptr<TVMExtTypeCapsule>(); this->Clear();
cap.ptr = cap.clone(cap.ptr); type_code_ = other.type_code();
SwitchToClass<TVMExtTypeCapsule>(other.type_code(), cap); value_.v_handle =
(*(ExtTypeVTable::Get(other.type_code())->clone))(
other.value().v_handle);
} }
break; break;
} }
...@@ -582,9 +600,7 @@ class TVMRetValue : public TVMPODValue_ { ...@@ -582,9 +600,7 @@ class TVMRetValue : public TVMPODValue_ {
case kNodeHandle: delete ptr<std::shared_ptr<Node> >(); break; case kNodeHandle: delete ptr<std::shared_ptr<Node> >(); break;
} }
if (type_code_ > kExtBegin) { if (type_code_ > kExtBegin) {
TVMExtTypeCapsule *cap = ptr<TVMExtTypeCapsule>(); (*(ExtTypeVTable::Get(type_code_)->destroy))(value_.v_handle);
cap->destroy(cap->ptr);
delete cap;
} }
type_code_ = kNull; type_code_ = kNull;
} }
...@@ -700,10 +716,8 @@ inline void for_each(const F& f, Args&&... args) { // NOLINT(*) ...@@ -700,10 +716,8 @@ inline void for_each(const F& f, Args&&... args) { // NOLINT(*)
/* \brief argument settter to PackedFunc */ /* \brief argument settter to PackedFunc */
class TVMArgsSetter { class TVMArgsSetter {
public: public:
TVMArgsSetter(TVMValue* values, TVMArgsSetter(TVMValue* values, int* type_codes)
int* type_codes, : values_(values), type_codes_(type_codes) {}
TVMExtTypeCapsule* exts)
: values_(values), type_codes_(type_codes), exts_(exts) {}
// setters for POD types // setters for POD types
template<typename T, template<typename T,
typename = typename std::enable_if< typename = typename std::enable_if<
...@@ -793,21 +807,15 @@ class TVMArgsSetter { ...@@ -793,21 +807,15 @@ class TVMArgsSetter {
TVMValue* values_; TVMValue* values_;
/*! \brief The type code fields */ /*! \brief The type code fields */
int* type_codes_; int* type_codes_;
/*! \brief Temporary storage for extension types */
TVMExtTypeCapsule* exts_;
}; };
template<typename... Args> template<typename... Args>
inline TVMRetValue PackedFunc::operator()(Args&& ...args) const { inline TVMRetValue PackedFunc::operator()(Args&& ...args) const {
const int kNumArgs = sizeof...(Args); const int kNumArgs = sizeof...(Args);
// Compiler will remove an static array when it is not touched.
const int kArraySize = kNumArgs > 0 ? kNumArgs : 1; const int kArraySize = kNumArgs > 0 ? kNumArgs : 1;
TVMValue values[kArraySize]; TVMValue values[kArraySize];
int type_codes[kArraySize]; int type_codes[kArraySize];
// If the function call does not contain extension type, detail::for_each(TVMArgsSetter(values, type_codes),
// exts will get optimized away by compiler.
TVMExtTypeCapsule exts[kArraySize];
detail::for_each(TVMArgsSetter(values, type_codes, exts),
std::forward<Args>(args)...); std::forward<Args>(args)...);
TVMRetValue rv; TVMRetValue rv;
body_(TVMArgs(values, type_codes, kNumArgs), &rv); body_(TVMArgs(values, type_codes, kNumArgs), &rv);
...@@ -845,6 +853,14 @@ inline TVMRetValue::operator T() const { ...@@ -845,6 +853,14 @@ inline TVMRetValue::operator T() const {
::Apply(this); ::Apply(this);
} }
template<typename T, typename>
inline void TVMArgsSetter::operator()(size_t i, const T& value) const {
static_assert(extension_class_info<T>::code != 0,
"Need to have extesion code");
type_codes_[i] = extension_class_info<T>::code;
values_[i].v_handle = const_cast<T*>(&value);
}
// extension type handling // extension type handling
template<typename T> template<typename T>
struct ExtTypeInfo { struct ExtTypeInfo {
...@@ -856,42 +872,16 @@ struct ExtTypeInfo { ...@@ -856,42 +872,16 @@ struct ExtTypeInfo {
} }
}; };
template<typename T, typename> template<typename T>
inline TVMRetValue& TVMRetValue::operator=(const T& other) { inline ExtTypeVTable* ExtTypeVTable::Register_() {
TVMExtTypeCapsule cap; const int code = extension_class_info<T>::code;
cap.clone = ExtTypeInfo<T>::clone; static_assert(code != 0,
cap.destroy = ExtTypeInfo<T>::destroy; "require extension_class_info traits to be declared with non-zero code");
cap.ptr = new T(other); ExtTypeVTable vt;
SwitchToClass<TVMExtTypeCapsule>( vt.clone = ExtTypeInfo<T>::clone;
extension_class_info<T>::code, cap); vt.destroy = ExtTypeInfo<T>::destroy;
return *this; return ExtTypeVTable::RegisterInternal(code, vt);
}
template<typename T, typename>
inline void TVMArgsSetter::operator()(size_t i, const T& value) const {
static_assert(extension_class_info<T>::code != 0,
"Need to have extesion code");
type_codes_[i] = extension_class_info<T>::code;
exts_[i].clone = ExtTypeInfo<T>::clone;
exts_[i].destroy = ExtTypeInfo<T>::destroy;
exts_[i].ptr = const_cast<T*>(&value);
values_[i].v_handle = &exts_[i];
}
// Implement Module::GetFunction
// Put implementation in this file so we have seen the PackedFunc
inline PackedFunc Module::GetFunction(const std::string& name, bool query_imports) {
PackedFunc pf = node_->GetFunction(name, node_);
if (pf != nullptr) return pf;
if (query_imports) {
for (const Module& m : node_->imports_) {
pf = m.node_->GetFunction(name, m.node_);
if (pf != nullptr) return pf;
}
}
return pf;
} }
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
#endif // TVM_RUNTIME_PACKED_FUNC_H_ #endif // TVM_RUNTIME_PACKED_FUNC_H_
...@@ -234,31 +234,6 @@ def list_global_func_names(): ...@@ -234,31 +234,6 @@ def list_global_func_names():
return fnames return fnames
def extract_ext_funcs(finit):
"""
Extract the extension PackedFuncs from a C module.
Parameters
----------
finit : ctypes function
a ctypes that takes signature of TVMExtensionDeclarer
Returns
-------
fdict : dict of str to Function
The extracted functions
"""
fdict = {}
def _list(name, func):
fdict[name] = func
myf = convert_to_tvm_func(_list)
ret = finit(myf.handle)
_ = myf
if ret != 0:
raise RuntimeError("cannot initialize with %s" % finit)
return fdict
def _get_api(f): def _get_api(f):
flocal = f flocal = f
flocal.is_global = True flocal.is_global = True
......
...@@ -8,7 +8,7 @@ from ._ffi.base import string_types ...@@ -8,7 +8,7 @@ from ._ffi.base import string_types
from ._ffi.node import register_node, NodeBase from ._ffi.node import register_node, NodeBase
from ._ffi.node import convert_to_node as _convert_to_node from ._ffi.node import convert_to_node as _convert_to_node
from ._ffi.function import Function from ._ffi.function import Function
from ._ffi.function import _init_api, register_func, get_global_func, extract_ext_funcs from ._ffi.function import _init_api, register_func, get_global_func
from ._ffi.function import convert_to_tvm_func as _convert_tvm_func from ._ffi.function import convert_to_tvm_func as _convert_tvm_func
from ._ffi.runtime_ctypes import TVMType from ._ffi.runtime_ctypes import TVMType
from . import _api_internal from . import _api_internal
......
...@@ -23,8 +23,7 @@ from . import target as _target ...@@ -23,8 +23,7 @@ from . import target as _target
from . import make from . import make
class DumpIR(object): class DumpIR(object):
""" """Dump IR for each pass.
Dump IR for each pass.
With it, you can dump ir just like gcc/llvm. With it, you can dump ir just like gcc/llvm.
How to use: How to use:
...@@ -33,6 +32,7 @@ class DumpIR(object): ...@@ -33,6 +32,7 @@ class DumpIR(object):
with tvm.build_config(dump_pass_ir=True) with tvm.build_config(dump_pass_ir=True)
run() run()
""" """
scope_level = 0 scope_level = 0
def __init__(self): def __init__(self):
...@@ -40,9 +40,9 @@ class DumpIR(object): ...@@ -40,9 +40,9 @@ class DumpIR(object):
self._recover_list = [] self._recover_list = []
def decorate(self, func): def decorate(self, func):
""" decorate the pass function""" ''' decorate the pass function'''
def dump(*args, **kwargs): def dump(*args, **kwargs):
"""dump function""" '''dump function'''
retv = func(*args, **kwargs) retv = func(*args, **kwargs)
if not isinstance(retv, (_stmt.Stmt, container.LoweredFunc, container.Array)): if not isinstance(retv, (_stmt.Stmt, container.LoweredFunc, container.Array)):
return retv return retv
...@@ -59,7 +59,7 @@ class DumpIR(object): ...@@ -59,7 +59,7 @@ class DumpIR(object):
return dump return dump
def decorate_irpass(self): def decorate_irpass(self):
"""decorate ir_pass and ScheduleOps""" '''decorate ir_pass and ScheduleOps'''
self._old_sgpass = schedule.ScheduleOps self._old_sgpass = schedule.ScheduleOps
schedule.ScheduleOps = self.decorate(schedule.ScheduleOps) schedule.ScheduleOps = self.decorate(schedule.ScheduleOps)
vset = vars(ir_pass) vset = vars(ir_pass)
...@@ -71,7 +71,7 @@ class DumpIR(object): ...@@ -71,7 +71,7 @@ class DumpIR(object):
vset[k] = self.decorate(v) if isinstance(v, types.FunctionType) else v vset[k] = self.decorate(v) if isinstance(v, types.FunctionType) else v
def decorate_custompass(self): def decorate_custompass(self):
""" decorate add_lower_pass pass in BuildConfig""" ''' decorate add_lower_pass pass in BuildConfig'''
cfg = BuildConfig.current cfg = BuildConfig.current
self._old_custom_pass = cfg.add_lower_pass self._old_custom_pass = cfg.add_lower_pass
custom_pass = cfg.add_lower_pass if cfg.add_lower_pass else [] custom_pass = cfg.add_lower_pass if cfg.add_lower_pass else []
...@@ -79,7 +79,7 @@ class DumpIR(object): ...@@ -79,7 +79,7 @@ class DumpIR(object):
BuildConfig.current.add_lower_pass = pass_list BuildConfig.current.add_lower_pass = pass_list
def enter(self): def enter(self):
"""only decorate outermost nest""" '''only decorate outermost nest'''
if DumpIR.scope_level > 0: if DumpIR.scope_level > 0:
return return
self.decorate_irpass() self.decorate_irpass()
...@@ -88,7 +88,7 @@ class DumpIR(object): ...@@ -88,7 +88,7 @@ class DumpIR(object):
DumpIR.scope_level += 1 DumpIR.scope_level += 1
def exit(self): def exit(self):
"""recover outermost nest""" '''recover outermost nest'''
if DumpIR.scope_level > 1: if DumpIR.scope_level > 1:
return return
# recover decorated functions # recover decorated functions
...@@ -163,7 +163,6 @@ class BuildConfig(NodeBase): ...@@ -163,7 +163,6 @@ class BuildConfig(NodeBase):
"'%s' object cannot set attribute '%s'" % (str(type(self)), name)) "'%s' object cannot set attribute '%s'" % (str(type(self)), name))
return super(BuildConfig, self).__setattr__(name, value) return super(BuildConfig, self).__setattr__(name, value)
def build_config(**kwargs): def build_config(**kwargs):
"""Configure the build behavior by setting config variables. """Configure the build behavior by setting config variables.
...@@ -227,7 +226,6 @@ def build_config(**kwargs): ...@@ -227,7 +226,6 @@ def build_config(**kwargs):
setattr(config, k, kwargs[k]) setattr(config, k, kwargs[k])
return config return config
if not _RUNTIME_ONLY: if not _RUNTIME_ONLY:
# BuildConfig is not available in tvm_runtime # BuildConfig is not available in tvm_runtime
BuildConfig.current = build_config() BuildConfig.current = build_config()
...@@ -354,10 +352,8 @@ def lower(sch, ...@@ -354,10 +352,8 @@ def lower(sch,
stmt = f(stmt) stmt = f(stmt)
if simple_mode: if simple_mode:
return stmt return stmt
return ir_pass.MakeAPI(stmt, name, arg_list, 0, cfg.restricted_func) return ir_pass.MakeAPI(stmt, name, arg_list, 0, cfg.restricted_func)
def build(sch, def build(sch,
args=None, args=None,
target=None, target=None,
......
...@@ -347,14 +347,6 @@ int TVMFuncCreateFromCFunc(TVMPackedCFunc func, ...@@ -347,14 +347,6 @@ int TVMFuncCreateFromCFunc(TVMPackedCFunc func,
API_END(); API_END();
} }
int TVMExtTypeFree(void* handle, int type_code) {
API_BEGIN();
TVMExtTypeCapsule* cap = static_cast<TVMExtTypeCapsule*>(handle);
cap->destroy(cap->ptr);
delete cap;
API_END();
}
int TVMArrayAlloc(const tvm_index_t* shape, int TVMArrayAlloc(const tvm_index_t* shape,
int ndim, int ndim,
int dtype_code, int dtype_code,
......
...@@ -13,6 +13,19 @@ ...@@ -13,6 +13,19 @@
namespace tvm { namespace tvm {
namespace runtime { namespace runtime {
PackedFunc Module::GetFunction(
const std::string& name, bool query_imports) {
PackedFunc pf = node_->GetFunction(name, node_);
if (pf != nullptr) return pf;
if (query_imports) {
for (const Module& m : node_->imports_) {
pf = m.node_->GetFunction(name, m.node_);
if (pf != nullptr) return pf;
}
}
return pf;
}
void Module::Import(Module other) { void Module::Import(Module other) {
// specially handle rpc // specially handle rpc
if (!std::strcmp((*this)->type_key(), "rpc")) { if (!std::strcmp((*this)->type_key(), "rpc")) {
......
...@@ -22,10 +22,15 @@ struct Registry::Manager { ...@@ -22,10 +22,15 @@ struct Registry::Manager {
// and the resource can become invalid because of indeterminstic order of destruction. // and the resource can become invalid because of indeterminstic order of destruction.
// The resources will only be recycled during program exit. // The resources will only be recycled during program exit.
std::unordered_map<std::string, Registry*> fmap; std::unordered_map<std::string, Registry*> fmap;
// vtable for extension type
std::array<ExtTypeVTable, kExtEnd> ext_vtable;
// mutex // mutex
std::mutex mutex; std::mutex mutex;
Manager() { Manager() {
for (auto& x : ext_vtable) {
x.destroy = nullptr;
}
} }
static Manager* Global() { static Manager* Global() {
...@@ -83,6 +88,24 @@ std::vector<std::string> Registry::ListNames() { ...@@ -83,6 +88,24 @@ std::vector<std::string> Registry::ListNames() {
return keys; return keys;
} }
ExtTypeVTable* ExtTypeVTable::Get(int type_code) {
CHECK(type_code > kExtBegin && type_code < kExtEnd);
Registry::Manager* m = Registry::Manager::Global();
ExtTypeVTable* vt = &(m->ext_vtable[type_code]);
CHECK(vt->destroy != nullptr)
<< "Extension type not registered";
return vt;
}
ExtTypeVTable* ExtTypeVTable::RegisterInternal(
int type_code, const ExtTypeVTable& vt) {
CHECK(type_code > kExtBegin && type_code < kExtEnd);
Registry::Manager* m = Registry::Manager::Global();
std::lock_guard<std::mutex>(m->mutex);
ExtTypeVTable* pvt = &(m->ext_vtable[type_code]);
pvt[0] = vt;
return pvt;
}
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
...@@ -97,6 +120,12 @@ struct TVMFuncThreadLocalEntry { ...@@ -97,6 +120,12 @@ struct TVMFuncThreadLocalEntry {
/*! \brief Thread local store that can be used to hold return values. */ /*! \brief Thread local store that can be used to hold return values. */
typedef dmlc::ThreadLocalStore<TVMFuncThreadLocalEntry> TVMFuncThreadLocalStore; typedef dmlc::ThreadLocalStore<TVMFuncThreadLocalEntry> TVMFuncThreadLocalStore;
int TVMExtTypeFree(void* handle, int type_code) {
API_BEGIN();
tvm::runtime::ExtTypeVTable::Get(type_code)->destroy(handle);
API_END();
}
int TVMFuncRegisterGlobal( int TVMFuncRegisterGlobal(
const char* name, TVMFunctionHandle f, int override) { const char* name, TVMFunctionHandle f, int override) {
API_BEGIN(); API_BEGIN();
......
...@@ -126,6 +126,9 @@ struct extension_class_info<test::IntVector> { ...@@ -126,6 +126,9 @@ struct extension_class_info<test::IntVector> {
} // runtime } // runtime
} // tvm } // tvm
// do registration, this need to be in cc file
TVM_REGISTER_EXT_TYPE(test::IntVector);
TEST(PackedFunc, ExtensionType) { TEST(PackedFunc, ExtensionType) {
using namespace tvm; using namespace tvm;
using namespace tvm::runtime; using namespace tvm::runtime;
......
...@@ -6,7 +6,6 @@ rm -rf python/tvm/*.pyc python/tvm/*/*.pyc ...@@ -6,7 +6,6 @@ rm -rf python/tvm/*.pyc python/tvm/*/*.pyc
# Test TVM # Test TVM
make cython || exit -1 make cython || exit -1
make cython3 || exit -1
# Test extern package package # Test extern package package
cd apps/extension cd apps/extension
......
...@@ -54,6 +54,8 @@ namespace topi { ...@@ -54,6 +54,8 @@ namespace topi {
using namespace tvm; using namespace tvm;
using namespace tvm::runtime; using namespace tvm::runtime;
TVM_REGISTER_EXT_TYPE(tvm::Target);
/*! \brief Canonicalize an argument that may be Array<Expr> or int to Array<Expr> */ /*! \brief Canonicalize an argument that may be Array<Expr> or int to Array<Expr> */
Array<Expr> ArrayOrInt(TVMArgValue arg) { Array<Expr> ArrayOrInt(TVMArgValue arg) {
if (arg.type_code() == kDLInt || arg.type_code() == kDLUInt) { if (arg.type_code() == kDLInt || arg.type_code() == kDLUInt) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment