/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ //! This module implements [`TVMArgValue`] and [`TVMRetValue`] types //! and their conversions needed for the types used in frontend crate. //! `TVMRetValue` is the owned version of `TVMPODValue`. use std::convert::TryFrom; use tvm_common::{ errors::ValueDowncastError, ffi::{TVMArrayHandle, TVMFunctionHandle, TVMModuleHandle}, try_downcast, }; use crate::{Function, Module, NDArray, TVMArgValue, TVMRetValue}; macro_rules! impl_handle_val { ($type:ty, $variant:ident, $inner_type:ty, $ctor:path) => { impl<'a> From<&'a $type> for TVMArgValue<'a> { fn from(arg: &'a $type) -> Self { TVMArgValue::$variant(arg.handle() as $inner_type) } } impl<'a> From<&'a mut $type> for TVMArgValue<'a> { fn from(arg: &'a mut $type) -> Self { TVMArgValue::$variant(arg.handle() as $inner_type) } } impl<'a> TryFrom<TVMArgValue<'a>> for $type { type Error = ValueDowncastError; fn try_from(val: TVMArgValue<'a>) -> Result<$type, Self::Error> { try_downcast!(val -> $type, |TVMArgValue::$variant(val)| { $ctor(val) }) } } impl<'a, 'v> TryFrom<&'a TVMArgValue<'v>> for $type { type Error = ValueDowncastError; fn try_from(val: &'a TVMArgValue<'v>) -> Result<$type, Self::Error> { try_downcast!(val -> $type, |TVMArgValue::$variant(val)| { $ctor(*val) }) } } impl From<$type> for TVMRetValue { fn from(val: $type) -> TVMRetValue { TVMRetValue::$variant(val.handle() as $inner_type) } } impl TryFrom<TVMRetValue> for $type { type Error = ValueDowncastError; fn try_from(val: TVMRetValue) -> Result<$type, Self::Error> { try_downcast!(val -> $type, |TVMRetValue::$variant(val)| { $ctor(val) }) } } }; } impl_handle_val!(Function, FuncHandle, TVMFunctionHandle, Function::new); impl_handle_val!(Module, ModuleHandle, TVMModuleHandle, Module::new); impl_handle_val!(NDArray, ArrayHandle, TVMArrayHandle, NDArray::new); #[cfg(test)] mod tests { use std::{convert::TryInto, str::FromStr}; use tvm_common::{TVMByteArray, TVMContext, TVMType}; use super::*; #[test] fn bytearray() { let w = vec![1u8, 2, 3, 4, 5]; let v = TVMByteArray::from(w.as_slice()); let tvm: TVMByteArray = TVMRetValue::from(v).try_into().unwrap(); assert_eq!( tvm.data(), w.iter().map(|e| *e).collect::<Vec<u8>>().as_slice() ); } #[test] fn ty() { let t = TVMType::from_str("int32").unwrap(); let tvm: TVMType = TVMRetValue::from(t).try_into().unwrap(); assert_eq!(tvm, t); } #[test] fn ctx() { let c = TVMContext::from_str("gpu").unwrap(); let tvm: TVMContext = TVMRetValue::from(c).try_into().unwrap(); assert_eq!(tvm, c); } }