# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name """Algebraic data types in Relay.""" from .base import RelayNode, register_relay_node, Object from . import _make from .ty import Type from .expr import Expr, Call class Pattern(RelayNode): """Base type for pattern matching constructs.""" @register_relay_node class PatternWildcard(Pattern): """Wildcard pattern in Relay: Matches any ADT and binds nothing.""" def __init__(self): """Constructs a wildcard pattern. Parameters ---------- None Returns ------- wildcard: PatternWildcard a wildcard pattern. """ self.__init_handle_by_constructor__(_make.PatternWildcard) @register_relay_node class PatternVar(Pattern): """Variable pattern in Relay: Matches anything and binds it to the variable.""" def __init__(self, var): """Construct a variable pattern. Parameters ---------- var: tvm.relay.Var Returns ------- pv: PatternVar A variable pattern. """ self.__init_handle_by_constructor__(_make.PatternVar, var) @register_relay_node class PatternConstructor(Pattern): """Constructor pattern in Relay: Matches an ADT of the given constructor, binds recursively.""" def __init__(self, constructor, patterns=None): """Construct a constructor pattern. Parameters ---------- constructor: Constructor The constructor. patterns: Optional[List[Pattern]] Optional subpatterns: for each field of the constructor, match to the given subpattern (treated as a variable pattern by default). Returns ------- wildcard: PatternWildcard a wildcard pattern. """ if patterns is None: patterns = [] self.__init_handle_by_constructor__(_make.PatternConstructor, constructor, patterns) @register_relay_node class PatternTuple(Pattern): """Constructor pattern in Relay: Matches a tuple, binds recursively.""" def __init__(self, patterns=None): """Construct a tuple pattern. Parameters ---------- patterns: Optional[List[Pattern]] Optional subpatterns: for each field of the constructor, match to the given subpattern (treated as a variable pattern by default). Returns ------- wildcard: PatternWildcard a wildcard pattern. """ if patterns is None: patterns = [] self.__init_handle_by_constructor__(_make.PatternTuple, patterns) @register_relay_node class Constructor(Expr): """Relay ADT constructor.""" def __init__(self, name_hint, inputs, belong_to): """Defines an ADT constructor. Parameters ---------- name_hint : str Name of constructor (only a hint). inputs : List[Type] Input types. belong_to : tvm.relay.GlobalTypeVar Denotes which ADT the constructor belongs to. Returns ------- con: Constructor A constructor. """ self.__init_handle_by_constructor__(_make.Constructor, name_hint, inputs, belong_to) def __call__(self, *args): """Call the constructor. Parameters ---------- args: List[relay.Expr] The arguments to the constructor. Returns ------- call: relay.Call A call to the constructor. """ return Call(self, args) @register_relay_node class TypeData(Type): """Stores the definition for an Algebraic Data Type (ADT) in Relay. Note that ADT definitions are treated as type-level functions because the type parameters need to be given for an instance of the ADT. Thus, any global type var that is an ADT header needs to be wrapped in a type call that passes in the type params. """ def __init__(self, header, type_vars, constructors): """Defines a TypeData object. Parameters ---------- header: tvm.relay.GlobalTypeVar The name of the ADT. ADTs with the same constructors but different names are treated as different types. type_vars: List[TypeVar] Type variables that appear in constructors. constructors: List[tvm.relay.Constructor] The constructors for the ADT. Returns ------- type_data: TypeData The adt declaration. """ self.__init_handle_by_constructor__(_make.TypeData, header, type_vars, constructors) @register_relay_node class Clause(Object): """Clause for pattern matching in Relay.""" def __init__(self, lhs, rhs): """Construct a clause. Parameters ---------- lhs: tvm.relay.Pattern Left-hand side of match clause. rhs: tvm.relay.Expr Right-hand side of match clause. Returns ------- clause: Clause The Clause. """ self.__init_handle_by_constructor__(_make.Clause, lhs, rhs) @register_relay_node class Match(Expr): """Pattern matching expression in Relay.""" def __init__(self, data, clauses, complete=True): """Construct a Match. Parameters ---------- data: tvm.relay.Expr The value being deconstructed and matched. clauses: List[tvm.relay.Clause] The pattern match clauses. complete: Optional[Bool] Should the match be complete (cover all cases)? If yes, the type checker will generate an error if there are any missing cases. Returns ------- match: tvm.relay.Expr The match expression. """ self.__init_handle_by_constructor__(_make.Match, data, clauses, complete)