attr_functor.h 9.43 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23 24 25 26 27 28 29 30 31 32 33
/*!
 *  Copyright (c) 2018 by Contributors
 * \file attr_functor.h
 * \brief A way to define arbitrary function signature
 *        with dispatch on common attributes.
 *
 * Common attributes include:
 *  - int, float, str constants
 *  - array of attributes
 *  - map of attributes
 */
#ifndef TVM_LANG_ATTR_FUNCTOR_H_
#define TVM_LANG_ATTR_FUNCTOR_H_

34 35
#include <utility>

36 37 38 39 40
namespace tvm {

template <typename FType>
class AttrFunctor;

41 42 43 44
#define ATTR_FUNCTOR_DEFAULT                                        \
  { return VisitAttrDefault_(op, std::forward<Args>(args)...); }


45 46 47
#define ATTR_FUNCTOR_DISPATCH(OP)                                       \
  vtable.template set_dispatch<OP>(                                     \
      [](const NodeRef& n, TSelf* self, Args... args) {                 \
48 49
        return self->VisitAttr_(static_cast<const OP*>(n.node_.get()),  \
                                std::forward<Args>(args)...);           \
50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
      });                                                               \

// A functor for common attribute information.
template <typename R, typename... Args>
class AttrFunctor<R(const NodeRef& n, Args...)> {
 private:
  using TSelf = AttrFunctor<R(const NodeRef& n, Args...)>;
  using FType = tvm::IRFunctor<R(const NodeRef& n, TSelf* self, Args...)>;

 public:
  /*! \brief the result type of this functor */
  using result_type = R;
  /*!
   * \brief The functor call.
   * \param n The expression node.
   * \param args Additional arguments.
   * \return The result of the call
   */
68
  virtual R VisitAttr(const NodeRef& n, Args... args) {
69 70 71 72
    static FType vtable = InitVTable();
    if (vtable.can_dispatch(n)) {
      return vtable(n, this, std::forward<Args>(args)...);
    } else {
73
      return VisitAttrDefault_(n.get(), std::forward<Args>(args)...);
74 75
    }
  }
76
  virtual R VisitAttrDefault_(const Node* node, Args... args) = 0;
77 78 79 80 81 82
  virtual R VisitAttr_(const ArrayNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
  virtual R VisitAttr_(const StrMapNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
  virtual R VisitAttr_(const ir::IntImm* op, Args... args) ATTR_FUNCTOR_DEFAULT;
  virtual R VisitAttr_(const ir::UIntImm* op, Args... args) ATTR_FUNCTOR_DEFAULT;
  virtual R VisitAttr_(const ir::FloatImm* op, Args... args) ATTR_FUNCTOR_DEFAULT;
  virtual R VisitAttr_(const ir::StringImm* op, Args... args) ATTR_FUNCTOR_DEFAULT;
83 84 85 86 87
  // deep comparison of symbolic integer expressions.
  virtual R VisitAttr_(const Variable* op, Args... args) ATTR_FUNCTOR_DEFAULT;
  virtual R VisitAttr_(const ir::Add* op, Args... args) ATTR_FUNCTOR_DEFAULT;
  virtual R VisitAttr_(const ir::Sub* op, Args... args) ATTR_FUNCTOR_DEFAULT;
  virtual R VisitAttr_(const ir::Mul* op, Args... args) ATTR_FUNCTOR_DEFAULT;
Siva committed
88
  virtual R VisitAttr_(const ir::Div* op, Args... args) ATTR_FUNCTOR_DEFAULT;
89 90 91 92 93 94 95 96 97 98 99 100 101 102 103
  virtual R VisitAttr_(const ir::Mod* op, Args... args) ATTR_FUNCTOR_DEFAULT;
  virtual R VisitAttr_(const ir::Min* op, Args... args) ATTR_FUNCTOR_DEFAULT;
  virtual R VisitAttr_(const ir::Max* op, Args... args) ATTR_FUNCTOR_DEFAULT;
  virtual R VisitAttr_(const ir::GE* op, Args... args) ATTR_FUNCTOR_DEFAULT;
  virtual R VisitAttr_(const ir::GT* op, Args... args) ATTR_FUNCTOR_DEFAULT;
  virtual R VisitAttr_(const ir::LT* op, Args... args) ATTR_FUNCTOR_DEFAULT;
  virtual R VisitAttr_(const ir::LE* op, Args... args) ATTR_FUNCTOR_DEFAULT;
  virtual R VisitAttr_(const ir::EQ* op, Args... args) ATTR_FUNCTOR_DEFAULT;
  virtual R VisitAttr_(const ir::NE* op, Args... args) ATTR_FUNCTOR_DEFAULT;
  virtual R VisitAttr_(const ir::And* op, Args... args) ATTR_FUNCTOR_DEFAULT;
  virtual R VisitAttr_(const ir::Or* op, Args... args) ATTR_FUNCTOR_DEFAULT;
  virtual R VisitAttr_(const ir::Not* op, Args... args) ATTR_FUNCTOR_DEFAULT;
  virtual R VisitAttr_(const ir::Cast* op, Args... args) ATTR_FUNCTOR_DEFAULT;
  virtual R VisitAttr_(const ir::Call* op, Args... args) ATTR_FUNCTOR_DEFAULT;
  virtual R VisitAttr_(const ir::Select* op, Args... args) ATTR_FUNCTOR_DEFAULT;
104 105 106 107 108 109 110 111 112 113 114 115 116

 private:
  // initialize the vtable.
  static FType InitVTable() {
    using namespace ir;
    FType vtable;
    // Set dispatch
    ATTR_FUNCTOR_DISPATCH(StrMapNode);
    ATTR_FUNCTOR_DISPATCH(ArrayNode);
    ATTR_FUNCTOR_DISPATCH(IntImm);
    ATTR_FUNCTOR_DISPATCH(UIntImm);
    ATTR_FUNCTOR_DISPATCH(FloatImm);
    ATTR_FUNCTOR_DISPATCH(StringImm);
117 118 119 120
    ATTR_FUNCTOR_DISPATCH(Variable);
    ATTR_FUNCTOR_DISPATCH(Add);
    ATTR_FUNCTOR_DISPATCH(Sub);
    ATTR_FUNCTOR_DISPATCH(Mul);
Siva committed
121
    ATTR_FUNCTOR_DISPATCH(Div);
122 123 124 125 126 127 128 129 130 131 132 133 134 135
    ATTR_FUNCTOR_DISPATCH(Min);
    ATTR_FUNCTOR_DISPATCH(Max);
    ATTR_FUNCTOR_DISPATCH(GE);
    ATTR_FUNCTOR_DISPATCH(GT);
    ATTR_FUNCTOR_DISPATCH(LE);
    ATTR_FUNCTOR_DISPATCH(LT);
    ATTR_FUNCTOR_DISPATCH(EQ);
    ATTR_FUNCTOR_DISPATCH(NE);
    ATTR_FUNCTOR_DISPATCH(And);
    ATTR_FUNCTOR_DISPATCH(Or);
    ATTR_FUNCTOR_DISPATCH(Not);
    ATTR_FUNCTOR_DISPATCH(Cast);
    ATTR_FUNCTOR_DISPATCH(Call);
    ATTR_FUNCTOR_DISPATCH(Select);
136 137 138 139
    return vtable;
  }
};

140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160
class AttrsEqualHandler :
      protected AttrFunctor<bool(const NodeRef&, const NodeRef&)> {
 public:
  /*!
   * \brief Check if lhs equals rhs
   * \param lhs The left operand.
   * \param rhs The right operand.
   */
  bool Equal(const NodeRef& lhs, const NodeRef& rhs);

 protected:
  bool VisitAttrDefault_(const Node* lhs, const NodeRef& other) final;
  bool VisitAttr_(const ArrayNode* lhs, const NodeRef& other) final;
  bool VisitAttr_(const StrMapNode* lhs, const NodeRef& other) final;
  bool VisitAttr_(const ir::IntImm* lhs, const NodeRef& other) final;
  bool VisitAttr_(const ir::UIntImm* lhs, const NodeRef& other) final;
  bool VisitAttr_(const ir::FloatImm* lhs, const NodeRef& other) final;
  bool VisitAttr_(const ir::StringImm* lhs, const NodeRef& other) final;
  bool VisitAttr_(const ir::Add* lhs, const NodeRef& other) final;
  bool VisitAttr_(const ir::Sub* lhs, const NodeRef& other) final;
  bool VisitAttr_(const ir::Mul* lhs, const NodeRef& other) final;
Siva committed
161
  bool VisitAttr_(const ir::Div* lhs, const NodeRef& other) final;
162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186
  bool VisitAttr_(const ir::Mod* lhs, const NodeRef& other) final;
  bool VisitAttr_(const ir::Min* lhs, const NodeRef& other) final;
  bool VisitAttr_(const ir::Max* lhs, const NodeRef& other) final;
  bool VisitAttr_(const ir::GE* lhs, const NodeRef& other) final;
  bool VisitAttr_(const ir::GT* lhs, const NodeRef& other) final;
  bool VisitAttr_(const ir::LT* lhs, const NodeRef& other) final;
  bool VisitAttr_(const ir::LE* lhs, const NodeRef& other) final;
  bool VisitAttr_(const ir::EQ* lhs, const NodeRef& other) final;
  bool VisitAttr_(const ir::NE* lhs, const NodeRef& other) final;
  bool VisitAttr_(const ir::And* lhs, const NodeRef& other) final;
  bool VisitAttr_(const ir::Or* lhs, const NodeRef& other) final;
  bool VisitAttr_(const ir::Not* lhs, const NodeRef& other) final;
  bool VisitAttr_(const ir::Cast* lhs, const NodeRef& other) final;
  bool VisitAttr_(const ir::Call* lhs, const NodeRef& other) final;
  bool VisitAttr_(const ir::Select* lhs, const NodeRef& other) final;
};

class AttrsHashHandler :
      protected AttrFunctor<size_t(const NodeRef&)> {
 public:
  /*!
   * \brief Get hash value of node
   * \param node The node to be hashed.
   */
  size_t Hash(const NodeRef& node) {
187
    if (!node.defined()) return 0;
188 189 190 191 192 193 194 195 196 197 198 199 200 201
    return this->VisitAttr(node);
  }

 protected:
  size_t VisitAttrDefault_(const Node* lhs) final;
  size_t VisitAttr_(const ir::IntImm* lhs) final;
  size_t VisitAttr_(const ir::UIntImm* lhs) final;
  size_t VisitAttr_(const ir::FloatImm* lhs) final;
  size_t VisitAttr_(const ir::StringImm* lhs) final;
  size_t VisitAttr_(const ArrayNode* lhs) final;
  size_t VisitAttr_(const StrMapNode* lhs) final;
  size_t VisitAttr_(const ir::Add* op) final;
  size_t VisitAttr_(const ir::Sub* op) final;
  size_t VisitAttr_(const ir::Mul* op) final;
Siva committed
202
  size_t VisitAttr_(const ir::Div* op) final;
203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226
  size_t VisitAttr_(const ir::Mod* op) final;
  size_t VisitAttr_(const ir::Min* op) final;
  size_t VisitAttr_(const ir::Max* op) final;
  size_t VisitAttr_(const ir::GE* op) final;
  size_t VisitAttr_(const ir::GT* op) final;
  size_t VisitAttr_(const ir::LE* op) final;
  size_t VisitAttr_(const ir::LT* op) final;
  size_t VisitAttr_(const ir::EQ* op) final;
  size_t VisitAttr_(const ir::NE* op) final;
  size_t VisitAttr_(const ir::And* op) final;
  size_t VisitAttr_(const ir::Or* op) final;
  size_t VisitAttr_(const ir::Not* op) final;
  size_t VisitAttr_(const ir::Cast* op) final;
  size_t VisitAttr_(const ir::Call* op) final;
  size_t VisitAttr_(const ir::Select* op) final;
  /*!
   * \brief alias of dmlc::HashCombine
   * \param lhs The first hash value.
   * \param rhs The second hash value.
   */
  static size_t Combine(size_t lhs, size_t rhs) {
    return dmlc::HashCombine(lhs, rhs);
  }
};
227 228
}  // namespace tvm
#endif  // TVM_LANG_ATTR_FUNCTOR_H_