/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

//! Provides [`TVMContext`] and related device specific queries.
//!
//! Create a new context by device type (cpu is 1) and device id.
//!
//! # Example
//!
//! ```
//! let ctx = TVMContext::new(1, 0);
//! let cpu0 = TVMContext::cpu(0);
//! assert_eq!(ctx, cpu0);
//! ```
//!
//! Or from a supported device name.
//!
//! ```
//! let cpu0 = TVMContext::from("cpu");
//! println!("{}", cpu0);
//! ```

use std::{
    convert::TryInto,
    fmt::{self, Display, Formatter},
    os::raw::c_void,
    ptr,
};

use failure::Error;

use tvm_common::ffi;

use crate::{function, TVMArgValue};

/// Device type can be from a supported device name. See the supported devices
/// in [TVM](https://github.com/dmlc/tvm).
///
/// ## Example
///
/// ```
/// let cpu = TVMDeviceType::from("cpu");
/// println!("device is: {}", cpu);
///```

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct TVMDeviceType(pub i64);

impl Default for TVMDeviceType {
    /// default device is cpu.
    fn default() -> Self {
        TVMDeviceType(1)
    }
}

impl From<TVMDeviceType> for ffi::DLDeviceType {
    fn from(device_type: TVMDeviceType) -> Self {
        match device_type.0 {
            1 => ffi::DLDeviceType_kDLCPU,
            2 => ffi::DLDeviceType_kDLGPU,
            3 => ffi::DLDeviceType_kDLCPUPinned,
            4 => ffi::DLDeviceType_kDLOpenCL,
            7 => ffi::DLDeviceType_kDLVulkan,
            8 => ffi::DLDeviceType_kDLMetal,
            9 => ffi::DLDeviceType_kDLVPI,
            10 => ffi::DLDeviceType_kDLROCM,
            12 => ffi::DLDeviceType_kDLExtDev,
            _ => panic!("device type not found!"),
        }
    }
}

impl From<ffi::DLDeviceType> for TVMDeviceType {
    fn from(device_type: ffi::DLDeviceType) -> Self {
        match device_type {
            ffi::DLDeviceType_kDLCPU => TVMDeviceType(1),
            ffi::DLDeviceType_kDLGPU => TVMDeviceType(2),
            ffi::DLDeviceType_kDLCPUPinned => TVMDeviceType(3),
            ffi::DLDeviceType_kDLOpenCL => TVMDeviceType(4),
            ffi::DLDeviceType_kDLVulkan => TVMDeviceType(7),
            ffi::DLDeviceType_kDLMetal => TVMDeviceType(8),
            ffi::DLDeviceType_kDLVPI => TVMDeviceType(9),
            ffi::DLDeviceType_kDLROCM => TVMDeviceType(10),
            ffi::DLDeviceType_kDLExtDev => TVMDeviceType(12),
            _ => panic!("device type not found!"),
        }
    }
}

impl Display for TVMDeviceType {
    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
        write!(
            f,
            "{}",
            match self {
                TVMDeviceType(1) => "cpu",
                TVMDeviceType(2) => "gpu",
                TVMDeviceType(3) => "cpu_pinned",
                TVMDeviceType(4) => "opencl",
                TVMDeviceType(8) => "meta",
                TVMDeviceType(9) => "vpi",
                TVMDeviceType(10) => "rocm",
                TVMDeviceType(_) => "rpc",
            }
        )
    }
}

impl<'a> From<&'a str> for TVMDeviceType {
    fn from(type_str: &'a str) -> Self {
        match type_str {
            "cpu" => TVMDeviceType(1),
            "llvm" => TVMDeviceType(1),
            "stackvm" => TVMDeviceType(1),
            "gpu" => TVMDeviceType(2),
            "cuda" => TVMDeviceType(2),
            "nvptx" => TVMDeviceType(2),
            "cl" => TVMDeviceType(4),
            "opencl" => TVMDeviceType(4),
            "metal" => TVMDeviceType(8),
            "vpi" => TVMDeviceType(9),
            "rocm" => TVMDeviceType(10),
            _ => panic!("{:?} not supported!", type_str),
        }
    }
}

impl<'a> From<&TVMDeviceType> for TVMArgValue<'a> {
    fn from(dev: &TVMDeviceType) -> Self {
        Self::Int(dev.0)
    }
}

/// Represents the underlying device context. Default is cpu.
///
/// ## Examples
///
/// ```
/// let ctx = TVMContext::from("gpu");
/// assert!(ctx.exist());
///
/// ```
///
/// It is possible to query the underlying context as follows
///
/// ```
/// println!("maximun threads per block: {}", ctx.max_threads_per_block());
/// println!("compute version: {}", ctx.compute_version());
/// ```
#[derive(Debug, Default, Clone, Copy, Hash, PartialEq, Eq)]
pub struct TVMContext {
    /// Supported device types
    pub device_type: TVMDeviceType,
    /// Device id
    pub device_id: i32,
}

impl TVMContext {
    /// Creates context from device type and id.
    pub fn new(device_type: TVMDeviceType, device_id: i32) -> Self {
        TVMContext {
            device_type,
            device_id,
        }
    }
}

macro_rules! impl_ctxs {
    ($(($ctx:ident, $dldevt:expr));+) => {
        $(
            impl TVMContext {
                pub fn $ctx(device_id: i32) -> Self {
                    Self::new(TVMDeviceType($dldevt), device_id)
                }
            }
        )+
    };
}

impl_ctxs!((cpu, 1);
            (gpu, 2);
            (nvptx, 2);
            (cuda, 2);
            (cpu_pinned, 3);
            (cl, 4);
            (opencl, 4);
            (metal, 8);
            (vpi, 9);
            (rocm, 10);
            (opengl, 11);
            (ext_dev, 12));

impl<'a> From<&'a str> for TVMContext {
    fn from(target: &str) -> Self {
        TVMContext::new(TVMDeviceType::from(target), 0)
    }
}

impl TVMContext {
    /// Checks whether the context exists or not.
    pub fn exist(&self) -> bool {
        let func = function::Function::get("_GetDeviceAttr").expect("API function always exists");
        let dt = self.device_type.0 as usize;
        // `unwrap` is ok here because if there is any error,
        // if would occure inside `call_packed!`
        let ret: u64 = call_packed!(func, dt, self.device_id, 0)
            .unwrap()
            .try_into()
            .unwrap();
        ret != 0
    }

    /// Synchronize the context stream.
    pub fn sync(&self) -> Result<(), Error> {
        check_call!(ffi::TVMSynchronize(
            self.device_type.0 as i32,
            self.device_id as i32,
            ptr::null_mut() as *mut c_void
        ));
        Ok(())
    }
}

macro_rules! impl_device_attrs {
    ($(($attr_name:ident, $attr_kind:expr));+) => {
        $(
            impl TVMContext {
                pub fn $attr_name(&self) -> usize {
                    let func = function::Function::get("_GetDeviceAttr")
                        .expect("API function always exists");
                    let dt = self.device_type.0 as usize;
                    // `unwrap` is ok here because if there is any error,
                    // if would occur in function call.
                    function::Builder::from(func)
                        .arg(dt)
                        .arg(self.device_id as usize)
                        .arg($attr_kind)
                        .invoke()
                        .unwrap()
                        .try_into()
                        .unwrap()
                }
            }
        )+
    };
}

impl_device_attrs!((max_threads_per_block, 1);
                (warp_size, 2);
                (max_shared_memory_per_block, 3);
                (compute_version, 4);
                (device_name, 5);
                (max_clock_rate, 6);
                (multi_processor_count, 7);
                (max_thread_dimensions, 8));

impl From<ffi::DLContext> for TVMContext {
    fn from(ctx: ffi::DLContext) -> Self {
        TVMContext {
            device_type: TVMDeviceType::from(ctx.device_type),
            device_id: ctx.device_id,
        }
    }
}

impl From<TVMContext> for ffi::DLContext {
    fn from(ctx: TVMContext) -> Self {
        ffi::DLContext {
            device_type: ctx.device_type.into(),
            device_id: ctx.device_id as i32,
        }
    }
}

impl Display for TVMContext {
    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
        write!(f, "{}({})", self.device_type, self.device_id)
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn context() {
        let ctx = TVMContext::cpu(0);
        println!("ctx: {}", ctx);
        let default_ctx = TVMContext::new(TVMDeviceType(1), 0);
        assert_eq!(ctx.clone(), default_ctx);
        assert_ne!(ctx, TVMContext::gpu(0));

        let str_ctx = TVMContext::new(TVMDeviceType::from("gpu"), 0);
        assert_eq!(str_ctx.clone(), str_ctx);
        assert_ne!(str_ctx, TVMContext::new(TVMDeviceType::from("cpu"), 0));
    }

    #[test]
    fn sync() {
        let ctx = TVMContext::cpu(0);
        assert!(ctx.sync().is_ok())
    }
}