codegen_nvptx.cc 8.31 KB
Newer Older
1 2 3 4 5 6 7 8
/*!
 *  Copyright (c) 2017 by Contributors
 * \file codegen_nvptx.cc
 * \brief NVPTX code generator.
 */
#ifdef TVM_LLVM_VERSION

#include <tvm/runtime/device_api.h>
9
#include "codegen_llvm.h"
10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48
#include "../build_common.h"
#include "../../pass/ir_util.h"
#include "../../runtime/cuda/cuda_module.h"

namespace tvm {
namespace codegen {

// NVPTX code generator.
class CodeGenNVPTX : public CodeGenLLVM {
 public:
  void AddFunction(const LoweredFunc& f) final {
    // add function as void return value
    CodeGenLLVM::AddFunctionInternal(f, true);
    // annotate as kernel function
    module_->getOrInsertNamedMetadata("nvvm.annotations")
        ->addOperand(llvm::MDNode::get(*ctx_, {
              llvm::ValueAsMetadata::get(function_),
              llvm::MDString::get(*ctx_, "kernel"),
              llvm::ValueAsMetadata::get(ConstInt32(1)) }));
  }

  void VisitStmt_(const Allocate* op) final {
    CHECK(!is_zero(op->condition));
    llvm::Value* buf = nullptr;
    if (op->new_expr.defined()) {
      CHECK_EQ(op->free_function, "nop");
      buf = MakeValue(op->new_expr);
    } else {
      int32_t constant_size = op->constant_allocation_size();
      CHECK_GT(constant_size, 0)
          << "Can only handle constant size stack allocation in GPU";
      StorageInfo& info = alloc_storage_info_[op->buffer_var.get()];
      if (constant_size % 4 == 0 && info.alignment == 0) {
        info.alignment = GetTempAllocaAlignment(op->type, constant_size);
      }
      // maximum necessary alignment in the NV devices
      if (info.alignment > 16) {
        info.alignment = 16;
      }
49
      if (info.scope.rank == runtime::StorageRank::kLocal) {
50 51 52 53 54 55 56 57 58
        // const int local_address_space = 5;
        // TODO(tqchen): for higher version of LLVM, local address space can be set.
        llvm::AllocaInst* alloca = builder_->CreateAlloca(
            LLVMType(op->type), ConstInt32(constant_size));
        if (alloca->getAlignment() < static_cast<uint32_t>(info.alignment)) {
          alloca->setAlignment(info.alignment);
        }
        buf = alloca;
      } else {
59
        CHECK(info.scope.rank == runtime::StorageRank::kShared)
60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123
            << "Can only allocate shared or local memory inside kernel";
        // Shared memory: address space  == 3
        const unsigned shared_address_space = 3;
        llvm::Type* type = llvm::ArrayType::get(LLVMType(op->type), constant_size);
        // Allocate shared memory in global, address_space = 3
        llvm::GlobalVariable *global = new llvm::GlobalVariable(
            *module_, type, false, llvm::GlobalValue::PrivateLinkage, 0, ".shared",
            nullptr, llvm::GlobalValue::NotThreadLocal, shared_address_space);
        global->setAlignment(info.alignment);
        buf = global;
      }
    }
    buf = builder_->CreatePointerCast(
        buf, LLVMType(op->type)->getPointerTo(
            buf->getType()->getPointerAddressSpace()));
    CHECK(!var_map_.count(op->buffer_var.get()));
    var_map_[op->buffer_var.get()] = buf;
    this->VisitStmt(op->body);
  }

  // Return the thread index via intrinsics.
  llvm::Value* GetThreadIndex(const IterVar& iv) final {
    runtime::ThreadScope ts = runtime::ThreadScope::make(iv->thread_tag);
    llvm::Intrinsic::ID intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x;
    if (ts.rank == 1) {
      switch (ts.dim_index) {
        case 0: intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x; break;
        case 1: intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_tid_y; break;
        case 2: intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_tid_z; break;
        default: LOG(FATAL) << "unknown thread idx";
      }
    } else {
      CHECK_EQ(ts.rank, 0);
      switch (ts.dim_index) {
        case 0: intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x; break;
        case 1: intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_y; break;
        case 2: intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_z; break;
        default: LOG(FATAL) << "unknown thread idx";
      }
    }
    llvm::Function* f = llvm::Intrinsic::getDeclaration(module_.get(), intrin_id);
    return builder_->CreateCall(f, {});
  }

  llvm::Value* CreateStorageSync(const Call* op) final {
    const std::string& sync = op->args[0].as<StringImm>()->value;
    if (sync == "warp") {
      // TODO(tqchen) warp sync in CUDA9
      return nullptr;
    } else if (sync == "shared") {
      llvm::Function* f = llvm::Intrinsic::getDeclaration(
          module_.get(),
          ::llvm::Intrinsic::nvvm_barrier0);
      return builder_->CreateCall(f, {});
    } else {
      LOG(FATAL) << "Do not support sync " << sync;
      return nullptr;
    }
  }

  void InitPassManagerBuilder(llvm::PassManagerBuilder* builder) final {
    // Additional optimization hook to tweak the builder.
  }

124 125 126 127 128 129 130 131 132 133 134 135 136 137
  void Optimize() final {
    for (auto& f : *module_) {
      auto fname = static_cast<std::string>(f.getName());
      if (fname.substr(0, 4) != "__nv") continue;
      // This is to strip off unused __nv_* functions from the final module
      // The one that is actually used will be inlined at call site
      // Adapted from Halide's runtime linker
      if (!f.isDeclaration() && !f.hasFnAttribute(llvm::Attribute::NoInline)) {
        f.setLinkage(llvm::GlobalValue::AvailableExternallyLinkage);
      }
    }
    CodeGenLLVM::Optimize();
  }

138 139 140 141 142 143 144 145
 protected:
  void InitTarget(llvm::TargetMachine* tm) final {
    // Maximum vector lane = float4
    native_vector_bits_ = 4 * 32;
    CodeGenLLVM::InitTarget(tm);
  }
};

146 147
inline int DetectCUDAComputeVersion() {
  TVMContext tvm_ctx;
148
  tvm_ctx.device_type = kDLGPU;
149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165
  tvm_ctx.device_id = 0;
  TVMRetValue val;
  tvm::runtime::DeviceAPI::Get(tvm_ctx)->GetAttr(
      tvm_ctx, tvm::runtime::kExist, &val);
  if (val.operator int() == 1) {
    tvm::runtime::DeviceAPI::Get(tvm_ctx)->GetAttr(
        tvm_ctx, tvm::runtime::kComputeVersion, &val);
    std::string version = val;
    std::istringstream is(version);
    double ver;
    is >> ver;
    return static_cast<int>(ver * 10);
  } else {
    return 20;
  }
}

166
runtime::Module BuildNVPTX(Array<LoweredFunc> funcs, std::string target) {
167
  CHECK(target.length() >= 5 &&
168
        target.substr(0, 5) == "nvptx");
169
  int compute_ver = DetectCUDAComputeVersion();
170 171
  std::ostringstream config;
  config << "-mtriple=nvptx64-nvidia-cuda -mcpu=sm_"
172
         << compute_ver
173 174
         << target.substr(5, target.length() - 5);
  llvm::TargetMachine* tm = GetLLVMTargetMachine(config.str());
175 176 177 178 179 180
  std::unique_ptr<CodeGenNVPTX> cg(new CodeGenNVPTX());
  std::unique_ptr<llvm::LLVMContext> ctx(new llvm::LLVMContext());
  cg->Init(funcs[0]->name, tm, ctx.get(), false, false);
  for (LoweredFunc f :  funcs) {
    cg->AddFunction(f);
  }
181 182 183 184 185 186 187 188 189 190 191 192 193 194 195

  const auto* flibdevice_path =
      tvm::runtime::Registry::Get("tvm_callback_libdevice_path");
  if (flibdevice_path != nullptr) {
    std::string path = (*flibdevice_path)(compute_ver);
    if (path.length() != 0) {
      llvm::SMDiagnostic err;
      std::unique_ptr<llvm::Module> mlib = llvm::parseIRFile(path, err, *ctx);
      if (mlib.get() == nullptr) {
        std::string msg = err.getMessage();
        LOG(FATAL) << "Fail to load bitcode file " << path << "\n"
                   << "line " << err.getLineNo() << ":" << msg;
      }
      mlib->setTargetTriple(tm->getTargetTriple().str());
      mlib->setDataLayout(tm->createDataLayout());
196
      cg->AddLinkModule(std::move(mlib));
197 198
    }
  }
199
  std::unique_ptr<llvm::Module> module = cg->Finish();
200 201 202 203
  llvm::SmallString<8> data_ptx, data_ll;
  llvm::raw_svector_ostream dest_ptx(data_ptx), dest_ll(data_ll);
  dest_ptx.SetUnbuffered();
  dest_ll.SetUnbuffered();
204 205 206 207
  // print ll
  module->print(dest_ll, nullptr);
  std::string ll(data_ll.begin(), data_ll.end());
  // emit ptx
208
  llvm::legacy::PassManager pass;
209
#if TVM_LLVM_VERSION <= 60
210
  CHECK(tm->addPassesToEmitFile(
211
      pass, dest_ptx, llvm::TargetMachine::CGFT_AssemblyFile) == 0)
212
      << "Cannot emit target CGFT_ObjectFile";
213 214 215 216 217
#else
  CHECK(tm->addPassesToEmitFile(
      pass, dest_ptx, nullptr, llvm::TargetMachine::CGFT_AssemblyFile) == 0)
      << "Cannot emit target CGFT_ObjectFile";
#endif
218
  pass.run(*module);
219 220
  std::string ptx(data_ptx.begin(), data_ptx.end());
  return CUDAModuleCreate(ptx, "ptx", ExtractFuncInfo(funcs), ll);
221 222 223 224 225 226 227 228 229 230
}

TVM_REGISTER_API("codegen.build_nvptx")
.set_body([](TVMArgs args, TVMRetValue* rv) {
    *rv = BuildNVPTX(args[0], args[1]);
  });

}  // namespace codegen
}  // namespace tvm
#endif  // TVM_LLVM_VERSION