/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
/*!
 * \file tvm/node/ir_functor.h
 * \brief Defines the IRFunctor data structures.
 */
#ifndef TVM_NODE_IR_FUNCTOR_H_
#define TVM_NODE_IR_FUNCTOR_H_

#include <dmlc/logging.h>
#include <string>
#include <vector>
#include <memory>
#include <type_traits>
#include <utility>
#include <functional>
#include "node.h"

namespace tvm {
/*!
 * \brief A dynamically dispatched functor on NodeRef in the first argument.
 *
 * \code
 *   IRFunctor<std::string (const NodeRef& n, std::string prefix)> tostr;
 *   tostr.set_dispatch<Add>([](const Add* op, std::string prefix) {
 *     return prefix + "Add";
 *   });
 *   tostr.set_dispatch<IntImm>([](const IntImm* op) {
 *     return prefix + "IntImm"
 *   });
 *
 *   Expr x = make_const(1);
 *   Expr y = x + x;
 *   // dispatch to IntImm, outputs "MyIntImm"
 *   LOG(INFO) << tostr(x, "My");
 *   // dispatch to IntImm, outputs "MyAdd"
 *   LOG(INFO) << tostr(y, "My");
 * \endcode
 *
 * \tparam FType function signiture
 *  This type if only defined for FType with function signature
 */
template<typename FType>
class IRFunctor;

template<typename R, typename ...Args>
class IRFunctor<R(const NodeRef& n, Args...)> {
 private:
  using Function = std::function<R (const NodeRef&n, Args...)>;
  using TSelf = IRFunctor<R (const NodeRef& n, Args...)>;
  /*! \brief internal function table */
  std::vector<Function> func_;

 public:
  /*! \brief the result type of this functor */
  using result_type = R;
  /*!
   * \brief Whether the functor can dispatch the corresponding Node
   * \param n The node to be dispatched
   * \return Whether dispatching function is registered for n's type.
   */
  inline bool can_dispatch(const NodeRef& n) const {
    uint32_t type_index = n.type_index();
    return type_index < func_.size() && func_[type_index] != nullptr;
  }
  /*!
   * \brief invoke the functor , dispatch on type of n
   * \param n The Node argument
   * \param args The additional arguments
   * \return The result.
   */
  inline R operator()(const NodeRef& n, Args... args) const {
    uint32_t type_index = n.type_index();
    CHECK(type_index < func_.size() &&
          func_[type_index] != nullptr)
        << "IRFunctor calls un-registered function on type "
        << Node::TypeIndex2Key(type_index);
    return func_[type_index](n, std::forward<Args>(args)...);
  }
  /*!
   * \brief set the dispacher for type TNode
   * \param f The function to be set.
   * \tparam TNode the type of Node to be dispatched.
   * \return reference to self.
   */
  template<typename TNode>
  inline TSelf& set_dispatch(Function f) {  // NOLINT(*)
    uint32_t tindex = Node::TypeKey2Index(TNode::_type_key);
    if (func_.size() <= tindex) {
      func_.resize(tindex + 1, nullptr);
    }
    CHECK(func_[tindex] == nullptr)
        << "Dispatch for " << Node::TypeIndex2Key(tindex)
        << " is already set";
    func_[tindex] = f;
    return *this;
  }
  /*!
   * \brief set the dispacher for type TNode
   *  This allows f to used detailed const Node pointer to replace NodeRef
   *
   * \param f The function to be set.
   * \tparam TNode the type of Node to be dispatched.
   * \return reference to self.
   */
  template<typename TNode>
  inline TSelf& set_dispatch(std::function<R(const TNode* n, Args...)> f) { // NOLINT(*)
    Function fun = [f](const NodeRef& n, Args... args) {
      return f(static_cast<const TNode*>(n.node_.get()),
               std::forward<Args>(args)...);
    };
    return this->set_dispatch<TNode>(fun);
  }
  /*!
  * \brief unset the dispacher for type TNode
  *
  * \tparam TNode the type of Node to be dispatched.
  * \return reference to self.
  */
  template<typename TNode>
  inline TSelf& clear_dispatch() {  // NOLINT(*)
    uint32_t tindex = Node::TypeKey2Index(TNode::_type_key);
    CHECK_LT(tindex, func_.size()) << "clear_dispatch: index out of range";
    func_[tindex] = nullptr;
    return *this;
  }
};

#if defined(__GNUC__)
#define TVM_ATTRIBUTE_UNUSED __attribute__((unused))
#else
#define TVM_ATTRIBUTE_UNUSED
#endif

/*! \brief helper macro to generate string concat */
#define TVM_STR_CONCAT_(__x, __y) __x##__y
#define TVM_STR_CONCAT(__x, __y) TVM_STR_CONCAT_(__x, __y)

#define TVM_REGISTER_VAR_DEF(ClsName)                                 \
  static TVM_ATTRIBUTE_UNUSED auto & __make_functor ## _ ## ClsName

/*!
 * \brief Useful macro to set IRFunctor dispatch in a global static field.
 *
 * \code
 *  // Use IRFunctor to implement IRPrinter similar to Visitor Pattern.
 *  // vtable allows easy patch in of new Node types, without changing
 *  // interface of IRPrinter.
 *
 *  class IRPrinter {
 *   public:
 *    std::ostream& stream;
 *    // the dispatch function.
 *    void print(Expr e) {
 *      const static FType& f = *vtable();
 *      f(e, this);
 *    }
 *
 *    using FType = IRFunctor<void (const NodeRef&, IRPrinter *)>;
 *    // function to return global function table
 *    static FType& vtable();
 *  };
 *
 *  // in cpp/cc file
 *  IRPrinter::FType& IRPrinter::vtable() { // NOLINT(*)
 *    static FType inst; return inst;
 *  }
 *
 *  TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
 *  .set_dispatch<Add>([](const Add* n, IRPrinter* p) {
 *    p->print(n->a);
 *    p->stream << '+'
 *    p->print(n->b);
 *  });
 *
 *
 * \endcode
 *
 * \param ClsName The name of the class
 * \param FField The static function that returns a singleton of IRFunctor.
 */
#define TVM_STATIC_IR_FUNCTOR(ClsName, FField)                       \
  TVM_STR_CONCAT(TVM_REGISTER_VAR_DEF(ClsName), __COUNTER__)  =      \
                              ClsName::FField()

 /*!
 * \brief A container for a list of callbacks. All callbacks are invoked when
 * the object is destructed.
 */
class IRFunctorCleanList {
 public:
  ~IRFunctorCleanList() {
    for (auto &f : clean_items) {
      f();
    }
  }

  void append(std::function<void()> func) {
    clean_items.push_back(func);
  }

 private:
  std::vector< std::function<void()> > clean_items;
};

/*!
* \brief A wrapper around IRFunctor that will record calls to set_dispatch
* and make a corresponding call to clear_dispatch when the last copy of
* the IRFunctorStaticRegistry is destructed. When assigned to a static variable,
* this can be used by NNVM and other libraries to unregister callbacks when
* the library is unloaded. This prevents crashes when the underlying IRFunctor
* is destructed as it will no longer contain std::function instances allocated
* by a library that has been unloaded.
*/
template<typename FType>
class IRFunctorStaticRegistry;

template<typename R, typename ...Args>
class IRFunctorStaticRegistry<R(const NodeRef& n, Args...)> {
 private:
  IRFunctor<R(const NodeRef& n, Args...)> *irf_;
  std::shared_ptr<IRFunctorCleanList> free_list;

  using TSelf = IRFunctorStaticRegistry<R(const NodeRef& n, Args...)>;

 public:
  IRFunctorStaticRegistry(IRFunctor<R(const NodeRef& n, Args...)> *irf) {
    irf_ = irf;
    free_list = std::make_shared<IRFunctorCleanList>();
  }

  template<typename TNode>
  inline TSelf& set_dispatch(std::function<R(const TNode* n, Args...)> f) {  // NOLINT(*)
    irf_->template set_dispatch<TNode>(f);
    auto irf_copy = irf_;
    free_list.get()->append([irf_copy] {
      irf_copy->template clear_dispatch<TNode>();
      });
    return *this;
  }
};

/*!
* \brief Helper function for constructing an IRFunctorStaticRegistry. This allows
* the compiler to deduce the template types.
*/
template<typename R, typename ...Args>
IRFunctorStaticRegistry<R(const NodeRef& n, Args...)> MakeIRFunctorStaticRegistry(
  IRFunctor<R(const NodeRef& n, Args...)> *irf) {
  return IRFunctorStaticRegistry<R(const NodeRef& n, Args...)>(irf);
}

#define TVM_AUTO_REGISTER_VAR_DEF(ClsName)                           \
  static TVM_ATTRIBUTE_UNUSED auto __make_functor ## _ ## ClsName

/*!
* \brief Macro to set IRFunctor dispatch in a global static field using an IRFunctorStaticRegistry.
* Usage is exactly the same as TVM_STATIC_IR_FUNCTOR. Libraries should use this instead of
* TVM_STATIC_IR_FUNCTOR.
*/
#define TVM_STATIC_IR_FUNCTOR_REGISTER(ClsName, FField)                  \
  TVM_STR_CONCAT(TVM_AUTO_REGISTER_VAR_DEF(ClsName), __COUNTER__)  = \
                        MakeIRFunctorStaticRegistry(&ClsName::FField())

}  // namespace tvm
#endif  // TVM_NODE_IR_FUNCTOR_H_