rocm_module.cc 7.08 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
/*!
 *  Copyright (c) 2017 by Contributors
 * \file rocm_module.cc
 */
#include "./rocm_module.h"

#if TVM_ROCM_RUNTIME

#include <tvm/runtime/registry.h>
#include <hip/hip_runtime_api.h>
#include <vector>
#include <array>
#include <string>
#include <mutex>
#include "./rocm_common.h"
#include "../pack_args.h"
#include "../thread_storage_scope.h"
#include "../meta_data.h"
#include "../file_util.h"

namespace tvm {
namespace runtime {

// Module to support thread-safe multi-GPU execution.
// hipModule_t is a per-GPU module
// The runtime will contain a per-device module table
// The modules will be lazily loaded
class ROCMModuleNode : public runtime::ModuleNode {
 public:
  explicit ROCMModuleNode(std::string data,
                          std::string fmt,
                          std::unordered_map<std::string, FunctionInfo> fmap,
33 34 35
                          std::string hip_source,
                          std::string assembly)
    : data_(data), fmt_(fmt), fmap_(fmap), hip_source_(hip_source), assembly_(assembly) {
36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
    std::fill(module_.begin(), module_.end(), nullptr);
  }
  // destructor
  ~ROCMModuleNode() {
    for (size_t i = 0; i < module_.size(); ++i) {
      if (module_[i] != nullptr) {
        ROCM_CALL(hipSetDevice(static_cast<int>(i)));
        ROCM_DRIVER_CALL(hipModuleUnload(module_[i]));
      }
    }
  }

  const char* type_key() const final {
    return "hip";
  }

  PackedFunc GetFunction(
      const std::string& name,
      const std::shared_ptr<ModuleNode>& sptr_to_self) final;


  void SaveToBinary(dmlc::Stream* stream) final {
    stream->Write(fmt_);
    stream->Write(fmap_);
    stream->Write(data_);
  }

63 64
  std::string GetSource(const std::string& format) final {
    if (format == fmt_) { return data_; }
65 66
    if (format == "llvm") { return hip_source_; }
    if (format == "asm") { return assembly_; }
67 68 69
    return "";
  }

70 71 72 73
  // get a CUfunction from primary context in device_id
  hipFunction_t GetFunc(int device_id, const std::string& func_name) {
    std::lock_guard<std::mutex> lock(mutex_);
    // must recheck under the lock scope
74

75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120
    if (module_[device_id] == nullptr) {
      ROCM_DRIVER_CALL(hipModuleLoadData(&(module_[device_id]), data_.c_str()));
    }
    hipFunction_t func;
    hipError_t result = hipModuleGetFunction(&func, module_[device_id], func_name.c_str());
    if (result != hipSuccess) {
      LOG(FATAL)
          << "ROCMError: hipModuleGetFunction " << func_name
          << " failed with error: " << hipGetErrorString(result);
    }
    return func;
  }
  // get a global var from primary context in device_id
  hipDeviceptr_t GetGlobal(int device_id,
                        const std::string& global_name,
                        size_t expect_nbytes) {
    std::lock_guard<std::mutex> lock(mutex_);
    // must recheck under the lock scope
    if (module_[device_id] == nullptr) {
      ROCM_DRIVER_CALL(hipModuleLoadData(&(module_[device_id]), data_.c_str()));
    }
    hipDeviceptr_t global = nullptr;
    size_t nbytes = 0;

    hipError_t result = hipSuccess;
    // ROCM doesn't support hipModuleGetGlobal yet.
    // hipError_t result = hipModuleGetGlobal(&global, &nbytes,
    //                                    module_[device_id], global_name.c_str());
    CHECK_EQ(nbytes, expect_nbytes);
    if (result != hipSuccess) {
      LOG(FATAL)
          << "ROCMError: hipModuleGetGlobal " << global_name
          << " failed with error: " << hipGetErrorString(result);
    }
    return global;
  }

 private:
  // the binary data
  std::string data_;
  // The format
  std::string fmt_;
  // function information table.
  std::unordered_map<std::string, FunctionInfo> fmap_;
  // The hip source.
  std::string hip_source_;
121 122
  // The gcn asm.
  std::string assembly_;
123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146
  // the internal modules per GPU, to be lazily initialized.
  std::array<hipModule_t, kMaxNumGPUs> module_;
  // internal mutex when updating the module
  std::mutex mutex_;
};

// a wrapped function class to get packed fucn.
class ROCMWrappedFunc {
 public:
  // initialize the ROCM function.
  void Init(ROCMModuleNode* m,
            std::shared_ptr<ModuleNode> sptr,
            const std::string& func_name,
            size_t num_void_args,
            const std::vector<std::string>& thread_axis_tags) {
    m_ = m;
    sptr_ = sptr;
    func_name_ = func_name;
    std::fill(fcache_.begin(), fcache_.end(), nullptr);
    thread_axis_cfg_.Init(num_void_args, thread_axis_tags);
  }
  // invoke the function with void arguments
  void operator()(TVMArgs args,
                  TVMRetValue* rv,
147 148
                  void* packed_args,
                  size_t packed_nbytes) const {
149 150 151 152 153
    int device_id;
    ROCM_CALL(hipGetDevice(&device_id));
    if (fcache_[device_id] == nullptr) {
      fcache_[device_id] = m_->GetFunc(device_id, func_name_);
    }
154

155
    hipStream_t strm = static_cast<hipStream_t>(ROCMThreadEntry::ThreadLocal()->stream);
156

157
    ThreadWorkLoad wl = thread_axis_cfg_.Extract(args);
158
    void* config[] = {
159
      HIP_LAUNCH_PARAM_BUFFER_POINTER, packed_args,
160 161 162
      HIP_LAUNCH_PARAM_BUFFER_SIZE, &packed_nbytes,
      HIP_LAUNCH_PARAM_END
    };
163 164 165 166 167 168 169 170 171
    // HIP supports only extra_args.
    ROCM_DRIVER_CALL(hipModuleLaunchKernel(
        fcache_[device_id],
        wl.grid_dim(0),
        wl.grid_dim(1),
        wl.grid_dim(2),
        wl.block_dim(0),
        wl.block_dim(1),
        wl.block_dim(2),
172 173
        0, strm, nullptr,
        reinterpret_cast<void**>(&config)));
174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201
  }

 private:
  // internal module
  ROCMModuleNode* m_;
  // the resource holder
  std::shared_ptr<ModuleNode> sptr_;
  // The name of the function.
  std::string func_name_;
  // Device function cache per device.
  // mark as mutable, to enable lazy initialization
  mutable std::array<hipFunction_t, kMaxNumGPUs> fcache_;
  // thread axis configuration
  ThreadAxisConfig thread_axis_cfg_;
};


PackedFunc ROCMModuleNode::GetFunction(
      const std::string& name,
      const std::shared_ptr<ModuleNode>& sptr_to_self) {
  CHECK_EQ(sptr_to_self.get(), this);
  CHECK_NE(name, symbol::tvm_module_main)
      << "Device function do not have main";
  auto it = fmap_.find(name);
  if (it == fmap_.end()) return PackedFunc();
  const FunctionInfo& info = it->second;
  ROCMWrappedFunc f;
  f.Init(this, sptr_to_self, name, info.arg_types.size(), info.thread_axis_tags);
202
  return PackFuncPackedArg(f, info.arg_types);
203 204 205 206 207 208
}

Module ROCMModuleCreate(
    std::string data,
    std::string fmt,
    std::unordered_map<std::string, FunctionInfo> fmap,
209 210
    std::string hip_source,
    std::string assembly) {
211
  std::shared_ptr<ROCMModuleNode> n =
212
    std::make_shared<ROCMModuleNode>(data, fmt, fmap, hip_source, assembly);
213 214 215 216 217 218 219 220 221 222 223
  return Module(n);
}

Module ROCMModuleLoadBinary(void* strm) {
  dmlc::Stream* stream = static_cast<dmlc::Stream*>(strm);
  std::string data;
  std::unordered_map<std::string, FunctionInfo> fmap;
  std::string fmt;
  stream->Read(&fmt);
  stream->Read(&fmap);
  stream->Read(&data);
224
  return ROCMModuleCreate(data, fmt, fmap, std::string(), std::string());
225 226 227 228 229 230 231 232 233 234
}


TVM_REGISTER_GLOBAL("module.loadbinary_hsaco")
.set_body([](TVMArgs args, TVMRetValue* rv) {
    *rv = ROCMModuleLoadBinary(args[0]);
  });
}  // namespace runtime
}  // namespace tvm
#endif  // TVM_ROCM_RUNTIME