rocm_module.cc 7.18 KB
Newer Older
1 2 3 4 5 6 7 8 9 10
/*!
 *  Copyright (c) 2017 by Contributors
 * \file rocm_module.cc
 */
#include <tvm/runtime/registry.h>
#include <hip/hip_runtime_api.h>
#include <vector>
#include <array>
#include <string>
#include <mutex>
11
#include <unordered_map>
12 13
#include "rocm_module.h"
#include "rocm_common.h"
14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30
#include "../pack_args.h"
#include "../thread_storage_scope.h"
#include "../meta_data.h"
#include "../file_util.h"

namespace tvm {
namespace runtime {

// Module to support thread-safe multi-GPU execution.
// hipModule_t is a per-GPU module
// The runtime will contain a per-device module table
// The modules will be lazily loaded
class ROCMModuleNode : public runtime::ModuleNode {
 public:
  explicit ROCMModuleNode(std::string data,
                          std::string fmt,
                          std::unordered_map<std::string, FunctionInfo> fmap,
31 32 33
                          std::string hip_source,
                          std::string assembly)
    : data_(data), fmt_(fmt), fmap_(fmap), hip_source_(hip_source), assembly_(assembly) {
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
    std::fill(module_.begin(), module_.end(), nullptr);
  }
  // destructor
  ~ROCMModuleNode() {
    for (size_t i = 0; i < module_.size(); ++i) {
      if (module_[i] != nullptr) {
        ROCM_CALL(hipSetDevice(static_cast<int>(i)));
        ROCM_DRIVER_CALL(hipModuleUnload(module_[i]));
      }
    }
  }

  const char* type_key() const final {
    return "hip";
  }

  PackedFunc GetFunction(
      const std::string& name,
      const std::shared_ptr<ModuleNode>& sptr_to_self) final;


  void SaveToBinary(dmlc::Stream* stream) final {
    stream->Write(fmt_);
    stream->Write(fmap_);
    stream->Write(data_);
  }

61 62
  std::string GetSource(const std::string& format) final {
    if (format == fmt_) { return data_; }
63 64
    if (format == "llvm") { return hip_source_; }
    if (format == "asm") { return assembly_; }
65 66 67
    return "";
  }

68 69 70 71
  // get a CUfunction from primary context in device_id
  hipFunction_t GetFunc(int device_id, const std::string& func_name) {
    std::lock_guard<std::mutex> lock(mutex_);
    // must recheck under the lock scope
72

73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
    if (module_[device_id] == nullptr) {
      ROCM_DRIVER_CALL(hipModuleLoadData(&(module_[device_id]), data_.c_str()));
    }
    hipFunction_t func;
    hipError_t result = hipModuleGetFunction(&func, module_[device_id], func_name.c_str());
    if (result != hipSuccess) {
      LOG(FATAL)
          << "ROCMError: hipModuleGetFunction " << func_name
          << " failed with error: " << hipGetErrorString(result);
    }
    return func;
  }
  // get a global var from primary context in device_id
  hipDeviceptr_t GetGlobal(int device_id,
                        const std::string& global_name,
                        size_t expect_nbytes) {
    std::lock_guard<std::mutex> lock(mutex_);
    // must recheck under the lock scope
    if (module_[device_id] == nullptr) {
      ROCM_DRIVER_CALL(hipModuleLoadData(&(module_[device_id]), data_.c_str()));
    }
    hipDeviceptr_t global = nullptr;
    size_t nbytes = 0;

    hipError_t result = hipSuccess;
    // ROCM doesn't support hipModuleGetGlobal yet.
    // hipError_t result = hipModuleGetGlobal(&global, &nbytes,
    //                                    module_[device_id], global_name.c_str());
    CHECK_EQ(nbytes, expect_nbytes);
    if (result != hipSuccess) {
      LOG(FATAL)
          << "ROCMError: hipModuleGetGlobal " << global_name
          << " failed with error: " << hipGetErrorString(result);
    }
    return global;
  }

 private:
  // the binary data
  std::string data_;
  // The format
  std::string fmt_;
  // function information table.
  std::unordered_map<std::string, FunctionInfo> fmap_;
  // The hip source.
  std::string hip_source_;
119 120
  // The gcn asm.
  std::string assembly_;
121 122 123 124 125 126
  // the internal modules per GPU, to be lazily initialized.
  std::array<hipModule_t, kMaxNumGPUs> module_;
  // internal mutex when updating the module
  std::mutex mutex_;
};

127
// a wrapped function class to get packed func.
128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144
class ROCMWrappedFunc {
 public:
  // initialize the ROCM function.
  void Init(ROCMModuleNode* m,
            std::shared_ptr<ModuleNode> sptr,
            const std::string& func_name,
            size_t num_void_args,
            const std::vector<std::string>& thread_axis_tags) {
    m_ = m;
    sptr_ = sptr;
    func_name_ = func_name;
    std::fill(fcache_.begin(), fcache_.end(), nullptr);
    thread_axis_cfg_.Init(num_void_args, thread_axis_tags);
  }
  // invoke the function with void arguments
  void operator()(TVMArgs args,
                  TVMRetValue* rv,
145 146
                  void* packed_args,
                  size_t packed_nbytes) const {
147 148 149 150 151
    int device_id;
    ROCM_CALL(hipGetDevice(&device_id));
    if (fcache_[device_id] == nullptr) {
      fcache_[device_id] = m_->GetFunc(device_id, func_name_);
    }
152

153
    hipStream_t strm = static_cast<hipStream_t>(ROCMThreadEntry::ThreadLocal()->stream);
154

155
    ThreadWorkLoad wl = thread_axis_cfg_.Extract(args);
156
    void* config[] = {
157
      HIP_LAUNCH_PARAM_BUFFER_POINTER, packed_args,
158 159 160
      HIP_LAUNCH_PARAM_BUFFER_SIZE, &packed_nbytes,
      HIP_LAUNCH_PARAM_END
    };
161 162 163 164 165 166 167 168 169
    // HIP supports only extra_args.
    ROCM_DRIVER_CALL(hipModuleLaunchKernel(
        fcache_[device_id],
        wl.grid_dim(0),
        wl.grid_dim(1),
        wl.grid_dim(2),
        wl.block_dim(0),
        wl.block_dim(1),
        wl.block_dim(2),
170 171
        0, strm, nullptr,
        reinterpret_cast<void**>(&config)));
172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199
  }

 private:
  // internal module
  ROCMModuleNode* m_;
  // the resource holder
  std::shared_ptr<ModuleNode> sptr_;
  // The name of the function.
  std::string func_name_;
  // Device function cache per device.
  // mark as mutable, to enable lazy initialization
  mutable std::array<hipFunction_t, kMaxNumGPUs> fcache_;
  // thread axis configuration
  ThreadAxisConfig thread_axis_cfg_;
};


PackedFunc ROCMModuleNode::GetFunction(
      const std::string& name,
      const std::shared_ptr<ModuleNode>& sptr_to_self) {
  CHECK_EQ(sptr_to_self.get(), this);
  CHECK_NE(name, symbol::tvm_module_main)
      << "Device function do not have main";
  auto it = fmap_.find(name);
  if (it == fmap_.end()) return PackedFunc();
  const FunctionInfo& info = it->second;
  ROCMWrappedFunc f;
  f.Init(this, sptr_to_self, name, info.arg_types.size(), info.thread_axis_tags);
200
  return PackFuncPackedArg(f, info.arg_types);
201 202 203 204 205 206
}

Module ROCMModuleCreate(
    std::string data,
    std::string fmt,
    std::unordered_map<std::string, FunctionInfo> fmap,
207 208
    std::string hip_source,
    std::string assembly) {
209
  std::shared_ptr<ROCMModuleNode> n =
210
    std::make_shared<ROCMModuleNode>(data, fmt, fmap, hip_source, assembly);
211 212 213 214 215 216 217 218 219 220 221
  return Module(n);
}

Module ROCMModuleLoadBinary(void* strm) {
  dmlc::Stream* stream = static_cast<dmlc::Stream*>(strm);
  std::string data;
  std::unordered_map<std::string, FunctionInfo> fmap;
  std::string fmt;
  stream->Read(&fmt);
  stream->Read(&fmap);
  stream->Read(&data);
222
  return ROCMModuleCreate(data, fmt, fmap, std::string(), std::string());
223 224 225 226 227 228 229
}


TVM_REGISTER_GLOBAL("module.loadbinary_hsaco")
.set_body([](TVMArgs args, TVMRetValue* rv) {
    *rv = ROCMModuleLoadBinary(args[0]);
  });
eqy committed
230 231 232 233 234 235


TVM_REGISTER_GLOBAL("module.loadbinary_hip")
.set_body([](TVMArgs args, TVMRetValue* rv) {
    *rv = ROCMModuleLoadBinary(args[0]);
  });
236 237
}  // namespace runtime
}  // namespace tvm