Commit 72d64520 by Tianqi Chen Committed by GitHub

[CODEGEN][LLVM] Refactor cpu runtime related code to CodeGenCPU (#361)

parent 7d5d9ec9
......@@ -4,19 +4,19 @@
* \brief ARM specific code generator
*/
#ifdef TVM_LLVM_VERSION
#include "./codegen_llvm.h"
#include "./codegen_cpu.h"
namespace tvm {
namespace codegen {
// ARM specific code generator, this is used as an example on
// how to override behavior llvm code generator for specific target
class CodeGenARM final : public CodeGenLLVM {
class CodeGenARM final : public CodeGenCPU {
public:
void InitTarget(llvm::TargetMachine* tm) final {
// set native vector bits.
native_vector_bits_ = 16 * 8;
CodeGenLLVM::InitTarget(tm);
CodeGenCPU::InitTarget(tm);
}
};
......
/*!
* Copyright (c) 2017 by Contributors
* \file codegen_cpu.cc
*/
#ifdef TVM_LLVM_VERSION
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/ir_pass.h>
#include "./codegen_cpu.h"
#include "../../pass/ir_util.h"
namespace tvm {
namespace codegen {
void CodeGenCPU::Init(const std::string& module_name,
llvm::TargetMachine* tm,
llvm::LLVMContext* ctx,
bool system_lib,
bool dynamic_lookup) {
CodeGenLLVM::Init(module_name, tm, ctx, system_lib, dynamic_lookup);
static_assert(sizeof(TVMValue) == sizeof(double), "invariant");
func_handle_map_.clear();
export_system_symbols_.clear();
// TVM runtime types
t_tvm_shape_index_ = llvm::Type::getIntNTy(*ctx, TVMShapeIndexType().bits());
t_tvm_context_ = llvm::StructType::create({t_int_, t_int_});
t_tvm_type_ = llvm::StructType::create({t_int8_, t_int8_, t_int16_});
t_tvm_func_handle_ = t_void_p_;
t_tvm_array_ = llvm::StructType::create(
{t_void_p_,
t_tvm_context_,
t_int_,
t_tvm_type_,
t_tvm_shape_index_->getPointerTo(),
t_tvm_shape_index_->getPointerTo(),
t_int64_});
t_tvm_value_ = llvm::StructType::create({t_float64_});
t_tvm_parallel_group_env_ = llvm::StructType::create({
t_int32_->getPointerTo(), t_int32_});
ftype_tvm_parallel_lambda_ = llvm::FunctionType::get(
t_int_,
{t_int_,
t_tvm_parallel_group_env_->getPointerTo(),
t_void_p_}, false);
md_tbaa_ctx_ptr_ = md_builder_->createTBAAScalarTypeNode("ctx_ptr", md_tbaa_root_);
// Runtime functions.
ftype_tvm_func_call_ = llvm::FunctionType::get(t_int_, {
t_tvm_func_handle_,
t_tvm_value_->getPointerTo(),
t_int_->getPointerTo(),
t_int_,
t_tvm_value_->getPointerTo(),
t_int_->getPointerTo()}, false);
ftype_tvm_get_func_from_env_ = llvm::FunctionType::get(t_int_, {
t_void_p_,
t_char_->getPointerTo(),
t_tvm_func_handle_->getPointerTo()}, false);
ftype_tvm_api_set_last_error_ = llvm::FunctionType::get(
t_void_, {t_char_->getPointerTo()}, false);
ftype_tvm_parallel_launch_ =
llvm::FunctionType::get(t_int_, {
ftype_tvm_parallel_lambda_->getPointerTo(), t_void_p_, t_int_}
, false);
ftype_tvm_parallel_barrier_ =
llvm::FunctionType::get(t_int_, {
t_int_, t_tvm_parallel_group_env_->getPointerTo()}
, false);
ftype_tvm_static_init_callback_ =
llvm::FunctionType::get(t_int_, {t_void_p_}, false);
ftype_tvm_static_init_ =
llvm::FunctionType::get(t_int_, {
t_void_p_->getPointerTo(),
ftype_tvm_static_init_callback_->getPointerTo(),
t_void_p_, t_int_}
, false);
// initialize TVM runtime API
if (system_lib) {
// We will need this in environment for backward registration.
f_tvm_register_system_symbol_ = llvm::Function::Create(
llvm::FunctionType::get(t_int_, {t_char_->getPointerTo(), t_void_p_}, false),
llvm::Function::ExternalLinkage, "TVMBackendRegisterSystemLibSymbol", module_.get());
} else {
f_tvm_register_system_symbol_ = nullptr;
}
if (dynamic_lookup || system_lib) {
f_tvm_func_call_ = llvm::Function::Create(
ftype_tvm_func_call_,
llvm::Function::ExternalLinkage, "TVMFuncCall", module_.get());
f_tvm_get_func_from_env_ = llvm::Function::Create(
ftype_tvm_get_func_from_env_,
llvm::Function::ExternalLinkage, "TVMBackendGetFuncFromEnv", module_.get());
f_tvm_api_set_last_error_ = llvm::Function::Create(
ftype_tvm_api_set_last_error_,
llvm::Function::ExternalLinkage, "TVMAPISetLastError", module_.get());
f_tvm_parallel_launch_ = llvm::Function::Create(
ftype_tvm_parallel_launch_,
llvm::Function::ExternalLinkage, "TVMBackendParallelLaunch", module_.get());
f_tvm_parallel_barrier_ = llvm::Function::Create(
ftype_tvm_parallel_barrier_,
llvm::Function::ExternalLinkage, "TVMBackendParallelBarrier", module_.get());
}
this->InitGlobalContext(dynamic_lookup);
}
void CodeGenCPU::AddFunction(const LoweredFunc& f) {
CodeGenLLVM::AddFunction(f);
if (f_tvm_register_system_symbol_ != nullptr) {
export_system_symbols_.emplace_back(
std::make_pair(f->name, builder_->CreatePointerCast(function_, t_void_p_)));
}
}
void CodeGenCPU::AddMainFunction(const std::string& entry_func_name) {
llvm::Function* f = module_->getFunction(entry_func_name);
CHECK(f) << "Function " << entry_func_name << "does not in module";
llvm::Type* type = llvm::ArrayType::get(t_char_, entry_func_name.length() + 1);
llvm::GlobalVariable *global = new llvm::GlobalVariable(
*module_, type, true, llvm::GlobalValue::WeakAnyLinkage, 0,
runtime::symbol::tvm_module_main);
global->setAlignment(1);
global->setInitializer(llvm::ConstantDataArray::getString(*ctx_, entry_func_name));
}
llvm::Value* CodeGenCPU::CreateStructRefPtr(
Type t, llvm::Value* buf, llvm::Value* index, int kind) {
if (kind < intrinsic::kArrKindBound_) {
if (buf->getType() == t_void_p_) {
buf = builder_->CreatePointerCast(buf, t_tvm_array_->getPointerTo());
} else {
CHECK_EQ(buf->getType(), t_tvm_array_->getPointerTo());
}
}
switch (kind) {
case intrinsic::kArrAddr: {
return builder_->CreateInBoundsGEP(buf, index);
}
case intrinsic::kArrData: {
return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(0)});
}
case intrinsic::kArrShape: {
return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(4)});
}
case intrinsic::kArrStrides: {
return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(5)});
}
case intrinsic::kArrNDim: {
return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(2)});
}
case intrinsic::kArrTypeCode: {
return builder_->CreateInBoundsGEP(
buf, {index, ConstInt32(3), ConstInt32(0)});
}
case intrinsic::kArrTypeBits: {
return builder_->CreateInBoundsGEP(
buf, {index, ConstInt32(3), ConstInt32(1)});
}
case intrinsic::kArrTypeLanes: {
return builder_->CreateInBoundsGEP(
buf, {index, ConstInt32(3), ConstInt32(2)});
}
case intrinsic::kArrByteOffset: {
return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(6)});
}
case intrinsic::kArrDeviceId: {
return builder_->CreateInBoundsGEP(
buf, {index, ConstInt32(1), ConstInt32(1)});
}
case intrinsic::kArrDeviceType: {
return builder_->CreateInBoundsGEP(
buf, {index, ConstInt32(1), ConstInt32(0)});
}
case intrinsic::kTVMValueContent: {
CHECK_EQ(t.lanes(), 1);
CHECK(t.is_handle() || t.bits() == 64);
if (t.is_int()) {
buf = builder_->CreatePointerCast(buf, t_int64_->getPointerTo());
return builder_->CreateInBoundsGEP(buf, index);
} else if (t.is_float()) {
buf = builder_->CreatePointerCast(buf, t_float64_->getPointerTo());
return builder_->CreateInBoundsGEP(buf, index);
} else {
CHECK(t.is_handle());
buf = builder_->CreatePointerCast(buf, t_tvm_value_->getPointerTo());
buf = builder_->CreateInBoundsGEP(buf, index);
return builder_->CreatePointerCast(buf, t_void_p_->getPointerTo());
}
}
default: LOG(FATAL) << "unknown field code"; return nullptr;
}
}
llvm::Value* CodeGenCPU::CreateCallExtern(const Call* op) {
std::vector<llvm::Value*> arg_values(op->args.size());
for (size_t i = 0; i < op->args.size(); ++i) {
arg_values[i] = MakeValue(op->args[i]);
}
std::vector<llvm::Type*> arg_types;
for (llvm::Value* v : arg_values) {
arg_types.push_back(v->getType());
}
llvm::FunctionType* ftype = llvm::FunctionType::get(
LLVMType(op->type), arg_types, false);
// Check if it is available in global function table as injected function.
auto it = gv_func_map_.find(op->name);
if (it != gv_func_map_.end()) {
if (it->second == nullptr) {
gv_func_map_[op->name] = InitContextPtr(ftype->getPointerTo(), "__" + op->name);
it = gv_func_map_.find(op->name);
}
return builder_->CreateCall(GetContextPtr(it->second), arg_values);
} else {
llvm::Function* f = module_->getFunction(op->name);
if (f == nullptr) {
f = llvm::Function::Create(
ftype, llvm::Function::ExternalLinkage, op->name, module_.get());
}
return builder_->CreateCall(f, arg_values);
}
}
llvm::GlobalVariable* CodeGenCPU::InitContextPtr(
llvm::Type* p_type, std::string name) {
llvm::GlobalVariable* gv = new llvm::GlobalVariable(
*module_, p_type, false,
llvm::GlobalValue::LinkOnceAnyLinkage, 0,
name);
gv->setAlignment(data_layout_->getTypeAllocSize(p_type));
gv->setInitializer(llvm::Constant::getNullValue(p_type));
return gv;
}
llvm::Value* CodeGenCPU::GetContextPtr(llvm::GlobalVariable* gv) {
CHECK(gv != nullptr);
llvm::LoadInst* faddr = builder_->CreateAlignedLoad(gv, gv->getAlignment());
faddr->setMetadata(
"tbaa",
md_builder_->createTBAAStructTagNode(md_tbaa_ctx_ptr_, md_tbaa_ctx_ptr_, 0));
return faddr;
}
void CodeGenCPU::InitGlobalContext(bool dynamic_lookup) {
// Module context
gv_mod_ctx_ = InitContextPtr(t_void_p_, tvm::runtime::symbol::tvm_module_ctx);
// Register back the locations.
if (f_tvm_register_system_symbol_ != nullptr) {
export_system_symbols_.emplace_back(
std::make_pair(tvm::runtime::symbol::tvm_module_ctx, gv_mod_ctx_));
} else {
if (!dynamic_lookup) {
gv_tvm_func_call_ = InitContextPtr(
ftype_tvm_func_call_->getPointerTo(), "__TVMFuncCall");
gv_tvm_get_func_from_env_ = InitContextPtr(
ftype_tvm_get_func_from_env_->getPointerTo(), "__TVMBackendGetFuncFromEnv");
gv_tvm_api_set_last_error_ = InitContextPtr(
ftype_tvm_api_set_last_error_->getPointerTo(), "__TVMAPISetLastError");
gv_tvm_parallel_launch_ = InitContextPtr(
ftype_tvm_parallel_launch_->getPointerTo(), "__TVMBackendParallelLaunch");
gv_tvm_parallel_barrier_ = InitContextPtr(
ftype_tvm_parallel_barrier_->getPointerTo(), "__TVMBackendParallelBarrier");
// Mark as context functions
gv_func_map_["TVMBackendAllocWorkspace"] = nullptr;
gv_func_map_["TVMBackendFreeWorkspace"] = nullptr;
}
}
}
llvm::BasicBlock* CodeGenCPU::CheckCallSuccess(llvm::Value* retcode) {
// create emit codes that checks and load the function.
using llvm::BasicBlock;
BasicBlock* fail_block = BasicBlock::Create(
*ctx_, "call_fail", function_);
BasicBlock* end_block = BasicBlock::Create(
*ctx_, "call_end", function_);
llvm::Value* succ = builder_->CreateICmpEQ(
retcode, llvm::ConstantInt::get(t_int_, 0));
builder_->CreateCondBr(succ, end_block, fail_block, md_very_likely_branch_);
builder_->SetInsertPoint(fail_block);
// return the code.
builder_->CreateRet(retcode);
// otherwise set it to be new end.
builder_->SetInsertPoint(end_block);
return end_block;
}
void CodeGenCPU::CreateComputeScope(const AttrStmt* op) {
// There are two reasons why we create another function for compute_scope
// - Make sure the generated compute function is clearly separately(though it can get inlined)
// - Set noalias on all the pointer arguments, some of them are loaded from TVMArgs.
// This is easier than set the alias scope manually.
using llvm::BasicBlock;
Array<Var> vargs = ir::UndefinedVars(op->body, {});
std::vector<llvm::Value*> arg_values;
std::vector<llvm::Type*> arg_types;
for (Var v : vargs) {
llvm::Value* value = MakeValue(v);
arg_values.push_back(value);
arg_types.push_back(value->getType());
}
llvm::FunctionType* ftype =
llvm::FunctionType::get(t_int_, arg_types, false);
llvm::Function* fcompute =
llvm::Function::Create(ftype,
llvm::Function::PrivateLinkage,
op->value.as<StringImm>()->value,
module_.get());
BasicBlock* compute_call_end = CheckCallSuccess(
builder_->CreateCall(fcompute, arg_values));
// setup compute fuinction.
std::unordered_map<const Variable*, llvm::Value*> new_vmap;
size_t idx = 0;
for (auto it = fcompute->arg_begin();
it != fcompute->arg_end(); ++it, ++idx) {
llvm::Argument* v = &(*it);
const Var& var = vargs[idx];
new_vmap[var.get()] = v;
if (var.type().is_handle() && !alias_var_set_.count(var.get())) {
// set non alias.
#if TVM_LLVM_VERSION >= 50
fcompute->addParamAttr(idx + 1, llvm::Attribute::NoAlias);
#else
fcompute->setDoesNotAlias(idx + 1);
#endif
}
}
std::swap(function_, fcompute);
std::swap(new_vmap, var_map_);
BasicBlock *compute_entry = BasicBlock::Create(*ctx_, "entry", function_);
builder_->SetInsertPoint(compute_entry);
this->VisitStmt(op->body);
builder_->CreateRet(ConstInt32(0));
// swap the var map back, now we are back on track.
std::swap(new_vmap, var_map_);
std::swap(function_, fcompute);
builder_->SetInsertPoint(compute_call_end);
}
llvm::Value* CodeGenCPU::PackClosureData(const Array<Var>& vfields) {
std::vector<llvm::Type*> fields;
for (Var v : vfields) {
auto it = var_map_.find(v.get());
CHECK(it != var_map_.end());
fields.push_back(it->second->getType());
}
llvm::StructType* tcdata = llvm::StructType::create(fields);
llvm::Value* cdata = builder_->CreateAlloca(tcdata, ConstInt32(1));
llvm::Value* zero = ConstInt32(0);
for (size_t i = 0; i < vfields.size(); ++i) {
builder_->CreateStore(
var_map_.at(vfields[i].get()),
builder_->CreateInBoundsGEP(cdata, {zero, ConstInt32(i)}));
}
return cdata;
}
void CodeGenCPU::UnpackClosureData(llvm::Value* cdata,
const Array<Var>& vfields,
std::unordered_map<const Variable*, llvm::Value*>* vmap) {
for (size_t i = 0; i < vfields.size(); ++i) {
(*vmap)[vfields[i].get()] =
builder_->CreateLoad(builder_->CreateInBoundsGEP(
cdata, {ConstInt32(0), ConstInt32(i)}));
}
}
void CodeGenCPU::CreateParallelLaunch(const Stmt& body, int num_task) {
using llvm::BasicBlock;
// closure data
llvm::Function* f = llvm::Function::Create(
ftype_tvm_parallel_lambda_,
llvm::Function::PrivateLinkage,
"__tvm_parallel_lambda", module_.get());
// allocate and setup the closure, call the closure.
Array<Var> vfields = ir::UndefinedVars(body, {});
llvm::Value* cdata = PackClosureData(vfields);
BasicBlock* par_launch_end = CheckCallSuccess(
builder_->CreateCall(
RuntimeTVMParallelLaunch(),
{f, builder_->CreatePointerCast(cdata, t_void_p_), ConstInt32(num_task)}));
// Setup the closure function.
BasicBlock *lambda_entry = BasicBlock::Create(*ctx_, "entry", f);
builder_->SetInsertPoint(lambda_entry);
auto it = f->arg_begin();
llvm::Value* task_id = &(*it++);
llvm::Value* penv = &(*it++);
cdata = builder_->CreatePointerCast(&(*it++), cdata->getType());
// setup new variable map, swap it with current var context.
std::unordered_map<const Variable*, llvm::Value*> new_vmap;
UnpackClosureData(cdata, vfields, &new_vmap);
// setup parallel env
ParallelEnv par_env;
par_env.task_id = Var("task_id", Int(32));
par_env.num_task = Var("num_task", Int(32));
new_vmap[par_env.task_id.get()] = task_id;
new_vmap[par_env.num_task.get()] = builder_->CreateLoad(
builder_->CreateInBoundsGEP(
penv, {ConstInt32(0), ConstInt32(1)}));
par_env.penv = penv;
std::swap(function_, f);
std::swap(parallel_env_, par_env);
std::swap(var_map_, new_vmap);
this->VisitStmt(body);
builder_->CreateRet(ConstInt32(0));
// swap the var map back, now we are back on track.
std::swap(var_map_, new_vmap);
std::swap(parallel_env_, par_env);
std::swap(function_, f);
CHECK(par_env.hit_parallel_loop)
<< "Cannot find parallel loop within parallel launch";
builder_->SetInsertPoint(par_launch_end);
}
void CodeGenCPU::CreateStaticInit(const std::string& init_fname, const Stmt& body) {
using llvm::BasicBlock;
// closure data
llvm::Function* f = llvm::Function::Create(
ftype_tvm_static_init_callback_,
llvm::Function::PrivateLinkage,
"__tvm_static_init_lambda", module_.get());
llvm::GlobalVariable* gv = new llvm::GlobalVariable(
*module_, t_void_p_, false,
llvm::GlobalValue::PrivateLinkage, 0,
"__tvm_static_handle");
gv->setAlignment(data_layout_->getTypeAllocSize(t_void_p_));
gv->setInitializer(llvm::Constant::getNullValue(t_void_p_));
llvm::Function* finit = module_->getFunction(init_fname);
if (finit == nullptr) {
finit = llvm::Function::Create(
ftype_tvm_static_init_, llvm::Function::ExternalLinkage, init_fname, module_.get());
}
// allocate and setup the closure, call the closure.
Array<Var> vfields = ir::UndefinedVars(body, {});
llvm::Value* cdata = PackClosureData(vfields);
llvm::Value* nbytes = ConstInt32(data_layout_->getTypeAllocSize(
llvm::cast<llvm::PointerType>(cdata->getType())->getElementType()));
BasicBlock* init_end = CheckCallSuccess(
builder_->CreateCall(
finit,
{gv, f, builder_->CreatePointerCast(cdata, t_void_p_), nbytes}));
// Setup the closure function.
BasicBlock *lambda_entry = BasicBlock::Create(*ctx_, "entry", f);
builder_->SetInsertPoint(lambda_entry);
auto it = f->arg_begin();
cdata = builder_->CreatePointerCast(&(*it++), cdata->getType());
// setup new variable map, swap it with current var context.
std::unordered_map<const Variable*, llvm::Value*> new_vmap;
UnpackClosureData(cdata, vfields, &new_vmap);
CHECK(parallel_env_.penv == nullptr);
std::swap(function_, f);
std::swap(var_map_, new_vmap);
this->VisitStmt(body);
builder_->CreateRet(ConstInt32(0));
// swap the var map back, now we are back on track.
std::swap(var_map_, new_vmap);
std::swap(function_, f);
builder_->SetInsertPoint(init_end);
}
llvm::Value* CodeGenCPU::GetPackedFuncHandle(const std::string& fname) {
using llvm::BasicBlock;
// We will store the packed function handle in global space.
// Initialize it during the first call.
llvm::DataLayout layout(module_.get());
uint64_t align = layout.getTypeAllocSize(t_tvm_func_handle_);
auto it = func_handle_map_.find(fname);
llvm::GlobalVariable* hptr;
if (it == func_handle_map_.end()) {
// create global location for the handle
// create the function handle
hptr = new llvm::GlobalVariable(
*module_, t_tvm_func_handle_, false,
llvm::GlobalValue::LinkOnceAnyLinkage, 0, ".tvm_func." + fname);
hptr->setAlignment(align);
hptr->setInitializer(llvm::Constant::getNullValue(t_tvm_func_handle_));
func_handle_map_[fname] = hptr;
} else {
hptr = it->second;
}
// create emit codes that checks and load the function.
BasicBlock* pre_block = builder_->GetInsertBlock();
BasicBlock* init_block = BasicBlock::Create(
*ctx_, "handle_init", function_);
BasicBlock* end_block = BasicBlock::Create(
*ctx_, "handle_init_end", function_);
llvm::Value* handle = builder_->CreateAlignedLoad(hptr, align);
llvm::Value* handle_not_null = builder_->CreateICmpNE(
handle, llvm::Constant::getNullValue(t_tvm_func_handle_));
builder_->CreateCondBr(
handle_not_null, end_block, init_block, md_very_likely_branch_);
// Initialize the handle if needed.
builder_->SetInsertPoint(init_block);
llvm::Value* out = builder_->CreateAlloca(t_tvm_func_handle_);
llvm::LoadInst* ctx = builder_->CreateAlignedLoad(
gv_mod_ctx_, gv_mod_ctx_->getAlignment());
ctx->setMetadata(
"tbaa",
md_builder_->createTBAAStructTagNode(md_tbaa_ctx_ptr_, md_tbaa_ctx_ptr_, 0));
llvm::Value* retcode = builder_->CreateCall(
RuntimeTVMGetFuncFromEnv(), {ctx, GetConstString(fname), out});
init_block = CheckCallSuccess(retcode);
llvm::Value* loaded_handle = builder_->CreateAlignedLoad(out, align);
builder_->CreateBr(end_block);
// end block
builder_->SetInsertPoint(end_block);
llvm::PHINode* phi = builder_->CreatePHI(t_tvm_func_handle_, 2);
phi->addIncoming(handle, pre_block);
phi->addIncoming(loaded_handle, init_block);
return phi;
}
llvm::Value* CodeGenCPU::CreateCallPacked(const Call* op) {
CHECK_EQ(op->args.size(), 5U);
std::string func_name = op->args[0].as<StringImm>()->value;
llvm::Value* handle = GetPackedFuncHandle(func_name);
// call the function
int64_t begin = op->args[3].as<IntImm>()->value;
int64_t end = op->args[4].as<IntImm>()->value;
int64_t nargs = end - begin;
CHECK_GE(nargs, 0);
llvm::Value* stack_value = MakeValue(op->args[1]);
llvm::Value* stack_tcode = MakeValue(op->args[2]);
llvm::Value* arg_value = builder_->CreateInBoundsGEP(
builder_->CreatePointerCast(
stack_value, t_tvm_value_->getPointerTo()), ConstInt32(begin));
llvm::Value* arg_tcode = CreateBufferPtr(
Int(32), stack_tcode, ConstInt32(begin));
llvm::Value* ret_value = builder_->CreateInBoundsGEP(
builder_->CreatePointerCast(
stack_value, t_tvm_value_->getPointerTo()), ConstInt32(end));
llvm::Value* ret_tcode = CreateBufferPtr(
Int(32), stack_tcode, ConstInt32(end));
CheckCallSuccess(
builder_->CreateCall(
RuntimeTVMFuncCall(),
{handle, arg_value, arg_tcode, ConstInt32(nargs),
ret_value, ret_tcode}));
Type r_type = op->type;
Type r_api_type = ir::APIType(r_type);
llvm::Value* rvalue =
builder_->CreateAlignedLoad(
builder_->CreatePointerCast(
ret_value, LLVMType(r_api_type)->getPointerTo()), 8);
rvalue = CreateCast(r_api_type, r_type, rvalue);
return rvalue;
}
llvm::Value* CodeGenCPU::RuntimeTVMFuncCall() {
if (f_tvm_func_call_ != nullptr) return f_tvm_func_call_;
return GetContextPtr(gv_tvm_func_call_);
}
llvm::Value* CodeGenCPU::RuntimeTVMGetFuncFromEnv() {
if (f_tvm_get_func_from_env_ != nullptr) return f_tvm_get_func_from_env_;
return GetContextPtr(gv_tvm_get_func_from_env_);
}
llvm::Value* CodeGenCPU::RuntimeTVMAPISetLastError() {
if (f_tvm_api_set_last_error_ != nullptr) return f_tvm_api_set_last_error_;
return GetContextPtr(gv_tvm_api_set_last_error_);
}
llvm::Value* CodeGenCPU::RuntimeTVMParallelLaunch() {
if (f_tvm_parallel_launch_ != nullptr) return f_tvm_parallel_launch_;
return GetContextPtr(gv_tvm_parallel_launch_);
}
llvm::Value* CodeGenCPU::RuntimeTVMParallelBarrier() {
if (f_tvm_parallel_barrier_ != nullptr) return f_tvm_parallel_barrier_;
return GetContextPtr(gv_tvm_parallel_barrier_);
}
void CodeGenCPU::AddStartupFunction() {
if (export_system_symbols_.size() != 0) {
llvm::FunctionType* ftype = llvm::FunctionType::get(t_void_, {}, false);
function_ = llvm::Function::Create(
ftype,
llvm::Function::InternalLinkage,
"__tvm_module_startup", module_.get());
llvm::BasicBlock* startup_entry = llvm::BasicBlock::Create(*ctx_, "entry", function_);
builder_->SetInsertPoint(startup_entry);
for (const auto& kv : export_system_symbols_) {
llvm::Value* name = GetConstString(kv.first);
builder_->CreateCall(
f_tvm_register_system_symbol_, {
name, builder_->CreateBitCast(kv.second, t_void_p_)});
}
llvm::appendToGlobalCtors(*module_, function_, 65535);
builder_->CreateRet(nullptr);
}
}
llvm::Value* CodeGenCPU::CreateIntrinsic(const Call* op) {
if (op->is_intrinsic(intrinsic::tvm_call_packed_lowered)) {
return CreateCallPacked(op);
} else if (op->is_intrinsic(intrinsic::tvm_throw_last_error)) {
builder_->CreateRet(ConstInt32(-1));
return ConstInt32(-1);
} else if (op->is_intrinsic(intrinsic::tvm_struct_get)) {
CHECK_EQ(op->args.size(), 3U);
int kind = op->args[2].as<IntImm>()->value;
llvm::Value* ref = this->CreateStructRefPtr(
op->type, MakeValue(op->args[0]),
MakeValue(op->args[1]), kind);
if (kind == intrinsic::kArrAddr) {
return builder_->CreatePointerCast(ref, t_void_p_);
} else {
return builder_->CreateLoad(ref);
}
} else if (op->is_intrinsic(intrinsic::tvm_struct_set)) {
CHECK_EQ(op->args.size(), 4U);
int kind = op->args[2].as<IntImm>()->value;
llvm::Value* value = MakeValue(op->args[3]);
llvm::Value* ref = this->CreateStructRefPtr(
op->args[3].type(), MakeValue(op->args[0]),
MakeValue(op->args[1]), kind);
CHECK(kind != intrinsic::kArrAddr);
if (value->getType()->isPointerTy()) {
value = builder_->CreatePointerCast(
value, ref->getType()->getPointerElementType());
}
builder_->CreateStore(value, ref);
return ConstInt32(0);
} else if (op->is_intrinsic(intrinsic::tvm_stack_alloca)) {
CHECK_EQ(op->args.size(), 2U);
const std::string& type = op->args[0].as<StringImm>()->value;
llvm::Value* num = MakeValue(op->args[1]);
if (type == "shape") {
return builder_->CreateAlloca(t_tvm_shape_index_, num);
} else if (type == "arg_value") {
return builder_->CreateAlloca(t_tvm_value_, num);
} else if (type == "arg_tcode") {
return builder_->CreateAlloca(t_int_, num);
} else if (type == "array") {
return builder_->CreateAlloca(t_tvm_array_, num);
} else {
LOG(FATAL) << "Unknown stack alloca type " << type;
return nullptr;
}
} else {
return CodeGenLLVM::CreateIntrinsic(op);
}
}
void CodeGenCPU::VisitStmt_(const AssertStmt* op) {
using llvm::BasicBlock;
llvm::Value* cond = MakeValue(op->condition);
std::ostringstream os;
os << "Assert fail: " << op->condition;
if (op->message.as<StringImm>()) {
os << ", " << op->message.as<StringImm>()->value;
}
llvm::Value* msg = GetConstString(os.str());
BasicBlock* fail_block = BasicBlock::Create(
*ctx_, "assert_fail", function_);
BasicBlock* end_block = BasicBlock::Create(
*ctx_, "assert_end", function_);
builder_->CreateCondBr(cond, end_block, fail_block, md_very_likely_branch_);
// fail condition.
builder_->SetInsertPoint(fail_block);
builder_->CreateCall(RuntimeTVMAPISetLastError(), {msg});
builder_->CreateRet(ConstInt32(-1));
// otherwise set it to be new end.
builder_->SetInsertPoint(end_block);
CodeGenLLVM::VisitStmt_(op);
}
void CodeGenCPU::VisitStmt_(const AttrStmt* op) {
if (op->attr_key == ir::attr::coproc_uop_scope) {
this->CreateStaticInit(op->value.as<StringImm>()->value, op->body);
} else if (op->attr_key == ir::attr::compute_scope) {
this->CreateComputeScope(op);
} else if (op->attr_key == ir::attr::pragma_scope) {
const std::string& pname = op->value.as<StringImm>()->value;
if (pname == "parallel_stride_pattern") {
CHECK(parallel_env_.penv != nullptr)
<< "Pragma parallel_stride_pattern only valid in parallel launch";
parallel_env_.stride_pattern = true;
this->VisitStmt(op->body);
} else if (pname == "parallel_launch_point") {
CreateParallelLaunch(op->body, 0);
} else if (pname == "parallel_barrier_when_finish") {
CHECK(parallel_env_.penv != nullptr)
<< "Cannot run barrier without parallel environment";
CHECK(!parallel_env_.hit_parallel_loop)
<< "Cannot not place within parallel loop as the workload may differ, "
<< " place it between parallel and parallel_launch_point";
this->VisitStmt(op->body);
builder_->CreateCall(
RuntimeTVMParallelBarrier(),
{MakeValue(parallel_env_.task_id), parallel_env_.penv});
} else {
LOG(WARNING) << "Unknown pragma " << pname;
this->VisitStmt(op->body);
}
} else {
CodeGenLLVM::VisitStmt_(op);
}
}
void CodeGenCPU::VisitStmt_(const For* op) {
CHECK(is_zero(op->min));
if (op->for_type == ForType::Serial) {
CodeGenLLVM::VisitStmt_(op);
} else if (op->for_type == ForType::Parallel) {
if (parallel_env_.penv == nullptr) {
CreateParallelLaunch(
For::make(
op->loop_var, op->min, op->extent,
op->for_type, op->device_api, op->body), 0);
} else {
// already in parallel env.
CHECK(parallel_env_.task_id.defined());
CHECK(parallel_env_.num_task.defined());
CHECK(parallel_env_.penv != nullptr);
Type t = op->extent.type();
Expr num_task = cast(t, parallel_env_.num_task);
Expr task_id = cast(t, parallel_env_.task_id);
CHECK(!parallel_env_.hit_parallel_loop)
<< "Nested parallel loop is not supported by threadpool, try fuse them instead";
parallel_env_.hit_parallel_loop = true;
if (parallel_env_.stride_pattern) {
CreateSerialFor(MakeValue(task_id),
MakeValue(op->extent),
MakeValue(num_task),
op->loop_var,
op->body);
} else {
Expr step = (op->extent + num_task - make_const(t, 1)) / num_task;
Expr begin = Min::make(task_id * step, op->extent);
Expr end = Min::make((task_id + make_const(t, 1)) * step, op->extent);
CreateSerialFor(MakeValue(begin),
MakeValue(end),
ConstInt32(1),
op->loop_var,
op->body);
}
}
} else {
LOG(FATAL) << "cannot handle for type " << op->for_type;
}
}
} // namespace codegen
} // namespace tvm
#endif // TVM_LLVM_VERSION
/*!
* Copyright (c) 2017 by Contributors
* \file codegen_llvm_cpu.h
* \brief Common base class for generating into LLVM IR on CPU host.
*/
#ifndef TVM_CODEGEN_LLVM_CODEGEN_CPU_H_
#define TVM_CODEGEN_LLVM_CODEGEN_CPU_H_
#include <utility>
#include <vector>
#include <string>
#include "./codegen_llvm.h"
namespace tvm {
namespace codegen {
// CPU host code generation
class CodeGenCPU : public CodeGenLLVM {
public:
void Init(const std::string& module_name,
llvm::TargetMachine* tm,
llvm::LLVMContext* ctx,
bool system_lib,
bool dynamic_lookup) override;
void AddFunction(const LoweredFunc& f) override;
void AddMainFunction(const std::string& entry_func_name) override;
void VisitStmt_(const AssertStmt* op) override;
void VisitStmt_(const AttrStmt* op) override;
void VisitStmt_(const For* op) override;
llvm::Value* CreateIntrinsic(const Call* op) override;
llvm::Value* CreateCallExtern(const Call* op) override;
protected:
void AddStartupFunction() final;
// meta data
llvm::MDNode* md_tbaa_ctx_ptr_{nullptr};
// TVM related data types
llvm::Type* t_tvm_shape_index_{nullptr};
llvm::Type* t_tvm_func_handle_{nullptr};
llvm::StructType* t_tvm_context_{nullptr};
llvm::StructType* t_tvm_type_{nullptr};
llvm::StructType* t_tvm_array_{nullptr};
llvm::StructType* t_tvm_value_{nullptr};
llvm::StructType* t_tvm_parallel_group_env_{nullptr};
llvm::FunctionType* ftype_tvm_parallel_lambda_{nullptr};
llvm::FunctionType* ftype_tvm_func_call_{nullptr};
llvm::FunctionType* ftype_tvm_get_func_from_env_{nullptr};
llvm::FunctionType* ftype_tvm_api_set_last_error_{nullptr};
llvm::FunctionType* ftype_tvm_parallel_launch_{nullptr};
llvm::FunctionType* ftype_tvm_parallel_barrier_{nullptr};
llvm::FunctionType* ftype_tvm_register_system_symbol_{nullptr};
// Lazy entry for function call.
llvm::FunctionType* ftype_tvm_static_init_callback_{nullptr};
llvm::FunctionType* ftype_tvm_static_init_{nullptr};
private:
// the parallel group information
struct ParallelEnv {
VarExpr task_id;
VarExpr num_task;
bool stride_pattern{false};
bool hit_parallel_loop{false};
llvm::Value* penv{nullptr};
};
// Get runtime functions
void InitGlobalContext(bool dynamic_lookup);
llvm::GlobalVariable* InitContextPtr(llvm::Type* type, std::string name);
llvm::Value* GetContextPtr(llvm::GlobalVariable* gv);
llvm::Value* RuntimeTVMFuncCall();
llvm::Value* RuntimeTVMGetFuncFromEnv();
llvm::Value* RuntimeTVMAPISetLastError();
llvm::Value* RuntimeTVMParallelLaunch();
llvm::Value* RuntimeTVMParallelBarrier();
llvm::Value* GetPackedFuncHandle(const std::string& str);
llvm::Value* PackClosureData(const Array<Var>& fields);
llvm::Value* CreateStructRefPtr(Type t, llvm::Value* buffer, llvm::Value* index, int kind);
void UnpackClosureData(llvm::Value*cdata,
const Array<Var>& fields,
std::unordered_map<const Variable*, llvm::Value*>* vmap);
// create call into tvm packed function.
llvm::Value* CreateCallPacked(const Call* op);
// Create static initialization
void CreateStaticInit(const std::string& init_fname, const Stmt& body);
// Create parallel launch
void CreateParallelLaunch(const Stmt& body, int num_task);
// Create a new compute scope.
void CreateComputeScope(const AttrStmt* op);
// Check if the call to packed function is successful
// if not directly finalize function and pass on return code.
// return the end block after the check
llvm::BasicBlock* CheckCallSuccess(llvm::Value* retcode);
// Context for injection lookup
llvm::GlobalVariable* gv_mod_ctx_{nullptr};
llvm::GlobalVariable* gv_tvm_func_call_{nullptr};
llvm::GlobalVariable* gv_tvm_get_func_from_env_{nullptr};
llvm::GlobalVariable* gv_tvm_api_set_last_error_{nullptr};
llvm::GlobalVariable* gv_tvm_parallel_launch_{nullptr};
llvm::GlobalVariable* gv_tvm_parallel_barrier_{nullptr};
std::unordered_map<std::string, llvm::GlobalVariable*> gv_func_map_;
// context for direct dynamic lookup
llvm::Function* f_tvm_func_call_{nullptr};
llvm::Function* f_tvm_get_func_from_env_{nullptr};
llvm::Function* f_tvm_api_set_last_error_{nullptr};
llvm::Function* f_tvm_parallel_launch_{nullptr};
llvm::Function* f_tvm_parallel_barrier_{nullptr};
llvm::Function* f_tvm_register_system_symbol_{nullptr};
// Current parallel environment scope.
ParallelEnv parallel_env_;
// global to packed function handle
std::unordered_map<std::string, llvm::GlobalVariable*> func_handle_map_;
// List of symbols to be exported to TVM system lib.
std::vector<std::pair<std::string, llvm::Value*> > export_system_symbols_;
};
} // namespace codegen
} // namespace tvm
#endif // TVM_CODEGEN_LLVM_CODEGEN_CPU_H_
......@@ -4,10 +4,10 @@
*/
#ifdef TVM_LLVM_VERSION
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/device_api.h>
#include <tvm/ir_pass.h>
#include <tvm/runtime/c_runtime_api.h>
#include "./codegen_llvm.h"
#include "./codegen_cpu.h"
#include "../../pass/ir_util.h"
#include "../../arithmetic/compute_expr.h"
......@@ -22,7 +22,7 @@ std::unique_ptr<CodeGenLLVM> CodeGenLLVM::Create(llvm::TargetMachine *tm) {
void* handle = (*f)();
return std::unique_ptr<CodeGenLLVM>(static_cast<CodeGenLLVM*>(handle));
} else {
return std::unique_ptr<CodeGenLLVM>(new CodeGenLLVM());
return std::unique_ptr<CodeGenLLVM>(new CodeGenCPU());
}
}
......@@ -32,15 +32,10 @@ void CodeGenLLVM::Init(const std::string& module_name,
bool system_lib,
bool dynamic_lookup) {
InitializeLLVM();
static_assert(sizeof(TVMValue) == sizeof(double), "invariant");
// static_assert(alignof(TVMValue) == alignof(double), "invariant");
// clear maps
var_map_.clear();
str_map_.clear();
func_handle_map_.clear();
export_system_symbols_.clear();
// initialize types.
if (ctx_ != ctx) {
ctx_ = ctx;
t_void_ = llvm::Type::getVoidTy(*ctx);
t_void_p_ = llvm::Type::getInt8Ty(*ctx)->getPointerTo();
t_int_ = llvm::Type::getIntNTy(*ctx, sizeof(int) * 8);
......@@ -50,98 +45,17 @@ void CodeGenLLVM::Init(const std::string& module_name,
t_int32_ = llvm::Type::getInt32Ty(*ctx);
t_int64_ = llvm::Type::getInt64Ty(*ctx);
t_float64_ = llvm::Type::getDoubleTy(*ctx);
t_tvm_shape_index_ = llvm::Type::getIntNTy(*ctx, TVMShapeIndexType().bits());
t_tvm_context_ = llvm::StructType::create({t_int_, t_int_});
t_tvm_type_ = llvm::StructType::create({t_int8_, t_int8_, t_int16_});
t_tvm_func_handle_ = t_void_p_;
t_tvm_array_ = llvm::StructType::create(
{t_void_p_,
t_tvm_context_,
t_int_,
t_tvm_type_,
t_tvm_shape_index_->getPointerTo(),
t_tvm_shape_index_->getPointerTo(),
t_int64_});
t_tvm_value_ = llvm::StructType::create({t_float64_});
t_tvm_parallel_group_env_ = llvm::StructType::create({
t_int32_->getPointerTo(),
t_int32_});
ftype_tvm_parallel_lambda_ = llvm::FunctionType::get(
t_int_,
{t_int_,
t_tvm_parallel_group_env_->getPointerTo(),
t_void_p_}, false);
md_builder_.reset(new llvm::MDBuilder(*ctx));
md_very_likely_branch_ =
md_builder_->createBranchWeights(1 << 30, 0);
md_tbaa_root_ = md_builder_->createTBAARoot("tvmtbaa");
md_tbaa_alias_set_ = md_builder_->createTBAAScalarTypeNode(
"alias_set", md_tbaa_root_);
md_tbaa_ctx_ptr_ = md_builder_->createTBAAScalarTypeNode(
"ctx_ptr", md_tbaa_root_);
}
ctx_ = ctx;
// initialize Modules and function type
module_.reset(new llvm::Module(module_name, *ctx));
ftype_tvm_func_call_ = llvm::FunctionType::get(t_int_, {
t_tvm_func_handle_,
t_tvm_value_->getPointerTo(),
t_int_->getPointerTo(),
t_int_,
t_tvm_value_->getPointerTo(),
t_int_->getPointerTo()}, false);
ftype_tvm_get_func_from_env_ = llvm::FunctionType::get(t_int_, {
t_void_p_,
t_char_->getPointerTo(),
t_tvm_func_handle_->getPointerTo()}, false);
ftype_tvm_api_set_last_error_ = llvm::FunctionType::get(
t_void_, {t_char_->getPointerTo()}, false);
ftype_tvm_parallel_launch_ =
llvm::FunctionType::get(t_int_, {
ftype_tvm_parallel_lambda_->getPointerTo(), t_void_p_, t_int_}
, false);
ftype_tvm_parallel_barrier_ =
llvm::FunctionType::get(t_int_, {
t_int_, t_tvm_parallel_group_env_->getPointerTo()}
, false);
ftype_tvm_static_init_callback_ =
llvm::FunctionType::get(t_int_, {t_void_p_}, false);
ftype_tvm_static_init_ =
llvm::FunctionType::get(t_int_, {
t_void_p_->getPointerTo(),
ftype_tvm_static_init_callback_->getPointerTo(),
t_void_p_, t_int_}
, false);
// initialize TVM runtime API
if (system_lib) {
// We will need this in environment for backward registration.
f_tvm_register_system_symbol_ = llvm::Function::Create(
llvm::FunctionType::get(t_int_, {t_char_->getPointerTo(), t_void_p_}, false),
llvm::Function::ExternalLinkage, "TVMBackendRegisterSystemLibSymbol", module_.get());
} else {
f_tvm_register_system_symbol_ = nullptr;
}
if (dynamic_lookup || system_lib) {
f_tvm_func_call_ = llvm::Function::Create(
ftype_tvm_func_call_,
llvm::Function::ExternalLinkage, "TVMFuncCall", module_.get());
f_tvm_get_func_from_env_ = llvm::Function::Create(
ftype_tvm_get_func_from_env_,
llvm::Function::ExternalLinkage, "TVMBackendGetFuncFromEnv", module_.get());
f_tvm_api_set_last_error_ = llvm::Function::Create(
ftype_tvm_api_set_last_error_,
llvm::Function::ExternalLinkage, "TVMAPISetLastError", module_.get());
f_tvm_parallel_launch_ = llvm::Function::Create(
ftype_tvm_parallel_launch_,
llvm::Function::ExternalLinkage, "TVMBackendParallelLaunch", module_.get());
f_tvm_parallel_barrier_ = llvm::Function::Create(
ftype_tvm_parallel_barrier_,
llvm::Function::ExternalLinkage, "TVMBackendParallelBarrier", module_.get());
}
this->InitTarget(tm);
// initialize builder
builder_.reset(new IRBuilder(*ctx));
this->InitGlobalContext(dynamic_lookup);
this->InitTarget(tm);
}
void CodeGenLLVM::InitTarget(llvm::TargetMachine* tm) {
......@@ -164,53 +78,6 @@ void CodeGenLLVM::InitTarget(llvm::TargetMachine* tm) {
}
}
llvm::GlobalVariable* CodeGenLLVM::InitContextPtr(
llvm::Type* p_type, std::string name) {
llvm::GlobalVariable* gv = new llvm::GlobalVariable(
*module_, p_type, false,
llvm::GlobalValue::LinkOnceAnyLinkage, 0,
name);
gv->setAlignment(data_layout_->getTypeAllocSize(p_type));
gv->setInitializer(llvm::Constant::getNullValue(p_type));
return gv;
}
llvm::Value* CodeGenLLVM::GetContextPtr(llvm::GlobalVariable* gv) {
CHECK(gv != nullptr);
llvm::LoadInst* faddr = builder_->CreateAlignedLoad(gv, gv->getAlignment());
faddr->setMetadata(
"tbaa",
md_builder_->createTBAAStructTagNode(md_tbaa_ctx_ptr_, md_tbaa_ctx_ptr_, 0));
return faddr;
}
void CodeGenLLVM::InitGlobalContext(bool dynamic_lookup) {
// Module context
gv_mod_ctx_ = InitContextPtr(t_void_p_, tvm::runtime::symbol::tvm_module_ctx);
// Register back the locations.
if (f_tvm_register_system_symbol_ != nullptr) {
export_system_symbols_.emplace_back(
std::make_pair(tvm::runtime::symbol::tvm_module_ctx, gv_mod_ctx_));
} else {
if (!dynamic_lookup) {
gv_tvm_func_call_ = InitContextPtr(
ftype_tvm_func_call_->getPointerTo(), "__TVMFuncCall");
gv_tvm_get_func_from_env_ = InitContextPtr(
ftype_tvm_get_func_from_env_->getPointerTo(), "__TVMBackendGetFuncFromEnv");
gv_tvm_api_set_last_error_ = InitContextPtr(
ftype_tvm_api_set_last_error_->getPointerTo(), "__TVMAPISetLastError");
gv_tvm_parallel_launch_ = InitContextPtr(
ftype_tvm_parallel_launch_->getPointerTo(), "__TVMBackendParallelLaunch");
gv_tvm_parallel_barrier_ = InitContextPtr(
ftype_tvm_parallel_barrier_->getPointerTo(), "__TVMBackendParallelBarrier");
// Mark as context functions
gv_func_map_["TVMBackendAllocWorkspace"] = nullptr;
gv_func_map_["TVMBackendFreeWorkspace"] = nullptr;
}
}
}
void CodeGenLLVM::InitFuncState() {
var_map_.clear();
align_map_.clear();
......@@ -264,22 +131,10 @@ void CodeGenLLVM::AddFunction(const LoweredFunc& f) {
builder_->SetInsertPoint(block);
this->VisitStmt(f->body);
builder_->CreateRet(ConstInt32(0));
if (f_tvm_register_system_symbol_ != nullptr) {
export_system_symbols_.emplace_back(
std::make_pair(f->name, builder_->CreatePointerCast(function_, t_void_p_)));
}
}
void CodeGenLLVM::AddMainFunction(const std::string& entry_func_name) {
llvm::Function* f = module_->getFunction(entry_func_name);
CHECK(f) << "Function " << entry_func_name << "does not in module";
llvm::Type* type = llvm::ArrayType::get(t_char_, entry_func_name.length() + 1);
llvm::GlobalVariable *global = new llvm::GlobalVariable(
*module_, type, true, llvm::GlobalValue::WeakAnyLinkage, 0,
runtime::symbol::tvm_module_main);
global->setAlignment(1);
global->setInitializer(llvm::ConstantDataArray::getString(*ctx_, entry_func_name));
LOG(FATAL) << "Donot support add main function";
}
class FPassManager : public llvm::legacy::FunctionPassManager {
......@@ -300,7 +155,6 @@ class MPassManager : public llvm::legacy::PassManager {
}
};
void CodeGenLLVM::Optimize() {
// place optimization pass
llvm::PassManagerBuilder builder;
......@@ -330,33 +184,9 @@ void CodeGenLLVM::Optimize() {
std::unique_ptr<llvm::Module> CodeGenLLVM::Finish() {
this->AddStartupFunction();
this->Optimize();
var_map_.clear();
str_map_.clear();
func_handle_map_.clear();
export_system_symbols_.clear();
return std::move(module_);
}
void CodeGenLLVM::AddStartupFunction() {
if (export_system_symbols_.size() != 0) {
llvm::FunctionType* ftype = llvm::FunctionType::get(t_void_, {}, false);
function_ = llvm::Function::Create(
ftype,
llvm::Function::InternalLinkage,
"__tvm_module_startup", module_.get());
llvm::BasicBlock* startup_entry = llvm::BasicBlock::Create(*ctx_, "entry", function_);
builder_->SetInsertPoint(startup_entry);
for (const auto& kv : export_system_symbols_) {
llvm::Value* name = GetConstString(kv.first);
builder_->CreateCall(
f_tvm_register_system_symbol_, {
name, builder_->CreateBitCast(kv.second, t_void_p_)});
}
llvm::appendToGlobalCtors(*module_, function_, 65535);
builder_->CreateRet(nullptr);
}
}
llvm::Type* CodeGenLLVM::LLVMType(const Type& t) const {
llvm::Type* ret = nullptr;
if (t.is_uint() || t.is_int()) {
......@@ -378,23 +208,6 @@ llvm::Type* CodeGenLLVM::LLVMType(const Type& t) const {
return ret;
}
llvm::BasicBlock* CodeGenLLVM::CheckCallSuccess(llvm::Value* retcode) {
// create emit codes that checks and load the function.
using llvm::BasicBlock;
BasicBlock* fail_block = BasicBlock::Create(
*ctx_, "call_fail", function_);
BasicBlock* end_block = BasicBlock::Create(
*ctx_, "call_end", function_);
llvm::Value* succ = builder_->CreateICmpEQ(
retcode, llvm::ConstantInt::get(t_int_, 0));
builder_->CreateCondBr(succ, end_block, fail_block, md_very_likely_branch_);
builder_->SetInsertPoint(fail_block);
// return the code.
builder_->CreateRet(retcode);
// otherwise set it to be new end.
builder_->SetInsertPoint(end_block);
return end_block;
}
void CodeGenLLVM::AddAliasInfo(
llvm::Instruction* inst, const Variable* buffer, Expr index, Type t) {
......@@ -474,74 +287,6 @@ llvm::Value* CodeGenLLVM::CreateBufferPtr(
return builder_->CreateInBoundsGEP(buffer, index);
}
llvm::Value* CodeGenLLVM::CreateStructRefPtr(
Type t, llvm::Value* buf, llvm::Value* index, int kind) {
if (kind < intrinsic::kArrKindBound_) {
if (buf->getType() == t_void_p_) {
buf = builder_->CreatePointerCast(buf, t_tvm_array_->getPointerTo());
} else {
CHECK_EQ(buf->getType(), t_tvm_array_->getPointerTo());
}
}
switch (kind) {
case intrinsic::kArrAddr: {
return builder_->CreateInBoundsGEP(buf, index);
}
case intrinsic::kArrData: {
return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(0)});
}
case intrinsic::kArrShape: {
return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(4)});
}
case intrinsic::kArrStrides: {
return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(5)});
}
case intrinsic::kArrNDim: {
return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(2)});
}
case intrinsic::kArrTypeCode: {
return builder_->CreateInBoundsGEP(
buf, {index, ConstInt32(3), ConstInt32(0)});
}
case intrinsic::kArrTypeBits: {
return builder_->CreateInBoundsGEP(
buf, {index, ConstInt32(3), ConstInt32(1)});
}
case intrinsic::kArrTypeLanes: {
return builder_->CreateInBoundsGEP(
buf, {index, ConstInt32(3), ConstInt32(2)});
}
case intrinsic::kArrByteOffset: {
return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(6)});
}
case intrinsic::kArrDeviceId: {
return builder_->CreateInBoundsGEP(
buf, {index, ConstInt32(1), ConstInt32(1)});
}
case intrinsic::kArrDeviceType: {
return builder_->CreateInBoundsGEP(
buf, {index, ConstInt32(1), ConstInt32(0)});
}
case intrinsic::kTVMValueContent: {
CHECK_EQ(t.lanes(), 1);
CHECK(t.is_handle() || t.bits() == 64);
if (t.is_int()) {
buf = builder_->CreatePointerCast(buf, t_int64_->getPointerTo());
return builder_->CreateInBoundsGEP(buf, index);
} else if (t.is_float()) {
buf = builder_->CreatePointerCast(buf, t_float64_->getPointerTo());
return builder_->CreateInBoundsGEP(buf, index);
} else {
CHECK(t.is_handle());
buf = builder_->CreatePointerCast(buf, t_tvm_value_->getPointerTo());
buf = builder_->CreateInBoundsGEP(buf, index);
return builder_->CreatePointerCast(buf, t_void_p_->getPointerTo());
}
}
default: LOG(FATAL) << "unknown field code"; return nullptr;
}
}
llvm::Value* CodeGenLLVM::CreateCast(Type from, Type to, llvm::Value* value) {
llvm::Type * target = LLVMType(to);
if (value->getType() == target) return value;
......@@ -568,133 +313,23 @@ llvm::Value* CodeGenLLVM::CreateCast(Type from, Type to, llvm::Value* value) {
}
}
llvm::Value* CodeGenLLVM::GetPackedFuncHandle(const std::string& fname) {
using llvm::BasicBlock;
// We will store the packed function handle in global space.
// Initialize it during the first call.
llvm::DataLayout layout(module_.get());
uint64_t align = layout.getTypeAllocSize(t_tvm_func_handle_);
auto it = func_handle_map_.find(fname);
llvm::GlobalVariable* hptr;
if (it == func_handle_map_.end()) {
// create global location for the handle
// create the function handle
hptr = new llvm::GlobalVariable(
*module_, t_tvm_func_handle_, false,
llvm::GlobalValue::LinkOnceAnyLinkage, 0, ".tvm_func." + fname);
hptr->setAlignment(align);
hptr->setInitializer(llvm::Constant::getNullValue(t_tvm_func_handle_));
func_handle_map_[fname] = hptr;
} else {
hptr = it->second;
}
// create emit codes that checks and load the function.
BasicBlock* pre_block = builder_->GetInsertBlock();
BasicBlock* init_block = BasicBlock::Create(
*ctx_, "handle_init", function_);
BasicBlock* end_block = BasicBlock::Create(
*ctx_, "handle_init_end", function_);
llvm::Value* handle = builder_->CreateAlignedLoad(hptr, align);
llvm::Value* handle_not_null = builder_->CreateICmpNE(
handle, llvm::Constant::getNullValue(t_tvm_func_handle_));
builder_->CreateCondBr(
handle_not_null, end_block, init_block, md_very_likely_branch_);
// Initialize the handle if needed.
builder_->SetInsertPoint(init_block);
llvm::Value* out = builder_->CreateAlloca(t_tvm_func_handle_);
llvm::LoadInst* ctx = builder_->CreateAlignedLoad(
gv_mod_ctx_, gv_mod_ctx_->getAlignment());
ctx->setMetadata(
"tbaa",
md_builder_->createTBAAStructTagNode(md_tbaa_ctx_ptr_, md_tbaa_ctx_ptr_, 0));
llvm::Value* retcode = builder_->CreateCall(
RuntimeTVMGetFuncFromEnv(), {ctx, GetConstString(fname), out});
init_block = CheckCallSuccess(retcode);
llvm::Value* loaded_handle = builder_->CreateAlignedLoad(out, align);
builder_->CreateBr(end_block);
// end block
builder_->SetInsertPoint(end_block);
llvm::PHINode* phi = builder_->CreatePHI(t_tvm_func_handle_, 2);
phi->addIncoming(handle, pre_block);
phi->addIncoming(loaded_handle, init_block);
return phi;
}
llvm::Value* CodeGenLLVM::CreateCallPacked(const Call* op) {
CHECK_EQ(op->args.size(), 5U);
std::string func_name = op->args[0].as<StringImm>()->value;
llvm::Value* handle = GetPackedFuncHandle(func_name);
// call the function
int64_t begin = op->args[3].as<IntImm>()->value;
int64_t end = op->args[4].as<IntImm>()->value;
int64_t nargs = end - begin;
CHECK_GE(nargs, 0);
llvm::Value* stack_value = MakeValue(op->args[1]);
llvm::Value* stack_tcode = MakeValue(op->args[2]);
llvm::Value* arg_value = builder_->CreateInBoundsGEP(
builder_->CreatePointerCast(
stack_value, t_tvm_value_->getPointerTo()), ConstInt32(begin));
llvm::Value* arg_tcode = CreateBufferPtr(
Int(32), stack_tcode, ConstInt32(begin));
llvm::Value* ret_value = builder_->CreateInBoundsGEP(
builder_->CreatePointerCast(
stack_value, t_tvm_value_->getPointerTo()), ConstInt32(end));
llvm::Value* ret_tcode = CreateBufferPtr(
Int(32), stack_tcode, ConstInt32(end));
CheckCallSuccess(
builder_->CreateCall(
RuntimeTVMFuncCall(),
{handle, arg_value, arg_tcode, ConstInt32(nargs),
ret_value, ret_tcode}));
Type r_type = op->type;
Type r_api_type = ir::APIType(r_type);
llvm::Value* rvalue =
builder_->CreateAlignedLoad(
builder_->CreatePointerCast(
ret_value, LLVMType(r_api_type)->getPointerTo()), 8);
rvalue = CreateCast(r_api_type, r_type, rvalue);
return rvalue;
}
llvm::Value* CodeGenLLVM::CreateCallExtern(const Call* op) {
std::vector<llvm::Value*> arg_values(op->args.size());
for (size_t i = 0; i < op->args.size(); ++i) {
arg_values[i] = MakeValue(op->args[i]);
}
if (op->type.is_scalar()) {
std::vector<llvm::Type*> arg_types;
for (llvm::Value* v : arg_values) {
arg_types.push_back(v->getType());
}
llvm::FunctionType* ftype = llvm::FunctionType::get(
LLVMType(op->type), arg_types, false);
// Check if it is available in global function table as injected function.
auto it = gv_func_map_.find(op->name);
if (it != gv_func_map_.end()) {
if (it->second == nullptr) {
gv_func_map_[op->name] = InitContextPtr(ftype->getPointerTo(), "__" + op->name);
it = gv_func_map_.find(op->name);
}
return builder_->CreateCall(GetContextPtr(it->second), arg_values);
} else {
llvm::Function* f = module_->getFunction(op->name);
if (f == nullptr) {
f = llvm::Function::Create(
ftype, llvm::Function::ExternalLinkage, op->name, module_.get());
}
return builder_->CreateCall(f, arg_values);
}
} else {
llvm::Function* f = module_->getFunction(op->name);
if (f) {
return CreateScalarizedCall(op, f, arg_values);
} else {
LOG(FATAL) << "cannot find function " << op->name;
}
}
LOG(FATAL) << "canot reach here";
return nullptr;
}
llvm::Value* CodeGenLLVM::CreateScalarizedCall(
......@@ -721,29 +356,6 @@ llvm::Value* CodeGenLLVM::CreateScalarizedCall(
return value;
}
llvm::Value* CodeGenLLVM::RuntimeTVMFuncCall() {
if (f_tvm_func_call_ != nullptr) return f_tvm_func_call_;
return GetContextPtr(gv_tvm_func_call_);
}
llvm::Value* CodeGenLLVM::RuntimeTVMGetFuncFromEnv() {
if (f_tvm_get_func_from_env_ != nullptr) return f_tvm_get_func_from_env_;
return GetContextPtr(gv_tvm_get_func_from_env_);
}
llvm::Value* CodeGenLLVM::RuntimeTVMAPISetLastError() {
if (f_tvm_api_set_last_error_ != nullptr) return f_tvm_api_set_last_error_;
return GetContextPtr(gv_tvm_api_set_last_error_);
}
llvm::Value* CodeGenLLVM::RuntimeTVMParallelLaunch() {
if (f_tvm_parallel_launch_ != nullptr) return f_tvm_parallel_launch_;
return GetContextPtr(gv_tvm_parallel_launch_);
}
llvm::Value* CodeGenLLVM::RuntimeTVMParallelBarrier() {
if (f_tvm_parallel_barrier_ != nullptr) return f_tvm_parallel_barrier_;
return GetContextPtr(gv_tvm_parallel_barrier_);
}
llvm::Value* CodeGenLLVM::GetVarValue(const Variable* v) const {
auto it = var_map_.find(v);
CHECK(it != var_map_.end())
......@@ -771,179 +383,6 @@ llvm::Value* CodeGenLLVM::GetConstString(const std::string& str) {
}
}
void CodeGenLLVM::CreateComputeScope(const AttrStmt* op) {
// There are two reasons why we create another function for compute_scope
// - Make sure the generated compute function is clearly separately(though it can get inlined)
// - Set noalias on all the pointer arguments, some of them are loaded from TVMArgs.
// This is easier than set the alias scope manually.
using llvm::BasicBlock;
Array<Var> vargs = ir::UndefinedVars(op->body, {});
std::vector<llvm::Value*> arg_values;
std::vector<llvm::Type*> arg_types;
for (Var v : vargs) {
llvm::Value* value = MakeValue(v);
arg_values.push_back(value);
arg_types.push_back(value->getType());
}
llvm::FunctionType* ftype =
llvm::FunctionType::get(t_int_, arg_types, false);
llvm::Function* fcompute =
llvm::Function::Create(ftype,
llvm::Function::PrivateLinkage,
op->value.as<StringImm>()->value,
module_.get());
BasicBlock* compute_call_end = CheckCallSuccess(
builder_->CreateCall(fcompute, arg_values));
// setup compute fuinction.
std::unordered_map<const Variable*, llvm::Value*> new_vmap;
size_t idx = 0;
for (auto it = fcompute->arg_begin();
it != fcompute->arg_end(); ++it, ++idx) {
llvm::Argument* v = &(*it);
const Var& var = vargs[idx];
new_vmap[var.get()] = v;
if (var.type().is_handle() && !alias_var_set_.count(var.get())) {
// set non alias.
#if TVM_LLVM_VERSION >= 50
fcompute->addParamAttr(idx + 1, llvm::Attribute::NoAlias);
#else
fcompute->setDoesNotAlias(idx + 1);
#endif
}
}
std::swap(function_, fcompute);
std::swap(new_vmap, var_map_);
BasicBlock *compute_entry = BasicBlock::Create(*ctx_, "entry", function_);
builder_->SetInsertPoint(compute_entry);
this->VisitStmt(op->body);
builder_->CreateRet(ConstInt32(0));
// swap the var map back, now we are back on track.
std::swap(new_vmap, var_map_);
std::swap(function_, fcompute);
builder_->SetInsertPoint(compute_call_end);
}
llvm::Value* CodeGenLLVM::PackClosureData(const Array<Var>& vfields) {
std::vector<llvm::Type*> fields;
for (Var v : vfields) {
auto it = var_map_.find(v.get());
CHECK(it != var_map_.end());
fields.push_back(it->second->getType());
}
llvm::StructType* tcdata = llvm::StructType::create(fields);
llvm::Value* cdata = builder_->CreateAlloca(tcdata, ConstInt32(1));
llvm::Value* zero = ConstInt32(0);
for (size_t i = 0; i < vfields.size(); ++i) {
builder_->CreateStore(
var_map_.at(vfields[i].get()),
builder_->CreateInBoundsGEP(cdata, {zero, ConstInt32(i)}));
}
return cdata;
}
void CodeGenLLVM::UnpackClosureData(llvm::Value* cdata,
const Array<Var>& vfields,
std::unordered_map<const Variable*, llvm::Value*>* vmap) {
for (size_t i = 0; i < vfields.size(); ++i) {
(*vmap)[vfields[i].get()] =
builder_->CreateLoad(builder_->CreateInBoundsGEP(
cdata, {ConstInt32(0), ConstInt32(i)}));
}
}
void CodeGenLLVM::CreateParallelLaunch(const Stmt& body, int num_task) {
using llvm::BasicBlock;
// closure data
llvm::Function* f = llvm::Function::Create(
ftype_tvm_parallel_lambda_,
llvm::Function::PrivateLinkage,
"__tvm_parallel_lambda", module_.get());
// allocate and setup the closure, call the closure.
Array<Var> vfields = ir::UndefinedVars(body, {});
llvm::Value* cdata = PackClosureData(vfields);
BasicBlock* par_launch_end = CheckCallSuccess(
builder_->CreateCall(
RuntimeTVMParallelLaunch(),
{f, builder_->CreatePointerCast(cdata, t_void_p_), ConstInt32(num_task)}));
// Setup the closure function.
BasicBlock *lambda_entry = BasicBlock::Create(*ctx_, "entry", f);
builder_->SetInsertPoint(lambda_entry);
auto it = f->arg_begin();
llvm::Value* task_id = &(*it++);
llvm::Value* penv = &(*it++);
cdata = builder_->CreatePointerCast(&(*it++), cdata->getType());
// setup new variable map, swap it with current var context.
std::unordered_map<const Variable*, llvm::Value*> new_vmap;
UnpackClosureData(cdata, vfields, &new_vmap);
// setup parallel env
ParallelEnv par_env;
par_env.task_id = Var("task_id", Int(32));
par_env.num_task = Var("num_task", Int(32));
new_vmap[par_env.task_id.get()] = task_id;
new_vmap[par_env.num_task.get()] = builder_->CreateLoad(
builder_->CreateInBoundsGEP(
penv, {ConstInt32(0), ConstInt32(1)}));
par_env.penv = penv;
std::swap(function_, f);
std::swap(parallel_env_, par_env);
std::swap(var_map_, new_vmap);
this->VisitStmt(body);
builder_->CreateRet(ConstInt32(0));
// swap the var map back, now we are back on track.
std::swap(var_map_, new_vmap);
std::swap(parallel_env_, par_env);
std::swap(function_, f);
CHECK(par_env.hit_parallel_loop)
<< "Cannot find parallel loop within parallel launch";
builder_->SetInsertPoint(par_launch_end);
}
void CodeGenLLVM::CreateStaticInit(const std::string& init_fname, const Stmt& body) {
using llvm::BasicBlock;
// closure data
llvm::Function* f = llvm::Function::Create(
ftype_tvm_static_init_callback_,
llvm::Function::PrivateLinkage,
"__tvm_static_init_lambda", module_.get());
llvm::GlobalVariable* gv = new llvm::GlobalVariable(
*module_, t_void_p_, false,
llvm::GlobalValue::PrivateLinkage, 0,
"__tvm_static_handle");
gv->setAlignment(data_layout_->getTypeAllocSize(t_void_p_));
gv->setInitializer(llvm::Constant::getNullValue(t_void_p_));
llvm::Function* finit = module_->getFunction(init_fname);
if (finit == nullptr) {
finit = llvm::Function::Create(
ftype_tvm_static_init_, llvm::Function::ExternalLinkage, init_fname, module_.get());
}
// allocate and setup the closure, call the closure.
Array<Var> vfields = ir::UndefinedVars(body, {});
llvm::Value* cdata = PackClosureData(vfields);
llvm::Value* nbytes = ConstInt32(data_layout_->getTypeAllocSize(
llvm::cast<llvm::PointerType>(cdata->getType())->getElementType()));
BasicBlock* init_end = CheckCallSuccess(
builder_->CreateCall(
finit,
{gv, f, builder_->CreatePointerCast(cdata, t_void_p_), nbytes}));
// Setup the closure function.
BasicBlock *lambda_entry = BasicBlock::Create(*ctx_, "entry", f);
builder_->SetInsertPoint(lambda_entry);
auto it = f->arg_begin();
cdata = builder_->CreatePointerCast(&(*it++), cdata->getType());
// setup new variable map, swap it with current var context.
std::unordered_map<const Variable*, llvm::Value*> new_vmap;
UnpackClosureData(cdata, vfields, &new_vmap);
CHECK(parallel_env_.penv == nullptr);
std::swap(function_, f);
std::swap(var_map_, new_vmap);
this->VisitStmt(body);
builder_->CreateRet(ConstInt32(0));
// swap the var map back, now we are back on track.
std::swap(var_map_, new_vmap);
std::swap(function_, f);
builder_->SetInsertPoint(init_end);
}
void CodeGenLLVM::CreateSerialFor(llvm::Value* begin,
llvm::Value* end,
llvm::Value* stride,
......@@ -1026,9 +465,6 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const Call* op) {
return builder_->CreateLShr(
MakeValue(op->args[0]), MakeValue(op->args[1]));
}
} else if (op->is_intrinsic(intrinsic::tvm_throw_last_error)) {
builder_->CreateRet(ConstInt32(-1));
return ConstInt32(-1);
} else if (op->is_intrinsic(intrinsic::tvm_address_of)) {
const Load *l = op->args[0].as<Load>();
CHECK(op->args.size() == 1 && l);
......@@ -1066,46 +502,6 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const Call* op) {
phi->addIncoming(then_value, then_block);
phi->addIncoming(else_value, else_block);
return phi;
} else if (op->is_intrinsic(intrinsic::tvm_struct_get)) {
CHECK_EQ(op->args.size(), 3U);
int kind = op->args[2].as<IntImm>()->value;
llvm::Value* ref = this->CreateStructRefPtr(
op->type, MakeValue(op->args[0]),
MakeValue(op->args[1]), kind);
if (kind == intrinsic::kArrAddr) {
return builder_->CreatePointerCast(ref, t_void_p_);
} else {
return builder_->CreateLoad(ref);
}
} else if (op->is_intrinsic(intrinsic::tvm_struct_set)) {
CHECK_EQ(op->args.size(), 4U);
int kind = op->args[2].as<IntImm>()->value;
llvm::Value* value = MakeValue(op->args[3]);
llvm::Value* ref = this->CreateStructRefPtr(
op->args[3].type(), MakeValue(op->args[0]),
MakeValue(op->args[1]), kind);
CHECK(kind != intrinsic::kArrAddr);
if (value->getType()->isPointerTy()) {
value = builder_->CreatePointerCast(
value, ref->getType()->getPointerElementType());
}
builder_->CreateStore(value, ref);
return ConstInt32(0);
} else if (op->is_intrinsic(intrinsic::tvm_stack_alloca)) {
CHECK_EQ(op->args.size(), 2U);
const std::string& type = op->args[0].as<StringImm>()->value;
llvm::Value* num = MakeValue(op->args[1]);
if (type == "shape") {
return builder_->CreateAlloca(t_tvm_shape_index_, num);
} else if (type == "arg_value") {
return builder_->CreateAlloca(t_tvm_value_, num);
} else if (type == "arg_tcode") {
return builder_->CreateAlloca(t_int_, num);
} else if (type == "array") {
return builder_->CreateAlloca(t_tvm_array_, num);
} else {
LOG(FATAL) << "Unknown stack alloca type " << type;
}
} else if (op->is_intrinsic(Call::reinterpret) && is_zero(op->args[0])) {
return llvm::Constant::getNullValue(t_void_p_);
} else {
......@@ -1594,9 +990,7 @@ void CodeGenLLVM::VisitStmt_(const Store* op) {
}
llvm::Value* CodeGenLLVM::VisitExpr_(const Call* op) {
if (op->is_intrinsic(intrinsic::tvm_call_packed_lowered)) {
return CreateCallPacked(op);
} else if (op->call_type == Call::Intrinsic ||
if (op->call_type == Call::Intrinsic ||
op->call_type == Call::PureIntrinsic) {
return CreateIntrinsic(op);
} else {
......@@ -1614,40 +1008,6 @@ void CodeGenLLVM::VisitStmt_(const For* op) {
ConstInt32(1),
op->loop_var,
op->body);
} else if (op->for_type == ForType::Parallel) {
if (parallel_env_.penv == nullptr) {
CreateParallelLaunch(
For::make(
op->loop_var, op->min, op->extent,
op->for_type, op->device_api, op->body), 0);
} else {
// already in parallel env.
CHECK(parallel_env_.task_id.defined());
CHECK(parallel_env_.num_task.defined());
CHECK(parallel_env_.penv != nullptr);
Type t = op->extent.type();
Expr num_task = cast(t, parallel_env_.num_task);
Expr task_id = cast(t, parallel_env_.task_id);
CHECK(!parallel_env_.hit_parallel_loop)
<< "Nested parallel loop is not supported by threadpool, try fuse them instead";
parallel_env_.hit_parallel_loop = true;
if (parallel_env_.stride_pattern) {
CreateSerialFor(MakeValue(task_id),
MakeValue(op->extent),
MakeValue(num_task),
op->loop_var,
op->body);
} else {
Expr step = (op->extent + num_task - make_const(t, 1)) / num_task;
Expr begin = Min::make(task_id * step, op->extent);
Expr end = Min::make((task_id + make_const(t, 1)) * step, op->extent);
CreateSerialFor(MakeValue(begin),
MakeValue(end),
ConstInt32(1),
op->loop_var,
op->body);
}
}
} else {
LOG(FATAL) << "cannot handle for type " << op->for_type;
}
......@@ -1727,58 +1087,12 @@ void CodeGenLLVM::VisitStmt_(const AttrStmt* op) {
alloc_storage_info_[v].alignment =
static_cast<int>(op->value.as<IntImm>()->value);
this->VisitStmt(op->body);
} else if (op->attr_key == ir::attr::coproc_uop_scope) {
this->CreateStaticInit(op->value.as<StringImm>()->value, op->body);
} else if (op->attr_key == ir::attr::compute_scope) {
this->CreateComputeScope(op);
} else if (op->attr_key == ir::attr::pragma_scope) {
const std::string& pname = op->value.as<StringImm>()->value;
if (pname == "parallel_stride_pattern") {
CHECK(parallel_env_.penv != nullptr)
<< "Pragma parallel_stride_pattern only valid in parallel launch";
parallel_env_.stride_pattern = true;
this->VisitStmt(op->body);
} else if (pname == "parallel_launch_point") {
CreateParallelLaunch(op->body, 0);
} else if (pname == "parallel_barrier_when_finish") {
CHECK(parallel_env_.penv != nullptr)
<< "Cannot run barrier without parallel environment";
CHECK(!parallel_env_.hit_parallel_loop)
<< "Cannot not place within parallel loop as the workload may differ, "
<< " place it between parallel and parallel_launch_point";
this->VisitStmt(op->body);
builder_->CreateCall(
RuntimeTVMParallelBarrier(),
{MakeValue(parallel_env_.task_id), parallel_env_.penv});
} else {
LOG(WARNING) << "Unknown pragma " << pname;
this->VisitStmt(op->body);
}
} else {
this->VisitStmt(op->body);
}
}
void CodeGenLLVM::VisitStmt_(const AssertStmt* op) {
using llvm::BasicBlock;
llvm::Value* cond = MakeValue(op->condition);
std::ostringstream os;
os << "Assert fail: " << op->condition;
if (op->message.as<StringImm>()) {
os << ", " << op->message.as<StringImm>()->value;
}
llvm::Value* msg = GetConstString(os.str());
BasicBlock* fail_block = BasicBlock::Create(
*ctx_, "assert_fail", function_);
BasicBlock* end_block = BasicBlock::Create(
*ctx_, "assert_end", function_);
builder_->CreateCondBr(cond, end_block, fail_block, md_very_likely_branch_);
// fail condition.
builder_->SetInsertPoint(fail_block);
builder_->CreateCall(RuntimeTVMAPISetLastError(), {msg});
builder_->CreateRet(ConstInt32(-1));
// otherwise set it to be new end.
builder_->SetInsertPoint(end_block);
// Detect useful invariant pattern and use them to visit child.
// Pattern: Var % const == 0
// TODO(tqchen) move these pattern to a generic scope info visitor.
......@@ -1819,13 +1133,16 @@ void CodeGenLLVM::VisitStmt_(const LetStmt* op) {
align_map_[op->var.get()] = arith::EvalModular(op->value, align_map_);
this->VisitStmt(op->body);
}
void CodeGenLLVM::VisitStmt_(const Block* op) {
VisitStmt(op->first);
if (op->rest.defined()) VisitStmt(op->rest);
}
void CodeGenLLVM::VisitStmt_(const Evaluate *op) {
MakeValue(op->value);
}
void CodeGenLLVM::VisitStmt_(const ProducerConsumer* op) {
VisitStmt(op->body);
}
......
......@@ -44,7 +44,7 @@ class CodeGenLLVM :
* \param dynamic_lookup Whether dynamically lookup runtime function
* or use the runtime function table passed by caller.
*/
void Init(const std::string& module_name,
virtual void Init(const std::string& module_name,
llvm::TargetMachine* tm,
llvm::LLVMContext* ctx,
bool system_lib,
......@@ -53,17 +53,17 @@ class CodeGenLLVM :
* \brief Compile and add function f to the current module.
* \param f The function to be added.
*/
void AddFunction(const LoweredFunc& f);
virtual void AddFunction(const LoweredFunc& f);
/*!
* \brief Add main function as the entry name
* \param entry_func_name The name of entry function to be added.
*/
void AddMainFunction(const std::string& entry_func_name);
virtual void AddMainFunction(const std::string& entry_func_name);
/*!
* \brief Finish current pass of codegen, get the module.
* \return the created module.
*/
std::unique_ptr<llvm::Module> Finish();
virtual std::unique_ptr<llvm::Module> Finish();
/*!
* \brief Create Value for expression e
* \param e The expression to be created value for.
......@@ -120,8 +120,6 @@ class CodeGenLLVM :
virtual llvm::Value* CreateIntrinsic(const Call* op);
// create extern function call
virtual llvm::Value* CreateCallExtern(const Call* op);
// create call into tvm packed function.
virtual llvm::Value* CreateCallPacked(const Call* op);
// Scalarize e by iterating elements of e.
// f is a callback that takes index and v.
virtual void Scalarize(const Expr& e,
......@@ -134,6 +132,14 @@ class CodeGenLLVM :
/*! \brief The alignment of allocation */
int alignment{0};
};
// Initialize target
virtual void InitTarget(llvm::TargetMachine* tm);
// Add module startup function if needed.
virtual void AddStartupFunction() {}
// apply optimization on the module.
virtual void Optimize();
// Get the maximim storage align bits of buffer pointer given storage scope.
virtual int NativeVectorBits(const std::string& storage_scope) const;
/*!
* \param t The original type.
* \return LLVM type of t
......@@ -145,15 +151,36 @@ class CodeGenLLVM :
void GetAlignment(
Type t, const Variable* buf_var, const Expr& index,
int* p_alignment, int* p_native_bits);
// Get constant string
llvm::Value* GetConstString(const std::string& str);
// do a scalarize call with f
llvm::Value* CreateScalarizedCall(
const Call* op, llvm::Function* f, const std::vector<llvm::Value*>& args);
// Initialize target
virtual void InitTarget(llvm::TargetMachine* tm);
// apply optimization on the module.
virtual void Optimize();
// Get the maximim storage align bits of buffer pointer given storage scope.
virtual int NativeVectorBits(const std::string& storage_scope) const;
// cast operatpr
llvm::Value* CreateCast(Type from, Type to, llvm::Value* value);
// comparison op
llvm::Value* GetVarValue(const Variable* v) const;
llvm::Value* CreateLT(Type t, llvm::Value* a, llvm::Value* b);
llvm::Value* CreateLE(Type t, llvm::Value* a, llvm::Value* b);
llvm::Value* CreateGT(Type t, llvm::Value* a, llvm::Value* b);
llvm::Value* CreateGE(Type t, llvm::Value* a, llvm::Value* b);
llvm::Value* CreateAdd(Type t, llvm::Value* a, llvm::Value* b);
llvm::Value* CreateSub(Type t, llvm::Value* a, llvm::Value* b);
llvm::Value* CreateMul(Type t, llvm::Value* a, llvm::Value* b);
llvm::Value* CreateBroadcast(llvm::Value* value, int lanes);
llvm::Value* CreateBufferPtr(Type t, llvm::Value* buffer, llvm::Value* index);
// Vector concatenation.
llvm::Value* CreateVecSlice(llvm::Value* vec, int begin, int extent);
llvm::Value* CreateVecFlip(llvm::Value* vec);
llvm::Value* CreateVecConcat(std::vector<llvm::Value*> vecs);
llvm::Value* CreateVecPad(llvm::Value* vec, int target_lanes);
// Create serial for
void CreateSerialFor(llvm::Value* begin,
llvm::Value* end,
llvm::Value* stride,
const VarExpr& loop_var, const Stmt& body);
// add alias information.
void AddAliasInfo(llvm::Instruction* load, const Variable* buffer, Expr index, Type type);
// The IRBuilder.
using IRBuilder = llvm::IRBuilder<llvm::ConstantFolder, llvm::IRBuilderDefaultInserter>;
// The current function
......@@ -177,129 +204,25 @@ class CodeGenLLVM :
llvm::Type* t_int32_{nullptr};
llvm::Type* t_int64_{nullptr};
llvm::Type* t_float64_{nullptr};
// branch
// meta data
llvm::MDNode* md_very_likely_branch_{nullptr};
llvm::MDNode* md_tbaa_root_{nullptr};
llvm::MDNode* md_tbaa_alias_set_{nullptr};
llvm::MDNode* md_tbaa_ctx_ptr_{nullptr};
// TVM related data types
llvm::Type* t_tvm_shape_index_{nullptr};
llvm::Type* t_tvm_func_handle_{nullptr};
llvm::StructType* t_tvm_context_{nullptr};
llvm::StructType* t_tvm_type_{nullptr};
llvm::StructType* t_tvm_array_{nullptr};
llvm::StructType* t_tvm_value_{nullptr};
llvm::StructType* t_tvm_parallel_group_env_{nullptr};
llvm::FunctionType* ftype_tvm_parallel_lambda_{nullptr};
llvm::FunctionType* ftype_tvm_func_call_{nullptr};
llvm::FunctionType* ftype_tvm_get_func_from_env_{nullptr};
llvm::FunctionType* ftype_tvm_api_set_last_error_{nullptr};
llvm::FunctionType* ftype_tvm_parallel_launch_{nullptr};
llvm::FunctionType* ftype_tvm_parallel_barrier_{nullptr};
llvm::FunctionType* ftype_tvm_register_system_symbol_{nullptr};
// Lazy entry for function call.
llvm::FunctionType* ftype_tvm_static_init_callback_{nullptr};
llvm::FunctionType* ftype_tvm_static_init_{nullptr};
// The acting body
llvm::BasicBlock* block_{nullptr};
/*! \brief native vector bits of current targetx*/
int native_vector_bits_{0};
/*! \brief the storage scope of allocation */
std::unordered_map<const Variable*, StorageInfo> alloc_storage_info_;
private:
// the parallel group information
struct ParallelEnv {
VarExpr task_id;
VarExpr num_task;
bool stride_pattern{false};
bool hit_parallel_loop{false};
llvm::Value* penv{nullptr};
};
// Get runtime functions
llvm::GlobalVariable* InitContextPtr(llvm::Type* type, std::string name);
llvm::Value* GetContextPtr(llvm::GlobalVariable* gv);
llvm::Value* RuntimeTVMFuncCall();
llvm::Value* RuntimeTVMGetFuncFromEnv();
llvm::Value* RuntimeTVMAPISetLastError();
llvm::Value* RuntimeTVMParallelLaunch();
llvm::Value* RuntimeTVMParallelBarrier();
// comparison op
llvm::Value* GetVarValue(const Variable* v) const;
llvm::Value* CreateLT(Type t, llvm::Value* a, llvm::Value* b);
llvm::Value* CreateLE(Type t, llvm::Value* a, llvm::Value* b);
llvm::Value* CreateGT(Type t, llvm::Value* a, llvm::Value* b);
llvm::Value* CreateGE(Type t, llvm::Value* a, llvm::Value* b);
llvm::Value* CreateAdd(Type t, llvm::Value* a, llvm::Value* b);
llvm::Value* CreateSub(Type t, llvm::Value* a, llvm::Value* b);
llvm::Value* CreateMul(Type t, llvm::Value* a, llvm::Value* b);
llvm::Value* CreateBroadcast(llvm::Value* value, int lanes);
llvm::Value* GetConstString(const std::string& str);
llvm::Value* CreateBufferPtr(Type t, llvm::Value* buffer, llvm::Value* index);
llvm::Value* CreateStructRefPtr(Type t, llvm::Value* buffer, llvm::Value* index, int kind);
llvm::Value* CreateCast(Type from, Type to, llvm::Value* value);
llvm::Value* GetPackedFuncHandle(const std::string& str);
// Vector concatenation.
llvm::Value* CreateVecSlice(llvm::Value* vec, int begin, int extent);
llvm::Value* CreateVecFlip(llvm::Value* vec);
llvm::Value* CreateVecConcat(std::vector<llvm::Value*> vecs);
llvm::Value* CreateVecPad(llvm::Value* vec, int target_lanes);
llvm::Value* PackClosureData(const Array<Var>& fields);
void UnpackClosureData(llvm::Value*cdata,
const Array<Var>& fields,
std::unordered_map<const Variable*, llvm::Value*>* vmap);
// Create static initialization
void CreateStaticInit(const std::string& init_fname, const Stmt& body);
// Create parallel launch
void CreateParallelLaunch(const Stmt& body, int num_task);
// Create serial for
void CreateSerialFor(llvm::Value* begin,
llvm::Value* end,
llvm::Value* stride,
const VarExpr& loop_var, const Stmt& body);
// Create a new compute scope.
void CreateComputeScope(const AttrStmt* op);
// Check if the call to packed function is successful
// if not directly finalize function and pass on return code.
// return the end block after the check
llvm::BasicBlock* CheckCallSuccess(llvm::Value* retcode);
// Add a function to set global module context
void InitGlobalContext(bool dynamic_lookup);
// Add module startup function if needed.
void AddStartupFunction();
// add alias information.
void AddAliasInfo(llvm::Instruction* load, const Variable* buffer, Expr index, Type type);
// The definition of local variable.
std::unordered_map<const Variable*, llvm::Value*> var_map_;
// global strings
std::unordered_map<std::string, llvm::Constant*> str_map_;
// The alignment information
std::unordered_map<const Variable*, arith::ModularEntry> align_map_;
// Whether current function is restricted
bool is_restricted_{true};
// The alignment information
std::unordered_map<const Variable*, arith::ModularEntry> align_map_;
// set of var that are not restricted(can alias)
std::unordered_set<const Variable*> alias_var_set_;
// Context for injection lookup
llvm::GlobalVariable* gv_mod_ctx_{nullptr};
llvm::GlobalVariable* gv_tvm_func_call_{nullptr};
llvm::GlobalVariable* gv_tvm_get_func_from_env_{nullptr};
llvm::GlobalVariable* gv_tvm_api_set_last_error_{nullptr};
llvm::GlobalVariable* gv_tvm_parallel_launch_{nullptr};
llvm::GlobalVariable* gv_tvm_parallel_barrier_{nullptr};
std::unordered_map<std::string, llvm::GlobalVariable*> gv_func_map_;
// context for direct dynamic lookup
llvm::Function* f_tvm_func_call_{nullptr};
llvm::Function* f_tvm_get_func_from_env_{nullptr};
llvm::Function* f_tvm_api_set_last_error_{nullptr};
llvm::Function* f_tvm_parallel_launch_{nullptr};
llvm::Function* f_tvm_parallel_barrier_{nullptr};
llvm::Function* f_tvm_register_system_symbol_{nullptr};
// Current parallel environment scope.
ParallelEnv parallel_env_;
// global to packed function handle
std::unordered_map<std::string, llvm::GlobalVariable*> func_handle_map_;
// List of symbols to be exported to TVM system lib.
std::vector<std::pair<std::string, llvm::Value*> > export_system_symbols_;
};
} // namespace codegen
} // namespace tvm
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment