Commit a78adbd5 by Animesh Jain Committed by Tianqi Chen

[QNN] Requantize operator (#3531)

* [Relay] [Quantization] WIP - Common files for the qauntization work.

* [Relay] [Quantization] WIP - Prototyping requantize op.

* Requantize operator implementation.

Requantize converts one quantized tensor representation to another quantized
representation. The PR has following implementation features

- Requantize operator defined in qnn namespace - relay.qnn.requantize
- Lowering of the requantize to exisiting Relay operators
- Integer fixed point implementation of requantize
    - Two rounding modes - FE_UPWARDS (round towards infinity) and
    FE_AWAY_FROM_ZERO (std::round behavior)
- Floating point implementation as well, that can act as reference or can be
used for devices when FP32 computation is not used.
- Unit test cases

Relevant Issue - https://github.com/dmlc/tvm/issues/2351

Credit to TFLite and GemmLowp to provide reference implementations.

* Typo and lint fixes.

* Doc fix.

* Uncommenting the lint script (fixing mistake).

* Modifying the unit tests.

* Moving C++ files into src/relay/qnn

* Moving python files to python/tvm/relay/qnn. Some minor fixes.

* Moving the attrs.h inside the include directory.

* Pushing files that I forgot earlier. Changing util location.

* Incorporating comments. API change. Lint fixes.

* Modifying the GetFixedPointMultiplierShift API as per comments.

* Forgot the dialect change.

* Changing rewrite to qnn_lower.

* Renaming Quantize to Qnn for clarity.

* Remove use_int_domain.

* Incorportaing review comments.

* Adding API doc for QNN dialect.

* Move the qnn_lower pass to transform namespace.

* Moving from expr to module. Adding namespace in C++.

* Minor sentence rewrites. Added qnn namespace.

* Added the API doc.

* Chanding default out_dtype to int8. Adding a test with in/out_dtype as uint8.

* Style fixes. Better error messages.

* Adding documentation.

* More documentation fixes.

* Adding out dtype check for requantize.

* Adding corner case for FP32 to fixed point conversion.

* Adding extra line.

* Documentation fix.

* Adding static inline.

* Incorporating jackwish comment. Removed idtype from requantize lowering.

* Removing Quantize/Dequantize code. Restricting Requantize to (u)int8/int32.

* Style fixes.

* Fix the docs.

* Move to Legalize API.
parent 60607eff
...@@ -202,6 +202,16 @@ This level support backpropagation of broadcast operators. It is temporary. ...@@ -202,6 +202,16 @@ This level support backpropagation of broadcast operators. It is temporary.
tvm.relay.contrib.adaptive_avg_pool2d tvm.relay.contrib.adaptive_avg_pool2d
**Level 11: Dialect Operators**
This level supports dialect operators.
.. autosummary::
:nosignatures:
tvm.relay.qnn.op.requantize
Level 1 Definitions Level 1 Definitions
------------------- -------------------
.. autofunction:: tvm.relay.log .. autofunction:: tvm.relay.log
...@@ -340,3 +350,8 @@ Level 10 Definitions ...@@ -340,3 +350,8 @@ Level 10 Definitions
.. autofunction:: tvm.relay.nn.batch_matmul .. autofunction:: tvm.relay.nn.batch_matmul
.. autofunction:: tvm.relay.contrib.adaptive_max_pool2d .. autofunction:: tvm.relay.contrib.adaptive_max_pool2d
.. autofunction:: tvm.relay.contrib.adaptive_avg_pool2d .. autofunction:: tvm.relay.contrib.adaptive_avg_pool2d
Level 11 Definitions
--------------------
.. autofunction:: tvm.relay.qnn.op.requantize
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/relay/qnn/attrs.h
* \brief Auxiliary attributes for qnn operators.
*/
#ifndef TVM_RELAY_QNN_ATTRS_H_
#define TVM_RELAY_QNN_ATTRS_H_
#include <tvm/attrs.h>
#include <string>
namespace tvm {
namespace relay {
namespace qnn {
/*! \brief Attribute for requantize operator */
struct RequantizeAttrs : public tvm::AttrsNode<RequantizeAttrs> {
double input_scale;
int32_t input_zero_point;
double output_scale;
int32_t output_zero_point;
std::string rounding;
DataType out_dtype;
TVM_DECLARE_ATTRS(RequantizeAttrs, "relay.attrs.RequantizeAttrs") {
TVM_ATTR_FIELD(input_scale)
.describe("The scale of the input tensor.");
TVM_ATTR_FIELD(input_zero_point)
.describe("The zero point of the input tensor.");
TVM_ATTR_FIELD(output_scale)
.describe("The scale of the output tensor.");
TVM_ATTR_FIELD(output_zero_point)
.describe("The zero point of the output tensor.");
TVM_ATTR_FIELD(rounding).set_default("TONEAREST")
.describe("Defines the rounding direction when the value is midway between"
"two representable values. There are two supported modes - UPWARD"
"or TONEAREST. Both modes behave exactly same except at the"
"midpoints between the two representable values. At the midpoint,"
"UPWARD rounds towards positive infinity (for example -1.5 will be"
"rounded to -1). TONEAREST is the standard rounding where the"
"value is rounded away from zero at midpoints (for example, -1.5"
"rounds to -2). More context can be found at following gblic manual"
"https://www.gnu.org/software/libc/manual/html_node/Rounding.html.");
TVM_ATTR_FIELD(out_dtype)
.set_default(NullValue<DataType>())
.describe("Output data type, set to explicit type under mixed precision setting");
}
};
} // namespace qnn
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_QNN_ATTRS_H_
...@@ -53,6 +53,9 @@ from . import frontend ...@@ -53,6 +53,9 @@ from . import frontend
from . import backend from . import backend
from . import quantize from . import quantize
# Dialects
from . import qnn
from .scope_builder import ScopeBuilder from .scope_builder import ScopeBuilder
# Span # Span
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=wildcard-import
"""QNN dialect operators and IR passes."""
from __future__ import absolute_import as _abs
from . import op
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=wildcard-import
"""Neural network related operators."""
from __future__ import absolute_import as _abs
from .qnn import *
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Constructor APIs"""
from ...._ffi.function import _init_api
_init_api("relay.qnn.op._make", __name__)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#pylint: disable=invalid-name
"""QNN dialect operators."""
from __future__ import absolute_import as _abs
from . import _make
def requantize(data,
input_scale,
input_zero_point,
output_scale,
output_zero_point,
rounding="TONEAREST",
out_dtype="int8"):
r"""Requantized operator.
The requantize operator converts one quantized tensor representation to
another quantized tensor representation. For the output tensor, we are
provided with output scale and zero point. The computation is as follows
Q_output = zp_output + (scale_input)/(scale_output) * (Q_input - zp_input)
Parameters
----------
data : tvm.relay.Expr
The input data to the operator.
input_scale: float
The quantization scale for the input tensor.
input_zero_point: int
The zero point of the input tensor.
output_scale: float
The quantization scale for the output tensor.
output_zero_point: int
The zero point of the output tensor.
rounding : string, optional
Defines the rounding direction when the value is midway between two
representable values.
out_dtype : str, optional
Specifies the output data type.
Returns
-------
result : tvm.relay.Expr
The computed result.
"""
return _make.requantize(data,
input_scale,
input_zero_point,
output_scale,
output_zero_point,
rounding,
out_dtype)
...@@ -394,6 +394,26 @@ inline Expr Variance(Expr data, Expr mean, Array<Integer> axis, bool keepdims, b ...@@ -394,6 +394,26 @@ inline Expr Variance(Expr data, Expr mean, Array<Integer> axis, bool keepdims, b
} }
static inline Expr Where(const Expr& condition, const Expr& x, const Expr& y) {
static const Op& op = Op::Get("where");
return CallNode::make(op, {condition, x, y});
}
static inline Expr GreaterEqual(const Expr& lhs, const Expr& rhs) {
static const Op& op = Op::Get("greater_equal");
return CallNode::make(op, {lhs, rhs}, Attrs(), {});
}
static inline Expr Full(Expr fill_value,
Array<IndexExpr> shape,
DataType dtype) {
auto attrs = make_node<InitOpAttrs>();
attrs->shape = std::move(shape);
attrs->dtype = std::move(dtype);
static const Op& op = Op::Get("full");
return CallNode::make(op, {fill_value}, Attrs(attrs), {});
}
Expr MakeConcatenate(Expr data, int axis); Expr MakeConcatenate(Expr data, int axis);
Expr MakeStridedSlice(Expr data, Array<Integer> begin, Array<Integer> end, Array<Integer> strides); Expr MakeStridedSlice(Expr data, Array<Integer> begin, Array<Integer> end, Array<Integer> strides);
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2019 by Contributors
* \file requantize.cc
* \brief QNN requantize operator.
*/
#include <tvm/relay/analysis.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/qnn/attrs.h>
#include "../../pass/pattern_util.h"
#include "../util.h"
namespace tvm {
namespace relay {
namespace qnn {
TVM_REGISTER_NODE_TYPE(RequantizeAttrs);
// Lowering of qnn.requantize op
/*
* \brief Convert FP32 representation into fixed point representation.
* \param double_multplier The input FP32 number.
* \return The pair of multiplier and shift for fixed point representation.
* \note Converts a floating point number so that it can be represented by
* integers. The representation is
* float_number = (significand) * 2^(exponent)
*
* The significand is a number between 0.5 and 1. This is represented by
* an integer number. For example, if it is int32, then the decimal point
* exists between bit 31 and 30 from LSB (or between first and second bit
* from the left).
*
* Some examples are
* 0.25 = (0.5) * 2^(-1)
* 0.125 = (0.5) * 2^(-2)
*
* Credit to TFLite reference implementation.
*/
std::pair<int32_t, int32_t> GetFixedPointMultiplierShift(double double_multiplier) {
int32_t significand, exponent;
if (double_multiplier == 0.) {
significand = 0;
exponent = 0;
return std::make_pair(significand, exponent);
}
// Get the significand and exponent.
double significand_d = std::frexp(double_multiplier, &exponent);
// Convert the double significand to int significand, i.e., convert into a
// integer where the decimal point is between bit 31 and 30. This is done by
// multiplying the double value with 2^31 and then casting to int.
significand_d = std::round(significand_d * (1ll << 31));
auto significand_int64 = static_cast<int64_t>(significand_d);
CHECK_LE(significand_int64, (1ll << 31));
if (significand_int64 == (1ll << 31)) {
significand_int64 /= 2;
++exponent;
}
CHECK_LE(significand_int64, std::numeric_limits<int32_t>::max());
significand = static_cast<int32_t>(significand_int64);
return std::make_pair(significand, exponent);
}
/*
* \brief Lower requantize to a sequence of ops.
* \param input_tensor The input tensor to requantize op.
* \param param The requantize op attrs.
* \param input_shape The input tensor shape of the requantize op.
* \return The sequence of existing Relay ops.
* \note Requantization using only integer computation. Here, the computation is
* converted to a fixed point computation by computing output multiplier
* and shift. This is useful, if the target device does not support/have
* very expensive floating point computations.
*
* Original compuation is scale_fp32 * quantized_tensor. To convert into
* integer computation, the multiplication with fp32 scalar can be
* replaced by multiplication with an int value and then right shifting
* the result. This approximates the floating point computation with a
* fixed point computation.
*
* The whole computation this can be broken down into following steps
* 1) Calculate the integer multiplier and integer shift.
* 2) Subtract the input integer zero point.
* 3) Multiply the fixed point multiplier with quantized tensor.
* 4) Round the result.
* 5) Right shift the result.
* 6) Add the output zero point.
* 7) Cast to the out_dtype.
*/
Expr RequantizeLower(const Expr& input_tensor, const RequantizeAttrs* param,
const Array<IndexExpr>& input_shape) {
double double_multiplier = param->input_scale / param->output_scale;
// Choose high precision datatype to be int64. This is for avoiding overflow
// in multiplication of two int32 values.
DataType hp_dtype = Int(64);
// 1) Calculating the integer multiplier and integer shift
int32_t fixed_point_multiplier, shift;
std::tie(fixed_point_multiplier, shift) = GetFixedPointMultiplierShift(double_multiplier);
int left_shift = shift > 0 ? shift : 0;
int right_shift = shift > 0 ? 0 : -shift;
// 2) Subtract the input_zero_point
auto tensor = Cast(input_tensor, hp_dtype);
if (param->input_zero_point != 0) {
auto input_zp = MakeConstantScalar(hp_dtype, param->input_zero_point);
tensor = Subtract(tensor, input_zp);
}
// 3) Multiply the integer multiplier
if (left_shift != 0) {
tensor = Multiply(tensor, MakeConstantScalar(hp_dtype, 1 << left_shift));
}
// Perform the multiplication in higher precision.
// The scalar is a fixed point value of int32 where the decimal point is
// between bits 31 and 30. After multiplying with input_tensor, the result is
// in int64 where the decimal point is sitting between bits 31 and 30 (from
// the right, rightmost bit is bit 0). The computation is performed in higher
// precision to avoid overflow in multiplying two int32 values.
Expr scalar = MakeConstantScalar(hp_dtype, fixed_point_multiplier);
auto multiplied_t = Multiply(tensor, scalar);
// 4) Find the rounding scalar. This depends on where the final decimal point
// sits. As we will be right shifting the multiplied_t, we need to first
// calculate the total_right_shift.
int total_right_shift = right_shift + 31;
int64_t pos_rounding_value = (1ll << (total_right_shift - 1));
tensor = multiplied_t;
Expr round_scalar;
if (param->rounding == "UPWARD") {
round_scalar = MakeConstantScalar(hp_dtype, pos_rounding_value);
} else if (param->rounding == "TONEAREST") {
auto pos_rounder = MakeConstantScalar(hp_dtype, pos_rounding_value);
auto neg_rounder = MakeConstantScalar(hp_dtype, pos_rounding_value - 1);
auto pos_rounder_t = Full(pos_rounder, input_shape, hp_dtype);
auto neg_rounder_t = Full(neg_rounder, input_shape, hp_dtype);
auto zero = MakeConstantScalar(hp_dtype, 0);
auto zero_t = Full(zero, input_shape, hp_dtype);
round_scalar = Where(GreaterEqual(tensor, zero_t), pos_rounder_t, neg_rounder_t);
}
// Add the rounding scalar.
tensor = Add(tensor, round_scalar);
// 5) Simply right shift the result to get the final output.
auto scaled_int64_t = RightShift(tensor, MakeConstantScalar(hp_dtype, total_right_shift));
// 6) Add the output zero point.
auto output_zp = MakeConstantScalar(hp_dtype, param->output_zero_point);
auto shifted_int64_t = Add(output_zp, scaled_int64_t);
// 7) Clip to the out_dtype min/max.
auto q_min = GetQmin(param->out_dtype);
auto q_max = GetQmax(param->out_dtype);
auto clipped_t = Clip(shifted_int64_t, q_min, q_max);
return Cast(clipped_t, param->out_dtype);
}
/*
* \brief Forward rewrite the requantize op.
* \param ref_call The original call that will be lowered.
* \param new_args The new mutated args to the call node.
* \param ctx The node context.
* \return The sequence of Relay ops for requantize op.
* \note Lowering of the requantize operation. The requantize operator converts
* one quantized tensor to another quantized tensor. For the output
* tensor, we are provided with output scale and zero point. The
* computation looks like this
*
* Q_output = zp_output + (scale_input)/(scale_ouptut) * (Q_input - zp_input)
*/
Expr RequantizeLegalize(const Attrs& attrs, const Array<Expr>& new_args,
const Array<tvm::relay::Type>& arg_types) {
CHECK_EQ(new_args.size(), 1);
auto& quantized_data = new_args[0];
const auto* param = attrs.as<RequantizeAttrs>();
CHECK(param != nullptr);
// Find input shape.
CHECK_EQ(arg_types.size(), 1);
auto input_dtype = arg_types[0];
auto input_tensor_type = input_dtype.as<TensorTypeNode>();
CHECK(input_tensor_type != nullptr) << "Type information missing."
<< " Please run infer_type pass.";
Array<IndexExpr> input_shape = input_tensor_type->shape;
// Check rounding validity.
CHECK(param->rounding == "UPWARD" || param->rounding == "TONEAREST")
<< "QNN requantize supports two rounding modes - UPWARD and "
<< "TONEAREST";
return RequantizeLower(quantized_data, param, input_shape);
}
/*
* \brief Infer shape function of Requantize op.
* \param types The types of input args.
* \param num_inputs The number of inputs.
* \param attrs The op attributes.
* \param reporter The type reporter that sets the dtype and shapes.
* \return True if the infer shape succeeded.
*/
bool RequantizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
const auto in_dtype = data->dtype;
CHECK(in_dtype == Int(8) || in_dtype == UInt(8) || in_dtype == Int(32))
<< "Input type should be an integer but was " << in_dtype;
const Array<tvm::Expr> oshape = data->shape;
// assign output type
const RequantizeAttrs* param = attrs.as<RequantizeAttrs>();
auto out_dtype = param->out_dtype;
CHECK(out_dtype == Int(8) || out_dtype == UInt(8) || out_dtype == Int(32))
<< "Output type should be an integer but was " << out_dtype;
reporter->Assign(types[1], TensorTypeNode::make(oshape, out_dtype));
return true;
}
// Positional relay function to create qnn requantize operator
// used by frontend FFI.
Expr MakeRequantize(Expr data, double input_scale, int32_t input_zero_point, double output_scale,
int32_t output_zero_point, std::string rounding, DataType out_dtype) {
auto attrs = make_node<RequantizeAttrs>();
attrs->input_scale = std::move(input_scale);
attrs->input_zero_point = std::move(input_zero_point);
attrs->output_scale = std::move(output_scale);
attrs->output_zero_point = std::move(output_zero_point);
attrs->rounding = std::move(rounding);
attrs->out_dtype = std::move(out_dtype);
static const Op& op = Op::Get("qnn.requantize");
return CallNode::make(op, {data}, Attrs(attrs), {});
}
RELAY_REGISTER_OP("qnn.requantize")
.describe(R"code(Requantize operator.
The requantize operator converts one quantized tensor to another quantized
tensor. For the output tensor, we are provided with output scale and zero
point. The computation looks like this
Q_output = zp_output + (scale_input)/(scale_ouptut) * (Q_input - zp_input)
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.RequantizeAttrs")
.set_num_inputs(1)
.add_argument("data", "Tensor", "The quantized input tensor.")
.set_support_level(11)
.add_type_rel("Requantize", RequantizeRel)
.set_attr<FTVMLegalize>("FTVMLegalize", RequantizeLegalize);
TVM_REGISTER_API("relay.qnn.op._make.requantize")
.set_body_typed(MakeRequantize);
} // namespace qnn
} // namespace relay
} // namespace tvm
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file src/relay/qnn/util.h
* \brief Utility methods needs for quantized ops that can be shared
*/
#ifndef TVM_RELAY_QNN_UTIL_H_
#define TVM_RELAY_QNN_UTIL_H_
#include <tvm/expr.h>
#include <tvm/relay/expr.h>
#include <limits>
namespace tvm {
namespace relay {
namespace qnn {
static inline const int32_t GetQmin(const DataType& dtype) {
CHECK_LE(dtype.bits(), 32)
<< "QNN ops support int32 or lower precision";
if (dtype.is_int()) {
auto* min_value = as_const_int(dtype.min());
CHECK(min_value != nullptr);
return static_cast<int32_t>(min_value[0]);
} else if (dtype.is_uint()) {
auto* min_value = as_const_uint(dtype.min());
CHECK(min_value != nullptr);
return static_cast<int32_t>(min_value[0]);
} else {
LOG(FATAL) << "Type not supported " << dtype;
return -1; // To hide the warning
}
}
static inline const int32_t GetQmax(const DataType& dtype) {
CHECK_LE(dtype.bits(), 32)
<< "QNN ops support int32 or lower precision";
if (dtype.is_int()) {
auto* max_value = as_const_int(dtype.max());
CHECK(max_value != nullptr);
return static_cast<int32_t>(max_value[0]);
} else if (dtype.is_uint()) {
auto* max_value = as_const_uint(dtype.max());
CHECK(max_value != nullptr);
return static_cast<int32_t>(max_value[0]);
} else {
LOG(FATAL) << "Type not supported " << dtype;
return -1; // To hide the warning
}
}
} // namespace qnn
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_QNN_UTIL_H_
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
import numpy as np
from tvm import relay
from tvm.relay.testing import create_workload
from tvm.contrib import graph_runtime
roundings = ["UPWARD", "TONEAREST"]
def run_infer_type(expr):
mod = relay.Module.from_expr(expr)
mod = relay.transform.InferType()(mod)
entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body
def test_requantize():
def verify(mod, goldens):
with relay.build_config(opt_level=3):
graph, lib, params = relay.build(mod, "llvm", params=None)
golden_data, golden_output = goldens
rt_mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0))
rt_mod.set_input("quantized_data",golden_data)
rt_mod.set_input(**params)
rt_mod.run()
res = rt_mod.get_output(0).asnumpy()
np.testing.assert_equal(res, golden_output)
def get_mod(data_shape, data_dtype, out_dtype, input_scale, output_scale,
input_zero_point=0, output_zero_point=0, rounding="TONEAREST"):
quantized_data = relay.var("quantized_data", shape=data_shape,
dtype=data_dtype)
mod = relay.qnn.op.requantize(
quantized_data,
input_scale=input_scale,
input_zero_point=input_zero_point,
output_scale=output_scale,
output_zero_point=output_zero_point,
rounding=rounding,
out_dtype=out_dtype)
mod = relay.Function(relay.analysis.free_vars(mod), mod)
mod = relay.Module.from_expr(mod)
mod = relay.transform.Legalize()(mod)
return mod
def same_scale_test():
# Have same scales, everything within range
golden_data = np.arange(-100, 100, 1).astype('int32')
golden_output = golden_data
for rounding in roundings:
mod = get_mod(data_shape=(200, ),
data_dtype='int32',
out_dtype="int8",
input_scale=0.5,
output_scale=0.5,
rounding=rounding)
verify(mod, (golden_data, golden_output))
def downscale_test():
for rounding in roundings:
mod = get_mod(data_shape=(32, ),
data_dtype='int32',
out_dtype='int8',
input_scale=1,
output_scale=16,
rounding=rounding)
# Try positive values
# 8 corresponds to 0.5, resulting in 1
golden_data = np.arange(0, 32, 1).astype('int32')
golden_output = np.repeat([0, 1, 2], [8, 16, 8])
verify(mod, (golden_data, golden_output))
# Try negative values
# -8 corresponds to -0.5. For UPWARD, this is 0
golden_data = np.arange(0, -32, -1).astype('int32')
if rounding == "UPWARD":
golden_output = np.repeat([0, -1, -2], [9, 16, 7])
else:
golden_output = np.repeat([0, -1, -2], [8, 16, 8])
verify(mod, (golden_data, golden_output))
# Try a different scale
mod = get_mod(data_shape=(32, ),
data_dtype='int32',
out_dtype="int8",
input_scale=1,
output_scale=4,
rounding=rounding)
# Try positive values
# 2I corresponds to 0.5, resulting in 1
golden_data = np.arange(0, 32, 1).astype('int32')
golden_output = np.repeat([0, 1, 2, 3, 4, 5, 6, 7, 8],
[2, 4, 4, 4, 4, 4, 4, 4, 2])
verify(mod, (golden_data, golden_output))
# Try negative values
# -8 corresponds to -0.5. For UPWARD, this is 0
golden_data = np.arange(0, -32, -1).astype('int32')
if rounding == "UPWARD":
golden_output = np.repeat([0, -1, -2, -3, -4, -5, -6, -7, -8],
[3, 4, 4, 4, 4, 4, 4, 4, 1])
else:
golden_output = np.repeat([0, -1, -2, -3, -4, -5, -6, -7, -8],
[2, 4, 4, 4, 4, 4, 4, 4, 2])
verify(mod, (golden_data, golden_output))
# Try uint8 out_dtype
mod = get_mod(data_shape=(32, ),
data_dtype='int32',
out_dtype='uint8',
input_scale=1,
output_scale=16,
rounding=rounding)
# Try positive values
# 8 corresponds to 0.5, resulting in 1
golden_data = np.arange(0, 32, 1).astype('int32')
golden_output = np.repeat([0, 1, 2], [8, 16, 8])
verify(mod, (golden_data, golden_output))
# Try uint8 in_dtyope and uint8 out_dtype
mod = get_mod(data_shape=(32, ),
data_dtype='uint8',
out_dtype='uint8',
input_scale=1,
output_scale=16,
rounding=rounding)
# Try positive values
# 8 corresponds to 0.5, resulting in 1
golden_data = np.arange(0, 32, 1).astype('int32')
golden_output = np.repeat([0, 1, 2], [8, 16, 8])
verify(mod, (golden_data, golden_output))
def upscale_test():
for rounding in roundings:
mod = get_mod(data_shape=(32, ),
data_dtype='int32',
out_dtype="int8",
input_scale=2,
output_scale=1,
rounding=rounding)
# Try positive values
# 8 corresponds to 0.5, resulting in 1
golden_data = np.arange(0, 32, 1).astype('int32')
golden_output = np.multiply(2, golden_data)
verify(mod, (golden_data, golden_output))
# Try negative values
# -8 corresponds to -0.5. For UPWARD, this is 0
golden_data = np.arange(0, -32, -1).astype('int32')
golden_output = np.multiply(2, golden_data)
verify(mod, (golden_data, golden_output))
def saturation_test():
for rounding in roundings:
mod = get_mod(data_shape=(16, ),
data_dtype='int32',
out_dtype="int8",
input_scale=0.5,
output_scale=0.5,
rounding=rounding)
golden_data = np.arange(0, 16, 1).astype('int32')
golden_data = np.add(120, golden_data)
output = np.array([120, 121, 122, 123, 124, 125, 126, 127,
127, 127, 127, 127, 127, 127, 127, 127])
golden_output = output
verify(mod, (golden_data, golden_output))
# Try negative numbers
golden_data = np.arange(0, -16, -1).astype('int32')
golden_data = np.add(-120, golden_data)
output = np.array([-120, -121, -122, -123, -124, -125, -126, -127,
-128, -128, -128, -128, -128, -128, -128, -128])
golden_output = output
verify(mod, (golden_data, golden_output))
def zero_point_test():
# Output zero point
for rounding in roundings:
mod = get_mod(data_shape=(32, ),
data_dtype='int32',
out_dtype='int8',
input_scale=1,
output_scale=16,
output_zero_point=1,
rounding=rounding)
# Try positive values
# 8 corresponds to 0.5, resulting in 1
golden_data = np.arange(0, 32, 1).astype('int32')
golden_output = np.repeat([0, 1, 2], [8, 16, 8])
golden_output = np.add(1, golden_output)
verify(mod, (golden_data, golden_output))
# Try negative values
# -8 corresponds to -0.5. For UPWARD, this is 0
golden_data = np.arange(-32, -64, -1).astype('int32')
if rounding == "UPWARD":
golden_output = np.repeat([-2, -3, -4], [9, 16, 7])
else:
golden_output = np.repeat([-2, -3, -4], [8, 16, 8])
golden_output = np.add(1, golden_output)
verify(mod, (golden_data, golden_output))
# Input zero point
for rounding in roundings:
mod = get_mod(data_shape=(32, ),
data_dtype='int32',
out_dtype='int8',
input_scale=1,
output_scale=16,
input_zero_point=16,
rounding=rounding)
# Try positive values
golden_data = np.arange(32, 64, 1).astype('int32')
golden_output = np.repeat([2, 3, 4], [8, 16, 8])
golden_output = np.subtract(golden_output, 1)
verify(mod, (golden_data, golden_output))
# Try negative values
golden_data = np.arange(-32, -64, -1).astype('int32')
if rounding == "UPWARD":
golden_output = np.repeat([-2, -3, -4], [9, 16, 7])
else:
golden_output = np.repeat([-2, -3, -4], [8, 16, 8])
golden_output = np.subtract(golden_output, 1)
verify(mod, (golden_data, golden_output))
same_scale_test()
downscale_test()
upscale_test()
saturation_test()
zero_point_test()
if __name__ == "__main__":
test_requantize()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment