Unverified Commit 23f3988b by Animesh Jain Committed by GitHub

[QNN] Optimize lowering for requantize and FixedPointMultiply. (#4798)

* [QNN] Optimize lowering for requantize and FixedPointMultiply.

* Add check for requantize scale gt 1.

* Added test case.
parent 2989d725
......@@ -67,6 +67,9 @@ Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale,
tensor = Subtract(tensor, Cast(input_zero_point, hp_dtype));
}
// Check if multiplier is greater than 1.
bool is_multiplier_gt_one = false;
// 2) If the input and output scales are same, we can skip the fixed point multiplication. Check
// if the input scale is per-tensor or per-channel. If it is per-tensor, there is single scale for
// the whole tensor. For per-channel (aka per-axis), there is a vector of scales for the input
......@@ -78,6 +81,9 @@ Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale,
float input_scale_float = GetScalarFromConstant<float>(input_scale);
double double_multiplier =
static_cast<double>(input_scale_float) / static_cast<double>(output_scale_float);
if (double_multiplier > 1) {
is_multiplier_gt_one = true;
}
// Skip if input and output scales are same.
if (!IsEqualScalar(input_scale, output_scale)) {
scaled_int64_t =
......@@ -88,8 +94,12 @@ Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale,
std::vector<double> double_multipliers;
auto input_axis_scales = GetFloatVectorFromConstant(input_scale);
for (auto input_axis_scale : input_axis_scales) {
double_multipliers.push_back(static_cast<double>(input_axis_scale) /
static_cast<double>(output_scale_float));
double multiplier =
static_cast<double>(input_axis_scale) / static_cast<double>(output_scale_float);
double_multipliers.push_back(multiplier);
if (multiplier > 1) {
is_multiplier_gt_one = true;
}
}
int axis = param->axis;
axis = (axis == -1) ? input_shape.size() - 1 : axis;
......@@ -103,7 +113,11 @@ Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale,
shifted_int64_t = Add(Cast(output_zero_point, hp_dtype), scaled_int64_t);
}
// 4) Clip to the out_dtype min/max.
// 4) Clip to the out_dtype min/max. Skip clipping if out_dtype is Int32. The fixed point
// multiplication keeps the value in int32 range if the requantize scale is less than 1.
if (out_dtype == DataType::Int(32) && !is_multiplier_gt_one) {
return Cast(shifted_int64_t, out_dtype);
}
auto q_min = GetQmin(out_dtype);
auto q_max = GetQmax(out_dtype);
auto clipped_t = Clip(shifted_int64_t, q_min, q_max);
......
......@@ -149,6 +149,7 @@ Expr FixedPointMultiplyPerChannel(Expr tensor, std::vector<double> multipliers,
// 1) Calculating the integer multiplier and integer shift. These are calculated per axis/per
// channel.
std::vector<int32_t> fixed_pt_multipliers, lshifts, rshifts;
bool is_lshift_required = false;
for (auto multiplier : multipliers) {
int32_t fixed_pt_multiplier, shift;
std::tie(fixed_pt_multiplier, shift) = GetFixedPointMultiplierShift(multiplier);
......@@ -157,12 +158,15 @@ Expr FixedPointMultiplyPerChannel(Expr tensor, std::vector<double> multipliers,
fixed_pt_multipliers.push_back(fixed_pt_multiplier);
lshifts.push_back(lshift);
rshifts.push_back(rshift);
is_lshift_required = is_lshift_required | (lshift != 0);
}
// 2) Multiply the integer multiplier. Convert lefts shifts into expr and multiply.
auto lshift_expr = MakeConstantTensor(hp_dtype, {n_channels}, lshifts);
auto exp_lshift_expr = ExpandBiasToMatchAxis(lshift_expr, n_dim, {channel_axis});
tensor = LeftShift(tensor, exp_lshift_expr);
if (is_lshift_required) {
auto lshift_expr = MakeConstantTensor(hp_dtype, {n_channels}, lshifts);
auto exp_lshift_expr = ExpandBiasToMatchAxis(lshift_expr, n_dim, {channel_axis});
tensor = LeftShift(tensor, exp_lshift_expr);
}
// 3) Perform the multiplication in higher precision.
// The scalar is a fixed point value of int32 where the decimal point is
......
......@@ -311,6 +311,21 @@ def test_per_channel_different_scale():
rounding=rounding)
verify(mod, (golden_data, golden_output))
# Have input scale > output scale
golden_data = np.arange(-5, 5, 1).astype('int32').reshape((5,2))
golden_output = np.array([-10, -2, -6, -1, -2, 0, 2, 1, 6, 2]).reshape((5, 2))
for rounding in roundings:
mod = get_mod(data_shape=(5, 2),
data_dtype='int32',
out_dtype="int8",
input_scale=[1.0, 0.25],
output_scale=0.5,
axis=1,
rounding=rounding)
verify(mod, (golden_data, golden_output))
if __name__ == "__main__":
test_same_scale()
test_downscale()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment