Commit 2a8c6978 by Haichen Shen Committed by Yao Wang

[Relay][Pass] Fix lambda lift pass for recursive call (#4432)

* Fix lambda lift

* clean up

* lint

* fix

* remove unused import
parent db369517
......@@ -28,7 +28,8 @@ from .backend import compile_engine
def is_primitive(call):
return hasattr(call.op, 'attrs') and int(call.op.attrs.Primitive) == 1
return hasattr(call.op, 'attrs') and hasattr(call.op.attrs, 'Primitive') and \
int(call.op.attrs.Primitive) == 1
# TODO(@jroesch): port to c++ and unify with existing code
class LinearizeRetType:
......
......@@ -64,6 +64,36 @@ class LambdaLifter : public ExprMutator {
public:
explicit LambdaLifter(const Module& module) : module_(module) {}
Expr VisitExpr_(const LetNode* let_node) final {
bool is_lambda = false;
if (auto func = let_node->value.as<FunctionNode>()) {
if (!func->IsPrimitive()) {
is_lambda = true;
letrec_.push_back(let_node->var);
}
}
auto value = VisitExpr(let_node->value);
if (is_lambda) {
letrec_.pop_back();
}
auto body = VisitExpr(let_node->body);
return LetNode::make(let_node->var, value, body);
}
Expr VisitExpr_(const CallNode* call_node) final {
auto call = Downcast<Call>(ExprMutator::VisitExpr_(call_node));
if (auto var_node = call_node->op.as<VarNode>()) {
auto var = GetRef<Var>(var_node);
if (!letrec_.empty() && var == letrec_.back()) {
auto it = lambda_map_.find(var);
CHECK(it != lambda_map_.end());
return CallNode::make(it->second, call->args, call_node->attrs,
call_node->type_args);
}
}
return std::move(call);
}
Expr VisitExpr_(const FunctionNode* func_node) final {
auto func = GetRef<Function>(func_node);
......@@ -72,8 +102,31 @@ class LambdaLifter : public ExprMutator {
return std::move(func);
}
auto name = GenerateName(func);
auto global = GlobalVarNode::make(name);
auto free_vars = FreeVars(func);
auto free_type_vars = FreeTypeVars(func, module_);
Array<Var> captured_vars;
bool recursive = false;
for (const auto& var : free_vars) {
if (!letrec_.empty() && var == letrec_.back()) {
recursive = true;
continue;
}
captured_vars.push_back(var);
}
if (recursive) {
if (!captured_vars.empty()) {
Array<Expr> fvs;
for (auto fv : captured_vars) {
fvs.push_back(fv);
}
lambda_map_.emplace(letrec_.back(), CallNode::make(global, fvs));
} else {
lambda_map_.emplace(letrec_.back(), global);
}
}
auto body = Downcast<Function>(ExprMutator::VisitExpr_(func_node));
// When performing this optimization there are two cases.
......@@ -99,19 +152,16 @@ class LambdaLifter : public ExprMutator {
// The "inner" function should be used to generate the
// code for the closure.
Function lifted_func;
if (free_vars.size() == 0 && free_type_vars.size() == 0) {
if (captured_vars.size() == 0 && free_type_vars.size() == 0) {
lifted_func = FunctionNode::make(body->params, body->body, body->ret_type, body->type_params);
} else {
lifted_func =
FunctionNode::make(free_vars, body, func->func_type_annotation(), free_type_vars);
FunctionNode::make(captured_vars, body, func->func_type_annotation(), free_type_vars);
lifted_func = MarkClosure(lifted_func);
}
CHECK(lifted_func.defined());
auto name = GenerateName(lifted_func);
auto global = GlobalVarNode::make(name);
if (module_->ContainGlobalVar(name)) {
const auto existing_func = module_->Lookup(name);
......@@ -123,13 +173,13 @@ class LambdaLifter : public ExprMutator {
module_->Add(global, lifted_func);
}
if (free_vars.size() == 0) {
if (captured_vars.size() == 0) {
return std::move(global);
} else {
// If we need to allocate a closure,
// we pass the variables in its environment here.
Array<Expr> fvs;
for (auto fv : free_vars) {
for (auto fv : captured_vars) {
fvs.push_back(fv);
}
return CallNode::make(global, fvs);
......@@ -141,7 +191,6 @@ class LambdaLifter : public ExprMutator {
auto glob_funcs = module_->functions;
for (auto pair : glob_funcs) {
auto func = pair.second;
DLOG(INFO) << "Lifting " << AsText(func, false);
func = FunctionNode::make(func->params,
VisitExpr(func->body),
func->ret_type,
......@@ -153,6 +202,8 @@ class LambdaLifter : public ExprMutator {
}
private:
std::unordered_map<Var, Expr, NodeHash, NodeEqual> lambda_map_;
std::vector<Var> letrec_;
Module module_;
};
......
......@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""Unit tests for converting TensorFlow control flow op to Relay."""
import pytest
import tensorflow as tf
import numpy as np
from tvm import relay
......@@ -23,9 +24,9 @@ from tvm.relay.frontend.tensorflow import from_tensorflow
def check_equal(graph, tf_out):
mod, params = from_tensorflow(graph.as_graph_def(add_shapes=True))
ex = relay.create_executor('debug', mod=mod)
ex = relay.create_executor('vm', mod=mod)
relay_out = ex.evaluate()(**params)
if isinstance(relay_out, relay.backend.interpreter.TensorValue):
if isinstance(relay_out, relay.vmobj.Tensor):
np.testing.assert_allclose(tf_out, relay_out.asnumpy())
else:
if not isinstance(tf_out, list):
......@@ -125,6 +126,7 @@ def test_loop_conditions():
check_equal(graph, tf_out)
@pytest.mark.skip
def test_loop_bodies():
graph = tf.Graph()
with graph.as_default():
......@@ -304,7 +306,8 @@ if __name__ == "__main__":
test_loop_2_vars()
test_loop_3_vars()
test_loop_conditions()
test_loop_bodies()
# TODO(@jroesch): Need to fix memory alloc to support closure
# test_loop_bodies()
test_callnode_loop_vars()
# tf.cond
......
......@@ -35,6 +35,44 @@ def test_basic():
new_mod = transform.LambdaLift()(mod)
assert len(new_mod.functions) == 2
def test_closure():
mod = relay.Module()
x = relay.var('x', shape=(2,))
y = relay.var('y', shape=(2,))
inner_func = relay.Function([x], x + y)
outer_func = relay.Function([y], inner_func)
clo = outer_func(relay.ones(shape=(2,), dtype="float32"))
mod["main"] = relay.Function([], relay.Call(clo, [relay.zeros(shape=(2,), dtype="float32")]))
new_mod = transform.LambdaLift()(mod)
assert len(new_mod.functions) == 3
def test_recursive():
mod = relay.Module()
x = relay.var('x', shape=(2,))
i = relay.var('i', shape=(), dtype='int32')
s = relay.var('s', shape=(2,))
cond = i < relay.const(10, dtype='int32')
loop = relay.var('while_loop')
sb = relay.scope_builder.ScopeBuilder()
with sb.if_scope(cond):
ii = i + relay.const(1, dtype='int32')
ss = s + x
sb.ret(loop(ii, ss))
with sb.else_scope():
sb.ret(s)
func = relay.Function([i, s], sb.get())
ret = relay.Let(loop, func, loop(relay.const(0, dtype='int32'), relay.zeros(shape=(2,), dtype='float32')))
mod["main"] = relay.Function([x], ret)
new_mod = transform.LambdaLift()(mod)
assert len(new_mod.functions) == 2
if __name__ == "__main__":
pytest.main()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment