Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
d99bcaf1
Commit
d99bcaf1
authored
Feb 23, 2018
by
Tianqi Chen
Committed by
GitHub
Feb 23, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[EXT] Allow easy extraction of extern module (#926)
parent
433756b9
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
116 additions
and
36 deletions
+116
-36
apps/extension/src/tvm_ext.cc
+17
-3
apps/extension/tests/test_ext.py
+6
-0
include/tvm/runtime/c_runtime_api.h
+18
-0
include/tvm/runtime/module.h
+7
-5
include/tvm/runtime/packed_func.h
+27
-0
python/tvm/_ffi/function.py
+25
-0
python/tvm/api.py
+1
-1
python/tvm/build_module.py
+14
-14
src/runtime/module.cc
+0
-13
tests/scripts/task_python_integration.sh
+1
-0
No files found.
apps/extension/src/tvm_ext.cc
View file @
d99bcaf1
...
...
@@ -22,12 +22,11 @@ struct extension_class_info<tvm_ext::IntVector> {
}
// namespace tvm
}
// namespace runtime
namespace
tvm_ext
{
using
namespace
tvm
;
using
namespace
tvm
::
runtime
;
namespace
tvm_ext
{
TVM_REGISTER_EXT_TYPE
(
IntVector
);
TVM_REGISTER_GLOBAL
(
"tvm_ext.ivec_create"
)
...
...
@@ -66,3 +65,18 @@ TVM_REGISTER_GLOBAL("device_api.ext_dev")
*
rv
=
(
*
tvm
::
runtime
::
Registry
::
Get
(
"device_api.cpu"
))();
});
}
// namespace tvm_ext
// This callback approach allows extension allows tvm to extract
// This way can be helpful when we want to use a header only
// minimum version of TVM Runtime.
extern
"C"
int
TVMExtDeclare
(
TVMFunctionHandle
pregister
)
{
const
PackedFunc
&
fregister
=
*
static_cast
<
PackedFunc
*>
(
pregister
);
auto
mul
=
[](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
int
x
=
args
[
0
];
int
y
=
args
[
1
];
*
rv
=
x
*
y
;
};
fregister
(
"mul"
,
PackedFunc
(
mul
));
return
0
;
}
apps/extension/tests/test_ext.py
View file @
d99bcaf1
...
...
@@ -44,8 +44,14 @@ def test_ext_vec():
tvm
.
convert
(
ivec_cb
)(
ivec
)
def
test_extract_ext
():
fdict
=
tvm
.
extract_ext_funcs
(
tvm_ext
.
_LIB
.
TVMExtDeclare
)
assert
fdict
[
"mul"
](
3
,
4
)
==
12
if
__name__
==
"__main__"
:
test_ext_dev
()
test_ext_vec
()
test_bind_add
()
test_sym_add
()
test_extract_ext
()
include/tvm/runtime/c_runtime_api.h
View file @
d99bcaf1
...
...
@@ -24,6 +24,13 @@
#define TVM_EXTERN_C
#endif
// Macros to do weak linking
#ifdef _MSC_VER
#define TVM_WEAK __declspec(selectany)
#else
#define TVM_WEAK __attribute__((weak))
#endif
#ifdef __EMSCRIPTEN__
#include <emscripten/emscripten.h>
#define TVM_DLL EMSCRIPTEN_KEEPALIVE
...
...
@@ -314,6 +321,17 @@ typedef int (*TVMPackedCFunc)(
typedef
void
(
*
TVMPackedCFuncFinalizer
)(
void
*
resource_handle
);
/*!
* \brief Signature for extension function declarer.
*
* TVM call this function to get the extension functions
* The declarer will call register_func to register function and their name.
*
* \param resource_func_handle The register function
* \return 0 if success, -1 if failure happens
*/
typedef
int
(
*
TVMExtensionFuncDeclarer
)(
TVMFunctionHandle
register_func_handle
);
/*!
* \brief Wrap a TVMPackedCFunc to become a FunctionHandle.
*
* The resource_handle will be managed by TVM API, until the function is no longer used.
...
...
include/tvm/runtime/module.h
View file @
d99bcaf1
...
...
@@ -38,8 +38,14 @@ class Module {
* \param query_imports Whether also query dependency modules.
* \return The result function.
* This function will return PackedFunc(nullptr) if function do not exist.
* \note Implemented in packed_func.cc
*/
TVM_DLL
PackedFunc
GetFunction
(
const
std
::
string
&
name
,
bool
query_imports
=
false
);
inline
PackedFunc
GetFunction
(
const
std
::
string
&
name
,
bool
query_imports
=
false
);
/*! \return internal container */
inline
ModuleNode
*
operator
->
();
/*! \return internal container */
inline
const
ModuleNode
*
operator
->
()
const
;
// The following functions requires link with runtime.
/*!
* \brief Import another module into this module.
* \param other The module to be imported.
...
...
@@ -57,10 +63,6 @@ class Module {
*/
TVM_DLL
static
Module
LoadFromFile
(
const
std
::
string
&
file_name
,
const
std
::
string
&
format
=
""
);
/*! \return internal container */
inline
ModuleNode
*
operator
->
();
/*! \return internal container */
inline
const
ModuleNode
*
operator
->
()
const
;
private
:
std
::
shared_ptr
<
ModuleNode
>
node_
;
...
...
include/tvm/runtime/packed_func.h
View file @
d99bcaf1
...
...
@@ -24,6 +24,11 @@ struct Type;
struct
Expr
;
}
// Whether use TVM runtime in header only mode.
#ifndef TVM_RUNTIME_HEADER_ONLY
#define TVM_RUNTIME_HEADER_ONLY 0
#endif
namespace
tvm
{
// Forward declare NodeRef and Node for extensions.
// This header works fine without depend on NodeRef
...
...
@@ -564,11 +569,15 @@ class TVMRetValue : public TVMPODValue_ {
SwitchToPOD
(
other
.
type_code
());
value_
=
other
.
value_
;
}
else
{
#if TVM_RUNTIME_HEADER_ONLY
LOG
(
FATAL
)
<<
"Header only mode do not support ext type"
;
#else
this
->
Clear
();
type_code_
=
other
.
type_code
();
value_
.
v_handle
=
(
*
(
ExtTypeVTable
::
Get
(
other
.
type_code
())
->
clone
))(
other
.
value
().
v_handle
);
#endif
}
break
;
}
...
...
@@ -600,7 +609,11 @@ class TVMRetValue : public TVMPODValue_ {
case
kNodeHandle
:
delete
ptr
<
std
::
shared_ptr
<
Node
>
>
();
break
;
}
if
(
type_code_
>
kExtBegin
)
{
#if TVM_RUNTIME_HEADER_ONLY
LOG
(
FATAL
)
<<
"Header only mode do not support ext type"
;
#else
(
*
(
ExtTypeVTable
::
Get
(
type_code_
)
->
destroy
))(
value_
.
v_handle
);
#endif
}
type_code_
=
kNull
;
}
...
...
@@ -882,6 +895,20 @@ inline ExtTypeVTable* ExtTypeVTable::Register_() {
vt
.
destroy
=
ExtTypeInfo
<
T
>::
destroy
;
return
ExtTypeVTable
::
RegisterInternal
(
code
,
vt
);
}
// Implement Module::GetFunction
// Put implementation in this file so we have seen the PackedFunc
inline
PackedFunc
Module
::
GetFunction
(
const
std
::
string
&
name
,
bool
query_imports
)
{
PackedFunc
pf
=
node_
->
GetFunction
(
name
,
node_
);
if
(
pf
!=
nullptr
)
return
pf
;
if
(
query_imports
)
{
for
(
const
Module
&
m
:
node_
->
imports_
)
{
pf
=
m
.
node_
->
GetFunction
(
name
,
m
.
node_
);
if
(
pf
!=
nullptr
)
return
pf
;
}
}
return
pf
;
}
}
// namespace runtime
}
// namespace tvm
#endif // TVM_RUNTIME_PACKED_FUNC_H_
python/tvm/_ffi/function.py
View file @
d99bcaf1
...
...
@@ -234,6 +234,31 @@ def list_global_func_names():
return
fnames
def
extract_ext_funcs
(
finit
):
"""
Extract the extension PackedFuncs from a C module.
Parameters
----------
finit : ctypes function
a ctypes that takes signature of TVMExtensionDeclarer
Returns
-------
fdict : dict of str to Function
The extracted functions
"""
fdict
=
{}
def
_list
(
name
,
func
):
fdict
[
name
]
=
func
myf
=
convert_to_tvm_func
(
_list
)
ret
=
finit
(
myf
.
handle
)
_
=
myf
if
ret
!=
0
:
raise
RuntimeError
(
"cannot initialize with
%
s"
%
finit
)
return
fdict
def
_get_api
(
f
):
flocal
=
f
flocal
.
is_global
=
True
...
...
python/tvm/api.py
View file @
d99bcaf1
...
...
@@ -8,7 +8,7 @@ from ._ffi.base import string_types
from
._ffi.node
import
register_node
,
NodeBase
from
._ffi.node
import
convert_to_node
as
_convert_to_node
from
._ffi.function
import
Function
from
._ffi.function
import
_init_api
,
register_func
,
get_global_func
from
._ffi.function
import
_init_api
,
register_func
,
get_global_func
,
extract_ext_funcs
from
._ffi.function
import
convert_to_tvm_func
as
_convert_tvm_func
from
._ffi.runtime_ctypes
import
TVMType
from
.
import
_api_internal
...
...
python/tvm/build_module.py
View file @
d99bcaf1
...
...
@@ -23,16 +23,16 @@ from . import target as _target
from
.
import
make
class
DumpIR
(
object
):
"""Dump IR for each pass.
With it, you can dump ir just like gcc/llvm.
How to use:
-----------
.. code-block:: python
"""
Dump IR for each pass.
With it, you can dump ir just like gcc/llvm.
with tvm.build_config(dump_pass_ir=True)
run()
How to use:
-----------
.. code-block:: python
with tvm.build_config(dump_pass_ir=True)
run()
"""
scope_level
=
0
def
__init__
(
self
):
...
...
@@ -40,9 +40,9 @@ class DumpIR(object):
self
.
_recover_list
=
[]
def
decorate
(
self
,
func
):
''' decorate the pass function'''
""" decorate the pass function"""
def
dump
(
*
args
,
**
kwargs
):
'''dump function'''
"""dump function"""
retv
=
func
(
*
args
,
**
kwargs
)
if
not
isinstance
(
retv
,
(
_stmt
.
Stmt
,
container
.
LoweredFunc
,
container
.
Array
)):
return
retv
...
...
@@ -59,7 +59,7 @@ class DumpIR(object):
return
dump
def
decorate_irpass
(
self
):
'''decorate ir_pass and ScheduleOps'''
"""decorate ir_pass and ScheduleOps"""
self
.
_old_sgpass
=
schedule
.
ScheduleOps
schedule
.
ScheduleOps
=
self
.
decorate
(
schedule
.
ScheduleOps
)
vset
=
vars
(
ir_pass
)
...
...
@@ -71,7 +71,7 @@ class DumpIR(object):
vset
[
k
]
=
self
.
decorate
(
v
)
if
isinstance
(
v
,
types
.
FunctionType
)
else
v
def
decorate_custompass
(
self
):
''' decorate add_lower_pass pass in BuildConfig'''
""" decorate add_lower_pass pass in BuildConfig"""
cfg
=
BuildConfig
.
current
self
.
_old_custom_pass
=
cfg
.
add_lower_pass
custom_pass
=
cfg
.
add_lower_pass
if
cfg
.
add_lower_pass
else
[]
...
...
@@ -79,7 +79,7 @@ class DumpIR(object):
BuildConfig
.
current
.
add_lower_pass
=
pass_list
def
enter
(
self
):
'''only decorate outermost nest'''
"""only decorate outermost nest"""
if
DumpIR
.
scope_level
>
0
:
return
self
.
decorate_irpass
()
...
...
@@ -88,7 +88,7 @@ class DumpIR(object):
DumpIR
.
scope_level
+=
1
def
exit
(
self
):
'''recover outermost nest'''
"""recover outermost nest"""
if
DumpIR
.
scope_level
>
1
:
return
# recover decorated functions
...
...
src/runtime/module.cc
View file @
d99bcaf1
...
...
@@ -13,19 +13,6 @@
namespace
tvm
{
namespace
runtime
{
PackedFunc
Module
::
GetFunction
(
const
std
::
string
&
name
,
bool
query_imports
)
{
PackedFunc
pf
=
node_
->
GetFunction
(
name
,
node_
);
if
(
pf
!=
nullptr
)
return
pf
;
if
(
query_imports
)
{
for
(
const
Module
&
m
:
node_
->
imports_
)
{
pf
=
m
.
node_
->
GetFunction
(
name
,
m
.
node_
);
if
(
pf
!=
nullptr
)
return
pf
;
}
}
return
pf
;
}
void
Module
::
Import
(
Module
other
)
{
// specially handle rpc
if
(
!
std
::
strcmp
((
*
this
)
->
type_key
(),
"rpc"
))
{
...
...
tests/scripts/task_python_integration.sh
View file @
d99bcaf1
...
...
@@ -6,6 +6,7 @@ rm -rf python/tvm/*.pyc python/tvm/*/*.pyc
# Test TVM
make cython
||
exit
-1
make cython3
||
exit
-1
# Test extern package package
cd
apps/extension
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment