/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file codegen_amdgpu.cc * \brief AMDGPU code generator. */ #ifdef TVM_LLVM_VERSION #include <tvm/runtime/device_api.h> #include <tvm/runtime/c_runtime_api.h> #include <tvm/runtime/registry.h> #include "codegen_llvm.h" #include "../build_common.h" #include "../codegen_source_base.h" #include "../../pass/ir_util.h" #include "../../runtime/rocm/rocm_module.h" namespace tvm { namespace codegen { namespace { // calls the device api to get the max threads per block static inline int DetectROCMmaxThreadsPerBlock() { TVMContext tvm_ctx; tvm_ctx.device_type = kDLROCM; tvm_ctx.device_id = 0; tvm::runtime::DeviceAPI* api = tvm::runtime::DeviceAPI::Get(tvm_ctx, true); if (api != nullptr) { TVMRetValue val; api->GetAttr(tvm_ctx, tvm::runtime::kExist, &val); if (val.operator int() == 1) { tvm::runtime::DeviceAPI::Get(tvm_ctx)-> GetAttr(tvm_ctx, tvm::runtime::kMaxThreadsPerBlock, &val); return val.operator int(); } } LOG(WARNING) << "Cannot get maximum number of threads for AMD codegen"; return 256; // see the discussion at PR #4342 for the choice of default } } // namespace // AMDGPU code generator. class CodeGenAMDGPU : public CodeGenLLVM { public: void AddFunction(const LoweredFunc& f) final { // add function as void return value CodeGenLLVM::AddFunctionInternal(f, true); function_->setCallingConv(llvm::CallingConv::AMDGPU_KERNEL); std::ostringstream attr; attr << "1," << DetectROCMmaxThreadsPerBlock(); function_->addFnAttr("amdgpu-flat-work-group-size", attr.str()); } void VisitStmt_(const AllocateNode* op) final { CHECK(!is_zero(op->condition)); llvm::Value* buf = nullptr; if (op->new_expr.defined()) { CHECK_EQ(op->free_function, "nop"); buf = MakeValue(op->new_expr); } else { int32_t constant_size = op->constant_allocation_size(); CHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation in GPU"; StorageInfo& info = alloc_storage_info_[op->buffer_var.get()]; if (constant_size % 4 == 0 && info.alignment == 0) { info.alignment = GetTempAllocaAlignment(op->dtype, constant_size); } // maximum necessary alignment in the AMD devices if (info.alignment > 16) { info.alignment = 16; } if (info.scope.rank == runtime::StorageRank::kLocal) { // const int local_address_space = 5; // TODO(tqchen): for higher version of LLVM, local address space can be set. llvm::AllocaInst* alloca = WithFunctionEntry([&]() { return builder_->CreateAlloca( LLVMType(op->dtype), ConstInt32(constant_size)); }); if (alloca->getAlignment() < static_cast<uint32_t>(info.alignment)) { #if TVM_LLVM_VERSION >= 100 alloca->setAlignment(llvm::Align(info.alignment)); #else alloca->setAlignment(info.alignment); #endif } buf = alloca; } else { CHECK(info.scope.rank == runtime::StorageRank::kShared) << "Can only allocate shared or local memory inside kernel"; // Shared memory: address space == 3 const unsigned shared_address_space = 3; llvm::Type* type = llvm::ArrayType::get(LLVMType(op->dtype), constant_size); // Allocate shared memory in global, address_space = 3 llvm::GlobalVariable *global = new llvm::GlobalVariable( *module_, type, false, llvm::GlobalValue::PrivateLinkage, 0, ".shared", nullptr, llvm::GlobalValue::NotThreadLocal, shared_address_space); #if TVM_LLVM_VERSION >= 100 global->setAlignment(llvm::Align(info.alignment)); #else global->setAlignment(info.alignment); #endif buf = global; } } buf = builder_->CreatePointerCast( buf, LLVMType(op->dtype)->getPointerTo( buf->getType()->getPointerAddressSpace())); CHECK(!var_map_.count(op->buffer_var.get())); var_map_[op->buffer_var.get()] = buf; this->VisitStmt(op->body); } // Return the thread index via intrinsics. llvm::Value* GetThreadIndex(const IterVar& iv) final { runtime::ThreadScope ts = runtime::ThreadScope::make(iv->thread_tag); llvm::Intrinsic::ID intrin_id = ::llvm::Intrinsic::amdgcn_workitem_id_x; if (ts.rank == 1) { switch (ts.dim_index) { case 0: intrin_id = ::llvm::Intrinsic::amdgcn_workitem_id_x; break; case 1: intrin_id = ::llvm::Intrinsic::amdgcn_workitem_id_y; break; case 2: intrin_id = ::llvm::Intrinsic::amdgcn_workitem_id_z; break; default: LOG(FATAL) << "unknown workitem idx"; } } else { CHECK_EQ(ts.rank, 0); switch (ts.dim_index) { case 0: intrin_id = ::llvm::Intrinsic::amdgcn_workgroup_id_x; break; case 1: intrin_id = ::llvm::Intrinsic::amdgcn_workgroup_id_y; break; case 2: intrin_id = ::llvm::Intrinsic::amdgcn_workgroup_id_z; break; default: LOG(FATAL) << "unknown workgroup idx"; } } llvm::Function* f = llvm::Intrinsic::getDeclaration(module_.get(), intrin_id); return builder_->CreateCall(f, {}); } llvm::Value* CreateStorageSync(const CallNode* op) final { const std::string& sync = op->args[0].as<StringImmNode>()->value; if (sync == "warp") { return nullptr; } else if (sync == "shared") { llvm::Function* f = llvm::Intrinsic::getDeclaration( module_.get(), ::llvm::Intrinsic::amdgcn_s_barrier); return builder_->CreateCall(f, {}); } else { LOG(FATAL) << "Do not support sync " << sync; return nullptr; } } void InitPassManagerBuilder(llvm::PassManagerBuilder* builder) final { // Additional optimization hook to tweak the builder. } unsigned GetGlobalAddressSpace() { return 1; } protected: void InitTarget(llvm::TargetMachine* tm) final { // Maximum vector lane = float4 native_vector_bits_ = 4 * 32; CodeGenLLVM::InitTarget(tm); } }; inline int DetectROCMComputeVersion(const std::string& target) { size_t pos = target.find("=gfx"); if (pos != std::string::npos) { int value; std::stringstream is(target.substr(pos + 4)); if (is >> value) return value; } TVMContext tvm_ctx; tvm_ctx.device_type = kDLROCM; tvm_ctx.device_id = 0; tvm::runtime::DeviceAPI* api = tvm::runtime::DeviceAPI::Get(tvm_ctx, true); if (api != nullptr) { TVMRetValue val; api->GetAttr(tvm_ctx, tvm::runtime::kExist, &val); if (val.operator int() == 1) { tvm::runtime::DeviceAPI::Get(tvm_ctx)->GetAttr(tvm_ctx, tvm::runtime::kGcnArch, &val); return val.operator int(); } } LOG(WARNING) << "Cannot find -mcpu to specify rocm compute version assume gfx900"; return 900; } runtime::Module BuildAMDGPU(Array<LoweredFunc> funcs, std::string target) { #if TVM_LLVM_VERSION < 90 LOG(FATAL) << "AMDGPU backend requires at least LLVM 9"; // Lower versions will crash when loading the bitcode, see // issue #4087 for a discussion #endif InitializeLLVM(); CHECK(target.length() >= 4 && target.substr(0, 4) == "rocm"); std::ostringstream config; config << "-mtriple=amdgcn-amd-amdhsa-hcc -mcpu=gfx" << DetectROCMComputeVersion(target) << " -mattr=-code-object-v3 " << target.substr(4, target.length() - 4); std::unique_ptr<llvm::TargetMachine> tm = GetLLVMTargetMachine(config.str()); std::unique_ptr<CodeGenAMDGPU> cg(new CodeGenAMDGPU()); std::unique_ptr<llvm::LLVMContext> ctx(new llvm::LLVMContext()); cg->Init(funcs[0]->name, tm.get(), ctx.get(), false, false); for (LoweredFunc f : funcs) { cg->AddFunction(f); } const auto *find_rocm_bitcodes = tvm::runtime::Registry::Get("tvm_callback_rocm_bitcode_path"); Array<PrimExpr> bitcode_files = (*find_rocm_bitcodes)(); for (auto &bitcode : bitcode_files) { std::string path = bitcode.as<StringImmNode>()->value; llvm::SMDiagnostic err; std::unique_ptr<llvm::Module> mlib = llvm::parseIRFile(path, err, *ctx); if (mlib.get() == nullptr) { std::string msg = err.getMessage(); LOG(FATAL) << "Fail to load bitcode file " << path << "\n" << "line " << err.getLineNo() << ":" << msg; } mlib->setTargetTriple(tm->getTargetTriple().str()); mlib->setDataLayout(tm->createDataLayout()); for (llvm::Function &f : mlib->functions()) { f.addFnAttr(llvm::Attribute::AlwaysInline); } cg->AddLinkModule(std::move(mlib)); } std::unique_ptr<llvm::Module> module = cg->Finish(); llvm::SmallString<8> dataObj, data_ll, dataAsm; llvm::raw_svector_ostream destObj(dataObj), dest_ll(data_ll), destAsm(dataAsm); destObj.SetUnbuffered(); dest_ll.SetUnbuffered(); destAsm.SetUnbuffered(); module->print(dest_ll, nullptr); #if TVM_LLVM_VERSION <= 60 std::unique_ptr<llvm::Module> mAsm = llvm::CloneModule(module.get()); std::unique_ptr<llvm::Module> mObj = llvm::CloneModule(module.get()); #else std::unique_ptr<llvm::Module> mAsm = llvm::CloneModule(*module.get()); std::unique_ptr<llvm::Module> mObj = llvm::CloneModule(*module.get()); #endif llvm::legacy::PassManager pass; #if TVM_LLVM_VERSION <= 60 CHECK(tm->addPassesToEmitFile( pass, destObj, llvm::TargetMachine::CGFT_ObjectFile) == 0) << "Cannot emit target CGFT_ObjectFile"; #elif TVM_LLVM_VERSION <= 90 CHECK(tm->addPassesToEmitFile( pass, destObj, nullptr, llvm::TargetMachine::CGFT_ObjectFile) == 0) << "Cannot emit target CGFT_ObjectFile"; #else CHECK(tm->addPassesToEmitFile( pass, destObj, nullptr, llvm::CGFT_ObjectFile) == 0) << "Cannot emit target CGFT_ObjectFile"; #endif pass.run(*mObj); std::string obj(dataObj.begin(), dataObj.end()); llvm::legacy::PassManager passAsm; #if TVM_LLVM_VERSION <= 60 CHECK(tm->addPassesToEmitFile(passAsm, destAsm, llvm::TargetMachine::CGFT_AssemblyFile) == 0) << "Cannot emit target CGFT_AssemblyFile"; #elif TVM_LLVM_VERSION <= 90 CHECK(tm->addPassesToEmitFile(passAsm, destAsm, nullptr, llvm::TargetMachine::CGFT_AssemblyFile) == 0) << "Cannot emit target CGFT_AssemblyFile"; #else CHECK(tm->addPassesToEmitFile(passAsm, destAsm, nullptr, llvm::CGFT_AssemblyFile) == 0) << "Cannot emit target CGFT_AssemblyFile"; #endif passAsm.run(*mAsm); std::string assembly(dataAsm.begin(), dataAsm.end()); const auto* f = tvm::runtime::Registry::Get("tvm_callback_rocm_link"); CHECK(f != nullptr) << "Require tvm_callback_rocm_link to exist, do import tvm.contrib.rocm"; TVMByteArray arr; arr.data = &obj[0]; arr.size = obj.length(); std::string hsaco = (*f)(arr); std::string ll(data_ll.begin(), data_ll.end()); return ROCMModuleCreate(hsaco, "hsaco", ExtractFuncInfo(funcs), ll, assembly); } TVM_REGISTER_GLOBAL("codegen.build_rocm") .set_body_typed(BuildAMDGPU); } // namespace codegen } // namespace tvm #endif // TVM_LLVM_VERSION