Unverified Commit dada6761 by shoubhik Committed by GitHub

Adding support for QNN subtract op (#5153)

* Adding support for QNN subtract op

* Fixing typo.

* Fixing typo.

* Fixing lint.

* Addressing review comments.

* Renaming variables as per convention and renamed QnnBinaryOpTypes -> QnnBinaryOpType

* Renaming QnnBinaryOpType to QnnBinaryOpTensorType which now takes the index you want to extract to make the code more readable.

* Fixing lint.

* Moving common code to macro.

* Fixing alignment.

* Fixing typo.

* Fixing lint.

* Renaming method to pass CI.
parent f4286cc7
...@@ -310,9 +310,6 @@ def add(lhs, ...@@ -310,9 +310,6 @@ def add(lhs,
rhs : relay.Expr rhs : relay.Expr
The right hand side quantized input data. The right hand side quantized input data.
lhs_scale: float
The scale of the lhs quantized expr.
lhs_scale: relay.Expr lhs_scale: relay.Expr
The scale of the lhs quantized expr. The scale of the lhs quantized expr.
...@@ -436,3 +433,51 @@ def mul(lhs, rhs, lhs_scale, lhs_zero_point, rhs_scale, rhs_zero_point, ...@@ -436,3 +433,51 @@ def mul(lhs, rhs, lhs_scale, lhs_zero_point, rhs_scale, rhs_zero_point,
lhs_scale, lhs_zero_point, lhs_scale, lhs_zero_point,
rhs_scale, rhs_zero_point, rhs_scale, rhs_zero_point,
output_scale, output_zero_point) output_scale, output_zero_point)
def subtract(lhs,
rhs,
lhs_scale,
lhs_zero_point,
rhs_scale,
rhs_zero_point,
output_scale,
output_zero_point):
"""Quantized subtraction with numpy-style broadcasting.
Parameters
----------
lhs : relay.Expr
The left hand side quantized input data.
rhs : relay.Expr
The right hand side quantized input data.
lhs_scale: relay.Expr
The scale of the lhs quantized expr.
lhs_zero_point: relay.Expr
The zero point of lhs quantized expr.
rhs_scale: relay.Expr
The scale of the rhs quantized expr.
rhs_zero_point: relay.Expr
The zero point of rhs quantized expr.
output_scale: relay.Expr
The scale of the output quantized expr.
output_zero_point: relay.Expr
The zero point of output quantized expr.
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.subtract(lhs, rhs,
lhs_scale, lhs_zero_point,
rhs_scale, rhs_zero_point,
output_scale, output_zero_point)
...@@ -23,59 +23,27 @@ ...@@ -23,59 +23,27 @@
*/ */
#include <tvm/relay/analysis.h> #include <tvm/relay/analysis.h>
#include <tvm/relay/op_attr_types.h> #include <tvm/relay/op_attr_types.h>
#include <tvm/relay/qnn/attrs.h>
#include "../../transforms/pattern_util.h"
#include "../../transforms/infer_layout_util.h"
#include "../util.h"
#include "op_common.h" #include "op_common.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
namespace qnn { namespace qnn {
/*! \brief Infer layout for QNN binary broadcast operators */
Array<Array<Layout> > QnnBinaryBroadcastLayout(const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<tvm::relay::Type>& old_in_types) {
// Use Relay Binary Broadcast Infer correct layout.
auto layouts = BinaryBroadcastLayout(attrs, new_in_layouts, old_in_layouts, old_in_types);
// Fill the layouts of remaining input tensors - scales and zero points. The layouts of these
// tensors can be treated as C.
Layout channel_layout = Layout("C");
Array<Layout> input_layouts = {layouts[0][0], layouts[0][1], channel_layout, channel_layout,
channel_layout, channel_layout, channel_layout, channel_layout};
Array<Layout> output_layouts = layouts[1];
return {input_layouts, output_layouts};
}
/* /*
* \brief Canonicalizes the QNN add op. * \brief Canonicalizes the QNN add op.
* \param attrs The QNN concatenate attrs. * \param attrs The empty attribute.
* \param new_args The new mutated args to the call node. * \param new_args The new mutated args to the call node.
* \param arg_types The types of input and output. * \param arg_types The types of input and output.
* \return The sequence of Relay ops for add op. * \return The sequence of Relay ops for add op.
*/ */
Expr QnnAddCanonicalize(const Attrs& attrs, const Array<Expr>& new_args, Expr QnnAddCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
const Array<tvm::relay::Type>& arg_types) { const Array<tvm::relay::Type>& arg_types) {
// Get the attrs. // Get the args.
CHECK_EQ(new_args.size(), 8); QnnBinaryOpArguments args(new_args);
auto& lhs = new_args[0];
auto& rhs = new_args[1];
auto& lhs_scale = new_args[2];
auto& lhs_zero_point = new_args[3];
auto& rhs_scale = new_args[4];
auto& rhs_zero_point = new_args[5];
auto& output_scale = new_args[6];
auto& output_zero_point = new_args[7];
// Get the input dtype and shape. // Get the input dtype and shape.
CHECK_EQ(arg_types.size(), 9); QnnBinaryOpTensorType input_type(arg_types, 0);
auto tensor_type = arg_types[0].as<TensorTypeNode>();
CHECK(tensor_type != nullptr);
auto input_dtype = tensor_type->dtype;
auto input_shape = tensor_type->shape;
// FIXME (anijain2305) - The lowering can be further optimized. Instead of inserting requantize in // FIXME (anijain2305) - The lowering can be further optimized. Instead of inserting requantize in
// the start, we can insert requantize at the end if both input tensors have same qnn params. In // the start, we can insert requantize at the end if both input tensors have same qnn params. In
...@@ -97,47 +65,36 @@ Expr QnnAddCanonicalize(const Attrs& attrs, const Array<Expr>& new_args, ...@@ -97,47 +65,36 @@ Expr QnnAddCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
// Q_c = Q_a' + Q_b' - zp_c // Q_c = Q_a' + Q_b' - zp_c
// The add op is done in int32 precision. // The add op is done in int32 precision.
// Requantize LHS if necessary.
auto requantized_lhs = lhs;
if (!IsEqualScalar(lhs_scale, output_scale) ||
!IsEqualScalar(lhs_zero_point, output_zero_point)) {
requantized_lhs = Requantize(lhs, input_shape, lhs_scale, lhs_zero_point, output_scale,
output_zero_point, DataType::Int(32));
} else {
requantized_lhs = Cast(requantized_lhs, DataType::Int(32));
}
// Requantize RHS if necessary.
auto requantized_rhs = rhs;
if (!IsEqualScalar(rhs_scale, output_scale) ||
!IsEqualScalar(rhs_zero_point, output_zero_point)) {
requantized_rhs = Requantize(rhs, input_shape, rhs_scale, rhs_zero_point, output_scale,
output_zero_point, DataType::Int(32));
} else {
requantized_rhs = Cast(requantized_rhs, DataType::Int(32));
}
// Requantize LHS if necessary. Computes Q_a'
auto requantized_lhs = RequantizeOrUpcast(args.lhs, args.lhs_scale,
args.lhs_zero_point,
args.output_scale, args.output_zero_point,
input_type.shape);
// Requantize RHS if necessary. Computes Q_b'
auto requantized_rhs = RequantizeOrUpcast(args.rhs, args.rhs_scale,
args.rhs_zero_point,
args.output_scale, args.output_zero_point,
input_type.shape);
// Computes Q_a' + Q_b'
auto output = Add(requantized_lhs, requantized_rhs); auto output = Add(requantized_lhs, requantized_rhs);
// Subtract zero point. // Subtract zero point. Computes (Q_a' + Q_b') - zp_c
auto zero_scalar = MakeConstantScalar(DataType::Int(32), 0); auto zero_scalar = MakeConstantScalar(DataType::Int(32), 0);
if (!IsEqualScalar(output_zero_point, zero_scalar)) { if (!IsEqualScalar(args.output_zero_point, zero_scalar)) {
output = Subtract(output, output_zero_point); output = Subtract(output, args.output_zero_point);
} }
// Go back to lower precision. // Go back to lower precision.
auto q_min = GetQmin(input_dtype); return ConvertDtype(output, input_type.dtype);
auto q_max = GetQmax(input_dtype);
output = Clip(output, q_min, q_max);
return Cast(output, input_dtype);
} }
// QNN Addition operator. // QNN Addition operator.
QNN_REGISTER_BINARY_OP("add") QNN_REGISTER_BINARY_OP("add")
.describe("Elementwise add with with broadcasting for quantized tensors.") .describe("Elementwise add with with broadcasting for quantized tensors.")
.set_support_level(11) .set_support_level(11)
.set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QnnAddCanonicalize) .set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QnnAddCanonicalize);
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", QnnBinaryBroadcastLayout);
} // namespace qnn } // namespace qnn
} // namespace relay } // namespace relay
......
...@@ -42,22 +42,13 @@ namespace qnn { ...@@ -42,22 +42,13 @@ namespace qnn {
Expr QnnMulCanonicalize(const Attrs& attrs, const Array<Expr>& new_args, Expr QnnMulCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
const Array<tvm::relay::Type>& arg_types) { const Array<tvm::relay::Type>& arg_types) {
// Get the attrs. // Get the attrs.
CHECK_EQ(new_args.size(), 8); QnnBinaryOpArguments args(new_args);
auto& lhs = new_args[0];
auto& rhs = new_args[1];
auto& lhs_scale = new_args[2];
auto& lhs_zero_point = new_args[3];
auto& rhs_scale = new_args[4];
auto& rhs_zero_point = new_args[5];
auto& output_scale = new_args[6];
auto& output_zero_point = new_args[7];
// Get the input dtype and shape. // Get the input dtype and shape.
CHECK_EQ(arg_types.size(), 9); QnnBinaryOpTensorType input_type(arg_types, 0);
auto tensor_type = arg_types[0].as<TensorTypeNode>(); // data types
CHECK(tensor_type != nullptr); const auto int32_dtype = DataType::Int(32);
auto input_dtype = tensor_type->dtype; const auto float32_dtype = DataType::Float(32);
auto input_shape = tensor_type->shape;
/* /*
A tensor multiplication c = a * b can be written in terms of respective A tensor multiplication c = a * b can be written in terms of respective
...@@ -71,31 +62,35 @@ Expr QnnMulCanonicalize(const Attrs& attrs, const Array<Expr>& new_args, ...@@ -71,31 +62,35 @@ Expr QnnMulCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
which is essentially a requantization of tensor Q' into tensor Q_c. which is essentially a requantization of tensor Q' into tensor Q_c.
*/ */
auto lhs_shifted = Cast(lhs, DataType::Int(32)); auto lhs_shifted = Cast(args.lhs, int32_dtype);
auto rhs_shifted = Cast(rhs, DataType::Int(32)); auto rhs_shifted = Cast(args.rhs, int32_dtype);
auto zero_scalar = MakeConstantScalar(DataType::Int(32), 0); auto zero_scalar = MakeConstantScalar(int32_dtype, 0);
if (!IsEqualScalar(lhs_zero_point, zero_scalar)) { if (!IsEqualScalar(args.lhs_zero_point, zero_scalar)) {
lhs_shifted = Subtract(lhs_shifted, lhs_zero_point); lhs_shifted = Subtract(lhs_shifted, args.lhs_zero_point);
} }
if (!IsEqualScalar(rhs_zero_point, zero_scalar)) { if (!IsEqualScalar(args.rhs_zero_point, zero_scalar)) {
rhs_shifted = Subtract(rhs_shifted, rhs_zero_point); rhs_shifted = Subtract(rhs_shifted, args.rhs_zero_point);
} }
// Create a new tensor Q' // Create a new tensor Q'
auto output = Multiply(lhs_shifted, rhs_shifted); auto output = Multiply(lhs_shifted, rhs_shifted);
// Get the adjusted new scale and zero points. // Get the adjusted new scale and zero points.
float lhs_scale_float = GetScalarFromConstant<float>(lhs_scale); float lhs_scale_float = GetScalarFromConstant<float>(args.lhs_scale);
float rhs_scale_float = GetScalarFromConstant<float>(rhs_scale); float rhs_scale_float = GetScalarFromConstant<float>(args.rhs_scale);
float new_scale_float = lhs_scale_float * rhs_scale_float; float new_scale_float = lhs_scale_float * rhs_scale_float;
auto new_input_scale = MakeConstantScalar(DataType::Float(32), new_scale_float); auto new_input_scale = MakeConstantScalar(float32_dtype, new_scale_float);
auto new_input_zero_point = zero_scalar; auto new_input_zero_point = zero_scalar;
// Requantize to get Q_c // Requantize to get Q_c
output = Requantize(output, input_shape, new_input_scale, new_input_zero_point, output_scale, output = Requantize(output, input_type.shape,
output_zero_point, input_dtype); new_input_scale,
new_input_zero_point,
args.output_scale,
args.output_zero_point,
input_type.dtype);
return output; return output;
} }
......
...@@ -30,14 +30,152 @@ ...@@ -30,14 +30,152 @@
#include <tvm/relay/qnn/attrs.h> #include <tvm/relay/qnn/attrs.h>
#include <vector> #include <vector>
#include "../../op/type_relations.h" #include "../../op/type_relations.h"
#include "../../transforms/infer_layout_util.h"
#include "../util.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
namespace qnn { namespace qnn {
static inline bool QnnBroadcastRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, /*
* Number of inputs for the Qnn binary operators.
* Refer the QNN_REGISTER_BINARY_OP macro to see
* what the operators are.
*/
static constexpr int kNumQnnBinaryOpInputs = 8;
/*
* Number of expected arg types.
*/
static constexpr int kNumQnnBinaryOpArgTypes = 9;
/*
* \brief Simple struct to organize the inputs to the Qnn
* binary operators. The main reason to have a struct
* is to be able to perform the common checks needed at a
* central location.
*/
struct QnnBinaryOpArguments {
Expr lhs;
Expr rhs;
Expr lhs_scale;
Expr lhs_zero_point;
Expr rhs_scale;
Expr rhs_zero_point;
Expr output_scale;
Expr output_zero_point;
explicit QnnBinaryOpArguments(const Array<Expr>& new_args) {
CHECK_EQ(new_args.size(), kNumQnnBinaryOpInputs);
int idx = 0;
lhs = new_args[idx++];
rhs = new_args[idx++];
lhs_scale = new_args[idx++];
lhs_zero_point = new_args[idx++];
rhs_scale = new_args[idx++];
rhs_zero_point = new_args[idx++];
output_scale = new_args[idx++];
output_zero_point = new_args[idx++];
CHECK_EQ(idx, kNumQnnBinaryOpInputs);
}
};
/*
* \brief Simple structure to hold the input tensor's dtype
* and shape. This structure allows a common point to do
* all the validation checks for Qnn binary operators.
*/
struct QnnBinaryOpTensorType {
DataType dtype;
Array <PrimExpr> shape;
explicit QnnBinaryOpTensorType(const Array<tvm::relay::Type>& arg_types,
const int32_t arg_idx) {
CHECK_EQ(arg_types.size(), kNumQnnBinaryOpArgTypes);
auto tensor_type = arg_types[arg_idx].as<TensorTypeNode>();
CHECK(tensor_type != nullptr);
dtype = tensor_type->dtype;
shape = tensor_type->shape;
}
};
/*
* \brief Converts the expression from expression's dtype
* to target dtype. This is mainly used for converting
* computations done in Int32 to lower precision Int8 or
* UInt8.
* \param expr The expression to whose dtype needs conversion.
* \param target_dtype The dtype of the target expression
* \return New expression with target dtype and possibly lower
* precision.
*/
inline Expr ConvertDtype(const Expr& expr,
const DataType& target_dtype) {
auto q_min = GetQmin(target_dtype);
auto q_max = GetQmax(target_dtype);
auto output = Clip(expr, q_min, q_max);
return Cast(output, target_dtype);
}
/*
* \brief Requantizes the given expression if expression's
* scale and zero point both do not match target scale and
* zero point. This is mainly needed for requantizing the
* input tensors with output tensor's scale and zero point
* to ease the computation of final quantized tensor.
* \param expr The expression on which the check needs to be performed.
* \param expr_scale The scale of the expression.
* \param expr_zero_point The zero point of the expression.
* \param target_scale The scale of the output tensor.
* \param target_zero_point The zero point of the output tensor.
* \param expr_shape The shape of the input expression.
* \return New expression that is requantized to target scale and zero
* point if the expression scale and zero points are different otherwise
* it simply casts the given expression to Int32 as no requantization is
* needed in this case.
*/
inline Expr RequantizeOrUpcast(const Expr& expr,
const Expr& expr_scale,
const Expr& expr_zero_point,
const Expr& target_scale,
const Expr& target_zero_point,
const Array<PrimExpr>& expr_shape,
const DataType& target_dtype = DataType::Int(32)) {
auto result = expr;
if (!IsEqualScalar(expr_scale, target_scale) ||
!IsEqualScalar(expr_zero_point, target_zero_point)) {
result = Requantize(expr, expr_shape, expr_scale, expr_zero_point,
target_scale, target_zero_point, target_dtype);
} else {
result = Cast(result, target_dtype);
}
return result;
}
/*! \brief Infer layout for QNN binary broadcast operators */
inline Array<Array<Layout> > QnnBinaryBroadcastLayout(
const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<tvm::relay::Type>& old_in_types) {
// Use Relay Binary Broadcast Infer correct layout.
auto layouts = BinaryBroadcastLayout(attrs, new_in_layouts, old_in_layouts, old_in_types);
// Fill the layouts of remaining input tensors - scales and zero points. The layouts of these
// tensors can be treated as C.
Layout channel_layout = Layout("C");
Array<Layout> input_layouts = {layouts[0][0], layouts[0][1], channel_layout, channel_layout,
channel_layout, channel_layout, channel_layout, channel_layout};
Array<Layout> output_layouts = layouts[1];
return {input_layouts, output_layouts};
}
static inline bool QnnBroadcastRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) { const TypeReporter& reporter) {
CHECK_EQ(types.size(), 9); CHECK_EQ(types.size(), kNumQnnBinaryOpArgTypes);
// Check the scale and zero point types // Check the scale and zero point types
CHECK(IsScalarType(types[2], DataType::Float(32))); // lhs_scale CHECK(IsScalarType(types[2], DataType::Float(32))); // lhs_scale
...@@ -74,7 +212,7 @@ static inline bool QnnBroadcastRel(const Array<Type>& types, int num_inputs, con ...@@ -74,7 +212,7 @@ static inline bool QnnBroadcastRel(const Array<Type>& types, int num_inputs, con
output_scale, output_zero_point}, Attrs(), {}); \ output_scale, output_zero_point}, Attrs(), {}); \
}); \ }); \
RELAY_REGISTER_OP("qnn." OpName) \ RELAY_REGISTER_OP("qnn." OpName) \
.set_num_inputs(8) \ .set_num_inputs(kNumQnnBinaryOpInputs) \
.add_argument("lhs", "Tensor", "The left hand side quantized tensor.") \ .add_argument("lhs", "Tensor", "The left hand side quantized tensor.") \
.add_argument("rhs", "Tensor", "The right hand side quantized tensor.") \ .add_argument("rhs", "Tensor", "The right hand side quantized tensor.") \
.add_argument("lhs_scale", "Tensor", "The scale of the lhs tensor.") \ .add_argument("lhs_scale", "Tensor", "The scale of the lhs tensor.") \
...@@ -83,7 +221,8 @@ static inline bool QnnBroadcastRel(const Array<Type>& types, int num_inputs, con ...@@ -83,7 +221,8 @@ static inline bool QnnBroadcastRel(const Array<Type>& types, int num_inputs, con
.add_argument("rhs_zero_point", "Tensor", "The zero_point of the rhs tensor.") \ .add_argument("rhs_zero_point", "Tensor", "The zero_point of the rhs tensor.") \
.add_argument("output_scale", "Tensor", "The scale of the output tensor.") \ .add_argument("output_scale", "Tensor", "The scale of the output tensor.") \
.add_argument("output_zero_point", "Tensor", "The zero_point of the output tensor.") \ .add_argument("output_zero_point", "Tensor", "The zero_point of the output tensor.") \
.add_type_rel("QnnBroadcast", QnnBroadcastRel) .add_type_rel("QnnBroadcast", QnnBroadcastRel) \
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", QnnBinaryBroadcastLayout)
} // namespace qnn } // namespace qnn
} // namespace relay } // namespace relay
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file src/relay/qnn/op/subtract.cc
* \brief QNN subtract operator.
*/
#include <tvm/relay/analysis.h>
#include <tvm/relay/op_attr_types.h>
#include "op_common.h"
namespace tvm {
namespace relay {
namespace qnn {
/*
* \brief Canonicalizes the QNN subtract op.
* \param attrs The empty attribute.
* \param new_args The new mutated args to the call node.
* \param arg_types The types of input and output.
* \return The sequence of Relay ops for add op.
*/
Expr QnnSubtractCanonicalize(const Attrs& attrs,
const Array<Expr>& new_args,
const Array<tvm::relay::Type>& arg_types) {
// Get the args.
QnnBinaryOpArguments args(new_args);
// Get the input dtype and shape.
QnnBinaryOpTensorType input_type(arg_types, 0);
// TODO(shoubhik) - The lowering can be further optimized. Instead of inserting requantize in
// the start, we can insert requantize at the end if both input tensors have same qnn params. In
// that case, we can first subtract the tensors, add the zero point, and requantize at the end.
// This can be done in future.
// Since the input qnn params can be different than output qnn params, we first requantize the
// input tensors to the output qnn params. Then we call relay.subtract on the requantized inputs.
// This subtraction results in extra subtraction of the output zero point. We further add
// the zero point. The whole process can be represented using following equations
//
// scale_c * (Q_c - zp_c) = scale_a * (Q_a - zp_a) - scale_b * (Q_b - zp_b)
//
// After requantizing Q_a and Q_b, equation becomes,
// scale_c * (Q_c - zp_c) = scale_c * (Q_a' - zp_c) - scale_c * (Q_b' - zp_c)
// scale_c * (Q_c - zp_c) = scale_c * (Q_a' - Q_b')
//
// Comparing the LHS and RHS, it results in
// Q_c = Q_a' - Q_b' + zp_c
// The subtract op is done in int32 precision.
// Requantize LHS if necessary. Computes Q_a'
auto requantized_lhs = RequantizeOrUpcast(args.lhs, args.lhs_scale,
args.lhs_zero_point,
args.output_scale,
args.output_zero_point,
input_type.shape);
// Requantize RHS if necessary. Computes Q_b'
auto requantized_rhs = RequantizeOrUpcast(args.rhs, args.rhs_scale,
args.rhs_zero_point,
args.output_scale,
args.output_zero_point,
input_type.shape);
// Computes Q_a' - Q_b'
auto output = Subtract(requantized_lhs, requantized_rhs);
// Add zero point. Computes (Q_a' - Q_b') + zp_c
auto zero_scalar = MakeConstantScalar(DataType::Int(32), 0);
if (!IsEqualScalar(args.output_zero_point, zero_scalar)) {
output = Add(output, args.output_zero_point);
}
// Go back to lower precision.
return ConvertDtype(output, input_type.dtype);
}
// QNN Subtraction operator.
QNN_REGISTER_BINARY_OP("subtract")
.describe("Elementwise subtract with with broadcasting for quantized tensors.")
.set_support_level(11)
.set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QnnSubtractCanonicalize);
} // namespace qnn
} // namespace relay
} // namespace tvm
...@@ -84,7 +84,7 @@ static inline Expr Requantize(const Expr& data, const Array<IndexExpr>& input_sh ...@@ -84,7 +84,7 @@ static inline Expr Requantize(const Expr& data, const Array<IndexExpr>& input_sh
attrs->rounding = std::move(rounding); attrs->rounding = std::move(rounding);
attrs->out_dtype = std::move(out_dtype); attrs->out_dtype = std::move(out_dtype);
return RequantizeLower(data, input_scale, input_zero_point, output_scale, output_zero_point, return RequantizeLower(data, input_scale, input_zero_point, output_scale, output_zero_point,
attrs.operator->(), input_shape, out_dtype); attrs.operator->(), input_shape, attrs->out_dtype);
} }
static inline int64_t get_const_int(const tvm::PrimExpr& x) { static inline int64_t get_const_int(const tvm::PrimExpr& x) {
......
...@@ -16,11 +16,9 @@ ...@@ -16,11 +16,9 @@
# under the License. # under the License.
import tvm import tvm
from tvm import te
import numpy as np import numpy as np
from tvm import relay from tvm import relay
from tvm.contrib import graph_runtime
import topi.testing
def test_tflite_same_io_qnn_params(): def test_tflite_same_io_qnn_params():
data_dtype = 'uint8' data_dtype = 'uint8'
...@@ -40,15 +38,15 @@ def test_tflite_same_io_qnn_params(): ...@@ -40,15 +38,15 @@ def test_tflite_same_io_qnn_params():
mod = relay.qnn.transform.CanonicalizeOps()(mod) mod = relay.qnn.transform.CanonicalizeOps()(mod)
func = mod["main"] func = mod["main"]
x_datas = [np.array((140, 153, 165, 178)).reshape((1,4)), x_datas = [np.array((140, 153, 165, 178)).reshape((1, 4)),
np.array((25, 153, 178, 216)).reshape((1,4)), np.array((25, 153, 178, 216)).reshape((1, 4)),
np.array((25, 153, 216, 165)).reshape((1,4))] np.array((25, 153, 216, 165)).reshape((1, 4))]
y_datas = [np.array((204, 178, 165, 140)).reshape((1,4)), y_datas = [np.array((204, 178, 165, 140)).reshape((1, 4)),
np.array((204, 178, 191, 25)).reshape((1,4)), np.array((204, 178, 191, 25)).reshape((1, 4)),
np.array((204, 178, 25, 191)).reshape((1,4))] np.array((204, 178, 25, 191)).reshape((1, 4))]
golden_outputs = [np.array((217,204,203,191)).reshape((1, 4)), golden_outputs = [np.array((217, 204, 203, 191)).reshape((1, 4)),
np.array((102, 204, 242, 114)).reshape((1,4)), np.array((102, 204, 242, 114)).reshape((1, 4)),
np.array((102, 204, 114, 229)).reshape((1,4))] np.array((102, 204, 114, 229)).reshape((1, 4))]
for i in range(0, 3): for i in range(0, 3):
x_data = x_datas[i] x_data = x_datas[i]
...@@ -78,15 +76,15 @@ def test_tflite_different_io_qnn_params(): ...@@ -78,15 +76,15 @@ def test_tflite_different_io_qnn_params():
mod = relay.qnn.transform.CanonicalizeOps()(mod) mod = relay.qnn.transform.CanonicalizeOps()(mod)
func = mod["main"] func = mod["main"]
x_datas = [np.array((76, 140, 153, 172)).reshape((1,4)), x_datas = [np.array((76, 140, 153, 172)).reshape((1, 4)),
np.array((133, 140, 146, 153)).reshape((1,4)), np.array((133, 140, 146, 153)).reshape((1, 4)),
np.array((76, 140, 172, 146)).reshape((1,4))] np.array((76, 140, 172, 146)).reshape((1, 4))]
y_datas = [np.array((136, 119, 128, 17)).reshape((1,4)), y_datas = [np.array((136, 119, 128, 17)).reshape((1, 4)),
np.array((136, 119, 111, 94)).reshape((1,4)), np.array((136, 119, 111, 94)).reshape((1, 4)),
np.array((136, 119, 17, 128)).reshape((1,4))] np.array((136, 119, 17, 128)).reshape((1, 4))]
golden_outputs = [np.array((120, 154, 167, 124)).reshape((1, 4)), golden_outputs = [np.array((120, 154, 167, 124)).reshape((1, 4)),
np.array((158, 154, 154, 150)).reshape((1,4)), np.array((158, 154, 154, 150)).reshape((1, 4)),
np.array((120, 154, 124, 163)).reshape((1,4))] np.array((120, 154, 124, 163)).reshape((1, 4))]
for i in range(0, 3): for i in range(0, 3):
x_data = x_datas[i] x_data = x_datas[i]
...@@ -116,8 +114,8 @@ def test_saturation(): ...@@ -116,8 +114,8 @@ def test_saturation():
mod = relay.qnn.transform.CanonicalizeOps()(mod) mod = relay.qnn.transform.CanonicalizeOps()(mod)
func = mod["main"] func = mod["main"]
x_data = np.array((255, 1, 1, 0)).reshape((1,4)) x_data = np.array((255, 1, 1, 0)).reshape((1, 4))
y_data = np.array((255, 255, 128, 0)).reshape((1,4)) y_data = np.array((255, 255, 128, 0)).reshape((1, 4))
golden_output = np.array((255, 255, 129, 0)).reshape((1, 4)) golden_output = np.array((255, 255, 129, 0)).reshape((1, 4))
intrp = relay.create_executor("graph", ctx=tvm.cpu(0), target="llvm") intrp = relay.create_executor("graph", ctx=tvm.cpu(0), target="llvm")
...@@ -138,8 +136,8 @@ def test_saturation(): ...@@ -138,8 +136,8 @@ def test_saturation():
mod = relay.qnn.transform.CanonicalizeOps()(mod) mod = relay.qnn.transform.CanonicalizeOps()(mod)
func = mod["main"] func = mod["main"]
x_data = np.array((255, 1, 1, 0)).reshape((1,4)) x_data = np.array((255, 1, 1, 0)).reshape((1, 4))
y_data = np.array((255, 255, 127, 0)).reshape((1,4)) y_data = np.array((255, 255, 127, 0)).reshape((1, 4))
golden_output = np.array((255, 129, 65, 0)).reshape((1, 4)) golden_output = np.array((255, 129, 65, 0)).reshape((1, 4))
intrp = relay.create_executor("graph", ctx=tvm.cpu(0), target="llvm") intrp = relay.create_executor("graph", ctx=tvm.cpu(0), target="llvm")
...@@ -160,8 +158,8 @@ def test_saturation(): ...@@ -160,8 +158,8 @@ def test_saturation():
mod = relay.qnn.transform.CanonicalizeOps()(mod) mod = relay.qnn.transform.CanonicalizeOps()(mod)
func = mod["main"] func = mod["main"]
x_data = np.array((255, 1, 1, 0)).reshape((1,4)) x_data = np.array((255, 1, 1, 0)).reshape((1, 4))
y_data = np.array((255, 255, 127, 0)).reshape((1,4)) y_data = np.array((255, 255, 127, 0)).reshape((1, 4))
golden_output = np.array((255, 129, 65, 0)).reshape((1, 4)) golden_output = np.array((255, 129, 65, 0)).reshape((1, 4))
intrp = relay.create_executor("graph", ctx=tvm.cpu(0), target="llvm") intrp = relay.create_executor("graph", ctx=tvm.cpu(0), target="llvm")
...@@ -182,8 +180,8 @@ def test_saturation(): ...@@ -182,8 +180,8 @@ def test_saturation():
mod = relay.qnn.transform.CanonicalizeOps()(mod) mod = relay.qnn.transform.CanonicalizeOps()(mod)
func = mod["main"] func = mod["main"]
x_data = np.array((255, 0, 1, 0)).reshape((1,4)) x_data = np.array((255, 0, 1, 0)).reshape((1, 4))
y_data = np.array((0, 128, 64, 0)).reshape((1,4)) y_data = np.array((0, 128, 64, 0)).reshape((1, 4))
golden_output = np.array((255, 255, 132, 0)).reshape((1, 4)) golden_output = np.array((255, 255, 132, 0)).reshape((1, 4))
intrp = relay.create_executor("graph", ctx=tvm.cpu(0), target="llvm") intrp = relay.create_executor("graph", ctx=tvm.cpu(0), target="llvm")
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
import numpy as np
from tvm import relay
def qnn_subtract_driver(x_datas, y_datas, golden_outputs,
scale_and_zp, data_dtype='uint8'):
# all x, y and golden outputs should be of the same length
assert len(x_datas) == len(y_datas)
assert len(y_datas) == len(golden_outputs)
x = relay.var("x", shape=(1, 4), dtype=data_dtype)
y = relay.var("y", shape=(1, 4), dtype=data_dtype)
lhs_scale = relay.const(scale_and_zp['lhs_scale'], 'float32')
lhs_zp = relay.const(scale_and_zp['lhs_zp'], 'int32')
rhs_scale = relay.const(scale_and_zp['rhs_scale'], 'float32')
rhs_zp = relay.const(scale_and_zp['rhs_zp'], 'int32')
output_scale = relay.const(scale_and_zp['output_scale'], 'float32')
output_zp = relay.const(scale_and_zp['output_zp'], 'int32')
z = relay.qnn.op.subtract(lhs=x, rhs=y,
lhs_scale=lhs_scale,
lhs_zero_point=lhs_zp,
rhs_scale=rhs_scale,
rhs_zero_point=rhs_zp,
output_scale=output_scale,
output_zero_point=output_zp)
func = relay.Function([x, y], z)
mod = tvm.IRModule.from_expr(func)
mod = relay.qnn.transform.CanonicalizeOps()(mod)
func = mod["main"]
for i in range(0, len(x_datas)):
x_data = x_datas[i]
y_data = y_datas[i]
golden_output = golden_outputs[i]
intrp = relay.create_executor("graph", ctx=tvm.cpu(0), target="llvm")
op_res = intrp.evaluate(func)(x_data, y_data)
np.testing.assert_equal(op_res.asnumpy(), golden_output)
def test_tflite_same_io_qnn_params():
scale_and_zp = {'lhs_scale': 0.00784314,
'lhs_zp': 127,
'rhs_scale': 0.00784314,
'rhs_zp': 127,
'output_scale': 0.00784314,
'output_zp': 127}
x_datas = [np.array((140, 153, 165, 178)).reshape((1, 4)),
np.array((25, 153, 178, 216)).reshape((1, 4)),
np.array((25, 153, 216, 165)).reshape((1, 4))]
y_datas = [np.array((204, 178, 165, 140)).reshape((1, 4)),
np.array((204, 178, 191, 25)).reshape((1, 4)),
np.array((204, 178, 25, 191)).reshape((1, 4))]
golden_outputs = [np.array((63, 102, 127, 165)).reshape((1, 4)),
np.array((0, 102, 114, 255)).reshape((1, 4)),
np.array((0, 102, 255, 101)).reshape((1, 4))]
qnn_subtract_driver(x_datas, y_datas, golden_outputs, scale_and_zp)
def test_tflite_different_io_qnn_params():
scale_and_zp = {'lhs_scale': 0.0156863,
'lhs_zp': 127,
'rhs_scale': 0.0117647,
'rhs_zp': 85,
'output_scale': 0.0235294,
'output_zp': 128}
x_datas = [np.array((76, 140, 153, 172)).reshape((1, 4)),
np.array((133, 140, 146, 153)).reshape((1, 4)),
np.array((76, 140, 172, 146)).reshape((1, 4))]
y_datas = [np.array((136, 119, 128, 17)).reshape((1, 4)),
np.array((136, 119, 111, 94)).reshape((1, 4)),
np.array((136, 119, 17, 128)).reshape((1, 4))]
golden_outputs = [np.array((68, 120, 123, 192)).reshape((1, 4)),
np.array((106, 120, 128, 140)).reshape((1, 4)),
np.array((68, 120, 192, 119)).reshape((1, 4))]
qnn_subtract_driver(x_datas, y_datas, golden_outputs, scale_and_zp)
def test_saturation():
# Same params
scale_and_zp = {'lhs_scale': 0.125,
'lhs_zp': 0,
'rhs_scale': 0.125,
'rhs_zp': 0,
'output_scale': 0.125,
'output_zp': 0}
x_data = [np.array((255, 1, 1, 0)).reshape((1, 4))]
y_data = [np.array((255, 255, 128, 0)).reshape((1, 4))]
golden_output = [np.array((0, 0, 0, 0)).reshape((1, 4))]
qnn_subtract_driver(x_data, y_data, golden_output, scale_and_zp)
# Same params, different scale
scale_and_zp = {'lhs_scale': 0.125,
'lhs_zp': 0,
'rhs_scale': 0.125,
'rhs_zp': 0,
'output_scale': 0.25,
'output_zp': 0}
x_data = [np.array((255, 1, 200, 0)).reshape((1, 4))]
y_data = [np.array((255, 255, 127, 0)).reshape((1, 4))]
golden_output = [np.array((0, 0, 36, 0)).reshape((1, 4))]
qnn_subtract_driver(x_data, y_data, golden_output, scale_and_zp)
# All params different
scale_and_zp = {'lhs_scale': 0.5,
'lhs_zp': 0,
'rhs_scale': 0.25,
'rhs_zp': 0,
'output_scale': 0.125,
'output_zp': 0}
x_data = [np.array((255, 0, 1, 0)).reshape((1, 4))]
y_data = [np.array((0, 128, 64, 0)).reshape((1, 4))]
golden_output = [np.array((255, 0, 0, 0)).reshape((1, 4))]
qnn_subtract_driver(x_data, y_data, golden_output, scale_and_zp)
if __name__ == '__main__':
test_tflite_same_io_qnn_params()
test_tflite_different_io_qnn_params()
test_saturation()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment