Unverified Commit 629a293a by Tianqi Chen Committed by GitHub

[RELAY][PASS] FuseOps, fix input fusion rule for conv2d (#2110)

parent b2521604
......@@ -464,14 +464,15 @@ class GraphPartitioner {
return true;
}
/*!
* \brief Check all the node between src and sink satisfies fcond.
* \brief Check all the node and edge pattern
* between src and sink satisfies fcond.
*
* src and sink are not checked.
* src is not checked.
*
* \param src The source node.
* \param sink The termination node.
* \param fcond The condition to be checked.
* \tparam F the condition function.
* \tparam F the condition function, with signature
* \note sink must be a post-dominator of src.
*/
template<typename F>
......@@ -596,18 +597,24 @@ class GraphPartitioner {
}
}
} else if (group_node->pattern <= kBroadcast) {
// The fuse can be executed if all the intermediate ops are still broadcast.
auto fcond = [](OpPatternKind kind, bool is_sink) {
if (!is_sink) {
return kind <= kBroadcast;
} else {
return (kind <= kBroadcast ||
kind == kCommReduce ||
kind == kOutEWiseFusable);
// Pre-condition: can only be fused to parent which is injective or reduction.
if (dom_node->parent != nullptr &&
(dom_node->pattern <= kInjective ||
dom_node->pattern == kCommReduce)) {
// Check if all the intermediate ops are still broadcast.
// The final terminal node can already be fused to a OutEWiseFusable group.
auto fcond = [](OpPatternKind kind, bool is_sink) {
if (!is_sink) {
return kind <= kBroadcast;
} else {
return (kind <= kBroadcast ||
kind == kCommReduce ||
kind == kOutEWiseFusable);
}
};
if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) {
CommitFuse(graph_node, dom_node->parent->gnode);
}
};
if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) {
CommitFuse(graph_node, dom_node->parent->gnode);
}
} else if (group_node->pattern == kInjective) {
// defer injective fusion to second phase.
......
......@@ -29,10 +29,12 @@ def test_fuse_simple():
def test_conv2d_fuse():
"""Test fusion case of conv2d"""
def before(dshape):
x = relay.var("x", shape=dshape)
x = relay.add(x, relay.const(1, "float32"))
y = relay.nn.conv2d(x, relay.var("w1"),
kernel_size=(3, 3),
padding=(1, 1),
......@@ -54,6 +56,10 @@ def test_conv2d_fuse():
return relay.Function(relay.ir_pass.free_vars(z), z)
def expected(dshape):
# segment 0
x = relay.var("p0", shape=dshape)
y = relay.add(x, relay.const(1, "float32"))
f0 = relay.Function([x], y)
# segment 1
x = relay.var("p0", shape=dshape)
w = relay.var("p1")
......@@ -84,7 +90,8 @@ def test_conv2d_fuse():
f3 = relay.Function([x, w, offset], z3)
# compose
x = relay.var("x", shape=dshape)
y = relay.Call(f1, [x, relay.var("w1")])
y = relay.Call(f0, [x])
y = relay.Call(f1, [y, relay.var("w1")])
z2 = relay.Call(f2, [y, relay.var("w3")])
z3 = relay.Call(f3, [y, relay.var("w2"), z2])
z = z3
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment