Commit a5bb789a by Animesh Jain Committed by Zhi

[QNN] Conv2D type checking for kernel per-channel scales. (#4732)

* [QNN] Conv2D type checking for kernel per-channel scales.

* Address commments.

* Address comments.

* - Adding safety checks for downcasts.

Co-authored-by: shoubhik <shoubhikbhatti@gmail.com>
parent 03ffb01c
...@@ -57,7 +57,10 @@ bool QnnConv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, ...@@ -57,7 +57,10 @@ bool QnnConv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
CHECK(IsScalarType(types[2], DataType::Int(32))); // input_zero_point CHECK(IsScalarType(types[2], DataType::Int(32))); // input_zero_point
CHECK(IsScalarType(types[3], DataType::Int(32))); // kernel_zero_point CHECK(IsScalarType(types[3], DataType::Int(32))); // kernel_zero_point
CHECK(IsScalarType(types[4], DataType::Float(32))); // input_scale CHECK(IsScalarType(types[4], DataType::Float(32))); // input_scale
CHECK(IsScalarType(types[5], DataType::Float(32))); // kernel_scale // Kernel scale can be a vector of length output_channels or a scalar.
size_t axis = param->kernel_layout.find('O');
CHECK(axis != std::string::npos) << "Kernel layout attribute is not defined";
AssignType(types[5], DataType::Float(32), weight->shape[axis], reporter); // kernel scale
// Collect the input tensor and output tensor devoid of scale and zero points to reuse Relay // Collect the input tensor and output tensor devoid of scale and zero points to reuse Relay
// Conv2D infer type function. // Conv2D infer type function.
......
...@@ -152,6 +152,8 @@ Expr FixedPointMultiplyPerChannel(Expr tensor, std::vector<double> multiplier, ...@@ -152,6 +152,8 @@ Expr FixedPointMultiplyPerChannel(Expr tensor, std::vector<double> multiplier,
*/ */
static inline bool IsScalarType(const Type& expr_type, const DataType& dtype) { static inline bool IsScalarType(const Type& expr_type, const DataType& dtype) {
const auto* tensor_type = expr_type.as<TensorTypeNode>(); const auto* tensor_type = expr_type.as<TensorTypeNode>();
CHECK(tensor_type) << "Only tensor type can be checked for scalar values. But got"
<< AsText(expr_type, false);
CHECK_EQ(tensor_type->shape.size(), 0); CHECK_EQ(tensor_type->shape.size(), 0);
CHECK(tensor_type->dtype == dtype) << "Expected " << dtype << " but got " << tensor_type->dtype; CHECK(tensor_type->dtype == dtype) << "Expected " << dtype << " but got " << tensor_type->dtype;
return true; return true;
...@@ -168,6 +170,8 @@ static inline void AssignType(const Type& expr_type, const DataType& dtype, cons ...@@ -168,6 +170,8 @@ static inline void AssignType(const Type& expr_type, const DataType& dtype, cons
const TypeReporter& reporter) { const TypeReporter& reporter) {
// Scale/Zero_points can be either const scalar or a vector with C axis num elems. // Scale/Zero_points can be either const scalar or a vector with C axis num elems.
const auto* tensor_type = expr_type.as<TensorTypeNode>(); const auto* tensor_type = expr_type.as<TensorTypeNode>();
CHECK(tensor_type) << "Can assign type to Tensor type only. But got "
<< AsText(expr_type, false);
const auto tensor_dtype = tensor_type->dtype; const auto tensor_dtype = tensor_type->dtype;
CHECK(tensor_dtype == dtype) << "Expected type is " << dtype << " but received " << tensor_dtype; CHECK(tensor_dtype == dtype) << "Expected type is " << dtype << " but received " << tensor_dtype;
if (tensor_type->shape.size() != 0) { if (tensor_type->shape.size() != 0) {
......
...@@ -768,8 +768,8 @@ def test_depthwise_depth_multiplier(): ...@@ -768,8 +768,8 @@ def test_depthwise_depth_multiplier():
channels=4) channels=4)
verify(ref_func, qnn_func, data_shape, data_dtype, verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype) kernel_shape, kernel_dtype)
# Depthwise multiplier = 2 # Depthwise multiplier = 2
data_shape = (10, 4, 16, 16) data_shape = (10, 4, 16, 16)
data_dtype = 'uint8' data_dtype = 'uint8'
...@@ -794,7 +794,7 @@ def test_depthwise_depth_multiplier(): ...@@ -794,7 +794,7 @@ def test_depthwise_depth_multiplier():
channels=8) channels=8)
verify(ref_func, qnn_func, data_shape, data_dtype, verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype) kernel_shape, kernel_dtype)
# uint8 input, NHWC and HWOI # uint8 input, NHWC and HWOI
# Depthwise multiplier = 1 # Depthwise multiplier = 1
data_shape = (2, 16, 16, 4) data_shape = (2, 16, 16, 4)
...@@ -820,7 +820,7 @@ def test_depthwise_depth_multiplier(): ...@@ -820,7 +820,7 @@ def test_depthwise_depth_multiplier():
channels=4) channels=4)
verify(ref_func, qnn_func, data_shape, data_dtype, verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype) kernel_shape, kernel_dtype)
# Depthwise multiplier = 2 # Depthwise multiplier = 2
data_shape = (2, 16, 16, 4) data_shape = (2, 16, 16, 4)
data_dtype = 'uint8' data_dtype = 'uint8'
...@@ -846,6 +846,35 @@ def test_depthwise_depth_multiplier(): ...@@ -846,6 +846,35 @@ def test_depthwise_depth_multiplier():
verify(ref_func, qnn_func, data_shape, data_dtype, verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype) kernel_shape, kernel_dtype)
def test_per_channel_kernel_scale():
with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d):
data_shape = (2, 1, 2, 4)
data_dtype = 'uint8'
kernel_shape = (3, 1, 2, 2)
kernel_dtype = 'uint8'
data = relay.var("data", shape=data_shape,
dtype=data_dtype)
kernel = relay.var("kernel", shape=kernel_shape,
dtype=kernel_dtype)
kernel_scales = [2, 2, 2]
kernel_scales = relay.const(np.array(kernel_scales).astype('float32'))
func = relay.qnn.op.conv2d(
data, kernel,
input_zero_point=relay.const(0, 'int32'),
kernel_zero_point=relay.const(0, 'int32'),
input_scale=relay.const(2.0, 'float32'),
kernel_scale=kernel_scales,
kernel_size=(2, 2),
padding=(0, 0),
strides=(1, 1),
dilation=(1, 1),
data_layout="NCHW",
kernel_layout="OIHW",
out_dtype="int32")
mod = relay.Function(relay.analysis.free_vars(func), func)
mod = relay.Module.from_expr(mod)
if __name__ == "__main__": if __name__ == "__main__":
test_no_zero_point() test_no_zero_point()
test_input_zero_point() test_input_zero_point()
...@@ -861,3 +890,4 @@ if __name__ == "__main__": ...@@ -861,3 +890,4 @@ if __name__ == "__main__":
test_tflite_output_multiplier_greater_than_one() test_tflite_output_multiplier_greater_than_one()
test_tflite_anistropic_strides() test_tflite_anistropic_strides()
test_depthwise_depth_multiplier() test_depthwise_depth_multiplier()
test_per_channel_kernel_scale()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment