/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

//! This module implements [`TVMArgValue`] and [`TVMRetValue`] types
//! and their conversions needed for the types used in frontend crate.
//! `TVMRetValue` is the owned version of `TVMPODValue`.

use std::convert::TryFrom;

use tvm_common::{
    errors::ValueDowncastError,
    ffi::{TVMArrayHandle, TVMFunctionHandle, TVMModuleHandle},
    try_downcast,
};

use crate::{Function, Module, NDArray, TVMArgValue, TVMRetValue};

macro_rules! impl_handle_val {
    ($type:ty, $variant:ident, $inner_type:ty, $ctor:path) => {
        impl<'a> From<&'a $type> for TVMArgValue<'a> {
            fn from(arg: &'a $type) -> Self {
                TVMArgValue::$variant(arg.handle() as $inner_type)
            }
        }

        impl<'a> From<&'a mut $type> for TVMArgValue<'a> {
            fn from(arg: &'a mut $type) -> Self {
                TVMArgValue::$variant(arg.handle() as $inner_type)
            }
        }

        impl<'a> TryFrom<TVMArgValue<'a>> for $type {
            type Error = ValueDowncastError;
            fn try_from(val: TVMArgValue<'a>) -> Result<$type, Self::Error> {
                try_downcast!(val -> $type, |TVMArgValue::$variant(val)| { $ctor(val) })
            }
        }

        impl<'a, 'v> TryFrom<&'a TVMArgValue<'v>> for $type {
            type Error = ValueDowncastError;
            fn try_from(val: &'a TVMArgValue<'v>) -> Result<$type, Self::Error> {
                try_downcast!(val -> $type, |TVMArgValue::$variant(val)| { $ctor(*val) })
            }
        }

        impl From<$type> for TVMRetValue {
            fn from(val: $type) -> TVMRetValue {
                TVMRetValue::$variant(val.handle() as $inner_type)
            }
        }

        impl TryFrom<TVMRetValue> for $type {
            type Error = ValueDowncastError;
            fn try_from(val: TVMRetValue) -> Result<$type, Self::Error> {
                try_downcast!(val -> $type, |TVMRetValue::$variant(val)| { $ctor(val) })
            }
        }
    };
}

impl_handle_val!(Function, FuncHandle, TVMFunctionHandle, Function::new);
impl_handle_val!(Module, ModuleHandle, TVMModuleHandle, Module::new);
impl_handle_val!(NDArray, ArrayHandle, TVMArrayHandle, NDArray::new);

#[cfg(test)]
mod tests {
    use std::{convert::TryInto, str::FromStr};

    use tvm_common::{TVMByteArray, TVMContext, TVMType};

    use super::*;

    #[test]
    fn bytearray() {
        let w = vec![1u8, 2, 3, 4, 5];
        let v = TVMByteArray::from(w.as_slice());
        let tvm: TVMByteArray = TVMRetValue::from(v).try_into().unwrap();
        assert_eq!(
            tvm.data(),
            w.iter().map(|e| *e).collect::<Vec<u8>>().as_slice()
        );
    }

    #[test]
    fn ty() {
        let t = TVMType::from_str("int32").unwrap();
        let tvm: TVMType = TVMRetValue::from(t).try_into().unwrap();
        assert_eq!(tvm, t);
    }

    #[test]
    fn ctx() {
        let c = TVMContext::from_str("gpu").unwrap();
        let tvm: TVMContext = TVMRetValue::from(c).try_into().unwrap();
        assert_eq!(tvm, c);
    }
}