# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
"""Tensor class for computation declaration."""
# pylint: disable=invalid-name
import tvm._ffi

from tvm.runtime import Object, ObjectGeneric, convert_to_object
from tvm.tir import expr as _expr

from . import _ffi_api

class TensorSlice(ObjectGeneric, _expr.ExprOp):
    """Auxiliary data structure for enable slicing syntax from tensor."""

    def __init__(self, tensor, indices):
        if not isinstance(indices, tuple):
            indices = (indices,)
        self.tensor = tensor
        self.indices = indices

    def __getitem__(self, indices):
        if not isinstance(indices, tuple):
            indices = (indices,)
        return TensorSlice(self.tensor, self.indices + indices)

    def asobject(self):
        """Convert slice to object."""
        return self.tensor(*self.indices)

    @property
    def dtype(self):
        """Data content of the tensor."""
        return self.tensor.dtype

@tvm._ffi.register_object
class TensorIntrinCall(Object):
    """Intermediate structure for calling a tensor intrinsic."""


@tvm._ffi.register_object
class Tensor(Object, _expr.ExprOp):
    """Tensor object, to construct, see function.Tensor"""

    def __call__(self, *indices):
        ndim = self.ndim
        if len(indices) != ndim:
            raise ValueError("Need to provide %d index in tensor slice" % ndim)
        indices = convert_to_object(indices)
        args = []
        for x in indices:
            if isinstance(x, _expr.PrimExpr):
                args.append(x)
            elif isinstance(x, _expr.IterVar):
                args.append(x.var)
            else:
                raise ValueError("The indices must be expression")

        return _expr.Call(self.dtype, self.op.name,
                          args, _expr.Call.Halide,
                          self.op, self.value_index)

    def __getitem__(self, indices):
        return TensorSlice(self, indices)

    def __hash__(self):
        return _ffi_api.TensorHash(self)

    def __eq__(self, other):
        if not isinstance(other, Tensor):
            if isinstance(other, _expr.ExprOp):
                return _expr.EqualOp(self, other)
            return False
        if self.ndim == 0 and other.ndim == 0:
            raise ValueError("Equal == comparison among rank-0 tensor is ambiguous, "
                             "use Tensor.equal for content expression equvalence, "
                             "use Tensor.same_as for exact reference comparison")
        return _ffi_api.TensorEqual(self, other)

    @property
    def ndim(self):
        """Dimension of the tensor."""
        return len(self.shape)

    @property
    def axis(self):
        """Axis of the tensor."""
        return self.__getattr__("axis")

    @property
    def op(self):
        """The corressponding :py:class:`Operation`."""
        return self.__getattr__("op")

    @property
    def value_index(self):
        """The output value index the tensor corresponds to."""
        return self.__getattr__("value_index")

    @property
    def shape(self):
        """The output shape of the tensor."""
        return self.__getattr__("shape")

    @property
    def name(self):
        op = self.op
        if op.num_outputs == 1:
            return op.name
        return "%s.v%d" % (op.name, self.value_index)


class Operation(Object):
    """Represent an operation that generates a tensor"""

    def output(self, index):
        """Get the index-th output of the operation

        Parameters
        ----------
        index : int
            The index size.

        Returns
        -------
        out : Tensor
            The i-th output.
        """
        return _ffi_api.OpGetOutput(self, index)

    @property
    def num_outputs(self):
        """Number of outputs from this op."""
        return _ffi_api.OpNumOutputs(self)

    @property
    def input_tensors(self):
        """List of input tensors to this op."""
        return _ffi_api.OpInputTensors(self)


@tvm._ffi.register_object
class PlaceholderOp(Operation):
    """Placeholder operation."""


@tvm._ffi.register_object
class BaseComputeOp(Operation):
    """Compute operation."""
    @property
    def axis(self):
        """Represent the IterVar axis, defined when it is a ComputeOp"""
        return self.__getattr__("axis")

    @property
    def reduce_axis(self):
        """Represent axis of reductions, only defined when it is a ComputeOp"""
        return self.__getattr__("reduce_axis")


@tvm._ffi.register_object
class ComputeOp(BaseComputeOp):
    """Scalar operation."""


@tvm._ffi.register_object
class TensorComputeOp(BaseComputeOp):
    """Tensor operation."""


@tvm._ffi.register_object
class ScanOp(Operation):
    """Scan operation."""
    @property
    def scan_axis(self):
        """Represent the scan axis, only defined when it is a ScanOp"""
        return self.__getattr__("scan_axis")


@tvm._ffi.register_object
class ExternOp(Operation):
    """External operation."""


@tvm._ffi.register_object
class HybridOp(Operation):
    """Hybrid operation."""
    @property
    def axis(self):
        """Represent the IterVar axis, also defined when it is a HybridOp"""
        return self.__getattr__("axis")