Commit b07ceff5 by Tianqi Chen Committed by GitHub

[CODEGEN] Enable closure with no argument (#635)

parent f1aabedc
......@@ -337,7 +337,11 @@ void CodeGenCPU::CreateComputeScope(const AttrStmt* op) {
builder_->SetInsertPoint(compute_call_end);
}
llvm::Value* CodeGenCPU::PackClosureData(const Array<Var>& vfields) {
llvm::Value* CodeGenCPU::PackClosureData(const Array<Var>& vfields, uint64_t* num_bytes) {
if (vfields.size() == 0) {
*num_bytes = 0U;
return llvm::Constant::getNullValue(t_void_p_);
}
std::vector<llvm::Type*> fields;
for (Var v : vfields) {
auto it = var_map_.find(v.get());
......@@ -352,6 +356,8 @@ llvm::Value* CodeGenCPU::PackClosureData(const Array<Var>& vfields) {
var_map_.at(vfields[i].get()),
builder_->CreateInBoundsGEP(cdata, {zero, ConstInt32(i)}));
}
*num_bytes = data_layout_->getTypeAllocSize(
llvm::cast<llvm::PointerType>(cdata->getType())->getElementType());
return cdata;
}
......@@ -374,7 +380,8 @@ void CodeGenCPU::CreateParallelLaunch(const Stmt& body, int num_task) {
"__tvm_parallel_lambda", module_.get());
// allocate and setup the closure, call the closure.
Array<Var> vfields = ir::UndefinedVars(body, {});
llvm::Value* cdata = PackClosureData(vfields);
uint64_t nbytes;
llvm::Value* cdata = PackClosureData(vfields, &nbytes);
BasicBlock* par_launch_end = CheckCallSuccess(
builder_->CreateCall(
RuntimeTVMParallelLaunch(),
......@@ -431,14 +438,13 @@ void CodeGenCPU::CreateStaticInit(const std::string& init_fname, const Stmt& bod
ftype_tvm_static_init_, llvm::Function::ExternalLinkage, init_fname, module_.get());
}
// allocate and setup the closure, call the closure.
uint64_t nbytes;
Array<Var> vfields = ir::UndefinedVars(body, {});
llvm::Value* cdata = PackClosureData(vfields);
llvm::Value* nbytes = ConstInt32(data_layout_->getTypeAllocSize(
llvm::cast<llvm::PointerType>(cdata->getType())->getElementType()));
llvm::Value* cdata = PackClosureData(vfields, &nbytes);
BasicBlock* init_end = CheckCallSuccess(
builder_->CreateCall(
finit,
{gv, f, builder_->CreatePointerCast(cdata, t_void_p_), nbytes}));
{gv, f, builder_->CreatePointerCast(cdata, t_void_p_), ConstInt32(nbytes)}));
// Setup the closure function.
BasicBlock *lambda_entry = BasicBlock::Create(*ctx_, "entry", f);
builder_->SetInsertPoint(lambda_entry);
......
......@@ -73,7 +73,7 @@ class CodeGenCPU : public CodeGenLLVM {
llvm::Value* RuntimeTVMParallelLaunch();
llvm::Value* RuntimeTVMParallelBarrier();
llvm::Value* GetPackedFuncHandle(const std::string& str);
llvm::Value* PackClosureData(const Array<Var>& fields);
llvm::Value* PackClosureData(const Array<Var>& fields, uint64_t *num_bytes);
llvm::Value* CreateStructRefPtr(Type t, llvm::Value* buffer, llvm::Value* index, int kind);
void UnpackClosureData(llvm::Value*cdata,
const Array<Var>& fields,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment