Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
81334be3
Commit
81334be3
authored
Feb 21, 2019
by
Junru Shao
Committed by
Tianqi Chen
Feb 21, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[RUNTIME][NDArray] Allowing External Libraries to Subclass NDArrays (#2613)
parent
79abd2c3
Hide whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
280 additions
and
54 deletions
+280
-54
apps/extension/Makefile
+1
-1
apps/extension/python/tvm_ext/__init__.py
+28
-1
apps/extension/src/tvm_ext.cc
+84
-1
apps/extension/tests/test_ext.py
+16
-0
include/tvm/runtime/ndarray.h
+26
-3
include/tvm/runtime/packed_func.h
+35
-13
include/tvm/runtime/registry.h
+1
-1
nnvm/include/nnvm/compiler/packed_func_ext.h
+3
-3
nnvm/src/compiler/packed_func_ext.cc
+2
-2
python/tvm/_ffi/_ctypes/function.py
+3
-3
python/tvm/_ffi/_ctypes/ndarray.py
+15
-4
python/tvm/_ffi/_cython/base.pxi
+9
-1
python/tvm/_ffi/_cython/function.pxi
+2
-2
python/tvm/_ffi/_cython/ndarray.pxi
+17
-5
python/tvm/_ffi/ndarray.py
+27
-12
python/tvm/_ffi/runtime_ctypes.py
+9
-0
python/tvm/ndarray.py
+1
-1
tests/cpp/packed_func_test.cc
+1
-1
No files found.
apps/extension/Makefile
View file @
81334be3
...
...
@@ -6,7 +6,7 @@ PKG_CFLAGS = -std=c++11 -O2 -fPIC\
-I
${
TVM_ROOT
}
/3rdparty/dlpack/include
\
-I
${
TVM_ROOT
}
/3rdparty/HalideIR/src
PKG_LDFLAGS
=
-L
${
TVM_ROOT
}
/
lib
PKG_LDFLAGS
=
-L
${
TVM_ROOT
}
/
build
UNAME_S
:=
$(
shell
uname
-s
)
ifeq
($(UNAME_S),
Darwin)
...
...
apps/extension/python/tvm_ext/__init__.py
View file @
81334be3
...
...
@@ -31,7 +31,7 @@ class IntVec(object):
def
__del__
(
self
):
# You can also call your own customized
# deleter if you can free it via your own FFI.
tvm
.
nd
.
free_extension_handle
(
self
.
handle
,
17
)
tvm
.
nd
.
free_extension_handle
(
self
.
handle
,
self
.
__class__
.
_tvm_tcode
)
@property
def
_tvm_handle
(
self
):
...
...
@@ -42,3 +42,30 @@ class IntVec(object):
# Register IntVec extension on python side.
tvm
.
register_extension
(
IntVec
,
IntVec
)
nd_create
=
tvm
.
get_global_func
(
"tvm_ext.nd_create"
)
nd_add_two
=
tvm
.
get_global_func
(
"tvm_ext.nd_add_two"
)
nd_get_addtional_info
=
tvm
.
get_global_func
(
"tvm_ext.nd_get_addtional_info"
)
class
NDSubClass
(
tvm
.
nd
.
NDArrayBase
):
"""Example for subclassing TVM's NDArray infrastructure.
By inheriting TMV's NDArray, external libraries could
leverage TVM's FFI without any modification.
"""
# Should be consistent with the type-trait set in the backend
_array_type_code
=
1
@staticmethod
def
create
(
addtional_info
):
return
nd_create
(
addtional_info
)
@property
def
addtional_info
(
self
):
return
nd_get_addtional_info
(
self
)
def
__add__
(
self
,
other
):
return
nd_add_two
(
self
,
other
)
tvm
.
register_extension
(
NDSubClass
,
NDSubClass
)
apps/extension/src/tvm_ext.cc
View file @
81334be3
...
...
@@ -7,18 +7,25 @@
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/module.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/packed_func_ext.h>
#include <tvm/runtime/device_api.h>
namespace
tvm_ext
{
using
IntVector
=
std
::
vector
<
int
>
;
class
NDSubClass
;
}
// namespace tvm_ext
namespace
tvm
{
namespace
runtime
{
template
<>
struct
extension_
class
_info
<
tvm_ext
::
IntVector
>
{
struct
extension_
type
_info
<
tvm_ext
::
IntVector
>
{
static
const
int
code
=
17
;
};
template
<>
struct
array_type_info
<
tvm_ext
::
NDSubClass
>
{
static
const
int
code
=
1
;
};
}
// namespace tvm
}
// namespace runtime
...
...
@@ -26,6 +33,62 @@ using namespace tvm;
using
namespace
tvm
::
runtime
;
namespace
tvm_ext
{
/*!
* \brief A subclass of TVM's NDArray.
*
* To use this extension, an external library should
*
* 1) Inherit TVM's NDArray and NDArray container,
* and define the trait `array_type_info` for this class.
*
* 2) Define a constructor in the inherited class that accepts
* a pointer to TVM's Container, which is nullable.
*
* 3) On Python frontend, inherit `tvm.nd.NDArrayBase`,
* define the class attribute `_array_type_code` consistent to
* the C++ type trait, and register the subclass using `tvm.register_extension`.
*/
class
NDSubClass
:
public
tvm
::
runtime
::
NDArray
{
public
:
class
SubContainer
:
public
NDArray
::
Container
{
public
:
SubContainer
(
int
addtional_info
)
:
addtional_info_
(
addtional_info
)
{
array_type_code_
=
array_type_info
<
NDSubClass
>::
code
;
}
static
bool
Is
(
NDArray
::
Container
*
container
)
{
SubContainer
*
c
=
static_cast
<
SubContainer
*>
(
container
);
return
c
->
array_type_code_
==
array_type_info
<
NDSubClass
>::
code
;
}
int
addtional_info_
{
0
};
};
NDSubClass
(
NDArray
::
Container
*
container
)
{
if
(
container
==
nullptr
)
{
data_
=
nullptr
;
return
;
}
CHECK
(
SubContainer
::
Is
(
container
));
container
->
IncRef
();
data_
=
container
;
}
~
NDSubClass
()
{
this
->
reset
();
}
NDSubClass
AddWith
(
const
NDSubClass
&
other
)
const
{
SubContainer
*
a
=
static_cast
<
SubContainer
*>
(
data_
);
SubContainer
*
b
=
static_cast
<
SubContainer
*>
(
other
.
data_
);
CHECK
(
a
!=
nullptr
&&
b
!=
nullptr
);
return
NDSubClass
(
new
SubContainer
(
a
->
addtional_info_
+
b
->
addtional_info_
));
}
int
get_additional_info
()
const
{
SubContainer
*
self
=
static_cast
<
SubContainer
*>
(
data_
);
CHECK
(
self
!=
nullptr
);
return
self
->
addtional_info_
;
}
};
}
// namespace tvm_ext
namespace
tvm_ext
{
TVM_REGISTER_EXT_TYPE
(
IntVector
);
...
...
@@ -64,6 +127,26 @@ TVM_REGISTER_GLOBAL("device_api.ext_dev")
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
*
rv
=
(
*
tvm
::
runtime
::
Registry
::
Get
(
"device_api.cpu"
))();
});
TVM_REGISTER_GLOBAL
(
"tvm_ext.nd_create"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
int
addtional_info
=
args
[
0
];
*
rv
=
NDSubClass
(
new
NDSubClass
::
SubContainer
(
addtional_info
));
});
TVM_REGISTER_GLOBAL
(
"tvm_ext.nd_add_two"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
NDSubClass
a
=
args
[
0
];
NDSubClass
b
=
args
[
1
];
*
rv
=
a
.
AddWith
(
b
);
});
TVM_REGISTER_GLOBAL
(
"tvm_ext.nd_get_addtional_info"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
NDSubClass
a
=
args
[
0
];
*
rv
=
a
.
get_additional_info
();
});
}
// namespace tvm_ext
// External function exposed to runtime.
...
...
apps/extension/tests/test_ext.py
View file @
81334be3
...
...
@@ -32,6 +32,7 @@ def test_sym_add():
c
=
tvm_ext
.
sym_add
(
a
,
b
)
assert
c
.
a
==
a
and
c
.
b
==
b
def
test_ext_vec
():
ivec
=
tvm_ext
.
ivec_create
(
1
,
2
,
3
)
assert
(
isinstance
(
ivec
,
tvm_ext
.
IntVec
))
...
...
@@ -44,6 +45,7 @@ def test_ext_vec():
tvm
.
convert
(
ivec_cb
)(
ivec
)
def
test_extract_ext
():
fdict
=
tvm
.
extract_ext_funcs
(
tvm_ext
.
_LIB
.
TVMExtDeclare
)
assert
fdict
[
"mul"
](
3
,
4
)
==
12
...
...
@@ -68,7 +70,21 @@ def test_extern_call():
check_llvm
()
def
test_nd_subclass
():
a
=
tvm_ext
.
NDSubClass
.
create
(
addtional_info
=
3
)
b
=
tvm_ext
.
NDSubClass
.
create
(
addtional_info
=
5
)
c
=
a
+
b
d
=
a
+
a
e
=
b
+
b
assert
(
a
.
addtional_info
==
3
)
assert
(
b
.
addtional_info
==
5
)
assert
(
c
.
addtional_info
==
8
)
assert
(
d
.
addtional_info
==
6
)
assert
(
e
.
addtional_info
==
10
)
if
__name__
==
"__main__"
:
test_nd_subclass
()
test_extern_call
()
test_ext_dev
()
test_ext_vec
()
...
...
include/tvm/runtime/ndarray.h
View file @
81334be3
...
...
@@ -178,11 +178,31 @@ class NDArray {
Container
*
data_
{
nullptr
};
// enable internal functions
friend
struct
Internal
;
friend
class
TVMPODValue_
;
friend
class
TVMArgValue
;
friend
class
TVMRetValue
;
friend
class
TVMArgsSetter
;
};
/*!
* \brief The type trait indicates subclass of TVM's NDArray.
* For irrelavant classes, code = -1.
* For TVM NDArray itself, code = 0.
* All subclasses of NDArray should override code > 0.
*/
template
<
typename
T
>
struct
array_type_info
{
/*! \brief the value of the traits */
static
const
int
code
=
-
1
;
};
// Overrides the type trait for tvm's NDArray.
template
<>
struct
array_type_info
<
NDArray
>
{
static
const
int
code
=
0
;
};
/*!
* \brief Save a DLTensor to stream
* \param strm The outpu stream
* \param tensor The tensor to be saved.
...
...
@@ -196,7 +216,7 @@ inline bool SaveDLTensor(dmlc::Stream* strm, const DLTensor* tensor);
* the pointer to the NDArrayContainer can be directly
* interpreted as a DLTensor*
*
* \note
:
do not use this function directly, use NDArray.
* \note do not use this function directly, use NDArray.
*/
class
NDArray
::
Container
{
public
:
...
...
@@ -228,16 +248,19 @@ class NDArray::Container {
protected
:
friend
class
NDArray
;
friend
class
TVMPODValue_
;
friend
class
TVMArgValue
;
friend
class
TVMRetValue
;
friend
class
RPCWrappedFunc
;
/*!
* \brief Type flag used to indicate subclass.
* Default value 0 means normal NDArray::Conatainer.
*
* We can extend a more specialized NDArray::Container
* and use the array_type_
index
_ to indicate
* and use the array_type_
code
_ to indicate
* the specific array subclass.
*/
uint32_t
array_type_index
_
{
0
};
int32_t
array_type_code
_
{
0
};
/*! \brief The internal reference counter */
std
::
atomic
<
int
>
ref_counter_
{
0
};
/*!
...
...
include/tvm/runtime/packed_func.h
View file @
81334be3
...
...
@@ -362,7 +362,7 @@ inline std::string TVMType2String(TVMType t);
* \tparam T the typename
*/
template
<
typename
T
>
struct
extension_
class
_info
{
struct
extension_
type
_info
{
static
const
int
code
=
0
;
};
...
...
@@ -455,6 +455,15 @@ class TVMPODValue_ {
TVM_CHECK_TYPE_CODE
(
type_code_
,
kTVMContext
);
return
value_
.
v_ctx
;
}
template
<
typename
TNDArray
,
typename
=
typename
std
::
enable_if
<
std
::
is_base_of
<
NDArray
,
TNDArray
>::
value
>::
type
>
TNDArray
AsNDArray
()
const
{
if
(
type_code_
==
kNull
)
return
TNDArray
(
nullptr
);
auto
*
container
=
static_cast
<
NDArray
::
Container
*>
(
value_
.
v_handle
);
CHECK_EQ
(
container
->
array_type_code_
,
array_type_info
<
TNDArray
>::
code
);
return
TNDArray
(
container
);
}
template
<
typename
TExtension
>
const
TExtension
&
AsExtension
()
const
{
CHECK_LT
(
type_code_
,
kExtEnd
);
...
...
@@ -561,7 +570,7 @@ class TVMArgValue : public TVMPODValue_ {
inline
TNodeRef
AsNodeRef
()
const
;
template
<
typename
T
,
typename
=
typename
std
::
enable_if
<
std
::
is_class
<
T
>::
value
>::
type
>
std
::
is_class
<
T
>::
value
>::
type
>
inline
operator
T
()
const
;
template
<
typename
TNodeRef
,
typename
=
typename
std
::
enable_if
<
...
...
@@ -727,10 +736,10 @@ class TVMRetValue : public TVMPODValue_ {
}
template
<
typename
T
,
typename
=
typename
std
::
enable_if
<
extension_
class
_info
<
T
>::
code
!=
0
>::
type
>
extension_
type
_info
<
T
>::
code
!=
0
>::
type
>
TVMRetValue
&
operator
=
(
const
T
&
other
)
{
this
->
SwitchToClass
<
T
>
(
extension_
class
_info
<
T
>::
code
,
other
);
extension_
type
_info
<
T
>::
code
,
other
);
return
*
this
;
}
/*!
...
...
@@ -1094,7 +1103,7 @@ class TVMArgsSetter {
// extension
template
<
typename
T
,
typename
=
typename
std
::
enable_if
<
extension_
class
_info
<
T
>::
code
!=
0
>::
type
>
extension_
type
_info
<
T
>::
code
!=
0
>::
type
>
inline
void
operator
()(
size_t
i
,
const
T
&
value
)
const
;
// NodeRef related extenstions: in tvm/packed_func_ext.h
inline
void
operator
()(
size_t
i
,
const
NodeRef
&
other
)
const
;
// NOLINT(*)
...
...
@@ -1212,40 +1221,53 @@ inline R TypedPackedFunc<R(Args...)>::operator()(Args... args) const {
// extension and node type handling
namespace
detail
{
template
<
typename
T
,
typename
TSrc
,
bool
is_ext
>
template
<
typename
T
,
typename
TSrc
,
bool
is_ext
,
bool
is_nd
>
struct
TVMValueCast
{
static
T
Apply
(
const
TSrc
*
self
)
{
static_assert
(
!
is_ext
&&
!
is_nd
,
"The default case accepts only non-extensions"
);
return
self
->
template
AsNodeRef
<
T
>
();
}
};
template
<
typename
T
,
typename
TSrc
>
struct
TVMValueCast
<
T
,
TSrc
,
true
>
{
struct
TVMValueCast
<
T
,
TSrc
,
true
,
false
>
{
static
T
Apply
(
const
TSrc
*
self
)
{
return
self
->
template
AsExtension
<
T
>
();
}
};
template
<
typename
T
,
typename
TSrc
>
struct
TVMValueCast
<
T
,
TSrc
,
false
,
true
>
{
static
T
Apply
(
const
TSrc
*
self
)
{
return
self
->
template
AsNDArray
<
T
>
();
}
};
}
// namespace detail
template
<
typename
T
,
typename
>
inline
TVMArgValue
::
operator
T
()
const
{
return
detail
::
TVMValueCast
<
T
,
TVMArgValue
,
extension_class_info
<
T
>::
code
!=
0
>
TVMValueCast
<
T
,
TVMArgValue
,
(
extension_type_info
<
T
>::
code
!=
0
),
(
array_type_info
<
T
>::
code
>
0
)
>
::
Apply
(
this
);
}
template
<
typename
T
,
typename
>
inline
TVMRetValue
::
operator
T
()
const
{
return
detail
::
TVMValueCast
<
T
,
TVMRetValue
,
extension_class_info
<
T
>::
code
!=
0
>
TVMValueCast
<
T
,
TVMRetValue
,
(
extension_type_info
<
T
>::
code
!=
0
),
(
array_type_info
<
T
>::
code
>
0
)
>
::
Apply
(
this
);
}
template
<
typename
T
,
typename
>
inline
void
TVMArgsSetter
::
operator
()(
size_t
i
,
const
T
&
value
)
const
{
static_assert
(
extension_
class
_info
<
T
>::
code
!=
0
,
static_assert
(
extension_
type
_info
<
T
>::
code
!=
0
,
"Need to have extesion code"
);
type_codes_
[
i
]
=
extension_
class
_info
<
T
>::
code
;
type_codes_
[
i
]
=
extension_
type
_info
<
T
>::
code
;
values_
[
i
].
v_handle
=
const_cast
<
T
*>
(
&
value
);
}
...
...
@@ -1262,9 +1284,9 @@ struct ExtTypeInfo {
template
<
typename
T
>
inline
ExtTypeVTable
*
ExtTypeVTable
::
Register_
()
{
const
int
code
=
extension_
class
_info
<
T
>::
code
;
const
int
code
=
extension_
type
_info
<
T
>::
code
;
static_assert
(
code
!=
0
,
"require extension_
class
_info traits to be declared with non-zero code"
);
"require extension_
type
_info traits to be declared with non-zero code"
);
ExtTypeVTable
vt
;
vt
.
clone
=
ExtTypeInfo
<
T
>::
clone
;
vt
.
destroy
=
ExtTypeInfo
<
T
>::
destroy
;
...
...
include/tvm/runtime/registry.h
View file @
81334be3
...
...
@@ -133,7 +133,7 @@ class Registry {
/*!
* \brief Macro to register extension type.
* This must be registered in a cc file
* after the trait extension_
class
_info is defined.
* after the trait extension_
type
_info is defined.
*/
#define TVM_REGISTER_EXT_TYPE(T) \
TVM_STR_CONCAT(TVM_TYPE_REG_VAR_DEF, __COUNTER__) = \
...
...
nnvm/include/nnvm/compiler/packed_func_ext.h
View file @
81334be3
...
...
@@ -40,17 +40,17 @@ namespace tvm {
namespace
runtime
{
template
<>
struct
extension_
class
_info
<
nnvm
::
Symbol
>
{
struct
extension_
type
_info
<
nnvm
::
Symbol
>
{
static
const
int
code
=
16
;
};
template
<>
struct
extension_
class
_info
<
nnvm
::
Graph
>
{
struct
extension_
type
_info
<
nnvm
::
Graph
>
{
static
const
int
code
=
17
;
};
template
<>
struct
extension_
class
_info
<
nnvm
::
compiler
::
AttrDict
>
{
struct
extension_
type
_info
<
nnvm
::
compiler
::
AttrDict
>
{
static
const
int
code
=
18
;
};
...
...
nnvm/src/compiler/packed_func_ext.cc
View file @
81334be3
...
...
@@ -76,8 +76,8 @@ TVM_REGISTER_GLOBAL("nnvm.compiler._register_alter_op_layout")
if
(
ret
.
type_code
()
==
TVMTypeCode
::
kNull
)
{
return
false
;
}
CHECK_EQ
(
ret
.
type_code
(),
tvm
::
runtime
::
extension_
class
_info
<
Symbol
>::
code
)
<<
" expected "
<<
"Symbol (code = "
<<
tvm
::
runtime
::
extension_
class
_info
<
Symbol
>::
code
CHECK_EQ
(
ret
.
type_code
(),
tvm
::
runtime
::
extension_
type
_info
<
Symbol
>::
code
)
<<
" expected "
<<
"Symbol (code = "
<<
tvm
::
runtime
::
extension_
type
_info
<
Symbol
>::
code
<<
") but get code = "
<<
ret
.
type_code
();
*
ret_symbol
=
*
(
static_cast
<
Symbol
*>
(
ret
.
value
().
v_handle
));
return
true
;
...
...
python/tvm/_ffi/_ctypes/function.py
View file @
81334be3
...
...
@@ -223,13 +223,13 @@ def _handle_return_func(x):
_node
.
__init_by_constructor__
=
__init_handle_by_constructor__
RETURN_SWITCH
[
TypeCode
.
FUNC_HANDLE
]
=
_handle_return_func
RETURN_SWITCH
[
TypeCode
.
MODULE_HANDLE
]
=
_return_module
RETURN_SWITCH
[
TypeCode
.
NDARRAY_CONTAINER
]
=
lambda
x
:
_make_array
(
x
.
v_handle
,
False
)
RETURN_SWITCH
[
TypeCode
.
NDARRAY_CONTAINER
]
=
lambda
x
:
_make_array
(
x
.
v_handle
,
False
,
True
)
C_TO_PY_ARG_SWITCH
[
TypeCode
.
FUNC_HANDLE
]
=
_wrap_arg_func
(
_handle_return_func
,
TypeCode
.
FUNC_HANDLE
)
C_TO_PY_ARG_SWITCH
[
TypeCode
.
MODULE_HANDLE
]
=
_wrap_arg_func
(
_return_module
,
TypeCode
.
MODULE_HANDLE
)
C_TO_PY_ARG_SWITCH
[
TypeCode
.
ARRAY_HANDLE
]
=
lambda
x
:
_make_array
(
x
.
v_handle
,
True
)
C_TO_PY_ARG_SWITCH
[
TypeCode
.
NDARRAY_CONTAINER
]
=
lambda
x
:
_make_array
(
x
.
v_handle
,
False
)
C_TO_PY_ARG_SWITCH
[
TypeCode
.
ARRAY_HANDLE
]
=
lambda
x
:
_make_array
(
x
.
v_handle
,
True
,
False
)
C_TO_PY_ARG_SWITCH
[
TypeCode
.
NDARRAY_CONTAINER
]
=
lambda
x
:
_make_array
(
x
.
v_handle
,
False
,
True
)
_CLASS_MODULE
=
None
_CLASS_FUNCTION
=
None
...
...
python/tvm/_ffi/_ctypes/ndarray.py
View file @
81334be3
...
...
@@ -4,7 +4,7 @@ from __future__ import absolute_import
import
ctypes
from
..base
import
_LIB
,
check_call
,
c_str
from
..runtime_ctypes
import
TVMArrayHandle
from
..runtime_ctypes
import
TVMArrayHandle
,
TVMNDArrayContainerHandle
from
.types
import
RETURN_SWITCH
,
C_TO_PY_ARG_SWITCH
,
_wrap_arg_func
,
_return_handle
...
...
@@ -28,7 +28,7 @@ def _from_dlpack(dltensor):
check_call
(
_LIB
.
TVMArrayFromDLPack
(
ptr
,
ctypes
.
byref
(
handle
)))
ctypes
.
pythonapi
.
PyCapsule_SetName
(
dltensor
,
_c_str_used_dltensor
)
ctypes
.
pythonapi
.
PyCapsule_SetDestructor
(
dltensor
,
TVMPyCapsuleDestructor
(
0
))
return
_make_array
(
handle
,
False
)
return
_make_array
(
handle
,
False
,
False
)
raise
ValueError
(
"Expect a dltensor field, PyCapsule can only be consumed once"
)
...
...
@@ -77,9 +77,15 @@ class NDArrayBase(object):
return
ctypes
.
pythonapi
.
PyCapsule_New
(
handle
,
_c_str_dltensor
,
_c_dlpack_deleter
)
def
_make_array
(
handle
,
is_view
):
def
_make_array
(
handle
,
is_view
,
is_container
):
global
_TVM_ND_CLS
handle
=
ctypes
.
cast
(
handle
,
TVMArrayHandle
)
return
_CLASS_NDARRAY
(
handle
,
is_view
)
fcreate
=
_CLASS_NDARRAY
if
is_container
and
_TVM_ND_CLS
:
array_type_info
=
ctypes
.
cast
(
handle
,
TVMNDArrayContainerHandle
)
.
array_type_info
.
value
if
array_type_info
>
0
:
fcreate
=
_TVM_ND_CLS
[
array_type_info
]
return
fcreate
(
handle
,
is_view
)
_TVM_COMPATS
=
()
...
...
@@ -91,6 +97,11 @@ def _reg_extension(cls, fcreate):
RETURN_SWITCH
[
cls
.
_tvm_tcode
]
=
fret
C_TO_PY_ARG_SWITCH
[
cls
.
_tvm_tcode
]
=
_wrap_arg_func
(
fret
,
cls
.
_tvm_tcode
)
_TVM_ND_CLS
=
{}
def
_reg_ndarray
(
cls
,
fcreate
):
global
_TVM_ND_CLS
_TVM_ND_CLS
[
cls
.
_array_type_code
]
=
fcreate
_CLASS_NDARRAY
=
None
...
...
python/tvm/_ffi/_cython/base.pxi
View file @
81334be3
...
...
@@ -2,7 +2,7 @@ from ..base import TVMError
from libcpp.vector cimport vector
from cpython.version cimport PY_MAJOR_VERSION
from cpython cimport pycapsule
from libc.stdint cimport int64_t, uint64_t, uint8_t, uint16_t
from libc.stdint cimport int
32_t, int
64_t, uint64_t, uint8_t, uint16_t
import ctypes
cdef enum TVMTypeCode:
...
...
@@ -61,6 +61,14 @@ ctypedef void* TVMRetValueHandle
ctypedef void* TVMFunctionHandle
ctypedef void* NodeHandle
ctypedef struct TVMNDArrayContainer:
DLTensor dl_tensor
void* manager_ctx
void (*deleter)(DLManagedTensor* self)
int32_t array_type_info
ctypedef TVMNDArrayContainer* TVMNDArrayContainerHandle
ctypedef int (*TVMPackedCFunc)(
TVMValue* args,
int* type_codes,
...
...
python/tvm/_ffi/_cython/function.pxi
View file @
81334be3
...
...
@@ -33,7 +33,7 @@ cdef int tvm_callback(TVMValue* args,
if tcode != kArrayHandle:
pyargs.append(make_ret(value, tcode))
else:
pyargs.append(c_make_array(value.v_handle, True))
pyargs.append(c_make_array(value.v_handle, True
, False
))
try:
rv = local_pyfunc(*pyargs)
except Exception:
...
...
@@ -175,7 +175,7 @@ cdef inline object make_ret(TVMValue value, int tcode):
elif tcode == kFloat:
return value.v_float64
elif tcode == kNDArrayContainer:
return c_make_array(value.v_handle, False)
return c_make_array(value.v_handle, False
, True
)
elif tcode == kStr:
return py_str(value.v_str)
elif tcode == kBytes:
...
...
python/tvm/_ffi/_cython/ndarray.pxi
View file @
81334be3
...
...
@@ -20,7 +20,7 @@ def _from_dlpack(object dltensor):
# set name and destructor to be empty
pycapsule.PyCapsule_SetDestructor(dltensor, NULL)
pycapsule.PyCapsule_SetName(dltensor, _c_str_used_dltensor)
return c_make_array(chandle,
0
)
return c_make_array(chandle,
False, False
)
raise ValueError("Expect a dltensor field, pycapsule.PyCapsule can only be consumed once")
...
...
@@ -73,8 +73,15 @@ cdef class NDArrayBase:
return pycapsule.PyCapsule_New(dltensor, _c_str_dltensor, _c_dlpack_deleter)
cdef c_make_array(void* chandle, is_view):
ret = _CLASS_NDARRAY(None, is_view)
cdef c_make_array(void* chandle, is_view, is_container):
global _TVM_ND_CLS
cdef int32_t array_type_info
fcreate = _CLASS_NDARRAY
if is_container and len(_TVM_ND_CLS) > 0:
array_type_info = (<TVMNDArrayContainerHandle>chandle).array_type_info
if array_type_info > 0:
fcreate = _TVM_ND_CLS[array_type_info]
ret = fcreate(None, is_view)
(<NDArrayBase>ret).chandle = <DLTensor*>chandle
return ret
...
...
@@ -89,11 +96,16 @@ def _reg_extension(cls, fcreate):
if fcreate:
_TVM_EXT_RET[cls._tvm_tcode] = fcreate
cdef _TVM_ND_CLS = {}
def _make_array(handle, is_view):
def _reg_ndarray(cls, fcreate):
global _TVM_ND_CLS
_TVM_ND_CLS[cls._array_type_code] = fcreate
def _make_array(handle, is_view, is_container):
cdef unsigned long long ptr
ptr = ctypes.cast(handle, ctypes.c_void_p).value
return c_make_array(<void*>ptr, is_view)
return c_make_array(<void*>ptr, is_view
, is_container
)
cdef object _CLASS_NDARRAY = None
...
...
python/tvm/_ffi/ndarray.py
View file @
81334be3
...
...
@@ -17,15 +17,18 @@ try:
if
_FFI_MODE
==
"ctypes"
:
raise
ImportError
()
if
sys
.
version_info
>=
(
3
,
0
):
from
._cy3.core
import
_set_class_ndarray
,
_
reg_extension
,
_
make_array
,
_from_dlpack
from
._cy3.core
import
_set_class_ndarray
,
_make_array
,
_from_dlpack
from
._cy3.core
import
NDArrayBase
as
_NDArrayBase
from
._cy3.core
import
_reg_extension
,
_reg_ndarray
else
:
from
._cy2.core
import
_set_class_ndarray
,
_
reg_extension
,
_
make_array
,
_from_dlpack
from
._cy2.core
import
_set_class_ndarray
,
_make_array
,
_from_dlpack
from
._cy2.core
import
NDArrayBase
as
_NDArrayBase
from
._cy2.core
import
_reg_extension
,
_reg_ndarray
except
IMPORT_EXCEPT
:
# pylint: disable=wrong-import-position
from
._ctypes.ndarray
import
_set_class_ndarray
,
_
reg_extension
,
_
make_array
,
_from_dlpack
from
._ctypes.ndarray
import
_set_class_ndarray
,
_make_array
,
_from_dlpack
from
._ctypes.ndarray
import
NDArrayBase
as
_NDArrayBase
from
._ctypes.ndarray
import
_reg_extension
,
_reg_ndarray
def
context
(
dev_type
,
dev_id
=
0
):
...
...
@@ -111,7 +114,7 @@ def empty(shape, dtype="float32", ctx=context(1, 0)):
ctx
.
device_type
,
ctx
.
device_id
,
ctypes
.
byref
(
handle
)))
return
_make_array
(
handle
,
False
)
return
_make_array
(
handle
,
False
,
False
)
def
from_dlpack
(
dltensor
):
...
...
@@ -295,6 +298,7 @@ def free_extension_handle(handle, type_code):
"""
check_call
(
_LIB
.
TVMExtTypeFree
(
handle
,
ctypes
.
c_int
(
type_code
)))
def
register_extension
(
cls
,
fcreate
=
None
):
"""Register a extension class to TVM.
...
...
@@ -306,21 +310,26 @@ def register_extension(cls, fcreate=None):
cls : class
The class object to be registered as extension.
fcreate : function, optional
The creation function to create a class object given handle value.
Note
----
The registered class is requires one property: _tvm_handle and a class attribute _tvm_tcode.
The registered class is requires one property: _tvm_handle.
If the registered class is a subclass of NDArray,
it is required to have a class attribute _array_type_code.
Otherwise, it is required to have a class attribute _tvm_tcode.
- ```_tvm_handle``` returns integer represents the address of the handle.
- ```_tvm_tcode``` gives integer represents type code of the class.
- ```_tvm_tcode``` or ```_array_type_code``` gives integer represents type
code of the class.
Returns
-------
cls : class
The class being registered.
fcreate : function, optional
The creation function to create a class object given handle value.
Example
-------
The following code registers user defined class
...
...
@@ -339,7 +348,13 @@ def register_extension(cls, fcreate=None):
def _tvm_handle(self):
return self.handle.value
"""
if
fcreate
and
cls
.
_tvm_tcode
<
TypeCode
.
EXT_BEGIN
:
raise
ValueError
(
"Cannot register create when extension tcode is same as buildin"
)
_reg_extension
(
cls
,
fcreate
)
if
issubclass
(
cls
,
_NDArrayBase
):
assert
fcreate
is
not
None
assert
hasattr
(
cls
,
"_array_type_code"
)
_reg_ndarray
(
cls
,
fcreate
)
else
:
assert
hasattr
(
cls
,
"_tvm_tcode"
)
if
fcreate
and
cls
.
_tvm_tcode
<
TypeCode
.
EXT_BEGIN
:
raise
ValueError
(
"Cannot register create when extension tcode is same as buildin"
)
_reg_extension
(
cls
,
fcreate
)
return
cls
python/tvm/_ffi/runtime_ctypes.py
View file @
81334be3
...
...
@@ -240,3 +240,12 @@ class TVMArray(ctypes.Structure):
(
"byte_offset"
,
ctypes
.
c_uint64
)]
TVMArrayHandle
=
ctypes
.
POINTER
(
TVMArray
)
class
TVMNDArrayContainer
(
ctypes
.
Structure
):
"""TVM NDArray::Container"""
_fields_
=
[(
"dl_tensor"
,
TVMArray
),
(
"manager_ctx"
,
ctypes
.
c_void_p
),
(
"deleter"
,
ctypes
.
c_void_p
),
(
"array_type_info"
,
ctypes
.
c_int32
)]
TVMNDArrayContainerHandle
=
ctypes
.
POINTER
(
TVMNDArrayContainer
)
python/tvm/ndarray.py
View file @
81334be3
...
...
@@ -15,7 +15,7 @@ from ._ffi.ndarray import register_extension, free_extension_handle
class
NDArray
(
NDArrayBase
):
"""Lightweight NDArray class of TVM runtime.
Strictly this is only an Array Container(a buffer object)
Strictly this is only an Array Container
(a buffer object)
No arthimetic operations are defined.
All operations are performed by TVM functions.
...
...
tests/cpp/packed_func_test.cc
View file @
81334be3
...
...
@@ -168,7 +168,7 @@ namespace tvm {
namespace
runtime
{
template
<>
struct
extension_
class
_info
<
test
::
IntVector
>
{
struct
extension_
type
_info
<
test
::
IntVector
>
{
static
const
int
code
=
kExtBegin
+
1
;
};
}
// runtime
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment