/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file src/relay/qnn/op/dequantize.cc
 * \brief QNN dequantize operator. Dequantize operator converts from quantized
 * domain to unquantized domain.
 */

#include <tvm/relay/analysis.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/qnn/attrs.h>
#include "../../pass/pattern_util.h"
#include "../util.h"

namespace tvm {
namespace relay {
namespace qnn {

TVM_REGISTER_NODE_TYPE(DequantizeAttrs);

bool DequantizeRel(const Array<Type>& types,
                   int num_inputs,
                   const Attrs& attrs,
                   const TypeReporter& reporter) {
  CHECK_EQ(types.size(), 2);
  const auto* data = types[0].as<TensorTypeNode>();
  const auto input_dtype = data->dtype;
  CHECK(input_dtype == DataType::Int(8) ||
        input_dtype == DataType::UInt(8) ||
        input_dtype == DataType::Int(32))
    << "Input type should be one of the quantized types [unit8, int8, int32] but was "
    <<  input_dtype;
  const Array<tvm::Expr> oshape = data->shape;
  // assign output type, output will always be float 32.
  reporter->Assign(types[1], TensorTypeNode::make(oshape, DataType::Float(32)));
  return true;
}

Expr MakeDequantize(Expr data,
                    double input_scale,
                    int32_t input_zero_point) {
  auto attrs = make_node<DequantizeAttrs>();
  attrs->input_scale = input_scale;
  attrs->input_zero_point = input_zero_point;
  // real_value = scale * (quantized_value - zero_point)
  // A more detailed explanation can be found here - https://github.com/google/gemmlowp/blob/master/doc/quantization.md
  static const Op& op = Op::Get("qnn.dequantize");
  return CallNode::make(op, {data}, Attrs(attrs), {});
}

Expr DequantizeLower(const Expr& input_tensor,
                     const DequantizeAttrs* attrs) {
  const auto input_zero_point = MakeConstantScalar(DataType::Int(32), attrs->input_zero_point);
  const auto input_scale = MakeConstantScalar(DataType::Float(32), attrs->input_scale);
  auto shift = Subtract(Cast(input_tensor, DataType::Int(32)), input_zero_point);
  auto scaled_output = Multiply(Cast(shift, DataType::Float(32)), input_scale);
  return scaled_output;
}

Expr DequantizeQnnCanonicalize(const Attrs& attrs,
                               const Array<Expr>& new_args,
                               const Array<tvm::relay::Type>& types) {
  CHECK_EQ(new_args.size(), 1);
  auto& data = new_args[0];
  const auto* dequantize_attrs = attrs.as<DequantizeAttrs>();
  CHECK(dequantize_attrs != nullptr);
  CHECK_EQ(types.size(), 2);
  return DequantizeLower(data, dequantize_attrs);
}

RELAY_REGISTER_OP("qnn.dequantize")
.describe(R"code(Dequantizes the input and produces float32 output.
The input is always quantized (int8, uint8) and will be converted to float32 given input scale and zero_point.
- **data**: Quantized tensor of any shape to dequantize. The input data can be of floating point
)code" TVM_ADD_FILELINE)
.set_attrs_type<DequantizeAttrs>()
.set_num_inputs(1)
.add_argument("data", "Tensor", "The tensor to dequantize.")
.set_support_level(11)
.add_type_rel("Dequantize", DequantizeRel)
.set_attr<FTVMLegalize>("FTVMQnnCanonicalize", DequantizeQnnCanonicalize);

TVM_REGISTER_API("relay.qnn.op._make.dequantize")
.set_body_typed(MakeDequantize);

}  // namespace qnn
}  // namespace relay
}  // namespace tvm