/*! * Copyright (c) 2017 by Contributors * * \brief Inject double buffering optimization for data fetch. * \file inject_double_buffer.cc */ #include <tvm/ir_pass.h> #include <tvm/ir_visitor.h> #include <tvm/ir_mutator.h> #include "./ir_util.h" #include "../arithmetic/compute_expr.h" namespace tvm { namespace ir { // Detect double buffer variables. class DoubleBufferDetector : public IRVisitor { public: void Visit_(const AttrStmt* op) final { if (op->attr_key == attr::double_buffer_scope) { touched_.insert(op->node.as<Variable>()); IRVisitor::Visit_(op); } else { IRVisitor::Visit_(op); } } void Visit_(const Variable* op) final { if (touched_.count(op)) { touched_.erase(op); } } // The set of touched variable. std::unordered_set<const Variable*> touched_; }; class StripDoubleBufferWrite : public IRMutator { public: Stmt Mutate_(const AttrStmt* op, const Stmt& s) final { if (op->attr_key == attr::double_buffer_write) { return Mutate(op->body); } else { return IRMutator::Mutate_(op, s); } } }; class DoubleBufferInjector : public IRMutator { public: explicit DoubleBufferInjector(int split_loop) : split_loop_(split_loop) {} Stmt Inject(const Stmt& stmt) { DoubleBufferDetector detector; detector.Visit(stmt); if (detector.touched_.empty()) return stmt; for (const Variable* v : detector.touched_) { dbuffer_info_[v] = StorageEntry(); } return ConvertSSA(this->Mutate(stmt)); } Stmt Mutate_(const AttrStmt* op, const Stmt& s) final { if (op->attr_key == attr::storage_scope) { const Variable* buf = op->node.as<Variable>(); auto it = dbuffer_info_.find(buf); if (it != dbuffer_info_.end()) { it->second.scope = op->value.as<StringImm>()->value; return Mutate(op->body); } else { return IRMutator::Mutate_(op, s); } } else if (op->attr_key == attr::double_buffer_scope) { return MakeProducer(op, s); } else { return IRMutator::Mutate_(op, s); } } Stmt Mutate_(const Allocate* op, const Stmt& s) final { auto it = dbuffer_info_.find(op->buffer_var.get()); if (it != dbuffer_info_.end()) { it->second.stride = arith::ComputeReduce<Mul> (op->extents, Expr()) * op->type.lanes(); Stmt stmt = IRMutator::Mutate_(op, s); op = stmt.as<Allocate>(); Array<Expr> new_extents{make_const(op->extents[0].type(), 2)}; for (Expr e : op->extents) { new_extents.push_back(e); } CHECK(it->second.loop != nullptr); auto& alloc_nest = loop_allocs_[it->second.loop]; alloc_nest.emplace_back(AttrStmt::make( op->buffer_var, attr::storage_scope, StringImm::make(it->second.scope), Evaluate::make(0))); alloc_nest.emplace_back(Allocate::make( op->buffer_var, op->type, new_extents, op->condition, Evaluate::make(0))); return op->body; } else { return IRMutator::Mutate_(op, s); } } Stmt Mutate_(const For* op, const Stmt& s) final { loop_nest_.push_back(op); Stmt stmt = IRMutator::Mutate_(op, s); auto it = loop_pre_.find(op); if (it != loop_pre_.end()) { const For* old_loop = stmt.as<For>(); if (split_loop_ != 0) { // Explicitly unroll the loop CHECK(split_loop_ % 2 == 0 || split_loop_ == 1) << "It is better to split with multiple of 2"; CHECK(is_zero(old_loop->min)); Expr zero = old_loop->min; Expr new_ext = arith::ComputeExpr<Sub>( old_loop->extent, make_const(old_loop->loop_var.type(), 1)); Expr factor = make_const(new_ext.type(), split_loop_); Expr outer_ext = arith::ComputeExpr<Div>(new_ext, factor); Expr tail_base = arith::ComputeExpr<Mul>(outer_ext, factor); Var outer_var(old_loop->loop_var->name_hint + ".outer", old_loop->loop_var.type()); std::unordered_map<const Variable*, Expr> vmap; std::vector<Stmt> loop_seq; for (int32_t i = 0; i < split_loop_; ++i) { vmap[old_loop->loop_var.get()] = outer_var * factor + make_const(factor.type(), i); loop_seq.emplace_back(Substitute(old_loop->body, vmap)); } Stmt loop = For::make( outer_var, zero, outer_ext, old_loop->for_type, old_loop->device_api, MergeSeq(loop_seq)); // tail std::vector<Stmt> tail_seq; Stmt tail_body = StripDoubleBufferWrite().Mutate(old_loop->body); for (int32_t i = 0; i < split_loop_; ++i) { Expr idx = tail_base + make_const(tail_base.type(), i); vmap[old_loop->loop_var.get()] = idx; tail_seq.emplace_back( IfThenElse::make(idx < old_loop->extent, Substitute(tail_body, vmap))); } stmt = Block::make(loop, MergeSeq(tail_seq)); } stmt = Block::make(MergeSeq(it->second), stmt); } it = loop_allocs_.find(op); if (it != loop_allocs_.end()) { stmt = MergeNest(it->second, stmt); } loop_nest_.pop_back(); return stmt; } Stmt Mutate_(const Store* op, const Stmt& s) final { Stmt stmt = IRMutator::Mutate_(op, s); op = stmt.as<Store>(); auto it = dbuffer_info_.find(op->buffer_var.get()); if (it != dbuffer_info_.end()) { const StorageEntry& e = it->second; CHECK(in_double_buffer_scope_); CHECK(e.stride.defined()); return Store::make(op->buffer_var, op->value, e.switch_write_var * e.stride + op->index, op->predicate); } else { return stmt; } } Expr Mutate_(const Load* op, const Expr& e) final { Expr expr = IRMutator::Mutate_(op, e); op = expr.as<Load>(); auto it = dbuffer_info_.find(op->buffer_var.get()); if (it != dbuffer_info_.end()) { const StorageEntry& e = it->second; CHECK(e.stride.defined()); CHECK(e.switch_read_var.defined()); return Load::make(op->type, op->buffer_var, e.switch_read_var * e.stride + op->index, op->predicate); } else { return expr; } } Expr Mutate_(const Variable* op, const Expr& e) final { CHECK(!dbuffer_info_.count(op)); return e; } private: Stmt MakeProducer(const AttrStmt* op, const Stmt& s) { const VarExpr buffer(op->node.node_); CHECK_NE(loop_nest_.size(), 0U) << "Double buffer scope must be inside a loop"; auto it = dbuffer_info_.find(buffer.get()); if (it == dbuffer_info_.end()) { LOG(WARNING) << "Skip double buffer scope " << op->node; return Mutate(op->body); } StorageEntry& e = it->second; e.loop = loop_nest_.back(); Expr zero = make_const(e.loop->loop_var.type(), 0); Expr one = make_const(e.loop->loop_var.type(), 1); Expr two = make_const(e.loop->loop_var.type(), 2); Expr loop_shift = e.loop->loop_var + one; e.switch_write_var = Var(e.loop->loop_var->name_hint + ".db", e.loop->loop_var.type()); e.switch_read_var = e.loop->loop_var % two; in_double_buffer_scope_ = true; Stmt body = Mutate(op->body); in_double_buffer_scope_ = false; std::unordered_map<const Variable*, Expr> vmap; vmap[e.switch_write_var.get()] = zero; vmap[e.loop->loop_var.get()] = zero; loop_pre_[e.loop].emplace_back(Substitute(body, vmap)); vmap[e.loop->loop_var.get()] = loop_shift; vmap[e.switch_write_var.get()] = loop_shift % two; body = Substitute(body, vmap); body = AttrStmt::make(buffer, attr::double_buffer_write, 1, body); body = IfThenElse::make(loop_shift < e.loop->extent, body); return body; } // Storage entry for those who need double buffering. struct StorageEntry { // The size of the buffer Expr stride; // The loop we need const For* loop{nullptr}; // The switch variable. VarExpr switch_write_var; // The switch variable for reading. Expr switch_read_var; // The storage scope. std::string scope; }; // Whether split loop int32_t split_loop_; // Whether we are inside double buffer scope. bool in_double_buffer_scope_{false}; // The current loop next std::vector<const For*> loop_nest_; // The allocs to be appended before the loop std::unordered_map<const For*, std::vector<Stmt> > loop_allocs_; // The stmt to be appended before the loop std::unordered_map<const For*, std::vector<Stmt> > loop_pre_; // The allocation size of the buffer std::unordered_map<const Variable*, StorageEntry> dbuffer_info_; }; Stmt InjectDoubleBuffer(Stmt stmt, int split_loop) { return DoubleBufferInjector(split_loop).Inject(stmt); } } // namespace ir } // namespace tvm