/*!
 *  Copyright (c) 2017 by Contributors
 * \file canonical.cc
 * \brief Canonicalize simplification.
 */
#include <tvm/ir_mutator.h>
#include <tvm/arithmetic.h>
#include <tvm/ir_pass.h>
#include <algorithm>
#include <map>
#include <limits>
#include <vector>
#include "canonical.h"
#include "compute_expr.h"
#include "arithmetic/Simplify.h"

namespace tvm {
namespace arith {
using namespace ir;

// Canonical entry for communicative ops.
struct ComExprEntry {
  // the value of the expression.
  Expr value;
  // the level of the expression.
  int level{0};
  // The integer scale on value
  int64_t scale{1};

  ComExprEntry() {}
  ComExprEntry(Expr value, int level)
      : value(value), level(level) {}
  inline bool operator<(const ComExprEntry& other) const {
    if (level < other.level) return true;
    if (level > other.level) return false;
    // compare top operator of entries and sort on that if possible (fast check)
    if (value.type_index() < other.value.type_index()) return true;
    if (value.type_index() > other.value.type_index()) return false;
    // if none of the above distinguishes the terms, compare the expression tree of the entries.
    // This is a slower check.
    int compare_result = Compare(value, other.value);
    if (compare_result < 0) return true;
    if (compare_result > 0) return false;
    // it's a problem if we see identical entries at this point. They should've been merged earlier.
    LOG(WARNING) << "we should not have identical entries at this point";
    return false;
  }
};

// canonical expression for communicative expression.
struct ComExprNode : public NodeBase {
  // base constant value.
  int64_t base{0};
  // The values to be sumed.
  std::vector<ComExprEntry> elem;
};

// canonical communicative expression
struct ComExpr {
 public:
  // constructor
  ComExpr() {}
  explicit ComExpr(NodePtr<ComExprNode> ptr) : ptr_(ptr) {}
  // get member
  ComExprNode* operator->() const {
    return ptr_.get();
  }
  void reset() {
    ptr_.reset();
  }
  bool defined() const {
    return ptr_.get() != nullptr;
  }
  // comparator
  bool operator<(const ComExpr& b) const {
    const ComExpr& a = *this;
    if (a->base < b->base) return true;
    if (a->base > b->base) return false;
    if (a->elem.size() < b->elem.size()) return true;
    if (a->elem.size() > b->elem.size()) return false;
    for (size_t i = 0; i < a->elem.size(); ++i) {
      const ComExprEntry& ea = a->elem[i];
      const ComExprEntry& eb = b->elem[i];
      if (ea.level < eb.level) return true;
      if (ea.level > eb.level) return false;
      if (ea.value.get() < eb.value.get()) return true;
      if (ea.value.get() > eb.value.get()) return false;
      if (ea.scale < eb.scale) return true;
      if (ea.scale > eb.scale) return false;
    }
    return false;
  }
  // equality
  bool operator==(const ComExpr& b) const {
    const ComExpr& a = *this;
    if (a->base != b->base) return false;
    if (a->elem.size() != b->elem.size()) return false;
    for (size_t i = 0; i < a->elem.size(); ++i) {
      const ComExprEntry& ea = a->elem[i];
      const ComExprEntry& eb = b->elem[i];
      if (ea.level != eb.level) return false;
      if (ea.value.get() != eb.value.get()) return false;
      if (ea.scale != eb.scale) return false;
    }
    return true;
  }

 private:
  NodePtr<ComExprNode> ptr_;
};

// binary comparison op.
struct BinaryExpr {
  int kind;
  Expr lhs, rhs;
  // comparator
  bool operator<(const BinaryExpr& b) const {
    if (kind < b.kind) return true;
    if (kind > b.kind) return false;
    if (lhs.get() < b.lhs.get()) return true;
    if (lhs.get() > b.lhs.get()) return false;
    return rhs.get() < b.rhs.get();
  }
  // equality
  bool operator==(const BinaryExpr& b) const {
    return kind == b.kind &&
        lhs.same_as(b.lhs) &&
        rhs.same_as(b.rhs);
  }
};


template<typename T>
inline Expr Binary_(const T* op,
                    const Expr& e,
                    Expr a, Expr b) {
  if (a.same_as(op->a) && b.same_as(op->b)) {
    return e;
  } else {
    return T::make(a, b);
  }
}

// internal of canonical engine.
class Canonical::Internal : public IRMutator {
 public:
  explicit Internal(Map<Var, Range> vrange) {
    for (auto kv : vrange) {
      SetRange(kv.first, kv.second, 0);
    }
  }
  // stack entry.
  struct StackEntry {
    int max_level{0};
    bool has_side_effect{false};
  };
  // aggressively canonicalized expression
  struct CacheEntry {
    // The canonical value of the expression.
    Expr value;
    // The level of the expression.
    int max_level{0};
    // whether the expression might have side effect.
    bool has_side_effect{false};
    // if not null, corresponds to to sum
    ComExpr sum;
    // reset the return entry.
    void reset() {
      sum.reset();
    }
    // as sum expr
    ComExpr AsSum() const {
      if (sum.defined()) return sum;
      const int64_t *v1 = as_const_int(value);
      const uint64_t *v2 = as_const_uint(value);
      auto n = make_node<ComExprNode>();
      if (v1) {
        n->base = *v1;
      } else if (v2) {
        CHECK_LE(*v2,
               static_cast<uint64_t>(std::numeric_limits<int64_t>::max()));
        n->base = static_cast<int64_t>(*v2);
      } else {
        n->elem.push_back(ComExprEntry(value, max_level));
      }
      return ComExpr(n);
    }
  };
  // Set range and level of var.
  void SetRange(Var v, Range r, int level) {
    var_range_[v.get()] = IntSet::range(r);
    var_level_[v.get()] = level;
    var_rec_.push_back(v);
  }
  // functions
  Stmt Mutate(Stmt stmt) final {
    stmt = IRMutator::Mutate(stmt);
    return stmt;
  }
  Expr MutateExpr_(Expr expr) {
    stack_.push_back(StackEntry());
    expr = IRMutator::Mutate(expr);
    // update result of parent automatically during pop
    if (stack_.size() > 1) {
      StackEntry& back = stack_[stack_.size() - 1];
      StackEntry& prev = stack_[stack_.size() - 2];
      prev.max_level = std::max(prev.max_level, back.max_level);
      if (back.has_side_effect) prev.has_side_effect = true;
    }
    // copy result from stack
    ret_entry_.has_side_effect = stack_.back().has_side_effect;
    ret_entry_.max_level = stack_.back().max_level;
    stack_.pop_back();
    CHECK(expr.defined());
    if (const IntImm* op = expr.as<IntImm>()) {
      return Mutate_(op, expr);
    }
    return expr;
  }
  // call produce to get a cache entry.
  CacheEntry Produce(Expr expr) {
    ret_entry_.reset();
    ret_entry_.value = MutateExpr_(expr);
    CacheEntry ret  = ret_entry_;
    ret_entry_.reset();
    return ret;
  }
  Expr Mutate(Expr expr) final {
    ret_entry_.reset();
    expr = MutateExpr_(expr);
    ret_entry_.reset();
    return expr;
  }

  // Check whether do special canonicalization.
  bool EnableOpt(Type t) const {
    return (t.lanes() == 1 && (t.is_int() || t.is_uint()));
  }
  // Max
  Expr Mutate_(const Max* op, const Expr& e) final {
    CacheEntry a = Produce(op->a);
    CacheEntry b = Produce(op->b);
    if (a.has_side_effect || b.has_side_effect) {
      return Binary_(op, e, a.value, b.value);
    }
    return Binary(op, e);
  }
  // Min
  Expr Mutate_(const Min* op, const Expr& e) final {
    CacheEntry a = Produce(op->a);
    CacheEntry b = Produce(op->b);
    if (a.has_side_effect || b.has_side_effect) {
      return Binary_(op, e, a.value, b.value);
    }
    return Binary(op, e);
  }
  // Add
  Expr Mutate_(const Add* op, const Expr& e) final {
    if (!EnableOpt(op->type)) {
      return Binary(op, e);
    }
    CacheEntry a = Produce(op->a);
    CacheEntry b = Produce(op->b);
    if (a.has_side_effect || b.has_side_effect) {
      return Binary_(op, e, a.value, b.value);
    }
    return SumAdd(a, b, +1);
  }
  // Sub
  Expr Mutate_(const Sub* op, const Expr& e) final {
    if (!EnableOpt(op->type)) {
      return Binary(op, e);
    }
    CacheEntry a = Produce(op->a);
    CacheEntry b = Produce(op->b);
    if (a.has_side_effect || b.has_side_effect) {
      return Binary_(op, e, a.value, b.value);
    }
    return SumAdd(a, b, -1);
  }
  // Mul
  Expr Mutate_(const Mul* op, const Expr& e) final {
    if (!EnableOpt(op->type)) {
      return Binary(op, e);
    }
    CacheEntry a = Produce(op->a);
    CacheEntry b = Produce(op->b);
    if (a.has_side_effect || b.has_side_effect) {
      return Binary_(op, e, a.value, b.value);
    }
    if (is_const(a.value) && is_const(b.value)) {
      return ComputeExpr<Mul>(a.value, b.value);
    } else if (is_const(a.value)) {
      return SumMulConst(b.AsSum(), a.value);
    } else if (is_const(b.value)) {
      return SumMulConst(a.AsSum(), b.value);
    } else {
      return Binary(op, e);
    }
  }
  // Variable
  Expr Mutate_(const Variable* op, const Expr& e) final {
    auto it = var_level_.find(op);
    if (it != var_level_.end()) {
      stack_.back().max_level = it->second;
    }
    return IRMutator::Mutate_(op, e);
  }
  // comparison
  Expr Mutate_(const LT* op, const Expr& e) {
    if (!EnableOpt(op->a.type())) {
      return Binary(op, e);
    }
    CacheEntry a = Produce(op->a);
    CacheEntry b = Produce(op->b);
    if (a.has_side_effect || b.has_side_effect) {
      return Binary_(op, e, a.value, b.value);
    }
    Expr b_sub_a = SumAdd(b, a, -1);
    if (EvalSet(b_sub_a, var_range_).can_prove_positive()) {
      return make_const(op->type, true);
    } else {
      return Binary_(op, e, a.value, b.value);
    }
  }
  // IntImm
  Expr Mutate_(const IntImm* op, const Expr& e) final {
    if (op->type != Int(32)) return e;
    auto it = cache_intimm_.find(op->value);
    if (it != cache_intimm_.end()) {
      return it->second;
    } else {
      cache_intimm_[op->value] = e;
      return e;
    }
  }
  // Div operator
  Expr Mutate_(const Div* op, const Expr& e) final {
    if (!EnableOpt(op->type)) {
      return Binary(op, e);
    }
    CacheEntry a = Produce(op->a);
    CacheEntry b = Produce(op->b);
    if (a.has_side_effect || b.has_side_effect) {
      return Binary_(op, e, a.value, b.value);
    }
    if (is_const(a.value) && is_const(b.value)) {
      return ComputeExpr<Div>(a.value, b.value);
    } else if (is_const(b.value)) {
      return SumDivConst(a.AsSum(), b.value);
    } else {
      return Binary(op, e);
    }
  }
  // Mod operator
  Expr Mutate_(const Mod* op, const Expr& e) final {
    if (!EnableOpt(op->type)) {
      return Binary(op, e);
    }
    CacheEntry a = Produce(op->a);
    CacheEntry b = Produce(op->b);
    if (a.has_side_effect || b.has_side_effect) {
      return Binary_(op, e, a.value, b.value);
    }
    if (is_const(a.value) && is_const(b.value)) {
      return ComputeExpr<Mod>(a.value, b.value);
    } else if (is_const(b.value)) {
      return SumModConst(a.AsSum(), b.value);
    } else {
      return Binary(op, e);
    }
  }

  Expr Mutate_(const And* op, const Expr& e) final {
    Expr expr = IRMutator::Mutate_(op, e);
    op = expr.as<And>();
    if (is_one(op->a)) return op->b;
    if (is_one(op->b)) return op->a;
    return expr;
  }
  // Call
  Expr Mutate_(const Call* op, const Expr& e) final {
    if (!op->is_pure()) {
      stack_.back().has_side_effect = true;
    }
    Expr expr = IRMutator::Mutate_(op, e);
    op = expr.as<Call>();
    if (op->is_intrinsic(Call::likely) && is_const(op->args[0])) {
      return op->args[0];
    } else {
      return expr;
    }
  }
  // For
  Stmt Mutate_(const For* op, const Stmt& s) {
    ++level_counter_;
    Var loop_var(op->loop_var.node_);
    this->SetRange(loop_var,
                   Range::make_by_min_extent(op->min, op->extent),
                   level_counter_);
    Stmt stmt = IRMutator::Mutate_(op, s);
    --level_counter_;
    return stmt;
  }
  // IfThenElse
  Stmt Mutate_(const IfThenElse* op, const Stmt& s) {
    Stmt stmt  = IRMutator::Mutate_(op, s);
    op = stmt.as<IfThenElse>();
    if (is_one(op->condition)) return op->then_case;
    return stmt;
  }
  // AttrStmt
  Stmt Mutate_(const AttrStmt* op, const Stmt& s) {
    if (op->attr_key == attr::thread_extent ||
        op->attr_key == attr::virtual_thread) {
      ++level_counter_;
      IterVar iv(op->node.node_);
      CHECK_NE(iv->thread_tag.length(), 0U);
      if (!var_level_.count(iv->var.get())) {
        this->SetRange(iv->var,
                       Range::make_by_min_extent(0, op->value),
                       level_counter_);
      }
      Stmt stmt = IRMutator::Mutate_(op, s);
      --level_counter_;
      return stmt;
    } else {
      return IRMutator::Mutate_(op, s);
    }
  }
  // The simplify statement.
  static FMutateExpr& vtable_expr() {  // NOLINT(*)
    static FMutateExpr inst; return inst;
  }

 private:
  template<typename T>
  Expr Binary(const T* op, Expr e) {
    Expr a = this->Mutate(op->a);
    Expr b = this->Mutate(op->b);
    BinaryExpr key{static_cast<int>(T::_type_info), a, b};
    auto it = cache_binary_.find(key);
    if (it != cache_binary_.end()) {
      return it->second;
    } else {
      Expr ret = Binary_(op, e, a, b);
      cache_binary_[key] = ret;
      return ret;
    }
  }
  // return entry
  CacheEntry ret_entry_;
  // internal information stack
  std::vector<StackEntry> stack_;
  // cache sum
  std::map<ComExpr, CacheEntry> cache_sum_;
  // cache of normal binary op
  std::map<BinaryExpr, Expr> cache_binary_;
  // cache of int constant
  std::unordered_map<int64_t, Expr> cache_intimm_;
  // range of each var
  std::unordered_map<const Variable*, IntSet> var_range_;
  // level of each var
  std::unordered_map<const Variable*, int> var_level_;
  // record history vars, to avoid false positive.
  std::vector<Var> var_rec_;
  // level counter
  int level_counter_{0};
  // get constant int value
  int64_t GetConstIntValue(const Expr& v) {
    int64_t value = 0;
    const int64_t *v1 = as_const_int(v);
    const uint64_t *v2 = as_const_uint(v);
    CHECK(v1 || v2);
    if (v1) {
      value = *v1;
    } else if (v2) {
      CHECK_LE(*v2,
               static_cast<uint64_t>(std::numeric_limits<int64_t>::max()));
      value = static_cast<int64_t>(*v2);
    }
    return value;
  }
  // Detect if a = q * coeff + r, where r \in [0, coeff), coeff > 0
  // (in Euclidean division)
  // returns pair (q, r) if such detection is successful
  // returns empty vector otherwise.
  // Assumes that coeff is a constant integer
  std::vector<ComExpr> TryLinearEquation(const ComExpr& a,
                                         const Expr& coeff) {
    Type type = coeff.type();
    int64_t value = GetConstIntValue(coeff);
    CHECK_NE(value, 0);
    if (value < 0) return {};
    // Given that denominator (value variable) is positive, truncated division
    // (i.e., TVM's division semantics) is equivalent to Euclidean division if and only if
    // numerator is non-negative or numerator is divisible by denominator (i.e., value)
    IntSet numerator_int_set = EvalSet(Sum2Expr(a, type), var_range_);
    bool numerator_is_non_neg = numerator_int_set.can_prove_non_negative();
    // Try to separate terms of a into ones that can be proven to be
    // divisible by coeff and ones that are not
    // We will build q and r from divisible and non_divisible respectively
    auto divisible = make_node<ComExprNode>();
    auto non_divisible = make_node<ComExprNode>();
    if (a->base % value == 0) {
      divisible->base = a->base;
    } else {
      non_divisible->base = a->base;
    }
    for (const auto& e : a->elem) {
      if (e.scale % value == 0) {
        divisible->elem.push_back(e);
      } else {
        non_divisible->elem.push_back(e);
      }
    }
    bool non_divisible_is_simplified = false;
    int64_t div_result;
    Expr non_divisible_res = Sum2Expr(ComExpr(non_divisible), type);
    // if non_divisible part consists of only an integer and numerator is non-negative,
    // we can simply divide it by coeff
    if (is_const(non_divisible_res)) {
      int64_t non_divisible_const = GetConstIntValue(non_divisible_res);
      if (numerator_is_non_neg || non_divisible_const == 0) {
        non_divisible_is_simplified = true;
        // We need to do an Euclidean division here because (a*b + c)/b == a + c/b
        // holds true only if division is Euclidean
        div_result = HalideIR::Internal::div_imp(non_divisible_const , value);
      }
    } else {
      // If we can prove that non_divisible part lies within [0, coeff), then
      // non_divisible itself will be our r
      IntSet non_divisible_set = EvalSet(non_divisible_res, var_range_);
      if (non_divisible_set.min().type() == type &&
          non_divisible_set.max().type() == type) {
        if ( (non_divisible_set.is_single_point() &&
              can_prove(non_divisible_set.point_value() == 0)) ||
             (numerator_is_non_neg &&
              can_prove(non_divisible_set.min() >= make_zero(type)) &&
              can_prove(non_divisible_set.max() < coeff)) ) {
          non_divisible_is_simplified = true;
          div_result = 0;
        }
      }
    }
    if (non_divisible_is_simplified) {
      non_divisible->base -= div_result * value;
      divisible->base /= value;
      divisible->base += div_result;
      for (auto& e : divisible->elem) {
        e.scale /= value;
      }
      return {ComExpr(divisible), ComExpr(non_divisible)};
    } else {
      return {};
    }
  }
  // subroutine to do produce a % v
  Expr SumModConst(ComExpr a, Expr v) {
    std::vector<ComExpr> pair = TryLinearEquation(a, v);
    if (pair.size() == 0) {
      int64_t value = GetConstIntValue(v);
      auto n = make_node<ComExprNode>();
      // FIXME(derisavi) : The following can be done only for Euclidean division/mod.
      //  Therefore, it's only valid when truncated division/mod is equivalent to Euclidean one,
      //  that is, if and only if a and v are
      //  both negative or both positive or a is divisible by v.
      //  Extend the code to handle cases where the above condition is not satisfied, i.e.,
      //  a and v are of different signs and a is not divisible by v.
      n->base = a->base % value;
      for (auto e : a->elem) {
        if (e.scale % value == 0) continue;
        e.scale = e.scale % value;
        n->elem.push_back(e);
      }
      Expr ret = Sum2Expr(ComExpr(n), v.type()) % v;
      if (const Mod* mod = ret.as<Mod>()) {
        return Binary(mod, ret);
      } else {
        // Sometimes the result is a constant, this may happen when value is -1
        CHECK(is_const(ret)) << "CanonicalSimplify: "
          << Sum2Expr(ComExpr(n), v.type()) << " % " << v << " is " << ret
          << " which is neither Mod, nor a constant";
        return ret;
      }
    }
    ret_entry_.sum = pair[1];
    ret_entry_.max_level = stack_.back().max_level;
    ret_entry_.has_side_effect = stack_.back().has_side_effect;
    auto it = cache_sum_.find(ret_entry_.sum);
    if (it != cache_sum_.end()) {
      ret_entry_ = it->second;
    } else {
      ret_entry_.value = Sum2Expr(ret_entry_.sum, v.type());
      cache_sum_[ret_entry_.sum] = ret_entry_;
    }
    return ret_entry_.value;
  }
  // subroutine to do produce a % v
  Expr SumDivConst(ComExpr a, Expr v) {
    std::vector<ComExpr> pair = TryLinearEquation(a, v);
    if (pair.size() == 0) {
      Expr ret = Sum2Expr(a, v.type()) / v;
      return Binary(ret.as<Div>(), ret);
    }
    ret_entry_.sum = pair[0];
    ret_entry_.max_level = stack_.back().max_level;
    ret_entry_.has_side_effect = stack_.back().has_side_effect;
    auto it = cache_sum_.find(ret_entry_.sum);
    if (it != cache_sum_.end()) {
      ret_entry_ = it->second;
    } else {
      ret_entry_.value = Sum2Expr(ret_entry_.sum, v.type());
      cache_sum_[ret_entry_.sum] = ret_entry_;
    }
    return ret_entry_.value;
  }
  // subroutine to do produce
  Expr SumMulConst(ComExpr a, Expr v) {
    int64_t value = GetConstIntValue(v);
    if (value == 0) {
      return make_zero(v.type());
    }
    auto vsum = make_node<ComExprNode>(*a.operator->());
    vsum->base *= value;
    for (auto& e : vsum->elem) {
      e.scale *= value;
    }
    ret_entry_.sum = ComExpr(vsum);
    ret_entry_.max_level = stack_.back().max_level;
    ret_entry_.has_side_effect = stack_.back().has_side_effect;
    auto it = cache_sum_.find(ret_entry_.sum);
    if (it != cache_sum_.end()) {
      ret_entry_ = it->second;
    } else {
      ret_entry_.value = Sum2Expr(ret_entry_.sum, v.type());
      cache_sum_[ret_entry_.sum] = ret_entry_;
    }
    return ret_entry_.value;
  }
  // add two ComExpr together
  ComExpr SumAdd_(const ComExpr& suma,
                  const ComExpr& sumb,
                  int bscale) {
    auto n = make_node<ComExprNode>();
    n->base = suma->base + sumb->base * bscale;
    // merge of suma and sumb;
    size_t i = 0, j = 0;
    while (i < suma->elem.size() && j < sumb->elem.size()) {
      const auto& a = suma->elem[i];
      const auto& b = sumb->elem[j];
      if (a.value.same_as(b.value) && a.level == b.level) {
        ComExprEntry e = a;
        e.scale = a.scale + b.scale * bscale;
        if (e.scale != 0) {
          n->elem.push_back(e);
        }
        ++i; ++j;
      } else if (a < b) {
        n->elem.push_back(a);
        ++i;
      } else {
        ComExprEntry e = b;
        e.scale *= bscale;
        n->elem.push_back(e);
        ++j;
      }
    }
    for (; i < suma->elem.size(); ++i) {
      n->elem.push_back(suma->elem[i]);
    }
    for (; j < sumb->elem.size(); ++j) {
      ComExprEntry e = sumb->elem[j];
      e.scale *= bscale;
      n->elem.push_back(e);
    }
    return ComExpr(n);
  }
  // subroutine to do produce
  Expr SumAdd(CacheEntry a, CacheEntry b, int bscale) {
    ret_entry_.sum = SumAdd_(a.AsSum(), b.AsSum(), bscale);
    CHECK_NE(stack_.size(), 0U);
    ret_entry_.max_level = stack_.back().max_level;
    ret_entry_.has_side_effect = stack_.back().has_side_effect;
    auto it = cache_sum_.find(ret_entry_.sum);
    if (it != cache_sum_.end()) {
      ret_entry_ = it->second;
    } else {
      ret_entry_.value = Sum2Expr(ret_entry_.sum, a.value.type());
      cache_sum_[ret_entry_.sum] = ret_entry_;
    }
    return ret_entry_.value;
  }
  // convert sum to expr
  Expr Sum2Expr(const ComExpr& com, Type t) {
    Expr vsum;
    if (com->base > 0) {
      vsum = make_const(t, com->base);
    }
    for (const ComExprEntry& e : com->elem) {
      if (e.scale > 0) {
        Expr v = e.value;
        if (e.scale != 1) {
          v = Mul::make(v, make_const(t, e.scale));
        }
        if (vsum.defined()) {
          vsum = Add::make(vsum, v);
        } else {
          vsum = v;
        }
      }
    }
    if (com->base < 0) {
      if (vsum.defined()) {
        vsum = Sub::make(vsum, make_const(t, -com->base));
      } else {
        vsum = make_const(t, com->base);
      }
    }
    for (const ComExprEntry& e : com->elem) {
      if (e.scale < 0) {
        Expr v = e.value;
        if (e.scale != -1) {
          v = Mul::make(v, make_const(t, -e.scale));
        }
        if (vsum.defined()) {
          vsum = Sub::make(vsum, v);
        } else {
          vsum = Sub::make(make_zero(t), v);
        }
      }
    }
    if (vsum.defined()) {
      return vsum;
    } else {
      return make_zero(t);
    }
  }
};

using CInternal = Canonical::Internal;

Canonical::Canonical(Map<Var, Range> vrange)
    : ptr_(std::make_shared<Internal>(vrange)) {}

Expr Canonical::Simplify(Expr expr) {
  return ptr_->Mutate(expr);
}

Stmt Canonical::Simplify(Stmt stmt) {
  return ptr_->Mutate(stmt);
}

void Canonical::SetRange(Var v, Range r, int level) {
  ptr_->SetRange(v, r, level);
}
}  // namespace arith

namespace ir {

Stmt CanonicalSimplify(Stmt stmt, Map<Var, Range> vrange) {
  return arith::Canonical(vrange).Simplify(stmt);
}

Expr CanonicalSimplify(Expr expr, Map<Var, Range> vrange) {
  return arith::Canonical(vrange).Simplify(expr);
}

template<typename T>
T Simplify_(T a, Map<Var, Range> vrange) {
  using namespace HalideIR::Internal;
  Scope<Interval> rscope;
  for (auto kv : vrange) {
    Range r = kv.second;
    rscope.push(
        kv.first.get(),
        Interval(r->min,
                 simplify(r->min + r->extent - make_const(r->min.type(), 1))));
  }
  return HalideIR::Internal::simplify(a, true, rscope);
}


/*!
 * \brief Simplify just the combiner of the given reduce node.
 *
 *  This function applies Simplify to the components of the top reduction's
 *  combiner, but not to the source or condition of the reduction.
 *  It also removes all components which are not used to
 *  compute the resulting value (the value_index-th value).
 *
 *  If \p expr is not a reduction node, it is left unchanged.
 *
 * \param expr The expression to be simplifed.
 * \return Simplified expression.
 */
Expr SimplifyCombiner(const Expr& expr, const Map<Var, Range>& vrange = Map<Var, Range>()) {
  const Reduce* op = expr.as<Reduce>();
  if (!op) {
    return expr;
  }

  // First simplify the results
  Array<Expr> simplified_result;
  for (const auto& res : op->combiner->result) {
    simplified_result.push_back(Simplify(res, vrange));
  }

  // Which components to keep
  std::vector<int> used(op->combiner->result.size(), false);

  // This function recursively marks the used components starting from
  // the index idx
  std::function<void(int)> mark_used;
  mark_used = [&used, &simplified_result, op, &mark_used](size_t idx) {
    // if the idx-th component was marked as used before, do nothing
    if (used[idx]) return;
    used[idx] = true;

    // check if the idx-th result expr uses some lhs or rhs variables
    // and recursively mark the corresponding components
    for (size_t i = 0; i < simplified_result.size(); ++i)
      if (!used[i]) {
        if (ExprUseVar(simplified_result[idx], op->combiner->lhs[i]) ||
            ExprUseVar(simplified_result[idx], op->combiner->rhs[i]))
          mark_used(i);
      }
  };

  // mark all used components starting from the value_index
  mark_used(op->value_index);

  // components which have side effects should also be preserved
  for (size_t i = 0; i < used.size(); ++i) {
    if (HasSideEffect(op->source[i]) || HasSideEffect(op->combiner->identity_element[i]) ||
        HasSideEffect(op->combiner->result[i])) {
      mark_used(i);
    }
  }

  int new_value_index = op->value_index;
  Array<Expr> new_result;
  Array<Expr> new_identity;
  Array<Var> new_lhs;
  Array<Var> new_rhs;
  Array<Expr> new_source;

  // new stuff is old stuff which is used
  for (size_t i = 0; i < used.size(); ++i) {
    if (used[i]) {
      // We simplify the result and identity, but not the source
      new_result.push_back(simplified_result[i]);
      new_identity.push_back(Simplify(op->combiner->identity_element[i], vrange));
      new_lhs.push_back(op->combiner->lhs[i]);
      new_rhs.push_back(op->combiner->rhs[i]);
      new_source.push_back(op->source[i]);
    } else if (static_cast<int>(i) < op->value_index) {
      // value_index should also be adjusted
      new_value_index--;
    }
  }

  CommReducer new_combiner = CommReducerNode::make(new_lhs, new_rhs, new_result, new_identity);
  return Reduce::make(new_combiner, new_source, op->axis, op->condition, new_value_index);
}

/*!
 * \brief Remove a single reduction over empty axis.
 *
 *  If \p e is a reduction node and its axis is empty, replace it with its source,
 *  otherwise return \p e unchanged.
 *
 * \param e The expression to be transformed.
 * \return The transformed expression.
 */
Expr RemoveEmptyReduction(const Expr& e) {
  const Reduce* r = e.as<Reduce>();
  if (r && r->axis.empty()) {
    // Note that here we assume that the identity element is indeed identity. Without this
    // assumption we would have to perform a single iteration of the loop, i.e. use
    // `(*r->combiner.get())(r->combiner->identity_element, r->source)[r->value_index]`
    // instead of `r->source[r->value_index]`. The former may be more difficult to simplify.
    return Select::make(r->condition,
                        r->source[r->value_index],
                        r->combiner->identity_element[r->value_index]);
  }
  return e;
}

Expr Simplify(Expr a, Map<Var, Range> vrange) {
  // We should not pass an expression having a non-HalideIR op to
  // Halide::Internal::simplify. Reduce op is the only such op at this time
  // and it only appears as the top op in an expression. So we strip it
  // first and send the sub-expressions to the simplifier.
  if (const Reduce* r = a.as<Reduce>()) {
    // If axis is empty, we can remove the reduce op completely.
    if (r->axis.empty())
      return Simplify_(RemoveEmptyReduction(a), vrange);

    // Simplify the combiner of the reduction
    a = SimplifyCombiner(a, vrange);
    r = a.as<Reduce>();

    // If axis is not empty then we add the information about ranges to vrange
    for (const IterVar& iv : r->axis) {
      if (vrange.count(iv->var)) {
        Range existing_range = vrange[iv->var];
        CHECK(Equal(existing_range->min, iv->dom->min) &&
              Equal(existing_range->extent, iv->dom->extent))
          << "Simplify was given vrange stating that the range of the reduction var "
          << iv << " is " << existing_range << ". This is probably a mistake.";
      }
      vrange.Set(iv->var, iv->dom);
    }

    Array<Expr> new_source;
    for (auto& e : r->source) {
      new_source.push_back(Simplify_(e, vrange));
    }
    Expr new_condition = Simplify_(r->condition, vrange);
    if (r->source.same_as(new_source) &&
        r->condition.same_as(new_condition)) {
      return a;
    } else {
      return Reduce::make(
              r->combiner, new_source, r->axis, new_condition, r->value_index);
    }
  }
  return Simplify_(a, vrange);
}

Stmt Simplify(Stmt a, Map<Var, Range> vrange) {
  return Simplify_(a, vrange);
}
}  // namespace ir
}  // namespace tvm