Commit e47bc1dd by reminisce Committed by Yizhi Liu

Add __float2half_rn for cuda compute capabilities less than 53 (#4489)

* Fix

* clean up
parent bca7914a
......@@ -176,8 +176,10 @@ class TVM_ALIGNED(2) half {
uint32_t vshift = 1 - exp16;
uint32_t significand = fp32HiddenBit | (v.ui & fp32FractionMask);
v.ui = significand >> vshift;
v.ui += (v.ui & 0x3fff) != 0x1000 || (significand & 0x7ff) ? 0x1000 : 0;
} else if (v.si <= maxN) {
// Handle norms
v.ui += (v.ui & 0x3fff) != 0x1000 ? 0x1000 : 0;
v.ui -= expAdjust << fp32FractionBits;
} else if (v.si <= infN) {
v.si = infN;
......@@ -211,8 +213,10 @@ class TVM_ALIGNED(2) half {
uint32_t vshift = 1 - exp16;
uint32_t significand = fp32HiddenBit | (v.ui & fp32FractionMask);
v.ui = significand >> vshift;
v.ui += (v.ui & 0x3fff) != 0x1000 || (significand & 0x7ff) ? 0x1000 : 0;
} else if (v.si <= maxN) {
// Handle norms
v.ui += (v.ui & 0x3fff) != 0x1000 ? 0x1000 : 0;
v.ui -= expAdjust << fp32FractionBits;
} else if (v.si <= infN) {
v.si = infN;
......@@ -275,6 +279,10 @@ TVM_HALF_OPERATOR(bool, >)
TVM_HALF_OPERATOR(bool, <)
TVM_HALF_OPERATOR(bool, >=)
TVM_HALF_OPERATOR(bool, <=)
TVM_XINLINE half __float2half_rn(const float a) {
return half(a);
}
)";
#endif // TVM_CODEGEN_LITERAL_CUDA_HALF_T_H_
......@@ -17,6 +17,7 @@
# under the License.
import tvm
import numpy as np
import unittest
from tvm.contrib.nvcc import have_fp16, have_int8
from tvm.contrib import nvcc
......@@ -263,6 +264,32 @@ def test_rfactor_predicates():
fcuda = tvm.build(s, [A, B], "cuda")
@unittest.skipIf(not tvm.gpu(0).exist or not tvm.module.enabled("cuda"), "skip because cuda is not enabled..")
def test_cuda_const_float_to_half():
# This import is required to use nvcc to perform code gen;
# otherwise it is found that the code gen is done by nvrtc.
from tvm import autotvm
shape = (2, 3, 4)
a = tvm.placeholder(shape, dtype='float16', name='a')
b = tvm.const(0.5, dtype='float16')
c = tvm.compute(shape, lambda i, j, k: a[i, j, k] > b, name='c')
s = tvm.create_schedule(c.op)
axes = [axis for axis in c.op.axis]
fused = s[c].fuse(*axes)
bx, tx = s[c].split(fused, factor=64)
s[c].bind(bx, tvm.thread_axis('blockIdx.x'))
s[c].bind(tx, tvm.thread_axis('threadIdx.x'))
func = tvm.build(s, [a, c], 'cuda')
ctx = tvm.gpu(0)
a_np = np.random.uniform(size=shape).astype(a.dtype)
c_np = np.zeros(shape=shape, dtype=c.dtype)
a = tvm.nd.array(a_np, ctx)
c = tvm.nd.array(c_np, ctx)
func(a, c)
np.testing.assert_equal(c.asnumpy(), a_np > b.value)
if __name__ == "__main__":
test_cuda_vectorize_add()
test_cuda_multiply_add()
......@@ -272,3 +299,4 @@ if __name__ == "__main__":
test_cuda_shuffle()
test_cuda_reducition_binding()
test_rfactor_predicates()
test_cuda_const_float_to_half()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment