Unverified Commit 529ee1fe by masahi Committed by GitHub

[Relay] Fix VM compiler for while loop with free vars (#4889)

* add additional switch to handle nested call node

* Fix VM compiler for while loop with free var
parent d50ba721
...@@ -637,6 +637,9 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> { ...@@ -637,6 +637,9 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
// emit invoke closure here. // emit invoke closure here.
VisitExpr(GetRef<Var>(var_node)); VisitExpr(GetRef<Var>(var_node));
Emit(Instruction::InvokeClosure(last_register_, args_registers, NewRegister())); Emit(Instruction::InvokeClosure(last_register_, args_registers, NewRegister()));
} else if (auto inner_call_node = op.as<CallNode>()) {
VisitExpr(GetRef<Call>(inner_call_node));
Emit(Instruction::InvokeClosure(last_register_, args_registers, NewRegister()));
} else { } else {
// Finally if there are any other cases this is a bug. // Finally if there are any other cases this is a bug.
LOG(FATAL) << "internal error: unreachable code," LOG(FATAL) << "internal error: unreachable code,"
......
...@@ -23,6 +23,7 @@ from tvm import relay ...@@ -23,6 +23,7 @@ from tvm import relay
from tvm.relay.scope_builder import ScopeBuilder from tvm.relay.scope_builder import ScopeBuilder
from tvm.relay.testing.config import ctx_list from tvm.relay.testing.config import ctx_list
from tvm.relay.prelude import Prelude from tvm.relay.prelude import Prelude
from tvm.relay.loops import while_loop
from tvm.relay import testing from tvm.relay import testing
def check_result(args, expected_result, mod=None): def check_result(args, expected_result, mod=None):
...@@ -576,5 +577,31 @@ def test_vm_optimize(): ...@@ -576,5 +577,31 @@ def test_vm_optimize():
comp = relay.vm.VMCompiler() comp = relay.vm.VMCompiler()
opt_mod, _ = comp.optimize(mod, "llvm", params) opt_mod, _ = comp.optimize(mod, "llvm", params)
def test_loop_free_var():
x = relay.var('x', shape=(), dtype='int32')
i = relay.var('i', shape=(), dtype='int32')
s = relay.var('s', shape=(), dtype='int32')
def cond(i, _):
return i < relay.const(10, dtype='int32')
def body_no_free_var(i, acc):
incr = relay.const(1, "int32")
return i + incr, acc + i
def body_with_free_var(i, acc):
incr = relay.const(1, "int32")
return i + incr, acc + x
for args, body, expected in zip([[], [1]],
[body_no_free_var, body_with_free_var],
[45, 10]):
loop = while_loop(cond, [i, s], body)
tup = loop(relay.const(0, dtype='int32'), relay.zeros(shape=(), dtype='int32'))
ret = relay.TupleGetItem(tup, 1)
mod = tvm.IRModule()
mod["main"] = relay.Function(relay.analysis.free_vars(ret), ret)
check_result(args, expected, mod=mod)
if __name__ == "__main__": if __name__ == "__main__":
pytest.main([__file__]) pytest.main([__file__])
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment