Commit 24fe04f8 by lixiaoquan Committed by Tianqi Chen

[CODEGEN][CUDA][OPENCL] Handle INF and NAN (#3194)

parent 246b4109
...@@ -57,6 +57,10 @@ std::string CodeGenCUDA::Finish() { ...@@ -57,6 +57,10 @@ std::string CodeGenCUDA::Finish() {
decl_stream << "#include <sm_61_intrinsics.h>\n"; decl_stream << "#include <sm_61_intrinsics.h>\n";
} }
if (need_math_constants_h_) {
decl_stream << "#include <math_constants.h>\n";
}
return CodeGenC::Finish(); return CodeGenC::Finish();
} }
...@@ -318,8 +322,19 @@ inline void PrintConst(const FloatImm* op, std::ostream& os, CodeGenCUDA* p) { / ...@@ -318,8 +322,19 @@ inline void PrintConst(const FloatImm* op, std::ostream& os, CodeGenCUDA* p) { /
switch (op->type.bits()) { switch (op->type.bits()) {
case 64: case 32: { case 64: case 32: {
std::ostringstream temp; std::ostringstream temp;
temp << std::scientific << op->value; if (std::isinf(op->value)) {
if (op->type.bits() == 32) temp << 'f'; if (op->value < 0) {
temp << "-";
}
temp << ((op->type.bits() == 32) ? "CUDART_INF_F" : "CUDART_INF");
p->need_math_constants_h_ = true;
} else if (std::isnan(op->value)) {
temp << ((op->type.bits() == 32) ? "CUDART_NAN_F" : "CUDART_NAN");
p->need_math_constants_h_ = true;
} else {
temp << std::scientific << op->value;
if (op->type.bits() == 32) temp << 'f';
}
p->MarkConst(temp.str()); p->MarkConst(temp.str());
os << temp.str(); os << temp.str();
break; break;
......
...@@ -39,7 +39,9 @@ class CodeGenCUDA final : public CodeGenC { ...@@ -39,7 +39,9 @@ class CodeGenCUDA final : public CodeGenC {
void Init(bool output_ssa); void Init(bool output_ssa);
void AddFunction(LoweredFunc f); void AddFunction(LoweredFunc f);
std::string Finish(); std::string Finish();
bool need_include_path() { return (enable_fp16_ || enable_int8_); } bool need_include_path() {
return (enable_fp16_ || enable_int8_ || need_math_constants_h_);
}
// override behavior // override behavior
void VisitStmt_(const ir::For* op) final; void VisitStmt_(const ir::For* op) final;
void PrintStorageSync(const Call* op) final; void PrintStorageSync(const Call* op) final;
...@@ -70,6 +72,9 @@ class CodeGenCUDA final : public CodeGenC { ...@@ -70,6 +72,9 @@ class CodeGenCUDA final : public CodeGenC {
bool enable_fp16_{false}; bool enable_fp16_{false};
// whether enable int8 // whether enable int8
bool enable_int8_{false}; bool enable_int8_{false};
// whether need math_constants.h
bool need_math_constants_h_{false};
friend void PrintConst(const FloatImm* op, std::ostream& os, CodeGenCUDA* p);
}; };
} // namespace codegen } // namespace codegen
......
...@@ -247,6 +247,19 @@ void CodeGenOpenCL::VisitExpr_(const Select* op, std::ostream& os) { // NOLINT( ...@@ -247,6 +247,19 @@ void CodeGenOpenCL::VisitExpr_(const Select* op, std::ostream& os) { // NOLINT(
CodeGenC::VisitExpr_(op, os); CodeGenC::VisitExpr_(op, os);
} }
void CodeGenOpenCL::VisitExpr_(const FloatImm *op, std::ostream& os) { // NOLINT(*)
if (std::isinf(op->value)) {
if (op->value < 0) {
os << "-";
}
os << "INFINITY";
} else if (std::isnan(op->value)) {
os << "NAN";
} else {
CodeGenC::VisitExpr_(op, os);
}
}
runtime::Module BuildOpenCL(Array<LoweredFunc> funcs) { runtime::Module BuildOpenCL(Array<LoweredFunc> funcs) {
using tvm::runtime::Registry; using tvm::runtime::Registry;
bool output_ssa = false; bool output_ssa = false;
......
...@@ -59,6 +59,7 @@ class CodeGenOpenCL final : public CodeGenC { ...@@ -59,6 +59,7 @@ class CodeGenOpenCL final : public CodeGenC {
void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const Call* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const Call* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const Select* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const Select* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const FloatImm *op, std::ostream& os) final; // NOLINT(*)
private: private:
// whether enable fp16 and fp64 extension // whether enable fp16 and fp64 extension
......
...@@ -125,8 +125,38 @@ def test_cuda_make_int8x4(): ...@@ -125,8 +125,38 @@ def test_cuda_make_int8x4():
check_cuda(64, 0) check_cuda(64, 0)
check_cuda(64, -3) check_cuda(64, -3)
def test_cuda_inf_nan():
target = 'cuda'
def check_inf_nan(ctx, n, value, dtype):
A = tvm.placeholder((n,), name='A', dtype=dtype)
inf_value = tvm.const(value, dtype=dtype)
C = tvm.compute((n,), lambda i: inf_value, name='C')
s = tvm.create_schedule(C.op)
s[C].bind(s[C].op.axis[0], tvm.thread_axis("threadIdx.x"))
fun = tvm.build(s, [A, C], target)
a = tvm.nd.empty((n,), A.dtype, ctx)
c = tvm.nd.empty((n,), A.dtype, ctx)
# Only need to test compiling here
fun(a, c)
if not tvm.gpu(0).exist or not tvm.module.enabled("cuda"):
print("skip because cuda is not enabled..")
return
ctx = tvm.context(target, 0)
check_inf_nan(ctx, 1, -float('inf'), 'float32')
check_inf_nan(ctx, 1, -float('inf'), 'float64')
check_inf_nan(ctx, 1, float('inf'), 'float32')
check_inf_nan(ctx, 1, float('inf'), 'float64')
check_inf_nan(ctx, 1, float('nan'), 'float32')
check_inf_nan(ctx, 1, float('nan'), 'float64')
if __name__ == "__main__": if __name__ == "__main__":
test_cuda_vectorize_add() test_cuda_vectorize_add()
test_cuda_multiply_add() test_cuda_multiply_add()
test_cuda_vectorize_load() test_cuda_vectorize_load()
test_cuda_make_int8x4() test_cuda_make_int8x4()
test_cuda_inf_nan()
...@@ -66,6 +66,33 @@ def test_opencl_ternary_expression(): ...@@ -66,6 +66,33 @@ def test_opencl_ternary_expression():
check_select(ctx, 1, 'int16') check_select(ctx, 1, 'int16')
check_select(ctx, 1, 'uint16') check_select(ctx, 1, 'uint16')
def test_opencl_inf_nan():
def check_inf_nan(ctx, n, value, dtype):
A = tvm.placeholder((n,), name='A', dtype=dtype)
inf_value = tvm.const(value, dtype=dtype)
C = tvm.compute((n,), lambda i: inf_value, name='C')
s = tvm.create_schedule(C.op)
s[C].bind(s[C].op.axis[0], tvm.thread_axis("threadIdx.x"))
fun = tvm.build(s, [A, C], target)
a = tvm.nd.empty((n,), A.dtype, ctx)
c = tvm.nd.empty((n,), A.dtype, ctx)
# Only need to test compiling here
fun(a, c)
if not tvm.module.enabled(target):
print("skip because opencl is not enabled..")
return
ctx = tvm.context(target, 0)
check_inf_nan(ctx, 1, -float('inf'), 'float32')
check_inf_nan(ctx, 1, -float('inf'), 'float64')
check_inf_nan(ctx, 1, float('inf'), 'float32')
check_inf_nan(ctx, 1, float('inf'), 'float64')
check_inf_nan(ctx, 1, float('nan'), 'float32')
check_inf_nan(ctx, 1, float('nan'), 'float64')
if __name__ == "__main__": if __name__ == "__main__":
test_opencl_ternary_expression() test_opencl_ternary_expression()
test_opencl_inf_nan()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment