Commit 0720ed67 by Animesh Jain Committed by Wuwei Lin

{QNN] Making scale/zero_points as expr instead of attrs. (#4611)

parent 2440c9ce
...@@ -33,22 +33,10 @@ namespace qnn { ...@@ -33,22 +33,10 @@ namespace qnn {
/*! \brief Attribute for requantize operator */ /*! \brief Attribute for requantize operator */
struct RequantizeAttrs : public tvm::AttrsNode<RequantizeAttrs> { struct RequantizeAttrs : public tvm::AttrsNode<RequantizeAttrs> {
double input_scale;
int32_t input_zero_point;
double output_scale;
int32_t output_zero_point;
std::string rounding; std::string rounding;
DataType out_dtype; DataType out_dtype;
TVM_DECLARE_ATTRS(RequantizeAttrs, "relay.attrs.RequantizeAttrs") { TVM_DECLARE_ATTRS(RequantizeAttrs, "relay.attrs.RequantizeAttrs") {
TVM_ATTR_FIELD(input_scale)
.describe("The scale of the input tensor.");
TVM_ATTR_FIELD(input_zero_point)
.describe("The zero point of the input tensor.");
TVM_ATTR_FIELD(output_scale)
.describe("The scale of the output tensor.");
TVM_ATTR_FIELD(output_zero_point)
.describe("The zero point of the output tensor.");
TVM_ATTR_FIELD(rounding).set_default("UPWARD") TVM_ATTR_FIELD(rounding).set_default("UPWARD")
.describe("Defines the rounding direction when the value is midway between" .describe("Defines the rounding direction when the value is midway between"
"two representable values. There are two supported modes - UPWARD" "two representable values. There are two supported modes - UPWARD"
...@@ -67,175 +55,11 @@ struct RequantizeAttrs : public tvm::AttrsNode<RequantizeAttrs> { ...@@ -67,175 +55,11 @@ struct RequantizeAttrs : public tvm::AttrsNode<RequantizeAttrs> {
/*! \brief Attribute for quantize operator */ /*! \brief Attribute for quantize operator */
struct QuantizeAttrs : public tvm::AttrsNode<QuantizeAttrs> { struct QuantizeAttrs : public tvm::AttrsNode<QuantizeAttrs> {
int32_t output_zero_point;
double output_scale;
DataType out_dtype; DataType out_dtype;
TVM_DECLARE_ATTRS(QuantizeAttrs, "relay.attrs.QuantizeAttrs") { TVM_DECLARE_ATTRS(QuantizeAttrs, "relay.attrs.QuantizeAttrs") {
TVM_ATTR_FIELD(out_dtype) TVM_ATTR_FIELD(out_dtype)
.describe("Output data type, can be one of [int8 or uint8]."); .describe("Output data type, can be one of [int8 or uint8].");
TVM_ATTR_FIELD(output_zero_point)
.describe("The zero_point for the activation of this op.");
TVM_ATTR_FIELD(output_scale)
.describe("The scale for the activation of this op.");
}
};
/*! \brief Attribute for dequantize operator */
struct DequantizeAttrs : public tvm::AttrsNode<DequantizeAttrs> {
int32_t input_zero_point;
double input_scale;
TVM_DECLARE_ATTRS(DequantizeAttrs, "relay.attrs.DequantizeAttrs") {
TVM_ATTR_FIELD(input_zero_point)
.describe("The zero_point for the input tensor of this op.");
TVM_ATTR_FIELD(input_scale)
.describe("The scale for the input tensor of this op.");
}
};
/*! \brief Attributes used in QNN concatenate operator */
struct QnnConcatenateAttrs : public tvm::AttrsNode<QnnConcatenateAttrs> {
Array<tvm::Expr> input_scales;
Array<tvm::Expr> input_zero_points;
double output_scale;
int32_t output_zero_point;
int axis;
TVM_DECLARE_ATTRS(QnnConcatenateAttrs, "relay.attrs.QnnConcatenateAttrs") {
TVM_ATTR_FIELD(input_scales)
.describe("The list of scales of input quantized tensors.");
TVM_ATTR_FIELD(input_zero_points)
.describe("The list of zero points of input quantized tensors.");
TVM_ATTR_FIELD(output_zero_point)
.describe("The zero_point for the output tensor.");
TVM_ATTR_FIELD(output_scale)
.describe("The scale for the output tensor.");
TVM_ATTR_FIELD(axis)
.describe("The axis at which the input arrays are concatenated."
"Should lie in range `[-ndim, ndim)`.")
.set_default(0);
}
}; // struct QnnConcatenateAttrs
/*! \brief Attribute for QNN Conv2d operator */
struct QnnConv2DAttrs : public tvm::AttrsNode<QnnConv2DAttrs> {
// Traditional conv2d attributes.
Array<IndexExpr> strides;
Array<IndexExpr> padding;
Array<IndexExpr> dilation;
int groups;
IndexExpr channels;
Array<IndexExpr> kernel_size;
std::string data_layout;
std::string kernel_layout;
std::string out_layout;
DataType out_dtype;
// Quantization related attributes.
int32_t input_zero_point;
int32_t kernel_zero_point;
// The input tensor scale and kernel tensor scales are stored
// for easy access to this information.
double input_scale;
double kernel_scale;
TVM_DECLARE_ATTRS(QnnConv2DAttrs, "relay.attrs.QnnConv2DAttrs") {
TVM_ATTR_FIELD(strides).set_default(Array<IndexExpr>({1, 1}))
.describe("Specifies the strides of the convolution.");
TVM_ATTR_FIELD(padding).set_default(Array<IndexExpr>({0, 0}))
.describe("If padding is non-zero, then the input is implicitly zero-padded"
"on both sides for padding number of points");
TVM_ATTR_FIELD(dilation).set_default(Array<IndexExpr>({1, 1}))
.describe("Specifies the dilation rate to use for dilated convolution.");
TVM_ATTR_FIELD(groups).set_default(1)
.describe("Controls the connections between inputs and outputs."
"At groups=1, all inputs are convolved to all outputs."
"At groups=2, the operation becomes equivalent to having two convolution"
"layers side by side, each seeing half the input channels, and producing"
"half the output channels, and both subsequently concatenated.");
TVM_ATTR_FIELD(channels)
.describe("The number of output channels in the convolution."
" If it is not set, inferred by shape of the weight.")
.set_default(NullValue<IndexExpr>());
TVM_ATTR_FIELD(kernel_size)
.describe("Specifies the dimensions of the convolution window.")
.set_default(NullValue<Array<IndexExpr> >());
TVM_ATTR_FIELD(data_layout).set_default("NCHW")
.describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Convolution is applied on the 'H' and"
"'W' dimensions.");
TVM_ATTR_FIELD(kernel_layout).set_default("OIHW")
.describe("Dimension ordering of weight. Can be 'OIHW', 'OIHW16o16i', etc."
"'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width"
"dimensions respectively.");
TVM_ATTR_FIELD(out_layout).set_default("")
.describe("Dimension ordering of output. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Default to be same as input layout.");
TVM_ATTR_FIELD(out_dtype)
.set_default(NullValue<DataType>())
.describe("Output data type, set to explicit type under mixed precision setting");
TVM_ATTR_FIELD(input_zero_point)
.describe("The zero point of the input tensor.");
TVM_ATTR_FIELD(kernel_zero_point)
.describe("The zero point of the kernel tensor.");
TVM_ATTR_FIELD(input_scale)
.describe("The quantization scale for the input tensor.");
TVM_ATTR_FIELD(kernel_scale)
.describe("The quantization scale for the weight tensor.");
}
};
/*! \brief Attribute for QNN binary operator */
struct QnnBinaryOpAttrs : public tvm::AttrsNode<QnnBinaryOpAttrs> {
int32_t lhs_zero_point;
double lhs_scale;
int32_t rhs_zero_point;
double rhs_scale;
int32_t output_zero_point;
double output_scale;
TVM_DECLARE_ATTRS(QnnBinaryOpAttrs, "relay.attrs.QnnBinaryOpAttrs") {
TVM_ATTR_FIELD(lhs_zero_point)
.describe("The zero_point for the lhs input tensor of this op.");
TVM_ATTR_FIELD(lhs_scale)
.describe("The scale for the lhs input tensor of this op.");
TVM_ATTR_FIELD(rhs_zero_point)
.describe("The zero_point for the rhs input tensor of this op.");
TVM_ATTR_FIELD(rhs_scale)
.describe("The scale for the rhs input tensor of this op.");
TVM_ATTR_FIELD(output_zero_point)
.describe("The zero_point for the activation of this op.");
TVM_ATTR_FIELD(output_scale)
.describe("The scale for the activation of this op.");
}
};
/*! \brief Attributes for qnn dense operator */
struct QnnDenseAttrs : public tvm::AttrsNode<QnnDenseAttrs> {
IndexExpr units;
DataType out_dtype;
// Quantization related attributes.
int32_t input_zero_point;
int32_t kernel_zero_point;
double input_scale;
double kernel_scale;
TVM_DECLARE_ATTRS(QnnDenseAttrs, "relay.attrs.QnnDenseAttrs") {
TVM_ATTR_FIELD(units)
.describe("Number of hidden units of the dense transformation.");
TVM_ATTR_FIELD(out_dtype)
.describe("Output data type, set to explicit type under mixed precision setting");
TVM_ATTR_FIELD(input_zero_point)
.describe("The zero point of the input tensor.");
TVM_ATTR_FIELD(kernel_zero_point)
.describe("The zero point of the kernel tensor.");
TVM_ATTR_FIELD(input_scale)
.describe("The input tensor scale.");
TVM_ATTR_FIELD(kernel_scale)
.describe("The kernel tensor scale.");
} }
}; };
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
""" """
import numpy as np import numpy as np
from tvm import relay
from tvm.relay.qnn.op.qnn import dequantize from tvm.relay.qnn.op.qnn import dequantize
zero_centered_uint8_quantized_range = np.float32(255) zero_centered_uint8_quantized_range = np.float32(255)
...@@ -54,8 +55,8 @@ def _dequantize_zero_centered(data, ...@@ -54,8 +55,8 @@ def _dequantize_zero_centered(data,
real_range = np.max([np.abs(np.float32(data_min)), real_range = np.max([np.abs(np.float32(data_min)),
np.abs(np.float32(data_max))]) np.abs(np.float32(data_max))])
scale = np.divide(real_range, quantized_range) scale = relay.const(np.divide(real_range, quantized_range), 'float32')
zero_point = 0 zero_point = relay.const(0, 'int32')
return dequantize(data, scale, zero_point) return dequantize(data, scale, zero_point)
...@@ -186,9 +187,11 @@ def _dequantize_mxnet_min_max_uint8(data, ...@@ -186,9 +187,11 @@ def _dequantize_mxnet_min_max_uint8(data,
max_limit = np.float64(iinfo.max) max_limit = np.float64(iinfo.max)
imin_range = np.float64(imin_range) imin_range = np.float64(imin_range)
imax_range = np.float64(imax_range) imax_range = np.float64(imax_range)
scale = np.divide((imax_range - imin_range), scale_val = np.divide((imax_range - imin_range),
(max_limit - min_limit)) (max_limit - min_limit))
zero_point = np.int(-1 * np.divide(imin_range, scale)) zero_point_val = np.int(-1 * np.divide(imin_range, scale_val))
scale = relay.const(scale_val, 'float32')
zero_point = relay.const(zero_point_val, 'int32')
return dequantize(data, scale, zero_point) return dequantize(data, scale, zero_point)
......
...@@ -20,11 +20,13 @@ from __future__ import absolute_import as _abs ...@@ -20,11 +20,13 @@ from __future__ import absolute_import as _abs
import math import math
import numpy as np import numpy as np
import tvm import tvm
from tvm import relay
from .. import analysis from .. import analysis
from .. import expr as _expr from .. import expr as _expr
from .. import module as _module from .. import module as _module
from .. import op as _op from .. import op as _op
from .. import qnn as _qnn from .. import qnn as _qnn
from ..util import get_scalar_from_constant
from ... import nd as _nd from ... import nd as _nd
from .common import ExprTable from .common import ExprTable
from .common import infer_shape as _infer_shape from .common import infer_shape as _infer_shape
...@@ -177,8 +179,8 @@ class OperatorConverter(object): ...@@ -177,8 +179,8 @@ class OperatorConverter(object):
# Check that the scale and zero points are valid. # Check that the scale and zero points are valid.
if scale != 0 or zero_point != 0: if scale != 0 or zero_point != 0:
qnn_params = dict() qnn_params = dict()
qnn_params['scale'] = scale qnn_params['scale'] = relay.const(scale, 'float32')
qnn_params['zero_point'] = zero_point qnn_params['zero_point'] = relay.const(zero_point, 'int32')
return_list.append(TensorWrapper(tensor_idx, tensor, buffer, qnn_params)) return_list.append(TensorWrapper(tensor_idx, tensor, buffer, qnn_params))
return return_list return return_list
...@@ -225,8 +227,16 @@ class OperatorConverter(object): ...@@ -225,8 +227,16 @@ class OperatorConverter(object):
.format(str(tensor_type))) .format(str(tensor_type)))
def has_same_qnn_params(self, lhs_tensor, rhs_tensor): def has_same_qnn_params(self, lhs_tensor, rhs_tensor):
return lhs_tensor.qnn_params['scale'] == rhs_tensor.qnn_params['scale'] and \ lhs_scale = lhs_tensor.qnn_params['scale']
lhs_tensor.qnn_params['zero_point'] == rhs_tensor.qnn_params['zero_point'] rhs_scale = rhs_tensor.qnn_params['scale']
lhs_zero_point = lhs_tensor.qnn_params['zero_point']
rhs_zero_point = rhs_tensor.qnn_params['zero_point']
lhs_scale_value = get_scalar_from_constant(lhs_scale)
rhs_scale_value = get_scalar_from_constant(rhs_scale)
lhs_zero_point_value = get_scalar_from_constant(lhs_zero_point)
rhs_zero_point_value = get_scalar_from_constant(rhs_zero_point)
return lhs_scale_value == rhs_scale_value and \
lhs_zero_point_value == rhs_zero_point_value
def is_quantized(self, op): def is_quantized(self, op):
"""Check if an input tensor is quantized.""" """Check if an input tensor is quantized."""
...@@ -750,13 +760,11 @@ class OperatorConverter(object): ...@@ -750,13 +760,11 @@ class OperatorConverter(object):
weight_expr = self.exp_tab.new_const(weight_value, dtype=weight_tensor_type_str) weight_expr = self.exp_tab.new_const(weight_value, dtype=weight_tensor_type_str)
if input_tensor.qnn_params: if input_tensor.qnn_params:
input_scale = input_tensor.qnn_params['scale']
kernel_scale = weight_tensor.qnn_params['scale']
out = _qnn.op.dense(in_expr, weight_expr, out = _qnn.op.dense(in_expr, weight_expr,
input_zero_point=input_tensor.qnn_params['zero_point'], input_zero_point=input_tensor.qnn_params['zero_point'],
kernel_zero_point=weight_tensor.qnn_params['zero_point'], kernel_zero_point=weight_tensor.qnn_params['zero_point'],
input_scale=input_scale, input_scale=input_tensor.qnn_params['scale'],
kernel_scale=kernel_scale, kernel_scale=weight_tensor.qnn_params['scale'],
out_dtype='int32') out_dtype='int32')
else: else:
out = _op.nn.dense(in_expr, weight_expr) out = _op.nn.dense(in_expr, weight_expr)
...@@ -783,11 +791,16 @@ class OperatorConverter(object): ...@@ -783,11 +791,16 @@ class OperatorConverter(object):
# Finally if the dense is quantized. Add a requantize at the end. # Finally if the dense is quantized. Add a requantize at the end.
if output_tensor.qnn_params: if output_tensor.qnn_params:
input_scale = input_tensor.qnn_params['scale'] * weight_tensor.qnn_params['scale'] data_scale = input_tensor.qnn_params['scale']
input_zero_point = 0 weight_scale = weight_tensor.qnn_params['scale']
data_scale_val = get_scalar_from_constant(data_scale)
weight_scale_val = get_scalar_from_constant(weight_scale)
new_input_scale_val = data_scale_val * weight_scale_val
new_input_scale = relay.const(new_input_scale_val, 'float32')
new_input_zero_point = relay.const(0, 'int32')
out = _qnn.op.requantize(out, out = _qnn.op.requantize(out,
input_scale=input_scale, input_scale=new_input_scale,
input_zero_point=input_zero_point, input_zero_point=new_input_zero_point,
output_scale=output_tensor.qnn_params['scale'], output_scale=output_tensor.qnn_params['scale'],
output_zero_point=output_tensor.qnn_params['zero_point'], output_zero_point=output_tensor.qnn_params['zero_point'],
out_dtype=output_tensor_type_str) out_dtype=output_tensor_type_str)
...@@ -989,11 +1002,16 @@ class OperatorConverter(object): ...@@ -989,11 +1002,16 @@ class OperatorConverter(object):
# Finally if the conv is quantized. Add a requantize at the end. # Finally if the conv is quantized. Add a requantize at the end.
if output_tensor.qnn_params: if output_tensor.qnn_params:
input_scale = input_tensor.qnn_params['scale'] * weight_tensor.qnn_params['scale'] data_scale = input_tensor.qnn_params['scale']
input_zero_point = 0 weight_scale = weight_tensor.qnn_params['scale']
data_scale_val = get_scalar_from_constant(data_scale)
weight_scale_val = get_scalar_from_constant(weight_scale)
new_input_scale_val = data_scale_val * weight_scale_val
new_input_scale = relay.const(new_input_scale_val, 'float32')
new_input_zero_point = relay.const(0, 'int32')
out = _qnn.op.requantize(out, out = _qnn.op.requantize(out,
input_scale=input_scale, input_scale=new_input_scale,
input_zero_point=input_zero_point, input_zero_point=new_input_zero_point,
output_scale=output_tensor.qnn_params['scale'], output_scale=output_tensor.qnn_params['scale'],
output_zero_point=output_tensor.qnn_params['zero_point'], output_zero_point=output_tensor.qnn_params['zero_point'],
out_dtype=output_tensor_type_str) out_dtype=output_tensor_type_str)
......
...@@ -20,4 +20,3 @@ from __future__ import absolute_import as _abs ...@@ -20,4 +20,3 @@ from __future__ import absolute_import as _abs
from .qnn import * from .qnn import *
from .op import register_qnn_legalize from .op import register_qnn_legalize
from . import legalizations from . import legalizations
from . import op_attrs
...@@ -21,6 +21,7 @@ from __future__ import absolute_import ...@@ -21,6 +21,7 @@ from __future__ import absolute_import
import tvm import tvm
from tvm import relay from tvm import relay
from .. import op as reg from .. import op as reg
from ...util import get_scalar_from_constant
################################################# #################################################
# Register the functions for different operators. # Register the functions for different operators.
...@@ -76,20 +77,13 @@ def helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay_op): ...@@ -76,20 +77,13 @@ def helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay_op):
""" """
# Collect the input exprs. # Collect the input exprs.
data, kernel = inputs data, kernel, input_zero_point, kernel_zero_point, _, _ = inputs
input_zp = attrs['input_zero_point']
kernel_zp = attrs['kernel_zero_point']
shift_data = relay.subtract(relay.cast(data, dtype='int16'), shift_data = relay.subtract(relay.cast(data, dtype='int16'),
relay.const(input_zp, 'int16')) relay.cast(input_zero_point, 'int16'))
shift_kernel = relay.subtract(relay.cast(kernel, dtype='int16'), shift_kernel = relay.subtract(relay.cast(kernel, dtype='int16'),
relay.const(kernel_zp, 'int16')) relay.cast(kernel_zero_point, 'int16'))
new_attrs = {k : attrs[k] for k in attrs.keys()} new_attrs = {k : attrs[k] for k in attrs.keys()}
del new_attrs['kernel_zero_point']
del new_attrs['input_zero_point']
del new_attrs['input_scale']
del new_attrs['kernel_scale']
return relay_op(shift_data, shift_kernel, **new_attrs) return relay_op(shift_data, shift_kernel, **new_attrs)
# Helper function to change dtypes to uint8 x int8. Intel VNNI instructions prefer this setting. # Helper function to change dtypes to uint8 x int8. Intel VNNI instructions prefer this setting.
...@@ -136,36 +130,36 @@ def helper_change_dtypes_to_uint8_int8(attrs, inputs, types, relay_op): ...@@ -136,36 +130,36 @@ def helper_change_dtypes_to_uint8_int8(attrs, inputs, types, relay_op):
data_modified = relay.cast(data, 'int32') data_modified = relay.cast(data, 'int32')
data_modified = relay.add(data_modified, relay.const(shift, 'int32')) data_modified = relay.add(data_modified, relay.const(shift, 'int32'))
data_modified = relay.cast(data_modified, out_dtype) data_modified = relay.cast(data_modified, out_dtype)
return (data_modified, zero_point + shift) zero_point_val = get_scalar_from_constant(zero_point)
zero_point_modified = relay.const(zero_point_val + shift, 'int32')
return (data_modified, zero_point_modified)
# Collect the dtypes. # Collect the dtypes.
data_dtype = types[0].dtype data_dtype = types[0].dtype
kernel_dtype = types[1].dtype kernel_dtype = types[1].dtype
# Collect the input exprs. # Collect the input exprs.
data, kernel = inputs data, kernel, input_zero_point, kernel_zero_point, input_scale, kernel_scale = inputs
# VNNI supports u8 x i8 fast conv/MM. Don't do anything if it is already satisfied. # VNNI supports u8 x i8 fast conv/MM. Don't do anything if it is already satisfied.
if data_dtype == 'uint8' and kernel_dtype == 'int8': if data_dtype == 'uint8' and kernel_dtype == 'int8':
return None return None
# Shift input if necessary. # Shift input if necessary.
input_zp = attrs['input_zero_point']
if data_dtype == 'int8': if data_dtype == 'int8':
# Compute (QA + 128) and (zp_a + 128) # Compute (QA + 128) and (zp_a + 128)
data, input_zp = _shift(data, input_zp, 'uint8') data, input_zero_point = _shift(data, input_zero_point, 'uint8')
# Shift kernel if necessary. # Shift kernel if necessary.
kernel_zp = attrs['kernel_zero_point']
if kernel_dtype == 'uint8': if kernel_dtype == 'uint8':
# Compute (QA - 128) and (zp_a - 128) # Compute (QA - 128) and (zp_a - 128)
kernel, kernel_zp = _shift(kernel, kernel_zp, 'int8') kernel, kernel_zero_point = _shift(kernel, kernel_zero_point, 'int8')
# Call qnn.conv2d with modified inputs and zero points. # Call qnn.conv2d with modified inputs and zero points.
new_attrs = {k : attrs[k] for k in attrs.keys()} new_attrs = {k : attrs[k] for k in attrs.keys()}
new_attrs['input_zero_point'] = input_zp return relay_op(data, kernel,
new_attrs['kernel_zero_point'] = kernel_zp input_zero_point, kernel_zero_point,
return relay_op(data, kernel, **new_attrs) input_scale, kernel_scale, **new_attrs)
# Helper function to change dtypes to be same. ARM dotprod instructions prefer this setting. # Helper function to change dtypes to be same. ARM dotprod instructions prefer this setting.
def helper_change_dtypes_to_be_same(attrs, inputs, types, relay_op): def helper_change_dtypes_to_be_same(attrs, inputs, types, relay_op):
...@@ -199,7 +193,9 @@ def helper_change_dtypes_to_be_same(attrs, inputs, types, relay_op): ...@@ -199,7 +193,9 @@ def helper_change_dtypes_to_be_same(attrs, inputs, types, relay_op):
data_modified = relay.cast(data, 'int32') data_modified = relay.cast(data, 'int32')
data_modified = relay.add(data_modified, relay.const(shift, 'int32')) data_modified = relay.add(data_modified, relay.const(shift, 'int32'))
data_modified = relay.cast(data_modified, out_dtype) data_modified = relay.cast(data_modified, out_dtype)
return (data_modified, zero_point + shift) zero_point_val = get_scalar_from_constant(zero_point)
zero_point_modified = relay.const(zero_point_val + shift, 'int32')
return (data_modified, zero_point_modified)
# Collect the dtypes. # Collect the dtypes.
data_dtype = types[0].dtype data_dtype = types[0].dtype
...@@ -209,18 +205,18 @@ def helper_change_dtypes_to_be_same(attrs, inputs, types, relay_op): ...@@ -209,18 +205,18 @@ def helper_change_dtypes_to_be_same(attrs, inputs, types, relay_op):
return None return None
# Collect the input exprs. # Collect the input exprs.
data, kernel = inputs data, kernel, input_zero_point, kernel_zero_point, input_scale, kernel_scale = inputs
assert 'int8' in data_dtype and 'int8' in kernel_dtype, \ assert 'int8' in data_dtype and 'int8' in kernel_dtype, \
"Qnn Conv2D/Dense only accepts uint8 or int8 inputs" "Qnn Conv2D/Dense only accepts uint8 or int8 inputs"
# Shift input if necessary. # Shift input if necessary.
input_zp = attrs['input_zero_point'] data, input_zero_point = _shift(data, input_zero_point, kernel_dtype)
data, input_zp = _shift(data, input_zp, kernel_dtype)
new_attrs = {k : attrs[k] for k in attrs.keys()} new_attrs = {k : attrs[k] for k in attrs.keys()}
new_attrs['input_zero_point'] = input_zp return relay_op(data, kernel,
return relay_op(data, kernel, **new_attrs) input_zero_point, kernel_zero_point,
input_scale, kernel_scale, **new_attrs)
def is_fast_int8_on_intel(): def is_fast_int8_on_intel():
""" Checks whether the hardware has support for fast Int8 arithmetic operations. """ """ Checks whether the hardware has support for fast Int8 arithmetic operations. """
......
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
"""QNN dialect operators.""" """QNN dialect operators."""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from tvm.expr import FloatImm, IntImm
from tvm.relay.expr import Tuple from tvm.relay.expr import Tuple
from . import _make from . import _make
...@@ -42,16 +41,16 @@ def requantize(data, ...@@ -42,16 +41,16 @@ def requantize(data,
data : tvm.relay.Expr data : tvm.relay.Expr
The input data to the operator. The input data to the operator.
input_scale: float input_scale: tvm.relay.Expr
The quantization scale for the input tensor. The quantization scale for the input tensor.
input_zero_point: int input_zero_point: tvm.relay.Expr
The zero point of the input tensor. The zero point of the input tensor.
output_scale: float output_scale: tvm.relay.Expr
The quantization scale for the output tensor. The quantization scale for the output tensor.
output_zero_point: int output_zero_point: tvm.relay.Expr
The zero point of the output tensor. The zero point of the output tensor.
rounding : string, optional rounding : string, optional
...@@ -92,9 +91,9 @@ def quantize(data, ...@@ -92,9 +91,9 @@ def quantize(data,
---------- ----------
data : tvm.relay.Expr data : tvm.relay.Expr
The input tensor to be quantized. Can be of type float32. The input tensor to be quantized. Can be of type float32.
output_zero_point : int output_zero_point : tvm.relay.Expr
The output zero_point. The output zero_point.
output_scale : float output_scale : tvm.relay.Expr
The output scale. The output scale.
out_dtype : str, optional out_dtype : str, optional
The data type of the input tensor. Can be [int8, uint8] The data type of the input tensor. Can be [int8, uint8]
...@@ -122,9 +121,9 @@ def dequantize(data, ...@@ -122,9 +121,9 @@ def dequantize(data,
---------- ----------
data : tvm.relay.Expr data : tvm.relay.Expr
The input tensor to be dequantized. Can be of type [int8, uint8]. The input tensor to be dequantized. Can be of type [int8, uint8].
input_zero_point : int input_zero_point : tvm.relay.Expr
The output zero_point. The output zero_point.
input_scale : float input_scale : tvm.relay.Expr
The output scale. The output scale.
Returns Returns
------- -------
...@@ -150,16 +149,16 @@ def concatenate(data, ...@@ -150,16 +149,16 @@ def concatenate(data,
data : Union(List[relay.Expr], Tuple[relay.Expr]) data : Union(List[relay.Expr], Tuple[relay.Expr])
The list of quantized tensors. The list of quantized tensors.
input_scales : List[float32] input_scales : List[relay.Expr]
The list of scales of input quantized tensors. The list of scales of input quantized tensors.
input_zero_points : List[int32] input_zero_points : List[relay.Expr]
The list of zero points of input quantized tensors. The list of zero points of input quantized tensors.
output_scale : float32 output_scale : relay.Expr
The scale of the output quantized tensor. The scale of the output quantized tensor.
output_zero_point : int32 output_zero_point : relay.Expr
The zero point of the output quantized tensor. The zero point of the output quantized tensor.
axis : int axis : int
...@@ -176,10 +175,12 @@ def concatenate(data, ...@@ -176,10 +175,12 @@ def concatenate(data,
raise ValueError("relay.concatenate requires data to be non-empty.") raise ValueError("relay.concatenate requires data to be non-empty.")
if not isinstance(axis, int): if not isinstance(axis, int):
raise ValueError("For now, we only support integer axis") raise ValueError("For now, we only support integer axis")
input_scales = list(input_scales)
input_zero_points = list(input_zero_points)
return _make.concatenate(Tuple(data), return _make.concatenate(Tuple(data),
[FloatImm("float64", x) for x in input_scales], Tuple(input_scales),
[IntImm("int32", x) for x in input_zero_points], Tuple(input_zero_points),
output_scale, output_scale,
output_zero_point, output_zero_point,
axis) axis)
...@@ -218,22 +219,22 @@ def conv2d(data, ...@@ -218,22 +219,22 @@ def conv2d(data,
kernel : tvm.relay.Expr kernel : tvm.relay.Expr
The kernel expressions. The kernel expressions.
input_zero_point: int input_zero_point: tvm.relay.Expr
The zero point of the data distribution. The zero point of the data distribution.
input_scale: float kernel_zero_point: tvm.relay.Expr
The zero point of the quantized_kernel distribution.
input_scale: tvm.relay.Expr
The scale for the input tensor. The scale for the input tensor is The scale for the input tensor. The scale for the input tensor is
stored purely for convenience here. See more commentary below. stored purely for convenience here. See more commentary below.
kernel_scale: float kernel_scale: tvm.relay.Expr
The scale for the weight tensor. The scale for the weight tensor is The scale for the weight tensor. The scale for the weight tensor is
stored for access to this during relay. This information is not stored for access to this during relay. This information is not
needed in the pass pipeline after qnn.conv2d is lowered to the needed in the pass pipeline after qnn.conv2d is lowered to the
sequence of steps as in nn.conv2d. See also input_scale in Requantize. sequence of steps as in nn.conv2d. See also input_scale in Requantize.
kernel_zero_point: int
The zero point of the quantized_kernel distribution.
strides : tuple of int, optional strides : tuple of int, optional
The strides of convolution. The strides of convolution.
...@@ -299,19 +300,22 @@ def add(lhs, ...@@ -299,19 +300,22 @@ def add(lhs,
lhs_scale: float lhs_scale: float
The scale of the lhs quantized expr. The scale of the lhs quantized expr.
lhs_zero_point: int lhs_scale: relay.Expr
The scale of the lhs quantized expr.
lhs_zero_point: relay.Expr
The zero point of lhs quantized expr. The zero point of lhs quantized expr.
rhs_scale: float rhs_scale: relay.Expr
The scale of the rhs quantized expr. The scale of the rhs quantized expr.
rhs_zero_point: int rhs_zero_point: relay.Expr
The zero point of rhs quantized expr. The zero point of rhs quantized expr.
output_scale: float output_scale: relay.Expr
The scale of the output quantized expr. The scale of the output quantized expr.
output_zero_point: int output_zero_point: relay.Expr
The zero point of output quantized expr. The zero point of output quantized expr.
Returns Returns
...@@ -347,13 +351,13 @@ def dense(data, ...@@ -347,13 +351,13 @@ def dense(data,
The quantized input data to the operator. The quantized input data to the operator.
weight : tvm.relay.Expr weight : tvm.relay.Expr
The quantized weight expressions. The quantized weight expressions.
input_zero_point: int input_zero_point: tvm.relay.Expr
The input zero point. The input zero point.
kernel_zero_point: int kernel_zero_point: tvm.relay.Expr
The kernel zero point. The kernel zero point.
input_scale: float input_scale: tvm.relay.Expr
The scale for the input tensor. The scale for the input tensor.
kernel_scale: float kernel_scale: tvm.relay.Expr
The scale for the weight tensor. The scale for the weight tensor is The scale for the weight tensor. The scale for the weight tensor is
stored for access to this during relay. This information is not stored for access to this during relay. This information is not
needed in the pass pipeline after qnn.conv2d is lowered to the needed in the pass pipeline after qnn.conv2d is lowered to the
...@@ -391,22 +395,22 @@ def mul(lhs, rhs, lhs_scale, lhs_zero_point, rhs_scale, rhs_zero_point, ...@@ -391,22 +395,22 @@ def mul(lhs, rhs, lhs_scale, lhs_zero_point, rhs_scale, rhs_zero_point,
rhs : relay.Expr rhs : relay.Expr
The right hand side quantized input data. The right hand side quantized input data.
lhs_scale: float lhs_scale: relay.Expr
The scale of the lhs quantized expr. The scale of the lhs quantized expr.
lhs_zero_point: int lhs_zero_point: relay.Expr
The zero point of lhs quantized expr. The zero point of lhs quantized expr.
rhs_scale: float rhs_scale: relay.Expr
The scale of the rhs quantized expr. The scale of the rhs quantized expr.
rhs_zero_point: int rhs_zero_point: relay.Expr
The zero point of rhs quantized expr. The zero point of rhs quantized expr.
output_scale: float output_scale: relay.Expr
The scale of the output quantized expr. The scale of the output quantized expr.
output_zero_point: int output_zero_point: relay.Expr
The zero point of output quantized expr. The zero point of output quantized expr.
Returns Returns
......
...@@ -14,15 +14,20 @@ ...@@ -14,15 +14,20 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
"""The attributes node used for QNN operators""" # pylint: disable=wildcard-import, redefined-builtin, invalid-name
""" Utility functions that are used across many directories. """
from __future__ import absolute_import
import numpy as np
from . import expr as _expr
from ....attrs import Attrs def get_scalar_from_constant(expr):
from ...base import register_relay_attr_node """ Returns scalar value from Relay constant scalar. """
assert isinstance(expr, _expr.Constant) and not expr.data.shape, \
@register_relay_attr_node "Expr is not a constant scalar."
class QnnConv2DAttrs(Attrs): value = expr.data.asnumpy()
"""Attributes for qnn.conv2d""" if value.dtype == np.dtype(np.int32):
return int(value)
@register_relay_attr_node if value.dtype == np.dtype(np.float32):
class QnnDenseAttrs(Attrs): return float(value)
"""Attributes for qnn.dense""" assert False, "Constant expr must be float32/int32"
return None # To suppress pylint
...@@ -285,6 +285,7 @@ inline Expr Log(Expr e) { ...@@ -285,6 +285,7 @@ inline Expr Log(Expr e) {
template <typename T> template <typename T>
T GetScalarFromConstant(Expr expr) { T GetScalarFromConstant(Expr expr) {
const auto* n = expr.as<ConstantNode>(); const auto* n = expr.as<ConstantNode>();
CHECK(n) << "Expr must be a constant expr - " << AsText(expr, false);
CHECK(n->is_scalar()); CHECK(n->is_scalar());
return static_cast<T*>(n->data->data)[0]; return static_cast<T*>(n->data->data)[0];
} }
......
...@@ -42,20 +42,18 @@ namespace qnn { ...@@ -42,20 +42,18 @@ namespace qnn {
Expr QnnAddCanonicalize(const Attrs& attrs, const Array<Expr>& new_args, Expr QnnAddCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
const Array<tvm::relay::Type>& arg_types) { const Array<tvm::relay::Type>& arg_types) {
// Get the attrs. // Get the attrs.
CHECK_EQ(new_args.size(), 2); CHECK_EQ(new_args.size(), 8);
auto& lhs = new_args[0]; auto& lhs = new_args[0];
auto& rhs = new_args[1]; auto& rhs = new_args[1];
const auto* binary_op_attrs = attrs.as<QnnBinaryOpAttrs>(); auto& lhs_scale = new_args[2];
CHECK(binary_op_attrs != nullptr); auto& lhs_zero_point = new_args[3];
auto lhs_scale = binary_op_attrs->lhs_scale; auto& rhs_scale = new_args[4];
auto lhs_zero_point = binary_op_attrs->lhs_zero_point; auto& rhs_zero_point = new_args[5];
auto rhs_scale = binary_op_attrs->rhs_scale; auto& output_scale = new_args[6];
auto rhs_zero_point = binary_op_attrs->rhs_zero_point; auto& output_zero_point = new_args[7];
auto output_scale = binary_op_attrs->output_scale;
auto output_zero_point = binary_op_attrs->output_zero_point;
// Get the input dtype and shape. // Get the input dtype and shape.
CHECK_EQ(arg_types.size(), 3); CHECK_EQ(arg_types.size(), 9);
auto tensor_type = arg_types[0].as<TensorTypeNode>(); auto tensor_type = arg_types[0].as<TensorTypeNode>();
auto input_dtype = tensor_type->dtype; auto input_dtype = tensor_type->dtype;
auto input_shape = tensor_type->shape; auto input_shape = tensor_type->shape;
...@@ -82,7 +80,8 @@ Expr QnnAddCanonicalize(const Attrs& attrs, const Array<Expr>& new_args, ...@@ -82,7 +80,8 @@ Expr QnnAddCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
// Requantize LHS if necessary. // Requantize LHS if necessary.
auto requantized_lhs = lhs; auto requantized_lhs = lhs;
if (lhs_scale != output_scale || lhs_zero_point != output_zero_point) { if (!IsEqualScalar(lhs_scale, output_scale) ||
!IsEqualScalar(lhs_zero_point, output_zero_point)) {
requantized_lhs = Requantize(lhs, input_shape, lhs_scale, lhs_zero_point, output_scale, requantized_lhs = Requantize(lhs, input_shape, lhs_scale, lhs_zero_point, output_scale,
output_zero_point, DataType::Int(32)); output_zero_point, DataType::Int(32));
} else { } else {
...@@ -91,7 +90,8 @@ Expr QnnAddCanonicalize(const Attrs& attrs, const Array<Expr>& new_args, ...@@ -91,7 +90,8 @@ Expr QnnAddCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
// Requantize RHS if necessary. // Requantize RHS if necessary.
auto requantized_rhs = rhs; auto requantized_rhs = rhs;
if (rhs_scale != output_scale || rhs_zero_point != output_zero_point) { if (!IsEqualScalar(rhs_scale, output_scale) ||
!IsEqualScalar(rhs_zero_point, output_zero_point)) {
requantized_rhs = Requantize(rhs, input_shape, rhs_scale, rhs_zero_point, output_scale, requantized_rhs = Requantize(rhs, input_shape, rhs_scale, rhs_zero_point, output_scale,
output_zero_point, DataType::Int(32)); output_zero_point, DataType::Int(32));
} else { } else {
...@@ -101,9 +101,9 @@ Expr QnnAddCanonicalize(const Attrs& attrs, const Array<Expr>& new_args, ...@@ -101,9 +101,9 @@ Expr QnnAddCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
auto output = Add(requantized_lhs, requantized_rhs); auto output = Add(requantized_lhs, requantized_rhs);
// Subtract zero point. // Subtract zero point.
if (output_zero_point != 0) { auto zero_scalar = MakeConstantScalar(DataType::Int(32), 0);
auto output_zp = MakeConstantScalar(DataType::Int(32), output_zero_point); if (!IsEqualScalar(output_zero_point, zero_scalar)) {
output = Subtract(output, output_zp); output = Subtract(output, output_zero_point);
} }
// Go back to lower precision. // Go back to lower precision.
......
...@@ -34,19 +34,48 @@ namespace tvm { ...@@ -34,19 +34,48 @@ namespace tvm {
namespace relay { namespace relay {
namespace qnn { namespace qnn {
TVM_REGISTER_NODE_TYPE(QnnConcatenateAttrs); bool QnnConcatenateRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
Expr MakeQnnConcatenate(Expr data, Array<tvm::Expr> input_scales, CHECK_EQ(types.size(), 6);
Array<tvm::Expr> input_zero_points, double output_scale,
int32_t output_zero_point, int axis) { // Check the scale and zero point types
auto attrs = make_object<QnnConcatenateAttrs>(); const auto* input_scales_tuple = types[1].as<TupleTypeNode>();
attrs->input_scales = std::move(input_scales); if (input_scales_tuple == nullptr) {
attrs->input_zero_points = std::move(input_zero_points); throw relay::Error(
attrs->output_scale = output_scale; RELAY_ERROR("qnn concatenate requires a tuple of scales as the second argument, found "
attrs->output_zero_point = output_zero_point; << PrettyPrint(types[1])));
}
for (const auto& input_scale : input_scales_tuple->fields) {
CHECK(IsScalarType(input_scale, DataType::Float(32))); // input_scales[idx]
}
const auto* input_zero_points_tuple = types[2].as<TupleTypeNode>();
if (input_zero_points_tuple == nullptr) {
throw relay::Error(
RELAY_ERROR("qnn concatenate requires a tuple of zero_points as the third argument, found "
<< PrettyPrint(types[2])));
}
for (const auto& input_zero_point : input_zero_points_tuple->fields) {
CHECK(IsScalarType(input_zero_point, DataType::Int(32))); // input_zero_points[idx]
}
CHECK(IsScalarType(types[3], DataType::Float(32))); // output_scale
CHECK(IsScalarType(types[4], DataType::Int(32))); // output_zero_point
// Collect the input tensor and output tensor devoid of scale and zero points to reuse Relay
// Concatenate infer type function.
Array<Type> tensor_types = {types[0], types[5]};
return ConcatenateRel<ConcatenateAttrs>(tensor_types, 2, attrs, reporter);
}
Expr MakeQnnConcatenate(Expr data, Expr input_scales, Expr input_zero_points, Expr output_scale,
Expr output_zero_point, int axis) {
auto attrs = make_object<ConcatenateAttrs>();
attrs->axis = axis; attrs->axis = axis;
static const Op& op = Op::Get("qnn.concatenate"); static const Op& op = Op::Get("qnn.concatenate");
return CallNode::make(op, {data}, Attrs(attrs), {}); return CallNode::make(op,
{data, input_scales, input_zero_points, output_scale, output_zero_point},
Attrs(attrs), {});
} }
/* /*
...@@ -59,14 +88,14 @@ Expr MakeQnnConcatenate(Expr data, Array<tvm::Expr> input_scales, ...@@ -59,14 +88,14 @@ Expr MakeQnnConcatenate(Expr data, Array<tvm::Expr> input_scales,
Expr ConcatenateQnnCanonicalize(const Attrs& attrs, const Array<Expr>& new_args, Expr ConcatenateQnnCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
const Array<tvm::relay::Type>& arg_types) { const Array<tvm::relay::Type>& arg_types) {
// Get the attrs. // Get the attrs.
CHECK_EQ(new_args.size(), 1); CHECK_EQ(new_args.size(), 5);
auto& data = new_args[0]; auto& data = new_args[0];
const auto* concatenate_attrs = attrs.as<QnnConcatenateAttrs>(); auto& input_scales = new_args[1];
auto& input_zero_points = new_args[2];
auto& output_scale = new_args[3];
auto& output_zero_point = new_args[4];
const auto* concatenate_attrs = attrs.as<ConcatenateAttrs>();
CHECK(concatenate_attrs != nullptr); CHECK(concatenate_attrs != nullptr);
auto input_scales = concatenate_attrs->input_scales;
auto input_zero_points = concatenate_attrs->input_zero_points;
auto output_scale = concatenate_attrs->output_scale;
auto output_zero_point = concatenate_attrs->output_zero_point;
// Get the input dtype and shape. // Get the input dtype and shape.
CHECK_GE(arg_types.size(), 1); CHECK_GE(arg_types.size(), 1);
...@@ -83,21 +112,24 @@ Expr ConcatenateQnnCanonicalize(const Attrs& attrs, const Array<Expr>& new_args, ...@@ -83,21 +112,24 @@ Expr ConcatenateQnnCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
auto tuple_data = data.as<TupleNode>(); auto tuple_data = data.as<TupleNode>();
CHECK(tuple_data != nullptr); CHECK(tuple_data != nullptr);
auto tuple_input_scales = input_scales.as<TupleNode>();
CHECK(tuple_input_scales != nullptr);
auto tuple_input_zero_points = input_zero_points.as<TupleNode>();
CHECK(tuple_input_zero_points != nullptr);
int idx = 0; int idx = 0;
Array<Expr> requantized_exprs; Array<Expr> requantized_exprs;
for (auto quantized_expr : tuple_data->fields) { for (auto quantized_expr : tuple_data->fields) {
// Get the input scale for the idx quantized input tensor. // Get the input scale for the idx quantized input tensor.
auto input_scale_expr = input_scales[idx].as<tvm::ir::FloatImm>(); auto input_scale = tuple_input_scales->fields[idx];
CHECK(input_scale_expr != nullptr);
auto input_scale = input_scale_expr->value;
// Get the zero point for the idx quantized input tensor. // Get the zero point for the idx quantized input tensor.
auto input_zero_point_expr = input_zero_points[idx].as<tvm::ir::IntImm>(); auto input_zero_point = tuple_input_zero_points->fields[idx];
CHECK(input_zero_point_expr != nullptr);
auto input_zero_point = input_zero_point_expr->value;
// Check if output and input qnn params are same. If not, requantize. // Check if output and input qnn params are same. If not, requantize.
if (input_scale != output_scale || input_zero_point != output_zero_point) { if (!IsEqualScalar(input_scale, output_scale) ||
!IsEqualScalar(input_zero_point, output_zero_point)) {
// Get the input shape and dtype. // Get the input shape and dtype.
auto tensor_type = tuple_type->fields[idx].as<TensorTypeNode>(); auto tensor_type = tuple_type->fields[idx].as<TensorTypeNode>();
auto input_dtype = tensor_type->dtype; auto input_dtype = tensor_type->dtype;
...@@ -118,11 +150,15 @@ Expr ConcatenateQnnCanonicalize(const Attrs& attrs, const Array<Expr>& new_args, ...@@ -118,11 +150,15 @@ Expr ConcatenateQnnCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
RELAY_REGISTER_OP("qnn.concatenate") RELAY_REGISTER_OP("qnn.concatenate")
.describe(R"code(Concatenate the quantized input tensors along the given axis. .describe(R"code(Concatenate the quantized input tensors along the given axis.
)code" TVM_ADD_FILELINE) )code" TVM_ADD_FILELINE)
.set_attrs_type<QnnConcatenateAttrs>() .set_attrs_type<ConcatenateAttrs>()
.set_num_inputs(1) .set_num_inputs(5)
.add_argument("data", "Tensor", "The tensor to concatenate.") .add_argument("data", "Tensor", "The tensor to concatenate.")
.add_argument("input_scales", "Tensor", "The quantization scales of the input tensors.")
.add_argument("input_zero_points", "Tensor", "The quantization zero_points of the input tensors.")
.add_argument("output_scale", "Tensor", "The quantization scale of the output tensor.")
.add_argument("output_zero_point", "Tensor", "The quantization zero_point of the output tensor.")
.set_support_level(11) .set_support_level(11)
.add_type_rel("QnnConcatenate", ConcatenateRel<QnnConcatenateAttrs>) .add_type_rel("QnnConcatenate", QnnConcatenateRel)
.set_attr<FTVMLegalize>("FTVMQnnCanonicalize", ConcatenateQnnCanonicalize); .set_attr<FTVMLegalize>("FTVMQnnCanonicalize", ConcatenateQnnCanonicalize);
TVM_REGISTER_API("relay.qnn.op._make.concatenate") TVM_REGISTER_API("relay.qnn.op._make.concatenate")
......
...@@ -35,59 +35,67 @@ namespace relay { ...@@ -35,59 +35,67 @@ namespace relay {
namespace qnn { namespace qnn {
// relay.op.qnn.dense // relay.op.qnn.dense
TVM_REGISTER_NODE_TYPE(QnnDenseAttrs);
bool QnnDenseRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, bool QnnDenseRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) { const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3); CHECK_EQ(types.size(), 7);
const auto* data = types[0].as<TensorTypeNode>(); const auto* data = types[0].as<TensorTypeNode>();
const auto* weight = types[1].as<TensorTypeNode>(); const auto* weight = types[1].as<TensorTypeNode>();
if (data == nullptr || weight == nullptr) return false; if (data == nullptr || weight == nullptr) return false;
const auto* param = attrs.as<QnnDenseAttrs>(); const auto* param = attrs.as<DenseAttrs>();
CHECK(param != nullptr) << "QnnDenseAttrs cannot be nullptr."; CHECK(param != nullptr) << "DenseAttrs cannot be nullptr.";
CHECK(data->dtype == DataType::Int(8) || data->dtype == DataType::UInt(8)) CHECK(data->dtype == DataType::Int(8) || data->dtype == DataType::UInt(8))
<< "Expected quantized dense type(int8, uint8) for input but was " << data->dtype; << "Expected quantized dense type(int8, uint8) for input but was " << data->dtype;
CHECK(weight->dtype == DataType::Int(8) || weight->dtype == DataType::UInt(8)) CHECK(weight->dtype == DataType::Int(8) || weight->dtype == DataType::UInt(8))
<< "Expected quantized dense type(int8, uint8) for weight but was " << weight->dtype; << "Expected quantized dense type(int8, uint8) for weight but was " << weight->dtype;
CHECK(param->out_dtype == DataType::Int(32)) CHECK(param->out_dtype == DataType::Int(32))
<< "Expected quantized dense type(int32) for output but was " << param->out_dtype; << "Expected quantized dense type(int32) for output but was " << param->out_dtype;
// Check the types of scale and zero points.
CHECK(IsScalarType(types[2], DataType::Int(32))); // input_zero_point
CHECK(IsScalarType(types[3], DataType::Int(32))); // kernel_zero_point
CHECK(IsScalarType(types[4], DataType::Float(32))); // input_scale
CHECK(IsScalarType(types[5], DataType::Float(32))); // kernel_scale
CHECK(param->out_dtype.bits() > 0) << "Output dtype bits should be greater than 0."; CHECK(param->out_dtype.bits() > 0) << "Output dtype bits should be greater than 0.";
return DenseRel<QnnDenseAttrs>(types, num_inputs, attrs, reporter);
// Collect the input tensor and output tensor devoid of scale and zero points to reuse Relay
// Dense infer type function.
Array<Type> tensor_types = {types[0], types[1], types[6]};
return DenseRel<DenseAttrs>(tensor_types, 3, attrs, reporter);
} }
// Positional relay function to create quantized dense operator used by frontend FFI. // Positional relay function to create quantized dense operator used by frontend FFI.
Expr MakeQuantizedDense(Expr data, Expr weight, int32_t input_zero_point, Expr MakeQuantizedDense(Expr data, Expr weight, Expr input_zero_point, Expr kernel_zero_point,
int32_t kernel_zero_point, double input_scale, Expr input_scale, Expr kernel_scale, IndexExpr units, DataType out_dtype) {
double kernel_scale, IndexExpr units, auto attrs = make_object<DenseAttrs>();
DataType out_dtype) {
auto attrs = make_object<QnnDenseAttrs>();
attrs->units = std::move(units); attrs->units = std::move(units);
attrs->out_dtype = out_dtype; attrs->out_dtype = out_dtype;
attrs->input_zero_point = input_zero_point;
attrs->kernel_zero_point = kernel_zero_point;
attrs->input_scale = input_scale;
attrs->kernel_scale = kernel_scale;
static const Op& op = Op::Get("qnn.dense"); static const Op& op = Op::Get("qnn.dense");
return CallNode::make(op, {data, weight}, Attrs(attrs), {}); return CallNode::make(
op, {data, weight, input_zero_point, kernel_zero_point, input_scale, kernel_scale},
Attrs(attrs), {});
} }
Expr DenseFirstTerm(const Expr& quantized_data, const Expr& quantized_kernel, Expr DenseFirstTerm(const Expr& quantized_data, const Expr& quantized_kernel,
const QnnDenseAttrs* attrs) { const DenseAttrs* attrs) {
return Dense(quantized_data, quantized_kernel, attrs->units, attrs->out_dtype); return Dense(quantized_data, quantized_kernel, attrs->units, attrs->out_dtype);
} }
Expr DenseSecondTerm(const Expr& quantized_data, const Expr& zp_kernel) { Expr DenseSecondTerm(const Expr& quantized_data, const Expr& kernel_zero_point) {
Array<Integer> axes = {1}; Array<Integer> axes = {1};
return Multiply(zp_kernel, Sum(Cast(quantized_data, DataType::Int(32)), axes, true, false)); return Multiply(kernel_zero_point,
Sum(Cast(quantized_data, DataType::Int(32)), axes, true, false));
} }
Expr DenseThirdTerm(const Expr& quantized_kernel, const Expr& zp_data) { Expr DenseThirdTerm(const Expr& quantized_kernel, const Expr& input_zero_point) {
Array<Integer> axes = {1}; Array<Integer> axes = {1};
return Multiply(zp_data, Sum(Cast(quantized_kernel, DataType::Int(32)), axes, false, false)); return Multiply(input_zero_point,
Sum(Cast(quantized_kernel, DataType::Int(32)), axes, false, false));
} }
Expr DenseFourthTerm(const QnnDenseAttrs* attrs, int reduction_dim_size) { Expr DenseFourthTerm(int input_zero_point_int, int kernel_zero_point_int, int reduction_dim_size) {
int32_t scalar_term = attrs->input_zero_point * attrs->kernel_zero_point * reduction_dim_size; int32_t scalar_term = input_zero_point_int * kernel_zero_point_int * reduction_dim_size;
return MakeConstantScalar(DataType::Int(32), scalar_term); return MakeConstantScalar(DataType::Int(32), scalar_term);
} }
...@@ -125,31 +133,35 @@ Expr DenseFourthTerm(const QnnDenseAttrs* attrs, int reduction_dim_size) { ...@@ -125,31 +133,35 @@ Expr DenseFourthTerm(const QnnDenseAttrs* attrs, int reduction_dim_size) {
*/ */
Expr QnnDenseCanonicalize(const Attrs& attrs, const Array<Expr>& new_args, Expr QnnDenseCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
const Array<tvm::relay::Type>& arg_types) { const Array<tvm::relay::Type>& arg_types) {
CHECK_EQ(new_args.size(), 2); CHECK_EQ(new_args.size(), 6);
Expr quantized_data = new_args[0]; Expr quantized_data = new_args[0];
Expr quantized_kernel = new_args[1]; Expr quantized_kernel = new_args[1];
Expr input_zero_point = new_args[2];
Expr kernel_zero_point = new_args[3];
const auto in_shape = get_shape(arg_types[0]); const auto in_shape = get_shape(arg_types[0]);
const int reduction_dim_size = get_const_int(in_shape[1]); const int reduction_dim_size = get_const_int(in_shape[1]);
const auto* qnn_dense_attrs = attrs.as<QnnDenseAttrs>(); const auto* qnn_dense_attrs = attrs.as<DenseAttrs>();
auto zp_kernel = MakeConstantScalar(DataType::Int(32), qnn_dense_attrs->kernel_zero_point);
auto zp_data = MakeConstantScalar(DataType::Int(32), qnn_dense_attrs->input_zero_point); // Extract the integer zero points.
auto input_zero_point_int = GetScalarFromConstant<int>(input_zero_point);
auto kernel_zero_point_int = GetScalarFromConstant<int>(kernel_zero_point);
// Get all the terms as described in the comments. // Get all the terms as described in the comments.
auto term1 = DenseFirstTerm(quantized_data, quantized_kernel, qnn_dense_attrs); auto term1 = DenseFirstTerm(quantized_data, quantized_kernel, qnn_dense_attrs);
auto term2 = DenseSecondTerm(quantized_data, zp_kernel); auto term2 = DenseSecondTerm(quantized_data, kernel_zero_point);
auto term3 = DenseThirdTerm(quantized_kernel, zp_data); auto term3 = DenseThirdTerm(quantized_kernel, input_zero_point);
auto term4 = DenseFourthTerm(qnn_dense_attrs, reduction_dim_size); auto term4 = DenseFourthTerm(input_zero_point_int, kernel_zero_point_int, reduction_dim_size);
// Combine those 4 terms depending on the zero points to get the best lowering. // Combine those 4 terms depending on the zero points to get the best lowering.
if (qnn_dense_attrs->input_zero_point == 0 && qnn_dense_attrs->kernel_zero_point == 0) { if (input_zero_point_int == 0 && kernel_zero_point_int == 0) {
// term 2, 3 and 4 become zero. // term 2, 3 and 4 become zero.
return term1; return term1;
} else if (qnn_dense_attrs->input_zero_point == 0 && qnn_dense_attrs->kernel_zero_point != 0) { } else if (input_zero_point_int == 0 && kernel_zero_point_int != 0) {
// term 3 and term 4 become zero. // term 3 and term 4 become zero.
return Subtract(term1, term2); return Subtract(term1, term2);
} else if (qnn_dense_attrs->input_zero_point != 0 && qnn_dense_attrs->kernel_zero_point == 0) { } else if (input_zero_point_int != 0 && kernel_zero_point_int == 0) {
// term 2 and term 4 become zero. // term 2 and term 4 become zero.
return Subtract(term1, term3); return Subtract(term1, term3);
} else { } else {
...@@ -166,12 +178,16 @@ RELAY_REGISTER_OP("qnn.dense") ...@@ -166,12 +178,16 @@ RELAY_REGISTER_OP("qnn.dense")
- **weight**: quantized(int8, unit8) `(units, input_dim)` - **weight**: quantized(int8, unit8) `(units, input_dim)`
- **out**: quantized(int32) `(x1, x2, ..., xn, units)`. - **out**: quantized(int32) `(x1, x2, ..., xn, units)`.
)code" TVM_ADD_FILELINE) )code" TVM_ADD_FILELINE)
.set_attrs_type<QnnDenseAttrs>() .set_attrs_type<DenseAttrs>()
.set_num_inputs(2) .set_num_inputs(6)
.add_argument("data", "quantized nD Tensor", "Input data.") .add_argument("data", "quantized nD Tensor", "Input data.")
.add_argument("weight", "quantized 2D Tensor", "Weight matrix.") .add_argument("weight", "quantized 2D Tensor", "Weight matrix.")
.add_argument("input_scale", "Tensor", "The quantization scale of the input tensor.")
.add_argument("input_zero_point", "Tensor", "The quantization zero_point of the input tensor.")
.add_argument("weight_scale", "Tensor", "The quantization scale of the weight tensor.")
.add_argument("weight_zero_point", "Tensor", "The quantization zero_point of the weight tensor.")
.set_support_level(11) .set_support_level(11)
.add_type_rel("QDense", DenseRel<QnnDenseAttrs>) .add_type_rel("QDense", QnnDenseRel)
.set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QnnDenseCanonicalize); .set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QnnDenseCanonicalize);
TVM_REGISTER_API("relay.qnn.op._make.dense") TVM_REGISTER_API("relay.qnn.op._make.dense")
......
...@@ -33,13 +33,11 @@ namespace tvm { ...@@ -33,13 +33,11 @@ namespace tvm {
namespace relay { namespace relay {
namespace qnn { namespace qnn {
TVM_REGISTER_NODE_TYPE(DequantizeAttrs);
bool DequantizeRel(const Array<Type>& types, bool DequantizeRel(const Array<Type>& types,
int num_inputs, int num_inputs,
const Attrs& attrs, const Attrs& attrs,
const TypeReporter& reporter) { const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2); CHECK_EQ(types.size(), 4);
const auto* data = types[0].as<TensorTypeNode>(); const auto* data = types[0].as<TensorTypeNode>();
const auto input_dtype = data->dtype; const auto input_dtype = data->dtype;
CHECK(input_dtype == DataType::Int(8) || CHECK(input_dtype == DataType::Int(8) ||
...@@ -47,42 +45,40 @@ bool DequantizeRel(const Array<Type>& types, ...@@ -47,42 +45,40 @@ bool DequantizeRel(const Array<Type>& types,
input_dtype == DataType::Int(32)) input_dtype == DataType::Int(32))
<< "Input type should be one of the quantized types [unit8, int8, int32] but was " << "Input type should be one of the quantized types [unit8, int8, int32] but was "
<< input_dtype; << input_dtype;
// Check the types of scale and zero points.
CHECK(IsScalarType(types[1], DataType::Float(32))); // input_scale
CHECK(IsScalarType(types[2], DataType::Int(32))); // input_zero_point
const Array<tvm::Expr> oshape = data->shape; const Array<tvm::Expr> oshape = data->shape;
// assign output type, output will always be float 32. // assign output type, output will always be float 32.
reporter->Assign(types[1], TensorTypeNode::make(oshape, DataType::Float(32))); reporter->Assign(types[3], TensorTypeNode::make(oshape, DataType::Float(32)));
return true; return true;
} }
Expr MakeDequantize(Expr data, Expr MakeDequantize(Expr data, Expr input_scale, Expr input_zero_point) {
double input_scale,
int32_t input_zero_point) {
auto attrs = make_object<DequantizeAttrs>();
attrs->input_scale = input_scale;
attrs->input_zero_point = input_zero_point;
// real_value = scale * (quantized_value - zero_point) // real_value = scale * (quantized_value - zero_point)
// A more detailed explanation can be found here - https://github.com/google/gemmlowp/blob/master/doc/quantization.md // A more detailed explanation can be found here -
// https://github.com/google/gemmlowp/blob/master/doc/quantization.md
static const Op& op = Op::Get("qnn.dequantize"); static const Op& op = Op::Get("qnn.dequantize");
return CallNode::make(op, {data}, Attrs(attrs), {}); return CallNode::make(op, {data, input_scale, input_zero_point}, Attrs(), {});
} }
Expr DequantizeLower(const Expr& input_tensor, Expr DequantizeLower(const Expr& input_tensor, const Expr& input_scale,
const DequantizeAttrs* attrs) { const Expr& input_zero_point) {
const auto input_zero_point = MakeConstantScalar(DataType::Int(32), attrs->input_zero_point);
const auto input_scale = MakeConstantScalar(DataType::Float(32), attrs->input_scale);
auto shift = Subtract(Cast(input_tensor, DataType::Int(32)), input_zero_point); auto shift = Subtract(Cast(input_tensor, DataType::Int(32)), input_zero_point);
auto scaled_output = Multiply(Cast(shift, DataType::Float(32)), input_scale); auto scaled_output = Multiply(Cast(shift, DataType::Float(32)), input_scale);
return scaled_output; return scaled_output;
} }
Expr DequantizeQnnCanonicalize(const Attrs& attrs, Expr DequantizeQnnCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
const Array<Expr>& new_args,
const Array<tvm::relay::Type>& types) { const Array<tvm::relay::Type>& types) {
CHECK_EQ(new_args.size(), 1); CHECK_EQ(new_args.size(), 3);
auto& data = new_args[0]; auto& data = new_args[0];
const auto* dequantize_attrs = attrs.as<DequantizeAttrs>(); auto& input_scale = new_args[1];
CHECK(dequantize_attrs != nullptr); auto& input_zero_point = new_args[2];
CHECK_EQ(types.size(), 2); CHECK_EQ(types.size(), 4);
return DequantizeLower(data, dequantize_attrs); return DequantizeLower(data, input_scale, input_zero_point);
} }
RELAY_REGISTER_OP("qnn.dequantize") RELAY_REGISTER_OP("qnn.dequantize")
...@@ -90,9 +86,10 @@ RELAY_REGISTER_OP("qnn.dequantize") ...@@ -90,9 +86,10 @@ RELAY_REGISTER_OP("qnn.dequantize")
The input is always quantized (int8, uint8) and will be converted to float32 given input scale and zero_point. The input is always quantized (int8, uint8) and will be converted to float32 given input scale and zero_point.
- **data**: Quantized tensor of any shape to dequantize. The input data can be of floating point - **data**: Quantized tensor of any shape to dequantize. The input data can be of floating point
)code" TVM_ADD_FILELINE) )code" TVM_ADD_FILELINE)
.set_attrs_type<DequantizeAttrs>() .set_num_inputs(3)
.set_num_inputs(1)
.add_argument("data", "Tensor", "The tensor to dequantize.") .add_argument("data", "Tensor", "The tensor to dequantize.")
.add_argument("input_scale", "Tensor", "The quantization scale of the input tensor.")
.add_argument("input_zero_point", "Tensor", "The quantization zero_point of the input tensor.")
.set_support_level(11) .set_support_level(11)
.add_type_rel("Dequantize", DequantizeRel) .add_type_rel("Dequantize", DequantizeRel)
.set_attr<FTVMLegalize>("FTVMQnnCanonicalize", DequantizeQnnCanonicalize); .set_attr<FTVMLegalize>("FTVMQnnCanonicalize", DequantizeQnnCanonicalize);
......
...@@ -42,20 +42,18 @@ namespace qnn { ...@@ -42,20 +42,18 @@ namespace qnn {
Expr QnnMulCanonicalize(const Attrs& attrs, const Array<Expr>& new_args, Expr QnnMulCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
const Array<tvm::relay::Type>& arg_types) { const Array<tvm::relay::Type>& arg_types) {
// Get the attrs. // Get the attrs.
CHECK_EQ(new_args.size(), 2); CHECK_EQ(new_args.size(), 8);
auto& lhs = new_args[0]; auto& lhs = new_args[0];
auto& rhs = new_args[1]; auto& rhs = new_args[1];
const auto* binary_op_attrs = attrs.as<QnnBinaryOpAttrs>(); auto& lhs_scale = new_args[2];
CHECK(binary_op_attrs != nullptr); auto& lhs_zero_point = new_args[3];
auto lhs_scale = binary_op_attrs->lhs_scale; auto& rhs_scale = new_args[4];
auto lhs_zero_point = binary_op_attrs->lhs_zero_point; auto& rhs_zero_point = new_args[5];
auto rhs_scale = binary_op_attrs->rhs_scale; auto& output_scale = new_args[6];
auto rhs_zero_point = binary_op_attrs->rhs_zero_point; auto& output_zero_point = new_args[7];
auto output_scale = binary_op_attrs->output_scale;
auto output_zero_point = binary_op_attrs->output_zero_point;
// Get the input dtype and shape. // Get the input dtype and shape.
CHECK_EQ(arg_types.size(), 3); CHECK_EQ(arg_types.size(), 9);
auto tensor_type = arg_types[0].as<TensorTypeNode>(); auto tensor_type = arg_types[0].as<TensorTypeNode>();
auto input_dtype = tensor_type->dtype; auto input_dtype = tensor_type->dtype;
auto input_shape = tensor_type->shape; auto input_shape = tensor_type->shape;
...@@ -75,24 +73,28 @@ Expr QnnMulCanonicalize(const Attrs& attrs, const Array<Expr>& new_args, ...@@ -75,24 +73,28 @@ Expr QnnMulCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
auto lhs_shifted = Cast(lhs, DataType::Int(32)); auto lhs_shifted = Cast(lhs, DataType::Int(32));
auto rhs_shifted = Cast(rhs, DataType::Int(32)); auto rhs_shifted = Cast(rhs, DataType::Int(32));
if (lhs_zero_point != 0) { auto zero_scalar = MakeConstantScalar(DataType::Int(32), 0);
auto lhs_zp = MakeConstantScalar(DataType::Int(32), lhs_zero_point); if (!IsEqualScalar(lhs_zero_point, zero_scalar)) {
lhs_shifted = Subtract(lhs_shifted, lhs_zp); lhs_shifted = Subtract(lhs_shifted, lhs_zero_point);
} }
if (rhs_zero_point != 0) { if (!IsEqualScalar(rhs_zero_point, zero_scalar)) {
auto rhs_zp = MakeConstantScalar(DataType::Int(32), rhs_zero_point); rhs_shifted = Subtract(rhs_shifted, rhs_zero_point);
rhs_shifted = Subtract(rhs_shifted, rhs_zp);
} }
// Create a new tensor Q' // Create a new tensor Q'
auto output = Multiply(lhs_shifted, rhs_shifted); auto output = Multiply(lhs_shifted, rhs_shifted);
auto scale_new = rhs_scale * lhs_scale; // Get the adjusted new scale and zero points.
float lhs_scale_float = GetScalarFromConstant<float>(lhs_scale);
float rhs_scale_float = GetScalarFromConstant<float>(rhs_scale);
float new_scale_float = lhs_scale_float * rhs_scale_float;
auto new_input_scale = MakeConstantScalar(DataType::Float(32), new_scale_float);
auto new_input_zero_point = zero_scalar;
// Requantize to get Q_c // Requantize to get Q_c
output = Requantize(output, input_shape, scale_new, 0, output_scale, output = Requantize(output, input_shape, new_input_scale, new_input_zero_point, output_scale,
output_zero_point, input_dtype); output_zero_point, input_dtype);
return output; return output;
} }
......
...@@ -35,6 +35,24 @@ namespace tvm { ...@@ -35,6 +35,24 @@ namespace tvm {
namespace relay { namespace relay {
namespace qnn { namespace qnn {
static inline bool QnnBroadcastRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 9);
// Check the scale and zero point types
CHECK(IsScalarType(types[2], DataType::Float(32))); // lhs_scale
CHECK(IsScalarType(types[3], DataType::Int(32))); // lhs_zero_point
CHECK(IsScalarType(types[4], DataType::Float(32))); // rhs_scale
CHECK(IsScalarType(types[5], DataType::Int(32))); // rhs_zero_point
CHECK(IsScalarType(types[6], DataType::Float(32))); // output_scale
CHECK(IsScalarType(types[7], DataType::Int(32))); // output_zero_point
// Collect the input tensor and output tensor devoid of scale and zero points to reuse Relay
// BroadcastRel infer type function.
Array<Type> tensor_types = {types[0], types[1], types[8]};
return BroadcastRel(tensor_types, 3, attrs, reporter);
}
/*! Quick helper macro /*! Quick helper macro
* - Expose a positional make function to construct the node. * - Expose a positional make function to construct the node.
* - Register op to the registry. * - Register op to the registry.
...@@ -47,24 +65,26 @@ namespace qnn { ...@@ -47,24 +65,26 @@ namespace qnn {
*/ */
#define QNN_REGISTER_BINARY_OP(OpName) \ #define QNN_REGISTER_BINARY_OP(OpName) \
TVM_REGISTER_API("relay.qnn.op._make." OpName) \ TVM_REGISTER_API("relay.qnn.op._make." OpName) \
.set_body_typed<Expr(Expr, Expr, double, int32_t, double, int32_t, double, int32_t)>( \ .set_body_typed<Expr(Expr, Expr, Expr, Expr, Expr, Expr, Expr, Expr)>( \
[](Expr lhs, Expr rhs, double lhs_scale, int32_t lhs_zero_point, double rhs_scale, \ [](Expr lhs, Expr rhs, Expr lhs_scale, Expr lhs_zero_point, Expr rhs_scale, \
int32_t rhs_zero_point, double output_scale, int32_t output_zero_point) { \ Expr rhs_zero_point, Expr output_scale, Expr output_zero_point) { \
auto attrs = make_object<QnnBinaryOpAttrs>(); \
attrs->lhs_scale = lhs_scale; \
attrs->lhs_zero_point = lhs_zero_point; \
attrs->rhs_scale = rhs_scale; \
attrs->rhs_zero_point = rhs_zero_point; \
attrs->output_scale = output_scale; \
attrs->output_zero_point = output_zero_point; \
static const Op& op = Op::Get("qnn." OpName); \ static const Op& op = Op::Get("qnn." OpName); \
return CallNode::make(op, {lhs, rhs}, Attrs(attrs), {}); \ return CallNode::make(op, {lhs, rhs, \
lhs_scale, lhs_zero_point, \
rhs_scale, rhs_zero_point, \
output_scale, output_zero_point}, Attrs(), {}); \
}); \ }); \
RELAY_REGISTER_OP("qnn." OpName) \ RELAY_REGISTER_OP("qnn." OpName) \
.set_num_inputs(2) \ .set_num_inputs(8) \
.add_argument("lhs", "Tensor", "The left hand side quantized tensor.") \ .add_argument("lhs", "Tensor", "The left hand side quantized tensor.") \
.add_argument("rhs", "Tensor", "The right hand side quantized tensor.") \ .add_argument("rhs", "Tensor", "The right hand side quantized tensor.") \
.add_type_rel("Broadcast", BroadcastRel) .add_argument("lhs_scale", "Tensor", "The scale of the lhs tensor.") \
.add_argument("lhs_zero_point", "Tensor", "The zero_point of the lhs tensor.") \
.add_argument("rhs_scale", "Tensor", "The scale of the rhs tensor.") \
.add_argument("rhs_zero_point", "Tensor", "The zero_point of the rhs tensor.") \
.add_argument("output_scale", "Tensor", "The scale of the output tensor.") \
.add_argument("output_zero_point", "Tensor", "The zero_point of the output tensor.") \
.add_type_rel("QnnBroadcast", QnnBroadcastRel)
} // namespace qnn } // namespace qnn
} // namespace relay } // namespace relay
......
...@@ -39,61 +39,61 @@ bool QuantizeRel(const Array<Type>& types, ...@@ -39,61 +39,61 @@ bool QuantizeRel(const Array<Type>& types,
int num_inputs, int num_inputs,
const Attrs& attrs, const Attrs& attrs,
const TypeReporter& reporter) { const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2); CHECK_EQ(types.size(), 4);
const auto* data = types[0].as<TensorTypeNode>(); const auto* data = types[0].as<TensorTypeNode>();
const auto input_dtype = data->dtype; const auto input_dtype = data->dtype;
CHECK(input_dtype == DataType::Float(32)) CHECK(input_dtype == DataType::Float(32))
<< "Input type should be one of float32 but was " << input_dtype; << "Input type should be one of float32 but was " << input_dtype;
// Check the types of scale and zero points.
CHECK(IsScalarType(types[1], DataType::Float(32))); // output_scale
CHECK(IsScalarType(types[2], DataType::Int(32))); // output_zero_point
const auto* quantize_attrs = attrs.as<QuantizeAttrs>(); const auto* quantize_attrs = attrs.as<QuantizeAttrs>();
const Array<tvm::Expr> oshape = data->shape; const Array<tvm::Expr> oshape = data->shape;
const DataType out_dtype = quantize_attrs->out_dtype; const DataType out_dtype = quantize_attrs->out_dtype;
CHECK(out_dtype == DataType::Int(8) || CHECK(out_dtype == DataType::Int(8) || out_dtype == DataType::UInt(8) ||
out_dtype == DataType::UInt(8) ||
out_dtype == DataType::Int(32)) out_dtype == DataType::Int(32))
<< "Output type should be one of [int8, unit8, int32] but was " << out_dtype; << "Output type should be one of [int8, unit8, int32] but was " << out_dtype;
// assign output type // assign output type
reporter->Assign(types[1], TensorTypeNode::make(oshape, out_dtype)); reporter->Assign(types[3], TensorTypeNode::make(oshape, out_dtype));
return true; return true;
} }
Expr MakeQuantize(Expr data, Expr MakeQuantize(Expr data, Expr output_scale, Expr output_zero_point, DataType out_dtype) {
double output_scale,
int32_t output_zero_point,
DataType out_dtype) {
auto attrs = make_object<QuantizeAttrs>(); auto attrs = make_object<QuantizeAttrs>();
attrs->output_scale = output_scale;
attrs->output_zero_point = output_zero_point;
attrs->out_dtype = std::move(out_dtype); attrs->out_dtype = std::move(out_dtype);
// result_quantized_value = result_zero_point + result_real_value / result_scale. // result_quantized_value = result_zero_point + result_real_value / result_scale.
// A more detailed explanation can be found here - https://github.com/google/gemmlowp/blob/master/doc/quantization.md // A more detailed explanation can be found here -
// https://github.com/google/gemmlowp/blob/master/doc/quantization.md
static const Op& op = Op::Get("qnn.quantize"); static const Op& op = Op::Get("qnn.quantize");
return CallNode::make(op, {data}, Attrs(attrs), {}); return CallNode::make(op, {data, output_scale, output_zero_point}, Attrs(attrs), {});
} }
Expr QuantizeLower(const Expr& input_tensor, Expr QuantizeLower(const Expr& input_tensor, const Expr& output_scale,
const QuantizeAttrs* attrs) { const Expr& output_zero_point, const QuantizeAttrs* attrs) {
const auto out_dtype = attrs->out_dtype; const auto out_dtype = attrs->out_dtype;
const auto output_zero_point = MakeConstantScalar(DataType::Float(32), attrs->output_zero_point);
const auto scale = MakeConstantScalar(DataType::Float(32), attrs->output_scale);
const int32_t min_val = GetQmin(out_dtype); const int32_t min_val = GetQmin(out_dtype);
const int32_t max_val = GetQmax(out_dtype); const int32_t max_val = GetQmax(out_dtype);
auto scale_data = Divide(input_tensor, scale); auto scale_data = Divide(input_tensor, output_scale);
auto add_zero_point = Cast(Round(Add(scale_data, output_zero_point)), DataType::Int(32)); auto add_zero_point =
Cast(Round(Add(scale_data, Cast(output_zero_point, DataType::Float(32)))), DataType::Int(32));
auto clamped_output = Clip(add_zero_point, min_val, max_val); auto clamped_output = Clip(add_zero_point, min_val, max_val);
auto clamp_out_dtype = Cast(clamped_output, out_dtype); auto clamp_out_dtype = Cast(clamped_output, out_dtype);
return clamp_out_dtype; return clamp_out_dtype;
} }
Expr QuantizeQnnCanonicalize(const Attrs& attrs, Expr QuantizeQnnCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
const Array<Expr>& new_args,
const Array<tvm::relay::Type>& types) { const Array<tvm::relay::Type>& types) {
CHECK_EQ(new_args.size(), 1); CHECK_EQ(new_args.size(), 3);
auto& data = new_args[0]; auto& data = new_args[0];
auto& output_scale = new_args[1];
auto& output_zero_point = new_args[2];
const auto* quantize_attrs = attrs.as<QuantizeAttrs>(); const auto* quantize_attrs = attrs.as<QuantizeAttrs>();
CHECK(quantize_attrs != nullptr); CHECK(quantize_attrs != nullptr);
CHECK_EQ(types.size(), 2); CHECK_EQ(types.size(), 4);
return QuantizeLower(data, quantize_attrs); return QuantizeLower(data, output_scale, output_zero_point, quantize_attrs);
} }
RELAY_REGISTER_OP("qnn.quantize") RELAY_REGISTER_OP("qnn.quantize")
...@@ -108,8 +108,10 @@ scale and zero point. ...@@ -108,8 +108,10 @@ scale and zero point.
or quantized. or quantized.
)code" TVM_ADD_FILELINE) )code" TVM_ADD_FILELINE)
.set_attrs_type<QuantizeAttrs>() .set_attrs_type<QuantizeAttrs>()
.set_num_inputs(1) .set_num_inputs(3)
.add_argument("data", "Tensor", "The tensor to quantize.") .add_argument("data", "Tensor", "The tensor to quantize.")
.add_argument("output_scale", "Tensor", "The quantization scale of the output tensor.")
.add_argument("output_zero_point", "Tensor", "The quantization zero_point of the output tensor.")
.set_support_level(11) .set_support_level(11)
.add_type_rel("Quantize", QuantizeRel) .add_type_rel("Quantize", QuantizeRel)
.set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QuantizeQnnCanonicalize); .set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QuantizeQnnCanonicalize);
......
...@@ -54,31 +54,35 @@ TVM_REGISTER_NODE_TYPE(RequantizeAttrs); ...@@ -54,31 +54,35 @@ TVM_REGISTER_NODE_TYPE(RequantizeAttrs);
* 4) Add the output zero point. * 4) Add the output zero point.
* 5) Cast to the out_dtype. * 5) Cast to the out_dtype.
*/ */
Expr RequantizeLower(const Expr& input_tensor, const RequantizeAttrs* param, Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale,
const Expr& input_zero_point, const Expr& output_scale,
const Expr& output_zero_point, const RequantizeAttrs* param,
const Array<IndexExpr>& input_shape, const DataType& out_dtype) { const Array<IndexExpr>& input_shape, const DataType& out_dtype) {
double double_multiplier = param->input_scale / param->output_scale; float input_scale_float = GetScalarFromConstant<float>(input_scale);
float output_scale_float = GetScalarFromConstant<float>(output_scale);
double double_multiplier =
static_cast<double>(input_scale_float) / static_cast<double>(output_scale_float);
DataType hp_dtype = DataType::Int(64); DataType hp_dtype = DataType::Int(64);
auto tensor = Cast(input_tensor, hp_dtype); auto tensor = Cast(input_tensor, hp_dtype);
// 1) Subtract the input_zero_point // 1) Subtract the input_zero_point
if (param->input_zero_point != 0) { auto zero_scalar = MakeConstantScalar(DataType::Int(32), 0);
auto input_zp = MakeConstantScalar(hp_dtype, param->input_zero_point); if (!IsEqualScalar(input_zero_point, zero_scalar)) {
tensor = Subtract(tensor, input_zp); tensor = Subtract(tensor, Cast(input_zero_point, hp_dtype));
} }
// 2) If the input and output scales are same, we can skip the fixed point multiplication. // 2) If the input and output scales are same, we can skip the fixed point multiplication.
auto scaled_int64_t = tensor; auto scaled_int64_t = tensor;
if (param->input_scale != param->output_scale) { if (!IsEqualScalar(input_scale, output_scale)) {
scaled_int64_t = scaled_int64_t =
FixedPointMultiply(scaled_int64_t, double_multiplier, input_shape, param->rounding); FixedPointMultiply(scaled_int64_t, double_multiplier, input_shape, param->rounding);
} }
// 3) Add the output zero point. // 3) Add the output zero point.
auto shifted_int64_t = scaled_int64_t; auto shifted_int64_t = scaled_int64_t;
if (param->output_zero_point != 0) { if (!IsEqualScalar(output_zero_point, zero_scalar)) {
auto output_zp = MakeConstantScalar(hp_dtype, param->output_zero_point); shifted_int64_t = Add(Cast(output_zero_point, hp_dtype), scaled_int64_t);
shifted_int64_t = Add(output_zp, scaled_int64_t);
} }
// 4) Clip to the out_dtype min/max. // 4) Clip to the out_dtype min/max.
...@@ -103,13 +107,17 @@ Expr RequantizeLower(const Expr& input_tensor, const RequantizeAttrs* param, ...@@ -103,13 +107,17 @@ Expr RequantizeLower(const Expr& input_tensor, const RequantizeAttrs* param,
*/ */
Expr RequantizeQnnCanonicalize(const Attrs& attrs, const Array<Expr>& new_args, Expr RequantizeQnnCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
const Array<tvm::relay::Type>& types) { const Array<tvm::relay::Type>& types) {
CHECK_EQ(new_args.size(), 1); CHECK_EQ(new_args.size(), 5);
auto& quantized_data = new_args[0]; auto& quantized_data = new_args[0];
auto& input_scale = new_args[1];
auto& input_zero_point = new_args[2];
auto& output_scale = new_args[3];
auto& output_zero_point = new_args[4];
const auto* param = attrs.as<RequantizeAttrs>(); const auto* param = attrs.as<RequantizeAttrs>();
CHECK(param != nullptr); CHECK(param != nullptr);
// Find input shape. // Find input shape.
CHECK_EQ(types.size(), 2); CHECK_EQ(types.size(), 6);
auto in_type = types[0]; auto in_type = types[0];
auto in_tensor_type = in_type.as<TensorTypeNode>(); auto in_tensor_type = in_type.as<TensorTypeNode>();
CHECK(in_tensor_type != nullptr) << "Type information missing." CHECK(in_tensor_type != nullptr) << "Type information missing."
...@@ -117,7 +125,7 @@ Expr RequantizeQnnCanonicalize(const Attrs& attrs, const Array<Expr>& new_args, ...@@ -117,7 +125,7 @@ Expr RequantizeQnnCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
Array<IndexExpr> input_shape = in_tensor_type->shape; Array<IndexExpr> input_shape = in_tensor_type->shape;
// Find the output dtype. // Find the output dtype.
auto out_type = types[1]; auto out_type = types[5];
auto out_tensor_type = out_type.as<TensorTypeNode>(); auto out_tensor_type = out_type.as<TensorTypeNode>();
CHECK(out_tensor_type != nullptr) << "Type information missing." CHECK(out_tensor_type != nullptr) << "Type information missing."
<< " Please run infer_type pass."; << " Please run infer_type pass.";
...@@ -127,7 +135,8 @@ Expr RequantizeQnnCanonicalize(const Attrs& attrs, const Array<Expr>& new_args, ...@@ -127,7 +135,8 @@ Expr RequantizeQnnCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
CHECK(param->rounding == "UPWARD" || param->rounding == "TONEAREST") CHECK(param->rounding == "UPWARD" || param->rounding == "TONEAREST")
<< "QNN requantize supports two rounding modes - UPWARD and " << "QNN requantize supports two rounding modes - UPWARD and "
<< "TONEAREST"; << "TONEAREST";
return RequantizeLower(quantized_data, param, input_shape, out_dtype); return RequantizeLower(quantized_data, input_scale, input_zero_point, output_scale,
output_zero_point, param, input_shape, out_dtype);
} }
/* /*
...@@ -140,7 +149,7 @@ Expr RequantizeQnnCanonicalize(const Attrs& attrs, const Array<Expr>& new_args, ...@@ -140,7 +149,7 @@ Expr RequantizeQnnCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
*/ */
bool RequantizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, bool RequantizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) { const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2); CHECK_EQ(types.size(), 6);
const auto* data = types[0].as<TensorTypeNode>(); const auto* data = types[0].as<TensorTypeNode>();
const auto in_dtype = data->dtype; const auto in_dtype = data->dtype;
CHECK(in_dtype == DataType::Int(8) || CHECK(in_dtype == DataType::Int(8) ||
...@@ -148,6 +157,12 @@ bool RequantizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, ...@@ -148,6 +157,12 @@ bool RequantizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
in_dtype == DataType::Int(32)) in_dtype == DataType::Int(32))
<< "Input type should be one of [int8, uint8, int32] but was " << in_dtype; << "Input type should be one of [int8, uint8, int32] but was " << in_dtype;
// Check the types of scale and zero points.
CHECK(IsScalarType(types[1], DataType::Float(32))); // input_scale
CHECK(IsScalarType(types[2], DataType::Int(32))); // input_zero_point
CHECK(IsScalarType(types[3], DataType::Float(32))); // output_scale
CHECK(IsScalarType(types[4], DataType::Int(32))); // output_zero_point
const Array<tvm::Expr> oshape = data->shape; const Array<tvm::Expr> oshape = data->shape;
// assign output type // assign output type
const RequantizeAttrs* param = attrs.as<RequantizeAttrs>(); const RequantizeAttrs* param = attrs.as<RequantizeAttrs>();
...@@ -156,23 +171,20 @@ bool RequantizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, ...@@ -156,23 +171,20 @@ bool RequantizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
out_dtype == DataType::UInt(8) || out_dtype == DataType::UInt(8) ||
out_dtype == DataType::Int(32)) out_dtype == DataType::Int(32))
<< "Output type should be one of [int8, uint8, int32] but was " << out_dtype; << "Output type should be one of [int8, uint8, int32] but was " << out_dtype;
reporter->Assign(types[1], TensorTypeNode::make(oshape, out_dtype)); reporter->Assign(types[5], TensorTypeNode::make(oshape, out_dtype));
return true; return true;
} }
// Positional relay function to create qnn requantize operator // Positional relay function to create qnn requantize operator
// used by frontend FFI. // used by frontend FFI.
Expr MakeRequantize(Expr data, double input_scale, int32_t input_zero_point, double output_scale, Expr MakeRequantize(Expr data, Expr input_scale, Expr input_zero_point, Expr output_scale,
int32_t output_zero_point, std::string rounding, DataType out_dtype) { Expr output_zero_point, std::string rounding, DataType out_dtype) {
auto attrs = make_object<RequantizeAttrs>(); auto attrs = make_object<RequantizeAttrs>();
attrs->input_scale = std::move(input_scale);
attrs->input_zero_point = std::move(input_zero_point);
attrs->output_scale = std::move(output_scale);
attrs->output_zero_point = std::move(output_zero_point);
attrs->rounding = std::move(rounding); attrs->rounding = std::move(rounding);
attrs->out_dtype = std::move(out_dtype); attrs->out_dtype = std::move(out_dtype);
static const Op& op = Op::Get("qnn.requantize"); static const Op& op = Op::Get("qnn.requantize");
return CallNode::make(op, {data}, Attrs(attrs), {}); return CallNode::make(op, {data, input_scale, input_zero_point, output_scale, output_zero_point},
Attrs(attrs), {});
} }
RELAY_REGISTER_OP("qnn.requantize") RELAY_REGISTER_OP("qnn.requantize")
...@@ -185,8 +197,12 @@ Q_output = zp_output + (scale_input)/(scale_output) * (Q_input - zp_input) ...@@ -185,8 +197,12 @@ Q_output = zp_output + (scale_input)/(scale_output) * (Q_input - zp_input)
)code" TVM_ADD_FILELINE) )code" TVM_ADD_FILELINE)
.set_attrs_type<RequantizeAttrs>() .set_attrs_type<RequantizeAttrs>()
.set_num_inputs(1) .set_num_inputs(5)
.add_argument("data", "Tensor", "The quantized input tensor.") .add_argument("data", "Tensor", "The quantized input tensor.")
.add_argument("input_scale", "Tensor", "The quantization scale of the input tensor.")
.add_argument("input_zero_point", "Tensor", "The quantization zero_point of the input tensor.")
.add_argument("output_scale", "Tensor", "The quantization scale of the output tensor.")
.add_argument("output_zero_point", "Tensor", "The quantization zero_point of the output tensor.")
.set_support_level(11) .set_support_level(11)
.add_type_rel("Requantize", RequantizeRel) .add_type_rel("Requantize", RequantizeRel)
.set_attr<FTVMLegalize>("FTVMQnnCanonicalize", RequantizeQnnCanonicalize); .set_attr<FTVMLegalize>("FTVMQnnCanonicalize", RequantizeQnnCanonicalize);
......
...@@ -77,21 +77,20 @@ static inline const int32_t GetQmax(const DataType& dtype) { ...@@ -77,21 +77,20 @@ static inline const int32_t GetQmax(const DataType& dtype) {
} }
} }
Expr RequantizeLower(const Expr& input_tensor, const RequantizeAttrs* param, Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale,
const Expr& input_zero_point, const Expr& output_scale,
const Expr& output_zero_point, const RequantizeAttrs* param,
const Array<IndexExpr>& input_shape, const DataType& out_dtype); const Array<IndexExpr>& input_shape, const DataType& out_dtype);
static inline Expr Requantize(const Expr& data, const Array<IndexExpr>& input_shape, static inline Expr Requantize(const Expr& data, const Array<IndexExpr>& input_shape,
double input_scale, int32_t input_zero_point, double output_scale, const Expr& input_scale, const Expr& input_zero_point,
int32_t output_zero_point, const DataType& out_dtype, const Expr& output_scale, const Expr& output_zero_point,
const std::string& rounding = "UPWARD") { const DataType& out_dtype, const std::string& rounding = "UPWARD") {
auto attrs = make_object<RequantizeAttrs>(); auto attrs = make_object<RequantizeAttrs>();
attrs->input_scale = std::move(input_scale);
attrs->input_zero_point = std::move(input_zero_point);
attrs->output_scale = std::move(output_scale);
attrs->output_zero_point = std::move(output_zero_point);
attrs->rounding = std::move(rounding); attrs->rounding = std::move(rounding);
attrs->out_dtype = std::move(out_dtype); attrs->out_dtype = std::move(out_dtype);
return RequantizeLower(data, attrs.operator->(), input_shape, out_dtype); return RequantizeLower(data, input_scale, input_zero_point, output_scale, output_zero_point,
attrs.operator->(), input_shape, out_dtype);
} }
static inline int64_t get_const_int(const tvm::Expr& x) { static inline int64_t get_const_int(const tvm::Expr& x) {
...@@ -122,10 +121,22 @@ static inline int64_t get_const_int(const tvm::Expr& x) { ...@@ -122,10 +121,22 @@ static inline int64_t get_const_int(const tvm::Expr& x) {
* 2) Round the result. * 2) Round the result.
* 3) Right shift the result * 3) Right shift the result
*/ */
Expr FixedPointMultiply(Expr tensor, double multiplier, Expr FixedPointMultiply(Expr tensor, double multiplier, const Array<IndexExpr>& input_shape,
const Array<IndexExpr>& input_shape,
const std::string& rounding); const std::string& rounding);
/*
* \brief Checks whether an expr type is scalar of a given data type.
* \param expr_type The type of expr to be checked.
* \param dtype The expected dtype.
* \return True if the type is a scalar of given dtype
*/
static inline bool IsScalarType(const Type& expr_type, const DataType& dtype) {
const auto* scale = expr_type.as<TensorTypeNode>();
CHECK_EQ(scale->shape.size(), 0);
CHECK(scale->dtype == dtype) << "Expected " << dtype << " but got " << scale->dtype;
return true;
}
} // namespace qnn } // namespace qnn
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
......
...@@ -27,12 +27,12 @@ def test_tflite_same_io_qnn_params(): ...@@ -27,12 +27,12 @@ def test_tflite_same_io_qnn_params():
x = relay.var("x", shape=(1, 4), dtype=data_dtype) x = relay.var("x", shape=(1, 4), dtype=data_dtype)
y = relay.var("y", shape=(1, 4), dtype=data_dtype) y = relay.var("y", shape=(1, 4), dtype=data_dtype)
z = relay.qnn.op.add(lhs=x, rhs=y, z = relay.qnn.op.add(lhs=x, rhs=y,
lhs_scale=0.00784314, lhs_scale=relay.const(0.00784314, 'float32'),
lhs_zero_point=127, lhs_zero_point=relay.const(127, 'int32'),
rhs_scale=0.00784314, rhs_scale=relay.const(0.00784314, 'float32'),
rhs_zero_point=127, rhs_zero_point=relay.const(127, 'int32'),
output_scale=0.00784314, output_scale=relay.const(0.00784314, 'float32'),
output_zero_point=127) output_zero_point=relay.const(127, 'int32'))
func = relay.Function([x, y], z) func = relay.Function([x, y], z)
mod = relay.Module.from_expr(func) mod = relay.Module.from_expr(func)
...@@ -65,12 +65,12 @@ def test_tflite_different_io_qnn_params(): ...@@ -65,12 +65,12 @@ def test_tflite_different_io_qnn_params():
x = relay.var("x", shape=(1, 4), dtype=data_dtype) x = relay.var("x", shape=(1, 4), dtype=data_dtype)
y = relay.var("y", shape=(1, 4), dtype=data_dtype) y = relay.var("y", shape=(1, 4), dtype=data_dtype)
z = relay.qnn.op.add(lhs=x, rhs=y, z = relay.qnn.op.add(lhs=x, rhs=y,
lhs_scale=0.0156863, lhs_scale=relay.const(0.0156863, 'float32'),
lhs_zero_point=127, lhs_zero_point=relay.const(127, 'int32'),
rhs_scale=0.0117647, rhs_scale=relay.const(0.0117647, 'float32'),
rhs_zero_point=85, rhs_zero_point=relay.const(85, 'int32'),
output_scale=0.0235294, output_scale=relay.const(0.0235294, 'float32'),
output_zero_point=128) output_zero_point=relay.const(128, 'int32'))
func = relay.Function([x, y], z) func = relay.Function([x, y], z)
mod = relay.Module.from_expr(func) mod = relay.Module.from_expr(func)
...@@ -103,12 +103,12 @@ def test_saturation(): ...@@ -103,12 +103,12 @@ def test_saturation():
x = relay.var("x", shape=(1, 4), dtype=data_dtype) x = relay.var("x", shape=(1, 4), dtype=data_dtype)
y = relay.var("y", shape=(1, 4), dtype=data_dtype) y = relay.var("y", shape=(1, 4), dtype=data_dtype)
z = relay.qnn.op.add(lhs=x, rhs=y, z = relay.qnn.op.add(lhs=x, rhs=y,
lhs_scale=0.125, lhs_scale=relay.const(0.125, 'float32'),
lhs_zero_point=0, lhs_zero_point=relay.const(0, 'int32'),
rhs_scale=0.125, rhs_scale=relay.const(0.125, 'float32'),
rhs_zero_point=0, rhs_zero_point=relay.const(0, 'int32'),
output_scale=0.125, output_scale=relay.const(0.125, 'float32'),
output_zero_point=0) output_zero_point=relay.const(0, 'int32'))
func = relay.Function([x, y], z) func = relay.Function([x, y], z)
mod = relay.Module.from_expr(func) mod = relay.Module.from_expr(func)
...@@ -125,12 +125,12 @@ def test_saturation(): ...@@ -125,12 +125,12 @@ def test_saturation():
# Same params, different scale # Same params, different scale
z = relay.qnn.op.add(lhs=x, rhs=y, z = relay.qnn.op.add(lhs=x, rhs=y,
lhs_scale=0.125, lhs_scale=relay.const(0.125, 'float32'),
lhs_zero_point=0, lhs_zero_point=relay.const(0, 'int32'),
rhs_scale=0.125, rhs_scale=relay.const(0.125, 'float32'),
rhs_zero_point=0, rhs_zero_point=relay.const(0, 'int32'),
output_scale=0.25, output_scale=relay.const(0.25, 'float32'),
output_zero_point=0) output_zero_point=relay.const(0, 'int32'))
func = relay.Function([x, y], z) func = relay.Function([x, y], z)
mod = relay.Module.from_expr(func) mod = relay.Module.from_expr(func)
...@@ -147,12 +147,12 @@ def test_saturation(): ...@@ -147,12 +147,12 @@ def test_saturation():
# Same io params, different output scale # Same io params, different output scale
z = relay.qnn.op.add(lhs=x, rhs=y, z = relay.qnn.op.add(lhs=x, rhs=y,
lhs_scale=0.125, lhs_scale=relay.const(0.125, 'float32'),
lhs_zero_point=0, lhs_zero_point=relay.const(0, 'int32'),
rhs_scale=0.125, rhs_scale=relay.const(0.125, 'float32'),
rhs_zero_point=0, rhs_zero_point=relay.const(0, 'int32'),
output_scale=0.25, output_scale=relay.const(0.25, 'float32'),
output_zero_point=0) output_zero_point=relay.const(0, 'int32'))
func = relay.Function([x, y], z) func = relay.Function([x, y], z)
mod = relay.Module.from_expr(func) mod = relay.Module.from_expr(func)
...@@ -169,12 +169,12 @@ def test_saturation(): ...@@ -169,12 +169,12 @@ def test_saturation():
# All params different # All params different
z = relay.qnn.op.add(lhs=x, rhs=y, z = relay.qnn.op.add(lhs=x, rhs=y,
lhs_scale=0.5, lhs_scale=relay.const(0.5, 'float32'),
lhs_zero_point=0, lhs_zero_point=relay.const(0, 'int32'),
rhs_scale=0.25, rhs_scale=relay.const(0.25, 'float32'),
rhs_zero_point=0, rhs_zero_point=relay.const(0, 'int32'),
output_scale=0.125, output_scale=relay.const(0.125, 'float32'),
output_zero_point=0) output_zero_point=relay.const(0, 'int32'))
func = relay.Function([x, y], z) func = relay.Function([x, y], z)
mod = relay.Module.from_expr(func) mod = relay.Module.from_expr(func)
......
...@@ -26,16 +26,17 @@ def test_same_io_qnn_params(): ...@@ -26,16 +26,17 @@ def test_same_io_qnn_params():
axis = 0 axis = 0
x_data = np.arange(-32, 32, 1).reshape(1, 64).astype(data_dtype) x_data = np.arange(-32, 32, 1).reshape(1, 64).astype(data_dtype)
y_data = np.arange(-64, 64, 2).reshape(1, 64).astype(data_dtype) y_data = np.arange(-64, 64, 2).reshape(1, 64).astype(data_dtype)
x_scale = (62 + 64) / (np.power(2, 32) - 1.0) x_scale = relay.const((62 + 64) / (np.power(2, 32) - 1.0), 'float32')
y_scale = (62 + 64) / (np.power(2, 32) - 1.0) y_scale = relay.const((62 + 64) / (np.power(2, 32) - 1.0), 'float32')
zero = relay.const(0, 'int32')
x = relay.var("x", shape=(1, 64), dtype=data_dtype) x = relay.var("x", shape=(1, 64), dtype=data_dtype)
y = relay.var("y", shape=(1, 64), dtype=data_dtype) y = relay.var("y", shape=(1, 64), dtype=data_dtype)
z = relay.qnn.op.concatenate((x, y), z = relay.qnn.op.concatenate((x, y),
input_scales=[x_scale, y_scale], input_scales=(x_scale, y_scale),
input_zero_points=[0, 0], input_zero_points=(zero, zero),
output_scale=y_scale, output_scale=y_scale,
output_zero_point=0, output_zero_point=zero,
axis=axis) axis=axis)
func = relay.Function([x, y], z) func = relay.Function([x, y], z)
...@@ -54,16 +55,19 @@ def test_different_io_qnn_params(): ...@@ -54,16 +55,19 @@ def test_different_io_qnn_params():
axis = 0 axis = 0
x_data = np.arange(-32, 32, 1).reshape(1, 64).astype(data_dtype) x_data = np.arange(-32, 32, 1).reshape(1, 64).astype(data_dtype)
y_data = np.arange(-64, 64, 2).reshape(1, 64).astype(data_dtype) y_data = np.arange(-64, 64, 2).reshape(1, 64).astype(data_dtype)
x_scale = (62 + 64) / (np.power(2, 32) - 1.0)
y_scale = (62 + 64) / (np.power(2, 32) - 1.0) x_scale = relay.const((62 + 64) / (np.power(2, 32) - 1.0), 'float32')
y_scale = relay.const((62 + 64) / (np.power(2, 32) - 1.0), 'float32')
x_zero_point = relay.const(3, 'int32')
y_zero_point = relay.const(4, 'int32')
x = relay.var("x", shape=(1, 64), dtype=data_dtype) x = relay.var("x", shape=(1, 64), dtype=data_dtype)
y = relay.var("y", shape=(1, 64), dtype=data_dtype) y = relay.var("y", shape=(1, 64), dtype=data_dtype)
z = relay.qnn.op.concatenate((x, y), z = relay.qnn.op.concatenate((x, y),
input_scales=[x_scale, y_scale], input_scales=(x_scale, y_scale),
input_zero_points=[3, 4], input_zero_points=(x_zero_point, y_zero_point),
output_scale=y_scale, output_scale=y_scale,
output_zero_point=1, output_zero_point=relay.const(1, 'int32'),
axis=axis) axis=axis)
func = relay.Function([x, y], z) func = relay.Function([x, y], z)
...@@ -82,16 +86,19 @@ def test_few_same_io_qnn_params(): ...@@ -82,16 +86,19 @@ def test_few_same_io_qnn_params():
axis = 0 axis = 0
x_data = np.arange(-32, 32, 1).reshape(1, 64).astype(data_dtype) x_data = np.arange(-32, 32, 1).reshape(1, 64).astype(data_dtype)
y_data = np.arange(-64, 64, 2).reshape(1, 64).astype(data_dtype) y_data = np.arange(-64, 64, 2).reshape(1, 64).astype(data_dtype)
x_scale = (62 + 64) / (np.power(2, 32) - 1.0)
y_scale = (62 + 64) / (np.power(2, 32) - 1.0) x_scale = relay.const((62 + 64) / (np.power(2, 32) - 1.0), 'float32')
y_scale = relay.const((62 + 64) / (np.power(2, 32) - 1.0), 'float32')
x_zero_point = relay.const(0, 'int32')
y_zero_point = relay.const(1, 'int32')
x = relay.var("x", shape=(1, 64), dtype=data_dtype) x = relay.var("x", shape=(1, 64), dtype=data_dtype)
y = relay.var("y", shape=(1, 64), dtype=data_dtype) y = relay.var("y", shape=(1, 64), dtype=data_dtype)
z = relay.qnn.op.concatenate((x, y), z = relay.qnn.op.concatenate((x, y),
input_scales=[x_scale, y_scale], input_scales=(x_scale, y_scale),
input_zero_points=[0, 1], input_zero_points=(x_zero_point, y_zero_point),
output_scale=y_scale, output_scale=y_scale,
output_zero_point=1, output_zero_point=relay.const(1, 'int32'),
axis=axis) axis=axis)
func = relay.Function([x, y], z) func = relay.Function([x, y], z)
...@@ -110,16 +117,19 @@ def test_same_i_qnn_params(): ...@@ -110,16 +117,19 @@ def test_same_i_qnn_params():
axis = 0 axis = 0
x_data = np.arange(-32, 32, 1).reshape(1, 64).astype(data_dtype) x_data = np.arange(-32, 32, 1).reshape(1, 64).astype(data_dtype)
y_data = np.arange(-64, 64, 2).reshape(1, 64).astype(data_dtype) y_data = np.arange(-64, 64, 2).reshape(1, 64).astype(data_dtype)
x_scale = (62 + 64) / (np.power(2, 32) - 1.0)
y_scale = (62 + 64) / (np.power(2, 32) - 1.0) x_scale = relay.const((62 + 64) / (np.power(2, 32) - 1.0), 'float32')
y_scale = relay.const((62 + 64) / (np.power(2, 32) - 1.0), 'float32')
x_zero_point = relay.const(0, 'int32')
y_zero_point = relay.const(0, 'int32')
x = relay.var("x", shape=(1, 64), dtype=data_dtype) x = relay.var("x", shape=(1, 64), dtype=data_dtype)
y = relay.var("y", shape=(1, 64), dtype=data_dtype) y = relay.var("y", shape=(1, 64), dtype=data_dtype)
z = relay.qnn.op.concatenate((x, y), z = relay.qnn.op.concatenate((x, y),
input_scales=[x_scale, y_scale], input_scales=(x_scale, y_scale),
input_zero_points=[0, 0], input_zero_points=(x_zero_point, y_zero_point),
output_scale=y_scale, output_scale=y_scale,
output_zero_point=1, output_zero_point=relay.const(1, 'int32'),
axis=axis) axis=axis)
func = relay.Function([x, y], z) func = relay.Function([x, y], z)
......
...@@ -52,16 +52,16 @@ def get_ref_func(data, ...@@ -52,16 +52,16 @@ def get_ref_func(data,
shifted_kernel = relay.op.subtract(casted_kernel, shifted_kernel = relay.op.subtract(casted_kernel,
relay.const(kernel_zero_point, "int32")) relay.const(kernel_zero_point, "int32"))
func = relay.op.nn.conv2d(shifted_data, func = relay.op.nn.conv2d(shifted_data,
shifted_kernel, shifted_kernel,
padding=padding, padding=padding,
strides=strides, strides=strides,
dilation=dilation, dilation=dilation,
groups=groups, groups=groups,
channels=channels, channels=channels,
kernel_size=kernel_size, kernel_size=kernel_size,
out_dtype=out_dtype, out_dtype=out_dtype,
data_layout=data_layout, data_layout=data_layout,
kernel_layout=kernel_layout) kernel_layout=kernel_layout)
func = relay.Function(relay.analysis.free_vars(func), func) func = relay.Function(relay.analysis.free_vars(func), func)
return func return func
...@@ -83,10 +83,10 @@ def get_qnn_func(data, ...@@ -83,10 +83,10 @@ def get_qnn_func(data,
channels=None): channels=None):
func = relay.qnn.op.conv2d( func = relay.qnn.op.conv2d(
data, kernel, data, kernel,
input_zero_point=input_zero_point, input_zero_point=relay.const(input_zero_point, 'int32'),
kernel_zero_point=kernel_zero_point, kernel_zero_point=relay.const(kernel_zero_point, 'int32'),
input_scale=input_scale, input_scale=relay.const(input_scale, 'float32'),
kernel_scale=kernel_scale, kernel_scale=relay.const(kernel_scale, 'float32'),
kernel_size=kernel_size, kernel_size=kernel_size,
strides=strides, strides=strides,
dilation=dilation, dilation=dilation,
......
...@@ -179,10 +179,10 @@ def qnn_dense_driver(test_configuration): ...@@ -179,10 +179,10 @@ def qnn_dense_driver(test_configuration):
mod = relay.qnn.op.dense( mod = relay.qnn.op.dense(
quantized_data, quantized_data,
quantized_kernel, quantized_kernel,
test_configuration['input_zero_point'], relay.const(test_configuration['input_zero_point'], 'int32'),
test_configuration['kernel_zero_point'], relay.const(test_configuration['kernel_zero_point'], 'int32'),
test_configuration['input_scale'], relay.const(test_configuration['input_scale'], 'float32'),
test_configuration['kernel_scale'], relay.const(test_configuration['kernel_scale'], 'float32'),
test_configuration['units']) test_configuration['units'])
if test_configuration[bias_name] is not None: if test_configuration[bias_name] is not None:
bias = relay.var(bias_name, bias = relay.var(bias_name,
...@@ -193,10 +193,10 @@ def qnn_dense_driver(test_configuration): ...@@ -193,10 +193,10 @@ def qnn_dense_driver(test_configuration):
requantize_config = test_configuration['requantize'] requantize_config = test_configuration['requantize']
mod = relay.qnn.op.requantize( mod = relay.qnn.op.requantize(
mod, mod,
input_scale=requantize_config['input_scale'], input_scale=relay.const(requantize_config['input_scale'], 'float32'),
input_zero_point=0, input_zero_point=relay.const(0, 'int32'),
output_scale=requantize_config['output_scale'], output_scale=relay.const(requantize_config['output_scale'], 'float32'),
output_zero_point=requantize_config['output_zero_point'], output_zero_point=relay.const(requantize_config['output_zero_point'], 'int32'),
out_dtype=requantize_config['out_dtype']) out_dtype=requantize_config['out_dtype'])
expected_out_dtype = requantize_config['out_dtype'] expected_out_dtype = requantize_config['out_dtype']
......
...@@ -20,61 +20,56 @@ import numpy as np ...@@ -20,61 +20,56 @@ import numpy as np
from tvm import relay from tvm import relay
from tvm.contrib import graph_runtime from tvm.contrib import graph_runtime
def test_dequantize_op(): def quantize_test_driver(in_dtype, quant_args, in_data, verify_output_data):
shape = in_data.shape
input_data = relay.var("input_data", shape=shape, dtype=in_dtype)
input_zero_point = relay.const(quant_args['in_zero_point'], 'int32')
input_scale = relay.const(quant_args['in_scale'], 'float32')
quantized_output = relay.qnn.op.dequantize(input_data, input_scale=input_scale,
input_zero_point=input_zero_point)
mod = relay.Function(relay.analysis.free_vars(quantized_output), quantized_output)
mod = relay.Module.from_expr(mod)
with relay.build_config(opt_level=3):
graph, lib, params = relay.build(mod, "llvm", params=None)
rt_mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0))
rt_mod.set_input(input_data=in_data)
rt_mod.set_input(**params)
rt_mod.run()
res = rt_mod.get_output(0).asnumpy()
np.testing.assert_equal(res, verify_output_data)
assert res.dtype == np.float32
def quantize_test_driver(in_dtype, quant_args, in_data, verify_output_data): def test_uint8_to_float32():
shape = in_data.shape data = np.array([0, 1, 2, 3, 4, 251, 252, 253, 254, 255]) \
input_data = relay.var("input_data", shape=shape, dtype=in_dtype) .astype('uint8') \
input_zero_point = quant_args['in_zero_point'] .reshape((2, 5))
input_scale = quant_args['in_scale'] output = np.array([-63.5, -63, -62.5, -62, -61.5, 62, 62.5, 63, 63.5, 64]) \
quantized_output = relay.qnn.op.dequantize(input_data, input_scale=input_scale, .astype('float32') \
input_zero_point=input_zero_point) .reshape((2, 5))
mod = relay.Function(relay.analysis.free_vars(quantized_output), quantized_output) quant_args = {"in_zero_point":127, "in_scale":0.5}
mod = relay.Module.from_expr(mod) quantize_test_driver(in_dtype='uint8', quant_args=quant_args, in_data=data,
with relay.build_config(opt_level=3): verify_output_data=output)
graph, lib, params = relay.build(mod, "llvm", params=None)
rt_mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0))
rt_mod.set_input(input_data=in_data)
rt_mod.set_input(**params)
rt_mod.run()
res = rt_mod.get_output(0).asnumpy()
np.testing.assert_equal(res, verify_output_data)
assert res.dtype == np.float32
def test_uint8_to_float32(): def test_int8_to_float32():
data = np.array([0, 1, 2, 3, 4, 251, 252, 253, 254, 255]) \ data = np.array([-128, -127, -126, -125, -124, 123, 124, 125, 126, 127]) \
.astype('uint8') \ .astype('int8') \
.reshape((2, 5)) .reshape((2, 5))
output = np.array([-63.5, -63, -62.5, -62, -61.5, 62, 62.5, 63, 63.5, 64]) \ output = np.array([-63.5, -63, -62.5, -62, -61.5, 62, 62.5, 63, 63.5, 64]) \
.astype('float32') \ .astype('float32') \
.reshape((2, 5)) .reshape((2, 5))
quant_args = {"in_zero_point":127, "in_scale":0.5} quant_args = {"in_zero_point": -1, "in_scale": 0.5}
quantize_test_driver(in_dtype='uint8', quant_args=quant_args, in_data=data, quantize_test_driver(in_dtype='int8', quant_args=quant_args, in_data=data,
verify_output_data=output) verify_output_data=output)
def test_int8_to_float32(): def test_int32_to_float32():
data = np.array([-128, -127, -126, -125, -124, 123, 124, 125, 126, 127]) \ data = np.array([113, 29, -1052]).astype('int32')
.astype('int8') \ output = np.array([0.6550452, 0.16810896, -6.098297]).astype('float32')
.reshape((2, 5)) quant_args = {"in_zero_point": 0, "in_scale": 0.0057968604}
output = np.array([-63.5, -63, -62.5, -62, -61.5, 62, 62.5, 63, 63.5, 64]) \ quantize_test_driver(in_dtype='int32', quant_args=quant_args, in_data=data,
.astype('float32') \ verify_output_data=output)
.reshape((2, 5))
quant_args = {"in_zero_point": -1, "in_scale": 0.5}
quantize_test_driver(in_dtype='int8', quant_args=quant_args, in_data=data,
verify_output_data=output)
def test_int32_to_float32():
data = np.array([113, 29, -1052]).astype('int32')
output = np.array([0.6550452, 0.16810896, -6.098297]).astype('float32')
quant_args = {"in_zero_point": 0, "in_scale": 0.0057968604}
quantize_test_driver(in_dtype='int32', quant_args=quant_args, in_data=data,
verify_output_data=output)
if __name__ == "__main__":
test_uint8_to_float32() test_uint8_to_float32()
test_int8_to_float32() test_int8_to_float32()
test_int32_to_float32() test_int32_to_float32()
if __name__ == "__main__":
test_dequantize_op()
...@@ -44,12 +44,12 @@ def test_tflite_same_io_qnn_params(): ...@@ -44,12 +44,12 @@ def test_tflite_same_io_qnn_params():
x = relay.var("x", shape=(1, 4), dtype=data_dtype) x = relay.var("x", shape=(1, 4), dtype=data_dtype)
y = relay.var("y", shape=(1, 4), dtype=data_dtype) y = relay.var("y", shape=(1, 4), dtype=data_dtype)
z = relay.qnn.op.mul(lhs=x, rhs=y, z = relay.qnn.op.mul(lhs=x, rhs=y,
lhs_scale=lhs_scale, lhs_scale=relay.const(lhs_scale, 'float32'),
lhs_zero_point=lhs_zero_point, lhs_zero_point=relay.const(lhs_zero_point, 'int32'),
rhs_scale=rhs_scale, rhs_scale=relay.const(rhs_scale, 'float32'),
rhs_zero_point=rhs_zero_point, rhs_zero_point=relay.const(rhs_zero_point, 'int32'),
output_scale=output_scale, output_scale=relay.const(output_scale, 'float32'),
output_zero_point=output_zero_point) output_zero_point=relay.const(output_zero_point, 'int32'))
func = relay.Function([x, y], z) func = relay.Function([x, y], z)
mod = relay.Module.from_expr(func) mod = relay.Module.from_expr(func)
...@@ -95,12 +95,12 @@ def test_tflite_different_io_qnn_params(): ...@@ -95,12 +95,12 @@ def test_tflite_different_io_qnn_params():
x = relay.var("x", shape=(1, 4), dtype=data_dtype) x = relay.var("x", shape=(1, 4), dtype=data_dtype)
y = relay.var("y", shape=(1, 4), dtype=data_dtype) y = relay.var("y", shape=(1, 4), dtype=data_dtype)
z = relay.qnn.op.mul(lhs=x, rhs=y, z = relay.qnn.op.mul(lhs=x, rhs=y,
lhs_scale=lhs_scale, lhs_scale=relay.const(lhs_scale, 'float32'),
lhs_zero_point=lhs_zero_point, lhs_zero_point=relay.const(lhs_zero_point, 'int32'),
rhs_scale=rhs_scale, rhs_scale=relay.const(rhs_scale, 'float32'),
rhs_zero_point=rhs_zero_point, rhs_zero_point=relay.const(rhs_zero_point, 'int32'),
output_scale=output_scale, output_scale=relay.const(output_scale, 'float32'),
output_zero_point=output_zero_point) output_zero_point=relay.const(output_zero_point, 'int32'))
func = relay.Function([x, y], z) func = relay.Function([x, y], z)
mod = relay.Module.from_expr(func) mod = relay.Module.from_expr(func)
...@@ -141,12 +141,12 @@ def test_saturation(): ...@@ -141,12 +141,12 @@ def test_saturation():
x = relay.var("x", shape=(1, 4), dtype=data_dtype) x = relay.var("x", shape=(1, 4), dtype=data_dtype)
y = relay.var("y", shape=(1, 4), dtype=data_dtype) y = relay.var("y", shape=(1, 4), dtype=data_dtype)
z = relay.qnn.op.mul(lhs=x, rhs=y, z = relay.qnn.op.mul(lhs=x, rhs=y,
lhs_scale=lhs_scale, lhs_scale=relay.const(lhs_scale, 'float32'),
lhs_zero_point=lhs_zero_point, lhs_zero_point=relay.const(lhs_zero_point, 'int32'),
rhs_scale=rhs_scale, rhs_scale=relay.const(rhs_scale, 'float32'),
rhs_zero_point=rhs_zero_point, rhs_zero_point=relay.const(rhs_zero_point, 'int32'),
output_scale=output_scale, output_scale=relay.const(output_scale, 'float32'),
output_zero_point=output_zero_point) output_zero_point=relay.const(output_zero_point, 'int32'))
func = relay.Function([x, y], z) func = relay.Function([x, y], z)
mod = relay.Module.from_expr(func) mod = relay.Module.from_expr(func)
...@@ -172,12 +172,12 @@ def test_saturation(): ...@@ -172,12 +172,12 @@ def test_saturation():
output_scale = 0.25 output_scale = 0.25
z = relay.qnn.op.mul(lhs=x, rhs=y, z = relay.qnn.op.mul(lhs=x, rhs=y,
lhs_scale=lhs_scale, lhs_scale=relay.const(lhs_scale, 'float32'),
lhs_zero_point=lhs_zero_point, lhs_zero_point=relay.const(lhs_zero_point, 'int32'),
rhs_scale=rhs_scale, rhs_scale=relay.const(rhs_scale, 'float32'),
rhs_zero_point=rhs_zero_point, rhs_zero_point=relay.const(rhs_zero_point, 'int32'),
output_scale=output_scale, output_scale=relay.const(output_scale, 'float32'),
output_zero_point=output_zero_point) output_zero_point=relay.const(output_zero_point, 'int32'))
func = relay.Function([x, y], z) func = relay.Function([x, y], z)
mod = relay.Module.from_expr(func) mod = relay.Module.from_expr(func)
...@@ -204,12 +204,12 @@ def test_saturation(): ...@@ -204,12 +204,12 @@ def test_saturation():
output_scale = 0.125 output_scale = 0.125
z = relay.qnn.op.mul(lhs=x, rhs=y, z = relay.qnn.op.mul(lhs=x, rhs=y,
lhs_scale=lhs_scale, lhs_scale=relay.const(lhs_scale, 'float32'),
lhs_zero_point=lhs_zero_point, lhs_zero_point=relay.const(lhs_zero_point, 'int32'),
rhs_scale=rhs_scale, rhs_scale=relay.const(rhs_scale, 'float32'),
rhs_zero_point=rhs_zero_point, rhs_zero_point=relay.const(rhs_zero_point, 'int32'),
output_scale=output_scale, output_scale=relay.const(output_scale, 'float32'),
output_zero_point=output_zero_point) output_zero_point=relay.const(output_zero_point, 'int32'))
func = relay.Function([x, y], z) func = relay.Function([x, y], z)
mod = relay.Module.from_expr(func) mod = relay.Module.from_expr(func)
......
...@@ -20,51 +20,47 @@ import numpy as np ...@@ -20,51 +20,47 @@ import numpy as np
from tvm import relay from tvm import relay
from tvm.contrib import graph_runtime from tvm.contrib import graph_runtime
def test_quantize_op(): def quantize_test_driver(in_dtype, quant_args, out_dtype, in_data, verify_output_data):
shape = in_data.shape
input_data = relay.var("input_data", shape=shape, dtype=in_dtype)
output_zero_point = relay.const(quant_args['out_zero_point'], 'int32')
output_scale = relay.const(quant_args['out_scale'], 'float32')
quantized_output = relay.qnn.op.quantize(input_data, output_scale=output_scale,
output_zero_point=output_zero_point,out_dtype=out_dtype)
mod = relay.Function(relay.analysis.free_vars(quantized_output), quantized_output)
mod = relay.Module.from_expr(mod)
with relay.build_config(opt_level=3):
graph, lib, params = relay.build(mod, "llvm", params=None)
rt_mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0))
rt_mod.set_input(input_data=in_data)
rt_mod.set_input(**params)
rt_mod.run()
res = rt_mod.get_output(0).asnumpy()
np.testing.assert_equal(res, verify_output_data)
assert res.dtype == out_dtype
def quantize_test_driver(in_dtype, quant_args, out_dtype, in_data, verify_output_data): def test_float32_to_uint8():
shape = in_data.shape data = np.array([-63.5, -63, -62.5, -62, -61.5, 62, 62.5, 63, 63.5, 64]) \
input_data = relay.var("input_data", shape=shape, dtype=in_dtype) .astype('float32') \
output_zero_point = quant_args['out_zero_point'] .reshape((2,5))
output_scale = quant_args['out_scale'] output = np.array([0, 1, 2, 3, 4, 251, 252, 253, 254, 255]) \
quantized_output = relay.qnn.op.quantize(input_data, output_scale=output_scale, .astype('uint8') \
output_zero_point=output_zero_point,out_dtype=out_dtype) .reshape((2,5))
mod = relay.Function(relay.analysis.free_vars(quantized_output), quantized_output) quant_args = {"out_zero_point":127, "out_scale":0.5}
mod = relay.Module.from_expr(mod) quantize_test_driver(in_dtype='float32', quant_args=quant_args, out_dtype='uint8', in_data=data,
with relay.build_config(opt_level=3): verify_output_data=output)
graph, lib, params = relay.build(mod, "llvm", params=None)
rt_mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0))
rt_mod.set_input(input_data=in_data)
rt_mod.set_input(**params)
rt_mod.run()
res = rt_mod.get_output(0).asnumpy()
np.testing.assert_equal(res, verify_output_data)
assert res.dtype == out_dtype
def test_float32_to_uint8(): def test_float32_to_int8():
data = np.array([-63.5, -63, -62.5, -62, -61.5, 62, 62.5, 63, 63.5, 64]) \ data = np.array([-63.5, -63, -62.5, -62, -61.5, 62, 62.5, 63, 63.5, 64]) \
.astype('float32') \ .astype('float32') \
.reshape((2,5)) .reshape((2,5))
output = np.array([0, 1, 2, 3, 4, 251, 252, 253, 254, 255]) \ output = np.array([-128, -127, -126, -125, -124, 123, 124, 125, 126, 127]) \
.astype('uint8') \ .astype('int8') \
.reshape((2,5)) .reshape((2,5))
quant_args = {"out_zero_point":127, "out_scale":0.5} quant_args = {"out_zero_point":-1, "out_scale":0.5}
quantize_test_driver(in_dtype='float32', quant_args=quant_args, out_dtype='uint8', in_data=data, quantize_test_driver(in_dtype='float32', quant_args=quant_args, out_dtype='int8', in_data=data,
verify_output_data=output) verify_output_data=output)
def test_float32_to_int8():
data = np.array([-63.5, -63, -62.5, -62, -61.5, 62, 62.5, 63, 63.5, 64]) \
.astype('float32') \
.reshape((2,5))
output = np.array([-128, -127, -126, -125, -124, 123, 124, 125, 126, 127]) \
.astype('int8') \
.reshape((2,5))
quant_args = {"out_zero_point":-1, "out_scale":0.5}
quantize_test_driver(in_dtype='float32', quant_args=quant_args, out_dtype='int8', in_data=data,
verify_output_data=output)
if __name__ == "__main__":
test_float32_to_uint8() test_float32_to_uint8()
test_float32_to_int8() test_float32_to_int8()
if __name__ == "__main__":
test_quantize_op()
...@@ -39,10 +39,10 @@ def get_mod(data_shape, data_dtype, out_dtype, input_scale, output_scale, ...@@ -39,10 +39,10 @@ def get_mod(data_shape, data_dtype, out_dtype, input_scale, output_scale,
dtype=data_dtype) dtype=data_dtype)
mod = relay.qnn.op.requantize( mod = relay.qnn.op.requantize(
quantized_data, quantized_data,
input_scale=input_scale, input_scale=relay.const(input_scale, 'float32'),
input_zero_point=input_zero_point, input_zero_point=relay.const(input_zero_point, 'int32'),
output_scale=output_scale, output_scale=relay.const(output_scale, 'float32'),
output_zero_point=output_zero_point, output_zero_point=relay.const(output_zero_point, 'int32'),
rounding=rounding, rounding=rounding,
out_dtype=out_dtype) out_dtype=out_dtype)
......
...@@ -46,10 +46,10 @@ def test_qnn_legalize(): ...@@ -46,10 +46,10 @@ def test_qnn_legalize():
def before(): def before():
x = relay.var("x", shape=(1, 64, 56, 56), dtype='int8') x = relay.var("x", shape=(1, 64, 56, 56), dtype='int8')
y = relay.qnn.op.requantize(x, y = relay.qnn.op.requantize(x,
input_scale=1, input_scale=relay.const(1, 'float32'),
input_zero_point=0, input_zero_point=relay.const(0, 'int32'),
output_scale=1, output_scale=relay.const(1, 'float32'),
output_zero_point=0, output_zero_point=relay.const(0, 'int32'),
out_dtype='int8') out_dtype='int8')
y = relay.Function([x], y) y = relay.Function([x], y)
return y return y
...@@ -58,10 +58,10 @@ def test_qnn_legalize(): ...@@ -58,10 +58,10 @@ def test_qnn_legalize():
data = inputs[0] data = inputs[0]
data = relay.add(relay.const(0, 'int8'), data) data = relay.add(relay.const(0, 'int8'), data)
y = relay.qnn.op.requantize(data, y = relay.qnn.op.requantize(data,
input_scale=1, input_scale=relay.const(1, 'float32'),
input_zero_point=0, input_zero_point=relay.const(0, 'int32'),
output_scale=1, output_scale=relay.const(1, 'float32'),
output_zero_point=0, output_zero_point=relay.const(0, 'int32'),
out_dtype='int8') out_dtype='int8')
return y return y
...@@ -69,10 +69,10 @@ def test_qnn_legalize(): ...@@ -69,10 +69,10 @@ def test_qnn_legalize():
x = relay.var("x", shape=(1, 64, 56, 56), dtype='int8') x = relay.var("x", shape=(1, 64, 56, 56), dtype='int8')
y = relay.add(relay.const(0, 'int8'), x) y = relay.add(relay.const(0, 'int8'), x)
z = relay.qnn.op.requantize(y, z = relay.qnn.op.requantize(y,
input_scale=1, input_scale=relay.const(1, 'float32'),
input_zero_point=0, input_zero_point=relay.const(0, 'int32'),
output_scale=1, output_scale=relay.const(1, 'float32'),
output_zero_point=0, output_zero_point=relay.const(0, 'int32'),
out_dtype='int8') out_dtype='int8')
z = relay.Function([x], z) z = relay.Function([x], z)
return z return z
...@@ -102,10 +102,10 @@ def test_qnn_legalize_qnn_conv2d(): ...@@ -102,10 +102,10 @@ def test_qnn_legalize_qnn_conv2d():
dtype=kernel_dtype) dtype=kernel_dtype)
func = relay.qnn.op.conv2d( func = relay.qnn.op.conv2d(
data, kernel, data, kernel,
input_zero_point=1, input_zero_point=relay.const(1, 'int32'),
kernel_zero_point=1, kernel_zero_point=relay.const(1, 'int32'),
input_scale=1.0, input_scale=relay.const(1.0, 'float32'),
kernel_scale=1.0, kernel_scale=relay.const(1.0, 'float32'),
kernel_size=(3, 3), kernel_size=(3, 3),
strides=(1, 1), strides=(1, 1),
dilation=(1, 1), dilation=(1, 1),
...@@ -186,10 +186,10 @@ def test_qnn_legalize_qnn_dense(): ...@@ -186,10 +186,10 @@ def test_qnn_legalize_qnn_dense():
dtype=kernel_dtype) dtype=kernel_dtype)
func = relay.qnn.op.dense( func = relay.qnn.op.dense(
data, kernel, data, kernel,
input_zero_point=1, input_zero_point=relay.const(1, 'int32'),
kernel_zero_point=1, kernel_zero_point=relay.const(1, 'int32'),
input_scale=1, input_scale=relay.const(1, 'float32'),
kernel_scale=1, kernel_scale=relay.const(1, 'float32'),
out_dtype='int32') out_dtype='int32')
mod = relay.Function(relay.analysis.free_vars(func), func) mod = relay.Function(relay.analysis.free_vars(func), func)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment