codegen_llvm.cc 43.6 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23 24
/*!
 *  Copyright (c) 2017 by Contributors
 * \file codegen_llvm.cc
 */
#ifdef TVM_LLVM_VERSION
25
// Part of the code are adapted from Halide's CodeGen_LLVM
26
#include <tvm/runtime/device_api.h>
27
#include <tvm/runtime/c_runtime_api.h>
28 29 30

#include <algorithm>

31 32
#include "codegen_llvm.h"
#include "codegen_cpu.h"
33
#include "../build_common.h"
34
#include "../../pass/ir_util.h"
35 36 37 38 39
#include "../../arithmetic/compute_expr.h"

namespace tvm {
namespace codegen {

40 41 42 43 44 45 46 47
std::unique_ptr<CodeGenLLVM> CodeGenLLVM::Create(llvm::TargetMachine *tm) {
  std::string target = tm->getTarget().getName();
  std::string factory_name = "tvm.codegen.llvm.target_" + target;
  const PackedFunc* f = runtime::Registry::Get(factory_name);
  if (f != nullptr) {
    void* handle = (*f)();
    return std::unique_ptr<CodeGenLLVM>(static_cast<CodeGenLLVM*>(handle));
  } else {
48
    return std::unique_ptr<CodeGenLLVM>(new CodeGenCPU());
49 50 51
  }
}

52
void CodeGenLLVM::Init(const std::string& module_name,
53
                       llvm::TargetMachine* tm,
54
                       llvm::LLVMContext* ctx,
55 56
                       bool system_lib,
                       bool dynamic_lookup) {
57 58
  InitializeLLVM();
  ctx_ = ctx;
59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75
  builder_.reset(new IRBuilder(*ctx_));
  module_.reset(new llvm::Module(module_name, *ctx_));
  md_builder_.reset(new llvm::MDBuilder(*ctx_));
  // types
  t_void_ = llvm::Type::getVoidTy(*ctx_);
  t_void_p_ = llvm::Type::getInt8Ty(*ctx_)->getPointerTo();
  t_int_ = llvm::Type::getInt32Ty(*ctx_);
  t_char_ = llvm::Type::getInt8Ty(*ctx_);
  t_int8_ = llvm::Type::getInt8Ty(*ctx_);
  t_int16_ = llvm::Type::getInt16Ty(*ctx_);
  t_int32_ = llvm::Type::getInt32Ty(*ctx_);
  t_int64_ = llvm::Type::getInt64Ty(*ctx_);
  t_float64_ = llvm::Type::getDoubleTy(*ctx_);
  // meta data
  md_very_likely_branch_ = md_builder_->createBranchWeights(1<<20, 1);
  md_tbaa_root_ = md_builder_->createTBAARoot("tvm-tbaa");
  md_tbaa_alias_set_ = md_builder_->createTBAANode("tvm-alias", md_tbaa_root_);
76
  this->InitTarget(tm);
77 78
}

79 80
void CodeGenLLVM::InitTarget(llvm::TargetMachine* tm) {
  module_->setTargetTriple(tm->getTargetTriple().str());
81 82
  module_->setDataLayout(tm->createDataLayout());
  data_layout_.reset(new llvm::DataLayout(module_.get()));
83
  target_machine_ = tm;
84 85 86 87 88 89 90 91 92 93 94 95 96
  if (native_vector_bits_ == 0) {
    const auto& arch = tm->getTargetTriple().getArch();
    if (arch == llvm::Triple::x86_64) {
      // for avx512
      native_vector_bits_ = 512;
    } else if (arch == llvm::Triple::x86) {
      native_vector_bits_ = 256;
    } else if (arch == llvm::Triple::arm || arch == llvm::Triple::aarch64) {
      native_vector_bits_ = 128;
    } else {
      native_vector_bits_ = 128;
      std::string arch_name = tm->getTargetTriple().getArchName();
      LOG(WARNING) << "Set native vector bits to be 128 for " << arch_name;
97
    }
98
  }
99 100
}

101 102 103 104
void CodeGenLLVM::AddFunction(const LoweredFunc& f) {
  this->AddFunctionInternal(f, false);
}

105
void CodeGenLLVM::InitFuncState() {
106
  var_map_.clear();
107
  alias_var_set_.clear();
108
  alloc_storage_info_.clear();
109
  volatile_buf_.clear();
110
  analyzer_.reset(new arith::Analyzer());
111 112
}

113

114
void CodeGenLLVM::AddFunctionInternal(const LoweredFunc& f, bool ret_void) {
115
  this->InitFuncState();
116
  std::vector<llvm::Type*> arg_types;
117
  is_restricted_ = f->is_restricted;
118 119
  for (Var arg : f->args) {
    Type t = arg.type();
120 121 122 123 124 125 126 127
    if (t.is_handle()) {
      auto it = f->handle_data_type.find(arg);
      if (it != f->handle_data_type.end()) {
        arg_types.push_back(LLVMType((*it).second.type())
                            ->getPointerTo(GetGlobalAddressSpace()));
      } else {
        arg_types.push_back(t_int8_->getPointerTo(GetGlobalAddressSpace()));
      }
128 129 130
      if (!is_restricted_) {
        alias_var_set_.insert(arg.get());
      }
131
    } else {
132
      arg_types.push_back(LLVMType(arg.type()));
133 134
    }
  }
135
  llvm::FunctionType* ftype = llvm::FunctionType::get(
136 137 138 139 140 141
      ret_void ? t_void_ : t_int_, arg_types, false);
  CHECK(module_->getFunction(f->name) == nullptr)
      << "Function " << f->name << " already exist in module";
  function_ = llvm::Function::Create(
      ftype, llvm::Function::ExternalLinkage,
      f->name, module_.get());
142
  function_->setCallingConv(llvm::CallingConv::C);
Hu Shiwen committed
143
  function_->setDLLStorageClass(llvm::GlobalValue::DLLStorageClassTypes::DLLExportStorageClass);
144 145 146 147 148 149 150 151 152
  // set var map and align information
  auto arg_it = function_->arg_begin();
  for (size_t i = 0; i < f->args.size(); ++i, ++arg_it) {
    llvm::Argument* v = &(*arg_it);
    const Var& var = f->args[i];
    var_map_[var.get()] = v;
    if (is_restricted_) {
      if (var.type().is_handle() && !alias_var_set_.count(var.get())) {
        // set non alias.
153 154 155
#if TVM_LLVM_VERSION >= 50
        function_->addParamAttr(i, llvm::Attribute::NoAlias);
#else
156
        function_->setDoesNotAlias(i + 1);
157
#endif
158 159 160
      }
    }
  }
161 162
  llvm::BasicBlock* entry = llvm::BasicBlock::Create(*ctx_, "entry", function_);
  builder_->SetInsertPoint(entry);
163
  this->VisitStmt(f->body);
164 165 166 167 168
  if (ret_void) {
    builder_->CreateRetVoid();
  } else {
    builder_->CreateRet(ConstInt32(0));
  }
169 170
}

171

172 173
std::unique_ptr<llvm::Module> CodeGenLLVM::Finish() {
  this->AddStartupFunction();
174 175 176 177 178 179
  for (size_t i = 0; i < link_modules_.size(); ++i) {
    CHECK(!llvm::Linker::linkModules(*module_, std::move(link_modules_[i])))
        << "Failed to link modules";
  }
  link_modules_.clear();
  // optimize
180 181 182 183
  this->Optimize();
  return std::move(module_);
}

184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219

void CodeGenLLVM::HandleImport(const std::string& code) {
  std::unique_ptr<llvm::Module> mlib;
  llvm::SMDiagnostic err;
  if (code.length() >= 3 &&
      (code.substr(code.length() - 3) == ".ll" ||
       code.substr(code.length() - 3) == ".bc")) {
    mlib = llvm::parseIRFile(code, err, *ctx_);
    if (mlib.get() == nullptr) {
      std::string msg = err.getMessage();
      LOG(FATAL) << "Fail to load bitcode file " << code << "\n"
                 << "line " << err.getLineNo() << ":" << msg;
    }
  } else {
    std::unique_ptr<llvm::MemoryBuffer> buf =
        llvm::MemoryBuffer::getMemBuffer(code);
    mlib = llvm::parseIR(*buf, err, *ctx_);
    if (mlib.get() == nullptr) {
      std::string msg = err.getMessage();
      LOG(FATAL) << "Fail to load llvm ir "
                 << "line " << err.getLineNo() << ":" << msg
                 << "\ncontent:\n"  << code;
    }
  }
  mlib->setTargetTriple(target_machine_->getTargetTriple().str());
  mlib->setDataLayout(target_machine_->createDataLayout());
  // mark all the functions as force inline
  for (llvm::Function &f : mlib->functions()) {
    f.removeFnAttr(llvm::Attribute::NoInline);
    f.addFnAttr(llvm::Attribute::AlwaysInline);
    f.setLinkage(llvm::GlobalValue::AvailableExternallyLinkage);
  }
  // add to linker libraries.
  this->AddLinkModule(std::move(mlib));
}

220 221 222 223
void CodeGenLLVM::AddLinkModule(std::unique_ptr<llvm::Module>&& mod) {
  link_modules_.emplace_back(std::move(mod));
}

224
void CodeGenLLVM::AddMainFunction(const std::string& entry_func_name) {
225 226 227 228 229 230 231 232 233 234 235
  LOG(FATAL) << "not implemented";
}

llvm::Value* CodeGenLLVM::GetThreadIndex(const IterVar& iv) {
  LOG(FATAL) << "not implemented";
  return nullptr;
}

llvm::Value* CodeGenLLVM::CreateStorageSync(const Call* op) {
  LOG(FATAL) << "not implemented";
  return nullptr;
236 237
}

238 239 240 241 242 243 244 245 246
class FPassManager : public llvm::legacy::FunctionPassManager {
 public:
  explicit FPassManager(llvm::Module* m)
      : llvm::legacy::FunctionPassManager(m) {}
  // override add to allow messaging
  void add(llvm::Pass* p) final {
    llvm::legacy::FunctionPassManager::add(p);
  }
};
247

248 249 250 251 252 253 254 255
class MPassManager : public llvm::legacy::PassManager {
 public:
  // override add to allow messaging
  void add(llvm::Pass* p) final {
    llvm::legacy::PassManager::add(p);
  }
};

256 257 258
void CodeGenLLVM::InitPassManagerBuilder(llvm::PassManagerBuilder* builder) {
}

259
void CodeGenLLVM::Optimize() {
260 261 262 263 264 265 266 267 268 269
  // pass manager
  FPassManager fpass(module_.get());
  MPassManager mpass;
  mpass.add(llvm::createTargetTransformInfoWrapperPass(
              target_machine_ ? target_machine_->getTargetIRAnalysis() :
                                llvm::TargetIRAnalysis()));
  fpass.add(llvm::createTargetTransformInfoWrapperPass(
              target_machine_ ? target_machine_->getTargetIRAnalysis() :
              llvm::TargetIRAnalysis()));

270 271 272
  // place optimization pass
  llvm::PassManagerBuilder builder;
  builder.OptLevel = 3;
273 274 275 276

#if TVM_LLVM_VERSION >= 50
  builder.Inliner = llvm::createFunctionInliningPass(builder.OptLevel, 0, false);
#else
277
  builder.Inliner = llvm::createFunctionInliningPass(builder.OptLevel, 0);
278
#endif
279 280
  builder.LoopVectorize = true;
  builder.SLPVectorize = true;
281 282 283 284 285 286
  this->InitPassManagerBuilder(&builder);

#if TVM_LLVM_VERSION >= 50
  target_machine_->adjustPassManager(builder);
#endif

287 288 289 290 291 292 293 294 295 296 297
  builder.populateFunctionPassManager(fpass);
  builder.populateModulePassManager(mpass);

  fpass.doInitialization();
  for (auto it = module_->begin(); it != module_->end(); ++it) {
    fpass.run(*it);
  }
  fpass.doFinalization();
  mpass.run(*module_);
}

298 299 300 301 302 303
int CodeGenLLVM::NativeVectorBits(const runtime::StorageScope& storage_scope) const {
  return native_vector_bits_;
}

unsigned CodeGenLLVM::GetGlobalAddressSpace() {
  return 0;
304 305 306
}

llvm::Type* CodeGenLLVM::LLVMType(const Type& t) const {
307 308 309 310
  if (t.is_handle()) {
    CHECK_EQ(t.lanes(), 1);
    return t_void_p_;
  }
311
  llvm::Type* etype = nullptr;
312 313
  if (t.is_int() || t.is_uint()) {
    etype = llvm::Type::getIntNTy(*ctx_, t.bits());
314 315
  } else if (t.is_float()) {
    switch (t.bits()) {
316 317 318 319
      case 16: etype = llvm::Type::getHalfTy(*ctx_); break;
      case 32: etype = llvm::Type::getFloatTy(*ctx_); break;
      case 64: etype = llvm::Type::getDoubleTy(*ctx_); break;
      default: LOG(FATAL) << "do not support " << t;
320 321 322
    }
  }
  if (t.lanes() != 1) {
323 324 325
    return llvm::VectorType::get(etype, t.lanes());
  } else {
    return etype;
326 327 328
  }
}

329 330 331 332 333 334 335 336 337 338 339
// Add tbaa alias information for load
//
// use a binary tree typed system to declare information
// and allow alias to be distinguished across nodes.
//
// This trick comes from Halide's CodeGen_LLVM
//
void CodeGenLLVM::AddAliasInfo(llvm::Instruction* inst,
                               const Variable* buffer,
                               Expr index,
                               Type type) {
340 341 342 343 344 345 346 347
  if (alias_var_set_.count(buffer) != 0) {
    // Mark all possibly aliased pointer as same type.
    llvm::MDNode* meta = md_tbaa_alias_set_;
    inst->setMetadata(
        "tbaa",
        md_builder_->createTBAAStructTagNode(meta, meta, 0));
    return;
  }
348 349
  int base = 0, width = 0;
  // create meta-data for alias analysis
350
  // Use a group of binary tree ranges of memory banks.
351 352 353 354 355 356 357 358 359 360 361 362 363 364 365
  if (index.defined()) {
    const Ramp* ramp = index.as<Ramp>();
    if (ramp) {
      int base, stride;
      if (arith::GetConstInt(ramp->base, &base) &&
          arith::GetConstInt(ramp->stride, &stride)) {
        int xwith = ramp->lanes * stride;
        width = 1;
        while (width < xwith) {
          width *= 2;
        }
        while (base % width) {
          base -= base % width;
          width *= 2;
        }
366
      }
367 368
    } else {
      if (arith::GetConstInt(index, &base)) width = 1;
369 370 371
    }
  }
  llvm::MDNode* meta = md_tbaa_root_;
372
  std::ostringstream buffer_addr, buffer_type;
373 374
  buffer_addr << buffer;
  meta = md_builder_->createTBAAScalarTypeNode(buffer_addr.str(), meta);
375
  buffer_type << type.element_of();
376
  meta = md_builder_->createTBAAScalarTypeNode(buffer_type.str(), meta);
377 378 379 380 381 382 383 384 385 386 387 388
  // create a tree-shape access structure.
  if (width != 0) {
    for (int w = 1024; w >= width; w /= 2) {
      int b = (base / w) * w;
      std::stringstream os;
      os << buffer << ".w" << w << ".b" << b;
      meta = md_builder_->createTBAAScalarTypeNode(os.str(), meta);
    }
  }
  inst->setMetadata(
      "tbaa",
      md_builder_->createTBAAStructTagNode(meta, meta, 0));
389 390
}

391 392 393 394 395 396 397 398 399 400 401 402 403 404 405
void CodeGenLLVM::GetAlignment(Type t,
                               const Variable* buf_var,
                               const Expr& index,
                               int* p_alignment,
                               int* p_native_bits) {
  int max_align_bits = t.bits();
  auto it = alloc_storage_info_.find(buf_var);
  if (it != alloc_storage_info_.end()) {
    const StorageInfo& info = it->second;
    *p_native_bits = NativeVectorBits(info.scope);
    max_align_bits = info.alignment * 8;
  } else {
    *p_native_bits = native_vector_bits_;
  }

406 407 408
  arith::ModularSet me = analyzer_->modular_set(index);
  int64_t base = me->base;
  int64_t coeff = me->coeff;
409

410 411
  int align_bits = t.bits();
  while (align_bits < max_align_bits &&
412 413 414 415
         base % 2  == 0 &&
         coeff % 2 == 0) {
    base =  base / 2;
    coeff =  coeff / 2;
416 417 418 419 420 421 422 423
    align_bits *= 2;
  }
  if (align_bits < 8) {
    align_bits = 8;
  }
  *p_alignment = align_bits / 8;
}

424 425
std::unique_ptr<CodeGenLLVM::DebugInfo>
CodeGenLLVM::CreateDebugInfo(llvm::Module* module) {
426 427 428 429 430 431 432 433 434 435 436 437
  auto debug_info = llvm::make_unique<CodeGenLLVM::DebugInfo>();
  debug_info->di_builder_ = llvm::make_unique<llvm::DIBuilder>(*module);
  // TODO(tulloch): pass this information through relay::Span classes to the LoweredFunc instance?
  debug_info->file_ = debug_info->di_builder_->createFile("model.tvm", "/tmp/");
  debug_info->compilation_unit_ = debug_info->di_builder_->createCompileUnit(
      llvm::dwarf::DW_LANG_C, debug_info->file_, "TVM", 0, "", 0, "",
      llvm::DICompileUnit::DebugEmissionKind::FullDebug,
      /* SplitDebugInlining */ true,
      /* DebugInfoForProfiling */ true);
  return debug_info;
}

438
llvm::Value* CodeGenLLVM::CreateBroadcast(llvm::Value* value, int lanes) {
439
  llvm::Constant* undef = llvm::UndefValue::get(
440 441
      llvm::VectorType::get(value->getType(), lanes));
  llvm::Constant* zero = ConstInt32(0);
442
  value = builder_->CreateInsertElement(undef, value, zero);
443
  llvm::Constant* mask = llvm::ConstantVector::getSplat(lanes, zero);
444
  return builder_->CreateShuffleVector(value, undef, mask);
445 446
}

447 448 449
llvm::Value* CodeGenLLVM::CreateVecSlice(llvm::Value* vec, int begin, int extent) {
  int num_elems = static_cast<int>(vec->getType()->getVectorNumElements());
  if (extent == num_elems && begin == 0) return vec;
450
  CHECK(begin >= 0 && extent <= num_elems) << "Slicing out of bound!\n";
451 452
  std::vector<llvm::Constant*> indices;
  indices.reserve(extent);
453
  for (int i = 0; i < extent; ++i) {
454 455 456 457 458
    if (begin + i >= 0 && begin + i < num_elems) {
      indices.push_back(llvm::ConstantInt::get(t_int32_, begin + i));
    } else {
      indices.push_back(llvm::UndefValue::get(t_int32_));
    }
459
  }
460
  return builder_->CreateShuffleVector(vec, vec, llvm::ConstantVector::get(indices));
461
}
462

463 464 465 466 467
llvm::Value* CodeGenLLVM::CreateVecFlip(llvm::Value* vec) {
  int num_elems = static_cast<int>(vec->getType()->getVectorNumElements());
  std::vector<unsigned> indices;
  for (int i = 0; i < num_elems; ++i) {
    indices.push_back(num_elems - i - 1);
468
  }
469 470 471 472 473 474 475 476 477 478
  return builder_->CreateShuffleVector(vec, vec, indices);
}

llvm::Value* CodeGenLLVM::CreateVecPad(llvm::Value* vec, int target_lanes) {
  llvm::Value* mask = llvm::UndefValue::get(LLVMType(Int(32, target_lanes)));
  int num_elems = static_cast<int>(vec->getType()->getVectorNumElements());
  if (num_elems == target_lanes) return vec;
  CHECK_LT(num_elems, target_lanes);
  for (int i = 0; i < num_elems; ++i) {
    mask = builder_->CreateInsertElement(mask, ConstInt32(i), ConstInt32(i));
479
  }
480
  return builder_->CreateShuffleVector(vec, vec, mask);
481 482
}

483 484 485
llvm::Value* CodeGenLLVM::CreateVecConcat(std::vector<llvm::Value*> vecs) {
  // concat vector, tree shape reduction
  int total_lanes = 0;
486

487 488 489 490 491
  for (llvm::Value* v : vecs) {
    total_lanes += static_cast<int>(
        v->getType()->getVectorNumElements());
  }
  while (vecs.size() > 1) {
492 493
    std::vector<llvm::Value*> new_vecs;
    for (size_t i = 0; i < vecs.size() - 1; i += 2) {
494 495
      llvm::Value* lhs = vecs[i];
      llvm::Value* rhs = vecs[i + 1];
496 497 498 499 500 501 502 503
      const size_t lhs_lanes = lhs->getType()->getVectorNumElements();
      const size_t rhs_lanes = rhs->getType()->getVectorNumElements();
      if (lhs_lanes < rhs_lanes) {
        lhs = CreateVecPad(lhs, rhs_lanes);
      } else if (rhs_lanes < lhs_lanes) {
        rhs = CreateVecPad(rhs, lhs_lanes);
      }
      const size_t shared_lanes = std::max(lhs_lanes, rhs_lanes);
504
      std::vector<unsigned> mask;
505
      for (size_t i = 0; i < lhs_lanes; ++i) {
506 507
        mask.push_back(i);
      }
508 509 510 511 512 513 514
      for (size_t i = 0; i < rhs_lanes; ++i) {
        mask.push_back(shared_lanes + i);
      }
      new_vecs.push_back(builder_->CreateShuffleVector(lhs, rhs, mask));
    }
    if (vecs.size() % 2 != 0) {
      new_vecs.push_back(vecs.back());
515
    }
516
    vecs.swap(new_vecs);
517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552
  }
  return CreateVecSlice(vecs[0], 0, total_lanes);
}


void CodeGenLLVM::CreateSerialFor(llvm::Value* begin,
                                  llvm::Value* end,
                                  llvm::Value* stride,
                                  const VarExpr& loop_var,
                                  const Stmt& body) {
  using llvm::BasicBlock;
  BasicBlock* pre_block = builder_->GetInsertBlock();
  BasicBlock* for_begin = BasicBlock::Create(
      *ctx_, "for_begin", function_);
  BasicBlock* for_body = BasicBlock::Create(
      *ctx_, "for_body", function_);
  BasicBlock* for_end = BasicBlock::Create(
      *ctx_, "for_end", function_);
  builder_->CreateBr(for_begin);
  builder_->SetInsertPoint(for_begin);
  llvm::PHINode* loop_value = builder_->CreatePHI(begin->getType(), 2);
  loop_value->addIncoming(begin, pre_block);
  CHECK(!var_map_.count(loop_var.get()));
  var_map_[loop_var.get()] = loop_value;
  builder_->CreateCondBr(CreateLT(loop_var.type(), loop_value, end),
                         for_body, for_end, md_very_likely_branch_);
  builder_->SetInsertPoint(for_body);
  this->VisitStmt(body);
  var_map_.erase(loop_var.get());
  llvm::Value* loop_next = CreateAdd(loop_var.type(), loop_value, stride);
  loop_value->addIncoming(loop_next, builder_->GetInsertBlock());
  builder_->CreateBr(for_begin);
  builder_->SetInsertPoint(for_end);
}

// cast operatpr
553 554 555
llvm::Value* CodeGenLLVM::CreateCast(Type from, Type to, llvm::Value* value) {
  llvm::Type * target = LLVMType(to);
  if (value->getType() == target) return value;
556
  if (to.is_handle()) {
557
    return builder_->CreateBitCast(value, target);
558 559 560 561 562 563 564 565
  } else if (to.is_uint() && to.bits() == 1) {
    if (from.is_float()) {
      llvm::Constant* zero = llvm::ConstantFP::get(LLVMType(from), 0.);
      return builder_->CreateFCmpONE(value, zero);
    } else {
      llvm::Constant* zero = llvm::ConstantInt::get(LLVMType(from), 0);
      return builder_->CreateICmpNE(value, zero);
    }
566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584
  } else if (!from.is_float() && !to.is_float()) {
    return builder_->CreateIntCast(value, target, from.is_int());
  } else if (from.is_float() && to.is_int()) {
    return builder_->CreateFPToSI(value, target);
  } else if (from.is_float() && to.is_uint()) {
    if (to.bits() < 8) {
      value = builder_->CreateFPToUI(value, LLVMType(to.with_bits(8)));
      return builder_->CreateIntCast(value, target, false);
    } else {
      return builder_->CreateFPToUI(value, target);
    }
  } else if (from.is_int() && to.is_float()) {
    return builder_->CreateSIToFP(value, target);
  } else if (from.is_uint() && to.is_float()) {
    return builder_->CreateUIToFP(value, target);
  } else {
    CHECK(from.is_float() && to.is_float());
    return builder_->CreateFPCast(value, target);
  }
585 586
}

587 588 589 590 591 592 593 594 595 596 597 598 599 600
llvm::Value* CodeGenLLVM::GetConstString(const std::string& str) {
  auto it = str_map_.find(str);
  if (it != str_map_.end()) return it->second;
  llvm::Type* type = llvm::ArrayType::get(t_char_, str.length() + 1);
  llvm::GlobalVariable *global = new llvm::GlobalVariable(
      *module_, type, true, llvm::GlobalValue::PrivateLinkage, 0, ".str");
  global->setAlignment(1);
  global->setInitializer(llvm::ConstantDataArray::getString(*ctx_, str));
  llvm::Constant* zero = ConstInt32(0);
  llvm::Constant* indices[] = {zero, zero};
  llvm::Constant* ptr = llvm::ConstantExpr::getGetElementPtr(
      type, global, indices);
  str_map_[str] = ptr;
  return ptr;
601 602
}

603 604 605 606 607 608 609 610
llvm::Value* CodeGenLLVM::CreateBufferPtr(
    Type t, llvm::Value* buffer, llvm::Value* index) {
  CHECK_EQ(t.lanes(), 1);
  llvm::PointerType* btype = llvm::dyn_cast<llvm::PointerType>(buffer->getType());
  CHECK(btype != nullptr);
  llvm::PointerType* ptype = LLVMType(t)->getPointerTo(btype->getAddressSpace());
  if (btype != ptype) {
    buffer = builder_->CreatePointerCast(buffer, ptype);
611
  }
612

613
  return builder_->CreateInBoundsGEP(buffer, index);
614 615
}

616 617 618 619 620 621 622 623 624 625 626 627
llvm::Value* CodeGenLLVM::CreateBufferVecPtr(
    Type t, llvm::Value* buffer, llvm::Value* index) {
  CHECK_GT(t.lanes(), 1);
  llvm::PointerType* btype = llvm::dyn_cast<llvm::PointerType>(buffer->getType());
  CHECK(btype != nullptr);
  llvm::PointerType* ptype = LLVMType(t)->getPointerTo(btype->getAddressSpace());
  if (btype != ptype) {
    buffer = builder_->CreatePointerCast(buffer, ptype);
  }
  return builder_->CreateInBoundsGEP(buffer, index);
}

628 629
llvm::Value* CodeGenLLVM::GetVarValue(const Variable* v) const {
  auto it = var_map_.find(v);
630
  CHECK(it != var_map_.end()) << "cannot find variable " << v->name_hint;
631
  return it->second;
632 633
}

634 635 636
llvm::Value* CodeGenLLVM::CreateCallExtern(const Call* op) {
  std::vector<llvm::Value*> arg_value;
  std::vector<llvm::Type*> arg_type;
637 638
  for (size_t i = 0; i < op->args.size(); ++i) {
    arg_value.push_back(MakeValue(op->args[i]));
639
    arg_type.push_back(arg_value.back()->getType());
640
  }
641 642 643 644 645 646 647 648 649 650
  llvm::FunctionType* ftype = llvm::FunctionType::get(
      LLVMType(op->type), arg_type, false);
  llvm::Function* f = module_->getFunction(op->name);
  if (f == nullptr) {
    f = llvm::Function::Create(
        ftype, llvm::Function::ExternalLinkage,
        op->name, module_.get());
  }
  llvm::CallInst* call = builder_->CreateCall(f, arg_value);
  return call;
651 652
}

653
llvm::Value* CodeGenLLVM::CreateIntrinsic(const Call* op) {
654
  if (op->is_intrinsic("llvm_intrin")) {
655
    CHECK_GE(op->args.size(), 2U);
656 657
    llvm::Intrinsic::ID id = static_cast<llvm::Intrinsic::ID>(
        op->args[0].as<UIntImm>()->value);
658 659 660
    const uint64_t *num_signature = as_const_uint(op->args[1]);
    CHECK(num_signature) << "The second argument should be a uint represents number of arguments, "
                         << "but " << op->args[1] << " got!\n";
661
    std::vector<llvm::Value*> arg_value;
662 663
    std::vector<llvm::Type*> sig_type;
    for (size_t i = 2; i < op->args.size(); ++i) {
664
      arg_value.push_back(MakeValue(op->args[i]));
665
      if (i - 2 < *num_signature) {
666 667
        sig_type.push_back(arg_value.back()->getType());
      }
668
    }
669 670 671 672
    llvm::Type *return_type = LLVMType(op->type);
    if (sig_type.size() > 0 && return_type != sig_type[0]) {
      sig_type.insert(sig_type.begin(), return_type);
    }
673
    llvm::Function* f = llvm::Intrinsic::getDeclaration(
674
        module_.get(), id, sig_type);
675
    return builder_->CreateCall(f, arg_value);
676
  } else if (op->is_intrinsic(Call::bitwise_and)) {
677
    return builder_->CreateAnd(MakeValue(op->args[0]), MakeValue(op->args[1]));
678
  } else if (op->is_intrinsic(Call::bitwise_or)) {
679
    return builder_->CreateOr(MakeValue(op->args[0]), MakeValue(op->args[1]));
680 681
  } else if (op->is_intrinsic(Call::bitwise_not)) {
    return builder_->CreateNot(MakeValue(op->args[0]));
682 683
  } else if (op->is_intrinsic(Call::bitwise_xor)) {
    return builder_->CreateXor(MakeValue(op->args[0]), MakeValue(op->args[1]));
684
  } else if (op->is_intrinsic(Call::shift_left)) {
685
    return builder_->CreateShl(MakeValue(op->args[0]), MakeValue(op->args[1]));
686
  } else if (op->is_intrinsic(Call::shift_right)) {
687 688
    if (op->args[0].type().is_int()) {
      return builder_->CreateAShr(MakeValue(op->args[0]), MakeValue(op->args[1]));
689
    } else {
690
      return builder_->CreateLShr(MakeValue(op->args[0]), MakeValue(op->args[1]));
691
    }
692 693
  } else if (op->is_intrinsic(intrinsic::tvm_storage_sync)) {
    return CreateStorageSync(op);
694
  } else if (op->is_intrinsic(intrinsic::tvm_address_of)) {
695 696
    const Load *l = op->args[0].as<Load>();
    CHECK(op->args.size() == 1 && l);
697 698 699 700 701 702 703 704 705 706 707 708 709 710 711
    const Ramp *r = l->index.as<Ramp>();
    llvm::Value* ptr;
    unsigned addrspace;
    if (!r) {
        ptr = CreateBufferPtr(
          l->type, MakeValue(l->buffer_var), MakeValue(l->index));
        addrspace = llvm::dyn_cast<llvm::PointerType>(
          ptr->getType())->getAddressSpace();
    } else {
        Expr index = r->base / make_const(Int(32), r->lanes);
        ptr = CreateBufferVecPtr(
          l->type, MakeValue(l->buffer_var), MakeValue(index));
        addrspace = llvm::dyn_cast<llvm::PointerType>(
          ptr->getType())->getAddressSpace();
    }
712
    return builder_->CreatePointerCast(ptr, t_char_->getPointerTo(addrspace));
713 714
  } else if (op->is_intrinsic(Call::reinterpret) && is_zero(op->args[0])) {
    return llvm::Constant::getNullValue(t_void_p_);
715
  } else if (op->is_intrinsic(intrinsic::tvm_handle_is_null)) {
716
    return builder_->CreateIsNull(MakeValue(op->args[0]));
717
  } else if (op->is_intrinsic(intrinsic::tvm_if_then_else)) {
718 719
    CHECK_EQ(op->args[0].type().lanes(), 1)
        << "if_then_else can only take scalar condition";
720 721 722 723 724 725 726
    using llvm::BasicBlock;
    BasicBlock* then_block = BasicBlock::Create(
        *ctx_, "if_then", function_);
    BasicBlock* else_block = BasicBlock::Create(
        *ctx_, "if_else", function_);
    BasicBlock* end_block = BasicBlock::Create(
        *ctx_, "if_end", function_);
727
    builder_->CreateCondBr(MakeValue(op->args[0]), then_block, else_block);
728 729
    builder_->SetInsertPoint(then_block);
    llvm::Value* then_value = MakeValue(op->args[1]);
730
    BasicBlock* then_value_block = builder_->GetInsertBlock();
731 732 733
    builder_->CreateBr(end_block);
    builder_->SetInsertPoint(else_block);
    llvm::Value* else_value = MakeValue(op->args[2]);
734
    BasicBlock* else_value_block = builder_->GetInsertBlock();
735 736
    builder_->CreateBr(end_block);
    builder_->SetInsertPoint(end_block);
737
    llvm::PHINode* value = builder_->CreatePHI(then_value->getType(), 2);
738 739
    value->addIncoming(then_value, then_value_block);
    value->addIncoming(else_value, else_value_block);
740
    return value;
741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760
  } else if (op->is_intrinsic(Call::reinterpret)) {
    llvm::Type * target = LLVMType(op->type);
    return builder_->CreateBitCast(MakeValue(op->args[0]), target);
  } else if (op->is_intrinsic("vectorlow")) {
    llvm::Value *v = MakeValue(op->args[0]);
    int l = v->getType()->getVectorNumElements();
    return CreateVecSlice(v, 0, l/2);
  } else if (op->is_intrinsic("vectorhigh")) {
    llvm::Value *v = MakeValue(op->args[0]);
    int l = v->getType()->getVectorNumElements();
    return CreateVecSlice(v, l/2, l/2);
  } else if (op->is_intrinsic("vectorcombine")) {
    llvm::Value *v0 = MakeValue(op->args[0]);
    llvm::Value *v1 = MakeValue(op->args[1]);
    int num_elems = static_cast<int>(v0->getType()->getVectorNumElements()) * 2;
    std::vector<unsigned> indices;
    for (int i = 0; i < num_elems; ++i) {
      indices.push_back(i);
    }
    return builder_->CreateShuffleVector(v0, v1, indices);
761
  } else {
762 763
    LOG(FATAL) << "unknown intrinsic " << op->name;
    return nullptr;
764
  }
765 766
}

767 768 769 770
void CodeGenLLVM::Scalarize(const Expr& e,
                            std::function<void(int i, llvm::Value* v)> f) {
  if (const Ramp* ramp = e.as<Ramp>()) {
    for (int i = 0; i < ramp->type.lanes(); ++i) {
771
      Expr offset = ramp->base + (ramp->stride * i);
772 773 774 775 776 777 778
      f(i, MakeValue(offset));
    }
  } else {
    llvm::Value* value = MakeValue(e);
    for (int i = 0; i < e.type().lanes(); ++i) {
      f(i, builder_->CreateExtractElement(value, i));
    }
779 780 781
  }
}

782 783

// Visitors
784 785
llvm::Value* CodeGenLLVM::VisitExpr_(const Variable* op) {
  return GetVarValue(op);
786
}
787 788 789

llvm::Value* CodeGenLLVM::VisitExpr_(const Cast* op) {
  return CreateCast(op->value.type(), op->type, MakeValue(op->value));
790
}
791 792
llvm::Value* CodeGenLLVM::VisitExpr_(const IntImm* op) {
  return llvm::ConstantInt::getSigned(LLVMType(op->type), op->value);
793 794
}

795 796 797 798 799 800 801 802 803 804 805 806
llvm::Value* CodeGenLLVM::VisitExpr_(const UIntImm* op) {
  return llvm::ConstantInt::get(LLVMType(op->type), op->value);
}

llvm::Value* CodeGenLLVM::VisitExpr_(const FloatImm* op) {
  return llvm::ConstantFP::get(LLVMType(op->type), op->value);
}

llvm::Value* CodeGenLLVM::VisitExpr_(const StringImm* op) {
  return GetConstString(op->value);
}

807 808
#define DEFINE_CODEGEN_BINARY_OP(Op)                                    \
  llvm::Value* CodeGenLLVM::Create ## Op(                               \
809
      Type t, llvm::Value* a, llvm::Value *b) {                         \
810 811 812 813 814 815 816 817 818 819 820 821
    if (t.is_int()) {                                                   \
      if (t.bits() >= 32) {                                             \
        return builder_->CreateNSW ## Op (a, b);                        \
      } else {                                                          \
        return builder_->Create ## Op (a, b);                           \
      }                                                                 \
    } else if (t.is_uint()) {                                           \
      if (t.bits() >= 32) {                                             \
        return builder_->CreateNUW ## Op (a, b);                        \
      } else {                                                          \
        return builder_->Create ## Op (a, b);                           \
      }                                                                 \
822
    } else {                                                            \
823 824
      CHECK(t.is_float());                                              \
      return builder_->CreateF ## Op (a, b);                            \
825 826
    }                                                                   \
  }                                                                     \
827 828 829
  llvm::Value* CodeGenLLVM::VisitExpr_(const Op* op) {                  \
    return Create ## Op(op->type, MakeValue(op->a), MakeValue(op->b));  \
  }
830 831 832 833 834

DEFINE_CODEGEN_BINARY_OP(Add);
DEFINE_CODEGEN_BINARY_OP(Sub);
DEFINE_CODEGEN_BINARY_OP(Mul);

835 836 837 838 839 840 841 842 843 844 845 846 847 848 849
#define DEFINE_CODEGEN_CMP_OP(Op)                                       \
  llvm::Value* CodeGenLLVM::Create ## Op(                               \
      Type t, llvm::Value* a, llvm::Value* b) {                         \
    if (t.is_int()) {                                                   \
      return builder_->CreateICmpS ## Op (a, b);                        \
    } else if (t.is_uint()) {                                           \
      return builder_->CreateICmpU ## Op (a, b);                        \
    } else {                                                            \
      CHECK(t.is_float());                                              \
      return builder_->CreateFCmpO ## Op (a, b);                        \
    }                                                                   \
}                                                                       \
  llvm::Value* CodeGenLLVM::VisitExpr_(const Op* op) {                  \
    return Create ## Op(op->a.type(), MakeValue(op->a), MakeValue(op->b)); \
  }
850

851 852 853 854
DEFINE_CODEGEN_CMP_OP(LT);
DEFINE_CODEGEN_CMP_OP(LE);
DEFINE_CODEGEN_CMP_OP(GT);
DEFINE_CODEGEN_CMP_OP(GE);
855 856 857

llvm::Value* CodeGenLLVM::VisitExpr_(const Div* op) {
  llvm::Value* a = MakeValue(op->a);
858
  llvm::Value* b = MakeValue(op->b);
859
  if (op->type.is_int()) {
860 861 862
    return builder_->CreateSDiv(a, b);
  } else if (op->type.is_uint()) {
    return builder_->CreateUDiv(a, b);
863
  } else {
864 865
    CHECK(op->type.is_float());
    return builder_->CreateFDiv(a, b);
866 867 868 869
  }
}

llvm::Value* CodeGenLLVM::VisitExpr_(const Mod* op) {
870 871
  llvm::Value* a = MakeValue(op->a);
  llvm::Value* b = MakeValue(op->b);
872
  if (op->type.is_int()) {
873 874 875
    return builder_->CreateSRem(a, b);
  } else if (op->type.is_uint()) {
    return builder_->CreateURem(a, b);
876
  } else {
877 878
    CHECK(op->type.is_float());
    return builder_->CreateFRem(a, b);
879 880 881 882 883 884
  }
}

llvm::Value* CodeGenLLVM::VisitExpr_(const Min* op) {
  llvm::Value* a = MakeValue(op->a);
  llvm::Value* b = MakeValue(op->b);
885
  return builder_->CreateSelect(CreateLT(op->a.type(), a, b), a, b);
886 887 888 889 890
}

llvm::Value* CodeGenLLVM::VisitExpr_(const Max* op) {
  llvm::Value* a = MakeValue(op->a);
  llvm::Value* b = MakeValue(op->b);
891
  return builder_->CreateSelect(CreateGT(op->a.type(), a, b), a, b);
892 893 894
}

llvm::Value* CodeGenLLVM::VisitExpr_(const EQ* op) {
895 896 897 898
  llvm::Value* a = MakeValue(op->a);
  llvm::Value* b = MakeValue(op->b);
  if (op->a.type().is_int() || op->a.type().is_uint()) {
    return builder_->CreateICmpEQ(a, b);
899
  } else {
900
    return builder_->CreateFCmpOEQ(a, b);
901 902 903 904
  }
}

llvm::Value* CodeGenLLVM::VisitExpr_(const NE* op) {
905 906 907 908
  llvm::Value* a = MakeValue(op->a);
  llvm::Value* b = MakeValue(op->b);
  if (op->a.type().is_int() || op->a.type().is_uint()) {
    return builder_->CreateICmpNE(a, b);
909
  } else {
910
    return builder_->CreateFCmpONE(a, b);
911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934
  }
}

