# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
"""Test alter op layout pass"""
import tvm

from tvm import relay
from tvm.relay.op import register_alter_op_layout
from tvm.relay import transform, analysis


def run_opt_pass(expr, passes):
    passes = passes if isinstance(passes, list) else [passes]
    mod = relay.Module.from_expr(expr)
    seq = transform.Sequential(passes)
    with transform.PassContext(opt_level=3):
        mod = seq(mod)
    entry = mod["main"]
    return entry if isinstance(expr, relay.Function) else entry.body


def test_no_convert_layout():
    def before():
        x = relay.var("x", shape=(1, 64, 56, 56))
        weight = relay.var('weight', shape=(64, 64, 3, 3))
        y = relay.nn.conv2d(x, weight,
                            channels=64,
                            kernel_size=(3, 3),
                            padding=(1, 1))
        y = relay.nn.relu(y)
        y = relay.Function([x, weight], y)
        return y

    def expected():
        return before()

    a = before()
    a = run_opt_pass(a, transform.ConvertLayout('NCHW'))
    b = run_opt_pass(expected(), transform.InferType())

    assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)


def test_conv_convert_layout():
    def before():
        x = relay.var("x", shape=(1, 56, 56, 64))
        weight = relay.var('weight', shape=(3, 3, 64, 64))
        y = relay.nn.conv2d(x, weight,
                            channels=64,
                            kernel_size=(3, 3),
                            padding=(1, 1),
                            data_layout='NHWC',
                            kernel_layout='HWIO')
        y = relay.nn.relu(y)
        y = relay.Function([x, weight], y)
        return y

    def expected():
        x = relay.var("x", shape=(1, 56, 56, 64))
        weight = relay.var('weight', shape=(3, 3, 64, 64))
        x = relay.layout_transform(x, 'NHWC', 'NCHW')
        weight = relay.layout_transform(weight, 'HWIO', 'OIHW')
        y = relay.nn.conv2d(x, weight,
                            channels=64,
                            kernel_size=(3, 3),
                            padding=(1, 1))
        y = relay.nn.relu(y)
        y = relay.layout_transform(y, 'NCHW', 'NHWC')
        y = relay.Function(relay.analysis.free_vars(y), y)
        return y

    a = before()
    a = run_opt_pass(a, transform.ConvertLayout('NCHW'))
    b = run_opt_pass(expected(), transform.InferType())

    assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)


def test_conv_bias_pool_convert_layout():
    def before():
        x = relay.var("x", shape=(1, 56, 56, 64))
        bias = relay.var("bias", shape=(64,))
        weight = relay.var("weight", shape=(3, 3, 64, 64))
        y = relay.nn.conv2d(x, weight, channels=64, kernel_size=(3, 3), padding=(1, 1),
                            data_layout='NHWC', kernel_layout='HWIO')
        y = relay.nn.bias_add(y, bias, axis=3)
        # a useless tuple, which will be eliminated
        y = relay.Tuple([y])[0]
        y = relay.nn.relu(y)
        y = relay.nn.max_pool2d(y, pool_size=(2, 2), layout='NHWC')
        y = relay.cast(y, 'int32')
        y = relay.nn.batch_flatten(y)
        y = relay.Function(analysis.free_vars(y), y)
        return y

    def expected():
        x = relay.var("x", shape=(1, 56, 56, 64))
        bias = relay.var("bias", shape=(64,))
        weight = relay.var("weight", shape=(3, 3, 64, 64))
        x = relay.layout_transform(x, 'NHWC', 'NCHW')
        weight = relay.layout_transform(weight, 'HWIO', 'OIHW')
        y = relay.nn.conv2d(x, weight, channels=64, kernel_size=(3, 3), padding=(1, 1))

        bias = relay.expand_dims(bias, axis=0, num_newaxis=3)
        bias = relay.layout_transform(bias, 'NHWC', 'NCHW')
        y = relay.add(y, bias)
        # a useless tuple, which will be eliminated
        y = relay.Tuple([y])[0]
        y = relay.nn.relu(y)
        y = relay.nn.max_pool2d(y, pool_size=(2, 2))
        y = relay.cast(y, 'int32')
        y = relay.layout_transform(y, 'NCHW', 'NHWC')
        y = relay.nn.batch_flatten(y)
        y = relay.Function(analysis.free_vars(y), y)
        return y

    a = before()
    a = run_opt_pass(a, transform.ConvertLayout('NCHW'))
    b = run_opt_pass(expected(), transform.InferType())

    assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)


def test_conv_concat_convert_layout():
    def before():
        x = relay.var("x", shape=(1, 56, 56, 64))
        weight1 = relay.var('weight1', shape=(3, 3, 64, 64))
        weight2 = relay.var('weight2', shape=(3, 3, 64, 64))
        y = relay.nn.conv2d(x, weight1,
                            channels=64,
                            kernel_size=(3, 3),
                            padding=(1, 1),
                            data_layout='NHWC',
                            kernel_layout='HWIO')
        y1 = relay.nn.conv2d(y, weight2,
                             channels=64,
                             kernel_size=(3, 3),
                             padding=(1, 1),
                             data_layout='NHWC',
                             kernel_layout='HWIO')
        ret = relay.concatenate([y, y1], axis=3)
        y = relay.Function(analysis.free_vars(ret), ret)
        return y

    def expected():
        x = relay.var("x", shape=(1, 56, 56, 64))
        weight1 = relay.var('weight1', shape=(3, 3, 64, 64))
        weight2 = relay.var('weight2', shape=(3, 3, 64, 64))
        weight1 = relay.layout_transform(weight1, 'HWIO', 'OIHW')
        weight2 = relay.layout_transform(weight2, 'HWIO', 'OIHW')
        y = relay.layout_transform(x, "NHWC", "NCHW")
        y = relay.nn.conv2d(y, weight1,
                            channels=64,
                            kernel_size=(3, 3),
                            padding=(1, 1))
        y1 = relay.nn.conv2d(y, weight2,
                             channels=64,
                             kernel_size=(3, 3),
                             padding=(1, 1))
        ret = relay.concatenate([y, y1], axis=1)
        ret = relay.layout_transform(ret, "NCHW", "NHWC")
        y = relay.Function(analysis.free_vars(ret), ret)
        return y

    a = before()
    a = run_opt_pass(a, transform.ConvertLayout('NCHW'))
    b = run_opt_pass(expected(), transform.InferType())

    assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)


def test_dual_path_convert_layout():
    def before():
        x = relay.var("x", shape=(1, 56, 56, 64))
        weight1 = relay.var('weight1', shape=(3, 3, 64, 32))
        weight2 = relay.var('weight2', shape=(3, 3, 32, 32))
        y = relay.nn.conv2d(x, weight1,
                            channels=32,
                            kernel_size=(3, 3),
                            padding=(1, 1),
                            data_layout='NHWC',
                            kernel_layout='HWIO')
        y = relay.nn.relu(y)
        y1 = relay.nn.conv2d(y, weight2,
                             channels=32,
                             kernel_size=(3, 3),
                             padding=(1, 1),
                             data_layout='NHWC',
                             kernel_layout='HWIO')
        y1 = relay.nn.relu(y1)
        y2 = relay.nn.batch_flatten(y)
        ret = relay.Tuple([y1, y2])
        y = relay.Function(analysis.free_vars(ret), ret)
        return y

    def expected():
        x = relay.var("x", shape=(1, 56, 56, 64))
        weight1 = relay.var('weight1', shape=(3, 3, 64, 32))
        weight2 = relay.var('weight2', shape=(3, 3, 32, 32))
        weight1 = relay.layout_transform(weight1, 'HWIO', 'OIHW')
        weight2 = relay.layout_transform(weight2, 'HWIO', 'OIHW')
        y = relay.layout_transform(x, "NHWC", "NCHW")
        y = relay.nn.conv2d(y, weight1,
                            channels=32,
                            kernel_size=(3, 3),
                            padding=(1, 1))
        y = relay.nn.relu(y)
        y1 = relay.nn.conv2d(y, weight2,
                             channels=32,
                             kernel_size=(3, 3),
                             padding=(1, 1))
        y1 = relay.nn.relu(y1)
        y1 = relay.layout_transform(y1, "NCHW", "NHWC")
        y2 = relay.layout_transform(y, "NCHW", "NHWC")
        y2 = relay.nn.batch_flatten(y2)
        ret = relay.Tuple([y1, y2])
        y = relay.Function(analysis.free_vars(ret), ret)
        return y

    a = before()
    a = run_opt_pass(a, transform.ConvertLayout('NCHW'))
    b = run_opt_pass(expected(), transform.InferType())

    assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)


def test_bn_convert_layout():
    def before():
        x = relay.var("x", shape=(1, 56, 56, 64))
        weight1 = relay.var('weight1', shape=(3, 3, 64, 32))
        y = relay.nn.conv2d(x, weight1,
                            channels=32,
                            kernel_size=(3, 3),
                            padding=(1, 1),
                            data_layout='NHWC',
                            kernel_layout='HWIO')
        gamma = relay.var("gamma")
        beta = relay.var("beta")
        mean = relay.var("mean")
        variance = relay.var("variance")
        y, _, _ = relay.nn.batch_norm(y , gamma, beta, mean, variance, axis=3)
        return relay.Function(analysis.free_vars(y), y)

    a = before()
    a = run_opt_pass(a, transform.ConvertLayout('NCHW'))

    # Check that there is only 1 NHWC to NCHW transform.
    has_lt = list()
    find_op = lambda x : \
            has_lt.append(isinstance(x, tvm.relay.expr.Call) and x.op.name == "layout_transform" \
            and x.attrs.src_layout == 'NCHW' and x.attrs.dst_layout == 'NHWC')
    relay.analysis.post_order_visit(a, find_op)
    has_lt = list(filter(lambda x: x, has_lt))
    assert len(has_lt) == 1


def test_resnet_convert_layout():
    def before():
        x = relay.var("x", shape=(1, 56, 56, 64))
        weight1 = relay.var('weight1', shape=(3, 3, 64, 32))
        weight2 = relay.var('weight2', shape=(1, 1, 64, 32))
        y = relay.nn.conv2d(x, weight1,
                            channels=32,
                            kernel_size=(3, 3),
                            padding=(1, 1),
                            data_layout='NHWC',
                            kernel_layout='HWIO')
        y = relay.nn.relu(y)
        y2 = relay.nn.conv2d(x, weight2,
                             channels=32,
                             kernel_size=(1, 1),
                             data_layout='NHWC',
                             kernel_layout='HWIO')
        y2 = relay.nn.relu(y2)
        y = y + y2
        y = relay.nn.global_max_pool2d(y, layout='NHWC')
        return relay.Function(analysis.free_vars(y), y)

    def expected():
        x = relay.var("x", shape=(1,56, 56, 64))
        weight1 = relay.var('weight1', shape=(3, 3, 64, 32))
        weight2 = relay.var('weight2', shape=(1, 1, 64, 32))
        weight1 = relay.layout_transform(weight1, 'HWIO', 'OIHW')
        weight2 = relay.layout_transform(weight2, 'HWIO', 'OIHW')
        x = relay.layout_transform(x, "NHWC", "NCHW")
        y = relay.nn.conv2d(x, weight1,
                            channels=32,
                            kernel_size=(3, 3),
                            padding=(1, 1))
        y = relay.nn.relu(y)
        y2 = relay.nn.conv2d(x, weight2,
                             channels=32,
                             kernel_size=(1, 1))
        y2 = relay.nn.relu(y2)
        y = y + y2
        y = relay.nn.global_max_pool2d(y)
        y = relay.layout_transform(y, "NCHW", "NHWC")
        return relay.Function(analysis.free_vars(y), y)

    a = before()
    a = run_opt_pass(a, transform.ConvertLayout('NCHW'))
    b = run_opt_pass(expected(), transform.InferType())

    assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)


def test_scalar_convert_layout():
    def before():
        x = relay.var("x", shape=(1, 56, 56, 64))
        weight = relay.var("weight", shape=(3, 3, 64, 64))
        y = relay.nn.conv2d(x, weight, channels=64, kernel_size=(3, 3), padding=(1, 1),
                            data_layout='NHWC', kernel_layout='HWIO')
        y = relay.add(y, relay.const(1, "float32"))
        y = relay.Function(analysis.free_vars(y), y)
        return y

    def expected():
        x = relay.var("x", shape=(1, 56, 56, 64))
        w = relay.var("weight", shape=(3, 3, 64, 64))
        x = relay.layout_transform(x, 'NHWC', 'NCHW')
        w = relay.layout_transform(w, 'HWIO', 'OIHW')
        y = relay.nn.conv2d(x, w,
                            channels=64,
                            kernel_size=(3, 3),
                            padding=(1, 1))
        y = relay.add(y, relay.const(1.0, "float32"))

        y = relay.layout_transform(y, "NCHW", "NHWC")
        y = relay.Function(analysis.free_vars(y), y)
        return y

    a = before()
    a = run_opt_pass(a, transform.ConvertLayout('NCHW'))
    b = run_opt_pass(expected(), transform.InferType())

    assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)


def test_conv_bn_convert_layout():
    """ Check that layout transforms are propagated through bn. """
    def before():
        x = relay.var("x", shape=(1, 56, 56, 64))
        weight = relay.var("weight", shape=(3, 3, 64, 64))
        y = relay.nn.conv2d(x, weight, channels=64, kernel_size=(3, 3), padding=(1, 1),
                            data_layout='NHWC', kernel_layout='HWIO')

        dtype = "float32"
        beta = relay.var("beta", relay.TensorType((64,), dtype))
        gamma = relay.var("gamma", relay.TensorType((64,), dtype))
        moving_mean = relay.var("moving_mean", relay.TensorType((64,), dtype))
        moving_var = relay.var("moving_var", relay.TensorType((64,), dtype))

        y = relay.nn.batch_norm(y, gamma, beta, moving_mean, moving_var, axis=3)
        y = relay.nn.relu(y[0])
        y = relay.Function(analysis.free_vars(y), y)
        return y

    def expected():
        x = relay.var("x", shape=(1, 56, 56, 64))
        w = relay.var("weight", shape=(3, 3, 64, 64))
        x = relay.layout_transform(x, 'NHWC', 'NCHW')
        w = relay.layout_transform(w, 'HWIO', 'OIHW')
        y = relay.nn.conv2d(x, w,
                            channels=64,
                            kernel_size=(3, 3),
                            padding=(1, 1))

        dtype = "float32"
        beta = relay.var("beta", relay.TensorType((64,), dtype))
        gamma = relay.var("gamma", relay.TensorType((64,), dtype))
        moving_mean = relay.var("moving_mean", relay.TensorType((64,), dtype))
        moving_var = relay.var("moving_var", relay.TensorType((64,), dtype))

        y = relay.nn.batch_norm(y, gamma, beta, moving_mean, moving_var, axis=1)
        y = relay.nn.relu(y[0])
        y = relay.layout_transform(y, "NCHW", "NHWC")
        y = relay.Function(analysis.free_vars(y), y)
        return y

    a = before()
    a = run_opt_pass(a, transform.ConvertLayout('NCHW'))
    b = run_opt_pass(expected(), transform.InferType())

    assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)


if __name__ == "__main__":
    test_no_convert_layout()
    test_conv_convert_layout()
    test_conv_bias_pool_convert_layout()
    test_conv_concat_convert_layout()
    test_dual_path_convert_layout()
    test_bn_convert_layout()
    test_resnet_convert_layout()
    test_scalar_convert_layout()
    test_conv_bn_convert_layout()