opencl_device_api.cc 9.53 KB
Newer Older
1 2
/*!
 *  Copyright (c) 2017 by Contributors
3
 * \file opencl_device_api.cc
4
 */
5
#include <tvm/runtime/registry.h>
6
#include <dmlc/thread_local.h>
7
#include "opencl_common.h"
8 9 10 11 12

namespace tvm {
namespace runtime {
namespace cl {

13 14 15 16
OpenCLThreadEntry* OpenCLWorkspace::GetThreadEntry() {
  return OpenCLThreadEntry::ThreadLocal();
}

17 18 19
const std::shared_ptr<OpenCLWorkspace>& OpenCLWorkspace::Global() {
  static std::shared_ptr<OpenCLWorkspace> inst = std::make_shared<OpenCLWorkspace>();
  return inst;
20 21
}

22
void OpenCLWorkspace::SetDevice(TVMContext ctx) {
23
  GetThreadEntry()->context.device_id = ctx.device_id;
24 25 26
}

void OpenCLWorkspace::GetAttr(
27
    TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) {
28
  this->Init();
29
  size_t index = static_cast<size_t>(ctx.device_id);
30 31 32 33 34 35 36
  if (kind == kExist) {
    *rv = static_cast<int>(index< devices.size());
    return;
  }
  CHECK_LT(index, devices.size())
      << "Invalid device id " << index;
  switch (kind) {
37
    case kExist: break;
38
    case kMaxThreadsPerBlock: {
39
      size_t value;
40 41 42 43 44 45 46
      OPENCL_CALL(clGetDeviceInfo(
          devices[index],  CL_DEVICE_MAX_WORK_GROUP_SIZE,
          sizeof(size_t), &value, nullptr));
      *rv = static_cast<int64_t>(value);
      break;
    }
    case kWarpSize: {
47
      /* TODO: the warp size of OpenCL device is not always 1
48
               e.g. Intel Graphics has a sub group concept which contains 8 - 32 work items,
49 50 51
               corresponding to the number of SIMD entries the heardware configures.
               We need to figure out a way to query this information from the hardware.
      */
52 53 54
      *rv = 1;
      break;
    }
55 56 57 58 59 60 61 62 63
    case kMaxSharedMemoryPerBlock: {
      cl_ulong value;
      OPENCL_CALL(clGetDeviceInfo(
          devices[index], CL_DEVICE_LOCAL_MEM_SIZE,
          sizeof(cl_ulong), &value, nullptr));
      *rv = static_cast<int64_t>(value);
      break;
    }
    case kComputeVersion: return;
64 65 66 67 68 69 70 71
    case kDeviceName: {
      char value[128] = {0};
      OPENCL_CALL(clGetDeviceInfo(
          devices[index], CL_DEVICE_NAME,
          sizeof(value) - 1, value, nullptr));
      *rv = std::string(value);
      break;
    }
72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
    case kMaxClockRate: {
      cl_uint value;
      OPENCL_CALL(clGetDeviceInfo(
          devices[index], CL_DEVICE_MAX_CLOCK_FREQUENCY,
          sizeof(cl_uint), &value, nullptr));
      *rv = static_cast<int32_t>(value);
      break;
    }
    case kMultiProcessorCount: {
      cl_uint value;
      OPENCL_CALL(clGetDeviceInfo(
          devices[index], CL_DEVICE_MAX_COMPUTE_UNITS,
          sizeof(cl_uint), &value, nullptr));
      *rv = static_cast<int32_t>(value);
      break;
    }
88 89 90 91 92 93 94 95 96 97
    case kMaxThreadDimensions: {
      size_t dims[3];
      OPENCL_CALL(clGetDeviceInfo(
          devices[index], CL_DEVICE_MAX_WORK_ITEM_SIZES, sizeof(dims), dims, nullptr));

      std::stringstream ss;  // use json string to return multiple int values;
      ss << "[" << dims[0] <<", " << dims[1] << ", " << dims[2] << "]";
      *rv = ss.str();
      break;
    }
98 99 100
  }
}

101
void* OpenCLWorkspace::AllocDataSpace(
102
    TVMContext ctx, size_t size, size_t alignment, TVMType type_hint) {
103 104
  this->Init();
  CHECK(context != nullptr) << "No OpenCL device";
105 106 107 108 109 110 111 112
  cl_int err_code;
  cl_mem mptr = clCreateBuffer(
      this->context, CL_MEM_READ_WRITE, size, nullptr, &err_code);
  OPENCL_CHECK_ERROR(err_code);
  return mptr;
}

void OpenCLWorkspace::FreeDataSpace(TVMContext ctx, void* ptr) {
113 114 115 116
  // We have to make sure that the memory object is not in the command queue
  // for some OpenCL platforms.
  OPENCL_CALL(clFinish(this->GetQueue(ctx)));

117 118 119 120 121
  cl_mem mptr = static_cast<cl_mem>(ptr);
  OPENCL_CALL(clReleaseMemObject(mptr));
}

void OpenCLWorkspace::CopyDataFromTo(const void* from,
122
                                     size_t from_offset,
123
                                     void* to,
124
                                     size_t to_offset,
125 126 127
                                     size_t size,
                                     TVMContext ctx_from,
                                     TVMContext ctx_to,
128
                                     TVMType type_hint,
129
                                     TVMStreamHandle stream) {
130
  this->Init();
131
  CHECK(stream == nullptr);
132
  if (IsOpenCLDevice(ctx_from) && IsOpenCLDevice(ctx_to)) {
133 134 135 136
    OPENCL_CALL(clEnqueueCopyBuffer(
        this->GetQueue(ctx_to),
        static_cast<cl_mem>((void*)from),  // NOLINT(*)
        static_cast<cl_mem>(to),
137
        from_offset, to_offset, size, 0, nullptr, nullptr));
138
  } else if (IsOpenCLDevice(ctx_from) && ctx_to.device_type == kDLCPU) {
139 140 141
    OPENCL_CALL(clEnqueueReadBuffer(
        this->GetQueue(ctx_from),
        static_cast<cl_mem>((void*)from),  // NOLINT(*)
142 143
        CL_FALSE, from_offset, size,
        static_cast<char*>(to) + to_offset,
144 145
        0, nullptr, nullptr));
    OPENCL_CALL(clFinish(this->GetQueue(ctx_from)));
146
  } else if (ctx_from.device_type == kDLCPU && IsOpenCLDevice(ctx_to)) {
147 148 149
    OPENCL_CALL(clEnqueueWriteBuffer(
        this->GetQueue(ctx_to),
        static_cast<cl_mem>(to),
150 151
        CL_FALSE, to_offset, size,
        static_cast<const char*>(from) + from_offset,
152 153 154 155 156 157 158 159 160 161 162 163
        0, nullptr, nullptr));
    OPENCL_CALL(clFinish(this->GetQueue(ctx_to)));
  } else {
    LOG(FATAL) << "Expect copy from/to OpenCL or between OpenCL";
  }
}

void OpenCLWorkspace::StreamSync(TVMContext ctx, TVMStreamHandle stream) {
  CHECK(stream == nullptr);
  OPENCL_CALL(clFinish(this->GetQueue(ctx)));
}

164 165 166
void* OpenCLWorkspace::AllocWorkspace(TVMContext ctx,
                                      size_t size,
                                      TVMType type_hint) {
167
  return GetThreadEntry()->pool.AllocWorkspace(ctx, size);
168 169 170
}

void OpenCLWorkspace::FreeWorkspace(TVMContext ctx, void* data) {
171
  GetThreadEntry()->pool.FreeWorkspace(ctx, data);
172 173
}

174 175 176 177 178 179
typedef dmlc::ThreadLocalStore<OpenCLThreadEntry> OpenCLThreadStore;

OpenCLThreadEntry* OpenCLThreadEntry::ThreadLocal() {
  return OpenCLThreadStore::Get();
}

180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201
std::string GetPlatformInfo(
    cl_platform_id pid, cl_platform_info param_name) {
  size_t ret_size;
  OPENCL_CALL(clGetPlatformInfo(pid, param_name, 0, nullptr, &ret_size));
  std::string ret;
  ret.resize(ret_size);
  OPENCL_CALL(clGetPlatformInfo(pid, param_name, ret_size, &ret[0], nullptr));
  return ret;
}

std::string GetDeviceInfo(
    cl_device_id pid, cl_device_info param_name) {
  size_t ret_size;
  OPENCL_CALL(clGetDeviceInfo(pid, param_name, 0, nullptr, &ret_size));
  std::string ret;
  ret.resize(ret_size);
  OPENCL_CALL(clGetDeviceInfo(pid, param_name, ret_size, &ret[0], nullptr));
  return ret;
}

std::vector<cl_platform_id> GetPlatformIDs() {
  cl_uint ret_size;
202
  cl_int code = clGetPlatformIDs(0, nullptr, &ret_size);
203
  std::vector<cl_platform_id> ret;
204
  if (code != CL_SUCCESS) return ret;
205 206 207 208 209 210 211 212 213
  ret.resize(ret_size);
  OPENCL_CALL(clGetPlatformIDs(ret_size, &ret[0], nullptr));
  return ret;
}

std::vector<cl_device_id> GetDeviceIDs(
    cl_platform_id pid, std::string device_type) {
  cl_device_type dtype = CL_DEVICE_TYPE_ALL;
  if (device_type == "cpu") dtype = CL_DEVICE_TYPE_CPU;
214
  if (device_type == "gpu") dtype = CL_DEVICE_TYPE_GPU;
215 216
  if (device_type == "accelerator") dtype = CL_DEVICE_TYPE_ACCELERATOR;
  cl_uint ret_size;
217
  cl_int code = clGetDeviceIDs(pid, dtype, 0, nullptr, &ret_size);
218
  std::vector<cl_device_id> ret;
219
  if (code != CL_SUCCESS) return ret;
220 221 222 223 224 225 226 227 228 229 230 231 232 233
  ret.resize(ret_size);
  OPENCL_CALL(clGetDeviceIDs(pid, dtype, ret_size, &ret[0], nullptr));
  return ret;
}

bool MatchPlatformInfo(
    cl_platform_id pid,
    cl_platform_info param_name,
    std::string value) {
  if (value.length() == 0) return true;
  std::string param_value = GetPlatformInfo(pid, param_name);
  return param_value.find(value) != std::string::npos;
}

234 235
void OpenCLWorkspace::Init(const std::string& type_key, const std::string& device_type,
                           const std::string& platform_name) {
236
  if (initialized_) return;
237
  std::lock_guard<std::mutex> lock(this->mu);
238 239
  if (initialized_) return;
  if (context != nullptr) return;
240
  this->type_key = type_key;
241
  // matched platforms
242 243
  std::vector<cl_platform_id> platform_ids = cl::GetPlatformIDs();
  if (platform_ids.size() == 0) {
244 245
    LOG(WARNING) << "No OpenCL platform matched given existing options ...";
    return;
246
  }
247 248 249 250 251 252
  this->platform_id = nullptr;
  for (auto platform_id : platform_ids) {
    if (!MatchPlatformInfo(platform_id, CL_PLATFORM_NAME, platform_name)) {
      continue;
    }
    std::vector<cl_device_id> devices_matched = cl::GetDeviceIDs(platform_id, device_type);
253 254 255 256
    if ((devices_matched.size() == 0) && (device_type == "gpu")) {
      LOG(WARNING) << "Using CPU OpenCL device";
      devices_matched = cl::GetDeviceIDs(platform_id, "cpu");
    }
257
    if (devices_matched.size() > 0) {
258 259 260 261
      this->platform_id = platform_id;
      this->platform_name = cl::GetPlatformInfo(platform_id, CL_PLATFORM_NAME);
      this->device_type = device_type;
      this->devices = devices_matched;
262
      break;
263
    }
264
  }
265
  if (this->platform_id == nullptr) {
266 267
    LOG(WARNING) << "No OpenCL device";
    return;
268
  }
269
  cl_int err_code;
270 271
  this->context = clCreateContext(
      nullptr, this->devices.size(), &(this->devices[0]),
272 273
      nullptr, nullptr, &err_code);
  OPENCL_CHECK_ERROR(err_code);
274 275 276 277 278
  CHECK_EQ(this->queues.size(), 0U);
  for (size_t i = 0; i < this->devices.size(); ++i) {
    cl_device_id did = this->devices[i];
    this->queues.push_back(
        clCreateCommandQueue(this->context, did, 0, &err_code));
279 280
    OPENCL_CHECK_ERROR(err_code);
  }
281
  initialized_ = true;
282 283
}

284
TVM_REGISTER_GLOBAL("device_api.opencl")
285
.set_body([](TVMArgs args, TVMRetValue* rv) {
286
    DeviceAPI* ptr = OpenCLWorkspace::Global().get();
287 288 289
    *rv = static_cast<void*>(ptr);
  });

290 291 292
}  // namespace cl
}  // namespace runtime
}  // namespace tvm