Commit d7998d39 by Animesh Jain Committed by Zhi

[QNN][Conv2D] Optimize lowering. (#4006)

parent b330d301
...@@ -217,15 +217,6 @@ Expr Conv2DSecondTerm(const Expr& padded_data, const Expr& zp_kernel, const QnnC ...@@ -217,15 +217,6 @@ Expr Conv2DSecondTerm(const Expr& padded_data, const Expr& zp_kernel, const QnnC
auto scaled_hw_t2 = Multiply(casted_t2, MakeConstantScalar(Int(32), kernel_h * kernel_w)); auto scaled_hw_t2 = Multiply(casted_t2, MakeConstantScalar(Int(32), kernel_h * kernel_w));
Array<IndexExpr> padding({0, 0}); Array<IndexExpr> padding({0, 0});
// If the pool_size is 1x1, we don't need avg_pool2d.
auto reduced_hw_t2 = scaled_hw_t2;
if (kernel_h * kernel_w != 1) {
reduced_hw_t2 =
AvgPool2D(scaled_hw_t2, param->kernel_size, param->strides, padding, param->data_layout,
false, // ceil_mode
false); // count_include_pad
}
// Reduce the C dimension. Find the dimension. // Reduce the C dimension. Find the dimension.
Array<Integer> axes_t2; Array<Integer> axes_t2;
if (param->data_layout == "NCHW") { if (param->data_layout == "NCHW") {
...@@ -236,7 +227,17 @@ Expr Conv2DSecondTerm(const Expr& padded_data, const Expr& zp_kernel, const QnnC ...@@ -236,7 +227,17 @@ Expr Conv2DSecondTerm(const Expr& padded_data, const Expr& zp_kernel, const QnnC
LOG(FATAL) << "qnn.conv2d does not support " << param->data_layout << " layout"; LOG(FATAL) << "qnn.conv2d does not support " << param->data_layout << " layout";
} }
// Keep dims true to retain 4D tensor // Keep dims true to retain 4D tensor
auto reduced_t2 = Sum(reduced_hw_t2, axes_t2, true, false); auto reduced_c_t2 = Sum(scaled_hw_t2, axes_t2, true, false);
// If the pool_size is 1x1, we don't need avg_pool2d.
auto reduced_t2 = reduced_c_t2;
if (kernel_h * kernel_w != 1) {
reduced_t2 =
AvgPool2D(reduced_c_t2, param->kernel_size, param->strides, padding, param->data_layout,
false, // ceil_mode
false); // count_include_pad
}
auto multiplied_t2 = reduced_t2; auto multiplied_t2 = reduced_t2;
if (param->kernel_zero_point != 1) { if (param->kernel_zero_point != 1) {
multiplied_t2 = Multiply(zp_kernel, reduced_t2); multiplied_t2 = Multiply(zp_kernel, reduced_t2);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment