/*!
 *  Copyright (c) 2017 by Contributors
 * \file compute_expr.h
 * \brief Utility integer expression with quick eager simplification.
 *  This is weaker than Simplify but can be done Eagerly.
 */
#ifndef TVM_ARITHMETIC_COMPUTE_EXPR_H_
#define TVM_ARITHMETIC_COMPUTE_EXPR_H_

#include <tvm/ir.h>
#include <arithmetic/Interval.h>
#include <limits>

namespace tvm {
namespace arith {

using HalideIR::Internal::add_would_overflow;
using HalideIR::Internal::sub_would_overflow;
using HalideIR::Internal::mul_would_overflow;

/*!
 * \brief Compute the expression with the given binary op.
 * \param lhs The left operand
 * \param rhs The right operand
 * \tparam Op the computation operator
 * \return The result.
 */
template<typename OP>
inline Expr ComputeExpr(Expr lhs, Expr rhs) {
  return OP::make(lhs, rhs);
}

/*!
 * \brief Compute an reduction with Op
 * \param values The input values.
 * \param empty_value The value when return if it is empty, can be Expr()
 *        which will cause an error to be rasied.
 * \tparam Op The computation operator
 * \return The result.
 */
template<typename Op>
inline Expr ComputeReduce(
    const Array<Expr>& values, Expr empty_value);

template<typename T>
inline bool GetConst(Expr e, T* out);

template<>
inline bool GetConst<int64_t>(Expr e, int64_t *out) {
  if (e.type().is_vector()) return false;
  const int64_t *v = as_const_int(e);
  if (v) {
    *out = *v; return true;
  } else {
    return false;
  }
}
template<>
inline bool GetConst<uint64_t>(Expr e, uint64_t *out) {
  if (e.type().is_vector()) return false;
  const uint64_t *v = as_const_uint(e);
  if (v) {
    *out = *v; return true;
  } else {
    return false;
  }
}

// get a small constant int
inline bool GetConstInt(Expr e, int* out) {
  int64_t v1 = 0;
  uint64_t v2 = 0;
  if (GetConst(e, &v1)) {
    if (v1 > static_cast<int64_t>(
            std::numeric_limits<int>::max())) return false;
    *out = static_cast<int>(v1); return true;
  }
  if (GetConst(e, &v2)) {
    if (v2 > static_cast<uint64_t>(
            std::numeric_limits<int>::max())) return false;
    *out = static_cast<int>(v2); return true;
  }
  return false;
}

#define TVM_CONST_PROPAGATION(OP_NAME, OP)                       \
  int64_t ia = 0, ib = 0;                                        \
  if (GetConst(a, &ia) && GetConst(b, &ib)) {                    \
    if (OP_NAME ## _would_overflow(a.type().bits(), ia, ib)) {   \
      LOG(FATAL) << "signed int overflow";                       \
    }                                                            \
    return ir::IntImm::make(a.type(), ia OP ib);                 \
  }                                                              \
  uint64_t ua = 0, ub = 0;                                       \
  if (GetConst(a, &ua) && GetConst(b, &ub)) {                    \
    return ir::UIntImm::make(a.type(), ua OP ub);                \
  }                                                              \

template<>
inline Expr ComputeExpr<ir::Add>(Expr a, Expr b) {
  if (is_zero(a)) return b;
  if (is_zero(b)) return a;
  TVM_CONST_PROPAGATION(add, +);
  return ir::Add::make(a, b);
}

template<>
inline Expr ComputeExpr<ir::Sub>(Expr a, Expr b) {
  if (is_zero(b)) return a;
  TVM_CONST_PROPAGATION(sub, -);
  return ir::Sub::make(a, b);
}

template<>
inline Expr ComputeExpr<ir::Mul>(Expr a, Expr b) {
  if (is_one(a)) return b;
  if (is_one(b)) return a;
  TVM_CONST_PROPAGATION(mul, *);
  return ir::Mul::make(a, b);
}

template<>
inline Expr ComputeExpr<ir::Div>(Expr a, Expr b) {
  if (is_one(b)) return a;
  return ir::Div::make(a, b);
}

template<>
inline Expr ComputeExpr<ir::Mod>(Expr a, Expr b) {
  if (is_zero(a)) return make_zero(a.type());
  return ir::Mod::make(a, b);
}

template<>
inline Expr ComputeExpr<ir::Max>(Expr a, Expr b) {
  return HalideIR::Internal::Interval::make_max(a, b);
}

template<>
inline Expr ComputeExpr<ir::Min>(Expr a, Expr b) {
  return HalideIR::Internal::Interval::make_min(a, b);
}

template<typename Op>
inline Expr ComputeReduce(const Array<Expr>& values, Expr empty_value) {
  if (values.size() == 0U) {
    CHECK(empty_value.defined());
    return empty_value;
  }
  Expr res = values[0];
  for (size_t i = 1; i < values.size(); ++i) {
    res = ComputeExpr<Op>(res, values[i]);
  }
  return res;
}

}  // namespace arith
}  // namespace tvm
#endif   // TVM_ARITHMETIC_COMPUTE_EXPR_H_