/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ use std::{ convert::TryFrom, ffi::{CStr, CString}, os::raw::c_void, }; pub use crate::ffi::TVMValue; use crate::{errors::ValueDowncastError, ffi::*}; pub trait PackedFunc = Fn(&[TVMArgValue]) -> Result<TVMRetValue, crate::errors::FuncCallError> + Send + Sync; /// Calls a packed function and returns a `TVMRetValue`. /// /// # Example /// /// `call_packed!(my_tvm_func, &mut arg1, &mut arg2)` #[macro_export] macro_rules! call_packed { ($fn:expr, $($args:expr),+) => { $fn(&[$($args.into(),)+]) }; ($fn:expr) => { $fn(&Vec::new()) }; } /// Constructs a derivative of a TVMPodValue. macro_rules! TVMPODValue { { $(#[$m:meta])+ $name:ident $(<$a:lifetime>)? { $($extra_variant:ident ( $variant_type:ty ) ),+ $(,)? }, match $value:ident { $($tvm_type:ident => { $from_tvm_type:expr })+ }, match &self { $($self_type:ident ( $val:ident ) => { $from_self_type:expr })+ } $(,)? } => { $(#[$m])+ #[derive(Clone, Debug)] pub enum $name $(<$a>)? { Int(i64), UInt(i64), Float(f64), Null, Type(TVMType), String(CString), Context(TVMContext), Handle(*mut c_void), ArrayHandle(TVMArrayHandle), NodeHandle(*mut c_void), ModuleHandle(TVMModuleHandle), FuncHandle(TVMFunctionHandle), NDArrayContainer(*mut c_void), $($extra_variant($variant_type)),+ } impl $(<$a>)? $name $(<$a>)? { pub fn from_tvm_value($value: TVMValue, type_code: u32) -> Self { use $name::*; #[allow(non_upper_case_globals)] unsafe { match type_code { DLDataTypeCode_kDLInt => Int($value.v_int64), DLDataTypeCode_kDLUInt => UInt($value.v_int64), DLDataTypeCode_kDLFloat => Float($value.v_float64), TVMTypeCode_kNull => Null, TVMTypeCode_kTVMType => Type($value.v_type), TVMTypeCode_kTVMContext => Context($value.v_ctx), TVMTypeCode_kHandle => Handle($value.v_handle), TVMTypeCode_kArrayHandle => ArrayHandle($value.v_handle as TVMArrayHandle), TVMTypeCode_kNodeHandle => NodeHandle($value.v_handle), TVMTypeCode_kModuleHandle => ModuleHandle($value.v_handle), TVMTypeCode_kFuncHandle => FuncHandle($value.v_handle), TVMTypeCode_kNDArrayContainer => NDArrayContainer($value.v_handle), $( $tvm_type => { $from_tvm_type } ),+ _ => unimplemented!("{}", type_code), } } } pub fn to_tvm_value(&self) -> (TVMValue, TVMTypeCode) { use $name::*; match self { Int(val) => (TVMValue { v_int64: *val }, DLDataTypeCode_kDLInt), UInt(val) => (TVMValue { v_int64: *val as i64 }, DLDataTypeCode_kDLUInt), Float(val) => (TVMValue { v_float64: *val }, DLDataTypeCode_kDLFloat), Null => (TVMValue{ v_int64: 0 },TVMTypeCode_kNull), Type(val) => (TVMValue { v_type: *val }, TVMTypeCode_kTVMType), Context(val) => (TVMValue { v_ctx: val.clone() }, TVMTypeCode_kTVMContext), String(val) => { ( TVMValue { v_handle: val.as_ptr() as *mut c_void }, TVMTypeCode_kStr, ) } Handle(val) => (TVMValue { v_handle: *val }, TVMTypeCode_kHandle), ArrayHandle(val) => { ( TVMValue { v_handle: *val as *const _ as *mut c_void }, TVMTypeCode_kArrayHandle, ) }, NodeHandle(val) => (TVMValue { v_handle: *val }, TVMTypeCode_kNodeHandle), ModuleHandle(val) => (TVMValue { v_handle: *val }, TVMTypeCode_kModuleHandle), FuncHandle(val) => ( TVMValue { v_handle: *val }, TVMTypeCode_kFuncHandle ), NDArrayContainer(val) => (TVMValue { v_handle: *val }, TVMTypeCode_kNDArrayContainer), $( $self_type($val) => { $from_self_type } ),+ } } } } } TVMPODValue! { /// A borrowed TVMPODValue. Can be constructed using `into()` but the preferred way /// to obtain a `TVMArgValue` is automatically via `call_packed!`. TVMArgValue<'a> { Bytes(&'a TVMByteArray), Str(&'a CStr), }, match value { TVMTypeCode_kBytes => { Bytes(&*(value.v_handle as *const TVMByteArray)) } TVMTypeCode_kStr => { Str(CStr::from_ptr(value.v_handle as *const i8)) } }, match &self { Bytes(val) => { (TVMValue { v_handle: val.clone() as *const _ as *mut c_void }, TVMTypeCode_kBytes) } Str(val) => { (TVMValue { v_handle: val.as_ptr() as *mut c_void }, TVMTypeCode_kStr) } } } TVMPODValue! { /// An owned TVMPODValue. Can be converted from a variety of primitive and object types. /// Can be downcasted using `try_from` if it contains the desired type. /// /// # Example /// /// ``` /// let a = 42u32; /// let b: i64 = TVMRetValue::from(a).try_into().unwrap(); /// /// let s = "hello, world!"; /// let t: TVMRetValue = s.into(); /// assert_eq!(String::try_from(t).unwrap(), s); /// ``` TVMRetValue { Bytes(TVMByteArray), Str(&'static CStr), }, match value { TVMTypeCode_kBytes => { Bytes(*(value.v_handle as *const TVMByteArray)) } TVMTypeCode_kStr => { Str(CStr::from_ptr(value.v_handle as *mut i8)) } }, match &self { Bytes(val) => { (TVMValue { v_handle: val as *const _ as *mut c_void }, TVMTypeCode_kBytes ) } Str(val) => { (TVMValue { v_str: val.as_ptr() }, TVMTypeCode_kStr ) } } } #[macro_export] macro_rules! try_downcast { ($val:ident -> $into:ty, $( |$pat:pat| { $converter:expr } ),+ ) => { match $val { $( $pat => { Ok($converter) } )+ _ => Err($crate::errors::ValueDowncastError { actual_type: format!("{:?}", $val), expected_type: stringify!($into), }), } }; } /// Creates a conversion to a `TVMArgValue` for a primitive type and DLDataTypeCode. macro_rules! impl_pod_value { ($variant:ident, $inner_ty:ty, [ $( $type:ty ),+ ] ) => { $( impl<'a> From<$type> for TVMArgValue<'a> { fn from(val: $type) -> Self { Self::$variant(val as $inner_ty) } } impl<'a, 'v> From<&'a $type> for TVMArgValue<'v> { fn from(val: &'a $type) -> Self { Self::$variant(*val as $inner_ty) } } impl<'a> TryFrom<TVMArgValue<'a>> for $type { type Error = $crate::errors::ValueDowncastError; fn try_from(val: TVMArgValue<'a>) -> Result<Self, Self::Error> { try_downcast!(val -> $type, |TVMArgValue::$variant(val)| { val as $type }) } } impl<'a, 'v> TryFrom<&'a TVMArgValue<'v>> for $type { type Error = $crate::errors::ValueDowncastError; fn try_from(val: &'a TVMArgValue<'v>) -> Result<Self, Self::Error> { try_downcast!(val -> $type, |TVMArgValue::$variant(val)| { *val as $type }) } } impl From<$type> for TVMRetValue { fn from(val: $type) -> Self { Self::$variant(val as $inner_ty) } } impl TryFrom<TVMRetValue> for $type { type Error = $crate::errors::ValueDowncastError; fn try_from(val: TVMRetValue) -> Result<Self, Self::Error> { try_downcast!(val -> $type, |TVMRetValue::$variant(val)| { val as $type }) } } )+ }; } impl_pod_value!(Int, i64, [i8, i16, i32, i64, isize]); impl_pod_value!(UInt, i64, [u8, u16, u32, u64, usize]); impl_pod_value!(Float, f64, [f32, f64]); impl_pod_value!(Type, TVMType, [TVMType]); impl_pod_value!(Context, TVMContext, [TVMContext]); impl<'a> From<&'a str> for TVMArgValue<'a> { fn from(s: &'a str) -> Self { Self::String(CString::new(s).unwrap()) } } impl<'a> From<String> for TVMArgValue<'a> { fn from(s: String) -> Self { Self::String(CString::new(s).unwrap()) } } impl<'a> From<&'a CStr> for TVMArgValue<'a> { fn from(s: &'a CStr) -> Self { Self::Str(s) } } impl<'a> From<&'a TVMByteArray> for TVMArgValue<'a> { fn from(s: &'a TVMByteArray) -> Self { Self::Bytes(s) } } impl<'a> TryFrom<TVMArgValue<'a>> for &'a str { type Error = ValueDowncastError; fn try_from(val: TVMArgValue<'a>) -> Result<Self, Self::Error> { try_downcast!(val -> &str, |TVMArgValue::Str(s)| { s.to_str().unwrap() }) } } impl<'a, 'v> TryFrom<&'a TVMArgValue<'v>> for &'v str { type Error = ValueDowncastError; fn try_from(val: &'a TVMArgValue<'v>) -> Result<Self, Self::Error> { try_downcast!(val -> &str, |TVMArgValue::Str(s)| { s.to_str().unwrap() }) } } /// Converts an unspecialized handle to a TVMArgValue. impl<T> From<*const T> for TVMArgValue<'static> { fn from(ptr: *const T) -> Self { Self::Handle(ptr as *mut c_void) } } /// Converts an unspecialized mutable handle to a TVMArgValue. impl<T> From<*mut T> for TVMArgValue<'static> { fn from(ptr: *mut T) -> Self { Self::Handle(ptr as *mut c_void) } } impl<'a> From<&'a mut DLTensor> for TVMArgValue<'a> { fn from(arr: &'a mut DLTensor) -> Self { Self::ArrayHandle(arr as *mut DLTensor) } } impl<'a> From<&'a DLTensor> for TVMArgValue<'a> { fn from(arr: &'a DLTensor) -> Self { Self::ArrayHandle(arr as *const _ as *mut DLTensor) } } impl TryFrom<TVMRetValue> for String { type Error = ValueDowncastError; fn try_from(val: TVMRetValue) -> Result<String, Self::Error> { try_downcast!( val -> String, |TVMRetValue::String(s)| { s.into_string().unwrap() }, |TVMRetValue::Str(s)| { s.to_str().unwrap().to_string() } ) } } impl From<String> for TVMRetValue { fn from(s: String) -> Self { Self::String(std::ffi::CString::new(s).unwrap()) } } impl From<TVMByteArray> for TVMRetValue { fn from(arr: TVMByteArray) -> Self { Self::Bytes(arr) } } impl TryFrom<TVMRetValue> for TVMByteArray { type Error = ValueDowncastError; fn try_from(val: TVMRetValue) -> Result<Self, Self::Error> { try_downcast!(val -> TVMByteArray, |TVMRetValue::Bytes(val)| { val }) } } impl Default for TVMRetValue { fn default() -> Self { Self::Int(0) } }