/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file canonical_simplify.cc
 * \brief Canonical form based simplification.
 */
#include <tvm/arith/analyzer.h>
#include <tvm/tir/op.h>
#include <tvm/tir/analysis.h>

#include "const_fold.h"
#include "pattern_match.h"
#include "rewrite_simplify.h"

namespace tvm {
namespace arith {

using namespace tir;

class SumExpr;
class SplitExpr;


/*!
 * \brief Base class of all temporary expression introduced
 *        for canonicalization.
 */
class CanonicalExprNode : public PrimExprNode {
 public:
  virtual ~CanonicalExprNode() {}
  /*!
   * \brief Return the normal Expr that is equivalent to self.
   * \note Can mutate the internal data structure.
   * \return The normal expression.
   */
  virtual PrimExpr Normalize() const = 0;

  // overrides
  void VisitAttrs(tvm::AttrVisitor* v) {
  }

  static constexpr const char* _type_key = "arith.CanonicalExpr";
  TVM_DECLARE_BASE_OBJECT_INFO(CanonicalExprNode, PrimExprNode);
};

enum DivMode {
  /*! \brief Truncated division. */
  kTruncDiv,
  /*! \brief Floor division. */
  kFloorDiv
};

inline PrimExpr ModImpl(PrimExpr a, PrimExpr b, DivMode mode) {
  if (mode == kTruncDiv) {
    return truncmod(a, b);
  } else {
    CHECK_EQ(mode, kFloorDiv);
    return floormod(a, b);
  }
}

inline PrimExpr DivImpl(PrimExpr a, PrimExpr b, DivMode mode) {
  if (mode == kTruncDiv) {
    return truncdiv(a, b);
  } else {
    CHECK_EQ(mode, kFloorDiv);
    return floordiv(a, b);
  }
}

/*!
 * \brief Internal "Split normal form" of expression.
 *
 * This is a special expression that represents
 * a scaled value derived from a split of an index.
 *
 * result = ((index % upper_factor) / lower_factor) * scale
 */
class SplitExprNode : public CanonicalExprNode {
 public:
  /*! \brief The base index expression. */
  PrimExpr index;
  /*! \brief The division factor ratio. */
  int64_t lower_factor{1};
  /*!
   * \brief The upper factor.
   * invariance: (upper_factor == kPosInf || upper_factor % lower_factor == 0)
   */
  int64_t upper_factor{kPosInf};
  /*! \brief scale to the expression. */
  int64_t scale{1};
  /*! \brief Division mode. */
  DivMode div_mode{kTruncDiv};

  /*! \brief verify that this is a valid entry. */
  void Verify() const {
    CHECK(upper_factor == kPosInf || upper_factor % lower_factor == 0);
  }

  PrimExpr NormalizeWithScale(int64_t sscale) const {
    PrimExpr res = this->index;
    DataType dtype = this->dtype;
    if (this->scale == 0) {
      return make_const(dtype, 0);
    }
    if (this->upper_factor != SplitExprNode::kPosInf) {
      res = ModImpl(res, make_const(dtype, this->upper_factor), div_mode);
    }
    if (this->lower_factor != 1) {
      res = DivImpl(res, make_const(dtype, this->lower_factor), div_mode);
    }
    sscale *= this->scale;
    if (sscale != 1) {
      CHECK(!dtype.is_uint() || sscale > 0);
      res = res * make_const(dtype, sscale);
    }
    return res;
  }

  PrimExpr Normalize() const final {
    return NormalizeWithScale(1);
  }

  void MulToSelf(int64_t scale) {
    this->scale *= scale;
  }

  inline bool IndexEqual(const SplitExpr& other) const;
  inline bool DivModeCompatibleTo(DivMode mode) const;

  /*! \brief positive infty */
  static const constexpr int64_t kPosInf = ConstIntBoundNode::kPosInf;
  static constexpr const char* _type_key = "arith.SplitExpr";
  TVM_DECLARE_FINAL_OBJECT_INFO(SplitExprNode, CanonicalExprNode);
};

class SplitExpr : public PrimExpr {
 public:
  TVM_DEFINE_OBJECT_REF_METHODS(SplitExpr, PrimExpr, SplitExprNode);
  TVM_DEFINE_OBJECT_REF_COW_METHOD(SplitExprNode);
};

inline bool SplitExprNode::IndexEqual(const SplitExpr& other) const {
  if (index.same_as(other->index)) return true;
  return tir::ExprDeepEqual()(index, other->index);
}

inline bool SplitExprNode::DivModeCompatibleTo(DivMode mode) const {
  if (this->div_mode == mode) return true;
  if (lower_factor == 1 && upper_factor == kPosInf) return true;
  return false;
}

/*!
 * \brief Normal form that represents sum of expressions.
 *
 *  result = sum(args) + base.
 */
class SumExprNode : public CanonicalExprNode {
 public:
  /*!
   * \brief arguments to be summed up.
   *
   * args are divided into segments with the same index.
   * within each segment, the SplitExpr is ordered in descending order of lower_factor.
   */
  std::vector<SplitExpr> args;
  /*! \brief Base value in the summation. */
  int64_t base{0};
  /*! \brief The expression equals zero. */
  bool IsZero() const {
    return base == 0 && args.size() == 0;
  }
  /*!
   * \brief Return the normal Expr that is equivalent to self.
   * \return The normal expression.
   */
  PrimExpr Normalize() const final {
    // quick path 1.
    if (this->args.size() == 0) {
      return make_const(this->dtype, this->base);
    }
    return Normalize_(this->dtype,
                      SimplifySplitExprs(args),
                      base);
  }
  /*!
   * \brief Whether self is divisible by scale.
   * \param scale The scale to be applied.
   */
  bool DivisibleBy(int64_t scale) {
    if (base % scale != 0) return false;
    for (size_t i = 0; i < this->args.size(); ++i) {
      if (args[i]->scale % scale != 0) return false;
    }
    return true;
  }
  /*!
   * \brief mul scale to self.
   * \param scale The scale to be applied.
   */
  void MulToSelf(int64_t scale) {
    this->base *= scale;
    for (size_t i = 0; i < this->args.size(); ++i) {
      args[i].CopyOnWrite()->scale *= scale;
    }
  }
  /*!
   * \brief divide by scale.
   * \param scale The scale to be applied.
   */
  void DivideBy(int64_t scale) {
    CHECK_EQ(this->base % scale, 0);
    this->base /= scale;
    for (size_t i = 0; i < this->args.size(); ++i) {
      CHECK_EQ(args[i]->scale % scale, 0);
      args[i].CopyOnWrite()->scale /= scale;
    }
  }
  /*!
   * \brief add constant value to self.
   * \param value to be added.
   */
  void AddToSelf(int64_t value) {
    this->base += value;
  }
  /*!
   * \brief self += other * scale;
   * \param other The expression to be added.
   * \param scale The additional scale on value.
   */
  void AddToSelf(SplitExpr other, int64_t scale) {
    if (other->scale == 0) return;
    // We need to maintain the segment invariance:
    // Same index are stored close to each other.
    // sorted from big lower_factor to small one.
    size_t start = 0;
    for (; start < args.size(); ++start) {
      if (args[start]->IndexEqual(other)) break;
    }
    for (size_t j = start; j < args.size(); ++j) {
      if (!args[j]->IndexEqual(other) ||
          other->lower_factor > args[j]->lower_factor) {
        other.CopyOnWrite()->scale *= scale;
        this->args.insert(this->args.begin() + j, other);
        return;
      }
      if (other->lower_factor == args[j]->lower_factor &&
          other->upper_factor == args[j]->upper_factor &&
          other->DivModeCompatibleTo(args[j]->div_mode)) {
        args[j].CopyOnWrite()->scale += other->scale * scale;
        return;
      }
    }
    // Insert other in the end.
    other.CopyOnWrite()->scale *= scale;
    this->args.emplace_back(std::move(other));
  }

  void AddToSelf(const SumExpr& other, int64_t scale);

  static constexpr const char* _type_key = "arith.SumExpr";
  TVM_DECLARE_FINAL_OBJECT_INFO(SumExprNode, CanonicalExprNode);

 private:
  /*!
   * \brief Simplify the args by merging SplitExprs
   * \param args The original list of arguments.
   * \return simplified version.
   */
  static std::vector<SplitExpr>
  SimplifySplitExprs(std::vector<SplitExpr> args) {
    // NOTE: This algorithm relies on the factor that args are divided into segments
    // and each segment is sorted in descending order of lower_factor.
    for (size_t i = 0; i < args.size(); ++i) {
      if (args[i]->scale == 0) continue;
      for (size_t j = i + 1; j < args.size(); ++j) {
        SplitExpr& lhs = args[i];
        SplitExpr& rhs = args[j];
        if (!lhs->IndexEqual(rhs)) break;
        if (lhs->upper_factor < rhs->lower_factor) break;
        if (lhs->upper_factor == rhs->upper_factor &&
            lhs->lower_factor == rhs->lower_factor &&
            lhs->DivModeCompatibleTo(rhs->div_mode)) {
          // folding same co-efficient.
          rhs.CopyOnWrite()->scale += lhs->scale;
          lhs.CopyOnWrite()->scale = 0;
        } else if (lhs->lower_factor == rhs->upper_factor &&
                   rhs->scale != 0 &&
                   lhs->scale % rhs->scale == 0 &&
                   lhs->lower_factor == (lhs->scale / rhs->scale) * rhs->lower_factor &&
                   lhs->DivModeCompatibleTo(rhs->div_mode)) {
          // Rules used in the proof:
          //
          // Rule 1:  (x % (c * s)) / c  =  (x / c) % s
          // Proof:
          //  x can always be decomposed into p * c * s + q * c + r
          //  where  0 <= q * c + r < c * s  and  0 <= r  <  c.
          //  Then, lhs = ((p * c * s + q * c + r) % (c * s)) / c = (q * c + r) / c = q
          //  rhs = ((p * c * s + q * c + r) / c) % s = (p * s + q) % s = q
          //  Thus, lhs = rhs
          //
          // The above proof is for the floordiv.
          // The same rule also holds for truncdiv(division rule in C).
          // Because both sides only involve mul, div and mod,
          // we can take abs of x, c and s, apply the floordiv proof,
          // and finally add the sign back.
          //
          // Rule 2:  (x / s) * s + x % s = x  (true for both trunc and floor div)
          //
          // General merge condition and proof:
          // - x = lhs->index % lhs->upper_factor
          // - s = lhs->scale / rhs->scale
          // - c = rhs->lower_factor
          //
          //    (x / (c * s)) * s + (x % (c * s)) / c
          // => ((x / c) / s) * s + ((x / c) % s)
          // => (x / c)
          //
          // Examples:
          //
          //    (z / 6) * 6 + ((z % 6) / 3) * 3
          // => ((z / 6) * 2 + (z % 6) / 3) * 3
          // => (z / 3) * 3
          // note: x = z, c = 3, s = 2
          //
          //    ((z % 12) / 6) * 6 + ((z % 6) / 3) * 3
          // => (((z % 12) / 6) * 2 + ((z % 12) % 6) / 3) * 3
          // => ((z % 12) / 3) * 3
          // note: x = z % 12, c = 3, s = 2
          // note also the invariance lhs->upper_factor % lhs->lower_factor == 0
          //
          SplitExprNode* merged = rhs.CopyOnWrite();
          merged->upper_factor = lhs->upper_factor;
          // reset args[i] to be zero.
          lhs.CopyOnWrite()->scale = 0;
          break;
        }
      }
    }
    // sort by the entry
    // Here we simply sort by descending order of scales.
    // For now, we do not compare by index because that comparison
    // can be runtime dependent and create inderminism.
    // we do not sort by index for now because it can be costly
    // to deep compare Exprs, and address of Vars can be runtime dependent.
    //
    auto fcompare = [](const SplitExpr& lhs, const SplitExpr& rhs) {
      // order by scale first
      if (lhs->scale > rhs->scale) return true;
      if (lhs->scale < rhs->scale) return false;
      // then order by factor
      if (lhs->lower_factor > rhs->lower_factor) return true;
      if (lhs->lower_factor < rhs->lower_factor) return false;
      // then order by upper factor
      if (lhs->upper_factor > rhs->upper_factor) return true;
      if (lhs->upper_factor < rhs->upper_factor) return false;
      // then order by div mode
      if (lhs->div_mode > rhs->div_mode) return true;
      if (lhs->div_mode < rhs->div_mode) return false;
      // tie.
      // TODO(tvm-team) We might consider index as the last comparison point,
      // after we make deep comparator more derministic.
      // Specifically, we can consider comparing names of vars and break ties with address.
      return false;
    };
    std::stable_sort(args.begin(), args.end(), fcompare);
    return args;
  }
  static PrimExpr Normalize_(DataType dtype,
                         const std::vector<SplitExpr>& args,
                         int64_t base) {
    // Positive scales first
    PrimExpr res = make_const(dtype, 0);
    for (size_t i = 0; i < args.size(); ++i) {
      if (args[i]->scale > 0) {
        res = res + args[i]->Normalize();
      }
    }
    if (base > 0) {
      res = res + make_const(dtype, base);
    }
    // negative scales follows using sub.
    for (size_t i = 0; i < args.size(); ++i) {
      if (args[i]->scale < 0) {
        res = res - args[i]->NormalizeWithScale(-1);
      }
    }
    if (base < 0) {
      res = res - make_const(dtype, -base);
    }
    return res;
  }
};

class SumExpr : public PrimExpr {
 public:
  TVM_DEFINE_OBJECT_REF_METHODS(SumExpr, PrimExpr, SumExprNode);
  TVM_DEFINE_OBJECT_REF_COW_METHOD(SumExprNode);
};

void SumExprNode::AddToSelf(const SumExpr& other, int64_t scale) {
  // NOTE: it is rare to have a balanced long expression,
  // linear scan is fine for our case.
  for (size_t i = 0; i < other->args.size(); ++i) {
    this->AddToSelf(other->args[i], scale);
  }
  this->AddToSelf(other->base * scale);
}

// Sub-class RewriteSimplifier::Impl to take benefit of
// rewriter for condition simplification etc.
class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl {
 public:
  using Rewriter = RewriteSimplifier::Impl;

  explicit Impl(Analyzer* parent)
      : Rewriter(parent) {}


  PrimExpr CanonicalSimplify(PrimExpr expr) {
    expr = operator()(expr);
    return expr;
  }

  // override the original mutate function.
  PrimExpr VisitExpr(const PrimExpr& input_expr) final {
    auto expr = Rewriter::VisitExpr(input_expr);
    return Normalize(expr);
  }

  // Normal mutation without normalization.
  PrimExpr CanonicalMutate(PrimExpr expr) {
    return Rewriter::VisitExpr(expr);
  }

  using Rewriter::VisitExpr_;
  PrimExpr VisitExpr_(const AddNode* op) final;
  PrimExpr VisitExpr_(const SubNode* op) final;
  PrimExpr VisitExpr_(const MulNode* op) final;
  PrimExpr VisitExpr_(const DivNode* op) final;
  PrimExpr VisitExpr_(const ModNode* op) final;
  PrimExpr VisitExpr_(const FloorDivNode* op) final;
  PrimExpr VisitExpr_(const FloorModNode* op) final;
  PrimExpr VisitExpr_(const ReduceNode* op) final;

 private:
  /*!
   * \brief compute lhs / cval
   * \param lhs The left operand.
   * \param cval The constant value.
   * \param div_mode The division mode.
   * \return The result expression;
   */
  SplitExpr SplitDivConst(SplitExpr lhs, int64_t cval, DivMode div_mode);
  /*!
   * \brief compute lhs % cval
   * \param lhs The left operand.
   * \param cval The constant value.
   * \param div_mode The division mode.
   * \return The result expression;
   */
  SplitExpr SplitModConst(SplitExpr lhs, int64_t cval, DivMode div_mode);
  /*!
   * \brief Separate psum into divisible and non-divisible parts.
   * \param psum The sum expression.
   * \param coeff The co-efficient.
   * \param out_divisible The result divisible component.
   * \param out_non_divisible The non-divisible component.
   */
  void SeparateDivisibleParts(const SumExprNode* psum,
                              int64_t coeff,
                              SumExpr* out_divisible,
                              SumExpr* out_non_divisible);
  /*!
   * \brief Normalize expr to normal expr.
   * \param expr The input expression.
   * \return Normalized expr.
   */
  PrimExpr Normalize(PrimExpr expr) {
    if (const auto* op = expr.as<CanonicalExprNode>()) {
      return op->Normalize();
    } else {
      return expr;
    }
  }
  /*!
   * \brief Create a SplitExpr from expr.
   * \param expr The input expr.
   * \return The transformed SplitExpr.
   */
  SplitExpr ToSplitExpr(PrimExpr expr) {
    if (const auto* op = expr.as<SplitExprNode>()) {
      return GetRef<SplitExpr>(op);
    }
    if (const auto* op = expr.as<SumExprNode>()) {
      if (op->base == 0 && op->args.size() == 1) return op->args[0];
    }
    if (const auto* op = expr.as<CanonicalExprNode>()) {
      expr = op->Normalize();
    }
    ObjectPtr<SplitExprNode> n = make_object<SplitExprNode>();
    n->dtype = expr.dtype();
    n->index = std::move(expr);
    n->div_mode = kTruncDiv;
    return SplitExpr(n);
  }
  /*!
   * \brief Convert expr to an equivalent SplitExpr
   *        that has the specified div_mode.
   *
   * This function will return the same expr if its
   * div_mode already satisfies the need.
   *
   * \param expr The input expr.
   * \param div_mode The new div_mode.
   * \return The transformed SplitExpr.
   */
  SplitExpr ConvertDivMode(SplitExpr expr, DivMode div_mode) {
    if (expr->div_mode == div_mode) return expr;
    if (expr->DivModeCompatibleTo(div_mode)) {
      expr.CopyOnWrite()->div_mode = div_mode;
      return expr;
    }
    expr = ToSplitExpr(Normalize(expr));
    CHECK(expr->DivModeCompatibleTo(div_mode));
    expr.CopyOnWrite()->div_mode = div_mode;
    return expr;
  }
  /*!
   * \brief Create a SumExpr from expr.
   * \param expr The input expr.
   * \return The transformed SumExpr.
   */
  SumExpr ToSumExpr(PrimExpr expr) {
    if (const auto* op = expr.as<SumExprNode>()) {
      return GetRef<SumExpr>(op);
    }
    ObjectPtr<SumExprNode> n = make_object<SumExprNode>();
    n->dtype = expr.dtype();
    if (const auto* op = expr.as<IntImmNode>()) {
      n->base = op->value;
      return SumExpr(n);
    } else {
      n->args.emplace_back(ToSplitExpr(expr));
      return SumExpr(n);
    }
  }
  // Simplify the combiner used in reduce.
  PrimExpr SimplifyReduceCombiner(const ReduceNode* op);
};

PrimExpr CanonicalSimplifier::Impl::
VisitExpr_(const AddNode* op) {
  if (!IsIndexType(op->dtype)) {
    return Rewriter::VisitExpr_(op);
  }
  // normalize
  PrimExpr a = this->CanonicalMutate(op->a);
  PrimExpr b = this->CanonicalMutate(op->b);

  // const folding
  PrimExpr const_res = TryConstFold<AddNode>(a, b);
  if (const_res.defined()) return const_res;

  // canonical form simplification.
  SumExpr ret = ToSumExpr(std::move(a));

  if (const auto* op = b.as<IntImmNode>()) {
    ret.CopyOnWrite()->AddToSelf(op->value);
  } else if (const auto* op = b.as<SumExprNode>()) {
    ret.CopyOnWrite()->AddToSelf(GetRef<SumExpr>(op), 1);
  } else {
    ret.CopyOnWrite()->AddToSelf(ToSplitExpr(b), 1);
  }
  return std::move(ret);
}

PrimExpr CanonicalSimplifier::Impl::
VisitExpr_(const SubNode* op) {
  if (!IsIndexType(op->dtype)) {
    return Rewriter::VisitExpr_(op);
  }
  // normalize
  PrimExpr a = this->CanonicalMutate(op->a);
  PrimExpr b = this->CanonicalMutate(op->b);

  // const folding
  PrimExpr const_res = TryConstFold<SubNode>(a, b);
  if (const_res.defined()) return const_res;

  // canonical form simplification.
  SumExpr ret = ToSumExpr(std::move(a));

  if (const auto* op = b.as<IntImmNode>()) {
    ret.CopyOnWrite()->AddToSelf(-op->value);
  } else if (const auto* op = b.as<SumExprNode>()) {
    ret.CopyOnWrite()->AddToSelf(GetRef<SumExpr>(op), -1);
  } else {
    ret.CopyOnWrite()->AddToSelf(ToSplitExpr(b), -1);
  }
  return std::move(ret);
}


PrimExpr CanonicalSimplifier::Impl::
VisitExpr_(const MulNode* op) {
  if (!IsIndexType(op->dtype)) {
    return Rewriter::VisitExpr_(op);
  }
  // normalize
  PrimExpr a = this->CanonicalMutate(op->a);
  PrimExpr b = this->CanonicalMutate(op->b);

  // const folding
  PrimExpr const_res = TryConstFold<MulNode>(a, b);
  if (const_res.defined()) return const_res;

  // x * c
  if (a.as<IntImmNode>()) {
    std::swap(a, b);
  }
  if (const auto* bconst = b.as<IntImmNode>()) {
    if (a.as<SumExprNode>()) {
      SumExpr ret = Downcast<SumExpr>(std::move(a));
      ret.CopyOnWrite()->MulToSelf(bconst->value);
      return std::move(ret);
    } else {
      SplitExpr ret = ToSplitExpr(std::move(a));
      ret.CopyOnWrite()->MulToSelf(bconst->value);
      return std::move(ret);
    }
  }

  // normal path.
  a = Normalize(a);
  b = Normalize(b);
  if (op->a.same_as(a) && op->b.same_as(b)) {
    return GetRef<PrimExpr>(op);
  } else {
    return MulNode::make(a, b);
  }
}

void CanonicalSimplifier::Impl::
SeparateDivisibleParts(const SumExprNode* psum,
                       int64_t coeff,
                       SumExpr* out_divisible,
                       SumExpr* out_non_divisible) {
  auto divisible = make_object<SumExprNode>();
  auto non_divisible = make_object<SumExprNode>();
  divisible->dtype = psum->dtype;
  non_divisible->dtype = psum->dtype;

  if (psum->base % coeff == 0) {
    divisible->base = psum->base;
  } else {
    non_divisible->base = psum->base;
  }
  for (const auto& e : psum->args) {
    if (e->scale % coeff == 0) {
      divisible->args.push_back(e);
    } else {
      non_divisible->args.push_back(e);
    }
  }
  *out_divisible = SumExpr(divisible);
  *out_non_divisible = SumExpr(non_divisible);
}

SplitExpr CanonicalSimplifier::Impl::
SplitDivConst(SplitExpr lhs, int64_t cval, DivMode div_mode) {
  CHECK_GT(cval, 0);
  lhs = ConvertDivMode(lhs, div_mode);

  // the following rule works for both floordiv and truncdiv
  if (lhs->scale % cval == 0) {
    lhs.CopyOnWrite()->scale /= cval;
    return lhs;
  }

  if (cval % lhs->scale == 0) {
    int64_t scaled_cval = cval / lhs->scale;
    if (lhs->upper_factor == SplitExprNode::kPosInf ||
        lhs->upper_factor % (lhs->lower_factor * scaled_cval) == 0) {
      // directly fold division.
      lhs.CopyOnWrite()->scale = 1;
      lhs.CopyOnWrite()->lower_factor *= scaled_cval;
      lhs->Verify();
      return lhs;
    } else if (lhs->upper_factor <= (lhs->lower_factor * scaled_cval)) {
      // (x % c1) / c2  => 0 when c2 >= c1
      return ToSplitExpr(make_zero(lhs.dtype()));
    } else {
      // move the upper_factor modular into index.
      lhs.CopyOnWrite()->index =
          ModImpl(lhs->index, make_const(lhs.dtype(), lhs->upper_factor), div_mode);
      lhs.CopyOnWrite()->upper_factor = SplitExprNode::kPosInf;
      lhs.CopyOnWrite()->scale = 1;
      lhs.CopyOnWrite()->lower_factor *= scaled_cval;
      lhs->Verify();
      return lhs;
    }
  }
  // directly return the split with cval == 1
  lhs = ToSplitExpr(Normalize(lhs));
  CHECK(lhs->DivModeCompatibleTo(div_mode));
  CHECK_EQ(lhs->scale, 1);
  lhs.CopyOnWrite()->lower_factor *= cval;
  return lhs;
}

PrimExpr CanonicalSimplifier::Impl::
VisitExpr_(const DivNode* op) {
  if (!IsIndexType(op->dtype)) {
    return Rewriter::VisitExpr_(op);
  }

  PrimExpr a = this->CanonicalMutate(op->a);
  PrimExpr b = this->CanonicalMutate(op->b);

  // const folding
  PrimExpr const_res = TryConstFold<DivNode>(a, b);
  if (const_res.defined()) return const_res;
  PVar<IntImm> c1;
  // x / c1
  if (c1.Match(b) && c1.Eval()->value > 0) {
    int64_t cval = c1.Eval()->value;
    if (cval == 1) return a;

    if (const auto* psum = a.as<SumExprNode>()) {
      SumExpr lhs, extra;
      SeparateDivisibleParts(psum, cval, &lhs, &extra);
      // can be divided by cval
      if (extra->IsZero()) {
        lhs.CopyOnWrite()->DivideBy(cval);
        return std::move(lhs);
      }
      // both lhs and extra are non-negative
      if (analyzer_->CanProveGreaterEqual(lhs->Normalize(), 0) &&
          analyzer_->CanProveGreaterEqual(extra->Normalize(), 0)) {
        lhs.CopyOnWrite()->DivideBy(cval);
        PrimExpr temp = Normalize(extra);
        if (const auto* pconst = temp.as<IntImmNode>()) {
          lhs.CopyOnWrite()->AddToSelf(pconst->value / cval);
        } else {
          // if 0 <= extra < cval, it means the extra can be eliminated.
          if (TryCompare(temp, cval) != kLT) {
            lhs.CopyOnWrite()->AddToSelf(
                SplitDivConst(ToSplitExpr(temp), cval, kTruncDiv), 1);
          }
        }
        return std::move(lhs);
      }
    } else {
      // if a >= 0 && a < cval, then result == 0
      auto cbound = analyzer_->const_int_bound(Normalize(a));
      if (cbound->min_value >= 0 && cbound->max_value < cval) {
        return make_zero(a.dtype());
      }
    }
    return SplitDivConst(ToSplitExpr(std::move(a)), cval, kTruncDiv);
  }
  // normal path
  a = Normalize(a);
  b = Normalize(b);
  if (op->a.same_as(a) && op->b.same_as(b)) {
    return GetRef<PrimExpr>(op);
  } else {
    return DivNode::make(a, b);
  }
}

PrimExpr CanonicalSimplifier::Impl::
VisitExpr_(const FloorDivNode* op) {
  if (!IsIndexType(op->dtype)) {
    return Rewriter::VisitExpr_(op);
  }
  PrimExpr a = this->CanonicalMutate(op->a);
  PrimExpr b = this->CanonicalMutate(op->b);

  // const folding
  PrimExpr const_res = TryConstFold<FloorDivNode>(a, b);
  if (const_res.defined()) return const_res;
  PVar<IntImm> c1;
  // x / c1
  if (c1.Match(b) && c1.Eval()->value > 0) {
    int64_t cval = c1.Eval()->value;
    if (cval == 1) return a;

    if (const auto* psum = a.as<SumExprNode>()) {
      SumExpr lhs, extra;
      SeparateDivisibleParts(psum, cval, &lhs, &extra);
      if (extra->IsZero()) {
        lhs.CopyOnWrite()->DivideBy(cval);
        return std::move(lhs);
      }
      // continue simplification.
      lhs.CopyOnWrite()->DivideBy(cval);
      PrimExpr temp = Normalize(extra);
      if (const auto* pconst = temp.as<IntImmNode>()) {
        lhs.CopyOnWrite()->AddToSelf(floordiv(pconst->value, cval));
      } else {
        // if 0 <= extra < cval, it means the extra can be eliminated.
        if (!(TryCompare(temp, cval) == kLT && analyzer_->CanProveGreaterEqual(temp, 0))) {
          lhs.CopyOnWrite()->AddToSelf(
              SplitDivConst(ToSplitExpr(temp), cval, kFloorDiv), 1);
        }
      }
      return std::move(lhs);
    } else {
      // if a >= 0 && a < cval, then result == 0
      auto cbound = analyzer_->const_int_bound(Normalize(a));
      if (cbound->min_value >= 0 && cbound->max_value < cval) {
        return make_zero(a.dtype());
      }
    }
    return SplitDivConst(ToSplitExpr(std::move(a)), cval, kFloorDiv);
  }
  // normal path
  a = Normalize(a);
  b = Normalize(b);
  if (op->a.same_as(a) && op->b.same_as(b)) {
    return GetRef<PrimExpr>(op);
  } else {
    return FloorDivNode::make(a, b);
  }
}

SplitExpr CanonicalSimplifier::Impl::
SplitModConst(SplitExpr lhs, int64_t cval, DivMode div_mode) {
  CHECK_GT(cval, 0);
  lhs = ConvertDivMode(lhs, div_mode);

  if (lhs->scale % cval == 0) {
    lhs.CopyOnWrite()->scale = 0;
    return lhs;
  }
  if (cval % lhs->scale == 0) {
    // (x * c1) % (c2 * c1) => (x % c2) * c1
    int64_t scaled_cval = cval / lhs->scale;
    //  (x / c1) % c2  =>  (x % (c1 * c2)) / c2
    int64_t new_upper_factor = lhs->lower_factor * scaled_cval;
    // try to see if we can reduce the existing upper modular.
    if (lhs->upper_factor == SplitExprNode::kPosInf ||
        lhs->upper_factor % new_upper_factor == 0) {
      // we gained a new upper factor that is smaller
      // than the original one
      // Perhaps there are more chances in simplifying the index
      // Do a recursive call to simplify the mod with the new factor.
      if (new_upper_factor < lhs->upper_factor &&
          lhs->upper_factor != SplitExprNode::kPosInf) {
        auto updated = ToSplitExpr(this->VisitExpr(ModImpl(
            lhs->index, make_const(lhs.dtype(), new_upper_factor), div_mode)));
        // re-apply the lower_factor
        if (lhs->lower_factor != 1) {
          return SplitDivConst(updated, lhs->lower_factor, div_mode);
        } else {
          return updated;
        }
      } else {
        lhs.CopyOnWrite()->upper_factor = new_upper_factor;
        return lhs;
      }
    } else if (new_upper_factor % lhs->upper_factor == 0) {
      // (x % 2) % 4 => x % 2
      return lhs;
    }
  }
  // Normalize the value.
  lhs = ToSplitExpr(Normalize(lhs));
  CHECK(lhs->DivModeCompatibleTo(div_mode));
  CHECK_EQ(lhs->scale, 1);
  CHECK_EQ(lhs->lower_factor, 1);
  lhs.CopyOnWrite()->div_mode = div_mode;
  lhs.CopyOnWrite()->upper_factor = cval;
  return lhs;
}

PrimExpr CanonicalSimplifier::Impl::
VisitExpr_(const ModNode* op) {
  if (!IsIndexType(op->dtype)) {
    return Rewriter::VisitExpr_(op);
  }
  // normalize
  PrimExpr a = this->CanonicalMutate(op->a);
  PrimExpr b = this->CanonicalMutate(op->b);

  // const folding
  PrimExpr const_res = TryConstFold<ModNode>(a, b);
  if (const_res.defined()) return const_res;

  PVar<IntImm> c1;
  // x % c1
  if (c1.Match(b) && c1.Eval()->value > 0) {
    int64_t cval = c1.Eval()->value;
    if (const auto* psum = a.as<SumExprNode>()) {
      SumExpr lhs, extra;
      SeparateDivisibleParts(psum, cval, &lhs, &extra);
      if (extra->IsZero()) {
        return make_zero(a.dtype());
      }
      // both lhs and extra are non-negative
      if (analyzer_->CanProveGreaterEqual(lhs->Normalize(), 0) &&
          analyzer_->CanProveGreaterEqual(extra->Normalize(), 0)) {
        PrimExpr temp = Normalize(extra);
        if (temp.as<IntImmNode>()) {
          return truncmod(temp, c1.Eval());
        } else {
          // If temp < cval && temp >=0 then can remove the mod.
          if (TryCompare(temp, cval) == kLT) {
            return temp;
          } else {
            // contonue to use logic below.
            a = extra;
            psum = a.as<SumExprNode>();
            CHECK(psum != nullptr);
          }
        }
      }
      // Simplify the offset constant if necessary.
      // (x - 5) % 3 => (x - 2) % 3 if x - 5 >= 0
      auto cbound = analyzer_->const_int_bound(Normalize(a));
      int64_t new_base = psum->base % cval;
      if (cbound->min_value >= 0 &&
          cbound->min_value - psum->base + new_base >= 0) {
        SumExpr sum_expr = Downcast<SumExpr>(a);
        sum_expr.CopyOnWrite()->base = new_base;
        return SplitModConst(ToSplitExpr(std::move(sum_expr)), cval, kTruncDiv);
      }
    } else {
      // if a >= 0 && a < cval, then result == 0
      auto cbound = analyzer_->const_int_bound(Normalize(a));
      if (cbound->min_value >= 0 && cbound->max_value < cval) {
        return a;
      }
    }
    return SplitModConst(ToSplitExpr(std::move(a)), cval, kTruncDiv);
  }
  // normal path
  a = Normalize(a);
  b = Normalize(b);
  if (op->a.same_as(a) && op->b.same_as(b)) {
    return GetRef<PrimExpr>(op);
  } else {
    return ModNode::make(a, b);
  }
}

PrimExpr CanonicalSimplifier::Impl::
VisitExpr_(const FloorModNode* op) {
  if (!IsIndexType(op->dtype)) {
    return Rewriter::VisitExpr_(op);
  }
  // normalize
  PrimExpr a = this->CanonicalMutate(op->a);
  PrimExpr b = this->CanonicalMutate(op->b);

  // const folding
  PrimExpr const_res = TryConstFold<FloorModNode>(a, b);
  if (const_res.defined()) return const_res;

  PVar<IntImm> c1;
  // x % c1
  if (c1.Match(b) && c1.Eval()->value > 0) {
    int64_t cval = c1.Eval()->value;
    if (const auto* psum = a.as<SumExprNode>()) {
      SumExpr lhs, extra;
      SeparateDivisibleParts(psum, cval, &lhs, &extra);
      PrimExpr temp = Normalize(extra);
      if (temp.as<IntImmNode>()) {
        return floormod(temp, c1.Eval());
      } else {
        // If temp < cval && temp >=0 then can remove the mod.
        if (TryCompare(temp, cval) == kLT &&
            analyzer_->CanProveGreaterEqual(temp, 0)) {
          return temp;
        } else {
          // contonue to use logic below.
          a = extra;
          psum = a.as<SumExprNode>();
          CHECK(psum != nullptr);
        }
      }
      // Simplify the offset constant if necessary.
      // floormod(x - 5, 3) => floormod(x + 1, 3)
      int64_t new_base = floormod(psum->base, cval);
      SumExpr sum_expr = Downcast<SumExpr>(std::move(a));
      sum_expr.CopyOnWrite()->base = new_base;
      return SplitModConst(ToSplitExpr(std::move(sum_expr)), cval, kFloorDiv);
    } else {
      // if a >= 0 && a < cval, then result == a
      auto cbound = analyzer_->const_int_bound(Normalize(a));
      if (cbound->min_value >= 0 && cbound->max_value < cval) {
        return a;
      }
    }
    return SplitModConst(ToSplitExpr(std::move(a)), cval, kFloorDiv);
  }
  // normal path
  a = Normalize(a);
  b = Normalize(b);
  if (op->a.same_as(a) && op->b.same_as(b)) {
    return GetRef<PrimExpr>(op);
  } else {
    return FloorModNode::make(a, b);
  }
}

// Simplify reduce expression.
PrimExpr CanonicalSimplifier::Impl::
SimplifyReduceCombiner(const ReduceNode* op) {
  // First simplify the results
  Array<PrimExpr> simplified_result;
  for (const auto& res : op->combiner->result) {
    PrimExpr new_res = this->VisitExpr(res);
    simplified_result.push_back(new_res);
  }

  // Which components to keep
  std::vector<int> used(op->combiner->result.size(), false);

  // This function recursively marks the used components starting from
  // the index idx
  std::function<void(int)> mark_used;
  mark_used = [&used, &simplified_result, op, &mark_used](size_t idx) {
    // if the idx-th component was marked as used before, do nothing
    if (used[idx]) return;
    used[idx] = true;

    // check if the idx-th result expr uses some lhs or rhs variables
    // and recursively mark the corresponding components
    for (size_t i = 0; i < simplified_result.size(); ++i)
      if (!used[i]) {
        if (ExprUseVar(simplified_result[idx], op->combiner->lhs[i]) ||
            ExprUseVar(simplified_result[idx], op->combiner->rhs[i]))
          mark_used(i);
      }
  };

  // mark all used components starting from the value_index
  mark_used(op->value_index);

  // components which have side effects should also be preserved
  for (size_t i = 0; i < used.size(); ++i) {
    if (HasSideEffect(op->source[i]) ||
        HasSideEffect(op->combiner->identity_element[i]) ||
        HasSideEffect(op->combiner->result[i])) {
      mark_used(i);
    }
  }

  int new_value_index = op->value_index;
  Array<PrimExpr> new_result;
  Array<PrimExpr> new_identity;
  Array<Var> new_lhs;
  Array<Var> new_rhs;
  Array<PrimExpr> new_source;

  // new stuff is old stuff which is used
  for (size_t i = 0; i < used.size(); ++i) {
    if (used[i]) {
      // We simplify the result and identity, but not the source
      new_result.push_back(simplified_result[i]);
      new_identity.push_back(this->VisitExpr(op->combiner->identity_element[i]));
      new_lhs.push_back(op->combiner->lhs[i]);
      new_rhs.push_back(op->combiner->rhs[i]);
      new_source.push_back(op->source[i]);
    } else if (static_cast<int>(i) < op->value_index) {
      // value_index should also be adjusted
      new_value_index--;
    }
  }

  CommReducer new_combiner =
      CommReducerNode::make(new_lhs, new_rhs, new_result, new_identity);
  return ReduceNode::make(
      new_combiner, new_source, op->axis, op->condition, new_value_index);
}

PrimExpr CanonicalSimplifier::Impl::
VisitExpr_(const ReduceNode* op) {
  // Recursively call simplification when necessary.
  PrimExpr ret = RewriteSimplifier::Impl::VisitExpr_(op);
  op = ret.as<ReduceNode>();
  // already been simplified by const reduction axis removal
  if (op == nullptr) return ret;
  if (op->axis.empty()) {
    // Note that here we assume that the identity element is indeed identity. Without this
    // assumption we would have to perform a single iteration of the loop, i.e. use
    // `(*op->combiner.get())(op->combineop->identity_element, op->source)[op->value_index]`
    // instead of `op->source[op->value_index]`. The former may be more difficult to simplify.
    return this->VisitExpr(
        SelectNode::make(op->condition,
                     op->source[op->value_index],
                     op->combiner->identity_element[op->value_index]));
  }
  // combiner simplification.
  ret = SimplifyReduceCombiner(op);
  return ret;
}

PrimExpr CanonicalSimplifier::operator()(const PrimExpr& expr) {
  return impl_->CanonicalSimplify(expr);
}

void CanonicalSimplifier::Update(const Var& var,
                                 const PrimExpr& info,
                                 bool override) {
  impl_->Update(var, info, override);
}

CanonicalSimplifier::CanonicalSimplifier(Analyzer* parent)
    : impl_(new Impl(parent)) {
}

CanonicalSimplifier::~CanonicalSimplifier() {
  delete impl_;
}

}  // namespace arith
}  // namespace tvm