Commit a45d3b01 by Tianqi Chen Committed by GitHub

[PASS] InjectDoubleBuffer (#405)

parent b8c8aadf
......@@ -178,6 +178,14 @@ constexpr const char* pragma_scope = "pragma_scope";
* run prefetch of Tensor on the current loop scope
*/
constexpr const char* prefetch_scope = "prefetch_scope";
/*!
* \brief Marks production of double buffer data
*/
constexpr const char* double_buffer_scope = "double_buffer_scope";
/*!
* \brief Marks region used by double buffer write
*/
constexpr const char* double_buffer_write = "double_buffer_write";
/*! \brief Mark of scan update scope */
constexpr const char* scan_update_scope = "scan_update_scope";
/*! \brief Mark of scan init scope */
......
......@@ -232,6 +232,14 @@ Stmt InjectVirtualThread(Stmt stmt);
Stmt InjectPrefetch(Stmt stmt);
/*!
* \brief Inject double buffer into stmt.
* \param stmt The statment to be transformed.
* \param split_loop Whether split the loop containing double buffering.
* \return Transformed stmt.
*/
Stmt InjectDoubleBuffer(Stmt stmt, bool split_loop);
/*!
* \brief Rewrite storage allocation pattern.
* Moves the allocation to outer most possible scope.
* Trying to share space between allocations to make
......
......@@ -209,6 +209,11 @@ class Stage : public NodeRef {
*/
Stage& storage_align(IterVar axis, int factor, int offset); //NOLINT(*)
/*!
* \brief Compute current stage with double buffering.
* \return reference to self.
*/
Stage& double_buffer(); // NOLINT(*)
/*!
* \brief whether the stage has been scheduled.
* \return whether the stage has been scheduled.
*/
......@@ -408,6 +413,8 @@ class StageNode : public Node {
std::string scope;
/*! \brief Whether this is an output stage */
bool is_output{false};
/*! \brief Whether apply double buffer optimization to this stage */
bool double_buffer{false};
/*!
* \brief The parent group of the current stage.
* The stage cannot be assigned to stages outside the group.
......@@ -429,6 +436,7 @@ class StageNode : public Node {
v->Visit("attach_stage", &attach_stage);
v->Visit("scope", &scope);
v->Visit("is_output", &is_output);
v->Visit("double_buffer", &double_buffer);
v->Visit("group", &group);
v->Visit("num_child_stages", &num_child_stages);
}
......
......@@ -33,6 +33,7 @@ class BuildConfig(object):
"offset_factor": 0,
"data_alignment": -1,
"restricted_func": True,
"double_buffer_split_loop": True,
"add_lower_pass": None
}
def __init__(self, **kwargs):
......@@ -97,6 +98,10 @@ def build_config(**kwargs):
not to overlap. This enables more optimization.
Corresponds to restricted keyword in C99
double_buffer_split_loop: bool, default=True
Whether split the loop containing double buffer so
that the buffer fetching won't contain condition.
add_lower_pass: list of function(Stmt->Stmt), default=None
Additional lowering passes to be applied before make_api.
......@@ -187,6 +192,7 @@ def lower(sch,
Then the Stmt before make api is returned.
"""
binds, arg_list = get_binds(args, binds)
cfg = BuildConfig.current
# normalize schedule first
sch = sch.normalize()
bounds = schedule.InferBound(sch)
......@@ -198,8 +204,8 @@ def lower(sch,
stmt = ir_pass.LoopPartition(stmt)
stmt = ir_pass.VectorizeLoop(stmt)
stmt = ir_pass.InjectVirtualThread(stmt)
stmt = ir_pass.InjectDoubleBuffer(stmt, cfg.double_buffer_split_loop)
stmt = ir_pass.StorageRewrite(stmt)
cfg = BuildConfig.current
stmt = ir_pass.UnrollLoop(
stmt,
cfg.auto_unroll_max_step,
......
......@@ -268,6 +268,21 @@ class IRBuilder(object):
self.emit(_make.IfThenElse(prev.condition, prev.then_case, self._pop_seq()))
return WithScope(None, _exit_cb)
def new_scope(self):
"""Create new scope,
this is useful to set boundary of attr and allocate.
Returns
-------
new_scope : WithScope
The result new scope.
"""
self._seq_stack.append([])
def _exit_cb():
self.emit(self._pop_seq())
return WithScope(None, _exit_cb)
def allocate(self, dtype, shape, name="buf", scope=None):
"""Create a allocate statement.
......
......@@ -589,4 +589,13 @@ class Stage(NodeBase):
"""
_api_internal._StageStorageAlign(self, axis, factor, offset)
def double_buffer(self):
"""Compute the current stage via double buffering.
This can only be applied to intermediate stage.
This will double the storage cost of the current stage.
Can be useful to hide load latency.
"""
_api_internal._StageDoubleBuffer(self)
_init_api("tvm.schedule")
......@@ -385,13 +385,18 @@ TVM_REGISTER_API("_StagePragma")
TVM_REGISTER_API("_StagePrefetch")
.set_body([](TVMArgs args, TVMRetValue *ret) {
args[0].operator Stage()
.prefetch(args[1], args[2], args[3]);
.prefetch(args[1], args[2], args[3]);
});
TVM_REGISTER_API("_StageStorageAlign")
.set_body([](TVMArgs args, TVMRetValue *ret) {
args[0].operator Stage()
.storage_align(args[1], args[2], args[3]);
.storage_align(args[1], args[2], args[3]);
});
TVM_REGISTER_API("_StageDoubleBuffer")
.set_body([](TVMArgs args, TVMRetValue *ret) {
args[0].operator Stage().double_buffer();
});
TVM_REGISTER_API("_ScheduleNormalize")
......
......@@ -101,6 +101,7 @@ REGISTER_PASS1(CoProcSync);
REGISTER_PASS1(LowerStorageAccessInfo);
REGISTER_PASS1(InjectVirtualThread);
REGISTER_PASS1(InjectPrefetch);
REGISTER_PASS2(InjectDoubleBuffer);
REGISTER_PASS1(LoopPartition);
REGISTER_PASS1(RemoveNoOp);
REGISTER_PASS2(SplitPipeline);
......
/*!
* Copyright (c) 2017 by Contributors
*
* \brief Inject double buffering optimization for data fetch.
* \file inject_double_buffer.cc
*/
#include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_mutator.h>
#include "./ir_util.h"
#include "../arithmetic/compute_expr.h"
namespace tvm {
namespace ir {
// Detect double buffer variables.
class DoubleBufferDetector : public IRVisitor {
public:
void Visit_(const AttrStmt* op) final {
if (op->attr_key == attr::double_buffer_scope) {
touched_.insert(op->node.as<Variable>());
IRVisitor::Visit_(op);
} else {
IRVisitor::Visit_(op);
}
}
void Visit_(const Variable* op) final {
if (touched_.count(op)) {
touched_.erase(op);
}
}
// The set of touched variable.
std::unordered_set<const Variable*> touched_;
};
class DoubleBufferInjector : public IRMutator {
public:
explicit DoubleBufferInjector(bool split_loop)
: split_loop_(split_loop) {}
Stmt Inject(const Stmt& stmt) {
DoubleBufferDetector detector;
detector.Visit(stmt);
if (detector.touched_.empty()) return stmt;
for (const Variable* v : detector.touched_) {
dbuffer_info_[v] = StorageEntry();
}
return ConvertSSA(this->Mutate(stmt));
}
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
if (op->attr_key == attr::storage_scope) {
const Variable* buf = op->node.as<Variable>();
auto it = dbuffer_info_.find(buf);
if (it != dbuffer_info_.end()) {
it->second.scope = op->value.as<StringImm>()->value;
return Mutate(op->body);
} else {
return IRMutator::Mutate_(op, s);
}
} else if (op->attr_key == attr::double_buffer_scope) {
return MakeProducer(op, s);
} else {
return IRMutator::Mutate_(op, s);
}
}
Stmt Mutate_(const Allocate* op, const Stmt& s) final {
auto it = dbuffer_info_.find(op->buffer_var.get());
if (it != dbuffer_info_.end()) {
it->second.size = arith::ComputeReduce<Mul>(op->extents);
Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<Allocate>();
Array<Expr> new_extents{make_const(op->extents[0].type(), 2)};
for (Expr e : op->extents) {
new_extents.push_back(e);
}
CHECK(it->second.loop != nullptr);
auto& alloc_nest = loop_allocs_[it->second.loop];
alloc_nest.emplace_back(AttrStmt::make(
op->buffer_var, attr::storage_scope,
StringImm::make(it->second.scope),
Evaluate::make(0)));
alloc_nest.emplace_back(Allocate::make(
op->buffer_var, op->type, new_extents, op->condition,
Evaluate::make(0)));
return op->body;
} else {
return IRMutator::Mutate_(op, s);
}
}
Stmt Mutate_(const For* op, const Stmt& s) final {
loop_nest_.push_back(op);
Stmt stmt = IRMutator::Mutate_(op, s);
auto it = loop_pre_.find(op);
if (it != loop_pre_.end()) {
const For* old_loop = stmt.as<For>();
if (split_loop_) {
Expr new_ext = arith::ComputeExpr<Sub>(
old_loop->extent, make_const(old_loop->loop_var.type(), 1));
Stmt loop = For::make(
old_loop->loop_var, old_loop->min, new_ext,
old_loop->for_type, old_loop->device_api,
old_loop->body);
std::unordered_map<const Variable*, Expr> vmap;
vmap[old_loop->loop_var.get()] = new_ext;
Stmt end = Substitute(old_loop->body, vmap);
stmt = Block::make(loop, end);
}
stmt = Block::make(MergeSeq(it->second), stmt);
}
it = loop_allocs_.find(op);
if (it != loop_allocs_.end()) {
stmt = MergeNest(it->second, stmt);
}
loop_nest_.pop_back();
return stmt;
}
Stmt Mutate_(const Store* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<Store>();
auto it = dbuffer_info_.find(op->buffer_var.get());
if (it != dbuffer_info_.end()) {
const StorageEntry& e = it->second;
CHECK(in_double_buffer_scope_);
CHECK(e.size.defined());
return Store::make(op->buffer_var,
op->value,
e.switch_write_var * e.size + op->index,
op->predicate);
} else {
return stmt;
}
}
Expr Mutate_(const Load* op, const Expr& e) final {
Expr expr = IRMutator::Mutate_(op, e);
op = expr.as<Load>();
auto it = dbuffer_info_.find(op->buffer_var.get());
if (it != dbuffer_info_.end()) {
const StorageEntry& e = it->second;
CHECK(e.size.defined());
CHECK(e.switch_read_var.defined());
return Load::make(op->type,
op->buffer_var,
e.switch_read_var * e.size + op->index,
op->predicate);
} else {
return expr;
}
}
Expr Mutate_(const Variable* op, const Expr& e) final {
CHECK(!dbuffer_info_.count(op));
return e;
}
private:
Stmt MakeProducer(const AttrStmt* op, const Stmt& s) {
const VarExpr buffer(op->node.node_);
CHECK_NE(loop_nest_.size(), 0U)
<< "Double buffer scope must be inside a loop";
auto it = dbuffer_info_.find(buffer.get());
if (it == dbuffer_info_.end()) {
LOG(WARNING) << "Skip double buffer scope " << op->node;
return Mutate(op->body);
}
StorageEntry& e = it->second;
e.loop = loop_nest_.back();
Expr zero = make_const(e.loop->loop_var.type(), 0);
Expr one = make_const(e.loop->loop_var.type(), 1);
Expr two = make_const(e.loop->loop_var.type(), 2);
Expr loop_shift = e.loop->loop_var + one;
e.switch_write_var = Var(e.loop->loop_var->name_hint + ".db",
e.loop->loop_var.type());
e.switch_read_var = e.loop->loop_var % two;
in_double_buffer_scope_ = true;
Stmt body = Mutate(op->body);
in_double_buffer_scope_ = false;
std::unordered_map<const Variable*, Expr> vmap;
vmap[e.switch_write_var.get()] = zero;
vmap[e.loop->loop_var.get()] = zero;
loop_pre_[e.loop].emplace_back(Substitute(body, vmap));
vmap[e.loop->loop_var.get()] = loop_shift;
vmap[e.switch_write_var.get()] = loop_shift % two;
body = Substitute(body, vmap);
body = AttrStmt::make(buffer, attr::double_buffer_write, 1, body);
body = IfThenElse::make(loop_shift < e.loop->extent, body);
return body;
}
// Storage entry for those who need double buffering.
struct StorageEntry {
// The size of the buffer
Expr size;
// The loop we need
const For* loop{nullptr};
// The switch variable.
VarExpr switch_write_var;
// The switch variable for reading.
Expr switch_read_var;
// The storage scope.
std::string scope;
};
// Whether split loop
bool split_loop_;
// Whether we are inside double buffer scope.
bool in_double_buffer_scope_{false};
// The current loop next
std::vector<const For*> loop_nest_;
// The allocs to be appended before the loop
std::unordered_map<const For*, std::vector<Stmt> > loop_allocs_;
// The stmt to be appended before the loop
std::unordered_map<const For*, std::vector<Stmt> > loop_pre_;
// The allocation size of the buffer
std::unordered_map<const Variable*, StorageEntry> dbuffer_info_;
};
Stmt InjectDoubleBuffer(Stmt stmt, bool split_loop) {
return DoubleBufferInjector(split_loop).Inject(stmt);
}
} // namespace ir
} // namespace tvm
......@@ -74,6 +74,24 @@ void StorageAccessVisitor::Visit_(const AttrStmt* op) {
storage_scope_[buf] =
StorageScope::make(op->value.as<StringImm>()->value);
IRVisitor::Visit_(op);
} else if (op->attr_key == attr::double_buffer_write) {
CHECK(double_buffer_write_ == nullptr);
double_buffer_write_ = op->node.as<Variable>();
scope_.push_back(std::vector<StmtEntry>());
IRVisitor::Visit_(op);
StmtEntry s;
s.stmt = op;
s.access = Summarize(std::move(scope_.back()), nullptr);
scope_.pop_back();
if (!s.access.empty()) {
for (AccessEntry& e : s.access) {
if (e.type == kWrite && e.buffer.get() == double_buffer_write_) {
e.double_buffer_write = true;
}
}
scope_.back().emplace_back(std::move(s));
}
double_buffer_write_ = nullptr;
} else if (op->attr_key == attr::coproc_scope) {
IterVar iv(op->node.node_);
env_threads_.push_back(iv);
......
......@@ -45,6 +45,8 @@ class StorageAccessVisitor : public IRVisitor {
AccessType type;
/*! \brief The storage scope */
StorageScope scope;
/*! \brief Whether the access is double buffer write */
bool double_buffer_write{false};
};
/*! \brief Access pattern about a single statement */
struct StmtEntry {
......@@ -116,6 +118,8 @@ class StorageAccessVisitor : public IRVisitor {
bool in_device_env_{false};
// Whether we are inside condition.
int condition_counter_{0};
// The current double buffer write scope.
const Variable* double_buffer_write_{nullptr};
// the current free stmt entry.
StmtEntry curr_stmt_;
// The involving threads
......
......@@ -4,6 +4,7 @@
*/
#include <tvm/ir.h>
#include <tvm/expr.h>
#include <tvm/operation.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_operator.h>
#include <tvm/ir_pass.h>
......@@ -53,6 +54,18 @@ class StorageFlattener : public IRMutator {
if (op->attr_key == attr::realize_scope) {
storage_scope_[op->node.get()] = op->value.as<StringImm>()->value;
return this->Mutate(op->body);
} else if (op->attr_key == attr::double_buffer_scope) {
Operation func(op->node.node_);
Stmt body = Mutate(op->body);
for (int i = 0; i < func->num_outputs(); ++i) {
TensorKey key{func, i};
auto it = buf_map_.find(key);
CHECK(it != buf_map_.end())
<< "Cannot find allocated buffer for " << key.f;
body = AttrStmt::make(
it->second.buffer->data, op->attr_key, op->value, body);
}
return body;
} else if (op->attr_key == attr::thread_extent) {
IterVar iv(op->node.node_);
ThreadScope ts = ThreadScope::make(iv->thread_tag);
......
......@@ -34,13 +34,10 @@ class ThreadSyncPlanner : public StorageAccessVisitor {
// Unsynced reads and writes
std::vector<AccessEntry> reads;
std::vector<AccessEntry> writes;
// if it is a loop, rotate two times to consider effect of loop.
size_t max_seq = seq.size();
if (loop != nullptr) max_seq *= 2;
// simulation based approach to find dependenceies
for (size_t i = 0; i < max_seq; ++i) {
const StmtEntry& s = seq[i % seq.size()];
for (size_t i = 0; i < seq.size(); ++i) {
const StmtEntry& s = seq[i];
// check if sync before statement is needed.
bool sync_before_stmt = (syncs_inserted_.count(s.stmt) != 0);
// Apply the syncs added already.
......@@ -50,11 +47,11 @@ class ThreadSyncPlanner : public StorageAccessVisitor {
}
for (const AccessEntry& acc : s.access) {
if (acc.type == kRead) {
if (FindConflict(writes, acc)) {
if (FindConflict(writes, acc, false)) {
sync_before_stmt = true; break;
}
} else if (acc.type == kWrite) {
if (FindConflict(reads, acc)) {
if (FindConflict(reads, acc, false)) {
sync_before_stmt = true; break;
}
} else if (acc.type == kSync) {
......@@ -81,6 +78,33 @@ class ThreadSyncPlanner : public StorageAccessVisitor {
syncs_inserted_.insert(s.stmt);
}
}
if (loop != nullptr) {
for (size_t i = 0; i < seq.size(); ++i) {
const StmtEntry& s = seq[i];
if (syncs_inserted_.count(s.stmt) != 0) break;
if (reads.empty() && writes.empty()) break;
bool sync_before_stmt = false;
for (const AccessEntry& acc : s.access) {
if (acc.type == kRead) {
if (FindConflict(writes, acc, true)) {
sync_before_stmt = true; break;
}
} else if (acc.type == kWrite) {
if (FindConflict(reads, acc, true)) {
sync_before_stmt = true; break;
}
} else if (acc.type == kSync) {
reads.clear(); writes.clear();
}
}
if (sync_before_stmt) {
CHECK_EQ(condition_counter(), 0)
<< "Cannot insert syncs inside condition";
syncs_inserted_.insert(s.stmt);
break;
}
}
}
// return the exposed entries, remove unecessary ones.
int sync_count = 0;
// head are before first sync, tail are after last sync
......@@ -117,13 +141,20 @@ class ThreadSyncPlanner : public StorageAccessVisitor {
}
}
head.insert(head.end(), tail.begin(), tail.end());
if (loop != nullptr) {
// clear double buffer flag after a loop is finished.
for (AccessEntry& e : head) {
e.double_buffer_write = false;
}
}
return head;
}
private:
// find conflicting entry in vec.
bool FindConflict(const std::vector<AccessEntry>& vec,
const AccessEntry& e) {
const AccessEntry& e,
bool loop_carry) {
for (const AccessEntry& x : vec) {
if (x.buffer.same_as(e.buffer)) {
// Assumes no race between threads
......@@ -134,6 +165,9 @@ class ThreadSyncPlanner : public StorageAccessVisitor {
if (Equal(e.touched.point_value(),
x.touched.point_value())) continue;
}
if (x.double_buffer_write &&
e.type == kRead &&
!loop_carry) continue;
return true;
}
}
......
......@@ -385,6 +385,13 @@ Stage& Stage::storage_align(IterVar axis, int factor, int offset) {
return *this;
}
Stage& Stage::double_buffer() {
StageNode *self = operator->();
CHECK(!self->is_output) << "Cannot apply double buffer on output";
self->double_buffer = true;
return *this;
}
Stage CopyStage(const Stage& s) {
std::shared_ptr<StageNode> n =
std::make_shared<StageNode>(*s.operator->());
......
......@@ -27,6 +27,10 @@ Stmt MakePipeline(const Stage& s,
if (producer.defined()) {
producer = ProducerConsumer::make(s->op, true, producer);
}
if (s->double_buffer) {
producer = AttrStmt::make(
s->op, ir::attr::double_buffer_scope, 1, producer);
}
Stmt pipeline = producer;
if (consumer.defined() && !is_no_op(consumer)) {
......@@ -170,7 +174,8 @@ class SchedulePostProc : public IRMutator {
thread_extent_scope_.erase(op->node.get());
return ret;
}
} else if (op->attr_key == ir::attr::realize_scope) {
} else if (op->attr_key == ir::attr::realize_scope ||
op->attr_key == ir::attr::double_buffer_scope) {
auto it = replace_op_.find(op->node.get());
if (it != replace_op_.end()) {
if (it->second.defined()) {
......
......@@ -47,7 +47,8 @@ def test_gemm():
s[CC].compute_at(s[C], tx)
s[AA].compute_at(s[CC], k)
s[BB].compute_at(s[CC], k)
s[AA].double_buffer()
s[BB].double_buffer()
ty, xi = s[AA].split(s[AA].op.axis[0], nparts=num_thread)
tx, xi = s[AA].split(xi, nparts=num_thread)
s[AA].bind(ty, thread_y)
......@@ -84,10 +85,10 @@ def test_gemm():
np.testing.assert_allclose(
c.asnumpy(), np.dot(a_np, b_np.T), rtol=1e-5)
check_device("nvptx -mcpu=sm_20")
check_device("metal")
check_device("opencl")
check_device("cuda")
#check_device("nvptx -mcpu=sm_20")
if __name__ == "__main__":
test_gemm()
import tvm
def test_double_buffer():
dtype = 'int64'
n = 100
m = 4
tx = tvm.thread_axis("threadIdx.x")
ib = tvm.ir_builder.create()
A = ib.pointer("float32", name="A")
C = ib.pointer("float32", name="A")
ib.scope_attr(tx, "thread_extent", 1)
with ib.for_range(0, n) as i:
B = ib.allocate("float32", m, name="B", scope="shared")
with ib.new_scope():
ib.scope_attr(B.asnode(), "double_buffer_scope", 1)
with ib.for_range(0, m) as j:
B[j] = A[i * 4 + j]
with ib.for_range(0, m) as j:
C[j] = B[j] + 1
stmt = ib.get()
stmt = tvm.ir_pass.InjectDoubleBuffer(stmt, True)
stmt = tvm.ir_pass.Simplify(stmt)
assert isinstance(stmt.body.body, tvm.stmt.Allocate)
assert stmt.body.body.extents[0].value == 2
f = tvm.ir_pass.MakeAPI(stmt, "db", [A.asnode(), C.asnode()], 2, True)
f = tvm.ir_pass.ThreadSync(f, "shared")
count = [0]
def count_sync(op):
if isinstance(op, tvm.expr.Call) and op.name == "tvm_storage_sync":
count[0] += 1
tvm.ir_pass.PostOrderVisit(f.body, count_sync)
assert count[0] == 2
if __name__ == "__main__":
test_double_buffer()
......@@ -96,6 +96,8 @@ def test_gemm():
s[BB].bind(ty, thread_y)
s[BB].bind(tx, thread_x)
s[BB].vectorize(xi)
s[AA].double_buffer()
s[BB].double_buffer()
# correctness
def check_device(device):
if not tvm.module.enabled(device):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment