Commit f1fcbaf9 by masahi Committed by Tianqi Chen

[Relay, OpFusion] Better tuple fusion implementation (#3092)

parent d0dca01a
...@@ -49,6 +49,9 @@ enum OpPatternKind { ...@@ -49,6 +49,9 @@ enum OpPatternKind {
// Complex operation, can still fuse elemwise operations into its output. // Complex operation, can still fuse elemwise operations into its output.
// but cannot chain another complex op // but cannot chain another complex op
kOutEWiseFusable = 4, kOutEWiseFusable = 4,
// The pattern for tuple nodes. Can fuse into subsequent injective ops,
// but treated specially
kTuple = 7,
// Opaque operation, cannot fuse anything. // Opaque operation, cannot fuse anything.
kOpaque = 8 kOpaque = 8
}; };
......
...@@ -112,6 +112,8 @@ class OpPattern(object): ...@@ -112,6 +112,8 @@ class OpPattern(object):
COMM_REDUCE = 3 COMM_REDUCE = 3
# Complex op, can still fuse ewise into it # Complex op, can still fuse ewise into it
OUT_ELEMWISE_FUSABLE = 4 OUT_ELEMWISE_FUSABLE = 4
# Represents tuple node
TUPLE = 7
# Not fusable opaque op # Not fusable opaque op
OPAQUE = 8 OPAQUE = 8
......
...@@ -267,7 +267,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor { ...@@ -267,7 +267,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
void VisitExpr_(const TupleNode* op) final { void VisitExpr_(const TupleNode* op) final {
CHECK(graph_.node_map.count(op)); CHECK(graph_.node_map.count(op));
Node* tuple_node = graph_.node_map.at(op); Node* tuple_node = graph_.node_map.at(op);
tuple_node->pattern = kInjective; tuple_node->pattern = kTuple;
for (const Expr& field : op->fields) { for (const Expr& field : op->fields) {
if (field->checked_type().as<TensorTypeNode>()) { if (field->checked_type().as<TensorTypeNode>()) {
this->Update(field, tuple_node, kInjective); this->Update(field, tuple_node, kInjective);
...@@ -661,12 +661,36 @@ class GraphPartitioner { ...@@ -661,12 +661,36 @@ class GraphPartitioner {
// no actions needed if the current node have no dominator // no actions needed if the current node have no dominator
if (dom_node->parent == nullptr) continue; if (dom_node->parent == nullptr) continue;
CHECK(!graph_node->extern_ref); CHECK(!graph_node->extern_ref);
// Skip if current node is already fused to the parent.
size_t dom_parent_gindex = dom_node->parent->gnode->index; size_t dom_parent_gindex = dom_node->parent->gnode->index;
if (phase == 2) {
// Fuse injective ops into intermediate tuples, if any
if (group_node->pattern > kInjective) continue;
Group* dom_parent_group = groups_[dom_parent_gindex];
Group* dom_root_group = dom_parent_group->FindRoot();
// If dom node group has a tuple as its root, we do not fuse tuple fields into it
if (dom_root_group->pattern == kTuple) continue;
if (dom_parent_group->pattern == kTuple && dom_root_group->pattern <= kInjective) {
// Now we know the tuple has been fused into subsequent injective ops
auto fcond = [](OpPatternKind kind, bool is_sink) {
return kind <= kInjective;
};
// dom_root_group can also be tuple, as in inception layers
// CheckPath is needed to avoid fusing two intermediate tuples
if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) {
CommitFuse(graph_node, dom_node->parent->gnode);
}
}
continue;
}
// Skip if current node is already fused to the parent.
if (groups_[dom_parent_gindex] != nullptr && if (groups_[dom_parent_gindex] != nullptr &&
group_node->FindRoot() == groups_[dom_parent_gindex]->FindRoot()) { group_node->FindRoot() == groups_[dom_parent_gindex]->FindRoot()) {
continue; continue;
} }
// Do not fuse into tuple for now
if (groups_[dom_parent_gindex]->pattern == kTuple) continue;
// Try to fuse current node to its post-dominator. // Try to fuse current node to its post-dominator.
if (group_node->pattern == kOutEWiseFusable) { if (group_node->pattern == kOutEWiseFusable) {
if (phase != 0) continue; if (phase != 0) continue;
...@@ -702,7 +726,7 @@ class GraphPartitioner { ...@@ -702,7 +726,7 @@ class GraphPartitioner {
CommitFuse(graph_node, dom_node->parent->gnode); CommitFuse(graph_node, dom_node->parent->gnode);
} }
} }
} else if (group_node->pattern == kInjective) { } else if (group_node->pattern == kInjective || group_node->pattern == kTuple) {
// defer injective fusion to second phase. // defer injective fusion to second phase.
// so conv2d always finishes fusing. // so conv2d always finishes fusing.
if (phase != 1) continue; if (phase != 1) continue;
...@@ -728,7 +752,7 @@ GraphPartitioner::Partition(const IndexedForwardGraph& graph) { ...@@ -728,7 +752,7 @@ GraphPartitioner::Partition(const IndexedForwardGraph& graph) {
// get post dominator tree // get post dominator tree
auto post_dom_tree = DominatorTree::PostDom(arena_, graph); auto post_dom_tree = DominatorTree::PostDom(arena_, graph);
// run fusion algorithm. // run fusion algorithm.
for (int phase = 0; phase < 2; ++phase) { for (int phase = 0; phase < 3; ++phase) {
this->RunFuse(graph, post_dom_tree, phase); this->RunFuse(graph, post_dom_tree, phase);
} }
return std::move(groups_); return std::move(groups_);
...@@ -821,29 +845,11 @@ class FuseMutator : private ExprMutator { ...@@ -821,29 +845,11 @@ class FuseMutator : private ExprMutator {
Expr VisitExpr_(const TupleNode* tuple) { Expr VisitExpr_(const TupleNode* tuple) {
auto* ret_group = gmap_.at(tuple)->FindRoot(); auto* ret_group = gmap_.at(tuple)->FindRoot();
Array<Expr> new_fields = GetNewArguments(tuple->fields, ret_group);
if (ret_group == gmap_.at(tuple)) { if (ret_group == gmap_.at(tuple)) {
// This tuple is the root of its group. Check if all fields come from other groups.
bool isolated = new_fields.size() == ginfo_[ret_group].params.size();
for (size_t i = 0; i < new_fields.size() && isolated; ++i) {
isolated &= (new_fields[i].same_as(ginfo_[ret_group].params[i]));
}
if (isolated) {
// Do not put a isolated tuple into a function
return ExprMutator::VisitExpr_(tuple); return ExprMutator::VisitExpr_(tuple);
} }
// This tuple has been fused with other ops before it
for (size_t i = 0; i < new_fields.size(); i++) {
// Copy function arguments to tuple field of the output because currently graph memory
// planer doesn't support inplace operations
if (new_fields[i].as<VarNode>()) {
auto copy = Copy(new_fields[i]);
new_fields.Set(i, copy);
}
}
return MakeNewFunction(ret_group, tuple->checked_type(), TupleNode::make(new_fields));
}
// This tuple is an intermediate node in the group // This tuple is an intermediate node in the group
Array<Expr> new_fields = GetNewArguments(tuple->fields, ret_group);
return TupleNode::make(new_fields); return TupleNode::make(new_fields);
} }
......
...@@ -69,8 +69,16 @@ def test_compile_injective_with_tuple(): ...@@ -69,8 +69,16 @@ def test_compile_injective_with_tuple():
relay.build(func, 'llvm') relay.build(func, 'llvm')
def test_compile_tuple_dup():
x = relay.var("data", shape=(16, 16))
log = relay.log(x)
output = relay.Tuple([log, log])
f = relay.Function([x], output)
relay.build(f, 'llvm')
if __name__ == "__main__": if __name__ == "__main__":
test_compile_engine() test_compile_engine()
test_compile_placeholder_bypass() test_compile_placeholder_bypass()
test_compile_injective_with_tuple() test_compile_injective_with_tuple()
test_compile_tuple_dup()
...@@ -176,16 +176,14 @@ def test_tuple_root(): ...@@ -176,16 +176,14 @@ def test_tuple_root():
f0 = relay.Function([x], pooled) f0 = relay.Function([x], pooled)
p0 = relay.var("p0", shape=(dshape[0], dshape[1], dshape[2]//2, dshape[3]//2)) p0 = relay.var("p0", shape=(dshape[0], dshape[1], dshape[2]//2, dshape[3]//2))
p1 = relay.var("p1", shape=(dshape[0], dshape[1], dshape[2], dshape[3]))
p1_copy = relay.copy(p1)
upsampled = relay.nn.upsampling(p0, scale=2, layout="NCHW") upsampled = relay.nn.upsampling(p0, scale=2, layout="NCHW")
out = relay.Tuple((upsampled, p1_copy)) f1 = relay.Function([p0], upsampled)
f1 = relay.Function([p0, p1], out)
x = relay.var("x", shape=dshape) x = relay.var("x", shape=dshape)
y = relay.Call(f0, [x]) y = relay.Call(f0, [x])
z = relay.Call(f1, [y, x]) z = relay.Call(f1, [y])
return relay.Function([x], z) tup = relay.Tuple((z, x))
return relay.Function([x], tup)
dshape = (1, 16, 64, 64) dshape = (1, 16, 64, 64)
z = before(dshape) z = before(dshape)
...@@ -199,41 +197,6 @@ def test_tuple_root(): ...@@ -199,41 +197,6 @@ def test_tuple_root():
assert relay.ir_pass.alpha_equal(zz, after) assert relay.ir_pass.alpha_equal(zz, after)
def test_tuple_strided_slice():
"""
Test fusion case where the number of fields of tuple and
the number of parameters to the function containing the tuple are different
"""
def before(dshape):
x = relay.var("x", shape=dshape)
slice1 = relay.strided_slice(x, begin=[0, 0], end=[dshape[1]//2, dshape[1]], strides=[1,1])
slice2 = relay.strided_slice(x, begin=[dshape[1]//2, 0], end=[dshape[0], dshape[1]], strides=[1,1])
out = relay.Tuple((slice1, slice2))
return relay.Function([x], out)
def expected(dshape):
x = relay.var("x", shape=dshape)
slice1 = relay.strided_slice(x, begin=[0, 0], end=[dshape[1]//2, dshape[1]], strides=[1,1])
slice2 = relay.strided_slice(x, begin=[dshape[1]//2, 0], end=[dshape[0], dshape[1]], strides=[1,1])
out = relay.Tuple((slice1, slice2))
f0 = relay.Function([x], out)
x = relay.var("x", shape=dshape)
y = relay.Call(f0, [x])
return relay.Function([x], y)
dshape = (64, 64)
z = before(dshape)
z = relay.ir_pass.infer_type(z)
zz = relay.ir_pass.fuse_ops(z, opt_level=0)
assert not relay.ir_pass.free_vars(zz)
zz = relay.ir_pass.fuse_ops(z, opt_level=2)
zz = relay.ir_pass.infer_type(zz)
assert not relay.ir_pass.free_vars(zz)
after = relay.ir_pass.infer_type(expected(dshape))
assert relay.ir_pass.alpha_equal(zz, after)
def test_stop_fusion(): def test_stop_fusion():
def before(dshape): def before(dshape):
...@@ -377,13 +340,178 @@ def test_tuple_get_root(): ...@@ -377,13 +340,178 @@ def test_tuple_get_root():
assert relay.ir_pass.alpha_equal(zz, after) assert relay.ir_pass.alpha_equal(zz, after)
def test_tuple_intermediate():
def before(x):
inj = relay.squeeze(x)
y1 = relay.add(inj, relay.const(1, "float32"))
tmp = relay.squeeze(inj)
tmp = relay.add(tmp, relay.const(1, "float32"))
y2 = relay.add(tmp, relay.const(1, "float32"))
y3 = relay.add(inj, relay.const(1, "float32"))
concat = relay.concatenate((y1, y2, y3), axis=1)
out_inj = relay.squeeze(concat)
out = relay.add(out_inj, relay.const(1, "float32"))
return relay.Function(relay.ir_pass.free_vars(out), out)
def expected(p0):
f0 = before(p0)
x = relay.var("x", shape=dshape)
y = relay.Call(f0, [x])
return relay.Function([x], y)
dshape = (1, 16, 64, 64)
x = relay.var("x", shape=dshape)
z = before(x)
z = relay.ir_pass.infer_type(z)
zz = relay.ir_pass.fuse_ops(z, opt_level=0)
assert not relay.ir_pass.free_vars(zz)
zz = relay.ir_pass.fuse_ops(z, opt_level=2)
relay.build(zz, 'llvm')
zz = relay.ir_pass.infer_type(zz)
assert not relay.ir_pass.free_vars(zz)
after = relay.ir_pass.infer_type(expected(x))
assert relay.ir_pass.alpha_equal(zz, after)
def test_tuple_consecutive():
def gen_intermediate_tuple(x):
y1 = relay.add(x, relay.const(1, "float32"))
y2 = relay.add(x, relay.const(1, "float32"))
y3 = relay.add(x, relay.const(1, "float32"))
concat = relay.concatenate((y1, y2, y3), axis=1)
out = relay.add(concat, relay.const(1, "float32"))
return out
def gen_consecutive_tuple(x):
y1 = gen_intermediate_tuple(x)
y2 = gen_intermediate_tuple(x)
y3 = gen_intermediate_tuple(x)
concat = relay.concatenate((y1, y2, y3), axis=1)
return concat
def before(x):
concat = gen_consecutive_tuple(x)
pooled = relay.nn.max_pool2d(concat, pool_size=(2, 2), strides=(2, 2), padding=(0, 0))
out = relay.add(pooled, relay.const(1, "float32"))
out2 = relay.add(out, relay.const(1, "float32"))
out_tup = relay.Tuple((out, out2))
return relay.Function(relay.ir_pass.free_vars(out_tup), out_tup)
def expected(dshape):
p0 = relay.var("p0", shape=dshape)
concat = gen_consecutive_tuple(p0)
f0 = relay.Function([p0], concat)
p01 = relay.var("p01", shape=(1, dshape[1]*9, dshape[2], dshape[3]))
pooled = relay.nn.max_pool2d(p01, pool_size=(2, 2), strides=(2, 2), padding=(0, 0))
out = relay.add(pooled, relay.const(1, "float32"))
f1 = relay.Function([p01], out)
p02 = relay.var("p02", shape=(1, dshape[1]*9, dshape[2]//2, dshape[3]//2))
out = relay.add(p02, relay.const(1, "float32"))
f2 = relay.Function([p02], out)
x = relay.var("x", shape=dshape)
y = relay.Call(f0, [x])
z = relay.Call(f1, [y])
z2 = relay.Call(f2, [z])
return relay.Function([x], relay.Tuple((z, z2)))
dshape = (1, 16, 64, 64)
x = relay.var("x", shape=dshape)
z = before(x)
z = relay.ir_pass.infer_type(z)
zz = relay.ir_pass.fuse_ops(z, opt_level=0)
assert not relay.ir_pass.free_vars(zz)
zz = relay.ir_pass.fuse_ops(z, opt_level=2)
relay.build(zz, 'llvm')
zz = relay.ir_pass.infer_type(zz)
assert not relay.ir_pass.free_vars(zz)
after = relay.ir_pass.infer_type(expected(dshape))
assert relay.ir_pass.alpha_equal(zz, after)
def test_inception_like():
def conv(data):
y = relay.nn.conv2d(data, relay.var("w"),
kernel_size=(3, 3),
padding=(1, 1),
channels=16)
return relay.nn.relu(data=y)
def inception_like(data):
c0 = conv(data)
c1 = conv(data)
return relay.concatenate((c0, c1), axis=1)
def before(dshape):
x = relay.var("x", shape=dshape)
in1 = inception_like(x)
in2 = inception_like(in1)
return relay.Function(relay.ir_pass.free_vars(in2), in2)
def expected(dshape):
p0 = relay.var("p0", shape=dshape)
c = conv(p0)
f0 = relay.Function(relay.ir_pass.free_vars(c), c)
p01 = relay.var("p01", shape=dshape)
c = conv(p01)
f1 = relay.Function(relay.ir_pass.free_vars(c), c)
p02 = relay.var("p02", shape=dshape)
p12 = relay.var("p12", shape=dshape)
concat1 = relay.concatenate((p02, p12), axis=1)
f_concat1 = relay.Function([p02, p12], concat1)
dshape2 = (dshape[0], dshape[1]*2, dshape[2], dshape[3])
p03 = relay.var("p03", shape=dshape2)
c = conv(p03)
f2 = relay.Function(relay.ir_pass.free_vars(c), c)
p04 = relay.var("p04", shape=dshape2)
c = conv(p04)
f3 = relay.Function(relay.ir_pass.free_vars(c), c)
p05 = relay.var("p05", shape=dshape)
p15 = relay.var("p15", shape=dshape)
concat2 = relay.concatenate((p05, p15), axis=1)
f_concat2 = relay.Function([p05, p15], concat2)
x = relay.var("x", shape=dshape)
c1 = relay.Call(f0, [x, relay.var("w1")])
c2 = relay.Call(f1, [x, relay.var("w2")])
concat = relay.Call(f_concat1, [c1, c2])
c3 = relay.Call(f2, [concat, relay.var("w3")])
c4 = relay.Call(f3, [concat, relay.var("w4")])
out = relay.Call(f_concat2, [c3, c4])
return relay.Function(relay.ir_pass.free_vars(out), out)
dshape = (1, 16, 64, 64)
z = before(dshape)
z = relay.ir_pass.infer_type(z)
zz = relay.ir_pass.fuse_ops(z, opt_level=0)
assert not relay.ir_pass.free_vars(zz)
zz = relay.ir_pass.fuse_ops(z, opt_level=2)
relay.build(zz, 'llvm')
zz = relay.ir_pass.infer_type(zz)
assert not relay.ir_pass.free_vars(zz)
after = relay.ir_pass.infer_type(expected(dshape))
assert relay.ir_pass.alpha_equal(zz, after)
if __name__ == "__main__": if __name__ == "__main__":
test_fuse_simple() test_fuse_simple()
test_conv2d_fuse() test_conv2d_fuse()
test_concatenate() test_concatenate()
test_tuple_root() test_tuple_root()
test_tuple_strided_slice()
test_stop_fusion() test_stop_fusion()
test_fuse_myia_regression() test_fuse_myia_regression()
test_fuse_tuple_get_elemwise() test_fuse_tuple_get_elemwise()
test_tuple_get_root() test_tuple_get_root()
test_tuple_intermediate()
test_tuple_consecutive()
test_inception_like()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment