test_topi_pooling.py 16.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17
"""Test code for pooling"""
18
import math
19 20 21
import numpy as np
import tvm
import topi
22
import topi.testing
23
from topi.util import get_const_tuple
24 25
from common import get_all_backend

26
def verify_pool(n, ic, ih, kh, sh, padding, pool_type, ceil_mode, count_include_pad=True):
Yuwei HU committed
27 28 29
    iw = ih
    kw = kh
    sw = sh
30
    pt, pl, pb, pr = padding
31
    layout = "NCHW"
Yuwei HU committed
32
    A = tvm.placeholder((n, ic, ih, iw), name='A')
33
    B = topi.nn.pool(A, kernel=[kh, kw], stride=[sh, sw], padding=padding,
34 35
                     pool_type=pool_type, ceil_mode=ceil_mode,
                     layout="NCHW", count_include_pad=count_include_pad)
Yuwei HU committed
36 37 38
    B = topi.nn.relu(B)
    dtype = A.dtype

39 40 41
    bshape = get_const_tuple(B.shape)
    ashape = get_const_tuple(A.shape)
    if ceil_mode:
42 43
        assert bshape[2] == int(math.ceil(float(ashape[2] - kh + pt + pb) / sh) + 1)
        assert bshape[3] == int(math.ceil(float(ashape[3] - kw + pl + pr) / sw) + 1)
44
    else:
45 46
        assert bshape[2] == int(math.floor(float(ashape[2] - kh + pt + pb) / sh) + 1)
        assert bshape[3] == int(math.floor(float(ashape[3] - kw + pl + pr) / sw) + 1)
47

48
    a_np = np.random.uniform(low=0.001, size=(n, ic, ih, iw)).astype(dtype)
49 50
    pad_np = np.zeros(shape=(n, ic, ih+pt+pb, iw+pl+pr)).astype(dtype)
    no_zero = (range(n), range(ic), (range(pt, ih+pt)), (range(pl, iw+pl)))
Yuwei HU committed
51 52 53 54 55 56 57
    pad_np[np.ix_(*no_zero)] = a_np
    _, oc, oh, ow = get_const_tuple(B.shape)
    b_np = np.zeros(shape=(n, oc, oh, ow)).astype(dtype)

    if pool_type == 'avg':
        for i in range(oh):
            for j in range(ow):
58 59 60 61 62 63
                if count_include_pad:
                    b_np[:,:,i,j] = np.mean(pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw], axis=(2,3))
                else:
                    pad_count = np.sum(pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw] > 0, axis=(2,3))
                    b_np[:,:,i,j] = np.sum(pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw], axis=(2,3)) / np.maximum(pad_count, 1)

Yuwei HU committed
64 65 66 67 68 69 70
    elif pool_type =='max':
        for i in range(oh):
            for j in range(ow):
                b_np[:,:,i,j] = np.max(pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw], axis=(2,3))
    b_np = np.maximum(b_np, 0.0)

    def check_device(device):
71 72
        ctx = tvm.context(device, 0)
        if not ctx.exist:
Yuwei HU committed
73 74
            print("Skip because %s is not enabled" % device)
            return
75
        print("Running on target: %s" % device)
76
        with tvm.target.create(device):
77
            s = topi.generic.schedule_pool(B, layout)
78

Yuwei HU committed
79 80 81 82
        a = tvm.nd.array(a_np, ctx)
        b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx)
        f = tvm.build(s, [A, B], device)
        f(a, b)
83
        tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=2e-5, atol=1e-5)
Yuwei HU committed
84

85
    for device in get_all_backend():
Yuwei HU committed
86 87
        check_device(device)

88 89
def verify_pool_grad(n, ic, ih, kh, sh, padding, pool_type, ceil_mode, count_include_pad=True,
                     add_relu=False):
90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112
    iw = ih
    kw = kh
    sw = sh
    pt, pl, pb, pr = padding
    layout = "NCHW"
    A = tvm.placeholder((n, ic, ih, iw), name='A')
    B = topi.nn.pool(A, kernel=[kh, kw], stride=[sh, sw], padding=padding,
                     pool_type=pool_type, ceil_mode=ceil_mode,
                     layout="NCHW", count_include_pad=count_include_pad)
    dtype = A.dtype

    bshape = get_const_tuple(B.shape)
    ashape = get_const_tuple(A.shape)
    if ceil_mode:
        assert bshape[2] == int(math.ceil(float(ashape[2] - kh + pt + pb) / sh) + 1)
        assert bshape[3] == int(math.ceil(float(ashape[3] - kw + pl + pr) / sw) + 1)
    else:
        assert bshape[2] == int(math.floor(float(ashape[2] - kh + pt + pb) / sh) + 1)
        assert bshape[3] == int(math.floor(float(ashape[3] - kw + pl + pr) / sw) + 1)
    OutGrad = tvm.placeholder(bshape, name='OutGrad')
    PoolGrad = topi.nn.pool_grad(OutGrad, A, kernel=[kh, kw], stride=[sh, sw], padding=padding,
                                 pool_type=pool_type, ceil_mode=ceil_mode,
                                 layout="NCHW", count_include_pad=count_include_pad)
113 114
    if add_relu:
        PoolGrad = topi.nn.relu(PoolGrad)
115 116 117 118 119 120 121

    a_np = np.random.uniform(low=0.001, size=(n, ic, ih, iw)).astype(dtype)
    out_grad_np = np.random.uniform(low=0.001, size=bshape).astype(dtype)
    pool_grad_np = topi.testing.pool_grad_nchw(a_np, out_grad_np, pool_size=(kh, kw),
                                               strides=(sh, sw), padding=padding,
                                               pool_type=pool_type, ceil_mode=ceil_mode,
                                               count_include_pad=count_include_pad)
122 123
    if add_relu:
        pool_grad_np = np.maximum(pool_grad_np, 0.)
124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140

    def check_device(device):
        ctx = tvm.context(device, 0)
        if not ctx.exist:
            print("Skip because %s is not enabled" % device)
            return
        print("Running on target: %s" % device)
        with tvm.target.create(device):
            s = topi.generic.schedule_pool_grad(PoolGrad)

        a = tvm.nd.array(a_np, ctx)
        out_grad = tvm.nd.array(out_grad_np, ctx)
        pool_grad = tvm.nd.array(np.zeros(get_const_tuple(PoolGrad.shape), dtype=dtype), ctx)
        f = tvm.build(s, [A, OutGrad, PoolGrad], device)
        f(a, out_grad, pool_grad)
        tvm.testing.assert_allclose(pool_grad.asnumpy(), pool_grad_np, rtol=1e-5)

141
    for device in get_all_backend():
142 143
        check_device(device)

Yuwei HU committed
144
def test_pool():
145 146 147 148 149 150 151 152
    verify_pool(1, 256, 32, 2, 2, [0, 0, 0, 0], 'avg', False, True)
    verify_pool(1, 256, 31, 3, 3, [1, 2, 1, 2], 'avg', False, True)
    verify_pool(1, 256, 32, 2, 2, [1, 2, 1, 2], 'avg', False, False)
    verify_pool(1, 256, 31, 4, 4, [3, 3, 3, 3], 'avg', False, False)
    verify_pool(1, 256, 31, 4, 4, [0, 0, 0, 0], 'avg', False, False)
    verify_pool(1, 256, 32, 2, 2, [0, 0, 0, 0], 'max', False)
    verify_pool(1, 256, 31, 3, 3, [2, 1, 2, 1], 'max', False)
    verify_pool(1, 256, 31, 3, 3, [2, 1, 2, 1], 'max', True)
153

154 155 156 157
    verify_pool(1, 256, 31, 3, 3, [2, 1, 0, 3], 'avg', False, True)
    verify_pool(1, 256, 32, 2, 2, [0, 3, 2, 1], 'avg', False, False)
    verify_pool(1, 256, 31, 3, 3, [1, 0, 3, 2], 'max', False)
    verify_pool(1, 256, 31, 3, 3, [3, 2, 1, 0], 'max', True)
Yuwei HU committed
158

159
def test_pool_grad():
160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176
    verify_pool_grad(1, 256, 32, 3, 2, [1, 1, 1, 1], 'avg', False, False)
    verify_pool_grad(1, 256, 32, 2, 2, [0, 0, 0, 0], 'avg', False, True)
    verify_pool_grad(1, 256, 31, 3, 3, [1, 2, 1, 2], 'avg', False, True)
    verify_pool_grad(1, 256, 32, 2, 2, [1, 2, 1, 2], 'avg', False, False)
    verify_pool_grad(1, 256, 31, 4, 4, [2, 2, 2, 2], 'avg', False, False)
    verify_pool_grad(1, 256, 31, 4, 4, [0, 0, 0, 0], 'avg', False, False)
    verify_pool_grad(1, 256, 32, 2, 2, [0, 0, 0, 0], 'max', False)
    verify_pool_grad(1, 256, 31, 3, 3, [2, 1, 2, 1], 'max', False)
    verify_pool_grad(1, 256, 31, 3, 3, [2, 1, 2, 1], 'max', True)

    verify_pool_grad(1, 256, 31, 3, 3, [2, 1, 0, 3], 'avg', False, True)
    verify_pool_grad(1, 256, 32, 2, 2, [0, 3, 2, 1], 'avg', False, False)
    verify_pool_grad(1, 256, 31, 3, 3, [1, 0, 3, 2], 'max', False)
    verify_pool_grad(1, 256, 31, 3, 3, [3, 2, 1, 0], 'max', True)
    verify_pool_grad(1, 256, 32, 3, 2, [1, 1, 1, 1], 'max', False)
    verify_pool_grad(1, 256, 32, 1, 2, [1, 1, 1, 1], 'avg', False, False)

177 178 179
    verify_pool_grad(1, 256, 31, 4, 4, [0, 0, 0, 0], 'avg', False, False, add_relu=True)
    verify_pool_grad(1, 256, 32, 2, 2, [0, 0, 0, 0], 'max', False, add_relu=True)

Yuwei HU committed
180

181 182
def verify_global_pool(n, c, h, w, pool_type, layout='NCHW'):

183
    assert layout in ["NCHW", "NHWC"]
184
    A = tvm.placeholder((n, c, h, w), name='A')
185
    B = topi.nn.global_pool(A, pool_type=pool_type, layout=layout)
186 187 188
    B = topi.nn.relu(B)

    a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype)
189 190

    axis = (layout.find('H'), layout.find('W'))
191
    if pool_type == 'avg':
192
        b_np = np.mean(a_np, axis=axis, keepdims=True)
193
    elif pool_type =='max':
194
        b_np = np.max(a_np, axis=axis, keepdims=True)
195 196 197
    b_np = np.maximum(b_np, 0.0)

    def check_device(device):
198 199
        ctx = tvm.context(device, 0)
        if not ctx.exist:
200 201
            print("Skip because %s is not enabled" % device)
            return
202
        print("Running on target: %s" % device)
203
        with tvm.target.create(device):
204
            s = topi.generic.schedule_adaptive_pool(B)
205 206
        a = tvm.nd.array(a_np, ctx)
        b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
Yuwei HU committed
207
        f = tvm.build(s, [A, B], device)
208
        f(a, b)
209
        tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
210

211
    for device in get_all_backend():
212 213 214 215 216 217 218
        check_device(device)

def test_global_pool():
    verify_global_pool(1, 1024, 7, 7, 'avg')
    verify_global_pool(4, 1024, 7, 7, 'avg')
    verify_global_pool(1, 1024, 7, 7, 'max')
    verify_global_pool(4, 1024, 7, 7, 'max')
219 220 221 222
    verify_global_pool(1, 1024, 7, 7, 'avg', 'NHWC')
    verify_global_pool(4, 1024, 7, 7, 'avg', 'NHWC')
    verify_global_pool(1, 1024, 7, 7, 'max', 'NHWC')
    verify_global_pool(4, 1024, 7, 7, 'max', 'NHWC')
223

224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273
def verify_adaptive_pool(dshape, out_size, pool_type, layout="NCHW", dtype="float32"):
    def start_index(index, odim, idim):
        return int(np.floor(index * idim / odim))

    def end_index(index, odim, idim):
        return int(np.ceil((index + 1) * idim / odim))

    np_data = np.random.uniform(low=0, high=255, size=dshape).astype(dtype)
    n, c, h, w = dshape
    oh, ow = out_size
    oshape = (n, c) + out_size
    np_out = np.zeros(oshape).astype(dtype)
    np_op = np.mean if pool_type == "avg" else np.max
    for i in range(n):
        for j in range(c):
            for k in range(oh):
                k_start = start_index(k, oh, h)
                k_end = end_index(k, oh, h)
                k_sl = slice(k_start, k_end)
                for l in range(ow):
                    l_start = start_index(l, ow, w)
                    l_end = end_index(l, ow, w)
                    l_sl = slice(l_start, l_end)
                    np_out[i, j, k, l] = np_op(np_data[i, j, k_sl, l_sl])

    data = tvm.placeholder(dshape, name="data", dtype=dtype)
    out = topi.nn.adaptive_pool(data, out_size, pool_type, layout)
    def check_device(device):
        ctx = tvm.context(device, 0)
        if not ctx.exist:
            print("Skip because %s is not enabled" % device)
            return
        print("Running on target: %s" % device)
        with tvm.target.create(device):
            s = topi.generic.schedule_adaptive_pool(out)
        a = tvm.nd.array(np_data, ctx)
        b = tvm.nd.array(np.zeros(get_const_tuple(oshape), dtype=out.dtype), ctx)
        f = tvm.build(s, [data, out], device)
        f(a, b)
        tvm.testing.assert_allclose(b.asnumpy(), np_out, rtol=1e-5)

    for device in get_all_backend():
        check_device(device)

def test_adaptive_pool():
    verify_adaptive_pool((1, 3, 224, 224), (1, 1), "max")
    verify_adaptive_pool((1, 3, 224, 224), (1, 1), "avg")
    verify_adaptive_pool((1, 14, 56, 78), (34, 13), "max")
    verify_adaptive_pool((1, 5, 46, 97), (4, 96), "avg")

274 275 276 277 278 279 280 281 282 283
def verify_pool3d(n, ic, ih, kh, sh, padding, pool_type,
                  ceil_mode, count_include_pad=True, layout='NCDHW'):
    id = iw = ih
    kd = kw = kh
    sd = sw = sh
    input_shape = (n, ic, id, ih, iw)
    kernel = [kd, kh, kw]
    stride = [sd, sh, sw]
    A = tvm.placeholder(input_shape, name='A')
    B = topi.nn.pool3d(A, kernel=kernel, stride=stride, padding=padding,
284
                       pool_type=pool_type, ceil_mode=ceil_mode,
285
                       layout=layout, count_include_pad=count_include_pad)
286 287
    B = topi.nn.relu(B)
    dtype = A.dtype
288
    output_shape = [int(i) for i in B.shape]
289

290 291 292
    input_np = np.random.uniform(low=0.001, size=input_shape).astype(dtype)
    ref_np = topi.testing.pool3d_ncdhw_python(input_np, kernel, stride, padding,
                                              output_shape, pool_type, count_include_pad, ceil_mode)
293 294 295 296 297 298 299 300 301 302

    def check_device(device):
        ctx = tvm.context(device, 0)
        if not ctx.exist:
            print("Skip because %s is not enabled" % device)
            return
        print("Running on target: %s" % device)
        with tvm.target.create(device):
            s = topi.generic.schedule_pool(B, layout)

303
        a = tvm.nd.array(input_np, ctx)
304 305 306
        b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx)
        f = tvm.build(s, [A, B], device)
        f(a, b)
307
        tvm.testing.assert_allclose(b.asnumpy(), ref_np, rtol=1e-5)
308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327

    for device in get_all_backend():
        check_device(device)


def test_pool3d():
    verify_pool3d(1, 256, 32, 2, 2, [0, 0, 0, 0, 0, 0], 'avg', False, True)
    verify_pool3d(1, 256, 31, 3, 3, [1, 1, 2, 2, 2, 1], 'avg', False, True)
    verify_pool3d(1, 256, 32, 2, 2, [1, 1, 2, 2, 2, 1], 'avg', False, False)
    verify_pool3d(1, 256, 31, 4, 4, [3, 3, 3, 3, 3, 3], 'avg', False, False)
    verify_pool3d(1, 256, 31, 4, 4, [0, 0, 0, 0, 0, 0], 'avg', False, False)
    verify_pool3d(1, 256, 32, 2, 2, [0, 0, 0, 0, 0, 0], 'max', False)
    verify_pool3d(1, 256, 31, 3, 3, [2, 2, 1, 1, 1, 2], 'max', False)
    verify_pool3d(1, 256, 31, 3, 3, [2, 2, 1, 1, 1, 2], 'max', True)

    verify_pool3d(1, 256, 31, 3, 3, [2, 1, 0, 5, 4, 3], 'avg', False, True)
    verify_pool3d(1, 256, 32, 2, 2, [0, 5, 4, 3, 2, 1], 'avg', False, False)
    verify_pool3d(1, 256, 31, 3, 3, [1, 0, 5, 4, 3, 2], 'max', False)
    verify_pool3d(1, 256, 31, 3, 3, [3, 2, 1, 0, 5, 4], 'max', True)

328

329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380
def verify_pool1d(n, ic, iw, kw, sw, padding, pool_type,
                  ceil_mode, count_include_pad=True, layout='NCW'):
    input_shape = (n, ic, iw)
    kernel = [kw]
    stride = [sw]
    A = tvm.placeholder(input_shape, name='A')
    B = topi.nn.pool1d(A, kernel=kernel, stride=stride, padding=padding,
                       pool_type=pool_type, ceil_mode=ceil_mode,
                       layout=layout, count_include_pad=count_include_pad)
    B = topi.nn.relu(B)
    dtype = A.dtype
    output_shape = [int(i) for i in B.shape]

    input_np = np.random.uniform(low=0.001, size=input_shape).astype(dtype)
    ref_np = topi.testing.pool1d_ncw_python(input_np, kernel, stride, padding,
                                            output_shape, pool_type, count_include_pad, ceil_mode)

    def check_device(device):
        ctx = tvm.context(device, 0)
        if not ctx.exist:
            print("Skip because %s is not enabled" % device)
            return
        print("Running on target: %s" % device)
        with tvm.target.create(device):
            s = topi.generic.schedule_pool(B, layout)

        a = tvm.nd.array(input_np, ctx)
        b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx)
        f = tvm.build(s, [A, B], device)
        f(a, b)
        tvm.testing.assert_allclose(b.asnumpy(), ref_np, rtol=1e-5)

    for device in get_all_backend():
        check_device(device)


def test_pool1d():
    verify_pool1d(1, 256, 32, 2, 2, [0, 0], 'avg', False, True)
    verify_pool1d(1, 256, 31, 3, 3, [1, 2], 'avg', False, True)
    verify_pool1d(1, 256, 32, 2, 2, [1, 2], 'avg', False, False)
    verify_pool1d(1, 256, 31, 4, 4, [3, 3], 'avg', False, False)
    verify_pool1d(1, 256, 31, 4, 4, [0, 0], 'avg', False, False)
    verify_pool1d(1, 256, 32, 2, 2, [0, 0], 'max', False)
    verify_pool1d(1, 256, 31, 3, 3, [2, 1], 'max', False)
    verify_pool1d(1, 256, 31, 3, 3, [2, 1], 'max', True)

    verify_pool1d(1, 256, 31, 3, 3, [2, 5], 'avg', False, True)
    verify_pool1d(1, 256, 32, 2, 2, [0, 3], 'avg', False, False)
    verify_pool1d(1, 256, 31, 3, 3, [1, 4], 'max', False)
    verify_pool1d(1, 256, 31, 3, 3, [3, 0], 'max', True)


381
if __name__ == "__main__":
Yuwei HU committed
382
    test_pool()
383
    test_pool1d()
384
    test_pool3d()
385
    test_pool_grad()
386
    test_global_pool()
387
    test_adaptive_pool()