Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
a698ad7f
Commit
a698ad7f
authored
Jun 13, 2019
by
Steven S. Lyubomirsky
Committed by
Tianqi Chen
Jun 13, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Relay] Check match expressions for completeness (#3203)
parent
6e2c7ede
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
307 additions
and
3 deletions
+307
-3
include/tvm/relay/pass.h
+30
-1
python/tvm/relay/ir_pass.py
+18
-0
python/tvm/relay/prelude.py
+0
-2
src/relay/pass/match_exhaustion.cc
+250
-0
src/relay/pass/type_infer.cc
+9
-0
tests/python/relay/test_pass_unmatched_cases.py
+0
-0
No files found.
include/tvm/relay/pass.h
View file @
a698ad7f
...
...
@@ -123,6 +123,24 @@ TVM_DLL bool AlphaEqual(const Expr& e1, const Expr& e2);
TVM_DLL
bool
AlphaEqual
(
const
Type
&
t1
,
const
Type
&
t2
);
/*!
* \brief Compare two patterns for structural equivalence.
*
* This comparison operator respects scoping and compares
* patterns without regard to variable choice.
*
* For example: `A(x, _, y)` is equal to `A(z, _, a)`.
*
* See https://en.wikipedia.org/wiki/Lambda_calculus#Alpha_equivalence
* for more details.
*
* \param t1 The left hand pattern.
* \param t2 The right hand pattern.
*
* \return true if equal, otherwise false
*/
TVM_DLL
bool
AlphaEqual
(
const
Pattern
&
t1
,
const
Pattern
&
t2
);
/*!
* \brief Add abstraction over a function
*
* For example: `square` is transformed to
...
...
@@ -400,8 +418,19 @@ TVM_DLL Expr ToANormalForm(const Expr& e, const Module& mod);
TVM_DLL
Expr
ToGraphNormalForm
(
const
Expr
&
e
);
/*!
* \brief Aggressive constant propagation/constant folding/inlining.
* \brief Finds cases that the given match expression does not catch, if any.
*
* \param match the match expression to test
*
* \param mod The module used for accessing global type var definitions, can be None.
*
* \return Returns a list of cases (as patterns) that are not handled by the match
* expression.
*/
TVM_DLL
Array
<
Pattern
>
UnmatchedCases
(
const
Match
&
match
,
const
Module
&
mod
);
/*!
* \brief Aggressive constant propagation/constant folding/inlining.
* It will do as much computation in compile time as possible.
* It has two benefit: remove runtime overhead, and allow more optimization (typically fusion).
* As a side effect, code size will explode.
...
...
python/tvm/relay/ir_pass.py
View file @
a698ad7f
...
...
@@ -652,3 +652,21 @@ def partial_evaluate(expr):
The output expression.
"""
return
_ir_pass
.
partial_evaluate
(
expr
)
def
unmatched_cases
(
match
,
mod
=
None
):
"""
Finds cases that the match expression does not catch, if any.
Parameters
----------
match : tvm.relay.Match
The match expression
mod : Optional[tvm.relay.Module]
The module (defaults to an empty module)
Returns
-------
missing_patterns : [tvm.relay.Pattern]
Patterns that the match expression does not catch.
"""
return
_ir_pass
.
unmatched_cases
(
match
,
mod
)
python/tvm/relay/prelude.py
View file @
a698ad7f
...
...
@@ -39,7 +39,6 @@ class Prelude:
self
.
cons
=
Constructor
(
"cons"
,
[
a
,
self
.
l
(
a
)],
self
.
l
)
self
.
mod
[
self
.
l
]
=
TypeData
(
self
.
l
,
[
a
],
[
self
.
nil
,
self
.
cons
])
def
define_list_hd
(
self
):
"""Defines a function to get the head of a list. Assume the list has at least one
element.
...
...
@@ -54,7 +53,6 @@ class Prelude:
cons_case
=
Clause
(
PatternConstructor
(
self
.
cons
,
[
PatternVar
(
y
),
PatternVar
(
z
)]),
y
)
self
.
mod
[
self
.
hd
]
=
Function
([
x
],
Match
(
x
,
[
cons_case
]),
a
,
[
a
])
def
define_list_tl
(
self
):
"""Defines a function to get the tail of a list.
...
...
src/relay/pass/match_exhaustion.cc
0 → 100644
View file @
a698ad7f
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2019 by Contributors
* \file match_exhaustion.cc
* \brief Checking Relay match expression exhaustiveness.
*
* This file implements a function that checks whether a match
* expression is exhaustive, that is, whether a given match clause
* matches every possible case. This is important for ensuring
* code correctness, since hitting an unmatched case results in a
* dynamic error unless exhaustiveness is checked in advance.
*/
#include <tvm/relay/adt.h>
#include <tvm/relay/error.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h>
#include <tvm/relay/pass.h>
#include <stack>
namespace
tvm
{
namespace
relay
{
/*! \brief Possible pattern match results */
enum
MatchResult
:
int
{
kMatch
=
0
,
// pattern matches
kClash
=
1
,
// pattern conflicts
kUnspecified
=
2
,
// ambiguous: candidate needs more constructors specified
};
class
CandidateChecker
:
public
PatternFunctor
<
MatchResult
(
const
Pattern
&
,
const
Pattern
&
)
>
{
public
:
explicit
CandidateChecker
()
{}
MatchResult
Check
(
const
Pattern
&
pat
,
const
Pattern
&
candidate
)
{
return
this
->
VisitPattern
(
pat
,
candidate
);
}
// for a constructor pattern, we must ensure that the candidate is
// a ConstructorPattern, that it has the same constructor, and
// that its fields match the subpatterns.
MatchResult
VisitPattern_
(
const
PatternConstructorNode
*
op
,
const
Pattern
&
cand
)
override
{
auto
*
ctor_cand
=
cand
.
as
<
PatternConstructorNode
>
();
// attempting to match non-constructor to constructor pattern: need to specify
if
(
ctor_cand
==
nullptr
)
{
return
MatchResult
::
kUnspecified
;
}
// check that constructors match
if
(
!
op
->
constructor
.
same_as
(
ctor_cand
->
constructor
))
{
return
MatchResult
::
kClash
;
}
// now check that subpatterns match
CHECK
(
op
->
patterns
.
size
()
==
ctor_cand
->
patterns
.
size
());
bool
unspecified
=
false
;
for
(
size_t
i
=
0
;
i
<
op
->
patterns
.
size
();
i
++
)
{
MatchResult
submatch
=
this
->
Check
(
op
->
patterns
[
i
],
ctor_cand
->
patterns
[
i
]);
// if we have a clash anywhere, then we can return clash
if
(
submatch
==
MatchResult
::
kClash
)
{
return
MatchResult
::
kClash
;
}
if
(
submatch
==
MatchResult
::
kUnspecified
)
{
unspecified
=
true
;
}
}
// only return unspecified if we have ruled out a clash
if
(
unspecified
)
{
return
MatchResult
::
kUnspecified
;
}
return
MatchResult
::
kMatch
;
}
// wildcard and var patterns always match
MatchResult
VisitPattern_
(
const
PatternWildcardNode
*
,
const
Pattern
&
)
override
{
return
MatchResult
::
kMatch
;
}
MatchResult
VisitPattern_
(
const
PatternVarNode
*
,
const
Pattern
&
)
override
{
return
MatchResult
::
kMatch
;
}
};
// Returns list of arrays corresponding to Cartesian product of input list
Array
<
Array
<
Pattern
>>
CartesianProduct
(
Array
<
Array
<
Pattern
>>
fields
)
{
CHECK_NE
(
fields
.
size
(),
0
);
Array
<
Pattern
>
field_vals
=
fields
[
fields
.
size
()
-
1
];
Array
<
Array
<
Pattern
>>
ret
;
// base case: this is the last field left
if
(
fields
.
size
()
==
1
)
{
for
(
auto
val
:
field_vals
)
{
ret
.
push_back
(
Array
<
Pattern
>
{
val
});
}
return
ret
;
}
// if we have more fields left, get the sub-candidates by getting
// their cartesian product and appending the elements here onto those
Array
<
Array
<
Pattern
>>
remaining_fields
;
for
(
size_t
i
=
0
;
i
<
fields
.
size
()
-
1
;
i
++
)
{
remaining_fields
.
push_back
(
fields
[
i
]);
}
Array
<
Array
<
Pattern
>>
candidates
=
CartesianProduct
(
remaining_fields
);
for
(
auto
val
:
field_vals
)
{
for
(
auto
candidate
:
candidates
)
{
candidate
.
push_back
(
val
);
ret
.
push_back
(
candidate
);
}
}
return
ret
;
}
// Expands all wildcards in the candidate pattern once, using the pattern
// to decide which constructors to insert. Returns a list of all possible expansions.
Array
<
Pattern
>
ExpandWildcards
(
const
Pattern
&
clause_pat
,
const
Pattern
&
cand
,
const
Module
&
mod
)
{
auto
ctor_cand
=
cand
.
as
<
PatternConstructorNode
>
();
PatternConstructor
clause_ctor
=
Downcast
<
PatternConstructor
>
(
clause_pat
);
auto
gtv
=
Downcast
<
GlobalTypeVar
>
(
clause_ctor
->
constructor
->
belong_to
);
// for a wildcard node, create constructor nodes with wildcards for all args
if
(
!
ctor_cand
)
{
TypeData
td
=
mod
->
LookupDef
(
gtv
);
// for each constructor add a candidate
Array
<
Pattern
>
ret
;
for
(
auto
constructor
:
td
->
constructors
)
{
Array
<
Pattern
>
args
;
for
(
auto
inp
:
constructor
->
inputs
)
{
args
.
push_back
(
PatternWildcardNode
::
make
());
}
ret
.
push_back
(
PatternConstructorNode
::
make
(
constructor
,
args
));
}
return
ret
;
}
// for constructors, we will expand the wildcards in any field
// that is an ADT
Array
<
Array
<
Pattern
>>
values_by_field
;
for
(
size_t
i
=
0
;
i
<
ctor_cand
->
constructor
->
inputs
.
size
();
i
++
)
{
auto
*
subpattern
=
clause_ctor
->
patterns
[
i
].
as
<
PatternConstructorNode
>
();
// for non-ADT fields, we can only have a wildcard for the value
if
(
!
subpattern
)
{
values_by_field
.
push_back
({
PatternWildcardNode
::
make
()});
continue
;
}
// otherwise, recursively expand
values_by_field
.
push_back
(
ExpandWildcards
(
GetRef
<
Pattern
>
(
subpattern
),
ctor_cand
->
patterns
[
i
],
mod
));
}
// generate new candidates using a cartesian product
auto
all_subfields
=
CartesianProduct
(
values_by_field
);
Array
<
Pattern
>
ret
;
for
(
auto
subfields
:
all_subfields
)
{
ret
.
push_back
(
PatternConstructorNode
::
make
(
ctor_cand
->
constructor
,
subfields
));
}
return
ret
;
}
/*!
* \brief Finds cases that the match expression does not catch, if any.
* \return Returns a list of cases that are not handled by the match
* expression.
*/
Array
<
Pattern
>
UnmatchedCases
(
const
Match
&
match
,
const
Module
&
mod
)
{
/* algorithm:
* candidates = { Wildcard }
* while candidates not empty {
* cand = candidates.pop()
* for clause in clauses {
* if clause fails: next clause
* if clause matches candidate: next candidate
* if candidate is not specific enough:
* candidates += expand_possible_wildcards(cand)
* next candidate
* }
* failed_candidates += { cand }
* }
* return failed_candidates
*/
std
::
stack
<
Pattern
>
candidates
;
candidates
.
push
(
PatternWildcardNode
::
make
());
CandidateChecker
checker
;
Array
<
Pattern
>
failures
;
while
(
!
candidates
.
empty
())
{
Pattern
cand
=
candidates
.
top
();
candidates
.
pop
();
bool
failure
=
true
;
for
(
auto
clause
:
match
->
clauses
)
{
// if the check fails, we move on to the next
MatchResult
check
=
checker
.
Check
(
clause
->
lhs
,
cand
);
if
(
check
==
MatchResult
::
kClash
)
{
continue
;
}
// either success or we need to generate more candidates;
// either way, we're done with this candidate
failure
=
false
;
if
(
check
==
MatchResult
::
kUnspecified
)
{
auto
new_candidates
=
ExpandWildcards
(
clause
->
lhs
,
cand
,
mod
);
for
(
auto
candidate
:
new_candidates
)
{
candidates
.
push
(
candidate
);
}
}
break
;
}
if
(
failure
)
{
failures
.
push_back
(
cand
);
}
}
return
failures
;
}
// expose for testing only
TVM_REGISTER_API
(
"relay._ir_pass.unmatched_cases"
)
.
set_body_typed
<
Array
<
Pattern
>
(
const
Match
&
,
const
Module
&
)
>
([](
const
Match
&
match
,
const
Module
&
mod_ref
)
{
Module
call_mod
=
mod_ref
;
if
(
!
call_mod
.
defined
())
{
call_mod
=
ModuleNode
::
make
({},
{});
}
return
UnmatchedCases
(
match
,
call_mod
);
});
}
// namespace relay
}
// namespace tvm
src/relay/pass/type_infer.cc
View file @
a698ad7f
...
...
@@ -293,6 +293,15 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
GetType
(
c
->
rhs
),
op
->
span
);
}
// check completness
Match
match
=
GetRef
<
Match
>
(
op
);
Array
<
Pattern
>
unmatched_cases
=
UnmatchedCases
(
match
,
this
->
mod_
);
if
(
unmatched_cases
.
size
()
!=
0
)
{
LOG
(
WARNING
)
<<
"Match clause "
<<
match
<<
" does not handle the following cases: "
<<
unmatched_cases
;
}
return
rtype
;
}
...
...
tests/python/relay/test_pass_unmatched_cases.py
0 → 100644
View file @
a698ad7f
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment