/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file storage_rewrite.cc * \brief Memory access pattern analysis and optimization. * Re-write data access to enable memory sharing when possible. */ #include <tvm/arith/analyzer.h> #include <tvm/tir/expr.h> #include <tvm/tir/ir_pass.h> #include <tvm/tir/stmt_functor.h> #include <tvm/target/target_info.h> #include <map> #include <unordered_set> #include <unordered_map> #include "ir_util.h" #include "../../arith/compute_expr.h" #include "../../runtime/thread_storage_scope.h" namespace tvm { namespace tir { using runtime::StorageRank; using runtime::StorageScope; // Find a linear pattern of storage access // Used for liveness analysis. // Composite scopes(loop/thread_launch/IfThen) is represented by two points: // before_scope -> scope_body -> after_scope // // The linear_seq_ stores before_scope and after_scope. // The access to the arrays are stored at the after_scope point. // // Define "scope" as the body of For/thread_launch/IfThenElse // This pass tries to detect last point that we need to keep memory // alive under the same scope as allocate. // The storage need to be kept alive between allocate and last access. // The free point is only inserted at the same scope of allocate. // class LinearAccessPatternFinder final : public StmtExprVisitor { public: /*! \brief record the touch hist of statment. */ struct StmtEntry { // The statment const Object* stmt; // The index in the linear_seq_ to point to end of the nested scope. // This is only set to non-zero if stmt is a nested scope. // if offset > 0, means this is the begin, the end entry is current_index + offset // if offset < 0, means this is the end, the begin entry is current_index + offset int64_t scope_pair_offset{0}; // The buffer variables this statment touched. std::vector<const VarNode*> touched; }; // The scope of each allocation struct AllocEntry { // Scope used for allocation. StorageScope storage_scope; // scope level size_t level{0}; // allocation stmt const AllocateNode* alloc{nullptr}; }; void VisitStmt_(const AllocateNode* op) final { size_t level = scope_.size(); const VarNode* buf = op->buffer_var.get(); auto it = alloc_info_.find(buf); CHECK(it != alloc_info_.end()); CHECK(it->second.alloc == nullptr); it->second.alloc = op; it->second.level = level; StmtExprVisitor::VisitStmt_(op); } void VisitStmt_(const StoreNode* op) final { scope_.push_back(StmtEntry()); // visit subexpr StmtExprVisitor::VisitStmt_(op); // Add write access. const VarNode* buf = op->buffer_var.get(); auto it = alloc_info_.find(buf); if (it != alloc_info_.end() && it->second.alloc) { CHECK_LT(it->second.level, scope_.size()); scope_[it->second.level].touched.push_back(buf); } StmtEntry e = scope_.back(); scope_.pop_back(); if (e.touched.size() != 0) { e.stmt = op; linear_seq_.push_back(e); } } void VisitStmt_(const EvaluateNode* op) final { scope_.push_back(StmtEntry()); // visit subexpr StmtExprVisitor::VisitStmt_(op); StmtEntry e = scope_.back(); scope_.pop_back(); if (e.touched.size() != 0) { e.stmt = op; linear_seq_.push_back(e); } } void VisitExpr_(const LoadNode* op) final { // Add write access. StmtExprVisitor::VisitExpr_(op); const VarNode* buf = op->buffer_var.get(); auto it = alloc_info_.find(buf); if (it != alloc_info_.end() && it->second.alloc) { CHECK_LT(it->second.level, scope_.size()) << "Load memory in places other than store."; scope_[it->second.level].touched.push_back(buf); } } void VisitExpr_(const CallNode* op) final { if (op->is_intrinsic(intrinsic::tvm_address_of)) { const LoadNode* l = op->args[0].as<LoadNode>(); this->VisitExpr(l->index); } else { StmtExprVisitor::VisitExpr_(op); } } void VisitExpr_(const VarNode* buf) final { // Directly reference to the variable count as a read. auto it = alloc_info_.find(buf); if (it != alloc_info_.end() && it->second.alloc) { CHECK_LT(it->second.level, scope_.size()) << " buf=" << buf->name_hint; scope_[it->second.level].touched.push_back(buf); } } template<typename T> void VisitNewScope(const T* op) { scope_.push_back(StmtEntry()); StmtEntry e; e.stmt = op; int64_t begin_index = static_cast<int64_t>(linear_seq_.size()); // before scope. linear_seq_.push_back(e); StmtExprVisitor::VisitStmt_(op); // after scope. e.touched = std::move(scope_.back().touched); scope_.pop_back(); int64_t end_index = static_cast<int64_t>(linear_seq_.size()); CHECK_GT(end_index, begin_index); e.scope_pair_offset = begin_index - end_index; linear_seq_.push_back(e); // record the pointer to end index. CHECK_NE(end_index, 0U); linear_seq_[begin_index].scope_pair_offset = end_index - begin_index; } void VisitStmt_(const AttrStmtNode* op) final { // Only record the outer most thread extent. if (op->attr_key == attr::thread_extent && !in_thread_env_) { in_thread_env_ = true; VisitNewScope(op); in_thread_env_ = false; } else if (op->attr_key == attr::extern_scope) { VisitNewScope(op); } else if (op->attr_key == attr::virtual_thread) { VisitNewScope(op); } else if (op->attr_key == attr::storage_scope) { const VarNode* buf = op->node.as<VarNode>(); alloc_info_[buf].storage_scope = StorageScope::make(op->value.as<StringImmNode>()->value); StmtExprVisitor::VisitStmt_(op); } else { StmtExprVisitor::VisitStmt_(op); } } void VisitStmt_(const IfThenElseNode* op) final { VisitNewScope(op); } void VisitStmt_(const ForNode* op) final { VisitNewScope(op); } void VisitStmt_(const AssertStmtNode* op) final { VisitNewScope(op); } // linearized access sequence. std::vector<StmtEntry> linear_seq_; // The storage scope of each buffer std::unordered_map<const VarNode*, AllocEntry> alloc_info_; private: // Whether already in thread env. bool in_thread_env_{false}; // The scope stack. std::vector<StmtEntry> scope_; }; // Verify if the statement can be run safely via inplace fashion // // Detect pattern: dst[index] = f(src[index]) // // WARNING: the current detection algorithm cannot handle the case // when a location in an array is written multiple times // // For example, the following program will pass the check, // but we cannot make A and B to be the same array. // // A[0] = B[0] + 1 // A[0] = B[0] + 1 // // The high level code generator needs to ensure that the generated // code only write each location of the target array once. // // This is the case with IR generated by the current compute schedule. // We explicitly return false if we find there is an extern block // which can be arbitrary IR. // // Neve-the-less, inplace detector should be used with care in mind. // We may also consider introduce a condition checker that checks // if every index only visited once for an absolute sufficient condition. // // The code after inplace transformation is no longer idempotent. // class InplaceOpVerifier : public StmtExprVisitor { public: bool Check(const Object* stmt, const VarNode* dst, const VarNode* src) { dst_ = dst; src_ = src; result_ = true; if (stmt->IsInstance<AttrStmtNode>()) { VisitStmt_(static_cast<const AttrStmtNode*>(stmt)); } else if (stmt->IsInstance<ForNode>()) { VisitStmt_(static_cast<const ForNode*>(stmt)); } else if (stmt->IsInstance<IfThenElseNode>()) { VisitStmt_(static_cast<const IfThenElseNode*>(stmt)); } else if (stmt->IsInstance<StoreNode>()) { VisitStmt_(static_cast<const StoreNode*>(stmt)); } else { return false; } return result_; } using StmtExprVisitor::VisitStmt_; void VisitStmt(const Stmt& n) final { if (!result_) return; StmtExprVisitor::VisitStmt(n); } void VisitExpr(const PrimExpr& n) final { if (!result_) return; StmtExprVisitor::VisitExpr(n); } void VisitExpr_(const VarNode* op) final { // assume all opaque access is unsafe if (op == dst_ || op == src_) { result_ = false; return; } } void VisitStmt_(const StoreNode* op) final { ++mem_nest_; this->VisitExpr(op->index); --mem_nest_; if (op->buffer_var.get() == dst_) { store_ = op; this->VisitExpr(op->value); this->VisitExpr(op->predicate); store_ = nullptr; } else { this->VisitExpr(op->value); this->VisitExpr(op->predicate); } } void VisitStmt_(const AttrStmtNode* op) final { // always reject extern code if (op->attr_key == attr::extern_scope || op->attr_key == attr::volatile_scope) { result_ = false; return; } StmtExprVisitor::VisitStmt_(op); } void VisitExpr_(const LoadNode* op) final { const VarNode* buf = op->buffer_var.get(); // cannot read from dst_ (no reduction) if (buf == dst_) { result_ = false; return; } // do not allow indirect memory load if (mem_nest_ != 0) { result_ = false; return; } if (src_ == buf) { if (store_ == nullptr || store_->value.dtype() != op->dtype || !tir::Equal(store_->index, op->index)) { result_ = false; return; } } ++mem_nest_; StmtExprVisitor::VisitExpr_(op); --mem_nest_; } private: // result of the check bool result_{true}; // destination memory const VarNode* dst_; // source variable const VarNode* src_; // counter of load, // it is not safe to inplace when there is nested load like A[B[i]] int mem_nest_{0}; // The current store to be inspected const StoreNode* store_{nullptr}; }; // Planner to plan and rewrite memory allocation. class StoragePlanRewriter : public StmtExprMutator { public: using StmtEntry = LinearAccessPatternFinder::StmtEntry; using AllocEntry = LinearAccessPatternFinder::AllocEntry; Stmt Rewrite(Stmt stmt, bool detect_inplace) { detect_inplace_ = detect_inplace; // plan the rewrite LinearAccessPatternFinder finder; finder(stmt); this->LivenessAnalysis(finder.linear_seq_); this->PlanMemory(finder.linear_seq_, finder.alloc_info_); this->PrepareNewAlloc(); // start rewrite stmt = operator()(std::move(stmt)); if (attach_map_.count(nullptr)) { std::vector<Stmt> nest; for (StorageEntry* e : attach_map_.at(nullptr)) { // CHECK_EQ(e->scope.rank, 0); if (e->new_alloc.defined()) { nest.emplace_back(AttrStmtNode::make( e->alloc_var, attr::storage_scope, StringImmNode::make(e->scope.to_string()), EvaluateNode::make(0))); nest.push_back(e->new_alloc); } } stmt = MergeNest(nest, stmt); } return stmt; } Stmt VisitStmt_(const StoreNode* op) final { Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as<StoreNode>(); auto it = alloc_map_.find(op->buffer_var.get()); if (it == alloc_map_.end()) return stmt; return StoreNode::make(it->second->alloc_var, op->value, RemapIndex(op->value.dtype(), op->index, it->second), op->predicate); } PrimExpr VisitExpr_(const LoadNode* op) final { PrimExpr expr = StmtExprMutator::VisitExpr_(op); op = expr.as<LoadNode>(); auto it = alloc_map_.find(op->buffer_var.get()); if (it == alloc_map_.end()) return expr; return LoadNode::make(op->dtype, it->second->alloc_var, RemapIndex(op->dtype, op->index, it->second), op->predicate); } PrimExpr VisitExpr_(const VarNode* op) final { auto it = alloc_map_.find(op); if (it != alloc_map_.end()) { if (it->second->bits_offset != 0) { LOG(WARNING) << "Use a merged buffer variable address, could cause error"; } return it->second->alloc_var; } else { return GetRef<PrimExpr>(op); } } PrimExpr VisitExpr_(const CallNode* op) final { if (op->is_intrinsic(intrinsic::tvm_access_ptr)) { CHECK_EQ(op->args.size(), 5U); DataType dtype = op->args[0].dtype(); const VarNode* buffer = op->args[1].as<VarNode>(); auto it = alloc_map_.find(buffer); if (it == alloc_map_.end()) { return StmtExprMutator::VisitExpr_(op); } const StorageEntry* se = it->second; PrimExpr offset = this->VisitExpr(op->args[2]); PrimExpr extent = this->VisitExpr(op->args[3]); uint64_t elem_bits = dtype.bits() * dtype.lanes(); CHECK_EQ(se->bits_offset % elem_bits, 0U); if (se->bits_offset != 0) { offset = make_const(offset.dtype(), se->bits_offset / elem_bits) + offset; } return CallNode::make( op->dtype, op->name, {op->args[0], se->alloc_var, offset, extent, op->args[4]}, op->call_type); } else { return StmtExprMutator::VisitExpr_(op); } } Stmt VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == attr::storage_scope) { return this->VisitStmt(op->body); } else if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread || attr::IsPragmaKey(op->attr_key)) { // remake all the allocation at the attach scope. if (attach_map_.count(op)) { auto& svec = attach_map_[op]; Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as<AttrStmtNode>(); return AttrStmtNode::make( op->node, op->attr_key, op->value, MakeAttach(svec, op->body)); } else { return StmtExprMutator::VisitStmt_(op); } } else if (op->attr_key == attr::volatile_scope) { Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as<AttrStmtNode>(); auto it = alloc_map_.find(op->node.as<VarNode>()); if (it == alloc_map_.end()) return stmt; return AttrStmtNode::make( it->second->alloc_var, op->attr_key, op->value, op->body); } else { return StmtExprMutator::VisitStmt_(op); } } Stmt VisitStmt_(const ForNode* op) final { CHECK(op->for_type != ForType::Vectorized) << "VectorizeLoop before LiftStorageAlloc"; // remake all the allocation at the attach scope. if (attach_map_.count(op)) { auto& svec = attach_map_[op]; Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as<ForNode>(); return ForNode::make( op->loop_var, op->min, op->extent, op->for_type, op->device_api, MakeAttach(svec, op->body)); } else { return StmtExprMutator::VisitStmt_(op); } } Stmt VisitStmt_(const AllocateNode* op) final { return this->VisitStmt(op->body); } private: struct StorageEntry { // The scope that this alloc attaches after // For shared/local memory it is beginning of the thread extent. // for global memory it is nullptr, means beginning of everything. const Object* attach_scope_{nullptr}; // The constant size of the buffer in bits, only used if it is constant uint64_t const_nbits{0}; // The storage scope. StorageScope scope; // Allocs that shares this entry. std::vector<const AllocateNode*> allocs; // The children of this entry, not including itself. std::vector<StorageEntry*> merged_children; // The replacement allocation, if any. Stmt new_alloc; // The var expr of new allocation. Var alloc_var; // The allocation element type. DataType elem_type; // This is non-zero if this allocate is folded into another one // the address(in bits) becomes alloc_var + bits_offset; // can be effectively converted to the element type. // We need to convert bit_offset to offset of specific element type later. // // We use bits(instead of bytes) to support non-conventional indexing in hardware. // When we are merging buffer together, the bits_offset are set to be aligned // to certain value given by the max_simd_bits property of the special memory. // // This allows effective sharing among different types as long as their alignment // requirement fits into the max_simd_bits. uint64_t bits_offset{0}; }; // Alllocate entry of node. // Event entry in liveness analysis struct EventEntry { // variables we generate std::vector<const VarNode*> gen; // variables we kill std::vector<const VarNode*> kill; }; Stmt MakeAttach(const std::vector<StorageEntry*>& svec, Stmt body) { std::vector<Stmt> nest; for (StorageEntry* e : svec) { if (e->new_alloc.defined()) { nest.emplace_back(AttrStmtNode::make( e->alloc_var, attr::storage_scope, StringImmNode::make(e->scope.to_string()), EvaluateNode::make(0))); nest.push_back(e->new_alloc); } } return MergeNest(nest, body); } // Remap the index PrimExpr RemapIndex(DataType dtype, PrimExpr index, StorageEntry* e) { if (e->bits_offset == 0) return index; uint64_t elem_bits = dtype.bits() * dtype.lanes(); CHECK_EQ(e->bits_offset % elem_bits, 0U); return make_const(index.dtype(), e->bits_offset / elem_bits) + index; } // Prepare the new allocations void PrepareNewAlloc() { for (size_t i = 0; i < alloc_vec_.size(); ++i) { StorageEntry* e = alloc_vec_[i].get(); attach_map_[e->attach_scope_].push_back(e); } // find allocation via attach map. for (auto &kv : attach_map_) { // find the element with the most amount of bytes. std::vector<StorageEntry*>& vec = kv.second; // try to find merge, for tagged memory for (size_t i = 0; i < vec.size(); ++i) { StorageEntry* e = vec[i]; if (e->scope.tag.length() != 0) { CHECK_NE(e->const_nbits, 0U) << "Special tagged memory must be const size"; for (size_t j = 0; j < i; ++j) { if (e->scope == vec[j]->scope) { vec[j]->merged_children.push_back(e); break; } } } } // Start allocation for (size_t i = 0; i < vec.size(); ++i) { StorageEntry* e = vec[i]; // already merged if (e->bits_offset != 0) continue; if (e->merged_children.size() != 0) { NewAllocTagMerged(e); continue; } // Get the allocation size; e->alloc_var = e->allocs[0]->buffer_var; DataType alloc_type = e->allocs[0]->dtype; for (const AllocateNode* op : e->allocs) { if (op->dtype.lanes() > alloc_type.lanes()) { alloc_type = op->dtype; } } if (e->allocs.size() == 1) { // simply use the original allocation. PrimExpr sz = arith::ComputeReduce<MulNode>(e->allocs[0]->extents, make_const(DataType::Int(32), 1)); e->new_alloc = AllocateNode::make( e->alloc_var, alloc_type, {sz}, e->allocs[0]->condition, EvaluateNode::make(0)); if (e->scope.tag.length() != 0) { MemoryInfo info = GetMemoryInfo(e->scope.to_string()); uint64_t total_elem = e->const_nbits / e->elem_type.bits(); CHECK_LE(total_elem * e->elem_type.bits(), info->max_num_bits) << "Allocation exceed bound of memory tag " << e->scope.to_string(); } } else { // Build a merged allocation PrimExpr combo_size; for (const AllocateNode* op : e->allocs) { PrimExpr sz = arith::ComputeReduce<MulNode>( op->extents, make_const(DataType::Int(32), 1)); auto nbits = op->dtype.bits() * op->dtype.lanes(); if (const auto* imm = sz.as<IntImmNode>()) { if (imm->value > std::numeric_limits<int>::max() / nbits) { LOG(WARNING) << "The allocation requires : " << imm->value << " * " << nbits << " bits, which is greater than the maximum of" " int32. The size is cast to int64." << "\n"; sz = make_const(DataType::Int(64), imm->value); } } // transform to bits auto sz_nbits = sz * nbits; if (combo_size.defined()) { combo_size = max(combo_size, sz_nbits); } else { combo_size = sz_nbits; } } // transform to alloc bytes auto type_bits = alloc_type.bits() * alloc_type.lanes(); bool divided = analyzer_.CanProve(indexmod(combo_size, type_bits) == 0); combo_size = indexdiv(combo_size, type_bits); // round up for can not divided if (!divided) { combo_size = combo_size + make_const(DataType::Int(32), 1); } combo_size = tir::Simplify(combo_size); e->new_alloc = AllocateNode::make( e->alloc_var, alloc_type, {combo_size}, const_true(), EvaluateNode::make(0)); if (e->scope.tag.length() != 0) { MemoryInfo info = GetMemoryInfo(e->scope.to_string()); uint64_t total_elem = e->const_nbits / e->elem_type.bits(); CHECK_LE(total_elem * e->elem_type.bits(), info->max_num_bits) << "Allocation exceed bound of memory tag " << e->scope.to_string(); } } } } } // New allocation for merged data void NewAllocTagMerged(StorageEntry* e) { CHECK_NE(e->scope.tag.length(), 0U); // allocate with element type. CHECK_NE(e->const_nbits, 0U); MemoryInfo info = GetMemoryInfo(e->scope.to_string()); uint64_t total_bits = e->const_nbits; // By default, align to 32 bits. size_t align = 32; if (info.defined()) { align = info->max_simd_bits; } // Always align to max_simd_bits // so we can remap types by keeping this property if (total_bits % align != 0) { total_bits += align - (total_bits % align); } e->alloc_var = e->allocs[0]->buffer_var; for (StorageEntry* child : e->merged_children) { CHECK_NE(child->const_nbits, 0U); CHECK_NE(total_bits, 0U); child->bits_offset = total_bits; child->alloc_var = e->alloc_var; total_bits += child->const_nbits; if (total_bits % align != 0) { total_bits += align - (total_bits % align); } } uint64_t type_bits = e->elem_type.bits() * e->elem_type.lanes(); PrimExpr alloc_size = make_const(e->allocs[0]->extents[0].dtype(), (total_bits + type_bits - 1) / type_bits); e->new_alloc = AllocateNode::make( e->alloc_var, e->elem_type, {alloc_size}, const_true(), EvaluateNode::make(0)); if (info.defined()) { CHECK_LE(total_bits, info->max_num_bits) << "Allocation exceed bound of memory tag " << e->scope.to_string(); } } // Liveness analysis to find gen and kill point of each variable. void LivenessAnalysis(const std::vector<StmtEntry>& seq) { // find kill point, do a reverse linear scan. std::unordered_set<const VarNode*> touched; for (size_t i = seq.size(); i != 0; --i) { const StmtEntry& s = seq[i - 1]; for (const VarNode* buffer : s.touched) { if (!touched.count(buffer)) { touched.insert(buffer); event_map_[s.stmt].kill.push_back(buffer); } } } // find gen point, do forward scan touched.clear(); for (size_t i = 0; i < seq.size(); ++i) { int64_t offset = seq[i].scope_pair_offset; if (offset < 0) continue; const StmtEntry& s = seq[i + offset]; for (const VarNode* buffer : s.touched) { if (!touched.count(buffer)) { touched.insert(buffer); event_map_[s.stmt].gen.push_back(buffer); } } } } void PlanNewScope(const Object* op) { if (thread_scope_ != nullptr) { CHECK(thread_scope_ == op); // erase all memory atatched to this scope. for (auto it = const_free_map_.begin(); it != const_free_map_.end();) { if (it->second->attach_scope_ == op) { it = const_free_map_.erase(it); } else { ++it; } } for (auto it = sym_free_list_.begin(); it != sym_free_list_.end();) { if ((*it)->attach_scope_ == op) { it = sym_free_list_.erase(it); } else { ++it; } } thread_scope_ = nullptr; } else { thread_scope_ = op; } } // Memory plan algorithm void PlanMemory(const std::vector<StmtEntry>& seq, const std::unordered_map<const VarNode*, AllocEntry>& alloc_info) { std::unordered_set<const VarNode*> inplace_flag; for (size_t i = 0; i < seq.size(); ++i) { const StmtEntry& s = seq[i]; auto it = event_map_.find(seq[i].stmt); // scope_pair_offset >= 0 means it is either // - leaf stmt(offset = 0) // - beginning of scope(offset < 0) // In both cases, we need to handle the gen event correctly if (it != event_map_.end() && seq[i].scope_pair_offset >= 0) { // Inplace operation detection // specially handle this bool detect_inplace = detect_inplace_ && (it->second.gen.size() <= 2); for (const VarNode* var : it->second.gen) { CHECK(alloc_info.count(var)); const AllocEntry& ae = alloc_info.at(var); StorageEntry* dst_entry = nullptr; // inplace detection if (detect_inplace) { // only one inplace var for s.stmt bool inplace_found = false; for (const VarNode* src : it->second.kill) { if (!inplace_flag.count(src) && alloc_map_.count(src)) { InplaceOpVerifier visitor; StorageEntry* src_entry = alloc_map_.at(src); if (src_entry->scope == ae.storage_scope && src_entry->attach_scope_ == thread_scope_ && src_entry->elem_type == ae.alloc->dtype.element_of() && visitor.Check(s.stmt, var, src)) { uint64_t const_nbits = static_cast<uint64_t>(ae.alloc->constant_allocation_size()) * ae.alloc->dtype.bits() * ae.alloc->dtype.lanes(); if (src_entry->const_nbits == const_nbits && !inplace_found) { // successfully inplace dst_entry = src_entry; inplace_flag.insert(src); inplace_found = true; } } } } } if (dst_entry == nullptr) { dst_entry = FindAlloc(ae.alloc, thread_scope_, ae.storage_scope); } dst_entry->allocs.emplace_back(ae.alloc); alloc_map_[var] = dst_entry; } } // enter/exit new scope if (s.stmt->IsInstance<AttrStmtNode>()) { const auto* op = static_cast<const AttrStmtNode*>(s.stmt); if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread || attr::IsPragmaKey(op->attr_key)) { PlanNewScope(op); } else { CHECK(op->attr_key == attr::extern_scope); } } else if (s.stmt->IsInstance<ForNode>()) { const auto* op = static_cast<const ForNode*>(s.stmt); if (op->for_type == ForType::Parallel) { if (thread_scope_ == nullptr || thread_scope_ == op) { PlanNewScope(op); } } } // scope_pair_offset <= 0 means it is either // - leaf stmt(offset = 0) // - end of scope(offset < 0) // In both cases, we need to handle the kill event correctly if (it != event_map_.end() && seq[i].scope_pair_offset <= 0) { for (const VarNode* var : it->second.kill) { // skip space which are already replaced by inplace if (!inplace_flag.count(var)) { this->Free(var); } } } } } // Allocate new storage entry. StorageEntry* NewAlloc(const AllocateNode* op, const Object* attach_scope, const StorageScope& scope, size_t const_nbits) { CHECK(op != nullptr); // Re-use not successful, allocate a new buffer. std::unique_ptr<StorageEntry> entry(new StorageEntry()); entry->attach_scope_ = attach_scope; entry->scope = scope; entry->elem_type = op->dtype.element_of(); entry->const_nbits = const_nbits; StorageEntry* e = entry.get(); alloc_vec_.emplace_back(std::move(entry)); return e; } StorageEntry* FindAlloc(const AllocateNode* op, const Object* attach_scope, const StorageScope& scope) { CHECK(op != nullptr); // skip plan for local variable, // compiler can do a better job with register allocation. const uint64_t match_range = 16; uint64_t op_elem_bits = op->dtype.bits() * op->dtype.lanes(); uint64_t const_nbits = static_cast<uint64_t>( op->constant_allocation_size() * op_elem_bits); // disable reuse of small arrays, they will be lowered to registers in LLVM // This rules only apply if we are using non special memory if (scope.tag.length() == 0) { if (scope.rank >= StorageRank::kWarp || op->dtype.is_handle()) { return NewAlloc(op, attach_scope, scope, const_nbits); } if (const_nbits > 0 && const_nbits <= 32) { return NewAlloc(op, attach_scope, scope, const_nbits); } } if (const_nbits != 0) { // constant allocation. auto begin = const_free_map_.lower_bound(const_nbits / match_range); auto mid = const_free_map_.lower_bound(const_nbits); auto end = const_free_map_.upper_bound(const_nbits * match_range); // start looking at the buffer that is bigger than the required size first for (auto it = mid; it != end; ++it) { StorageEntry *e = it->second; if (e->attach_scope_ != attach_scope) continue; if (e->scope != scope) continue; // when not divided, no reuse, eg, float4 vs float3 if (e->bits_offset % op_elem_bits != 0) continue; e->const_nbits = std::max(const_nbits, e->const_nbits); const_free_map_.erase(it); return e; } // then start looking at smaller buffers. for (auto it = mid; it != begin;) { --it; StorageEntry *e = it->second; if (e->attach_scope_ != attach_scope) continue; if (e->scope != scope) continue; if (e->elem_type != op->dtype.element_of()) continue; e->const_nbits = std::max(const_nbits, e->const_nbits); const_free_map_.erase(it); return e; } } else { // Simple strategy: round roubin. for (auto it = sym_free_list_.begin(); it != sym_free_list_.end(); ++it) { StorageEntry* e = *it; if (e->attach_scope_ != attach_scope) continue; if (e->scope != scope) continue; if (e->elem_type != op->dtype.element_of()) continue; sym_free_list_.erase(it); return e; } } return NewAlloc(op, attach_scope, scope, const_nbits); } // simulated free. void Free(const VarNode* var) { auto it = alloc_map_.find(var); CHECK(it != alloc_map_.end()); StorageEntry* e = it->second; CHECK_NE(e->allocs.size(), 0U); // disable reuse of small arrays, they will be lowered to registers in LLVM // This rules only apply if we are using non special memory if (e->scope.tag.length() == 0) { // Disable sharing of local memory. if (e->scope.rank >= StorageRank::kWarp || e->allocs[0]->dtype.is_handle()) return; // disable reuse of small arrays if (e->const_nbits > 0 && e->const_nbits <= 32) return; } // normal free. if (e->const_nbits != 0) { const_free_map_.insert({e->const_nbits, e}); } else { sym_free_list_.push_back(e); } } // thread scope. const Object* thread_scope_{nullptr}; // whether enable inplace detection. bool detect_inplace_{false}; // Locations of free ops. std::unordered_map<const Object*, EventEntry> event_map_; // constant size free map. std::multimap<uint64_t, StorageEntry*> const_free_map_; // symbolic free list, for non constant items. std::list<StorageEntry*> sym_free_list_; // The allocation attach map std::unordered_map<const Object*, std::vector<StorageEntry*> > attach_map_; // The allocation assign map std::unordered_map<const VarNode*, StorageEntry*> alloc_map_; // The allocations std::vector<std::unique_ptr<StorageEntry> > alloc_vec_; // analyzer arith::Analyzer analyzer_; }; // Turn alloc into vector alloc // if all its access is the same vector type. class VectorAllocRewriter : public StmtExprMutator { public: PrimExpr VisitExpr_(const LoadNode* op) final { UpdateTypeMap(op->buffer_var.get(), op->dtype); return StmtExprMutator::VisitExpr_(op); } Stmt VisitStmt_(const StoreNode* op) final { UpdateTypeMap(op->buffer_var.get(), op->value.dtype()); return StmtExprMutator::VisitStmt_(op); } PrimExpr VisitExpr_(const CallNode* op) final { if (op->is_intrinsic(intrinsic::tvm_access_ptr)) { DataType dtype = op->args[0].dtype(); const VarNode* buffer = op->args[1].as<VarNode>(); UpdateTypeMap(buffer, dtype); } return StmtExprMutator::VisitExpr_(op); } Stmt VisitStmt_(const AllocateNode* op) final { Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as<AllocateNode>(); const auto& tvec = acc_map_[op->buffer_var.get()]; if (tvec.size() == 1 && tvec[0].element_of() == op->dtype.element_of() && tvec[0].lanes() % op->dtype.lanes() == 0 && tvec[0].lanes() != op->dtype.lanes()) { int factor = tvec[0].lanes() / op->dtype.lanes(); Array<PrimExpr> extents = op->extents; arith::ModularSet me = analyzer_.modular_set(extents[extents.size() - 1]); if (me->base % factor == 0 && me->coeff % factor == 0) { extents.Set(extents.size() - 1, extents[extents.size() - 1] / make_const(extents[0].dtype(), factor)); return AllocateNode::make( op->buffer_var, tvec[0], extents, op->condition, op->body); } } return stmt; } void UpdateTypeMap(const VarNode* buffer, DataType t) { auto& tvec = acc_map_[buffer]; if (std::find(tvec.begin(), tvec.end(), t) == tvec.end()) { tvec.push_back(t); } } // Internal access map std::unordered_map<const VarNode*, std::vector<DataType> > acc_map_; // internal analyzer arith::Analyzer analyzer_; }; LoweredFunc PointerValueTypeRewrite(LoweredFunc f) { auto n = make_object<LoweredFuncNode>(*f.operator->()); VectorAllocRewriter rewriter; n->body = rewriter(n->body); for (Var arg : f->args) { if (arg.dtype().is_handle()) { const auto& tvec = rewriter.acc_map_[arg.get()]; if (tvec.size() == 1) { PrimExpr dtype = make_const(tvec[0], 0); n->handle_data_type.Set(arg, dtype); } else { // always set data type to be non vectorized so // load/store can still work via scalarization if (tvec.size() != 0 && !n->handle_data_type.count(arg)) { PrimExpr dtype = make_const(tvec[0].with_lanes(1), 0); n->handle_data_type.Set(arg, dtype); } } } } return LoweredFunc(n); } Stmt StorageRewrite(Stmt stmt) { stmt = StoragePlanRewriter().Rewrite(std::move(stmt), true); return VectorAllocRewriter()(std::move(stmt)); } } // namespace tir } // namespace tvm