Commit 3b8e70ae by Tianqi Chen Committed by GitHub

[RUNTIME] Move device_api to include (#185)

* [RUNTIME] Move device_api to include

* fix doxygen

* fix device api

* fx
parent fcfec961
...@@ -195,43 +195,6 @@ TVM_DLL int TVMModGetFunction(TVMModuleHandle mod, ...@@ -195,43 +195,6 @@ TVM_DLL int TVMModGetFunction(TVMModuleHandle mod,
TVM_DLL int TVMModFree(TVMModuleHandle mod); TVM_DLL int TVMModFree(TVMModuleHandle mod);
/*! /*!
* \brief Backend function for modules to get function
* from its environment mod_node (its imports and global function).
*
* The user do should not call TVMFuncFree on func.
*
* \note This API is supposed to be used by backend,
* it is not supposed to be used by user.
*
* \param mod_node The module handle.
* \param func_name The name of the function.
* \param out The result function.
* \return 0 when no error is thrown, -1 when failure happens
*/
TVM_DLL int TVMBackendGetFuncFromEnv(void* mod_node,
const char* func_name,
TVMFunctionHandle *out);
/*!
* \brief Backend function for running parallel for loop.
*
* \note This API is supposed to be used by backend,
* it is not supposed to be used by user.
*
* \param begin The start of iteration.
* \param end The end of iteration.
* \param lambda The lambda function to be executed.
* \param env The environment of lambda function.
*
* \return 0 when no error is thrown, -1 when failure happens
*/
TVM_DLL int TVMBackendParallelFor(
int64_t begin,
int64_t end,
int (*lambda)(int64_t begin, int64_t end, void* env),
void* env);
/*!
* \brief Free the function when it is no longer needed. * \brief Free the function when it is no longer needed.
* \param func The function handle * \param func The function handle
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
...@@ -351,6 +314,44 @@ TVM_DLL int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out); ...@@ -351,6 +314,44 @@ TVM_DLL int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out);
TVM_DLL int TVMFuncListGlobalNames(int *out_size, TVM_DLL int TVMFuncListGlobalNames(int *out_size,
const char*** out_array); const char*** out_array);
// Backend related functions.
/*!
* \brief Backend function for modules to get function
* from its environment mod_node (its imports and global function).
*
* The user do should not call TVMFuncFree on func.
*
* \note This API is supposed to be used by backend,
* it is not supposed to be used by user.
*
* \param mod_node The module handle.
* \param func_name The name of the function.
* \param out The result function.
* \return 0 when no error is thrown, -1 when failure happens
*/
TVM_DLL int TVMBackendGetFuncFromEnv(void* mod_node,
const char* func_name,
TVMFunctionHandle *out);
/*!
* \brief Backend function for running parallel for loop.
*
* \note This API is supposed to be used by backend,
* it is not supposed to be used by user.
*
* \param begin The start of iteration.
* \param end The end of iteration.
* \param lambda The lambda function to be executed.
* \param env The environment of lambda function.
*
* \return 0 when no error is thrown, -1 when failure happens
*/
TVM_DLL int TVMBackendParallelFor(
int64_t begin,
int64_t end,
int (*lambda)(int64_t begin, int64_t end, void* env),
void* env);
// Array related apis for quick proptyping // Array related apis for quick proptyping
/*! /*!
* \brief Allocate a nd-array's memory, * \brief Allocate a nd-array's memory,
...@@ -368,6 +369,7 @@ TVM_DLL int TVMArrayAlloc(const tvm_index_t* shape, ...@@ -368,6 +369,7 @@ TVM_DLL int TVMArrayAlloc(const tvm_index_t* shape,
TVMType dtype, TVMType dtype,
TVMContext ctx, TVMContext ctx,
TVMArrayHandle* out); TVMArrayHandle* out);
/*! /*!
* \brief Free the TVM Array. * \brief Free the TVM Array.
* \param handle The array handle to be freed. * \param handle The array handle to be freed.
...@@ -385,6 +387,19 @@ TVM_DLL int TVMArrayFree(TVMArrayHandle handle); ...@@ -385,6 +387,19 @@ TVM_DLL int TVMArrayFree(TVMArrayHandle handle);
TVM_DLL int TVMArrayCopyFromTo(TVMArrayHandle from, TVM_DLL int TVMArrayCopyFromTo(TVMArrayHandle from,
TVMArrayHandle to, TVMArrayHandle to,
TVMStreamHandle stream); TVMStreamHandle stream);
/*!
* \brief Set the runtime stream of current thread to be stream.
* The subsequent calls to the same device_type
* will use the setted stream handle.
* The specific type of stream is runtime device dependent.
*
* \param ctx The context.
* \param handle The stream handle.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMSetStream(TVMContext ctx, TVMStreamHandle handle);
/*! /*!
* \brief Wait until all computations on stream completes. * \brief Wait until all computations on stream completes.
* \param ctx The ctx to be synchronized. * \param ctx The ctx to be synchronized.
......
/*! /*!
* Copyright (c) 2016 by Contributors * Copyright (c) 2016 by Contributors
* \file device_api.h * \file device_api.h
* \brief Device specific API * \brief Abstract device memory management API
*/ */
#ifndef TVM_RUNTIME_DEVICE_API_H_ #ifndef TVM_RUNTIME_DEVICE_API_H_
#define TVM_RUNTIME_DEVICE_API_H_ #define TVM_RUNTIME_DEVICE_API_H_
#include <tvm/base.h>
#include <tvm/runtime/c_runtime_api.h>
#include <string> #include <string>
#include "./packed_func.h"
#include "./c_runtime_api.h"
namespace tvm { namespace tvm {
namespace runtime { namespace runtime {
/*!
* \brief the query type into GetAttr
*/
enum DeviceAttrKind : int { enum DeviceAttrKind : int {
kExist = 0, kExist = 0,
kMaxThreadsPerBlock = 1, kMaxThreadsPerBlock = 1,
kWarpSize = 2 kWarpSize = 2
}; };
/*!
* \brief TVM Runtime Device API, abstracts the device
* specific interface for memory management.
*/
class DeviceAPI { class DeviceAPI {
public: public:
/*! \brief virtual destructor */ /*! \brief virtual destructor */
...@@ -34,6 +39,7 @@ class DeviceAPI { ...@@ -34,6 +39,7 @@ class DeviceAPI {
* \param ctx The device context * \param ctx The device context
* \param kind The result kind * \param kind The result kind
* \param rv The return value. * \param rv The return value.
* \sa DeviceAttrKind
*/ */
virtual void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) = 0; virtual void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) = 0;
/*! /*!
...@@ -53,7 +59,6 @@ class DeviceAPI { ...@@ -53,7 +59,6 @@ class DeviceAPI {
virtual void FreeDataSpace(TVMContext ctx, void* ptr) = 0; virtual void FreeDataSpace(TVMContext ctx, void* ptr) = 0;
/*! /*!
* \brief copy data from one place to another * \brief copy data from one place to another
* \param dev The device to perform operation.
* \param from The source array. * \param from The source array.
* \param from_offset The byte offeset in the from. * \param from_offset The byte offeset in the from.
* \param to The target array. * \param to The target array.
...@@ -78,6 +83,12 @@ class DeviceAPI { ...@@ -78,6 +83,12 @@ class DeviceAPI {
*/ */
virtual void StreamSync(TVMContext ctx, TVMStreamHandle stream) = 0; virtual void StreamSync(TVMContext ctx, TVMStreamHandle stream) = 0;
/*! /*!
* \brief Set the stream
* \param ctx The context to set stream.
* \param stream The stream to be set.
*/
virtual void SetStream(TVMContext ctx, TVMStreamHandle stream) {}
/*!
* \brief Get device API base don context. * \brief Get device API base don context.
* \param ctx The context * \param ctx The context
* \param allow_missing Whether allow missing * \param allow_missing Whether allow missing
...@@ -88,21 +99,6 @@ class DeviceAPI { ...@@ -88,21 +99,6 @@ class DeviceAPI {
/*! \brief The device type bigger than this is RPC device */ /*! \brief The device type bigger than this is RPC device */
constexpr int kRPCSessMask = 128; constexpr int kRPCSessMask = 128;
/*!
* \brief The name of Device API factory.
* \param type The device type.
*/
inline std::string DeviceName(int type) {
switch (type) {
case kCPU: return "cpu";
case kGPU: return "gpu";
case kOpenCL: return "opencl";
case kMetal: return "metal";
case kVPI: return "vpi";
default: LOG(FATAL) << "unknown type =" << type; return "Unknown";
}
}
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
#endif // TVM_RUNTIME_DEVICE_API_H_ #endif // TVM_RUNTIME_DEVICE_API_H_
...@@ -4,12 +4,12 @@ ...@@ -4,12 +4,12 @@
* \brief Simulated VPI RAM device. * \brief Simulated VPI RAM device.
*/ */
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/runtime/device_api.h>
#include <tvm/packed_func_ext.h> #include <tvm/packed_func_ext.h>
#include <cstdlib> #include <cstdlib>
#include <unordered_map> #include <unordered_map>
#include <map> #include <map>
#include <queue> #include <queue>
#include "../../runtime/device_api.h"
#include "./vpi_session.h" #include "./vpi_session.h"
namespace tvm { namespace tvm {
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#include <tvm/runtime/packed_func.h> #include <tvm/runtime/packed_func.h>
#include <tvm/runtime/module.h> #include <tvm/runtime/module.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <dmlc/timer.h> #include <tvm/runtime/device_api.h>
#include <array> #include <array>
#include <algorithm> #include <algorithm>
#include <string> #include <string>
...@@ -16,11 +16,25 @@ ...@@ -16,11 +16,25 @@
#include <thread> #include <thread>
#include <mutex> #include <mutex>
#include "./runtime_base.h" #include "./runtime_base.h"
#include "./device_api.h"
namespace tvm { namespace tvm {
namespace runtime { namespace runtime {
/*!
* \brief The name of Device API factory.
* \param type The device type.
*/
inline std::string DeviceName(int type) {
switch (type) {
case kCPU: return "cpu";
case kGPU: return "gpu";
case kOpenCL: return "opencl";
case kMetal: return "metal";
case kVPI: return "vpi";
default: LOG(FATAL) << "unknown type =" << type; return "Unknown";
}
}
class DeviceAPIManager { class DeviceAPIManager {
public: public:
static const int kMaxDeviceAPI = 32; static const int kMaxDeviceAPI = 32;
...@@ -380,6 +394,12 @@ int TVMArrayCopyFromTo(TVMArrayHandle from, ...@@ -380,6 +394,12 @@ int TVMArrayCopyFromTo(TVMArrayHandle from,
API_END(); API_END();
} }
int TVMSetStream(TVMContext ctx, TVMStreamHandle stream) {
API_BEGIN();
DeviceAPIManager::Get(ctx)->SetStream(ctx, stream);
API_END();
}
int TVMSynchronize(TVMContext ctx, TVMStreamHandle stream) { int TVMSynchronize(TVMContext ctx, TVMStreamHandle stream) {
API_BEGIN(); API_BEGIN();
DeviceAPIManager::Get(ctx)->StreamSync(ctx, stream); DeviceAPIManager::Get(ctx)->StreamSync(ctx, stream);
......
...@@ -4,9 +4,9 @@ ...@@ -4,9 +4,9 @@
*/ */
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/runtime/device_api.h>
#include <cstdlib> #include <cstdlib>
#include <cstring> #include <cstring>
#include "./device_api.h"
namespace tvm { namespace tvm {
namespace runtime { namespace runtime {
......
...@@ -34,13 +34,14 @@ namespace runtime { ...@@ -34,13 +34,14 @@ namespace runtime {
<< "CUDA: " << cudaGetErrorString(e); \ << "CUDA: " << cudaGetErrorString(e); \
} }
/*! \brief Thread local workspace */
/*! class CUDAThreadEntry {
* \brief Compile code into ptx using NVRTC public:
* \param code The cuda code. /*! \brief The cuda stream */
* \return The PTX code. cudaStream_t stream{nullptr};
*/ // get the threadlocal workspace
std::string NVRTCCompile(const std::string& code); static CUDAThreadEntry* ThreadLocal();
};
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
#endif // TVM_CUDA_RUNTIME #endif // TVM_CUDA_RUNTIME
......
...@@ -4,13 +4,14 @@ ...@@ -4,13 +4,14 @@
* \brief GPU specific API * \brief GPU specific API
*/ */
#include <tvm/runtime/config.h> #include <tvm/runtime/config.h>
#include <tvm/runtime/device_api.h>
#if TVM_CUDA_RUNTIME #if TVM_CUDA_RUNTIME
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <dmlc/thread_local.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include "./cuda_common.h" #include "./cuda_common.h"
#include "../device_api.h"
namespace tvm { namespace tvm {
namespace runtime { namespace runtime {
...@@ -92,6 +93,11 @@ class CUDADeviceAPI final : public DeviceAPI { ...@@ -92,6 +93,11 @@ class CUDADeviceAPI final : public DeviceAPI {
CUDA_CALL(cudaStreamSynchronize(static_cast<cudaStream_t>(stream))); CUDA_CALL(cudaStreamSynchronize(static_cast<cudaStream_t>(stream)));
} }
void SetStream(TVMContext ctx, TVMStreamHandle stream) final {
CUDAThreadEntry::ThreadLocal()
->stream = static_cast<cudaStream_t>(stream);
}
private: private:
static void GPUCopy(const void* from, static void GPUCopy(const void* from,
void* to, void* to,
...@@ -106,6 +112,12 @@ class CUDADeviceAPI final : public DeviceAPI { ...@@ -106,6 +112,12 @@ class CUDADeviceAPI final : public DeviceAPI {
} }
}; };
typedef dmlc::ThreadLocalStore<CUDAThreadEntry> CUDAThreadStore;
CUDAThreadEntry* CUDAThreadEntry::ThreadLocal() {
return CUDAThreadStore::Get();
}
TVM_REGISTER_GLOBAL("device_api.gpu") TVM_REGISTER_GLOBAL("device_api.gpu")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
static CUDADeviceAPI inst; static CUDADeviceAPI inst;
......
...@@ -167,6 +167,7 @@ class CUDAWrappedFunc { ...@@ -167,6 +167,7 @@ class CUDAWrappedFunc {
if (fcache_[device_id] == nullptr) { if (fcache_[device_id] == nullptr) {
fcache_[device_id] = m_->GetFunc(device_id, func_name_); fcache_[device_id] = m_->GetFunc(device_id, func_name_);
} }
CUstream strm = static_cast<CUstream>(CUDAThreadEntry::ThreadLocal()->stream);
ThreadWorkLoad wl = thread_axis_cfg_.Extract(args); ThreadWorkLoad wl = thread_axis_cfg_.Extract(args);
CUDA_DRIVER_CALL(cuLaunchKernel( CUDA_DRIVER_CALL(cuLaunchKernel(
fcache_[device_id], fcache_[device_id],
...@@ -176,7 +177,7 @@ class CUDAWrappedFunc { ...@@ -176,7 +177,7 @@ class CUDAWrappedFunc {
wl.block_dim(0), wl.block_dim(0),
wl.block_dim(1), wl.block_dim(1),
wl.block_dim(2), wl.block_dim(2),
0, nullptr, void_args, 0)); 0, strm, void_args, 0));
} }
private: private:
......
...@@ -16,11 +16,11 @@ ...@@ -16,11 +16,11 @@
#include <tvm/runtime/config.h> #include <tvm/runtime/config.h>
#include <tvm/runtime/c_runtime_api.h> #include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/packed_func.h> #include <tvm/runtime/packed_func.h>
#include <tvm/runtime/device_api.h>
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <mutex> #include <mutex>
#include <string> #include <string>
#include <vector> #include <vector>
#include "../device_api.h"
namespace tvm { namespace tvm {
namespace runtime { namespace runtime {
......
...@@ -9,10 +9,10 @@ ...@@ -9,10 +9,10 @@
#include <tvm/runtime/config.h> #include <tvm/runtime/config.h>
#include <tvm/runtime/c_runtime_api.h> #include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/packed_func.h> #include <tvm/runtime/packed_func.h>
#include <tvm/runtime/device_api.h>
#include <dmlc/logging.h> #include <dmlc/logging.h>
#if TVM_OPENCL_RUNTIME
#include "../device_api.h"
#if TVM_OPENCL_RUNTIME
#ifdef __APPLE__ #ifdef __APPLE__
#include <OpenCL/opencl.h> #include <OpenCL/opencl.h>
#else #else
......
...@@ -4,8 +4,8 @@ ...@@ -4,8 +4,8 @@
*/ */
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/runtime/device_api.h>
#include "./rpc_session.h" #include "./rpc_session.h"
#include "../device_api.h"
namespace tvm { namespace tvm {
namespace runtime { namespace runtime {
......
...@@ -4,11 +4,12 @@ ...@@ -4,11 +4,12 @@
* \brief RPC session for remote function call. * \brief RPC session for remote function call.
*/ */
#include <tvm/runtime/packed_func.h> #include <tvm/runtime/packed_func.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/registry.h>
#include <memory> #include <memory>
#include <array> #include <array>
#include <chrono> #include <chrono>
#include "./rpc_session.h" #include "./rpc_session.h"
#include "../device_api.h"
namespace tvm { namespace tvm {
namespace runtime { namespace runtime {
......
...@@ -7,9 +7,9 @@ ...@@ -7,9 +7,9 @@
#define TVM_RUNTIME_RPC_RPC_SESSION_H_ #define TVM_RUNTIME_RPC_RPC_SESSION_H_
#include <tvm/runtime/packed_func.h> #include <tvm/runtime/packed_func.h>
#include <tvm/runtime/device_api.h>
#include <mutex> #include <mutex>
#include <string> #include <string>
#include "../device_api.h"
#include "../../common/socket.h" #include "../../common/socket.h"
namespace tvm { namespace tvm {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment