value.rs 3.76 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23
//! This module implements [`TVMArgValue`] and [`TVMRetValue`] types
//! and their conversions needed for the types used in frontend crate.
//! `TVMRetValue` is the owned version of `TVMPODValue`.

24
use std::convert::TryFrom;
25 26

use tvm_common::{
27 28 29
    errors::ValueDowncastError,
    ffi::{TVMArrayHandle, TVMFunctionHandle, TVMModuleHandle},
    try_downcast,
30
};
31

32
use crate::{Function, Module, NDArray, TVMArgValue, TVMRetValue};
33

34 35 36 37 38
macro_rules! impl_handle_val {
    ($type:ty, $variant:ident, $inner_type:ty, $ctor:path) => {
        impl<'a> From<&'a $type> for TVMArgValue<'a> {
            fn from(arg: &'a $type) -> Self {
                TVMArgValue::$variant(arg.handle() as $inner_type)
39
            }
40
        }
41

42 43 44
        impl<'a> From<&'a mut $type> for TVMArgValue<'a> {
            fn from(arg: &'a mut $type) -> Self {
                TVMArgValue::$variant(arg.handle() as $inner_type)
45
            }
46 47
        }

48 49 50 51
        impl<'a> TryFrom<TVMArgValue<'a>> for $type {
            type Error = ValueDowncastError;
            fn try_from(val: TVMArgValue<'a>) -> Result<$type, Self::Error> {
                try_downcast!(val -> $type, |TVMArgValue::$variant(val)| { $ctor(val) })
52
            }
53 54
        }

55 56 57 58
        impl<'a, 'v> TryFrom<&'a TVMArgValue<'v>> for $type {
            type Error = ValueDowncastError;
            fn try_from(val: &'a TVMArgValue<'v>) -> Result<$type, Self::Error> {
                try_downcast!(val -> $type, |TVMArgValue::$variant(val)| { $ctor(*val) })
59 60 61 62
            }
        }

        impl From<$type> for TVMRetValue {
63 64
            fn from(val: $type) -> TVMRetValue {
                TVMRetValue::$variant(val.handle() as $inner_type)
65 66
            }
        }
67

68
        impl TryFrom<TVMRetValue> for $type {
69 70 71
            type Error = ValueDowncastError;
            fn try_from(val: TVMRetValue) -> Result<$type, Self::Error> {
                try_downcast!(val -> $type, |TVMRetValue::$variant(val)| { $ctor(val) })
72 73 74 75 76
            }
        }
    };
}

77 78 79
impl_handle_val!(Function, FuncHandle, TVMFunctionHandle, Function::new);
impl_handle_val!(Module, ModuleHandle, TVMModuleHandle, Module::new);
impl_handle_val!(NDArray, ArrayHandle, TVMArrayHandle, NDArray::new);
80 81 82

#[cfg(test)]
mod tests {
83
    use std::{convert::TryInto, str::FromStr};
84 85 86 87

    use tvm_common::{TVMByteArray, TVMContext, TVMType};

    use super::*;
88 89 90 91

    #[test]
    fn bytearray() {
        let w = vec![1u8, 2, 3, 4, 5];
92
        let v = TVMByteArray::from(w.as_slice());
93
        let tvm: TVMByteArray = TVMRetValue::from(v).try_into().unwrap();
94 95 96 97
        assert_eq!(
            tvm.data(),
            w.iter().map(|e| *e).collect::<Vec<u8>>().as_slice()
        );
98 99 100 101
    }

    #[test]
    fn ty() {
102
        let t = TVMType::from_str("int32").unwrap();
103 104 105 106 107 108
        let tvm: TVMType = TVMRetValue::from(t).try_into().unwrap();
        assert_eq!(tvm, t);
    }

    #[test]
    fn ctx() {
109
        let c = TVMContext::from_str("gpu").unwrap();
110 111 112 113
        let tvm: TVMContext = TVMRetValue::from(c).try_into().unwrap();
        assert_eq!(tvm, c);
    }
}