Commit 3dfb8459 by Lianmin Zheng Committed by Tianqi Chen

[CODEGEN] add fp16 and fp64 enable pragma for opencl (#697)

* [CODEGEN] add fp16 and fp64 enable pragma for opencl

* fix style
parent a71beda3
...@@ -272,7 +272,7 @@ void CodeGenC::PrintStorageScope(const std::string& scope, std::ostream& os) { / ...@@ -272,7 +272,7 @@ void CodeGenC::PrintStorageScope(const std::string& scope, std::ostream& os) { /
CHECK_EQ(scope, "global"); CHECK_EQ(scope, "global");
} }
void CodeGenC::PrintType(Type t, std::ostream& os) const { // NOLINT(*) void CodeGenC::PrintType(Type t, std::ostream& os) { // NOLINT(*)
CHECK_EQ(t.lanes(), 1) CHECK_EQ(t.lanes(), 1)
<< "do not yet support vector types"; << "do not yet support vector types";
if (t.is_handle()) { if (t.is_handle()) {
...@@ -402,7 +402,9 @@ inline void PrintBinaryIntrinsitc(const Call* op, ...@@ -402,7 +402,9 @@ inline void PrintBinaryIntrinsitc(const Call* op,
} }
} }
void CodeGenC::VisitExpr_(const Cast *op, std::ostream& os) { // NOLINT(*) void CodeGenC::VisitExpr_(const Cast *op, std::ostream& os) { // NOLINT(*)
os << "(";
this->PrintType(op->type, os); this->PrintType(op->type, os);
os << ")";
os << '('; os << '(';
this->PrintExpr(op->value, os); this->PrintExpr(op->value, os);
os << ')'; os << ')';
......
...@@ -118,7 +118,7 @@ class CodeGenC : ...@@ -118,7 +118,7 @@ class CodeGenC :
* \param t The type representation. * \param t The type representation.
* \param os The stream to print the ctype into * \param os The stream to print the ctype into
*/ */
virtual void PrintType(Type t, std::ostream& os) const; // NOLINT(*) virtual void PrintType(Type t, std::ostream& os); // NOLINT(*)
/*! /*!
* \brief Print expr representing the thread tag * \brief Print expr representing the thread tag
* \param IterVar iv The thread index to be binded; * \param IterVar iv The thread index to be binded;
......
...@@ -45,7 +45,7 @@ void CodeGenCUDA::BindThreadIndex(const IterVar& iv) { ...@@ -45,7 +45,7 @@ void CodeGenCUDA::BindThreadIndex(const IterVar& iv) {
CastFromTo(iv->thread_tag, UInt(32), iv->var.type()); CastFromTo(iv->thread_tag, UInt(32), iv->var.type());
} }
void CodeGenCUDA::PrintType(Type t, std::ostream& os) const { // NOLINT(*) void CodeGenCUDA::PrintType(Type t, std::ostream& os) { // NOLINT(*)
int lanes = t.lanes(); int lanes = t.lanes();
if (t.is_handle()) { if (t.is_handle()) {
CHECK_EQ(lanes, 1) CHECK_EQ(lanes, 1)
......
...@@ -26,7 +26,7 @@ class CodeGenCUDA final : public CodeGenC { ...@@ -26,7 +26,7 @@ class CodeGenCUDA final : public CodeGenC {
void PrintVecBinaryOp( void PrintVecBinaryOp(
const std::string&op, Type t, const std::string&op, Type t,
Expr lhs, Expr rhs, std::ostream& os) final; // NOLINT(*) Expr lhs, Expr rhs, std::ostream& os) final; // NOLINT(*)
void PrintType(Type t, std::ostream& os) const final; // NOLINT(*) void PrintType(Type t, std::ostream& os) final; // NOLINT(*)
void PrintVecElemLoad( void PrintVecElemLoad(
const std::string& vec, Type t, int i, std::ostream& os) final; // NOLINT(*) const std::string& vec, Type t, int i, std::ostream& os) final; // NOLINT(*)
void PrintVecElemStore( void PrintVecElemStore(
......
...@@ -132,7 +132,7 @@ void CodeGenMetal::BindThreadIndex(const IterVar& iv) { ...@@ -132,7 +132,7 @@ void CodeGenMetal::BindThreadIndex(const IterVar& iv) {
CastFromTo(iv->thread_tag, UInt(thread_index_bits_), iv->var.type()); CastFromTo(iv->thread_tag, UInt(thread_index_bits_), iv->var.type());
} }
void CodeGenMetal::PrintType(Type t, std::ostream& os) const { // NOLINT(*) void CodeGenMetal::PrintType(Type t, std::ostream& os) { // NOLINT(*)
int lanes = t.lanes(); int lanes = t.lanes();
if (t.is_handle()) { if (t.is_handle()) {
CHECK_EQ(lanes, 1) CHECK_EQ(lanes, 1)
......
...@@ -23,7 +23,7 @@ class CodeGenMetal final : public CodeGenC { ...@@ -23,7 +23,7 @@ class CodeGenMetal final : public CodeGenC {
void InitFuncState(LoweredFunc f) final; void InitFuncState(LoweredFunc f) final;
void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*)
void PrintStorageSync(const Call* op) final; // NOLINT(*) void PrintStorageSync(const Call* op) final; // NOLINT(*)
void PrintType(Type t, std::ostream& os) const final; // NOLINT(*) void PrintType(Type t, std::ostream& os) final; // NOLINT(*)
void BindThreadIndex(const IterVar& iv) final; // NOLINT(*) void BindThreadIndex(const IterVar& iv) final; // NOLINT(*)
// overload visitor // overload visitor
void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*)
......
...@@ -30,6 +30,35 @@ void CodeGenOpenCL::AddFunction(LoweredFunc f) { ...@@ -30,6 +30,35 @@ void CodeGenOpenCL::AddFunction(LoweredFunc f) {
CodeGenC::AddFunction(f); CodeGenC::AddFunction(f);
} }
std::string CodeGenOpenCL::Finish() {
// inject extension enable pragma for fp16 and fp64
if (enable_fp16_) {
decl_stream
<< "#ifdef cl_khr_fp16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#elif defined(cl_amd_fp16)\n"
"#pragma OPENCL EXTENSION cl_amd_fp16 : enable\n"
"#else\n"
"#error \"Half precision floating point not supported"
"by OpenCL implementation on your device.\" \n"
"#endif\n\n";
}
if (enable_fp64_) {
decl_stream
<< "#ifdef cl_khr_fp64\n"
"#pragma OPENCL EXTENSION cl_khr_fp64 : enable\n"
"#elif defined(cl_amd_fp64)\n"
"#pragma OPENCL EXTENSION cl_amd_fp64 : enable\n"
"#else\n"
"#error \"Double precision floating point not supported"
"by OpenCL implementation on your device.\" \n"
"#endif\n\n";
}
return CodeGenC::Finish();
}
void CodeGenOpenCL::BindThreadIndex(const IterVar& iv) { void CodeGenOpenCL::BindThreadIndex(const IterVar& iv) {
CHECK(!var_idmap_.count(iv->var.get())); CHECK(!var_idmap_.count(iv->var.get()));
runtime::ThreadScope ts = runtime::ThreadScope::make(iv->thread_tag); runtime::ThreadScope ts = runtime::ThreadScope::make(iv->thread_tag);
...@@ -43,7 +72,7 @@ void CodeGenOpenCL::BindThreadIndex(const IterVar& iv) { ...@@ -43,7 +72,7 @@ void CodeGenOpenCL::BindThreadIndex(const IterVar& iv) {
CastFromTo(os.str(), UInt(64), iv->var.type()); CastFromTo(os.str(), UInt(64), iv->var.type());
} }
void CodeGenOpenCL::PrintType(Type t, std::ostream& os) const { // NOLINT(*) void CodeGenOpenCL::PrintType(Type t, std::ostream& os) { // NOLINT(*)
int lanes = t.lanes(); int lanes = t.lanes();
if (t.is_handle()) { if (t.is_handle()) {
CHECK_EQ(lanes, 1) CHECK_EQ(lanes, 1)
...@@ -53,9 +82,15 @@ void CodeGenOpenCL::PrintType(Type t, std::ostream& os) const { // NOLINT(*) ...@@ -53,9 +82,15 @@ void CodeGenOpenCL::PrintType(Type t, std::ostream& os) const { // NOLINT(*)
bool fail = false; bool fail = false;
if (t.is_float()) { if (t.is_float()) {
switch (t.bits()) { switch (t.bits()) {
case 16: os << "half"; break; case 16:
os << "half";
enable_fp16_ = true;
break;
case 32: os << "float"; break; case 32: os << "float"; break;
case 64: os << "double"; break; case 64:
os << "double";
enable_fp64_ = true;
break;
default: fail = true; break; default: fail = true; break;
} }
if (!fail && lanes == 1) return; if (!fail && lanes == 1) return;
......
...@@ -18,12 +18,14 @@ class CodeGenOpenCL final : public CodeGenC { ...@@ -18,12 +18,14 @@ class CodeGenOpenCL final : public CodeGenC {
public: public:
CodeGenOpenCL(); CodeGenOpenCL();
void AddFunction(LoweredFunc f); void AddFunction(LoweredFunc f);
std::string Finish();
// override print thread tag. // override print thread tag.
void InitFuncState(LoweredFunc f) final; void InitFuncState(LoweredFunc f) final;
void BindThreadIndex(const IterVar& iv) final; // NOLINT(*) void BindThreadIndex(const IterVar& iv) final; // NOLINT(*)
void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*)
void PrintStorageSync(const Call* op) final; // NOLINT(*) void PrintStorageSync(const Call* op) final; // NOLINT(*)
void PrintType(Type t, std::ostream& os) const final; // NOLINT(*) void PrintType(Type t, std::ostream& os) final; // NOLINT(*)
std::string GetVecLoad(Type t, const Variable* buffer, std::string GetVecLoad(Type t, const Variable* buffer,
Expr base) final; Expr base) final;
void PrintVecStore(const Variable* buffer, void PrintVecStore(const Variable* buffer,
...@@ -34,6 +36,11 @@ class CodeGenOpenCL final : public CodeGenC { ...@@ -34,6 +36,11 @@ class CodeGenOpenCL final : public CodeGenC {
Expr base, std::ostream& os); // NOLINT(*) Expr base, std::ostream& os); // NOLINT(*)
// overload visitor // overload visitor
void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*)
private:
// whether enable fp16 and fp64 extension
bool enable_fp16_{false};
bool enable_fp64_{false};
}; };
} // namespace codegen } // namespace codegen
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment