/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file codegen_opencl.cc */ #include <tvm/packed_func_ext.h> #include <cmath> #include <vector> #include <string> #include "codegen_opencl.h" #include "build_common.h" #include "../runtime/thread_storage_scope.h" #include "../runtime/opencl/opencl_module.h" namespace tvm { namespace codegen { CodeGenOpenCL::CodeGenOpenCL() { restrict_keyword_ = "restrict"; } void CodeGenOpenCL::InitFuncState(LoweredFunc f) { CodeGenC::InitFuncState(f); for (Var arg : f->args) { if (arg.dtype().is_handle()) { alloc_storage_scope_[arg.get()] = "global"; } } } void CodeGenOpenCL::AddFunction(LoweredFunc f) { this->stream << "__kernel "; CodeGenC::AddFunction(f); } std::string CodeGenOpenCL::Finish() { // inject extension enable pragma for fp16 and fp64 if (enable_fp16_) { decl_stream << "#ifdef cl_khr_fp16\n" "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n" "#elif defined(cl_amd_fp16)\n" "#pragma OPENCL EXTENSION cl_amd_fp16 : enable\n" "#else\n" "#error \"Half precision floating point not supported" "by OpenCL implementation on your device.\" \n" "#endif\n\n"; } if (enable_fp64_) { decl_stream << "#ifdef cl_khr_fp64\n" "#pragma OPENCL EXTENSION cl_khr_fp64 : enable\n" "#elif defined(cl_amd_fp64)\n" "#pragma OPENCL EXTENSION cl_amd_fp64 : enable\n" "#else\n" "#error \"Double precision floating point not supported" "by OpenCL implementation on your device.\" \n" "#endif\n\n"; } return CodeGenC::Finish(); } void CodeGenOpenCL::BindThreadIndex(const IterVar& iv) { CHECK(!var_idmap_.count(iv->var.get())); runtime::ThreadScope ts = runtime::ThreadScope::make(iv->thread_tag); std::ostringstream os; if (ts.rank == 1) { os << "get_local_id(" << ts.dim_index << ")"; } else { os << "get_group_id(" << ts.dim_index << ")"; } var_idmap_[iv->var.get()] = CastFromTo(os.str(), DataType::UInt(64), iv->var.dtype()); } void CodeGenOpenCL::PrintType(DataType t, std::ostream& os) { // NOLINT(*) int lanes = t.lanes(); if (t.is_handle()) { CHECK_EQ(lanes, 1) << "do not yet support vector types"; os << "void*"; return; } if (t == DataType::Bool()) { os << "bool"; return; } bool fail = false; if (t.is_float()) { switch (t.bits()) { case 16: os << "half"; enable_fp16_ = true; break; case 32: os << "float"; break; case 64: os << "double"; enable_fp64_ = true; break; default: fail = true; break; } if (!fail && lanes == 1) return; if (!fail && (lanes >= 2 && lanes <= 16)) { os << lanes; return; } } else if (t.is_uint() || t.is_int()) { if (t.is_uint()) { os << 'u'; } if (t.bits() == 8 && t.lanes() == 4) { // directly 4 8 bit int in integer. os << "int"; return; } switch (t.bits()) { case 8: os << "char"; break; case 16: os << "short"; break; case 32: os << "int"; break; case 64: os << "long"; break; case 1: os << "int"; break; default: fail = true; break; } if (!fail && lanes == 1) return; if (!fail && (lanes >= 2 && lanes <= 16)) { os << lanes; return; } } LOG(FATAL) << "Cannot convert type " << t << " to OpenCL type"; } void CodeGenOpenCL::PrintVecAddr(const Variable* buffer, DataType t, Expr base, std::ostream& os) { // NOLINT(*) if (!HandleTypeMatch(buffer, t.element_of())) { os << '('; auto it = alloc_storage_scope_.find(buffer); if (it != alloc_storage_scope_.end()) { PrintStorageScope(it->second, os); } os << ' '; PrintType(t.element_of(), os); os << "*)"; } os << GetVarID(buffer) << " + "; PrintExpr(base, os); } std::string CodeGenOpenCL::GetVecLoad( DataType t, const Variable* buffer, Expr base) { std::ostringstream os; os << "vload" << t.lanes() << "(0, "; PrintVecAddr(buffer, t, base, os); os << ")"; return os.str(); } void CodeGenOpenCL::PrintVecStore(const Variable* buffer, DataType t, Expr base, const std::string& value) { this->PrintIndent(); stream << "vstore" << t.lanes() << "(" << value << ", 0, "; PrintVecAddr(buffer, t, base, stream); stream << ");\n"; } void CodeGenOpenCL::PrintStorageSync(const Call* op) { const std::string& sync = op->args[0].as<StringImm>()->value; if (sync == "warp") { this->PrintIndent(); this->stream << "barrier(CLK_LOCAL_MEM_FENCE);\n"; } else if (sync == "shared") { this->PrintIndent(); this->stream << "barrier(CLK_LOCAL_MEM_FENCE);\n"; } else if (sync == "global") { LOG(FATAL) << "not supported"; } } void CodeGenOpenCL::PrintStorageScope( const std::string& scope, std::ostream& os) { // NOLINT(*) if (scope == "global") { os << "__global"; } else if (scope == "shared") { os << "__local"; } } std::string CodeGenOpenCL::CastFromTo(std::string value, DataType from, DataType target) { if (from == target) return value; std::ostringstream os; if (target.lanes() == 1) { os << "(("; this->PrintType(target, os); os << ")" << value << ")"; } else { // convert vector type os << "("; os << "convert_"; this->PrintType(target, os); os << "(" << value << "))"; } return os.str(); } void CodeGenOpenCL::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOLINT(*) std::string v = PrintExpr(op->value); os << "(("; PrintType(op->dtype, os); os << ")("; for (int i = 0; i < op->lanes; ++i) { if (i != 0) os << ", "; os << v; } os << "))"; } void CodeGenOpenCL::VisitExpr_(const Call *op, std::ostream& os) { // NOLINT(*) /* Return type of ternary expression is not always same as its sub-expressions, * add a cast */ if (op->is_intrinsic(intrinsic::tvm_if_then_else)) { os << "("; PrintType(op->args[2].dtype(), os); os << ")"; } CodeGenC::VisitExpr_(op, os); } void CodeGenOpenCL::VisitExpr_(const Select* op, std::ostream& os) { // NOLINT(*) /* Return type of ternary expression is not always same as its sub-expressions, * add a cast */ os << "("; PrintType(op->true_value.dtype(), os); os << ")"; CodeGenC::VisitExpr_(op, os); } void CodeGenOpenCL::VisitExpr_(const FloatImm *op, std::ostream& os) { // NOLINT(*) if (std::isinf(op->value)) { if (op->value < 0) { os << "-"; } os << "INFINITY"; } else if (std::isnan(op->value)) { os << "NAN"; } else { CodeGenC::VisitExpr_(op, os); } } runtime::Module BuildOpenCL(Array<LoweredFunc> funcs) { using tvm::runtime::Registry; bool output_ssa = false; CodeGenOpenCL cg; cg.Init(output_ssa); for (LoweredFunc f : funcs) { cg.AddFunction(f); } std::string code = cg.Finish(); if (const auto* f = Registry::Get("tvm_callback_opencl_postproc")) { code = (*f)(code).operator std::string(); } return OpenCLModuleCreate(code, "cl", ExtractFuncInfo(funcs), code); } TVM_REGISTER_GLOBAL("codegen.build_opencl") .set_body_typed(BuildOpenCL); } // namespace codegen } // namespace tvm