Unverified Commit 0755e4a5 by Animesh Jain Committed by GitHub

[QNN] Support 4D padding. (#5036)

* [QNN] Support 4D padding.

* Empty commit.

Co-authored-by: Ubuntu <ubuntu@ip-172-31-38-96.us-west-2.compute.internal>
parent ab839933
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from tvm.relay.expr import Tuple from tvm.relay.expr import Tuple
from tvm.relay.op.nn.util import get_pad_tuple2d
from . import _make from . import _make
def requantize(data, def requantize(data,
...@@ -280,6 +281,9 @@ def conv2d(data, ...@@ -280,6 +281,9 @@ def conv2d(data,
The computed result. The computed result.
""" """
# TODO enforce 4-way padding in topi/nn/conv2d after #4644 merged
# convert 2-way padding to 4-way padding
padding = get_pad_tuple2d(padding)
return _make.conv2d(data, kernel, return _make.conv2d(data, kernel,
input_zero_point, kernel_zero_point, input_zero_point, kernel_zero_point,
input_scale, kernel_scale, input_scale, kernel_scale,
......
...@@ -177,13 +177,17 @@ Expr Conv2DFallBack(const Expr& data, const Expr& weight, const Expr& input_zero ...@@ -177,13 +177,17 @@ Expr Conv2DFallBack(const Expr& data, const Expr& weight, const Expr& input_zero
Expr Conv2DPadInput(const Expr& data, const Expr& input_zero_point, const Conv2DAttrs* param) { Expr Conv2DPadInput(const Expr& data, const Expr& input_zero_point, const Conv2DAttrs* param) {
// 1) Pad the input data // 1) Pad the input data
auto padded_data = data; auto padded_data = data;
auto pad_h_value = get_const_int(param->padding[0]); auto pad_top_value = get_const_int(param->padding[0]);
auto pad_w_value = get_const_int(param->padding[1]); auto pad_left_value = get_const_int(param->padding[1]);
if (pad_h_value != 0 || pad_w_value != 0) { auto pad_bottom_value = get_const_int(param->padding[2]);
auto pad_right_value = get_const_int(param->padding[3]);
bool do_pad = pad_top_value != 0 || pad_left_value != 0 ||
pad_bottom_value != 0 || pad_right_value != 0;
if (do_pad) {
Array<IndexExpr> pad_n({0, 0}); Array<IndexExpr> pad_n({0, 0});
Array<IndexExpr> pad_c({0, 0}); Array<IndexExpr> pad_c({0, 0});
Array<IndexExpr> pad_h({param->padding[0], param->padding[0]}); Array<IndexExpr> pad_h({param->padding[0], param->padding[2]});
Array<IndexExpr> pad_w({param->padding[1], param->padding[1]}); Array<IndexExpr> pad_w({param->padding[1], param->padding[3]});
Array<Array<IndexExpr>> pad_width; Array<Array<IndexExpr>> pad_width;
if (param->data_layout == "NCHW") { if (param->data_layout == "NCHW") {
...@@ -336,7 +340,7 @@ Expr DepthwiseConv2DFourthTerm(int input_zero_point_int, int kernel_zero_point_i ...@@ -336,7 +340,7 @@ Expr DepthwiseConv2DFourthTerm(int input_zero_point_int, int kernel_zero_point_i
*/ */
Expr Conv2DFirstTerm(const Expr& padded_data, const Expr& weight, const Conv2DAttrs* param) { Expr Conv2DFirstTerm(const Expr& padded_data, const Expr& weight, const Conv2DAttrs* param) {
// Lowering for Term 1 // Lowering for Term 1
Array<IndexExpr> padding({0, 0}); Array<IndexExpr> padding({0, 0, 0, 0});
return Conv2D(padded_data, weight, param->strides, padding, param->dilation, param->groups, return Conv2D(padded_data, weight, param->strides, padding, param->dilation, param->groups,
param->channels, param->kernel_size, param->data_layout, param->kernel_layout, param->channels, param->kernel_size, param->data_layout, param->kernel_layout,
param->out_layout, param->out_dtype); param->out_layout, param->out_dtype);
...@@ -583,7 +587,6 @@ Expr QnnConv2DCanonicalize(const Attrs& attrs, const Array<Expr>& new_args, ...@@ -583,7 +587,6 @@ Expr QnnConv2DCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
const auto* param = attrs.as<Conv2DAttrs>(); const auto* param = attrs.as<Conv2DAttrs>();
CHECK(param != nullptr); CHECK(param != nullptr);
// Assertion checks for exisiing support. // Assertion checks for exisiing support.
CHECK_EQ(param->padding.size(), 2) << "qnn.conv2d only supports 2D padding";
CHECK(param->data_layout == "NCHW" || param->data_layout == "NHWC") CHECK(param->data_layout == "NCHW" || param->data_layout == "NHWC")
<< "qnn.conv2d supports only NCHW/NHWC input data layout."; << "qnn.conv2d supports only NCHW/NHWC input data layout.";
CHECK(param->kernel_layout == "OIHW" || param->kernel_layout == "HWIO" || CHECK(param->kernel_layout == "OIHW" || param->kernel_layout == "HWIO" ||
......
...@@ -496,6 +496,30 @@ def test_padding(): ...@@ -496,6 +496,30 @@ def test_padding():
verify(ref_func, qnn_func, data_shape, data_dtype, verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype) kernel_shape, kernel_dtype)
# Try asymmetric padding
data_shape = (2, 2, 4, 4) # NHWC
data_dtype = 'uint8'
kernel_shape = (2, 2, 4, 3) # HWIO
kernel_dtype = 'uint8'
ref_func, qnn_func = get_funcs(data_shape=data_shape,
data_dtype=data_dtype,
kernel_shape=kernel_shape,
kernel_dtype=kernel_dtype,
input_zero_point=8,
kernel_zero_point=3,
input_scale=1.0,
kernel_scale=1.0,
kernel_size=(2, 2),
padding=(1, 1, 2, 2),
strides=(1, 1),
dilation=(1, 1),
data_layout="NHWC",
kernel_layout="HWIO",
out_dtype="int32")
verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype)
def test_dilation(): def test_dilation():
with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d): with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment