Commit cd29c18c by Tianqi Chen Committed by GitHub

[RUNTIME][ABI] Flat structure arguments (#232)

parent 2dec0510
...@@ -173,7 +173,7 @@ void GraphExecutor::SetupStorage() { ...@@ -173,7 +173,7 @@ void GraphExecutor::SetupStorage() {
TShape shape{static_cast<int64_t>(pool_entry_bytes[i] + 3) / 4}; TShape shape{static_cast<int64_t>(pool_entry_bytes[i] + 3) / 4};
DLTensor* tensor; DLTensor* tensor;
TVM_CCALL(TVMArrayAlloc( TVM_CCALL(TVMArrayAlloc(
shape.data(), 1, DLDataType{kFloat, 32U, 1U}, ctx_, &tensor)); shape.data(), 1, kFloat, 32, 1, ctx_.device_type, ctx_.device_id, &tensor));
storage_pool_.push_back(tensor); storage_pool_.push_back(tensor);
} }
// Assign the pooled entries. // Assign the pooled entries.
......
...@@ -339,15 +339,21 @@ TVM_DLL int TVMFuncListGlobalNames(int *out_size, ...@@ -339,15 +339,21 @@ TVM_DLL int TVMFuncListGlobalNames(int *out_size,
* *
* \param shape The shape of the array, the data content will be copied to out * \param shape The shape of the array, the data content will be copied to out
* \param ndim The number of dimension of the array. * \param ndim The number of dimension of the array.
* \param dtype The array data type. * \param dtype_code The type code of the dtype
* \param ctx The ctx this array sits on. * \param dtype_bits The number of bits of dtype
* \param dtype_lanes The number of lanes in the dtype.
* \param device_type The device type of context
* \param device_id The device id of context.
* \param out The output handle. * \param out The output handle.
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
TVM_DLL int TVMArrayAlloc(const tvm_index_t* shape, TVM_DLL int TVMArrayAlloc(const tvm_index_t* shape,
int ndim, int ndim,
TVMType dtype, int dtype_code,
TVMContext ctx, int dtype_bits,
int dtype_lanes,
int device_type,
int device_id,
TVMArrayHandle* out); TVMArrayHandle* out);
/*! /*!
...@@ -396,19 +402,22 @@ TVM_DLL int TVMArrayCopyFromTo(TVMArrayHandle from, ...@@ -396,19 +402,22 @@ TVM_DLL int TVMArrayCopyFromTo(TVMArrayHandle from,
* will use the setted stream handle. * will use the setted stream handle.
* The specific type of stream is runtime device dependent. * The specific type of stream is runtime device dependent.
* *
* \param ctx The context. * \param device_type The device type of context
* \param device_id The device id of context.
* \param handle The stream handle. * \param handle The stream handle.
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
TVM_DLL int TVMSetStream(TVMContext ctx, TVMStreamHandle handle); TVM_DLL int TVMSetStream(int device_type, int device_id, TVMStreamHandle handle);
/*! /*!
* \brief Wait until all computations on stream completes. * \brief Wait until all computations on stream completes.
* \param ctx The ctx to be synchronized. *
* \param device_type The device type of context
* \param device_id The device id of context.
* \param stream The stream to be synchronized. * \param stream The stream to be synchronized.
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
TVM_DLL int TVMSynchronize(TVMContext ctx, TVMStreamHandle stream); TVM_DLL int TVMSynchronize(int device_type, int device_id, TVMStreamHandle stream);
#ifdef __cplusplus #ifdef __cplusplus
} // TVM_EXTERN_C } // TVM_EXTERN_C
......
...@@ -101,7 +101,13 @@ def empty(shape, dtype="float32", ctx=context(1, 0)): ...@@ -101,7 +101,13 @@ def empty(shape, dtype="float32", ctx=context(1, 0)):
handle = TVMArrayHandle() handle = TVMArrayHandle()
dtype = TVMType(dtype) dtype = TVMType(dtype)
check_call(_LIB.TVMArrayAlloc( check_call(_LIB.TVMArrayAlloc(
shape, ndim, dtype, ctx, ctypes.byref(handle))) shape, ndim,
ctypes.c_int(dtype.type_code),
ctypes.c_int(dtype.bits),
ctypes.c_int(dtype.lanes),
ctx.device_type,
ctx.device_id,
ctypes.byref(handle)))
return _make_array(handle, False) return _make_array(handle, False)
class NDArrayBase(_NDArrayBase): class NDArrayBase(_NDArrayBase):
......
...@@ -127,7 +127,7 @@ class TVMContext(ctypes.Structure): ...@@ -127,7 +127,7 @@ class TVMContext(ctypes.Structure):
def sync(self): def sync(self):
"""Synchronize until jobs finished at the context.""" """Synchronize until jobs finished at the context."""
check_call(_LIB.TVMSynchronize(self, None)) check_call(_LIB.TVMSynchronize(self.device_type, self.device_id, None))
def __eq__(self, other): def __eq__(self, other):
return (isinstance(other, TVMContext) and return (isinstance(other, TVMContext) and
......
...@@ -124,14 +124,14 @@ inline void TVMArrayFree_(TVMArray* arr) { ...@@ -124,14 +124,14 @@ inline void TVMArrayFree_(TVMArray* arr) {
delete arr; delete arr;
} }
inline void VerifyType(TVMType dtype) { inline void VerifyType(int dtype_code, int dtype_bits, int dtype_lanes) {
CHECK_GE(dtype.lanes, 1U); CHECK_GE(dtype_lanes, 1);
if (dtype.code == kFloat) { if (dtype_code == kFloat) {
CHECK_EQ(dtype.bits % 32U, 0U); CHECK_EQ(dtype_bits % 32, 0);
} else { } else {
CHECK_EQ(dtype.bits % 8U, 0U); CHECK_EQ(dtype_bits % 8, 0);
} }
CHECK_EQ(dtype.bits & (dtype.bits - 1), 0); CHECK_EQ(dtype_bits & (dtype_bits - 1), 0);
} }
inline size_t GetDataSize(TVMArray* arr) { inline size_t GetDataSize(TVMArray* arr) {
...@@ -367,8 +367,11 @@ int TVMFuncCreateFromCFunc(TVMPackedCFunc func, ...@@ -367,8 +367,11 @@ int TVMFuncCreateFromCFunc(TVMPackedCFunc func,
int TVMArrayAlloc(const tvm_index_t* shape, int TVMArrayAlloc(const tvm_index_t* shape,
int ndim, int ndim,
TVMType dtype, int dtype_code,
TVMContext ctx, int dtype_bits,
int dtype_lanes,
int device_type,
int device_id,
TVMArrayHandle* out) { TVMArrayHandle* out) {
TVMArray* arr = nullptr; TVMArray* arr = nullptr;
API_BEGIN(); API_BEGIN();
...@@ -377,17 +380,20 @@ int TVMArrayAlloc(const tvm_index_t* shape, ...@@ -377,17 +380,20 @@ int TVMArrayAlloc(const tvm_index_t* shape,
// ndim // ndim
arr->ndim = ndim; arr->ndim = ndim;
// dtype // dtype
VerifyType(dtype); VerifyType(dtype_code, dtype_bits, dtype_lanes);
arr->dtype = dtype; arr->dtype.code = static_cast<uint8_t>(dtype_code);
arr->dtype.bits = static_cast<uint8_t>(dtype_bits);
arr->dtype.lanes = static_cast<uint16_t>(dtype_lanes);
tvm_index_t* shape_copy = new tvm_index_t[ndim]; tvm_index_t* shape_copy = new tvm_index_t[ndim];
std::copy(shape, shape + ndim, shape_copy); std::copy(shape, shape + ndim, shape_copy);
arr->shape = shape_copy; arr->shape = shape_copy;
// ctx // ctx
arr->ctx = ctx; arr->ctx.device_type = static_cast<DLDeviceType>(device_type);
arr->ctx.device_id = device_id;
size_t size = GetDataSize(arr); size_t size = GetDataSize(arr);
size_t alignment = GetDataAlignment(arr); size_t alignment = GetDataAlignment(arr);
arr->data = DeviceAPIManager::Get(ctx)->AllocDataSpace( arr->data = DeviceAPIManager::Get(arr->ctx)->AllocDataSpace(
ctx, size, alignment); arr->ctx, size, alignment);
*out = arr; *out = arr;
API_END_HANDLE_ERROR(TVMArrayFree_(arr)); API_END_HANDLE_ERROR(TVMArrayFree_(arr));
} }
...@@ -456,14 +462,20 @@ int TVMArrayCopyToBytes(TVMArrayHandle handle, ...@@ -456,14 +462,20 @@ int TVMArrayCopyToBytes(TVMArrayHandle handle,
API_END(); API_END();
} }
int TVMSetStream(TVMContext ctx, TVMStreamHandle stream) { int TVMSetStream(int device_type, int device_id, TVMStreamHandle stream) {
API_BEGIN(); API_BEGIN();
TVMContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type);
ctx.device_id = device_id;
DeviceAPIManager::Get(ctx)->SetStream(ctx, stream); DeviceAPIManager::Get(ctx)->SetStream(ctx, stream);
API_END(); API_END();
} }
int TVMSynchronize(TVMContext ctx, TVMStreamHandle stream) { int TVMSynchronize(int device_type, int device_id, TVMStreamHandle stream) {
API_BEGIN(); API_BEGIN();
TVMContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type);
ctx.device_id = device_id;
DeviceAPIManager::Get(ctx)->StreamSync(ctx, stream); DeviceAPIManager::Get(ctx)->StreamSync(ctx, stream);
API_END(); API_END();
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment