/*!
 *  Copyright (c) 2016 by Contributors
 * \file tvm/arithmetic.h
 * \brief Algebra and set operations and simplifications.
 */
#ifndef TVM_ARITHMETIC_H_
#define TVM_ARITHMETIC_H_

#include <vector>
#include <unordered_map>
#include <memory>
#include <limits>
#include "expr.h"

namespace tvm {
// forward delcare Tensor
class Tensor;
/*! \brief namespace of arithmetic */
namespace arith {
//-------------------------------------------------------
// Base integer analysis API.
//
// We have multiple type of analyzers to do relaxed
// integer set analysis(bound analysis, modulo) and
// equivalence checking and simplification.
//
// Importantly, each analyzer may need result from
// another analyzer.
//-------------------------------------------------------

// Forward declare Analyzer
class Analyzer;
/*!
 * \brief reference class to ConstIntBoundNode
 * \sa ConstIntBoundNode
 */
class ConstIntBound;
/*!
 * \brief Constant integer up and lower bound(inclusive).
 *  Useful for value bound analysis.
 *
 *  set = [min_value, max_value]
 */
class ConstIntBoundNode : public Node {
 public:
  int64_t min_value;
  int64_t max_value;

  void VisitAttrs(tvm::AttrVisitor* v) final {
    v->Visit("min_value", &min_value);
    v->Visit("max_value", &max_value);
  }

  TVM_DLL static ConstIntBound make(int64_t min_value, int64_t max_value);

  /*! \brief Number to represent +inf */
  static const constexpr int64_t kPosInf = std::numeric_limits<int64_t>::max();
  /*!
   * \brief Number to represent -inf
   * \note We can make use the of fact that -kPosInf == kNegInf in the project.
   */
  static const constexpr int64_t kNegInf = -kPosInf;

  static constexpr const char* _type_key = "arith.ConstIntBound";
  TVM_DECLARE_NODE_TYPE_INFO(ConstIntBoundNode, Node);
};

TVM_DEFINE_NODE_REF(ConstIntBound, ConstIntBoundNode);

/*!
 * \brief Analyzer to get constant integer bound over expression.
 */
class ConstIntBoundAnalyzer {
 public:
  /*!
   * \brief analyze the expr
   * \param expr The expression of interest.
   * \return the result of the analysis.
   */
  ConstIntBound operator()(const Expr& expr);

  /*!
   * \brief Update constant int bound information of var.
   *
   * \param var The variable of interest.
   * \param info The bound information.
   * \param override Whether do we allow override of existing information.
   */
  void Update(const Var& var,
              const ConstIntBound& info,
              bool override = false);
  /*!
   * \brief Bind variable to a range.
   *
   * \param var The variable.
   * \param range The range we bind to.
   */
  void Bind(const Var& var, const Range& range);

 private:
  friend class Analyzer;
  friend class ConstraintContext;
  explicit ConstIntBoundAnalyzer(Analyzer* parent);
  ~ConstIntBoundAnalyzer();
  /*!
   * \brief Update the internal state to enter constraint.
   * \param constraint A constraint expression.
   *
   * \return an exit function that must be called to cleanup the constraint can be nullptr.
   */
  std::function<void()> EnterConstraint(const Expr& constraint);
  struct Entry;
  class Impl;
  /*! \brief Internal impl */
  Impl* impl_;
};

/*!
 * \brief reference of ModularSetNode
 * \sa ModularSetNode
 */
class ModularSet;
/*!
 * \brief Range of a linear integer function.
 *  Use to do specify the possible index values.
 *
 *  set = { coeff * x + base | x in Z }
 *
 *  When coeff != 0, it can also be written as
 *  set = { n | n % coeff == base }
 *
 *  This is useful to decide if the index is dividable by certain value.
 *  For example, if index = 0 + 4 x, then we know it can be divided by 4.
 */
class ModularSetNode : public Node {
 public:
  /*! \brief linear co-efficient */
  int64_t coeff;
  /*! \brief The base */
  int64_t base;

  void VisitAttrs(tvm::AttrVisitor* v) final {
    v->Visit("coeff", &coeff);
    v->Visit("base", &base);
  }

  TVM_DLL static ModularSet make(int64_t coeff, int64_t base);

  static constexpr const char* _type_key = "arith.ModularSet";
  TVM_DECLARE_NODE_TYPE_INFO(ModularSetNode, Node);
};

TVM_DEFINE_NODE_REF(ModularSet, ModularSetNode);

/*!
 * \brief Analyzer to get modular information over expression.
 */
class ModularSetAnalyzer {
 public:
  /*!
   * \brief analyze the expr
   * \param expr The expression of interest.
   * \return the result of the analysis.
   */
  ModularSet operator()(const Expr& expr);
  /*!
   * \brief Update constant int bound information of var.
   *
   * \param var The variable of interest.
   * \param info The bound information.
   * \param override Whether do we allow override of existing information.
   */
  void Update(const Var& var,
              const ModularSet& info,
              bool override = false);

 private:
  friend class Analyzer;
  friend class ConstraintContext;
  explicit ModularSetAnalyzer(Analyzer* parent);
  ~ModularSetAnalyzer();
  /*!
   * \brief Update the internal state to enter constraint.
   * \param constraint A constraint expression.
   *
   * \return an exit function that must be called to cleanup the constraint can be nullptr.
   */
  std::function<void()> EnterConstraint(const Expr& constraint);
  struct Entry;
  class Impl;
  /*! \brief Internal impl */
  Impl* impl_;
};

/*!
 * \brief Rewrite-rule based simplifier.
 */
class RewriteSimplifier {
 public:
  /*!
   * \brief analyze the expr
   * \param expr The expression of interest.
   * \return the result of the analysis.
   */
  Expr operator()(const Expr& expr);

  /*!
   * \brief Update binding of var to a new expression.
   *
   * \param var The variable of interest.
   * \param new_expr
   * \param override Whether do we allow override of existing information.
   */
  void Update(const Var& var,
              const Expr& new_expr,
              bool override = false);

 private:
  friend class Analyzer;
  friend class ConstraintContext;
  friend class CanonicalSimplifier;
  explicit RewriteSimplifier(Analyzer* parent);
  ~RewriteSimplifier();
  class Impl;
  /*! \brief Internal impl */
  Impl* impl_;
};

/*!
 * \brief Canonical-form based simplifier.
 */
class CanonicalSimplifier {
 public:
  /*!
   * \brief analyze the expr
   * \param expr The expression of interest.
   * \return the result of the analysis.
   */
  Expr operator()(const Expr& expr);

  /*!
   * \brief Update binding of var to a new expression.
   *
   * \param var The variable of interest.
   * \param new_expr
   * \param override Whether do we allow override of existing information.
   */
  void Update(const Var& var,
              const Expr& new_expr,
              bool override = false);

 private:
  friend class Analyzer;
  friend class ConstraintContext;
  explicit CanonicalSimplifier(Analyzer* parent);
  ~CanonicalSimplifier();
  class Impl;
  /*! \brief Internal impl */
  Impl* impl_;
};

/*!
 * \brief A RAII constraint context.
 *
 * \code
 *
 *  Var("x");
 *  arith::Analyzer analyzer;
 *  {
 *    arith::ConstraintContext cctx(&analyzer, x % 3 == 0);
 *    CHECK_EQ(analyzer.modular_set(x)->coeff, 3);
 *  }
 *  // constraint no longer in effect.
 *  CHECK_NE(analyzer.modular_set(x)->coeff, 3);
 *
 * \endcode
 */
class ConstraintContext {
 public:
  /*!
   * \brief Construct a constraint context.
   * \param analyzer The analyzer.
   * \param constraint The constraint to be applied.
   */
  ConstraintContext(Analyzer* analyzer, const Expr& constraint) DMLC_THROW_EXCEPTION;
  /*! \brief destructor */
  ~ConstraintContext() DMLC_THROW_EXCEPTION {
    exit_();
  }

 private:
  /*! \brief function to be called in recovery */
  std::function<void()> exit_;
};

/*!
 * \brief Analyzer that contains bunch of sub-analyzers.
 *
 * Each sub-analyzer can make use of another sub-analyzer
 * by weak reference of this.
 *
 * NOTE for sub-analyzer developers:
 * If the analyzer uses memoization, we need to clear the internal
 * cache when information about a Var has been overrideen.
 */
class Analyzer {
 public:
  /*! \brief sub-analyzer: const integer bound */
  ConstIntBoundAnalyzer const_int_bound;
  /*! \brief sub-analyzer: modular set */
  ModularSetAnalyzer modular_set;
  /*! \brief sub-analyzer rewrite simplfy */
  RewriteSimplifier rewrite_simplify;
  /*! \brief sub-analyzer rewrite simplfy */
  CanonicalSimplifier canonical_simplify;
  /*! \brief constructor */
  Analyzer();
  /*!
   * \brief Notify all the sub-analyzers that var
   *        is created and binded to expr.
   *
   *  Each var can only be binded once.
   *
   * \param var The variable.
   * \param expr The expression we bind to.
   */
  void Bind(const VarExpr& var, const Expr& expr);
  /*!
   * \brief Notify all the sub-analyzers that var
   *        is created and binded to a range.
   *
   *  Each var can only be binded once.
   *
   * \param var The variable.
   * \param range The range we bind to.
   */
  void Bind(const VarExpr& var, const Range& range);
  /*!
   * \brief Whether can we proof expr >= val.

   *  Non-negative proof is very useful in integer analysis
   *  to lower divisions and mods given difference in trunc and ceil mode.
   *
   * \param expr The expression.
   * \param lower_bound The lower bound.
   * \return Whether we can proof it.
   *
   * \note Analyzer will call into sub-analyzers to get the result.
   */
  bool CanProveGreaterEqual(const Expr& expr, int64_t lower_bound);
};

//-----------------------------------------------
// Integer set abstraction API.
//
// This is a API build on top of the base
// integer analysis API to provide set analysis.
//------------------------------------------------
/*!
 * \brief Sign of an expression or set.
 */
enum SignType {
  kPositive,
  kNegative,
  kZero,
  kUnknown
};

// internal node container of int set.
struct IntSetNode;

/*!
 * \brief Integer set class, represent a set of integers in one dimension.
 */
class IntSet : public NodeRef {
 public:
  /*! \brief constructor */
  IntSet() {}
  // constructor from not container.
  explicit IntSet(NodePtr<Node> n) : NodeRef(n) {}
  /*!
   * \brief access the internal node container
   * \return the pointer to the internal node container
   */
  inline const IntSetNode* operator->() const;
  /*!
   * \brief Find a range that covers the region.
   * \param max_range The range to be covered.
   * \return The covering range.
   */
  Range cover_range(Range max_range) const;
  /*!
   * \brief find an interval that covers the set.
   * \return The covering interval set.
   */
  IntSet cover_interval() const;
  /*! \return Lower bound of the set */
  Expr min() const;
  /*! \return upper bound of the set */
  Expr max() const;
  /*! \return Whether the set represent nothing  */
  bool is_nothing() const;
  /*! \return Whether the set represent everything  */
  bool is_everything() const;
  /*! \return Whether the set is a single point */
  bool is_single_point() const;
  /*! \return Whether the set is proved to be bigger than 0 */
  bool can_prove_positive() const;
  /*! \return Whether the set is proved to be smaller than 0 */
  bool can_prove_negative() const;
  /*! \return Whether the set is proved to be smaller than or equal to 0 */
  bool can_prove_non_positive() const;
  /*! \return Whether the set is proved to be larger than or equal to 0 */
  bool can_prove_non_negative() const;
  /*! \return The sign of the elements in the integer set */
  SignType sign_type() const;
  /*!
   * \brief The single point value, call only if is_single_point is true
   * \return The point value.
   */
  Expr point_value() const;
  /*!
   * \brief Try to match IntSet with range r.
   *
   * \note It is guanrateed that IntSet::range(r).match_range(r) == true
   * \return true if we can prove they are the same.
   */
  bool match_range(const Range& r) const;
  /*! \return The set contains nothing */
  static IntSet nothing();
  /*! \return The set contains everything */
  static IntSet everything();
  /*!
   * \brief construct a point set.
   * \param point The point in the set.
   * \return construct a single point set
   */
  static IntSet single_point(Expr point);
  /*!
   * \brief construct a integer set from vector expression.
   * \param vec The vector expression, can also be single point.
   * \return The result set containing the indices in the vector.
   */
  static IntSet vector(Expr vec);
  /*!
   * \brief Construct a set representing a range.
   * \param r The range
   * \return constructed set.
   */
  static IntSet range(Range r);
  /*!
   * \brief Construct a set representing a interval.
   * \param min The minimum value of the interval.
   * \param max The maximum value of the interval.
   * \return constructed set.
   */
  static IntSet interval(Expr min, Expr max);
};

/*!
 * \brief Base class of all IntSet containers.
 */
struct IntSetNode : public Node {
  static constexpr const char* _type_key = "IntSet";
  TVM_DECLARE_BASE_NODE_INFO(IntSetNode, Node);
};

/*!
 * \brief Detect if e can be rewritten as e = sum_{i=0}^{n-1} var[i] * coeff[i] + coeff[n]
 *  Where coeff[i] and base are invariant of var[j] for all i and j.
 *
 * \param e The expression to be detected.
 * \param vars List of variables to be used in detection.
 * \return [coeff[i]] if it is possible, empty array if it is not.
 */
Array<Expr> DetectLinearEquation(const Expr& e, const Array<Var>& vars);

/*!
 * \brief Detect if expression corresponds to clip bound of the vars
 *
 * \param e The expression to be detected.
 * \param vars List of variables to be used in detection.
 * \return concat([min_value[i], max_value[i]]), None is returned if there is no min or max value
 *          return empty if the e does not match the pattern.
 */
Array<Expr> DetectClipBound(const Expr& e, const Array<Var>& vars);

/*!
 * \brief Find an symbolic integer set that contains all possible values of
 *  e given the domain of each iteration variables.
 *
 * \param e The expression to be evaluated.
 * \param dom_map The domain of each variable.
 * \return An integer set that can cover all the possible values of e.
 */
IntSet EvalSet(Expr e,
               const Map<IterVar, IntSet>& dom_map);
/*!
 * \brief Same as EvalSet, but takes unordered_map
 *
 * \param e The expression to be evaluated.
 * \param dom_map The domain of each variable.
 * \return An integer set that can cover all the possible values of e.
 */
IntSet EvalSet(Expr e,
               const std::unordered_map<const Variable*, IntSet>& dom_map);

/*!
 * \brief Find an symbolic integer set that contains is union over
 *  all the possible conditional values in dom_map.
 *
 * \param r The initial range.
 * \param dom_map The domain of each variable.
 * \return An integer set that can cover all the possible values.
 */
IntSet EvalSet(Range r,
               const Map<IterVar, IntSet>& dom_map);

/*!
 * \brief Find an symbolic integer set that contains is union over
 *  all the possible conditional values in dom_map.
 *
 * \param s The initial set.
 * \param dom_map The domain of each variable.
 * \return An integer set that can cover all the possible values.
 */
IntSet EvalSet(IntSet s,
               const std::unordered_map<const Variable*, IntSet>& dom_map);
/*!
 * \brief Same as EvalSet, but takes unordered_map
 *
 * \param r The range to be evaluated.
 * \param dom_map The domain of each variable.
 * \return An integer set that can cover all the possible values of e.
 */
IntSet EvalSet(Range r,
               const std::unordered_map<const Variable*, IntSet>& dom_map);

/*! \brief Map from Expr to IntSet */
using ExprIntSetMap = std::unordered_map<Expr, IntSet, ExprHash, ExprEqual>;
/*!
 * \brief Find the integer set of every sub-expression, given the
 *  domain of each iteration variables.
 *
 * \param e The expression to be evaluated.
 * \param dom_map The domain of each variable.
 * \return the map from the expression to its possible value.
 */
ExprIntSetMap EvalSetForEachSubExpr(
    Expr e,
    const std::unordered_map<const Variable*, IntSet>& dom_map);

/*!
 * \brief Create an union set of all sets
 * \param sets The sets to be unioned
 * \return the set after union
 */
IntSet Union(const Array<IntSet>& sets);

/*!
 * \brief Create an union set of all sets
 * \param sets The sets to be intersected
 * \return the set after intersected
 */
IntSet Intersect(const Array<IntSet>& sets);

/*!
 * \brief Deduce the bound of the target variable in a expression,
 *  give the domain of each variables. Return undefined IntSet to
 *  represent failure.
 *
 * \param v The target variable to be deduced.
 * \param cond The conditional expression.
 * \param hint_map The domain of variable, used to help deduce.
 * \param relax_map The domain of each variable, used to relax the domain,
 *        The deduce bound mush implies e for all value in relax_map
 * \return An integer set that can cover all the possible values.
 */
IntSet DeduceBound(Expr v, Expr cond,
                   const Map<Var, IntSet>& hint_map,
                   const Map<Var, IntSet>& relax_map);
/*!
 * \brief Same as DeduceBound with  unordered_map signature.
 *
 * \param v The target variable to be deduced.
 * \param cond The conditional expression.
 * \param hint_map The domain of variable, used to help deduce.
 * \param relax_map The domain of each variable, used to relax the domain,
 *        The deduce bound mush implies e for all value in relax_map
 * \return An integer set that can cover all the possible values.
 */
IntSet DeduceBound(Expr v, Expr cond,
                   const std::unordered_map<const Variable*, IntSet>& hint_map,
                   const std::unordered_map<const Variable*, IntSet>& relax_map);

/*!
 * \brief Infer a regular domain that covers all the calls or provides within the given statement.
 * \param body The given statement.
 * \param tensor The name of the calls or provides.
 * \param consider_calls If calls (read) are considered.
 * \param consider_provides If provides (write) are considered.
 * \return The domain that covers all the calls or provides within the given statement.
 */
Domain DomainTouched(Stmt body, const Tensor &tensor, bool consider_calls, bool consider_provides);

// implementation
inline const IntSetNode* IntSet::operator->() const {
  return static_cast<const IntSetNode*>(node_.get());
}
}  // namespace arith
}  // namespace tvm
#endif  // TVM_ARITHMETIC_H_