/*! * Copyright (c) 2017 by Contributors * \file storage_sync.cc */ #include <tvm/ir.h> #include <tvm/ir_pass.h> #include <tvm/ir_mutator.h> #include <tvm/ir_visitor.h> #include <unordered_map> #include <unordered_set> #include "./ir_util.h" #include "./storage_access.h" #include "../runtime/thread_storage_scope.h" namespace tvm { namespace ir { using namespace storage; class StorageSyncPlanner : public IRVisitor { public: explicit StorageSyncPlanner(StorageScope sync_scope) : sync_scope_(sync_scope) {} void Visit_(const Load* op) final { if (!in_device_env_) return; CHECK(allow_load_); const Variable* buf = op->buffer_var.as<Variable>(); StorageScope s = GetScope(buf); if (s == sync_scope_) { curr_stmt_.access.emplace_back( AccessEntry(buf, op->index, kRead, s)); } } void Visit_(const Store* op) final { if (!in_device_env_) return; allow_load_ = true; CHECK_EQ(curr_stmt_.access.size(), 0U); curr_stmt_.stmt = op; const Variable* buf = op->buffer_var.as<Variable>(); StorageScope s = GetScope(buf); if (s == sync_scope_) { curr_stmt_.access.emplace_back( AccessEntry(buf, op->index, kWrite, s)); } // traverse child IRVisitor::Visit_(op); // push to the scope scope_.back().push_back(curr_stmt_); // clear access entry. curr_stmt_.access.clear(); allow_load_ = false; } void Visit_(const Evaluate* op) final { if (!in_device_env_) return; if (const Call* call = op->value.as<Call>()) { if (call->is_intrinsic(intrinsic::tvm_storage_sync)) { const std::string& s = call->args[0].as<StringImm>()->value; if (s != "warp") { StorageScope scope = StorageScope::make(s); if (scope.rank <= sync_scope_.rank) { CHECK_EQ(curr_stmt_.access.size(), 0U); curr_stmt_.access.emplace_back( AccessEntry(nullptr, Expr(), kSync, scope)); // push to the scope scope_.back().push_back(curr_stmt_); curr_stmt_.access.clear(); } } } } } void Visit_(const AttrStmt* op) final { if (op->attr_key == attr::storage_scope) { const Variable* buf = op->node.as<Variable>(); storage_scope_[buf] = StorageScope::make(op->value.as<StringImm>()->value); IRVisitor::Visit_(op); } else if (op->attr_key == attr::thread_extent && !in_device_env_) { in_device_env_ = true; CHECK_EQ(scope_.size(), 0U); scope_.push_back(std::vector<StmtEntry>()); IRVisitor::Visit_(op); this->PlanSync(false); in_device_env_ = false; scope_.pop_back(); } else { IRVisitor::Visit_(op); } } void Visit_(const For* op) final { if (in_device_env_) { scope_.push_back(std::vector<StmtEntry>()); IRVisitor::Visit_(op); StmtEntry s; s.stmt = op; s.access = PlanSync(true); scope_.pop_back(); scope_.back().emplace_back(std::move(s)); } else { IRVisitor::Visit_(op); } } void Visit_(const Call* op) final { if (op->is_intrinsic(intrinsic::tvm_address_of)) { const Load *l = op->args[0].as<Load>(); IRVisitor::Visit_(l); } else { IRVisitor::Visit_(op); } } void Visit_(const IfThenElse* op) final { if (in_device_env_) { ++condition_counter_; this->Visit(op->condition); scope_.push_back(std::vector<StmtEntry>()); this->Visit(op->then_case); StmtEntry s; s.stmt = op; s.access = PlanSync(false); scope_.pop_back(); if (op->else_case.defined()) { scope_.push_back(std::vector<StmtEntry>()); auto v = PlanSync(false); scope_.pop_back(); s.access.insert(s.access.end(), v.begin(), v.end()); } scope_.back().emplace_back(std::move(s)); --condition_counter_; } else { IRVisitor::Visit_(op); } } // The syncs inserted before each statement std::unordered_set<const Node*> syncs_inserted_; private: // Get storage scope of buffer. StorageScope GetScope(const Variable* buf) const { auto it = storage_scope_.find(buf); StorageScope s; s.rank = 0; if (it == storage_scope_.end()) return s; return it->second; } // Plan the sync std::vector<AccessEntry> PlanSync(bool is_loop) { // unsynced reads and writes std::vector<AccessEntry> reads; std::vector<AccessEntry> writes; const std::vector<StmtEntry>& seq = scope_.back(); // if it is a loop, rotate two times to consider effect of loop. size_t max_seq = seq.size(); if (is_loop) max_seq *= 2; // simulation based approach to find dependenceies for (size_t i = 0; i < max_seq; ++i) { const StmtEntry& s = seq[i % seq.size()]; // check if sync before statement is needed. bool sync_before_stmt = (syncs_inserted_.count(s.stmt) != 0); // Apply the syncs added already. if (sync_before_stmt) { reads.clear(); writes.clear(); } for (const AccessEntry& acc : s.access) { if (acc.type == kRead) { if (FindConflict(writes, acc)) { sync_before_stmt = true; break; } } else if (acc.type == kWrite) { if (FindConflict(reads, acc)) { sync_before_stmt = true; break; } } else if (acc.type == kSync) { reads.clear(); writes.clear(); } } // If sync is inserted. remove the irrelevant things. if (sync_before_stmt) { reads.clear(); writes.clear(); } // Add the read/write of current statement for (const AccessEntry& acc : s.access) { if (acc.type == kRead) { reads.push_back(acc); } else if (acc.type == kWrite) { writes.push_back(acc); } else if (acc.type == kSync) { reads.clear(); writes.clear(); } } if (sync_before_stmt) { CHECK_EQ(condition_counter_, 0) << "Cannot insert syncs inside condition"; syncs_inserted_.insert(s.stmt); } } // return the exposed entries, remove unecessary ones. int sync_count = 0; // head are before first sync, tail are after last sync std::vector<AccessEntry> head, tail; for (const StmtEntry& s : seq) { if (syncs_inserted_.count(s.stmt)) { if (sync_count != 0) { tail.clear(); } else { head.push_back(AccessEntry(nullptr, Expr(), kSync, sync_scope_)); } ++sync_count; } for (const AccessEntry& acc : s.access) { if (acc.type == kSync) { if (sync_count != 0) { tail.clear(); } else { head.push_back(AccessEntry(nullptr, Expr(), kSync, sync_scope_)); } ++sync_count; } else { if (sync_count != 0) { tail.push_back(acc); } else { head.push_back(acc); } } } } head.insert(head.end(), tail.begin(), tail.end()); return head; } // find conflicting entry in vec. bool FindConflict(const std::vector<AccessEntry>& vec, const AccessEntry& e) { for (const AccessEntry& x : vec) { if (x.buffer == e.buffer && !e.index.same_as(x.index)) return true; } return false; } // Whether we are inside condition. int condition_counter_{0}; // whether load is enabled. bool in_device_env_{false}; // whether load is enabled. bool allow_load_{false}; // the current free stmt entry. StmtEntry curr_stmt_; // access scope std::vector<std::vector<StmtEntry> > scope_; // The storage scope of each buffer std::unordered_map<const Variable*, StorageScope> storage_scope_; // The sync scope we care about. StorageScope sync_scope_; }; class StorageSyncInserter : public IRMutator { public: StorageSyncInserter(StorageScope sync_scope, const std::unordered_set<const Node*>& syncs) : sync_scope_(sync_scope), syncs_(syncs) {} Stmt Mutate(Stmt stmt) final { if (syncs_.size() == 0) return stmt; stmt = IRMutator::Mutate(stmt); if (syncs_.count(stmt.get())) { Stmt barrier; if (sync_scope_.rank == 0) { barrier = MakeGlobalBarrier(); } else { barrier = Evaluate::make( Call::make(Int(32), intrinsic::tvm_storage_sync, {StringImm::make(sync_scope_.to_string())}, Call::Intrinsic)); } stmt = Block::make(barrier, stmt); } return stmt; } Expr Mutate_(const Load* op, const Expr& e) final { if (sync_scope_.rank == 0 && GetScope(op->buffer_var.get()).rank == 0) { ++rw_stats_[op->buffer_var].read_count; } return IRMutator::Mutate_(op, e); } Stmt Mutate_(const Store* op, const Stmt& s) final { if (sync_scope_.rank == 0 && GetScope(op->buffer_var.get()).rank == 0) { ++rw_stats_[op->buffer_var].write_count; } return IRMutator::Mutate_(op, s); } Stmt Mutate_(const AttrStmt* op, const Stmt& s) final { if (op->attr_key == attr::thread_extent) { bool temp = true; std::swap(temp, in_thread_env_); thread_extents_.push_back(op); Stmt ret = IRMutator::Mutate_(op, s); thread_extents_.pop_back(); std::swap(temp, in_thread_env_); // first thread scope. if (!in_thread_env_ && sync_scope_.rank == 0) { ret = InitGlobalBarrier(ret.as<AttrStmt>()); num_blocks_ = Expr(); is_lead_ = Expr(); } return ret; } else if (op->attr_key == attr::storage_scope) { const Variable* buf = op->node.as<Variable>(); storage_scope_[buf] = StorageScope::make(op->value.as<StringImm>()->value); return IRMutator::Mutate_(op, s); } else { return IRMutator::Mutate_(op, s); } } private: // RW statistics about data struct Entry { int read_count{0}; int write_count{0}; }; // Get current storage scope. StorageScope GetScope(const Variable* buf) const { auto it = storage_scope_.find(buf); StorageScope s; s.rank = 0; if (it == storage_scope_.end()) return s; return it->second; } // private functions. Stmt InitGlobalBarrier(const AttrStmt* op) { CHECK(op != nullptr); Array<Expr> pargs = {StringImm::make(runtime::symbol::tvm_prepare_global_barrier)}; Stmt prep = Evaluate::make( Call::make(Int(32), intrinsic::tvm_call_packed, pargs, Call::Intrinsic)); Stmt body = op->body; for (const auto& kv : rw_stats_) { const auto& e = kv.second; if (e.read_count != 0 && e.write_count != 0) { body = AttrStmt::make(kv.first, attr::volatile_scope, 1, body); } } rw_stats_.clear(); Stmt kinit = Evaluate::make( Call::make(Int(32), intrinsic::tvm_global_barrier_kinit, {}, Call::Intrinsic)); body = Block::make(kinit, body); body = AttrStmt::make( op->node, op->attr_key, op->value, body); return Block::make(prep, body); } Stmt MakeGlobalBarrier() { CHECK_EQ(sync_scope_.rank, 0); if (!num_blocks_.defined()) { CHECK(!is_lead_.defined()); num_work_dim_ = thread_extents_.size(); for (const AttrStmt* attr : thread_extents_) { IterVar iv(attr->node.node_); runtime::ThreadScope s = runtime::ThreadScope::make(iv->thread_tag); if (s.rank == 0) { num_blocks_ = (num_blocks_.defined() ? attr->value * num_blocks_ : attr->value); } else if (s.rank == 1) { Expr cond = iv->var == make_zero(iv->var.type()); is_lead_ = is_lead_.defined() ? (is_lead_ && cond) : cond; } } } else { CHECK_EQ(num_work_dim_, thread_extents_.size()); } return Evaluate::make( Call::make(Int(32), intrinsic::tvm_storage_sync, {StringImm::make(sync_scope_.to_string()), is_lead_, num_blocks_}, Call::Intrinsic)); } // data structure. StorageScope sync_scope_; const std::unordered_set<const Node*>& syncs_; // The storage scope of each buffer std::unordered_map<const Variable*, StorageScope> storage_scope_; // The read write statistics of storage std::unordered_map<VarExpr, Entry, NodeHash, NodeEqual> rw_stats_; // The statistics for global barrier bool in_thread_env_{false}; // memorized results std::vector<const AttrStmt*> thread_extents_; size_t num_work_dim_{0}; Expr num_blocks_; Expr is_lead_; }; Stmt StorageSync(Stmt stmt, std::string storage_scope) { StorageScope sync_scope = StorageScope::make(storage_scope); StorageSyncPlanner planner(sync_scope); planner.Visit(stmt); return StorageSyncInserter(sync_scope, planner.syncs_inserted_).Mutate(stmt); } LoweredFunc StorageSync(LoweredFunc f, std::string storage_scope) { CHECK_NE(f->func_type, kHostFunc); auto n = std::make_shared<LoweredFuncNode>(*f.operator->()); n->body = StorageSync(f->body, storage_scope); return LoweredFunc(n); } } // namespace ir } // namespace tvm