Commit 9b148f14 by Yizhi Liu Committed by Tianqi Chen

[schedule] Improve ceil_divide in tile/split (#3842)

parent d9bbdbc8
......@@ -56,6 +56,9 @@ void PassDownDomain(const Stage& stage,
arith::Analyzer* actx,
bool allow_missing) {
auto ceil_div = [actx](Expr a, Expr b) {
if (actx->CanProve(a % b == 0)) {
return actx->Simplify(a / b);
}
return actx->Simplify((a + (b - 1)) / b);
};
......
......@@ -69,6 +69,33 @@ def test_bound3():
assert(bounds[A1.op.axis[0]].extent.value==32)
assert(bounds[A1.op.axis[1]].extent.value==16)
def test_bound_split_divisible():
m = tvm.var('m')
l = tvm.var('l')
A = tvm.placeholder((8 * m, l), name='A')
B = tvm.compute((8 * m, l), lambda i, j: A[i, j], name='B')
s = tvm.create_schedule(B.op)
xo, xi = s[B].split(B.op.axis[0], 8)
bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.container.Map)
assert bounds[xo].extent == m
assert bounds[xi].extent.value == 8
def test_bound_tile_divisible():
m = tvm.var('m')
l = tvm.var('l')
shape = (8 * m, 32 * l)
A = tvm.placeholder(shape, name='A')
B = tvm.compute(shape, lambda i, j: A[i, j], name='B')
s = tvm.create_schedule(B.op)
xo, yo, xi, yi = s[B].tile(B.op.axis[0], B.op.axis[1], 8, 32)
bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.container.Map)
assert bounds[xo].extent == m
assert bounds[xi].extent.value == 8
assert bounds[yo].extent == l
assert bounds[yi].extent.value == 32
def test_bound_fusesplit1():
m = tvm.var('m')
l = tvm.var('l')
......@@ -393,3 +420,5 @@ if __name__ == "__main__":
test_bound_simplification_failure()
test_bound_fusesplit1()
test_bound_fusesplit2()
test_bound_split_divisible()
test_bound_tile_divisible()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment