codegen_hybrid.cc 15.5 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20
/*!
21 22
 * \file codegen_hybrid.cc
 */
23
#include <tvm/runtime/registry.h>
24 25 26 27 28 29 30
#include <iomanip>
#include <cctype>
#include "codegen_hybrid.h"

namespace tvm {
namespace contrib {

31 32 33
using runtime::TVMArgs;
using runtime::TVMRetValue;

34
using namespace tir;
35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63

std::string dot_to_underscore(std::string s) {
  for (auto &ch : s)
    if (ch == '.') ch = '_';
  return s;
}

std::string CodeGenHybrid::GetUniqueName(std::string prefix) {
  prefix = dot_to_underscore(prefix);
  auto it = ids_allocated_.find(prefix);
  if (it != ids_allocated_.end()) {
    while (true) {
      std::ostringstream os;
      os << prefix << (++it->second);
      std::string name = os.str();
      if (ids_allocated_.count(name) == 0) {
        prefix = name;
        break;
      }
    }
  }
  ids_allocated_[prefix] = 0;
  return prefix;
}

std::string CodeGenHybrid::Finish() {
  return stream.str();
}

64
void CodeGenHybrid::PrintType(DataType t, std::ostream &os) {
65 66 67 68 69 70 71 72 73 74 75 76 77 78
  if (t.is_float()) {
    os << "float";
    CHECK(t.bits() == 16 || t.bits() == 32 || t.bits() == 64);
  } else if (t.is_int()) {
    os << "int";
    CHECK(t.bits() == 8 || t.bits() == 16 || t.bits() == 32 || t.bits() == 64);
  } else {
    CHECK(t.is_uint()) << "Unsupported type " << t;
    os << "uint";
    CHECK(t.bits() == 8 || t.bits() == 16 || t.bits() == 32 || t.bits() == 64);
  }
  os << t.bits();
}

79
void CodeGenHybrid::VisitExpr_(const IntImmNode* op, std::ostream& os) {  // NOLINT(*)
80 81
  os << op->value;
}
82

83
void CodeGenHybrid::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOLINT(*)
84
  PrintType(op->dtype, os);
85 86
  os << "(" << std::setprecision(20) << op->value << ")";
}
87
void CodeGenHybrid::VisitExpr_(const StringImmNode* op, std::ostream& os) { // NOLINT(*)
88 89 90 91 92
  os << "'" << op->value << "'";
}

template<typename T>
inline void PrintBinaryExpr(const T* op,
93
                            const char* opstr,
94 95
                            std::ostream& os,  // NOLINT(*)
                            CodeGenHybrid* p) {
96
  CHECK(op->dtype.lanes() == 1)  << "vec bin op not implemented";
97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113
  if (isalpha(opstr[0])) {
    os << opstr << '(';
    p->PrintExpr(op->a, os);
    os << ", ";
    p->PrintExpr(op->b, os);
    os << ')';
  } else {
    os << '(';
    p->PrintExpr(op->a, os);
    if (!strcmp(opstr, "&&")) opstr = "and";
    if (!strcmp(opstr, "||")) opstr = "or";
    os << ' ' << opstr << ' ';
    p->PrintExpr(op->b, os);
    os << ')';
  }
}

114
inline void PrintBinaryIntrinsitc(const CallNode* op,
115
                                  const char* opstr,
116 117
                                  std::ostream& os,  // NOLINT(*)
                                  CodeGenHybrid* p) {
118
  CHECK(op->dtype.lanes() == 1)  << "vec bin intrin not implemented";
119 120 121 122 123 124 125 126
  CHECK_EQ(op->args.size(), 2U);
  os << '(';
  p->PrintExpr(op->args[0], os);
  os << opstr;
  p->PrintExpr(op->args[1], os);
  os << ')';
}

127
void CodeGenHybrid::VisitExpr_(const CastNode* op, std::ostream& os) {  // NOLINT(*)
128
  if (op->dtype == op->value.dtype()) {
129 130
    PrintExpr(op->value, stream);
  } else {
131
    PrintType(op->dtype, os);
132 133 134 135 136 137
    os << "(";
    PrintExpr(op->value, os);
    os << ")";
  }
}

138
void CodeGenHybrid::VisitExpr_(const VarNode* op, std::ostream& os) {  // NOLINT(*)
139 140
  os << GetVarID(op);
}
141
void CodeGenHybrid::VisitExpr_(const AddNode* op, std::ostream& os) {  // NOLINT(*)
142 143
  PrintBinaryExpr(op, "+", os, this);
}
144
void CodeGenHybrid::VisitExpr_(const SubNode* op, std::ostream& os) {  // NOLINT(*)
145 146
  PrintBinaryExpr(op, "-", os, this);
}
147
void CodeGenHybrid::VisitExpr_(const MulNode* op, std::ostream& os) {  // NOLINT(*)
148 149
  PrintBinaryExpr(op, "*", os, this);
}
150

151
void CodeGenHybrid::VisitExpr_(const DivNode* op, std::ostream& os) {  // NOLINT(*)
152
  if (op->dtype.is_int())
153 154 155 156
    PrintBinaryExpr(op, "//", os, this);
  else
    PrintBinaryExpr(op, "/", os, this);
}
157

158
void CodeGenHybrid::VisitExpr_(const FloorDivNode* op, std::ostream& os) {  // NOLINT(*)
159
  if (op->dtype.is_int())
160 161 162 163 164
    PrintBinaryExpr(op, "//", os, this);
  else
    PrintBinaryExpr(op, "/", os, this);
}

165
void CodeGenHybrid::VisitExpr_(const ModNode* op, std::ostream& os) {  // NOLINT(*)
166 167
  PrintBinaryExpr(op, "%", os, this);
}
168

169
void CodeGenHybrid::VisitExpr_(const FloorModNode* op, std::ostream& os) {  // NOLINT(*)
170 171
  PrintBinaryExpr(op, "%", os, this);
}
172
void CodeGenHybrid::VisitExpr_(const MinNode* op, std::ostream& os) {  // NOLINT(*)
173 174
  PrintBinaryExpr(op, "min", os, this);
}
175
void CodeGenHybrid::VisitExpr_(const MaxNode* op, std::ostream& os) {  // NOLINT(*)
176 177
  PrintBinaryExpr(op, "max", os, this);
}
178
void CodeGenHybrid::VisitExpr_(const EQNode* op, std::ostream& os) {  // NOLINT(*)
179 180
  PrintBinaryExpr(op, "==", os, this);
}
181
void CodeGenHybrid::VisitExpr_(const NENode* op, std::ostream& os) {  // NOLINT(*)
182 183
  PrintBinaryExpr(op, "!=", os, this);
}
184
void CodeGenHybrid::VisitExpr_(const LTNode* op, std::ostream& os) {  // NOLINT(*)
185 186
  PrintBinaryExpr(op, "<", os, this);
}
187
void CodeGenHybrid::VisitExpr_(const LENode* op, std::ostream& os) {  // NOLINT(*)
188 189
  PrintBinaryExpr(op, "<=", os, this);
}
190
void CodeGenHybrid::VisitExpr_(const GTNode* op, std::ostream& os) {  // NOLINT(*)
191 192
  PrintBinaryExpr(op, ">", os, this);
}
193
void CodeGenHybrid::VisitExpr_(const GENode* op, std::ostream& os) {  // NOLINT(*)
194 195
  PrintBinaryExpr(op, ">=", os, this);
}
196
void CodeGenHybrid::VisitExpr_(const AndNode* op, std::ostream& os) {  // NOLINT(*)
197 198
  PrintBinaryExpr(op, "&&", os, this);
}
199
void CodeGenHybrid::VisitExpr_(const OrNode* op, std::ostream& os) {  // NOLINT(*)
200 201
  PrintBinaryExpr(op, "||", os, this);
}
202
void CodeGenHybrid::VisitExpr_(const NotNode* op, std::ostream& os) {  // NOLINT(*)
203 204 205 206
  os << "not ";
  PrintExpr(op->a, os);
}

207 208
void CodeGenHybrid::VisitExpr_(const CallNode* op, std::ostream& os) {  // NOLINT(*)
  if (op->call_type == CallNode::Halide) {
209 210 211 212 213 214 215 216 217
    os << GetTensorID(op->func, op->value_index);
    os << "[";
    for (size_t i = 0; i < op->args.size(); ++i) {
      if (i) os << ", ";
      std::stringstream idx;
      PrintExpr(op->args[i], idx);
      os << idx.str();
    }
    os << "]";
218
  } else if (op->is_intrinsic(CallNode::bitwise_and)) {
219
    PrintBinaryIntrinsitc(op, "&", os, this);
220
  } else if (op->is_intrinsic(CallNode::bitwise_xor)) {
221
    PrintBinaryIntrinsitc(op, "^", os, this);
222
  } else if (op->is_intrinsic(CallNode::bitwise_or)) {
223
    PrintBinaryIntrinsitc(op, "|", os, this);
224
  } else if (op->is_intrinsic(CallNode::shift_left)) {
225
    PrintBinaryIntrinsitc(op, "<<", os, this);
226
  } else if (op->is_intrinsic(CallNode::shift_right)) {
227
    PrintBinaryIntrinsitc(op, ">>", os, this);
228
  } else if (op->is_intrinsic(CallNode::bitwise_not)) {
229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250
    CHECK_EQ(op->args.size(), 1U);
    os << "(~";
    PrintExpr(op->args[0], os);
    os << ')';
  } else if (op->is_intrinsic(intrinsic::tvm_if_then_else)) {
    PrintExpr(op->args[1], os);
    os << " if ";
    PrintExpr(op->args[0], os);
    os << " else ";
    PrintExpr(op->args[2], os);
  } else {
    os << op->name << "(";
    for (size_t i = 0; i < op->args.size(); i++) {
      PrintExpr(op->args[i], os);
      if (i < op->args.size() - 1) {
        os << ", ";
      }
    }
    os << ")";
  }
}

251
void CodeGenHybrid::VisitExpr_(const LoadNode* op, std::ostream& os) {  // NOLINT(*)
252 253 254
  LOG(FATAL) << "Phase 0 has no Load(s)!";
}

255
void CodeGenHybrid::VisitStmt_(const StoreNode* op) {
256 257 258
  LOG(FATAL) << "Phase 0 has no Store(s)!";
}

259
void CodeGenHybrid::VisitExpr_(const LetNode* op, std::ostream& os) {  // NOLINT(*)
260 261 262
  LOG(FATAL) << "Phase 0 has no Let(s)!";
}

263
void CodeGenHybrid::VisitStmt_(const AllocateNode* op) {
264 265 266
  LOG(FATAL) << "Phase 0 has no Allocate(s)!";
}

267
void CodeGenHybrid::VisitExpr_(const RampNode* op, std::ostream& os) {  // NOLINT(*)
268 269 270
  LOG(FATAL) << "Ramp to be supported yet";
}

271
void CodeGenHybrid::VisitExpr_(const BroadcastNode* op, std::ostream& os) {   // NOLINT(*)
272 273 274
  LOG(FATAL) << "Broadcast: not supported ";
}

275
void CodeGenHybrid::VisitExpr_(const SelectNode* op, std::ostream& os) {  // NOLINT(*)
276 277 278 279 280 281 282 283
  PrintExpr(op->true_value, os);
  os << " if ";
  PrintExpr(op->condition, os);
  os << " else ";
  PrintExpr(op->false_value, os);
  os << "\n";
}

284
void CodeGenHybrid::VisitStmt_(const LetStmtNode* op) {
285 286 287 288 289
  std::string value = PrintExpr(op->value);
  stream << GetVarID(op->var.get()) << " = " << value << ";\n";
  PrintStmt(op->body);
}

290
void CodeGenHybrid::VisitStmt_(const AttrStmtNode* op) {
291
  if (op->attr_key == tir::attr::thread_extent) {
292 293 294 295 296 297 298 299 300 301 302
    auto iter_var = op->node.as<IterVarNode>();
    CHECK(iter_var);
    binds_[iter_var->var.get()] = dot_to_underscore(iter_var->var->name_hint);
    PrintIndent();
    stream << "for " << binds_[iter_var->var.get()] << " in bind('"
           << iter_var->var->name_hint << "', ";
    PrintExpr(op->value, stream);
    stream << "):\n";
    indent_ += tab_;
    PrintStmt(op->body);
    indent_ -= tab_;
303
  } else if (op->attr_key == tir::attr::realize_scope) {
304
    auto v = Downcast<FunctionRef>(op->node);
305
    alloc_storage_scope_[v] = op->value.as<StringImmNode>()->value;
306 307 308 309 310 311 312
    PrintStmt(op->body);
  } else {
    // For now we ignore the unsupported AttrStmt
    PrintStmt(op->body);
  }
}

313
void CodeGenHybrid::VisitStmt_(const RealizeNode* op) {
314 315 316 317 318 319 320 321 322 323
  CHECK(alloc_storage_scope_.count(op->func));
  if (!alloc_storage_scope_[op->func].empty()) {
    PrintIndent();
    stream << GetTensorID(op->func, op->value_index) << " = allocate((";
    for (size_t i = 0; i < op->bounds.size(); ++i) {
      if (i) stream << ", ";
      stream << PrintExpr(op->bounds[i]->extent);
    }
    if (op->bounds.size() == 1) stream << ", ";
    stream << "), '";
324
    PrintType(op->dtype, stream);
325 326 327 328 329 330
    stream << "', '";
    stream << alloc_storage_scope_[op->func] << "')\n";
  }
  PrintStmt(op->body);
}

331
void CodeGenHybrid::VisitStmt_(const AssertStmtNode* op) {
332 333 334 335 336 337 338 339 340
  PrintIndent();
  stream << "assert ";
  PrintExpr(op->condition, stream);
  stream << ", ";
  PrintExpr(op->message, stream);
  stream << "\n";
  PrintStmt(op->body);
}

341
void CodeGenHybrid::VisitStmt_(const ProvideNode* op) {
342 343 344 345 346 347 348 349 350 351 352 353
  PrintIndent();
  stream << GetTensorID(op->func, op->value_index);
  stream << "[";
  for (size_t i = 0; i < op->args.size(); ++i) {
    if (i) stream << ", ";
    PrintExpr(op->args[i], stream);
  }
  stream << "] = ";
  PrintExpr(op->value, stream);
  stream << "\n";
}

354
void CodeGenHybrid::VisitStmt_(const ForNode* op) {
355 356 357 358 359 360 361 362 363 364 365 366
  std::string extent = PrintExpr(op->extent);
  PrintIndent();
  std::string vid = GetVarID(op->loop_var.get());
  stream << "for " << vid << " in " << "range(" << extent << "):\n";
  indent_ += tab_;
  PrintStmt(op->body);
  indent_ -= tab_;
}

bool is_noop(const Stmt &stmt) {
  if (!stmt.defined())
    return true;
367
  if (auto eval = stmt.as<EvaluateNode>())
368 369 370 371
    return is_const(eval->value);
  return false;
}

372
void CodeGenHybrid::VisitStmt_(const IfThenElseNode* op) {
373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388
  std::string cond = PrintExpr(op->condition);
  PrintIndent();
  stream << "if " << cond << ":\n";
  indent_ += tab_;
  PrintStmt(op->then_case);
  indent_ -= tab_;

  if (!is_noop(op->else_case)) {
    PrintIndent();
    stream << "else:\n";
    indent_ += tab_;
    PrintStmt(op->else_case);
    indent_ -= tab_;
  }
}

389 390 391 392
void CodeGenHybrid::VisitStmt_(const SeqStmtNode* op) {
  for (Stmt stmt : op->seq) {
    PrintStmt(stmt);
  }
393 394
}

395
void CodeGenHybrid::VisitStmt_(const EvaluateNode* op) {
396 397 398 399 400 401
  if (is_const(op->value)) return;
  std::string str = PrintExpr(op->value);
  if (!str.empty())
    stream << str << "\n";
}

402
void CodeGenHybrid::VisitStmt_(const ProducerConsumerNode* op) {
403 404 405 406 407 408 409
  PrintStmt(op->body);
}

void CodeGenHybrid::PrintIndent() {
  stream << std::string(indent_, ' ');
}

410
std::string CodeGenHybrid::GetVarID(const VarNode *v) {
411 412
  if (binds_.count(v))
    return binds_[v];
413
  auto key = std::make_pair(static_cast<const Object*>(v), 0);
414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436
  if (id_map_.count(key)) {
    return id_map_[key];
  }
  return id_map_[key] = GetUniqueName(v->name_hint);
}

std::string CodeGenHybrid::GetTensorID(const FunctionRef &func, int value_index) {
  auto key = std::make_pair(func.get(), value_index);
  if (id_map_.count(key)) {
    return id_map_[key];
  }
  std::string name_hint = func->func_name();
  if (func->num_outputs() > 1) {
    name_hint += "_v" + std::to_string(value_index);
  }
  return id_map_[key] = GetUniqueName(name_hint);
}

void CodeGenHybrid::ReserveKeywords() {
  GetUniqueName("def");
  GetUniqueName("for");
  GetUniqueName("in");
  GetUniqueName("range");
437 438
  GetUniqueName("True");
  GetUniqueName("False");
439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472
  GetUniqueName("unroll");
  GetUniqueName("const_range");
  GetUniqueName("parallel");
  GetUniqueName("vectorize");
  GetUniqueName("bind");
  GetUniqueName("threadIdx.x");
  GetUniqueName("threadIdx.y");
  GetUniqueName("threadIdx.z");
  GetUniqueName("blockIdx.x");
  GetUniqueName("blockIdx.y");
  GetUniqueName("blockIdx.z");
  GetUniqueName("vthread");
  GetUniqueName("allocate");
  GetUniqueName("output_tensor");
  GetUniqueName("sqrt");
  GetUniqueName("log");
  GetUniqueName("tanh");
  GetUniqueName("power");
  GetUniqueName("exp");
  GetUniqueName("sigmoid");
  GetUniqueName("popcount");
  GetUniqueName("likely");
  GetUniqueName("int8");
  GetUniqueName("int16");
  GetUniqueName("int32");
  GetUniqueName("int64");
  GetUniqueName("uint8");
  GetUniqueName("uint16");
  GetUniqueName("uint32");
  GetUniqueName("uint64");
  GetUniqueName("float16");
  GetUniqueName("float32");
  GetUniqueName("float64");
  GetUniqueName("ceil_div");
473
  GetUniqueName("max_num_threads");
474 475 476
}

void CodeGenHybrid::DumpStmt(const Stmt &stmt,
477
                             const Array<ObjectRef> &inputs,
478 479 480 481 482 483 484 485 486 487 488
                             const Array<Tensor> &outputs,
                             const std::string &name) {
  ReserveKeywords();
  GetUniqueName(name);

  stream << "def " << name << "(";
  for (size_t i = 0; i < inputs.size(); ++i) {
    if (i) stream << ", ";
    if (auto tensor = inputs[i].as<TensorNode>()) {
      stream << GetTensorID(tensor->op, tensor->value_index);
    } else {
489
      auto var = inputs[i].as<VarNode>();
490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528
      CHECK(var) << "Input should either be a tensor or a variable!";
      stream << GetVarID(var);
    }
  }
  stream << "):\n";
  indent_ += tab_;
  for (size_t i = 0; i < outputs.size(); ++i) {
    PrintIndent();
    stream << GetTensorID(outputs[i]->op, outputs[i]->value_index)
           << " = output_tensor((";
    for (size_t j = 0; j < outputs[i]->shape.size(); ++j) {
      if (j) stream << ", ";
      PrintExpr(outputs[i]->shape[j], stream);
    }
    if (outputs[i]->shape.size() == 1)
      stream << ", ";
    stream << "), '" << outputs[i]->dtype << "')\n";
  }
  PrintStmt(stmt);
  PrintIndent();
  stream << "return ";
  for (size_t i = 0; i < outputs.size(); ++i) {
    if (i) stream << ", ";
    stream << GetTensorID(outputs[i]->op, outputs[i]->value_index);
  }
  stream << "\n";
}

TVM_REGISTER_GLOBAL("hybrid._Dump")
.set_body([](TVMArgs args, TVMRetValue* rv) {
    CodeGenHybrid codegen;
    if (args.size() == 4)
      codegen.DumpStmt(args[0], args[1], args[2], args[3]);
    else
      codegen.DumpStmt(args[0], args[1], args[2]);
    *rv = codegen.Finish();
  });
}  // namespace contrib
}  // namespace tvm