Commit 396c161f by Tianqi Chen Committed by GitHub

Acknowledge related projects (#465)

* [CODEGEN] Redo CodegenLLVM.

* Add remarks about origin of the pass

Properly acknowledge related projects

* Fix and expression
parent 2607a836
...@@ -28,3 +28,12 @@ TVM adopts apache committer model, we aim to create an open source project that ...@@ -28,3 +28,12 @@ TVM adopts apache committer model, we aim to create an open source project that
- [Contributor Guide](docs/how_to/contribute.md) - [Contributor Guide](docs/how_to/contribute.md)
- Please add your name to [CONTRIBUTORS.md](CONTRIBUTORS.md) - Please add your name to [CONTRIBUTORS.md](CONTRIBUTORS.md)
- Please also update [NEWS.md](NEWS.md) on changes and improvements in API and codes. - Please also update [NEWS.md](NEWS.md) on changes and improvements in API and codes.
Acknowledgement
---------------
We learnt a lot from the following projects when building TVM.
- [Halide](https://github.com/halide/Halide): TVM uses [HalideIR](https://github.com/dmlc/HalideIR) as data structure for
arithematic simplification and low level lowering. HalideIR is derived from Halide.
We also learns from Halide when implementing the lowering pipeline in TVM.
- [Loopy](https://github.com/inducer/loopy): use of integer set analysis and its loop transformation primitives.
- [Theano](https://github.com/Theano/Theano): the design inspiration of symbolic scan operator for recurrence.
...@@ -32,29 +32,24 @@ void CodeGenLLVM::Init(const std::string& module_name, ...@@ -32,29 +32,24 @@ void CodeGenLLVM::Init(const std::string& module_name,
bool system_lib, bool system_lib,
bool dynamic_lookup) { bool dynamic_lookup) {
InitializeLLVM(); InitializeLLVM();
// clear maps
var_map_.clear();
str_map_.clear();
ctx_ = ctx; ctx_ = ctx;
t_void_ = llvm::Type::getVoidTy(*ctx); builder_.reset(new IRBuilder(*ctx_));
t_void_p_ = llvm::Type::getInt8Ty(*ctx)->getPointerTo(); module_.reset(new llvm::Module(module_name, *ctx_));
t_int_ = llvm::Type::getIntNTy(*ctx, sizeof(int) * 8); md_builder_.reset(new llvm::MDBuilder(*ctx_));
t_char_ = llvm::Type::getInt8Ty(*ctx); // types
t_int8_ = llvm::Type::getInt8Ty(*ctx); t_void_ = llvm::Type::getVoidTy(*ctx_);
t_int16_ = llvm::Type::getInt16Ty(*ctx); t_void_p_ = llvm::Type::getInt8Ty(*ctx_)->getPointerTo();
t_int32_ = llvm::Type::getInt32Ty(*ctx); t_int_ = llvm::Type::getInt32Ty(*ctx_);
t_int64_ = llvm::Type::getInt64Ty(*ctx); t_char_ = llvm::Type::getInt8Ty(*ctx_);
t_float64_ = llvm::Type::getDoubleTy(*ctx); t_int8_ = llvm::Type::getInt8Ty(*ctx_);
md_builder_.reset(new llvm::MDBuilder(*ctx)); t_int16_ = llvm::Type::getInt16Ty(*ctx_);
md_very_likely_branch_ = t_int32_ = llvm::Type::getInt32Ty(*ctx_);
md_builder_->createBranchWeights(1 << 30, 0); t_int64_ = llvm::Type::getInt64Ty(*ctx_);
md_tbaa_root_ = md_builder_->createTBAARoot("tvmtbaa"); t_float64_ = llvm::Type::getDoubleTy(*ctx_);
md_tbaa_alias_set_ = md_builder_->createTBAAScalarTypeNode( // meta data
"alias_set", md_tbaa_root_); md_very_likely_branch_ = md_builder_->createBranchWeights(1<<20, 1);
// initialize Modules and function type md_tbaa_root_ = md_builder_->createTBAARoot("tvm-tbaa");
module_.reset(new llvm::Module(module_name, *ctx)); md_tbaa_alias_set_ = md_builder_->createTBAANode("tvm-alias", md_tbaa_root_);
// initialize builder
builder_.reset(new IRBuilder(*ctx));
this->InitTarget(tm); this->InitTarget(tm);
} }
...@@ -63,61 +58,73 @@ void CodeGenLLVM::InitTarget(llvm::TargetMachine* tm) { ...@@ -63,61 +58,73 @@ void CodeGenLLVM::InitTarget(llvm::TargetMachine* tm) {
module_->setDataLayout(tm->createDataLayout()); module_->setDataLayout(tm->createDataLayout());
data_layout_.reset(new llvm::DataLayout(module_.get())); data_layout_.reset(new llvm::DataLayout(module_.get()));
target_machine_ = tm; target_machine_ = tm;
// initialize native vector bits if (native_vector_bits_ == 0) {
std::string target = tm->getTarget().getName(); const auto& arch = tm->getTargetTriple().getArch();
if (target == "x86-64") { if (arch == llvm::Triple::x86_64) {
// for avx512 // for avx512
native_vector_bits_ = 64 * 8; native_vector_bits_ = 512;
} else if (target == "x86") { } else if (arch == llvm::Triple::x86) {
native_vector_bits_ = 32 * 8; native_vector_bits_ = 256;
} else if (arch == llvm::Triple::arm || arch == llvm::Triple::aarch64) {
native_vector_bits_ = 128;
} else { } else {
if (native_vector_bits_ == 0) { native_vector_bits_ = 128;
native_vector_bits_ = 32 * 8; std::string arch_name = tm->getTargetTriple().getArchName();
LOG(WARNING) << "set native vector to be " << native_vector_bits_ / 8 LOG(WARNING) << "Set native vector bits to be 128 for " << arch_name;
<< " for target " << target;
} }
} }
} }
void CodeGenLLVM::AddFunction(const LoweredFunc& f) {
this->AddFunctionInternal(f, false);
}
void CodeGenLLVM::InitFuncState() { void CodeGenLLVM::InitFuncState() {
var_map_.clear(); var_map_.clear();
alias_var_set_.clear();
align_map_.clear(); align_map_.clear();
alloc_storage_info_.clear(); alloc_storage_info_.clear();
alias_var_set_.clear(); volatile_buf_.clear();
}
void CodeGenLLVM::AddFunction(const LoweredFunc& f) {
AddFunctionInternal(f, false);
} }
void CodeGenLLVM::AddFunctionInternal(const LoweredFunc& f, bool ret_void) { void CodeGenLLVM::AddFunctionInternal(const LoweredFunc& f, bool ret_void) {
this->InitFuncState(); this->InitFuncState();
std::vector<llvm::Type*> arg_types;
is_restricted_ = f->is_restricted; is_restricted_ = f->is_restricted;
CHECK(!module_->getFunction(f->name))
<< "Function " << f->name << "already exists in module";
std::vector<llvm::Type*> arg_type;
for (Var arg : f->args) { for (Var arg : f->args) {
Type t = arg.type(); Type t = arg.type();
if (t.is_handle() && f->handle_data_type.count(arg)) { if (t.is_handle()) {
arg_type.push_back( auto it = f->handle_data_type.find(arg);
LLVMType(f->handle_data_type[arg].type())->getPointerTo(GetGlobalAddressSpace())); if (it != f->handle_data_type.end()) {
arg_types.push_back(LLVMType((*it).second.type())
->getPointerTo(GetGlobalAddressSpace()));
} else {
arg_types.push_back(t_int8_->getPointerTo(GetGlobalAddressSpace()));
}
if (!is_restricted_) { if (!is_restricted_) {
alias_var_set_.insert(arg.get()); alias_var_set_.insert(arg.get());
} }
} else { } else {
arg_type.push_back(LLVMType(t)); arg_types.push_back(LLVMType(arg.type()));
} }
} }
llvm::FunctionType* ftype = llvm::FunctionType::get( llvm::FunctionType* ftype = llvm::FunctionType::get(
ret_void ? t_void_ : t_int_, arg_type, false); ret_void ? t_void_ : t_int_, arg_types, false);
// setup the function. CHECK(module_->getFunction(f->name) == nullptr)
function_ = llvm::cast<llvm::Function>(module_->getOrInsertFunction(f->name, ftype)); << "Function " << f->name << " already exist in module";
function_ = llvm::Function::Create(
ftype, llvm::Function::ExternalLinkage,
f->name, module_.get());
function_->setCallingConv(llvm::CallingConv::C); function_->setCallingConv(llvm::CallingConv::C);
// set handle argument to be non alias. // set var map and align information
auto arg_it = function_->arg_begin();
for (size_t i = 0; i < f->args.size(); ++i, ++arg_it) {
llvm::Argument* v = &(*arg_it);
const Var& var = f->args[i];
var_map_[var.get()] = v;
if (is_restricted_) { if (is_restricted_) {
for (size_t i = 0; i < f->args.size(); ++i) { if (var.type().is_handle() && !alias_var_set_.count(var.get())) {
if (f->args[i].type().is_handle()) { // set non alias.
#if TVM_LLVM_VERSION >= 50 #if TVM_LLVM_VERSION >= 50
function_->addParamAttr(i, llvm::Attribute::NoAlias); function_->addParamAttr(i, llvm::Attribute::NoAlias);
#else #else
...@@ -126,17 +133,8 @@ void CodeGenLLVM::AddFunctionInternal(const LoweredFunc& f, bool ret_void) { ...@@ -126,17 +133,8 @@ void CodeGenLLVM::AddFunctionInternal(const LoweredFunc& f, bool ret_void) {
} }
} }
} }
llvm::BasicBlock* entry = llvm::BasicBlock::Create(*ctx_, "entry", function_);
size_t idx = 0; builder_->SetInsertPoint(entry);
for (auto it = function_->arg_begin();
it != function_->arg_end(); ++it, ++idx) {
llvm::Argument* v = &(*it);
var_map_[f->args[idx].get()] = v;
}
llvm::BasicBlock* block = llvm::BasicBlock::Create(*ctx_, "entry", function_);
builder_->SetInsertPoint(block);
this->VisitStmt(f->body); this->VisitStmt(f->body);
if (ret_void) { if (ret_void) {
builder_->CreateRetVoid(); builder_->CreateRetVoid();
...@@ -145,8 +143,24 @@ void CodeGenLLVM::AddFunctionInternal(const LoweredFunc& f, bool ret_void) { ...@@ -145,8 +143,24 @@ void CodeGenLLVM::AddFunctionInternal(const LoweredFunc& f, bool ret_void) {
} }
} }
std::unique_ptr<llvm::Module> CodeGenLLVM::Finish() {
this->AddStartupFunction();
this->Optimize();
return std::move(module_);
}
void CodeGenLLVM::AddMainFunction(const std::string& entry_func_name) { void CodeGenLLVM::AddMainFunction(const std::string& entry_func_name) {
LOG(FATAL) << "Donot support add main function"; LOG(FATAL) << "not implemented";
}
llvm::Value* CodeGenLLVM::GetThreadIndex(const IterVar& iv) {
LOG(FATAL) << "not implemented";
return nullptr;
}
llvm::Value* CodeGenLLVM::CreateStorageSync(const Call* op) {
LOG(FATAL) << "not implemented";
return nullptr;
} }
class FPassManager : public llvm::legacy::FunctionPassManager { class FPassManager : public llvm::legacy::FunctionPassManager {
...@@ -202,36 +216,48 @@ void CodeGenLLVM::Optimize() { ...@@ -202,36 +216,48 @@ void CodeGenLLVM::Optimize() {
mpass.run(*module_); mpass.run(*module_);
} }
std::unique_ptr<llvm::Module> CodeGenLLVM::Finish() { int CodeGenLLVM::NativeVectorBits(const runtime::StorageScope& storage_scope) const {
this->AddStartupFunction(); return native_vector_bits_;
this->Optimize(); }
return std::move(module_);
unsigned CodeGenLLVM::GetGlobalAddressSpace() {
return 0;
} }
llvm::Type* CodeGenLLVM::LLVMType(const Type& t) const { llvm::Type* CodeGenLLVM::LLVMType(const Type& t) const {
llvm::Type* ret = nullptr; if (t.is_handle()) {
if (t.is_uint() || t.is_int()) { CHECK_EQ(t.lanes(), 1);
ret = llvm::Type::getIntNTy(*ctx_, t.bits()); return t_void_p_;
}
llvm::Type* etype;
if (t.is_int() || t.is_uint()) {
etype = llvm::Type::getIntNTy(*ctx_, t.bits());
} else if (t.is_float()) { } else if (t.is_float()) {
switch (t.bits()) { switch (t.bits()) {
case 16: ret = llvm::Type::getHalfTy(*ctx_); break; case 16: etype = llvm::Type::getHalfTy(*ctx_); break;
case 32: ret = llvm::Type::getFloatTy(*ctx_); break; case 32: etype = llvm::Type::getFloatTy(*ctx_); break;
case 64: ret = llvm::Type::getDoubleTy(*ctx_); break; case 64: etype = llvm::Type::getDoubleTy(*ctx_); break;
default: LOG(FATAL) << "cannot handle " << t; default: LOG(FATAL) << "do not support " << t;
} }
} else {
CHECK(t.is_handle());
ret = t_void_p_;
} }
if (t.lanes() != 1) { if (t.lanes() != 1) {
ret = llvm::VectorType::get(ret, t.lanes()); return llvm::VectorType::get(etype, t.lanes());
} else {
return etype;
} }
return ret;
} }
// Add tbaa alias information for load
void CodeGenLLVM::AddAliasInfo( //
llvm::Instruction* inst, const Variable* buffer, Expr index, Type t) { // use a binary tree typed system to declare information
// and allow alias to be distinguished across nodes.
//
// This trick comes from Halide's CodeGen_LLVM
//
void CodeGenLLVM::AddAliasInfo(llvm::Instruction* inst,
const Variable* buffer,
Expr index,
Type type) {
if (alias_var_set_.count(buffer) != 0) { if (alias_var_set_.count(buffer) != 0) {
// Mark all possibly aliased pointer as same type. // Mark all possibly aliased pointer as same type.
llvm::MDNode* meta = md_tbaa_alias_set_; llvm::MDNode* meta = md_tbaa_alias_set_;
...@@ -242,7 +268,7 @@ void CodeGenLLVM::AddAliasInfo( ...@@ -242,7 +268,7 @@ void CodeGenLLVM::AddAliasInfo(
} }
int base = 0, width = 0; int base = 0, width = 0;
// create meta-data for alias analysis // create meta-data for alias analysis
// Use a group of binary tree ranges. // Use a group of binary tree ranges of memory banks.
if (index.defined()) { if (index.defined()) {
const Ramp* ramp = index.as<Ramp>(); const Ramp* ramp = index.as<Ramp>();
if (ramp) { if (ramp) {
...@@ -267,7 +293,7 @@ void CodeGenLLVM::AddAliasInfo( ...@@ -267,7 +293,7 @@ void CodeGenLLVM::AddAliasInfo(
std::ostringstream buffer_addr, buffer_type; std::ostringstream buffer_addr, buffer_type;
buffer_addr << buffer; buffer_addr << buffer;
meta = md_builder_->createTBAAScalarTypeNode(buffer_addr.str(), meta); meta = md_builder_->createTBAAScalarTypeNode(buffer_addr.str(), meta);
buffer_type << t.element_of(); buffer_type << type.element_of();
meta = md_builder_->createTBAAScalarTypeNode(buffer_type.str(), meta); meta = md_builder_->createTBAAScalarTypeNode(buffer_type.str(), meta);
// create a tree-shape access structure. // create a tree-shape access structure.
if (width != 0) { if (width != 0) {
...@@ -283,6 +309,36 @@ void CodeGenLLVM::AddAliasInfo( ...@@ -283,6 +309,36 @@ void CodeGenLLVM::AddAliasInfo(
md_builder_->createTBAAStructTagNode(meta, meta, 0)); md_builder_->createTBAAStructTagNode(meta, meta, 0));
} }
void CodeGenLLVM::GetAlignment(Type t,
const Variable* buf_var,
const Expr& index,
int* p_alignment,
int* p_native_bits) {
int max_align_bits = t.bits();
auto it = alloc_storage_info_.find(buf_var);
if (it != alloc_storage_info_.end()) {
const StorageInfo& info = it->second;
*p_native_bits = NativeVectorBits(info.scope);
max_align_bits = info.alignment * 8;
} else {
*p_native_bits = native_vector_bits_;
}
arith::ModularEntry me = arith::EvalModular(index, align_map_);
int align_bits = t.bits();
while (align_bits < max_align_bits &&
me.base % 2 == 0 &&
me.coeff %2 == 0) {
me.base = me.base / 2;
me.coeff = me.coeff / 2;
align_bits *= 2;
}
if (align_bits < 8) {
align_bits = 8;
}
*p_alignment = align_bits / 8;
}
llvm::Value* CodeGenLLVM::CreateBroadcast(llvm::Value* value, int lanes) { llvm::Value* CodeGenLLVM::CreateBroadcast(llvm::Value* value, int lanes) {
llvm::Constant* undef = llvm::UndefValue::get( llvm::Constant* undef = llvm::UndefValue::get(
llvm::VectorType::get(value->getType(), lanes)); llvm::VectorType::get(value->getType(), lanes));
...@@ -292,26 +348,103 @@ llvm::Value* CodeGenLLVM::CreateBroadcast(llvm::Value* value, int lanes) { ...@@ -292,26 +348,103 @@ llvm::Value* CodeGenLLVM::CreateBroadcast(llvm::Value* value, int lanes) {
return builder_->CreateShuffleVector(value, undef, mask); return builder_->CreateShuffleVector(value, undef, mask);
} }
llvm::Value* CodeGenLLVM::CreateBufferPtr( llvm::Value* CodeGenLLVM::CreateVecSlice(llvm::Value* vec, int begin, int extent) {
Type t, llvm::Value* buffer, llvm::Value* index) { int num_elems = static_cast<int>(vec->getType()->getVectorNumElements());
llvm::Type* elem_type = buffer->getType(); if (extent == num_elems && begin == 0) return vec;
unsigned address_space = elem_type->getPointerAddressSpace(); CHECK_LT(begin + extent, num_elems);
llvm::Type* load_type = LLVMType(t)->getPointerTo(address_space); std::vector<unsigned> indices;
for (int i = 0; i < extent; ++i) {
indices.push_back(begin + i);
}
return builder_->CreateShuffleVector(vec, vec, indices);
}
llvm::Value* CodeGenLLVM::CreateVecFlip(llvm::Value* vec) {
int num_elems = static_cast<int>(vec->getType()->getVectorNumElements());
std::vector<unsigned> indices;
for (int i = 0; i < num_elems; ++i) {
indices.push_back(num_elems - i - 1);
}
return builder_->CreateShuffleVector(vec, vec, indices);
}
if (load_type != elem_type) { llvm::Value* CodeGenLLVM::CreateVecPad(llvm::Value* vec, int target_lanes) {
buffer = builder_->CreatePointerCast(buffer, load_type); llvm::Value* mask = llvm::UndefValue::get(LLVMType(Int(32, target_lanes)));
int num_elems = static_cast<int>(vec->getType()->getVectorNumElements());
if (num_elems == target_lanes) return vec;
CHECK_LT(num_elems, target_lanes);
for (int i = 0; i < num_elems; ++i) {
mask = builder_->CreateInsertElement(mask, ConstInt32(i), ConstInt32(i));
} }
llvm::Constant* cindex = llvm::dyn_cast<llvm::Constant>(index); return builder_->CreateShuffleVector(vec, vec, mask);
if (cindex && cindex->isZeroValue()) { }
return buffer;
llvm::Value* CodeGenLLVM::CreateVecConcat(std::vector<llvm::Value*> vecs) {
// concat vector, tree shape reduction
int total_lanes = 0;
for (llvm::Value* v : vecs) {
total_lanes += static_cast<int>(
v->getType()->getVectorNumElements());
} }
return builder_->CreateInBoundsGEP(buffer, index); while (vecs.size() > 1) {
for (size_t i = 0; i < vecs.size(); i+=2) {
if (i + 1 >= vecs.size()) {
vecs[i / 2] = vecs[i]; continue;
}
llvm::Value* lhs = vecs[i];
llvm::Value* rhs = vecs[i + 1];
int lanes = static_cast<int>(std::max(
lhs->getType()->getVectorNumElements(),
rhs->getType()->getVectorNumElements()));
lhs = CreateVecPad(lhs, lanes);
rhs = CreateVecPad(lhs, lanes);
std::vector<unsigned> mask;
for (int i = 0; i < lanes * 2; ++i) {
mask.push_back(i);
}
vecs[i / 2] = builder_->CreateShuffleVector(lhs, rhs, mask);
}
vecs.resize((vecs.size() + 1) / 2);
}
return CreateVecSlice(vecs[0], 0, total_lanes);
}
void CodeGenLLVM::CreateSerialFor(llvm::Value* begin,
llvm::Value* end,
llvm::Value* stride,
const VarExpr& loop_var,
const Stmt& body) {
using llvm::BasicBlock;
BasicBlock* pre_block = builder_->GetInsertBlock();
BasicBlock* for_begin = BasicBlock::Create(
*ctx_, "for_begin", function_);
BasicBlock* for_body = BasicBlock::Create(
*ctx_, "for_body", function_);
BasicBlock* for_end = BasicBlock::Create(
*ctx_, "for_end", function_);
builder_->CreateBr(for_begin);
builder_->SetInsertPoint(for_begin);
llvm::PHINode* loop_value = builder_->CreatePHI(begin->getType(), 2);
loop_value->addIncoming(begin, pre_block);
CHECK(!var_map_.count(loop_var.get()));
var_map_[loop_var.get()] = loop_value;
builder_->CreateCondBr(CreateLT(loop_var.type(), loop_value, end),
for_body, for_end, md_very_likely_branch_);
builder_->SetInsertPoint(for_body);
this->VisitStmt(body);
var_map_.erase(loop_var.get());
llvm::Value* loop_next = CreateAdd(loop_var.type(), loop_value, stride);
loop_value->addIncoming(loop_next, builder_->GetInsertBlock());
builder_->CreateBr(for_begin);
builder_->SetInsertPoint(for_end);
} }
// cast operatpr
llvm::Value* CodeGenLLVM::CreateCast(Type from, Type to, llvm::Value* value) { llvm::Value* CodeGenLLVM::CreateCast(Type from, Type to, llvm::Value* value) {
llvm::Type * target = LLVMType(to); llvm::Type * target = LLVMType(to);
if (value->getType() == target) return value; if (value->getType() == target) return value;
if (from.is_handle() && from.is_handle()) { if (to.is_handle()) {
return builder_->CreateBitCast(value, target); return builder_->CreateBitCast(value, target);
} else if (!from.is_float() && !to.is_float()) { } else if (!from.is_float() && !to.is_float()) {
return builder_->CreateIntCast(value, target, from.is_int()); return builder_->CreateIntCast(value, target, from.is_int());
...@@ -334,262 +467,159 @@ llvm::Value* CodeGenLLVM::CreateCast(Type from, Type to, llvm::Value* value) { ...@@ -334,262 +467,159 @@ llvm::Value* CodeGenLLVM::CreateCast(Type from, Type to, llvm::Value* value) {
} }
} }
llvm::CallInst* CodeGenLLVM::CreateCallExtern(
llvm::Type* ret,
const std::string& name,
const std::vector<llvm::Value*>& arg_values) {
std::vector<llvm::Type*> arg_types;
for (llvm::Value* v : arg_values) {
arg_types.push_back(v->getType());
}
llvm::FunctionType* ftype = llvm::FunctionType::get(ret, arg_types, false);
llvm::Function* f = module_->getFunction(name);
if (f == nullptr) {
f = llvm::Function::Create(
ftype, llvm::Function::ExternalLinkage, name, module_.get());
}
return builder_->CreateCall(f, arg_values);
}
llvm::Value* CodeGenLLVM::CreateCallExtern(const Call* op) {
std::vector<llvm::Value*> arg_values(op->args.size());
for (size_t i = 0; i < op->args.size(); ++i) {
arg_values[i] = MakeValue(op->args[i]);
}
return CreateCallExtern(LLVMType(op->type), op->name, arg_values);
}
llvm::Value* CodeGenLLVM::CreateScalarizedCall(
const Call* op, llvm::Function* f, const std::vector<llvm::Value*>& args) {
llvm::Value* value = llvm::UndefValue::get(LLVMType(op->type));
for (int i = 0; i < op->type.lanes(); ++i) {
std::vector<llvm::Value*> sargs(args.size());
for (size_t j = 0; j < args.size(); ++j) {
if (args[j]->getType()->isVectorTy()) {
sargs[j] = builder_->CreateExtractElement(args[j], ConstInt32(i));
} else {
sargs[j] = args[j];
}
}
llvm::CallInst* call = builder_->CreateCall(f, sargs);
if (op->is_pure()) {
call->setDoesNotAccessMemory();
}
call->setDoesNotThrow();
if (!call->getType()->isVoidTy()) {
value = builder_->CreateInsertElement(value, call, ConstInt32(i));
}
}
return value;
}
llvm::Value* CodeGenLLVM::GetVarValue(const Variable* v) const {
auto it = var_map_.find(v);
CHECK(it != var_map_.end())
<< "Cannot find " << v->name_hint << " in the var map";
return it->second;
}
llvm::Value* CodeGenLLVM::GetConstString(const std::string& str) { llvm::Value* CodeGenLLVM::GetConstString(const std::string& str) {
auto it = str_map_.find(str); auto it = str_map_.find(str);
if (it == str_map_.end()) { if (it != str_map_.end()) return it->second;
llvm::Type* type = llvm::ArrayType::get(t_char_, str.length() + 1); llvm::Type* type = llvm::ArrayType::get(t_char_, str.length() + 1);
llvm::GlobalVariable *global = new llvm::GlobalVariable( llvm::GlobalVariable *global = new llvm::GlobalVariable(
*module_, type, true, llvm::GlobalValue::PrivateLinkage, 0, ".str"); *module_, type, true, llvm::GlobalValue::PrivateLinkage, 0, ".str");
global->setAlignment(1); global->setAlignment(1);
global->setInitializer(llvm::ConstantDataArray::getString(*ctx_, str)); global->setInitializer(llvm::ConstantDataArray::getString(*ctx_, str));
// useful constant value
llvm::Constant* zero = ConstInt32(0); llvm::Constant* zero = ConstInt32(0);
llvm::Constant* indices[] = {zero, zero}; llvm::Constant* indices[] = {zero, zero};
llvm::Constant* sptr = llvm::ConstantExpr::getGetElementPtr( llvm::Constant* ptr = llvm::ConstantExpr::getGetElementPtr(
type, global, indices); type, global, indices);
str_map_[str] = sptr; str_map_[str] = ptr;
return sptr; return ptr;
} else { }
return it->second;
llvm::Value* CodeGenLLVM::CreateBufferPtr(
Type t, llvm::Value* buffer, llvm::Value* index) {
CHECK_EQ(t.lanes(), 1);
llvm::PointerType* btype = llvm::dyn_cast<llvm::PointerType>(buffer->getType());
CHECK(btype != nullptr);
llvm::PointerType* ptype = LLVMType(t)->getPointerTo(btype->getAddressSpace());
if (btype != ptype) {
buffer = builder_->CreatePointerCast(buffer, ptype);
} }
return builder_->CreateInBoundsGEP(buffer, index);
} }
void CodeGenLLVM::CreateSerialFor(llvm::Value* begin, llvm::Value* CodeGenLLVM::GetVarValue(const Variable* v) const {
llvm::Value* end, auto it = var_map_.find(v);
llvm::Value* stride, CHECK(it != var_map_.end()) << "cannot find variable " << v->name_hint;
const VarExpr& loop_var, const Stmt& body) { return it->second;
using llvm::BasicBlock; }
Type t = loop_var.type();
BasicBlock* for_head = BasicBlock::Create( llvm::Value* CodeGenLLVM::CreateCallExtern(const Call* op) {
*ctx_, "for_head", function_); CHECK_GE(op->args.size(), 1U);
BasicBlock* for_body = BasicBlock::Create( std::vector<llvm::Value*> arg_value;
*ctx_, "for_body", function_); std::vector<llvm::Type*> arg_type;
BasicBlock* for_end = BasicBlock::Create( for (size_t i = 1; i < op->args.size(); ++i) {
*ctx_, "for_end", function_); arg_value.push_back(MakeValue(op->args[i + 1]));
BasicBlock* pre_block = builder_->GetInsertBlock(); arg_type.push_back(arg_value.back()->getType());
builder_->CreateBr(for_head); }
builder_->SetInsertPoint(for_head); llvm::FunctionType* ftype = llvm::FunctionType::get(
llvm::PHINode* index = builder_->CreatePHI(begin->getType(), 2); LLVMType(op->type), arg_type, false);
index->addIncoming(begin, pre_block); llvm::Function* f = module_->getFunction(op->name);
llvm::Value* cond = CreateLT(t, index, end); if (f == nullptr) {
builder_->CreateCondBr(cond, for_body, for_end, md_very_likely_branch_); f = llvm::Function::Create(
// body of for ftype, llvm::Function::ExternalLinkage,
builder_->SetInsertPoint(for_body); op->name, module_.get());
var_map_[loop_var.get()] = index; }
this->VisitStmt(body); llvm::CallInst* call = builder_->CreateCall(f, arg_value);
llvm::Value* next_index = CreateAdd(t, index, stride); return call;
index->addIncoming(next_index, builder_->GetInsertBlock());
builder_->CreateBr(for_head);
// end of for
builder_->SetInsertPoint(for_end);
} }
llvm::Value* CodeGenLLVM::CreateIntrinsic(const Call* op) { llvm::Value* CodeGenLLVM::CreateIntrinsic(const Call* op) {
if (op->is_intrinsic("llvm_intrin")) { if (op->is_intrinsic("llvm_intrin")) {
CHECK_GE(op->args.size(), 1U); CHECK_GE(op->args.size(), 1U);
std::vector<llvm::Value*> arg_values; llvm::Intrinsic::ID id = static_cast<llvm::Intrinsic::ID>(
std::vector<llvm::Type*> arg_types; op->args[0].as<UIntImm>()->value);
std::vector<llvm::Value*> arg_value;
std::vector<llvm::Type*> arg_type;
for (size_t i = 1; i < op->args.size(); ++i) { for (size_t i = 1; i < op->args.size(); ++i) {
llvm::Value* v = MakeValue(op->args[i]); arg_value.push_back(MakeValue(op->args[i]));
arg_values.push_back(v); arg_type.push_back(arg_value.back()->getType());
arg_types.push_back(v->getType());
} }
auto id = static_cast<llvm::Intrinsic::ID>(op->args[0].as<UIntImm>()->value);
llvm::Function* f = llvm::Intrinsic::getDeclaration( llvm::Function* f = llvm::Intrinsic::getDeclaration(
module_.get(), id, arg_types); module_.get(), id, arg_type);
return builder_->CreateCall(f, arg_values); return builder_->CreateCall(f, arg_value);
} else if (op->is_intrinsic("llvm_builtin")) { } else if (op->is_intrinsic("llvm_builtin")) {
CHECK_GE(op->args.size(), 1U); CHECK_GE(op->args.size(), 1U);
std::vector<llvm::Value*> arg_values; llvm::Intrinsic::ID id = static_cast<llvm::Intrinsic::ID>(
op->args[0].as<UIntImm>()->value);
std::vector<llvm::Value*> arg_value;
for (size_t i = 1; i < op->args.size(); ++i) { for (size_t i = 1; i < op->args.size(); ++i) {
llvm::Value* v = MakeValue(op->args[i]); arg_value.push_back(MakeValue(op->args[i]));
arg_values.push_back(v);
} }
auto id = static_cast<llvm::Intrinsic::ID>(op->args[0].as<UIntImm>()->value); llvm::Function* f = llvm::Intrinsic::getDeclaration(module_.get(), id, {});
llvm::Function* f = llvm::Intrinsic::getDeclaration(module_.get(), id); return builder_->CreateCall(f, arg_value);
return builder_->CreateCall(f, arg_values);
} else if (op->is_intrinsic(intrinsic::tvm_storage_sync)) {
return CreateStorageSync(op);
} else if (op->is_intrinsic(Call::bitwise_and)) { } else if (op->is_intrinsic(Call::bitwise_and)) {
CHECK_EQ(op->args.size(), 2U); return builder_->CreateAnd(MakeValue(op->args[0]), MakeValue(op->args[1]));
return builder_->CreateAnd(
MakeValue(op->args[0]), MakeValue(op->args[1]));
} else if (op->is_intrinsic(Call::bitwise_xor)) {
CHECK_EQ(op->args.size(), 2U);
return builder_->CreateXor(
MakeValue(op->args[0]), MakeValue(op->args[1]));
} else if (op->is_intrinsic(Call::bitwise_or)) { } else if (op->is_intrinsic(Call::bitwise_or)) {
CHECK_EQ(op->args.size(), 2U); return builder_->CreateOr(MakeValue(op->args[0]), MakeValue(op->args[1]));
return builder_->CreateOr(
MakeValue(op->args[0]), MakeValue(op->args[1]));
} else if (op->is_intrinsic(Call::bitwise_not)) { } else if (op->is_intrinsic(Call::bitwise_not)) {
CHECK_EQ(op->args.size(), 1U);
return builder_->CreateNot(MakeValue(op->args[0])); return builder_->CreateNot(MakeValue(op->args[0]));
} else if (op->is_intrinsic(Call::bitwise_xor)) {
return builder_->CreateXor(MakeValue(op->args[0]), MakeValue(op->args[1]));
} else if (op->is_intrinsic(Call::shift_left)) { } else if (op->is_intrinsic(Call::shift_left)) {
CHECK_EQ(op->args.size(), 2U); return builder_->CreateShl(MakeValue(op->args[0]), MakeValue(op->args[1]));
return builder_->CreateShl(
MakeValue(op->args[0]), MakeValue(op->args[1]));
} else if (op->is_intrinsic(Call::shift_right)) { } else if (op->is_intrinsic(Call::shift_right)) {
CHECK_EQ(op->args.size(), 2U); if (op->args[0].type().is_int()) {
if (op->type.is_int()) { return builder_->CreateAShr(MakeValue(op->args[0]), MakeValue(op->args[1]));
return builder_->CreateAShr(
MakeValue(op->args[0]), MakeValue(op->args[1]));
} else { } else {
return builder_->CreateLShr( return builder_->CreateLShr(MakeValue(op->args[0]), MakeValue(op->args[1]));
MakeValue(op->args[0]), MakeValue(op->args[1]));
} }
} else if (op->is_intrinsic(intrinsic::tvm_storage_sync)) {
return CreateStorageSync(op);
} else if (op->is_intrinsic(intrinsic::tvm_address_of)) { } else if (op->is_intrinsic(intrinsic::tvm_address_of)) {
const Load *l = op->args[0].as<Load>(); const Load *l = op->args[0].as<Load>();
CHECK(op->args.size() == 1 && l); CHECK(op->args.size() == 1 && l);
return builder_->CreatePointerCast( llvm::Value* ptr = CreateBufferPtr(
CreateBufferPtr( l->type, MakeValue(l->buffer_var), MakeValue(l->index));
l->type, GetVarValue(l->buffer_var.get()), MakeValue(l->index)), unsigned addrspace = llvm::dyn_cast<llvm::PointerType>(
t_void_p_); ptr->getType())->getAddressSpace();
return builder_->CreatePointerCast(ptr, t_void_->getPointerTo(addrspace));
} else if (op->is_intrinsic(Call::reinterpret) && is_zero(op->args[0])) {
return llvm::Constant::getNullValue(t_void_p_);
} else if (op->is_intrinsic(intrinsic::tvm_handle_is_null)) { } else if (op->is_intrinsic(intrinsic::tvm_handle_is_null)) {
CHECK_EQ(op->args.size(), 1U); return builder_->CreateIsNull(MakeValue(op->args[0]));
llvm::Value* ptr = MakeValue(op->args[0]);
return builder_->CreateICmpEQ(
ptr, llvm::Constant::getNullValue(ptr->getType()));
} else if (op->is_intrinsic(intrinsic::tvm_if_then_else)) { } else if (op->is_intrinsic(intrinsic::tvm_if_then_else)) {
using llvm::BasicBlock; using llvm::BasicBlock;
CHECK_EQ(op->args.size(), 3U);
llvm::Value* cond = MakeValue(op->args[0]);
BasicBlock* then_block = BasicBlock::Create( BasicBlock* then_block = BasicBlock::Create(
*ctx_, "if_then", function_); *ctx_, "if_then", function_);
BasicBlock* else_block = BasicBlock::Create( BasicBlock* else_block = BasicBlock::Create(
*ctx_, "if_else", function_); *ctx_, "if_else", function_);
BasicBlock* end_block = BasicBlock::Create( BasicBlock* end_block = BasicBlock::Create(
*ctx_, "if_end", function_); *ctx_, "if_end", function_);
builder_->CreateCondBr(cond, then_block, else_block); builder_->CreateCondBr(MakeValue(op->args[0]), then_block, else_block);
// Then
builder_->SetInsertPoint(then_block); builder_->SetInsertPoint(then_block);
llvm::Value* then_value = MakeValue(op->args[1]); llvm::Value* then_value = MakeValue(op->args[1]);
builder_->CreateBr(end_block); builder_->CreateBr(end_block);
builder_->SetInsertPoint(else_block); builder_->SetInsertPoint(else_block);
// else
llvm::Value* else_value = MakeValue(op->args[2]); llvm::Value* else_value = MakeValue(op->args[2]);
builder_->CreateBr(end_block); builder_->CreateBr(end_block);
builder_->SetInsertPoint(end_block); builder_->SetInsertPoint(end_block);
// phi llvm::PHINode* value = builder_->CreatePHI(then_value->getType(), 2);
llvm::PHINode* phi = builder_->CreatePHI(then_value->getType(), 2); value->addIncoming(then_value, then_block);
phi->addIncoming(then_value, then_block); value->addIncoming(else_value, else_block);
phi->addIncoming(else_value, else_block); return value;
return phi;
} else if (op->is_intrinsic(Call::reinterpret) && is_zero(op->args[0])) {
return llvm::Constant::getNullValue(t_void_p_);
} else { } else {
LOG(FATAL) << "Unknown intrinstic " << op->name; LOG(FATAL) << "unknown intrinsic " << op->name;
}
return nullptr;
}
// Get the corresponding thread index
llvm::Value* CodeGenLLVM::GetThreadIndex(const IterVar& iv) {
LOG(FATAL) << "Donot support threading " << iv;
return nullptr; return nullptr;
}
} }
llvm::Value* CodeGenLLVM::CreateStorageSync(const Call* op) { void CodeGenLLVM::Scalarize(const Expr& e,
LOG(FATAL) << "Donot support storage sync in CPU mode"; std::function<void(int i, llvm::Value* v)> f) {
return nullptr; if (const Ramp* ramp = e.as<Ramp>()) {
} for (int i = 0; i < ramp->type.lanes(); ++i) {
Expr offset = arith::ComputeExpr<Add>(
int CodeGenLLVM::NativeVectorBits(const runtime::StorageScope& storage_scope) const { ramp->base,
// By default, we ask the buffer to be aligned to 64 bytes arith::ComputeExpr<Mul>(ramp->stride, i));
return native_vector_bits_; f(i, MakeValue(offset));
} }
} else {
unsigned CodeGenLLVM::GetGlobalAddressSpace() { llvm::Value* value = MakeValue(e);
return 0; for (int i = 0; i < e.type().lanes(); ++i) {
f(i, builder_->CreateExtractElement(value, i));
}
}
} }
void CodeGenLLVM::GetAlignment(
Type t, const Variable* buf_var, const Expr& index,
int* p_alignment, int* p_native_bits) {
int& alignment = *p_alignment;
int& native_bits = *p_native_bits;
// The storage scope.
StorageInfo info;
auto it = alloc_storage_info_.find(buf_var);
if (it != alloc_storage_info_.end()) {
info = it->second;
}
arith::ModularEntry m = EvalModular(index, align_map_);
native_bits = NativeVectorBits(info.scope);
alignment = t.element_of().bits();
// find alignment, cannot exceed allocated alignment
int max_align_bits = std::min(
info.alignment * 8, alignment * t.lanes());
while ((m.coeff & 1) == 0 &&
(m.base & 1) == 0 &&
alignment < max_align_bits &&
alignment < native_bits) {
m.coeff /= 2;
m.base /= 2;
alignment *= 2;
}
CHECK_EQ(alignment % 8, 0)
<< "Load from memory that does not align to 8 bits";
alignment /= 8;
}
// visitor overrides // Visitors
llvm::Value* CodeGenLLVM::VisitExpr_(const Variable* op) { llvm::Value* CodeGenLLVM::VisitExpr_(const Variable* op) {
return GetVarValue(op); return GetVarValue(op);
} }
...@@ -597,7 +627,6 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Variable* op) { ...@@ -597,7 +627,6 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Variable* op) {
llvm::Value* CodeGenLLVM::VisitExpr_(const Cast* op) { llvm::Value* CodeGenLLVM::VisitExpr_(const Cast* op) {
return CreateCast(op->value.type(), op->type, MakeValue(op->value)); return CreateCast(op->value.type(), op->type, MakeValue(op->value));
} }
llvm::Value* CodeGenLLVM::VisitExpr_(const IntImm* op) { llvm::Value* CodeGenLLVM::VisitExpr_(const IntImm* op) {
return llvm::ConstantInt::getSigned(LLVMType(op->type), op->value); return llvm::ConstantInt::getSigned(LLVMType(op->type), op->value);
} }
...@@ -614,121 +643,114 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const StringImm* op) { ...@@ -614,121 +643,114 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const StringImm* op) {
return GetConstString(op->value); return GetConstString(op->value);
} }
#define DEFINE_CODEGEN_BINARY_OP(OP) \ #define DEFINE_CODEGEN_BINARY_OP(Op) \
llvm::Value* CodeGenLLVM::Create ## OP( \ llvm::Value* CodeGenLLVM::Create ## Op( \
Type t, llvm::Value* a, llvm::Value *b) { \ Type t, llvm::Value* a, llvm::Value *b) { \
if (t.is_float()) { \ if (t.is_int()) { \
return builder_->CreateF ## OP (a, b); \ if (t.bits() >= 32) { \
} else if (t.is_int() && t.bits() >= 32) { \ return builder_->CreateNSW ## Op (a, b); \
return builder_->CreateNSW ## OP (a, b); \
} else { \ } else { \
return builder_->Create ## OP (a, b); \ return builder_->Create ## Op (a, b); \
} \ } \
} else if (t.is_uint()) { \
if (t.bits() >= 32) { \
return builder_->CreateNUW ## Op (a, b); \
} else { \
return builder_->Create ## Op (a, b); \
} \
} else { \
CHECK(t.is_float()); \
return builder_->CreateF ## Op (a, b); \
} \ } \
} \
llvm::Value* CodeGenLLVM::VisitExpr_(const Op* op) { \
return Create ## Op(op->type, MakeValue(op->a), MakeValue(op->b)); \
}
DEFINE_CODEGEN_BINARY_OP(Add); DEFINE_CODEGEN_BINARY_OP(Add);
DEFINE_CODEGEN_BINARY_OP(Sub); DEFINE_CODEGEN_BINARY_OP(Sub);
DEFINE_CODEGEN_BINARY_OP(Mul); DEFINE_CODEGEN_BINARY_OP(Mul);
llvm::Value* CodeGenLLVM::VisitExpr_(const Add* op) { #define DEFINE_CODEGEN_CMP_OP(Op) \
return CreateAdd(op->type, MakeValue(op->a), MakeValue(op->b)); llvm::Value* CodeGenLLVM::Create ## Op( \
} Type t, llvm::Value* a, llvm::Value* b) { \
if (t.is_int()) { \
llvm::Value* CodeGenLLVM::VisitExpr_(const Sub* op) { return builder_->CreateICmpS ## Op (a, b); \
return CreateSub(op->type, MakeValue(op->a), MakeValue(op->b)); } else if (t.is_uint()) { \
} return builder_->CreateICmpU ## Op (a, b); \
} else { \
CHECK(t.is_float()); \
return builder_->CreateFCmpO ## Op (a, b); \
} \
} \
llvm::Value* CodeGenLLVM::VisitExpr_(const Op* op) { \
return Create ## Op(op->a.type(), MakeValue(op->a), MakeValue(op->b)); \
}
llvm::Value* CodeGenLLVM::VisitExpr_(const Mul* op) { DEFINE_CODEGEN_CMP_OP(LT);
return CreateMul(op->type, MakeValue(op->a), MakeValue(op->b)); DEFINE_CODEGEN_CMP_OP(LE);
} DEFINE_CODEGEN_CMP_OP(GT);
DEFINE_CODEGEN_CMP_OP(GE);
llvm::Value* CodeGenLLVM::VisitExpr_(const Div* op) { llvm::Value* CodeGenLLVM::VisitExpr_(const Div* op) {
llvm::Value* a = MakeValue(op->a); llvm::Value* a = MakeValue(op->a);
llvm::Value* b = MakeValue(op->b);
int shift; int shift;
if (op->type.is_float()) { if ((op->type.is_int() || op->type.is_uint()) &&
return builder_->CreateFDiv(a, MakeValue(op->b));
} else if ((op->type.is_int() || op->type.is_uint()) &&
is_const_power_of_two_integer(op->b, &shift)) { is_const_power_of_two_integer(op->b, &shift)) {
return builder_->CreateAShr(a, shift); return builder_->CreateAShr(a, shift);
} else { } else if (op->type.is_int()) {
llvm::Value* b = MakeValue(op->b);
if (op->type.is_int()) {
return builder_->CreateSDiv(a, b); return builder_->CreateSDiv(a, b);
} else { } else if (op->type.is_uint()) {
CHECK(op->type.is_uint());
return builder_->CreateUDiv(a, b); return builder_->CreateUDiv(a, b);
} } else {
CHECK(op->type.is_float());
return builder_->CreateFDiv(a, b);
} }
} }
llvm::Value* CodeGenLLVM::VisitExpr_(const Mod* op) { llvm::Value* CodeGenLLVM::VisitExpr_(const Mod* op) {
CHECK(!op->type.is_float()) llvm::Value* a = MakeValue(op->a);
<< "Cannot do mod for float"; llvm::Value* b = MakeValue(op->b);
if (op->type.is_int()) { if (op->type.is_int()) {
return builder_->CreateSRem(MakeValue(op->a), MakeValue(op->b)); return builder_->CreateSRem(a, b);
} else if (op->type.is_uint()) {
return builder_->CreateURem(a, b);
} else { } else {
CHECK(op->type.is_uint()); CHECK(op->type.is_float());
return builder_->CreateURem(MakeValue(op->a), MakeValue(op->b)); return builder_->CreateFRem(a, b);
} }
} }
llvm::Value* CodeGenLLVM::VisitExpr_(const Min* op) { llvm::Value* CodeGenLLVM::VisitExpr_(const Min* op) {
llvm::Value* a = MakeValue(op->a); llvm::Value* a = MakeValue(op->a);
llvm::Value* b = MakeValue(op->b); llvm::Value* b = MakeValue(op->b);
llvm::Value* cond = CreateLT(op->a.type(), a, b); return builder_->CreateSelect(CreateLT(op->a.type(), a, b), a, b);
return builder_->CreateSelect(cond, a, b);
} }
llvm::Value* CodeGenLLVM::VisitExpr_(const Max* op) { llvm::Value* CodeGenLLVM::VisitExpr_(const Max* op) {
llvm::Value* a = MakeValue(op->a); llvm::Value* a = MakeValue(op->a);
llvm::Value* b = MakeValue(op->b); llvm::Value* b = MakeValue(op->b);
llvm::Value* cond = CreateGT(op->a.type(), a, b); return builder_->CreateSelect(CreateGT(op->a.type(), a, b), a, b);
return builder_->CreateSelect(cond, a, b);
}
#define DEFINE_CODEGEN_CMP_OP(OP) \
llvm::Value* CodeGenLLVM::Create ## OP( \
Type t, llvm::Value* a, llvm::Value* b) { \
if (t.is_float()) { \
return builder_->CreateFCmpO ## OP (a, b); \
} else if (t.is_int()) { \
return builder_->CreateICmpS ## OP (a, b); \
} else { \
return builder_->CreateICmpU ## OP (a, b); \
} \
} \
DEFINE_CODEGEN_CMP_OP(LT);
DEFINE_CODEGEN_CMP_OP(LE);
DEFINE_CODEGEN_CMP_OP(GT);
DEFINE_CODEGEN_CMP_OP(GE);
llvm::Value* CodeGenLLVM::VisitExpr_(const LT* op) {
return CreateLT(op->a.type(), MakeValue(op->a), MakeValue(op->b));
}
llvm::Value* CodeGenLLVM::VisitExpr_(const LE* op) {
return CreateLE(op->a.type(), MakeValue(op->a), MakeValue(op->b));
}
llvm::Value* CodeGenLLVM::VisitExpr_(const GT* op) {
return CreateGT(op->a.type(), MakeValue(op->a), MakeValue(op->b));
}
llvm::Value* CodeGenLLVM::VisitExpr_(const GE* op) {
return CreateGE(op->a.type(), MakeValue(op->a), MakeValue(op->b));
} }
llvm::Value* CodeGenLLVM::VisitExpr_(const EQ* op) { llvm::Value* CodeGenLLVM::VisitExpr_(const EQ* op) {
if (op->a.type().is_float()) { llvm::Value* a = MakeValue(op->a);
return builder_->CreateFCmpOEQ(MakeValue(op->a), MakeValue(op->b)); llvm::Value* b = MakeValue(op->b);
if (op->a.type().is_int() || op->a.type().is_uint()) {
return builder_->CreateICmpEQ(a, b);
} else { } else {
return builder_->CreateICmpEQ(MakeValue(op->a), MakeValue(op->b)); return builder_->CreateFCmpOEQ(a, b);
} }
} }
llvm::Value* CodeGenLLVM::VisitExpr_(const NE* op) { llvm::Value* CodeGenLLVM::VisitExpr_(const NE* op) {
if (op->a.type().is_float()) { llvm::Value* a = MakeValue(op->a);
return builder_->CreateFCmpONE(MakeValue(op->a), MakeValue(op->b)); llvm::Value* b = MakeValue(op->b);
if (op->a.type().is_int() || op->a.type().is_uint()) {
return builder_->CreateICmpNE(a, b);
} else { } else {
return builder_->CreateICmpNE(MakeValue(op->a), MakeValue(op->b)); return builder_->CreateFCmpONE(a, b);
} }
} }
...@@ -752,345 +774,161 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Select* op) { ...@@ -752,345 +774,161 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Select* op) {
} }
llvm::Value* CodeGenLLVM::VisitExpr_(const Let* op) { llvm::Value* CodeGenLLVM::VisitExpr_(const Let* op) {
llvm::Value* v = MakeValue(op->value);
CHECK(!var_map_.count(op->var.get())); CHECK(!var_map_.count(op->var.get()));
CHECK(!align_map_.count(op->var.get())); var_map_[op->var.get()] = MakeValue(op->value);
var_map_[op->var.get()] = v; align_map_[op->var.get()] = EvalModular(op->value, align_map_);
align_map_[op->var.get()] = arith::EvalModular(op->value, align_map_);
return MakeValue(op->body); return MakeValue(op->body);
} }
llvm::Value* CodeGenLLVM::VisitExpr_(const Broadcast* op) { llvm::Value* CodeGenLLVM::VisitExpr_(const Load* op) {
return CreateBroadcast(MakeValue(op->value), op->lanes);
}
llvm::Value* CodeGenLLVM::VisitExpr_(const Ramp* op) {
Type t = op->type; Type t = op->type;
llvm::Value* base = MakeValue(op->base); int alignment, native_bits;
llvm::Value* stride = MakeValue(op->stride); bool is_volatile = volatile_buf_.count(op->buffer_var.get());
llvm::Value* value = llvm::UndefValue::get(LLVMType(t)); GetAlignment(t, op->buffer_var.get(), op->index, &alignment, &native_bits);
for (int i = 0; i < t.lanes(); ++i) { llvm::Value* buffer = MakeValue(op->buffer_var);
if (i != 0) { llvm::Value* index = MakeValue(op->index);
base = CreateAdd(t, base, stride);
} if (t.lanes() == 1) {
value = builder_->CreateInsertElement( llvm::Value* ptr = CreateBufferPtr(t, buffer, index);
value, base, llvm::ConstantInt::get(t_int32_, i)); llvm::LoadInst* load = builder_->CreateAlignedLoad(ptr, alignment, is_volatile);
} AddAliasInfo(load, op->buffer_var.get(), op->index, t);
return value; return load;
} else {
// vector load
unsigned addrspace = llvm::dyn_cast<llvm::PointerType>(
buffer->getType())->getAddressSpace();
if (const Ramp* ramp = op->index.as<Ramp>()) {
if (is_one(ramp->stride)) {
CHECK_EQ(ramp->lanes, t.lanes());
llvm::Value* ptr = CreateBufferPtr(
t.element_of(), buffer, MakeValue(ramp->base));
ptr = builder_->CreatePointerCast(ptr, LLVMType(t)->getPointerTo(addrspace));
llvm::LoadInst* load = builder_->CreateAlignedLoad(ptr, alignment, is_volatile);
AddAliasInfo(load, op->buffer_var.get(), op->index, t);
return load;
}
}
}
// scalarized load.
int basic_align = t.bits() / 8;
llvm::Value* ret = llvm::UndefValue::get(LLVMType(t));
auto f = [&](int i, llvm::Value* index) {
llvm::Value* ptr = CreateBufferPtr(t.element_of(), buffer, index);
llvm::LoadInst* load = builder_->CreateAlignedLoad(
ptr, basic_align, is_volatile);
ret = builder_->CreateInsertElement(ret, load, ConstInt32(i));
AddAliasInfo(load, op->buffer_var.get(), Expr(), t);
};
this->Scalarize(op->index, f);
return ret;
} }
void CodeGenLLVM::Scalarize( llvm::Value* CodeGenLLVM::VisitExpr_(const Call* op) {
const Expr& e, if (op->call_type == Call::Intrinsic ||
std::function<void(int i, llvm::Value* v)> f) { op->call_type == Call::PureIntrinsic) {
const Ramp* ramp = e.as<Ramp>(); return CreateIntrinsic(op);
Type t = e.type(); } else if (op->call_type == Call::Extern ||
if (ramp) { op->call_type == Call::PureExtern) {
for (int i = 0; i < t.lanes(); ++i) { return CreateCallExtern(op);
Expr offset = arith::ComputeExpr<Add>(
ramp->base,
arith::ComputeExpr<Mul>(ramp->stride, i));
f(i, MakeValue(offset));
}
} else { } else {
llvm::Value* index = MakeValue(e); LOG(FATAL) << "Unknown call type ";
for (int i = 0; i < t.lanes(); ++i) { return nullptr;
f(i, builder_->CreateExtractElement(index, ConstInt32(i)));
}
} }
} }
llvm::Value* CodeGenLLVM::CreateVecFlip(llvm::Value* vec) { llvm::Value* CodeGenLLVM::VisitExpr_(const Ramp* op) {
int lanes = static_cast<int>(vec->getType()->getVectorNumElements()); llvm::Value* vec = llvm::UndefValue::get(LLVMType(op->type));
std::vector<llvm::Constant*> indices; for (int i = 0; i < op->lanes; ++i) {
for (int i = lanes; i != 0; --i) { vec = builder_->CreateInsertElement(
indices.push_back(ConstInt32(i - 1)); vec, MakeValue(op->base + op->stride * make_const(op->stride.type(), i)),
} ConstInt32(i));
llvm::Constant* undef = llvm::UndefValue::get(vec->getType());
return builder_->CreateShuffleVector(
vec, undef, llvm::ConstantVector::get(indices));
}
llvm::Value* CodeGenLLVM::CreateVecSlice(
llvm::Value* vec, int begin, int lanes) {
int total_lanes = static_cast<int>(vec->getType()->getVectorNumElements());
CHECK_LE(begin + lanes, total_lanes);
if (lanes == total_lanes && begin == 0) return vec;
std::vector<llvm::Constant*> indices;
for (int i = 0; i < lanes; ++i) {
indices.push_back(ConstInt32(begin + i));
}
llvm::Constant* undef = llvm::UndefValue::get(vec->getType());
return builder_->CreateShuffleVector(
vec, undef, llvm::ConstantVector::get(indices));
}
llvm::Value* CodeGenLLVM::CreateVecPad(llvm::Value* vec, int target_lanes) {
int lanes = static_cast<int>(vec->getType()->getVectorNumElements());
if (target_lanes == lanes) return vec;
CHECK_GT(target_lanes, lanes);
int pad_lanes = target_lanes - lanes;
llvm::Constant* undef = llvm::UndefValue::get(
llvm::VectorType::get(vec->getType()->getVectorElementType(), pad_lanes));
std::vector<llvm::Constant*> indices;
for (int i = 0; i < target_lanes; ++i) {
indices.push_back(ConstInt32(i));
} }
return builder_->CreateShuffleVector( return vec;
vec, undef, llvm::ConstantVector::get(indices));
} }
llvm::Value* CodeGenLLVM::CreateVecConcat( llvm::Value* CodeGenLLVM::VisitExpr_(const Broadcast* op) {
std::vector<llvm::Value*> vec) { return CreateBroadcast(MakeValue(op->value), op->lanes);
CHECK_NE(vec.size(), 0U);
int target_lanes = 0;
for (llvm::Value* v : vec) {
target_lanes += static_cast<int>(v->getType()->getVectorNumElements());
}
// tree shape merging
while (vec.size() != 1) {
std::vector<llvm::Value*> merged;
for (size_t i = 0; i < vec.size() - 1; i += 2) {
llvm::Value* v1 = vec[i];
llvm::Value* v2 = vec[i + 1];
int w1 = static_cast<int>(v1->getType()->getVectorNumElements());
int w2 = static_cast<int>(v2->getType()->getVectorNumElements());
int w = std::max(w1, w2);
v1 = CreateVecPad(v1, w);
v2 = CreateVecPad(v2, w);
std::vector<llvm::Constant*> indices;
for (int i = 0; i < w * 2; ++i) {
indices.push_back(ConstInt32(i));
}
merged.push_back(
builder_->CreateShuffleVector(
v1, v2, llvm::ConstantVector::get(indices)));
}
if (vec.size() % 2 == 1) {
merged.push_back(vec.back());
}
vec = merged;
}
return CreateVecSlice(vec[0], 0, target_lanes);
}
llvm::Value* CodeGenLLVM::VisitExpr_(const Load* op) {
CHECK(is_one(op->predicate))
<< "Predicated Load is not supported";
Type t = op->type;
const Ramp* ramp = op->index.as<Ramp>();
llvm::Value* buf = GetVarValue(op->buffer_var.get());
if (t.is_scalar()) {
llvm::LoadInst* inst = builder_->CreateAlignedLoad(
CreateBufferPtr(t, buf, MakeValue(op->index)),
data_layout_->getTypeAllocSize(LLVMType(t)));
AddAliasInfo(inst, op->buffer_var.get(), op->index, op->type);
return inst;
} else if (ramp && is_one(ramp->stride)) {
int alignment, native_bits;
GetAlignment(t, op->buffer_var.get(), ramp->base,
&alignment, &native_bits);
int total_lanes = t.lanes();
int step = native_bits / t.bits();
std::vector<llvm::Value*> loads;
for (int offset = 0; offset < total_lanes; offset += step) {
int lanes = std::min(step, total_lanes - offset);
Expr base = arith::ComputeExpr<Add>(
ramp->base, make_const(ramp->base.type(), offset));
llvm::Value* ptr = CreateBufferPtr(t.element_of(), buf, MakeValue(base));
llvm::Type* vtype = llvm::VectorType::get(
LLVMType(t.element_of()), lanes)->getPointerTo(
ptr->getType()->getPointerAddressSpace());
llvm::LoadInst* inst = builder_->CreateAlignedLoad(
builder_->CreatePointerCast(ptr, vtype), alignment);
AddAliasInfo(inst, op->buffer_var.get(),
Ramp::make(base, make_const(base.type(), 1), lanes), op->type);
loads.push_back(inst);
}
return CreateVecConcat(loads);
} else if (ramp && is_const(ramp->stride, 2)) {
int alignment, native_bits;
GetAlignment(t, op->buffer_var.get(), ramp->base,
&alignment, &native_bits);
arith::ModularEntry e = arith::EvalModular(ramp->base, align_map_);
Type bt = ramp->base.type();
int first_shift, next_shift;
// If it is even base, and native alignments is bigger than twice
// of the type, to ensure safe loading.
if (e.coeff % 2 == 0 &&
e.base % 2 == 0 &&
native_bits >= t.bits() * 2) {
first_shift = 0;
next_shift = 0;
} else if (e.coeff % 2 == 0 && e.base % 2 == 1) {
// odd base, shift both to left.
first_shift = -1;
next_shift = -1;
} else {
// save option, right part, safe option.
first_shift = 0;
next_shift = -1;
}
llvm::Value* first = MakeValue(Load::make(
t, op->buffer_var,
Ramp::make(arith::ComputeExpr<Add>(
ramp->base, make_const(bt, first_shift)),
make_const(bt, 1), ramp->lanes),
const_true(t.lanes())));
llvm::Value* next = MakeValue(Load::make(
t, op->buffer_var,
Ramp::make(arith::ComputeExpr<Add>(
ramp->base, make_const(bt, ramp->lanes + next_shift)),
make_const(bt, 1), ramp->lanes),
const_true(t.lanes())));
// shuffle
std::vector<llvm::Constant*> indices;
int target_index = 0;
for (int i = 0; i < ramp->lanes; ++i) {
int idx = first_shift + i;
if (idx == target_index) {
indices.push_back(ConstInt32(i));
target_index += 2;
}
}
for (int i = 0; i < ramp->lanes; ++i) {
int idx = ramp->lanes + next_shift + i;
if (idx == target_index) {
indices.push_back(ConstInt32(i + ramp->lanes));
target_index += 2;
}
}
CHECK_EQ(indices.size(), static_cast<size_t>(ramp->lanes));
return builder_->CreateShuffleVector(
first, next, llvm::ConstantVector::get(indices));
} else if (ramp && is_const(ramp->stride, -1)) {
int lanes = ramp->type.lanes();
Expr neg_ramp = Ramp::make(
arith::ComputeExpr<Sub>(
ramp->base,
make_const(ramp->base.type(), lanes - 1)),
make_const(ramp->base.type(), 1),
lanes);
// load value then flip
llvm::Value* v = MakeValue(
Load::make(t, op->buffer_var, neg_ramp, const_true(t.lanes())));
return CreateVecFlip(v);
} else {
llvm::Value* ret = llvm::UndefValue::get(LLVMType(t));
Scalarize(op->index, [&](int i, llvm::Value* offset) {
llvm::Value* ptr = CreateBufferPtr(t.element_of(), buf, offset);
llvm::LoadInst* inst = builder_->CreateAlignedLoad(
ptr, data_layout_->getTypeAllocSize(LLVMType(t)));
AddAliasInfo(inst, op->buffer_var.get(), Expr(), op->type);
ret = builder_->CreateInsertElement(ret, inst, ConstInt32(i));
});
return ret;
}
} }
// stmts
void CodeGenLLVM::VisitStmt_(const Store* op) { void CodeGenLLVM::VisitStmt_(const Store* op) {
CHECK(is_one(op->predicate)) CHECK(is_one(op->predicate));
<< "Predicated Load is not supported";
llvm::Value* value = MakeValue(op->value);
Type t = op->value.type(); Type t = op->value.type();
const Ramp* ramp = op->index.as<Ramp>();
llvm::Value* buf = GetVarValue(op->buffer_var.get());
if (t.is_scalar()) {
llvm::StoreInst* inst = builder_->CreateAlignedStore(
value,
CreateBufferPtr(t, buf, MakeValue(op->index)),
data_layout_->getTypeAllocSize(value->getType()));
AddAliasInfo(inst, op->buffer_var.get(), op->index, op->value.type());
} else if (ramp && is_one(ramp->stride)) {
int alignment, native_bits; int alignment, native_bits;
GetAlignment(t, op->buffer_var.get(), ramp->base, bool is_volatile = volatile_buf_.count(op->buffer_var.get());
&alignment, &native_bits); GetAlignment(t, op->buffer_var.get(), op->index, &alignment, &native_bits);
int total_lanes = t.lanes(); llvm::Value* buffer = MakeValue(op->buffer_var);
int step = native_bits / t.bits(); llvm::Value* index = MakeValue(op->index);
// vector store. llvm::Value* value = MakeValue(op->value);
for (int offset = 0; offset < total_lanes; offset += step) {
int lanes = std::min(step, total_lanes - offset);
Expr base = arith::ComputeExpr<Add>(
ramp->base, make_const(ramp->base.type(), offset));
llvm::Value* ptr = CreateBufferPtr(t.element_of(), buf, MakeValue(base));
llvm::Type* vtype = llvm::VectorType::get(
LLVMType(t.element_of()), lanes)->getPointerTo(
ptr->getType()->getPointerAddressSpace());
llvm::StoreInst* inst = builder_->CreateAlignedStore(
CreateVecSlice(value, offset, lanes),
builder_->CreatePointerCast(ptr, vtype), alignment);
AddAliasInfo(inst, op->buffer_var.get(),
Ramp::make(base, make_const(base.type(), 1), lanes), op->value.type());
}
} else {
Scalarize(op->index, [&](int i, llvm::Value* offset) {
llvm::Value* ptr = CreateBufferPtr(t.element_of(), buf, offset);
llvm::StoreInst* inst = builder_->CreateAlignedStore(
builder_->CreateExtractElement(value, ConstInt32(i)),
ptr, data_layout_->getTypeAllocSize(LLVMType(t)));
AddAliasInfo(inst, op->buffer_var.get(), Expr(), op->value.type());
});
}
}
llvm::Value* CodeGenLLVM::VisitExpr_(const Call* op) { if (t.lanes() == 1) {
if (op->call_type == Call::Intrinsic || llvm::Value* ptr = CreateBufferPtr(t, buffer, index);
op->call_type == Call::PureIntrinsic) { llvm::StoreInst* store = builder_->CreateAlignedStore(value, ptr, alignment, is_volatile);
return CreateIntrinsic(op); AddAliasInfo(store, op->buffer_var.get(), op->index, op->value.type());
return;
} else { } else {
CHECK(op->call_type == Call::Extern || // vector store
op->call_type == Call::PureExtern); unsigned addrspace = llvm::dyn_cast<llvm::PointerType>(
return CreateCallExtern(op); buffer->getType())->getAddressSpace();
if (const Ramp* ramp = op->index.as<Ramp>()) {
if (is_one(ramp->stride)) {
CHECK_EQ(ramp->lanes, t.lanes());
llvm::Value* ptr = CreateBufferPtr(
t.element_of(), buffer, MakeValue(ramp->base));
ptr = builder_->CreatePointerCast(ptr, LLVMType(t)->getPointerTo(addrspace));
llvm::StoreInst* store = builder_->CreateAlignedStore(value, ptr, alignment, is_volatile);
AddAliasInfo(store, op->buffer_var.get(), op->index, op->value.type());
return;
} }
}
}
CHECK_GE(t.bits(), 8);
// scalarized store.
int basic_align = t.bits() / 8;
auto f = [&](int i, llvm::Value* index) {
llvm::Value* ptr = CreateBufferPtr(t.element_of(), buffer, index);
llvm::StoreInst* store = builder_->CreateAlignedStore(
builder_->CreateExtractElement(value, i),
ptr, basic_align, is_volatile);
AddAliasInfo(store, op->buffer_var.get(), Expr(), op->value.type());
};
this->Scalarize(op->index, f);
} }
void CodeGenLLVM::VisitStmt_(const For* op) { void CodeGenLLVM::VisitStmt_(const For* op) {
CHECK(is_zero(op->min)); CHECK(is_zero(op->min));
if (op->for_type == ForType::Serial) { CHECK(op->for_type == ForType::Serial);
CreateSerialFor(ConstInt32(0), CreateSerialFor(MakeValue(op->min), MakeValue(op->extent),
MakeValue(op->extent), ConstInt32(1), op->loop_var, op->body);
ConstInt32(1),
op->loop_var,
op->body);
} else {
LOG(FATAL) << "cannot handle for type " << op->for_type;
}
} }
void CodeGenLLVM::VisitStmt_(const IfThenElse* op) { void CodeGenLLVM::VisitStmt_(const IfThenElse* op) {
using llvm::BasicBlock; using llvm::BasicBlock;
llvm::Value* cond = MakeValue(op->condition);
BasicBlock* then_block = BasicBlock::Create( BasicBlock* then_block = BasicBlock::Create(
*ctx_, "if_then", function_); *ctx_, "if_then", function_);
BasicBlock* else_block = BasicBlock::Create(
*ctx_, "if_else", function_);
BasicBlock* end_block = BasicBlock::Create( BasicBlock* end_block = BasicBlock::Create(
*ctx_, "if_end", function_); *ctx_, "if_end", function_);
if (!op->else_case.defined()) { if (op->else_case.defined()) {
else_block = end_block; BasicBlock* else_block = BasicBlock::Create(
} *ctx_, "if_else", function_);
// condition.
llvm::Value* cond = MakeValue(op->condition);
bool likely = true;
if (likely) {
builder_->CreateCondBr(cond, then_block, else_block, md_very_likely_branch_);
} else {
builder_->CreateCondBr(cond, then_block, else_block); builder_->CreateCondBr(cond, then_block, else_block);
}
// then case.
builder_->SetInsertPoint(then_block); builder_->SetInsertPoint(then_block);
this->VisitStmt(op->then_case); this->VisitStmt(op->then_case);
builder_->CreateBr(end_block); builder_->CreateBr(end_block);
// else case.
if (op->else_case.defined()) {
builder_->SetInsertPoint(else_block); builder_->SetInsertPoint(else_block);
this->VisitStmt(op->else_case); this->VisitStmt(op->else_case);
builder_->CreateBr(end_block); builder_->CreateBr(end_block);
} else {
builder_->CreateCondBr(cond, then_block, end_block, md_very_likely_branch_);
builder_->SetInsertPoint(then_block);
this->VisitStmt(op->then_case);
builder_->CreateBr(end_block);
} }
builder_->SetInsertPoint(end_block); builder_->SetInsertPoint(end_block);
} }
void CodeGenLLVM::VisitStmt_(const Allocate* op) { void CodeGenLLVM::VisitStmt_(const Allocate* op) {
CHECK(!is_zero(op->condition)); CHECK(!is_zero(op->condition));
llvm::Value* buf = nullptr; llvm::Value* buf = nullptr;
...@@ -1100,20 +938,22 @@ void CodeGenLLVM::VisitStmt_(const Allocate* op) { ...@@ -1100,20 +938,22 @@ void CodeGenLLVM::VisitStmt_(const Allocate* op) {
} else { } else {
int32_t constant_size = op->constant_allocation_size(); int32_t constant_size = op->constant_allocation_size();
CHECK_GT(constant_size, 0) CHECK_GT(constant_size, 0)
<< "Can only handle constant size stack allocation for now"; << "Can only handle constant size stack allocation";
llvm::AllocaInst* alloca = builder_->CreateAlloca(
LLVMType(op->type), ConstInt32(constant_size));
buf = alloca;
StorageInfo& info = alloc_storage_info_[op->buffer_var.get()]; StorageInfo& info = alloc_storage_info_[op->buffer_var.get()];
// Align stack to be TempAllocaAlignment.
// TODO(tqchen) have pass to detect vector access and pre-set alignment
if (constant_size % 4 == 0 && info.alignment == 0) { if (constant_size % 4 == 0 && info.alignment == 0) {
info.alignment = GetTempAllocaAlignment(op->type, constant_size); info.alignment = GetTempAllocaAlignment(op->type, constant_size);
} }
// maximum necessary alignment in the NV devices
if (info.alignment > 16) {
info.alignment = 16;
}
llvm::AllocaInst* alloca = builder_->CreateAlloca(
LLVMType(op->type), ConstInt32(constant_size));
if (alloca->getAlignment() < static_cast<uint32_t>(info.alignment)) { if (alloca->getAlignment() < static_cast<uint32_t>(info.alignment)) {
alloca->setAlignment(info.alignment); alloca->setAlignment(info.alignment);
} }
info.alignment = alloca->getAlignment(); info.alignment = alloca->getAlignment();
buf = alloca;
} }
buf = builder_->CreatePointerCast( buf = builder_->CreatePointerCast(
buf, LLVMType(op->type)->getPointerTo( buf, LLVMType(op->type)->getPointerTo(
...@@ -1124,29 +964,29 @@ void CodeGenLLVM::VisitStmt_(const Allocate* op) { ...@@ -1124,29 +964,29 @@ void CodeGenLLVM::VisitStmt_(const Allocate* op) {
} }
void CodeGenLLVM::VisitStmt_(const AttrStmt* op) { void CodeGenLLVM::VisitStmt_(const AttrStmt* op) {
if (op->attr_key == ir::attr::thread_extent) { if (op->attr_key == attr::thread_extent) {
IterVar iv(op->node.node_); IterVar iv(op->node.node_);
if (iv->thread_tag.length() != 0) { if (iv->thread_tag.length() != 0) {
if (!var_map_.count(iv->var.get())) { if (!var_map_.count(iv->var.get())) {
var_map_[iv->var.get()] = GetThreadIndex(iv); var_map_[iv->var.get()] = GetThreadIndex(iv);
} }
} }
this->VisitStmt(op->body);
} else if (op->attr_key == ir::attr::storage_scope) { } else if (op->attr_key == ir::attr::storage_scope) {
const Variable* v = op->node.as<Variable>(); const Variable* v = op->node.as<Variable>();
CHECK(v); CHECK(v);
alloc_storage_info_[v].scope = runtime::StorageScope::make( alloc_storage_info_[v].scope =
op->value.as<StringImm>()->value); runtime::StorageScope::make(op->value.as<StringImm>()->value);
this->VisitStmt(op->body);
} else if (op->attr_key == ir::attr::storage_alignment) { } else if (op->attr_key == ir::attr::storage_alignment) {
const Variable* v = op->node.as<Variable>(); const Variable* v = op->node.as<Variable>();
CHECK(v); CHECK(v);
alloc_storage_info_[v].alignment = alloc_storage_info_[v].alignment =
static_cast<int>(op->value.as<IntImm>()->value); static_cast<int>(op->value.as<IntImm>()->value);
this->VisitStmt(op->body); } else if (op->attr_key == ir::attr::volatile_scope) {
} else { const Variable* v = op->node.as<Variable>();
this->VisitStmt(op->body); CHECK(v);
volatile_buf_.insert(v);
} }
this->VisitStmt(op->body);
} }
void CodeGenLLVM::VisitStmt_(const AssertStmt* op) { void CodeGenLLVM::VisitStmt_(const AssertStmt* op) {
...@@ -1178,7 +1018,6 @@ void CodeGenLLVM::VisitStmt_(const AssertStmt* op) { ...@@ -1178,7 +1018,6 @@ void CodeGenLLVM::VisitStmt_(const AssertStmt* op) {
} }
void CodeGenLLVM::VisitStmt_(const LetStmt* op) { void CodeGenLLVM::VisitStmt_(const LetStmt* op) {
llvm::Value* v = MakeValue(op->value);
CHECK(!var_map_.count(op->var.get())); CHECK(!var_map_.count(op->var.get()));
CHECK(!align_map_.count(op->var.get())); CHECK(!align_map_.count(op->var.get()));
if (op->var.type().is_handle()) { if (op->var.type().is_handle()) {
...@@ -1186,24 +1025,25 @@ void CodeGenLLVM::VisitStmt_(const LetStmt* op) { ...@@ -1186,24 +1025,25 @@ void CodeGenLLVM::VisitStmt_(const LetStmt* op) {
alias_var_set_.insert(op->var.get()); alias_var_set_.insert(op->var.get());
} }
} }
var_map_[op->var.get()] = v; var_map_[op->var.get()] = MakeValue(op->value);
align_map_[op->var.get()] = arith::EvalModular(op->value, align_map_); align_map_[op->var.get()] = EvalModular(op->value, align_map_);
this->VisitStmt(op->body); this->VisitStmt(op->body);
} }
void CodeGenLLVM::VisitStmt_(const Block* op) { void CodeGenLLVM::VisitStmt_(const Block* op) {
VisitStmt(op->first); this->VisitStmt(op->first);
if (op->rest.defined()) VisitStmt(op->rest); if (op->rest.defined()) {
this->VisitStmt(op->rest);
}
} }
void CodeGenLLVM::VisitStmt_(const Evaluate *op) { void CodeGenLLVM::VisitStmt_(const Evaluate* op) {
MakeValue(op->value); MakeValue(op->value);
} }
void CodeGenLLVM::VisitStmt_(const ProducerConsumer* op) { void CodeGenLLVM::VisitStmt_(const ProducerConsumer* op) {
VisitStmt(op->body); this->VisitStmt(op->body);
} }
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
#endif // TVM_LLVM_VERSION #endif // TVM_LLVM_VERSION
...@@ -242,6 +242,8 @@ class CodeGenLLVM : ...@@ -242,6 +242,8 @@ class CodeGenLLVM :
std::unordered_map<const Variable*, arith::ModularEntry> align_map_; std::unordered_map<const Variable*, arith::ModularEntry> align_map_;
// set of var that are not restricted(can alias) // set of var that are not restricted(can alias)
std::unordered_set<const Variable*> alias_var_set_; std::unordered_set<const Variable*> alias_var_set_;
// set of volatile buffer.
std::unordered_set<const Variable*> volatile_buf_;
}; };
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
......
...@@ -28,7 +28,13 @@ inline Expr BroadcastTo(Expr e, int lanes) { ...@@ -28,7 +28,13 @@ inline Expr BroadcastTo(Expr e, int lanes) {
} }
// Rewrite vectorized allocation access // Rewrite vectorized allocation access
// This is necessary for making each vector component containing its own workspace.
// Originates from Halide's loop vectorizer
//
// s[i] = s[i * lanes + var] // s[i] = s[i * lanes + var]
//
// The same principle applies when using one thread to simulate multiple context.
//
class VecAllocAccess : public IRMutator { class VecAllocAccess : public IRMutator {
public: public:
VecAllocAccess(const Variable* buf, Var var, int var_lanes) VecAllocAccess(const Variable* buf, Var var, int var_lanes)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment