Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
f2ab736b
Commit
f2ab736b
authored
Sep 11, 2017
by
Tianqi Chen
Committed by
GitHub
Sep 11, 2017
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[RUNTIME] Enable extension type to PackedFunc. (#447)
* [RUNTIME] Enable extension type to PackedFunc. * More comments
parent
3130f2d5
Hide whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
419 additions
and
78 deletions
+419
-78
apps/extension/python/tvm_ext/__init__.py
+23
-0
apps/extension/src/tvm_ext.cc
+32
-0
apps/extension/tests/test_ext.py
+13
-0
include/tvm/packed_func_ext.h
+6
-6
include/tvm/runtime/c_runtime_api.h
+20
-1
include/tvm/runtime/packed_func.h
+175
-47
include/tvm/runtime/registry.h
+15
-2
python/tvm/_ffi/_ctypes/function.py
+1
-1
python/tvm/_ffi/_ctypes/ndarray.py
+7
-1
python/tvm/_ffi/_cython/base.pxi
+1
-0
python/tvm/_ffi/_cython/function.pxi
+8
-4
python/tvm/_ffi/_cython/ndarray.pxi
+6
-1
python/tvm/_ffi/ndarray.py
+26
-10
python/tvm/_ffi/runtime_ctypes.py
+1
-0
python/tvm/ndarray.py
+2
-1
src/runtime/registry.cc
+33
-0
tests/cpp/packed_func_test.cc
+49
-0
tests/python/unittest/test_runtime_extension.py
+1
-4
No files found.
apps/extension/python/tvm_ext/__init__.py
View file @
f2ab736b
...
...
@@ -16,4 +16,27 @@ _LIB = load_lib()
# Expose two functions into python
bind_add
=
tvm
.
get_global_func
(
"tvm_ext.bind_add"
)
sym_add
=
tvm
.
get_global_func
(
"tvm_ext.sym_add"
)
ivec_create
=
tvm
.
get_global_func
(
"tvm_ext.ivec_create"
)
ivec_get
=
tvm
.
get_global_func
(
"tvm_ext.ivec_get"
)
class
IntVec
(
object
):
"""Example for using extension class in c++ """
_tvm_tcode
=
17
def
__init__
(
self
,
handle
):
self
.
handle
=
handle
def
__del__
(
self
):
# You can also call your own customized
# deleter if you can free it via your own FFI.
tvm
.
nd
.
free_extension_handle
(
self
.
handle
,
17
)
@property
def
_tvm_handle
(
self
):
return
self
.
handle
.
value
def
__getitem__
(
self
,
idx
):
return
ivec_get
(
self
,
idx
)
# Register IntVec extension on python side.
tvm
.
register_extension
(
IntVec
,
IntVec
)
apps/extension/src/tvm_ext.cc
View file @
f2ab736b
...
...
@@ -10,9 +10,41 @@
#include <tvm/packed_func_ext.h>
namespace
tvm_ext
{
using
IntVector
=
std
::
vector
<
int
>
;
}
// namespace tvm_ext
namespace
tvm
{
namespace
runtime
{
template
<>
struct
extension_class_info
<
tvm_ext
::
IntVector
>
{
static
const
int
code
=
17
;
};
}
// namespace tvm
}
// namespace runtime
namespace
tvm_ext
{
using
namespace
tvm
;
using
namespace
tvm
::
runtime
;
TVM_REGISTER_EXT_TYPE
(
IntVector
);
TVM_REGISTER_GLOBAL
(
"tvm_ext.ivec_create"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
IntVector
vec
;
for
(
int
i
=
0
;
i
<
args
.
size
();
++
i
)
{
vec
.
push_back
(
args
[
i
].
operator
int
());
}
*
rv
=
vec
;
});
TVM_REGISTER_GLOBAL
(
"tvm_ext.ivec_get"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
*
rv
=
args
[
0
].
AsExtension
<
IntVector
>
()[
args
[
1
].
operator
int
()];
});
TVM_REGISTER_GLOBAL
(
"tvm_ext.bind_add"
)
.
set_body
([](
TVMArgs
args_
,
TVMRetValue
*
rv_
)
{
PackedFunc
pf
=
args_
[
0
];
...
...
apps/extension/tests/test_ext.py
View file @
f2ab736b
...
...
@@ -13,6 +13,19 @@ def test_sym_add():
c
=
tvm_ext
.
sym_add
(
a
,
b
)
assert
c
.
a
==
a
and
c
.
b
==
b
def
test_ext_vec
():
ivec
=
tvm_ext
.
ivec_create
(
1
,
2
,
3
)
assert
(
isinstance
(
ivec
,
tvm_ext
.
IntVec
))
assert
ivec
[
0
]
==
1
assert
ivec
[
1
]
==
2
def
ivec_cb
(
v2
):
assert
(
isinstance
(
v2
,
tvm_ext
.
IntVec
))
assert
v2
[
2
]
==
3
tvm
.
convert
(
ivec_cb
)(
ivec
)
if
__name__
==
"__main__"
:
test_ext_vec
()
test_bind_add
()
test_sym_add
()
include/tvm/packed_func_ext.h
View file @
f2ab736b
...
...
@@ -89,8 +89,8 @@ inline std::string NodeTypeName() {
// extensions for tvm arg value
template
<
typename
TNodeRef
,
typename
>
inline
T
VMArgValue
::
operator
T
NodeRef
()
const
{
template
<
typename
TNodeRef
>
inline
T
NodeRef
TVMArgValue
::
As
NodeRef
()
const
{
static_assert
(
std
::
is_base_of
<
NodeRef
,
TNodeRef
>::
value
,
"Conversion only works for NodeRef"
);
...
...
@@ -156,8 +156,8 @@ inline TVMRetValue& TVMRetValue::operator=(const NodeRef& other) {
return
*
this
;
}
template
<
typename
TNodeRef
,
typename
>
inline
T
VMRetValue
::
operator
T
NodeRef
()
const
{
template
<
typename
TNodeRef
>
inline
T
NodeRef
TVMRetValue
::
As
NodeRef
()
const
{
static_assert
(
std
::
is_base_of
<
NodeRef
,
TNodeRef
>::
value
,
"Conversion only works for NodeRef"
);
...
...
@@ -166,8 +166,8 @@ inline TVMRetValue::operator TNodeRef() const {
return
TNodeRef
(
*
ptr
<
std
::
shared_ptr
<
Node
>
>
());
}
inline
void
TVMArgsSetter
::
operator
()(
size_t
i
,
NodeRef
&
other
)
const
{
// NOLINT(*)
values_
[
i
].
v_handle
=
&
(
other
.
node_
);
inline
void
TVMArgsSetter
::
operator
()(
size_t
i
,
const
NodeRef
&
other
)
const
{
// NOLINT(*)
values_
[
i
].
v_handle
=
const_cast
<
std
::
shared_ptr
<
Node
>*>
(
&
(
other
.
node_
)
);
type_codes_
[
i
]
=
kNodeHandle
;
}
...
...
include/tvm/runtime/c_runtime_api.h
View file @
f2ab736b
...
...
@@ -75,7 +75,17 @@ typedef enum {
kModuleHandle
=
9U
,
kFuncHandle
=
10U
,
kStr
=
11U
,
kBytes
=
12U
kBytes
=
12U
,
// Extension codes for other frameworks to integrate TVM PackedFunc.
// To make sure each framework's id do not conflict, use first and
// last sections to mark ranges.
// Open an issue at the repo if you need a section of code.
kExtBegin
=
15U
,
kNNVMFirst
=
16U
,
kNNVMLast
=
20U
,
// The following section of code is used for non-reserved types.
kExtReserveEnd
=
64U
,
kExtEnd
=
128U
}
TVMTypeCode
;
/*!
...
...
@@ -192,6 +202,14 @@ TVM_DLL int TVMModGetFunction(TVMModuleHandle mod,
TVMFunctionHandle
*
out
);
/*!
* \brief Free front-end extension type resource.
* \param handle The extension handle.
* \param type_code The type of of the extension type.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL
int
TVMExtTypeFree
(
void
*
handle
,
int
type_code
);
/*!
* \brief Free the Module
* \param mod The module to be freed.
*
...
...
@@ -200,6 +218,7 @@ TVM_DLL int TVMModGetFunction(TVMModuleHandle mod,
* Or if this module is imported by another active module.
*
* The all functions remains valid until TVMFuncFree is called.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL
int
TVMModFree
(
TVMModuleHandle
mod
);
...
...
include/tvm/runtime/packed_func.h
View file @
f2ab736b
...
...
@@ -167,6 +167,50 @@ inline std::string TVMType2String(TVMType t);
<< TypeCode2Str(T) << " but get " << TypeCode2Str(CODE) \
/*!
* \brief Type traits to mark if a class is tvm extension type.
*
* To enable extension type in C++ must be register () ed via marco.
* TVM_REGISTER_EXT_TYPE(TypeName) after defining this with this traits.
*
* Extension class can be passed and returned via PackedFunc in all tvm runtime.
* Internally extension class is stored as T*.
*
* \tparam T the typename
*/
template
<
typename
T
>
struct
extension_class_info
{
static
const
int
code
=
0
;
};
/*!
* \brief Runtime function table about extension type.
*/
class
ExtTypeVTable
{
public
:
/*! \brief function to be called to delete a handle */
void
(
*
destroy
)(
void
*
handle
);
/*! \brief function to be called when clone a handle */
void
*
(
*
clone
)(
void
*
handle
);
/*!
* \brief Register type
* \tparam T The type to be register.
* \return The registered vtable.
*/
template
<
typename
T
>
static
inline
ExtTypeVTable
*
Register_
();
/*!
* \brief Get a vtable based on type code.
* \param type_code The type code
* \return The registered vtable.
*/
static
ExtTypeVTable
*
Get
(
int
type_code
);
private
:
// Internal registration function.
static
ExtTypeVTable
*
RegisterInternal
(
int
type_code
,
const
ExtTypeVTable
&
vt
);
};
/*!
* \brief Internal base class to
* handle conversion to POD values.
*/
...
...
@@ -209,6 +253,11 @@ class TVMPODValue_ {
TVM_CHECK_TYPE_CODE
(
type_code_
,
kTVMContext
);
return
value_
.
v_ctx
;
}
template
<
typename
TExtension
>
const
TExtension
&
AsExtension
()
const
{
CHECK_LT
(
type_code_
,
kExtEnd
);
return
static_cast
<
TExtension
*>
(
value_
.
v_handle
)[
0
];
}
int
type_code
()
const
{
return
type_code_
;
}
...
...
@@ -291,11 +340,13 @@ class TVMArgValue : public TVMPODValue_ {
const
TVMValue
&
value
()
const
{
return
value_
;
}
// NodeRef related extenstions: in tvm/packed_func_ext.h
template
<
typename
TNodeRef
,
// Deferred extension handler.
template
<
typename
TNodeRef
>
inline
TNodeRef
AsNodeRef
()
const
;
template
<
typename
T
,
typename
=
typename
std
::
enable_if
<
std
::
is_class
<
T
NodeRef
>::
value
>::
type
>
inline
operator
T
NodeRef
()
const
;
std
::
is_class
<
T
>::
value
>::
type
>
inline
operator
T
()
const
;
template
<
typename
TNodeRef
,
typename
=
typename
std
::
enable_if
<
std
::
is_class
<
TNodeRef
>::
value
>::
type
>
...
...
@@ -433,10 +484,18 @@ class TVMRetValue : public TVMPODValue_ {
this
->
Assign
(
other
);
return
*
this
;
}
TVMRetValue
&
operator
=
(
TVMArgValue
other
)
{
TVMRetValue
&
operator
=
(
const
TVMArgValue
&
other
)
{
this
->
Assign
(
other
);
return
*
this
;
}
template
<
typename
T
,
typename
=
typename
std
::
enable_if
<
extension_class_info
<
T
>::
code
!=
0
>::
type
>
TVMRetValue
&
operator
=
(
const
T
&
other
)
{
this
->
SwitchToClass
<
T
>
(
extension_class_info
<
T
>::
code
,
other
);
return
*
this
;
}
/*!
* \brief Move the value back to front-end via C API.
* This marks the current container as null.
...
...
@@ -463,12 +522,14 @@ class TVMRetValue : public TVMPODValue_ {
return
value_
;
}
// NodeRef related extenstions: in tvm/packed_func_ext.h
template
<
typename
T
,
typename
=
typename
std
::
enable_if
<
std
::
is_class
<
T
>::
value
>::
type
>
inline
operator
T
()
const
;
template
<
typename
TNodeRef
>
inline
TNodeRef
AsNodeRef
()
const
;
inline
TVMRetValue
&
operator
=
(
const
NodeRef
&
other
);
inline
TVMRetValue
&
operator
=
(
const
std
::
shared_ptr
<
Node
>&
other
);
template
<
typename
TNodeRef
,
typename
=
typename
std
::
enable_if
<
std
::
is_class
<
TNodeRef
>::
value
>::
type
>
inline
operator
TNodeRef
()
const
;
// type related
inline
operator
Halide
::
Type
()
const
;
inline
TVMRetValue
&
operator
=
(
const
Halide
::
Type
&
other
);
...
...
@@ -499,13 +560,20 @@ class TVMRetValue : public TVMPODValue_ {
break
;
}
default
:
{
SwitchToPOD
(
other
.
type_code
());
value_
=
other
.
value_
;
if
(
other
.
type_code
()
<
kExtBegin
)
{
SwitchToPOD
(
other
.
type_code
());
value_
=
other
.
value_
;
}
else
{
this
->
Clear
();
type_code_
=
other
.
type_code
();
value_
.
v_handle
=
(
*
(
ExtTypeVTable
::
Get
(
other
.
type_code
())
->
clone
))(
other
.
value
().
v_handle
);
}
break
;
}
}
}
// get the internal container.
void
SwitchToPOD
(
int
type_code
)
{
if
(
type_code_
!=
type_code
)
{
...
...
@@ -531,6 +599,9 @@ class TVMRetValue : public TVMPODValue_ {
case
kModuleHandle
:
delete
ptr
<
Module
>
();
break
;
case
kNodeHandle
:
delete
ptr
<
std
::
shared_ptr
<
Node
>
>
();
break
;
}
if
(
type_code_
>
kExtBegin
)
{
(
*
(
ExtTypeVTable
::
Get
(
type_code_
)
->
destroy
))(
value_
.
v_handle
);
}
type_code_
=
kNull
;
}
};
...
...
@@ -619,24 +690,28 @@ inline PackedFunc::FType PackedFunc::body() const {
// internal namespace
namespace
detail
{
template
<
bool
stop
,
std
::
size_t
I
,
typename
F
,
typename
...
Args
>
template
<
bool
stop
,
std
::
size_t
I
,
typename
F
>
struct
for_each_dispatcher
{
static
void
run
(
std
::
tuple
<
Args
...
>&
args
,
const
F
&
f
)
{
// NOLINT(*)
f
(
I
,
std
::
get
<
I
>
(
args
));
for_each_dispatcher
<
(
I
+
1
)
==
sizeof
...(
Args
),
(
I
+
1
),
F
,
Args
...
>::
run
(
args
,
f
);
template
<
typename
T
,
typename
...
Args
>
static
void
run
(
const
F
&
f
,
T
&&
value
,
Args
&&
...
args
)
{
// NOLINT(*)
f
(
I
,
std
::
forward
<
T
>
(
value
));
for_each_dispatcher
<
sizeof
...(
Args
)
==
0
,
(
I
+
1
),
F
>
::
run
(
f
,
std
::
forward
<
Args
>
(
args
)...);
}
};
template
<
std
::
size_t
I
,
typename
F
,
typename
...
Args
>
struct
for_each_dispatcher
<
true
,
I
,
F
,
Args
...
>
{
static
void
run
(
std
::
tuple
<
Args
...
>&
args
,
const
F
&
f
)
{}
// NOLINT(*)
template
<
std
::
size_t
I
,
typename
F
>
struct
for_each_dispatcher
<
true
,
I
,
F
>
{
static
void
run
(
const
F
&
f
)
{}
// NOLINT(*)
};
}
// namespace detail
template
<
typename
F
,
typename
...
Args
>
inline
void
for_each
(
std
::
tuple
<
Args
...
>&
args
,
const
F
&
f
)
{
// NOLINT(*)
detail
::
for_each_dispatcher
<
sizeof
...(
Args
)
==
0
,
0
,
F
,
Args
...
>::
run
(
args
,
f
);
inline
void
for_each
(
const
F
&
f
,
Args
&&
...
args
)
{
// NOLINT(*)
for_each_dispatcher
<
sizeof
...(
Args
)
==
0
,
0
,
F
>
::
run
(
f
,
std
::
forward
<
Args
>
(
args
)...);
}
}
// namespace detail
/* \brief argument settter to PackedFunc */
class
TVMArgsSetter
{
...
...
@@ -645,7 +720,8 @@ class TVMArgsSetter {
:
values_
(
values
),
type_codes_
(
type_codes
)
{}
// setters for POD types
template
<
typename
T
,
typename
=
typename
std
::
enable_if
<
std
::
is_integral
<
T
>::
value
>::
type
>
typename
=
typename
std
::
enable_if
<
std
::
is_integral
<
T
>::
value
>::
type
>
void
operator
()(
size_t
i
,
T
value
)
const
{
values_
[
i
].
v_int64
=
static_cast
<
int64_t
>
(
value
);
type_codes_
[
i
]
=
kInt
;
...
...
@@ -691,23 +767,23 @@ class TVMArgsSetter {
// setters for container type
// They must be reference(instead of const ref)
// to make sure they are alive in the tuple(instead of getting converted)
void
operator
()(
size_t
i
,
std
::
string
&
value
)
const
{
// NOLINT(*)
void
operator
()(
size_t
i
,
const
std
::
string
&
value
)
const
{
// NOLINT(*)
values_
[
i
].
v_str
=
value
.
c_str
();
type_codes_
[
i
]
=
kStr
;
}
void
operator
()(
size_t
i
,
TVMByteArray
&
value
)
const
{
// NOLINT(*)
values_
[
i
].
v_handle
=
&
value
;
void
operator
()(
size_t
i
,
const
TVMByteArray
&
value
)
const
{
// NOLINT(*)
values_
[
i
].
v_handle
=
const_cast
<
TVMByteArray
*>
(
&
value
)
;
type_codes_
[
i
]
=
kBytes
;
}
void
operator
()(
size_t
i
,
PackedFunc
&
value
)
const
{
// NOLINT(*)
values_
[
i
].
v_handle
=
&
value
;
void
operator
()(
size_t
i
,
const
PackedFunc
&
value
)
const
{
// NOLINT(*)
values_
[
i
].
v_handle
=
const_cast
<
PackedFunc
*>
(
&
value
)
;
type_codes_
[
i
]
=
kFuncHandle
;
}
void
operator
()(
size_t
i
,
Module
&
value
)
const
{
// NOLINT(*)
values_
[
i
].
v_handle
=
&
value
;
void
operator
()(
size_t
i
,
const
Module
&
value
)
const
{
// NOLINT(*)
values_
[
i
].
v_handle
=
const_cast
<
Module
*>
(
&
value
)
;
type_codes_
[
i
]
=
kModuleHandle
;
}
void
operator
()(
size_t
i
,
TVMRetValue
&
value
)
const
{
// NOLINT(*)
void
operator
()(
size_t
i
,
const
TVMRetValue
&
value
)
const
{
// NOLINT(*)
if
(
value
.
type_code
()
==
kStr
)
{
values_
[
i
].
v_str
=
value
.
ptr
<
std
::
string
>
()
->
c_str
();
type_codes_
[
i
]
=
kStr
;
...
...
@@ -717,8 +793,13 @@ class TVMArgsSetter {
type_codes_
[
i
]
=
value
.
type_code
();
}
}
// extension
template
<
typename
T
,
typename
=
typename
std
::
enable_if
<
extension_class_info
<
T
>::
code
!=
0
>::
type
>
inline
void
operator
()(
size_t
i
,
const
T
&
value
)
const
;
// NodeRef related extenstions: in tvm/packed_func_ext.h
inline
void
operator
()(
size_t
i
,
NodeRef
&
other
)
const
;
// NOLINT(*)
inline
void
operator
()(
size_t
i
,
const
NodeRef
&
other
)
const
;
// NOLINT(*)
inline
void
operator
()(
size_t
i
,
const
Halide
::
Type
&
t
)
const
;
private
:
...
...
@@ -728,32 +809,79 @@ class TVMArgsSetter {
int
*
type_codes_
;
};
class
TVMArgsGetter
{
public
:
explicit
TVMArgsGetter
(
TVMArgs
args
)
:
args_
(
args
)
{}
template
<
typename
T
>
inline
void
operator
()(
size_t
i
,
T
&
target
)
const
{
// NOLINT(*)
target
=
args_
[
i
].
operator
T
();
}
private
:
TVMArgs
args_
;
};
template
<
typename
...
Args
>
inline
TVMRetValue
PackedFunc
::
operator
()(
Args
&&
...
args
)
const
{
auto
targs
=
std
::
make_tuple
(
std
::
forward
<
Args
>
(
args
)...);
const
int
kNumArgs
=
sizeof
...(
Args
);
const
int
kArraySize
=
kNumArgs
>
0
?
kNumArgs
:
1
;
TVMValue
values
[
kArraySize
];
int
type_codes
[
kArraySize
];
for_each
(
targs
,
TVMArgsSetter
(
values
,
type_codes
));
detail
::
for_each
(
TVMArgsSetter
(
values
,
type_codes
),
std
::
forward
<
Args
>
(
args
)...);
TVMRetValue
rv
;
body_
(
TVMArgs
(
values
,
type_codes
,
kNumArgs
),
&
rv
);
return
rv
;
}
// extension and node type handling
namespace
detail
{
template
<
typename
T
,
typename
TSrc
,
bool
is_ext
>
struct
TVMValueCast
{
static
T
Apply
(
const
TSrc
*
self
)
{
return
self
->
template
AsNodeRef
<
T
>
();
}
};
template
<
typename
T
,
typename
TSrc
>
struct
TVMValueCast
<
T
,
TSrc
,
true
>
{
static
T
Apply
(
const
TSrc
*
self
)
{
return
self
->
template
AsExtension
<
T
>
();
}
};
}
// namespace detail
template
<
typename
T
,
typename
>
inline
TVMArgValue
::
operator
T
()
const
{
return
detail
::
TVMValueCast
<
T
,
TVMArgValue
,
extension_class_info
<
T
>::
code
!=
0
>
::
Apply
(
this
);
}
template
<
typename
T
,
typename
>
inline
TVMRetValue
::
operator
T
()
const
{
return
detail
::
TVMValueCast
<
T
,
TVMRetValue
,
extension_class_info
<
T
>::
code
!=
0
>
::
Apply
(
this
);
}
template
<
typename
T
,
typename
>
inline
void
TVMArgsSetter
::
operator
()(
size_t
i
,
const
T
&
value
)
const
{
static_assert
(
extension_class_info
<
T
>::
code
!=
0
,
"Need to have extesion code"
);
type_codes_
[
i
]
=
extension_class_info
<
T
>::
code
;
values_
[
i
].
v_handle
=
const_cast
<
T
*>
(
&
value
);
}
// extension type handling
template
<
typename
T
>
struct
ExtTypeInfo
{
static
void
destroy
(
void
*
handle
)
{
delete
static_cast
<
T
*>
(
handle
);
}
static
void
*
clone
(
void
*
handle
)
{
return
new
T
(
*
static_cast
<
T
*>
(
handle
));
}
};
template
<
typename
T
>
inline
ExtTypeVTable
*
ExtTypeVTable
::
Register_
()
{
const
int
code
=
extension_class_info
<
T
>::
code
;
static_assert
(
code
!=
0
,
"require extension_class_info traits to be declared with non-zero code"
);
ExtTypeVTable
vt
;
vt
.
clone
=
ExtTypeInfo
<
T
>::
clone
;
vt
.
destroy
=
ExtTypeInfo
<
T
>::
destroy
;
return
ExtTypeVTable
::
RegisterInternal
(
code
,
vt
);
}
}
// namespace runtime
}
// namespace tvm
#endif // TVM_RUNTIME_PACKED_FUNC_H_
include/tvm/runtime/registry.h
View file @
f2ab736b
...
...
@@ -73,13 +73,14 @@ class Registry {
*/
static
std
::
vector
<
std
::
string
>
ListNames
();
// Internal class.
struct
Manager
;
private
:
/*! \brief name of the function */
std
::
string
name_
;
/*! \brief internal packed function */
PackedFunc
func_
;
// Internal class.
struct
Manager
;
friend
struct
Manager
;
};
...
...
@@ -96,6 +97,9 @@ class Registry {
#define TVM_FUNC_REG_VAR_DEF \
static TVM_ATTRIBUTE_UNUSED ::tvm::runtime::Registry& __mk_ ## TVM
#define TVM_TYPE_REG_VAR_DEF \
static TVM_ATTRIBUTE_UNUSED ::tvm::runtime::ExtTypeVTable* __mk_ ## TVMT
/*!
* \brief Register a function globally.
* \code
...
...
@@ -108,6 +112,15 @@ class Registry {
TVM_STR_CONCAT(TVM_FUNC_REG_VAR_DEF, __COUNTER__) = \
::tvm::runtime::Registry::Register(OpName)
/*!
* \brief Macro to register extension type.
* This must be registered in a cc file
* after the trait extension_class_info is defined.
*/
#define TVM_REGISTER_EXT_TYPE(T) \
TVM_STR_CONCAT(TVM_TYPE_REG_VAR_DEF, __COUNTER__) = \
::tvm::runtime::ExtTypeVTable::Register_<T>()
}
// namespace runtime
}
// namespace tvm
#endif // TVM_RUNTIME_REGISTRY_H_
python/tvm/_ffi/_ctypes/function.py
View file @
f2ab736b
...
...
@@ -97,7 +97,7 @@ def _make_tvm_args(args, temp_args):
type_codes
[
i
]
=
TypeCode
.
ARRAY_HANDLE
elif
isinstance
(
arg
,
_nd
.
_TVM_COMPATS
):
values
[
i
]
.
v_handle
=
ctypes
.
c_void_p
(
arg
.
_tvm_handle
)
type_codes
[
i
]
=
arg
.
_tvm_tcode
type_codes
[
i
]
=
arg
.
_
_class__
.
_
tvm_tcode
elif
isinstance
(
arg
,
Integral
):
values
[
i
]
.
v_int64
=
arg
type_codes
[
i
]
=
TypeCode
.
INT
...
...
python/tvm/_ffi/_ctypes/ndarray.py
View file @
f2ab736b
...
...
@@ -4,6 +4,7 @@ from __future__ import absolute_import
import
ctypes
from
..base
import
_LIB
,
check_call
from
..runtime_ctypes
import
TVMArrayHandle
from
.types
import
RETURN_SWITCH
,
C_TO_PY_ARG_SWITCH
,
_wrap_arg_func
,
_return_handle
class
NDArrayBase
(
object
):
"""A simple Device/CPU Array object in runtime."""
...
...
@@ -35,9 +36,14 @@ def _make_array(handle, is_view):
_TVM_COMPATS
=
()
def
_reg_extension
(
cls
):
def
_reg_extension
(
cls
,
fcreate
):
global
_TVM_COMPATS
_TVM_COMPATS
+=
(
cls
,)
if
fcreate
:
fret
=
lambda
x
:
fcreate
(
_return_handle
(
x
))
RETURN_SWITCH
[
cls
.
_tvm_tcode
]
=
fret
C_TO_PY_ARG_SWITCH
[
cls
.
_tvm_tcode
]
=
_wrap_arg_func
(
fret
,
cls
.
_tvm_tcode
)
_CLASS_NDARRAY
=
None
...
...
python/tvm/_ffi/_cython/base.pxi
View file @
f2ab736b
...
...
@@ -18,6 +18,7 @@ cdef enum TVMTypeCode:
kFuncHandle = 10
kStr = 11
kBytes = 12
kExtBegin = 15
cdef extern from "tvm/runtime/c_runtime_api.h":
ctypedef struct DLDataType:
...
...
python/tvm/_ffi/_cython/function.pxi
View file @
f2ab736b
...
...
@@ -27,8 +27,10 @@ cdef int tvm_callback(TVMValue* args,
tcode = type_codes[i]
if (tcode == kNodeHandle or
tcode == kFuncHandle or
tcode == kModuleHandle):
tcode == kModuleHandle or
tcode > kExtBegin):
CALL(TVMCbArgToReturn(&value, tcode))
if tcode != kArrayHandle:
pyargs.append(make_ret(value, tcode))
else:
...
...
@@ -87,7 +89,7 @@ cdef inline void make_arg(object arg,
elif isinstance(arg, _TVM_COMPATS):
ptr = arg._tvm_handle
value[0].v_handle = (<void*>ptr)
tcode[0] = arg._tvm_tcode
tcode[0] = arg._
_class__._
tvm_tcode
elif isinstance(arg, (int, long)):
value[0].v_int64 = arg
tcode[0] = kInt
...
...
@@ -185,8 +187,10 @@ cdef inline object make_ret(TVMValue value, int tcode):
fobj = _CLASS_FUNCTION(None, False)
(<FunctionBase>fobj).chandle = value.v_handle
return fobj
else:
raise ValueError("Unhandled type code %d" % tcode)
elif tcode in _TVM_EXT_RET:
return _TVM_EXT_RET[tcode](ctypes_handle(value.v_handle))
raise ValueError("Unhandled type code %d" % tcode)
cdef inline object FuncCall3(void* chandle, tuple args, int nargs):
...
...
python/tvm/_ffi/_cython/ndarray.pxi
View file @
f2ab736b
...
...
@@ -43,9 +43,14 @@ cdef c_make_array(void* chandle, is_view):
cdef _TVM_COMPATS = ()
def _reg_extension(cls):
cdef _TVM_EXT_RET = {}
def _reg_extension(cls, fcreate):
global _TVM_COMPATS
_TVM_COMPATS += (cls,)
if fcreate:
_TVM_EXT_RET[cls._tvm_tcode] = fcreate
def _make_array(handle, is_view):
cdef unsigned long long ptr
...
...
python/tvm/_ffi/ndarray.py
View file @
f2ab736b
...
...
@@ -6,7 +6,8 @@ import sys
import
ctypes
import
numpy
as
np
from
.base
import
_LIB
,
check_call
,
c_array
,
string_types
,
_FFI_MODE
from
.runtime_ctypes
import
TVMType
,
TVMContext
,
TVMArray
,
TVMArrayHandle
,
tvm_shape_index_t
from
.runtime_ctypes
import
TVMType
,
TVMContext
,
TVMArray
,
TVMArrayHandle
from
.runtime_ctypes
import
TypeCode
,
tvm_shape_index_t
IMPORT_EXCEPT
=
RuntimeError
if
_FFI_MODE
==
"cython"
else
ImportError
...
...
@@ -222,9 +223,21 @@ class NDArrayBase(_NDArrayBase):
raise
ValueError
(
"Unsupported target type
%
s"
%
str
(
type
(
target
)))
return
target
def
free_extension_handle
(
handle
,
type_code
):
"""Free c++ extension type handle
def
register_extension
(
cls
):
"""Register a extensio class to TVM.
Parameters
----------
handle : ctypes.c_void_p
The handle to the extension type.
type_code : int
The tyoe code
"""
check_call
(
_LIB
.
TVMExtTypeFree
(
handle
,
ctypes
.
c_int
(
type_code
)))
def
register_extension
(
cls
,
fcreate
=
None
):
"""Register a extension class to TVM.
After the class is registered, the class will be able
to directly pass as Function argument generated by TVM.
...
...
@@ -236,16 +249,19 @@ def register_extension(cls):
Note
----
The registered class is requires
two properties: _tvm_handle and _tvm_tcode
The registered class is requires
one property: _tvm_handle and a class attribute _tvm_tcode.
- ```_tvm_handle``` returns integer represents the address of the handle.
- ```_tvm_tcode```
return
s integer represents type code of the class.
- ```_tvm_tcode```
give
s integer represents type code of the class.
Returns
-------
cls : class
The class being registered.
fcreate : function, optional
The creation function to create a class object given handle value.
Example
-------
The following code registers user defined class
...
...
@@ -255,16 +271,16 @@ def register_extension(cls):
@tvm.register_extension
class MyTensor(object):
_tvm_tcode = tvm.TypeCode.ARRAY_HANDLE
def __init__(self):
self.handle = _LIB.NewDLTensor()
@property
def _tvm_handle(self):
return self.handle.value
@property
def _tvm_tcode(self):
return tvm.TypeCode.ARRAY_HANDLE
"""
_reg_extension
(
cls
)
if
fcreate
and
cls
.
_tvm_tcode
<
TypeCode
.
EXT_BEGIN
:
raise
ValueError
(
"Cannot register create when extension tcode is same as buildin"
)
_reg_extension
(
cls
,
fcreate
)
return
cls
python/tvm/_ffi/runtime_ctypes.py
View file @
f2ab736b
...
...
@@ -24,6 +24,7 @@ class TypeCode(object):
FUNC_HANDLE
=
10
STR
=
11
BYTES
=
12
EXT_BEGIN
=
15
class
TVMByteArray
(
ctypes
.
Structure
):
"""Temp data structure for byte array."""
...
...
python/tvm/ndarray.py
View file @
f2ab736b
...
...
@@ -9,7 +9,8 @@ import numpy as _np
from
._ffi.ndarray
import
TVMContext
,
TVMType
,
NDArrayBase
from
._ffi.ndarray
import
context
,
empty
from
._ffi.ndarray
import
_set_class_ndarray
,
register_extension
from
._ffi.ndarray
import
_set_class_ndarray
from
._ffi.ndarray
import
register_extension
,
free_extension_handle
class
NDArray
(
NDArrayBase
):
"""Lightweight NDArray class of TVM runtime.
...
...
src/runtime/registry.cc
View file @
f2ab736b
...
...
@@ -9,6 +9,7 @@
#include <unordered_map>
#include <mutex>
#include <memory>
#include <array>
#include "./runtime_base.h"
namespace
tvm
{
...
...
@@ -21,8 +22,17 @@ struct Registry::Manager {
// and the resource can become invalid because of indeterminstic order of destruction.
// The resources will only be recycled during program exit.
std
::
unordered_map
<
std
::
string
,
Registry
*>
fmap
;
// vtable for extension type
std
::
array
<
ExtTypeVTable
,
kExtEnd
>
ext_vtable
;
// mutex
std
::
mutex
mutex
;
Manager
()
{
for
(
auto
&
x
:
ext_vtable
)
{
x
.
destroy
=
nullptr
;
}
}
static
Manager
*
Global
()
{
static
Manager
inst
;
return
&
inst
;
...
...
@@ -78,6 +88,24 @@ std::vector<std::string> Registry::ListNames() {
return
keys
;
}
ExtTypeVTable
*
ExtTypeVTable
::
Get
(
int
type_code
)
{
CHECK
(
type_code
>
kExtBegin
&&
type_code
<
kExtEnd
);
Registry
::
Manager
*
m
=
Registry
::
Manager
::
Global
();
ExtTypeVTable
*
vt
=
&
(
m
->
ext_vtable
[
type_code
]);
CHECK
(
vt
->
destroy
!=
nullptr
)
<<
"Extension type not registered"
;
return
vt
;
}
ExtTypeVTable
*
ExtTypeVTable
::
RegisterInternal
(
int
type_code
,
const
ExtTypeVTable
&
vt
)
{
CHECK
(
type_code
>
kExtBegin
&&
type_code
<
kExtEnd
);
Registry
::
Manager
*
m
=
Registry
::
Manager
::
Global
();
std
::
lock_guard
<
std
::
mutex
>
(
m
->
mutex
);
ExtTypeVTable
*
pvt
=
&
(
m
->
ext_vtable
[
type_code
]);
pvt
[
0
]
=
vt
;
return
pvt
;
}
}
// namespace runtime
}
// namespace tvm
...
...
@@ -92,6 +120,11 @@ struct TVMFuncThreadLocalEntry {
/*! \brief Thread local store that can be used to hold return values. */
typedef
dmlc
::
ThreadLocalStore
<
TVMFuncThreadLocalEntry
>
TVMFuncThreadLocalStore
;
int
TVMExtTypeFree
(
void
*
handle
,
int
type_code
)
{
API_BEGIN
();
tvm
::
runtime
::
ExtTypeVTable
::
Get
(
type_code
)
->
destroy
(
handle
);
API_END
();
}
int
TVMFuncRegisterGlobal
(
const
char
*
name
,
TVMFunctionHandle
f
,
int
override
)
{
...
...
tests/cpp/packed_func_test.cc
View file @
f2ab736b
...
...
@@ -110,6 +110,55 @@ TEST(PackedFunc, Type) {
CHECK
(
get_type2
(
"float32x2"
).
operator
Type
()
==
Float
(
32
,
2
));
}
// new namespoace
namespace
test
{
// register int vector as extension type
using
IntVector
=
std
::
vector
<
int
>
;
}
// namespace test
namespace
tvm
{
namespace
runtime
{
template
<>
struct
extension_class_info
<
test
::
IntVector
>
{
static
const
int
code
=
kExtBegin
+
1
;
};
}
// runtime
}
// tvm
// do registration, this need to be in cc file
TVM_REGISTER_EXT_TYPE
(
test
::
IntVector
);
TEST
(
PackedFunc
,
ExtensionType
)
{
using
namespace
tvm
;
using
namespace
tvm
::
runtime
;
// note: class are copy by value.
test
::
IntVector
vec
{
1
,
2
,
4
};
auto
copy_vec
=
PackedFunc
([
&
](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
// copy by value
const
test
::
IntVector
&
v
=
args
[
0
].
AsExtension
<
test
::
IntVector
>
();
CHECK
(
&
v
==
&
vec
);
test
::
IntVector
v2
=
args
[
0
];
CHECK_EQ
(
v2
.
size
(),
3U
);
CHECK_EQ
(
v
[
2
],
4
);
// return copy by value
*
rv
=
v2
;
});
auto
pass_vec
=
PackedFunc
([
&
](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
// copy by value
*
rv
=
args
[
0
];
});
test
::
IntVector
vret1
=
copy_vec
(
vec
);
test
::
IntVector
vret2
=
pass_vec
(
copy_vec
(
vec
));
CHECK_EQ
(
vret1
.
size
(),
3U
);
CHECK_EQ
(
vret2
.
size
(),
3U
);
CHECK_EQ
(
vret1
[
2
],
4
);
CHECK_EQ
(
vret2
[
2
],
4
);
}
int
main
(
int
argc
,
char
**
argv
)
{
testing
::
InitGoogleTest
(
&
argc
,
argv
);
...
...
tests/python/unittest/test_runtime_extension.py
View file @
f2ab736b
...
...
@@ -3,6 +3,7 @@ import numpy as np
@tvm.register_extension
class
MyTensorView
(
object
):
_tvm_tcode
=
tvm
.
TypeCode
.
ARRAY_HANDLE
def
__init__
(
self
,
arr
):
self
.
arr
=
arr
...
...
@@ -10,10 +11,6 @@ class MyTensorView(object):
def
_tvm_handle
(
self
):
return
self
.
arr
.
_tvm_handle
@property
def
_tvm_tcode
(
self
):
return
tvm
.
TypeCode
.
ARRAY_HANDLE
def
test_dltensor_compatible
():
dtype
=
'int64'
n
=
tvm
.
var
(
'n'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment