expr_functor.h 8.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23 24 25 26 27
/*!
 * \file tvm/relay/expr_functor.h
 * \brief A more powerful visitor which enables defining arbitrary function
 * signatures with type based dispatch on first argument.
 */
#ifndef TVM_RELAY_EXPR_FUNCTOR_H_
#define TVM_RELAY_EXPR_FUNCTOR_H_

28
#include <tvm/node/ir_functor.h>
29
#include <string>
30 31
#include <utility>
#include <unordered_map>
32
#include "./expr.h"
33
#include "./adt.h"
34
#include "./op.h"
35
#include "./error.h"
36 37 38 39 40 41 42 43

namespace tvm {
namespace relay {

/*!
 * \brief A dynamical functor that dispatches on in the first Expr argument.
 *  You can use this as a more powerful Visitor, since it allows you to
 *  define function signatures of Visit Function.
44
 *
45 46 47 48 49 50 51 52 53 54
 * \sa tvm/ir_functor.h
 *
 * \tparam FType function signiture
 *  This type is only defined for FType with function signature R(const Expr&,
 * Args...)
 */
template <typename FType>
class ExprFunctor;

// functions to be overriden.
55
#define EXPR_FUNCTOR_DEFAULT                                      \
56 57 58 59
  { return VisitExprDefault_(op, std::forward<Args>(args)...); }

#define RELAY_EXPR_FUNCTOR_DISPATCH(OP)                                \
  vtable.template set_dispatch<OP>(                                    \
60 61
      [](const ObjectRef& n, TSelf* self, Args... args) {                \
        return self->VisitExpr_(static_cast<const OP*>(n.get()), \
62 63 64 65 66 67 68
                                std::forward<Args>(args)...);          \
      });

template <typename R, typename... Args>
class ExprFunctor<R(const Expr& n, Args...)> {
 private:
  using TSelf = ExprFunctor<R(const Expr& n, Args...)>;
69
  using FType = tvm::IRFunctor<R(const ObjectRef& n, TSelf* self, Args...)>;
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91

 public:
  /*! \brief the result type of this functor */
  using result_type = R;
  /*! \brief virtual destructor */
  virtual ~ExprFunctor() {}
  /*!
   * \brief Same as call.
   * \param n The expression node.
   * \param args Additional arguments.
   * \return The result of the call
   */
  R operator()(const Expr& n, Args... args) {
    return VisitExpr(n, std::forward<Args>(args)...);
  }
  /*!
   * \brief The functor call.
   * \param n The expression node.
   * \param args Additional arguments.
   * \return The result of the call
   */
  virtual R VisitExpr(const Expr& n, Args... args) {
雾雨魔理沙 committed
92
    CHECK(n.defined());
93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112
    static FType vtable = InitVTable();
    return vtable(n, this, std::forward<Args>(args)...);
  }
  // Functions that can be overriden by subclass
  virtual R VisitExpr_(const ConstantNode* op,
                       Args... args) EXPR_FUNCTOR_DEFAULT;
  virtual R VisitExpr_(const TupleNode* op,
                       Args... args) EXPR_FUNCTOR_DEFAULT;
  virtual R VisitExpr_(const VarNode* op,
                       Args... args) EXPR_FUNCTOR_DEFAULT;
  virtual R VisitExpr_(const GlobalVarNode* op,
                       Args... args) EXPR_FUNCTOR_DEFAULT;
  virtual R VisitExpr_(const FunctionNode* op,
                       Args... args) EXPR_FUNCTOR_DEFAULT;
  virtual R VisitExpr_(const CallNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
  virtual R VisitExpr_(const LetNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
  virtual R VisitExpr_(const IfNode* op,
                       Args... args) EXPR_FUNCTOR_DEFAULT;
  virtual R VisitExpr_(const OpNode* op,
                       Args... args) EXPR_FUNCTOR_DEFAULT;
113
  virtual R VisitExpr_(const TupleGetItemNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
114 115 116
  virtual R VisitExpr_(const RefCreateNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
  virtual R VisitExpr_(const RefReadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
  virtual R VisitExpr_(const RefWriteNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
117 118
  virtual R VisitExpr_(const ConstructorNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
  virtual R VisitExpr_(const MatchNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
119
  virtual R VisitExprDefault_(const Node* op, Args...) {
120
    LOG(FATAL) << "Do not have a default for " << op->GetTypeKey();
121
    throw;
122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
  }

 private:
  // initialize the vtable.
  static FType InitVTable() {
    FType vtable;
    // Set dispatch
    RELAY_EXPR_FUNCTOR_DISPATCH(ConstantNode);
    RELAY_EXPR_FUNCTOR_DISPATCH(TupleNode);
    RELAY_EXPR_FUNCTOR_DISPATCH(VarNode);
    RELAY_EXPR_FUNCTOR_DISPATCH(GlobalVarNode);
    RELAY_EXPR_FUNCTOR_DISPATCH(FunctionNode);
    RELAY_EXPR_FUNCTOR_DISPATCH(CallNode);
    RELAY_EXPR_FUNCTOR_DISPATCH(LetNode);
    RELAY_EXPR_FUNCTOR_DISPATCH(IfNode);
    RELAY_EXPR_FUNCTOR_DISPATCH(OpNode);
138
    RELAY_EXPR_FUNCTOR_DISPATCH(TupleGetItemNode);
139 140 141
    RELAY_EXPR_FUNCTOR_DISPATCH(RefCreateNode);
    RELAY_EXPR_FUNCTOR_DISPATCH(RefReadNode);
    RELAY_EXPR_FUNCTOR_DISPATCH(RefWriteNode);
142 143
    RELAY_EXPR_FUNCTOR_DISPATCH(ConstructorNode);
    RELAY_EXPR_FUNCTOR_DISPATCH(MatchNode);
144 145 146 147
    return vtable;
  }
};

148 149 150
/*!
 * \brief A simple visitor wrapper around ExprFunctor.
 *  Recursively visit the content.
151
 *
152 153
 * ExprVisitor treats Expr as dataflow graph,
 * and only visit each Expr node once.
154
 */
155 156
class ExprVisitor
    : public ::tvm::relay::ExprFunctor<void(const Expr& n)> {
157
 public:
158
  void VisitExpr(const Expr& expr) override;
159 160 161 162 163 164 165 166 167
  void VisitExpr_(const VarNode* op) override;
  void VisitExpr_(const GlobalVarNode* op) override;
  void VisitExpr_(const ConstantNode* op) override;
  void VisitExpr_(const TupleNode* op) override;
  void VisitExpr_(const FunctionNode* op) override;
  void VisitExpr_(const CallNode* op) override;
  void VisitExpr_(const LetNode* op) override;
  void VisitExpr_(const IfNode* op) override;
  void VisitExpr_(const OpNode* op) override;
168
  void VisitExpr_(const TupleGetItemNode* op) override;
169 170 171
  void VisitExpr_(const RefCreateNode* op) override;
  void VisitExpr_(const RefReadNode* op) override;
  void VisitExpr_(const RefWriteNode* op) override;
172 173
  void VisitExpr_(const ConstructorNode* op) override;
  void VisitExpr_(const MatchNode* op) override;
174
  virtual void VisitType(const Type& t);
175 176
  virtual void VisitClause(const Clause& c);
  virtual void VisitPattern(const Pattern& c);
177

178 179 180
 protected:
  // Internal visiting counter
  std::unordered_map<const Node*, size_t> visit_counter_;
181 182
};

183 184 185 186 187 188 189
/*!
 * \brief A wrapper around ExprFunctor which functionally updates the AST.
 *
 * ExprMutator treats Expr as dataflow graph, and only Mutate each Expr once.
 * The mutated results are memoized in a map and reused so that
 * local transformation on the dataflow preserves the graph structure.
 */
190
class ExprMutator
191
    : public ::tvm::relay::ExprFunctor<Expr(const Expr&)> {
192
 public:
193 194 195 196 197 198 199 200
  /*!
   * \brief Mutate is alias for VisitExpr
   * \return expr.
   */
  Expr Mutate(const Expr& expr) {
    return this->VisitExpr(expr);
  }
  Expr VisitExpr(const Expr& expr) override;
201 202 203 204 205 206 207 208 209
  Expr VisitExpr_(const VarNode* op) override;
  Expr VisitExpr_(const ConstantNode* op) override;
  Expr VisitExpr_(const GlobalVarNode* op) override;
  Expr VisitExpr_(const OpNode* op) override;
  Expr VisitExpr_(const TupleNode* op) override;
  Expr VisitExpr_(const FunctionNode* op) override;
  Expr VisitExpr_(const CallNode* call_node) override;
  Expr VisitExpr_(const LetNode* op) override;
  Expr VisitExpr_(const IfNode* op) override;
210
  Expr VisitExpr_(const TupleGetItemNode* op) override;
211 212 213
  Expr VisitExpr_(const RefCreateNode* op) override;
  Expr VisitExpr_(const RefReadNode* op) override;
  Expr VisitExpr_(const RefWriteNode* op) override;
214 215 216
  Expr VisitExpr_(const ConstructorNode* op) override;
  Expr VisitExpr_(const MatchNode* op) override;

217 218
  /*!
   * \brief Used to visit the types inside of expressions.
219
   *
220 221 222
   * Can be overloaded to transform the types in arbitrary
   * ways, one way would be to define a sub-class of type
   * visitor for types which transform them appropriately.
223
   */
224
  virtual Type VisitType(const Type& t);
225 226
  virtual Clause VisitClause(const Clause& c);
  virtual Pattern VisitPattern(const Pattern& c);
227

228
 protected:
229
  /*! \brief Internal map used for memoization. */
230
  std::unordered_map<Expr, Expr, NodeHash, NodeEqual> memo_;
231 232
};

ziheng committed
233 234 235 236 237 238
/*!
 * \brief recursively visit the ir in post DFS order node, apply fvisit
 * Each node is guaranteed to be visited only once.
 * \param node The ir to be visited.
 * \param fvisit The visitor function to be applied.
 */
239
void PostOrderVisit(const Expr& node, std::function<void(const Expr&)> fvisit);
ziheng committed
240

241 242 243
}  // namespace relay
}  // namespace tvm
#endif  // TVM_RELAY_EXPR_FUNCTOR_H_