attr_functor.h 6.35 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23 24 25 26 27 28 29
/*!
 * \file attr_functor.h
 * \brief A way to define arbitrary function signature
 *        with dispatch on common attributes.
 *
 * Common attributes include:
 *  - int, float, str constants
 *  - array of attributes
 *  - map of attributes
 */
30 31
#ifndef TVM_IR_ATTR_FUNCTOR_H_
#define TVM_IR_ATTR_FUNCTOR_H_
32

33
#include <tvm/node/functor.h>
34
#include <tvm/tir/expr.h>
35 36
#include <utility>

37 38 39 40 41
namespace tvm {

template <typename FType>
class AttrFunctor;

42 43 44 45
#define ATTR_FUNCTOR_DEFAULT                                        \
  { return VisitAttrDefault_(op, std::forward<Args>(args)...); }


46 47
#define ATTR_FUNCTOR_DISPATCH(OP)                                       \
  vtable.template set_dispatch<OP>(                                     \
48 49
      [](const ObjectRef& n, TSelf* self, Args... args) {               \
        return self->VisitAttr_(static_cast<const OP*>(n.get()),        \
50
                                std::forward<Args>(args)...);           \
51 52 53 54
      });                                                               \

// A functor for common attribute information.
template <typename R, typename... Args>
55
class AttrFunctor<R(const ObjectRef& n, Args...)> {
56
 private:
57
  using TSelf = AttrFunctor<R(const ObjectRef& n, Args...)>;
58
  using FType = tvm::NodeFunctor<R(const ObjectRef& n, TSelf* self, Args...)>;
59 60 61 62

 public:
  /*! \brief the result type of this functor */
  using result_type = R;
63 64
  /*! \brief virtual destructor */
  virtual ~AttrFunctor() {}
65 66 67 68 69 70
  /*!
   * \brief The functor call.
   * \param n The expression node.
   * \param args Additional arguments.
   * \return The result of the call
   */
71
  virtual R VisitAttr(const ObjectRef& n, Args... args) {
72 73 74 75
    static FType vtable = InitVTable();
    if (vtable.can_dispatch(n)) {
      return vtable(n, this, std::forward<Args>(args)...);
    } else {
76
      return VisitAttrDefault_(n.get(), std::forward<Args>(args)...);
77 78
    }
  }
79
  virtual R VisitAttrDefault_(const Object* node, Args... args) = 0;
80 81
  virtual R VisitAttr_(const ArrayNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
  virtual R VisitAttr_(const StrMapNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
82 83 84
  virtual R VisitAttr_(const tir::IntImmNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
  virtual R VisitAttr_(const tir::FloatImmNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
  virtual R VisitAttr_(const tir::StringImmNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
85
  // deep comparison of symbolic integer expressions.
86 87 88
  virtual R VisitAttr_(const tir::VarNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
  virtual R VisitAttr_(const tir::SizeVarNode* op, Args... args) {
    return VisitAttr_(static_cast<const tir::VarNode*>(op), std::forward<Args>(args)...);
89
  }
90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110
  virtual R VisitAttr_(const tir::AddNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
  virtual R VisitAttr_(const tir::SubNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
  virtual R VisitAttr_(const tir::MulNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
  virtual R VisitAttr_(const tir::DivNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
  virtual R VisitAttr_(const tir::ModNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
  virtual R VisitAttr_(const tir::FloorDivNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
  virtual R VisitAttr_(const tir::FloorModNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
  virtual R VisitAttr_(const tir::MinNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
  virtual R VisitAttr_(const tir::MaxNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
  virtual R VisitAttr_(const tir::GENode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
  virtual R VisitAttr_(const tir::GTNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
  virtual R VisitAttr_(const tir::LTNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
  virtual R VisitAttr_(const tir::LENode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
  virtual R VisitAttr_(const tir::EQNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
  virtual R VisitAttr_(const tir::NENode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
  virtual R VisitAttr_(const tir::AndNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
  virtual R VisitAttr_(const tir::OrNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
  virtual R VisitAttr_(const tir::NotNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
  virtual R VisitAttr_(const tir::CastNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
  virtual R VisitAttr_(const tir::CallNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
  virtual R VisitAttr_(const tir::SelectNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
111 112 113 114

 private:
  // initialize the vtable.
  static FType InitVTable() {
115
    using namespace tir;
116 117 118 119
    FType vtable;
    // Set dispatch
    ATTR_FUNCTOR_DISPATCH(StrMapNode);
    ATTR_FUNCTOR_DISPATCH(ArrayNode);
120 121 122 123
    ATTR_FUNCTOR_DISPATCH(IntImmNode);
    ATTR_FUNCTOR_DISPATCH(FloatImmNode);
    ATTR_FUNCTOR_DISPATCH(StringImmNode);
    ATTR_FUNCTOR_DISPATCH(VarNode);
124
    ATTR_FUNCTOR_DISPATCH(SizeVarNode);
125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
    ATTR_FUNCTOR_DISPATCH(AddNode);
    ATTR_FUNCTOR_DISPATCH(SubNode);
    ATTR_FUNCTOR_DISPATCH(MulNode);
    ATTR_FUNCTOR_DISPATCH(DivNode);
    ATTR_FUNCTOR_DISPATCH(ModNode);
    ATTR_FUNCTOR_DISPATCH(FloorDivNode);
    ATTR_FUNCTOR_DISPATCH(FloorModNode);
    ATTR_FUNCTOR_DISPATCH(MinNode);
    ATTR_FUNCTOR_DISPATCH(MaxNode);
    ATTR_FUNCTOR_DISPATCH(GENode);
    ATTR_FUNCTOR_DISPATCH(GTNode);
    ATTR_FUNCTOR_DISPATCH(LENode);
    ATTR_FUNCTOR_DISPATCH(LTNode);
    ATTR_FUNCTOR_DISPATCH(EQNode);
    ATTR_FUNCTOR_DISPATCH(NENode);
    ATTR_FUNCTOR_DISPATCH(AndNode);
    ATTR_FUNCTOR_DISPATCH(OrNode);
    ATTR_FUNCTOR_DISPATCH(NotNode);
    ATTR_FUNCTOR_DISPATCH(CastNode);
    ATTR_FUNCTOR_DISPATCH(CallNode);
    ATTR_FUNCTOR_DISPATCH(SelectNode);
146 147 148 149 150
    return vtable;
  }
};

}  // namespace tvm
151
#endif  // TVM_IR_ATTR_FUNCTOR_H_