Commit 2be444f9 by shoubhik Committed by Zhi

Improve the lowering of Qnn Dense (#4213)

* [QNN] Improving Dense lowering.

* - Moving get_shape method to util
- Finalizing the test cases and the code structure for optimized dense computation.

* - Fixing cpplint.

* - Addressing review comments.

* - Renaming the variables correctly.

* - Renaming the variables correctly.
parent 50e4aa0d
...@@ -213,7 +213,7 @@ struct QnnDenseAttrs : public tvm::AttrsNode<QnnDenseAttrs> { ...@@ -213,7 +213,7 @@ struct QnnDenseAttrs : public tvm::AttrsNode<QnnDenseAttrs> {
int32_t input_zero_point; int32_t input_zero_point;
int32_t kernel_zero_point; int32_t kernel_zero_point;
TVM_DECLARE_ATTRS(QnnDenseAttrs, "relay.attrs.qnn.QnnDenseAttrs") { TVM_DECLARE_ATTRS(QnnDenseAttrs, "relay.attrs.QnnDenseAttrs") {
TVM_ATTR_FIELD(units) TVM_ATTR_FIELD(units)
.describe("Number of hidden units of the dense transformation."); .describe("Number of hidden units of the dense transformation.");
TVM_ATTR_FIELD(out_dtype) TVM_ATTR_FIELD(out_dtype)
......
...@@ -22,3 +22,7 @@ from ...base import register_relay_attr_node ...@@ -22,3 +22,7 @@ from ...base import register_relay_attr_node
@register_relay_attr_node @register_relay_attr_node
class QnnConv2DAttrs(Attrs): class QnnConv2DAttrs(Attrs):
"""Attributes for qnn.conv2d""" """Attributes for qnn.conv2d"""
@register_relay_attr_node
class QnnDenseAttrs(Attrs):
"""Attributes for qnn.dense"""
...@@ -70,13 +70,6 @@ using WorkloadType = std::tuple<int, int, int, int, int>; ...@@ -70,13 +70,6 @@ using WorkloadType = std::tuple<int, int, int, int, int>;
*/ */
WorkloadType GetWorkload(const Array<tvm::relay::Type>& arg_types, const QnnConv2DAttrs* param) { WorkloadType GetWorkload(const Array<tvm::relay::Type>& arg_types, const QnnConv2DAttrs* param) {
// Get conv parameters. // Get conv parameters.
auto get_shape = [](const Type& type) {
auto input_tt = type.as<TensorTypeNode>();
CHECK(input_tt != nullptr) << "Type information missing."
<< " Please run infer_type pass.";
return input_tt->shape;
};
const auto in_shape = get_shape(arg_types[0]); const auto in_shape = get_shape(arg_types[0]);
int batch_size, in_channels; int batch_size, in_channels;
if (param->data_layout == "NCHW") { if (param->data_layout == "NCHW") {
......
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include <tvm/relay/qnn/attrs.h> #include <tvm/relay/qnn/attrs.h>
#include "../../op/nn/nn.h" #include "../../op/nn/nn.h"
#include "../../pass/pattern_util.h" #include "../../pass/pattern_util.h"
#include "../util.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
...@@ -37,33 +38,27 @@ namespace qnn { ...@@ -37,33 +38,27 @@ namespace qnn {
// relay.op.qnn.dense // relay.op.qnn.dense
TVM_REGISTER_NODE_TYPE(QnnDenseAttrs); TVM_REGISTER_NODE_TYPE(QnnDenseAttrs);
bool QnnDenseRel(const Array<Type>& types, bool QnnDenseRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) { const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3); CHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>(); const auto* data = types[0].as<TensorTypeNode>();
const auto* weight = types[1].as<TensorTypeNode>(); const auto* weight = types[1].as<TensorTypeNode>();
if (data == nullptr || weight == nullptr) return false; if (data == nullptr || weight == nullptr) return false;
const auto* param = attrs.as<QnnDenseAttrs>(); const auto* param = attrs.as<QnnDenseAttrs>();
CHECK(param != nullptr) << "QnnConv2DAttrs cannot be nullptr."; CHECK(param != nullptr) << "QnnDenseAttrs cannot be nullptr.";
CHECK(data->dtype == Int(8) || data->dtype == UInt(8)) CHECK(data->dtype == Int(8) || data->dtype == UInt(8))
<< "Expected quantized dense type(int8, uint8) for input but was " << data->dtype; << "Expected quantized dense type(int8, uint8) for input but was " << data->dtype;
CHECK(weight->dtype == Int(8) || weight->dtype == UInt(8)) CHECK(weight->dtype == Int(8) || weight->dtype == UInt(8))
<< "Expected quantized dense type(int8, uint8) for weight but was " << weight->dtype; << "Expected quantized dense type(int8, uint8) for weight but was " << weight->dtype;
CHECK(param->out_dtype == Int(32)) CHECK(param->out_dtype == Int(32))
<< "Expected quantized dense type(int32) for output but was " << param->out_dtype; << "Expected quantized dense type(int32) for output but was " << param->out_dtype;
CHECK(param->out_dtype.bits() > 0) << "Output dtype bits should be greater than 0."; CHECK(param->out_dtype.bits() > 0) << "Output dtype bits should be greater than 0.";
return DenseRel<QnnDenseAttrs>(types, num_inputs, attrs, reporter); return DenseRel<QnnDenseAttrs>(types, num_inputs, attrs, reporter);
} }
// Positional relay function to create quantized dense operator used by frontend FFI. // Positional relay function to create quantized dense operator used by frontend FFI.
Expr MakeQuantizedDense(Expr data, Expr MakeQuantizedDense(Expr data, Expr weight, IndexExpr units, int32_t input_zero_point,
Expr weight, int32_t kernel_zero_point, DataType out_dtype) {
IndexExpr units,
int32_t input_zero_point,
int32_t kernel_zero_point,
DataType out_dtype) {
auto attrs = make_node<QnnDenseAttrs>(); auto attrs = make_node<QnnDenseAttrs>();
attrs->units = std::move(units); attrs->units = std::move(units);
attrs->out_dtype = out_dtype; attrs->out_dtype = out_dtype;
...@@ -73,40 +68,93 @@ Expr MakeQuantizedDense(Expr data, ...@@ -73,40 +68,93 @@ Expr MakeQuantizedDense(Expr data,
return CallNode::make(op, {data, weight}, Attrs(attrs), {}); return CallNode::make(op, {data, weight}, Attrs(attrs), {});
} }
/** Expr DenseFirstTerm(const Expr& quantized_data, const Expr& quantized_kernel,
* \brief Lowers Qnn convolution in terms of core operators in relay. const QnnDenseAttrs* attrs) {
* Mathematically it is equals to - return Dense(quantized_data, quantized_kernel, attrs->units, attrs->out_dtype);
* Dense((quantized_input - input_zero_point;int32), (quantized_kernel - kernel_zero_point; int32)) }
*
* \param attrs QnnDenseAttrs for Qnn Dense layer. Expr DenseSecondTerm(const Expr& quantized_data, const Expr& zp_kernel) {
Array<Integer> axes = {1};
return Multiply(zp_kernel, Sum(Cast(quantized_data, Int(32)), axes, true, false));
}
Expr DenseThirdTerm(const Expr& quantized_kernel, const Expr& zp_data) {
Array<Integer> axes = {1};
return Multiply(zp_data, Sum(Cast(quantized_kernel, Int(32)), axes, false, false));
}
Expr DenseFourthTerm(const QnnDenseAttrs* attrs, int reduction_dim_size) {
int32_t scalar_term = attrs->input_zero_point * attrs->kernel_zero_point * reduction_dim_size;
return MakeConstantScalar(Int(32), scalar_term);
}
/*
* \brief Forward rewrite the qnn dense op.
* \param attrs The QNN dense attrs.
* \param new_args The new mutated args to the call node. * \param new_args The new mutated args to the call node.
* \param arg_types The data types of input and output. * \param arg_types The types of input and output.
* \reutrn The sequence of Relay ops for qnn cov2d op. * \return The sequence of Relay ops for qnn cov2d op.
* \note Lowering of the qnn.dense operator
* A quantized tensor is represented in following manner
* A = scale_a x (QA - zp_A)
* where QA is quantized tensor, scale_a and zp_A are quantization
* params.
*
* Quantized dense multiplies two quantized tensors and returns a
* quantized tensor of default dtype of int32, with scale equaling to the
* product of scales of input tensors, and a zero point of zero.
*
* The lowering for asymmetric quantized dense looks as follows. More details at
* https://discuss.tvm.ai/t/tf-lite-quantized-conv2d-operator-conversion/2651/8
* The computation gets unrolled into following 4 terms
* C(m, n) = Sigma(k) (A(m, k) * W(n, k))
*
* RHS becomes
* Sigma(k) ([QA(m, k) - zp_a] * [QW(n, k) - zp_w])
*
* Unrolling leads to following sequence
* Sigma(k) QA(m, k) * QW(n, k) // Term1
* - Sigma(k) zp_w * QA(m, k) // Term2
* - Sigma(k) zp_a * QW(n, k) // Term3
* - Sigma(k) * zp_a * zp_w // Term4
*
* Term3 and Term4 can be computed at compile time.
*/ */
Expr QnnDenseCanonicalize(const Attrs& attrs, Expr QnnDenseCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
const Array<Expr>& new_args,
const Array<tvm::relay::Type>& arg_types) { const Array<tvm::relay::Type>& arg_types) {
CHECK_EQ(new_args.size(), 2); CHECK_EQ(new_args.size(), 2);
Expr quantized_data = new_args[0]; Expr quantized_data = new_args[0];
Expr quantized_kernel = new_args[1]; Expr quantized_kernel = new_args[1];
const auto in_shape = get_shape(arg_types[0]);
const int reduction_dim_size = get_const_int(in_shape[1]);
const auto* qnn_dense_attrs = attrs.as<QnnDenseAttrs>(); const auto* qnn_dense_attrs = attrs.as<QnnDenseAttrs>();
Expr quantized_data_int32 = Cast(quantized_data, Int(32)); auto zp_kernel = MakeConstantScalar(Int(32), qnn_dense_attrs->kernel_zero_point);
if (qnn_dense_attrs->input_zero_point != 0) { auto zp_data = MakeConstantScalar(Int(32), qnn_dense_attrs->input_zero_point);
quantized_data_int32 = Subtract(quantized_data_int32,
MakeConstantScalar(Int(32), // Get all the terms as described in the comments.
qnn_dense_attrs->input_zero_point)); auto term1 = DenseFirstTerm(quantized_data, quantized_kernel, qnn_dense_attrs);
} auto term2 = DenseSecondTerm(quantized_data, zp_kernel);
Expr quantized_kernel_int32 = Cast(quantized_kernel, Int(32)); auto term3 = DenseThirdTerm(quantized_kernel, zp_data);
if (qnn_dense_attrs->kernel_zero_point != 0) { auto term4 = DenseFourthTerm(qnn_dense_attrs, reduction_dim_size);
quantized_kernel_int32 = Subtract(quantized_kernel_int32,
MakeConstantScalar(Int(32), // Combine those 4 terms depending on the zero points to get the best lowering.
qnn_dense_attrs->kernel_zero_point)); if (qnn_dense_attrs->input_zero_point == 0 && qnn_dense_attrs->kernel_zero_point == 0) {
// term 2, 3 and 4 become zero.
return term1;
} else if (qnn_dense_attrs->input_zero_point == 0 && qnn_dense_attrs->kernel_zero_point != 0) {
// term 3 and term 4 become zero.
return Subtract(term1, term2);
} else if (qnn_dense_attrs->input_zero_point != 0 && qnn_dense_attrs->kernel_zero_point == 0) {
// term 2 and term 4 become zero.
return Subtract(term1, term3);
} else {
auto data_term = Subtract(term1, term2);
// Putting constant terms together, so that constant folding can fold it.
auto const_term = Subtract(term4, term3);
return Add(data_term, const_term);
} }
Expr int32_dense = Dense(quantized_data_int32,
quantized_kernel_int32,
qnn_dense_attrs->units,
qnn_dense_attrs->out_dtype);
return int32_dense;
} }
RELAY_REGISTER_OP("qnn.dense") RELAY_REGISTER_OP("qnn.dense")
......
...@@ -36,6 +36,13 @@ namespace tvm { ...@@ -36,6 +36,13 @@ namespace tvm {
namespace relay { namespace relay {
namespace qnn { namespace qnn {
static inline Array<IndexExpr> get_shape(const Type& type) {
auto input_tt = type.as<TensorTypeNode>();
CHECK(input_tt != nullptr) << "Type information missing."
<< " Please run infer_type pass.";
return input_tt->shape;
}
static inline const int32_t GetQmin(const DataType& dtype) { static inline const int32_t GetQmin(const DataType& dtype) {
CHECK_LE(dtype.bits(), 32) CHECK_LE(dtype.bits(), 32)
<< "QNN ops support int32 or lower precision"; << "QNN ops support int32 or lower precision";
......
...@@ -193,29 +193,20 @@ def qnn_dense_driver(test_configuration): ...@@ -193,29 +193,20 @@ def qnn_dense_driver(test_configuration):
def test_qnn_dense_without_bias(): def test_qnn_dense_without_bias():
uint32_output_without_bias_paramas = \
make_uint_configuration(use_bias=False)
int32_output_without_bias_params = \ int32_output_without_bias_params = \
make_int_configuration(use_bias=False) make_int_configuration(use_bias=False)
qnn_dense_driver(uint32_output_without_bias_paramas)
qnn_dense_driver(int32_output_without_bias_params) qnn_dense_driver(int32_output_without_bias_params)
def test_qnn_dense_with_bias(): def test_qnn_dense_with_bias():
uint32_output_with_bias_params = \
make_uint_configuration(use_bias=True)
int32_output_with_bias_params = \ int32_output_with_bias_params = \
make_int_configuration(use_bias=True) make_int_configuration(use_bias=True)
qnn_dense_driver(uint32_output_with_bias_params)
qnn_dense_driver(int32_output_with_bias_params) qnn_dense_driver(int32_output_with_bias_params)
def test_qnn_dense_with_requantized_output(): def test_qnn_dense_with_requantized_output():
uint8_requantized_output_with_bias_params = \
make_uint_configuration(use_bias=True, requantize_output=True)
int8_requantized_output_with_bias_params = \ int8_requantized_output_with_bias_params = \
make_int_configuration(use_bias=True, requantize_output=True) make_int_configuration(use_bias=True, requantize_output=True)
qnn_dense_driver(uint8_requantized_output_with_bias_params)
qnn_dense_driver(int8_requantized_output_with_bias_params) qnn_dense_driver(int8_requantized_output_with_bias_params)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment