Commit 425430d4 by Wuwei Lin Committed by Zhi

[QNN] Refactor fixed point multiplication in requantize (#4073)

parent 76c23926
...@@ -336,6 +336,14 @@ inline Expr ZerosLike(Expr e) { ...@@ -336,6 +336,14 @@ inline Expr ZerosLike(Expr e) {
return CallNode::make(op, {e}); return CallNode::make(op, {e});
} }
inline Expr Zeros(Array<IndexExpr> shape, DataType dtype) {
auto attrs = make_node<InitOpAttrs>();
attrs->shape = std::move(shape);
attrs->dtype = std::move(dtype);
static const Op& op = Op::Get("zeros");
return CallNode::make(op, {}, Attrs(attrs), {});
}
inline Expr OnesLike(Expr e) { inline Expr OnesLike(Expr e) {
static const Op& op = Op::Get("ones_like"); static const Op& op = Op::Get("ones_like");
return CallNode::make(op, {e}); return CallNode::make(op, {e});
......
...@@ -37,50 +37,7 @@ TVM_REGISTER_NODE_TYPE(RequantizeAttrs); ...@@ -37,50 +37,7 @@ TVM_REGISTER_NODE_TYPE(RequantizeAttrs);
// Lowering of qnn.requantize op // Lowering of qnn.requantize op
/*
* \brief Convert FP32 representation into fixed point representation.
* \param double_multplier The input FP32 number.
* \return The pair of multiplier and shift for fixed point representation.
* \note Converts a floating point number so that it can be represented by
* integers. The representation is
* float_number = (significand) * 2^(exponent)
*
* The significand is a number between 0.5 and 1. This is represented by
* an integer number. For example, if it is int32, then the decimal point
* exists between bit 31 and 30 from LSB (or between first and second bit
* from the left).
*
* Some examples are
* 0.25 = (0.5) * 2^(-1)
* 0.125 = (0.5) * 2^(-2)
*
* Credit to TFLite reference implementation.
*/
std::pair<int32_t, int32_t> GetFixedPointMultiplierShift(double double_multiplier) {
int32_t significand, exponent;
if (double_multiplier == 0.) {
significand = 0;
exponent = 0;
return std::make_pair(significand, exponent);
}
// Get the significand and exponent.
double significand_d = std::frexp(double_multiplier, &exponent);
// Convert the double significand to int significand, i.e., convert into a
// integer where the decimal point is between bit 31 and 30. This is done by
// multiplying the double value with 2^31 and then casting to int.
significand_d = std::round(significand_d * (1ll << 31));
auto significand_int64 = static_cast<int64_t>(significand_d);
CHECK_LE(significand_int64, (1ll << 31));
if (significand_int64 == (1ll << 31)) {
significand_int64 /= 2;
++exponent;
}
CHECK_LE(significand_int64, std::numeric_limits<int32_t>::max());
significand = static_cast<int32_t>(significand_int64);
return std::make_pair(significand, exponent);
}
/* /*
* \brief Lower requantize to a sequence of ops. * \brief Lower requantize to a sequence of ops.
...@@ -93,93 +50,41 @@ std::pair<int32_t, int32_t> GetFixedPointMultiplierShift(double double_multiplie ...@@ -93,93 +50,41 @@ std::pair<int32_t, int32_t> GetFixedPointMultiplierShift(double double_multiplie
* and shift. This is useful, if the target device does not support/have * and shift. This is useful, if the target device does not support/have
* very expensive floating point computations. * very expensive floating point computations.
* *
* Original compuation is scale_fp32 * quantized_tensor. To convert into
* integer computation, the multiplication with fp32 scalar can be
* replaced by multiplication with an int value and then right shifting
* the result. This approximates the floating point computation with a
* fixed point computation.
*
* The whole computation this can be broken down into following steps * The whole computation this can be broken down into following steps
* 1) Calculate the integer multiplier and integer shift. * 1) Calculate the integer multiplier and integer shift.
* 2) Subtract the input integer zero point. * 2) Subtract the input integer zero point.
* 3) Multiply the fixed point multiplier with quantized tensor. * 3) Perform fixed point multiplication.
* 4) Round the result. * 4) Add the output zero point.
* 5) Right shift the result. * 5) Cast to the out_dtype.
* 6) Add the output zero point.
* 7) Cast to the out_dtype.
*/ */
Expr RequantizeLower(const Expr& input_tensor, const RequantizeAttrs* param, Expr RequantizeLower(const Expr& input_tensor, const RequantizeAttrs* param,
const Array<IndexExpr>& input_shape, const DataType& out_dtype) { const Array<IndexExpr>& input_shape, const DataType& out_dtype) {
double double_multiplier = param->input_scale / param->output_scale; double double_multiplier = param->input_scale / param->output_scale;
// Choose high precision datatype to be int64. This is for avoiding overflow
// in multiplication of two int32 values.
DataType hp_dtype = Int(64); DataType hp_dtype = Int(64);
// 1) Calculating the integer multiplier and integer shift
int32_t fixed_point_multiplier, shift;
std::tie(fixed_point_multiplier, shift) = GetFixedPointMultiplierShift(double_multiplier);
int left_shift = shift > 0 ? shift : 0;
int right_shift = shift > 0 ? 0 : -shift;
// 2) Subtract the input_zero_point
auto tensor = Cast(input_tensor, hp_dtype); auto tensor = Cast(input_tensor, hp_dtype);
// 1) Subtract the input_zero_point
if (param->input_zero_point != 0) { if (param->input_zero_point != 0) {
auto input_zp = MakeConstantScalar(hp_dtype, param->input_zero_point); auto input_zp = MakeConstantScalar(hp_dtype, param->input_zero_point);
tensor = Subtract(tensor, input_zp); tensor = Subtract(tensor, input_zp);
} }
// If the input and output scales are same, we can skip the fixed point multiplication. // 2) If the input and output scales are same, we can skip the fixed point multiplication.
auto scaled_int64_t = tensor; auto scaled_int64_t = tensor;
if (param->input_scale != param->output_scale) { if (param->input_scale != param->output_scale) {
// 3) Multiply the integer multiplier scaled_int64_t = FixedPointMuliply(scaled_int64_t, double_multiplier, input_shape,
if (left_shift != 0) { param->rounding);
tensor = Multiply(tensor, MakeConstantScalar(hp_dtype, 1 << left_shift));
}
// Perform the multiplication in higher precision.
// The scalar is a fixed point value of int32 where the decimal point is
// between bits 31 and 30. After multiplying with input_tensor, the result is
// in int64 where the decimal point is sitting between bits 31 and 30 (from
// the right, rightmost bit is bit 0). The computation is performed in higher
// precision to avoid overflow in multiplying two int32 values.
Expr scalar = MakeConstantScalar(hp_dtype, fixed_point_multiplier);
auto multiplied_t = Multiply(tensor, scalar);
// 4) Find the rounding scalar. This depends on where the final decimal point
// sits. As we will be right shifting the multiplied_t, we need to first
// calculate the total_right_shift.
int total_right_shift = right_shift + 31;
int64_t pos_rounding_value = (1ll << (total_right_shift - 1));
tensor = multiplied_t;
Expr round_scalar;
if (param->rounding == "UPWARD") {
round_scalar = MakeConstantScalar(hp_dtype, pos_rounding_value);
} else if (param->rounding == "TONEAREST") {
auto pos_rounder = MakeConstantScalar(hp_dtype, pos_rounding_value);
auto neg_rounder = MakeConstantScalar(hp_dtype, pos_rounding_value - 1);
auto pos_rounder_t = Full(pos_rounder, input_shape, hp_dtype);
auto neg_rounder_t = Full(neg_rounder, input_shape, hp_dtype);
auto zero = MakeConstantScalar(hp_dtype, 0);
auto zero_t = Full(zero, input_shape, hp_dtype);
round_scalar = Where(GreaterEqual(tensor, zero_t), pos_rounder_t, neg_rounder_t);
}
// Add the rounding scalar.
tensor = Add(tensor, round_scalar);
// 5) Simply right shift the result to get the final output.
scaled_int64_t = RightShift(tensor, MakeConstantScalar(hp_dtype, total_right_shift));
} }
// 6) Add the output zero point. // 3) Add the output zero point.
auto shifted_int64_t = scaled_int64_t; auto shifted_int64_t = scaled_int64_t;
if (param->output_zero_point != 0) { if (param->output_zero_point != 0) {
auto output_zp = MakeConstantScalar(hp_dtype, param->output_zero_point); auto output_zp = MakeConstantScalar(hp_dtype, param->output_zero_point);
shifted_int64_t = Add(output_zp, scaled_int64_t); shifted_int64_t = Add(output_zp, scaled_int64_t);
} }
// 7) Clip to the out_dtype min/max. // 4) Clip to the out_dtype min/max.
auto q_min = GetQmin(out_dtype); auto q_min = GetQmin(out_dtype);
auto q_max = GetQmax(out_dtype); auto q_max = GetQmax(out_dtype);
auto clipped_t = Clip(shifted_int64_t, q_min, q_max); auto clipped_t = Clip(shifted_int64_t, q_min, q_max);
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2019 by Contributors
* \file src/relay/qnn/util.cc
* \brief Utility functions for QNN.
*/
#include "util.h"
#include "../pass/pattern_util.h"
namespace tvm {
namespace relay {
namespace qnn {
/*
* \brief Convert FP32 representation into fixed point representation.
* \param double_multplier The input FP32 number.
* \return The pair of multiplier and shift for fixed point representation.
* \note Converts a floating point number so that it can be represented by
* integers. The representation is
* float_number = (significand) * 2^(exponent)
*
* The significand is a number between 0.5 and 1. This is represented by
* an integer number. For example, if it is int32, then the decimal point
* exists between bit 31 and 30 from LSB (or between first and second bit
* from the left).
*
* Some examples are
* 0.25 = (0.5) * 2^(-1)
* 0.125 = (0.5) * 2^(-2)
*
* Credit to TFLite reference implementation.
*/
std::pair<int32_t, int32_t> GetFixedPointMultiplierShift(
double double_multiplier) {
int32_t significand, exponent;
if (double_multiplier == 0.) {
significand = 0;
exponent = 0;
return std::make_pair(significand, exponent);
}
// Get the significand and exponent.
double significand_d = std::frexp(double_multiplier, &exponent);
// Convert the double significand to int significand, i.e., convert into a
// integer where the decimal point is between bit 31 and 30. This is done by
// multiplying the double value with 2^31 and then casting to int.
significand_d = std::round(significand_d * (1ll << 31));
auto significand_int64 = static_cast<int64_t>(significand_d);
CHECK_LE(significand_int64, (1ll << 31));
if (significand_int64 == (1ll << 31)) {
significand_int64 /= 2;
++exponent;
}
CHECK_LE(significand_int64, std::numeric_limits<int32_t>::max());
significand = static_cast<int32_t>(significand_int64);
return std::make_pair(significand, exponent);
}
Expr FixedPointMuliply(Expr tensor, double multiplier,
const Array<IndexExpr>& input_shape, const std::string& rounding) {
// Choose high precision datatype to be int64. This is for avoiding overflow
// in multiplication of two int32 values.
DataType hp_dtype = Int(64);
// 1) Calculating the integer multiplier and integer shift
int32_t fixed_point_multiplier, shift;
std::tie(fixed_point_multiplier, shift) =
GetFixedPointMultiplierShift(multiplier);
int left_shift = shift > 0 ? shift : 0;
int right_shift = shift > 0 ? 0 : -shift;
// 2) Multiply the integer multiplier
if (left_shift != 0) {
tensor = LeftShift(tensor, MakeConstantScalar(hp_dtype, left_shift));
}
// 3) Perform the multiplication in higher precision.
// The scalar is a fixed point value of int32 where the decimal point is
// between bits 31 and 30. After multiplying with input_tensor, the result
// is in int64 where the decimal point is sitting between bits 31 and 30
// (from the right, rightmost bit is bit 0). The computation is performed in
// higher precision to avoid overflow in multiplying two int32 values.
Expr scalar = MakeConstantScalar(hp_dtype, fixed_point_multiplier);
tensor = Multiply(tensor, scalar);
// 4) Find the rounding scalar. This depends on where the final decimal
// point sits. As we will be right shifting the multiplied_t, we need to
// first calculate the total_right_shift.
int total_right_shift = right_shift + 31;
int64_t pos_rounding_value = (1ll << (total_right_shift - 1));
Expr round_scalar;
if (rounding == "UPWARD") {
round_scalar = MakeConstantScalar(hp_dtype, pos_rounding_value);
} else if (rounding == "TONEAREST") {
auto pos_rounder = MakeConstantScalar(hp_dtype, pos_rounding_value);
auto neg_rounder = MakeConstantScalar(hp_dtype, pos_rounding_value - 1);
auto pos_rounder_t = Full(pos_rounder, input_shape, hp_dtype);
auto neg_rounder_t = Full(neg_rounder, input_shape, hp_dtype);
auto zero_t = Zeros(input_shape, hp_dtype);
round_scalar =
Where(GreaterEqual(tensor, zero_t), pos_rounder_t, neg_rounder_t);
}
// Add the rounding scalar.
tensor = Add(tensor, round_scalar);
// 5) Simply right shift the result to get the final output.
tensor =
RightShift(tensor, MakeConstantScalar(hp_dtype, total_right_shift));
return tensor;
}
} // namespace qnn
} // namespace relay
} // namespace tvm
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include <tvm/expr.h> #include <tvm/expr.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <tvm/relay/qnn/attrs.h>
#include <limits> #include <limits>
#include <string> #include <string>
#include <utility> #include <utility>
...@@ -92,6 +93,32 @@ static inline int64_t get_const_int(const tvm::Expr& x) { ...@@ -92,6 +93,32 @@ static inline int64_t get_const_int(const tvm::Expr& x) {
return value_ptr[0]; return value_ptr[0];
} }
/*
* \brief Fixed point multiplication between integer tensor with floating point
scalar.
* \param tensor The quantized input tensor of dtype int64.
* \param multiplier The scalar multiplier.
* \param input_shape Shape of the input tensor.
* \param rounding "UPWARD" or "TONEAREST". The rounding direction when the value
is midway between" "two representable values.
* \return The sequence of Relay ops for fixed point multiplication.
* \note Original compuation is scale_fp32 * quantized_tensor. To convert into
* integer computation, the multiplication with fp32 scalar can be
* replaced by multiplication with an int value and then right shifting
* the result. This approximates the floating point computation with a
* fixed point computation.
*
* Computation of fixed point multiplication is consist of following
steps:
* 1) Multiply the fixed point multiplier with quantized tensor.
* 2) Round the result.
* 3) Right shift the result
*/
Expr FixedPointMuliply(Expr tensor, double multiplier,
const Array<IndexExpr>& input_shape,
const std::string& rounding);
} // namespace qnn } // namespace qnn
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment