/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file tvm/relay/pattern_functor.h
 * \brief A more powerful visitor on ADT patterns that enables defining
 * arbitrary function signatures with type-based dispatch on first argument.
 */
#ifndef TVM_RELAY_PATTERN_FUNCTOR_H_
#define TVM_RELAY_PATTERN_FUNCTOR_H_

#include <tvm/node/ir_functor.h>
#include <string>
#include <utility>
#include <unordered_map>
#include "./expr.h"
#include "./op.h"
#include "./error.h"
#include "./adt.h"

namespace tvm {
namespace relay {

/*!
 * \brief A dynamical functor on ADT patterns that dispatches on its first argument.
 *  You can use this as a more powerful visitor, since it allows you to
 *  define the types of further arguments to VisitPattern.
 *
 * \sa tvm/ir_functor.h
 *
 * \tparam FType function signiture
 *  This type is only defined for FType with function signature R(const Pattern&,
 * Args...)
 */
template <typename FType>
class PatternFunctor;

// functions to be overriden.
#define PATTERN_FUNCTOR_DEFAULT                                      \
  { return VisitPatternDefault_(op, std::forward<Args>(args)...); }

#define RELAY_PATTERN_FUNCTOR_DISPATCH(OP)                                \
  vtable.template set_dispatch<OP>(                                       \
      [](const NodeRef& n, TSelf* self, Args... args) {                   \
        return self->VisitPattern_(static_cast<const OP*>(n.node_.get()), \
                                   std::forward<Args>(args)...);          \
      });

template <typename R, typename... Args>
class PatternFunctor<R(const Pattern& n, Args...)> {
 private:
  using TSelf = PatternFunctor<R(const Pattern& n, Args...)>;
  using FType = tvm::IRFunctor<R(const NodeRef& n, TSelf* self, Args...)>;

 public:
  /*! \brief the result type of this functor */
  using result_type = R;
  /*! \brief virtual destructor */
  virtual ~PatternFunctor() {}
  /*!
   * \brief Same as call.
   * \param n The expression node.
   * \param args Additional arguments.
   * \return The result of the call
   */
  R operator()(const Pattern& n, Args... args) {
    return VisitPattern(n, std::forward<Args>(args)...);
  }
  /*!
   * \brief The functor call.
   * \param n The expression node.
   * \param args Additional arguments.
   * \return The result of the call
   */
  virtual R VisitPattern(const Pattern& n, Args... args) {
    CHECK(n.defined());
    static FType vtable = InitVTable();
    return vtable(n, this, std::forward<Args>(args)...);
  }
  // Functions that can be overriden by subclass
  virtual R VisitPattern_(const PatternWildcardNode* op,
                          Args... args) PATTERN_FUNCTOR_DEFAULT;
  virtual R VisitPattern_(const PatternVarNode* op,
                          Args... args) PATTERN_FUNCTOR_DEFAULT;
  virtual R VisitPattern_(const PatternConstructorNode* op,
                          Args... args) PATTERN_FUNCTOR_DEFAULT;
  virtual R VisitPatternDefault_(const Node* op, Args...) {
    throw Error(std::string("Do not have a default for ") + op->type_key());
  }

 private:
  // initialize the vtable.
  static FType InitVTable() {
    FType vtable;
    // Set dispatch
    RELAY_PATTERN_FUNCTOR_DISPATCH(PatternWildcardNode);
    RELAY_PATTERN_FUNCTOR_DISPATCH(PatternVarNode);
    RELAY_PATTERN_FUNCTOR_DISPATCH(PatternConstructorNode);
    return vtable;
  }
};

/*! \brief A simple visitor wrapper around PatternFunctor.
 *
 * Exposes two visitors with default traversal strategies, one
 * which doesn't compute a result but can mutate internal state,
 * and another which functionally builds a new pattern.
 */
class PatternVisitor : public ::tvm::relay::PatternFunctor<void(const Pattern& n)> {
 public:
  void VisitPattern_(const PatternWildcardNode* op) override;
  void VisitPattern_(const PatternVarNode* op) override;
  void VisitPattern_(const PatternConstructorNode* op) override;
  virtual void VisitType(const Type& t);
  virtual void VisitVar(const Var& v);
  virtual void VisitConstructor(const Constructor& c);
};

/*! \brief A wrapper around ExprFunctor which functionally updates the AST.
 *
 * ExprMutator uses memoization and self return in order to amortize
 * the cost of using functional updates.
 */
class PatternMutator
    : public ::tvm::relay::PatternFunctor<Pattern(const Pattern&)> {
 public:
  Pattern Mutate(const Pattern& pat);
  Pattern VisitPattern_(const PatternWildcardNode* op) override;
  Pattern VisitPattern_(const PatternVarNode* op) override;
  Pattern VisitPattern_(const PatternConstructorNode* op) override;
  /*! \brief Used to visit the types inside of patterns.
   *
   * Can be overloaded to transform the types in arbitrary
   * ways, one way would be to define a sub-class of type
   * visitor for types which transform them appropriately.
   */
  virtual Type VisitType(const Type& t);
  /*! \brief Used to visit the vars inside of patterns. */
  virtual Var VisitVar(const Var& v);
  /*! \brief Used to visit the vars inside of patterns. */
  virtual Constructor VisitConstructor(const Constructor& c);
 private:
  std::unordered_map<Var, Var, NodeHash, NodeEqual> var_map_;
};

}  // namespace relay
}  // namespace tvm
#endif  // TVM_RELAY_PATTERN_FUNCTOR_H_