/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

//! This module provides an idiomatic Rust API for creating and working with TVM functions.
//!
//! For calling an already registered TVM function use [`function::Builder`]
//! To register a TVM packed function from Rust side either
//! use [`function::register`] or the macro [`register_global_func`].
//!
//! See the tests and examples repository for more examples.

use std::{
    collections::BTreeMap,
    ffi::{CStr, CString},
    mem::{self, MaybeUninit},
    os::raw::{c_char, c_int, c_void},
    ptr, slice, str,
    sync::Mutex,
};

use failure::Error;

use crate::{errors, ffi, Module, TVMArgValue, TVMRetValue};

lazy_static! {
    static ref GLOBAL_FUNCTIONS: Mutex<BTreeMap<&'static str, Option<Function>>> = {
        let mut out_size = 0 as c_int;
        let mut names_ptr = ptr::null_mut() as *mut *const c_char;
        check_call!(ffi::TVMFuncListGlobalNames(
            &mut out_size as *mut _,
            &mut names_ptr as *mut _,
        ));
        let names_list = unsafe { slice::from_raw_parts(names_ptr, out_size as usize) };
        Mutex::new(
            names_list
                .iter()
                .map(|&p| (unsafe { CStr::from_ptr(p).to_str().unwrap() }, None))
                .collect(),
        )
    };
}

/// Wrapper around TVM function handle which includes `is_global`
/// indicating whether the function is global or not, and `is_cloned` showing
/// not to drop a cloned function from Rust side.
/// The value of these fields can be accessed through their respective methods.
#[derive(Debug, Hash)]
pub struct Function {
    pub(crate) handle: ffi::TVMFunctionHandle,
    // whether the registered function is global or not.
    is_global: bool,
    // whether the function has been cloned from frontend or not.
    is_cloned: bool,
}

unsafe impl Send for Function {}
unsafe impl Sync for Function {}

impl Function {
    pub(crate) fn new(handle: ffi::TVMFunctionHandle) -> Self {
        Function {
            handle,
            is_global: false,
            is_cloned: false,
        }
    }

    /// For a given function, it returns a function by name.
    pub fn get<S: AsRef<str>>(name: S) -> Option<&'static Function> {
        let mut globals = GLOBAL_FUNCTIONS.lock().unwrap();
        globals.get_mut(name.as_ref()).and_then(|maybe_func| {
            if maybe_func.is_none() {
                let name = CString::new(name.as_ref()).unwrap();
                let mut handle = ptr::null_mut() as ffi::TVMFunctionHandle;
                check_call!(ffi::TVMFuncGetGlobal(
                    name.as_ptr() as *const c_char,
                    &mut handle as *mut _
                ));
                maybe_func.replace(Function {
                    handle,
                    is_global: true,
                    is_cloned: false,
                });
            }
            unsafe {
                mem::transmute::<Option<&Function>, Option<&'static Function>>(maybe_func.as_ref())
            }
        })
    }

    /// Returns the underlying TVM function handle.
    pub fn handle(&self) -> ffi::TVMFunctionHandle {
        self.handle
    }

    /// Returns `true` if the underlying TVM function is global and `false` otherwise.
    pub fn is_global(&self) -> bool {
        self.is_global
    }

    /// Returns `true` if the underlying TVM function has been cloned
    /// from the frontend and `false` otherwise.
    pub fn is_cloned(&self) -> bool {
        self.is_cloned
    }
}

impl Clone for Function {
    fn clone(&self) -> Function {
        Self {
            handle: self.handle,
            is_global: self.is_global,
            is_cloned: true,
        }
    }
}

impl Drop for Function {
    fn drop(&mut self) {
        if !self.is_global && !self.is_cloned {
            check_call!(ffi::TVMFuncFree(self.handle));
        }
    }
}

/// Function builder in order to create and call functions.
///
/// *Note:* Currently TVM functions accept *at most* one return value.
#[derive(Default)]
pub struct Builder<'a, 'm> {
    pub func: Option<&'m Function>,
    pub arg_buf: Vec<TVMArgValue<'a>>,
    pub ret_buf: Option<TVMRetValue>,
}

impl<'a, 'm> Builder<'a, 'm> {
    pub fn new(
        func: Option<&'m Function>,
        arg_buf: Vec<TVMArgValue<'a>>,
        ret_buf: Option<TVMRetValue>,
    ) -> Self {
        Self {
            func,
            arg_buf,
            ret_buf,
        }
    }

    pub fn get_function(&mut self, name: &'m str) -> &mut Self {
        self.func = Function::get(name);
        self
    }

    /// Pushes a [`TVMArgValue`] into the function argument buffer.
    pub fn arg<T: 'a>(&mut self, arg: T) -> &mut Self
    where
        TVMArgValue<'a>: From<T>,
    {
        self.arg_buf.push(arg.into());
        self
    }

    /// Pushes multiple [`TVMArgValue`]s into the function argument buffer.
    pub fn args<T: 'a, I>(&mut self, args: I) -> &mut Self
    where
        I: IntoIterator<Item = &'a T>,
        TVMArgValue<'a>: From<&'a T>,
    {
        args.into_iter().for_each(|arg| {
            self.arg(&arg);
        });
        self
    }

    /// Sets an output for a function that requirs a mutable output to be provided.
    /// See the `basics` in tests for an example.
    pub fn set_output<T>(&mut self, ret: T) -> &mut Self
    where
        TVMRetValue: From<T>,
    {
        self.ret_buf = Some(ret.into());
        self
    }

    /// Calls the function that created from `Builder`.
    pub fn invoke(&mut self) -> Result<TVMRetValue, Error> {
        #![allow(unused_unsafe)]
        ensure!(self.func.is_some(), errors::FunctionNotFoundError);

        let num_args = self.arg_buf.len();
        let (mut values, mut type_codes): (Vec<ffi::TVMValue>, Vec<ffi::TVMTypeCode>) =
            self.arg_buf.iter().map(|arg| arg.to_tvm_value()).unzip();

        let mut ret_val = unsafe { MaybeUninit::uninit().assume_init() };
        let mut ret_type_code = 0i32;
        check_call!(ffi::TVMFuncCall(
            self.func.ok_or(errors::FunctionNotFoundError)?.handle,
            values.as_mut_ptr(),
            type_codes.as_mut_ptr() as *mut i32,
            num_args as c_int,
            &mut ret_val as *mut _,
            &mut ret_type_code as *mut _
        ));

        Ok(unsafe { TVMRetValue::from_tvm_value(ret_val, ret_type_code as u32) })
    }
}

/// Converts a [`Function`] to builder. Currently, this is the best way to work with
/// TVM functions.
impl<'a, 'm> From<&'m Function> for Builder<'a, 'm> {
    fn from(func: &'m Function) -> Self {
        Builder::new(Some(func), Vec::new(), None)
    }
}

/// Converts a mutable reference of a [`Module`] to [`Builder`].
impl<'a, 'm> From<&'m mut Module> for Builder<'a, 'm> {
    fn from(module: &'m mut Module) -> Self {
        Builder::new(module.entry(), Vec::new(), None)
    }
}

unsafe extern "C" fn tvm_callback(
    args: *mut ffi::TVMValue,
    type_codes: *mut c_int,
    num_args: c_int,
    ret: ffi::TVMRetValueHandle,
    fhandle: *mut c_void,
) -> c_int {
    // turning off the incorrect linter complaints
    #![allow(unused_assignments, unused_unsafe)]
    let len = num_args as usize;
    let args_list = slice::from_raw_parts_mut(args, len);
    let type_codes_list = slice::from_raw_parts_mut(type_codes, len);
    let mut local_args: Vec<TVMArgValue> = Vec::new();
    let mut value = MaybeUninit::uninit().assume_init();
    let mut tcode = MaybeUninit::uninit().assume_init();
    let rust_fn =
        mem::transmute::<*mut c_void, fn(&[TVMArgValue]) -> Result<TVMRetValue, Error>>(fhandle);
    for i in 0..len {
        value = args_list[i];
        tcode = type_codes_list[i];
        if tcode == ffi::TVMTypeCode_kTVMObjectHandle as c_int
            || tcode == ffi::TVMTypeCode_kTVMPackedFuncHandle as c_int
            || tcode == ffi::TVMTypeCode_kTVMModuleHandle as c_int
        {
            check_call!(ffi::TVMCbArgToReturn(&mut value as *mut _, tcode));
        }
        local_args.push(TVMArgValue::from_tvm_value(value, tcode as u32));
    }

    let rv = match rust_fn(local_args.as_slice()) {
        Ok(v) => v,
        Err(msg) => {
            crate::set_last_error(&msg);
            return -1;
        }
    };

    let (mut ret_val, ret_tcode) = rv.to_tvm_value();
    let mut ret_type_code = ret_tcode as c_int;
    check_call!(ffi::TVMCFuncSetReturn(
        ret,
        &mut ret_val as *mut _,
        &mut ret_type_code as *mut _,
        1 as c_int
    ));
    0
}

unsafe extern "C" fn tvm_callback_finalizer(fhandle: *mut c_void) {
    let _rust_fn =
        mem::transmute::<*mut c_void, fn(&[TVMArgValue]) -> Result<TVMRetValue, Error>>(fhandle);
    // XXX: give converted functions lifetimes so they're not called after use
}

fn convert_to_tvm_func(f: fn(&[TVMArgValue]) -> Result<TVMRetValue, Error>) -> Function {
    let mut fhandle = ptr::null_mut() as ffi::TVMFunctionHandle;
    let resource_handle = f as *mut fn(&[TVMArgValue]) -> Result<TVMRetValue, Error>;
    check_call!(ffi::TVMFuncCreateFromCFunc(
        Some(tvm_callback),
        resource_handle as *mut c_void,
        Some(tvm_callback_finalizer),
        &mut fhandle as *mut _
    ));
    Function::new(fhandle)
}

/// Registers a Rust function with signature
/// `fn(&[TVMArgValue]) -> Result<TVMRetValue, Error>`
/// as a **global TVM packed function** from frontend to TVM backend.
///
/// Use [`register_global_func`] if overriding an existing global TVM function
/// is not required.
///
/// ## Example
///
/// ```
/// use std::convert::TryInto;
///
/// fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue, Error> {
///     let mut ret = 0i64;
///     for arg in args.iter() {
///         let arg: i64 = arg.try_into()?;
///         ret += arg;
///     }
///     let ret_val = TVMRetValue::from(&ret);
///     Ok(ret_val)
/// }
///
/// tvm::function::register(sum, "mysum".to_owned(), false).unwrap();
/// let mut registered = function::Builder::default();
/// registered.get_function("mysum", true);
/// assert!(registered.func.is_some());
/// let ret: i64 = registered.args(&[10, 20, 30]).invoke().unwrap().try_into().unwrap();
/// assert_eq!(ret, 60);
/// ```
pub fn register<S: AsRef<str>>(
    f: fn(&[TVMArgValue]) -> Result<TVMRetValue, Error>,
    name: S,
    override_: bool,
) -> Result<(), Error> {
    let func = convert_to_tvm_func(f);
    let name = CString::new(name.as_ref())?;
    check_call!(ffi::TVMFuncRegisterGlobal(
        name.into_raw(),
        func.handle(),
        override_ as c_int
    ));
    Ok(())
}

/// Convenient macro for registering functions from frontend to backend as global
/// TVM packed functions without overriding. If overriding an existing function is needed
/// use the [`function::register`] function instead.
///
/// ## Example
///
/// ```
/// use std::convert::TryInto;
///
/// register_global_func! {
///     fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue, Error> {
///         let mut ret = 0f64;
///         for arg in args.iter() {
///             let arg: f64 = arg.try_into()?;
///             ret += arg;
///         }
///         let ret_val = TVMRetValue::from(&ret);
///         Ok(ret_val)
///     }
/// }
///
/// let mut registered = function::Builder::default();
/// registered.get_function("sum", true);
/// assert!(registered.func.is_some());
/// let ret: f64 = registered.args(&[10f64, 20f64, 30f64]).invoke().unwrap().try_into().unwrap();
/// assert_eq!(ret, 60f64);
/// ```
#[macro_export]
macro_rules! register_global_func {
    {
        $(#[$m:meta])*
        fn $fn_name:ident($args:ident : &[TVMArgValue]) -> Result<TVMRetValue, Error> {
            $($code:tt)*
        }
    } => {{
        $(#[$m])*
        fn $fn_name($args: &[TVMArgValue]) -> Result<TVMRetValue, Error> {
            $($code)*
        }

        $crate::function::register($fn_name, stringify!($fn_name).to_owned(), false).unwrap();
    }}
}

/// Convenient macro for calling TVM packed functions by providing a
/// function identifier and some arguments. This macro outputs a `Result` type
/// and let user to perform proper error handling.
///
/// **Note**: this macro does *not* expect an outside mutable output. To
/// set mutable output use [`set_output`] directly in the builder pattern.
///
/// [`set_output`]:function/struct.Builder.html#method.set_output
///
/// ## Example
///
/// Instead of
///
/// ```
/// function::Builder::from(func).arg(&a).arg(&b).invoke();
/// ```
///
/// one can use
///
/// ```
/// call_packed!(func, &a, &b);
/// ```
#[macro_export]
macro_rules! call_packed {
    ($fn_name:expr, $($arg:expr),*) => {{
        let mut builder = $crate::function::Builder::from($fn_name);
        $(
            builder.arg($arg);
        )*
        builder.invoke()
    }}
}

#[cfg(test)]
mod tests {
    use super::*;

    static CANARY: &str = "module._LoadFromFile";

    #[test]
    fn list_global_func() {
        assert!(GLOBAL_FUNCTIONS.lock().unwrap().contains_key(CANARY));
    }

    #[test]
    fn get_fn() {
        assert!(Function::get(CANARY).is_some());
        assert!(Function::get("does not exists!").is_none());
    }

    #[test]
    fn provide_args() {
        let str_arg = CString::new("test").unwrap();
        let mut func = Builder::default();
        func.get_function("tvm.graph_runtime.remote_create")
            .arg(10)
            .arg(20)
            .arg(str_arg.as_c_str());
        assert_eq!(func.arg_buf.len(), 3);
    }
}