Commit 331585f4 by 雾雨魔理沙 Committed by Thierry Moreau

[Relay] Strict mode in pattern matching (#3620)

* add fatal

lint

lint

lint

do

make completeness check an error

lint

remove fatal

* fix test

* reset parser file

* remove unneeded import

* Update python/tvm/relay/adt.py

Co-Authored-By: Steven S. Lyubomirsky <slyubomirsky@gmail.com>

* Update include/tvm/relay/adt.h

Co-Authored-By: Steven S. Lyubomirsky <slyubomirsky@gmail.com>

* Eliminate trailing whitespace (my fault)
parent 461e019e
......@@ -241,14 +241,20 @@ class MatchNode : public ExprNode {
/*! \brief The match node clauses. */
tvm::Array<Clause> clauses;
/*! \brief Should this match be complete (cover all cases)?
* If yes, the type checker will generate an error if there are any missing cases.
*/
bool complete;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("data", &data);
v->Visit("clauses", &clauses);
v->Visit("complete", &complete);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
TVM_DLL static Match make(Expr data, tvm::Array<Clause> pattern);
TVM_DLL static Match make(Expr data, tvm::Array<Clause> pattern, bool complete = true);
static constexpr const char* _type_key = "relay.Match";
TVM_DECLARE_NODE_TYPE_INFO(MatchNode, ExprNode);
......
......@@ -186,18 +186,24 @@ class Clause(NodeBase):
class Match(Expr):
"""Pattern matching expression in Relay."""
def __init__(self, data, clauses):
def __init__(self, data, clauses, complete=True):
"""Construct a Match.
Parameters
----------
data: tvm.relay.Expr
The value being deconstructed and matched.
clauses: List[tvm.relay.Clause]
The pattern match clauses.
complete: Optional[Bool]
Should the match be complete (cover all cases)?
If yes, the type checker will generate an error if there are any missing cases.
Returns
-------
match: tvm.relay.Expr
The match expression.
"""
self.__init_handle_by_constructor__(_make.Match, data, clauses)
self.__init_handle_by_constructor__(_make.Match, data, clauses, complete)
......@@ -51,7 +51,7 @@ class Prelude:
y = Var("y")
z = Var("z")
cons_case = Clause(PatternConstructor(self.cons, [PatternVar(y), PatternVar(z)]), y)
self.mod[self.hd] = Function([x], Match(x, [cons_case]), a, [a])
self.mod[self.hd] = Function([x], Match(x, [cons_case], False), a, [a])
def define_list_tl(self):
"""Defines a function to get the tail of a list.
......@@ -64,7 +64,7 @@ class Prelude:
y = Var("y")
z = Var("z")
cons_case = Clause(PatternConstructor(self.cons, [PatternVar(y), PatternVar(z)]), z)
self.mod[self.tl] = Function([x], Match(x, [cons_case]), self.l(a), [a])
self.mod[self.tl] = Function([x], Match(x, [cons_case], False), self.l(a), [a])
def define_list_nth(self):
......@@ -191,7 +191,7 @@ class Prelude:
cons_case = Clause(PatternConstructor(self.cons, [PatternVar(y), PatternVar(z)]),
f(y, self.foldr1(f, z)))
self.mod[self.foldr1] = Function([f, av],
Match(av, [one_case, cons_case]), a, [a])
Match(av, [one_case, cons_case], False), a, [a])
def define_list_concat(self):
......
......@@ -144,10 +144,11 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
<< node->rhs << ")";
});
Match MatchNode::make(Expr data, tvm::Array<Clause> clauses) {
Match MatchNode::make(Expr data, tvm::Array<Clause> clauses, bool complete) {
NodePtr<MatchNode> n = make_node<MatchNode>();
n->data = std::move(data);
n->clauses = std::move(clauses);
n->complete = complete;
return Match(n);
}
......@@ -160,7 +161,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<MatchNode>([](const MatchNode* node,
tvm::IRPrinter* p) {
p->stream << "MatchNode(" << node->data << ", "
<< node->clauses << ")";
<< node->clauses << ", " << node->complete << ")";
});
} // namespace relay
......
......@@ -525,7 +525,8 @@ class AlphaEqualHandler:
if (rhs == nullptr
|| !ExprEqual(lhs->data, rhs->data)
|| lhs->clauses.size() != rhs->clauses.size()) {
|| lhs->clauses.size() != rhs->clauses.size()
|| lhs->complete != rhs->complete) {
return false;
}
......
......@@ -332,7 +332,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
TVM_REGISTER_API("relay._expr.TempExprRealize")
.set_body_typed<Expr(TempExpr)>([](TempExpr temp) {
return temp->Realize();
return temp->Realize();
});
} // namespace relay
......
......@@ -212,7 +212,7 @@ Expr ExprMutator::VisitExpr_(const MatchNode* m) {
for (const Clause& p : m->clauses) {
clauses.push_back(VisitClause(p));
}
return MatchNode::make(VisitExpr(m->data), clauses);
return MatchNode::make(VisitExpr(m->data), clauses, m->complete);
}
Clause ExprMutator::VisitClause(const Clause& c) {
......
......@@ -341,6 +341,7 @@ class RelayHashHandler:
hash = Combine(hash, PatternHash(c->lhs));
hash = Combine(hash, ExprHash(c->rhs));
}
hash = Combine(hash, std::hash<bool>()(mn->complete));
return hash;
}
......
......@@ -34,6 +34,7 @@
#include <utility>
#include <vector>
#include <tuple>
#include <string>
#include "tvm/relay/type.h"
namespace tvm {
......
......@@ -825,7 +825,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
clauses.push_back(ClauseNode::make(c->lhs, expr));
}
store_.Invalidate();
return NoStatic(ll->Push(MatchNode::make(ps->dynamic, clauses)));
return NoStatic(ll->Push(MatchNode::make(ps->dynamic, clauses, op->complete)));
}
}
LOG(FATAL) << "No case Match";
......
......@@ -262,7 +262,7 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)> {
c->lhs,
GetSubScope(e, 1 + clauses.size())->ll->Get(VisitExpr(c->rhs))));
}
return Compound(e, MatchNode::make(data, clauses), v);
return Compound(e, MatchNode::make(data, clauses, m->complete), v);
}
};
......
......@@ -195,7 +195,7 @@ Function ToCPS(const Function& f, const Module& m, CPSMap* cm, VarMap* vm, const
for (const auto& c : op->clauses) {
clauses.push_back(ClauseNode::make(VisitPattern(c->lhs), VisitExpr(c->rhs, kf)));
}
return MatchNode::make(v, clauses);
return MatchNode::make(v, clauses, op->complete);
});
});
}
......
......@@ -295,12 +295,14 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
op->span);
}
// check completness
Match match = GetRef<Match>(op);
Array<Pattern> unmatched_cases = UnmatchedCases(match, this->mod_);
if (unmatched_cases.size() != 0) {
LOG(WARNING) << "Match clause " << match << " does not handle the following cases: "
if (op->complete) {
// check completness
Match match = GetRef<Match>(op);
Array<Pattern> unmatched_cases = UnmatchedCases(match, this->mod_);
if (unmatched_cases.size() != 0) {
LOG(FATAL) << "Match clause " << match << " does not handle the following cases: "
<< unmatched_cases;
}
}
return rtype;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment