/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file attr_functor.h * \brief A way to define arbitrary function signature * with dispatch on common attributes. * * Common attributes include: * - int, float, str constants * - array of attributes * - map of attributes */ #ifndef TVM_IR_ATTR_FUNCTOR_H_ #define TVM_IR_ATTR_FUNCTOR_H_ #include <tvm/node/functor.h> #include <tvm/tir/expr.h> #include <utility> namespace tvm { template <typename FType> class AttrFunctor; #define ATTR_FUNCTOR_DEFAULT \ { return VisitAttrDefault_(op, std::forward<Args>(args)...); } #define ATTR_FUNCTOR_DISPATCH(OP) \ vtable.template set_dispatch<OP>( \ [](const ObjectRef& n, TSelf* self, Args... args) { \ return self->VisitAttr_(static_cast<const OP*>(n.get()), \ std::forward<Args>(args)...); \ }); \ // A functor for common attribute information. template <typename R, typename... Args> class AttrFunctor<R(const ObjectRef& n, Args...)> { private: using TSelf = AttrFunctor<R(const ObjectRef& n, Args...)>; using FType = tvm::NodeFunctor<R(const ObjectRef& n, TSelf* self, Args...)>; public: /*! \brief the result type of this functor */ using result_type = R; /*! \brief virtual destructor */ virtual ~AttrFunctor() {} /*! * \brief The functor call. * \param n The expression node. * \param args Additional arguments. * \return The result of the call */ virtual R VisitAttr(const ObjectRef& n, Args... args) { static FType vtable = InitVTable(); if (vtable.can_dispatch(n)) { return vtable(n, this, std::forward<Args>(args)...); } else { return VisitAttrDefault_(n.get(), std::forward<Args>(args)...); } } virtual R VisitAttrDefault_(const Object* node, Args... args) = 0; virtual R VisitAttr_(const ArrayNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const StrMapNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const tir::IntImmNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const tir::FloatImmNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const tir::StringImmNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; // deep comparison of symbolic integer expressions. virtual R VisitAttr_(const tir::VarNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const tir::SizeVarNode* op, Args... args) { return VisitAttr_(static_cast<const tir::VarNode*>(op), std::forward<Args>(args)...); } virtual R VisitAttr_(const tir::AddNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const tir::SubNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const tir::MulNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const tir::DivNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const tir::ModNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const tir::FloorDivNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const tir::FloorModNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const tir::MinNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const tir::MaxNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const tir::GENode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const tir::GTNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const tir::LTNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const tir::LENode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const tir::EQNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const tir::NENode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const tir::AndNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const tir::OrNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const tir::NotNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const tir::CastNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const tir::CallNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const tir::SelectNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; private: // initialize the vtable. static FType InitVTable() { using namespace tir; FType vtable; // Set dispatch ATTR_FUNCTOR_DISPATCH(StrMapNode); ATTR_FUNCTOR_DISPATCH(ArrayNode); ATTR_FUNCTOR_DISPATCH(IntImmNode); ATTR_FUNCTOR_DISPATCH(FloatImmNode); ATTR_FUNCTOR_DISPATCH(StringImmNode); ATTR_FUNCTOR_DISPATCH(VarNode); ATTR_FUNCTOR_DISPATCH(SizeVarNode); ATTR_FUNCTOR_DISPATCH(AddNode); ATTR_FUNCTOR_DISPATCH(SubNode); ATTR_FUNCTOR_DISPATCH(MulNode); ATTR_FUNCTOR_DISPATCH(DivNode); ATTR_FUNCTOR_DISPATCH(ModNode); ATTR_FUNCTOR_DISPATCH(FloorDivNode); ATTR_FUNCTOR_DISPATCH(FloorModNode); ATTR_FUNCTOR_DISPATCH(MinNode); ATTR_FUNCTOR_DISPATCH(MaxNode); ATTR_FUNCTOR_DISPATCH(GENode); ATTR_FUNCTOR_DISPATCH(GTNode); ATTR_FUNCTOR_DISPATCH(LENode); ATTR_FUNCTOR_DISPATCH(LTNode); ATTR_FUNCTOR_DISPATCH(EQNode); ATTR_FUNCTOR_DISPATCH(NENode); ATTR_FUNCTOR_DISPATCH(AndNode); ATTR_FUNCTOR_DISPATCH(OrNode); ATTR_FUNCTOR_DISPATCH(NotNode); ATTR_FUNCTOR_DISPATCH(CastNode); ATTR_FUNCTOR_DISPATCH(CallNode); ATTR_FUNCTOR_DISPATCH(SelectNode); return vtable; } }; class AttrsEqualHandler : protected AttrFunctor<bool(const ObjectRef&, const ObjectRef&)> { public: /*! * \brief Check if lhs equals rhs * \param lhs The left operand. * \param rhs The right operand. */ bool Equal(const ObjectRef& lhs, const ObjectRef& rhs); protected: bool VisitAttrDefault_(const Object* lhs, const ObjectRef& other) final; bool VisitAttr_(const ArrayNode* lhs, const ObjectRef& other) final; bool VisitAttr_(const StrMapNode* lhs, const ObjectRef& other) final; bool VisitAttr_(const tir::IntImmNode* lhs, const ObjectRef& other) final; bool VisitAttr_(const tir::FloatImmNode* lhs, const ObjectRef& other) final; bool VisitAttr_(const tir::StringImmNode* lhs, const ObjectRef& other) final; bool VisitAttr_(const tir::AddNode* lhs, const ObjectRef& other) final; bool VisitAttr_(const tir::SubNode* lhs, const ObjectRef& other) final; bool VisitAttr_(const tir::MulNode* lhs, const ObjectRef& other) final; bool VisitAttr_(const tir::DivNode* lhs, const ObjectRef& other) final; bool VisitAttr_(const tir::ModNode* lhs, const ObjectRef& other) final; bool VisitAttr_(const tir::FloorDivNode* lhs, const ObjectRef& other) final; bool VisitAttr_(const tir::FloorModNode* lhs, const ObjectRef& other) final; bool VisitAttr_(const tir::MinNode* lhs, const ObjectRef& other) final; bool VisitAttr_(const tir::MaxNode* lhs, const ObjectRef& other) final; bool VisitAttr_(const tir::GENode* lhs, const ObjectRef& other) final; bool VisitAttr_(const tir::GTNode* lhs, const ObjectRef& other) final; bool VisitAttr_(const tir::LTNode* lhs, const ObjectRef& other) final; bool VisitAttr_(const tir::LENode* lhs, const ObjectRef& other) final; bool VisitAttr_(const tir::EQNode* lhs, const ObjectRef& other) final; bool VisitAttr_(const tir::NENode* lhs, const ObjectRef& other) final; bool VisitAttr_(const tir::AndNode* lhs, const ObjectRef& other) final; bool VisitAttr_(const tir::OrNode* lhs, const ObjectRef& other) final; bool VisitAttr_(const tir::NotNode* lhs, const ObjectRef& other) final; bool VisitAttr_(const tir::CastNode* lhs, const ObjectRef& other) final; bool VisitAttr_(const tir::CallNode* lhs, const ObjectRef& other) final; bool VisitAttr_(const tir::SelectNode* lhs, const ObjectRef& other) final; }; class AttrsHashHandler : protected AttrFunctor<size_t(const ObjectRef&)> { public: /*! * \brief Get hash value of node * \param node The node to be hashed. */ size_t Hash(const ObjectRef& node) { if (!node.defined()) return 0; return this->VisitAttr(node); } protected: size_t VisitAttrDefault_(const Object* lhs) final; size_t VisitAttr_(const tir::IntImmNode* lhs) final; size_t VisitAttr_(const tir::FloatImmNode* lhs) final; size_t VisitAttr_(const tir::StringImmNode* lhs) final; size_t VisitAttr_(const ArrayNode* lhs) final; size_t VisitAttr_(const StrMapNode* lhs) final; size_t VisitAttr_(const tir::AddNode* op) final; size_t VisitAttr_(const tir::SubNode* op) final; size_t VisitAttr_(const tir::MulNode* op) final; size_t VisitAttr_(const tir::DivNode* op) final; size_t VisitAttr_(const tir::ModNode* op) final; size_t VisitAttr_(const tir::FloorDivNode* op) final; size_t VisitAttr_(const tir::FloorModNode* op) final; size_t VisitAttr_(const tir::MinNode* op) final; size_t VisitAttr_(const tir::MaxNode* op) final; size_t VisitAttr_(const tir::GENode* op) final; size_t VisitAttr_(const tir::GTNode* op) final; size_t VisitAttr_(const tir::LENode* op) final; size_t VisitAttr_(const tir::LTNode* op) final; size_t VisitAttr_(const tir::EQNode* op) final; size_t VisitAttr_(const tir::NENode* op) final; size_t VisitAttr_(const tir::AndNode* op) final; size_t VisitAttr_(const tir::OrNode* op) final; size_t VisitAttr_(const tir::NotNode* op) final; size_t VisitAttr_(const tir::CastNode* op) final; size_t VisitAttr_(const tir::CallNode* op) final; size_t VisitAttr_(const tir::SelectNode* op) final; /*! * \brief alias of dmlc::HashCombine * \param lhs The first hash value. * \param rhs The second hash value. */ static size_t Combine(size_t lhs, size_t rhs) { return dmlc::HashCombine(lhs, rhs); } }; } // namespace tvm #endif // TVM_IR_ATTR_FUNCTOR_H_