Unverified Commit 24c53a34 by masahi Committed by GitHub

[QNN] More doc fix on quantize and convolution (#4874)

* [QNN] Doc fix on quantize and convolution

* update test
parent 7013fc9a
......@@ -104,7 +104,7 @@ def quantize(data,
axis : int
The channel axis for quantization. Default value is -1 which corresponds to the last axis.
out_dtype : str, optional
The data type of the input tensor. Can be [int8, uint8]
The data type of the input tensor. Can be [int8, uint8, int32]
Returns
-------
result : tvm.relay.Expr
......@@ -202,11 +202,11 @@ def conv2d(data,
input_scale,
kernel_scale,
kernel_size,
channels,
strides=(1, 1),
padding=(0, 0),
dilation=(1, 1),
groups=1,
channels=None,
data_layout="NCHW",
kernel_layout="OIHW",
out_layout="",
......@@ -247,6 +247,9 @@ def conv2d(data,
kernel_size : tuple of int
The spatial width and height of the convolution kernel.
channels : int
Number of output channels of this convolution.
strides : tuple of int, optional
The strides of convolution.
......@@ -259,9 +262,6 @@ def conv2d(data,
groups : int, optional
Number of groups for grouped convolution.
channels : int, optional
Number of output channels of this convolution.
data_layout : str, optional
Layout of the input.
......
......@@ -79,8 +79,8 @@ def get_qnn_func(data,
data_layout,
kernel_layout,
out_dtype,
groups,
channels=None):
channels,
groups):
func = relay.qnn.op.conv2d(
data, kernel,
input_zero_point=relay.const(input_zero_point, 'int32'),
......@@ -116,12 +116,23 @@ def get_funcs(data_shape,
data_layout,
kernel_layout,
out_dtype,
groups=1,
channels=None):
groups=1):
data = relay.var("data", shape=data_shape,
dtype=data_dtype)
kernel = relay.var("kernel", shape=kernel_shape,
dtype=kernel_dtype)
if groups > 1:
channels = groups
elif kernel_layout == "OIHW":
channels = kernel_shape[0]
elif kernel_layout == "HWIO":
channels = kernel_shape[3]
elif kernel_layout == "HWOI":
channels = kernel_shape[2]
else:
raise NotImplementedError
ref_func = get_ref_func(data,
kernel,
input_zero_point,
......@@ -152,8 +163,9 @@ def get_funcs(data_shape,
data_layout,
kernel_layout,
out_dtype,
groups,
channels)
channels,
groups)
return (ref_func, qnn_func)
def verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape,
......@@ -418,7 +430,7 @@ def test_layout():
verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype)
# NHWC and HWIO layout. Used in depthwise conv.
# NHWC and HWOI layout. Used in depthwise conv.
data_shape = (2, 2, 4, 1) # NHWC
data_dtype = 'uint8'
kernel_shape = (2, 2, 1, 1) # HWOI
......@@ -568,6 +580,7 @@ def test_const_folding():
data_layout="NCHW",
kernel_layout="OIHW",
out_dtype="int32",
channels=kernel_shape[0],
groups=1)
folded_mod = transform.FoldConstant()(qnn_func)
folded_func = folded_mod["main"]
......@@ -787,8 +800,8 @@ def test_depthwise_depth_multiplier():
data_layout="NCHW",
kernel_layout="OIHW",
out_dtype="int32",
groups=4,
channels=4)
groups=4)
verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype)
......@@ -813,8 +826,7 @@ def test_depthwise_depth_multiplier():
data_layout="NCHW",
kernel_layout="OIHW",
out_dtype="int32",
groups=8,
channels=8)
groups=8)
verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype)
......@@ -839,8 +851,7 @@ def test_depthwise_depth_multiplier():
data_layout="NHWC",
kernel_layout="HWOI",
out_dtype="int32",
groups=4,
channels=4)
groups=4)
verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype)
......@@ -864,8 +875,7 @@ def test_depthwise_depth_multiplier():
data_layout="NHWC",
kernel_layout="HWOI",
out_dtype="int32",
groups=8,
channels=8)
groups=8)
verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype)
......@@ -888,6 +898,7 @@ def test_per_channel_kernel_scale():
input_scale=relay.const(2.0, 'float32'),
kernel_scale=kernel_scales,
kernel_size=(2, 2),
channels=kernel_shape[0],
padding=(0, 0),
strides=(1, 1),
dilation=(1, 1),
......
......@@ -107,6 +107,7 @@ def test_qnn_legalize_qnn_conv2d():
input_scale=relay.const(1.0, 'float32'),
kernel_scale=relay.const(1.0, 'float32'),
kernel_size=(3, 3),
channels=kernel_shape[0],
strides=(1, 1),
dilation=(1, 1),
out_dtype='int32',
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment