Commit 8eb3157a by Animesh Jain Committed by Tianqi Chen

[QNN] Fix padding changes due to #3739 (#3989)

parent cb1faf8a
...@@ -476,10 +476,12 @@ static inline Expr AvgPool2D(Expr data, Array<IndexExpr> pool_size, Array<IndexE ...@@ -476,10 +476,12 @@ static inline Expr AvgPool2D(Expr data, Array<IndexExpr> pool_size, Array<IndexE
return CallNode::make(op, {data}, Attrs(attrs), {}); return CallNode::make(op, {data}, Attrs(attrs), {});
} }
static inline Expr Pad(Expr data, Array<Array<IndexExpr>> pad_width, double pad_value) { static inline Expr Pad(Expr data, Array<Array<IndexExpr>> pad_width, double pad_value,
std::string pad_mode) {
auto attrs = make_node<PadAttrs>(); auto attrs = make_node<PadAttrs>();
attrs->pad_value = pad_value; attrs->pad_value = pad_value;
attrs->pad_width = std::move(pad_width); attrs->pad_width = std::move(pad_width);
attrs->pad_mode = std::move(pad_mode);
static const Op& op = Op::Get("nn.pad"); static const Op& op = Op::Get("nn.pad");
return CallNode::make(op, {data}, Attrs(attrs), {}); return CallNode::make(op, {data}, Attrs(attrs), {});
} }
......
...@@ -167,7 +167,7 @@ Expr Conv2DPadInput(const Expr& data, const QnnConv2DAttrs* param) { ...@@ -167,7 +167,7 @@ Expr Conv2DPadInput(const Expr& data, const QnnConv2DAttrs* param) {
} else { } else {
LOG(FATAL) << "qnn.conv2d does not support " << param->data_layout << " layout"; LOG(FATAL) << "qnn.conv2d does not support " << param->data_layout << " layout";
} }
padded_data = Pad(data, pad_width, param->input_zero_point); padded_data = Pad(data, pad_width, param->input_zero_point, "constant");
} }
return padded_data; return padded_data;
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment