# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
"""Unittest cases for AlterOpLayout pass"""
from nnvm import symbol as sym
from nnvm.compiler import graph_attr
from nnvm.top import registry as reg
import nnvm.graph as graph

def get_layouts(g):
    ldict = {}
    vlayout = g.json_attr("layout")
    entry_ptr = g.index.entry_ptr
    for i, n in enumerate(g.index.nodes):
        begin, end = entry_ptr[i], entry_ptr[i + 1]
        ldict[n["name"]] = vlayout[begin:end]
    return ldict


def test_alter_conv2d_layout():
    data = sym.Variable("data", shape=(1, 32, 512, 512))
    conv = sym.conv2d(data, name="conv", channels=16,
                      kernel_size=(3,3), padding=(1,1),
                      use_bias=False, layout="NCHW")
    # split here
    convs = sym.split(conv, indices_or_sections=2)
    relus = [sym.relu(x, name="relu") for x in convs]
    relu = sym.concatenate(*relus)
    flatten = sym.flatten(relu, name="flatten")
    softmax = sym.softmax(flatten, name="softmax")
    g = graph.create(softmax)

    g = g.apply("CorrectLayout")
    g = graph_attr.set_dtype_inputs(g, "float32")
    g = g.apply(["InferShape", "InferType"])
    layouts_origin = get_layouts(g)

    @reg.register_alter_op_layout("conv2d", level=100)
    def alter_conv2d_layout(attrs, inputs, tinfos):
        new_attrs = {k : attrs[k] for k in attrs.keys()}
        new_attrs["layout"] = "NCHW16c"
        new_attrs["kernel_layout"] = "NCHW16c"
        new_attrs["name"] = "conv_alter"
        return sym.conv2d(inputs[0], inputs[1], **new_attrs)

    g = g.apply("AlterOpLayout")
    layouts = get_layouts(g)

    # check copy layouts
    for node in ["data", "relu", "flatten", "softmax", "conv_weight"]:
        assert layouts[node] == layouts_origin[node]
    assert layouts["conv_alter"] == layouts_origin["conv"]


def test_consecutive_alter_layout():
    data = sym.Variable("data", shape=(1, 32, 512, 512))
    pool1 = sym.global_avg_pool2d(data, name="global_avg_pool2d_1", layout="NCHW")
    pool2 = sym.global_avg_pool2d(pool1, name="global_avg_pool2d_2", layout="NCHW")
    relu = sym.relu(pool2, name="relu")

    g = graph.create(relu)
    g = g.apply("CorrectLayout")
    g = graph_attr.set_dtype_inputs(g, "float32")
    g = g.apply(["InferShape", "InferType"])
    assert g.json_attr("layout") == ['NCHW', 'NCHW', 'NCHW', 'NCHW']

    @reg.register_alter_op_layout("global_avg_pool2d", level=100)
    def alter_global_avg_pool2d_layout(attrs, inputs, tinfos):
        new_attrs = {k : attrs[k] for k in attrs.keys()}
        new_attrs["layout"] = "NCHW16c"
        return sym.global_avg_pool2d(inputs[0], **new_attrs)

    g = g.apply("AlterOpLayout")

    # pool1 get replaced - output layout of pool1 is not recorded
    # pool2 get replaced - input layout of pool2 is not recorded
    # thus the second entry must be undefined - it can neither recover from pool1's output,
    # nor from pool2's input.
    assert g.json_attr("layout") == ['NCHW', '__undef__', 'NCHW', 'NCHW']


def test_alter_func_return_none():
    data = sym.Variable("data", shape=(1, 32, 512, 512))
    pool1 = sym.global_max_pool2d(data, name="pool1", layout="NCHW")
    pool2 = sym.global_max_pool2d(pool1, name="pool2", layout="NCHW")
    relu = sym.relu(pool2, name="relu")

    g = graph.create(relu)
    g = g.apply("CorrectLayout")
    g = graph_attr.set_dtype_inputs(g, "float32")
    g = g.apply(["InferShape", "InferType"])
    assert g.json_attr("layout") == ['NCHW', 'NCHW', 'NCHW', 'NCHW']

    @reg.register_alter_op_layout("global_max_pool2d", level=100)
    def alter_global_max_pool2d_layout(attrs, inputs, tinfos):
        return None

    g = g.apply("AlterOpLayout")

    # alter func return none, nothing get replaced,
    # the layouts should remain the same
    assert g.json_attr("layout") == ['NCHW', 'NCHW', 'NCHW', 'NCHW']


if __name__ == "__main__":
    test_alter_conv2d_layout()
    test_consecutive_alter_layout()
    test_alter_func_return_none()