codegen_nvptx.cc 7.64 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132
/*!
 *  Copyright (c) 2017 by Contributors
 * \file codegen_nvptx.cc
 * \brief NVPTX code generator.
 */
#ifdef TVM_LLVM_VERSION
#if TVM_CUDA_RUNTIME

#include <tvm/runtime/device_api.h>
#include "./codegen_llvm.h"
#include "../build_common.h"
#include "../../pass/ir_util.h"
#include "../../runtime/cuda/cuda_module.h"

namespace tvm {
namespace codegen {

// NVPTX code generator.
class CodeGenNVPTX : public CodeGenLLVM {
 public:
  void AddFunction(const LoweredFunc& f) final {
    // add function as void return value
    CodeGenLLVM::AddFunctionInternal(f, true);
    // annotate as kernel function
    module_->getOrInsertNamedMetadata("nvvm.annotations")
        ->addOperand(llvm::MDNode::get(*ctx_, {
              llvm::ValueAsMetadata::get(function_),
              llvm::MDString::get(*ctx_, "kernel"),
              llvm::ValueAsMetadata::get(ConstInt32(1)) }));
  }

  void VisitStmt_(const Allocate* op) final {
    CHECK(!is_zero(op->condition));
    llvm::Value* buf = nullptr;
    if (op->new_expr.defined()) {
      CHECK_EQ(op->free_function, "nop");
      buf = MakeValue(op->new_expr);
    } else {
      int32_t constant_size = op->constant_allocation_size();
      CHECK_GT(constant_size, 0)
          << "Can only handle constant size stack allocation in GPU";
      StorageInfo& info = alloc_storage_info_[op->buffer_var.get()];
      if (constant_size % 4 == 0 && info.alignment == 0) {
        info.alignment = GetTempAllocaAlignment(op->type, constant_size);
      }
      // maximum necessary alignment in the NV devices
      if (info.alignment > 16) {
        info.alignment = 16;
      }
      if (info.scope.rank == 2) {
        // const int local_address_space = 5;
        // TODO(tqchen): for higher version of LLVM, local address space can be set.
        llvm::AllocaInst* alloca = builder_->CreateAlloca(
            LLVMType(op->type), ConstInt32(constant_size));
        if (alloca->getAlignment() < static_cast<uint32_t>(info.alignment)) {
          alloca->setAlignment(info.alignment);
        }
        buf = alloca;
      } else {
        CHECK_EQ(info.scope.rank, 1)
            << "Can only allocate shared or local memory inside kernel";
        // Shared memory: address space  == 3
        const unsigned shared_address_space = 3;
        llvm::Type* type = llvm::ArrayType::get(LLVMType(op->type), constant_size);
        // Allocate shared memory in global, address_space = 3
        llvm::GlobalVariable *global = new llvm::GlobalVariable(
            *module_, type, false, llvm::GlobalValue::PrivateLinkage, 0, ".shared",
            nullptr, llvm::GlobalValue::NotThreadLocal, shared_address_space);
        global->setAlignment(info.alignment);
        buf = global;
      }
    }
    buf = builder_->CreatePointerCast(
        buf, LLVMType(op->type)->getPointerTo(
            buf->getType()->getPointerAddressSpace()));
    CHECK(!var_map_.count(op->buffer_var.get()));
    var_map_[op->buffer_var.get()] = buf;
    this->VisitStmt(op->body);
  }

  // Return the thread index via intrinsics.
  llvm::Value* GetThreadIndex(const IterVar& iv) final {
    runtime::ThreadScope ts = runtime::ThreadScope::make(iv->thread_tag);
    llvm::Intrinsic::ID intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x;
    if (ts.rank == 1) {
      switch (ts.dim_index) {
        case 0: intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x; break;
        case 1: intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_tid_y; break;
        case 2: intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_tid_z; break;
        default: LOG(FATAL) << "unknown thread idx";
      }
    } else {
      CHECK_EQ(ts.rank, 0);
      switch (ts.dim_index) {
        case 0: intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x; break;
        case 1: intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_y; break;
        case 2: intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_z; break;
        default: LOG(FATAL) << "unknown thread idx";
      }
    }
    llvm::Function* f = llvm::Intrinsic::getDeclaration(module_.get(), intrin_id);
    return builder_->CreateCall(f, {});
  }

  llvm::Value* CreateStorageSync(const Call* op) final {
    const std::string& sync = op->args[0].as<StringImm>()->value;
    if (sync == "warp") {
      // TODO(tqchen) warp sync in CUDA9
      return nullptr;
    } else if (sync == "shared") {
      llvm::Function* f = llvm::Intrinsic::getDeclaration(
          module_.get(),
          ::llvm::Intrinsic::nvvm_barrier0);
      return builder_->CreateCall(f, {});
    } else {
      LOG(FATAL) << "Do not support sync " << sync;
      return nullptr;
    }
  }

  void InitPassManagerBuilder(llvm::PassManagerBuilder* builder) final {
    // Additional optimization hook to tweak the builder.
  }

 protected:
  void InitTarget(llvm::TargetMachine* tm) final {
    // Maximum vector lane = float4
    native_vector_bits_ = 4 * 32;
    CodeGenLLVM::InitTarget(tm);
  }
};

133 134
inline int DetectCUDAComputeVersion() {
  TVMContext tvm_ctx;
135
  tvm_ctx.device_type = kDLGPU;
136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152
  tvm_ctx.device_id = 0;
  TVMRetValue val;
  tvm::runtime::DeviceAPI::Get(tvm_ctx)->GetAttr(
      tvm_ctx, tvm::runtime::kExist, &val);
  if (val.operator int() == 1) {
    tvm::runtime::DeviceAPI::Get(tvm_ctx)->GetAttr(
        tvm_ctx, tvm::runtime::kComputeVersion, &val);
    std::string version = val;
    std::istringstream is(version);
    double ver;
    is >> ver;
    return static_cast<int>(ver * 10);
  } else {
    return 20;
  }
}

153
runtime::Module BuildNVPTX(Array<LoweredFunc> funcs, std::string target) {
154
  CHECK(target.length() >= 5 &&
155
        target.substr(0, 5) == "nvptx");
156
  int compute_ver = DetectCUDAComputeVersion();
157 158
  std::ostringstream config;
  config << "-mtriple=nvptx64-nvidia-cuda -mcpu=sm_"
159
         << compute_ver
160 161
         << target.substr(5, target.length() - 5);
  llvm::TargetMachine* tm = GetLLVMTargetMachine(config.str());
162 163 164 165 166 167
  std::unique_ptr<CodeGenNVPTX> cg(new CodeGenNVPTX());
  std::unique_ptr<llvm::LLVMContext> ctx(new llvm::LLVMContext());
  cg->Init(funcs[0]->name, tm, ctx.get(), false, false);
  for (LoweredFunc f :  funcs) {
    cg->AddFunction(f);
  }
168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186

  const auto* flibdevice_path =
      tvm::runtime::Registry::Get("tvm_callback_libdevice_path");
  if (flibdevice_path != nullptr) {
    std::string path = (*flibdevice_path)(compute_ver);
    if (path.length() != 0) {
      llvm::SMDiagnostic err;
      std::unique_ptr<llvm::Module> mlib = llvm::parseIRFile(path, err, *ctx);
      if (mlib.get() == nullptr) {
        std::string msg = err.getMessage();
        LOG(FATAL) << "Fail to load bitcode file " << path << "\n"
                   << "line " << err.getLineNo() << ":" << msg;
      }
      mlib->setTargetTriple(tm->getTargetTriple().str());
      mlib->setDataLayout(tm->createDataLayout());
      // TODO(tqchen) libdevice linking not yet working.
      // cg->AddLinkModule(std::move(mlib));
    }
  }
187
  std::unique_ptr<llvm::Module> module = cg->Finish();
188 189 190 191
  llvm::SmallString<8> data_ptx, data_ll;
  llvm::raw_svector_ostream dest_ptx(data_ptx), dest_ll(data_ll);
  dest_ptx.SetUnbuffered();
  dest_ll.SetUnbuffered();
192 193 194 195
  // print ll
  module->print(dest_ll, nullptr);
  std::string ll(data_ll.begin(), data_ll.end());
  // emit ptx
196 197
  llvm::legacy::PassManager pass;
  CHECK(tm->addPassesToEmitFile(
198
      pass, dest_ptx, llvm::TargetMachine::CGFT_AssemblyFile) == 0)
199 200
      << "Cannot emit target CGFT_ObjectFile";
  pass.run(*module);
201 202
  std::string ptx(data_ptx.begin(), data_ptx.end());
  return CUDAModuleCreate(ptx, "ptx", ExtractFuncInfo(funcs), ll);
203 204 205 206 207 208 209 210 211 212 213
}

TVM_REGISTER_API("codegen.build_nvptx")
.set_body([](TVMArgs args, TVMRetValue* rv) {
    *rv = BuildNVPTX(args[0], args[1]);
  });

}  // namespace codegen
}  // namespace tvm
#endif   // TVM_CUDA_RUNTIME
#endif  // TVM_LLVM_VERSION