Commit 989dda89 by Tianqi Chen Committed by GitHub

[PASS] Simplify dependency of StorageRewrite (#291)

parent 657498a3
...@@ -13,14 +13,16 @@ ...@@ -13,14 +13,16 @@
#include <unordered_set> #include <unordered_set>
#include <unordered_map> #include <unordered_map>
#include "./ir_util.h" #include "./ir_util.h"
#include "./storage_access.h"
#include "../arithmetic/compute_expr.h" #include "../arithmetic/compute_expr.h"
#include "../runtime/thread_storage_scope.h"
namespace tvm { namespace tvm {
namespace ir { namespace ir {
using namespace storage; using runtime::StorageScope;
// Find a linear pattern of storage acess // Find a linear pattern of storage acess
// Used for liveness analysis.
// Composite scopes(loop/thread_launch/IfThen) is represented by two points: // Composite scopes(loop/thread_launch/IfThen) is represented by two points:
// before_scope -> scope_body -> after_scope // before_scope -> scope_body -> after_scope
// //
...@@ -33,8 +35,18 @@ using namespace storage; ...@@ -33,8 +35,18 @@ using namespace storage;
// The storage need to be kept alive between allocate and last access. // The storage need to be kept alive between allocate and last access.
// The free point is only inserted at the same scope of allocate. // The free point is only inserted at the same scope of allocate.
// //
class StorageAccessPatternFinder final : public IRVisitor { class LinearAccessPatternFinder final : public IRVisitor {
public: public:
/*! \brief record the touch hist of statment. */
struct StmtEntry {
// The statment
const Node* stmt;
// Scope used for allocation.
StorageScope alloc_scope;
// The buffer variables this statment touched.
std::vector<const Variable*> touched;
};
// Get linear access pattern. // Get linear access pattern.
std::vector<StmtEntry> GetLinearSeq(const Stmt& s) { std::vector<StmtEntry> GetLinearSeq(const Stmt& s) {
this->Visit(s); this->Visit(s);
...@@ -49,8 +61,8 @@ class StorageAccessPatternFinder final : public IRVisitor { ...@@ -49,8 +61,8 @@ class StorageAccessPatternFinder final : public IRVisitor {
alloc_scope_level_[buf] = level; alloc_scope_level_[buf] = level;
StmtEntry e; StmtEntry e;
e.stmt = op; e.stmt = op;
e.access.emplace_back( e.alloc_scope = GetScope(buf);
AccessEntry(buf, Expr(), kAlloc, GetScope(buf))); e.touched.push_back(buf);
linear_seq_.emplace_back(std::move(e)); linear_seq_.emplace_back(std::move(e));
IRVisitor::Visit_(op); IRVisitor::Visit_(op);
} }
...@@ -62,12 +74,11 @@ class StorageAccessPatternFinder final : public IRVisitor { ...@@ -62,12 +74,11 @@ class StorageAccessPatternFinder final : public IRVisitor {
const Variable* buf = op->buffer_var.get(); const Variable* buf = op->buffer_var.get();
auto it = alloc_scope_level_.find(buf); auto it = alloc_scope_level_.find(buf);
if (it != alloc_scope_level_.end()) { if (it != alloc_scope_level_.end()) {
scope_[it->second].access.emplace_back( scope_[it->second].touched.push_back(buf);
AccessEntry(buf, op->index, kWrite, GetScope(buf)));
} }
StmtEntry e = scope_.back(); StmtEntry e = scope_.back();
scope_.pop_back(); scope_.pop_back();
if (e.access.size() != 0) { if (e.touched.size() != 0) {
e.stmt = op; e.stmt = op;
linear_seq_.push_back(e); linear_seq_.push_back(e);
} }
...@@ -78,7 +89,7 @@ class StorageAccessPatternFinder final : public IRVisitor { ...@@ -78,7 +89,7 @@ class StorageAccessPatternFinder final : public IRVisitor {
IRVisitor::Visit_(op); IRVisitor::Visit_(op);
StmtEntry e = scope_.back(); StmtEntry e = scope_.back();
scope_.pop_back(); scope_.pop_back();
if (e.access.size() != 0) { if (e.touched.size() != 0) {
e.stmt = op; e.stmt = op;
linear_seq_.push_back(e); linear_seq_.push_back(e);
} }
...@@ -91,8 +102,7 @@ class StorageAccessPatternFinder final : public IRVisitor { ...@@ -91,8 +102,7 @@ class StorageAccessPatternFinder final : public IRVisitor {
if (it != alloc_scope_level_.end()) { if (it != alloc_scope_level_.end()) {
CHECK_LT(it->second, scope_.size()) CHECK_LT(it->second, scope_.size())
<< "Load memory in places other than store."; << "Load memory in places other than store.";
scope_[it->second].access.emplace_back( scope_[it->second].touched.push_back(buf);
AccessEntry(buf, op->index, kRead, GetScope(buf)));
} }
} }
void Visit_(const Call* op) final { void Visit_(const Call* op) final {
...@@ -108,8 +118,7 @@ class StorageAccessPatternFinder final : public IRVisitor { ...@@ -108,8 +118,7 @@ class StorageAccessPatternFinder final : public IRVisitor {
auto it = alloc_scope_level_.find(buf); auto it = alloc_scope_level_.find(buf);
if (it != alloc_scope_level_.end()) { if (it != alloc_scope_level_.end()) {
CHECK_LT(it->second, scope_.size()) << " buf=" << buf->name_hint; CHECK_LT(it->second, scope_.size()) << " buf=" << buf->name_hint;
scope_[it->second].access.emplace_back( scope_[it->second].touched.push_back(buf);
AccessEntry(buf, Expr(), kOpaque, GetScope(buf)));
} }
} }
template<typename T> template<typename T>
...@@ -121,7 +130,7 @@ class StorageAccessPatternFinder final : public IRVisitor { ...@@ -121,7 +130,7 @@ class StorageAccessPatternFinder final : public IRVisitor {
linear_seq_.push_back(e); linear_seq_.push_back(e);
IRVisitor::Visit_(op); IRVisitor::Visit_(op);
// after scope. // after scope.
e.access = std::move(scope_.back().access); e.touched = std::move(scope_.back().touched);
scope_.pop_back(); scope_.pop_back();
linear_seq_.push_back(e); linear_seq_.push_back(e);
} }
...@@ -178,9 +187,11 @@ class StorageAccessPatternFinder final : public IRVisitor { ...@@ -178,9 +187,11 @@ class StorageAccessPatternFinder final : public IRVisitor {
// Planner to plan and rewrite memory allocation. // Planner to plan and rewrite memory allocation.
class StoragePlanRewriter : public IRMutator { class StoragePlanRewriter : public IRMutator {
public: public:
using StmtEntry = LinearAccessPatternFinder::StmtEntry;
Stmt Rewrite(Stmt stmt) { Stmt Rewrite(Stmt stmt) {
std::vector<StmtEntry> seq = std::vector<StmtEntry> seq =
StorageAccessPatternFinder().GetLinearSeq(stmt); LinearAccessPatternFinder().GetLinearSeq(stmt);
this->FindFreeLocation(seq); this->FindFreeLocation(seq);
this->PlanMemory(seq); this->PlanMemory(seq);
this->PrepareNewAlloc(); this->PrepareNewAlloc();
...@@ -442,10 +453,10 @@ class StoragePlanRewriter : public IRMutator { ...@@ -442,10 +453,10 @@ class StoragePlanRewriter : public IRMutator {
std::unordered_set<const Variable*> touched; std::unordered_set<const Variable*> touched;
for (size_t i = seq.size(); i != 0; --i) { for (size_t i = seq.size(); i != 0; --i) {
const StmtEntry& s = seq[i - 1]; const StmtEntry& s = seq[i - 1];
for (const AccessEntry& e : s.access) { for (const Variable* buffer : s.touched) {
if (!touched.count(e.buffer)) { if (!touched.count(buffer)) {
touched.insert(e.buffer); touched.insert(buffer);
free_loc_[i - 1].push_back(e.buffer); free_loc_[i - 1].push_back(buffer);
} }
} }
} }
...@@ -474,7 +485,7 @@ class StoragePlanRewriter : public IRMutator { ...@@ -474,7 +485,7 @@ class StoragePlanRewriter : public IRMutator {
} }
} else if (s.stmt->is_type<Allocate>()) { } else if (s.stmt->is_type<Allocate>()) {
const auto* op = static_cast<const Allocate*>(s.stmt); const auto* op = static_cast<const Allocate*>(s.stmt);
StorageEntry* e = this->FindAlloc(op, s.access[0].scope); StorageEntry* e = this->FindAlloc(op, s.alloc_scope);
e->allocs.emplace_back(op); e->allocs.emplace_back(op);
alloc_map_[op->buffer_var.get()] = e; alloc_map_[op->buffer_var.get()] = e;
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment