Commit 28120f55 by Tianqi Chen Committed by GitHub

[C API] Make DSL API registerable, add copy from/to raw bytes (#222)

* [C API] Make DSL API registerable, add copy from/to raw bytes

* fix cython
parent 0a19b16a
/*!
* Copyright (c) 2016 by Contributors
* \file c_api.h
* \brief C API of TVM DSL
* \file c_dsl_api.h
*
* \note The API is designed in a minimum way.
* Most of the API functions are registered and can be pulled out.
* \brief TVM DSL Node C API, used to interact to DSL compilation.
*
* The common flow is:
* - Use TVMFuncListGlobalNames to get global function name
* - Use TVMFuncCall to call these functions.
* These are only a few functions needed for DSL construction time.
* These function are only available when link libtvm.
* If only TVM runtime is linked, calling these function will trigger error.
*
* \note Most API functions are registerd as PackedFunc and
* can be grabbed via TVMFuncGetGlobal
*/
#ifndef TVM_C_API_H_
#define TVM_C_API_H_
#ifndef TVM_C_DSL_API_H_
#define TVM_C_DSL_API_H_
#include "./runtime/c_runtime_api.h"
#ifdef __cplusplus
TVM_EXTERN_C {
#endif
/*! \brief handle to node */
typedef void* NodeHandle;
/*!
* \brief Inplace translate callback argument value to return value.
* This is only needed for non-POD arguments.
*
* \param value The value to be translated.
* \param code The type code to be translated.
* \note This function will do a shallow copy when necessary.
*
* \return 0 when success, -1 when failure happens.
*/
TVM_DLL int TVMCbArgToReturn(TVMValue* value, int code);
/*!
* \brief free the node handle
* \param handle The node handle to be freed.
* \return 0 when success, -1 when failure happens
......@@ -82,5 +74,7 @@ TVM_DLL int TVMNodeGetAttr(NodeHandle handle,
TVM_DLL int TVMNodeListAttrNames(NodeHandle handle,
int *out_size,
const char*** out_array);
#ifdef __cplusplus
} // TVM_EXTERN_C
#endif // TVM_C_API_H_
#endif
#endif // TVM_C_DSL_API_H_
......@@ -12,6 +12,7 @@
#include <ir/IREquality.h>
#include <arithmetic/Simplify.h>
#include <tvm/ir_functor.h>
#include <arithmetic/Simplify.h>
#include <unordered_map>
#include <vector>
#include <string>
......
......@@ -5,9 +5,15 @@
*
* The philosophy of TVM project is to customize the compilation
* stage to generate code that can used by other projects transparently.
*
* So this is a minimum runtime code gluing, and some limited
* memory management code to enable quick testing.
*
* The runtime API is independent from TVM compilation stack and can
* be linked via libtvm_runtime.
*
* The common flow is:
* - Use TVMFuncListGlobalNames to get global function name
* - Use TVMFuncCall to call these functions.
*/
#ifndef TVM_RUNTIME_C_RUNTIME_API_H_
#define TVM_RUNTIME_C_RUNTIME_API_H_
......@@ -243,6 +249,18 @@ TVM_DLL int TVMCFuncSetReturn(TVMRetValueHandle ret,
int type_code);
/*!
* \brief Inplace translate callback argument value to return value.
* This is only needed for non-POD arguments.
*
* \param value The value to be translated.
* \param code The type code to be translated.
* \note This function will do a shallow copy when necessary.
*
* \return 0 when success, -1 when failure happens.
*/
TVM_DLL int TVMCbArgToReturn(TVMValue* value, int code);
/*!
* \brief C type of packed function.
*
* \param args The arguments
......@@ -378,6 +396,28 @@ TVM_DLL int TVMArrayAlloc(const tvm_index_t* shape,
TVM_DLL int TVMArrayFree(TVMArrayHandle handle);
/*!
* \brief Copy array data from CPU byte array.
* \param handle The array handle.
* \param data the data pointer
* \param nbytes The number of bytes to copy.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMArrayCopyFromBytes(TVMArrayHandle handle,
void* data,
size_t nbytes);
/*!
* \brief Copy array data to CPU byte array.
* \param handle The array handle.
* \param data the data pointer
* \param nbytes The number of bytes to copy.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMArrayCopyToBytes(TVMArrayHandle handle,
void* data,
size_t nbytes);
/*!
* \brief Copy the array, both from and to must be valid during the copy.
* \param from The array to be copied from.
* \param to The target space.
......
......@@ -79,6 +79,7 @@ cdef extern from "tvm/runtime/c_runtime_api.h":
void* resource_handle,
TVMPackedCFuncFinalizer fin,
TVMFunctionHandle *out)
int TVMCbArgToReturn(TVMValue* value, int code)
int TVMArrayAlloc(tvm_index_t* shape,
tvm_index_t ndim,
DLDataType dtype,
......@@ -89,8 +90,7 @@ cdef extern from "tvm/runtime/c_runtime_api.h":
DLTensorHandle to,
TVMStreamHandle stream)
cdef extern from "tvm/c_api.h":
int TVMCbArgToReturn(TVMValue* value, int code)
cdef extern from "tvm/c_dsl_api.h":
int TVMNodeFree(NodeHandle handle)
TVMNodeTypeKey2Index(const char* type_key,
int* out_index)
......
......@@ -60,7 +60,6 @@ def context(dev_type, dev_id=0):
dev_type = TVMContext.STR2MASK[dev_type]
return TVMContext(dev_type, dev_id)
def numpyasarray(np_data):
"""Return a TVMArray representation of a numpy array.
"""
......@@ -158,11 +157,10 @@ class NDArrayBase(_NDArrayBase):
source_array = np.ascontiguousarray(source_array, dtype=self.dtype)
if source_array.shape != self.shape:
raise ValueError('array shape do not match the shape of NDArray')
source_tvm_arr, shape = numpyasarray(source_array)
check_call(_LIB.TVMArrayCopyFromTo(
ctypes.byref(source_tvm_arr), self.handle, None))
# de-allocate shape until now
_ = shape
assert source_array.flags['C_CONTIGUOUS']
data = source_array.ctypes.data_as(ctypes.c_void_p)
nbytes = ctypes.c_size_t(np.prod(source_array.shape) * source_array.dtype.itemsize)
check_call(_LIB.TVMArrayCopyFromBytes(self.handle, data, nbytes))
def asnumpy(self):
"""Convert this array to numpy array
......@@ -173,10 +171,10 @@ class NDArrayBase(_NDArrayBase):
The corresponding numpy array.
"""
np_arr = np.empty(self.shape, dtype=self.dtype)
tvm_arr, shape = numpyasarray(np_arr)
check_call(_LIB.TVMArrayCopyFromTo(
self.handle, ctypes.byref(tvm_arr), None))
_ = shape
assert np_arr.flags['C_CONTIGUOUS']
data = np_arr.ctypes.data_as(ctypes.c_void_p)
nbytes = ctypes.c_size_t(np.prod(np_arr.shape) * np_arr.dtype.itemsize)
check_call(_LIB.TVMArrayCopyToBytes(self.handle, data, nbytes))
return np_arr
def copyto(self, target):
......
......@@ -81,11 +81,9 @@ def register_node(type_key=None):
def register(cls):
"""internal register function"""
tindex = ctypes.c_int()
try:
check_call(_LIB.TVMNodeTypeKey2Index(c_str(node_name), ctypes.byref(tindex)))
ret = _LIB.TVMNodeTypeKey2Index(c_str(node_name), ctypes.byref(tindex))
if ret == 0:
_register_node(tindex.value, cls)
except AttributeError:
pass
return cls
if isinstance(type_key, str):
......
......@@ -390,7 +390,7 @@ def decl_buffer(shape,
strides=None,
elem_offset=None,
scope="",
data_alignment=0,
data_alignment=-1,
offset_factor=0):
"""Decleare a new symbolic buffer.
......@@ -426,7 +426,7 @@ def decl_buffer(shape,
data_alignment: int, optional
The alignment of data pointer in bytes.
If 0 is passed, the alignment will be set to TVM's internal default.
If -1 is passed, the alignment will be set to TVM's internal default.
offset_factor: int, optional
The factor of elem_offset field, when set,
......
......@@ -28,7 +28,7 @@ class BuildConfig(object):
'unroll_explicit': True,
'detect_global_barrier': False,
'offset_factor': 0,
'data_alignment': 0,
'data_alignment': -1,
'restricted_func': True
}
def __init__(self, **kwargs):
......@@ -81,7 +81,7 @@ def build_config(**kwargs):
data_alignment: int, optional
The alignment of data pointer in bytes.
If 0 is passed, the alignment will be set to TVM's internal default.
If -1 is passed, the alignment will be set to TVM's internal default.
offset_factor: int, default=0
The factor used in default buffer declaration.
......
/*!
* Copyright (c) 2016 by Contributors
* Implementation of C API
* \file c_api.cc
* Implementation of DSL API
* \file dsl_api.cc
*/
#include <dmlc/base.h>
#include <dmlc/logging.h>
#include <dmlc/thread_local.h>
#include <tvm/c_api.h>
#include <tvm/api_registry.h>
#include <vector>
#include <string>
#include <exception>
#include "../runtime/runtime_base.h"
#include "../runtime/dsl_api.h"
namespace tvm {
namespace runtime {
/*! \brief entry to to easily hold returning information */
struct TVMAPIThreadLocalEntry {
/*! \brief result holder for returning strings */
......@@ -24,8 +24,6 @@ struct TVMAPIThreadLocalEntry {
std::string ret_str;
};
using namespace tvm;
/*! \brief Thread local store that can be used to hold return values. */
typedef dmlc::ThreadLocalStore<TVMAPIThreadLocalEntry> TVMAPIThreadLocalStore;
......@@ -96,44 +94,25 @@ struct APIAttrDir : public AttrVisitor {
}
};
int TVMNodeFree(NodeHandle handle) {
API_BEGIN();
class DSLAPIImpl : public DSLAPI {
public:
void NodeFree(NodeHandle handle) const final {
delete static_cast<TVMAPINode*>(handle);
API_END();
}
int TVMCbArgToReturn(TVMValue* value, int code) {
API_BEGIN();
tvm::runtime::TVMRetValue rv;
rv = tvm::runtime::TVMArgValue(*value, code);
int tcode;
rv.MoveToCHost(value, &tcode);
CHECK_EQ(tcode, code);
API_END();
}
int TVMNodeTypeKey2Index(const char* type_key,
int* out_index) {
API_BEGIN();
}
void NodeTypeKey2Index(const char* type_key,
int* out_index) const final {
*out_index = static_cast<int>(Node::TypeKey2Index(type_key));
API_END();
}
int TVMNodeGetTypeIndex(NodeHandle handle,
int* out_index) {
API_BEGIN();
}
void NodeGetTypeIndex(NodeHandle handle,
int* out_index) const final {
*out_index = static_cast<int>(
(*static_cast<TVMAPINode*>(handle))->type_index());
API_END();
}
int TVMNodeGetAttr(NodeHandle handle,
}
void NodeGetAttr(NodeHandle handle,
const char* key,
TVMValue* ret_val,
int* ret_type_code,
int* ret_success) {
API_BEGIN();
int* ret_success) const final {
TVMRetValue rv;
APIAttrGetter getter;
getter.skey = key;
......@@ -156,14 +135,11 @@ int TVMNodeGetAttr(NodeHandle handle,
rv.MoveToCHost(ret_val, ret_type_code);
}
}
API_END();
}
int TVMNodeListAttrNames(NodeHandle handle,
}
void NodeListAttrNames(NodeHandle handle,
int *out_size,
const char*** out_array) {
const char*** out_array) const final {
TVMAPIThreadLocalEntry *ret = TVMAPIThreadLocalStore::Get();
API_BEGIN();
ret->ret_vec_str.clear();
TVMAPINode* tnode = static_cast<TVMAPINode*>(handle);
APIAttrDir dir;
......@@ -175,5 +151,14 @@ int TVMNodeListAttrNames(NodeHandle handle,
}
*out_array = dmlc::BeginPtr(ret->ret_vec_charp);
*out_size = static_cast<int>(ret->ret_vec_str.size());
API_END();
}
}
};
TVM_REGISTER_GLOBAL("dsl_api.singleton")
.set_body([](TVMArgs args, TVMRetValue* rv) {
static DSLAPIImpl impl;
void* ptr = &impl;
*rv = ptr;
});
} // namespace runtime
} // namespace tvm
......@@ -153,7 +153,7 @@ Buffer BufferNode::make(Var data,
if (!elem_offset.defined()) {
elem_offset = make_const(n->shape[0].type(), 0);
}
if (data_alignment == 0) {
if (data_alignment <= 0) {
data_alignment = runtime::kAllocAlignment;
}
if (offset_factor == 0) {
......
......@@ -16,15 +16,15 @@ namespace ir {
void BinderAddAssert(Expr cond,
const std::string& arg_name,
std::vector<Stmt>* asserts) {
cond = Simplify(cond);
if (is_zero(cond)) {
Expr scond = Simplify(cond);
if (is_zero(scond)) {
LOG(FATAL) << "Bind have an unmet assertion: "
<< cond << ", " << " on argument " << arg_name;
}
if (!is_one(cond)) {
if (!is_one(scond)) {
std::ostringstream os;
os << "Argument " << arg_name << " has an unsatisfied constraint";
asserts->emplace_back(AssertStmt::make(cond, os.str(), Evaluate::make(0)));
asserts->emplace_back(AssertStmt::make(scond, os.str(), Evaluate::make(0)));
}
}
......
......@@ -103,9 +103,11 @@ class StorageFlattener : public IRMutator {
} else {
skey = StorageScope::make(strkey);
}
// use small alignment for small arrays
int32_t const_size = Allocate::constant_allocation_size(shape, key.GetName());
int align = GetTempAllocaAlignment(op->type, const_size);
e.buffer = BufferNode::make(
Var(key.GetName(), Handle()),
op->type, shape,
......
/*!
* Copyright (c) 2017 by Contributors
* \file cpu_dsl_api.cc
* \brief DSL API dispatcher
*/
#include <tvm/runtime/registry.h>
#include <tvm/c_dsl_api.h>
#include "./dsl_api.h"
#include "./runtime_base.h"
namespace tvm {
namespace runtime {
DSLAPI* FindDSLAPI() {
auto* f = Registry::Get("dsl_api.singleton");
if (f == nullptr) {
throw dmlc::Error("TVM runtime only environment, "\
"DSL API is not available");
}
void* ptr = (*f)();
return static_cast<DSLAPI*>(ptr);
}
static DSLAPI* GetDSLAPI() {
static DSLAPI* inst = FindDSLAPI();
return inst;
}
} // namespace runtime
} // namespace tvm
using namespace tvm::runtime;
int TVMNodeFree(NodeHandle handle) {
API_BEGIN();
GetDSLAPI()->NodeFree(handle);
API_END();
}
int TVMNodeTypeKey2Index(const char* type_key,
int* out_index) {
API_BEGIN();
GetDSLAPI()->NodeTypeKey2Index(type_key, out_index);
API_END();
}
int TVMNodeGetTypeIndex(NodeHandle handle,
int* out_index) {
API_BEGIN();
GetDSLAPI()->NodeGetTypeIndex(handle, out_index);
API_END();
}
int TVMNodeGetAttr(NodeHandle handle,
const char* key,
TVMValue* out_value,
int* out_type_code,
int* out_success) {
API_BEGIN();
GetDSLAPI()->NodeGetAttr(
handle, key, out_value, out_type_code, out_success);
API_END();
}
int TVMNodeListAttrNames(NodeHandle handle,
int *out_size,
const char*** out_array) {
API_BEGIN();
GetDSLAPI()->NodeListAttrNames(
handle, out_size, out_array);
API_END();
}
......@@ -130,7 +130,7 @@ inline size_t GetDataSize(TVMArray* arr) {
for (tvm_index_t i = 0; i < arr->ndim; ++i) {
size *= arr->shape[i];
}
size *= (arr->dtype.bits / 8) * arr->dtype.lanes;
size *= (arr->dtype.bits * arr->dtype.lanes + 7) / 8;
return size;
}
......@@ -394,6 +394,40 @@ int TVMArrayCopyFromTo(TVMArrayHandle from,
API_END();
}
int TVMArrayCopyFromBytes(TVMArrayHandle handle,
void* data,
size_t nbytes) {
API_BEGIN();
TVMContext cpu_ctx;
cpu_ctx.device_type = kCPU;
cpu_ctx.device_id = 0;
size_t arr_size = GetDataSize(handle);
CHECK_EQ(arr_size, nbytes)
<< "TVMArrayCopyFromBytes: size mismatch";
DeviceAPIManager::Get(handle->ctx)->CopyDataFromTo(
data, 0,
handle->data, handle->byte_offset,
nbytes, cpu_ctx, handle->ctx, nullptr);
API_END();
}
int TVMArrayCopyToBytes(TVMArrayHandle handle,
void* data,
size_t nbytes) {
API_BEGIN();
TVMContext cpu_ctx;
cpu_ctx.device_type = kCPU;
cpu_ctx.device_id = 0;
size_t arr_size = GetDataSize(handle);
CHECK_EQ(arr_size, nbytes)
<< "TVMArrayCopyToBytes: size mismatch";
DeviceAPIManager::Get(handle->ctx)->CopyDataFromTo(
handle->data, handle->byte_offset,
data, 0,
nbytes, handle->ctx, cpu_ctx, nullptr);
API_END();
}
int TVMSetStream(TVMContext ctx, TVMStreamHandle stream) {
API_BEGIN();
DeviceAPIManager::Get(ctx)->SetStream(ctx, stream);
......@@ -406,6 +440,16 @@ int TVMSynchronize(TVMContext ctx, TVMStreamHandle stream) {
API_END();
}
int TVMCbArgToReturn(TVMValue* value, int code) {
API_BEGIN();
tvm::runtime::TVMRetValue rv;
rv = tvm::runtime::TVMArgValue(*value, code);
int tcode;
rv.MoveToCHost(value, &tcode);
CHECK_EQ(tcode, code);
API_END();
}
// set device api
TVM_REGISTER_GLOBAL(tvm::runtime::symbol::tvm_set_device)
.set_body([](TVMArgs args, TVMRetValue *ret) {
......
/*!
* Copyright (c) 2017 by Contributors
* \file cpu_dsl_api.cc
* \brief DSL API dispatcher
*/
#ifndef TVM_RUNTIME_DSL_API_H_
#define TVM_RUNTIME_DSL_API_H_
#include <tvm/c_dsl_api.h>
namespace tvm {
namespace runtime {
/*!
* \brief Common interface for DSL API
* Used for runtime registration
*/
class DSLAPI {
public:
virtual void NodeFree(NodeHandle handle) const = 0;
virtual void NodeTypeKey2Index(const char* type_key,
int* out_index) const = 0;
virtual void NodeGetTypeIndex(NodeHandle handle,
int* out_index) const = 0;
virtual void NodeGetAttr(NodeHandle handle,
const char* key,
TVMValue* out_value,
int* out_type_code,
int* out_success) const = 0;
virtual void NodeListAttrNames(NodeHandle handle,
int *out_size,
const char*** out_array) const = 0;
};
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_DSL_API_H_
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment