Commit a698ad7f by Steven S. Lyubomirsky Committed by Tianqi Chen

[Relay] Check match expressions for completeness (#3203)

parent 6e2c7ede
......@@ -123,6 +123,24 @@ TVM_DLL bool AlphaEqual(const Expr& e1, const Expr& e2);
TVM_DLL bool AlphaEqual(const Type& t1, const Type& t2);
/*!
* \brief Compare two patterns for structural equivalence.
*
* This comparison operator respects scoping and compares
* patterns without regard to variable choice.
*
* For example: `A(x, _, y)` is equal to `A(z, _, a)`.
*
* See https://en.wikipedia.org/wiki/Lambda_calculus#Alpha_equivalence
* for more details.
*
* \param t1 The left hand pattern.
* \param t2 The right hand pattern.
*
* \return true if equal, otherwise false
*/
TVM_DLL bool AlphaEqual(const Pattern& t1, const Pattern& t2);
/*!
* \brief Add abstraction over a function
*
* For example: `square` is transformed to
......@@ -400,8 +418,19 @@ TVM_DLL Expr ToANormalForm(const Expr& e, const Module& mod);
TVM_DLL Expr ToGraphNormalForm(const Expr& e);
/*!
* \brief Aggressive constant propagation/constant folding/inlining.
* \brief Finds cases that the given match expression does not catch, if any.
*
* \param match the match expression to test
*
* \param mod The module used for accessing global type var definitions, can be None.
*
* \return Returns a list of cases (as patterns) that are not handled by the match
* expression.
*/
TVM_DLL Array<Pattern> UnmatchedCases(const Match& match, const Module& mod);
/*!
* \brief Aggressive constant propagation/constant folding/inlining.
* It will do as much computation in compile time as possible.
* It has two benefit: remove runtime overhead, and allow more optimization (typically fusion).
* As a side effect, code size will explode.
......
......@@ -652,3 +652,21 @@ def partial_evaluate(expr):
The output expression.
"""
return _ir_pass.partial_evaluate(expr)
def unmatched_cases(match, mod=None):
"""
Finds cases that the match expression does not catch, if any.
Parameters
----------
match : tvm.relay.Match
The match expression
mod : Optional[tvm.relay.Module]
The module (defaults to an empty module)
Returns
-------
missing_patterns : [tvm.relay.Pattern]
Patterns that the match expression does not catch.
"""
return _ir_pass.unmatched_cases(match, mod)
......@@ -39,7 +39,6 @@ class Prelude:
self.cons = Constructor("cons", [a, self.l(a)], self.l)
self.mod[self.l] = TypeData(self.l, [a], [self.nil, self.cons])
def define_list_hd(self):
"""Defines a function to get the head of a list. Assume the list has at least one
element.
......@@ -54,7 +53,6 @@ class Prelude:
cons_case = Clause(PatternConstructor(self.cons, [PatternVar(y), PatternVar(z)]), y)
self.mod[self.hd] = Function([x], Match(x, [cons_case]), a, [a])
def define_list_tl(self):
"""Defines a function to get the tail of a list.
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2019 by Contributors
* \file match_exhaustion.cc
* \brief Checking Relay match expression exhaustiveness.
*
* This file implements a function that checks whether a match
* expression is exhaustive, that is, whether a given match clause
* matches every possible case. This is important for ensuring
* code correctness, since hitting an unmatched case results in a
* dynamic error unless exhaustiveness is checked in advance.
*/
#include <tvm/relay/adt.h>
#include <tvm/relay/error.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h>
#include <tvm/relay/pass.h>
#include <stack>
namespace tvm {
namespace relay {
/*! \brief Possible pattern match results */
enum MatchResult : int {
kMatch = 0, // pattern matches
kClash = 1, // pattern conflicts
kUnspecified = 2, // ambiguous: candidate needs more constructors specified
};
class CandidateChecker : public PatternFunctor<MatchResult(const Pattern&, const Pattern&)> {
public:
explicit CandidateChecker() {}
MatchResult Check(const Pattern& pat, const Pattern& candidate) {
return this->VisitPattern(pat, candidate);
}
// for a constructor pattern, we must ensure that the candidate is
// a ConstructorPattern, that it has the same constructor, and
// that its fields match the subpatterns.
MatchResult VisitPattern_(const PatternConstructorNode* op, const Pattern& cand) override {
auto* ctor_cand = cand.as<PatternConstructorNode>();
// attempting to match non-constructor to constructor pattern: need to specify
if (ctor_cand == nullptr) {
return MatchResult::kUnspecified;
}
// check that constructors match
if (!op->constructor.same_as(ctor_cand->constructor)) {
return MatchResult::kClash;
}
// now check that subpatterns match
CHECK(op->patterns.size() == ctor_cand->patterns.size());
bool unspecified = false;
for (size_t i = 0; i < op->patterns.size(); i++) {
MatchResult submatch = this->Check(op->patterns[i], ctor_cand->patterns[i]);
// if we have a clash anywhere, then we can return clash
if (submatch == MatchResult::kClash) {
return MatchResult::kClash;
}
if (submatch == MatchResult::kUnspecified) {
unspecified = true;
}
}
// only return unspecified if we have ruled out a clash
if (unspecified) {
return MatchResult::kUnspecified;
}
return MatchResult::kMatch;
}
// wildcard and var patterns always match
MatchResult VisitPattern_(const PatternWildcardNode*, const Pattern&) override {
return MatchResult::kMatch;
}
MatchResult VisitPattern_(const PatternVarNode*, const Pattern&) override {
return MatchResult::kMatch;
}
};
// Returns list of arrays corresponding to Cartesian product of input list
Array<Array<Pattern>> CartesianProduct(Array<Array<Pattern>> fields) {
CHECK_NE(fields.size(), 0);
Array<Pattern> field_vals = fields[fields.size() - 1];
Array<Array<Pattern>> ret;
// base case: this is the last field left
if (fields.size() == 1) {
for (auto val : field_vals) {
ret.push_back(Array<Pattern>{val});
}
return ret;
}
// if we have more fields left, get the sub-candidates by getting
// their cartesian product and appending the elements here onto those
Array<Array<Pattern>> remaining_fields;
for (size_t i = 0; i < fields.size() - 1; i++) {
remaining_fields.push_back(fields[i]);
}
Array<Array<Pattern>> candidates = CartesianProduct(remaining_fields);
for (auto val : field_vals) {
for (auto candidate : candidates) {
candidate.push_back(val);
ret.push_back(candidate);
}
}
return ret;
}
// Expands all wildcards in the candidate pattern once, using the pattern
// to decide which constructors to insert. Returns a list of all possible expansions.
Array<Pattern> ExpandWildcards(const Pattern& clause_pat, const Pattern& cand,
const Module& mod) {
auto ctor_cand = cand.as<PatternConstructorNode>();
PatternConstructor clause_ctor = Downcast<PatternConstructor>(clause_pat);
auto gtv = Downcast<GlobalTypeVar>(clause_ctor->constructor->belong_to);
// for a wildcard node, create constructor nodes with wildcards for all args
if (!ctor_cand) {
TypeData td = mod->LookupDef(gtv);
// for each constructor add a candidate
Array<Pattern> ret;
for (auto constructor : td->constructors) {
Array<Pattern> args;
for (auto inp : constructor->inputs) {
args.push_back(PatternWildcardNode::make());
}
ret.push_back(PatternConstructorNode::make(constructor, args));
}
return ret;
}
// for constructors, we will expand the wildcards in any field
// that is an ADT
Array<Array<Pattern>> values_by_field;
for (size_t i = 0; i < ctor_cand->constructor->inputs.size(); i++) {
auto* subpattern = clause_ctor->patterns[i].as<PatternConstructorNode>();
// for non-ADT fields, we can only have a wildcard for the value
if (!subpattern) {
values_by_field.push_back({PatternWildcardNode::make()});
continue;
}
// otherwise, recursively expand
values_by_field.push_back(ExpandWildcards(GetRef<Pattern>(subpattern),
ctor_cand->patterns[i], mod));
}
// generate new candidates using a cartesian product
auto all_subfields = CartesianProduct(values_by_field);
Array<Pattern> ret;
for (auto subfields : all_subfields) {
ret.push_back(PatternConstructorNode::make(ctor_cand->constructor, subfields));
}
return ret;
}
/*!
* \brief Finds cases that the match expression does not catch, if any.
* \return Returns a list of cases that are not handled by the match
* expression.
*/
Array<Pattern> UnmatchedCases(const Match& match, const Module& mod) {
/* algorithm:
* candidates = { Wildcard }
* while candidates not empty {
* cand = candidates.pop()
* for clause in clauses {
* if clause fails: next clause
* if clause matches candidate: next candidate
* if candidate is not specific enough:
* candidates += expand_possible_wildcards(cand)
* next candidate
* }
* failed_candidates += { cand }
* }
* return failed_candidates
*/
std::stack<Pattern> candidates;
candidates.push(PatternWildcardNode::make());
CandidateChecker checker;
Array<Pattern> failures;
while (!candidates.empty()) {
Pattern cand = candidates.top();
candidates.pop();
bool failure = true;
for (auto clause : match->clauses) {
// if the check fails, we move on to the next
MatchResult check = checker.Check(clause->lhs, cand);
if (check == MatchResult::kClash) {
continue;
}
// either success or we need to generate more candidates;
// either way, we're done with this candidate
failure = false;
if (check == MatchResult::kUnspecified) {
auto new_candidates = ExpandWildcards(clause->lhs, cand, mod);
for (auto candidate : new_candidates) {
candidates.push(candidate);
}
}
break;
}
if (failure) {
failures.push_back(cand);
}
}
return failures;
}
// expose for testing only
TVM_REGISTER_API("relay._ir_pass.unmatched_cases")
.set_body_typed<Array<Pattern>(const Match&,
const Module&)>([](const Match& match,
const Module& mod_ref) {
Module call_mod = mod_ref;
if (!call_mod.defined()) {
call_mod = ModuleNode::make({}, {});
}
return UnmatchedCases(match, call_mod);
});
} // namespace relay
} // namespace tvm
......@@ -293,6 +293,15 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
GetType(c->rhs),
op->span);
}
// check completness
Match match = GetRef<Match>(op);
Array<Pattern> unmatched_cases = UnmatchedCases(match, this->mod_);
if (unmatched_cases.size() != 0) {
LOG(WARNING) << "Match clause " << match << " does not handle the following cases: "
<< unmatched_cases;
}
return rtype;
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment