/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file tvm/relay/adt.h
 * \brief Algebraic data types for Relay
 */
#ifndef TVM_RELAY_ADT_H_
#define TVM_RELAY_ADT_H_

#include <tvm/ir/attrs.h>
#include <tvm/ir/adt.h>
#include <string>
#include <functional>
#include "./base.h"
#include "./type.h"
#include "./expr.h"

namespace tvm {
namespace relay {

using Constructor = tvm::Constructor;
using ConstructorNode = tvm::ConstructorNode;

using TypeData = tvm::TypeData;
using TypeDataNode = tvm::TypeDataNode;

/*! \brief Base type for declaring relay pattern. */
class PatternNode : public RelayNode {
 public:
  static constexpr const char* _type_key = "relay.Pattern";
  TVM_DECLARE_BASE_OBJECT_INFO(PatternNode, Object);
};

/*!
 * \brief Pattern is the base type for an ADT match pattern in Relay.
 *
 * Given an ADT value, a pattern might accept it and bind the pattern variable to some value
 * (typically a subnode of the input or the input). Otherwise, the pattern rejects the value.
 *
 * ADT pattern matching thus takes a list of values and binds to the first that accepts the value.
 */
class Pattern : public ObjectRef {
 public:
  Pattern() {}
  explicit Pattern(ObjectPtr<tvm::Object> p) : ObjectRef(p) {}

  using ContainerType = PatternNode;
};

/*! \brief A wildcard pattern: Accepts all input and binds nothing. */
class PatternWildcard;
/*! \brief PatternWildcard container node */
class PatternWildcardNode : public PatternNode {
 public:
  PatternWildcardNode() {}

  TVM_DLL static PatternWildcard make();

  void VisitAttrs(tvm::AttrVisitor* v) {
    v->Visit("span", &span);
  }

  static constexpr const char* _type_key = "relay.PatternWildcard";
  TVM_DECLARE_FINAL_OBJECT_INFO(PatternWildcardNode, PatternNode);
};

class PatternWildcard : public Pattern {
 public:
  TVM_DEFINE_OBJECT_REF_METHODS(PatternWildcard, Pattern, PatternWildcardNode);
};

/*! \brief A var pattern. Accept all input and bind to a var. */
class PatternVar;
/*! \brief PatternVar container node */
class PatternVarNode : public PatternNode {
 public:
  PatternVarNode() {}

  /*! \brief Variable that stores the matched value. */
  tvm::relay::Var var;

  TVM_DLL static PatternVar make(tvm::relay::Var var);

  void VisitAttrs(tvm::AttrVisitor* v) {
    v->Visit("var", &var);
    v->Visit("span", &span);
  }

  static constexpr const char* _type_key = "relay.PatternVar";
  TVM_DECLARE_FINAL_OBJECT_INFO(PatternVarNode, PatternNode);
};

class PatternVar : public Pattern {
 public:
  TVM_DEFINE_OBJECT_REF_METHODS(PatternVar, Pattern, PatternVarNode);
};

/*! \brief A constructor pattern. Matches a value with the given constructor, binds recursively. */
class PatternConstructor;
/*! \brief PatternVar container node */
class PatternConstructorNode : public PatternNode {
 public:
  /*! Constructor matched by the pattern. */
  Constructor constructor;
  /*! Sub-patterns to match against each input to the constructor. */
  tvm::Array<Pattern> patterns;

  PatternConstructorNode() {}

  TVM_DLL static PatternConstructor make(Constructor constructor, tvm::Array<Pattern> var);

  void VisitAttrs(tvm::AttrVisitor* v) {
    v->Visit("constructor", &constructor);
    v->Visit("patterns", &patterns);
    v->Visit("span", &span);
  }

  static constexpr const char* _type_key = "relay.PatternConstructor";
  TVM_DECLARE_FINAL_OBJECT_INFO(PatternConstructorNode, PatternNode);
};

class PatternConstructor : public Pattern {
 public:
  TVM_DEFINE_OBJECT_REF_METHODS(PatternConstructor, Pattern, PatternConstructorNode);
};

/*! \brief A tuple pattern. Matches a tuple, binds recursively. */
class PatternTuple;
/*! \brief PatternVar container node */
class PatternTupleNode : public PatternNode {
 public:
  /*! Sub-patterns to match against each value of the tuple. */
  tvm::Array<Pattern> patterns;

  PatternTupleNode() {}

  TVM_DLL static PatternTuple make(tvm::Array<Pattern> var);

  void VisitAttrs(tvm::AttrVisitor* v) {
    v->Visit("patterns", &patterns);
    v->Visit("span", &span);
  }

  static constexpr const char* _type_key = "relay.PatternTuple";
  TVM_DECLARE_FINAL_OBJECT_INFO(PatternTupleNode, PatternNode);
};

class PatternTuple : public Pattern {
 public:
  TVM_DEFINE_OBJECT_REF_METHODS(PatternTuple, Pattern, PatternTupleNode);
};

/*! \brief A clause in a match expression. */
class Clause;
/*! \brief Clause container node. */
class ClauseNode : public Object {
 public:
  /*! \brief The pattern the clause matches. */
  Pattern lhs;
  /*! \brief The resulting value. */
  Expr rhs;

  void VisitAttrs(tvm::AttrVisitor* v) {
    v->Visit("lhs", &lhs);
    v->Visit("rhs", &rhs);
  }

  TVM_DLL static Clause make(Pattern lhs, Expr rhs);

  static constexpr const char* _type_key = "relay.Clause";
  TVM_DECLARE_FINAL_OBJECT_INFO(ClauseNode, Object);
};

class Clause : public ObjectRef {
 public:
  TVM_DEFINE_OBJECT_REF_METHODS(Clause, ObjectRef, ClauseNode);
};

/*! \brief ADT pattern matching exression. */
class Match;
/*! \brief Match container node. */
class MatchNode : public ExprNode {
 public:
  /*! \brief The input being deconstructed. */
  Expr data;

  /*! \brief The match node clauses. */
  tvm::Array<Clause> clauses;

  /*! \brief Should this match be complete (cover all cases)?
   *  If yes, the type checker will generate an error if there are any missing cases.
   */
  bool complete;

  void VisitAttrs(tvm::AttrVisitor* v) {
    v->Visit("data", &data);
    v->Visit("clauses", &clauses);
    v->Visit("complete", &complete);
    v->Visit("span", &span);
    v->Visit("_checked_type_", &checked_type_);
  }

  TVM_DLL static Match make(Expr data, tvm::Array<Clause> pattern, bool complete = true);

  static constexpr const char* _type_key = "relay.Match";
  TVM_DECLARE_FINAL_OBJECT_INFO(MatchNode, ExprNode);
};

class Match : public Expr {
 public:
  TVM_DEFINE_OBJECT_REF_METHODS(Match, RelayExpr, MatchNode);
};

}  // namespace relay
}  // namespace tvm

#endif  // TVM_RELAY_ADT_H_