Commit cf83d50c by Yizhi Liu Committed by Wuwei Lin

[Codegen] remove fp16 function override for cuda (#4331)

* add volatile override back

* [codegen] remove fp16 function override for cuda
parent b127dc76
...@@ -58,15 +58,19 @@ std::string CodeGenCUDA::Finish() { ...@@ -58,15 +58,19 @@ std::string CodeGenCUDA::Finish() {
<< "{\n return __hgt(__half(a), __half(b)) ? a : b;\n}\n"; << "{\n return __hgt(__half(a), __half(b)) ? a : b;\n}\n";
decl_stream << "__device__ half min(half a, half b)\n" decl_stream << "__device__ half min(half a, half b)\n"
<< "{\n return __hlt(__half(a), __half(b)) ? a : b;\n}\n"; << "{\n return __hlt(__half(a), __half(b)) ? a : b;\n}\n";
decl_stream << "__device__ half operator<=" // FIXME(tvm-team): "volatile" is used to enable cross thread reduction,
<< "(__half a, __half b)\n" // which is needed by operations such as softmax.
<< "{\n return __hlt(a, b);\n}\n"; // However, volatile overloading is not supported in NVRTC and CUDA < 9.2.
decl_stream << "__device__ half operator+" // We need to figure out a solution which can satisfy both scenario.
<< "(__half a, __half &b)\n" // decl_stream << "__device__ half operator<="
<<"{\n return __hadd(a, b);\n}\n"; // << "(const volatile __half &a, const volatile __half &b)\n"
decl_stream << "__device__ half operator*" // << "{\n return __hlt(a, b);\n}\n";
<< "(__half a, __half b)\n" // decl_stream << "__device__ half operator+"
<< "{\n return __hmul(a, b);\n}\n"; // << "(const volatile __half &a, const volatile __half &b)\n"
// <<"{\n return __hadd(a, b);\n}\n";
// decl_stream << "__device__ half operator*"
// << "(const volatile __half &a, const volatile __half &b)\n"
// << "{\n return __hmul(a, b);\n}\n";
// otherwise simulate computation via float32 // otherwise simulate computation via float32
decl_stream << "#else\n"; decl_stream << "#else\n";
decl_stream << _cuda_half_t_def; decl_stream << _cuda_half_t_def;
......
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
static constexpr const char* _cuda_half_t_def = R"( static constexpr const char* _cuda_half_t_def = R"(
typedef unsigned short uint16_t; typedef unsigned short uint16_t;
typedef unsigned char uint8_t; typedef unsigned char uint8_t;
typedef signed char int8_t;
typedef int int32_t; typedef int int32_t;
typedef unsigned long long uint64_t; typedef unsigned long long uint64_t;
typedef unsigned int uint32_t; typedef unsigned int uint32_t;
...@@ -76,7 +77,7 @@ class TVM_ALIGNED(2) half { ...@@ -76,7 +77,7 @@ class TVM_ALIGNED(2) half {
TVM_XINLINE explicit half(const uint8_t& value) { constructor(value); } TVM_XINLINE explicit half(const uint8_t& value) { constructor(value); }
TVM_XINLINE explicit half(const int32_t& value) { constructor(value); } TVM_XINLINE explicit half(const int32_t& value) { constructor(value); }
TVM_XINLINE explicit half(const uint32_t& value) { constructor(value); } TVM_XINLINE explicit half(const uint32_t& value) { constructor(value); }
TVM_XINLINE explicit half(const int64_t& value) { constructor(value); } TVM_XINLINE explicit half(const long long& value) { constructor(value); }
TVM_XINLINE explicit half(const uint64_t& value) { constructor(value); } TVM_XINLINE explicit half(const uint64_t& value) { constructor(value); }
TVM_XINLINE operator float() const { \ TVM_XINLINE operator float() const { \
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment