Commit a376eb30 by Junru Shao Committed by Tianqi Chen

[SCHEDULE] Fix boundary check (#2126)

* Fix boundary check

* Add unittest
parent cb8a70f4
......@@ -491,11 +491,12 @@ std::vector<Expr> MakeBoundCheck(
IntSet s = EvalSet(value, iset_dmap);
Expr vmin = s.min();
Expr vmax = s.max();
if (vmin.type() != value.type() || !can_prove(vmin >= iv->dom->min)) {
// The range of `value` resides in [vmin, vmax]
if (vmin.type() != value.type() || !can_prove(vmin >= 0)) {
preds.emplace_back(value >= 0);
}
if (vmax.type() != value.type() || !can_prove(vmax < iv->dom->extent)) {
preds.emplace_back(value < (iv->dom->extent - iv->dom->min));
preds.emplace_back(value < iv->dom->extent);
}
}
}
......
......@@ -12,6 +12,7 @@ def test_schedule0():
assert isinstance(bounds, tvm.container.Map)
stmt = tvm.schedule.ScheduleOps(s, bounds)
def test_schedule1():
m = tvm.var('m')
l = tvm.var('l')
......@@ -53,10 +54,13 @@ def test_schedule_scan():
assert tuple(res.shape) == (m, n)
s = tvm.create_schedule(res.op)
s = s.normalize()
ir = tvm.lower(s, [s_state], simple_mode=True)
assert not hasattr(ir.body.body.body.body.rest.body.body.rest.body, "condition")
bounds = tvm.schedule.InferBound(s)
assert(bounds[res.op.scan_axis].min.value == 1)
stmt = tvm.schedule.ScheduleOps(s, bounds)
def test_inline_multi_reduce():
def argmax_comp(x, y):
idx = tvm.select((x[1] >= y[1]), x[0], y[0])
......@@ -80,7 +84,6 @@ def test_inline_multi_reduce():
stmt = tvm.schedule.ScheduleOps(s, bounds)
def test_auto_inline():
m = tvm.var('m')
n = tvm.var('n')
......@@ -96,6 +99,7 @@ def test_auto_inline():
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
def test_schedule_const_bound():
n = 128
A = tvm.placeholder((n,), name='A')
......@@ -146,6 +150,7 @@ def test_scan_inline1():
s[s_x1].compute_inline()
stmt = tvm.lower(s, [x, res1, res2])
def test_scan_inline2():
m = tvm.var("m")
n = tvm.var("n")
......@@ -183,6 +188,7 @@ def test_schedule_cache():
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
def test_schedule_middle_cache():
m = tvm.var('m')
n = tvm.var('n')
......@@ -202,7 +208,6 @@ def test_schedule_middle_cache():
stmt = tvm.schedule.ScheduleOps(s, bounds)
def test_schedule_cache_relayout1():
m = tvm.var('m')
n = tvm.var('n')
......@@ -249,6 +254,7 @@ def test_schedule_cache_relayout3():
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
def test_schedule_cache_relayout4():
def _compute(*indice):
return A(*indice) + 1, B(*indice) / 2
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment