Commit 3b8e70ae by Tianqi Chen Committed by GitHub

[RUNTIME] Move device_api to include (#185)

* [RUNTIME] Move device_api to include

* fix doxygen

* fix device api

* fx
parent fcfec961
......@@ -195,43 +195,6 @@ TVM_DLL int TVMModGetFunction(TVMModuleHandle mod,
TVM_DLL int TVMModFree(TVMModuleHandle mod);
/*!
* \brief Backend function for modules to get function
* from its environment mod_node (its imports and global function).
*
* The user do should not call TVMFuncFree on func.
*
* \note This API is supposed to be used by backend,
* it is not supposed to be used by user.
*
* \param mod_node The module handle.
* \param func_name The name of the function.
* \param out The result function.
* \return 0 when no error is thrown, -1 when failure happens
*/
TVM_DLL int TVMBackendGetFuncFromEnv(void* mod_node,
const char* func_name,
TVMFunctionHandle *out);
/*!
* \brief Backend function for running parallel for loop.
*
* \note This API is supposed to be used by backend,
* it is not supposed to be used by user.
*
* \param begin The start of iteration.
* \param end The end of iteration.
* \param lambda The lambda function to be executed.
* \param env The environment of lambda function.
*
* \return 0 when no error is thrown, -1 when failure happens
*/
TVM_DLL int TVMBackendParallelFor(
int64_t begin,
int64_t end,
int (*lambda)(int64_t begin, int64_t end, void* env),
void* env);
/*!
* \brief Free the function when it is no longer needed.
* \param func The function handle
* \return 0 when success, -1 when failure happens
......@@ -351,6 +314,44 @@ TVM_DLL int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out);
TVM_DLL int TVMFuncListGlobalNames(int *out_size,
const char*** out_array);
// Backend related functions.
/*!
* \brief Backend function for modules to get function
* from its environment mod_node (its imports and global function).
*
* The user do should not call TVMFuncFree on func.
*
* \note This API is supposed to be used by backend,
* it is not supposed to be used by user.
*
* \param mod_node The module handle.
* \param func_name The name of the function.
* \param out The result function.
* \return 0 when no error is thrown, -1 when failure happens
*/
TVM_DLL int TVMBackendGetFuncFromEnv(void* mod_node,
const char* func_name,
TVMFunctionHandle *out);
/*!
* \brief Backend function for running parallel for loop.
*
* \note This API is supposed to be used by backend,
* it is not supposed to be used by user.
*
* \param begin The start of iteration.
* \param end The end of iteration.
* \param lambda The lambda function to be executed.
* \param env The environment of lambda function.
*
* \return 0 when no error is thrown, -1 when failure happens
*/
TVM_DLL int TVMBackendParallelFor(
int64_t begin,
int64_t end,
int (*lambda)(int64_t begin, int64_t end, void* env),
void* env);
// Array related apis for quick proptyping
/*!
* \brief Allocate a nd-array's memory,
......@@ -368,6 +369,7 @@ TVM_DLL int TVMArrayAlloc(const tvm_index_t* shape,
TVMType dtype,
TVMContext ctx,
TVMArrayHandle* out);
/*!
* \brief Free the TVM Array.
* \param handle The array handle to be freed.
......@@ -385,6 +387,19 @@ TVM_DLL int TVMArrayFree(TVMArrayHandle handle);
TVM_DLL int TVMArrayCopyFromTo(TVMArrayHandle from,
TVMArrayHandle to,
TVMStreamHandle stream);
/*!
* \brief Set the runtime stream of current thread to be stream.
* The subsequent calls to the same device_type
* will use the setted stream handle.
* The specific type of stream is runtime device dependent.
*
* \param ctx The context.
* \param handle The stream handle.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMSetStream(TVMContext ctx, TVMStreamHandle handle);
/*!
* \brief Wait until all computations on stream completes.
* \param ctx The ctx to be synchronized.
......
/*!
* Copyright (c) 2016 by Contributors
* \file device_api.h
* \brief Device specific API
* \brief Abstract device memory management API
*/
#ifndef TVM_RUNTIME_DEVICE_API_H_
#define TVM_RUNTIME_DEVICE_API_H_
#include <tvm/base.h>
#include <tvm/runtime/c_runtime_api.h>
#include <string>
#include "./packed_func.h"
#include "./c_runtime_api.h"
namespace tvm {
namespace runtime {
/*!
* \brief the query type into GetAttr
*/
enum DeviceAttrKind : int {
kExist = 0,
kMaxThreadsPerBlock = 1,
kWarpSize = 2
};
/*!
* \brief TVM Runtime Device API, abstracts the device
* specific interface for memory management.
*/
class DeviceAPI {
public:
/*! \brief virtual destructor */
......@@ -34,6 +39,7 @@ class DeviceAPI {
* \param ctx The device context
* \param kind The result kind
* \param rv The return value.
* \sa DeviceAttrKind
*/
virtual void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) = 0;
/*!
......@@ -53,7 +59,6 @@ class DeviceAPI {
virtual void FreeDataSpace(TVMContext ctx, void* ptr) = 0;
/*!
* \brief copy data from one place to another
* \param dev The device to perform operation.
* \param from The source array.
* \param from_offset The byte offeset in the from.
* \param to The target array.
......@@ -78,6 +83,12 @@ class DeviceAPI {
*/
virtual void StreamSync(TVMContext ctx, TVMStreamHandle stream) = 0;
/*!
* \brief Set the stream
* \param ctx The context to set stream.
* \param stream The stream to be set.
*/
virtual void SetStream(TVMContext ctx, TVMStreamHandle stream) {}
/*!
* \brief Get device API base don context.
* \param ctx The context
* \param allow_missing Whether allow missing
......@@ -88,21 +99,6 @@ class DeviceAPI {
/*! \brief The device type bigger than this is RPC device */
constexpr int kRPCSessMask = 128;
/*!
* \brief The name of Device API factory.
* \param type The device type.
*/
inline std::string DeviceName(int type) {
switch (type) {
case kCPU: return "cpu";
case kGPU: return "gpu";
case kOpenCL: return "opencl";
case kMetal: return "metal";
case kVPI: return "vpi";
default: LOG(FATAL) << "unknown type =" << type; return "Unknown";
}
}
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_DEVICE_API_H_
......@@ -4,12 +4,12 @@
* \brief Simulated VPI RAM device.
*/
#include <tvm/runtime/registry.h>
#include <tvm/runtime/device_api.h>
#include <tvm/packed_func_ext.h>
#include <cstdlib>
#include <unordered_map>
#include <map>
#include <queue>
#include "../../runtime/device_api.h"
#include "./vpi_session.h"
namespace tvm {
......
......@@ -8,7 +8,7 @@
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/module.h>
#include <tvm/runtime/registry.h>
#include <dmlc/timer.h>
#include <tvm/runtime/device_api.h>
#include <array>
#include <algorithm>
#include <string>
......@@ -16,11 +16,25 @@
#include <thread>
#include <mutex>
#include "./runtime_base.h"
#include "./device_api.h"
namespace tvm {
namespace runtime {
/*!
* \brief The name of Device API factory.
* \param type The device type.
*/
inline std::string DeviceName(int type) {
switch (type) {
case kCPU: return "cpu";
case kGPU: return "gpu";
case kOpenCL: return "opencl";
case kMetal: return "metal";
case kVPI: return "vpi";
default: LOG(FATAL) << "unknown type =" << type; return "Unknown";
}
}
class DeviceAPIManager {
public:
static const int kMaxDeviceAPI = 32;
......@@ -380,6 +394,12 @@ int TVMArrayCopyFromTo(TVMArrayHandle from,
API_END();
}
int TVMSetStream(TVMContext ctx, TVMStreamHandle stream) {
API_BEGIN();
DeviceAPIManager::Get(ctx)->SetStream(ctx, stream);
API_END();
}
int TVMSynchronize(TVMContext ctx, TVMStreamHandle stream) {
API_BEGIN();
DeviceAPIManager::Get(ctx)->StreamSync(ctx, stream);
......
......@@ -4,9 +4,9 @@
*/
#include <dmlc/logging.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/device_api.h>
#include <cstdlib>
#include <cstring>
#include "./device_api.h"
namespace tvm {
namespace runtime {
......
......@@ -34,13 +34,14 @@ namespace runtime {
<< "CUDA: " << cudaGetErrorString(e); \
}
/*!
* \brief Compile code into ptx using NVRTC
* \param code The cuda code.
* \return The PTX code.
*/
std::string NVRTCCompile(const std::string& code);
/*! \brief Thread local workspace */
class CUDAThreadEntry {
public:
/*! \brief The cuda stream */
cudaStream_t stream{nullptr};
// get the threadlocal workspace
static CUDAThreadEntry* ThreadLocal();
};
} // namespace runtime
} // namespace tvm
#endif // TVM_CUDA_RUNTIME
......
......@@ -4,13 +4,14 @@
* \brief GPU specific API
*/
#include <tvm/runtime/config.h>
#include <tvm/runtime/device_api.h>
#if TVM_CUDA_RUNTIME
#include <dmlc/logging.h>
#include <dmlc/thread_local.h>
#include <tvm/runtime/registry.h>
#include <cuda_runtime.h>
#include "./cuda_common.h"
#include "../device_api.h"
namespace tvm {
namespace runtime {
......@@ -92,6 +93,11 @@ class CUDADeviceAPI final : public DeviceAPI {
CUDA_CALL(cudaStreamSynchronize(static_cast<cudaStream_t>(stream)));
}
void SetStream(TVMContext ctx, TVMStreamHandle stream) final {
CUDAThreadEntry::ThreadLocal()
->stream = static_cast<cudaStream_t>(stream);
}
private:
static void GPUCopy(const void* from,
void* to,
......@@ -106,6 +112,12 @@ class CUDADeviceAPI final : public DeviceAPI {
}
};
typedef dmlc::ThreadLocalStore<CUDAThreadEntry> CUDAThreadStore;
CUDAThreadEntry* CUDAThreadEntry::ThreadLocal() {
return CUDAThreadStore::Get();
}
TVM_REGISTER_GLOBAL("device_api.gpu")
.set_body([](TVMArgs args, TVMRetValue* rv) {
static CUDADeviceAPI inst;
......
......@@ -167,6 +167,7 @@ class CUDAWrappedFunc {
if (fcache_[device_id] == nullptr) {
fcache_[device_id] = m_->GetFunc(device_id, func_name_);
}
CUstream strm = static_cast<CUstream>(CUDAThreadEntry::ThreadLocal()->stream);
ThreadWorkLoad wl = thread_axis_cfg_.Extract(args);
CUDA_DRIVER_CALL(cuLaunchKernel(
fcache_[device_id],
......@@ -176,7 +177,7 @@ class CUDAWrappedFunc {
wl.block_dim(0),
wl.block_dim(1),
wl.block_dim(2),
0, nullptr, void_args, 0));
0, strm, void_args, 0));
}
private:
......
......@@ -16,11 +16,11 @@
#include <tvm/runtime/config.h>
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/device_api.h>
#include <dmlc/logging.h>
#include <mutex>
#include <string>
#include <vector>
#include "../device_api.h"
namespace tvm {
namespace runtime {
......
......@@ -9,10 +9,10 @@
#include <tvm/runtime/config.h>
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/device_api.h>
#include <dmlc/logging.h>
#if TVM_OPENCL_RUNTIME
#include "../device_api.h"
#if TVM_OPENCL_RUNTIME
#ifdef __APPLE__
#include <OpenCL/opencl.h>
#else
......
......@@ -4,8 +4,8 @@
*/
#include <dmlc/logging.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/device_api.h>
#include "./rpc_session.h"
#include "../device_api.h"
namespace tvm {
namespace runtime {
......
......@@ -4,11 +4,12 @@
* \brief RPC session for remote function call.
*/
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/registry.h>
#include <memory>
#include <array>
#include <chrono>
#include "./rpc_session.h"
#include "../device_api.h"
namespace tvm {
namespace runtime {
......
......@@ -7,9 +7,9 @@
#define TVM_RUNTIME_RPC_RPC_SESSION_H_
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/device_api.h>
#include <mutex>
#include <string>
#include "../device_api.h"
#include "../../common/socket.h"
namespace tvm {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment