Unverified Commit 629a293a by Tianqi Chen Committed by GitHub

[RELAY][PASS] FuseOps, fix input fusion rule for conv2d (#2110)

parent b2521604
...@@ -464,14 +464,15 @@ class GraphPartitioner { ...@@ -464,14 +464,15 @@ class GraphPartitioner {
return true; return true;
} }
/*! /*!
* \brief Check all the node between src and sink satisfies fcond. * \brief Check all the node and edge pattern
* between src and sink satisfies fcond.
* *
* src and sink are not checked. * src is not checked.
* *
* \param src The source node. * \param src The source node.
* \param sink The termination node. * \param sink The termination node.
* \param fcond The condition to be checked. * \param fcond The condition to be checked.
* \tparam F the condition function. * \tparam F the condition function, with signature
* \note sink must be a post-dominator of src. * \note sink must be a post-dominator of src.
*/ */
template<typename F> template<typename F>
...@@ -596,18 +597,24 @@ class GraphPartitioner { ...@@ -596,18 +597,24 @@ class GraphPartitioner {
} }
} }
} else if (group_node->pattern <= kBroadcast) { } else if (group_node->pattern <= kBroadcast) {
// The fuse can be executed if all the intermediate ops are still broadcast. // Pre-condition: can only be fused to parent which is injective or reduction.
auto fcond = [](OpPatternKind kind, bool is_sink) { if (dom_node->parent != nullptr &&
if (!is_sink) { (dom_node->pattern <= kInjective ||
return kind <= kBroadcast; dom_node->pattern == kCommReduce)) {
} else { // Check if all the intermediate ops are still broadcast.
return (kind <= kBroadcast || // The final terminal node can already be fused to a OutEWiseFusable group.
kind == kCommReduce || auto fcond = [](OpPatternKind kind, bool is_sink) {
kind == kOutEWiseFusable); if (!is_sink) {
return kind <= kBroadcast;
} else {
return (kind <= kBroadcast ||
kind == kCommReduce ||
kind == kOutEWiseFusable);
}
};
if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) {
CommitFuse(graph_node, dom_node->parent->gnode);
} }
};
if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) {
CommitFuse(graph_node, dom_node->parent->gnode);
} }
} else if (group_node->pattern == kInjective) { } else if (group_node->pattern == kInjective) {
// defer injective fusion to second phase. // defer injective fusion to second phase.
......
...@@ -29,10 +29,12 @@ def test_fuse_simple(): ...@@ -29,10 +29,12 @@ def test_fuse_simple():
def test_conv2d_fuse(): def test_conv2d_fuse():
"""Test fusion case of conv2d""" """Test fusion case of conv2d"""
def before(dshape): def before(dshape):
x = relay.var("x", shape=dshape) x = relay.var("x", shape=dshape)
x = relay.add(x, relay.const(1, "float32"))
y = relay.nn.conv2d(x, relay.var("w1"), y = relay.nn.conv2d(x, relay.var("w1"),
kernel_size=(3, 3), kernel_size=(3, 3),
padding=(1, 1), padding=(1, 1),
...@@ -54,6 +56,10 @@ def test_conv2d_fuse(): ...@@ -54,6 +56,10 @@ def test_conv2d_fuse():
return relay.Function(relay.ir_pass.free_vars(z), z) return relay.Function(relay.ir_pass.free_vars(z), z)
def expected(dshape): def expected(dshape):
# segment 0
x = relay.var("p0", shape=dshape)
y = relay.add(x, relay.const(1, "float32"))
f0 = relay.Function([x], y)
# segment 1 # segment 1
x = relay.var("p0", shape=dshape) x = relay.var("p0", shape=dshape)
w = relay.var("p1") w = relay.var("p1")
...@@ -84,7 +90,8 @@ def test_conv2d_fuse(): ...@@ -84,7 +90,8 @@ def test_conv2d_fuse():
f3 = relay.Function([x, w, offset], z3) f3 = relay.Function([x, w, offset], z3)
# compose # compose
x = relay.var("x", shape=dshape) x = relay.var("x", shape=dshape)
y = relay.Call(f1, [x, relay.var("w1")]) y = relay.Call(f0, [x])
y = relay.Call(f1, [y, relay.var("w1")])
z2 = relay.Call(f2, [y, relay.var("w3")]) z2 = relay.Call(f2, [y, relay.var("w3")])
z3 = relay.Call(f3, [y, relay.var("w2"), z2]) z3 = relay.Call(f3, [y, relay.var("w2"), z2])
z = z3 z = z3
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment