//! This module provides an idiomatic Rust API for creating and working with TVM functions.
//!
//! For calling an already registered TVM function use [`function::Builder`]
//! To register a TVM packed function from Rust side either
//! use [`function::register`] or the macro [`register_global_func`].
//!
//! See the tests and examples repository for more examples.

use std::{
    collections::BTreeMap,
    ffi::{CStr, CString},
    mem,
    os::raw::{c_char, c_int, c_void},
    ptr, slice, str,
    sync::Mutex,
};

use crate::{ts, ErrorKind, Module, Result, TVMArgValue, TVMRetValue, TVMTypeCode, TVMValue};

lazy_static! {
    static ref GLOBAL_FUNCTIONS: Mutex<BTreeMap<&'static str, Option<Function>>> = {
        let mut out_size = 0 as c_int;
        let name = ptr::null_mut() as *mut c_char;
        let mut out_array = name as *mut _;
        check_call!(ts::TVMFuncListGlobalNames(
            &mut out_size as *mut _,
            &mut out_array
        ));
        let names_list = unsafe { slice::from_raw_parts(out_array, out_size as usize) };
        Mutex::new(
            names_list
                .into_iter()
                .map(|&p| (unsafe { CStr::from_ptr(p).to_str().unwrap() }, None))
                .collect(),
        )
    };
}

/// Wrapper around TVM function handle which includes `is_global`
/// indicating whether the function is global or not, `is_released`
/// to hint dropping the function handle and `is_cloned` showing
/// not to drop a cloned function from Rust side.
/// The value of these fields can be accessed through their respective methods.
#[derive(Debug, Hash)]
pub struct Function {
    pub(crate) handle: ts::TVMFunctionHandle,
    // whether the registered function is global or not.
    is_global: bool,
    // whether the function has been dropped from frontend or not.
    is_released: bool,
    // whether the function has been cloned from frontend or not.
    is_cloned: bool,
}

unsafe impl Send for Function {}
unsafe impl Sync for Function {}

impl Function {
    pub(crate) fn new(handle: ts::TVMFunctionHandle, is_global: bool, is_released: bool) -> Self {
        Function {
            handle: handle,
            is_global: is_global,
            is_released: is_released,
            is_cloned: false,
        }
    }

    /// For a given function, it returns a function by name.
    pub fn get<S: AsRef<str>>(name: S, is_global: bool) -> Option<&'static Function> {
        let mut globals = GLOBAL_FUNCTIONS.lock().unwrap();
        globals.get_mut(name.as_ref()).and_then(|maybe_func| {
            if maybe_func.is_none() {
                let name = CString::new(name.as_ref()).unwrap();
                let mut handle = ptr::null_mut() as ts::TVMFunctionHandle;
                check_call!(ts::TVMFuncGetGlobal(
                    name.as_ptr() as *const c_char,
                    &mut handle as *mut _
                ));
                maybe_func.replace(Function::new(
                    handle, is_global, false, /* is_released */
                ));
            }
            unsafe {
                std::mem::transmute::<Option<&Function>, Option<&'static Function>>(
                    maybe_func.as_ref(),
                )
            }
        })
    }

    /// Returns the underlying TVM function handle.
    pub fn handle(&self) -> ts::TVMFunctionHandle {
        self.handle
    }

    /// Returns `true` if the underlying TVM function is global and `false` otherwise.
    pub fn is_global(&self) -> bool {
        self.is_global
    }

    /// Returns `true` if the underlying TVM function has been released
    /// from the frontend and `false` otherwise.
    pub fn is_released(&self) -> bool {
        self.is_released
    }

    /// Returns `true` if the underlying TVM function has been cloned
    /// from the frontend and `false` otherwise.
    pub fn is_cloned(&self) -> bool {
        self.is_cloned
    }
}

impl Clone for Function {
    fn clone(&self) -> Function {
        if !self.is_released && !self.is_cloned {
            Self {
                handle: self.handle,
                is_global: self.is_global,
                is_released: self.is_released,
                is_cloned: true,
            }
        } else {
            Function::new(self.handle, self.is_global, self.is_released)
        }
    }
}

impl Drop for Function {
    fn drop(&mut self) {
        if !self.is_released && !self.is_global && !self.is_cloned {
            check_call!(ts::TVMFuncFree(self.handle));
            self.is_released = true;
        }
    }
}

/// Function builder in order to create and call functions.
///
/// *Note:* Currently TVM functions accept *at most* one return value.
#[derive(Debug, Clone, Default)]
pub struct Builder<'a, 'm> {
    pub func: Option<&'m Function>,
    pub arg_buf: Option<Box<[TVMArgValue<'a>]>>,
    pub ret_buf: Option<TVMRetValue>,
}

impl<'a, 'm> Builder<'a, 'm> {
    pub fn new(
        func: Option<&'m Function>,
        arg_buf: Option<Box<[TVMArgValue<'a>]>>,
        ret_buf: Option<TVMRetValue>,
    ) -> Self {
        Self {
            func,
            arg_buf,
            ret_buf,
        }
    }

    pub fn get_function(&mut self, name: &'m str, is_global: bool) -> &mut Self {
        self.func = Function::get(name, is_global);
        self
    }

    /// Pushes a [`TVMArgValue`] into the function argument buffer.
    pub fn arg<'b, T: ?Sized>(&mut self, arg: &'b T) -> &mut Self
    where
        TVMValue: From<&'b T>,
        TVMTypeCode: From<&'b T>,
    {
        let tvm_arg = TVMArgValue::from(arg);
        if self.arg_buf.is_none() {
            self.arg_buf = Some(Box::new([tvm_arg]));
        } else {
            let new_arg_buf = self.arg_buf.take().map(|bbuf| {
                let mut new_arg_buf = Vec::from(bbuf);
                new_arg_buf.push(tvm_arg);
                let new_len = new_arg_buf.len();
                new_arg_buf.truncate(new_len);
                new_arg_buf.into_boxed_slice()
            });
            self.arg_buf = new_arg_buf;
        }
        self
    }

    /// Pushes multiple [`TVMArgValue`]s into the function argument buffer.
    pub fn args<'b, T: 'b + ?Sized, I>(&mut self, args: I) -> &mut Self
    where
        I: IntoIterator<Item = &'b T>,
        TVMValue: From<&'b T>,
        TVMTypeCode: From<&'b T>,
    {
        for arg in args {
            self.arg(&arg);
        }
        self
    }

    /// Sets an output for a function that requirs a mutable output to be provided.
    /// See the `basics` in tests for an example.
    pub fn set_output<'b, T: 'b + ?Sized>(&mut self, arg: &'b mut T) -> Result<&mut Self>
    where
        TVMValue: From<&'b T>,
        TVMTypeCode: From<&'b T>,
    {
        if self.ret_buf.is_none() {
            let tvm_ret =
                unsafe { TVMRetValue::from_tvm_value(TVMValue::from(arg), TVMTypeCode::from(arg)) };
            self.ret_buf = Some(tvm_ret);
        } else {
            bail!(ErrorKind::AtMostOneReturn)
        }
        Ok(self)
    }

    /// Calls the function that created from `Builder`.
    pub fn invoke(&mut self) -> Result<TVMRetValue> {
        self.clone()(())
    }
}

impl<'a, 'm> FnOnce<((),)> for Builder<'a, 'm> {
    type Output = Result<TVMRetValue>;
    extern "rust-call" fn call_once(self, _: ((),)) -> Self::Output {
        if self.func.is_none() {
            bail!("{}", ErrorKind::FunctionNotFound);
        }

        let mut ret_val = unsafe { mem::uninitialized::<ts::TVMValue>() };
        let mut ret_type_code = 0 as c_int;
        if self.arg_buf.is_some() {
            let arg_buf = self.arg_buf?;
            let mut num_args = arg_buf.len();
            let mut values = arg_buf
                .iter()
                .map(|tav| tav.value.inner)
                .collect::<Vec<ts::TVMValue>>();
            let mut tcodes = arg_buf
                .iter()
                .map(|tav| tav.type_code as c_int)
                .collect::<Vec<_>>();

            if self.ret_buf.is_some() {
                num_args = num_args + 1;
                let ret_buf = self.ret_buf?;
                let (ret_val, ret_type_code) = TVMRetValue::into_tvm_value(ret_buf);
                values.append(&mut vec![ret_val.inner]);
                tcodes.append(&mut vec![ret_type_code as c_int]);
            }

            values.truncate(num_args);
            tcodes.truncate(num_args);
            check_call!(ts::TVMFuncCall(
                self.func?.handle,
                values.as_mut_ptr(),
                tcodes.as_mut_ptr(),
                num_args as c_int,
                &mut ret_val as *mut _,
                &mut ret_type_code as *mut _
            ));
        } else {
            check_call!(ts::TVMFuncCall(
                self.func?.handle,
                ptr::null_mut(),
                ptr::null_mut(),
                0 as c_int,
                &mut ret_val as *mut _,
                &mut ret_type_code as *mut _
            ));
        }

        let ret = unsafe {
            TVMRetValue::from_tvm_value(TVMValue::new(ret_val), (ret_type_code as i64).into())
        };
        Ok(ret)
    }
}

/// Converts a [`Function`] to builder. Currently, this is the best way to work with
/// TVM functions.
impl<'a, 'm> From<&'m Function> for Builder<'a, 'm> {
    fn from(func: &'m Function) -> Self {
        Builder::new(Some(func), None, None)
    }
}

/// Converts a mutable reference of a [`Module`] to [`Builder`].
impl<'a, 'm> From<&'m mut Module> for Builder<'a, 'm> {
    fn from(module: &'m mut Module) -> Self {
        Builder::new(module.entry(), None, None)
    }
}

unsafe extern "C" fn tvm_callback(
    args: *mut ts::TVMValue,
    type_codes: *mut c_int,
    num_args: c_int,
    ret: ts::TVMRetValueHandle,
    fhandle: *mut c_void,
) -> c_int {
    // turning off the incorrect linter complaints
    #![allow(unused_assignments)]
    let len = num_args as usize;
    let args_list = slice::from_raw_parts_mut(args, len);
    let type_codes_list = slice::from_raw_parts_mut(type_codes, len);
    let mut local_args: Vec<TVMArgValue> = Vec::new();
    let mut value = mem::uninitialized::<ts::TVMValue>();
    let mut tcode = mem::uninitialized::<c_int>();
    let rust_fn = mem::transmute::<*mut c_void, fn(&[TVMArgValue]) -> Result<TVMRetValue>>(fhandle);
    for i in 0..len {
        value = args_list[i];
        tcode = type_codes_list[i];
        if tcode == ts::TVMTypeCode_kNodeHandle as c_int
            || tcode == ts::TVMTypeCode_kFuncHandle as c_int
            || tcode == ts::TVMTypeCode_kModuleHandle as c_int
        {
            check_call!(ts::TVMCbArgToReturn(&mut value as *mut _, tcode));
        }
        local_args.push(TVMArgValue::new(
            TVMValue::new(value),
            (tcode as i64).into(),
        ));
    }

    let rv = match rust_fn(local_args.as_slice()) {
        Ok(v) => v,
        Err(msg) => {
            crate::set_last_error(&msg);
            return -1;
        }
    };

    let (ret_val, ret_tcode) = TVMRetValue::into_tvm_value(rv);
    let mut ret_val = ret_val.inner;
    let mut ret_type_code = ret_tcode as c_int;
    check_call!(ts::TVMCFuncSetReturn(
        ret,
        &mut ret_val as *mut _,
        &mut ret_type_code as *mut _,
        1 as c_int
    ));
    0
}

unsafe extern "C" fn tvm_callback_finalizer(fhandle: *mut c_void) {
    let rust_fn = mem::transmute::<*mut c_void, fn(&[TVMArgValue]) -> Result<TVMRetValue>>(fhandle);
    mem::drop(rust_fn);
}

fn convert_to_tvm_func(f: fn(&[TVMArgValue]) -> Result<TVMRetValue>) -> Function {
    let mut fhandle = ptr::null_mut() as ts::TVMFunctionHandle;
    let resource_handle = f as *mut fn(&[TVMArgValue]) -> Result<TVMRetValue>;
    check_call!(ts::TVMFuncCreateFromCFunc(
        Some(tvm_callback),
        resource_handle as *mut c_void,
        Some(tvm_callback_finalizer),
        &mut fhandle as *mut _
    ));
    Function::new(fhandle, false, false)
}

/// Registers a Rust function with signature
/// `fn(&[TVMArgValue]) -> Result<TVMRetValue>`
/// as a **global TVM packed function** from frontend to TVM backend.
///
/// Use [`register_global_func`] if overriding an existing global TVM function
/// is not required.
///
/// ## Example
///
/// ```
/// use std::convert::TryInto;
///
/// fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue> {
///     let mut ret = 0i64;
///     for arg in args.iter() {
///         let arg: i64 = arg.try_into()?;
///         ret += arg;
///     }
///     let ret_val = TVMRetValue::from(&ret);
///     Ok(ret_val)
/// }
///
/// tvm::function::register(sum, "mysum".to_owned(), false).unwrap();
/// let mut registered = function::Builder::default();
/// registered.get_function("mysum", true);
/// assert!(registered.func.is_some());
/// let ret: i64 = registered.args(&[10, 20, 30]).invoke().unwrap().try_into().unwrap();
/// assert_eq!(ret, 60);
/// ```
pub fn register<S: AsRef<str>>(
    f: fn(&[TVMArgValue]) -> Result<TVMRetValue>,
    name: S,
    override_: bool,
) -> Result<()> {
    let func = convert_to_tvm_func(f);
    let name = CString::new(name.as_ref())?;
    check_call!(ts::TVMFuncRegisterGlobal(
        name.as_ref().as_ptr() as *const c_char,
        func.handle(),
        override_ as c_int
    ));
    mem::forget(name);
    Ok(())
}

/// Convenient macro for registering functions from frontend to backend as global
/// TVM packed functions without overriding. If overriding an existing function is needed
/// use the [`function::register`] function instead.
///
/// ## Example
///
/// ```
/// use std::convert::TryInto;
///
/// register_global_func! {
///     fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue> {
///         let mut ret = 0f64;
///         for arg in args.iter() {
///             let arg: f64 = arg.try_into()?;
///             ret += arg;
///         }
///         let ret_val = TVMRetValue::from(&ret);
///         Ok(ret_val)
///     }
/// }
///
/// let mut registered = function::Builder::default();
/// registered.get_function("sum", true);
/// assert!(registered.func.is_some());
/// let ret: f64 = registered.args(&[10f64, 20f64, 30f64]).invoke().unwrap().try_into().unwrap();
/// assert_eq!(ret, 60f64);
/// ```
#[macro_export]
macro_rules! register_global_func {
    {
        $(#[$m:meta])*
        fn $fn_name:ident($args:ident : &[TVMArgValue]) -> Result<TVMRetValue> {
            $($code:tt)*
        }
    } => {{
        $(#[$m])*
        fn $fn_name($args: &[TVMArgValue]) -> Result<TVMRetValue> {
            $($code)*
        }

        $crate::function::register($fn_name, stringify!($fn_name).to_owned(), false).unwrap();
    }}
}

/// Convenient macro for calling TVM packed functions by providing a
/// function identifier and some arguments. This macro outputs a `Result` type
/// and let user to perform proper error handling.
///
/// **Note**: this macro does *not* expect an outside mutable output. To
/// set mutable output use [`set_output`] directly in the builder pattern.
///
/// [`set_output`]:function/struct.Builder.html#method.set_output
///
/// ## Example
///
/// Instead of
///
/// ```
/// function::Builder::from(func).arg(&a).arg(&b).invoke();
/// ```
///
/// one can use
///
/// ```
/// call_packed!(func, &a, &b);
/// ```
#[macro_export]
macro_rules! call_packed {
    ($fn_name:expr, $($arg:expr),*) => {{
        let mut builder = $crate::function::Builder::from($fn_name);
        $(
            builder.arg($arg);
        )*
        builder.invoke()
    }}
}

#[cfg(test)]
mod tests {
    use super::*;

    static CANARY: &str = "module._LoadFromFile";

    #[test]
    fn list_global_func() {
        assert!(GLOBAL_FUNCTIONS.lock().unwrap().contains_key(CANARY));
    }

    #[test]
    fn get_fn() {
        assert!(Function::get(CANARY, true).is_some());
        assert!(Function::get("does not exists!", false).is_none());
    }

    #[test]
    fn provide_args() {
        let mut func = Builder::default();
        func.get_function("tvm.graph_runtime.remote_create", true)
            .args(&[10, 20])
            .arg(&"test".to_owned());
        assert!(func.arg_buf.is_some());
        assert_eq!(func.arg_buf.take().map(|bv| Vec::from(bv).len()), Some(3));
    }
}