//! This module provides the the wrapped `TVMValue`, `TVMArgValue` and `TVMRetValue`
//! required for using TVM functions.

use std::{
    any::Any,
    convert::TryFrom,
    ffi::{CStr, CString},
    fmt::{self, Debug, Formatter},
    marker::PhantomData,
    mem,
    ops::Deref,
    os::raw::{c_char, c_void},
};

#[cfg(feature = "runtime")]
use ffi::runtime::TVMValue as _TVMValue;

#[cfg(feature = "frontend")]
use ffi::ts::TVMValue as _TVMValue;

use errors::*;

use ty::TVMTypeCode;

/// Wrapped TVMValue type.
#[derive(Clone, Copy)]
pub struct TVMValue {
    pub inner: _TVMValue,
}

impl TVMValue {
    /// Creates TVMValue from the raw part.
    pub fn new(inner: _TVMValue) -> Self {
        TVMValue { inner }
    }

    pub(crate) fn into_raw(self) -> _TVMValue {
        self.inner
    }
}

impl Debug for TVMValue {
    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
        unsafe {
            write!(
                f,
                "TVMValue: [v_int64: {:?}], [v_float64: {:?}], [v_handle: {:?}],\
                 [v_str: {:?}]",
                self.inner.v_int64, self.inner.v_float64, self.inner.v_handle, self.inner.v_str
            )
        }
    }
}

impl Deref for TVMValue {
    type Target = _TVMValue;
    fn deref(&self) -> &Self::Target {
        &self.inner
    }
}

macro_rules! impl_prim_val {
    ($type:ty, $field:ident, $cast:ty) => {
        impl From<$type> for TVMValue {
            fn from(arg: $type) -> Self {
                let inner = _TVMValue {
                    $field: arg as $cast,
                };
                Self::new(inner)
            }
        }

        impl<'a> From<&'a $type> for TVMValue {
            fn from(arg: &$type) -> Self {
                let inner = _TVMValue {
                    $field: *arg as $cast,
                };
                Self::new(inner)
            }
        }

        impl<'a> From<&'a mut $type> for TVMValue {
            fn from(arg: &mut $type) -> Self {
                let inner = _TVMValue {
                    $field: *arg as $cast,
                };
                Self::new(inner)
            }
        }

        impl TryFrom<TVMValue> for $type {
            type Error = Error;
            fn try_from(val: TVMValue) -> Result<Self> {
                Ok(unsafe { val.inner.$field as $type })
            }
        }

        impl<'a> TryFrom<&'a TVMValue> for $type {
            type Error = Error;
            fn try_from(val: &TVMValue) -> Result<Self> {
                Ok(unsafe { val.into_raw().$field as $type })
            }
        }

        impl<'a> TryFrom<&'a mut TVMValue> for $type {
            type Error = Error;
            fn try_from(val: &mut TVMValue) -> Result<Self> {
                Ok(unsafe { val.into_raw().$field as $type })
            }
        }
    };
}

impl_prim_val!(isize, v_int64, i64);
impl_prim_val!(i64, v_int64, i64);
impl_prim_val!(i32, v_int64, i64);
impl_prim_val!(i16, v_int64, i64);
impl_prim_val!(i8, v_int64, i64);
impl_prim_val!(usize, v_int64, i64);
impl_prim_val!(u64, v_int64, i64);
impl_prim_val!(u32, v_int64, i64);
impl_prim_val!(u16, v_int64, i64);
impl_prim_val!(u8, v_int64, i64);

impl_prim_val!(f64, v_float64, f64);
impl_prim_val!(f32, v_float64, f64);

impl<'a> From<&'a str> for TVMValue {
    fn from(arg: &str) -> TVMValue {
        let arg = CString::new(arg).unwrap();
        let inner = _TVMValue {
            v_str: arg.as_ptr() as *const c_char,
        };
        mem::forget(arg);
        Self::new(inner)
    }
}

impl<'a> From<&'a String> for TVMValue {
    fn from(arg: &String) -> TVMValue {
        let arg = CString::new(arg.as_bytes()).unwrap();
        let inner = _TVMValue {
            v_str: arg.as_ptr() as *const c_char,
        };
        mem::forget(arg);
        Self::new(inner)
    }
}

impl<'a> From<&'a CString> for TVMValue {
    fn from(arg: &CString) -> TVMValue {
        let arg = arg.to_owned();
        let inner = _TVMValue {
            v_str: arg.as_ptr() as *const c_char,
        };
        mem::forget(arg);
        Self::new(inner)
    }
}

impl<'a> From<&'a [u8]> for TVMValue {
    fn from(arg: &[u8]) -> TVMValue {
        let arg = arg.to_owned();
        let inner = _TVMValue {
            v_handle: &arg as *const _ as *mut c_void,
        };
        mem::forget(arg);
        Self::new(inner)
    }
}

/// Captures both `TVMValue` and `TVMTypeCode` needed for TVM function.
/// The preferred way to obtain a `TVMArgValue` is automatically via `call_packed!`.
/// or in the frontend crate, with `function::Builder`. Checkout the methods for conversions.
///
/// ## Example
///
/// ```
/// let s = "hello".to_string();
/// let arg = TVMArgValue::from(&s);
/// let tvm: String = arg.try_into().unwrap();
/// assert_eq!(arg, s);
/// ```
#[derive(Debug, Clone, Copy)]
pub struct TVMArgValue<'a> {
    /// The wrapped TVMValue
    pub value: TVMValue,
    /// The matching type code.
    pub type_code: TVMTypeCode,
    /// This is only exposed to runtime and frontend crates and is not meant to be used directly.
    pub lifetime: PhantomData<&'a ()>,
}

impl<'a> TVMArgValue<'a> {
    pub fn new(value: TVMValue, type_code: TVMTypeCode) -> Self {
        TVMArgValue {
            value: value,
            type_code: type_code,
            lifetime: PhantomData,
        }
    }
}

impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for i64 {
    type Error = Error;
    fn try_from(arg: &TVMArgValue<'a>) -> Result<Self> {
        if (arg.type_code == TVMTypeCode::kDLInt)
            | (arg.type_code == TVMTypeCode::kDLUInt)
            | (arg.type_code == TVMTypeCode::kNull)
        {
            Ok(unsafe { arg.value.inner.v_int64 })
        } else {
            bail!(ErrorKind::TryFromTVMArgValueError(
                stringify!(i64).to_string(),
                arg.type_code.to_string()
            ))
        }
    }
}

impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for f64 {
    type Error = Error;
    fn try_from(arg: &TVMArgValue<'a>) -> Result<Self> {
        if arg.type_code == TVMTypeCode::kDLFloat {
            Ok(unsafe { arg.value.inner.v_float64 })
        } else {
            bail!(ErrorKind::TryFromTVMArgValueError(
                stringify!(f64).to_string(),
                arg.type_code.to_string()
            ))
        }
    }
}

impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for String {
    type Error = Error;
    fn try_from(arg: &TVMArgValue<'a>) -> Result<Self> {
        if arg.type_code == TVMTypeCode::kStr {
            let ret_str = unsafe {
                match CStr::from_ptr(arg.value.inner.v_str).to_str() {
                    Ok(s) => s,
                    Err(_) => "Invalid UTF-8 message",
                }
            };
            Ok(ret_str.to_string())
        } else {
            bail!(ErrorKind::TryFromTVMArgValueError(
                stringify!(String).to_string(),
                arg.type_code.to_string()
            ))
        }
    }
}

/// Main way to create a TVMArgValue from suported Rust values.
impl<'b, 'a: 'b, T: 'b + ?Sized> From<&'b T> for TVMArgValue<'a>
where
    TVMValue: From<&'b T>,
    TVMTypeCode: From<&'b T>,
{
    fn from(arg: &'b T) -> Self {
        TVMArgValue::new(TVMValue::from(arg), TVMTypeCode::from(arg))
    }
}

/// Creates a conversion to a `TVMArgValue` for an object handle.
impl<'a, T> From<*const T> for TVMArgValue<'a> {
    fn from(ptr: *const T) -> Self {
        let value = TVMValue::new(_TVMValue {
            v_handle: ptr as *mut T as *mut c_void,
        });

        TVMArgValue::new(value, TVMTypeCode::kArrayHandle)
    }
}

/// Creates a conversion to a `TVMArgValue` for a mutable object handle.
impl<'a, T> From<*mut T> for TVMArgValue<'a> {
    fn from(ptr: *mut T) -> Self {
        let value = TVMValue::new(_TVMValue {
            v_handle: ptr as *mut c_void,
        });

        TVMArgValue::new(value, TVMTypeCode::kHandle)
    }
}

/// An owned version of TVMPODValue. It can be converted from varieties of
/// primitive and object types.
/// It can be downcasted using `try_from` if it contains the desired type.
///
/// # Example
///
/// ```
/// let a = 42u32;
/// let b: i64 = TVMRetValue::from(a).try_into().unwrap();
///
/// let s = "hello, world!";
/// let t: TVMRetValue = s.into();
/// assert_eq!(String::try_from(t).unwrap(), s);
/// ```
pub struct TVMRetValue {
    /// A primitive return value, if any.
    pub prim_value: usize,
    /// An object return value, if any.
    pub box_value: Box<Any>,
    pub type_code: TVMTypeCode,
}

impl TVMRetValue {
    fn new(prim_value: usize, box_value: Box<Any>, type_code: TVMTypeCode) -> Self {
        Self {
            prim_value,
            box_value,
            type_code,
        }
    }

    /// unsafe function to create `TVMRetValue` from `TVMValue` and
    /// its matching `TVMTypeCode`.
    pub unsafe fn from_tvm_value(value: TVMValue, type_code: TVMTypeCode) -> Self {
        let value = value.into_raw();
        match type_code {
            TVMTypeCode::kDLInt | TVMTypeCode::kDLUInt => {
                Self::new(value.v_int64 as usize, box (), type_code)
            }
            TVMTypeCode::kDLFloat => Self::new(value.v_float64 as usize, box (), type_code),
            TVMTypeCode::kHandle
            | TVMTypeCode::kArrayHandle
            | TVMTypeCode::kNodeHandle
            | TVMTypeCode::kModuleHandle
            | TVMTypeCode::kFuncHandle => {
                Self::new(value.v_handle as usize, box value.v_handle, type_code)
            }
            TVMTypeCode::kStr | TVMTypeCode::kBytes => {
                Self::new(value.v_str as usize, box (value.v_str), type_code)
            }
            _ => Self::new(0usize, box (), type_code),
        }
    }

    /// Returns the underlying `TVMValue` and `TVMTypeCode`.
    pub fn into_tvm_value(self) -> (TVMValue, TVMTypeCode) {
        let val = match self.type_code {
            TVMTypeCode::kDLInt | TVMTypeCode::kDLUInt => TVMValue::new(_TVMValue {
                v_int64: self.prim_value as i64,
            }),
            TVMTypeCode::kDLFloat => TVMValue::new(_TVMValue {
                v_float64: self.prim_value as f64,
            }),
            TVMTypeCode::kHandle
            | TVMTypeCode::kArrayHandle
            | TVMTypeCode::kNodeHandle
            | TVMTypeCode::kModuleHandle
            | TVMTypeCode::kFuncHandle
            | TVMTypeCode::kNDArrayContainer => TVMValue::new(_TVMValue {
                v_handle: self.prim_value as *const c_void as *mut c_void,
            }),
            TVMTypeCode::kStr | TVMTypeCode::kBytes => TVMValue::new(_TVMValue {
                v_str: self.prim_value as *const c_char,
            }),
            _ => unreachable!(),
        };
        (val, self.type_code)
    }
}

impl Default for TVMRetValue {
    fn default() -> Self {
        TVMRetValue {
            prim_value: 0usize,
            box_value: box (),
            type_code: TVMTypeCode::default(),
        }
    }
}

impl Clone for TVMRetValue {
    fn clone(&self) -> Self {
        match self.type_code {
            TVMTypeCode::kDLInt | TVMTypeCode::kDLUInt | TVMTypeCode::kDLFloat => {
                Self::new(self.prim_value.clone(), box (), self.type_code.clone())
            }
            TVMTypeCode::kHandle
            | TVMTypeCode::kArrayHandle
            | TVMTypeCode::kNodeHandle
            | TVMTypeCode::kModuleHandle
            | TVMTypeCode::kFuncHandle
            | TVMTypeCode::kNDArrayContainer => Self::new(
                self.prim_value.clone(),
                box (self.prim_value.clone() as *const c_void as *mut c_void),
                self.type_code.clone(),
            ),
            TVMTypeCode::kStr | TVMTypeCode::kBytes => Self::new(
                self.prim_value.clone(),
                box (self.prim_value.clone() as *const c_char),
                self.type_code.clone(),
            ),
            _ => unreachable!(),
        }
    }
}

impl Debug for TVMRetValue {
    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
        write!(
            f,
            "prim_value: {:?}, box_value: {:?}, type_code: {:?}",
            self.prim_value, self.prim_value as *const c_void as *mut c_void, self.type_code
        )
    }
}

macro_rules! impl_prim_ret_value {
    ($type:ty, $code:expr) => {
        impl From<$type> for TVMRetValue {
            fn from(val: $type) -> Self {
                TVMRetValue {
                    prim_value: val as usize,
                    box_value: box (),
                    type_code: $code,
                }
            }
        }

        impl<'a> From<&'a $type> for TVMRetValue {
            fn from(val: &$type) -> Self {
                TVMRetValue {
                    prim_value: *val as usize,
                    box_value: box (),
                    type_code: $code,
                }
            }
        }

        impl<'a> From<&'a mut $type> for TVMRetValue {
            fn from(val: &mut $type) -> Self {
                TVMRetValue {
                    prim_value: *val as usize,
                    box_value: box (),
                    type_code: $code,
                }
            }
        }

        impl TryFrom<TVMRetValue> for $type {
            type Error = Error;
            fn try_from(ret: TVMRetValue) -> Result<$type> {
                if ret.type_code == $code {
                    Ok(ret.prim_value as $type)
                } else {
                    bail!(ErrorKind::TryFromTVMRetValueError(
                        stringify!($type).to_string(),
                        ret.type_code.to_string(),
                    ))
                }
            }
        }
    };
}

impl_prim_ret_value!(i8, TVMTypeCode::kDLInt);
impl_prim_ret_value!(i16, TVMTypeCode::kDLInt);
impl_prim_ret_value!(i32, TVMTypeCode::kDLInt);
impl_prim_ret_value!(i64, TVMTypeCode::kDLInt);
impl_prim_ret_value!(isize, TVMTypeCode::kDLInt);

impl_prim_ret_value!(u8, TVMTypeCode::kDLUInt);
impl_prim_ret_value!(u16, TVMTypeCode::kDLUInt);
impl_prim_ret_value!(u32, TVMTypeCode::kDLUInt);
impl_prim_ret_value!(u64, TVMTypeCode::kDLUInt);
impl_prim_ret_value!(usize, TVMTypeCode::kDLUInt);

impl_prim_ret_value!(f32, TVMTypeCode::kDLFloat);
impl_prim_ret_value!(f64, TVMTypeCode::kDLFloat);

macro_rules! impl_ptr_ret_value {
    ($type:ty) => {
        impl From<$type> for TVMRetValue {
            fn from(ptr: $type) -> Self {
                TVMRetValue {
                    prim_value: ptr as usize,
                    box_value: box (),
                    type_code: TVMTypeCode::kHandle,
                }
            }
        }

        impl TryFrom<TVMRetValue> for $type {
            type Error = Error;
            fn try_from(ret: TVMRetValue) -> Result<$type> {
                if ret.type_code == TVMTypeCode::kHandle {
                    Ok(ret.prim_value as $type)
                } else {
                    bail!(ErrorKind::TryFromTVMRetValueError(
                        stringify!($type).to_string(),
                        ret.type_code.to_string(),
                    ))
                }
            }
        }
    };
}

impl_ptr_ret_value!(*const c_void);
impl_ptr_ret_value!(*mut c_void);

impl From<String> for TVMRetValue {
    fn from(val: String) -> Self {
        let pval = val.as_ptr() as *const c_char as usize;
        let bval = box (val.as_ptr() as *const c_char);
        mem::forget(val);
        TVMRetValue::new(pval, bval, TVMTypeCode::kStr)
    }
}

impl TryFrom<TVMRetValue> for String {
    type Error = Error;
    fn try_from(ret: TVMRetValue) -> Result<String> {
        // Note: simple downcast doesn't work for function call return values
        let ret_str = unsafe {
            match CStr::from_ptr(ret.prim_value as *const c_char).to_str() {
                Ok(s) => s,
                Err(_) => "Invalid UTF-8 message",
            }
        };

        Ok(ret_str.to_string())
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::convert::TryInto;

    #[test]
    fn numeric() {
        macro_rules! arg_ret_tests {
            ($v:expr; $($ty:ty),+) => {{
                $(
                    let v = $v as $ty;
                    let b = TVMRetValue::from(&v);
                    let b: $ty = b.try_into().unwrap();
                    assert_eq!(b, v);
                )+
            }};
        }

        arg_ret_tests!(42; i8, i16, i32, i64, f32, f64);
    }

    #[test]
    fn string() {
        let s = "hello".to_string();
        let tvm_arg: String = TVMRetValue::from(s.clone()).try_into().unwrap();
        assert_eq!(tvm_arg, s);
    }
}