test_topi_reduce.py 7.81 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17 18 19 20 21 22
"""Test code for reduce."""
import os
import numpy as np
import tvm
import topi

23 24
from common import get_all_backend

25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45
def _my_npy_argmax(arr, axis, keepdims):
    if not keepdims:
        return arr.argmax(axis=axis)
    else:
        if axis is not None:
            out_shape = list(arr.shape)
            out_shape[axis] = 1
        else:
            out_shape = [1 for _ in range(len(arr.shape))]
        return arr.argmax(axis=axis).reshape(out_shape)


def _my_npy_argmin(arr, axis, keepdims):
    if not keepdims:
        return arr.argmin(axis=axis)
    else:
        out_shape = list(arr.shape)
        out_shape[axis] = 1
        return arr.argmin(axis=axis).reshape(out_shape)


46
def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum", dtype="float32"):
47
    # Build the logic and compile the function
48
    A = tvm.placeholder(shape=in_shape, name="A", dtype=dtype)
49
    A1 = topi.sqrt(topi.exp(A))
50
    out_dtype = dtype
51
    if type == "sum":
52
        B = topi.sum(A1, axis=axis, keepdims=keepdims)
53 54
    elif type == "all":
        B = topi.all(A, axis=axis, keepdims=keepdims)
55 56
    elif type == "any":
        B = topi.any(A, axis=axis, keepdims=keepdims)
57
    elif type == "max":
58
        B = topi.max(A1, axis=axis, keepdims=keepdims)
59
    elif type == "min":
60
        B = topi.min(A1, axis=axis, keepdims=keepdims)
61 62 63 64 65 66
    elif type == "argmax":
        B = topi.argmax(A1, axis=axis, keepdims=keepdims)
        out_dtype = "int32"
    elif type == "argmin":
        B = topi.argmin(A1, axis=axis, keepdims=keepdims)
        out_dtype = "int32"
67 68
    else:
        raise NotImplementedError
69

70
    def check_device(device):
71 72
        ctx = tvm.context(device, 0)
        if not ctx.exist:
73 74
            print("Skip because %s is not enabled" % device)
            return
75
        print("Running on target: %s" % device)
76 77
        with tvm.target.create(device):
            s = topi.generic.schedule_reduce(B)
78

Xingjian Shi committed
79
        foo = tvm.build(s, [A, B], device, name=type)
80
        # Test
81 82 83 84 85 86
        if dtype == 'bool':
            in_npy_map = in_npy = np.random.choice([True, False], size=in_shape)
        else:
            in_npy = np.random.uniform(-1, 1, size=in_shape).astype(dtype)
            in_npy_map = np.sqrt(np.exp(in_npy)).astype(dtype)

87
        if type == "sum":
88
            out_npy = in_npy_map.sum(axis=axis, keepdims=keepdims)
89 90
        elif type == "all" and dtype == 'bool':
            out_npy = in_npy_map.all(axis=axis, keepdims=keepdims)
91 92
        elif type == "any" and dtype == "bool":
            out_npy = in_npy_map.any(axis=axis, keepdims=keepdims)
93
        elif type == "max":
94
            out_npy = in_npy_map.max(axis=axis, keepdims=keepdims)
95
        elif type == "min":
96
            out_npy = in_npy_map.min(axis=axis, keepdims=keepdims)
97 98 99 100
        elif type == "argmax":
            out_npy = _my_npy_argmax(in_npy_map, axis=axis, keepdims=keepdims)
        elif type == "argmin":
            out_npy = _my_npy_argmin(in_npy_map, axis=axis, keepdims=keepdims)
101 102 103
        else:
            raise NotImplementedError
        data_tvm = tvm.nd.array(in_npy, ctx=ctx)
104
        out_tvm = tvm.nd.empty(shape=out_npy.shape, ctx=ctx, dtype=out_dtype)
105 106
        for _ in range(1):
            foo(data_tvm, out_tvm)
Xingjian Shi committed
107 108 109 110 111 112 113 114 115 116 117
        if type == "argmax" or type == "argmin":
            out_tvm_indices = out_tvm.asnumpy()
            if keepdims:
                out_tvm_indices = np.take(out_tvm_indices, indices=0, axis=axis)
            if axis is None:
                out_tvm_val = in_npy_map.ravel()[out_tvm_indices]
            else:
                other_indices = tuple(np.indices(in_shape[0:axis] + in_shape[(axis+1):]))
                sel_indices = other_indices[0:axis] + (out_tvm_indices,) + other_indices[axis:]
                out_tvm_val = in_npy_map[sel_indices]
            if type == "argmax":
118
                tvm.testing.assert_allclose(out_tvm_val, in_npy_map.max(axis=axis), 1E-3, 1E-3)
Xingjian Shi committed
119
            elif type == "argmin":
120
                tvm.testing.assert_allclose(out_tvm_val, in_npy_map.min(axis=axis), 1E-3, 1E-3)
Xingjian Shi committed
121
        else:
122
            tvm.testing.assert_allclose(out_tvm.asnumpy(), out_npy, 1E-3, 1E-3)
123
    for device in get_all_backend():
124
        check_device(device)
125 126 127


def test_reduce_map():
128

129 130 131 132
    verify_reduce_map_ele(in_shape=(32,),
                          axis=0,
                          keepdims=False,
                          type="argmax")
133
    verify_reduce_map_ele(in_shape=(128, 24, 128, 24),
134 135 136 137 138 139 140 141
                          axis=(1, 2, 3),
                          keepdims=True,
                          type="sum")
    verify_reduce_map_ele(in_shape=(2, 3),
                          axis=None,
                          keepdims=True,
                          type="all",
                          dtype='bool')
142
    verify_reduce_map_ele(in_shape=(128, 24 * 128 * 24),
143 144 145 146 147 148 149
                          axis=(1,),
                          keepdims=False,
                          type="max")
    verify_reduce_map_ele(in_shape=(32, 128, 24),
                          axis=None,
                          keepdims=True,
                          type="sum")
150
    verify_reduce_map_ele(in_shape=(32, 128, 24),
151 152 153 154
                          axis=None,
                          keepdims=True,
                          dtype='bool',
                          type="all")
155
    verify_reduce_map_ele(in_shape=(128, 24, 128, 24),
156 157 158
                          axis=(0, 2),
                          keepdims=False,
                          type="min")
159 160 161 162 163 164 165 166 167 168 169 170
    verify_reduce_map_ele(in_shape=(32, 128),
                          axis=1,
                          keepdims=True,
                          type="argmax")
    verify_reduce_map_ele(in_shape=(32, 24, 32, 24),
                          axis=2,
                          keepdims=False,
                          type="argmin")
    verify_reduce_map_ele(in_shape=(31, 21, 15),
                          axis=None,
                          keepdims=True,
                          type="argmax")
171 172 173 174
    verify_reduce_map_ele(in_shape=(31, 21, 15),
                          axis=None,
                          keepdims=False,
                          type="sum")
175 176 177 178 179
    verify_reduce_map_ele(in_shape=(128, 24, 128, 24),
                          axis=(1, 2, 3),
                          keepdims=True,
                          type="sum",
                          dtype="float64")
180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199
    verify_reduce_map_ele(in_shape=(2, 3),
                          axis=None,
                          keepdims=True,
                          type="any",
                          dtype="bool")
    verify_reduce_map_ele(in_shape=(32, 128, 24),
                          axis=None,
                          keepdims=True,
                          type="any",
                          dtype="bool")
    verify_reduce_map_ele(in_shape=(1, 4, 7),
                          axis=1,
                          keepdims=True,
                          type="any",
                          dtype="bool")
    verify_reduce_map_ele(in_shape=(128, 24, 128, 24),
                          axis=2,
                          keepdims=False,
                          type="any",
                          dtype="bool")
200 201 202

if __name__ == "__main__":
    test_reduce_map()