Commit cf83d50c by Yizhi Liu Committed by Wuwei Lin

[Codegen] remove fp16 function override for cuda (#4331)

* add volatile override back

* [codegen] remove fp16 function override for cuda
parent b127dc76
......@@ -58,15 +58,19 @@ std::string CodeGenCUDA::Finish() {
<< "{\n return __hgt(__half(a), __half(b)) ? a : b;\n}\n";
decl_stream << "__device__ half min(half a, half b)\n"
<< "{\n return __hlt(__half(a), __half(b)) ? a : b;\n}\n";
decl_stream << "__device__ half operator<="
<< "(__half a, __half b)\n"
<< "{\n return __hlt(a, b);\n}\n";
decl_stream << "__device__ half operator+"
<< "(__half a, __half &b)\n"
<<"{\n return __hadd(a, b);\n}\n";
decl_stream << "__device__ half operator*"
<< "(__half a, __half b)\n"
<< "{\n return __hmul(a, b);\n}\n";
// FIXME(tvm-team): "volatile" is used to enable cross thread reduction,
// which is needed by operations such as softmax.
// However, volatile overloading is not supported in NVRTC and CUDA < 9.2.
// We need to figure out a solution which can satisfy both scenario.
// decl_stream << "__device__ half operator<="
// << "(const volatile __half &a, const volatile __half &b)\n"
// << "{\n return __hlt(a, b);\n}\n";
// decl_stream << "__device__ half operator+"
// << "(const volatile __half &a, const volatile __half &b)\n"
// <<"{\n return __hadd(a, b);\n}\n";
// decl_stream << "__device__ half operator*"
// << "(const volatile __half &a, const volatile __half &b)\n"
// << "{\n return __hmul(a, b);\n}\n";
// otherwise simulate computation via float32
decl_stream << "#else\n";
decl_stream << _cuda_half_t_def;
......
......@@ -28,6 +28,7 @@
static constexpr const char* _cuda_half_t_def = R"(
typedef unsigned short uint16_t;
typedef unsigned char uint8_t;
typedef signed char int8_t;
typedef int int32_t;
typedef unsigned long long uint64_t;
typedef unsigned int uint32_t;
......@@ -76,7 +77,7 @@ class TVM_ALIGNED(2) half {
TVM_XINLINE explicit half(const uint8_t& value) { constructor(value); }
TVM_XINLINE explicit half(const int32_t& value) { constructor(value); }
TVM_XINLINE explicit half(const uint32_t& value) { constructor(value); }
TVM_XINLINE explicit half(const int64_t& value) { constructor(value); }
TVM_XINLINE explicit half(const long long& value) { constructor(value); }
TVM_XINLINE explicit half(const uint64_t& value) { constructor(value); }
TVM_XINLINE operator float() const { \
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment