Commit b1ffac44 by Chris Nuernberger Committed by Tianqi Chen

[RUNTIME] Stream API (#953)

parent 889573cf
...@@ -37,3 +37,4 @@ List of Contributors ...@@ -37,3 +37,4 @@ List of Contributors
- [Masahiro Masuda](https://github.com/masahi) - [Masahiro Masuda](https://github.com/masahi)
- [Haolong Zhang](https://github.com/haolongzhangm) - [Haolong Zhang](https://github.com/haolongzhangm)
- [Cody Hao Yu](https://github.com/comaniac) - [Cody Hao Yu](https://github.com/comaniac)
- [Chris Nuernberger](https://github.com/cnuernber)
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
#include "./runtime/c_runtime_api.h" #include "./runtime/c_runtime_api.h"
#ifdef __cplusplus #ifdef __cplusplus
TVM_EXTERN_C { extern "C" {
#endif #endif
/*! \brief handle to node */ /*! \brief handle to node */
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
#include "./c_runtime_api.h" #include "./c_runtime_api.h"
#ifdef __cplusplus #ifdef __cplusplus
TVM_EXTERN_C { extern "C" {
#endif #endif
// Backend related functions. // Backend related functions.
......
...@@ -18,12 +18,6 @@ ...@@ -18,12 +18,6 @@
#ifndef TVM_RUNTIME_C_RUNTIME_API_H_ #ifndef TVM_RUNTIME_C_RUNTIME_API_H_
#define TVM_RUNTIME_C_RUNTIME_API_H_ #define TVM_RUNTIME_C_RUNTIME_API_H_
#ifdef __cplusplus
#define TVM_EXTERN_C extern "C"
#else
#define TVM_EXTERN_C
#endif
// Macros to do weak linking // Macros to do weak linking
#ifdef _MSC_VER #ifdef _MSC_VER
#define TVM_WEAK __declspec(selectany) #define TVM_WEAK __declspec(selectany)
...@@ -52,7 +46,7 @@ ...@@ -52,7 +46,7 @@
#include <dlpack/dlpack.h> #include <dlpack/dlpack.h>
#ifdef __cplusplus #ifdef __cplusplus
TVM_EXTERN_C { extern "C" {
#endif #endif
#include <stdint.h> #include <stdint.h>
#include <stddef.h> #include <stddef.h>
...@@ -444,6 +438,26 @@ TVM_DLL int TVMArrayCopyFromTo(TVMArrayHandle from, ...@@ -444,6 +438,26 @@ TVM_DLL int TVMArrayCopyFromTo(TVMArrayHandle from,
TVMStreamHandle stream); TVMStreamHandle stream);
/*! /*!
* \brief Create a new runtime stream.
*
* \param device_type The device type of context
* \param device_id The device id of context
* \param out The new stream handle
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMStreamCreate(int device_type, int device_id, TVMStreamHandle* out);
/*!
* \brief Free a created stream handle.
*
* \param device_type The device type of context
* \param device_id The device id of context
* \param stream The stream to be freed
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMStreamFree(int device_type, int device_id, TVMStreamHandle stream);
/*!
* \brief Set the runtime stream of current thread to be stream. * \brief Set the runtime stream of current thread to be stream.
* The subsequent calls to the same device_type * The subsequent calls to the same device_type
* will use the setted stream handle. * will use the setted stream handle.
...@@ -466,6 +480,20 @@ TVM_DLL int TVMSetStream(int device_type, int device_id, TVMStreamHandle handle) ...@@ -466,6 +480,20 @@ TVM_DLL int TVMSetStream(int device_type, int device_id, TVMStreamHandle handle)
*/ */
TVM_DLL int TVMSynchronize(int device_type, int device_id, TVMStreamHandle stream); TVM_DLL int TVMSynchronize(int device_type, int device_id, TVMStreamHandle stream);
/*!
* \brief Synchronize two streams of execution.
*
* \param device_type The device type of context
* \param device_id The device id of context
* \param src The source stream to synchronize.
* \param dst The destination stream to synchronize.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMStreamStreamSynchronize(int device_type,
int device_id,
TVMStreamHandle src,
TVMStreamHandle dst);
#ifdef __cplusplus #ifdef __cplusplus
} // TVM_EXTERN_C } // TVM_EXTERN_C
#endif #endif
......
...@@ -19,7 +19,7 @@ enum DeviceAttrKind : int { ...@@ -19,7 +19,7 @@ enum DeviceAttrKind : int {
kExist = 0, kExist = 0,
kMaxThreadsPerBlock = 1, kMaxThreadsPerBlock = 1,
kWarpSize = 2, kWarpSize = 2,
kComputeVersion = 3 kComputeVersion = 3,
}; };
/*! \brief Number of bytes each allocation must align to */ /*! \brief Number of bytes each allocation must align to */
...@@ -90,6 +90,21 @@ class DeviceAPI { ...@@ -90,6 +90,21 @@ class DeviceAPI {
TVMContext ctx_from, TVMContext ctx_from,
TVMContext ctx_to, TVMContext ctx_to,
TVMStreamHandle stream) = 0; TVMStreamHandle stream) = 0;
/*!
* \brief Create a new stream of execution.
*
* \param ctx The context of allocation.
*/
TVM_DLL virtual TVMStreamHandle CreateStream(TVMContext ctx);
/*!
* \brief Free a stream of execution
*
* \param ctx The context of the stream
* \param stream The pointer to be freed.
*/
TVM_DLL virtual void FreeStream(TVMContext ctx, TVMStreamHandle stream);
/*! /*!
* \brief Synchronize the stream * \brief Synchronize the stream
* \param ctx The context to perform operation. * \param ctx The context to perform operation.
...@@ -103,6 +118,21 @@ class DeviceAPI { ...@@ -103,6 +118,21 @@ class DeviceAPI {
*/ */
virtual void SetStream(TVMContext ctx, TVMStreamHandle stream) {} virtual void SetStream(TVMContext ctx, TVMStreamHandle stream) {}
/*! /*!
* \brief Synchronize 2 streams of execution.
*
* An event is created in event_src stream that the second then
* stream waits on. Neither event_src or event_dst need to be of
* the same device ID as the context, but they must be of the same
* device type.
*
* \param ctx The context of the streams.
* \param event_src The source stream to synchronize.
* \param event_dst The destination stream to synchronize.
*/
TVM_DLL virtual void SyncStreamFromTo(TVMContext ctx,
TVMStreamHandle event_src,
TVMStreamHandle event_dst);
/*!
* \brief Allocate temporal workspace for backend execution. * \brief Allocate temporal workspace for backend execution.
* *
* \note We have the following assumption about backend temporal * \note We have the following assumption about backend temporal
...@@ -128,6 +158,7 @@ class DeviceAPI { ...@@ -128,6 +158,7 @@ class DeviceAPI {
* \param ptr The pointer to be freed. * \param ptr The pointer to be freed.
*/ */
TVM_DLL virtual void FreeWorkspace(TVMContext ctx, void* ptr); TVM_DLL virtual void FreeWorkspace(TVMContext ctx, void* ptr);
/*! /*!
* \brief Get device API base don context. * \brief Get device API base don context.
* \param ctx The context * \param ctx The context
......
...@@ -106,6 +106,21 @@ void DeviceAPI::FreeWorkspace(TVMContext ctx, void* ptr) { ...@@ -106,6 +106,21 @@ void DeviceAPI::FreeWorkspace(TVMContext ctx, void* ptr) {
FreeDataSpace(ctx, ptr); FreeDataSpace(ctx, ptr);
} }
TVMStreamHandle DeviceAPI::CreateStream(TVMContext ctx) {
LOG(FATAL) << "Device does not support stream api.";
return 0;
}
void DeviceAPI::FreeStream(TVMContext ctx, TVMStreamHandle stream) {
LOG(FATAL) << "Device does not support stream api.";
}
void DeviceAPI::SyncStreamFromTo(TVMContext ctx,
TVMStreamHandle event_src,
TVMStreamHandle event_dst) {
LOG(FATAL) << "Device does not support stream api.";
}
inline TVMArray* TVMArrayCreate_() { inline TVMArray* TVMArrayCreate_() {
TVMArray* arr = new TVMArray(); TVMArray* arr = new TVMArray();
arr->shape = nullptr; arr->shape = nullptr;
...@@ -448,6 +463,24 @@ int TVMArrayCopyToBytes(TVMArrayHandle handle, ...@@ -448,6 +463,24 @@ int TVMArrayCopyToBytes(TVMArrayHandle handle,
API_END(); API_END();
} }
int TVMStreamCreate(int device_type, int device_id, TVMStreamHandle* out) {
API_BEGIN();
TVMContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type);
ctx.device_id = device_id;
*out = DeviceAPIManager::Get(ctx)->CreateStream(ctx);
API_END();
}
int TVMStreamFree(int device_type, int device_id, TVMStreamHandle stream) {
API_BEGIN();
TVMContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type);
ctx.device_id = device_id;
DeviceAPIManager::Get(ctx)->FreeStream(ctx, stream);
API_END();
}
int TVMSetStream(int device_type, int device_id, TVMStreamHandle stream) { int TVMSetStream(int device_type, int device_id, TVMStreamHandle stream) {
API_BEGIN(); API_BEGIN();
TVMContext ctx; TVMContext ctx;
...@@ -466,6 +499,18 @@ int TVMSynchronize(int device_type, int device_id, TVMStreamHandle stream) { ...@@ -466,6 +499,18 @@ int TVMSynchronize(int device_type, int device_id, TVMStreamHandle stream) {
API_END(); API_END();
} }
int TVMStreamStreamSynchronize(int device_type,
int device_id,
TVMStreamHandle src,
TVMStreamHandle dst) {
API_BEGIN();
TVMContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type);
ctx.device_id = device_id;
DeviceAPIManager::Get(ctx)->SyncStreamFromTo(ctx, src, dst);
API_END();
}
int TVMCbArgToReturn(TVMValue* value, int code) { int TVMCbArgToReturn(TVMValue* value, int code) {
API_BEGIN(); API_BEGIN();
tvm::runtime::TVMRetValue rv; tvm::runtime::TVMRetValue rv;
......
...@@ -30,7 +30,7 @@ class CUDADeviceAPI final : public DeviceAPI { ...@@ -30,7 +30,7 @@ class CUDADeviceAPI final : public DeviceAPI {
&value, cudaDevAttrMaxThreadsPerBlock, ctx.device_id) &value, cudaDevAttrMaxThreadsPerBlock, ctx.device_id)
== cudaSuccess); == cudaSuccess);
break; break;
case kMaxThreadsPerBlock: { case kMaxThreadsPerBlock: {
CUDA_CALL(cudaDeviceGetAttribute( CUDA_CALL(cudaDeviceGetAttribute(
&value, cudaDevAttrMaxThreadsPerBlock, ctx.device_id)); &value, cudaDevAttrMaxThreadsPerBlock, ctx.device_id));
break; break;
...@@ -102,6 +102,30 @@ class CUDADeviceAPI final : public DeviceAPI { ...@@ -102,6 +102,30 @@ class CUDADeviceAPI final : public DeviceAPI {
} }
} }
TVMStreamHandle CreateStream(TVMContext ctx) {
CUDA_CALL(cudaSetDevice(ctx.device_id));
cudaStream_t retval;
CUDA_CALL(cudaStreamCreate(&retval));
return static_cast<TVMStreamHandle>(retval);
}
void FreeStream(TVMContext ctx, TVMStreamHandle stream) {
CUDA_CALL(cudaSetDevice(ctx.device_id));
cudaStream_t cu_stream = static_cast<cudaStream_t>(stream);
CUDA_CALL(cudaStreamDestroy(cu_stream));
}
void SyncStreamFromTo(TVMContext ctx, TVMStreamHandle event_src, TVMStreamHandle event_dst) {
CUDA_CALL(cudaSetDevice(ctx.device_id));
cudaStream_t src_stream = static_cast<cudaStream_t>(event_src);
cudaStream_t dst_stream = static_cast<cudaStream_t>(event_dst);
cudaEvent_t evt;
CUDA_CALL(cudaEventCreate(&evt));
CUDA_CALL(cudaEventRecord(evt, src_stream));
CUDA_CALL(cudaStreamWaitEvent(dst_stream, evt, 0));
CUDA_CALL(cudaEventDestroy(evt));
}
void StreamSync(TVMContext ctx, TVMStreamHandle stream) final { void StreamSync(TVMContext ctx, TVMStreamHandle stream) final {
CUDA_CALL(cudaSetDevice(ctx.device_id)); CUDA_CALL(cudaSetDevice(ctx.device_id));
CUDA_CALL(cudaStreamSynchronize(static_cast<cudaStream_t>(stream))); CUDA_CALL(cudaStreamSynchronize(static_cast<cudaStream_t>(stream)));
......
...@@ -45,8 +45,8 @@ void OpenCLWorkspace::GetAttr( ...@@ -45,8 +45,8 @@ void OpenCLWorkspace::GetAttr(
*rv = 1; *rv = 1;
break; break;
} }
case kComputeVersion: return; case kComputeVersion: return;
case kExist: break; case kExist: break;
} }
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment