cuda_module.cc 9.47 KB
Newer Older
1 2 3 4 5 6 7
/*!
 *  Copyright (c) 2017 by Contributors
 * \file cuda_module.cc
 */
#include "./cuda_module.h"

#if TVM_CUDA_RUNTIME
8

9
#include <tvm/runtime/registry.h>
10 11 12 13 14 15 16
#include <cuda.h>
#include <cuda_runtime.h>
#include <vector>
#include <array>
#include <string>
#include <mutex>
#include "./cuda_common.h"
17
#include "../pack_args.h"
18
#include "../thread_storage_scope.h"
19 20
#include "../meta_data.h"
#include "../file_util.h"
21 22 23 24

namespace tvm {
namespace runtime {

25 26 27 28 29
// Module to support thread-safe multi-GPU execution.
// cuModule is a per-GPU module
// The runtime will contain a per-device module table
// The modules will be lazily loaded
class CUDAModuleNode : public runtime::ModuleNode {
30
 public:
31 32 33 34 35
  explicit CUDAModuleNode(std::string data,
                          std::string fmt,
                          std::unordered_map<std::string, FunctionInfo> fmap,
                          std::string cuda_source)
      : data_(data), fmt_(fmt), fmap_(fmap), cuda_source_(cuda_source) {
36 37
    std::fill(module_.begin(), module_.end(), nullptr);
  }
38 39 40 41
  // destructor
  ~CUDAModuleNode() {
    for (size_t i = 0; i < module_.size(); ++i) {
      if (module_[i] != nullptr) {
42
        CUDA_CALL(cudaSetDevice(static_cast<int>(i)));
43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
        CUDA_DRIVER_CALL(cuModuleUnload(module_[i]));
      }
    }
  }

  const char* type_key() const final {
    return "cuda";
  }

  PackedFunc GetFunction(
      const std::string& name,
      const std::shared_ptr<ModuleNode>& sptr_to_self) final;

  void SaveToFile(const std::string& file_name,
                  const std::string& format) final {
58 59
    std::string fmt = GetFileFormat(file_name, format);
    std::string meta_file = GetMetaFilePath(file_name);
60 61 62 63 64 65 66 67 68 69
    if (fmt == "cu") {
      CHECK_NE(cuda_source_.length(), 0);
      SaveMetaDataToFile(meta_file, fmap_);
      SaveBinaryToFile(file_name, cuda_source_);
    } else {
      CHECK_EQ(fmt, fmt_)
          << "Can only save to format=" << fmt_;
      SaveMetaDataToFile(meta_file, fmap_);
      SaveBinaryToFile(file_name, data_);
    }
70 71
  }

72 73 74 75 76 77
  void SaveToBinary(dmlc::Stream* stream) final {
    stream->Write(fmt_);
    stream->Write(fmap_);
    stream->Write(data_);
  }

78 79 80 81 82 83 84 85 86 87
  std::string GetSource(const std::string& format) final {
    if (format == fmt_) return data_;
    if (cuda_source_.length() != 0) {
      return cuda_source_;
    } else {
      if (fmt_ == "ptx") return data_;
      return "";
    }
  }

88 89
  // get a CUfunction from primary context in device_id
  CUfunction GetFunc(int device_id, const std::string& func_name) {
90 91
    std::lock_guard<std::mutex> lock(mutex_);
    // must recheck under the lock scope
92 93
    if (module_[device_id] == nullptr) {
      CUDA_DRIVER_CALL(cuModuleLoadData(&(module_[device_id]), data_.c_str()));
94 95
    }
    CUfunction func;
96
    CUresult result = cuModuleGetFunction(&func, module_[device_id], func_name.c_str());
97 98 99 100 101 102 103 104 105
    if (result != CUDA_SUCCESS) {
      const char *msg;
      cuGetErrorName(result, &msg);
      LOG(FATAL)
          << "CUDAError: cuModuleGetFunction " << func_name
          << " failed with error: " << msg;
    }
    return func;
  }
106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129
  // get a global var from primary context in device_id
  CUdeviceptr GetGlobal(int device_id,
                        const std::string& global_name,
                        size_t expect_nbytes) {
    std::lock_guard<std::mutex> lock(mutex_);
    // must recheck under the lock scope
    if (module_[device_id] == nullptr) {
      CUDA_DRIVER_CALL(cuModuleLoadData(&(module_[device_id]), data_.c_str()));
    }
    CUdeviceptr global;
    size_t nbytes;

    CUresult result = cuModuleGetGlobal(&global, &nbytes,
                                        module_[device_id], global_name.c_str());
    CHECK_EQ(nbytes, expect_nbytes);
    if (result != CUDA_SUCCESS) {
      const char *msg;
      cuGetErrorName(result, &msg);
      LOG(FATAL)
          << "CUDAError: cuModuleGetGlobal " << global_name
          << " failed with error: " << msg;
    }
    return global;
  }
130 131 132 133

 private:
  // the binary data
  std::string data_;
134 135 136 137 138 139
  // The format
  std::string fmt_;
  // function information table.
  std::unordered_map<std::string, FunctionInfo> fmap_;
  // The cuda source.
  std::string cuda_source_;
140
  // the internal modules per GPU, to be lazily initialized.
141
  std::array<CUmodule, kMaxNumGPUs> module_;
142 143 144 145 146 147 148 149
  // internal mutex when updating the module
  std::mutex mutex_;
};

// a wrapped function class to get packed fucn.
class CUDAWrappedFunc {
 public:
  // initialize the CUDA function.
150 151
  void Init(CUDAModuleNode* m,
            std::shared_ptr<ModuleNode> sptr,
152 153
            const std::string& func_name,
            size_t num_void_args,
154
            const std::vector<std::string>& thread_axis_tags) {
155
    m_ = m;
156
    sptr_ = sptr;
157 158 159 160 161 162 163 164
    func_name_ = func_name;
    std::fill(fcache_.begin(), fcache_.end(), nullptr);
    thread_axis_cfg_.Init(num_void_args, thread_axis_tags);
  }
  // invoke the function with void arguments
  void operator()(TVMArgs args,
                  TVMRetValue* rv,
                  void** void_args) const {
165 166 167 168
    int device_id;
    CUDA_CALL(cudaGetDevice(&device_id));
    if (fcache_[device_id] == nullptr) {
      fcache_[device_id] = m_->GetFunc(device_id, func_name_);
169
    }
170
    CUstream strm = static_cast<CUstream>(CUDAThreadEntry::ThreadLocal()->stream);
171
    ThreadWorkLoad wl = thread_axis_cfg_.Extract(args);
172
    CUresult result = cuLaunchKernel(
173
        fcache_[device_id],
174 175 176 177 178 179
        wl.grid_dim(0),
        wl.grid_dim(1),
        wl.grid_dim(2),
        wl.block_dim(0),
        wl.block_dim(1),
        wl.block_dim(2),
180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198
        0, strm, void_args, 0);
    if (result != CUDA_SUCCESS && result != CUDA_ERROR_DEINITIALIZED) {
      const char *msg;
      cuGetErrorName(result, &msg);
      std::ostringstream os;
      os << "CUDALaunch Error: " << msg << "\n"
         << " grid=(" << wl.grid_dim(0) << ","
         << wl.grid_dim(1) << "," << wl.grid_dim(2) << "), "
         << " block=(" << wl.block_dim(0) << ","
         << wl.block_dim(1) << "," << wl.block_dim(2) << ")\n";
      std::string cuda = m_->GetSource("");
      if (cuda.length() != 0) {
        os << "// func_name=" << func_name_ << "\n"
           << "// CUDA Source\n"
           << "// -----------\n"
           << cuda;
      }
      LOG(FATAL) << os.str();
    }
199 200 201 202
  }

 private:
  // internal module
203 204 205
  CUDAModuleNode* m_;
  // the resource holder
  std::shared_ptr<ModuleNode> sptr_;
206 207 208 209
  // The name of the function.
  std::string func_name_;
  // Device function cache per device.
  // mark as mutable, to enable lazy initialization
210
  mutable std::array<CUfunction, kMaxNumGPUs> fcache_;
211 212 213 214
  // thread axis configuration
  ThreadAxisConfig thread_axis_cfg_;
};

215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241
class CUDAPrepGlobalBarrier {
 public:
  CUDAPrepGlobalBarrier(CUDAModuleNode* m,
                        std::shared_ptr<ModuleNode> sptr)
      : m_(m), sptr_(sptr) {
    std::fill(pcache_.begin(), pcache_.end(), 0);
  }

  void operator()(const TVMArgs& args, TVMRetValue* rv) const {
    int device_id;
    CUDA_CALL(cudaGetDevice(&device_id));
    if (pcache_[device_id] == 0) {
      pcache_[device_id] = m_->GetGlobal(
          device_id, runtime::symbol::tvm_global_barrier_state, sizeof(unsigned));
    }
    CUDA_DRIVER_CALL(cuMemsetD32(pcache_[device_id], 0, 1));
  }

 private:
  // internal module
  CUDAModuleNode* m_;
  // the resource holder
  std::shared_ptr<ModuleNode> sptr_;
  // mark as mutable, to enable lazy initialization
  mutable std::array<CUdeviceptr, kMaxNumGPUs> pcache_;
};

242 243 244 245 246 247
PackedFunc CUDAModuleNode::GetFunction(
      const std::string& name,
      const std::shared_ptr<ModuleNode>& sptr_to_self) {
  CHECK_EQ(sptr_to_self.get(), this);
  CHECK_NE(name, symbol::tvm_module_main)
      << "Device function do not have main";
248
  if (name == symbol::tvm_prepare_global_barrier) {
249
    return PackedFunc(CUDAPrepGlobalBarrier(this, sptr_to_self));
250 251 252 253
  }
  auto it = fmap_.find(name);
  if (it == fmap_.end()) return PackedFunc();
  const FunctionInfo& info = it->second;
254
  CUDAWrappedFunc f;
255
  f.Init(this, sptr_to_self, name, info.arg_types.size(), info.thread_axis_tags);
256
  return PackFuncVoidAddr(f, info.arg_types);
257 258
}

259 260 261 262 263 264 265 266
Module CUDAModuleCreate(
    std::string data,
    std::string fmt,
    std::unordered_map<std::string, FunctionInfo> fmap,
    std::string cuda_source) {
  std::shared_ptr<CUDAModuleNode> n =
      std::make_shared<CUDAModuleNode>(data, fmt, fmap, cuda_source);
  return Module(n);
267
}
268 269

// Load module from module.
270 271
Module CUDAModuleLoadFile(const std::string& file_name,
                          const std::string& format) {
272 273
  std::string data;
  std::unordered_map<std::string, FunctionInfo> fmap;
274
  std::string fmt = GetFileFormat(file_name, format);
275 276 277 278
  std::string meta_file = GetMetaFilePath(file_name);
  LoadBinaryFromFile(file_name, &data);
  LoadMetaDataFromFile(meta_file, &fmap);
  return CUDAModuleCreate(data, fmt, fmap, std::string());
279 280
}

281 282 283 284 285 286 287 288 289 290 291
Module CUDAModuleLoadBinary(void* strm) {
  dmlc::Stream* stream = static_cast<dmlc::Stream*>(strm);
  std::string data;
  std::unordered_map<std::string, FunctionInfo> fmap;
  std::string fmt;
  stream->Read(&fmt);
  stream->Read(&fmap);
  stream->Read(&data);
  return CUDAModuleCreate(data, fmt, fmap, std::string());
}

292
TVM_REGISTER_GLOBAL("module.loadfile_cubin")
293
.set_body([](TVMArgs args, TVMRetValue* rv) {
294
    *rv = CUDAModuleLoadFile(args[0], args[1]);
295 296
  });

297
TVM_REGISTER_GLOBAL("module.loadfile_ptx")
298
.set_body([](TVMArgs args, TVMRetValue* rv) {
299 300 301 302 303 304
    *rv = CUDAModuleLoadFile(args[0], args[1]);
  });

TVM_REGISTER_GLOBAL("module.loadbinary_cuda")
.set_body([](TVMArgs args, TVMRetValue* rv) {
    *rv = CUDAModuleLoadBinary(args[0]);
305
  });
306 307 308
}  // namespace runtime
}  // namespace tvm
#endif  // TVM_CUDA_RUNTIME