Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
a698ad7f
Commit
a698ad7f
authored
Jun 13, 2019
by
Steven S. Lyubomirsky
Committed by
Tianqi Chen
Jun 13, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Relay] Check match expressions for completeness (#3203)
parent
6e2c7ede
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
574 additions
and
3 deletions
+574
-3
include/tvm/relay/pass.h
+30
-1
python/tvm/relay/ir_pass.py
+18
-0
python/tvm/relay/prelude.py
+0
-2
src/relay/pass/match_exhaustion.cc
+250
-0
src/relay/pass/type_infer.cc
+9
-0
tests/python/relay/test_pass_unmatched_cases.py
+267
-0
No files found.
include/tvm/relay/pass.h
View file @
a698ad7f
...
@@ -123,6 +123,24 @@ TVM_DLL bool AlphaEqual(const Expr& e1, const Expr& e2);
...
@@ -123,6 +123,24 @@ TVM_DLL bool AlphaEqual(const Expr& e1, const Expr& e2);
TVM_DLL
bool
AlphaEqual
(
const
Type
&
t1
,
const
Type
&
t2
);
TVM_DLL
bool
AlphaEqual
(
const
Type
&
t1
,
const
Type
&
t2
);
/*!
/*!
* \brief Compare two patterns for structural equivalence.
*
* This comparison operator respects scoping and compares
* patterns without regard to variable choice.
*
* For example: `A(x, _, y)` is equal to `A(z, _, a)`.
*
* See https://en.wikipedia.org/wiki/Lambda_calculus#Alpha_equivalence
* for more details.
*
* \param t1 The left hand pattern.
* \param t2 The right hand pattern.
*
* \return true if equal, otherwise false
*/
TVM_DLL
bool
AlphaEqual
(
const
Pattern
&
t1
,
const
Pattern
&
t2
);
/*!
* \brief Add abstraction over a function
* \brief Add abstraction over a function
*
*
* For example: `square` is transformed to
* For example: `square` is transformed to
...
@@ -400,8 +418,19 @@ TVM_DLL Expr ToANormalForm(const Expr& e, const Module& mod);
...
@@ -400,8 +418,19 @@ TVM_DLL Expr ToANormalForm(const Expr& e, const Module& mod);
TVM_DLL
Expr
ToGraphNormalForm
(
const
Expr
&
e
);
TVM_DLL
Expr
ToGraphNormalForm
(
const
Expr
&
e
);
/*!
/*!
* \brief Aggressive constant propagation/constant folding/inlining.
* \brief Finds cases that the given match expression does not catch, if any.
*
* \param match the match expression to test
*
* \param mod The module used for accessing global type var definitions, can be None.
*
*
* \return Returns a list of cases (as patterns) that are not handled by the match
* expression.
*/
TVM_DLL
Array
<
Pattern
>
UnmatchedCases
(
const
Match
&
match
,
const
Module
&
mod
);
/*!
* \brief Aggressive constant propagation/constant folding/inlining.
* It will do as much computation in compile time as possible.
* It will do as much computation in compile time as possible.
* It has two benefit: remove runtime overhead, and allow more optimization (typically fusion).
* It has two benefit: remove runtime overhead, and allow more optimization (typically fusion).
* As a side effect, code size will explode.
* As a side effect, code size will explode.
...
...
python/tvm/relay/ir_pass.py
View file @
a698ad7f
...
@@ -652,3 +652,21 @@ def partial_evaluate(expr):
...
@@ -652,3 +652,21 @@ def partial_evaluate(expr):
The output expression.
The output expression.
"""
"""
return
_ir_pass
.
partial_evaluate
(
expr
)
return
_ir_pass
.
partial_evaluate
(
expr
)
def
unmatched_cases
(
match
,
mod
=
None
):
"""
Finds cases that the match expression does not catch, if any.
Parameters
----------
match : tvm.relay.Match
The match expression
mod : Optional[tvm.relay.Module]
The module (defaults to an empty module)
Returns
-------
missing_patterns : [tvm.relay.Pattern]
Patterns that the match expression does not catch.
"""
return
_ir_pass
.
unmatched_cases
(
match
,
mod
)
python/tvm/relay/prelude.py
View file @
a698ad7f
...
@@ -39,7 +39,6 @@ class Prelude:
...
@@ -39,7 +39,6 @@ class Prelude:
self
.
cons
=
Constructor
(
"cons"
,
[
a
,
self
.
l
(
a
)],
self
.
l
)
self
.
cons
=
Constructor
(
"cons"
,
[
a
,
self
.
l
(
a
)],
self
.
l
)
self
.
mod
[
self
.
l
]
=
TypeData
(
self
.
l
,
[
a
],
[
self
.
nil
,
self
.
cons
])
self
.
mod
[
self
.
l
]
=
TypeData
(
self
.
l
,
[
a
],
[
self
.
nil
,
self
.
cons
])
def
define_list_hd
(
self
):
def
define_list_hd
(
self
):
"""Defines a function to get the head of a list. Assume the list has at least one
"""Defines a function to get the head of a list. Assume the list has at least one
element.
element.
...
@@ -54,7 +53,6 @@ class Prelude:
...
@@ -54,7 +53,6 @@ class Prelude:
cons_case
=
Clause
(
PatternConstructor
(
self
.
cons
,
[
PatternVar
(
y
),
PatternVar
(
z
)]),
y
)
cons_case
=
Clause
(
PatternConstructor
(
self
.
cons
,
[
PatternVar
(
y
),
PatternVar
(
z
)]),
y
)
self
.
mod
[
self
.
hd
]
=
Function
([
x
],
Match
(
x
,
[
cons_case
]),
a
,
[
a
])
self
.
mod
[
self
.
hd
]
=
Function
([
x
],
Match
(
x
,
[
cons_case
]),
a
,
[
a
])
def
define_list_tl
(
self
):
def
define_list_tl
(
self
):
"""Defines a function to get the tail of a list.
"""Defines a function to get the tail of a list.
...
...
src/relay/pass/match_exhaustion.cc
0 → 100644
View file @
a698ad7f
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2019 by Contributors
* \file match_exhaustion.cc
* \brief Checking Relay match expression exhaustiveness.
*
* This file implements a function that checks whether a match
* expression is exhaustive, that is, whether a given match clause
* matches every possible case. This is important for ensuring
* code correctness, since hitting an unmatched case results in a
* dynamic error unless exhaustiveness is checked in advance.
*/
#include <tvm/relay/adt.h>
#include <tvm/relay/error.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h>
#include <tvm/relay/pass.h>
#include <stack>
namespace
tvm
{
namespace
relay
{
/*! \brief Possible pattern match results */
enum
MatchResult
:
int
{
kMatch
=
0
,
// pattern matches
kClash
=
1
,
// pattern conflicts
kUnspecified
=
2
,
// ambiguous: candidate needs more constructors specified
};
class
CandidateChecker
:
public
PatternFunctor
<
MatchResult
(
const
Pattern
&
,
const
Pattern
&
)
>
{
public
:
explicit
CandidateChecker
()
{}
MatchResult
Check
(
const
Pattern
&
pat
,
const
Pattern
&
candidate
)
{
return
this
->
VisitPattern
(
pat
,
candidate
);
}
// for a constructor pattern, we must ensure that the candidate is
// a ConstructorPattern, that it has the same constructor, and
// that its fields match the subpatterns.
MatchResult
VisitPattern_
(
const
PatternConstructorNode
*
op
,
const
Pattern
&
cand
)
override
{
auto
*
ctor_cand
=
cand
.
as
<
PatternConstructorNode
>
();
// attempting to match non-constructor to constructor pattern: need to specify
if
(
ctor_cand
==
nullptr
)
{
return
MatchResult
::
kUnspecified
;
}
// check that constructors match
if
(
!
op
->
constructor
.
same_as
(
ctor_cand
->
constructor
))
{
return
MatchResult
::
kClash
;
}
// now check that subpatterns match
CHECK
(
op
->
patterns
.
size
()
==
ctor_cand
->
patterns
.
size
());
bool
unspecified
=
false
;
for
(
size_t
i
=
0
;
i
<
op
->
patterns
.
size
();
i
++
)
{
MatchResult
submatch
=
this
->
Check
(
op
->
patterns
[
i
],
ctor_cand
->
patterns
[
i
]);
// if we have a clash anywhere, then we can return clash
if
(
submatch
==
MatchResult
::
kClash
)
{
return
MatchResult
::
kClash
;
}
if
(
submatch
==
MatchResult
::
kUnspecified
)
{
unspecified
=
true
;
}
}
// only return unspecified if we have ruled out a clash
if
(
unspecified
)
{
return
MatchResult
::
kUnspecified
;
}
return
MatchResult
::
kMatch
;
}
// wildcard and var patterns always match
MatchResult
VisitPattern_
(
const
PatternWildcardNode
*
,
const
Pattern
&
)
override
{
return
MatchResult
::
kMatch
;
}
MatchResult
VisitPattern_
(
const
PatternVarNode
*
,
const
Pattern
&
)
override
{
return
MatchResult
::
kMatch
;
}
};
// Returns list of arrays corresponding to Cartesian product of input list
Array
<
Array
<
Pattern
>>
CartesianProduct
(
Array
<
Array
<
Pattern
>>
fields
)
{
CHECK_NE
(
fields
.
size
(),
0
);
Array
<
Pattern
>
field_vals
=
fields
[
fields
.
size
()
-
1
];
Array
<
Array
<
Pattern
>>
ret
;
// base case: this is the last field left
if
(
fields
.
size
()
==
1
)
{
for
(
auto
val
:
field_vals
)
{
ret
.
push_back
(
Array
<
Pattern
>
{
val
});
}
return
ret
;
}
// if we have more fields left, get the sub-candidates by getting
// their cartesian product and appending the elements here onto those
Array
<
Array
<
Pattern
>>
remaining_fields
;
for
(
size_t
i
=
0
;
i
<
fields
.
size
()
-
1
;
i
++
)
{
remaining_fields
.
push_back
(
fields
[
i
]);
}
Array
<
Array
<
Pattern
>>
candidates
=
CartesianProduct
(
remaining_fields
);
for
(
auto
val
:
field_vals
)
{
for
(
auto
candidate
:
candidates
)
{
candidate
.
push_back
(
val
);
ret
.
push_back
(
candidate
);
}
}
return
ret
;
}
// Expands all wildcards in the candidate pattern once, using the pattern
// to decide which constructors to insert. Returns a list of all possible expansions.
Array
<
Pattern
>
ExpandWildcards
(
const
Pattern
&
clause_pat
,
const
Pattern
&
cand
,
const
Module
&
mod
)
{
auto
ctor_cand
=
cand
.
as
<
PatternConstructorNode
>
();
PatternConstructor
clause_ctor
=
Downcast
<
PatternConstructor
>
(
clause_pat
);
auto
gtv
=
Downcast
<
GlobalTypeVar
>
(
clause_ctor
->
constructor
->
belong_to
);
// for a wildcard node, create constructor nodes with wildcards for all args
if
(
!
ctor_cand
)
{
TypeData
td
=
mod
->
LookupDef
(
gtv
);
// for each constructor add a candidate
Array
<
Pattern
>
ret
;
for
(
auto
constructor
:
td
->
constructors
)
{
Array
<
Pattern
>
args
;
for
(
auto
inp
:
constructor
->
inputs
)
{
args
.
push_back
(
PatternWildcardNode
::
make
());
}
ret
.
push_back
(
PatternConstructorNode
::
make
(
constructor
,
args
));
}
return
ret
;
}
// for constructors, we will expand the wildcards in any field
// that is an ADT
Array
<
Array
<
Pattern
>>
values_by_field
;
for
(
size_t
i
=
0
;
i
<
ctor_cand
->
constructor
->
inputs
.
size
();
i
++
)
{
auto
*
subpattern
=
clause_ctor
->
patterns
[
i
].
as
<
PatternConstructorNode
>
();
// for non-ADT fields, we can only have a wildcard for the value
if
(
!
subpattern
)
{
values_by_field
.
push_back
({
PatternWildcardNode
::
make
()});
continue
;
}
// otherwise, recursively expand
values_by_field
.
push_back
(
ExpandWildcards
(
GetRef
<
Pattern
>
(
subpattern
),
ctor_cand
->
patterns
[
i
],
mod
));
}
// generate new candidates using a cartesian product
auto
all_subfields
=
CartesianProduct
(
values_by_field
);
Array
<
Pattern
>
ret
;
for
(
auto
subfields
:
all_subfields
)
{
ret
.
push_back
(
PatternConstructorNode
::
make
(
ctor_cand
->
constructor
,
subfields
));
}
return
ret
;
}
/*!
* \brief Finds cases that the match expression does not catch, if any.
* \return Returns a list of cases that are not handled by the match
* expression.
*/
Array
<
Pattern
>
UnmatchedCases
(
const
Match
&
match
,
const
Module
&
mod
)
{
/* algorithm:
* candidates = { Wildcard }
* while candidates not empty {
* cand = candidates.pop()
* for clause in clauses {
* if clause fails: next clause
* if clause matches candidate: next candidate
* if candidate is not specific enough:
* candidates += expand_possible_wildcards(cand)
* next candidate
* }
* failed_candidates += { cand }
* }
* return failed_candidates
*/
std
::
stack
<
Pattern
>
candidates
;
candidates
.
push
(
PatternWildcardNode
::
make
());
CandidateChecker
checker
;
Array
<
Pattern
>
failures
;
while
(
!
candidates
.
empty
())
{
Pattern
cand
=
candidates
.
top
();
candidates
.
pop
();
bool
failure
=
true
;
for
(
auto
clause
:
match
->
clauses
)
{
// if the check fails, we move on to the next
MatchResult
check
=
checker
.
Check
(
clause
->
lhs
,
cand
);
if
(
check
==
MatchResult
::
kClash
)
{
continue
;
}
// either success or we need to generate more candidates;
// either way, we're done with this candidate
failure
=
false
;
if
(
check
==
MatchResult
::
kUnspecified
)
{
auto
new_candidates
=
ExpandWildcards
(
clause
->
lhs
,
cand
,
mod
);
for
(
auto
candidate
:
new_candidates
)
{
candidates
.
push
(
candidate
);
}
}
break
;
}
if
(
failure
)
{
failures
.
push_back
(
cand
);
}
}
return
failures
;
}
// expose for testing only
TVM_REGISTER_API
(
"relay._ir_pass.unmatched_cases"
)
.
set_body_typed
<
Array
<
Pattern
>
(
const
Match
&
,
const
Module
&
)
>
([](
const
Match
&
match
,
const
Module
&
mod_ref
)
{
Module
call_mod
=
mod_ref
;
if
(
!
call_mod
.
defined
())
{
call_mod
=
ModuleNode
::
make
({},
{});
}
return
UnmatchedCases
(
match
,
call_mod
);
});
}
// namespace relay
}
// namespace tvm
src/relay/pass/type_infer.cc
View file @
a698ad7f
...
@@ -293,6 +293,15 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
...
@@ -293,6 +293,15 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
GetType
(
c
->
rhs
),
GetType
(
c
->
rhs
),
op
->
span
);
op
->
span
);
}
}
// check completness
Match
match
=
GetRef
<
Match
>
(
op
);
Array
<
Pattern
>
unmatched_cases
=
UnmatchedCases
(
match
,
this
->
mod_
);
if
(
unmatched_cases
.
size
()
!=
0
)
{
LOG
(
WARNING
)
<<
"Match clause "
<<
match
<<
" does not handle the following cases: "
<<
unmatched_cases
;
}
return
rtype
;
return
rtype
;
}
}
...
...
tests/python/relay/test_pass_unmatched_cases.py
0 → 100644
View file @
a698ad7f
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import
tvm
from
tvm
import
relay
from
tvm.relay.prelude
import
Prelude
from
tvm.relay.ir_pass
import
unmatched_cases
def
test_empty_match_block
():
# empty match block will not match anything, so it should return a wildcard pattern
v
=
relay
.
Var
(
'v'
)
match
=
relay
.
Match
(
v
,
[])
unmatched
=
unmatched_cases
(
match
)
assert
len
(
unmatched
)
==
1
assert
isinstance
(
unmatched
[
0
],
relay
.
PatternWildcard
)
def
test_trivial_matches
():
# a match clause with a wildcard will match anything
v
=
relay
.
Var
(
'v'
)
match
=
relay
.
Match
(
v
,
[
relay
.
Clause
(
relay
.
PatternWildcard
(),
v
)
])
assert
len
(
unmatched_cases
(
match
))
==
0
# same with a pattern var
w
=
relay
.
Var
(
'w'
)
match
=
relay
.
Match
(
v
,
[
relay
.
Clause
(
relay
.
PatternVar
(
w
),
w
)
])
assert
len
(
unmatched_cases
(
match
))
==
0
def
test_single_constructor_adt
():
mod
=
relay
.
Module
()
box
=
relay
.
GlobalTypeVar
(
'box'
)
a
=
relay
.
TypeVar
(
'a'
)
box_ctor
=
relay
.
Constructor
(
'box'
,
[
a
],
box
)
box_data
=
relay
.
TypeData
(
box
,
[
a
],
[
box_ctor
])
mod
[
box
]
=
box_data
v
=
relay
.
Var
(
'v'
)
match
=
relay
.
Match
(
v
,
[
relay
.
Clause
(
relay
.
PatternConstructor
(
box_ctor
,
[
relay
.
PatternWildcard
()]),
v
)
])
# with one constructor, having one pattern constructor case is exhaustive
assert
len
(
unmatched_cases
(
match
,
mod
))
==
0
# this will be so if we nest the constructors too
nested_pattern
=
relay
.
Match
(
v
,
[
relay
.
Clause
(
relay
.
PatternConstructor
(
box_ctor
,
[
relay
.
PatternConstructor
(
box_ctor
,
[
relay
.
PatternConstructor
(
box_ctor
,
[
relay
.
PatternWildcard
()])])]),
v
)
])
assert
len
(
unmatched_cases
(
nested_pattern
,
mod
))
==
0
def
test_too_specific_match
():
mod
=
relay
.
Module
()
p
=
Prelude
(
mod
)
v
=
relay
.
Var
(
'v'
)
match
=
relay
.
Match
(
v
,
[
relay
.
Clause
(
relay
.
PatternConstructor
(
p
.
cons
,
[
relay
.
PatternWildcard
(),
relay
.
PatternConstructor
(
p
.
cons
,
[
relay
.
PatternWildcard
(),
relay
.
PatternWildcard
()])]),
v
)
])
unmatched
=
unmatched_cases
(
match
,
mod
)
# will not match nil or a list of length 1
nil_found
=
False
single_length_found
=
False
assert
len
(
unmatched
)
==
2
for
case
in
unmatched
:
assert
isinstance
(
case
,
relay
.
PatternConstructor
)
if
case
.
constructor
==
p
.
nil
:
nil_found
=
True
if
case
.
constructor
==
p
.
cons
:
assert
isinstance
(
case
.
patterns
[
1
],
relay
.
PatternConstructor
)
assert
case
.
patterns
[
1
]
.
constructor
==
p
.
nil
single_length_found
=
True
assert
nil_found
and
single_length_found
# if we add a wildcard, this should work
new_match
=
relay
.
Match
(
v
,
[
relay
.
Clause
(
relay
.
PatternConstructor
(
p
.
cons
,
[
relay
.
PatternWildcard
(),
relay
.
PatternConstructor
(
p
.
cons
,
[
relay
.
PatternWildcard
(),
relay
.
PatternWildcard
()])]),
v
),
relay
.
Clause
(
relay
.
PatternWildcard
(),
v
)
])
assert
len
(
unmatched_cases
(
new_match
,
mod
))
==
0
def
test_multiple_constructor_clauses
():
mod
=
relay
.
Module
()
p
=
Prelude
(
mod
)
v
=
relay
.
Var
(
'v'
)
match
=
relay
.
Match
(
v
,
[
# list of length exactly 1
relay
.
Clause
(
relay
.
PatternConstructor
(
p
.
cons
,
[
relay
.
PatternWildcard
(),
relay
.
PatternConstructor
(
p
.
nil
,
[])]),
v
),
# list of length exactly 2
relay
.
Clause
(
relay
.
PatternConstructor
(
p
.
cons
,
[
relay
.
PatternWildcard
(),
relay
.
PatternConstructor
(
p
.
cons
,
[
relay
.
PatternWildcard
(),
relay
.
PatternConstructor
(
p
.
nil
,
[])
])]),
v
),
# empty list
relay
.
Clause
(
relay
.
PatternConstructor
(
p
.
nil
,
[]),
v
),
# list of length 2 or more
relay
.
Clause
(
relay
.
PatternConstructor
(
p
.
cons
,
[
relay
.
PatternWildcard
(),
relay
.
PatternConstructor
(
p
.
cons
,
[
relay
.
PatternWildcard
(),
relay
.
PatternWildcard
()])]),
v
)
])
assert
len
(
unmatched_cases
(
match
,
mod
))
==
0
def
test_missing_in_the_middle
():
mod
=
relay
.
Module
()
p
=
Prelude
(
mod
)
v
=
relay
.
Var
(
'v'
)
match
=
relay
.
Match
(
v
,
[
# list of length exactly 1
relay
.
Clause
(
relay
.
PatternConstructor
(
p
.
cons
,
[
relay
.
PatternWildcard
(),
relay
.
PatternConstructor
(
p
.
nil
,
[])]),
v
),
# empty list
relay
.
Clause
(
relay
.
PatternConstructor
(
p
.
nil
,
[]),
v
),
# list of length 3 or more
relay
.
Clause
(
relay
.
PatternConstructor
(
p
.
cons
,
[
relay
.
PatternWildcard
(),
relay
.
PatternConstructor
(
p
.
cons
,
[
relay
.
PatternWildcard
(),
relay
.
PatternConstructor
(
p
.
cons
,
[
relay
.
PatternWildcard
(),
relay
.
PatternWildcard
()])])]),
v
)
])
# fails to match a list of length exactly two
unmatched
=
unmatched_cases
(
match
,
mod
)
assert
len
(
unmatched
)
==
1
assert
isinstance
(
unmatched
[
0
],
relay
.
PatternConstructor
)
assert
unmatched
[
0
]
.
constructor
==
p
.
cons
assert
isinstance
(
unmatched
[
0
]
.
patterns
[
1
],
relay
.
PatternConstructor
)
assert
unmatched
[
0
]
.
patterns
[
1
]
.
constructor
==
p
.
cons
assert
isinstance
(
unmatched
[
0
]
.
patterns
[
1
]
.
patterns
[
1
],
relay
.
PatternConstructor
)
assert
unmatched
[
0
]
.
patterns
[
1
]
.
patterns
[
1
]
.
constructor
==
p
.
nil
def
test_mixed_adt_constructors
():
mod
=
relay
.
Module
()
box
=
relay
.
GlobalTypeVar
(
'box'
)
a
=
relay
.
TypeVar
(
'a'
)
box_ctor
=
relay
.
Constructor
(
'box'
,
[
a
],
box
)
box_data
=
relay
.
TypeData
(
box
,
[
a
],
[
box_ctor
])
mod
[
box
]
=
box_data
p
=
Prelude
(
mod
)
v
=
relay
.
Var
(
'v'
)
box_of_lists_inc
=
relay
.
Match
(
v
,
[
relay
.
Clause
(
relay
.
PatternConstructor
(
box_ctor
,
[
relay
.
PatternConstructor
(
p
.
cons
,
[
relay
.
PatternWildcard
(),
relay
.
PatternWildcard
()])]),
v
)
])
# will fail to match a box containing an empty list
unmatched
=
unmatched_cases
(
box_of_lists_inc
,
mod
)
assert
len
(
unmatched
)
==
1
assert
isinstance
(
unmatched
[
0
],
relay
.
PatternConstructor
)
assert
unmatched
[
0
]
.
constructor
==
box_ctor
assert
len
(
unmatched
[
0
]
.
patterns
)
==
1
and
unmatched
[
0
]
.
patterns
[
0
]
.
constructor
==
p
.
nil
box_of_lists_comp
=
relay
.
Match
(
v
,
[
relay
.
Clause
(
relay
.
PatternConstructor
(
box_ctor
,
[
relay
.
PatternConstructor
(
p
.
nil
,
[])]),
v
),
relay
.
Clause
(
relay
.
PatternConstructor
(
box_ctor
,
[
relay
.
PatternConstructor
(
p
.
cons
,
[
relay
.
PatternWildcard
(),
relay
.
PatternWildcard
()])]),
v
)
])
assert
len
(
unmatched_cases
(
box_of_lists_comp
,
mod
))
==
0
list_of_boxes_inc
=
relay
.
Match
(
v
,
[
relay
.
Clause
(
relay
.
PatternConstructor
(
p
.
cons
,
[
relay
.
PatternConstructor
(
box_ctor
,
[
relay
.
PatternWildcard
()]),
relay
.
PatternWildcard
()]),
v
)
])
# fails to match empty list of boxes
unmatched
=
unmatched_cases
(
list_of_boxes_inc
,
mod
)
assert
len
(
unmatched
)
==
1
assert
isinstance
(
unmatched
[
0
],
relay
.
PatternConstructor
)
assert
unmatched
[
0
]
.
constructor
==
p
.
nil
list_of_boxes_comp
=
relay
.
Match
(
v
,
[
# exactly one box
relay
.
Clause
(
relay
.
PatternConstructor
(
p
.
cons
,
[
relay
.
PatternConstructor
(
box_ctor
,
[
relay
.
PatternWildcard
()]),
relay
.
PatternConstructor
(
p
.
nil
,
[])]),
v
),
# exactly two boxes
relay
.
Clause
(
relay
.
PatternConstructor
(
p
.
cons
,
[
relay
.
PatternConstructor
(
box_ctor
,
[
relay
.
PatternWildcard
()]),
relay
.
PatternConstructor
(
p
.
cons
,
[
relay
.
PatternConstructor
(
box_ctor
,
[
relay
.
PatternWildcard
()]),
relay
.
PatternConstructor
(
p
.
nil
,
[])
])]),
v
),
# exactly three boxes
relay
.
Clause
(
relay
.
PatternConstructor
(
p
.
cons
,
[
relay
.
PatternConstructor
(
box_ctor
,
[
relay
.
PatternWildcard
()]),
relay
.
PatternConstructor
(
p
.
cons
,
[
relay
.
PatternConstructor
(
box_ctor
,
[
relay
.
PatternWildcard
()]),
relay
.
PatternConstructor
(
p
.
cons
,
[
relay
.
PatternConstructor
(
box_ctor
,
[
relay
.
PatternWildcard
()]),
relay
.
PatternConstructor
(
p
.
nil
,
[])
])])]),
v
),
# one or more boxes
relay
.
Clause
(
relay
.
PatternConstructor
(
p
.
cons
,
[
relay
.
PatternWildcard
(),
relay
.
PatternWildcard
()]),
v
),
# no boxes
relay
.
Clause
(
relay
.
PatternConstructor
(
p
.
nil
,
[]),
v
)
])
assert
len
(
unmatched_cases
(
list_of_boxes_comp
,
mod
))
==
0
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment