Commit 30409045 by Tatsuya Nishiyama Committed by Tianqi Chen

[CUDA] FP16 support (#1413)

parent 2cf3fd02
...@@ -167,3 +167,78 @@ def callback_libdevice_path(arch): ...@@ -167,3 +167,78 @@ def callback_libdevice_path(arch):
except RuntimeError: except RuntimeError:
warnings.warn("Cannot find libdevice path") warnings.warn("Cannot find libdevice path")
return "" return ""
def parse_compute_version(compute_version):
"""Parse compute capability string to divide major and minor version
Parameters
----------
compute_version : str
compute capability of a GPU (e.g. "6.0")
Returns
-------
major : int
major version number
minor : int
minor version number
"""
split_ver = compute_version.split('.')
try:
major = int(split_ver[0])
minor = int(split_ver[1])
return major, minor
except (IndexError, ValueError) as err:
raise RuntimeError("Compute version parsing error: " + str(err))
def have_fp16(compute_version):
"""Either fp16 support is provided in the compute capability or not
Parameters
----------
compute_version: str
compute capability of a GPU (e.g. "6.0")
"""
major, minor = parse_compute_version(compute_version)
# fp 16 support in reference to:
# https://docs.nvidia.com/cuda/cuda-c-programming-guide/#arithmetic-instructions
if major == 5 and minor == 3:
return True
# NOTE: exclude compute capability 6.1 devices although it is actually available
# to compute fp16, because these devices only have low-rate fp16 performance.
if major == 6 and minor != 1:
return True
if major == 7:
return True
return False
def have_int8(compute_version):
"""Either int8 support is provided in the compute capability or not
Parameters
----------
compute_version : str
compute capability of a GPU (e.g. "6.1")
"""
major, minor = parse_compute_version(compute_version)
if major == 6 and minor == 1:
return True
return False
def have_tensorcore(compute_version):
"""Either TensorCore support is provided in the compute capability or not
Parameters
----------
compute_version : str
compute capability of a GPU (e.g. "7.0")
"""
major, _ = parse_compute_version(compute_version)
if major == 7:
return True
return False
...@@ -29,6 +29,14 @@ void CodeGenCUDA::AddFunction(LoweredFunc f) { ...@@ -29,6 +29,14 @@ void CodeGenCUDA::AddFunction(LoweredFunc f) {
CodeGenC::AddFunction(f); CodeGenC::AddFunction(f);
} }
std::string CodeGenCUDA::Finish() {
if (enable_fp16_) {
decl_stream << "#include <cuda_fp16.h>\n";
}
return CodeGenC::Finish();
}
void CodeGenCUDA::VisitStmt_(const ir::For* op) { void CodeGenCUDA::VisitStmt_(const ir::For* op) {
CHECK(is_zero(op->min)); CHECK(is_zero(op->min));
if (op->for_type == ir::ForType::Unrolled) { if (op->for_type == ir::ForType::Unrolled) {
...@@ -54,7 +62,9 @@ void CodeGenCUDA::PrintType(Type t, std::ostream& os) { // NOLINT(*) ...@@ -54,7 +62,9 @@ void CodeGenCUDA::PrintType(Type t, std::ostream& os) { // NOLINT(*)
bool fail = false; bool fail = false;
if (t.is_float()) { if (t.is_float()) {
switch (t.bits()) { switch (t.bits()) {
case 16: os << "half"; break; case 16: os << "half";
enable_fp16_ = true;
break;
case 32: os << "float"; break; case 32: os << "float"; break;
case 64: os << "double"; break; case 64: os << "double"; break;
default: fail = true; break; default: fail = true; break;
...@@ -258,5 +268,30 @@ void CodeGenCUDA::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOLIN ...@@ -258,5 +268,30 @@ void CodeGenCUDA::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOLIN
os << ')'; os << ')';
} }
inline void PrintConst(const FloatImm* op, std::ostream& os, CodeGenCUDA* p) { // NOLINT(*)
switch (op->type.bits()) {
case 64: case 32: {
std::ostringstream temp;
temp << std::scientific << op->value;
if (op->type.bits() == 32) temp << 'f';
p->MarkConst(temp.str());
os << temp.str();
break;
}
case 16: {
os << "__float2half_rn";
os << '(' << std::scientific << op->value << 'f' << ')';
break;
}
default: LOG(FATAL) << "Bad bit-width for float: " << op->type << "\n";
}
}
void CodeGenCUDA::VisitExpr_(const FloatImm *op, std::ostream& os) { // NOLINT(*)
PrintConst(op, os, this);
}
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
...@@ -19,6 +19,8 @@ class CodeGenCUDA final : public CodeGenC { ...@@ -19,6 +19,8 @@ class CodeGenCUDA final : public CodeGenC {
CodeGenCUDA(); CodeGenCUDA();
void Init(bool output_ssa); void Init(bool output_ssa);
void AddFunction(LoweredFunc f); void AddFunction(LoweredFunc f);
std::string Finish();
bool need_include_path() { return enable_fp16_; }
// override behavior // override behavior
void VisitStmt_(const ir::For* op) final; void VisitStmt_(const ir::For* op) final;
void PrintStorageSync(const Call* op) final; void PrintStorageSync(const Call* op) final;
...@@ -35,6 +37,7 @@ class CodeGenCUDA final : public CodeGenC { ...@@ -35,6 +37,7 @@ class CodeGenCUDA final : public CodeGenC {
// overload visitor // overload visitor
void VisitExpr_(const Ramp* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const Ramp* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const FloatImm *op, std::ostream& os) final;
void VisitStmt_(const Evaluate *op) final; void VisitStmt_(const Evaluate *op) final;
private: private:
...@@ -44,6 +47,8 @@ class CodeGenCUDA final : public CodeGenC { ...@@ -44,6 +47,8 @@ class CodeGenCUDA final : public CodeGenC {
std::string vid_global_barrier_state_; std::string vid_global_barrier_state_;
// Global barrier expected node. // Global barrier expected node.
std::string vid_global_barrier_expect_; std::string vid_global_barrier_expect_;
// whether enable fp16
bool enable_fp16_{false};
}; };
} // namespace codegen } // namespace codegen
......
...@@ -5,14 +5,20 @@ ...@@ -5,14 +5,20 @@
* *
* \file build_cuda.cc * \file build_cuda.cc
*/ */
#if defined(__linux__)
#include <sys/stat.h>
#endif
#include <cuda_runtime.h>
#include <tvm/base.h> #include <tvm/base.h>
#include <nvrtc.h> #include <nvrtc.h>
#include <cstdlib>
#include "../codegen_cuda.h" #include "../codegen_cuda.h"
#include "../build_common.h" #include "../build_common.h"
#include "../../runtime/cuda/cuda_common.h" #include "../../runtime/cuda/cuda_common.h"
#include "../../runtime/cuda/cuda_module.h" #include "../../runtime/cuda/cuda_module.h"
namespace tvm { namespace tvm {
namespace codegen { namespace codegen {
...@@ -26,11 +32,69 @@ namespace codegen { ...@@ -26,11 +32,69 @@ namespace codegen {
} \ } \
} }
std::string NVRTCCompile(const std::string& code) {
std::string FindCUDAIncludePath() {
#if defined(_WIN32)
const std::string delimiter = "\\";
#else
const std::string delimiter = "/";
#endif
std::string cuda_include_path;
const char* cuda_path_env = std::getenv("CUDA_PATH");
if (cuda_path_env != nullptr) {
cuda_include_path += cuda_path_env;
cuda_include_path += delimiter + "include";
return cuda_include_path;
}
#if defined(__linux__)
struct stat st;
cuda_include_path = "/usr/local/cuda/include";
if (stat(cuda_include_path.c_str(), &st) == 0) {
return cuda_include_path;
}
#endif
LOG(FATAL) << "Cannot find cuda include path."
<< "CUDA_PATH is not set or CUDA is not installed in the default installation path."
<< "In other than linux, it is necessary to set CUDA_PATH.";
return cuda_include_path;
}
std::string NVRTCCompile(const std::string& code, bool include_path = false) {
std::vector<std::string> compile_params;
std::vector<const char*> param_cstrings{};
int num_options = 0;
nvrtcProgram prog; nvrtcProgram prog;
cudaDeviceProp device_prop;
std::string cc = "30";
cudaError_t e = cudaGetDeviceProperties(&device_prop, 0);
if (e == cudaSuccess) {
cc = std::to_string(device_prop.major) + std::to_string(device_prop.minor);
} else {
LOG(WARNING) << "cannot detect compute capability from your device, "
<< "fall back to compute_30.";
}
compile_params.push_back("-arch=compute_" + cc);
num_options++;
if (include_path) {
std::string include_option = "--include-path=" + FindCUDAIncludePath();
compile_params.push_back(include_option);
num_options++;
}
for (const auto& string : compile_params) {
param_cstrings.push_back(string.c_str());
}
NVRTC_CALL(nvrtcCreateProgram( NVRTC_CALL(nvrtcCreateProgram(
&prog, code.c_str(), nullptr, 0, nullptr, nullptr)); &prog, code.c_str(), nullptr, 0, nullptr, nullptr));
nvrtcResult compile_res = nvrtcCompileProgram(prog, 0, nullptr); nvrtcResult compile_res =
nvrtcCompileProgram(prog, param_cstrings.size(), param_cstrings.data());
size_t log_size; size_t log_size;
NVRTC_CALL(nvrtcGetProgramLogSize(prog, &log_size)); NVRTC_CALL(nvrtcGetProgramLogSize(prog, &log_size));
std::string log; log.resize(log_size); std::string log; log.resize(log_size);
...@@ -43,6 +107,7 @@ std::string NVRTCCompile(const std::string& code) { ...@@ -43,6 +107,7 @@ std::string NVRTCCompile(const std::string& code) {
ptx.resize(ptx_size); ptx.resize(ptx_size);
NVRTC_CALL(nvrtcGetPTX(prog, &ptx[0])); NVRTC_CALL(nvrtcGetPTX(prog, &ptx[0]));
NVRTC_CALL(nvrtcDestroyProgram(&prog)); NVRTC_CALL(nvrtcDestroyProgram(&prog));
return ptx; return ptx;
} }
...@@ -68,7 +133,7 @@ runtime::Module BuildCUDA(Array<LoweredFunc> funcs) { ...@@ -68,7 +133,7 @@ runtime::Module BuildCUDA(Array<LoweredFunc> funcs) {
// TODO(tqchen) more reliable checks // TODO(tqchen) more reliable checks
if (ptx[0] != '/') fmt = "cubin"; if (ptx[0] != '/') fmt = "cubin";
} else { } else {
ptx = NVRTCCompile(code); ptx = NVRTCCompile(code, cg.need_include_path());
} }
return CUDAModuleCreate(ptx, fmt, ExtractFuncInfo(funcs), code); return CUDAModuleCreate(ptx, fmt, ExtractFuncInfo(funcs), code);
} }
......
import tvm
import numpy as np
from tvm.contrib.nvcc import have_fp16
def test_cuda_vectorize_add():
num_thread = 8
def check_cuda(dtype, n, lanes):
if not tvm.gpu(0).exist or not tvm.module.enabled("cuda"):
print("skip because cuda is not enabled..")
return
if dtype == "float16" and not have_fp16(tvm.gpu(0).compute_version):
print("skip because gpu does not support fp16")
return
A = tvm.placeholder((n,), name='A', dtype="%sx%d" % (dtype, lanes))
B = tvm.compute((n,), lambda i: A[i]+tvm.const(1, A.dtype), name='B')
s = tvm.create_schedule(B.op)
xo, xi = s[B].split(B.op.axis[0], factor=num_thread)
s[B].bind(xo, tvm.thread_axis("blockIdx.x"))
s[B].bind(xi, tvm.thread_axis("threadIdx.x"))
fun = tvm.build(s, [A, B], "cuda")
ctx = tvm.gpu(0)
a = tvm.nd.empty((n,), A.dtype, ctx).copyfrom(
np.random.uniform(size=(n, lanes)))
c = tvm.nd.empty((n,), B.dtype, ctx)
fun(a, c)
np.testing.assert_allclose(c.asnumpy(), a.asnumpy() + 1)
check_cuda("float32", 64, 2)
check_cuda("float16", 64, 2)
if __name__ == "__main__":
test_cuda_vectorize_add()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment