/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file tvm/node/node.h * \brief Node system data structure. */ #ifndef TVM_NODE_NODE_H_ #define TVM_NODE_NODE_H_ #include <dmlc/logging.h> #include <tvm/runtime/c_runtime_api.h> #include <tvm/runtime/node_base.h> #include <string> #include <vector> #include <utility> #include <type_traits> namespace tvm { // forward declaration class DataType; class Node; class NodeRef; namespace runtime { // forward declaration class NDArray; // forward declaration class Object; } // namespace runtime /*! * \brief Visitor class to each node content. * The content is going to be called for each field. */ class TVM_DLL AttrVisitor { public: //! \cond Doxygen_Suppress virtual ~AttrVisitor() = default; virtual void Visit(const char* key, double* value) = 0; virtual void Visit(const char* key, int64_t* value) = 0; virtual void Visit(const char* key, uint64_t* value) = 0; virtual void Visit(const char* key, int* value) = 0; virtual void Visit(const char* key, bool* value) = 0; virtual void Visit(const char* key, std::string* value) = 0; virtual void Visit(const char* key, void** value) = 0; virtual void Visit(const char* key, DataType* value) = 0; virtual void Visit(const char* key, NodeRef* value) = 0; virtual void Visit(const char* key, runtime::NDArray* value) = 0; virtual void Visit(const char* key, runtime::Object* value) = 0; template<typename ENum, typename = typename std::enable_if<std::is_enum<ENum>::value>::type> void Visit(const char* key, ENum* ptr) { static_assert(std::is_same<int, typename std::underlying_type<ENum>::type>::value, "declare enum to be enum int to use visitor"); this->Visit(key, reinterpret_cast<int*>(ptr)); } //! \endcond }; /*! * \brief base class of node container in DSL AST. */ class TVM_DLL Node : public NodeBase { public: /*! \brief virtual destructor */ virtual ~Node() {} /*! \return The unique type key of the node */ virtual const char* type_key() const = 0; /*! * \brief Apply visitor to each field of the Node * Visitor could mutate the content of the node. * override if Node contains attribute fields. * \param visitor The visitor */ virtual void VisitAttrs(AttrVisitor* visitor) {} /*! \return the type index of the node */ virtual uint32_t type_index() const = 0; /*! * \brief Whether this node derives from node with type_index=tid. * Implemented by TVM_DECLARE_NODE_TYPE_INFO * * \param tid The type index. * \return the check result. */ virtual bool _DerivedFrom(uint32_t tid) const; /*! * \brief get a runtime unique type index given a type key * \param type_key Type key of a type. * \return the corresponding type index. */ static uint32_t TypeKey2Index(const char* type_key); /*! * \brief get type key from type index. * \param index The type index * \return the corresponding type key. */ static const char* TypeIndex2Key(uint32_t index); /*! * \return whether the type is derived from */ template<typename T> inline bool derived_from() const; /*! * \return whether the node is of type T * \tparam The type to be checked. */ template<typename T> inline bool is_type() const; /*! * \brief Get a NodePtr that holds reference to this Node. * \return the NodePtr */ inline NodePtr<Node> GetNodePtr() const; // node ref can see this friend class NodeRef; static constexpr const char* _type_key = "Node"; }; /*! \brief Base class of all node reference object */ class NodeRef { public: /*! \brief type indicate the container type */ using ContainerType = Node; /*! * \brief Comparator * \param other Another node ref. * \return the compare result. */ inline bool operator==(const NodeRef& other) const; /*! * \brief Comparator * \param other Another node ref. * \return the compare result. */ inline bool same_as(const NodeRef& other) const; /*! * \brief Comparator * \param other Another node ref. * \return the compare result. */ inline bool operator<(const NodeRef& other) const; /*! * \brief Comparator * \param other Another node ref. * \return the compare result. */ inline bool operator!=(const NodeRef& other) const; /*! \return the hash function for NodeRef */ inline size_t hash() const; /*! \return whether the expression is null */ inline bool defined() const; /*! \return the internal type index of IRNode */ inline uint32_t type_index() const; /*! \return the internal node pointer */ inline const Node* get() const; /*! \return the internal node pointer */ inline const Node* operator->() const; /*! * \brief Downcast this ir node to its actual type (e.g. Add, or * Select). This returns nullptr if the node is not of the requested * type. Example usage: * * if (const Add *add = node->as<Add>()) { * // This is an add node * } * \tparam T the target type, must be subtype of IRNode */ template<typename T> inline const T *as() const; /*! * \brief A more powerful version of as that also works with * intermediate base types. * \tparam T the target type, must be subtype of IRNode */ template<typename T> inline const T *as_derived() const; /*! \brief default constructor */ NodeRef() = default; explicit NodeRef(NodePtr<Node> node) : node_(node) {} /*! \brief the internal node object, do not touch */ NodePtr<Node> node_; }; /*! * \brief Get a reference type from a Node ptr type * * It is always important to get a reference type * if we want to return a value as reference or keep * the node alive beyond the scope of the function. * * \param ptr The node pointer * \tparam RefType The reference type * \tparam NodeType The node type * \return The corresponding RefType */ template <typename RefType, typename NodeType> inline RefType GetRef(const NodeType* ptr); /*! * \brief Downcast a base reference type to a more specific type. * * \param ref The inptut reference * \return The corresponding SubRef. * \tparam SubRef The target specific reference type. * \tparam BaseRef the current reference type. */ template <typename SubRef, typename BaseRef> inline SubRef Downcast(BaseRef ref); /*! * \brief helper macro to declare type information in a base node. */ #define TVM_DECLARE_BASE_NODE_INFO(TypeName, Parent) \ bool _DerivedFrom(uint32_t tid) const override { \ static uint32_t tidx = TypeKey2Index(TypeName::_type_key); \ if (tidx == tid) return true; \ return Parent::_DerivedFrom(tid); \ } /*! * \brief helper macro to declare type information in a terminal node */ #define TVM_DECLARE_NODE_TYPE_INFO(TypeName, Parent) \ const char* type_key() const final { \ return TypeName::_type_key; \ } \ uint32_t type_index() const final { \ static uint32_t tidx = TypeKey2Index(TypeName::_type_key); \ return tidx; \ } \ bool _DerivedFrom(uint32_t tid) const final { \ static uint32_t tidx = TypeKey2Index(TypeName::_type_key); \ if (tidx == tid) return true; \ return Parent::_DerivedFrom(tid); \ } // implementations of inline functions after this template<typename T> inline bool Node::derived_from() const { // use static field so query only happens once. static uint32_t type_id = Node::TypeKey2Index(T::_type_key); return this->_DerivedFrom(type_id); } template<typename T> inline bool Node::is_type() const { // use static field so query only happens once. static uint32_t type_id = Node::TypeKey2Index(T::_type_key); return type_id == this->type_index(); } inline NodePtr<Node> Node::GetNodePtr() const { return NodePtr<Node>(const_cast<Node*>(this)); } template <typename RefType, typename NodeType> inline RefType GetRef(const NodeType* ptr) { static_assert(std::is_base_of<typename RefType::ContainerType, NodeType>::value, "Can only cast to the ref of same container type"); return RefType(ptr->GetNodePtr()); } template <typename SubRef, typename BaseRef> inline SubRef Downcast(BaseRef ref) { CHECK(ref->template is_type<typename SubRef::ContainerType>() || ref->template derived_from<typename SubRef::ContainerType>()) << "Downcast from " << ref->type_key() << " to " << SubRef::ContainerType::_type_key << " failed."; return SubRef(std::move(ref.node_)); } inline const Node* NodeRef::get() const { return node_.get(); } inline const Node* NodeRef::operator->() const { return node_.get(); } inline bool NodeRef::defined() const { return node_.get() != nullptr; } inline bool NodeRef::operator==(const NodeRef& other) const { return node_.get() == other.node_.get(); } inline bool NodeRef::same_as(const NodeRef& other) const { return node_.get() == other.node_.get(); } inline bool NodeRef::operator<(const NodeRef& other) const { return node_.get() < other.node_.get(); } inline bool NodeRef::operator!=(const NodeRef& other) const { return node_.get() != other.node_.get(); } inline size_t NodeRef::hash() const { return std::hash<Node*>()(node_.get()); } inline uint32_t NodeRef::type_index() const { CHECK(node_.get() != nullptr) << "null type"; return get()->type_index(); } template<typename T> inline const T* NodeRef::as() const { const Node* ptr = static_cast<const Node*>(get()); if (ptr && ptr->is_type<T>()) { return static_cast<const T*>(ptr); } return nullptr; } template<typename T> inline const T* NodeRef::as_derived() const { const Node* ptr = static_cast<const Node*>(get()); if (ptr && (ptr->is_type<T>() || ptr->derived_from<T>())) { return static_cast<const T*>(ptr); } return nullptr; } /*! \brief The hash function for nodes */ struct NodeHash { size_t operator()(const NodeRef& a) const { return a.hash(); } }; /*! \brief The equal comparator for nodes */ struct NodeEqual { bool operator()(const NodeRef& a, const NodeRef& b) const { return a.get() == b.get(); } }; } // namespace tvm #endif // TVM_NODE_NODE_H_