opencl_module.cc 9.85 KB
Newer Older
1 2 3 4 5 6 7 8 9
/*!
 *  Copyright (c) 2017 by Contributors
 * \file opencl_module.cc
 */
#include "./opencl_common.h"
#include "./opencl_module.h"

#if TVM_OPENCL_RUNTIME

10
#include <dmlc/memory_io.h>
11
#include <tvm/runtime/registry.h>
12 13 14
#include <vector>
#include <string>
#include <unordered_map>
15
#include "../pack_args.h"
16
#include "../thread_storage_scope.h"
17 18
#include "../meta_data.h"
#include "../file_util.h"
19 20 21 22

namespace tvm {
namespace runtime {

23 24 25 26 27 28
// Module to support thread-safe multi-device execution.
// OpenCL runtime is a bit tricky because clSetKernelArg is not thread-safe
// To make the call thread-safe, we create a thread-local kernel table
// and lazily install new kernels into the kernel table when the kernel is called.
// The kernels are recycled when the module get destructed.
class OpenCLModuleNode : public ModuleNode {
29
 public:
30 31 32 33 34 35 36 37 38 39 40 41 42
  // Kernel table reference entry.
  struct KTRefEntry {
    size_t kernel_id;
    size_t version;
  };
  explicit OpenCLModuleNode(std::string data,
                            std::string fmt,
                            std::unordered_map<std::string, FunctionInfo> fmap)
      : data_(data), fmt_(fmt), fmap_(fmap) {}
  // destructor
  ~OpenCLModuleNode() {
    {
      // free the kernel ids in global table.
43
      std::lock_guard<std::mutex> lock(workspace_->mu);
44
      for (auto& kv : kid_map_) {
45
        workspace_->free_kernel_ids.push_back(kv.second.kernel_id);
46 47 48 49 50 51 52 53 54 55
      }
    }
    // free the kernels
    for (cl_kernel k : kernels_) {
      OPENCL_CALL(clReleaseKernel(k));
    }
    if (program_) {
      OPENCL_CALL(clReleaseProgram(program_));
    }
  }
56

57 58
  const char* type_key() const final {
    return "opencl";
59
  }
60 61 62 63 64 65 66

  PackedFunc GetFunction(
      const std::string& name,
      const std::shared_ptr<ModuleNode>& sptr_to_self) final;

  void SaveToFile(const std::string& file_name,
                  const std::string& format) final {
67 68 69 70 71 72
    std::string fmt = GetFileFormat(file_name, format);
    CHECK_EQ(fmt, fmt_)
        << "Can only save to format=" << fmt_;
    std::string meta_file = GetMetaFilePath(file_name);
    SaveMetaDataToFile(meta_file, fmap_);
    SaveBinaryToFile(file_name, data_);
73
  }
74

75 76 77 78 79 80
  void SaveToBinary(dmlc::Stream* stream) final {
    stream->Write(fmt_);
    stream->Write(fmap_);
    stream->Write(data_);
  }

81 82 83 84 85 86 87 88 89 90
  std::string GetSource(const std::string& format) final {
    if (format == fmt_) return data_;
    if (fmt_ == "cl") {
      return data_;
    } else {
      return "";
    }
  }

  // Initialize the programs
91 92 93 94
  void Init() {
    workspace_ = cl::OpenCLWorkspace::Global();
    workspace_->Init();
    CHECK(workspace_->context != nullptr) << "No OpenCL device";
95 96 97 98 99
    if (fmt_ == "cl") {
      const char* s = data_.c_str();
      size_t len = data_.length();
      cl_int err;
      program_ = clCreateProgramWithSource(
100
          workspace_->context, 1, &s, &len, &err);
101 102 103 104
      OPENCL_CHECK_ERROR(err);
    } else {
      LOG(FATAL) << "Unknown OpenCL format " << fmt_;
    }
105
    device_built_flag_.resize(workspace_->devices.size(), false);
106
    // initialize the kernel id, need to lock global table.
107
    std::lock_guard<std::mutex> lock(workspace_->mu);
108 109 110
    for (const auto& kv : fmap_) {
      const std::string& key = kv.first;
      KTRefEntry e;
111 112 113
      if (workspace_->free_kernel_ids.size() != 0) {
        e.kernel_id = workspace_->free_kernel_ids.back();
        workspace_->free_kernel_ids.pop_back();
114
      } else {
115
        e.kernel_id = workspace_->num_registered_kernels++;
116
      }
117
      e.version = workspace_->timestamp++;
118 119 120 121 122 123 124 125 126
      kid_map_[key] = e;
    }
  }
  // install a new kernel to thread local entry
  cl_kernel InstallKernel(cl::OpenCLWorkspace* w,
                          cl::OpenCLThreadEntry* t,
                          const std::string& func_name,
                          const KTRefEntry& e) {
    std::lock_guard<std::mutex> lock(build_lock_);
127 128
    int device_id = t->context.device_id;
    if (!device_built_flag_[device_id]) {
129 130
      // build program
      cl_int err;
131
      cl_device_id dev = w->devices[device_id];
132 133 134 135 136 137 138 139 140 141 142
      err = clBuildProgram(program_, 1, &dev, nullptr, nullptr, nullptr);
      if (err != CL_SUCCESS) {
        size_t len;
        std::string log;
        clGetProgramBuildInfo(
            program_, dev, CL_PROGRAM_BUILD_LOG, 0, nullptr, &len);
        log.resize(len);
        clGetProgramBuildInfo(
            program_, dev, CL_PROGRAM_BUILD_LOG, len, &log[0], nullptr);
        LOG(FATAL) << "OpenCL build error for device=" << dev << log;
      }
143
      device_built_flag_[device_id] = true;
144 145 146 147 148 149 150 151 152
    }
    // build kernel
    cl_int err;
    cl_kernel kernel = clCreateKernel(program_, func_name.c_str(), &err);
    OPENCL_CHECK_ERROR(err);
    t->kernel_table[e.kernel_id].kernel = kernel;
    t->kernel_table[e.kernel_id].version = e.version;
    kernels_.push_back(kernel);
    return kernel;
153
  }
154 155

 private:
156 157 158
  // The workspace, need to keep reference to use it in destructor.
  // In case of static destruction order problem.
  std::shared_ptr<cl::OpenCLWorkspace> workspace_;
159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174
  // the binary data
  std::string data_;
  // The format
  std::string fmt_;
  // function information table.
  std::unordered_map<std::string, FunctionInfo> fmap_;
  // Module local mutex
  std::mutex build_lock_;
  // the binary data
  cl_program program_{nullptr};
  // build info
  std::vector<bool> device_built_flag_;
  // kernel id cache
  std::unordered_map<std::string, KTRefEntry> kid_map_;
  // kernels build so far.
  std::vector<cl_kernel> kernels_;
175 176 177 178 179
};

class OpenCLWrappedFunc {
 public:
  // initialize the CUDA function.
180 181 182
  void Init(OpenCLModuleNode* m,
            std::shared_ptr<ModuleNode> sptr,
            OpenCLModuleNode::KTRefEntry entry,
183 184 185
            std::string func_name,
            std::vector<size_t> arg_size,
            const std::vector<std::string>& thread_axis_tags)  {
186
    w_ = cl::OpenCLWorkspace::Global().get();
187
    m_ = m;
188 189
    sptr_ = sptr;
    entry_ = entry;
190 191 192 193 194 195 196 197 198 199
    func_name_ = func_name;
    arg_size_ = arg_size;
    thread_axis_cfg_.Init(arg_size.size(), thread_axis_tags);
  }
  // invoke the function with void arguments
  void operator()(TVMArgs args,
                  TVMRetValue* rv,
                  void** void_args) const {
    cl::OpenCLThreadEntry* t = cl::OpenCLThreadEntry::ThreadLocal();
    // get the kernel from thread local kernel table.
200 201
    if (entry_.kernel_id >= t->kernel_table.size()) {
      t->kernel_table.resize(entry_.kernel_id + 1);
202
    }
203 204 205
    const auto& e = t->kernel_table[entry_.kernel_id];
    cl_kernel kernel = e.kernel;
    if (kernel == nullptr || e.version != entry_.version) {
206
      kernel = m_->InstallKernel(w_, t, func_name_, entry_);
207 208 209 210 211
    }
    // setup arguments.
    for (cl_uint i = 0; i < arg_size_.size(); ++i) {
      OPENCL_CALL(clSetKernelArg(kernel, i, arg_size_[i], void_args[i]));
    }
212
    cl_command_queue queue = w_->GetQueue(t->context);
213 214 215
    ThreadWorkLoad wl = thread_axis_cfg_.Extract(args);
    cl_uint work_dim = static_cast<cl_uint>(thread_axis_cfg_.work_dim());
    for (cl_uint i = 0; i < work_dim; ++i) {
216
      wl.work_size[i] *= wl.work_size[i + 3];
217 218 219 220 221
    }
    // launch kernel
    OPENCL_CALL(clEnqueueNDRangeKernel(
        queue, kernel, work_dim, nullptr,
        wl.work_size,
222
        wl.work_size + 3,
223 224 225 226
        0, nullptr, nullptr));
  }

 private:
227 228
  // global workspace.
  cl::OpenCLWorkspace* w_;
229 230 231 232
  // The module
  OpenCLModuleNode* m_;
  // resource handle
  std::shared_ptr<ModuleNode> sptr_;
233
  // global kernel id in the kernel table.
234
  OpenCLModuleNode::KTRefEntry entry_;
235 236 237 238 239 240 241 242
  // The name of the function.
  std::string func_name_;
  // convert code for void argument
  std::vector<size_t> arg_size_;
  // thread axis config
  ThreadAxisConfig thread_axis_cfg_;
};

243 244 245 246 247 248 249 250 251
PackedFunc OpenCLModuleNode::GetFunction(
    const std::string& name,
    const std::shared_ptr<ModuleNode>& sptr_to_self) {
  CHECK_EQ(sptr_to_self.get(), this);
  CHECK_NE(name, symbol::tvm_module_main)
      << "Device function do not have main";
  auto it = fmap_.find(name);
  if (it == fmap_.end()) return PackedFunc();
  const FunctionInfo& info = it->second;
252
  OpenCLWrappedFunc f;
253 254 255
  std::vector<size_t> arg_size(info.arg_types.size());
  for (size_t i = 0; i < info.arg_types.size(); ++i) {
    TVMType t = info.arg_types[i];
256 257 258 259 260 261
    CHECK_EQ(t.lanes, 1U);
    uint32_t bits = t.bits;
    CHECK_EQ(bits % 8, 0U);
    arg_size[i] = bits / 8;
  }
  // initialize the wrapped func.
262 263
  f.Init(this, sptr_to_self, kid_map_.at(name),
         name, arg_size, info.thread_axis_tags);
264
  return PackFuncVoidAddr(f, info.arg_types);
265 266
}

267 268 269 270 271 272
Module OpenCLModuleCreate(
    std::string data,
    std::string fmt,
    std::unordered_map<std::string, FunctionInfo> fmap) {
  std::shared_ptr<OpenCLModuleNode> n =
      std::make_shared<OpenCLModuleNode>(data, fmt, fmap);
273
  n->Init();
274
  return Module(n);
275 276
}

277
// Load module from module.
278 279
Module OpenCLModuleLoadFile(const std::string& file_name,
                            const std::string& format) {
280 281 282 283 284 285 286 287 288
  std::string data;
  std::unordered_map<std::string, FunctionInfo> fmap;
  std::string fmt = GetFileFormat(file_name, format);
  std::string meta_file = GetMetaFilePath(file_name);
  LoadBinaryFromFile(file_name, &data);
  LoadMetaDataFromFile(meta_file, &fmap);
  return OpenCLModuleCreate(data, fmt, fmap);
}

289 290 291 292 293 294 295 296 297 298 299
Module OpenCLModuleLoadBinary(void* strm) {
  dmlc::Stream* stream = static_cast<dmlc::Stream*>(strm);
  std::string data;
  std::unordered_map<std::string, FunctionInfo> fmap;
  std::string fmt;
  stream->Read(&fmt);
  stream->Read(&fmap);
  stream->Read(&data);
  return OpenCLModuleCreate(data, fmt, fmap);
}

300
TVM_REGISTER_GLOBAL("module.loadfile_cl")
301
.set_body([](TVMArgs args, TVMRetValue* rv) {
302
    *rv = OpenCLModuleLoadFile(args[0], args[1]);
303 304
  });

305
TVM_REGISTER_GLOBAL("module.loadfile_clbin")
306
.set_body([](TVMArgs args, TVMRetValue* rv) {
307 308 309 310 311 312
    *rv = OpenCLModuleLoadFile(args[0], args[1]);
  });

TVM_REGISTER_GLOBAL("module.loadbinary_opencl")
.set_body([](TVMArgs args, TVMRetValue* rv) {
    *rv = OpenCLModuleLoadBinary(args[0]);
313
  });
314 315 316 317
}  // namespace runtime
}  // namespace tvm

#endif  // TVM_OPENCL_RUNTIME