/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 *  Copyright (c) 2017 by Contributors
 * \file expr_operator.cc
 */
#include <tvm/base.h>
#include <tvm/ir.h>
#include <tvm/expr_operator.h>
#include <cmath>
// Centralized header for constant folders.
#include "../arithmetic/const_fold.h"

namespace tvm {

// simple cast that only checks if type matches and cast
inline Expr SimpleCast(const Type& t, Expr value) {
  if (value.type() == t) return value;
  return ir::Cast::make(t, value);
}

// The public function with a quick checking path.
void BinaryOpMatchTypes(Expr& lhs, Expr& rhs) {  // NOLINT(*)
  if (lhs.type() == rhs.type()) return;
  Type ltype = lhs.type();
  Type rtype = rhs.type();
  if (ltype.lanes() == 1 && rtype.lanes() != 1) {
    lhs = ir::Broadcast::make(lhs, rtype.lanes());
  } else if (rtype.lanes() == 1 && ltype.lanes() != 1) {
    rhs = ir::Broadcast::make(rhs, ltype.lanes());
  } else {
    CHECK(ltype.lanes() == rtype.lanes())
        << "Cannot match type " << ltype << " vs " << rtype;
  }
  if (lhs.type() == rhs.type()) return;
  // Only do very simple type coversion
  // int->float, int(32)->int(64)
  // require the types to be relatively consistent
  // This will the reduce amount code generated by operators
  // and also help user to find potential type conversion problems.
  if (!lhs.type().is_float() && rhs.type().is_float()) {
    // int->float
    lhs = cast(rhs.type(), lhs);
  } else if (lhs.type().is_float() && !rhs.type().is_float()) {
    // int->float
    rhs = cast(lhs.type(), rhs);
  } else if ((lhs.type().is_int() && rhs.type().is_int()) ||
             (lhs.type().is_uint() && rhs.type().is_uint())) {
    // promote int to higher bits
    if (lhs.type().bits() < rhs.type().bits()) {
      lhs = cast(rhs.type(), lhs);
    } else {
      rhs = cast(lhs.type(), rhs);
    }
  } else if ((lhs.type().is_int() && rhs.type().is_uint()) ||
             (lhs.type().is_uint() && rhs.type().is_int())) {
    int bits = std::max(lhs.type().bits(), rhs.type().bits());
    lhs = SimpleCast(Int(bits, lhs.type().lanes()), lhs);
    rhs = SimpleCast(Int(bits, rhs.type().lanes()), rhs);
  } else {
    LOG(FATAL) << "Cannot match type " << ltype << " vs " << rtype;
  }
}


template<typename ValueType>
inline bool ConstPowerHelper(ValueType val, int *shift) {
  if (val <= 0) return false;
  shift[0] = 0;
  while (val != 0) {
    if (val & 1) {
      return (val == 1);
    }
    ++shift[0];
    val = val >> 1;
  }
  return true;
}

bool is_const_power_of_two_integer(const Expr& x, int* shift) {
  if (const auto* op = x.as<ir::IntImm>()) {
    return ConstPowerHelper(op->value, shift);
  } else if (const auto* op = x.as<ir::UIntImm>()) {
    return ConstPowerHelper(op->value, shift);
  } else {
    return false;
  }
}

Expr cast(const Type& t, Expr value) {
  using ir::IntImm;
  using ir::FloatImm;
  if (value.type() == t) return value;
  // const fold IntImm as they are used in index computations
  if (t.lanes() == 1) {
    if (const IntImm* op = value.as<IntImm>()) {
      return make_const(t, op->value);
    } else if (const FloatImm* op = value.as<FloatImm>()) {
      return make_const(t, op->value);
    }
    return ir::Cast::make(t, value);
  } else {
    if (value.type().lanes() == 1) {
      // manually unroll cast
      Type vtype = t.element_of();
      if (value.type() != vtype) {
        if (const IntImm* op = value.as<IntImm>()) {
          value = make_const(vtype, op->value);
        } else if (const FloatImm* op = value.as<FloatImm>()) {
          value = make_const(vtype, op->value);
        } else {
          value = ir::Cast::make(vtype, value);
        }
      }
      return ir::Broadcast::make(value, t.lanes());
    } else {
      CHECK(value.type().lanes() == t.lanes());
      return ir::Cast::make(t, value);
    }
  }
}

Expr reinterpret(const Type& t, Expr value) {
  if (value.type() == t) return value;
  return ir::Call::make(t, ir::Call::reinterpret, { value }, ir::Call::PureIntrinsic);
}

Expr operator+(Expr a, Expr b) {
  BinaryOpMatchTypes(a, b);
  Expr ret = arith::TryConstFold<ir::Add>(a, b);
  if (ret.defined()) return ret;
  return ir::Add::make(a, b);
}

// negation
Expr operator-(Expr a) {
  using ir::IntImm;
  using ir::FloatImm;
  const IntImm* pa = a.as<IntImm>();
  const FloatImm* fa = a.as<FloatImm>();
  if (pa) return ir::IntImm::make(a.type(), -pa->value);
  if (fa) return ir::FloatImm::make(a.type(), -fa->value);
  return make_zero(a.type()) - a;
}

Expr operator-(Expr a, Expr b) {
  BinaryOpMatchTypes(a, b);
  Expr ret = arith::TryConstFold<ir::Sub>(a, b);
  if (ret.defined()) return ret;
  return ir::Sub::make(a, b);
}

Expr operator*(Expr a, Expr b) {
  BinaryOpMatchTypes(a, b);
  Expr ret = arith::TryConstFold<ir::Mul>(a, b);
  if (ret.defined()) return ret;
  return ir::Mul::make(a, b);
}

Expr operator/(Expr a, Expr b) {
  BinaryOpMatchTypes(a, b);
  Expr ret = arith::TryConstFold<ir::Div>(a, b);
  if (ret.defined()) return ret;
  return ir::Div::make(a, b);
}

Expr operator%(Expr a, Expr b) {
  BinaryOpMatchTypes(a, b);
  Expr ret = arith::TryConstFold<ir::Mod>(a, b);
  if (ret.defined()) return ret;
  return ir::Mod::make(a, b);
}

Expr min(Expr a, Expr b) {
  BinaryOpMatchTypes(a, b);
  Expr ret = arith::TryConstFold<ir::Min>(a, b);
  if (ret.defined()) return ret;
  return ir::Min::make(a, b);
}

Expr max(Expr a, Expr b) {
  BinaryOpMatchTypes(a, b);
  Expr ret = arith::TryConstFold<ir::Max>(a, b);
  if (ret.defined()) return ret;
  return ir::Max::make(a, b);
}

Expr if_then_else(Expr cond, Expr true_value, Expr false_value) {
  using ir::IntImm;
  using ir::UIntImm;
  CHECK(cond.type() == Bool(1))
      << "if_then_else only accept a single condition";
  BinaryOpMatchTypes(true_value, false_value);
  if (const UIntImm* op = cond.as<UIntImm>()) {
    if (op->value != 0) {
      return true_value;
    } else {
      return false_value;
    }
  } else if (const IntImm* op = cond.as<IntImm>()) {
    if (op->value != 0) {
      return true_value;
    } else {
      return false_value;
    }
  }
  return ir::Call::make(
      true_value.type(),
      ir::intrinsic::tvm_if_then_else,
      {cond, true_value, false_value},
      ir::Call::PureIntrinsic);
}

Expr likely(Expr cond) {
  if (is_const(cond)) return cond;
  return ir::Call::make(cond.type(), ir::Call::likely, { cond }, ir::Call::PureIntrinsic);
}

Expr operator>(Expr a, Expr b) {
  BinaryOpMatchTypes(a, b);
  Expr ret = arith::TryConstFold<ir::GT>(a, b);
  if (ret.defined()) return ret;
  return ir::GT::make(a, b);
}

Expr operator>=(Expr a, Expr b) {
  BinaryOpMatchTypes(a, b);
  Expr ret = arith::TryConstFold<ir::GE>(a, b);
  if (ret.defined()) return ret;
  return ir::GE::make(a, b);
}

Expr operator<(Expr a, Expr b) {
  BinaryOpMatchTypes(a, b);
  Expr ret = arith::TryConstFold<ir::LT>(a, b);
  if (ret.defined()) return ret;
  return ir::LT::make(a, b);
}

Expr operator<=(Expr a, Expr b) {
  BinaryOpMatchTypes(a, b);
  Expr ret = arith::TryConstFold<ir::LE>(a, b);
  if (ret.defined()) return ret;
  return ir::LE::make(a, b);
}

Expr operator==(Expr a, Expr b) {
  BinaryOpMatchTypes(a, b);
  Expr ret = arith::TryConstFold<ir::EQ>(a, b);
  if (ret.defined()) return ret;
  return ir::EQ::make(a, b);
}

Expr operator!=(Expr a, Expr b) {
  BinaryOpMatchTypes(a, b);
  Expr ret = arith::TryConstFold<ir::NE>(a, b);
  if (ret.defined()) return ret;
  return ir::NE::make(a, b);
}

Expr operator&&(Expr a, Expr b) {
  CHECK(a.type().is_bool());
  CHECK(b.type().is_bool());
  Expr ret = arith::TryConstFold<ir::And>(a, b);
  if (ret.defined()) return ret;
  return ir::And::make(a, b);
}

Expr operator||(Expr a, Expr b) {
  CHECK(a.type().is_bool());
  CHECK(b.type().is_bool());
  Expr ret = arith::TryConstFold<ir::Or>(a, b);
  if (ret.defined()) return ret;
  return ir::Or::make(a, b);
}

Expr operator!(Expr a) {
  CHECK(a.type().is_bool());
  Expr ret = arith::TryConstFold<ir::Not>(a);
  if (ret.defined()) return ret;
  return ir::Not::make(a);
}

Expr operator>>(Expr a, Expr b) {
  BinaryOpMatchTypes(a, b);
  TVM_INDEX_CONST_PROPAGATION({
      const Type& rtype = a.type();
      if (pa && pb) return IntImm::make(rtype, (pa->value >> pb->value));
      if (pb) {
        if (pb->value == 0) return a;
      }
    });
  return ir::Call::make(a.type(), ir::Call::shift_right, { a, b }, ir::Call::PureIntrinsic);
}

Expr operator<<(Expr a, Expr b) {
  BinaryOpMatchTypes(a, b);
  TVM_INDEX_CONST_PROPAGATION({
      const Type& rtype = a.type();
      if (pa && pb) return IntImm::make(rtype, (pa->value << pb->value));
      if (pb) {
        if (pb->value == 0) return a;
      }
    });
  return ir::Call::make(a.type(), ir::Call::shift_left, { a, b }, ir::Call::PureIntrinsic);
}

Expr operator&(Expr a, Expr b) {
  BinaryOpMatchTypes(a, b);
  TVM_INDEX_CONST_PROPAGATION({
      const Type& rtype = a.type();
      if (pa && pb) return IntImm::make(rtype, (pa->value & pb->value));
    });
  return ir::Call::make(a.type(), ir::Call::bitwise_and, { a, b }, ir::Call::PureIntrinsic);
}

Expr operator|(Expr a, Expr b) {
  BinaryOpMatchTypes(a, b);
  TVM_INDEX_CONST_PROPAGATION({
      const Type& rtype = a.type();
      if (pa && pb) return IntImm::make(rtype, (pa->value | pb->value));
    });
  return ir::Call::make(a.type(), ir::Call::bitwise_or, { a, b }, ir::Call::PureIntrinsic);
}

Expr operator^(Expr a, Expr b) {
  BinaryOpMatchTypes(a, b);
  TVM_INDEX_CONST_PROPAGATION({
      const Type& rtype = a.type();
      if (pa && pb) return IntImm::make(rtype, (pa->value ^ pb->value));
    });
  return ir::Call::make(a.type(), ir::Call::bitwise_xor, { a, b }, ir::Call::PureIntrinsic);
}

Expr operator~(Expr a) {
  CHECK(a.type().is_int() || a.type().is_uint());
  return ir::Call::make(a.type(), ir::Call::bitwise_not, { a }, ir::Call::PureIntrinsic);
}

Expr pow(Expr x, Expr y) {
  BinaryOpMatchTypes(x, y);
  CHECK(x.type().is_float()) << "power only applies to float";
  return ir::Call::make(x.type(), "pow", { x, y }, ir::Call::PureIntrinsic);
}

Expr abs(Expr x) {
  if (x.type().is_int()) {
    using ir::IntImm;
    const IntImm* px = x.as<IntImm>();
    if (px) {
      return ir::IntImm::make(x.type(), std::abs(px->value));
    }
    return ir::Select::make(x >= make_zero(x.type()), x, -x);
  } else if (x.type().is_float()) {
    using ir::FloatImm;
    const FloatImm* fx = x.as<FloatImm>();
    if (fx) {
      return ir::FloatImm::make(x.type(), std::fabs(fx->value));
    }
    return ir::Call::make(x.type(), "fabs", {x}, ir::Call::PureIntrinsic);
  } else if (x.type().is_uint()) {
    return x;
  } else {
    LOG(FATAL) << "Data type " << x.type()
               <<" not supported for absolute op. Skipping absolute op...";
    return x;
  }
}

Expr sum(Expr source, Array<IterVar> rdom) {
  Var x("x", source.type()), y("y", source.type());
  Expr result = ir::Add::make(x, y);
  Expr identity_element = make_zero(source.type());
  ir::CommReducer combiner =
    ir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
  return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0);
}

Expr max(Expr source, Array<IterVar> rdom) {
  Var x("x", source.type()), y("y", source.type());
  Expr result = ir::Max::make(x, y);
  Expr identity_element = source.type().min();
  ir::CommReducer combiner =
    ir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
  return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0);
}

Expr min(Expr source, Array<IterVar> rdom) {
  Var x("x", source.type()), y("y", source.type());
  Expr result = ir::Min::make(x, y);
  Expr identity_element = source.type().max();
  ir::CommReducer combiner =
    ir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
  return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0);
}

Expr prod(Expr source, Array<IterVar> rdom) {
  Var x("x", source.type()), y("y", source.type());
  Expr result = ir::Mul::make(x, y);
  Expr identity_element = make_const(source.type(), 1);
  ir::CommReducer combiner =
    ir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
  return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0);
}

Expr fmod(Expr x, Expr y) {
  BinaryOpMatchTypes(x, y);
  CHECK(x.type().is_float()) << "fmod only applies to float";
  return ir::Call::make(x.type(), "fmod", { x, y }, ir::Call::PureIntrinsic);
}

Expr floor(Expr x) {
  using ir::FloatImm;
  const FloatImm* fx = x.as<FloatImm>();
  if (fx) return FloatImm::make(x.type(), std::floor(fx->value));
  return ir::Call::make(x.type(), "floor", {x}, ir::Call::PureIntrinsic);
}

Expr ceil(Expr x) {
  using ir::FloatImm;
  const FloatImm* fx = x.as<FloatImm>();
  if (fx) return FloatImm::make(x.type(), std::ceil(fx->value));
  return ir::Call::make(x.type(), "ceil", {x}, ir::Call::PureIntrinsic);
}

Expr round(Expr x) {
  using ir::FloatImm;
  const FloatImm* fx = x.as<FloatImm>();
  if (fx) return FloatImm::make(x.type(), std::nearbyint(fx->value));
  return ir::Call::make(x.type(), "round", {x}, ir::Call::PureIntrinsic);
}

Expr trunc(Expr x) {
  using ir::FloatImm;
  const FloatImm* fx = x.as<FloatImm>();
  if (fx) {
    return FloatImm::make(x.type(), (fx->value < 0 ? std::ceil(fx->value) :
                                     std::floor(fx->value)));
  }
  return ir::Call::make(x.type(), "trunc", {x}, ir::Call::PureIntrinsic);
}

}  // namespace tvm