Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
d99bcaf1
Commit
d99bcaf1
authored
Feb 23, 2018
by
Tianqi Chen
Committed by
GitHub
Feb 23, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[EXT] Allow easy extraction of extern module (#926)
parent
433756b9
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
116 additions
and
36 deletions
+116
-36
apps/extension/src/tvm_ext.cc
+17
-3
apps/extension/tests/test_ext.py
+6
-0
include/tvm/runtime/c_runtime_api.h
+18
-0
include/tvm/runtime/module.h
+7
-5
include/tvm/runtime/packed_func.h
+27
-0
python/tvm/_ffi/function.py
+25
-0
python/tvm/api.py
+1
-1
python/tvm/build_module.py
+14
-14
src/runtime/module.cc
+0
-13
tests/scripts/task_python_integration.sh
+1
-0
No files found.
apps/extension/src/tvm_ext.cc
View file @
d99bcaf1
...
@@ -22,12 +22,11 @@ struct extension_class_info<tvm_ext::IntVector> {
...
@@ -22,12 +22,11 @@ struct extension_class_info<tvm_ext::IntVector> {
}
// namespace tvm
}
// namespace tvm
}
// namespace runtime
}
// namespace runtime
namespace
tvm_ext
{
using
namespace
tvm
;
using
namespace
tvm
;
using
namespace
tvm
::
runtime
;
using
namespace
tvm
::
runtime
;
namespace
tvm_ext
{
TVM_REGISTER_EXT_TYPE
(
IntVector
);
TVM_REGISTER_EXT_TYPE
(
IntVector
);
TVM_REGISTER_GLOBAL
(
"tvm_ext.ivec_create"
)
TVM_REGISTER_GLOBAL
(
"tvm_ext.ivec_create"
)
...
@@ -66,3 +65,18 @@ TVM_REGISTER_GLOBAL("device_api.ext_dev")
...
@@ -66,3 +65,18 @@ TVM_REGISTER_GLOBAL("device_api.ext_dev")
*
rv
=
(
*
tvm
::
runtime
::
Registry
::
Get
(
"device_api.cpu"
))();
*
rv
=
(
*
tvm
::
runtime
::
Registry
::
Get
(
"device_api.cpu"
))();
});
});
}
// namespace tvm_ext
}
// namespace tvm_ext
// This callback approach allows extension allows tvm to extract
// This way can be helpful when we want to use a header only
// minimum version of TVM Runtime.
extern
"C"
int
TVMExtDeclare
(
TVMFunctionHandle
pregister
)
{
const
PackedFunc
&
fregister
=
*
static_cast
<
PackedFunc
*>
(
pregister
);
auto
mul
=
[](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
int
x
=
args
[
0
];
int
y
=
args
[
1
];
*
rv
=
x
*
y
;
};
fregister
(
"mul"
,
PackedFunc
(
mul
));
return
0
;
}
apps/extension/tests/test_ext.py
View file @
d99bcaf1
...
@@ -44,8 +44,14 @@ def test_ext_vec():
...
@@ -44,8 +44,14 @@ def test_ext_vec():
tvm
.
convert
(
ivec_cb
)(
ivec
)
tvm
.
convert
(
ivec_cb
)(
ivec
)
def
test_extract_ext
():
fdict
=
tvm
.
extract_ext_funcs
(
tvm_ext
.
_LIB
.
TVMExtDeclare
)
assert
fdict
[
"mul"
](
3
,
4
)
==
12
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_ext_dev
()
test_ext_dev
()
test_ext_vec
()
test_ext_vec
()
test_bind_add
()
test_bind_add
()
test_sym_add
()
test_sym_add
()
test_extract_ext
()
include/tvm/runtime/c_runtime_api.h
View file @
d99bcaf1
...
@@ -24,6 +24,13 @@
...
@@ -24,6 +24,13 @@
#define TVM_EXTERN_C
#define TVM_EXTERN_C
#endif
#endif
// Macros to do weak linking
#ifdef _MSC_VER
#define TVM_WEAK __declspec(selectany)
#else
#define TVM_WEAK __attribute__((weak))
#endif
#ifdef __EMSCRIPTEN__
#ifdef __EMSCRIPTEN__
#include <emscripten/emscripten.h>
#include <emscripten/emscripten.h>
#define TVM_DLL EMSCRIPTEN_KEEPALIVE
#define TVM_DLL EMSCRIPTEN_KEEPALIVE
...
@@ -314,6 +321,17 @@ typedef int (*TVMPackedCFunc)(
...
@@ -314,6 +321,17 @@ typedef int (*TVMPackedCFunc)(
typedef
void
(
*
TVMPackedCFuncFinalizer
)(
void
*
resource_handle
);
typedef
void
(
*
TVMPackedCFuncFinalizer
)(
void
*
resource_handle
);
/*!
/*!
* \brief Signature for extension function declarer.
*
* TVM call this function to get the extension functions
* The declarer will call register_func to register function and their name.
*
* \param resource_func_handle The register function
* \return 0 if success, -1 if failure happens
*/
typedef
int
(
*
TVMExtensionFuncDeclarer
)(
TVMFunctionHandle
register_func_handle
);
/*!
* \brief Wrap a TVMPackedCFunc to become a FunctionHandle.
* \brief Wrap a TVMPackedCFunc to become a FunctionHandle.
*
*
* The resource_handle will be managed by TVM API, until the function is no longer used.
* The resource_handle will be managed by TVM API, until the function is no longer used.
...
...
include/tvm/runtime/module.h
View file @
d99bcaf1
...
@@ -38,8 +38,14 @@ class Module {
...
@@ -38,8 +38,14 @@ class Module {
* \param query_imports Whether also query dependency modules.
* \param query_imports Whether also query dependency modules.
* \return The result function.
* \return The result function.
* This function will return PackedFunc(nullptr) if function do not exist.
* This function will return PackedFunc(nullptr) if function do not exist.
* \note Implemented in packed_func.cc
*/
*/
TVM_DLL
PackedFunc
GetFunction
(
const
std
::
string
&
name
,
bool
query_imports
=
false
);
inline
PackedFunc
GetFunction
(
const
std
::
string
&
name
,
bool
query_imports
=
false
);
/*! \return internal container */
inline
ModuleNode
*
operator
->
();
/*! \return internal container */
inline
const
ModuleNode
*
operator
->
()
const
;
// The following functions requires link with runtime.
/*!
/*!
* \brief Import another module into this module.
* \brief Import another module into this module.
* \param other The module to be imported.
* \param other The module to be imported.
...
@@ -57,10 +63,6 @@ class Module {
...
@@ -57,10 +63,6 @@ class Module {
*/
*/
TVM_DLL
static
Module
LoadFromFile
(
const
std
::
string
&
file_name
,
TVM_DLL
static
Module
LoadFromFile
(
const
std
::
string
&
file_name
,
const
std
::
string
&
format
=
""
);
const
std
::
string
&
format
=
""
);
/*! \return internal container */
inline
ModuleNode
*
operator
->
();
/*! \return internal container */
inline
const
ModuleNode
*
operator
->
()
const
;
private
:
private
:
std
::
shared_ptr
<
ModuleNode
>
node_
;
std
::
shared_ptr
<
ModuleNode
>
node_
;
...
...
include/tvm/runtime/packed_func.h
View file @
d99bcaf1
...
@@ -24,6 +24,11 @@ struct Type;
...
@@ -24,6 +24,11 @@ struct Type;
struct
Expr
;
struct
Expr
;
}
}
// Whether use TVM runtime in header only mode.
#ifndef TVM_RUNTIME_HEADER_ONLY
#define TVM_RUNTIME_HEADER_ONLY 0
#endif
namespace
tvm
{
namespace
tvm
{
// Forward declare NodeRef and Node for extensions.
// Forward declare NodeRef and Node for extensions.
// This header works fine without depend on NodeRef
// This header works fine without depend on NodeRef
...
@@ -564,11 +569,15 @@ class TVMRetValue : public TVMPODValue_ {
...
@@ -564,11 +569,15 @@ class TVMRetValue : public TVMPODValue_ {
SwitchToPOD
(
other
.
type_code
());
SwitchToPOD
(
other
.
type_code
());
value_
=
other
.
value_
;
value_
=
other
.
value_
;
}
else
{
}
else
{
#if TVM_RUNTIME_HEADER_ONLY
LOG
(
FATAL
)
<<
"Header only mode do not support ext type"
;
#else
this
->
Clear
();
this
->
Clear
();
type_code_
=
other
.
type_code
();
type_code_
=
other
.
type_code
();
value_
.
v_handle
=
value_
.
v_handle
=
(
*
(
ExtTypeVTable
::
Get
(
other
.
type_code
())
->
clone
))(
(
*
(
ExtTypeVTable
::
Get
(
other
.
type_code
())
->
clone
))(
other
.
value
().
v_handle
);
other
.
value
().
v_handle
);
#endif
}
}
break
;
break
;
}
}
...
@@ -600,7 +609,11 @@ class TVMRetValue : public TVMPODValue_ {
...
@@ -600,7 +609,11 @@ class TVMRetValue : public TVMPODValue_ {
case
kNodeHandle
:
delete
ptr
<
std
::
shared_ptr
<
Node
>
>
();
break
;
case
kNodeHandle
:
delete
ptr
<
std
::
shared_ptr
<
Node
>
>
();
break
;
}
}
if
(
type_code_
>
kExtBegin
)
{
if
(
type_code_
>
kExtBegin
)
{
#if TVM_RUNTIME_HEADER_ONLY
LOG
(
FATAL
)
<<
"Header only mode do not support ext type"
;
#else
(
*
(
ExtTypeVTable
::
Get
(
type_code_
)
->
destroy
))(
value_
.
v_handle
);
(
*
(
ExtTypeVTable
::
Get
(
type_code_
)
->
destroy
))(
value_
.
v_handle
);
#endif
}
}
type_code_
=
kNull
;
type_code_
=
kNull
;
}
}
...
@@ -882,6 +895,20 @@ inline ExtTypeVTable* ExtTypeVTable::Register_() {
...
@@ -882,6 +895,20 @@ inline ExtTypeVTable* ExtTypeVTable::Register_() {
vt
.
destroy
=
ExtTypeInfo
<
T
>::
destroy
;
vt
.
destroy
=
ExtTypeInfo
<
T
>::
destroy
;
return
ExtTypeVTable
::
RegisterInternal
(
code
,
vt
);
return
ExtTypeVTable
::
RegisterInternal
(
code
,
vt
);
}
}
// Implement Module::GetFunction
// Put implementation in this file so we have seen the PackedFunc
inline
PackedFunc
Module
::
GetFunction
(
const
std
::
string
&
name
,
bool
query_imports
)
{
PackedFunc
pf
=
node_
->
GetFunction
(
name
,
node_
);
if
(
pf
!=
nullptr
)
return
pf
;
if
(
query_imports
)
{
for
(
const
Module
&
m
:
node_
->
imports_
)
{
pf
=
m
.
node_
->
GetFunction
(
name
,
m
.
node_
);
if
(
pf
!=
nullptr
)
return
pf
;
}
}
return
pf
;
}
}
// namespace runtime
}
// namespace runtime
}
// namespace tvm
}
// namespace tvm
#endif // TVM_RUNTIME_PACKED_FUNC_H_
#endif // TVM_RUNTIME_PACKED_FUNC_H_
python/tvm/_ffi/function.py
View file @
d99bcaf1
...
@@ -234,6 +234,31 @@ def list_global_func_names():
...
@@ -234,6 +234,31 @@ def list_global_func_names():
return
fnames
return
fnames
def
extract_ext_funcs
(
finit
):
"""
Extract the extension PackedFuncs from a C module.
Parameters
----------
finit : ctypes function
a ctypes that takes signature of TVMExtensionDeclarer
Returns
-------
fdict : dict of str to Function
The extracted functions
"""
fdict
=
{}
def
_list
(
name
,
func
):
fdict
[
name
]
=
func
myf
=
convert_to_tvm_func
(
_list
)
ret
=
finit
(
myf
.
handle
)
_
=
myf
if
ret
!=
0
:
raise
RuntimeError
(
"cannot initialize with
%
s"
%
finit
)
return
fdict
def
_get_api
(
f
):
def
_get_api
(
f
):
flocal
=
f
flocal
=
f
flocal
.
is_global
=
True
flocal
.
is_global
=
True
...
...
python/tvm/api.py
View file @
d99bcaf1
...
@@ -8,7 +8,7 @@ from ._ffi.base import string_types
...
@@ -8,7 +8,7 @@ from ._ffi.base import string_types
from
._ffi.node
import
register_node
,
NodeBase
from
._ffi.node
import
register_node
,
NodeBase
from
._ffi.node
import
convert_to_node
as
_convert_to_node
from
._ffi.node
import
convert_to_node
as
_convert_to_node
from
._ffi.function
import
Function
from
._ffi.function
import
Function
from
._ffi.function
import
_init_api
,
register_func
,
get_global_func
from
._ffi.function
import
_init_api
,
register_func
,
get_global_func
,
extract_ext_funcs
from
._ffi.function
import
convert_to_tvm_func
as
_convert_tvm_func
from
._ffi.function
import
convert_to_tvm_func
as
_convert_tvm_func
from
._ffi.runtime_ctypes
import
TVMType
from
._ffi.runtime_ctypes
import
TVMType
from
.
import
_api_internal
from
.
import
_api_internal
...
...
python/tvm/build_module.py
View file @
d99bcaf1
...
@@ -23,16 +23,16 @@ from . import target as _target
...
@@ -23,16 +23,16 @@ from . import target as _target
from
.
import
make
from
.
import
make
class
DumpIR
(
object
):
class
DumpIR
(
object
):
"""Dump IR for each pass.
"""
With it, you can dump ir just like gcc/llvm.
Dump IR for each pass.
With it, you can dump ir just like gcc/llvm.
How to use:
-----------
.. code-block:: python
with tvm.build_config(dump_pass_ir=True)
How to use:
run()
-----------
.. code-block:: python
with tvm.build_config(dump_pass_ir=True)
run()
"""
"""
scope_level
=
0
scope_level
=
0
def
__init__
(
self
):
def
__init__
(
self
):
...
@@ -40,9 +40,9 @@ class DumpIR(object):
...
@@ -40,9 +40,9 @@ class DumpIR(object):
self
.
_recover_list
=
[]
self
.
_recover_list
=
[]
def
decorate
(
self
,
func
):
def
decorate
(
self
,
func
):
''' decorate the pass function'''
""" decorate the pass function"""
def
dump
(
*
args
,
**
kwargs
):
def
dump
(
*
args
,
**
kwargs
):
'''dump function'''
"""dump function"""
retv
=
func
(
*
args
,
**
kwargs
)
retv
=
func
(
*
args
,
**
kwargs
)
if
not
isinstance
(
retv
,
(
_stmt
.
Stmt
,
container
.
LoweredFunc
,
container
.
Array
)):
if
not
isinstance
(
retv
,
(
_stmt
.
Stmt
,
container
.
LoweredFunc
,
container
.
Array
)):
return
retv
return
retv
...
@@ -59,7 +59,7 @@ class DumpIR(object):
...
@@ -59,7 +59,7 @@ class DumpIR(object):
return
dump
return
dump
def
decorate_irpass
(
self
):
def
decorate_irpass
(
self
):
'''decorate ir_pass and ScheduleOps'''
"""decorate ir_pass and ScheduleOps"""
self
.
_old_sgpass
=
schedule
.
ScheduleOps
self
.
_old_sgpass
=
schedule
.
ScheduleOps
schedule
.
ScheduleOps
=
self
.
decorate
(
schedule
.
ScheduleOps
)
schedule
.
ScheduleOps
=
self
.
decorate
(
schedule
.
ScheduleOps
)
vset
=
vars
(
ir_pass
)
vset
=
vars
(
ir_pass
)
...
@@ -71,7 +71,7 @@ class DumpIR(object):
...
@@ -71,7 +71,7 @@ class DumpIR(object):
vset
[
k
]
=
self
.
decorate
(
v
)
if
isinstance
(
v
,
types
.
FunctionType
)
else
v
vset
[
k
]
=
self
.
decorate
(
v
)
if
isinstance
(
v
,
types
.
FunctionType
)
else
v
def
decorate_custompass
(
self
):
def
decorate_custompass
(
self
):
''' decorate add_lower_pass pass in BuildConfig'''
""" decorate add_lower_pass pass in BuildConfig"""
cfg
=
BuildConfig
.
current
cfg
=
BuildConfig
.
current
self
.
_old_custom_pass
=
cfg
.
add_lower_pass
self
.
_old_custom_pass
=
cfg
.
add_lower_pass
custom_pass
=
cfg
.
add_lower_pass
if
cfg
.
add_lower_pass
else
[]
custom_pass
=
cfg
.
add_lower_pass
if
cfg
.
add_lower_pass
else
[]
...
@@ -79,7 +79,7 @@ class DumpIR(object):
...
@@ -79,7 +79,7 @@ class DumpIR(object):
BuildConfig
.
current
.
add_lower_pass
=
pass_list
BuildConfig
.
current
.
add_lower_pass
=
pass_list
def
enter
(
self
):
def
enter
(
self
):
'''only decorate outermost nest'''
"""only decorate outermost nest"""
if
DumpIR
.
scope_level
>
0
:
if
DumpIR
.
scope_level
>
0
:
return
return
self
.
decorate_irpass
()
self
.
decorate_irpass
()
...
@@ -88,7 +88,7 @@ class DumpIR(object):
...
@@ -88,7 +88,7 @@ class DumpIR(object):
DumpIR
.
scope_level
+=
1
DumpIR
.
scope_level
+=
1
def
exit
(
self
):
def
exit
(
self
):
'''recover outermost nest'''
"""recover outermost nest"""
if
DumpIR
.
scope_level
>
1
:
if
DumpIR
.
scope_level
>
1
:
return
return
# recover decorated functions
# recover decorated functions
...
...
src/runtime/module.cc
View file @
d99bcaf1
...
@@ -13,19 +13,6 @@
...
@@ -13,19 +13,6 @@
namespace
tvm
{
namespace
tvm
{
namespace
runtime
{
namespace
runtime
{
PackedFunc
Module
::
GetFunction
(
const
std
::
string
&
name
,
bool
query_imports
)
{
PackedFunc
pf
=
node_
->
GetFunction
(
name
,
node_
);
if
(
pf
!=
nullptr
)
return
pf
;
if
(
query_imports
)
{
for
(
const
Module
&
m
:
node_
->
imports_
)
{
pf
=
m
.
node_
->
GetFunction
(
name
,
m
.
node_
);
if
(
pf
!=
nullptr
)
return
pf
;
}
}
return
pf
;
}
void
Module
::
Import
(
Module
other
)
{
void
Module
::
Import
(
Module
other
)
{
// specially handle rpc
// specially handle rpc
if
(
!
std
::
strcmp
((
*
this
)
->
type_key
(),
"rpc"
))
{
if
(
!
std
::
strcmp
((
*
this
)
->
type_key
(),
"rpc"
))
{
...
...
tests/scripts/task_python_integration.sh
View file @
d99bcaf1
...
@@ -6,6 +6,7 @@ rm -rf python/tvm/*.pyc python/tvm/*/*.pyc
...
@@ -6,6 +6,7 @@ rm -rf python/tvm/*.pyc python/tvm/*/*.pyc
# Test TVM
# Test TVM
make cython
||
exit
-1
make cython
||
exit
-1
make cython3
||
exit
-1
# Test extern package package
# Test extern package package
cd
apps/extension
cd
apps/extension
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment