Commit f433373d by Tianqi Chen Committed by GitHub

[CODEGEN] Generate main compute function separately with alias info (#253)

parent 8d241b9d
...@@ -144,6 +144,11 @@ constexpr const char* thread_extent = "thread_extent"; ...@@ -144,6 +144,11 @@ constexpr const char* thread_extent = "thread_extent";
constexpr const char* virtual_thread = "virtual_thread"; constexpr const char* virtual_thread = "virtual_thread";
/*! \brief Mark the scope as volatile access for certain handle. */ /*! \brief Mark the scope as volatile access for certain handle. */
constexpr const char* volatile_scope = "volatile_scope"; constexpr const char* volatile_scope = "volatile_scope";
/*!
* \brief Mark the scope as when computation start to happen
* This can hint some code generator to create a new function for compute.
*/
constexpr const char* compute_scope = "compute_scope";
/*! \brief Mark storage scope of buffers */ /*! \brief Mark storage scope of buffers */
constexpr const char* storage_scope = "storage_scope"; constexpr const char* storage_scope = "storage_scope";
/*! \brief Mark storage alignement requirement of buffers */ /*! \brief Mark storage alignement requirement of buffers */
......
...@@ -174,8 +174,16 @@ void CodeGenLLVM::AddFunction(const LoweredFunc& f) { ...@@ -174,8 +174,16 @@ void CodeGenLLVM::AddFunction(const LoweredFunc& f) {
// setup the function. // setup the function.
function_ = llvm::cast<llvm::Function>(module_->getOrInsertFunction(f->name, ftype)); function_ = llvm::cast<llvm::Function>(module_->getOrInsertFunction(f->name, ftype));
function_->setCallingConv(llvm::CallingConv::C); function_->setCallingConv(llvm::CallingConv::C);
size_t idx = 0; // set handle argument to be non alias.
if (is_restricted_) {
for (size_t i = 0; i < f->args.size(); ++i) {
if (f->args[i].type().is_handle()) {
function_->setDoesNotAlias(i + 1);
}
}
}
size_t idx = 0;
for (auto it = function_->arg_begin(); for (auto it = function_->arg_begin();
it != function_->arg_end(); ++it, ++idx) { it != function_->arg_end(); ++it, ++idx) {
llvm::Argument* v = &(*it); llvm::Argument* v = &(*it);
...@@ -649,6 +657,54 @@ llvm::Value* CodeGenLLVM::GetConstString(const std::string& str) { ...@@ -649,6 +657,54 @@ llvm::Value* CodeGenLLVM::GetConstString(const std::string& str) {
} }
} }
void CodeGenLLVM::CreateComputeScope(const AttrStmt* op) {
// There are two reasons why we create another function for compute_scope
// - Make sure the generated compute function is clearly separately(though it can get inlined)
// - Set noalias on all the pointer arguments, some of them are loaded from TVMArgs.
// This is easier than set the alias scope manually.
using llvm::BasicBlock;
Array<Var> vargs = ir::UndefinedVars(op->body, {});
std::vector<llvm::Value*> arg_values;
std::vector<llvm::Type*> arg_types;
for (Var v : vargs) {
llvm::Value* value = MakeValue(v);
arg_values.push_back(value);
arg_types.push_back(value->getType());
}
llvm::FunctionType* ftype =
llvm::FunctionType::get(t_int_, arg_types, false);
llvm::Function* fcompute =
llvm::Function::Create(ftype,
llvm::Function::PrivateLinkage,
op->value.as<StringImm>()->value,
module_.get());
BasicBlock* compute_call_end = CheckCallSuccess(
builder_->CreateCall(fcompute, arg_values));
// setup compute fuinction.
std::unordered_map<const Variable*, llvm::Value*> new_vmap;
size_t idx = 0;
for (auto it = fcompute->arg_begin();
it != fcompute->arg_end(); ++it, ++idx) {
llvm::Argument* v = &(*it);
const Var& var = vargs[idx];
new_vmap[var.get()] = v;
if (var.type().is_handle() && !alias_var_set_.count(var.get())) {
// set non alias.
fcompute->setDoesNotAlias(idx + 1);
}
}
std::swap(function_, fcompute);
std::swap(new_vmap, var_map_);
BasicBlock *compute_entry = BasicBlock::Create(*ctx_, "entry", function_);
builder_->SetInsertPoint(compute_entry);
this->VisitStmt(op->body);
builder_->CreateRet(ConstInt32(0));
// swap the var map back, now we are back on track.
std::swap(new_vmap, var_map_);
std::swap(function_, fcompute);
builder_->SetInsertPoint(compute_call_end);
}
void CodeGenLLVM::CreateParallelFor(const For* op) { void CodeGenLLVM::CreateParallelFor(const For* op) {
using llvm::BasicBlock; using llvm::BasicBlock;
llvm::Value* min = MakeValue(op->min); llvm::Value* min = MakeValue(op->min);
...@@ -1429,6 +1485,8 @@ void CodeGenLLVM::VisitStmt_(const AttrStmt* op) { ...@@ -1429,6 +1485,8 @@ void CodeGenLLVM::VisitStmt_(const AttrStmt* op) {
alloc_storage_info_[v].alignment = alloc_storage_info_[v].alignment =
static_cast<int>(op->value.as<IntImm>()->value); static_cast<int>(op->value.as<IntImm>()->value);
this->VisitStmt(op->body); this->VisitStmt(op->body);
} else if (op->attr_key == ir::attr::compute_scope) {
this->CreateComputeScope(op);
} else { } else {
this->VisitStmt(op->body); this->VisitStmt(op->body);
} }
......
...@@ -225,6 +225,8 @@ class CodeGenLLVM : ...@@ -225,6 +225,8 @@ class CodeGenLLVM :
// Create serial for // Create serial for
void CreateSerialFor(llvm::Value* begin, llvm::Value* end, void CreateSerialFor(llvm::Value* begin, llvm::Value* end,
const VarExpr& loop_var, const Stmt& body); const VarExpr& loop_var, const Stmt& body);
// Create a new compute scope.
void CreateComputeScope(const AttrStmt* op);
// Check if the call to packed function is successful // Check if the call to packed function is successful
// if not directly finalize function and pass on return code. // if not directly finalize function and pass on return code.
// return the end block after the check // return the end block after the check
......
...@@ -149,6 +149,9 @@ LoweredFunc MakeAPI(Stmt body, ...@@ -149,6 +149,9 @@ LoweredFunc MakeAPI(Stmt body,
Int(32), intrinsic::tvm_call_packed, Int(32), intrinsic::tvm_call_packed,
{StringImm::make(runtime::symbol::tvm_set_device), {StringImm::make(runtime::symbol::tvm_set_device),
device_type, device_id}, Call::Intrinsic))); device_type, device_id}, Call::Intrinsic)));
body = AttrStmt::make(
make_zero(Int(32)), attr::compute_scope,
StringImm::make(name + "_compute_"), body);
body = Block::make(set_device, body); body = Block::make(set_device, body);
} }
n->body = MergeNest( n->body = MergeNest(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment