/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file const_fold.h
 * \brief Centralized location for constant folding.
 */
#ifndef TVM_ARITH_CONST_FOLD_H_
#define TVM_ARITH_CONST_FOLD_H_

#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include <algorithm>
#include <cmath>
#include "int_operator.h"

namespace tvm {
namespace arith {

/*!
 * \brief Try to run binary compute with constant folding.
 *
 * \param a The left operand.
 * \param b The right operand.
 * \tparam Op The operator type.
 *
 * \note a and b Must already matched data types with each other.
 * \return nullptr if constant fold fails, otherwise return folded result.
 */
template<typename Op>
inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) {
  return PrimExpr();
}

/*!
 * \brief Try to run unary compute with constant folding.
 *
 * \param a The left operand.
 * \tparam Op The operator type.
 *
 * \note a and b Must already matched data types with each other.
 * \return nullptr if constant fold fails, otherwise return folded result.
 */
template<typename Op>
inline PrimExpr TryConstFold(PrimExpr a);

/*!
 * \brief Check whether type is used to represent index.
 *
 * Index types are frequently used in shape computation
 * and need to be aggressively constant-folded.
 *
 * \param type The type to represent index.
 * \return the checked result.
 */
inline bool IsIndexType(const DataType& type) {
  return type.is_int() && type.lanes() == 1 &&
      (type.bits() == 32 || type.bits() == 64);
}


#define TVM_ARITH_CONST_PROPAGATION(BODY)                               \
  using tir::FloatImmNode;                                               \
  const IntImmNode* pa = a.as<IntImmNode>();                            \
  const IntImmNode* pb = b.as<IntImmNode>();                            \
  const FloatImmNode* fa = a.as<FloatImmNode>();                        \
  const FloatImmNode* fb = b.as<FloatImmNode>();                        \
  BODY;


#define TVM_INDEX_CONST_PROPAGATION(BODY)                               \
  const IntImmNode* pa = a.as<IntImmNode>();                            \
  const IntImmNode* pb = b.as<IntImmNode>();                            \
  const DataType& ta = a.dtype();                                       \
  const DataType& tb = b.dtype();                                       \
  if (arith::IsIndexType(ta) && arith::IsIndexType(tb)) {               \
    BODY;                                                               \
  }                                                                     \


// specialization of constant folders.
template<>
inline PrimExpr TryConstFold<tir::AddNode>(PrimExpr a, PrimExpr b) {
  TVM_ARITH_CONST_PROPAGATION({
      const DataType& rtype = a.dtype();
      if (pa && pb) return IntImm(rtype, pa->value + pb->value);
      if (pa && pa->value == 0) return b;
      if (pb && pb->value == 0) return a;
      if (fa && fb) return FloatImm(rtype, fa->value + fb->value);
      if (fa && fa->value == 0) return b;
      if (fb && fb->value == 0) return a;
    });
  return PrimExpr();
}

template<>
inline PrimExpr TryConstFold<tir::SubNode>(PrimExpr a, PrimExpr b) {
  TVM_ARITH_CONST_PROPAGATION({
      const DataType& rtype = a.dtype();
      if (pa && pb) return IntImm(rtype, pa->value - pb->value);
      if (pb && pb->value == 0) return a;
      if (fa && fb) return FloatImm(rtype, fa->value - fb->value);
      if (fb && fb->value == 0) return a;
    });
  return PrimExpr();
}

template<>
inline PrimExpr TryConstFold<tir::MulNode>(PrimExpr a, PrimExpr b) {
  TVM_ARITH_CONST_PROPAGATION({
      const DataType& rtype = a.dtype();
      if (pa && pb) return IntImm(rtype, pa->value * pb->value);
      if (pa) {
        if (pa->value == 1) return b;
        if (pa->value == 0) return a;
      }
      if (pb) {
        if (pb->value == 1) return a;
        if (pb->value == 0) return b;
      }
      if (fa && fb) return FloatImm(rtype, fa->value * fb->value);
      if (fa) {
        if (fa->value == 1) return b;
        if (fa->value == 0) return a;
      }
      if (fb) {
        if (fb->value == 1) return a;
        if (fb->value == 0) return b;
      }
    });
  return PrimExpr();
}

template<>
inline PrimExpr TryConstFold<tir::DivNode>(PrimExpr a, PrimExpr b) {
  TVM_ARITH_CONST_PROPAGATION({
      const DataType& rtype = a.dtype();
      if (pa && pb) {
        // due to division and mod can have different modes
        // NOTE: this will assumes truc div.
        CHECK_NE(pb->value, 0) << "Divide by zero";
        return IntImm(rtype, pa->value / pb->value);
      }
      if (pa) {
        if (pa->value == 0) return a;
      }
      if (pb) {
        if (pb->value == 1) return a;
        CHECK_NE(pb->value, 0) << "Divide by zero";
      }
      if (fa && fb && fb->value != 0) {
        return FloatImm(rtype, fa->value / fb->value);
      }
      if (fa && fa->value == 0) return a;
      if (fb) {
        if (fb->value == 1) return a;
        CHECK_NE(fb->value, 0) << "Divide by zero";
      }
    });
  return PrimExpr();
}

template<>
inline PrimExpr TryConstFold<tir::ModNode>(PrimExpr a, PrimExpr b) {
  TVM_INDEX_CONST_PROPAGATION({
      const DataType& rtype = a.dtype();
      if (pa && pb) {
        return IntImm(rtype, pa->value % pb->value);
      }
      if (pa) {
        if (pa->value == 0) return a;
      }
      if (pb) {
        if (pb->value == 1) return tir::make_zero(rtype);
        CHECK_NE(pb->value, 0) << "Divide by zero";
      }
    });
  return PrimExpr();
}

template<>
inline PrimExpr TryConstFold<tir::FloorDivNode>(PrimExpr a, PrimExpr b) {
  TVM_ARITH_CONST_PROPAGATION({
      const DataType& rtype = a.dtype();
      if (pa && pb) {
        CHECK_NE(pb->value, 0) << "Divide by zero";
        return IntImm(rtype, arith::floordiv(pa->value, pb->value));
      }
      if (pa) {
        if (pa->value == 0) return a;
      }
      if (pb) {
        if (pb->value == 1) return a;
        CHECK_NE(pb->value, 0) << "Divide by zero";
      }
      if (fa && fb && fb->value != 0) {
        return FloatImm(rtype, std::floor(fa->value / fb->value));
      }
      if (fa && fa->value == 0) return a;
      if (fb) {
        if (fb->value == 1) return a;
        CHECK_NE(fb->value, 0) << "Divide by zero";
      }
    });
  return PrimExpr();
}

template<>
inline PrimExpr TryConstFold<tir::FloorModNode>(PrimExpr a, PrimExpr b) {
  TVM_INDEX_CONST_PROPAGATION({
      const DataType& rtype = a.dtype();
      if (pa && pb) {
        return IntImm(rtype, floormod(pa->value, pb->value));
      }
      if (pa) {
        if (pa->value == 0) return a;
      }
      if (pb) {
        if (pb->value == 1) return tir::make_zero(rtype);
        CHECK_NE(pb->value, 0) << "Divide by zero";
      }
    });
  return PrimExpr();
}

template<>
inline PrimExpr TryConstFold<tir::MinNode>(PrimExpr a, PrimExpr b) {
  TVM_ARITH_CONST_PROPAGATION({
      const DataType& rtype = a.dtype();
      if (pa && pb) return IntImm(rtype, std::min(pa->value, pb->value));
      if (fa && fb) return FloatImm(rtype, std::min(fa->value, fb->value));
    });
  if (a.same_as(b)) return a;
  return PrimExpr();
}

template<>
inline PrimExpr TryConstFold<tir::MaxNode>(PrimExpr a, PrimExpr b) {
  TVM_ARITH_CONST_PROPAGATION({
      const DataType& rtype = a.dtype();
      if (pa && pb) return IntImm(rtype, std::max(pa->value, pb->value));
      if (fa && fb) return FloatImm(rtype, std::max(fa->value, fb->value));
    });
  if (a.same_as(b)) return a;
  return PrimExpr();
}

template<>
inline PrimExpr TryConstFold<tir::GTNode>(PrimExpr a, PrimExpr b) {
  TVM_ARITH_CONST_PROPAGATION({
      if (pa && pb) return IntImm(DataType::UInt(1), pa->value > pb->value);
      if (fa && fb) return IntImm(DataType::UInt(1), fa->value > fb->value);
    });
  return PrimExpr();
}

template<>
inline PrimExpr TryConstFold<tir::GENode>(PrimExpr a, PrimExpr b) {
  TVM_ARITH_CONST_PROPAGATION({
      if (pa && pb) return IntImm(DataType::UInt(1), pa->value >= pb->value);
      if (fa && fb) return IntImm(DataType::UInt(1), fa->value >= fb->value);
    });
  return PrimExpr();
}

template<>
inline PrimExpr TryConstFold<tir::LTNode>(PrimExpr a, PrimExpr b) {
  TVM_ARITH_CONST_PROPAGATION({
      if (pa && pb) return IntImm(DataType::UInt(1), pa->value < pb->value);
      if (fa && fb) return IntImm(DataType::UInt(1), fa->value < fb->value);
    });
  return PrimExpr();
}

template<>
inline PrimExpr TryConstFold<tir::LENode>(PrimExpr a, PrimExpr b) {
  TVM_ARITH_CONST_PROPAGATION({
      if (pa && pb) return IntImm(DataType::UInt(1), pa->value <= pb->value);
      if (fa && fb) return IntImm(DataType::UInt(1), fa->value <= fb->value);
    });
  return PrimExpr();
}

template<>
inline PrimExpr TryConstFold<tir::EQNode>(PrimExpr a, PrimExpr b) {
  TVM_ARITH_CONST_PROPAGATION({
      if (pa && pb) return IntImm(DataType::UInt(1), pa->value == pb->value);
      if (fa && fb) return IntImm(DataType::UInt(1), fa->value == fb->value);
    });
  return PrimExpr();
}

template<>
inline PrimExpr TryConstFold<tir::NENode>(PrimExpr a, PrimExpr b) {
  TVM_ARITH_CONST_PROPAGATION({
      if (pa && pb) return IntImm(DataType::UInt(1), pa->value != pb->value);
      if (fa && fb) return IntImm(DataType::UInt(1), fa->value != fb->value);
    });
  return PrimExpr();
}

template<>
inline PrimExpr TryConstFold<tir::AndNode>(PrimExpr a, PrimExpr b) {
  const IntImmNode* pa = a.as<IntImmNode>();
  const IntImmNode* pb = b.as<IntImmNode>();
  if (pa && pa->value) return b;
  if (pa && !pa->value) return a;
  if (pb && pb->value) return a;
  if (pb && !pb->value) return b;
  return PrimExpr();
}

template<>
inline PrimExpr TryConstFold<tir::OrNode>(PrimExpr a, PrimExpr b) {
  const IntImmNode* pa = a.as<IntImmNode>();
  const IntImmNode* pb = b.as<IntImmNode>();
  if (pa && pa->value) return a;
  if (pa && !pa->value) return b;
  if (pb && pb->value) return b;
  if (pb && !pb->value) return a;
  return PrimExpr();
}

template<>
inline PrimExpr TryConstFold<tir::NotNode>(PrimExpr a) {
  const IntImmNode* pa = a.as<IntImmNode>();
  if (pa) {
    return IntImm(DataType::UInt(1), !(pa->value));
  }
  return PrimExpr();
}

/*! \brief Helper namespace for symbolic value limits */
struct SymbolicLimits {
  /*! \brief positive infinity */
  static PrimExpr pos_inf_;
  /*! \brief negative infinity */
  static PrimExpr neg_inf_;
};

/*!
 * \brief Opaque expression representing positive infinity.
 *
 *  It can can only be used as parameter of by min/max
 *  for integer analysis and cannot be used in normal expressions.
 *
 * \return positive infinity.
 */
inline PrimExpr pos_inf() {
  return SymbolicLimits::pos_inf_;
}

/*!
 * \brief Check if value is positive infinity.
 * \param value The value to be checked.
 *
 * \return The check result.
 */
inline bool is_pos_inf(const PrimExpr& value) {
  return value.same_as(SymbolicLimits::pos_inf_);
}

/*!
 * \brief Opaque expression representing negative infinity.
 *
 *  It can can only be used as parameter of by min/max
 *  for integer analysis and cannot be used in normal expressions.
 *
 * \return negative infinity.
 */
inline PrimExpr neg_inf() {
  return SymbolicLimits::neg_inf_;
}

/*!
 * \brief Check if value is negative infinity.
 * \param value The value to be checked.
 *
 * \return The check result.
 */
inline bool is_neg_inf(const PrimExpr& value) {
  return value.same_as(SymbolicLimits::neg_inf_);
}

}  // namespace arith
}  // namespace tvm
#endif  // TVM_ARITH_CONST_FOLD_H_