Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
181edb4a
Commit
181edb4a
authored
Apr 21, 2017
by
Tianqi Chen
Committed by
GitHub
Apr 21, 2017
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[LANG] Change namespace convention to dot (#100)
parent
97c67e53
Show whitespace changes
Inline
Side-by-side
Showing
31 changed files
with
208 additions
and
145 deletions
+208
-145
include/tvm/base.h
+1
-0
include/tvm/c_api.h
+12
-0
include/tvm/runtime/registry.h
+9
-6
python/tvm/_ctypes/_function.py
+19
-19
python/tvm/_ctypes/_node.py
+6
-2
python/tvm/_ctypes/_types.py
+7
-1
src/api/api_arith.cc
+9
-9
src/api/api_base.cc
+5
-5
src/api/api_codegen.cc
+2
-2
src/api/api_ir.cc
+10
-10
src/api/api_lang.cc
+42
-42
src/api/api_pass.cc
+6
-6
src/api/api_schedule.cc
+3
-3
src/api/c_api.cc
+17
-0
src/codegen/build_cuda.cc
+1
-1
src/codegen/build_opencl.cc
+1
-1
src/codegen/codegen.cc
+2
-2
src/codegen/llvm/llvm_module.cc
+1
-1
src/codegen/stack_vm/stack_vm_module.cc
+1
-1
src/codegen/verilog/verilog_module.cc
+1
-1
src/codegen/verilog/vpi_device_api.cc
+4
-4
src/codegen/verilog/vpi_session.cc
+9
-9
src/runtime/c_runtime_api.cc
+1
-1
src/runtime/cpu_device_api.cc
+1
-1
src/runtime/cuda/cuda_device_api.cc
+1
-1
src/runtime/cuda/cuda_module.cc
+2
-2
src/runtime/dso_module.cc
+1
-1
src/runtime/module.cc
+10
-10
src/runtime/opencl/opencl_device_api.cc
+2
-2
src/runtime/opencl/opencl_module.cc
+2
-2
tests/python/unittest/test_runtime_packed_func.py
+20
-0
No files found.
include/tvm/base.h
View file @
181edb4a
...
@@ -12,6 +12,7 @@
...
@@ -12,6 +12,7 @@
#include <string>
#include <string>
#include <memory>
#include <memory>
#include <functional>
#include <functional>
#include "./runtime/registry.h"
namespace
tvm
{
namespace
tvm
{
...
...
include/tvm/c_api.h
View file @
181edb4a
...
@@ -20,6 +20,18 @@ TVM_EXTERN_C {
...
@@ -20,6 +20,18 @@ TVM_EXTERN_C {
typedef
void
*
NodeHandle
;
typedef
void
*
NodeHandle
;
/*!
/*!
* \brief Inplace translate callback argument value to return value.
* This is only needed for non-POD arguments.
*
* \param value The value to be translated.
* \param code The type code to be translated.
* \note This function will do a shallow copy when necessary.
*
* \return 0 when success, -1 when failure happens.
*/
TVM_DLL
int
TVMCbArgToReturn
(
TVMValue
*
value
,
int
code
);
/*!
* \brief free the node handle
* \brief free the node handle
* \param handle The node handle to be freed.
* \param handle The node handle to be freed.
* \return 0 when success, -1 when failure happens
* \return 0 when success, -1 when failure happens
...
...
include/tvm/runtime/registry.h
View file @
181edb4a
...
@@ -81,7 +81,6 @@ class Registry {
...
@@ -81,7 +81,6 @@ class Registry {
friend
struct
Manager
;
friend
struct
Manager
;
};
};
/*! \brief helper macro to supress unused warning */
/*! \brief helper macro to supress unused warning */
#if defined(__GNUC__)
#if defined(__GNUC__)
#define TVM_ATTRIBUTE_UNUSED __attribute__((unused))
#define TVM_ATTRIBUTE_UNUSED __attribute__((unused))
...
@@ -89,19 +88,23 @@ class Registry {
...
@@ -89,19 +88,23 @@ class Registry {
#define TVM_ATTRIBUTE_UNUSED
#define TVM_ATTRIBUTE_UNUSED
#endif
#endif
#define TVM_STR_CONCAT_(__x, __y) __x##__y
#define TVM_STR_CONCAT(__x, __y) TVM_STR_CONCAT_(__x, __y)
#define TVM_FUNC_REG_VAR_DEF \
static TVM_ATTRIBUTE_UNUSED ::tvm::runtime::Registry& __make_ ## TVMOp
/*!
/*!
* \brief Register a function globally.
* \brief Register a function globally.
* \code
* \code
* TVM_REGISTER_GLOBAL(
MyPrint
)
* TVM_REGISTER_GLOBAL(
"MyPrint"
)
* .set_body([](TVMArgs args, TVMRetValue* rv) {
* .set_body([](TVMArgs args, TVMRetValue* rv) {
* // my code.
* });
* });
* \endcode
* \endcode
*/
*/
#define TVM_REGISTER_GLOBAL(OpName) \
#define TVM_REGISTER_GLOBAL(OpName) \
static TVM_ATTRIBUTE_UNUSED ::tvm::runtime::Registry& \
TVM_STR_CONCAT(TVM_FUNC_REG_VAR_DEF, __COUNTER__) = \
__make_TVMRegistry_ ## OpName = \
::tvm::runtime::Registry::Register(OpName)
::tvm::runtime::Registry::Register(#OpName)
}
// namespace runtime
}
// namespace runtime
}
// namespace tvm
}
// namespace tvm
...
...
python/tvm/_ctypes/_function.py
View file @
181edb4a
...
@@ -12,7 +12,7 @@ from .._base import _LIB, check_call
...
@@ -12,7 +12,7 @@ from .._base import _LIB, check_call
from
.._base
import
c_str
,
py_str
,
string_types
from
.._base
import
c_str
,
py_str
,
string_types
from
._types
import
TVMValue
,
TypeCode
,
TVMType
,
TVMByteArray
from
._types
import
TVMValue
,
TypeCode
,
TVMType
,
TVMByteArray
from
._types
import
TVMPackedCFunc
,
TVMCFuncFinalizer
from
._types
import
TVMPackedCFunc
,
TVMCFuncFinalizer
from
._types
import
RETURN_SWITCH
,
C_TO_PY_ARG_SWITCH
from
._types
import
RETURN_SWITCH
,
C_TO_PY_ARG_SWITCH
,
_wrap_arg_func
from
._node
import
NodeBase
,
SliceBase
,
convert_to_node
from
._node
import
NodeBase
,
SliceBase
,
convert_to_node
from
._ndarray
import
NDArrayBase
from
._ndarray
import
NDArrayBase
...
@@ -302,6 +302,10 @@ def _handle_return_func(x):
...
@@ -302,6 +302,10 @@ def _handle_return_func(x):
# setup return handle for function type
# setup return handle for function type
RETURN_SWITCH
[
TypeCode
.
FUNC_HANDLE
]
=
_handle_return_func
RETURN_SWITCH
[
TypeCode
.
FUNC_HANDLE
]
=
_handle_return_func
RETURN_SWITCH
[
TypeCode
.
MODULE_HANDLE
]
=
_return_module
RETURN_SWITCH
[
TypeCode
.
MODULE_HANDLE
]
=
_return_module
C_TO_PY_ARG_SWITCH
[
TypeCode
.
FUNC_HANDLE
]
=
_wrap_arg_func
(
_handle_return_func
,
TypeCode
.
FUNC_HANDLE
)
C_TO_PY_ARG_SWITCH
[
TypeCode
.
MODULE_HANDLE
]
=
_wrap_arg_func
(
_return_module
,
TypeCode
.
MODULE_HANDLE
)
def
register_func
(
func_name
,
f
=
None
):
def
register_func
(
func_name
,
f
=
None
):
...
@@ -415,35 +419,31 @@ def _get_api(f):
...
@@ -415,35 +419,31 @@ def _get_api(f):
return
flocal
(
*
args
)
return
flocal
(
*
args
)
return
my_api_func
return
my_api_func
def
_init_api
(
mod
):
def
_init_api
(
namespace
):
"""Initialize api for a given module name
"""Initialize api for a given module name
mod : str
mod : str
The name of the module.
The name of the module.
"""
"""
module
=
sys
.
modules
[
mod
]
module
=
sys
.
modules
[
namespace
]
namespace_match
=
{
assert
namespace
.
startswith
(
"tvm."
)
"_make_"
:
"tvm.make"
,
prefix
=
namespace
[
4
:]
"_arith_"
:
"tvm.arith"
,
"_pass_"
:
"tvm.ir_pass"
,
"_codegen_"
:
"tvm.codegen"
,
"_module_"
:
"tvm.module"
,
"_schedule_"
:
"tvm.schedule"
}
for
name
in
list_global_func_names
():
for
name
in
list_global_func_names
():
if
prefix
==
"api"
:
fname
=
name
fname
=
name
target
=
"tvm.api"
if
name
.
startswith
(
"_"
):
for
k
,
v
in
namespace_match
.
items
():
if
name
.
startswith
(
k
):
fname
=
name
[
len
(
k
):]
target
=
v
if
target
!=
mod
:
continue
if
mod
==
"tvm.api"
and
name
.
startswith
(
"_"
):
target_module
=
sys
.
modules
[
"tvm._api_internal"
]
target_module
=
sys
.
modules
[
"tvm._api_internal"
]
else
:
else
:
target_module
=
module
target_module
=
module
else
:
if
not
name
.
startswith
(
prefix
):
continue
fname
=
name
[
len
(
prefix
)
+
1
:]
target_module
=
module
if
fname
.
find
(
"."
)
!=
-
1
:
continue
f
=
get_global_func
(
name
)
f
=
get_global_func
(
name
)
ff
=
_get_api
(
f
)
ff
=
_get_api
(
f
)
ff
.
__name__
=
fname
ff
.
__name__
=
fname
...
...
python/tvm/_ctypes/_node.py
View file @
181edb4a
...
@@ -10,7 +10,9 @@ from numbers import Number, Integral
...
@@ -10,7 +10,9 @@ from numbers import Number, Integral
from
.._base
import
_LIB
,
check_call
from
.._base
import
_LIB
,
check_call
from
.._base
import
c_str
,
py_str
,
string_types
from
.._base
import
c_str
,
py_str
,
string_types
from
..
import
_api_internal
from
..
import
_api_internal
from
._types
import
TVMValue
,
TypeCode
,
RETURN_SWITCH
from
._types
import
TVMValue
,
TypeCode
from
._types
import
RETURN_SWITCH
,
C_TO_PY_ARG_SWITCH
,
_wrap_arg_func
NodeHandle
=
ctypes
.
c_void_p
NodeHandle
=
ctypes
.
c_void_p
...
@@ -19,7 +21,7 @@ NODE_TYPE = {
...
@@ -19,7 +21,7 @@ NODE_TYPE = {
}
}
def
_return_node
(
x
):
def
_return_node
(
x
):
"""Return function"""
"""Return
node
function"""
handle
=
x
.
v_handle
handle
=
x
.
v_handle
if
not
isinstance
(
handle
,
NodeHandle
):
if
not
isinstance
(
handle
,
NodeHandle
):
handle
=
NodeHandle
(
handle
)
handle
=
NodeHandle
(
handle
)
...
@@ -35,6 +37,8 @@ def _return_node(x):
...
@@ -35,6 +37,8 @@ def _return_node(x):
RETURN_SWITCH
[
TypeCode
.
NODE_HANDLE
]
=
_return_node
RETURN_SWITCH
[
TypeCode
.
NODE_HANDLE
]
=
_return_node
C_TO_PY_ARG_SWITCH
[
TypeCode
.
NODE_HANDLE
]
=
_wrap_arg_func
(
_return_node
,
TypeCode
.
NODE_HANDLE
)
class
SliceBase
(
object
):
class
SliceBase
(
object
):
...
...
python/tvm/_ctypes/_types.py
View file @
181edb4a
...
@@ -4,7 +4,7 @@ from __future__ import absolute_import as _abs
...
@@ -4,7 +4,7 @@ from __future__ import absolute_import as _abs
import
ctypes
import
ctypes
import
numpy
as
np
import
numpy
as
np
from
.._base
import
py_str
from
.._base
import
py_str
,
check_call
,
_LIB
tvm_shape_index_t
=
ctypes
.
c_int64
tvm_shape_index_t
=
ctypes
.
c_int64
...
@@ -130,6 +130,12 @@ def _return_bytes(x):
...
@@ -130,6 +130,12 @@ def _return_bytes(x):
raise
RuntimeError
(
'memmove failed'
)
raise
RuntimeError
(
'memmove failed'
)
return
res
return
res
def
_wrap_arg_func
(
return_f
,
type_code
):
tcode
=
ctypes
.
c_int
(
type_code
)
def
_wrap_func
(
x
):
check_call
(
_LIB
.
TVMCbArgToReturn
(
ctypes
.
byref
(
x
),
tcode
))
return
return_f
(
x
)
return
_wrap_func
RETURN_SWITCH
=
{
RETURN_SWITCH
=
{
TypeCode
.
INT
:
lambda
x
:
x
.
v_int64
,
TypeCode
.
INT
:
lambda
x
:
x
.
v_int64
,
...
...
src/api/api_arith.cc
View file @
181edb4a
...
@@ -11,49 +11,49 @@
...
@@ -11,49 +11,49 @@
namespace
tvm
{
namespace
tvm
{
namespace
arith
{
namespace
arith
{
TVM_REGISTER_API
(
_arith_intset_single_point
)
TVM_REGISTER_API
(
"arith.intset_single_point"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
IntSet
::
single_point
(
args
[
0
]);
*
ret
=
IntSet
::
single_point
(
args
[
0
]);
});
});
TVM_REGISTER_API
(
_arith_intset_interval
)
TVM_REGISTER_API
(
"arith.intset_interval"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
IntSet
::
interval
(
args
[
0
],
args
[
1
]);
*
ret
=
IntSet
::
interval
(
args
[
0
],
args
[
1
]);
});
});
TVM_REGISTER_API
(
_arith_EvalModular
)
TVM_REGISTER_API
(
"arith.EvalModular"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
EvalModular
(
args
[
0
],
Map
<
Var
,
IntSet
>
());
*
ret
=
EvalModular
(
args
[
0
],
Map
<
Var
,
IntSet
>
());
});
});
TVM_REGISTER_API
(
_arith_DetectLinearEquation
)
TVM_REGISTER_API
(
"arith.DetectLinearEquation"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
DetectLinearEquation
(
args
[
0
],
args
[
1
]);
*
ret
=
DetectLinearEquation
(
args
[
0
],
args
[
1
]);
});
});
TVM_REGISTER_API
(
_arith_DeduceBound
)
TVM_REGISTER_API
(
"arith.DeduceBound"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
DeduceBound
(
args
[
0
],
args
[
1
],
*
ret
=
DeduceBound
(
args
[
0
],
args
[
1
],
args
[
2
].
operator
Map
<
Var
,
IntSet
>
(),
args
[
2
].
operator
Map
<
Var
,
IntSet
>
(),
args
[
3
].
operator
Map
<
Var
,
IntSet
>
());
args
[
3
].
operator
Map
<
Var
,
IntSet
>
());
});
});
TVM_REGISTER_API
(
_IntervalSetGetMin
)
TVM_REGISTER_API
(
"_IntervalSetGetMin"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
args
[
0
].
operator
IntSet
().
min
();
*
ret
=
args
[
0
].
operator
IntSet
().
min
();
});
});
TVM_REGISTER_API
(
_IntervalSetGetMax
)
TVM_REGISTER_API
(
"_IntervalSetGetMax"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
args
[
0
].
operator
IntSet
().
max
();
*
ret
=
args
[
0
].
operator
IntSet
().
max
();
});
});
TVM_REGISTER_API
(
_IntSetIsNothing
)
TVM_REGISTER_API
(
"_IntSetIsNothing"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
args
[
0
].
operator
IntSet
().
is_nothing
();
*
ret
=
args
[
0
].
operator
IntSet
().
is_nothing
();
});
});
TVM_REGISTER_API
(
_IntSetIsEverything
)
TVM_REGISTER_API
(
"_IntSetIsEverything"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
args
[
0
].
operator
IntSet
().
is_everything
();
*
ret
=
args
[
0
].
operator
IntSet
().
is_everything
();
});
});
...
...
src/api/api_base.cc
View file @
181edb4a
/*!
/*!
* Copyright (c) 2017 by Contributors
* Copyright (c) 2017 by Contributors
* Implementation of basic API functions
* Implementation of basic API functions
* \file api_base.cc
* \file api_base.cc
...
@@ -9,7 +9,7 @@
...
@@ -9,7 +9,7 @@
namespace
tvm
{
namespace
tvm
{
TVM_REGISTER_API
(
_format_str
)
TVM_REGISTER_API
(
"_format_str"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
CHECK
(
args
[
0
].
type_code
()
==
kNodeHandle
);
CHECK
(
args
[
0
].
type_code
()
==
kNodeHandle
);
std
::
ostringstream
os
;
std
::
ostringstream
os
;
...
@@ -17,19 +17,19 @@ TVM_REGISTER_API(_format_str)
...
@@ -17,19 +17,19 @@ TVM_REGISTER_API(_format_str)
*
ret
=
os
.
str
();
*
ret
=
os
.
str
();
});
});
TVM_REGISTER_API
(
_raw_ptr
)
TVM_REGISTER_API
(
"_raw_ptr"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
CHECK
(
args
[
0
].
type_code
()
==
kNodeHandle
);
CHECK
(
args
[
0
].
type_code
()
==
kNodeHandle
);
*
ret
=
reinterpret_cast
<
int64_t
>
(
*
ret
=
reinterpret_cast
<
int64_t
>
(
args
[
0
].
node_sptr
().
get
());
args
[
0
].
node_sptr
().
get
());
});
});
TVM_REGISTER_API
(
_save_json
)
TVM_REGISTER_API
(
"_save_json"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
SaveJSON
(
args
[
0
]);
*
ret
=
SaveJSON
(
args
[
0
]);
});
});
TVM_REGISTER_API
(
_load_json
)
TVM_REGISTER_API
(
"_load_json"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
NodeRef
(
LoadJSON_
(
args
[
0
]));
*
ret
=
NodeRef
(
LoadJSON_
(
args
[
0
]));
});
});
...
...
src/api/api_codegen.cc
View file @
181edb4a
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
namespace
tvm
{
namespace
tvm
{
namespace
codegen
{
namespace
codegen
{
TVM_REGISTER_API
(
_codegen__Build
)
TVM_REGISTER_API
(
"codegen._Build"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
if
(
args
[
0
].
IsNodeType
<
LoweredFunc
>
())
{
if
(
args
[
0
].
IsNodeType
<
LoweredFunc
>
())
{
*
ret
=
Build
({
args
[
0
]},
args
[
1
]);
*
ret
=
Build
({
args
[
0
]},
args
[
1
]);
...
@@ -21,7 +21,7 @@ TVM_REGISTER_API(_codegen__Build)
...
@@ -21,7 +21,7 @@ TVM_REGISTER_API(_codegen__Build)
}
}
});
});
TVM_REGISTER_API
(
_codegen__Enabled
)
TVM_REGISTER_API
(
"codegen._Enabled"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
TargetEnabled
(
args
[
0
]);
*
ret
=
TargetEnabled
(
args
[
0
]);
});
});
...
...
src/api/api_ir.cc
View file @
181edb4a
...
@@ -11,12 +11,12 @@
...
@@ -11,12 +11,12 @@
namespace
tvm
{
namespace
tvm
{
namespace
ir
{
namespace
ir
{
TVM_REGISTER_API
(
_Var
)
TVM_REGISTER_API
(
"_Var"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
Variable
::
make
(
args
[
1
],
args
[
0
]);
*
ret
=
Variable
::
make
(
args
[
1
],
args
[
0
]);
});
});
TVM_REGISTER_API
(
_make_For
)
TVM_REGISTER_API
(
"make.For"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
For
::
make
(
args
[
0
],
*
ret
=
For
::
make
(
args
[
0
],
args
[
1
],
args
[
1
],
...
@@ -26,7 +26,7 @@ TVM_REGISTER_API(_make_For)
...
@@ -26,7 +26,7 @@ TVM_REGISTER_API(_make_For)
args
[
5
]);
args
[
5
]);
});
});
TVM_REGISTER_API
(
_make_Realize
)
TVM_REGISTER_API
(
"make.Realize"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
Realize
::
make
(
args
[
0
],
*
ret
=
Realize
::
make
(
args
[
0
],
args
[
1
],
args
[
1
],
...
@@ -37,7 +37,7 @@ TVM_REGISTER_API(_make_Realize)
...
@@ -37,7 +37,7 @@ TVM_REGISTER_API(_make_Realize)
});
});
TVM_REGISTER_API
(
_make_Call
)
TVM_REGISTER_API
(
"make.Call"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
Call
::
make
(
args
[
0
],
*
ret
=
Call
::
make
(
args
[
0
],
args
[
1
],
args
[
1
],
...
@@ -47,7 +47,7 @@ TVM_REGISTER_API(_make_Call)
...
@@ -47,7 +47,7 @@ TVM_REGISTER_API(_make_Call)
args
[
5
]);
args
[
5
]);
});
});
TVM_REGISTER_API
(
_make_Allocate
)
TVM_REGISTER_API
(
"make.Allocate"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
Allocate
::
make
(
args
[
0
],
*
ret
=
Allocate
::
make
(
args
[
0
],
args
[
1
],
args
[
1
],
...
@@ -58,31 +58,31 @@ TVM_REGISTER_API(_make_Allocate)
...
@@ -58,31 +58,31 @@ TVM_REGISTER_API(_make_Allocate)
// make from two arguments
// make from two arguments
#define REGISTER_MAKE1(Node) \
#define REGISTER_MAKE1(Node) \
TVM_REGISTER_API(
_make_## Node)
\
TVM_REGISTER_API(
"make."#Node)
\
.set_body([](TVMArgs args, TVMRetValue *ret) { \
.set_body([](TVMArgs args, TVMRetValue *ret) { \
*ret = Node::make(args[0]); \
*ret = Node::make(args[0]); \
}) \
}) \
#define REGISTER_MAKE2(Node) \
#define REGISTER_MAKE2(Node) \
TVM_REGISTER_API(
_make_## Node)
\
TVM_REGISTER_API(
"make."#Node)
\
.set_body([](TVMArgs args, TVMRetValue *ret) { \
.set_body([](TVMArgs args, TVMRetValue *ret) { \
*ret = Node::make(args[0], args[1]); \
*ret = Node::make(args[0], args[1]); \
}) \
}) \
#define REGISTER_MAKE3(Node) \
#define REGISTER_MAKE3(Node) \
TVM_REGISTER_API(
_make_## Node)
\
TVM_REGISTER_API(
"make."#Node)
\
.set_body([](TVMArgs args, TVMRetValue *ret) { \
.set_body([](TVMArgs args, TVMRetValue *ret) { \
*ret = Node::make(args[0], args[1], args[2]); \
*ret = Node::make(args[0], args[1], args[2]); \
}) \
}) \
#define REGISTER_MAKE4(Node) \
#define REGISTER_MAKE4(Node) \
TVM_REGISTER_API(
_make_## Node)
\
TVM_REGISTER_API(
"make."#Node)
\
.set_body([](TVMArgs args, TVMRetValue *ret) { \
.set_body([](TVMArgs args, TVMRetValue *ret) { \
*ret = Node::make(args[0], args[1], args[2], args[3]); \
*ret = Node::make(args[0], args[1], args[2], args[3]); \
}) \
}) \
#define REGISTER_MAKE_BINARY_OP(Node) \
#define REGISTER_MAKE_BINARY_OP(Node) \
TVM_REGISTER_API(
_make_## Node)
\
TVM_REGISTER_API(
"make."#Node)
\
.set_body([](TVMArgs args, TVMRetValue *ret) { \
.set_body([](TVMArgs args, TVMRetValue *ret) { \
Expr a = args[0], b = args[1]; \
Expr a = args[0], b = args[1]; \
match_types(a, b); \
match_types(a, b); \
...
...
src/api/api_lang.cc
View file @
181edb4a
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
namespace
tvm
{
namespace
tvm
{
TVM_REGISTER_API
(
_const
)
TVM_REGISTER_API
(
"_const"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
if
(
args
[
0
].
type_code
()
==
kInt
)
{
if
(
args
[
0
].
type_code
()
==
kInt
)
{
*
ret
=
make_const
(
args
[
1
],
args
[
0
].
operator
int64_t
());
*
ret
=
make_const
(
args
[
1
],
args
[
0
].
operator
int64_t
());
...
@@ -25,13 +25,13 @@ TVM_REGISTER_API(_const)
...
@@ -25,13 +25,13 @@ TVM_REGISTER_API(_const)
});
});
TVM_REGISTER_API
(
_str
)
TVM_REGISTER_API
(
"_str"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
ir
::
StringImm
::
make
(
args
[
0
]);
*
ret
=
ir
::
StringImm
::
make
(
args
[
0
]);
});
});
TVM_REGISTER_API
(
_Array
)
TVM_REGISTER_API
(
"_Array"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
std
::
vector
<
std
::
shared_ptr
<
Node
>
>
data
;
std
::
vector
<
std
::
shared_ptr
<
Node
>
>
data
;
for
(
int
i
=
0
;
i
<
args
.
size
();
++
i
)
{
for
(
int
i
=
0
;
i
<
args
.
size
();
++
i
)
{
...
@@ -42,7 +42,7 @@ TVM_REGISTER_API(_Array)
...
@@ -42,7 +42,7 @@ TVM_REGISTER_API(_Array)
*
ret
=
node
;
*
ret
=
node
;
});
});
TVM_REGISTER_API
(
_ArrayGetItem
)
TVM_REGISTER_API
(
"_ArrayGetItem"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
int64_t
i
=
args
[
1
];
int64_t
i
=
args
[
1
];
auto
&
sptr
=
args
[
0
].
node_sptr
();
auto
&
sptr
=
args
[
0
].
node_sptr
();
...
@@ -53,7 +53,7 @@ TVM_REGISTER_API(_ArrayGetItem)
...
@@ -53,7 +53,7 @@ TVM_REGISTER_API(_ArrayGetItem)
*
ret
=
n
->
data
[
static_cast
<
size_t
>
(
i
)];
*
ret
=
n
->
data
[
static_cast
<
size_t
>
(
i
)];
});
});
TVM_REGISTER_API
(
_ArraySize
)
TVM_REGISTER_API
(
"_ArraySize"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
auto
&
sptr
=
args
[
0
].
node_sptr
();
auto
&
sptr
=
args
[
0
].
node_sptr
();
CHECK
(
sptr
->
is_type
<
ArrayNode
>
());
CHECK
(
sptr
->
is_type
<
ArrayNode
>
());
...
@@ -61,7 +61,7 @@ TVM_REGISTER_API(_ArraySize)
...
@@ -61,7 +61,7 @@ TVM_REGISTER_API(_ArraySize)
static_cast
<
const
ArrayNode
*>
(
sptr
.
get
())
->
data
.
size
());
static_cast
<
const
ArrayNode
*>
(
sptr
.
get
())
->
data
.
size
());
});
});
TVM_REGISTER_API
(
_Map
)
TVM_REGISTER_API
(
"_Map"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
CHECK_EQ
(
args
.
size
()
%
2
,
0
);
CHECK_EQ
(
args
.
size
()
%
2
,
0
);
MapNode
::
ContainerType
data
;
MapNode
::
ContainerType
data
;
...
@@ -78,7 +78,7 @@ TVM_REGISTER_API(_Map)
...
@@ -78,7 +78,7 @@ TVM_REGISTER_API(_Map)
*
ret
=
node
;
*
ret
=
node
;
});
});
TVM_REGISTER_API
(
_MapSize
)
TVM_REGISTER_API
(
"_MapSize"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
auto
&
sptr
=
args
[
0
].
node_sptr
();
auto
&
sptr
=
args
[
0
].
node_sptr
();
CHECK
(
sptr
->
is_type
<
MapNode
>
());
CHECK
(
sptr
->
is_type
<
MapNode
>
());
...
@@ -86,7 +86,7 @@ TVM_REGISTER_API(_MapSize)
...
@@ -86,7 +86,7 @@ TVM_REGISTER_API(_MapSize)
*
ret
=
static_cast
<
int64_t
>
(
n
->
data
.
size
());
*
ret
=
static_cast
<
int64_t
>
(
n
->
data
.
size
());
});
});
TVM_REGISTER_API
(
_MapGetItem
)
TVM_REGISTER_API
(
"_MapGetItem"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
CHECK
(
args
[
0
].
type_code
()
==
kNodeHandle
);
CHECK
(
args
[
0
].
type_code
()
==
kNodeHandle
);
CHECK
(
args
[
1
].
type_code
()
==
kNodeHandle
);
CHECK
(
args
[
1
].
type_code
()
==
kNodeHandle
);
...
@@ -99,7 +99,7 @@ TVM_REGISTER_API(_MapGetItem)
...
@@ -99,7 +99,7 @@ TVM_REGISTER_API(_MapGetItem)
*
ret
=
(
*
it
).
second
;
*
ret
=
(
*
it
).
second
;
});
});
TVM_REGISTER_API
(
_MapCount
)
TVM_REGISTER_API
(
"_MapCount"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
CHECK
(
args
[
0
].
type_code
()
==
kNodeHandle
);
CHECK
(
args
[
0
].
type_code
()
==
kNodeHandle
);
CHECK
(
args
[
1
].
type_code
()
==
kNodeHandle
);
CHECK
(
args
[
1
].
type_code
()
==
kNodeHandle
);
...
@@ -110,7 +110,7 @@ TVM_REGISTER_API(_MapCount)
...
@@ -110,7 +110,7 @@ TVM_REGISTER_API(_MapCount)
n
->
data
.
count
(
args
[
1
].
node_sptr
()));
n
->
data
.
count
(
args
[
1
].
node_sptr
()));
});
});
TVM_REGISTER_API
(
_MapItems
)
TVM_REGISTER_API
(
"_MapItems"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
auto
&
sptr
=
args
[
0
].
node_sptr
();
auto
&
sptr
=
args
[
0
].
node_sptr
();
CHECK
(
sptr
->
is_type
<
MapNode
>
());
CHECK
(
sptr
->
is_type
<
MapNode
>
());
...
@@ -123,7 +123,7 @@ TVM_REGISTER_API(_MapItems)
...
@@ -123,7 +123,7 @@ TVM_REGISTER_API(_MapItems)
*
ret
=
rkvs
;
*
ret
=
rkvs
;
});
});
TVM_REGISTER_API
(
Range
)
TVM_REGISTER_API
(
"Range"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
if
(
args
.
size
()
==
1
)
{
if
(
args
.
size
()
==
1
)
{
*
ret
=
Range
(
0
,
args
[
0
]);
*
ret
=
Range
(
0
,
args
[
0
]);
...
@@ -132,7 +132,7 @@ TVM_REGISTER_API(Range)
...
@@ -132,7 +132,7 @@ TVM_REGISTER_API(Range)
}
}
});
});
TVM_REGISTER_API
(
_Buffer
)
TVM_REGISTER_API
(
"_Buffer"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
BufferNode
::
make
(
args
[
0
],
*
ret
=
BufferNode
::
make
(
args
[
0
],
args
[
1
],
args
[
1
],
...
@@ -143,7 +143,7 @@ TVM_REGISTER_API(_Buffer)
...
@@ -143,7 +143,7 @@ TVM_REGISTER_API(_Buffer)
args
[
6
]);
args
[
6
]);
});
});
TVM_REGISTER_API
(
_Tensor
)
TVM_REGISTER_API
(
"_Tensor"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
TensorNode
::
make
(
args
[
0
],
*
ret
=
TensorNode
::
make
(
args
[
0
],
args
[
1
],
args
[
1
],
...
@@ -151,32 +151,32 @@ TVM_REGISTER_API(_Tensor)
...
@@ -151,32 +151,32 @@ TVM_REGISTER_API(_Tensor)
args
[
3
]);
args
[
3
]);
});
});
TVM_REGISTER_API
(
_TensorEqual
)
TVM_REGISTER_API
(
"_TensorEqual"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
args
[
0
].
operator
Tensor
()
==
args
[
1
].
operator
Tensor
();
*
ret
=
args
[
0
].
operator
Tensor
()
==
args
[
1
].
operator
Tensor
();
});
});
TVM_REGISTER_API
(
_TensorHash
)
TVM_REGISTER_API
(
"_TensorHash"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
static_cast
<
int64_t
>
(
*
ret
=
static_cast
<
int64_t
>
(
std
::
hash
<
Tensor
>
()(
args
[
0
].
operator
Tensor
()));
std
::
hash
<
Tensor
>
()(
args
[
0
].
operator
Tensor
()));
});
});
TVM_REGISTER_API
(
_Placeholder
)
TVM_REGISTER_API
(
"_Placeholder"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
placeholder
(
args
[
0
],
*
ret
=
placeholder
(
args
[
0
],
args
[
1
],
args
[
1
],
args
[
2
]);
args
[
2
]);
});
});
TVM_REGISTER_API
(
_ComputeOp
)
TVM_REGISTER_API
(
"_ComputeOp"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
ComputeOpNode
::
make
(
args
[
0
],
*
ret
=
ComputeOpNode
::
make
(
args
[
0
],
args
[
1
],
args
[
1
],
args
[
2
]);
args
[
2
]);
});
});
TVM_REGISTER_API
(
_ScanOp
)
TVM_REGISTER_API
(
"_ScanOp"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
ScanOpNode
::
make
(
args
[
0
],
*
ret
=
ScanOpNode
::
make
(
args
[
0
],
args
[
1
],
args
[
1
],
...
@@ -186,7 +186,7 @@ TVM_REGISTER_API(_ScanOp)
...
@@ -186,7 +186,7 @@ TVM_REGISTER_API(_ScanOp)
args
[
5
]);
args
[
5
]);
});
});
TVM_REGISTER_API
(
_ExternOp
)
TVM_REGISTER_API
(
"_ExternOp"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
ExternOpNode
::
make
(
args
[
0
],
*
ret
=
ExternOpNode
::
make
(
args
[
0
],
args
[
1
],
args
[
1
],
...
@@ -195,18 +195,18 @@ TVM_REGISTER_API(_ExternOp)
...
@@ -195,18 +195,18 @@ TVM_REGISTER_API(_ExternOp)
args
[
4
]);
args
[
4
]);
});
});
TVM_REGISTER_API
(
_OpGetOutput
)
TVM_REGISTER_API
(
"_OpGetOutput"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
args
[
0
].
operator
Operation
().
output
(
*
ret
=
args
[
0
].
operator
Operation
().
output
(
static_cast
<
size_t
>
(
args
[
1
].
operator
int64_t
()));
static_cast
<
size_t
>
(
args
[
1
].
operator
int64_t
()));
});
});
TVM_REGISTER_API
(
_OpNumOutputs
)
TVM_REGISTER_API
(
"_OpNumOutputs"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
args
[
0
].
operator
Operation
()
->
num_outputs
();
*
ret
=
args
[
0
].
operator
Operation
()
->
num_outputs
();
});
});
TVM_REGISTER_API
(
_IterVar
)
TVM_REGISTER_API
(
"_IterVar"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
IterVarNode
::
make
(
*
ret
=
IterVarNode
::
make
(
args
[
0
],
args
[
1
],
args
[
0
],
args
[
1
],
...
@@ -214,24 +214,24 @@ TVM_REGISTER_API(_IterVar)
...
@@ -214,24 +214,24 @@ TVM_REGISTER_API(_IterVar)
args
[
3
]);
args
[
3
]);
});
});
TVM_REGISTER_API
(
_Schedule
)
TVM_REGISTER_API
(
"_Schedule"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
Schedule
(
args
[
0
].
operator
Array
<
Operation
>
());
*
ret
=
Schedule
(
args
[
0
].
operator
Array
<
Operation
>
());
});
});
TVM_REGISTER_API
(
_StageSetScope
)
TVM_REGISTER_API
(
"_StageSetScope"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
args
[
0
].
operator
Stage
()
args
[
0
].
operator
Stage
()
.
set_scope
(
args
[
1
]);
.
set_scope
(
args
[
1
]);
});
});
TVM_REGISTER_API
(
_StageBind
)
TVM_REGISTER_API
(
"_StageBind"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
args
[
0
].
operator
Stage
()
args
[
0
].
operator
Stage
()
.
bind
(
args
[
1
],
args
[
2
]);
.
bind
(
args
[
1
],
args
[
2
]);
});
});
TVM_REGISTER_API
(
_StageSplitByFactor
)
TVM_REGISTER_API
(
"_StageSplitByFactor"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
IterVar
outer
,
inner
;
IterVar
outer
,
inner
;
args
[
0
].
operator
Stage
()
args
[
0
].
operator
Stage
()
...
@@ -239,7 +239,7 @@ TVM_REGISTER_API(_StageSplitByFactor)
...
@@ -239,7 +239,7 @@ TVM_REGISTER_API(_StageSplitByFactor)
*
ret
=
Array
<
IterVar
>
({
outer
,
inner
});
*
ret
=
Array
<
IterVar
>
({
outer
,
inner
});
});
});
TVM_REGISTER_API
(
_StageSplitByNParts
)
TVM_REGISTER_API
(
"_StageSplitByNParts"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
IterVar
outer
,
inner
;
IterVar
outer
,
inner
;
args
[
0
].
operator
Stage
()
args
[
0
].
operator
Stage
()
...
@@ -247,7 +247,7 @@ TVM_REGISTER_API(_StageSplitByNParts)
...
@@ -247,7 +247,7 @@ TVM_REGISTER_API(_StageSplitByNParts)
*
ret
=
Array
<
IterVar
>
({
outer
,
inner
});
*
ret
=
Array
<
IterVar
>
({
outer
,
inner
});
});
});
TVM_REGISTER_API
(
_StageFuse
)
TVM_REGISTER_API
(
"_StageFuse"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
IterVar
fused
;
IterVar
fused
;
args
[
0
].
operator
Stage
()
args
[
0
].
operator
Stage
()
...
@@ -255,31 +255,31 @@ TVM_REGISTER_API(_StageFuse)
...
@@ -255,31 +255,31 @@ TVM_REGISTER_API(_StageFuse)
*
ret
=
fused
;
*
ret
=
fused
;
});
});
TVM_REGISTER_API
(
_StageComputeAt
)
TVM_REGISTER_API
(
"_StageComputeAt"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
args
[
0
].
operator
Stage
()
args
[
0
].
operator
Stage
()
.
compute_at
(
args
[
1
],
args
[
2
]);
.
compute_at
(
args
[
1
],
args
[
2
]);
});
});
TVM_REGISTER_API
(
_StageComputeInline
)
TVM_REGISTER_API
(
"_StageComputeInline"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
args
[
0
].
operator
Stage
()
args
[
0
].
operator
Stage
()
.
compute_inline
();
.
compute_inline
();
});
});
TVM_REGISTER_API
(
_StageComputeRoot
)
TVM_REGISTER_API
(
"_StageComputeRoot"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
args
[
0
].
operator
Stage
()
args
[
0
].
operator
Stage
()
.
compute_root
();
.
compute_root
();
});
});
TVM_REGISTER_API
(
_StageReorder
)
TVM_REGISTER_API
(
"_StageReorder"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
args
[
0
].
operator
Stage
()
args
[
0
].
operator
Stage
()
.
reorder
(
args
[
1
]);
.
reorder
(
args
[
1
]);
});
});
TVM_REGISTER_API
(
_StageTile
)
TVM_REGISTER_API
(
"_StageTile"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
IterVar
x_outer
,
y_outer
,
x_inner
,
y_inner
;
IterVar
x_outer
,
y_outer
,
x_inner
,
y_inner
;
args
[
0
].
operator
Stage
()
args
[
0
].
operator
Stage
()
...
@@ -290,55 +290,55 @@ TVM_REGISTER_API(_StageTile)
...
@@ -290,55 +290,55 @@ TVM_REGISTER_API(_StageTile)
*
ret
=
Array
<
IterVar
>
({
x_outer
,
y_outer
,
x_inner
,
y_inner
});
*
ret
=
Array
<
IterVar
>
({
x_outer
,
y_outer
,
x_inner
,
y_inner
});
});
});
TVM_REGISTER_API
(
_StageEnvThreads
)
TVM_REGISTER_API
(
"_StageEnvThreads"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
args
[
0
].
operator
Stage
()
args
[
0
].
operator
Stage
()
.
env_threads
(
args
[
1
]);
.
env_threads
(
args
[
1
]);
});
});
TVM_REGISTER_API
(
_StageUnroll
)
TVM_REGISTER_API
(
"_StageUnroll"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
args
[
0
].
operator
Stage
()
args
[
0
].
operator
Stage
()
.
unroll
(
args
[
1
]);
.
unroll
(
args
[
1
]);
});
});
TVM_REGISTER_API
(
_StageVectorize
)
TVM_REGISTER_API
(
"_StageVectorize"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
args
[
0
].
operator
Stage
()
args
[
0
].
operator
Stage
()
.
vectorize
(
args
[
1
]);
.
vectorize
(
args
[
1
]);
});
});
TVM_REGISTER_API
(
_StageParallel
)
TVM_REGISTER_API
(
"_StageParallel"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
args
[
0
].
operator
Stage
()
args
[
0
].
operator
Stage
()
.
parallel
(
args
[
1
]);
.
parallel
(
args
[
1
]);
});
});
TVM_REGISTER_API
(
_ScheduleNormalize
)
TVM_REGISTER_API
(
"_ScheduleNormalize"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
args
[
0
].
operator
Schedule
()
*
ret
=
args
[
0
].
operator
Schedule
()
.
normalize
();
.
normalize
();
});
});
TVM_REGISTER_API
(
_ScheduleCreateGroup
)
TVM_REGISTER_API
(
"_ScheduleCreateGroup"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
args
[
0
].
operator
Schedule
()
*
ret
=
args
[
0
].
operator
Schedule
()
.
create_group
(
args
[
1
],
args
[
2
],
args
[
3
]);
.
create_group
(
args
[
1
],
args
[
2
],
args
[
3
]);
});
});
TVM_REGISTER_API
(
_ScheduleCacheRead
)
TVM_REGISTER_API
(
"_ScheduleCacheRead"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
args
[
0
].
operator
Schedule
()
*
ret
=
args
[
0
].
operator
Schedule
()
.
cache_read
(
args
[
1
],
args
[
2
],
args
[
3
]);
.
cache_read
(
args
[
1
],
args
[
2
],
args
[
3
]);
});
});
TVM_REGISTER_API
(
_ScheduleCacheWrite
)
TVM_REGISTER_API
(
"_ScheduleCacheWrite"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
args
[
0
].
operator
Schedule
()
*
ret
=
args
[
0
].
operator
Schedule
()
.
cache_write
(
args
[
1
],
args
[
2
]);
.
cache_write
(
args
[
1
],
args
[
2
]);
});
});
TVM_REGISTER_API
(
_ScheduleRFactor
)
TVM_REGISTER_API
(
"_ScheduleRFactor"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
args
[
0
].
operator
Schedule
()
*
ret
=
args
[
0
].
operator
Schedule
()
.
rfactor
(
args
[
1
],
args
[
2
]);
.
rfactor
(
args
[
1
],
args
[
2
]);
...
...
src/api/api_pass.cc
View file @
181edb4a
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
namespace
tvm
{
namespace
tvm
{
namespace
ir
{
namespace
ir
{
TVM_REGISTER_API
(
_pass_Simplify
)
TVM_REGISTER_API
(
"ir_pass.Simplify"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
if
(
args
[
0
].
IsNodeType
<
Stmt
>
())
{
if
(
args
[
0
].
IsNodeType
<
Stmt
>
())
{
*
ret
=
Simplify
(
args
[
0
].
operator
Stmt
());
*
ret
=
Simplify
(
args
[
0
].
operator
Stmt
());
...
@@ -21,7 +21,7 @@ TVM_REGISTER_API(_pass_Simplify)
...
@@ -21,7 +21,7 @@ TVM_REGISTER_API(_pass_Simplify)
}
}
});
});
TVM_REGISTER_API
(
_pass_Equal
)
TVM_REGISTER_API
(
"ir_pass.Equal"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
if
(
args
[
0
].
IsNodeType
<
Stmt
>
())
{
if
(
args
[
0
].
IsNodeType
<
Stmt
>
())
{
*
ret
=
Equal
(
args
[
0
].
operator
Stmt
(),
args
[
1
].
operator
Stmt
());
*
ret
=
Equal
(
args
[
0
].
operator
Stmt
(),
args
[
1
].
operator
Stmt
());
...
@@ -30,7 +30,7 @@ TVM_REGISTER_API(_pass_Equal)
...
@@ -30,7 +30,7 @@ TVM_REGISTER_API(_pass_Equal)
}
}
});
});
TVM_REGISTER_API
(
_pass_PostOrderVisit
)
TVM_REGISTER_API
(
"ir_pass.PostOrderVisit"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
PackedFunc
f
=
args
[
1
];
PackedFunc
f
=
args
[
1
];
ir
::
PostOrderVisit
(
args
[
0
],
[
f
](
const
NodeRef
&
n
)
{
ir
::
PostOrderVisit
(
args
[
0
],
[
f
](
const
NodeRef
&
n
)
{
...
@@ -40,19 +40,19 @@ TVM_REGISTER_API(_pass_PostOrderVisit)
...
@@ -40,19 +40,19 @@ TVM_REGISTER_API(_pass_PostOrderVisit)
// make from two arguments
// make from two arguments
#define REGISTER_PASS1(PassName) \
#define REGISTER_PASS1(PassName) \
TVM_REGISTER_API(
_pass_## PassName)
\
TVM_REGISTER_API(
"ir_pass."#PassName)
\
.set_body([](TVMArgs args, TVMRetValue *ret) { \
.set_body([](TVMArgs args, TVMRetValue *ret) { \
*ret = PassName(args[0]); \
*ret = PassName(args[0]); \
}) \
}) \
#define REGISTER_PASS2(PassName) \
#define REGISTER_PASS2(PassName) \
TVM_REGISTER_API(
_pass_## PassName)
\
TVM_REGISTER_API(
"ir_pass."#PassName)
\
.set_body([](TVMArgs args, TVMRetValue *ret) { \
.set_body([](TVMArgs args, TVMRetValue *ret) { \
*ret = PassName(args[0], args[1]); \
*ret = PassName(args[0], args[1]); \
}) \
}) \
#define REGISTER_PASS4(PassName) \
#define REGISTER_PASS4(PassName) \
TVM_REGISTER_API(
_pass_## PassName)
\
TVM_REGISTER_API(
"ir_pass."#PassName)
\
.set_body([](TVMArgs args, TVMRetValue *ret) { \
.set_body([](TVMArgs args, TVMRetValue *ret) { \
*ret = PassName(args[0], args[1], args[2], args[3]); \
*ret = PassName(args[0], args[1], args[2], args[3]); \
}) \
}) \
...
...
src/api/api_schedule.cc
View file @
181edb4a
...
@@ -13,19 +13,19 @@
...
@@ -13,19 +13,19 @@
namespace
tvm
{
namespace
tvm
{
namespace
schedule
{
namespace
schedule
{
TVM_REGISTER_API
(
_schedule_AutoInlineElemWise
)
TVM_REGISTER_API
(
"schedule.AutoInlineElemWise"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
AutoInlineElemWise
(
args
[
0
]);
AutoInlineElemWise
(
args
[
0
]);
});
});
#define REGISTER_SCHEDULE_PASS1(PassName) \
#define REGISTER_SCHEDULE_PASS1(PassName) \
TVM_REGISTER_API(
_schedule_## PassName)
\
TVM_REGISTER_API(
"schedule."#PassName)
\
.set_body([](TVMArgs args, TVMRetValue *ret) { \
.set_body([](TVMArgs args, TVMRetValue *ret) { \
*ret = PassName(args[0]); \
*ret = PassName(args[0]); \
}) \
}) \
#define REGISTER_SCHEDULE_PASS2(PassName) \
#define REGISTER_SCHEDULE_PASS2(PassName) \
TVM_REGISTER_API(
_schedule_## PassName)
\
TVM_REGISTER_API(
"schedule."#PassName)
\
.set_body([](TVMArgs args, TVMRetValue *ret) { \
.set_body([](TVMArgs args, TVMRetValue *ret) { \
*ret = PassName(args[0], args[1]); \
*ret = PassName(args[0], args[1]); \
}) \
}) \
...
...
src/api/c_api.cc
View file @
181edb4a
...
@@ -103,6 +103,23 @@ int TVMNodeFree(NodeHandle handle) {
...
@@ -103,6 +103,23 @@ int TVMNodeFree(NodeHandle handle) {
API_END
();
API_END
();
}
}
int
TVMCbArgToReturn
(
TVMValue
*
value
,
int
code
)
{
API_BEGIN
();
tvm
::
runtime
::
TVMRetValue
rv
;
rv
=
tvm
::
runtime
::
TVMArgValue
(
*
value
,
code
);
int
tcode
;
rv
.
MoveToCHost
(
value
,
&
tcode
);
CHECK_EQ
(
tcode
,
code
);
API_END
();
}
int
TVMNodeDupe
(
NodeHandle
handle
,
NodeHandle
*
out_handle
)
{
API_BEGIN
();
*
out_handle
=
new
TVMAPINode
(
*
static_cast
<
TVMAPINode
*>
(
handle
));
API_END
();
}
int
TVMNodeGetAttr
(
NodeHandle
handle
,
int
TVMNodeGetAttr
(
NodeHandle
handle
,
const
char
*
key
,
const
char
*
key
,
TVMValue
*
ret_val
,
TVMValue
*
ret_val
,
...
...
src/codegen/build_cuda.cc
View file @
181edb4a
...
@@ -86,7 +86,7 @@ runtime::Module BuildCUDA(Array<LoweredFunc> funcs) {
...
@@ -86,7 +86,7 @@ runtime::Module BuildCUDA(Array<LoweredFunc> funcs) {
return
CUDAModuleCreate
(
ptx
,
fmt
,
fmap
,
code
);
return
CUDAModuleCreate
(
ptx
,
fmt
,
fmap
,
code
);
}
}
TVM_REGISTER_API
(
_codegen_build_cuda
)
TVM_REGISTER_API
(
"codegen.build_cuda"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
*
rv
=
BuildCUDA
(
args
[
0
]);
*
rv
=
BuildCUDA
(
args
[
0
]);
});
});
...
...
src/codegen/build_opencl.cc
View file @
181edb4a
...
@@ -41,7 +41,7 @@ runtime::Module BuildOpenCL(Array<LoweredFunc> funcs) {
...
@@ -41,7 +41,7 @@ runtime::Module BuildOpenCL(Array<LoweredFunc> funcs) {
return
OpenCLModuleCreate
(
code
,
"cl"
,
fmap
);
return
OpenCLModuleCreate
(
code
,
"cl"
,
fmap
);
}
}
TVM_REGISTER_API
(
_codegen_build_opencl
)
TVM_REGISTER_API
(
"codegen.build_opencl"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
*
rv
=
BuildOpenCL
(
args
[
0
]);
*
rv
=
BuildOpenCL
(
args
[
0
]);
});
});
...
...
src/codegen/codegen.cc
View file @
181edb4a
...
@@ -17,7 +17,7 @@ runtime::Module Build(const Array<LoweredFunc>& funcs,
...
@@ -17,7 +17,7 @@ runtime::Module Build(const Array<LoweredFunc>& funcs,
if
(
pos
!=
std
::
string
::
npos
)
{
if
(
pos
!=
std
::
string
::
npos
)
{
mode
=
mode
.
substr
(
0
,
pos
);
mode
=
mode
.
substr
(
0
,
pos
);
}
}
std
::
string
build_f_name
=
"
_codegen_
build_"
+
mode
;
std
::
string
build_f_name
=
"
codegen.
build_"
+
mode
;
const
PackedFunc
*
bf
=
runtime
::
Registry
::
Get
(
build_f_name
);
const
PackedFunc
*
bf
=
runtime
::
Registry
::
Get
(
build_f_name
);
CHECK
(
bf
!=
nullptr
)
CHECK
(
bf
!=
nullptr
)
...
@@ -27,7 +27,7 @@ runtime::Module Build(const Array<LoweredFunc>& funcs,
...
@@ -27,7 +27,7 @@ runtime::Module Build(const Array<LoweredFunc>& funcs,
}
}
bool
TargetEnabled
(
const
std
::
string
&
target
)
{
bool
TargetEnabled
(
const
std
::
string
&
target
)
{
std
::
string
build_f_name
=
"
_codegen_
build_"
+
target
;
std
::
string
build_f_name
=
"
codegen.
build_"
+
target
;
return
runtime
::
Registry
::
Get
(
build_f_name
)
!=
nullptr
;
return
runtime
::
Registry
::
Get
(
build_f_name
)
!=
nullptr
;
}
}
...
...
src/codegen/llvm/llvm_module.cc
View file @
181edb4a
...
@@ -150,7 +150,7 @@ class LLVMModuleNode : public runtime::ModuleNode {
...
@@ -150,7 +150,7 @@ class LLVMModuleNode : public runtime::ModuleNode {
std
::
shared_ptr
<
llvm
::
LLVMContext
>
ctx_
;
std
::
shared_ptr
<
llvm
::
LLVMContext
>
ctx_
;
};
};
TVM_REGISTER_API
(
_codegen_build_llvm
)
TVM_REGISTER_API
(
"codegen.build_llvm"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
std
::
shared_ptr
<
LLVMModuleNode
>
n
=
std
::
make_shared
<
LLVMModuleNode
>
();
std
::
shared_ptr
<
LLVMModuleNode
>
n
=
std
::
make_shared
<
LLVMModuleNode
>
();
n
->
Init
(
args
[
0
],
args
[
1
]);
n
->
Init
(
args
[
0
],
args
[
1
]);
...
...
src/codegen/stack_vm/stack_vm_module.cc
View file @
181edb4a
...
@@ -69,7 +69,7 @@ class StackVMModuleNode : public runtime::ModuleNode {
...
@@ -69,7 +69,7 @@ class StackVMModuleNode : public runtime::ModuleNode {
std
::
unordered_map
<
std
::
string
,
StackVM
>
fmap_
;
std
::
unordered_map
<
std
::
string
,
StackVM
>
fmap_
;
};
};
TVM_REGISTER_API
(
_codegen_build_stackvm
)
TVM_REGISTER_API
(
"codegen.build_stackvm"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
*
rv
=
StackVMModuleNode
::
Build
(
args
[
0
]);
*
rv
=
StackVMModuleNode
::
Build
(
args
[
0
]);
});
});
...
...
src/codegen/verilog/verilog_module.cc
View file @
181edb4a
...
@@ -86,7 +86,7 @@ class VerilogModuleNode : public runtime::ModuleNode {
...
@@ -86,7 +86,7 @@ class VerilogModuleNode : public runtime::ModuleNode {
std
::
string
fmt_
;
std
::
string
fmt_
;
};
};
TVM_REGISTER_API
(
_codegen_build_verilog
)
TVM_REGISTER_API
(
"codegen.build_verilog"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
std
::
shared_ptr
<
VerilogModuleNode
>
n
=
std
::
shared_ptr
<
VerilogModuleNode
>
n
=
std
::
make_shared
<
VerilogModuleNode
>
();
std
::
make_shared
<
VerilogModuleNode
>
();
...
...
src/codegen/verilog/vpi_device_api.cc
View file @
181edb4a
...
@@ -385,7 +385,7 @@ class VPIWriteMemMap : public VPIMemMapBase {
...
@@ -385,7 +385,7 @@ class VPIWriteMemMap : public VPIMemMapBase {
VPIHandle
enable_
;
VPIHandle
enable_
;
};
};
TVM_REGISTER_GLOBAL
(
_device_api_vpi
)
TVM_REGISTER_GLOBAL
(
"device_api.vpi"
)
.
set_body
([](
runtime
::
TVMArgs
args
,
runtime
::
TVMRetValue
*
rv
)
{
.
set_body
([](
runtime
::
TVMArgs
args
,
runtime
::
TVMRetValue
*
rv
)
{
runtime
::
DeviceAPI
*
ptr
=
VPIDeviceAPI
::
Global
();
runtime
::
DeviceAPI
*
ptr
=
VPIDeviceAPI
::
Global
();
*
rv
=
static_cast
<
void
*>
(
ptr
);
*
rv
=
static_cast
<
void
*>
(
ptr
);
...
@@ -403,13 +403,13 @@ void TVMVPIHook(runtime::TVMArgs args, runtime::TVMRetValue* rv) {
...
@@ -403,13 +403,13 @@ void TVMVPIHook(runtime::TVMArgs args, runtime::TVMRetValue* rv) {
*
rv
=
pf
;
*
rv
=
pf
;
}
}
TVM_REGISTER_GLOBAL
(
_vpi_module_tvm_vpi_mem_interface
)
TVM_REGISTER_GLOBAL
(
"_vpi_module_tvm_vpi_mem_interface"
)
.
set_body
(
TVMVPIHook
<
VPIMemoryInterface
>
);
.
set_body
(
TVMVPIHook
<
VPIMemoryInterface
>
);
TVM_REGISTER_GLOBAL
(
_vpi_module_tvm_vpi_read_mmap
)
TVM_REGISTER_GLOBAL
(
"_vpi_module_tvm_vpi_read_mmap"
)
.
set_body
(
TVMVPIHook
<
VPIReadMemMap
>
);
.
set_body
(
TVMVPIHook
<
VPIReadMemMap
>
);
TVM_REGISTER_GLOBAL
(
_vpi_module_tvm_vpi_write_mmap
)
TVM_REGISTER_GLOBAL
(
"_vpi_module_tvm_vpi_write_mmap"
)
.
set_body
(
TVMVPIHook
<
VPIWriteMemMap
>
);
.
set_body
(
TVMVPIHook
<
VPIWriteMemMap
>
);
}
// namespace codegen
}
// namespace codegen
...
...
src/codegen/verilog/vpi_session.cc
View file @
181edb4a
...
@@ -212,47 +212,47 @@ VPIHandle VPIHandle::operator[](const std::string& name) const {
...
@@ -212,47 +212,47 @@ VPIHandle VPIHandle::operator[](const std::string& name) const {
}
}
// API registration
// API registration
TVM_REGISTER_API
(
_vpi_SessMake
)
TVM_REGISTER_API
(
"_vpi_SessMake"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
VPISession
::
make
(
args
[
0
],
args
[
1
]);
*
ret
=
VPISession
::
make
(
args
[
0
],
args
[
1
]);
});
});
TVM_REGISTER_API
(
_vpi_SessGetHandleByName
)
TVM_REGISTER_API
(
"_vpi_SessGetHandleByName"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
args
[
0
].
operator
VPISession
().
operator
[](
args
[
1
]);
*
ret
=
args
[
0
].
operator
VPISession
().
operator
[](
args
[
1
]);
});
});
TVM_REGISTER_API
(
_vpi_SessYield
)
TVM_REGISTER_API
(
"_vpi_SessYield"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
args
[
0
].
operator
VPISession
().
yield
();
args
[
0
].
operator
VPISession
().
yield
();
});
});
TVM_REGISTER_API
(
_vpi_SessShutdown
)
TVM_REGISTER_API
(
"_vpi_SessShutdown"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
args
[
0
].
operator
VPISession
().
shutdown
();
args
[
0
].
operator
VPISession
().
shutdown
();
});
});
TVM_REGISTER_API
(
_vpi_HandlePutInt
)
TVM_REGISTER_API
(
"_vpi_HandlePutInt"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
args
[
0
].
operator
VPIHandle
().
put_int
(
args
[
1
]);
args
[
0
].
operator
VPIHandle
().
put_int
(
args
[
1
]);
});
});
TVM_REGISTER_API
(
_vpi_HandleGetInt
)
TVM_REGISTER_API
(
"_vpi_HandleGetInt"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
args
[
0
].
operator
VPIHandle
().
get_int
();
*
ret
=
args
[
0
].
operator
VPIHandle
().
get_int
();
});
});
TVM_REGISTER_API
(
_vpi_HandleGetName
)
TVM_REGISTER_API
(
"_vpi_HandleGetName"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
args
[
0
].
operator
VPIHandle
().
name
();
*
ret
=
args
[
0
].
operator
VPIHandle
().
name
();
});
});
TVM_REGISTER_API
(
_vpi_HandleGetSize
)
TVM_REGISTER_API
(
"_vpi_HandleGetSize"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
args
[
0
].
operator
VPIHandle
().
size
();
*
ret
=
args
[
0
].
operator
VPIHandle
().
size
();
});
});
TVM_REGISTER_API
(
_vpi_HandleGetHandleByName
)
TVM_REGISTER_API
(
"_vpi_HandleGetHandleByName"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
args
[
0
].
operator
VPIHandle
().
operator
[](
args
[
1
]);
*
ret
=
args
[
0
].
operator
VPIHandle
().
operator
[](
args
[
1
]);
});
});
...
...
src/runtime/c_runtime_api.cc
View file @
181edb4a
...
@@ -46,7 +46,7 @@ class DeviceAPIManager {
...
@@ -46,7 +46,7 @@ class DeviceAPIManager {
if
(
api_
[
type
]
!=
nullptr
)
return
api_
[
type
];
if
(
api_
[
type
]
!=
nullptr
)
return
api_
[
type
];
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
if
(
api_
[
type
]
!=
nullptr
)
return
api_
[
type
];
if
(
api_
[
type
]
!=
nullptr
)
return
api_
[
type
];
std
::
string
factory
=
"
_device_api_
"
+
DeviceName
(
type
);
std
::
string
factory
=
"
device_api.
"
+
DeviceName
(
type
);
auto
*
f
=
Registry
::
Get
(
factory
);
auto
*
f
=
Registry
::
Get
(
factory
);
CHECK
(
f
!=
nullptr
)
CHECK
(
f
!=
nullptr
)
<<
"Device API "
<<
DeviceName
(
type
)
<<
" is not enabled."
;
<<
"Device API "
<<
DeviceName
(
type
)
<<
" is not enabled."
;
...
...
src/runtime/cpu_device_api.cc
View file @
181edb4a
...
@@ -46,7 +46,7 @@ class CPUDeviceAPI : public DeviceAPI {
...
@@ -46,7 +46,7 @@ class CPUDeviceAPI : public DeviceAPI {
}
}
};
};
TVM_REGISTER_GLOBAL
(
_device_api_cpu
)
TVM_REGISTER_GLOBAL
(
"device_api.cpu"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
static
CPUDeviceAPI
inst
;
static
CPUDeviceAPI
inst
;
DeviceAPI
*
ptr
=
&
inst
;
DeviceAPI
*
ptr
=
&
inst
;
...
...
src/runtime/cuda/cuda_device_api.cc
View file @
181edb4a
...
@@ -77,7 +77,7 @@ class CUDADeviceAPI : public DeviceAPI {
...
@@ -77,7 +77,7 @@ class CUDADeviceAPI : public DeviceAPI {
}
}
};
};
TVM_REGISTER_GLOBAL
(
_device_api_gpu
)
TVM_REGISTER_GLOBAL
(
"device_api.gpu"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
static
CUDADeviceAPI
inst
;
static
CUDADeviceAPI
inst
;
DeviceAPI
*
ptr
=
&
inst
;
DeviceAPI
*
ptr
=
&
inst
;
...
...
src/runtime/cuda/cuda_module.cc
View file @
181edb4a
...
@@ -281,12 +281,12 @@ Module CUDAModuleLoad(const std::string& file_name,
...
@@ -281,12 +281,12 @@ Module CUDAModuleLoad(const std::string& file_name,
return
CUDAModuleCreate
(
data
,
fmt
,
fmap
,
std
::
string
());
return
CUDAModuleCreate
(
data
,
fmt
,
fmap
,
std
::
string
());
}
}
TVM_REGISTER_GLOBAL
(
_module_loadfile_cubin
)
TVM_REGISTER_GLOBAL
(
"module.loadfile_cubin"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
*
rv
=
CUDAModuleLoad
(
args
[
0
],
args
[
1
]);
*
rv
=
CUDAModuleLoad
(
args
[
0
],
args
[
1
]);
});
});
TVM_REGISTER_GLOBAL
(
_module_loadfile_ptx
)
TVM_REGISTER_GLOBAL
(
"module.loadfile_ptx"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
*
rv
=
CUDAModuleLoad
(
args
[
0
],
args
[
1
]);
*
rv
=
CUDAModuleLoad
(
args
[
0
],
args
[
1
]);
});
});
...
...
src/runtime/dso_module.cc
View file @
181edb4a
...
@@ -110,7 +110,7 @@ class DSOModuleNode : public ModuleNode {
...
@@ -110,7 +110,7 @@ class DSOModuleNode : public ModuleNode {
#endif
#endif
};
};
TVM_REGISTER_GLOBAL
(
_module_loadfile_so
)
TVM_REGISTER_GLOBAL
(
"module.loadfile_so"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
std
::
shared_ptr
<
DSOModuleNode
>
n
=
std
::
make_shared
<
DSOModuleNode
>
();
std
::
shared_ptr
<
DSOModuleNode
>
n
=
std
::
make_shared
<
DSOModuleNode
>
();
n
->
Init
(
args
[
0
]);
n
->
Init
(
args
[
0
]);
...
...
src/runtime/module.cc
View file @
181edb4a
...
@@ -53,7 +53,7 @@ Module Module::LoadFromFile(const std::string& file_name,
...
@@ -53,7 +53,7 @@ Module Module::LoadFromFile(const std::string& file_name,
if
(
fmt
==
"dll"
||
fmt
==
"dylib"
||
fmt
==
"dso"
)
{
if
(
fmt
==
"dll"
||
fmt
==
"dylib"
||
fmt
==
"dso"
)
{
fmt
=
"so"
;
fmt
=
"so"
;
}
}
std
::
string
load_f_name
=
"
_module_
loadfile_"
+
fmt
;
std
::
string
load_f_name
=
"
module.
loadfile_"
+
fmt
;
const
PackedFunc
*
f
=
Registry
::
Get
(
load_f_name
);
const
PackedFunc
*
f
=
Registry
::
Get
(
load_f_name
);
CHECK
(
f
!=
nullptr
)
CHECK
(
f
!=
nullptr
)
<<
"Loader of "
<<
format
<<
"("
<<
"Loader of "
<<
format
<<
"("
...
@@ -88,48 +88,48 @@ bool RuntimeEnabled(const std::string& target) {
...
@@ -88,48 +88,48 @@ bool RuntimeEnabled(const std::string& target) {
if
(
target
==
"cpu"
)
{
if
(
target
==
"cpu"
)
{
return
true
;
return
true
;
}
else
if
(
target
==
"cuda"
||
target
==
"gpu"
)
{
}
else
if
(
target
==
"cuda"
||
target
==
"gpu"
)
{
load_f_name
=
"
_module_
loadfile_ptx"
;
load_f_name
=
"
module.
loadfile_ptx"
;
}
else
if
(
target
==
"cl"
||
target
==
"opencl"
)
{
}
else
if
(
target
==
"cl"
||
target
==
"opencl"
)
{
load_f_name
=
"
_module_
loadfile_cl"
;
load_f_name
=
"
module.
loadfile_cl"
;
}
else
{
}
else
{
LOG
(
FATAL
)
<<
"Unknown optional runtime "
<<
target
;
LOG
(
FATAL
)
<<
"Unknown optional runtime "
<<
target
;
}
}
return
runtime
::
Registry
::
Get
(
load_f_name
)
!=
nullptr
;
return
runtime
::
Registry
::
Get
(
load_f_name
)
!=
nullptr
;
}
}
TVM_REGISTER_GLOBAL
(
_module__Enabled
)
TVM_REGISTER_GLOBAL
(
"module._Enabled"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
RuntimeEnabled
(
args
[
0
]);
*
ret
=
RuntimeEnabled
(
args
[
0
]);
});
});
TVM_REGISTER_GLOBAL
(
_module__GetSource
)
TVM_REGISTER_GLOBAL
(
"module._GetSource"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
args
[
0
].
operator
Module
()
->
GetSource
(
args
[
1
]);
*
ret
=
args
[
0
].
operator
Module
()
->
GetSource
(
args
[
1
]);
});
});
TVM_REGISTER_GLOBAL
(
_module__ImportsSize
)
TVM_REGISTER_GLOBAL
(
"module._ImportsSize"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
static_cast
<
int64_t
>
(
*
ret
=
static_cast
<
int64_t
>
(
args
[
0
].
operator
Module
()
->
imports
().
size
());
args
[
0
].
operator
Module
()
->
imports
().
size
());
});
});
TVM_REGISTER_GLOBAL
(
_module__GetImport
)
TVM_REGISTER_GLOBAL
(
"module._GetImport"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
args
[
0
].
operator
Module
()
->
*
ret
=
args
[
0
].
operator
Module
()
->
imports
().
at
(
args
[
1
].
operator
int
());
imports
().
at
(
args
[
1
].
operator
int
());
});
});
TVM_REGISTER_GLOBAL
(
_module__GetTypeKey
)
TVM_REGISTER_GLOBAL
(
"module._GetTypeKey"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
std
::
string
(
args
[
0
].
operator
Module
()
->
type_key
());
*
ret
=
std
::
string
(
args
[
0
].
operator
Module
()
->
type_key
());
});
});
TVM_REGISTER_GLOBAL
(
_module__LoadFromFile
)
TVM_REGISTER_GLOBAL
(
"module._LoadFromFile"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
Module
::
LoadFromFile
(
args
[
0
],
args
[
1
]);
*
ret
=
Module
::
LoadFromFile
(
args
[
0
],
args
[
1
]);
});
});
TVM_REGISTER_GLOBAL
(
_module__SaveToFile
)
TVM_REGISTER_GLOBAL
(
"module._SaveToFile"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
args
[
0
].
operator
Module
()
->
args
[
0
].
operator
Module
()
->
SaveToFile
(
args
[
1
],
args
[
2
]);
SaveToFile
(
args
[
1
],
args
[
2
]);
...
...
src/runtime/opencl/opencl_device_api.cc
View file @
181edb4a
...
@@ -189,10 +189,10 @@ bool InitOpenCL(TVMArgs args, TVMRetValue* rv) {
...
@@ -189,10 +189,10 @@ bool InitOpenCL(TVMArgs args, TVMRetValue* rv) {
return
true
;
return
true
;
}
}
TVM_REGISTER_GLOBAL
(
_module_init_opencl
)
TVM_REGISTER_GLOBAL
(
"module.init_opencl"
)
.
set_body
(
InitOpenCL
);
.
set_body
(
InitOpenCL
);
TVM_REGISTER_GLOBAL
(
_device_api_opencl
)
TVM_REGISTER_GLOBAL
(
"device_api.opencl"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
DeviceAPI
*
ptr
=
OpenCLWorkspace
::
Global
();
DeviceAPI
*
ptr
=
OpenCLWorkspace
::
Global
();
*
rv
=
static_cast
<
void
*>
(
ptr
);
*
rv
=
static_cast
<
void
*>
(
ptr
);
...
...
src/runtime/opencl/opencl_module.cc
View file @
181edb4a
...
@@ -314,12 +314,12 @@ Module OpenCLModuleLoad(const std::string& file_name,
...
@@ -314,12 +314,12 @@ Module OpenCLModuleLoad(const std::string& file_name,
return
OpenCLModuleCreate
(
data
,
fmt
,
fmap
);
return
OpenCLModuleCreate
(
data
,
fmt
,
fmap
);
}
}
TVM_REGISTER_GLOBAL
(
_module_loadfile_cl
)
TVM_REGISTER_GLOBAL
(
"module.loadfile_cl"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
*
rv
=
OpenCLModuleLoad
(
args
[
0
],
args
[
1
]);
*
rv
=
OpenCLModuleLoad
(
args
[
0
],
args
[
1
]);
});
});
TVM_REGISTER_GLOBAL
(
_module_loadfile_clbin
)
TVM_REGISTER_GLOBAL
(
"module.loadfile_clbin"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
*
rv
=
OpenCLModuleLoad
(
args
[
0
],
args
[
1
]);
*
rv
=
OpenCLModuleLoad
(
args
[
0
],
args
[
1
]);
});
});
...
...
tests/python/unittest/test_runtime_packed_func.py
View file @
181edb4a
...
@@ -14,6 +14,25 @@ def test_get_global():
...
@@ -14,6 +14,25 @@ def test_get_global():
y
=
f
(
*
targs
)
y
=
f
(
*
targs
)
assert
y
==
10
assert
y
==
10
def
test_get_callback_with_node
():
x
=
tvm
.
convert
(
10
)
def
test
(
y
):
assert
y
.
handle
!=
x
.
handle
return
y
f2
=
tvm
.
convert
(
test
)
# register into global function table
@tvm.register_func
def
my_callback_with_node
(
y
,
f
):
assert
y
==
x
return
f
(
y
)
# get it out from global function table
f
=
tvm
.
get_global_func
(
"my_callback_with_node"
)
assert
isinstance
(
f
,
tvm
.
Function
)
y
=
f
(
x
,
f2
)
assert
(
y
.
value
==
10
)
def
test_return_func
():
def
test_return_func
():
def
addy
(
y
):
def
addy
(
y
):
...
@@ -45,6 +64,7 @@ def test_byte_array():
...
@@ -45,6 +64,7 @@ def test_byte_array():
f
(
a
)
f
(
a
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_get_callback_with_node
()
test_convert
()
test_convert
()
test_get_global
()
test_get_global
()
test_return_func
()
test_return_func
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment