/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * Copyright (c) 2018 by Contributors
 * \file simplify_inference.cc
 */
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/transform.h>
#include <tvm/relay/op.h>
#include "./pattern_util.h"

namespace tvm {
namespace relay {

Expr BatchNormToInferUnpack(const Attrs attrs,
                            Expr data,
                            Expr gamma,
                            Expr beta,
                            Expr moving_mean,
                            Expr moving_var,
                            Type tdata) {
  auto ttype = tdata.as<TensorTypeNode>();
  CHECK(ttype);
  const auto param = attrs.as<BatchNormAttrs>();
  Expr epsilon = MakeConstantScalar(ttype->dtype, static_cast<float>(param->epsilon));
  Expr var_add_eps = Add(moving_var, epsilon);
  Expr sqrt_var = Sqrt(var_add_eps);
  Expr scale = Divide(MakeConstantScalar(ttype->dtype, 1.0f), sqrt_var);

  if (param->scale) {
    scale = Multiply(scale, gamma);
  }
  Expr neg_mean = Negative(moving_mean);
  Expr shift = Multiply(neg_mean, scale);
  if (param->center) {
    shift = Add(shift, beta);
  }

  auto ndim = ttype->shape.size();
  int axis = (param->axis < 0) ? param->axis + ndim : param->axis;
  scale = ExpandBiasToMatchAxis(scale, ndim, {axis});
  shift = ExpandBiasToMatchAxis(shift, ndim, {axis});

  Expr out = Multiply(data, scale);
  out = Add(out, shift);
  return out;
}

Expr LayerNormToInferUnpack(const Attrs attrs,
                            Expr data,
                            Expr gamma,
                            Expr beta,
                            Type tdata) {
  auto ttype = tdata.as<TensorTypeNode>();
  CHECK(ttype);
  const auto param = attrs.as<LayerNormAttrs>();
  CHECK(param);

  Expr epsilon = MakeConstantScalar(Float(32), static_cast<float>(param->epsilon));
  Expr mean = Mean(data, {param->axis}, true, false);
  Expr var = Variance(data, mean, {param->axis}, true, false);
  Expr denom = Sqrt(Add(var, epsilon));
  Expr out = Divide(Subtract(data, mean), denom);

  size_t ndim = ttype->shape.size();
  int axis = (param->axis < 0) ? param->axis + ndim : param->axis;
  if (param->scale) {
    out = Multiply(out, ExpandBiasToMatchAxis(gamma, ndim, {axis}));
  }
  if (param->center) {
    out = Add(out, ExpandBiasToMatchAxis(beta, ndim, {axis}));
  }
  return out;
}

class InferenceSimplifier : public ExprMutator {
 public:
  Expr VisitExpr_(const TupleGetItemNode* n) final {
    static const Op& batch_norm = Op::Get("nn.batch_norm");
    static const Op& dropout = Op::Get("nn.dropout");

    Expr new_e = ExprMutator::VisitExpr_(n);
    const auto* new_n = new_e.as<TupleGetItemNode>();
    if (new_n->index != 0) {
      return new_e;
    }
    if (const auto* call = new_n->tuple.as<CallNode>()) {
      if (call->op.same_as(batch_norm)) {
        return BatchNormToInferUnpack(call->attrs, call->args[0], call->args[1], call->args[2],
                                      call->args[3], call->args[4], ty_map_.at(call->args[0]));
      } else if (call->op.same_as(dropout)) {
        return call->args[0];
      }
    }
    return new_e;
  }

  Expr VisitExpr_(const CallNode* n) {
    static const Op& batch_norm = Op::Get("nn.batch_norm");
    static const Op& layer_norm = Op::Get("nn.layer_norm");
    auto new_n = ExprMutator::VisitExpr_(n);
    if (n->op.same_as(batch_norm)) {
      ty_map_[new_n.as<CallNode>()->args[0]] = n->args[0]->checked_type();
    } else if (n->op.same_as(layer_norm)) {
      const auto* call = new_n.as<CallNode>();
      return LayerNormToInferUnpack(call->attrs, call->args[0], call->args[1],
                                    call->args[2], n->args[0]->checked_type());
    }
    return new_n;
  }

 private:
  std::unordered_map<Expr, Type, NodeHash, NodeEqual> ty_map_;
};

Expr SimplifyInference(const Expr& e) {
  return InferenceSimplifier().Mutate(e);
}

namespace transform {

Pass SimplifyInference() {
  runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
    [=](Function f, Module m, PassContext pc) {
    return Downcast<Function>(SimplifyInference(f));
  };
  return CreateFunctionPass(pass_func, 0, "SimplifyInference",
                            {ir::StringImm::make("InferType")});
}

TVM_REGISTER_API("relay._transform.SimplifyInference")
.set_body_typed(SimplifyInference);

}  // namespace transform

}  // namespace relay
}  // namespace tvm