Commit 76efece3 by Animesh Jain Committed by Thierry Moreau

[QNN] Channel wise quantization - Quantize & Requantize (#4629)

parent eecd8cab
...@@ -33,10 +33,15 @@ namespace qnn { ...@@ -33,10 +33,15 @@ namespace qnn {
/*! \brief Attribute for requantize operator */ /*! \brief Attribute for requantize operator */
struct RequantizeAttrs : public tvm::AttrsNode<RequantizeAttrs> { struct RequantizeAttrs : public tvm::AttrsNode<RequantizeAttrs> {
int axis;
std::string rounding; std::string rounding;
DataType out_dtype; DataType out_dtype;
TVM_DECLARE_ATTRS(RequantizeAttrs, "relay.attrs.RequantizeAttrs") { TVM_DECLARE_ATTRS(RequantizeAttrs, "relay.attrs.RequantizeAttrs") {
TVM_ATTR_FIELD(axis)
.describe("The output channel axis for channel wise quantization. Default value is -1,"
"which corresponds to the last axis.")
.set_default(-1);
TVM_ATTR_FIELD(rounding).set_default("UPWARD") TVM_ATTR_FIELD(rounding).set_default("UPWARD")
.describe("Defines the rounding direction when the value is midway between" .describe("Defines the rounding direction when the value is midway between"
"two representable values. There are two supported modes - UPWARD" "two representable values. There are two supported modes - UPWARD"
...@@ -56,10 +61,15 @@ struct RequantizeAttrs : public tvm::AttrsNode<RequantizeAttrs> { ...@@ -56,10 +61,15 @@ struct RequantizeAttrs : public tvm::AttrsNode<RequantizeAttrs> {
/*! \brief Attribute for quantize operator */ /*! \brief Attribute for quantize operator */
struct QuantizeAttrs : public tvm::AttrsNode<QuantizeAttrs> { struct QuantizeAttrs : public tvm::AttrsNode<QuantizeAttrs> {
DataType out_dtype; DataType out_dtype;
int axis;
TVM_DECLARE_ATTRS(QuantizeAttrs, "relay.attrs.QuantizeAttrs") { TVM_DECLARE_ATTRS(QuantizeAttrs, "relay.attrs.QuantizeAttrs") {
TVM_ATTR_FIELD(out_dtype) TVM_ATTR_FIELD(out_dtype)
.describe("Output data type, can be one of [int8 or uint8]."); .describe("Output data type, can be one of [int8 or uint8].");
TVM_ATTR_FIELD(axis)
.describe("The output channel axis for channel wise quantization. Default value is -1,"
"which corresponds to the last axis.")
.set_default(-1);
} }
}; };
......
...@@ -26,6 +26,7 @@ def requantize(data, ...@@ -26,6 +26,7 @@ def requantize(data,
input_zero_point, input_zero_point,
output_scale, output_scale,
output_zero_point, output_zero_point,
axis=-1,
rounding="UPWARD", rounding="UPWARD",
out_dtype="int8"): out_dtype="int8"):
r"""Requantized operator. r"""Requantized operator.
...@@ -53,6 +54,9 @@ def requantize(data, ...@@ -53,6 +54,9 @@ def requantize(data,
output_zero_point: tvm.relay.Expr output_zero_point: tvm.relay.Expr
The zero point of the output tensor. The zero point of the output tensor.
axis : int
The channel axis for quantization. Default value is -1 which corresponds to the last axis.
rounding : string, optional rounding : string, optional
Defines the rounding direction when the value is midway between two Defines the rounding direction when the value is midway between two
representable values. representable values.
...@@ -71,6 +75,7 @@ def requantize(data, ...@@ -71,6 +75,7 @@ def requantize(data,
input_zero_point, input_zero_point,
output_scale, output_scale,
output_zero_point, output_zero_point,
axis,
rounding, rounding,
out_dtype) out_dtype)
...@@ -78,6 +83,7 @@ def requantize(data, ...@@ -78,6 +83,7 @@ def requantize(data,
def quantize(data, def quantize(data,
output_scale, output_scale,
output_zero_point, output_zero_point,
axis=-1,
out_dtype='int8'): out_dtype='int8'):
r""" Quantize op r""" Quantize op
This operator takes float32 as input and produces quantized int8 or unit8 as output. This operator takes float32 as input and produces quantized int8 or unit8 as output.
...@@ -95,6 +101,8 @@ def quantize(data, ...@@ -95,6 +101,8 @@ def quantize(data,
The output zero_point. The output zero_point.
output_scale : tvm.relay.Expr output_scale : tvm.relay.Expr
The output scale. The output scale.
axis : int
The channel axis for quantization. Default value is -1 which corresponds to the last axis.
out_dtype : str, optional out_dtype : str, optional
The data type of the input tensor. Can be [int8, uint8] The data type of the input tensor. Can be [int8, uint8]
Returns Returns
...@@ -106,6 +114,7 @@ def quantize(data, ...@@ -106,6 +114,7 @@ def quantize(data,
return _make.quantize(data, return _make.quantize(data,
output_scale, output_scale,
output_zero_point, output_zero_point,
axis,
out_dtype) out_dtype)
......
...@@ -35,6 +35,7 @@ ...@@ -35,6 +35,7 @@
#include <tvm/relay/attrs/transform.h> #include <tvm/relay/attrs/transform.h>
#include <tvm/relay/attrs/reduce.h> #include <tvm/relay/attrs/reduce.h>
#include <string> #include <string>
#include <vector>
#include <utility> #include <utility>
...@@ -222,13 +223,26 @@ inline bool IsScalar(const Expr& expr) { ...@@ -222,13 +223,26 @@ inline bool IsScalar(const Expr& expr) {
} }
/*! /*!
* \brief Check if expr is a const scalar.
* \param expr The expr.
* \return True if const scalar.
*/
inline bool IsConstScalar(const Expr& expr) {
const auto* const_expr = expr.as<ConstantNode>();
if (const_expr) {
return const_expr->is_scalar();
}
return false;
}
/*!
* \brief Create a Constant with a scalar * \brief Create a Constant with a scalar
* *
* \param dtype The data type. * \param dtype The data type.
* \param value The value of the scalar. * \param value The value of the scalar.
* \return A Constant. * \return A Constant.
*/ */
template<typename T> template <typename T>
inline Constant MakeConstantScalar(DataType dtype, T value) { inline Constant MakeConstantScalar(DataType dtype, T value) {
runtime::NDArray arr = runtime::NDArray::Empty({}, dtype, {kDLCPU, 0}); runtime::NDArray arr = runtime::NDArray::Empty({}, dtype, {kDLCPU, 0});
TVM_DTYPE_DISPATCH(dtype, DType, { TVM_DTYPE_DISPATCH(dtype, DType, {
...@@ -245,6 +259,34 @@ inline Constant MakeConstantScalar(DataType dtype, T value) { ...@@ -245,6 +259,34 @@ inline Constant MakeConstantScalar(DataType dtype, T value) {
} }
/*! /*!
* \brief Create a Constant with a tensor.
*
* \param dtype The data type.
* \param value The vector of the tensor values.
* \return A Constant.
*/
template <typename T>
static inline Constant MakeConstantTensor(DataType dtype, std::vector<int64_t> shape,
std::vector<T> value) {
runtime::NDArray arr = runtime::NDArray::Empty(shape, dtype, {kDLCPU, 0});
TVM_DTYPE_DISPATCH(dtype, DType, {
for (size_t i = 0; i < value.size(); i++) {
if (dtype == DataType::Float(16)) {
// convert to float16
// storage is uint16_t
// Similar handling as that in MakeConstantScalar
*(static_cast<DType*>(arr->data) + i) =
__truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 10>(
static_cast<float>(value[i]));
} else {
*(static_cast<DType*>(arr->data) + i) = value[i];
}
}
})
return ConstantNode::make(arr);
}
/*!
* \brief Check if two expressions are equal scalars. * \brief Check if two expressions are equal scalars.
* \param a The expression to be checked. * \param a The expression to be checked.
* \param b The expression to be checked * \param b The expression to be checked
...@@ -523,6 +565,8 @@ static inline Expr Tile(Expr data, Array<Integer> reps) { ...@@ -523,6 +565,8 @@ static inline Expr Tile(Expr data, Array<Integer> reps) {
return CallNode::make(op, {data}, Attrs(attrs), {}); return CallNode::make(op, {data}, Attrs(attrs), {});
} }
Expr MakeBroadCastTo(Expr data, Array<IndexExpr> shape);
Expr MakeConcatenate(Expr data, int axis); Expr MakeConcatenate(Expr data, int axis);
Expr MakeRepeat(Expr data, int repeats, int axis); Expr MakeRepeat(Expr data, int repeats, int axis);
......
...@@ -45,11 +45,18 @@ bool QuantizeRel(const Array<Type>& types, ...@@ -45,11 +45,18 @@ bool QuantizeRel(const Array<Type>& types,
CHECK(input_dtype == DataType::Float(32)) CHECK(input_dtype == DataType::Float(32))
<< "Input type should be one of float32 but was " << input_dtype; << "Input type should be one of float32 but was " << input_dtype;
// Check the types of scale and zero points.
CHECK(IsScalarType(types[1], DataType::Float(32))); // output_scale
CHECK(IsScalarType(types[2], DataType::Int(32))); // output_zero_point
const auto* quantize_attrs = attrs.as<QuantizeAttrs>(); const auto* quantize_attrs = attrs.as<QuantizeAttrs>();
int axis = quantize_attrs->axis;
axis = (axis == -1) ? data->shape.size() - 1: axis;
CHECK_LT(axis, static_cast<int>(data->shape.size()))
<< "axis " << quantize_attrs->axis << " is out of range";
CHECK_GE(axis, 0)
<< "axis " << quantize_attrs->axis << " is out of range";
// Check and assign types for scale and zero points.
AssignType(types[1], DataType::Float(32), data->shape[axis], reporter); // scale
AssignType(types[2], DataType::Int(32), data->shape[axis], reporter); // zero point
const Array<tvm::Expr> oshape = data->shape; const Array<tvm::Expr> oshape = data->shape;
const DataType out_dtype = quantize_attrs->out_dtype; const DataType out_dtype = quantize_attrs->out_dtype;
CHECK(out_dtype == DataType::Int(8) || out_dtype == DataType::UInt(8) || CHECK(out_dtype == DataType::Int(8) || out_dtype == DataType::UInt(8) ||
...@@ -60,8 +67,10 @@ bool QuantizeRel(const Array<Type>& types, ...@@ -60,8 +67,10 @@ bool QuantizeRel(const Array<Type>& types,
return true; return true;
} }
Expr MakeQuantize(Expr data, Expr output_scale, Expr output_zero_point, DataType out_dtype) { Expr MakeQuantize(Expr data, Expr output_scale, Expr output_zero_point, int axis,
DataType out_dtype) {
auto attrs = make_object<QuantizeAttrs>(); auto attrs = make_object<QuantizeAttrs>();
attrs->axis = axis;
attrs->out_dtype = std::move(out_dtype); attrs->out_dtype = std::move(out_dtype);
// result_quantized_value = result_zero_point + result_real_value / result_scale. // result_quantized_value = result_zero_point + result_real_value / result_scale.
// A more detailed explanation can be found here - // A more detailed explanation can be found here -
...@@ -71,13 +80,29 @@ Expr MakeQuantize(Expr data, Expr output_scale, Expr output_zero_point, DataType ...@@ -71,13 +80,29 @@ Expr MakeQuantize(Expr data, Expr output_scale, Expr output_zero_point, DataType
} }
Expr QuantizeLower(const Expr& input_tensor, const Expr& output_scale, Expr QuantizeLower(const Expr& input_tensor, const Expr& output_scale,
const Expr& output_zero_point, const QuantizeAttrs* attrs) { const Expr& output_zero_point, const Array<IndexExpr>& input_shape,
const QuantizeAttrs* attrs) {
const auto out_dtype = attrs->out_dtype; const auto out_dtype = attrs->out_dtype;
const auto axis = attrs->axis;
size_t n_dim = input_shape.size();
auto expanded_output_scale = output_scale;
if (!IsConstScalar(output_scale)) {
expanded_output_scale = ExpandBiasToMatchAxis(output_scale, n_dim, {axis});
}
auto expanded_output_zero_point = output_zero_point;
if (!IsConstScalar(output_zero_point)) {
expanded_output_zero_point = ExpandBiasToMatchAxis(output_zero_point, n_dim, {axis});
}
const int32_t min_val = GetQmin(out_dtype); const int32_t min_val = GetQmin(out_dtype);
const int32_t max_val = GetQmax(out_dtype); const int32_t max_val = GetQmax(out_dtype);
auto scale_data = Divide(input_tensor, output_scale); auto scale_data = Divide(input_tensor, expanded_output_scale);
auto add_zero_point = auto add_zero_point =
Cast(Round(Add(scale_data, Cast(output_zero_point, DataType::Float(32)))), DataType::Int(32)); Cast(Round(Add(scale_data, Cast(expanded_output_zero_point, DataType::Float(32)))),
DataType::Int(32));
auto clamped_output = Clip(add_zero_point, min_val, max_val); auto clamped_output = Clip(add_zero_point, min_val, max_val);
auto clamp_out_dtype = Cast(clamped_output, out_dtype); auto clamp_out_dtype = Cast(clamped_output, out_dtype);
return clamp_out_dtype; return clamp_out_dtype;
...@@ -92,8 +117,15 @@ Expr QuantizeQnnCanonicalize(const Attrs& attrs, const Array<Expr>& new_args, ...@@ -92,8 +117,15 @@ Expr QuantizeQnnCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
const auto* quantize_attrs = attrs.as<QuantizeAttrs>(); const auto* quantize_attrs = attrs.as<QuantizeAttrs>();
CHECK(quantize_attrs != nullptr); CHECK(quantize_attrs != nullptr);
// Find input shape.
CHECK_EQ(types.size(), 4); CHECK_EQ(types.size(), 4);
return QuantizeLower(data, output_scale, output_zero_point, quantize_attrs); auto in_type = types[0];
auto in_tensor_type = in_type.as<TensorTypeNode>();
CHECK(in_tensor_type != nullptr) << "Type information missing."
<< " Please run infer_type pass.";
Array<IndexExpr> input_shape = in_tensor_type->shape;
return QuantizeLower(data, output_scale, output_zero_point, input_shape, quantize_attrs);
} }
RELAY_REGISTER_OP("qnn.quantize") RELAY_REGISTER_OP("qnn.quantize")
......
...@@ -58,11 +58,6 @@ Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale, ...@@ -58,11 +58,6 @@ Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale,
const Expr& input_zero_point, const Expr& output_scale, const Expr& input_zero_point, const Expr& output_scale,
const Expr& output_zero_point, const RequantizeAttrs* param, const Expr& output_zero_point, const RequantizeAttrs* param,
const Array<IndexExpr>& input_shape, const DataType& out_dtype) { const Array<IndexExpr>& input_shape, const DataType& out_dtype) {
float input_scale_float = GetScalarFromConstant<float>(input_scale);
float output_scale_float = GetScalarFromConstant<float>(output_scale);
double double_multiplier =
static_cast<double>(input_scale_float) / static_cast<double>(output_scale_float);
DataType hp_dtype = DataType::Int(64); DataType hp_dtype = DataType::Int(64);
auto tensor = Cast(input_tensor, hp_dtype); auto tensor = Cast(input_tensor, hp_dtype);
...@@ -72,12 +67,35 @@ Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale, ...@@ -72,12 +67,35 @@ Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale,
tensor = Subtract(tensor, Cast(input_zero_point, hp_dtype)); tensor = Subtract(tensor, Cast(input_zero_point, hp_dtype));
} }
// 2) If the input and output scales are same, we can skip the fixed point multiplication. // 2) If the input and output scales are same, we can skip the fixed point multiplication. Check
// if the input scale is per-tensor or per-channel. If it is per-tensor, there is single scale for
// the whole tensor. For per-channel (aka per-axis), there is a vector of scales for the input
// tensor. Depending on the quantization type, the fixed point multiplication routing is called.
auto scaled_int64_t = tensor; auto scaled_int64_t = tensor;
float output_scale_float = GetScalarFromConstant<float>(output_scale);
if (IsConstScalar(input_scale)) {
// This is per-tensor quantization. Single scale.
float input_scale_float = GetScalarFromConstant<float>(input_scale);
double double_multiplier =
static_cast<double>(input_scale_float) / static_cast<double>(output_scale_float);
// Skip if input and output scales are same.
if (!IsEqualScalar(input_scale, output_scale)) { if (!IsEqualScalar(input_scale, output_scale)) {
scaled_int64_t = scaled_int64_t =
FixedPointMultiply(scaled_int64_t, double_multiplier, input_shape, param->rounding); FixedPointMultiply(scaled_int64_t, double_multiplier, input_shape, param->rounding);
} }
} else {
// This is per-channel (per=axis) quantization.
std::vector<double> double_multipliers;
auto input_axis_scales = GetFloatVectorFromConstant(input_scale);
for (auto input_axis_scale : input_axis_scales) {
double_multipliers.push_back(static_cast<double>(input_axis_scale) /
static_cast<double>(output_scale_float));
}
int axis = param->axis;
axis = (axis == -1) ? input_shape.size() - 1 : axis;
scaled_int64_t = FixedPointMultiplyPerChannel(scaled_int64_t, double_multipliers, input_shape,
axis, param->rounding);
}
// 3) Add the output zero point. // 3) Add the output zero point.
auto shifted_int64_t = scaled_int64_t; auto shifted_int64_t = scaled_int64_t;
...@@ -157,16 +175,24 @@ bool RequantizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, ...@@ -157,16 +175,24 @@ bool RequantizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
in_dtype == DataType::Int(32)) in_dtype == DataType::Int(32))
<< "Input type should be one of [int8, uint8, int32] but was " << in_dtype; << "Input type should be one of [int8, uint8, int32] but was " << in_dtype;
// Check the types of scale and zero points. const RequantizeAttrs* requantize_attrs = attrs.as<RequantizeAttrs>();
CHECK(IsScalarType(types[1], DataType::Float(32))); // input_scale int axis = requantize_attrs->axis;
CHECK(IsScalarType(types[2], DataType::Int(32))); // input_zero_point axis = (axis == -1) ? data->shape.size() - 1: axis;
CHECK_LT(axis, static_cast<int>(data->shape.size()))
<< "axis " << requantize_attrs->axis << " is out of range";
CHECK_GE(axis, 0)
<< "axis " << requantize_attrs->axis << " is out of range";
// Check and assign types for scale and zero points.
AssignType(types[1], DataType::Float(32), data->shape[axis], reporter); // input_scale
AssignType(types[2], DataType::Int(32), data->shape[axis], reporter); // input_zero_pt
// For now, requantize output tensor is limited to full tensor uniform quantization.
CHECK(IsScalarType(types[3], DataType::Float(32))); // output_scale CHECK(IsScalarType(types[3], DataType::Float(32))); // output_scale
CHECK(IsScalarType(types[4], DataType::Int(32))); // output_zero_point CHECK(IsScalarType(types[4], DataType::Int(32))); // output_zero_point
const Array<tvm::Expr> oshape = data->shape; const Array<tvm::Expr> oshape = data->shape;
// assign output type // assign output type
const RequantizeAttrs* param = attrs.as<RequantizeAttrs>(); auto out_dtype = requantize_attrs->out_dtype;
auto out_dtype = param->out_dtype;
CHECK(out_dtype == DataType::Int(8) || CHECK(out_dtype == DataType::Int(8) ||
out_dtype == DataType::UInt(8) || out_dtype == DataType::UInt(8) ||
out_dtype == DataType::Int(32)) out_dtype == DataType::Int(32))
...@@ -178,8 +204,9 @@ bool RequantizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, ...@@ -178,8 +204,9 @@ bool RequantizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
// Positional relay function to create qnn requantize operator // Positional relay function to create qnn requantize operator
// used by frontend FFI. // used by frontend FFI.
Expr MakeRequantize(Expr data, Expr input_scale, Expr input_zero_point, Expr output_scale, Expr MakeRequantize(Expr data, Expr input_scale, Expr input_zero_point, Expr output_scale,
Expr output_zero_point, std::string rounding, DataType out_dtype) { Expr output_zero_point, int axis, std::string rounding, DataType out_dtype) {
auto attrs = make_object<RequantizeAttrs>(); auto attrs = make_object<RequantizeAttrs>();
attrs->axis = axis;
attrs->rounding = std::move(rounding); attrs->rounding = std::move(rounding);
attrs->out_dtype = std::move(out_dtype); attrs->out_dtype = std::move(out_dtype);
static const Op& op = Op::Get("qnn.requantize"); static const Op& op = Op::Get("qnn.requantize");
......
...@@ -75,8 +75,8 @@ std::pair<int32_t, int32_t> GetFixedPointMultiplierShift( ...@@ -75,8 +75,8 @@ std::pair<int32_t, int32_t> GetFixedPointMultiplierShift(
return std::make_pair(significand, exponent); return std::make_pair(significand, exponent);
} }
Expr FixedPointMultiply(Expr tensor, double multiplier, Expr FixedPointMultiply(Expr tensor, double multiplier, const Array<IndexExpr>& input_shape,
const Array<IndexExpr>& input_shape, const std::string& rounding) { const std::string& rounding) {
// Choose high precision datatype to be int64. This is for avoiding overflow // Choose high precision datatype to be int64. This is for avoiding overflow
// in multiplication of two int32 values. // in multiplication of two int32 values.
DataType hp_dtype = DataType::Int(64); DataType hp_dtype = DataType::Int(64);
...@@ -133,6 +133,90 @@ Expr FixedPointMultiply(Expr tensor, double multiplier, ...@@ -133,6 +133,90 @@ Expr FixedPointMultiply(Expr tensor, double multiplier,
return tensor; return tensor;
} }
Expr FixedPointMultiplyPerChannel(Expr tensor, std::vector<double> multipliers,
const Array<IndexExpr>& input_shape, int channel_axis,
const std::string& rounding) {
// Get the n dim. This will be used to expand the multiplier to match the axis.
size_t n_dim = input_shape.size();
// Get the num of channels/axis along which the tensor was quantized.
int64_t n_channels = (int64_t)multipliers.size();
// Choose high precision datatype to be int64. This is for avoiding overflow
// in multiplication of two int32 values.
DataType hp_dtype = DataType::Int(64);
// 1) Calculating the integer multiplier and integer shift. These are calculated per axis/per
// channel.
std::vector<int32_t> fixed_pt_multipliers, lshifts, rshifts;
for (auto multiplier : multipliers) {
int32_t fixed_pt_multiplier, shift;
std::tie(fixed_pt_multiplier, shift) = GetFixedPointMultiplierShift(multiplier);
int lshift = shift > 0 ? shift : 0;
int rshift = shift > 0 ? 0 : -shift;
fixed_pt_multipliers.push_back(fixed_pt_multiplier);
lshifts.push_back(lshift);
rshifts.push_back(rshift);
}
// 2) Multiply the integer multiplier. Convert lefts shifts into expr and multiply.
auto lshift_expr = MakeConstantTensor(hp_dtype, {n_channels}, lshifts);
auto exp_lshift_expr = ExpandBiasToMatchAxis(lshift_expr, n_dim, {channel_axis});
tensor = LeftShift(tensor, exp_lshift_expr);
// 3) Perform the multiplication in higher precision.
// The scalar is a fixed point value of int32 where the decimal point is
// between bits 31 and 30. After multiplying with input_tensor, the result
// is in int64 where the decimal point is sitting between bits 31 and 30
// (from the right, rightmost bit is bit 0). The computation is performed in
// higher precision to avoid overflow in multiplying two int32 values.
auto fixed_pt_multiplier_expr = MakeConstantTensor(hp_dtype, {n_channels}, fixed_pt_multipliers);
auto exp_fixed_pt_multiplier_expr =
ExpandBiasToMatchAxis(fixed_pt_multiplier_expr, n_dim, {channel_axis});
tensor = Multiply(tensor, exp_fixed_pt_multiplier_expr);
// 4) Find the rounding scalar. This depends on where the final decimal point sits. As we will be
// right shifting the multiplied_t, we need to first calculate the total_rshift. Further, we can
// calculate the pos and neg rounding offset.
std::vector<int64_t> pos_rounding_values, neg_rounding_values, total_rshifts;
for (auto rshift : rshifts) {
int total_rshift = rshift + 31;
total_rshifts.push_back(total_rshift);
pos_rounding_values.push_back((1ll << (total_rshift - 1)));
neg_rounding_values.push_back((1ll << (total_rshift - 1)) - 1);
}
// Make a Relay expr from positive and negative rounding offset values.
auto pos_rounding_value_expr = MakeConstantTensor(hp_dtype, {n_channels}, pos_rounding_values);
auto exp_pos_rounding_value_expr =
ExpandBiasToMatchAxis(pos_rounding_value_expr, n_dim, {channel_axis});
auto neg_rounding_value_expr = MakeConstantTensor(hp_dtype, {n_channels}, neg_rounding_values);
auto exp_neg_rounding_value_expr =
ExpandBiasToMatchAxis(neg_rounding_value_expr, n_dim, {channel_axis});
Expr round_scalar;
if (rounding == "UPWARD") {
round_scalar = exp_pos_rounding_value_expr;
} else if (rounding == "TONEAREST") {
// To satisfy where op shape requirements, the rounding values are broadcasted.
auto pos_rounder = MakeBroadCastTo(exp_pos_rounding_value_expr, input_shape);
auto neg_rounder = MakeBroadCastTo(exp_neg_rounding_value_expr, input_shape);
auto zero_t = Zeros(input_shape, hp_dtype);
round_scalar = Where(GreaterEqual(tensor, zero_t), pos_rounder, neg_rounder);
} else {
LOG(FATAL) << "Rounding mode " << rounding << " not supported.";
}
// Add the rounding scalar.
tensor = Add(tensor, round_scalar);
// 5) Simply right shift the result to get the final output.
auto total_rshift_expr = MakeConstantTensor(hp_dtype, {n_channels}, total_rshifts);
auto exp_total_rshift_expr = ExpandBiasToMatchAxis(total_rshift_expr, n_dim, {channel_axis});
tensor = RightShift(tensor, exp_total_rshift_expr);
return tensor;
}
} // namespace qnn } // namespace qnn
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
#include <tvm/relay/qnn/attrs.h> #include <tvm/relay/qnn/attrs.h>
#include <limits> #include <limits>
#include <string> #include <string>
#include <vector>
#include <utility> #include <utility>
namespace tvm { namespace tvm {
...@@ -125,18 +126,78 @@ Expr FixedPointMultiply(Expr tensor, double multiplier, const Array<IndexExpr>& ...@@ -125,18 +126,78 @@ Expr FixedPointMultiply(Expr tensor, double multiplier, const Array<IndexExpr>&
const std::string& rounding); const std::string& rounding);
/* /*
* \brief Fixed point multiplication between integer tensor with floating point
scalar where the input tensor is per-axis/per-channel quantized..
* \param tensor The quantized input tensor of dtype int64.
* \param multiplier The scalar multiplier.
* \param input_shape Shape of the input tensor.
* \param channel_axis The channel_axis along which the input tensor is quantized. Default value is
-1 which corresponds to the last channel_axis.
* \param rounding "UPWARD" or "TONEAREST". The rounding direction when the value
is midway between" "two representable values.
* \return The sequence of Relay ops for fixed point multiplication.
* \note Original compuation is scale_fp32 * quantized_tensor. To convert into
* integer computation, the multiplication with fp32 vector can be
* replaced by multiplication with an int vector and then right shifting
* the result. This approximates the floating point computation with a
* fixed point computation.
*
* Computation of fixed point multiplication is consist of following
steps:
* 1) Multiply the fixed point multiplier with quantized tensor.
* 2) Round the result.
* 3) Right shift the result
*/
Expr FixedPointMultiplyPerChannel(Expr tensor, std::vector<double> multiplier,
const Array<IndexExpr>& input_shape, int channel_axis,
const std::string& rounding);
/*
* \brief Checks whether an expr type is scalar of a given data type. * \brief Checks whether an expr type is scalar of a given data type.
* \param expr_type The type of expr to be checked. * \param expr_type The type of expr to be checked.
* \param dtype The expected dtype. * \param dtype The expected dtype.
* \return True if the type is a scalar of given dtype * \return True if the type is a scalar of given dtype
*/ */
static inline bool IsScalarType(const Type& expr_type, const DataType& dtype) { static inline bool IsScalarType(const Type& expr_type, const DataType& dtype) {
const auto* scale = expr_type.as<TensorTypeNode>(); const auto* tensor_type = expr_type.as<TensorTypeNode>();
CHECK_EQ(scale->shape.size(), 0); CHECK_EQ(tensor_type->shape.size(), 0);
CHECK(scale->dtype == dtype) << "Expected " << dtype << " but got " << scale->dtype; CHECK(tensor_type->dtype == dtype) << "Expected " << dtype << " but got " << tensor_type->dtype;
return true; return true;
} }
/*
* \brief Checks and assigns types to scale and zero points.
* \param expr_type The type of expr to be checked.
* \param dtype The expected dtype.
* \param shape The shape at C dim of original tensor.
* \param reporter The type reported of original InferType call.
*/
static inline void AssignType(const Type& expr_type, const DataType& dtype, const IndexExpr& shape,
const TypeReporter& reporter) {
// Scale/Zero_points can be either const scalar or a vector with C axis num elems.
const auto* tensor_type = expr_type.as<TensorTypeNode>();
const auto tensor_dtype = tensor_type->dtype;
CHECK(tensor_dtype == dtype) << "Expected type is " << dtype << " but received " << tensor_dtype;
if (tensor_type->shape.size() != 0) {
reporter->Assign(expr_type, TensorTypeNode::make({shape}, tensor_type->dtype));
}
}
static inline std::vector<float> GetFloatVectorFromConstant(const Expr& expr) {
const auto* n = expr.as<ConstantNode>();
std::vector<float> vals;
CHECK(n) << "Expr must be a constant expr - " << AsText(expr, false);
int64_t num_elems = 1;
auto shape = n->data.Shape();
for (size_t i = 0; i < shape.size(); i++) {
num_elems *= shape[i];
}
for (int64_t i = 0; i < num_elems; i++) {
vals.push_back(static_cast<float*>(n->data->data)[i]);
}
return vals;
}
} // namespace qnn } // namespace qnn
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
......
...@@ -20,13 +20,15 @@ import numpy as np ...@@ -20,13 +20,15 @@ import numpy as np
from tvm import relay from tvm import relay
from tvm.contrib import graph_runtime from tvm.contrib import graph_runtime
def quantize_test_driver(in_dtype, quant_args, out_dtype, in_data, verify_output_data): def quantize_test_driver(in_dtype, quant_args, axis, out_dtype, in_data, verify_output_data):
shape = in_data.shape shape = in_data.shape
input_data = relay.var("input_data", shape=shape, dtype=in_dtype) input_data = relay.var("input_data", shape=shape, dtype=in_dtype)
output_zero_point = relay.const(quant_args['out_zero_point'], 'int32') output_zero_point = relay.const(quant_args['out_zero_point'])
output_scale = relay.const(quant_args['out_scale'], 'float32') output_scale = relay.const(quant_args['out_scale'])
quantized_output = relay.qnn.op.quantize(input_data, output_scale=output_scale, quantized_output = relay.qnn.op.quantize(input_data, output_scale=output_scale,
output_zero_point=output_zero_point,out_dtype=out_dtype) output_zero_point=output_zero_point,
axis=axis,
out_dtype=out_dtype)
mod = relay.Function(relay.analysis.free_vars(quantized_output), quantized_output) mod = relay.Function(relay.analysis.free_vars(quantized_output), quantized_output)
mod = relay.Module.from_expr(mod) mod = relay.Module.from_expr(mod)
with relay.build_config(opt_level=3): with relay.build_config(opt_level=3):
...@@ -46,9 +48,9 @@ def test_float32_to_uint8(): ...@@ -46,9 +48,9 @@ def test_float32_to_uint8():
output = np.array([0, 1, 2, 3, 4, 251, 252, 253, 254, 255]) \ output = np.array([0, 1, 2, 3, 4, 251, 252, 253, 254, 255]) \
.astype('uint8') \ .astype('uint8') \
.reshape((2,5)) .reshape((2,5))
quant_args = {"out_zero_point":127, "out_scale":0.5} quant_args = {"out_zero_point":np.int32(127), "out_scale": np.float32(0.5)}
quantize_test_driver(in_dtype='float32', quant_args=quant_args, out_dtype='uint8', in_data=data, quantize_test_driver(in_dtype='float32', quant_args=quant_args, axis=-1, out_dtype='uint8',
verify_output_data=output) in_data=data, verify_output_data=output)
def test_float32_to_int8(): def test_float32_to_int8():
data = np.array([-63.5, -63, -62.5, -62, -61.5, 62, 62.5, 63, 63.5, 64]) \ data = np.array([-63.5, -63, -62.5, -62, -61.5, 62, 62.5, 63, 63.5, 64]) \
...@@ -57,10 +59,37 @@ def test_float32_to_int8(): ...@@ -57,10 +59,37 @@ def test_float32_to_int8():
output = np.array([-128, -127, -126, -125, -124, 123, 124, 125, 126, 127]) \ output = np.array([-128, -127, -126, -125, -124, 123, 124, 125, 126, 127]) \
.astype('int8') \ .astype('int8') \
.reshape((2,5)) .reshape((2,5))
quant_args = {"out_zero_point":-1, "out_scale":0.5} quant_args = {"out_zero_point":np.int32(-1), "out_scale":np.float32(0.5)}
quantize_test_driver(in_dtype='float32', quant_args=quant_args, out_dtype='int8', in_data=data, quantize_test_driver(in_dtype='float32', quant_args=quant_args, axis=-1, out_dtype='int8',
verify_output_data=output) in_data=data, verify_output_data=output)
def test_channelwise_axis_0():
data = np.array([-63.5, -63, -62.5, -62, -61.5, 30, 31, 31.5, 31.75, 32]) \
.astype('float32') \
.reshape((2,5))
output = np.array([0, 1, 2, 3, 4, 243, 247, 249, 250, 251]) \
.astype('uint8') \
.reshape((2,5))
quant_args = {"out_zero_point" : np.array([127, 123]).astype('int32'),
"out_scale" : np.array([0.5, 0.25]).astype('float32')}
quantize_test_driver(in_dtype='float32', quant_args=quant_args, axis=0, out_dtype='uint8',
in_data=data, verify_output_data=output)
def test_channelwise_axis_1():
data = np.transpose(np.array([-63.5, -63, -62.5, -62, -61.5, 30, 31, 31.5, 31.75, 32]) \
.astype('float32').reshape((2,5)))
output = np.transpose(np.array([0, 1, 2, 3, 4, 243, 247, 249, 250, 251]) \
.astype('uint8').reshape((2,5)))
quant_args = {"out_zero_point" : np.array([127, 123]).astype('int32'),
"out_scale" : np.array([0.5, 0.25]).astype('float32')}
quantize_test_driver(in_dtype='float32', quant_args=quant_args, axis=1, out_dtype='uint8',
in_data=data, verify_output_data=output)
if __name__ == "__main__": if __name__ == "__main__":
test_float32_to_uint8() test_float32_to_uint8()
test_float32_to_int8() test_float32_to_int8()
test_channelwise_axis_0()
test_channelwise_axis_1()
...@@ -34,15 +34,27 @@ def verify(mod, goldens): ...@@ -34,15 +34,27 @@ def verify(mod, goldens):
np.testing.assert_equal(res, golden_output) np.testing.assert_equal(res, golden_output)
def get_mod(data_shape, data_dtype, out_dtype, input_scale, output_scale, def get_mod(data_shape, data_dtype, out_dtype, input_scale, output_scale,
input_zero_point=0, output_zero_point=0, rounding="TONEAREST"): input_zero_point=0, output_zero_point=0, rounding="TONEAREST",
axis=0):
quantized_data = relay.var("quantized_data", shape=data_shape, quantized_data = relay.var("quantized_data", shape=data_shape,
dtype=data_dtype) dtype=data_dtype)
if isinstance(input_scale, float):
input_scale_expr = relay.const(input_scale, 'float32')
else:
input_scale_expr = relay.const(np.array(input_scale).astype('float32'))
if isinstance(input_zero_point, float):
input_zero_point_expr = relay.const(input_zero_point, 'int32')
else:
input_zero_point_expr = relay.const(np.array(input_zero_point).astype('int32'))
mod = relay.qnn.op.requantize( mod = relay.qnn.op.requantize(
quantized_data, quantized_data,
input_scale=relay.const(input_scale, 'float32'), input_scale=input_scale_expr,
input_zero_point=relay.const(input_zero_point, 'int32'), input_zero_point=input_zero_point_expr,
output_scale=relay.const(output_scale, 'float32'), output_scale=relay.const(output_scale, 'float32'),
output_zero_point=relay.const(output_zero_point, 'int32'), output_zero_point=relay.const(output_zero_point, 'int32'),
axis=axis,
rounding=rounding, rounding=rounding,
out_dtype=out_dtype) out_dtype=out_dtype)
...@@ -240,9 +252,70 @@ def test_zero_point(): ...@@ -240,9 +252,70 @@ def test_zero_point():
golden_output = np.subtract(golden_output, 1) golden_output = np.subtract(golden_output, 1)
verify(mod, (golden_data, golden_output)) verify(mod, (golden_data, golden_output))
def test_per_channel_same_scale():
# Have same scales, everything within range
golden_data = np.arange(-5, 5, 1).astype('int32').reshape((5,2))
golden_output = golden_data
for rounding in roundings:
mod = get_mod(data_shape=(5, 2),
data_dtype='int32',
out_dtype="int8",
input_scale=[0.5, 0.5],
output_scale=0.5,
axis=1,
rounding=rounding)
verify(mod, (golden_data, golden_output))
# Change axis
golden_data = np.arange(-10, 10, 1).astype('int32').reshape((2,2,5))
golden_output = golden_data
for rounding in roundings:
mod = get_mod(data_shape=(2, 2, 5),
data_dtype='int32',
out_dtype="int8",
input_scale=[0.5, 0.5],
output_scale=0.5,
axis=1,
rounding=rounding)
verify(mod, (golden_data, golden_output))
def test_per_channel_different_scale():
# Have same scales, everything within range
golden_data = np.arange(-5, 5, 1).astype('int32').reshape((5,2))
golden_output = np.array([-5, -2, -3, -1, -1, 0, 1, 1, 3, 2]).reshape((5, 2))
for rounding in roundings:
mod = get_mod(data_shape=(5, 2),
data_dtype='int32',
out_dtype="int8",
input_scale=[0.5, 0.25],
output_scale=0.5,
axis=1,
rounding=rounding)
verify(mod, (golden_data, golden_output))
# Change axis
golden_data = np.arange(-20, 20, 2).astype('int32').reshape((2,2,5))
golden_output = np.array([-20, -18, -16, -14, -12, -5, -4, -3, -2, -1, 0, 2, 4, 6, 8, 5, 6, 7,
8, 9]).reshape((2, 2, 5))
for rounding in roundings:
mod = get_mod(data_shape=(2, 2, 5),
data_dtype='int32',
out_dtype="int8",
input_scale=[0.5, 0.25],
output_scale=0.5,
axis=1,
rounding=rounding)
verify(mod, (golden_data, golden_output))
if __name__ == "__main__": if __name__ == "__main__":
test_same_scale() test_same_scale()
test_downscale() test_downscale()
test_upscale() test_upscale()
test_saturation() test_saturation()
test_zero_point() test_zero_point()
test_per_channel_same_scale()
test_per_channel_different_scale()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment