Unverified Commit 2586b4d3 by Wei Chen Committed by GitHub

[Relay][VM] Fix compilation of If-Elses (#5040)

parent d56829ea
......@@ -366,7 +366,9 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
this->Emit(Instruction::If(test_register, target_register, 0, 0));
this->VisitExpr(if_node->true_branch);
size_t true_register = last_register_;
// It saves the result of If-Else expression.
auto merge_register = NewRegister();
Emit(Instruction::Move(last_register_, merge_register));
Emit(Instruction::Goto(0));
// Finally store how many instructions there are in the
......@@ -378,7 +380,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
size_t false_register = last_register_;
// In else-branch, override the then-branch register
Emit(Instruction::Move(false_register, true_register));
Emit(Instruction::Move(false_register, merge_register));
// Compute the total number of instructions
// after generating false.
auto after_false = this->instructions_.size();
......@@ -397,7 +399,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
// Patch the Goto.
this->instructions_[after_true - 1].pc_offset = (after_false - after_true) + 1;
this->last_register_ = true_register;
this->last_register_ = merge_register;
}
void EmitShapeFunc(Function func, Array<Expr> inputs, Array<Expr> outputs) {
......
......@@ -142,6 +142,25 @@ def test_simple_if():
# diff
check_result([x_data, y_data], y_data, mod=mod)
def test_multiple_ifs():
mod = tvm.IRModule({})
b = relay.var('b')
v0 = relay.var('v0')
v1 = relay.var('v1')
v2 = relay.var('v2')
v3 = relay.var('v3')
out = relay.Tuple([v2, v3])
out = relay.Let(v3, relay.If(b, v1, v0), out)
out = relay.Let(v2, relay.If(b, v0, v1), out)
out = relay.Let(v1, relay.Tuple([relay.const(1)]), out)
out = relay.Let(v0, relay.Tuple([relay.const(0)]), out)
fn = relay.Function([b], out)
mod['main'] = fn
ctx = tvm.runtime.ndarray.context('llvm', 0)
vm = relay.create_executor(ctx=ctx, mod=mod, kind='vm')
res = vmobj_to_list(vm.evaluate()(False))
assert(res == [1, 0])
def test_simple_call():
mod = tvm.IRModule({})
sum_up = relay.GlobalVar('sum_up')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment