test_topi_broadcast.py 4.24 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
"""Test code for broadcasting operators."""
import os
import numpy as np
import tvm
import topi

def verify_broadcast_to_ele(in_shape, out_shape):
    # Build the logic and compile the function
    A = tvm.placeholder(shape=in_shape, name="A")
    B = topi.cpp.broadcast_to(A, out_shape)
    def check_device(device):
        if not tvm.module.enabled(device):
            print("Skip because %s is not enabled" % device)
            return
        print("Running on target: %s" % device)
        target = topi.cpp.TEST_create_target(device)
        s = topi.cpp.cuda.schedule_injective(target, [B])
        ctx = tvm.context(device, 0)
        foo = tvm.build(s, [A, B], device, name="broadcast_to")
        data_npy = np.random.uniform(size=in_shape).astype(A.dtype)
        out_npy = np.broadcast_to(data_npy, out_shape)
        data_nd = tvm.nd.array(data_npy, ctx)
        out_nd = tvm.nd.array(np.empty(out_shape).astype(B.dtype), ctx)
        for _ in range(1):
            foo(data_nd, out_nd)
        np.testing.assert_allclose(out_nd.asnumpy(), out_npy)

    check_device("opencl")
    check_device("cuda")
    #check_device("metal")
    #check_device("rocm")


def verify_broadcast_binary_ele(lhs_shape, rhs_shape, typ="add"):
    # Build the logic and compile the function
    A = tvm.placeholder(shape=lhs_shape, name="A")
    B = tvm.placeholder(shape=rhs_shape, name="B")
    if typ == "add":
        C = topi.cpp.broadcast_add(A, B)
    elif typ == "sub":
        C = topi.cpp.broadcast_sub(A, B)
    elif typ == "div":
        C = topi.cpp.broadcast_div(A, B)
    elif typ == "mul":
        C = topi.cpp.broadcast_mul(A, B)
    elif typ == "maximum":
        C = topi.cpp.broadcast_maximum(A, B)
    elif typ == "minimum":
        C = topi.cpp.broadcast_minimum(A, B)
    elif typ == "pow":
        C = topi.cpp.broadcast_pow(A, B)
    else:
        raise NotImplementedError
    def check_device(device):
        if not tvm.module.enabled(device):
            print("Skip because %s is not enabled" % device)
            return
        print("Running on target: %s" % device)
        target = topi.cpp.TEST_create_target(device)
        s = topi.cpp.cuda.schedule_injective(target, [C])
        ctx = tvm.context(device, 0)
        foo = tvm.build(s, [A, B, C], device, name="broadcast_binary" + "_" + typ)
        lhs_npy = np.random.uniform(size=lhs_shape).astype(A.dtype)
        rhs_npy = np.random.uniform(size=rhs_shape).astype(A.dtype)
        if typ == "add":
            out_npy = lhs_npy + rhs_npy
        elif typ == "sub":
            out_npy = lhs_npy - rhs_npy
        elif typ == "div":
            rhs_npy = np.abs(rhs_npy) + 0.001
            out_npy = lhs_npy / rhs_npy
        elif typ == "mul":
            out_npy = lhs_npy * rhs_npy
        elif typ == "maximum":
            out_npy = np.maximum(lhs_npy, rhs_npy)
        elif typ == "minimum":
            out_npy = np.minimum(lhs_npy, rhs_npy)
        elif typ == "pow":
            out_npy = lhs_npy ** rhs_npy
        else:
            raise NotImplementedError
        lhs_nd = tvm.nd.array(lhs_npy, ctx)
        rhs_nd = tvm.nd.array(rhs_npy, ctx)
        out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(B.dtype), ctx)
        for _ in range(1):
            foo(lhs_nd, rhs_nd, out_nd)
        np.testing.assert_allclose(out_nd.asnumpy(), out_npy, rtol=1E-4, atol=1E-4)

    check_device("opencl")
    check_device("cuda")
    #check_device("metal")
    #check_device("rocm")

def test_broadcast_to():
    verify_broadcast_to_ele((1,), (10,))
    verify_broadcast_to_ele((), (10,))
    verify_broadcast_to_ele((1, 1, 5, 4), (3, 4, 4, 4, 5, 4))
    verify_broadcast_to_ele((1, 128, 1, 32), (64, 128, 64, 32))


def test_broadcast_binary():
    verify_broadcast_binary_ele((5, 2, 3), (2, 1), typ="add")
    verify_broadcast_binary_ele((5, 2, 3), (), typ="add")
    verify_broadcast_binary_ele((5, 64, 128), (2, 5, 64, 1), typ="mul")
    verify_broadcast_binary_ele((2, 3, 1, 32), (64, 32), typ="div")
    verify_broadcast_binary_ele((1, 32), (64, 32), typ="sub")
    verify_broadcast_binary_ele((32,), (64, 32), typ="maximum")
    verify_broadcast_binary_ele((1, 2, 2, 1, 32), (64, 32), typ="minimum")
    verify_broadcast_binary_ele((1, 32), (64, 32), typ="pow")


if __name__ == "__main__":
    test_broadcast_to()
    test_broadcast_binary()