# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
"""Test code for pooling"""
import math
import numpy as np
import tvm
from tvm import te
import topi
import topi.testing
from topi.util import get_const_tuple
from common import get_all_backend

_pool_schedule = {
    "generic": topi.generic.schedule_pool,
    "cpu": topi.x86.schedule_pool,
    "gpu": topi.cuda.schedule_pool,
    "hls": topi.hls.schedule_pool,
}

_adaptive_pool_schedule = {
    "generic": topi.generic.schedule_adaptive_pool,
    "cpu": topi.x86.schedule_adaptive_pool,
    "gpu": topi.cuda.schedule_adaptive_pool,
    "hls": topi.hls.schedule_adaptive_pool,
}

_pool_grad_schedule = {
    "generic": topi.generic.schedule_pool_grad,
    "gpu": topi.cuda.schedule_pool_grad,
}

def verify_pool(n, ic, ih, kh, sh, padding, pool_type, ceil_mode, count_include_pad=True):
    iw = ih
    kw = kh
    sw = sh
    pt, pl, pb, pr = padding
    layout = "NCHW"
    A = te.placeholder((n, ic, ih, iw), name='A')
    B = topi.nn.pool(A, kernel=[kh, kw], stride=[sh, sw], padding=padding,
                     pool_type=pool_type, ceil_mode=ceil_mode,
                     layout="NCHW", count_include_pad=count_include_pad)
    B = topi.nn.relu(B)
    dtype = A.dtype

    bshape = get_const_tuple(B.shape)
    ashape = get_const_tuple(A.shape)
    if ceil_mode:
        assert bshape[2] == int(math.ceil(float(ashape[2] - kh + pt + pb) / sh) + 1)
        assert bshape[3] == int(math.ceil(float(ashape[3] - kw + pl + pr) / sw) + 1)
    else:
        assert bshape[2] == int(math.floor(float(ashape[2] - kh + pt + pb) / sh) + 1)
        assert bshape[3] == int(math.floor(float(ashape[3] - kw + pl + pr) / sw) + 1)

    a_np = np.random.uniform(low=0.001, size=(n, ic, ih, iw)).astype(dtype)
    pad_np = np.zeros(shape=(n, ic, ih+pt+pb, iw+pl+pr)).astype(dtype)
    no_zero = (range(n), range(ic), (range(pt, ih+pt)), (range(pl, iw+pl)))
    pad_np[np.ix_(*no_zero)] = a_np
    _, oc, oh, ow = get_const_tuple(B.shape)
    b_np = np.zeros(shape=(n, oc, oh, ow)).astype(dtype)

    if pool_type == 'avg':
        for i in range(oh):
            for j in range(ow):
                if count_include_pad:
                    b_np[:,:,i,j] = np.mean(pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw], axis=(2,3))
                else:
                    pad_count = np.sum(pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw] > 0, axis=(2,3))
                    b_np[:,:,i,j] = np.sum(pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw], axis=(2,3)) / np.maximum(pad_count, 1)

    elif pool_type =='max':
        for i in range(oh):
            for j in range(ow):
                b_np[:,:,i,j] = np.max(pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw], axis=(2,3))
    b_np = np.maximum(b_np, 0.0)

    def check_device(device):
        ctx = tvm.context(device, 0)
        if not ctx.exist:
            print("Skip because %s is not enabled" % device)
            return
        print("Running on target: %s" % device)
        with tvm.target.create(device):
            s_func = topi.testing.dispatch(device, _pool_schedule)
            s = s_func(B, layout)

        a = tvm.nd.array(a_np, ctx)
        b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx)
        f = tvm.build(s, [A, B], device)
        f(a, b)
        tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=2e-5, atol=1e-5)

    for device in get_all_backend():
        check_device(device)

def verify_pool_grad(n, ic, ih, kh, sh, padding, pool_type, ceil_mode, count_include_pad=True,
                     add_relu=False):
    iw = ih
    kw = kh
    sw = sh
    pt, pl, pb, pr = padding
    layout = "NCHW"
    A = te.placeholder((n, ic, ih, iw), name='A')
    B = topi.nn.pool(A, kernel=[kh, kw], stride=[sh, sw], padding=padding,
                     pool_type=pool_type, ceil_mode=ceil_mode,
                     layout="NCHW", count_include_pad=count_include_pad)
    dtype = A.dtype

    bshape = get_const_tuple(B.shape)
    ashape = get_const_tuple(A.shape)
    if ceil_mode:
        assert bshape[2] == int(math.ceil(float(ashape[2] - kh + pt + pb) / sh) + 1)
        assert bshape[3] == int(math.ceil(float(ashape[3] - kw + pl + pr) / sw) + 1)
    else:
        assert bshape[2] == int(math.floor(float(ashape[2] - kh + pt + pb) / sh) + 1)
        assert bshape[3] == int(math.floor(float(ashape[3] - kw + pl + pr) / sw) + 1)
    OutGrad = te.placeholder(bshape, name='OutGrad')
    PoolGrad = topi.nn.pool_grad(OutGrad, A, kernel=[kh, kw], stride=[sh, sw], padding=padding,
                                 pool_type=pool_type, ceil_mode=ceil_mode,
                                 layout="NCHW", count_include_pad=count_include_pad)
    if add_relu:
        PoolGrad = topi.nn.relu(PoolGrad)

    a_np = np.random.uniform(low=0.001, size=(n, ic, ih, iw)).astype(dtype)
    out_grad_np = np.random.uniform(low=0.001, size=bshape).astype(dtype)
    pool_grad_np = topi.testing.pool_grad_nchw(a_np, out_grad_np, pool_size=(kh, kw),
                                               strides=(sh, sw), padding=padding,
                                               pool_type=pool_type, ceil_mode=ceil_mode,
                                               count_include_pad=count_include_pad)
    if add_relu:
        pool_grad_np = np.maximum(pool_grad_np, 0.)

    def check_device(device):
        ctx = tvm.context(device, 0)
        if not ctx.exist:
            print("Skip because %s is not enabled" % device)
            return
        print("Running on target: %s" % device)
        with tvm.target.create(device):
            s_func = topi.testing.dispatch(device, _pool_grad_schedule)
            s = s_func(PoolGrad)

        a = tvm.nd.array(a_np, ctx)
        out_grad = tvm.nd.array(out_grad_np, ctx)
        pool_grad = tvm.nd.array(np.zeros(get_const_tuple(PoolGrad.shape), dtype=dtype), ctx)
        f = tvm.build(s, [A, OutGrad, PoolGrad], device)
        f(a, out_grad, pool_grad)
        tvm.testing.assert_allclose(pool_grad.asnumpy(), pool_grad_np, rtol=1e-5)

    for device in get_all_backend():
        check_device(device)

def test_pool():
    verify_pool(1, 256, 32, 2, 2, [0, 0, 0, 0], 'avg', False, True)
    verify_pool(1, 256, 31, 3, 3, [1, 2, 1, 2], 'avg', False, True)
    verify_pool(1, 256, 32, 2, 2, [1, 2, 1, 2], 'avg', False, False)
    verify_pool(1, 256, 31, 4, 4, [3, 3, 3, 3], 'avg', False, False)
    verify_pool(1, 256, 31, 4, 4, [0, 0, 0, 0], 'avg', False, False)
    verify_pool(1, 256, 32, 2, 2, [0, 0, 0, 0], 'max', False)
    verify_pool(1, 256, 31, 3, 3, [2, 1, 2, 1], 'max', False)
    verify_pool(1, 256, 31, 3, 3, [2, 1, 2, 1], 'max', True)

    verify_pool(1, 256, 31, 3, 3, [2, 1, 0, 3], 'avg', False, True)
    verify_pool(1, 256, 32, 2, 2, [0, 3, 2, 1], 'avg', False, False)
    verify_pool(1, 256, 31, 3, 3, [1, 0, 3, 2], 'max', False)
    verify_pool(1, 256, 31, 3, 3, [3, 2, 1, 0], 'max', True)

def test_pool_grad():
    verify_pool_grad(1, 256, 32, 3, 2, [1, 1, 1, 1], 'avg', False, False)
    verify_pool_grad(1, 256, 32, 2, 2, [0, 0, 0, 0], 'avg', False, True)
    verify_pool_grad(1, 256, 31, 3, 3, [1, 2, 1, 2], 'avg', False, True)
    verify_pool_grad(1, 256, 32, 2, 2, [1, 2, 1, 2], 'avg', False, False)
    verify_pool_grad(1, 256, 31, 4, 4, [2, 2, 2, 2], 'avg', False, False)
    verify_pool_grad(1, 256, 31, 4, 4, [0, 0, 0, 0], 'avg', False, False)
    verify_pool_grad(1, 256, 32, 2, 2, [0, 0, 0, 0], 'max', False)
    verify_pool_grad(1, 256, 31, 3, 3, [2, 1, 2, 1], 'max', False)
    verify_pool_grad(1, 256, 31, 3, 3, [2, 1, 2, 1], 'max', True)

    verify_pool_grad(1, 256, 31, 3, 3, [2, 1, 0, 3], 'avg', False, True)
    verify_pool_grad(1, 256, 32, 2, 2, [0, 3, 2, 1], 'avg', False, False)
    verify_pool_grad(1, 256, 31, 3, 3, [1, 0, 3, 2], 'max', False)
    verify_pool_grad(1, 256, 31, 3, 3, [3, 2, 1, 0], 'max', True)
    verify_pool_grad(1, 256, 32, 3, 2, [1, 1, 1, 1], 'max', False)
    verify_pool_grad(1, 256, 32, 1, 2, [1, 1, 1, 1], 'avg', False, False)

    verify_pool_grad(1, 256, 31, 4, 4, [0, 0, 0, 0], 'avg', False, False, add_relu=True)
    verify_pool_grad(1, 256, 32, 2, 2, [0, 0, 0, 0], 'max', False, add_relu=True)


def verify_global_pool(n, c, h, w, pool_type, layout='NCHW'):

    assert layout in ["NCHW", "NHWC"]
    A = te.placeholder((n, c, h, w), name='A')
    B = topi.nn.global_pool(A, pool_type=pool_type, layout=layout)
    B = topi.nn.relu(B)

    a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype)

    axis = (layout.find('H'), layout.find('W'))
    if pool_type == 'avg':
        b_np = np.mean(a_np, axis=axis, keepdims=True)
    elif pool_type =='max':
        b_np = np.max(a_np, axis=axis, keepdims=True)
    b_np = np.maximum(b_np, 0.0)

    def check_device(device):
        ctx = tvm.context(device, 0)
        if not ctx.exist:
            print("Skip because %s is not enabled" % device)
            return
        print("Running on target: %s" % device)
        with tvm.target.create(device):
            s_func = topi.testing.dispatch(device, _adaptive_pool_schedule)
            s = s_func(B)
        a = tvm.nd.array(a_np, ctx)
        b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
        f = tvm.build(s, [A, B], device)
        f(a, b)
        tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)

    for device in get_all_backend():
        check_device(device)

def test_global_pool():
    verify_global_pool(1, 1024, 7, 7, 'avg')
    verify_global_pool(4, 1024, 7, 7, 'avg')
    verify_global_pool(1, 1024, 7, 7, 'max')
    verify_global_pool(4, 1024, 7, 7, 'max')
    verify_global_pool(1, 1024, 7, 7, 'avg', 'NHWC')
    verify_global_pool(4, 1024, 7, 7, 'avg', 'NHWC')
    verify_global_pool(1, 1024, 7, 7, 'max', 'NHWC')
    verify_global_pool(4, 1024, 7, 7, 'max', 'NHWC')

def verify_adaptive_pool(dshape, out_size, pool_type, layout="NCHW", dtype="float32"):
    def start_index(index, odim, idim):
        return int(np.floor(index * idim / odim))

    def end_index(index, odim, idim):
        return int(np.ceil((index + 1) * idim / odim))

    np_data = np.random.uniform(low=0, high=255, size=dshape).astype(dtype)
    n, c, h, w = dshape
    oh, ow = out_size
    oshape = (n, c) + out_size
    np_out = np.zeros(oshape).astype(dtype)
    np_op = np.mean if pool_type == "avg" else np.max
    for i in range(n):
        for j in range(c):
            for k in range(oh):
                k_start = start_index(k, oh, h)
                k_end = end_index(k, oh, h)
                k_sl = slice(k_start, k_end)
                for l in range(ow):
                    l_start = start_index(l, ow, w)
                    l_end = end_index(l, ow, w)
                    l_sl = slice(l_start, l_end)
                    np_out[i, j, k, l] = np_op(np_data[i, j, k_sl, l_sl])

    data = te.placeholder(dshape, name="data", dtype=dtype)
    out = topi.nn.adaptive_pool(data, out_size, pool_type, layout)
    def check_device(device):
        ctx = tvm.context(device, 0)
        if not ctx.exist:
            print("Skip because %s is not enabled" % device)
            return
        print("Running on target: %s" % device)
        with tvm.target.create(device):
            s_func = topi.testing.dispatch(device, _adaptive_pool_schedule)
            s = s_func(out)
        a = tvm.nd.array(np_data, ctx)
        b = tvm.nd.array(np.zeros(get_const_tuple(oshape), dtype=out.dtype), ctx)
        f = tvm.build(s, [data, out], device)
        f(a, b)
        tvm.testing.assert_allclose(b.asnumpy(), np_out, rtol=1e-5)

    for device in get_all_backend():
        check_device(device)

def test_adaptive_pool():
    verify_adaptive_pool((1, 3, 224, 224), (1, 1), "max")
    verify_adaptive_pool((1, 3, 224, 224), (1, 1), "avg")
    verify_adaptive_pool((1, 14, 56, 78), (34, 13), "max")
    verify_adaptive_pool((1, 5, 46, 97), (4, 96), "avg")

def verify_pool3d(n, ic, ih, kh, sh, padding, pool_type,
                  ceil_mode, count_include_pad=True, layout='NCDHW'):
    id = iw = ih
    kd = kw = kh
    sd = sw = sh
    input_shape = (n, ic, id, ih, iw)
    kernel = [kd, kh, kw]
    stride = [sd, sh, sw]
    A = te.placeholder(input_shape, name='A')
    B = topi.nn.pool3d(A, kernel=kernel, stride=stride, padding=padding,
                       pool_type=pool_type, ceil_mode=ceil_mode,
                       layout=layout, count_include_pad=count_include_pad)
    B = topi.nn.relu(B)
    dtype = A.dtype
    output_shape = [int(i) for i in B.shape]

    input_np = np.random.uniform(low=0.001, size=input_shape).astype(dtype)
    ref_np = topi.testing.pool3d_ncdhw_python(input_np, kernel, stride, padding,
                                              output_shape, pool_type, count_include_pad, ceil_mode)

    def check_device(device):
        ctx = tvm.context(device, 0)
        if not ctx.exist:
            print("Skip because %s is not enabled" % device)
            return
        print("Running on target: %s" % device)
        with tvm.target.create(device):
            s_func = topi.testing.dispatch(device, _pool_schedule)
            s = s_func(B, layout)

        a = tvm.nd.array(input_np, ctx)
        b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx)
        f = tvm.build(s, [A, B], device)
        f(a, b)
        tvm.testing.assert_allclose(b.asnumpy(), ref_np, rtol=1e-5)

    for device in get_all_backend():
        check_device(device)


def test_pool3d():
    verify_pool3d(1, 256, 32, 2, 2, [0, 0, 0, 0, 0, 0], 'avg', False, True)
    verify_pool3d(1, 256, 31, 3, 3, [1, 1, 2, 2, 2, 1], 'avg', False, True)
    verify_pool3d(1, 256, 32, 2, 2, [1, 1, 2, 2, 2, 1], 'avg', False, False)
    verify_pool3d(1, 256, 31, 4, 4, [3, 3, 3, 3, 3, 3], 'avg', False, False)
    verify_pool3d(1, 256, 31, 4, 4, [0, 0, 0, 0, 0, 0], 'avg', False, False)
    verify_pool3d(1, 256, 32, 2, 2, [0, 0, 0, 0, 0, 0], 'max', False)
    verify_pool3d(1, 256, 31, 3, 3, [2, 2, 1, 1, 1, 2], 'max', False)
    verify_pool3d(1, 256, 31, 3, 3, [2, 2, 1, 1, 1, 2], 'max', True)

    verify_pool3d(1, 256, 31, 3, 3, [2, 1, 0, 5, 4, 3], 'avg', False, True)
    verify_pool3d(1, 256, 32, 2, 2, [0, 5, 4, 3, 2, 1], 'avg', False, False)
    verify_pool3d(1, 256, 31, 3, 3, [1, 0, 5, 4, 3, 2], 'max', False)
    verify_pool3d(1, 256, 31, 3, 3, [3, 2, 1, 0, 5, 4], 'max', True)


def verify_pool1d(n, ic, iw, kw, sw, padding, pool_type,
                  ceil_mode, count_include_pad=True, layout='NCW'):
    input_shape = (n, ic, iw)
    kernel = [kw]
    stride = [sw]
    A = te.placeholder(input_shape, name='A')
    B = topi.nn.pool1d(A, kernel=kernel, stride=stride, padding=padding,
                       pool_type=pool_type, ceil_mode=ceil_mode,
                       layout=layout, count_include_pad=count_include_pad)
    B = topi.nn.relu(B)
    dtype = A.dtype
    output_shape = [int(i) for i in B.shape]

    input_np = np.random.uniform(low=0.001, size=input_shape).astype(dtype)
    ref_np = topi.testing.pool1d_ncw_python(input_np, kernel, stride, padding,
                                            output_shape, pool_type, count_include_pad, ceil_mode)

    def check_device(device):
        ctx = tvm.context(device, 0)
        if not ctx.exist:
            print("Skip because %s is not enabled" % device)
            return
        print("Running on target: %s" % device)
        with tvm.target.create(device):
            s_func = topi.testing.dispatch(device, _pool_schedule)
            s = s_func(B, layout)

        a = tvm.nd.array(input_np, ctx)
        b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx)
        f = tvm.build(s, [A, B], device)
        f(a, b)
        tvm.testing.assert_allclose(b.asnumpy(), ref_np, rtol=1e-5)

    for device in get_all_backend():
        check_device(device)


def test_pool1d():
    verify_pool1d(1, 256, 32, 2, 2, [0, 0], 'avg', False, True)
    verify_pool1d(1, 256, 31, 3, 3, [1, 2], 'avg', False, True)
    verify_pool1d(1, 256, 32, 2, 2, [1, 2], 'avg', False, False)
    verify_pool1d(1, 256, 31, 4, 4, [3, 3], 'avg', False, False)
    verify_pool1d(1, 256, 31, 4, 4, [0, 0], 'avg', False, False)
    verify_pool1d(1, 256, 32, 2, 2, [0, 0], 'max', False)
    verify_pool1d(1, 256, 31, 3, 3, [2, 1], 'max', False)
    verify_pool1d(1, 256, 31, 3, 3, [2, 1], 'max', True)

    verify_pool1d(1, 256, 31, 3, 3, [2, 5], 'avg', False, True)
    verify_pool1d(1, 256, 32, 2, 2, [0, 3], 'avg', False, False)
    verify_pool1d(1, 256, 31, 3, 3, [1, 4], 'max', False)
    verify_pool1d(1, 256, 31, 3, 3, [3, 0], 'max', True)


if __name__ == "__main__":
    test_pool()
    test_pool1d()
    test_pool3d()
    test_pool_grad()
    test_global_pool()
    test_adaptive_pool()