import tvm import os def test_unroll_loop(): ib = tvm.ir_builder.create() dtype = 'int64' n = tvm.var('n') Ab = tvm.decl_buffer((n, ), dtype) Aptr = ib.buffer_ptr(Ab) # for i in 0 to n-1: with ib.for_range(n, n + 2, name="i") as i: with ib.for_range(0, 8, name="i", for_type="unroll") as j: Aptr[j + 1] = Aptr[i] + 1 stmt = ib.get() assert isinstance(stmt, tvm.stmt.For) ret = tvm.ir_pass.UnrollLoop(stmt, 16, 8, 0, True) assert not isinstance(ret, tvm.stmt.For) ret = tvm.ir_pass.UnrollLoop(stmt, 15, 8, 0, True) assert isinstance(ret, tvm.stmt.For) ret = tvm.ir_pass.UnrollLoop(stmt, 16, 8, 0, False) assert isinstance(ret, tvm.stmt.For) assert ret.for_type == tvm.stmt.For.Unrolled ib = tvm.ir_builder.create() ib.scope_attr(tvm.const(0), "pragma_auto_unroll_max_step", 16) ib.emit(stmt) wrapped = ib.get() wrapped = tvm.make.Block(wrapped, stmt) assert isinstance(ret, tvm.stmt.For) ret = tvm.ir_pass.UnrollLoop(wrapped, 0, 8, 0, False) assert isinstance(ret.first, tvm.stmt.For) assert ret.first.for_type == tvm.stmt.For.Unrolled assert isinstance(ret.rest, tvm.stmt.For) assert ret.rest.for_type != tvm.stmt.For.Unrolled if __name__ == "__main__": test_unroll_loop()