/*! * Copyright (c) 2017 by Contributors * Build cuda modules from source. * \file build_cuda.cc */ #include <tvm/base.h> #include <tvm/runtime/config.h> #include "./codegen_cuda.h" #include "./build_common.h" #if TVM_CUDA_RUNTIME #include <nvrtc.h> #include "../runtime/cuda/cuda_common.h" #include "../runtime/cuda/cuda_module.h" namespace tvm { namespace codegen { #define NVRTC_CALL(x) \ { \ nvrtcResult result = x; \ if (result != NVRTC_SUCCESS) { \ LOG(FATAL) \ << "NvrtcError: " #x " failed with error: " \ << nvrtcGetErrorString(result); \ } \ } std::string NVRTCCompile(const std::string& code) { nvrtcProgram prog; NVRTC_CALL(nvrtcCreateProgram( &prog, code.c_str(), nullptr, 0, nullptr, nullptr)); nvrtcResult compile_res = nvrtcCompileProgram(prog, 0, nullptr); size_t log_size; NVRTC_CALL(nvrtcGetProgramLogSize(prog, &log_size)); std::string log; log.resize(log_size); NVRTC_CALL(nvrtcGetProgramLog(prog, &log[0])); CHECK_EQ(compile_res, NVRTC_SUCCESS) << log; size_t ptx_size; NVRTC_CALL(nvrtcGetPTXSize(prog, &ptx_size)); std::string ptx; ptx.resize(ptx_size); NVRTC_CALL(nvrtcGetPTX(prog, &ptx[0])); NVRTC_CALL(nvrtcDestroyProgram(&prog)); return ptx; } runtime::Module BuildCUDA(Array<LoweredFunc> funcs) { using tvm::runtime::Registry; bool output_ssa = false; CodeGenCUDA cg; cg.Init(output_ssa); for (LoweredFunc f : funcs) { cg.AddFunction(f); } std::string code = cg.Finish(); if (const auto* f = Registry::Get("tvm_callback_cuda_postproc")) { code = (*f)(code).operator std::string(); } std::string fmt = "ptx"; std::string ptx; if (const auto* f = Registry::Get("tvm_callback_cuda_compile")) { ptx = (*f)(code).operator std::string(); // Dirty matching to check PTX vs cubin. // TODO(tqchen) more reliable checks if (ptx[0] != '/') fmt = "cubin"; } else { ptx = NVRTCCompile(code); } return CUDAModuleCreate(ptx, fmt, ExtractFuncInfo(funcs), code); } TVM_REGISTER_API("codegen.build_cuda") .set_body([](TVMArgs args, TVMRetValue* rv) { *rv = BuildCUDA(args[0]); }); } // namespace codegen } // namespace tvm #endif // TVM_CUDA_RUNTIME