/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file src/tvm/relay/ir/alpha_equal.cc
 * \brief Alpha equality check by deep comparing two nodes.
 */
#include <tvm/ir/type_functor.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/attrs/nn.h>
#include "../../ir/attr_functor.h"
namespace tvm {
namespace relay {

// Alpha Equal handler for Relay.
class AlphaEqualHandler:
      public AttrsEqualHandler,
      public TypeFunctor<bool(const Type&, const Type&)>,
      public ExprFunctor<bool(const Expr&, const Expr&)>,
      public PatternFunctor<bool(const Pattern&, const Pattern&)> {
 public:
  explicit AlphaEqualHandler(bool map_free_var, bool assert_mode)
    : map_free_var_(map_free_var), assert_mode_(assert_mode) { }

  /*!
   * Check equality of two nodes.
   * \param lhs The left hand operand.
   * \param rhs The right hand operand.
   * \return The comparison result.
   */
  bool Equal(const ObjectRef& lhs, const ObjectRef& rhs) {
    if (!lhs.defined() || !rhs.defined()) return false;
    if (lhs.same_as(rhs)) return true;
    if (lhs->IsInstance<TypeNode>() || rhs->IsInstance<TypeNode>()) {
      if (!rhs->IsInstance<TypeNode>() || !lhs->IsInstance<TypeNode>()) return false;
      return TypeEqual(Downcast<Type>(lhs), Downcast<Type>(rhs));
    }
    if (lhs->IsInstance<ExprNode>() || rhs->IsInstance<ExprNode>()) {
      if (!rhs->IsInstance<ExprNode>() || !lhs->IsInstance<ExprNode>()) return false;
      return ExprEqual(Downcast<Expr>(lhs), Downcast<Expr>(rhs));
    }
    if (const auto lhsm = lhs.as<IRModuleNode>()) {
      auto rhsm = rhs.as<IRModuleNode>();
      if (!rhsm) return false;
      if (lhsm->functions.size() != rhsm->functions.size()) return false;
      for (const auto& p : lhsm->functions) {
        if (!Equal(p.second, rhsm->Lookup(p.first->name_hint))) return false;
      }
      if (lhsm->type_definitions.size() != rhsm->type_definitions.size()) return false;
      for (const auto& p : lhsm->type_definitions) {
        if (!rhsm->ContainGlobalTypeVar(p.first->name_hint) ||
            !Equal(p.second, rhsm->LookupTypeDef(p.first->name_hint))) {
          return false;
        }
      }
      return true;
    }
    return AttrEqual(lhs, rhs);
  }

  bool DoubleEqual(double l, double r) {
    return true;
  }
  /*!
   * Check equality of two attributes.
   * \param lhs The left hand operand.
   * \param rhs The right hand operand.
   * \return The comparison result.
   */
  bool AttrEqual(const ObjectRef& lhs, const ObjectRef& rhs) {
    auto compute = [&]() {
      if (&lhs == &rhs) return true;
      if (auto lhsd = lhs.as<DictAttrsNode>()) {
        auto rhsd = rhs.as<DictAttrsNode>();
        if (!rhsd) return false;
        if (lhsd->dict.size() != rhsd->dict.size()) return false;
        for (const auto& k : lhsd->dict) {
          if (!Equal(k.second, rhsd->dict[k.first])) return false;
        }
        return true;
      }
      if (auto lhsbn = lhs.as<BatchNormAttrs>()) {
        auto rhsbn = rhs.as<BatchNormAttrs>();
        if (!rhsbn) return false;
        return (lhsbn->axis == rhsbn->axis)
          && DoubleEqual(lhsbn->epsilon, rhsbn->epsilon)
          && (lhsbn->center == rhsbn->center)
          && (lhsbn->scale == rhsbn->scale);
      }
      return AttrsEqualHandler::Equal(lhs, rhs);
    };
    return Compare(compute(), lhs, rhs);
  }
  /*!
   * Check equality of two types.
   * \param lhs The left hand operand.
   * \param rhs The right hand operand.
   * \return the comparison result.
   */
  bool TypeEqual(const Type& lhs, const Type& rhs) {
    auto compute = [&]() {
      if (lhs.same_as(rhs)) return true;
      if (!lhs.defined() || !rhs.defined()) return false;
      return this->VisitType(lhs, rhs);
    };
    return Compare(compute(), lhs, rhs);
  }

  bool Compare(bool result, const ObjectRef& lhs, const ObjectRef& rhs) {
    if (assert_mode_) {
      CHECK(result) << "\n" << AsText(lhs, true) << "\nis not equal to:\n" << AsText(rhs, true);
    }
    return result;
  }
  /*!
   * Check equality of two expressions.
   *
   * \note We run graph structural equality checking when comparing two Exprs.
   *   This means that AlphaEqualHandler can only be used once for each pair.
   *   The equality checker checks data-flow equvalence of the Expr DAG.
   *   This function also runs faster as it memomizes equal_map.
   *
   * \param lhs The left hand operand.
   * \param rhs The right hand operand.
   * \return The comparison result.
   */
  bool ExprEqual(const Expr& lhs, const Expr& rhs) {
    auto compute = [&]() {
      if (lhs.same_as(rhs)) return true;
      if (!lhs.defined() || !rhs.defined()) return false;
      auto it = equal_map_.find(lhs);
      if (it != equal_map_.end()) {
        return it->second.same_as(rhs);
      }
      if (this->VisitExpr(lhs, rhs)) {
        equal_map_[lhs] = rhs;
        return true;
      } else {
        return false;
      }
    };
    return Compare(compute(), lhs, rhs);
  }

 protected:
  /*!
   * \brief Check if data type equals each other.
   * \param lhs The left hand operand.
   * \param rhs The right hand operand.
   * \return The compare result.
   */
  bool DataTypeEqual(const DataType& lhs, const DataType& rhs) {
    return lhs == rhs;
  }
  /*!
   * \brief Check Equality of leaf node of the graph.
   *  if map_free_var_ is set to true, try to map via equal node.
   * \param lhs The left hand operand.
   * \param rhs The right hand operand.
   * \return The compare result.
   */
  bool LeafObjectEqual(const ObjectRef& lhs, const ObjectRef& rhs) {
    if (lhs.same_as(rhs)) return true;
    auto it = equal_map_.find(lhs);
    if (it != equal_map_.end()) {
      return it->second.same_as(rhs);
    } else {
      if (map_free_var_) {
        if (lhs->type_index() != rhs->type_index()) return false;
        equal_map_[lhs] = rhs;
        return true;
      } else {
        return false;
      }
    }
  }
  using AttrsEqualHandler::VisitAttr_;
  bool VisitAttr_(const tvm::tir::VarNode* lhs, const ObjectRef& other) final {
    return LeafObjectEqual(GetRef<ObjectRef>(lhs), other);
  }

  // Type equality
  bool VisitType_(const TensorTypeNode* lhs, const Type& other) final {
    if (const TensorTypeNode* rhs = other.as<TensorTypeNode>()) {
      return (lhs->dtype == rhs->dtype &&
              AttrEqual(lhs->shape, rhs->shape));
    } else {
      return false;
    }
  }

  bool VisitType_(const IncompleteTypeNode* lhs, const Type& other) final {
    return LeafObjectEqual(GetRef<ObjectRef>(lhs), other);
  }

  bool VisitType_(const TypeVarNode* lhs, const Type& other) final {
    if (const TypeVarNode* rhs = other.as<TypeVarNode>()) {
      if (lhs->kind != rhs->kind) return false;
      return LeafObjectEqual(GetRef<ObjectRef>(lhs), other);
    } else {
      return false;
    }
  }

  bool VisitType_(const FuncTypeNode* lhs, const Type& other) final {
    if (const FuncTypeNode* rhs = other.as<FuncTypeNode>()) {
      if (lhs->arg_types.size() != rhs->arg_types.size()) return false;
      if (lhs->type_params.size() != rhs->type_params.size()) return false;
      if (lhs->type_constraints.size() != rhs->type_constraints.size()) return false;
      for (size_t i = 0; i < lhs->type_params.size(); ++i) {
        if (lhs->type_params[i]->kind != rhs->type_params[i]->kind) {
          return false;
        }
        equal_map_[lhs->type_params[i]] = rhs->type_params[i];
      }
      for (size_t i = 0; i < lhs->arg_types.size(); i++) {
        if (!TypeEqual(lhs->arg_types[i], rhs->arg_types[i])) return false;
      }
      if (!TypeEqual(lhs->ret_type, rhs->ret_type)) return false;
      for (size_t i = 0; i < lhs->type_constraints.size(); i++) {
        if (!TypeEqual(lhs->type_constraints[i],
                       rhs->type_constraints[i])) {
          return false;
        }
      }
      return true;
    } else {
      return false;
    }
  }

  bool VisitType_(const TypeRelationNode* lhs, const Type& other) final {
    if (const TypeRelationNode* rhs = other.as<TypeRelationNode>()) {
      if (lhs->func->name != rhs->func->name) return false;
      if (lhs->num_inputs != rhs->num_inputs) return false;
      if (!this->AttrEqual(lhs->attrs, rhs->attrs)) return false;
      if (lhs->args.size() != rhs->args.size()) return false;
      for (size_t i = 0; i < lhs->args.size(); ++i) {
        if (!TypeEqual(lhs->args[i], rhs->args[i])) return false;
      }
      return true;
    } else {
      return false;
    }
  }

  bool VisitType_(const TupleTypeNode* lhs, const Type& other) final {
    if (const TupleTypeNode* rhs = other.as<TupleTypeNode>()) {
      if (lhs->fields.size() != rhs->fields.size()) return false;
      for (size_t i = 0; i < lhs->fields.size(); ++i) {
        if (!TypeEqual(lhs->fields[i], rhs->fields[i])) return false;
      }
      return true;
    } else {
      return false;
    }
  }

  bool VisitType_(const RelayRefTypeNode* lhs, const Type& other) final {
    if (const RelayRefTypeNode* rhs = other.as<RelayRefTypeNode>()) {
      return TypeEqual(lhs->value, rhs->value);
    }
    return false;
  }

  bool VisitType_(const GlobalTypeVarNode* lhs, const Type& other) final {
    return LeafObjectEqual(GetRef<ObjectRef>(lhs), other);
  }

  bool VisitType_(const TypeCallNode* lhs, const Type& other) final {
    const TypeCallNode* rhs = other.as<TypeCallNode>();
    if (rhs == nullptr
        || lhs->args.size() != rhs->args.size()
        || !TypeEqual(lhs->func, rhs->func)) {
      return false;
    }

    for (size_t i = 0; i < lhs->args.size(); ++i) {
      if (!TypeEqual(lhs->args[i], rhs->args[i])) {
        return false;
      }
    }
    return true;
  }

  bool VisitType_(const TypeDataNode* lhs, const Type& other) final {
    const TypeDataNode* rhs = other.as<TypeDataNode>();
    if (rhs == nullptr
        || lhs->type_vars.size() != rhs->type_vars.size()
        || !TypeEqual(lhs->header, rhs->header)) {
      return false;
    }
    for (size_t i = 0; i < lhs->type_vars.size(); ++i) {
      if (!TypeEqual(lhs->type_vars[i], rhs->type_vars[i])) {
        return false;
      }
    }
    for (size_t i = 0; i < lhs->constructors.size(); ++i) {
      if (!ExprEqual(lhs->constructors[i], rhs->constructors[i])) {
        return false;
      }
    }
    return true;
  }

  // Expr equal checking.
  bool NDArrayEqual(const runtime::NDArray& lhs,
                    const runtime::NDArray& rhs) {
    if (lhs.defined() != rhs.defined()) {
      return false;
    } else if (lhs.same_as(rhs)) {
      return true;
    } else {
      auto ldt = lhs->dtype;
      auto rdt = rhs->dtype;
      CHECK_EQ(lhs->ctx.device_type, kDLCPU) << "can only compare CPU tensor";
      CHECK_EQ(rhs->ctx.device_type, kDLCPU) << "can only compare CPU tensor";
      if (ldt.code == rdt.code && ldt.lanes == rdt.lanes && ldt.bits == rdt.bits) {
        size_t data_size = runtime::GetDataSize(*lhs.operator->());
        return std::memcmp(lhs->data, rhs->data, data_size) == 0;
      } else {
        return false;
      }
    }
  }
  // merge declaration of two variables together.
  bool MergeVarDecl(const Var& lhs, const Var& rhs) {
    if (lhs.same_as(rhs)) return true;
    if (!lhs.defined() || !rhs.defined()) return false;
    if (!TypeEqual(lhs->type_annotation,
                   rhs->type_annotation)) return false;
    CHECK(!equal_map_.count(lhs))
        << "Duplicated declaration of variable " <<  lhs;
    equal_map_[lhs] = rhs;
    return true;
  }

  bool VisitExpr_(const VarNode* lhs, const Expr& other) final {
    // This function will only be triggered if we are matching free variables.
    if (const VarNode* rhs = other.as<VarNode>()) {
      if (lhs->name_hint() != rhs->name_hint()) return false;
      if (!TypeEqual(lhs->type_annotation, rhs->type_annotation)) return false;
      return LeafObjectEqual(GetRef<ObjectRef>(lhs), other);
    } else {
      return false;
    }
  }

  bool VisitExpr_(const GlobalVarNode* lhs, const Expr& other) final {
    if (const GlobalVarNode* rhs = other.as<GlobalVarNode>()) {
      // use name equality for global var for now.
      return lhs->name_hint == rhs->name_hint;
    }
    return false;
  }

  bool VisitExpr_(const TupleNode* lhs, const Expr& other) final {
    if (const TupleNode* rhs = other.as<TupleNode>()) {
      if (lhs->fields.size() != rhs->fields.size()) return false;
      for (size_t i = 0; i < lhs->fields.size(); ++i) {
        if (!ExprEqual(lhs->fields[i], rhs->fields[i])) return false;
      }
      return true;
    } else {
      return false;
    }
  }

  bool VisitExpr_(const FunctionNode* lhs, const Expr& other) final {
    if (const FunctionNode* rhs = other.as<FunctionNode>()) {
      if (lhs->params.size() != rhs->params.size()) return false;
      if (lhs->type_params.size() != rhs->type_params.size()) return false;
      // map type parameter to be the same
      for (size_t i = 0; i < lhs->type_params.size(); ++i) {
        if (lhs->type_params[i]->kind != rhs->type_params[i]->kind) return false;
        equal_map_[lhs->type_params[i]] = rhs->type_params[i];
      }
      // check parameter type annotations
      for (size_t i = 0; i < lhs->params.size(); ++i) {
        if (!MergeVarDecl(lhs->params[i], rhs->params[i])) return false;
      }
      // check return types.
      if (!TypeEqual(lhs->ret_type, rhs->ret_type)) return false;
      if (!AttrEqual(lhs->attrs, rhs->attrs)) return false;
      return ExprEqual(lhs->body, rhs->body);
    } else {
      return false;
    }
  }

  bool VisitExpr_(const CallNode* lhs, const Expr& other) final {
    if (const CallNode* rhs = other.as<CallNode>()) {
      if (!ExprEqual(lhs->op, rhs->op)) return false;
      if (lhs->args.size() != rhs->args.size()) return false;
      // skip type_args check for primitive ops.
      bool is_primitive = IsPrimitiveOp(lhs->op);
      if (!is_primitive) {
        if (lhs->type_args.size() != rhs->type_args.size()) {
          return false;
        }
      }
      for (size_t i = 0; i < lhs->args.size(); ++i) {
        if (!ExprEqual(lhs->args[i], rhs->args[i])) {
          return false;
        }
      }

      if (!is_primitive) {
        for (size_t i = 0; i < lhs->type_args.size(); ++i) {
          if (!TypeEqual(lhs->type_args[i], rhs->type_args[i])) return false;
        }
      }
      return AttrEqual(lhs->attrs, rhs->attrs);
    } else {
      return false;
    }
  }

  bool VisitExpr_(const LetNode* lhs, const Expr& other) final {
    if (const LetNode* rhs = other.as<LetNode>()) {
      if (!MergeVarDecl(lhs->var, rhs->var)) return false;
      if (!ExprEqual(lhs->value, rhs->value)) return false;
      return ExprEqual(lhs->body, rhs->body);
    } else {
      return false;
    }
  }

  bool VisitExpr_(const IfNode* lhs, const Expr& other) final {
    if (const IfNode* rhs = other.as<IfNode>()) {
      return ExprEqual(lhs->cond, rhs->cond) &&
          ExprEqual(lhs->true_branch, rhs->true_branch) &&
          ExprEqual(lhs->false_branch, rhs->false_branch);
    } else {
      return false;
    }
  }

  bool VisitExpr_(const OpNode* lhs, const Expr& other) final {
    return lhs == other.get();
  }

  bool VisitExpr_(const ConstantNode* lhs, const Expr& other) final {
    if (const ConstantNode* rhs = other.as<ConstantNode>()) {
      return NDArrayEqual(lhs->data, rhs->data);
    } else {
      return false;
    }
  }

  bool VisitExpr_(const TupleGetItemNode* lhs, const Expr& other) final {
    if (const TupleGetItemNode* rhs = other.as<TupleGetItemNode>()) {
      return ExprEqual(lhs->tuple, rhs->tuple) && lhs->index == rhs->index;
    } else {
      return false;
    }
  }

  bool VisitExpr_(const RefCreateNode* lhs, const Expr& other) final {
    if (const RefCreateNode* rhs = other.as<RefCreateNode>()) {
      return ExprEqual(lhs->value, rhs->value);
    } else {
      return false;
    }
  }

  bool VisitExpr_(const RefReadNode* lhs, const Expr& other) final {
    if (const RefReadNode* rhs = other.as<RefReadNode>()) {
      return ExprEqual(lhs->ref, rhs->ref);
    } else {
      return false;
    }
  }

  bool VisitExpr_(const RefWriteNode* lhs, const Expr& other) final {
    if (const RefWriteNode* rhs = other.as<RefWriteNode>()) {
      return ExprEqual(lhs->ref, rhs->ref) && ExprEqual(lhs->value, rhs->value);
    } else {
      return false;
    }
  }

  bool VisitExpr_(const ConstructorNode* lhs, const Expr& other) final {
    if (const ConstructorNode* rhs = other.as<ConstructorNode>()) {
      return lhs->name_hint == rhs->name_hint;
    }
    return false;
  }

  bool ClauseEqual(const Clause& lhs, const Clause& rhs) {
    return PatternEqual(lhs->lhs, rhs->lhs) && ExprEqual(lhs->rhs, rhs->rhs);
  }

  bool PatternEqual(const Pattern& lhs, const Pattern& rhs) {
    return Compare(VisitPattern(lhs, rhs), lhs, rhs);
  }

  bool VisitPattern_(const PatternWildcardNode* lhs, const Pattern& other) final {
    return other.as<PatternWildcardNode>();
  }

  bool VisitPattern_(const PatternVarNode* lhs, const Pattern& other) final {
    if (const auto* rhs = other.as<PatternVarNode>()) {
      return MergeVarDecl(lhs->var, rhs->var);
    }
    return false;
  }

  bool VisitPattern_(const PatternConstructorNode* lhs, const Pattern& other) final {
    const auto* rhs = other.as<PatternConstructorNode>();
    if (rhs == nullptr
        || !ExprEqual(lhs->constructor, rhs->constructor)
        || lhs->patterns.size() != rhs->patterns.size()) {
      return false;
    }

    for (size_t i = 0; i < lhs->patterns.size(); i++) {
      if (!PatternEqual(lhs->patterns[i], rhs->patterns[i])) {
        return false;
      }
    }
    return true;
  }

  bool VisitPattern_(const PatternTupleNode* lhs, const Pattern& other) final {
    const auto* rhs = other.as<PatternTupleNode>();
    if (rhs == nullptr
        || lhs->patterns.size() != rhs->patterns.size()) {
      return false;
    }

    for (size_t i = 0; i < lhs->patterns.size(); i++) {
      if (!PatternEqual(lhs->patterns[i], rhs->patterns[i])) {
        return false;
      }
    }
    return true;
  }

  bool VisitExpr_(const MatchNode* lhs, const Expr& other) final {
    const MatchNode* rhs = other.as<MatchNode>();

    if (rhs == nullptr
        || !ExprEqual(lhs->data, rhs->data)
        || lhs->clauses.size() != rhs->clauses.size()
        || lhs->complete != rhs->complete) {
      return false;
    }

    for (size_t i = 0; i < lhs->clauses.size(); ++i) {
      if (!ClauseEqual(lhs->clauses[i], rhs->clauses[i])) {
        return false;
      }
    }
    return true;
  }

 private:
  // whether to map open terms.
  bool map_free_var_;
  // if in assert mode, must return true, and will throw error otherwise.
  bool assert_mode_;
  // renaming of NodeRef to indicate two nodes equals to each other
  std::unordered_map<ObjectRef, ObjectRef, ObjectHash, ObjectEqual> equal_map_;
};

bool AlphaEqual(const Type& lhs, const Type& rhs) {
  return AlphaEqualHandler(false, false).TypeEqual(lhs, rhs);
}

bool AlphaEqual(const Expr& lhs, const Expr& rhs) {
  return AlphaEqualHandler(false, false).ExprEqual(lhs, rhs);
}

// TODO(@jroesch): move to correct namespace?
TVM_REGISTER_GLOBAL("relay._make._alpha_equal")
.set_body_typed([](ObjectRef a, ObjectRef b) {
  return AlphaEqualHandler(false, false).Equal(a, b);
});

TVM_REGISTER_GLOBAL("ir.type_alpha_equal")
.set_body_typed([](Type a, Type b) {
  return AlphaEqual(a, b);
});

TVM_REGISTER_GLOBAL("relay._make._assert_alpha_equal")
.set_body_typed([](ObjectRef a, ObjectRef b) {
  bool alpha_equal = AlphaEqualHandler(false, true).Equal(a, b);
  CHECK(alpha_equal) << AsText(a, true) << " and " << AsText(b, true) << " are not alpha equal";
});

TVM_REGISTER_GLOBAL("relay._make._graph_equal")
.set_body_typed([](ObjectRef a, ObjectRef b) {
  return AlphaEqualHandler(true, false).Equal(a, b);
});

TVM_REGISTER_GLOBAL("relay._make._assert_graph_equal")
.set_body_typed([](ObjectRef a, ObjectRef b) {
  bool graph_equal = AlphaEqualHandler(true, true).Equal(a, b);
  CHECK(graph_equal) << AsText(a, true) << " and " << AsText(b, true) << " are not graph equal";
});

}  // namespace relay
}  // namespace tvm