codegen_llvm.cc 40.2 KB
Newer Older
1 2 3 4 5
/*!
 *  Copyright (c) 2017 by Contributors
 * \file codegen_llvm.cc
 */
#ifdef TVM_LLVM_VERSION
6
// Part of the code are adapted from Halide's CodeGen_LLVM
7

8
#include <tvm/runtime/device_api.h>
9
#include <tvm/runtime/c_runtime_api.h>
10 11
#include "codegen_llvm.h"
#include "codegen_cpu.h"
12
#include "../../pass/ir_util.h"
13 14 15 16 17
#include "../../arithmetic/compute_expr.h"

namespace tvm {
namespace codegen {

18 19 20 21 22 23 24 25
std::unique_ptr<CodeGenLLVM> CodeGenLLVM::Create(llvm::TargetMachine *tm) {
  std::string target = tm->getTarget().getName();
  std::string factory_name = "tvm.codegen.llvm.target_" + target;
  const PackedFunc* f = runtime::Registry::Get(factory_name);
  if (f != nullptr) {
    void* handle = (*f)();
    return std::unique_ptr<CodeGenLLVM>(static_cast<CodeGenLLVM*>(handle));
  } else {
26
    return std::unique_ptr<CodeGenLLVM>(new CodeGenCPU());
27 28 29
  }
}

30
void CodeGenLLVM::Init(const std::string& module_name,
31
                       llvm::TargetMachine* tm,
32
                       llvm::LLVMContext* ctx,
33 34
                       bool system_lib,
                       bool dynamic_lookup) {
35 36
  InitializeLLVM();
  ctx_ = ctx;
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53
  builder_.reset(new IRBuilder(*ctx_));
  module_.reset(new llvm::Module(module_name, *ctx_));
  md_builder_.reset(new llvm::MDBuilder(*ctx_));
  // types
  t_void_ = llvm::Type::getVoidTy(*ctx_);
  t_void_p_ = llvm::Type::getInt8Ty(*ctx_)->getPointerTo();
  t_int_ = llvm::Type::getInt32Ty(*ctx_);
  t_char_ = llvm::Type::getInt8Ty(*ctx_);
  t_int8_ = llvm::Type::getInt8Ty(*ctx_);
  t_int16_ = llvm::Type::getInt16Ty(*ctx_);
  t_int32_ = llvm::Type::getInt32Ty(*ctx_);
  t_int64_ = llvm::Type::getInt64Ty(*ctx_);
  t_float64_ = llvm::Type::getDoubleTy(*ctx_);
  // meta data
  md_very_likely_branch_ = md_builder_->createBranchWeights(1<<20, 1);
  md_tbaa_root_ = md_builder_->createTBAARoot("tvm-tbaa");
  md_tbaa_alias_set_ = md_builder_->createTBAANode("tvm-alias", md_tbaa_root_);
54
  this->InitTarget(tm);
55 56
}

57 58
void CodeGenLLVM::InitTarget(llvm::TargetMachine* tm) {
  module_->setTargetTriple(tm->getTargetTriple().str());
59 60
  module_->setDataLayout(tm->createDataLayout());
  data_layout_.reset(new llvm::DataLayout(module_.get()));
61
  target_machine_ = tm;
62 63 64 65 66 67 68 69 70 71 72 73 74
  if (native_vector_bits_ == 0) {
    const auto& arch = tm->getTargetTriple().getArch();
    if (arch == llvm::Triple::x86_64) {
      // for avx512
      native_vector_bits_ = 512;
    } else if (arch == llvm::Triple::x86) {
      native_vector_bits_ = 256;
    } else if (arch == llvm::Triple::arm || arch == llvm::Triple::aarch64) {
      native_vector_bits_ = 128;
    } else {
      native_vector_bits_ = 128;
      std::string arch_name = tm->getTargetTriple().getArchName();
      LOG(WARNING) << "Set native vector bits to be 128 for " << arch_name;
75
    }
76
  }
77 78
}

79 80 81 82
void CodeGenLLVM::AddFunction(const LoweredFunc& f) {
  this->AddFunctionInternal(f, false);
}

83
void CodeGenLLVM::InitFuncState() {
84
  var_map_.clear();
85
  alias_var_set_.clear();
86
  alloc_storage_info_.clear();
87
  volatile_buf_.clear();
88
  analyzer_.reset(new arith::Analyzer());
89 90 91
}

void CodeGenLLVM::AddFunctionInternal(const LoweredFunc& f, bool ret_void) {
92
  this->InitFuncState();
93
  std::vector<llvm::Type*> arg_types;
94
  is_restricted_ = f->is_restricted;
95 96
  for (Var arg : f->args) {
    Type t = arg.type();
97 98 99 100 101 102 103 104
    if (t.is_handle()) {
      auto it = f->handle_data_type.find(arg);
      if (it != f->handle_data_type.end()) {
        arg_types.push_back(LLVMType((*it).second.type())
                            ->getPointerTo(GetGlobalAddressSpace()));
      } else {
        arg_types.push_back(t_int8_->getPointerTo(GetGlobalAddressSpace()));
      }
105 106 107
      if (!is_restricted_) {
        alias_var_set_.insert(arg.get());
      }
108
    } else {
109
      arg_types.push_back(LLVMType(arg.type()));
110 111
    }
  }
112
  llvm::FunctionType* ftype = llvm::FunctionType::get(
113 114 115 116 117 118
      ret_void ? t_void_ : t_int_, arg_types, false);
  CHECK(module_->getFunction(f->name) == nullptr)
      << "Function " << f->name << " already exist in module";
  function_ = llvm::Function::Create(
      ftype, llvm::Function::ExternalLinkage,
      f->name, module_.get());
119
  function_->setCallingConv(llvm::CallingConv::C);
Hu Shiwen committed
120
  function_->setDLLStorageClass(llvm::GlobalValue::DLLStorageClassTypes::DLLExportStorageClass);
121 122 123 124 125 126 127 128 129
  // set var map and align information
  auto arg_it = function_->arg_begin();
  for (size_t i = 0; i < f->args.size(); ++i, ++arg_it) {
    llvm::Argument* v = &(*arg_it);
    const Var& var = f->args[i];
    var_map_[var.get()] = v;
    if (is_restricted_) {
      if (var.type().is_handle() && !alias_var_set_.count(var.get())) {
        // set non alias.
130 131 132
#if TVM_LLVM_VERSION >= 50
        function_->addParamAttr(i, llvm::Attribute::NoAlias);
#else
133
        function_->setDoesNotAlias(i + 1);
134
#endif
135 136 137
      }
    }
  }
138 139
  llvm::BasicBlock* entry = llvm::BasicBlock::Create(*ctx_, "entry", function_);
  builder_->SetInsertPoint(entry);
140
  this->VisitStmt(f->body);
141 142 143 144 145
  if (ret_void) {
    builder_->CreateRetVoid();
  } else {
    builder_->CreateRet(ConstInt32(0));
  }
146 147
}

148 149
std::unique_ptr<llvm::Module> CodeGenLLVM::Finish() {
  this->AddStartupFunction();
150 151 152 153 154 155 156
  // link modules
  for (size_t i = 0; i < link_modules_.size(); ++i) {
    CHECK(!llvm::Linker::linkModules(*module_, std::move(link_modules_[i])))
        << "Failed to link modules";
  }
  link_modules_.clear();
  // optimize
157 158 159 160
  this->Optimize();
  return std::move(module_);
}

161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196

void CodeGenLLVM::HandleImport(const std::string& code) {
  std::unique_ptr<llvm::Module> mlib;
  llvm::SMDiagnostic err;
  if (code.length() >= 3 &&
      (code.substr(code.length() - 3) == ".ll" ||
       code.substr(code.length() - 3) == ".bc")) {
    mlib = llvm::parseIRFile(code, err, *ctx_);
    if (mlib.get() == nullptr) {
      std::string msg = err.getMessage();
      LOG(FATAL) << "Fail to load bitcode file " << code << "\n"
                 << "line " << err.getLineNo() << ":" << msg;
    }
  } else {
    std::unique_ptr<llvm::MemoryBuffer> buf =
        llvm::MemoryBuffer::getMemBuffer(code);
    mlib = llvm::parseIR(*buf, err, *ctx_);
    if (mlib.get() == nullptr) {
      std::string msg = err.getMessage();
      LOG(FATAL) << "Fail to load llvm ir "
                 << "line " << err.getLineNo() << ":" << msg
                 << "\ncontent:\n"  << code;
    }
  }
  mlib->setTargetTriple(target_machine_->getTargetTriple().str());
  mlib->setDataLayout(target_machine_->createDataLayout());
  // mark all the functions as force inline
  for (llvm::Function &f : mlib->functions()) {
    f.removeFnAttr(llvm::Attribute::NoInline);
    f.addFnAttr(llvm::Attribute::AlwaysInline);
    f.setLinkage(llvm::GlobalValue::AvailableExternallyLinkage);
  }
  // add to linker libraries.
  this->AddLinkModule(std::move(mlib));
}

197 198 199 200
void CodeGenLLVM::AddLinkModule(std::unique_ptr<llvm::Module>&& mod) {
  link_modules_.emplace_back(std::move(mod));
}

201
void CodeGenLLVM::AddMainFunction(const std::string& entry_func_name) {
202 203 204 205 206 207 208 209 210 211 212
  LOG(FATAL) << "not implemented";
}

llvm::Value* CodeGenLLVM::GetThreadIndex(const IterVar& iv) {
  LOG(FATAL) << "not implemented";
  return nullptr;
}

llvm::Value* CodeGenLLVM::CreateStorageSync(const Call* op) {
  LOG(FATAL) << "not implemented";
  return nullptr;
213 214
}

215 216 217 218 219 220 221 222 223
class FPassManager : public llvm::legacy::FunctionPassManager {
 public:
  explicit FPassManager(llvm::Module* m)
      : llvm::legacy::FunctionPassManager(m) {}
  // override add to allow messaging
  void add(llvm::Pass* p) final {
    llvm::legacy::FunctionPassManager::add(p);
  }
};
224

225 226 227 228 229 230 231 232
class MPassManager : public llvm::legacy::PassManager {
 public:
  // override add to allow messaging
  void add(llvm::Pass* p) final {
    llvm::legacy::PassManager::add(p);
  }
};

233 234 235
void CodeGenLLVM::InitPassManagerBuilder(llvm::PassManagerBuilder* builder) {
}

236
void CodeGenLLVM::Optimize() {
237 238 239 240 241 242 243 244 245 246
  // pass manager
  FPassManager fpass(module_.get());
  MPassManager mpass;
  mpass.add(llvm::createTargetTransformInfoWrapperPass(
              target_machine_ ? target_machine_->getTargetIRAnalysis() :
                                llvm::TargetIRAnalysis()));
  fpass.add(llvm::createTargetTransformInfoWrapperPass(
              target_machine_ ? target_machine_->getTargetIRAnalysis() :
              llvm::TargetIRAnalysis()));

247 248 249
  // place optimization pass
  llvm::PassManagerBuilder builder;
  builder.OptLevel = 3;
250 251 252 253

#if TVM_LLVM_VERSION >= 50
  builder.Inliner = llvm::createFunctionInliningPass(builder.OptLevel, 0, false);
#else
254
  builder.Inliner = llvm::createFunctionInliningPass(builder.OptLevel, 0);
255
#endif
256 257
  builder.LoopVectorize = true;
  builder.SLPVectorize = true;
258 259 260 261 262 263
  this->InitPassManagerBuilder(&builder);

#if TVM_LLVM_VERSION >= 50
  target_machine_->adjustPassManager(builder);
#endif

264 265 266 267 268 269 270 271 272 273 274
  builder.populateFunctionPassManager(fpass);
  builder.populateModulePassManager(mpass);

  fpass.doInitialization();
  for (auto it = module_->begin(); it != module_->end(); ++it) {
    fpass.run(*it);
  }
  fpass.doFinalization();
  mpass.run(*module_);
}

275 276 277 278 279 280
int CodeGenLLVM::NativeVectorBits(const runtime::StorageScope& storage_scope) const {
  return native_vector_bits_;
}

unsigned CodeGenLLVM::GetGlobalAddressSpace() {
  return 0;
281 282 283
}

llvm::Type* CodeGenLLVM::LLVMType(const Type& t) const {
284 285 286 287
  if (t.is_handle()) {
    CHECK_EQ(t.lanes(), 1);
    return t_void_p_;
  }
288
  llvm::Type* etype = nullptr;
289 290
  if (t.is_int() || t.is_uint()) {
    etype = llvm::Type::getIntNTy(*ctx_, t.bits());
291 292
  } else if (t.is_float()) {
    switch (t.bits()) {
293 294 295 296
      case 16: etype = llvm::Type::getHalfTy(*ctx_); break;
      case 32: etype = llvm::Type::getFloatTy(*ctx_); break;
      case 64: etype = llvm::Type::getDoubleTy(*ctx_); break;
      default: LOG(FATAL) << "do not support " << t;
297 298 299
    }
  }
  if (t.lanes() != 1) {
300 301 302
    return llvm::VectorType::get(etype, t.lanes());
  } else {
    return etype;
303 304 305
  }
}

306 307 308 309 310 311 312 313 314 315 316
// Add tbaa alias information for load
//
// use a binary tree typed system to declare information
// and allow alias to be distinguished across nodes.
//
// This trick comes from Halide's CodeGen_LLVM
//
void CodeGenLLVM::AddAliasInfo(llvm::Instruction* inst,
                               const Variable* buffer,
                               Expr index,
                               Type type) {
317 318 319 320 321 322 323 324
  if (alias_var_set_.count(buffer) != 0) {
    // Mark all possibly aliased pointer as same type.
    llvm::MDNode* meta = md_tbaa_alias_set_;
    inst->setMetadata(
        "tbaa",
        md_builder_->createTBAAStructTagNode(meta, meta, 0));
    return;
  }
325 326
  int base = 0, width = 0;
  // create meta-data for alias analysis
327
  // Use a group of binary tree ranges of memory banks.
328 329 330 331 332 333 334 335 336 337 338 339 340 341 342
  if (index.defined()) {
    const Ramp* ramp = index.as<Ramp>();
    if (ramp) {
      int base, stride;
      if (arith::GetConstInt(ramp->base, &base) &&
          arith::GetConstInt(ramp->stride, &stride)) {
        int xwith = ramp->lanes * stride;
        width = 1;
        while (width < xwith) {
          width *= 2;
        }
        while (base % width) {
          base -= base % width;
          width *= 2;
        }
343
      }
344 345
    } else {
      if (arith::GetConstInt(index, &base)) width = 1;
346 347 348
    }
  }
  llvm::MDNode* meta = md_tbaa_root_;
349
  std::ostringstream buffer_addr, buffer_type;
350 351
  buffer_addr << buffer;
  meta = md_builder_->createTBAAScalarTypeNode(buffer_addr.str(), meta);
352
  buffer_type << type.element_of();
353
  meta = md_builder_->createTBAAScalarTypeNode(buffer_type.str(), meta);
354 355 356 357 358 359 360 361 362 363 364 365
  // create a tree-shape access structure.
  if (width != 0) {
    for (int w = 1024; w >= width; w /= 2) {
      int b = (base / w) * w;
      std::stringstream os;
      os << buffer << ".w" << w << ".b" << b;
      meta = md_builder_->createTBAAScalarTypeNode(os.str(), meta);
    }
  }
  inst->setMetadata(
      "tbaa",
      md_builder_->createTBAAStructTagNode(meta, meta, 0));
366 367
}

368 369 370 371 372 373 374 375 376 377 378 379 380 381 382
void CodeGenLLVM::GetAlignment(Type t,
                               const Variable* buf_var,
                               const Expr& index,
                               int* p_alignment,
                               int* p_native_bits) {
  int max_align_bits = t.bits();
  auto it = alloc_storage_info_.find(buf_var);
  if (it != alloc_storage_info_.end()) {
    const StorageInfo& info = it->second;
    *p_native_bits = NativeVectorBits(info.scope);
    max_align_bits = info.alignment * 8;
  } else {
    *p_native_bits = native_vector_bits_;
  }

383 384 385
  arith::ModularSet me = analyzer_->modular_set(index);
  int64_t base = me->base;
  int64_t coeff = me->coeff;
386

387 388
  int align_bits = t.bits();
  while (align_bits < max_align_bits &&
389 390 391 392
         base % 2  == 0 &&
         coeff % 2 == 0) {
    base =  base / 2;
    coeff =  coeff / 2;
393 394 395 396 397 398 399 400
    align_bits *= 2;
  }
  if (align_bits < 8) {
    align_bits = 8;
  }
  *p_alignment = align_bits / 8;
}

401
llvm::Value* CodeGenLLVM::CreateBroadcast(llvm::Value* value, int lanes) {
402
  llvm::Constant* undef = llvm::UndefValue::get(
403 404
      llvm::VectorType::get(value->getType(), lanes));
  llvm::Constant* zero = ConstInt32(0);
405
  value = builder_->CreateInsertElement(undef, value, zero);
406
  llvm::Constant* mask = llvm::ConstantVector::getSplat(lanes, zero);
407
  return builder_->CreateShuffleVector(value, undef, mask);
408 409
}

410 411 412
llvm::Value* CodeGenLLVM::CreateVecSlice(llvm::Value* vec, int begin, int extent) {
  int num_elems = static_cast<int>(vec->getType()->getVectorNumElements());
  if (extent == num_elems && begin == 0) return vec;
413
  CHECK_LE(begin + extent, num_elems);
414 415 416 417 418 419
  std::vector<unsigned> indices;
  for (int i = 0; i < extent; ++i) {
    indices.push_back(begin + i);
  }
  return builder_->CreateShuffleVector(vec, vec, indices);
}
420

421 422 423 424 425
llvm::Value* CodeGenLLVM::CreateVecFlip(llvm::Value* vec) {
  int num_elems = static_cast<int>(vec->getType()->getVectorNumElements());
  std::vector<unsigned> indices;
  for (int i = 0; i < num_elems; ++i) {
    indices.push_back(num_elems - i - 1);
426
  }
427 428 429 430 431 432 433 434 435 436
  return builder_->CreateShuffleVector(vec, vec, indices);
}

llvm::Value* CodeGenLLVM::CreateVecPad(llvm::Value* vec, int target_lanes) {
  llvm::Value* mask = llvm::UndefValue::get(LLVMType(Int(32, target_lanes)));
  int num_elems = static_cast<int>(vec->getType()->getVectorNumElements());
  if (num_elems == target_lanes) return vec;
  CHECK_LT(num_elems, target_lanes);
  for (int i = 0; i < num_elems; ++i) {
    mask = builder_->CreateInsertElement(mask, ConstInt32(i), ConstInt32(i));
437
  }
438
  return builder_->CreateShuffleVector(vec, vec, mask);
439 440
}

441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502
llvm::Value* CodeGenLLVM::CreateVecConcat(std::vector<llvm::Value*> vecs) {
  // concat vector, tree shape reduction
  int total_lanes = 0;
  for (llvm::Value* v : vecs) {
    total_lanes += static_cast<int>(
        v->getType()->getVectorNumElements());
  }
  while (vecs.size() > 1) {
    for (size_t i = 0; i < vecs.size(); i+=2) {
      if (i + 1 >= vecs.size()) {
        vecs[i / 2] = vecs[i]; continue;
      }
      llvm::Value* lhs = vecs[i];
      llvm::Value* rhs = vecs[i + 1];
      int lanes = static_cast<int>(std::max(
          lhs->getType()->getVectorNumElements(),
          rhs->getType()->getVectorNumElements()));
      lhs = CreateVecPad(lhs, lanes);
      rhs = CreateVecPad(lhs, lanes);
      std::vector<unsigned> mask;
      for (int i = 0; i < lanes * 2; ++i) {
        mask.push_back(i);
      }
      vecs[i / 2] = builder_->CreateShuffleVector(lhs, rhs, mask);
    }
    vecs.resize((vecs.size() + 1) / 2);
  }
  return CreateVecSlice(vecs[0], 0, total_lanes);
}


void CodeGenLLVM::CreateSerialFor(llvm::Value* begin,
                                  llvm::Value* end,
                                  llvm::Value* stride,
                                  const VarExpr& loop_var,
                                  const Stmt& body) {
  using llvm::BasicBlock;
  BasicBlock* pre_block = builder_->GetInsertBlock();
  BasicBlock* for_begin = BasicBlock::Create(
      *ctx_, "for_begin", function_);
  BasicBlock* for_body = BasicBlock::Create(
      *ctx_, "for_body", function_);
  BasicBlock* for_end = BasicBlock::Create(
      *ctx_, "for_end", function_);
  builder_->CreateBr(for_begin);
  builder_->SetInsertPoint(for_begin);
  llvm::PHINode* loop_value = builder_->CreatePHI(begin->getType(), 2);
  loop_value->addIncoming(begin, pre_block);
  CHECK(!var_map_.count(loop_var.get()));
  var_map_[loop_var.get()] = loop_value;
  builder_->CreateCondBr(CreateLT(loop_var.type(), loop_value, end),
                         for_body, for_end, md_very_likely_branch_);
  builder_->SetInsertPoint(for_body);
  this->VisitStmt(body);
  var_map_.erase(loop_var.get());
  llvm::Value* loop_next = CreateAdd(loop_var.type(), loop_value, stride);
  loop_value->addIncoming(loop_next, builder_->GetInsertBlock());
  builder_->CreateBr(for_begin);
  builder_->SetInsertPoint(for_end);
}

// cast operatpr
503 504 505
llvm::Value* CodeGenLLVM::CreateCast(Type from, Type to, llvm::Value* value) {
  llvm::Type * target = LLVMType(to);
  if (value->getType() == target) return value;
506
  if (to.is_handle()) {
507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526
    return builder_->CreateBitCast(value, target);
  } else if (!from.is_float() && !to.is_float()) {
    return builder_->CreateIntCast(value, target, from.is_int());
  } else if (from.is_float() && to.is_int()) {
    return builder_->CreateFPToSI(value, target);
  } else if (from.is_float() && to.is_uint()) {
    if (to.bits() < 8) {
      value = builder_->CreateFPToUI(value, LLVMType(to.with_bits(8)));
      return builder_->CreateIntCast(value, target, false);
    } else {
      return builder_->CreateFPToUI(value, target);
    }
  } else if (from.is_int() && to.is_float()) {
    return builder_->CreateSIToFP(value, target);
  } else if (from.is_uint() && to.is_float()) {
    return builder_->CreateUIToFP(value, target);
  } else {
    CHECK(from.is_float() && to.is_float());
    return builder_->CreateFPCast(value, target);
  }
527 528
}

529 530 531 532 533 534 535 536 537 538 539 540 541 542
llvm::Value* CodeGenLLVM::GetConstString(const std::string& str) {
  auto it = str_map_.find(str);
  if (it != str_map_.end()) return it->second;
  llvm::Type* type = llvm::ArrayType::get(t_char_, str.length() + 1);
  llvm::GlobalVariable *global = new llvm::GlobalVariable(
      *module_, type, true, llvm::GlobalValue::PrivateLinkage, 0, ".str");
  global->setAlignment(1);
  global->setInitializer(llvm::ConstantDataArray::getString(*ctx_, str));
  llvm::Constant* zero = ConstInt32(0);
  llvm::Constant* indices[] = {zero, zero};
  llvm::Constant* ptr = llvm::ConstantExpr::getGetElementPtr(
      type, global, indices);
  str_map_[str] = ptr;
  return ptr;
543 544
}

545 546 547 548 549 550 551 552
llvm::Value* CodeGenLLVM::CreateBufferPtr(
    Type t, llvm::Value* buffer, llvm::Value* index) {
  CHECK_EQ(t.lanes(), 1);
  llvm::PointerType* btype = llvm::dyn_cast<llvm::PointerType>(buffer->getType());
  CHECK(btype != nullptr);
  llvm::PointerType* ptype = LLVMType(t)->getPointerTo(btype->getAddressSpace());
  if (btype != ptype) {
    buffer = builder_->CreatePointerCast(buffer, ptype);
553
  }
554

555
  return builder_->CreateInBoundsGEP(buffer, index);
556 557
}

558 559 560 561 562 563 564 565 566 567 568 569
llvm::Value* CodeGenLLVM::CreateBufferVecPtr(
    Type t, llvm::Value* buffer, llvm::Value* index) {
  CHECK_GT(t.lanes(), 1);
  llvm::PointerType* btype = llvm::dyn_cast<llvm::PointerType>(buffer->getType());
  CHECK(btype != nullptr);
  llvm::PointerType* ptype = LLVMType(t)->getPointerTo(btype->getAddressSpace());
  if (btype != ptype) {
    buffer = builder_->CreatePointerCast(buffer, ptype);
  }
  return builder_->CreateInBoundsGEP(buffer, index);
}

570 571
llvm::Value* CodeGenLLVM::GetVarValue(const Variable* v) const {
  auto it = var_map_.find(v);
572
  CHECK(it != var_map_.end()) << "cannot find variable " << v->name_hint;
573
  return it->second;
574 575
}

576 577 578
llvm::Value* CodeGenLLVM::CreateCallExtern(const Call* op) {
  std::vector<llvm::Value*> arg_value;
  std::vector<llvm::Type*> arg_type;
579 580
  for (size_t i = 0; i < op->args.size(); ++i) {
    arg_value.push_back(MakeValue(op->args[i]));
581
    arg_type.push_back(arg_value.back()->getType());
582
  }
583 584 585 586 587 588 589 590 591 592
  llvm::FunctionType* ftype = llvm::FunctionType::get(
      LLVMType(op->type), arg_type, false);
  llvm::Function* f = module_->getFunction(op->name);
  if (f == nullptr) {
    f = llvm::Function::Create(
        ftype, llvm::Function::ExternalLinkage,
        op->name, module_.get());
  }
  llvm::CallInst* call = builder_->CreateCall(f, arg_value);
  return call;
593 594
}

595
llvm::Value* CodeGenLLVM::CreateIntrinsic(const Call* op) {
596
  if (op->is_intrinsic("llvm_intrin")) {
597
    CHECK_GE(op->args.size(), 2U);
598 599
    llvm::Intrinsic::ID id = static_cast<llvm::Intrinsic::ID>(
        op->args[0].as<UIntImm>()->value);
600
    uint64_t num_signature = op->args[1].as<UIntImm>()->value;
601
    std::vector<llvm::Value*> arg_value;
602 603
    std::vector<llvm::Type*> sig_type;
    for (size_t i = 2; i < op->args.size(); ++i) {
604
      arg_value.push_back(MakeValue(op->args[i]));
605 606 607
      if (i - 2 < num_signature) {
        sig_type.push_back(arg_value.back()->getType());
      }
608
    }
609 610 611 612
    llvm::Type *return_type = LLVMType(op->type);
    if (sig_type.size() > 0 && return_type != sig_type[0]) {
      sig_type.insert(sig_type.begin(), return_type);
    }
613
    llvm::Function* f = llvm::Intrinsic::getDeclaration(
614
        module_.get(), id, sig_type);
615
    return builder_->CreateCall(f, arg_value);
616
  } else if (op->is_intrinsic(Call::bitwise_and)) {
617
    return builder_->CreateAnd(MakeValue(op->args[0]), MakeValue(op->args[1]));
618
  } else if (op->is_intrinsic(Call::bitwise_or)) {
619
    return builder_->CreateOr(MakeValue(op->args[0]), MakeValue(op->args[1]));
620 621
  } else if (op->is_intrinsic(Call::bitwise_not)) {
    return builder_->CreateNot(MakeValue(op->args[0]));
622 623
  } else if (op->is_intrinsic(Call::bitwise_xor)) {
    return builder_->CreateXor(MakeValue(op->args[0]), MakeValue(op->args[1]));
624
  } else if (op->is_intrinsic(Call::shift_left)) {
625
    return builder_->CreateShl(MakeValue(op->args[0]), MakeValue(op->args[1]));
626
  } else if (op->is_intrinsic(Call::shift_right)) {
627 628
    if (op->args[0].type().is_int()) {
      return builder_->CreateAShr(MakeValue(op->args[0]), MakeValue(op->args[1]));
629
    } else {
630
      return builder_->CreateLShr(MakeValue(op->args[0]), MakeValue(op->args[1]));
631
    }
632 633
  } else if (op->is_intrinsic(intrinsic::tvm_storage_sync)) {
    return CreateStorageSync(op);
634
  } else if (op->is_intrinsic(intrinsic::tvm_address_of)) {
635 636
    const Load *l = op->args[0].as<Load>();
    CHECK(op->args.size() == 1 && l);
637 638 639 640 641 642 643 644 645 646 647 648 649 650 651
    const Ramp *r = l->index.as<Ramp>();
    llvm::Value* ptr;
    unsigned addrspace;
    if (!r) {
        ptr = CreateBufferPtr(
          l->type, MakeValue(l->buffer_var), MakeValue(l->index));
        addrspace = llvm::dyn_cast<llvm::PointerType>(
          ptr->getType())->getAddressSpace();
    } else {
        Expr index = r->base / make_const(Int(32), r->lanes);
        ptr = CreateBufferVecPtr(
          l->type, MakeValue(l->buffer_var), MakeValue(index));
        addrspace = llvm::dyn_cast<llvm::PointerType>(
          ptr->getType())->getAddressSpace();
    }
652
    return builder_->CreatePointerCast(ptr, t_char_->getPointerTo(addrspace));
653 654
  } else if (op->is_intrinsic(Call::reinterpret) && is_zero(op->args[0])) {
    return llvm::Constant::getNullValue(t_void_p_);
655
  } else if (op->is_intrinsic(intrinsic::tvm_handle_is_null)) {
656
    return builder_->CreateIsNull(MakeValue(op->args[0]));
657
  } else if (op->is_intrinsic(intrinsic::tvm_if_then_else)) {
658 659
    CHECK_EQ(op->args[0].type().lanes(), 1)
        << "if_then_else can only take scalar condition";
660 661 662 663 664 665 666
    using llvm::BasicBlock;
    BasicBlock* then_block = BasicBlock::Create(
        *ctx_, "if_then", function_);
    BasicBlock* else_block = BasicBlock::Create(
        *ctx_, "if_else", function_);
    BasicBlock* end_block = BasicBlock::Create(
        *ctx_, "if_end", function_);
667
    builder_->CreateCondBr(MakeValue(op->args[0]), then_block, else_block);
668 669
    builder_->SetInsertPoint(then_block);
    llvm::Value* then_value = MakeValue(op->args[1]);
670
    BasicBlock* then_value_block = builder_->GetInsertBlock();
671 672 673
    builder_->CreateBr(end_block);
    builder_->SetInsertPoint(else_block);
    llvm::Value* else_value = MakeValue(op->args[2]);
674
    BasicBlock* else_value_block = builder_->GetInsertBlock();
675 676
    builder_->CreateBr(end_block);
    builder_->SetInsertPoint(end_block);
677
    llvm::PHINode* value = builder_->CreatePHI(then_value->getType(), 2);
678 679
    value->addIncoming(then_value, then_value_block);
    value->addIncoming(else_value, else_value_block);
680
    return value;
681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700
  } else if (op->is_intrinsic(Call::reinterpret)) {
    llvm::Type * target = LLVMType(op->type);
    return builder_->CreateBitCast(MakeValue(op->args[0]), target);
  } else if (op->is_intrinsic("vectorlow")) {
    llvm::Value *v = MakeValue(op->args[0]);
    int l = v->getType()->getVectorNumElements();
    return CreateVecSlice(v, 0, l/2);
  } else if (op->is_intrinsic("vectorhigh")) {
    llvm::Value *v = MakeValue(op->args[0]);
    int l = v->getType()->getVectorNumElements();
    return CreateVecSlice(v, l/2, l/2);
  } else if (op->is_intrinsic("vectorcombine")) {
    llvm::Value *v0 = MakeValue(op->args[0]);
    llvm::Value *v1 = MakeValue(op->args[1]);
    int num_elems = static_cast<int>(v0->getType()->getVectorNumElements()) * 2;
    std::vector<unsigned> indices;
    for (int i = 0; i < num_elems; ++i) {
      indices.push_back(i);
    }
    return builder_->CreateShuffleVector(v0, v1, indices);
701
  } else {
702 703
    LOG(FATAL) << "unknown intrinsic " << op->name;
    return nullptr;
704
  }
705 706
}

707 708 709 710 711 712 713 714 715 716 717 718 719 720
void CodeGenLLVM::Scalarize(const Expr& e,
                            std::function<void(int i, llvm::Value* v)> f) {
  if (const Ramp* ramp = e.as<Ramp>()) {
    for (int i = 0; i < ramp->type.lanes(); ++i) {
      Expr offset = arith::ComputeExpr<Add>(
          ramp->base,
          arith::ComputeExpr<Mul>(ramp->stride, i));
      f(i, MakeValue(offset));
    }
  } else {
    llvm::Value* value = MakeValue(e);
    for (int i = 0; i < e.type().lanes(); ++i) {
      f(i, builder_->CreateExtractElement(value, i));
    }
721 722 723
  }
}

724 725

// Visitors
726 727
llvm::Value* CodeGenLLVM::VisitExpr_(const Variable* op) {
  return GetVarValue(op);
728
}
729 730 731

llvm::Value* CodeGenLLVM::VisitExpr_(const Cast* op) {
  return CreateCast(op->value.type(), op->type, MakeValue(op->value));
732
}
733 734
llvm::Value* CodeGenLLVM::VisitExpr_(const IntImm* op) {
  return llvm::ConstantInt::getSigned(LLVMType(op->type), op->value);
735 736
}

737 738 739 740 741 742 743 744 745 746 747 748
llvm::Value* CodeGenLLVM::VisitExpr_(const UIntImm* op) {
  return llvm::ConstantInt::get(LLVMType(op->type), op->value);
}

llvm::Value* CodeGenLLVM::VisitExpr_(const FloatImm* op) {
  return llvm::ConstantFP::get(LLVMType(op->type), op->value);
}

llvm::Value* CodeGenLLVM::VisitExpr_(const StringImm* op) {
  return GetConstString(op->value);
}

749 750
#define DEFINE_CODEGEN_BINARY_OP(Op)                                    \
  llvm::Value* CodeGenLLVM::Create ## Op(                               \
751
      Type t, llvm::Value* a, llvm::Value *b) {                         \
752 753 754 755 756 757 758 759 760 761 762 763
    if (t.is_int()) {                                                   \
      if (t.bits() >= 32) {                                             \
        return builder_->CreateNSW ## Op (a, b);                        \
      } else {                                                          \
        return builder_->Create ## Op (a, b);                           \
      }                                                                 \
    } else if (t.is_uint()) {                                           \
      if (t.bits() >= 32) {                                             \
        return builder_->CreateNUW ## Op (a, b);                        \
      } else {                                                          \
        return builder_->Create ## Op (a, b);                           \
      }                                                                 \
764
    } else {                                                            \
765 766
      CHECK(t.is_float());                                              \
      return builder_->CreateF ## Op (a, b);                            \
767 768
    }                                                                   \
  }                                                                     \
769 770 771
  llvm::Value* CodeGenLLVM::VisitExpr_(const Op* op) {                  \
    return Create ## Op(op->type, MakeValue(op->a), MakeValue(op->b));  \
  }
772 773 774 775 776

DEFINE_CODEGEN_BINARY_OP(Add);
DEFINE_CODEGEN_BINARY_OP(Sub);
DEFINE_CODEGEN_BINARY_OP(Mul);

777 778 779 780 781 782 783 784 785 786 787 788 789 790 791
#define DEFINE_CODEGEN_CMP_OP(Op)                                       \
  llvm::Value* CodeGenLLVM::Create ## Op(                               \
      Type t, llvm::Value* a, llvm::Value* b) {                         \
    if (t.is_int()) {                                                   \
      return builder_->CreateICmpS ## Op (a, b);                        \
    } else if (t.is_uint()) {                                           \
      return builder_->CreateICmpU ## Op (a, b);                        \
    } else {                                                            \
      CHECK(t.is_float());                                              \
      return builder_->CreateFCmpO ## Op (a, b);                        \
    }                                                                   \
}                                                                       \
  llvm::Value* CodeGenLLVM::VisitExpr_(const Op* op) {                  \
    return Create ## Op(op->a.type(), MakeValue(op->a), MakeValue(op->b)); \
  }
792

793 794 795 796
DEFINE_CODEGEN_CMP_OP(LT);
DEFINE_CODEGEN_CMP_OP(LE);
DEFINE_CODEGEN_CMP_OP(GT);
DEFINE_CODEGEN_CMP_OP(GE);
797 798 799

llvm::Value* CodeGenLLVM::VisitExpr_(const Div* op) {
  llvm::Value* a = MakeValue(op->a);
800
  llvm::Value* b = MakeValue(op->b);
801
  if (op->type.is_int()) {
802 803 804
    return builder_->CreateSDiv(a, b);
  } else if (op->type.is_uint()) {
    return builder_->CreateUDiv(a, b);
805
  } else {
806 807
    CHECK(op->type.is_float());
    return builder_->CreateFDiv(a, b);
808 809 810 811
  }
}

llvm::Value* CodeGenLLVM::VisitExpr_(const Mod* op) {
812 813
  llvm::Value* a = MakeValue(op->a);
  llvm::Value* b = MakeValue(op->b);
814
  if (op->type.is_int()) {
815 816 817
    return builder_->CreateSRem(a, b);
  } else if (op->type.is_uint()) {
    return builder_->CreateURem(a, b);
818
  } else {
819 820
    CHECK(op->type.is_float());
    return builder_->CreateFRem(a, b);
821 822 823 824 825 826
  }
}

llvm::Value* CodeGenLLVM::VisitExpr_(const Min* op) {
  llvm::Value* a = MakeValue(op->a);
  llvm::Value* b = MakeValue(op->b);
827
  return builder_->CreateSelect(CreateLT(op->a.type(), a, b), a, b);
828 829 830 831 832
}

llvm::Value* CodeGenLLVM::VisitExpr_(const Max* op) {
  llvm::Value* a = MakeValue(op->a);
  llvm::Value* b = MakeValue(op->b);
833
  return builder_->CreateSelect(CreateGT(op->a.type(), a, b), a, b);
834 835 836
}

llvm::Value* CodeGenLLVM::VisitExpr_(const EQ* op) {
837 838 839 840
  llvm::Value* a = MakeValue(op->a);
  llvm::Value* b = MakeValue(op->b);
  if (op->a.type().is_int() || op->a.type().is_uint()) {
    return builder_->CreateICmpEQ(a, b);
841
  } else {
842
    return builder_->CreateFCmpOEQ(a, b);
843 844 845 846
  }
}

llvm::Value* CodeGenLLVM::VisitExpr_(const NE* op) {
847 848 849 850
  llvm::Value* a = MakeValue(op->a);
  llvm::Value* b = MakeValue(op->b);
  if (op->a.type().is_int() || op->a.type().is_uint()) {
    return builder_->CreateICmpNE(a, b);
851
  } else {
852
    return builder_->CreateFCmpONE(a, b);
853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876
  }
}

llvm::Value* CodeGenLLVM::VisitExpr_(const And* op) {
  return builder_->CreateAnd(MakeValue(op->a), MakeValue(op->b));
}

llvm::Value* CodeGenLLVM::VisitExpr_(const Or* op) {
  return builder_->CreateOr(MakeValue(op->a), MakeValue(op->b));
}

llvm::Value* CodeGenLLVM::VisitExpr_(const Not* op) {
  return builder_->CreateNot(MakeValue(op->a));
}

llvm::Value* CodeGenLLVM::VisitExpr_(const Select* op) {
  return builder_->CreateSelect(
      MakeValue(op->condition),
      MakeValue(op->true_value),
      MakeValue(op->false_value));
}

llvm::Value* CodeGenLLVM::VisitExpr_(const Let* op) {
  CHECK(!var_map_.count(op->var.get()));
877
  var_map_[op->var.get()] = MakeValue(op->value);
878
  analyzer_->Bind(op->var, op->value);
879 880 881
  return MakeValue(op->body);
}

882
llvm::Value* CodeGenLLVM::VisitExpr_(const Load* op) {
883
  Type t = op->type;
884 885 886 887 888
  bool is_volatile = volatile_buf_.count(op->buffer_var.get());
  llvm::Value* buffer = MakeValue(op->buffer_var);
  llvm::Value* index = MakeValue(op->index);

  if (t.lanes() == 1) {
889 890
    int alignment, native_bits;
    GetAlignment(t, op->buffer_var.get(), op->index, &alignment, &native_bits);
891 892 893 894
    llvm::Value* ptr = CreateBufferPtr(t, buffer, index);
    llvm::LoadInst* load = builder_->CreateAlignedLoad(ptr, alignment, is_volatile);
    AddAliasInfo(load, op->buffer_var.get(), op->index, t);
    return load;
895
  } else {
896 897 898 899 900
    // vector load
    unsigned addrspace = llvm::dyn_cast<llvm::PointerType>(
      buffer->getType())->getAddressSpace();
    if (const Ramp* ramp = op->index.as<Ramp>()) {
      if (is_one(ramp->stride)) {
901 902
        int alignment, native_bits;
        GetAlignment(t, op->buffer_var.get(), ramp->base, &alignment, &native_bits);
903 904 905 906 907 908 909 910
        CHECK_EQ(ramp->lanes, t.lanes());
        llvm::Value* ptr = CreateBufferPtr(
            t.element_of(), buffer, MakeValue(ramp->base));
        ptr = builder_->CreatePointerCast(ptr, LLVMType(t)->getPointerTo(addrspace));
        llvm::LoadInst* load = builder_->CreateAlignedLoad(ptr, alignment, is_volatile);
        AddAliasInfo(load, op->buffer_var.get(), op->index, t);
        return load;
      }
911 912
    }
  }
913 914 915 916 917 918 919 920 921 922 923 924
  // scalarized load.
  int basic_align = t.bits() / 8;
  llvm::Value* ret = llvm::UndefValue::get(LLVMType(t));
  auto f = [&](int i, llvm::Value* index) {
    llvm::Value* ptr = CreateBufferPtr(t.element_of(), buffer, index);
    llvm::LoadInst* load = builder_->CreateAlignedLoad(
        ptr, basic_align, is_volatile);
    ret = builder_->CreateInsertElement(ret, load, ConstInt32(i));
    AddAliasInfo(load, op->buffer_var.get(), Expr(), t);
  };
  this->Scalarize(op->index, f);
  return ret;
925 926
}

927 928 929 930 931 932 933 934 935 936
llvm::Value* CodeGenLLVM::VisitExpr_(const Call* op) {
  if (op->call_type == Call::Intrinsic ||
      op->call_type == Call::PureIntrinsic) {
    return CreateIntrinsic(op);
  } else if (op->call_type == Call::Extern ||
             op->call_type == Call::PureExtern) {
    return CreateCallExtern(op);
  } else {
    LOG(FATAL) << "Unknown call type ";
    return nullptr;
937 938 939
  }
}

940 941 942 943 944 945
llvm::Value* CodeGenLLVM::VisitExpr_(const Ramp* op) {
  llvm::Value* vec = llvm::UndefValue::get(LLVMType(op->type));
  for (int i = 0; i < op->lanes; ++i) {
    vec = builder_->CreateInsertElement(
        vec, MakeValue(op->base + op->stride * make_const(op->stride.type(), i)),
        ConstInt32(i));
946
  }
947
  return vec;
948 949
}

950 951
llvm::Value* CodeGenLLVM::VisitExpr_(const Broadcast* op) {
  return CreateBroadcast(MakeValue(op->value), op->lanes);
952 953 954
}

void CodeGenLLVM::VisitStmt_(const Store* op) {
955
  CHECK(is_one(op->predicate));
956
  Type t = op->value.type();
957 958 959 960
  bool is_volatile = volatile_buf_.count(op->buffer_var.get());
  llvm::Value* buffer = MakeValue(op->buffer_var);
  llvm::Value* index = MakeValue(op->index);
  llvm::Value* value = MakeValue(op->value);
961

962
  if (t.lanes() == 1) {
963 964
    int alignment, native_bits;
    GetAlignment(t, op->buffer_var.get(), op->index, &alignment, &native_bits);
965 966 967 968
    llvm::Value* ptr = CreateBufferPtr(t, buffer, index);
    llvm::StoreInst* store = builder_->CreateAlignedStore(value, ptr, alignment, is_volatile);
    AddAliasInfo(store, op->buffer_var.get(), op->index, op->value.type());
    return;
969
  } else {
970 971 972 973 974
    // vector store
    unsigned addrspace = llvm::dyn_cast<llvm::PointerType>(
        buffer->getType())->getAddressSpace();
    if (const Ramp* ramp = op->index.as<Ramp>()) {
      if (is_one(ramp->stride)) {
975 976
        int alignment, native_bits;
        GetAlignment(t, op->buffer_var.get(), ramp->base, &alignment, &native_bits);
977 978 979 980 981 982 983 984 985
        CHECK_EQ(ramp->lanes, t.lanes());
        llvm::Value* ptr = CreateBufferPtr(
            t.element_of(), buffer, MakeValue(ramp->base));
        ptr = builder_->CreatePointerCast(ptr, LLVMType(t)->getPointerTo(addrspace));
        llvm::StoreInst* store = builder_->CreateAlignedStore(value, ptr, alignment, is_volatile);
        AddAliasInfo(store, op->buffer_var.get(), op->index, op->value.type());
        return;
      }
    }
986
  }
987 988 989 990 991 992 993 994 995 996 997
  CHECK_GE(t.bits(), 8);
  // scalarized store.
  int basic_align = t.bits() / 8;
  auto f = [&](int i, llvm::Value* index) {
    llvm::Value* ptr = CreateBufferPtr(t.element_of(), buffer, index);
    llvm::StoreInst* store = builder_->CreateAlignedStore(
        builder_->CreateExtractElement(value, i),
        ptr, basic_align, is_volatile);
    AddAliasInfo(store, op->buffer_var.get(), Expr(), op->value.type());
  };
  this->Scalarize(op->index, f);
998 999 1000 1001
}

void CodeGenLLVM::VisitStmt_(const For* op) {
  CHECK(is_zero(op->min));
1002
  analyzer_->Bind(op->loop_var, Range::make_by_min_extent(op->min, op->extent));
1003 1004 1005 1006 1007 1008
  if (op->for_type == ForType::Unrolled) {
    LOG(WARNING) << "Unroll hint get ignore at CodeGenLLVM backend, "
                 << " consider set unroll_explicit=True";
  } else {
    CHECK(op->for_type == ForType::Serial);
  }
1009 1010
  CreateSerialFor(MakeValue(op->min), MakeValue(op->extent),
                  ConstInt32(1), op->loop_var, op->body);
1011 1012
}

1013

1014 1015
void CodeGenLLVM::VisitStmt_(const IfThenElse* op) {
  using llvm::BasicBlock;
1016
  llvm::Value* cond = MakeValue(op->condition);
1017 1018 1019 1020 1021
  BasicBlock* then_block = BasicBlock::Create(
      *ctx_, "if_then", function_);
  BasicBlock* end_block = BasicBlock::Create(
      *ctx_, "if_end", function_);
  if (op->else_case.defined()) {
1022 1023 1024 1025 1026 1027
    BasicBlock* else_block = BasicBlock::Create(
        *ctx_, "if_else", function_);
    builder_->CreateCondBr(cond, then_block, else_block);
    builder_->SetInsertPoint(then_block);
    this->VisitStmt(op->then_case);
    builder_->CreateBr(end_block);
1028 1029 1030
    builder_->SetInsertPoint(else_block);
    this->VisitStmt(op->else_case);
    builder_->CreateBr(end_block);
1031 1032 1033 1034 1035
  } else {
    builder_->CreateCondBr(cond, then_block, end_block, md_very_likely_branch_);
    builder_->SetInsertPoint(then_block);
    this->VisitStmt(op->then_case);
    builder_->CreateBr(end_block);
1036 1037 1038 1039
  }
  builder_->SetInsertPoint(end_block);
}

1040

1041 1042 1043
void CodeGenLLVM::VisitStmt_(const Allocate* op) {
  CHECK(!is_zero(op->condition));
  llvm::Value* buf = nullptr;
1044 1045 1046 1047 1048 1049
  if (op->new_expr.defined()) {
    CHECK_EQ(op->free_function, "nop");
    buf = MakeValue(op->new_expr);
  } else {
    int32_t constant_size = op->constant_allocation_size();
    CHECK_GT(constant_size, 0)
1050
        << "Can only handle constant size stack allocation";
1051 1052
    StorageInfo& info = alloc_storage_info_[op->buffer_var.get()];
    if (constant_size % 4 == 0 && info.alignment == 0) {
1053
      info.alignment = GetTempAllocaAlignment(op->type, constant_size);
1054
    }
1055 1056 1057 1058
    // maximum necessary alignment in the NV devices
    if (info.alignment > 16) {
      info.alignment = 16;
    }
1059 1060 1061 1062
    llvm::AllocaInst* alloca = WithFunctionEntry([&]() {
        return builder_->CreateAlloca(
            LLVMType(op->type), ConstInt32(constant_size));
      });
1063 1064 1065 1066
    if (alloca->getAlignment() < static_cast<uint32_t>(info.alignment)) {
      alloca->setAlignment(info.alignment);
    }
    info.alignment = alloca->getAlignment();
1067
    buf = alloca;
1068
  }
1069 1070 1071
  buf = builder_->CreatePointerCast(
      buf, LLVMType(op->type)->getPointerTo(
          buf->getType()->getPointerAddressSpace()));
1072 1073
  CHECK(!var_map_.count(op->buffer_var.get()));
  var_map_[op->buffer_var.get()] = buf;
1074
  this->VisitStmt(op->body);
1075 1076
}

1077
void CodeGenLLVM::VisitStmt_(const AttrStmt* op) {
1078
  if (op->attr_key == attr::thread_extent) {
1079 1080 1081 1082
    IterVar iv(op->node.node_);
    if (iv->thread_tag.length() != 0) {
      if (!var_map_.count(iv->var.get())) {
        var_map_[iv->var.get()] = GetThreadIndex(iv);
1083
        analyzer_->Bind(iv->var, Range::make_by_min_extent(0, op->value));
1084 1085 1086
      }
    }
  } else if (op->attr_key == ir::attr::storage_scope) {
1087 1088
    const Variable* v = op->node.as<Variable>();
    CHECK(v);
1089 1090
    alloc_storage_info_[v].scope =
        runtime::StorageScope::make(op->value.as<StringImm>()->value);
1091 1092 1093 1094 1095
  } else if (op->attr_key == ir::attr::storage_alignment) {
    const Variable* v = op->node.as<Variable>();
    CHECK(v);
    alloc_storage_info_[v].alignment =
        static_cast<int>(op->value.as<IntImm>()->value);
1096 1097 1098 1099
  } else if (op->attr_key == ir::attr::volatile_scope) {
    const Variable* v = op->node.as<Variable>();
    CHECK(v);
    volatile_buf_.insert(v);
1100
  }
1101
  this->VisitStmt(op->body);
1102 1103
}

1104
void CodeGenLLVM::VisitStmt_(const AssertStmt* op) {
1105 1106
  arith::ConstraintContext cctx(analyzer_.get(), op->condition);
  this->VisitStmt(op->body);
1107 1108
}

1109
void CodeGenLLVM::VisitStmt_(const LetStmt* op) {
1110
  CHECK(!var_map_.count(op->var.get()));
1111 1112 1113 1114 1115
  if (op->var.type().is_handle()) {
    if (!is_restricted_) {
      alias_var_set_.insert(op->var.get());
    }
  }
1116
  var_map_[op->var.get()] = MakeValue(op->value);
1117
  analyzer_->Bind(op->var, op->value);
1118
  this->VisitStmt(op->body);
1119
}
1120

1121
void CodeGenLLVM::VisitStmt_(const Block* op) {
1122 1123 1124 1125
  this->VisitStmt(op->first);
  if (op->rest.defined()) {
    this->VisitStmt(op->rest);
  }
1126
}
1127

1128
void CodeGenLLVM::VisitStmt_(const Evaluate* op) {
1129 1130
  MakeValue(op->value);
}
1131

1132
void CodeGenLLVM::VisitStmt_(const ProducerConsumer* op) {
1133
  this->VisitStmt(op->body);
1134 1135 1136 1137
}
}  // namespace codegen
}  // namespace tvm
#endif  // TVM_LLVM_VERSION