Commit a45d3b01 by Tianqi Chen Committed by GitHub

[PASS] InjectDoubleBuffer (#405)

parent b8c8aadf
...@@ -178,6 +178,14 @@ constexpr const char* pragma_scope = "pragma_scope"; ...@@ -178,6 +178,14 @@ constexpr const char* pragma_scope = "pragma_scope";
* run prefetch of Tensor on the current loop scope * run prefetch of Tensor on the current loop scope
*/ */
constexpr const char* prefetch_scope = "prefetch_scope"; constexpr const char* prefetch_scope = "prefetch_scope";
/*!
* \brief Marks production of double buffer data
*/
constexpr const char* double_buffer_scope = "double_buffer_scope";
/*!
* \brief Marks region used by double buffer write
*/
constexpr const char* double_buffer_write = "double_buffer_write";
/*! \brief Mark of scan update scope */ /*! \brief Mark of scan update scope */
constexpr const char* scan_update_scope = "scan_update_scope"; constexpr const char* scan_update_scope = "scan_update_scope";
/*! \brief Mark of scan init scope */ /*! \brief Mark of scan init scope */
......
...@@ -232,6 +232,14 @@ Stmt InjectVirtualThread(Stmt stmt); ...@@ -232,6 +232,14 @@ Stmt InjectVirtualThread(Stmt stmt);
Stmt InjectPrefetch(Stmt stmt); Stmt InjectPrefetch(Stmt stmt);
/*! /*!
* \brief Inject double buffer into stmt.
* \param stmt The statment to be transformed.
* \param split_loop Whether split the loop containing double buffering.
* \return Transformed stmt.
*/
Stmt InjectDoubleBuffer(Stmt stmt, bool split_loop);
/*!
* \brief Rewrite storage allocation pattern. * \brief Rewrite storage allocation pattern.
* Moves the allocation to outer most possible scope. * Moves the allocation to outer most possible scope.
* Trying to share space between allocations to make * Trying to share space between allocations to make
......
...@@ -209,6 +209,11 @@ class Stage : public NodeRef { ...@@ -209,6 +209,11 @@ class Stage : public NodeRef {
*/ */
Stage& storage_align(IterVar axis, int factor, int offset); //NOLINT(*) Stage& storage_align(IterVar axis, int factor, int offset); //NOLINT(*)
/*! /*!
* \brief Compute current stage with double buffering.
* \return reference to self.
*/
Stage& double_buffer(); // NOLINT(*)
/*!
* \brief whether the stage has been scheduled. * \brief whether the stage has been scheduled.
* \return whether the stage has been scheduled. * \return whether the stage has been scheduled.
*/ */
...@@ -408,6 +413,8 @@ class StageNode : public Node { ...@@ -408,6 +413,8 @@ class StageNode : public Node {
std::string scope; std::string scope;
/*! \brief Whether this is an output stage */ /*! \brief Whether this is an output stage */
bool is_output{false}; bool is_output{false};
/*! \brief Whether apply double buffer optimization to this stage */
bool double_buffer{false};
/*! /*!
* \brief The parent group of the current stage. * \brief The parent group of the current stage.
* The stage cannot be assigned to stages outside the group. * The stage cannot be assigned to stages outside the group.
...@@ -429,6 +436,7 @@ class StageNode : public Node { ...@@ -429,6 +436,7 @@ class StageNode : public Node {
v->Visit("attach_stage", &attach_stage); v->Visit("attach_stage", &attach_stage);
v->Visit("scope", &scope); v->Visit("scope", &scope);
v->Visit("is_output", &is_output); v->Visit("is_output", &is_output);
v->Visit("double_buffer", &double_buffer);
v->Visit("group", &group); v->Visit("group", &group);
v->Visit("num_child_stages", &num_child_stages); v->Visit("num_child_stages", &num_child_stages);
} }
......
...@@ -33,6 +33,7 @@ class BuildConfig(object): ...@@ -33,6 +33,7 @@ class BuildConfig(object):
"offset_factor": 0, "offset_factor": 0,
"data_alignment": -1, "data_alignment": -1,
"restricted_func": True, "restricted_func": True,
"double_buffer_split_loop": True,
"add_lower_pass": None "add_lower_pass": None
} }
def __init__(self, **kwargs): def __init__(self, **kwargs):
...@@ -97,6 +98,10 @@ def build_config(**kwargs): ...@@ -97,6 +98,10 @@ def build_config(**kwargs):
not to overlap. This enables more optimization. not to overlap. This enables more optimization.
Corresponds to restricted keyword in C99 Corresponds to restricted keyword in C99
double_buffer_split_loop: bool, default=True
Whether split the loop containing double buffer so
that the buffer fetching won't contain condition.
add_lower_pass: list of function(Stmt->Stmt), default=None add_lower_pass: list of function(Stmt->Stmt), default=None
Additional lowering passes to be applied before make_api. Additional lowering passes to be applied before make_api.
...@@ -187,6 +192,7 @@ def lower(sch, ...@@ -187,6 +192,7 @@ def lower(sch,
Then the Stmt before make api is returned. Then the Stmt before make api is returned.
""" """
binds, arg_list = get_binds(args, binds) binds, arg_list = get_binds(args, binds)
cfg = BuildConfig.current
# normalize schedule first # normalize schedule first
sch = sch.normalize() sch = sch.normalize()
bounds = schedule.InferBound(sch) bounds = schedule.InferBound(sch)
...@@ -198,8 +204,8 @@ def lower(sch, ...@@ -198,8 +204,8 @@ def lower(sch,
stmt = ir_pass.LoopPartition(stmt) stmt = ir_pass.LoopPartition(stmt)
stmt = ir_pass.VectorizeLoop(stmt) stmt = ir_pass.VectorizeLoop(stmt)
stmt = ir_pass.InjectVirtualThread(stmt) stmt = ir_pass.InjectVirtualThread(stmt)
stmt = ir_pass.InjectDoubleBuffer(stmt, cfg.double_buffer_split_loop)
stmt = ir_pass.StorageRewrite(stmt) stmt = ir_pass.StorageRewrite(stmt)
cfg = BuildConfig.current
stmt = ir_pass.UnrollLoop( stmt = ir_pass.UnrollLoop(
stmt, stmt,
cfg.auto_unroll_max_step, cfg.auto_unroll_max_step,
......
...@@ -268,6 +268,21 @@ class IRBuilder(object): ...@@ -268,6 +268,21 @@ class IRBuilder(object):
self.emit(_make.IfThenElse(prev.condition, prev.then_case, self._pop_seq())) self.emit(_make.IfThenElse(prev.condition, prev.then_case, self._pop_seq()))
return WithScope(None, _exit_cb) return WithScope(None, _exit_cb)
def new_scope(self):
"""Create new scope,
this is useful to set boundary of attr and allocate.
Returns
-------
new_scope : WithScope
The result new scope.
"""
self._seq_stack.append([])
def _exit_cb():
self.emit(self._pop_seq())
return WithScope(None, _exit_cb)
def allocate(self, dtype, shape, name="buf", scope=None): def allocate(self, dtype, shape, name="buf", scope=None):
"""Create a allocate statement. """Create a allocate statement.
......
...@@ -589,4 +589,13 @@ class Stage(NodeBase): ...@@ -589,4 +589,13 @@ class Stage(NodeBase):
""" """
_api_internal._StageStorageAlign(self, axis, factor, offset) _api_internal._StageStorageAlign(self, axis, factor, offset)
def double_buffer(self):
"""Compute the current stage via double buffering.
This can only be applied to intermediate stage.
This will double the storage cost of the current stage.
Can be useful to hide load latency.
"""
_api_internal._StageDoubleBuffer(self)
_init_api("tvm.schedule") _init_api("tvm.schedule")
...@@ -385,13 +385,18 @@ TVM_REGISTER_API("_StagePragma") ...@@ -385,13 +385,18 @@ TVM_REGISTER_API("_StagePragma")
TVM_REGISTER_API("_StagePrefetch") TVM_REGISTER_API("_StagePrefetch")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
args[0].operator Stage() args[0].operator Stage()
.prefetch(args[1], args[2], args[3]); .prefetch(args[1], args[2], args[3]);
}); });
TVM_REGISTER_API("_StageStorageAlign") TVM_REGISTER_API("_StageStorageAlign")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
args[0].operator Stage() args[0].operator Stage()
.storage_align(args[1], args[2], args[3]); .storage_align(args[1], args[2], args[3]);
});
TVM_REGISTER_API("_StageDoubleBuffer")
.set_body([](TVMArgs args, TVMRetValue *ret) {
args[0].operator Stage().double_buffer();
}); });
TVM_REGISTER_API("_ScheduleNormalize") TVM_REGISTER_API("_ScheduleNormalize")
......
...@@ -101,6 +101,7 @@ REGISTER_PASS1(CoProcSync); ...@@ -101,6 +101,7 @@ REGISTER_PASS1(CoProcSync);
REGISTER_PASS1(LowerStorageAccessInfo); REGISTER_PASS1(LowerStorageAccessInfo);
REGISTER_PASS1(InjectVirtualThread); REGISTER_PASS1(InjectVirtualThread);
REGISTER_PASS1(InjectPrefetch); REGISTER_PASS1(InjectPrefetch);
REGISTER_PASS2(InjectDoubleBuffer);
REGISTER_PASS1(LoopPartition); REGISTER_PASS1(LoopPartition);
REGISTER_PASS1(RemoveNoOp); REGISTER_PASS1(RemoveNoOp);
REGISTER_PASS2(SplitPipeline); REGISTER_PASS2(SplitPipeline);
......
/*!
* Copyright (c) 2017 by Contributors
*
* \brief Inject double buffering optimization for data fetch.
* \file inject_double_buffer.cc
*/
#include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_mutator.h>
#include "./ir_util.h"
#include "../arithmetic/compute_expr.h"
namespace tvm {
namespace ir {
// Detect double buffer variables.
class DoubleBufferDetector : public IRVisitor {
public:
void Visit_(const AttrStmt* op) final {
if (op->attr_key == attr::double_buffer_scope) {
touched_.insert(op->node.as<Variable>());
IRVisitor::Visit_(op);
} else {
IRVisitor::Visit_(op);
}
}
void Visit_(const Variable* op) final {
if (touched_.count(op)) {
touched_.erase(op);
}
}
// The set of touched variable.
std::unordered_set<const Variable*> touched_;
};
class DoubleBufferInjector : public IRMutator {
public:
explicit DoubleBufferInjector(bool split_loop)
: split_loop_(split_loop) {}
Stmt Inject(const Stmt& stmt) {
DoubleBufferDetector detector;
detector.Visit(stmt);
if (detector.touched_.empty()) return stmt;
for (const Variable* v : detector.touched_) {
dbuffer_info_[v] = StorageEntry();
}
return ConvertSSA(this->Mutate(stmt));
}
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
if (op->attr_key == attr::storage_scope) {
const Variable* buf = op->node.as<Variable>();
auto it = dbuffer_info_.find(buf);
if (it != dbuffer_info_.end()) {
it->second.scope = op->value.as<StringImm>()->value;
return Mutate(op->body);
} else {
return IRMutator::Mutate_(op, s);
}
} else if (op->attr_key == attr::double_buffer_scope) {
return MakeProducer(op, s);
} else {
return IRMutator::Mutate_(op, s);
}
}
Stmt Mutate_(const Allocate* op, const Stmt& s) final {
auto it = dbuffer_info_.find(op->buffer_var.get());
if (it != dbuffer_info_.end()) {
it->second.size = arith::ComputeReduce<Mul>(op->extents);
Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<Allocate>();
Array<Expr> new_extents{make_const(op->extents[0].type(), 2)};
for (Expr e : op->extents) {
new_extents.push_back(e);
}
CHECK(it->second.loop != nullptr);
auto& alloc_nest = loop_allocs_[it->second.loop];
alloc_nest.emplace_back(AttrStmt::make(
op->buffer_var, attr::storage_scope,
StringImm::make(it->second.scope),
Evaluate::make(0)));
alloc_nest.emplace_back(Allocate::make(
op->buffer_var, op->type, new_extents, op->condition,
Evaluate::make(0)));
return op->body;
} else {
return IRMutator::Mutate_(op, s);
}
}
Stmt Mutate_(const For* op, const Stmt& s) final {
loop_nest_.push_back(op);
Stmt stmt = IRMutator::Mutate_(op, s);
auto it = loop_pre_.find(op);
if (it != loop_pre_.end()) {
const For* old_loop = stmt.as<For>();
if (split_loop_) {
Expr new_ext = arith::ComputeExpr<Sub>(
old_loop->extent, make_const(old_loop->loop_var.type(), 1));
Stmt loop = For::make(
old_loop->loop_var, old_loop->min, new_ext,
old_loop->for_type, old_loop->device_api,
old_loop->body);
std::unordered_map<const Variable*, Expr> vmap;
vmap[old_loop->loop_var.get()] = new_ext;
Stmt end = Substitute(old_loop->body, vmap);
stmt = Block::make(loop, end);
}
stmt = Block::make(MergeSeq(it->second), stmt);
}
it = loop_allocs_.find(op);
if (it != loop_allocs_.end()) {
stmt = MergeNest(it->second, stmt);
}
loop_nest_.pop_back();
return stmt;
}
Stmt Mutate_(const Store* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<Store>();
auto it = dbuffer_info_.find(op->buffer_var.get());
if (it != dbuffer_info_.end()) {
const StorageEntry& e = it->second;
CHECK(in_double_buffer_scope_);
CHECK(e.size.defined());
return Store::make(op->buffer_var,
op->value,
e.switch_write_var * e.size + op->index,
op->predicate);
} else {
return stmt;
}
}
Expr Mutate_(const Load* op, const Expr& e) final {
Expr expr = IRMutator::Mutate_(op, e);
op = expr.as<Load>();
auto it = dbuffer_info_.find(op->buffer_var.get());
if (it != dbuffer_info_.end()) {
const StorageEntry& e = it->second;
CHECK(e.size.defined());
CHECK(e.switch_read_var.defined());
return Load::make(op->type,
op->buffer_var,
e.switch_read_var * e.size + op->index,
op->predicate);
} else {
return expr;
}
}
Expr Mutate_(const Variable* op, const Expr& e) final {
CHECK(!dbuffer_info_.count(op));
return e;
}
private:
Stmt MakeProducer(const AttrStmt* op, const Stmt& s) {
const VarExpr buffer(op->node.node_);
CHECK_NE(loop_nest_.size(), 0U)
<< "Double buffer scope must be inside a loop";
auto it = dbuffer_info_.find(buffer.get());
if (it == dbuffer_info_.end()) {
LOG(WARNING) << "Skip double buffer scope " << op->node;
return Mutate(op->body);
}
StorageEntry& e = it->second;
e.loop = loop_nest_.back();
Expr zero = make_const(e.loop->loop_var.type(), 0);
Expr one = make_const(e.loop->loop_var.type(), 1);
Expr two = make_const(e.loop->loop_var.type(), 2);
Expr loop_shift = e.loop->loop_var + one;
e.switch_write_var = Var(e.loop->loop_var->name_hint + ".db",
e.loop->loop_var.type());
e.switch_read_var = e.loop->loop_var % two;
in_double_buffer_scope_ = true;
Stmt body = Mutate(op->body);
in_double_buffer_scope_ = false;
std::unordered_map<const Variable*, Expr> vmap;
vmap[e.switch_write_var.get()] = zero;
vmap[e.loop->loop_var.get()] = zero;
loop_pre_[e.loop].emplace_back(Substitute(body, vmap));
vmap[e.loop->loop_var.get()] = loop_shift;
vmap[e.switch_write_var.get()] = loop_shift % two;
body = Substitute(body, vmap);
body = AttrStmt::make(buffer, attr::double_buffer_write, 1, body);
body = IfThenElse::make(loop_shift < e.loop->extent, body);
return body;
}
// Storage entry for those who need double buffering.
struct StorageEntry {
// The size of the buffer
Expr size;
// The loop we need
const For* loop{nullptr};
// The switch variable.
VarExpr switch_write_var;
// The switch variable for reading.
Expr switch_read_var;
// The storage scope.
std::string scope;
};
// Whether split loop
bool split_loop_;
// Whether we are inside double buffer scope.
bool in_double_buffer_scope_{false};
// The current loop next
std::vector<const For*> loop_nest_;
// The allocs to be appended before the loop
std::unordered_map<const For*, std::vector<Stmt> > loop_allocs_;
// The stmt to be appended before the loop
std::unordered_map<const For*, std::vector<Stmt> > loop_pre_;
// The allocation size of the buffer
std::unordered_map<const Variable*, StorageEntry> dbuffer_info_;
};
Stmt InjectDoubleBuffer(Stmt stmt, bool split_loop) {
return DoubleBufferInjector(split_loop).Inject(stmt);
}
} // namespace ir
} // namespace tvm
...@@ -74,6 +74,24 @@ void StorageAccessVisitor::Visit_(const AttrStmt* op) { ...@@ -74,6 +74,24 @@ void StorageAccessVisitor::Visit_(const AttrStmt* op) {
storage_scope_[buf] = storage_scope_[buf] =
StorageScope::make(op->value.as<StringImm>()->value); StorageScope::make(op->value.as<StringImm>()->value);
IRVisitor::Visit_(op); IRVisitor::Visit_(op);
} else if (op->attr_key == attr::double_buffer_write) {
CHECK(double_buffer_write_ == nullptr);
double_buffer_write_ = op->node.as<Variable>();
scope_.push_back(std::vector<StmtEntry>());
IRVisitor::Visit_(op);
StmtEntry s;
s.stmt = op;
s.access = Summarize(std::move(scope_.back()), nullptr);
scope_.pop_back();
if (!s.access.empty()) {
for (AccessEntry& e : s.access) {
if (e.type == kWrite && e.buffer.get() == double_buffer_write_) {
e.double_buffer_write = true;
}
}
scope_.back().emplace_back(std::move(s));
}
double_buffer_write_ = nullptr;
} else if (op->attr_key == attr::coproc_scope) { } else if (op->attr_key == attr::coproc_scope) {
IterVar iv(op->node.node_); IterVar iv(op->node.node_);
env_threads_.push_back(iv); env_threads_.push_back(iv);
......
...@@ -45,6 +45,8 @@ class StorageAccessVisitor : public IRVisitor { ...@@ -45,6 +45,8 @@ class StorageAccessVisitor : public IRVisitor {
AccessType type; AccessType type;
/*! \brief The storage scope */ /*! \brief The storage scope */
StorageScope scope; StorageScope scope;
/*! \brief Whether the access is double buffer write */
bool double_buffer_write{false};
}; };
/*! \brief Access pattern about a single statement */ /*! \brief Access pattern about a single statement */
struct StmtEntry { struct StmtEntry {
...@@ -116,6 +118,8 @@ class StorageAccessVisitor : public IRVisitor { ...@@ -116,6 +118,8 @@ class StorageAccessVisitor : public IRVisitor {
bool in_device_env_{false}; bool in_device_env_{false};
// Whether we are inside condition. // Whether we are inside condition.
int condition_counter_{0}; int condition_counter_{0};
// The current double buffer write scope.
const Variable* double_buffer_write_{nullptr};
// the current free stmt entry. // the current free stmt entry.
StmtEntry curr_stmt_; StmtEntry curr_stmt_;
// The involving threads // The involving threads
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
*/ */
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/expr.h> #include <tvm/expr.h>
#include <tvm/operation.h>
#include <tvm/ir_mutator.h> #include <tvm/ir_mutator.h>
#include <tvm/ir_operator.h> #include <tvm/ir_operator.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
...@@ -53,6 +54,18 @@ class StorageFlattener : public IRMutator { ...@@ -53,6 +54,18 @@ class StorageFlattener : public IRMutator {
if (op->attr_key == attr::realize_scope) { if (op->attr_key == attr::realize_scope) {
storage_scope_[op->node.get()] = op->value.as<StringImm>()->value; storage_scope_[op->node.get()] = op->value.as<StringImm>()->value;
return this->Mutate(op->body); return this->Mutate(op->body);
} else if (op->attr_key == attr::double_buffer_scope) {
Operation func(op->node.node_);
Stmt body = Mutate(op->body);
for (int i = 0; i < func->num_outputs(); ++i) {
TensorKey key{func, i};
auto it = buf_map_.find(key);
CHECK(it != buf_map_.end())
<< "Cannot find allocated buffer for " << key.f;
body = AttrStmt::make(
it->second.buffer->data, op->attr_key, op->value, body);
}
return body;
} else if (op->attr_key == attr::thread_extent) { } else if (op->attr_key == attr::thread_extent) {
IterVar iv(op->node.node_); IterVar iv(op->node.node_);
ThreadScope ts = ThreadScope::make(iv->thread_tag); ThreadScope ts = ThreadScope::make(iv->thread_tag);
......
...@@ -34,13 +34,10 @@ class ThreadSyncPlanner : public StorageAccessVisitor { ...@@ -34,13 +34,10 @@ class ThreadSyncPlanner : public StorageAccessVisitor {
// Unsynced reads and writes // Unsynced reads and writes
std::vector<AccessEntry> reads; std::vector<AccessEntry> reads;
std::vector<AccessEntry> writes; std::vector<AccessEntry> writes;
// if it is a loop, rotate two times to consider effect of loop. // if it is a loop, rotate two times to consider effect of loop.
size_t max_seq = seq.size();
if (loop != nullptr) max_seq *= 2;
// simulation based approach to find dependenceies // simulation based approach to find dependenceies
for (size_t i = 0; i < max_seq; ++i) { for (size_t i = 0; i < seq.size(); ++i) {
const StmtEntry& s = seq[i % seq.size()]; const StmtEntry& s = seq[i];
// check if sync before statement is needed. // check if sync before statement is needed.
bool sync_before_stmt = (syncs_inserted_.count(s.stmt) != 0); bool sync_before_stmt = (syncs_inserted_.count(s.stmt) != 0);
// Apply the syncs added already. // Apply the syncs added already.
...@@ -50,11 +47,11 @@ class ThreadSyncPlanner : public StorageAccessVisitor { ...@@ -50,11 +47,11 @@ class ThreadSyncPlanner : public StorageAccessVisitor {
} }
for (const AccessEntry& acc : s.access) { for (const AccessEntry& acc : s.access) {
if (acc.type == kRead) { if (acc.type == kRead) {
if (FindConflict(writes, acc)) { if (FindConflict(writes, acc, false)) {
sync_before_stmt = true; break; sync_before_stmt = true; break;
} }
} else if (acc.type == kWrite) { } else if (acc.type == kWrite) {
if (FindConflict(reads, acc)) { if (FindConflict(reads, acc, false)) {
sync_before_stmt = true; break; sync_before_stmt = true; break;
} }
} else if (acc.type == kSync) { } else if (acc.type == kSync) {
...@@ -81,6 +78,33 @@ class ThreadSyncPlanner : public StorageAccessVisitor { ...@@ -81,6 +78,33 @@ class ThreadSyncPlanner : public StorageAccessVisitor {
syncs_inserted_.insert(s.stmt); syncs_inserted_.insert(s.stmt);
} }
} }
if (loop != nullptr) {
for (size_t i = 0; i < seq.size(); ++i) {
const StmtEntry& s = seq[i];
if (syncs_inserted_.count(s.stmt) != 0) break;
if (reads.empty() && writes.empty()) break;
bool sync_before_stmt = false;
for (const AccessEntry& acc : s.access) {
if (acc.type == kRead) {
if (FindConflict(writes, acc, true)) {
sync_before_stmt = true; break;
}
} else if (acc.type == kWrite) {
if (FindConflict(reads, acc, true)) {
sync_before_stmt = true; break;
}
} else if (acc.type == kSync) {
reads.clear(); writes.clear();
}
}
if (sync_before_stmt) {
CHECK_EQ(condition_counter(), 0)
<< "Cannot insert syncs inside condition";
syncs_inserted_.insert(s.stmt);
break;
}
}
}
// return the exposed entries, remove unecessary ones. // return the exposed entries, remove unecessary ones.
int sync_count = 0; int sync_count = 0;
// head are before first sync, tail are after last sync // head are before first sync, tail are after last sync
...@@ -117,13 +141,20 @@ class ThreadSyncPlanner : public StorageAccessVisitor { ...@@ -117,13 +141,20 @@ class ThreadSyncPlanner : public StorageAccessVisitor {
} }
} }
head.insert(head.end(), tail.begin(), tail.end()); head.insert(head.end(), tail.begin(), tail.end());
if (loop != nullptr) {
// clear double buffer flag after a loop is finished.
for (AccessEntry& e : head) {
e.double_buffer_write = false;
}
}
return head; return head;
} }
private: private:
// find conflicting entry in vec. // find conflicting entry in vec.
bool FindConflict(const std::vector<AccessEntry>& vec, bool FindConflict(const std::vector<AccessEntry>& vec,
const AccessEntry& e) { const AccessEntry& e,
bool loop_carry) {
for (const AccessEntry& x : vec) { for (const AccessEntry& x : vec) {
if (x.buffer.same_as(e.buffer)) { if (x.buffer.same_as(e.buffer)) {
// Assumes no race between threads // Assumes no race between threads
...@@ -134,6 +165,9 @@ class ThreadSyncPlanner : public StorageAccessVisitor { ...@@ -134,6 +165,9 @@ class ThreadSyncPlanner : public StorageAccessVisitor {
if (Equal(e.touched.point_value(), if (Equal(e.touched.point_value(),
x.touched.point_value())) continue; x.touched.point_value())) continue;
} }
if (x.double_buffer_write &&
e.type == kRead &&
!loop_carry) continue;
return true; return true;
} }
} }
......
...@@ -385,6 +385,13 @@ Stage& Stage::storage_align(IterVar axis, int factor, int offset) { ...@@ -385,6 +385,13 @@ Stage& Stage::storage_align(IterVar axis, int factor, int offset) {
return *this; return *this;
} }
Stage& Stage::double_buffer() {
StageNode *self = operator->();
CHECK(!self->is_output) << "Cannot apply double buffer on output";
self->double_buffer = true;
return *this;
}
Stage CopyStage(const Stage& s) { Stage CopyStage(const Stage& s) {
std::shared_ptr<StageNode> n = std::shared_ptr<StageNode> n =
std::make_shared<StageNode>(*s.operator->()); std::make_shared<StageNode>(*s.operator->());
......
...@@ -27,6 +27,10 @@ Stmt MakePipeline(const Stage& s, ...@@ -27,6 +27,10 @@ Stmt MakePipeline(const Stage& s,
if (producer.defined()) { if (producer.defined()) {
producer = ProducerConsumer::make(s->op, true, producer); producer = ProducerConsumer::make(s->op, true, producer);
} }
if (s->double_buffer) {
producer = AttrStmt::make(
s->op, ir::attr::double_buffer_scope, 1, producer);
}
Stmt pipeline = producer; Stmt pipeline = producer;
if (consumer.defined() && !is_no_op(consumer)) { if (consumer.defined() && !is_no_op(consumer)) {
...@@ -170,7 +174,8 @@ class SchedulePostProc : public IRMutator { ...@@ -170,7 +174,8 @@ class SchedulePostProc : public IRMutator {
thread_extent_scope_.erase(op->node.get()); thread_extent_scope_.erase(op->node.get());
return ret; return ret;
} }
} else if (op->attr_key == ir::attr::realize_scope) { } else if (op->attr_key == ir::attr::realize_scope ||
op->attr_key == ir::attr::double_buffer_scope) {
auto it = replace_op_.find(op->node.get()); auto it = replace_op_.find(op->node.get());
if (it != replace_op_.end()) { if (it != replace_op_.end()) {
if (it->second.defined()) { if (it->second.defined()) {
......
...@@ -47,7 +47,8 @@ def test_gemm(): ...@@ -47,7 +47,8 @@ def test_gemm():
s[CC].compute_at(s[C], tx) s[CC].compute_at(s[C], tx)
s[AA].compute_at(s[CC], k) s[AA].compute_at(s[CC], k)
s[BB].compute_at(s[CC], k) s[BB].compute_at(s[CC], k)
s[AA].double_buffer()
s[BB].double_buffer()
ty, xi = s[AA].split(s[AA].op.axis[0], nparts=num_thread) ty, xi = s[AA].split(s[AA].op.axis[0], nparts=num_thread)
tx, xi = s[AA].split(xi, nparts=num_thread) tx, xi = s[AA].split(xi, nparts=num_thread)
s[AA].bind(ty, thread_y) s[AA].bind(ty, thread_y)
...@@ -84,10 +85,10 @@ def test_gemm(): ...@@ -84,10 +85,10 @@ def test_gemm():
np.testing.assert_allclose( np.testing.assert_allclose(
c.asnumpy(), np.dot(a_np, b_np.T), rtol=1e-5) c.asnumpy(), np.dot(a_np, b_np.T), rtol=1e-5)
check_device("nvptx -mcpu=sm_20")
check_device("metal") check_device("metal")
check_device("opencl") check_device("opencl")
check_device("cuda") check_device("cuda")
#check_device("nvptx -mcpu=sm_20")
if __name__ == "__main__": if __name__ == "__main__":
test_gemm() test_gemm()
import tvm
def test_double_buffer():
dtype = 'int64'
n = 100
m = 4
tx = tvm.thread_axis("threadIdx.x")
ib = tvm.ir_builder.create()
A = ib.pointer("float32", name="A")
C = ib.pointer("float32", name="A")
ib.scope_attr(tx, "thread_extent", 1)
with ib.for_range(0, n) as i:
B = ib.allocate("float32", m, name="B", scope="shared")
with ib.new_scope():
ib.scope_attr(B.asnode(), "double_buffer_scope", 1)
with ib.for_range(0, m) as j:
B[j] = A[i * 4 + j]
with ib.for_range(0, m) as j:
C[j] = B[j] + 1
stmt = ib.get()
stmt = tvm.ir_pass.InjectDoubleBuffer(stmt, True)
stmt = tvm.ir_pass.Simplify(stmt)
assert isinstance(stmt.body.body, tvm.stmt.Allocate)
assert stmt.body.body.extents[0].value == 2
f = tvm.ir_pass.MakeAPI(stmt, "db", [A.asnode(), C.asnode()], 2, True)
f = tvm.ir_pass.ThreadSync(f, "shared")
count = [0]
def count_sync(op):
if isinstance(op, tvm.expr.Call) and op.name == "tvm_storage_sync":
count[0] += 1
tvm.ir_pass.PostOrderVisit(f.body, count_sync)
assert count[0] == 2
if __name__ == "__main__":
test_double_buffer()
...@@ -96,6 +96,8 @@ def test_gemm(): ...@@ -96,6 +96,8 @@ def test_gemm():
s[BB].bind(ty, thread_y) s[BB].bind(ty, thread_y)
s[BB].bind(tx, thread_x) s[BB].bind(tx, thread_x)
s[BB].vectorize(xi) s[BB].vectorize(xi)
s[AA].double_buffer()
s[BB].double_buffer()
# correctness # correctness
def check_device(device): def check_device(device):
if not tvm.module.enabled(device): if not tvm.module.enabled(device):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment