Commit a376eb30 by Junru Shao Committed by Tianqi Chen

[SCHEDULE] Fix boundary check (#2126)

* Fix boundary check

* Add unittest
parent cb8a70f4
...@@ -491,11 +491,12 @@ std::vector<Expr> MakeBoundCheck( ...@@ -491,11 +491,12 @@ std::vector<Expr> MakeBoundCheck(
IntSet s = EvalSet(value, iset_dmap); IntSet s = EvalSet(value, iset_dmap);
Expr vmin = s.min(); Expr vmin = s.min();
Expr vmax = s.max(); Expr vmax = s.max();
if (vmin.type() != value.type() || !can_prove(vmin >= iv->dom->min)) { // The range of `value` resides in [vmin, vmax]
if (vmin.type() != value.type() || !can_prove(vmin >= 0)) {
preds.emplace_back(value >= 0); preds.emplace_back(value >= 0);
} }
if (vmax.type() != value.type() || !can_prove(vmax < iv->dom->extent)) { if (vmax.type() != value.type() || !can_prove(vmax < iv->dom->extent)) {
preds.emplace_back(value < (iv->dom->extent - iv->dom->min)); preds.emplace_back(value < iv->dom->extent);
} }
} }
} }
......
...@@ -12,6 +12,7 @@ def test_schedule0(): ...@@ -12,6 +12,7 @@ def test_schedule0():
assert isinstance(bounds, tvm.container.Map) assert isinstance(bounds, tvm.container.Map)
stmt = tvm.schedule.ScheduleOps(s, bounds) stmt = tvm.schedule.ScheduleOps(s, bounds)
def test_schedule1(): def test_schedule1():
m = tvm.var('m') m = tvm.var('m')
l = tvm.var('l') l = tvm.var('l')
...@@ -53,10 +54,13 @@ def test_schedule_scan(): ...@@ -53,10 +54,13 @@ def test_schedule_scan():
assert tuple(res.shape) == (m, n) assert tuple(res.shape) == (m, n)
s = tvm.create_schedule(res.op) s = tvm.create_schedule(res.op)
s = s.normalize() s = s.normalize()
ir = tvm.lower(s, [s_state], simple_mode=True)
assert not hasattr(ir.body.body.body.body.rest.body.body.rest.body, "condition")
bounds = tvm.schedule.InferBound(s) bounds = tvm.schedule.InferBound(s)
assert(bounds[res.op.scan_axis].min.value == 1) assert(bounds[res.op.scan_axis].min.value == 1)
stmt = tvm.schedule.ScheduleOps(s, bounds) stmt = tvm.schedule.ScheduleOps(s, bounds)
def test_inline_multi_reduce(): def test_inline_multi_reduce():
def argmax_comp(x, y): def argmax_comp(x, y):
idx = tvm.select((x[1] >= y[1]), x[0], y[0]) idx = tvm.select((x[1] >= y[1]), x[0], y[0])
...@@ -80,7 +84,6 @@ def test_inline_multi_reduce(): ...@@ -80,7 +84,6 @@ def test_inline_multi_reduce():
stmt = tvm.schedule.ScheduleOps(s, bounds) stmt = tvm.schedule.ScheduleOps(s, bounds)
def test_auto_inline(): def test_auto_inline():
m = tvm.var('m') m = tvm.var('m')
n = tvm.var('n') n = tvm.var('n')
...@@ -96,6 +99,7 @@ def test_auto_inline(): ...@@ -96,6 +99,7 @@ def test_auto_inline():
bounds = tvm.schedule.InferBound(s) bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds) stmt = tvm.schedule.ScheduleOps(s, bounds)
def test_schedule_const_bound(): def test_schedule_const_bound():
n = 128 n = 128
A = tvm.placeholder((n,), name='A') A = tvm.placeholder((n,), name='A')
...@@ -146,6 +150,7 @@ def test_scan_inline1(): ...@@ -146,6 +150,7 @@ def test_scan_inline1():
s[s_x1].compute_inline() s[s_x1].compute_inline()
stmt = tvm.lower(s, [x, res1, res2]) stmt = tvm.lower(s, [x, res1, res2])
def test_scan_inline2(): def test_scan_inline2():
m = tvm.var("m") m = tvm.var("m")
n = tvm.var("n") n = tvm.var("n")
...@@ -183,6 +188,7 @@ def test_schedule_cache(): ...@@ -183,6 +188,7 @@ def test_schedule_cache():
bounds = tvm.schedule.InferBound(s) bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds) stmt = tvm.schedule.ScheduleOps(s, bounds)
def test_schedule_middle_cache(): def test_schedule_middle_cache():
m = tvm.var('m') m = tvm.var('m')
n = tvm.var('n') n = tvm.var('n')
...@@ -202,7 +208,6 @@ def test_schedule_middle_cache(): ...@@ -202,7 +208,6 @@ def test_schedule_middle_cache():
stmt = tvm.schedule.ScheduleOps(s, bounds) stmt = tvm.schedule.ScheduleOps(s, bounds)
def test_schedule_cache_relayout1(): def test_schedule_cache_relayout1():
m = tvm.var('m') m = tvm.var('m')
n = tvm.var('n') n = tvm.var('n')
...@@ -249,6 +254,7 @@ def test_schedule_cache_relayout3(): ...@@ -249,6 +254,7 @@ def test_schedule_cache_relayout3():
bounds = tvm.schedule.InferBound(s) bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds) stmt = tvm.schedule.ScheduleOps(s, bounds)
def test_schedule_cache_relayout4(): def test_schedule_cache_relayout4():
def _compute(*indice): def _compute(*indice):
return A(*indice) + 1, B(*indice) / 2 return A(*indice) + 1, B(*indice) / 2
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment