Commit 181edb4a by Tianqi Chen Committed by GitHub

[LANG] Change namespace convention to dot (#100)

parent 97c67e53
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include <string> #include <string>
#include <memory> #include <memory>
#include <functional> #include <functional>
#include "./runtime/registry.h"
namespace tvm { namespace tvm {
......
...@@ -20,6 +20,18 @@ TVM_EXTERN_C { ...@@ -20,6 +20,18 @@ TVM_EXTERN_C {
typedef void* NodeHandle; typedef void* NodeHandle;
/*! /*!
* \brief Inplace translate callback argument value to return value.
* This is only needed for non-POD arguments.
*
* \param value The value to be translated.
* \param code The type code to be translated.
* \note This function will do a shallow copy when necessary.
*
* \return 0 when success, -1 when failure happens.
*/
TVM_DLL int TVMCbArgToReturn(TVMValue* value, int code);
/*!
* \brief free the node handle * \brief free the node handle
* \param handle The node handle to be freed. * \param handle The node handle to be freed.
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
......
...@@ -81,7 +81,6 @@ class Registry { ...@@ -81,7 +81,6 @@ class Registry {
friend struct Manager; friend struct Manager;
}; };
/*! \brief helper macro to supress unused warning */ /*! \brief helper macro to supress unused warning */
#if defined(__GNUC__) #if defined(__GNUC__)
#define TVM_ATTRIBUTE_UNUSED __attribute__((unused)) #define TVM_ATTRIBUTE_UNUSED __attribute__((unused))
...@@ -89,19 +88,23 @@ class Registry { ...@@ -89,19 +88,23 @@ class Registry {
#define TVM_ATTRIBUTE_UNUSED #define TVM_ATTRIBUTE_UNUSED
#endif #endif
#define TVM_STR_CONCAT_(__x, __y) __x##__y
#define TVM_STR_CONCAT(__x, __y) TVM_STR_CONCAT_(__x, __y)
#define TVM_FUNC_REG_VAR_DEF \
static TVM_ATTRIBUTE_UNUSED ::tvm::runtime::Registry& __make_ ## TVMOp
/*! /*!
* \brief Register a function globally. * \brief Register a function globally.
* \code * \code
* TVM_REGISTER_GLOBAL(MyPrint) * TVM_REGISTER_GLOBAL("MyPrint")
* .set_body([](TVMArgs args, TVMRetValue* rv) { * .set_body([](TVMArgs args, TVMRetValue* rv) {
* // my code.
* }); * });
* \endcode * \endcode
*/ */
#define TVM_REGISTER_GLOBAL(OpName) \ #define TVM_REGISTER_GLOBAL(OpName) \
static TVM_ATTRIBUTE_UNUSED ::tvm::runtime::Registry& \ TVM_STR_CONCAT(TVM_FUNC_REG_VAR_DEF, __COUNTER__) = \
__make_TVMRegistry_ ## OpName = \ ::tvm::runtime::Registry::Register(OpName)
::tvm::runtime::Registry::Register(#OpName)
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
......
...@@ -12,7 +12,7 @@ from .._base import _LIB, check_call ...@@ -12,7 +12,7 @@ from .._base import _LIB, check_call
from .._base import c_str, py_str, string_types from .._base import c_str, py_str, string_types
from ._types import TVMValue, TypeCode, TVMType, TVMByteArray from ._types import TVMValue, TypeCode, TVMType, TVMByteArray
from ._types import TVMPackedCFunc, TVMCFuncFinalizer from ._types import TVMPackedCFunc, TVMCFuncFinalizer
from ._types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH from ._types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func
from ._node import NodeBase, SliceBase, convert_to_node from ._node import NodeBase, SliceBase, convert_to_node
from ._ndarray import NDArrayBase from ._ndarray import NDArrayBase
...@@ -302,6 +302,10 @@ def _handle_return_func(x): ...@@ -302,6 +302,10 @@ def _handle_return_func(x):
# setup return handle for function type # setup return handle for function type
RETURN_SWITCH[TypeCode.FUNC_HANDLE] = _handle_return_func RETURN_SWITCH[TypeCode.FUNC_HANDLE] = _handle_return_func
RETURN_SWITCH[TypeCode.MODULE_HANDLE] = _return_module RETURN_SWITCH[TypeCode.MODULE_HANDLE] = _return_module
C_TO_PY_ARG_SWITCH[TypeCode.FUNC_HANDLE] = _wrap_arg_func(
_handle_return_func, TypeCode.FUNC_HANDLE)
C_TO_PY_ARG_SWITCH[TypeCode.MODULE_HANDLE] = _wrap_arg_func(
_return_module, TypeCode.MODULE_HANDLE)
def register_func(func_name, f=None): def register_func(func_name, f=None):
...@@ -415,35 +419,31 @@ def _get_api(f): ...@@ -415,35 +419,31 @@ def _get_api(f):
return flocal(*args) return flocal(*args)
return my_api_func return my_api_func
def _init_api(mod): def _init_api(namespace):
"""Initialize api for a given module name """Initialize api for a given module name
mod : str mod : str
The name of the module. The name of the module.
""" """
module = sys.modules[mod] module = sys.modules[namespace]
namespace_match = { assert namespace.startswith("tvm.")
"_make_": "tvm.make", prefix = namespace[4:]
"_arith_": "tvm.arith",
"_pass_": "tvm.ir_pass",
"_codegen_": "tvm.codegen",
"_module_": "tvm.module",
"_schedule_": "tvm.schedule"
}
for name in list_global_func_names(): for name in list_global_func_names():
fname = name if prefix == "api":
target = "tvm.api" fname = name
for k, v in namespace_match.items(): if name.startswith("_"):
if name.startswith(k): target_module = sys.modules["tvm._api_internal"]
fname = name[len(k):] else:
target = v target_module = module
if target != mod:
continue
if mod == "tvm.api" and name.startswith("_"):
target_module = sys.modules["tvm._api_internal"]
else: else:
if not name.startswith(prefix):
continue
fname = name[len(prefix)+1:]
target_module = module target_module = module
if fname.find(".") != -1:
continue
f = get_global_func(name) f = get_global_func(name)
ff = _get_api(f) ff = _get_api(f)
ff.__name__ = fname ff.__name__ = fname
......
...@@ -10,7 +10,9 @@ from numbers import Number, Integral ...@@ -10,7 +10,9 @@ from numbers import Number, Integral
from .._base import _LIB, check_call from .._base import _LIB, check_call
from .._base import c_str, py_str, string_types from .._base import c_str, py_str, string_types
from .. import _api_internal from .. import _api_internal
from ._types import TVMValue, TypeCode, RETURN_SWITCH from ._types import TVMValue, TypeCode
from ._types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func
NodeHandle = ctypes.c_void_p NodeHandle = ctypes.c_void_p
...@@ -19,7 +21,7 @@ NODE_TYPE = { ...@@ -19,7 +21,7 @@ NODE_TYPE = {
} }
def _return_node(x): def _return_node(x):
"""Return function""" """Return node function"""
handle = x.v_handle handle = x.v_handle
if not isinstance(handle, NodeHandle): if not isinstance(handle, NodeHandle):
handle = NodeHandle(handle) handle = NodeHandle(handle)
...@@ -35,6 +37,8 @@ def _return_node(x): ...@@ -35,6 +37,8 @@ def _return_node(x):
RETURN_SWITCH[TypeCode.NODE_HANDLE] = _return_node RETURN_SWITCH[TypeCode.NODE_HANDLE] = _return_node
C_TO_PY_ARG_SWITCH[TypeCode.NODE_HANDLE] = _wrap_arg_func(
_return_node, TypeCode.NODE_HANDLE)
class SliceBase(object): class SliceBase(object):
......
...@@ -4,7 +4,7 @@ from __future__ import absolute_import as _abs ...@@ -4,7 +4,7 @@ from __future__ import absolute_import as _abs
import ctypes import ctypes
import numpy as np import numpy as np
from .._base import py_str from .._base import py_str, check_call, _LIB
tvm_shape_index_t = ctypes.c_int64 tvm_shape_index_t = ctypes.c_int64
...@@ -130,6 +130,12 @@ def _return_bytes(x): ...@@ -130,6 +130,12 @@ def _return_bytes(x):
raise RuntimeError('memmove failed') raise RuntimeError('memmove failed')
return res return res
def _wrap_arg_func(return_f, type_code):
tcode = ctypes.c_int(type_code)
def _wrap_func(x):
check_call(_LIB.TVMCbArgToReturn(ctypes.byref(x), tcode))
return return_f(x)
return _wrap_func
RETURN_SWITCH = { RETURN_SWITCH = {
TypeCode.INT: lambda x: x.v_int64, TypeCode.INT: lambda x: x.v_int64,
......
...@@ -11,49 +11,49 @@ ...@@ -11,49 +11,49 @@
namespace tvm { namespace tvm {
namespace arith { namespace arith {
TVM_REGISTER_API(_arith_intset_single_point) TVM_REGISTER_API("arith.intset_single_point")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = IntSet::single_point(args[0]); *ret = IntSet::single_point(args[0]);
}); });
TVM_REGISTER_API(_arith_intset_interval) TVM_REGISTER_API("arith.intset_interval")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = IntSet::interval(args[0], args[1]); *ret = IntSet::interval(args[0], args[1]);
}); });
TVM_REGISTER_API(_arith_EvalModular) TVM_REGISTER_API("arith.EvalModular")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = EvalModular(args[0], Map<Var, IntSet>()); *ret = EvalModular(args[0], Map<Var, IntSet>());
}); });
TVM_REGISTER_API(_arith_DetectLinearEquation) TVM_REGISTER_API("arith.DetectLinearEquation")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = DetectLinearEquation(args[0], args[1]); *ret = DetectLinearEquation(args[0], args[1]);
}); });
TVM_REGISTER_API(_arith_DeduceBound) TVM_REGISTER_API("arith.DeduceBound")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = DeduceBound(args[0], args[1], *ret = DeduceBound(args[0], args[1],
args[2].operator Map<Var, IntSet>(), args[2].operator Map<Var, IntSet>(),
args[3].operator Map<Var, IntSet>()); args[3].operator Map<Var, IntSet>());
}); });
TVM_REGISTER_API(_IntervalSetGetMin) TVM_REGISTER_API("_IntervalSetGetMin")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = args[0].operator IntSet().min(); *ret = args[0].operator IntSet().min();
}); });
TVM_REGISTER_API(_IntervalSetGetMax) TVM_REGISTER_API("_IntervalSetGetMax")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = args[0].operator IntSet().max(); *ret = args[0].operator IntSet().max();
}); });
TVM_REGISTER_API(_IntSetIsNothing) TVM_REGISTER_API("_IntSetIsNothing")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = args[0].operator IntSet().is_nothing(); *ret = args[0].operator IntSet().is_nothing();
}); });
TVM_REGISTER_API(_IntSetIsEverything) TVM_REGISTER_API("_IntSetIsEverything")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = args[0].operator IntSet().is_everything(); *ret = args[0].operator IntSet().is_everything();
}); });
......
/*! /*!
* Copyright (c) 2017 by Contributors * Copyright (c) 2017 by Contributors
* Implementation of basic API functions * Implementation of basic API functions
* \file api_base.cc * \file api_base.cc
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
namespace tvm { namespace tvm {
TVM_REGISTER_API(_format_str) TVM_REGISTER_API("_format_str")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
CHECK(args[0].type_code() == kNodeHandle); CHECK(args[0].type_code() == kNodeHandle);
std::ostringstream os; std::ostringstream os;
...@@ -17,19 +17,19 @@ TVM_REGISTER_API(_format_str) ...@@ -17,19 +17,19 @@ TVM_REGISTER_API(_format_str)
*ret = os.str(); *ret = os.str();
}); });
TVM_REGISTER_API(_raw_ptr) TVM_REGISTER_API("_raw_ptr")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
CHECK(args[0].type_code() == kNodeHandle); CHECK(args[0].type_code() == kNodeHandle);
*ret = reinterpret_cast<int64_t>( *ret = reinterpret_cast<int64_t>(
args[0].node_sptr().get()); args[0].node_sptr().get());
}); });
TVM_REGISTER_API(_save_json) TVM_REGISTER_API("_save_json")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = SaveJSON(args[0]); *ret = SaveJSON(args[0]);
}); });
TVM_REGISTER_API(_load_json) TVM_REGISTER_API("_load_json")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = NodeRef(LoadJSON_(args[0])); *ret = NodeRef(LoadJSON_(args[0]));
}); });
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
namespace tvm { namespace tvm {
namespace codegen { namespace codegen {
TVM_REGISTER_API(_codegen__Build) TVM_REGISTER_API("codegen._Build")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
if (args[0].IsNodeType<LoweredFunc>()) { if (args[0].IsNodeType<LoweredFunc>()) {
*ret = Build({args[0]}, args[1]); *ret = Build({args[0]}, args[1]);
...@@ -21,7 +21,7 @@ TVM_REGISTER_API(_codegen__Build) ...@@ -21,7 +21,7 @@ TVM_REGISTER_API(_codegen__Build)
} }
}); });
TVM_REGISTER_API(_codegen__Enabled) TVM_REGISTER_API("codegen._Enabled")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = TargetEnabled(args[0]); *ret = TargetEnabled(args[0]);
}); });
......
...@@ -11,12 +11,12 @@ ...@@ -11,12 +11,12 @@
namespace tvm { namespace tvm {
namespace ir { namespace ir {
TVM_REGISTER_API(_Var) TVM_REGISTER_API("_Var")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = Variable::make(args[1], args[0]); *ret = Variable::make(args[1], args[0]);
}); });
TVM_REGISTER_API(_make_For) TVM_REGISTER_API("make.For")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = For::make(args[0], *ret = For::make(args[0],
args[1], args[1],
...@@ -26,7 +26,7 @@ TVM_REGISTER_API(_make_For) ...@@ -26,7 +26,7 @@ TVM_REGISTER_API(_make_For)
args[5]); args[5]);
}); });
TVM_REGISTER_API(_make_Realize) TVM_REGISTER_API("make.Realize")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = Realize::make(args[0], *ret = Realize::make(args[0],
args[1], args[1],
...@@ -37,7 +37,7 @@ TVM_REGISTER_API(_make_Realize) ...@@ -37,7 +37,7 @@ TVM_REGISTER_API(_make_Realize)
}); });
TVM_REGISTER_API(_make_Call) TVM_REGISTER_API("make.Call")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = Call::make(args[0], *ret = Call::make(args[0],
args[1], args[1],
...@@ -47,7 +47,7 @@ TVM_REGISTER_API(_make_Call) ...@@ -47,7 +47,7 @@ TVM_REGISTER_API(_make_Call)
args[5]); args[5]);
}); });
TVM_REGISTER_API(_make_Allocate) TVM_REGISTER_API("make.Allocate")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = Allocate::make(args[0], *ret = Allocate::make(args[0],
args[1], args[1],
...@@ -58,31 +58,31 @@ TVM_REGISTER_API(_make_Allocate) ...@@ -58,31 +58,31 @@ TVM_REGISTER_API(_make_Allocate)
// make from two arguments // make from two arguments
#define REGISTER_MAKE1(Node) \ #define REGISTER_MAKE1(Node) \
TVM_REGISTER_API(_make_## Node) \ TVM_REGISTER_API("make."#Node) \
.set_body([](TVMArgs args, TVMRetValue *ret) { \ .set_body([](TVMArgs args, TVMRetValue *ret) { \
*ret = Node::make(args[0]); \ *ret = Node::make(args[0]); \
}) \ }) \
#define REGISTER_MAKE2(Node) \ #define REGISTER_MAKE2(Node) \
TVM_REGISTER_API(_make_## Node) \ TVM_REGISTER_API("make."#Node) \
.set_body([](TVMArgs args, TVMRetValue *ret) { \ .set_body([](TVMArgs args, TVMRetValue *ret) { \
*ret = Node::make(args[0], args[1]); \ *ret = Node::make(args[0], args[1]); \
}) \ }) \
#define REGISTER_MAKE3(Node) \ #define REGISTER_MAKE3(Node) \
TVM_REGISTER_API(_make_## Node) \ TVM_REGISTER_API("make."#Node) \
.set_body([](TVMArgs args, TVMRetValue *ret) { \ .set_body([](TVMArgs args, TVMRetValue *ret) { \
*ret = Node::make(args[0], args[1], args[2]); \ *ret = Node::make(args[0], args[1], args[2]); \
}) \ }) \
#define REGISTER_MAKE4(Node) \ #define REGISTER_MAKE4(Node) \
TVM_REGISTER_API(_make_## Node) \ TVM_REGISTER_API("make."#Node) \
.set_body([](TVMArgs args, TVMRetValue *ret) { \ .set_body([](TVMArgs args, TVMRetValue *ret) { \
*ret = Node::make(args[0], args[1], args[2], args[3]); \ *ret = Node::make(args[0], args[1], args[2], args[3]); \
}) \ }) \
#define REGISTER_MAKE_BINARY_OP(Node) \ #define REGISTER_MAKE_BINARY_OP(Node) \
TVM_REGISTER_API(_make_## Node) \ TVM_REGISTER_API("make."#Node) \
.set_body([](TVMArgs args, TVMRetValue *ret) { \ .set_body([](TVMArgs args, TVMRetValue *ret) { \
Expr a = args[0], b = args[1]; \ Expr a = args[0], b = args[1]; \
match_types(a, b); \ match_types(a, b); \
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
namespace tvm { namespace tvm {
TVM_REGISTER_API(_const) TVM_REGISTER_API("_const")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
if (args[0].type_code() == kInt) { if (args[0].type_code() == kInt) {
*ret = make_const(args[1], args[0].operator int64_t()); *ret = make_const(args[1], args[0].operator int64_t());
...@@ -25,13 +25,13 @@ TVM_REGISTER_API(_const) ...@@ -25,13 +25,13 @@ TVM_REGISTER_API(_const)
}); });
TVM_REGISTER_API(_str) TVM_REGISTER_API("_str")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = ir::StringImm::make(args[0]); *ret = ir::StringImm::make(args[0]);
}); });
TVM_REGISTER_API(_Array) TVM_REGISTER_API("_Array")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
std::vector<std::shared_ptr<Node> > data; std::vector<std::shared_ptr<Node> > data;
for (int i = 0; i < args.size(); ++i) { for (int i = 0; i < args.size(); ++i) {
...@@ -42,7 +42,7 @@ TVM_REGISTER_API(_Array) ...@@ -42,7 +42,7 @@ TVM_REGISTER_API(_Array)
*ret = node; *ret = node;
}); });
TVM_REGISTER_API(_ArrayGetItem) TVM_REGISTER_API("_ArrayGetItem")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
int64_t i = args[1]; int64_t i = args[1];
auto& sptr = args[0].node_sptr(); auto& sptr = args[0].node_sptr();
...@@ -53,7 +53,7 @@ TVM_REGISTER_API(_ArrayGetItem) ...@@ -53,7 +53,7 @@ TVM_REGISTER_API(_ArrayGetItem)
*ret = n->data[static_cast<size_t>(i)]; *ret = n->data[static_cast<size_t>(i)];
}); });
TVM_REGISTER_API(_ArraySize) TVM_REGISTER_API("_ArraySize")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
auto& sptr = args[0].node_sptr(); auto& sptr = args[0].node_sptr();
CHECK(sptr->is_type<ArrayNode>()); CHECK(sptr->is_type<ArrayNode>());
...@@ -61,7 +61,7 @@ TVM_REGISTER_API(_ArraySize) ...@@ -61,7 +61,7 @@ TVM_REGISTER_API(_ArraySize)
static_cast<const ArrayNode*>(sptr.get())->data.size()); static_cast<const ArrayNode*>(sptr.get())->data.size());
}); });
TVM_REGISTER_API(_Map) TVM_REGISTER_API("_Map")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK_EQ(args.size() % 2, 0); CHECK_EQ(args.size() % 2, 0);
MapNode::ContainerType data; MapNode::ContainerType data;
...@@ -78,7 +78,7 @@ TVM_REGISTER_API(_Map) ...@@ -78,7 +78,7 @@ TVM_REGISTER_API(_Map)
*ret = node; *ret = node;
}); });
TVM_REGISTER_API(_MapSize) TVM_REGISTER_API("_MapSize")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
auto& sptr = args[0].node_sptr(); auto& sptr = args[0].node_sptr();
CHECK(sptr->is_type<MapNode>()); CHECK(sptr->is_type<MapNode>());
...@@ -86,7 +86,7 @@ TVM_REGISTER_API(_MapSize) ...@@ -86,7 +86,7 @@ TVM_REGISTER_API(_MapSize)
*ret = static_cast<int64_t>(n->data.size()); *ret = static_cast<int64_t>(n->data.size());
}); });
TVM_REGISTER_API(_MapGetItem) TVM_REGISTER_API("_MapGetItem")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK(args[0].type_code() == kNodeHandle); CHECK(args[0].type_code() == kNodeHandle);
CHECK(args[1].type_code() == kNodeHandle); CHECK(args[1].type_code() == kNodeHandle);
...@@ -99,7 +99,7 @@ TVM_REGISTER_API(_MapGetItem) ...@@ -99,7 +99,7 @@ TVM_REGISTER_API(_MapGetItem)
*ret = (*it).second; *ret = (*it).second;
}); });
TVM_REGISTER_API(_MapCount) TVM_REGISTER_API("_MapCount")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK(args[0].type_code() == kNodeHandle); CHECK(args[0].type_code() == kNodeHandle);
CHECK(args[1].type_code() == kNodeHandle); CHECK(args[1].type_code() == kNodeHandle);
...@@ -110,7 +110,7 @@ TVM_REGISTER_API(_MapCount) ...@@ -110,7 +110,7 @@ TVM_REGISTER_API(_MapCount)
n->data.count(args[1].node_sptr())); n->data.count(args[1].node_sptr()));
}); });
TVM_REGISTER_API(_MapItems) TVM_REGISTER_API("_MapItems")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
auto& sptr = args[0].node_sptr(); auto& sptr = args[0].node_sptr();
CHECK(sptr->is_type<MapNode>()); CHECK(sptr->is_type<MapNode>());
...@@ -123,7 +123,7 @@ TVM_REGISTER_API(_MapItems) ...@@ -123,7 +123,7 @@ TVM_REGISTER_API(_MapItems)
*ret = rkvs; *ret = rkvs;
}); });
TVM_REGISTER_API(Range) TVM_REGISTER_API("Range")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
if (args.size() == 1) { if (args.size() == 1) {
*ret = Range(0, args[0]); *ret = Range(0, args[0]);
...@@ -132,7 +132,7 @@ TVM_REGISTER_API(Range) ...@@ -132,7 +132,7 @@ TVM_REGISTER_API(Range)
} }
}); });
TVM_REGISTER_API(_Buffer) TVM_REGISTER_API("_Buffer")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = BufferNode::make(args[0], *ret = BufferNode::make(args[0],
args[1], args[1],
...@@ -143,7 +143,7 @@ TVM_REGISTER_API(_Buffer) ...@@ -143,7 +143,7 @@ TVM_REGISTER_API(_Buffer)
args[6]); args[6]);
}); });
TVM_REGISTER_API(_Tensor) TVM_REGISTER_API("_Tensor")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = TensorNode::make(args[0], *ret = TensorNode::make(args[0],
args[1], args[1],
...@@ -151,32 +151,32 @@ TVM_REGISTER_API(_Tensor) ...@@ -151,32 +151,32 @@ TVM_REGISTER_API(_Tensor)
args[3]); args[3]);
}); });
TVM_REGISTER_API(_TensorEqual) TVM_REGISTER_API("_TensorEqual")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Tensor() == args[1].operator Tensor(); *ret = args[0].operator Tensor() == args[1].operator Tensor();
}); });
TVM_REGISTER_API(_TensorHash) TVM_REGISTER_API("_TensorHash")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = static_cast<int64_t>( *ret = static_cast<int64_t>(
std::hash<Tensor>()(args[0].operator Tensor())); std::hash<Tensor>()(args[0].operator Tensor()));
}); });
TVM_REGISTER_API(_Placeholder) TVM_REGISTER_API("_Placeholder")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = placeholder(args[0], *ret = placeholder(args[0],
args[1], args[1],
args[2]); args[2]);
}); });
TVM_REGISTER_API(_ComputeOp) TVM_REGISTER_API("_ComputeOp")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = ComputeOpNode::make(args[0], *ret = ComputeOpNode::make(args[0],
args[1], args[1],
args[2]); args[2]);
}); });
TVM_REGISTER_API(_ScanOp) TVM_REGISTER_API("_ScanOp")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = ScanOpNode::make(args[0], *ret = ScanOpNode::make(args[0],
args[1], args[1],
...@@ -186,7 +186,7 @@ TVM_REGISTER_API(_ScanOp) ...@@ -186,7 +186,7 @@ TVM_REGISTER_API(_ScanOp)
args[5]); args[5]);
}); });
TVM_REGISTER_API(_ExternOp) TVM_REGISTER_API("_ExternOp")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = ExternOpNode::make(args[0], *ret = ExternOpNode::make(args[0],
args[1], args[1],
...@@ -195,18 +195,18 @@ TVM_REGISTER_API(_ExternOp) ...@@ -195,18 +195,18 @@ TVM_REGISTER_API(_ExternOp)
args[4]); args[4]);
}); });
TVM_REGISTER_API(_OpGetOutput) TVM_REGISTER_API("_OpGetOutput")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Operation().output( *ret = args[0].operator Operation().output(
static_cast<size_t>(args[1].operator int64_t())); static_cast<size_t>(args[1].operator int64_t()));
}); });
TVM_REGISTER_API(_OpNumOutputs) TVM_REGISTER_API("_OpNumOutputs")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Operation()->num_outputs(); *ret = args[0].operator Operation()->num_outputs();
}); });
TVM_REGISTER_API(_IterVar) TVM_REGISTER_API("_IterVar")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = IterVarNode::make( *ret = IterVarNode::make(
args[0], args[1], args[0], args[1],
...@@ -214,24 +214,24 @@ TVM_REGISTER_API(_IterVar) ...@@ -214,24 +214,24 @@ TVM_REGISTER_API(_IterVar)
args[3]); args[3]);
}); });
TVM_REGISTER_API(_Schedule) TVM_REGISTER_API("_Schedule")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = Schedule(args[0].operator Array<Operation>()); *ret = Schedule(args[0].operator Array<Operation>());
}); });
TVM_REGISTER_API(_StageSetScope) TVM_REGISTER_API("_StageSetScope")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
args[0].operator Stage() args[0].operator Stage()
.set_scope(args[1]); .set_scope(args[1]);
}); });
TVM_REGISTER_API(_StageBind) TVM_REGISTER_API("_StageBind")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
args[0].operator Stage() args[0].operator Stage()
.bind(args[1], args[2]); .bind(args[1], args[2]);
}); });
TVM_REGISTER_API(_StageSplitByFactor) TVM_REGISTER_API("_StageSplitByFactor")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
IterVar outer, inner; IterVar outer, inner;
args[0].operator Stage() args[0].operator Stage()
...@@ -239,7 +239,7 @@ TVM_REGISTER_API(_StageSplitByFactor) ...@@ -239,7 +239,7 @@ TVM_REGISTER_API(_StageSplitByFactor)
*ret = Array<IterVar>({outer, inner}); *ret = Array<IterVar>({outer, inner});
}); });
TVM_REGISTER_API(_StageSplitByNParts) TVM_REGISTER_API("_StageSplitByNParts")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
IterVar outer, inner; IterVar outer, inner;
args[0].operator Stage() args[0].operator Stage()
...@@ -247,7 +247,7 @@ TVM_REGISTER_API(_StageSplitByNParts) ...@@ -247,7 +247,7 @@ TVM_REGISTER_API(_StageSplitByNParts)
*ret = Array<IterVar>({outer, inner}); *ret = Array<IterVar>({outer, inner});
}); });
TVM_REGISTER_API(_StageFuse) TVM_REGISTER_API("_StageFuse")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
IterVar fused; IterVar fused;
args[0].operator Stage() args[0].operator Stage()
...@@ -255,31 +255,31 @@ TVM_REGISTER_API(_StageFuse) ...@@ -255,31 +255,31 @@ TVM_REGISTER_API(_StageFuse)
*ret = fused; *ret = fused;
}); });
TVM_REGISTER_API(_StageComputeAt) TVM_REGISTER_API("_StageComputeAt")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
args[0].operator Stage() args[0].operator Stage()
.compute_at(args[1], args[2]); .compute_at(args[1], args[2]);
}); });
TVM_REGISTER_API(_StageComputeInline) TVM_REGISTER_API("_StageComputeInline")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
args[0].operator Stage() args[0].operator Stage()
.compute_inline(); .compute_inline();
}); });
TVM_REGISTER_API(_StageComputeRoot) TVM_REGISTER_API("_StageComputeRoot")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
args[0].operator Stage() args[0].operator Stage()
.compute_root(); .compute_root();
}); });
TVM_REGISTER_API(_StageReorder) TVM_REGISTER_API("_StageReorder")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
args[0].operator Stage() args[0].operator Stage()
.reorder(args[1]); .reorder(args[1]);
}); });
TVM_REGISTER_API(_StageTile) TVM_REGISTER_API("_StageTile")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
IterVar x_outer, y_outer, x_inner, y_inner; IterVar x_outer, y_outer, x_inner, y_inner;
args[0].operator Stage() args[0].operator Stage()
...@@ -290,55 +290,55 @@ TVM_REGISTER_API(_StageTile) ...@@ -290,55 +290,55 @@ TVM_REGISTER_API(_StageTile)
*ret = Array<IterVar>({x_outer, y_outer, x_inner, y_inner}); *ret = Array<IterVar>({x_outer, y_outer, x_inner, y_inner});
}); });
TVM_REGISTER_API(_StageEnvThreads) TVM_REGISTER_API("_StageEnvThreads")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
args[0].operator Stage() args[0].operator Stage()
.env_threads(args[1]); .env_threads(args[1]);
}); });
TVM_REGISTER_API(_StageUnroll) TVM_REGISTER_API("_StageUnroll")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
args[0].operator Stage() args[0].operator Stage()
.unroll(args[1]); .unroll(args[1]);
}); });
TVM_REGISTER_API(_StageVectorize) TVM_REGISTER_API("_StageVectorize")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
args[0].operator Stage() args[0].operator Stage()
.vectorize(args[1]); .vectorize(args[1]);
}); });
TVM_REGISTER_API(_StageParallel) TVM_REGISTER_API("_StageParallel")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
args[0].operator Stage() args[0].operator Stage()
.parallel(args[1]); .parallel(args[1]);
}); });
TVM_REGISTER_API(_ScheduleNormalize) TVM_REGISTER_API("_ScheduleNormalize")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Schedule() *ret = args[0].operator Schedule()
.normalize(); .normalize();
}); });
TVM_REGISTER_API(_ScheduleCreateGroup) TVM_REGISTER_API("_ScheduleCreateGroup")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Schedule() *ret = args[0].operator Schedule()
.create_group(args[1], args[2], args[3]); .create_group(args[1], args[2], args[3]);
}); });
TVM_REGISTER_API(_ScheduleCacheRead) TVM_REGISTER_API("_ScheduleCacheRead")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Schedule() *ret = args[0].operator Schedule()
.cache_read(args[1], args[2], args[3]); .cache_read(args[1], args[2], args[3]);
}); });
TVM_REGISTER_API(_ScheduleCacheWrite) TVM_REGISTER_API("_ScheduleCacheWrite")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Schedule() *ret = args[0].operator Schedule()
.cache_write(args[1], args[2]); .cache_write(args[1], args[2]);
}); });
TVM_REGISTER_API(_ScheduleRFactor) TVM_REGISTER_API("_ScheduleRFactor")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Schedule() *ret = args[0].operator Schedule()
.rfactor(args[1], args[2]); .rfactor(args[1], args[2]);
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
namespace tvm { namespace tvm {
namespace ir { namespace ir {
TVM_REGISTER_API(_pass_Simplify) TVM_REGISTER_API("ir_pass.Simplify")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
if (args[0].IsNodeType<Stmt>()) { if (args[0].IsNodeType<Stmt>()) {
*ret = Simplify(args[0].operator Stmt()); *ret = Simplify(args[0].operator Stmt());
...@@ -21,7 +21,7 @@ TVM_REGISTER_API(_pass_Simplify) ...@@ -21,7 +21,7 @@ TVM_REGISTER_API(_pass_Simplify)
} }
}); });
TVM_REGISTER_API(_pass_Equal) TVM_REGISTER_API("ir_pass.Equal")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
if (args[0].IsNodeType<Stmt>()) { if (args[0].IsNodeType<Stmt>()) {
*ret = Equal(args[0].operator Stmt(), args[1].operator Stmt()); *ret = Equal(args[0].operator Stmt(), args[1].operator Stmt());
...@@ -30,7 +30,7 @@ TVM_REGISTER_API(_pass_Equal) ...@@ -30,7 +30,7 @@ TVM_REGISTER_API(_pass_Equal)
} }
}); });
TVM_REGISTER_API(_pass_PostOrderVisit) TVM_REGISTER_API("ir_pass.PostOrderVisit")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
PackedFunc f = args[1]; PackedFunc f = args[1];
ir::PostOrderVisit(args[0], [f](const NodeRef& n) { ir::PostOrderVisit(args[0], [f](const NodeRef& n) {
...@@ -40,19 +40,19 @@ TVM_REGISTER_API(_pass_PostOrderVisit) ...@@ -40,19 +40,19 @@ TVM_REGISTER_API(_pass_PostOrderVisit)
// make from two arguments // make from two arguments
#define REGISTER_PASS1(PassName) \ #define REGISTER_PASS1(PassName) \
TVM_REGISTER_API(_pass_## PassName) \ TVM_REGISTER_API("ir_pass."#PassName) \
.set_body([](TVMArgs args, TVMRetValue *ret) { \ .set_body([](TVMArgs args, TVMRetValue *ret) { \
*ret = PassName(args[0]); \ *ret = PassName(args[0]); \
}) \ }) \
#define REGISTER_PASS2(PassName) \ #define REGISTER_PASS2(PassName) \
TVM_REGISTER_API(_pass_## PassName) \ TVM_REGISTER_API("ir_pass."#PassName) \
.set_body([](TVMArgs args, TVMRetValue *ret) { \ .set_body([](TVMArgs args, TVMRetValue *ret) { \
*ret = PassName(args[0], args[1]); \ *ret = PassName(args[0], args[1]); \
}) \ }) \
#define REGISTER_PASS4(PassName) \ #define REGISTER_PASS4(PassName) \
TVM_REGISTER_API(_pass_## PassName) \ TVM_REGISTER_API("ir_pass."#PassName) \
.set_body([](TVMArgs args, TVMRetValue *ret) { \ .set_body([](TVMArgs args, TVMRetValue *ret) { \
*ret = PassName(args[0], args[1], args[2], args[3]); \ *ret = PassName(args[0], args[1], args[2], args[3]); \
}) \ }) \
......
...@@ -13,19 +13,19 @@ ...@@ -13,19 +13,19 @@
namespace tvm { namespace tvm {
namespace schedule { namespace schedule {
TVM_REGISTER_API(_schedule_AutoInlineElemWise) TVM_REGISTER_API("schedule.AutoInlineElemWise")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
AutoInlineElemWise(args[0]); AutoInlineElemWise(args[0]);
}); });
#define REGISTER_SCHEDULE_PASS1(PassName) \ #define REGISTER_SCHEDULE_PASS1(PassName) \
TVM_REGISTER_API(_schedule_## PassName) \ TVM_REGISTER_API("schedule."#PassName) \
.set_body([](TVMArgs args, TVMRetValue *ret) { \ .set_body([](TVMArgs args, TVMRetValue *ret) { \
*ret = PassName(args[0]); \ *ret = PassName(args[0]); \
}) \ }) \
#define REGISTER_SCHEDULE_PASS2(PassName) \ #define REGISTER_SCHEDULE_PASS2(PassName) \
TVM_REGISTER_API(_schedule_## PassName) \ TVM_REGISTER_API("schedule."#PassName) \
.set_body([](TVMArgs args, TVMRetValue *ret) { \ .set_body([](TVMArgs args, TVMRetValue *ret) { \
*ret = PassName(args[0], args[1]); \ *ret = PassName(args[0], args[1]); \
}) \ }) \
......
...@@ -103,6 +103,23 @@ int TVMNodeFree(NodeHandle handle) { ...@@ -103,6 +103,23 @@ int TVMNodeFree(NodeHandle handle) {
API_END(); API_END();
} }
int TVMCbArgToReturn(TVMValue* value, int code) {
API_BEGIN();
tvm::runtime::TVMRetValue rv;
rv = tvm::runtime::TVMArgValue(*value, code);
int tcode;
rv.MoveToCHost(value, &tcode);
CHECK_EQ(tcode, code);
API_END();
}
int TVMNodeDupe(NodeHandle handle, NodeHandle* out_handle) {
API_BEGIN();
*out_handle = new TVMAPINode(*static_cast<TVMAPINode*>(handle));
API_END();
}
int TVMNodeGetAttr(NodeHandle handle, int TVMNodeGetAttr(NodeHandle handle,
const char* key, const char* key,
TVMValue* ret_val, TVMValue* ret_val,
......
...@@ -86,7 +86,7 @@ runtime::Module BuildCUDA(Array<LoweredFunc> funcs) { ...@@ -86,7 +86,7 @@ runtime::Module BuildCUDA(Array<LoweredFunc> funcs) {
return CUDAModuleCreate(ptx, fmt, fmap, code); return CUDAModuleCreate(ptx, fmt, fmap, code);
} }
TVM_REGISTER_API(_codegen_build_cuda) TVM_REGISTER_API("codegen.build_cuda")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = BuildCUDA(args[0]); *rv = BuildCUDA(args[0]);
}); });
......
...@@ -41,7 +41,7 @@ runtime::Module BuildOpenCL(Array<LoweredFunc> funcs) { ...@@ -41,7 +41,7 @@ runtime::Module BuildOpenCL(Array<LoweredFunc> funcs) {
return OpenCLModuleCreate(code, "cl", fmap); return OpenCLModuleCreate(code, "cl", fmap);
} }
TVM_REGISTER_API(_codegen_build_opencl) TVM_REGISTER_API("codegen.build_opencl")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = BuildOpenCL(args[0]); *rv = BuildOpenCL(args[0]);
}); });
......
...@@ -17,7 +17,7 @@ runtime::Module Build(const Array<LoweredFunc>& funcs, ...@@ -17,7 +17,7 @@ runtime::Module Build(const Array<LoweredFunc>& funcs,
if (pos != std::string::npos) { if (pos != std::string::npos) {
mode = mode.substr(0, pos); mode = mode.substr(0, pos);
} }
std::string build_f_name = "_codegen_build_" + mode; std::string build_f_name = "codegen.build_" + mode;
const PackedFunc* bf = runtime::Registry::Get(build_f_name); const PackedFunc* bf = runtime::Registry::Get(build_f_name);
CHECK(bf != nullptr) CHECK(bf != nullptr)
...@@ -27,7 +27,7 @@ runtime::Module Build(const Array<LoweredFunc>& funcs, ...@@ -27,7 +27,7 @@ runtime::Module Build(const Array<LoweredFunc>& funcs,
} }
bool TargetEnabled(const std::string& target) { bool TargetEnabled(const std::string& target) {
std::string build_f_name = "_codegen_build_" + target; std::string build_f_name = "codegen.build_" + target;
return runtime::Registry::Get(build_f_name) != nullptr; return runtime::Registry::Get(build_f_name) != nullptr;
} }
......
...@@ -150,7 +150,7 @@ class LLVMModuleNode : public runtime::ModuleNode { ...@@ -150,7 +150,7 @@ class LLVMModuleNode : public runtime::ModuleNode {
std::shared_ptr<llvm::LLVMContext> ctx_; std::shared_ptr<llvm::LLVMContext> ctx_;
}; };
TVM_REGISTER_API(_codegen_build_llvm) TVM_REGISTER_API("codegen.build_llvm")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
std::shared_ptr<LLVMModuleNode> n = std::make_shared<LLVMModuleNode>(); std::shared_ptr<LLVMModuleNode> n = std::make_shared<LLVMModuleNode>();
n->Init(args[0], args[1]); n->Init(args[0], args[1]);
......
...@@ -69,7 +69,7 @@ class StackVMModuleNode : public runtime::ModuleNode { ...@@ -69,7 +69,7 @@ class StackVMModuleNode : public runtime::ModuleNode {
std::unordered_map<std::string, StackVM> fmap_; std::unordered_map<std::string, StackVM> fmap_;
}; };
TVM_REGISTER_API(_codegen_build_stackvm) TVM_REGISTER_API("codegen.build_stackvm")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = StackVMModuleNode::Build(args[0]); *rv = StackVMModuleNode::Build(args[0]);
}); });
......
...@@ -86,7 +86,7 @@ class VerilogModuleNode : public runtime::ModuleNode { ...@@ -86,7 +86,7 @@ class VerilogModuleNode : public runtime::ModuleNode {
std::string fmt_; std::string fmt_;
}; };
TVM_REGISTER_API(_codegen_build_verilog) TVM_REGISTER_API("codegen.build_verilog")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
std::shared_ptr<VerilogModuleNode> n = std::shared_ptr<VerilogModuleNode> n =
std::make_shared<VerilogModuleNode>(); std::make_shared<VerilogModuleNode>();
......
...@@ -385,7 +385,7 @@ class VPIWriteMemMap : public VPIMemMapBase { ...@@ -385,7 +385,7 @@ class VPIWriteMemMap : public VPIMemMapBase {
VPIHandle enable_; VPIHandle enable_;
}; };
TVM_REGISTER_GLOBAL(_device_api_vpi) TVM_REGISTER_GLOBAL("device_api.vpi")
.set_body([](runtime::TVMArgs args, runtime::TVMRetValue* rv) { .set_body([](runtime::TVMArgs args, runtime::TVMRetValue* rv) {
runtime::DeviceAPI* ptr = VPIDeviceAPI::Global(); runtime::DeviceAPI* ptr = VPIDeviceAPI::Global();
*rv = static_cast<void*>(ptr); *rv = static_cast<void*>(ptr);
...@@ -403,13 +403,13 @@ void TVMVPIHook(runtime::TVMArgs args, runtime::TVMRetValue* rv) { ...@@ -403,13 +403,13 @@ void TVMVPIHook(runtime::TVMArgs args, runtime::TVMRetValue* rv) {
*rv = pf; *rv = pf;
} }
TVM_REGISTER_GLOBAL(_vpi_module_tvm_vpi_mem_interface) TVM_REGISTER_GLOBAL("_vpi_module_tvm_vpi_mem_interface")
.set_body(TVMVPIHook<VPIMemoryInterface>); .set_body(TVMVPIHook<VPIMemoryInterface>);
TVM_REGISTER_GLOBAL(_vpi_module_tvm_vpi_read_mmap) TVM_REGISTER_GLOBAL("_vpi_module_tvm_vpi_read_mmap")
.set_body(TVMVPIHook<VPIReadMemMap>); .set_body(TVMVPIHook<VPIReadMemMap>);
TVM_REGISTER_GLOBAL(_vpi_module_tvm_vpi_write_mmap) TVM_REGISTER_GLOBAL("_vpi_module_tvm_vpi_write_mmap")
.set_body(TVMVPIHook<VPIWriteMemMap>); .set_body(TVMVPIHook<VPIWriteMemMap>);
} // namespace codegen } // namespace codegen
......
...@@ -212,47 +212,47 @@ VPIHandle VPIHandle::operator[](const std::string& name) const { ...@@ -212,47 +212,47 @@ VPIHandle VPIHandle::operator[](const std::string& name) const {
} }
// API registration // API registration
TVM_REGISTER_API(_vpi_SessMake) TVM_REGISTER_API("_vpi_SessMake")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = VPISession::make(args[0], args[1]); *ret = VPISession::make(args[0], args[1]);
}); });
TVM_REGISTER_API(_vpi_SessGetHandleByName) TVM_REGISTER_API("_vpi_SessGetHandleByName")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = args[0].operator VPISession().operator[](args[1]); *ret = args[0].operator VPISession().operator[](args[1]);
}); });
TVM_REGISTER_API(_vpi_SessYield) TVM_REGISTER_API("_vpi_SessYield")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
args[0].operator VPISession().yield(); args[0].operator VPISession().yield();
}); });
TVM_REGISTER_API(_vpi_SessShutdown) TVM_REGISTER_API("_vpi_SessShutdown")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
args[0].operator VPISession().shutdown(); args[0].operator VPISession().shutdown();
}); });
TVM_REGISTER_API(_vpi_HandlePutInt) TVM_REGISTER_API("_vpi_HandlePutInt")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
args[0].operator VPIHandle().put_int(args[1]); args[0].operator VPIHandle().put_int(args[1]);
}); });
TVM_REGISTER_API(_vpi_HandleGetInt) TVM_REGISTER_API("_vpi_HandleGetInt")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = args[0].operator VPIHandle().get_int(); *ret = args[0].operator VPIHandle().get_int();
}); });
TVM_REGISTER_API(_vpi_HandleGetName) TVM_REGISTER_API("_vpi_HandleGetName")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = args[0].operator VPIHandle().name(); *ret = args[0].operator VPIHandle().name();
}); });
TVM_REGISTER_API(_vpi_HandleGetSize) TVM_REGISTER_API("_vpi_HandleGetSize")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = args[0].operator VPIHandle().size(); *ret = args[0].operator VPIHandle().size();
}); });
TVM_REGISTER_API(_vpi_HandleGetHandleByName) TVM_REGISTER_API("_vpi_HandleGetHandleByName")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = args[0].operator VPIHandle().operator[](args[1]); *ret = args[0].operator VPIHandle().operator[](args[1]);
}); });
......
...@@ -46,7 +46,7 @@ class DeviceAPIManager { ...@@ -46,7 +46,7 @@ class DeviceAPIManager {
if (api_[type] != nullptr) return api_[type]; if (api_[type] != nullptr) return api_[type];
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
if (api_[type] != nullptr) return api_[type]; if (api_[type] != nullptr) return api_[type];
std::string factory = "_device_api_" + DeviceName(type); std::string factory = "device_api." + DeviceName(type);
auto* f = Registry::Get(factory); auto* f = Registry::Get(factory);
CHECK(f != nullptr) CHECK(f != nullptr)
<< "Device API " << DeviceName(type) << " is not enabled."; << "Device API " << DeviceName(type) << " is not enabled.";
......
...@@ -46,7 +46,7 @@ class CPUDeviceAPI : public DeviceAPI { ...@@ -46,7 +46,7 @@ class CPUDeviceAPI : public DeviceAPI {
} }
}; };
TVM_REGISTER_GLOBAL(_device_api_cpu) TVM_REGISTER_GLOBAL("device_api.cpu")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
static CPUDeviceAPI inst; static CPUDeviceAPI inst;
DeviceAPI* ptr = &inst; DeviceAPI* ptr = &inst;
......
...@@ -77,7 +77,7 @@ class CUDADeviceAPI : public DeviceAPI { ...@@ -77,7 +77,7 @@ class CUDADeviceAPI : public DeviceAPI {
} }
}; };
TVM_REGISTER_GLOBAL(_device_api_gpu) TVM_REGISTER_GLOBAL("device_api.gpu")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
static CUDADeviceAPI inst; static CUDADeviceAPI inst;
DeviceAPI* ptr = &inst; DeviceAPI* ptr = &inst;
......
...@@ -281,12 +281,12 @@ Module CUDAModuleLoad(const std::string& file_name, ...@@ -281,12 +281,12 @@ Module CUDAModuleLoad(const std::string& file_name,
return CUDAModuleCreate(data, fmt, fmap, std::string()); return CUDAModuleCreate(data, fmt, fmap, std::string());
} }
TVM_REGISTER_GLOBAL(_module_loadfile_cubin) TVM_REGISTER_GLOBAL("module.loadfile_cubin")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = CUDAModuleLoad(args[0], args[1]); *rv = CUDAModuleLoad(args[0], args[1]);
}); });
TVM_REGISTER_GLOBAL(_module_loadfile_ptx) TVM_REGISTER_GLOBAL("module.loadfile_ptx")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = CUDAModuleLoad(args[0], args[1]); *rv = CUDAModuleLoad(args[0], args[1]);
}); });
......
...@@ -110,7 +110,7 @@ class DSOModuleNode : public ModuleNode { ...@@ -110,7 +110,7 @@ class DSOModuleNode : public ModuleNode {
#endif #endif
}; };
TVM_REGISTER_GLOBAL(_module_loadfile_so) TVM_REGISTER_GLOBAL("module.loadfile_so")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
std::shared_ptr<DSOModuleNode> n = std::make_shared<DSOModuleNode>(); std::shared_ptr<DSOModuleNode> n = std::make_shared<DSOModuleNode>();
n->Init(args[0]); n->Init(args[0]);
......
...@@ -53,7 +53,7 @@ Module Module::LoadFromFile(const std::string& file_name, ...@@ -53,7 +53,7 @@ Module Module::LoadFromFile(const std::string& file_name,
if (fmt == "dll" || fmt == "dylib" || fmt == "dso") { if (fmt == "dll" || fmt == "dylib" || fmt == "dso") {
fmt = "so"; fmt = "so";
} }
std::string load_f_name = "_module_loadfile_" + fmt; std::string load_f_name = "module.loadfile_" + fmt;
const PackedFunc* f = Registry::Get(load_f_name); const PackedFunc* f = Registry::Get(load_f_name);
CHECK(f != nullptr) CHECK(f != nullptr)
<< "Loader of " << format << "(" << "Loader of " << format << "("
...@@ -88,48 +88,48 @@ bool RuntimeEnabled(const std::string& target) { ...@@ -88,48 +88,48 @@ bool RuntimeEnabled(const std::string& target) {
if (target == "cpu") { if (target == "cpu") {
return true; return true;
} else if (target == "cuda" || target == "gpu") { } else if (target == "cuda" || target == "gpu") {
load_f_name = "_module_loadfile_ptx"; load_f_name = "module.loadfile_ptx";
} else if (target == "cl" || target == "opencl") { } else if (target == "cl" || target == "opencl") {
load_f_name = "_module_loadfile_cl"; load_f_name = "module.loadfile_cl";
} else { } else {
LOG(FATAL) << "Unknown optional runtime " << target; LOG(FATAL) << "Unknown optional runtime " << target;
} }
return runtime::Registry::Get(load_f_name) != nullptr; return runtime::Registry::Get(load_f_name) != nullptr;
} }
TVM_REGISTER_GLOBAL(_module__Enabled) TVM_REGISTER_GLOBAL("module._Enabled")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = RuntimeEnabled(args[0]); *ret = RuntimeEnabled(args[0]);
}); });
TVM_REGISTER_GLOBAL(_module__GetSource) TVM_REGISTER_GLOBAL("module._GetSource")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = args[0].operator Module()->GetSource(args[1]); *ret = args[0].operator Module()->GetSource(args[1]);
}); });
TVM_REGISTER_GLOBAL(_module__ImportsSize) TVM_REGISTER_GLOBAL("module._ImportsSize")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = static_cast<int64_t>( *ret = static_cast<int64_t>(
args[0].operator Module()->imports().size()); args[0].operator Module()->imports().size());
}); });
TVM_REGISTER_GLOBAL(_module__GetImport) TVM_REGISTER_GLOBAL("module._GetImport")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = args[0].operator Module()-> *ret = args[0].operator Module()->
imports().at(args[1].operator int()); imports().at(args[1].operator int());
}); });
TVM_REGISTER_GLOBAL(_module__GetTypeKey) TVM_REGISTER_GLOBAL("module._GetTypeKey")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = std::string(args[0].operator Module()->type_key()); *ret = std::string(args[0].operator Module()->type_key());
}); });
TVM_REGISTER_GLOBAL(_module__LoadFromFile) TVM_REGISTER_GLOBAL("module._LoadFromFile")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = Module::LoadFromFile(args[0], args[1]); *ret = Module::LoadFromFile(args[0], args[1]);
}); });
TVM_REGISTER_GLOBAL(_module__SaveToFile) TVM_REGISTER_GLOBAL("module._SaveToFile")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
args[0].operator Module()-> args[0].operator Module()->
SaveToFile(args[1], args[2]); SaveToFile(args[1], args[2]);
......
...@@ -189,10 +189,10 @@ bool InitOpenCL(TVMArgs args, TVMRetValue* rv) { ...@@ -189,10 +189,10 @@ bool InitOpenCL(TVMArgs args, TVMRetValue* rv) {
return true; return true;
} }
TVM_REGISTER_GLOBAL(_module_init_opencl) TVM_REGISTER_GLOBAL("module.init_opencl")
.set_body(InitOpenCL); .set_body(InitOpenCL);
TVM_REGISTER_GLOBAL(_device_api_opencl) TVM_REGISTER_GLOBAL("device_api.opencl")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
DeviceAPI* ptr = OpenCLWorkspace::Global(); DeviceAPI* ptr = OpenCLWorkspace::Global();
*rv = static_cast<void*>(ptr); *rv = static_cast<void*>(ptr);
......
...@@ -314,12 +314,12 @@ Module OpenCLModuleLoad(const std::string& file_name, ...@@ -314,12 +314,12 @@ Module OpenCLModuleLoad(const std::string& file_name,
return OpenCLModuleCreate(data, fmt, fmap); return OpenCLModuleCreate(data, fmt, fmap);
} }
TVM_REGISTER_GLOBAL(_module_loadfile_cl) TVM_REGISTER_GLOBAL("module.loadfile_cl")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = OpenCLModuleLoad(args[0], args[1]); *rv = OpenCLModuleLoad(args[0], args[1]);
}); });
TVM_REGISTER_GLOBAL(_module_loadfile_clbin) TVM_REGISTER_GLOBAL("module.loadfile_clbin")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = OpenCLModuleLoad(args[0], args[1]); *rv = OpenCLModuleLoad(args[0], args[1]);
}); });
......
...@@ -14,6 +14,25 @@ def test_get_global(): ...@@ -14,6 +14,25 @@ def test_get_global():
y = f(*targs) y = f(*targs)
assert y == 10 assert y == 10
def test_get_callback_with_node():
x = tvm.convert(10)
def test(y):
assert y.handle != x.handle
return y
f2 = tvm.convert(test)
# register into global function table
@tvm.register_func
def my_callback_with_node(y, f):
assert y == x
return f(y)
# get it out from global function table
f = tvm.get_global_func("my_callback_with_node")
assert isinstance(f, tvm.Function)
y = f(x, f2)
assert(y.value == 10)
def test_return_func(): def test_return_func():
def addy(y): def addy(y):
...@@ -45,6 +64,7 @@ def test_byte_array(): ...@@ -45,6 +64,7 @@ def test_byte_array():
f(a) f(a)
if __name__ == "__main__": if __name__ == "__main__":
test_get_callback_with_node()
test_convert() test_convert()
test_get_global() test_get_global()
test_return_func() test_return_func()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment