Commit eefcfe19 by Tianqi Chen Committed by GitHub

[PASS] Refactor thread storage sync to a common visitor (#296)

* [PASS] Refactor thread storage sync to a common visitor

* Fix the sync scope check behavior
parent 6bc0ae12
...@@ -196,6 +196,17 @@ IntSet EvalSet(Expr e, ...@@ -196,6 +196,17 @@ IntSet EvalSet(Expr e,
*/ */
IntSet EvalSet(Range r, IntSet EvalSet(Range r,
const Map<IterVar, IntSet>& dom_map); const Map<IterVar, IntSet>& dom_map);
/*!
* \brief Find an symbolic integer set that contains is union over
* all the possible conditional values in dom_map.
*
* \param s The initial set.
* \param dom_map The domain of each variable.
* \return An integer set that can cover all the possible values.
*/
IntSet EvalSet(IntSet s,
const std::unordered_map<const Variable*, IntSet>& dom_map);
/*! /*!
* \brief Same as EvalSet, but takes unordered_map * \brief Same as EvalSet, but takes unordered_map
* *
......
...@@ -319,7 +319,7 @@ Array<LoweredFunc> SplitHostDevice(LoweredFunc func); ...@@ -319,7 +319,7 @@ Array<LoweredFunc> SplitHostDevice(LoweredFunc func);
* \param stmt The stmt to be trasnformed. * \param stmt The stmt to be trasnformed.
* \param storage_scope The storage scope considered. * \param storage_scope The storage scope considered.
*/ */
LoweredFunc StorageSync(LoweredFunc stmt, std::string storage_scope); LoweredFunc ThreadSync(LoweredFunc stmt, std::string storage_scope);
/*! /*!
* \brief Lower cross thread alleduce in the stmt. * \brief Lower cross thread alleduce in the stmt.
......
...@@ -299,8 +299,8 @@ def build(sch, ...@@ -299,8 +299,8 @@ def build(sch,
for func in flist: for func in flist:
if func.func_type == container.LoweredFunc.MixedFunc: if func.func_type == container.LoweredFunc.MixedFunc:
if BuildConfig.current.detect_global_barrier: if BuildConfig.current.detect_global_barrier:
func = ir_pass.StorageSync(func, "global") func = ir_pass.ThreadSync(func, "global")
func = ir_pass.StorageSync(func, "shared") func = ir_pass.ThreadSync(func, "shared")
warp_size = 32 if target == "cuda" else 1 warp_size = 32 if target == "cuda" else 1
func = ir_pass.LowerThreadAllreduce(func, warp_size) func = ir_pass.LowerThreadAllreduce(func, warp_size)
fsplits = [s for s in ir_pass.SplitHostDevice(func)] fsplits = [s for s in ir_pass.SplitHostDevice(func)]
......
...@@ -89,7 +89,7 @@ REGISTER_PASS4(Inline); ...@@ -89,7 +89,7 @@ REGISTER_PASS4(Inline);
REGISTER_PASS3(StorageFlatten); REGISTER_PASS3(StorageFlatten);
REGISTER_PASS1(VectorizeLoop); REGISTER_PASS1(VectorizeLoop);
REGISTER_PASS4(UnrollLoop); REGISTER_PASS4(UnrollLoop);
REGISTER_PASS2(StorageSync); REGISTER_PASS2(ThreadSync);
REGISTER_PASS5(MakeAPI); REGISTER_PASS5(MakeAPI);
REGISTER_PASS2(BindDeviceType); REGISTER_PASS2(BindDeviceType);
REGISTER_PASS1(SplitHostDevice); REGISTER_PASS1(SplitHostDevice);
......
...@@ -565,10 +565,22 @@ IntSet EvalSet(Range r, ...@@ -565,10 +565,22 @@ IntSet EvalSet(Range r,
IntSet ext_set = m.Eval(r->extent).cover_interval(); IntSet ext_set = m.Eval(r->extent).cover_interval();
const Interval& ei = ext_set.as<IntervalSet>()->i; const Interval& ei = ext_set.as<IntervalSet>()->i;
if (!ei.has_upper_bound()) return IntSet::everything(); if (!ei.has_upper_bound()) return IntSet::everything();
ext_set = IntervalSet::make(0, ComputeExpr<Sub>(ei.max, 1)); ext_set = IntervalSet::make(make_zero(ei.max.type()), ComputeExpr<Sub>(ei.max, 1));
return Combine<Add>(min_set, ext_set); return Combine<Add>(min_set, ext_set);
} }
IntSet EvalSet(IntSet s,
const std::unordered_map<const Variable*, IntSet>& dom_map) {
IntSetEvaluator m(dom_map);
s = s.cover_interval();
const IntervalSet* s_int = s.as<IntervalSet>();
Expr vmax = s_int->i.has_upper_bound() ?
m.Eval(s_int->i.max).cover_interval().max() : s_int->i.max;
Expr vmin = s_int->i.has_lower_bound() ?
m.Eval(s_int->i.min).cover_interval().min() : s_int->i.min;
return IntervalSet::make(vmin, vmax);
}
class SubExprIntSetEvaluator : public IntSetEvaluator { class SubExprIntSetEvaluator : public IntSetEvaluator {
public: public:
explicit SubExprIntSetEvaluator( explicit SubExprIntSetEvaluator(
......
/*!
* Copyright (c) 2017 by Contributors
* \file storage_access.cc
*/
#include "./storage_access.h"
namespace tvm {
namespace ir {
void StorageAccessVisitor::Visit_(const Load* op) {
const Variable* buf = op->buffer_var.as<Variable>();
StorageScope scope = GetScope(buf);
if (Enabled(buf, scope)) {
CHECK(allow_append_);
AccessEntry e;
e.threads = env_threads();
e.buffer = buf;
e.dtype = op->type.element_of();
e.touched = arith::IntSet::vector(op->index);
e.type = kRead;
e.scope = scope;
curr_stmt_.access.emplace_back(std::move(e));
}
// traverse child
IRVisitor::Visit_(op);
}
void StorageAccessVisitor::Visit_(const Store* op) {
allow_append_ = true;
CHECK_EQ(curr_stmt_.access.size(), 0U);
curr_stmt_.stmt = op;
const Variable* buf = op->buffer_var.as<Variable>();
StorageScope scope = GetScope(buf);
if (Enabled(buf, scope)) {
AccessEntry e;
e.threads = env_threads();
e.buffer = buf;
e.dtype = op->value.type().element_of();
e.touched = arith::IntSet::vector(op->index);
e.type = kWrite;
e.scope = scope;
curr_stmt_.access.emplace_back(std::move(e));
}
// traverse child
IRVisitor::Visit_(op);
// push to the scope
scope_.back().push_back(curr_stmt_);
// clear access entry.
curr_stmt_.access.clear();
allow_append_ = false;
}
void StorageAccessVisitor::Visit_(const Evaluate* op) {
allow_append_ = true;
CHECK_EQ(curr_stmt_.access.size(), 0U);
curr_stmt_.stmt = op;
IRVisitor::Visit_(op);
// push to the scope
if (curr_stmt_.access.size() != 0) {
scope_.back().push_back(curr_stmt_);
curr_stmt_.access.clear();
}
allow_append_ = false;
}
void StorageAccessVisitor::Visit_(const AttrStmt* op) {
if (op->attr_key == attr::storage_scope) {
const Variable* buf = op->node.as<Variable>();
storage_scope_[buf] =
StorageScope::make(op->value.as<StringImm>()->value);
IRVisitor::Visit_(op);
} else if (op->attr_key == attr::thread_extent) {
IterVar iv(op->node.node_);
env_threads_.push_back(iv);
if (!in_device_env_) {
in_device_env_ = true;
scope_.push_back(std::vector<StmtEntry>());
IRVisitor::Visit_(op);
// no need to take the result as the thread barrier automatically syncs.
Summarize(std::move(scope_.back()), nullptr);
in_device_env_ = false;
scope_.pop_back();
} else {
IRVisitor::Visit_(op);
}
env_threads_.CopyOnWrite()->data.pop_back();
} else {
IRVisitor::Visit_(op);
}
}
void StorageAccessVisitor::Visit_(const For* op) {
scope_.push_back(std::vector<StmtEntry>());
IRVisitor::Visit_(op);
StmtEntry s;
s.stmt = op;
s.access = Summarize(std::move(scope_.back()), op);
scope_.pop_back();
if (s.access.size() != 0) {
// relax the touched set to contain all ranges in the loop.
std::unordered_map<const Variable*, arith::IntSet> relax_map;
relax_map[op->loop_var.get()] = arith::IntSet::range(
Range::make_by_min_extent(op->min, op->extent));
for (AccessEntry& e : s.access) {
if (e.buffer != nullptr) {
CHECK(e.touched.defined());
e.touched = arith::EvalSet(e.touched, relax_map);
}
}
scope_.back().emplace_back(std::move(s));
}
}
void StorageAccessVisitor::Visit_(const IfThenElse* op) {
++condition_counter_;
this->Visit(op->condition);
scope_.push_back(std::vector<StmtEntry>());
this->Visit(op->then_case);
StmtEntry s;
s.stmt = op;
s.access = Summarize(std::move(scope_.back()), nullptr);
scope_.pop_back();
if (op->else_case.defined()) {
scope_.push_back(std::vector<StmtEntry>());
auto v = Summarize(std::move(scope_.back()), nullptr);
scope_.pop_back();
s.access.insert(s.access.end(), v.begin(), v.end());
}
scope_.back().emplace_back(std::move(s));
--condition_counter_;
}
void StorageAccessVisitor::Visit_(const Call* op) {
if (op->is_intrinsic(intrinsic::tvm_address_of)) {
const Load *l = op->args[0].as<Load>();
IRVisitor::Visit_(l);
} else if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
CHECK_EQ(op->args.size(), 5U);
Type dtype = op->args[0].type();
const Variable* buffer = op->args[1].as<Variable>();
Expr offset = op->args[2];
Expr extent = op->args[3];
const IntImm* flag = op->args[4].as<IntImm>();
StorageScope scope = GetScope(buffer);
// The buffer scope.
if (Enabled(buffer, scope)) {
CHECK(allow_append_);
AccessEntry e;
e.threads = env_threads();
e.dtype = dtype;
e.buffer = buffer;
e.touched = arith::IntSet::range(
Range::make_by_min_extent(offset, extent));
e.scope = scope;
if (flag->value & 1) {
e.type = kRead;
curr_stmt_.access.emplace_back(e);
}
if (flag->value & 2) {
e.type = kWrite;
curr_stmt_.access.emplace_back(e);
}
}
IRVisitor::Visit_(op);
} else if (op->is_intrinsic(intrinsic::tvm_storage_sync)) {
CHECK(allow_append_);
const std::string& s = op->args[0].as<StringImm>()->value;
if (s != "warp") {
StorageScope scope = StorageScope::make(s);
AccessEntry e;
e.threads = env_threads();
e.type = kSync;
e.scope = StorageScope::make(s);
curr_stmt_.access.emplace_back(std::move(e));
}
} else {
IRVisitor::Visit_(op);
}
}
StorageScope StorageAccessVisitor::GetScope(const Variable* buf) const {
auto it = storage_scope_.find(buf);
StorageScope s; s.rank = 0;
if (it == storage_scope_.end()) return s;
return it->second;
}
} // namespace ir
} // namespace tvm
...@@ -15,46 +15,113 @@ ...@@ -15,46 +15,113 @@
namespace tvm { namespace tvm {
namespace ir { namespace ir {
namespace storage {
// The storage scope.
using runtime::StorageScope; using runtime::StorageScope;
/*! \brief Storage access type */ /*!
enum AccessType { * \brief Base class of storage access analysis
kRead, */
kWrite, class StorageAccessVisitor : public IRVisitor {
kOpaque, public:
kSync, /*! \brief Storage access type */
kAlloc enum AccessType {
}; kRead,
kWrite,
kSync,
kAlloc
};
/*! \brief An access entry */
struct AccessEntry {
/*! \brief The thread index that access this entry */
Array<IterVar> threads;
/*! \brief The buffer variable, if any */
const Variable* buffer{nullptr};
/*! \brief The access data type */
Type dtype;
/*! \brief The touched access range */
arith::IntSet touched;
/*! \brief The type of access */
AccessType type;
/*! \brief The storage scope */
StorageScope scope;
};
/*! \brief Access pattern about a single statement */
struct StmtEntry {
/*! \brief The statement */
const Node* stmt;
/*! \brief access patterns in the statement */
std::vector<AccessEntry> access;
};
// override visitor pattern
void Visit_(const Load* op) final;
void Visit_(const Store* op) final;
void Visit_(const Evaluate* op) final;
void Visit_(const AttrStmt* op) final;
void Visit_(const For* op) final;
void Visit_(const IfThenElse* op) final;
void Visit_(const Call* op) final;
/*! \brief The access entry */ protected:
struct AccessEntry { StorageAccessVisitor() {
/*! \brief The buffer variable, if any */ scope_.push_back(std::vector<StmtEntry>());
const Variable* buffer{nullptr}; }
/*! \brief The access index */ /*! \return number of conditions in the current scope. */
Expr index; int condition_counter() const {
/*! \brief The type of access */ return condition_counter_;
AccessType type; }
/*! \brief The storage scope */ /*! \return whether we are in device environment. */
StorageScope scope; bool in_device_env() const {
// constructor return in_device_env_;
AccessEntry() {} }
AccessEntry(const Variable* buffer, /*! \return environment threads */
Expr index, const Array<IterVar>& env_threads() const {
AccessType type, return env_threads_;
StorageScope scope) }
: buffer(buffer), index(index), type(type), scope(scope) {} /*!
}; * \brief Whether we need analyze the buffer in current scope.
* \param buffer The buffer to be checked
* \param scope The scope of the buffer.
* \return Whether the analysis of buffer is enabled.
*/
virtual bool Enabled(const Variable* buffer,
const StorageScope& scope) const {
return true;
}
/*!
* \brief Summarize the sequence of operations into parent.
*
* Insert synchronization if necessary and remove un-necessary
* memory access which are already synced.
*
* \param seq The sequence of the access operations.
* \param loop Pass loop node if it is a loop, otherwise nullptr.
* \return The summarized sequence that represent access that
* the parent should taken care of to synchronize.
*/
virtual std::vector<AccessEntry> Summarize(
std::vector<StmtEntry> seq, const For* loop) = 0;
/*!
* \brief Get the scope of the buffer array.
* \return The scope of the final buffer array.
*/
StorageScope GetScope(const Variable* buf) const;
/*! \brief The access info about a statment */ private:
struct StmtEntry { // whether access appending is enabled.
/*! \brief The statement */ bool allow_append_{false};
const Node* stmt; // Whether we are in device environment
/*! \brief access patterns in the statement */ bool in_device_env_{false};
std::vector<AccessEntry> access; // Whether we are inside condition.
int condition_counter_{0};
// the current free stmt entry.
StmtEntry curr_stmt_;
// The involving threads
Array<IterVar> env_threads_;
// access scope
std::vector<std::vector<StmtEntry> > scope_;
// The storage scope of each buffer
std::unordered_map<const Variable*, StorageScope> storage_scope_;
}; };
} // namespace storage
} // namespace ir } // namespace ir
} // namespace tvm } // namespace tvm
#endif // TVM_PASS_STORAGE_ACCESS_H_ #endif // TVM_PASS_STORAGE_ACCESS_H_
...@@ -15,142 +15,29 @@ ...@@ -15,142 +15,29 @@
namespace tvm { namespace tvm {
namespace ir { namespace ir {
using namespace storage; class ThreadSyncPlanner : public StorageAccessVisitor {
class StorageSyncPlanner : public IRVisitor {
public: public:
explicit StorageSyncPlanner(StorageScope sync_scope) explicit ThreadSyncPlanner(StorageScope sync_scope)
: sync_scope_(sync_scope) {} : sync_scope_(sync_scope) {}
void Visit_(const Load* op) final {
if (!in_device_env_) return;
CHECK(allow_load_);
const Variable* buf = op->buffer_var.as<Variable>();
StorageScope s = GetScope(buf);
if (s == sync_scope_) {
curr_stmt_.access.emplace_back(
AccessEntry(buf, op->index, kRead, s));
}
}
void Visit_(const Store* op) final {
if (!in_device_env_) return;
allow_load_ = true;
CHECK_EQ(curr_stmt_.access.size(), 0U);
curr_stmt_.stmt = op;
const Variable* buf = op->buffer_var.as<Variable>();
StorageScope s = GetScope(buf);
if (s == sync_scope_) {
curr_stmt_.access.emplace_back(
AccessEntry(buf, op->index, kWrite, s));
}
// traverse child
IRVisitor::Visit_(op);
// push to the scope
scope_.back().push_back(curr_stmt_);
// clear access entry.
curr_stmt_.access.clear();
allow_load_ = false;
}
void Visit_(const Evaluate* op) final {
if (!in_device_env_) return;
if (const Call* call = op->value.as<Call>()) {
if (call->is_intrinsic(intrinsic::tvm_storage_sync)) {
const std::string& s = call->args[0].as<StringImm>()->value;
if (s != "warp") {
StorageScope scope = StorageScope::make(s);
if (scope.rank <= sync_scope_.rank) {
CHECK_EQ(curr_stmt_.access.size(), 0U);
curr_stmt_.access.emplace_back(
AccessEntry(nullptr, Expr(), kSync, scope));
// push to the scope
scope_.back().push_back(curr_stmt_);
curr_stmt_.access.clear();
}
}
}
}
}
void Visit_(const AttrStmt* op) final {
if (op->attr_key == attr::storage_scope) {
const Variable* buf = op->node.as<Variable>();
storage_scope_[buf] =
StorageScope::make(op->value.as<StringImm>()->value);
IRVisitor::Visit_(op);
} else if (op->attr_key == attr::thread_extent && !in_device_env_) {
in_device_env_ = true;
CHECK_EQ(scope_.size(), 0U);
scope_.push_back(std::vector<StmtEntry>());
IRVisitor::Visit_(op);
this->PlanSync(false);
in_device_env_ = false;
scope_.pop_back();
} else {
IRVisitor::Visit_(op);
}
}
void Visit_(const For* op) final {
if (in_device_env_) {
scope_.push_back(std::vector<StmtEntry>());
IRVisitor::Visit_(op);
StmtEntry s; s.stmt = op;
s.access = PlanSync(true);
scope_.pop_back();
scope_.back().emplace_back(std::move(s));
} else {
IRVisitor::Visit_(op);
}
}
void Visit_(const Call* op) final {
if (op->is_intrinsic(intrinsic::tvm_address_of)) {
const Load *l = op->args[0].as<Load>();
IRVisitor::Visit_(l);
} else {
IRVisitor::Visit_(op);
}
}
void Visit_(const IfThenElse* op) final {
if (in_device_env_) {
++condition_counter_;
this->Visit(op->condition);
scope_.push_back(std::vector<StmtEntry>());
this->Visit(op->then_case);
StmtEntry s; s.stmt = op;
s.access = PlanSync(false);
scope_.pop_back();
if (op->else_case.defined()) {
scope_.push_back(std::vector<StmtEntry>());
auto v = PlanSync(false);
scope_.pop_back();
s.access.insert(s.access.end(), v.begin(), v.end());
}
scope_.back().emplace_back(std::move(s));
--condition_counter_;
} else {
IRVisitor::Visit_(op);
}
}
// The syncs inserted before each statement // The syncs inserted before each statement
std::unordered_set<const Node*> syncs_inserted_; std::unordered_set<const Node*> syncs_inserted_;
private: protected:
// Get storage scope of buffer. bool Enabled(const Variable* buf,
StorageScope GetScope(const Variable* buf) const { const StorageScope& scope) const final {
auto it = storage_scope_.find(buf); return in_device_env() && scope == sync_scope_;
StorageScope s; s.rank = 0;
if (it == storage_scope_.end()) return s;
return it->second;
} }
// Plan the sync // Plan the sync
std::vector<AccessEntry> PlanSync(bool is_loop) { std::vector<AccessEntry> Summarize(
// unsynced reads and writes std::vector<StmtEntry> seq, const For* loop) final {
// Unsynced reads and writes
std::vector<AccessEntry> reads; std::vector<AccessEntry> reads;
std::vector<AccessEntry> writes; std::vector<AccessEntry> writes;
const std::vector<StmtEntry>& seq = scope_.back();
// if it is a loop, rotate two times to consider effect of loop. // if it is a loop, rotate two times to consider effect of loop.
size_t max_seq = seq.size(); size_t max_seq = seq.size();
if (is_loop) max_seq *= 2; if (loop != 0) max_seq *= 2;
// simulation based approach to find dependenceies // simulation based approach to find dependenceies
for (size_t i = 0; i < max_seq; ++i) { for (size_t i = 0; i < max_seq; ++i) {
const StmtEntry& s = seq[i % seq.size()]; const StmtEntry& s = seq[i % seq.size()];
...@@ -189,7 +76,7 @@ class StorageSyncPlanner : public IRVisitor { ...@@ -189,7 +76,7 @@ class StorageSyncPlanner : public IRVisitor {
} }
} }
if (sync_before_stmt) { if (sync_before_stmt) {
CHECK_EQ(condition_counter_, 0) CHECK_EQ(condition_counter(), 0)
<< "Cannot insert syncs inside condition"; << "Cannot insert syncs inside condition";
syncs_inserted_.insert(s.stmt); syncs_inserted_.insert(s.stmt);
} }
...@@ -198,12 +85,17 @@ class StorageSyncPlanner : public IRVisitor { ...@@ -198,12 +85,17 @@ class StorageSyncPlanner : public IRVisitor {
int sync_count = 0; int sync_count = 0;
// head are before first sync, tail are after last sync // head are before first sync, tail are after last sync
std::vector<AccessEntry> head, tail; std::vector<AccessEntry> head, tail;
AccessEntry esync;
esync.threads = this->env_threads();
esync.type = kSync;
esync.scope = sync_scope_;
for (const StmtEntry& s : seq) { for (const StmtEntry& s : seq) {
if (syncs_inserted_.count(s.stmt)) { if (syncs_inserted_.count(s.stmt)) {
if (sync_count != 0) { if (sync_count != 0) {
tail.clear(); tail.clear();
} else { } else {
head.push_back(AccessEntry(nullptr, Expr(), kSync, sync_scope_)); head.push_back(esync);
} }
++sync_count; ++sync_count;
} }
...@@ -212,7 +104,7 @@ class StorageSyncPlanner : public IRVisitor { ...@@ -212,7 +104,7 @@ class StorageSyncPlanner : public IRVisitor {
if (sync_count != 0) { if (sync_count != 0) {
tail.clear(); tail.clear();
} else { } else {
head.push_back(AccessEntry(nullptr, Expr(), kSync, sync_scope_)); head.push_back(esync);
} }
++sync_count; ++sync_count;
} else { } else {
...@@ -227,35 +119,36 @@ class StorageSyncPlanner : public IRVisitor { ...@@ -227,35 +119,36 @@ class StorageSyncPlanner : public IRVisitor {
head.insert(head.end(), tail.begin(), tail.end()); head.insert(head.end(), tail.begin(), tail.end());
return head; return head;
} }
private:
// find conflicting entry in vec. // find conflicting entry in vec.
bool FindConflict(const std::vector<AccessEntry>& vec, bool FindConflict(const std::vector<AccessEntry>& vec,
const AccessEntry& e) { const AccessEntry& e) {
for (const AccessEntry& x : vec) { for (const AccessEntry& x : vec) {
if (x.buffer == e.buffer && if (x.buffer == e.buffer) {
!e.index.same_as(x.index)) return true; // Assumes no race between threads
// Same index value means no conflicts
// TODO(tqchen) more standard set based testing.
if (e.touched.is_single_point() &&
x.touched.is_single_point()) {
if (Equal(e.touched.point_value(),
x.touched.point_value())) continue;
}
return true;
}
} }
return false; return false;
} }
// Whether we are inside condition.
int condition_counter_{0}; private:
// whether load is enabled. // synchronization scope
bool in_device_env_{false};
// whether load is enabled.
bool allow_load_{false};
// the current free stmt entry.
StmtEntry curr_stmt_;
// access scope
std::vector<std::vector<StmtEntry> > scope_;
// The storage scope of each buffer
std::unordered_map<const Variable*, StorageScope> storage_scope_;
// The sync scope we care about.
StorageScope sync_scope_; StorageScope sync_scope_;
}; };
class StorageSyncInserter : public IRMutator { class ThreadSyncInserter : public IRMutator {
public: public:
StorageSyncInserter(StorageScope sync_scope, ThreadSyncInserter(StorageScope sync_scope,
const std::unordered_set<const Node*>& syncs) const std::unordered_set<const Node*>& syncs)
: sync_scope_(sync_scope), syncs_(syncs) {} : sync_scope_(sync_scope), syncs_(syncs) {}
Stmt Mutate(Stmt stmt) final { Stmt Mutate(Stmt stmt) final {
...@@ -389,17 +282,17 @@ class StorageSyncInserter : public IRMutator { ...@@ -389,17 +282,17 @@ class StorageSyncInserter : public IRMutator {
Expr is_lead_; Expr is_lead_;
}; };
Stmt StorageSync(Stmt stmt, std::string storage_scope) { Stmt ThreadSync(Stmt stmt, std::string storage_scope) {
StorageScope sync_scope = StorageScope::make(storage_scope); StorageScope sync_scope = StorageScope::make(storage_scope);
StorageSyncPlanner planner(sync_scope); ThreadSyncPlanner planner(sync_scope);
planner.Visit(stmt); planner.Visit(stmt);
return StorageSyncInserter(sync_scope, planner.syncs_inserted_).Mutate(stmt); return ThreadSyncInserter(sync_scope, planner.syncs_inserted_).Mutate(stmt);
} }
LoweredFunc StorageSync(LoweredFunc f, std::string storage_scope) { LoweredFunc ThreadSync(LoweredFunc f, std::string storage_scope) {
CHECK_NE(f->func_type, kHostFunc); CHECK_NE(f->func_type, kHostFunc);
auto n = std::make_shared<LoweredFuncNode>(*f.operator->()); auto n = std::make_shared<LoweredFuncNode>(*f.operator->());
n->body = StorageSync(f->body, storage_scope); n->body = ThreadSync(f->body, storage_scope);
return LoweredFunc(n); return LoweredFunc(n);
} }
......
...@@ -23,7 +23,7 @@ def test_storage_sync(): ...@@ -23,7 +23,7 @@ def test_storage_sync():
f = tvm.ir_pass.MakeAPI(stmt, "test", [Ab, A2b], 0, True) f = tvm.ir_pass.MakeAPI(stmt, "test", [Ab, A2b], 0, True)
flist = tvm.ir_pass.SplitHostDevice(f) flist = tvm.ir_pass.SplitHostDevice(f)
f = flist[1] f = flist[1]
f = tvm.ir_pass.StorageSync(f, "shared") f = tvm.ir_pass.ThreadSync(f, "shared")
print(f.body) print(f.body)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment