Commit b07ceff5 by Tianqi Chen Committed by GitHub

[CODEGEN] Enable closure with no argument (#635)

parent f1aabedc
...@@ -337,7 +337,11 @@ void CodeGenCPU::CreateComputeScope(const AttrStmt* op) { ...@@ -337,7 +337,11 @@ void CodeGenCPU::CreateComputeScope(const AttrStmt* op) {
builder_->SetInsertPoint(compute_call_end); builder_->SetInsertPoint(compute_call_end);
} }
llvm::Value* CodeGenCPU::PackClosureData(const Array<Var>& vfields) { llvm::Value* CodeGenCPU::PackClosureData(const Array<Var>& vfields, uint64_t* num_bytes) {
if (vfields.size() == 0) {
*num_bytes = 0U;
return llvm::Constant::getNullValue(t_void_p_);
}
std::vector<llvm::Type*> fields; std::vector<llvm::Type*> fields;
for (Var v : vfields) { for (Var v : vfields) {
auto it = var_map_.find(v.get()); auto it = var_map_.find(v.get());
...@@ -352,6 +356,8 @@ llvm::Value* CodeGenCPU::PackClosureData(const Array<Var>& vfields) { ...@@ -352,6 +356,8 @@ llvm::Value* CodeGenCPU::PackClosureData(const Array<Var>& vfields) {
var_map_.at(vfields[i].get()), var_map_.at(vfields[i].get()),
builder_->CreateInBoundsGEP(cdata, {zero, ConstInt32(i)})); builder_->CreateInBoundsGEP(cdata, {zero, ConstInt32(i)}));
} }
*num_bytes = data_layout_->getTypeAllocSize(
llvm::cast<llvm::PointerType>(cdata->getType())->getElementType());
return cdata; return cdata;
} }
...@@ -374,7 +380,8 @@ void CodeGenCPU::CreateParallelLaunch(const Stmt& body, int num_task) { ...@@ -374,7 +380,8 @@ void CodeGenCPU::CreateParallelLaunch(const Stmt& body, int num_task) {
"__tvm_parallel_lambda", module_.get()); "__tvm_parallel_lambda", module_.get());
// allocate and setup the closure, call the closure. // allocate and setup the closure, call the closure.
Array<Var> vfields = ir::UndefinedVars(body, {}); Array<Var> vfields = ir::UndefinedVars(body, {});
llvm::Value* cdata = PackClosureData(vfields); uint64_t nbytes;
llvm::Value* cdata = PackClosureData(vfields, &nbytes);
BasicBlock* par_launch_end = CheckCallSuccess( BasicBlock* par_launch_end = CheckCallSuccess(
builder_->CreateCall( builder_->CreateCall(
RuntimeTVMParallelLaunch(), RuntimeTVMParallelLaunch(),
...@@ -431,14 +438,13 @@ void CodeGenCPU::CreateStaticInit(const std::string& init_fname, const Stmt& bod ...@@ -431,14 +438,13 @@ void CodeGenCPU::CreateStaticInit(const std::string& init_fname, const Stmt& bod
ftype_tvm_static_init_, llvm::Function::ExternalLinkage, init_fname, module_.get()); ftype_tvm_static_init_, llvm::Function::ExternalLinkage, init_fname, module_.get());
} }
// allocate and setup the closure, call the closure. // allocate and setup the closure, call the closure.
uint64_t nbytes;
Array<Var> vfields = ir::UndefinedVars(body, {}); Array<Var> vfields = ir::UndefinedVars(body, {});
llvm::Value* cdata = PackClosureData(vfields); llvm::Value* cdata = PackClosureData(vfields, &nbytes);
llvm::Value* nbytes = ConstInt32(data_layout_->getTypeAllocSize(
llvm::cast<llvm::PointerType>(cdata->getType())->getElementType()));
BasicBlock* init_end = CheckCallSuccess( BasicBlock* init_end = CheckCallSuccess(
builder_->CreateCall( builder_->CreateCall(
finit, finit,
{gv, f, builder_->CreatePointerCast(cdata, t_void_p_), nbytes})); {gv, f, builder_->CreatePointerCast(cdata, t_void_p_), ConstInt32(nbytes)}));
// Setup the closure function. // Setup the closure function.
BasicBlock *lambda_entry = BasicBlock::Create(*ctx_, "entry", f); BasicBlock *lambda_entry = BasicBlock::Create(*ctx_, "entry", f);
builder_->SetInsertPoint(lambda_entry); builder_->SetInsertPoint(lambda_entry);
......
...@@ -73,7 +73,7 @@ class CodeGenCPU : public CodeGenLLVM { ...@@ -73,7 +73,7 @@ class CodeGenCPU : public CodeGenLLVM {
llvm::Value* RuntimeTVMParallelLaunch(); llvm::Value* RuntimeTVMParallelLaunch();
llvm::Value* RuntimeTVMParallelBarrier(); llvm::Value* RuntimeTVMParallelBarrier();
llvm::Value* GetPackedFuncHandle(const std::string& str); llvm::Value* GetPackedFuncHandle(const std::string& str);
llvm::Value* PackClosureData(const Array<Var>& fields); llvm::Value* PackClosureData(const Array<Var>& fields, uint64_t *num_bytes);
llvm::Value* CreateStructRefPtr(Type t, llvm::Value* buffer, llvm::Value* index, int kind); llvm::Value* CreateStructRefPtr(Type t, llvm::Value* buffer, llvm::Value* index, int kind);
void UnpackClosureData(llvm::Value*cdata, void UnpackClosureData(llvm::Value*cdata,
const Array<Var>& fields, const Array<Var>& fields,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment