Commit 28120f55 by Tianqi Chen Committed by GitHub

[C API] Make DSL API registerable, add copy from/to raw bytes (#222)

* [C API] Make DSL API registerable, add copy from/to raw bytes

* fix cython
parent 0a19b16a
/*! /*!
* Copyright (c) 2016 by Contributors * Copyright (c) 2016 by Contributors
* \file c_api.h * \file c_dsl_api.h
* \brief C API of TVM DSL
* *
* \note The API is designed in a minimum way. * \brief TVM DSL Node C API, used to interact to DSL compilation.
* Most of the API functions are registered and can be pulled out.
* *
* The common flow is: * These are only a few functions needed for DSL construction time.
* - Use TVMFuncListGlobalNames to get global function name * These function are only available when link libtvm.
* - Use TVMFuncCall to call these functions. * If only TVM runtime is linked, calling these function will trigger error.
*
* \note Most API functions are registerd as PackedFunc and
* can be grabbed via TVMFuncGetGlobal
*/ */
#ifndef TVM_C_API_H_ #ifndef TVM_C_DSL_API_H_
#define TVM_C_API_H_ #define TVM_C_DSL_API_H_
#include "./runtime/c_runtime_api.h" #include "./runtime/c_runtime_api.h"
#ifdef __cplusplus
TVM_EXTERN_C { TVM_EXTERN_C {
#endif
/*! \brief handle to node */ /*! \brief handle to node */
typedef void* NodeHandle; typedef void* NodeHandle;
/*! /*!
* \brief Inplace translate callback argument value to return value.
* This is only needed for non-POD arguments.
*
* \param value The value to be translated.
* \param code The type code to be translated.
* \note This function will do a shallow copy when necessary.
*
* \return 0 when success, -1 when failure happens.
*/
TVM_DLL int TVMCbArgToReturn(TVMValue* value, int code);
/*!
* \brief free the node handle * \brief free the node handle
* \param handle The node handle to be freed. * \param handle The node handle to be freed.
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
...@@ -82,5 +74,7 @@ TVM_DLL int TVMNodeGetAttr(NodeHandle handle, ...@@ -82,5 +74,7 @@ TVM_DLL int TVMNodeGetAttr(NodeHandle handle,
TVM_DLL int TVMNodeListAttrNames(NodeHandle handle, TVM_DLL int TVMNodeListAttrNames(NodeHandle handle,
int *out_size, int *out_size,
const char*** out_array); const char*** out_array);
#ifdef __cplusplus
} // TVM_EXTERN_C } // TVM_EXTERN_C
#endif // TVM_C_API_H_ #endif
#endif // TVM_C_DSL_API_H_
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include <ir/IREquality.h> #include <ir/IREquality.h>
#include <arithmetic/Simplify.h> #include <arithmetic/Simplify.h>
#include <tvm/ir_functor.h> #include <tvm/ir_functor.h>
#include <arithmetic/Simplify.h>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include <string> #include <string>
......
...@@ -5,9 +5,15 @@ ...@@ -5,9 +5,15 @@
* *
* The philosophy of TVM project is to customize the compilation * The philosophy of TVM project is to customize the compilation
* stage to generate code that can used by other projects transparently. * stage to generate code that can used by other projects transparently.
*
* So this is a minimum runtime code gluing, and some limited * So this is a minimum runtime code gluing, and some limited
* memory management code to enable quick testing. * memory management code to enable quick testing.
*
* The runtime API is independent from TVM compilation stack and can
* be linked via libtvm_runtime.
*
* The common flow is:
* - Use TVMFuncListGlobalNames to get global function name
* - Use TVMFuncCall to call these functions.
*/ */
#ifndef TVM_RUNTIME_C_RUNTIME_API_H_ #ifndef TVM_RUNTIME_C_RUNTIME_API_H_
#define TVM_RUNTIME_C_RUNTIME_API_H_ #define TVM_RUNTIME_C_RUNTIME_API_H_
...@@ -243,6 +249,18 @@ TVM_DLL int TVMCFuncSetReturn(TVMRetValueHandle ret, ...@@ -243,6 +249,18 @@ TVM_DLL int TVMCFuncSetReturn(TVMRetValueHandle ret,
int type_code); int type_code);
/*! /*!
* \brief Inplace translate callback argument value to return value.
* This is only needed for non-POD arguments.
*
* \param value The value to be translated.
* \param code The type code to be translated.
* \note This function will do a shallow copy when necessary.
*
* \return 0 when success, -1 when failure happens.
*/
TVM_DLL int TVMCbArgToReturn(TVMValue* value, int code);
/*!
* \brief C type of packed function. * \brief C type of packed function.
* *
* \param args The arguments * \param args The arguments
...@@ -378,6 +396,28 @@ TVM_DLL int TVMArrayAlloc(const tvm_index_t* shape, ...@@ -378,6 +396,28 @@ TVM_DLL int TVMArrayAlloc(const tvm_index_t* shape,
TVM_DLL int TVMArrayFree(TVMArrayHandle handle); TVM_DLL int TVMArrayFree(TVMArrayHandle handle);
/*! /*!
* \brief Copy array data from CPU byte array.
* \param handle The array handle.
* \param data the data pointer
* \param nbytes The number of bytes to copy.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMArrayCopyFromBytes(TVMArrayHandle handle,
void* data,
size_t nbytes);
/*!
* \brief Copy array data to CPU byte array.
* \param handle The array handle.
* \param data the data pointer
* \param nbytes The number of bytes to copy.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMArrayCopyToBytes(TVMArrayHandle handle,
void* data,
size_t nbytes);
/*!
* \brief Copy the array, both from and to must be valid during the copy. * \brief Copy the array, both from and to must be valid during the copy.
* \param from The array to be copied from. * \param from The array to be copied from.
* \param to The target space. * \param to The target space.
......
...@@ -79,6 +79,7 @@ cdef extern from "tvm/runtime/c_runtime_api.h": ...@@ -79,6 +79,7 @@ cdef extern from "tvm/runtime/c_runtime_api.h":
void* resource_handle, void* resource_handle,
TVMPackedCFuncFinalizer fin, TVMPackedCFuncFinalizer fin,
TVMFunctionHandle *out) TVMFunctionHandle *out)
int TVMCbArgToReturn(TVMValue* value, int code)
int TVMArrayAlloc(tvm_index_t* shape, int TVMArrayAlloc(tvm_index_t* shape,
tvm_index_t ndim, tvm_index_t ndim,
DLDataType dtype, DLDataType dtype,
...@@ -89,8 +90,7 @@ cdef extern from "tvm/runtime/c_runtime_api.h": ...@@ -89,8 +90,7 @@ cdef extern from "tvm/runtime/c_runtime_api.h":
DLTensorHandle to, DLTensorHandle to,
TVMStreamHandle stream) TVMStreamHandle stream)
cdef extern from "tvm/c_api.h": cdef extern from "tvm/c_dsl_api.h":
int TVMCbArgToReturn(TVMValue* value, int code)
int TVMNodeFree(NodeHandle handle) int TVMNodeFree(NodeHandle handle)
TVMNodeTypeKey2Index(const char* type_key, TVMNodeTypeKey2Index(const char* type_key,
int* out_index) int* out_index)
......
...@@ -60,7 +60,6 @@ def context(dev_type, dev_id=0): ...@@ -60,7 +60,6 @@ def context(dev_type, dev_id=0):
dev_type = TVMContext.STR2MASK[dev_type] dev_type = TVMContext.STR2MASK[dev_type]
return TVMContext(dev_type, dev_id) return TVMContext(dev_type, dev_id)
def numpyasarray(np_data): def numpyasarray(np_data):
"""Return a TVMArray representation of a numpy array. """Return a TVMArray representation of a numpy array.
""" """
...@@ -158,11 +157,10 @@ class NDArrayBase(_NDArrayBase): ...@@ -158,11 +157,10 @@ class NDArrayBase(_NDArrayBase):
source_array = np.ascontiguousarray(source_array, dtype=self.dtype) source_array = np.ascontiguousarray(source_array, dtype=self.dtype)
if source_array.shape != self.shape: if source_array.shape != self.shape:
raise ValueError('array shape do not match the shape of NDArray') raise ValueError('array shape do not match the shape of NDArray')
source_tvm_arr, shape = numpyasarray(source_array) assert source_array.flags['C_CONTIGUOUS']
check_call(_LIB.TVMArrayCopyFromTo( data = source_array.ctypes.data_as(ctypes.c_void_p)
ctypes.byref(source_tvm_arr), self.handle, None)) nbytes = ctypes.c_size_t(np.prod(source_array.shape) * source_array.dtype.itemsize)
# de-allocate shape until now check_call(_LIB.TVMArrayCopyFromBytes(self.handle, data, nbytes))
_ = shape
def asnumpy(self): def asnumpy(self):
"""Convert this array to numpy array """Convert this array to numpy array
...@@ -173,10 +171,10 @@ class NDArrayBase(_NDArrayBase): ...@@ -173,10 +171,10 @@ class NDArrayBase(_NDArrayBase):
The corresponding numpy array. The corresponding numpy array.
""" """
np_arr = np.empty(self.shape, dtype=self.dtype) np_arr = np.empty(self.shape, dtype=self.dtype)
tvm_arr, shape = numpyasarray(np_arr) assert np_arr.flags['C_CONTIGUOUS']
check_call(_LIB.TVMArrayCopyFromTo( data = np_arr.ctypes.data_as(ctypes.c_void_p)
self.handle, ctypes.byref(tvm_arr), None)) nbytes = ctypes.c_size_t(np.prod(np_arr.shape) * np_arr.dtype.itemsize)
_ = shape check_call(_LIB.TVMArrayCopyToBytes(self.handle, data, nbytes))
return np_arr return np_arr
def copyto(self, target): def copyto(self, target):
......
...@@ -81,11 +81,9 @@ def register_node(type_key=None): ...@@ -81,11 +81,9 @@ def register_node(type_key=None):
def register(cls): def register(cls):
"""internal register function""" """internal register function"""
tindex = ctypes.c_int() tindex = ctypes.c_int()
try: ret = _LIB.TVMNodeTypeKey2Index(c_str(node_name), ctypes.byref(tindex))
check_call(_LIB.TVMNodeTypeKey2Index(c_str(node_name), ctypes.byref(tindex))) if ret == 0:
_register_node(tindex.value, cls) _register_node(tindex.value, cls)
except AttributeError:
pass
return cls return cls
if isinstance(type_key, str): if isinstance(type_key, str):
......
...@@ -390,7 +390,7 @@ def decl_buffer(shape, ...@@ -390,7 +390,7 @@ def decl_buffer(shape,
strides=None, strides=None,
elem_offset=None, elem_offset=None,
scope="", scope="",
data_alignment=0, data_alignment=-1,
offset_factor=0): offset_factor=0):
"""Decleare a new symbolic buffer. """Decleare a new symbolic buffer.
...@@ -426,7 +426,7 @@ def decl_buffer(shape, ...@@ -426,7 +426,7 @@ def decl_buffer(shape,
data_alignment: int, optional data_alignment: int, optional
The alignment of data pointer in bytes. The alignment of data pointer in bytes.
If 0 is passed, the alignment will be set to TVM's internal default. If -1 is passed, the alignment will be set to TVM's internal default.
offset_factor: int, optional offset_factor: int, optional
The factor of elem_offset field, when set, The factor of elem_offset field, when set,
......
...@@ -28,7 +28,7 @@ class BuildConfig(object): ...@@ -28,7 +28,7 @@ class BuildConfig(object):
'unroll_explicit': True, 'unroll_explicit': True,
'detect_global_barrier': False, 'detect_global_barrier': False,
'offset_factor': 0, 'offset_factor': 0,
'data_alignment': 0, 'data_alignment': -1,
'restricted_func': True 'restricted_func': True
} }
def __init__(self, **kwargs): def __init__(self, **kwargs):
...@@ -81,7 +81,7 @@ def build_config(**kwargs): ...@@ -81,7 +81,7 @@ def build_config(**kwargs):
data_alignment: int, optional data_alignment: int, optional
The alignment of data pointer in bytes. The alignment of data pointer in bytes.
If 0 is passed, the alignment will be set to TVM's internal default. If -1 is passed, the alignment will be set to TVM's internal default.
offset_factor: int, default=0 offset_factor: int, default=0
The factor used in default buffer declaration. The factor used in default buffer declaration.
......
/*! /*!
* Copyright (c) 2016 by Contributors * Copyright (c) 2016 by Contributors
* Implementation of C API * Implementation of DSL API
* \file c_api.cc * \file dsl_api.cc
*/ */
#include <dmlc/base.h> #include <dmlc/base.h>
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <dmlc/thread_local.h> #include <dmlc/thread_local.h>
#include <tvm/c_api.h>
#include <tvm/api_registry.h> #include <tvm/api_registry.h>
#include <vector> #include <vector>
#include <string> #include <string>
#include <exception> #include <exception>
#include "../runtime/runtime_base.h" #include "../runtime/dsl_api.h"
namespace tvm {
namespace runtime {
/*! \brief entry to to easily hold returning information */ /*! \brief entry to to easily hold returning information */
struct TVMAPIThreadLocalEntry { struct TVMAPIThreadLocalEntry {
/*! \brief result holder for returning strings */ /*! \brief result holder for returning strings */
...@@ -24,8 +24,6 @@ struct TVMAPIThreadLocalEntry { ...@@ -24,8 +24,6 @@ struct TVMAPIThreadLocalEntry {
std::string ret_str; std::string ret_str;
}; };
using namespace tvm;
/*! \brief Thread local store that can be used to hold return values. */ /*! \brief Thread local store that can be used to hold return values. */
typedef dmlc::ThreadLocalStore<TVMAPIThreadLocalEntry> TVMAPIThreadLocalStore; typedef dmlc::ThreadLocalStore<TVMAPIThreadLocalEntry> TVMAPIThreadLocalStore;
...@@ -96,84 +94,71 @@ struct APIAttrDir : public AttrVisitor { ...@@ -96,84 +94,71 @@ struct APIAttrDir : public AttrVisitor {
} }
}; };
class DSLAPIImpl : public DSLAPI {
int TVMNodeFree(NodeHandle handle) { public:
API_BEGIN(); void NodeFree(NodeHandle handle) const final {
delete static_cast<TVMAPINode*>(handle); delete static_cast<TVMAPINode*>(handle);
API_END(); }
} void NodeTypeKey2Index(const char* type_key,
int* out_index) const final {
int TVMCbArgToReturn(TVMValue* value, int code) { *out_index = static_cast<int>(Node::TypeKey2Index(type_key));
API_BEGIN(); }
tvm::runtime::TVMRetValue rv; void NodeGetTypeIndex(NodeHandle handle,
rv = tvm::runtime::TVMArgValue(*value, code); int* out_index) const final {
int tcode; *out_index = static_cast<int>(
rv.MoveToCHost(value, &tcode); (*static_cast<TVMAPINode*>(handle))->type_index());
CHECK_EQ(tcode, code); }
API_END(); void NodeGetAttr(NodeHandle handle,
} const char* key,
TVMValue* ret_val,
int TVMNodeTypeKey2Index(const char* type_key, int* ret_type_code,
int* out_index) { int* ret_success) const final {
API_BEGIN(); TVMRetValue rv;
*out_index = static_cast<int>(Node::TypeKey2Index(type_key)); APIAttrGetter getter;
API_END(); getter.skey = key;
} getter.ret = &rv;
TVMAPINode* tnode = static_cast<TVMAPINode*>(handle);
int TVMNodeGetTypeIndex(NodeHandle handle, if (getter.skey == "type_key") {
int* out_index) { ret_val->v_str = (*tnode)->type_key();
API_BEGIN();
*out_index = static_cast<int>(
(*static_cast<TVMAPINode*>(handle))->type_index());
API_END();
}
int TVMNodeGetAttr(NodeHandle handle,
const char* key,
TVMValue* ret_val,
int* ret_type_code,
int* ret_success) {
API_BEGIN();
TVMRetValue rv;
APIAttrGetter getter;
getter.skey = key;
getter.ret = &rv;
TVMAPINode* tnode = static_cast<TVMAPINode*>(handle);
if (getter.skey == "type_key") {
ret_val->v_str = (*tnode)->type_key();
*ret_type_code = kStr;
*ret_success = 1;
} else {
(*tnode)->VisitAttrs(&getter);
*ret_success = getter.found_node_ref || rv.type_code() != kNull;
if (rv.type_code() == kStr ||
rv.type_code() == kTVMType) {
TVMAPIThreadLocalEntry *e = TVMAPIThreadLocalStore::Get();
e->ret_str = rv.operator std::string();
*ret_type_code = kStr; *ret_type_code = kStr;
ret_val->v_str = e->ret_str.c_str(); *ret_success = 1;
} else { } else {
rv.MoveToCHost(ret_val, ret_type_code); (*tnode)->VisitAttrs(&getter);
*ret_success = getter.found_node_ref || rv.type_code() != kNull;
if (rv.type_code() == kStr ||
rv.type_code() == kTVMType) {
TVMAPIThreadLocalEntry *e = TVMAPIThreadLocalStore::Get();
e->ret_str = rv.operator std::string();
*ret_type_code = kStr;
ret_val->v_str = e->ret_str.c_str();
} else {
rv.MoveToCHost(ret_val, ret_type_code);
}
} }
} }
API_END(); void NodeListAttrNames(NodeHandle handle,
} int *out_size,
const char*** out_array) const final {
TVMAPIThreadLocalEntry *ret = TVMAPIThreadLocalStore::Get();
ret->ret_vec_str.clear();
TVMAPINode* tnode = static_cast<TVMAPINode*>(handle);
APIAttrDir dir;
dir.names = &(ret->ret_vec_str);
(*tnode)->VisitAttrs(&dir);
ret->ret_vec_charp.clear();
for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) {
ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str());
}
*out_array = dmlc::BeginPtr(ret->ret_vec_charp);
*out_size = static_cast<int>(ret->ret_vec_str.size());
}
};
int TVMNodeListAttrNames(NodeHandle handle, TVM_REGISTER_GLOBAL("dsl_api.singleton")
int *out_size, .set_body([](TVMArgs args, TVMRetValue* rv) {
const char*** out_array) { static DSLAPIImpl impl;
TVMAPIThreadLocalEntry *ret = TVMAPIThreadLocalStore::Get(); void* ptr = &impl;
API_BEGIN(); *rv = ptr;
ret->ret_vec_str.clear(); });
TVMAPINode* tnode = static_cast<TVMAPINode*>(handle); } // namespace runtime
APIAttrDir dir; } // namespace tvm
dir.names = &(ret->ret_vec_str);
(*tnode)->VisitAttrs(&dir);
ret->ret_vec_charp.clear();
for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) {
ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str());
}
*out_array = dmlc::BeginPtr(ret->ret_vec_charp);
*out_size = static_cast<int>(ret->ret_vec_str.size());
API_END();
}
...@@ -153,7 +153,7 @@ Buffer BufferNode::make(Var data, ...@@ -153,7 +153,7 @@ Buffer BufferNode::make(Var data,
if (!elem_offset.defined()) { if (!elem_offset.defined()) {
elem_offset = make_const(n->shape[0].type(), 0); elem_offset = make_const(n->shape[0].type(), 0);
} }
if (data_alignment == 0) { if (data_alignment <= 0) {
data_alignment = runtime::kAllocAlignment; data_alignment = runtime::kAllocAlignment;
} }
if (offset_factor == 0) { if (offset_factor == 0) {
......
...@@ -16,15 +16,15 @@ namespace ir { ...@@ -16,15 +16,15 @@ namespace ir {
void BinderAddAssert(Expr cond, void BinderAddAssert(Expr cond,
const std::string& arg_name, const std::string& arg_name,
std::vector<Stmt>* asserts) { std::vector<Stmt>* asserts) {
cond = Simplify(cond); Expr scond = Simplify(cond);
if (is_zero(cond)) { if (is_zero(scond)) {
LOG(FATAL) << "Bind have an unmet assertion: " LOG(FATAL) << "Bind have an unmet assertion: "
<< cond << ", " << " on argument " << arg_name; << cond << ", " << " on argument " << arg_name;
} }
if (!is_one(cond)) { if (!is_one(scond)) {
std::ostringstream os; std::ostringstream os;
os << "Argument " << arg_name << " has an unsatisfied constraint"; os << "Argument " << arg_name << " has an unsatisfied constraint";
asserts->emplace_back(AssertStmt::make(cond, os.str(), Evaluate::make(0))); asserts->emplace_back(AssertStmt::make(scond, os.str(), Evaluate::make(0)));
} }
} }
......
...@@ -103,9 +103,11 @@ class StorageFlattener : public IRMutator { ...@@ -103,9 +103,11 @@ class StorageFlattener : public IRMutator {
} else { } else {
skey = StorageScope::make(strkey); skey = StorageScope::make(strkey);
} }
// use small alignment for small arrays // use small alignment for small arrays
int32_t const_size = Allocate::constant_allocation_size(shape, key.GetName()); int32_t const_size = Allocate::constant_allocation_size(shape, key.GetName());
int align = GetTempAllocaAlignment(op->type, const_size); int align = GetTempAllocaAlignment(op->type, const_size);
e.buffer = BufferNode::make( e.buffer = BufferNode::make(
Var(key.GetName(), Handle()), Var(key.GetName(), Handle()),
op->type, shape, op->type, shape,
......
/*!
* Copyright (c) 2017 by Contributors
* \file cpu_dsl_api.cc
* \brief DSL API dispatcher
*/
#include <tvm/runtime/registry.h>
#include <tvm/c_dsl_api.h>
#include "./dsl_api.h"
#include "./runtime_base.h"
namespace tvm {
namespace runtime {
DSLAPI* FindDSLAPI() {
auto* f = Registry::Get("dsl_api.singleton");
if (f == nullptr) {
throw dmlc::Error("TVM runtime only environment, "\
"DSL API is not available");
}
void* ptr = (*f)();
return static_cast<DSLAPI*>(ptr);
}
static DSLAPI* GetDSLAPI() {
static DSLAPI* inst = FindDSLAPI();
return inst;
}
} // namespace runtime
} // namespace tvm
using namespace tvm::runtime;
int TVMNodeFree(NodeHandle handle) {
API_BEGIN();
GetDSLAPI()->NodeFree(handle);
API_END();
}
int TVMNodeTypeKey2Index(const char* type_key,
int* out_index) {
API_BEGIN();
GetDSLAPI()->NodeTypeKey2Index(type_key, out_index);
API_END();
}
int TVMNodeGetTypeIndex(NodeHandle handle,
int* out_index) {
API_BEGIN();
GetDSLAPI()->NodeGetTypeIndex(handle, out_index);
API_END();
}
int TVMNodeGetAttr(NodeHandle handle,
const char* key,
TVMValue* out_value,
int* out_type_code,
int* out_success) {
API_BEGIN();
GetDSLAPI()->NodeGetAttr(
handle, key, out_value, out_type_code, out_success);
API_END();
}
int TVMNodeListAttrNames(NodeHandle handle,
int *out_size,
const char*** out_array) {
API_BEGIN();
GetDSLAPI()->NodeListAttrNames(
handle, out_size, out_array);
API_END();
}
...@@ -130,7 +130,7 @@ inline size_t GetDataSize(TVMArray* arr) { ...@@ -130,7 +130,7 @@ inline size_t GetDataSize(TVMArray* arr) {
for (tvm_index_t i = 0; i < arr->ndim; ++i) { for (tvm_index_t i = 0; i < arr->ndim; ++i) {
size *= arr->shape[i]; size *= arr->shape[i];
} }
size *= (arr->dtype.bits / 8) * arr->dtype.lanes; size *= (arr->dtype.bits * arr->dtype.lanes + 7) / 8;
return size; return size;
} }
...@@ -394,6 +394,40 @@ int TVMArrayCopyFromTo(TVMArrayHandle from, ...@@ -394,6 +394,40 @@ int TVMArrayCopyFromTo(TVMArrayHandle from,
API_END(); API_END();
} }
int TVMArrayCopyFromBytes(TVMArrayHandle handle,
void* data,
size_t nbytes) {
API_BEGIN();
TVMContext cpu_ctx;
cpu_ctx.device_type = kCPU;
cpu_ctx.device_id = 0;
size_t arr_size = GetDataSize(handle);
CHECK_EQ(arr_size, nbytes)
<< "TVMArrayCopyFromBytes: size mismatch";
DeviceAPIManager::Get(handle->ctx)->CopyDataFromTo(
data, 0,
handle->data, handle->byte_offset,
nbytes, cpu_ctx, handle->ctx, nullptr);
API_END();
}
int TVMArrayCopyToBytes(TVMArrayHandle handle,
void* data,
size_t nbytes) {
API_BEGIN();
TVMContext cpu_ctx;
cpu_ctx.device_type = kCPU;
cpu_ctx.device_id = 0;
size_t arr_size = GetDataSize(handle);
CHECK_EQ(arr_size, nbytes)
<< "TVMArrayCopyToBytes: size mismatch";
DeviceAPIManager::Get(handle->ctx)->CopyDataFromTo(
handle->data, handle->byte_offset,
data, 0,
nbytes, handle->ctx, cpu_ctx, nullptr);
API_END();
}
int TVMSetStream(TVMContext ctx, TVMStreamHandle stream) { int TVMSetStream(TVMContext ctx, TVMStreamHandle stream) {
API_BEGIN(); API_BEGIN();
DeviceAPIManager::Get(ctx)->SetStream(ctx, stream); DeviceAPIManager::Get(ctx)->SetStream(ctx, stream);
...@@ -406,6 +440,16 @@ int TVMSynchronize(TVMContext ctx, TVMStreamHandle stream) { ...@@ -406,6 +440,16 @@ int TVMSynchronize(TVMContext ctx, TVMStreamHandle stream) {
API_END(); API_END();
} }
int TVMCbArgToReturn(TVMValue* value, int code) {
API_BEGIN();
tvm::runtime::TVMRetValue rv;
rv = tvm::runtime::TVMArgValue(*value, code);
int tcode;
rv.MoveToCHost(value, &tcode);
CHECK_EQ(tcode, code);
API_END();
}
// set device api // set device api
TVM_REGISTER_GLOBAL(tvm::runtime::symbol::tvm_set_device) TVM_REGISTER_GLOBAL(tvm::runtime::symbol::tvm_set_device)
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
......
/*!
* Copyright (c) 2017 by Contributors
* \file cpu_dsl_api.cc
* \brief DSL API dispatcher
*/
#ifndef TVM_RUNTIME_DSL_API_H_
#define TVM_RUNTIME_DSL_API_H_
#include <tvm/c_dsl_api.h>
namespace tvm {
namespace runtime {
/*!
* \brief Common interface for DSL API
* Used for runtime registration
*/
class DSLAPI {
public:
virtual void NodeFree(NodeHandle handle) const = 0;
virtual void NodeTypeKey2Index(const char* type_key,
int* out_index) const = 0;
virtual void NodeGetTypeIndex(NodeHandle handle,
int* out_index) const = 0;
virtual void NodeGetAttr(NodeHandle handle,
const char* key,
TVMValue* out_value,
int* out_type_code,
int* out_success) const = 0;
virtual void NodeListAttrNames(NodeHandle handle,
int *out_size,
const char*** out_array) const = 0;
};
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_DSL_API_H_
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment