/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file codegen_nvptx.cc * \brief NVPTX code generator. */ #ifdef TVM_LLVM_VERSION #include <tvm/runtime/device_api.h> #include "codegen_llvm.h" #include "../build_common.h" #include "../../pass/ir_util.h" #include "../../runtime/cuda/cuda_module.h" namespace tvm { namespace codegen { // NVPTX code generator. class CodeGenNVPTX : public CodeGenLLVM { public: void AddFunction(const LoweredFunc& f) final { // add function as void return value CodeGenLLVM::AddFunctionInternal(f, true); // annotate as kernel function module_->getOrInsertNamedMetadata("nvvm.annotations") ->addOperand(llvm::MDNode::get(*ctx_, { llvm::ValueAsMetadata::get(function_), llvm::MDString::get(*ctx_, "kernel"), llvm::ValueAsMetadata::get(ConstInt32(1)) })); } void VisitStmt_(const AllocateNode* op) final { CHECK(!is_zero(op->condition)); llvm::Value* buf = nullptr; if (op->new_expr.defined()) { CHECK_EQ(op->free_function, "nop"); buf = MakeValue(op->new_expr); } else { int32_t constant_size = op->constant_allocation_size(); CHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation in GPU"; StorageInfo& info = alloc_storage_info_[op->buffer_var.get()]; if (constant_size % 4 == 0 && info.alignment == 0) { info.alignment = GetTempAllocaAlignment(op->dtype, constant_size); } // maximum necessary alignment in the NV devices if (info.alignment > 16) { info.alignment = 16; } if (info.scope.rank == runtime::StorageRank::kLocal) { // const int local_address_space = 5; // TODO(tqchen): for higher version of LLVM, local address space can be set. llvm::AllocaInst* alloca = WithFunctionEntry([&]() { return builder_->CreateAlloca( LLVMType(op->dtype), ConstInt32(constant_size)); }); if (alloca->getAlignment() < static_cast<uint32_t>(info.alignment)) { #if TVM_LLVM_VERSION >= 100 alloca->setAlignment(llvm::Align(info.alignment)); #else alloca->setAlignment(info.alignment); #endif } buf = alloca; } else { CHECK(info.scope.rank == runtime::StorageRank::kShared) << "Can only allocate shared or local memory inside kernel"; // Shared memory: address space == 3 const unsigned shared_address_space = 3; llvm::Type* type = llvm::ArrayType::get(LLVMType(op->dtype), constant_size); // Allocate shared memory in global, address_space = 3 llvm::GlobalVariable *global = new llvm::GlobalVariable( *module_, type, false, llvm::GlobalValue::PrivateLinkage, 0, ".shared", nullptr, llvm::GlobalValue::NotThreadLocal, shared_address_space); #if TVM_LLVM_VERSION >= 100 global->setAlignment(llvm::Align(info.alignment)); #else global->setAlignment(info.alignment); #endif buf = global; } } buf = builder_->CreatePointerCast( buf, LLVMType(op->dtype)->getPointerTo( buf->getType()->getPointerAddressSpace())); CHECK(!var_map_.count(op->buffer_var.get())); var_map_[op->buffer_var.get()] = buf; this->VisitStmt(op->body); } // Return the thread index via intrinsics. llvm::Value* GetThreadIndex(const IterVar& iv) final { runtime::ThreadScope ts = runtime::ThreadScope::make(iv->thread_tag); llvm::Intrinsic::ID intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x; if (ts.rank == 1) { switch (ts.dim_index) { case 0: intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x; break; case 1: intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_tid_y; break; case 2: intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_tid_z; break; default: LOG(FATAL) << "unknown thread idx"; } } else { CHECK_EQ(ts.rank, 0); switch (ts.dim_index) { case 0: intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x; break; case 1: intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_y; break; case 2: intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_z; break; default: LOG(FATAL) << "unknown thread idx"; } } llvm::Function* f = llvm::Intrinsic::getDeclaration(module_.get(), intrin_id); return builder_->CreateCall(f, {}); } llvm::Value* CreateStorageSync(const CallNode* op) final { const std::string& sync = op->args[0].as<StringImmNode>()->value; if (sync == "warp") { // TODO(tqchen) warp sync in CUDA9 return nullptr; } else if (sync == "shared") { llvm::Function* f = llvm::Intrinsic::getDeclaration( module_.get(), ::llvm::Intrinsic::nvvm_barrier0); return builder_->CreateCall(f, {}); } else { LOG(FATAL) << "Do not support sync " << sync; return nullptr; } } void InitPassManagerBuilder(llvm::PassManagerBuilder* builder) final { // Additional optimization hook to tweak the builder. } void Optimize() final { for (auto& f : *module_) { auto fname = static_cast<std::string>(f.getName()); if (fname.substr(0, 4) != "__nv") continue; // This is to strip off unused __nv_* functions from the final module // The one that is actually used will be inlined at call site // Adapted from Halide's runtime linker if (!f.isDeclaration() && !f.hasFnAttribute(llvm::Attribute::NoInline)) { f.setLinkage(llvm::GlobalValue::AvailableExternallyLinkage); } } CodeGenLLVM::Optimize(); } protected: void InitTarget(llvm::TargetMachine* tm) final { // Maximum vector lane = float4 native_vector_bits_ = 4 * 32; CodeGenLLVM::InitTarget(tm); } }; inline int DetectCUDAComputeVersion() { TVMContext tvm_ctx; tvm_ctx.device_type = kDLGPU; tvm_ctx.device_id = 0; TVMRetValue val; tvm::runtime::DeviceAPI::Get(tvm_ctx)->GetAttr( tvm_ctx, tvm::runtime::kExist, &val); if (val.operator int() == 1) { tvm::runtime::DeviceAPI::Get(tvm_ctx)->GetAttr( tvm_ctx, tvm::runtime::kComputeVersion, &val); std::string version = val; std::istringstream is(version); double ver; is >> ver; return static_cast<int>(ver * 10); } else { return 20; } } runtime::Module BuildNVPTX(Array<LoweredFunc> funcs, std::string target) { InitializeLLVM(); CHECK(target.length() >= 5 && target.substr(0, 5) == "nvptx"); int compute_ver = DetectCUDAComputeVersion(); std::ostringstream config; config << "-mtriple=nvptx64-nvidia-cuda -mcpu=sm_" << compute_ver << target.substr(5, target.length() - 5); std::unique_ptr<llvm::TargetMachine> tm = GetLLVMTargetMachine(config.str()); std::unique_ptr<CodeGenNVPTX> cg(new CodeGenNVPTX()); std::unique_ptr<llvm::LLVMContext> ctx(new llvm::LLVMContext()); cg->Init(funcs[0]->name, tm.get(), ctx.get(), false, false); for (LoweredFunc f : funcs) { cg->AddFunction(f); } const auto* flibdevice_path = tvm::runtime::Registry::Get("tvm_callback_libdevice_path"); if (flibdevice_path != nullptr) { std::string path = (*flibdevice_path)(compute_ver); if (path.length() != 0) { llvm::SMDiagnostic err; std::unique_ptr<llvm::Module> mlib = llvm::parseIRFile(path, err, *ctx); if (mlib.get() == nullptr) { std::string msg = err.getMessage(); LOG(FATAL) << "Fail to load bitcode file " << path << "\n" << "line " << err.getLineNo() << ":" << msg; } mlib->setTargetTriple(tm->getTargetTriple().str()); mlib->setDataLayout(tm->createDataLayout()); cg->AddLinkModule(std::move(mlib)); } } std::unique_ptr<llvm::Module> module = cg->Finish(); llvm::SmallString<8> data_ptx, data_ll; llvm::raw_svector_ostream dest_ptx(data_ptx), dest_ll(data_ll); dest_ptx.SetUnbuffered(); dest_ll.SetUnbuffered(); // print ll module->print(dest_ll, nullptr); std::string ll(data_ll.begin(), data_ll.end()); // emit ptx llvm::legacy::PassManager pass; #if TVM_LLVM_VERSION <= 60 CHECK(tm->addPassesToEmitFile( pass, dest_ptx, llvm::TargetMachine::CGFT_AssemblyFile) == 0) << "Cannot emit target CGFT_ObjectFile"; #elif TVM_LLVM_VERSION <= 90 CHECK(tm->addPassesToEmitFile( pass, dest_ptx, nullptr, llvm::TargetMachine::CGFT_AssemblyFile) == 0) << "Cannot emit target CGFT_ObjectFile"; #else CHECK(tm->addPassesToEmitFile( pass, dest_ptx, nullptr, llvm::CGFT_AssemblyFile) == 0) << "Cannot emit target CGFT_ObjectFile"; #endif pass.run(*module); std::string ptx(data_ptx.begin(), data_ptx.end()); return CUDAModuleCreate(ptx, "ptx", ExtractFuncInfo(funcs), ll); } TVM_REGISTER_GLOBAL("codegen.build_nvptx") .set_body_typed(BuildNVPTX); } // namespace codegen } // namespace tvm #endif // TVM_LLVM_VERSION