codegen_amdgpu.cc 9.21 KB
Newer Older
1 2 3 4 5 6 7 8 9 10
/*!
 *  Copyright (c) 2017 by Contributors
 * \file codegen_amdgpu.cc
 * \brief AMDGPU code generator.
 */
#ifdef TVM_LLVM_VERSION

#include <tvm/runtime/device_api.h>
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/registry.h>
11
#include "codegen_llvm.h"
12
#include "../build_common.h"
13
#include "../codegen_source_base.h"
14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46
#include "../../pass/ir_util.h"
#include "../../runtime/rocm/rocm_module.h"

namespace tvm {
namespace codegen {

// AMDGPU code generator.
class CodeGenAMDGPU : public CodeGenLLVM {
 public:
  void AddFunction(const LoweredFunc& f) final {
    // add function as void return value
    CodeGenLLVM::AddFunctionInternal(f, true);
    function_->setCallingConv(llvm::CallingConv::AMDGPU_KERNEL);
  }

  void VisitStmt_(const Allocate* op) final {
    CHECK(!is_zero(op->condition));
    llvm::Value* buf = nullptr;
    if (op->new_expr.defined()) {
      CHECK_EQ(op->free_function, "nop");
      buf = MakeValue(op->new_expr);
    } else {
      int32_t constant_size = op->constant_allocation_size();
      CHECK_GT(constant_size, 0)
          << "Can only handle constant size stack allocation in GPU";
      StorageInfo& info = alloc_storage_info_[op->buffer_var.get()];
      if (constant_size % 4 == 0 && info.alignment == 0) {
        info.alignment = GetTempAllocaAlignment(op->type, constant_size);
      }
      // maximum necessary alignment in the AMD devices
      if (info.alignment > 16) {
        info.alignment = 16;
      }
47
      if (info.scope.rank == runtime::StorageRank::kLocal) {
48 49
        // const int local_address_space = 5;
        // TODO(tqchen): for higher version of LLVM, local address space can be set.
50 51 52 53
        llvm::AllocaInst* alloca = WithFunctionEntry([&]() {
            return builder_->CreateAlloca(
                LLVMType(op->type), ConstInt32(constant_size));
          });
54 55 56 57 58
        if (alloca->getAlignment() < static_cast<uint32_t>(info.alignment)) {
          alloca->setAlignment(info.alignment);
        }
        buf = alloca;
      } else {
59
        CHECK(info.scope.rank == runtime::StorageRank::kShared)
60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134
            << "Can only allocate shared or local memory inside kernel";
        // Shared memory: address space  == 3
        const unsigned shared_address_space = 3;
        llvm::Type* type = llvm::ArrayType::get(LLVMType(op->type), constant_size);
        // Allocate shared memory in global, address_space = 3
        llvm::GlobalVariable *global = new llvm::GlobalVariable(
            *module_, type, false, llvm::GlobalValue::PrivateLinkage, 0, ".shared",
            nullptr, llvm::GlobalValue::NotThreadLocal, shared_address_space);
        global->setAlignment(info.alignment);
        buf = global;
      }
    }
    buf = builder_->CreatePointerCast(
        buf, LLVMType(op->type)->getPointerTo(
            buf->getType()->getPointerAddressSpace()));
    CHECK(!var_map_.count(op->buffer_var.get()));
    var_map_[op->buffer_var.get()] = buf;
    this->VisitStmt(op->body);
  }

  // Return the thread index via intrinsics.
  llvm::Value* GetThreadIndex(const IterVar& iv) final {
    runtime::ThreadScope ts = runtime::ThreadScope::make(iv->thread_tag);
    llvm::Intrinsic::ID intrin_id = ::llvm::Intrinsic::amdgcn_workitem_id_x;
    if (ts.rank == 1) {
      switch (ts.dim_index) {
        case 0: intrin_id = ::llvm::Intrinsic::amdgcn_workitem_id_x; break;
        case 1: intrin_id = ::llvm::Intrinsic::amdgcn_workitem_id_y; break;
        case 2: intrin_id = ::llvm::Intrinsic::amdgcn_workitem_id_z; break;
        default: LOG(FATAL) << "unknown workitem idx";
      }
    } else {
      CHECK_EQ(ts.rank, 0);
      switch (ts.dim_index) {
        case 0: intrin_id = ::llvm::Intrinsic::amdgcn_workgroup_id_x; break;
        case 1: intrin_id = ::llvm::Intrinsic::amdgcn_workgroup_id_y; break;
        case 2: intrin_id = ::llvm::Intrinsic::amdgcn_workgroup_id_z; break;
        default: LOG(FATAL) << "unknown workgroup idx";
      }
    }
    llvm::Function* f = llvm::Intrinsic::getDeclaration(module_.get(), intrin_id);
    return builder_->CreateCall(f, {});
  }

  llvm::Value* CreateStorageSync(const Call* op) final {
    const std::string& sync = op->args[0].as<StringImm>()->value;
    if (sync == "warp") {
      return nullptr;
    } else if (sync == "shared") {
      llvm::Function* f = llvm::Intrinsic::getDeclaration(
          module_.get(),
          ::llvm::Intrinsic::amdgcn_s_barrier);
      return builder_->CreateCall(f, {});
    } else {
      LOG(FATAL) << "Do not support sync " << sync;
      return nullptr;
    }
  }

  void InitPassManagerBuilder(llvm::PassManagerBuilder* builder) final {
    // Additional optimization hook to tweak the builder.
  }

  unsigned GetGlobalAddressSpace() {
    return 1;
  }

 protected:
  void InitTarget(llvm::TargetMachine* tm) final {
    // Maximum vector lane = float4
    native_vector_bits_ = 4 * 32;
    CodeGenLLVM::InitTarget(tm);
  }
};

135 136 137 138 139 140 141
inline int DetectROCMComputeVersion(const std::string& target) {
  size_t pos = target.find("=gfx");
  if (pos != std::string::npos) {
    int value;
    std::stringstream is(target.substr(pos + 4));
    if (is >> value) return value;
  }
142
  TVMContext tvm_ctx;
143
  tvm_ctx.device_type = kDLROCM;
144
  tvm_ctx.device_id = 0;
145 146 147 148 149 150 151 152
  tvm::runtime::DeviceAPI* api = tvm::runtime::DeviceAPI::Get(tvm_ctx, true);
  if (api != nullptr) {
    TVMRetValue val;
    api->GetAttr(tvm_ctx, tvm::runtime::kExist, &val);
    if (val.operator int() == 1) {
      tvm::runtime::DeviceAPI::Get(tvm_ctx)->GetAttr(tvm_ctx, tvm::runtime::kComputeVersion, &val);
      return val.operator int();
    }
153
  }
154 155
  LOG(WARNING) << "Cannot find -mcpu to specify rocm compute version assume gfx803";
  return 803;
156
}
157

158
runtime::Module BuildAMDGPU(Array<LoweredFunc> funcs, std::string target) {
159
  InitializeLLVM();
160 161 162 163
  CHECK(target.length() >= 4 &&
        target.substr(0, 4) == "rocm");
  std::ostringstream config;
  config << "-mtriple=amdgcn-amd-amdhsa-hcc -mcpu=gfx"
164
         << DetectROCMComputeVersion(target)
165
         << target.substr(4, target.length() - 4);
166
  std::unique_ptr<llvm::TargetMachine> tm = GetLLVMTargetMachine(config.str());
167 168
  std::unique_ptr<CodeGenAMDGPU> cg(new CodeGenAMDGPU());
  std::unique_ptr<llvm::LLVMContext> ctx(new llvm::LLVMContext());
169
  cg->Init(funcs[0]->name, tm.get(), ctx.get(), false, false);
170 171 172 173
  for (LoweredFunc f :  funcs) {
    cg->AddFunction(f);
  }

174 175 176 177 178 179 180 181 182 183 184 185 186 187 188
  const auto *find_rocm_bitcodes =
      tvm::runtime::Registry::Get("tvm_callback_rocm_bitcode_path");
  Array<Expr> bitcode_files = (*find_rocm_bitcodes)();

  for (auto &bitcode : bitcode_files) {
    std::string path = bitcode.as<StringImm>()->value;
    llvm::SMDiagnostic err;
    std::unique_ptr<llvm::Module> mlib = llvm::parseIRFile(path, err, *ctx);
    if (mlib.get() == nullptr) {
      std::string msg = err.getMessage();
      LOG(FATAL) << "Fail to load bitcode file " << path << "\n"
                 << "line " << err.getLineNo() << ":" << msg;
    }
    mlib->setTargetTriple(tm->getTargetTriple().str());
    mlib->setDataLayout(tm->createDataLayout());
eqy committed
189 190 191
    for (llvm::Function &f : mlib->functions()) {
      f.addFnAttr(llvm::Attribute::AlwaysInline);
    }
192 193 194
    cg->AddLinkModule(std::move(mlib));
  }

195 196 197 198 199 200 201
  std::unique_ptr<llvm::Module> module = cg->Finish();
  llvm::SmallString<8> dataObj, data_ll, dataAsm;
  llvm::raw_svector_ostream destObj(dataObj), dest_ll(data_ll), destAsm(dataAsm);
  destObj.SetUnbuffered();
  dest_ll.SetUnbuffered();
  destAsm.SetUnbuffered();
  module->print(dest_ll, nullptr);
202
#if TVM_LLVM_VERSION <= 60
203 204
  std::unique_ptr<llvm::Module> mAsm = llvm::CloneModule(module.get());
  std::unique_ptr<llvm::Module> mObj = llvm::CloneModule(module.get());
205 206 207 208
#else
  std::unique_ptr<llvm::Module> mAsm = llvm::CloneModule(*module.get());
  std::unique_ptr<llvm::Module> mObj = llvm::CloneModule(*module.get());
#endif
209 210
  llvm::legacy::PassManager pass;

211
#if TVM_LLVM_VERSION <= 60
212 213 214
  CHECK(tm->addPassesToEmitFile(
            pass, destObj, llvm::TargetMachine::CGFT_ObjectFile) == 0)
            << "Cannot emit target CGFT_ObjectFile";
215 216 217 218 219
#else
  CHECK(tm->addPassesToEmitFile(
            pass, destObj, nullptr, llvm::TargetMachine::CGFT_ObjectFile) == 0)
            << "Cannot emit target CGFT_ObjectFile";
#endif
220 221 222
  pass.run(*mObj);
  std::string obj(dataObj.begin(), dataObj.end());

223
  llvm::legacy::PassManager passAsm;
224
#if TVM_LLVM_VERSION <= 60
225 226 227
  CHECK(tm->addPassesToEmitFile(passAsm, destAsm,
                                llvm::TargetMachine::CGFT_AssemblyFile) == 0)
      << "Cannot emit target CGFT_AssemblyFile";
228 229 230 231 232
#else
  CHECK(tm->addPassesToEmitFile(passAsm, destAsm, nullptr,
                                llvm::TargetMachine::CGFT_AssemblyFile) == 0)
      << "Cannot emit target CGFT_AssemblyFile";
#endif
233 234 235
  passAsm.run(*mAsm);
  std::string assembly(dataAsm.begin(), dataAsm.end());

236 237 238 239 240 241 242 243 244
  const auto* f = tvm::runtime::Registry::Get("tvm_callback_rocm_link");
  CHECK(f != nullptr) << "Require tvm_callback_rocm_link to exist, do import tvm.contrib.rocm";

  TVMByteArray arr;
  arr.data = &obj[0];
  arr.size = obj.length();

  std::string hsaco = (*f)(arr);
  std::string ll(data_ll.begin(), data_ll.end());
245
  return ROCMModuleCreate(hsaco, "hsaco", ExtractFuncInfo(funcs), ll, assembly);
246 247 248 249 250 251 252 253 254 255
}

TVM_REGISTER_API("codegen.build_rocm")
.set_body([](TVMArgs args, TVMRetValue* rv) {
    *rv = BuildAMDGPU(args[0], args[1]);
  });

}  // namespace codegen
}  // namespace tvm
#endif  // TVM_LLVM_VERSION