Commit a2b45887 by Salem Derisavi Committed by Tianqi Chen

Ensure loop count is a constant before trying to unroll. (#2797)

parent 864da840
......@@ -78,7 +78,7 @@ class LoopUnroller : public IRMutator {
if ((auto_unroll && explicit_unroll_) ||
// unroll loops with extent = 1, no matter how many steps in body
(value <= auto_max_extent_ && auto_max_extent_ == 1)) {
(0 <= value && value <= auto_max_extent_ && auto_max_extent_ == 1)) {
return Unroll(op);
} else {
if (auto_unroll) {
......
......@@ -51,7 +51,20 @@ def test_unroll_fake_loop():
ret = tvm.ir_pass.UnrollLoop(stmt, 8, 0, 1, True)
assert isinstance(ret.first, tvm.stmt.Store)
def test_unroll_single_count_loops():
n = tvm.var('n')
A = tvm.placeholder((n,), name='A')
B = tvm.compute((n,), lambda *i: A(*i), name='B')
s = tvm.create_schedule(B.op)
s = s.normalize()
dom_map = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, dom_map)
# all parameters to UnrolLoops are default values except for
# auto_unroll_max_extent which has been set to 1 (default:0)
after_unroll_stmt = tvm.ir_pass.UnrollLoop(stmt, 0, 8, 1, True)
assert after_unroll_stmt == stmt
if __name__ == "__main__":
test_unroll_loop()
test_unroll_fake_loop()
test_unroll_single_count_loops()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment