Commit cd29c18c by Tianqi Chen Committed by GitHub

[RUNTIME][ABI] Flat structure arguments (#232)

parent 2dec0510
......@@ -173,7 +173,7 @@ void GraphExecutor::SetupStorage() {
TShape shape{static_cast<int64_t>(pool_entry_bytes[i] + 3) / 4};
DLTensor* tensor;
TVM_CCALL(TVMArrayAlloc(
shape.data(), 1, DLDataType{kFloat, 32U, 1U}, ctx_, &tensor));
shape.data(), 1, kFloat, 32, 1, ctx_.device_type, ctx_.device_id, &tensor));
storage_pool_.push_back(tensor);
}
// Assign the pooled entries.
......
......@@ -339,15 +339,21 @@ TVM_DLL int TVMFuncListGlobalNames(int *out_size,
*
* \param shape The shape of the array, the data content will be copied to out
* \param ndim The number of dimension of the array.
* \param dtype The array data type.
* \param ctx The ctx this array sits on.
* \param dtype_code The type code of the dtype
* \param dtype_bits The number of bits of dtype
* \param dtype_lanes The number of lanes in the dtype.
* \param device_type The device type of context
* \param device_id The device id of context.
* \param out The output handle.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMArrayAlloc(const tvm_index_t* shape,
int ndim,
TVMType dtype,
TVMContext ctx,
int dtype_code,
int dtype_bits,
int dtype_lanes,
int device_type,
int device_id,
TVMArrayHandle* out);
/*!
......@@ -396,19 +402,22 @@ TVM_DLL int TVMArrayCopyFromTo(TVMArrayHandle from,
* will use the setted stream handle.
* The specific type of stream is runtime device dependent.
*
* \param ctx The context.
* \param device_type The device type of context
* \param device_id The device id of context.
* \param handle The stream handle.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMSetStream(TVMContext ctx, TVMStreamHandle handle);
TVM_DLL int TVMSetStream(int device_type, int device_id, TVMStreamHandle handle);
/*!
* \brief Wait until all computations on stream completes.
* \param ctx The ctx to be synchronized.
*
* \param device_type The device type of context
* \param device_id The device id of context.
* \param stream The stream to be synchronized.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMSynchronize(TVMContext ctx, TVMStreamHandle stream);
TVM_DLL int TVMSynchronize(int device_type, int device_id, TVMStreamHandle stream);
#ifdef __cplusplus
} // TVM_EXTERN_C
......
......@@ -101,7 +101,13 @@ def empty(shape, dtype="float32", ctx=context(1, 0)):
handle = TVMArrayHandle()
dtype = TVMType(dtype)
check_call(_LIB.TVMArrayAlloc(
shape, ndim, dtype, ctx, ctypes.byref(handle)))
shape, ndim,
ctypes.c_int(dtype.type_code),
ctypes.c_int(dtype.bits),
ctypes.c_int(dtype.lanes),
ctx.device_type,
ctx.device_id,
ctypes.byref(handle)))
return _make_array(handle, False)
class NDArrayBase(_NDArrayBase):
......
......@@ -127,7 +127,7 @@ class TVMContext(ctypes.Structure):
def sync(self):
"""Synchronize until jobs finished at the context."""
check_call(_LIB.TVMSynchronize(self, None))
check_call(_LIB.TVMSynchronize(self.device_type, self.device_id, None))
def __eq__(self, other):
return (isinstance(other, TVMContext) and
......
......@@ -124,14 +124,14 @@ inline void TVMArrayFree_(TVMArray* arr) {
delete arr;
}
inline void VerifyType(TVMType dtype) {
CHECK_GE(dtype.lanes, 1U);
if (dtype.code == kFloat) {
CHECK_EQ(dtype.bits % 32U, 0U);
inline void VerifyType(int dtype_code, int dtype_bits, int dtype_lanes) {
CHECK_GE(dtype_lanes, 1);
if (dtype_code == kFloat) {
CHECK_EQ(dtype_bits % 32, 0);
} else {
CHECK_EQ(dtype.bits % 8U, 0U);
CHECK_EQ(dtype_bits % 8, 0);
}
CHECK_EQ(dtype.bits & (dtype.bits - 1), 0);
CHECK_EQ(dtype_bits & (dtype_bits - 1), 0);
}
inline size_t GetDataSize(TVMArray* arr) {
......@@ -367,8 +367,11 @@ int TVMFuncCreateFromCFunc(TVMPackedCFunc func,
int TVMArrayAlloc(const tvm_index_t* shape,
int ndim,
TVMType dtype,
TVMContext ctx,
int dtype_code,
int dtype_bits,
int dtype_lanes,
int device_type,
int device_id,
TVMArrayHandle* out) {
TVMArray* arr = nullptr;
API_BEGIN();
......@@ -377,17 +380,20 @@ int TVMArrayAlloc(const tvm_index_t* shape,
// ndim
arr->ndim = ndim;
// dtype
VerifyType(dtype);
arr->dtype = dtype;
VerifyType(dtype_code, dtype_bits, dtype_lanes);
arr->dtype.code = static_cast<uint8_t>(dtype_code);
arr->dtype.bits = static_cast<uint8_t>(dtype_bits);
arr->dtype.lanes = static_cast<uint16_t>(dtype_lanes);
tvm_index_t* shape_copy = new tvm_index_t[ndim];
std::copy(shape, shape + ndim, shape_copy);
arr->shape = shape_copy;
// ctx
arr->ctx = ctx;
arr->ctx.device_type = static_cast<DLDeviceType>(device_type);
arr->ctx.device_id = device_id;
size_t size = GetDataSize(arr);
size_t alignment = GetDataAlignment(arr);
arr->data = DeviceAPIManager::Get(ctx)->AllocDataSpace(
ctx, size, alignment);
arr->data = DeviceAPIManager::Get(arr->ctx)->AllocDataSpace(
arr->ctx, size, alignment);
*out = arr;
API_END_HANDLE_ERROR(TVMArrayFree_(arr));
}
......@@ -456,14 +462,20 @@ int TVMArrayCopyToBytes(TVMArrayHandle handle,
API_END();
}
int TVMSetStream(TVMContext ctx, TVMStreamHandle stream) {
int TVMSetStream(int device_type, int device_id, TVMStreamHandle stream) {
API_BEGIN();
TVMContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type);
ctx.device_id = device_id;
DeviceAPIManager::Get(ctx)->SetStream(ctx, stream);
API_END();
}
int TVMSynchronize(TVMContext ctx, TVMStreamHandle stream) {
int TVMSynchronize(int device_type, int device_id, TVMStreamHandle stream) {
API_BEGIN();
TVMContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type);
ctx.device_id = device_id;
DeviceAPIManager::Get(ctx)->StreamSync(ctx, stream);
API_END();
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment