/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file cuda_half_t.h * \brief half_t (fp16) definition for cuda codegen. */ #ifndef TVM_TARGET_SOURCE_LITERAL_CUDA_HALF_T_H_ #define TVM_TARGET_SOURCE_LITERAL_CUDA_HALF_T_H_ static constexpr const char* _cuda_half_t_def = R"( typedef unsigned short uint16_t; typedef unsigned char uint8_t; typedef signed char int8_t; typedef int int32_t; typedef unsigned long long uint64_t; typedef unsigned int uint32_t; #define TVM_FORCE_INLINE inline __attribute__((always_inline)) #define TVM_XINLINE TVM_FORCE_INLINE __device__ __host__ #define TVM_ALIGNED(x) __attribute__ ((aligned(x))) #define TVM_HALF_OPERATOR(RTYPE, OP) \ TVM_XINLINE RTYPE operator OP (half a, half b) { \ return RTYPE(float(a) OP float(b)); \ } \ template<typename T> \ TVM_XINLINE RTYPE operator OP (half a, T b) { \ return RTYPE(float(a) OP float(b)); \ } \ template<typename T> \ TVM_XINLINE RTYPE operator OP (T a, half b) { \ return RTYPE(float(a) OP float(b)); \ } #define TVM_HALF_ASSIGNOP(AOP, OP) \ template<typename T> \ TVM_XINLINE half operator AOP (const T& a) { \ return *this = half(float(*this) OP float(a)); \ } \ template<typename T> \ TVM_XINLINE half operator AOP (const volatile T& a) volatile { \ return *this = half(float(*this) OP float(a)); \ } class TVM_ALIGNED(2) half { public: uint16_t half_; static TVM_XINLINE half Binary(uint16_t value) { half res; res.half_ = value; return res; } TVM_XINLINE half() {} TVM_XINLINE half(const float& value) { constructor(value); } TVM_XINLINE explicit half(const double& value) { constructor(value); } TVM_XINLINE explicit half(const int8_t& value) { constructor(value); } TVM_XINLINE explicit half(const uint8_t& value) { constructor(value); } TVM_XINLINE explicit half(const int32_t& value) { constructor(value); } TVM_XINLINE explicit half(const uint32_t& value) { constructor(value); } TVM_XINLINE explicit half(const long long& value) { constructor(value); } TVM_XINLINE explicit half(const uint64_t& value) { constructor(value); } TVM_XINLINE operator float() const { \ return float(half2float(half_)); \ } \ TVM_XINLINE operator float() const volatile { \ return float(half2float(half_)); \ } TVM_HALF_ASSIGNOP(+=, +) TVM_HALF_ASSIGNOP(-=, -) TVM_HALF_ASSIGNOP(*=, *) TVM_HALF_ASSIGNOP(/=, /) TVM_XINLINE half operator+() { return *this; } TVM_XINLINE half operator-() { return half(-float(*this)); } TVM_XINLINE half operator=(const half& a) { half_ = a.half_; return a; } template<typename T> TVM_XINLINE half operator=(const T& a) { return *this = half(a); } TVM_XINLINE half operator=(const half& a) volatile { half_ = a.half_; return a; } template<typename T> TVM_XINLINE half operator=(const T& a) volatile { return *this = half(a); } private: union Bits { float f; int32_t si; uint32_t ui; }; static int const fp16FractionBits = 10; static int const fp32FractionBits = 23; static int32_t const fp32FractionMask = ~(~0u << fp32FractionBits); // == 0x7fffff static int32_t const fp32HiddenBit = 1 << fp32FractionBits; // == 0x800000 static int const shift = fp32FractionBits - fp16FractionBits; // == 13 static int const shiftSign = 16; static int32_t const expAdjust = 127 - 15; // exp32-127 = exp16-15, so exp16 = exp32 - (127-15) static int32_t const infN = 0x7F800000; // flt32 infinity static int32_t const maxN = 0x477FFFFF; // max flt32 that's a flt16 normal after >> by shift static int32_t const minN = 0x38800000; // min flt16 normal as a flt32 static int32_t const maxZ = 0x33000000; // max fp32 number that's still rounded to zero in fp16 static int32_t const signN = 0x80000000; // flt32 sign bit static int32_t const infC = infN >> shift; static int32_t const nanN = (infC + 1) << shift; // minimum flt16 nan as a flt32 static int32_t const maxC = maxN >> shift; static int32_t const minC = minN >> shift; static int32_t const signC = signN >> shiftSign; // flt16 sign bit static int32_t const mulN = 0x52000000; // (1 << 23) / minN static int32_t const mulC = 0x33800000; // minN / (1 << (23 - shift)) static int32_t const subC = 0x003FF; // max flt32 subnormal down shifted static int32_t const norC = 0x00400; // min flt32 normal down shifted static int32_t const maxD = infC - maxC - 1; static int32_t const minD = minC - subC - 1; TVM_XINLINE uint16_t float2half(const float& value) const { Bits v; v.f = value; uint32_t sign = v.si & signN; // grab sign bit v.si ^= sign; // clear sign bit from v sign >>= shiftSign; // logical shift sign to fp16 position if (v.si <= maxZ) { // Handle eventual zeros here to ensure // vshift will not exceed 32 below. v.ui = 0; } else if (v.si < minN) { // Handle denorms uint32_t exp32 = v.ui >> fp32FractionBits; int32_t exp16 = exp32 - expAdjust; // If exp16 == 0 (just into the denorm range), then significant should be shifted right 1. // Smaller (so negative) exp16 values should result in greater right shifts. uint32_t vshift = 1 - exp16; uint32_t significand = fp32HiddenBit | (v.ui & fp32FractionMask); v.ui = significand >> vshift; v.ui += (v.ui & 0x3fff) != 0x1000 || (significand & 0x7ff) ? 0x1000 : 0; } else if (v.si <= maxN) { // Handle norms v.ui += (v.ui & 0x3fff) != 0x1000 ? 0x1000 : 0; v.ui -= expAdjust << fp32FractionBits; } else if (v.si <= infN) { v.si = infN; } else if (v.si < nanN) { v.si = nanN; } v.ui >>= shift; return sign | (v.ui & 0x7fff); } // Same as above routine, except for addition of volatile keyword TVM_XINLINE uint16_t float2half( const volatile float& value) const volatile { Bits v; v.f = value; uint32_t sign = v.si & signN; // grab sign bit v.si ^= sign; // clear sign bit from v sign >>= shiftSign; // logical shift sign to fp16 position if (v.si <= maxZ) { // Handle eventual zeros here to ensure // vshift will not exceed 32 below. v.ui = 0; } else if (v.si < minN) { // Handle denorms uint32_t exp32 = v.ui >> fp32FractionBits; int32_t exp16 = exp32 - expAdjust; // If exp16 == 0 (just into the denorm range), then significant should be shifted right 1. // Smaller (so negative) exp16 values should result in greater right shifts. uint32_t vshift = 1 - exp16; uint32_t significand = fp32HiddenBit | (v.ui & fp32FractionMask); v.ui = significand >> vshift; v.ui += (v.ui & 0x3fff) != 0x1000 || (significand & 0x7ff) ? 0x1000 : 0; } else if (v.si <= maxN) { // Handle norms v.ui += (v.ui & 0x3fff) != 0x1000 ? 0x1000 : 0; v.ui -= expAdjust << fp32FractionBits; } else if (v.si <= infN) { v.si = infN; } else if (v.si < nanN) { v.si = nanN; } v.ui >>= shift; return sign | (v.ui & 0x7fff); } TVM_XINLINE float half2float(const uint16_t& value) const { Bits v; v.ui = value; int32_t sign = v.si & signC; v.si ^= sign; sign <<= shiftSign; v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC); v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC); Bits s; s.si = mulC; s.f *= v.si; int32_t mask = -(norC > v.si); v.si <<= shift; v.si ^= (s.si ^ v.si) & mask; v.si |= sign; return v.f; } TVM_XINLINE float half2float( const volatile uint16_t& value) const volatile { Bits v; v.ui = value; int32_t sign = v.si & signC; v.si ^= sign; sign <<= shiftSign; v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC); v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC); Bits s; s.si = mulC; s.f *= v.si; int32_t mask = -(norC > v.si); v.si <<= shift; v.si ^= (s.si ^ v.si) & mask; v.si |= sign; return v.f; } template<typename T> TVM_XINLINE void constructor(const T& value) { half_ = float2half(float(value)); } }; TVM_HALF_OPERATOR(half, +) TVM_HALF_OPERATOR(half, -) TVM_HALF_OPERATOR(half, *) TVM_HALF_OPERATOR(half, /) TVM_HALF_OPERATOR(bool, >) TVM_HALF_OPERATOR(bool, <) TVM_HALF_OPERATOR(bool, >=) TVM_HALF_OPERATOR(bool, <=) TVM_XINLINE half __float2half_rn(const float a) { return half(a); } )"; static constexpr const char* _cuda_half_util = R"( // Pack two half values. static inline __device__ __host__ unsigned __pack_half2(const half x, const half y) { unsigned v0 = *((unsigned short *)&x); unsigned v1 = *((unsigned short *)&y); return (v0 << 16) | v1; } )"; #endif // TVM_TARGET_SOURCE_LITERAL_CUDA_HALF_T_H_