Commit b40d43c4 by Tianqi Chen Committed by GitHub

[PASS][RUNTIME] Support attr scope lift and runonce (#303)

parent 7d67e473
......@@ -145,6 +145,11 @@ constexpr const char* thread_extent = "thread_extent";
constexpr const char* virtual_thread = "virtual_thread";
/*! \brief Mark region is processed by a co-proccesor */
constexpr const char* coproc_scope = "coproc_scope";
/*!
* \brief Mark region creates coprocessor micro ops,
* can be reused if corresponding variable is independent.
*/
constexpr const char* coproc_uop_scope = "coproc_uop_scope";
/*! \brief Mark the scope as volatile access for certain handle. */
constexpr const char* volatile_scope = "volatile_scope";
/*!
......
......@@ -258,6 +258,15 @@ Stmt LoopPartition(Stmt stmt);
Stmt CoProcSync(Stmt stmt);
/*!
* \brief Lift common attrs with attr_key to outer scope.
*
* \param stmt The stmt to be trasnformed
* \param attr_key The attribute key to be checked.
* \return Transformed stmt.
*/
Stmt LiftAttrScope(Stmt stmt, std::string attr_key);
/*!
* \brief Make an user callable API LoweredFunc.
*
* The main task of this function is to create code to :
......
......@@ -110,6 +110,23 @@ TVM_DLL int TVMBackendParallelLaunch(FTVMParallelLambda flambda,
*/
TVM_DLL int TVMBackendParallelBarrier(int task_id, TVMParallelGroupEnv* penv);
/*!
* \brief Simple static initialization fucntion.
* Run f once and set handle to be not null.
* This function is mainly used for test purpose.
*
* \param handle An global address to indicate f
* \param f The function to be ran
* \param cdata The closure data to pass to the function.
* \param nbytes Number of bytes in the closure data.
* \return 0 when no error is thrown, -1 when failure happens
*/
TVM_DLL int TVMBackendRunOnce(void** handle,
int (*f)(void*),
void *cdata,
int nbytes);
#ifdef __cplusplus
} // TVM_EXTERN_C
#endif
......
......@@ -24,13 +24,14 @@ class BuildConfig(object):
"""
current = None
defaults = {
'auto_unroll_max_step': 0,
'auto_unroll_min_depth': 1,
'unroll_explicit': True,
'detect_global_barrier': False,
'offset_factor': 0,
'data_alignment': -1,
'restricted_func': True
"auto_unroll_max_step": 0,
"auto_unroll_min_depth": 1,
"unroll_explicit": True,
"detect_global_barrier": False,
"offset_factor": 0,
"data_alignment": -1,
"restricted_func": True,
"add_lower_pass": None
}
def __init__(self, **kwargs):
self._old_scope = None
......@@ -94,6 +95,9 @@ def build_config(**kwargs):
not to overlap. This enables more optimization.
Corresponds to restricted keyword in C99
add_lower_pass: list of function(Stmt->Stmt), default=None
Additional lowering passes to be applied before make_api.
Returns
-------
config: BuildConfig
......@@ -200,6 +204,9 @@ def lower(sch,
cfg.auto_unroll_max_step,
cfg.auto_unroll_min_depth,
cfg.unroll_explicit)
if cfg.add_lower_pass:
for f in cfg.add_lower_pass:
stmt = f(stmt)
stmt = ir_pass.Simplify(stmt)
if simple_mode:
return stmt
......
......@@ -100,6 +100,7 @@ REGISTER_PASS1(InjectPrefetch);
REGISTER_PASS1(LoopPartition);
REGISTER_PASS1(RemoveNoOp);
REGISTER_PASS2(SplitPipeline);
REGISTER_PASS2(LiftAttrScope);
REGISTER_PASS1(NarrowChannelAccess);
REGISTER_PASS2(LowerThreadAllreduce);
REGISTER_PASS2(LowerIntrin);
......
......@@ -104,6 +104,14 @@ void CodeGenLLVM::Init(const std::string& module_name,
llvm::FunctionType::get(t_int_, {
t_int_, t_tvm_parallel_group_env_->getPointerTo()}
, false);
ftype_tvm_static_init_callback_ =
llvm::FunctionType::get(t_int_, {t_void_p_}, false);
ftype_tvm_static_init_ =
llvm::FunctionType::get(t_int_, {
t_void_p_->getPointerTo(),
ftype_tvm_static_init_callback_->getPointerTo(),
t_void_p_, t_int_}
, false);
// initialize TVM runtime API
if (system_lib) {
// We will need this in environment for backward registration.
......@@ -802,30 +810,44 @@ void CodeGenLLVM::CreateComputeScope(const AttrStmt* op) {
builder_->SetInsertPoint(compute_call_end);
}
void CodeGenLLVM::CreateParallelLaunch(const Stmt& body, int num_task) {
using llvm::BasicBlock;
Array<Var> vfields = ir::UndefinedVars(body, {});
llvm::Value* CodeGenLLVM::PackClosureData(const Array<Var>& vfields) {
std::vector<llvm::Type*> fields;
for (Var v : vfields) {
auto it = var_map_.find(v.get());
CHECK(it != var_map_.end());
fields.push_back(it->second->getType());
}
// closure data
llvm::StructType* tcdata = llvm::StructType::create(fields);
llvm::Function* f = llvm::Function::Create(
ftype_tvm_parallel_lambda_,
llvm::Function::PrivateLinkage,
"__tvm_parallel_lambda", module_.get());
// allocate and setup the closure, call the closure.
llvm::Value* cdata = builder_->CreateAlloca(tcdata, ConstInt32(1));
llvm::Value* zero = ConstInt32(0);
for (size_t i = 0; i < vfields.size(); ++i) {
builder_->CreateStore(
var_map_.at(vfields[i].get()),
builder_->CreateInBoundsGEP(cdata, {zero, ConstInt32(i)}));
var_map_.at(vfields[i].get()),
builder_->CreateInBoundsGEP(cdata, {zero, ConstInt32(i)}));
}
return cdata;
}
void CodeGenLLVM::UnpackClosureData(llvm::Value* cdata,
const Array<Var>& vfields,
std::unordered_map<const Variable*, llvm::Value*>* vmap) {
for (size_t i = 0; i < vfields.size(); ++i) {
(*vmap)[vfields[i].get()] =
builder_->CreateLoad(builder_->CreateInBoundsGEP(
cdata, {ConstInt32(0), ConstInt32(i)}));
}
}
void CodeGenLLVM::CreateParallelLaunch(const Stmt& body, int num_task) {
using llvm::BasicBlock;
// closure data
llvm::Function* f = llvm::Function::Create(
ftype_tvm_parallel_lambda_,
llvm::Function::PrivateLinkage,
"__tvm_parallel_lambda", module_.get());
// allocate and setup the closure, call the closure.
Array<Var> vfields = ir::UndefinedVars(body, {});
llvm::Value* cdata = PackClosureData(vfields);
BasicBlock* par_launch_end = CheckCallSuccess(
builder_->CreateCall(
RuntimeTVMParallelLaunch(),
......@@ -836,15 +858,10 @@ void CodeGenLLVM::CreateParallelLaunch(const Stmt& body, int num_task) {
auto it = f->arg_begin();
llvm::Value* task_id = &(*it++);
llvm::Value* penv = &(*it++);
cdata = &(*it++);
cdata = builder_->CreatePointerCast(cdata, tcdata->getPointerTo());
cdata = builder_->CreatePointerCast(&(*it++), cdata->getType());
// setup new variable map, swap it with current var context.
std::unordered_map<const Variable*, llvm::Value*> new_vmap;
for (size_t i = 0; i < vfields.size(); ++i) {
new_vmap[vfields[i].get()] =
builder_->CreateLoad(builder_->CreateInBoundsGEP(
cdata, {zero, ConstInt32(i)}));
}
UnpackClosureData(cdata, vfields, &new_vmap);
// setup parallel env
ParallelEnv par_env;
par_env.task_id = Var("task_id", Int(32));
......@@ -852,7 +869,7 @@ void CodeGenLLVM::CreateParallelLaunch(const Stmt& body, int num_task) {
new_vmap[par_env.task_id.get()] = task_id;
new_vmap[par_env.num_task.get()] = builder_->CreateLoad(
builder_->CreateInBoundsGEP(
penv, {zero, ConstInt32(1)}));
penv, {ConstInt32(0), ConstInt32(1)}));
par_env.penv = penv;
std::swap(function_, f);
std::swap(parallel_env_, par_env);
......@@ -868,6 +885,52 @@ void CodeGenLLVM::CreateParallelLaunch(const Stmt& body, int num_task) {
builder_->SetInsertPoint(par_launch_end);
}
void CodeGenLLVM::CreateStaticInit(const std::string& init_fname, const Stmt& body) {
using llvm::BasicBlock;
// closure data
llvm::Function* f = llvm::Function::Create(
ftype_tvm_static_init_callback_,
llvm::Function::PrivateLinkage,
"__tvm_static_init_lambda", module_.get());
llvm::GlobalVariable* gv = new llvm::GlobalVariable(
*module_, t_void_p_, false,
llvm::GlobalValue::PrivateLinkage, 0,
"__tvm_static_handle");
gv->setAlignment(data_layout_->getTypeAllocSize(t_void_p_));
gv->setInitializer(llvm::Constant::getNullValue(t_void_p_));
llvm::Function* finit = module_->getFunction(init_fname);
if (finit == nullptr) {
finit = llvm::Function::Create(
ftype_tvm_static_init_, llvm::Function::ExternalLinkage, init_fname, module_.get());
}
// allocate and setup the closure, call the closure.
Array<Var> vfields = ir::UndefinedVars(body, {});
llvm::Value* cdata = PackClosureData(vfields);
llvm::Value* nbytes = ConstInt32(data_layout_->getTypeAllocSize(
llvm::cast<llvm::PointerType>(cdata->getType())->getElementType()));
BasicBlock* init_end = CheckCallSuccess(
builder_->CreateCall(
finit,
{gv, f, builder_->CreatePointerCast(cdata, t_void_p_), nbytes}));
// Setup the closure function.
BasicBlock *lambda_entry = BasicBlock::Create(*ctx_, "entry", f);
builder_->SetInsertPoint(lambda_entry);
auto it = f->arg_begin();
cdata = builder_->CreatePointerCast(&(*it++), cdata->getType());
// setup new variable map, swap it with current var context.
std::unordered_map<const Variable*, llvm::Value*> new_vmap;
UnpackClosureData(cdata, vfields, &new_vmap);
CHECK(parallel_env_.penv == nullptr);
std::swap(function_, f);
std::swap(var_map_, new_vmap);
this->VisitStmt(body);
builder_->CreateRet(ConstInt32(0));
// swap the var map back, now we are back on track.
std::swap(var_map_, new_vmap);
std::swap(function_, f);
builder_->SetInsertPoint(init_end);
}
void CodeGenLLVM::CreateSerialFor(llvm::Value* begin,
llvm::Value* end,
llvm::Value* stride,
......@@ -1626,6 +1689,8 @@ void CodeGenLLVM::VisitStmt_(const AttrStmt* op) {
alloc_storage_info_[v].alignment =
static_cast<int>(op->value.as<IntImm>()->value);
this->VisitStmt(op->body);
} else if (op->attr_key == ir::attr::coproc_uop_scope) {
this->CreateStaticInit(op->value.as<StringImm>()->value, op->body);
} else if (op->attr_key == ir::attr::compute_scope) {
this->CreateComputeScope(op);
} else if (op->attr_key == ir::attr::pragma_scope) {
......
......@@ -197,6 +197,9 @@ class CodeGenLLVM :
llvm::FunctionType* ftype_tvm_parallel_launch_{nullptr};
llvm::FunctionType* ftype_tvm_parallel_barrier_{nullptr};
llvm::FunctionType* ftype_tvm_register_system_symbol_{nullptr};
// Lazy entry for function call.
llvm::FunctionType* ftype_tvm_static_init_callback_{nullptr};
llvm::FunctionType* ftype_tvm_static_init_{nullptr};
// The acting body
llvm::BasicBlock* block_{nullptr};
/*! \brief native vector bits of current targetx*/
......@@ -241,6 +244,12 @@ class CodeGenLLVM :
llvm::Value* CreateVecFlip(llvm::Value* vec);
llvm::Value* CreateVecConcat(std::vector<llvm::Value*> vecs);
llvm::Value* CreateVecPad(llvm::Value* vec, int target_lanes);
llvm::Value* PackClosureData(const Array<Var>& fields);
void UnpackClosureData(llvm::Value*cdata,
const Array<Var>& fields,
std::unordered_map<const Variable*, llvm::Value*>* vmap);
// Create static initialization
void CreateStaticInit(const std::string& init_fname, const Stmt& body);
// Create parallel launch
void CreateParallelLaunch(const Stmt& body, int num_task);
// Create serial for
......
......@@ -47,7 +47,8 @@ class ContextCallCombiner final : public IRMutator {
}
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
if (op->attr_key == attr::thread_extent) {
if (op->attr_key == attr::thread_extent ||
op->attr_key == attr::coproc_uop_scope) {
// Map of comparison expression to variable
std::map<Expr, Var, CompareExpr> temp;
std::swap(temp, ctx_map_);
......
/*!
* Copyright (c) 2017 by Contributors
*
* \brief Lift specified AttrStmt scope to outer if
* the body contains the same scope.
* \file lift_attr_scope.cc
*/
#include <tvm/ir_pass.h>
#include <tvm/ir_mutator.h>
namespace tvm {
namespace ir {
// NOTE: this optimization can only be applied
// to a few specified attr keys
class AttrScopeLifter : public IRMutator {
public:
explicit AttrScopeLifter(std::string attr_key)
: attr_key_(attr_key) {}
Stmt Lift(Stmt stmt) {
stmt = Mutate(stmt);
if (attr_node_.defined()) {
stmt = AttrStmt::make(
attr_node_, attr_key_, attr_value_, stmt);
}
return stmt;
}
// do not go beyond
Stmt Mutate_(const Allocate* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<Allocate>();
if (attr_node_.defined()) {
Stmt body = AttrStmt::make(
attr_node_, attr_key_, attr_value_, op->body);
// undefine them
attr_node_ = NodeRef();
attr_value_ = Expr();
return Allocate::make(
op->buffer_var, op->type,
op->extents, op->condition, body,
op->new_expr, op->free_function);
} else {
return stmt;
}
}
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
if (op->attr_key == attr_key_) {
attr_node_ = op->node;
attr_value_ = op->value;
return op->body;
} else {
return IRMutator::Mutate_(op, s);
}
}
Stmt Mutate_(const Block* op, const Stmt& s) final {
Stmt first = this->Mutate(op->first);
NodeRef first_node_;
Expr first_value_;
std::swap(first_node_, attr_node_);
std::swap(first_value_, attr_value_);
Stmt rest = this->Mutate(op->rest);
if (attr_node_.defined() &&
attr_value_.defined() &&
first_node_.defined() &&
first_value_.defined() &&
attr_node_.same_as(first_node_) &&
attr_value_.same_as(first_value_)) {
if (first.same_as(op->first) && rest.same_as(op->rest)) {
return s;
} else {
return Block::make(first, rest);
}
} else {
if (first_node_.defined()) {
first = AttrStmt::make(
first_node_, attr_key_, first_value_, first);
}
if (attr_node_.defined()) {
rest = AttrStmt::make(
attr_node_, attr_key_, attr_value_, rest);
// undefine them
attr_node_ = NodeRef();
attr_value_ = Expr();
}
if (first.same_as(op->first) && rest.same_as(op->rest)) {
return s;
} else {
return Block::make(first, rest);
}
}
}
Stmt Mutate_(const IfThenElse* op, const Stmt& s) final {
if (!op->then_case.defined()) {
return IRMutator::Mutate_(op, s);
}
Stmt then_case = this->Mutate(op->then_case);
NodeRef first_node_;
Expr first_value_;
std::swap(first_node_, attr_node_);
std::swap(first_value_, attr_value_);
Stmt else_case = this->Mutate(op->else_case);
if (attr_node_.defined() &&
attr_value_.defined() &&
first_node_.defined() &&
first_value_.defined() &&
attr_node_.same_as(first_node_) &&
attr_value_.same_as(first_value_)) {
if (then_case.same_as(op->then_case) &&
else_case.same_as(op->else_case)) {
return s;
} else {
return IfThenElse::make(op->condition, then_case, else_case);
}
} else {
if (first_node_.defined()) {
then_case = AttrStmt::make(
first_node_, attr_key_, first_value_, then_case);
}
if (attr_node_.defined()) {
else_case = AttrStmt::make(
attr_node_, attr_key_, attr_value_, else_case);
// undefine them
attr_node_ = NodeRef();
attr_value_ = Expr();
}
if (then_case.same_as(op->then_case) &&
else_case.same_as(op->else_case)) {
return s;
} else {
return IfThenElse::make(op->condition, then_case, else_case);
}
}
}
private:
std::string attr_key_;
NodeRef attr_node_;
Expr attr_value_;
};
Stmt LiftAttrScope(Stmt stmt, std::string attr_key) {
return AttrScopeLifter(attr_key).Lift(stmt);
}
} // namespace ir
} // namespace tvm
......@@ -234,6 +234,17 @@ int TVMBackendFreeWorkspace(int device_type,
return 0;
}
int TVMBackendRunOnce(void** handle,
int (*f)(void*),
void* cdata,
int nbytes) {
if (*handle == nullptr) {
*handle = reinterpret_cast<void*>(1);
return (*f)(cdata);
}
return 0;
}
int TVMFuncFree(TVMFunctionHandle func) {
API_BEGIN();
delete static_cast<PackedFunc*>(func);
......
import tvm
import numpy as np
def test_static_init():
dtype = 'int64'
n = tvm.var('n')
Ab = tvm.decl_buffer((n, ), dtype)
i = tvm.var('i')
ib = tvm.ir_builder.create()
A = ib.buffer_ptr(Ab)
cp = tvm.thread_axis((0, 1), "cop")
finit = tvm.make.StringImm("TVMBackendRunOnce")
ib.scope_attr(cp, "coproc_uop_scope", finit)
with ib.for_range(0, n, "i", for_type="parallel") as i:
A[i] = A[i] + 1
stmt = ib.get()
fapi = tvm.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0, True)
fapi = tvm.ir_pass.LowerTVMBuiltin(fapi)
f = tvm.codegen.build_module(fapi, "llvm")
a = tvm.nd.array(np.zeros(10, dtype=dtype))
f(a)
f(a)
np.testing.assert_equal(a.asnumpy(), np.ones(a.shape[0]))
if __name__ == "__main__":
test_static_init()
import tvm
def test_coproc_lift():
ib = tvm.ir_builder.create()
n = tvm.var("n")
cp = tvm.thread_axis((0, 1), "cop")
value = tvm.make.StringImm("xxx")
A = ib.allocate("float32", n, name="A", scope="global")
with ib.for_range(0, n, name="i") as i:
with ib.for_range(0, 10, name="j") as j:
ib.scope_attr(cp, "coproc_uop_scope", value)
A[i] = A[i] + 1
with ib.for_range(0, 10, name="j") as j:
ib.scope_attr(cp, "coproc_uop_scope", value)
A[j] = A[j] + 2
body = ib.get()
body = tvm.ir_pass.LiftAttrScope(body, "coproc_uop_scope")
assert body.body.body.node == cp
if __name__ == "__main__":
test_coproc_lift()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment