Commit 01a7ce0c by Tianqi Chen Committed by GitHub

[RUNTIME] Add Function, Unify TVMTypeCode and TVMArgTypeID (#24)

parent 4f1473f3
......@@ -6,11 +6,11 @@
#ifndef TVM_C_API_H_
#define TVM_C_API_H_
#include "./c_runtime_api.h"
#include "./runtime/c_runtime_api.h"
TVM_EXTERN_C {
/*! \brief handle to functions */
typedef void* APIFunctionHandle;
typedef void* APIFuncHandle;
/*! \brief handle to node */
typedef void* NodeHandle;
......@@ -18,16 +18,18 @@ typedef void* NodeHandle;
* \brief List all the node function name
* \param out_size The number of functions
* \param out_array The array of function names.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMListAPIFunctionNames(int *out_size,
const char*** out_array);
TVM_DLL int TVMListAPIFuncNames(int *out_size,
const char*** out_array);
/*!
* \brief get function handle by name
* \param name The name of function
* \param handle The returning function handle
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMGetAPIFunctionHandle(const char* name,
APIFunctionHandle *handle);
TVM_DLL int TVMGetAPIFuncHandle(const char* name,
APIFuncHandle *handle);
/*!
* \brief Get the detailed information about function.
......@@ -42,24 +44,26 @@ TVM_DLL int TVMGetAPIFunctionHandle(const char* name,
* \param return_type Return type of the function, if any.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMGetAPIFunctionInfo(APIFunctionHandle handle,
const char **real_name,
const char **description,
int *num_doc_args,
const char ***arg_names,
const char ***arg_type_infos,
const char ***arg_descriptions,
const char **return_type);
TVM_DLL int TVMGetAPIFuncInfo(APIFuncHandle handle,
const char **real_name,
const char **description,
int *num_doc_args,
const char ***arg_names,
const char ***arg_type_infos,
const char ***arg_descriptions,
const char **return_type);
/*!
* \brief Push an argument to the function calling stack.
* If push fails, the stack will be reset to empty
*
* \param arg number of attributes
* \param type_id The typeid of attributes.
* \param arg The argument
* \param type_code The type_code of argument as in TVMTypeCode
* \return 0 when success, -1 when failure happens
* \note API calls always exchanges with type bits=64, lanes=1
*/
TVM_DLL int TVMAPIPushStack(TVMArg arg,
int type_id);
TVM_DLL int TVMAPIPushStack(TVMValue arg,
int type_code);
/*!
* \brief call a function by using arguments in the stack.
......@@ -67,15 +71,18 @@ TVM_DLL int TVMAPIPushStack(TVMArg arg,
*
* \param handle The function handle
* \param ret_val The return value.
* \param ret_typeid the type id of return value.
* \param ret_type_code the type code of return value.
* \return 0 when success, -1 when failure happens
* \note API calls always exchanges with type bits=64, lanes=1
*/
TVM_DLL int TVMAPIFunctionCall(APIFunctionHandle handle,
TVMArg* ret_val,
int* ret_typeid);
TVM_DLL int TVMAPIFuncCall(APIFuncHandle handle,
TVMValue* ret_val,
int* ret_type_code);
/*!
* \brief free the node handle
* \param handle The node handle to be freed.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMNodeFree(NodeHandle handle);
......@@ -84,13 +91,15 @@ TVM_DLL int TVMNodeFree(NodeHandle handle);
* \param handle The node handle
* \param key The attribute name
* \param out_value The attribute value
* \param out_typeid The typeid of the attribute.
* \param out_type_code The type code of the attribute.
* \param out_success Whether get is successful.
* \return 0 when success, -1 when failure happens
* \note API calls always exchanges with type bits=64, lanes=1
*/
TVM_DLL int TVMNodeGetAttr(NodeHandle handle,
const char* key,
TVMArg* out_value,
int* out_typeid,
TVMValue* out_value,
int* out_type_code,
int* out_success);
/*!
......@@ -98,6 +107,7 @@ TVM_DLL int TVMNodeGetAttr(NodeHandle handle,
* \param handle The node handle
* \param out_size The number of functions
* \param out_array The array of function names.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMNodeListAttrNames(NodeHandle handle,
int *out_size,
......
......@@ -10,6 +10,8 @@
#include "./base.h"
#include "./expr.h"
#include "./module.h"
#include "./runtime/runtime.h"
namespace tvm {
/*! \brief namespace for lowlevel IR pass and codegen */
......@@ -62,6 +64,9 @@ Array<Var> UndefinedVars(const LoweredFunc& f);
*/
Array<LoweredFunc> SplitHostDevice(LoweredFunc func);
runtime::PackedFunc BuildStackVM(LoweredFunc func);
} // namespace codegen
} // namespace tvm
......
......@@ -78,6 +78,14 @@ constexpr const char* tvm_array_get_field = "tvm_array_get_field";
* }
*/
constexpr const char* tvm_handle_is_null = "tvm_handle_is_null";
/*!
* \brief See pesudo code
*
* bool tvm_print(VType value) {
* LOG(INFO) << value;
* }
*/
constexpr const char* tvm_print = "tvm_print";
/*! \brief The field id of each field in array */
enum TVMArrayFieldKind {
......
......@@ -9,8 +9,8 @@
* So this is a minimum runtime code gluing, and some limited
* memory management code to enable quick testing.
*/
#ifndef TVM_C_RUNTIME_API_H_
#define TVM_C_RUNTIME_API_H_
#ifndef TVM_RUNTIME_C_RUNTIME_API_H_
#define TVM_RUNTIME_C_RUNTIME_API_H_
#ifdef __cplusplus
#define TVM_EXTERN_C extern "C"
......@@ -38,27 +38,51 @@ TVM_EXTERN_C {
typedef uint32_t tvm_index_t;
/*!
* \brief union type for arguments and return values
* in both runtime API and TVM API calls
* \brief Union type of values
* being passed through API and function calls.
*/
typedef union {
long v_long; // NOLINT(*)
double v_double;
const char* v_str;
int64_t v_int64;
double v_float64;
void* v_handle;
} TVMArg;
const char* v_str;
} TVMValue;
/*!
* \brief The type index in TVM.
* \brief The type code in TVMType
* \note TVMType is used in two places.
*/
typedef enum {
kNull = 0,
kLong = 1,
kDouble = 2,
kStr = 3,
kNodeHandle = 4,
kArrayHandle = 5
} TVMArgTypeID;
kInt = 0U,
kUInt = 1U,
kFloat = 2U,
kHandle = 3U,
// The next few fields are extension types
// that is used by TVM API calls.
kNull = 4U,
kNodeHandle = 5U,
kStr = 6U,
kFuncHandle = 7U
} TVMTypeCode;
/*!
* \brief The data type used in TVM Runtime.
*
* Examples
* - float: type_code = 2, bits = 32, lanes=1
* - float4(vectorized 4 float): type_code = 2, bits = 32, lanes=4
* - int8: type_code = 0, bits = 8, lanes=1
*
* \note Arguments TVM API function always takes bits=64 and lanes=1
*/
typedef struct {
/*! \brief type code, in TVMTypeCode */
uint8_t type_code;
/*! \brief number of bits of the type */
uint8_t bits;
/*! \brief number of lanes, */
uint16_t lanes;
} TVMType;
/*!
* \brief The device type
......@@ -82,29 +106,6 @@ typedef struct {
int dev_id;
} TVMContext;
/*! \brief The type code in TVMDataType */
typedef enum {
kInt = 0U,
kUInt = 1U,
kFloat = 2U
} TVMTypeCode;
/*!
* \brief the data type
* Examples
* - float: type_code = 2, bits = 32, lanes=1
* - float4(vectorized 4 float): type_code = 2, bits = 32, lanes=4
* - int8: type_code = 0, bits = 8, lanes=1
*/
typedef struct {
/*! \brief type code, in TVMTypeCode */
uint8_t type_code;
/*! \brief number of bits of the type */
uint8_t bits;
/*! \brief number of lanes, */
uint16_t lanes;
} TVMDataType;
/*!
* \brief Data structure representing a n-dimensional array(tensor).
* This is used to pass data specification into TVM.
......@@ -122,7 +123,7 @@ typedef struct {
/*! \brief number of dimensions of the array */
tvm_index_t ndim;
/*! \brief The data type flag */
TVMDataType dtype;
TVMType dtype;
/*! \brief The device context this array sits on */
TVMContext ctx;
} TVMArray;
......@@ -191,7 +192,7 @@ TVM_DLL int TVMContextEnabled(TVMContext ctx,
*/
TVM_DLL int TVMArrayAlloc(const tvm_index_t* shape,
tvm_index_t ndim,
TVMDataType dtype,
TVMType dtype,
TVMContext ctx,
TVMArrayHandle* out);
/*!
......@@ -217,45 +218,27 @@ TVM_DLL int TVMArrayCopyFromTo(TVMArrayHandle from,
TVM_DLL int TVMSynchronize(TVMContext ctx, TVMStreamHandle stream);
/*!
* \brief TVM Function API: Get resource requirement
*
* By default TVM function try not to do internal allocations.
* Instead, TVMFuncRequirement can be called, given the input arguments.
*
* \param func function handle to be launched.
* \param args The arguments
* \param arg_type_ids The type id of the arguments
* \param num_args Number of arguments.
* \param out_workspace_size The workspace size needed to launch this function.
* \param out_workspace_align The alignment requirement of workspace.
*
* \note The data pointer in the arrays is not used by requirement.
* \brief Free the function when it is no longer needed.
* \param func The function handle
* \return whether
*/
TVM_DLL int TVMFuncRequirement(TVMFunctionHandle func,
TVMArg* args,
int* arg_type_ids,
int num_args,
size_t* out_workspace_size,
size_t* out_workspace_align);
TVM_DLL int TVMFuncFree(TVMFunctionHandle func);
/*!
* \brief TVM Function API: Launch generated function.
* \brief Call a function whose parameters are all packed.
*
* \param func function handle to be launched.
* \param func node handle of the function.
* \param args The arguments
* \param arg_type_ids The type id of the arguments
* \param type_codes The type codes of the arguments
* \param num_args Number of arguments.
* \param stream The stream this function to be launched on.
* \param workspace Additional workspace used to launch this function.
*
* \sa TVMFuncRequirement
* \return 0 when success, -1 when failure happens
* \note TVM calls always exchanges with type bits=64, lanes=1
*/
TVM_DLL int TVMFuncLaunch(TVMFunctionHandle func,
TVMArg* args,
int* arg_type_ids,
int num_args,
TVMStreamHandle stream,
TVMArrayHandle workspace);
TVM_DLL int TVMFuncCall(TVMFunctionHandle func,
TVMValue* args,
int* type_codes,
int num_args);
} // TVM_EXTERN_C
#endif // TVM_C_RUNTIME_API_H_
#endif // TVM_RUNTIME_C_RUNTIME_API_H_
/*!
* Copyright (c) 2016 by Contributors
* \file runtime.h
* \brief Runtime related c++ class.
*/
#ifndef TVM_RUNTIME_RUNTIME_H_
#define TVM_RUNTIME_RUNTIME_H_
#include <functional>
#include <tuple>
#include "./c_runtime_api.h"
namespace tvm {
namespace runtime {
/*!
* \brief Packed function is a runtime function
* whose argument type_codes are erased by packed format.
*
* This is an useful unified interface to call generated functions.
*/
class PackedFunc {
public:
/*! \brief The internal std::function */
using FType = std::function<void(const TVMValue* args, const int* type_codes, int num_args)>;
PackedFunc() {}
explicit PackedFunc(FType body) : body_(body) {}
/*!
* \brief invoke the packed function by directly passing in arguments.
* \param args Arguments to be passed.
* \tparam Args arguments to be passed.
* \return The first return value.
*/
template<typename... Args>
inline void operator()(Args&& ...args) const;
/*!
* \brief Call the function in packed format.
* \param args The arguments
* \param type_codes The type_codes of the arguments
* \param num_args Number of arguments.
*/
inline void CallPacked(const TVMValue* args, const int* type_codes, int num_args) const;
/*! \return the internal body function */
inline FType body() const {
return body_;
}
private:
/*! \brief internal container of packed function */
FType body_;
};
// implementations
inline void PackedFunc::CallPacked(
const TVMValue* args, const int* type_codes, int num_args) const {
body_(args, type_codes, num_args);
}
template<bool stop, std::size_t I, typename F, typename ...Args>
struct for_each_dispatcher_ {
static inline void run(const std::tuple<Args...>& args, F f) {
f(I, std::get<I>(args));
for_each_dispatcher_<(I + 1) == sizeof...(Args), (I+1), F, Args...>::run(args, f);
}
};
template<std::size_t I, typename F, typename ...Args>
struct for_each_dispatcher_<true, I, F, Args...> {
static inline void run(const std::tuple<Args...>& args, F f) {}
};
template<typename F, typename ...Args>
inline void for_each(const std::tuple<Args...>& args, F f) {
for_each_dispatcher_<sizeof...(Args) == 0, 0, F, Args...>::run(args, f);
}
namespace arg_setter {
template<typename T>
inline void Set(TVMValue& arg, int& t, T v); // NOLINT(*)
template<>
inline void Set<double>(TVMValue& arg, int& t, double value) { // NOLINT(*)
arg.v_float64 = value;
t = kFloat;
}
template<>
inline void Set<int>(TVMValue& arg, int& t, int value) { // NOLINT(*)
arg.v_int64 = value;
t = kInt;
}
template<>
inline void Set<long>(TVMValue& arg, int& t, long value) { // NOLINT(*)
arg.v_int64 = value;
t = kInt;
}
template<>
inline void Set<TVMArray*>(TVMValue& arg, int& t, TVMArray* value) { // NOLINT(*)
arg.v_handle = value;
t = kHandle;
}
template<>
inline void Set<void*>(TVMValue& arg, int& t, void* value) { // NOLINT(*)
arg.v_handle = value;
t = kHandle;
}
} // namespace arg_setter
struct PackedFuncArgSetter {
TVMValue* args;
int* type_codes;
template<typename T>
inline void operator()(size_t i, T v) const {
arg_setter::Set(args[i], type_codes[i], v);
}
};
template<typename... Args>
inline void PackedFunc::operator()(Args&& ...args) const {
auto targ = std::make_tuple(std::forward<Args>(args)...);
const int kNumArgs = sizeof...(Args);
TVMValue tvm_args[kNumArgs];
int tvm_arg_type_ids[kNumArgs];
for_each(targ, PackedFuncArgSetter{tvm_args, tvm_arg_type_ids});
body_(tvm_args, tvm_arg_type_ids, kNumArgs);
}
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_RUNTIME_H_
......@@ -16,4 +16,4 @@ from . import ndarray as nd
from .ndarray import cpu, gpu, opencl, init_opencl
from ._base import TVMError
from .function import *
from .api import *
"""namespace of internal API"""
......@@ -11,24 +11,26 @@ from numbers import Number, Integral
from .._base import _LIB
from .._base import c_str, py_str, string_types
from .._base import check_call, ctypes2docstring
from .. import _function_internal
class TVMArg(ctypes.Union):
"""TVMArg in C API"""
_fields_ = [("v_long", ctypes.c_long),
("v_double", ctypes.c_double),
("v_str", ctypes.c_char_p),
("v_handle", ctypes.c_void_p)]
from .. import _api_internal
from . import _runtime_api
from ._types import TVMValue, TypeCode
# type definitions
APIFunctionHandle = ctypes.c_void_p
APIFuncHandle = ctypes.c_void_p
NodeHandle = ctypes.c_void_p
FunctionHandle = ctypes.c_void_p
class APIType(object):
"""TVMType used in API calls"""
INT = ctypes.c_int(TypeCode.INT)
UINT = ctypes.c_int(TypeCode.UINT)
FLOAT = ctypes.c_int(TypeCode.FLOAT)
HANDLE = ctypes.c_int(TypeCode.HANDLE)
NULL = ctypes.c_int(TypeCode.NULL)
NODE_HANDLE = ctypes.c_int(TypeCode.NODE_HANDLE)
STR = ctypes.c_int(TypeCode.STR)
FUNC_HANDLE = ctypes.c_int(TypeCode.FUNC_HANDLE)
kNull = 0
kLong = 1
kDouble = 2
kStr = 3
kNodeHandle = 4
NODE_TYPE = {
}
......@@ -37,22 +39,31 @@ def _return_node(x):
handle = x.v_handle
if not isinstance(handle, NodeHandle):
handle = NodeHandle(handle)
ret_val = TVMArg()
ret_typeid = ctypes.c_int()
ret_val = TVMValue()
ret_type_code = ctypes.c_int()
ret_success = ctypes.c_int()
check_call(_LIB.TVMNodeGetAttr(
handle, c_str("type_key"),
ctypes.byref(ret_val),
ctypes.byref(ret_typeid),
ctypes.byref(ret_type_code),
ctypes.byref(ret_success)))
return NODE_TYPE.get(py_str(ret_val.v_str), NodeBase)(handle)
def _return_func(x):
handle = x.v_handle
if not isinstance(handle, FunctionHandle):
handle = FunctionHandle(handle)
return _runtime_api._function_cls(handle)
RET_SWITCH = {
kNull: lambda x: None,
kLong: lambda x: x.v_long,
kDouble: lambda x: x.v_double,
kStr: lambda x: py_str(x.v_str),
kNodeHandle: _return_node
TypeCode.NULL: lambda x: None,
TypeCode.INT: lambda x: x.v_int64,
TypeCode.FLOAT: lambda x: x.v_float64,
TypeCode.STR: lambda x: py_str(x.v_str),
TypeCode.NODE_HANDLE: _return_node,
TypeCode.FUNC_HANDLE: _return_func
}
class SliceBase(object):
......@@ -74,28 +85,28 @@ class NodeBase(object):
self.handle = handle
def __repr__(self):
return _function_internal._format_str(self)
return _api_internal._format_str(self)
def __del__(self):
check_call(_LIB.TVMNodeFree(self.handle))
def __getattr__(self, name):
ret_val = TVMArg()
ret_typeid = ctypes.c_int()
ret_val = TVMValue()
ret_type_code = ctypes.c_int()
ret_success = ctypes.c_int()
check_call(_LIB.TVMNodeGetAttr(
self.handle, c_str(name),
ctypes.byref(ret_val),
ctypes.byref(ret_typeid),
ctypes.byref(ret_type_code),
ctypes.byref(ret_success)))
value = RET_SWITCH[ret_typeid.value](ret_val)
value = RET_SWITCH[ret_type_code.value](ret_val)
if not ret_success.value:
raise AttributeError(
"'%s' object has no attribute '%s'" % (str(type(self)), name))
return value
def __hash__(self):
return _function_internal._raw_ptr(self)
return _api_internal._raw_ptr(self)
def __eq__(self, other):
if not isinstance(other, NodeBase):
......@@ -121,7 +132,7 @@ class NodeBase(object):
def __getstate__(self):
handle = self.handle
if handle is not None:
return {'handle': _function_internal._save_json(self)}
return {'handle': _api_internal._save_json(self)}
else:
return {'handle': None}
......@@ -131,7 +142,7 @@ class NodeBase(object):
if handle is not None:
json_str = handle
_push_arg(json_str)
other = _function_internal._load_json(json_str)
other = _api_internal._load_json(json_str)
self.handle = other.handle
other.handle = None
else:
......@@ -145,7 +156,7 @@ def const(value, dtype=None):
dtype = 'int32'
else:
dtype = 'float32'
return _function_internal._const(value, dtype)
return _api_internal._const(value, dtype)
def convert(value):
......@@ -154,7 +165,7 @@ def convert(value):
return const(value)
elif isinstance(value, (list, tuple)):
value = [convert(x) for x in value]
return _function_internal._Array(*value)
return _api_internal._Array(*value)
elif isinstance(value, dict):
vlist = []
for it in value.items():
......@@ -162,7 +173,7 @@ def convert(value):
raise ValueError("key of map must already been a container type")
vlist.append(it[0])
vlist.append(convert(it[1]))
return _function_internal._Map(*vlist)
return _api_internal._Map(*vlist)
elif isinstance(value, SliceBase):
return value.tensor(*value.indices)
else:
......@@ -172,21 +183,21 @@ def convert(value):
def _push_arg(arg):
a = TVMArg()
a = TVMValue()
if arg is None:
_LIB.TVMAPIPushStack(a, ctypes.c_int(kNull))
_LIB.TVMAPIPushStack(a, APIType.NULL)
elif isinstance(arg, NodeBase):
a.v_handle = arg.handle
_LIB.TVMAPIPushStack(a, ctypes.c_int(kNodeHandle))
elif isinstance(arg, int):
a.v_long = ctypes.c_long(arg)
_LIB.TVMAPIPushStack(a, ctypes.c_int(kLong))
_LIB.TVMAPIPushStack(a, APIType.NODE_HANDLE)
elif isinstance(arg, Integral):
a.v_int64 = ctypes.c_int64(arg)
_LIB.TVMAPIPushStack(a, APIType.INT)
elif isinstance(arg, Number):
a.v_double = ctypes.c_double(arg)
_LIB.TVMAPIPushStack(a, ctypes.c_int(kDouble))
_LIB.TVMAPIPushStack(a, APIType.FLOAT)
elif isinstance(arg, string_types):
a.v_str = c_str(arg)
_LIB.TVMAPIPushStack(a, ctypes.c_int(kStr))
_LIB.TVMAPIPushStack(a, APIType.STR)
else:
raise TypeError("Don't know how to handle type %s" % type(arg))
......@@ -201,7 +212,7 @@ def _make_function(handle, name):
arg_descs = ctypes.POINTER(ctypes.c_char_p)()
ret_type = ctypes.c_char_p()
check_call(_LIB.TVMGetAPIFunctionInfo(
check_call(_LIB.TVMGetAPIFuncInfo(
handle, ctypes.byref(real_name), ctypes.byref(desc),
ctypes.byref(num_args),
ctypes.byref(arg_names),
......@@ -214,13 +225,7 @@ def _make_function(handle, name):
desc = py_str(desc.value)
doc_str = ('%s\n\n' +
'%s\n' +
'name : string, optional.\n' +
' Name of the resulting symbol.\n\n' +
'Returns\n' +
'-------\n' +
'symbol: Symbol\n' +
' The result symbol.')
'%s\n')
doc_str = doc_str % (desc, param_str)
arg_names = [py_str(arg_names[i]) for i in range(num_args.value)]
......@@ -235,11 +240,11 @@ def _make_function(handle, name):
for arg in cargs:
_push_arg(arg)
ret_val = TVMArg()
ret_typeid = ctypes.c_int()
check_call(_LIB.TVMAPIFunctionCall(
handle, ctypes.byref(ret_val), ctypes.byref(ret_typeid)))
return RET_SWITCH[ret_typeid.value](ret_val)
ret_val = TVMValue()
ret_type_code = ctypes.c_int()
check_call(_LIB.TVMAPIFuncCall(
handle, ctypes.byref(ret_val), ctypes.byref(ret_type_code)))
return RET_SWITCH[ret_type_code.value](ret_val)
func.__name__ = func_name
func.__doc__ = doc_str
......@@ -265,19 +270,19 @@ def register_node(type_key=None):
NODE_TYPE[cls.__name__] = cls
return cls
def _init_function_module(root_namespace):
def _init_api_module(root_namespace):
"""List and add all the functions to current module."""
plist = ctypes.POINTER(ctypes.c_char_p)()
size = ctypes.c_uint()
check_call(_LIB.TVMListAPIFunctionNames(ctypes.byref(size),
ctypes.byref(plist)))
check_call(_LIB.TVMListAPIFuncNames(ctypes.byref(size),
ctypes.byref(plist)))
op_names = []
for i in range(size.value):
op_names.append(py_str(plist[i]))
module_obj = sys.modules["%s.function" % root_namespace]
module_internal = sys.modules["%s._function_internal" % root_namespace]
module_obj = sys.modules["%s.api" % root_namespace]
module_internal = sys.modules["%s._api_internal" % root_namespace]
namespace_match = {
"_make_": sys.modules["%s.make" % root_namespace],
"_pass_": sys.modules["%s.ir_pass" % root_namespace],
......@@ -286,8 +291,8 @@ def _init_function_module(root_namespace):
}
for name in op_names:
hdl = APIFunctionHandle()
check_call(_LIB.TVMGetAPIFunctionHandle(c_str(name), ctypes.byref(hdl)))
hdl = APIFuncHandle()
check_call(_LIB.TVMGetAPIFuncHandle(c_str(name), ctypes.byref(hdl)))
fname = name
target_module = module_internal if name.startswith('_') else module_obj
for k, v in namespace_match.items():
......
......@@ -4,16 +4,16 @@
from __future__ import absolute_import as _abs
import ctypes
from numbers import Number, Integral
import numpy as np
from .._base import _LIB
from .._base import c_array, c_str
from .._base import c_array, c_str, string_types
from .._base import check_call
from ._types import TVMValue, TypeCode, TVMType
tvm_index_t = ctypes.c_uint32
class TVMContext(ctypes.Structure):
"""TVM context strucure."""
_fields_ = [("dev_mask", ctypes.c_int),
......@@ -72,52 +72,13 @@ def opencl(dev_id=0):
return TVMContext(4, dev_id)
class TVMDataType(ctypes.Structure):
"""TVM datatype structure"""
_fields_ = [("type_code", ctypes.c_uint8),
("bits", ctypes.c_uint8),
("lanes", ctypes.c_uint16)]
CODE2STR = {
0 : 'int',
1 : 'uint',
2 : 'float'
}
def __init__(self, type_str, lanes=1):
super(TVMDataType, self).__init__()
if isinstance(type_str, np.dtype):
type_str = str(type_str)
if type_str.startswith("int"):
self.type_code = 0
bits = int(type_str[3:])
elif type_str.startswith("uint"):
self.type_code = 1
bits = int(type_str[4:])
elif type_str.startswith("float"):
self.type_code = 2
bits = int(type_str[5:])
else:
raise ValueError("Donot know how to handle type %s" % type_str)
bits = 32 if bits == 0 else bits
if (bits & (bits - 1)) != 0 or bits < 8:
raise ValueError("Donot know how to handle type %s" % type_str)
self.bits = bits
self.lanes = lanes
def __repr__(self):
x = "%s%d" % (TVMDataType.CODE2STR[self.type_code], self.bits)
if self.lanes != 1:
x += "x%d" % self.lanes
return x
class TVMArray(ctypes.Structure):
"""TVMArg in C API"""
"""TVMValue in C API"""
_fields_ = [("data", ctypes.c_void_p),
("shape", ctypes.POINTER(tvm_index_t)),
("strides", ctypes.POINTER(tvm_index_t)),
("ndim", tvm_index_t),
("dtype", TVMDataType),
("dtype", TVMType),
("ctx", TVMContext)]
TVMArrayHandle = ctypes.POINTER(TVMArray)
......@@ -133,7 +94,7 @@ def numpyasarray(np_data):
arr.data = data.ctypes.data_as(ctypes.c_void_p)
arr.shape = shape
arr.strides = None
arr.dtype = TVMDataType(np.dtype(data.dtype).name)
arr.dtype = TVMType(np.dtype(data.dtype).name)
arr.ndim = data.ndim
# CPU device
arr.ctx = cpu(0)
......@@ -141,6 +102,7 @@ def numpyasarray(np_data):
_ndarray_cls = None
_function_cls = None
def empty(shape, dtype="float32", ctx=cpu(0)):
......@@ -165,7 +127,7 @@ def empty(shape, dtype="float32", ctx=cpu(0)):
shape = c_array(tvm_index_t, shape)
ndim = tvm_index_t(len(shape))
handle = TVMArrayHandle()
dtype = TVMDataType(dtype)
dtype = TVMType(dtype)
check_call(_LIB.TVMArrayAlloc(
shape, ndim, dtype, ctx, ctypes.byref(handle)))
return _ndarray_cls(handle)
......@@ -313,6 +275,51 @@ class NDArrayBase(object):
return target
def _init_runtime_module(ndarray_class):
class FunctionBase(object):
"""A function object at runtim."""
__slots__ = ["handle"]
# pylint: disable=no-member
def __init__(self, handle):
"""Initialize the function with handle
Parameters
----------
handle : FunctionHandle
the handle to the underlying function.
"""
self.handle = handle
def __del__(self):
check_call(_LIB.TVMFuncFree(self.handle))
def __call__(self, *args):
num_args = len(args)
tvm_args = (TVMValue * num_args)()
tvm_type_code = (ctypes.c_int * num_args)()
for i, arg in enumerate(args):
if arg is None:
tvm_args[i].v_handle = None
tvm_type_code[i] = TypeCode.NULL
elif isinstance(arg, NDArrayBase):
tvm_args[i].v_handle = ctypes.cast(arg.handle, ctypes.c_void_p)
tvm_type_code[i] = TypeCode.HANDLE
elif isinstance(arg, Integral):
tvm_args[i].v_int64 = arg
tvm_type_code[i] = TypeCode.INT
elif isinstance(arg, Number):
tvm_args[i].v_float64 = arg
tvm_type_code[i] = TypeCode.FLOAT
elif isinstance(arg, string_types):
tvm_args[i].v_str = c_str(arg)
tvm_type_code[i] = TypeCode.STR
else:
raise TypeError("Don't know how to handle type %s" % type(arg))
check_call(_LIB.TVMFuncCall(
self.handle, tvm_args, tvm_type_code, ctypes.c_int(num_args)))
def _init_runtime_module(ndarray_class, function_class):
global _ndarray_cls
global _function_cls
_ndarray_cls = ndarray_class
_function_cls = function_class
"""The C Types used in API."""
# pylint: disable=invalid-name
from __future__ import absolute_import as _abs
import ctypes
import numpy as np
class TVMValue(ctypes.Union):
"""TVMValue in C API"""
_fields_ = [("v_int64", ctypes.c_int64),
("v_float64", ctypes.c_double),
("v_handle", ctypes.c_void_p),
("v_str", ctypes.c_char_p)]
class TypeCode(object):
"""Type code used in API calls"""
INT = 0
UINT = 1
FLOAT = 2
HANDLE = 3
NULL = 4
NODE_HANDLE = 5
STR = 6
FUNC_HANDLE = 7
def _api_type(code):
"""create a type accepted by API"""
t = TVMType()
t.bits = 64
t.lanes = 1
t.type_code = code
return t
class TVMType(ctypes.Structure):
"""TVM datatype structure"""
_fields_ = [("type_code", ctypes.c_uint8),
("bits", ctypes.c_uint8),
("lanes", ctypes.c_uint16)]
CODE2STR = {
0 : 'int',
1 : 'uint',
2 : 'float'
}
def __init__(self, type_str, lanes=1):
super(TVMType, self).__init__()
if isinstance(type_str, np.dtype):
type_str = str(type_str)
if type_str.startswith("int"):
self.type_code = 0
bits = int(type_str[3:])
elif type_str.startswith("uint"):
self.type_code = 1
bits = int(type_str[4:])
elif type_str.startswith("float"):
self.type_code = 2
bits = int(type_str[5:])
else:
raise ValueError("Donot know how to handle type %s" % type_str)
bits = 32 if bits == 0 else bits
if (bits & (bits - 1)) != 0 or bits < 8:
raise ValueError("Donot know how to handle type %s" % type_str)
self.bits = bits
self.lanes = lanes
def __repr__(self):
x = "%s%d" % (TVMType.CODE2STR[self.type_code], self.bits)
if self.lanes != 1:
x += "x%d" % self.lanes
return x
"""namespace of internal function"""
......@@ -3,8 +3,8 @@
"""Functions defined in TVM."""
from __future__ import absolute_import as _abs
from numbers import Integral as _Integral
from ._ctypes._api import _init_function_module, convert
from . import _function_internal
from ._ctypes._api import _init_api_module, convert
from . import _api_internal
from . import make as _make
from . import expr as _expr
from . import collections as _collections
......@@ -20,7 +20,7 @@ def const(value, dtype=None):
dtype = 'int32'
else:
dtype = 'float32'
return _function_internal._const(value, dtype)
return _api_internal._const(value, dtype)
def load_json(json_str):
......@@ -36,7 +36,7 @@ def load_json(json_str):
node : Node
The loaded tvm node.
"""
return _function_internal._load_json(json_str)
return _api_internal._load_json(json_str)
def save_json(node):
......@@ -52,7 +52,7 @@ def save_json(node):
json_str : str
Saved json string.
"""
return _function_internal._save_json(node)
return _api_internal._save_json(node)
def Var(name="tindex", dtype=int32):
......@@ -66,7 +66,7 @@ def Var(name="tindex", dtype=int32):
dtype : int
The data type
"""
return _function_internal._Var(name, dtype)
return _api_internal._Var(name, dtype)
def placeholder(shape, dtype=None, name="placeholder"):
......@@ -90,7 +90,7 @@ def placeholder(shape, dtype=None, name="placeholder"):
"""
shape = (shape,) if isinstance(shape, _expr.Expr) else shape
dtype = float32 if dtype is None else dtype
return _function_internal._Placeholder(
return _api_internal._Placeholder(
shape, dtype, name)
......@@ -128,9 +128,9 @@ def compute(shape, fcompute, name="compute"):
dim_var = [IterVar((0, s), x) for x, s in zip(arg_names, shape)]
body = fcompute(*[v.var for v in dim_var])
body = convert(body)
op_node = _function_internal._ComputeOp(
op_node = _api_internal._ComputeOp(
name, dim_var, body)
return _function_internal._Tensor(
return _api_internal._Tensor(
shape, body.dtype, op_node, 0)
......@@ -168,7 +168,7 @@ def Buffer(shape, dtype=None,
if ptr is None:
ptr = Var(name, "handle")
return _function_internal._Buffer(
return _api_internal._Buffer(
name, ptr, shape, strides, dtype)
......@@ -202,7 +202,7 @@ def IterVar(dom=None, name=None, thread_tag=''):
if name is None:
name = thread_tag if thread_tag else name
name = name if name else 'iter'
return _function_internal._IterVar(dom, name, thread_tag)
return _api_internal._IterVar(dom, name, thread_tag)
def sum(expr, rdom):
......@@ -263,7 +263,7 @@ def Schedule(ops):
"""
if not isinstance(ops, (list, _collections.Array)):
ops = [ops]
return _function_internal._Schedule(ops)
return _api_internal._Schedule(ops)
_init_function_module("tvm")
_init_api_module("tvm")
......@@ -2,7 +2,7 @@
"""Collection structure in the high level DSL."""
from __future__ import absolute_import as _abs
from ._ctypes._api import NodeBase, register_node
from . import _function_internal
from . import _api_internal
from . import expr as _expr
@register_node
......@@ -11,10 +11,10 @@ class Array(NodeBase):
def __getitem__(self, i):
if i >= len(self):
raise IndexError("array index out ot range")
return _function_internal._ArrayGetItem(self, i)
return _api_internal._ArrayGetItem(self, i)
def __len__(self):
return _function_internal._ArraySize(self)
return _api_internal._ArraySize(self)
def __repr__(self):
return '[' + (','.join(str(x) for x in self)) + ']'
......@@ -23,18 +23,18 @@ class Array(NodeBase):
class Map(NodeBase):
"""Map container of TVM"""
def __getitem__(self, k):
return _function_internal._MapGetItem(self, k)
return _api_internal._MapGetItem(self, k)
def __contains__(self, k):
return _function_internal._MapCount(self, k) != 0
return _api_internal._MapCount(self, k) != 0
def items(self):
"""Get the items from the map"""
akvs = _function_internal._MapItems(self)
akvs = _api_internal._MapItems(self)
return [(akvs[i], akvs[i+1]) for i in range(0, len(akvs), 2)]
def __len__(self):
return _function_internal._MapSize(self)
return _api_internal._MapSize(self)
def __repr__(self):
return '{' + (", ".join(str(x[0]) + ": " +str(x[1]) for x in self.items())) + '}'
......
......@@ -6,7 +6,7 @@ This is a simplified runtime API for quick testing and proptyping.
from __future__ import absolute_import as _abs
import numpy as _np
from ._ctypes._runtime_api import TVMContext, TVMDataType, NDArrayBase
from ._ctypes._runtime_api import TVMContext, TVMType, NDArrayBase, FunctionBase
from ._ctypes._runtime_api import cpu, gpu, opencl, empty, sync
from ._ctypes._runtime_api import _init_runtime_module
from ._ctypes._runtime_api import init_opencl
......@@ -26,6 +26,11 @@ class NDArray(NDArrayBase):
pass
class Function(FunctionBase):
"""Function class that can executed a generated code."""
pass
def array(arr, ctx=cpu(0)):
"""Create an array from source arr.
......@@ -49,4 +54,4 @@ def array(arr, ctx=cpu(0)):
return ret
_init_runtime_module(NDArray)
_init_runtime_module(NDArray, Function)
......@@ -2,7 +2,7 @@
"""Collection structure in the high level DSL."""
from __future__ import absolute_import as _abs
from ._ctypes._api import NodeBase, register_node
from . import _function_internal
from . import _api_internal
from . import tensor as _tensor
@register_node
......@@ -56,11 +56,11 @@ class Stage(NodeBase):
if outer is not None:
if outer.thread_tag == '':
raise ValueError("split by outer must have special thread_tag")
inner = _function_internal._StageSplitByOuter(self, parent, outer, factor)
inner = _api_internal._StageSplitByOuter(self, parent, outer, factor)
else:
if factor is None:
raise ValueError("either outer or factor need to be provided")
outer, inner = _function_internal._StageSplitByFactor(self, parent, factor)
outer, inner = _api_internal._StageSplitByFactor(self, parent, factor)
return outer, inner
def fuse(self, inner, outer):
......@@ -79,7 +79,7 @@ class Stage(NodeBase):
inner : IterVar
The fused variable of iteration.
"""
return _function_internal._StageFuse(self, inner, outer)
return _api_internal._StageFuse(self, inner, outer)
def set_scope(self, scope):
"""Set the thread scope of this stage
......@@ -89,7 +89,7 @@ class Stage(NodeBase):
scope : str
The thread scope of this stage
"""
return _function_internal._StageSetScope(self, scope)
return _api_internal._StageSetScope(self, scope)
def compute_at(self, parent, scope):
"""Attach the stage at parent's scope
......@@ -102,7 +102,7 @@ class Stage(NodeBase):
scope : IterVar
The loop scope t be attached to.
"""
_function_internal._StageComputeAt(self, parent, scope)
_api_internal._StageComputeAt(self, parent, scope)
def compute_inline(self):
"""Mark stage as inline
......@@ -112,7 +112,7 @@ class Stage(NodeBase):
parent : Stage
The parent stage
"""
_function_internal._StageComputeInline(self)
_api_internal._StageComputeInline(self)
def compute_root(self):
"""Attach the stage at parent, and mark it as root
......@@ -122,7 +122,7 @@ class Stage(NodeBase):
parent : Stage
The parent stage
"""
_function_internal._StageComputeInline(self)
_api_internal._StageComputeInline(self)
def reorder(self, *args):
"""reorder the arguments in the specified order.
......@@ -132,7 +132,7 @@ class Stage(NodeBase):
args : list of IterVar
The order to be ordered
"""
_function_internal._StageReorder(self, args)
_api_internal._StageReorder(self, args)
def tile(self, x_parent, y_parent, x_factor, y_factor):
""" Perform tiling on two dimensions
......@@ -161,6 +161,6 @@ class Stage(NodeBase):
p_y_inner : IterVar
Inner axis of y dimension
"""
x_outer, y_outer, x_inner, y_inner = _function_internal._StageTile(
x_outer, y_outer, x_inner, y_inner = _api_internal._StageTile(
self, x_parent, y_parent, x_factor, y_factor)
return x_outer, y_outer, x_inner, y_inner
......@@ -3,7 +3,7 @@
from __future__ import absolute_import as _abs
from ._ctypes._api import NodeBase, SliceBase, register_node, convert
from . import collections as _collections
from . import _function_internal
from . import _api_internal
from . import make as _make
from . import expr as _expr
......@@ -44,12 +44,12 @@ class Tensor(NodeBase):
return TensorSlice(self, indices)
def __hash__(self):
return _function_internal._TensorHash(self)
return _api_internal._TensorHash(self)
def __eq__(self, other):
if not isinstance(other, Tensor):
return False
return _function_internal._TensorEqual(self, other)
return _api_internal._TensorEqual(self, other)
@property
def ndim(self):
......@@ -72,7 +72,7 @@ class Operation(NodeBase):
out : Tensor
The i-th output.
"""
return _function_internal._OpGetOutput(self, index)
return _api_internal._OpGetOutput(self, index)
@register_node
class ComputeOp(Operation):
......
......@@ -4,4 +4,5 @@
- lang The definition of DSL related data structure
- schedule The operations on the schedule graph before converting to IR.
- pass The optimization pass on the IR structure
- runtime The runtime related codes.
\ No newline at end of file
- runtime Minimum runtime related codes.
- jit JIT runtime related code.
......@@ -22,7 +22,7 @@ struct TVMAPIThreadLocalEntry {
arg_stack.clear();
ret_value.sptr.reset();
}
inline void SetReturn(ArgVariant* ret_val, int* ret_typeid);
inline void SetReturn(TVMValue* ret_val, int* ret_type_code);
};
using namespace tvm;
......@@ -97,11 +97,11 @@ struct APIAttrDir : public AttrVisitor {
}
};
int TVMListAPIFunctionNames(int *out_size,
int TVMListAPIFuncNames(int *out_size,
const char*** out_array) {
API_BEGIN();
TVMAPIThreadLocalEntry *ret = TVMAPIThreadLocalStore::Get();
ret->ret_vec_str = dmlc::Registry<APIFunctionReg>::ListAllNames();
ret->ret_vec_str = dmlc::Registry<APIFuncReg>::ListAllNames();
ret->ret_vec_charp.clear();
for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) {
ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str());
......@@ -111,16 +111,16 @@ int TVMListAPIFunctionNames(int *out_size,
API_END();
}
int TVMGetAPIFunctionHandle(const char* fname,
APIFunctionHandle* out) {
int TVMGetAPIFuncHandle(const char* fname,
APIFuncHandle* out) {
API_BEGIN();
const APIFunctionReg* reg = dmlc::Registry<APIFunctionReg>::Find(fname);
const APIFuncReg* reg = dmlc::Registry<APIFuncReg>::Find(fname);
CHECK(reg != nullptr) << "cannot find function " << fname;
*out = (APIFunctionHandle)reg;
*out = (APIFuncHandle)reg;
API_END();
}
int TVMGetAPIFunctionInfo(APIFunctionHandle handle,
int TVMGetAPIFuncInfo(APIFuncHandle handle,
const char **real_name,
const char **description,
int *num_doc_args,
......@@ -128,7 +128,7 @@ int TVMGetAPIFunctionInfo(APIFunctionHandle handle,
const char ***arg_type_infos,
const char ***arg_descriptions,
const char **return_type) {
const auto *op = static_cast<const APIFunctionReg *>(handle);
const auto *op = static_cast<const APIFuncReg *>(handle);
TVMAPIThreadLocalEntry *ret = TVMAPIThreadLocalStore::Get();
API_BEGIN();
......@@ -152,33 +152,37 @@ int TVMGetAPIFunctionInfo(APIFunctionHandle handle,
API_END();
}
int TVMAPIPushStack(ArgVariant arg,
int type_id) {
int TVMAPIPushStack(TVMValue arg,
int type_code) {
TVMAPIThreadLocalEntry *ret = TVMAPIThreadLocalStore::Get();
API_BEGIN();
ret->arg_stack.resize(ret->arg_stack.size() + 1);
APIVariantValue& v = ret->arg_stack.back();
v.type_id = static_cast<ArgVariantID>(type_id);
if (type_id == kStr) {
v.str = arg.v_str;
} else if (type_id == kNodeHandle) {
v.sptr = *static_cast<TVMAPINode*>(arg.v_handle);
} else {
v.v_union = arg;
v.type_code = type_code;
switch (type_code) {
case kInt: case kUInt: case kFloat: case kNull: {
v.v_union = arg; break;
}
case kStr: {
v.str = arg.v_str; break;
}
case kNodeHandle: {
v.sptr = *static_cast<TVMAPINode*>(arg.v_handle); break;
}
default: LOG(FATAL) << "TVM API cannot take type " << TVMTypeCode2Str(type_code);
}
API_END_HANDLE_ERROR(ret->Clear());
}
int TVMAPIFunctionCall(APIFunctionHandle handle,
ArgVariant* ret_val,
int* ret_typeid) {
int TVMAPIFuncCall(APIFuncHandle handle,
TVMValue* ret_val,
int* ret_type_code) {
TVMAPIThreadLocalEntry *ret = TVMAPIThreadLocalStore::Get();
API_BEGIN();
const auto *op = static_cast<const APIFunctionReg *>(handle);
const auto *op = static_cast<const APIFuncReg *>(handle);
op->body(ret->arg_stack, &(ret->ret_value));
ret->SetReturn(ret_val, ret_typeid);
ret->SetReturn(ret_val, ret_type_code);
ret->arg_stack.clear();
API_END_HANDLE_ERROR(ret->Clear());
}
......@@ -191,28 +195,28 @@ int TVMNodeFree(NodeHandle handle) {
int TVMNodeGetAttr(NodeHandle handle,
const char* key,
ArgVariant* ret_val,
int* ret_typeid,
TVMValue* ret_val,
int* ret_type_code,
int* ret_success) {
TVMAPIThreadLocalEntry *ret = TVMAPIThreadLocalStore::Get();
API_BEGIN();
ret->ret_value.type_id = kNull;
ret->ret_value.type_code = kNull;
APIAttrGetter getter;
getter.skey = key;
getter.ret = &(ret->ret_value);
TVMAPINode* tnode = static_cast<TVMAPINode*>(handle);
if (getter.skey == "type_key") {
ret_val->v_str = (*tnode)->type_key();
*ret_typeid = kStr;
*ret_type_code = kStr;
*ret_success = 1;
} else {
(*tnode)->VisitAttrs(&getter);
if (ret->ret_value.type_id != kNull) {
ret->SetReturn(ret_val, ret_typeid);
if (ret->ret_value.type_code != kNull) {
ret->SetReturn(ret_val, ret_type_code);
*ret_success = 1;
} else {
*ret_success = getter.found_node_ref ? 1 : 0;
*ret_typeid = kNull;
*ret_type_code = kNull;
}
}
API_END_HANDLE_ERROR(ret->Clear());
......@@ -238,16 +242,18 @@ int TVMNodeListAttrNames(NodeHandle handle,
}
inline void TVMAPIThreadLocalEntry::SetReturn(ArgVariant* ret_val,
int* ret_typeid) {
inline void TVMAPIThreadLocalEntry::SetReturn(TVMValue* ret_val,
int* ret_type_code) {
APIVariantValue& rv = ret_value;
*ret_typeid = rv.type_id;
if (rv.type_id == kNodeHandle) {
*ret_type_code = rv.type_code;
if (rv.type_code == kNodeHandle) {
if (rv.sptr.get() != nullptr) {
ret_val->v_handle = new TVMAPINode(std::move(rv.sptr));
} else {
ret_val->v_handle = nullptr;
}
} else if (rv.type_code == kFuncHandle) {
ret_val->v_handle = new runtime::PackedFunc::FType(std::move(rv.func));
} else {
*ret_val = rv.v_union;
}
......
/*!
* Copyright (c) 2016 by Contributors
* Implementation of API functions related to IR build
* \file c_api_ir.cc
* Implementation of API functions related to Codegen
* \file c_api_codegen.cc
*/
#include <tvm/expr.h>
#include <tvm/ir.h>
......@@ -32,5 +32,25 @@ TVM_REGISTER_API(_codegen_SplitHostDevice)
*ret = SplitHostDevice(args.at(0));
});
// generate a dummy packed function for testing
void DummyHelloFunction(const TVMValue* args, const int* type_code, int num_args) {
LOG(INFO) << num_args << " arguments";
for (int i = 0; i < num_args; ++i) {
switch (type_code[i]) {
case kNull: LOG(INFO) << i << ":nullptr"; break;
case kFloat: LOG(INFO) << i << ": double=" << args[i].v_float64; break;
case kInt: LOG(INFO) << i << ": long=" << args[i].v_int64; break;
case kHandle: LOG(INFO) << i << ": handle=" << args[i].v_handle; break;
default: LOG(FATAL) << "unhandled type " << TVMTypeCode2Str(type_code[i]);
}
}
}
TVM_REGISTER_API(_codegen_DummyHelloFunction)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = runtime::PackedFunc(DummyHelloFunction);
});
} // namespace codegen
} // namespace tvm
......@@ -8,7 +8,7 @@
#include "./c_api_registry.h"
namespace dmlc {
DMLC_REGISTRY_ENABLE(::tvm::APIFunctionReg);
DMLC_REGISTRY_ENABLE(::tvm::APIFuncReg);
} // namespace dmlc
namespace tvm {
......@@ -18,7 +18,7 @@ using RetValue = APIVariantValue;
TVM_REGISTER_API(_format_str)
.set_body([](const ArgStack& args, RetValue *ret) {
CHECK(args.at(0).type_id == kNodeHandle);
CHECK(args.at(0).type_code == kNodeHandle);
std::ostringstream os;
os << args.at(0).operator NodeRef();
*ret = os.str();
......@@ -27,7 +27,7 @@ TVM_REGISTER_API(_format_str)
TVM_REGISTER_API(_raw_ptr)
.set_body([](const ArgStack& args, RetValue *ret) {
CHECK(args.at(0).type_id == kNodeHandle);
CHECK(args.at(0).type_code == kNodeHandle);
*ret = reinterpret_cast<int64_t>(args.at(0).sptr.get());
})
.add_argument("src", "NodeBase", "the node base");
......
......@@ -16,9 +16,9 @@ using RetValue = APIVariantValue;
TVM_REGISTER_API(_const)
.set_body([](const ArgStack& args, RetValue *ret) {
if (args.at(0).type_id == kLong) {
if (args.at(0).type_code == kInt) {
*ret = make_const(args.at(1), args.at(0).operator int64_t());
} else if (args.at(0).type_id == kDouble) {
} else if (args.at(0).type_code == kFloat) {
*ret = make_const(args.at(1), args.at(0).operator double());
} else {
LOG(FATAL) << "only accept int or float";
......@@ -31,19 +31,19 @@ TVM_REGISTER_API(_Array)
.set_body([](const ArgStack& args, RetValue *ret) {
std::vector<std::shared_ptr<Node> > data;
for (size_t i = 0; i < args.size(); ++i) {
CHECK(args.at(i).type_id == kNodeHandle)
CHECK(args.at(i).type_code == kNodeHandle)
<< "need content of array to be NodeBase";
data.push_back(args.at(i).sptr);
}
auto node = std::make_shared<ArrayNode>();
node->data = std::move(data);
ret->type_id = kNodeHandle;
ret->type_code = kNodeHandle;
ret->sptr = node;
});
TVM_REGISTER_API(_ArrayGetItem)
.set_body([](const ArgStack& args, RetValue *ret) {
CHECK(args.at(0).type_id == kNodeHandle);
CHECK(args.at(0).type_code == kNodeHandle);
int64_t i = args.at(1);
auto& sptr = args.at(0).sptr;
CHECK(sptr->is_type<ArrayNode>());
......@@ -51,12 +51,12 @@ TVM_REGISTER_API(_ArrayGetItem)
CHECK_LT(static_cast<size_t>(i), n->data.size())
<< "out of bound of array";
ret->sptr = n->data[i];
ret->type_id = kNodeHandle;
ret->type_code = kNodeHandle;
});
TVM_REGISTER_API(_ArraySize)
.set_body([](const ArgStack& args, RetValue *ret) {
CHECK(args.at(0).type_id == kNodeHandle);
CHECK(args.at(0).type_code == kNodeHandle);
auto& sptr = args.at(0).sptr;
CHECK(sptr->is_type<ArrayNode>());
*ret = static_cast<int64_t>(
......@@ -68,21 +68,21 @@ TVM_REGISTER_API(_Map)
CHECK_EQ(args.size() % 2, 0U);
MapNode::ContainerType data;
for (size_t i = 0; i < args.size(); i += 2) {
CHECK(args.at(i).type_id == kNodeHandle)
CHECK(args.at(i).type_code == kNodeHandle)
<< "need content of array to be NodeBase";
CHECK(args.at(i + 1).type_id == kNodeHandle)
CHECK(args.at(i + 1).type_code == kNodeHandle)
<< "need content of array to be NodeBase";
data.emplace(std::make_pair(args.at(i).sptr, args.at(i + 1).sptr));
}
auto node = std::make_shared<MapNode>();
node->data = std::move(data);
ret->type_id = kNodeHandle;
ret->type_code = kNodeHandle;
ret->sptr = node;
});
TVM_REGISTER_API(_MapSize)
.set_body([](const ArgStack& args, RetValue *ret) {
CHECK(args.at(0).type_id == kNodeHandle);
CHECK(args.at(0).type_code == kNodeHandle);
auto& sptr = args.at(0).sptr;
CHECK(sptr->is_type<MapNode>());
auto* n = static_cast<const MapNode*>(sptr.get());
......@@ -91,8 +91,8 @@ TVM_REGISTER_API(_MapSize)
TVM_REGISTER_API(_MapGetItem)
.set_body([](const ArgStack& args, RetValue *ret) {
CHECK(args.at(0).type_id == kNodeHandle);
CHECK(args.at(1).type_id == kNodeHandle);
CHECK(args.at(0).type_code == kNodeHandle);
CHECK(args.at(1).type_code == kNodeHandle);
auto& sptr = args.at(0).sptr;
CHECK(sptr->is_type<MapNode>());
auto* n = static_cast<const MapNode*>(sptr.get());
......@@ -100,13 +100,13 @@ TVM_REGISTER_API(_MapGetItem)
CHECK(it != n->data.end())
<< "cannot find the corresponding key in the Map";
ret->sptr = (*it).second;
ret->type_id = kNodeHandle;
ret->type_code = kNodeHandle;
});
TVM_REGISTER_API(_MapCount)
.set_body([](const ArgStack& args, RetValue *ret) {
CHECK(args.at(0).type_id == kNodeHandle);
CHECK(args.at(1).type_id == kNodeHandle);
CHECK(args.at(0).type_code == kNodeHandle);
CHECK(args.at(1).type_code == kNodeHandle);
auto& sptr = args.at(0).sptr;
CHECK(sptr->is_type<MapNode>());
auto* n = static_cast<const MapNode*>(sptr.get());
......@@ -115,7 +115,7 @@ TVM_REGISTER_API(_MapCount)
TVM_REGISTER_API(_MapItems)
.set_body([](const ArgStack& args, RetValue *ret) {
CHECK(args.at(0).type_id == kNodeHandle);
CHECK(args.at(0).type_code == kNodeHandle);
auto& sptr = args.at(0).sptr;
CHECK(sptr->is_type<MapNode>());
auto* n = static_cast<const MapNode*>(sptr.get());
......@@ -125,7 +125,7 @@ TVM_REGISTER_API(_MapItems)
rkvs->data.push_back(kv.second);
}
ret->sptr = rkvs;
ret->type_id = kNodeHandle;
ret->type_code = kNodeHandle;
});
TVM_REGISTER_API(Range)
......
......@@ -9,25 +9,25 @@
#include <tvm/base.h>
#include <tvm/expr.h>
#include <tvm/c_api.h>
#include <tvm/runtime/runtime.h>
#include <memory>
#include <limits>
#include <string>
#include <vector>
#include "../base/common.h"
using ArgVariant = TVMArg;
using ArgVariantID = TVMArgTypeID;
namespace tvm {
inline const char* TypeId2Str(ArgVariantID type_id) {
switch (type_id) {
case kNull: return "Null";
case kLong: return "Long";
case kDouble: return "Double";
case kStr: return "Str";
inline const char* TVMTypeCode2Str(int type_code) {
switch (type_code) {
case kInt: return "int";
case kFloat: return "float";
case kStr: return "str";
case kHandle: return "Handle";
case kNull: return "NULL";
case kNodeHandle: return "NodeHandle";
default: LOG(FATAL) << "unknown type_id=" << type_id; return "";
default: LOG(FATAL) << "unknown type_code="
<< static_cast<int>(type_code); return "";
}
}
......@@ -96,72 +96,83 @@ inline std::string NodeTypeName() {
class APIVariantValue {
public:
/*! \brief the type id */
ArgVariantID type_id{kNull};
int type_code{kNull};
/*! \brief shared pointer container */
std::shared_ptr<Node> sptr;
/*! \brief string container */
std::string str;
/*! \brief the variant holder */
ArgVariant v_union;
TVMValue v_union;
/*! \brief std::function */
runtime::PackedFunc::FType func;
// constructor
APIVariantValue() {}
APIVariantValue() {
}
// clear value
inline void Clear() {
}
// assign op
inline APIVariantValue& operator=(double value) {
type_id = kDouble;
v_union.v_double = value;
type_code = kFloat;
v_union.v_float64 = value;
return *this;
}
inline APIVariantValue& operator=(std::nullptr_t value) {
type_id = kNull;
type_code = kHandle;
v_union.v_handle = value;
return *this;
}
inline APIVariantValue& operator=(int64_t value) {
type_id = kLong;
v_union.v_long = value;
type_code = kInt;
v_union.v_int64 = value;
return *this;
}
inline APIVariantValue& operator=(bool value) {
type_id = kLong;
v_union.v_long = value;
type_code = kInt;
v_union.v_int64 = value;
return *this;
}
inline APIVariantValue& operator=(std::string value) {
type_id = kStr;
type_code = kStr;
str = std::move(value);
v_union.v_str = str.c_str();
return *this;
}
inline APIVariantValue& operator=(const NodeRef& ref) {
if (ref.node_.get() == nullptr) {
type_id = kNull;
type_code = kNull;
} else {
type_id = kNodeHandle;
type_code = kNodeHandle;
this->sptr = ref.node_;
}
return *this;
}
inline APIVariantValue& operator=(const runtime::PackedFunc& f) {
type_code = kFuncHandle;
this->func = f.body();
return *this;
}
inline APIVariantValue& operator=(const Type& value) {
return operator=(Type2String(value));
}
template<typename T,
typename = typename std::enable_if<std::is_base_of<NodeRef, T>::value>::type>
inline operator T() const {
if (type_id == kNull) return T();
CHECK_EQ(type_id, kNodeHandle);
if (type_code == kNull) return T();
CHECK_EQ(type_code, kNodeHandle);
CHECK(NodeTypeChecker<T>::Check(sptr.get()))
<< "Did not get expected type " << NodeTypeName<T>();
return T(sptr);
}
inline operator Expr() const {
if (type_id == kNull) return Expr();
if (type_id == kLong) return Expr(operator int());
if (type_id == kDouble) {
if (type_code == kNull) {
return Expr();
}
if (type_code == kInt) return Expr(operator int());
if (type_code == kFloat) {
return Expr(static_cast<float>(operator double()));
}
CHECK_EQ(type_id, kNodeHandle);
CHECK_EQ(type_code, kNodeHandle);
if (sptr->is_type<IterVarNode>()) {
return IterVar(sptr)->var;
} else {
......@@ -171,52 +182,58 @@ class APIVariantValue {
}
}
inline operator double() const {
CHECK_EQ(type_id, kDouble);
return v_union.v_double;
CHECK_EQ(type_code, kFloat);
return v_union.v_float64;
}
inline operator int64_t() const {
CHECK_EQ(type_id, kLong);
return v_union.v_long;
CHECK_EQ(type_code, kInt);
return v_union.v_int64;
}
inline operator uint64_t() const {
CHECK_EQ(type_id, kLong);
return v_union.v_long;
CHECK_EQ(type_code, kInt);
return v_union.v_int64;
}
inline operator int() const {
CHECK_EQ(type_id, kLong);
CHECK_LE(v_union.v_long,
CHECK_EQ(type_code, kInt);
CHECK_LE(v_union.v_int64,
std::numeric_limits<int>::max());
return v_union.v_long;
return v_union.v_int64;
}
inline operator bool() const {
CHECK_EQ(type_id, kLong)
<< "expect boolean(int) but get " << TypeId2Str(type_id);
return v_union.v_long != 0;
CHECK_EQ(type_code, kInt)
<< "expect boolean(int) but get "
<< TVMTypeCode2Str(type_code);
return v_union.v_int64 != 0;
}
inline operator std::string() const {
CHECK_EQ(type_id, kStr)
<< "expect Str but get " << TypeId2Str(type_id);
CHECK_EQ(type_code, kStr)
<< "expect Str but get "
<< TVMTypeCode2Str(type_code);
return str;
}
inline operator Type() const {
return String2Type(operator std::string());
}
inline operator runtime::PackedFunc() const {
CHECK_EQ(type_code, kFuncHandle);
return runtime::PackedFunc(func);
}
};
// common defintiion of API function.
using APIFunction = std::function<
using APIFunc = std::function<
void(const std::vector<APIVariantValue> &args, APIVariantValue* ret)>;
/*!
* \brief Registry entry for DataIterator factory functions.
*/
struct APIFunctionReg
: public dmlc::FunctionRegEntryBase<APIFunctionReg,
APIFunction> {
struct APIFuncReg
: public dmlc::FunctionRegEntryBase<APIFuncReg,
APIFunc> {
};
#define TVM_REGISTER_API(TypeName) \
DMLC_REGISTRY_REGISTER(::tvm::APIFunctionReg, APIFunctionReg, TypeName) \
#define TVM_REGISTER_API(TypeName) \
DMLC_REGISTRY_REGISTER(::tvm::APIFuncReg, APIFuncReg, TypeName) \
} // namespace tvm
......
......@@ -7,7 +7,6 @@
#define TVM_CODEGEN_CODEGEN_C_H_
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <tvm/module.h>
#include <string>
#include <unordered_map>
......
......@@ -3,7 +3,8 @@
* \file c_runtime_api.cc
* \brief Device specific implementations
*/
#include <tvm/c_runtime_api.h>
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/runtime.h>
#include <algorithm>
#include "./runtime_base.h"
#include "./device_api.h"
......@@ -34,7 +35,7 @@ inline void TVMArrayFree_(TVMArray* arr) {
delete arr;
}
inline void VerifyType(TVMDataType dtype) {
inline void VerifyType(TVMType dtype) {
CHECK_GE(dtype.lanes, 1U);
if (dtype.type_code == kFloat) {
CHECK_EQ(dtype.bits % 32U, 0U);
......@@ -98,7 +99,7 @@ int TVMContextEnabled(TVMContext ctx,
int TVMArrayAlloc(const tvm_index_t* shape,
tvm_index_t ndim,
TVMDataType dtype,
TVMType dtype,
TVMContext ctx,
TVMArrayHandle* out) {
TVMArray* arr = nullptr;
......@@ -166,3 +167,19 @@ int TVMSynchronize(TVMContext ctx, TVMStreamHandle stream) {
});
API_END();
}
int TVMFuncFree(TVMFunctionHandle func) {
API_BEGIN();
delete static_cast<PackedFunc::FType*>(func);
API_END();
}
int TVMFuncCall(TVMFunctionHandle func,
TVMValue* args,
int* arg_type_codes,
int num_args) {
API_BEGIN();
(*static_cast<const PackedFunc::FType*>(func))(
args, arg_type_codes, num_args);
API_END();
}
......@@ -7,7 +7,7 @@
#define TVM_RUNTIME_DEVICE_API_H_
#include <tvm/base.h>
#include <tvm/c_runtime_api.h>
#include <tvm/runtime/c_runtime_api.h>
namespace tvm {
namespace runtime {
......
......@@ -6,7 +6,7 @@
#ifndef TVM_RUNTIME_RUNTIME_BASE_H_
#define TVM_RUNTIME_RUNTIME_BASE_H_
#include <tvm/c_runtime_api.h>
#include <tvm/runtime/c_runtime_api.h>
#include <stdexcept>
/*! \brief macro to guard beginning and end section of all functions */
......
#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <tvm/runtime/runtime.h>
TEST(PackedFunc, Basic) {
using namespace tvm::runtime;
int x = 0;
void* handle = &x;
TVMArray a;
PackedFunc([&](const TVMValue* args, const int* type_codes, int num_args) {
CHECK(num_args == 3);
CHECK(args[0].v_float64 == 1.0);
CHECK(type_codes[0] == kFloat);
CHECK(args[1].v_handle == &a);
CHECK(type_codes[1] == kHandle);
CHECK(args[2].v_handle == &x);
CHECK(type_codes[2] == kHandle);
})(1.0, &a, handle);
}
int main(int argc, char ** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
return RUN_ALL_TESTS();
}
......@@ -15,11 +15,11 @@ def mock_test_add():
thread_x = tvm.IterVar((0, num_thread), thread_tag="threadIdx.x")
_, x = s[C].split(C.op.axis[0], factor=num_thread, outer=grid_x)
_, x = s[C].split(x, outer=thread_x)
# compile to IR
bounds = tvm.schedule.InferBound(s)
stmt = tvm.ir_pass.ScheduleOps(s, bounds)
Ab = tvm.Buffer(A.shape, A.dtype, name='A')
Bb = tvm.Buffer(B.shape, B.dtype, name='B')
Cb = tvm.Buffer(C.shape, C.dtype, name='C')
......
import tvm
import numpy as np
def test_function():
ctx = tvm.cpu(0)
x = np.random.randint(0, 10, size=(3, 4))
x = np.array(x)
y = tvm.nd.array(x, ctx=ctx)
f = tvm.codegen.DummyHelloFunction()
f(y, 10)
if __name__ == "__main__":
test_function()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment