Commit 880c2603 by shoubhik Committed by Zhi

Do type checking for the input and kernel in the qnn conv2d (#3904)

* [QNN] Convolution 2D Implementation.

Rebasing. Empty commit.

Clang-format styling.

* Reformatting code.

* Fixing lint issues.
parent 88f9bfd4
...@@ -40,6 +40,26 @@ namespace qnn { ...@@ -40,6 +40,26 @@ namespace qnn {
// relay.op.qnn.conv2d // relay.op.qnn.conv2d
TVM_REGISTER_NODE_TYPE(QnnConv2DAttrs); TVM_REGISTER_NODE_TYPE(QnnConv2DAttrs);
bool QnnConv2DRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
const auto* weight = types[1].as<TensorTypeNode>();
if (data == nullptr || weight == nullptr) return false;
const auto* param = attrs.as<QnnConv2DAttrs>();
CHECK(param != nullptr) << "QnnConv2DAttrs cannot be nullptr.";
CHECK(data->dtype == Int(8) || data->dtype == UInt(8))
<< "Expected qnn conv2d type(int8, uint8) for input but was " << data->dtype;
CHECK(weight->dtype == Int(8) || weight->dtype == UInt(8))
<< "Expected qnn conv2d type(int8, uint8) for weight but was " << weight->dtype;
CHECK(param->out_dtype == Int(16) || param->out_dtype == Int(32))
<< "Expected qnn conv2d type(int32, int16) for output but was " << param->out_dtype;
CHECK(param->out_dtype.bits() > 0) << "Output dtype bits should be greater than 0.";
return Conv2DRel<QnnConv2DAttrs>(types, num_inputs, attrs, reporter);
}
// Workload - batch_size, in_channels, out_channels, kernel_h, kernel_w // Workload - batch_size, in_channels, out_channels, kernel_h, kernel_w
using WorkloadType = std::tuple<int, int, int, int, int>; using WorkloadType = std::tuple<int, int, int, int, int>;
...@@ -475,7 +495,7 @@ operator to understand how to scale back the int32 output to (u)int8. ...@@ -475,7 +495,7 @@ operator to understand how to scale back the int32 output to (u)int8.
.add_argument("data", "Tensor", "The quantized input data tensor.") .add_argument("data", "Tensor", "The quantized input data tensor.")
.add_argument("weight", "Tensor", "The quantized weight tensor.") .add_argument("weight", "Tensor", "The quantized weight tensor.")
.set_support_level(11) .set_support_level(11)
.add_type_rel("QnnConv2D", Conv2DRel<QnnConv2DAttrs>) .add_type_rel("QnnConv2D", QnnConv2DRel)
.set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QnnConv2DCanonicalize); .set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QnnConv2DCanonicalize);
TVM_REGISTER_API("relay.qnn.op._make.conv2d").set_body_typed(MakeQnnConv2D); TVM_REGISTER_API("relay.qnn.op._make.conv2d").set_body_typed(MakeQnnConv2D);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment