Commit 01a7ce0c by Tianqi Chen Committed by GitHub

[RUNTIME] Add Function, Unify TVMTypeCode and TVMArgTypeID (#24)

parent 4f1473f3
...@@ -6,11 +6,11 @@ ...@@ -6,11 +6,11 @@
#ifndef TVM_C_API_H_ #ifndef TVM_C_API_H_
#define TVM_C_API_H_ #define TVM_C_API_H_
#include "./c_runtime_api.h" #include "./runtime/c_runtime_api.h"
TVM_EXTERN_C { TVM_EXTERN_C {
/*! \brief handle to functions */ /*! \brief handle to functions */
typedef void* APIFunctionHandle; typedef void* APIFuncHandle;
/*! \brief handle to node */ /*! \brief handle to node */
typedef void* NodeHandle; typedef void* NodeHandle;
...@@ -18,16 +18,18 @@ typedef void* NodeHandle; ...@@ -18,16 +18,18 @@ typedef void* NodeHandle;
* \brief List all the node function name * \brief List all the node function name
* \param out_size The number of functions * \param out_size The number of functions
* \param out_array The array of function names. * \param out_array The array of function names.
* \return 0 when success, -1 when failure happens
*/ */
TVM_DLL int TVMListAPIFunctionNames(int *out_size, TVM_DLL int TVMListAPIFuncNames(int *out_size,
const char*** out_array); const char*** out_array);
/*! /*!
* \brief get function handle by name * \brief get function handle by name
* \param name The name of function * \param name The name of function
* \param handle The returning function handle * \param handle The returning function handle
* \return 0 when success, -1 when failure happens
*/ */
TVM_DLL int TVMGetAPIFunctionHandle(const char* name, TVM_DLL int TVMGetAPIFuncHandle(const char* name,
APIFunctionHandle *handle); APIFuncHandle *handle);
/*! /*!
* \brief Get the detailed information about function. * \brief Get the detailed information about function.
...@@ -42,7 +44,7 @@ TVM_DLL int TVMGetAPIFunctionHandle(const char* name, ...@@ -42,7 +44,7 @@ TVM_DLL int TVMGetAPIFunctionHandle(const char* name,
* \param return_type Return type of the function, if any. * \param return_type Return type of the function, if any.
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
TVM_DLL int TVMGetAPIFunctionInfo(APIFunctionHandle handle, TVM_DLL int TVMGetAPIFuncInfo(APIFuncHandle handle,
const char **real_name, const char **real_name,
const char **description, const char **description,
int *num_doc_args, int *num_doc_args,
...@@ -55,11 +57,13 @@ TVM_DLL int TVMGetAPIFunctionInfo(APIFunctionHandle handle, ...@@ -55,11 +57,13 @@ TVM_DLL int TVMGetAPIFunctionInfo(APIFunctionHandle handle,
* \brief Push an argument to the function calling stack. * \brief Push an argument to the function calling stack.
* If push fails, the stack will be reset to empty * If push fails, the stack will be reset to empty
* *
* \param arg number of attributes * \param arg The argument
* \param type_id The typeid of attributes. * \param type_code The type_code of argument as in TVMTypeCode
* \return 0 when success, -1 when failure happens
* \note API calls always exchanges with type bits=64, lanes=1
*/ */
TVM_DLL int TVMAPIPushStack(TVMArg arg, TVM_DLL int TVMAPIPushStack(TVMValue arg,
int type_id); int type_code);
/*! /*!
* \brief call a function by using arguments in the stack. * \brief call a function by using arguments in the stack.
...@@ -67,15 +71,18 @@ TVM_DLL int TVMAPIPushStack(TVMArg arg, ...@@ -67,15 +71,18 @@ TVM_DLL int TVMAPIPushStack(TVMArg arg,
* *
* \param handle The function handle * \param handle The function handle
* \param ret_val The return value. * \param ret_val The return value.
* \param ret_typeid the type id of return value. * \param ret_type_code the type code of return value.
* \return 0 when success, -1 when failure happens
* \note API calls always exchanges with type bits=64, lanes=1
*/ */
TVM_DLL int TVMAPIFunctionCall(APIFunctionHandle handle, TVM_DLL int TVMAPIFuncCall(APIFuncHandle handle,
TVMArg* ret_val, TVMValue* ret_val,
int* ret_typeid); int* ret_type_code);
/*! /*!
* \brief free the node handle * \brief free the node handle
* \param handle The node handle to be freed. * \param handle The node handle to be freed.
* \return 0 when success, -1 when failure happens
*/ */
TVM_DLL int TVMNodeFree(NodeHandle handle); TVM_DLL int TVMNodeFree(NodeHandle handle);
...@@ -84,13 +91,15 @@ TVM_DLL int TVMNodeFree(NodeHandle handle); ...@@ -84,13 +91,15 @@ TVM_DLL int TVMNodeFree(NodeHandle handle);
* \param handle The node handle * \param handle The node handle
* \param key The attribute name * \param key The attribute name
* \param out_value The attribute value * \param out_value The attribute value
* \param out_typeid The typeid of the attribute. * \param out_type_code The type code of the attribute.
* \param out_success Whether get is successful. * \param out_success Whether get is successful.
* \return 0 when success, -1 when failure happens
* \note API calls always exchanges with type bits=64, lanes=1
*/ */
TVM_DLL int TVMNodeGetAttr(NodeHandle handle, TVM_DLL int TVMNodeGetAttr(NodeHandle handle,
const char* key, const char* key,
TVMArg* out_value, TVMValue* out_value,
int* out_typeid, int* out_type_code,
int* out_success); int* out_success);
/*! /*!
...@@ -98,6 +107,7 @@ TVM_DLL int TVMNodeGetAttr(NodeHandle handle, ...@@ -98,6 +107,7 @@ TVM_DLL int TVMNodeGetAttr(NodeHandle handle,
* \param handle The node handle * \param handle The node handle
* \param out_size The number of functions * \param out_size The number of functions
* \param out_array The array of function names. * \param out_array The array of function names.
* \return 0 when success, -1 when failure happens
*/ */
TVM_DLL int TVMNodeListAttrNames(NodeHandle handle, TVM_DLL int TVMNodeListAttrNames(NodeHandle handle,
int *out_size, int *out_size,
......
...@@ -10,6 +10,8 @@ ...@@ -10,6 +10,8 @@
#include "./base.h" #include "./base.h"
#include "./expr.h" #include "./expr.h"
#include "./module.h" #include "./module.h"
#include "./runtime/runtime.h"
namespace tvm { namespace tvm {
/*! \brief namespace for lowlevel IR pass and codegen */ /*! \brief namespace for lowlevel IR pass and codegen */
...@@ -62,6 +64,9 @@ Array<Var> UndefinedVars(const LoweredFunc& f); ...@@ -62,6 +64,9 @@ Array<Var> UndefinedVars(const LoweredFunc& f);
*/ */
Array<LoweredFunc> SplitHostDevice(LoweredFunc func); Array<LoweredFunc> SplitHostDevice(LoweredFunc func);
runtime::PackedFunc BuildStackVM(LoweredFunc func);
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
......
...@@ -78,6 +78,14 @@ constexpr const char* tvm_array_get_field = "tvm_array_get_field"; ...@@ -78,6 +78,14 @@ constexpr const char* tvm_array_get_field = "tvm_array_get_field";
* } * }
*/ */
constexpr const char* tvm_handle_is_null = "tvm_handle_is_null"; constexpr const char* tvm_handle_is_null = "tvm_handle_is_null";
/*!
* \brief See pesudo code
*
* bool tvm_print(VType value) {
* LOG(INFO) << value;
* }
*/
constexpr const char* tvm_print = "tvm_print";
/*! \brief The field id of each field in array */ /*! \brief The field id of each field in array */
enum TVMArrayFieldKind { enum TVMArrayFieldKind {
......
...@@ -9,8 +9,8 @@ ...@@ -9,8 +9,8 @@
* So this is a minimum runtime code gluing, and some limited * So this is a minimum runtime code gluing, and some limited
* memory management code to enable quick testing. * memory management code to enable quick testing.
*/ */
#ifndef TVM_C_RUNTIME_API_H_ #ifndef TVM_RUNTIME_C_RUNTIME_API_H_
#define TVM_C_RUNTIME_API_H_ #define TVM_RUNTIME_C_RUNTIME_API_H_
#ifdef __cplusplus #ifdef __cplusplus
#define TVM_EXTERN_C extern "C" #define TVM_EXTERN_C extern "C"
...@@ -38,27 +38,51 @@ TVM_EXTERN_C { ...@@ -38,27 +38,51 @@ TVM_EXTERN_C {
typedef uint32_t tvm_index_t; typedef uint32_t tvm_index_t;
/*! /*!
* \brief union type for arguments and return values * \brief Union type of values
* in both runtime API and TVM API calls * being passed through API and function calls.
*/ */
typedef union { typedef union {
long v_long; // NOLINT(*) int64_t v_int64;
double v_double; double v_float64;
const char* v_str;
void* v_handle; void* v_handle;
} TVMArg; const char* v_str;
} TVMValue;
/*! /*!
* \brief The type index in TVM. * \brief The type code in TVMType
* \note TVMType is used in two places.
*/ */
typedef enum { typedef enum {
kNull = 0, kInt = 0U,
kLong = 1, kUInt = 1U,
kDouble = 2, kFloat = 2U,
kStr = 3, kHandle = 3U,
kNodeHandle = 4, // The next few fields are extension types
kArrayHandle = 5 // that is used by TVM API calls.
} TVMArgTypeID; kNull = 4U,
kNodeHandle = 5U,
kStr = 6U,
kFuncHandle = 7U
} TVMTypeCode;
/*!
* \brief The data type used in TVM Runtime.
*
* Examples
* - float: type_code = 2, bits = 32, lanes=1
* - float4(vectorized 4 float): type_code = 2, bits = 32, lanes=4
* - int8: type_code = 0, bits = 8, lanes=1
*
* \note Arguments TVM API function always takes bits=64 and lanes=1
*/
typedef struct {
/*! \brief type code, in TVMTypeCode */
uint8_t type_code;
/*! \brief number of bits of the type */
uint8_t bits;
/*! \brief number of lanes, */
uint16_t lanes;
} TVMType;
/*! /*!
* \brief The device type * \brief The device type
...@@ -82,29 +106,6 @@ typedef struct { ...@@ -82,29 +106,6 @@ typedef struct {
int dev_id; int dev_id;
} TVMContext; } TVMContext;
/*! \brief The type code in TVMDataType */
typedef enum {
kInt = 0U,
kUInt = 1U,
kFloat = 2U
} TVMTypeCode;
/*!
* \brief the data type
* Examples
* - float: type_code = 2, bits = 32, lanes=1
* - float4(vectorized 4 float): type_code = 2, bits = 32, lanes=4
* - int8: type_code = 0, bits = 8, lanes=1
*/
typedef struct {
/*! \brief type code, in TVMTypeCode */
uint8_t type_code;
/*! \brief number of bits of the type */
uint8_t bits;
/*! \brief number of lanes, */
uint16_t lanes;
} TVMDataType;
/*! /*!
* \brief Data structure representing a n-dimensional array(tensor). * \brief Data structure representing a n-dimensional array(tensor).
* This is used to pass data specification into TVM. * This is used to pass data specification into TVM.
...@@ -122,7 +123,7 @@ typedef struct { ...@@ -122,7 +123,7 @@ typedef struct {
/*! \brief number of dimensions of the array */ /*! \brief number of dimensions of the array */
tvm_index_t ndim; tvm_index_t ndim;
/*! \brief The data type flag */ /*! \brief The data type flag */
TVMDataType dtype; TVMType dtype;
/*! \brief The device context this array sits on */ /*! \brief The device context this array sits on */
TVMContext ctx; TVMContext ctx;
} TVMArray; } TVMArray;
...@@ -191,7 +192,7 @@ TVM_DLL int TVMContextEnabled(TVMContext ctx, ...@@ -191,7 +192,7 @@ TVM_DLL int TVMContextEnabled(TVMContext ctx,
*/ */
TVM_DLL int TVMArrayAlloc(const tvm_index_t* shape, TVM_DLL int TVMArrayAlloc(const tvm_index_t* shape,
tvm_index_t ndim, tvm_index_t ndim,
TVMDataType dtype, TVMType dtype,
TVMContext ctx, TVMContext ctx,
TVMArrayHandle* out); TVMArrayHandle* out);
/*! /*!
...@@ -217,45 +218,27 @@ TVM_DLL int TVMArrayCopyFromTo(TVMArrayHandle from, ...@@ -217,45 +218,27 @@ TVM_DLL int TVMArrayCopyFromTo(TVMArrayHandle from,
TVM_DLL int TVMSynchronize(TVMContext ctx, TVMStreamHandle stream); TVM_DLL int TVMSynchronize(TVMContext ctx, TVMStreamHandle stream);
/*! /*!
* \brief TVM Function API: Get resource requirement * \brief Free the function when it is no longer needed.
* * \param func The function handle
* By default TVM function try not to do internal allocations. * \return whether
* Instead, TVMFuncRequirement can be called, given the input arguments.
*
* \param func function handle to be launched.
* \param args The arguments
* \param arg_type_ids The type id of the arguments
* \param num_args Number of arguments.
* \param out_workspace_size The workspace size needed to launch this function.
* \param out_workspace_align The alignment requirement of workspace.
*
* \note The data pointer in the arrays is not used by requirement.
*/ */
TVM_DLL int TVMFuncRequirement(TVMFunctionHandle func, TVM_DLL int TVMFuncFree(TVMFunctionHandle func);
TVMArg* args,
int* arg_type_ids,
int num_args,
size_t* out_workspace_size,
size_t* out_workspace_align);
/*! /*!
* \brief TVM Function API: Launch generated function. * \brief Call a function whose parameters are all packed.
* *
* \param func function handle to be launched. * \param func node handle of the function.
* \param args The arguments * \param args The arguments
* \param arg_type_ids The type id of the arguments * \param type_codes The type codes of the arguments
* \param num_args Number of arguments. * \param num_args Number of arguments.
* \param stream The stream this function to be launched on.
* \param workspace Additional workspace used to launch this function.
* *
* \sa TVMFuncRequirement * \return 0 when success, -1 when failure happens
* \note TVM calls always exchanges with type bits=64, lanes=1
*/ */
TVM_DLL int TVMFuncLaunch(TVMFunctionHandle func, TVM_DLL int TVMFuncCall(TVMFunctionHandle func,
TVMArg* args, TVMValue* args,
int* arg_type_ids, int* type_codes,
int num_args, int num_args);
TVMStreamHandle stream,
TVMArrayHandle workspace);
} // TVM_EXTERN_C } // TVM_EXTERN_C
#endif // TVM_C_RUNTIME_API_H_ #endif // TVM_RUNTIME_C_RUNTIME_API_H_
/*!
* Copyright (c) 2016 by Contributors
* \file runtime.h
* \brief Runtime related c++ class.
*/
#ifndef TVM_RUNTIME_RUNTIME_H_
#define TVM_RUNTIME_RUNTIME_H_
#include <functional>
#include <tuple>
#include "./c_runtime_api.h"
namespace tvm {
namespace runtime {
/*!
* \brief Packed function is a runtime function
* whose argument type_codes are erased by packed format.
*
* This is an useful unified interface to call generated functions.
*/
class PackedFunc {
public:
/*! \brief The internal std::function */
using FType = std::function<void(const TVMValue* args, const int* type_codes, int num_args)>;
PackedFunc() {}
explicit PackedFunc(FType body) : body_(body) {}
/*!
* \brief invoke the packed function by directly passing in arguments.
* \param args Arguments to be passed.
* \tparam Args arguments to be passed.
* \return The first return value.
*/
template<typename... Args>
inline void operator()(Args&& ...args) const;
/*!
* \brief Call the function in packed format.
* \param args The arguments
* \param type_codes The type_codes of the arguments
* \param num_args Number of arguments.
*/
inline void CallPacked(const TVMValue* args, const int* type_codes, int num_args) const;
/*! \return the internal body function */
inline FType body() const {
return body_;
}
private:
/*! \brief internal container of packed function */
FType body_;
};
// implementations
inline void PackedFunc::CallPacked(
const TVMValue* args, const int* type_codes, int num_args) const {
body_(args, type_codes, num_args);
}
template<bool stop, std::size_t I, typename F, typename ...Args>
struct for_each_dispatcher_ {
static inline void run(const std::tuple<Args...>& args, F f) {
f(I, std::get<I>(args));
for_each_dispatcher_<(I + 1) == sizeof...(Args), (I+1), F, Args...>::run(args, f);
}
};
template<std::size_t I, typename F, typename ...Args>
struct for_each_dispatcher_<true, I, F, Args...> {
static inline void run(const std::tuple<Args...>& args, F f) {}
};
template<typename F, typename ...Args>
inline void for_each(const std::tuple<Args...>& args, F f) {
for_each_dispatcher_<sizeof...(Args) == 0, 0, F, Args...>::run(args, f);
}
namespace arg_setter {
template<typename T>
inline void Set(TVMValue& arg, int& t, T v); // NOLINT(*)
template<>
inline void Set<double>(TVMValue& arg, int& t, double value) { // NOLINT(*)
arg.v_float64 = value;
t = kFloat;
}
template<>
inline void Set<int>(TVMValue& arg, int& t, int value) { // NOLINT(*)
arg.v_int64 = value;
t = kInt;
}
template<>
inline void Set<long>(TVMValue& arg, int& t, long value) { // NOLINT(*)
arg.v_int64 = value;
t = kInt;
}
template<>
inline void Set<TVMArray*>(TVMValue& arg, int& t, TVMArray* value) { // NOLINT(*)
arg.v_handle = value;
t = kHandle;
}
template<>
inline void Set<void*>(TVMValue& arg, int& t, void* value) { // NOLINT(*)
arg.v_handle = value;
t = kHandle;
}
} // namespace arg_setter
struct PackedFuncArgSetter {
TVMValue* args;
int* type_codes;
template<typename T>
inline void operator()(size_t i, T v) const {
arg_setter::Set(args[i], type_codes[i], v);
}
};
template<typename... Args>
inline void PackedFunc::operator()(Args&& ...args) const {
auto targ = std::make_tuple(std::forward<Args>(args)...);
const int kNumArgs = sizeof...(Args);
TVMValue tvm_args[kNumArgs];
int tvm_arg_type_ids[kNumArgs];
for_each(targ, PackedFuncArgSetter{tvm_args, tvm_arg_type_ids});
body_(tvm_args, tvm_arg_type_ids, kNumArgs);
}
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_RUNTIME_H_
...@@ -16,4 +16,4 @@ from . import ndarray as nd ...@@ -16,4 +16,4 @@ from . import ndarray as nd
from .ndarray import cpu, gpu, opencl, init_opencl from .ndarray import cpu, gpu, opencl, init_opencl
from ._base import TVMError from ._base import TVMError
from .function import * from .api import *
"""namespace of internal API"""
...@@ -11,24 +11,26 @@ from numbers import Number, Integral ...@@ -11,24 +11,26 @@ from numbers import Number, Integral
from .._base import _LIB from .._base import _LIB
from .._base import c_str, py_str, string_types from .._base import c_str, py_str, string_types
from .._base import check_call, ctypes2docstring from .._base import check_call, ctypes2docstring
from .. import _function_internal from .. import _api_internal
from . import _runtime_api
class TVMArg(ctypes.Union): from ._types import TVMValue, TypeCode
"""TVMArg in C API"""
_fields_ = [("v_long", ctypes.c_long),
("v_double", ctypes.c_double),
("v_str", ctypes.c_char_p),
("v_handle", ctypes.c_void_p)]
# type definitions # type definitions
APIFunctionHandle = ctypes.c_void_p APIFuncHandle = ctypes.c_void_p
NodeHandle = ctypes.c_void_p NodeHandle = ctypes.c_void_p
FunctionHandle = ctypes.c_void_p
class APIType(object):
"""TVMType used in API calls"""
INT = ctypes.c_int(TypeCode.INT)
UINT = ctypes.c_int(TypeCode.UINT)
FLOAT = ctypes.c_int(TypeCode.FLOAT)
HANDLE = ctypes.c_int(TypeCode.HANDLE)
NULL = ctypes.c_int(TypeCode.NULL)
NODE_HANDLE = ctypes.c_int(TypeCode.NODE_HANDLE)
STR = ctypes.c_int(TypeCode.STR)
FUNC_HANDLE = ctypes.c_int(TypeCode.FUNC_HANDLE)
kNull = 0
kLong = 1
kDouble = 2
kStr = 3
kNodeHandle = 4
NODE_TYPE = { NODE_TYPE = {
} }
...@@ -37,22 +39,31 @@ def _return_node(x): ...@@ -37,22 +39,31 @@ def _return_node(x):
handle = x.v_handle handle = x.v_handle
if not isinstance(handle, NodeHandle): if not isinstance(handle, NodeHandle):
handle = NodeHandle(handle) handle = NodeHandle(handle)
ret_val = TVMArg() ret_val = TVMValue()
ret_typeid = ctypes.c_int() ret_type_code = ctypes.c_int()
ret_success = ctypes.c_int() ret_success = ctypes.c_int()
check_call(_LIB.TVMNodeGetAttr( check_call(_LIB.TVMNodeGetAttr(
handle, c_str("type_key"), handle, c_str("type_key"),
ctypes.byref(ret_val), ctypes.byref(ret_val),
ctypes.byref(ret_typeid), ctypes.byref(ret_type_code),
ctypes.byref(ret_success))) ctypes.byref(ret_success)))
return NODE_TYPE.get(py_str(ret_val.v_str), NodeBase)(handle) return NODE_TYPE.get(py_str(ret_val.v_str), NodeBase)(handle)
def _return_func(x):
handle = x.v_handle
if not isinstance(handle, FunctionHandle):
handle = FunctionHandle(handle)
return _runtime_api._function_cls(handle)
RET_SWITCH = { RET_SWITCH = {
kNull: lambda x: None, TypeCode.NULL: lambda x: None,
kLong: lambda x: x.v_long, TypeCode.INT: lambda x: x.v_int64,
kDouble: lambda x: x.v_double, TypeCode.FLOAT: lambda x: x.v_float64,
kStr: lambda x: py_str(x.v_str), TypeCode.STR: lambda x: py_str(x.v_str),
kNodeHandle: _return_node TypeCode.NODE_HANDLE: _return_node,
TypeCode.FUNC_HANDLE: _return_func
} }
class SliceBase(object): class SliceBase(object):
...@@ -74,28 +85,28 @@ class NodeBase(object): ...@@ -74,28 +85,28 @@ class NodeBase(object):
self.handle = handle self.handle = handle
def __repr__(self): def __repr__(self):
return _function_internal._format_str(self) return _api_internal._format_str(self)
def __del__(self): def __del__(self):
check_call(_LIB.TVMNodeFree(self.handle)) check_call(_LIB.TVMNodeFree(self.handle))
def __getattr__(self, name): def __getattr__(self, name):
ret_val = TVMArg() ret_val = TVMValue()
ret_typeid = ctypes.c_int() ret_type_code = ctypes.c_int()
ret_success = ctypes.c_int() ret_success = ctypes.c_int()
check_call(_LIB.TVMNodeGetAttr( check_call(_LIB.TVMNodeGetAttr(
self.handle, c_str(name), self.handle, c_str(name),
ctypes.byref(ret_val), ctypes.byref(ret_val),
ctypes.byref(ret_typeid), ctypes.byref(ret_type_code),
ctypes.byref(ret_success))) ctypes.byref(ret_success)))
value = RET_SWITCH[ret_typeid.value](ret_val) value = RET_SWITCH[ret_type_code.value](ret_val)
if not ret_success.value: if not ret_success.value:
raise AttributeError( raise AttributeError(
"'%s' object has no attribute '%s'" % (str(type(self)), name)) "'%s' object has no attribute '%s'" % (str(type(self)), name))
return value return value
def __hash__(self): def __hash__(self):
return _function_internal._raw_ptr(self) return _api_internal._raw_ptr(self)
def __eq__(self, other): def __eq__(self, other):
if not isinstance(other, NodeBase): if not isinstance(other, NodeBase):
...@@ -121,7 +132,7 @@ class NodeBase(object): ...@@ -121,7 +132,7 @@ class NodeBase(object):
def __getstate__(self): def __getstate__(self):
handle = self.handle handle = self.handle
if handle is not None: if handle is not None:
return {'handle': _function_internal._save_json(self)} return {'handle': _api_internal._save_json(self)}
else: else:
return {'handle': None} return {'handle': None}
...@@ -131,7 +142,7 @@ class NodeBase(object): ...@@ -131,7 +142,7 @@ class NodeBase(object):
if handle is not None: if handle is not None:
json_str = handle json_str = handle
_push_arg(json_str) _push_arg(json_str)
other = _function_internal._load_json(json_str) other = _api_internal._load_json(json_str)
self.handle = other.handle self.handle = other.handle
other.handle = None other.handle = None
else: else:
...@@ -145,7 +156,7 @@ def const(value, dtype=None): ...@@ -145,7 +156,7 @@ def const(value, dtype=None):
dtype = 'int32' dtype = 'int32'
else: else:
dtype = 'float32' dtype = 'float32'
return _function_internal._const(value, dtype) return _api_internal._const(value, dtype)
def convert(value): def convert(value):
...@@ -154,7 +165,7 @@ def convert(value): ...@@ -154,7 +165,7 @@ def convert(value):
return const(value) return const(value)
elif isinstance(value, (list, tuple)): elif isinstance(value, (list, tuple)):
value = [convert(x) for x in value] value = [convert(x) for x in value]
return _function_internal._Array(*value) return _api_internal._Array(*value)
elif isinstance(value, dict): elif isinstance(value, dict):
vlist = [] vlist = []
for it in value.items(): for it in value.items():
...@@ -162,7 +173,7 @@ def convert(value): ...@@ -162,7 +173,7 @@ def convert(value):
raise ValueError("key of map must already been a container type") raise ValueError("key of map must already been a container type")
vlist.append(it[0]) vlist.append(it[0])
vlist.append(convert(it[1])) vlist.append(convert(it[1]))
return _function_internal._Map(*vlist) return _api_internal._Map(*vlist)
elif isinstance(value, SliceBase): elif isinstance(value, SliceBase):
return value.tensor(*value.indices) return value.tensor(*value.indices)
else: else:
...@@ -172,21 +183,21 @@ def convert(value): ...@@ -172,21 +183,21 @@ def convert(value):
def _push_arg(arg): def _push_arg(arg):
a = TVMArg() a = TVMValue()
if arg is None: if arg is None:
_LIB.TVMAPIPushStack(a, ctypes.c_int(kNull)) _LIB.TVMAPIPushStack(a, APIType.NULL)
elif isinstance(arg, NodeBase): elif isinstance(arg, NodeBase):
a.v_handle = arg.handle a.v_handle = arg.handle
_LIB.TVMAPIPushStack(a, ctypes.c_int(kNodeHandle)) _LIB.TVMAPIPushStack(a, APIType.NODE_HANDLE)
elif isinstance(arg, int): elif isinstance(arg, Integral):
a.v_long = ctypes.c_long(arg) a.v_int64 = ctypes.c_int64(arg)
_LIB.TVMAPIPushStack(a, ctypes.c_int(kLong)) _LIB.TVMAPIPushStack(a, APIType.INT)
elif isinstance(arg, Number): elif isinstance(arg, Number):
a.v_double = ctypes.c_double(arg) a.v_double = ctypes.c_double(arg)
_LIB.TVMAPIPushStack(a, ctypes.c_int(kDouble)) _LIB.TVMAPIPushStack(a, APIType.FLOAT)
elif isinstance(arg, string_types): elif isinstance(arg, string_types):
a.v_str = c_str(arg) a.v_str = c_str(arg)
_LIB.TVMAPIPushStack(a, ctypes.c_int(kStr)) _LIB.TVMAPIPushStack(a, APIType.STR)
else: else:
raise TypeError("Don't know how to handle type %s" % type(arg)) raise TypeError("Don't know how to handle type %s" % type(arg))
...@@ -201,7 +212,7 @@ def _make_function(handle, name): ...@@ -201,7 +212,7 @@ def _make_function(handle, name):
arg_descs = ctypes.POINTER(ctypes.c_char_p)() arg_descs = ctypes.POINTER(ctypes.c_char_p)()
ret_type = ctypes.c_char_p() ret_type = ctypes.c_char_p()
check_call(_LIB.TVMGetAPIFunctionInfo( check_call(_LIB.TVMGetAPIFuncInfo(
handle, ctypes.byref(real_name), ctypes.byref(desc), handle, ctypes.byref(real_name), ctypes.byref(desc),
ctypes.byref(num_args), ctypes.byref(num_args),
ctypes.byref(arg_names), ctypes.byref(arg_names),
...@@ -214,13 +225,7 @@ def _make_function(handle, name): ...@@ -214,13 +225,7 @@ def _make_function(handle, name):
desc = py_str(desc.value) desc = py_str(desc.value)
doc_str = ('%s\n\n' + doc_str = ('%s\n\n' +
'%s\n' + '%s\n')
'name : string, optional.\n' +
' Name of the resulting symbol.\n\n' +
'Returns\n' +
'-------\n' +
'symbol: Symbol\n' +
' The result symbol.')
doc_str = doc_str % (desc, param_str) doc_str = doc_str % (desc, param_str)
arg_names = [py_str(arg_names[i]) for i in range(num_args.value)] arg_names = [py_str(arg_names[i]) for i in range(num_args.value)]
...@@ -235,11 +240,11 @@ def _make_function(handle, name): ...@@ -235,11 +240,11 @@ def _make_function(handle, name):
for arg in cargs: for arg in cargs:
_push_arg(arg) _push_arg(arg)
ret_val = TVMArg() ret_val = TVMValue()
ret_typeid = ctypes.c_int() ret_type_code = ctypes.c_int()
check_call(_LIB.TVMAPIFunctionCall( check_call(_LIB.TVMAPIFuncCall(
handle, ctypes.byref(ret_val), ctypes.byref(ret_typeid))) handle, ctypes.byref(ret_val), ctypes.byref(ret_type_code)))
return RET_SWITCH[ret_typeid.value](ret_val) return RET_SWITCH[ret_type_code.value](ret_val)
func.__name__ = func_name func.__name__ = func_name
func.__doc__ = doc_str func.__doc__ = doc_str
...@@ -265,19 +270,19 @@ def register_node(type_key=None): ...@@ -265,19 +270,19 @@ def register_node(type_key=None):
NODE_TYPE[cls.__name__] = cls NODE_TYPE[cls.__name__] = cls
return cls return cls
def _init_function_module(root_namespace): def _init_api_module(root_namespace):
"""List and add all the functions to current module.""" """List and add all the functions to current module."""
plist = ctypes.POINTER(ctypes.c_char_p)() plist = ctypes.POINTER(ctypes.c_char_p)()
size = ctypes.c_uint() size = ctypes.c_uint()
check_call(_LIB.TVMListAPIFunctionNames(ctypes.byref(size), check_call(_LIB.TVMListAPIFuncNames(ctypes.byref(size),
ctypes.byref(plist))) ctypes.byref(plist)))
op_names = [] op_names = []
for i in range(size.value): for i in range(size.value):
op_names.append(py_str(plist[i])) op_names.append(py_str(plist[i]))
module_obj = sys.modules["%s.function" % root_namespace] module_obj = sys.modules["%s.api" % root_namespace]
module_internal = sys.modules["%s._function_internal" % root_namespace] module_internal = sys.modules["%s._api_internal" % root_namespace]
namespace_match = { namespace_match = {
"_make_": sys.modules["%s.make" % root_namespace], "_make_": sys.modules["%s.make" % root_namespace],
"_pass_": sys.modules["%s.ir_pass" % root_namespace], "_pass_": sys.modules["%s.ir_pass" % root_namespace],
...@@ -286,8 +291,8 @@ def _init_function_module(root_namespace): ...@@ -286,8 +291,8 @@ def _init_function_module(root_namespace):
} }
for name in op_names: for name in op_names:
hdl = APIFunctionHandle() hdl = APIFuncHandle()
check_call(_LIB.TVMGetAPIFunctionHandle(c_str(name), ctypes.byref(hdl))) check_call(_LIB.TVMGetAPIFuncHandle(c_str(name), ctypes.byref(hdl)))
fname = name fname = name
target_module = module_internal if name.startswith('_') else module_obj target_module = module_internal if name.startswith('_') else module_obj
for k, v in namespace_match.items(): for k, v in namespace_match.items():
......
...@@ -4,16 +4,16 @@ ...@@ -4,16 +4,16 @@
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import ctypes import ctypes
from numbers import Number, Integral
import numpy as np import numpy as np
from .._base import _LIB from .._base import _LIB
from .._base import c_array, c_str from .._base import c_array, c_str, string_types
from .._base import check_call from .._base import check_call
from ._types import TVMValue, TypeCode, TVMType
tvm_index_t = ctypes.c_uint32 tvm_index_t = ctypes.c_uint32
class TVMContext(ctypes.Structure): class TVMContext(ctypes.Structure):
"""TVM context strucure.""" """TVM context strucure."""
_fields_ = [("dev_mask", ctypes.c_int), _fields_ = [("dev_mask", ctypes.c_int),
...@@ -72,52 +72,13 @@ def opencl(dev_id=0): ...@@ -72,52 +72,13 @@ def opencl(dev_id=0):
return TVMContext(4, dev_id) return TVMContext(4, dev_id)
class TVMDataType(ctypes.Structure):
"""TVM datatype structure"""
_fields_ = [("type_code", ctypes.c_uint8),
("bits", ctypes.c_uint8),
("lanes", ctypes.c_uint16)]
CODE2STR = {
0 : 'int',
1 : 'uint',
2 : 'float'
}
def __init__(self, type_str, lanes=1):
super(TVMDataType, self).__init__()
if isinstance(type_str, np.dtype):
type_str = str(type_str)
if type_str.startswith("int"):
self.type_code = 0
bits = int(type_str[3:])
elif type_str.startswith("uint"):
self.type_code = 1
bits = int(type_str[4:])
elif type_str.startswith("float"):
self.type_code = 2
bits = int(type_str[5:])
else:
raise ValueError("Donot know how to handle type %s" % type_str)
bits = 32 if bits == 0 else bits
if (bits & (bits - 1)) != 0 or bits < 8:
raise ValueError("Donot know how to handle type %s" % type_str)
self.bits = bits
self.lanes = lanes
def __repr__(self):
x = "%s%d" % (TVMDataType.CODE2STR[self.type_code], self.bits)
if self.lanes != 1:
x += "x%d" % self.lanes
return x
class TVMArray(ctypes.Structure): class TVMArray(ctypes.Structure):
"""TVMArg in C API""" """TVMValue in C API"""
_fields_ = [("data", ctypes.c_void_p), _fields_ = [("data", ctypes.c_void_p),
("shape", ctypes.POINTER(tvm_index_t)), ("shape", ctypes.POINTER(tvm_index_t)),
("strides", ctypes.POINTER(tvm_index_t)), ("strides", ctypes.POINTER(tvm_index_t)),
("ndim", tvm_index_t), ("ndim", tvm_index_t),
("dtype", TVMDataType), ("dtype", TVMType),
("ctx", TVMContext)] ("ctx", TVMContext)]
TVMArrayHandle = ctypes.POINTER(TVMArray) TVMArrayHandle = ctypes.POINTER(TVMArray)
...@@ -133,7 +94,7 @@ def numpyasarray(np_data): ...@@ -133,7 +94,7 @@ def numpyasarray(np_data):
arr.data = data.ctypes.data_as(ctypes.c_void_p) arr.data = data.ctypes.data_as(ctypes.c_void_p)
arr.shape = shape arr.shape = shape
arr.strides = None arr.strides = None
arr.dtype = TVMDataType(np.dtype(data.dtype).name) arr.dtype = TVMType(np.dtype(data.dtype).name)
arr.ndim = data.ndim arr.ndim = data.ndim
# CPU device # CPU device
arr.ctx = cpu(0) arr.ctx = cpu(0)
...@@ -141,6 +102,7 @@ def numpyasarray(np_data): ...@@ -141,6 +102,7 @@ def numpyasarray(np_data):
_ndarray_cls = None _ndarray_cls = None
_function_cls = None
def empty(shape, dtype="float32", ctx=cpu(0)): def empty(shape, dtype="float32", ctx=cpu(0)):
...@@ -165,7 +127,7 @@ def empty(shape, dtype="float32", ctx=cpu(0)): ...@@ -165,7 +127,7 @@ def empty(shape, dtype="float32", ctx=cpu(0)):
shape = c_array(tvm_index_t, shape) shape = c_array(tvm_index_t, shape)
ndim = tvm_index_t(len(shape)) ndim = tvm_index_t(len(shape))
handle = TVMArrayHandle() handle = TVMArrayHandle()
dtype = TVMDataType(dtype) dtype = TVMType(dtype)
check_call(_LIB.TVMArrayAlloc( check_call(_LIB.TVMArrayAlloc(
shape, ndim, dtype, ctx, ctypes.byref(handle))) shape, ndim, dtype, ctx, ctypes.byref(handle)))
return _ndarray_cls(handle) return _ndarray_cls(handle)
...@@ -313,6 +275,51 @@ class NDArrayBase(object): ...@@ -313,6 +275,51 @@ class NDArrayBase(object):
return target return target
def _init_runtime_module(ndarray_class): class FunctionBase(object):
"""A function object at runtim."""
__slots__ = ["handle"]
# pylint: disable=no-member
def __init__(self, handle):
"""Initialize the function with handle
Parameters
----------
handle : FunctionHandle
the handle to the underlying function.
"""
self.handle = handle
def __del__(self):
check_call(_LIB.TVMFuncFree(self.handle))
def __call__(self, *args):
num_args = len(args)
tvm_args = (TVMValue * num_args)()
tvm_type_code = (ctypes.c_int * num_args)()
for i, arg in enumerate(args):
if arg is None:
tvm_args[i].v_handle = None
tvm_type_code[i] = TypeCode.NULL
elif isinstance(arg, NDArrayBase):
tvm_args[i].v_handle = ctypes.cast(arg.handle, ctypes.c_void_p)
tvm_type_code[i] = TypeCode.HANDLE
elif isinstance(arg, Integral):
tvm_args[i].v_int64 = arg
tvm_type_code[i] = TypeCode.INT
elif isinstance(arg, Number):
tvm_args[i].v_float64 = arg
tvm_type_code[i] = TypeCode.FLOAT
elif isinstance(arg, string_types):
tvm_args[i].v_str = c_str(arg)
tvm_type_code[i] = TypeCode.STR
else:
raise TypeError("Don't know how to handle type %s" % type(arg))
check_call(_LIB.TVMFuncCall(
self.handle, tvm_args, tvm_type_code, ctypes.c_int(num_args)))
def _init_runtime_module(ndarray_class, function_class):
global _ndarray_cls global _ndarray_cls
global _function_cls
_ndarray_cls = ndarray_class _ndarray_cls = ndarray_class
_function_cls = function_class
"""The C Types used in API."""
# pylint: disable=invalid-name
from __future__ import absolute_import as _abs
import ctypes
import numpy as np
class TVMValue(ctypes.Union):
"""TVMValue in C API"""
_fields_ = [("v_int64", ctypes.c_int64),
("v_float64", ctypes.c_double),
("v_handle", ctypes.c_void_p),
("v_str", ctypes.c_char_p)]
class TypeCode(object):
"""Type code used in API calls"""
INT = 0
UINT = 1
FLOAT = 2
HANDLE = 3
NULL = 4
NODE_HANDLE = 5
STR = 6
FUNC_HANDLE = 7
def _api_type(code):
"""create a type accepted by API"""
t = TVMType()
t.bits = 64
t.lanes = 1
t.type_code = code
return t
class TVMType(ctypes.Structure):
"""TVM datatype structure"""
_fields_ = [("type_code", ctypes.c_uint8),
("bits", ctypes.c_uint8),
("lanes", ctypes.c_uint16)]
CODE2STR = {
0 : 'int',
1 : 'uint',
2 : 'float'
}
def __init__(self, type_str, lanes=1):
super(TVMType, self).__init__()
if isinstance(type_str, np.dtype):
type_str = str(type_str)
if type_str.startswith("int"):
self.type_code = 0
bits = int(type_str[3:])
elif type_str.startswith("uint"):
self.type_code = 1
bits = int(type_str[4:])
elif type_str.startswith("float"):
self.type_code = 2
bits = int(type_str[5:])
else:
raise ValueError("Donot know how to handle type %s" % type_str)
bits = 32 if bits == 0 else bits
if (bits & (bits - 1)) != 0 or bits < 8:
raise ValueError("Donot know how to handle type %s" % type_str)
self.bits = bits
self.lanes = lanes
def __repr__(self):
x = "%s%d" % (TVMType.CODE2STR[self.type_code], self.bits)
if self.lanes != 1:
x += "x%d" % self.lanes
return x
"""namespace of internal function"""
...@@ -3,8 +3,8 @@ ...@@ -3,8 +3,8 @@
"""Functions defined in TVM.""" """Functions defined in TVM."""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from numbers import Integral as _Integral from numbers import Integral as _Integral
from ._ctypes._api import _init_function_module, convert from ._ctypes._api import _init_api_module, convert
from . import _function_internal from . import _api_internal
from . import make as _make from . import make as _make
from . import expr as _expr from . import expr as _expr
from . import collections as _collections from . import collections as _collections
...@@ -20,7 +20,7 @@ def const(value, dtype=None): ...@@ -20,7 +20,7 @@ def const(value, dtype=None):
dtype = 'int32' dtype = 'int32'
else: else:
dtype = 'float32' dtype = 'float32'
return _function_internal._const(value, dtype) return _api_internal._const(value, dtype)
def load_json(json_str): def load_json(json_str):
...@@ -36,7 +36,7 @@ def load_json(json_str): ...@@ -36,7 +36,7 @@ def load_json(json_str):
node : Node node : Node
The loaded tvm node. The loaded tvm node.
""" """
return _function_internal._load_json(json_str) return _api_internal._load_json(json_str)
def save_json(node): def save_json(node):
...@@ -52,7 +52,7 @@ def save_json(node): ...@@ -52,7 +52,7 @@ def save_json(node):
json_str : str json_str : str
Saved json string. Saved json string.
""" """
return _function_internal._save_json(node) return _api_internal._save_json(node)
def Var(name="tindex", dtype=int32): def Var(name="tindex", dtype=int32):
...@@ -66,7 +66,7 @@ def Var(name="tindex", dtype=int32): ...@@ -66,7 +66,7 @@ def Var(name="tindex", dtype=int32):
dtype : int dtype : int
The data type The data type
""" """
return _function_internal._Var(name, dtype) return _api_internal._Var(name, dtype)
def placeholder(shape, dtype=None, name="placeholder"): def placeholder(shape, dtype=None, name="placeholder"):
...@@ -90,7 +90,7 @@ def placeholder(shape, dtype=None, name="placeholder"): ...@@ -90,7 +90,7 @@ def placeholder(shape, dtype=None, name="placeholder"):
""" """
shape = (shape,) if isinstance(shape, _expr.Expr) else shape shape = (shape,) if isinstance(shape, _expr.Expr) else shape
dtype = float32 if dtype is None else dtype dtype = float32 if dtype is None else dtype
return _function_internal._Placeholder( return _api_internal._Placeholder(
shape, dtype, name) shape, dtype, name)
...@@ -128,9 +128,9 @@ def compute(shape, fcompute, name="compute"): ...@@ -128,9 +128,9 @@ def compute(shape, fcompute, name="compute"):
dim_var = [IterVar((0, s), x) for x, s in zip(arg_names, shape)] dim_var = [IterVar((0, s), x) for x, s in zip(arg_names, shape)]
body = fcompute(*[v.var for v in dim_var]) body = fcompute(*[v.var for v in dim_var])
body = convert(body) body = convert(body)
op_node = _function_internal._ComputeOp( op_node = _api_internal._ComputeOp(
name, dim_var, body) name, dim_var, body)
return _function_internal._Tensor( return _api_internal._Tensor(
shape, body.dtype, op_node, 0) shape, body.dtype, op_node, 0)
...@@ -168,7 +168,7 @@ def Buffer(shape, dtype=None, ...@@ -168,7 +168,7 @@ def Buffer(shape, dtype=None,
if ptr is None: if ptr is None:
ptr = Var(name, "handle") ptr = Var(name, "handle")
return _function_internal._Buffer( return _api_internal._Buffer(
name, ptr, shape, strides, dtype) name, ptr, shape, strides, dtype)
...@@ -202,7 +202,7 @@ def IterVar(dom=None, name=None, thread_tag=''): ...@@ -202,7 +202,7 @@ def IterVar(dom=None, name=None, thread_tag=''):
if name is None: if name is None:
name = thread_tag if thread_tag else name name = thread_tag if thread_tag else name
name = name if name else 'iter' name = name if name else 'iter'
return _function_internal._IterVar(dom, name, thread_tag) return _api_internal._IterVar(dom, name, thread_tag)
def sum(expr, rdom): def sum(expr, rdom):
...@@ -263,7 +263,7 @@ def Schedule(ops): ...@@ -263,7 +263,7 @@ def Schedule(ops):
""" """
if not isinstance(ops, (list, _collections.Array)): if not isinstance(ops, (list, _collections.Array)):
ops = [ops] ops = [ops]
return _function_internal._Schedule(ops) return _api_internal._Schedule(ops)
_init_function_module("tvm") _init_api_module("tvm")
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
"""Collection structure in the high level DSL.""" """Collection structure in the high level DSL."""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from ._ctypes._api import NodeBase, register_node from ._ctypes._api import NodeBase, register_node
from . import _function_internal from . import _api_internal
from . import expr as _expr from . import expr as _expr
@register_node @register_node
...@@ -11,10 +11,10 @@ class Array(NodeBase): ...@@ -11,10 +11,10 @@ class Array(NodeBase):
def __getitem__(self, i): def __getitem__(self, i):
if i >= len(self): if i >= len(self):
raise IndexError("array index out ot range") raise IndexError("array index out ot range")
return _function_internal._ArrayGetItem(self, i) return _api_internal._ArrayGetItem(self, i)
def __len__(self): def __len__(self):
return _function_internal._ArraySize(self) return _api_internal._ArraySize(self)
def __repr__(self): def __repr__(self):
return '[' + (','.join(str(x) for x in self)) + ']' return '[' + (','.join(str(x) for x in self)) + ']'
...@@ -23,18 +23,18 @@ class Array(NodeBase): ...@@ -23,18 +23,18 @@ class Array(NodeBase):
class Map(NodeBase): class Map(NodeBase):
"""Map container of TVM""" """Map container of TVM"""
def __getitem__(self, k): def __getitem__(self, k):
return _function_internal._MapGetItem(self, k) return _api_internal._MapGetItem(self, k)
def __contains__(self, k): def __contains__(self, k):
return _function_internal._MapCount(self, k) != 0 return _api_internal._MapCount(self, k) != 0
def items(self): def items(self):
"""Get the items from the map""" """Get the items from the map"""
akvs = _function_internal._MapItems(self) akvs = _api_internal._MapItems(self)
return [(akvs[i], akvs[i+1]) for i in range(0, len(akvs), 2)] return [(akvs[i], akvs[i+1]) for i in range(0, len(akvs), 2)]
def __len__(self): def __len__(self):
return _function_internal._MapSize(self) return _api_internal._MapSize(self)
def __repr__(self): def __repr__(self):
return '{' + (", ".join(str(x[0]) + ": " +str(x[1]) for x in self.items())) + '}' return '{' + (", ".join(str(x[0]) + ": " +str(x[1]) for x in self.items())) + '}'
......
...@@ -6,7 +6,7 @@ This is a simplified runtime API for quick testing and proptyping. ...@@ -6,7 +6,7 @@ This is a simplified runtime API for quick testing and proptyping.
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import numpy as _np import numpy as _np
from ._ctypes._runtime_api import TVMContext, TVMDataType, NDArrayBase from ._ctypes._runtime_api import TVMContext, TVMType, NDArrayBase, FunctionBase
from ._ctypes._runtime_api import cpu, gpu, opencl, empty, sync from ._ctypes._runtime_api import cpu, gpu, opencl, empty, sync
from ._ctypes._runtime_api import _init_runtime_module from ._ctypes._runtime_api import _init_runtime_module
from ._ctypes._runtime_api import init_opencl from ._ctypes._runtime_api import init_opencl
...@@ -26,6 +26,11 @@ class NDArray(NDArrayBase): ...@@ -26,6 +26,11 @@ class NDArray(NDArrayBase):
pass pass
class Function(FunctionBase):
"""Function class that can executed a generated code."""
pass
def array(arr, ctx=cpu(0)): def array(arr, ctx=cpu(0)):
"""Create an array from source arr. """Create an array from source arr.
...@@ -49,4 +54,4 @@ def array(arr, ctx=cpu(0)): ...@@ -49,4 +54,4 @@ def array(arr, ctx=cpu(0)):
return ret return ret
_init_runtime_module(NDArray) _init_runtime_module(NDArray, Function)
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
"""Collection structure in the high level DSL.""" """Collection structure in the high level DSL."""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from ._ctypes._api import NodeBase, register_node from ._ctypes._api import NodeBase, register_node
from . import _function_internal from . import _api_internal
from . import tensor as _tensor from . import tensor as _tensor
@register_node @register_node
...@@ -56,11 +56,11 @@ class Stage(NodeBase): ...@@ -56,11 +56,11 @@ class Stage(NodeBase):
if outer is not None: if outer is not None:
if outer.thread_tag == '': if outer.thread_tag == '':
raise ValueError("split by outer must have special thread_tag") raise ValueError("split by outer must have special thread_tag")
inner = _function_internal._StageSplitByOuter(self, parent, outer, factor) inner = _api_internal._StageSplitByOuter(self, parent, outer, factor)
else: else:
if factor is None: if factor is None:
raise ValueError("either outer or factor need to be provided") raise ValueError("either outer or factor need to be provided")
outer, inner = _function_internal._StageSplitByFactor(self, parent, factor) outer, inner = _api_internal._StageSplitByFactor(self, parent, factor)
return outer, inner return outer, inner
def fuse(self, inner, outer): def fuse(self, inner, outer):
...@@ -79,7 +79,7 @@ class Stage(NodeBase): ...@@ -79,7 +79,7 @@ class Stage(NodeBase):
inner : IterVar inner : IterVar
The fused variable of iteration. The fused variable of iteration.
""" """
return _function_internal._StageFuse(self, inner, outer) return _api_internal._StageFuse(self, inner, outer)
def set_scope(self, scope): def set_scope(self, scope):
"""Set the thread scope of this stage """Set the thread scope of this stage
...@@ -89,7 +89,7 @@ class Stage(NodeBase): ...@@ -89,7 +89,7 @@ class Stage(NodeBase):
scope : str scope : str
The thread scope of this stage The thread scope of this stage
""" """
return _function_internal._StageSetScope(self, scope) return _api_internal._StageSetScope(self, scope)
def compute_at(self, parent, scope): def compute_at(self, parent, scope):
"""Attach the stage at parent's scope """Attach the stage at parent's scope
...@@ -102,7 +102,7 @@ class Stage(NodeBase): ...@@ -102,7 +102,7 @@ class Stage(NodeBase):
scope : IterVar scope : IterVar
The loop scope t be attached to. The loop scope t be attached to.
""" """
_function_internal._StageComputeAt(self, parent, scope) _api_internal._StageComputeAt(self, parent, scope)
def compute_inline(self): def compute_inline(self):
"""Mark stage as inline """Mark stage as inline
...@@ -112,7 +112,7 @@ class Stage(NodeBase): ...@@ -112,7 +112,7 @@ class Stage(NodeBase):
parent : Stage parent : Stage
The parent stage The parent stage
""" """
_function_internal._StageComputeInline(self) _api_internal._StageComputeInline(self)
def compute_root(self): def compute_root(self):
"""Attach the stage at parent, and mark it as root """Attach the stage at parent, and mark it as root
...@@ -122,7 +122,7 @@ class Stage(NodeBase): ...@@ -122,7 +122,7 @@ class Stage(NodeBase):
parent : Stage parent : Stage
The parent stage The parent stage
""" """
_function_internal._StageComputeInline(self) _api_internal._StageComputeInline(self)
def reorder(self, *args): def reorder(self, *args):
"""reorder the arguments in the specified order. """reorder the arguments in the specified order.
...@@ -132,7 +132,7 @@ class Stage(NodeBase): ...@@ -132,7 +132,7 @@ class Stage(NodeBase):
args : list of IterVar args : list of IterVar
The order to be ordered The order to be ordered
""" """
_function_internal._StageReorder(self, args) _api_internal._StageReorder(self, args)
def tile(self, x_parent, y_parent, x_factor, y_factor): def tile(self, x_parent, y_parent, x_factor, y_factor):
""" Perform tiling on two dimensions """ Perform tiling on two dimensions
...@@ -161,6 +161,6 @@ class Stage(NodeBase): ...@@ -161,6 +161,6 @@ class Stage(NodeBase):
p_y_inner : IterVar p_y_inner : IterVar
Inner axis of y dimension Inner axis of y dimension
""" """
x_outer, y_outer, x_inner, y_inner = _function_internal._StageTile( x_outer, y_outer, x_inner, y_inner = _api_internal._StageTile(
self, x_parent, y_parent, x_factor, y_factor) self, x_parent, y_parent, x_factor, y_factor)
return x_outer, y_outer, x_inner, y_inner return x_outer, y_outer, x_inner, y_inner
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from ._ctypes._api import NodeBase, SliceBase, register_node, convert from ._ctypes._api import NodeBase, SliceBase, register_node, convert
from . import collections as _collections from . import collections as _collections
from . import _function_internal from . import _api_internal
from . import make as _make from . import make as _make
from . import expr as _expr from . import expr as _expr
...@@ -44,12 +44,12 @@ class Tensor(NodeBase): ...@@ -44,12 +44,12 @@ class Tensor(NodeBase):
return TensorSlice(self, indices) return TensorSlice(self, indices)
def __hash__(self): def __hash__(self):
return _function_internal._TensorHash(self) return _api_internal._TensorHash(self)
def __eq__(self, other): def __eq__(self, other):
if not isinstance(other, Tensor): if not isinstance(other, Tensor):
return False return False
return _function_internal._TensorEqual(self, other) return _api_internal._TensorEqual(self, other)
@property @property
def ndim(self): def ndim(self):
...@@ -72,7 +72,7 @@ class Operation(NodeBase): ...@@ -72,7 +72,7 @@ class Operation(NodeBase):
out : Tensor out : Tensor
The i-th output. The i-th output.
""" """
return _function_internal._OpGetOutput(self, index) return _api_internal._OpGetOutput(self, index)
@register_node @register_node
class ComputeOp(Operation): class ComputeOp(Operation):
......
...@@ -4,4 +4,5 @@ ...@@ -4,4 +4,5 @@
- lang The definition of DSL related data structure - lang The definition of DSL related data structure
- schedule The operations on the schedule graph before converting to IR. - schedule The operations on the schedule graph before converting to IR.
- pass The optimization pass on the IR structure - pass The optimization pass on the IR structure
- runtime The runtime related codes. - runtime Minimum runtime related codes.
\ No newline at end of file - jit JIT runtime related code.
...@@ -22,7 +22,7 @@ struct TVMAPIThreadLocalEntry { ...@@ -22,7 +22,7 @@ struct TVMAPIThreadLocalEntry {
arg_stack.clear(); arg_stack.clear();
ret_value.sptr.reset(); ret_value.sptr.reset();
} }
inline void SetReturn(ArgVariant* ret_val, int* ret_typeid); inline void SetReturn(TVMValue* ret_val, int* ret_type_code);
}; };
using namespace tvm; using namespace tvm;
...@@ -97,11 +97,11 @@ struct APIAttrDir : public AttrVisitor { ...@@ -97,11 +97,11 @@ struct APIAttrDir : public AttrVisitor {
} }
}; };
int TVMListAPIFunctionNames(int *out_size, int TVMListAPIFuncNames(int *out_size,
const char*** out_array) { const char*** out_array) {
API_BEGIN(); API_BEGIN();
TVMAPIThreadLocalEntry *ret = TVMAPIThreadLocalStore::Get(); TVMAPIThreadLocalEntry *ret = TVMAPIThreadLocalStore::Get();
ret->ret_vec_str = dmlc::Registry<APIFunctionReg>::ListAllNames(); ret->ret_vec_str = dmlc::Registry<APIFuncReg>::ListAllNames();
ret->ret_vec_charp.clear(); ret->ret_vec_charp.clear();
for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) { for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) {
ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str()); ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str());
...@@ -111,16 +111,16 @@ int TVMListAPIFunctionNames(int *out_size, ...@@ -111,16 +111,16 @@ int TVMListAPIFunctionNames(int *out_size,
API_END(); API_END();
} }
int TVMGetAPIFunctionHandle(const char* fname, int TVMGetAPIFuncHandle(const char* fname,
APIFunctionHandle* out) { APIFuncHandle* out) {
API_BEGIN(); API_BEGIN();
const APIFunctionReg* reg = dmlc::Registry<APIFunctionReg>::Find(fname); const APIFuncReg* reg = dmlc::Registry<APIFuncReg>::Find(fname);
CHECK(reg != nullptr) << "cannot find function " << fname; CHECK(reg != nullptr) << "cannot find function " << fname;
*out = (APIFunctionHandle)reg; *out = (APIFuncHandle)reg;
API_END(); API_END();
} }
int TVMGetAPIFunctionInfo(APIFunctionHandle handle, int TVMGetAPIFuncInfo(APIFuncHandle handle,
const char **real_name, const char **real_name,
const char **description, const char **description,
int *num_doc_args, int *num_doc_args,
...@@ -128,7 +128,7 @@ int TVMGetAPIFunctionInfo(APIFunctionHandle handle, ...@@ -128,7 +128,7 @@ int TVMGetAPIFunctionInfo(APIFunctionHandle handle,
const char ***arg_type_infos, const char ***arg_type_infos,
const char ***arg_descriptions, const char ***arg_descriptions,
const char **return_type) { const char **return_type) {
const auto *op = static_cast<const APIFunctionReg *>(handle); const auto *op = static_cast<const APIFuncReg *>(handle);
TVMAPIThreadLocalEntry *ret = TVMAPIThreadLocalStore::Get(); TVMAPIThreadLocalEntry *ret = TVMAPIThreadLocalStore::Get();
API_BEGIN(); API_BEGIN();
...@@ -152,33 +152,37 @@ int TVMGetAPIFunctionInfo(APIFunctionHandle handle, ...@@ -152,33 +152,37 @@ int TVMGetAPIFunctionInfo(APIFunctionHandle handle,
API_END(); API_END();
} }
int TVMAPIPushStack(ArgVariant arg, int TVMAPIPushStack(TVMValue arg,
int type_id) { int type_code) {
TVMAPIThreadLocalEntry *ret = TVMAPIThreadLocalStore::Get(); TVMAPIThreadLocalEntry *ret = TVMAPIThreadLocalStore::Get();
API_BEGIN(); API_BEGIN();
ret->arg_stack.resize(ret->arg_stack.size() + 1); ret->arg_stack.resize(ret->arg_stack.size() + 1);
APIVariantValue& v = ret->arg_stack.back(); APIVariantValue& v = ret->arg_stack.back();
v.type_id = static_cast<ArgVariantID>(type_id); v.type_code = type_code;
if (type_id == kStr) { switch (type_code) {
v.str = arg.v_str; case kInt: case kUInt: case kFloat: case kNull: {
} else if (type_id == kNodeHandle) { v.v_union = arg; break;
v.sptr = *static_cast<TVMAPINode*>(arg.v_handle); }
} else { case kStr: {
v.v_union = arg; v.str = arg.v_str; break;
}
case kNodeHandle: {
v.sptr = *static_cast<TVMAPINode*>(arg.v_handle); break;
}
default: LOG(FATAL) << "TVM API cannot take type " << TVMTypeCode2Str(type_code);
} }
API_END_HANDLE_ERROR(ret->Clear()); API_END_HANDLE_ERROR(ret->Clear());
} }
int TVMAPIFunctionCall(APIFunctionHandle handle, int TVMAPIFuncCall(APIFuncHandle handle,
ArgVariant* ret_val, TVMValue* ret_val,
int* ret_typeid) { int* ret_type_code) {
TVMAPIThreadLocalEntry *ret = TVMAPIThreadLocalStore::Get(); TVMAPIThreadLocalEntry *ret = TVMAPIThreadLocalStore::Get();
API_BEGIN(); API_BEGIN();
const auto *op = static_cast<const APIFunctionReg *>(handle); const auto *op = static_cast<const APIFuncReg *>(handle);
op->body(ret->arg_stack, &(ret->ret_value)); op->body(ret->arg_stack, &(ret->ret_value));
ret->SetReturn(ret_val, ret_typeid); ret->SetReturn(ret_val, ret_type_code);
ret->arg_stack.clear(); ret->arg_stack.clear();
API_END_HANDLE_ERROR(ret->Clear()); API_END_HANDLE_ERROR(ret->Clear());
} }
...@@ -191,28 +195,28 @@ int TVMNodeFree(NodeHandle handle) { ...@@ -191,28 +195,28 @@ int TVMNodeFree(NodeHandle handle) {
int TVMNodeGetAttr(NodeHandle handle, int TVMNodeGetAttr(NodeHandle handle,
const char* key, const char* key,
ArgVariant* ret_val, TVMValue* ret_val,
int* ret_typeid, int* ret_type_code,
int* ret_success) { int* ret_success) {
TVMAPIThreadLocalEntry *ret = TVMAPIThreadLocalStore::Get(); TVMAPIThreadLocalEntry *ret = TVMAPIThreadLocalStore::Get();
API_BEGIN(); API_BEGIN();
ret->ret_value.type_id = kNull; ret->ret_value.type_code = kNull;
APIAttrGetter getter; APIAttrGetter getter;
getter.skey = key; getter.skey = key;
getter.ret = &(ret->ret_value); getter.ret = &(ret->ret_value);
TVMAPINode* tnode = static_cast<TVMAPINode*>(handle); TVMAPINode* tnode = static_cast<TVMAPINode*>(handle);
if (getter.skey == "type_key") { if (getter.skey == "type_key") {
ret_val->v_str = (*tnode)->type_key(); ret_val->v_str = (*tnode)->type_key();
*ret_typeid = kStr; *ret_type_code = kStr;
*ret_success = 1; *ret_success = 1;
} else { } else {
(*tnode)->VisitAttrs(&getter); (*tnode)->VisitAttrs(&getter);
if (ret->ret_value.type_id != kNull) { if (ret->ret_value.type_code != kNull) {
ret->SetReturn(ret_val, ret_typeid); ret->SetReturn(ret_val, ret_type_code);
*ret_success = 1; *ret_success = 1;
} else { } else {
*ret_success = getter.found_node_ref ? 1 : 0; *ret_success = getter.found_node_ref ? 1 : 0;
*ret_typeid = kNull; *ret_type_code = kNull;
} }
} }
API_END_HANDLE_ERROR(ret->Clear()); API_END_HANDLE_ERROR(ret->Clear());
...@@ -238,16 +242,18 @@ int TVMNodeListAttrNames(NodeHandle handle, ...@@ -238,16 +242,18 @@ int TVMNodeListAttrNames(NodeHandle handle,
} }
inline void TVMAPIThreadLocalEntry::SetReturn(ArgVariant* ret_val, inline void TVMAPIThreadLocalEntry::SetReturn(TVMValue* ret_val,
int* ret_typeid) { int* ret_type_code) {
APIVariantValue& rv = ret_value; APIVariantValue& rv = ret_value;
*ret_typeid = rv.type_id; *ret_type_code = rv.type_code;
if (rv.type_id == kNodeHandle) { if (rv.type_code == kNodeHandle) {
if (rv.sptr.get() != nullptr) { if (rv.sptr.get() != nullptr) {
ret_val->v_handle = new TVMAPINode(std::move(rv.sptr)); ret_val->v_handle = new TVMAPINode(std::move(rv.sptr));
} else { } else {
ret_val->v_handle = nullptr; ret_val->v_handle = nullptr;
} }
} else if (rv.type_code == kFuncHandle) {
ret_val->v_handle = new runtime::PackedFunc::FType(std::move(rv.func));
} else { } else {
*ret_val = rv.v_union; *ret_val = rv.v_union;
} }
......
/*! /*!
* Copyright (c) 2016 by Contributors * Copyright (c) 2016 by Contributors
* Implementation of API functions related to IR build * Implementation of API functions related to Codegen
* \file c_api_ir.cc * \file c_api_codegen.cc
*/ */
#include <tvm/expr.h> #include <tvm/expr.h>
#include <tvm/ir.h> #include <tvm/ir.h>
...@@ -32,5 +32,25 @@ TVM_REGISTER_API(_codegen_SplitHostDevice) ...@@ -32,5 +32,25 @@ TVM_REGISTER_API(_codegen_SplitHostDevice)
*ret = SplitHostDevice(args.at(0)); *ret = SplitHostDevice(args.at(0));
}); });
// generate a dummy packed function for testing
void DummyHelloFunction(const TVMValue* args, const int* type_code, int num_args) {
LOG(INFO) << num_args << " arguments";
for (int i = 0; i < num_args; ++i) {
switch (type_code[i]) {
case kNull: LOG(INFO) << i << ":nullptr"; break;
case kFloat: LOG(INFO) << i << ": double=" << args[i].v_float64; break;
case kInt: LOG(INFO) << i << ": long=" << args[i].v_int64; break;
case kHandle: LOG(INFO) << i << ": handle=" << args[i].v_handle; break;
default: LOG(FATAL) << "unhandled type " << TVMTypeCode2Str(type_code[i]);
}
}
}
TVM_REGISTER_API(_codegen_DummyHelloFunction)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = runtime::PackedFunc(DummyHelloFunction);
});
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#include "./c_api_registry.h" #include "./c_api_registry.h"
namespace dmlc { namespace dmlc {
DMLC_REGISTRY_ENABLE(::tvm::APIFunctionReg); DMLC_REGISTRY_ENABLE(::tvm::APIFuncReg);
} // namespace dmlc } // namespace dmlc
namespace tvm { namespace tvm {
...@@ -18,7 +18,7 @@ using RetValue = APIVariantValue; ...@@ -18,7 +18,7 @@ using RetValue = APIVariantValue;
TVM_REGISTER_API(_format_str) TVM_REGISTER_API(_format_str)
.set_body([](const ArgStack& args, RetValue *ret) { .set_body([](const ArgStack& args, RetValue *ret) {
CHECK(args.at(0).type_id == kNodeHandle); CHECK(args.at(0).type_code == kNodeHandle);
std::ostringstream os; std::ostringstream os;
os << args.at(0).operator NodeRef(); os << args.at(0).operator NodeRef();
*ret = os.str(); *ret = os.str();
...@@ -27,7 +27,7 @@ TVM_REGISTER_API(_format_str) ...@@ -27,7 +27,7 @@ TVM_REGISTER_API(_format_str)
TVM_REGISTER_API(_raw_ptr) TVM_REGISTER_API(_raw_ptr)
.set_body([](const ArgStack& args, RetValue *ret) { .set_body([](const ArgStack& args, RetValue *ret) {
CHECK(args.at(0).type_id == kNodeHandle); CHECK(args.at(0).type_code == kNodeHandle);
*ret = reinterpret_cast<int64_t>(args.at(0).sptr.get()); *ret = reinterpret_cast<int64_t>(args.at(0).sptr.get());
}) })
.add_argument("src", "NodeBase", "the node base"); .add_argument("src", "NodeBase", "the node base");
......
...@@ -16,9 +16,9 @@ using RetValue = APIVariantValue; ...@@ -16,9 +16,9 @@ using RetValue = APIVariantValue;
TVM_REGISTER_API(_const) TVM_REGISTER_API(_const)
.set_body([](const ArgStack& args, RetValue *ret) { .set_body([](const ArgStack& args, RetValue *ret) {
if (args.at(0).type_id == kLong) { if (args.at(0).type_code == kInt) {
*ret = make_const(args.at(1), args.at(0).operator int64_t()); *ret = make_const(args.at(1), args.at(0).operator int64_t());
} else if (args.at(0).type_id == kDouble) { } else if (args.at(0).type_code == kFloat) {
*ret = make_const(args.at(1), args.at(0).operator double()); *ret = make_const(args.at(1), args.at(0).operator double());
} else { } else {
LOG(FATAL) << "only accept int or float"; LOG(FATAL) << "only accept int or float";
...@@ -31,19 +31,19 @@ TVM_REGISTER_API(_Array) ...@@ -31,19 +31,19 @@ TVM_REGISTER_API(_Array)
.set_body([](const ArgStack& args, RetValue *ret) { .set_body([](const ArgStack& args, RetValue *ret) {
std::vector<std::shared_ptr<Node> > data; std::vector<std::shared_ptr<Node> > data;
for (size_t i = 0; i < args.size(); ++i) { for (size_t i = 0; i < args.size(); ++i) {
CHECK(args.at(i).type_id == kNodeHandle) CHECK(args.at(i).type_code == kNodeHandle)
<< "need content of array to be NodeBase"; << "need content of array to be NodeBase";
data.push_back(args.at(i).sptr); data.push_back(args.at(i).sptr);
} }
auto node = std::make_shared<ArrayNode>(); auto node = std::make_shared<ArrayNode>();
node->data = std::move(data); node->data = std::move(data);
ret->type_id = kNodeHandle; ret->type_code = kNodeHandle;
ret->sptr = node; ret->sptr = node;
}); });
TVM_REGISTER_API(_ArrayGetItem) TVM_REGISTER_API(_ArrayGetItem)
.set_body([](const ArgStack& args, RetValue *ret) { .set_body([](const ArgStack& args, RetValue *ret) {
CHECK(args.at(0).type_id == kNodeHandle); CHECK(args.at(0).type_code == kNodeHandle);
int64_t i = args.at(1); int64_t i = args.at(1);
auto& sptr = args.at(0).sptr; auto& sptr = args.at(0).sptr;
CHECK(sptr->is_type<ArrayNode>()); CHECK(sptr->is_type<ArrayNode>());
...@@ -51,12 +51,12 @@ TVM_REGISTER_API(_ArrayGetItem) ...@@ -51,12 +51,12 @@ TVM_REGISTER_API(_ArrayGetItem)
CHECK_LT(static_cast<size_t>(i), n->data.size()) CHECK_LT(static_cast<size_t>(i), n->data.size())
<< "out of bound of array"; << "out of bound of array";
ret->sptr = n->data[i]; ret->sptr = n->data[i];
ret->type_id = kNodeHandle; ret->type_code = kNodeHandle;
}); });
TVM_REGISTER_API(_ArraySize) TVM_REGISTER_API(_ArraySize)
.set_body([](const ArgStack& args, RetValue *ret) { .set_body([](const ArgStack& args, RetValue *ret) {
CHECK(args.at(0).type_id == kNodeHandle); CHECK(args.at(0).type_code == kNodeHandle);
auto& sptr = args.at(0).sptr; auto& sptr = args.at(0).sptr;
CHECK(sptr->is_type<ArrayNode>()); CHECK(sptr->is_type<ArrayNode>());
*ret = static_cast<int64_t>( *ret = static_cast<int64_t>(
...@@ -68,21 +68,21 @@ TVM_REGISTER_API(_Map) ...@@ -68,21 +68,21 @@ TVM_REGISTER_API(_Map)
CHECK_EQ(args.size() % 2, 0U); CHECK_EQ(args.size() % 2, 0U);
MapNode::ContainerType data; MapNode::ContainerType data;
for (size_t i = 0; i < args.size(); i += 2) { for (size_t i = 0; i < args.size(); i += 2) {
CHECK(args.at(i).type_id == kNodeHandle) CHECK(args.at(i).type_code == kNodeHandle)
<< "need content of array to be NodeBase"; << "need content of array to be NodeBase";
CHECK(args.at(i + 1).type_id == kNodeHandle) CHECK(args.at(i + 1).type_code == kNodeHandle)
<< "need content of array to be NodeBase"; << "need content of array to be NodeBase";
data.emplace(std::make_pair(args.at(i).sptr, args.at(i + 1).sptr)); data.emplace(std::make_pair(args.at(i).sptr, args.at(i + 1).sptr));
} }
auto node = std::make_shared<MapNode>(); auto node = std::make_shared<MapNode>();
node->data = std::move(data); node->data = std::move(data);
ret->type_id = kNodeHandle; ret->type_code = kNodeHandle;
ret->sptr = node; ret->sptr = node;
}); });
TVM_REGISTER_API(_MapSize) TVM_REGISTER_API(_MapSize)
.set_body([](const ArgStack& args, RetValue *ret) { .set_body([](const ArgStack& args, RetValue *ret) {
CHECK(args.at(0).type_id == kNodeHandle); CHECK(args.at(0).type_code == kNodeHandle);
auto& sptr = args.at(0).sptr; auto& sptr = args.at(0).sptr;
CHECK(sptr->is_type<MapNode>()); CHECK(sptr->is_type<MapNode>());
auto* n = static_cast<const MapNode*>(sptr.get()); auto* n = static_cast<const MapNode*>(sptr.get());
...@@ -91,8 +91,8 @@ TVM_REGISTER_API(_MapSize) ...@@ -91,8 +91,8 @@ TVM_REGISTER_API(_MapSize)
TVM_REGISTER_API(_MapGetItem) TVM_REGISTER_API(_MapGetItem)
.set_body([](const ArgStack& args, RetValue *ret) { .set_body([](const ArgStack& args, RetValue *ret) {
CHECK(args.at(0).type_id == kNodeHandle); CHECK(args.at(0).type_code == kNodeHandle);
CHECK(args.at(1).type_id == kNodeHandle); CHECK(args.at(1).type_code == kNodeHandle);
auto& sptr = args.at(0).sptr; auto& sptr = args.at(0).sptr;
CHECK(sptr->is_type<MapNode>()); CHECK(sptr->is_type<MapNode>());
auto* n = static_cast<const MapNode*>(sptr.get()); auto* n = static_cast<const MapNode*>(sptr.get());
...@@ -100,13 +100,13 @@ TVM_REGISTER_API(_MapGetItem) ...@@ -100,13 +100,13 @@ TVM_REGISTER_API(_MapGetItem)
CHECK(it != n->data.end()) CHECK(it != n->data.end())
<< "cannot find the corresponding key in the Map"; << "cannot find the corresponding key in the Map";
ret->sptr = (*it).second; ret->sptr = (*it).second;
ret->type_id = kNodeHandle; ret->type_code = kNodeHandle;
}); });
TVM_REGISTER_API(_MapCount) TVM_REGISTER_API(_MapCount)
.set_body([](const ArgStack& args, RetValue *ret) { .set_body([](const ArgStack& args, RetValue *ret) {
CHECK(args.at(0).type_id == kNodeHandle); CHECK(args.at(0).type_code == kNodeHandle);
CHECK(args.at(1).type_id == kNodeHandle); CHECK(args.at(1).type_code == kNodeHandle);
auto& sptr = args.at(0).sptr; auto& sptr = args.at(0).sptr;
CHECK(sptr->is_type<MapNode>()); CHECK(sptr->is_type<MapNode>());
auto* n = static_cast<const MapNode*>(sptr.get()); auto* n = static_cast<const MapNode*>(sptr.get());
...@@ -115,7 +115,7 @@ TVM_REGISTER_API(_MapCount) ...@@ -115,7 +115,7 @@ TVM_REGISTER_API(_MapCount)
TVM_REGISTER_API(_MapItems) TVM_REGISTER_API(_MapItems)
.set_body([](const ArgStack& args, RetValue *ret) { .set_body([](const ArgStack& args, RetValue *ret) {
CHECK(args.at(0).type_id == kNodeHandle); CHECK(args.at(0).type_code == kNodeHandle);
auto& sptr = args.at(0).sptr; auto& sptr = args.at(0).sptr;
CHECK(sptr->is_type<MapNode>()); CHECK(sptr->is_type<MapNode>());
auto* n = static_cast<const MapNode*>(sptr.get()); auto* n = static_cast<const MapNode*>(sptr.get());
...@@ -125,7 +125,7 @@ TVM_REGISTER_API(_MapItems) ...@@ -125,7 +125,7 @@ TVM_REGISTER_API(_MapItems)
rkvs->data.push_back(kv.second); rkvs->data.push_back(kv.second);
} }
ret->sptr = rkvs; ret->sptr = rkvs;
ret->type_id = kNodeHandle; ret->type_code = kNodeHandle;
}); });
TVM_REGISTER_API(Range) TVM_REGISTER_API(Range)
......
...@@ -9,25 +9,25 @@ ...@@ -9,25 +9,25 @@
#include <tvm/base.h> #include <tvm/base.h>
#include <tvm/expr.h> #include <tvm/expr.h>
#include <tvm/c_api.h> #include <tvm/c_api.h>
#include <tvm/runtime/runtime.h>
#include <memory> #include <memory>
#include <limits> #include <limits>
#include <string> #include <string>
#include <vector> #include <vector>
#include "../base/common.h" #include "../base/common.h"
using ArgVariant = TVMArg;
using ArgVariantID = TVMArgTypeID;
namespace tvm { namespace tvm {
inline const char* TypeId2Str(ArgVariantID type_id) { inline const char* TVMTypeCode2Str(int type_code) {
switch (type_id) { switch (type_code) {
case kNull: return "Null"; case kInt: return "int";
case kLong: return "Long"; case kFloat: return "float";
case kDouble: return "Double"; case kStr: return "str";
case kStr: return "Str"; case kHandle: return "Handle";
case kNull: return "NULL";
case kNodeHandle: return "NodeHandle"; case kNodeHandle: return "NodeHandle";
default: LOG(FATAL) << "unknown type_id=" << type_id; return ""; default: LOG(FATAL) << "unknown type_code="
<< static_cast<int>(type_code); return "";
} }
} }
...@@ -96,72 +96,83 @@ inline std::string NodeTypeName() { ...@@ -96,72 +96,83 @@ inline std::string NodeTypeName() {
class APIVariantValue { class APIVariantValue {
public: public:
/*! \brief the type id */ /*! \brief the type id */
ArgVariantID type_id{kNull}; int type_code{kNull};
/*! \brief shared pointer container */ /*! \brief shared pointer container */
std::shared_ptr<Node> sptr; std::shared_ptr<Node> sptr;
/*! \brief string container */ /*! \brief string container */
std::string str; std::string str;
/*! \brief the variant holder */ /*! \brief the variant holder */
ArgVariant v_union; TVMValue v_union;
/*! \brief std::function */
runtime::PackedFunc::FType func;
// constructor // constructor
APIVariantValue() {} APIVariantValue() {
}
// clear value // clear value
inline void Clear() { inline void Clear() {
} }
// assign op // assign op
inline APIVariantValue& operator=(double value) { inline APIVariantValue& operator=(double value) {
type_id = kDouble; type_code = kFloat;
v_union.v_double = value; v_union.v_float64 = value;
return *this; return *this;
} }
inline APIVariantValue& operator=(std::nullptr_t value) { inline APIVariantValue& operator=(std::nullptr_t value) {
type_id = kNull; type_code = kHandle;
v_union.v_handle = value;
return *this; return *this;
} }
inline APIVariantValue& operator=(int64_t value) { inline APIVariantValue& operator=(int64_t value) {
type_id = kLong; type_code = kInt;
v_union.v_long = value; v_union.v_int64 = value;
return *this; return *this;
} }
inline APIVariantValue& operator=(bool value) { inline APIVariantValue& operator=(bool value) {
type_id = kLong; type_code = kInt;
v_union.v_long = value; v_union.v_int64 = value;
return *this; return *this;
} }
inline APIVariantValue& operator=(std::string value) { inline APIVariantValue& operator=(std::string value) {
type_id = kStr; type_code = kStr;
str = std::move(value); str = std::move(value);
v_union.v_str = str.c_str(); v_union.v_str = str.c_str();
return *this; return *this;
} }
inline APIVariantValue& operator=(const NodeRef& ref) { inline APIVariantValue& operator=(const NodeRef& ref) {
if (ref.node_.get() == nullptr) { if (ref.node_.get() == nullptr) {
type_id = kNull; type_code = kNull;
} else { } else {
type_id = kNodeHandle; type_code = kNodeHandle;
this->sptr = ref.node_; this->sptr = ref.node_;
} }
return *this; return *this;
} }
inline APIVariantValue& operator=(const runtime::PackedFunc& f) {
type_code = kFuncHandle;
this->func = f.body();
return *this;
}
inline APIVariantValue& operator=(const Type& value) { inline APIVariantValue& operator=(const Type& value) {
return operator=(Type2String(value)); return operator=(Type2String(value));
} }
template<typename T, template<typename T,
typename = typename std::enable_if<std::is_base_of<NodeRef, T>::value>::type> typename = typename std::enable_if<std::is_base_of<NodeRef, T>::value>::type>
inline operator T() const { inline operator T() const {
if (type_id == kNull) return T(); if (type_code == kNull) return T();
CHECK_EQ(type_id, kNodeHandle); CHECK_EQ(type_code, kNodeHandle);
CHECK(NodeTypeChecker<T>::Check(sptr.get())) CHECK(NodeTypeChecker<T>::Check(sptr.get()))
<< "Did not get expected type " << NodeTypeName<T>(); << "Did not get expected type " << NodeTypeName<T>();
return T(sptr); return T(sptr);
} }
inline operator Expr() const { inline operator Expr() const {
if (type_id == kNull) return Expr(); if (type_code == kNull) {
if (type_id == kLong) return Expr(operator int()); return Expr();
if (type_id == kDouble) { }
if (type_code == kInt) return Expr(operator int());
if (type_code == kFloat) {
return Expr(static_cast<float>(operator double())); return Expr(static_cast<float>(operator double()));
} }
CHECK_EQ(type_id, kNodeHandle); CHECK_EQ(type_code, kNodeHandle);
if (sptr->is_type<IterVarNode>()) { if (sptr->is_type<IterVarNode>()) {
return IterVar(sptr)->var; return IterVar(sptr)->var;
} else { } else {
...@@ -171,52 +182,58 @@ class APIVariantValue { ...@@ -171,52 +182,58 @@ class APIVariantValue {
} }
} }
inline operator double() const { inline operator double() const {
CHECK_EQ(type_id, kDouble); CHECK_EQ(type_code, kFloat);
return v_union.v_double; return v_union.v_float64;
} }
inline operator int64_t() const { inline operator int64_t() const {
CHECK_EQ(type_id, kLong); CHECK_EQ(type_code, kInt);
return v_union.v_long; return v_union.v_int64;
} }
inline operator uint64_t() const { inline operator uint64_t() const {
CHECK_EQ(type_id, kLong); CHECK_EQ(type_code, kInt);
return v_union.v_long; return v_union.v_int64;
} }
inline operator int() const { inline operator int() const {
CHECK_EQ(type_id, kLong); CHECK_EQ(type_code, kInt);
CHECK_LE(v_union.v_long, CHECK_LE(v_union.v_int64,
std::numeric_limits<int>::max()); std::numeric_limits<int>::max());
return v_union.v_long; return v_union.v_int64;
} }
inline operator bool() const { inline operator bool() const {
CHECK_EQ(type_id, kLong) CHECK_EQ(type_code, kInt)
<< "expect boolean(int) but get " << TypeId2Str(type_id); << "expect boolean(int) but get "
return v_union.v_long != 0; << TVMTypeCode2Str(type_code);
return v_union.v_int64 != 0;
} }
inline operator std::string() const { inline operator std::string() const {
CHECK_EQ(type_id, kStr) CHECK_EQ(type_code, kStr)
<< "expect Str but get " << TypeId2Str(type_id); << "expect Str but get "
<< TVMTypeCode2Str(type_code);
return str; return str;
} }
inline operator Type() const { inline operator Type() const {
return String2Type(operator std::string()); return String2Type(operator std::string());
} }
inline operator runtime::PackedFunc() const {
CHECK_EQ(type_code, kFuncHandle);
return runtime::PackedFunc(func);
}
}; };
// common defintiion of API function. // common defintiion of API function.
using APIFunction = std::function< using APIFunc = std::function<
void(const std::vector<APIVariantValue> &args, APIVariantValue* ret)>; void(const std::vector<APIVariantValue> &args, APIVariantValue* ret)>;
/*! /*!
* \brief Registry entry for DataIterator factory functions. * \brief Registry entry for DataIterator factory functions.
*/ */
struct APIFunctionReg struct APIFuncReg
: public dmlc::FunctionRegEntryBase<APIFunctionReg, : public dmlc::FunctionRegEntryBase<APIFuncReg,
APIFunction> { APIFunc> {
}; };
#define TVM_REGISTER_API(TypeName) \ #define TVM_REGISTER_API(TypeName) \
DMLC_REGISTRY_REGISTER(::tvm::APIFunctionReg, APIFunctionReg, TypeName) \ DMLC_REGISTRY_REGISTER(::tvm::APIFuncReg, APIFuncReg, TypeName) \
} // namespace tvm } // namespace tvm
......
...@@ -7,7 +7,6 @@ ...@@ -7,7 +7,6 @@
#define TVM_CODEGEN_CODEGEN_C_H_ #define TVM_CODEGEN_CODEGEN_C_H_
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <tvm/module.h> #include <tvm/module.h>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
......
...@@ -3,7 +3,8 @@ ...@@ -3,7 +3,8 @@
* \file c_runtime_api.cc * \file c_runtime_api.cc
* \brief Device specific implementations * \brief Device specific implementations
*/ */
#include <tvm/c_runtime_api.h> #include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/runtime.h>
#include <algorithm> #include <algorithm>
#include "./runtime_base.h" #include "./runtime_base.h"
#include "./device_api.h" #include "./device_api.h"
...@@ -34,7 +35,7 @@ inline void TVMArrayFree_(TVMArray* arr) { ...@@ -34,7 +35,7 @@ inline void TVMArrayFree_(TVMArray* arr) {
delete arr; delete arr;
} }
inline void VerifyType(TVMDataType dtype) { inline void VerifyType(TVMType dtype) {
CHECK_GE(dtype.lanes, 1U); CHECK_GE(dtype.lanes, 1U);
if (dtype.type_code == kFloat) { if (dtype.type_code == kFloat) {
CHECK_EQ(dtype.bits % 32U, 0U); CHECK_EQ(dtype.bits % 32U, 0U);
...@@ -98,7 +99,7 @@ int TVMContextEnabled(TVMContext ctx, ...@@ -98,7 +99,7 @@ int TVMContextEnabled(TVMContext ctx,
int TVMArrayAlloc(const tvm_index_t* shape, int TVMArrayAlloc(const tvm_index_t* shape,
tvm_index_t ndim, tvm_index_t ndim,
TVMDataType dtype, TVMType dtype,
TVMContext ctx, TVMContext ctx,
TVMArrayHandle* out) { TVMArrayHandle* out) {
TVMArray* arr = nullptr; TVMArray* arr = nullptr;
...@@ -166,3 +167,19 @@ int TVMSynchronize(TVMContext ctx, TVMStreamHandle stream) { ...@@ -166,3 +167,19 @@ int TVMSynchronize(TVMContext ctx, TVMStreamHandle stream) {
}); });
API_END(); API_END();
} }
int TVMFuncFree(TVMFunctionHandle func) {
API_BEGIN();
delete static_cast<PackedFunc::FType*>(func);
API_END();
}
int TVMFuncCall(TVMFunctionHandle func,
TVMValue* args,
int* arg_type_codes,
int num_args) {
API_BEGIN();
(*static_cast<const PackedFunc::FType*>(func))(
args, arg_type_codes, num_args);
API_END();
}
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#define TVM_RUNTIME_DEVICE_API_H_ #define TVM_RUNTIME_DEVICE_API_H_
#include <tvm/base.h> #include <tvm/base.h>
#include <tvm/c_runtime_api.h> #include <tvm/runtime/c_runtime_api.h>
namespace tvm { namespace tvm {
namespace runtime { namespace runtime {
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#ifndef TVM_RUNTIME_RUNTIME_BASE_H_ #ifndef TVM_RUNTIME_RUNTIME_BASE_H_
#define TVM_RUNTIME_RUNTIME_BASE_H_ #define TVM_RUNTIME_RUNTIME_BASE_H_
#include <tvm/c_runtime_api.h> #include <tvm/runtime/c_runtime_api.h>
#include <stdexcept> #include <stdexcept>
/*! \brief macro to guard beginning and end section of all functions */ /*! \brief macro to guard beginning and end section of all functions */
......
#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <tvm/runtime/runtime.h>
TEST(PackedFunc, Basic) {
using namespace tvm::runtime;
int x = 0;
void* handle = &x;
TVMArray a;
PackedFunc([&](const TVMValue* args, const int* type_codes, int num_args) {
CHECK(num_args == 3);
CHECK(args[0].v_float64 == 1.0);
CHECK(type_codes[0] == kFloat);
CHECK(args[1].v_handle == &a);
CHECK(type_codes[1] == kHandle);
CHECK(args[2].v_handle == &x);
CHECK(type_codes[2] == kHandle);
})(1.0, &a, handle);
}
int main(int argc, char ** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
return RUN_ALL_TESTS();
}
...@@ -15,11 +15,11 @@ def mock_test_add(): ...@@ -15,11 +15,11 @@ def mock_test_add():
thread_x = tvm.IterVar((0, num_thread), thread_tag="threadIdx.x") thread_x = tvm.IterVar((0, num_thread), thread_tag="threadIdx.x")
_, x = s[C].split(C.op.axis[0], factor=num_thread, outer=grid_x) _, x = s[C].split(C.op.axis[0], factor=num_thread, outer=grid_x)
_, x = s[C].split(x, outer=thread_x) _, x = s[C].split(x, outer=thread_x)
# compile to IR # compile to IR
bounds = tvm.schedule.InferBound(s) bounds = tvm.schedule.InferBound(s)
stmt = tvm.ir_pass.ScheduleOps(s, bounds) stmt = tvm.ir_pass.ScheduleOps(s, bounds)
Ab = tvm.Buffer(A.shape, A.dtype, name='A') Ab = tvm.Buffer(A.shape, A.dtype, name='A')
Bb = tvm.Buffer(B.shape, B.dtype, name='B') Bb = tvm.Buffer(B.shape, B.dtype, name='B')
Cb = tvm.Buffer(C.shape, C.dtype, name='C') Cb = tvm.Buffer(C.shape, C.dtype, name='C')
......
import tvm
import numpy as np
def test_function():
ctx = tvm.cpu(0)
x = np.random.randint(0, 10, size=(3, 4))
x = np.array(x)
y = tvm.nd.array(x, ctx=ctx)
f = tvm.codegen.DummyHelloFunction()
f(y, 10)
if __name__ == "__main__":
test_function()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment