opencl_device_api.cc 7.42 KB
Newer Older
1 2
/*!
 *  Copyright (c) 2017 by Contributors
3
 * \file opencl_device_api.cc
4 5 6 7 8
 */
#include "./opencl_common.h"

#if TVM_OPENCL_RUNTIME

9
#include <tvm/runtime/registry.h>
10 11 12 13 14 15
#include <dmlc/thread_local.h>

namespace tvm {
namespace runtime {
namespace cl {

16 17 18
const std::shared_ptr<OpenCLWorkspace>& OpenCLWorkspace::Global() {
  static std::shared_ptr<OpenCLWorkspace> inst = std::make_shared<OpenCLWorkspace>();
  return inst;
19 20
}

21 22
void OpenCLWorkspace::SetDevice(TVMContext ctx) {
  OpenCLThreadEntry::ThreadLocal()->context.device_id = ctx.device_id;
23 24 25
}

void OpenCLWorkspace::GetAttr(
26
    TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) {
27
  this->Init();
28
  size_t index = static_cast<size_t>(ctx.device_id);
29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51
  if (kind == kExist) {
    *rv = static_cast<int>(index< devices.size());
    return;
  }
  CHECK_LT(index, devices.size())
      << "Invalid device id " << index;
  size_t value;
  switch (kind) {
    case kMaxThreadsPerBlock: {
      OPENCL_CALL(clGetDeviceInfo(
          devices[index],  CL_DEVICE_MAX_WORK_GROUP_SIZE,
          sizeof(size_t), &value, nullptr));
      *rv = static_cast<int64_t>(value);
      break;
    }
    case kWarpSize: {
      *rv = 1;
      break;
    }
    case kExist: break;
  }
}

52 53
void* OpenCLWorkspace::AllocDataSpace(
    TVMContext ctx, size_t size, size_t alignment) {
54 55
  this->Init();
  CHECK(context != nullptr) << "No OpenCL device";
56 57 58 59 60 61 62 63 64 65 66 67 68
  cl_int err_code;
  cl_mem mptr = clCreateBuffer(
      this->context, CL_MEM_READ_WRITE, size, nullptr, &err_code);
  OPENCL_CHECK_ERROR(err_code);
  return mptr;
}

void OpenCLWorkspace::FreeDataSpace(TVMContext ctx, void* ptr) {
  cl_mem mptr = static_cast<cl_mem>(ptr);
  OPENCL_CALL(clReleaseMemObject(mptr));
}

void OpenCLWorkspace::CopyDataFromTo(const void* from,
69
                                     size_t from_offset,
70
                                     void* to,
71
                                     size_t to_offset,
72 73 74 75
                                     size_t size,
                                     TVMContext ctx_from,
                                     TVMContext ctx_to,
                                     TVMStreamHandle stream) {
76
  this->Init();
77 78 79 80 81 82
  CHECK(stream == nullptr);
  if (ctx_from.device_type == kOpenCL && ctx_to.device_type == kOpenCL) {
    OPENCL_CALL(clEnqueueCopyBuffer(
        this->GetQueue(ctx_to),
        static_cast<cl_mem>((void*)from),  // NOLINT(*)
        static_cast<cl_mem>(to),
83
        from_offset, to_offset, size, 0, nullptr, nullptr));
84 85 86 87
  } else if (ctx_from.device_type == kOpenCL && ctx_to.device_type == kCPU) {
    OPENCL_CALL(clEnqueueReadBuffer(
        this->GetQueue(ctx_from),
        static_cast<cl_mem>((void*)from),  // NOLINT(*)
88 89
        CL_FALSE, from_offset, size,
        static_cast<char*>(to) + to_offset,
90 91 92 93 94 95
        0, nullptr, nullptr));
    OPENCL_CALL(clFinish(this->GetQueue(ctx_from)));
  } else if (ctx_from.device_type == kCPU && ctx_to.device_type == kOpenCL) {
    OPENCL_CALL(clEnqueueWriteBuffer(
        this->GetQueue(ctx_to),
        static_cast<cl_mem>(to),
96 97
        CL_FALSE, to_offset, size,
        static_cast<const char*>(from) + from_offset,
98 99 100 101 102 103 104 105 106 107 108 109
        0, nullptr, nullptr));
    OPENCL_CALL(clFinish(this->GetQueue(ctx_to)));
  } else {
    LOG(FATAL) << "Expect copy from/to OpenCL or between OpenCL";
  }
}

void OpenCLWorkspace::StreamSync(TVMContext ctx, TVMStreamHandle stream) {
  CHECK(stream == nullptr);
  OPENCL_CALL(clFinish(this->GetQueue(ctx)));
}

110 111 112 113 114 115 116 117
void* OpenCLWorkspace::AllocWorkspace(TVMContext ctx, size_t size) {
  return OpenCLThreadEntry::ThreadLocal()->pool.AllocWorkspace(ctx, size);
}

void OpenCLWorkspace::FreeWorkspace(TVMContext ctx, void* data) {
  OpenCLThreadEntry::ThreadLocal()->pool.FreeWorkspace(ctx, data);
}

118 119 120 121 122 123
typedef dmlc::ThreadLocalStore<OpenCLThreadEntry> OpenCLThreadStore;

OpenCLThreadEntry* OpenCLThreadEntry::ThreadLocal() {
  return OpenCLThreadStore::Get();
}

124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
std::string GetPlatformInfo(
    cl_platform_id pid, cl_platform_info param_name) {
  size_t ret_size;
  OPENCL_CALL(clGetPlatformInfo(pid, param_name, 0, nullptr, &ret_size));
  std::string ret;
  ret.resize(ret_size);
  OPENCL_CALL(clGetPlatformInfo(pid, param_name, ret_size, &ret[0], nullptr));
  return ret;
}

std::string GetDeviceInfo(
    cl_device_id pid, cl_device_info param_name) {
  size_t ret_size;
  OPENCL_CALL(clGetDeviceInfo(pid, param_name, 0, nullptr, &ret_size));
  std::string ret;
  ret.resize(ret_size);
  OPENCL_CALL(clGetDeviceInfo(pid, param_name, ret_size, &ret[0], nullptr));
  return ret;
}

std::vector<cl_platform_id> GetPlatformIDs() {
  cl_uint ret_size;
146
  cl_int code = clGetPlatformIDs(0, nullptr, &ret_size);
147
  std::vector<cl_platform_id> ret;
148
  if (code != CL_SUCCESS) return ret;
149 150 151 152 153 154 155 156 157
  ret.resize(ret_size);
  OPENCL_CALL(clGetPlatformIDs(ret_size, &ret[0], nullptr));
  return ret;
}

std::vector<cl_device_id> GetDeviceIDs(
    cl_platform_id pid, std::string device_type) {
  cl_device_type dtype = CL_DEVICE_TYPE_ALL;
  if (device_type == "cpu") dtype = CL_DEVICE_TYPE_CPU;
158
  if (device_type == "gpu") dtype = CL_DEVICE_TYPE_GPU;
159 160
  if (device_type == "accelerator") dtype = CL_DEVICE_TYPE_ACCELERATOR;
  cl_uint ret_size;
161
  cl_int code = clGetDeviceIDs(pid, dtype, 0, nullptr, &ret_size);
162
  std::vector<cl_device_id> ret;
163
  if (code != CL_SUCCESS) return ret;
164 165 166 167 168 169 170 171 172 173 174 175 176 177
  ret.resize(ret_size);
  OPENCL_CALL(clGetDeviceIDs(pid, dtype, ret_size, &ret[0], nullptr));
  return ret;
}

bool MatchPlatformInfo(
    cl_platform_id pid,
    cl_platform_info param_name,
    std::string value) {
  if (value.length() == 0) return true;
  std::string param_value = GetPlatformInfo(pid, param_name);
  return param_value.find(value) != std::string::npos;
}

178 179 180 181 182 183
void OpenCLWorkspace::Init() {
  if (initialized_) return;
  std::lock_guard<std::mutex>(this->mu);
  if (initialized_) return;
  initialized_ = true;
  if (context != nullptr) return;
184
  // matched platforms
185
  std::vector<cl_platform_id> platform_matched = cl::GetPlatformIDs();
186
  if (platform_matched.size() == 0) {
187 188
    LOG(WARNING) << "No OpenCL platform matched given existing options ...";
    return;
189 190 191 192
  }
  if (platform_matched.size() > 1) {
    LOG(WARNING) << "Multiple OpenCL platforms matched, use the first one ... ";
  }
193
  this->platform_id = platform_matched[0];
194
  LOG(INFO) << "Initialize OpenCL platform \'"
195
            << cl::GetPlatformInfo(this->platform_id, CL_PLATFORM_NAME) << '\'';
196
  std::vector<cl_device_id> devices_matched =
197 198 199 200 201 202
      cl::GetDeviceIDs(this->platform_id, "gpu");
  if (devices_matched.size() == 0) {
    LOG(WARNING) << "No OpenCL device any device matched given the options";
    return;
  }
  this->devices = devices_matched;
203
  cl_int err_code;
204 205
  this->context = clCreateContext(
      nullptr, this->devices.size(), &(this->devices[0]),
206 207
      nullptr, nullptr, &err_code);
  OPENCL_CHECK_ERROR(err_code);
208 209 210 211 212
  CHECK_EQ(this->queues.size(), 0U);
  for (size_t i = 0; i < this->devices.size(); ++i) {
    cl_device_id did = this->devices[i];
    this->queues.push_back(
        clCreateCommandQueue(this->context, did, 0, &err_code));
213 214 215 216 217 218 219
    OPENCL_CHECK_ERROR(err_code);
    LOG(INFO) << "opencl(" << i
              << ")=\'" << cl::GetDeviceInfo(did, CL_DEVICE_NAME)
              << "\' cl_device_id=" << did;
  }
}

220
bool InitOpenCL(TVMArgs args, TVMRetValue* rv) {
221
  cl::OpenCLWorkspace::Global()->Init();
222 223
  return true;
}
224

225
TVM_REGISTER_GLOBAL("device_api.opencl")
226
.set_body([](TVMArgs args, TVMRetValue* rv) {
227
    DeviceAPI* ptr = OpenCLWorkspace::Global().get();
228 229 230
    *rv = static_cast<void*>(ptr);
  });

231 232 233 234 235
}  // namespace cl
}  // namespace runtime
}  // namespace tvm

#endif  // TVM_OPENCL_RUNTIME