/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file codegen_opengl.cc * * We are targeting OpenGL 3.3. The reason of not targeting a recent version * of OpenGL is to have better compatibility of WebGL 2. */ #include <tvm/packed_func_ext.h> #include <vector> #include <string> #include <utility> #include <unordered_map> #include "codegen_opengl.h" #include "build_common.h" #include "../runtime/thread_storage_scope.h" namespace tvm { namespace codegen { CodeGenOpenGL::CodeGenOpenGL() : output_(nullptr), output_iter_var_(nullptr) {} void CodeGenOpenGL::InitFuncState(LoweredFunc f) { CodeGenC::InitFuncState(f); output_ = nullptr; inputs_.clear(); output_iter_var_ = nullptr; thread_extent_var_ = ""; this->decl_stream.str(""); this->stream.str(""); } void CodeGenOpenGL::AddFunction(LoweredFunc f) { // clear previous generated state. this->InitFuncState(f); this->decl_stream << "#version 300 es\n"; this->decl_stream << "precision highp float;\n"; // skip the first underscore, so SSA variable starts from _1 GetUniqueName("_"); // add to alloc buffer type. for (const auto& kv : f->handle_data_type) { RegisterHandleType(kv.first.get(), kv.second.dtype()); } // Allocate argument names. Store in `var_idmap_`. for (auto arg : f->args) { auto arg_name = GetUniqueName(arg.get()->name_hint); var_idmap_[arg.get()] = arg_name; } thread_extent_var_ = GetUniqueName("thread_extent"); this->decl_stream << "uniform int " << thread_extent_var_ << ";\n"; this->stream << "void main() {\n"; int func_scope = this->BeginScope(); this->PrintStmt(f->body); this->EndScope(func_scope); this->PrintIndent(); this->stream << "}\n\n"; // Declare arguments. for (auto arg : f->args) { if (this->inputs_.find(arg.get()) != this->inputs_.cend()) { // Declare input texture. // Format: // - Float: "uniform sampler2D {name};" // - Int: "uniform isampler2D {name};" // - UInt: "uniform usampler2D {name};" auto arg_name = GetVarID(arg.get()); auto type_it = this->handle_data_type_.find(arg.get()); CHECK(type_it != this->handle_data_type_.cend()) << "Cannot find type."; DLDataType type = type_it->second; CHECK_EQ(type.lanes, 1) << "Vector type not supported."; switch (type.code) { case kDLInt: this->decl_stream << "uniform isampler2D " << arg_name << ";\n"; break; case kDLUInt: this->decl_stream << "uniform usampler2D " << arg_name << ";\n"; break; case kDLFloat: this->decl_stream << "uniform sampler2D " << arg_name << ";\n"; break; default: LOG(FATAL) << "Unsupported type code."; } } else if (this->output_ == arg.get()) { // Declare output texture. // Format: "out {type} {name};" auto arg_name = GetVarID(arg.get()); auto type_it = this->handle_data_type_.find(arg.get()); CHECK(type_it != this->handle_data_type_.cend()) << "Cannot find type."; auto type = type_it->second; this->decl_stream << "out "; PrintType(type, this->decl_stream); this->decl_stream << " " << arg_name << ";\n"; } else { // Declare uniform value. // Format: "uniform {type} {name};" auto arg_name = GetVarID(arg.get()); auto type = arg.get()->dtype; this->decl_stream << "uniform "; PrintType(type, this->decl_stream); this->decl_stream << " " << arg_name << ";\n"; } } std::vector<std::string> arg_names; std::vector<runtime::OpenGLArgKind> arg_kinds; for (auto arg : f->args) { std::string name = GetVarID(arg.get()); runtime::OpenGLArgKind kind; if (inputs_.find(arg.get()) != inputs_.cend()) { kind = runtime::OpenGLArgKind::kInputTexture; } else if (output_ == arg.get()) { kind = runtime::OpenGLArgKind::kOutputTexture; } else { kind = runtime::OpenGLArgKind::kUniform; } arg_names.push_back(name); arg_kinds.push_back(kind); } shaders_[f->name] = runtime::OpenGLShader( this->decl_stream.str() + this->stream.str(), std::move(arg_names), std::move(arg_kinds), this->thread_extent_var_); } std::unordered_map<std::string, runtime::OpenGLShader> CodeGenOpenGL::Finish() { return shaders_; } void CodeGenOpenGL::BindThreadIndex(const IterVar& iv) { CHECK_EQ(iv->thread_tag, "threadIdx.x") << "Must be threadIdx.x"; CHECK(var_idmap_.find(iv->var.get()) == var_idmap_.end()) << "Only support one thread iter var"; CHECK(output_iter_var_ == nullptr) << "Only support one thread iter var"; var_idmap_[iv->var.get()] = iv->thread_tag; output_iter_var_ = iv->var.get(); // Declare threadIdx local variable. this->PrintIndent(); this->stream << "ivec2 threadIdx = ivec2(" << runtime::kTextureRowSize << " * int(gl_FragCoord.y) + int(gl_FragCoord.x), 0);\n"; // Return directly if threadIdx.x >= thread_extent. this->PrintIndent(); this->stream << "if (threadIdx.x >= " << thread_extent_var_ << ") {\n"; this->PrintIndent(); this->stream << " return;\n"; this->PrintIndent(); this->stream << "}\n"; } void CodeGenOpenGL::VisitStmt_(const Store* op) { LOG(FATAL) << "Store statement not supported in OpenGL." << " Texture store should be a Call statement."; } // texelFetch(tex, ivec2(idx & kTextureRowMask, idx >> kTextureRowBits), 0).r std::string CodeGenOpenGL::TexelFetch(const Variable* buffer, Expr index) { std::ostringstream os; os << "texelFetch(" << GetVarID(buffer) << ", ivec2(int("; PrintExpr(index, os); os << ") & " << runtime::kTextureRowMask << ", int("; PrintExpr(index, os); os << ") >> " << runtime::kTextureRowBits << "), 0).r"; return os.str(); } // Print a reference expression to a buffer. // Format: texelFetch(buffer, index, 0).r std::string CodeGenOpenGL::GetBufferRef( DataType t, const Variable* buffer, Expr index) { CHECK_EQ(t.lanes(), 1) << "Vector type not supported."; CHECK(HandleTypeMatch(buffer, t)) << "Type mismatch not supported."; if (buffer == this->output_) { // This is the output texture. return GetVarID(buffer); } else { // This is an input texture. this->inputs_.insert(buffer); return TexelFetch(buffer, index); } } void CodeGenOpenGL::PrintType(DataType t, std::ostream& os) { switch (t.code()) { case kDLInt: CHECK_EQ(t.bits(), 32) << "Only support 32-bit int."; os << "int"; break; case kDLUInt: CHECK_EQ(t.bits(), 32) << "Only support 32-bit uint."; os << "uint"; break; case kDLFloat: CHECK_EQ(t.bits(), 32) << "Only support 32-bit float."; os << "float"; break; default: LOG(FATAL) << "Unsupported type code."; } } // Codegen for immediate values void CodeGenOpenGL::VisitExpr_(const IntImm* op, std::ostream& os) { CHECK_EQ(op->dtype, DataType::Int(32)) << "GLSL 3.0 only supports 32-bit ints."; CodeGenC::VisitExpr_(op, os); } void CodeGenOpenGL::VisitExpr_(const UIntImm* op, std::ostream& os) { CHECK_EQ(op->dtype, DataType::UInt(32)) << "GLSL 3.0 only supports 32-bit uints."; CodeGenC::VisitExpr_(op, os); } void CodeGenOpenGL::VisitExpr_(const FloatImm* op, std::ostream& os) { CHECK_EQ(op->dtype, DataType::Float(32)) << "GLSL 3.0 only supports 32-bit floats."; CodeGenC::VisitExpr_(op, os); } void CodeGenOpenGL::VisitExpr_(const StringImm*, std::ostream& os) { LOG(FATAL) << "GLSL 3.0 doesn't support strings."; } void CodeGenOpenGL::VisitStmt_(const Evaluate* op) { auto call = op->value.as<Call>(); if (call == nullptr || call->name != Call::glsl_texture_store) { // Fallback to normal logic. CodeGenC::VisitStmt_(op); } CHECK_EQ(call->args.size(), 2); auto buffer = call->args[0].as<Variable>(); auto value = call->args[1]; // Doesn't support store to vector. auto type = value.dtype(); CHECK_EQ(type.lanes(), 1) << "Vectorized store not implemented, type = " << type; CHECK(inputs_.find(buffer) == inputs_.cend()) << "Texture has been read from before. Must not store to it."; if (output_ == nullptr) { output_ = buffer; // Record that this texture is the output. } else { CHECK(output_ == buffer) << "GLSL can only write to 1 texture."; } this->PrintIndent(); this->stream << GetVarID(buffer) << " = " << PrintExpr(value) << ";\n"; } runtime::Module BuildOpenGL(Array<LoweredFunc> funcs) { bool output_ssa = false; CodeGenOpenGL cg; cg.Init(output_ssa); for (LoweredFunc f : funcs) { cg.AddFunction(f); } auto shaders = cg.Finish(); return OpenGLModuleCreate(shaders, "gl", ExtractFuncInfo(funcs)); } TVM_REGISTER_API("codegen.build_opengl") .set_body_typed(BuildOpenGL); } // namespace codegen } // namespace tvm