/*! * Copyright (c) 2018 by Contributors * \file tvm/relay/expr_functor.h * \brief A more powerful visitor which enables defining arbitrary function * signatures with type based dispatch on first argument. */ #ifndef TVM_RELAY_EXPR_FUNCTOR_H_ #define TVM_RELAY_EXPR_FUNCTOR_H_ #include <tvm/node/ir_functor.h> #include <string> #include "./expr.h" #include "./op.h" #include "./error.h" namespace tvm { namespace relay { /*! * \brief A dynamical functor that dispatches on in the first Expr argument. * You can use this as a more powerful Visitor, since it allows you to * define function signatures of Visit Function. * * \sa tvm/ir_functor.h * * \tparam FType function signiture * This type is only defined for FType with function signature R(const Expr&, * Args...) */ template <typename FType> class ExprFunctor; // functions to be overriden. #define EXPR_FUNCTOR_DEFAULT \ { return VisitExprDefault_(op, std::forward<Args>(args)...); } #define RELAY_EXPR_FUNCTOR_DISPATCH(OP) \ vtable.template set_dispatch<OP>( \ [](const NodeRef& n, TSelf* self, Args... args) { \ return self->VisitExpr_(static_cast<const OP*>(n.node_.get()), \ std::forward<Args>(args)...); \ }); template <typename R, typename... Args> class ExprFunctor<R(const Expr& n, Args...)> { private: using TSelf = ExprFunctor<R(const Expr& n, Args...)>; using FType = tvm::IRFunctor<R(const NodeRef& n, TSelf* self, Args...)>; public: /*! \brief the result type of this functor */ using result_type = R; /*! \brief virtual destructor */ virtual ~ExprFunctor() {} /*! * \brief Same as call. * \param n The expression node. * \param args Additional arguments. * \return The result of the call */ R operator()(const Expr& n, Args... args) { return VisitExpr(n, std::forward<Args>(args)...); } /*! * \brief The functor call. * \param n The expression node. * \param args Additional arguments. * \return The result of the call */ virtual R VisitExpr(const Expr& n, Args... args) { static FType vtable = InitVTable(); return vtable(n, this, std::forward<Args>(args)...); } // Functions that can be overriden by subclass virtual R VisitExpr_(const ConstantNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const TupleNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const VarNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const GlobalVarNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const ParamNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const FunctionNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const CallNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const LetNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const IfNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const OpNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExprDefault_(const Node* op, Args...) { throw Error(std::string("Do not have a default for ") + op->type_key()); } private: // initialize the vtable. static FType InitVTable() { FType vtable; // Set dispatch RELAY_EXPR_FUNCTOR_DISPATCH(ConstantNode); RELAY_EXPR_FUNCTOR_DISPATCH(TupleNode); RELAY_EXPR_FUNCTOR_DISPATCH(VarNode); RELAY_EXPR_FUNCTOR_DISPATCH(GlobalVarNode); RELAY_EXPR_FUNCTOR_DISPATCH(ParamNode); RELAY_EXPR_FUNCTOR_DISPATCH(FunctionNode); RELAY_EXPR_FUNCTOR_DISPATCH(CallNode); RELAY_EXPR_FUNCTOR_DISPATCH(LetNode); RELAY_EXPR_FUNCTOR_DISPATCH(IfNode); RELAY_EXPR_FUNCTOR_DISPATCH(OpNode); return vtable; } }; /*! \brief A simple visitor wrapper around ExprFunctor. * * Exposes two visitors with default traversal strategies, one * which doesn't compute a result but can mutate internal state, * and another which functionally builds a new Expr. */ class ExprVisitor : public ::tvm::relay::ExprFunctor<void(const Expr& n)> { public: void VisitExpr_(const VarNode* op) override; void VisitExpr_(const GlobalVarNode* op) override; void VisitExpr_(const ConstantNode* op) override; void VisitExpr_(const TupleNode* op) override; void VisitExpr_(const ParamNode* op) override; void VisitExpr_(const FunctionNode* op) override; void VisitExpr_(const CallNode* op) override; void VisitExpr_(const LetNode* op) override; void VisitExpr_(const IfNode* op) override; void VisitExpr_(const OpNode* op) override; virtual void VisitType(const Type& t); }; /*! \brief A wrapper around ExprFunctor which functionally updates the AST. * * ExprMutator uses memoization and self return in order to amortize * the cost of using functional updates. */ class ExprMutator : public ::tvm::relay::ExprFunctor<Expr(const Expr&)> { public: Expr Mutate(const Expr& expr); Expr VisitExpr_(const VarNode* op) override; Expr VisitExpr_(const ConstantNode* op) override; Expr VisitExpr_(const GlobalVarNode* op) override; Expr VisitExpr_(const OpNode* op) override; Expr VisitExpr_(const TupleNode* op) override; Expr VisitExpr_(const ParamNode* op) override; Expr VisitExpr_(const FunctionNode* op) override; Expr VisitExpr_(const CallNode* call_node) override; Expr VisitExpr_(const LetNode* op) override; Expr VisitExpr_(const IfNode* op) override; /*! \brief Used to visit the types inside of expressions. * * Can be overloaded to transform the types in arbitrary * ways, one way would be to define a sub-class of type * visitor for types which transform them appropriately. */ virtual Type VisitType(const Type& t); private: /*! \brief Internal map used for memoization. */ std::unordered_map<Expr, Expr, NodeHash, NodeEqual> memo_; }; } // namespace relay } // namespace tvm #endif // TVM_RELAY_EXPR_FUNCTOR_H_