/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

use std::{convert::TryFrom, mem, os::raw::c_void, ptr, slice};

use failure::Error;
use ndarray;
use tvm_common::{
    array::{DataType, TVMContext},
    ffi::{
        DLContext, DLDataType, DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt,
        DLDataTypeCode_kDLUInt, DLTensor,
    },
};

use crate::allocator::Allocation;

/// A `Storage` is a container which holds `Tensor` data.
#[derive(PartialEq)]
pub enum Storage<'a> {
    /// A `Storage` which owns its contained bytes.
    Owned(Allocation),

    /// A view of an existing `Storage`.
    View(&'a mut [u8], usize), // ptr, align
}

impl<'a> Storage<'a> {
    pub fn new(size: usize, align: Option<usize>) -> Result<Storage<'static>, Error> {
        Ok(Storage::Owned(Allocation::new(size, align)?))
    }

    pub fn as_mut_ptr(&self) -> *mut u8 {
        match self {
            Storage::Owned(alloc) => alloc.as_mut_ptr(),
            Storage::View(slice, _) => slice.as_ptr() as *mut u8,
        }
    }

    pub fn size(&self) -> usize {
        match self {
            Storage::Owned(alloc) => alloc.size(),
            Storage::View(slice, _) => slice.len(),
        }
    }

    pub fn align(&self) -> usize {
        match self {
            Storage::Owned(alloc) => alloc.align(),
            Storage::View(_, align) => *align,
        }
    }

    pub fn as_ptr(&self) -> *const u8 {
        self.as_mut_ptr() as *const _
    }

    /// Returns a `Storage::View` which points to an owned `Storage::Owned`.
    pub fn view(&self) -> Storage<'a> {
        match self {
            Storage::Owned(alloc) => Storage::View(
                unsafe { slice::from_raw_parts_mut(alloc.as_mut_ptr(), self.size()) },
                self.align(),
            ),
            Storage::View(slice, _) => Storage::View(
                unsafe { slice::from_raw_parts_mut(self.as_mut_ptr(), slice.len()) },
                self.align(),
            ),
        }
    }

    pub fn is_owned(&self) -> bool {
        match self {
            Storage::Owned(_) => true,
            _ => false,
        }
    }

    /// Returns an owned version of this storage via cloning.
    pub fn to_owned(&self) -> Storage<'static> {
        let s = Storage::new(self.size(), Some(self.align())).unwrap();
        unsafe {
            s.as_mut_ptr()
                .copy_from_nonoverlapping(self.as_ptr(), self.size());
        }
        s
    }
}

impl<'d, 's, T> From<&'d [T]> for Storage<'s> {
    fn from(data: &'d [T]) -> Self {
        let data = unsafe {
            slice::from_raw_parts_mut(
                data.as_ptr() as *const u8 as *mut u8,
                data.len() * mem::size_of::<T>() as usize,
            )
        };
        Storage::View(data, mem::align_of::<T>())
    }
}

/// A n-dimensional array type which can be converted to/from `tvm::DLTensor` and `ndarray::Array`.
/// `Tensor` is primarily a holder of data which can be operated on via TVM (via `DLTensor`) or
/// converted to `ndarray::Array` for non-TVM processing.
///
/// # Examples
///
/// ```
/// extern crate ndarray;
///
/// let mut a_nd: ndarray::Array = ndarray::Array::from_vec(vec![1f32, 2., 3., 4.]);
/// let mut a: Tensor = a_nd.into();
/// let mut a_dl: DLTensor = (&mut t).into();
/// call_packed!(tvm_fn, &mut a_dl);
///
/// // Array -> Tensor is mostly useful when post-processing TVM graph outputs.
/// let mut a_nd = ndarray::Array::try_from(&a).unwrap();
/// ```
#[derive(PartialEq)]
pub struct Tensor<'a> {
    /// The bytes which contain the data this `Tensor` represents.
    pub(crate) data: Storage<'a>,
    pub(crate) ctx: TVMContext,
    pub(crate) dtype: DataType,
    pub(crate) shape: Vec<i64>,
    // ^ not usize because `typedef int64_t tvm_index_t` in c_runtime_api.h
    /// The `Tensor` strides. Can be `None` if the `Tensor` is contiguous.
    pub(crate) strides: Option<Vec<usize>>,
    pub(crate) byte_offset: isize,
    /// The number of elements in the `Tensor`.
    pub(crate) size: usize,
}

unsafe impl<'a> Send for Tensor<'a> {}

impl<'a> Tensor<'a> {
    pub fn shape(&self) -> Vec<i64> {
        self.shape.clone()
    }

    /// Returns the data of this `Tensor` as a `Vec`.
    ///
    /// # Panics
    ///
    /// Panics if the `Tensor` is not contiguous or does not contain elements of type `T`.
    pub fn to_vec<T: 'static + std::fmt::Debug + Clone>(&self) -> Vec<T> {
        assert!(self.is_contiguous());
        assert!(self.dtype.is_type::<T>());
        unsafe { slice::from_raw_parts(self.data.as_ptr() as *const T, self.size).to_vec() }
    }

    /// Returns `true` iff this `Tensor` is represented by a contiguous region of memory.
    pub fn is_contiguous(&self) -> bool {
        match self.strides {
            None => true,
            Some(ref strides) => {
                // check that stride for each dimension is the
                // product of all trailing dimensons' shapes
                self.shape
                    .iter()
                    .zip(strides)
                    .rfold(
                        (true, 1),
                        |(is_contig, expected_stride), (shape, stride)| {
                            (
                                is_contig && *stride == expected_stride,
                                expected_stride * (*shape as usize),
                            )
                        },
                    )
                    .0
            }
        }
    }

    /// Returns a clone of this `Tensor`.
    ///
    /// # Panics
    ///
    /// Panics if the `Tensor` is not contiguous or does not contain elements of type `T`.
    pub fn copy(&mut self, other: &Tensor) {
        assert!(
            self.dtype == other.dtype && self.size == other.size,
            "Tensor shape/dtype mismatch."
        );
        assert!(
      self.is_contiguous() && other.is_contiguous(),
      "copy currently requires contiguous tensors\n`self.strides = {:?}` `other.strides = {:?}`",
      self.strides,
      other.strides
    );
        unsafe {
            self.data
                .as_mut_ptr()
                .offset(self.byte_offset as isize)
                .copy_from_nonoverlapping(
                    other.data.as_mut_ptr().offset(other.byte_offset),
                    other.size * other.dtype.itemsize(),
                );
        }
    }

    /// Returns an owned version of this `Tensor` via cloning.
    pub fn to_owned(&self) -> Tensor<'static> {
        let t = Tensor {
            data: self.data.to_owned(),
            ctx: self.ctx.clone(),
            dtype: self.dtype.clone(),
            size: self.size.clone(),
            shape: self.shape.clone(),
            strides: None,
            byte_offset: 0,
        };
        unsafe { mem::transmute::<Tensor<'a>, Tensor<'static>>(t) }
    }

    fn from_array_storage<'s, T, D: ndarray::Dimension>(
        arr: &ndarray::Array<T, D>,
        storage: Storage<'s>,
        type_code: usize,
    ) -> Tensor<'s> {
        let type_width = mem::size_of::<T>() as usize;
        Tensor {
            data: storage,
            ctx: TVMContext::default(),
            dtype: DataType {
                code: type_code,
                bits: 8 * type_width,
                lanes: 1,
            },
            size: arr.len(),
            shape: arr.shape().iter().map(|&v| v as i64).collect(),
            strides: Some(arr.strides().into_iter().map(|&v| v as usize).collect()),
            byte_offset: 0,
        }
    }

    pub(crate) fn as_dltensor(&self, flatten: bool) -> DLTensor {
        assert!(!flatten || self.is_contiguous());
        DLTensor {
            data: unsafe { self.data.as_mut_ptr().offset(self.byte_offset) } as *mut c_void,
            ctx: DLContext::from(&self.ctx),
            ndim: if flatten { 1 } else { self.shape.len() } as i32,
            dtype: DLDataType::from(&self.dtype),
            shape: if flatten {
                &self.size as *const _ as *mut i64
            } else {
                self.shape.as_ptr()
            } as *mut i64,
            strides: if flatten || self.is_contiguous() {
                ptr::null_mut()
            } else {
                self.strides.as_ref().unwrap().as_ptr()
            } as *mut i64,
            byte_offset: 0,
        }
    }
}

/// Conversions to `ndarray::Array` from `Tensor`, if the types match.
macro_rules! impl_ndarray_try_from_tensor {
    ($type:ty, $dtype:expr) => {
        impl<'a, 't> TryFrom<&'a Tensor<'t>> for ndarray::ArrayD<$type> {
            type Error = Error;
            fn try_from(tensor: &'a Tensor) -> Result<ndarray::ArrayD<$type>, Error> {
                ensure!(
                    tensor.dtype == $dtype,
                    "Cannot convert Tensor with dtype {:?} to ndarray",
                    tensor.dtype
                );
                Ok(ndarray::Array::from_shape_vec(
                    tensor
                        .shape
                        .iter()
                        .map(|s| *s as usize)
                        .collect::<Vec<usize>>(),
                    tensor.to_vec::<$type>(),
                )?)
            }
        }
    };
}

macro_rules! make_dtype_const {
    ($name: ident, $code: ident, $bits: expr, $lanes: expr) => {
        pub const $name: DataType = DataType {
            code: $code as usize,
            bits: $bits,
            lanes: $lanes,
        };
    };
}

make_dtype_const!(DTYPE_INT32, DLDataTypeCode_kDLInt, 32, 1);
make_dtype_const!(DTYPE_UINT32, DLDataTypeCode_kDLUInt, 32, 1);
// make_dtype_const!(DTYPE_FLOAT16, DLDataTypeCode_kDLFloat, 16, 1);
make_dtype_const!(DTYPE_FLOAT32, DLDataTypeCode_kDLFloat, 32, 1);
make_dtype_const!(DTYPE_FLOAT64, DLDataTypeCode_kDLFloat, 64, 1);
impl_ndarray_try_from_tensor!(i32, DTYPE_INT32);
impl_ndarray_try_from_tensor!(u32, DTYPE_UINT32);
impl_ndarray_try_from_tensor!(f32, DTYPE_FLOAT32);
impl_ndarray_try_from_tensor!(f64, DTYPE_FLOAT64);

impl<'a, 't> From<&'a Tensor<'t>> for DLTensor {
    fn from(tensor: &'a Tensor<'t>) -> Self {
        Tensor::as_dltensor(tensor, false /* flatten */)
    }
}

impl<'a, 't> From<&'a mut Tensor<'t>> for DLTensor {
    fn from(tensor: &'a mut Tensor<'t>) -> Self {
        Tensor::as_dltensor(tensor, false /* flatten */)
    }
}

impl<'a> From<DLTensor> for Tensor<'a> {
    fn from(dlt: DLTensor) -> Self {
        unsafe {
            let dtype = DataType::from(dlt.dtype);
            let shape = slice::from_raw_parts(dlt.shape, dlt.ndim as usize).to_vec();
            let size = shape.iter().map(|v| *v as usize).product::<usize>() as usize;
            let storage = Storage::from(slice::from_raw_parts(
                dlt.data as *const u8,
                dtype.itemsize() * size,
            ));
            Self {
                data: storage,
                ctx: TVMContext::default(),
                dtype: dtype,
                size: size,
                shape: shape,
                strides: if dlt.strides == ptr::null_mut() {
                    None
                } else {
                    Some(slice::from_raw_parts_mut(dlt.strides as *mut usize, size).to_vec())
                },
                byte_offset: dlt.byte_offset as isize,
            }
        }
    }
}

/// `From` conversions to `Tensor` for owned or borrowed `ndarray::Array`.
///
/// # Panics
///
/// Panics if the ndarray is not contiguous.
macro_rules! impl_tensor_from_ndarray {
    ($type:ty, $typecode:expr) => {
        impl<D: ndarray::Dimension> From<ndarray::Array<$type, D>> for Tensor<'static> {
            fn from(arr: ndarray::Array<$type, D>) -> Self {
                let storage = Storage::from(arr.as_slice().expect("NDArray must be contiguous"));
                Tensor::from_array_storage(&arr, storage.to_owned(), $typecode as usize)
            }
        }
        impl<'a, D: ndarray::Dimension> From<&'a ndarray::Array<$type, D>> for Tensor<'a> {
            fn from(arr: &'a ndarray::Array<$type, D>) -> Self {
                let storage = Storage::from(arr.as_slice().expect("NDArray must be contiguous"));
                Tensor::from_array_storage(arr, storage, $typecode as usize)
            }
        }
    };
}

impl_tensor_from_ndarray!(f32, DLDataTypeCode_kDLFloat);
impl_tensor_from_ndarray!(f64, DLDataTypeCode_kDLFloat);
impl_tensor_from_ndarray!(i32, DLDataTypeCode_kDLInt);
impl_tensor_from_ndarray!(i64, DLDataTypeCode_kDLInt);
impl_tensor_from_ndarray!(u32, DLDataTypeCode_kDLUInt);
impl_tensor_from_ndarray!(u64, DLDataTypeCode_kDLUInt);