Commit 181edb4a by Tianqi Chen Committed by GitHub

[LANG] Change namespace convention to dot (#100)

parent 97c67e53
......@@ -12,6 +12,7 @@
#include <string>
#include <memory>
#include <functional>
#include "./runtime/registry.h"
namespace tvm {
......
......@@ -20,6 +20,18 @@ TVM_EXTERN_C {
typedef void* NodeHandle;
/*!
* \brief Inplace translate callback argument value to return value.
* This is only needed for non-POD arguments.
*
* \param value The value to be translated.
* \param code The type code to be translated.
* \note This function will do a shallow copy when necessary.
*
* \return 0 when success, -1 when failure happens.
*/
TVM_DLL int TVMCbArgToReturn(TVMValue* value, int code);
/*!
* \brief free the node handle
* \param handle The node handle to be freed.
* \return 0 when success, -1 when failure happens
......
......@@ -81,7 +81,6 @@ class Registry {
friend struct Manager;
};
/*! \brief helper macro to supress unused warning */
#if defined(__GNUC__)
#define TVM_ATTRIBUTE_UNUSED __attribute__((unused))
......@@ -89,19 +88,23 @@ class Registry {
#define TVM_ATTRIBUTE_UNUSED
#endif
#define TVM_STR_CONCAT_(__x, __y) __x##__y
#define TVM_STR_CONCAT(__x, __y) TVM_STR_CONCAT_(__x, __y)
#define TVM_FUNC_REG_VAR_DEF \
static TVM_ATTRIBUTE_UNUSED ::tvm::runtime::Registry& __make_ ## TVMOp
/*!
* \brief Register a function globally.
* \code
* TVM_REGISTER_GLOBAL(MyPrint)
* TVM_REGISTER_GLOBAL("MyPrint")
* .set_body([](TVMArgs args, TVMRetValue* rv) {
* // my code.
* });
* \endcode
*/
#define TVM_REGISTER_GLOBAL(OpName) \
static TVM_ATTRIBUTE_UNUSED ::tvm::runtime::Registry& \
__make_TVMRegistry_ ## OpName = \
::tvm::runtime::Registry::Register(#OpName)
TVM_STR_CONCAT(TVM_FUNC_REG_VAR_DEF, __COUNTER__) = \
::tvm::runtime::Registry::Register(OpName)
} // namespace runtime
} // namespace tvm
......
......@@ -12,7 +12,7 @@ from .._base import _LIB, check_call
from .._base import c_str, py_str, string_types
from ._types import TVMValue, TypeCode, TVMType, TVMByteArray
from ._types import TVMPackedCFunc, TVMCFuncFinalizer
from ._types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH
from ._types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func
from ._node import NodeBase, SliceBase, convert_to_node
from ._ndarray import NDArrayBase
......@@ -302,6 +302,10 @@ def _handle_return_func(x):
# setup return handle for function type
RETURN_SWITCH[TypeCode.FUNC_HANDLE] = _handle_return_func
RETURN_SWITCH[TypeCode.MODULE_HANDLE] = _return_module
C_TO_PY_ARG_SWITCH[TypeCode.FUNC_HANDLE] = _wrap_arg_func(
_handle_return_func, TypeCode.FUNC_HANDLE)
C_TO_PY_ARG_SWITCH[TypeCode.MODULE_HANDLE] = _wrap_arg_func(
_return_module, TypeCode.MODULE_HANDLE)
def register_func(func_name, f=None):
......@@ -415,35 +419,31 @@ def _get_api(f):
return flocal(*args)
return my_api_func
def _init_api(mod):
def _init_api(namespace):
"""Initialize api for a given module name
mod : str
The name of the module.
"""
module = sys.modules[mod]
namespace_match = {
"_make_": "tvm.make",
"_arith_": "tvm.arith",
"_pass_": "tvm.ir_pass",
"_codegen_": "tvm.codegen",
"_module_": "tvm.module",
"_schedule_": "tvm.schedule"
}
module = sys.modules[namespace]
assert namespace.startswith("tvm.")
prefix = namespace[4:]
for name in list_global_func_names():
fname = name
target = "tvm.api"
for k, v in namespace_match.items():
if name.startswith(k):
fname = name[len(k):]
target = v
if target != mod:
continue
if mod == "tvm.api" and name.startswith("_"):
target_module = sys.modules["tvm._api_internal"]
if prefix == "api":
fname = name
if name.startswith("_"):
target_module = sys.modules["tvm._api_internal"]
else:
target_module = module
else:
if not name.startswith(prefix):
continue
fname = name[len(prefix)+1:]
target_module = module
if fname.find(".") != -1:
continue
f = get_global_func(name)
ff = _get_api(f)
ff.__name__ = fname
......
......@@ -10,7 +10,9 @@ from numbers import Number, Integral
from .._base import _LIB, check_call
from .._base import c_str, py_str, string_types
from .. import _api_internal
from ._types import TVMValue, TypeCode, RETURN_SWITCH
from ._types import TVMValue, TypeCode
from ._types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func
NodeHandle = ctypes.c_void_p
......@@ -19,7 +21,7 @@ NODE_TYPE = {
}
def _return_node(x):
"""Return function"""
"""Return node function"""
handle = x.v_handle
if not isinstance(handle, NodeHandle):
handle = NodeHandle(handle)
......@@ -35,6 +37,8 @@ def _return_node(x):
RETURN_SWITCH[TypeCode.NODE_HANDLE] = _return_node
C_TO_PY_ARG_SWITCH[TypeCode.NODE_HANDLE] = _wrap_arg_func(
_return_node, TypeCode.NODE_HANDLE)
class SliceBase(object):
......
......@@ -4,7 +4,7 @@ from __future__ import absolute_import as _abs
import ctypes
import numpy as np
from .._base import py_str
from .._base import py_str, check_call, _LIB
tvm_shape_index_t = ctypes.c_int64
......@@ -130,6 +130,12 @@ def _return_bytes(x):
raise RuntimeError('memmove failed')
return res
def _wrap_arg_func(return_f, type_code):
tcode = ctypes.c_int(type_code)
def _wrap_func(x):
check_call(_LIB.TVMCbArgToReturn(ctypes.byref(x), tcode))
return return_f(x)
return _wrap_func
RETURN_SWITCH = {
TypeCode.INT: lambda x: x.v_int64,
......
......@@ -11,49 +11,49 @@
namespace tvm {
namespace arith {
TVM_REGISTER_API(_arith_intset_single_point)
TVM_REGISTER_API("arith.intset_single_point")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = IntSet::single_point(args[0]);
});
TVM_REGISTER_API(_arith_intset_interval)
TVM_REGISTER_API("arith.intset_interval")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = IntSet::interval(args[0], args[1]);
});
TVM_REGISTER_API(_arith_EvalModular)
TVM_REGISTER_API("arith.EvalModular")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = EvalModular(args[0], Map<Var, IntSet>());
});
TVM_REGISTER_API(_arith_DetectLinearEquation)
TVM_REGISTER_API("arith.DetectLinearEquation")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = DetectLinearEquation(args[0], args[1]);
});
TVM_REGISTER_API(_arith_DeduceBound)
TVM_REGISTER_API("arith.DeduceBound")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = DeduceBound(args[0], args[1],
args[2].operator Map<Var, IntSet>(),
args[3].operator Map<Var, IntSet>());
});
TVM_REGISTER_API(_IntervalSetGetMin)
TVM_REGISTER_API("_IntervalSetGetMin")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = args[0].operator IntSet().min();
});
TVM_REGISTER_API(_IntervalSetGetMax)
TVM_REGISTER_API("_IntervalSetGetMax")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = args[0].operator IntSet().max();
});
TVM_REGISTER_API(_IntSetIsNothing)
TVM_REGISTER_API("_IntSetIsNothing")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = args[0].operator IntSet().is_nothing();
});
TVM_REGISTER_API(_IntSetIsEverything)
TVM_REGISTER_API("_IntSetIsEverything")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = args[0].operator IntSet().is_everything();
});
......
/*!
/*!
* Copyright (c) 2017 by Contributors
* Implementation of basic API functions
* \file api_base.cc
......@@ -9,7 +9,7 @@
namespace tvm {
TVM_REGISTER_API(_format_str)
TVM_REGISTER_API("_format_str")
.set_body([](TVMArgs args, TVMRetValue *ret) {
CHECK(args[0].type_code() == kNodeHandle);
std::ostringstream os;
......@@ -17,19 +17,19 @@ TVM_REGISTER_API(_format_str)
*ret = os.str();
});
TVM_REGISTER_API(_raw_ptr)
TVM_REGISTER_API("_raw_ptr")
.set_body([](TVMArgs args, TVMRetValue *ret) {
CHECK(args[0].type_code() == kNodeHandle);
*ret = reinterpret_cast<int64_t>(
args[0].node_sptr().get());
});
TVM_REGISTER_API(_save_json)
TVM_REGISTER_API("_save_json")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = SaveJSON(args[0]);
});
TVM_REGISTER_API(_load_json)
TVM_REGISTER_API("_load_json")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = NodeRef(LoadJSON_(args[0]));
});
......
......@@ -12,7 +12,7 @@
namespace tvm {
namespace codegen {
TVM_REGISTER_API(_codegen__Build)
TVM_REGISTER_API("codegen._Build")
.set_body([](TVMArgs args, TVMRetValue *ret) {
if (args[0].IsNodeType<LoweredFunc>()) {
*ret = Build({args[0]}, args[1]);
......@@ -21,7 +21,7 @@ TVM_REGISTER_API(_codegen__Build)
}
});
TVM_REGISTER_API(_codegen__Enabled)
TVM_REGISTER_API("codegen._Enabled")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = TargetEnabled(args[0]);
});
......
......@@ -11,12 +11,12 @@
namespace tvm {
namespace ir {
TVM_REGISTER_API(_Var)
TVM_REGISTER_API("_Var")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = Variable::make(args[1], args[0]);
});
TVM_REGISTER_API(_make_For)
TVM_REGISTER_API("make.For")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = For::make(args[0],
args[1],
......@@ -26,7 +26,7 @@ TVM_REGISTER_API(_make_For)
args[5]);
});
TVM_REGISTER_API(_make_Realize)
TVM_REGISTER_API("make.Realize")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = Realize::make(args[0],
args[1],
......@@ -37,7 +37,7 @@ TVM_REGISTER_API(_make_Realize)
});
TVM_REGISTER_API(_make_Call)
TVM_REGISTER_API("make.Call")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = Call::make(args[0],
args[1],
......@@ -47,7 +47,7 @@ TVM_REGISTER_API(_make_Call)
args[5]);
});
TVM_REGISTER_API(_make_Allocate)
TVM_REGISTER_API("make.Allocate")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = Allocate::make(args[0],
args[1],
......@@ -58,31 +58,31 @@ TVM_REGISTER_API(_make_Allocate)
// make from two arguments
#define REGISTER_MAKE1(Node) \
TVM_REGISTER_API(_make_## Node) \
TVM_REGISTER_API("make."#Node) \
.set_body([](TVMArgs args, TVMRetValue *ret) { \
*ret = Node::make(args[0]); \
}) \
#define REGISTER_MAKE2(Node) \
TVM_REGISTER_API(_make_## Node) \
.set_body([](TVMArgs args, TVMRetValue *ret) { \
TVM_REGISTER_API("make."#Node) \
.set_body([](TVMArgs args, TVMRetValue *ret) { \
*ret = Node::make(args[0], args[1]); \
}) \
#define REGISTER_MAKE3(Node) \
TVM_REGISTER_API(_make_## Node) \
TVM_REGISTER_API("make."#Node) \
.set_body([](TVMArgs args, TVMRetValue *ret) { \
*ret = Node::make(args[0], args[1], args[2]); \
}) \
#define REGISTER_MAKE4(Node) \
TVM_REGISTER_API(_make_## Node) \
TVM_REGISTER_API("make."#Node) \
.set_body([](TVMArgs args, TVMRetValue *ret) { \
*ret = Node::make(args[0], args[1], args[2], args[3]); \
}) \
*ret = Node::make(args[0], args[1], args[2], args[3]); \
}) \
#define REGISTER_MAKE_BINARY_OP(Node) \
TVM_REGISTER_API(_make_## Node) \
TVM_REGISTER_API("make."#Node) \
.set_body([](TVMArgs args, TVMRetValue *ret) { \
Expr a = args[0], b = args[1]; \
match_types(a, b); \
......
......@@ -13,7 +13,7 @@
namespace tvm {
TVM_REGISTER_API(_const)
TVM_REGISTER_API("_const")
.set_body([](TVMArgs args, TVMRetValue* ret) {
if (args[0].type_code() == kInt) {
*ret = make_const(args[1], args[0].operator int64_t());
......@@ -25,13 +25,13 @@ TVM_REGISTER_API(_const)
});
TVM_REGISTER_API(_str)
TVM_REGISTER_API("_str")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = ir::StringImm::make(args[0]);
});
TVM_REGISTER_API(_Array)
TVM_REGISTER_API("_Array")
.set_body([](TVMArgs args, TVMRetValue* ret) {
std::vector<std::shared_ptr<Node> > data;
for (int i = 0; i < args.size(); ++i) {
......@@ -42,7 +42,7 @@ TVM_REGISTER_API(_Array)
*ret = node;
});
TVM_REGISTER_API(_ArrayGetItem)
TVM_REGISTER_API("_ArrayGetItem")
.set_body([](TVMArgs args, TVMRetValue* ret) {
int64_t i = args[1];
auto& sptr = args[0].node_sptr();
......@@ -53,7 +53,7 @@ TVM_REGISTER_API(_ArrayGetItem)
*ret = n->data[static_cast<size_t>(i)];
});
TVM_REGISTER_API(_ArraySize)
TVM_REGISTER_API("_ArraySize")
.set_body([](TVMArgs args, TVMRetValue* ret) {
auto& sptr = args[0].node_sptr();
CHECK(sptr->is_type<ArrayNode>());
......@@ -61,7 +61,7 @@ TVM_REGISTER_API(_ArraySize)
static_cast<const ArrayNode*>(sptr.get())->data.size());
});
TVM_REGISTER_API(_Map)
TVM_REGISTER_API("_Map")
.set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK_EQ(args.size() % 2, 0);
MapNode::ContainerType data;
......@@ -78,7 +78,7 @@ TVM_REGISTER_API(_Map)
*ret = node;
});
TVM_REGISTER_API(_MapSize)
TVM_REGISTER_API("_MapSize")
.set_body([](TVMArgs args, TVMRetValue* ret) {
auto& sptr = args[0].node_sptr();
CHECK(sptr->is_type<MapNode>());
......@@ -86,7 +86,7 @@ TVM_REGISTER_API(_MapSize)
*ret = static_cast<int64_t>(n->data.size());
});
TVM_REGISTER_API(_MapGetItem)
TVM_REGISTER_API("_MapGetItem")
.set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK(args[0].type_code() == kNodeHandle);
CHECK(args[1].type_code() == kNodeHandle);
......@@ -99,7 +99,7 @@ TVM_REGISTER_API(_MapGetItem)
*ret = (*it).second;
});
TVM_REGISTER_API(_MapCount)
TVM_REGISTER_API("_MapCount")
.set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK(args[0].type_code() == kNodeHandle);
CHECK(args[1].type_code() == kNodeHandle);
......@@ -110,7 +110,7 @@ TVM_REGISTER_API(_MapCount)
n->data.count(args[1].node_sptr()));
});
TVM_REGISTER_API(_MapItems)
TVM_REGISTER_API("_MapItems")
.set_body([](TVMArgs args, TVMRetValue* ret) {
auto& sptr = args[0].node_sptr();
CHECK(sptr->is_type<MapNode>());
......@@ -123,7 +123,7 @@ TVM_REGISTER_API(_MapItems)
*ret = rkvs;
});
TVM_REGISTER_API(Range)
TVM_REGISTER_API("Range")
.set_body([](TVMArgs args, TVMRetValue* ret) {
if (args.size() == 1) {
*ret = Range(0, args[0]);
......@@ -132,7 +132,7 @@ TVM_REGISTER_API(Range)
}
});
TVM_REGISTER_API(_Buffer)
TVM_REGISTER_API("_Buffer")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = BufferNode::make(args[0],
args[1],
......@@ -143,7 +143,7 @@ TVM_REGISTER_API(_Buffer)
args[6]);
});
TVM_REGISTER_API(_Tensor)
TVM_REGISTER_API("_Tensor")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = TensorNode::make(args[0],
args[1],
......@@ -151,32 +151,32 @@ TVM_REGISTER_API(_Tensor)
args[3]);
});
TVM_REGISTER_API(_TensorEqual)
TVM_REGISTER_API("_TensorEqual")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Tensor() == args[1].operator Tensor();
});
TVM_REGISTER_API(_TensorHash)
TVM_REGISTER_API("_TensorHash")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = static_cast<int64_t>(
std::hash<Tensor>()(args[0].operator Tensor()));
});
TVM_REGISTER_API(_Placeholder)
TVM_REGISTER_API("_Placeholder")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = placeholder(args[0],
args[1],
args[2]);
});
TVM_REGISTER_API(_ComputeOp)
TVM_REGISTER_API("_ComputeOp")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = ComputeOpNode::make(args[0],
args[1],
args[2]);
});
TVM_REGISTER_API(_ScanOp)
TVM_REGISTER_API("_ScanOp")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = ScanOpNode::make(args[0],
args[1],
......@@ -186,7 +186,7 @@ TVM_REGISTER_API(_ScanOp)
args[5]);
});
TVM_REGISTER_API(_ExternOp)
TVM_REGISTER_API("_ExternOp")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = ExternOpNode::make(args[0],
args[1],
......@@ -195,18 +195,18 @@ TVM_REGISTER_API(_ExternOp)
args[4]);
});
TVM_REGISTER_API(_OpGetOutput)
TVM_REGISTER_API("_OpGetOutput")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Operation().output(
static_cast<size_t>(args[1].operator int64_t()));
});
TVM_REGISTER_API(_OpNumOutputs)
TVM_REGISTER_API("_OpNumOutputs")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Operation()->num_outputs();
});
TVM_REGISTER_API(_IterVar)
TVM_REGISTER_API("_IterVar")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = IterVarNode::make(
args[0], args[1],
......@@ -214,24 +214,24 @@ TVM_REGISTER_API(_IterVar)
args[3]);
});
TVM_REGISTER_API(_Schedule)
TVM_REGISTER_API("_Schedule")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = Schedule(args[0].operator Array<Operation>());
});
TVM_REGISTER_API(_StageSetScope)
TVM_REGISTER_API("_StageSetScope")
.set_body([](TVMArgs args, TVMRetValue* ret) {
args[0].operator Stage()
.set_scope(args[1]);
});
TVM_REGISTER_API(_StageBind)
TVM_REGISTER_API("_StageBind")
.set_body([](TVMArgs args, TVMRetValue* ret) {
args[0].operator Stage()
.bind(args[1], args[2]);
});
TVM_REGISTER_API(_StageSplitByFactor)
TVM_REGISTER_API("_StageSplitByFactor")
.set_body([](TVMArgs args, TVMRetValue* ret) {
IterVar outer, inner;
args[0].operator Stage()
......@@ -239,7 +239,7 @@ TVM_REGISTER_API(_StageSplitByFactor)
*ret = Array<IterVar>({outer, inner});
});
TVM_REGISTER_API(_StageSplitByNParts)
TVM_REGISTER_API("_StageSplitByNParts")
.set_body([](TVMArgs args, TVMRetValue* ret) {
IterVar outer, inner;
args[0].operator Stage()
......@@ -247,7 +247,7 @@ TVM_REGISTER_API(_StageSplitByNParts)
*ret = Array<IterVar>({outer, inner});
});
TVM_REGISTER_API(_StageFuse)
TVM_REGISTER_API("_StageFuse")
.set_body([](TVMArgs args, TVMRetValue* ret) {
IterVar fused;
args[0].operator Stage()
......@@ -255,31 +255,31 @@ TVM_REGISTER_API(_StageFuse)
*ret = fused;
});
TVM_REGISTER_API(_StageComputeAt)
TVM_REGISTER_API("_StageComputeAt")
.set_body([](TVMArgs args, TVMRetValue* ret) {
args[0].operator Stage()
.compute_at(args[1], args[2]);
});
TVM_REGISTER_API(_StageComputeInline)
TVM_REGISTER_API("_StageComputeInline")
.set_body([](TVMArgs args, TVMRetValue* ret) {
args[0].operator Stage()
.compute_inline();
});
TVM_REGISTER_API(_StageComputeRoot)
TVM_REGISTER_API("_StageComputeRoot")
.set_body([](TVMArgs args, TVMRetValue* ret) {
args[0].operator Stage()
.compute_root();
});
TVM_REGISTER_API(_StageReorder)
TVM_REGISTER_API("_StageReorder")
.set_body([](TVMArgs args, TVMRetValue* ret) {
args[0].operator Stage()
.reorder(args[1]);
});
TVM_REGISTER_API(_StageTile)
TVM_REGISTER_API("_StageTile")
.set_body([](TVMArgs args, TVMRetValue* ret) {
IterVar x_outer, y_outer, x_inner, y_inner;
args[0].operator Stage()
......@@ -290,55 +290,55 @@ TVM_REGISTER_API(_StageTile)
*ret = Array<IterVar>({x_outer, y_outer, x_inner, y_inner});
});
TVM_REGISTER_API(_StageEnvThreads)
TVM_REGISTER_API("_StageEnvThreads")
.set_body([](TVMArgs args, TVMRetValue* ret) {
args[0].operator Stage()
.env_threads(args[1]);
});
TVM_REGISTER_API(_StageUnroll)
TVM_REGISTER_API("_StageUnroll")
.set_body([](TVMArgs args, TVMRetValue* ret) {
args[0].operator Stage()
.unroll(args[1]);
});
TVM_REGISTER_API(_StageVectorize)
TVM_REGISTER_API("_StageVectorize")
.set_body([](TVMArgs args, TVMRetValue* ret) {
args[0].operator Stage()
.vectorize(args[1]);
});
TVM_REGISTER_API(_StageParallel)
TVM_REGISTER_API("_StageParallel")
.set_body([](TVMArgs args, TVMRetValue* ret) {
args[0].operator Stage()
.parallel(args[1]);
});
TVM_REGISTER_API(_ScheduleNormalize)
TVM_REGISTER_API("_ScheduleNormalize")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Schedule()
.normalize();
});
TVM_REGISTER_API(_ScheduleCreateGroup)
TVM_REGISTER_API("_ScheduleCreateGroup")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Schedule()
.create_group(args[1], args[2], args[3]);
});
TVM_REGISTER_API(_ScheduleCacheRead)
TVM_REGISTER_API("_ScheduleCacheRead")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Schedule()
.cache_read(args[1], args[2], args[3]);
});
TVM_REGISTER_API(_ScheduleCacheWrite)
TVM_REGISTER_API("_ScheduleCacheWrite")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Schedule()
.cache_write(args[1], args[2]);
});
TVM_REGISTER_API(_ScheduleRFactor)
TVM_REGISTER_API("_ScheduleRFactor")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Schedule()
.rfactor(args[1], args[2]);
......
......@@ -12,7 +12,7 @@
namespace tvm {
namespace ir {
TVM_REGISTER_API(_pass_Simplify)
TVM_REGISTER_API("ir_pass.Simplify")
.set_body([](TVMArgs args, TVMRetValue *ret) {
if (args[0].IsNodeType<Stmt>()) {
*ret = Simplify(args[0].operator Stmt());
......@@ -21,7 +21,7 @@ TVM_REGISTER_API(_pass_Simplify)
}
});
TVM_REGISTER_API(_pass_Equal)
TVM_REGISTER_API("ir_pass.Equal")
.set_body([](TVMArgs args, TVMRetValue *ret) {
if (args[0].IsNodeType<Stmt>()) {
*ret = Equal(args[0].operator Stmt(), args[1].operator Stmt());
......@@ -30,7 +30,7 @@ TVM_REGISTER_API(_pass_Equal)
}
});
TVM_REGISTER_API(_pass_PostOrderVisit)
TVM_REGISTER_API("ir_pass.PostOrderVisit")
.set_body([](TVMArgs args, TVMRetValue *ret) {
PackedFunc f = args[1];
ir::PostOrderVisit(args[0], [f](const NodeRef& n) {
......@@ -40,19 +40,19 @@ TVM_REGISTER_API(_pass_PostOrderVisit)
// make from two arguments
#define REGISTER_PASS1(PassName) \
TVM_REGISTER_API(_pass_## PassName) \
TVM_REGISTER_API("ir_pass."#PassName) \
.set_body([](TVMArgs args, TVMRetValue *ret) { \
*ret = PassName(args[0]); \
}) \
#define REGISTER_PASS2(PassName) \
TVM_REGISTER_API(_pass_## PassName) \
TVM_REGISTER_API("ir_pass."#PassName) \
.set_body([](TVMArgs args, TVMRetValue *ret) { \
*ret = PassName(args[0], args[1]); \
}) \
#define REGISTER_PASS4(PassName) \
TVM_REGISTER_API(_pass_## PassName) \
TVM_REGISTER_API("ir_pass."#PassName) \
.set_body([](TVMArgs args, TVMRetValue *ret) { \
*ret = PassName(args[0], args[1], args[2], args[3]); \
}) \
......
......@@ -13,19 +13,19 @@
namespace tvm {
namespace schedule {
TVM_REGISTER_API(_schedule_AutoInlineElemWise)
TVM_REGISTER_API("schedule.AutoInlineElemWise")
.set_body([](TVMArgs args, TVMRetValue* ret) {
AutoInlineElemWise(args[0]);
});
#define REGISTER_SCHEDULE_PASS1(PassName) \
TVM_REGISTER_API(_schedule_## PassName) \
TVM_REGISTER_API("schedule."#PassName) \
.set_body([](TVMArgs args, TVMRetValue *ret) { \
*ret = PassName(args[0]); \
}) \
#define REGISTER_SCHEDULE_PASS2(PassName) \
TVM_REGISTER_API(_schedule_## PassName) \
TVM_REGISTER_API("schedule."#PassName) \
.set_body([](TVMArgs args, TVMRetValue *ret) { \
*ret = PassName(args[0], args[1]); \
}) \
......
......@@ -103,6 +103,23 @@ int TVMNodeFree(NodeHandle handle) {
API_END();
}
int TVMCbArgToReturn(TVMValue* value, int code) {
API_BEGIN();
tvm::runtime::TVMRetValue rv;
rv = tvm::runtime::TVMArgValue(*value, code);
int tcode;
rv.MoveToCHost(value, &tcode);
CHECK_EQ(tcode, code);
API_END();
}
int TVMNodeDupe(NodeHandle handle, NodeHandle* out_handle) {
API_BEGIN();
*out_handle = new TVMAPINode(*static_cast<TVMAPINode*>(handle));
API_END();
}
int TVMNodeGetAttr(NodeHandle handle,
const char* key,
TVMValue* ret_val,
......
......@@ -86,7 +86,7 @@ runtime::Module BuildCUDA(Array<LoweredFunc> funcs) {
return CUDAModuleCreate(ptx, fmt, fmap, code);
}
TVM_REGISTER_API(_codegen_build_cuda)
TVM_REGISTER_API("codegen.build_cuda")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = BuildCUDA(args[0]);
});
......
......@@ -41,7 +41,7 @@ runtime::Module BuildOpenCL(Array<LoweredFunc> funcs) {
return OpenCLModuleCreate(code, "cl", fmap);
}
TVM_REGISTER_API(_codegen_build_opencl)
TVM_REGISTER_API("codegen.build_opencl")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = BuildOpenCL(args[0]);
});
......
......@@ -17,7 +17,7 @@ runtime::Module Build(const Array<LoweredFunc>& funcs,
if (pos != std::string::npos) {
mode = mode.substr(0, pos);
}
std::string build_f_name = "_codegen_build_" + mode;
std::string build_f_name = "codegen.build_" + mode;
const PackedFunc* bf = runtime::Registry::Get(build_f_name);
CHECK(bf != nullptr)
......@@ -27,7 +27,7 @@ runtime::Module Build(const Array<LoweredFunc>& funcs,
}
bool TargetEnabled(const std::string& target) {
std::string build_f_name = "_codegen_build_" + target;
std::string build_f_name = "codegen.build_" + target;
return runtime::Registry::Get(build_f_name) != nullptr;
}
......
......@@ -150,7 +150,7 @@ class LLVMModuleNode : public runtime::ModuleNode {
std::shared_ptr<llvm::LLVMContext> ctx_;
};
TVM_REGISTER_API(_codegen_build_llvm)
TVM_REGISTER_API("codegen.build_llvm")
.set_body([](TVMArgs args, TVMRetValue* rv) {
std::shared_ptr<LLVMModuleNode> n = std::make_shared<LLVMModuleNode>();
n->Init(args[0], args[1]);
......
......@@ -69,7 +69,7 @@ class StackVMModuleNode : public runtime::ModuleNode {
std::unordered_map<std::string, StackVM> fmap_;
};
TVM_REGISTER_API(_codegen_build_stackvm)
TVM_REGISTER_API("codegen.build_stackvm")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = StackVMModuleNode::Build(args[0]);
});
......
......@@ -86,7 +86,7 @@ class VerilogModuleNode : public runtime::ModuleNode {
std::string fmt_;
};
TVM_REGISTER_API(_codegen_build_verilog)
TVM_REGISTER_API("codegen.build_verilog")
.set_body([](TVMArgs args, TVMRetValue* rv) {
std::shared_ptr<VerilogModuleNode> n =
std::make_shared<VerilogModuleNode>();
......
......@@ -385,7 +385,7 @@ class VPIWriteMemMap : public VPIMemMapBase {
VPIHandle enable_;
};
TVM_REGISTER_GLOBAL(_device_api_vpi)
TVM_REGISTER_GLOBAL("device_api.vpi")
.set_body([](runtime::TVMArgs args, runtime::TVMRetValue* rv) {
runtime::DeviceAPI* ptr = VPIDeviceAPI::Global();
*rv = static_cast<void*>(ptr);
......@@ -403,13 +403,13 @@ void TVMVPIHook(runtime::TVMArgs args, runtime::TVMRetValue* rv) {
*rv = pf;
}
TVM_REGISTER_GLOBAL(_vpi_module_tvm_vpi_mem_interface)
TVM_REGISTER_GLOBAL("_vpi_module_tvm_vpi_mem_interface")
.set_body(TVMVPIHook<VPIMemoryInterface>);
TVM_REGISTER_GLOBAL(_vpi_module_tvm_vpi_read_mmap)
TVM_REGISTER_GLOBAL("_vpi_module_tvm_vpi_read_mmap")
.set_body(TVMVPIHook<VPIReadMemMap>);
TVM_REGISTER_GLOBAL(_vpi_module_tvm_vpi_write_mmap)
TVM_REGISTER_GLOBAL("_vpi_module_tvm_vpi_write_mmap")
.set_body(TVMVPIHook<VPIWriteMemMap>);
} // namespace codegen
......
......@@ -212,47 +212,47 @@ VPIHandle VPIHandle::operator[](const std::string& name) const {
}
// API registration
TVM_REGISTER_API(_vpi_SessMake)
TVM_REGISTER_API("_vpi_SessMake")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = VPISession::make(args[0], args[1]);
});
TVM_REGISTER_API(_vpi_SessGetHandleByName)
TVM_REGISTER_API("_vpi_SessGetHandleByName")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = args[0].operator VPISession().operator[](args[1]);
});
TVM_REGISTER_API(_vpi_SessYield)
TVM_REGISTER_API("_vpi_SessYield")
.set_body([](TVMArgs args, TVMRetValue *ret) {
args[0].operator VPISession().yield();
});
TVM_REGISTER_API(_vpi_SessShutdown)
TVM_REGISTER_API("_vpi_SessShutdown")
.set_body([](TVMArgs args, TVMRetValue *ret) {
args[0].operator VPISession().shutdown();
});
TVM_REGISTER_API(_vpi_HandlePutInt)
TVM_REGISTER_API("_vpi_HandlePutInt")
.set_body([](TVMArgs args, TVMRetValue *ret) {
args[0].operator VPIHandle().put_int(args[1]);
});
TVM_REGISTER_API(_vpi_HandleGetInt)
TVM_REGISTER_API("_vpi_HandleGetInt")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = args[0].operator VPIHandle().get_int();
});
TVM_REGISTER_API(_vpi_HandleGetName)
TVM_REGISTER_API("_vpi_HandleGetName")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = args[0].operator VPIHandle().name();
});
TVM_REGISTER_API(_vpi_HandleGetSize)
TVM_REGISTER_API("_vpi_HandleGetSize")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = args[0].operator VPIHandle().size();
});
TVM_REGISTER_API(_vpi_HandleGetHandleByName)
TVM_REGISTER_API("_vpi_HandleGetHandleByName")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = args[0].operator VPIHandle().operator[](args[1]);
});
......
......@@ -46,7 +46,7 @@ class DeviceAPIManager {
if (api_[type] != nullptr) return api_[type];
std::lock_guard<std::mutex> lock(mutex_);
if (api_[type] != nullptr) return api_[type];
std::string factory = "_device_api_" + DeviceName(type);
std::string factory = "device_api." + DeviceName(type);
auto* f = Registry::Get(factory);
CHECK(f != nullptr)
<< "Device API " << DeviceName(type) << " is not enabled.";
......
......@@ -46,7 +46,7 @@ class CPUDeviceAPI : public DeviceAPI {
}
};
TVM_REGISTER_GLOBAL(_device_api_cpu)
TVM_REGISTER_GLOBAL("device_api.cpu")
.set_body([](TVMArgs args, TVMRetValue* rv) {
static CPUDeviceAPI inst;
DeviceAPI* ptr = &inst;
......
......@@ -77,7 +77,7 @@ class CUDADeviceAPI : public DeviceAPI {
}
};
TVM_REGISTER_GLOBAL(_device_api_gpu)
TVM_REGISTER_GLOBAL("device_api.gpu")
.set_body([](TVMArgs args, TVMRetValue* rv) {
static CUDADeviceAPI inst;
DeviceAPI* ptr = &inst;
......
......@@ -281,12 +281,12 @@ Module CUDAModuleLoad(const std::string& file_name,
return CUDAModuleCreate(data, fmt, fmap, std::string());
}
TVM_REGISTER_GLOBAL(_module_loadfile_cubin)
TVM_REGISTER_GLOBAL("module.loadfile_cubin")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = CUDAModuleLoad(args[0], args[1]);
});
TVM_REGISTER_GLOBAL(_module_loadfile_ptx)
TVM_REGISTER_GLOBAL("module.loadfile_ptx")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = CUDAModuleLoad(args[0], args[1]);
});
......
......@@ -110,7 +110,7 @@ class DSOModuleNode : public ModuleNode {
#endif
};
TVM_REGISTER_GLOBAL(_module_loadfile_so)
TVM_REGISTER_GLOBAL("module.loadfile_so")
.set_body([](TVMArgs args, TVMRetValue* rv) {
std::shared_ptr<DSOModuleNode> n = std::make_shared<DSOModuleNode>();
n->Init(args[0]);
......
......@@ -53,7 +53,7 @@ Module Module::LoadFromFile(const std::string& file_name,
if (fmt == "dll" || fmt == "dylib" || fmt == "dso") {
fmt = "so";
}
std::string load_f_name = "_module_loadfile_" + fmt;
std::string load_f_name = "module.loadfile_" + fmt;
const PackedFunc* f = Registry::Get(load_f_name);
CHECK(f != nullptr)
<< "Loader of " << format << "("
......@@ -88,48 +88,48 @@ bool RuntimeEnabled(const std::string& target) {
if (target == "cpu") {
return true;
} else if (target == "cuda" || target == "gpu") {
load_f_name = "_module_loadfile_ptx";
load_f_name = "module.loadfile_ptx";
} else if (target == "cl" || target == "opencl") {
load_f_name = "_module_loadfile_cl";
load_f_name = "module.loadfile_cl";
} else {
LOG(FATAL) << "Unknown optional runtime " << target;
}
return runtime::Registry::Get(load_f_name) != nullptr;
}
TVM_REGISTER_GLOBAL(_module__Enabled)
TVM_REGISTER_GLOBAL("module._Enabled")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = RuntimeEnabled(args[0]);
});
TVM_REGISTER_GLOBAL(_module__GetSource)
TVM_REGISTER_GLOBAL("module._GetSource")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = args[0].operator Module()->GetSource(args[1]);
});
TVM_REGISTER_GLOBAL(_module__ImportsSize)
TVM_REGISTER_GLOBAL("module._ImportsSize")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = static_cast<int64_t>(
args[0].operator Module()->imports().size());
});
TVM_REGISTER_GLOBAL(_module__GetImport)
TVM_REGISTER_GLOBAL("module._GetImport")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = args[0].operator Module()->
imports().at(args[1].operator int());
});
TVM_REGISTER_GLOBAL(_module__GetTypeKey)
TVM_REGISTER_GLOBAL("module._GetTypeKey")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = std::string(args[0].operator Module()->type_key());
});
TVM_REGISTER_GLOBAL(_module__LoadFromFile)
TVM_REGISTER_GLOBAL("module._LoadFromFile")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = Module::LoadFromFile(args[0], args[1]);
});
TVM_REGISTER_GLOBAL(_module__SaveToFile)
TVM_REGISTER_GLOBAL("module._SaveToFile")
.set_body([](TVMArgs args, TVMRetValue *ret) {
args[0].operator Module()->
SaveToFile(args[1], args[2]);
......
......@@ -189,10 +189,10 @@ bool InitOpenCL(TVMArgs args, TVMRetValue* rv) {
return true;
}
TVM_REGISTER_GLOBAL(_module_init_opencl)
TVM_REGISTER_GLOBAL("module.init_opencl")
.set_body(InitOpenCL);
TVM_REGISTER_GLOBAL(_device_api_opencl)
TVM_REGISTER_GLOBAL("device_api.opencl")
.set_body([](TVMArgs args, TVMRetValue* rv) {
DeviceAPI* ptr = OpenCLWorkspace::Global();
*rv = static_cast<void*>(ptr);
......
......@@ -314,12 +314,12 @@ Module OpenCLModuleLoad(const std::string& file_name,
return OpenCLModuleCreate(data, fmt, fmap);
}
TVM_REGISTER_GLOBAL(_module_loadfile_cl)
TVM_REGISTER_GLOBAL("module.loadfile_cl")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = OpenCLModuleLoad(args[0], args[1]);
});
TVM_REGISTER_GLOBAL(_module_loadfile_clbin)
TVM_REGISTER_GLOBAL("module.loadfile_clbin")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = OpenCLModuleLoad(args[0], args[1]);
});
......
......@@ -14,6 +14,25 @@ def test_get_global():
y = f(*targs)
assert y == 10
def test_get_callback_with_node():
x = tvm.convert(10)
def test(y):
assert y.handle != x.handle
return y
f2 = tvm.convert(test)
# register into global function table
@tvm.register_func
def my_callback_with_node(y, f):
assert y == x
return f(y)
# get it out from global function table
f = tvm.get_global_func("my_callback_with_node")
assert isinstance(f, tvm.Function)
y = f(x, f2)
assert(y.value == 10)
def test_return_func():
def addy(y):
......@@ -45,6 +64,7 @@ def test_byte_array():
f(a)
if __name__ == "__main__":
test_get_callback_with_node()
test_convert()
test_get_global()
test_return_func()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment