Unverified Commit b422f6a9 by 雾雨魔理沙 Committed by GitHub

[WIP] Fixing an Infinite Loop case in UnmatchedChecker. (#4881)

* save

* save

* remove

* remove cerr
parent 545f6ea3
...@@ -168,8 +168,10 @@ Array<Pattern> ExpandWildcards(const Pattern& clause_pat, ...@@ -168,8 +168,10 @@ Array<Pattern> ExpandWildcards(const Pattern& clause_pat,
const IRModule& mod) { const IRModule& mod) {
if (auto clause_ctor = clause_pat.as<PatternConstructorNode>()) { if (auto clause_ctor = clause_pat.as<PatternConstructorNode>()) {
return ExpandWildcardsConstructor(GetRef<PatternConstructor>(clause_ctor), cand, mod); return ExpandWildcardsConstructor(GetRef<PatternConstructor>(clause_ctor), cand, mod);
} else if (auto clause_tup = clause_pat.as<PatternTupleNode>()) {
return ExpandWildcardsTuple(GetRef<PatternTuple>(clause_tup), cand, mod);
} else { } else {
return ExpandWildcardsTuple(Downcast<PatternTuple>(clause_pat), cand, mod); return {cand};
} }
} }
...@@ -201,18 +203,9 @@ Array<Pattern> ExpandWildcardsConstructor(const PatternConstructor& clause_ctor, ...@@ -201,18 +203,9 @@ Array<Pattern> ExpandWildcardsConstructor(const PatternConstructor& clause_ctor,
// for constructors, we will expand the wildcards in any field that is an ADT. // for constructors, we will expand the wildcards in any field that is an ADT.
Array<Array<Pattern>> values_by_field; Array<Array<Pattern>> values_by_field;
for (size_t i = 0; i < ctor_cand->constructor->inputs.size(); i++) { for (size_t i = 0; i < ctor_cand->constructor->inputs.size(); i++) {
bool subpattern = values_by_field.push_back(ExpandWildcards(clause_ctor->patterns[i],
clause_ctor->patterns[i].as<PatternConstructorNode>() || ctor_cand->patterns[i],
clause_ctor->patterns[i].as<PatternTupleNode>(); mod));
// for non-ADT fields, we can only have a wildcard for the value.
if (!subpattern) {
values_by_field.push_back({PatternWildcardNode::make()});
} else {
// otherwise, recursively expand.
values_by_field.push_back(ExpandWildcards(clause_ctor->patterns[i],
ctor_cand->patterns[i],
mod));
}
} }
// generate new candidates using a cartesian product. // generate new candidates using a cartesian product.
...@@ -243,18 +236,9 @@ Array<Pattern> ExpandWildcardsTuple(const PatternTuple& clause_tuple, ...@@ -243,18 +236,9 @@ Array<Pattern> ExpandWildcardsTuple(const PatternTuple& clause_tuple,
// for constructors, we will expand the wildcards in any field that is an ADT. // for constructors, we will expand the wildcards in any field that is an ADT.
Array<Array<Pattern>> values_by_field; Array<Array<Pattern>> values_by_field;
for (size_t i = 0; i < tuple_cand->patterns.size(); i++) { for (size_t i = 0; i < tuple_cand->patterns.size(); i++) {
bool subpattern = values_by_field.push_back(ExpandWildcards(clause_tuple->patterns[i],
clause_tuple->patterns[i].as<PatternConstructorNode>() || tuple_cand->patterns[i],
clause_tuple->patterns[i].as<PatternTupleNode>(); mod));
// for non-ADT fields, we can only have a wildcard for the value
if (!subpattern) {
values_by_field.push_back({PatternWildcardNode::make()});
} else {
// otherwise, recursively expand
values_by_field.push_back(ExpandWildcards(clause_tuple->patterns[i],
tuple_cand->patterns[i],
mod));
}
} }
// generate new candidates using a cartesian product // generate new candidates using a cartesian product
......
...@@ -19,6 +19,7 @@ import tvm ...@@ -19,6 +19,7 @@ import tvm
from tvm import relay from tvm import relay
from tvm.relay.prelude import Prelude from tvm.relay.prelude import Prelude
from tvm.relay.analysis import unmatched_cases from tvm.relay.analysis import unmatched_cases
import pytest
def test_empty_match_block(): def test_empty_match_block():
# empty match block will not match anything, so it should return a wildcard pattern # empty match block will not match anything, so it should return a wildcard pattern
...@@ -273,3 +274,27 @@ def test_tuple_match(): ...@@ -273,3 +274,27 @@ def test_tuple_match():
clause = relay.Clause(relay.PatternTuple([relay.PatternVar(a), relay.PatternVar(b)]), a + b) clause = relay.Clause(relay.PatternTuple([relay.PatternVar(a), relay.PatternVar(b)]), a + b)
x = relay.Match(relay.Tuple([relay.const(1), relay.const(1)]), [clause]) x = relay.Match(relay.Tuple([relay.const(1), relay.const(1)]), [clause])
assert len(unmatched_cases(x)) == 0 assert len(unmatched_cases(x)) == 0
def test_inf_loop_case():
code = """
v0.0.4
type Arith[A] {
Zero,
Const(A),
Plus(Arith[A], Arith[A])
}
def @shallow_opt[A](%a: Arith[A]) -> Arith[A] {
match (%a) {
Plus(Zero, %r) => %r,
Plus(%l, Zero) => %l,
_ => %a
}
}
"""
relay.fromtext(code)
# fromtext parse the module, then checked it (which include strictness checking).
if __name__ == "__main__":
pytest.main([__file__])
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment