adt.h 8.73 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
/*!
 * \file tvm/relay/adt.h
 * \brief Algebraic data types for Relay
 */
#ifndef TVM_RELAY_ADT_H_
#define TVM_RELAY_ADT_H_

#include <tvm/attrs.h>
#include <string>
#include <functional>
#include "./base.h"
#include "./type.h"
#include "./expr.h"

namespace tvm {
namespace relay {

/*! \brief Base type for declaring relay pattern. */
class PatternNode : public RelayNode {
 public:
  static constexpr const char* _type_key = "relay.Pattern";
  TVM_DECLARE_BASE_NODE_INFO(PatternNode, Node);
};

/*!
 * \brief Pattern is the base type for an ADT match pattern in Relay.
 *
 * Given an ADT value, a pattern might accept it and bind the pattern variable to some value
 * (typically a subnode of the input or the input). Otherwise, the pattern rejects the value.
 *
 * ADT pattern matching thus takes a list of values and binds to the first that accepts the value.
 */
class Pattern : public NodeRef {
 public:
  Pattern() {}
55
  explicit Pattern(ObjectPtr<tvm::Object> p) : NodeRef(p) {}
56 57 58 59 60 61 62 63 64 65 66 67 68

  using ContainerType = PatternNode;
};

/*! \brief A wildcard pattern: Accepts all input and binds nothing. */
class PatternWildcard;
/*! \brief PatternWildcard container node */
class PatternWildcardNode : public PatternNode {
 public:
  PatternWildcardNode() {}

  TVM_DLL static PatternWildcard make();

69
  void VisitAttrs(tvm::AttrVisitor* v) {
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90
    v->Visit("span", &span);
  }

  static constexpr const char* _type_key = "relay.PatternWildcard";
  TVM_DECLARE_NODE_TYPE_INFO(PatternWildcardNode, PatternNode);
};

RELAY_DEFINE_NODE_REF(PatternWildcard, PatternWildcardNode, Pattern);

/*! \brief A var pattern. Accept all input and bind to a var. */
class PatternVar;
/*! \brief PatternVar container node */
class PatternVarNode : public PatternNode {
 public:
  PatternVarNode() {}

  /*! \brief Variable that stores the matched value. */
  tvm::relay::Var var;

  TVM_DLL static PatternVar make(tvm::relay::Var var);

91
  void VisitAttrs(tvm::AttrVisitor* v) {
92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116
    v->Visit("var", &var);
    v->Visit("span", &span);
  }

  static constexpr const char* _type_key = "relay.PatternVar";
  TVM_DECLARE_NODE_TYPE_INFO(PatternVarNode, PatternNode);
};

RELAY_DEFINE_NODE_REF(PatternVar, PatternVarNode, Pattern);

/*!
 * \brief ADT constructor.
 * Constructors compare by pointer equality.
 */
class Constructor;
/*! \brief Constructor container node. */
class ConstructorNode : public ExprNode {
 public:
  /*! \brief The name (only a hint) */
  std::string name_hint;
  /*! \brief Input to the constructor. */
  tvm::Array<Type> inputs;
  /*! \brief The datatype the constructor will construct. */
  GlobalTypeVar belong_to;
  /*! \brief Index in the table of constructors (set when the type is registered). */
117
  mutable int32_t tag = -1;
118 119 120 121 122 123 124

  ConstructorNode() {}

  TVM_DLL static Constructor make(std::string name_hint,
                                  tvm::Array<Type> inputs,
                                  GlobalTypeVar belong_to);

125
  void VisitAttrs(tvm::AttrVisitor* v) {
126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153
    v->Visit("name_hint", &name_hint);
    v->Visit("inputs", &inputs);
    v->Visit("belong_to", &belong_to);
    v->Visit("tag", &tag);
    v->Visit("span", &span);
    v->Visit("_checked_type_", &checked_type_);
  }

  static constexpr const char* _type_key = "relay.Constructor";
  TVM_DECLARE_NODE_TYPE_INFO(ConstructorNode, ExprNode);
};

RELAY_DEFINE_NODE_REF(Constructor, ConstructorNode, Expr);

/*! \brief A constructor pattern. Matches a value with the given constructor, binds recursively. */
class PatternConstructor;
/*! \brief PatternVar container node */
class PatternConstructorNode : public PatternNode {
 public:
  /*! Constructor matched by the pattern. */
  Constructor constructor;
  /*! Sub-patterns to match against each input to the constructor. */
  tvm::Array<Pattern> patterns;

  PatternConstructorNode() {}

  TVM_DLL static PatternConstructor make(Constructor constructor, tvm::Array<Pattern> var);

154
  void VisitAttrs(tvm::AttrVisitor* v) {
155 156 157 158 159 160 161 162 163 164 165
    v->Visit("constructor", &constructor);
    v->Visit("patterns", &patterns);
    v->Visit("span", &span);
  }

  static constexpr const char* _type_key = "relay.PatternConstructor";
  TVM_DECLARE_NODE_TYPE_INFO(PatternConstructorNode, PatternNode);
};

RELAY_DEFINE_NODE_REF(PatternConstructor, PatternConstructorNode, Pattern);

166 167 168 169 170 171 172 173 174 175 176 177
/*! \brief A tuple pattern. Matches a tuple, binds recursively. */
class PatternTuple;
/*! \brief PatternVar container node */
class PatternTupleNode : public PatternNode {
 public:
  /*! Sub-patterns to match against each value of the tuple. */
  tvm::Array<Pattern> patterns;

  PatternTupleNode() {}

  TVM_DLL static PatternTuple make(tvm::Array<Pattern> var);

178
  void VisitAttrs(tvm::AttrVisitor* v) {
179 180 181 182 183 184 185 186 187 188
    v->Visit("patterns", &patterns);
    v->Visit("span", &span);
  }

  static constexpr const char* _type_key = "relay.PatternTuple";
  TVM_DECLARE_NODE_TYPE_INFO(PatternTupleNode, PatternNode);
};

RELAY_DEFINE_NODE_REF(PatternTuple, PatternTupleNode, Pattern);

189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215
/*!
 * \brief Stores all data for an Algebraic Data Type (ADT).
 *
 * In particular, it stores the handle (global type var) for an ADT
 * and the constructors used to build it and is kept in the module. Note
 * that type parameters are also indicated in the type data: this means that
 * for any instance of an ADT, the type parameters must be indicated. That is,
 * an ADT definition is treated as a type-level function, so an ADT handle
 * must be wrapped in a TypeCall node that instantiates the type-level arguments.
 * The kind checker enforces this.
 */
class TypeData;
/*! \brief TypeData container node */
class TypeDataNode : public TypeNode {
 public:
  /*!
   * \brief The header is simply the name of the ADT.
   * We adopt nominal typing for ADT definitions;
   * that is, differently-named ADT definitions with same constructors
   * have different types.
   */
  GlobalTypeVar header;
  /*! \brief The type variables (to allow for polymorphism). */
  tvm::Array<TypeVar> type_vars;
  /*! \brief The constructors. */
  tvm::Array<Constructor> constructors;

216
  void VisitAttrs(tvm::AttrVisitor* v) {
217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242
    v->Visit("header", &header);
    v->Visit("type_vars", &type_vars);
    v->Visit("constructors", &constructors);
    v->Visit("span", &span);
  }

  TVM_DLL static TypeData make(GlobalTypeVar header,
                               tvm::Array<TypeVar> type_vars,
                               tvm::Array<Constructor> constructors);

  static constexpr const char* _type_key = "relay.TypeData";
  TVM_DECLARE_NODE_TYPE_INFO(TypeDataNode, TypeNode);
};

RELAY_DEFINE_NODE_REF(TypeData, TypeDataNode, Type);

/*! \brief A clause in a match expression. */
class Clause;
/*! \brief Clause container node. */
class ClauseNode : public Node {
 public:
  /*! \brief The pattern the clause matches. */
  Pattern lhs;
  /*! \brief The resulting value. */
  Expr rhs;

243
  void VisitAttrs(tvm::AttrVisitor* v) {
244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266
    v->Visit("lhs", &lhs);
    v->Visit("rhs", &rhs);
  }

  TVM_DLL static Clause make(Pattern lhs, Expr rhs);

  static constexpr const char* _type_key = "relay.Clause";
  TVM_DECLARE_NODE_TYPE_INFO(ClauseNode, Node);
};

RELAY_DEFINE_NODE_REF(Clause, ClauseNode, NodeRef);

/*! \brief ADT pattern matching exression. */
class Match;
/*! \brief Match container node. */
class MatchNode : public ExprNode {
 public:
  /*! \brief The input being deconstructed. */
  Expr data;

  /*! \brief The match node clauses. */
  tvm::Array<Clause> clauses;

267
  /*! \brief Should this match be complete (cover all cases)?
268 269 270 271
   *  If yes, the type checker will generate an error if there are any missing cases.
   */
  bool complete;

272
  void VisitAttrs(tvm::AttrVisitor* v) {
273
    v->Visit("data", &data);
274
    v->Visit("clauses", &clauses);
275
    v->Visit("complete", &complete);
276 277 278 279
    v->Visit("span", &span);
    v->Visit("_checked_type_", &checked_type_);
  }

280
  TVM_DLL static Match make(Expr data, tvm::Array<Clause> pattern, bool complete = true);
281 282 283 284 285 286 287 288 289 290 291

  static constexpr const char* _type_key = "relay.Match";
  TVM_DECLARE_NODE_TYPE_INFO(MatchNode, ExprNode);
};

RELAY_DEFINE_NODE_REF(Match, MatchNode, Expr);

}  // namespace relay
}  // namespace tvm

#endif  // TVM_RELAY_ADT_H_