opencl_device_api.cc 9.38 KB
Newer Older
1 2
/*!
 *  Copyright (c) 2017 by Contributors
3
 * \file opencl_device_api.cc
4
 */
5
#include <tvm/runtime/registry.h>
6
#include <dmlc/thread_local.h>
7
#include "opencl_common.h"
8 9 10 11 12

namespace tvm {
namespace runtime {
namespace cl {

13 14 15 16
OpenCLThreadEntry* OpenCLWorkspace::GetThreadEntry() {
  return OpenCLThreadEntry::ThreadLocal();
}

17 18 19
const std::shared_ptr<OpenCLWorkspace>& OpenCLWorkspace::Global() {
  static std::shared_ptr<OpenCLWorkspace> inst = std::make_shared<OpenCLWorkspace>();
  return inst;
20 21
}

22
void OpenCLWorkspace::SetDevice(TVMContext ctx) {
23
  GetThreadEntry()->context.device_id = ctx.device_id;
24 25 26
}

void OpenCLWorkspace::GetAttr(
27
    TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) {
28
  this->Init();
29
  size_t index = static_cast<size_t>(ctx.device_id);
30 31 32 33 34 35 36
  if (kind == kExist) {
    *rv = static_cast<int>(index< devices.size());
    return;
  }
  CHECK_LT(index, devices.size())
      << "Invalid device id " << index;
  switch (kind) {
37
    case kExist: break;
38
    case kMaxThreadsPerBlock: {
39
      size_t value;
40 41 42 43 44 45 46
      OPENCL_CALL(clGetDeviceInfo(
          devices[index],  CL_DEVICE_MAX_WORK_GROUP_SIZE,
          sizeof(size_t), &value, nullptr));
      *rv = static_cast<int64_t>(value);
      break;
    }
    case kWarpSize: {
47
      /* TODO: the warp size of OpenCL device is not always 1
48
               e.g. Intel Graphics has a sub group concept which contains 8 - 32 work items,
49 50 51
               corresponding to the number of SIMD entries the heardware configures.
               We need to figure out a way to query this information from the hardware.
      */
52 53 54
      *rv = 1;
      break;
    }
55 56 57 58 59 60 61 62 63
    case kMaxSharedMemoryPerBlock: {
      cl_ulong value;
      OPENCL_CALL(clGetDeviceInfo(
          devices[index], CL_DEVICE_LOCAL_MEM_SIZE,
          sizeof(cl_ulong), &value, nullptr));
      *rv = static_cast<int64_t>(value);
      break;
    }
    case kComputeVersion: return;
64 65 66 67 68 69 70 71
    case kDeviceName: {
      char value[128] = {0};
      OPENCL_CALL(clGetDeviceInfo(
          devices[index], CL_DEVICE_NAME,
          sizeof(value) - 1, value, nullptr));
      *rv = std::string(value);
      break;
    }
72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
    case kMaxClockRate: {
      cl_uint value;
      OPENCL_CALL(clGetDeviceInfo(
          devices[index], CL_DEVICE_MAX_CLOCK_FREQUENCY,
          sizeof(cl_uint), &value, nullptr));
      *rv = static_cast<int32_t>(value);
      break;
    }
    case kMultiProcessorCount: {
      cl_uint value;
      OPENCL_CALL(clGetDeviceInfo(
          devices[index], CL_DEVICE_MAX_COMPUTE_UNITS,
          sizeof(cl_uint), &value, nullptr));
      *rv = static_cast<int32_t>(value);
      break;
    }
88 89 90 91 92 93 94 95 96 97
    case kMaxThreadDimensions: {
      size_t dims[3];
      OPENCL_CALL(clGetDeviceInfo(
          devices[index], CL_DEVICE_MAX_WORK_ITEM_SIZES, sizeof(dims), dims, nullptr));

      std::stringstream ss;  // use json string to return multiple int values;
      ss << "[" << dims[0] <<", " << dims[1] << ", " << dims[2] << "]";
      *rv = ss.str();
      break;
    }
98 99 100
  }
}

101
void* OpenCLWorkspace::AllocDataSpace(
102
    TVMContext ctx, size_t size, size_t alignment, TVMType type_hint) {
103 104
  this->Init();
  CHECK(context != nullptr) << "No OpenCL device";
105 106 107 108 109 110 111 112 113 114 115 116 117
  cl_int err_code;
  cl_mem mptr = clCreateBuffer(
      this->context, CL_MEM_READ_WRITE, size, nullptr, &err_code);
  OPENCL_CHECK_ERROR(err_code);
  return mptr;
}

void OpenCLWorkspace::FreeDataSpace(TVMContext ctx, void* ptr) {
  cl_mem mptr = static_cast<cl_mem>(ptr);
  OPENCL_CALL(clReleaseMemObject(mptr));
}

void OpenCLWorkspace::CopyDataFromTo(const void* from,
118
                                     size_t from_offset,
119
                                     void* to,
120
                                     size_t to_offset,
121 122 123
                                     size_t size,
                                     TVMContext ctx_from,
                                     TVMContext ctx_to,
124
                                     TVMType type_hint,
125
                                     TVMStreamHandle stream) {
126
  this->Init();
127
  CHECK(stream == nullptr);
128
  if (IsOpenCLDevice(ctx_from) && IsOpenCLDevice(ctx_to)) {
129 130 131 132
    OPENCL_CALL(clEnqueueCopyBuffer(
        this->GetQueue(ctx_to),
        static_cast<cl_mem>((void*)from),  // NOLINT(*)
        static_cast<cl_mem>(to),
133
        from_offset, to_offset, size, 0, nullptr, nullptr));
134
  } else if (IsOpenCLDevice(ctx_from) && ctx_to.device_type == kDLCPU) {
135 136 137
    OPENCL_CALL(clEnqueueReadBuffer(
        this->GetQueue(ctx_from),
        static_cast<cl_mem>((void*)from),  // NOLINT(*)
138 139
        CL_FALSE, from_offset, size,
        static_cast<char*>(to) + to_offset,
140 141
        0, nullptr, nullptr));
    OPENCL_CALL(clFinish(this->GetQueue(ctx_from)));
142
  } else if (ctx_from.device_type == kDLCPU && IsOpenCLDevice(ctx_to)) {
143 144 145
    OPENCL_CALL(clEnqueueWriteBuffer(
        this->GetQueue(ctx_to),
        static_cast<cl_mem>(to),
146 147
        CL_FALSE, to_offset, size,
        static_cast<const char*>(from) + from_offset,
148 149 150 151 152 153 154 155 156 157 158 159
        0, nullptr, nullptr));
    OPENCL_CALL(clFinish(this->GetQueue(ctx_to)));
  } else {
    LOG(FATAL) << "Expect copy from/to OpenCL or between OpenCL";
  }
}

void OpenCLWorkspace::StreamSync(TVMContext ctx, TVMStreamHandle stream) {
  CHECK(stream == nullptr);
  OPENCL_CALL(clFinish(this->GetQueue(ctx)));
}

160 161 162
void* OpenCLWorkspace::AllocWorkspace(TVMContext ctx,
                                      size_t size,
                                      TVMType type_hint) {
163
  return GetThreadEntry()->pool.AllocWorkspace(ctx, size);
164 165 166
}

void OpenCLWorkspace::FreeWorkspace(TVMContext ctx, void* data) {
167
  GetThreadEntry()->pool.FreeWorkspace(ctx, data);
168 169
}

170 171 172 173 174 175
typedef dmlc::ThreadLocalStore<OpenCLThreadEntry> OpenCLThreadStore;

OpenCLThreadEntry* OpenCLThreadEntry::ThreadLocal() {
  return OpenCLThreadStore::Get();
}

176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197
std::string GetPlatformInfo(
    cl_platform_id pid, cl_platform_info param_name) {
  size_t ret_size;
  OPENCL_CALL(clGetPlatformInfo(pid, param_name, 0, nullptr, &ret_size));
  std::string ret;
  ret.resize(ret_size);
  OPENCL_CALL(clGetPlatformInfo(pid, param_name, ret_size, &ret[0], nullptr));
  return ret;
}

std::string GetDeviceInfo(
    cl_device_id pid, cl_device_info param_name) {
  size_t ret_size;
  OPENCL_CALL(clGetDeviceInfo(pid, param_name, 0, nullptr, &ret_size));
  std::string ret;
  ret.resize(ret_size);
  OPENCL_CALL(clGetDeviceInfo(pid, param_name, ret_size, &ret[0], nullptr));
  return ret;
}

std::vector<cl_platform_id> GetPlatformIDs() {
  cl_uint ret_size;
198
  cl_int code = clGetPlatformIDs(0, nullptr, &ret_size);
199
  std::vector<cl_platform_id> ret;
200
  if (code != CL_SUCCESS) return ret;
201 202 203 204 205 206 207 208 209
  ret.resize(ret_size);
  OPENCL_CALL(clGetPlatformIDs(ret_size, &ret[0], nullptr));
  return ret;
}

std::vector<cl_device_id> GetDeviceIDs(
    cl_platform_id pid, std::string device_type) {
  cl_device_type dtype = CL_DEVICE_TYPE_ALL;
  if (device_type == "cpu") dtype = CL_DEVICE_TYPE_CPU;
210
  if (device_type == "gpu") dtype = CL_DEVICE_TYPE_GPU;
211 212
  if (device_type == "accelerator") dtype = CL_DEVICE_TYPE_ACCELERATOR;
  cl_uint ret_size;
213
  cl_int code = clGetDeviceIDs(pid, dtype, 0, nullptr, &ret_size);
214
  std::vector<cl_device_id> ret;
215
  if (code != CL_SUCCESS) return ret;
216 217 218 219 220 221 222 223 224 225 226 227 228 229
  ret.resize(ret_size);
  OPENCL_CALL(clGetDeviceIDs(pid, dtype, ret_size, &ret[0], nullptr));
  return ret;
}

bool MatchPlatformInfo(
    cl_platform_id pid,
    cl_platform_info param_name,
    std::string value) {
  if (value.length() == 0) return true;
  std::string param_value = GetPlatformInfo(pid, param_name);
  return param_value.find(value) != std::string::npos;
}

230 231
void OpenCLWorkspace::Init(const std::string& type_key, const std::string& device_type,
                           const std::string& platform_name) {
232
  if (initialized_) return;
233
  std::lock_guard<std::mutex> lock(this->mu);
234 235
  if (initialized_) return;
  if (context != nullptr) return;
236
  // matched platforms
237 238
  std::vector<cl_platform_id> platform_ids = cl::GetPlatformIDs();
  if (platform_ids.size() == 0) {
239 240
    LOG(WARNING) << "No OpenCL platform matched given existing options ...";
    return;
241
  }
242 243 244 245 246 247
  this->platform_id = nullptr;
  for (auto platform_id : platform_ids) {
    if (!MatchPlatformInfo(platform_id, CL_PLATFORM_NAME, platform_name)) {
      continue;
    }
    std::vector<cl_device_id> devices_matched = cl::GetDeviceIDs(platform_id, device_type);
248 249 250 251
    if ((devices_matched.size() == 0) && (device_type == "gpu")) {
      LOG(WARNING) << "Using CPU OpenCL device";
      devices_matched = cl::GetDeviceIDs(platform_id, "cpu");
    }
252
    if (devices_matched.size() > 0) {
253
      this->type_key = type_key;
254 255 256 257
      this->platform_id = platform_id;
      this->platform_name = cl::GetPlatformInfo(platform_id, CL_PLATFORM_NAME);
      this->device_type = device_type;
      this->devices = devices_matched;
258
      break;
259
    }
260
  }
261
  if (this->platform_id == nullptr) {
262 263
    LOG(WARNING) << "No OpenCL device";
    return;
264
  }
265
  cl_int err_code;
266 267
  this->context = clCreateContext(
      nullptr, this->devices.size(), &(this->devices[0]),
268 269
      nullptr, nullptr, &err_code);
  OPENCL_CHECK_ERROR(err_code);
270 271 272 273 274
  CHECK_EQ(this->queues.size(), 0U);
  for (size_t i = 0; i < this->devices.size(); ++i) {
    cl_device_id did = this->devices[i];
    this->queues.push_back(
        clCreateCommandQueue(this->context, did, 0, &err_code));
275 276
    OPENCL_CHECK_ERROR(err_code);
  }
277
  initialized_ = true;
278 279
}

280
TVM_REGISTER_GLOBAL("device_api.opencl")
281
.set_body([](TVMArgs args, TVMRetValue* rv) {
282
    DeviceAPI* ptr = OpenCLWorkspace::Global().get();
283 284 285
    *rv = static_cast<void*>(ptr);
  });

286 287 288
}  // namespace cl
}  // namespace runtime
}  // namespace tvm