codegen_opencl.cc 7.89 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21
/*!
 *  Copyright (c) 2017 by Contributors
22
 * \file codegen_opencl.cc
23 24 25 26
 */
#include <tvm/packed_func_ext.h>
#include <vector>
#include <string>
27 28
#include "codegen_opencl.h"
#include "build_common.h"
29
#include "../runtime/thread_storage_scope.h"
30
#include "../runtime/opencl/opencl_module.h"
31 32 33 34

namespace tvm {
namespace codegen {

35 36 37 38
CodeGenOpenCL::CodeGenOpenCL() {
  restrict_keyword_ = "restrict";
}

39 40
void CodeGenOpenCL::InitFuncState(LoweredFunc f) {
  CodeGenC::InitFuncState(f);
41 42 43 44 45
  for (Var arg : f->args) {
    if (arg.type().is_handle()) {
      alloc_storage_scope_[arg.get()] = "global";
    }
  }
46 47 48
}

void CodeGenOpenCL::AddFunction(LoweredFunc f) {
49
  this->stream << "__kernel ";
50
  CodeGenC::AddFunction(f);
51 52
}

53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81
std::string CodeGenOpenCL::Finish() {
  // inject extension enable pragma for fp16 and fp64
  if (enable_fp16_) {
    decl_stream
        << "#ifdef cl_khr_fp16\n"
           "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
           "#elif defined(cl_amd_fp16)\n"
           "#pragma OPENCL EXTENSION cl_amd_fp16 : enable\n"
           "#else\n"
           "#error \"Half precision floating point not supported"
                    "by OpenCL implementation on your device.\" \n"
           "#endif\n\n";
  }

  if (enable_fp64_) {
    decl_stream
        << "#ifdef cl_khr_fp64\n"
           "#pragma OPENCL EXTENSION cl_khr_fp64 : enable\n"
           "#elif defined(cl_amd_fp64)\n"
           "#pragma OPENCL EXTENSION cl_amd_fp64 : enable\n"
           "#else\n"
           "#error \"Double precision floating point not supported"
                    "by OpenCL implementation on your device.\" \n"
           "#endif\n\n";
  }

  return CodeGenC::Finish();
}

82 83 84 85
void CodeGenOpenCL::BindThreadIndex(const IterVar& iv) {
  CHECK(!var_idmap_.count(iv->var.get()));
  runtime::ThreadScope ts = runtime::ThreadScope::make(iv->thread_tag);
  std::ostringstream os;
86 87
  if (ts.rank == 1) {
    os << "get_local_id(" << ts.dim_index << ")";
88
  } else {
89
    os << "get_group_id(" << ts.dim_index << ")";
90
  }
91 92
  var_idmap_[iv->var.get()] =
      CastFromTo(os.str(), UInt(64), iv->var.type());
93 94
}

95
void CodeGenOpenCL::PrintType(Type t, std::ostream& os) {  // NOLINT(*)
96 97 98 99 100 101
  int lanes = t.lanes();
  if (t.is_handle()) {
    CHECK_EQ(lanes, 1)
        << "do not yet support vector types";
    os << "void*"; return;
  }
102 103 104
  if (t == Bool()) {
    os << "bool"; return;
  }
105 106 107
  bool fail = false;
  if (t.is_float()) {
    switch (t.bits()) {
108 109 110 111
      case 16:
        os << "half";
        enable_fp16_ = true;
        break;
112
      case 32: os << "float"; break;
113 114 115 116
      case 64:
        os << "double";
        enable_fp64_ = true;
        break;
117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161
      default: fail = true; break;
    }
    if (!fail && lanes == 1) return;
    if (!fail && (lanes >= 2 && lanes <= 16)) {
      os << lanes; return;
    }
  } else if (t.is_uint() || t.is_int()) {
    if (t.is_uint()) {
      os << 'u';
    }
    if (t.bits() == 8 && t.lanes() == 4) {
      // directly 4 8 bit int in integer.
      os << "int"; return;
    }
    switch (t.bits()) {
      case 8: os << "char"; break;
      case 16: os << "short"; break;
      case 32: os << "int"; break;
      case 64: os << "long"; break;
      case 1: os << "int"; break;
      default: fail = true; break;
    }
    if (!fail && lanes == 1) return;
    if (!fail && (lanes >= 2 && lanes <= 16)) {
      os << lanes; return;
    }
  }
  LOG(FATAL) << "Cannot convert type " << t << " to OpenCL type";
}

void CodeGenOpenCL::PrintVecAddr(const Variable* buffer, Type t,
                                 Expr base, std::ostream& os) {  // NOLINT(*)
  if (!HandleTypeMatch(buffer, t.element_of())) {
    os << '(';
    auto it = alloc_storage_scope_.find(buffer);
    if (it != alloc_storage_scope_.end()) {
      PrintStorageScope(it->second, os);
    }
    os << ' ';
    PrintType(t.element_of(), os);
    os << "*)";
  }
  os << GetVarID(buffer) << " + ";
  PrintExpr(base, os);
}
162 163
std::string CodeGenOpenCL::GetVecLoad(
    Type t, const Variable* buffer, Expr base) {
164
  std::ostringstream os;
165 166 167
  os << "vload" << t.lanes() << "(0, ";
  PrintVecAddr(buffer, t, base, os);
  os << ")";
168
  return os.str();
169 170 171 172 173 174 175 176 177 178
}

void CodeGenOpenCL::PrintVecStore(const Variable* buffer,
                                  Type t, Expr base,
                                  const std::string& value) {
  this->PrintIndent();
  stream << "vstore" << t.lanes() << "(" << value << ", 0, ";
  PrintVecAddr(buffer, t, base, stream);
  stream << ");\n";
}
179

180 181
void CodeGenOpenCL::PrintStorageSync(const Call* op) {
  const std::string& sync = op->args[0].as<StringImm>()->value;
182
  if (sync == "warp") {
183
    this->PrintIndent();
184
    this->stream << "barrier(CLK_LOCAL_MEM_FENCE);\n";
185
  } else if (sync == "shared") {
186 187 188 189 190 191 192
    this->PrintIndent();
    this->stream << "barrier(CLK_LOCAL_MEM_FENCE);\n";
  } else if (sync == "global") {
    LOG(FATAL) << "not supported";
  }
}

193 194
void CodeGenOpenCL::PrintStorageScope(
    const std::string& scope, std::ostream& os) { // NOLINT(*)
195 196 197
  if (scope == "global") {
    os << "__global";
  } else if (scope == "shared") {
198 199 200 201
    os << "__local";
  }
}

202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217
std::string CodeGenOpenCL::CastFromTo(std::string value, Type from, Type target) {
  if (from == target) return value;
  std::ostringstream os;
  if (target.lanes() == 1) {
    os << "((";
    this->PrintType(target, os);
    os << ")" << value << ")";
  } else {  // convert vector type
    os << "(";
    os << "convert_";
    this->PrintType(target, os);
    os << "(" << value << "))";
  }
  return os.str();
}

218 219
void CodeGenOpenCL::VisitExpr_(const Broadcast* op, std::ostream& os) {   // NOLINT(*)
  std::string v = PrintExpr(op->value);
220
  os << "((";
221 222 223 224 225
  PrintType(op->type, os);
  os << ")(";
  for (int i = 0; i < op->lanes; ++i) {
    if (i != 0) os << ", ";
    os << v;
226
  }
227
  os << "))";
228
}
229

230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248
void CodeGenOpenCL::VisitExpr_(const Call *op, std::ostream& os) {  // NOLINT(*)
  /* Return type of ternary expression is not always same as its sub-expressions,
   * add a cast */
  if (op->is_intrinsic(intrinsic::tvm_if_then_else)) {
    os << "(";
    PrintType(op->args[2].type(), os);
    os << ")";
  }
  CodeGenC::VisitExpr_(op, os);
}

void CodeGenOpenCL::VisitExpr_(const Select* op, std::ostream& os) {  // NOLINT(*)
  /* Return type of ternary expression is not always same as its sub-expressions,
   * add a cast */
  os << "(";
  PrintType(op->true_value.type(), os);
  os << ")";
  CodeGenC::VisitExpr_(op, os);
}
249

250 251 252 253 254 255 256 257 258 259 260 261 262
void CodeGenOpenCL::VisitExpr_(const FloatImm *op, std::ostream& os) { // NOLINT(*)
  if (std::isinf(op->value)) {
    if (op->value < 0) {
      os << "-";
    }
    os << "INFINITY";
  } else if (std::isnan(op->value)) {
    os << "NAN";
  } else {
    CodeGenC::VisitExpr_(op, os);
  }
}

263 264 265 266 267 268 269 270 271 272 273 274
runtime::Module BuildOpenCL(Array<LoweredFunc> funcs) {
  using tvm::runtime::Registry;
  bool output_ssa = false;
  CodeGenOpenCL cg;
  cg.Init(output_ssa);
  for (LoweredFunc f : funcs) {
    cg.AddFunction(f);
  }
  std::string code = cg.Finish();
  if (const auto* f = Registry::Get("tvm_callback_opencl_postproc")) {
    code = (*f)(code).operator std::string();
  }
275
  return OpenCLModuleCreate(code, "cl", ExtractFuncInfo(funcs), code);
276 277 278
}

TVM_REGISTER_API("codegen.build_opencl")
279
.set_body_typed(BuildOpenCL);
280 281
}  // namespace codegen
}  // namespace tvm