Commit e4c74668 by Tianqi Chen Committed by GitHub

[PASS] StorageRewrite, Memory optimization pass as in NNVM. (#104)

* [PASS] StorageRewrite, reuse memory pass as in NNVM.

* fix issue
parent 1245cc25
......@@ -158,14 +158,15 @@ Stmt VectorizeLoop(Stmt stmt);
Stmt InjectVirtualThread(Stmt stmt);
/*!
* \brief Lift storage allocation to relevant outpost location
*
* Only do this after vectorization and virtual thread injection completes.
* \brief Rewrite storage allocation pattern.
* Moves the allocation to outer most possible scope.
* Trying to share space between allocations to make
* a static allocation plan when possible.
*
* \param stmt The stmt to be trasnformed
* \return Transformed stmt.
*/
Stmt LiftAllocate(Stmt stmt);
Stmt StorageRewrite(Stmt stmt);
/*!
* \brief partition loops in the stmt
......
......@@ -70,7 +70,7 @@ def lower(sch,
stmt = ir_pass.CanonicalSimplify(stmt)
stmt = ir_pass.VectorizeLoop(stmt)
stmt = ir_pass.InjectVirtualThread(stmt)
stmt = ir_pass.LiftAllocate(stmt)
stmt = ir_pass.StorageRewrite(stmt)
stmt = ir_pass.UnrollLoop(stmt, max_auto_unroll_step)
stmt = ir_pass.Simplify(stmt)
if not with_api_wrapper:
......
......@@ -68,7 +68,7 @@ REGISTER_PASS2(UnrollLoop);
REGISTER_PASS2(StorageSync);
REGISTER_PASS4(MakeAPI);
REGISTER_PASS1(SplitHostDevice);
REGISTER_PASS1(LiftAllocate);
REGISTER_PASS1(StorageRewrite);
REGISTER_PASS1(InjectVirtualThread);
REGISTER_PASS1(LoopPartition);
REGISTER_PASS1(RemoveNoOp);
......
......@@ -249,7 +249,7 @@ llvm::BasicBlock* CodeGenLLVM::CheckCallSuccess(llvm::Value* retcode) {
}
void CodeGenLLVM::AddAliasInfo(
llvm::Instruction* inst, const Variable* buffer, Expr index) {
llvm::Instruction* inst, const Variable* buffer, Expr index, Type t) {
int base = 0, width = 0;
// create meta-data for alias analysis
// Use a group of binary tree ranges.
......@@ -274,9 +274,11 @@ void CodeGenLLVM::AddAliasInfo(
}
}
llvm::MDNode* meta = md_tbaa_root_;
std::ostringstream buffer_addr;
std::ostringstream buffer_addr, buffer_type;
buffer_addr << buffer;
meta = md_builder_->createTBAAScalarTypeNode(buffer_addr.str(), meta);
buffer_type << t.element_of();
meta = md_builder_->createTBAAScalarTypeNode(buffer_type.str(), meta);
// create a tree-shape access structure.
if (width != 0) {
for (int w = 1024; w >= width; w /= 2) {
......@@ -1033,7 +1035,7 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Load* op) {
llvm::LoadInst* inst = builder_->CreateAlignedLoad(
CreateBufferPtr(t, buf, MakeValue(op->index)),
data_layout_->getTypeAllocSize(LLVMType(t)));
AddAliasInfo(inst, op->buffer_var.get(), op->index);
AddAliasInfo(inst, op->buffer_var.get(), op->index, op->type);
return inst;
} else if (ramp && is_one(ramp->stride)) {
int alignment, native_bits;
......@@ -1053,7 +1055,7 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Load* op) {
llvm::LoadInst* inst = builder_->CreateAlignedLoad(
builder_->CreatePointerCast(ptr, vtype), alignment);
AddAliasInfo(inst, op->buffer_var.get(),
Ramp::make(base, make_const(base.type(), 1), lanes));
Ramp::make(base, make_const(base.type(), 1), lanes), op->type);
loads.push_back(inst);
}
return CreateVecConcat(loads);
......@@ -1127,7 +1129,7 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Load* op) {
llvm::Value* ptr = CreateBufferPtr(t.element_of(), buf, offset);
llvm::LoadInst* inst = builder_->CreateAlignedLoad(
ptr, data_layout_->getTypeAllocSize(LLVMType(t)));
AddAliasInfo(inst, op->buffer_var.get(), Expr());
AddAliasInfo(inst, op->buffer_var.get(), Expr(), op->type);
ret = builder_->CreateInsertElement(ret, inst, ConstInt32(i));
});
return ret;
......@@ -1146,7 +1148,7 @@ void CodeGenLLVM::VisitStmt_(const Store* op) {
value,
CreateBufferPtr(t, buf, MakeValue(op->index)),
data_layout_->getTypeAllocSize(value->getType()));
AddAliasInfo(inst, op->buffer_var.get(), op->index);
AddAliasInfo(inst, op->buffer_var.get(), op->index, op->value.type());
} else if (ramp && is_one(ramp->stride)) {
int alignment, native_bits;
GetAlignment(t, op->buffer_var.get(), ramp->base,
......@@ -1165,7 +1167,7 @@ void CodeGenLLVM::VisitStmt_(const Store* op) {
CreateVecSlice(value, offset, lanes),
builder_->CreatePointerCast(ptr, vtype), alignment);
AddAliasInfo(inst, op->buffer_var.get(),
Ramp::make(base, make_const(base.type(), 1), lanes));
Ramp::make(base, make_const(base.type(), 1), lanes), op->value.type());
}
} else {
Scalarize(op->index, [&](int i, llvm::Value* offset) {
......@@ -1173,7 +1175,7 @@ void CodeGenLLVM::VisitStmt_(const Store* op) {
llvm::StoreInst* inst = builder_->CreateAlignedStore(
builder_->CreateExtractElement(value, ConstInt32(i)),
ptr, data_layout_->getTypeAllocSize(LLVMType(t)));
AddAliasInfo(inst, op->buffer_var.get(), Expr());
AddAliasInfo(inst, op->buffer_var.get(), Expr(), op->value.type());
});
}
}
......
......@@ -211,7 +211,7 @@ class CodeGenLLVM :
// Add a function to set global module context
void InitGlobalContext();
// add alias information.
void AddAliasInfo(llvm::Instruction* load, const Variable* buffer, Expr index);
void AddAliasInfo(llvm::Instruction* load, const Variable* buffer, Expr index, Type type);
// The definition of local variable.
std::unordered_map<const Variable*, llvm::Value*> var_map_;
// global strings
......
/*!
* Copyright (c) 2017 by Contributors
* \file lift_allocate.cc
*/
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
#include <unordered_map>
#include "./ir_util.h"
#include "../runtime/thread_storage_scope.h"
namespace tvm {
namespace ir {
using runtime::StorageScope;
using runtime::ThreadScope;
class AllocateLifter : public IRMutator {
public:
Stmt Lift(Stmt stmt) {
stmt = this->Mutate(stmt);
StorageScope key; key.rank = 0;
stmt = MergeNest(allocs_[key], stmt);
return stmt;
}
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
CHECK(op->attr_key != attr::virtual_thread)
<< "InjectVirtualThread before LiftStorageAlloc";
if (op->attr_key == attr::storage_scope) {
StorageScope sc = StorageScope::make(op->value.as<StringImm>()->value);
allocs_[sc].emplace_back(
AttrStmt::make(
op->node, attr::storage_scope,
op->value, Evaluate::make(0)));
storage_scope_[op->node.get()] = sc;
return this->Mutate(op->body);
} else if (op->attr_key == attr::thread_extent) {
IterVar iv(op->node.node_);
ThreadScope ts = ThreadScope::make(iv->thread_tag);
curr_thread_scope_.push_back(ts);
Stmt stmt = IRMutator::Mutate_(op, s);
curr_thread_scope_.pop_back();
op = stmt.as<AttrStmt>();
bool first_scope = true;
for (const ThreadScope& t : curr_thread_scope_) {
if (t.rank == ts.rank) first_scope = false;
}
if (first_scope) {
StorageScope key;
key.rank = ts.rank + 1;
std::vector<Stmt>& vec = allocs_[key];
if (vec.size() != 0) {
Stmt body = MergeNest(vec, op->body);
vec.clear();
return AttrStmt::make(
op->node, op->attr_key, op->value, body);
}
}
return stmt;
}
return IRMutator::Mutate_(op, s);
}
Stmt Mutate_(const For* op, const Stmt& s) final {
CHECK(op->for_type != ForType::Vectorized)
<< "VectorizeLoop before LiftStorageAlloc";
return IRMutator::Mutate_(op, s);
}
Stmt Mutate_(const Allocate* op, const Stmt& s) final {
auto it = storage_scope_.find(op->buffer_var.get());
CHECK(it != storage_scope_.end());
allocs_[it->second].emplace_back(
Allocate::make(
op->buffer_var, op->type, op->extents, op->condition,
Evaluate::make(0)));
return this->Mutate(op->body);
}
private:
// storage scope of internal allocation.
std::unordered_map<const Node*, StorageScope> storage_scope_;
// The current thread scope.
std::vector<ThreadScope> curr_thread_scope_;
// The allocations by rank
std::unordered_map<StorageScope, std::vector<Stmt> > allocs_;
};
Stmt LiftAllocate(Stmt stmt) {
return AllocateLifter().Lift(stmt);
}
} // namespace ir
} // namespace tvm
/*!
* Copyright (c) 2016 by Contributors
* \file scope.h
* \brief attribute scope data structure,
* defines attributes on current domain
*/
#ifndef TVM_PASS_SCOPE_H_
#define TVM_PASS_SCOPE_H_
#include <tvm/ir.h>
#include <unordered_map>
#include <vector>
#include <string>
namespace tvm {
namespace ir {
/*!
* \brief Attribute scope of Nodes in the IR.
* \tparam ValueType The value of of the scope.
*/
template<typename K, typename V>
class Scope {
public:
/*!
* \brief Push value to scope
* \param key the key to be pushed.
* \param v The value to be pushed.
*/
inline void Push(const K& key, V v) {
data_[key].emplace_back(v);
}
/*!
* \brief Pop value from scope.
* \param key the key to be poped
*/
inline void Pop(const K& key) {
auto& v = data_[key];
CHECK_NE(v.size(), 0U);
v.pop_back();
}
/*!
* \brief Get value from the scope
* \param key the key to fetch.
* \return The value to be fetched.
*/
inline V operator[](const K& key) const {
const auto it = data_.find(key);
CHECK(it != data_.end() && it->second.size() != 0)
<< "cannot find value in scope";
return it->second.back();
}
private:
std::unordered_map<K, std::vector<V> > data_;
};
/*! \brief Attribute key for specific attribute */
struct AttrKey {
/*! \brief The node of the attribute */
NodeRef node;
/*! \brief The type key of the attribute. */
std::string type_key;
// overload operator ==
inline bool operator==(const AttrKey& other) const {
return node == other.node && type_key == other.type_key;
}
};
} // namespace ir
} // namespace tvm
namespace std {
template <>
struct hash<::tvm::ir::AttrKey> {
std::size_t operator()(const ::tvm::ir::AttrKey& k) const {
size_t lhs = k.node.hash();
size_t rhs = std::hash<std::string>()(k.type_key);
lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2);
return lhs;
}
};
} // namespace std
#endif // TVM_PASS_SCOPE_H_
/*!
* Copyright (c) 2017 by Contributors
* \file storage_access.h
* \brief Common data structure for storage access analysis.
*/
#ifndef TVM_PASS_STORAGE_ACCESS_H_
#define TVM_PASS_STORAGE_ACCESS_H_
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h>
#include <vector>
#include <unordered_map>
#include "../runtime/thread_storage_scope.h"
namespace tvm {
namespace ir {
namespace storage {
// The storage scope.
using runtime::StorageScope;
/*! \brief Storage access type */
enum AccessType {
kRead,
kWrite,
kOpaque,
kSync,
kAlloc
};
/*! \brief The access entry */
struct AccessEntry {
/*! \brief The buffer variable, if any */
const Variable* buffer{nullptr};
/*! \brief The access index */
Expr index;
/*! \brief The type of access */
AccessType type;
/*! \brief The storage scope */
StorageScope scope;
// constructor
AccessEntry() {}
AccessEntry(const Variable* buffer,
Expr index,
AccessType type,
StorageScope scope)
: buffer(buffer), index(index), type(type), scope(scope) {}
};
/*! \brief The access info about a statment */
struct StmtEntry {
/*! \brief The statement */
const Node* stmt;
/*! \brief access patterns in the statement */
std::vector<AccessEntry> access;
};
} // namespace storage
} // namespace ir
} // namespace tvm
#endif // TVM_PASS_STORAGE_ACCESS_H_
......@@ -9,12 +9,13 @@
#include <unordered_map>
#include <unordered_set>
#include "./ir_util.h"
#include "./storage_access.h"
#include "../runtime/thread_storage_scope.h"
namespace tvm {
namespace ir {
using runtime::StorageScope;
using namespace storage;
class StorageSyncPlanner : public IRVisitor {
public:
......@@ -130,37 +131,7 @@ class StorageSyncPlanner : public IRVisitor {
std::unordered_set<const Node*> syncs_inserted_;
private:
// Storage access type
enum AccessType {
kRead,
kWrite,
kSync
};
// The access entry
struct AccessEntry {
/*! \brief The buffer variable, if any */
const Variable* buffer{nullptr};
/*! \brief The access index */
Expr index;
/*! \brief The type of access */
AccessType type;
/*! \brief The storage scope */
StorageScope scope;
// constructor
AccessEntry() {}
AccessEntry(const Variable* buffer,
Expr index,
AccessType type,
StorageScope scope)
: buffer(buffer), index(index), type(type), scope(scope) {}
};
// The statment entry
struct StmtEntry {
// the associated statement.
const Node* stmt;
std::vector<AccessEntry> access;
};
// Get current storage scope.
// Get storage scope of buffer.
StorageScope GetScope(const Variable* buf) const {
auto it = storage_scope_.find(buf);
StorageScope s; s.rank = 0;
......
......@@ -21,6 +21,9 @@ struct StorageScope {
inline bool operator==(const StorageScope& other) const {
return rank == other.rank;
}
inline bool operator!=(const StorageScope& other) const {
return !(*this == other);
}
inline std::string to_string() const {
switch (rank) {
case 0: return "global";
......
import tvm
def test_storage_share():
m = tvm.var('m')
l = tvm.var('l')
A = tvm.placeholder((m, l), name='A')
num_stage = 5
B = A
for t in range(num_stage):
B = tvm.compute((m, l), lambda i, j: B[i, j] + (t+1), name='A%d' % t)
s = tvm.create_schedule(B.op)
bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.collections.Map)
stmt = tvm.schedule.ScheduleOps(s, bounds)
Ab = tvm.decl_buffer(A.shape, A.dtype, name='A')
Bb = tvm.decl_buffer(B.shape, B.dtype, name='B')
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb})
stmt = tvm.ir_pass.CanonicalSimplify(stmt)
stmt = tvm.ir_pass.Simplify(stmt)
stmt = tvm.ir_pass.StorageRewrite(stmt)
# verify only have two allocations.
# verify that the data is folded.
num_alloc = [0]
def verify(n):
if isinstance(n, tvm.stmt.Allocate):
num_alloc[0] += 1
elif isinstance(n, tvm.stmt.Store):
assert n.buffer_var != n.value.a.buffer_var
tvm.ir_pass.PostOrderVisit(stmt, verify)
assert num_alloc[0] == 2
def test_storage_share_gpu():
m = tvm.var('m')
A = [tvm.placeholder((m), name='A')]
num_stage = 5
for t in range(num_stage):
A.append(tvm.compute((m,), lambda i: A[-1][i] + (t+1), name='A%d_s' % t))
A.append(tvm.compute((m,), lambda i: A[-1][i], name='A%d' % t))
s = tvm.create_schedule(A[-1].op)
for t in range(num_stage):
x = A[2*t+2].op.axis[0]
bx, tx = s[A[2*t+2]].split(x, factor=32)
s[A[2*t+2]].bind(bx, tvm.thread_axis("blockIdx.x"))
s[A[2*t+2]].bind(tx, tvm.thread_axis("threadIdx.x"))
s[A[2*t+1]].compute_at(s[A[2*t+2]], tx)
s[A[2*t+1]].set_scope("shared")
bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.collections.Map)
stmt = tvm.schedule.ScheduleOps(s, bounds)
Ab = tvm.decl_buffer(A[0].shape, A[0].dtype, name='A')
Bb = tvm.decl_buffer(A[0].shape, A[0].dtype, name='B')
stmt = tvm.ir_pass.StorageFlatten(stmt, {A[0]: Ab, A[-1]: Bb})
stmt = tvm.ir_pass.CanonicalSimplify(stmt)
stmt = tvm.ir_pass.Simplify(stmt)
stmt = tvm.ir_pass.StorageRewrite(stmt)
alloc_stats = {"global": 0, "shared": 0}
def verify(n):
if isinstance(n, tvm.stmt.AttrStmt):
if n.attr_key == "storage_scope":
alloc_stats[n.value.value] += 1
tvm.ir_pass.PostOrderVisit(stmt, verify)
assert alloc_stats["global"] == 2
assert alloc_stats["shared"] == num_stage
if __name__ == "__main__":
test_storage_share_gpu()
test_storage_share()
......@@ -17,7 +17,6 @@ def test_storage_sync():
bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.collections.Map)
stmt = tvm.schedule.ScheduleOps(s, bounds)
print(stmt)
Ab = tvm.decl_buffer(A.shape, A.dtype, name='A')
A2b = tvm.decl_buffer(A2.shape, A2.dtype, name='A2')
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b})
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment