ir_mutator.h 5.18 KB
Newer Older
1 2 3 4 5 6 7 8
/*!
 *  Copyright (c) 2016 by Contributors
 * \file ir_mutator.h
 * \brief Defines general IRMutation pass
 */
#ifndef TVM_IR_MUTATOR_H_
#define TVM_IR_MUTATOR_H_

9
#include <tvm/ir_functor.h>
tqchen committed
10
#include <unordered_map>
11
#include "./expr.h"
12
#include "./ir.h"
13 14 15 16 17 18

namespace tvm {
namespace ir {
/*!
 * \brief a base class for mutator to iterative mutate the IR
 *
19 20
 *  This IRMutator is implemented via Visitor Pattern.
 *  Also you can implement via IRFunctor.
21
 *  This enables easy extensions of possible new Node.
22 23 24 25 26 27
 *  It also makes changing return types easier.
 *
 * \note If you want to return a different type other than Expr and Stmt,
 *       Simply following the same pattern as IRMutator and create a seperate class.
 * \sa IRFunctor
 */
28
class TVM_DLL IRMutator {
29 30 31 32 33
 public:
  /*!
   * \brief mutate expression
   * \return the mutated expr
   */
tqchen committed
34
  virtual Expr Mutate(Expr expr) {
35 36 37 38 39 40 41
    static const FMutateExpr& f = vtable_expr();
    return f(expr, expr, this);
  }
  /*!
   * \brief mutate expression
   * \return the mutated stmt
   */
tqchen committed
42
  virtual Stmt Mutate(Stmt stmt) {
43 44 45 46 47 48
    static const FMutateStmt& f = vtable_stmt();
    return f(stmt, stmt, this);
  }
  /*! \brief destructor */
  virtual ~IRMutator() {}
  /*! \brief functor type of expr mutation */
49
  using FMutateExpr = IRFunctor<Expr(const NodeRef&, const Expr&, IRMutator*)>;
50
  /*! \brief functor type of stmt mutation */
51
  using FMutateStmt = IRFunctor<Stmt(const NodeRef&, const Stmt&, IRMutator*)>;
52 53 54 55
  /*! \return internal vtable of expr */
  static FMutateExpr& vtable_expr();  // NOLINT(*)
  /*! \return internal stmt of expr */
  static FMutateStmt& vtable_stmt();  // NOLINT(*)
56 57 58 59
  // Set of overloadable functions
  // The underscore allows Mutate not to be shadowed by inheritance
  virtual Stmt Mutate_(const LetStmt* op, const Stmt& s);
  virtual Stmt Mutate_(const AttrStmt* op, const Stmt& s);
60
  virtual Stmt Mutate_(const IfThenElse* op, const Stmt& s);
61 62 63 64
  virtual Stmt Mutate_(const For* op, const Stmt& s);
  virtual Stmt Mutate_(const Allocate* op, const Stmt& s);
  virtual Stmt Mutate_(const Store* op, const Stmt& s);
  virtual Stmt Mutate_(const Free* op, const Stmt& s);
65 66 67
  virtual Stmt Mutate_(const AssertStmt* op, const Stmt& s);
  virtual Stmt Mutate_(const ProducerConsumer* op, const Stmt& s);
  virtual Stmt Mutate_(const Provide* op, const Stmt& s);
68
  virtual Stmt Mutate_(const Realize* op, const Stmt& s);
69
  virtual Stmt Mutate_(const Prefetch* op, const Stmt& s);
70
  virtual Stmt Mutate_(const Block* op, const Stmt& s);
71
  virtual Stmt Mutate_(const Evaluate* op, const Stmt& s);
72

73
  virtual Expr Mutate_(const Variable* op, const Expr& e);
74
  virtual Expr Mutate_(const Load* op, const Expr& e);
75
  virtual Expr Mutate_(const Let* op, const Expr& e);
76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101
  virtual Expr Mutate_(const Call* op, const Expr& e);
  virtual Expr Mutate_(const Add* op, const Expr& e);
  virtual Expr Mutate_(const Sub* op, const Expr& e);
  virtual Expr Mutate_(const Mul* op, const Expr& e);
  virtual Expr Mutate_(const Div* op, const Expr& e);
  virtual Expr Mutate_(const Mod* op, const Expr& e);
  virtual Expr Mutate_(const Min* op, const Expr& e);
  virtual Expr Mutate_(const Max* op, const Expr& e);
  virtual Expr Mutate_(const EQ* op, const Expr& e);
  virtual Expr Mutate_(const NE* op, const Expr& e);
  virtual Expr Mutate_(const LT* op, const Expr& e);
  virtual Expr Mutate_(const LE* op, const Expr& e);
  virtual Expr Mutate_(const GT* op, const Expr& e);
  virtual Expr Mutate_(const GE* op, const Expr& e);
  virtual Expr Mutate_(const And* op, const Expr& e);
  virtual Expr Mutate_(const Or* op, const Expr& e);
  virtual Expr Mutate_(const Reduce* op, const Expr& e);
  virtual Expr Mutate_(const Cast* op, const Expr& e);
  virtual Expr Mutate_(const Not* op, const Expr& e);
  virtual Expr Mutate_(const Select* op, const Expr& e);
  virtual Expr Mutate_(const Ramp* op, const Expr& e);
  virtual Expr Mutate_(const Broadcast* op, const Expr& e);
  virtual Expr Mutate_(const IntImm* op, const Expr& e);
  virtual Expr Mutate_(const UIntImm* op, const Expr& e);
  virtual Expr Mutate_(const FloatImm* op, const Expr& e);
  virtual Expr Mutate_(const StringImm* op, const Expr& e);
102
  virtual Expr Mutate_(const Shuffle* op, const Expr& e);
103 104
};

105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123
/*!
 * \brief recursively visit the ir in post DFS order node, and transform it
 *
 * \param node The ir to be transformed.
 * \param preorder The function called in before recursive mutation
 *          If preorder returns None, then the transform will proceed to recursive call.
 *          If preorder returns a not None Stmt/Expr, the transformer will simply return it and
 *          won't do further recursion.
 * \param postorder The function called after recursive mutation.
 *          The recursive mutation result is passed to postorder for further mutation.
 * \param only_enable List of StringImm.
 *          If it is empty, all IRNode will call preorder/postorder
 *          If it is not empty, preorder/postorder will only be called
 *          when the IRNode's type key is in the list.
 */
Stmt IRTransform(const Stmt& node,
                 const runtime::PackedFunc& preorder,
                 const runtime::PackedFunc& postorder,
                 const Array<Expr>& only_enable = {});
124 125 126
}  // namespace ir
}  // namespace tvm
#endif  // TVM_IR_MUTATOR_H_