Commit 08d92203 by 雾雨魔理沙 Committed by Jared Roesch

[Relay] add Tuple pattern (#3596)

* implement tuple pattern

* add tuple pattern

* lint;

* lint

* lint

* fix error

* fix

* add test
parent 98c99805
......@@ -163,6 +163,29 @@ class PatternConstructorNode : public PatternNode {
RELAY_DEFINE_NODE_REF(PatternConstructor, PatternConstructorNode, Pattern);
/*! \brief A tuple pattern. Matches a tuple, binds recursively. */
class PatternTuple;
/*! \brief PatternVar container node */
class PatternTupleNode : public PatternNode {
public:
/*! Sub-patterns to match against each value of the tuple. */
tvm::Array<Pattern> patterns;
PatternTupleNode() {}
TVM_DLL static PatternTuple make(tvm::Array<Pattern> var);
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("patterns", &patterns);
v->Visit("span", &span);
}
static constexpr const char* _type_key = "relay.PatternTuple";
TVM_DECLARE_NODE_TYPE_INFO(PatternTupleNode, PatternNode);
};
RELAY_DEFINE_NODE_REF(PatternTuple, PatternTupleNode, Pattern);
/*!
* \brief Stores all data for an Algebraic Data Type (ADT).
*
......
......@@ -100,6 +100,8 @@ class PatternFunctor<R(const Pattern& n, Args...)> {
Args... args) PATTERN_FUNCTOR_DEFAULT;
virtual R VisitPattern_(const PatternConstructorNode* op,
Args... args) PATTERN_FUNCTOR_DEFAULT;
virtual R VisitPattern_(const PatternTupleNode* op,
Args... args) PATTERN_FUNCTOR_DEFAULT;
virtual R VisitPatternDefault_(const Node* op, Args...) {
throw Error(std::string("Do not have a default for ") + op->type_key());
}
......@@ -112,6 +114,7 @@ class PatternFunctor<R(const Pattern& n, Args...)> {
RELAY_PATTERN_FUNCTOR_DISPATCH(PatternWildcardNode);
RELAY_PATTERN_FUNCTOR_DISPATCH(PatternVarNode);
RELAY_PATTERN_FUNCTOR_DISPATCH(PatternConstructorNode);
RELAY_PATTERN_FUNCTOR_DISPATCH(PatternTupleNode);
return vtable;
}
};
......@@ -127,6 +130,7 @@ class PatternVisitor : public ::tvm::relay::PatternFunctor<void(const Pattern& n
void VisitPattern_(const PatternWildcardNode* op) override;
void VisitPattern_(const PatternVarNode* op) override;
void VisitPattern_(const PatternConstructorNode* op) override;
void VisitPattern_(const PatternTupleNode* op) override;
virtual void VisitType(const Type& t);
virtual void VisitVar(const Var& v);
virtual void VisitConstructor(const Constructor& c);
......@@ -144,6 +148,7 @@ class PatternMutator
Pattern VisitPattern_(const PatternWildcardNode* op) override;
Pattern VisitPattern_(const PatternVarNode* op) override;
Pattern VisitPattern_(const PatternConstructorNode* op) override;
Pattern VisitPattern_(const PatternTupleNode* op) override;
/*! \brief Used to visit the types inside of patterns.
*
* Can be overloaded to transform the types in arbitrary
......
......@@ -105,6 +105,7 @@ RefWrite = expr.RefWrite
PatternWildcard = adt.PatternWildcard
PatternVar = adt.PatternVar
PatternConstructor = adt.PatternConstructor
PatternTuple = adt.PatternTuple
Constructor = adt.Constructor
TypeData = adt.TypeData
Clause = adt.Clause
......
......@@ -90,6 +90,29 @@ class PatternConstructor(Pattern):
@register_relay_node
class PatternTuple(Pattern):
"""Constructor pattern in Relay: Matches a tuple, binds recursively."""
def __init__(self, patterns=None):
"""Construct a tuple pattern.
Parameters
----------
patterns: Optional[List[Pattern]]
Optional subpatterns: for each field of the constructor,
match to the given subpattern (treated as a variable pattern by default).
Returns
-------
wildcard: PatternWildcard
a wildcard pattern.
"""
if patterns is None:
patterns = []
self.__init_handle_by_constructor__(_make.PatternTuple, patterns)
@register_relay_node
class Constructor(Expr):
"""Relay ADT constructor."""
......
......@@ -21,7 +21,7 @@ from .ty import GlobalTypeVar, TypeVar, FuncType, TupleType, scalar_type
from .expr import Var, Function, GlobalVar, Let, If, Tuple, TupleGetItem, const
from .op.tensor import add, subtract, equal
from .adt import Constructor, TypeData, Clause, Match
from .adt import PatternConstructor, PatternVar, PatternWildcard
from .adt import PatternConstructor, PatternVar, PatternWildcard, PatternTuple
from .parser import fromtext
__PRELUDE_PATH__ = os.path.dirname(os.path.realpath(__file__))
from .module import Module
......@@ -239,18 +239,19 @@ class Prelude:
self.zip = GlobalVar("zip")
a = TypeVar("a")
b = TypeVar("b")
nil_case = Clause(PatternConstructor(self.nil), self.nil())
l1 = Var("l1")
l2 = Var("l2")
h1 = Var("h1")
h2 = Var("h2")
t1 = Var("t1")
t2 = Var("t2")
inner_cons_case = Clause(PatternConstructor(self.cons, [PatternVar(h2), PatternVar(t2)]),
cons_case = Clause(PatternTuple([PatternConstructor(self.cons,
[PatternVar(h1), PatternVar(t1)]),
PatternConstructor(self.cons,
[PatternVar(h2), PatternVar(t2)])]),
self.cons(Tuple([h1, h2]), self.zip(t1, t2)))
outer_cons_case = Clause(PatternConstructor(self.cons, [PatternVar(h1), PatternVar(t1)]),
Match(l2, [nil_case, inner_cons_case]))
self.mod[self.zip] = Function([l1, l2], Match(l1, [nil_case, outer_cons_case]),
nil_case = Clause(PatternWildcard(), self.nil())
self.mod[self.zip] = Function([l1, l2], Match(Tuple([l1, l2]), [cons_case, nil_case]),
self.l(TupleType([a, b])), [a, b])
......
......@@ -311,14 +311,18 @@ class PythonConverter(ExprFunctor):
if isinstance(pattern, (relay.PatternWildcard, relay.PatternVar)):
return NameConstant(True)
conds = []
if isinstance(pattern, relay.PatternConstructor):
# constructor patterns check whether the constructors match
# and also the matches of any nested patterns
# equiv: (arg.tag == patern_constructor.tag)
conds = [ast.Compare(ast.Attribute(data, 'tag', Load()),
conds.append(ast.Compare(ast.Attribute(data, 'tag', Load()),
[ast.Eq()],
[ast.Num(pattern.constructor.tag)])]
[ast.Num(pattern.constructor.tag)]))
assert isinstance(pattern, (relay.PatternConstructor, relay.PatternTuple))
# now check for any nested patterns
for i in range(len(pattern.patterns)):
nested_pat = pattern.patterns[i]
......
......@@ -18,7 +18,7 @@
*/
/*!
* Copyright (c) 2018 by Contributors
* Copyright (c) 2019 by Contributors
* \file src/tvm/relay/interpreter.cc
* \brief An interpreter for the Relay IR.
*/
......@@ -708,6 +708,18 @@ class Interpreter :
return false;
}
bool VisitPattern_(const PatternTupleNode* op, const Value& v) final {
const TupleValueNode* tvn = v.as<TupleValueNode>();
CHECK(tvn) << "need to be a tuple for match";
CHECK_EQ(op->patterns.size(), tvn->fields.size());
for (size_t i = 0; i < op->patterns.size(); ++i) {
if (!VisitPattern(op->patterns[i], tvn->fields[i])) {
return false;
}
}
return true;
}
bool VisitPattern_(const PatternWildcardNode* op, const Value& v) final {
return true;
}
......
......@@ -152,19 +152,27 @@ TreeNodePtr BuildDecisionTreeFromPattern(MatchValuePtr data,
auto pattern = GetRef<PatternVar>(pat);
auto cond = std::make_shared<VarBinding>(pattern->var, data);
return TreeBranchNode::Make(cond, then_branch, else_branch);
} else {
auto pat = pattern.as<PatternConstructorNode>();
auto pattern = GetRef<PatternConstructor>(pat);
auto tag = pattern->constructor->tag;
} else if (auto pcn = pattern.as<PatternConstructorNode>()) {
auto tag = pcn->constructor->tag;
size_t field_index = 0;
for (auto& p : pattern->patterns) {
for (auto& p : pcn->patterns) {
auto d = std::make_shared<AccessField>(data, field_index);
then_branch = BuildDecisionTreeFromPattern(d, p, then_branch, else_branch);
field_index++;
}
auto cond = std::make_shared<TagCompare>(data, tag);
return TreeBranchNode::Make(cond, then_branch, else_branch);
} else {
auto pt = pattern.as<PatternTupleNode>();
CHECK(pt) << "unhandled case: " << pattern;
size_t field_index = 0;
for (auto& p : pt->patterns) {
auto d = std::make_shared<AccessField>(data, field_index);
then_branch = BuildDecisionTreeFromPattern(d, p, then_branch, else_branch);
field_index++;
}
return then_branch;
}
}
......
......@@ -81,6 +81,23 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
<< ", " << node->patterns << ")";
});
PatternTuple PatternTupleNode::make(tvm::Array<Pattern> patterns) {
NodePtr<PatternTupleNode> n = make_node<PatternTupleNode>();
n->patterns = std::move(patterns);
return PatternTuple(n);
}
TVM_REGISTER_NODE_TYPE(PatternTupleNode);
TVM_REGISTER_API("relay._make.PatternTuple")
.set_body_typed(PatternTupleNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<PatternTupleNode>([](const PatternTupleNode* node,
tvm::IRPrinter* p) {
p->stream << "PatternTupleNode(" << node->patterns << ")";
});
Constructor ConstructorNode::make(std::string name_hint,
tvm::Array<Type> inputs,
GlobalTypeVar belong_to) {
......
......@@ -493,7 +493,7 @@ class AlphaEqualHandler:
}
bool PatternEqual(const Pattern& lhs, const Pattern& rhs) {
return VisitPattern(lhs, rhs);
return Compare(VisitPattern(lhs, rhs), lhs, rhs);
}
bool VisitPattern_(const PatternWildcardNode* lhs, const Pattern& other) final {
......@@ -523,6 +523,21 @@ class AlphaEqualHandler:
return true;
}
bool VisitPattern_(const PatternTupleNode* lhs, const Pattern& other) final {
const auto* rhs = other.as<PatternTupleNode>();
if (rhs == nullptr
|| lhs->patterns.size() != rhs->patterns.size()) {
return false;
}
for (size_t i = 0; i < lhs->patterns.size(); i++) {
if (!PatternEqual(lhs->patterns[i], rhs->patterns[i])) {
return false;
}
}
return true;
}
bool VisitExpr_(const MatchNode* lhs, const Expr& other) final {
const MatchNode* rhs = other.as<MatchNode>();
......
......@@ -389,6 +389,14 @@ class RelayHashHandler:
return hash;
}
size_t VisitPattern_(const PatternTupleNode* ptn) final {
size_t hash = std::hash<std::string>()(PatternTupleNode::_type_key);
for (const auto& p : ptn->patterns) {
hash = Combine(hash, PatternHash(p));
}
return hash;
}
size_t VisitPattern_(const PatternVarNode* pvn) final {
size_t hash = std::hash<std::string>()(PatternVarNode::_type_key);
hash = Combine(hash, BindVar(pvn->var));
......
......@@ -18,8 +18,8 @@
*/
/*!
* Copyright (c) 2018 by Contributors
* \file src/tvm/relay/pattern_functor.cc
* Copyright (c) 2019 by Contributors
* \file src/relay/ir/pattern_functor.cc
* \brief Implementations of visitors and mutators for ADT patterns.
*/
......@@ -48,6 +48,14 @@ Pattern PatternMutator::VisitPattern_(const PatternConstructorNode* op) {
return PatternConstructorNode::make(VisitConstructor(op->constructor), pat);
}
Pattern PatternMutator::VisitPattern_(const PatternTupleNode* op) {
std::vector<Pattern> pat;
for (const auto& p : op->patterns) {
pat.push_back(VisitPattern(p));
}
return PatternTupleNode::make(pat);
}
Type PatternMutator::VisitType(const Type& t) {
return t;
}
......@@ -78,6 +86,12 @@ void PatternVisitor::VisitPattern_(const PatternConstructorNode* op) {
}
}
void PatternVisitor::VisitPattern_(const PatternTupleNode* op) {
for (const auto& p : op->patterns) {
VisitPattern(p);
}
}
void PatternVisitor::VisitType(const Type& t) { }
void PatternVisitor::VisitVar(const Var& v) {
......
......@@ -68,7 +68,7 @@ class CandidateChecker : public PatternFunctor<MatchResult(const Pattern&, const
}
// now check that subpatterns match
CHECK(op->patterns.size() == ctor_cand->patterns.size());
CHECK_EQ(op->patterns.size(), ctor_cand->patterns.size());
bool unspecified = false;
for (size_t i = 0; i < op->patterns.size(); i++) {
MatchResult submatch = this->Check(op->patterns[i], ctor_cand->patterns[i]);
......@@ -87,6 +87,33 @@ class CandidateChecker : public PatternFunctor<MatchResult(const Pattern&, const
return MatchResult::kMatch;
}
MatchResult VisitPattern_(const PatternTupleNode* op, const Pattern& cand) override {
auto* tuple_cand = cand.as<PatternTupleNode>();
// attempting to match non-tuple to constructor pattern: need to specify
if (tuple_cand == nullptr) {
return MatchResult::kUnspecified;
}
// now check that subpatterns match
CHECK_EQ(op->patterns.size(), tuple_cand->patterns.size());
bool unspecified = false;
for (size_t i = 0; i < op->patterns.size(); i++) {
MatchResult submatch = this->Check(op->patterns[i], tuple_cand->patterns[i]);
// if we have a clash anywhere, then we can return clash
if (submatch == MatchResult::kClash) {
return MatchResult::kClash;
}
if (submatch == MatchResult::kUnspecified) {
unspecified = true;
}
}
// only return unspecified if we have ruled out a clash
if (unspecified) {
return MatchResult::kUnspecified;
}
return MatchResult::kMatch;
}
// wildcard and var patterns always match
MatchResult VisitPattern_(const PatternWildcardNode*, const Pattern&) override {
return MatchResult::kMatch;
......@@ -127,18 +154,38 @@ Array<Array<Pattern>> CartesianProduct(Array<Array<Pattern>> fields) {
return ret;
}
// Expands all wildcards in the candidate pattern once, using the pattern
// to decide which constructors to insert. Returns a list of all possible expansions.
Array<Pattern> ExpandWildcards(const Pattern& clause_pat, const Pattern& cand,
Array<Pattern> ExpandWildcardsConstructor(const PatternConstructor& clause_ctor,
const Pattern& cand,
const Module& mod);
Array<Pattern> ExpandWildcardsTuple(const PatternTuple& clause_tuple,
const Pattern& cand,
const Module& mod);
// Expands all wildcards in the candidate pattern once
// Returns a list of all possible expansions.
Array<Pattern> ExpandWildcards(const Pattern& clause_pat,
const Pattern& cand,
const Module& mod) {
if (auto clause_ctor = clause_pat.as<PatternConstructorNode>()) {
return ExpandWildcardsConstructor(GetRef<PatternConstructor>(clause_ctor), cand, mod);
} else {
return ExpandWildcardsTuple(Downcast<PatternTuple>(clause_pat), cand, mod);
}
}
// Expands all wildcards in the candidate pattern once.
// Use the pattern to decide which constructors to insert.
// Returns a list of all possible expansions.
Array<Pattern> ExpandWildcardsConstructor(const PatternConstructor& clause_ctor,
const Pattern& cand,
const Module& mod) {
auto ctor_cand = cand.as<PatternConstructorNode>();
PatternConstructor clause_ctor = Downcast<PatternConstructor>(clause_pat);
auto gtv = Downcast<GlobalTypeVar>(clause_ctor->constructor->belong_to);
// for a wildcard node, create constructor nodes with wildcards for all args
if (!ctor_cand) {
// for a wildcard node, create constructor nodes with wildcards for all args.
if (cand.as<PatternWildcardNode>()) {
TypeData td = mod->LookupDef(gtv);
// for each constructor add a candidate
// for each constructor add a candidate.
Array<Pattern> ret;
for (auto constructor : td->constructors) {
Array<Pattern> args;
......@@ -150,27 +197,72 @@ Array<Pattern> ExpandWildcards(const Pattern& clause_pat, const Pattern& cand,
return ret;
}
// for constructors, we will expand the wildcards in any field
// that is an ADT
auto ctor_cand = Downcast<PatternConstructor>(cand);
// for constructors, we will expand the wildcards in any field that is an ADT.
Array<Array<Pattern>> values_by_field;
for (size_t i = 0; i < ctor_cand->constructor->inputs.size(); i++) {
auto* subpattern = clause_ctor->patterns[i].as<PatternConstructorNode>();
// for non-ADT fields, we can only have a wildcard for the value
bool subpattern =
clause_ctor->patterns[i].as<PatternConstructorNode>() ||
clause_ctor->patterns[i].as<PatternTupleNode>();
// for non-ADT fields, we can only have a wildcard for the value.
if (!subpattern) {
values_by_field.push_back({PatternWildcardNode::make()});
continue;
} else {
// otherwise, recursively expand.
values_by_field.push_back(ExpandWildcards(clause_ctor->patterns[i],
ctor_cand->patterns[i],
mod));
}
}
// generate new candidates using a cartesian product.
auto all_subfields = CartesianProduct(values_by_field);
Array<Pattern> ret;
for (auto subfields : all_subfields) {
ret.push_back(PatternConstructorNode::make(ctor_cand->constructor, subfields));
}
return ret;
}
// Expands all wildcards in the candidate pattern once.
// Returns a list of all possible expansions.
Array<Pattern> ExpandWildcardsTuple(const PatternTuple& clause_tuple,
const Pattern& cand,
const Module& mod) {
// for a wildcard node, create constructor nodes with wildcards for all args.
if (cand.as<PatternWildcardNode>()) {
Array<Pattern> args;
for (auto inp : clause_tuple->patterns) {
args.push_back(PatternWildcardNode::make());
}
return {PatternTupleNode::make(args)};
}
auto tuple_cand = Downcast<PatternTuple>(cand);
// for constructors, we will expand the wildcards in any field that is an ADT.
Array<Array<Pattern>> values_by_field;
for (size_t i = 0; i < tuple_cand->patterns.size(); i++) {
bool subpattern =
clause_tuple->patterns[i].as<PatternConstructorNode>() ||
clause_tuple->patterns[i].as<PatternTupleNode>();
// for non-ADT fields, we can only have a wildcard for the value
if (!subpattern) {
values_by_field.push_back({PatternWildcardNode::make()});
} else {
// otherwise, recursively expand
values_by_field.push_back(ExpandWildcards(GetRef<Pattern>(subpattern),
ctor_cand->patterns[i], mod));
values_by_field.push_back(ExpandWildcards(clause_tuple->patterns[i],
tuple_cand->patterns[i],
mod));
}
}
// generate new candidates using a cartesian product
auto all_subfields = CartesianProduct(values_by_field);
Array<Pattern> ret;
for (auto subfields : all_subfields) {
ret.push_back(PatternConstructorNode::make(ctor_cand->constructor, subfields));
ret.push_back(PatternTupleNode::make(subfields));
}
return ret;
}
......
......@@ -1051,6 +1051,28 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
}
}
MatchStatus VisitPattern_(const PatternTupleNode* op, const PStatic& ps) final {
if (ps->pstatic.defined()) {
STuple stn = Downcast<STuple>(ps->pstatic);
CHECK_EQ(op->patterns.size(), stn->fields.size());
MatchStatus current_match_status = MatchStatus::Match;
for (size_t i = 0; i < op->patterns.size(); ++i) {
MatchStatus ms = VisitPattern(op->patterns[i], stn->fields[i]);
switch (ms) {
case MatchStatus::Match:
continue;
case MatchStatus::NoMatch:
return MatchStatus::NoMatch;
case MatchStatus::Unknown:
current_match_status = MatchStatus::Unknown;
}
}
return current_match_status;
} else {
return MatchStatus::Unknown;
}
}
void InitializeFuncId(const Expr& e) {
struct InitializeFuncIdVisitor : ExprVisitor, PatternVisitor {
PartialEvaluator* pe;
......
......@@ -276,6 +276,27 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
}
}
void VisitPattern_(const PatternTupleNode* tup, const Type& t) {
auto pt = GetRef<PatternTuple>(tup);
// we can expect a certain number of arguments
Array<Type> unknown_args;
for (size_t i = 0; i < tup->patterns.size(); i++) {
unknown_args.push_back(IncompleteTypeNode::make(Kind::kType));
}
Type expected = TupleTypeNode::make(unknown_args);
Type unified = Unify(t, expected, GetRef<NodeRef>(tup));
auto* tt = unified.as<TupleTypeNode>();
if (!tt) {
this->ReportFatalError(pt, RELAY_ERROR("Expected a tuple type, got " << unified));
}
CHECK(tup->patterns.size() == tt->fields.size()) << "not enough pattern";
for (size_t i = 0; i < tup->patterns.size(); ++i) {
VisitPattern(tup->patterns[i], tt->fields[i]);
}
}
void VisitPattern_(const PatternVarNode* pv, const Type& t) {
Type vt = GetType(pv->var);
Unify(vt, t, pv->span);
......
......@@ -611,6 +611,21 @@ def test_hash_unequal():
assert not analysis.structural_hash(func1) == analysis.structural_hash(func3)
def test_tuple_match():
a = relay.Var("a")
b = relay.Var("b")
clause = relay.Clause(relay.PatternTuple([relay.PatternVar(a), relay.PatternVar(b)]), a + b)
x = relay.Match(relay.Tuple([relay.const(1), relay.const(1)]), [clause])
a = relay.Var("a")
b = relay.Var("b")
clause = relay.Clause(relay.PatternTuple([relay.PatternVar(a), relay.PatternVar(b)]), a + b)
y = relay.Match(relay.Tuple([relay.const(1), relay.const(1)]), [clause])
assert analysis.alpha_equal(x, y)
assert analysis.structural_hash(x) == analysis.structural_hash(y)
if __name__ == "__main__":
test_tensor_type_alpha_equal()
test_incomplete_type_alpha_equal()
......
......@@ -331,6 +331,14 @@ def test_nat_update():
transform.PartialEvaluate()(m)
def test_tuple_match():
a = relay.Var("a")
b = relay.Var("b")
clause = relay.Clause(relay.PatternTuple([relay.PatternVar(a), relay.PatternVar(b)]), a + b)
x = relay.Match(relay.Tuple([relay.const(1), relay.const(1)]), [clause])
assert_alpha_equal(dcpe(x), const(2))
if __name__ == '__main__':
test_nat_update()
test_ref()
......@@ -351,3 +359,4 @@ if __name__ == '__main__':
test_match_nat_id()
test_concat()
test_triangle_number()
test_tuple_match()
......@@ -265,3 +265,11 @@ def test_mixed_adt_constructors():
relay.Clause(relay.PatternConstructor(p.nil, []), v)
])
assert len(unmatched_cases(list_of_boxes_comp, mod)) == 0
def test_tuple_match():
a = relay.Var("a")
b = relay.Var("b")
clause = relay.Clause(relay.PatternTuple([relay.PatternVar(a), relay.PatternVar(b)]), a + b)
x = relay.Match(relay.Tuple([relay.const(1), relay.const(1)]), [clause])
assert len(unmatched_cases(x)) == 0
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment