expr_functor.h 8.96 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23 24 25 26 27
/*!
 * \file tvm/relay/expr_functor.h
 * \brief A more powerful visitor which enables defining arbitrary function
 * signatures with type based dispatch on first argument.
 */
#ifndef TVM_RELAY_EXPR_FUNCTOR_H_
#define TVM_RELAY_EXPR_FUNCTOR_H_

28
#include <tvm/node/functor.h>
29
#include <tvm/ir/error.h>
30 31 32 33
#include <tvm/relay/expr.h>
#include <tvm/relay/function.h>
#include <tvm/relay/adt.h>
#include <tvm/relay/op.h>
34

35
#include <string>
36 37
#include <utility>
#include <unordered_map>
38

39 40 41 42 43 44 45
namespace tvm {
namespace relay {

/*!
 * \brief A dynamical functor that dispatches on in the first Expr argument.
 *  You can use this as a more powerful Visitor, since it allows you to
 *  define function signatures of Visit Function.
46
 *
47 48 49 50 51 52 53 54 55 56
 * \sa tvm/ir_functor.h
 *
 * \tparam FType function signiture
 *  This type is only defined for FType with function signature R(const Expr&,
 * Args...)
 */
template <typename FType>
class ExprFunctor;

// functions to be overriden.
57
#define EXPR_FUNCTOR_DEFAULT                                      \
58 59 60 61
  { return VisitExprDefault_(op, std::forward<Args>(args)...); }

#define RELAY_EXPR_FUNCTOR_DISPATCH(OP)                                \
  vtable.template set_dispatch<OP>(                                    \
62 63
      [](const ObjectRef& n, TSelf* self, Args... args) {                \
        return self->VisitExpr_(static_cast<const OP*>(n.get()), \
64 65 66 67 68 69 70
                                std::forward<Args>(args)...);          \
      });

template <typename R, typename... Args>
class ExprFunctor<R(const Expr& n, Args...)> {
 private:
  using TSelf = ExprFunctor<R(const Expr& n, Args...)>;
71
  using FType = tvm::NodeFunctor<R(const ObjectRef& n, TSelf* self, Args...)>;
72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93

 public:
  /*! \brief the result type of this functor */
  using result_type = R;
  /*! \brief virtual destructor */
  virtual ~ExprFunctor() {}
  /*!
   * \brief Same as call.
   * \param n The expression node.
   * \param args Additional arguments.
   * \return The result of the call
   */
  R operator()(const Expr& n, Args... args) {
    return VisitExpr(n, std::forward<Args>(args)...);
  }
  /*!
   * \brief The functor call.
   * \param n The expression node.
   * \param args Additional arguments.
   * \return The result of the call
   */
  virtual R VisitExpr(const Expr& n, Args... args) {
雾雨魔理沙 committed
94
    CHECK(n.defined());
95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
    static FType vtable = InitVTable();
    return vtable(n, this, std::forward<Args>(args)...);
  }
  // Functions that can be overriden by subclass
  virtual R VisitExpr_(const ConstantNode* op,
                       Args... args) EXPR_FUNCTOR_DEFAULT;
  virtual R VisitExpr_(const TupleNode* op,
                       Args... args) EXPR_FUNCTOR_DEFAULT;
  virtual R VisitExpr_(const VarNode* op,
                       Args... args) EXPR_FUNCTOR_DEFAULT;
  virtual R VisitExpr_(const GlobalVarNode* op,
                       Args... args) EXPR_FUNCTOR_DEFAULT;
  virtual R VisitExpr_(const FunctionNode* op,
                       Args... args) EXPR_FUNCTOR_DEFAULT;
  virtual R VisitExpr_(const CallNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
  virtual R VisitExpr_(const LetNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
  virtual R VisitExpr_(const IfNode* op,
                       Args... args) EXPR_FUNCTOR_DEFAULT;
  virtual R VisitExpr_(const OpNode* op,
                       Args... args) EXPR_FUNCTOR_DEFAULT;
115
  virtual R VisitExpr_(const TupleGetItemNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
116 117 118
  virtual R VisitExpr_(const RefCreateNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
  virtual R VisitExpr_(const RefReadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
  virtual R VisitExpr_(const RefWriteNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
119 120
  virtual R VisitExpr_(const ConstructorNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
  virtual R VisitExpr_(const MatchNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
121
  virtual R VisitExprDefault_(const Object* op, Args...) {
122
    LOG(FATAL) << "Do not have a default for " << op->GetTypeKey();
123
    throw;
124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139
  }

 private:
  // initialize the vtable.
  static FType InitVTable() {
    FType vtable;
    // Set dispatch
    RELAY_EXPR_FUNCTOR_DISPATCH(ConstantNode);
    RELAY_EXPR_FUNCTOR_DISPATCH(TupleNode);
    RELAY_EXPR_FUNCTOR_DISPATCH(VarNode);
    RELAY_EXPR_FUNCTOR_DISPATCH(GlobalVarNode);
    RELAY_EXPR_FUNCTOR_DISPATCH(FunctionNode);
    RELAY_EXPR_FUNCTOR_DISPATCH(CallNode);
    RELAY_EXPR_FUNCTOR_DISPATCH(LetNode);
    RELAY_EXPR_FUNCTOR_DISPATCH(IfNode);
    RELAY_EXPR_FUNCTOR_DISPATCH(OpNode);
140
    RELAY_EXPR_FUNCTOR_DISPATCH(TupleGetItemNode);
141 142 143
    RELAY_EXPR_FUNCTOR_DISPATCH(RefCreateNode);
    RELAY_EXPR_FUNCTOR_DISPATCH(RefReadNode);
    RELAY_EXPR_FUNCTOR_DISPATCH(RefWriteNode);
144 145
    RELAY_EXPR_FUNCTOR_DISPATCH(ConstructorNode);
    RELAY_EXPR_FUNCTOR_DISPATCH(MatchNode);
146 147 148 149
    return vtable;
  }
};

150 151 152
/*!
 * \brief A simple visitor wrapper around ExprFunctor.
 *  Recursively visit the content.
153
 *
154 155
 * ExprVisitor treats Expr as dataflow graph,
 * and only visit each Expr node once.
156
 */
157 158
class ExprVisitor
    : public ::tvm::relay::ExprFunctor<void(const Expr& n)> {
159
 public:
160
  void VisitExpr(const Expr& expr) override;
161 162 163 164 165 166 167 168 169
  void VisitExpr_(const VarNode* op) override;
  void VisitExpr_(const GlobalVarNode* op) override;
  void VisitExpr_(const ConstantNode* op) override;
  void VisitExpr_(const TupleNode* op) override;
  void VisitExpr_(const FunctionNode* op) override;
  void VisitExpr_(const CallNode* op) override;
  void VisitExpr_(const LetNode* op) override;
  void VisitExpr_(const IfNode* op) override;
  void VisitExpr_(const OpNode* op) override;
170
  void VisitExpr_(const TupleGetItemNode* op) override;
171 172 173
  void VisitExpr_(const RefCreateNode* op) override;
  void VisitExpr_(const RefReadNode* op) override;
  void VisitExpr_(const RefWriteNode* op) override;
174 175
  void VisitExpr_(const ConstructorNode* op) override;
  void VisitExpr_(const MatchNode* op) override;
176
  virtual void VisitType(const Type& t);
177 178
  virtual void VisitClause(const Clause& c);
  virtual void VisitPattern(const Pattern& c);
179

180 181
 protected:
  // Internal visiting counter
182
  std::unordered_map<const Object*, size_t> visit_counter_;
183 184
};

185 186 187 188 189 190 191
/*!
 * \brief A wrapper around ExprFunctor which functionally updates the AST.
 *
 * ExprMutator treats Expr as dataflow graph, and only Mutate each Expr once.
 * The mutated results are memoized in a map and reused so that
 * local transformation on the dataflow preserves the graph structure.
 */
192
class ExprMutator
193
    : public ::tvm::relay::ExprFunctor<Expr(const Expr&)> {
194
 public:
195 196 197 198 199 200 201 202
  /*!
   * \brief Mutate is alias for VisitExpr
   * \return expr.
   */
  Expr Mutate(const Expr& expr) {
    return this->VisitExpr(expr);
  }
  Expr VisitExpr(const Expr& expr) override;
203 204 205 206 207 208 209 210 211
  Expr VisitExpr_(const VarNode* op) override;
  Expr VisitExpr_(const ConstantNode* op) override;
  Expr VisitExpr_(const GlobalVarNode* op) override;
  Expr VisitExpr_(const OpNode* op) override;
  Expr VisitExpr_(const TupleNode* op) override;
  Expr VisitExpr_(const FunctionNode* op) override;
  Expr VisitExpr_(const CallNode* call_node) override;
  Expr VisitExpr_(const LetNode* op) override;
  Expr VisitExpr_(const IfNode* op) override;
212
  Expr VisitExpr_(const TupleGetItemNode* op) override;
213 214 215
  Expr VisitExpr_(const RefCreateNode* op) override;
  Expr VisitExpr_(const RefReadNode* op) override;
  Expr VisitExpr_(const RefWriteNode* op) override;
216 217 218
  Expr VisitExpr_(const ConstructorNode* op) override;
  Expr VisitExpr_(const MatchNode* op) override;

219 220
  /*!
   * \brief Used to visit the types inside of expressions.
221
   *
222 223 224
   * Can be overloaded to transform the types in arbitrary
   * ways, one way would be to define a sub-class of type
   * visitor for types which transform them appropriately.
225
   */
226
  virtual Type VisitType(const Type& t);
227 228
  virtual Clause VisitClause(const Clause& c);
  virtual Pattern VisitPattern(const Pattern& c);
229

230
 protected:
231
  /*! \brief Internal map used for memoization. */
232
  std::unordered_map<Expr, Expr, ObjectHash, ObjectEqual> memo_;
233 234
};

ziheng committed
235 236 237 238 239 240
/*!
 * \brief recursively visit the ir in post DFS order node, apply fvisit
 * Each node is guaranteed to be visited only once.
 * \param node The ir to be visited.
 * \param fvisit The visitor function to be applied.
 */
241
void PostOrderVisit(const Expr& node, std::function<void(const Expr&)> fvisit);
ziheng committed
242

243 244 245
}  // namespace relay
}  // namespace tvm
#endif  // TVM_RELAY_EXPR_FUNCTOR_H_