llvm::Value* CodeGenLLVM::VisitExpr_(const And* op) {
  return builder_->CreateAnd(MakeValue(op->a), MakeValue(op->b));
}

llvm::Value* CodeGenLLVM::VisitExpr_(const Or* op) {
  return builder_->CreateOr(MakeValue(op->a), MakeValue(op->b));
}

llvm::Value* CodeGenLLVM::VisitExpr_(const Not* op) {
  return builder_->CreateNot(MakeValue(op->a));
}

llvm::Value* CodeGenLLVM::VisitExpr_(const Select* op) {
  return builder_->CreateSelect(
      MakeValue(op->condition),
      MakeValue(op->true_value),
      MakeValue(op->false_value));
}

llvm::Value* CodeGenLLVM::VisitExpr_(const Let* op) {
  CHECK(!var_map_.count(op->var.get()));
935
  var_map_[op->var.get()] = MakeValue(op->value);
936
  analyzer_->Bind(op->var, op->value);
937 938 939
  return MakeValue(op->body);
}

940
llvm::Value* CodeGenLLVM::VisitExpr_(const Load* op) {
941
  Type t = op->type;
942 943 944 945 946
  bool is_volatile = volatile_buf_.count(op->buffer_var.get());
  llvm::Value* buffer = MakeValue(op->buffer_var);
  llvm::Value* index = MakeValue(op->index);

  if (t.lanes() == 1) {
947 948
    int alignment, native_bits;
    GetAlignment(t, op->buffer_var.get(), op->index, &alignment, &native_bits);
949 950 951 952
    llvm::Value* ptr = CreateBufferPtr(t, buffer, index);
    llvm::LoadInst* load = builder_->CreateAlignedLoad(ptr, alignment, is_volatile);
    AddAliasInfo(load, op->buffer_var.get(), op->index, t);
    return load;
953
  } else {
954 955 956 957 958
    // vector load
    unsigned addrspace = llvm::dyn_cast<llvm::PointerType>(
      buffer->getType())->getAddressSpace();
    if (const Ramp* ramp = op->index.as<Ramp>()) {
      if (is_one(ramp->stride)) {
959 960
        int alignment, native_bits;
        GetAlignment(t, op->buffer_var.get(), ramp->base, &alignment, &native_bits);
961 962 963 964 965 966 967 968
        CHECK_EQ(ramp->lanes, t.lanes());
        llvm::Value* ptr = CreateBufferPtr(
            t.element_of(), buffer, MakeValue(ramp->base));
        ptr = builder_->CreatePointerCast(ptr, LLVMType(t)->getPointerTo(addrspace));
        llvm::LoadInst* load = builder_->CreateAlignedLoad(ptr, alignment, is_volatile);
        AddAliasInfo(load, op->buffer_var.get(), op->index, t);
        return load;
      }
969 970
    }
  }
971 972 973 974 975 976 977 978 979 980 981 982
  // scalarized load.
  int basic_align = t.bits() / 8;
  llvm::Value* ret = llvm::UndefValue::get(LLVMType(t));
  auto f = [&](int i, llvm::Value* index) {
    llvm::Value* ptr = CreateBufferPtr(t.element_of(), buffer, index);
    llvm::LoadInst* load = builder_->CreateAlignedLoad(
        ptr, basic_align, is_volatile);
    ret = builder_->CreateInsertElement(ret, load, ConstInt32(i));
    AddAliasInfo(load, op->buffer_var.get(), Expr(), t);
  };
  this->Scalarize(op->index, f);
  return ret;
983 984
}

985 986 987 988 989 990 991 992
llvm::Value* CodeGenLLVM::VisitExpr_(const Call* op) {
  if (op->call_type == Call::Intrinsic ||
      op->call_type == Call::PureIntrinsic) {
    return CreateIntrinsic(op);
  } else if (op->call_type == Call::Extern ||
             op->call_type == Call::PureExtern) {
    return CreateCallExtern(op);
  } else {
993 994 995
    LOG(FATAL) << "Unknown call type " <<
      "name= " << op->name <<
      " call_type= " << op->call_type;
996
    return nullptr;
997 998 999
  }
}

1000 1001 1002 1003 1004 1005
llvm::Value* CodeGenLLVM::VisitExpr_(const Ramp* op) {
  llvm::Value* vec = llvm::UndefValue::get(LLVMType(op->type));
  for (int i = 0; i < op->lanes; ++i) {
    vec = builder_->CreateInsertElement(
        vec, MakeValue(op->base + op->stride * make_const(op->stride.type(), i)),
        ConstInt32(i));
1006
  }
1007
  return vec;
1008 1009
}

1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029
llvm::Value* CodeGenLLVM::VisitExpr_(const Shuffle* op) {
  std::vector<llvm::Value *> vecs(op->vectors.size());
  int total_lanes = 0;
  for (int i = 0, e = op->vectors.size(); i < e; ++i) {
    vecs[i] = VisitExpr(op->vectors[i]);
    total_lanes += op->vectors[i].type().lanes();
  }
  llvm::Value* v0 = CreateVecConcat(vecs);
  std::vector<uint32_t> idx(op->indices.size());
  for (int i = 0, e = op->indices.size(); i < e; ++i) {
    const int64_t *val = as_const_int(op->indices[i]);
    CHECK(val && *val >= 0 && *val  < total_lanes) << "Shuffled indeces are suppose to be int, "
      << "but get " << op->indices[i] << "\n";
    idx[i] = *val;
  }
  llvm::Value* mask = llvm::ConstantDataVector::get(builder_->getContext(), idx);
  auto res = builder_->CreateShuffleVector(v0, llvm::UndefValue::get(v0->getType()), mask);
  return res;
}

1030 1031
llvm::Value* CodeGenLLVM::VisitExpr_(const Broadcast* op) {
  return CreateBroadcast(MakeValue(op->value), op->lanes);
1032 1033 1034
}

void CodeGenLLVM::VisitStmt_(const Store* op) {
1035
  CHECK(is_one(op->predicate));
1036
  Type t = op->value.type();
1037 1038 1039 1040
  bool is_volatile = volatile_buf_.count(op->buffer_var.get());
  llvm::Value* buffer = MakeValue(op->buffer_var);
  llvm::Value* index = MakeValue(op->index);
  llvm::Value* value = MakeValue(op->value);
1041

1042
  if (t.lanes() == 1) {
1043 1044
    int alignment, native_bits;
    GetAlignment(t, op->buffer_var.get(), op->index, &alignment, &native_bits);
1045 1046 1047 1048
    llvm::Value* ptr = CreateBufferPtr(t, buffer, index);
    llvm::StoreInst* store = builder_->CreateAlignedStore(value, ptr, alignment, is_volatile);
    AddAliasInfo(store, op->buffer_var.get(), op->index, op->value.type());
    return;
1049
  } else {
1050 1051 1052 1053 1054
    // vector store
    unsigned addrspace = llvm::dyn_cast<llvm::PointerType>(
        buffer->getType())->getAddressSpace();
    if (const Ramp* ramp = op->index.as<Ramp>()) {
      if (is_one(ramp->stride)) {
1055 1056
        int alignment, native_bits;
        GetAlignment(t, op->buffer_var.get(), ramp->base, &alignment, &native_bits);
1057 1058 1059 1060 1061 1062 1063 1064 1065
        CHECK_EQ(ramp->lanes, t.lanes());
        llvm::Value* ptr = CreateBufferPtr(
            t.element_of(), buffer, MakeValue(ramp->base));
        ptr = builder_->CreatePointerCast(ptr, LLVMType(t)->getPointerTo(addrspace));
        llvm::StoreInst* store = builder_->CreateAlignedStore(value, ptr, alignment, is_volatile);
        AddAliasInfo(store, op->buffer_var.get(), op->index, op->value.type());
        return;
      }
    }
1066
  }
1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077
  CHECK_GE(t.bits(), 8);
  // scalarized store.
  int basic_align = t.bits() / 8;
  auto f = [&](int i, llvm::Value* index) {
    llvm::Value* ptr = CreateBufferPtr(t.element_of(), buffer, index);
    llvm::StoreInst* store = builder_->CreateAlignedStore(
        builder_->CreateExtractElement(value, i),
        ptr, basic_align, is_volatile);
    AddAliasInfo(store, op->buffer_var.get(), Expr(), op->value.type());
  };
  this->Scalarize(op->index, f);
1078 1079 1080 1081
}

void CodeGenLLVM::VisitStmt_(const For* op) {
  CHECK(is_zero(op->min));
1082
  analyzer_->Bind(op->loop_var, Range::make_by_min_extent(op->min, op->extent));
1083 1084 1085 1086 1087 1088
  if (op->for_type == ForType::Unrolled) {
    LOG(WARNING) << "Unroll hint get ignore at CodeGenLLVM backend, "
                 << " consider set unroll_explicit=True";
  } else {
    CHECK(op->for_type == ForType::Serial);
  }
1089 1090
  CreateSerialFor(MakeValue(op->min), MakeValue(op->extent),
                  ConstInt32(1), op->loop_var, op->body);
1091 1092
}

1093

1094 1095
void CodeGenLLVM::VisitStmt_(const IfThenElse* op) {
  using llvm::BasicBlock;
1096
  llvm::Value* cond = MakeValue(op->condition);
1097 1098 1099 1100 1101
  BasicBlock* then_block = BasicBlock::Create(
      *ctx_, "if_then", function_);
  BasicBlock* end_block = BasicBlock::Create(
      *ctx_, "if_end", function_);
  if (op->else_case.defined()) {
1102 1103 1104 1105 1106 1107
    BasicBlock* else_block = BasicBlock::Create(
        *ctx_, "if_else", function_);
    builder_->CreateCondBr(cond, then_block, else_block);
    builder_->SetInsertPoint(then_block);
    this->VisitStmt(op->then_case);
    builder_->CreateBr(end_block);
1108 1109 1110
    builder_->SetInsertPoint(else_block);
    this->VisitStmt(op->else_case);
    builder_->CreateBr(end_block);
1111 1112 1113 1114 1115
  } else {
    builder_->CreateCondBr(cond, then_block, end_block, md_very_likely_branch_);
    builder_->SetInsertPoint(then_block);
    this->VisitStmt(op->then_case);
    builder_->CreateBr(end_block);
1116 1117 1118 1119
  }
  builder_->SetInsertPoint(end_block);
}

1120

1121 1122 1123
void CodeGenLLVM::VisitStmt_(const Allocate* op) {
  CHECK(!is_zero(op->condition));
  llvm::Value* buf = nullptr;
1124 1125 1126 1127 1128 1129
  if (op->new_expr.defined()) {
    CHECK_EQ(op->free_function, "nop");
    buf = MakeValue(op->new_expr);
  } else {
    int32_t constant_size = op->constant_allocation_size();
    CHECK_GT(constant_size, 0)
1130
        << "Can only handle constant size stack allocation";
1131 1132
    StorageInfo& info = alloc_storage_info_[op->buffer_var.get()];
    if (constant_size % 4 == 0 && info.alignment == 0) {
1133
      info.alignment = GetTempAllocaAlignment(op->type, constant_size);
1134
    }
1135 1136 1137 1138
    // maximum necessary alignment in the NV devices
    if (info.alignment > 16) {
      info.alignment = 16;
    }
1139 1140 1141 1142
    llvm::AllocaInst* alloca = WithFunctionEntry([&]() {
        return builder_->CreateAlloca(
            LLVMType(op->type), ConstInt32(constant_size));
      });
1143 1144 1145 1146
    if (alloca->getAlignment() < static_cast<uint32_t>(info.alignment)) {
      alloca->setAlignment(info.alignment);
    }
    info.alignment = alloca->getAlignment();
1147
    buf = alloca;
1148
  }
1149 1150 1151
  buf = builder_->CreatePointerCast(
      buf, LLVMType(op->type)->getPointerTo(
          buf->getType()->getPointerAddressSpace()));
1152 1153
  CHECK(!var_map_.count(op->buffer_var.get()));
  var_map_[op->buffer_var.get()] = buf;
1154
  this->VisitStmt(op->body);
1155 1156
}

1157
void CodeGenLLVM::VisitStmt_(const AttrStmt* op) {
1158
  if (op->attr_key == attr::thread_extent) {
1159 1160 1161 1162
    IterVar iv(op->node.node_);
    if (iv->thread_tag.length() != 0) {
      if (!var_map_.count(iv->var.get())) {
        var_map_[iv->var.get()] = GetThreadIndex(iv);
1163
        analyzer_->Bind(iv->var, Range::make_by_min_extent(0, op->value));
1164 1165 1166
      }
    }
  } else if (op->attr_key == ir::attr::storage_scope) {
1167 1168
    const Variable* v = op->node.as<Variable>();
    CHECK(v);
1169 1170
    alloc_storage_info_[v].scope =
        runtime::StorageScope::make(op->value.as<StringImm>()->value);
1171 1172 1173 1174 1175
  } else if (op->attr_key == ir::attr::storage_alignment) {
    const Variable* v = op->node.as<Variable>();
    CHECK(v);
    alloc_storage_info_[v].alignment =
        static_cast<int>(op->value.as<IntImm>()->value);
1176 1177 1178 1179
  } else if (op->attr_key == ir::attr::volatile_scope) {
    const Variable* v = op->node.as<Variable>();
    CHECK(v);
    volatile_buf_.insert(v);
1180
  }
1181
  this->VisitStmt(op->body);
1182 1183
}

1184
void CodeGenLLVM::VisitStmt_(const AssertStmt* op) {
1185
  With<arith::ConstraintContext> cctx(analyzer_.get(), op->condition);
1186
  this->VisitStmt(op->body);
1187 1188
}

1189
void CodeGenLLVM::VisitStmt_(const LetStmt* op) {
1190
  CHECK(!var_map_.count(op->var.get()));
1191 1192 1193 1194 1195
  if (op->var.type().is_handle()) {
    if (!is_restricted_) {
      alias_var_set_.insert(op->var.get());
    }
  }
1196
  var_map_[op->var.get()] = MakeValue(op->value);
1197
  analyzer_->Bind(op->var, op->value);
1198
  this->VisitStmt(op->body);
1199
}
1200

1201
void CodeGenLLVM::VisitStmt_(const Block* op) {
1202 1203 1204 1205
  this->VisitStmt(op->first);
  if (op->rest.defined()) {
    this->VisitStmt(op->rest);
  }
1206
}
1207

1208
void CodeGenLLVM::VisitStmt_(const Evaluate* op) {
1209 1210
  MakeValue(op->value);
}
1211

1212
void CodeGenLLVM::VisitStmt_(const ProducerConsumer* op) {
1213
  this->VisitStmt(op->body);
1214 1215 1216 1217
}
}  // namespace codegen
}  // namespace tvm
#endif  // TVM_LLVM_VERSION