util.cc 5.23 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file src/relay/qnn/util.cc
 * \brief Utility functions for QNN.
 */

#include "util.h"
#include "../pass/pattern_util.h"

namespace tvm {
namespace relay {
namespace qnn {

/*
 * \brief Convert FP32 representation into fixed point representation.
 * \param double_multplier The input FP32 number.
 * \return The pair of multiplier and shift for fixed point representation.
 * \note Converts a floating point number so that it can be represented by
 *       integers. The representation is
 *             float_number = (significand) * 2^(exponent)
 *
 *       The significand is a number between 0.5 and 1. This is represented by
 *       an integer number. For example, if it is int32, then the decimal point
 *       exists between bit 31 and 30 from LSB (or between first and second bit
 *       from the left).
 *
 *       Some examples are
 *           0.25 = (0.5) * 2^(-1)
 *           0.125 = (0.5) * 2^(-2)
 *
 *       Credit to TFLite reference implementation.
 */
std::pair<int32_t, int32_t> GetFixedPointMultiplierShift(
    double double_multiplier) {
  int32_t significand, exponent;
  if (double_multiplier == 0.) {
    significand = 0;
    exponent = 0;
    return std::make_pair(significand, exponent);
  }

  // Get the significand and exponent.
  double significand_d = std::frexp(double_multiplier, &exponent);

  // Convert the double significand to int significand, i.e., convert into a
  // integer where the decimal point is between bit 31 and 30. This is done by
  // multiplying the double value with 2^31 and then casting to int.
  significand_d = std::round(significand_d * (1ll << 31));
  auto significand_int64 = static_cast<int64_t>(significand_d);
  CHECK_LE(significand_int64, (1ll << 31));
  if (significand_int64 == (1ll << 31)) {
    significand_int64 /= 2;
    ++exponent;
  }
  CHECK_LE(significand_int64, std::numeric_limits<int32_t>::max());
  significand = static_cast<int32_t>(significand_int64);
  return std::make_pair(significand, exponent);
}

78
Expr FixedPointMultiply(Expr tensor, double multiplier,
79 80 81
                   const Array<IndexExpr>& input_shape, const std::string& rounding) {
  // Choose high precision datatype to be int64. This is for avoiding overflow
  // in multiplication of two int32 values.
82
  DataType hp_dtype = DataType::Int(64);
83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122

  // 1) Calculating the integer multiplier and integer shift
  int32_t fixed_point_multiplier, shift;
  std::tie(fixed_point_multiplier, shift) =
      GetFixedPointMultiplierShift(multiplier);
  int left_shift = shift > 0 ? shift : 0;
  int right_shift = shift > 0 ? 0 : -shift;

  // 2) Multiply the integer multiplier
  if (left_shift != 0) {
    tensor = LeftShift(tensor, MakeConstantScalar(hp_dtype, left_shift));
  }

  // 3) Perform the multiplication in higher precision.
  // The scalar is a fixed point value of int32 where the decimal point is
  // between bits 31 and 30. After multiplying with input_tensor, the result
  // is in int64 where the decimal point is sitting between bits 31 and 30
  // (from the right, rightmost bit is bit 0). The computation is performed in
  // higher precision to avoid overflow in multiplying two int32 values.
  Expr scalar = MakeConstantScalar(hp_dtype, fixed_point_multiplier);
  tensor = Multiply(tensor, scalar);

  // 4) Find the rounding scalar. This depends on where the final decimal
  // point sits. As we will be right shifting the multiplied_t, we need to
  // first calculate the total_right_shift.
  int total_right_shift = right_shift + 31;
  int64_t pos_rounding_value = (1ll << (total_right_shift - 1));

  Expr round_scalar;
  if (rounding == "UPWARD") {
    round_scalar = MakeConstantScalar(hp_dtype, pos_rounding_value);
  } else if (rounding == "TONEAREST") {
    auto pos_rounder = MakeConstantScalar(hp_dtype, pos_rounding_value);
    auto neg_rounder = MakeConstantScalar(hp_dtype, pos_rounding_value - 1);
    auto pos_rounder_t = Full(pos_rounder, input_shape, hp_dtype);
    auto neg_rounder_t = Full(neg_rounder, input_shape, hp_dtype);

    auto zero_t = Zeros(input_shape, hp_dtype);
    round_scalar =
        Where(GreaterEqual(tensor, zero_t), pos_rounder_t, neg_rounder_t);
123 124
  } else {
    LOG(FATAL) << "Rounding mode " << rounding << " not supported.";
125 126 127 128 129 130 131 132 133 134 135 136 137 138
  }
  // Add the rounding scalar.
  tensor = Add(tensor, round_scalar);

  // 5) Simply right shift the result to get the final output.
  tensor =
      RightShift(tensor, MakeConstantScalar(hp_dtype, total_right_shift));

  return tensor;
}

}  // namespace qnn
}  // namespace relay
}  // namespace tvm