/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file attr_functor.h
 * \brief A way to define arbitrary function signature
 *        with dispatch on common attributes.
 *
 * Common attributes include:
 *  - int, float, str constants
 *  - array of attributes
 *  - map of attributes
 */
#ifndef TVM_IR_ATTR_FUNCTOR_H_
#define TVM_IR_ATTR_FUNCTOR_H_

#include <tvm/node/functor.h>
#include <tvm/tir/expr.h>
#include <utility>

namespace tvm {

template <typename FType>
class AttrFunctor;

#define ATTR_FUNCTOR_DEFAULT                                        \
  { return VisitAttrDefault_(op, std::forward<Args>(args)...); }


#define ATTR_FUNCTOR_DISPATCH(OP)                                       \
  vtable.template set_dispatch<OP>(                                     \
      [](const ObjectRef& n, TSelf* self, Args... args) {               \
        return self->VisitAttr_(static_cast<const OP*>(n.get()),        \
                                std::forward<Args>(args)...);           \
      });                                                               \

// A functor for common attribute information.
template <typename R, typename... Args>
class AttrFunctor<R(const ObjectRef& n, Args...)> {
 private:
  using TSelf = AttrFunctor<R(const ObjectRef& n, Args...)>;
  using FType = tvm::NodeFunctor<R(const ObjectRef& n, TSelf* self, Args...)>;

 public:
  /*! \brief the result type of this functor */
  using result_type = R;
  /*! \brief virtual destructor */
  virtual ~AttrFunctor() {}
  /*!
   * \brief The functor call.
   * \param n The expression node.
   * \param args Additional arguments.
   * \return The result of the call
   */
  virtual R VisitAttr(const ObjectRef& n, Args... args) {
    static FType vtable = InitVTable();
    if (vtable.can_dispatch(n)) {
      return vtable(n, this, std::forward<Args>(args)...);
    } else {
      return VisitAttrDefault_(n.get(), std::forward<Args>(args)...);
    }
  }
  virtual R VisitAttrDefault_(const Object* node, Args... args) = 0;
  virtual R VisitAttr_(const ArrayNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
  virtual R VisitAttr_(const StrMapNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
  virtual R VisitAttr_(const tir::IntImmNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
  virtual R VisitAttr_(const tir::FloatImmNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
  virtual R VisitAttr_(const tir::StringImmNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
  // deep comparison of symbolic integer expressions.
  virtual R VisitAttr_(const tir::VarNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
  virtual R VisitAttr_(const tir::SizeVarNode* op, Args... args) {
    return VisitAttr_(static_cast<const tir::VarNode*>(op), std::forward<Args>(args)...);
  }
  virtual R VisitAttr_(const tir::AddNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
  virtual R VisitAttr_(const tir::SubNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
  virtual R VisitAttr_(const tir::MulNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
  virtual R VisitAttr_(const tir::DivNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
  virtual R VisitAttr_(const tir::ModNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
  virtual R VisitAttr_(const tir::FloorDivNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
  virtual R VisitAttr_(const tir::FloorModNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
  virtual R VisitAttr_(const tir::MinNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
  virtual R VisitAttr_(const tir::MaxNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
  virtual R VisitAttr_(const tir::GENode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
  virtual R VisitAttr_(const tir::GTNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
  virtual R VisitAttr_(const tir::LTNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
  virtual R VisitAttr_(const tir::LENode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
  virtual R VisitAttr_(const tir::EQNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
  virtual R VisitAttr_(const tir::NENode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
  virtual R VisitAttr_(const tir::AndNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
  virtual R VisitAttr_(const tir::OrNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
  virtual R VisitAttr_(const tir::NotNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
  virtual R VisitAttr_(const tir::CastNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
  virtual R VisitAttr_(const tir::CallNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
  virtual R VisitAttr_(const tir::SelectNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;

 private:
  // initialize the vtable.
  static FType InitVTable() {
    using namespace tir;
    FType vtable;
    // Set dispatch
    ATTR_FUNCTOR_DISPATCH(StrMapNode);
    ATTR_FUNCTOR_DISPATCH(ArrayNode);
    ATTR_FUNCTOR_DISPATCH(IntImmNode);
    ATTR_FUNCTOR_DISPATCH(FloatImmNode);
    ATTR_FUNCTOR_DISPATCH(StringImmNode);
    ATTR_FUNCTOR_DISPATCH(VarNode);
    ATTR_FUNCTOR_DISPATCH(SizeVarNode);
    ATTR_FUNCTOR_DISPATCH(AddNode);
    ATTR_FUNCTOR_DISPATCH(SubNode);
    ATTR_FUNCTOR_DISPATCH(MulNode);
    ATTR_FUNCTOR_DISPATCH(DivNode);
    ATTR_FUNCTOR_DISPATCH(ModNode);
    ATTR_FUNCTOR_DISPATCH(FloorDivNode);
    ATTR_FUNCTOR_DISPATCH(FloorModNode);
    ATTR_FUNCTOR_DISPATCH(MinNode);
    ATTR_FUNCTOR_DISPATCH(MaxNode);
    ATTR_FUNCTOR_DISPATCH(GENode);
    ATTR_FUNCTOR_DISPATCH(GTNode);
    ATTR_FUNCTOR_DISPATCH(LENode);
    ATTR_FUNCTOR_DISPATCH(LTNode);
    ATTR_FUNCTOR_DISPATCH(EQNode);
    ATTR_FUNCTOR_DISPATCH(NENode);
    ATTR_FUNCTOR_DISPATCH(AndNode);
    ATTR_FUNCTOR_DISPATCH(OrNode);
    ATTR_FUNCTOR_DISPATCH(NotNode);
    ATTR_FUNCTOR_DISPATCH(CastNode);
    ATTR_FUNCTOR_DISPATCH(CallNode);
    ATTR_FUNCTOR_DISPATCH(SelectNode);
    return vtable;
  }
};

}  // namespace tvm
#endif  // TVM_IR_ATTR_FUNCTOR_H_