Commit 989dda89 by Tianqi Chen Committed by GitHub

[PASS] Simplify dependency of StorageRewrite (#291)

parent 657498a3
......@@ -13,14 +13,16 @@
#include <unordered_set>
#include <unordered_map>
#include "./ir_util.h"
#include "./storage_access.h"
#include "../arithmetic/compute_expr.h"
#include "../runtime/thread_storage_scope.h"
namespace tvm {
namespace ir {
using namespace storage;
using runtime::StorageScope;
// Find a linear pattern of storage acess
// Used for liveness analysis.
// Composite scopes(loop/thread_launch/IfThen) is represented by two points:
// before_scope -> scope_body -> after_scope
//
......@@ -33,8 +35,18 @@ using namespace storage;
// The storage need to be kept alive between allocate and last access.
// The free point is only inserted at the same scope of allocate.
//
class StorageAccessPatternFinder final : public IRVisitor {
class LinearAccessPatternFinder final : public IRVisitor {
public:
/*! \brief record the touch hist of statment. */
struct StmtEntry {
// The statment
const Node* stmt;
// Scope used for allocation.
StorageScope alloc_scope;
// The buffer variables this statment touched.
std::vector<const Variable*> touched;
};
// Get linear access pattern.
std::vector<StmtEntry> GetLinearSeq(const Stmt& s) {
this->Visit(s);
......@@ -49,8 +61,8 @@ class StorageAccessPatternFinder final : public IRVisitor {
alloc_scope_level_[buf] = level;
StmtEntry e;
e.stmt = op;
e.access.emplace_back(
AccessEntry(buf, Expr(), kAlloc, GetScope(buf)));
e.alloc_scope = GetScope(buf);
e.touched.push_back(buf);
linear_seq_.emplace_back(std::move(e));
IRVisitor::Visit_(op);
}
......@@ -62,12 +74,11 @@ class StorageAccessPatternFinder final : public IRVisitor {
const Variable* buf = op->buffer_var.get();
auto it = alloc_scope_level_.find(buf);
if (it != alloc_scope_level_.end()) {
scope_[it->second].access.emplace_back(
AccessEntry(buf, op->index, kWrite, GetScope(buf)));
scope_[it->second].touched.push_back(buf);
}
StmtEntry e = scope_.back();
scope_.pop_back();
if (e.access.size() != 0) {
if (e.touched.size() != 0) {
e.stmt = op;
linear_seq_.push_back(e);
}
......@@ -78,7 +89,7 @@ class StorageAccessPatternFinder final : public IRVisitor {
IRVisitor::Visit_(op);
StmtEntry e = scope_.back();
scope_.pop_back();
if (e.access.size() != 0) {
if (e.touched.size() != 0) {
e.stmt = op;
linear_seq_.push_back(e);
}
......@@ -91,8 +102,7 @@ class StorageAccessPatternFinder final : public IRVisitor {
if (it != alloc_scope_level_.end()) {
CHECK_LT(it->second, scope_.size())
<< "Load memory in places other than store.";
scope_[it->second].access.emplace_back(
AccessEntry(buf, op->index, kRead, GetScope(buf)));
scope_[it->second].touched.push_back(buf);
}
}
void Visit_(const Call* op) final {
......@@ -108,8 +118,7 @@ class StorageAccessPatternFinder final : public IRVisitor {
auto it = alloc_scope_level_.find(buf);
if (it != alloc_scope_level_.end()) {
CHECK_LT(it->second, scope_.size()) << " buf=" << buf->name_hint;
scope_[it->second].access.emplace_back(
AccessEntry(buf, Expr(), kOpaque, GetScope(buf)));
scope_[it->second].touched.push_back(buf);
}
}
template<typename T>
......@@ -121,7 +130,7 @@ class StorageAccessPatternFinder final : public IRVisitor {
linear_seq_.push_back(e);
IRVisitor::Visit_(op);
// after scope.
e.access = std::move(scope_.back().access);
e.touched = std::move(scope_.back().touched);
scope_.pop_back();
linear_seq_.push_back(e);
}
......@@ -178,9 +187,11 @@ class StorageAccessPatternFinder final : public IRVisitor {
// Planner to plan and rewrite memory allocation.
class StoragePlanRewriter : public IRMutator {
public:
using StmtEntry = LinearAccessPatternFinder::StmtEntry;
Stmt Rewrite(Stmt stmt) {
std::vector<StmtEntry> seq =
StorageAccessPatternFinder().GetLinearSeq(stmt);
LinearAccessPatternFinder().GetLinearSeq(stmt);
this->FindFreeLocation(seq);
this->PlanMemory(seq);
this->PrepareNewAlloc();
......@@ -442,10 +453,10 @@ class StoragePlanRewriter : public IRMutator {
std::unordered_set<const Variable*> touched;
for (size_t i = seq.size(); i != 0; --i) {
const StmtEntry& s = seq[i - 1];
for (const AccessEntry& e : s.access) {
if (!touched.count(e.buffer)) {
touched.insert(e.buffer);
free_loc_[i - 1].push_back(e.buffer);
for (const Variable* buffer : s.touched) {
if (!touched.count(buffer)) {
touched.insert(buffer);
free_loc_[i - 1].push_back(buffer);
}
}
}
......@@ -474,7 +485,7 @@ class StoragePlanRewriter : public IRMutator {
}
} else if (s.stmt->is_type<Allocate>()) {
const auto* op = static_cast<const Allocate*>(s.stmt);
StorageEntry* e = this->FindAlloc(op, s.access[0].scope);
StorageEntry* e = this->FindAlloc(op, s.alloc_scope);
e->allocs.emplace_back(op);
alloc_map_[op->buffer_var.get()] = e;
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment