Commit a1b86100 by 雾雨魔理沙 Committed by Tianqi Chen

[Relay] fix error in ANF (too agressively inline atomic expression and create…

[Relay] fix error in ANF (too agressively inline atomic expression and create free variable). (#2665)
parent 3c70b0d0
...@@ -256,6 +256,10 @@ bool IsPrimitiveFunction(const Expr& e) { ...@@ -256,6 +256,10 @@ bool IsPrimitiveFunction(const Expr& e) {
return e.as<FunctionNode>() && Downcast<Function>(e)->IsPrimitive(); return e.as<FunctionNode>() && Downcast<Function>(e)->IsPrimitive();
} }
/* Special care is needed to handle local recursion.
* Fill additionally take a (possibly null) Var argument,
* If it is not null, Fill is required to bind the transformed result to that var.
*/
class Fill : ExprFunctor<Expr(const Expr&, const Var&)> { class Fill : ExprFunctor<Expr(const Expr&, const Var&)> {
public: public:
static Expr ToANormalForm(const Expr& e, static Expr ToANormalForm(const Expr& e,
...@@ -307,12 +311,18 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)> { ...@@ -307,12 +311,18 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)> {
} }
Expr VisitExpr(const Expr& e) { Expr VisitExpr(const Expr& e) {
Var v = VarNode::make(std::string("x"), IncompleteTypeNode::make(Kind::kType)); return this->VisitExpr(e, Var());
return this->VisitExpr(e, v); }
Expr Atomic(const Expr& orig, const Expr& now, const Var& v) {
return v.defined() ? GetScope(orig)->ll->Push(v, now) : now;
} }
Expr Compound(const Expr& orig, const Expr& now, const Var& v) { Expr Compound(const Expr& orig, const Expr& now, const Var& v) {
return GetScope(orig)->ll->Push(v, now); Var var = v.defined() ?
v :
VarNode::make(std::string("x"), IncompleteTypeNode::make(Kind::kType));
return GetScope(orig)->ll->Push(var, now);
} }
Expr VisitExpr_(const CallNode* c, const Var& v) final { Expr VisitExpr_(const CallNode* c, const Var& v) final {
...@@ -389,7 +399,8 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)> { ...@@ -389,7 +399,8 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)> {
} }
Expr VisitExpr_(const VarNode* vn, const Var& v) final { Expr VisitExpr_(const VarNode* vn, const Var& v) final {
return GetRef<Expr>(vn); Expr e = GetRef<Expr>(vn);
return Atomic(e, e, v);
} }
Expr VisitExpr_(const GlobalVarNode* gvn, const Var& v) final { Expr VisitExpr_(const GlobalVarNode* gvn, const Var& v) final {
...@@ -398,15 +409,17 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)> { ...@@ -398,15 +409,17 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)> {
visited_->insert(gv); visited_->insert(gv);
mod_->Update(gv, Downcast<Function>(relay::ToANormalForm(mod_->Lookup(gv), mod_, visited_))); mod_->Update(gv, Downcast<Function>(relay::ToANormalForm(mod_->Lookup(gv), mod_, visited_)));
} }
return std::move(gv); return Atomic(gv, gv, v);
} }
Expr VisitExpr_(const OpNode* op, const Var& v) final { Expr VisitExpr_(const OpNode* op, const Var& v) final {
return GetRef<Expr>(op); Expr e = GetRef<Expr>(op);
return Atomic(e, e, v);
} }
Expr VisitExpr_(const ConstructorNode* c, const Var& v) final { Expr VisitExpr_(const ConstructorNode* c, const Var& v) final {
return GetRef<Expr>(c); Expr e = GetRef<Expr>(c);
return Atomic(e, e, v);
} }
Expr VisitExpr_(const MatchNode* m, const Var& v) final { Expr VisitExpr_(const MatchNode* m, const Var& v) final {
...@@ -418,8 +431,7 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)> { ...@@ -418,8 +431,7 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)> {
c->lhs, c->lhs,
GetSubScope(e, 1 + clauses.size())->ll->Get(VisitExpr(c->rhs)))); GetSubScope(e, 1 + clauses.size())->ll->Get(VisitExpr(c->rhs))));
} }
Expr r = Compound(e, MatchNode::make(data, clauses), v); return Compound(e, MatchNode::make(data, clauses), v);
return r;
} }
}; };
......
...@@ -138,6 +138,15 @@ def test_add(): ...@@ -138,6 +138,15 @@ def test_add():
assert count(intrp.evaluate(to_a_normal_form(add(s(z()), s(z())), mod))) == 2 assert count(intrp.evaluate(to_a_normal_form(add(s(z()), s(z())), mod))) == 2
assert "let" in mod[add].astext() assert "let" in mod[add].astext()
def test_let():
x = relay.Var("x")
y = relay.Var("y")
d = relay.const(4.0, 'float32')
body = relay.Let(y, x, x + y)
body = relay.Let(x, d, body)
check_eval(body, 8)
check_eval(to_a_normal_form(body), 8)
if __name__ == '__main__': if __name__ == '__main__':
test_explicit_bound() test_explicit_bound()
test_order() test_order()
...@@ -145,3 +154,4 @@ if __name__ == '__main__': ...@@ -145,3 +154,4 @@ if __name__ == '__main__':
test_recursion() test_recursion()
test_ref() test_ref()
test_add() test_add()
test_let()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment