Commit 08d92203 by 雾雨魔理沙 Committed by Jared Roesch

[Relay] add Tuple pattern (#3596)

* implement tuple pattern

* add tuple pattern

* lint;

* lint

* lint

* fix error

* fix

* add test
parent 98c99805
...@@ -163,6 +163,29 @@ class PatternConstructorNode : public PatternNode { ...@@ -163,6 +163,29 @@ class PatternConstructorNode : public PatternNode {
RELAY_DEFINE_NODE_REF(PatternConstructor, PatternConstructorNode, Pattern); RELAY_DEFINE_NODE_REF(PatternConstructor, PatternConstructorNode, Pattern);
/*! \brief A tuple pattern. Matches a tuple, binds recursively. */
class PatternTuple;
/*! \brief PatternVar container node */
class PatternTupleNode : public PatternNode {
public:
/*! Sub-patterns to match against each value of the tuple. */
tvm::Array<Pattern> patterns;
PatternTupleNode() {}
TVM_DLL static PatternTuple make(tvm::Array<Pattern> var);
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("patterns", &patterns);
v->Visit("span", &span);
}
static constexpr const char* _type_key = "relay.PatternTuple";
TVM_DECLARE_NODE_TYPE_INFO(PatternTupleNode, PatternNode);
};
RELAY_DEFINE_NODE_REF(PatternTuple, PatternTupleNode, Pattern);
/*! /*!
* \brief Stores all data for an Algebraic Data Type (ADT). * \brief Stores all data for an Algebraic Data Type (ADT).
* *
......
...@@ -100,6 +100,8 @@ class PatternFunctor<R(const Pattern& n, Args...)> { ...@@ -100,6 +100,8 @@ class PatternFunctor<R(const Pattern& n, Args...)> {
Args... args) PATTERN_FUNCTOR_DEFAULT; Args... args) PATTERN_FUNCTOR_DEFAULT;
virtual R VisitPattern_(const PatternConstructorNode* op, virtual R VisitPattern_(const PatternConstructorNode* op,
Args... args) PATTERN_FUNCTOR_DEFAULT; Args... args) PATTERN_FUNCTOR_DEFAULT;
virtual R VisitPattern_(const PatternTupleNode* op,
Args... args) PATTERN_FUNCTOR_DEFAULT;
virtual R VisitPatternDefault_(const Node* op, Args...) { virtual R VisitPatternDefault_(const Node* op, Args...) {
throw Error(std::string("Do not have a default for ") + op->type_key()); throw Error(std::string("Do not have a default for ") + op->type_key());
} }
...@@ -112,6 +114,7 @@ class PatternFunctor<R(const Pattern& n, Args...)> { ...@@ -112,6 +114,7 @@ class PatternFunctor<R(const Pattern& n, Args...)> {
RELAY_PATTERN_FUNCTOR_DISPATCH(PatternWildcardNode); RELAY_PATTERN_FUNCTOR_DISPATCH(PatternWildcardNode);
RELAY_PATTERN_FUNCTOR_DISPATCH(PatternVarNode); RELAY_PATTERN_FUNCTOR_DISPATCH(PatternVarNode);
RELAY_PATTERN_FUNCTOR_DISPATCH(PatternConstructorNode); RELAY_PATTERN_FUNCTOR_DISPATCH(PatternConstructorNode);
RELAY_PATTERN_FUNCTOR_DISPATCH(PatternTupleNode);
return vtable; return vtable;
} }
}; };
...@@ -127,6 +130,7 @@ class PatternVisitor : public ::tvm::relay::PatternFunctor<void(const Pattern& n ...@@ -127,6 +130,7 @@ class PatternVisitor : public ::tvm::relay::PatternFunctor<void(const Pattern& n
void VisitPattern_(const PatternWildcardNode* op) override; void VisitPattern_(const PatternWildcardNode* op) override;
void VisitPattern_(const PatternVarNode* op) override; void VisitPattern_(const PatternVarNode* op) override;
void VisitPattern_(const PatternConstructorNode* op) override; void VisitPattern_(const PatternConstructorNode* op) override;
void VisitPattern_(const PatternTupleNode* op) override;
virtual void VisitType(const Type& t); virtual void VisitType(const Type& t);
virtual void VisitVar(const Var& v); virtual void VisitVar(const Var& v);
virtual void VisitConstructor(const Constructor& c); virtual void VisitConstructor(const Constructor& c);
...@@ -144,6 +148,7 @@ class PatternMutator ...@@ -144,6 +148,7 @@ class PatternMutator
Pattern VisitPattern_(const PatternWildcardNode* op) override; Pattern VisitPattern_(const PatternWildcardNode* op) override;
Pattern VisitPattern_(const PatternVarNode* op) override; Pattern VisitPattern_(const PatternVarNode* op) override;
Pattern VisitPattern_(const PatternConstructorNode* op) override; Pattern VisitPattern_(const PatternConstructorNode* op) override;
Pattern VisitPattern_(const PatternTupleNode* op) override;
/*! \brief Used to visit the types inside of patterns. /*! \brief Used to visit the types inside of patterns.
* *
* Can be overloaded to transform the types in arbitrary * Can be overloaded to transform the types in arbitrary
......
...@@ -105,6 +105,7 @@ RefWrite = expr.RefWrite ...@@ -105,6 +105,7 @@ RefWrite = expr.RefWrite
PatternWildcard = adt.PatternWildcard PatternWildcard = adt.PatternWildcard
PatternVar = adt.PatternVar PatternVar = adt.PatternVar
PatternConstructor = adt.PatternConstructor PatternConstructor = adt.PatternConstructor
PatternTuple = adt.PatternTuple
Constructor = adt.Constructor Constructor = adt.Constructor
TypeData = adt.TypeData TypeData = adt.TypeData
Clause = adt.Clause Clause = adt.Clause
......
...@@ -90,6 +90,29 @@ class PatternConstructor(Pattern): ...@@ -90,6 +90,29 @@ class PatternConstructor(Pattern):
@register_relay_node @register_relay_node
class PatternTuple(Pattern):
"""Constructor pattern in Relay: Matches a tuple, binds recursively."""
def __init__(self, patterns=None):
"""Construct a tuple pattern.
Parameters
----------
patterns: Optional[List[Pattern]]
Optional subpatterns: for each field of the constructor,
match to the given subpattern (treated as a variable pattern by default).
Returns
-------
wildcard: PatternWildcard
a wildcard pattern.
"""
if patterns is None:
patterns = []
self.__init_handle_by_constructor__(_make.PatternTuple, patterns)
@register_relay_node
class Constructor(Expr): class Constructor(Expr):
"""Relay ADT constructor.""" """Relay ADT constructor."""
......
...@@ -21,7 +21,7 @@ from .ty import GlobalTypeVar, TypeVar, FuncType, TupleType, scalar_type ...@@ -21,7 +21,7 @@ from .ty import GlobalTypeVar, TypeVar, FuncType, TupleType, scalar_type
from .expr import Var, Function, GlobalVar, Let, If, Tuple, TupleGetItem, const from .expr import Var, Function, GlobalVar, Let, If, Tuple, TupleGetItem, const
from .op.tensor import add, subtract, equal from .op.tensor import add, subtract, equal
from .adt import Constructor, TypeData, Clause, Match from .adt import Constructor, TypeData, Clause, Match
from .adt import PatternConstructor, PatternVar, PatternWildcard from .adt import PatternConstructor, PatternVar, PatternWildcard, PatternTuple
from .parser import fromtext from .parser import fromtext
__PRELUDE_PATH__ = os.path.dirname(os.path.realpath(__file__)) __PRELUDE_PATH__ = os.path.dirname(os.path.realpath(__file__))
from .module import Module from .module import Module
...@@ -239,18 +239,19 @@ class Prelude: ...@@ -239,18 +239,19 @@ class Prelude:
self.zip = GlobalVar("zip") self.zip = GlobalVar("zip")
a = TypeVar("a") a = TypeVar("a")
b = TypeVar("b") b = TypeVar("b")
nil_case = Clause(PatternConstructor(self.nil), self.nil())
l1 = Var("l1") l1 = Var("l1")
l2 = Var("l2") l2 = Var("l2")
h1 = Var("h1") h1 = Var("h1")
h2 = Var("h2") h2 = Var("h2")
t1 = Var("t1") t1 = Var("t1")
t2 = Var("t2") t2 = Var("t2")
inner_cons_case = Clause(PatternConstructor(self.cons, [PatternVar(h2), PatternVar(t2)]), cons_case = Clause(PatternTuple([PatternConstructor(self.cons,
self.cons(Tuple([h1, h2]), self.zip(t1, t2))) [PatternVar(h1), PatternVar(t1)]),
outer_cons_case = Clause(PatternConstructor(self.cons, [PatternVar(h1), PatternVar(t1)]), PatternConstructor(self.cons,
Match(l2, [nil_case, inner_cons_case])) [PatternVar(h2), PatternVar(t2)])]),
self.mod[self.zip] = Function([l1, l2], Match(l1, [nil_case, outer_cons_case]), self.cons(Tuple([h1, h2]), self.zip(t1, t2)))
nil_case = Clause(PatternWildcard(), self.nil())
self.mod[self.zip] = Function([l1, l2], Match(Tuple([l1, l2]), [cons_case, nil_case]),
self.l(TupleType([a, b])), [a, b]) self.l(TupleType([a, b])), [a, b])
......
...@@ -311,14 +311,18 @@ class PythonConverter(ExprFunctor): ...@@ -311,14 +311,18 @@ class PythonConverter(ExprFunctor):
if isinstance(pattern, (relay.PatternWildcard, relay.PatternVar)): if isinstance(pattern, (relay.PatternWildcard, relay.PatternVar)):
return NameConstant(True) return NameConstant(True)
# constructor patterns check whether the constructors match conds = []
# and also the matches of any nested patterns
# equiv: (arg.tag == patern_constructor.tag) if isinstance(pattern, relay.PatternConstructor):
conds = [ast.Compare(ast.Attribute(data, 'tag', Load()), # constructor patterns check whether the constructors match
[ast.Eq()], # and also the matches of any nested patterns
[ast.Num(pattern.constructor.tag)])]
# equiv: (arg.tag == patern_constructor.tag)
conds.append(ast.Compare(ast.Attribute(data, 'tag', Load()),
[ast.Eq()],
[ast.Num(pattern.constructor.tag)]))
assert isinstance(pattern, (relay.PatternConstructor, relay.PatternTuple))
# now check for any nested patterns # now check for any nested patterns
for i in range(len(pattern.patterns)): for i in range(len(pattern.patterns)):
nested_pat = pattern.patterns[i] nested_pat = pattern.patterns[i]
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
*/ */
/*! /*!
* Copyright (c) 2018 by Contributors * Copyright (c) 2019 by Contributors
* \file src/tvm/relay/interpreter.cc * \file src/tvm/relay/interpreter.cc
* \brief An interpreter for the Relay IR. * \brief An interpreter for the Relay IR.
*/ */
...@@ -708,6 +708,18 @@ class Interpreter : ...@@ -708,6 +708,18 @@ class Interpreter :
return false; return false;
} }
bool VisitPattern_(const PatternTupleNode* op, const Value& v) final {
const TupleValueNode* tvn = v.as<TupleValueNode>();
CHECK(tvn) << "need to be a tuple for match";
CHECK_EQ(op->patterns.size(), tvn->fields.size());
for (size_t i = 0; i < op->patterns.size(); ++i) {
if (!VisitPattern(op->patterns[i], tvn->fields[i])) {
return false;
}
}
return true;
}
bool VisitPattern_(const PatternWildcardNode* op, const Value& v) final { bool VisitPattern_(const PatternWildcardNode* op, const Value& v) final {
return true; return true;
} }
......
...@@ -152,19 +152,27 @@ TreeNodePtr BuildDecisionTreeFromPattern(MatchValuePtr data, ...@@ -152,19 +152,27 @@ TreeNodePtr BuildDecisionTreeFromPattern(MatchValuePtr data,
auto pattern = GetRef<PatternVar>(pat); auto pattern = GetRef<PatternVar>(pat);
auto cond = std::make_shared<VarBinding>(pattern->var, data); auto cond = std::make_shared<VarBinding>(pattern->var, data);
return TreeBranchNode::Make(cond, then_branch, else_branch); return TreeBranchNode::Make(cond, then_branch, else_branch);
} else { } else if (auto pcn = pattern.as<PatternConstructorNode>()) {
auto pat = pattern.as<PatternConstructorNode>(); auto tag = pcn->constructor->tag;
auto pattern = GetRef<PatternConstructor>(pat);
auto tag = pattern->constructor->tag;
size_t field_index = 0; size_t field_index = 0;
for (auto& p : pattern->patterns) { for (auto& p : pcn->patterns) {
auto d = std::make_shared<AccessField>(data, field_index); auto d = std::make_shared<AccessField>(data, field_index);
then_branch = BuildDecisionTreeFromPattern(d, p, then_branch, else_branch); then_branch = BuildDecisionTreeFromPattern(d, p, then_branch, else_branch);
field_index++; field_index++;
} }
auto cond = std::make_shared<TagCompare>(data, tag); auto cond = std::make_shared<TagCompare>(data, tag);
return TreeBranchNode::Make(cond, then_branch, else_branch); return TreeBranchNode::Make(cond, then_branch, else_branch);
} else {
auto pt = pattern.as<PatternTupleNode>();
CHECK(pt) << "unhandled case: " << pattern;
size_t field_index = 0;
for (auto& p : pt->patterns) {
auto d = std::make_shared<AccessField>(data, field_index);
then_branch = BuildDecisionTreeFromPattern(d, p, then_branch, else_branch);
field_index++;
}
return then_branch;
} }
} }
......
...@@ -81,6 +81,23 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) ...@@ -81,6 +81,23 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
<< ", " << node->patterns << ")"; << ", " << node->patterns << ")";
}); });
PatternTuple PatternTupleNode::make(tvm::Array<Pattern> patterns) {
NodePtr<PatternTupleNode> n = make_node<PatternTupleNode>();
n->patterns = std::move(patterns);
return PatternTuple(n);
}
TVM_REGISTER_NODE_TYPE(PatternTupleNode);
TVM_REGISTER_API("relay._make.PatternTuple")
.set_body_typed(PatternTupleNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<PatternTupleNode>([](const PatternTupleNode* node,
tvm::IRPrinter* p) {
p->stream << "PatternTupleNode(" << node->patterns << ")";
});
Constructor ConstructorNode::make(std::string name_hint, Constructor ConstructorNode::make(std::string name_hint,
tvm::Array<Type> inputs, tvm::Array<Type> inputs,
GlobalTypeVar belong_to) { GlobalTypeVar belong_to) {
......
...@@ -493,7 +493,7 @@ class AlphaEqualHandler: ...@@ -493,7 +493,7 @@ class AlphaEqualHandler:
} }
bool PatternEqual(const Pattern& lhs, const Pattern& rhs) { bool PatternEqual(const Pattern& lhs, const Pattern& rhs) {
return VisitPattern(lhs, rhs); return Compare(VisitPattern(lhs, rhs), lhs, rhs);
} }
bool VisitPattern_(const PatternWildcardNode* lhs, const Pattern& other) final { bool VisitPattern_(const PatternWildcardNode* lhs, const Pattern& other) final {
...@@ -523,6 +523,21 @@ class AlphaEqualHandler: ...@@ -523,6 +523,21 @@ class AlphaEqualHandler:
return true; return true;
} }
bool VisitPattern_(const PatternTupleNode* lhs, const Pattern& other) final {
const auto* rhs = other.as<PatternTupleNode>();
if (rhs == nullptr
|| lhs->patterns.size() != rhs->patterns.size()) {
return false;
}
for (size_t i = 0; i < lhs->patterns.size(); i++) {
if (!PatternEqual(lhs->patterns[i], rhs->patterns[i])) {
return false;
}
}
return true;
}
bool VisitExpr_(const MatchNode* lhs, const Expr& other) final { bool VisitExpr_(const MatchNode* lhs, const Expr& other) final {
const MatchNode* rhs = other.as<MatchNode>(); const MatchNode* rhs = other.as<MatchNode>();
......
...@@ -389,6 +389,14 @@ class RelayHashHandler: ...@@ -389,6 +389,14 @@ class RelayHashHandler:
return hash; return hash;
} }
size_t VisitPattern_(const PatternTupleNode* ptn) final {
size_t hash = std::hash<std::string>()(PatternTupleNode::_type_key);
for (const auto& p : ptn->patterns) {
hash = Combine(hash, PatternHash(p));
}
return hash;
}
size_t VisitPattern_(const PatternVarNode* pvn) final { size_t VisitPattern_(const PatternVarNode* pvn) final {
size_t hash = std::hash<std::string>()(PatternVarNode::_type_key); size_t hash = std::hash<std::string>()(PatternVarNode::_type_key);
hash = Combine(hash, BindVar(pvn->var)); hash = Combine(hash, BindVar(pvn->var));
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -18,8 +18,8 @@ ...@@ -18,8 +18,8 @@
*/ */
/*! /*!
* Copyright (c) 2018 by Contributors * Copyright (c) 2019 by Contributors
* \file src/tvm/relay/pattern_functor.cc * \file src/relay/ir/pattern_functor.cc
* \brief Implementations of visitors and mutators for ADT patterns. * \brief Implementations of visitors and mutators for ADT patterns.
*/ */
...@@ -48,6 +48,14 @@ Pattern PatternMutator::VisitPattern_(const PatternConstructorNode* op) { ...@@ -48,6 +48,14 @@ Pattern PatternMutator::VisitPattern_(const PatternConstructorNode* op) {
return PatternConstructorNode::make(VisitConstructor(op->constructor), pat); return PatternConstructorNode::make(VisitConstructor(op->constructor), pat);
} }
Pattern PatternMutator::VisitPattern_(const PatternTupleNode* op) {
std::vector<Pattern> pat;
for (const auto& p : op->patterns) {
pat.push_back(VisitPattern(p));
}
return PatternTupleNode::make(pat);
}
Type PatternMutator::VisitType(const Type& t) { Type PatternMutator::VisitType(const Type& t) {
return t; return t;
} }
...@@ -78,6 +86,12 @@ void PatternVisitor::VisitPattern_(const PatternConstructorNode* op) { ...@@ -78,6 +86,12 @@ void PatternVisitor::VisitPattern_(const PatternConstructorNode* op) {
} }
} }
void PatternVisitor::VisitPattern_(const PatternTupleNode* op) {
for (const auto& p : op->patterns) {
VisitPattern(p);
}
}
void PatternVisitor::VisitType(const Type& t) { } void PatternVisitor::VisitType(const Type& t) { }
void PatternVisitor::VisitVar(const Var& v) { void PatternVisitor::VisitVar(const Var& v) {
......
...@@ -68,7 +68,7 @@ class CandidateChecker : public PatternFunctor<MatchResult(const Pattern&, const ...@@ -68,7 +68,7 @@ class CandidateChecker : public PatternFunctor<MatchResult(const Pattern&, const
} }
// now check that subpatterns match // now check that subpatterns match
CHECK(op->patterns.size() == ctor_cand->patterns.size()); CHECK_EQ(op->patterns.size(), ctor_cand->patterns.size());
bool unspecified = false; bool unspecified = false;
for (size_t i = 0; i < op->patterns.size(); i++) { for (size_t i = 0; i < op->patterns.size(); i++) {
MatchResult submatch = this->Check(op->patterns[i], ctor_cand->patterns[i]); MatchResult submatch = this->Check(op->patterns[i], ctor_cand->patterns[i]);
...@@ -87,6 +87,33 @@ class CandidateChecker : public PatternFunctor<MatchResult(const Pattern&, const ...@@ -87,6 +87,33 @@ class CandidateChecker : public PatternFunctor<MatchResult(const Pattern&, const
return MatchResult::kMatch; return MatchResult::kMatch;
} }
MatchResult VisitPattern_(const PatternTupleNode* op, const Pattern& cand) override {
auto* tuple_cand = cand.as<PatternTupleNode>();
// attempting to match non-tuple to constructor pattern: need to specify
if (tuple_cand == nullptr) {
return MatchResult::kUnspecified;
}
// now check that subpatterns match
CHECK_EQ(op->patterns.size(), tuple_cand->patterns.size());
bool unspecified = false;
for (size_t i = 0; i < op->patterns.size(); i++) {
MatchResult submatch = this->Check(op->patterns[i], tuple_cand->patterns[i]);
// if we have a clash anywhere, then we can return clash
if (submatch == MatchResult::kClash) {
return MatchResult::kClash;
}
if (submatch == MatchResult::kUnspecified) {
unspecified = true;
}
}
// only return unspecified if we have ruled out a clash
if (unspecified) {
return MatchResult::kUnspecified;
}
return MatchResult::kMatch;
}
// wildcard and var patterns always match // wildcard and var patterns always match
MatchResult VisitPattern_(const PatternWildcardNode*, const Pattern&) override { MatchResult VisitPattern_(const PatternWildcardNode*, const Pattern&) override {
return MatchResult::kMatch; return MatchResult::kMatch;
...@@ -127,18 +154,38 @@ Array<Array<Pattern>> CartesianProduct(Array<Array<Pattern>> fields) { ...@@ -127,18 +154,38 @@ Array<Array<Pattern>> CartesianProduct(Array<Array<Pattern>> fields) {
return ret; return ret;
} }
// Expands all wildcards in the candidate pattern once, using the pattern Array<Pattern> ExpandWildcardsConstructor(const PatternConstructor& clause_ctor,
// to decide which constructors to insert. Returns a list of all possible expansions. const Pattern& cand,
Array<Pattern> ExpandWildcards(const Pattern& clause_pat, const Pattern& cand, const Module& mod);
Array<Pattern> ExpandWildcardsTuple(const PatternTuple& clause_tuple,
const Pattern& cand,
const Module& mod);
// Expands all wildcards in the candidate pattern once
// Returns a list of all possible expansions.
Array<Pattern> ExpandWildcards(const Pattern& clause_pat,
const Pattern& cand,
const Module& mod) { const Module& mod) {
auto ctor_cand = cand.as<PatternConstructorNode>(); if (auto clause_ctor = clause_pat.as<PatternConstructorNode>()) {
PatternConstructor clause_ctor = Downcast<PatternConstructor>(clause_pat); return ExpandWildcardsConstructor(GetRef<PatternConstructor>(clause_ctor), cand, mod);
} else {
return ExpandWildcardsTuple(Downcast<PatternTuple>(clause_pat), cand, mod);
}
}
// Expands all wildcards in the candidate pattern once.
// Use the pattern to decide which constructors to insert.
// Returns a list of all possible expansions.
Array<Pattern> ExpandWildcardsConstructor(const PatternConstructor& clause_ctor,
const Pattern& cand,
const Module& mod) {
auto gtv = Downcast<GlobalTypeVar>(clause_ctor->constructor->belong_to); auto gtv = Downcast<GlobalTypeVar>(clause_ctor->constructor->belong_to);
// for a wildcard node, create constructor nodes with wildcards for all args // for a wildcard node, create constructor nodes with wildcards for all args.
if (!ctor_cand) { if (cand.as<PatternWildcardNode>()) {
TypeData td = mod->LookupDef(gtv); TypeData td = mod->LookupDef(gtv);
// for each constructor add a candidate // for each constructor add a candidate.
Array<Pattern> ret; Array<Pattern> ret;
for (auto constructor : td->constructors) { for (auto constructor : td->constructors) {
Array<Pattern> args; Array<Pattern> args;
...@@ -150,27 +197,72 @@ Array<Pattern> ExpandWildcards(const Pattern& clause_pat, const Pattern& cand, ...@@ -150,27 +197,72 @@ Array<Pattern> ExpandWildcards(const Pattern& clause_pat, const Pattern& cand,
return ret; return ret;
} }
// for constructors, we will expand the wildcards in any field auto ctor_cand = Downcast<PatternConstructor>(cand);
// that is an ADT
// for constructors, we will expand the wildcards in any field that is an ADT.
Array<Array<Pattern>> values_by_field; Array<Array<Pattern>> values_by_field;
for (size_t i = 0; i < ctor_cand->constructor->inputs.size(); i++) { for (size_t i = 0; i < ctor_cand->constructor->inputs.size(); i++) {
auto* subpattern = clause_ctor->patterns[i].as<PatternConstructorNode>(); bool subpattern =
// for non-ADT fields, we can only have a wildcard for the value clause_ctor->patterns[i].as<PatternConstructorNode>() ||
clause_ctor->patterns[i].as<PatternTupleNode>();
// for non-ADT fields, we can only have a wildcard for the value.
if (!subpattern) { if (!subpattern) {
values_by_field.push_back({PatternWildcardNode::make()}); values_by_field.push_back({PatternWildcardNode::make()});
continue; } else {
// otherwise, recursively expand.
values_by_field.push_back(ExpandWildcards(clause_ctor->patterns[i],
ctor_cand->patterns[i],
mod));
} }
}
// otherwise, recursively expand // generate new candidates using a cartesian product.
values_by_field.push_back(ExpandWildcards(GetRef<Pattern>(subpattern), auto all_subfields = CartesianProduct(values_by_field);
ctor_cand->patterns[i], mod)); Array<Pattern> ret;
for (auto subfields : all_subfields) {
ret.push_back(PatternConstructorNode::make(ctor_cand->constructor, subfields));
}
return ret;
}
// Expands all wildcards in the candidate pattern once.
// Returns a list of all possible expansions.
Array<Pattern> ExpandWildcardsTuple(const PatternTuple& clause_tuple,
const Pattern& cand,
const Module& mod) {
// for a wildcard node, create constructor nodes with wildcards for all args.
if (cand.as<PatternWildcardNode>()) {
Array<Pattern> args;
for (auto inp : clause_tuple->patterns) {
args.push_back(PatternWildcardNode::make());
}
return {PatternTupleNode::make(args)};
}
auto tuple_cand = Downcast<PatternTuple>(cand);
// for constructors, we will expand the wildcards in any field that is an ADT.
Array<Array<Pattern>> values_by_field;
for (size_t i = 0; i < tuple_cand->patterns.size(); i++) {
bool subpattern =
clause_tuple->patterns[i].as<PatternConstructorNode>() ||
clause_tuple->patterns[i].as<PatternTupleNode>();
// for non-ADT fields, we can only have a wildcard for the value
if (!subpattern) {
values_by_field.push_back({PatternWildcardNode::make()});
} else {
// otherwise, recursively expand
values_by_field.push_back(ExpandWildcards(clause_tuple->patterns[i],
tuple_cand->patterns[i],
mod));
}
} }
// generate new candidates using a cartesian product // generate new candidates using a cartesian product
auto all_subfields = CartesianProduct(values_by_field); auto all_subfields = CartesianProduct(values_by_field);
Array<Pattern> ret; Array<Pattern> ret;
for (auto subfields : all_subfields) { for (auto subfields : all_subfields) {
ret.push_back(PatternConstructorNode::make(ctor_cand->constructor, subfields)); ret.push_back(PatternTupleNode::make(subfields));
} }
return ret; return ret;
} }
......
...@@ -1051,6 +1051,28 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)> ...@@ -1051,6 +1051,28 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
} }
} }
MatchStatus VisitPattern_(const PatternTupleNode* op, const PStatic& ps) final {
if (ps->pstatic.defined()) {
STuple stn = Downcast<STuple>(ps->pstatic);
CHECK_EQ(op->patterns.size(), stn->fields.size());
MatchStatus current_match_status = MatchStatus::Match;
for (size_t i = 0; i < op->patterns.size(); ++i) {
MatchStatus ms = VisitPattern(op->patterns[i], stn->fields[i]);
switch (ms) {
case MatchStatus::Match:
continue;
case MatchStatus::NoMatch:
return MatchStatus::NoMatch;
case MatchStatus::Unknown:
current_match_status = MatchStatus::Unknown;
}
}
return current_match_status;
} else {
return MatchStatus::Unknown;
}
}
void InitializeFuncId(const Expr& e) { void InitializeFuncId(const Expr& e) {
struct InitializeFuncIdVisitor : ExprVisitor, PatternVisitor { struct InitializeFuncIdVisitor : ExprVisitor, PatternVisitor {
PartialEvaluator* pe; PartialEvaluator* pe;
......
...@@ -276,6 +276,27 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>, ...@@ -276,6 +276,27 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
} }
} }
void VisitPattern_(const PatternTupleNode* tup, const Type& t) {
auto pt = GetRef<PatternTuple>(tup);
// we can expect a certain number of arguments
Array<Type> unknown_args;
for (size_t i = 0; i < tup->patterns.size(); i++) {
unknown_args.push_back(IncompleteTypeNode::make(Kind::kType));
}
Type expected = TupleTypeNode::make(unknown_args);
Type unified = Unify(t, expected, GetRef<NodeRef>(tup));
auto* tt = unified.as<TupleTypeNode>();
if (!tt) {
this->ReportFatalError(pt, RELAY_ERROR("Expected a tuple type, got " << unified));
}
CHECK(tup->patterns.size() == tt->fields.size()) << "not enough pattern";
for (size_t i = 0; i < tup->patterns.size(); ++i) {
VisitPattern(tup->patterns[i], tt->fields[i]);
}
}
void VisitPattern_(const PatternVarNode* pv, const Type& t) { void VisitPattern_(const PatternVarNode* pv, const Type& t) {
Type vt = GetType(pv->var); Type vt = GetType(pv->var);
Unify(vt, t, pv->span); Unify(vt, t, pv->span);
......
...@@ -611,6 +611,21 @@ def test_hash_unequal(): ...@@ -611,6 +611,21 @@ def test_hash_unequal():
assert not analysis.structural_hash(func1) == analysis.structural_hash(func3) assert not analysis.structural_hash(func1) == analysis.structural_hash(func3)
def test_tuple_match():
a = relay.Var("a")
b = relay.Var("b")
clause = relay.Clause(relay.PatternTuple([relay.PatternVar(a), relay.PatternVar(b)]), a + b)
x = relay.Match(relay.Tuple([relay.const(1), relay.const(1)]), [clause])
a = relay.Var("a")
b = relay.Var("b")
clause = relay.Clause(relay.PatternTuple([relay.PatternVar(a), relay.PatternVar(b)]), a + b)
y = relay.Match(relay.Tuple([relay.const(1), relay.const(1)]), [clause])
assert analysis.alpha_equal(x, y)
assert analysis.structural_hash(x) == analysis.structural_hash(y)
if __name__ == "__main__": if __name__ == "__main__":
test_tensor_type_alpha_equal() test_tensor_type_alpha_equal()
test_incomplete_type_alpha_equal() test_incomplete_type_alpha_equal()
......
...@@ -331,6 +331,14 @@ def test_nat_update(): ...@@ -331,6 +331,14 @@ def test_nat_update():
transform.PartialEvaluate()(m) transform.PartialEvaluate()(m)
def test_tuple_match():
a = relay.Var("a")
b = relay.Var("b")
clause = relay.Clause(relay.PatternTuple([relay.PatternVar(a), relay.PatternVar(b)]), a + b)
x = relay.Match(relay.Tuple([relay.const(1), relay.const(1)]), [clause])
assert_alpha_equal(dcpe(x), const(2))
if __name__ == '__main__': if __name__ == '__main__':
test_nat_update() test_nat_update()
test_ref() test_ref()
...@@ -351,3 +359,4 @@ if __name__ == '__main__': ...@@ -351,3 +359,4 @@ if __name__ == '__main__':
test_match_nat_id() test_match_nat_id()
test_concat() test_concat()
test_triangle_number() test_triangle_number()
test_tuple_match()
...@@ -265,3 +265,11 @@ def test_mixed_adt_constructors(): ...@@ -265,3 +265,11 @@ def test_mixed_adt_constructors():
relay.Clause(relay.PatternConstructor(p.nil, []), v) relay.Clause(relay.PatternConstructor(p.nil, []), v)
]) ])
assert len(unmatched_cases(list_of_boxes_comp, mod)) == 0 assert len(unmatched_cases(list_of_boxes_comp, mod)) == 0
def test_tuple_match():
a = relay.Var("a")
b = relay.Var("b")
clause = relay.Clause(relay.PatternTuple([relay.PatternVar(a), relay.PatternVar(b)]), a + b)
x = relay.Match(relay.Tuple([relay.const(1), relay.const(1)]), [clause])
assert len(unmatched_cases(x)) == 0
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment