loop_partition.cc 12.7 KB
Newer Older
1 2 3 4 5 6 7 8
/*!
 *  Copyright (c) 2017 by Contributors
 * \file loop_partition.cc
 */
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
9
#include <tvm/arithmetic.h>
10 11 12
#include <unordered_map>
#include <unordered_set>
#include "../arithmetic/int_set_internal.h"
13
#include "../runtime/thread_storage_scope.h"
14 15 16 17 18

namespace tvm {
namespace ir {

using arith::IntSet;
19 20
using arith::DeduceBound;
using arith::Intersect;
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40

// a partition means the expr is equal to true in the interval
struct Partition {
  Expr expr;
  IntSet interval;
};

bool ExprUseVars(Expr expr, const std::unordered_set<const Variable*>& vars) {
  bool success = false;
  PostOrderVisit(expr, [&vars, &success](const NodeRef& node) {
    if (const Variable* v = node.as<Variable>()) {
      if (vars.count(v)) {
        success = true;
        return;
      }
    }
  });
  return success;
}

41 42 43 44
// Select potential candidate IRs that can be partitioned.
// Rule:
//   - the range should not be const
//   - there exist a condition expression in the scope that use the var
45
class CandidateSelector final : public IRVisitor {
46 47
 public:
  using VarIsUsed = bool;
48 49
  explicit CandidateSelector(bool split_const_loop)
      : split_const_loop_(split_const_loop) {}
50 51

  void Visit_(const For* op) {
52 53
    // partition const loop when sets split_const_loop_
    if (!is_const(op->min) || !is_const(op->extent) || split_const_loop_) {
54 55 56
      const Variable* var = op->loop_var.get();
      record_.insert({var, false});
      IRVisitor::Visit_(op);
57
      if (record_.at(var) && !no_split_) {
58 59 60 61 62 63 64 65 66 67 68 69 70 71
        candidates.insert(op);
      }
      record_.erase(var);
    } else {
      IRVisitor::Visit_(op);
    }
  }

  void Visit_(const AttrStmt* op) {
    if (op->attr_key == attr::thread_extent) {
      const IterVarNode *iv = op->node.as<IterVarNode>();
      CHECK(iv);
      Var var = iv->var;
      runtime::ThreadScope scope = runtime::ThreadScope::make(iv->thread_tag);
72
      if ((scope.rank == 0) && (!is_const(op->value) || split_const_loop_)) {
73 74
        record_.insert({var.get(), false});
        IRVisitor::Visit_(op);
75
        if (record_.at(var.get()) && !no_split_) {
76 77 78 79 80 81 82 83 84
          candidates.insert(op);
        }
        record_.erase(var.get());
        return;
      }
    }
    IRVisitor::Visit_(op);
  }

85 86 87 88 89 90 91 92 93 94
  void Visit_(const Block* op) {
    bool temp = no_split_;
    this->Visit(op->first);
    // erase the no split state of first when visit rest.
    std::swap(temp, no_split_);
    this->Visit(op->rest);
    // restore the no split flag.
    no_split_ = no_split_ || temp;
  }

95 96 97 98 99
  void Visit_(const Call* op) {
    if (op->is_intrinsic(Call::likely)) {
      in_likely_ = true;
      IRVisitor::Visit_(op);
      in_likely_ = false;
100 101 102 103
    } else if (op->is_intrinsic(intrinsic::tvm_thread_allreduce)) {
      // no split if the body contains allreduce.
      no_split_ = true;
      return;
104 105 106 107 108 109 110 111 112 113 114 115 116 117
    } else {
      IRVisitor::Visit_(op);
    }
  }

  void Visit_(const Variable* op) {
    if (in_likely_ && record_.count(op)) {
      record_.at(op) = true;
    }
  }

  std::unordered_set<const Node*> candidates;

 private:
118
  bool in_likely_{false};
119
  bool no_split_{false};
120
  bool split_const_loop_{false};
121 122 123 124
  std::unordered_map<const Variable*, VarIsUsed> record_;
};

// Find valid partition for specific variable
125 126
class PartitionFinder : public IRVisitor {
 public:
127
  explicit PartitionFinder(VarExpr current_var,
128 129 130 131 132 133 134 135 136
    const std::unordered_map<const Variable*, IntSet>& hint_map,
    const std::unordered_map<const Variable*, IntSet>& relax_map)
      : current_var_(current_var), hint_map_(hint_map),  relax_map_(relax_map) {
        for (const auto& kv : hint_map) {
          out_vars_.insert(kv.first);
        }
        for (const auto& kv : relax_map) {
          out_vars_.insert(kv.first);
        }
137 138 139 140 141
      }

  void Visit_(const For* op) {
    if (ExprUseVars(op->min, out_vars_) || ExprUseVars(op->extent, out_vars_)) return;

142 143 144
    const Variable* var = op->loop_var.get();
    hint_map_.insert({var, IntSet::interval(op->min, op->min + op->extent - 1)});
    relax_map_.insert({var, IntSet::interval(op->min, op->min + op->extent - 1)});
145
    IRVisitor::Visit_(op);
146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164
    relax_map_.erase(var);
    hint_map_.erase(var);
  }

  void Visit_(const AttrStmt* op) {
    // handle thread_axis
    if (op->attr_key == attr::thread_extent) {
      const IterVarNode* thread_axis = op->node.as<IterVarNode>();
      CHECK(thread_axis);
      const Variable* var = thread_axis->var.get();
      IntSet dom = IntSet::range(Range(make_zero(op->value.type()), op->value));
      hint_map_.insert({var, dom});
      relax_map_.insert({var, dom});
      IRVisitor::Visit_(op);
      relax_map_.erase(var);
      hint_map_.erase(var);
    } else {
      IRVisitor::Visit_(op);
    }
165 166
  }

167 168 169 170 171 172 173
  void Visit_(const Call* op) {
    if (op->is_intrinsic(Call::likely)) {
      Expr cond = op->args[0];
      if (ExprUseVars(cond,
          std::unordered_set<const Variable*>({current_var_.get()}))) {
        IntSet interval =
          DeduceBound(current_var_, cond, hint_map_, relax_map_);
174 175 176
        if (!interval.is_nothing()) {
          partitions[cond.get()] = Partition{cond, interval};
        }
177
      }
178 179 180 181 182 183 184 185
    } else {
      IRVisitor::Visit_(op);
    }
  }

  std::unordered_map<const Node*, Partition> partitions;

 private:
186
  VarExpr current_var_;
187 188 189 190 191
  std::unordered_set<const Variable*> out_vars_;
  std::unordered_map<const Variable*, IntSet> hint_map_;
  std::unordered_map<const Variable*, IntSet> relax_map_;
};

192 193
// Eliminate the condition expressions by partitions
class ConditionEliminator : public IRMutator {
194
 public:
195
  explicit ConditionEliminator(const std::unordered_map<const Node*, Partition>& ps)
196 197
    : ps_(ps) {}

198 199 200
  using IRMutator::Mutate;
  Expr Mutate(Expr e) final {
    if (ps_.count(e.get())) return Mutate(const_true());
201 202 203 204 205 206 207
    return IRMutator::Mutate(e);
  }

 private:
  const std::unordered_map<const Node*, Partition>& ps_;
};

208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239

// Insert the partition branch at the innermost thread scope
class ThreadPartitionInserter : public IRMutator {
 public:
  explicit ThreadPartitionInserter(const std::unordered_map<const Node*, Partition>& ps,
    Expr cond) : ps_(ps), cond_(cond), innermost_thread_scope_(false) {}

  Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
    if (op->attr_key == attr::thread_extent) {
      innermost_thread_scope_ = true;
      Stmt stmt = IRMutator::Mutate_(op, s);
      // add branch code inside the innermost thread scope
      if (innermost_thread_scope_) {
        Stmt simplified_body = ConditionEliminator(ps_).Mutate(op->body);
        Stmt body = IfThenElse::make(cond_, simplified_body, op->body);
        Expr value = this->Mutate(op->value);
        stmt = AttrStmt::make(op->node, op->attr_key, value, body);
      }
      innermost_thread_scope_ = false;
      return stmt;
    } else {
      return IRMutator::Mutate_(op, s);
    }
  }

 private:
  const std::unordered_map<const Node*, Partition>& ps_;
  Expr cond_;
  bool innermost_thread_scope_;
};

// Try to do partition at the candidate IRs
240 241
class LoopPartitioner : public IRMutator {
 public:
242 243
  explicit LoopPartitioner(std::unordered_set<const Node*> candidates)
    : candidates_(candidates) {}
244 245

  Stmt Mutate_(const For* op, const Stmt& stmt) {
246 247 248
    if (candidates_.count(op)) {
      Stmt s = TryPartition(op, stmt, op->loop_var,
          op->min, op->min + op->extent - 1, op->body, false);
249 250
      if (s.defined()) return s;
    }
251 252 253 254

    // normal path when loop parittion fails
    // normal loop variable can be put into hint map.
    hint_map_.insert({op->loop_var.get(),
255 256
      IntSet::interval(op->min, op->min + op->extent - 1)});
    Stmt res = IRMutator::Mutate_(op, stmt);
257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288
    hint_map_.erase(op->loop_var.get());
    return res;
  }

  Stmt Mutate_(const AttrStmt* op, const Stmt& stmt) {
    if (op->attr_key != attr::thread_extent) {
      return IRMutator::Mutate_(op, stmt);
    }

    const IterVarNode *iv = op->node.as<IterVarNode>();
    CHECK(iv);
    Var var = iv->var;
    if (candidates_.count(op)) {
      Stmt s = TryPartition(op, stmt, var, 0, op->value - 1, op->body, true);
      if (s.defined()) return s;
    }

    // normal path when loop parittion fails.
    runtime::ThreadScope scope = runtime::ThreadScope::make(iv->thread_tag);
    Stmt res;
    if (scope.rank == 1) {
      // threadIdx should be put into relax map, in case of divergence.
      relax_map_.insert({var.get(),
        IntSet::interval(make_zero(var.type()), op->value - 1)});
      res = IRMutator::Mutate_(op, stmt);
      relax_map_.erase(var.get());
    } else {
      hint_map_.insert({var.get(),
        IntSet::interval(make_zero(var.type()), op->value - 1)});
      res = IRMutator::Mutate_(op, stmt);
      hint_map_.erase(var.get());
    }
289 290 291 292
    return res;
  }

 private:
293 294 295
  Stmt TryPartition(const Node* op, const Stmt& stmt, VarExpr var,
      Expr min, Expr max, Stmt body, bool partition_thread_scope);
  inline Stmt MakeFor(const Node* op, Expr extent, Stmt body);
296

297 298 299 300
  /* Candidate IRs that may be partitioned potentially */
  std::unordered_set<const Node*> candidates_;
  std::unordered_map<const Variable*, IntSet> hint_map_;
  std::unordered_map<const Variable*, IntSet> relax_map_;
301 302
};

303 304 305 306 307 308 309
Stmt LoopPartitioner::TryPartition(const Node* node,
                                   const Stmt& stmt,
                                   VarExpr var,
                                   Expr min,
                                   Expr max,
                                   Stmt body,
                                   bool partition_thread_scope) {
310 311
  PartitionFinder finder(var, hint_map_, relax_map_);
  finder.Visit(body);
312 313 314 315 316 317 318 319 320 321 322
  const auto& partitions = finder.partitions;
  if (partitions.empty()) return Stmt();

  Array<IntSet> sets;
  // merge partitions (take their intersect)
  for (const auto& kv : partitions) {
    sets.push_back(kv.second.interval);
  }
  IntSet true_itrv  = Intersect(sets);

  Expr body_begin;
323
  Stmt pre_stmt;
324 325 326
  if (true_itrv.as<arith::IntervalSet>()->i.has_lower_bound()) {
    body_begin = true_itrv.min();
    if (!can_prove(body_begin == min)) {
327 328 329
      Expr cond = (body_begin - min >= 0);
      if (!can_prove(cond)) {
        LOG(WARNING) << "Cannot prove: " << cond
330 331 332 333
                     << ", when generating the pre doubt loop";
        body_begin = Max::make(body_begin, min);
      }
      // [min, body_begin)
334 335 336 337
      if (!partition_thread_scope) {
        Stmt pre_body = Substitute(body, {{Var{var}, var + min}});
        pre_stmt = MakeFor(node, body_begin - min, pre_body);
      }
338 339 340 341 342 343
    }
  } else {
    body_begin = min;
  }

  Expr post_doubt_begin;
344
  Stmt post_stmt;
345 346 347
  if (true_itrv.as<arith::IntervalSet>()->i.has_upper_bound()) {
    post_doubt_begin = true_itrv.max() + 1;
    if (!can_prove(true_itrv.max() == max)) {
348 349
      // require the extent to be non-negative
      Expr cond = (max - post_doubt_begin + 1 >= 0);
350 351
      if (!can_prove(cond)) {
        LOG(WARNING) << "Cannot prove: " << cond
352 353 354 355
                     << ", when generating the post doubt loop";
        post_doubt_begin = Min::make(post_doubt_begin, max);
      }
      // [post_doubt_begin, max]
356 357 358 359
      if (!partition_thread_scope) {
        Stmt post_body = Substitute(body, {{Var{var}, var + post_doubt_begin}});
        post_stmt = MakeFor(node, max - post_doubt_begin + 1, post_body);
      }
360 361 362 363 364
    }
  } else {
    post_doubt_begin = max + 1;
  }

365 366 367 368 369 370 371 372 373 374 375 376 377
  Stmt s;
  if (!partition_thread_scope) {
    // [body_begin, post_doubt_begin)
    Stmt simplified_body = ConditionEliminator(partitions).Mutate(body);
    Stmt new_body = Substitute(simplified_body, {{Var{var}, var + body_begin}});
    s = MakeFor(node, post_doubt_begin - body_begin, new_body);
    if (pre_stmt.defined())  s = Block::make(pre_stmt, s);
    if (post_stmt.defined()) s = Block::make(s, post_stmt);
  } else {
    Expr cond = const_true();
    if (!can_prove(body_begin == min)) cond = cond && (var >= body_begin);
    if (!can_prove(post_doubt_begin == (max + 1))) cond = cond && (var < post_doubt_begin);
    s = ThreadPartitionInserter(partitions, cond).Mutate(stmt);
378
  }
379 380 381
  s = ConvertSSA(s);
  return s;
}
382

383 384 385 386 387
inline Stmt LoopPartitioner::MakeFor(const Node *node, Expr extent, Stmt body) {
  const For *for_node = static_cast<const For*>(node);
  CHECK(for_node);
  return For::make(for_node->loop_var, 0, extent,
    for_node->for_type, for_node->device_api, body);
388 389
}

390 391 392 393 394 395 396 397 398 399 400 401 402 403
class RemoveLikelyTags : public IRMutator {
 public:
  using IRMutator::Mutate;

  Expr Mutate_(const Call *op, const Expr& e) {
    if (op->is_intrinsic(Call::likely)) {
      CHECK_EQ(op->args.size(), 1);
      return IRMutator::Mutate(op->args[0]);
    } else {
      return IRMutator::Mutate_(op, e);
    }
  }
};

404 405
Stmt LoopPartition(Stmt stmt, bool split_const_loop) {
  CandidateSelector selector(split_const_loop);
406 407 408
  selector.Visit(stmt);
  stmt = LoopPartitioner(selector.candidates).Mutate(stmt);
  stmt = RemoveLikelyTags().Mutate(stmt);
409 410 411 412 413
  return stmt;
}

}  // namespace ir
}  // namespace tvm