Commit f433373d by Tianqi Chen Committed by GitHub

[CODEGEN] Generate main compute function separately with alias info (#253)

parent 8d241b9d
......@@ -144,6 +144,11 @@ constexpr const char* thread_extent = "thread_extent";
constexpr const char* virtual_thread = "virtual_thread";
/*! \brief Mark the scope as volatile access for certain handle. */
constexpr const char* volatile_scope = "volatile_scope";
/*!
* \brief Mark the scope as when computation start to happen
* This can hint some code generator to create a new function for compute.
*/
constexpr const char* compute_scope = "compute_scope";
/*! \brief Mark storage scope of buffers */
constexpr const char* storage_scope = "storage_scope";
/*! \brief Mark storage alignement requirement of buffers */
......
......@@ -174,8 +174,16 @@ void CodeGenLLVM::AddFunction(const LoweredFunc& f) {
// setup the function.
function_ = llvm::cast<llvm::Function>(module_->getOrInsertFunction(f->name, ftype));
function_->setCallingConv(llvm::CallingConv::C);
size_t idx = 0;
// set handle argument to be non alias.
if (is_restricted_) {
for (size_t i = 0; i < f->args.size(); ++i) {
if (f->args[i].type().is_handle()) {
function_->setDoesNotAlias(i + 1);
}
}
}
size_t idx = 0;
for (auto it = function_->arg_begin();
it != function_->arg_end(); ++it, ++idx) {
llvm::Argument* v = &(*it);
......@@ -649,6 +657,54 @@ llvm::Value* CodeGenLLVM::GetConstString(const std::string& str) {
}
}
void CodeGenLLVM::CreateComputeScope(const AttrStmt* op) {
// There are two reasons why we create another function for compute_scope
// - Make sure the generated compute function is clearly separately(though it can get inlined)
// - Set noalias on all the pointer arguments, some of them are loaded from TVMArgs.
// This is easier than set the alias scope manually.
using llvm::BasicBlock;
Array<Var> vargs = ir::UndefinedVars(op->body, {});
std::vector<llvm::Value*> arg_values;
std::vector<llvm::Type*> arg_types;
for (Var v : vargs) {
llvm::Value* value = MakeValue(v);
arg_values.push_back(value);
arg_types.push_back(value->getType());
}
llvm::FunctionType* ftype =
llvm::FunctionType::get(t_int_, arg_types, false);
llvm::Function* fcompute =
llvm::Function::Create(ftype,
llvm::Function::PrivateLinkage,
op->value.as<StringImm>()->value,
module_.get());
BasicBlock* compute_call_end = CheckCallSuccess(
builder_->CreateCall(fcompute, arg_values));
// setup compute fuinction.
std::unordered_map<const Variable*, llvm::Value*> new_vmap;
size_t idx = 0;
for (auto it = fcompute->arg_begin();
it != fcompute->arg_end(); ++it, ++idx) {
llvm::Argument* v = &(*it);
const Var& var = vargs[idx];
new_vmap[var.get()] = v;
if (var.type().is_handle() && !alias_var_set_.count(var.get())) {
// set non alias.
fcompute->setDoesNotAlias(idx + 1);
}
}
std::swap(function_, fcompute);
std::swap(new_vmap, var_map_);
BasicBlock *compute_entry = BasicBlock::Create(*ctx_, "entry", function_);
builder_->SetInsertPoint(compute_entry);
this->VisitStmt(op->body);
builder_->CreateRet(ConstInt32(0));
// swap the var map back, now we are back on track.
std::swap(new_vmap, var_map_);
std::swap(function_, fcompute);
builder_->SetInsertPoint(compute_call_end);
}
void CodeGenLLVM::CreateParallelFor(const For* op) {
using llvm::BasicBlock;
llvm::Value* min = MakeValue(op->min);
......@@ -1429,6 +1485,8 @@ void CodeGenLLVM::VisitStmt_(const AttrStmt* op) {
alloc_storage_info_[v].alignment =
static_cast<int>(op->value.as<IntImm>()->value);
this->VisitStmt(op->body);
} else if (op->attr_key == ir::attr::compute_scope) {
this->CreateComputeScope(op);
} else {
this->VisitStmt(op->body);
}
......
......@@ -225,6 +225,8 @@ class CodeGenLLVM :
// Create serial for
void CreateSerialFor(llvm::Value* begin, llvm::Value* end,
const VarExpr& loop_var, const Stmt& body);
// Create a new compute scope.
void CreateComputeScope(const AttrStmt* op);
// Check if the call to packed function is successful
// if not directly finalize function and pass on return code.
// return the end block after the check
......
......@@ -149,6 +149,9 @@ LoweredFunc MakeAPI(Stmt body,
Int(32), intrinsic::tvm_call_packed,
{StringImm::make(runtime::symbol::tvm_set_device),
device_type, device_id}, Call::Intrinsic)));
body = AttrStmt::make(
make_zero(Int(32)), attr::compute_scope,
StringImm::make(name + "_compute_"), body);
body = Block::make(set_device, body);
}
n->body = MergeNest(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment