/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file tvm/relay/expr_functor.h * \brief A more powerful visitor which enables defining arbitrary function * signatures with type based dispatch on first argument. */ #ifndef TVM_RELAY_EXPR_FUNCTOR_H_ #define TVM_RELAY_EXPR_FUNCTOR_H_ #include <tvm/node/functor.h> #include <tvm/ir/error.h> #include <tvm/relay/expr.h> #include <tvm/relay/function.h> #include <tvm/relay/adt.h> #include <tvm/relay/op.h> #include <string> #include <utility> #include <unordered_map> namespace tvm { namespace relay { /*! * \brief A dynamical functor that dispatches on in the first Expr argument. * You can use this as a more powerful Visitor, since it allows you to * define function signatures of Visit Function. * * \sa tvm/ir_functor.h * * \tparam FType function signiture * This type is only defined for FType with function signature R(const Expr&, * Args...) */ template <typename FType> class ExprFunctor; // functions to be overriden. #define EXPR_FUNCTOR_DEFAULT \ { return VisitExprDefault_(op, std::forward<Args>(args)...); } #define RELAY_EXPR_FUNCTOR_DISPATCH(OP) \ vtable.template set_dispatch<OP>( \ [](const ObjectRef& n, TSelf* self, Args... args) { \ return self->VisitExpr_(static_cast<const OP*>(n.get()), \ std::forward<Args>(args)...); \ }); template <typename R, typename... Args> class ExprFunctor<R(const Expr& n, Args...)> { private: using TSelf = ExprFunctor<R(const Expr& n, Args...)>; using FType = tvm::NodeFunctor<R(const ObjectRef& n, TSelf* self, Args...)>; public: /*! \brief the result type of this functor */ using result_type = R; /*! \brief virtual destructor */ virtual ~ExprFunctor() {} /*! * \brief Same as call. * \param n The expression node. * \param args Additional arguments. * \return The result of the call */ R operator()(const Expr& n, Args... args) { return VisitExpr(n, std::forward<Args>(args)...); } /*! * \brief The functor call. * \param n The expression node. * \param args Additional arguments. * \return The result of the call */ virtual R VisitExpr(const Expr& n, Args... args) { CHECK(n.defined()); static FType vtable = InitVTable(); return vtable(n, this, std::forward<Args>(args)...); } // Functions that can be overriden by subclass virtual R VisitExpr_(const ConstantNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const TupleNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const VarNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const GlobalVarNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const FunctionNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const CallNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const LetNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const IfNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const OpNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const TupleGetItemNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const RefCreateNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const RefReadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const RefWriteNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const ConstructorNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const MatchNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExprDefault_(const Object* op, Args...) { LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); throw; } private: // initialize the vtable. static FType InitVTable() { FType vtable; // Set dispatch RELAY_EXPR_FUNCTOR_DISPATCH(ConstantNode); RELAY_EXPR_FUNCTOR_DISPATCH(TupleNode); RELAY_EXPR_FUNCTOR_DISPATCH(VarNode); RELAY_EXPR_FUNCTOR_DISPATCH(GlobalVarNode); RELAY_EXPR_FUNCTOR_DISPATCH(FunctionNode); RELAY_EXPR_FUNCTOR_DISPATCH(CallNode); RELAY_EXPR_FUNCTOR_DISPATCH(LetNode); RELAY_EXPR_FUNCTOR_DISPATCH(IfNode); RELAY_EXPR_FUNCTOR_DISPATCH(OpNode); RELAY_EXPR_FUNCTOR_DISPATCH(TupleGetItemNode); RELAY_EXPR_FUNCTOR_DISPATCH(RefCreateNode); RELAY_EXPR_FUNCTOR_DISPATCH(RefReadNode); RELAY_EXPR_FUNCTOR_DISPATCH(RefWriteNode); RELAY_EXPR_FUNCTOR_DISPATCH(ConstructorNode); RELAY_EXPR_FUNCTOR_DISPATCH(MatchNode); return vtable; } }; /*! * \brief A simple visitor wrapper around ExprFunctor. * Recursively visit the content. * * ExprVisitor treats Expr as dataflow graph, * and only visit each Expr node once. */ class ExprVisitor : public ::tvm::relay::ExprFunctor<void(const Expr& n)> { public: void VisitExpr(const Expr& expr) override; void VisitExpr_(const VarNode* op) override; void VisitExpr_(const GlobalVarNode* op) override; void VisitExpr_(const ConstantNode* op) override; void VisitExpr_(const TupleNode* op) override; void VisitExpr_(const FunctionNode* op) override; void VisitExpr_(const CallNode* op) override; void VisitExpr_(const LetNode* op) override; void VisitExpr_(const IfNode* op) override; void VisitExpr_(const OpNode* op) override; void VisitExpr_(const TupleGetItemNode* op) override; void VisitExpr_(const RefCreateNode* op) override; void VisitExpr_(const RefReadNode* op) override; void VisitExpr_(const RefWriteNode* op) override; void VisitExpr_(const ConstructorNode* op) override; void VisitExpr_(const MatchNode* op) override; virtual void VisitType(const Type& t); virtual void VisitClause(const Clause& c); virtual void VisitPattern(const Pattern& c); protected: // Internal visiting counter std::unordered_map<const Object*, size_t> visit_counter_; }; /*! * \brief A wrapper around ExprFunctor which functionally updates the AST. * * ExprMutator treats Expr as dataflow graph, and only Mutate each Expr once. * The mutated results are memoized in a map and reused so that * local transformation on the dataflow preserves the graph structure. */ class ExprMutator : public ::tvm::relay::ExprFunctor<Expr(const Expr&)> { public: /*! * \brief Mutate is alias for VisitExpr * \return expr. */ Expr Mutate(const Expr& expr) { return this->VisitExpr(expr); } Expr VisitExpr(const Expr& expr) override; Expr VisitExpr_(const VarNode* op) override; Expr VisitExpr_(const ConstantNode* op) override; Expr VisitExpr_(const GlobalVarNode* op) override; Expr VisitExpr_(const OpNode* op) override; Expr VisitExpr_(const TupleNode* op) override; Expr VisitExpr_(const FunctionNode* op) override; Expr VisitExpr_(const CallNode* call_node) override; Expr VisitExpr_(const LetNode* op) override; Expr VisitExpr_(const IfNode* op) override; Expr VisitExpr_(const TupleGetItemNode* op) override; Expr VisitExpr_(const RefCreateNode* op) override; Expr VisitExpr_(const RefReadNode* op) override; Expr VisitExpr_(const RefWriteNode* op) override; Expr VisitExpr_(const ConstructorNode* op) override; Expr VisitExpr_(const MatchNode* op) override; /*! * \brief Used to visit the types inside of expressions. * * Can be overloaded to transform the types in arbitrary * ways, one way would be to define a sub-class of type * visitor for types which transform them appropriately. */ virtual Type VisitType(const Type& t); virtual Clause VisitClause(const Clause& c); virtual Pattern VisitPattern(const Pattern& c); protected: /*! \brief Internal map used for memoization. */ std::unordered_map<Expr, Expr, ObjectHash, ObjectEqual> memo_; }; /*! * \brief recursively visit the ir in post DFS order node, apply fvisit * Each node is guaranteed to be visited only once. * \param node The ir to be visited. * \param fvisit The visitor function to be applied. */ void PostOrderVisit(const Expr& node, std::function<void(const Expr&)> fvisit); } // namespace relay } // namespace tvm #endif // TVM_RELAY_EXPR_FUNCTOR_H_