Commit 7e025234 by Tianqi Chen Committed by GitHub

[RUNTIME] Add interface header of runtime (#15)

* [RUNTIME] Add interface header of runtime

* fix mac build
parent 7f82912b
Subproject commit 3278103721cfabf7435f1e9ba1fd75a7c38f13c9
Subproject commit 6375e6b76f6b70d58f66b357d946c971843f3169
export LDFLAGS = -pthread -lm
export CFLAGS = -std=c++11 -Wall -O2 -Wno-unknown-pragmas -funroll-loops\
export CFLAGS = -std=c++11 -Wall -O2\
-Iinclude -Idmlc-core/include -IHalideIR/src -fPIC
# specify tensor path
......
......@@ -6,74 +6,28 @@
#ifndef TVM_C_API_H_
#define TVM_C_API_H_
#ifdef __cplusplus
#define TVM_EXTERN_C extern "C"
#else
#define TVM_EXTERN_C
#endif
/*! \brief TVM_DLL prefix for windows */
#ifdef _WIN32
#ifdef TVM_EXPORTS
#define TVM_DLL __declspec(dllexport)
#else
#define TVM_DLL __declspec(dllimport)
#endif
#else
#define TVM_DLL
#endif
#include "./c_runtime_api.h"
TVM_EXTERN_C {
/*! \brief handle to functions */
typedef void* FunctionHandle;
typedef void* APIFunctionHandle;
/*! \brief handle to node */
typedef void* NodeHandle;
/*!
* \brief union type for returning value of attributes
* Attribute type can be identified by id
*/
typedef union {
long v_long; // NOLINT(*)
double v_double;
const char* v_str;
NodeHandle v_handle;
} ArgVariant;
/*! \brief attribute types */
typedef enum {
kNull = 0,
kLong = 1,
kDouble = 2,
kStr = 3,
kNodeHandle = 4
} ArgVariantID;
/*!
* \brief return str message of the last error
* all function in this file will return 0 when success
* and -1 when an error occured,
* NNGetLastError can be called to retrieve the error
*
* this function is threadsafe and can be called by different thread
* \return error info
*/
TVM_DLL const char *TVMGetLastError(void);
/*!
* \brief List all the node function name
* \param out_size The number of functions
* \param out_array The array of function names.
*/
TVM_DLL int TVMListFunctionNames(int *out_size,
TVM_DLL int TVMListAPIFunctionNames(int *out_size,
const char*** out_array);
/*!
* \brief get function handle by name
* \param name The name of function
* \param handle The returning function handle
*/
TVM_DLL int TVMGetFunctionHandle(const char* name,
FunctionHandle *handle);
TVM_DLL int TVMGetAPIFunctionHandle(const char* name,
APIFunctionHandle *handle);
/*!
* \brief Get the detailed information about function.
......@@ -88,7 +42,7 @@ TVM_DLL int TVMGetFunctionHandle(const char* name,
* \param return_type Return type of the function, if any.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMGetFunctionInfo(FunctionHandle handle,
TVM_DLL int TVMGetAPIFunctionInfo(APIFunctionHandle handle,
const char **real_name,
const char **description,
int *num_doc_args,
......@@ -104,7 +58,7 @@ TVM_DLL int TVMGetFunctionInfo(FunctionHandle handle,
* \param arg number of attributes
* \param type_id The typeid of attributes.
*/
TVM_DLL int TVMPushStack(ArgVariant arg,
TVM_DLL int TVMAPIPushStack(TVMArg arg,
int type_id);
/*!
......@@ -115,8 +69,8 @@ TVM_DLL int TVMPushStack(ArgVariant arg,
* \param ret_val The return value.
* \param ret_typeid the type id of return value.
*/
TVM_DLL int TVMFunctionCall(FunctionHandle handle,
ArgVariant* ret_val,
TVM_DLL int TVMAPIFunctionCall(APIFunctionHandle handle,
TVMArg* ret_val,
int* ret_typeid);
/*!
......@@ -135,7 +89,7 @@ TVM_DLL int TVMNodeFree(NodeHandle handle);
*/
TVM_DLL int TVMNodeGetAttr(NodeHandle handle,
const char* key,
ArgVariant* out_value,
TVMArg* out_value,
int* out_typeid,
int* out_success);
......
/*!
* Copyright (c) 2016 by Contributors
* \file c_runtime_api.h
* \brief TVM runtime library.
*
* The philosophy of TVM project is to customize the compilation
* stage to generate code that can used by other projects transparently.
*
* So this is a minimum runtime code gluing, and some limited
* memory management code to enable quick testing.
*/
#ifndef TVM_C_RUNTIME_API_H_
#define TVM_C_RUNTIME_API_H_
#ifdef __cplusplus
#define TVM_EXTERN_C extern "C"
#else
#define TVM_EXTERN_C
#endif
/*! \brief TVM_DLL prefix for windows */
#ifdef _WIN32
#ifdef TVM_EXPORTS
#define TVM_DLL __declspec(dllexport)
#else
#define TVM_DLL __declspec(dllimport)
#endif
#else
#define TVM_DLL
#endif
#include <stdint.h>
TVM_EXTERN_C {
/*! \brief type of array index. */
typedef unsigned tvm_index_t;
/*!
* \brief union type for arguments and return values
* in both runtime API and TVM API calls
*/
typedef union {
long v_long; // NOLINT(*)
double v_double;
const char* v_str;
void* v_handle;
} TVMArg;
/*!
* \brief The type index in TVM.
*/
typedef enum {
kNull = 0,
kLong = 1,
kDouble = 2,
kStr = 3,
kNodeHandle = 4,
kArrayHandle = 5
} TVMArgTypeID;
/*!
* \brief The device type
*/
typedef enum {
/*! \brief CPU device */
kCPU = 1,
/*! \brief NVidia GPU device(CUDA) */
kGPU = 2,
/*! \brief opencl device */
KOpenCL = 4
} TVMDeviceMask;
/*!
* \brief The Device information, abstract away common device types.
*/
typedef struct {
/*! \brief The device type mask */
int dev_mask;
/*! \brief the device id */
int dev_id;
} TVMDevice;
/*! \brief The type code in TVMDataType */
typedef enum {
kInt = 0U,
kUInt = 1U,
kFloat = 2U
} TVMTypeCode;
/*!
* \brief the data type
* Examples
* - float: type_code = 2, bits = 32, lanes=1
* - float4(vectorized 4 float): type_code = 2, bits = 32, lanes=4
* - int8: type_code = 0, bits = 8, lanes=1
*/
typedef struct {
/*! \brief type code, in TVMTypeCode */
uint8_t type_code;
/*! \brief number of bits of the type */
uint8_t bits;
/*! \brief number of lanes, */
uint16_t lanes;
} TVMDataType;
/*!
* \brief Data structure representing a n-dimensional array(tensor).
* This is used to pass data specification into TVM.
*/
typedef struct {
/*! \brief The data field pointer on specified device */
void* data;
/*! \brief The shape pointers of the array */
const tvm_index_t* shape;
/*!
* \brief The stride data about each dimension of the array, can be NULL
* When strides is NULL, it indicates that the array is empty.
*/
const tvm_index_t* strides;
/*! \brief number of dimensions of the array */
tvm_index_t ndim;
/*! \brief The data type flag */
TVMDataType dtype;
/*! \brief The device this array sits on */
TVMDevice device;
} TVMArray;
/*!
* \brief The stream that is specific to device
* can be NULL, which indicates the default one.
*/
typedef void* TVMStreamHandle;
/*!
* \brief Pointer to function handle that points to
* a generated TVM function.
*/
typedef void* TVMFunctionHandle;
/*! \brief the array handle */
typedef TVMArray* TVMArrayHandle;
/*!
* \brief return str message of the last error
* all function in this file will return 0 when success
* and -1 when an error occured,
* TVMGetLastError can be called to retrieve the error
*
* this function is threadsafe and can be called by different thread
* \return error info
*/
TVM_DLL const char *TVMGetLastError(void);
/*!
* \brief Allocate a nd-array's memory,
* including space of shape, of given spec.
*
* \param shape The shape of the array, the data content will be copied to out
* \param ndim The number of dimension of the array.
* \param dtype The array data type.
* \param device The device this array sits on.
* \param out The output handle.
* \return Whether the function is successful.
*/
TVM_DLL int TVMArrayAlloc(const tvm_index_t* shape,
tvm_index_t ndim,
int dtype,
TVMDevice device,
TVMArrayHandle* out);
/*!
* \brief Free the TVM Array.
* \param handle The array handle to be freed.
*/
TVM_DLL int TVMArrayFree(TVMArrayHandle handle);
/*!
* \brief Copy the array, both from and to must be valid during the copy.
* \param from The array to be copied from.
* \param to The target space.
* \param stream The stream where the copy happens, can be NULL.
*/
TVM_DLL int TVMArrayCopyFromTo(TVMArrayHandle from,
TVMArrayHandle to,
TVMStreamHandle stream);
/*!
* \brief Wait until all computations on stream completes.
* \param stream the stream to be synchronized.
*/
TVM_DLL int TVMSynchronize(TVMStreamHandle stream);
/*!
* \brief Launch a generated TVM function
* \param func function handle to be launched.
* \param args The arguments
* \param arg_type_ids The type id of the arguments
* \param num_args Number of arguments.
* \param stream The stream this function to be launched on.
*/
TVM_DLL int TVMLaunch(TVMFunctionHandle func,
TVMArg* args,
int* arg_type_ids,
int num_args,
TVMStreamHandle stream);
} // TVM_EXTERN_C
#endif // TVM_C_RUNTIME_API_H_
......@@ -41,10 +41,6 @@ __version__ = libinfo.__version__
# library instance of nnvm
_LIB = _load_lib()
# type definitions
FunctionHandle = ctypes.c_void_p
NodeHandle = ctypes.c_void_p
#----------------------------
# helper function definition
#----------------------------
......
......@@ -10,17 +10,20 @@ from numbers import Number, Integral
from .._base import _LIB
from .._base import c_str, py_str, string_types
from .._base import FunctionHandle, NodeHandle
from .._base import check_call, ctypes2docstring
from .. import _function_internal
class ArgVariant(ctypes.Union):
"""ArgVariant in C API"""
class TVMArg(ctypes.Union):
"""TVMArg in C API"""
_fields_ = [("v_long", ctypes.c_long),
("v_double", ctypes.c_double),
("v_str", ctypes.c_char_p),
("v_handle", ctypes.c_void_p)]
# type definitions
APIFunctionHandle = ctypes.c_void_p
NodeHandle = ctypes.c_void_p
kNull = 0
kLong = 1
kDouble = 2
......@@ -34,7 +37,7 @@ def _return_node(x):
handle = x.v_handle
if not isinstance(handle, NodeHandle):
handle = NodeHandle(handle)
ret_val = ArgVariant()
ret_val = TVMArg()
ret_typeid = ctypes.c_int()
ret_success = ctypes.c_int()
check_call(_LIB.TVMNodeGetAttr(
......@@ -77,7 +80,7 @@ class NodeBase(object):
check_call(_LIB.TVMNodeFree(self.handle))
def __getattr__(self, name):
ret_val = ArgVariant()
ret_val = TVMArg()
ret_typeid = ctypes.c_int()
ret_success = ctypes.c_int()
check_call(_LIB.TVMNodeGetAttr(
......@@ -169,21 +172,21 @@ def convert(value):
def _push_arg(arg):
a = ArgVariant()
a = TVMArg()
if arg is None:
_LIB.TVMPushStack(a, ctypes.c_int(kNull))
_LIB.TVMAPIPushStack(a, ctypes.c_int(kNull))
elif isinstance(arg, NodeBase):
a.v_handle = arg.handle
_LIB.TVMPushStack(a, ctypes.c_int(kNodeHandle))
_LIB.TVMAPIPushStack(a, ctypes.c_int(kNodeHandle))
elif isinstance(arg, int):
a.v_long = ctypes.c_long(arg)
_LIB.TVMPushStack(a, ctypes.c_int(kLong))
_LIB.TVMAPIPushStack(a, ctypes.c_int(kLong))
elif isinstance(arg, Number):
a.v_double = ctypes.c_double(arg)
_LIB.TVMPushStack(a, ctypes.c_int(kDouble))
_LIB.TVMAPIPushStack(a, ctypes.c_int(kDouble))
elif isinstance(arg, string_types):
a.v_str = c_str(arg)
_LIB.TVMPushStack(a, ctypes.c_int(kStr))
_LIB.TVMAPIPushStack(a, ctypes.c_int(kStr))
else:
raise TypeError("Don't know how to handle type %s" % type(arg))
......@@ -198,7 +201,7 @@ def _make_function(handle, name):
arg_descs = ctypes.POINTER(ctypes.c_char_p)()
ret_type = ctypes.c_char_p()
check_call(_LIB.TVMGetFunctionInfo(
check_call(_LIB.TVMGetAPIFunctionInfo(
handle, ctypes.byref(real_name), ctypes.byref(desc),
ctypes.byref(num_args),
ctypes.byref(arg_names),
......@@ -232,9 +235,9 @@ def _make_function(handle, name):
for arg in cargs:
_push_arg(arg)
ret_val = ArgVariant()
ret_val = TVMArg()
ret_typeid = ctypes.c_int()
check_call(_LIB.TVMFunctionCall(
check_call(_LIB.TVMAPIFunctionCall(
handle, ctypes.byref(ret_val), ctypes.byref(ret_typeid)))
return RET_SWITCH[ret_typeid.value](ret_val)
......@@ -267,7 +270,7 @@ def _init_function_module(root_namespace):
plist = ctypes.POINTER(ctypes.c_char_p)()
size = ctypes.c_uint()
check_call(_LIB.TVMListFunctionNames(ctypes.byref(size),
check_call(_LIB.TVMListAPIFunctionNames(ctypes.byref(size),
ctypes.byref(plist)))
op_names = []
for i in range(size.value):
......@@ -282,8 +285,8 @@ def _init_function_module(root_namespace):
}
for name in op_names:
hdl = FunctionHandle()
check_call(_LIB.TVMGetFunctionHandle(c_str(name), ctypes.byref(hdl)))
hdl = APIFunctionHandle()
check_call(_LIB.TVMGetAPIFunctionHandle(c_str(name), ctypes.byref(hdl)))
fname = name
target_module = module_internal if name.startswith('_') else module_obj
for k, v in namespace_match.items():
......
......@@ -4,3 +4,4 @@
- lang The definition of DSL related data structure
- schedule The operations on the schedule graph before converting to IR.
- pass The optimization pass on the IR structure
- runtime The runtime related codes.
\ No newline at end of file
......@@ -9,8 +9,6 @@
/*! \brief entry to to easily hold returning information */
struct TVMAPIThreadLocalEntry {
/*! \brief hold last error */
std::string last_error;
/*! \brief result holder for returning strings */
std::vector<std::string> ret_vec_str;
/*! \brief result holder for returning string pointers */
......@@ -99,15 +97,7 @@ struct APIAttrDir : public AttrVisitor {
}
};
const char *TVMGetLastError() {
return TVMAPIThreadLocalStore::Get()->last_error.c_str();
}
void TVMAPISetLastError(const char* msg) {
TVMAPIThreadLocalStore::Get()->last_error = msg;
}
int TVMListFunctionNames(int *out_size,
int TVMListAPIFunctionNames(int *out_size,
const char*** out_array) {
API_BEGIN();
TVMAPIThreadLocalEntry *ret = TVMAPIThreadLocalStore::Get();
......@@ -121,16 +111,16 @@ int TVMListFunctionNames(int *out_size,
API_END();
}
int TVMGetFunctionHandle(const char* fname,
FunctionHandle* out) {
int TVMGetAPIFunctionHandle(const char* fname,
APIFunctionHandle* out) {
API_BEGIN();
const APIFunctionReg* reg = dmlc::Registry<APIFunctionReg>::Find(fname);
CHECK(reg != nullptr) << "cannot find function " << fname;
*out = (FunctionHandle)reg;
*out = (APIFunctionHandle)reg;
API_END();
}
int TVMGetFunctionInfo(FunctionHandle handle,
int TVMGetAPIFunctionInfo(APIFunctionHandle handle,
const char **real_name,
const char **description,
int *num_doc_args,
......@@ -162,7 +152,7 @@ int TVMGetFunctionInfo(FunctionHandle handle,
API_END();
}
int TVMPushStack(ArgVariant arg,
int TVMAPIPushStack(ArgVariant arg,
int type_id) {
TVMAPIThreadLocalEntry *ret = TVMAPIThreadLocalStore::Get();
API_BEGIN();
......@@ -181,7 +171,7 @@ int TVMPushStack(ArgVariant arg,
API_END_HANDLE_ERROR(ret->Clear());
}
int TVMFunctionCall(FunctionHandle handle,
int TVMAPIFunctionCall(APIFunctionHandle handle,
ArgVariant* ret_val,
int* ret_typeid) {
TVMAPIThreadLocalEntry *ret = TVMAPIThreadLocalStore::Get();
......
......@@ -14,29 +14,6 @@
#include <string>
#include <exception>
#include "./c_api_registry.h"
/*! \brief macro to guard beginning and end section of all functions */
#define API_BEGIN() try {
/*! \brief every function starts with API_BEGIN();
and finishes with API_END() or API_END_HANDLE_ERROR */
#define API_END() } catch(std::runtime_error &_except_) { return TVMAPIHandleException(_except_); } return 0; // NOLINT(*)
/*!
* \brief every function starts with API_BEGIN();
* and finishes with API_END() or API_END_HANDLE_ERROR
* The finally clause contains procedure to cleanup states when an error happens.
*/
#define API_END_HANDLE_ERROR(Finalize) } catch(std::runtime_error &_except_) { Finalize; return TVMAPIHandleException(_except_); } return 0; // NOLINT(*)
void TVMAPISetLastError(const char* msg);
/*!
* \brief handle exception throwed out
* \param e the exception
* \return the return value of API after exception is handled
*/
inline int TVMAPIHandleException(const std::runtime_error &e) {
TVMAPISetLastError(e.what());
return -1;
}
#include "../runtime/runtime_common.h"
#endif // TVM_C_API_C_API_COMMON_H_
......@@ -15,6 +15,9 @@
#include <vector>
#include "../base/common.h"
using ArgVariant = TVMArg;
using ArgVariantID = TVMArgTypeID;
namespace tvm {
inline const char* TypeId2Str(ArgVariantID type_id) {
......
/*!
* Copyright (c) 2016 by Contributors
* Implementation of error handling API
* \file error_handle.cc
*/
#include <dmlc/thread_local.h>
#include <string>
#include "./runtime_common.h"
struct TVMErrorEntry {
std::string last_error;
};
typedef dmlc::ThreadLocalStore<TVMErrorEntry> TVMAPIErrorStore;
const char *TVMGetLastError() {
return TVMAPIErrorStore::Get()->last_error.c_str();
}
void TVMAPISetLastError(const char* msg) {
TVMAPIErrorStore::Get()->last_error = msg;
}
/*!
* Copyright (c) 2016 by Contributors
* \file c_runtime_common.h
* \brief Common fields of all C APIs
*/
#ifndef TVM_RUNTIME_RUNTIME_COMMON_H_
#define TVM_RUNTIME_RUNTIME_COMMON_H_
#include <tvm/c_runtime_api.h>
#include <exception>
/*! \brief macro to guard beginning and end section of all functions */
#define API_BEGIN() try {
/*! \brief every function starts with API_BEGIN();
and finishes with API_END() or API_END_HANDLE_ERROR */
#define API_END() } catch(std::runtime_error &_except_) { return TVMAPIHandleException(_except_); } return 0; // NOLINT(*)
/*!
* \brief every function starts with API_BEGIN();
* and finishes with API_END() or API_END_HANDLE_ERROR
* The finally clause contains procedure to cleanup states when an error happens.
*/
#define API_END_HANDLE_ERROR(Finalize) } catch(std::runtime_error &_except_) { Finalize; return TVMAPIHandleException(_except_); } return 0; // NOLINT(*)
void TVMAPISetLastError(const char* msg);
/*!
* \brief handle exception throwed out
* \param e the exception
* \return the return value of API after exception is handled
*/
inline int TVMAPIHandleException(const std::runtime_error &e) {
TVMAPISetLastError(e.what());
return -1;
}
#endif // TVM_RUNTIME_RUNTIME_COMMON_H_
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment