Commit f2ab736b by Tianqi Chen Committed by GitHub

[RUNTIME] Enable extension type to PackedFunc. (#447)

* [RUNTIME] Enable extension type to PackedFunc.

* More comments
parent 3130f2d5
...@@ -16,4 +16,27 @@ _LIB = load_lib() ...@@ -16,4 +16,27 @@ _LIB = load_lib()
# Expose two functions into python # Expose two functions into python
bind_add = tvm.get_global_func("tvm_ext.bind_add") bind_add = tvm.get_global_func("tvm_ext.bind_add")
sym_add = tvm.get_global_func("tvm_ext.sym_add") sym_add = tvm.get_global_func("tvm_ext.sym_add")
ivec_create = tvm.get_global_func("tvm_ext.ivec_create")
ivec_get = tvm.get_global_func("tvm_ext.ivec_get")
class IntVec(object):
"""Example for using extension class in c++ """
_tvm_tcode = 17
def __init__(self, handle):
self.handle = handle
def __del__(self):
# You can also call your own customized
# deleter if you can free it via your own FFI.
tvm.nd.free_extension_handle(self.handle, 17)
@property
def _tvm_handle(self):
return self.handle.value
def __getitem__(self, idx):
return ivec_get(self, idx)
# Register IntVec extension on python side.
tvm.register_extension(IntVec, IntVec)
...@@ -10,9 +10,41 @@ ...@@ -10,9 +10,41 @@
#include <tvm/packed_func_ext.h> #include <tvm/packed_func_ext.h>
namespace tvm_ext { namespace tvm_ext {
using IntVector = std::vector<int>;
} // namespace tvm_ext
namespace tvm {
namespace runtime {
template<>
struct extension_class_info<tvm_ext::IntVector> {
static const int code = 17;
};
} // namespace tvm
} // namespace runtime
namespace tvm_ext {
using namespace tvm; using namespace tvm;
using namespace tvm::runtime; using namespace tvm::runtime;
TVM_REGISTER_EXT_TYPE(IntVector);
TVM_REGISTER_GLOBAL("tvm_ext.ivec_create")
.set_body([](TVMArgs args, TVMRetValue *rv) {
IntVector vec;
for (int i = 0; i < args.size(); ++i) {
vec.push_back(args[i].operator int());
}
*rv = vec;
});
TVM_REGISTER_GLOBAL("tvm_ext.ivec_get")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = args[0].AsExtension<IntVector>()[args[1].operator int()];
});
TVM_REGISTER_GLOBAL("tvm_ext.bind_add") TVM_REGISTER_GLOBAL("tvm_ext.bind_add")
.set_body([](TVMArgs args_, TVMRetValue *rv_) { .set_body([](TVMArgs args_, TVMRetValue *rv_) {
PackedFunc pf = args_[0]; PackedFunc pf = args_[0];
......
...@@ -13,6 +13,19 @@ def test_sym_add(): ...@@ -13,6 +13,19 @@ def test_sym_add():
c = tvm_ext.sym_add(a, b) c = tvm_ext.sym_add(a, b)
assert c.a == a and c.b == b assert c.a == a and c.b == b
def test_ext_vec():
ivec = tvm_ext.ivec_create(1, 2, 3)
assert(isinstance(ivec, tvm_ext.IntVec))
assert ivec[0] == 1
assert ivec[1] == 2
def ivec_cb(v2):
assert(isinstance(v2, tvm_ext.IntVec))
assert v2[2] == 3
tvm.convert(ivec_cb)(ivec)
if __name__ == "__main__": if __name__ == "__main__":
test_ext_vec()
test_bind_add() test_bind_add()
test_sym_add() test_sym_add()
...@@ -89,8 +89,8 @@ inline std::string NodeTypeName() { ...@@ -89,8 +89,8 @@ inline std::string NodeTypeName() {
// extensions for tvm arg value // extensions for tvm arg value
template<typename TNodeRef, typename> template<typename TNodeRef>
inline TVMArgValue::operator TNodeRef() const { inline TNodeRef TVMArgValue::AsNodeRef() const {
static_assert( static_assert(
std::is_base_of<NodeRef, TNodeRef>::value, std::is_base_of<NodeRef, TNodeRef>::value,
"Conversion only works for NodeRef"); "Conversion only works for NodeRef");
...@@ -156,8 +156,8 @@ inline TVMRetValue& TVMRetValue::operator=(const NodeRef& other) { ...@@ -156,8 +156,8 @@ inline TVMRetValue& TVMRetValue::operator=(const NodeRef& other) {
return *this; return *this;
} }
template<typename TNodeRef, typename> template<typename TNodeRef>
inline TVMRetValue::operator TNodeRef() const { inline TNodeRef TVMRetValue::AsNodeRef() const {
static_assert( static_assert(
std::is_base_of<NodeRef, TNodeRef>::value, std::is_base_of<NodeRef, TNodeRef>::value,
"Conversion only works for NodeRef"); "Conversion only works for NodeRef");
...@@ -166,8 +166,8 @@ inline TVMRetValue::operator TNodeRef() const { ...@@ -166,8 +166,8 @@ inline TVMRetValue::operator TNodeRef() const {
return TNodeRef(*ptr<std::shared_ptr<Node> >()); return TNodeRef(*ptr<std::shared_ptr<Node> >());
} }
inline void TVMArgsSetter::operator()(size_t i, NodeRef& other) const { // NOLINT(*) inline void TVMArgsSetter::operator()(size_t i, const NodeRef& other) const { // NOLINT(*)
values_[i].v_handle = &(other.node_); values_[i].v_handle = const_cast<std::shared_ptr<Node>*>(&(other.node_));
type_codes_[i] = kNodeHandle; type_codes_[i] = kNodeHandle;
} }
......
...@@ -75,7 +75,17 @@ typedef enum { ...@@ -75,7 +75,17 @@ typedef enum {
kModuleHandle = 9U, kModuleHandle = 9U,
kFuncHandle = 10U, kFuncHandle = 10U,
kStr = 11U, kStr = 11U,
kBytes = 12U kBytes = 12U,
// Extension codes for other frameworks to integrate TVM PackedFunc.
// To make sure each framework's id do not conflict, use first and
// last sections to mark ranges.
// Open an issue at the repo if you need a section of code.
kExtBegin = 15U,
kNNVMFirst = 16U,
kNNVMLast = 20U,
// The following section of code is used for non-reserved types.
kExtReserveEnd = 64U,
kExtEnd = 128U
} TVMTypeCode; } TVMTypeCode;
/*! /*!
...@@ -192,6 +202,14 @@ TVM_DLL int TVMModGetFunction(TVMModuleHandle mod, ...@@ -192,6 +202,14 @@ TVM_DLL int TVMModGetFunction(TVMModuleHandle mod,
TVMFunctionHandle *out); TVMFunctionHandle *out);
/*! /*!
* \brief Free front-end extension type resource.
* \param handle The extension handle.
* \param type_code The type of of the extension type.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMExtTypeFree(void* handle, int type_code);
/*!
* \brief Free the Module * \brief Free the Module
* \param mod The module to be freed. * \param mod The module to be freed.
* *
...@@ -200,6 +218,7 @@ TVM_DLL int TVMModGetFunction(TVMModuleHandle mod, ...@@ -200,6 +218,7 @@ TVM_DLL int TVMModGetFunction(TVMModuleHandle mod,
* Or if this module is imported by another active module. * Or if this module is imported by another active module.
* *
* The all functions remains valid until TVMFuncFree is called. * The all functions remains valid until TVMFuncFree is called.
* \return 0 when success, -1 when failure happens
*/ */
TVM_DLL int TVMModFree(TVMModuleHandle mod); TVM_DLL int TVMModFree(TVMModuleHandle mod);
......
...@@ -73,13 +73,14 @@ class Registry { ...@@ -73,13 +73,14 @@ class Registry {
*/ */
static std::vector<std::string> ListNames(); static std::vector<std::string> ListNames();
// Internal class.
struct Manager;
private: private:
/*! \brief name of the function */ /*! \brief name of the function */
std::string name_; std::string name_;
/*! \brief internal packed function */ /*! \brief internal packed function */
PackedFunc func_; PackedFunc func_;
// Internal class.
struct Manager;
friend struct Manager; friend struct Manager;
}; };
...@@ -96,6 +97,9 @@ class Registry { ...@@ -96,6 +97,9 @@ class Registry {
#define TVM_FUNC_REG_VAR_DEF \ #define TVM_FUNC_REG_VAR_DEF \
static TVM_ATTRIBUTE_UNUSED ::tvm::runtime::Registry& __mk_ ## TVM static TVM_ATTRIBUTE_UNUSED ::tvm::runtime::Registry& __mk_ ## TVM
#define TVM_TYPE_REG_VAR_DEF \
static TVM_ATTRIBUTE_UNUSED ::tvm::runtime::ExtTypeVTable* __mk_ ## TVMT
/*! /*!
* \brief Register a function globally. * \brief Register a function globally.
* \code * \code
...@@ -108,6 +112,15 @@ class Registry { ...@@ -108,6 +112,15 @@ class Registry {
TVM_STR_CONCAT(TVM_FUNC_REG_VAR_DEF, __COUNTER__) = \ TVM_STR_CONCAT(TVM_FUNC_REG_VAR_DEF, __COUNTER__) = \
::tvm::runtime::Registry::Register(OpName) ::tvm::runtime::Registry::Register(OpName)
/*!
* \brief Macro to register extension type.
* This must be registered in a cc file
* after the trait extension_class_info is defined.
*/
#define TVM_REGISTER_EXT_TYPE(T) \
TVM_STR_CONCAT(TVM_TYPE_REG_VAR_DEF, __COUNTER__) = \
::tvm::runtime::ExtTypeVTable::Register_<T>()
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
#endif // TVM_RUNTIME_REGISTRY_H_ #endif // TVM_RUNTIME_REGISTRY_H_
...@@ -97,7 +97,7 @@ def _make_tvm_args(args, temp_args): ...@@ -97,7 +97,7 @@ def _make_tvm_args(args, temp_args):
type_codes[i] = TypeCode.ARRAY_HANDLE type_codes[i] = TypeCode.ARRAY_HANDLE
elif isinstance(arg, _nd._TVM_COMPATS): elif isinstance(arg, _nd._TVM_COMPATS):
values[i].v_handle = ctypes.c_void_p(arg._tvm_handle) values[i].v_handle = ctypes.c_void_p(arg._tvm_handle)
type_codes[i] = arg._tvm_tcode type_codes[i] = arg.__class__._tvm_tcode
elif isinstance(arg, Integral): elif isinstance(arg, Integral):
values[i].v_int64 = arg values[i].v_int64 = arg
type_codes[i] = TypeCode.INT type_codes[i] = TypeCode.INT
......
...@@ -4,6 +4,7 @@ from __future__ import absolute_import ...@@ -4,6 +4,7 @@ from __future__ import absolute_import
import ctypes import ctypes
from ..base import _LIB, check_call from ..base import _LIB, check_call
from ..runtime_ctypes import TVMArrayHandle from ..runtime_ctypes import TVMArrayHandle
from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func, _return_handle
class NDArrayBase(object): class NDArrayBase(object):
"""A simple Device/CPU Array object in runtime.""" """A simple Device/CPU Array object in runtime."""
...@@ -35,9 +36,14 @@ def _make_array(handle, is_view): ...@@ -35,9 +36,14 @@ def _make_array(handle, is_view):
_TVM_COMPATS = () _TVM_COMPATS = ()
def _reg_extension(cls): def _reg_extension(cls, fcreate):
global _TVM_COMPATS global _TVM_COMPATS
_TVM_COMPATS += (cls,) _TVM_COMPATS += (cls,)
if fcreate:
fret = lambda x: fcreate(_return_handle(x))
RETURN_SWITCH[cls._tvm_tcode] = fret
C_TO_PY_ARG_SWITCH[cls._tvm_tcode] = _wrap_arg_func(fret, cls._tvm_tcode)
_CLASS_NDARRAY = None _CLASS_NDARRAY = None
......
...@@ -18,6 +18,7 @@ cdef enum TVMTypeCode: ...@@ -18,6 +18,7 @@ cdef enum TVMTypeCode:
kFuncHandle = 10 kFuncHandle = 10
kStr = 11 kStr = 11
kBytes = 12 kBytes = 12
kExtBegin = 15
cdef extern from "tvm/runtime/c_runtime_api.h": cdef extern from "tvm/runtime/c_runtime_api.h":
ctypedef struct DLDataType: ctypedef struct DLDataType:
......
...@@ -27,8 +27,10 @@ cdef int tvm_callback(TVMValue* args, ...@@ -27,8 +27,10 @@ cdef int tvm_callback(TVMValue* args,
tcode = type_codes[i] tcode = type_codes[i]
if (tcode == kNodeHandle or if (tcode == kNodeHandle or
tcode == kFuncHandle or tcode == kFuncHandle or
tcode == kModuleHandle): tcode == kModuleHandle or
tcode > kExtBegin):
CALL(TVMCbArgToReturn(&value, tcode)) CALL(TVMCbArgToReturn(&value, tcode))
if tcode != kArrayHandle: if tcode != kArrayHandle:
pyargs.append(make_ret(value, tcode)) pyargs.append(make_ret(value, tcode))
else: else:
...@@ -87,7 +89,7 @@ cdef inline void make_arg(object arg, ...@@ -87,7 +89,7 @@ cdef inline void make_arg(object arg,
elif isinstance(arg, _TVM_COMPATS): elif isinstance(arg, _TVM_COMPATS):
ptr = arg._tvm_handle ptr = arg._tvm_handle
value[0].v_handle = (<void*>ptr) value[0].v_handle = (<void*>ptr)
tcode[0] = arg._tvm_tcode tcode[0] = arg.__class__._tvm_tcode
elif isinstance(arg, (int, long)): elif isinstance(arg, (int, long)):
value[0].v_int64 = arg value[0].v_int64 = arg
tcode[0] = kInt tcode[0] = kInt
...@@ -185,8 +187,10 @@ cdef inline object make_ret(TVMValue value, int tcode): ...@@ -185,8 +187,10 @@ cdef inline object make_ret(TVMValue value, int tcode):
fobj = _CLASS_FUNCTION(None, False) fobj = _CLASS_FUNCTION(None, False)
(<FunctionBase>fobj).chandle = value.v_handle (<FunctionBase>fobj).chandle = value.v_handle
return fobj return fobj
else: elif tcode in _TVM_EXT_RET:
raise ValueError("Unhandled type code %d" % tcode) return _TVM_EXT_RET[tcode](ctypes_handle(value.v_handle))
raise ValueError("Unhandled type code %d" % tcode)
cdef inline object FuncCall3(void* chandle, tuple args, int nargs): cdef inline object FuncCall3(void* chandle, tuple args, int nargs):
......
...@@ -43,9 +43,14 @@ cdef c_make_array(void* chandle, is_view): ...@@ -43,9 +43,14 @@ cdef c_make_array(void* chandle, is_view):
cdef _TVM_COMPATS = () cdef _TVM_COMPATS = ()
def _reg_extension(cls): cdef _TVM_EXT_RET = {}
def _reg_extension(cls, fcreate):
global _TVM_COMPATS global _TVM_COMPATS
_TVM_COMPATS += (cls,) _TVM_COMPATS += (cls,)
if fcreate:
_TVM_EXT_RET[cls._tvm_tcode] = fcreate
def _make_array(handle, is_view): def _make_array(handle, is_view):
cdef unsigned long long ptr cdef unsigned long long ptr
......
...@@ -6,7 +6,8 @@ import sys ...@@ -6,7 +6,8 @@ import sys
import ctypes import ctypes
import numpy as np import numpy as np
from .base import _LIB, check_call, c_array, string_types, _FFI_MODE from .base import _LIB, check_call, c_array, string_types, _FFI_MODE
from .runtime_ctypes import TVMType, TVMContext, TVMArray, TVMArrayHandle, tvm_shape_index_t from .runtime_ctypes import TVMType, TVMContext, TVMArray, TVMArrayHandle
from .runtime_ctypes import TypeCode, tvm_shape_index_t
IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
...@@ -222,9 +223,21 @@ class NDArrayBase(_NDArrayBase): ...@@ -222,9 +223,21 @@ class NDArrayBase(_NDArrayBase):
raise ValueError("Unsupported target type %s" % str(type(target))) raise ValueError("Unsupported target type %s" % str(type(target)))
return target return target
def free_extension_handle(handle, type_code):
"""Free c++ extension type handle
def register_extension(cls): Parameters
"""Register a extensio class to TVM. ----------
handle : ctypes.c_void_p
The handle to the extension type.
type_code : int
The tyoe code
"""
check_call(_LIB.TVMExtTypeFree(handle, ctypes.c_int(type_code)))
def register_extension(cls, fcreate=None):
"""Register a extension class to TVM.
After the class is registered, the class will be able After the class is registered, the class will be able
to directly pass as Function argument generated by TVM. to directly pass as Function argument generated by TVM.
...@@ -236,16 +249,19 @@ def register_extension(cls): ...@@ -236,16 +249,19 @@ def register_extension(cls):
Note Note
---- ----
The registered class is requires two properties: _tvm_handle and _tvm_tcode The registered class is requires one property: _tvm_handle and a class attribute _tvm_tcode.
- ```_tvm_handle``` returns integer represents the address of the handle. - ```_tvm_handle``` returns integer represents the address of the handle.
- ```_tvm_tcode``` returns integer represents type code of the class. - ```_tvm_tcode``` gives integer represents type code of the class.
Returns Returns
------- -------
cls : class cls : class
The class being registered. The class being registered.
fcreate : function, optional
The creation function to create a class object given handle value.
Example Example
------- -------
The following code registers user defined class The following code registers user defined class
...@@ -255,16 +271,16 @@ def register_extension(cls): ...@@ -255,16 +271,16 @@ def register_extension(cls):
@tvm.register_extension @tvm.register_extension
class MyTensor(object): class MyTensor(object):
_tvm_tcode = tvm.TypeCode.ARRAY_HANDLE
def __init__(self): def __init__(self):
self.handle = _LIB.NewDLTensor() self.handle = _LIB.NewDLTensor()
@property @property
def _tvm_handle(self): def _tvm_handle(self):
return self.handle.value return self.handle.value
@property
def _tvm_tcode(self):
return tvm.TypeCode.ARRAY_HANDLE
""" """
_reg_extension(cls) if fcreate and cls._tvm_tcode < TypeCode.EXT_BEGIN:
raise ValueError("Cannot register create when extension tcode is same as buildin")
_reg_extension(cls, fcreate)
return cls return cls
...@@ -24,6 +24,7 @@ class TypeCode(object): ...@@ -24,6 +24,7 @@ class TypeCode(object):
FUNC_HANDLE = 10 FUNC_HANDLE = 10
STR = 11 STR = 11
BYTES = 12 BYTES = 12
EXT_BEGIN = 15
class TVMByteArray(ctypes.Structure): class TVMByteArray(ctypes.Structure):
"""Temp data structure for byte array.""" """Temp data structure for byte array."""
......
...@@ -9,7 +9,8 @@ import numpy as _np ...@@ -9,7 +9,8 @@ import numpy as _np
from ._ffi.ndarray import TVMContext, TVMType, NDArrayBase from ._ffi.ndarray import TVMContext, TVMType, NDArrayBase
from ._ffi.ndarray import context, empty from ._ffi.ndarray import context, empty
from ._ffi.ndarray import _set_class_ndarray, register_extension from ._ffi.ndarray import _set_class_ndarray
from ._ffi.ndarray import register_extension, free_extension_handle
class NDArray(NDArrayBase): class NDArray(NDArrayBase):
"""Lightweight NDArray class of TVM runtime. """Lightweight NDArray class of TVM runtime.
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <unordered_map> #include <unordered_map>
#include <mutex> #include <mutex>
#include <memory> #include <memory>
#include <array>
#include "./runtime_base.h" #include "./runtime_base.h"
namespace tvm { namespace tvm {
...@@ -21,8 +22,17 @@ struct Registry::Manager { ...@@ -21,8 +22,17 @@ struct Registry::Manager {
// and the resource can become invalid because of indeterminstic order of destruction. // and the resource can become invalid because of indeterminstic order of destruction.
// The resources will only be recycled during program exit. // The resources will only be recycled during program exit.
std::unordered_map<std::string, Registry*> fmap; std::unordered_map<std::string, Registry*> fmap;
// vtable for extension type
std::array<ExtTypeVTable, kExtEnd> ext_vtable;
// mutex
std::mutex mutex; std::mutex mutex;
Manager() {
for (auto& x : ext_vtable) {
x.destroy = nullptr;
}
}
static Manager* Global() { static Manager* Global() {
static Manager inst; static Manager inst;
return &inst; return &inst;
...@@ -78,6 +88,24 @@ std::vector<std::string> Registry::ListNames() { ...@@ -78,6 +88,24 @@ std::vector<std::string> Registry::ListNames() {
return keys; return keys;
} }
ExtTypeVTable* ExtTypeVTable::Get(int type_code) {
CHECK(type_code > kExtBegin && type_code < kExtEnd);
Registry::Manager* m = Registry::Manager::Global();
ExtTypeVTable* vt = &(m->ext_vtable[type_code]);
CHECK(vt->destroy != nullptr)
<< "Extension type not registered";
return vt;
}
ExtTypeVTable* ExtTypeVTable::RegisterInternal(
int type_code, const ExtTypeVTable& vt) {
CHECK(type_code > kExtBegin && type_code < kExtEnd);
Registry::Manager* m = Registry::Manager::Global();
std::lock_guard<std::mutex>(m->mutex);
ExtTypeVTable* pvt = &(m->ext_vtable[type_code]);
pvt[0] = vt;
return pvt;
}
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
...@@ -92,6 +120,11 @@ struct TVMFuncThreadLocalEntry { ...@@ -92,6 +120,11 @@ struct TVMFuncThreadLocalEntry {
/*! \brief Thread local store that can be used to hold return values. */ /*! \brief Thread local store that can be used to hold return values. */
typedef dmlc::ThreadLocalStore<TVMFuncThreadLocalEntry> TVMFuncThreadLocalStore; typedef dmlc::ThreadLocalStore<TVMFuncThreadLocalEntry> TVMFuncThreadLocalStore;
int TVMExtTypeFree(void* handle, int type_code) {
API_BEGIN();
tvm::runtime::ExtTypeVTable::Get(type_code)->destroy(handle);
API_END();
}
int TVMFuncRegisterGlobal( int TVMFuncRegisterGlobal(
const char* name, TVMFunctionHandle f, int override) { const char* name, TVMFunctionHandle f, int override) {
......
...@@ -110,6 +110,55 @@ TEST(PackedFunc, Type) { ...@@ -110,6 +110,55 @@ TEST(PackedFunc, Type) {
CHECK(get_type2("float32x2").operator Type() == Float(32, 2)); CHECK(get_type2("float32x2").operator Type() == Float(32, 2));
} }
// new namespoace
namespace test {
// register int vector as extension type
using IntVector = std::vector<int>;
} // namespace test
namespace tvm {
namespace runtime {
template<>
struct extension_class_info<test::IntVector> {
static const int code = kExtBegin + 1;
};
} // runtime
} // tvm
// do registration, this need to be in cc file
TVM_REGISTER_EXT_TYPE(test::IntVector);
TEST(PackedFunc, ExtensionType) {
using namespace tvm;
using namespace tvm::runtime;
// note: class are copy by value.
test::IntVector vec{1, 2, 4};
auto copy_vec = PackedFunc([&](TVMArgs args, TVMRetValue* rv) {
// copy by value
const test::IntVector& v = args[0].AsExtension<test::IntVector>();
CHECK(&v == &vec);
test::IntVector v2 = args[0];
CHECK_EQ(v2.size(), 3U);
CHECK_EQ(v[2], 4);
// return copy by value
*rv = v2;
});
auto pass_vec = PackedFunc([&](TVMArgs args, TVMRetValue* rv) {
// copy by value
*rv = args[0];
});
test::IntVector vret1 = copy_vec(vec);
test::IntVector vret2 = pass_vec(copy_vec(vec));
CHECK_EQ(vret1.size(), 3U);
CHECK_EQ(vret2.size(), 3U);
CHECK_EQ(vret1[2], 4);
CHECK_EQ(vret2[2], 4);
}
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
testing::InitGoogleTest(&argc, argv); testing::InitGoogleTest(&argc, argv);
......
...@@ -3,6 +3,7 @@ import numpy as np ...@@ -3,6 +3,7 @@ import numpy as np
@tvm.register_extension @tvm.register_extension
class MyTensorView(object): class MyTensorView(object):
_tvm_tcode = tvm.TypeCode.ARRAY_HANDLE
def __init__(self, arr): def __init__(self, arr):
self.arr = arr self.arr = arr
...@@ -10,10 +11,6 @@ class MyTensorView(object): ...@@ -10,10 +11,6 @@ class MyTensorView(object):
def _tvm_handle(self): def _tvm_handle(self):
return self.arr._tvm_handle return self.arr._tvm_handle
@property
def _tvm_tcode(self):
return tvm.TypeCode.ARRAY_HANDLE
def test_dltensor_compatible(): def test_dltensor_compatible():
dtype = 'int64' dtype = 'int64'
n = tvm.var('n') n = tvm.var('n')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment