Unverified Commit 24c53a34 by masahi Committed by GitHub

[QNN] More doc fix on quantize and convolution (#4874)

* [QNN] Doc fix on quantize and convolution

* update test
parent 7013fc9a
...@@ -104,7 +104,7 @@ def quantize(data, ...@@ -104,7 +104,7 @@ def quantize(data,
axis : int axis : int
The channel axis for quantization. Default value is -1 which corresponds to the last axis. The channel axis for quantization. Default value is -1 which corresponds to the last axis.
out_dtype : str, optional out_dtype : str, optional
The data type of the input tensor. Can be [int8, uint8] The data type of the input tensor. Can be [int8, uint8, int32]
Returns Returns
------- -------
result : tvm.relay.Expr result : tvm.relay.Expr
...@@ -202,11 +202,11 @@ def conv2d(data, ...@@ -202,11 +202,11 @@ def conv2d(data,
input_scale, input_scale,
kernel_scale, kernel_scale,
kernel_size, kernel_size,
channels,
strides=(1, 1), strides=(1, 1),
padding=(0, 0), padding=(0, 0),
dilation=(1, 1), dilation=(1, 1),
groups=1, groups=1,
channels=None,
data_layout="NCHW", data_layout="NCHW",
kernel_layout="OIHW", kernel_layout="OIHW",
out_layout="", out_layout="",
...@@ -247,6 +247,9 @@ def conv2d(data, ...@@ -247,6 +247,9 @@ def conv2d(data,
kernel_size : tuple of int kernel_size : tuple of int
The spatial width and height of the convolution kernel. The spatial width and height of the convolution kernel.
channels : int
Number of output channels of this convolution.
strides : tuple of int, optional strides : tuple of int, optional
The strides of convolution. The strides of convolution.
...@@ -259,9 +262,6 @@ def conv2d(data, ...@@ -259,9 +262,6 @@ def conv2d(data,
groups : int, optional groups : int, optional
Number of groups for grouped convolution. Number of groups for grouped convolution.
channels : int, optional
Number of output channels of this convolution.
data_layout : str, optional data_layout : str, optional
Layout of the input. Layout of the input.
......
...@@ -79,8 +79,8 @@ def get_qnn_func(data, ...@@ -79,8 +79,8 @@ def get_qnn_func(data,
data_layout, data_layout,
kernel_layout, kernel_layout,
out_dtype, out_dtype,
groups, channels,
channels=None): groups):
func = relay.qnn.op.conv2d( func = relay.qnn.op.conv2d(
data, kernel, data, kernel,
input_zero_point=relay.const(input_zero_point, 'int32'), input_zero_point=relay.const(input_zero_point, 'int32'),
...@@ -116,12 +116,23 @@ def get_funcs(data_shape, ...@@ -116,12 +116,23 @@ def get_funcs(data_shape,
data_layout, data_layout,
kernel_layout, kernel_layout,
out_dtype, out_dtype,
groups=1, groups=1):
channels=None):
data = relay.var("data", shape=data_shape, data = relay.var("data", shape=data_shape,
dtype=data_dtype) dtype=data_dtype)
kernel = relay.var("kernel", shape=kernel_shape, kernel = relay.var("kernel", shape=kernel_shape,
dtype=kernel_dtype) dtype=kernel_dtype)
if groups > 1:
channels = groups
elif kernel_layout == "OIHW":
channels = kernel_shape[0]
elif kernel_layout == "HWIO":
channels = kernel_shape[3]
elif kernel_layout == "HWOI":
channels = kernel_shape[2]
else:
raise NotImplementedError
ref_func = get_ref_func(data, ref_func = get_ref_func(data,
kernel, kernel,
input_zero_point, input_zero_point,
...@@ -152,8 +163,9 @@ def get_funcs(data_shape, ...@@ -152,8 +163,9 @@ def get_funcs(data_shape,
data_layout, data_layout,
kernel_layout, kernel_layout,
out_dtype, out_dtype,
groups, channels,
channels) groups)
return (ref_func, qnn_func) return (ref_func, qnn_func)
def verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, def verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape,
...@@ -418,7 +430,7 @@ def test_layout(): ...@@ -418,7 +430,7 @@ def test_layout():
verify(ref_func, qnn_func, data_shape, data_dtype, verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype) kernel_shape, kernel_dtype)
# NHWC and HWIO layout. Used in depthwise conv. # NHWC and HWOI layout. Used in depthwise conv.
data_shape = (2, 2, 4, 1) # NHWC data_shape = (2, 2, 4, 1) # NHWC
data_dtype = 'uint8' data_dtype = 'uint8'
kernel_shape = (2, 2, 1, 1) # HWOI kernel_shape = (2, 2, 1, 1) # HWOI
...@@ -568,6 +580,7 @@ def test_const_folding(): ...@@ -568,6 +580,7 @@ def test_const_folding():
data_layout="NCHW", data_layout="NCHW",
kernel_layout="OIHW", kernel_layout="OIHW",
out_dtype="int32", out_dtype="int32",
channels=kernel_shape[0],
groups=1) groups=1)
folded_mod = transform.FoldConstant()(qnn_func) folded_mod = transform.FoldConstant()(qnn_func)
folded_func = folded_mod["main"] folded_func = folded_mod["main"]
...@@ -787,8 +800,8 @@ def test_depthwise_depth_multiplier(): ...@@ -787,8 +800,8 @@ def test_depthwise_depth_multiplier():
data_layout="NCHW", data_layout="NCHW",
kernel_layout="OIHW", kernel_layout="OIHW",
out_dtype="int32", out_dtype="int32",
groups=4, groups=4)
channels=4)
verify(ref_func, qnn_func, data_shape, data_dtype, verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype) kernel_shape, kernel_dtype)
...@@ -813,8 +826,7 @@ def test_depthwise_depth_multiplier(): ...@@ -813,8 +826,7 @@ def test_depthwise_depth_multiplier():
data_layout="NCHW", data_layout="NCHW",
kernel_layout="OIHW", kernel_layout="OIHW",
out_dtype="int32", out_dtype="int32",
groups=8, groups=8)
channels=8)
verify(ref_func, qnn_func, data_shape, data_dtype, verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype) kernel_shape, kernel_dtype)
...@@ -839,8 +851,7 @@ def test_depthwise_depth_multiplier(): ...@@ -839,8 +851,7 @@ def test_depthwise_depth_multiplier():
data_layout="NHWC", data_layout="NHWC",
kernel_layout="HWOI", kernel_layout="HWOI",
out_dtype="int32", out_dtype="int32",
groups=4, groups=4)
channels=4)
verify(ref_func, qnn_func, data_shape, data_dtype, verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype) kernel_shape, kernel_dtype)
...@@ -864,8 +875,7 @@ def test_depthwise_depth_multiplier(): ...@@ -864,8 +875,7 @@ def test_depthwise_depth_multiplier():
data_layout="NHWC", data_layout="NHWC",
kernel_layout="HWOI", kernel_layout="HWOI",
out_dtype="int32", out_dtype="int32",
groups=8, groups=8)
channels=8)
verify(ref_func, qnn_func, data_shape, data_dtype, verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype) kernel_shape, kernel_dtype)
...@@ -888,6 +898,7 @@ def test_per_channel_kernel_scale(): ...@@ -888,6 +898,7 @@ def test_per_channel_kernel_scale():
input_scale=relay.const(2.0, 'float32'), input_scale=relay.const(2.0, 'float32'),
kernel_scale=kernel_scales, kernel_scale=kernel_scales,
kernel_size=(2, 2), kernel_size=(2, 2),
channels=kernel_shape[0],
padding=(0, 0), padding=(0, 0),
strides=(1, 1), strides=(1, 1),
dilation=(1, 1), dilation=(1, 1),
......
...@@ -107,6 +107,7 @@ def test_qnn_legalize_qnn_conv2d(): ...@@ -107,6 +107,7 @@ def test_qnn_legalize_qnn_conv2d():
input_scale=relay.const(1.0, 'float32'), input_scale=relay.const(1.0, 'float32'),
kernel_scale=relay.const(1.0, 'float32'), kernel_scale=relay.const(1.0, 'float32'),
kernel_size=(3, 3), kernel_size=(3, 3),
channels=kernel_shape[0],
strides=(1, 1), strides=(1, 1),
dilation=(1, 1), dilation=(1, 1),
out_dtype='int32', out_dtype='int32',
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment