codegen_c_host.cc 9.56 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22
/*!
 * \file codegen_c_host.cc
 */
23
#include <tvm/target/codegen.h>
24 25 26
#include <vector>
#include <string>
#include "codegen_c_host.h"
27
#include "../build_common.h"
28 29 30 31 32

namespace tvm {
namespace codegen {

CodeGenCHost::CodeGenCHost() {
33
  module_name_ = GetUniqueName("__tvm_module_ctx");
34 35
}

36 37
void CodeGenCHost::Init(bool output_ssa, bool emit_asserts) {
  emit_asserts_ = emit_asserts;
38 39
  decl_stream << "#include \"tvm/runtime/c_runtime_api.h\"\n";
  decl_stream << "#include \"tvm/runtime/c_backend_api.h\"\n";
40
  decl_stream << "extern void* " << module_name_ << " = NULL;\n";
41 42 43 44 45 46 47 48 49 50
  CodeGenC::Init(output_ssa);
}

void CodeGenCHost::AddFunction(LoweredFunc f) {
  // clear previous generated state.
  this->InitFuncState(f);
  // reserve keywords
  ReserveKeywordsAsUnique();
  // add to alloc buffer type.
  for (const auto & kv : f->handle_data_type) {
51
    RegisterHandleType(kv.first.get(), kv.second.dtype());
52 53 54 55 56 57 58 59 60 61
  }

  this->stream << "#ifdef __cplusplus\n";
  this->stream << "extern \"C\"\n";
  this->stream << "#endif\n";
  this->stream << "TVM_DLL int32_t " << f->name << "(";
  for (size_t i = 0; i < f->args.size(); ++i) {
    Var v = f->args[i];
    std::string vid = AllocVarID(v.get());
    if (i != 0) stream << ", ";
62
    if (v.dtype().is_handle()) {
63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
      auto it = alloc_storage_scope_.find(v.get());
      if (it != alloc_storage_scope_.end()) {
        PrintStorageScope(it->second, stream);
      }
      stream << ' ';

      if (handle_data_type_.count(v.get())) {
        PrintType(handle_data_type_.at(v.get()), stream);
      } else {
        stream << "void";
      }
      stream << "*";

      if (f->is_restricted && restrict_keyword_.length() != 0) {
        stream << ' ' << restrict_keyword_;
      }
    } else {
80
      PrintType(v.dtype(), stream);
81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98
    }
    stream << ' ' << vid;
  }
  stream << ") {\n";
  this->PreFunctionBody(f);
  int func_scope = this->BeginScope();
  this->PrintStmt(f->body);
  this->PrintIndent();
  this->stream << "return 0;\n";
  this->EndScope(func_scope);
  this->PrintIndent();
  this->stream << "}\n\n";
}

std::string CodeGenCHost::Finish() {
  return CodeGenC::Finish();
}

99
void CodeGenCHost::PrintType(DataType t, std::ostream& os) {  // NOLINT(*)
100 101 102 103 104 105
  int lanes = t.lanes();
  if (t.is_handle()) {
    CHECK_EQ(lanes, 1)
        << "does not support vector types";
    os << "void*"; return;
  }
106
  if (t == DataType::Bool()) {
107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144
    os << "bool"; return;
  }
  bool fail = false;
  if (t.is_float()) {
    switch (t.bits()) {
      case 16:
        os << "half";
        break;
      case 32: os << "float"; break;
      case 64:
        os << "double";
        break;
      default: fail = true; break;
    }
    if (!fail && lanes == 1) return;
    if (!fail && (lanes >= 2 && lanes <= 16)) {
      os << lanes; return;
    }
  } else if (t.is_uint() || t.is_int()) {
    if (t.is_uint()) {
      os << 'u';
    }
    switch (t.bits()) {
      case 8: os << "int8_t"; break;
      case 16: os << "int16_t"; break;
      case 32: os << "int32_t"; break;
      case 64: os << "int64_t"; break;
      case 1: os << "int32_t"; break;
      default: fail = true; break;
    }
    if (!fail && lanes == 1) return;
    if (!fail && (lanes >= 2 && lanes <= 16)) {
      os << lanes; return;
    }
  }
  LOG(FATAL) << "Cannot convert type " << t << " to C type";
}

145
void CodeGenCHost::VisitExpr_(const BroadcastNode* op, std::ostream& os) {   // NOLINT(*)
146 147
  std::string v = PrintExpr(op->value);
  os << "((";
148
  PrintType(op->dtype, os);
149 150 151 152 153 154 155 156
  os << ")(";
  for (int i = 0; i < op->lanes; ++i) {
    if (i != 0) os << ", ";
    os << v;
  }
  os << "))";
}

157 158
void CodeGenCHost::PrintGetFuncFromBackend(const std::string& func_name,
                                           const std::string& packed_func_name) {
159 160 161 162
  this->PrintIndent();
  this->stream << "if (" << packed_func_name << " == NULL) {\n";
  int packed_func_if_scope = this->BeginScope();
  this->PrintIndent();
163
  this->stream << "if (TVMBackendGetFuncFromEnv(" << module_name_
164 165 166 167 168 169 170 171 172 173 174 175 176
              << ", \"" << func_name << "\""
              << ", &" << packed_func_name << ") != 0) {\n";
  int get_func_env_scope = this->BeginScope();
  this->PrintIndent();
  this->stream << "return -1;\n";
  this->EndScope(get_func_env_scope);
  this->PrintIndent();
  this->stream << "}\n";
  this->EndScope(packed_func_if_scope);
  this->PrintIndent();
  this->stream << "}\n";
}

177
void CodeGenCHost::PrintFuncCall(const std::string& packed_func_name, int num_args) {
178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196
  this->PrintIndent();
  std::string ret_val = GetUniqueName("ret_val");
  std::string ret_type_code = GetUniqueName("ret_type_code");
  this->stream << "TVMValue " << ret_val << ";\n";
  this->PrintIndent();
  this->stream << "int " << ret_type_code << ";\n";
  this->PrintIndent();
  this->stream << "if (TVMFuncCall(" << packed_func_name << ", "
               << "(TVMValue*) stack_value" << ", " << "(int*) stack_tcode" << ", "
               << num_args << ", " << "&" << ret_val << ", " << "&"
               << ret_type_code << ") != 0) {\n";
  int func_call_scope = this->BeginScope();
  this->PrintIndent();
  this->stream << "return -1;\n";
  this->EndScope(func_call_scope);
  this->PrintIndent();
  this->stream << "}\n";
}

197
void CodeGenCHost::VisitExpr_(const CallNode *op, std::ostream& os) { // NOLINT(*)
198 199
  if (op->is_intrinsic(intrinsic::tvm_stack_alloca)) {
    std::string stack_name = GetUniqueName("stack");
200 201
    const std::string& type = op->args[0].as<StringImmNode>()->value;
    const IntImmNode* num = op->args[1].as<IntImmNode>();
202
    CHECK(num != nullptr);
203
    static_assert(alignof(TVMValue) % alignof(DLTensor) == 0, "invariant");
204 205 206 207 208 209 210 211 212
    size_t unit = sizeof(TVMValue);
    size_t size = 0;
    if (type == "shape") {
      size = (num->value * sizeof(tvm_index_t) + unit - 1) / unit;
    } else if (type == "arg_value") {
      size = (num->value * sizeof(TVMValue) + unit - 1) / unit;
    } else if (type == "arg_tcode") {
      size = (num->value * sizeof(int) + unit - 1) / unit;
    } else if (type == "array") {
213
      size = (num->value * sizeof(DLTensor) + unit - 1) / unit;
214 215 216 217 218 219 220
    } else {
      LOG(FATAL) << "Unknown stack alloca type " << type;
    }
    this->PrintIndent();
    this->stream << "TVMValue " << stack_name << "[" << size << "];\n";
    os << stack_name;
  } else if (op->is_intrinsic(intrinsic::tvm_call_packed_lowered)) {
221
    const StringImmNode* s = op->args[0].as<StringImmNode>();
222
    CHECK(s != nullptr) << "tvm_call_packed_lowered expects first argument as function name";
223 224
    int64_t begin = op->args[3].as<IntImmNode>()->value;
    int64_t end = op->args[4].as<IntImmNode>()->value;
225 226 227 228 229 230 231 232 233 234 235 236 237 238 239
    int64_t num_args = end - begin;
    CHECK_GE(num_args, 0);
    std::string func_name = s->value;
    std::string packed_func_name = GetUniqueName(func_name + "_packed");
    decl_stream << "static void* " << packed_func_name << " = NULL;\n";
    this->PrintGetFuncFromBackend(func_name, packed_func_name);
    this->PrintFuncCall(packed_func_name, num_args);
  } else if (op->is_intrinsic(intrinsic::tvm_throw_last_error)) {
    this->PrintIndent();
    this->stream << "return -1;\n";
  } else {
    CodeGenC::VisitExpr_(op, os);
  }
}

240
void CodeGenCHost::VisitStmt_(const AssertStmtNode *op) { // NOLINT(*)
241 242 243 244 245 246
  if (emit_asserts_) {
    std::string cond = PrintExpr(op->condition);
    PrintIndent();
    stream << "if (!(" << cond << ")) {\n";
    int assert_if_scope = this->BeginScope();
    PrintIndent();
247
    stream << "TVMAPISetLastError(\"" << op->message.as<StringImmNode>()->value << "\");\n";
248 249 250 251 252 253
    PrintIndent();
    stream << "return -1;\n";
    this->EndScope(assert_if_scope);
    PrintIndent();
    stream << "}\n";
  }
254 255 256
  this->PrintStmt(op->body);
}

257
void CodeGenCHost::VisitExpr_(const MinNode *op, std::ostream& os) {  // NOLINT(*)
258 259 260
  PrintTernaryCondExpr(op, "<", os);
}

261
void CodeGenCHost::VisitExpr_(const MaxNode *op, std::ostream& os) {  // NOLINT(*)
262 263 264 265 266 267 268 269 270
  PrintTernaryCondExpr(op, ">", os);
}

template <typename T>
inline void CodeGenCHost::PrintTernaryCondExpr(const T* op,
                                           const char* compare,
                                           std::ostream& os) {  // NOLINT(*)
  std::ostringstream temp_a;
  VisitExpr(op->a, temp_a);
271
  std::string a_id = SSAGetID(temp_a.str(), op->a.dtype());
272 273
  std::ostringstream temp_b;
  VisitExpr(op->b, temp_b);
274
  std::string b_id = SSAGetID(temp_b.str(), op->b.dtype());
275 276 277 278 279

  os << "((" << a_id << ") " << compare << " (" << b_id << ") "
     << "? (" << a_id << ") : (" << b_id << "))";
}

280 281 282
runtime::Module BuildCHost(Array<LoweredFunc> funcs) {
  using tvm::runtime::Registry;
  bool output_ssa = false;
283
  bool emit_asserts = false;
284
  CodeGenCHost cg;
285
  cg.Init(output_ssa, emit_asserts);
286 287 288 289 290 291 292
  for (LoweredFunc f : funcs) {
    cg.AddFunction(f);
  }
  std::string code = cg.Finish();
  return CSourceModuleCreate(code, "c");
}

293
TVM_REGISTER_GLOBAL("codegen.build_c")
294 295 296 297 298
.set_body([](TVMArgs args, TVMRetValue* rv) {
    *rv = BuildCHost(args[0]);
  });
}  // namespace codegen
}  // namespace tvm