# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
"""Data layout."""
import tvm._ffi

from tvm.runtime import Object
from . import _ffi_api

@tvm._ffi.register_object
class Layout(Object):
    """Layout is composed of upper cases, lower cases and numbers,
    where upper case indicates a primal axis and
    the corresponding lower case with factor size indicates the subordinate axis.
    For example, NCHW16c can describe a 5-D tensor of
    [batch_size, channel, height, width, channel_block].
    Here subordinate axis channel_block=16 is the factor size of the primal axis C (channel).

    See Also
    --------
    layout : Declare a layout
    """
    def __len__(self):
        return _ffi_api.LayoutNdim(self)

    def __contains__(self, axis):
        return len(axis) == 1 and axis[0].isalpha() and axis[0] in self.name

    def __getitem__(self, index):
        if index >= len(self):
            raise IndexError("Layout index out of range")
        return _ffi_api.LayoutGetItem(self, index)

    def index_of(self, axis):
        """Get the index of an axis

        Parameters
        ----------
        axis : str
            The axis name, need to be [a-z,A-Z]

        Returns
        -------
        index : int
            The index of the axis, -1 if not found.
        """
        return _ffi_api.LayoutIndexOf(self, axis)

    def factor_of(self, axis):
        """Get the factor size of the subordinate axis.

        Parameters
        ----------
        axis : str
            The axis name, need to be [a-z,A-Z]

        Returns
        -------
        factor : int
            the size of the subordinate-axis of axis (if axis is a primal-axis),
            or the size of axis itself (if axis is a subordinate-axis).
            Return -1 if axis is not in the layout.
        """
        return _ffi_api.LayoutFactorOf(self, axis)


@tvm._ffi.register_object
class BijectiveLayout(Object):
    """Bijective mapping for two layouts (src-layout and dst-layout).
    It provides shape and index conversion between each other.

    Do not construct directly, use :any:`bijective_layout` instead.
    See the documentation of :any:`bijective_layout` for more details.

    Parameters
    ----------
    src_layout : str or Layout
        source layout.

    dst_layout : str or Layout
        destination layout.

    See Also
    --------
    bijective_layout : Declare a layout
    """
    def forward_index(self, index):
        """Given the indices of the src-layout, infer the dst index.

        Parameters
        ----------
        index: Array of Expr
            The indices in src-layout.

        Returns
        -------
        dst_index: Array of Expr
            The inferred indices in dst-layout.
        """
        return _ffi_api.BijectiveLayoutForwardIndex(self, index)

    def backward_index(self, index):
        """Given the indices of the dst-layout, infer the src index.

        Parameters
        ----------
        index: Array of Expr
            The indices in dst-layout.

        Returns
        -------
        src_index: Array of Expr
            The inferred indices in src-layout.
        """
        return _ffi_api.BijectiveLayoutBackwardIndex(self, index)

    def forward_shape(self, shape):
        """Given the shape of the src-layout, infer the dst shape.

        Parameters
        ----------
        shape: Array of Expr
            The shape in src-layout.

        Returns
        -------
        dst_shape: Array of Expr
            The inferred shape in dst-layout.
        """
        return _ffi_api.BijectiveLayoutForwardShape(self, shape)

    def backward_shape(self, shape):
        """Given the shape of the dst-layout, infer the src shape.

        Parameters
        ----------
        shape: Array of Expr
            The shape in dst-layout.

        Returns
        -------
        src_shape: Array of Expr
            The inferred shape in src-layout.
        """
        return _ffi_api.BijectiveLayoutBackwardShape(self, shape)


def layout(layout_str):
    """Create a layout node from a string.

    Parameters
    ----------
    layout_str : str
        A layout representation is composed of upper cases, lower cases and numbers,
        where upper case indicates a primal axis and
        the corresponding lower case with factor size indicates the subordinate axis.
        For example, NCHW16c can describe a 5-D tensor of
        [batch_size, channel, height, width, channel_block].
        Here subordinate axis channel_block=16 is the factor size of
        the primal axis C (channel).

    Returns
    -------
    layout : Layout
        The created layout
    """
    return _ffi_api.Layout(layout_str)


def bijective_layout(src_layout, dst_layout):
    """Create a bijective layout mapping.

    Parameters
    ----------
    src_layout : str or Layout
        source layout.

    dst_layout : str or Layout
        destination layout.

    Returns
    -------
    bijective_layout : BijectiveLayout
        The created bijective layout
    """
    if isinstance(src_layout, str):
        src_layout = layout(src_layout)
    if isinstance(dst_layout, str):
        dst_layout = layout(dst_layout)
    return _ffi_api.BijectiveLayout(src_layout, dst_layout)