/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file src/tvm/relay/ir/hash.cc
 * \brief Hash functions for Relay types and expressions.
 */
#include <tvm/ir/type_functor.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/relay/analysis.h>
#include <tvm/ir/attrs.h>
#include "../../ir/attr_functor.h"

namespace tvm {
namespace relay {

// Hash handler for Relay.
class RelayHashHandler:
    public AttrsHashHandler,
    public TypeFunctor<size_t(const Type&)>,
    public ExprFunctor<size_t(const Expr&)>,
    public PatternFunctor<size_t(const Pattern&)> {
 public:
  explicit RelayHashHandler() {}

  /*!
   * Compute hash of a node.
   * \param ref The node to hash.
   * \return the hash value.
   */
  size_t Hash(const ObjectRef& ref) {
    if (!ref.defined()) return ObjectHash()(ref);

    if (ref->IsInstance<TypeNode>()) {
      return TypeHash(Downcast<Type>(ref));
    }
    if (ref->IsInstance<ExprNode>()) {
      return ExprHash(Downcast<Expr>(ref));
    }
    return AttrHash(ref);
  }

  /*!
   * Compute hash of the attributes.
   * \param ref The attributes.
   * \return the hash value
   */
  size_t AttrHash(const ObjectRef& ref) {
    if (!ref.defined()) {
      return ObjectHash()(ref);
    }
    return AttrsHashHandler::Hash(ref);
  }
  /*!
   * Compute hash of a Relay type.
   * \param ref The type to hash.
   * \param rhs The right hand operand.
   * \return the hash value.
   */
  size_t TypeHash(const Type& type) {
    if (!type.defined()) {
      return ObjectHash()(type);
    }
    auto found = hash_map_.find(type);
    if (found != hash_map_.end()) {
      return found->second;
    } else {
      auto hash = this->VisitType(type);
      hash_map_.insert({type, hash});
      return hash;
    }
  }
  /*!
   * Compute the hash of an expression.
   *
   * \note We run graph structural equality checking when comparing two Exprs.
   *   This means that AlphaEqualHandler can only be used once for each pair.
   *   The equality checker checks data-flow equvalence of the Expr DAG.
   *   This function also runs faster as it memomizes equal_map.
   *
   * \param expr The expression to hash.
   * \return the hash value.
   */
  size_t ExprHash(const Expr& expr) {
    if (!expr.defined()) {
      return ObjectHash()(expr);
    }
    auto found = hash_map_.find(expr);
    if (found != hash_map_.end()) {
      return found->second;
    } else {
      auto hash = this->VisitExpr(expr);
      hash_map_.insert({expr, hash});
      return hash;
    }
  }

 protected:
  /*!
   * \brief Hash a DataType.
   * \param dtype The dtype to hash.
   * \return the hash value.
   */
  size_t DataTypeHash(const DataType& dtype) {
    return ::tvm::AttrsHash()(dtype);
  }

  using AttrsHashHandler::VisitAttr_;
  size_t VisitAttr_(const tvm::tir::VarNode* var) final {
    size_t hash = std::hash<std::string>()(VarNode::_type_key);
    auto it = hash_map_.find(GetRef<tvm::tir::Var>(var));
    if (it != hash_map_.end()) {
      return it->second;
    }
    return Combine(hash, std::hash<std::string>()(var->name_hint));
  }

  // Type hashing
  size_t VisitType_(const TensorTypeNode* tensor_type) final {
    size_t hash = std::hash<std::string>()(TensorTypeNode::_type_key);
    hash = Combine(hash, DataTypeHash(tensor_type->dtype));
    hash = Combine(hash, Hash(tensor_type->shape));
    return hash;
  }

  size_t VisitType_(const IncompleteTypeNode* incomplete) final {
    size_t hash = std::hash<std::string>()(IncompleteTypeNode::_type_key);
    return Combine(hash, std::hash<int>()(incomplete->kind));
  }

  size_t VisitType_(const TypeVarNode* tyvar) final {
    /*
      TypeVar/Var/Variable have two locations where they are hashed:

        The declaration site of a function, let, or function type.
        The first occurence in the term.

      We will only reach this code if the TypeVar itself is unbound, we assign
      a free variable index to it, meaning this hashing function implements
      structural equality for both open (i.e graph equality) and closed terms
      (i.e alpha_equality).
    */
    return BindVar(GetRef<TypeVar>(tyvar));
  }

  size_t VisitType_(const FuncTypeNode* func_type) final {
    size_t hash = std::hash<std::string>()(FuncTypeNode::_type_key);

    for (auto type_param : func_type->type_params) {
      hash = Combine(hash, BindVar(type_param));
    }

    for (auto arg : func_type->arg_types) {
      hash = Combine(hash, TypeHash(arg));
    }

    hash = Combine(hash, TypeHash(func_type->ret_type));
    for (auto cs : func_type->type_constraints) {
      hash = Combine(hash, TypeHash(cs));
    }

    return hash;
  }

  size_t VisitType_(const TypeRelationNode* type_rel) final {
    size_t hash = std::hash<std::string>()(TypeRelationNode::_type_key);
    hash = Combine(hash, std::hash<std::string>()(type_rel->func->name));
    hash = Combine(hash, AttrHash(type_rel->attrs));

    for (auto arg : type_rel->args) {
      hash = Combine(hash, TypeHash(arg));
    }

    return hash;
  }

  size_t VisitType_(const TupleTypeNode* tuple_type) final {
    size_t hash = std::hash<std::string>()(TupleTypeNode::_type_key);
    for (size_t i = 0; i < tuple_type->fields.size(); i++) {
      hash = Combine(hash, TypeHash(tuple_type->fields[i]));
    }
    return hash;
  }

  size_t VisitType_(const RelayRefTypeNode* rtn) final {
    size_t hash = std::hash<std::string>()(RelayRefTypeNode::_type_key);
    hash = Combine(hash, TypeHash(rtn->value));
    return hash;
  }

  // Expr hashing.
  size_t NDArrayHash(const runtime::NDArray& array) {
    size_t hash = std::hash<uint8_t>()(array->dtype.code);
    hash = Combine(hash, std::hash<uint8_t>()(array->dtype.bits));
    hash = Combine(hash, std::hash<uint16_t>()(array->dtype.lanes));
    CHECK_EQ(array->ctx.device_type, kDLCPU) << "can only compare CPU tensor";
    size_t data_size = runtime::GetDataSize(*array.operator->());
    uint8_t * data = reinterpret_cast<uint8_t*>(array->data);
    for (size_t i = 0; i < data_size; i++) {
      hash = Combine(hash, std::hash<uint8_t>()(data[i]));
    }
    return hash;
  }

  size_t BindVar(const ObjectRef& var) {
    size_t hash = std::hash<int>()(var_counter++);
    CHECK_EQ(hash_map_.count(var), 0);
    if (auto var_node = var.as<VarNode>()) {
      hash = Combine(hash, TypeHash(var_node->type_annotation));
    }
    hash_map_[var] = hash;
    return hash;
  }

  size_t VisitExpr_(const VarNode* var) final {
    // hash free variable
    size_t name_hash = std::hash<const Object*>()(var->vid.get());
    return Combine(name_hash, TypeHash(var->type_annotation));
  }

  size_t VisitExpr_(const GlobalVarNode* global) final {
    return std::hash<std::string>()(global->name_hint);
  }

  size_t VisitExpr_(const TupleNode* tuple) final {
    size_t hash = std::hash<std::string>()(TupleNode::_type_key);
    for (size_t i = 0; i < tuple->fields.size(); i++) {
      hash = Combine(hash, ExprHash(tuple->fields[i]));
    }
    return hash;
  }

  size_t VisitExpr_(const FunctionNode* func) final {
    size_t hash = std::hash<std::string>()(FunctionNode::_type_key);
    for (auto type_param : func->type_params) {
      hash = Combine(hash, BindVar(type_param));
    }

    for (auto param : func->params) {
      hash = Combine(hash, BindVar(param));
    }

    hash = Combine(hash, TypeHash(func->ret_type));
    hash = Combine(hash, ExprHash(func->body));

    hash = Combine(hash, AttrHash(func->attrs));

    return hash;
  }

  size_t VisitExpr_(const CallNode* call) final {
    size_t hash = std::hash<std::string>()(CallNode::_type_key);
    hash = Combine(hash, ExprHash(call->op));

    for (auto arg : call->args) {
      hash = Combine(hash, ExprHash(arg));
    }

    for (auto t : call->type_args) {
      CHECK(t.defined());
      hash = Combine(hash, TypeHash(t));
    }

    hash = Combine(hash, AttrHash(call->attrs));

    return hash;
  }

  size_t VisitExpr_(const LetNode* let) final {
    size_t hash = std::hash<std::string>()(LetNode::_type_key);
    hash = Combine(hash, BindVar(let->var));
    hash = Combine(hash, ExprHash(let->value));
    hash = Combine(hash, ExprHash(let->body));
    return hash;
  }

  size_t VisitExpr_(const IfNode* ite) final {
    size_t key = std::hash<std::string>()(IfNode::_type_key);
    size_t hash = key;
    hash = Combine(hash, ExprHash(ite->cond));
    hash = Combine(hash, ExprHash(ite->true_branch));
    hash = Combine(hash, ExprHash(ite->false_branch));
    return hash;
  }

  size_t VisitExpr_(const OpNode* op) final {
    return ObjectHash()(GetRef<Op>(op));
  }

  size_t VisitExpr_(const ConstantNode* rconst) final {
    return NDArrayHash(rconst->data);
  }

  size_t VisitExpr_(const TupleGetItemNode* get_item) final {
    size_t hash = std::hash<std::string>()(TupleGetItemNode::_type_key);
    hash = Combine(hash, ExprHash(get_item->tuple));
    hash = Combine(hash, std::hash<int>()(get_item->index));
    return hash;
  }

  size_t VisitExpr_(const RefCreateNode* rn) final {
    size_t hash = std::hash<std::string>()(RefCreateNode::_type_key);
    hash = Combine(hash, ExprHash(rn->value));
    return hash;
  }

  size_t VisitExpr_(const RefReadNode* rn) final {
    size_t hash = std::hash<std::string>()(RefReadNode::_type_key);
    hash = Combine(hash, ExprHash(rn->ref));
    return hash;
  }

  size_t VisitExpr_(const RefWriteNode* rn) final {
    size_t hash = std::hash<std::string>()(RefWriteNode::_type_key);
    hash = Combine(hash, ExprHash(rn->ref));
    hash = Combine(hash, ExprHash(rn->value));
    return hash;
  }

  size_t VisitExpr_(const MatchNode* mn) final {
    size_t hash = std::hash<std::string>()(MatchNode::_type_key);
    hash = Combine(hash, ExprHash(mn->data));
    for (const auto& c : mn->clauses) {
      hash = Combine(hash, PatternHash(c->lhs));
      hash = Combine(hash, ExprHash(c->rhs));
    }
    hash = Combine(hash, std::hash<bool>()(mn->complete));
    return hash;
  }

  size_t VisitExpr_(const ConstructorNode* cn) final {
    size_t hash = std::hash<std::string>()(ConstructorNode::_type_key);
    hash = Combine(hash, std::hash<std::string>()(cn->name_hint));
    return hash;
  }

  size_t VisitType_(const TypeCallNode* tcn) final {
    size_t hash = std::hash<std::string>()(TypeCallNode::_type_key);
    hash = Combine(hash, TypeHash(tcn->func));
    for (const auto& t : tcn->args) {
      hash = Combine(hash, TypeHash(t));
    }
    return hash;
  }

  size_t VisitType_(const TypeDataNode* tdn) final {
    size_t hash = std::hash<std::string>()(TypeDataNode::_type_key);
    hash = Combine(hash, TypeHash(tdn->header));
    for (const auto& tv : tdn->type_vars) {
      hash = Combine(hash, TypeHash(tv));
    }
    for (const auto& cn : tdn->constructors) {
      hash = Combine(hash, ExprHash(cn));
    }
    return hash;
  }

  size_t VisitType_(const GlobalTypeVarNode* tvn) final {
    return BindVar(GetRef<GlobalTypeVar>(tvn));
  }

  size_t PatternHash(const Pattern& p) {
    return VisitPattern(p);
  }

  size_t VisitPattern_(const PatternConstructorNode* pcn) final {
    size_t hash = std::hash<std::string>()(PatternConstructorNode::_type_key);
    hash = Combine(hash, ExprHash(pcn->constructor));
    for (const auto& p : pcn->patterns) {
      hash = Combine(hash, PatternHash(p));
    }
    return hash;
  }

  size_t VisitPattern_(const PatternTupleNode* ptn) final {
    size_t hash = std::hash<std::string>()(PatternTupleNode::_type_key);
    for (const auto& p : ptn->patterns) {
      hash = Combine(hash, PatternHash(p));
    }
    return hash;
  }

  size_t VisitPattern_(const PatternVarNode* pvn) final {
    size_t hash = std::hash<std::string>()(PatternVarNode::_type_key);
    hash = Combine(hash, BindVar(pvn->var));
    return hash;
  }

  size_t VisitPattern_(const PatternWildcardNode* pwn) final {
    size_t hash = std::hash<std::string>()(PatternWildcardNode::_type_key);
    return hash;
  }
 private:
  // renaming of NodeRef to indicate two nodes equals to each other
  std::unordered_map<ObjectRef, size_t, ObjectHash, ObjectEqual> hash_map_;
  int var_counter = 0;
};

size_t StructuralHash::operator()(const Type& type) const {
  return RelayHashHandler().TypeHash(type);
}

size_t StructuralHash::operator()(const Expr& expr) const {
  return RelayHashHandler().ExprHash(expr);
}

TVM_REGISTER_GLOBAL("relay._analysis._expr_hash")
.set_body_typed([](ObjectRef ref) {
  return static_cast<int64_t>(RelayHashHandler().Hash(ref));
});

TVM_REGISTER_GLOBAL("relay._analysis._type_hash")
.set_body_typed([](Type type) {
  return static_cast<int64_t>(RelayHashHandler().TypeHash(type));
});

}  // namespace relay
}  // namespace tvm