# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
import numpy as np

import topi
import topi.testing
import tvm
from tvm import relay
from tvm.relay.testing import check_grad, ctx_list, run_infer_type
from tvm.relay.transform import gradient


def verify_max_pool2d_grad(x_shape, pool_size, strides, padding, ceil_mode):
    x = relay.var("x", relay.TensorType(x_shape, "float32"))
    y = tvm.relay.nn.max_pool2d(x, pool_size=pool_size, strides=strides, padding=padding,
                                ceil_mode=ceil_mode)

    fwd_func = relay.Function([x], y)
    fwd_func = run_infer_type(fwd_func)
    bwd_func = run_infer_type(gradient(fwd_func))

    data = np.random.rand(*x_shape).astype("float32")
    ph, pw = padding
    y_shape = topi.util.get_const_tuple(fwd_func.ret_type.shape)
    out_grad = np.ones(shape=y_shape)
    ref_grad = topi.testing.pool_grad_nchw(data, out_grad, pool_size=pool_size, strides=strides,
                                           padding=[ph, pw, ph, pw],
                                           pool_type='max', ceil_mode=ceil_mode)

    for target, ctx in ctx_list():
        intrp = relay.create_executor(ctx=ctx, target=target)
        op_res, (op_grad, ) = intrp.evaluate(bwd_func)(data)
        np.testing.assert_allclose(op_grad.asnumpy(), ref_grad, rtol=0.01)


def test_max_pool2d_grad():
    verify_max_pool2d_grad((1, 4, 16, 16), pool_size=(2, 2), strides=(2, 2), padding=(0, 0), ceil_mode=False)
    verify_max_pool2d_grad((1, 4, 16, 16), pool_size=(1, 1), strides=(1, 1), padding=(1, 1), ceil_mode=False)


def verify_avg_pool2d_grad(x_shape, pool_size, strides, padding, ceil_mode, count_include_pad):
    x = relay.var("x", relay.TensorType(x_shape, "float32"))
    y = tvm.relay.nn.avg_pool2d(x, pool_size=pool_size, strides=strides, padding=padding,
                                ceil_mode=ceil_mode, count_include_pad=count_include_pad)

    fwd_func = relay.Function([x], y)
    fwd_func = run_infer_type(fwd_func)
    bwd_func = run_infer_type(gradient(fwd_func))

    data = np.random.rand(*x_shape).astype("float32")
    ph, pw = padding
    y_shape = topi.util.get_const_tuple(fwd_func.ret_type.shape)
    out_grad = np.ones(shape=y_shape)
    ref_grad = topi.testing.pool_grad_nchw(data, out_grad, pool_size=pool_size, strides=strides,
                                           padding=[ph, pw, ph, pw],
                                           pool_type='avg', ceil_mode=ceil_mode)

    for target, ctx in ctx_list():
        intrp = relay.create_executor(ctx=ctx, target=target)
        op_res, (op_grad, ) = intrp.evaluate(bwd_func)(data)
        np.testing.assert_allclose(op_grad.asnumpy(), ref_grad, rtol=0.01)

def test_avg_pool2d_grad():
    verify_avg_pool2d_grad((1, 4, 16, 16), pool_size=(2, 2), strides=(2, 2), padding=(0, 0),
                           ceil_mode=False, count_include_pad=True)
    verify_avg_pool2d_grad((1, 4, 16, 16), pool_size=(1, 1), strides=(1, 1), padding=(1, 1),
                           ceil_mode=False, count_include_pad=False)


def verify_global_avg_pool2d_grad(x_shape):
    x = relay.var("x", relay.TensorType(x_shape, "float32"))
    y = tvm.relay.nn.global_avg_pool2d(x)

    fwd_func = relay.Function([x], y)
    fwd_func = run_infer_type(fwd_func)
    bwd_func = run_infer_type(gradient(fwd_func))

    data = np.random.rand(*x_shape).astype("float32")
    y_shape = topi.util.get_const_tuple(fwd_func.ret_type.shape)
    out_grad = np.ones(shape=y_shape)
    ref_grad = topi.testing.pool_grad_nchw(data, out_grad, pool_size=(x_shape[2], x_shape[3]), 
                                            strides=(1, 1), padding=[0, 0, 0, 0], pool_type='avg', 
                                            ceil_mode=False)

    for target, ctx in ctx_list():
        intrp = relay.create_executor(ctx=ctx, target=target)
        op_res, (op_grad, ) = intrp.evaluate(bwd_func)(data)
        np.testing.assert_allclose(op_grad.asnumpy(), ref_grad, rtol=0.01)

def test_global_avg_pool2d_grad():
    verify_global_avg_pool2d_grad((1, 4, 16, 16))
    verify_global_avg_pool2d_grad((1, 8, 8, 24))

def verify_conv2d_grad(dshape, wshape, strides, padding, dilation, groups=1, mode='higher_order'):
    try:
        import torch
        import torch.nn.functional as F
    except ImportError:
        print('Skip because pytorch is not installed')
        return

    dtype = 'float32'
    data = relay.var('data', shape=dshape, dtype=dtype)
    weight = relay.var('weight', shape=wshape, dtype=dtype)
    conv = relay.nn.conv2d(data, weight, strides=strides, padding=padding, dilation=dilation,
                           groups=groups)
    fwd_func = relay.Function([data, weight], conv)
    fwd_func = run_infer_type(fwd_func)
    bwd_func = run_infer_type(gradient(fwd_func, mode=mode))

    data_pt = torch.randn(*dshape, dtype=torch.float32, requires_grad=True)
    weight_pt = torch.randn(*wshape, dtype=torch.float32, requires_grad=True)
    out_pt = F.conv2d(data_pt, weight_pt, stride=strides, padding=padding, dilation=dilation,
                      groups=groups)
    grad_output_pt = torch.ones(out_pt.shape)
    grad_input_pt = F.grad.conv2d_input(dshape, weight_pt, grad_output_pt, stride=strides,
                                        padding=padding, dilation=dilation, groups=groups) \
                          .detach().numpy()
    grad_weight_pt = F.grad.conv2d_weight(data_pt, wshape, grad_output_pt, stride=strides,
                                          padding=padding, dilation=dilation, groups=groups) \
                           .detach().numpy()


    for target, ctx in ctx_list():
        data = tvm.nd.array(data_pt.detach().numpy(), ctx)
        weight = tvm.nd.array(weight_pt.detach().numpy(), ctx)
        intrp = relay.create_executor(ctx=ctx, target=target)
        op_res, (grad_input, grad_weight) = intrp.evaluate(bwd_func)(data, weight)
        np.testing.assert_allclose(grad_input.asnumpy(), grad_input_pt, rtol=1e-4, atol=1e-4)
        np.testing.assert_allclose(grad_weight.asnumpy(), grad_weight_pt, rtol=1e-4, atol=1e-4)


def test_conv2d_grad():
    verify_conv2d_grad((1, 4, 16, 16), (16, 4, 3, 3), [1, 1], [1, 1], [1, 1])
    verify_conv2d_grad((1, 4, 16, 16), (16, 4, 1, 1), [1, 1], [0, 0], [1, 1])
    verify_conv2d_grad((1, 4, 16, 16), (16, 4, 1, 1), [2, 2], [0, 0], [1, 1])
    verify_conv2d_grad((1, 4, 16, 16), (16, 4, 3, 3), [1, 1], [1, 1], [1, 1], mode='first_order')


def verify_dense_grad(d_shape, w_shape):
    data = relay.var("data", relay.TensorType(d_shape, "float32"))
    weight = relay.var("weight", relay.TensorType(w_shape, "float32"))
    fwd_func = relay.Function([data, weight], relay.nn.dense(data, weight))
    check_grad(fwd_func)


def test_dense_grad():
    verify_dense_grad((1, 8), (16, 8))
    verify_dense_grad((1, 4), (3, 4))


def verify_batch_flatten_grad(d_shape):
    data = relay.var("data", relay.TensorType(d_shape, "float32"))
    fwd_func = relay.Function([data], relay.nn.batch_flatten(data))
    check_grad(fwd_func)


def test_batch_flatten_grad():
    verify_batch_flatten_grad((1, 2, 3, 4))
    verify_batch_flatten_grad((1, 8))


if __name__ == "__main__":
    test_max_pool2d_grad()
    test_avg_pool2d_grad()
    test_global_avg_pool2d_grad()
    test_conv2d_grad()
    test_dense_grad()
    test_batch_flatten_grad()