Commit 3dfb8459 by Lianmin Zheng Committed by Tianqi Chen

[CODEGEN] add fp16 and fp64 enable pragma for opencl (#697)

* [CODEGEN] add fp16 and fp64 enable pragma for opencl

* fix style
parent a71beda3
......@@ -272,7 +272,7 @@ void CodeGenC::PrintStorageScope(const std::string& scope, std::ostream& os) { /
CHECK_EQ(scope, "global");
}
void CodeGenC::PrintType(Type t, std::ostream& os) const { // NOLINT(*)
void CodeGenC::PrintType(Type t, std::ostream& os) { // NOLINT(*)
CHECK_EQ(t.lanes(), 1)
<< "do not yet support vector types";
if (t.is_handle()) {
......@@ -402,7 +402,9 @@ inline void PrintBinaryIntrinsitc(const Call* op,
}
}
void CodeGenC::VisitExpr_(const Cast *op, std::ostream& os) { // NOLINT(*)
os << "(";
this->PrintType(op->type, os);
os << ")";
os << '(';
this->PrintExpr(op->value, os);
os << ')';
......
......@@ -118,7 +118,7 @@ class CodeGenC :
* \param t The type representation.
* \param os The stream to print the ctype into
*/
virtual void PrintType(Type t, std::ostream& os) const; // NOLINT(*)
virtual void PrintType(Type t, std::ostream& os); // NOLINT(*)
/*!
* \brief Print expr representing the thread tag
* \param IterVar iv The thread index to be binded;
......
......@@ -45,7 +45,7 @@ void CodeGenCUDA::BindThreadIndex(const IterVar& iv) {
CastFromTo(iv->thread_tag, UInt(32), iv->var.type());
}
void CodeGenCUDA::PrintType(Type t, std::ostream& os) const { // NOLINT(*)
void CodeGenCUDA::PrintType(Type t, std::ostream& os) { // NOLINT(*)
int lanes = t.lanes();
if (t.is_handle()) {
CHECK_EQ(lanes, 1)
......
......@@ -26,7 +26,7 @@ class CodeGenCUDA final : public CodeGenC {
void PrintVecBinaryOp(
const std::string&op, Type t,
Expr lhs, Expr rhs, std::ostream& os) final; // NOLINT(*)
void PrintType(Type t, std::ostream& os) const final; // NOLINT(*)
void PrintType(Type t, std::ostream& os) final; // NOLINT(*)
void PrintVecElemLoad(
const std::string& vec, Type t, int i, std::ostream& os) final; // NOLINT(*)
void PrintVecElemStore(
......
......@@ -132,7 +132,7 @@ void CodeGenMetal::BindThreadIndex(const IterVar& iv) {
CastFromTo(iv->thread_tag, UInt(thread_index_bits_), iv->var.type());
}
void CodeGenMetal::PrintType(Type t, std::ostream& os) const { // NOLINT(*)
void CodeGenMetal::PrintType(Type t, std::ostream& os) { // NOLINT(*)
int lanes = t.lanes();
if (t.is_handle()) {
CHECK_EQ(lanes, 1)
......
......@@ -23,7 +23,7 @@ class CodeGenMetal final : public CodeGenC {
void InitFuncState(LoweredFunc f) final;
void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*)
void PrintStorageSync(const Call* op) final; // NOLINT(*)
void PrintType(Type t, std::ostream& os) const final; // NOLINT(*)
void PrintType(Type t, std::ostream& os) final; // NOLINT(*)
void BindThreadIndex(const IterVar& iv) final; // NOLINT(*)
// overload visitor
void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*)
......
......@@ -30,6 +30,35 @@ void CodeGenOpenCL::AddFunction(LoweredFunc f) {
CodeGenC::AddFunction(f);
}
std::string CodeGenOpenCL::Finish() {
// inject extension enable pragma for fp16 and fp64
if (enable_fp16_) {
decl_stream
<< "#ifdef cl_khr_fp16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#elif defined(cl_amd_fp16)\n"
"#pragma OPENCL EXTENSION cl_amd_fp16 : enable\n"
"#else\n"
"#error \"Half precision floating point not supported"
"by OpenCL implementation on your device.\" \n"
"#endif\n\n";
}
if (enable_fp64_) {
decl_stream
<< "#ifdef cl_khr_fp64\n"
"#pragma OPENCL EXTENSION cl_khr_fp64 : enable\n"
"#elif defined(cl_amd_fp64)\n"
"#pragma OPENCL EXTENSION cl_amd_fp64 : enable\n"
"#else\n"
"#error \"Double precision floating point not supported"
"by OpenCL implementation on your device.\" \n"
"#endif\n\n";
}
return CodeGenC::Finish();
}
void CodeGenOpenCL::BindThreadIndex(const IterVar& iv) {
CHECK(!var_idmap_.count(iv->var.get()));
runtime::ThreadScope ts = runtime::ThreadScope::make(iv->thread_tag);
......@@ -43,7 +72,7 @@ void CodeGenOpenCL::BindThreadIndex(const IterVar& iv) {
CastFromTo(os.str(), UInt(64), iv->var.type());
}
void CodeGenOpenCL::PrintType(Type t, std::ostream& os) const { // NOLINT(*)
void CodeGenOpenCL::PrintType(Type t, std::ostream& os) { // NOLINT(*)
int lanes = t.lanes();
if (t.is_handle()) {
CHECK_EQ(lanes, 1)
......@@ -53,9 +82,15 @@ void CodeGenOpenCL::PrintType(Type t, std::ostream& os) const { // NOLINT(*)
bool fail = false;
if (t.is_float()) {
switch (t.bits()) {
case 16: os << "half"; break;
case 16:
os << "half";
enable_fp16_ = true;
break;
case 32: os << "float"; break;
case 64: os << "double"; break;
case 64:
os << "double";
enable_fp64_ = true;
break;
default: fail = true; break;
}
if (!fail && lanes == 1) return;
......
......@@ -18,12 +18,14 @@ class CodeGenOpenCL final : public CodeGenC {
public:
CodeGenOpenCL();
void AddFunction(LoweredFunc f);
std::string Finish();
// override print thread tag.
void InitFuncState(LoweredFunc f) final;
void BindThreadIndex(const IterVar& iv) final; // NOLINT(*)
void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*)
void PrintStorageSync(const Call* op) final; // NOLINT(*)
void PrintType(Type t, std::ostream& os) const final; // NOLINT(*)
void PrintType(Type t, std::ostream& os) final; // NOLINT(*)
std::string GetVecLoad(Type t, const Variable* buffer,
Expr base) final;
void PrintVecStore(const Variable* buffer,
......@@ -34,6 +36,11 @@ class CodeGenOpenCL final : public CodeGenC {
Expr base, std::ostream& os); // NOLINT(*)
// overload visitor
void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*)
private:
// whether enable fp16 and fp64 extension
bool enable_fp16_{false};
bool enable_fp64_{false};
};
} // namespace codegen
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment