/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file match_exhaustion.cc
 * \brief Checking Relay match expression exhaustiveness.
 *
 * This file implements a function that checks whether a match
 * expression is exhaustive, that is, whether a given match clause
 * matches every possible case. This is important for ensuring
 * code correctness, since hitting an unmatched case results in a
 * dynamic error unless exhaustiveness is checked in advance.
 */
#include <tvm/relay/adt.h>
#include <tvm/ir/error.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h>
#include <stack>

namespace tvm {
namespace relay {

/*! \brief Possible pattern match results */
enum MatchResult : int {
  kMatch = 0,        // pattern matches
  kClash = 1,        // pattern conflicts
  kUnspecified = 2,  // ambiguous: candidate needs more constructors specified
};

class CandidateChecker : public PatternFunctor<MatchResult(const Pattern&, const Pattern&)> {
 public:
  explicit CandidateChecker() {}

  MatchResult Check(const Pattern& pat, const Pattern& candidate) {
    return this->VisitPattern(pat, candidate);
  }

  // for a constructor pattern, we must ensure that the candidate is
  // a ConstructorPattern, that it has the same constructor, and
  // that its fields match the subpatterns.
  MatchResult VisitPattern_(const PatternConstructorNode* op, const Pattern& cand) override {
    auto* ctor_cand = cand.as<PatternConstructorNode>();
    // attempting to match non-constructor to constructor pattern: need to specify
    if (ctor_cand == nullptr) {
      return MatchResult::kUnspecified;
    }

    // check that constructors match
    if (!op->constructor.same_as(ctor_cand->constructor)) {
      return MatchResult::kClash;
    }

    // now check that subpatterns match
    CHECK_EQ(op->patterns.size(), ctor_cand->patterns.size());
    bool unspecified = false;
    for (size_t i = 0; i < op->patterns.size(); i++) {
      MatchResult submatch = this->Check(op->patterns[i], ctor_cand->patterns[i]);
      // if we have a clash anywhere, then we can return clash
      if (submatch == MatchResult::kClash) {
        return MatchResult::kClash;
      }
      if (submatch == MatchResult::kUnspecified) {
        unspecified = true;
      }
    }
    // only return unspecified if we have ruled out a clash
    if (unspecified) {
      return MatchResult::kUnspecified;
    }
    return MatchResult::kMatch;
  }

  MatchResult VisitPattern_(const PatternTupleNode* op, const Pattern& cand) override {
    auto* tuple_cand = cand.as<PatternTupleNode>();
    // attempting to match non-tuple to constructor pattern: need to specify
    if (tuple_cand == nullptr) {
      return MatchResult::kUnspecified;
    }

    // now check that subpatterns match
    CHECK_EQ(op->patterns.size(), tuple_cand->patterns.size());
    bool unspecified = false;
    for (size_t i = 0; i < op->patterns.size(); i++) {
      MatchResult submatch = this->Check(op->patterns[i], tuple_cand->patterns[i]);
      // if we have a clash anywhere, then we can return clash
      if (submatch == MatchResult::kClash) {
        return MatchResult::kClash;
      }
      if (submatch == MatchResult::kUnspecified) {
        unspecified = true;
      }
    }
    // only return unspecified if we have ruled out a clash
    if (unspecified) {
      return MatchResult::kUnspecified;
    }
    return MatchResult::kMatch;
  }

  // wildcard and var patterns always match
  MatchResult VisitPattern_(const PatternWildcardNode*, const Pattern&) override {
    return MatchResult::kMatch;
  }

  MatchResult VisitPattern_(const PatternVarNode*, const Pattern&) override {
    return MatchResult::kMatch;
  }
};

// Returns list of arrays corresponding to Cartesian product of input list
Array<Array<Pattern>> CartesianProduct(Array<Array<Pattern>> fields) {
  CHECK_NE(fields.size(), 0);
  Array<Pattern> field_vals = fields[fields.size() - 1];
  Array<Array<Pattern>> ret;

  // base case: this is the last field left
  if (fields.size() == 1) {
    for (auto val : field_vals) {
      ret.push_back(Array<Pattern>{val});
    }
    return ret;
  }

  // if we have more fields left, get the sub-candidates by getting
  // their cartesian product and appending the elements here onto those
  Array<Array<Pattern>> remaining_fields;
  for (size_t i = 0; i < fields.size() - 1; i++) {
    remaining_fields.push_back(fields[i]);
  }
  Array<Array<Pattern>> candidates = CartesianProduct(remaining_fields);
  for (auto val : field_vals) {
    for (auto candidate : candidates) {
      candidate.push_back(val);
      ret.push_back(candidate);
    }
  }
  return ret;
}

Array<Pattern> ExpandWildcardsConstructor(const PatternConstructor& clause_ctor,
                                          const Pattern& cand,
                                          const IRModule& mod);

Array<Pattern> ExpandWildcardsTuple(const PatternTuple& clause_tuple,
                                    const Pattern& cand,
                                    const IRModule& mod);

// Expands all wildcards in the candidate pattern once
// Returns a list of all possible expansions.
Array<Pattern> ExpandWildcards(const Pattern& clause_pat,
                               const Pattern& cand,
                               const IRModule& mod) {
  if (auto clause_ctor = clause_pat.as<PatternConstructorNode>()) {
    return ExpandWildcardsConstructor(GetRef<PatternConstructor>(clause_ctor), cand, mod);
  } else if (auto clause_tup = clause_pat.as<PatternTupleNode>()) {
    return ExpandWildcardsTuple(GetRef<PatternTuple>(clause_tup), cand, mod);
  } else {
    return {cand};
  }
}

// Expands all wildcards in the candidate pattern once.
// Use the pattern to decide which constructors to insert.
// Returns a list of all possible expansions.
Array<Pattern> ExpandWildcardsConstructor(const PatternConstructor& clause_ctor,
                                          const Pattern& cand,
                                          const IRModule& mod) {
  auto gtv = Downcast<GlobalTypeVar>(clause_ctor->constructor->belong_to);

  // for a wildcard node, create constructor nodes with wildcards for all args.
  if (cand.as<PatternWildcardNode>()) {
    TypeData td = mod->LookupTypeDef(gtv);
    // for each constructor add a candidate.
    Array<Pattern> ret;
    for (auto constructor : td->constructors) {
      Array<Pattern> args;
      for (auto inp : constructor->inputs) {
        args.push_back(PatternWildcard());
      }
      ret.push_back(PatternConstructor(constructor, args));
    }
    return ret;
  }

  auto ctor_cand = Downcast<PatternConstructor>(cand);

  // for constructors, we will expand the wildcards in any field that is an ADT.
  Array<Array<Pattern>> values_by_field;
  for (size_t i = 0; i < ctor_cand->constructor->inputs.size(); i++) {
    values_by_field.push_back(ExpandWildcards(clause_ctor->patterns[i],
                                              ctor_cand->patterns[i],
                                              mod));
  }

  // generate new candidates using a cartesian product.
  auto all_subfields = CartesianProduct(values_by_field);
  Array<Pattern> ret;
  for (auto subfields : all_subfields) {
    ret.push_back(PatternConstructor(ctor_cand->constructor, subfields));
  }
  return ret;
}

// Expands all wildcards in the candidate pattern once.
// Returns a list of all possible expansions.
Array<Pattern> ExpandWildcardsTuple(const PatternTuple& clause_tuple,
                                    const Pattern& cand,
                                    const IRModule& mod) {
  // for a wildcard node, create constructor nodes with wildcards for all args.
  if (cand.as<PatternWildcardNode>()) {
    Array<Pattern> args;
    for (auto inp : clause_tuple->patterns) {
      args.push_back(PatternWildcard());
    }
    return {PatternTuple(args)};
  }

  auto tuple_cand = Downcast<PatternTuple>(cand);

  // for constructors, we will expand the wildcards in any field that is an ADT.
  Array<Array<Pattern>> values_by_field;
  for (size_t i = 0; i < tuple_cand->patterns.size(); i++) {
    values_by_field.push_back(ExpandWildcards(clause_tuple->patterns[i],
                                              tuple_cand->patterns[i],
                                              mod));
  }

  // generate new candidates using a cartesian product
  auto all_subfields = CartesianProduct(values_by_field);
  Array<Pattern> ret;
  for (auto subfields : all_subfields) {
    ret.push_back(PatternTuple(subfields));
  }
  return ret;
}

/*!
 * \brief Finds cases that the match expression does not catch, if any.
 * \return Returns a list of cases that are not handled by the match
 * expression.
 */
Array<Pattern> UnmatchedCases(const Match& match, const IRModule& mod) {
  /* algorithm:
   * candidates = { Wildcard }
   * while candidates not empty {
   *   cand = candidates.pop()
   *   for clause in clauses {
   *     if clause fails: next clause
   *     if clause matches candidate: next candidate
   *     if candidate is not specific enough:
   *        candidates += expand_possible_wildcards(cand)
   *        next candidate
   *   }
   *   failed_candidates += { cand }
   * }
   * return failed_candidates
   */
  std::stack<Pattern> candidates;
  candidates.push(PatternWildcard());
  CandidateChecker checker;

  Array<Pattern> failures;

  while (!candidates.empty()) {
    Pattern cand = candidates.top();
    candidates.pop();

    bool failure = true;
    for (auto clause : match->clauses) {
      // if the check fails, we move on to the next
      MatchResult check = checker.Check(clause->lhs, cand);
      if (check == MatchResult::kClash) {
        continue;
      }

      // either success or we need to generate more candidates;
      // either way, we're done with this candidate
      failure = false;
      if (check == MatchResult::kUnspecified) {
        auto new_candidates = ExpandWildcards(clause->lhs, cand, mod);
        for (auto candidate : new_candidates) {
          candidates.push(candidate);
        }
      }
      break;
    }

    if (failure) {
      failures.push_back(cand);
    }
  }

  return failures;
}

// expose for testing only
TVM_REGISTER_GLOBAL("relay.analysis.unmatched_cases")
.set_body_typed(
  [](const Match& match, const IRModule& mod_ref) {
    IRModule call_mod = mod_ref;
    if (!call_mod.defined()) {
      call_mod = IRModule({}, {});
    }
    return UnmatchedCases(match, call_mod);
  });

}  // namespace relay
}  // namespace tvm