import tvm

def test_vthread():
    dtype = 'int64'
    n = 100
    m = 4
    nthread = 2
    def get_vthread(name):
        tx = tvm.thread_axis(name)
        ty = tvm.thread_axis(name)
        ib = tvm.ir_builder.create()
        A = ib.pointer("float32", name="A")
        C = ib.pointer("float32", name="C")
        with ib.for_range(0, n) as i:
            ib.scope_attr(tx, "virtual_thread", nthread)
            ib.scope_attr(ty, "virtual_thread", nthread)
            B = ib.allocate("float32", m, name="B", scope="shared")
            B[i] = A[i * nthread + tx]
            bbuffer = tvm.decl_buffer((m,), dtype=B.dtype, data=B.asnode())
            ib.emit(tvm.call_extern("int32", "Run",
                                    bbuffer.access_ptr("r"),
                                    tvm.call_pure_intrin("int32", "tvm_context_id")))
            C[i * nthread + tx] = B[i] + 1
        return ib.get()

    stmt = tvm.ir_pass.InjectVirtualThread(get_vthread("vthread"))
    assert stmt.body.body.extents[0].value == 2
    stmt = tvm.ir_pass.InjectVirtualThread(get_vthread("cthread"))
    assert len(stmt.body.body.extents) == 3

if __name__ == "__main__":
    test_vthread()