Commit a78adbd5 by Animesh Jain Committed by Tianqi Chen

[QNN] Requantize operator (#3531)

* [Relay] [Quantization] WIP - Common files for the qauntization work.

* [Relay] [Quantization] WIP - Prototyping requantize op.

* Requantize operator implementation.

Requantize converts one quantized tensor representation to another quantized
representation. The PR has following implementation features

- Requantize operator defined in qnn namespace - relay.qnn.requantize
- Lowering of the requantize to exisiting Relay operators
- Integer fixed point implementation of requantize
    - Two rounding modes - FE_UPWARDS (round towards infinity) and
    FE_AWAY_FROM_ZERO (std::round behavior)
- Floating point implementation as well, that can act as reference or can be
used for devices when FP32 computation is not used.
- Unit test cases

Relevant Issue - https://github.com/dmlc/tvm/issues/2351

Credit to TFLite and GemmLowp to provide reference implementations.

* Typo and lint fixes.

* Doc fix.

* Uncommenting the lint script (fixing mistake).

* Modifying the unit tests.

* Moving C++ files into src/relay/qnn

* Moving python files to python/tvm/relay/qnn. Some minor fixes.

* Moving the attrs.h inside the include directory.

* Pushing files that I forgot earlier. Changing util location.

* Incorporating comments. API change. Lint fixes.

* Modifying the GetFixedPointMultiplierShift API as per comments.

* Forgot the dialect change.

* Changing rewrite to qnn_lower.

* Renaming Quantize to Qnn for clarity.

* Remove use_int_domain.

* Incorportaing review comments.

* Adding API doc for QNN dialect.

* Move the qnn_lower pass to transform namespace.

* Moving from expr to module. Adding namespace in C++.

* Minor sentence rewrites. Added qnn namespace.

* Added the API doc.

* Chanding default out_dtype to int8. Adding a test with in/out_dtype as uint8.

* Style fixes. Better error messages.

* Adding documentation.

* More documentation fixes.

* Adding out dtype check for requantize.

* Adding corner case for FP32 to fixed point conversion.

* Adding extra line.

* Documentation fix.

* Adding static inline.

* Incorporating jackwish comment. Removed idtype from requantize lowering.

* Removing Quantize/Dequantize code. Restricting Requantize to (u)int8/int32.

* Style fixes.

* Fix the docs.

* Move to Legalize API.
parent 60607eff
......@@ -202,6 +202,16 @@ This level support backpropagation of broadcast operators. It is temporary.
tvm.relay.contrib.adaptive_avg_pool2d
**Level 11: Dialect Operators**
This level supports dialect operators.
.. autosummary::
:nosignatures:
tvm.relay.qnn.op.requantize
Level 1 Definitions
-------------------
.. autofunction:: tvm.relay.log
......@@ -340,3 +350,8 @@ Level 10 Definitions
.. autofunction:: tvm.relay.nn.batch_matmul
.. autofunction:: tvm.relay.contrib.adaptive_max_pool2d
.. autofunction:: tvm.relay.contrib.adaptive_avg_pool2d
Level 11 Definitions
--------------------
.. autofunction:: tvm.relay.qnn.op.requantize
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/relay/qnn/attrs.h
* \brief Auxiliary attributes for qnn operators.
*/
#ifndef TVM_RELAY_QNN_ATTRS_H_
#define TVM_RELAY_QNN_ATTRS_H_
#include <tvm/attrs.h>
#include <string>
namespace tvm {
namespace relay {
namespace qnn {
/*! \brief Attribute for requantize operator */
struct RequantizeAttrs : public tvm::AttrsNode<RequantizeAttrs> {
double input_scale;
int32_t input_zero_point;
double output_scale;
int32_t output_zero_point;
std::string rounding;
DataType out_dtype;
TVM_DECLARE_ATTRS(RequantizeAttrs, "relay.attrs.RequantizeAttrs") {
TVM_ATTR_FIELD(input_scale)
.describe("The scale of the input tensor.");
TVM_ATTR_FIELD(input_zero_point)
.describe("The zero point of the input tensor.");
TVM_ATTR_FIELD(output_scale)
.describe("The scale of the output tensor.");
TVM_ATTR_FIELD(output_zero_point)
.describe("The zero point of the output tensor.");
TVM_ATTR_FIELD(rounding).set_default("TONEAREST")
.describe("Defines the rounding direction when the value is midway between"
"two representable values. There are two supported modes - UPWARD"
"or TONEAREST. Both modes behave exactly same except at the"
"midpoints between the two representable values. At the midpoint,"
"UPWARD rounds towards positive infinity (for example -1.5 will be"
"rounded to -1). TONEAREST is the standard rounding where the"
"value is rounded away from zero at midpoints (for example, -1.5"
"rounds to -2). More context can be found at following gblic manual"
"https://www.gnu.org/software/libc/manual/html_node/Rounding.html.");
TVM_ATTR_FIELD(out_dtype)
.set_default(NullValue<DataType>())
.describe("Output data type, set to explicit type under mixed precision setting");
}
};
} // namespace qnn
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_QNN_ATTRS_H_
......@@ -53,6 +53,9 @@ from . import frontend
from . import backend
from . import quantize
# Dialects
from . import qnn
from .scope_builder import ScopeBuilder
# Span
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=wildcard-import
"""QNN dialect operators and IR passes."""
from __future__ import absolute_import as _abs
from . import op
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=wildcard-import
"""Neural network related operators."""
from __future__ import absolute_import as _abs
from .qnn import *
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Constructor APIs"""
from ...._ffi.function import _init_api
_init_api("relay.qnn.op._make", __name__)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#pylint: disable=invalid-name
"""QNN dialect operators."""
from __future__ import absolute_import as _abs
from . import _make
def requantize(data,
input_scale,
input_zero_point,
output_scale,
output_zero_point,
rounding="TONEAREST",
out_dtype="int8"):
r"""Requantized operator.
The requantize operator converts one quantized tensor representation to
another quantized tensor representation. For the output tensor, we are
provided with output scale and zero point. The computation is as follows
Q_output = zp_output + (scale_input)/(scale_output) * (Q_input - zp_input)
Parameters
----------
data : tvm.relay.Expr
The input data to the operator.
input_scale: float
The quantization scale for the input tensor.
input_zero_point: int
The zero point of the input tensor.
output_scale: float
The quantization scale for the output tensor.
output_zero_point: int
The zero point of the output tensor.
rounding : string, optional
Defines the rounding direction when the value is midway between two
representable values.
out_dtype : str, optional
Specifies the output data type.
Returns
-------
result : tvm.relay.Expr
The computed result.
"""
return _make.requantize(data,
input_scale,
input_zero_point,
output_scale,
output_zero_point,
rounding,
out_dtype)
......@@ -394,6 +394,26 @@ inline Expr Variance(Expr data, Expr mean, Array<Integer> axis, bool keepdims, b
}
static inline Expr Where(const Expr& condition, const Expr& x, const Expr& y) {
static const Op& op = Op::Get("where");
return CallNode::make(op, {condition, x, y});
}
static inline Expr GreaterEqual(const Expr& lhs, const Expr& rhs) {
static const Op& op = Op::Get("greater_equal");
return CallNode::make(op, {lhs, rhs}, Attrs(), {});
}
static inline Expr Full(Expr fill_value,
Array<IndexExpr> shape,
DataType dtype) {
auto attrs = make_node<InitOpAttrs>();
attrs->shape = std::move(shape);
attrs->dtype = std::move(dtype);
static const Op& op = Op::Get("full");
return CallNode::make(op, {fill_value}, Attrs(attrs), {});
}
Expr MakeConcatenate(Expr data, int axis);
Expr MakeStridedSlice(Expr data, Array<Integer> begin, Array<Integer> end, Array<Integer> strides);
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file src/relay/qnn/util.h
* \brief Utility methods needs for quantized ops that can be shared
*/
#ifndef TVM_RELAY_QNN_UTIL_H_
#define TVM_RELAY_QNN_UTIL_H_
#include <tvm/expr.h>
#include <tvm/relay/expr.h>
#include <limits>
namespace tvm {
namespace relay {
namespace qnn {
static inline const int32_t GetQmin(const DataType& dtype) {
CHECK_LE(dtype.bits(), 32)
<< "QNN ops support int32 or lower precision";
if (dtype.is_int()) {
auto* min_value = as_const_int(dtype.min());
CHECK(min_value != nullptr);
return static_cast<int32_t>(min_value[0]);
} else if (dtype.is_uint()) {
auto* min_value = as_const_uint(dtype.min());
CHECK(min_value != nullptr);
return static_cast<int32_t>(min_value[0]);
} else {
LOG(FATAL) << "Type not supported " << dtype;
return -1; // To hide the warning
}
}
static inline const int32_t GetQmax(const DataType& dtype) {
CHECK_LE(dtype.bits(), 32)
<< "QNN ops support int32 or lower precision";
if (dtype.is_int()) {
auto* max_value = as_const_int(dtype.max());
CHECK(max_value != nullptr);
return static_cast<int32_t>(max_value[0]);
} else if (dtype.is_uint()) {
auto* max_value = as_const_uint(dtype.max());
CHECK(max_value != nullptr);
return static_cast<int32_t>(max_value[0]);
} else {
LOG(FATAL) << "Type not supported " << dtype;
return -1; // To hide the warning
}
}
} // namespace qnn
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_QNN_UTIL_H_
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment