build_cuda.cc 2.53 KB
Newer Older
1 2 3 4 5 6 7 8
/*!
 *  Copyright (c) 2017 by Contributors
 *  Build cuda modules from source.
 * \file build_cuda.cc
 */
#include <tvm/base.h>
#include <tvm/runtime/config.h>
#include "./codegen_cuda.h"
9
#include "./build_common.h"
10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62

#if TVM_CUDA_RUNTIME
#include <nvrtc.h>
#include "../runtime/cuda/cuda_common.h"
#include "../runtime/cuda/cuda_module.h"

namespace tvm {
namespace codegen {

#define NVRTC_CALL(x)                                                   \
  {                                                                     \
    nvrtcResult result = x;                                             \
    if (result != NVRTC_SUCCESS) {                                      \
      LOG(FATAL)                                                        \
          << "NvrtcError: " #x " failed with error: "                   \
          << nvrtcGetErrorString(result);                               \
    }                                                                   \
  }

std::string NVRTCCompile(const std::string& code) {
  nvrtcProgram prog;
  NVRTC_CALL(nvrtcCreateProgram(
      &prog, code.c_str(), nullptr, 0, nullptr, nullptr));
  nvrtcResult compile_res = nvrtcCompileProgram(prog, 0, nullptr);
  size_t log_size;
  NVRTC_CALL(nvrtcGetProgramLogSize(prog, &log_size));
  std::string log; log.resize(log_size);
  NVRTC_CALL(nvrtcGetProgramLog(prog, &log[0]));
  CHECK_EQ(compile_res, NVRTC_SUCCESS) << log;
  size_t ptx_size;
  NVRTC_CALL(nvrtcGetPTXSize(prog, &ptx_size));

  std::string ptx;
  ptx.resize(ptx_size);
  NVRTC_CALL(nvrtcGetPTX(prog, &ptx[0]));
  NVRTC_CALL(nvrtcDestroyProgram(&prog));
  return ptx;
}

runtime::Module BuildCUDA(Array<LoweredFunc> funcs) {
  using tvm::runtime::Registry;
  bool output_ssa = false;
  CodeGenCUDA cg;
  cg.Init(output_ssa);

  for (LoweredFunc f : funcs) {
    cg.AddFunction(f);
  }
  std::string code = cg.Finish();

  if (const auto* f = Registry::Get("tvm_callback_cuda_postproc")) {
    code = (*f)(code).operator std::string();
  }
63
  std::string fmt = "ptx";
64 65 66
  std::string ptx;
  if (const auto* f = Registry::Get("tvm_callback_cuda_compile")) {
    ptx = (*f)(code).operator std::string();
67 68 69
    // Dirty matching to check PTX vs cubin.
    // TODO(tqchen) more reliable checks
    if (ptx[0] != '/') fmt = "cubin";
70 71 72
  } else {
    ptx = NVRTCCompile(code);
  }
73
  return CUDAModuleCreate(ptx, fmt, ExtractFuncInfo(funcs), code);
74 75
}

76
TVM_REGISTER_API("codegen.build_cuda")
77 78 79 80 81 82
.set_body([](TVMArgs args, TVMRetValue* rv) {
    *rv = BuildCUDA(args[0]);
  });
}  // namespace codegen
}  // namespace tvm
#endif   // TVM_CUDA_RUNTIME