Commit f1fcbaf9 by masahi Committed by Tianqi Chen

[Relay, OpFusion] Better tuple fusion implementation (#3092)

parent d0dca01a
......@@ -49,6 +49,9 @@ enum OpPatternKind {
// Complex operation, can still fuse elemwise operations into its output.
// but cannot chain another complex op
kOutEWiseFusable = 4,
// The pattern for tuple nodes. Can fuse into subsequent injective ops,
// but treated specially
kTuple = 7,
// Opaque operation, cannot fuse anything.
kOpaque = 8
......@@ -112,6 +112,8 @@ class OpPattern(object):
# Complex op, can still fuse ewise into it
# Represents tuple node
# Not fusable opaque op
......@@ -267,7 +267,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
void VisitExpr_(const TupleNode* op) final {
Node* tuple_node =;
tuple_node->pattern = kInjective;
tuple_node->pattern = kTuple;
for (const Expr& field : op->fields) {
if (field->checked_type().as<TensorTypeNode>()) {
this->Update(field, tuple_node, kInjective);
......@@ -661,12 +661,36 @@ class GraphPartitioner {
// no actions needed if the current node have no dominator
if (dom_node->parent == nullptr) continue;
// Skip if current node is already fused to the parent.
size_t dom_parent_gindex = dom_node->parent->gnode->index;
if (phase == 2) {
// Fuse injective ops into intermediate tuples, if any
if (group_node->pattern > kInjective) continue;
Group* dom_parent_group = groups_[dom_parent_gindex];
Group* dom_root_group = dom_parent_group->FindRoot();
// If dom node group has a tuple as its root, we do not fuse tuple fields into it
if (dom_root_group->pattern == kTuple) continue;
if (dom_parent_group->pattern == kTuple && dom_root_group->pattern <= kInjective) {
// Now we know the tuple has been fused into subsequent injective ops
auto fcond = [](OpPatternKind kind, bool is_sink) {
return kind <= kInjective;
// dom_root_group can also be tuple, as in inception layers
// CheckPath is needed to avoid fusing two intermediate tuples
if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) {
CommitFuse(graph_node, dom_node->parent->gnode);
// Skip if current node is already fused to the parent.
if (groups_[dom_parent_gindex] != nullptr &&
group_node->FindRoot() == groups_[dom_parent_gindex]->FindRoot()) {
// Do not fuse into tuple for now
if (groups_[dom_parent_gindex]->pattern == kTuple) continue;
// Try to fuse current node to its post-dominator.
if (group_node->pattern == kOutEWiseFusable) {
if (phase != 0) continue;
......@@ -702,7 +726,7 @@ class GraphPartitioner {
CommitFuse(graph_node, dom_node->parent->gnode);
} else if (group_node->pattern == kInjective) {
} else if (group_node->pattern == kInjective || group_node->pattern == kTuple) {
// defer injective fusion to second phase.
// so conv2d always finishes fusing.
if (phase != 1) continue;
......@@ -728,7 +752,7 @@ GraphPartitioner::Partition(const IndexedForwardGraph& graph) {
// get post dominator tree
auto post_dom_tree = DominatorTree::PostDom(arena_, graph);
// run fusion algorithm.
for (int phase = 0; phase < 2; ++phase) {
for (int phase = 0; phase < 3; ++phase) {
this->RunFuse(graph, post_dom_tree, phase);
return std::move(groups_);
......@@ -821,29 +845,11 @@ class FuseMutator : private ExprMutator {
Expr VisitExpr_(const TupleNode* tuple) {
auto* ret_group =>FindRoot();
Array<Expr> new_fields = GetNewArguments(tuple->fields, ret_group);
if (ret_group == {
// This tuple is the root of its group. Check if all fields come from other groups.
bool isolated = new_fields.size() == ginfo_[ret_group].params.size();
for (size_t i = 0; i < new_fields.size() && isolated; ++i) {
isolated &= (new_fields[i].same_as(ginfo_[ret_group].params[i]));
if (isolated) {
// Do not put a isolated tuple into a function
return ExprMutator::VisitExpr_(tuple);
// This tuple has been fused with other ops before it
for (size_t i = 0; i < new_fields.size(); i++) {
// Copy function arguments to tuple field of the output because currently graph memory
// planer doesn't support inplace operations
if (new_fields[i].as<VarNode>()) {
auto copy = Copy(new_fields[i]);
new_fields.Set(i, copy);
return MakeNewFunction(ret_group, tuple->checked_type(), TupleNode::make(new_fields));
// This tuple is an intermediate node in the group
Array<Expr> new_fields = GetNewArguments(tuple->fields, ret_group);
return TupleNode::make(new_fields);
......@@ -69,8 +69,16 @@ def test_compile_injective_with_tuple():, 'llvm')
def test_compile_tuple_dup():
x = relay.var("data", shape=(16, 16))
log = relay.log(x)
output = relay.Tuple([log, log])
f = relay.Function([x], output), 'llvm')
if __name__ == "__main__":
......@@ -176,16 +176,14 @@ def test_tuple_root():
f0 = relay.Function([x], pooled)
p0 = relay.var("p0", shape=(dshape[0], dshape[1], dshape[2]//2, dshape[3]//2))
p1 = relay.var("p1", shape=(dshape[0], dshape[1], dshape[2], dshape[3]))
p1_copy = relay.copy(p1)
upsampled = relay.nn.upsampling(p0, scale=2, layout="NCHW")
out = relay.Tuple((upsampled, p1_copy))
f1 = relay.Function([p0, p1], out)
f1 = relay.Function([p0], upsampled)
x = relay.var("x", shape=dshape)
y = relay.Call(f0, [x])
z = relay.Call(f1, [y, x])
return relay.Function([x], z)
z = relay.Call(f1, [y])
tup = relay.Tuple((z, x))
return relay.Function([x], tup)
dshape = (1, 16, 64, 64)
z = before(dshape)
......@@ -199,41 +197,6 @@ def test_tuple_root():
assert relay.ir_pass.alpha_equal(zz, after)
def test_tuple_strided_slice():
Test fusion case where the number of fields of tuple and
the number of parameters to the function containing the tuple are different
def before(dshape):
x = relay.var("x", shape=dshape)
slice1 = relay.strided_slice(x, begin=[0, 0], end=[dshape[1]//2, dshape[1]], strides=[1,1])
slice2 = relay.strided_slice(x, begin=[dshape[1]//2, 0], end=[dshape[0], dshape[1]], strides=[1,1])
out = relay.Tuple((slice1, slice2))
return relay.Function([x], out)
def expected(dshape):
x = relay.var("x", shape=dshape)
slice1 = relay.strided_slice(x, begin=[0, 0], end=[dshape[1]//2, dshape[1]], strides=[1,1])
slice2 = relay.strided_slice(x, begin=[dshape[1]//2, 0], end=[dshape[0], dshape[1]], strides=[1,1])
out = relay.Tuple((slice1, slice2))
f0 = relay.Function([x], out)
x = relay.var("x", shape=dshape)
y = relay.Call(f0, [x])
return relay.Function([x], y)
dshape = (64, 64)
z = before(dshape)
z = relay.ir_pass.infer_type(z)
zz = relay.ir_pass.fuse_ops(z, opt_level=0)
assert not relay.ir_pass.free_vars(zz)
zz = relay.ir_pass.fuse_ops(z, opt_level=2)
zz = relay.ir_pass.infer_type(zz)
assert not relay.ir_pass.free_vars(zz)
after = relay.ir_pass.infer_type(expected(dshape))
assert relay.ir_pass.alpha_equal(zz, after)
def test_stop_fusion():
def before(dshape):
......@@ -377,13 +340,178 @@ def test_tuple_get_root():
assert relay.ir_pass.alpha_equal(zz, after)
def test_tuple_intermediate():
def before(x):
inj = relay.squeeze(x)
y1 = relay.add(inj, relay.const(1, "float32"))
tmp = relay.squeeze(inj)
tmp = relay.add(tmp, relay.const(1, "float32"))
y2 = relay.add(tmp, relay.const(1, "float32"))
y3 = relay.add(inj, relay.const(1, "float32"))
concat = relay.concatenate((y1, y2, y3), axis=1)
out_inj = relay.squeeze(concat)
out = relay.add(out_inj, relay.const(1, "float32"))
return relay.Function(relay.ir_pass.free_vars(out), out)
def expected(p0):
f0 = before(p0)
x = relay.var("x", shape=dshape)
y = relay.Call(f0, [x])
return relay.Function([x], y)
dshape = (1, 16, 64, 64)
x = relay.var("x", shape=dshape)
z = before(x)
z = relay.ir_pass.infer_type(z)
zz = relay.ir_pass.fuse_ops(z, opt_level=0)
assert not relay.ir_pass.free_vars(zz)
zz = relay.ir_pass.fuse_ops(z, opt_level=2), 'llvm')
zz = relay.ir_pass.infer_type(zz)
assert not relay.ir_pass.free_vars(zz)
after = relay.ir_pass.infer_type(expected(x))
assert relay.ir_pass.alpha_equal(zz, after)
def test_tuple_consecutive():
def gen_intermediate_tuple(x):
y1 = relay.add(x, relay.const(1, "float32"))
y2 = relay.add(x, relay.const(1, "float32"))
y3 = relay.add(x, relay.const(1, "float32"))
concat = relay.concatenate((y1, y2, y3), axis=1)
out = relay.add(concat, relay.const(1, "float32"))
return out
def gen_consecutive_tuple(x):
y1 = gen_intermediate_tuple(x)
y2 = gen_intermediate_tuple(x)
y3 = gen_intermediate_tuple(x)
concat = relay.concatenate((y1, y2, y3), axis=1)
return concat
def before(x):
concat = gen_consecutive_tuple(x)
pooled = relay.nn.max_pool2d(concat, pool_size=(2, 2), strides=(2, 2), padding=(0, 0))
out = relay.add(pooled, relay.const(1, "float32"))
out2 = relay.add(out, relay.const(1, "float32"))
out_tup = relay.Tuple((out, out2))
return relay.Function(relay.ir_pass.free_vars(out_tup), out_tup)
def expected(dshape):
p0 = relay.var("p0", shape=dshape)
concat = gen_consecutive_tuple(p0)
f0 = relay.Function([p0], concat)
p01 = relay.var("p01", shape=(1, dshape[1]*9, dshape[2], dshape[3]))
pooled = relay.nn.max_pool2d(p01, pool_size=(2, 2), strides=(2, 2), padding=(0, 0))
out = relay.add(pooled, relay.const(1, "float32"))
f1 = relay.Function([p01], out)
p02 = relay.var("p02", shape=(1, dshape[1]*9, dshape[2]//2, dshape[3]//2))
out = relay.add(p02, relay.const(1, "float32"))
f2 = relay.Function([p02], out)
x = relay.var("x", shape=dshape)
y = relay.Call(f0, [x])
z = relay.Call(f1, [y])
z2 = relay.Call(f2, [z])
return relay.Function([x], relay.Tuple((z, z2)))
dshape = (1, 16, 64, 64)
x = relay.var("x", shape=dshape)
z = before(x)
z = relay.ir_pass.infer_type(z)
zz = relay.ir_pass.fuse_ops(z, opt_level=0)
assert not relay.ir_pass.free_vars(zz)
zz = relay.ir_pass.fuse_ops(z, opt_level=2), 'llvm')
zz = relay.ir_pass.infer_type(zz)
assert not relay.ir_pass.free_vars(zz)
after = relay.ir_pass.infer_type(expected(dshape))
assert relay.ir_pass.alpha_equal(zz, after)
def test_inception_like():
def conv(data):
y = relay.nn.conv2d(data, relay.var("w"),
kernel_size=(3, 3),
padding=(1, 1),
return relay.nn.relu(data=y)
def inception_like(data):
c0 = conv(data)
c1 = conv(data)
return relay.concatenate((c0, c1), axis=1)
def before(dshape):
x = relay.var("x", shape=dshape)
in1 = inception_like(x)
in2 = inception_like(in1)
return relay.Function(relay.ir_pass.free_vars(in2), in2)
def expected(dshape):
p0 = relay.var("p0", shape=dshape)
c = conv(p0)
f0 = relay.Function(relay.ir_pass.free_vars(c), c)
p01 = relay.var("p01", shape=dshape)
c = conv(p01)
f1 = relay.Function(relay.ir_pass.free_vars(c), c)
p02 = relay.var("p02", shape=dshape)
p12 = relay.var("p12", shape=dshape)
concat1 = relay.concatenate((p02, p12), axis=1)
f_concat1 = relay.Function([p02, p12], concat1)
dshape2 = (dshape[0], dshape[1]*2, dshape[2], dshape[3])
p03 = relay.var("p03", shape=dshape2)
c = conv(p03)
f2 = relay.Function(relay.ir_pass.free_vars(c), c)
p04 = relay.var("p04", shape=dshape2)
c = conv(p04)
f3 = relay.Function(relay.ir_pass.free_vars(c), c)
p05 = relay.var("p05", shape=dshape)
p15 = relay.var("p15", shape=dshape)
concat2 = relay.concatenate((p05, p15), axis=1)
f_concat2 = relay.Function([p05, p15], concat2)
x = relay.var("x", shape=dshape)
c1 = relay.Call(f0, [x, relay.var("w1")])
c2 = relay.Call(f1, [x, relay.var("w2")])
concat = relay.Call(f_concat1, [c1, c2])
c3 = relay.Call(f2, [concat, relay.var("w3")])
c4 = relay.Call(f3, [concat, relay.var("w4")])
out = relay.Call(f_concat2, [c3, c4])
return relay.Function(relay.ir_pass.free_vars(out), out)
dshape = (1, 16, 64, 64)
z = before(dshape)
z = relay.ir_pass.infer_type(z)
zz = relay.ir_pass.fuse_ops(z, opt_level=0)
assert not relay.ir_pass.free_vars(zz)
zz = relay.ir_pass.fuse_ops(z, opt_level=2), 'llvm')
zz = relay.ir_pass.infer_type(zz)
assert not relay.ir_pass.free_vars(zz)
after = relay.ir_pass.infer_type(expected(dshape))
assert relay.ir_pass.alpha_equal(zz, after)
if __name__ == "__main__":
