expr_functor.h 12 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20
/*!
21 22 23
 * \file tvm/tir/expr_functor.h
 *
 * \brief Functors for tir expressions.
24
 */
25 26
#ifndef TVM_TIR_EXPR_FUNCTOR_H_
#define TVM_TIR_EXPR_FUNCTOR_H_
27

28
#include <tvm/node/functor.h>
29
#include <tvm/tir/expr.h>
30

31
#include <utility>
32 33

namespace tvm {
34
namespace tir {
35 36 37 38 39 40

/*!
 * \brief A dynamical functor that dispatches on in the first Expr argument.
 *  You can use this as a more powerful Visitor, since it allows you to
 *  define function signatures of Visit Function.
 *
41 42 43
 *  This helps you to avoid to book-keep return value of Visitor via state,
 *  which can cause bugs easily when state is incorrectly maintained.
 *
44 45 46
 * \code
 *  // A functor that set variable to b. and calculate results.
 *  class MyExprFunctor
47
 *    : public tir::ExprFunctor<int(const Expr&, int)> {
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83
 *   public:
 *    int VisitExpr_(const Variable* op, int b) final {
 *     return b;
 *    }
 *    int VisitExpr_(const IntImm* op, int b) final {
 *      return op->value;
 *    }
 *    int VisitExpr_(const Add* op, int b) final {
 *     return Visit(op->a, b) + Visit(op->b, b);
 *    }
 *  };
 *  MyExprFunctor f;
 *  Var x("x");
 *  CHECK_EQ(f(x + 1, 2), 3);
 * \endcode
 *
 * \note Why do we need this more powerful Functor:
 *
 *  We often need to implement a transformer tasks.
 *  Say we want to take Expr and transform it to some analysis result,
 *  This easily be done incorrectly using plain Visitor. See IRVisitor's
 *  document for possible error cases.
 *
 * \tparam FType function signiture
 *  This type if only defined for FType with function signiture R(const Expr&, Args...)
 */
template<typename FType>
class ExprFunctor;

// functions to be overriden.
#define EXPR_FUNCTOR_DEFAULT {                                      \
    return VisitExprDefault_(op, std::forward<Args>(args)...);      \
  }

#define IR_EXPR_FUNCTOR_DISPATCH(OP)                                    \
  vtable.template set_dispatch<OP>(                                     \
84 85
      [](const ObjectRef& n, TSelf* self, Args... args) {               \
        return self->VisitExpr_(static_cast<const OP*>(n.get()),        \
86 87 88 89
                                std::forward<Args>(args)...);           \
      });                                                               \

template<typename R, typename ...Args>
90
class ExprFunctor<R(const PrimExpr& n, Args...)> {
91
 private:
92
  using TSelf = ExprFunctor<R(const PrimExpr& n, Args...)>;
93
  using FType = NodeFunctor<R(const ObjectRef& n, TSelf* self, Args...)>;
94 95 96 97 98 99 100 101 102 103 104 105

 public:
  /*! \brief the result type of this functor */
  using result_type = R;
  /*! \brief virtual destructor */
  virtual ~ExprFunctor() {}
  /*!
   * \brief Same as call.
   * \param n The expression node.
   * \param args Additional arguments.
   * \return The result of the call
   */
106
  R operator()(const PrimExpr& n, Args... args) {
107 108 109 110 111 112 113 114
    return VisitExpr(n, std::forward<Args>(args)...);
  }
  /*!
   * \brief The functor call.
   * \param n The expression node.
   * \param args Additional arguments.
   * \return The result of the call
   */
115
  virtual R VisitExpr(const PrimExpr& n, Args... args) {
116 117 118 119
    static FType vtable = InitVTable();
    return vtable(n, this, std::forward<Args>(args)...);
  }
  // Functions that can be overriden by subclass
120
  virtual R VisitExpr_(const VarNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
121 122 123
  virtual R VisitExpr_(const SizeVarNode* op, Args... args) {
    return VisitExpr_(static_cast<const VarNode*>(op), std::forward<Args>(args)...);
  }
124
  virtual R VisitExpr_(const BufferLoadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154
  virtual R VisitExpr_(const LoadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
  virtual R VisitExpr_(const LetNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
  virtual R VisitExpr_(const CallNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
  virtual R VisitExpr_(const AddNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
  virtual R VisitExpr_(const SubNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
  virtual R VisitExpr_(const MulNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
  virtual R VisitExpr_(const DivNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
  virtual R VisitExpr_(const ModNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
  virtual R VisitExpr_(const FloorDivNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
  virtual R VisitExpr_(const FloorModNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
  virtual R VisitExpr_(const MinNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
  virtual R VisitExpr_(const MaxNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
  virtual R VisitExpr_(const EQNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
  virtual R VisitExpr_(const NENode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
  virtual R VisitExpr_(const LTNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
  virtual R VisitExpr_(const LENode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
  virtual R VisitExpr_(const GTNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
  virtual R VisitExpr_(const GENode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
  virtual R VisitExpr_(const AndNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
  virtual R VisitExpr_(const OrNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
  virtual R VisitExpr_(const ReduceNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
  virtual R VisitExpr_(const CastNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
  virtual R VisitExpr_(const NotNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
  virtual R VisitExpr_(const SelectNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
  virtual R VisitExpr_(const RampNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
  virtual R VisitExpr_(const BroadcastNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
  virtual R VisitExpr_(const ShuffleNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
  virtual R VisitExpr_(const IntImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
  virtual R VisitExpr_(const FloatImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
  virtual R VisitExpr_(const StringImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
155
  virtual R VisitExprDefault_(const Object* op, Args ...) {
156
    LOG(FATAL) << "Do not have a default for " << op->GetTypeKey();
157 158 159 160 161 162 163 164
    return R();
  }

 private:
  // initialize the vtable.
  static FType InitVTable() {
    FType vtable;
    // Set dispatch
165
    IR_EXPR_FUNCTOR_DISPATCH(VarNode);
166
    IR_EXPR_FUNCTOR_DISPATCH(SizeVarNode);
167
    IR_EXPR_FUNCTOR_DISPATCH(LoadNode);
168
    IR_EXPR_FUNCTOR_DISPATCH(BufferLoadNode);
169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197
    IR_EXPR_FUNCTOR_DISPATCH(LetNode);
    IR_EXPR_FUNCTOR_DISPATCH(CallNode);
    IR_EXPR_FUNCTOR_DISPATCH(AddNode);
    IR_EXPR_FUNCTOR_DISPATCH(SubNode);
    IR_EXPR_FUNCTOR_DISPATCH(MulNode);
    IR_EXPR_FUNCTOR_DISPATCH(DivNode);
    IR_EXPR_FUNCTOR_DISPATCH(ModNode);
    IR_EXPR_FUNCTOR_DISPATCH(FloorDivNode);
    IR_EXPR_FUNCTOR_DISPATCH(FloorModNode);
    IR_EXPR_FUNCTOR_DISPATCH(MinNode);
    IR_EXPR_FUNCTOR_DISPATCH(MaxNode);
    IR_EXPR_FUNCTOR_DISPATCH(EQNode);
    IR_EXPR_FUNCTOR_DISPATCH(NENode);
    IR_EXPR_FUNCTOR_DISPATCH(LTNode);
    IR_EXPR_FUNCTOR_DISPATCH(LENode);
    IR_EXPR_FUNCTOR_DISPATCH(GTNode);
    IR_EXPR_FUNCTOR_DISPATCH(GENode);
    IR_EXPR_FUNCTOR_DISPATCH(AndNode);
    IR_EXPR_FUNCTOR_DISPATCH(OrNode);
    IR_EXPR_FUNCTOR_DISPATCH(ReduceNode);
    IR_EXPR_FUNCTOR_DISPATCH(CastNode);
    IR_EXPR_FUNCTOR_DISPATCH(NotNode);
    IR_EXPR_FUNCTOR_DISPATCH(SelectNode);
    IR_EXPR_FUNCTOR_DISPATCH(RampNode);
    IR_EXPR_FUNCTOR_DISPATCH(ShuffleNode);
    IR_EXPR_FUNCTOR_DISPATCH(BroadcastNode);
    IR_EXPR_FUNCTOR_DISPATCH(IntImmNode);
    IR_EXPR_FUNCTOR_DISPATCH(FloatImmNode);
    IR_EXPR_FUNCTOR_DISPATCH(StringImmNode);
198 199 200 201 202 203 204
    return vtable;
  }
};

#undef IR_EXPR_FUNCTOR_DISPATCH
#undef EXPR_FUNCTOR_DEFAULT

205 206 207 208
/*!
 * \brief ExprVisitor
 */
class TVM_DLL ExprVisitor :
209
      public ExprFunctor<void(const PrimExpr&)> {
210 211 212 213 214 215
 public:
  using ExprFunctor::operator();

 protected:
  using ExprFunctor::VisitExpr;
  // list of functions to override.
216
  void VisitExpr_(const VarNode* op) override;
217
  void VisitExpr_(const SizeVarNode* op) override;
218
  void VisitExpr_(const LoadNode* op) override;
219
  void VisitExpr_(const BufferLoadNode* op) override;
220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248
  void VisitExpr_(const LetNode* op) override;
  void VisitExpr_(const CallNode* op) override;
  void VisitExpr_(const AddNode* op) override;
  void VisitExpr_(const SubNode* op) override;
  void VisitExpr_(const MulNode* op) override;
  void VisitExpr_(const DivNode* op) override;
  void VisitExpr_(const ModNode* op) override;
  void VisitExpr_(const FloorDivNode* op) override;
  void VisitExpr_(const FloorModNode* op) override;
  void VisitExpr_(const MinNode* op) override;
  void VisitExpr_(const MaxNode* op) override;
  void VisitExpr_(const EQNode* op) override;
  void VisitExpr_(const NENode* op) override;
  void VisitExpr_(const LTNode* op) override;
  void VisitExpr_(const LENode* op) override;
  void VisitExpr_(const GTNode* op) override;
  void VisitExpr_(const GENode* op) override;
  void VisitExpr_(const AndNode* op) override;
  void VisitExpr_(const OrNode* op) override;
  void VisitExpr_(const ReduceNode* op) override;
  void VisitExpr_(const CastNode* op) override;
  void VisitExpr_(const NotNode* op) override;
  void VisitExpr_(const SelectNode* op) override;
  void VisitExpr_(const RampNode* op) override;
  void VisitExpr_(const BroadcastNode* op) override;
  void VisitExpr_(const ShuffleNode* op) override;
  void VisitExpr_(const IntImmNode* op) override;
  void VisitExpr_(const FloatImmNode* op) override;
  void VisitExpr_(const StringImmNode* op) override;
249 250 251 252 253 254
};

/*!
 * \brief ExprMutator that mutates expressions.
 */
class TVM_DLL ExprMutator :
255
      protected ExprFunctor<PrimExpr(const PrimExpr&)> {
256 257 258 259 260 261
 public:
  using ExprFunctor::operator();

 protected:
  using ExprFunctor::VisitExpr;
  // list of functions to override.
262
  PrimExpr VisitExpr_(const VarNode* op) override;
263
  PrimExpr VisitExpr_(const SizeVarNode* op) override;
264
  PrimExpr VisitExpr_(const LoadNode* op) override;
265
  PrimExpr VisitExpr_(const BufferLoadNode* op) override;
266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294
  PrimExpr VisitExpr_(const LetNode* op) override;
  PrimExpr VisitExpr_(const CallNode* op) override;
  PrimExpr VisitExpr_(const AddNode* op) override;
  PrimExpr VisitExpr_(const SubNode* op) override;
  PrimExpr VisitExpr_(const MulNode* op) override;
  PrimExpr VisitExpr_(const DivNode* op) override;
  PrimExpr VisitExpr_(const ModNode* op) override;
  PrimExpr VisitExpr_(const FloorDivNode* op) override;
  PrimExpr VisitExpr_(const FloorModNode* op) override;
  PrimExpr VisitExpr_(const MinNode* op) override;
  PrimExpr VisitExpr_(const MaxNode* op) override;
  PrimExpr VisitExpr_(const EQNode* op) override;
  PrimExpr VisitExpr_(const NENode* op) override;
  PrimExpr VisitExpr_(const LTNode* op) override;
  PrimExpr VisitExpr_(const LENode* op) override;
  PrimExpr VisitExpr_(const GTNode* op) override;
  PrimExpr VisitExpr_(const GENode* op) override;
  PrimExpr VisitExpr_(const AndNode* op) override;
  PrimExpr VisitExpr_(const OrNode* op) override;
  PrimExpr VisitExpr_(const ReduceNode* op) override;
  PrimExpr VisitExpr_(const CastNode* op) override;
  PrimExpr VisitExpr_(const NotNode* op) override;
  PrimExpr VisitExpr_(const SelectNode* op) override;
  PrimExpr VisitExpr_(const RampNode* op) override;
  PrimExpr VisitExpr_(const BroadcastNode* op) override;
  PrimExpr VisitExpr_(const ShuffleNode* op) override;
  PrimExpr VisitExpr_(const IntImmNode* op) override;
  PrimExpr VisitExpr_(const FloatImmNode* op) override;
  PrimExpr VisitExpr_(const StringImmNode* op) override;
295 296
};

297
}  // namespace tir
298
}  // namespace tvm
299
#endif  // TVM_TIR_EXPR_FUNCTOR_H_