packed_func.rs 12.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23 24
use std::{
    convert::TryFrom,
    ffi::{CStr, CString},
    os::raw::c_void,
};
25 26

pub use crate::ffi::TVMValue;
27
use crate::{errors::ValueDowncastError, ffi::*};
28

29 30 31 32
pub trait PackedFunc : Fn(&[TVMArgValue]) -> Result<TVMRetValue, crate::errors::FuncCallError> + Send + Sync {}

impl<T> PackedFunc for T
    where T : Fn(&[TVMArgValue]) -> Result<TVMRetValue, crate::errors::FuncCallError> + Send + Sync {}
33 34 35 36 37 38 39 40

/// Calls a packed function and returns a `TVMRetValue`.
///
/// # Example
///
/// `call_packed!(my_tvm_func, &mut arg1, &mut arg2)`
#[macro_export]
macro_rules! call_packed {
41 42 43 44 45 46
    ($fn:expr, $($args:expr),+) => {
        $fn(&[$($args.into(),)+])
    };
    ($fn:expr) => {
        $fn(&Vec::new())
    };
47 48
}

49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
/// Constructs a derivative of a TVMPodValue.
macro_rules! TVMPODValue {
    {
        $(#[$m:meta])+
        $name:ident $(<$a:lifetime>)? {
            $($extra_variant:ident ( $variant_type:ty ) ),+ $(,)?
        },
        match $value:ident {
            $($tvm_type:ident => { $from_tvm_type:expr })+
        },
        match &self {
            $($self_type:ident ( $val:ident ) => { $from_self_type:expr })+
        }
        $(,)?
    } => {
        $(#[$m])+
        #[derive(Clone, Debug)]
        pub enum $name $(<$a>)? {
            Int(i64),
            UInt(i64),
            Float(f64),
            Null,
71
            DataType(DLDataType),
72 73 74 75
            String(CString),
            Context(TVMContext),
            Handle(*mut c_void),
            ArrayHandle(TVMArrayHandle),
76
            ObjectHandle(*mut c_void),
77 78 79 80 81 82 83 84 85 86 87
            ModuleHandle(TVMModuleHandle),
            FuncHandle(TVMFunctionHandle),
            NDArrayContainer(*mut c_void),
            $($extra_variant($variant_type)),+
        }

        impl $(<$a>)? $name $(<$a>)? {
            pub fn from_tvm_value($value: TVMValue, type_code: u32) -> Self {
                use $name::*;
                #[allow(non_upper_case_globals)]
                unsafe {
88
                    match type_code as _ {
89 90 91
                        DLDataTypeCode_kDLInt => Int($value.v_int64),
                        DLDataTypeCode_kDLUInt => UInt($value.v_int64),
                        DLDataTypeCode_kDLFloat => Float($value.v_float64),
92 93
                        TVMTypeCode_kTVMNullptr => Null,
                        TVMTypeCode_kTVMDataType => DataType($value.v_type),
94
                        TVMTypeCode_kTVMContext => Context($value.v_ctx),
95 96 97 98 99 100
                        TVMTypeCode_kTVMOpaqueHandle => Handle($value.v_handle),
                        TVMTypeCode_kTVMDLTensorHandle => ArrayHandle($value.v_handle as TVMArrayHandle),
                        TVMTypeCode_kTVMObjectHandle => ObjectHandle($value.v_handle),
                        TVMTypeCode_kTVMModuleHandle => ModuleHandle($value.v_handle),
                        TVMTypeCode_kTVMPackedFuncHandle => FuncHandle($value.v_handle),
                        TVMTypeCode_kTVMNDArrayHandle => NDArrayContainer($value.v_handle),
101 102 103 104 105 106 107 108 109 110 111 112
                        $( $tvm_type => { $from_tvm_type } ),+
                        _ => unimplemented!("{}", type_code),
                    }
                }
            }

            pub fn to_tvm_value(&self) -> (TVMValue, TVMTypeCode) {
                use $name::*;
                match self {
                    Int(val) => (TVMValue { v_int64: *val }, DLDataTypeCode_kDLInt),
                    UInt(val) => (TVMValue { v_int64: *val as i64 }, DLDataTypeCode_kDLUInt),
                    Float(val) => (TVMValue { v_float64: *val }, DLDataTypeCode_kDLFloat),
113 114
                    Null => (TVMValue{ v_int64: 0 },TVMTypeCode_kTVMNullptr),
                    DataType(val) => (TVMValue { v_type: *val }, TVMTypeCode_kTVMDataType),
115 116 117 118
                    Context(val) => (TVMValue { v_ctx: val.clone() }, TVMTypeCode_kTVMContext),
                    String(val) => {
                        (
                            TVMValue { v_handle: val.as_ptr() as *mut c_void },
119
                            TVMTypeCode_kTVMStr,
120 121
                        )
                    }
122
                    Handle(val) => (TVMValue { v_handle: *val }, TVMTypeCode_kTVMOpaqueHandle),
123 124 125
                    ArrayHandle(val) => {
                        (
                            TVMValue { v_handle: *val as *const _ as *mut c_void },
126
                            TVMTypeCode_kTVMNDArrayHandle,
127 128
                        )
                    },
129
                    ObjectHandle(val) => (TVMValue { v_handle: *val }, TVMTypeCode_kTVMObjectHandle),
130
                    ModuleHandle(val) =>
131
                        (TVMValue { v_handle: *val }, TVMTypeCode_kTVMModuleHandle),
132 133
                    FuncHandle(val) => (
                        TVMValue { v_handle: *val },
134
                        TVMTypeCode_kTVMPackedFuncHandle
135 136
                    ),
                    NDArrayContainer(val) =>
137
                        (TVMValue { v_handle: *val }, TVMTypeCode_kTVMNDArrayHandle),
138 139 140 141 142
                    $( $self_type($val) => { $from_self_type } ),+
                }
            }
        }
    }
143 144
}

145 146 147 148 149 150 151 152
TVMPODValue! {
    /// A borrowed TVMPODValue. Can be constructed using `into()` but the preferred way
    /// to obtain a `TVMArgValue` is automatically via `call_packed!`.
    TVMArgValue<'a> {
        Bytes(&'a TVMByteArray),
        Str(&'a CStr),
    },
    match value {
153 154
        TVMTypeCode_kTVMBytes => { Bytes(&*(value.v_handle as *const TVMByteArray)) }
        TVMTypeCode_kTVMStr => { Str(CStr::from_ptr(value.v_handle as *const i8)) }
155 156 157
    },
    match &self {
        Bytes(val) => {
158
            (TVMValue { v_handle: val as *const _ as *mut c_void }, TVMTypeCode_kTVMBytes)
159
        }
160
        Str(val) => { (TVMValue { v_handle: val.as_ptr() as *mut c_void }, TVMTypeCode_kTVMStr) }
161 162 163 164 165 166 167 168 169 170
    }
}

TVMPODValue! {
    /// An owned TVMPODValue. Can be converted from a variety of primitive and object types.
    /// Can be downcasted using `try_from` if it contains the desired type.
    ///
    /// # Example
    ///
    /// ```
171 172 173
    /// use std::convert::{TryFrom, TryInto};
    /// use tvm_common::TVMRetValue;
    ///
174
    /// let a = 42u32;
175
    /// let b: u32 = tvm_common::TVMRetValue::from(a).try_into().unwrap();
176 177
    ///
    /// let s = "hello, world!";
178
    /// let t: TVMRetValue = s.to_string().into();
179 180 181 182 183 184 185
    /// assert_eq!(String::try_from(t).unwrap(), s);
    /// ```
    TVMRetValue {
        Bytes(TVMByteArray),
        Str(&'static CStr),
    },
    match value {
186 187
        TVMTypeCode_kTVMBytes => { Bytes(*(value.v_handle as *const TVMByteArray)) }
        TVMTypeCode_kTVMStr => { Str(CStr::from_ptr(value.v_handle as *mut i8)) }
188 189 190
    },
    match &self {
        Bytes(val) =>
191
            { (TVMValue { v_handle: val as *const _ as *mut c_void }, TVMTypeCode_kTVMBytes ) }
192
        Str(val) =>
193
            { (TVMValue { v_str: val.as_ptr() }, TVMTypeCode_kTVMStr ) }
194 195 196 197
    }
}

#[macro_export]
198 199 200 201 202 203 204 205 206
macro_rules! try_downcast {
    ($val:ident -> $into:ty, $( |$pat:pat| { $converter:expr } ),+ ) => {
        match $val {
            $( $pat => { Ok($converter) } )+
            _ => Err($crate::errors::ValueDowncastError {
                actual_type: format!("{:?}", $val),
                expected_type: stringify!($into),
            }),
        }
207 208 209 210
    };
}

/// Creates a conversion to a `TVMArgValue` for a primitive type and DLDataTypeCode.
211 212
macro_rules! impl_pod_value {
    ($variant:ident, $inner_ty:ty, [ $( $type:ty ),+ ] ) => {
213
        $(
214
            impl<'a> From<$type> for TVMArgValue<'a> {
215
                fn from(val: $type) -> Self {
216
                    Self::$variant(val as $inner_ty)
217 218
                }
            }
219 220

            impl<'a, 'v> From<&'a $type> for TVMArgValue<'v> {
221
                fn from(val: &'a $type) -> Self {
222
                    Self::$variant(*val as $inner_ty)
223 224
                }
            }
225

226
            impl<'a> TryFrom<TVMArgValue<'a>> for $type {
227
                type Error = $crate::errors::ValueDowncastError;
228
                fn try_from(val: TVMArgValue<'a>) -> Result<Self, Self::Error> {
229
                    try_downcast!(val -> $type, |TVMArgValue::$variant(val)| { val as $type })
230 231 232
                }
            }

233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249
            impl<'a, 'v> TryFrom<&'a TVMArgValue<'v>> for $type {
                type Error = $crate::errors::ValueDowncastError;
                fn try_from(val: &'a TVMArgValue<'v>) -> Result<Self, Self::Error> {
                    try_downcast!(val -> $type, |TVMArgValue::$variant(val)| { *val as $type })
                }
            }

            impl From<$type> for TVMRetValue {
                fn from(val: $type) -> Self {
                    Self::$variant(val as $inner_ty)
                }
            }

            impl TryFrom<TVMRetValue> for $type {
              type Error = $crate::errors::ValueDowncastError;
                fn try_from(val: TVMRetValue) -> Result<Self, Self::Error> {
                    try_downcast!(val -> $type, |TVMRetValue::$variant(val)| { val as $type })
250 251 252 253 254 255
                }
            }
        )+
    };
}

256 257 258
impl_pod_value!(Int, i64, [i8, i16, i32, i64, isize]);
impl_pod_value!(UInt, i64, [u8, u16, u32, u64, usize]);
impl_pod_value!(Float, f64, [f32, f64]);
259
impl_pod_value!(DataType, DLDataType, [DLDataType]);
260
impl_pod_value!(Context, TVMContext, [TVMContext]);
261

262 263 264
impl<'a> From<&'a str> for TVMArgValue<'a> {
    fn from(s: &'a str) -> Self {
        Self::String(CString::new(s).unwrap())
265 266 267
    }
}

268 269 270 271 272 273
impl<'a> From<String> for TVMArgValue<'a> {
    fn from(s: String) -> Self {
        Self::String(CString::new(s).unwrap())
    }
}

274 275 276
impl<'a> From<&'a CStr> for TVMArgValue<'a> {
    fn from(s: &'a CStr) -> Self {
        Self::Str(s)
277 278 279
    }
}

280 281 282 283 284 285
impl<'a> From<&'a TVMByteArray> for TVMArgValue<'a> {
    fn from(s: &'a TVMByteArray) -> Self {
        Self::Bytes(s)
    }
}

286 287 288 289
impl<'a> TryFrom<TVMArgValue<'a>> for &'a str {
    type Error = ValueDowncastError;
    fn try_from(val: TVMArgValue<'a>) -> Result<Self, Self::Error> {
        try_downcast!(val -> &str, |TVMArgValue::Str(s)| { s.to_str().unwrap() })
290 291 292
    }
}

293 294 295 296
impl<'a, 'v> TryFrom<&'a TVMArgValue<'v>> for &'v str {
    type Error = ValueDowncastError;
    fn try_from(val: &'a TVMArgValue<'v>) -> Result<Self, Self::Error> {
        try_downcast!(val -> &str, |TVMArgValue::Str(s)| { s.to_str().unwrap() })
297 298 299
    }
}

300 301
/// Converts an unspecialized handle to a TVMArgValue.
impl<T> From<*const T> for TVMArgValue<'static> {
302
    fn from(ptr: *const T) -> Self {
303
        Self::Handle(ptr as *mut c_void)
304 305 306
    }
}

307 308
/// Converts an unspecialized mutable handle to a TVMArgValue.
impl<T> From<*mut T> for TVMArgValue<'static> {
309
    fn from(ptr: *mut T) -> Self {
310
        Self::Handle(ptr as *mut c_void)
311 312 313 314 315
    }
}

impl<'a> From<&'a mut DLTensor> for TVMArgValue<'a> {
    fn from(arr: &'a mut DLTensor) -> Self {
316
        Self::ArrayHandle(arr as *mut DLTensor)
317 318 319 320 321
    }
}

impl<'a> From<&'a DLTensor> for TVMArgValue<'a> {
    fn from(arr: &'a DLTensor) -> Self {
322
        Self::ArrayHandle(arr as *const _ as *mut DLTensor)
323 324 325
    }
}

326 327 328 329 330 331 332 333
impl TryFrom<TVMRetValue> for String {
    type Error = ValueDowncastError;
    fn try_from(val: TVMRetValue) -> Result<String, Self::Error> {
        try_downcast!(
            val -> String,
            |TVMRetValue::String(s)| { s.into_string().unwrap() },
            |TVMRetValue::Str(s)| { s.to_str().unwrap().to_string() }
        )
334 335 336
    }
}

337 338 339
impl From<String> for TVMRetValue {
    fn from(s: String) -> Self {
        Self::String(std::ffi::CString::new(s).unwrap())
340 341 342
    }
}

343 344 345
impl From<TVMByteArray> for TVMRetValue {
    fn from(arr: TVMByteArray) -> Self {
        Self::Bytes(arr)
346 347 348
    }
}

349 350 351 352
impl TryFrom<TVMRetValue> for TVMByteArray {
    type Error = ValueDowncastError;
    fn try_from(val: TVMRetValue) -> Result<Self, Self::Error> {
        try_downcast!(val -> TVMByteArray, |TVMRetValue::Bytes(val)| { val })
353 354 355
    }
}

356 357 358
impl Default for TVMRetValue {
    fn default() -> Self {
        Self::Int(0)
359 360
    }
}