Commit e4c74668 by Tianqi Chen Committed by GitHub

[PASS] StorageRewrite, Memory optimization pass as in NNVM. (#104)

* [PASS] StorageRewrite, reuse memory pass as in NNVM.

* fix issue
parent 1245cc25
......@@ -158,14 +158,15 @@ Stmt VectorizeLoop(Stmt stmt);
Stmt InjectVirtualThread(Stmt stmt);
/*!
* \brief Lift storage allocation to relevant outpost location
*
* Only do this after vectorization and virtual thread injection completes.
* \brief Rewrite storage allocation pattern.
* Moves the allocation to outer most possible scope.
* Trying to share space between allocations to make
* a static allocation plan when possible.
*
* \param stmt The stmt to be trasnformed
* \return Transformed stmt.
*/
Stmt LiftAllocate(Stmt stmt);
Stmt StorageRewrite(Stmt stmt);
/*!
* \brief partition loops in the stmt
......
......@@ -70,7 +70,7 @@ def lower(sch,
stmt = ir_pass.CanonicalSimplify(stmt)
stmt = ir_pass.VectorizeLoop(stmt)
stmt = ir_pass.InjectVirtualThread(stmt)
stmt = ir_pass.LiftAllocate(stmt)
stmt = ir_pass.StorageRewrite(stmt)
stmt = ir_pass.UnrollLoop(stmt, max_auto_unroll_step)
stmt = ir_pass.Simplify(stmt)
if not with_api_wrapper:
......
......@@ -68,7 +68,7 @@ REGISTER_PASS2(UnrollLoop);
REGISTER_PASS2(StorageSync);
REGISTER_PASS4(MakeAPI);
REGISTER_PASS1(SplitHostDevice);
REGISTER_PASS1(LiftAllocate);
REGISTER_PASS1(StorageRewrite);
REGISTER_PASS1(InjectVirtualThread);
REGISTER_PASS1(LoopPartition);
REGISTER_PASS1(RemoveNoOp);
......
......@@ -249,7 +249,7 @@ llvm::BasicBlock* CodeGenLLVM::CheckCallSuccess(llvm::Value* retcode) {
}
void CodeGenLLVM::AddAliasInfo(
llvm::Instruction* inst, const Variable* buffer, Expr index) {
llvm::Instruction* inst, const Variable* buffer, Expr index, Type t) {
int base = 0, width = 0;
// create meta-data for alias analysis
// Use a group of binary tree ranges.
......@@ -274,9 +274,11 @@ void CodeGenLLVM::AddAliasInfo(
}
}
llvm::MDNode* meta = md_tbaa_root_;
std::ostringstream buffer_addr;
std::ostringstream buffer_addr, buffer_type;
buffer_addr << buffer;
meta = md_builder_->createTBAAScalarTypeNode(buffer_addr.str(), meta);
buffer_type << t.element_of();
meta = md_builder_->createTBAAScalarTypeNode(buffer_type.str(), meta);
// create a tree-shape access structure.
if (width != 0) {
for (int w = 1024; w >= width; w /= 2) {
......@@ -1033,7 +1035,7 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Load* op) {
llvm::LoadInst* inst = builder_->CreateAlignedLoad(
CreateBufferPtr(t, buf, MakeValue(op->index)),
data_layout_->getTypeAllocSize(LLVMType(t)));
AddAliasInfo(inst, op->buffer_var.get(), op->index);
AddAliasInfo(inst, op->buffer_var.get(), op->index, op->type);
return inst;
} else if (ramp && is_one(ramp->stride)) {
int alignment, native_bits;
......@@ -1053,7 +1055,7 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Load* op) {
llvm::LoadInst* inst = builder_->CreateAlignedLoad(
builder_->CreatePointerCast(ptr, vtype), alignment);
AddAliasInfo(inst, op->buffer_var.get(),
Ramp::make(base, make_const(base.type(), 1), lanes));
Ramp::make(base, make_const(base.type(), 1), lanes), op->type);
loads.push_back(inst);
}
return CreateVecConcat(loads);
......@@ -1127,7 +1129,7 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Load* op) {
llvm::Value* ptr = CreateBufferPtr(t.element_of(), buf, offset);
llvm::LoadInst* inst = builder_->CreateAlignedLoad(
ptr, data_layout_->getTypeAllocSize(LLVMType(t)));
AddAliasInfo(inst, op->buffer_var.get(), Expr());
AddAliasInfo(inst, op->buffer_var.get(), Expr(), op->type);
ret = builder_->CreateInsertElement(ret, inst, ConstInt32(i));
});
return ret;
......@@ -1146,7 +1148,7 @@ void CodeGenLLVM::VisitStmt_(const Store* op) {
value,
CreateBufferPtr(t, buf, MakeValue(op->index)),
data_layout_->getTypeAllocSize(value->getType()));
AddAliasInfo(inst, op->buffer_var.get(), op->index);
AddAliasInfo(inst, op->buffer_var.get(), op->index, op->value.type());
} else if (ramp && is_one(ramp->stride)) {
int alignment, native_bits;
GetAlignment(t, op->buffer_var.get(), ramp->base,
......@@ -1165,7 +1167,7 @@ void CodeGenLLVM::VisitStmt_(const Store* op) {
CreateVecSlice(value, offset, lanes),
builder_->CreatePointerCast(ptr, vtype), alignment);
AddAliasInfo(inst, op->buffer_var.get(),
Ramp::make(base, make_const(base.type(), 1), lanes));
Ramp::make(base, make_const(base.type(), 1), lanes), op->value.type());
}
} else {
Scalarize(op->index, [&](int i, llvm::Value* offset) {
......@@ -1173,7 +1175,7 @@ void CodeGenLLVM::VisitStmt_(const Store* op) {
llvm::StoreInst* inst = builder_->CreateAlignedStore(
builder_->CreateExtractElement(value, ConstInt32(i)),
ptr, data_layout_->getTypeAllocSize(LLVMType(t)));
AddAliasInfo(inst, op->buffer_var.get(), Expr());
AddAliasInfo(inst, op->buffer_var.get(), Expr(), op->value.type());
});
}
}
......
......@@ -211,7 +211,7 @@ class CodeGenLLVM :
// Add a function to set global module context
void InitGlobalContext();
// add alias information.
void AddAliasInfo(llvm::Instruction* load, const Variable* buffer, Expr index);
void AddAliasInfo(llvm::Instruction* load, const Variable* buffer, Expr index, Type type);
// The definition of local variable.
std::unordered_map<const Variable*, llvm::Value*> var_map_;
// global strings
......
/*!
* Copyright (c) 2017 by Contributors
* \file lift_allocate.cc
*/
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
#include <unordered_map>
#include "./ir_util.h"
#include "../runtime/thread_storage_scope.h"
namespace tvm {
namespace ir {
using runtime::StorageScope;
using runtime::ThreadScope;
class AllocateLifter : public IRMutator {
public:
Stmt Lift(Stmt stmt) {
stmt = this->Mutate(stmt);
StorageScope key; key.rank = 0;
stmt = MergeNest(allocs_[key], stmt);
return stmt;
}
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
CHECK(op->attr_key != attr::virtual_thread)
<< "InjectVirtualThread before LiftStorageAlloc";
if (op->attr_key == attr::storage_scope) {
StorageScope sc = StorageScope::make(op->value.as<StringImm>()->value);
allocs_[sc].emplace_back(
AttrStmt::make(
op->node, attr::storage_scope,
op->value, Evaluate::make(0)));
storage_scope_[op->node.get()] = sc;
return this->Mutate(op->body);
} else if (op->attr_key == attr::thread_extent) {
IterVar iv(op->node.node_);
ThreadScope ts = ThreadScope::make(iv->thread_tag);
curr_thread_scope_.push_back(ts);
Stmt stmt = IRMutator::Mutate_(op, s);
curr_thread_scope_.pop_back();
op = stmt.as<AttrStmt>();
bool first_scope = true;
for (const ThreadScope& t : curr_thread_scope_) {
if (t.rank == ts.rank) first_scope = false;
}
if (first_scope) {
StorageScope key;
key.rank = ts.rank + 1;
std::vector<Stmt>& vec = allocs_[key];
if (vec.size() != 0) {
Stmt body = MergeNest(vec, op->body);
vec.clear();
return AttrStmt::make(
op->node, op->attr_key, op->value, body);
}
}
return stmt;
}
return IRMutator::Mutate_(op, s);
}
Stmt Mutate_(const For* op, const Stmt& s) final {
CHECK(op->for_type != ForType::Vectorized)
<< "VectorizeLoop before LiftStorageAlloc";
return IRMutator::Mutate_(op, s);
}
Stmt Mutate_(const Allocate* op, const Stmt& s) final {
auto it = storage_scope_.find(op->buffer_var.get());
CHECK(it != storage_scope_.end());
allocs_[it->second].emplace_back(
Allocate::make(
op->buffer_var, op->type, op->extents, op->condition,
Evaluate::make(0)));
return this->Mutate(op->body);
}
private:
// storage scope of internal allocation.
std::unordered_map<const Node*, StorageScope> storage_scope_;
// The current thread scope.
std::vector<ThreadScope> curr_thread_scope_;
// The allocations by rank
std::unordered_map<StorageScope, std::vector<Stmt> > allocs_;
};
Stmt LiftAllocate(Stmt stmt) {
return AllocateLifter().Lift(stmt);
}
} // namespace ir
} // namespace tvm
/*!
* Copyright (c) 2016 by Contributors
* \file scope.h
* \brief attribute scope data structure,
* defines attributes on current domain
*/
#ifndef TVM_PASS_SCOPE_H_
#define TVM_PASS_SCOPE_H_
#include <tvm/ir.h>
#include <unordered_map>
#include <vector>
#include <string>
namespace tvm {
namespace ir {
/*!
* \brief Attribute scope of Nodes in the IR.
* \tparam ValueType The value of of the scope.
*/
template<typename K, typename V>
class Scope {
public:
/*!
* \brief Push value to scope
* \param key the key to be pushed.
* \param v The value to be pushed.
*/
inline void Push(const K& key, V v) {
data_[key].emplace_back(v);
}
/*!
* \brief Pop value from scope.
* \param key the key to be poped
*/
inline void Pop(const K& key) {
auto& v = data_[key];
CHECK_NE(v.size(), 0U);
v.pop_back();
}
/*!
* \brief Get value from the scope
* \param key the key to fetch.
* \return The value to be fetched.
*/
inline V operator[](const K& key) const {
const auto it = data_.find(key);
CHECK(it != data_.end() && it->second.size() != 0)
<< "cannot find value in scope";
return it->second.back();
}
private:
std::unordered_map<K, std::vector<V> > data_;
};
/*! \brief Attribute key for specific attribute */
struct AttrKey {
/*! \brief The node of the attribute */
NodeRef node;
/*! \brief The type key of the attribute. */
std::string type_key;
// overload operator ==
inline bool operator==(const AttrKey& other) const {
return node == other.node && type_key == other.type_key;
}
};
} // namespace ir
} // namespace tvm
namespace std {
template <>
struct hash<::tvm::ir::AttrKey> {
std::size_t operator()(const ::tvm::ir::AttrKey& k) const {
size_t lhs = k.node.hash();
size_t rhs = std::hash<std::string>()(k.type_key);
lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2);
return lhs;
}
};
} // namespace std
#endif // TVM_PASS_SCOPE_H_
/*!
* Copyright (c) 2017 by Contributors
* \file storage_access.h
* \brief Common data structure for storage access analysis.
*/
#ifndef TVM_PASS_STORAGE_ACCESS_H_
#define TVM_PASS_STORAGE_ACCESS_H_
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h>
#include <vector>
#include <unordered_map>
#include "../runtime/thread_storage_scope.h"
namespace tvm {
namespace ir {
namespace storage {
// The storage scope.
using runtime::StorageScope;
/*! \brief Storage access type */
enum AccessType {
kRead,
kWrite,
kOpaque,
kSync,
kAlloc
};
/*! \brief The access entry */
struct AccessEntry {
/*! \brief The buffer variable, if any */
const Variable* buffer{nullptr};
/*! \brief The access index */
Expr index;
/*! \brief The type of access */
AccessType type;
/*! \brief The storage scope */
StorageScope scope;
// constructor
AccessEntry() {}
AccessEntry(const Variable* buffer,
Expr index,
AccessType type,
StorageScope scope)
: buffer(buffer), index(index), type(type), scope(scope) {}
};
/*! \brief The access info about a statment */
struct StmtEntry {
/*! \brief The statement */
const Node* stmt;
/*! \brief access patterns in the statement */
std::vector<AccessEntry> access;
};
} // namespace storage
} // namespace ir
} // namespace tvm
#endif // TVM_PASS_STORAGE_ACCESS_H_
/*!
* Copyright (c) 2017 by Contributors
* \file storage_rewrite.cc
* \brief Memory access pattern analysis and optimization.
* Re-write data access to enable memory sharing when possible.
*/
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_visitor.h>
#include <map>
#include <unordered_set>
#include <unordered_map>
#include "./ir_util.h"
#include "./storage_access.h"
namespace tvm {
namespace ir {
using namespace storage;
// Find a linear pattern of storage acess
// Composite scopes(loop/thread_launch/IfThen) is represented by two points:
// before_scope -> scope_body -> after_scope
//
// The linear_seq_ stores before_scope and after_scope.
// The access to the arrays are stored at the after_scope point.
//
// Define "scope" as the body of For/thread_launch/IfThenElse
// This pass tries to detect last point that we need to keep memory
// alive under the same scope as allocate.
// The storage need to be kept alive between allocate and last access.
// The free point is only inserted at the same scope of allocate.
//
class StorageAccessPatternFinder : public IRVisitor {
public:
// Get linear access pattern.
std::vector<StmtEntry> GetLinearSeq(const Stmt& s) {
this->Visit(s);
return std::move(linear_seq_);
}
void Visit_(const Allocate* op) final {
CHECK(!in_parallel_env_)
<< "Allocation inside parallel is not yet handled.";
size_t level = scope_.size();
const Variable* buf = op->buffer_var.get();
CHECK(!alloc_scope_level_.count(buf));
alloc_scope_level_[buf] = level;
StmtEntry e;
e.stmt = op;
e.access.emplace_back(
AccessEntry(buf, Expr(), kAlloc, GetScope(buf)));
linear_seq_.emplace_back(std::move(e));
IRVisitor::Visit_(op);
}
void Visit_(const Store* op) final {
scope_.push_back(StmtEntry());
// visit subexpr
IRVisitor::Visit_(op);
// Add write access.
const Variable* buf = op->buffer_var.get();
auto it = alloc_scope_level_.find(buf);
if (it != alloc_scope_level_.end()) {
scope_[it->second].access.emplace_back(
AccessEntry(buf, op->index, kWrite, GetScope(buf)));
}
StmtEntry e = scope_.back();
scope_.pop_back();
if (e.access.size() != 0) {
e.stmt = op;
linear_seq_.push_back(e);
}
}
void Visit_(const Load* op) final {
// Add write access.
IRVisitor::Visit_(op);
const Variable* buf = op->buffer_var.get();
auto it = alloc_scope_level_.find(buf);
if (it != alloc_scope_level_.end()) {
CHECK_LT(it->second, scope_.size())
<< "Load memory in places other than store.";
scope_[it->second].access.emplace_back(
AccessEntry(buf, op->index, kRead, GetScope(buf)));
}
}
void Visit_(const Variable* buf) final {
// Directly reference to the variable count as a read.
auto it = alloc_scope_level_.find(buf);
if (it != alloc_scope_level_.end()) {
CHECK_LT(it->second, scope_.size());
scope_[it->second].access.emplace_back(
AccessEntry(buf, Expr(), kOpaque, GetScope(buf)));
}
}
template<typename T>
void VisitNewScope(const T* op) {
scope_.push_back(StmtEntry());
StmtEntry e;
e.stmt = op;
// before scope.
linear_seq_.push_back(e);
IRVisitor::Visit_(op);
// after scope.
e.access = std::move(scope_.back().access);
scope_.pop_back();
linear_seq_.push_back(e);
}
void Visit_(const AttrStmt* op) final {
// Only record the outer most thread extent.
if (op->attr_key == attr::thread_extent && !in_thread_env_) {
in_thread_env_ = true;
VisitNewScope(op);
in_thread_env_ = false;
} else if (op->attr_key == attr::storage_scope) {
const Variable* buf = op->node.as<Variable>();
storage_scope_[buf] =
StorageScope::make(op->value.as<StringImm>()->value);
IRVisitor::Visit_(op);
} else {
IRVisitor::Visit_(op);
}
}
void Visit_(const For* op) final {
if (op->for_type == ForType::Parallel) {
bool in_par = in_parallel_env_;
in_parallel_env_ = true;
VisitNewScope(op);
in_parallel_env_ = in_par;
} else {
VisitNewScope(op);
}
}
void Visit_(const IfThenElse* op) final {
VisitNewScope(op);
}
private:
// Get storage scope of buffer.
StorageScope GetScope(const Variable* buf) const {
auto it = storage_scope_.find(buf);
CHECK(it != storage_scope_.end());
return it->second;
}
// Whether already in thread env.
bool in_thread_env_{false};
// Whether already in parallel env.
bool in_parallel_env_{false};
// linearized access sequence.
std::vector<StmtEntry> linear_seq_;
// The scope stack.
std::vector<StmtEntry> scope_;
// The storage scope of each buffer
std::unordered_map<const Variable*, StorageScope> storage_scope_;
// buffer -> allocated scope level in the IR.
std::unordered_map<const Variable*, size_t> alloc_scope_level_;
};
// Planner to plan and rewrite memory allocation.
class StoragePlanRewriter : public IRMutator {
public:
Stmt Rewrite(Stmt stmt) {
std::vector<StmtEntry> seq =
StorageAccessPatternFinder().GetLinearSeq(stmt);
this->FindFreeLocation(seq);
this->PlanMemory(seq);
this->PrepareNewAlloc();
stmt = this->Mutate(stmt);
if (attach_map_.count(nullptr)) {
std::vector<Stmt> nest;
for (StorageEntry* e : attach_map_.at(nullptr)) {
CHECK_EQ(e->scope.rank, 0);
nest.emplace_back(AttrStmt::make(
e->alloc_var, attr::storage_scope,
StringImm::make(e->scope.to_string()),
Evaluate::make(0)));
nest.push_back(e->new_alloc);
}
stmt = MergeNest(nest, stmt);
}
return stmt;
}
Stmt Mutate_(const Store* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<Store>();
auto it = alloc_map_.find(op->buffer_var.get());
if (it == alloc_map_.end()) return stmt;
return Store::make(it->second->alloc_var, op->value, op->index);
}
Expr Mutate_(const Load* op, const Expr& e) final {
Expr expr = IRMutator::Mutate_(op, e);
op = expr.as<Load>();
auto it = alloc_map_.find(op->buffer_var.get());
if (it == alloc_map_.end()) return expr;
return Load::make(op->type, it->second->alloc_var, op->index);
}
Expr Mutate_(const Variable* op, const Expr& e) final {
auto it = alloc_map_.find(op);
if (it != alloc_map_.end()) {
return it->second->alloc_var;
} else {
return e;
}
}
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
CHECK(op->attr_key != attr::virtual_thread)
<< "InjectVirtualThread before StoragePlan";
if (op->attr_key == attr::storage_scope) {
return this->Mutate(op->body);
} else if (op->attr_key == attr::thread_extent) {
// remake all the allocation at the thread extent.
if (attach_map_.count(op)) {
std::vector<Stmt> nest;
for (StorageEntry* e : attach_map_.at(op)) {
nest.emplace_back(AttrStmt::make(
e->alloc_var, attr::storage_scope,
StringImm::make(e->scope.to_string()),
Evaluate::make(0)));
nest.push_back(e->new_alloc);
}
Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<AttrStmt>();
Stmt body = MergeNest(nest, op->body);
return AttrStmt::make(
op->node, op->attr_key, op->value, body);
} else {
return IRMutator::Mutate_(op, s);
}
} else if (op->attr_key == attr::volatile_scope) {
Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<AttrStmt>();
auto it = alloc_map_.find(op->node.as<Variable>());
if (it == alloc_map_.end()) return stmt;
return AttrStmt::make(
it->second->alloc_var, op->attr_key, op->value, op->body);
} else {
return IRMutator::Mutate_(op, s);
}
}
Stmt Mutate_(const For* op, const Stmt& s) final {
CHECK(op->for_type != ForType::Vectorized)
<< "VectorizeLoop before LiftStorageAlloc";
return IRMutator::Mutate_(op, s);
}
Stmt Mutate_(const Allocate* op, const Stmt& s) final {
return this->Mutate(op->body);
}
private:
// Alllocate entry of node.
struct StorageEntry {
// The scope that this alloc attaches after
// For shared/local memory it is beginning of the thread extent.
// for global memory it is nullptr, means beginning of everything.
const Node* attach_scope_{nullptr};
// The constant size of the buffer in bytes, only used if it is constant.
size_t const_size{0};
// The storage scope.
StorageScope scope;
// Allocs that shares this entry.
std::vector<const Allocate*> allocs;
// The var expr of new allocation.
VarExpr alloc_var;
// The replacement allocation
Stmt new_alloc;
};
// Prepare the new allocations
void PrepareNewAlloc() {
for (size_t i = 0; i < alloc_vec_.size(); ++i) {
StorageEntry* e = alloc_vec_[i].get();
// find the element with the most amount of bytes.
Type t = e->allocs[0]->type;
for (const Allocate* op : e->allocs) {
if (op->type.bytes() * op->type.lanes() > t.bytes() * t.lanes()) {
t = op->type;
}
}
// Get the allocation size;
e->alloc_var = e->allocs[0]->buffer_var;
if (e->allocs.size() == 1) {
// simply use the original allocation.
e->new_alloc = Allocate::make(
e->alloc_var, t, e->allocs[0]->extents,
e->allocs[0]->condition, Evaluate::make(0));
} else {
// Build a merged allocation.
int alloc_unit = t.bytes() * t.lanes();
Expr combo_size;
for (const Allocate* op : e->allocs) {
// Get the size
Expr sz = op->extents[0];
for (size_t i = 1; i < op->extents.size(); ++i) {
sz = sz * op->extents[i];
}
int bytes = op->type.bytes() * op->type.lanes();
if (alloc_unit != bytes) {
sz = (sz * make_const(sz.type(), bytes) +
make_const(sz.type(), alloc_unit - 1)) /
make_const(sz.type(), alloc_unit);
}
if (combo_size.defined()) {
combo_size = max(combo_size, sz);
} else {
combo_size = sz;
}
}
combo_size = ir::Simplify(combo_size);
e->new_alloc = Allocate::make(
e->alloc_var, t, {combo_size}, const_true(),
Evaluate::make(0));
}
attach_map_[e->attach_scope_].push_back(e);
}
}
// Find the free location of each varaible.
// Just do a reverse linear scan.
void FindFreeLocation(const std::vector<StmtEntry>& seq) {
std::unordered_set<const Variable*> touched;
for (size_t i = seq.size(); i != 0; --i) {
const StmtEntry& s = seq[i - 1];
for (const AccessEntry& e : s.access) {
if (!touched.count(e.buffer)) {
touched.insert(e.buffer);
free_loc_[i - 1].push_back(e.buffer);
}
}
}
}
// Memory plan algorithm
void PlanMemory(const std::vector<StmtEntry>& seq) {
for (size_t i = 0; i < seq.size(); ++i) {
const StmtEntry& s = seq[i];
if (s.stmt->is_type<AttrStmt>()) {
const auto* op = static_cast<const AttrStmt*>(s.stmt);
CHECK_EQ(op->attr_key, attr::thread_extent);
if (thread_scope_ != nullptr) {
CHECK(thread_scope_ == op);
// erase all non-global memory from constant free map.
for (auto it = const_free_map_.begin();
it != const_free_map_.end();) {
if (it->second->scope.rank != 0) {
it = const_free_map_.erase(it);
} else {
++it;
}
}
thread_scope_ = nullptr;
} else {
thread_scope_ = op;
}
} else if (s.stmt->is_type<Allocate>()) {
const auto* op = static_cast<const Allocate*>(s.stmt);
StorageEntry* e = this->FindAlloc(op, s.access[0].scope);
e->allocs.emplace_back(op);
alloc_map_[op->buffer_var.get()] = e;
}
// free list
if (free_loc_.count(i)) {
for (const Variable* var : free_loc_.at(i)) {
this->Free(var);
}
}
}
}
// Allocate new storage entry.
StorageEntry* NewAlloc(const Allocate* op,
const StorageScope& scope,
size_t const_size) {
// Re-use not successful, allocate a new buffer.
std::unique_ptr<StorageEntry> entry(new StorageEntry());
entry->attach_scope_ = thread_scope_;
entry->scope = scope;
entry->const_size = const_size;
StorageEntry* e = entry.get();
alloc_vec_.emplace_back(std::move(entry));
return e;
}
StorageEntry* FindAlloc(const Allocate* op,
const StorageScope& scope) {
// skip plan for local variable,
// compiler can do a better job with register allocation.
const size_t match_range = 16;
size_t const_size = static_cast<size_t>(
op->constant_allocation_size()) * op->type.bytes() * op->type.lanes();
if (scope.rank > 1 || op->type.is_handle()) {
return NewAlloc(op, scope, const_size);
}
// disable reuse of small arrays
if (const_size > 0 && const_size <= 32) {
return NewAlloc(op, scope, const_size);
}
if (const_size != 0) {
// constant allocation.
auto begin = const_free_map_.lower_bound(const_size / match_range);
auto mid = const_free_map_.lower_bound(const_size);
auto end = const_free_map_.upper_bound(const_size * match_range);
for (auto it = mid; it != end; ++it) {
StorageEntry *e = it->second;
if (it->second->scope != scope) continue;
e->const_size = std::max(const_size, e->const_size);
const_free_map_.erase(it);
return e;
}
for (auto it = mid; it != begin;) {
--it;
StorageEntry *e = it->second;
if (it->second->scope != scope) continue;
const_free_map_.erase(it);
return e;
}
} else {
// Simple strategy: round roubin.
for (auto it = sym_free_list_.begin();
it != sym_free_list_.end(); ++it) {
StorageEntry* e = *it;
if (e->scope != scope) continue;
sym_free_list_.erase(it);
return e;
}
}
return NewAlloc(op, scope, const_size);
}
// simulated free.
void Free(const Variable* var) {
auto it = alloc_map_.find(var);
CHECK(it != alloc_map_.end());
StorageEntry* e = it->second;
// Disable sharing of local memory.
if (e->scope.rank > 1 || e->allocs[0]->type.is_handle()) return;
// disable reuse of small arrays
if (e->const_size > 0 && e->const_size <= 32) return;
// normal free.
if (e->const_size != 0) {
const_free_map_.insert({e->const_size, e});
} else {
sym_free_list_.push_back(e);
}
}
// thread scope.
const Node* thread_scope_{nullptr};
// Locations of free ops.
std::unordered_map<size_t,
std::vector<const Variable*> > free_loc_;
// The allocation attach map
std::unordered_map<const Node*, std::vector<StorageEntry*> > attach_map_;
// The allocation assign map
std::unordered_map<const Variable*, StorageEntry*> alloc_map_;
// constant size free map.
std::multimap<size_t, StorageEntry*> const_free_map_;
// symbolic free list, for non constant items.
std::list<StorageEntry*> sym_free_list_;
// The allocations
std::vector<std::unique_ptr<StorageEntry> > alloc_vec_;
};
Stmt StorageRewrite(Stmt stmt) {
return StoragePlanRewriter().Rewrite(stmt);
}
} // namespace ir
} // namespace tvm
......@@ -9,12 +9,13 @@
#include <unordered_map>
#include <unordered_set>
#include "./ir_util.h"
#include "./storage_access.h"
#include "../runtime/thread_storage_scope.h"
namespace tvm {
namespace ir {
using runtime::StorageScope;
using namespace storage;
class StorageSyncPlanner : public IRVisitor {
public:
......@@ -130,37 +131,7 @@ class StorageSyncPlanner : public IRVisitor {
std::unordered_set<const Node*> syncs_inserted_;
private:
// Storage access type
enum AccessType {
kRead,
kWrite,
kSync
};
// The access entry
struct AccessEntry {
/*! \brief The buffer variable, if any */
const Variable* buffer{nullptr};
/*! \brief The access index */
Expr index;
/*! \brief The type of access */
AccessType type;
/*! \brief The storage scope */
StorageScope scope;
// constructor
AccessEntry() {}
AccessEntry(const Variable* buffer,
Expr index,
AccessType type,
StorageScope scope)
: buffer(buffer), index(index), type(type), scope(scope) {}
};
// The statment entry
struct StmtEntry {
// the associated statement.
const Node* stmt;
std::vector<AccessEntry> access;
};
// Get current storage scope.
// Get storage scope of buffer.
StorageScope GetScope(const Variable* buf) const {
auto it = storage_scope_.find(buf);
StorageScope s; s.rank = 0;
......
......@@ -21,6 +21,9 @@ struct StorageScope {
inline bool operator==(const StorageScope& other) const {
return rank == other.rank;
}
inline bool operator!=(const StorageScope& other) const {
return !(*this == other);
}
inline std::string to_string() const {
switch (rank) {
case 0: return "global";
......
import tvm
def test_storage_share():
m = tvm.var('m')
l = tvm.var('l')
A = tvm.placeholder((m, l), name='A')
num_stage = 5
B = A
for t in range(num_stage):
B = tvm.compute((m, l), lambda i, j: B[i, j] + (t+1), name='A%d' % t)
s = tvm.create_schedule(B.op)
bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.collections.Map)
stmt = tvm.schedule.ScheduleOps(s, bounds)
Ab = tvm.decl_buffer(A.shape, A.dtype, name='A')
Bb = tvm.decl_buffer(B.shape, B.dtype, name='B')
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb})
stmt = tvm.ir_pass.CanonicalSimplify(stmt)
stmt = tvm.ir_pass.Simplify(stmt)
stmt = tvm.ir_pass.StorageRewrite(stmt)
# verify only have two allocations.
# verify that the data is folded.
num_alloc = [0]
def verify(n):
if isinstance(n, tvm.stmt.Allocate):
num_alloc[0] += 1
elif isinstance(n, tvm.stmt.Store):
assert n.buffer_var != n.value.a.buffer_var
tvm.ir_pass.PostOrderVisit(stmt, verify)
assert num_alloc[0] == 2
def test_storage_share_gpu():
m = tvm.var('m')
A = [tvm.placeholder((m), name='A')]
num_stage = 5
for t in range(num_stage):
A.append(tvm.compute((m,), lambda i: A[-1][i] + (t+1), name='A%d_s' % t))
A.append(tvm.compute((m,), lambda i: A[-1][i], name='A%d' % t))
s = tvm.create_schedule(A[-1].op)
for t in range(num_stage):
x = A[2*t+2].op.axis[0]
bx, tx = s[A[2*t+2]].split(x, factor=32)
s[A[2*t+2]].bind(bx, tvm.thread_axis("blockIdx.x"))
s[A[2*t+2]].bind(tx, tvm.thread_axis("threadIdx.x"))
s[A[2*t+1]].compute_at(s[A[2*t+2]], tx)
s[A[2*t+1]].set_scope("shared")
bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.collections.Map)
stmt = tvm.schedule.ScheduleOps(s, bounds)
Ab = tvm.decl_buffer(A[0].shape, A[0].dtype, name='A')
Bb = tvm.decl_buffer(A[0].shape, A[0].dtype, name='B')
stmt = tvm.ir_pass.StorageFlatten(stmt, {A[0]: Ab, A[-1]: Bb})
stmt = tvm.ir_pass.CanonicalSimplify(stmt)
stmt = tvm.ir_pass.Simplify(stmt)
stmt = tvm.ir_pass.StorageRewrite(stmt)
alloc_stats = {"global": 0, "shared": 0}
def verify(n):
if isinstance(n, tvm.stmt.AttrStmt):
if n.attr_key == "storage_scope":
alloc_stats[n.value.value] += 1
tvm.ir_pass.PostOrderVisit(stmt, verify)
assert alloc_stats["global"] == 2
assert alloc_stats["shared"] == num_stage
if __name__ == "__main__":
test_storage_share_gpu()
test_storage_share()
......@@ -17,7 +17,6 @@ def test_storage_sync():
bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.collections.Map)
stmt = tvm.schedule.ScheduleOps(s, bounds)
print(stmt)
Ab = tvm.decl_buffer(A.shape, A.dtype, name='A')
A2b = tvm.decl_buffer(A2.shape, A2.dtype, name='A2')
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b})
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment