"""codegen related to bool types"""

import tvm
import numpy as np

def test_cmp_load_store():
    n = 32
    A = tvm.placeholder((n,), name='A')
    B = tvm.placeholder((n,), name='B')
    C = tvm.compute(A.shape, lambda *i: A(*i) > B(*i), name='C')
    D = tvm.compute(C.shape, lambda *i: tvm.all(C(*i), A(*i) > 1), name="D")


    def check_llvm():
        if not tvm.module.enabled("llvm"):
            return
        s = tvm.create_schedule(D.op)
        xo, xi = s[C].split(C.op.axis[0], factor=4)
        xo1, xo2 = s[C].split(xo, factor=13)
        s[C].parallel(xo2)
        # BUILD and invoke the kernel.
        f = tvm.build(s, [A, B, D], "llvm")
        ctx = tvm.cpu(0)
        a_np = np.random.uniform(size=n).astype(A.dtype)
        a = tvm.nd.array(a_np, ctx)
        b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), ctx)
        d = tvm.nd.array(np.zeros(n, dtype=D.dtype), ctx)
        f(a, b, d)
        np.testing.assert_equal(
            d.asnumpy(), np.logical_and(a.asnumpy()> b.asnumpy(), a.asnumpy() > 1))

    def check_device(device):
        ctx = tvm.context(device, 0)
        if not ctx.exist:
            return
        s = tvm.create_schedule(D.op)
        for stage in [C, D]:
            xo, xi = s[stage].split(stage.op.axis[0], factor=4)
            s[stage].bind(xo, tvm.thread_axis("blockIdx.x"))
            s[stage].bind(xi, tvm.thread_axis("threadIdx.x"))
        f = tvm.build(s, [A, B, D], device)
        a_np = np.random.uniform(size=n).astype(A.dtype)
        a = tvm.nd.array(a_np, ctx)
        b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), ctx)
        d = tvm.nd.array(np.zeros(n, dtype=D.dtype), ctx)
        f(a, b, d)
        np.testing.assert_equal(
            d.asnumpy(), np.logical_and(a.asnumpy()> b.asnumpy(), a.asnumpy() > 1))


    check_llvm()
    for device in ["vulkan", "opencl", "cuda", "rocm", "metal"]:
        check_device(device)



if __name__ == "__main__":
    test_cmp_load_store()