/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file src/relay/qnn/util.cc
 * \brief Utility functions for QNN.
 */

#include "util.h"
#include "../transforms/pattern_util.h"

namespace tvm {
namespace relay {
namespace qnn {

/*
 * \brief Convert FP32 representation into fixed point representation.
 * \param double_multplier The input FP32 number.
 * \return The pair of multiplier and shift for fixed point representation.
 * \note Converts a floating point number so that it can be represented by
 *       integers. The representation is
 *             float_number = (significand) * 2^(exponent)
 *
 *       The significand is a number between 0.5 and 1. This is represented by
 *       an integer number. For example, if it is int32, then the decimal point
 *       exists between bit 31 and 30 from LSB (or between first and second bit
 *       from the left).
 *
 *       Some examples are
 *           0.25 = (0.5) * 2^(-1)
 *           0.125 = (0.5) * 2^(-2)
 *
 *       Credit to TFLite reference implementation.
 */
std::pair<int32_t, int32_t> GetFixedPointMultiplierShift(
    double double_multiplier) {
  int32_t significand, exponent;
  if (double_multiplier == 0.) {
    significand = 0;
    exponent = 0;
    return std::make_pair(significand, exponent);
  }

  // Get the significand and exponent.
  double significand_d = std::frexp(double_multiplier, &exponent);

  // Convert the double significand to int significand, i.e., convert into a
  // integer where the decimal point is between bit 31 and 30. This is done by
  // multiplying the double value with 2^31 and then casting to int.
  significand_d = std::round(significand_d * (1ll << 31));
  auto significand_int64 = static_cast<int64_t>(significand_d);
  CHECK_LE(significand_int64, (1ll << 31));
  if (significand_int64 == (1ll << 31)) {
    significand_int64 /= 2;
    ++exponent;
  }
  CHECK_LE(significand_int64, std::numeric_limits<int32_t>::max());
  significand = static_cast<int32_t>(significand_int64);
  return std::make_pair(significand, exponent);
}

Expr FixedPointMultiply(Expr tensor, double multiplier, const Array<IndexExpr>& input_shape,
                        const std::string& rounding) {
  // Choose high precision datatype to be int64. This is for avoiding overflow
  // in multiplication of two int32 values.
  DataType hp_dtype = DataType::Int(64);

  // 1) Calculating the integer multiplier and integer shift
  int32_t fixed_point_multiplier, shift;
  std::tie(fixed_point_multiplier, shift) =
      GetFixedPointMultiplierShift(multiplier);
  int left_shift = shift > 0 ? shift : 0;
  int right_shift = shift > 0 ? 0 : -shift;

  // 2) Multiply the integer multiplier
  if (left_shift != 0) {
    tensor = LeftShift(tensor, MakeConstantScalar(hp_dtype, left_shift));
  }

  // 3) Perform the multiplication in higher precision.
  // The scalar is a fixed point value of int32 where the decimal point is
  // between bits 31 and 30. After multiplying with input_tensor, the result
  // is in int64 where the decimal point is sitting between bits 31 and 30
  // (from the right, rightmost bit is bit 0). The computation is performed in
  // higher precision to avoid overflow in multiplying two int32 values.
  Expr scalar = MakeConstantScalar(hp_dtype, fixed_point_multiplier);
  tensor = Multiply(tensor, scalar);

  // 4) Find the rounding scalar. This depends on where the final decimal
  // point sits. As we will be right shifting the multiplied_t, we need to
  // first calculate the total_right_shift.
  int total_right_shift = right_shift + 31;
  int64_t pos_rounding_value = (1ll << (total_right_shift - 1));

  Expr round_scalar;
  if (rounding == "UPWARD") {
    round_scalar = MakeConstantScalar(hp_dtype, pos_rounding_value);
  } else if (rounding == "TONEAREST") {
    auto pos_rounder = MakeConstantScalar(hp_dtype, pos_rounding_value);
    auto neg_rounder = MakeConstantScalar(hp_dtype, pos_rounding_value - 1);
    auto pos_rounder_t = Full(pos_rounder, input_shape, hp_dtype);
    auto neg_rounder_t = Full(neg_rounder, input_shape, hp_dtype);

    auto zero_t = Zeros(input_shape, hp_dtype);
    round_scalar =
        Where(GreaterEqual(tensor, zero_t), pos_rounder_t, neg_rounder_t);
  } else {
    LOG(FATAL) << "Rounding mode " << rounding << " not supported.";
  }
  // Add the rounding scalar.
  tensor = Add(tensor, round_scalar);

  // 5) Simply right shift the result to get the final output.
  tensor =
      RightShift(tensor, MakeConstantScalar(hp_dtype, total_right_shift));

  return tensor;
}

Expr FixedPointMultiplyPerChannel(Expr tensor, std::vector<double> multipliers,
                                  const Array<IndexExpr>& input_shape, int channel_axis,
                                  const std::string& rounding) {
  // Get the n dim. This will be used to expand the multiplier to match the axis.
  size_t n_dim = input_shape.size();

  // Get the num of channels/axis along which the tensor was quantized.
  int64_t n_channels = (int64_t)multipliers.size();

  // Choose high precision datatype to be int64. This is for avoiding overflow
  // in multiplication of two int32 values.
  DataType hp_dtype = DataType::Int(64);

  // 1) Calculating the integer multiplier and integer shift. These are calculated per axis/per
  // channel.
  std::vector<int32_t> fixed_pt_multipliers, lshifts, rshifts;
  bool is_lshift_required = false;
  for (auto multiplier : multipliers) {
    int32_t fixed_pt_multiplier, shift;
    std::tie(fixed_pt_multiplier, shift) = GetFixedPointMultiplierShift(multiplier);
    int lshift = shift > 0 ? shift : 0;
    int rshift = shift > 0 ? 0 : -shift;
    fixed_pt_multipliers.push_back(fixed_pt_multiplier);
    lshifts.push_back(lshift);
    rshifts.push_back(rshift);
    is_lshift_required = is_lshift_required | (lshift != 0);
  }

  // 2) Multiply the integer multiplier. Convert lefts shifts into expr and multiply.
  if (is_lshift_required) {
    auto lshift_expr = MakeConstantTensor(hp_dtype, {n_channels}, lshifts);
    auto exp_lshift_expr = ExpandBiasToMatchAxis(lshift_expr, n_dim, {channel_axis});
    tensor = LeftShift(tensor, exp_lshift_expr);
  }

  // 3) Perform the multiplication in higher precision.
  // The scalar is a fixed point value of int32 where the decimal point is
  // between bits 31 and 30. After multiplying with input_tensor, the result
  // is in int64 where the decimal point is sitting between bits 31 and 30
  // (from the right, rightmost bit is bit 0). The computation is performed in
  // higher precision to avoid overflow in multiplying two int32 values.
  auto fixed_pt_multiplier_expr = MakeConstantTensor(hp_dtype, {n_channels}, fixed_pt_multipliers);
  auto exp_fixed_pt_multiplier_expr =
      ExpandBiasToMatchAxis(fixed_pt_multiplier_expr, n_dim, {channel_axis});
  tensor = Multiply(tensor, exp_fixed_pt_multiplier_expr);

  // 4) Find the rounding scalar. This depends on where the final decimal point sits. As we will be
  // right shifting the multiplied_t, we need to first calculate the total_rshift. Further, we can
  // calculate the pos and neg rounding offset.
  std::vector<int64_t> pos_rounding_values, neg_rounding_values, total_rshifts;
  for (auto rshift : rshifts) {
    int total_rshift = rshift + 31;
    total_rshifts.push_back(total_rshift);
    pos_rounding_values.push_back((1ll << (total_rshift - 1)));
    neg_rounding_values.push_back((1ll << (total_rshift - 1)) - 1);
  }
  // Make a Relay expr from positive and negative rounding offset values.
  auto pos_rounding_value_expr = MakeConstantTensor(hp_dtype, {n_channels}, pos_rounding_values);
  auto exp_pos_rounding_value_expr =
      ExpandBiasToMatchAxis(pos_rounding_value_expr, n_dim, {channel_axis});
  auto neg_rounding_value_expr = MakeConstantTensor(hp_dtype, {n_channels}, neg_rounding_values);
  auto exp_neg_rounding_value_expr =
      ExpandBiasToMatchAxis(neg_rounding_value_expr, n_dim, {channel_axis});

  Expr round_scalar;
  if (rounding == "UPWARD") {
    round_scalar = exp_pos_rounding_value_expr;
  } else if (rounding == "TONEAREST") {
    // To satisfy where op shape requirements, the rounding values are broadcasted.
    auto pos_rounder = MakeBroadCastTo(exp_pos_rounding_value_expr, input_shape);
    auto neg_rounder = MakeBroadCastTo(exp_neg_rounding_value_expr, input_shape);

    auto zero_t = Zeros(input_shape, hp_dtype);
    round_scalar = Where(GreaterEqual(tensor, zero_t), pos_rounder, neg_rounder);
  } else {
    LOG(FATAL) << "Rounding mode " << rounding << " not supported.";
  }
  // Add the rounding scalar.
  tensor = Add(tensor, round_scalar);

  // 5) Simply right shift the result to get the final output.
  auto total_rshift_expr = MakeConstantTensor(hp_dtype, {n_channels}, total_rshifts);
  auto exp_total_rshift_expr = ExpandBiasToMatchAxis(total_rshift_expr, n_dim, {channel_axis});
  tensor = RightShift(tensor, exp_total_rshift_expr);

  return tensor;
}

}  // namespace qnn
}  // namespace relay
}  // namespace tvm