loop_partition.cc 20.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23 24 25 26 27
/*!
 *  Copyright (c) 2017 by Contributors
 * \file loop_partition.cc
 */
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
28
#include <tvm/arithmetic.h>
29 30 31
#include <unordered_map>
#include <unordered_set>
#include "../arithmetic/int_set_internal.h"
32
#include "../runtime/thread_storage_scope.h"
33 34 35 36 37

namespace tvm {
namespace ir {

using arith::IntSet;
38 39
using arith::DeduceBound;
using arith::Intersect;
40

41 42 43 44 45 46 47
using PartitionKey = std::pair<const Node*, bool>;
struct PartitionKeyHash {
  std::size_t operator()(PartitionKey const& k) const noexcept {
    std::size_t h1 = std::hash<const Node*>{}(k.first);
    std::size_t h2 = std::hash<bool>{}(k.second);
    return h1 ^ h2;
  }
48 49
};

50 51 52 53 54
// Each mapping (cond, cond_value) -> interval represents the fact that
// condition cond is proven to have value cond_value (true or false) in interval.
using Partition = std::unordered_map<PartitionKey, IntSet, PartitionKeyHash>;


55 56 57 58 59 60 61 62 63 64 65 66 67
bool ExprUseVars(Expr expr, const std::unordered_set<const Variable*>& vars) {
  bool success = false;
  PostOrderVisit(expr, [&vars, &success](const NodeRef& node) {
    if (const Variable* v = node.as<Variable>()) {
      if (vars.count(v)) {
        success = true;
        return;
      }
    }
  });
  return success;
}

68 69 70 71
// Select potential candidate IRs that can be partitioned.
// Rule:
//   - the range should not be const
//   - there exist a condition expression in the scope that use the var
72
class CandidateSelector final : public IRVisitor {
73 74
 public:
  using VarIsUsed = bool;
75 76
  explicit CandidateSelector(bool split_const_loop)
      : split_const_loop_(split_const_loop) {}
77 78

  void Visit_(const For* op) {
79 80
    // partition const loop when sets split_const_loop_
    if (!is_const(op->min) || !is_const(op->extent) || split_const_loop_) {
81 82 83
      const Variable* var = op->loop_var.get();
      record_.insert({var, false});
      IRVisitor::Visit_(op);
84
      if (record_.at(var) && !no_split_) {
85 86 87 88 89 90 91 92 93 94 95 96 97 98
        candidates.insert(op);
      }
      record_.erase(var);
    } else {
      IRVisitor::Visit_(op);
    }
  }

  void Visit_(const AttrStmt* op) {
    if (op->attr_key == attr::thread_extent) {
      const IterVarNode *iv = op->node.as<IterVarNode>();
      CHECK(iv);
      Var var = iv->var;
      runtime::ThreadScope scope = runtime::ThreadScope::make(iv->thread_tag);
99
      if ((scope.rank == 0) && (!is_const(op->value) || split_const_loop_)) {
100 101
        record_.insert({var.get(), false});
        IRVisitor::Visit_(op);
102
        if (record_.at(var.get()) && !no_split_) {
103 104 105 106 107 108 109 110 111
          candidates.insert(op);
        }
        record_.erase(var.get());
        return;
      }
    }
    IRVisitor::Visit_(op);
  }

112 113 114 115 116 117 118 119 120 121
  void Visit_(const Block* op) {
    bool temp = no_split_;
    this->Visit(op->first);
    // erase the no split state of first when visit rest.
    std::swap(temp, no_split_);
    this->Visit(op->rest);
    // restore the no split flag.
    no_split_ = no_split_ || temp;
  }

122 123 124 125 126
  void Visit_(const Call* op) {
    if (op->is_intrinsic(Call::likely)) {
      in_likely_ = true;
      IRVisitor::Visit_(op);
      in_likely_ = false;
127 128 129 130
    } else if (op->is_intrinsic(intrinsic::tvm_thread_allreduce)) {
      // no split if the body contains allreduce.
      no_split_ = true;
      return;
131 132 133 134 135 136 137 138 139 140 141 142 143 144
    } else {
      IRVisitor::Visit_(op);
    }
  }

  void Visit_(const Variable* op) {
    if (in_likely_ && record_.count(op)) {
      record_.at(op) = true;
    }
  }

  std::unordered_set<const Node*> candidates;

 private:
145
  bool in_likely_{false};
146
  bool no_split_{false};
147
  bool split_const_loop_{false};
148 149 150
  std::unordered_map<const Variable*, VarIsUsed> record_;
};

151 152 153
// Populate partitions data structure, i.e., for a specific variable,
// find an interval in which each condition
// (currently, "likely" conditions) has fixed true or false value
154 155
class PartitionFinder : public IRVisitor {
 public:
156
  explicit PartitionFinder(VarExpr current_var,
157 158 159 160 161 162 163 164 165
    const std::unordered_map<const Variable*, IntSet>& hint_map,
    const std::unordered_map<const Variable*, IntSet>& relax_map)
      : current_var_(current_var), hint_map_(hint_map),  relax_map_(relax_map) {
        for (const auto& kv : hint_map) {
          out_vars_.insert(kv.first);
        }
        for (const auto& kv : relax_map) {
          out_vars_.insert(kv.first);
        }
166 167 168 169 170
      }

  void Visit_(const For* op) {
    if (ExprUseVars(op->min, out_vars_) || ExprUseVars(op->extent, out_vars_)) return;

171 172 173
    const Variable* var = op->loop_var.get();
    hint_map_.insert({var, IntSet::interval(op->min, op->min + op->extent - 1)});
    relax_map_.insert({var, IntSet::interval(op->min, op->min + op->extent - 1)});
174
    IRVisitor::Visit_(op);
175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193
    relax_map_.erase(var);
    hint_map_.erase(var);
  }

  void Visit_(const AttrStmt* op) {
    // handle thread_axis
    if (op->attr_key == attr::thread_extent) {
      const IterVarNode* thread_axis = op->node.as<IterVarNode>();
      CHECK(thread_axis);
      const Variable* var = thread_axis->var.get();
      IntSet dom = IntSet::range(Range(make_zero(op->value.type()), op->value));
      hint_map_.insert({var, dom});
      relax_map_.insert({var, dom});
      IRVisitor::Visit_(op);
      relax_map_.erase(var);
      hint_map_.erase(var);
    } else {
      IRVisitor::Visit_(op);
    }
194 195
  }

196 197 198 199 200
  void Visit_(const Call* op) {
    if (op->is_intrinsic(Call::likely)) {
      Expr cond = op->args[0];
      if (ExprUseVars(cond,
          std::unordered_set<const Variable*>({current_var_.get()}))) {
201 202 203
        // For cond, find out the interval, if exists, in which we can prove that cond is
        // true. Also find the interval, if exists, in which we can prove that cond is
        // false.
204
        IntSet interval =
205
                DeduceBound(current_var_, cond, hint_map_, relax_map_);
206
        if (!interval.is_nothing()) {
207 208 209 210 211 212 213 214 215 216 217
          // cond is true within interval
          partitions[{cond.get(), true}] = interval;
        }
        Expr inverse_cond = InverseCond(cond);
        if (inverse_cond.defined()) {
          IntSet interval =
                  DeduceBound(current_var_, inverse_cond, hint_map_, relax_map_);
          if (!interval.is_nothing()) {
            // cond is false within interval
            partitions[{cond.get(), false}] = interval;
          }
218
        }
219
      }
220 221 222 223 224
    } else {
      IRVisitor::Visit_(op);
    }
  }

225
  Partition partitions;
226 227

 private:
228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247
  Expr InverseCond(const Expr& cond) {
    // We expect most condition not to be of EQ or NE form.
    // Currently we do not handle inversing EQ or NE.
    Expr inverse_cond;
    if (const LT* op = cond.as<LT>()) {
      // a < b -> a >= b
      inverse_cond = GE::make(op->a, op->b);
    } else if (const GT* op = cond.as<GT>()) {
      // a > b -> a <= b
      inverse_cond = LE::make(op->a, op->b);
    } else if (const LE* op = cond.as<LE>()) {
      // a <= b -> a > b
      inverse_cond = GT::make(op->a, op->b);
    } else if (const GE* op = cond.as<GE>()) {
      // a >= b -> a < b
      inverse_cond = LT::make(op->a, op->b);
    }
    return inverse_cond;
  }

248
  VarExpr current_var_;
249 250 251 252 253
  std::unordered_set<const Variable*> out_vars_;
  std::unordered_map<const Variable*, IntSet> hint_map_;
  std::unordered_map<const Variable*, IntSet> relax_map_;
};

254
// Replace the set of conditions given by ps with cond_value (true or false)
255
class ConditionEliminator : public IRMutator {
256
 public:
257 258
  explicit ConditionEliminator(const std::unordered_set<const Node*>& ps, bool cond_value = true)
    : ps_(ps), cond_value_(cond_value) {}
259

260 261
  using IRMutator::Mutate;
  Expr Mutate(Expr e) final {
262 263 264
    if (ps_.find(e.get()) != ps_.end()) {
      return Mutate(cond_value_ ? const_true() : const_false());
    }
265 266 267 268
    return IRMutator::Mutate(e);
  }

 private:
269 270
  std::unordered_set<const Node*> ps_;
  bool cond_value_;
271 272
};

273 274 275 276

// Insert the partition branch at the innermost thread scope
class ThreadPartitionInserter : public IRMutator {
 public:
277
  explicit ThreadPartitionInserter(const std::unordered_set<const Node*>& ps,
278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298
    Expr cond) : ps_(ps), cond_(cond), innermost_thread_scope_(false) {}

  Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
    if (op->attr_key == attr::thread_extent) {
      innermost_thread_scope_ = true;
      Stmt stmt = IRMutator::Mutate_(op, s);
      // add branch code inside the innermost thread scope
      if (innermost_thread_scope_) {
        Stmt simplified_body = ConditionEliminator(ps_).Mutate(op->body);
        Stmt body = IfThenElse::make(cond_, simplified_body, op->body);
        Expr value = this->Mutate(op->value);
        stmt = AttrStmt::make(op->node, op->attr_key, value, body);
      }
      innermost_thread_scope_ = false;
      return stmt;
    } else {
      return IRMutator::Mutate_(op, s);
    }
  }

 private:
299
  const std::unordered_set<const Node*>& ps_;
300 301 302 303
  Expr cond_;
  bool innermost_thread_scope_;
};

304 305
// Try to partition range of iteration variables in order to remove (some)
// likely conditions
306 307
class LoopPartitioner : public IRMutator {
 public:
308 309 310 311 312 313 314
  explicit LoopPartitioner(bool split_const_loop)
      : selector(CandidateSelector(split_const_loop)) {}

  Stmt VisitAndMutate(const Stmt& stmt) {
    selector.Visit(stmt);
    return Mutate(stmt);
  }
315 316

  Stmt Mutate_(const For* op, const Stmt& stmt) {
317
    if (selector.candidates.count(op)) {
318 319
      Stmt s = TryPartition(op, stmt, op->loop_var,
          op->min, op->min + op->extent - 1, op->body, false);
320 321
      if (s.defined()) return s;
    }
322

323
    // normal path when loop partition fails
324 325
    // normal loop variable can be put into hint map.
    hint_map_.insert({op->loop_var.get(),
326 327
      IntSet::interval(op->min, op->min + op->extent - 1)});
    Stmt res = IRMutator::Mutate_(op, stmt);
328 329 330 331 332 333 334 335 336 337 338 339
    hint_map_.erase(op->loop_var.get());
    return res;
  }

  Stmt Mutate_(const AttrStmt* op, const Stmt& stmt) {
    if (op->attr_key != attr::thread_extent) {
      return IRMutator::Mutate_(op, stmt);
    }

    const IterVarNode *iv = op->node.as<IterVarNode>();
    CHECK(iv);
    Var var = iv->var;
340
    if (selector.candidates.count(op)) {
341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359
      Stmt s = TryPartition(op, stmt, var, 0, op->value - 1, op->body, true);
      if (s.defined()) return s;
    }

    // normal path when loop parittion fails.
    runtime::ThreadScope scope = runtime::ThreadScope::make(iv->thread_tag);
    Stmt res;
    if (scope.rank == 1) {
      // threadIdx should be put into relax map, in case of divergence.
      relax_map_.insert({var.get(),
        IntSet::interval(make_zero(var.type()), op->value - 1)});
      res = IRMutator::Mutate_(op, stmt);
      relax_map_.erase(var.get());
    } else {
      hint_map_.insert({var.get(),
        IntSet::interval(make_zero(var.type()), op->value - 1)});
      res = IRMutator::Mutate_(op, stmt);
      hint_map_.erase(var.get());
    }
360 361 362 363
    return res;
  }

 private:
364 365
  Stmt TryPartition(const Node* op, const Stmt& stmt, VarExpr var,
      Expr min, Expr max, Stmt body, bool partition_thread_scope);
366 367 368 369 370 371

  std::pair<IntSet, std::unordered_set<const Node*>>
  GetIntervalAndCondset(const Partition &partitions,
                        const arith::Interval &for_interval,
                        bool cond_value);

372
  inline Stmt MakeFor(const Node* op, Expr extent, Stmt body);
373

374 375 376
  /* Candidate IRs that may be partitioned potentially */
  std::unordered_map<const Variable*, IntSet> hint_map_;
  std::unordered_map<const Variable*, IntSet> relax_map_;
377
  CandidateSelector selector;
378 379
};

380 381 382 383 384 385 386 387 388 389 390 391
// Returns an interval (in the first component) in which all the conditions
// given in the second component provably have value given by cond_value
std::pair<IntSet, std::unordered_set<const Node*>>
LoopPartitioner::GetIntervalAndCondset(const Partition &partitions,
                                       const arith::Interval &for_interval,
                                       bool cond_value) {
  Array<IntSet> sets;
  std::unordered_set<const Node*> cond_set;

  for (const auto &kv : partitions) {
    if (kv.first.second == cond_value) {
      arith::Interval interval = kv.second.as<arith::IntervalSet>()->i;
392
      arith::Interval intersection = arith::Interval::make_intersection(interval, for_interval);
393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458
      if (!intersection.is_empty()) {
        sets.push_back(kv.second);
        cond_set.insert(kv.first.first);
      }
    }
  }
  IntSet interval = sets.empty() ? IntSet::nothing() : Intersect(sets);
  return std::make_pair(interval, cond_set);
}

Stmt AppendStmts(const Stmt& a, const Stmt& b) {
  if (!a.defined()) {
    return b;
  } else if (!b.defined()) {
    return a;
  } else {
    return Block::make(a, b);
  }
}

/*
 * Tries to recursively partition the range of the variable (given by var) of
 * the for loop (given by node and stmt) into a
 * number of disjoint ranges such that in some ranges one or more predicates
 * in the loopnest are provably true or false in each range. For example, given the
 * following loop to partition:
 * for (i = 0; i < 4; i++)
 *    for (j = 0; j < 10; j++)
 *        if (likely(i*10 + j < 36))
 *            A[10*i+j] = B[10*i+j]
 *
 * We first partition range of i, i.e., [0,3] into subranges [0,2] and [3,3] because the
 * likely condition is always true for the first subrange but not always true for the
 * second subrange. Therefore, we'll have
 * for (i = 0; i < 3; i++)
 *    for (j = 0; j < 10; j++)
 *        if (likely(1))
 *           A[10*i+j] = B[10*i+j]
 * for (i = 0; i < 1; i++)
 *    for (j = 0; j < 10; j++)
 *        if (likely((i+3)*10 + j < 36))
 *            A[10*(i+3)+j] = B[10*(i+3)+j]
 * Which is simplified as:
 * for (i = 0; i < 3; i++)
 *    for (j = 0; j < 10; j++)
 *        A[10*i+j] = B[10*i+j]
 * for (j = 0; j < 10; j++) // loopnest 1
 *    if (likely(j < 6))
 *            A[30+j] = B[30+j]
 * Now, we recursively partition j in loopnest 1 into subranges [0,5] and [6,9] where the
 * condition is true for the first subrange and now always true for the second subrange.
 * for (j = 0; j < 6; j++)
 *    if (likely(1))
 *         A[30+j] = B[30+j]
 * for (j = 0; j < 4; j++) // loop 2
 *    if (likely(j < 0))
 *        A[36+j] = B[36+j]
 * Finally we recursively partition loop 2 above into subrange [0,3] where the
 * condition is false and empty interval where the condition is not false,
 * therefore we generate
 * for (j = 0; j < 4; j++)
 *    if (likely(0))
 *        A[36+j] = B[36+j]
 * which will eventually be simplified to empty code. And because only one loop was generated
 * from loop 2 we stop recursing.
 */
459 460 461 462 463 464 465
Stmt LoopPartitioner::TryPartition(const Node* node,
                                   const Stmt& stmt,
                                   VarExpr var,
                                   Expr min,
                                   Expr max,
                                   Stmt body,
                                   bool partition_thread_scope) {
466 467
  PartitionFinder finder(var, hint_map_, relax_map_);
  finder.Visit(body);
468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488
  if (finder.partitions.empty()) return Stmt();

  arith::Interval for_interval(min, max);
  bool cond_value;
  IntSet middle_interval;
  std::unordered_set<const Node*> cond_set;
  // find an interval in which all conditions on var are true
  std::tie(middle_interval, cond_set) =
          GetIntervalAndCondset(finder.partitions, for_interval, true);
  if (middle_interval.is_nothing()) {
    // if such interval doesn't exist, find an interval in which all
    // conditions on var are false
    std::tie(middle_interval, cond_set) =
            GetIntervalAndCondset(finder.partitions, for_interval, false);
    if (middle_interval.is_nothing())
      // we couldn't find an interval in which the condintions are provably true or false
      // Therefore, we can't partition the loop based on those conds
      return Stmt();
    cond_value = false;
  } else {
    cond_value = true;
489 490
  }

491 492 493 494 495 496 497 498
  arith::Interval middle_interval_i = middle_interval.as<arith::IntervalSet>()->i;
  // middle_interval is the subrange of the loop variable range for which a
  // set of conditions are true (or false resp.)
  // The part of the loop variable range that is before (after resp.) that
  // subrange is prefixed with pre- (post- resp.)

  // Calculating pre-subrange and generating code for it.
  // pre-subrange = [min, body_begin)
499
  Expr body_begin;
500
  Stmt pre_stmt;
501 502 503
  bool pre_stmt_recurse = true;
  if (middle_interval_i.has_lower_bound()) {
    body_begin = ir::Simplify(middle_interval.min());
504
    if (!can_prove(body_begin == min)) {
505 506 507
      Expr cond = (body_begin - min >= 0);
      if (!can_prove(cond)) {
        LOG(WARNING) << "Cannot prove: " << cond
508 509
                     << ", when generating the pre doubt loop";
        body_begin = Max::make(body_begin, min);
510 511
        // stop recursing on this interval if we can't prove it has non-negative length
        pre_stmt_recurse = false;
512
      }
513 514 515 516
      if (!partition_thread_scope) {
        Stmt pre_body = Substitute(body, {{Var{var}, var + min}});
        pre_stmt = MakeFor(node, body_begin - min, pre_body);
      }
517 518 519 520 521
    }
  } else {
    body_begin = min;
  }

522 523
  // Calculating post-subrange and generating code for it.
  // post-subrange = [post_doubt_begin, max]
524
  Expr post_doubt_begin;
525
  Stmt post_stmt;
526 527 528 529
  bool post_stmt_recurse = true;
  if (middle_interval_i.has_upper_bound()) {
    post_doubt_begin = ir::Simplify(middle_interval.max() + 1);
    if (!can_prove(middle_interval.max() == max)) {
530 531
      // require the extent to be non-negative
      Expr cond = (max - post_doubt_begin + 1 >= 0);
532 533
      if (!can_prove(cond)) {
        LOG(WARNING) << "Cannot prove: " << cond
534 535
                     << ", when generating the post doubt loop";
        post_doubt_begin = Min::make(post_doubt_begin, max);
536 537
        // stop recursing on this interval if we can't prove it has non-negative length
        post_stmt_recurse = false;
538
      }
539
      if (!partition_thread_scope) {
540 541 542
        Stmt post_body =
                Substitute(body, {{Var{var}, var + post_doubt_begin}});
        post_stmt = MakeFor(node, max - post_doubt_begin + 1, post_body);
543
      }
544 545 546 547 548
    }
  } else {
    post_doubt_begin = max + 1;
  }

549
  Stmt s;
550 551

  // Generating code for middle subrange
552
  if (!partition_thread_scope) {
553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569
    Stmt mid_stmt;
    if (!can_prove(body_begin >= post_doubt_begin)) {
      // [body_begin, post_doubt_begin)
      Stmt simplified_body = ConditionEliminator(cond_set, cond_value).Mutate(body);
      Stmt new_body = Substitute(simplified_body, {{Var{var}, var + body_begin}});
      mid_stmt = MakeFor(node, post_doubt_begin - body_begin, new_body);

      // Recurse for each non-empty subrange only if there are at least
      // two non-empty subranges
      if (pre_stmt.defined() || post_stmt.defined()) {
        mid_stmt = VisitAndMutate(mid_stmt);
        if (pre_stmt.defined() && pre_stmt_recurse) {
          pre_stmt = VisitAndMutate(pre_stmt);
        }
        if (post_stmt.defined() && post_stmt_recurse) {
          post_stmt = VisitAndMutate(post_stmt);
        }
570 571
      }
    }
572 573
    s = AppendStmts(pre_stmt, mid_stmt);
    s = AppendStmts(s, post_stmt);
574 575 576 577
  } else {
    Expr cond = const_true();
    if (!can_prove(body_begin == min)) cond = cond && (var >= body_begin);
    if (!can_prove(post_doubt_begin == (max + 1))) cond = cond && (var < post_doubt_begin);
578
    s = ThreadPartitionInserter(cond_set, cond).Mutate(stmt);
579
  }
580 581 582
  s = ConvertSSA(s);
  return s;
}
583

584 585 586
inline Stmt LoopPartitioner::MakeFor(const Node *node, Expr extent, Stmt body) {
  const For *for_node = static_cast<const For*>(node);
  CHECK(for_node);
587 588 589 590 591 592 593
  if (can_prove(extent == make_const(Int(32), 1))) {
    // If the loop extent is 1, do not create the loop anymore
    return Substitute(body, {{Var{for_node->loop_var}, make_const(Int(32), 0)}});
  } else {
    return For::make(for_node->loop_var, 0, extent,
                     for_node->for_type, for_node->device_api, body);
  }
594 595
}

596 597 598 599 600 601 602 603 604 605 606 607 608 609
class RemoveLikelyTags : public IRMutator {
 public:
  using IRMutator::Mutate;

  Expr Mutate_(const Call *op, const Expr& e) {
    if (op->is_intrinsic(Call::likely)) {
      CHECK_EQ(op->args.size(), 1);
      return IRMutator::Mutate(op->args[0]);
    } else {
      return IRMutator::Mutate_(op, e);
    }
  }
};

610
Stmt LoopPartition(Stmt stmt, bool split_const_loop) {
611
  stmt = LoopPartitioner(split_const_loop).VisitAndMutate(stmt);
612
  stmt = RemoveLikelyTags().Mutate(stmt);
613 614 615 616 617
  return stmt;
}

}  // namespace ir
}  // namespace tvm