Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
181edb4a
Commit
181edb4a
authored
Apr 21, 2017
by
Tianqi Chen
Committed by
GitHub
Apr 21, 2017
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[LANG] Change namespace convention to dot (#100)
parent
97c67e53
Hide whitespace changes
Inline
Side-by-side
Showing
31 changed files
with
213 additions
and
150 deletions
+213
-150
include/tvm/base.h
+1
-0
include/tvm/c_api.h
+12
-0
include/tvm/runtime/registry.h
+9
-6
python/tvm/_ctypes/_function.py
+21
-21
python/tvm/_ctypes/_node.py
+6
-2
python/tvm/_ctypes/_types.py
+7
-1
src/api/api_arith.cc
+9
-9
src/api/api_base.cc
+5
-5
src/api/api_codegen.cc
+2
-2
src/api/api_ir.cc
+13
-13
src/api/api_lang.cc
+42
-42
src/api/api_pass.cc
+6
-6
src/api/api_schedule.cc
+3
-3
src/api/c_api.cc
+17
-0
src/codegen/build_cuda.cc
+1
-1
src/codegen/build_opencl.cc
+1
-1
src/codegen/codegen.cc
+2
-2
src/codegen/llvm/llvm_module.cc
+1
-1
src/codegen/stack_vm/stack_vm_module.cc
+1
-1
src/codegen/verilog/verilog_module.cc
+1
-1
src/codegen/verilog/vpi_device_api.cc
+4
-4
src/codegen/verilog/vpi_session.cc
+9
-9
src/runtime/c_runtime_api.cc
+1
-1
src/runtime/cpu_device_api.cc
+1
-1
src/runtime/cuda/cuda_device_api.cc
+1
-1
src/runtime/cuda/cuda_module.cc
+2
-2
src/runtime/dso_module.cc
+1
-1
src/runtime/module.cc
+10
-10
src/runtime/opencl/opencl_device_api.cc
+2
-2
src/runtime/opencl/opencl_module.cc
+2
-2
tests/python/unittest/test_runtime_packed_func.py
+20
-0
No files found.
include/tvm/base.h
View file @
181edb4a
...
...
@@ -12,6 +12,7 @@
#include <string>
#include <memory>
#include <functional>
#include "./runtime/registry.h"
namespace
tvm
{
...
...
include/tvm/c_api.h
View file @
181edb4a
...
...
@@ -20,6 +20,18 @@ TVM_EXTERN_C {
typedef
void
*
NodeHandle
;
/*!
* \brief Inplace translate callback argument value to return value.
* This is only needed for non-POD arguments.
*
* \param value The value to be translated.
* \param code The type code to be translated.
* \note This function will do a shallow copy when necessary.
*
* \return 0 when success, -1 when failure happens.
*/
TVM_DLL
int
TVMCbArgToReturn
(
TVMValue
*
value
,
int
code
);
/*!
* \brief free the node handle
* \param handle The node handle to be freed.
* \return 0 when success, -1 when failure happens
...
...
include/tvm/runtime/registry.h
View file @
181edb4a
...
...
@@ -81,7 +81,6 @@ class Registry {
friend
struct
Manager
;
};
/*! \brief helper macro to supress unused warning */
#if defined(__GNUC__)
#define TVM_ATTRIBUTE_UNUSED __attribute__((unused))
...
...
@@ -89,19 +88,23 @@ class Registry {
#define TVM_ATTRIBUTE_UNUSED
#endif
#define TVM_STR_CONCAT_(__x, __y) __x##__y
#define TVM_STR_CONCAT(__x, __y) TVM_STR_CONCAT_(__x, __y)
#define TVM_FUNC_REG_VAR_DEF \
static TVM_ATTRIBUTE_UNUSED ::tvm::runtime::Registry& __make_ ## TVMOp
/*!
* \brief Register a function globally.
* \code
* TVM_REGISTER_GLOBAL(
MyPrint
)
* TVM_REGISTER_GLOBAL(
"MyPrint"
)
* .set_body([](TVMArgs args, TVMRetValue* rv) {
* // my code.
* });
* \endcode
*/
#define TVM_REGISTER_GLOBAL(OpName) \
static TVM_ATTRIBUTE_UNUSED ::tvm::runtime::Registry& \
__make_TVMRegistry_ ## OpName = \
::tvm::runtime::Registry::Register(#OpName)
TVM_STR_CONCAT(TVM_FUNC_REG_VAR_DEF, __COUNTER__) = \
::tvm::runtime::Registry::Register(OpName)
}
// namespace runtime
}
// namespace tvm
...
...
python/tvm/_ctypes/_function.py
View file @
181edb4a
...
...
@@ -12,7 +12,7 @@ from .._base import _LIB, check_call
from
.._base
import
c_str
,
py_str
,
string_types
from
._types
import
TVMValue
,
TypeCode
,
TVMType
,
TVMByteArray
from
._types
import
TVMPackedCFunc
,
TVMCFuncFinalizer
from
._types
import
RETURN_SWITCH
,
C_TO_PY_ARG_SWITCH
from
._types
import
RETURN_SWITCH
,
C_TO_PY_ARG_SWITCH
,
_wrap_arg_func
from
._node
import
NodeBase
,
SliceBase
,
convert_to_node
from
._ndarray
import
NDArrayBase
...
...
@@ -302,6 +302,10 @@ def _handle_return_func(x):
# setup return handle for function type
RETURN_SWITCH
[
TypeCode
.
FUNC_HANDLE
]
=
_handle_return_func
RETURN_SWITCH
[
TypeCode
.
MODULE_HANDLE
]
=
_return_module
C_TO_PY_ARG_SWITCH
[
TypeCode
.
FUNC_HANDLE
]
=
_wrap_arg_func
(
_handle_return_func
,
TypeCode
.
FUNC_HANDLE
)
C_TO_PY_ARG_SWITCH
[
TypeCode
.
MODULE_HANDLE
]
=
_wrap_arg_func
(
_return_module
,
TypeCode
.
MODULE_HANDLE
)
def
register_func
(
func_name
,
f
=
None
):
...
...
@@ -415,35 +419,31 @@ def _get_api(f):
return
flocal
(
*
args
)
return
my_api_func
def
_init_api
(
mod
):
def
_init_api
(
namespace
):
"""Initialize api for a given module name
mod : str
The name of the module.
"""
module
=
sys
.
modules
[
mod
]
namespace_match
=
{
"_make_"
:
"tvm.make"
,
"_arith_"
:
"tvm.arith"
,
"_pass_"
:
"tvm.ir_pass"
,
"_codegen_"
:
"tvm.codegen"
,
"_module_"
:
"tvm.module"
,
"_schedule_"
:
"tvm.schedule"
}
module
=
sys
.
modules
[
namespace
]
assert
namespace
.
startswith
(
"tvm."
)
prefix
=
namespace
[
4
:]
for
name
in
list_global_func_names
():
fname
=
name
target
=
"tvm.api"
for
k
,
v
in
namespace_match
.
items
():
if
name
.
startswith
(
k
):
fname
=
name
[
len
(
k
):]
target
=
v
if
target
!=
mod
:
continue
if
mod
==
"tvm.api"
and
name
.
startswith
(
"_"
):
target_module
=
sys
.
modules
[
"tvm._api_internal"
]
if
prefix
==
"api"
:
fname
=
name
if
name
.
startswith
(
"_"
):
target_module
=
sys
.
modules
[
"tvm._api_internal"
]
else
:
target_module
=
module
else
:
if
not
name
.
startswith
(
prefix
):
continue
fname
=
name
[
len
(
prefix
)
+
1
:]
target_module
=
module
if
fname
.
find
(
"."
)
!=
-
1
:
continue
f
=
get_global_func
(
name
)
ff
=
_get_api
(
f
)
ff
.
__name__
=
fname
...
...
python/tvm/_ctypes/_node.py
View file @
181edb4a
...
...
@@ -10,7 +10,9 @@ from numbers import Number, Integral
from
.._base
import
_LIB
,
check_call
from
.._base
import
c_str
,
py_str
,
string_types
from
..
import
_api_internal
from
._types
import
TVMValue
,
TypeCode
,
RETURN_SWITCH
from
._types
import
TVMValue
,
TypeCode
from
._types
import
RETURN_SWITCH
,
C_TO_PY_ARG_SWITCH
,
_wrap_arg_func
NodeHandle
=
ctypes
.
c_void_p
...
...
@@ -19,7 +21,7 @@ NODE_TYPE = {
}
def
_return_node
(
x
):
"""Return function"""
"""Return
node
function"""
handle
=
x
.
v_handle
if
not
isinstance
(
handle
,
NodeHandle
):
handle
=
NodeHandle
(
handle
)
...
...
@@ -35,6 +37,8 @@ def _return_node(x):
RETURN_SWITCH
[
TypeCode
.
NODE_HANDLE
]
=
_return_node
C_TO_PY_ARG_SWITCH
[
TypeCode
.
NODE_HANDLE
]
=
_wrap_arg_func
(
_return_node
,
TypeCode
.
NODE_HANDLE
)
class
SliceBase
(
object
):
...
...
python/tvm/_ctypes/_types.py
View file @
181edb4a
...
...
@@ -4,7 +4,7 @@ from __future__ import absolute_import as _abs
import
ctypes
import
numpy
as
np
from
.._base
import
py_str
from
.._base
import
py_str
,
check_call
,
_LIB
tvm_shape_index_t
=
ctypes
.
c_int64
...
...
@@ -130,6 +130,12 @@ def _return_bytes(x):
raise
RuntimeError
(
'memmove failed'
)
return
res
def
_wrap_arg_func
(
return_f
,
type_code
):
tcode
=
ctypes
.
c_int
(
type_code
)
def
_wrap_func
(
x
):
check_call
(
_LIB
.
TVMCbArgToReturn
(
ctypes
.
byref
(
x
),
tcode
))
return
return_f
(
x
)
return
_wrap_func
RETURN_SWITCH
=
{
TypeCode
.
INT
:
lambda
x
:
x
.
v_int64
,
...
...
src/api/api_arith.cc
View file @
181edb4a
...
...
@@ -11,49 +11,49 @@
namespace
tvm
{
namespace
arith
{
TVM_REGISTER_API
(
_arith_intset_single_point
)
TVM_REGISTER_API
(
"arith.intset_single_point"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
IntSet
::
single_point
(
args
[
0
]);
});
TVM_REGISTER_API
(
_arith_intset_interval
)
TVM_REGISTER_API
(
"arith.intset_interval"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
IntSet
::
interval
(
args
[
0
],
args
[
1
]);
});
TVM_REGISTER_API
(
_arith_EvalModular
)
TVM_REGISTER_API
(
"arith.EvalModular"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
EvalModular
(
args
[
0
],
Map
<
Var
,
IntSet
>
());
});
TVM_REGISTER_API
(
_arith_DetectLinearEquation
)
TVM_REGISTER_API
(
"arith.DetectLinearEquation"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
DetectLinearEquation
(
args
[
0
],
args
[
1
]);
});
TVM_REGISTER_API
(
_arith_DeduceBound
)
TVM_REGISTER_API
(
"arith.DeduceBound"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
DeduceBound
(
args
[
0
],
args
[
1
],
args
[
2
].
operator
Map
<
Var
,
IntSet
>
(),
args
[
3
].
operator
Map
<
Var
,
IntSet
>
());
});
TVM_REGISTER_API
(
_IntervalSetGetMin
)
TVM_REGISTER_API
(
"_IntervalSetGetMin"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
args
[
0
].
operator
IntSet
().
min
();
});
TVM_REGISTER_API
(
_IntervalSetGetMax
)
TVM_REGISTER_API
(
"_IntervalSetGetMax"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
args
[
0
].
operator
IntSet
().
max
();
});
TVM_REGISTER_API
(
_IntSetIsNothing
)
TVM_REGISTER_API
(
"_IntSetIsNothing"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
args
[
0
].
operator
IntSet
().
is_nothing
();
});
TVM_REGISTER_API
(
_IntSetIsEverything
)
TVM_REGISTER_API
(
"_IntSetIsEverything"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
args
[
0
].
operator
IntSet
().
is_everything
();
});
...
...
src/api/api_base.cc
View file @
181edb4a
/*!
/*!
* Copyright (c) 2017 by Contributors
* Implementation of basic API functions
* \file api_base.cc
...
...
@@ -9,7 +9,7 @@
namespace
tvm
{
TVM_REGISTER_API
(
_format_str
)
TVM_REGISTER_API
(
"_format_str"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
CHECK
(
args
[
0
].
type_code
()
==
kNodeHandle
);
std
::
ostringstream
os
;
...
...
@@ -17,19 +17,19 @@ TVM_REGISTER_API(_format_str)
*
ret
=
os
.
str
();
});
TVM_REGISTER_API
(
_raw_ptr
)
TVM_REGISTER_API
(
"_raw_ptr"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
CHECK
(
args
[
0
].
type_code
()
==
kNodeHandle
);
*
ret
=
reinterpret_cast
<
int64_t
>
(
args
[
0
].
node_sptr
().
get
());
});
TVM_REGISTER_API
(
_save_json
)
TVM_REGISTER_API
(
"_save_json"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
SaveJSON
(
args
[
0
]);
});
TVM_REGISTER_API
(
_load_json
)
TVM_REGISTER_API
(
"_load_json"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
NodeRef
(
LoadJSON_
(
args
[
0
]));
});
...
...
src/api/api_codegen.cc
View file @
181edb4a
...
...
@@ -12,7 +12,7 @@
namespace
tvm
{
namespace
codegen
{
TVM_REGISTER_API
(
_codegen__Build
)
TVM_REGISTER_API
(
"codegen._Build"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
if
(
args
[
0
].
IsNodeType
<
LoweredFunc
>
())
{
*
ret
=
Build
({
args
[
0
]},
args
[
1
]);
...
...
@@ -21,7 +21,7 @@ TVM_REGISTER_API(_codegen__Build)
}
});
TVM_REGISTER_API
(
_codegen__Enabled
)
TVM_REGISTER_API
(
"codegen._Enabled"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
TargetEnabled
(
args
[
0
]);
});
...
...
src/api/api_ir.cc
View file @
181edb4a
...
...
@@ -11,12 +11,12 @@
namespace
tvm
{
namespace
ir
{
TVM_REGISTER_API
(
_Var
)
TVM_REGISTER_API
(
"_Var"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
Variable
::
make
(
args
[
1
],
args
[
0
]);
});
TVM_REGISTER_API
(
_make_For
)
TVM_REGISTER_API
(
"make.For"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
For
::
make
(
args
[
0
],
args
[
1
],
...
...
@@ -26,7 +26,7 @@ TVM_REGISTER_API(_make_For)
args
[
5
]);
});
TVM_REGISTER_API
(
_make_Realize
)
TVM_REGISTER_API
(
"make.Realize"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
Realize
::
make
(
args
[
0
],
args
[
1
],
...
...
@@ -37,7 +37,7 @@ TVM_REGISTER_API(_make_Realize)
});
TVM_REGISTER_API
(
_make_Call
)
TVM_REGISTER_API
(
"make.Call"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
Call
::
make
(
args
[
0
],
args
[
1
],
...
...
@@ -47,7 +47,7 @@ TVM_REGISTER_API(_make_Call)
args
[
5
]);
});
TVM_REGISTER_API
(
_make_Allocate
)
TVM_REGISTER_API
(
"make.Allocate"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
Allocate
::
make
(
args
[
0
],
args
[
1
],
...
...
@@ -58,31 +58,31 @@ TVM_REGISTER_API(_make_Allocate)
// make from two arguments
#define REGISTER_MAKE1(Node) \
TVM_REGISTER_API(
_make_## Node)
\
TVM_REGISTER_API(
"make."#Node)
\
.set_body([](TVMArgs args, TVMRetValue *ret) { \
*ret = Node::make(args[0]); \
}) \
#define REGISTER_MAKE2(Node) \
TVM_REGISTER_API(
_make_## Node)
\
.set_body([](TVMArgs args, TVMRetValue *ret) { \
TVM_REGISTER_API(
"make."#Node)
\
.set_body([](TVMArgs args, TVMRetValue *ret) {
\
*ret = Node::make(args[0], args[1]); \
}) \
#define REGISTER_MAKE3(Node) \
TVM_REGISTER_API(
_make_## Node)
\
TVM_REGISTER_API(
"make."#Node)
\
.set_body([](TVMArgs args, TVMRetValue *ret) { \
*ret = Node::make(args[0], args[1], args[2]); \
}) \
#define REGISTER_MAKE4(Node) \
TVM_REGISTER_API(
_make_## Node)
\
TVM_REGISTER_API(
"make."#Node)
\
.set_body([](TVMArgs args, TVMRetValue *ret) { \
*ret = Node::make(args[0], args[1], args[2], args[3]); \
}) \
*ret = Node::make(args[0], args[1], args[2], args[3]);
\
})
\
#define REGISTER_MAKE_BINARY_OP(Node) \
TVM_REGISTER_API(
_make_## Node)
\
TVM_REGISTER_API(
"make."#Node)
\
.set_body([](TVMArgs args, TVMRetValue *ret) { \
Expr a = args[0], b = args[1]; \
match_types(a, b); \
...
...
src/api/api_lang.cc
View file @
181edb4a
...
...
@@ -13,7 +13,7 @@
namespace
tvm
{
TVM_REGISTER_API
(
_const
)
TVM_REGISTER_API
(
"_const"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
if
(
args
[
0
].
type_code
()
==
kInt
)
{
*
ret
=
make_const
(
args
[
1
],
args
[
0
].
operator
int64_t
());
...
...
@@ -25,13 +25,13 @@ TVM_REGISTER_API(_const)
});
TVM_REGISTER_API
(
_str
)
TVM_REGISTER_API
(
"_str"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
ir
::
StringImm
::
make
(
args
[
0
]);
});
TVM_REGISTER_API
(
_Array
)
TVM_REGISTER_API
(
"_Array"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
std
::
vector
<
std
::
shared_ptr
<
Node
>
>
data
;
for
(
int
i
=
0
;
i
<
args
.
size
();
++
i
)
{
...
...
@@ -42,7 +42,7 @@ TVM_REGISTER_API(_Array)
*
ret
=
node
;
});
TVM_REGISTER_API
(
_ArrayGetItem
)
TVM_REGISTER_API
(
"_ArrayGetItem"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
int64_t
i
=
args
[
1
];
auto
&
sptr
=
args
[
0
].
node_sptr
();
...
...
@@ -53,7 +53,7 @@ TVM_REGISTER_API(_ArrayGetItem)
*
ret
=
n
->
data
[
static_cast
<
size_t
>
(
i
)];
});
TVM_REGISTER_API
(
_ArraySize
)
TVM_REGISTER_API
(
"_ArraySize"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
auto
&
sptr
=
args
[
0
].
node_sptr
();
CHECK
(
sptr
->
is_type
<
ArrayNode
>
());
...
...
@@ -61,7 +61,7 @@ TVM_REGISTER_API(_ArraySize)
static_cast
<
const
ArrayNode
*>
(
sptr
.
get
())
->
data
.
size
());
});
TVM_REGISTER_API
(
_Map
)
TVM_REGISTER_API
(
"_Map"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
CHECK_EQ
(
args
.
size
()
%
2
,
0
);
MapNode
::
ContainerType
data
;
...
...
@@ -78,7 +78,7 @@ TVM_REGISTER_API(_Map)
*
ret
=
node
;
});
TVM_REGISTER_API
(
_MapSize
)
TVM_REGISTER_API
(
"_MapSize"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
auto
&
sptr
=
args
[
0
].
node_sptr
();
CHECK
(
sptr
->
is_type
<
MapNode
>
());
...
...
@@ -86,7 +86,7 @@ TVM_REGISTER_API(_MapSize)
*
ret
=
static_cast
<
int64_t
>
(
n
->
data
.
size
());
});
TVM_REGISTER_API
(
_MapGetItem
)
TVM_REGISTER_API
(
"_MapGetItem"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
CHECK
(
args
[
0
].
type_code
()
==
kNodeHandle
);
CHECK
(
args
[
1
].
type_code
()
==
kNodeHandle
);
...
...
@@ -99,7 +99,7 @@ TVM_REGISTER_API(_MapGetItem)
*
ret
=
(
*
it
).
second
;
});
TVM_REGISTER_API
(
_MapCount
)
TVM_REGISTER_API
(
"_MapCount"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
CHECK
(
args
[
0
].
type_code
()
==
kNodeHandle
);
CHECK
(
args
[
1
].
type_code
()
==
kNodeHandle
);
...
...
@@ -110,7 +110,7 @@ TVM_REGISTER_API(_MapCount)
n
->
data
.
count
(
args
[
1
].
node_sptr
()));
});
TVM_REGISTER_API
(
_MapItems
)
TVM_REGISTER_API
(
"_MapItems"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
auto
&
sptr
=
args
[
0
].
node_sptr
();
CHECK
(
sptr
->
is_type
<
MapNode
>
());
...
...
@@ -123,7 +123,7 @@ TVM_REGISTER_API(_MapItems)
*
ret
=
rkvs
;
});
TVM_REGISTER_API
(
Range
)
TVM_REGISTER_API
(
"Range"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
if
(
args
.
size
()
==
1
)
{
*
ret
=
Range
(
0
,
args
[
0
]);
...
...
@@ -132,7 +132,7 @@ TVM_REGISTER_API(Range)
}
});
TVM_REGISTER_API
(
_Buffer
)
TVM_REGISTER_API
(
"_Buffer"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
BufferNode
::
make
(
args
[
0
],
args
[
1
],
...
...
@@ -143,7 +143,7 @@ TVM_REGISTER_API(_Buffer)
args
[
6
]);
});
TVM_REGISTER_API
(
_Tensor
)
TVM_REGISTER_API
(
"_Tensor"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
TensorNode
::
make
(
args
[
0
],
args
[
1
],
...
...
@@ -151,32 +151,32 @@ TVM_REGISTER_API(_Tensor)
args
[
3
]);
});
TVM_REGISTER_API
(
_TensorEqual
)
TVM_REGISTER_API
(
"_TensorEqual"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
args
[
0
].
operator
Tensor
()
==
args
[
1
].
operator
Tensor
();
});
TVM_REGISTER_API
(
_TensorHash
)
TVM_REGISTER_API
(
"_TensorHash"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
static_cast
<
int64_t
>
(
std
::
hash
<
Tensor
>
()(
args
[
0
].
operator
Tensor
()));
});
TVM_REGISTER_API
(
_Placeholder
)
TVM_REGISTER_API
(
"_Placeholder"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
placeholder
(
args
[
0
],
args
[
1
],
args
[
2
]);
});
TVM_REGISTER_API
(
_ComputeOp
)
TVM_REGISTER_API
(
"_ComputeOp"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
ComputeOpNode
::
make
(
args
[
0
],
args
[
1
],
args
[
2
]);
});
TVM_REGISTER_API
(
_ScanOp
)
TVM_REGISTER_API
(
"_ScanOp"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
ScanOpNode
::
make
(
args
[
0
],
args
[
1
],
...
...
@@ -186,7 +186,7 @@ TVM_REGISTER_API(_ScanOp)
args
[
5
]);
});
TVM_REGISTER_API
(
_ExternOp
)
TVM_REGISTER_API
(
"_ExternOp"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
ExternOpNode
::
make
(
args
[
0
],
args
[
1
],
...
...
@@ -195,18 +195,18 @@ TVM_REGISTER_API(_ExternOp)
args
[
4
]);
});
TVM_REGISTER_API
(
_OpGetOutput
)
TVM_REGISTER_API
(
"_OpGetOutput"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
args
[
0
].
operator
Operation
().
output
(
static_cast
<
size_t
>
(
args
[
1
].
operator
int64_t
()));
});
TVM_REGISTER_API
(
_OpNumOutputs
)
TVM_REGISTER_API
(
"_OpNumOutputs"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
args
[
0
].
operator
Operation
()
->
num_outputs
();
});
TVM_REGISTER_API
(
_IterVar
)
TVM_REGISTER_API
(
"_IterVar"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
IterVarNode
::
make
(
args
[
0
],
args
[
1
],
...
...
@@ -214,24 +214,24 @@ TVM_REGISTER_API(_IterVar)
args
[
3
]);
});
TVM_REGISTER_API
(
_Schedule
)
TVM_REGISTER_API
(
"_Schedule"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
Schedule
(
args
[
0
].
operator
Array
<
Operation
>
());
});
TVM_REGISTER_API
(
_StageSetScope
)
TVM_REGISTER_API
(
"_StageSetScope"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
args
[
0
].
operator
Stage
()
.
set_scope
(
args
[
1
]);
});
TVM_REGISTER_API
(
_StageBind
)
TVM_REGISTER_API
(
"_StageBind"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
args
[
0
].
operator
Stage
()
.
bind
(
args
[
1
],
args
[
2
]);
});
TVM_REGISTER_API
(
_StageSplitByFactor
)
TVM_REGISTER_API
(
"_StageSplitByFactor"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
IterVar
outer
,
inner
;
args
[
0
].
operator
Stage
()
...
...
@@ -239,7 +239,7 @@ TVM_REGISTER_API(_StageSplitByFactor)
*
ret
=
Array
<
IterVar
>
({
outer
,
inner
});
});
TVM_REGISTER_API
(
_StageSplitByNParts
)
TVM_REGISTER_API
(
"_StageSplitByNParts"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
IterVar
outer
,
inner
;
args
[
0
].
operator
Stage
()
...
...
@@ -247,7 +247,7 @@ TVM_REGISTER_API(_StageSplitByNParts)
*
ret
=
Array
<
IterVar
>
({
outer
,
inner
});
});
TVM_REGISTER_API
(
_StageFuse
)
TVM_REGISTER_API
(
"_StageFuse"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
IterVar
fused
;
args
[
0
].
operator
Stage
()
...
...
@@ -255,31 +255,31 @@ TVM_REGISTER_API(_StageFuse)
*
ret
=
fused
;
});
TVM_REGISTER_API
(
_StageComputeAt
)
TVM_REGISTER_API
(
"_StageComputeAt"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
args
[
0
].
operator
Stage
()
.
compute_at
(
args
[
1
],
args
[
2
]);
});
TVM_REGISTER_API
(
_StageComputeInline
)
TVM_REGISTER_API
(
"_StageComputeInline"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
args
[
0
].
operator
Stage
()
.
compute_inline
();
});
TVM_REGISTER_API
(
_StageComputeRoot
)
TVM_REGISTER_API
(
"_StageComputeRoot"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
args
[
0
].
operator
Stage
()
.
compute_root
();
});
TVM_REGISTER_API
(
_StageReorder
)
TVM_REGISTER_API
(
"_StageReorder"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
args
[
0
].
operator
Stage
()
.
reorder
(
args
[
1
]);
});
TVM_REGISTER_API
(
_StageTile
)
TVM_REGISTER_API
(
"_StageTile"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
IterVar
x_outer
,
y_outer
,
x_inner
,
y_inner
;
args
[
0
].
operator
Stage
()
...
...
@@ -290,55 +290,55 @@ TVM_REGISTER_API(_StageTile)
*
ret
=
Array
<
IterVar
>
({
x_outer
,
y_outer
,
x_inner
,
y_inner
});
});
TVM_REGISTER_API
(
_StageEnvThreads
)
TVM_REGISTER_API
(
"_StageEnvThreads"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
args
[
0
].
operator
Stage
()
.
env_threads
(
args
[
1
]);
});
TVM_REGISTER_API
(
_StageUnroll
)
TVM_REGISTER_API
(
"_StageUnroll"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
args
[
0
].
operator
Stage
()
.
unroll
(
args
[
1
]);
});
TVM_REGISTER_API
(
_StageVectorize
)
TVM_REGISTER_API
(
"_StageVectorize"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
args
[
0
].
operator
Stage
()
.
vectorize
(
args
[
1
]);
});
TVM_REGISTER_API
(
_StageParallel
)
TVM_REGISTER_API
(
"_StageParallel"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
args
[
0
].
operator
Stage
()
.
parallel
(
args
[
1
]);
});
TVM_REGISTER_API
(
_ScheduleNormalize
)
TVM_REGISTER_API
(
"_ScheduleNormalize"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
args
[
0
].
operator
Schedule
()
.
normalize
();
});
TVM_REGISTER_API
(
_ScheduleCreateGroup
)
TVM_REGISTER_API
(
"_ScheduleCreateGroup"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
args
[
0
].
operator
Schedule
()
.
create_group
(
args
[
1
],
args
[
2
],
args
[
3
]);
});
TVM_REGISTER_API
(
_ScheduleCacheRead
)
TVM_REGISTER_API
(
"_ScheduleCacheRead"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
args
[
0
].
operator
Schedule
()
.
cache_read
(
args
[
1
],
args
[
2
],
args
[
3
]);
});
TVM_REGISTER_API
(
_ScheduleCacheWrite
)
TVM_REGISTER_API
(
"_ScheduleCacheWrite"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
args
[
0
].
operator
Schedule
()
.
cache_write
(
args
[
1
],
args
[
2
]);
});
TVM_REGISTER_API
(
_ScheduleRFactor
)
TVM_REGISTER_API
(
"_ScheduleRFactor"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
args
[
0
].
operator
Schedule
()
.
rfactor
(
args
[
1
],
args
[
2
]);
...
...
src/api/api_pass.cc
View file @
181edb4a
...
...
@@ -12,7 +12,7 @@
namespace
tvm
{
namespace
ir
{
TVM_REGISTER_API
(
_pass_Simplify
)
TVM_REGISTER_API
(
"ir_pass.Simplify"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
if
(
args
[
0
].
IsNodeType
<
Stmt
>
())
{
*
ret
=
Simplify
(
args
[
0
].
operator
Stmt
());
...
...
@@ -21,7 +21,7 @@ TVM_REGISTER_API(_pass_Simplify)
}
});
TVM_REGISTER_API
(
_pass_Equal
)
TVM_REGISTER_API
(
"ir_pass.Equal"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
if
(
args
[
0
].
IsNodeType
<
Stmt
>
())
{
*
ret
=
Equal
(
args
[
0
].
operator
Stmt
(),
args
[
1
].
operator
Stmt
());
...
...
@@ -30,7 +30,7 @@ TVM_REGISTER_API(_pass_Equal)
}
});
TVM_REGISTER_API
(
_pass_PostOrderVisit
)
TVM_REGISTER_API
(
"ir_pass.PostOrderVisit"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
PackedFunc
f
=
args
[
1
];
ir
::
PostOrderVisit
(
args
[
0
],
[
f
](
const
NodeRef
&
n
)
{
...
...
@@ -40,19 +40,19 @@ TVM_REGISTER_API(_pass_PostOrderVisit)
// make from two arguments
#define REGISTER_PASS1(PassName) \
TVM_REGISTER_API(
_pass_## PassName)
\
TVM_REGISTER_API(
"ir_pass."#PassName)
\
.set_body([](TVMArgs args, TVMRetValue *ret) { \
*ret = PassName(args[0]); \
}) \
#define REGISTER_PASS2(PassName) \
TVM_REGISTER_API(
_pass_## PassName)
\
TVM_REGISTER_API(
"ir_pass."#PassName)
\
.set_body([](TVMArgs args, TVMRetValue *ret) { \
*ret = PassName(args[0], args[1]); \
}) \
#define REGISTER_PASS4(PassName) \
TVM_REGISTER_API(
_pass_## PassName)
\
TVM_REGISTER_API(
"ir_pass."#PassName)
\
.set_body([](TVMArgs args, TVMRetValue *ret) { \
*ret = PassName(args[0], args[1], args[2], args[3]); \
}) \
...
...
src/api/api_schedule.cc
View file @
181edb4a
...
...
@@ -13,19 +13,19 @@
namespace
tvm
{
namespace
schedule
{
TVM_REGISTER_API
(
_schedule_AutoInlineElemWise
)
TVM_REGISTER_API
(
"schedule.AutoInlineElemWise"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
AutoInlineElemWise
(
args
[
0
]);
});
#define REGISTER_SCHEDULE_PASS1(PassName) \
TVM_REGISTER_API(
_schedule_## PassName)
\
TVM_REGISTER_API(
"schedule."#PassName)
\
.set_body([](TVMArgs args, TVMRetValue *ret) { \
*ret = PassName(args[0]); \
}) \
#define REGISTER_SCHEDULE_PASS2(PassName) \
TVM_REGISTER_API(
_schedule_## PassName)
\
TVM_REGISTER_API(
"schedule."#PassName)
\
.set_body([](TVMArgs args, TVMRetValue *ret) { \
*ret = PassName(args[0], args[1]); \
}) \
...
...
src/api/c_api.cc
View file @
181edb4a
...
...
@@ -103,6 +103,23 @@ int TVMNodeFree(NodeHandle handle) {
API_END
();
}
int
TVMCbArgToReturn
(
TVMValue
*
value
,
int
code
)
{
API_BEGIN
();
tvm
::
runtime
::
TVMRetValue
rv
;
rv
=
tvm
::
runtime
::
TVMArgValue
(
*
value
,
code
);
int
tcode
;
rv
.
MoveToCHost
(
value
,
&
tcode
);
CHECK_EQ
(
tcode
,
code
);
API_END
();
}
int
TVMNodeDupe
(
NodeHandle
handle
,
NodeHandle
*
out_handle
)
{
API_BEGIN
();
*
out_handle
=
new
TVMAPINode
(
*
static_cast
<
TVMAPINode
*>
(
handle
));
API_END
();
}
int
TVMNodeGetAttr
(
NodeHandle
handle
,
const
char
*
key
,
TVMValue
*
ret_val
,
...
...
src/codegen/build_cuda.cc
View file @
181edb4a
...
...
@@ -86,7 +86,7 @@ runtime::Module BuildCUDA(Array<LoweredFunc> funcs) {
return
CUDAModuleCreate
(
ptx
,
fmt
,
fmap
,
code
);
}
TVM_REGISTER_API
(
_codegen_build_cuda
)
TVM_REGISTER_API
(
"codegen.build_cuda"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
*
rv
=
BuildCUDA
(
args
[
0
]);
});
...
...
src/codegen/build_opencl.cc
View file @
181edb4a
...
...
@@ -41,7 +41,7 @@ runtime::Module BuildOpenCL(Array<LoweredFunc> funcs) {
return
OpenCLModuleCreate
(
code
,
"cl"
,
fmap
);
}
TVM_REGISTER_API
(
_codegen_build_opencl
)
TVM_REGISTER_API
(
"codegen.build_opencl"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
*
rv
=
BuildOpenCL
(
args
[
0
]);
});
...
...
src/codegen/codegen.cc
View file @
181edb4a
...
...
@@ -17,7 +17,7 @@ runtime::Module Build(const Array<LoweredFunc>& funcs,
if
(
pos
!=
std
::
string
::
npos
)
{
mode
=
mode
.
substr
(
0
,
pos
);
}
std
::
string
build_f_name
=
"
_codegen_
build_"
+
mode
;
std
::
string
build_f_name
=
"
codegen.
build_"
+
mode
;
const
PackedFunc
*
bf
=
runtime
::
Registry
::
Get
(
build_f_name
);
CHECK
(
bf
!=
nullptr
)
...
...
@@ -27,7 +27,7 @@ runtime::Module Build(const Array<LoweredFunc>& funcs,
}
bool
TargetEnabled
(
const
std
::
string
&
target
)
{
std
::
string
build_f_name
=
"
_codegen_
build_"
+
target
;
std
::
string
build_f_name
=
"
codegen.
build_"
+
target
;
return
runtime
::
Registry
::
Get
(
build_f_name
)
!=
nullptr
;
}
...
...
src/codegen/llvm/llvm_module.cc
View file @
181edb4a
...
...
@@ -150,7 +150,7 @@ class LLVMModuleNode : public runtime::ModuleNode {
std
::
shared_ptr
<
llvm
::
LLVMContext
>
ctx_
;
};
TVM_REGISTER_API
(
_codegen_build_llvm
)
TVM_REGISTER_API
(
"codegen.build_llvm"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
std
::
shared_ptr
<
LLVMModuleNode
>
n
=
std
::
make_shared
<
LLVMModuleNode
>
();
n
->
Init
(
args
[
0
],
args
[
1
]);
...
...
src/codegen/stack_vm/stack_vm_module.cc
View file @
181edb4a
...
...
@@ -69,7 +69,7 @@ class StackVMModuleNode : public runtime::ModuleNode {
std
::
unordered_map
<
std
::
string
,
StackVM
>
fmap_
;
};
TVM_REGISTER_API
(
_codegen_build_stackvm
)
TVM_REGISTER_API
(
"codegen.build_stackvm"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
*
rv
=
StackVMModuleNode
::
Build
(
args
[
0
]);
});
...
...
src/codegen/verilog/verilog_module.cc
View file @
181edb4a
...
...
@@ -86,7 +86,7 @@ class VerilogModuleNode : public runtime::ModuleNode {
std
::
string
fmt_
;
};
TVM_REGISTER_API
(
_codegen_build_verilog
)
TVM_REGISTER_API
(
"codegen.build_verilog"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
std
::
shared_ptr
<
VerilogModuleNode
>
n
=
std
::
make_shared
<
VerilogModuleNode
>
();
...
...
src/codegen/verilog/vpi_device_api.cc
View file @
181edb4a
...
...
@@ -385,7 +385,7 @@ class VPIWriteMemMap : public VPIMemMapBase {
VPIHandle
enable_
;
};
TVM_REGISTER_GLOBAL
(
_device_api_vpi
)
TVM_REGISTER_GLOBAL
(
"device_api.vpi"
)
.
set_body
([](
runtime
::
TVMArgs
args
,
runtime
::
TVMRetValue
*
rv
)
{
runtime
::
DeviceAPI
*
ptr
=
VPIDeviceAPI
::
Global
();
*
rv
=
static_cast
<
void
*>
(
ptr
);
...
...
@@ -403,13 +403,13 @@ void TVMVPIHook(runtime::TVMArgs args, runtime::TVMRetValue* rv) {
*
rv
=
pf
;
}
TVM_REGISTER_GLOBAL
(
_vpi_module_tvm_vpi_mem_interface
)
TVM_REGISTER_GLOBAL
(
"_vpi_module_tvm_vpi_mem_interface"
)
.
set_body
(
TVMVPIHook
<
VPIMemoryInterface
>
);
TVM_REGISTER_GLOBAL
(
_vpi_module_tvm_vpi_read_mmap
)
TVM_REGISTER_GLOBAL
(
"_vpi_module_tvm_vpi_read_mmap"
)
.
set_body
(
TVMVPIHook
<
VPIReadMemMap
>
);
TVM_REGISTER_GLOBAL
(
_vpi_module_tvm_vpi_write_mmap
)
TVM_REGISTER_GLOBAL
(
"_vpi_module_tvm_vpi_write_mmap"
)
.
set_body
(
TVMVPIHook
<
VPIWriteMemMap
>
);
}
// namespace codegen
...
...
src/codegen/verilog/vpi_session.cc
View file @
181edb4a
...
...
@@ -212,47 +212,47 @@ VPIHandle VPIHandle::operator[](const std::string& name) const {
}
// API registration
TVM_REGISTER_API
(
_vpi_SessMake
)
TVM_REGISTER_API
(
"_vpi_SessMake"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
VPISession
::
make
(
args
[
0
],
args
[
1
]);
});
TVM_REGISTER_API
(
_vpi_SessGetHandleByName
)
TVM_REGISTER_API
(
"_vpi_SessGetHandleByName"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
args
[
0
].
operator
VPISession
().
operator
[](
args
[
1
]);
});
TVM_REGISTER_API
(
_vpi_SessYield
)
TVM_REGISTER_API
(
"_vpi_SessYield"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
args
[
0
].
operator
VPISession
().
yield
();
});
TVM_REGISTER_API
(
_vpi_SessShutdown
)
TVM_REGISTER_API
(
"_vpi_SessShutdown"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
args
[
0
].
operator
VPISession
().
shutdown
();
});
TVM_REGISTER_API
(
_vpi_HandlePutInt
)
TVM_REGISTER_API
(
"_vpi_HandlePutInt"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
args
[
0
].
operator
VPIHandle
().
put_int
(
args
[
1
]);
});
TVM_REGISTER_API
(
_vpi_HandleGetInt
)
TVM_REGISTER_API
(
"_vpi_HandleGetInt"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
args
[
0
].
operator
VPIHandle
().
get_int
();
});
TVM_REGISTER_API
(
_vpi_HandleGetName
)
TVM_REGISTER_API
(
"_vpi_HandleGetName"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
args
[
0
].
operator
VPIHandle
().
name
();
});
TVM_REGISTER_API
(
_vpi_HandleGetSize
)
TVM_REGISTER_API
(
"_vpi_HandleGetSize"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
args
[
0
].
operator
VPIHandle
().
size
();
});
TVM_REGISTER_API
(
_vpi_HandleGetHandleByName
)
TVM_REGISTER_API
(
"_vpi_HandleGetHandleByName"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
args
[
0
].
operator
VPIHandle
().
operator
[](
args
[
1
]);
});
...
...
src/runtime/c_runtime_api.cc
View file @
181edb4a
...
...
@@ -46,7 +46,7 @@ class DeviceAPIManager {
if
(
api_
[
type
]
!=
nullptr
)
return
api_
[
type
];
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
if
(
api_
[
type
]
!=
nullptr
)
return
api_
[
type
];
std
::
string
factory
=
"
_device_api_
"
+
DeviceName
(
type
);
std
::
string
factory
=
"
device_api.
"
+
DeviceName
(
type
);
auto
*
f
=
Registry
::
Get
(
factory
);
CHECK
(
f
!=
nullptr
)
<<
"Device API "
<<
DeviceName
(
type
)
<<
" is not enabled."
;
...
...
src/runtime/cpu_device_api.cc
View file @
181edb4a
...
...
@@ -46,7 +46,7 @@ class CPUDeviceAPI : public DeviceAPI {
}
};
TVM_REGISTER_GLOBAL
(
_device_api_cpu
)
TVM_REGISTER_GLOBAL
(
"device_api.cpu"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
static
CPUDeviceAPI
inst
;
DeviceAPI
*
ptr
=
&
inst
;
...
...
src/runtime/cuda/cuda_device_api.cc
View file @
181edb4a
...
...
@@ -77,7 +77,7 @@ class CUDADeviceAPI : public DeviceAPI {
}
};
TVM_REGISTER_GLOBAL
(
_device_api_gpu
)
TVM_REGISTER_GLOBAL
(
"device_api.gpu"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
static
CUDADeviceAPI
inst
;
DeviceAPI
*
ptr
=
&
inst
;
...
...
src/runtime/cuda/cuda_module.cc
View file @
181edb4a
...
...
@@ -281,12 +281,12 @@ Module CUDAModuleLoad(const std::string& file_name,
return
CUDAModuleCreate
(
data
,
fmt
,
fmap
,
std
::
string
());
}
TVM_REGISTER_GLOBAL
(
_module_loadfile_cubin
)
TVM_REGISTER_GLOBAL
(
"module.loadfile_cubin"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
*
rv
=
CUDAModuleLoad
(
args
[
0
],
args
[
1
]);
});
TVM_REGISTER_GLOBAL
(
_module_loadfile_ptx
)
TVM_REGISTER_GLOBAL
(
"module.loadfile_ptx"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
*
rv
=
CUDAModuleLoad
(
args
[
0
],
args
[
1
]);
});
...
...
src/runtime/dso_module.cc
View file @
181edb4a
...
...
@@ -110,7 +110,7 @@ class DSOModuleNode : public ModuleNode {
#endif
};
TVM_REGISTER_GLOBAL
(
_module_loadfile_so
)
TVM_REGISTER_GLOBAL
(
"module.loadfile_so"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
std
::
shared_ptr
<
DSOModuleNode
>
n
=
std
::
make_shared
<
DSOModuleNode
>
();
n
->
Init
(
args
[
0
]);
...
...
src/runtime/module.cc
View file @
181edb4a
...
...
@@ -53,7 +53,7 @@ Module Module::LoadFromFile(const std::string& file_name,
if
(
fmt
==
"dll"
||
fmt
==
"dylib"
||
fmt
==
"dso"
)
{
fmt
=
"so"
;
}
std
::
string
load_f_name
=
"
_module_
loadfile_"
+
fmt
;
std
::
string
load_f_name
=
"
module.
loadfile_"
+
fmt
;
const
PackedFunc
*
f
=
Registry
::
Get
(
load_f_name
);
CHECK
(
f
!=
nullptr
)
<<
"Loader of "
<<
format
<<
"("
...
...
@@ -88,48 +88,48 @@ bool RuntimeEnabled(const std::string& target) {
if
(
target
==
"cpu"
)
{
return
true
;
}
else
if
(
target
==
"cuda"
||
target
==
"gpu"
)
{
load_f_name
=
"
_module_
loadfile_ptx"
;
load_f_name
=
"
module.
loadfile_ptx"
;
}
else
if
(
target
==
"cl"
||
target
==
"opencl"
)
{
load_f_name
=
"
_module_
loadfile_cl"
;
load_f_name
=
"
module.
loadfile_cl"
;
}
else
{
LOG
(
FATAL
)
<<
"Unknown optional runtime "
<<
target
;
}
return
runtime
::
Registry
::
Get
(
load_f_name
)
!=
nullptr
;
}
TVM_REGISTER_GLOBAL
(
_module__Enabled
)
TVM_REGISTER_GLOBAL
(
"module._Enabled"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
RuntimeEnabled
(
args
[
0
]);
});
TVM_REGISTER_GLOBAL
(
_module__GetSource
)
TVM_REGISTER_GLOBAL
(
"module._GetSource"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
args
[
0
].
operator
Module
()
->
GetSource
(
args
[
1
]);
});
TVM_REGISTER_GLOBAL
(
_module__ImportsSize
)
TVM_REGISTER_GLOBAL
(
"module._ImportsSize"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
static_cast
<
int64_t
>
(
args
[
0
].
operator
Module
()
->
imports
().
size
());
});
TVM_REGISTER_GLOBAL
(
_module__GetImport
)
TVM_REGISTER_GLOBAL
(
"module._GetImport"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
args
[
0
].
operator
Module
()
->
imports
().
at
(
args
[
1
].
operator
int
());
});
TVM_REGISTER_GLOBAL
(
_module__GetTypeKey
)
TVM_REGISTER_GLOBAL
(
"module._GetTypeKey"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
std
::
string
(
args
[
0
].
operator
Module
()
->
type_key
());
});
TVM_REGISTER_GLOBAL
(
_module__LoadFromFile
)
TVM_REGISTER_GLOBAL
(
"module._LoadFromFile"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
Module
::
LoadFromFile
(
args
[
0
],
args
[
1
]);
});
TVM_REGISTER_GLOBAL
(
_module__SaveToFile
)
TVM_REGISTER_GLOBAL
(
"module._SaveToFile"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
args
[
0
].
operator
Module
()
->
SaveToFile
(
args
[
1
],
args
[
2
]);
...
...
src/runtime/opencl/opencl_device_api.cc
View file @
181edb4a
...
...
@@ -189,10 +189,10 @@ bool InitOpenCL(TVMArgs args, TVMRetValue* rv) {
return
true
;
}
TVM_REGISTER_GLOBAL
(
_module_init_opencl
)
TVM_REGISTER_GLOBAL
(
"module.init_opencl"
)
.
set_body
(
InitOpenCL
);
TVM_REGISTER_GLOBAL
(
_device_api_opencl
)
TVM_REGISTER_GLOBAL
(
"device_api.opencl"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
DeviceAPI
*
ptr
=
OpenCLWorkspace
::
Global
();
*
rv
=
static_cast
<
void
*>
(
ptr
);
...
...
src/runtime/opencl/opencl_module.cc
View file @
181edb4a
...
...
@@ -314,12 +314,12 @@ Module OpenCLModuleLoad(const std::string& file_name,
return
OpenCLModuleCreate
(
data
,
fmt
,
fmap
);
}
TVM_REGISTER_GLOBAL
(
_module_loadfile_cl
)
TVM_REGISTER_GLOBAL
(
"module.loadfile_cl"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
*
rv
=
OpenCLModuleLoad
(
args
[
0
],
args
[
1
]);
});
TVM_REGISTER_GLOBAL
(
_module_loadfile_clbin
)
TVM_REGISTER_GLOBAL
(
"module.loadfile_clbin"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
*
rv
=
OpenCLModuleLoad
(
args
[
0
],
args
[
1
]);
});
...
...
tests/python/unittest/test_runtime_packed_func.py
View file @
181edb4a
...
...
@@ -14,6 +14,25 @@ def test_get_global():
y
=
f
(
*
targs
)
assert
y
==
10
def
test_get_callback_with_node
():
x
=
tvm
.
convert
(
10
)
def
test
(
y
):
assert
y
.
handle
!=
x
.
handle
return
y
f2
=
tvm
.
convert
(
test
)
# register into global function table
@tvm.register_func
def
my_callback_with_node
(
y
,
f
):
assert
y
==
x
return
f
(
y
)
# get it out from global function table
f
=
tvm
.
get_global_func
(
"my_callback_with_node"
)
assert
isinstance
(
f
,
tvm
.
Function
)
y
=
f
(
x
,
f2
)
assert
(
y
.
value
==
10
)
def
test_return_func
():
def
addy
(
y
):
...
...
@@ -45,6 +64,7 @@ def test_byte_array():
f
(
a
)
if
__name__
==
"__main__"
:
test_get_callback_with_node
()
test_convert
()
test_get_global
()
test_return_func
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment