codegen_c.cc 29.1 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22
/*!
 * \file codegen_c.cc
 */
23
#include <iomanip>
24
#include <cctype>
25
#include "codegen_c.h"
26 27
#include "../../arith/compute_expr.h"
#include "../../tir/pass/ir_util.h"
28 29 30 31

namespace tvm {
namespace codegen {

32
using namespace tir;
33

34
void CodeGenC::Init(bool output_ssa) {
35
  print_ssa_form_ = output_ssa;
36 37
}

38
void CodeGenC::InitFuncState(const PrimFunc& f) {
39 40
  alloc_storage_scope_.clear();
  handle_data_type_.clear();
41
  CodeGenSourceBase::ClearFuncState();
42
}
43 44

void CodeGenC::ReserveKeywordsAsUnique() {
45
  // skip the first underscore, so SSA variable starts from _1
46
  GetUniqueName("_");
47
  GetUniqueName("extern");
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74
  GetUniqueName("void");
  GetUniqueName("int");
  GetUniqueName("float");
  GetUniqueName("double");
  GetUniqueName("char");
  GetUniqueName("unsigned");
  GetUniqueName("short");
  GetUniqueName("long");
  GetUniqueName("if");
  GetUniqueName("else");
  GetUniqueName("switch");
  GetUniqueName("case");
  GetUniqueName("default");
  GetUniqueName("for");
  GetUniqueName("do");
  GetUniqueName("while");
  GetUniqueName("goto");
  GetUniqueName("register");
  GetUniqueName("continue");
  GetUniqueName("break");
  GetUniqueName("typedef");
  GetUniqueName("struct");
  GetUniqueName("enum");
  GetUniqueName("union");
  GetUniqueName("return");
}

75
void CodeGenC::AddFunction(const PrimFunc& f) {
76 77 78 79
  // clear previous generated state.
  this->InitFuncState(f);
  // reserve keywords
  ReserveKeywordsAsUnique();
80

81 82 83 84 85 86 87 88 89 90
  auto global_symbol = f->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
  CHECK(global_symbol.defined())
      << "CodeGenC: Expect PrimFunc to have the global_symbol attribute";
  bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias);

  this->PrintFuncPrefix();
  this->stream << " " << static_cast<std::string>(global_symbol) << "(";

  for (size_t i = 0; i < f->params.size(); ++i) {
    tir::Var v = f->params[i];
91 92
    std::string vid = AllocVarID(v.get());
    if (i != 0) stream << ", ";
93
    if (v.dtype().is_handle()) {
94
      auto it = alloc_storage_scope_.find(v.get());
95
      if (it != alloc_storage_scope_.end()) {
96
        PrintStorageScope(it->second, stream);
97 98
        stream << ' ';
      }
99

100 101 102 103 104 105 106 107
      PrintType(GetType(v), stream);
      // Register handle data type
      // TODO(tvm-team): consider simply keep type info in the
      // type annotation(via a normalizing rewriting).
      if (auto* ptr = v->type_annotation.as<PointerTypeNode>()) {
        if (auto* prim = ptr->element_type.as<PrimTypeNode>()) {
          RegisterHandleType(v.get(), prim->dtype);
        }
108
      }
109

110
      if (no_alias && restrict_keyword_.length() != 0) {
111 112
        stream << ' ' << restrict_keyword_;
      }
113
    } else {
114
      PrintType(GetType(v), stream);
115
    }
116 117 118
    stream << ' ' << vid;
  }
  stream << ") {\n";
119
  this->PreFunctionBody(f);
120
  int func_scope = this->BeginScope();
121
  this->PrintStmt(f->body);
122
  this->PrintFinalReturn();
123
  this->EndScope(func_scope);
124
  this->PrintIndent();
125 126 127
  this->stream << "}\n\n";
}

128 129 130 131 132 133 134
void CodeGenC::PrintFuncPrefix() {
  stream << "void";
}

void CodeGenC::PrintFinalReturn() {
}

135
std::string CodeGenC::Finish() {
136
  return decl_stream.str() + stream.str();
137 138
}

139
void CodeGenC::PrintExpr(const PrimExpr& n, std::ostream& os) {  // NOLINT(*)
140 141
  if (print_ssa_form_) {
    std::ostringstream temp;
142
    VisitExpr(n, temp);
143
    os << SSAGetID(temp.str(), n.dtype());
144
  } else {
145
    VisitExpr(n, os);
146 147 148
  }
}

149
void CodeGenC::PrintSSAAssign(
150
    const std::string& target, const std::string& src, DataType t) {
151 152 153 154 155
  PrintType(t, stream);
  stream << ' ' << target << " = ";
  if (src.length() > 3 &&
      src[0] == '(' && src[src.length() - 1] == ')') {
    stream << src.substr(1, src.length() - 2);
156
  } else {
157
    stream << src;
158
  }
159
  stream << ";\n";
160 161 162
}

// Print a reference expression to a buffer.
163
std::string CodeGenC::GetBufferRef(
164
    DataType t, const VarNode* buffer, PrimExpr index) {
165
  std::ostringstream os;
166
  std::string vid = GetVarID(buffer);
167 168 169 170
  std::string scope;
  if (alloc_storage_scope_.count(buffer)) {
    scope = alloc_storage_scope_.at(buffer);
  }
171
  bool is_vol = IsVolatile(buffer);
172
  if (t.lanes() == 1) {
173
    if (!HandleTypeMatch(buffer, t) || is_vol) {
174
      os << "((";
175 176 177
      if (is_vol) {
        os << "volatile ";
      }
178 179
      // Scope may not be part of type.
      if (!scope.empty() && IsScopePartOfType()) {
180 181 182
        PrintStorageScope(scope, os);
      }
      os << ' ';
183 184 185 186 187
      PrintType(t, os);
      os << "*)" << vid << ')';
    } else {
      os << vid;
    }
188
    os << "[(";
189
    PrintExpr(index, os);
190 191 192 193 194
    os << ")";
    if (t.bits() == 4 ||
        (t.bits() == 1 && t.is_int())) {
      os << " / " << (32 / t.bits());
    }
195 196 197 198
    os << ']';
  } else {
    // Buffer declared as vector type.
    // optimize for case where it is in register,
199
    if (HandleTypeMatch(buffer, t) && !is_vol) {
200 201 202 203 204 205
      // optimize for constant access
      int offset;
      if (arith::GetConstInt(index, &offset)) {
        CHECK_EQ(offset % t.lanes(), 0)
            << "Find unaligned vector load to a vector type";
        os << vid << '[' << (offset / t.lanes()) << ']';
206
        return os.str();
207 208 209
      }
    }
    os << "((";
210 211 212
    if (is_vol) {
      os << "volatile ";
    }
213
    if (!scope.empty() && IsScopePartOfType()) {
214 215 216
      PrintStorageScope(scope, os);
    }
    os << ' ';
217 218 219 220
    PrintType(t, os);
    os << "*)(";
    if (!HandleTypeMatch(buffer, t.element_of())) {
      os << '(';
221
      if (!scope.empty() && IsScopePartOfType()) {
222 223 224
        PrintStorageScope(scope, os);
      }
      os << ' ';
225 226 227
      PrintType(t.element_of(), os);
      os << "*)";
    }
228
    os << vid << " + (";
229
    PrintExpr(index, os);
230 231 232 233 234
    os << ")";
    if (t.bits() == 4 ||
        (t.bits() == 1 && t.is_int())) {
      os << " / " << (32 / t.bits());
    }
235 236
    os << "))[0]";
  }
237
  return os.str();
238 239
}

240 241
// Print a reference expression to a buffer.
std::string CodeGenC::GetStructRef(
242
    DataType t, const PrimExpr& buffer, const PrimExpr& index, int kind) {
243 244
  if (kind < intrinsic::kArrKindBound_) {
    std::ostringstream os;
245
    os << "(((DLTensor*)";
246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264
    this->PrintExpr(buffer, os);
    os << ")";
    if (kind == intrinsic::kArrAddr) {
      os << " + ";
      this->PrintExpr(index, os);
      os << ")";
      return os.str();
    }
    os << '[';
    this->PrintExpr(index, os);
    os << "].";
    // other case: get fields.
    switch (kind) {
      case intrinsic::kArrData: os << "data"; break;
      case intrinsic::kArrShape: os << "shape"; break;
      case intrinsic::kArrStrides: os << "strides"; break;
      case intrinsic::kArrNDim: os << "ndim"; break;
      case intrinsic::kArrTypeCode: os << "dtype.code"; break;
      case intrinsic::kArrTypeBits: os << "dtype.bits"; break;
265
      case intrinsic::kArrByteOffset: os << "byte_offset"; break;
266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285
      case intrinsic::kArrTypeLanes: os << "dtype.lanes"; break;
      case intrinsic::kArrDeviceId: os << "ctx.device_id"; break;
      case intrinsic::kArrDeviceType: os << "ctx.device_type"; break;
      default: LOG(FATAL) << "unknown field code";
    }
    os << ')';
    return os.str();
  } else {
    CHECK_LT(kind, intrinsic::kTVMValueKindBound_);
    std::ostringstream os;
    os << "(((TVMValue*)";
    this->PrintExpr(buffer, os);
    os << ")[" << index << "].";
    if (t.is_handle()) {
      os << "v_handle";
    } else if (t.is_float()) {
      os << "v_float64";
    } else if (t.is_int()) {
      os << "v_int64";
    } else {
Siju committed
286
      LOG(FATAL) << "Do not know how to handle type" << t;
287 288 289 290 291 292
    }
    os << ")";
    return os.str();
  }
}

293
bool CodeGenC::HandleTypeMatch(const VarNode* buf_var, DataType t) const {
294 295 296 297 298
  auto it = handle_data_type_.find(buf_var);
  if (it == handle_data_type_.end()) return false;
  return it->second == t;
}

299
void CodeGenC::RegisterHandleType(const VarNode* buf_var, DataType t) {
300 301 302 303 304 305 306 307 308
  auto it = handle_data_type_.find(buf_var);
  if (it == handle_data_type_.end()) {
    handle_data_type_[buf_var] = t;
  } else {
    CHECK(it->second == t)
        << "conflicting buf var type";
  }
}

309
void CodeGenC::PrintVecElemLoad(const std::string& vec,
310
                                DataType t, int i,
311
                                std::ostream& os) {  // NOLINT(*)
312
  os << vec << ".s" << std::hex << i << std::dec;
313 314 315
}

void CodeGenC::PrintVecElemStore(const std::string& vec,
316
                                 DataType t, int i,
317 318 319
                                 const std::string& value) {
  this->PrintIndent();
  stream << vec << ".s" << std::hex << i
320
         << " = " << value << ";\n" << std::dec;
321 322
}

323
std::string CodeGenC::GetVecLoad(
324
    DataType t, const VarNode* buffer, PrimExpr base) {
325
  return GetBufferRef(t, buffer, base);
326 327
}

328
void CodeGenC::PrintVecStore(const VarNode* buffer,
329
                             DataType t, PrimExpr base,
330
                             const std::string& value) {
331
  std::string ref = GetBufferRef(t, buffer, base);
332
  this->PrintIndent();
333
  stream << ref << " = " << value << ";\n";
334 335
}

336
std::string CodeGenC::CastFromTo(std::string value, DataType from, DataType target) {
337 338 339 340 341 342 343 344
  if (from == target) return value;
  std::ostringstream os;
  os << "((";
  this->PrintType(target, os);
  os << ")" << value << ")";
  return os.str();
}

345
void CodeGenC::BindThreadIndex(const IterVar& iv) {
346
  LOG(FATAL) << "not implemented";
347 348
}

349
void CodeGenC::PrintStorageSync(const CallNode* op) { // NOLINT(*)
350 351 352 353 354 355
}

void CodeGenC::PrintStorageScope(const std::string& scope, std::ostream& os) { // NOLINT(*)
  CHECK_EQ(scope, "global");
}

356
void CodeGenC::PrintType(DataType t, std::ostream& os) {  // NOLINT(*)
357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386
  CHECK_EQ(t.lanes(), 1)
      << "do not yet support vector types";
  if (t.is_handle()) {
    os << "void*"; return;
  }
  if (t.is_float()) {
    if (t.bits() == 32) {
      os << "float"; return;
    }
    if (t.bits() == 64) {
      os << "double"; return;
    }
  } else if (t.is_uint()) {
    switch (t.bits()) {
      case 8: case 16: case 32: case 64: {
        os << "uint" << t.bits() << "_t"; return;
      }
      case 1: os << "int"; return;
    }
  } else if (t.is_int()) {
    switch (t.bits()) {
      case 8: case 16: case 32: case 64: {
        os << "int" << t.bits() << "_t";  return;
      }
    }
  }
  LOG(FATAL) << "Cannot convert type " << t << " to C type";
}


387 388 389 390 391 392 393 394 395 396 397 398 399 400
void CodeGenC::PrintType(const Type& type, std::ostream& os) { // NOLINT(*)
  if (auto* ptr = type.as<PrimTypeNode>()) {
    return PrintType(ptr->dtype, os);
  } else if (auto* ptr = type.as<PointerTypeNode>()) {
    PrintType(ptr->element_type, os);
    os << '*';
  } else if (IsVoidType(type)) {
    os << "void";
  } else {
    LOG(FATAL) << "Type " << type << " does not have a corresponding C Type";
  }
}


401
inline void PrintConst(const IntImmNode* op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
402
  if (op->dtype == DataType::Int(32)) {
403 404 405 406 407 408
    std::ostringstream temp;
    temp << op->value;
    p->MarkConst(temp.str());
    os << temp.str();
  } else {
    os << "(";
409
    p->PrintType(op->dtype, os);
410 411 412 413
    os << ")" << op->value;
  }
}

414 415 416

inline void PrintUIntConst(DataType dtype, uint64_t val, std::ostream& os, CodeGenC* p) { // NOLINT(*)
  if (dtype == DataType::UInt(32)) {
417
    std::ostringstream temp;
418
    temp << val << "U";
419 420 421 422
    p->MarkConst(temp.str());
    os << temp.str();
  } else {
    os << "(";
423 424
    p->PrintType(dtype, os);
    os << ")" << val;
425 426 427
  }
}

428
inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
429
  switch (op->dtype.bits()) {
430 431
    case 64: case 32: {
      std::ostringstream temp;
432
      temp << std::scientific << op->value;
433
      if (op->dtype.bits() == 32) temp << 'f';
434 435 436 437 438 439
      p->MarkConst(temp.str());
      os << temp.str();
      break;
    }
    case 16: {
      os << '(';
440
      p->PrintType(op->dtype, os);
441
      os << ')' << std::scientific <<op->value << 'f';
442 443
      break;
    }
444
    default: LOG(FATAL) << "Bad bit-width for float: " << op->dtype << "\n";
445 446 447
  }
}

448
void CodeGenC::VisitExpr_(const IntImmNode* op, std::ostream& os) {  // NOLINT(*)
449 450
  PrintConst(op, os, this);
}
451

452
void CodeGenC::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOLINT(*)
453 454
  PrintConst(op, os, this);
}
455
void CodeGenC::VisitExpr_(const StringImmNode* op, std::ostream& os) { // NOLINT(*)
456 457
  os << "\"" << op->value << "\"";
}
458 459 460

template<typename T>
inline void PrintBinaryExpr(const T* op,
461
                            const char* opstr,
462 463
                            std::ostream& os,  // NOLINT(*)
                            CodeGenC* p) {
464
  if (op->dtype.lanes() == 1) {
465 466 467 468 469 470 471 472 473 474 475 476 477 478
    if (isalpha(opstr[0])) {
      os << opstr << '(';
      p->PrintExpr(op->a, os);
      os << ", ";
      p->PrintExpr(op->b, os);
      os << ')';
    } else {
      os << '(';
      p->PrintExpr(op->a, os);
      os << ' ' << opstr << ' ';
      p->PrintExpr(op->b, os);
      os << ')';
    }
  } else {
479
    p->PrintVecBinaryOp(opstr, op->dtype, op->a, op->b, os);
480
  }
481 482
}

483
inline void PrintBinaryIntrinsic(const CallNode* op,
484
                                  const char* opstr,
485 486
                                  std::ostream& os,  // NOLINT(*)
                                  CodeGenC* p) {
487
  if (op->dtype.lanes() == 1) {
488 489 490 491 492 493 494
    CHECK_EQ(op->args.size(), 2U);
    os << '(';
    p->PrintExpr(op->args[0], os);
    os << opstr;
    p->PrintExpr(op->args[1], os);
    os << ')';
  } else {
495
    p->PrintVecBinaryOp(opstr, op->dtype, op->args[0], op->args[1], os);
496
  }
497
}
498
void CodeGenC::VisitExpr_(const CastNode* op, std::ostream& os) {  // NOLINT(*)
499 500
  std::stringstream value;
  this->PrintExpr(op->value, value);
501
  os << CastFromTo(value.str(), op->value.dtype(), op->dtype);
502
}
503
void CodeGenC::VisitExpr_(const VarNode* op, std::ostream& os) {  // NOLINT(*)
504 505
  os << GetVarID(op);
}
506
void CodeGenC::VisitExpr_(const AddNode* op, std::ostream& os) {  // NOLINT(*)
507 508
  PrintBinaryExpr(op, "+", os, this);
}
509
void CodeGenC::VisitExpr_(const SubNode* op, std::ostream& os) {  // NOLINT(*)
510 511
  PrintBinaryExpr(op, "-", os, this);
}
512
void CodeGenC::VisitExpr_(const MulNode* op, std::ostream& os) {  // NOLINT(*)
513 514
  PrintBinaryExpr(op, "*", os, this);
}
515
void CodeGenC::VisitExpr_(const DivNode* op, std::ostream& os) {  // NOLINT(*)
516 517
  PrintBinaryExpr(op, "/", os, this);
}
518
void CodeGenC::VisitExpr_(const ModNode* op, std::ostream& os) {  // NOLINT(*)
519 520
  PrintBinaryExpr(op, "%", os, this);
}
521
void CodeGenC::VisitExpr_(const MinNode* op, std::ostream& os) {  // NOLINT(*)
522 523
  PrintBinaryExpr(op, "min", os, this);
}
524
void CodeGenC::VisitExpr_(const MaxNode* op, std::ostream& os) {  // NOLINT(*)
525 526
  PrintBinaryExpr(op, "max", os, this);
}
527
void CodeGenC::VisitExpr_(const EQNode* op, std::ostream& os) {  // NOLINT(*)
528 529
  PrintBinaryExpr(op, "==", os, this);
}
530
void CodeGenC::VisitExpr_(const NENode* op, std::ostream& os) {  // NOLINT(*)
531 532
  PrintBinaryExpr(op, "!=", os, this);
}
533
void CodeGenC::VisitExpr_(const LTNode* op, std::ostream& os) {  // NOLINT(*)
534 535
  PrintBinaryExpr(op, "<", os, this);
}
536
void CodeGenC::VisitExpr_(const LENode* op, std::ostream& os) {  // NOLINT(*)
537 538
  PrintBinaryExpr(op, "<=", os, this);
}
539
void CodeGenC::VisitExpr_(const GTNode* op, std::ostream& os) {  // NOLINT(*)
540 541
  PrintBinaryExpr(op, ">", os, this);
}
542
void CodeGenC::VisitExpr_(const GENode* op, std::ostream& os) {  // NOLINT(*)
543 544
  PrintBinaryExpr(op, ">=", os, this);
}
545
void CodeGenC::VisitExpr_(const AndNode* op, std::ostream& os) {  // NOLINT(*)
546 547
  PrintBinaryExpr(op, "&&", os, this);
}
548
void CodeGenC::VisitExpr_(const OrNode* op, std::ostream& os) {  // NOLINT(*)
549 550
  PrintBinaryExpr(op, "||", os, this);
}
551
void CodeGenC::VisitExpr_(const NotNode* op, std::ostream& os) {  // NOLINT(*)
552 553 554
  os << '!';
  PrintExpr(op->a, os);
}
555

556 557 558
void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) {  // NOLINT(*)
  if (op->call_type == CallNode::Extern ||
      op->call_type == CallNode::PureExtern) {
559 560 561 562 563 564 565 566
    os << op->name << "(";
    for (size_t i = 0; i < op->args.size(); i++) {
      this->PrintExpr(op->args[i], os);
      if (i < op->args.size() - 1) {
        os << ", ";
      }
    }
    os << ")";
567
  } else if (op->is_intrinsic(CallNode::bitwise_and)) {
568
    PrintBinaryIntrinsic(op, " & ", os, this);
569 570 571 572 573 574
  } else if (op->is_intrinsic(intrinsic::tvm_large_uint_imm)) {
    CHECK_EQ(op->args.size(), 2U);
    uint64_t low = static_cast<uint64_t>(Downcast<IntImm>(op->args[0])->value);
    uint64_t high = static_cast<uint64_t>(Downcast<IntImm>(op->args[1])->value);
    uint64_t val = (high << 32U) | low;
    PrintUIntConst(op->dtype, val, os, this);
575
  } else if (op->is_intrinsic(CallNode::bitwise_xor)) {
576
    PrintBinaryIntrinsic(op, " ^ ", os, this);
577
  } else if (op->is_intrinsic(CallNode::bitwise_or)) {
578
    PrintBinaryIntrinsic(op, " | ", os, this);
579
  } else if (op->is_intrinsic(CallNode::bitwise_not)) {
580 581
    CHECK_EQ(op->args.size(), 1U);
    os << "(~";
582
    this->PrintExpr(op->args[0], os);
583
    os << ')';
584
  } else if (op->is_intrinsic(CallNode::shift_left)) {
585
    PrintBinaryIntrinsic(op, " << ", os, this);
586
  } else if (op->is_intrinsic(CallNode::shift_right)) {
587
    PrintBinaryIntrinsic(op, " >> ", os, this);
588 589 590 591 592 593 594 595
  } else if (op->is_intrinsic(intrinsic::tvm_if_then_else)) {
    os << "(";
    PrintExpr(op->args[0], os);
    os << " ? ";
    PrintExpr(op->args[1], os);
    os << " : ";
    PrintExpr(op->args[2], os);
    os << ")";
596
  } else if (op->is_intrinsic(intrinsic::tvm_address_of)) {
597
    const LoadNode *l = op->args[0].as<LoadNode>();
598 599
    CHECK(op->args.size() == 1 && l);
    os << "((";
600
    this->PrintType(l->dtype.element_of(), os);
601
    os << " *)" << this->GetVarID(l->buffer_var.get())
602
       << " + ";
603
    this->PrintExpr(l->index, os);
604
    os << ')';
605
  } else if (op->is_intrinsic(intrinsic::tvm_struct_get)) {
606
    CHECK_EQ(op->args.size(), 3U);
607
    os << GetStructRef(
608
        op->dtype, op->args[0], op->args[1],
609
        op->args[2].as<IntImmNode>()->value);
610 611 612
  } else if (op->is_intrinsic(intrinsic::tvm_handle_is_null)) {
    CHECK_EQ(op->args.size(), 1U);
    os << "(";
613
    this->PrintExpr(op->args[0], os);
614
    os << " == NULL)";
615
  } else if (op->is_intrinsic(CallNode::reinterpret)) {
616 617
    // generate (*( TYPE *)(&(ARG)))
    os << "(*(";
618
    this->PrintType(op->dtype, os);
619 620 621
    os << " *)(&(";
    this->PrintExpr(op->args[0], os);
    os << ")))";
622
  } else if (op->is_intrinsic(CallNode::isnan)) {
623 624 625 626 627
    os << "(";
    this->PrintExpr(op->args[0], os);
    os << " != ";
    this->PrintExpr(op->args[0], os);
    os << ")";
628
  } else {
629 630
    if (op->call_type == CallNode::Intrinsic ||
        op->call_type == CallNode::PureIntrinsic) {
631
      LOG(FATAL) << "Unresolved intrinsic " << op->name
632
                 << " with return type " << op->dtype;
633 634
    } else {
      LOG(FATAL) << "Unresolved call type " << op->call_type;
635 636 637 638
    }
  }
}

639
void CodeGenC::PrintVecBinaryOp(
640
    const std::string& op, DataType t,
641
    PrimExpr lhs, PrimExpr rhs, std::ostream& os) {  // NOLINT(*)
642 643 644 645 646 647 648 649 650 651 652 653 654 655 656
  if (isalpha(op[0])) {
    os << op << "(";
    this->PrintExpr(lhs, os);
    os << ", ";
    this->PrintExpr(rhs, os);
    os << ")";
  } else {
    os <<"(";
    this->PrintExpr(lhs, os);
    os << ' ' << op << ' ';
    this->PrintExpr(rhs, os);
    os << ")";
  }
}

657
void CodeGenC::VisitExpr_(const LoadNode* op, std::ostream& os) {  // NOLINT(*)
658
  int lanes = op->dtype.lanes();
659
  // delcare type.
660 661
  if (op->dtype.lanes() == 1) {
    std::string ref = GetBufferRef(op->dtype, op->buffer_var.get(), op->index);
662
    HandleVolatileLoads(ref, op, os);
663
  } else {
664 665
    CHECK(is_one(op->predicate))
        << "predicated load is not supported";
666
    PrimExpr base;
667 668
    if (GetRamp1Base(op->index, op->dtype.lanes(), &base)) {
      std::string ref = GetVecLoad(op->dtype, op->buffer_var.get(), base);
669
      HandleVolatileLoads(ref, op, os);
670
    } else {
671 672 673 674
      // The assignment below introduces side-effect, and the resulting value cannot
      // be reused across multiple expression, thus a new scope is needed
      int vec_scope = BeginScope();

675 676 677
      // load seperately.
      std::string svalue = GetUniqueName("_");
      this->PrintIndent();
678
      this->PrintType(op->dtype, stream);
679
      stream << ' ' << svalue << ";\n";
680
      std::string sindex = SSAGetID(PrintExpr(op->index), op->index.dtype());
681
      std::string vid = GetVarID(op->buffer_var.get());
682
      DataType elem_type = op->dtype.element_of();
683 684 685 686
      for (int i = 0; i < lanes; ++i) {
        std::ostringstream value_temp;
        if (!HandleTypeMatch(op->buffer_var.get(), elem_type)) {
          value_temp << "((";
687
          if (op->buffer_var.get()->dtype.is_handle()) {
688 689 690 691 692 693
            auto it = alloc_storage_scope_.find(op->buffer_var.get());
            if (it != alloc_storage_scope_.end()) {
              PrintStorageScope(it->second, value_temp);
              value_temp << ' ';
            }
          }
694
          PrintType(elem_type, value_temp);
695 696 697 698 699
          value_temp << "*)" << vid << ')';
        } else {
          value_temp << vid;
        }
        value_temp << '[';
700
        PrintVecElemLoad(sindex, op->index.dtype(), i, value_temp);
701
        value_temp << ']';
702
        PrintVecElemStore(svalue, op->dtype, i, value_temp.str());
703
      }
704
      os << svalue;
705
      EndScope(vec_scope);
706 707 708 709
    }
  }
}

710
void CodeGenC::VisitStmt_(const StoreNode* op) {
711
  DataType t = op->value.dtype();
712 713
  if (t.lanes() == 1) {
    std::string value = this->PrintExpr(op->value);
714
    std::string ref  = this->GetBufferRef(t, op->buffer_var.get(), op->index);
715
    this->PrintIndent();
716
    stream << ref << " = " << value << ";\n";
717
  } else {
718 719
    CHECK(is_one(op->predicate))
        << "Predicated store is not supported";
720
    PrimExpr base;
721
    if (GetRamp1Base(op->index, t.lanes(), &base)) {
722 723 724
      std::string value = this->PrintExpr(op->value);
      this->PrintVecStore(op->buffer_var.get(), t, base, value);
    } else {
725 726 727 728
      // The assignment below introduces side-effect, and the resulting value cannot
      // be reused across multiple expression, thus a new scope is needed
      int vec_scope = BeginScope();

729
      // store elements seperately
730 731
      std::string index = SSAGetID(PrintExpr(op->index), op->index.dtype());
      std::string value = SSAGetID(PrintExpr(op->value), op->value.dtype());
732 733 734
      std::string vid = GetVarID(op->buffer_var.get());
      for (int i = 0; i < t.lanes(); ++i) {
        this->PrintIndent();
735
        DataType elem_type = t.element_of();
736 737
        if (!HandleTypeMatch(op->buffer_var.get(), elem_type)) {
          stream << "((";
738
          if (op->buffer_var.get()->dtype.is_handle()) {
739 740 741 742 743 744
            auto it = alloc_storage_scope_.find(op->buffer_var.get());
            if (it != alloc_storage_scope_.end()) {
              PrintStorageScope(it->second, stream);
              stream << ' ';
            }
          }
745 746 747 748 749 750
          PrintType(elem_type, stream);
          stream << "*)" << vid << ')';
        } else {
          stream << vid;
        }
        stream << '[';
751
        PrintVecElemLoad(index, op->index.dtype(), i, stream);
752
        stream << "] = ";
753
        PrintVecElemLoad(value, op->value.dtype(), i, stream);
754 755
        stream << ";\n";
      }
756
      EndScope(vec_scope);
757
    }
758
  }
759 760
}

761
void CodeGenC::VisitExpr_(const LetNode* op, std::ostream& os) {  // NOLINT(*)
762 763 764
  std::string value = PrintExpr(op->value);
  CHECK(!var_idmap_.count(op->var.get()));
  var_idmap_[op->var.get()] = value;
765
  os << PrintExpr(op->body);
766 767
}

768
void CodeGenC::VisitExpr_(const RampNode* op, std::ostream& os) {  // NOLINT(*)
769
  // constraint of current logic
770
  CHECK_EQ(op->base.dtype(), DataType::Int(32));
771 772 773 774 775 776 777
  os << "((int" << op->lanes << ")(";
  for (int i = 0; i < op->lanes; i++) {
    os << "(" << PrintExpr(op->base) << ")" << "+(" << PrintExpr(op->stride) << "*" << i <<")";
    if (i != op->lanes - 1)
      os << ", ";
  }
  os << "))";
778 779
}

780
void CodeGenC::VisitExpr_(const ShuffleNode* op, std::ostream& os) {
781 782 783
  LOG(FATAL) << "Shuffle: not supported ";
}

784
void CodeGenC::VisitExpr_(const BroadcastNode* op, std::ostream& os) {   // NOLINT(*)
785
  LOG(FATAL) << "Broadcast: not supported ";
786 787
}

788
void CodeGenC::VisitExpr_(const SelectNode* op, std::ostream& os) {  // NOLINT(*)
789 790 791 792 793 794 795
  os << "(";
  PrintExpr(op->condition, os);
  os << " ? ";
  PrintExpr(op->true_value, os);
  os << " : ";
  PrintExpr(op->false_value, os);
  os << ")";
796 797
}

798
void CodeGenC::VisitStmt_(const LetStmtNode* op) {
799 800 801 802 803 804
  std::string value = PrintExpr(op->value);
  if (print_ssa_form_) {
    CHECK(!var_idmap_.count(op->var.get()));
    var_idmap_[op->var.get()] = value;
  } else {
    PrintIndent();
805
    if (op->var.dtype() == DataType::Handle() &&
806 807 808 809 810 811 812 813
        handle_data_type_.count(op->var.get())) {
      PrintType(handle_data_type_.at(op->var.get()), stream);
      stream << "* "
             << AllocVarID(op->var.get())
             << " = (";
      PrintType(handle_data_type_.at(op->var.get()), stream);
      stream << "*)"  << value << ";\n";
    } else {
814
      PrintType(op->var.dtype(), this->stream);
815 816 817 818
      this->stream << ' '
                   << AllocVarID(op->var.get())
                   << " = " << value << ";\n";
    }
819 820 821 822
  }
  PrintStmt(op->body);
}

823
void CodeGenC::VisitStmt_(const AllocateNode* op) {
824
  CHECK(!is_zero(op->condition));
825 826 827 828 829 830
  std::string vid = AllocVarID(op->buffer_var.get());
  if (op->new_expr.defined()) {
    // Prefer global static allocation for the program
    CHECK_EQ(op->free_function, "nop");
    std::string new_data = PrintExpr(op->new_expr);
    this->PrintIndent();
831
    PrintType(op->dtype, stream);
832 833 834 835 836 837
    stream << "* "<< vid << '=' << new_data << ";\n";
  } else {
    this->PrintIndent();
    int32_t constant_size = op->constant_allocation_size();
    CHECK_GT(constant_size, 0)
        << "Can only handle constant size stack allocation for now";
838
    const VarNode* buffer = op->buffer_var.as<VarNode>();
839 840
    std::string scope = alloc_storage_scope_.at(buffer);
    PrintStorageScope(scope, stream);
841
    stream << ' ';
842
    PrintType(op->dtype, stream);
843
    stream << ' '<< vid << '['
844
           << constant_size << "];\n";
845
  }
846
  RegisterHandleType(op->buffer_var.get(), op->dtype);
847 848 849
  this->PrintStmt(op->body);
}

850
void CodeGenC::VisitStmt_(const AttrStmtNode* op) {
851
  if (op->attr_key == tir::attr::thread_extent) {
852
    IterVar iv = Downcast<IterVar>(op->node);
853
    if (iv->thread_tag.length() != 0) {
854
      if (!var_idmap_.count(iv->var.get())) {
855
        BindThreadIndex(iv);
856
      }
857
    }
858
  } else if (op->attr_key == tir::attr::storage_scope) {
859
    const VarNode* v = op->node.as<VarNode>();
860
    CHECK(v);
861
    alloc_storage_scope_[v] = op->value.as<StringImmNode>()->value;
862
  } else if (op->attr_key == tir::attr::volatile_scope) {
863
    const VarNode* v = op->node.as<VarNode>();
864 865
    CHECK(v);
    volatile_buf_.insert(v);
866 867 868 869
  }
  this->PrintStmt(op->body);
}

870
void CodeGenC::VisitStmt_(const AssertStmtNode* op) {
871 872
  std::string cond = PrintExpr(op->condition);
  PrintIndent();
873
  if (const auto* str = op->message.as<StringImmNode>()) {
874
    // GLOG style check
875
    stream << "CHECK(" << cond << ") << \"" << str->value << "\";\n";
876 877 878
  } else {
    stream << "assert(" << cond << ");\n";
  }
879
  this->PrintStmt(op->body);
880 881
}

882
void CodeGenC::VisitStmt_(const ForNode* op) {
883 884 885 886 887
  std::string extent = PrintExpr(op->extent);
  PrintIndent();
  std::string vid = AllocVarID(op->loop_var.get());
  CHECK(is_zero(op->min));
  stream << "for (";
888
  PrintType(op->loop_var.dtype(), stream);
889 890 891 892 893 894 895 896 897 898
  stream << ' ' << vid << " = 0; "
            << vid << " < " << extent
            << "; ++" << vid << ") {\n";
  int for_scope = BeginScope();
  PrintStmt(op->body);
  this->EndScope(for_scope);
  PrintIndent();
  stream << "}\n";
}

899
void CodeGenC::VisitStmt_(const IfThenElseNode* op) {
900 901
  std::string cond = PrintExpr(op->condition);
  PrintIndent();
902 903 904 905 906
  if (cond[0] == '(' && cond[cond.length() - 1] == ')') {
    stream << "if " << cond << " {\n";
  } else {
    stream << "if (" << cond << ") {\n";
  }
907 908 909 910 911 912 913 914 915 916 917 918 919 920 921
  int then_scope = BeginScope();
  PrintStmt(op->then_case);
  this->EndScope(then_scope);

  if (op->else_case.defined()) {
    PrintIndent();
    stream << "} else {\n";
    int else_scope = BeginScope();
    PrintStmt(op->else_case);
    this->EndScope(else_scope);
  }
  PrintIndent();
  stream << "}\n";
}

922 923 924 925
void CodeGenC::VisitStmt_(const SeqStmtNode* op) {
  for (Stmt stmt : op->seq) {
    PrintStmt(stmt);
  }
926 927
}

928
void CodeGenC::VisitStmt_(const EvaluateNode* op) {
929
  if (is_const(op->value)) return;
930
  const CallNode* call = op->value.as<CallNode>();
931 932 933 934 935 936 937
  if (call) {
    if (call->is_intrinsic(intrinsic::tvm_storage_sync)) {
      this->PrintStorageSync(call); return;
    } else if (call->is_intrinsic(intrinsic::tvm_struct_set)) {
      CHECK_EQ(call->args.size(), 4);
      std::string value = PrintExpr(call->args[3]);
      std::string ref = GetStructRef(
938
          call->args[3].dtype(),
939 940
          call->args[0],
          call->args[1],
941
          call->args[2].as<IntImmNode>()->value);
942 943 944 945
      this->PrintIndent();
      this->stream << ref << " = " << value << ";\n";
      return;
    }
946
  }
947
  std::string vid = this->PrintExpr(op->value);
948 949 950 951
  if (vid != "") {
    this->PrintIndent();
    this->stream << "(void)" << vid << ";\n";
  }
952 953
}

954
void CodeGenC::VisitStmt_(const ProducerConsumerNode* op) {
955 956
  PrintStmt(op->body);
}
957

958 959
}  // namespace codegen
}  // namespace tvm