Commit 2ff74317 by Tianqi Chen Committed by GitHub

[PASS] StorageRewrite Fold Inplace op storage when possible (#759)

* [PASS] StorageRewrite Fold Inplace op storage when possible

* update comment to fix typos
parent 9d6dbe34
...@@ -153,6 +153,12 @@ constexpr const char* coproc_uop_scope = "coproc_uop_scope"; ...@@ -153,6 +153,12 @@ constexpr const char* coproc_uop_scope = "coproc_uop_scope";
/*! \brief Mark the scope as volatile access for certain handle. */ /*! \brief Mark the scope as volatile access for certain handle. */
constexpr const char* volatile_scope = "volatile_scope"; constexpr const char* volatile_scope = "volatile_scope";
/*! /*!
* \brief Mark the scope as generated by extern primitive.
* such scope can contain arbitrary ir program and we need to be careful
* when make certain assumptions about the structure of the program.
*/
constexpr const char* extern_scope = "extern_scope";
/*!
* \brief Mark the scope as when computation start to happen * \brief Mark the scope as when computation start to happen
* This can hint some code generator to create a new function for compute. * This can hint some code generator to create a new function for compute.
*/ */
......
...@@ -130,7 +130,7 @@ Stmt ExternOpNode::BuildProvide( ...@@ -130,7 +130,7 @@ Stmt ExternOpNode::BuildProvide(
const Stage& stage, const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map) const { const std::unordered_map<IterVar, Range>& dom_map) const {
CHECK_EQ(stage->op.operator->(), this); CHECK_EQ(stage->op.operator->(), this);
Stmt ret = this->body; Stmt ret = AttrStmt::make(make_zero(Int(32)), attr::extern_scope, 0, this->body);
auto f_push_bind = [&ret](Buffer buffer, Tensor tensor) { auto f_push_bind = [&ret](Buffer buffer, Tensor tensor) {
Array<NodeRef> bind_spec; Array<NodeRef> bind_spec;
Array<Expr> tuple; Array<Expr> tuple;
......
...@@ -41,27 +41,32 @@ class LinearAccessPatternFinder final : public IRVisitor { ...@@ -41,27 +41,32 @@ class LinearAccessPatternFinder final : public IRVisitor {
struct StmtEntry { struct StmtEntry {
// The statment // The statment
const Node* stmt; const Node* stmt;
// Scope used for allocation. // The index in the linear_seq_ to point to end of the nested scope.
StorageScope alloc_scope; // This is only set to non-zero if stmt is a nested scope.
// if offset > 0, means this is the begin, the end entry is current_index + offset
// if offset < 0, means this is the end, the begin entry is current_index + offset
int64_t scope_pair_offset{0};
// The buffer variables this statment touched. // The buffer variables this statment touched.
std::vector<const Variable*> touched; std::vector<const Variable*> touched;
}; };
// The scope of each allocation
struct AllocEntry {
// Scope used for allocation.
StorageScope storage_scope;
// scope level
size_t level{0};
// allocation stmt
const Allocate* alloc{nullptr};
};
// Get linear access pattern.
std::vector<StmtEntry> GetLinearSeq(const Stmt& s) {
this->Visit(s);
return std::move(linear_seq_);
}
void Visit_(const Allocate* op) final { void Visit_(const Allocate* op) final {
size_t level = scope_.size(); size_t level = scope_.size();
const Variable* buf = op->buffer_var.get(); const Variable* buf = op->buffer_var.get();
CHECK(!alloc_scope_level_.count(buf)); auto it = alloc_info_.find(buf);
alloc_scope_level_[buf] = level; CHECK(it != alloc_info_.end());
StmtEntry e; CHECK(it->second.alloc == nullptr);
e.stmt = op; it->second.alloc = op;
e.alloc_scope = GetScope(buf); it->second.level = level;
e.touched.push_back(buf);
linear_seq_.emplace_back(std::move(e));
IRVisitor::Visit_(op); IRVisitor::Visit_(op);
} }
void Visit_(const Store* op) final { void Visit_(const Store* op) final {
...@@ -70,9 +75,10 @@ class LinearAccessPatternFinder final : public IRVisitor { ...@@ -70,9 +75,10 @@ class LinearAccessPatternFinder final : public IRVisitor {
IRVisitor::Visit_(op); IRVisitor::Visit_(op);
// Add write access. // Add write access.
const Variable* buf = op->buffer_var.get(); const Variable* buf = op->buffer_var.get();
auto it = alloc_scope_level_.find(buf); auto it = alloc_info_.find(buf);
if (it != alloc_scope_level_.end()) { if (it != alloc_info_.end() && it->second.alloc) {
scope_[it->second].touched.push_back(buf); CHECK_LT(it->second.level, scope_.size());
scope_[it->second.level].touched.push_back(buf);
} }
StmtEntry e = scope_.back(); StmtEntry e = scope_.back();
scope_.pop_back(); scope_.pop_back();
...@@ -96,11 +102,11 @@ class LinearAccessPatternFinder final : public IRVisitor { ...@@ -96,11 +102,11 @@ class LinearAccessPatternFinder final : public IRVisitor {
// Add write access. // Add write access.
IRVisitor::Visit_(op); IRVisitor::Visit_(op);
const Variable* buf = op->buffer_var.get(); const Variable* buf = op->buffer_var.get();
auto it = alloc_scope_level_.find(buf); auto it = alloc_info_.find(buf);
if (it != alloc_scope_level_.end()) { if (it != alloc_info_.end() && it->second.alloc) {
CHECK_LT(it->second, scope_.size()) CHECK_LT(it->second.level, scope_.size())
<< "Load memory in places other than store."; << "Load memory in places other than store.";
scope_[it->second].touched.push_back(buf); scope_[it->second.level].touched.push_back(buf);
} }
} }
void Visit_(const Call* op) final { void Visit_(const Call* op) final {
...@@ -113,10 +119,11 @@ class LinearAccessPatternFinder final : public IRVisitor { ...@@ -113,10 +119,11 @@ class LinearAccessPatternFinder final : public IRVisitor {
} }
void Visit_(const Variable* buf) final { void Visit_(const Variable* buf) final {
// Directly reference to the variable count as a read. // Directly reference to the variable count as a read.
auto it = alloc_scope_level_.find(buf); auto it = alloc_info_.find(buf);
if (it != alloc_scope_level_.end()) { if (it != alloc_info_.end() && it->second.alloc) {
CHECK_LT(it->second, scope_.size()) << " buf=" << buf->name_hint; CHECK_LT(it->second.level, scope_.size())
scope_[it->second].touched.push_back(buf); << " buf=" << buf->name_hint;
scope_[it->second.level].touched.push_back(buf);
} }
} }
template<typename T> template<typename T>
...@@ -124,13 +131,20 @@ class LinearAccessPatternFinder final : public IRVisitor { ...@@ -124,13 +131,20 @@ class LinearAccessPatternFinder final : public IRVisitor {
scope_.push_back(StmtEntry()); scope_.push_back(StmtEntry());
StmtEntry e; StmtEntry e;
e.stmt = op; e.stmt = op;
int64_t begin_index = static_cast<int64_t>(linear_seq_.size());
// before scope. // before scope.
linear_seq_.push_back(e); linear_seq_.push_back(e);
IRVisitor::Visit_(op); IRVisitor::Visit_(op);
// after scope. // after scope.
e.touched = std::move(scope_.back().touched); e.touched = std::move(scope_.back().touched);
scope_.pop_back(); scope_.pop_back();
int64_t end_index = static_cast<int64_t>(linear_seq_.size());
CHECK_GT(end_index, begin_index);
e.scope_pair_offset = begin_index - end_index;
linear_seq_.push_back(e); linear_seq_.push_back(e);
// record the pointer to end index.
CHECK_NE(end_index, 0U);
linear_seq_[begin_index].scope_pair_offset = end_index - begin_index;
} }
void Visit_(const AttrStmt* op) final { void Visit_(const AttrStmt* op) final {
// Only record the outer most thread extent. // Only record the outer most thread extent.
...@@ -138,9 +152,11 @@ class LinearAccessPatternFinder final : public IRVisitor { ...@@ -138,9 +152,11 @@ class LinearAccessPatternFinder final : public IRVisitor {
in_thread_env_ = true; in_thread_env_ = true;
VisitNewScope(op); VisitNewScope(op);
in_thread_env_ = false; in_thread_env_ = false;
} else if (op->attr_key == attr::extern_scope) {
VisitNewScope(op);
} else if (op->attr_key == attr::storage_scope) { } else if (op->attr_key == attr::storage_scope) {
const Variable* buf = op->node.as<Variable>(); const Variable* buf = op->node.as<Variable>();
storage_scope_[buf] = alloc_info_[buf].storage_scope =
StorageScope::make(op->value.as<StringImm>()->value); StorageScope::make(op->value.as<StringImm>()->value);
IRVisitor::Visit_(op); IRVisitor::Visit_(op);
} else { } else {
...@@ -155,36 +171,156 @@ class LinearAccessPatternFinder final : public IRVisitor { ...@@ -155,36 +171,156 @@ class LinearAccessPatternFinder final : public IRVisitor {
VisitNewScope(op); VisitNewScope(op);
} }
// linearized access sequence.
std::vector<StmtEntry> linear_seq_;
// The storage scope of each buffer
std::unordered_map<const Variable*, AllocEntry> alloc_info_;
private: private:
// Get storage scope of buffer.
StorageScope GetScope(const Variable* buf) const {
auto it = storage_scope_.find(buf);
CHECK(it != storage_scope_.end());
return it->second;
}
// Whether already in thread env. // Whether already in thread env.
bool in_thread_env_{false}; bool in_thread_env_{false};
// linearized access sequence.
std::vector<StmtEntry> linear_seq_;
// The scope stack. // The scope stack.
std::vector<StmtEntry> scope_; std::vector<StmtEntry> scope_;
// The storage scope of each buffer };
std::unordered_map<const Variable*, StorageScope> storage_scope_;
// buffer -> allocated scope level in the IR. // Verify if the statement can be run safely via inplace fashion
std::unordered_map<const Variable*, size_t> alloc_scope_level_; //
// Detect pattern: dst[index] = f(src[index])
//
// WARNING: the current detection algorithm cannot handle the case
// when a location in an array is written multiple times
//
// For example, the following program will pass the check,
// but we cannot make A and B to be the same array.
//
// A[0] = B[0] + 1
// A[0] = B[0] + 1
//
// The high level code generator needs to ensure that the generated
// code only write each location of the target array once.
//
// This is the case with IR generated by the current compute schedule.
// We explicitly return false if we find there is an extern block
// which can be arbitrary IR.
//
// Neve-the-less, inplace detector should be used with care in mind.
// We may also consider introduce a condition checker that checks
// if every index only visited once for an absolute sufficient condition.
//
// The code after inplace transformation is no longer idempotent.
//
class InplaceOpVerifier : public IRVisitor {
public:
bool Check(const Node* stmt,
const Variable* dst,
const Variable* src) {
dst_ = dst;
src_ = src;
result_ = true;
if (stmt->is_type<AttrStmt>()) {
Visit_(static_cast<const AttrStmt*>(stmt));
} else if (stmt->is_type<For>()) {
Visit_(static_cast<const For*>(stmt));
} else if (stmt->is_type<IfThenElse>()) {
Visit_(static_cast<const IfThenElse*>(stmt));
} else if (stmt->is_type<Store>()) {
Visit_(static_cast<const Store*>(stmt));
} else {
return false;
}
return result_;
}
using IRVisitor::Visit_;
void Visit(const NodeRef& e) final {
if (!result_) return;
IRVisitor::Visit(e);
}
void Visit_(const Variable* op) final {
// assume all opaque access is unsafe
if (op == dst_ || op == src_) {
result_ = false; return;
}
}
void Visit_(const Store* op) final {
++mem_nest_;
this->Visit(op->index);
--mem_nest_;
if (op->buffer_var.get() == dst_) {
store_ = op;
this->Visit(op->value);
this->Visit(op->predicate);
store_ = nullptr;
} else {
this->Visit(op->value);
this->Visit(op->predicate);
}
}
void Visit_(const AttrStmt* op) final {
// always reject extern code
if (op->attr_key == attr::extern_scope ||
op->attr_key == attr::volatile_scope) {
result_ = false; return;
}
IRVisitor::Visit_(op);
}
void Visit_(const Load* op) final {
const Variable* buf = op->buffer_var.get();
// cannot read from dst_ (no reduction)
if (buf == dst_) {
result_ = false; return;
}
// do not allow indirect memory load
if (mem_nest_ != 0) {
result_ = false; return;
}
if (src_ == buf) {
if (store_ == nullptr ||
store_->value.type() != op->type ||
!ir::Equal(store_->index, op->index)) {
result_ = false; return;
}
}
++mem_nest_;
IRVisitor::Visit_(op);
--mem_nest_;
}
private:
// result of the check
bool result_{true};
// destination memory
const Variable* dst_;
// source variable
const Variable* src_;
// counter of load,
// it is not safe to inplace when there is nested load like A[B[i]]
int mem_nest_{0};
// The current store to be inspected
const Store* store_{nullptr};
}; };
// Planner to plan and rewrite memory allocation. // Planner to plan and rewrite memory allocation.
class StoragePlanRewriter : public IRMutator { class StoragePlanRewriter : public IRMutator {
public: public:
using StmtEntry = LinearAccessPatternFinder::StmtEntry; using StmtEntry = LinearAccessPatternFinder::StmtEntry;
using AllocEntry = LinearAccessPatternFinder::AllocEntry;
Stmt Rewrite(Stmt stmt) { Stmt Rewrite(Stmt stmt, bool detect_inplace) {
std::vector<StmtEntry> seq = detect_inplace_ = detect_inplace;
LinearAccessPatternFinder().GetLinearSeq(stmt); // plan the rewrite
this->FindFreeLocation(seq); LinearAccessPatternFinder finder;
this->PlanMemory(seq); finder.Visit(stmt);
this->LivenessAnalysis(finder.linear_seq_);
this->PlanMemory(finder.linear_seq_, finder.alloc_info_);
this->PrepareNewAlloc(); this->PrepareNewAlloc();
// start rewrite
stmt = this->Mutate(stmt); stmt = this->Mutate(stmt);
if (attach_map_.count(nullptr)) { if (attach_map_.count(nullptr)) {
std::vector<Stmt> nest; std::vector<Stmt> nest;
...@@ -308,7 +444,6 @@ class StoragePlanRewriter : public IRMutator { ...@@ -308,7 +444,6 @@ class StoragePlanRewriter : public IRMutator {
} }
private: private:
// Alllocate entry of node.
struct StorageEntry { struct StorageEntry {
// The scope that this alloc attaches after // The scope that this alloc attaches after
// For shared/local memory it is beginning of the thread extent. // For shared/local memory it is beginning of the thread extent.
...@@ -332,6 +467,16 @@ class StoragePlanRewriter : public IRMutator { ...@@ -332,6 +467,16 @@ class StoragePlanRewriter : public IRMutator {
// the address becomes alloc_var + sizeof(elem_type) * elem_offset; // the address becomes alloc_var + sizeof(elem_type) * elem_offset;
uint64_t elem_offset{0}; uint64_t elem_offset{0};
}; };
// Alllocate entry of node.
// Event entry in liveness analysis
struct EventEntry {
// variables we generate
std::vector<const Variable*> gen;
// variables we kill
std::vector<const Variable*> kill;
};
Stmt MakeAttach(const std::vector<StorageEntry*>& svec, Stmt MakeAttach(const std::vector<StorageEntry*>& svec,
Stmt body) { Stmt body) {
std::vector<Stmt> nest; std::vector<Stmt> nest;
...@@ -461,16 +606,29 @@ class StoragePlanRewriter : public IRMutator { ...@@ -461,16 +606,29 @@ class StoragePlanRewriter : public IRMutator {
<< "Allocation exceed bound of memory tag " << e->scope.to_string(); << "Allocation exceed bound of memory tag " << e->scope.to_string();
} }
} }
// Find the free location of each varaible. // Liveness analysis to find gen and kill point of each variable.
// Just do a reverse linear scan. void LivenessAnalysis(const std::vector<StmtEntry>& seq) {
void FindFreeLocation(const std::vector<StmtEntry>& seq) { // find kill point, do a reverse linear scan.
std::unordered_set<const Variable*> touched; std::unordered_set<const Variable*> touched;
for (size_t i = seq.size(); i != 0; --i) { for (size_t i = seq.size(); i != 0; --i) {
const StmtEntry& s = seq[i - 1]; const StmtEntry& s = seq[i - 1];
for (const Variable* buffer : s.touched) { for (const Variable* buffer : s.touched) {
if (!touched.count(buffer)) { if (!touched.count(buffer)) {
touched.insert(buffer); touched.insert(buffer);
free_loc_[i - 1].push_back(buffer); event_map_[s.stmt].kill.push_back(buffer);
}
}
}
// find gen point, do forward scan
touched.clear();
for (size_t i = 0; i < seq.size(); ++i) {
int64_t offset = seq[i].scope_pair_offset;
if (offset < 0) continue;
const StmtEntry& s = seq[i + offset];
for (const Variable* buffer : s.touched) {
if (!touched.count(buffer)) {
touched.insert(buffer);
event_map_[s.stmt].gen.push_back(buffer);
} }
} }
} }
...@@ -500,14 +658,66 @@ class StoragePlanRewriter : public IRMutator { ...@@ -500,14 +658,66 @@ class StoragePlanRewriter : public IRMutator {
} }
// Memory plan algorithm // Memory plan algorithm
void PlanMemory(const std::vector<StmtEntry>& seq) { void PlanMemory(const std::vector<StmtEntry>& seq,
const std::unordered_map<const Variable*, AllocEntry>& alloc_info) {
std::unordered_set<const Variable*> inplace_flag;
for (size_t i = 0; i < seq.size(); ++i) { for (size_t i = 0; i < seq.size(); ++i) {
const StmtEntry& s = seq[i]; const StmtEntry& s = seq[i];
auto it = event_map_.find(seq[i].stmt);
// scope_pair_offset >= 0 means it is either
// - leaf stmt(offset = 0)
// - beginning of scope(offset < 0)
// In both cases, we need to handle the gen event correctly
if (it != event_map_.end() && seq[i].scope_pair_offset >= 0) {
// Inplace operation detection
// specially handle this
bool detect_inplace = detect_inplace_ && (it->second.gen.size() <= 2);
for (const Variable* var : it->second.gen) {
CHECK(alloc_info.count(var));
const AllocEntry& ae = alloc_info.at(var);
StorageEntry* dst_entry = nullptr;
// inplace detection
if (detect_inplace) {
for (const Variable* src : it->second.kill) {
if (!inplace_flag.count(src) && alloc_map_.count(src)) {
InplaceOpVerifier visitor;
StorageEntry* src_entry = alloc_map_.at(src);
if (src_entry->scope == ae.storage_scope &&
src_entry->attach_scope_ == thread_scope_ &&
src_entry->elem_type == ae.alloc->type.element_of() &&
visitor.Check(s.stmt, var, src)) {
uint64_t const_nbits = static_cast<uint64_t>(
ae.alloc->constant_allocation_size() *
ae.alloc->type.bits() *
ae.alloc->type.lanes());
if (src_entry->const_nbits == const_nbits) {
// successfully inplace
dst_entry = src_entry;
inplace_flag.insert(src);
}
}
}
}
}
if (dst_entry == nullptr) {
dst_entry = FindAlloc(ae.alloc, thread_scope_, ae.storage_scope);
}
dst_entry->allocs.emplace_back(ae.alloc);
alloc_map_[var] = dst_entry;
}
}
// enter/exit new scope
if (s.stmt->is_type<AttrStmt>()) { if (s.stmt->is_type<AttrStmt>()) {
const auto* op = static_cast<const AttrStmt*>(s.stmt); const auto* op = static_cast<const AttrStmt*>(s.stmt);
CHECK(op->attr_key == attr::thread_extent || if (op->attr_key == attr::thread_extent ||
op->attr_key == attr::pragma_scope); op->attr_key == attr::pragma_scope) {
PlanNewScope(op); PlanNewScope(op);
} else {
CHECK(op->attr_key == attr::extern_scope);
}
} else if (s.stmt->is_type<For>()) { } else if (s.stmt->is_type<For>()) {
const auto* op = static_cast<const For*>(s.stmt); const auto* op = static_cast<const For*>(s.stmt);
if (op->for_type == ForType::Parallel) { if (op->for_type == ForType::Parallel) {
...@@ -515,16 +725,17 @@ class StoragePlanRewriter : public IRMutator { ...@@ -515,16 +725,17 @@ class StoragePlanRewriter : public IRMutator {
PlanNewScope(op); PlanNewScope(op);
} }
} }
} else if (s.stmt->is_type<Allocate>()) {
const auto* op = static_cast<const Allocate*>(s.stmt);
StorageEntry* e = this->FindAlloc(op, thread_scope_, s.alloc_scope);
e->allocs.emplace_back(op);
alloc_map_[op->buffer_var.get()] = e;
} }
// free list // scope_pair_offset <= 0 means it is either
if (free_loc_.count(i)) { // - leaf stmt(offset = 0)
for (const Variable* var : free_loc_.at(i)) { // - end of scope(offset < 0)
this->Free(var); // In both cases, we need to handle the kill event correctly
if (it != event_map_.end() && seq[i].scope_pair_offset <= 0) {
for (const Variable* var : it->second.kill) {
// skip space which are already replaced by inplace
if (!inplace_flag.count(var)) {
this->Free(var);
}
} }
} }
} }
...@@ -534,6 +745,7 @@ class StoragePlanRewriter : public IRMutator { ...@@ -534,6 +745,7 @@ class StoragePlanRewriter : public IRMutator {
const Node* attach_scope, const Node* attach_scope,
const StorageScope& scope, const StorageScope& scope,
size_t const_nbits) { size_t const_nbits) {
CHECK(op != nullptr);
// Re-use not successful, allocate a new buffer. // Re-use not successful, allocate a new buffer.
std::unique_ptr<StorageEntry> entry(new StorageEntry()); std::unique_ptr<StorageEntry> entry(new StorageEntry());
entry->attach_scope_ = attach_scope; entry->attach_scope_ = attach_scope;
...@@ -544,9 +756,11 @@ class StoragePlanRewriter : public IRMutator { ...@@ -544,9 +756,11 @@ class StoragePlanRewriter : public IRMutator {
alloc_vec_.emplace_back(std::move(entry)); alloc_vec_.emplace_back(std::move(entry));
return e; return e;
} }
StorageEntry* FindAlloc(const Allocate* op, StorageEntry* FindAlloc(const Allocate* op,
const Node* attach_scope, const Node* attach_scope,
const StorageScope& scope) { const StorageScope& scope) {
CHECK(op != nullptr);
// skip plan for local variable, // skip plan for local variable,
// compiler can do a better job with register allocation. // compiler can do a better job with register allocation.
const uint64_t match_range = 16; const uint64_t match_range = 16;
...@@ -603,6 +817,7 @@ class StoragePlanRewriter : public IRMutator { ...@@ -603,6 +817,7 @@ class StoragePlanRewriter : public IRMutator {
auto it = alloc_map_.find(var); auto it = alloc_map_.find(var);
CHECK(it != alloc_map_.end()); CHECK(it != alloc_map_.end());
StorageEntry* e = it->second; StorageEntry* e = it->second;
CHECK_NE(e->allocs.size(), 0U);
// Disable sharing of local memory. // Disable sharing of local memory.
if (e->scope.rank > 1 || e->allocs[0]->type.is_handle()) return; if (e->scope.rank > 1 || e->allocs[0]->type.is_handle()) return;
// disable reuse of small arrays // disable reuse of small arrays
...@@ -616,17 +831,18 @@ class StoragePlanRewriter : public IRMutator { ...@@ -616,17 +831,18 @@ class StoragePlanRewriter : public IRMutator {
} }
// thread scope. // thread scope.
const Node* thread_scope_{nullptr}; const Node* thread_scope_{nullptr};
// whether enable inplace detection.
bool detect_inplace_{false};
// Locations of free ops. // Locations of free ops.
std::unordered_map<size_t, std::unordered_map<const Node*, EventEntry> event_map_;
std::vector<const Variable*> > free_loc_;
// The allocation attach map
std::unordered_map<const Node*, std::vector<StorageEntry*> > attach_map_;
// The allocation assign map
std::unordered_map<const Variable*, StorageEntry*> alloc_map_;
// constant size free map. // constant size free map.
std::multimap<uint64_t, StorageEntry*> const_free_map_; std::multimap<uint64_t, StorageEntry*> const_free_map_;
// symbolic free list, for non constant items. // symbolic free list, for non constant items.
std::list<StorageEntry*> sym_free_list_; std::list<StorageEntry*> sym_free_list_;
// The allocation attach map
std::unordered_map<const Node*, std::vector<StorageEntry*> > attach_map_;
// The allocation assign map
std::unordered_map<const Variable*, StorageEntry*> alloc_map_;
// The allocations // The allocations
std::vector<std::unique_ptr<StorageEntry> > alloc_vec_; std::vector<std::unique_ptr<StorageEntry> > alloc_vec_;
}; };
...@@ -693,7 +909,7 @@ class VectorAllocRewriter : public IRMutator { ...@@ -693,7 +909,7 @@ class VectorAllocRewriter : public IRMutator {
Stmt StorageRewrite(Stmt stmt) { Stmt StorageRewrite(Stmt stmt) {
stmt = StoragePlanRewriter().Rewrite(stmt); stmt = StoragePlanRewriter().Rewrite(stmt, true);
return VectorAllocRewriter().Mutate(stmt); return VectorAllocRewriter().Mutate(stmt);
} }
} // namespace ir } // namespace ir
......
...@@ -15,6 +15,7 @@ def test_add_pipeline(): ...@@ -15,6 +15,7 @@ def test_add_pipeline():
C = tvm.extern(A.shape, [A], extern_generator, name='C') C = tvm.extern(A.shape, [A], extern_generator, name='C')
s = tvm.create_schedule(C.op) s = tvm.create_schedule(C.op)
print(tvm.lower(s, [A, C], simple_mode=True))
def check_llvm(): def check_llvm():
if not tvm.module.enabled("llvm"): if not tvm.module.enabled("llvm"):
......
...@@ -19,14 +19,39 @@ def test_storage_share(): ...@@ -19,14 +19,39 @@ def test_storage_share():
stmt = tvm.ir_pass.CanonicalSimplify(stmt) stmt = tvm.ir_pass.CanonicalSimplify(stmt)
stmt = tvm.ir_pass.Simplify(stmt) stmt = tvm.ir_pass.Simplify(stmt)
stmt = tvm.ir_pass.StorageRewrite(stmt) stmt = tvm.ir_pass.StorageRewrite(stmt)
# verify only have two allocations. # verify only have one allocations.
# verify that the data is folded. # verify inplace folding works
num_alloc = [0]
def verify(n):
if isinstance(n, tvm.stmt.Allocate):
num_alloc[0] += 1
tvm.ir_pass.PostOrderVisit(stmt, verify)
assert num_alloc[0] == 1
def test_inplace_rule():
m = 10
A = tvm.placeholder((m,), name='A')
A0 = tvm.compute((m,), lambda i: A[i], name='A0')
A1 = tvm.compute((m,), lambda i: A[i] + 1, name='A1')
AA = tvm.compute((m,), lambda i: A0[i] + A1[i] + A1[0], name='AA')
B = tvm.compute((m,), lambda i: AA[i] + 1, name='B')
s = tvm.create_schedule(B.op)
bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.container.Map)
stmt = tvm.schedule.ScheduleOps(s, bounds)
Ab = tvm.decl_buffer(A.shape, A.dtype, name='A')
Bb = tvm.decl_buffer(B.shape, B.dtype, name='B')
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64)
stmt = tvm.ir_pass.CanonicalSimplify(stmt)
stmt = tvm.ir_pass.Simplify(stmt)
stmt = tvm.ir_pass.StorageRewrite(stmt)
# verify only have one allocations.
# verify inplace folding works
num_alloc = [0] num_alloc = [0]
def verify(n): def verify(n):
if isinstance(n, tvm.stmt.Allocate): if isinstance(n, tvm.stmt.Allocate):
num_alloc[0] += 1 num_alloc[0] += 1
elif isinstance(n, tvm.stmt.Store):
assert n.buffer_var != n.value.a.buffer_var
tvm.ir_pass.PostOrderVisit(stmt, verify) tvm.ir_pass.PostOrderVisit(stmt, verify)
assert num_alloc[0] == 2 assert num_alloc[0] == 2
...@@ -38,7 +63,7 @@ def test_storage_combine(): ...@@ -38,7 +63,7 @@ def test_storage_combine():
B = A B = A
stages = [] stages = []
for t in range(num_stage): for t in range(num_stage):
B = tvm.compute((n, ), lambda i: B[i] + (t+1), name='A%d' % t) B = tvm.compute((n, ), lambda i: B[i] + B[0] + (t+1), name='A%d' % t)
stages.append(B) stages.append(B)
s = tvm.create_schedule(B.op) s = tvm.create_schedule(B.op)
...@@ -121,12 +146,14 @@ def test_parallel_alloc(): ...@@ -121,12 +146,14 @@ def test_parallel_alloc():
A[j] = A[j] + 2 A[j] = A[j] + 2
body = ib.get() body = ib.get()
body = tvm.ir_pass.StorageRewrite(body) body = tvm.ir_pass.StorageRewrite(body)
assert(isinstance(body.body.body.body.body, tvm.stmt.Allocate)) assert(isinstance(body.body.body.body.body, tvm.stmt.Allocate))
if __name__ == "__main__": if __name__ == "__main__":
test_inplace_rule()
test_storage_share()
test_parallel_alloc() test_parallel_alloc()
test_storage_combine() test_storage_combine()
test_storage_share_gpu() test_storage_share_gpu()
test_storage_share()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment