test_pass_inject_vthread.py 2.53 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30
import tvm

def test_vthread():
    dtype = 'int64'
    n = 100
    m = 4
    nthread = 2
    def get_vthread(name):
        tx = tvm.thread_axis(name)
        ty = tvm.thread_axis(name)
        ib = tvm.ir_builder.create()
        A = ib.pointer("float32", name="A")
        C = ib.pointer("float32", name="C")
        with ib.for_range(0, n) as i:
            ib.scope_attr(tx, "virtual_thread", nthread)
            ib.scope_attr(ty, "virtual_thread", nthread)
            B = ib.allocate("float32", m, name="B", scope="shared")
            B[i] = A[i * nthread + tx]
            bbuffer = tvm.decl_buffer((m,), dtype=B.dtype, data=B.asnode())
            ib.emit(tvm.call_extern("int32", "Run",
                                    bbuffer.access_ptr("r"),
                                    tvm.call_pure_intrin("int32", "tvm_context_id")))
            C[i * nthread + tx] = B[i] + 1
        return ib.get()

    stmt = tvm.ir_pass.InjectVirtualThread(get_vthread("vthread"))
    assert stmt.body.body.extents[0].value == 2
    stmt = tvm.ir_pass.InjectVirtualThread(get_vthread("cthread"))
    assert len(stmt.body.body.extents) == 3

31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63

def test_vthread_extern():
    dtype = 'int64'
    n = 100
    m = 4
    nthread = 2
    def get_vthread(name):
        tx = tvm.thread_axis(name)
        ty = tvm.thread_axis(name)
        ib = tvm.ir_builder.create()
        with ib.for_range(0, n) as i:
            ib.scope_attr(tx, "virtual_thread", nthread)
            ib.scope_attr(ty, "virtual_thread", nthread)
            A = ib.allocate("float32", m, name="A", scope="shared")
            B = ib.allocate("float32", m, name="B", scope="shared")
            C = ib.allocate("float32", m, name="C", scope="shared")
            cbuffer = tvm.decl_buffer((m,), dtype=C.dtype, data=C.asnode())
            abuffer = tvm.decl_buffer((m,), dtype=A.dtype, data=A.asnode())
            bbuffer = tvm.decl_buffer((m,), dtype=B.dtype, data=B.asnode())
            A[tx] = tx + 1.0
            B[ty] = ty + 1.0
            ib.emit(tvm.call_extern("int32", "Run",
                                    abuffer.access_ptr("r"),
                                    bbuffer.access_ptr("r"),
                                    cbuffer.access_ptr("rw")))
        return ib.get()

    stmt = tvm.ir_pass.InjectVirtualThread(get_vthread("vthread"))
    assert stmt.body.body.extents[0].value == 2
    assert stmt.body.body.body.body.body.body.extents[0].value == 2
    assert len(stmt.body.body.body.body.body.body.extents) == 3


64
if __name__ == "__main__":
65
    test_vthread_extern()
66
    test_vthread()