Unverified Commit 09eb5082 by Trevor Morris Committed by GitHub

[BYOC] Prevent duplicate outputs in subgraph Tuple (#5320)

* Fix duplicate output in partitiongraph

* Add test case

* Fix test_annotated_regions with duplicate compiler_end outputs

* Revert "Fix duplicate output in partitiongraph"

This reverts commit e1f8ef3f4ca5b2aaa31ace6fa968bb50e5e4d1fa.

* Prevent duplicate outputs in Tuple in PartitionGraph

* Fix lint

* Add another test case for when regions are merged, and when TupleGetItem was duplicated

* Pull GetFunctionOutput out of branch, improve description of GetFunctionOutput

* Use std::move for GetFunctionOutput. Fix typo with testcase name

* Use tvm.transform.Sequential
parent e8138f7d
...@@ -23,6 +23,7 @@ import pytest ...@@ -23,6 +23,7 @@ import pytest
import tvm import tvm
import tvm.relay.testing import tvm.relay.testing
import tvm.relay.op as reg
from tvm import relay from tvm import relay
from tvm import runtime from tvm import runtime
from tvm.relay import transform from tvm.relay import transform
...@@ -1036,6 +1037,138 @@ def test_multiple_use_of_an_output(): ...@@ -1036,6 +1037,138 @@ def test_multiple_use_of_an_output():
test_same_output_region() test_same_output_region()
test_different_output_region() test_different_output_region()
def test_duplicate_outputs():
target = "test_duplicate_outputs"
@reg.register("abs", "target." + target)
def abs(attrs, args): # pylint: disable=unused-variable
return True
def create_graph():
data = relay.var('data', shape=(10, 10))
x = relay.abs(data)
out_1 = relay.nn.relu(x)
out_2 = relay.tanh(x)
out_3 = relay.log(x)
out = relay.Tuple([out_1, out_2, out_3])
func = relay.Function([data], out)
return func
def expected():
mod = tvm.IRModule()
# function 0
f0_i0 = relay.var(target+"_0_i0", shape=(10, 10))
f0_o0 = relay.abs(f0_i0)
func0 = relay.Function([f0_i0], f0_o0)
func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1))
func0 = func0.with_attr("Compiler", target)
func0 = func0.with_attr("global_symbol", target+"_0")
gv0 = relay.GlobalVar(target+"_0")
mod[gv0] = func0
# body
data = relay.var('data', shape=(10, 10))
function_out = gv0(data)
out_1 = relay.nn.relu(function_out)
out_2 = relay.tanh(function_out)
out_3 = relay.log(function_out)
out = relay.Tuple([out_1, out_2, out_3])
func = relay.Function([data], out)
mod["main"] = func
return mod
mod = tvm.IRModule()
mod["main"] = create_graph()
seq = tvm.transform.Sequential([
transform.AnnotateTarget(target),
transform.MergeCompilerRegions(),
transform.PartitionGraph(),
])
ref_mod = expected()
partitioned = seq(mod)
assert tvm.ir.structural_equal(partitioned, ref_mod, map_free_vars=True)
def test_duplicate_merge_and_tuplegetitem():
target = "test_duplicate_merge_and_tuplegetitem"
@reg.register("nn.batch_norm", "target." + target)
def abs(attrs, args): # pylint: disable=unused-variable
return True
@reg.register("nn.relu", "target." + target)
def abs(attrs, args): # pylint: disable=unused-variable
return True
def create_graph():
data = relay.var('data', shape=(10, 10))
bn_gamma = relay.var("bn_gamma")
bn_beta = relay.var("bn_beta")
bn_mmean = relay.var("bn_mean")
bn_mvar = relay.var("bn_var")
x = relay.nn.batch_norm(data, bn_gamma, bn_beta, bn_mmean, bn_mvar)
out_1 = relay.nn.relu(x[0])
bn_out_1 = x[1]
out_2 = relay.tanh(bn_out_1)
out_3 = relay.log(bn_out_1)
out = relay.Tuple([out_1, out_2, out_3])
func = relay.Function([data, bn_gamma, bn_beta, bn_mmean, bn_mvar], out)
return func
def expected():
mod = tvm.IRModule()
# function 0
f0_i0 = relay.var(target+"_1_i0", shape=(10, 10))
f0_i1 = relay.var(target+"_1_i1")
f0_i2 = relay.var(target+"_1_i2")
f0_i3 = relay.var(target+"_1_i3")
f0_i4 = relay.var(target+"_1_i4")
f0_n0 = relay.nn.batch_norm(f0_i0, f0_i1, f0_i2, f0_i3, f0_i4)
f0_n1 = f0_n0[1]
f0_n2 = relay.nn.relu(f0_n0[0])
f0_o0 = relay.Tuple([f0_n1, f0_n2])
func0 = relay.Function([f0_i0, f0_i1, f0_i2, f0_i3, f0_i4], f0_o0)
func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1))
func0 = func0.with_attr("Compiler", target)
func0 = func0.with_attr("global_symbol", target+"_1")
gv0 = relay.GlobalVar(target+"_1")
mod[gv0] = func0
# body
data = relay.var('data', shape=(10, 10))
bn_gamma = relay.var("bn_gamma")
bn_beta = relay.var("bn_beta")
bn_mmean = relay.var("bn_mean")
bn_mvar = relay.var("bn_var")
function_out = gv0(data, bn_gamma, bn_beta, bn_mmean, bn_mvar)
get_out0 = relay.TupleGetItem(function_out, 0)
get_out1 = relay.TupleGetItem(function_out, 1)
out_2 = relay.tanh(get_out0)
out_3 = relay.log(get_out0)
out = relay.Tuple([get_out1, out_2, out_3])
func = relay.Function([data, bn_gamma, bn_beta, bn_mmean, bn_mvar], out)
mod["main"] = func
return mod
mod = tvm.IRModule()
mod["main"] = create_graph()
seq = tvm.transform.Sequential([
transform.AnnotateTarget(target),
transform.MergeCompilerRegions(),
transform.PartitionGraph(),
])
ref_mod = expected()
partitioned = seq(mod)
assert tvm.ir.structural_equal(partitioned, ref_mod, map_free_vars=True)
if __name__ == "__main__": if __name__ == "__main__":
test_multi_node_compiler() test_multi_node_compiler()
...@@ -1051,3 +1184,5 @@ if __name__ == "__main__": ...@@ -1051,3 +1184,5 @@ if __name__ == "__main__":
test_mixed_single_multiple_outputs() test_mixed_single_multiple_outputs()
test_dnnl_fuse() test_dnnl_fuse()
test_multiple_use_of_an_output() test_multiple_use_of_an_output()
test_duplicate_outputs()
test_duplicate_merge_and_tuplegetitem()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment