Commit 1bc83853 by Animesh Jain Committed by Wuwei Lin

[QNN] Requantize - Optimize lowering for some corner cases. (#3864)

parent dee52466
...@@ -129,6 +129,9 @@ Expr RequantizeLower(const Expr& input_tensor, const RequantizeAttrs* param, ...@@ -129,6 +129,9 @@ Expr RequantizeLower(const Expr& input_tensor, const RequantizeAttrs* param,
tensor = Subtract(tensor, input_zp); tensor = Subtract(tensor, input_zp);
} }
// If the input and output scales are same, we can skip the fixed point multiplication.
auto scaled_int64_t = tensor;
if (param->input_scale != param->output_scale) {
// 3) Multiply the integer multiplier // 3) Multiply the integer multiplier
if (left_shift != 0) { if (left_shift != 0) {
tensor = Multiply(tensor, MakeConstantScalar(hp_dtype, 1 << left_shift)); tensor = Multiply(tensor, MakeConstantScalar(hp_dtype, 1 << left_shift));
...@@ -166,11 +169,15 @@ Expr RequantizeLower(const Expr& input_tensor, const RequantizeAttrs* param, ...@@ -166,11 +169,15 @@ Expr RequantizeLower(const Expr& input_tensor, const RequantizeAttrs* param,
tensor = Add(tensor, round_scalar); tensor = Add(tensor, round_scalar);
// 5) Simply right shift the result to get the final output. // 5) Simply right shift the result to get the final output.
auto scaled_int64_t = RightShift(tensor, MakeConstantScalar(hp_dtype, total_right_shift)); scaled_int64_t = RightShift(tensor, MakeConstantScalar(hp_dtype, total_right_shift));
}
// 6) Add the output zero point. // 6) Add the output zero point.
auto shifted_int64_t = scaled_int64_t;
if (param->output_zero_point != 0) {
auto output_zp = MakeConstantScalar(hp_dtype, param->output_zero_point); auto output_zp = MakeConstantScalar(hp_dtype, param->output_zero_point);
auto shifted_int64_t = Add(output_zp, scaled_int64_t); shifted_int64_t = Add(output_zp, scaled_int64_t);
}
// 7) Clip to the out_dtype min/max. // 7) Clip to the out_dtype min/max.
auto q_min = GetQmin(out_dtype); auto q_min = GetQmin(out_dtype);
......
...@@ -64,6 +64,7 @@ def test_requantize(): ...@@ -64,6 +64,7 @@ def test_requantize():
input_scale=0.5, input_scale=0.5,
output_scale=0.5, output_scale=0.5,
rounding=rounding) rounding=rounding)
assert 'right_shift' not in mod.astext()
verify(mod, (golden_data, golden_output)) verify(mod, (golden_data, golden_output))
def downscale_test(): def downscale_test():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment