Commit 3ba9dd09 by Ramana Radhakrishnan Committed by Zhi

Retain qnn input kernel scales (#4292)

* Add qnn conv2d attributes for input_tensor_scale and
kernel_tensor_scale.

The lowering in the tflite frontend loses the input_tensor_scale
and the kernel_tensor_scale by multiplying it and putting it into
the Requantize operation. This means that any graph partitioning
passes or other passes that need to access this information no longer
have it available in the qnn dialect.

regards
Ramana

* Store input tensor scale and Weight tensor scale for Dense as well

As for conv2d, the tflite frontend drops the input tensor
scale and the weight tensor scale from the relay op. Store
it as separate fields in there.

* Fix unintentional tab

* Rename input_tensor_scale to input_scale and kernel_tensor_scale
to kernel_scale for conv2d.

* input_tensor_scale -> input_scale weight_tensor_scale->weight_scale

* Rework dense testcase

And use input_scale and kernel_scale

* Be consistent in use of input_scale and kernel_scale values

* Fixup qnn conv2d tests for input_scale and kernel_scale

* Make pydoc identical between conv2d and dense for weight_tensor

* Fix up conv2d parameters to be in the same order between C++ and python

* Fix ordering of parameters for dense.

* Add input_scale and output_scale to try and satisfy ci gods

* Delete input_scale and kernel_scale.

nn.conv2d does not contain input_scale and kernel_scale. We need
to delete it when lowering it to nn.conv2d.

* Add input_scale and kernel_scale for qnn.conv2d
parent 560280dd
......@@ -135,6 +135,10 @@ struct QnnConv2DAttrs : public tvm::AttrsNode<QnnConv2DAttrs> {
// Quantization related attributes.
int32_t input_zero_point;
int32_t kernel_zero_point;
// The input tensor scale and kernel tensor scales are stored
// for easy access to this information.
double input_scale;
double kernel_scale;
TVM_DECLARE_ATTRS(QnnConv2DAttrs, "relay.attrs.QnnConv2DAttrs") {
TVM_ATTR_FIELD(strides).set_default(Array<IndexExpr>({1, 1}))
......@@ -177,6 +181,10 @@ struct QnnConv2DAttrs : public tvm::AttrsNode<QnnConv2DAttrs> {
.describe("The zero point of the input tensor.");
TVM_ATTR_FIELD(kernel_zero_point)
.describe("The zero point of the kernel tensor.");
TVM_ATTR_FIELD(input_scale)
.describe("The quantization scale for the input tensor.");
TVM_ATTR_FIELD(kernel_scale)
.describe("The quantization scale for the weight tensor.");
}
};
......@@ -212,6 +220,8 @@ struct QnnDenseAttrs : public tvm::AttrsNode<QnnDenseAttrs> {
// Quantization related attributes.
int32_t input_zero_point;
int32_t kernel_zero_point;
double input_scale;
double kernel_scale;
TVM_DECLARE_ATTRS(QnnDenseAttrs, "relay.attrs.QnnDenseAttrs") {
TVM_ATTR_FIELD(units)
......@@ -222,6 +232,10 @@ struct QnnDenseAttrs : public tvm::AttrsNode<QnnDenseAttrs> {
.describe("The zero point of the input tensor.");
TVM_ATTR_FIELD(kernel_zero_point)
.describe("The zero point of the kernel tensor.");
TVM_ATTR_FIELD(input_scale)
.describe("The input tensor scale.");
TVM_ATTR_FIELD(kernel_scale)
.describe("The kernel tensor scale.");
}
};
......
......@@ -729,9 +729,13 @@ class OperatorConverter(object):
weight_expr = self.exp_tab.new_const(weight_value, dtype=weight_tensor_type_str)
if input_tensor.qnn_params:
input_scale = input_tensor.qnn_params['scale']
kernel_scale = weight_tensor.qnn_params['scale']
out = _qnn.op.dense(in_expr, weight_expr,
input_zero_point=input_tensor.qnn_params['zero_point'],
kernel_zero_point=weight_tensor.qnn_params['zero_point'],
input_scale=input_scale,
kernel_scale=kernel_scale,
out_dtype='int32')
else:
out = _op.nn.dense(in_expr, weight_expr)
......@@ -935,6 +939,8 @@ class OperatorConverter(object):
qnn_conv2d_params['input_zero_point'] = input_tensor.qnn_params['zero_point']
qnn_conv2d_params['kernel_zero_point'] = weight_tensor.qnn_params['zero_point']
qnn_conv2d_params['out_dtype'] = 'int32'
qnn_conv2d_params['input_scale'] = input_tensor.qnn_params['scale']
qnn_conv2d_params['kernel_scale'] = weight_tensor.qnn_params['scale']
out = _qnn.op.conv2d(in_expr, weight_expr, **qnn_conv2d_params)
else:
out = _op.nn.conv2d(in_expr, weight_expr, **params)
......
......@@ -88,6 +88,8 @@ def helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay_op):
new_attrs = {k : attrs[k] for k in attrs.keys()}
del new_attrs['kernel_zero_point']
del new_attrs['input_zero_point']
del new_attrs['input_scale']
del new_attrs['kernel_scale']
return relay_op(shift_data, shift_kernel, **new_attrs)
# Helper function to change dtypes to uint8 x int8. Intel VNNI instructions prefer this setting.
......
......@@ -189,6 +189,8 @@ def conv2d(data,
kernel,
input_zero_point,
kernel_zero_point,
input_scale,
kernel_scale,
strides=(1, 1),
padding=(0, 0),
dilation=(1, 1),
......@@ -219,6 +221,16 @@ def conv2d(data,
input_zero_point: int
The zero point of the data distribution.
input_scale: float
The scale for the input tensor. The scale for the input tensor is
stored purely for convenience here. See more commentary below.
kernel_scale: float
The scale for the weight tensor. The scale for the weight tensor is
stored for access to this during relay. This information is not
needed in the pass pipeline after qnn.conv2d is lowered to the
sequence of steps as in nn.conv2d. See also input_scale in Requantize.
kernel_zero_point: int
The zero point of the quantized_kernel distribution.
......@@ -260,6 +272,7 @@ def conv2d(data,
return _make.conv2d(data, kernel,
input_zero_point, kernel_zero_point,
input_scale, kernel_scale,
strides, padding, dilation,
groups, channels, kernel_size,
data_layout, kernel_layout, out_layout, out_dtype)
......@@ -317,6 +330,8 @@ def dense(data,
weight,
input_zero_point,
kernel_zero_point,
input_scale,
kernel_scale,
units=None,
out_dtype="int32"):
"""Qnn Dense operator.
......@@ -332,6 +347,17 @@ def dense(data,
The quantized input data to the operator.
weight : tvm.relay.Expr
The quantized weight expressions.
input_zero_point: int
The input zero point.
kernel_zero_point: int
The kernel zero point.
input_scale: float
The scale for the input tensor.
kernel_scale: float
The scale for the weight tensor. The scale for the weight tensor is
stored for access to this during relay. This information is not
needed in the pass pipeline after qnn.conv2d is lowered to the
sequence of steps as in nn.conv2d. See also input_scale in Requantize.
units : int, optional
Number of hidden units of the dense transformation.
out_dtype : str, optional
......@@ -345,9 +371,11 @@ def dense(data,
return _make.dense(data,
weight,
units,
input_zero_point,
kernel_zero_point,
input_scale,
kernel_scale,
units,
out_dtype)
......
......@@ -440,7 +440,8 @@ Expr QnnConv2DCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
// Positional relay function to create quantized conv2d operator
// used by frontend FFI.
Expr MakeQnnConv2D(Expr data, Expr weight, int32_t input_zero_point, int32_t kernel_zero_point,
Array<IndexExpr> strides, Array<IndexExpr> padding, Array<IndexExpr> dilation,
double input_scale, double kernel_scale, Array<IndexExpr> strides,
Array<IndexExpr> padding, Array<IndexExpr> dilation,
int groups, IndexExpr channels, Array<IndexExpr> kernel_size,
std::string data_layout, std::string kernel_layout, std::string out_layout,
DataType out_dtype) {
......@@ -457,6 +458,8 @@ Expr MakeQnnConv2D(Expr data, Expr weight, int32_t input_zero_point, int32_t ker
attrs->out_dtype = std::move(out_dtype);
attrs->input_zero_point = std::move(input_zero_point);
attrs->kernel_zero_point = std::move(kernel_zero_point);
attrs->input_scale = std::move(input_scale);
attrs->kernel_scale = std::move(kernel_scale);
static const Op& op = Op::Get("qnn.conv2d");
return CallNode::make(op, {data, weight}, Attrs(attrs), {});
}
......
......@@ -57,13 +57,17 @@ bool QnnDenseRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
}
// Positional relay function to create quantized dense operator used by frontend FFI.
Expr MakeQuantizedDense(Expr data, Expr weight, IndexExpr units, int32_t input_zero_point,
int32_t kernel_zero_point, DataType out_dtype) {
Expr MakeQuantizedDense(Expr data, Expr weight, int32_t input_zero_point,
int32_t kernel_zero_point, double input_scale,
double kernel_scale, IndexExpr units,
DataType out_dtype) {
auto attrs = make_node<QnnDenseAttrs>();
attrs->units = std::move(units);
attrs->out_dtype = out_dtype;
attrs->input_zero_point = input_zero_point;
attrs->kernel_zero_point = kernel_zero_point;
attrs->input_scale = input_scale;
attrs->kernel_scale = kernel_scale;
static const Op& op = Op::Get("qnn.dense");
return CallNode::make(op, {data, weight}, Attrs(attrs), {});
}
......
......@@ -31,13 +31,15 @@ def make_requantize_params(input_scale, output_scale, output_zero_point, out_dty
return config
def make_configuration(quantized_data,
def make_configuration(quantized_data,
quantized_kernel,
dtype,
input_shape,
kernel_shape,
input_zero_point,
kernel_zero_point,
input_scale,
kernel_scale,
units,
output,
out_dtype='int32',
......@@ -53,6 +55,8 @@ def make_configuration(quantized_data,
'kernel_shape': kernel_shape,
'input_zero_point': input_zero_point,
'kernel_zero_point': kernel_zero_point,
'input_scale': input_scale,
'kernel_scale': kernel_scale,
'units': units,
'output': output,
'out_dtype': out_dtype,
......@@ -65,6 +69,9 @@ def make_configuration(quantized_data,
def make_uint_configuration(use_bias=False, requantize_output=False):
input_shape, kernel_shape, output_shape = (2, 10), (3,10), (2, 3)
input_zero_point, kernel_zero_point = 127, 127
input_scale = 0.5
kernel_scale = 0.5
output_scale = 1.0
in_dtype = 'uint8'
out_dtype = 'int32' if not requantize_output else 'uint8'
units = 3
......@@ -78,7 +85,7 @@ def make_uint_configuration(use_bias=False, requantize_output=False):
.astype(in_dtype) \
.reshape(kernel_shape)
bias = np.array([4, 8, 12]).astype(out_dtype).reshape((units, )) if use_bias else None
requant_params = make_requantize_params(0.25, 1.0, 127, 'uint8') if requantize_output else None
requant_params = make_requantize_params(input_scale * kernel_scale, output_scale, 127, 'uint8') if requantize_output else None
if requantize_output:
assert use_bias
......@@ -95,6 +102,8 @@ def make_uint_configuration(use_bias=False, requantize_output=False):
kernel_shape=kernel_shape,
input_zero_point=input_zero_point,
kernel_zero_point=kernel_zero_point,
input_scale=input_scale,
kernel_scale= kernel_scale,
units=units,
output=output,
bias=bias,
......@@ -116,8 +125,11 @@ def make_int_configuration(use_bias=False, requantize_output=False):
1, 3, 5, 7, 9, 11, 13, 15, 17, 19]) \
.astype(in_dtype) \
.reshape(kernel_shape)
input_scale = 0.5
kernel_scale = 0.5
output_scale = 1.0
bias = np.array([4, 8, 12]).astype(out_dtype).reshape((units, )) if use_bias else None
requant_params = make_requantize_params(0.25, 1.0, -1, 'int8') if requantize_output else None
requant_params = make_requantize_params(input_scale * kernel_scale, output_scale, -1, 'int8') if requantize_output else None
if requantize_output:
assert use_bias
......@@ -134,6 +146,8 @@ def make_int_configuration(use_bias=False, requantize_output=False):
kernel_shape=kernel_shape,
input_zero_point=input_zero_point,
kernel_zero_point=kernel_zero_point,
input_scale=input_scale,
kernel_scale=kernel_scale,
units=units,
output=output,
bias=bias,
......@@ -158,6 +172,8 @@ def qnn_dense_driver(test_configuration):
quantized_kernel,
test_configuration['input_zero_point'],
test_configuration['kernel_zero_point'],
test_configuration['input_scale'],
test_configuration['kernel_scale'],
test_configuration['units'])
if test_configuration[bias_name] is not None:
bias = relay.var(bias_name,
......
......@@ -103,6 +103,8 @@ def test_qnn_legalize_qnn_conv2d():
data, kernel,
input_zero_point=1,
kernel_zero_point=1,
input_scale=1.0,
kernel_scale=1.0,
kernel_size=(3, 3),
strides=(1, 1),
dilation=(1, 1),
......@@ -185,6 +187,8 @@ def test_qnn_legalize_qnn_dense():
data, kernel,
input_zero_point=1,
kernel_zero_point=1,
input_scale=1,
kernel_scale=1,
out_dtype='int32')
mod = relay.Function(relay.analysis.free_vars(func), func)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment