Commit 749cb215 by masahi Committed by Tianqi Chen

fix handling a tuple node in op fusion (#2433)

parent e0a20ad4
......@@ -740,8 +740,9 @@ class FuseMutator : private ExprMutator {
Array<Expr> new_fields = GetNewArguments(tuple->fields, ret_group);
Tuple new_tuple = TupleNode::make(new_fields);
if (ret_group == gmap_.at(tuple)) {
bool isolated = true;
for (size_t i = 0; i < new_fields.size(); ++i) {
// This tuple is the root of its group. Check if all fields come from other groups.
bool isolated = new_fields.size() == ginfo_[ret_group].params.size();
for (size_t i = 0; i < new_fields.size() && isolated; ++i) {
isolated &= (new_fields[i].same_as(ginfo_[ret_group].params[i]));
}
if (isolated) {
......
......@@ -182,8 +182,46 @@ def test_tuple_root():
assert relay.ir_pass.alpha_equal(zz, after)
def test_tuple_strided_slice():
"""
Test fusion case where the number of fields of tuple and
the number of parameters to the function containing the tuple are different
"""
def before(dshape):
x = relay.var("x", shape=dshape)
slice1 = relay.strided_slice(x, begin=[0, 0], end=[dshape[1]//2, dshape[1]], strides=[1,1])
slice2 = relay.strided_slice(x, begin=[dshape[1]//2, 0], end=[dshape[0], dshape[1]], strides=[1,1])
out = relay.Tuple((slice1, slice2))
return relay.Function([x], out)
def expected(dshape):
x = relay.var("x", shape=dshape)
slice1 = relay.strided_slice(x, begin=[0, 0], end=[dshape[1]//2, dshape[1]], strides=[1,1])
slice2 = relay.strided_slice(x, begin=[dshape[1]//2, 0], end=[dshape[0], dshape[1]], strides=[1,1])
out = relay.Tuple((slice1, slice2))
f0 = relay.Function([x], out)
x = relay.var("x", shape=dshape)
y = relay.Call(f0, [x])
return relay.Function([x], y)
dshape = (64, 64)
z = before(dshape)
z = relay.ir_pass.infer_type(z)
zz = relay.ir_pass.fuse_ops(z, opt_level=0)
assert not relay.ir_pass.free_vars(zz)
zz = relay.ir_pass.fuse_ops(z, opt_level=2)
zz = relay.ir_pass.infer_type(zz)
assert not relay.ir_pass.free_vars(zz)
after = relay.ir_pass.infer_type(expected(dshape))
assert relay.ir_pass.alpha_equal(zz, after)
print(zz.astext())
if __name__ == "__main__":
test_fuse_simple()
test_conv2d_fuse()
test_concatenate()
test_tuple_root()
test_tuple_strided_slice()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment