/*! * Copyright (c) 2018 by Contributors * * Lower warp memory to use local memory * and shuffle intrinsics. * * \file lower_warp_memory.cc */ // Thanks to Andrew Adams and Vinod Grover for // explaining the concept of warp shuffle. #include <tvm/ir.h> #include <tvm/ir_mutator.h> #include <tvm/ir_visitor.h> #include <tvm/ir_pass.h> #include <unordered_set> #include "./ir_util.h" #include "../arithmetic/compute_expr.h" #include "../runtime/thread_storage_scope.h" namespace tvm { namespace ir { // Rewrite Rule // // There is no special warp memory in most GPUs. // Instead, we can stripe the data into threads // and store the data into local memory. // // This requires us to do the following rewriting: // - Rewrite allocation to use local memory. // - Rewrite store of warp memory to local store. // - Rewrite load of waro memory to local plus a shuffle. // // Define a generic shuffle instrinsic warp_shuffle(data, warp_index). // We can use the following rewriting rule // // Before rewrite, // // alloc warp warp_mem[n * warp_size * m] // store warp_mem[m * warp_index + (warp_size * m) * y + x] // load warp_mem[m * z + (warp_size * m) * y + x] // subject to x \in [0, m), y \in [0, n) // // After rewrite: // // alloc local local_mem[n * m] // store warp_mem[m * y + x] // warp_shuffle(load warp_mem[m * y + x], z) // subject to (m * y + x) is invariant to warp_index // Algorithm // // To implement this rewrite rule, we can do the follow step: // For each warp memory alloc // - Use linear pattern detector on load index to find m // - Deduce n given warp_size and alloc size // - Now that we have m, n, warp_size, we can proceed with the rewrite // Visitor to find m in pattern // store warp_mem[m * warp_index + (warp_size * m) * y + x] class WarpStoreCoeffFinder : private IRVisitor { public: WarpStoreCoeffFinder(const Variable* buffer, Var warp_index) : buffer_(buffer), warp_index_(warp_index) { } // find the warp co-efficient in the statement given the warp size int Find(const Stmt& stmt) { this->Visit(stmt); return warp_coeff_; } private: /// Visitor implementation void Visit_(const Store *op) final { if (op->buffer_var.get() == buffer_) { if (op->value.type().lanes() == 1) { UpdatePattern(op->index); } else { Expr base; CHECK(GetRamp1Base(op->index, op->value.type().lanes(), &base)) << "LowerWarpMemory failed due to store index=" << op->index << ", can only handle continuous store"; UpdatePattern(base); } } else { IRVisitor::Visit_(op); } } void UpdatePattern(const Expr& index) { Array<Expr> m = arith::DetectLinearEquation(index, {warp_index_}); CHECK_EQ(m.size(), 2U) << "LowerWarpMemory failed due to store index=" << index; int coeff; Expr mcoeff = ir::Simplify(m[0]); CHECK(arith::GetConstInt(mcoeff, &coeff) && coeff > 0) << "LowerWarpMemory failed due to store index=" << index << ", require positive constant coefficient on warp index " << warp_index_ << " but get " << mcoeff; if (warp_coeff_ != 0) { CHECK_EQ(warp_coeff_, coeff) << "LowerWarpMemory failed due to two different store coefficient to warp index"; } else { warp_coeff_ = coeff; } } // The buffer variable const Variable* buffer_; // the warp index Var warp_index_; // the coefficient int warp_coeff_{0}; }; // Visitor to find the warp index class WarpIndexFinder : private IRVisitor { public: explicit WarpIndexFinder(int warp_size) : warp_size_(warp_size) { } // find the warp co-efficient in the statement given the warp size IterVar Find(const Stmt& stmt) { this->Visit(stmt); CHECK(warp_index_.defined()) << "Cannot find warp index(threadIdx.x) within the scope of warp memory"; return warp_index_; } private: /// Visitor implementation void Visit_(const AttrStmt *op) final { if (op->attr_key == attr::thread_extent) { IterVar iv(op->node.node_); if (iv->thread_tag == "threadIdx.x") { int value; CHECK(arith::GetConstInt(op->value, &value) && value == warp_size_) << "Expect threadIdx.x 's size to be equal to warp size(" << warp_size_ << ")" << " to enable warp memory" << " but get " << op->value << " instead"; if (warp_index_.defined()) { CHECK(warp_index_.same_as(iv)) << "Find two instance of " << warp_index_->thread_tag << " in the same kernel. " << "Please create it using thread_axis once and reuse the axis " << "across multiple binds in the same kernel"; } else { warp_index_ = iv; } } } IRVisitor::Visit_(op); } // warp size int warp_size_{0}; // the warp index IterVar warp_index_{nullptr}; }; // Mutator to change the read pattern class WarpAccessRewriter : protected IRMutator { public: explicit WarpAccessRewriter(int warp_size) : warp_size_(warp_size) {} // Rewrite the allocate statement which transforms // warp memory to local memory. Stmt Rewrite(const Allocate* op, const Stmt& stmt) { buffer_ = op->buffer_var.get(); int alloc_size = op->constant_allocation_size(); CHECK_GT(alloc_size, 0) << "warp memory only support constant alloc size"; alloc_size *= op->type.lanes(); warp_index_ = WarpIndexFinder(warp_size_).Find(op->body)->var; warp_coeff_ = WarpStoreCoeffFinder( buffer_, warp_index_).Find(op->body); CHECK_EQ(alloc_size % (warp_size_ * warp_coeff_), 0) << "Warp memory must be multiple of warp size"; warp_group_ = alloc_size / (warp_size_ * warp_coeff_); return Allocate::make( op->buffer_var, op->type, {make_const(Int(32), alloc_size / warp_size_)}, op->condition, this->Mutate(op->body)); } protected: Expr Mutate_(const Variable* op, const Expr& expr) { CHECK(op != buffer_) << "Cannot access address of warp memory directly"; return IRMutator::Mutate_(op, expr); } Stmt Mutate_(const Store* op, const Stmt& stmt) { if (op->buffer_var.get() == buffer_) { Expr local_index, group; std::tie(local_index, group) = SplitIndexByGroup(op->index); return Store::make(op->buffer_var, op->value, local_index, op->predicate); } else { return IRMutator::Mutate_(op, stmt); } } Expr Mutate_(const Load* op, const Expr& expr) { if (op->buffer_var.get() == buffer_) { Expr local_index, group; std::tie(local_index, group) = SplitIndexByGroup(op->index); // invariance: local index must do not contain warp id CHECK(!ExprUseVar(local_index, {warp_index_.get()})) << "LowerWarpMemory failed to rewrite load to shuffle for index " << op->index << " local_index=" << local_index; Expr load_value = Load::make( op->type, op->buffer_var, local_index, op->predicate); return Call::make(load_value.type(), intrinsic::tvm_warp_shuffle, {load_value, group}, Call::Intrinsic); } else { return IRMutator::Mutate_(op, expr); } } // Split the index to the two component // <local_index, source_index> // local index is the index in the local // source index is the corresponding source index // in this access pattern. std::pair<Expr, Expr> SplitIndexByGroup(const Expr& index) { if (index.type().lanes() != 1) { Expr base, local_index, group; CHECK(GetRamp1Base(index, index.type().lanes(), &base)); std::tie(local_index, group) = SplitIndexByGroup(base); local_index = Ramp::make(local_index, make_const(local_index.type(), 1), index.type().lanes()); return std::make_pair(local_index, group); } Expr m = make_const(index.type(), warp_coeff_); Range rng = Range::make_by_min_extent( make_zero(index.type()), make_const(index.type(), warp_size_)); Map<Var, Range> vrange({{warp_index_, rng}}); // simple case, warp index is on the highest. if (warp_group_ == 1) { Expr x = Simplify(index % m, vrange); Expr z = Simplify(index / m, vrange); return std::make_pair(x, z); } else { Expr x = Simplify(index % m, vrange); Expr y = index / make_const(index.type(), warp_coeff_ * warp_size_); y = y * m + x; Expr z = index % make_const(index.type(), warp_coeff_ * warp_size_) / m; return std::make_pair(Simplify(y, vrange), Simplify(z, vrange)); } } private: // the warp size int warp_size_{0}; // The buffer variable const Variable* buffer_; // Warp index Var warp_index_; // the coefficient m int warp_coeff_{0}; // the coefficient n int warp_group_{0}; }; // Mutator to change the read pattern class WarpMemoryRewriter : private IRMutator { public: explicit WarpMemoryRewriter(int warp_size) : warp_size_(warp_size) { } Stmt Rewrite(Stmt stmt) { if (warp_size_ == 1) return stmt; stmt = this->Mutate(stmt); stmt = CanonicalSimplify(stmt); return stmt; } private: Stmt Mutate_(const Allocate* op, const Stmt& stmt) { if (warp_buffer_.count(op->buffer_var.get())) { WarpAccessRewriter rewriter(warp_size_); return rewriter.Rewrite(op, stmt); } else { return IRMutator::Mutate_(op, stmt); } } Stmt Mutate_(const AttrStmt* op, const Stmt& stmt) { using runtime::StorageScope; if (op->attr_key == attr::storage_scope) { const Variable* buf = op->node.as<Variable>(); StorageScope scope = StorageScope::make(op->value.as<StringImm>()->value); if (scope.rank == runtime::StorageRank::kWarp) { warp_buffer_.insert(buf); Stmt ret = IRMutator::Mutate_(op, stmt); op = ret.as<AttrStmt>(); return AttrStmt::make( op->node, op->attr_key, StringImm::make("local"), op->body); } } return IRMutator::Mutate_(op, stmt); } int warp_size_{0}; std::unordered_set<const Variable*> warp_buffer_; }; LoweredFunc LowerWarpMemory(LoweredFunc f, int warp_size) { CHECK_EQ(f->func_type, kDeviceFunc); auto n = std::make_shared<LoweredFuncNode>(*f.operator->()); n->body = WarpMemoryRewriter(warp_size).Rewrite(n->body); return LoweredFunc(n); } } // namespace ir } // namespace tvm