# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
"""Tensor intrinsics"""
from __future__ import absolute_import as _abs
from . import _api_internal
from . import api as _api
from . import expr as _expr
from . import stmt as _stmt
from . import make as _make
from . import tensor as _tensor
from . import schedule as _schedule
from .build_module import current_build_config
from ._ffi.node import NodeBase, register_node


def _get_region(tslice):
    region = []
    for idx in tslice.indices:
        if isinstance(idx, slice):
            assert idx.step is None
            region.append(_api.Range(idx.start, idx.stop))
        else:
            if isinstance(idx, _schedule.IterVar):
                begin = idx.var
            else:
                begin = idx
            region.append(_make.range_by_min_extent(begin, 1))
    return region

@register_node
class TensorIntrin(NodeBase):
    """Tensor intrinsic functions for certain computation.

    See Also
    --------
    decl_tensor_intrin: Construct a TensorIntrin
    """
    def __call__(self, *args, **kwargs):
        tensors = [x.tensor for x in args if isinstance(x, _tensor.TensorSlice)]
        scalar_inputs = [x for x in args if not isinstance(x, _tensor.TensorSlice)]
        regions = [_get_region(x) for x in args if isinstance(x, _tensor.TensorSlice)]
        reduce_axis = []
        if "reduce_axis" in kwargs:
            reduce_axis = kwargs["reduce_axis"]
            if not isinstance(reduce_axis, (list, tuple)):
                reduce_axis = [reduce_axis]
            reduce_axis = _api.convert(reduce_axis)
        if scalar_inputs:
            scalar_inputs = _api.convert(scalar_inputs)
        return _api_internal._TensorIntrinCall(self, tensors, regions, reduce_axis, scalar_inputs)

def decl_tensor_intrin(op,
                       fcompute,
                       name="tensor_intrin",
                       binds=None, scalar_params=None):
    """Declare a tensor intrinsic function.

    Parameters
    ----------
    op: Operation
        The symbolic description of the intrinsic operation

    fcompute: lambda function of inputs, outputs-> stmt
        Specifies the IR statement to do the computation.
        See the following note for function signature of fcompute

        .. note::
             **Parameters**

             - **ins** (list of :any:`Buffer`) - Placeholder for each inputs
             - **outs** (list of :any:`Buffer`) - Placeholder for each outputs

             **Returns**

             - **stmt** (:any:`Stmt`, or tuple of three stmts)
             - If a single stmt is returned, it represents the body
             - If tuple of three stmts are returned they corresponds to body,
               reduce_init, reduce_update

    name: str, optional
        The name of the intrinsic.

    binds: dict of :any:`Tensor` to :any:`Buffer`, optional
        Dictionary that maps the Tensor to Buffer which specified the data layout
        requirement of the function. By default, a new compact buffer is created
        for each tensor in the argument.

    scalar_params: a list of variables used by op, whose values will be passed
                   as scalar_inputs when the tensor intrinsic is called.

    Returns
    -------
    intrin: TensorIntrin
        A TensorIntrin that can be used in tensorize schedule.
    """
    if not isinstance(op, _tensor.Operation):
        raise TypeError("expect Operation")
    inputs = op.input_tensors
    binds = binds if binds else {}
    tensors = [x for x in inputs]
    for i in range(op.num_outputs):
        tensors.append(op.output(i))

    binds_list = []
    for t in inputs:
        if not isinstance(t.op, _tensor.PlaceholderOp):
            raise ValueError("Do not yet support composition op")

    cfg = current_build_config()
    for t in tensors:
        buf = (binds[t] if t in binds else
               _api.decl_buffer(t.shape, t.dtype, t.op.name,
                                data_alignment=cfg.data_alignment,
                                offset_factor=cfg.offset_factor))
        binds_list.append(buf)

    if scalar_params:
        body = fcompute(binds_list[:len(inputs)], binds_list[len(inputs):], scalar_params)
    else:
        body = fcompute(binds_list[:len(inputs)], binds_list[len(inputs):])
        scalar_params = []
    if isinstance(body, (_expr.Expr, _stmt.Stmt)):
        body = [body]
    body = [_make.Evaluate(x) if isinstance(x, _expr.Expr) else x for x in body]
    if len(body) < 3:
        body += [None] * (3 - len(body))
    return _api_internal._TensorIntrin(
        name, op, inputs, binds_list, scalar_params, *body)