Commit 9572d98e by Yida Wang Committed by Zhi

[Fix] Fix the logic of the number of nodes checking in op fusion (#4074)

* move the number of nodes constraint in op fusion up to the dom tree level

* add test case of limiting the max number of ops to be fused

* uncomment other test cases
parent 283afac0
......@@ -623,9 +623,7 @@ class GraphPartitioner {
* \param parent The parent group.
*/
void MergeFromTo(Group* child, Group* parent) {
// refuse the fusion if too many ops are going to be fused together
if (child->num_nodes + parent->num_nodes > kMaxFusedOps)
return;
// update the number of nodes of the parent group
parent->num_nodes += child->num_nodes;
child = child->FindRoot();
parent = parent->FindRoot();
......@@ -701,6 +699,10 @@ class GraphPartitioner {
CHECK(!graph_node->extern_ref);
size_t dom_parent_gindex = dom_node->parent->gnode->index;
// refuse the fusion if too many ops are going to be fused together
if (groups_[dom_parent_gindex]->num_nodes + group_node->num_nodes > kMaxFusedOps)
continue;
if (phase == 2) {
// Fuse injective ops into intermediate tuples, if any
if (group_node->pattern > kInjective) continue;
......
......@@ -552,6 +552,39 @@ def test_split():
mod["main"] = relay.Function([x], a + relay.RefRead(relay.RefCreate(b)) + c)
mod = transform.FuseOps()(mod)
def test_fuse_max():
"""Test the constraint of number of nodes in op fusion."""
max_fused_ops = 256
# n is the number of nodes to be fused, should be less than 2*max_fused_ops
n = 300
def before():
x = relay.var("x", shape=(10, 20))
y = x
for i in range(n):
y = relay.exp(y)
return relay.Function([x], y)
def expected():
x = relay.var("p", shape=(10, 20))
y = x
for i in range(max_fused_ops):
y = relay.exp(y)
f1 = relay.Function([x], y)
x = relay.var("x", shape=(10, 20))
z = relay.Call(f1, [x])
xx = relay.var("pp", shape=(10, 20))
yy = xx
for i in range(n-max_fused_ops):
yy = relay.exp(yy)
f2 = relay.Function([xx], yy)
zz = relay.Call(f2, [z])
return relay.Function([x], zz)
z = before()
zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2))
zz = run_opt_pass(z, transform.FuseOps())
after = run_opt_pass(expected(), transform.InferType())
assert relay.analysis.alpha_equal(zz, after)
if __name__ == "__main__":
test_fuse_simple()
......@@ -568,3 +601,4 @@ if __name__ == "__main__":
test_fuse_parallel_injective()
test_immutable()
test_split()
test_fuse_max()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment