Commit dee8cf9b by 雾雨魔理沙 Committed by Tianqi Chen

[Relay] fix anf for reference and pattern matching (#2637)

parent cc5a3cf0
...@@ -120,6 +120,22 @@ class DependencyGraph::Creator : private ExprFunctor<void(const Expr& e)> { ...@@ -120,6 +120,22 @@ class DependencyGraph::Creator : private ExprFunctor<void(const Expr& e)> {
Depend(n, t->tuple); Depend(n, t->tuple);
} }
void VisitExpr_(const RefCreateNode* r) final {
DependencyGraph::Node* n = graph_.expr_node[GetRef<Expr>(r)];
Depend(n, r->value);
}
void VisitExpr_(const RefReadNode* r) final {
DependencyGraph::Node* n = graph_.expr_node[GetRef<Expr>(r)];
Depend(n, r->ref);
}
void VisitExpr_(const RefWriteNode* r) final {
DependencyGraph::Node* n = graph_.expr_node[GetRef<Expr>(r)];
Depend(n, r->ref);
Depend(n, r->value);
}
void VisitExpr_(const IfNode* i) final { void VisitExpr_(const IfNode* i) final {
DependencyGraph::Node* n = graph_.expr_node[GetRef<Expr>(i)]; DependencyGraph::Node* n = graph_.expr_node[GetRef<Expr>(i)];
DependencyGraph::Node* t = NewNode(true); DependencyGraph::Node* t = NewNode(true);
...@@ -150,6 +166,21 @@ class DependencyGraph::Creator : private ExprFunctor<void(const Expr& e)> { ...@@ -150,6 +166,21 @@ class DependencyGraph::Creator : private ExprFunctor<void(const Expr& e)> {
graph_.post_dfs_order.push_back(b); graph_.post_dfs_order.push_back(b);
} }
void VisitExpr_(const MatchNode* m) final {
DependencyGraph::Node* n = graph_.expr_node[GetRef<Expr>(m)];
Depend(n, m->data);
std::vector<DependencyGraph::Node*> v;
for (const Clause& c : m->clauses) {
DependencyGraph::Node* b = NewNode(true);
Depend(n, b);
Depend(b, c->rhs);
v.push_back(b);
}
for (auto it = v.rbegin(); it != v.rend(); ++it) {
graph_.post_dfs_order.push_back(*it);
}
}
void VisitExpr_(const VarNode* v) final { } void VisitExpr_(const VarNode* v) final { }
void VisitExpr_(const GlobalVarNode* v) final { } void VisitExpr_(const GlobalVarNode* v) final { }
...@@ -157,6 +188,8 @@ class DependencyGraph::Creator : private ExprFunctor<void(const Expr& e)> { ...@@ -157,6 +188,8 @@ class DependencyGraph::Creator : private ExprFunctor<void(const Expr& e)> {
void VisitExpr_(const ConstantNode* c) final { } void VisitExpr_(const ConstantNode* c) final { }
void VisitExpr_(const OpNode* o) final { } void VisitExpr_(const OpNode* o) final { }
void VisitExpr_(const ConstructorNode* c) final { }
}; };
DependencyGraph DependencyGraph::Create(common::Arena* arena, const Expr& body) { DependencyGraph DependencyGraph::Create(common::Arena* arena, const Expr& body) {
...@@ -305,6 +338,21 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)> { ...@@ -305,6 +338,21 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)> {
return Compound(e, TupleGetItemNode::make(VisitExpr(t->tuple), t->index), v); return Compound(e, TupleGetItemNode::make(VisitExpr(t->tuple), t->index), v);
} }
Expr VisitExpr_(const RefCreateNode* r, const Var& v) final {
Expr e = GetRef<Expr>(r);
return Compound(e, RefCreateNode::make(VisitExpr(r->value)), v);
}
Expr VisitExpr_(const RefReadNode* r, const Var& v) final {
Expr e = GetRef<Expr>(r);
return Compound(e, RefReadNode::make(VisitExpr(r->ref)), v);
}
Expr VisitExpr_(const RefWriteNode* r, const Var& v) final {
Expr e = GetRef<Expr>(r);
return Compound(e, RefWriteNode::make(VisitExpr(r->ref), VisitExpr(r->value)), v);
}
Expr VisitExpr_(const IfNode* i, const Var& v) final { Expr VisitExpr_(const IfNode* i, const Var& v) final {
Expr e = GetRef<Expr>(i); Expr e = GetRef<Expr>(i);
Expr ret = IfNode::make(VisitExpr(i->cond), Expr ret = IfNode::make(VisitExpr(i->cond),
...@@ -356,6 +404,23 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)> { ...@@ -356,6 +404,23 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)> {
Expr VisitExpr_(const OpNode* op, const Var& v) final { Expr VisitExpr_(const OpNode* op, const Var& v) final {
return GetRef<Expr>(op); return GetRef<Expr>(op);
} }
Expr VisitExpr_(const ConstructorNode* c, const Var& v) final {
return GetRef<Expr>(c);
}
Expr VisitExpr_(const MatchNode* m, const Var& v) final {
Expr e = GetRef<Expr>(m);
Expr data = VisitExpr(m->data);
std::vector<Clause> clauses;
for (const Clause& c : m->clauses) {
clauses.push_back(ClauseNode::make(
c->lhs,
GetSubScope(e, 1 + clauses.size())->ll->Get(VisitExpr(c->rhs))));
}
Expr r = Compound(e, MatchNode::make(data, clauses), v);
return r;
}
}; };
Expr ToANFAux(const Expr& e, const Module& m, std::set<GlobalVar>* gv) { Expr ToANFAux(const Expr& e, const Module& m, std::set<GlobalVar>* gv) {
......
...@@ -3,7 +3,8 @@ import tvm ...@@ -3,7 +3,8 @@ import tvm
from tvm import relay from tvm import relay
from tvm.relay.ir_pass import to_anf, alpha_equal, infer_type from tvm.relay.ir_pass import to_anf, alpha_equal, infer_type
from tvm.relay import op, create_executor from tvm.relay import op, create_executor
from tvm.relay.backend.interpreter import Value, TupleValue from tvm.relay.backend.interpreter import Value, TupleValue, ConstructorValue
from tvm.relay.prelude import Prelude
def check_eval(expr, expected_result, mod=None, rtol=1e-07): def check_eval(expr, expected_result, mod=None, rtol=1e-07):
...@@ -99,8 +100,48 @@ def test_recursion(): ...@@ -99,8 +100,48 @@ def test_recursion():
check_eval(f(relay.const(5, 'int64')), 30.0, mod=mod) check_eval(f(relay.const(5, 'int64')), 30.0, mod=mod)
def test_ref():
i = relay.Var('i')
iv = relay.Var('iv')
u = relay.Var('u')
uv = relay.Var('uv')
body = relay.add(iv, uv)
body = relay.Let(uv, relay.RefRead(i), body)
body = relay.Let(u, relay.RefWrite(i, relay.const(2)), body)
body = relay.Let(iv, relay.RefRead(i), body)
body = relay.Let(i, relay.RefCreate(relay.const(1)), body)
check_eval(body, 3)
check_eval(to_anf(body), 3)
# this is an example of using the adt value in python side
def count(n):
assert isinstance(n, ConstructorValue)
if n.constructor.name_hint == 's':
return 1 + count(n.fields[0])
else:
assert n.constructor.name_hint == 'z'
return 0
def test_add():
mod = relay.Module()
p = Prelude(mod)
nat = p.nat
add = p.add
s = p.s
z = p.z
ctx = tvm.context("llvm", 0)
intrp = create_executor(mod=mod, ctx=ctx, target="llvm")
assert mod[add].checked_type == relay.FuncType([nat(), nat()], nat())
assert count(intrp.evaluate(add(s(z()), s(z())))) == 2
assert count(intrp.evaluate(to_anf(add(s(z()), s(z())), mod))) == 2
assert "let" in mod[add].astext()
if __name__ == '__main__': if __name__ == '__main__':
test_explicit_bound() test_explicit_bound()
test_order() test_order()
test_if() test_if()
test_recursion() test_recursion()
test_ref()
test_add()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment