# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
"""TVM Runtime Object API."""
from __future__ import absolute_import as _abs
import numpy as _np

from tvm._ffi.object import Object, register_object, getitem_helper
from tvm import ndarray as _nd
from . import _vmobj


@register_object("vm.Tensor")
class Tensor(Object):
    """Tensor object.

    Parameters
    ----------
    arr : numpy.ndarray or tvm.nd.NDArray
        The source array.

    ctx :  TVMContext, optional
        The device context to create the array
    """
    def __init__(self, arr, ctx=None):
        if isinstance(arr, _np.ndarray):
            ctx = ctx if ctx else _nd.cpu(0)
            self.__init_handle_by_constructor__(
                _vmobj.Tensor, _nd.array(arr, ctx=ctx))
        elif isinstance(arr, _nd.NDArray):
            self.__init_handle_by_constructor__(
                _vmobj.Tensor, arr)
        else:
            raise RuntimeError("Unsupported type for tensor object.")

    @property
    def data(self):
        return _vmobj.GetTensorData(self)

    def asnumpy(self):
        """Convert data to numpy array

        Returns
        -------
        np_arr : numpy.ndarray
            The corresponding numpy array.
        """
        return self.data.asnumpy()


@register_object("vm.ADT")
class ADT(Object):
    """Algebatic data type(ADT) object.

    Parameters
    ----------
    tag : int
        The tag of ADT.

    fields : list[Object] or tuple[Object]
        The source tuple.
    """
    def __init__(self, tag, fields):
        for f in fields:
            assert isinstance(f, Object)
        self.__init_handle_by_constructor__(
            _vmobj.ADT, tag, *fields)

    @property
    def tag(self):
        return _vmobj.GetADTTag(self)

    def __getitem__(self, idx):
        return getitem_helper(
            self, _vmobj.GetADTFields, len(self), idx)

    def __len__(self):
        return _vmobj.GetADTNumberOfFields(self)


def tuple_object(fields):
    """Create a ADT object from source tuple.

    Parameters
    ----------
    fields : list[Object] or tuple[Object]
        The source tuple.

    Returns
    -------
    ret : ADT
        The created object.
    """
    for f in fields:
        assert isinstance(f, Object)
    return _vmobj.Tuple(*fields)