Commit 01e1699d by kun-zh Committed by Tianqi Chen

Generate Lower Bound Conditions for issue 1014 (#1091)

parent 515d4b6f
...@@ -477,9 +477,14 @@ std::vector<Expr> MakeBoundCheck( ...@@ -477,9 +477,14 @@ std::vector<Expr> MakeBoundCheck(
CHECK(iv->dom.defined()); CHECK(iv->dom.defined());
if (!skip_ivar_domain && !iv->dom.same_as(dom)) { if (!skip_ivar_domain && !iv->dom.same_as(dom)) {
Expr value = ComputeExpr<Sub>(value_map.at(iv), iv->dom->min); Expr value = ComputeExpr<Sub>(value_map.at(iv), iv->dom->min);
Expr vmax = EvalSet(value, iset_dmap).max(); IntSet s = EvalSet(value, iset_dmap);
Expr vmin = s.min();
Expr vmax = s.max();
if (vmin.type() != value.type() || !can_prove(vmin >= iv->dom->min)) {
preds.emplace_back(value >= 0);
}
if (vmax.type() != value.type() || !can_prove(vmax < iv->dom->extent)) { if (vmax.type() != value.type() || !can_prove(vmax < iv->dom->extent)) {
preds.emplace_back(value < iv->dom->extent); preds.emplace_back(value < (iv->dom->extent - iv->dom->min));
} }
} }
} }
......
...@@ -82,6 +82,7 @@ def test_copy_pad_split(): ...@@ -82,6 +82,7 @@ def test_copy_pad_split():
Ab = tvm.decl_buffer(A.shape, A.dtype, name='A') Ab = tvm.decl_buffer(A.shape, A.dtype, name='A')
Bb = tvm.decl_buffer(B.shape, B.dtype, name='B') Bb = tvm.decl_buffer(B.shape, B.dtype, name='B')
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64) stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64)
stmt = tvm.ir_pass.Simplify(stmt)
stmt = tvm.ir_pass.CanonicalSimplify(stmt) stmt = tvm.ir_pass.CanonicalSimplify(stmt)
def cb(src, dst, pad_before, pad_after, pad_value): def cb(src, dst, pad_before, pad_after, pad_value):
assert(dst.elem_offset.value == 0) assert(dst.elem_offset.value == 0)
......
...@@ -249,6 +249,18 @@ def test_schedule_cache_relayout3(): ...@@ -249,6 +249,18 @@ def test_schedule_cache_relayout3():
stmt = tvm.schedule.ScheduleOps(s, bounds) stmt = tvm.schedule.ScheduleOps(s, bounds)
def test_schedule_bound_condition():
A = tvm.placeholder((64,), name='A', dtype="float32")
Apad = tvm.compute((66,), lambda i: tvm.select(tvm.all(i>0, i < 65), A[i-1], tvm.const(0.)), name='Apad')
Apad2 = tvm.compute((66,), lambda i: Apad[i]*2, name='Apad2')
s = tvm.create_schedule(Apad2.op)
AL1 = s.cache_read(A,"local",[Apad])
s = s.normalize()
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
stmt = tvm.ir_pass.Simplify(stmt)
assert (isinstance(stmt.body.body.first.body.body.then_case, tvm.stmt.IfThenElse))
if __name__ == "__main__": if __name__ == "__main__":
test_schedule_middle_cache() test_schedule_middle_cache()
test_inline_multi_reduce() test_inline_multi_reduce()
...@@ -265,3 +277,4 @@ if __name__ == "__main__": ...@@ -265,3 +277,4 @@ if __name__ == "__main__":
test_schedule1() test_schedule1()
test_schedule2() test_schedule2()
test_schedule_cache() test_schedule_cache()
test_schedule_bound_condition()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment