Commit adf39837 by Tianqi Chen Committed by GitHub

[PASS] Improve double buffer (#413)

parent 5072efae
...@@ -234,10 +234,10 @@ Stmt InjectPrefetch(Stmt stmt); ...@@ -234,10 +234,10 @@ Stmt InjectPrefetch(Stmt stmt);
/*! /*!
* \brief Inject double buffer into stmt. * \brief Inject double buffer into stmt.
* \param stmt The statment to be transformed. * \param stmt The statment to be transformed.
* \param split_loop Whether split the loop containing double buffering. * \param split_loop Loop splitting factor.
* \return Transformed stmt. * \return Transformed stmt.
*/ */
Stmt InjectDoubleBuffer(Stmt stmt, bool split_loop); Stmt InjectDoubleBuffer(Stmt stmt, int split_loop);
/*! /*!
* \brief Rewrite storage allocation pattern. * \brief Rewrite storage allocation pattern.
......
...@@ -33,7 +33,7 @@ class BuildConfig(object): ...@@ -33,7 +33,7 @@ class BuildConfig(object):
"offset_factor": 0, "offset_factor": 0,
"data_alignment": -1, "data_alignment": -1,
"restricted_func": True, "restricted_func": True,
"double_buffer_split_loop": True, "double_buffer_split_loop": 1,
"add_lower_pass": None "add_lower_pass": None
} }
def __init__(self, **kwargs): def __init__(self, **kwargs):
...@@ -99,9 +99,10 @@ def build_config(**kwargs): ...@@ -99,9 +99,10 @@ def build_config(**kwargs):
not to overlap. This enables more optimization. not to overlap. This enables more optimization.
Corresponds to restricted keyword in C99 Corresponds to restricted keyword in C99
double_buffer_split_loop: bool, default=True double_buffer_split_loop: int, default=2
Whether split the loop containing double buffer so Whether split the loop with factor. If it is zero, no splitting will happen.
that the buffer fetching won't contain condition. It it is bigger than one, the logic will do a split with factor equals the integer
and unroll the inner loop. This allows the buffer fetching won't contain condition.
add_lower_pass: list of tuiple (phase, function(Stmt->Stmt)), default=None add_lower_pass: list of tuiple (phase, function(Stmt->Stmt)), default=None
phase contains an integer on which optimization pass we apply the pass. phase contains an integer on which optimization pass we apply the pass.
......
...@@ -34,9 +34,21 @@ class DoubleBufferDetector : public IRVisitor { ...@@ -34,9 +34,21 @@ class DoubleBufferDetector : public IRVisitor {
std::unordered_set<const Variable*> touched_; std::unordered_set<const Variable*> touched_;
}; };
class StripDoubleBufferWrite : public IRMutator {
public:
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
if (op->attr_key == attr::double_buffer_write) {
return Mutate(op->body);
} else {
return IRMutator::Mutate_(op, s);
}
}
};
class DoubleBufferInjector : public IRMutator { class DoubleBufferInjector : public IRMutator {
public: public:
explicit DoubleBufferInjector(bool split_loop) explicit DoubleBufferInjector(int split_loop)
: split_loop_(split_loop) {} : split_loop_(split_loop) {}
Stmt Inject(const Stmt& stmt) { Stmt Inject(const Stmt& stmt) {
...@@ -97,17 +109,38 @@ class DoubleBufferInjector : public IRMutator { ...@@ -97,17 +109,38 @@ class DoubleBufferInjector : public IRMutator {
auto it = loop_pre_.find(op); auto it = loop_pre_.find(op);
if (it != loop_pre_.end()) { if (it != loop_pre_.end()) {
const For* old_loop = stmt.as<For>(); const For* old_loop = stmt.as<For>();
if (split_loop_) { if (split_loop_ != 0) {
// Explicitly unroll the loop
CHECK(split_loop_ % 2 == 0 || split_loop_ == 1)
<< "It is better to split with multiple of 2";
CHECK(is_zero(old_loop->min));
Expr zero = old_loop->min;
Expr new_ext = arith::ComputeExpr<Sub>( Expr new_ext = arith::ComputeExpr<Sub>(
old_loop->extent, make_const(old_loop->loop_var.type(), 1)); old_loop->extent, make_const(old_loop->loop_var.type(), 1));
Stmt loop = For::make( Expr factor = make_const(new_ext.type(), split_loop_);
old_loop->loop_var, old_loop->min, new_ext, Expr outer_ext = arith::ComputeExpr<Div>(new_ext, factor);
old_loop->for_type, old_loop->device_api, Expr tail_base = arith::ComputeExpr<Mul>(outer_ext, factor);
old_loop->body); Var outer_var(old_loop->loop_var->name_hint + ".outer", old_loop->loop_var.type());
std::unordered_map<const Variable*, Expr> vmap; std::unordered_map<const Variable*, Expr> vmap;
vmap[old_loop->loop_var.get()] = new_ext; std::vector<Stmt> loop_seq;
Stmt end = Substitute(old_loop->body, vmap); for (size_t i = 0; i < split_loop_; ++i) {
stmt = Block::make(loop, end); vmap[old_loop->loop_var.get()] = outer_var * factor + make_const(factor.type(), i);
loop_seq.emplace_back(Substitute(old_loop->body, vmap));
}
Stmt loop = For::make(
outer_var, zero, outer_ext, old_loop->for_type, old_loop->device_api,
MergeSeq(loop_seq));
// tail
std::vector<Stmt> tail_seq;
Stmt tail_body = StripDoubleBufferWrite().Mutate(old_loop->body);
for (size_t i = 0; i < split_loop_; ++i) {
Expr idx = tail_base + make_const(tail_base.type(), i);
vmap[old_loop->loop_var.get()] = idx;
tail_seq.emplace_back(
IfThenElse::make(idx < old_loop->extent,
Substitute(tail_body, vmap)));
}
stmt = Block::make(loop, MergeSeq(tail_seq));
} }
stmt = Block::make(MergeSeq(it->second), stmt); stmt = Block::make(MergeSeq(it->second), stmt);
} }
...@@ -205,7 +238,7 @@ class DoubleBufferInjector : public IRMutator { ...@@ -205,7 +238,7 @@ class DoubleBufferInjector : public IRMutator {
std::string scope; std::string scope;
}; };
// Whether split loop // Whether split loop
bool split_loop_; int split_loop_;
// Whether we are inside double buffer scope. // Whether we are inside double buffer scope.
bool in_double_buffer_scope_{false}; bool in_double_buffer_scope_{false};
// The current loop next // The current loop next
...@@ -219,7 +252,7 @@ class DoubleBufferInjector : public IRMutator { ...@@ -219,7 +252,7 @@ class DoubleBufferInjector : public IRMutator {
}; };
Stmt InjectDoubleBuffer(Stmt stmt, bool split_loop) { Stmt InjectDoubleBuffer(Stmt stmt, int split_loop) {
return DoubleBufferInjector(split_loop).Inject(stmt); return DoubleBufferInjector(split_loop).Inject(stmt);
} }
} // namespace ir } // namespace ir
......
...@@ -19,7 +19,7 @@ def test_double_buffer(): ...@@ -19,7 +19,7 @@ def test_double_buffer():
C[j] = B[j] + 1 C[j] = B[j] + 1
stmt = ib.get() stmt = ib.get()
stmt = tvm.ir_pass.InjectDoubleBuffer(stmt, True) stmt = tvm.ir_pass.InjectDoubleBuffer(stmt, 2)
stmt = tvm.ir_pass.Simplify(stmt) stmt = tvm.ir_pass.Simplify(stmt)
assert isinstance(stmt.body.body, tvm.stmt.Allocate) assert isinstance(stmt.body.body, tvm.stmt.Allocate)
assert stmt.body.body.extents[0].value == 2 assert stmt.body.body.extents[0].value == 2
...@@ -30,7 +30,7 @@ def test_double_buffer(): ...@@ -30,7 +30,7 @@ def test_double_buffer():
if isinstance(op, tvm.expr.Call) and op.name == "tvm_storage_sync": if isinstance(op, tvm.expr.Call) and op.name == "tvm_storage_sync":
count[0] += 1 count[0] += 1
tvm.ir_pass.PostOrderVisit(f.body, count_sync) tvm.ir_pass.PostOrderVisit(f.body, count_sync)
assert count[0] == 2 assert count[0] == 4
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment