Commit ce72e9b5 by Xingyu Zhou Committed by Wuwei Lin

[codegen] Add multiple operands and function support when using fp16 compilation (#4056)

* overload half operators for cuda codegen

* add float16 te test_op_level1

* fix test_op_level1.py

* fix lint

* disable fp16 test if gpu does not support

* disable fp16 test if gpu does not support

* bypass float16 test if gpu does not support float16
parent d08ec106
...@@ -50,6 +50,20 @@ void CodeGenCUDA::AddFunction(LoweredFunc f) { ...@@ -50,6 +50,20 @@ void CodeGenCUDA::AddFunction(LoweredFunc f) {
std::string CodeGenCUDA::Finish() { std::string CodeGenCUDA::Finish() {
if (enable_fp16_) { if (enable_fp16_) {
decl_stream << "#include <cuda_fp16.h>\n"; decl_stream << "#include <cuda_fp16.h>\n";
decl_stream << "__device__ half max" \
"(const half a, const half b)\n"
"{\n return __hgt(__half(a), __half(b)) ? a : b;\n}\n";
decl_stream << "__device__ half min(const half a, const half b)\n"
"{\n return __hlt(__half(a), __half(b)) ? a : b;\n}\n";
decl_stream << "__device__ half operator+" \
"(const volatile __half &a, const volatile __half &b)\n"
"{\n return __hadd(a, b);\n}\n";
decl_stream << "__device__ half operator<=" \
"(const volatile __half &a, const volatile __half &b)\n"
"{\n return __hlt(a, b);\n}\n";
decl_stream << "__device__ half operator*" \
"(const volatile __half &a, const volatile __half &b)\n"
"{\n return __hmul(a, b);\n}\n";
} }
if (enable_int8_) { if (enable_int8_) {
......
...@@ -19,6 +19,7 @@ import numpy as np ...@@ -19,6 +19,7 @@ import numpy as np
import tvm import tvm
import topi import topi
import topi.testing import topi.testing
from tvm.contrib.nvcc import have_fp16
from common import get_all_backend from common import get_all_backend
...@@ -53,6 +54,9 @@ def verify_reinterpret(in_shape, in_dtype, out_dtype, generator): ...@@ -53,6 +54,9 @@ def verify_reinterpret(in_shape, in_dtype, out_dtype, generator):
if not ctx.exist: if not ctx.exist:
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
if in_dtype == "float16" and device == 'cuda' and not have_fp16(ctx.compute_version):
print("Skip because %s does not have fp16 support" % device)
return
print("Running on target: %s" % device) print("Running on target: %s" % device)
with tvm.target.create(device): with tvm.target.create(device):
s = topi.generic.schedule_elemwise(B) s = topi.generic.schedule_elemwise(B)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment