canonical.cc 21.4 KB
Newer Older
1 2 3 4 5 6
/*!
 *  Copyright (c) 2017 by Contributors
 * \file canonical.cc
 * \brief Canonicalize simplification.
 */
#include <tvm/ir_mutator.h>
7
#include <tvm/arithmetic.h>
8
#include <tvm/ir_pass.h>
9 10 11 12 13 14
#include <algorithm>
#include <map>
#include <limits>
#include <vector>
#include "canonical.h"
#include "compute_expr.h"
15
#include "arithmetic/Simplify.h"
16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35

namespace tvm {
namespace arith {
using namespace ir;

// Canonical entry for communicative ops.
struct ComExprEntry {
  // the value of the expression.
  Expr value;
  // the level of the expression.
  int level{0};
  // The integer scale on value
  int64_t scale{1};

  ComExprEntry() {}
  ComExprEntry(Expr value, int level)
      : value(value), level(level) {}
  inline bool operator<(const ComExprEntry& other) const {
    if (level < other.level) return true;
    if (level > other.level) return false;
36
    // compare top operator of entries and sort on that if possible (fast check)
37 38
    if (value.type_index() < other.value.type_index()) return true;
    if (value.type_index() > other.value.type_index()) return false;
39 40 41 42 43 44 45 46
    // if none of the above distinguishes the terms, compare the expression tree of the entries.
    // This is a slower check.
    int compare_result = Compare(value, other.value);
    if (compare_result < 0) return true;
    if (compare_result > 0) return false;
    // it's a problem if we see identical entries at this point. They should've been merged earlier.
    LOG(WARNING) << "we should not have identical entries at this point";
    return false;
47 48 49 50
  }
};

// canonical expression for communicative expression.
51
struct ComExprNode : public NodeBase {
52 53 54 55 56 57 58 59 60 61 62
  // base constant value.
  int64_t base{0};
  // The values to be sumed.
  std::vector<ComExprEntry> elem;
};

// canonical communicative expression
struct ComExpr {
 public:
  // constructor
  ComExpr() {}
63
  explicit ComExpr(NodePtr<ComExprNode> ptr) : ptr_(ptr) {}
64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
  // get member
  ComExprNode* operator->() const {
    return ptr_.get();
  }
  void reset() {
    ptr_.reset();
  }
  bool defined() const {
    return ptr_.get() != nullptr;
  }
  // comparator
  bool operator<(const ComExpr& b) const {
    const ComExpr& a = *this;
    if (a->base < b->base) return true;
    if (a->base > b->base) return false;
    if (a->elem.size() < b->elem.size()) return true;
    if (a->elem.size() > b->elem.size()) return false;
    for (size_t i = 0; i < a->elem.size(); ++i) {
      const ComExprEntry& ea = a->elem[i];
      const ComExprEntry& eb = b->elem[i];
      if (ea.level < eb.level) return true;
      if (ea.level > eb.level) return false;
      if (ea.value.get() < eb.value.get()) return true;
      if (ea.value.get() > eb.value.get()) return false;
      if (ea.scale < eb.scale) return true;
      if (ea.scale > eb.scale) return false;
    }
    return false;
  }
  // equality
  bool operator==(const ComExpr& b) const {
    const ComExpr& a = *this;
    if (a->base != b->base) return false;
    if (a->elem.size() != b->elem.size()) return false;
    for (size_t i = 0; i < a->elem.size(); ++i) {
      const ComExprEntry& ea = a->elem[i];
      const ComExprEntry& eb = b->elem[i];
      if (ea.level != eb.level) return false;
      if (ea.value.get() != eb.value.get()) return false;
      if (ea.scale != eb.scale) return false;
    }
    return true;
  }

 private:
109
  NodePtr<ComExprNode> ptr_;
110 111
};

112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
// binary comparison op.
struct BinaryExpr {
  int kind;
  Expr lhs, rhs;
  // comparator
  bool operator<(const BinaryExpr& b) const {
    if (kind < b.kind) return true;
    if (kind > b.kind) return false;
    if (lhs.get() < b.lhs.get()) return true;
    if (lhs.get() > b.lhs.get()) return false;
    return rhs.get() < b.rhs.get();
  }
  // equality
  bool operator==(const BinaryExpr& b) const {
    return kind == b.kind &&
        lhs.same_as(b.lhs) &&
        rhs.same_as(b.rhs);
  }
};

132

133 134 135 136 137 138 139 140 141 142 143 144 145 146
template<typename T>
inline Expr Binary_(const T* op,
                    const Expr& e,
                    Expr a, Expr b) {
  if (a.same_as(op->a) && b.same_as(op->b)) {
    return e;
  } else {
    return T::make(a, b);
  }
}

// internal of canonical engine.
class Canonical::Internal : public IRMutator {
 public:
147 148 149 150 151
  explicit Internal(Map<Var, Range> vrange) {
    for (auto kv : vrange) {
      SetRange(kv.first, kv.second, 0);
    }
  }
152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175
  // stack entry.
  struct StackEntry {
    int max_level{0};
    bool has_side_effect{false};
  };
  // aggressively canonicalized expression
  struct CacheEntry {
    // The canonical value of the expression.
    Expr value;
    // The level of the expression.
    int max_level{0};
    // whether the expression might have side effect.
    bool has_side_effect{false};
    // if not null, corresponds to to sum
    ComExpr sum;
    // reset the return entry.
    void reset() {
      sum.reset();
    }
    // as sum expr
    ComExpr AsSum() const {
      if (sum.defined()) return sum;
      const int64_t *v1 = as_const_int(value);
      const uint64_t *v2 = as_const_uint(value);
176
      auto n = make_node<ComExprNode>();
177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196
      if (v1) {
        n->base = *v1;
      } else if (v2) {
        CHECK_LE(*v2,
               static_cast<uint64_t>(std::numeric_limits<int64_t>::max()));
        n->base = static_cast<int64_t>(*v2);
      } else {
        n->elem.push_back(ComExprEntry(value, max_level));
      }
      return ComExpr(n);
    }
  };
  // Set range and level of var.
  void SetRange(Var v, Range r, int level) {
    var_range_[v.get()] = IntSet::range(r);
    var_level_[v.get()] = level;
    var_rec_.push_back(v);
  }
  // functions
  Stmt Mutate(Stmt stmt) final {
197 198
    stmt = IRMutator::Mutate(stmt);
    return stmt;
199 200 201
  }
  Expr MutateExpr_(Expr expr) {
    stack_.push_back(StackEntry());
202
    expr = IRMutator::Mutate(expr);
203 204 205 206 207 208 209 210 211 212 213
    // update result of parent automatically during pop
    if (stack_.size() > 1) {
      StackEntry& back = stack_[stack_.size() - 1];
      StackEntry& prev = stack_[stack_.size() - 2];
      prev.max_level = std::max(prev.max_level, back.max_level);
      if (back.has_side_effect) prev.has_side_effect = true;
    }
    // copy result from stack
    ret_entry_.has_side_effect = stack_.back().has_side_effect;
    ret_entry_.max_level = stack_.back().max_level;
    stack_.pop_back();
214
    CHECK(expr.defined());
215 216 217
    if (const IntImm* op = expr.as<IntImm>()) {
      return Mutate_(op, expr);
    }
218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239
    return expr;
  }
  // call produce to get a cache entry.
  CacheEntry Produce(Expr expr) {
    ret_entry_.reset();
    ret_entry_.value = MutateExpr_(expr);
    CacheEntry ret  = ret_entry_;
    ret_entry_.reset();
    return ret;
  }
  Expr Mutate(Expr expr) final {
    ret_entry_.reset();
    expr = MutateExpr_(expr);
    ret_entry_.reset();
    return expr;
  }

  // Check whether do special canonicalization.
  bool EnableOpt(Type t) const {
    return (t.lanes() == 1 && (t.is_int() || t.is_uint()));
  }
  // Add
240
  Expr Mutate_(const Add* op, const Expr& e) final {
241
    if (!EnableOpt(op->type)) {
242
      return Binary(op, e);
243 244 245 246 247 248 249 250 251
    }
    CacheEntry a = Produce(op->a);
    CacheEntry b = Produce(op->b);
    if (a.has_side_effect || b.has_side_effect) {
      return Binary_(op, e, a.value, b.value);
    }
    return SumAdd(a, b, +1);
  }
  // Sub
252
  Expr Mutate_(const Sub* op, const Expr& e) final {
253
    if (!EnableOpt(op->type)) {
254
      return Binary(op, e);
255 256 257 258 259 260 261 262 263
    }
    CacheEntry a = Produce(op->a);
    CacheEntry b = Produce(op->b);
    if (a.has_side_effect || b.has_side_effect) {
      return Binary_(op, e, a.value, b.value);
    }
    return SumAdd(a, b, -1);
  }
  // Mul
264
  Expr Mutate_(const Mul* op, const Expr& e) final {
265
    if (!EnableOpt(op->type)) {
266
      return Binary(op, e);
267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293
    }
    CacheEntry a = Produce(op->a);
    CacheEntry b = Produce(op->b);
    if (a.has_side_effect || b.has_side_effect) {
      return Binary_(op, e, a.value, b.value);
    }
    if (is_const(a.value) && is_const(b.value)) {
      return ComputeExpr<Mul>(a.value, b.value);
    } else if (is_const(a.value)) {
      return SumMulConst(b.AsSum(), a.value);
    } else if (is_const(b.value)) {
      return SumMulConst(a.AsSum(), b.value);
    } else {
      return Binary_(op, e, a.value, b.value);
    }
  }
  // Variable
  Expr Mutate_(const Variable* op, const Expr& e) final {
    auto it = var_level_.find(op);
    if (it != var_level_.end()) {
      stack_.back().max_level = it->second;
    }
    return IRMutator::Mutate_(op, e);
  }
  // comparison
  Expr Mutate_(const LT* op, const Expr& e) {
    if (!EnableOpt(op->a.type())) {
294
      return Binary(op, e);
295 296 297 298 299 300 301 302 303 304 305 306 307
    }
    CacheEntry a = Produce(op->a);
    CacheEntry b = Produce(op->b);
    if (a.has_side_effect || b.has_side_effect) {
      return Binary_(op, e, a.value, b.value);
    }
    Expr b_sub_a = SumAdd(b, a, -1);
    if (EvalSet(b_sub_a, var_range_).can_prove_positive()) {
      return make_const(op->type, true);
    } else {
      return Binary_(op, e, a.value, b.value);
    }
  }
308 309
  // IntImm
  Expr Mutate_(const IntImm* op, const Expr& e) final {
310
    if (op->type != Int(32)) return e;
311 312 313 314 315 316 317 318
    auto it = cache_intimm_.find(op->value);
    if (it != cache_intimm_.end()) {
      return it->second;
    } else {
      cache_intimm_[op->value] = e;
      return e;
    }
  }
319
  // Div operator
320
  Expr Mutate_(const Div* op, const Expr& e) final {
321 322 323 324 325 326 327 328 329 330 331 332 333 334 335
    if (!EnableOpt(op->type)) {
      return Binary(op, e);
    }
    CacheEntry a = Produce(op->a);
    CacheEntry b = Produce(op->b);
    if (a.has_side_effect || b.has_side_effect) {
      return Binary_(op, e, a.value, b.value);
    }
    if (is_const(a.value) && is_const(b.value)) {
      return ComputeExpr<Div>(a.value, b.value);
    } else if (is_const(b.value)) {
      return SumDivConst(a.AsSum(), b.value);
    } else {
      return Binary(op, e);
    }
336
  }
337
  // Mod operator
338
  Expr Mutate_(const Mod* op, const Expr& e) final {
339 340 341 342 343 344 345 346 347 348 349 350 351 352 353
    if (!EnableOpt(op->type)) {
      return Binary(op, e);
    }
    CacheEntry a = Produce(op->a);
    CacheEntry b = Produce(op->b);
    if (a.has_side_effect || b.has_side_effect) {
      return Binary_(op, e, a.value, b.value);
    }
    if (is_const(a.value) && is_const(b.value)) {
      return ComputeExpr<Mul>(a.value, b.value);
    } else if (is_const(b.value)) {
      return SumModConst(a.AsSum(), b.value);
    } else {
      return Binary(op, e);
    }
354
  }
355

356 357 358 359 360 361 362
  Expr Mutate_(const And* op, const Expr& e) final {
    Expr expr = IRMutator::Mutate_(op, e);
    op = expr.as<And>();
    if (is_one(op->a)) return op->b;
    if (is_one(op->b)) return op->a;
    return expr;
  }
363 364 365 366 367
  // Call
  Expr Mutate_(const Call* op, const Expr& e) final {
    if (!op->is_pure()) {
      stack_.back().has_side_effect = true;
    }
368 369 370 371 372 373 374
    Expr expr = IRMutator::Mutate_(op, e);
    op = expr.as<Call>();
    if (op->is_intrinsic(Call::likely) && is_const(op->args[0])) {
      return op->args[0];
    } else {
      return expr;
    }
375 376 377 378 379 380
  }
  // For
  Stmt Mutate_(const For* op, const Stmt& s) {
    ++level_counter_;
    Var loop_var(op->loop_var.node_);
    this->SetRange(loop_var,
381
                   Range::make_by_min_extent(op->min, op->extent),
382 383 384 385 386
                   level_counter_);
    Stmt stmt = IRMutator::Mutate_(op, s);
    --level_counter_;
    return stmt;
  }
387 388 389 390 391 392 393
  // IfThenElse
  Stmt Mutate_(const IfThenElse* op, const Stmt& s) {
    Stmt stmt  = IRMutator::Mutate_(op, s);
    op = stmt.as<IfThenElse>();
    if (is_one(op->condition)) return op->then_case;
    return stmt;
  }
394 395
  // AttrStmt
  Stmt Mutate_(const AttrStmt* op, const Stmt& s) {
396 397
    if (op->attr_key == attr::thread_extent ||
        op->attr_key == attr::virtual_thread) {
398 399 400 401 402
      ++level_counter_;
      IterVar iv(op->node.node_);
      CHECK_NE(iv->thread_tag.length(), 0U);
      if (!var_level_.count(iv->var.get())) {
        this->SetRange(iv->var,
403
                       Range::make_by_min_extent(0, op->value),
404 405 406 407 408 409 410 411 412 413 414 415 416 417 418
                       level_counter_);
      }
      Stmt stmt = IRMutator::Mutate_(op, s);
      --level_counter_;
      return stmt;
    } else {
      return IRMutator::Mutate_(op, s);
    }
  }
  // The simplify statement.
  static FMutateExpr& vtable_expr() {  // NOLINT(*)
    static FMutateExpr inst; return inst;
  }

 private:
419
  template<typename T>
420
  Expr Binary(const T* op, Expr e) {
421 422 423 424 425 426 427 428 429 430 431 432
    Expr a = this->Mutate(op->a);
    Expr b = this->Mutate(op->b);
    BinaryExpr key{static_cast<int>(T::_type_info), a, b};
    auto it = cache_binary_.find(key);
    if (it != cache_binary_.end()) {
      return it->second;
    } else {
      Expr ret = Binary_(op, e, a, b);
      cache_binary_[key] = ret;
      return ret;
    }
  }
433 434 435 436 437 438
  // return entry
  CacheEntry ret_entry_;
  // internal information stack
  std::vector<StackEntry> stack_;
  // cache sum
  std::map<ComExpr, CacheEntry> cache_sum_;
439 440 441 442
  // cache of normal binary op
  std::map<BinaryExpr, Expr> cache_binary_;
  // cache of int constant
  std::unordered_map<int64_t, Expr> cache_intimm_;
443 444 445 446 447 448 449 450
  // range of each var
  std::unordered_map<const Variable*, IntSet> var_range_;
  // level of each var
  std::unordered_map<const Variable*, int> var_level_;
  // record history vars, to avoid false positive.
  std::vector<Var> var_rec_;
  // level counter
  int level_counter_{0};
451 452
  // get constant int value
  int64_t GetConstIntValue(const Expr& v) {
453 454 455 456 457 458 459 460 461 462 463
    int64_t value = 0;
    const int64_t *v1 = as_const_int(v);
    const uint64_t *v2 = as_const_uint(v);
    CHECK(v1 || v2);
    if (v1) {
      value = *v1;
    } else if (v2) {
      CHECK_LE(*v2,
               static_cast<uint64_t>(std::numeric_limits<int64_t>::max()));
      value = static_cast<int64_t>(*v2);
    }
464 465
    return value;
  }
466 467 468 469 470 471 472 473
  // Detect if a = x * coeff + y, where y \in [0, coeff), x >= 0
  // return true if such detection is successful
  // return false if it is not.
  std::vector<ComExpr> TryLinearEquation(const ComExpr& a,
                                         const Expr& coeff) {
    Type type = coeff.type();
    int64_t value = GetConstIntValue(coeff);
    if (value < 0) return {};
474 475
    auto xnode = make_node<ComExprNode>();
    auto ynode = make_node<ComExprNode>();
476 477 478 479 480 481 482 483
    if (a->base % value == 0) {
      xnode->base = a->base;
    } else {
      ynode->base = a->base;
    }
    for (const auto& e : a->elem) {
      if (e.scale % value == 0) {
        xnode->elem.push_back(e);
484
      } else {
485
        ynode->elem.push_back(e);
486 487
      }
    }
488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509
    Expr yres = Sum2Expr(ComExpr(ynode), type);
    IntSet yset = EvalSet(yres, var_range_);
    // This relies on the integer division rounds down
    // Most cases it is good for integer division.
    if (yset.min().type() == type &&
        can_prove(yset.min() >= make_zero(type)) &&
        yset.max().type() == type &&
        can_prove(yset.max() < coeff)) {
      xnode->base /= value;
      for (auto &e : xnode->elem) {
        e.scale /= value;
      }
      return {ComExpr(xnode), ComExpr(ynode)};
    } else {
      return {};
    }
  }
  // subroutine to do produce a % v
  Expr SumModConst(ComExpr a, Expr v) {
    std::vector<ComExpr> pair = TryLinearEquation(a, v);
    if (pair.size() == 0) {
      int64_t value = GetConstIntValue(v);
510
      auto n = make_node<ComExprNode>();
511 512 513 514 515 516
      n->base = a->base % value;
      for (auto e : a->elem) {
        if (e.scale % value == 0) continue;
        e.scale = e.scale % value;
        n->elem.push_back(e);
      }
517 518 519
      Expr ret = Sum2Expr(ComExpr(n), v.type()) % v;
      return Binary(ret.as<Mod>(), ret);
    }
520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539
    ret_entry_.sum = pair[1];
    ret_entry_.max_level = stack_.back().max_level;
    ret_entry_.has_side_effect = stack_.back().has_side_effect;
    auto it = cache_sum_.find(ret_entry_.sum);
    if (it != cache_sum_.end()) {
      ret_entry_ = it->second;
    } else {
      ret_entry_.value = Sum2Expr(ret_entry_.sum, v.type());
      cache_sum_[ret_entry_.sum] = ret_entry_;
    }
    return ret_entry_.value;
  }
  // subroutine to do produce a % v
  Expr SumDivConst(ComExpr a, Expr v) {
    std::vector<ComExpr> pair = TryLinearEquation(a, v);
    if (pair.size() == 0) {
      Expr ret = Sum2Expr(a, v.type()) / v;
      return Binary(ret.as<Div>(), ret);
    }
    ret_entry_.sum = pair[0];
540 541 542 543 544 545 546 547 548 549 550 551 552 553
    ret_entry_.max_level = stack_.back().max_level;
    ret_entry_.has_side_effect = stack_.back().has_side_effect;
    auto it = cache_sum_.find(ret_entry_.sum);
    if (it != cache_sum_.end()) {
      ret_entry_ = it->second;
    } else {
      ret_entry_.value = Sum2Expr(ret_entry_.sum, v.type());
      cache_sum_[ret_entry_.sum] = ret_entry_;
    }
    return ret_entry_.value;
  }
  // subroutine to do produce
  Expr SumMulConst(ComExpr a, Expr v) {
    int64_t value = GetConstIntValue(v);
554 555 556
    if (value == 0) {
      return make_zero(v.type());
    }
557
    auto vsum = make_node<ComExprNode>(*a.operator->());
558 559 560 561
    vsum->base *= value;
    for (auto& e : vsum->elem) {
      e.scale *= value;
    }
562
    ret_entry_.sum = ComExpr(vsum);
563 564 565 566 567 568 569 570 571 572 573 574 575 576 577
    ret_entry_.max_level = stack_.back().max_level;
    ret_entry_.has_side_effect = stack_.back().has_side_effect;
    auto it = cache_sum_.find(ret_entry_.sum);
    if (it != cache_sum_.end()) {
      ret_entry_ = it->second;
    } else {
      ret_entry_.value = Sum2Expr(ret_entry_.sum, v.type());
      cache_sum_[ret_entry_.sum] = ret_entry_;
    }
    return ret_entry_.value;
  }
  // add two ComExpr together
  ComExpr SumAdd_(const ComExpr& suma,
                  const ComExpr& sumb,
                  int bscale) {
578
    auto n = make_node<ComExprNode>();
579
    n->base = suma->base + sumb->base * bscale;
580 581 582 583 584
    // merge of suma and sumb;
    size_t i = 0, j = 0;
    while (i < suma->elem.size() && j < sumb->elem.size()) {
      const auto& a = suma->elem[i];
      const auto& b = sumb->elem[j];
585
      if (a.value.same_as(b.value) && a.level == b.level) {
586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614
        ComExprEntry e = a;
        e.scale = a.scale + b.scale * bscale;
        if (e.scale != 0) {
          n->elem.push_back(e);
        }
        ++i; ++j;
      } else if (a < b) {
        n->elem.push_back(a);
        ++i;
      } else {
        ComExprEntry e = b;
        e.scale *= bscale;
        n->elem.push_back(e);
        ++j;
      }
    }
    for (; i < suma->elem.size(); ++i) {
      n->elem.push_back(suma->elem[i]);
    }
    for (; j < sumb->elem.size(); ++j) {
      ComExprEntry e = sumb->elem[j];
      e.scale *= bscale;
      n->elem.push_back(e);
    }
    return ComExpr(n);
  }
  // subroutine to do produce
  Expr SumAdd(CacheEntry a, CacheEntry b, int bscale) {
    ret_entry_.sum = SumAdd_(a.AsSum(), b.AsSum(), bscale);
615
    CHECK_NE(stack_.size(), 0U);
616 617 618 619 620 621 622 623 624 625 626 627 628 629
    ret_entry_.max_level = stack_.back().max_level;
    ret_entry_.has_side_effect = stack_.back().has_side_effect;
    auto it = cache_sum_.find(ret_entry_.sum);
    if (it != cache_sum_.end()) {
      ret_entry_ = it->second;
    } else {
      ret_entry_.value = Sum2Expr(ret_entry_.sum, a.value.type());
      cache_sum_[ret_entry_.sum] = ret_entry_;
    }
    return ret_entry_.value;
  }
  // convert sum to expr
  Expr Sum2Expr(const ComExpr& com, Type t) {
    Expr vsum;
630
    if (com->base > 0) {
631 632 633 634 635 636 637 638 639 640 641 642 643 644 645
      vsum = make_const(t, com->base);
    }
    for (const ComExprEntry& e : com->elem) {
      if (e.scale > 0) {
        Expr v = e.value;
        if (e.scale != 1) {
          v = Mul::make(v, make_const(t, e.scale));
        }
        if (vsum.defined()) {
          vsum = Add::make(vsum, v);
        } else {
          vsum = v;
        }
      }
    }
646 647 648 649 650 651 652
    if (com->base < 0) {
      if (vsum.defined()) {
        vsum = Sub::make(vsum, make_const(t, -com->base));
      } else {
        vsum = make_const(t, com->base);
      }
    }
653 654 655 656 657 658 659 660 661 662 663 664 665
    for (const ComExprEntry& e : com->elem) {
      if (e.scale < 0) {
        Expr v = e.value;
        if (e.scale != -1) {
          v = Mul::make(v, make_const(t, -e.scale));
        }
        if (vsum.defined()) {
          vsum = Sub::make(vsum, v);
        } else {
          vsum = Sub::make(make_zero(t), v);
        }
      }
    }
666 667 668 669 670
    if (vsum.defined()) {
      return vsum;
    } else {
      return make_zero(t);
    }
671 672 673 674 675
  }
};

using CInternal = Canonical::Internal;

676 677
Canonical::Canonical(Map<Var, Range> vrange)
    : ptr_(std::make_shared<Internal>(vrange)) {}
678 679 680 681 682 683 684 685 686 687 688 689 690 691 692

Expr Canonical::Simplify(Expr expr) {
  return ptr_->Mutate(expr);
}

Stmt Canonical::Simplify(Stmt stmt) {
  return ptr_->Mutate(stmt);
}

void Canonical::SetRange(Var v, Range r, int level) {
  ptr_->SetRange(v, r, level);
}
}  // namespace arith

namespace ir {
693

694 695
Stmt CanonicalSimplify(Stmt stmt, Map<Var, Range> vrange) {
  return arith::Canonical(vrange).Simplify(stmt);
696 697
}

698 699
Expr CanonicalSimplify(Expr expr, Map<Var, Range> vrange) {
  return arith::Canonical(vrange).Simplify(expr);
700
}
701 702 703

template<typename T>
T Simplify_(T a, Map<Var, Range> vrange) {
704
  using namespace HalideIR::Internal;
705 706 707 708 709 710 711 712
  Scope<Interval> rscope;
  for (auto kv : vrange) {
    Range r = kv.second;
    rscope.push(
        kv.first.get(),
        Interval(r->min,
                 simplify(r->min + r->extent - make_const(r->min.type(), 1))));
  }
713
  return HalideIR::Internal::simplify(a, true, rscope);
714 715 716 717
}


Expr Simplify(Expr a, Map<Var, Range> vrange) {
718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735
  // We should not pass an expression having a non-HalideIR op to
  // Halide::Internal::simplify. Reduce op is the only such op at this time
  // and it only appears as the top op in an expression. So we strip it
  // first and send the sub-expressions to the simplifier.
  if (const Reduce* r = a.as<Reduce>()) {
    Array<Expr> new_source;
    for (auto& e : r->source) {
      new_source.push_back(Simplify_(e, vrange));
    }
    Expr new_condition = Simplify_(r->condition, vrange);
    if (r->source.same_as(new_source) &&
        r->condition.same_as(new_condition)) {
      return a;
    } else {
      return Reduce::make(
              r->combiner, new_source, r->axis, new_condition, r->value_index);
    }
  }
736 737 738 739 740 741
  return Simplify_(a, vrange);
}

Stmt Simplify(Stmt a, Map<Var, Range> vrange) {
  return Simplify_(a, vrange);
}
742 743
}  // namespace ir
}  // namespace tvm