Unverified Commit a5e54b1d by masahi Committed by GitHub

[QNN] Add support for per channel weight scale in dense op (#4880)

* add test case for per channel dense

* add unit arg in tflite frontend

* update qnn legalize test

* fix output dim index
parent 24c53a34
......@@ -982,6 +982,7 @@ class OperatorConverter(object):
weight_value = self.get_tensor_value(weight_tensor)
weight_expr = self.exp_tab.new_const(weight_value, dtype=weight_tensor_type_str)
weight_shape = _infer_shape(weight_expr)
if input_tensor.qnn_params:
out = _qnn.op.dense(in_expr, weight_expr,
......@@ -989,6 +990,7 @@ class OperatorConverter(object):
kernel_zero_point=weight_tensor.qnn_params['zero_point'],
input_scale=input_tensor.qnn_params['scale'],
kernel_scale=weight_tensor.qnn_params['scale'],
units=weight_shape[0],
out_dtype='int32')
else:
out = _op.nn.dense(in_expr, weight_expr)
......
......@@ -345,7 +345,7 @@ def dense(data,
kernel_zero_point,
input_scale,
kernel_scale,
units=None,
units,
out_dtype="int32"):
"""Qnn Dense operator.
Applies a quantized linear transformation
......@@ -371,7 +371,7 @@ def dense(data,
stored for access to this during relay. This information is not
needed in the pass pipeline after qnn.conv2d is lowered to the
sequence of steps as in nn.conv2d. See also input_scale in Requantize.
units : int, optional
units : int
Number of hidden units of the dense transformation.
out_dtype : str, optional
Specifies the output data type for mixed precision dense can be int32 or int16.
......
......@@ -55,7 +55,7 @@ bool QnnDenseRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
CHECK(IsScalarType(types[2], DataType::Int(32))); // input_zero_point
CHECK(IsScalarType(types[3], DataType::Int(32))); // kernel_zero_point
CHECK(IsScalarType(types[4], DataType::Float(32))); // input_scale
CHECK(IsScalarType(types[5], DataType::Float(32))); // kernel_scale
AssignType(types[5], DataType::Float(32), param->units, reporter);
CHECK(param->out_dtype.bits() > 0) << "Output dtype bits should be greater than 0.";
......
......@@ -75,52 +75,8 @@ def make_configuration(quantized_data,
return config
def make_uint_configuration(use_bias=False, requantize_output=False):
input_shape, kernel_shape, output_shape = (2, 10), (3,10), (2, 3)
input_zero_point, kernel_zero_point = 127, 127
input_scale = 0.5
kernel_scale = 0.5
output_scale = 1.0
in_dtype = 'uint8'
out_dtype = 'int32' if not requantize_output else 'uint8'
units = 3
quantized_data_np = np.array([129, 131, 133, 135, 137, 139, 141, 143, 109, 107,
129, 131, 133, 135, 137, 139, 141, 111, 145, 107]) \
.astype(in_dtype) \
.reshape(input_shape)
quantized_kernel_np = np.array([129, 131, 133, 135, 137, 139, 141, 143, 145, 147,
129, 131, 133, 135, 137, 139, 141, 143, 145, 147,
129, 131, 133, 135, 137, 139, 141, 143, 145, 147]) \
.astype(in_dtype) \
.reshape(kernel_shape)
bias = np.array([4, 8, 12]).astype(out_dtype).reshape((units, )) if use_bias else None
requant_params = make_requantize_params(input_scale * kernel_scale, output_scale, 127, 'uint8') if requantize_output else None
if requantize_output:
assert use_bias
output = np.array([151, 152, 153, 185, 186, 187])
elif use_bias:
output = np.array([96, 100, 104, 232, 236, 240 ])
else:
output = np.array([92, 92, 92, 228, 228, 228 ])
output = output.astype(out_dtype).reshape(output_shape)
return make_configuration(quantized_data=quantized_data_np,
quantized_kernel=quantized_kernel_np,
dtype=in_dtype,
input_shape=input_shape,
kernel_shape=kernel_shape,
input_zero_point=input_zero_point,
kernel_zero_point=kernel_zero_point,
input_scale=input_scale,
kernel_scale= kernel_scale,
units=units,
output=output,
bias=bias,
requantize=requant_params)
def make_int_configuration(use_bias=False, requantize_output=False):
input_shape, kernel_shape, output_shape = (2, 10), (3,10), (2, 3)
def make_int_configuration(use_bias=False, requantize_output=False, per_channel=False):
input_shape, kernel_shape, output_shape = (2, 10), (3, 10), (2, 3)
input_zero_point, kernel_zero_point = -1, -1
in_dtype = 'int8'
out_dtype = 'int32' if not requantize_output else 'int8'
......@@ -138,15 +94,22 @@ def make_int_configuration(use_bias=False, requantize_output=False):
kernel_scale = 0.5
output_scale = 1.0
bias = np.array([4, 8, 12]).astype(out_dtype).reshape((units, )) if use_bias else None
requant_params = make_requantize_params(input_scale * kernel_scale, output_scale, -1, 'int8') if requantize_output else None
if requantize_output:
if per_channel:
assert use_bias and requantize_output
kernel_scale = np.array([0.5, 0.3, 0.4], dtype=np.float32)
output = np.array([23, 14, 20, 57, 34, 47])
elif requantize_output:
assert use_bias
output = np.array([23, 24, 25, 57, 58, 59])
elif use_bias:
output = np.array([96, 100, 104, 232, 236, 240 ])
output = np.array([96, 100, 104, 232, 236, 240])
else:
output = np.array([92, 92, 92, 228, 228, 228 ])
output = np.array([92, 92, 92, 228, 228, 228])
requant_params = make_requantize_params(input_scale * kernel_scale,
output_scale, -1, 'int8') if requantize_output else None
output = output.astype(out_dtype).reshape(output_shape)
return make_configuration(quantized_data=quantized_data_np,
quantized_kernel=quantized_kernel_np,
......@@ -206,8 +169,8 @@ def qnn_dense_driver(test_configuration):
with relay.build_config(opt_level=2):
graph, lib, params = relay.build(mod, "llvm", params=None)
mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0))
mod.set_input(quantized_data_name,test_configuration[quantized_data_name])
mod.set_input(quantized_kernel_name,test_configuration[quantized_kernel_name])
mod.set_input(quantized_data_name, test_configuration[quantized_data_name])
mod.set_input(quantized_kernel_name, test_configuration[quantized_kernel_name])
if test_configuration[bias_name] is not None:
mod.set_input(bias_name, test_configuration[bias_name])
mod.set_input(**params)
......@@ -241,7 +204,15 @@ def test_qnn_dense_with_requantized_output():
qnn_dense_driver(int8_requantized_output_with_bias_params)
def test_per_channel_weight_scale():
with TempOpAttr("qnn.dense", "FTVMQnnLegalize", legalize_qnn_dense):
config = make_int_configuration(use_bias=True, requantize_output=True,
per_channel=True)
qnn_dense_driver(config)
if __name__ == "__main__":
test_qnn_dense_without_bias()
test_qnn_dense_with_bias()
test_qnn_dense_with_requantized_output()
test_per_channel_weight_scale()
......@@ -191,6 +191,7 @@ def test_qnn_legalize_qnn_dense():
kernel_zero_point=relay.const(1, 'int32'),
input_scale=relay.const(1, 'float32'),
kernel_scale=relay.const(1, 'float32'),
units=kernel_shape[0],
out_dtype='int32')
mod = relay.Function(relay.analysis.free_vars(func), func)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment