/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * Copyright (c) 2017 by Contributors * \file loop_partition.cc */ #include <tvm/ir.h> #include <tvm/ir_visitor.h> #include <tvm/ir_mutator.h> #include <tvm/ir_pass.h> #include <tvm/arithmetic.h> #include <unordered_map> #include <unordered_set> #include "../arithmetic/int_set.h" #include "../runtime/thread_storage_scope.h" namespace tvm { namespace ir { using arith::IntSet; using arith::DeduceBound; using arith::Intersect; using PartitionKey = std::pair<const Node*, bool>; struct PartitionKeyHash { std::size_t operator()(PartitionKey const& k) const noexcept { std::size_t h1 = std::hash<const Node*>{}(k.first); std::size_t h2 = std::hash<bool>{}(k.second); return h1 ^ h2; } }; // Each mapping (cond, cond_value) -> interval represents the fact that // condition cond is proven to have value cond_value (true or false) in interval. using Partition = std::unordered_map<PartitionKey, IntSet, PartitionKeyHash>; bool ExprUseVars(Expr expr, const std::unordered_set<const Variable*>& vars) { bool success = false; PostOrderVisit(expr, [&vars, &success](const NodeRef& node) { if (const Variable* v = node.as<Variable>()) { if (vars.count(v)) { success = true; return; } } }); return success; } // Select potential candidate IRs that can be partitioned. // Rule: // - the range should not be const // - there exist a condition expression in the scope that use the var class CandidateSelector final : public IRVisitor { public: using VarIsUsed = bool; explicit CandidateSelector(bool split_const_loop) : split_const_loop_(split_const_loop) {} void Visit_(const For* op) { // partition const loop when sets split_const_loop_ if (!is_const(op->min) || !is_const(op->extent) || split_const_loop_) { const Variable* var = op->loop_var.get(); record_.insert({var, false}); IRVisitor::Visit_(op); if (record_.at(var) && !no_split_) { candidates.insert(op); } record_.erase(var); } else { IRVisitor::Visit_(op); } } void Visit_(const AttrStmt* op) { if (op->attr_key == attr::thread_extent) { const IterVarNode *iv = op->node.as<IterVarNode>(); CHECK(iv); Var var = iv->var; runtime::ThreadScope scope = runtime::ThreadScope::make(iv->thread_tag); if ((scope.rank == 0) && (!is_const(op->value) || split_const_loop_)) { record_.insert({var.get(), false}); IRVisitor::Visit_(op); if (record_.at(var.get()) && !no_split_) { candidates.insert(op); } record_.erase(var.get()); return; } } IRVisitor::Visit_(op); } void Visit_(const Block* op) { bool temp = no_split_; this->Visit(op->first); // erase the no split state of first when visit rest. std::swap(temp, no_split_); this->Visit(op->rest); // restore the no split flag. no_split_ = no_split_ || temp; } void Visit_(const Call* op) { if (op->is_intrinsic(Call::likely)) { in_likely_ = true; IRVisitor::Visit_(op); in_likely_ = false; } else if (op->is_intrinsic(intrinsic::tvm_thread_allreduce)) { // no split if the body contains allreduce. no_split_ = true; return; } else { IRVisitor::Visit_(op); } } void Visit_(const Variable* op) { if (in_likely_ && record_.count(op)) { record_.at(op) = true; } } std::unordered_set<const Node*> candidates; private: bool in_likely_{false}; bool no_split_{false}; bool split_const_loop_{false}; std::unordered_map<const Variable*, VarIsUsed> record_; }; // Populate partitions data structure, i.e., for a specific variable, // find an interval in which each condition // (currently, "likely" conditions) has fixed true or false value class PartitionFinder : public IRVisitor { public: explicit PartitionFinder(VarExpr current_var, const std::unordered_map<const Variable*, IntSet>& hint_map, const std::unordered_map<const Variable*, IntSet>& relax_map) : current_var_(current_var), hint_map_(hint_map), relax_map_(relax_map) { for (const auto& kv : hint_map) { out_vars_.insert(kv.first); } for (const auto& kv : relax_map) { out_vars_.insert(kv.first); } } void Visit_(const For* op) { if (ExprUseVars(op->min, out_vars_) || ExprUseVars(op->extent, out_vars_)) return; const Variable* var = op->loop_var.get(); hint_map_.insert({var, IntSet::interval(op->min, op->min + op->extent - 1)}); relax_map_.insert({var, IntSet::interval(op->min, op->min + op->extent - 1)}); IRVisitor::Visit_(op); relax_map_.erase(var); hint_map_.erase(var); } void Visit_(const AttrStmt* op) { // handle thread_axis if (op->attr_key == attr::thread_extent) { const IterVarNode* thread_axis = op->node.as<IterVarNode>(); CHECK(thread_axis); const Variable* var = thread_axis->var.get(); IntSet dom = IntSet::range(Range(make_zero(op->value.type()), op->value)); hint_map_.insert({var, dom}); relax_map_.insert({var, dom}); IRVisitor::Visit_(op); relax_map_.erase(var); hint_map_.erase(var); } else { IRVisitor::Visit_(op); } } void Visit_(const Call* op) { if (op->is_intrinsic(Call::likely)) { Expr cond = op->args[0]; if (ExprUseVars(cond, std::unordered_set<const Variable*>({current_var_.get()}))) { // For cond, find out the interval, if exists, in which we can prove that cond is // true. Also find the interval, if exists, in which we can prove that cond is // false. IntSet interval = DeduceBound(current_var_, cond, hint_map_, relax_map_); if (!interval.is_nothing()) { // cond is true within interval partitions[{cond.get(), true}] = interval; } Expr inverse_cond = InverseCond(cond); if (inverse_cond.defined()) { IntSet interval = DeduceBound(current_var_, inverse_cond, hint_map_, relax_map_); if (!interval.is_nothing()) { // cond is false within interval partitions[{cond.get(), false}] = interval; } } } } else { IRVisitor::Visit_(op); } } Partition partitions; private: Expr InverseCond(const Expr& cond) { // We expect most condition not to be of EQ or NE form. // Currently we do not handle inversing EQ or NE. Expr inverse_cond; if (const LT* op = cond.as<LT>()) { // a < b -> a >= b inverse_cond = GE::make(op->a, op->b); } else if (const GT* op = cond.as<GT>()) { // a > b -> a <= b inverse_cond = LE::make(op->a, op->b); } else if (const LE* op = cond.as<LE>()) { // a <= b -> a > b inverse_cond = GT::make(op->a, op->b); } else if (const GE* op = cond.as<GE>()) { // a >= b -> a < b inverse_cond = LT::make(op->a, op->b); } return inverse_cond; } VarExpr current_var_; std::unordered_set<const Variable*> out_vars_; std::unordered_map<const Variable*, IntSet> hint_map_; std::unordered_map<const Variable*, IntSet> relax_map_; }; // Replace the set of conditions given by ps with cond_value (true or false) class ConditionEliminator : public IRMutator { public: explicit ConditionEliminator(const std::unordered_set<const Node*>& ps, bool cond_value = true) : ps_(ps), cond_value_(cond_value) {} using IRMutator::Mutate; Expr Mutate(Expr e) final { if (ps_.find(e.get()) != ps_.end()) { return Mutate(cond_value_ ? const_true() : const_false()); } return IRMutator::Mutate(e); } private: std::unordered_set<const Node*> ps_; bool cond_value_; }; // Insert the partition branch at the innermost thread scope class ThreadPartitionInserter : public IRMutator { public: explicit ThreadPartitionInserter(const std::unordered_set<const Node*>& ps, Expr cond) : ps_(ps), cond_(cond), innermost_thread_scope_(false) {} Stmt Mutate_(const AttrStmt* op, const Stmt& s) final { if (op->attr_key == attr::thread_extent) { innermost_thread_scope_ = true; Stmt stmt = IRMutator::Mutate_(op, s); // add branch code inside the innermost thread scope if (innermost_thread_scope_) { Stmt simplified_body = ConditionEliminator(ps_).Mutate(op->body); Stmt body = IfThenElse::make(cond_, simplified_body, op->body); Expr value = this->Mutate(op->value); stmt = AttrStmt::make(op->node, op->attr_key, value, body); } innermost_thread_scope_ = false; return stmt; } else { return IRMutator::Mutate_(op, s); } } private: const std::unordered_set<const Node*>& ps_; Expr cond_; bool innermost_thread_scope_; }; // Try to partition range of iteration variables in order to remove (some) // likely conditions class LoopPartitioner : public IRMutator { public: explicit LoopPartitioner(bool split_const_loop) : selector(CandidateSelector(split_const_loop)) {} Stmt VisitAndMutate(const Stmt& stmt) { selector.Visit(stmt); return Mutate(stmt); } Stmt Mutate_(const For* op, const Stmt& stmt) { if (selector.candidates.count(op)) { Stmt s = TryPartition(op, stmt, op->loop_var, op->min, op->min + op->extent - 1, op->body, false); if (s.defined()) return s; } // normal path when loop partition fails // normal loop variable can be put into hint map. hint_map_.insert({op->loop_var.get(), IntSet::interval(op->min, op->min + op->extent - 1)}); Stmt res = IRMutator::Mutate_(op, stmt); hint_map_.erase(op->loop_var.get()); return res; } Stmt Mutate_(const AttrStmt* op, const Stmt& stmt) { if (op->attr_key != attr::thread_extent) { return IRMutator::Mutate_(op, stmt); } const IterVarNode *iv = op->node.as<IterVarNode>(); CHECK(iv); Var var = iv->var; if (selector.candidates.count(op)) { Stmt s = TryPartition(op, stmt, var, 0, op->value - 1, op->body, true); if (s.defined()) return s; } // normal path when loop parittion fails. runtime::ThreadScope scope = runtime::ThreadScope::make(iv->thread_tag); Stmt res; if (scope.rank == 1) { // threadIdx should be put into relax map, in case of divergence. relax_map_.insert({var.get(), IntSet::interval(make_zero(var.type()), op->value - 1)}); res = IRMutator::Mutate_(op, stmt); relax_map_.erase(var.get()); } else { hint_map_.insert({var.get(), IntSet::interval(make_zero(var.type()), op->value - 1)}); res = IRMutator::Mutate_(op, stmt); hint_map_.erase(var.get()); } return res; } private: Stmt TryPartition(const Node* op, const Stmt& stmt, VarExpr var, Expr min, Expr max, Stmt body, bool partition_thread_scope); std::pair<IntSet, std::unordered_set<const Node*>> GetIntervalAndCondset(const Partition &partitions, const arith::IntervalSet &for_interval, bool cond_value); inline Stmt MakeFor(const Node* op, Expr extent, Stmt body); /* Candidate IRs that may be partitioned potentially */ std::unordered_map<const Variable*, IntSet> hint_map_; std::unordered_map<const Variable*, IntSet> relax_map_; arith::Analyzer analyzer_; CandidateSelector selector; }; // Returns an interval (in the first component) in which all the conditions // given in the second component provably have value given by cond_value std::pair<IntSet, std::unordered_set<const Node*>> LoopPartitioner::GetIntervalAndCondset(const Partition &partitions, const arith::IntervalSet &for_interval, bool cond_value) { Array<IntSet> sets; std::unordered_set<const Node*> cond_set; for (const auto &kv : partitions) { if (kv.first.second == cond_value) { arith::IntervalSet interval = Downcast<arith::IntervalSet>(kv.second); arith::IntervalSet intersection = arith::Intersect( &analyzer_, interval, for_interval); if (!intersection->IsEmpty()) { sets.push_back(kv.second); cond_set.insert(kv.first.first); } } } IntSet interval = sets.empty() ? IntSet::nothing() : Intersect(sets); return std::make_pair(interval, cond_set); } Stmt AppendStmts(const Stmt& a, const Stmt& b) { if (!a.defined()) { return b; } else if (!b.defined()) { return a; } else { return Block::make(a, b); } } /* * Tries to recursively partition the range of the variable (given by var) of * the for loop (given by node and stmt) into a * number of disjoint ranges such that in some ranges one or more predicates * in the loopnest are provably true or false in each range. For example, given the * following loop to partition: * for (i = 0; i < 4; i++) * for (j = 0; j < 10; j++) * if (likely(i*10 + j < 36)) * A[10*i+j] = B[10*i+j] * * We first partition range of i, i.e., [0,3] into subranges [0,2] and [3,3] because the * likely condition is always true for the first subrange but not always true for the * second subrange. Therefore, we'll have * for (i = 0; i < 3; i++) * for (j = 0; j < 10; j++) * if (likely(1)) * A[10*i+j] = B[10*i+j] * for (i = 0; i < 1; i++) * for (j = 0; j < 10; j++) * if (likely((i+3)*10 + j < 36)) * A[10*(i+3)+j] = B[10*(i+3)+j] * Which is simplified as: * for (i = 0; i < 3; i++) * for (j = 0; j < 10; j++) * A[10*i+j] = B[10*i+j] * for (j = 0; j < 10; j++) // loopnest 1 * if (likely(j < 6)) * A[30+j] = B[30+j] * Now, we recursively partition j in loopnest 1 into subranges [0,5] and [6,9] where the * condition is true for the first subrange and now always true for the second subrange. * for (j = 0; j < 6; j++) * if (likely(1)) * A[30+j] = B[30+j] * for (j = 0; j < 4; j++) // loop 2 * if (likely(j < 0)) * A[36+j] = B[36+j] * Finally we recursively partition loop 2 above into subrange [0,3] where the * condition is false and empty interval where the condition is not false, * therefore we generate * for (j = 0; j < 4; j++) * if (likely(0)) * A[36+j] = B[36+j] * which will eventually be simplified to empty code. And because only one loop was generated * from loop 2 we stop recursing. */ Stmt LoopPartitioner::TryPartition(const Node* node, const Stmt& stmt, VarExpr var, Expr min, Expr max, Stmt body, bool partition_thread_scope) { using namespace arith; // include hint of var. hint_map_.insert({var.get(), IntSet::interval(min, max)}); PartitionFinder finder(var, hint_map_, relax_map_); finder.Visit(body); hint_map_.erase(var.get()); if (finder.partitions.empty()) return Stmt(); arith::IntervalSet for_interval(min, max); bool cond_value; IntSet middle_interval; std::unordered_set<const Node*> cond_set; // find an interval in which all conditions on var are true std::tie(middle_interval, cond_set) = GetIntervalAndCondset(finder.partitions, for_interval, true); if (middle_interval.is_nothing()) { // if such interval doesn't exist, find an interval in which all // conditions on var are false std::tie(middle_interval, cond_set) = GetIntervalAndCondset(finder.partitions, for_interval, false); if (middle_interval.is_nothing()) // we couldn't find an interval in which the condintions are provably true or false // Therefore, we can't partition the loop based on those conds return Stmt(); cond_value = false; } else { cond_value = true; } IntervalSet middle_interval_i = Downcast<IntervalSet>(middle_interval); // middle_interval is the subrange of the loop variable range for which a // set of conditions are true (or false resp.) // The part of the loop variable range that is before (after resp.) that // subrange is prefixed with pre- (post- resp.) // Calculating pre-subrange and generating code for it. // pre-subrange = [min, body_begin) Expr body_begin; Stmt pre_stmt; bool pre_stmt_recurse = true; if (middle_interval_i->HasLowerBound()) { body_begin = ir::Simplify(middle_interval.min()); if (!analyzer_.CanProve(body_begin == min)) { Expr cond = (body_begin - min >= 0); if (!analyzer_.CanProve(cond)) { LOG(WARNING) << "Cannot prove: " << cond << ", when generating the pre doubt loop"; body_begin = Max::make(body_begin, min); // stop recursing on this interval if we can't prove it has non-negative length pre_stmt_recurse = false; } if (!partition_thread_scope) { Stmt pre_body = Substitute(body, {{Var{var}, var + min}}); pre_stmt = MakeFor(node, body_begin - min, pre_body); } } } else { body_begin = min; } // Calculating post-subrange and generating code for it. // post-subrange = [post_doubt_begin, max] Expr post_doubt_begin; Stmt post_stmt; bool post_stmt_recurse = true; if (middle_interval_i->HasUpperBound()) { post_doubt_begin = ir::Simplify(middle_interval.max() + 1); if (!analyzer_.CanProve(middle_interval.max() == max)) { // require the extent to be non-negative Expr cond = (max - post_doubt_begin + 1 >= 0); if (!analyzer_.CanProve(cond)) { LOG(WARNING) << "Cannot prove: " << cond << ", when generating the post doubt loop"; post_doubt_begin = Min::make(post_doubt_begin, max); // stop recursing on this interval if we can't prove it has non-negative length post_stmt_recurse = false; } if (!partition_thread_scope) { Stmt post_body = Substitute(body, {{Var{var}, var + post_doubt_begin}}); post_stmt = MakeFor(node, max - post_doubt_begin + 1, post_body); } } } else { post_doubt_begin = max + 1; } Stmt s; // Generating code for middle subrange if (!partition_thread_scope) { Stmt mid_stmt; if (!analyzer_.CanProve(body_begin >= post_doubt_begin)) { // [body_begin, post_doubt_begin) Stmt simplified_body = ConditionEliminator(cond_set, cond_value).Mutate(body); Stmt new_body = Substitute(simplified_body, {{Var{var}, var + body_begin}}); mid_stmt = MakeFor(node, post_doubt_begin - body_begin, new_body); // Recurse for each non-empty subrange only if there are at least // two non-empty subranges if (pre_stmt.defined() || post_stmt.defined()) { mid_stmt = VisitAndMutate(mid_stmt); if (pre_stmt.defined() && pre_stmt_recurse) { pre_stmt = VisitAndMutate(pre_stmt); } if (post_stmt.defined() && post_stmt_recurse) { post_stmt = VisitAndMutate(post_stmt); } } } s = AppendStmts(pre_stmt, mid_stmt); s = AppendStmts(s, post_stmt); } else { Expr cond = const_true(); if (!analyzer_.CanProve(body_begin == min)) cond = cond && (var >= body_begin); if (!analyzer_.CanProve(post_doubt_begin == (max + 1))) cond = cond && (var < post_doubt_begin); s = ThreadPartitionInserter(cond_set, cond).Mutate(stmt); } s = ConvertSSA(s); return s; } inline Stmt LoopPartitioner::MakeFor(const Node *node, Expr extent, Stmt body) { const For *for_node = static_cast<const For*>(node); CHECK(for_node); if (analyzer_.CanProve(extent == make_const(Int(32), 1))) { // If the loop extent is 1, do not create the loop anymore return Substitute(body, {{Var{for_node->loop_var}, make_const(Int(32), 0)}}); } else { return For::make(for_node->loop_var, 0, extent, for_node->for_type, for_node->device_api, body); } } class RemoveLikelyTags : public IRMutator { public: using IRMutator::Mutate; Expr Mutate_(const Call *op, const Expr& e) { if (op->is_intrinsic(Call::likely)) { CHECK_EQ(op->args.size(), 1); return IRMutator::Mutate(op->args[0]); } else { return IRMutator::Mutate_(op, e); } } }; Stmt LoopPartition(Stmt stmt, bool split_const_loop) { stmt = LoopPartitioner(split_const_loop).VisitAndMutate(stmt); stmt = RemoveLikelyTags().Mutate(stmt); return stmt; } } // namespace ir } // namespace tvm