Commit 76efece3 by Animesh Jain Committed by Thierry Moreau

[QNN] Channel wise quantization - Quantize & Requantize (#4629)

parent eecd8cab
......@@ -33,10 +33,15 @@ namespace qnn {
/*! \brief Attribute for requantize operator */
struct RequantizeAttrs : public tvm::AttrsNode<RequantizeAttrs> {
int axis;
std::string rounding;
DataType out_dtype;
TVM_DECLARE_ATTRS(RequantizeAttrs, "relay.attrs.RequantizeAttrs") {
TVM_ATTR_FIELD(axis)
.describe("The output channel axis for channel wise quantization. Default value is -1,"
"which corresponds to the last axis.")
.set_default(-1);
TVM_ATTR_FIELD(rounding).set_default("UPWARD")
.describe("Defines the rounding direction when the value is midway between"
"two representable values. There are two supported modes - UPWARD"
......@@ -56,10 +61,15 @@ struct RequantizeAttrs : public tvm::AttrsNode<RequantizeAttrs> {
/*! \brief Attribute for quantize operator */
struct QuantizeAttrs : public tvm::AttrsNode<QuantizeAttrs> {
DataType out_dtype;
int axis;
TVM_DECLARE_ATTRS(QuantizeAttrs, "relay.attrs.QuantizeAttrs") {
TVM_ATTR_FIELD(out_dtype)
.describe("Output data type, can be one of [int8 or uint8].");
TVM_ATTR_FIELD(axis)
.describe("The output channel axis for channel wise quantization. Default value is -1,"
"which corresponds to the last axis.")
.set_default(-1);
}
};
......
......@@ -26,6 +26,7 @@ def requantize(data,
input_zero_point,
output_scale,
output_zero_point,
axis=-1,
rounding="UPWARD",
out_dtype="int8"):
r"""Requantized operator.
......@@ -53,6 +54,9 @@ def requantize(data,
output_zero_point: tvm.relay.Expr
The zero point of the output tensor.
axis : int
The channel axis for quantization. Default value is -1 which corresponds to the last axis.
rounding : string, optional
Defines the rounding direction when the value is midway between two
representable values.
......@@ -71,6 +75,7 @@ def requantize(data,
input_zero_point,
output_scale,
output_zero_point,
axis,
rounding,
out_dtype)
......@@ -78,6 +83,7 @@ def requantize(data,
def quantize(data,
output_scale,
output_zero_point,
axis=-1,
out_dtype='int8'):
r""" Quantize op
This operator takes float32 as input and produces quantized int8 or unit8 as output.
......@@ -95,6 +101,8 @@ def quantize(data,
The output zero_point.
output_scale : tvm.relay.Expr
The output scale.
axis : int
The channel axis for quantization. Default value is -1 which corresponds to the last axis.
out_dtype : str, optional
The data type of the input tensor. Can be [int8, uint8]
Returns
......@@ -106,6 +114,7 @@ def quantize(data,
return _make.quantize(data,
output_scale,
output_zero_point,
axis,
out_dtype)
......
......@@ -35,6 +35,7 @@
#include <tvm/relay/attrs/transform.h>
#include <tvm/relay/attrs/reduce.h>
#include <string>
#include <vector>
#include <utility>
......@@ -222,13 +223,26 @@ inline bool IsScalar(const Expr& expr) {
}
/*!
* \brief Check if expr is a const scalar.
* \param expr The expr.
* \return True if const scalar.
*/
inline bool IsConstScalar(const Expr& expr) {
const auto* const_expr = expr.as<ConstantNode>();
if (const_expr) {
return const_expr->is_scalar();
}
return false;
}
/*!
* \brief Create a Constant with a scalar
*
* \param dtype The data type.
* \param value The value of the scalar.
* \return A Constant.
*/
template<typename T>
template <typename T>
inline Constant MakeConstantScalar(DataType dtype, T value) {
runtime::NDArray arr = runtime::NDArray::Empty({}, dtype, {kDLCPU, 0});
TVM_DTYPE_DISPATCH(dtype, DType, {
......@@ -236,7 +250,7 @@ inline Constant MakeConstantScalar(DataType dtype, T value) {
// convert to float16
// storage is uint16_t
*static_cast<DType*>(arr->data) =
__truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 10>(static_cast<float>(value));
__truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 10>(static_cast<float>(value));
} else {
*static_cast<DType*>(arr->data) = value;
}
......@@ -245,6 +259,34 @@ inline Constant MakeConstantScalar(DataType dtype, T value) {
}
/*!
* \brief Create a Constant with a tensor.
*
* \param dtype The data type.
* \param value The vector of the tensor values.
* \return A Constant.
*/
template <typename T>
static inline Constant MakeConstantTensor(DataType dtype, std::vector<int64_t> shape,
std::vector<T> value) {
runtime::NDArray arr = runtime::NDArray::Empty(shape, dtype, {kDLCPU, 0});
TVM_DTYPE_DISPATCH(dtype, DType, {
for (size_t i = 0; i < value.size(); i++) {
if (dtype == DataType::Float(16)) {
// convert to float16
// storage is uint16_t
// Similar handling as that in MakeConstantScalar
*(static_cast<DType*>(arr->data) + i) =
__truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 10>(
static_cast<float>(value[i]));
} else {
*(static_cast<DType*>(arr->data) + i) = value[i];
}
}
})
return ConstantNode::make(arr);
}
/*!
* \brief Check if two expressions are equal scalars.
* \param a The expression to be checked.
* \param b The expression to be checked
......@@ -523,6 +565,8 @@ static inline Expr Tile(Expr data, Array<Integer> reps) {
return CallNode::make(op, {data}, Attrs(attrs), {});
}
Expr MakeBroadCastTo(Expr data, Array<IndexExpr> shape);
Expr MakeConcatenate(Expr data, int axis);
Expr MakeRepeat(Expr data, int repeats, int axis);
......
......@@ -45,11 +45,18 @@ bool QuantizeRel(const Array<Type>& types,
CHECK(input_dtype == DataType::Float(32))
<< "Input type should be one of float32 but was " << input_dtype;
// Check the types of scale and zero points.
CHECK(IsScalarType(types[1], DataType::Float(32))); // output_scale
CHECK(IsScalarType(types[2], DataType::Int(32))); // output_zero_point
const auto* quantize_attrs = attrs.as<QuantizeAttrs>();
int axis = quantize_attrs->axis;
axis = (axis == -1) ? data->shape.size() - 1: axis;
CHECK_LT(axis, static_cast<int>(data->shape.size()))
<< "axis " << quantize_attrs->axis << " is out of range";
CHECK_GE(axis, 0)
<< "axis " << quantize_attrs->axis << " is out of range";
// Check and assign types for scale and zero points.
AssignType(types[1], DataType::Float(32), data->shape[axis], reporter); // scale
AssignType(types[2], DataType::Int(32), data->shape[axis], reporter); // zero point
const Array<tvm::Expr> oshape = data->shape;
const DataType out_dtype = quantize_attrs->out_dtype;
CHECK(out_dtype == DataType::Int(8) || out_dtype == DataType::UInt(8) ||
......@@ -60,8 +67,10 @@ bool QuantizeRel(const Array<Type>& types,
return true;
}
Expr MakeQuantize(Expr data, Expr output_scale, Expr output_zero_point, DataType out_dtype) {
Expr MakeQuantize(Expr data, Expr output_scale, Expr output_zero_point, int axis,
DataType out_dtype) {
auto attrs = make_object<QuantizeAttrs>();
attrs->axis = axis;
attrs->out_dtype = std::move(out_dtype);
// result_quantized_value = result_zero_point + result_real_value / result_scale.
// A more detailed explanation can be found here -
......@@ -71,13 +80,29 @@ Expr MakeQuantize(Expr data, Expr output_scale, Expr output_zero_point, DataType
}
Expr QuantizeLower(const Expr& input_tensor, const Expr& output_scale,
const Expr& output_zero_point, const QuantizeAttrs* attrs) {
const Expr& output_zero_point, const Array<IndexExpr>& input_shape,
const QuantizeAttrs* attrs) {
const auto out_dtype = attrs->out_dtype;
const auto axis = attrs->axis;
size_t n_dim = input_shape.size();
auto expanded_output_scale = output_scale;
if (!IsConstScalar(output_scale)) {
expanded_output_scale = ExpandBiasToMatchAxis(output_scale, n_dim, {axis});
}
auto expanded_output_zero_point = output_zero_point;
if (!IsConstScalar(output_zero_point)) {
expanded_output_zero_point = ExpandBiasToMatchAxis(output_zero_point, n_dim, {axis});
}
const int32_t min_val = GetQmin(out_dtype);
const int32_t max_val = GetQmax(out_dtype);
auto scale_data = Divide(input_tensor, output_scale);
auto scale_data = Divide(input_tensor, expanded_output_scale);
auto add_zero_point =
Cast(Round(Add(scale_data, Cast(output_zero_point, DataType::Float(32)))), DataType::Int(32));
Cast(Round(Add(scale_data, Cast(expanded_output_zero_point, DataType::Float(32)))),
DataType::Int(32));
auto clamped_output = Clip(add_zero_point, min_val, max_val);
auto clamp_out_dtype = Cast(clamped_output, out_dtype);
return clamp_out_dtype;
......@@ -92,8 +117,15 @@ Expr QuantizeQnnCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
const auto* quantize_attrs = attrs.as<QuantizeAttrs>();
CHECK(quantize_attrs != nullptr);
// Find input shape.
CHECK_EQ(types.size(), 4);
return QuantizeLower(data, output_scale, output_zero_point, quantize_attrs);
auto in_type = types[0];
auto in_tensor_type = in_type.as<TensorTypeNode>();
CHECK(in_tensor_type != nullptr) << "Type information missing."
<< " Please run infer_type pass.";
Array<IndexExpr> input_shape = in_tensor_type->shape;
return QuantizeLower(data, output_scale, output_zero_point, input_shape, quantize_attrs);
}
RELAY_REGISTER_OP("qnn.quantize")
......
......@@ -58,11 +58,6 @@ Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale,
const Expr& input_zero_point, const Expr& output_scale,
const Expr& output_zero_point, const RequantizeAttrs* param,
const Array<IndexExpr>& input_shape, const DataType& out_dtype) {
float input_scale_float = GetScalarFromConstant<float>(input_scale);
float output_scale_float = GetScalarFromConstant<float>(output_scale);
double double_multiplier =
static_cast<double>(input_scale_float) / static_cast<double>(output_scale_float);
DataType hp_dtype = DataType::Int(64);
auto tensor = Cast(input_tensor, hp_dtype);
......@@ -72,11 +67,34 @@ Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale,
tensor = Subtract(tensor, Cast(input_zero_point, hp_dtype));
}
// 2) If the input and output scales are same, we can skip the fixed point multiplication.
// 2) If the input and output scales are same, we can skip the fixed point multiplication. Check
// if the input scale is per-tensor or per-channel. If it is per-tensor, there is single scale for
// the whole tensor. For per-channel (aka per-axis), there is a vector of scales for the input
// tensor. Depending on the quantization type, the fixed point multiplication routing is called.
auto scaled_int64_t = tensor;
if (!IsEqualScalar(input_scale, output_scale)) {
scaled_int64_t =
FixedPointMultiply(scaled_int64_t, double_multiplier, input_shape, param->rounding);
float output_scale_float = GetScalarFromConstant<float>(output_scale);
if (IsConstScalar(input_scale)) {
// This is per-tensor quantization. Single scale.
float input_scale_float = GetScalarFromConstant<float>(input_scale);
double double_multiplier =
static_cast<double>(input_scale_float) / static_cast<double>(output_scale_float);
// Skip if input and output scales are same.
if (!IsEqualScalar(input_scale, output_scale)) {
scaled_int64_t =
FixedPointMultiply(scaled_int64_t, double_multiplier, input_shape, param->rounding);
}
} else {
// This is per-channel (per=axis) quantization.
std::vector<double> double_multipliers;
auto input_axis_scales = GetFloatVectorFromConstant(input_scale);
for (auto input_axis_scale : input_axis_scales) {
double_multipliers.push_back(static_cast<double>(input_axis_scale) /
static_cast<double>(output_scale_float));
}
int axis = param->axis;
axis = (axis == -1) ? input_shape.size() - 1 : axis;
scaled_int64_t = FixedPointMultiplyPerChannel(scaled_int64_t, double_multipliers, input_shape,
axis, param->rounding);
}
// 3) Add the output zero point.
......@@ -157,16 +175,24 @@ bool RequantizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
in_dtype == DataType::Int(32))
<< "Input type should be one of [int8, uint8, int32] but was " << in_dtype;
// Check the types of scale and zero points.
CHECK(IsScalarType(types[1], DataType::Float(32))); // input_scale
CHECK(IsScalarType(types[2], DataType::Int(32))); // input_zero_point
const RequantizeAttrs* requantize_attrs = attrs.as<RequantizeAttrs>();
int axis = requantize_attrs->axis;
axis = (axis == -1) ? data->shape.size() - 1: axis;
CHECK_LT(axis, static_cast<int>(data->shape.size()))
<< "axis " << requantize_attrs->axis << " is out of range";
CHECK_GE(axis, 0)
<< "axis " << requantize_attrs->axis << " is out of range";
// Check and assign types for scale and zero points.
AssignType(types[1], DataType::Float(32), data->shape[axis], reporter); // input_scale
AssignType(types[2], DataType::Int(32), data->shape[axis], reporter); // input_zero_pt
// For now, requantize output tensor is limited to full tensor uniform quantization.
CHECK(IsScalarType(types[3], DataType::Float(32))); // output_scale
CHECK(IsScalarType(types[4], DataType::Int(32))); // output_zero_point
const Array<tvm::Expr> oshape = data->shape;
// assign output type
const RequantizeAttrs* param = attrs.as<RequantizeAttrs>();
auto out_dtype = param->out_dtype;
auto out_dtype = requantize_attrs->out_dtype;
CHECK(out_dtype == DataType::Int(8) ||
out_dtype == DataType::UInt(8) ||
out_dtype == DataType::Int(32))
......@@ -178,8 +204,9 @@ bool RequantizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
// Positional relay function to create qnn requantize operator
// used by frontend FFI.
Expr MakeRequantize(Expr data, Expr input_scale, Expr input_zero_point, Expr output_scale,
Expr output_zero_point, std::string rounding, DataType out_dtype) {
Expr output_zero_point, int axis, std::string rounding, DataType out_dtype) {
auto attrs = make_object<RequantizeAttrs>();
attrs->axis = axis;
attrs->rounding = std::move(rounding);
attrs->out_dtype = std::move(out_dtype);
static const Op& op = Op::Get("qnn.requantize");
......
......@@ -75,8 +75,8 @@ std::pair<int32_t, int32_t> GetFixedPointMultiplierShift(
return std::make_pair(significand, exponent);
}
Expr FixedPointMultiply(Expr tensor, double multiplier,
const Array<IndexExpr>& input_shape, const std::string& rounding) {
Expr FixedPointMultiply(Expr tensor, double multiplier, const Array<IndexExpr>& input_shape,
const std::string& rounding) {
// Choose high precision datatype to be int64. This is for avoiding overflow
// in multiplication of two int32 values.
DataType hp_dtype = DataType::Int(64);
......@@ -133,6 +133,90 @@ Expr FixedPointMultiply(Expr tensor, double multiplier,
return tensor;
}
Expr FixedPointMultiplyPerChannel(Expr tensor, std::vector<double> multipliers,
const Array<IndexExpr>& input_shape, int channel_axis,
const std::string& rounding) {
// Get the n dim. This will be used to expand the multiplier to match the axis.
size_t n_dim = input_shape.size();
// Get the num of channels/axis along which the tensor was quantized.
int64_t n_channels = (int64_t)multipliers.size();
// Choose high precision datatype to be int64. This is for avoiding overflow
// in multiplication of two int32 values.
DataType hp_dtype = DataType::Int(64);
// 1) Calculating the integer multiplier and integer shift. These are calculated per axis/per
// channel.
std::vector<int32_t> fixed_pt_multipliers, lshifts, rshifts;
for (auto multiplier : multipliers) {
int32_t fixed_pt_multiplier, shift;
std::tie(fixed_pt_multiplier, shift) = GetFixedPointMultiplierShift(multiplier);
int lshift = shift > 0 ? shift : 0;
int rshift = shift > 0 ? 0 : -shift;
fixed_pt_multipliers.push_back(fixed_pt_multiplier);
lshifts.push_back(lshift);
rshifts.push_back(rshift);
}
// 2) Multiply the integer multiplier. Convert lefts shifts into expr and multiply.
auto lshift_expr = MakeConstantTensor(hp_dtype, {n_channels}, lshifts);
auto exp_lshift_expr = ExpandBiasToMatchAxis(lshift_expr, n_dim, {channel_axis});
tensor = LeftShift(tensor, exp_lshift_expr);
// 3) Perform the multiplication in higher precision.
// The scalar is a fixed point value of int32 where the decimal point is
// between bits 31 and 30. After multiplying with input_tensor, the result
// is in int64 where the decimal point is sitting between bits 31 and 30
// (from the right, rightmost bit is bit 0). The computation is performed in
// higher precision to avoid overflow in multiplying two int32 values.
auto fixed_pt_multiplier_expr = MakeConstantTensor(hp_dtype, {n_channels}, fixed_pt_multipliers);
auto exp_fixed_pt_multiplier_expr =
ExpandBiasToMatchAxis(fixed_pt_multiplier_expr, n_dim, {channel_axis});
tensor = Multiply(tensor, exp_fixed_pt_multiplier_expr);
// 4) Find the rounding scalar. This depends on where the final decimal point sits. As we will be
// right shifting the multiplied_t, we need to first calculate the total_rshift. Further, we can
// calculate the pos and neg rounding offset.
std::vector<int64_t> pos_rounding_values, neg_rounding_values, total_rshifts;
for (auto rshift : rshifts) {
int total_rshift = rshift + 31;
total_rshifts.push_back(total_rshift);
pos_rounding_values.push_back((1ll << (total_rshift - 1)));
neg_rounding_values.push_back((1ll << (total_rshift - 1)) - 1);
}
// Make a Relay expr from positive and negative rounding offset values.
auto pos_rounding_value_expr = MakeConstantTensor(hp_dtype, {n_channels}, pos_rounding_values);
auto exp_pos_rounding_value_expr =
ExpandBiasToMatchAxis(pos_rounding_value_expr, n_dim, {channel_axis});
auto neg_rounding_value_expr = MakeConstantTensor(hp_dtype, {n_channels}, neg_rounding_values);
auto exp_neg_rounding_value_expr =
ExpandBiasToMatchAxis(neg_rounding_value_expr, n_dim, {channel_axis});
Expr round_scalar;
if (rounding == "UPWARD") {
round_scalar = exp_pos_rounding_value_expr;
} else if (rounding == "TONEAREST") {
// To satisfy where op shape requirements, the rounding values are broadcasted.
auto pos_rounder = MakeBroadCastTo(exp_pos_rounding_value_expr, input_shape);
auto neg_rounder = MakeBroadCastTo(exp_neg_rounding_value_expr, input_shape);
auto zero_t = Zeros(input_shape, hp_dtype);
round_scalar = Where(GreaterEqual(tensor, zero_t), pos_rounder, neg_rounder);
} else {
LOG(FATAL) << "Rounding mode " << rounding << " not supported.";
}
// Add the rounding scalar.
tensor = Add(tensor, round_scalar);
// 5) Simply right shift the result to get the final output.
auto total_rshift_expr = MakeConstantTensor(hp_dtype, {n_channels}, total_rshifts);
auto exp_total_rshift_expr = ExpandBiasToMatchAxis(total_rshift_expr, n_dim, {channel_axis});
tensor = RightShift(tensor, exp_total_rshift_expr);
return tensor;
}
} // namespace qnn
} // namespace relay
} // namespace tvm
......@@ -30,6 +30,7 @@
#include <tvm/relay/qnn/attrs.h>
#include <limits>
#include <string>
#include <vector>
#include <utility>
namespace tvm {
......@@ -125,18 +126,78 @@ Expr FixedPointMultiply(Expr tensor, double multiplier, const Array<IndexExpr>&
const std::string& rounding);
/*
* \brief Fixed point multiplication between integer tensor with floating point
scalar where the input tensor is per-axis/per-channel quantized..
* \param tensor The quantized input tensor of dtype int64.
* \param multiplier The scalar multiplier.
* \param input_shape Shape of the input tensor.
* \param channel_axis The channel_axis along which the input tensor is quantized. Default value is
-1 which corresponds to the last channel_axis.
* \param rounding "UPWARD" or "TONEAREST". The rounding direction when the value
is midway between" "two representable values.
* \return The sequence of Relay ops for fixed point multiplication.
* \note Original compuation is scale_fp32 * quantized_tensor. To convert into
* integer computation, the multiplication with fp32 vector can be
* replaced by multiplication with an int vector and then right shifting
* the result. This approximates the floating point computation with a
* fixed point computation.
*
* Computation of fixed point multiplication is consist of following
steps:
* 1) Multiply the fixed point multiplier with quantized tensor.
* 2) Round the result.
* 3) Right shift the result
*/
Expr FixedPointMultiplyPerChannel(Expr tensor, std::vector<double> multiplier,
const Array<IndexExpr>& input_shape, int channel_axis,
const std::string& rounding);
/*
* \brief Checks whether an expr type is scalar of a given data type.
* \param expr_type The type of expr to be checked.
* \param dtype The expected dtype.
* \return True if the type is a scalar of given dtype
*/
static inline bool IsScalarType(const Type& expr_type, const DataType& dtype) {
const auto* scale = expr_type.as<TensorTypeNode>();
CHECK_EQ(scale->shape.size(), 0);
CHECK(scale->dtype == dtype) << "Expected " << dtype << " but got " << scale->dtype;
const auto* tensor_type = expr_type.as<TensorTypeNode>();
CHECK_EQ(tensor_type->shape.size(), 0);
CHECK(tensor_type->dtype == dtype) << "Expected " << dtype << " but got " << tensor_type->dtype;
return true;
}
/*
* \brief Checks and assigns types to scale and zero points.
* \param expr_type The type of expr to be checked.
* \param dtype The expected dtype.
* \param shape The shape at C dim of original tensor.
* \param reporter The type reported of original InferType call.
*/
static inline void AssignType(const Type& expr_type, const DataType& dtype, const IndexExpr& shape,
const TypeReporter& reporter) {
// Scale/Zero_points can be either const scalar or a vector with C axis num elems.
const auto* tensor_type = expr_type.as<TensorTypeNode>();
const auto tensor_dtype = tensor_type->dtype;
CHECK(tensor_dtype == dtype) << "Expected type is " << dtype << " but received " << tensor_dtype;
if (tensor_type->shape.size() != 0) {
reporter->Assign(expr_type, TensorTypeNode::make({shape}, tensor_type->dtype));
}
}
static inline std::vector<float> GetFloatVectorFromConstant(const Expr& expr) {
const auto* n = expr.as<ConstantNode>();
std::vector<float> vals;
CHECK(n) << "Expr must be a constant expr - " << AsText(expr, false);
int64_t num_elems = 1;
auto shape = n->data.Shape();
for (size_t i = 0; i < shape.size(); i++) {
num_elems *= shape[i];
}
for (int64_t i = 0; i < num_elems; i++) {
vals.push_back(static_cast<float*>(n->data->data)[i]);
}
return vals;
}
} // namespace qnn
} // namespace relay
} // namespace tvm
......
......@@ -20,13 +20,15 @@ import numpy as np
from tvm import relay
from tvm.contrib import graph_runtime
def quantize_test_driver(in_dtype, quant_args, out_dtype, in_data, verify_output_data):
def quantize_test_driver(in_dtype, quant_args, axis, out_dtype, in_data, verify_output_data):
shape = in_data.shape
input_data = relay.var("input_data", shape=shape, dtype=in_dtype)
output_zero_point = relay.const(quant_args['out_zero_point'], 'int32')
output_scale = relay.const(quant_args['out_scale'], 'float32')
output_zero_point = relay.const(quant_args['out_zero_point'])
output_scale = relay.const(quant_args['out_scale'])
quantized_output = relay.qnn.op.quantize(input_data, output_scale=output_scale,
output_zero_point=output_zero_point,out_dtype=out_dtype)
output_zero_point=output_zero_point,
axis=axis,
out_dtype=out_dtype)
mod = relay.Function(relay.analysis.free_vars(quantized_output), quantized_output)
mod = relay.Module.from_expr(mod)
with relay.build_config(opt_level=3):
......@@ -46,9 +48,9 @@ def test_float32_to_uint8():
output = np.array([0, 1, 2, 3, 4, 251, 252, 253, 254, 255]) \
.astype('uint8') \
.reshape((2,5))
quant_args = {"out_zero_point":127, "out_scale":0.5}
quantize_test_driver(in_dtype='float32', quant_args=quant_args, out_dtype='uint8', in_data=data,
verify_output_data=output)
quant_args = {"out_zero_point":np.int32(127), "out_scale": np.float32(0.5)}
quantize_test_driver(in_dtype='float32', quant_args=quant_args, axis=-1, out_dtype='uint8',
in_data=data, verify_output_data=output)
def test_float32_to_int8():
data = np.array([-63.5, -63, -62.5, -62, -61.5, 62, 62.5, 63, 63.5, 64]) \
......@@ -57,10 +59,37 @@ def test_float32_to_int8():
output = np.array([-128, -127, -126, -125, -124, 123, 124, 125, 126, 127]) \
.astype('int8') \
.reshape((2,5))
quant_args = {"out_zero_point":-1, "out_scale":0.5}
quantize_test_driver(in_dtype='float32', quant_args=quant_args, out_dtype='int8', in_data=data,
verify_output_data=output)
quant_args = {"out_zero_point":np.int32(-1), "out_scale":np.float32(0.5)}
quantize_test_driver(in_dtype='float32', quant_args=quant_args, axis=-1, out_dtype='int8',
in_data=data, verify_output_data=output)
def test_channelwise_axis_0():
data = np.array([-63.5, -63, -62.5, -62, -61.5, 30, 31, 31.5, 31.75, 32]) \
.astype('float32') \
.reshape((2,5))
output = np.array([0, 1, 2, 3, 4, 243, 247, 249, 250, 251]) \
.astype('uint8') \
.reshape((2,5))
quant_args = {"out_zero_point" : np.array([127, 123]).astype('int32'),
"out_scale" : np.array([0.5, 0.25]).astype('float32')}
quantize_test_driver(in_dtype='float32', quant_args=quant_args, axis=0, out_dtype='uint8',
in_data=data, verify_output_data=output)
def test_channelwise_axis_1():
data = np.transpose(np.array([-63.5, -63, -62.5, -62, -61.5, 30, 31, 31.5, 31.75, 32]) \
.astype('float32').reshape((2,5)))
output = np.transpose(np.array([0, 1, 2, 3, 4, 243, 247, 249, 250, 251]) \
.astype('uint8').reshape((2,5)))
quant_args = {"out_zero_point" : np.array([127, 123]).astype('int32'),
"out_scale" : np.array([0.5, 0.25]).astype('float32')}
quantize_test_driver(in_dtype='float32', quant_args=quant_args, axis=1, out_dtype='uint8',
in_data=data, verify_output_data=output)
if __name__ == "__main__":
test_float32_to_uint8()
test_float32_to_int8()
test_channelwise_axis_0()
test_channelwise_axis_1()
......@@ -34,15 +34,27 @@ def verify(mod, goldens):
np.testing.assert_equal(res, golden_output)
def get_mod(data_shape, data_dtype, out_dtype, input_scale, output_scale,
input_zero_point=0, output_zero_point=0, rounding="TONEAREST"):
input_zero_point=0, output_zero_point=0, rounding="TONEAREST",
axis=0):
quantized_data = relay.var("quantized_data", shape=data_shape,
dtype=data_dtype)
if isinstance(input_scale, float):
input_scale_expr = relay.const(input_scale, 'float32')
else:
input_scale_expr = relay.const(np.array(input_scale).astype('float32'))
if isinstance(input_zero_point, float):
input_zero_point_expr = relay.const(input_zero_point, 'int32')
else:
input_zero_point_expr = relay.const(np.array(input_zero_point).astype('int32'))
mod = relay.qnn.op.requantize(
quantized_data,
input_scale=relay.const(input_scale, 'float32'),
input_zero_point=relay.const(input_zero_point, 'int32'),
input_scale=input_scale_expr,
input_zero_point=input_zero_point_expr,
output_scale=relay.const(output_scale, 'float32'),
output_zero_point=relay.const(output_zero_point, 'int32'),
axis=axis,
rounding=rounding,
out_dtype=out_dtype)
......@@ -240,9 +252,70 @@ def test_zero_point():
golden_output = np.subtract(golden_output, 1)
verify(mod, (golden_data, golden_output))
def test_per_channel_same_scale():
# Have same scales, everything within range
golden_data = np.arange(-5, 5, 1).astype('int32').reshape((5,2))
golden_output = golden_data
for rounding in roundings:
mod = get_mod(data_shape=(5, 2),
data_dtype='int32',
out_dtype="int8",
input_scale=[0.5, 0.5],
output_scale=0.5,
axis=1,
rounding=rounding)
verify(mod, (golden_data, golden_output))
# Change axis
golden_data = np.arange(-10, 10, 1).astype('int32').reshape((2,2,5))
golden_output = golden_data
for rounding in roundings:
mod = get_mod(data_shape=(2, 2, 5),
data_dtype='int32',
out_dtype="int8",
input_scale=[0.5, 0.5],
output_scale=0.5,
axis=1,
rounding=rounding)
verify(mod, (golden_data, golden_output))
def test_per_channel_different_scale():
# Have same scales, everything within range
golden_data = np.arange(-5, 5, 1).astype('int32').reshape((5,2))
golden_output = np.array([-5, -2, -3, -1, -1, 0, 1, 1, 3, 2]).reshape((5, 2))
for rounding in roundings:
mod = get_mod(data_shape=(5, 2),
data_dtype='int32',
out_dtype="int8",
input_scale=[0.5, 0.25],
output_scale=0.5,
axis=1,
rounding=rounding)
verify(mod, (golden_data, golden_output))
# Change axis
golden_data = np.arange(-20, 20, 2).astype('int32').reshape((2,2,5))
golden_output = np.array([-20, -18, -16, -14, -12, -5, -4, -3, -2, -1, 0, 2, 4, 6, 8, 5, 6, 7,
8, 9]).reshape((2, 2, 5))
for rounding in roundings:
mod = get_mod(data_shape=(2, 2, 5),
data_dtype='int32',
out_dtype="int8",
input_scale=[0.5, 0.25],
output_scale=0.5,
axis=1,
rounding=rounding)
verify(mod, (golden_data, golden_output))
if __name__ == "__main__":
test_same_scale()
test_downscale()
test_upscale()
test_saturation()
test_zero_point()
test_per_channel_same_scale()
test_per_channel_different_scale()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment