Commit a1dfb9ae by xqdan Committed by Tianqi Chen

[PASS]unroll loops with extent=1 (#2027)

parent ed1718b6
...@@ -76,7 +76,9 @@ class LoopUnroller : public IRMutator { ...@@ -76,7 +76,9 @@ class LoopUnroller : public IRMutator {
normal_loop_depth_ += 1; normal_loop_depth_ += 1;
} }
if (auto_unroll && explicit_unroll_) { if ((auto_unroll && explicit_unroll_) ||
// unroll loops with extent = 1, no matter how many steps in body
(value <= auto_max_extent_ && auto_max_extent_ == 1)) {
return Unroll(op); return Unroll(op);
} else { } else {
if (auto_unroll) { if (auto_unroll) {
......
...@@ -35,6 +35,23 @@ def test_unroll_loop(): ...@@ -35,6 +35,23 @@ def test_unroll_loop():
assert isinstance(ret.rest, tvm.stmt.For) assert isinstance(ret.rest, tvm.stmt.For)
assert ret.rest.for_type != tvm.stmt.For.Unrolled assert ret.rest.for_type != tvm.stmt.For.Unrolled
def test_unroll_fake_loop():
ib = tvm.ir_builder.create()
dtype = 'int32'
n = tvm.var('n')
Ab = tvm.decl_buffer((n, ), dtype)
Aptr = ib.buffer_ptr(Ab)
# for i in 0 to n-1:
with ib.for_range(0, 1, name="i") as i:
Aptr[i*2] = 3
with ib.for_range(0, 10, name="j") as j:
Aptr[j + 1] = Aptr[i] + 1
stmt = ib.get()
ret = tvm.ir_pass.UnrollLoop(stmt, 8, 0, 1, True)
assert isinstance(ret.first, tvm.stmt.Store)
if __name__ == "__main__": if __name__ == "__main__":
test_unroll_loop() test_unroll_loop()
test_unroll_fake_loop()
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment