ir_mutator.h 6.07 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20
/*!
tqchen committed
21
 * \file tvm/ir_mutator.h
22 23 24 25 26
 * \brief Defines general IRMutation pass
 */
#ifndef TVM_IR_MUTATOR_H_
#define TVM_IR_MUTATOR_H_

tqchen committed
27
#include <unordered_map>
28
#include <utility>
29 30
#include "expr.h"
#include "ir.h"
31
#include "tvm/node/ir_functor.h"
32 33 34 35 36 37

namespace tvm {
namespace ir {
/*!
 * \brief a base class for mutator to iterative mutate the IR
 *
38 39
 *  This IRMutator is implemented via Visitor Pattern.
 *  Also you can implement via IRFunctor.
40
 *  This enables easy extensions of possible new Node.
41 42 43 44 45 46
 *  It also makes changing return types easier.
 *
 * \note If you want to return a different type other than Expr and Stmt,
 *       Simply following the same pattern as IRMutator and create a seperate class.
 * \sa IRFunctor
 */
47
class TVM_DLL IRMutator {
48 49 50 51 52
 public:
  /*!
   * \brief mutate expression
   * \return the mutated expr
   */
tqchen committed
53
  virtual Expr Mutate(Expr expr) {
54 55 56 57 58 59 60
    static const FMutateExpr& f = vtable_expr();
    return f(expr, expr, this);
  }
  /*!
   * \brief mutate expression
   * \return the mutated stmt
   */
tqchen committed
61
  virtual Stmt Mutate(Stmt stmt) {
62 63 64 65 66 67
    static const FMutateStmt& f = vtable_stmt();
    return f(stmt, stmt, this);
  }
  /*! \brief destructor */
  virtual ~IRMutator() {}
  /*! \brief functor type of expr mutation */
68
  using FMutateExpr = IRFunctor<Expr(const NodeRef&, const Expr&, IRMutator*)>;
69
  /*! \brief functor type of stmt mutation */
70
  using FMutateStmt = IRFunctor<Stmt(const NodeRef&, const Stmt&, IRMutator*)>;
71 72 73 74
  /*! \return internal vtable of expr */
  static FMutateExpr& vtable_expr();  // NOLINT(*)
  /*! \return internal stmt of expr */
  static FMutateStmt& vtable_stmt();  // NOLINT(*)
75 76 77 78
  // Set of overloadable functions
  // The underscore allows Mutate not to be shadowed by inheritance
  virtual Stmt Mutate_(const LetStmt* op, const Stmt& s);
  virtual Stmt Mutate_(const AttrStmt* op, const Stmt& s);
79
  virtual Stmt Mutate_(const IfThenElse* op, const Stmt& s);
80 81 82 83
  virtual Stmt Mutate_(const For* op, const Stmt& s);
  virtual Stmt Mutate_(const Allocate* op, const Stmt& s);
  virtual Stmt Mutate_(const Store* op, const Stmt& s);
  virtual Stmt Mutate_(const Free* op, const Stmt& s);
84 85 86
  virtual Stmt Mutate_(const AssertStmt* op, const Stmt& s);
  virtual Stmt Mutate_(const ProducerConsumer* op, const Stmt& s);
  virtual Stmt Mutate_(const Provide* op, const Stmt& s);
87
  virtual Stmt Mutate_(const Realize* op, const Stmt& s);
88
  virtual Stmt Mutate_(const Prefetch* op, const Stmt& s);
89
  virtual Stmt Mutate_(const Block* op, const Stmt& s);
90
  virtual Stmt Mutate_(const Evaluate* op, const Stmt& s);
91

92
  virtual Expr Mutate_(const Variable* op, const Expr& e);
93
  virtual Expr Mutate_(const Load* op, const Expr& e);
94
  virtual Expr Mutate_(const Let* op, const Expr& e);
95 96 97 98 99 100
  virtual Expr Mutate_(const Call* op, const Expr& e);
  virtual Expr Mutate_(const Add* op, const Expr& e);
  virtual Expr Mutate_(const Sub* op, const Expr& e);
  virtual Expr Mutate_(const Mul* op, const Expr& e);
  virtual Expr Mutate_(const Div* op, const Expr& e);
  virtual Expr Mutate_(const Mod* op, const Expr& e);
101 102
  virtual Expr Mutate_(const FloorDiv* op, const Expr& e);
  virtual Expr Mutate_(const FloorMod* op, const Expr& e);
103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122
  virtual Expr Mutate_(const Min* op, const Expr& e);
  virtual Expr Mutate_(const Max* op, const Expr& e);
  virtual Expr Mutate_(const EQ* op, const Expr& e);
  virtual Expr Mutate_(const NE* op, const Expr& e);
  virtual Expr Mutate_(const LT* op, const Expr& e);
  virtual Expr Mutate_(const LE* op, const Expr& e);
  virtual Expr Mutate_(const GT* op, const Expr& e);
  virtual Expr Mutate_(const GE* op, const Expr& e);
  virtual Expr Mutate_(const And* op, const Expr& e);
  virtual Expr Mutate_(const Or* op, const Expr& e);
  virtual Expr Mutate_(const Reduce* op, const Expr& e);
  virtual Expr Mutate_(const Cast* op, const Expr& e);
  virtual Expr Mutate_(const Not* op, const Expr& e);
  virtual Expr Mutate_(const Select* op, const Expr& e);
  virtual Expr Mutate_(const Ramp* op, const Expr& e);
  virtual Expr Mutate_(const Broadcast* op, const Expr& e);
  virtual Expr Mutate_(const IntImm* op, const Expr& e);
  virtual Expr Mutate_(const UIntImm* op, const Expr& e);
  virtual Expr Mutate_(const FloatImm* op, const Expr& e);
  virtual Expr Mutate_(const StringImm* op, const Expr& e);
123
  virtual Expr Mutate_(const Shuffle* op, const Expr& e);
124 125
};

126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144
/*!
 * \brief recursively visit the ir in post DFS order node, and transform it
 *
 * \param node The ir to be transformed.
 * \param preorder The function called in before recursive mutation
 *          If preorder returns None, then the transform will proceed to recursive call.
 *          If preorder returns a not None Stmt/Expr, the transformer will simply return it and
 *          won't do further recursion.
 * \param postorder The function called after recursive mutation.
 *          The recursive mutation result is passed to postorder for further mutation.
 * \param only_enable List of StringImm.
 *          If it is empty, all IRNode will call preorder/postorder
 *          If it is not empty, preorder/postorder will only be called
 *          when the IRNode's type key is in the list.
 */
Stmt IRTransform(const Stmt& node,
                 const runtime::PackedFunc& preorder,
                 const runtime::PackedFunc& postorder,
                 const Array<Expr>& only_enable = {});
145 146 147
}  // namespace ir
}  // namespace tvm
#endif  // TVM_IR_MUTATOR_H_