Unverified Commit 92d0ec14 by Animesh Jain Committed by GitHub

[Requantize] Cleanup and Optimize Lowering (#5286)

* Adding Cast back to Int32 in FixedPointMultiply.

* Removing extra clip.

* Fix space.

* Retrigger.

* Retrigger.
parent e4b80bda
...@@ -132,36 +132,28 @@ Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale, ...@@ -132,36 +132,28 @@ Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale,
const Expr& input_zero_point, const Expr& output_scale, const Expr& input_zero_point, const Expr& output_scale,
const Expr& output_zero_point, const RequantizeAttrs* param, const Expr& output_zero_point, const RequantizeAttrs* param,
const Array<IndexExpr>& input_shape, const DataType& out_dtype) { const Array<IndexExpr>& input_shape, const DataType& out_dtype) {
DataType hp_dtype = DataType::Int(64); auto tensor = Cast(input_tensor, DataType::Int(32));
auto tensor = Cast(input_tensor, hp_dtype);
// 1) Subtract the input_zero_point // 1) Subtract the input_zero_point
auto zero_scalar = MakeConstantScalar(DataType::Int(32), 0); auto zero_scalar = MakeConstantScalar(DataType::Int(32), 0);
if (!IsEqualScalar(input_zero_point, zero_scalar)) { if (!IsEqualScalar(input_zero_point, zero_scalar)) {
tensor = Subtract(tensor, Cast(input_zero_point, hp_dtype)); tensor = Subtract(tensor, Cast(input_zero_point, DataType::Int(32)));
} }
// Check if multiplier is greater than 1.
bool is_multiplier_gt_one = false;
// 2) If the input and output scales are same, we can skip the fixed point multiplication. Check // 2) If the input and output scales are same, we can skip the fixed point multiplication. Check
// if the input scale is per-tensor or per-channel. If it is per-tensor, there is single scale for // if the input scale is per-tensor or per-channel. If it is per-tensor, there is single scale for
// the whole tensor. For per-channel (aka per-axis), there is a vector of scales for the input // the whole tensor. For per-channel (aka per-axis), there is a vector of scales for the input
// tensor. Depending on the quantization type, the fixed point multiplication routing is called. // tensor. Depending on the quantization type, the fixed point multiplication routing is called.
auto scaled_int64_t = tensor; auto scaled_int32_t = tensor;
float output_scale_float = GetScalarFromConstant<float>(output_scale); float output_scale_float = GetScalarFromConstant<float>(output_scale);
if (IsConstScalar(input_scale)) { if (IsConstScalar(input_scale)) {
// This is per-tensor quantization. Single scale. // This is per-tensor quantization. Single scale.
float input_scale_float = GetScalarFromConstant<float>(input_scale); float input_scale_float = GetScalarFromConstant<float>(input_scale);
double double_multiplier = double double_multiplier =
static_cast<double>(input_scale_float) / static_cast<double>(output_scale_float); static_cast<double>(input_scale_float) / static_cast<double>(output_scale_float);
if (double_multiplier > 1) {
is_multiplier_gt_one = true;
}
// Skip if input and output scales are same. // Skip if input and output scales are same.
if (!IsEqualScalar(input_scale, output_scale)) { if (!IsEqualScalar(input_scale, output_scale)) {
scaled_int64_t = scaled_int32_t =
FixedPointMultiply(scaled_int64_t, double_multiplier, input_shape, param->rounding); FixedPointMultiply(scaled_int32_t, double_multiplier, input_shape, param->rounding);
} }
} else { } else {
// This is per-channel (per=axis) quantization. // This is per-channel (per=axis) quantization.
...@@ -171,30 +163,28 @@ Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale, ...@@ -171,30 +163,28 @@ Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale,
double multiplier = double multiplier =
static_cast<double>(input_axis_scale) / static_cast<double>(output_scale_float); static_cast<double>(input_axis_scale) / static_cast<double>(output_scale_float);
double_multipliers.push_back(multiplier); double_multipliers.push_back(multiplier);
if (multiplier > 1) {
is_multiplier_gt_one = true;
}
} }
int axis = param->axis; int axis = param->axis;
axis = (axis == -1) ? input_shape.size() - 1 : axis; axis = (axis == -1) ? input_shape.size() - 1 : axis;
scaled_int64_t = FixedPointMultiplyPerChannel(scaled_int64_t, double_multipliers, input_shape, scaled_int32_t = FixedPointMultiplyPerChannel(scaled_int32_t, double_multipliers, input_shape,
axis, param->rounding); axis, param->rounding);
} }
// 3) Add the output zero point. // 3) Add the output zero point.
auto shifted_int64_t = scaled_int64_t; auto shifted_int32_t = scaled_int32_t;
if (!IsEqualScalar(output_zero_point, zero_scalar)) { if (!IsEqualScalar(output_zero_point, zero_scalar)) {
shifted_int64_t = Add(Cast(output_zero_point, hp_dtype), scaled_int64_t); shifted_int32_t = Add(Cast(output_zero_point, DataType::Int(32)), scaled_int32_t);
} }
// 4) Clip to the out_dtype min/max. Skip clipping if out_dtype is Int32. The fixed point // 4) Clip to the out_dtype min/max. Skip clipping if out_dtype is Int32. The fixed point
// multiplication keeps the value in int32 range if the requantize scale is less than 1. // multiplication keeps the value in int32 range.
if (out_dtype == DataType::Int(32) && !is_multiplier_gt_one) { if (out_dtype == DataType::Int(32)) {
return Cast(shifted_int64_t, out_dtype); return shifted_int32_t;
} }
auto q_min = GetQmin(out_dtype); auto q_min = GetQmin(out_dtype);
auto q_max = GetQmax(out_dtype); auto q_max = GetQmax(out_dtype);
auto clipped_t = Clip(shifted_int64_t, q_min, q_max); auto clipped_t = Clip(shifted_int32_t, q_min, q_max);
return Cast(clipped_t, out_dtype); return Cast(clipped_t, out_dtype);
} }
......
...@@ -80,6 +80,7 @@ Expr FixedPointMultiply(Expr tensor, double multiplier, const Array<IndexExpr>& ...@@ -80,6 +80,7 @@ Expr FixedPointMultiply(Expr tensor, double multiplier, const Array<IndexExpr>&
// Choose high precision datatype to be int64. This is for avoiding overflow // Choose high precision datatype to be int64. This is for avoiding overflow
// in multiplication of two int32 values. // in multiplication of two int32 values.
DataType hp_dtype = DataType::Int(64); DataType hp_dtype = DataType::Int(64);
tensor = Cast(tensor, hp_dtype);
// 1) Calculating the integer multiplier and integer shift // 1) Calculating the integer multiplier and integer shift
int32_t fixed_point_multiplier, shift; int32_t fixed_point_multiplier, shift;
...@@ -130,7 +131,8 @@ Expr FixedPointMultiply(Expr tensor, double multiplier, const Array<IndexExpr>& ...@@ -130,7 +131,8 @@ Expr FixedPointMultiply(Expr tensor, double multiplier, const Array<IndexExpr>&
tensor = tensor =
RightShift(tensor, MakeConstantScalar(hp_dtype, total_right_shift)); RightShift(tensor, MakeConstantScalar(hp_dtype, total_right_shift));
return tensor; // 6) The fixed point multiplication keeps the value in int32 range. Casting back to int32.
return Cast(tensor, DataType::Int(32));
} }
Expr FixedPointMultiplyPerChannel(Expr tensor, std::vector<double> multipliers, Expr FixedPointMultiplyPerChannel(Expr tensor, std::vector<double> multipliers,
...@@ -145,6 +147,7 @@ Expr FixedPointMultiplyPerChannel(Expr tensor, std::vector<double> multipliers, ...@@ -145,6 +147,7 @@ Expr FixedPointMultiplyPerChannel(Expr tensor, std::vector<double> multipliers,
// Choose high precision datatype to be int64. This is for avoiding overflow // Choose high precision datatype to be int64. This is for avoiding overflow
// in multiplication of two int32 values. // in multiplication of two int32 values.
DataType hp_dtype = DataType::Int(64); DataType hp_dtype = DataType::Int(64);
tensor = Cast(tensor, hp_dtype);
// 1) Calculating the integer multiplier and integer shift. These are calculated per axis/per // 1) Calculating the integer multiplier and integer shift. These are calculated per axis/per
// channel. // channel.
...@@ -218,7 +221,8 @@ Expr FixedPointMultiplyPerChannel(Expr tensor, std::vector<double> multipliers, ...@@ -218,7 +221,8 @@ Expr FixedPointMultiplyPerChannel(Expr tensor, std::vector<double> multipliers,
auto exp_total_rshift_expr = ExpandBiasToMatchAxis(total_rshift_expr, n_dim, {channel_axis}); auto exp_total_rshift_expr = ExpandBiasToMatchAxis(total_rshift_expr, n_dim, {channel_axis});
tensor = RightShift(tensor, exp_total_rshift_expr); tensor = RightShift(tensor, exp_total_rshift_expr);
return tensor; // 6) The fixed point multiplication keeps the value in int32 range. Casting back to int32.
return Cast(tensor, DataType::Int(32));
} }
} // namespace qnn } // namespace qnn
......
...@@ -117,8 +117,7 @@ inline Expr MulAndDiv(Expr data, float s1, float s2, DataType dtype, ...@@ -117,8 +117,7 @@ inline Expr MulAndDiv(Expr data, float s1, float s2, DataType dtype,
} else if (static_cast<int>(factor) == factor) { } else if (static_cast<int>(factor) == factor) {
return Multiply(data, MakeConstantScalar(dtype, factor)); return Multiply(data, MakeConstantScalar(dtype, factor));
} else { } else {
data = qnn::FixedPointMultiply( data = qnn::FixedPointMultiply(data, factor, data_shape, cfg->rounding);
Cast(data, DataType::Int(64)), factor, data_shape, cfg->rounding);
return Cast(data, dtype); return Cast(data, dtype);
} }
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment