/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file attr_functor.h * \brief A way to define arbitrary function signature * with dispatch on common attributes. * * Common attributes include: * - int, float, str constants * - array of attributes * - map of attributes */ #ifndef TVM_IR_ATTR_FUNCTOR_H_ #define TVM_IR_ATTR_FUNCTOR_H_ #include <tvm/node/functor.h> #include <tvm/tir/expr.h> #include <utility> namespace tvm { template <typename FType> class AttrFunctor; #define ATTR_FUNCTOR_DEFAULT \ { return VisitAttrDefault_(op, std::forward<Args>(args)...); } #define ATTR_FUNCTOR_DISPATCH(OP) \ vtable.template set_dispatch<OP>( \ [](const ObjectRef& n, TSelf* self, Args... args) { \ return self->VisitAttr_(static_cast<const OP*>(n.get()), \ std::forward<Args>(args)...); \ }); \ // A functor for common attribute information. template <typename R, typename... Args> class AttrFunctor<R(const ObjectRef& n, Args...)> { private: using TSelf = AttrFunctor<R(const ObjectRef& n, Args...)>; using FType = tvm::NodeFunctor<R(const ObjectRef& n, TSelf* self, Args...)>; public: /*! \brief the result type of this functor */ using result_type = R; /*! \brief virtual destructor */ virtual ~AttrFunctor() {} /*! * \brief The functor call. * \param n The expression node. * \param args Additional arguments. * \return The result of the call */ virtual R VisitAttr(const ObjectRef& n, Args... args) { static FType vtable = InitVTable(); if (vtable.can_dispatch(n)) { return vtable(n, this, std::forward<Args>(args)...); } else { return VisitAttrDefault_(n.get(), std::forward<Args>(args)...); } } virtual R VisitAttrDefault_(const Object* node, Args... args) = 0; virtual R VisitAttr_(const ArrayNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const StrMapNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const tir::IntImmNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const tir::FloatImmNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const tir::StringImmNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; // deep comparison of symbolic integer expressions. virtual R VisitAttr_(const tir::VarNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const tir::SizeVarNode* op, Args... args) { return VisitAttr_(static_cast<const tir::VarNode*>(op), std::forward<Args>(args)...); } virtual R VisitAttr_(const tir::AddNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const tir::SubNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const tir::MulNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const tir::DivNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const tir::ModNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const tir::FloorDivNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const tir::FloorModNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const tir::MinNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const tir::MaxNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const tir::GENode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const tir::GTNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const tir::LTNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const tir::LENode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const tir::EQNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const tir::NENode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const tir::AndNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const tir::OrNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const tir::NotNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const tir::CastNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const tir::CallNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const tir::SelectNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; private: // initialize the vtable. static FType InitVTable() { using namespace tir; FType vtable; // Set dispatch ATTR_FUNCTOR_DISPATCH(StrMapNode); ATTR_FUNCTOR_DISPATCH(ArrayNode); ATTR_FUNCTOR_DISPATCH(IntImmNode); ATTR_FUNCTOR_DISPATCH(FloatImmNode); ATTR_FUNCTOR_DISPATCH(StringImmNode); ATTR_FUNCTOR_DISPATCH(VarNode); ATTR_FUNCTOR_DISPATCH(SizeVarNode); ATTR_FUNCTOR_DISPATCH(AddNode); ATTR_FUNCTOR_DISPATCH(SubNode); ATTR_FUNCTOR_DISPATCH(MulNode); ATTR_FUNCTOR_DISPATCH(DivNode); ATTR_FUNCTOR_DISPATCH(ModNode); ATTR_FUNCTOR_DISPATCH(FloorDivNode); ATTR_FUNCTOR_DISPATCH(FloorModNode); ATTR_FUNCTOR_DISPATCH(MinNode); ATTR_FUNCTOR_DISPATCH(MaxNode); ATTR_FUNCTOR_DISPATCH(GENode); ATTR_FUNCTOR_DISPATCH(GTNode); ATTR_FUNCTOR_DISPATCH(LENode); ATTR_FUNCTOR_DISPATCH(LTNode); ATTR_FUNCTOR_DISPATCH(EQNode); ATTR_FUNCTOR_DISPATCH(NENode); ATTR_FUNCTOR_DISPATCH(AndNode); ATTR_FUNCTOR_DISPATCH(OrNode); ATTR_FUNCTOR_DISPATCH(NotNode); ATTR_FUNCTOR_DISPATCH(CastNode); ATTR_FUNCTOR_DISPATCH(CallNode); ATTR_FUNCTOR_DISPATCH(SelectNode); return vtable; } }; } // namespace tvm #endif // TVM_IR_ATTR_FUNCTOR_H_