Unverified Commit 09eb5082 by Trevor Morris Committed by GitHub

[BYOC] Prevent duplicate outputs in subgraph Tuple (#5320)

* Fix duplicate output in partitiongraph

* Add test case

* Fix test_annotated_regions with duplicate compiler_end outputs

* Revert "Fix duplicate output in partitiongraph"

This reverts commit e1f8ef3f4ca5b2aaa31ace6fa968bb50e5e4d1fa.

* Prevent duplicate outputs in Tuple in PartitionGraph

* Fix lint

* Add another test case for when regions are merged, and when TupleGetItem was duplicated

* Pull GetFunctionOutput out of branch, improve description of GetFunctionOutput

* Use std::move for GetFunctionOutput. Fix typo with testcase name

* Use tvm.transform.Sequential
parent e8138f7d
......@@ -23,6 +23,7 @@ import pytest
import tvm
import tvm.relay.testing
import tvm.relay.op as reg
from tvm import relay
from tvm import runtime
from tvm.relay import transform
......@@ -1036,6 +1037,138 @@ def test_multiple_use_of_an_output():
test_same_output_region()
test_different_output_region()
def test_duplicate_outputs():
target = "test_duplicate_outputs"
@reg.register("abs", "target." + target)
def abs(attrs, args): # pylint: disable=unused-variable
return True
def create_graph():
data = relay.var('data', shape=(10, 10))
x = relay.abs(data)
out_1 = relay.nn.relu(x)
out_2 = relay.tanh(x)
out_3 = relay.log(x)
out = relay.Tuple([out_1, out_2, out_3])
func = relay.Function([data], out)
return func
def expected():
mod = tvm.IRModule()
# function 0
f0_i0 = relay.var(target+"_0_i0", shape=(10, 10))
f0_o0 = relay.abs(f0_i0)
func0 = relay.Function([f0_i0], f0_o0)
func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1))
func0 = func0.with_attr("Compiler", target)
func0 = func0.with_attr("global_symbol", target+"_0")
gv0 = relay.GlobalVar(target+"_0")
mod[gv0] = func0
# body
data = relay.var('data', shape=(10, 10))
function_out = gv0(data)
out_1 = relay.nn.relu(function_out)
out_2 = relay.tanh(function_out)
out_3 = relay.log(function_out)
out = relay.Tuple([out_1, out_2, out_3])
func = relay.Function([data], out)
mod["main"] = func
return mod
mod = tvm.IRModule()
mod["main"] = create_graph()
seq = tvm.transform.Sequential([
transform.AnnotateTarget(target),
transform.MergeCompilerRegions(),
transform.PartitionGraph(),
])
ref_mod = expected()
partitioned = seq(mod)
assert tvm.ir.structural_equal(partitioned, ref_mod, map_free_vars=True)
def test_duplicate_merge_and_tuplegetitem():
target = "test_duplicate_merge_and_tuplegetitem"
@reg.register("nn.batch_norm", "target." + target)
def abs(attrs, args): # pylint: disable=unused-variable
return True
@reg.register("nn.relu", "target." + target)
def abs(attrs, args): # pylint: disable=unused-variable
return True
def create_graph():
data = relay.var('data', shape=(10, 10))
bn_gamma = relay.var("bn_gamma")
bn_beta = relay.var("bn_beta")
bn_mmean = relay.var("bn_mean")
bn_mvar = relay.var("bn_var")
x = relay.nn.batch_norm(data, bn_gamma, bn_beta, bn_mmean, bn_mvar)
out_1 = relay.nn.relu(x[0])
bn_out_1 = x[1]
out_2 = relay.tanh(bn_out_1)
out_3 = relay.log(bn_out_1)
out = relay.Tuple([out_1, out_2, out_3])
func = relay.Function([data, bn_gamma, bn_beta, bn_mmean, bn_mvar], out)
return func
def expected():
mod = tvm.IRModule()
# function 0
f0_i0 = relay.var(target+"_1_i0", shape=(10, 10))
f0_i1 = relay.var(target+"_1_i1")
f0_i2 = relay.var(target+"_1_i2")
f0_i3 = relay.var(target+"_1_i3")
f0_i4 = relay.var(target+"_1_i4")
f0_n0 = relay.nn.batch_norm(f0_i0, f0_i1, f0_i2, f0_i3, f0_i4)
f0_n1 = f0_n0[1]
f0_n2 = relay.nn.relu(f0_n0[0])
f0_o0 = relay.Tuple([f0_n1, f0_n2])
func0 = relay.Function([f0_i0, f0_i1, f0_i2, f0_i3, f0_i4], f0_o0)
func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1))
func0 = func0.with_attr("Compiler", target)
func0 = func0.with_attr("global_symbol", target+"_1")
gv0 = relay.GlobalVar(target+"_1")
mod[gv0] = func0
# body
data = relay.var('data', shape=(10, 10))
bn_gamma = relay.var("bn_gamma")
bn_beta = relay.var("bn_beta")
bn_mmean = relay.var("bn_mean")
bn_mvar = relay.var("bn_var")
function_out = gv0(data, bn_gamma, bn_beta, bn_mmean, bn_mvar)
get_out0 = relay.TupleGetItem(function_out, 0)
get_out1 = relay.TupleGetItem(function_out, 1)
out_2 = relay.tanh(get_out0)
out_3 = relay.log(get_out0)
out = relay.Tuple([get_out1, out_2, out_3])
func = relay.Function([data, bn_gamma, bn_beta, bn_mmean, bn_mvar], out)
mod["main"] = func
return mod
mod = tvm.IRModule()
mod["main"] = create_graph()
seq = tvm.transform.Sequential([
transform.AnnotateTarget(target),
transform.MergeCompilerRegions(),
transform.PartitionGraph(),
])
ref_mod = expected()
partitioned = seq(mod)
assert tvm.ir.structural_equal(partitioned, ref_mod, map_free_vars=True)
if __name__ == "__main__":
test_multi_node_compiler()
......@@ -1051,3 +1184,5 @@ if __name__ == "__main__":
test_mixed_single_multiple_outputs()
test_dnnl_fuse()
test_multiple_use_of_an_output()
test_duplicate_outputs()
test_duplicate_merge_and_tuplegetitem()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment