test_topi_reduce.py 5.57 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44
"""Test code for reduce."""
import os
import numpy as np
import tvm
import topi

def _my_npy_argmax(arr, axis, keepdims):
    if not keepdims:
        return arr.argmax(axis=axis)
    else:
        if axis is not None:
            out_shape = list(arr.shape)
            out_shape[axis] = 1
        else:
            out_shape = [1 for _ in range(len(arr.shape))]
        return arr.argmax(axis=axis).reshape(out_shape)


def _my_npy_argmin(arr, axis, keepdims):
    if not keepdims:
        return arr.argmin(axis=axis)
    else:
        out_shape = list(arr.shape)
        out_shape[axis] = 1
        return arr.argmin(axis=axis).reshape(out_shape)

def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"):
    # Build the logic and compile the function
    dat_dtype = "float32"
    A = tvm.placeholder(shape=in_shape, name="A", dtype=dat_dtype)
    A1 = topi.cpp.sqrt(topi.cpp.exp(A))
    out_dtype = "float32"
    if type == "sum":
        B = topi.cpp.sum(A1, axis, keepdims)
    elif type == "max":
        B = topi.cpp.max(A1, axis, keepdims)
    elif type == "min":
        B = topi.cpp.min(A1, axis, keepdims)
    elif type == "argmax":
        B = topi.cpp.argmax(A1, axis, keepdims)
        out_dtype = "int32"
    elif type == "argmin":
        B = topi.cpp.argmin(A1, axis, keepdims)
        out_dtype = "int32"
45 46
    elif type == "prod":
        B = topi.cpp.prod(A1, axis, keepdims)
47 48 49 50
    else:
        raise NotImplementedError

    def check_device(device):
51 52
        ctx = tvm.context(device, 0)
        if not ctx.exist:
53 54 55 56 57 58 59 60 61
            print("Skip because %s is not enabled" % device)
            return
        print("Running on target: %s" % device)
        target = topi.cpp.TEST_create_target(device)
        if device == "llvm":
            s = topi.cpp.generic.default_schedule(target, [B], True)
        else:
            s = topi.cpp.cuda.schedule_reduce(target, [B])

62
        foo = tvm.build(s, [A, B], device, name=type)
63 64 65 66 67 68 69 70 71 72 73 74 75
        # Test
        in_npy = np.random.uniform(size=in_shape).astype(np.float32)
        in_npy_map = np.sqrt(np.exp(in_npy)).astype(np.float32)
        if type == "sum":
            out_npy = in_npy_map.sum(axis=axis, keepdims=keepdims)
        elif type == "max":
            out_npy = in_npy_map.max(axis=axis, keepdims=keepdims)
        elif type == "min":
            out_npy = in_npy_map.min(axis=axis, keepdims=keepdims)
        elif type == "argmax":
            out_npy = _my_npy_argmax(in_npy_map, axis=axis, keepdims=keepdims)
        elif type == "argmin":
            out_npy = _my_npy_argmin(in_npy_map, axis=axis, keepdims=keepdims)
76 77
        elif type == "prod":
            out_npy = in_npy_map.prod(axis=axis, keepdims=keepdims)
78 79 80 81 82 83
        else:
            raise NotImplementedError
        data_tvm = tvm.nd.array(in_npy, ctx=ctx)
        out_tvm = tvm.nd.empty(shape=out_npy.shape, ctx=ctx, dtype=out_dtype)
        for _ in range(1):
            foo(data_tvm, out_tvm)
84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99
        if type == "argmax" or type == "argmin":
            out_tvm_indices = out_tvm.asnumpy()
            if keepdims:
                out_tvm_indices = np.take(out_tvm_indices, indices=0, axis=axis)
            if axis is None:
                out_tvm_val = in_npy_map.ravel()[out_tvm_indices]
            else:
                other_indices = tuple(np.indices(in_shape[0:axis] + in_shape[(axis+1):]))
                sel_indices = other_indices[0:axis] + (out_tvm_indices,) + other_indices[axis:]
                out_tvm_val = in_npy_map[sel_indices]
            if type == "argmax":
                np.testing.assert_allclose(out_tvm_val, in_npy_map.max(axis=axis), 1E-3, 1E-3)
            elif type == "argmin":
                np.testing.assert_allclose(out_tvm_val, in_npy_map.min(axis=axis), 1E-3, 1E-3)
        else:
            np.testing.assert_allclose(out_tvm.asnumpy(), out_npy, 1E-3, 1E-3)
100 101 102 103 104 105
    for device in ["cuda", "opencl", "metal", "llvm", "rocm"]:
        check_device(device)


def test_reduce_map():
    verify_reduce_map_ele(in_shape=(128, 24, 128, 24),
106 107 108
                          axis=(1, 2, 3),
                          keepdims=True,
                          type="sum")
109
    verify_reduce_map_ele(in_shape=(128, 24 * 128 * 24),
110 111 112
                          axis=(1,),
                          keepdims=False,
                          type="max")
113
    verify_reduce_map_ele(in_shape=(32, 128, 24),
114 115 116
                          axis=None,
                          keepdims=True,
                          type="sum")
117
    verify_reduce_map_ele(in_shape=(128, 24, 128, 24),
118 119 120 121 122 123 124 125 126 127 128
                          axis=(0, 2),
                          keepdims=False,
                          type="min")
    verify_reduce_map_ele(in_shape=(128, 4, 4, 128),
                          axis=(1, ),
                          keepdims=True,
                          type="prod")
    verify_reduce_map_ele(in_shape=(4, 4),
                          axis=(0, 1),
                          keepdims=False,
                          type="prod")
129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
    verify_reduce_map_ele(in_shape=(32, 128),
                          axis=1,
                          keepdims=True,
                          type="argmax")
    verify_reduce_map_ele(in_shape=(32, 24, 32, 24),
                          axis=2,
                          keepdims=False,
                          type="argmin")
    verify_reduce_map_ele(in_shape=(31, 21, 15),
                          axis=None,
                          keepdims=True,
                          type="argmax")
    verify_reduce_map_ele(in_shape=(31, 21, 15),
                          axis=None,
                          keepdims=False,
                          type="sum")

if __name__ == "__main__":
    test_reduce_map()