Commit 54f903a5 by bulanova-huawei Committed by Tianqi Chen

tightening bounding box for IntSet fused in PassUpDomain (#3073)

Apply suggestions from code review

Co-Authored-By: Wei Chen <ipondering.weic@gmail.com>
parent ce363d61
...@@ -270,10 +270,24 @@ void PassUpDomain(const FuseNode* s, ...@@ -270,10 +270,24 @@ void PassUpDomain(const FuseNode* s,
*outer = IntSet::single_point(v_outer); *outer = IntSet::single_point(v_outer);
*inner = IntSet::single_point(v_inner); *inner = IntSet::single_point(v_inner);
} else { } else {
LOG(WARNING) << "use fallback inference rule in fuse"; Expr fused_extent = (fused.max() - fused.min() + 1);
// simply use the entire set, this rule can be enhanced. Expr inner_extent = dom_map.at(s->inner)->extent;
*outer = IntSet::range(dom_map.at(s->outer)); *outer = IntSet::interval(outer_min + fused.min() / inner_extent,
outer_min + fused.max() / inner_extent);
if (is_zero(Simplify(inner_extent % fused_extent)) &&
is_zero(Simplify(fused.min() % fused_extent)) ) {
// fused never spans multiple rows, make a tight bounding box
// there may be other cases when bounding box could be tightened
*inner = IntSet::interval(inner_min + fused.min() % inner_extent,
inner_min + fused.max() % inner_extent);
} else { // fused may span multiple rows, use full row widths
if (!is_zero(Simplify(fused_extent % inner_extent)) ||
!is_zero(Simplify(fused.min() % inner_extent))) {
LOG(WARNING) <<
"fused and original axes are not aligned, this may cause redundant computations";
}
*inner = IntSet::range(dom_map.at(s->inner)); *inner = IntSet::range(dom_map.at(s->inner));
}
return; return;
} }
} }
......
...@@ -69,6 +69,55 @@ def test_bound3(): ...@@ -69,6 +69,55 @@ def test_bound3():
assert(bounds[A1.op.axis[0]].extent.value==32) assert(bounds[A1.op.axis[0]].extent.value==32)
assert(bounds[A1.op.axis[1]].extent.value==16) assert(bounds[A1.op.axis[1]].extent.value==16)
def test_bound_fusesplit1():
m = tvm.var('m')
l = tvm.var('l')
split1 = tvm.var('s')
A = tvm.placeholder((m, l), name='A')
A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1')
A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')
s = tvm.create_schedule(A2.op)
fused_axes = s[A2].fuse(A2.op.axis[0], A2.op.axis[1])
xo, xi = s[A2].split(fused_axes, split1)
s[A1].compute_at(s[A2], xo)
bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.container.Map)
assert(tvm.ir_pass.Simplify(bounds[A1.op.axis[0]].min - (xo * split1) / l ).value == 0)
expected_extent = (((xo + 1) * split1 - 1) / l - (xo * split1) / l + 1)
for i in range(1, 6):
for j in range(1, 6):
for k in range(1, 6):
vars = tvm.convert({split1: tvm.const(i, "int32"), l: tvm.const(j, "int32"), xo.var: tvm.const(k, "int32")})
comp_ext = tvm.ir_pass.Simplify(tvm.ir_pass.Substitute(bounds[A1.op.axis[0]].extent, vars)).value
exp_ext = tvm.ir_pass.Simplify(tvm.ir_pass.Substitute(expected_extent, vars)).value
assert(comp_ext == exp_ext)
assert(tvm.ir_pass.Simplify(bounds[A1.op.axis[1]].extent - l).value == 0)
def test_bound_fusesplit2():
m = tvm.var("m")
l = tvm.convert(6)
split = tvm.convert(3)
A = tvm.placeholder((m, l), name='A')
A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1')
A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')
s = tvm.create_schedule(A2.op)
fused_axes = s[A2].fuse(A2.op.axis[0], A2.op.axis[1])
xo, xi = s[A2].split(fused_axes, split)
s[A1].compute_at(s[A2], xo)
bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.container.Map)
vars = tvm.convert({xo.var: tvm.const(5, "int32")})
assert(tvm.ir_pass.Simplify(tvm.ir_pass.Substitute(bounds[A1.op.axis[0]].min, vars)).value == 2)
assert(tvm.ir_pass.Simplify(tvm.ir_pass.Substitute(bounds[A1.op.axis[1]].min, vars)).value == 3)
assert(tvm.ir_pass.Simplify(tvm.ir_pass.Substitute(bounds[A1.op.axis[0]].extent, vars)).value == 1)
assert(tvm.ir_pass.Simplify(tvm.ir_pass.Substitute(bounds[A1.op.axis[1]].extent, vars)).value == 3)
def test_bound_warp(): def test_bound_warp():
m = tvm.var('m') m = tvm.var('m')
...@@ -342,3 +391,5 @@ if __name__ == "__main__": ...@@ -342,3 +391,5 @@ if __name__ == "__main__":
test_bound_warp() test_bound_warp()
test_bound_tensor_compute_op() test_bound_tensor_compute_op()
test_bound_simplification_failure() test_bound_simplification_failure()
test_bound_fusesplit1()
test_bound_fusesplit2()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment