test_pass_storage_flatten.py 786 Bytes
Newer Older
1 2 3
import tvm

def test_flatten2():
4 5
    m = tvm.var('m')
    l = tvm.var('l')
6 7 8 9
    A = tvm.placeholder((m, l), name='A')
    A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1')
    A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')

10
    s = tvm.create_schedule(A2.op)
11 12 13 14
    xo, xi = s[A2].split(A2.op.axis[0], 8)
    s[A1].compute_at(s[A2], xo)
    bounds = tvm.schedule.InferBound(s)
    assert isinstance(bounds, tvm.collections.Map)
15
    stmt = tvm.schedule.ScheduleOps(s, bounds)
16 17

    print(stmt)
18 19
    Ab = tvm.decl_buffer(A.shape, A.dtype, name='A')
    A2b = tvm.decl_buffer(A2.shape, A2.dtype, name='A2')
20
    stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b})
21
    stmt = tvm.ir_pass.Simplify(stmt)
22 23 24 25
    print(stmt)

if __name__ == "__main__":
    test_flatten2()