Commit a1b86100 by 雾雨魔理沙 Committed by Tianqi Chen

[Relay] fix error in ANF (too agressively inline atomic expression and create…

[Relay] fix error in ANF (too agressively inline atomic expression and create free variable). (#2665)
parent 3c70b0d0
......@@ -256,6 +256,10 @@ bool IsPrimitiveFunction(const Expr& e) {
return e.as<FunctionNode>() && Downcast<Function>(e)->IsPrimitive();
}
/* Special care is needed to handle local recursion.
* Fill additionally take a (possibly null) Var argument,
* If it is not null, Fill is required to bind the transformed result to that var.
*/
class Fill : ExprFunctor<Expr(const Expr&, const Var&)> {
public:
static Expr ToANormalForm(const Expr& e,
......@@ -307,12 +311,18 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)> {
}
Expr VisitExpr(const Expr& e) {
Var v = VarNode::make(std::string("x"), IncompleteTypeNode::make(Kind::kType));
return this->VisitExpr(e, v);
return this->VisitExpr(e, Var());
}
Expr Atomic(const Expr& orig, const Expr& now, const Var& v) {
return v.defined() ? GetScope(orig)->ll->Push(v, now) : now;
}
Expr Compound(const Expr& orig, const Expr& now, const Var& v) {
return GetScope(orig)->ll->Push(v, now);
Var var = v.defined() ?
v :
VarNode::make(std::string("x"), IncompleteTypeNode::make(Kind::kType));
return GetScope(orig)->ll->Push(var, now);
}
Expr VisitExpr_(const CallNode* c, const Var& v) final {
......@@ -389,7 +399,8 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)> {
}
Expr VisitExpr_(const VarNode* vn, const Var& v) final {
return GetRef<Expr>(vn);
Expr e = GetRef<Expr>(vn);
return Atomic(e, e, v);
}
Expr VisitExpr_(const GlobalVarNode* gvn, const Var& v) final {
......@@ -398,15 +409,17 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)> {
visited_->insert(gv);
mod_->Update(gv, Downcast<Function>(relay::ToANormalForm(mod_->Lookup(gv), mod_, visited_)));
}
return std::move(gv);
return Atomic(gv, gv, v);
}
Expr VisitExpr_(const OpNode* op, const Var& v) final {
return GetRef<Expr>(op);
Expr e = GetRef<Expr>(op);
return Atomic(e, e, v);
}
Expr VisitExpr_(const ConstructorNode* c, const Var& v) final {
return GetRef<Expr>(c);
Expr e = GetRef<Expr>(c);
return Atomic(e, e, v);
}
Expr VisitExpr_(const MatchNode* m, const Var& v) final {
......@@ -418,8 +431,7 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)> {
c->lhs,
GetSubScope(e, 1 + clauses.size())->ll->Get(VisitExpr(c->rhs))));
}
Expr r = Compound(e, MatchNode::make(data, clauses), v);
return r;
return Compound(e, MatchNode::make(data, clauses), v);
}
};
......
......@@ -138,6 +138,15 @@ def test_add():
assert count(intrp.evaluate(to_a_normal_form(add(s(z()), s(z())), mod))) == 2
assert "let" in mod[add].astext()
def test_let():
x = relay.Var("x")
y = relay.Var("y")
d = relay.const(4.0, 'float32')
body = relay.Let(y, x, x + y)
body = relay.Let(x, d, body)
check_eval(body, 8)
check_eval(to_a_normal_form(body), 8)
if __name__ == '__main__':
test_explicit_bound()
test_order()
......@@ -145,3 +154,4 @@ if __name__ == '__main__':
test_recursion()
test_ref()
test_add()
test_let()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment