/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file tvm/relay/adt.h * \brief Algebraic data types for Relay */ #ifndef TVM_RELAY_ADT_H_ #define TVM_RELAY_ADT_H_ #include <tvm/attrs.h> #include <string> #include <functional> #include "./base.h" #include "./type.h" #include "./expr.h" namespace tvm { namespace relay { /*! \brief Base type for declaring relay pattern. */ class PatternNode : public RelayNode { public: static constexpr const char* _type_key = "relay.Pattern"; TVM_DECLARE_BASE_NODE_INFO(PatternNode, Node); }; /*! * \brief Pattern is the base type for an ADT match pattern in Relay. * * Given an ADT value, a pattern might accept it and bind the pattern variable to some value * (typically a subnode of the input or the input). Otherwise, the pattern rejects the value. * * ADT pattern matching thus takes a list of values and binds to the first that accepts the value. */ class Pattern : public NodeRef { public: Pattern() {} explicit Pattern(ObjectPtr<tvm::Object> p) : NodeRef(p) {} using ContainerType = PatternNode; }; /*! \brief A wildcard pattern: Accepts all input and binds nothing. */ class PatternWildcard; /*! \brief PatternWildcard container node */ class PatternWildcardNode : public PatternNode { public: PatternWildcardNode() {} TVM_DLL static PatternWildcard make(); void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("span", &span); } static constexpr const char* _type_key = "relay.PatternWildcard"; TVM_DECLARE_NODE_TYPE_INFO(PatternWildcardNode, PatternNode); }; RELAY_DEFINE_NODE_REF(PatternWildcard, PatternWildcardNode, Pattern); /*! \brief A var pattern. Accept all input and bind to a var. */ class PatternVar; /*! \brief PatternVar container node */ class PatternVarNode : public PatternNode { public: PatternVarNode() {} /*! \brief Variable that stores the matched value. */ tvm::relay::Var var; TVM_DLL static PatternVar make(tvm::relay::Var var); void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("var", &var); v->Visit("span", &span); } static constexpr const char* _type_key = "relay.PatternVar"; TVM_DECLARE_NODE_TYPE_INFO(PatternVarNode, PatternNode); }; RELAY_DEFINE_NODE_REF(PatternVar, PatternVarNode, Pattern); /*! * \brief ADT constructor. * Constructors compare by pointer equality. */ class Constructor; /*! \brief Constructor container node. */ class ConstructorNode : public ExprNode { public: /*! \brief The name (only a hint) */ std::string name_hint; /*! \brief Input to the constructor. */ tvm::Array<Type> inputs; /*! \brief The datatype the constructor will construct. */ GlobalTypeVar belong_to; /*! \brief Index in the table of constructors (set when the type is registered). */ mutable int32_t tag = -1; ConstructorNode() {} TVM_DLL static Constructor make(std::string name_hint, tvm::Array<Type> inputs, GlobalTypeVar belong_to); void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("name_hint", &name_hint); v->Visit("inputs", &inputs); v->Visit("belong_to", &belong_to); v->Visit("tag", &tag); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); } static constexpr const char* _type_key = "relay.Constructor"; TVM_DECLARE_NODE_TYPE_INFO(ConstructorNode, ExprNode); }; RELAY_DEFINE_NODE_REF(Constructor, ConstructorNode, Expr); /*! \brief A constructor pattern. Matches a value with the given constructor, binds recursively. */ class PatternConstructor; /*! \brief PatternVar container node */ class PatternConstructorNode : public PatternNode { public: /*! Constructor matched by the pattern. */ Constructor constructor; /*! Sub-patterns to match against each input to the constructor. */ tvm::Array<Pattern> patterns; PatternConstructorNode() {} TVM_DLL static PatternConstructor make(Constructor constructor, tvm::Array<Pattern> var); void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("constructor", &constructor); v->Visit("patterns", &patterns); v->Visit("span", &span); } static constexpr const char* _type_key = "relay.PatternConstructor"; TVM_DECLARE_NODE_TYPE_INFO(PatternConstructorNode, PatternNode); }; RELAY_DEFINE_NODE_REF(PatternConstructor, PatternConstructorNode, Pattern); /*! \brief A tuple pattern. Matches a tuple, binds recursively. */ class PatternTuple; /*! \brief PatternVar container node */ class PatternTupleNode : public PatternNode { public: /*! Sub-patterns to match against each value of the tuple. */ tvm::Array<Pattern> patterns; PatternTupleNode() {} TVM_DLL static PatternTuple make(tvm::Array<Pattern> var); void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("patterns", &patterns); v->Visit("span", &span); } static constexpr const char* _type_key = "relay.PatternTuple"; TVM_DECLARE_NODE_TYPE_INFO(PatternTupleNode, PatternNode); }; RELAY_DEFINE_NODE_REF(PatternTuple, PatternTupleNode, Pattern); /*! * \brief Stores all data for an Algebraic Data Type (ADT). * * In particular, it stores the handle (global type var) for an ADT * and the constructors used to build it and is kept in the module. Note * that type parameters are also indicated in the type data: this means that * for any instance of an ADT, the type parameters must be indicated. That is, * an ADT definition is treated as a type-level function, so an ADT handle * must be wrapped in a TypeCall node that instantiates the type-level arguments. * The kind checker enforces this. */ class TypeData; /*! \brief TypeData container node */ class TypeDataNode : public TypeNode { public: /*! * \brief The header is simply the name of the ADT. * We adopt nominal typing for ADT definitions; * that is, differently-named ADT definitions with same constructors * have different types. */ GlobalTypeVar header; /*! \brief The type variables (to allow for polymorphism). */ tvm::Array<TypeVar> type_vars; /*! \brief The constructors. */ tvm::Array<Constructor> constructors; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("header", &header); v->Visit("type_vars", &type_vars); v->Visit("constructors", &constructors); v->Visit("span", &span); } TVM_DLL static TypeData make(GlobalTypeVar header, tvm::Array<TypeVar> type_vars, tvm::Array<Constructor> constructors); static constexpr const char* _type_key = "relay.TypeData"; TVM_DECLARE_NODE_TYPE_INFO(TypeDataNode, TypeNode); }; RELAY_DEFINE_NODE_REF(TypeData, TypeDataNode, Type); /*! \brief A clause in a match expression. */ class Clause; /*! \brief Clause container node. */ class ClauseNode : public Node { public: /*! \brief The pattern the clause matches. */ Pattern lhs; /*! \brief The resulting value. */ Expr rhs; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("lhs", &lhs); v->Visit("rhs", &rhs); } TVM_DLL static Clause make(Pattern lhs, Expr rhs); static constexpr const char* _type_key = "relay.Clause"; TVM_DECLARE_NODE_TYPE_INFO(ClauseNode, Node); }; RELAY_DEFINE_NODE_REF(Clause, ClauseNode, NodeRef); /*! \brief ADT pattern matching exression. */ class Match; /*! \brief Match container node. */ class MatchNode : public ExprNode { public: /*! \brief The input being deconstructed. */ Expr data; /*! \brief The match node clauses. */ tvm::Array<Clause> clauses; /*! \brief Should this match be complete (cover all cases)? * If yes, the type checker will generate an error if there are any missing cases. */ bool complete; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("data", &data); v->Visit("clauses", &clauses); v->Visit("complete", &complete); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); } TVM_DLL static Match make(Expr data, tvm::Array<Clause> pattern, bool complete = true); static constexpr const char* _type_key = "relay.Match"; TVM_DECLARE_NODE_TYPE_INFO(MatchNode, ExprNode); }; RELAY_DEFINE_NODE_REF(Match, MatchNode, Expr); } // namespace relay } // namespace tvm #endif // TVM_RELAY_ADT_H_