Commit 396c161f by Tianqi Chen Committed by GitHub

Acknowledge related projects (#465)

* [CODEGEN] Redo CodegenLLVM.

* Add remarks about origin of the pass

Properly acknowledge related projects

* Fix and expression
parent 2607a836
......@@ -28,3 +28,12 @@ TVM adopts apache committer model, we aim to create an open source project that
- [Contributor Guide](docs/how_to/contribute.md)
- Please add your name to [CONTRIBUTORS.md](CONTRIBUTORS.md)
- Please also update [NEWS.md](NEWS.md) on changes and improvements in API and codes.
Acknowledgement
---------------
We learnt a lot from the following projects when building TVM.
- [Halide](https://github.com/halide/Halide): TVM uses [HalideIR](https://github.com/dmlc/HalideIR) as data structure for
arithematic simplification and low level lowering. HalideIR is derived from Halide.
We also learns from Halide when implementing the lowering pipeline in TVM.
- [Loopy](https://github.com/inducer/loopy): use of integer set analysis and its loop transformation primitives.
- [Theano](https://github.com/Theano/Theano): the design inspiration of symbolic scan operator for recurrence.
......@@ -32,29 +32,24 @@ void CodeGenLLVM::Init(const std::string& module_name,
bool system_lib,
bool dynamic_lookup) {
InitializeLLVM();
// clear maps
var_map_.clear();
str_map_.clear();
ctx_ = ctx;
t_void_ = llvm::Type::getVoidTy(*ctx);
t_void_p_ = llvm::Type::getInt8Ty(*ctx)->getPointerTo();
t_int_ = llvm::Type::getIntNTy(*ctx, sizeof(int) * 8);
t_char_ = llvm::Type::getInt8Ty(*ctx);
t_int8_ = llvm::Type::getInt8Ty(*ctx);
t_int16_ = llvm::Type::getInt16Ty(*ctx);
t_int32_ = llvm::Type::getInt32Ty(*ctx);
t_int64_ = llvm::Type::getInt64Ty(*ctx);
t_float64_ = llvm::Type::getDoubleTy(*ctx);
md_builder_.reset(new llvm::MDBuilder(*ctx));
md_very_likely_branch_ =
md_builder_->createBranchWeights(1 << 30, 0);
md_tbaa_root_ = md_builder_->createTBAARoot("tvmtbaa");
md_tbaa_alias_set_ = md_builder_->createTBAAScalarTypeNode(
"alias_set", md_tbaa_root_);
// initialize Modules and function type
module_.reset(new llvm::Module(module_name, *ctx));
// initialize builder
builder_.reset(new IRBuilder(*ctx));
builder_.reset(new IRBuilder(*ctx_));
module_.reset(new llvm::Module(module_name, *ctx_));
md_builder_.reset(new llvm::MDBuilder(*ctx_));
// types
t_void_ = llvm::Type::getVoidTy(*ctx_);
t_void_p_ = llvm::Type::getInt8Ty(*ctx_)->getPointerTo();
t_int_ = llvm::Type::getInt32Ty(*ctx_);
t_char_ = llvm::Type::getInt8Ty(*ctx_);
t_int8_ = llvm::Type::getInt8Ty(*ctx_);
t_int16_ = llvm::Type::getInt16Ty(*ctx_);
t_int32_ = llvm::Type::getInt32Ty(*ctx_);
t_int64_ = llvm::Type::getInt64Ty(*ctx_);
t_float64_ = llvm::Type::getDoubleTy(*ctx_);
// meta data
md_very_likely_branch_ = md_builder_->createBranchWeights(1<<20, 1);
md_tbaa_root_ = md_builder_->createTBAARoot("tvm-tbaa");
md_tbaa_alias_set_ = md_builder_->createTBAANode("tvm-alias", md_tbaa_root_);
this->InitTarget(tm);
}
......@@ -63,61 +58,73 @@ void CodeGenLLVM::InitTarget(llvm::TargetMachine* tm) {
module_->setDataLayout(tm->createDataLayout());
data_layout_.reset(new llvm::DataLayout(module_.get()));
target_machine_ = tm;
// initialize native vector bits
std::string target = tm->getTarget().getName();
if (target == "x86-64") {
// for avx512
native_vector_bits_ = 64 * 8;
} else if (target == "x86") {
native_vector_bits_ = 32 * 8;
} else {
if (native_vector_bits_ == 0) {
native_vector_bits_ = 32 * 8;
LOG(WARNING) << "set native vector to be " << native_vector_bits_ / 8
<< " for target " << target;
if (native_vector_bits_ == 0) {
const auto& arch = tm->getTargetTriple().getArch();
if (arch == llvm::Triple::x86_64) {
// for avx512
native_vector_bits_ = 512;
} else if (arch == llvm::Triple::x86) {
native_vector_bits_ = 256;
} else if (arch == llvm::Triple::arm || arch == llvm::Triple::aarch64) {
native_vector_bits_ = 128;
} else {
native_vector_bits_ = 128;
std::string arch_name = tm->getTargetTriple().getArchName();
LOG(WARNING) << "Set native vector bits to be 128 for " << arch_name;
}
}
}
void CodeGenLLVM::AddFunction(const LoweredFunc& f) {
this->AddFunctionInternal(f, false);
}
void CodeGenLLVM::InitFuncState() {
var_map_.clear();
alias_var_set_.clear();
align_map_.clear();
alloc_storage_info_.clear();
alias_var_set_.clear();
}
void CodeGenLLVM::AddFunction(const LoweredFunc& f) {
AddFunctionInternal(f, false);
volatile_buf_.clear();
}
void CodeGenLLVM::AddFunctionInternal(const LoweredFunc& f, bool ret_void) {
this->InitFuncState();
std::vector<llvm::Type*> arg_types;
is_restricted_ = f->is_restricted;
CHECK(!module_->getFunction(f->name))
<< "Function " << f->name << "already exists in module";
std::vector<llvm::Type*> arg_type;
for (Var arg : f->args) {
Type t = arg.type();
if (t.is_handle() && f->handle_data_type.count(arg)) {
arg_type.push_back(
LLVMType(f->handle_data_type[arg].type())->getPointerTo(GetGlobalAddressSpace()));
if (t.is_handle()) {
auto it = f->handle_data_type.find(arg);
if (it != f->handle_data_type.end()) {
arg_types.push_back(LLVMType((*it).second.type())
->getPointerTo(GetGlobalAddressSpace()));
} else {
arg_types.push_back(t_int8_->getPointerTo(GetGlobalAddressSpace()));
}
if (!is_restricted_) {
alias_var_set_.insert(arg.get());
}
} else {
arg_type.push_back(LLVMType(t));
arg_types.push_back(LLVMType(arg.type()));
}
}
llvm::FunctionType* ftype = llvm::FunctionType::get(
ret_void ? t_void_ : t_int_, arg_type, false);
// setup the function.
function_ = llvm::cast<llvm::Function>(module_->getOrInsertFunction(f->name, ftype));
ret_void ? t_void_ : t_int_, arg_types, false);
CHECK(module_->getFunction(f->name) == nullptr)
<< "Function " << f->name << " already exist in module";
function_ = llvm::Function::Create(
ftype, llvm::Function::ExternalLinkage,
f->name, module_.get());
function_->setCallingConv(llvm::CallingConv::C);
// set handle argument to be non alias.
if (is_restricted_) {
for (size_t i = 0; i < f->args.size(); ++i) {
if (f->args[i].type().is_handle()) {
// set var map and align information
auto arg_it = function_->arg_begin();
for (size_t i = 0; i < f->args.size(); ++i, ++arg_it) {
llvm::Argument* v = &(*arg_it);
const Var& var = f->args[i];
var_map_[var.get()] = v;
if (is_restricted_) {
if (var.type().is_handle() && !alias_var_set_.count(var.get())) {
// set non alias.
#if TVM_LLVM_VERSION >= 50
function_->addParamAttr(i, llvm::Attribute::NoAlias);
#else
......@@ -126,17 +133,8 @@ void CodeGenLLVM::AddFunctionInternal(const LoweredFunc& f, bool ret_void) {
}
}
}
size_t idx = 0;
for (auto it = function_->arg_begin();
it != function_->arg_end(); ++it, ++idx) {
llvm::Argument* v = &(*it);
var_map_[f->args[idx].get()] = v;
}
llvm::BasicBlock* block = llvm::BasicBlock::Create(*ctx_, "entry", function_);
builder_->SetInsertPoint(block);
llvm::BasicBlock* entry = llvm::BasicBlock::Create(*ctx_, "entry", function_);
builder_->SetInsertPoint(entry);
this->VisitStmt(f->body);
if (ret_void) {
builder_->CreateRetVoid();
......@@ -145,8 +143,24 @@ void CodeGenLLVM::AddFunctionInternal(const LoweredFunc& f, bool ret_void) {
}
}
std::unique_ptr<llvm::Module> CodeGenLLVM::Finish() {
this->AddStartupFunction();
this->Optimize();
return std::move(module_);
}
void CodeGenLLVM::AddMainFunction(const std::string& entry_func_name) {
LOG(FATAL) << "Donot support add main function";
LOG(FATAL) << "not implemented";
}
llvm::Value* CodeGenLLVM::GetThreadIndex(const IterVar& iv) {
LOG(FATAL) << "not implemented";
return nullptr;
}
llvm::Value* CodeGenLLVM::CreateStorageSync(const Call* op) {
LOG(FATAL) << "not implemented";
return nullptr;
}
class FPassManager : public llvm::legacy::FunctionPassManager {
......@@ -202,36 +216,48 @@ void CodeGenLLVM::Optimize() {
mpass.run(*module_);
}
std::unique_ptr<llvm::Module> CodeGenLLVM::Finish() {
this->AddStartupFunction();
this->Optimize();
return std::move(module_);
int CodeGenLLVM::NativeVectorBits(const runtime::StorageScope& storage_scope) const {
return native_vector_bits_;
}
unsigned CodeGenLLVM::GetGlobalAddressSpace() {
return 0;
}
llvm::Type* CodeGenLLVM::LLVMType(const Type& t) const {
llvm::Type* ret = nullptr;
if (t.is_uint() || t.is_int()) {
ret = llvm::Type::getIntNTy(*ctx_, t.bits());
if (t.is_handle()) {
CHECK_EQ(t.lanes(), 1);
return t_void_p_;
}
llvm::Type* etype;
if (t.is_int() || t.is_uint()) {
etype = llvm::Type::getIntNTy(*ctx_, t.bits());
} else if (t.is_float()) {
switch (t.bits()) {
case 16: ret = llvm::Type::getHalfTy(*ctx_); break;
case 32: ret = llvm::Type::getFloatTy(*ctx_); break;
case 64: ret = llvm::Type::getDoubleTy(*ctx_); break;
default: LOG(FATAL) << "cannot handle " << t;
case 16: etype = llvm::Type::getHalfTy(*ctx_); break;
case 32: etype = llvm::Type::getFloatTy(*ctx_); break;
case 64: etype = llvm::Type::getDoubleTy(*ctx_); break;
default: LOG(FATAL) << "do not support " << t;
}
} else {
CHECK(t.is_handle());
ret = t_void_p_;
}
if (t.lanes() != 1) {
ret = llvm::VectorType::get(ret, t.lanes());
return llvm::VectorType::get(etype, t.lanes());
} else {
return etype;
}
return ret;
}
void CodeGenLLVM::AddAliasInfo(
llvm::Instruction* inst, const Variable* buffer, Expr index, Type t) {
// Add tbaa alias information for load
//
// use a binary tree typed system to declare information
// and allow alias to be distinguished across nodes.
//
// This trick comes from Halide's CodeGen_LLVM
//
void CodeGenLLVM::AddAliasInfo(llvm::Instruction* inst,
const Variable* buffer,
Expr index,
Type type) {
if (alias_var_set_.count(buffer) != 0) {
// Mark all possibly aliased pointer as same type.
llvm::MDNode* meta = md_tbaa_alias_set_;
......@@ -242,7 +268,7 @@ void CodeGenLLVM::AddAliasInfo(
}
int base = 0, width = 0;
// create meta-data for alias analysis
// Use a group of binary tree ranges.
// Use a group of binary tree ranges of memory banks.
if (index.defined()) {
const Ramp* ramp = index.as<Ramp>();
if (ramp) {
......@@ -267,7 +293,7 @@ void CodeGenLLVM::AddAliasInfo(
std::ostringstream buffer_addr, buffer_type;
buffer_addr << buffer;
meta = md_builder_->createTBAAScalarTypeNode(buffer_addr.str(), meta);
buffer_type << t.element_of();
buffer_type << type.element_of();
meta = md_builder_->createTBAAScalarTypeNode(buffer_type.str(), meta);
// create a tree-shape access structure.
if (width != 0) {
......@@ -283,6 +309,36 @@ void CodeGenLLVM::AddAliasInfo(
md_builder_->createTBAAStructTagNode(meta, meta, 0));
}
void CodeGenLLVM::GetAlignment(Type t,
const Variable* buf_var,
const Expr& index,
int* p_alignment,
int* p_native_bits) {
int max_align_bits = t.bits();
auto it = alloc_storage_info_.find(buf_var);
if (it != alloc_storage_info_.end()) {
const StorageInfo& info = it->second;
*p_native_bits = NativeVectorBits(info.scope);
max_align_bits = info.alignment * 8;
} else {
*p_native_bits = native_vector_bits_;
}
arith::ModularEntry me = arith::EvalModular(index, align_map_);
int align_bits = t.bits();
while (align_bits < max_align_bits &&
me.base % 2 == 0 &&
me.coeff %2 == 0) {
me.base = me.base / 2;
me.coeff = me.coeff / 2;
align_bits *= 2;
}
if (align_bits < 8) {
align_bits = 8;
}
*p_alignment = align_bits / 8;
}
llvm::Value* CodeGenLLVM::CreateBroadcast(llvm::Value* value, int lanes) {
llvm::Constant* undef = llvm::UndefValue::get(
llvm::VectorType::get(value->getType(), lanes));
......@@ -292,26 +348,103 @@ llvm::Value* CodeGenLLVM::CreateBroadcast(llvm::Value* value, int lanes) {
return builder_->CreateShuffleVector(value, undef, mask);
}
llvm::Value* CodeGenLLVM::CreateBufferPtr(
Type t, llvm::Value* buffer, llvm::Value* index) {
llvm::Type* elem_type = buffer->getType();
unsigned address_space = elem_type->getPointerAddressSpace();
llvm::Type* load_type = LLVMType(t)->getPointerTo(address_space);
llvm::Value* CodeGenLLVM::CreateVecSlice(llvm::Value* vec, int begin, int extent) {
int num_elems = static_cast<int>(vec->getType()->getVectorNumElements());
if (extent == num_elems && begin == 0) return vec;
CHECK_LT(begin + extent, num_elems);
std::vector<unsigned> indices;
for (int i = 0; i < extent; ++i) {
indices.push_back(begin + i);
}
return builder_->CreateShuffleVector(vec, vec, indices);
}
if (load_type != elem_type) {
buffer = builder_->CreatePointerCast(buffer, load_type);
llvm::Value* CodeGenLLVM::CreateVecFlip(llvm::Value* vec) {
int num_elems = static_cast<int>(vec->getType()->getVectorNumElements());
std::vector<unsigned> indices;
for (int i = 0; i < num_elems; ++i) {
indices.push_back(num_elems - i - 1);
}
llvm::Constant* cindex = llvm::dyn_cast<llvm::Constant>(index);
if (cindex && cindex->isZeroValue()) {
return buffer;
return builder_->CreateShuffleVector(vec, vec, indices);
}
llvm::Value* CodeGenLLVM::CreateVecPad(llvm::Value* vec, int target_lanes) {
llvm::Value* mask = llvm::UndefValue::get(LLVMType(Int(32, target_lanes)));
int num_elems = static_cast<int>(vec->getType()->getVectorNumElements());
if (num_elems == target_lanes) return vec;
CHECK_LT(num_elems, target_lanes);
for (int i = 0; i < num_elems; ++i) {
mask = builder_->CreateInsertElement(mask, ConstInt32(i), ConstInt32(i));
}
return builder_->CreateInBoundsGEP(buffer, index);
return builder_->CreateShuffleVector(vec, vec, mask);
}
llvm::Value* CodeGenLLVM::CreateVecConcat(std::vector<llvm::Value*> vecs) {
// concat vector, tree shape reduction
int total_lanes = 0;
for (llvm::Value* v : vecs) {
total_lanes += static_cast<int>(
v->getType()->getVectorNumElements());
}
while (vecs.size() > 1) {
for (size_t i = 0; i < vecs.size(); i+=2) {
if (i + 1 >= vecs.size()) {
vecs[i / 2] = vecs[i]; continue;
}
llvm::Value* lhs = vecs[i];
llvm::Value* rhs = vecs[i + 1];
int lanes = static_cast<int>(std::max(
lhs->getType()->getVectorNumElements(),
rhs->getType()->getVectorNumElements()));
lhs = CreateVecPad(lhs, lanes);
rhs = CreateVecPad(lhs, lanes);
std::vector<unsigned> mask;
for (int i = 0; i < lanes * 2; ++i) {
mask.push_back(i);
}
vecs[i / 2] = builder_->CreateShuffleVector(lhs, rhs, mask);
}
vecs.resize((vecs.size() + 1) / 2);
}
return CreateVecSlice(vecs[0], 0, total_lanes);
}
void CodeGenLLVM::CreateSerialFor(llvm::Value* begin,
llvm::Value* end,
llvm::Value* stride,
const VarExpr& loop_var,
const Stmt& body) {
using llvm::BasicBlock;
BasicBlock* pre_block = builder_->GetInsertBlock();
BasicBlock* for_begin = BasicBlock::Create(
*ctx_, "for_begin", function_);
BasicBlock* for_body = BasicBlock::Create(
*ctx_, "for_body", function_);
BasicBlock* for_end = BasicBlock::Create(
*ctx_, "for_end", function_);
builder_->CreateBr(for_begin);
builder_->SetInsertPoint(for_begin);
llvm::PHINode* loop_value = builder_->CreatePHI(begin->getType(), 2);
loop_value->addIncoming(begin, pre_block);
CHECK(!var_map_.count(loop_var.get()));
var_map_[loop_var.get()] = loop_value;
builder_->CreateCondBr(CreateLT(loop_var.type(), loop_value, end),
for_body, for_end, md_very_likely_branch_);
builder_->SetInsertPoint(for_body);
this->VisitStmt(body);
var_map_.erase(loop_var.get());
llvm::Value* loop_next = CreateAdd(loop_var.type(), loop_value, stride);
loop_value->addIncoming(loop_next, builder_->GetInsertBlock());
builder_->CreateBr(for_begin);
builder_->SetInsertPoint(for_end);
}
// cast operatpr
llvm::Value* CodeGenLLVM::CreateCast(Type from, Type to, llvm::Value* value) {
llvm::Type * target = LLVMType(to);
if (value->getType() == target) return value;
if (from.is_handle() && from.is_handle()) {
if (to.is_handle()) {
return builder_->CreateBitCast(value, target);
} else if (!from.is_float() && !to.is_float()) {
return builder_->CreateIntCast(value, target, from.is_int());
......@@ -334,262 +467,159 @@ llvm::Value* CodeGenLLVM::CreateCast(Type from, Type to, llvm::Value* value) {
}
}
llvm::CallInst* CodeGenLLVM::CreateCallExtern(
llvm::Type* ret,
const std::string& name,
const std::vector<llvm::Value*>& arg_values) {
std::vector<llvm::Type*> arg_types;
for (llvm::Value* v : arg_values) {
arg_types.push_back(v->getType());
}
llvm::FunctionType* ftype = llvm::FunctionType::get(ret, arg_types, false);
llvm::Function* f = module_->getFunction(name);
if (f == nullptr) {
f = llvm::Function::Create(
ftype, llvm::Function::ExternalLinkage, name, module_.get());
}
return builder_->CreateCall(f, arg_values);
}
llvm::Value* CodeGenLLVM::CreateCallExtern(const Call* op) {
std::vector<llvm::Value*> arg_values(op->args.size());
for (size_t i = 0; i < op->args.size(); ++i) {
arg_values[i] = MakeValue(op->args[i]);
}
return CreateCallExtern(LLVMType(op->type), op->name, arg_values);
llvm::Value* CodeGenLLVM::GetConstString(const std::string& str) {
auto it = str_map_.find(str);
if (it != str_map_.end()) return it->second;
llvm::Type* type = llvm::ArrayType::get(t_char_, str.length() + 1);
llvm::GlobalVariable *global = new llvm::GlobalVariable(
*module_, type, true, llvm::GlobalValue::PrivateLinkage, 0, ".str");
global->setAlignment(1);
global->setInitializer(llvm::ConstantDataArray::getString(*ctx_, str));
llvm::Constant* zero = ConstInt32(0);
llvm::Constant* indices[] = {zero, zero};
llvm::Constant* ptr = llvm::ConstantExpr::getGetElementPtr(
type, global, indices);
str_map_[str] = ptr;
return ptr;
}
llvm::Value* CodeGenLLVM::CreateScalarizedCall(
const Call* op, llvm::Function* f, const std::vector<llvm::Value*>& args) {
llvm::Value* value = llvm::UndefValue::get(LLVMType(op->type));
for (int i = 0; i < op->type.lanes(); ++i) {
std::vector<llvm::Value*> sargs(args.size());
for (size_t j = 0; j < args.size(); ++j) {
if (args[j]->getType()->isVectorTy()) {
sargs[j] = builder_->CreateExtractElement(args[j], ConstInt32(i));
} else {
sargs[j] = args[j];
}
}
llvm::CallInst* call = builder_->CreateCall(f, sargs);
if (op->is_pure()) {
call->setDoesNotAccessMemory();
}
call->setDoesNotThrow();
if (!call->getType()->isVoidTy()) {
value = builder_->CreateInsertElement(value, call, ConstInt32(i));
}
llvm::Value* CodeGenLLVM::CreateBufferPtr(
Type t, llvm::Value* buffer, llvm::Value* index) {
CHECK_EQ(t.lanes(), 1);
llvm::PointerType* btype = llvm::dyn_cast<llvm::PointerType>(buffer->getType());
CHECK(btype != nullptr);
llvm::PointerType* ptype = LLVMType(t)->getPointerTo(btype->getAddressSpace());
if (btype != ptype) {
buffer = builder_->CreatePointerCast(buffer, ptype);
}
return value;
return builder_->CreateInBoundsGEP(buffer, index);
}
llvm::Value* CodeGenLLVM::GetVarValue(const Variable* v) const {
auto it = var_map_.find(v);
CHECK(it != var_map_.end())
<< "Cannot find " << v->name_hint << " in the var map";
CHECK(it != var_map_.end()) << "cannot find variable " << v->name_hint;
return it->second;
}
llvm::Value* CodeGenLLVM::GetConstString(const std::string& str) {
auto it = str_map_.find(str);
if (it == str_map_.end()) {
llvm::Type* type = llvm::ArrayType::get(t_char_, str.length() + 1);
llvm::GlobalVariable *global = new llvm::GlobalVariable(
*module_, type, true, llvm::GlobalValue::PrivateLinkage, 0, ".str");
global->setAlignment(1);
global->setInitializer(llvm::ConstantDataArray::getString(*ctx_, str));
// useful constant value
llvm::Constant* zero = ConstInt32(0);
llvm::Constant* indices[] = {zero, zero};
llvm::Constant* sptr = llvm::ConstantExpr::getGetElementPtr(
type, global, indices);
str_map_[str] = sptr;
return sptr;
} else {
return it->second;
llvm::Value* CodeGenLLVM::CreateCallExtern(const Call* op) {
CHECK_GE(op->args.size(), 1U);
std::vector<llvm::Value*> arg_value;
std::vector<llvm::Type*> arg_type;
for (size_t i = 1; i < op->args.size(); ++i) {
arg_value.push_back(MakeValue(op->args[i + 1]));
arg_type.push_back(arg_value.back()->getType());
}
}
void CodeGenLLVM::CreateSerialFor(llvm::Value* begin,
llvm::Value* end,
llvm::Value* stride,
const VarExpr& loop_var, const Stmt& body) {
using llvm::BasicBlock;
Type t = loop_var.type();
BasicBlock* for_head = BasicBlock::Create(
*ctx_, "for_head", function_);
BasicBlock* for_body = BasicBlock::Create(
*ctx_, "for_body", function_);
BasicBlock* for_end = BasicBlock::Create(
*ctx_, "for_end", function_);
BasicBlock* pre_block = builder_->GetInsertBlock();
builder_->CreateBr(for_head);
builder_->SetInsertPoint(for_head);
llvm::PHINode* index = builder_->CreatePHI(begin->getType(), 2);
index->addIncoming(begin, pre_block);
llvm::Value* cond = CreateLT(t, index, end);
builder_->CreateCondBr(cond, for_body, for_end, md_very_likely_branch_);
// body of for
builder_->SetInsertPoint(for_body);
var_map_[loop_var.get()] = index;
this->VisitStmt(body);
llvm::Value* next_index = CreateAdd(t, index, stride);
index->addIncoming(next_index, builder_->GetInsertBlock());
builder_->CreateBr(for_head);
// end of for
builder_->SetInsertPoint(for_end);
llvm::FunctionType* ftype = llvm::FunctionType::get(
LLVMType(op->type), arg_type, false);
llvm::Function* f = module_->getFunction(op->name);
if (f == nullptr) {
f = llvm::Function::Create(
ftype, llvm::Function::ExternalLinkage,
op->name, module_.get());
}
llvm::CallInst* call = builder_->CreateCall(f, arg_value);
return call;
}
llvm::Value* CodeGenLLVM::CreateIntrinsic(const Call* op) {
if (op->is_intrinsic("llvm_intrin")) {
CHECK_GE(op->args.size(), 1U);
std::vector<llvm::Value*> arg_values;
std::vector<llvm::Type*> arg_types;
llvm::Intrinsic::ID id = static_cast<llvm::Intrinsic::ID>(
op->args[0].as<UIntImm>()->value);
std::vector<llvm::Value*> arg_value;
std::vector<llvm::Type*> arg_type;
for (size_t i = 1; i < op->args.size(); ++i) {
llvm::Value* v = MakeValue(op->args[i]);
arg_values.push_back(v);
arg_types.push_back(v->getType());
arg_value.push_back(MakeValue(op->args[i]));
arg_type.push_back(arg_value.back()->getType());
}
auto id = static_cast<llvm::Intrinsic::ID>(op->args[0].as<UIntImm>()->value);
llvm::Function* f = llvm::Intrinsic::getDeclaration(
module_.get(), id, arg_types);
return builder_->CreateCall(f, arg_values);
module_.get(), id, arg_type);
return builder_->CreateCall(f, arg_value);
} else if (op->is_intrinsic("llvm_builtin")) {
CHECK_GE(op->args.size(), 1U);
std::vector<llvm::Value*> arg_values;
llvm::Intrinsic::ID id = static_cast<llvm::Intrinsic::ID>(
op->args[0].as<UIntImm>()->value);
std::vector<llvm::Value*> arg_value;
for (size_t i = 1; i < op->args.size(); ++i) {
llvm::Value* v = MakeValue(op->args[i]);
arg_values.push_back(v);
arg_value.push_back(MakeValue(op->args[i]));
}
auto id = static_cast<llvm::Intrinsic::ID>(op->args[0].as<UIntImm>()->value);
llvm::Function* f = llvm::Intrinsic::getDeclaration(module_.get(), id);
return builder_->CreateCall(f, arg_values);
} else if (op->is_intrinsic(intrinsic::tvm_storage_sync)) {
return CreateStorageSync(op);
llvm::Function* f = llvm::Intrinsic::getDeclaration(module_.get(), id, {});
return builder_->CreateCall(f, arg_value);
} else if (op->is_intrinsic(Call::bitwise_and)) {
CHECK_EQ(op->args.size(), 2U);
return builder_->CreateAnd(
MakeValue(op->args[0]), MakeValue(op->args[1]));
} else if (op->is_intrinsic(Call::bitwise_xor)) {
CHECK_EQ(op->args.size(), 2U);
return builder_->CreateXor(
MakeValue(op->args[0]), MakeValue(op->args[1]));
return builder_->CreateAnd(MakeValue(op->args[0]), MakeValue(op->args[1]));
} else if (op->is_intrinsic(Call::bitwise_or)) {
CHECK_EQ(op->args.size(), 2U);
return builder_->CreateOr(
MakeValue(op->args[0]), MakeValue(op->args[1]));
return builder_->CreateOr(MakeValue(op->args[0]), MakeValue(op->args[1]));
} else if (op->is_intrinsic(Call::bitwise_not)) {
CHECK_EQ(op->args.size(), 1U);
return builder_->CreateNot(MakeValue(op->args[0]));
} else if (op->is_intrinsic(Call::bitwise_xor)) {
return builder_->CreateXor(MakeValue(op->args[0]), MakeValue(op->args[1]));
} else if (op->is_intrinsic(Call::shift_left)) {
CHECK_EQ(op->args.size(), 2U);
return builder_->CreateShl(
MakeValue(op->args[0]), MakeValue(op->args[1]));
return builder_->CreateShl(MakeValue(op->args[0]), MakeValue(op->args[1]));
} else if (op->is_intrinsic(Call::shift_right)) {
CHECK_EQ(op->args.size(), 2U);
if (op->type.is_int()) {
return builder_->CreateAShr(
MakeValue(op->args[0]), MakeValue(op->args[1]));
if (op->args[0].type().is_int()) {
return builder_->CreateAShr(MakeValue(op->args[0]), MakeValue(op->args[1]));
} else {
return builder_->CreateLShr(
MakeValue(op->args[0]), MakeValue(op->args[1]));
return builder_->CreateLShr(MakeValue(op->args[0]), MakeValue(op->args[1]));
}
} else if (op->is_intrinsic(intrinsic::tvm_storage_sync)) {
return CreateStorageSync(op);
} else if (op->is_intrinsic(intrinsic::tvm_address_of)) {
const Load *l = op->args[0].as<Load>();
CHECK(op->args.size() == 1 && l);
return builder_->CreatePointerCast(
CreateBufferPtr(
l->type, GetVarValue(l->buffer_var.get()), MakeValue(l->index)),
t_void_p_);
llvm::Value* ptr = CreateBufferPtr(
l->type, MakeValue(l->buffer_var), MakeValue(l->index));
unsigned addrspace = llvm::dyn_cast<llvm::PointerType>(
ptr->getType())->getAddressSpace();
return builder_->CreatePointerCast(ptr, t_void_->getPointerTo(addrspace));
} else if (op->is_intrinsic(Call::reinterpret) && is_zero(op->args[0])) {
return llvm::Constant::getNullValue(t_void_p_);
} else if (op->is_intrinsic(intrinsic::tvm_handle_is_null)) {
CHECK_EQ(op->args.size(), 1U);
llvm::Value* ptr = MakeValue(op->args[0]);
return builder_->CreateICmpEQ(
ptr, llvm::Constant::getNullValue(ptr->getType()));
return builder_->CreateIsNull(MakeValue(op->args[0]));
} else if (op->is_intrinsic(intrinsic::tvm_if_then_else)) {
using llvm::BasicBlock;
CHECK_EQ(op->args.size(), 3U);
llvm::Value* cond = MakeValue(op->args[0]);
BasicBlock* then_block = BasicBlock::Create(
*ctx_, "if_then", function_);
BasicBlock* else_block = BasicBlock::Create(
*ctx_, "if_else", function_);
BasicBlock* end_block = BasicBlock::Create(
*ctx_, "if_end", function_);
builder_->CreateCondBr(cond, then_block, else_block);
// Then
builder_->CreateCondBr(MakeValue(op->args[0]), then_block, else_block);
builder_->SetInsertPoint(then_block);
llvm::Value* then_value = MakeValue(op->args[1]);
builder_->CreateBr(end_block);
builder_->SetInsertPoint(else_block);
// else
llvm::Value* else_value = MakeValue(op->args[2]);
builder_->CreateBr(end_block);
builder_->SetInsertPoint(end_block);
// phi
llvm::PHINode* phi = builder_->CreatePHI(then_value->getType(), 2);
phi->addIncoming(then_value, then_block);
phi->addIncoming(else_value, else_block);
return phi;
} else if (op->is_intrinsic(Call::reinterpret) && is_zero(op->args[0])) {
return llvm::Constant::getNullValue(t_void_p_);
llvm::PHINode* value = builder_->CreatePHI(then_value->getType(), 2);
value->addIncoming(then_value, then_block);
value->addIncoming(else_value, else_block);
return value;
} else {
LOG(FATAL) << "Unknown intrinstic " << op->name;
LOG(FATAL) << "unknown intrinsic " << op->name;
return nullptr;
}
return nullptr;
}
// Get the corresponding thread index
llvm::Value* CodeGenLLVM::GetThreadIndex(const IterVar& iv) {
LOG(FATAL) << "Donot support threading " << iv;
return nullptr;
}
llvm::Value* CodeGenLLVM::CreateStorageSync(const Call* op) {
LOG(FATAL) << "Donot support storage sync in CPU mode";
return nullptr;
}
int CodeGenLLVM::NativeVectorBits(const runtime::StorageScope& storage_scope) const {
// By default, we ask the buffer to be aligned to 64 bytes
return native_vector_bits_;
}
unsigned CodeGenLLVM::GetGlobalAddressSpace() {
return 0;
}
void CodeGenLLVM::GetAlignment(
Type t, const Variable* buf_var, const Expr& index,
int* p_alignment, int* p_native_bits) {
int& alignment = *p_alignment;
int& native_bits = *p_native_bits;
// The storage scope.
StorageInfo info;
auto it = alloc_storage_info_.find(buf_var);
if (it != alloc_storage_info_.end()) {
info = it->second;
}
arith::ModularEntry m = EvalModular(index, align_map_);
native_bits = NativeVectorBits(info.scope);
alignment = t.element_of().bits();
// find alignment, cannot exceed allocated alignment
int max_align_bits = std::min(
info.alignment * 8, alignment * t.lanes());
while ((m.coeff & 1) == 0 &&
(m.base & 1) == 0 &&
alignment < max_align_bits &&
alignment < native_bits) {
m.coeff /= 2;
m.base /= 2;
alignment *= 2;
void CodeGenLLVM::Scalarize(const Expr& e,
std::function<void(int i, llvm::Value* v)> f) {
if (const Ramp* ramp = e.as<Ramp>()) {
for (int i = 0; i < ramp->type.lanes(); ++i) {
Expr offset = arith::ComputeExpr<Add>(
ramp->base,
arith::ComputeExpr<Mul>(ramp->stride, i));
f(i, MakeValue(offset));
}
} else {
llvm::Value* value = MakeValue(e);
for (int i = 0; i < e.type().lanes(); ++i) {
f(i, builder_->CreateExtractElement(value, i));
}
}
CHECK_EQ(alignment % 8, 0)
<< "Load from memory that does not align to 8 bits";
alignment /= 8;
}
// visitor overrides
// Visitors
llvm::Value* CodeGenLLVM::VisitExpr_(const Variable* op) {
return GetVarValue(op);
}
......@@ -597,7 +627,6 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Variable* op) {
llvm::Value* CodeGenLLVM::VisitExpr_(const Cast* op) {
return CreateCast(op->value.type(), op->type, MakeValue(op->value));
}
llvm::Value* CodeGenLLVM::VisitExpr_(const IntImm* op) {
return llvm::ConstantInt::getSigned(LLVMType(op->type), op->value);
}
......@@ -614,121 +643,114 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const StringImm* op) {
return GetConstString(op->value);
}
#define DEFINE_CODEGEN_BINARY_OP(OP) \
llvm::Value* CodeGenLLVM::Create ## OP( \
#define DEFINE_CODEGEN_BINARY_OP(Op) \
llvm::Value* CodeGenLLVM::Create ## Op( \
Type t, llvm::Value* a, llvm::Value *b) { \
if (t.is_float()) { \
return builder_->CreateF ## OP (a, b); \
} else if (t.is_int() && t.bits() >= 32) { \
return builder_->CreateNSW ## OP (a, b); \
if (t.is_int()) { \
if (t.bits() >= 32) { \
return builder_->CreateNSW ## Op (a, b); \
} else { \
return builder_->Create ## Op (a, b); \
} \
} else if (t.is_uint()) { \
if (t.bits() >= 32) { \
return builder_->CreateNUW ## Op (a, b); \
} else { \
return builder_->Create ## Op (a, b); \
} \
} else { \
return builder_->Create ## OP (a, b); \
CHECK(t.is_float()); \
return builder_->CreateF ## Op (a, b); \
} \
} \
llvm::Value* CodeGenLLVM::VisitExpr_(const Op* op) { \
return Create ## Op(op->type, MakeValue(op->a), MakeValue(op->b)); \
}
DEFINE_CODEGEN_BINARY_OP(Add);
DEFINE_CODEGEN_BINARY_OP(Sub);
DEFINE_CODEGEN_BINARY_OP(Mul);
llvm::Value* CodeGenLLVM::VisitExpr_(const Add* op) {
return CreateAdd(op->type, MakeValue(op->a), MakeValue(op->b));
}
llvm::Value* CodeGenLLVM::VisitExpr_(const Sub* op) {
return CreateSub(op->type, MakeValue(op->a), MakeValue(op->b));
}
#define DEFINE_CODEGEN_CMP_OP(Op) \
llvm::Value* CodeGenLLVM::Create ## Op( \
Type t, llvm::Value* a, llvm::Value* b) { \
if (t.is_int()) { \
return builder_->CreateICmpS ## Op (a, b); \
} else if (t.is_uint()) { \
return builder_->CreateICmpU ## Op (a, b); \
} else { \
CHECK(t.is_float()); \
return builder_->CreateFCmpO ## Op (a, b); \
} \
} \
llvm::Value* CodeGenLLVM::VisitExpr_(const Op* op) { \
return Create ## Op(op->a.type(), MakeValue(op->a), MakeValue(op->b)); \
}
llvm::Value* CodeGenLLVM::VisitExpr_(const Mul* op) {
return CreateMul(op->type, MakeValue(op->a), MakeValue(op->b));
}
DEFINE_CODEGEN_CMP_OP(LT);
DEFINE_CODEGEN_CMP_OP(LE);
DEFINE_CODEGEN_CMP_OP(GT);
DEFINE_CODEGEN_CMP_OP(GE);
llvm::Value* CodeGenLLVM::VisitExpr_(const Div* op) {
llvm::Value* a = MakeValue(op->a);
llvm::Value* b = MakeValue(op->b);
int shift;
if (op->type.is_float()) {
return builder_->CreateFDiv(a, MakeValue(op->b));
} else if ((op->type.is_int() || op->type.is_uint()) &&
is_const_power_of_two_integer(op->b, &shift)) {
if ((op->type.is_int() || op->type.is_uint()) &&
is_const_power_of_two_integer(op->b, &shift)) {
return builder_->CreateAShr(a, shift);
} else if (op->type.is_int()) {
return builder_->CreateSDiv(a, b);
} else if (op->type.is_uint()) {
return builder_->CreateUDiv(a, b);
} else {
llvm::Value* b = MakeValue(op->b);
if (op->type.is_int()) {
return builder_->CreateSDiv(a, b);
} else {
CHECK(op->type.is_uint());
return builder_->CreateUDiv(a, b);
}
CHECK(op->type.is_float());
return builder_->CreateFDiv(a, b);
}
}
llvm::Value* CodeGenLLVM::VisitExpr_(const Mod* op) {
CHECK(!op->type.is_float())
<< "Cannot do mod for float";
llvm::Value* a = MakeValue(op->a);
llvm::Value* b = MakeValue(op->b);
if (op->type.is_int()) {
return builder_->CreateSRem(MakeValue(op->a), MakeValue(op->b));
return builder_->CreateSRem(a, b);
} else if (op->type.is_uint()) {
return builder_->CreateURem(a, b);
} else {
CHECK(op->type.is_uint());
return builder_->CreateURem(MakeValue(op->a), MakeValue(op->b));
CHECK(op->type.is_float());
return builder_->CreateFRem(a, b);
}
}
llvm::Value* CodeGenLLVM::VisitExpr_(const Min* op) {
llvm::Value* a = MakeValue(op->a);
llvm::Value* b = MakeValue(op->b);
llvm::Value* cond = CreateLT(op->a.type(), a, b);
return builder_->CreateSelect(cond, a, b);
return builder_->CreateSelect(CreateLT(op->a.type(), a, b), a, b);
}
llvm::Value* CodeGenLLVM::VisitExpr_(const Max* op) {
llvm::Value* a = MakeValue(op->a);
llvm::Value* b = MakeValue(op->b);
llvm::Value* cond = CreateGT(op->a.type(), a, b);
return builder_->CreateSelect(cond, a, b);
}
#define DEFINE_CODEGEN_CMP_OP(OP) \
llvm::Value* CodeGenLLVM::Create ## OP( \
Type t, llvm::Value* a, llvm::Value* b) { \
if (t.is_float()) { \
return builder_->CreateFCmpO ## OP (a, b); \
} else if (t.is_int()) { \
return builder_->CreateICmpS ## OP (a, b); \
} else { \
return builder_->CreateICmpU ## OP (a, b); \
} \
} \
DEFINE_CODEGEN_CMP_OP(LT);
DEFINE_CODEGEN_CMP_OP(LE);
DEFINE_CODEGEN_CMP_OP(GT);
DEFINE_CODEGEN_CMP_OP(GE);
llvm::Value* CodeGenLLVM::VisitExpr_(const LT* op) {
return CreateLT(op->a.type(), MakeValue(op->a), MakeValue(op->b));
}
llvm::Value* CodeGenLLVM::VisitExpr_(const LE* op) {
return CreateLE(op->a.type(), MakeValue(op->a), MakeValue(op->b));
}
llvm::Value* CodeGenLLVM::VisitExpr_(const GT* op) {
return CreateGT(op->a.type(), MakeValue(op->a), MakeValue(op->b));
}
llvm::Value* CodeGenLLVM::VisitExpr_(const GE* op) {
return CreateGE(op->a.type(), MakeValue(op->a), MakeValue(op->b));
return builder_->CreateSelect(CreateGT(op->a.type(), a, b), a, b);
}
llvm::Value* CodeGenLLVM::VisitExpr_(const EQ* op) {
if (op->a.type().is_float()) {
return builder_->CreateFCmpOEQ(MakeValue(op->a), MakeValue(op->b));
llvm::Value* a = MakeValue(op->a);
llvm::Value* b = MakeValue(op->b);
if (op->a.type().is_int() || op->a.type().is_uint()) {
return builder_->CreateICmpEQ(a, b);
} else {
return builder_->CreateICmpEQ(MakeValue(op->a), MakeValue(op->b));
return builder_->CreateFCmpOEQ(a, b);
}
}
llvm::Value* CodeGenLLVM::VisitExpr_(const NE* op) {
if (op->a.type().is_float()) {
return builder_->CreateFCmpONE(MakeValue(op->a), MakeValue(op->b));
llvm::Value* a = MakeValue(op->a);
llvm::Value* b = MakeValue(op->b);
if (op->a.type().is_int() || op->a.type().is_uint()) {
return builder_->CreateICmpNE(a, b);
} else {
return builder_->CreateICmpNE(MakeValue(op->a), MakeValue(op->b));
return builder_->CreateFCmpONE(a, b);
}
}
......@@ -752,345 +774,161 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Select* op) {
}
llvm::Value* CodeGenLLVM::VisitExpr_(const Let* op) {
llvm::Value* v = MakeValue(op->value);
CHECK(!var_map_.count(op->var.get()));
CHECK(!align_map_.count(op->var.get()));
var_map_[op->var.get()] = v;
align_map_[op->var.get()] = arith::EvalModular(op->value, align_map_);
var_map_[op->var.get()] = MakeValue(op->value);
align_map_[op->var.get()] = EvalModular(op->value, align_map_);
return MakeValue(op->body);
}
llvm::Value* CodeGenLLVM::VisitExpr_(const Broadcast* op) {
return CreateBroadcast(MakeValue(op->value), op->lanes);
}
llvm::Value* CodeGenLLVM::VisitExpr_(const Ramp* op) {
llvm::Value* CodeGenLLVM::VisitExpr_(const Load* op) {
Type t = op->type;
llvm::Value* base = MakeValue(op->base);
llvm::Value* stride = MakeValue(op->stride);
llvm::Value* value = llvm::UndefValue::get(LLVMType(t));
for (int i = 0; i < t.lanes(); ++i) {
if (i != 0) {
base = CreateAdd(t, base, stride);
}
value = builder_->CreateInsertElement(
value, base, llvm::ConstantInt::get(t_int32_, i));
}
return value;
}
void CodeGenLLVM::Scalarize(
const Expr& e,
std::function<void(int i, llvm::Value* v)> f) {
const Ramp* ramp = e.as<Ramp>();
Type t = e.type();
if (ramp) {
for (int i = 0; i < t.lanes(); ++i) {
Expr offset = arith::ComputeExpr<Add>(
ramp->base,
arith::ComputeExpr<Mul>(ramp->stride, i));
f(i, MakeValue(offset));
}
int alignment, native_bits;
bool is_volatile = volatile_buf_.count(op->buffer_var.get());
GetAlignment(t, op->buffer_var.get(), op->index, &alignment, &native_bits);
llvm::Value* buffer = MakeValue(op->buffer_var);
llvm::Value* index = MakeValue(op->index);
if (t.lanes() == 1) {
llvm::Value* ptr = CreateBufferPtr(t, buffer, index);
llvm::LoadInst* load = builder_->CreateAlignedLoad(ptr, alignment, is_volatile);
AddAliasInfo(load, op->buffer_var.get(), op->index, t);
return load;
} else {
llvm::Value* index = MakeValue(e);
for (int i = 0; i < t.lanes(); ++i) {
f(i, builder_->CreateExtractElement(index, ConstInt32(i)));
// vector load
unsigned addrspace = llvm::dyn_cast<llvm::PointerType>(
buffer->getType())->getAddressSpace();
if (const Ramp* ramp = op->index.as<Ramp>()) {
if (is_one(ramp->stride)) {
CHECK_EQ(ramp->lanes, t.lanes());
llvm::Value* ptr = CreateBufferPtr(
t.element_of(), buffer, MakeValue(ramp->base));
ptr = builder_->CreatePointerCast(ptr, LLVMType(t)->getPointerTo(addrspace));
llvm::LoadInst* load = builder_->CreateAlignedLoad(ptr, alignment, is_volatile);
AddAliasInfo(load, op->buffer_var.get(), op->index, t);
return load;
}
}
}
// scalarized load.
int basic_align = t.bits() / 8;
llvm::Value* ret = llvm::UndefValue::get(LLVMType(t));
auto f = [&](int i, llvm::Value* index) {
llvm::Value* ptr = CreateBufferPtr(t.element_of(), buffer, index);
llvm::LoadInst* load = builder_->CreateAlignedLoad(
ptr, basic_align, is_volatile);
ret = builder_->CreateInsertElement(ret, load, ConstInt32(i));
AddAliasInfo(load, op->buffer_var.get(), Expr(), t);
};
this->Scalarize(op->index, f);
return ret;
}
llvm::Value* CodeGenLLVM::CreateVecFlip(llvm::Value* vec) {
int lanes = static_cast<int>(vec->getType()->getVectorNumElements());
std::vector<llvm::Constant*> indices;
for (int i = lanes; i != 0; --i) {
indices.push_back(ConstInt32(i - 1));
}
llvm::Constant* undef = llvm::UndefValue::get(vec->getType());
return builder_->CreateShuffleVector(
vec, undef, llvm::ConstantVector::get(indices));
}
llvm::Value* CodeGenLLVM::CreateVecSlice(
llvm::Value* vec, int begin, int lanes) {
int total_lanes = static_cast<int>(vec->getType()->getVectorNumElements());
CHECK_LE(begin + lanes, total_lanes);
if (lanes == total_lanes && begin == 0) return vec;
std::vector<llvm::Constant*> indices;
for (int i = 0; i < lanes; ++i) {
indices.push_back(ConstInt32(begin + i));
}
llvm::Constant* undef = llvm::UndefValue::get(vec->getType());
return builder_->CreateShuffleVector(
vec, undef, llvm::ConstantVector::get(indices));
}
llvm::Value* CodeGenLLVM::CreateVecPad(llvm::Value* vec, int target_lanes) {
int lanes = static_cast<int>(vec->getType()->getVectorNumElements());
if (target_lanes == lanes) return vec;
CHECK_GT(target_lanes, lanes);
int pad_lanes = target_lanes - lanes;
llvm::Constant* undef = llvm::UndefValue::get(
llvm::VectorType::get(vec->getType()->getVectorElementType(), pad_lanes));
std::vector<llvm::Constant*> indices;
for (int i = 0; i < target_lanes; ++i) {
indices.push_back(ConstInt32(i));
llvm::Value* CodeGenLLVM::VisitExpr_(const Call* op) {
if (op->call_type == Call::Intrinsic ||
op->call_type == Call::PureIntrinsic) {
return CreateIntrinsic(op);
} else if (op->call_type == Call::Extern ||
op->call_type == Call::PureExtern) {
return CreateCallExtern(op);
} else {
LOG(FATAL) << "Unknown call type ";
return nullptr;
}
return builder_->CreateShuffleVector(
vec, undef, llvm::ConstantVector::get(indices));
}
llvm::Value* CodeGenLLVM::CreateVecConcat(
std::vector<llvm::Value*> vec) {
CHECK_NE(vec.size(), 0U);
int target_lanes = 0;
for (llvm::Value* v : vec) {
target_lanes += static_cast<int>(v->getType()->getVectorNumElements());
}
// tree shape merging
while (vec.size() != 1) {
std::vector<llvm::Value*> merged;
for (size_t i = 0; i < vec.size() - 1; i += 2) {
llvm::Value* v1 = vec[i];
llvm::Value* v2 = vec[i + 1];
int w1 = static_cast<int>(v1->getType()->getVectorNumElements());
int w2 = static_cast<int>(v2->getType()->getVectorNumElements());
int w = std::max(w1, w2);
v1 = CreateVecPad(v1, w);
v2 = CreateVecPad(v2, w);
std::vector<llvm::Constant*> indices;
for (int i = 0; i < w * 2; ++i) {
indices.push_back(ConstInt32(i));
}
merged.push_back(
builder_->CreateShuffleVector(
v1, v2, llvm::ConstantVector::get(indices)));
}
if (vec.size() % 2 == 1) {
merged.push_back(vec.back());
}
vec = merged;
llvm::Value* CodeGenLLVM::VisitExpr_(const Ramp* op) {
llvm::Value* vec = llvm::UndefValue::get(LLVMType(op->type));
for (int i = 0; i < op->lanes; ++i) {
vec = builder_->CreateInsertElement(
vec, MakeValue(op->base + op->stride * make_const(op->stride.type(), i)),
ConstInt32(i));
}
return CreateVecSlice(vec[0], 0, target_lanes);
return vec;
}
llvm::Value* CodeGenLLVM::VisitExpr_(const Load* op) {
CHECK(is_one(op->predicate))
<< "Predicated Load is not supported";
Type t = op->type;
const Ramp* ramp = op->index.as<Ramp>();
llvm::Value* buf = GetVarValue(op->buffer_var.get());
if (t.is_scalar()) {
llvm::LoadInst* inst = builder_->CreateAlignedLoad(
CreateBufferPtr(t, buf, MakeValue(op->index)),
data_layout_->getTypeAllocSize(LLVMType(t)));
AddAliasInfo(inst, op->buffer_var.get(), op->index, op->type);
return inst;
} else if (ramp && is_one(ramp->stride)) {
int alignment, native_bits;
GetAlignment(t, op->buffer_var.get(), ramp->base,
&alignment, &native_bits);
int total_lanes = t.lanes();
int step = native_bits / t.bits();
std::vector<llvm::Value*> loads;
for (int offset = 0; offset < total_lanes; offset += step) {
int lanes = std::min(step, total_lanes - offset);
Expr base = arith::ComputeExpr<Add>(
ramp->base, make_const(ramp->base.type(), offset));
llvm::Value* ptr = CreateBufferPtr(t.element_of(), buf, MakeValue(base));
llvm::Type* vtype = llvm::VectorType::get(
LLVMType(t.element_of()), lanes)->getPointerTo(
ptr->getType()->getPointerAddressSpace());
llvm::LoadInst* inst = builder_->CreateAlignedLoad(
builder_->CreatePointerCast(ptr, vtype), alignment);
AddAliasInfo(inst, op->buffer_var.get(),
Ramp::make(base, make_const(base.type(), 1), lanes), op->type);
loads.push_back(inst);
}
return CreateVecConcat(loads);
} else if (ramp && is_const(ramp->stride, 2)) {
int alignment, native_bits;
GetAlignment(t, op->buffer_var.get(), ramp->base,
&alignment, &native_bits);
arith::ModularEntry e = arith::EvalModular(ramp->base, align_map_);
Type bt = ramp->base.type();
int first_shift, next_shift;
// If it is even base, and native alignments is bigger than twice
// of the type, to ensure safe loading.
if (e.coeff % 2 == 0 &&
e.base % 2 == 0 &&
native_bits >= t.bits() * 2) {
first_shift = 0;
next_shift = 0;
} else if (e.coeff % 2 == 0 && e.base % 2 == 1) {
// odd base, shift both to left.
first_shift = -1;
next_shift = -1;
} else {
// save option, right part, safe option.
first_shift = 0;
next_shift = -1;
}
llvm::Value* first = MakeValue(Load::make(
t, op->buffer_var,
Ramp::make(arith::ComputeExpr<Add>(
ramp->base, make_const(bt, first_shift)),
make_const(bt, 1), ramp->lanes),
const_true(t.lanes())));
llvm::Value* next = MakeValue(Load::make(
t, op->buffer_var,
Ramp::make(arith::ComputeExpr<Add>(
ramp->base, make_const(bt, ramp->lanes + next_shift)),
make_const(bt, 1), ramp->lanes),
const_true(t.lanes())));
// shuffle
std::vector<llvm::Constant*> indices;
int target_index = 0;
for (int i = 0; i < ramp->lanes; ++i) {
int idx = first_shift + i;
if (idx == target_index) {
indices.push_back(ConstInt32(i));
target_index += 2;
}
}
for (int i = 0; i < ramp->lanes; ++i) {
int idx = ramp->lanes + next_shift + i;
if (idx == target_index) {
indices.push_back(ConstInt32(i + ramp->lanes));
target_index += 2;
}
}
CHECK_EQ(indices.size(), static_cast<size_t>(ramp->lanes));
return builder_->CreateShuffleVector(
first, next, llvm::ConstantVector::get(indices));
} else if (ramp && is_const(ramp->stride, -1)) {
int lanes = ramp->type.lanes();
Expr neg_ramp = Ramp::make(
arith::ComputeExpr<Sub>(
ramp->base,
make_const(ramp->base.type(), lanes - 1)),
make_const(ramp->base.type(), 1),
lanes);
// load value then flip
llvm::Value* v = MakeValue(
Load::make(t, op->buffer_var, neg_ramp, const_true(t.lanes())));
return CreateVecFlip(v);
} else {
llvm::Value* ret = llvm::UndefValue::get(LLVMType(t));
Scalarize(op->index, [&](int i, llvm::Value* offset) {
llvm::Value* ptr = CreateBufferPtr(t.element_of(), buf, offset);
llvm::LoadInst* inst = builder_->CreateAlignedLoad(
ptr, data_layout_->getTypeAllocSize(LLVMType(t)));
AddAliasInfo(inst, op->buffer_var.get(), Expr(), op->type);
ret = builder_->CreateInsertElement(ret, inst, ConstInt32(i));
});
return ret;
}
llvm::Value* CodeGenLLVM::VisitExpr_(const Broadcast* op) {
return CreateBroadcast(MakeValue(op->value), op->lanes);
}
// stmts
void CodeGenLLVM::VisitStmt_(const Store* op) {
CHECK(is_one(op->predicate))
<< "Predicated Load is not supported";
llvm::Value* value = MakeValue(op->value);
CHECK(is_one(op->predicate));
Type t = op->value.type();
const Ramp* ramp = op->index.as<Ramp>();
llvm::Value* buf = GetVarValue(op->buffer_var.get());
if (t.is_scalar()) {
llvm::StoreInst* inst = builder_->CreateAlignedStore(
value,
CreateBufferPtr(t, buf, MakeValue(op->index)),
data_layout_->getTypeAllocSize(value->getType()));
AddAliasInfo(inst, op->buffer_var.get(), op->index, op->value.type());
} else if (ramp && is_one(ramp->stride)) {
int alignment, native_bits;
GetAlignment(t, op->buffer_var.get(), ramp->base,
&alignment, &native_bits);
int total_lanes = t.lanes();
int step = native_bits / t.bits();
// vector store.
for (int offset = 0; offset < total_lanes; offset += step) {
int lanes = std::min(step, total_lanes - offset);
Expr base = arith::ComputeExpr<Add>(
ramp->base, make_const(ramp->base.type(), offset));
llvm::Value* ptr = CreateBufferPtr(t.element_of(), buf, MakeValue(base));
llvm::Type* vtype = llvm::VectorType::get(
LLVMType(t.element_of()), lanes)->getPointerTo(
ptr->getType()->getPointerAddressSpace());
llvm::StoreInst* inst = builder_->CreateAlignedStore(
CreateVecSlice(value, offset, lanes),
builder_->CreatePointerCast(ptr, vtype), alignment);
AddAliasInfo(inst, op->buffer_var.get(),
Ramp::make(base, make_const(base.type(), 1), lanes), op->value.type());
}
} else {
Scalarize(op->index, [&](int i, llvm::Value* offset) {
llvm::Value* ptr = CreateBufferPtr(t.element_of(), buf, offset);
llvm::StoreInst* inst = builder_->CreateAlignedStore(
builder_->CreateExtractElement(value, ConstInt32(i)),
ptr, data_layout_->getTypeAllocSize(LLVMType(t)));
AddAliasInfo(inst, op->buffer_var.get(), Expr(), op->value.type());
});
}
}
int alignment, native_bits;
bool is_volatile = volatile_buf_.count(op->buffer_var.get());
GetAlignment(t, op->buffer_var.get(), op->index, &alignment, &native_bits);
llvm::Value* buffer = MakeValue(op->buffer_var);
llvm::Value* index = MakeValue(op->index);
llvm::Value* value = MakeValue(op->value);
llvm::Value* CodeGenLLVM::VisitExpr_(const Call* op) {
if (op->call_type == Call::Intrinsic ||
op->call_type == Call::PureIntrinsic) {
return CreateIntrinsic(op);
if (t.lanes() == 1) {
llvm::Value* ptr = CreateBufferPtr(t, buffer, index);
llvm::StoreInst* store = builder_->CreateAlignedStore(value, ptr, alignment, is_volatile);
AddAliasInfo(store, op->buffer_var.get(), op->index, op->value.type());
return;
} else {
CHECK(op->call_type == Call::Extern ||
op->call_type == Call::PureExtern);
return CreateCallExtern(op);
// vector store
unsigned addrspace = llvm::dyn_cast<llvm::PointerType>(
buffer->getType())->getAddressSpace();
if (const Ramp* ramp = op->index.as<Ramp>()) {
if (is_one(ramp->stride)) {
CHECK_EQ(ramp->lanes, t.lanes());
llvm::Value* ptr = CreateBufferPtr(
t.element_of(), buffer, MakeValue(ramp->base));
ptr = builder_->CreatePointerCast(ptr, LLVMType(t)->getPointerTo(addrspace));
llvm::StoreInst* store = builder_->CreateAlignedStore(value, ptr, alignment, is_volatile);
AddAliasInfo(store, op->buffer_var.get(), op->index, op->value.type());
return;
}
}
}
CHECK_GE(t.bits(), 8);
// scalarized store.
int basic_align = t.bits() / 8;
auto f = [&](int i, llvm::Value* index) {
llvm::Value* ptr = CreateBufferPtr(t.element_of(), buffer, index);
llvm::StoreInst* store = builder_->CreateAlignedStore(
builder_->CreateExtractElement(value, i),
ptr, basic_align, is_volatile);
AddAliasInfo(store, op->buffer_var.get(), Expr(), op->value.type());
};
this->Scalarize(op->index, f);
}
void CodeGenLLVM::VisitStmt_(const For* op) {
CHECK(is_zero(op->min));
if (op->for_type == ForType::Serial) {
CreateSerialFor(ConstInt32(0),
MakeValue(op->extent),
ConstInt32(1),
op->loop_var,
op->body);
} else {
LOG(FATAL) << "cannot handle for type " << op->for_type;
}
CHECK(op->for_type == ForType::Serial);
CreateSerialFor(MakeValue(op->min), MakeValue(op->extent),
ConstInt32(1), op->loop_var, op->body);
}
void CodeGenLLVM::VisitStmt_(const IfThenElse* op) {
using llvm::BasicBlock;
llvm::Value* cond = MakeValue(op->condition);
BasicBlock* then_block = BasicBlock::Create(
*ctx_, "if_then", function_);
BasicBlock* else_block = BasicBlock::Create(
*ctx_, "if_else", function_);
BasicBlock* end_block = BasicBlock::Create(
*ctx_, "if_end", function_);
if (!op->else_case.defined()) {
else_block = end_block;
}
// condition.
llvm::Value* cond = MakeValue(op->condition);
bool likely = true;
if (likely) {
builder_->CreateCondBr(cond, then_block, else_block, md_very_likely_branch_);
} else {
builder_->CreateCondBr(cond, then_block, else_block);
}
// then case.
builder_->SetInsertPoint(then_block);
this->VisitStmt(op->then_case);
builder_->CreateBr(end_block);
// else case.
if (op->else_case.defined()) {
BasicBlock* else_block = BasicBlock::Create(
*ctx_, "if_else", function_);
builder_->CreateCondBr(cond, then_block, else_block);
builder_->SetInsertPoint(then_block);
this->VisitStmt(op->then_case);
builder_->CreateBr(end_block);
builder_->SetInsertPoint(else_block);
this->VisitStmt(op->else_case);
builder_->CreateBr(end_block);
} else {
builder_->CreateCondBr(cond, then_block, end_block, md_very_likely_branch_);
builder_->SetInsertPoint(then_block);
this->VisitStmt(op->then_case);
builder_->CreateBr(end_block);
}
builder_->SetInsertPoint(end_block);
}
void CodeGenLLVM::VisitStmt_(const Allocate* op) {
CHECK(!is_zero(op->condition));
llvm::Value* buf = nullptr;
......@@ -1100,20 +938,22 @@ void CodeGenLLVM::VisitStmt_(const Allocate* op) {
} else {
int32_t constant_size = op->constant_allocation_size();
CHECK_GT(constant_size, 0)
<< "Can only handle constant size stack allocation for now";
llvm::AllocaInst* alloca = builder_->CreateAlloca(
LLVMType(op->type), ConstInt32(constant_size));
buf = alloca;
<< "Can only handle constant size stack allocation";
StorageInfo& info = alloc_storage_info_[op->buffer_var.get()];
// Align stack to be TempAllocaAlignment.
// TODO(tqchen) have pass to detect vector access and pre-set alignment
if (constant_size % 4 == 0 && info.alignment == 0) {
info.alignment = GetTempAllocaAlignment(op->type, constant_size);
}
// maximum necessary alignment in the NV devices
if (info.alignment > 16) {
info.alignment = 16;
}
llvm::AllocaInst* alloca = builder_->CreateAlloca(
LLVMType(op->type), ConstInt32(constant_size));
if (alloca->getAlignment() < static_cast<uint32_t>(info.alignment)) {
alloca->setAlignment(info.alignment);
}
info.alignment = alloca->getAlignment();
buf = alloca;
}
buf = builder_->CreatePointerCast(
buf, LLVMType(op->type)->getPointerTo(
......@@ -1124,29 +964,29 @@ void CodeGenLLVM::VisitStmt_(const Allocate* op) {
}
void CodeGenLLVM::VisitStmt_(const AttrStmt* op) {
if (op->attr_key == ir::attr::thread_extent) {
if (op->attr_key == attr::thread_extent) {
IterVar iv(op->node.node_);
if (iv->thread_tag.length() != 0) {
if (!var_map_.count(iv->var.get())) {
var_map_[iv->var.get()] = GetThreadIndex(iv);
}
}
this->VisitStmt(op->body);
} else if (op->attr_key == ir::attr::storage_scope) {
const Variable* v = op->node.as<Variable>();
CHECK(v);
alloc_storage_info_[v].scope = runtime::StorageScope::make(
op->value.as<StringImm>()->value);
this->VisitStmt(op->body);
alloc_storage_info_[v].scope =
runtime::StorageScope::make(op->value.as<StringImm>()->value);
} else if (op->attr_key == ir::attr::storage_alignment) {
const Variable* v = op->node.as<Variable>();
CHECK(v);
alloc_storage_info_[v].alignment =
static_cast<int>(op->value.as<IntImm>()->value);
this->VisitStmt(op->body);
} else {
this->VisitStmt(op->body);
} else if (op->attr_key == ir::attr::volatile_scope) {
const Variable* v = op->node.as<Variable>();
CHECK(v);
volatile_buf_.insert(v);
}
this->VisitStmt(op->body);
}
void CodeGenLLVM::VisitStmt_(const AssertStmt* op) {
......@@ -1178,7 +1018,6 @@ void CodeGenLLVM::VisitStmt_(const AssertStmt* op) {
}
void CodeGenLLVM::VisitStmt_(const LetStmt* op) {
llvm::Value* v = MakeValue(op->value);
CHECK(!var_map_.count(op->var.get()));
CHECK(!align_map_.count(op->var.get()));
if (op->var.type().is_handle()) {
......@@ -1186,24 +1025,25 @@ void CodeGenLLVM::VisitStmt_(const LetStmt* op) {
alias_var_set_.insert(op->var.get());
}
}
var_map_[op->var.get()] = v;
align_map_[op->var.get()] = arith::EvalModular(op->value, align_map_);
var_map_[op->var.get()] = MakeValue(op->value);
align_map_[op->var.get()] = EvalModular(op->value, align_map_);
this->VisitStmt(op->body);
}
void CodeGenLLVM::VisitStmt_(const Block* op) {
VisitStmt(op->first);
if (op->rest.defined()) VisitStmt(op->rest);
this->VisitStmt(op->first);
if (op->rest.defined()) {
this->VisitStmt(op->rest);
}
}
void CodeGenLLVM::VisitStmt_(const Evaluate *op) {
void CodeGenLLVM::VisitStmt_(const Evaluate* op) {
MakeValue(op->value);
}
void CodeGenLLVM::VisitStmt_(const ProducerConsumer* op) {
VisitStmt(op->body);
this->VisitStmt(op->body);
}
} // namespace codegen
} // namespace tvm
#endif // TVM_LLVM_VERSION
......@@ -242,6 +242,8 @@ class CodeGenLLVM :
std::unordered_map<const Variable*, arith::ModularEntry> align_map_;
// set of var that are not restricted(can alias)
std::unordered_set<const Variable*> alias_var_set_;
// set of volatile buffer.
std::unordered_set<const Variable*> volatile_buf_;
};
} // namespace codegen
} // namespace tvm
......
......@@ -28,7 +28,13 @@ inline Expr BroadcastTo(Expr e, int lanes) {
}
// Rewrite vectorized allocation access
// This is necessary for making each vector component containing its own workspace.
// Originates from Halide's loop vectorizer
//
// s[i] = s[i * lanes + var]
//
// The same principle applies when using one thread to simulate multiple context.
//
class VecAllocAccess : public IRMutator {
public:
VecAllocAccess(const Variable* buf, Var var, int var_lanes)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment