/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
/*!
 * \file tvm/node/structural_equal.h
 * \brief Structural hash class.
 */
#ifndef TVM_NODE_STRUCTURAL_HASH_H_
#define TVM_NODE_STRUCTURAL_HASH_H_

#include <tvm/runtime/data_type.h>
#include <tvm/node/functor.h>
#include <tvm/node/container.h>
#include <string>
#include <functional>

namespace tvm {

/*!
 * \brief Hash definition of base value classes.
 */
class BaseValueHash {
 public:
  size_t operator()(const double& key) const {
    return std::hash<double>()(key);
  }

  size_t operator()(const int64_t& key) const {
    return std::hash<int64_t>()(key);
  }

  size_t operator()(const uint64_t& key) const {
    return std::hash<uint64_t>()(key);
  }

  size_t operator()(const int& key) const {
    return std::hash<int>()(key);
  }

  size_t operator()(const bool& key) const {
    return std::hash<bool>()(key);
  }

  size_t operator()(const std::string& key) const {
    return std::hash<std::string>()(key);
  }

  size_t operator()(const runtime::DataType& key) const {
    return std::hash<int32_t>()(
        static_cast<int32_t>(key.code()) |
        (static_cast<int32_t>(key.bits()) << 8) |
        (static_cast<int32_t>(key.lanes()) << 16));
  }

  template<typename ENum,
           typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
  bool operator()(const ENum& key) const {
    return std::hash<size_t>()(static_cast<size_t>(key));
  }
};

/*!
 * \brief Content-aware structural hasing.
 *
 *  The structural hash value is recursively defined in the DAG of IRNodes.
 *  There are two kinds of nodes:
 *
 *  - Normal node: the hash value is defined by its content and type only.
 *  - Graph node: each graph node will be assigned a unique index ordered by the
 *    first occurence during the visit. The hash value of a graph node is
 *    combined from the hash values of its contents and the index.
 */
class StructuralHash : public BaseValueHash {
 public:
  // inheritate operator()
  using BaseValueHash::operator();
  /*!
   * \brief Compute structural hashing value for an object.
   * \param key The left operand.
   * \return The hash value.
   */
  TVM_DLL size_t operator()(const ObjectRef& key) const;
};

/*!
 * \brief A Reducer class to reduce the structural hash value.
 *
 *  The reducer will call the SEqualHash function of each objects recursively.
 *
 *  A SEqualHash function will make a sequence of calls to the reducer to
 *  indicate a sequence of child hash values that the reducer need to combine
 *  inorder to obtain the hash value of the hash value of the parent object.
 *
 *  Importantly, the reducer may not directly use recursive calls
 *  to compute the hash values of child objects directly.
 *
 *  Instead, it can store the necessary hash computing task into a stack
 *  and reduce the result later.
 */
class SHashReducer {
 public:
  /*! \brief Internal handler that defines custom behaviors. */
  class Handler {
   public:
    /*!
     * \brief Append hashed_value to the current sequence of hashes.
     *
     * \param hashed_value The hashed value
     */
    virtual void SHashReduceHashedValue(size_t hashed_value) = 0;
    /*!
     * \brief Append hash value of key to the current sequence of hashes.
     *
     * \param key The object to compute hash from.
     * \param map_free_vars Whether to map free variables by their occurence number.
     */
    virtual void SHashReduce(const ObjectRef& key, bool map_free_vars) = 0;
    /*!
     * \brief Apppend a hash value of free variable to the current sequence of hashes.
     *
     * \param var The var of interest.
     * \param map_free_vars Whether to map free variables by their occurence number.
     *
     * \note If map_free_vars is set to be true,
     *       internally the handler can maintain a counter to encode free variables
     *       by their order of occurence. This helps to resolve variable
     *       mapping of function parameters and let binding variables.
     *
     *       If map_free_vars is set to be false, the address of the variable will be used.
     */
    virtual void SHashReduceFreeVar(const runtime::Object* var, bool map_free_vars) = 0;
    /*!
     * \brief Lookup a hash value for key
     *
     * \param key The hash key.
     * \param hashed_value the result hash value
     *
     * \return Whether there is already a pre-computed hash value.
     */
    virtual bool LookupHashedValue(const ObjectRef& key, size_t* hashed_value) = 0;
    /*!
     * \brief Mark current comparison as graph node in hashing.
     *        Graph node hash will depends on the graph structure.
     */
    virtual void MarkGraphNode() = 0;
  };

  /*! \brief default constructor */
  SHashReducer() = default;
  /*!
   * \brief Constructor with a specific handler.
   * \param handler The equal handler for objects.
   * \param map_free_vars Whether to map free variables.
   */
  explicit SHashReducer(Handler* handler, bool map_free_vars)
      : handler_(handler), map_free_vars_(map_free_vars) {}
  /*!
   * \brief Push hash of key to the current sequence of hash values.
   * \param key The key to be hashed.
   */
  template<typename T,
           typename = typename std::enable_if<
             !std::is_base_of<ObjectRef, T>::value>::type>
  void operator()(const T& key) const {
    // handle normal values.
    handler_->SHashReduceHashedValue(BaseValueHash()(key));
  }
  /*!
   * \brief Push hash of key to the current sequence of hash values.
   * \param key The key to be hashed.
   */
  void operator()(const ObjectRef& key) const {
    return handler_->SHashReduce(key, map_free_vars_);
  }
  /*!
   * \brief Push hash of key to the current sequence of hash values.
   * \param key The key to be hashed.
   * \note This function indicate key could contain var defintions.
   */
  void DefHash(const ObjectRef& key) const {
    return handler_->SHashReduce(key, true);
  }
  /*!
   * \brief Implementation for hash for a free var.
   * \param var The variable.
   * \return the result.
   */
  void FreeVarHashImpl(const runtime::Object* var) const {
    handler_->SHashReduceFreeVar(var, map_free_vars_);
  }

  /*! \return Get the internal handler. */
  Handler* operator->() const {
    return handler_;
  }

 private:
  /*! \brief Internal class pointer. */
  Handler* handler_;
  /*!
   * \brief Whether or not to map free variables by their occurence
   *        If the flag is false, then free variables will be mapped
   *        by their in-memory address.
   */
  bool map_free_vars_;
};

}  // namespace tvm
#endif  // TVM_NODE_STRUCTURAL_HASH_H_