test_pass_unroll.py 1.28 KB
Newer Older
1
import tvm
2
import os
3

4

5
def test_unroll_loop():
6
    ib = tvm.ir_builder.create()
7
    dtype = 'int64'
8 9
    n = tvm.var('n')
    Ab = tvm.decl_buffer((n, ), dtype)
10
    Aptr = ib.buffer_ptr(Ab)
11
    # for i in 0 to n-1:
12 13 14 15 16
    with ib.for_range(n, n + 2, name="i") as i:
        with ib.for_range(0, 8, name="i", for_type="unroll") as j:
            Aptr[j + 1] = Aptr[i] + 1

    stmt = ib.get()
17
    assert isinstance(stmt, tvm.stmt.For)
18
    ret = tvm.ir_pass.UnrollLoop(stmt, 16, 8, 0, True)
19
    assert not isinstance(ret, tvm.stmt.For)
20
    ret = tvm.ir_pass.UnrollLoop(stmt, 15, 8, 0, True)
21
    assert isinstance(ret, tvm.stmt.For)
22
    ret = tvm.ir_pass.UnrollLoop(stmt, 16, 8, 0, False)
23 24
    assert isinstance(ret, tvm.stmt.For)
    assert ret.for_type == tvm.stmt.For.Unrolled
25

26 27 28 29 30 31 32 33 34 35 36
    ib = tvm.ir_builder.create()
    ib.scope_attr(tvm.const(0), "pragma_auto_unroll_max_step", 16)
    ib.emit(stmt)
    wrapped = ib.get()
    wrapped = tvm.make.Block(wrapped, stmt)
    assert isinstance(ret, tvm.stmt.For)
    ret = tvm.ir_pass.UnrollLoop(wrapped, 0, 8, 0, False)
    assert isinstance(ret.first, tvm.stmt.For)
    assert ret.first.for_type == tvm.stmt.For.Unrolled
    assert isinstance(ret.rest, tvm.stmt.For)
    assert ret.rest.for_type != tvm.stmt.For.Unrolled
37

38

39 40
if __name__ == "__main__":
    test_unroll_loop()