/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 *
 * \file realize.cc
 *
 * \brief Realizing the simulated graph into real low-precision
 *   graph.
 */

#include <tvm/relay/transform.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/attrs/annotation.h>
#include "./quantize.h"
#include "../pattern_util.h"
#include "../../qnn/util.h"

namespace tvm {
namespace relay {
namespace quantize {

using namespace relay::transform;

class QRealizeExpr;
class QRealizeIntExpr;

class QRealizeExprNode : public TempExprNode {
 public:
  Expr data;
  static constexpr const char* _type_key = "relay.quantize.QRealizeExpr";
  TVM_DECLARE_BASE_OBJECT_INFO(QRealizeExprNode, TempExprNode);
};

class QRealizeExpr : public TempExpr {
 public:
  TVM_DEFINE_OBJECT_REF_METHODS(QRealizeExpr, TempExpr, QRealizeExprNode);
};


class QRealizeIntExprNode : public QRealizeExprNode {
 public:
  Expr dom_scale;
  DataType dtype;

  void VisitAttrs(tvm::AttrVisitor* v) {
    v->Visit("data", &data);
    v->Visit("dom_scale", &dom_scale);
    v->Visit("dtype", &dtype);
  }

  Expr Realize() const final;

  TVM_DLL static QRealizeIntExpr make(Expr data, Expr dom_scale, DataType dtype);

  static constexpr const char * _type_key = "relay.quantize.QRealizeIntExpr";
  TVM_DECLARE_FINAL_OBJECT_INFO(QRealizeIntExprNode, QRealizeExprNode);
};

class QRealizeIntExpr : public QRealizeExpr {
 public:
  TVM_DEFINE_OBJECT_REF_METHODS(QRealizeIntExpr, QRealizeExpr, QRealizeIntExprNode);
};


Expr QRealizeIntExprNode::Realize() const {
  Expr data = this->data;
  // dequantize
  data = Cast(data, DataType::Float(32));
  data = Multiply(data, this->dom_scale);
  return data;
}

QRealizeIntExpr QRealizeIntExprNode::make(Expr data, Expr dom_scale, DataType dtype) {
  ObjectPtr<QRealizeIntExprNode> n = make_object<QRealizeIntExprNode>();
  n->data = std::move(data);
  n->dom_scale = std::move(dom_scale);
  n->dtype = std::move(dtype);
  return QRealizeIntExpr(n);
}


inline Expr ForwardOp(const Call& ref_call, const Array<Expr>& args) {
  return CallNode::make(ref_call->op,
    args, ref_call->attrs, ref_call->type_args);
}


/* calculate `data * s1 / s2`, use shift if possible */
inline Expr MulAndDiv(Expr data, float s1, float s2, DataType dtype,
                      const Array<IndexExpr> &data_shape) {
  const QConfig& cfg = QConfig::Current();
  // here we assume the dtype of data is dtype activation
  if (s1 == s2) return data;

  float factor = s1 / s2;
  float shift_factor = std::log2(factor);
  CHECK_GT(shift_factor, 0);
  if (static_cast<int>(shift_factor) == shift_factor) {
    return LeftShift(data, MakeConstantScalar(dtype,
                                              static_cast<int>(shift_factor)));
  } else if (static_cast<int>(factor) == factor) {
    return Multiply(data, MakeConstantScalar(dtype, factor));
  } else {
    data = qnn::FixedPointMultiply(
        Cast(data, DataType::Int(64)), factor, data_shape, cfg->rounding);
    return Cast(data, dtype);
  }
}

Expr QuantizeRealize(const Call& ref_call,
                     const Array<Expr>& new_args,
                     const ObjectRef& ctx) {
  const QConfig& cfg = QConfig::Current();
  // do not handle data type cast
  const auto param = ref_call->attrs.as<SimulatedQuantizeAttrs>();
  CHECK_EQ(param->rounding, "round");

  Expr dom_scale = new_args[1];
  Expr clip_min = new_args[2];
  Expr clip_max = new_args[3];

  float dom_scale_imm = GetScalarFromConstant<float>(dom_scale);
  float clip_min_imm = GetScalarFromConstant<float>(clip_min);
  float clip_max_imm = GetScalarFromConstant<float>(clip_max);

  // x * idom_scale = y * odom_scale
  // => y = x * idom_scale / odom_scale
  if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) {
    // int32->int8
    Expr data = n->data;
    float idom_scale_imm = GetScalarFromConstant<float>(n->dom_scale);
    float odom_scale_imm = GetScalarFromConstant<float>(dom_scale);
    if (idom_scale_imm == odom_scale_imm) {
      // same domain scale, only clip
      data = Clip(data, clip_min_imm, clip_max_imm);
      return QRealizeIntExprNode::make(data, dom_scale, n->dtype);
    }

    float shift_nbit = std::log2(odom_scale_imm / idom_scale_imm);
    CHECK_NE(shift_nbit, 0);
    if (static_cast<int>(shift_nbit) == shift_nbit) {
      if (shift_nbit > 0) {
        // use right shift
        if (cfg->round_for_shift) {
          float round_bias = std::pow(2.0, shift_nbit - 1);
          data = Add(data, MakeConstantScalar(cfg->dtype_activation,
                                              static_cast<int>(round_bias)));
        }
        data = RightShift(data, MakeConstantScalar(cfg->dtype_activation,
                                                   static_cast<int>(shift_nbit)));
      } else {
        data = LeftShift(data, MakeConstantScalar(cfg->dtype_activation,
                                                  static_cast<int>(shift_nbit)));
      }
      data = Clip(data, clip_min_imm, clip_max_imm);
      return QRealizeIntExprNode::make(data, dom_scale, n->dtype);
    } else {
      data = Cast(data, DataType::Int(64));
      data = qnn::FixedPointMultiply(data, idom_scale_imm / odom_scale_imm,
                                     ref_call->type_as<TensorTypeNode>()->shape,
                                     cfg->rounding);
      data = Cast(Clip(data, clip_min_imm, clip_max_imm), n->dtype);
      return QRealizeIntExprNode::make(data, dom_scale, n->dtype);
    }
  }

  // quantize from real
  CHECK(!new_args[0]->IsInstance<TempExprNode>());
  Expr data = new_args[0];
  Expr scaled_data = Multiply(data, MakeConstantScalar(DataType::Float(32), 1 / dom_scale_imm));
  Expr round_data = Clip(Round(scaled_data), clip_min_imm, clip_max_imm);
  return QRealizeIntExprNode::make(round_data, dom_scale, DataType::Float(32));
}

Expr FoldConstantOpt(const Expr& expr) {
  auto mod = IRModule::FromExpr(expr);
  mod = transform::FoldConstant()(mod);
  auto entry_func = Downcast<Function>(mod->Lookup("main"));
  return expr.as<FunctionNode>() == nullptr ? entry_func->body : entry_func;
}

RELAY_REGISTER_OP("relay.op.annotation.simulated_quantize")
.set_attr<FForwardRewrite>("FQRealizeRewrite", QuantizeRealize);


Expr Conv2dRealize(const Call& ref_call,
                   const Array<Expr>& new_args,
                   const ObjectRef& ctx) {
  const QConfig& cfg = QConfig::Current();
  CHECK_EQ(new_args.size(), 2);
  if (!new_args[0]->IsInstance<TempExprNode>() && !new_args[1]->IsInstance<TempExprNode>()) {
    return Expr(nullptr);
  }
  const auto* lhs = new_args[0].as<QRealizeIntExprNode>();
  CHECK(lhs);
  const auto* rhs = new_args[1].as<QRealizeIntExprNode>();
  CHECK(rhs);

  Expr ldata = lhs->data;
  if (lhs->dtype != cfg->dtype_input) {
    ldata = Cast(ldata, cfg->dtype_input);
  }
  Expr rdata = Cast(rhs->data, cfg->dtype_weight);

  const auto ref_attrs = ref_call->attrs.as<Conv2DAttrs>();
  auto attrs = make_object<Conv2DAttrs>();
  *attrs = *ref_attrs;
  DataType out_dtype = cfg->dtype_activation;
  attrs->out_dtype = out_dtype;

  Expr ret = CallNode::make(ref_call->op,
    {ldata, rdata}, Attrs(attrs), ref_call->type_args);
  Expr mul = Multiply(lhs->dom_scale, rhs->dom_scale);
  Expr dom_scale = FoldConstantOpt(mul);
  return QRealizeIntExprNode::make(ret, dom_scale, out_dtype);
}

RELAY_REGISTER_OP("nn.conv2d")
.set_attr<FForwardRewrite>("FQRealizeRewrite", Conv2dRealize);


Expr DenseRealize(const Call& ref_call,
                  const Array<Expr>& new_args,
                  const ObjectRef& ctx) {
  const QConfig& cfg = QConfig::Current();
  CHECK_EQ(new_args.size(), 2);
  if (!new_args[0]->IsInstance<TempExprNode>() || !new_args[1]->IsInstance<TempExprNode>()) {
    return Expr(nullptr);
  }
  const auto* lhs = new_args[0].as<QRealizeIntExprNode>();
  const auto* rhs = new_args[1].as<QRealizeIntExprNode>();

  Expr ldata = lhs->data;
  if (lhs->dtype != cfg->dtype_input) {
    ldata = Cast(ldata, cfg->dtype_input);
  }
  Expr rdata = Cast(rhs->data, cfg->dtype_weight);

  const auto ref_attrs = ref_call->attrs.as<DenseAttrs>();
  auto attrs = make_object<DenseAttrs>();
  *attrs = *ref_attrs;
  DataType out_dtype = cfg->dtype_activation;
  attrs->out_dtype = out_dtype;

  Expr ret = CallNode::make(ref_call->op,
          {ldata, rdata}, Attrs(attrs), ref_call->type_args);
  Expr mul = Multiply(lhs->dom_scale, rhs->dom_scale);
  Expr dom_scale = FoldConstantOpt(mul);
  return QRealizeIntExprNode::make(ret, dom_scale, out_dtype);
}

RELAY_REGISTER_OP("nn.dense")
.set_attr<FForwardRewrite>("FQRealizeRewrite", DenseRealize);


Expr MulRealize(const Call& ref_call,
                const Array<Expr>& new_args,
                const ObjectRef& ctx) {
  const QConfig& cfg = QConfig::Current();
  CHECK_EQ(new_args.size(), 2);
  if (new_args[0].as<QRealizeIntExprNode>() && new_args[1].as<QRealizeIntExprNode>()) {
    // execute the operation with activation data type.
    const auto* lhs = new_args[0].as<QRealizeIntExprNode>();
    const auto* rhs = new_args[1].as<QRealizeIntExprNode>();
    Expr ldata = lhs->data;
    Expr rdata = rhs->data;

    DataType dtype = cfg->dtype_activation;
    if (lhs->dtype != dtype) {
      ldata = Cast(ldata, dtype);
    }
    if (rhs->dtype != dtype) {
      rdata = Cast(rdata, dtype);
    }

    Expr ret = ForwardOp(ref_call, {ldata, rdata});
    Expr mul = Multiply(lhs->dom_scale, rhs->dom_scale);
    Expr dom_scale = FoldConstantOpt(mul);
    return QRealizeIntExprNode::make(ret, dom_scale, dtype);
  }
  CHECK(!new_args[0]->IsInstance<TempExprNode>() && !new_args[1]->IsInstance<TempExprNode>());
  return Expr(nullptr);
}

RELAY_REGISTER_OP("multiply")
.set_attr<FForwardRewrite>("FQRealizeRewrite", MulRealize);


float ChooseDomScale(const std::vector<const QRealizeIntExprNode*>& nptrs) {
  if (nptrs.size() == 2) {
    // x = a * s1, y = b * s2
    // x + y = (a * s1 / s2 + b) * s2, if s1 > s2
    //       = (a + b * s2 / s1) * s1, if s2 > s1
    float s1 = GetScalarFromConstant<float>(nptrs[0]->dom_scale);
    float s2 = GetScalarFromConstant<float>(nptrs[1]->dom_scale);
    return s1 > s2 ? s2 : s1;
  } else {
    const QConfig& cfg = QConfig::Current();
    float scale = cfg->global_scale;
    return scale / std::pow(2.0, cfg->nbit_activation - 1);
  }
}


/* \brief Unify the dom scale of arguments */
Array<Expr> UnifyDTypeScale(const Array<Expr>& ref_args, const Array<Expr>& args,
                            DataType* dtype_ptr, Expr* scale_ptr) {
  static const Op& simulated_quantize = Op::Get("relay.op.annotation.simulated_quantize");
  const QConfig& cfg = QConfig::Current();

  std::vector<const QRealizeIntExprNode*> nptrs;
  Array<Expr> ret;
  for (auto arg : args) {
    const auto* nptr = arg.as<QRealizeIntExprNode>();
    CHECK(nptr);
    nptrs.push_back(nptr);
    ret.push_back(nptr->data);
  }

  // unify the data type
  CHECK_EQ(ref_args.size(), args.size());
  DataType dtype;

  if (ret.size() == 2 && nptrs[1]->dtype == cfg->dtype_input) {
    dtype = cfg->dtype_input;
  } else {
    dtype = cfg->dtype_activation;
  }
  for (size_t i = 0; i < ret.size(); ++i) {
    auto ref_arg = ref_args[i].as<CallNode>();
    if (nptrs[i]->dtype != dtype) {
      ret.Set(i, Cast(ret[i], dtype));
    } else if (ref_arg && ref_arg->op.same_as(simulated_quantize) &&
               ref_arg->attrs.as<SimulatedQuantizeAttrs>()->kind == kQInput) {
      auto new_arg = Cast(ret[i], cfg->dtype_input);
      new_arg = StopFusion(new_arg);
      ret.Set(i, Cast(new_arg, dtype));
    }
  }

  // unify the dom_scale
  float s = ChooseDomScale(nptrs);
  Expr dom_scale = MakeConstantScalar(DataType::Float(32), s);
  for (size_t i = 0; i < ret.size(); ++i) {
    float cur_s = GetScalarFromConstant<float>(nptrs[i]->dom_scale);
    ret.Set(i, MulAndDiv(ret[i], cur_s, s, dtype, ref_args[i]->type_as<TensorTypeNode>()->shape));
  }

  *dtype_ptr = dtype;
  *scale_ptr = dom_scale;
  return ret;
}

Expr AddRealize(const Call& ref_call,
                const Array<Expr>& new_args,
                const ObjectRef& ctx) {
  CHECK_EQ(new_args.size(), 2);
  if (new_args[0].as<QRealizeIntExprNode>() && new_args[1].as<QRealizeIntExprNode>()) {
    DataType dtype;
    Expr dom_scale;
    Array<Expr> ret_args = UnifyDTypeScale(ref_call->args, new_args, &dtype, &dom_scale);
    Expr ret = ForwardOp(ref_call, ret_args);
    return QRealizeIntExprNode::make(ret, dom_scale, dtype);
  }

  CHECK(!new_args[0]->IsInstance<TempExprNode>() && !new_args[1]->IsInstance<TempExprNode>());
  return Expr(nullptr);
}

RELAY_REGISTER_OP("add")
.set_attr<FForwardRewrite>("FQRealizeRewrite", AddRealize);

Expr ClipRealize(const Call& ref_call,
                 const Array<Expr>& new_args,
                 const ObjectRef& ctx) {
  CHECK_EQ(new_args.size(), 1);
  if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) {
    const auto ref_attrs = ref_call->attrs.as<ClipAttrs>();
    auto attrs = make_object<ClipAttrs>();
    double dom_scale = GetScalarFromConstant<float>(n->dom_scale);
    attrs->a_min = ref_attrs->a_min / dom_scale;
    attrs->a_max = ref_attrs->a_max / dom_scale;

    Expr ret = CallNode::make(ref_call->op,
      {n->data}, Attrs(attrs), ref_call->type_args);
    return QRealizeIntExprNode::make(ret, n->dom_scale, n->dtype);
  }
  CHECK(!new_args[0]->IsInstance<TempExprNode>());
  return Expr(nullptr);
}

RELAY_REGISTER_OP("clip")
.set_attr<FForwardRewrite>("FQRealizeRewrite", ClipRealize);


Expr ConcatenateRealize(const Call& ref_call,
                        const Array<Expr>& new_args,
                        const ObjectRef& ctx) {
  CHECK_EQ(new_args.size(), 1);
  CHECK_EQ(ref_call->args.size(), 1);

  const auto* tuple = new_args[0].as<TupleNode>();
  const auto* ref_tuple = ref_call->args[0].as<TupleNode>();
  CHECK(tuple);
  CHECK(ref_tuple);
  const Array<Expr>& arr = tuple->fields;
  const Array<Expr>& ref_arr = ref_tuple->fields;

  if (arr[0].as<QRealizeIntExprNode>()) {
    DataType dtype;
    Expr dom_scale;
    Array<Expr> ret_args = UnifyDTypeScale(ref_arr, arr, &dtype, &dom_scale);
    Expr ret = ForwardOp(ref_call, {TupleNode::make(ret_args)});
    return QRealizeIntExprNode::make(ret, dom_scale, dtype);
  } else {
    for (auto arg : new_args) {
      CHECK(!arg->IsInstance<TempExprNode>());
    }
    return Expr(nullptr);
  }
}

RELAY_REGISTER_OP("concatenate")
.set_attr<FForwardRewrite>("FQRealizeRewrite", ConcatenateRealize);


/* \brief forward the original operator */
Expr IdentityRealize(const Call& ref_call,
                     const Array<Expr>& new_args,
                     const ObjectRef& ctx) {
  CHECK_EQ(new_args.size(), 1);
  if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) {
    Expr ret = ForwardOp(ref_call, {n->data});
    return QRealizeIntExprNode::make(ret, n->dom_scale, n->dtype);
  }
  CHECK(!new_args[0]->IsInstance<TempExprNode>());
  return Expr(nullptr);
}

RELAY_REGISTER_OP("nn.relu")
.set_attr<FForwardRewrite>("FQRealizeRewrite", IdentityRealize);

RELAY_REGISTER_OP("strided_slice")
.set_attr<FForwardRewrite>("FQRealizeRewrite", IdentityRealize);

RELAY_REGISTER_OP("annotation.stop_fusion")
.set_attr<FForwardRewrite>("FQRealizeRewrite", IdentityRealize);

/* \brief for unary operators which requantize its input to dtype_nbit */
Expr CastDtypeInputRealize(const Call& ref_call,
                           const Array<Expr>& new_args,
                           const ObjectRef& ctx) {
  const QConfig& cfg = QConfig::Current();
  CHECK_EQ(new_args.size(), 1);
  if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) {
    Expr data = Cast(n->data, cfg->dtype_input);
    Expr ret = ForwardOp(ref_call, {data});
    return QRealizeIntExprNode::make(ret, n->dom_scale, cfg->dtype_input);
  }
  CHECK(!new_args[0]->IsInstance<TempExprNode>());
  return Expr(nullptr);
}

RELAY_REGISTER_OP("nn.max_pool2d")
.set_attr<FForwardRewrite>("FQRealizeRewrite", CastDtypeInputRealize);


Expr AvgPoolRealize(const Call& ref_call,
                    const Array<Expr>& new_args,
                    const ObjectRef& ctx) {
  const QConfig& cfg = QConfig::Current();
  CHECK_EQ(new_args.size(), 1);
  if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) {
    Expr data = n->data;
    if (n->dtype != cfg->dtype_activation) {
      data = Cast(n->data, cfg->dtype_activation);
    }
    Expr ret = ForwardOp(ref_call, {data});
    return QRealizeIntExprNode::make(ret, n->dom_scale, cfg->dtype_activation);
  }
  CHECK(!new_args[0]->IsInstance<TempExprNode>());
  return Expr(nullptr);
}

RELAY_REGISTER_OP("nn.avg_pool2d")
.set_attr<FForwardRewrite>("FQRealizeRewrite", AvgPoolRealize);

RELAY_REGISTER_OP("nn.global_avg_pool2d")
.set_attr<FForwardRewrite>("FQRealizeRewrite", AvgPoolRealize);

Expr CastHintRealize(const Call& ref_call,
                     const Array<Expr>& new_args,
                     const ObjectRef& ctx) {
  const auto param = ref_call->attrs.as<CastHintAttrs>();
  CHECK_EQ(new_args.size(), 1);
  if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) {
    Expr ret = Cast(n->data, param->dtype);
    return QRealizeIntExprNode::make(ret, n->dom_scale, param->dtype);
  }
  CHECK(!new_args[0]->IsInstance<TempExprNode>());
  return Expr(nullptr);
}

RELAY_REGISTER_OP("annotation.cast_hint")
.set_attr<FForwardRewrite>("FQRealizeRewrite", CastHintRealize);

Pass QuantizeRealizePass() {
  runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
    [=](Function f, IRModule m, PassContext pc) {
      return Downcast<Function>(
          ForwardRewrite(f, "FQRealizeRewrite", nullptr, nullptr));
  };
  return CreateFunctionPass(pass_func, 1, "QuantizeRealize", {});
}

TVM_REGISTER_GLOBAL("relay._quantize.QuantizeRealize")
.set_body_typed(QuantizeRealizePass);

}  // namespace quantize
}  // namespace relay
}  // namespace tvm