Commit 331585f4 by 雾雨魔理沙 Committed by Thierry Moreau

[Relay] Strict mode in pattern matching (#3620)

* add fatal

lint

lint

lint

do

make completeness check an error

lint

remove fatal

* fix test

* reset parser file

* remove unneeded import

* Update python/tvm/relay/adt.py

Co-Authored-By: Steven S. Lyubomirsky <slyubomirsky@gmail.com>

* Update include/tvm/relay/adt.h

Co-Authored-By: Steven S. Lyubomirsky <slyubomirsky@gmail.com>

* Eliminate trailing whitespace (my fault)
parent 461e019e
...@@ -241,14 +241,20 @@ class MatchNode : public ExprNode { ...@@ -241,14 +241,20 @@ class MatchNode : public ExprNode {
/*! \brief The match node clauses. */ /*! \brief The match node clauses. */
tvm::Array<Clause> clauses; tvm::Array<Clause> clauses;
/*! \brief Should this match be complete (cover all cases)?
* If yes, the type checker will generate an error if there are any missing cases.
*/
bool complete;
void VisitAttrs(tvm::AttrVisitor* v) final { void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("data", &data); v->Visit("data", &data);
v->Visit("clauses", &clauses); v->Visit("clauses", &clauses);
v->Visit("complete", &complete);
v->Visit("span", &span); v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_); v->Visit("_checked_type_", &checked_type_);
} }
TVM_DLL static Match make(Expr data, tvm::Array<Clause> pattern); TVM_DLL static Match make(Expr data, tvm::Array<Clause> pattern, bool complete = true);
static constexpr const char* _type_key = "relay.Match"; static constexpr const char* _type_key = "relay.Match";
TVM_DECLARE_NODE_TYPE_INFO(MatchNode, ExprNode); TVM_DECLARE_NODE_TYPE_INFO(MatchNode, ExprNode);
......
...@@ -186,18 +186,24 @@ class Clause(NodeBase): ...@@ -186,18 +186,24 @@ class Clause(NodeBase):
class Match(Expr): class Match(Expr):
"""Pattern matching expression in Relay.""" """Pattern matching expression in Relay."""
def __init__(self, data, clauses): def __init__(self, data, clauses, complete=True):
"""Construct a Match. """Construct a Match.
Parameters Parameters
---------- ----------
data: tvm.relay.Expr data: tvm.relay.Expr
The value being deconstructed and matched. The value being deconstructed and matched.
clauses: List[tvm.relay.Clause] clauses: List[tvm.relay.Clause]
The pattern match clauses. The pattern match clauses.
complete: Optional[Bool]
Should the match be complete (cover all cases)?
If yes, the type checker will generate an error if there are any missing cases.
Returns Returns
------- -------
match: tvm.relay.Expr match: tvm.relay.Expr
The match expression. The match expression.
""" """
self.__init_handle_by_constructor__(_make.Match, data, clauses) self.__init_handle_by_constructor__(_make.Match, data, clauses, complete)
...@@ -51,7 +51,7 @@ class Prelude: ...@@ -51,7 +51,7 @@ class Prelude:
y = Var("y") y = Var("y")
z = Var("z") z = Var("z")
cons_case = Clause(PatternConstructor(self.cons, [PatternVar(y), PatternVar(z)]), y) cons_case = Clause(PatternConstructor(self.cons, [PatternVar(y), PatternVar(z)]), y)
self.mod[self.hd] = Function([x], Match(x, [cons_case]), a, [a]) self.mod[self.hd] = Function([x], Match(x, [cons_case], False), a, [a])
def define_list_tl(self): def define_list_tl(self):
"""Defines a function to get the tail of a list. """Defines a function to get the tail of a list.
...@@ -64,7 +64,7 @@ class Prelude: ...@@ -64,7 +64,7 @@ class Prelude:
y = Var("y") y = Var("y")
z = Var("z") z = Var("z")
cons_case = Clause(PatternConstructor(self.cons, [PatternVar(y), PatternVar(z)]), z) cons_case = Clause(PatternConstructor(self.cons, [PatternVar(y), PatternVar(z)]), z)
self.mod[self.tl] = Function([x], Match(x, [cons_case]), self.l(a), [a]) self.mod[self.tl] = Function([x], Match(x, [cons_case], False), self.l(a), [a])
def define_list_nth(self): def define_list_nth(self):
...@@ -191,7 +191,7 @@ class Prelude: ...@@ -191,7 +191,7 @@ class Prelude:
cons_case = Clause(PatternConstructor(self.cons, [PatternVar(y), PatternVar(z)]), cons_case = Clause(PatternConstructor(self.cons, [PatternVar(y), PatternVar(z)]),
f(y, self.foldr1(f, z))) f(y, self.foldr1(f, z)))
self.mod[self.foldr1] = Function([f, av], self.mod[self.foldr1] = Function([f, av],
Match(av, [one_case, cons_case]), a, [a]) Match(av, [one_case, cons_case], False), a, [a])
def define_list_concat(self): def define_list_concat(self):
......
...@@ -144,10 +144,11 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) ...@@ -144,10 +144,11 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
<< node->rhs << ")"; << node->rhs << ")";
}); });
Match MatchNode::make(Expr data, tvm::Array<Clause> clauses) { Match MatchNode::make(Expr data, tvm::Array<Clause> clauses, bool complete) {
NodePtr<MatchNode> n = make_node<MatchNode>(); NodePtr<MatchNode> n = make_node<MatchNode>();
n->data = std::move(data); n->data = std::move(data);
n->clauses = std::move(clauses); n->clauses = std::move(clauses);
n->complete = complete;
return Match(n); return Match(n);
} }
...@@ -160,7 +161,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) ...@@ -160,7 +161,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<MatchNode>([](const MatchNode* node, .set_dispatch<MatchNode>([](const MatchNode* node,
tvm::IRPrinter* p) { tvm::IRPrinter* p) {
p->stream << "MatchNode(" << node->data << ", " p->stream << "MatchNode(" << node->data << ", "
<< node->clauses << ")"; << node->clauses << ", " << node->complete << ")";
}); });
} // namespace relay } // namespace relay
......
...@@ -525,7 +525,8 @@ class AlphaEqualHandler: ...@@ -525,7 +525,8 @@ class AlphaEqualHandler:
if (rhs == nullptr if (rhs == nullptr
|| !ExprEqual(lhs->data, rhs->data) || !ExprEqual(lhs->data, rhs->data)
|| lhs->clauses.size() != rhs->clauses.size()) { || lhs->clauses.size() != rhs->clauses.size()
|| lhs->complete != rhs->complete) {
return false; return false;
} }
......
...@@ -332,7 +332,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) ...@@ -332,7 +332,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
TVM_REGISTER_API("relay._expr.TempExprRealize") TVM_REGISTER_API("relay._expr.TempExprRealize")
.set_body_typed<Expr(TempExpr)>([](TempExpr temp) { .set_body_typed<Expr(TempExpr)>([](TempExpr temp) {
return temp->Realize(); return temp->Realize();
}); });
} // namespace relay } // namespace relay
......
...@@ -212,7 +212,7 @@ Expr ExprMutator::VisitExpr_(const MatchNode* m) { ...@@ -212,7 +212,7 @@ Expr ExprMutator::VisitExpr_(const MatchNode* m) {
for (const Clause& p : m->clauses) { for (const Clause& p : m->clauses) {
clauses.push_back(VisitClause(p)); clauses.push_back(VisitClause(p));
} }
return MatchNode::make(VisitExpr(m->data), clauses); return MatchNode::make(VisitExpr(m->data), clauses, m->complete);
} }
Clause ExprMutator::VisitClause(const Clause& c) { Clause ExprMutator::VisitClause(const Clause& c) {
......
...@@ -341,6 +341,7 @@ class RelayHashHandler: ...@@ -341,6 +341,7 @@ class RelayHashHandler:
hash = Combine(hash, PatternHash(c->lhs)); hash = Combine(hash, PatternHash(c->lhs));
hash = Combine(hash, ExprHash(c->rhs)); hash = Combine(hash, ExprHash(c->rhs));
} }
hash = Combine(hash, std::hash<bool>()(mn->complete));
return hash; return hash;
} }
......
...@@ -34,6 +34,7 @@ ...@@ -34,6 +34,7 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#include <tuple> #include <tuple>
#include <string>
#include "tvm/relay/type.h" #include "tvm/relay/type.h"
namespace tvm { namespace tvm {
......
...@@ -825,7 +825,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)> ...@@ -825,7 +825,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
clauses.push_back(ClauseNode::make(c->lhs, expr)); clauses.push_back(ClauseNode::make(c->lhs, expr));
} }
store_.Invalidate(); store_.Invalidate();
return NoStatic(ll->Push(MatchNode::make(ps->dynamic, clauses))); return NoStatic(ll->Push(MatchNode::make(ps->dynamic, clauses, op->complete)));
} }
} }
LOG(FATAL) << "No case Match"; LOG(FATAL) << "No case Match";
......
...@@ -262,7 +262,7 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)> { ...@@ -262,7 +262,7 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)> {
c->lhs, c->lhs,
GetSubScope(e, 1 + clauses.size())->ll->Get(VisitExpr(c->rhs)))); GetSubScope(e, 1 + clauses.size())->ll->Get(VisitExpr(c->rhs))));
} }
return Compound(e, MatchNode::make(data, clauses), v); return Compound(e, MatchNode::make(data, clauses, m->complete), v);
} }
}; };
......
...@@ -195,7 +195,7 @@ Function ToCPS(const Function& f, const Module& m, CPSMap* cm, VarMap* vm, const ...@@ -195,7 +195,7 @@ Function ToCPS(const Function& f, const Module& m, CPSMap* cm, VarMap* vm, const
for (const auto& c : op->clauses) { for (const auto& c : op->clauses) {
clauses.push_back(ClauseNode::make(VisitPattern(c->lhs), VisitExpr(c->rhs, kf))); clauses.push_back(ClauseNode::make(VisitPattern(c->lhs), VisitExpr(c->rhs, kf)));
} }
return MatchNode::make(v, clauses); return MatchNode::make(v, clauses, op->complete);
}); });
}); });
} }
......
...@@ -295,12 +295,14 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>, ...@@ -295,12 +295,14 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
op->span); op->span);
} }
// check completness if (op->complete) {
Match match = GetRef<Match>(op); // check completness
Array<Pattern> unmatched_cases = UnmatchedCases(match, this->mod_); Match match = GetRef<Match>(op);
if (unmatched_cases.size() != 0) { Array<Pattern> unmatched_cases = UnmatchedCases(match, this->mod_);
LOG(WARNING) << "Match clause " << match << " does not handle the following cases: " if (unmatched_cases.size() != 0) {
LOG(FATAL) << "Match clause " << match << " does not handle the following cases: "
<< unmatched_cases; << unmatched_cases;
}
} }
return rtype; return rtype;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment