Commit a698ad7f by Steven S. Lyubomirsky Committed by Tianqi Chen

[Relay] Check match expressions for completeness (#3203)

parent 6e2c7ede
...@@ -123,6 +123,24 @@ TVM_DLL bool AlphaEqual(const Expr& e1, const Expr& e2); ...@@ -123,6 +123,24 @@ TVM_DLL bool AlphaEqual(const Expr& e1, const Expr& e2);
TVM_DLL bool AlphaEqual(const Type& t1, const Type& t2); TVM_DLL bool AlphaEqual(const Type& t1, const Type& t2);
/*! /*!
* \brief Compare two patterns for structural equivalence.
*
* This comparison operator respects scoping and compares
* patterns without regard to variable choice.
*
* For example: `A(x, _, y)` is equal to `A(z, _, a)`.
*
* See https://en.wikipedia.org/wiki/Lambda_calculus#Alpha_equivalence
* for more details.
*
* \param t1 The left hand pattern.
* \param t2 The right hand pattern.
*
* \return true if equal, otherwise false
*/
TVM_DLL bool AlphaEqual(const Pattern& t1, const Pattern& t2);
/*!
* \brief Add abstraction over a function * \brief Add abstraction over a function
* *
* For example: `square` is transformed to * For example: `square` is transformed to
...@@ -400,8 +418,19 @@ TVM_DLL Expr ToANormalForm(const Expr& e, const Module& mod); ...@@ -400,8 +418,19 @@ TVM_DLL Expr ToANormalForm(const Expr& e, const Module& mod);
TVM_DLL Expr ToGraphNormalForm(const Expr& e); TVM_DLL Expr ToGraphNormalForm(const Expr& e);
/*! /*!
* \brief Aggressive constant propagation/constant folding/inlining. * \brief Finds cases that the given match expression does not catch, if any.
*
* \param match the match expression to test
*
* \param mod The module used for accessing global type var definitions, can be None.
* *
* \return Returns a list of cases (as patterns) that are not handled by the match
* expression.
*/
TVM_DLL Array<Pattern> UnmatchedCases(const Match& match, const Module& mod);
/*!
* \brief Aggressive constant propagation/constant folding/inlining.
* It will do as much computation in compile time as possible. * It will do as much computation in compile time as possible.
* It has two benefit: remove runtime overhead, and allow more optimization (typically fusion). * It has two benefit: remove runtime overhead, and allow more optimization (typically fusion).
* As a side effect, code size will explode. * As a side effect, code size will explode.
......
...@@ -652,3 +652,21 @@ def partial_evaluate(expr): ...@@ -652,3 +652,21 @@ def partial_evaluate(expr):
The output expression. The output expression.
""" """
return _ir_pass.partial_evaluate(expr) return _ir_pass.partial_evaluate(expr)
def unmatched_cases(match, mod=None):
"""
Finds cases that the match expression does not catch, if any.
Parameters
----------
match : tvm.relay.Match
The match expression
mod : Optional[tvm.relay.Module]
The module (defaults to an empty module)
Returns
-------
missing_patterns : [tvm.relay.Pattern]
Patterns that the match expression does not catch.
"""
return _ir_pass.unmatched_cases(match, mod)
...@@ -39,7 +39,6 @@ class Prelude: ...@@ -39,7 +39,6 @@ class Prelude:
self.cons = Constructor("cons", [a, self.l(a)], self.l) self.cons = Constructor("cons", [a, self.l(a)], self.l)
self.mod[self.l] = TypeData(self.l, [a], [self.nil, self.cons]) self.mod[self.l] = TypeData(self.l, [a], [self.nil, self.cons])
def define_list_hd(self): def define_list_hd(self):
"""Defines a function to get the head of a list. Assume the list has at least one """Defines a function to get the head of a list. Assume the list has at least one
element. element.
...@@ -54,7 +53,6 @@ class Prelude: ...@@ -54,7 +53,6 @@ class Prelude:
cons_case = Clause(PatternConstructor(self.cons, [PatternVar(y), PatternVar(z)]), y) cons_case = Clause(PatternConstructor(self.cons, [PatternVar(y), PatternVar(z)]), y)
self.mod[self.hd] = Function([x], Match(x, [cons_case]), a, [a]) self.mod[self.hd] = Function([x], Match(x, [cons_case]), a, [a])
def define_list_tl(self): def define_list_tl(self):
"""Defines a function to get the tail of a list. """Defines a function to get the tail of a list.
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2019 by Contributors
* \file match_exhaustion.cc
* \brief Checking Relay match expression exhaustiveness.
*
* This file implements a function that checks whether a match
* expression is exhaustive, that is, whether a given match clause
* matches every possible case. This is important for ensuring
* code correctness, since hitting an unmatched case results in a
* dynamic error unless exhaustiveness is checked in advance.
*/
#include <tvm/relay/adt.h>
#include <tvm/relay/error.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h>
#include <tvm/relay/pass.h>
#include <stack>
namespace tvm {
namespace relay {
/*! \brief Possible pattern match results */
enum MatchResult : int {
kMatch = 0, // pattern matches
kClash = 1, // pattern conflicts
kUnspecified = 2, // ambiguous: candidate needs more constructors specified
};
class CandidateChecker : public PatternFunctor<MatchResult(const Pattern&, const Pattern&)> {
public:
explicit CandidateChecker() {}
MatchResult Check(const Pattern& pat, const Pattern& candidate) {
return this->VisitPattern(pat, candidate);
}
// for a constructor pattern, we must ensure that the candidate is
// a ConstructorPattern, that it has the same constructor, and
// that its fields match the subpatterns.
MatchResult VisitPattern_(const PatternConstructorNode* op, const Pattern& cand) override {
auto* ctor_cand = cand.as<PatternConstructorNode>();
// attempting to match non-constructor to constructor pattern: need to specify
if (ctor_cand == nullptr) {
return MatchResult::kUnspecified;
}
// check that constructors match
if (!op->constructor.same_as(ctor_cand->constructor)) {
return MatchResult::kClash;
}
// now check that subpatterns match
CHECK(op->patterns.size() == ctor_cand->patterns.size());
bool unspecified = false;
for (size_t i = 0; i < op->patterns.size(); i++) {
MatchResult submatch = this->Check(op->patterns[i], ctor_cand->patterns[i]);
// if we have a clash anywhere, then we can return clash
if (submatch == MatchResult::kClash) {
return MatchResult::kClash;
}
if (submatch == MatchResult::kUnspecified) {
unspecified = true;
}
}
// only return unspecified if we have ruled out a clash
if (unspecified) {
return MatchResult::kUnspecified;
}
return MatchResult::kMatch;
}
// wildcard and var patterns always match
MatchResult VisitPattern_(const PatternWildcardNode*, const Pattern&) override {
return MatchResult::kMatch;
}
MatchResult VisitPattern_(const PatternVarNode*, const Pattern&) override {
return MatchResult::kMatch;
}
};
// Returns list of arrays corresponding to Cartesian product of input list
Array<Array<Pattern>> CartesianProduct(Array<Array<Pattern>> fields) {
CHECK_NE(fields.size(), 0);
Array<Pattern> field_vals = fields[fields.size() - 1];
Array<Array<Pattern>> ret;
// base case: this is the last field left
if (fields.size() == 1) {
for (auto val : field_vals) {
ret.push_back(Array<Pattern>{val});
}
return ret;
}
// if we have more fields left, get the sub-candidates by getting
// their cartesian product and appending the elements here onto those
Array<Array<Pattern>> remaining_fields;
for (size_t i = 0; i < fields.size() - 1; i++) {
remaining_fields.push_back(fields[i]);
}
Array<Array<Pattern>> candidates = CartesianProduct(remaining_fields);
for (auto val : field_vals) {
for (auto candidate : candidates) {
candidate.push_back(val);
ret.push_back(candidate);
}
}
return ret;
}
// Expands all wildcards in the candidate pattern once, using the pattern
// to decide which constructors to insert. Returns a list of all possible expansions.
Array<Pattern> ExpandWildcards(const Pattern& clause_pat, const Pattern& cand,
const Module& mod) {
auto ctor_cand = cand.as<PatternConstructorNode>();
PatternConstructor clause_ctor = Downcast<PatternConstructor>(clause_pat);
auto gtv = Downcast<GlobalTypeVar>(clause_ctor->constructor->belong_to);
// for a wildcard node, create constructor nodes with wildcards for all args
if (!ctor_cand) {
TypeData td = mod->LookupDef(gtv);
// for each constructor add a candidate
Array<Pattern> ret;
for (auto constructor : td->constructors) {
Array<Pattern> args;
for (auto inp : constructor->inputs) {
args.push_back(PatternWildcardNode::make());
}
ret.push_back(PatternConstructorNode::make(constructor, args));
}
return ret;
}
// for constructors, we will expand the wildcards in any field
// that is an ADT
Array<Array<Pattern>> values_by_field;
for (size_t i = 0; i < ctor_cand->constructor->inputs.size(); i++) {
auto* subpattern = clause_ctor->patterns[i].as<PatternConstructorNode>();
// for non-ADT fields, we can only have a wildcard for the value
if (!subpattern) {
values_by_field.push_back({PatternWildcardNode::make()});
continue;
}
// otherwise, recursively expand
values_by_field.push_back(ExpandWildcards(GetRef<Pattern>(subpattern),
ctor_cand->patterns[i], mod));
}
// generate new candidates using a cartesian product
auto all_subfields = CartesianProduct(values_by_field);
Array<Pattern> ret;
for (auto subfields : all_subfields) {
ret.push_back(PatternConstructorNode::make(ctor_cand->constructor, subfields));
}
return ret;
}
/*!
* \brief Finds cases that the match expression does not catch, if any.
* \return Returns a list of cases that are not handled by the match
* expression.
*/
Array<Pattern> UnmatchedCases(const Match& match, const Module& mod) {
/* algorithm:
* candidates = { Wildcard }
* while candidates not empty {
* cand = candidates.pop()
* for clause in clauses {
* if clause fails: next clause
* if clause matches candidate: next candidate
* if candidate is not specific enough:
* candidates += expand_possible_wildcards(cand)
* next candidate
* }
* failed_candidates += { cand }
* }
* return failed_candidates
*/
std::stack<Pattern> candidates;
candidates.push(PatternWildcardNode::make());
CandidateChecker checker;
Array<Pattern> failures;
while (!candidates.empty()) {
Pattern cand = candidates.top();
candidates.pop();
bool failure = true;
for (auto clause : match->clauses) {
// if the check fails, we move on to the next
MatchResult check = checker.Check(clause->lhs, cand);
if (check == MatchResult::kClash) {
continue;
}
// either success or we need to generate more candidates;
// either way, we're done with this candidate
failure = false;
if (check == MatchResult::kUnspecified) {
auto new_candidates = ExpandWildcards(clause->lhs, cand, mod);
for (auto candidate : new_candidates) {
candidates.push(candidate);
}
}
break;
}
if (failure) {
failures.push_back(cand);
}
}
return failures;
}
// expose for testing only
TVM_REGISTER_API("relay._ir_pass.unmatched_cases")
.set_body_typed<Array<Pattern>(const Match&,
const Module&)>([](const Match& match,
const Module& mod_ref) {
Module call_mod = mod_ref;
if (!call_mod.defined()) {
call_mod = ModuleNode::make({}, {});
}
return UnmatchedCases(match, call_mod);
});
} // namespace relay
} // namespace tvm
...@@ -293,6 +293,15 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>, ...@@ -293,6 +293,15 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
GetType(c->rhs), GetType(c->rhs),
op->span); op->span);
} }
// check completness
Match match = GetRef<Match>(op);
Array<Pattern> unmatched_cases = UnmatchedCases(match, this->mod_);
if (unmatched_cases.size() != 0) {
LOG(WARNING) << "Match clause " << match << " does not handle the following cases: "
<< unmatched_cases;
}
return rtype; return rtype;
} }
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm import relay
from tvm.relay.prelude import Prelude
from tvm.relay.ir_pass import unmatched_cases
def test_empty_match_block():
# empty match block will not match anything, so it should return a wildcard pattern
v = relay.Var('v')
match = relay.Match(v, [])
unmatched = unmatched_cases(match)
assert len(unmatched) == 1
assert isinstance(unmatched[0], relay.PatternWildcard)
def test_trivial_matches():
# a match clause with a wildcard will match anything
v = relay.Var('v')
match = relay.Match(v, [
relay.Clause(relay.PatternWildcard(), v)
])
assert len(unmatched_cases(match)) == 0
# same with a pattern var
w = relay.Var('w')
match = relay.Match(v, [
relay.Clause(relay.PatternVar(w), w)
])
assert len(unmatched_cases(match)) == 0
def test_single_constructor_adt():
mod = relay.Module()
box = relay.GlobalTypeVar('box')
a = relay.TypeVar('a')
box_ctor = relay.Constructor('box', [a], box)
box_data = relay.TypeData(box, [a], [box_ctor])
mod[box] = box_data
v = relay.Var('v')
match = relay.Match(v, [
relay.Clause(relay.PatternConstructor(box_ctor, [relay.PatternWildcard()]), v)
])
# with one constructor, having one pattern constructor case is exhaustive
assert len(unmatched_cases(match, mod)) == 0
# this will be so if we nest the constructors too
nested_pattern = relay.Match(v, [
relay.Clause(
relay.PatternConstructor(
box_ctor,
[relay.PatternConstructor(box_ctor,
[relay.PatternConstructor(
box_ctor,
[relay.PatternWildcard()])])]), v)
])
assert len(unmatched_cases(nested_pattern, mod)) == 0
def test_too_specific_match():
mod = relay.Module()
p = Prelude(mod)
v = relay.Var('v')
match = relay.Match(v, [
relay.Clause(
relay.PatternConstructor(
p.cons, [relay.PatternWildcard(),
relay.PatternConstructor(p.cons, [relay.PatternWildcard(),
relay.PatternWildcard()])]), v)
])
unmatched = unmatched_cases(match, mod)
# will not match nil or a list of length 1
nil_found = False
single_length_found = False
assert len(unmatched) == 2
for case in unmatched:
assert isinstance(case, relay.PatternConstructor)
if case.constructor == p.nil:
nil_found = True
if case.constructor == p.cons:
assert isinstance(case.patterns[1], relay.PatternConstructor)
assert case.patterns[1].constructor == p.nil
single_length_found = True
assert nil_found and single_length_found
# if we add a wildcard, this should work
new_match = relay.Match(v, [
relay.Clause(
relay.PatternConstructor(
p.cons, [relay.PatternWildcard(),
relay.PatternConstructor(p.cons, [relay.PatternWildcard(),
relay.PatternWildcard()])]), v),
relay.Clause(relay.PatternWildcard(), v)
])
assert len(unmatched_cases(new_match, mod)) == 0
def test_multiple_constructor_clauses():
mod = relay.Module()
p = Prelude(mod)
v = relay.Var('v')
match = relay.Match(v, [
# list of length exactly 1
relay.Clause(
relay.PatternConstructor(p.cons, [relay.PatternWildcard(),
relay.PatternConstructor(p.nil, [])]), v),
# list of length exactly 2
relay.Clause(
relay.PatternConstructor(
p.cons, [relay.PatternWildcard(),
relay.PatternConstructor(p.cons, [relay.PatternWildcard(),
relay.PatternConstructor(p.nil, [])
])]), v),
# empty list
relay.Clause(
relay.PatternConstructor(p.nil, []), v),
# list of length 2 or more
relay.Clause(
relay.PatternConstructor(
p.cons, [relay.PatternWildcard(),
relay.PatternConstructor(p.cons, [relay.PatternWildcard(),
relay.PatternWildcard()])]), v)
])
assert len(unmatched_cases(match, mod)) == 0
def test_missing_in_the_middle():
mod = relay.Module()
p = Prelude(mod)
v = relay.Var('v')
match = relay.Match(v, [
# list of length exactly 1
relay.Clause(
relay.PatternConstructor(p.cons, [relay.PatternWildcard(),
relay.PatternConstructor(p.nil, [])]), v),
# empty list
relay.Clause(
relay.PatternConstructor(p.nil, []), v),
# list of length 3 or more
relay.Clause(
relay.PatternConstructor(
p.cons, [relay.PatternWildcard(),
relay.PatternConstructor(
p.cons,
[relay.PatternWildcard(),
relay.PatternConstructor(
p.cons,
[relay.PatternWildcard(),
relay.PatternWildcard()])])]),
v)
])
# fails to match a list of length exactly two
unmatched = unmatched_cases(match, mod)
assert len(unmatched) == 1
assert isinstance(unmatched[0], relay.PatternConstructor)
assert unmatched[0].constructor == p.cons
assert isinstance(unmatched[0].patterns[1], relay.PatternConstructor)
assert unmatched[0].patterns[1].constructor == p.cons
assert isinstance(unmatched[0].patterns[1].patterns[1], relay.PatternConstructor)
assert unmatched[0].patterns[1].patterns[1].constructor == p.nil
def test_mixed_adt_constructors():
mod = relay.Module()
box = relay.GlobalTypeVar('box')
a = relay.TypeVar('a')
box_ctor = relay.Constructor('box', [a], box)
box_data = relay.TypeData(box, [a], [box_ctor])
mod[box] = box_data
p = Prelude(mod)
v = relay.Var('v')
box_of_lists_inc = relay.Match(v, [
relay.Clause(
relay.PatternConstructor(
box_ctor,
[relay.PatternConstructor(p.cons, [
relay.PatternWildcard(), relay.PatternWildcard()])]), v)
])
# will fail to match a box containing an empty list
unmatched = unmatched_cases(box_of_lists_inc, mod)
assert len(unmatched) == 1
assert isinstance(unmatched[0], relay.PatternConstructor)
assert unmatched[0].constructor == box_ctor
assert len(unmatched[0].patterns) == 1 and unmatched[0].patterns[0].constructor == p.nil
box_of_lists_comp = relay.Match(v, [
relay.Clause(
relay.PatternConstructor(
box_ctor, [relay.PatternConstructor(p.nil, [])]), v),
relay.Clause(
relay.PatternConstructor(
box_ctor, [relay.PatternConstructor(p.cons, [
relay.PatternWildcard(), relay.PatternWildcard()])]), v)
])
assert len(unmatched_cases(box_of_lists_comp, mod)) == 0
list_of_boxes_inc = relay.Match(v, [
relay.Clause(
relay.PatternConstructor(
p.cons, [relay.PatternConstructor(box_ctor, [relay.PatternWildcard()]),
relay.PatternWildcard()]), v)
])
# fails to match empty list of boxes
unmatched = unmatched_cases(list_of_boxes_inc, mod)
assert len(unmatched) == 1
assert isinstance(unmatched[0], relay.PatternConstructor)
assert unmatched[0].constructor == p.nil
list_of_boxes_comp = relay.Match(v, [
# exactly one box
relay.Clause(
relay.PatternConstructor(
p.cons, [relay.PatternConstructor(box_ctor, [relay.PatternWildcard()]),
relay.PatternConstructor(p.nil, [])]), v),
# exactly two boxes
relay.Clause(
relay.PatternConstructor(
p.cons, [relay.PatternConstructor(box_ctor, [relay.PatternWildcard()]),
relay.PatternConstructor(p.cons, [
relay.PatternConstructor(box_ctor, [relay.PatternWildcard()]),
relay.PatternConstructor(p.nil, [])
])]), v),
# exactly three boxes
relay.Clause(
relay.PatternConstructor(
p.cons, [relay.PatternConstructor(box_ctor, [relay.PatternWildcard()]),
relay.PatternConstructor(p.cons, [
relay.PatternConstructor(box_ctor, [relay.PatternWildcard()]),
relay.PatternConstructor(p.cons, [
relay.PatternConstructor(box_ctor, [relay.PatternWildcard()]),
relay.PatternConstructor(p.nil, [])
])])]), v),
# one or more boxes
relay.Clause(relay.PatternConstructor(p.cons, [relay.PatternWildcard(),
relay.PatternWildcard()]), v),
# no boxes
relay.Clause(relay.PatternConstructor(p.nil, []), v)
])
assert len(unmatched_cases(list_of_boxes_comp, mod)) == 0
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment