import tvm

def test_vthread():
    dtype = 'int64'
    n = 100
    m = 4
    nthread = 2
    def get_vthread(name):
        tx = tvm.thread_axis(name)
        ty = tvm.thread_axis(name)
        ib = tvm.ir_builder.create()
        A = ib.pointer("float32", name="A")
        C = ib.pointer("float32", name="C")
        with ib.for_range(0, n) as i:
            ib.scope_attr(tx, "virtual_thread", nthread)
            ib.scope_attr(ty, "virtual_thread", nthread)
            B = ib.allocate("float32", m, name="B", scope="shared")
            B[i] = A[i * nthread + tx]
            bbuffer = tvm.decl_buffer((m,), dtype=B.dtype, data=B.asnode())
            ib.emit(tvm.call_extern("int32", "Run",
                                    bbuffer.access_ptr("r"),
                                    tvm.call_pure_intrin("int32", "tvm_context_id")))
            C[i * nthread + tx] = B[i] + 1
        return ib.get()

    stmt = tvm.ir_pass.InjectVirtualThread(get_vthread("vthread"))
    assert stmt.body.body.extents[0].value == 2
    stmt = tvm.ir_pass.InjectVirtualThread(get_vthread("cthread"))
    assert len(stmt.body.body.extents) == 3


def test_vthread_extern():
    dtype = 'int64'
    n = 100
    m = 4
    nthread = 2
    def get_vthread(name):
        tx = tvm.thread_axis(name)
        ty = tvm.thread_axis(name)
        ib = tvm.ir_builder.create()
        with ib.for_range(0, n) as i:
            ib.scope_attr(tx, "virtual_thread", nthread)
            ib.scope_attr(ty, "virtual_thread", nthread)
            A = ib.allocate("float32", m, name="A", scope="shared")
            B = ib.allocate("float32", m, name="B", scope="shared")
            C = ib.allocate("float32", m, name="C", scope="shared")
            cbuffer = tvm.decl_buffer((m,), dtype=C.dtype, data=C.asnode())
            abuffer = tvm.decl_buffer((m,), dtype=A.dtype, data=A.asnode())
            bbuffer = tvm.decl_buffer((m,), dtype=B.dtype, data=B.asnode())
            A[tx] = tx + 1.0
            B[ty] = ty + 1.0
            ib.emit(tvm.call_extern("int32", "Run",
                                    abuffer.access_ptr("r"),
                                    bbuffer.access_ptr("r"),
                                    cbuffer.access_ptr("rw")))
        return ib.get()

    stmt = tvm.ir_pass.InjectVirtualThread(get_vthread("vthread"))
    assert stmt.body.body.extents[0].value == 2
    assert stmt.body.body.body.body.body.body.extents[0].value == 2
    assert len(stmt.body.body.body.body.body.body.extents) == 3

def test_vthread_if_then_else():
    nthread = 2
    tx = tvm.thread_axis("vthread")
    ib = tvm.ir_builder.create()
    A = ib.pointer("float32", name="A")
    with ib.for_range(0, 100) as i:
        ib.scope_attr(tx, "virtual_thread", nthread)
        B = ib.allocate("float32", 128, name="B", scope="shared")
        with ib.if_scope(i == 0):
            B[i] = A[i * nthread + tx]
        with ib.else_scope():
            B[i] = A[i * nthread + tx] + 1
        with ib.if_scope(i == 0):
            B[i] = A[i * nthread + tx] + 2
    stmt = ib.get()
    stmt = tvm.ir_pass.InjectVirtualThread(stmt)
    assert stmt.body.body.body.first.else_case != None
    assert stmt.body.body.body.rest.else_case == None

if __name__ == "__main__":
    test_vthread_extern()
    test_vthread()
    test_vthread_if_then_else()