Commit b1ffac44 by Chris Nuernberger Committed by Tianqi Chen

[RUNTIME] Stream API (#953)

parent 889573cf
......@@ -37,3 +37,4 @@ List of Contributors
- [Masahiro Masuda](https://github.com/masahi)
- [Haolong Zhang](https://github.com/haolongzhangm)
- [Cody Hao Yu](https://github.com/comaniac)
- [Chris Nuernberger](https://github.com/cnuernber)
......@@ -17,7 +17,7 @@
#include "./runtime/c_runtime_api.h"
#ifdef __cplusplus
TVM_EXTERN_C {
extern "C" {
#endif
/*! \brief handle to node */
......
......@@ -13,7 +13,7 @@
#include "./c_runtime_api.h"
#ifdef __cplusplus
TVM_EXTERN_C {
extern "C" {
#endif
// Backend related functions.
......
......@@ -18,12 +18,6 @@
#ifndef TVM_RUNTIME_C_RUNTIME_API_H_
#define TVM_RUNTIME_C_RUNTIME_API_H_
#ifdef __cplusplus
#define TVM_EXTERN_C extern "C"
#else
#define TVM_EXTERN_C
#endif
// Macros to do weak linking
#ifdef _MSC_VER
#define TVM_WEAK __declspec(selectany)
......@@ -52,7 +46,7 @@
#include <dlpack/dlpack.h>
#ifdef __cplusplus
TVM_EXTERN_C {
extern "C" {
#endif
#include <stdint.h>
#include <stddef.h>
......@@ -444,6 +438,26 @@ TVM_DLL int TVMArrayCopyFromTo(TVMArrayHandle from,
TVMStreamHandle stream);
/*!
* \brief Create a new runtime stream.
*
* \param device_type The device type of context
* \param device_id The device id of context
* \param out The new stream handle
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMStreamCreate(int device_type, int device_id, TVMStreamHandle* out);
/*!
* \brief Free a created stream handle.
*
* \param device_type The device type of context
* \param device_id The device id of context
* \param stream The stream to be freed
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMStreamFree(int device_type, int device_id, TVMStreamHandle stream);
/*!
* \brief Set the runtime stream of current thread to be stream.
* The subsequent calls to the same device_type
* will use the setted stream handle.
......@@ -466,6 +480,20 @@ TVM_DLL int TVMSetStream(int device_type, int device_id, TVMStreamHandle handle)
*/
TVM_DLL int TVMSynchronize(int device_type, int device_id, TVMStreamHandle stream);
/*!
* \brief Synchronize two streams of execution.
*
* \param device_type The device type of context
* \param device_id The device id of context
* \param src The source stream to synchronize.
* \param dst The destination stream to synchronize.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMStreamStreamSynchronize(int device_type,
int device_id,
TVMStreamHandle src,
TVMStreamHandle dst);
#ifdef __cplusplus
} // TVM_EXTERN_C
#endif
......
......@@ -19,7 +19,7 @@ enum DeviceAttrKind : int {
kExist = 0,
kMaxThreadsPerBlock = 1,
kWarpSize = 2,
kComputeVersion = 3
kComputeVersion = 3,
};
/*! \brief Number of bytes each allocation must align to */
......@@ -90,6 +90,21 @@ class DeviceAPI {
TVMContext ctx_from,
TVMContext ctx_to,
TVMStreamHandle stream) = 0;
/*!
* \brief Create a new stream of execution.
*
* \param ctx The context of allocation.
*/
TVM_DLL virtual TVMStreamHandle CreateStream(TVMContext ctx);
/*!
* \brief Free a stream of execution
*
* \param ctx The context of the stream
* \param stream The pointer to be freed.
*/
TVM_DLL virtual void FreeStream(TVMContext ctx, TVMStreamHandle stream);
/*!
* \brief Synchronize the stream
* \param ctx The context to perform operation.
......@@ -103,6 +118,21 @@ class DeviceAPI {
*/
virtual void SetStream(TVMContext ctx, TVMStreamHandle stream) {}
/*!
* \brief Synchronize 2 streams of execution.
*
* An event is created in event_src stream that the second then
* stream waits on. Neither event_src or event_dst need to be of
* the same device ID as the context, but they must be of the same
* device type.
*
* \param ctx The context of the streams.
* \param event_src The source stream to synchronize.
* \param event_dst The destination stream to synchronize.
*/
TVM_DLL virtual void SyncStreamFromTo(TVMContext ctx,
TVMStreamHandle event_src,
TVMStreamHandle event_dst);
/*!
* \brief Allocate temporal workspace for backend execution.
*
* \note We have the following assumption about backend temporal
......@@ -128,6 +158,7 @@ class DeviceAPI {
* \param ptr The pointer to be freed.
*/
TVM_DLL virtual void FreeWorkspace(TVMContext ctx, void* ptr);
/*!
* \brief Get device API base don context.
* \param ctx The context
......
......@@ -106,6 +106,21 @@ void DeviceAPI::FreeWorkspace(TVMContext ctx, void* ptr) {
FreeDataSpace(ctx, ptr);
}
TVMStreamHandle DeviceAPI::CreateStream(TVMContext ctx) {
LOG(FATAL) << "Device does not support stream api.";
return 0;
}
void DeviceAPI::FreeStream(TVMContext ctx, TVMStreamHandle stream) {
LOG(FATAL) << "Device does not support stream api.";
}
void DeviceAPI::SyncStreamFromTo(TVMContext ctx,
TVMStreamHandle event_src,
TVMStreamHandle event_dst) {
LOG(FATAL) << "Device does not support stream api.";
}
inline TVMArray* TVMArrayCreate_() {
TVMArray* arr = new TVMArray();
arr->shape = nullptr;
......@@ -448,6 +463,24 @@ int TVMArrayCopyToBytes(TVMArrayHandle handle,
API_END();
}
int TVMStreamCreate(int device_type, int device_id, TVMStreamHandle* out) {
API_BEGIN();
TVMContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type);
ctx.device_id = device_id;
*out = DeviceAPIManager::Get(ctx)->CreateStream(ctx);
API_END();
}
int TVMStreamFree(int device_type, int device_id, TVMStreamHandle stream) {
API_BEGIN();
TVMContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type);
ctx.device_id = device_id;
DeviceAPIManager::Get(ctx)->FreeStream(ctx, stream);
API_END();
}
int TVMSetStream(int device_type, int device_id, TVMStreamHandle stream) {
API_BEGIN();
TVMContext ctx;
......@@ -466,6 +499,18 @@ int TVMSynchronize(int device_type, int device_id, TVMStreamHandle stream) {
API_END();
}
int TVMStreamStreamSynchronize(int device_type,
int device_id,
TVMStreamHandle src,
TVMStreamHandle dst) {
API_BEGIN();
TVMContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type);
ctx.device_id = device_id;
DeviceAPIManager::Get(ctx)->SyncStreamFromTo(ctx, src, dst);
API_END();
}
int TVMCbArgToReturn(TVMValue* value, int code) {
API_BEGIN();
tvm::runtime::TVMRetValue rv;
......
......@@ -30,7 +30,7 @@ class CUDADeviceAPI final : public DeviceAPI {
&value, cudaDevAttrMaxThreadsPerBlock, ctx.device_id)
== cudaSuccess);
break;
case kMaxThreadsPerBlock: {
case kMaxThreadsPerBlock: {
CUDA_CALL(cudaDeviceGetAttribute(
&value, cudaDevAttrMaxThreadsPerBlock, ctx.device_id));
break;
......@@ -102,6 +102,30 @@ class CUDADeviceAPI final : public DeviceAPI {
}
}
TVMStreamHandle CreateStream(TVMContext ctx) {
CUDA_CALL(cudaSetDevice(ctx.device_id));
cudaStream_t retval;
CUDA_CALL(cudaStreamCreate(&retval));
return static_cast<TVMStreamHandle>(retval);
}
void FreeStream(TVMContext ctx, TVMStreamHandle stream) {
CUDA_CALL(cudaSetDevice(ctx.device_id));
cudaStream_t cu_stream = static_cast<cudaStream_t>(stream);
CUDA_CALL(cudaStreamDestroy(cu_stream));
}
void SyncStreamFromTo(TVMContext ctx, TVMStreamHandle event_src, TVMStreamHandle event_dst) {
CUDA_CALL(cudaSetDevice(ctx.device_id));
cudaStream_t src_stream = static_cast<cudaStream_t>(event_src);
cudaStream_t dst_stream = static_cast<cudaStream_t>(event_dst);
cudaEvent_t evt;
CUDA_CALL(cudaEventCreate(&evt));
CUDA_CALL(cudaEventRecord(evt, src_stream));
CUDA_CALL(cudaStreamWaitEvent(dst_stream, evt, 0));
CUDA_CALL(cudaEventDestroy(evt));
}
void StreamSync(TVMContext ctx, TVMStreamHandle stream) final {
CUDA_CALL(cudaSetDevice(ctx.device_id));
CUDA_CALL(cudaStreamSynchronize(static_cast<cudaStream_t>(stream)));
......
......@@ -45,8 +45,8 @@ void OpenCLWorkspace::GetAttr(
*rv = 1;
break;
}
case kComputeVersion: return;
case kExist: break;
case kComputeVersion: return;
case kExist: break;
}
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment