opencl_device_api.cc 10.3 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20
/*!
21
 * \file opencl_device_api.cc
22
 */
23
#include <tvm/runtime/registry.h>
24
#include <dmlc/thread_local.h>
25
#include "opencl_common.h"
26 27 28 29 30

namespace tvm {
namespace runtime {
namespace cl {

31 32 33 34
OpenCLThreadEntry* OpenCLWorkspace::GetThreadEntry() {
  return OpenCLThreadEntry::ThreadLocal();
}

35 36 37
const std::shared_ptr<OpenCLWorkspace>& OpenCLWorkspace::Global() {
  static std::shared_ptr<OpenCLWorkspace> inst = std::make_shared<OpenCLWorkspace>();
  return inst;
38 39
}

40
void OpenCLWorkspace::SetDevice(TVMContext ctx) {
41
  GetThreadEntry()->context.device_id = ctx.device_id;
42 43 44
}

void OpenCLWorkspace::GetAttr(
45
    TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) {
46
  this->Init();
47
  size_t index = static_cast<size_t>(ctx.device_id);
48 49 50 51 52 53 54
  if (kind == kExist) {
    *rv = static_cast<int>(index< devices.size());
    return;
  }
  CHECK_LT(index, devices.size())
      << "Invalid device id " << index;
  switch (kind) {
55
    case kExist: break;
56
    case kMaxThreadsPerBlock: {
57
      size_t value;
58 59 60 61 62 63 64
      OPENCL_CALL(clGetDeviceInfo(
          devices[index],  CL_DEVICE_MAX_WORK_GROUP_SIZE,
          sizeof(size_t), &value, nullptr));
      *rv = static_cast<int64_t>(value);
      break;
    }
    case kWarpSize: {
65
      /* TODO: the warp size of OpenCL device is not always 1
66
               e.g. Intel Graphics has a sub group concept which contains 8 - 32 work items,
67 68 69
               corresponding to the number of SIMD entries the heardware configures.
               We need to figure out a way to query this information from the hardware.
      */
70 71 72
      *rv = 1;
      break;
    }
73 74 75 76 77 78 79 80 81
    case kMaxSharedMemoryPerBlock: {
      cl_ulong value;
      OPENCL_CALL(clGetDeviceInfo(
          devices[index], CL_DEVICE_LOCAL_MEM_SIZE,
          sizeof(cl_ulong), &value, nullptr));
      *rv = static_cast<int64_t>(value);
      break;
    }
    case kComputeVersion: return;
82 83 84 85 86 87 88 89
    case kDeviceName: {
      char value[128] = {0};
      OPENCL_CALL(clGetDeviceInfo(
          devices[index], CL_DEVICE_NAME,
          sizeof(value) - 1, value, nullptr));
      *rv = std::string(value);
      break;
    }
90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
    case kMaxClockRate: {
      cl_uint value;
      OPENCL_CALL(clGetDeviceInfo(
          devices[index], CL_DEVICE_MAX_CLOCK_FREQUENCY,
          sizeof(cl_uint), &value, nullptr));
      *rv = static_cast<int32_t>(value);
      break;
    }
    case kMultiProcessorCount: {
      cl_uint value;
      OPENCL_CALL(clGetDeviceInfo(
          devices[index], CL_DEVICE_MAX_COMPUTE_UNITS,
          sizeof(cl_uint), &value, nullptr));
      *rv = static_cast<int32_t>(value);
      break;
    }
106 107 108 109 110 111 112 113 114 115
    case kMaxThreadDimensions: {
      size_t dims[3];
      OPENCL_CALL(clGetDeviceInfo(
          devices[index], CL_DEVICE_MAX_WORK_ITEM_SIZES, sizeof(dims), dims, nullptr));

      std::stringstream ss;  // use json string to return multiple int values;
      ss << "[" << dims[0] <<", " << dims[1] << ", " << dims[2] << "]";
      *rv = ss.str();
      break;
    }
116
    case kGcnArch: return;
117 118 119
  }
}

120
void* OpenCLWorkspace::AllocDataSpace(
121
    TVMContext ctx, size_t size, size_t alignment, DLDataType type_hint) {
122 123
  this->Init();
  CHECK(context != nullptr) << "No OpenCL device";
124 125 126 127 128 129 130 131
  cl_int err_code;
  cl_mem mptr = clCreateBuffer(
      this->context, CL_MEM_READ_WRITE, size, nullptr, &err_code);
  OPENCL_CHECK_ERROR(err_code);
  return mptr;
}

void OpenCLWorkspace::FreeDataSpace(TVMContext ctx, void* ptr) {
132 133 134 135
  // We have to make sure that the memory object is not in the command queue
  // for some OpenCL platforms.
  OPENCL_CALL(clFinish(this->GetQueue(ctx)));

136 137 138 139 140
  cl_mem mptr = static_cast<cl_mem>(ptr);
  OPENCL_CALL(clReleaseMemObject(mptr));
}

void OpenCLWorkspace::CopyDataFromTo(const void* from,
141
                                     size_t from_offset,
142
                                     void* to,
143
                                     size_t to_offset,
144 145 146
                                     size_t size,
                                     TVMContext ctx_from,
                                     TVMContext ctx_to,
147
                                     DLDataType type_hint,
148
                                     TVMStreamHandle stream) {
149
  this->Init();
150
  CHECK(stream == nullptr);
151
  if (IsOpenCLDevice(ctx_from) && IsOpenCLDevice(ctx_to)) {
152 153 154 155
    OPENCL_CALL(clEnqueueCopyBuffer(
        this->GetQueue(ctx_to),
        static_cast<cl_mem>((void*)from),  // NOLINT(*)
        static_cast<cl_mem>(to),
156
        from_offset, to_offset, size, 0, nullptr, nullptr));
157
  } else if (IsOpenCLDevice(ctx_from) && ctx_to.device_type == kDLCPU) {
158 159 160
    OPENCL_CALL(clEnqueueReadBuffer(
        this->GetQueue(ctx_from),
        static_cast<cl_mem>((void*)from),  // NOLINT(*)
161 162
        CL_FALSE, from_offset, size,
        static_cast<char*>(to) + to_offset,
163 164
        0, nullptr, nullptr));
    OPENCL_CALL(clFinish(this->GetQueue(ctx_from)));
165
  } else if (ctx_from.device_type == kDLCPU && IsOpenCLDevice(ctx_to)) {
166 167 168
    OPENCL_CALL(clEnqueueWriteBuffer(
        this->GetQueue(ctx_to),
        static_cast<cl_mem>(to),
169 170
        CL_FALSE, to_offset, size,
        static_cast<const char*>(from) + from_offset,
171 172 173 174 175 176 177 178 179 180 181 182
        0, nullptr, nullptr));
    OPENCL_CALL(clFinish(this->GetQueue(ctx_to)));
  } else {
    LOG(FATAL) << "Expect copy from/to OpenCL or between OpenCL";
  }
}

void OpenCLWorkspace::StreamSync(TVMContext ctx, TVMStreamHandle stream) {
  CHECK(stream == nullptr);
  OPENCL_CALL(clFinish(this->GetQueue(ctx)));
}

183 184
void* OpenCLWorkspace::AllocWorkspace(TVMContext ctx,
                                      size_t size,
185
                                      DLDataType type_hint) {
186
  return GetThreadEntry()->pool.AllocWorkspace(ctx, size);
187 188 189
}

void OpenCLWorkspace::FreeWorkspace(TVMContext ctx, void* data) {
190
  GetThreadEntry()->pool.FreeWorkspace(ctx, data);
191 192
}

193 194 195 196 197 198
typedef dmlc::ThreadLocalStore<OpenCLThreadEntry> OpenCLThreadStore;

OpenCLThreadEntry* OpenCLThreadEntry::ThreadLocal() {
  return OpenCLThreadStore::Get();
}

199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220
std::string GetPlatformInfo(
    cl_platform_id pid, cl_platform_info param_name) {
  size_t ret_size;
  OPENCL_CALL(clGetPlatformInfo(pid, param_name, 0, nullptr, &ret_size));
  std::string ret;
  ret.resize(ret_size);
  OPENCL_CALL(clGetPlatformInfo(pid, param_name, ret_size, &ret[0], nullptr));
  return ret;
}

std::string GetDeviceInfo(
    cl_device_id pid, cl_device_info param_name) {
  size_t ret_size;
  OPENCL_CALL(clGetDeviceInfo(pid, param_name, 0, nullptr, &ret_size));
  std::string ret;
  ret.resize(ret_size);
  OPENCL_CALL(clGetDeviceInfo(pid, param_name, ret_size, &ret[0], nullptr));
  return ret;
}

std::vector<cl_platform_id> GetPlatformIDs() {
  cl_uint ret_size;
221
  cl_int code = clGetPlatformIDs(0, nullptr, &ret_size);
222
  std::vector<cl_platform_id> ret;
223
  if (code != CL_SUCCESS) return ret;
224 225 226 227 228 229 230 231 232
  ret.resize(ret_size);
  OPENCL_CALL(clGetPlatformIDs(ret_size, &ret[0], nullptr));
  return ret;
}

std::vector<cl_device_id> GetDeviceIDs(
    cl_platform_id pid, std::string device_type) {
  cl_device_type dtype = CL_DEVICE_TYPE_ALL;
  if (device_type == "cpu") dtype = CL_DEVICE_TYPE_CPU;
233
  if (device_type == "gpu") dtype = CL_DEVICE_TYPE_GPU;
234 235
  if (device_type == "accelerator") dtype = CL_DEVICE_TYPE_ACCELERATOR;
  cl_uint ret_size;
236
  cl_int code = clGetDeviceIDs(pid, dtype, 0, nullptr, &ret_size);
237
  std::vector<cl_device_id> ret;
238
  if (code != CL_SUCCESS) return ret;
239 240 241 242 243 244 245 246 247 248 249 250 251 252
  ret.resize(ret_size);
  OPENCL_CALL(clGetDeviceIDs(pid, dtype, ret_size, &ret[0], nullptr));
  return ret;
}

bool MatchPlatformInfo(
    cl_platform_id pid,
    cl_platform_info param_name,
    std::string value) {
  if (value.length() == 0) return true;
  std::string param_value = GetPlatformInfo(pid, param_name);
  return param_value.find(value) != std::string::npos;
}

253 254
void OpenCLWorkspace::Init(const std::string& type_key, const std::string& device_type,
                           const std::string& platform_name) {
255
  if (initialized_) return;
256
  std::lock_guard<std::mutex> lock(this->mu);
257 258
  if (initialized_) return;
  if (context != nullptr) return;
259
  this->type_key = type_key;
260
  // matched platforms
261 262
  std::vector<cl_platform_id> platform_ids = cl::GetPlatformIDs();
  if (platform_ids.size() == 0) {
263 264
    LOG(WARNING) << "No OpenCL platform matched given existing options ...";
    return;
265
  }
266 267 268 269 270 271
  this->platform_id = nullptr;
  for (auto platform_id : platform_ids) {
    if (!MatchPlatformInfo(platform_id, CL_PLATFORM_NAME, platform_name)) {
      continue;
    }
    std::vector<cl_device_id> devices_matched = cl::GetDeviceIDs(platform_id, device_type);
272 273 274 275
    if ((devices_matched.size() == 0) && (device_type == "gpu")) {
      LOG(WARNING) << "Using CPU OpenCL device";
      devices_matched = cl::GetDeviceIDs(platform_id, "cpu");
    }
276
    if (devices_matched.size() > 0) {
277 278 279 280
      this->platform_id = platform_id;
      this->platform_name = cl::GetPlatformInfo(platform_id, CL_PLATFORM_NAME);
      this->device_type = device_type;
      this->devices = devices_matched;
281
      break;
282
    }
283
  }
284
  if (this->platform_id == nullptr) {
285 286
    LOG(WARNING) << "No OpenCL device";
    return;
287
  }
288
  cl_int err_code;
289 290
  this->context = clCreateContext(
      nullptr, this->devices.size(), &(this->devices[0]),
291 292
      nullptr, nullptr, &err_code);
  OPENCL_CHECK_ERROR(err_code);
293 294 295 296 297
  CHECK_EQ(this->queues.size(), 0U);
  for (size_t i = 0; i < this->devices.size(); ++i) {
    cl_device_id did = this->devices[i];
    this->queues.push_back(
        clCreateCommandQueue(this->context, did, 0, &err_code));
298 299
    OPENCL_CHECK_ERROR(err_code);
  }
300
  initialized_ = true;
301 302
}

303
TVM_REGISTER_GLOBAL("device_api.opencl")
304
.set_body([](TVMArgs args, TVMRetValue* rv) {
305
    DeviceAPI* ptr = OpenCLWorkspace::Global().get();
306 307 308
    *rv = static_cast<void*>(ptr);
  });

309 310 311
}  // namespace cl
}  // namespace runtime
}  // namespace tvm