Commit 14a0ecba by Nick Hynes Committed by Tianqi Chen

Rustify PackedFunc & Friends (#2969)

parent 0708c48d
......@@ -3,6 +3,7 @@ name = "tvm-common"
version = "0.1.0"
authors = ["TVM Contributors"]
license = "Apache-2.0"
edition = "2018"
[features]
bindings = []
......
use std::fmt;
static TYPE_CODE_STRS: [&str; 15] = [
"int",
"uint",
"float",
"handle",
"null",
"TVMType",
"TVMContext",
"ArrayHandle",
"NodeHandle",
"ModuleHandle",
"FuncHandle",
"str",
"bytes",
"NDArrayContainer",
"ExtBegin",
];
#[derive(Debug, Fail)]
#[fail(
display = "Could not downcast `{}` into `{}`",
expected_type, actual_type
)]
pub struct ValueDowncastError {
actual_type_code: i64,
expected_type_code: i64,
}
impl ValueDowncastError {
pub fn new(actual_type_code: i64, expected_type_code: i64) -> Self {
Self {
actual_type_code,
expected_type_code,
}
}
}
impl fmt::Display for ValueDowncastError {
fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
write!(
formatter,
"Could not downcast TVMValue: expected `{}` but was {}",
TYPE_CODE_STRS[self.actual_type_code as usize],
TYPE_CODE_STRS[self.expected_type_code as usize]
)
}
pub actual_type: String,
pub expected_type: &'static str,
}
#[derive(Debug, Fail)]
......@@ -62,18 +26,3 @@ impl FuncCallError {
}
}
}
// error_chain! {
// errors {
// TryFromTVMRetValueError(expected_type: String, actual_type_code: i64) {
// description("mismatched types while downcasting TVMRetValue")
// display("invalid downcast: expected `{}` but was `{}`",
// expected_type, type_code_to_string(actual_type_code))
// }
// }
// foreign_links {
// IntoString(std::ffi::IntoStringError);
// ParseInt(std::num::ParseIntError);
// Utf8(std::str::Utf8Error);
// }
// }
//! This crate contains the refactored basic components required
//! for `runtime` and `frontend` TVM crates.
#![feature(box_syntax, trait_alias)]
#![feature(box_syntax, type_alias_enum_variants, trait_alias)]
#[macro_use]
extern crate failure;
......@@ -25,5 +25,5 @@ pub mod packed_func;
pub mod value;
pub use errors::*;
pub use ffi::{TVMContext, TVMType};
pub use ffi::{TVMByteArray, TVMContext, TVMType};
pub use packed_func::{TVMArgValue, TVMRetValue};
use std::{any::Any, convert::TryFrom, marker::PhantomData, os::raw::c_void};
use failure::Error;
use std::{
convert::TryFrom,
ffi::{CStr, CString},
os::raw::c_void,
};
pub use crate::ffi::TVMValue;
use crate::ffi::*;
use crate::{errors::ValueDowncastError, ffi::*};
pub trait PackedFunc =
Fn(&[TVMArgValue]) -> Result<TVMRetValue, crate::errors::FuncCallError> + Send + Sync;
......@@ -15,298 +17,308 @@ pub trait PackedFunc =
/// `call_packed!(my_tvm_func, &mut arg1, &mut arg2)`
#[macro_export]
macro_rules! call_packed {
($fn:expr, $($args:expr),+) => {
$fn(&[$($args.into(),)+])
};
($fn:expr) => {
$fn(&Vec::new())
};
($fn:expr, $($args:expr),+) => {
$fn(&[$($args.into(),)+])
};
($fn:expr) => {
$fn(&Vec::new())
};
}
/// A borrowed TVMPODValue. Can be constructed using `into()` but the preferred way
/// to obtain a `TVMArgValue` is automatically via `call_packed!`.
#[derive(Clone, Copy)]
pub struct TVMArgValue<'a> {
pub _lifetime: PhantomData<&'a ()>,
pub value: TVMValue,
pub type_code: i64,
/// Constructs a derivative of a TVMPodValue.
macro_rules! TVMPODValue {
{
$(#[$m:meta])+
$name:ident $(<$a:lifetime>)? {
$($extra_variant:ident ( $variant_type:ty ) ),+ $(,)?
},
match $value:ident {
$($tvm_type:ident => { $from_tvm_type:expr })+
},
match &self {
$($self_type:ident ( $val:ident ) => { $from_self_type:expr })+
}
$(,)?
} => {
$(#[$m])+
#[derive(Clone, Debug)]
pub enum $name $(<$a>)? {
Int(i64),
UInt(i64),
Float(f64),
Null,
Type(TVMType),
String(CString),
Context(TVMContext),
Handle(*mut c_void),
ArrayHandle(TVMArrayHandle),
NodeHandle(*mut c_void),
ModuleHandle(TVMModuleHandle),
FuncHandle(TVMFunctionHandle),
NDArrayContainer(*mut c_void),
$($extra_variant($variant_type)),+
}
impl $(<$a>)? $name $(<$a>)? {
pub fn from_tvm_value($value: TVMValue, type_code: u32) -> Self {
use $name::*;
#[allow(non_upper_case_globals)]
unsafe {
match type_code {
DLDataTypeCode_kDLInt => Int($value.v_int64),
DLDataTypeCode_kDLUInt => UInt($value.v_int64),
DLDataTypeCode_kDLFloat => Float($value.v_float64),
TVMTypeCode_kNull => Null,
TVMTypeCode_kTVMType => Type($value.v_type),
TVMTypeCode_kTVMContext => Context($value.v_ctx),
TVMTypeCode_kHandle => Handle($value.v_handle),
TVMTypeCode_kArrayHandle => ArrayHandle($value.v_handle as TVMArrayHandle),
TVMTypeCode_kNodeHandle => NodeHandle($value.v_handle),
TVMTypeCode_kModuleHandle => ModuleHandle($value.v_handle),
TVMTypeCode_kFuncHandle => FuncHandle($value.v_handle),
TVMTypeCode_kNDArrayContainer => NDArrayContainer($value.v_handle),
$( $tvm_type => { $from_tvm_type } ),+
_ => unimplemented!("{}", type_code),
}
}
}
pub fn to_tvm_value(&self) -> (TVMValue, TVMTypeCode) {
use $name::*;
match self {
Int(val) => (TVMValue { v_int64: *val }, DLDataTypeCode_kDLInt),
UInt(val) => (TVMValue { v_int64: *val as i64 }, DLDataTypeCode_kDLUInt),
Float(val) => (TVMValue { v_float64: *val }, DLDataTypeCode_kDLFloat),
Null => (TVMValue{ v_int64: 0 },TVMTypeCode_kNull),
Type(val) => (TVMValue { v_type: *val }, TVMTypeCode_kTVMType),
Context(val) => (TVMValue { v_ctx: val.clone() }, TVMTypeCode_kTVMContext),
String(val) => {
(
TVMValue { v_handle: val.as_ptr() as *mut c_void },
TVMTypeCode_kStr,
)
}
Handle(val) => (TVMValue { v_handle: *val }, TVMTypeCode_kHandle),
ArrayHandle(val) => {
(
TVMValue { v_handle: *val as *const _ as *mut c_void },
TVMTypeCode_kArrayHandle,
)
},
NodeHandle(val) => (TVMValue { v_handle: *val }, TVMTypeCode_kNodeHandle),
ModuleHandle(val) =>
(TVMValue { v_handle: *val }, TVMTypeCode_kModuleHandle),
FuncHandle(val) => (
TVMValue { v_handle: *val },
TVMTypeCode_kFuncHandle
),
NDArrayContainer(val) =>
(TVMValue { v_handle: *val }, TVMTypeCode_kNDArrayContainer),
$( $self_type($val) => { $from_self_type } ),+
}
}
}
}
}
impl<'a> TVMArgValue<'a> {
pub fn new(value: TVMValue, type_code: i64) -> Self {
TVMArgValue {
_lifetime: PhantomData,
value: value,
type_code: type_code,
TVMPODValue! {
/// A borrowed TVMPODValue. Can be constructed using `into()` but the preferred way
/// to obtain a `TVMArgValue` is automatically via `call_packed!`.
TVMArgValue<'a> {
Bytes(&'a TVMByteArray),
Str(&'a CStr),
},
match value {
TVMTypeCode_kBytes => { Bytes(&*(value.v_handle as *const TVMByteArray)) }
TVMTypeCode_kStr => { Str(CStr::from_ptr(value.v_handle as *const i8)) }
},
match &self {
Bytes(val) => {
(TVMValue { v_handle: val.clone() as *const _ as *mut c_void }, TVMTypeCode_kBytes)
}
Str(val) => { (TVMValue { v_handle: val.as_ptr() as *mut c_void }, TVMTypeCode_kStr)}
}
}
TVMPODValue! {
/// An owned TVMPODValue. Can be converted from a variety of primitive and object types.
/// Can be downcasted using `try_from` if it contains the desired type.
///
/// # Example
///
/// ```
/// let a = 42u32;
/// let b: i64 = TVMRetValue::from(a).try_into().unwrap();
///
/// let s = "hello, world!";
/// let t: TVMRetValue = s.into();
/// assert_eq!(String::try_from(t).unwrap(), s);
/// ```
TVMRetValue {
Bytes(TVMByteArray),
Str(&'static CStr),
},
match value {
TVMTypeCode_kBytes => { Bytes(*(value.v_handle as *const TVMByteArray)) }
TVMTypeCode_kStr => { Str(CStr::from_ptr(value.v_handle as *mut i8)) }
},
match &self {
Bytes(val) =>
{ (TVMValue { v_handle: val as *const _ as *mut c_void }, TVMTypeCode_kBytes ) }
Str(val) =>
{ (TVMValue { v_str: val.as_ptr() }, TVMTypeCode_kStr ) }
}
}
#[macro_export]
macro_rules! ensure_type {
($val:ident, $expected_type_code:expr) => {
ensure!(
$val.type_code == $expected_type_code as i64,
$crate::errors::ValueDowncastError::new(
$val.type_code as i64,
$expected_type_code as i64
)
);
macro_rules! try_downcast {
($val:ident -> $into:ty, $( |$pat:pat| { $converter:expr } ),+ ) => {
match $val {
$( $pat => { Ok($converter) } )+
_ => Err($crate::errors::ValueDowncastError {
actual_type: format!("{:?}", $val),
expected_type: stringify!($into),
}),
}
};
}
/// Creates a conversion to a `TVMArgValue` for a primitive type and DLDataTypeCode.
macro_rules! impl_prim_tvm_arg {
($type_code:ident, $field:ident, $field_type:ty, [ $( $type:ty ),+ ] ) => {
macro_rules! impl_pod_value {
($variant:ident, $inner_ty:ty, [ $( $type:ty ),+ ] ) => {
$(
impl From<$type> for TVMArgValue<'static> {
impl<'a> From<$type> for TVMArgValue<'a> {
fn from(val: $type) -> Self {
TVMArgValue {
value: TVMValue { $field: val as $field_type },
type_code: $type_code as i64,
_lifetime: PhantomData,
}
Self::$variant(val as $inner_ty)
}
}
impl<'a> From<&'a $type> for TVMArgValue<'a> {
impl<'a, 'v> From<&'a $type> for TVMArgValue<'v> {
fn from(val: &'a $type) -> Self {
TVMArgValue {
value: TVMValue {
$field: val.to_owned() as $field_type,
},
type_code: $type_code as i64,
_lifetime: PhantomData,
}
Self::$variant(*val as $inner_ty)
}
}
impl<'a> TryFrom<TVMArgValue<'a>> for $type {
type Error = Error;
type Error = $crate::errors::ValueDowncastError;
fn try_from(val: TVMArgValue<'a>) -> Result<Self, Self::Error> {
ensure_type!(val, $type_code);
Ok(unsafe { val.value.$field as $type })
try_downcast!(val -> $type, |TVMArgValue::$variant(val)| { val as $type })
}
}
impl<'a> TryFrom<&TVMArgValue<'a>> for $type {
type Error = Error;
fn try_from(val: &TVMArgValue<'a>) -> Result<Self, Self::Error> {
ensure_type!(val, $type_code);
Ok(unsafe { val.value.$field as $type })
impl<'a, 'v> TryFrom<&'a TVMArgValue<'v>> for $type {
type Error = $crate::errors::ValueDowncastError;
fn try_from(val: &'a TVMArgValue<'v>) -> Result<Self, Self::Error> {
try_downcast!(val -> $type, |TVMArgValue::$variant(val)| { *val as $type })
}
}
impl From<$type> for TVMRetValue {
fn from(val: $type) -> Self {
Self::$variant(val as $inner_ty)
}
}
impl TryFrom<TVMRetValue> for $type {
type Error = $crate::errors::ValueDowncastError;
fn try_from(val: TVMRetValue) -> Result<Self, Self::Error> {
try_downcast!(val -> $type, |TVMRetValue::$variant(val)| { val as $type })
}
}
)+
};
}
impl_prim_tvm_arg!(DLDataTypeCode_kDLFloat, v_float64, f64, [f32, f64]);
impl_prim_tvm_arg!(
DLDataTypeCode_kDLInt,
v_int64,
i64,
[i8, i16, i32, i64, isize]
);
impl_prim_tvm_arg!(
DLDataTypeCode_kDLUInt,
v_int64,
i64,
[u8, u16, u32, u64, usize]
);
impl_pod_value!(Int, i64, [i8, i16, i32, i64, isize]);
impl_pod_value!(UInt, i64, [u8, u16, u32, u64, usize]);
impl_pod_value!(Float, f64, [f32, f64]);
impl_pod_value!(Type, TVMType, [TVMType]);
impl_pod_value!(Context, TVMContext, [TVMContext]);
#[cfg(feature = "bindings")]
// only allow this in bindings because pure-rust can't take ownership of leaked CString
impl<'a> From<&String> for TVMArgValue<'a> {
fn from(string: &String) -> Self {
TVMArgValue {
value: TVMValue {
v_str: std::ffi::CString::new(string.clone()).unwrap().into_raw(),
},
type_code: TVMTypeCode_kStr as i64,
_lifetime: PhantomData,
}
impl<'a> From<&'a str> for TVMArgValue<'a> {
fn from(s: &'a str) -> Self {
Self::String(CString::new(s).unwrap())
}
}
impl<'a> From<&std::ffi::CString> for TVMArgValue<'a> {
fn from(string: &std::ffi::CString) -> Self {
TVMArgValue {
value: TVMValue {
v_str: string.as_ptr(),
},
type_code: TVMTypeCode_kStr as i64,
_lifetime: PhantomData,
}
impl<'a> From<&'a CStr> for TVMArgValue<'a> {
fn from(s: &'a CStr) -> Self {
Self::Str(s)
}
}
impl<'a> TryFrom<TVMArgValue<'a>> for &str {
type Error = Error;
fn try_from(arg: TVMArgValue<'a>) -> Result<Self, Self::Error> {
ensure_type!(arg, TVMTypeCode_kStr);
Ok(unsafe { std::ffi::CStr::from_ptr(arg.value.v_handle as *const i8) }.to_str()?)
impl<'a> TryFrom<TVMArgValue<'a>> for &'a str {
type Error = ValueDowncastError;
fn try_from(val: TVMArgValue<'a>) -> Result<Self, Self::Error> {
try_downcast!(val -> &str, |TVMArgValue::Str(s)| { s.to_str().unwrap() })
}
}
impl<'a> TryFrom<&TVMArgValue<'a>> for &str {
type Error = Error;
fn try_from(arg: &TVMArgValue<'a>) -> Result<Self, Self::Error> {
ensure_type!(arg, TVMTypeCode_kStr);
Ok(unsafe { std::ffi::CStr::from_ptr(arg.value.v_handle as *const i8) }.to_str()?)
impl<'a, 'v> TryFrom<&'a TVMArgValue<'v>> for &'v str {
type Error = ValueDowncastError;
fn try_from(val: &'a TVMArgValue<'v>) -> Result<Self, Self::Error> {
try_downcast!(val -> &str, |TVMArgValue::Str(s)| { s.to_str().unwrap() })
}
}
/// Creates a conversion to a `TVMArgValue` for an object handle.
impl<'a, T> From<*const T> for TVMArgValue<'a> {
/// Converts an unspecialized handle to a TVMArgValue.
impl<T> From<*const T> for TVMArgValue<'static> {
fn from(ptr: *const T) -> Self {
TVMArgValue {
value: TVMValue {
v_handle: ptr as *mut T as *mut c_void,
},
type_code: TVMTypeCode_kArrayHandle as i64,
_lifetime: PhantomData,
}
Self::Handle(ptr as *mut c_void)
}
}
/// Creates a conversion to a `TVMArgValue` for a mutable object handle.
impl<'a, T> From<*mut T> for TVMArgValue<'a> {
/// Converts an unspecialized mutable handle to a TVMArgValue.
impl<T> From<*mut T> for TVMArgValue<'static> {
fn from(ptr: *mut T) -> Self {
TVMArgValue {
value: TVMValue {
v_handle: ptr as *mut c_void,
},
type_code: TVMTypeCode_kHandle as i64,
_lifetime: PhantomData,
}
Self::Handle(ptr as *mut c_void)
}
}
impl<'a> From<&'a mut DLTensor> for TVMArgValue<'a> {
fn from(arr: &'a mut DLTensor) -> Self {
TVMArgValue {
value: TVMValue {
v_handle: arr as *mut _ as *mut c_void,
},
type_code: TVMTypeCode_kArrayHandle as i64,
_lifetime: PhantomData,
}
Self::ArrayHandle(arr as *mut DLTensor)
}
}
impl<'a> From<&'a DLTensor> for TVMArgValue<'a> {
fn from(arr: &'a DLTensor) -> Self {
TVMArgValue {
value: TVMValue {
v_handle: arr as *const _ as *mut DLTensor as *mut c_void,
},
type_code: TVMTypeCode_kArrayHandle as i64,
_lifetime: PhantomData,
}
Self::ArrayHandle(arr as *const _ as *mut DLTensor)
}
}
impl<'a, 'v> TryFrom<&'a TVMArgValue<'v>> for TVMType {
type Error = Error;
fn try_from(arg: &'a TVMArgValue<'v>) -> Result<Self, Self::Error> {
ensure_type!(arg, TVMTypeCode_kTVMType);
Ok(unsafe { arg.value.v_type.into() })
impl TryFrom<TVMRetValue> for String {
type Error = ValueDowncastError;
fn try_from(val: TVMRetValue) -> Result<String, Self::Error> {
try_downcast!(
val -> String,
|TVMRetValue::String(s)| { s.into_string().unwrap() },
|TVMRetValue::Str(s)| { s.to_str().unwrap().to_string() }
)
}
}
/// An owned TVMPODValue. Can be converted from a variety of primitive and object types.
/// Can be downcasted using `try_from` if it contains the desired type.
///
/// # Example
///
/// ```
/// let a = 42u32;
/// let b: i64 = TVMRetValue::from(a).try_into().unwrap();
///
/// let s = "hello, world!";
/// let t: TVMRetValue = s.into();
/// assert_eq!(String::try_from(t).unwrap(), s);
/// ```
pub struct TVMRetValue {
pub value: TVMValue,
pub box_value: Box<Any>,
pub type_code: i64,
}
impl TVMRetValue {
pub fn from_tvm_value(value: TVMValue, type_code: i64) -> Self {
Self {
value,
type_code,
box_value: box (),
}
}
pub fn into_tvm_value(self) -> (TVMValue, TVMTypeCode) {
(self.value, self.type_code as TVMTypeCode)
impl From<String> for TVMRetValue {
fn from(s: String) -> Self {
Self::String(std::ffi::CString::new(s).unwrap())
}
}
impl Default for TVMRetValue {
fn default() -> Self {
TVMRetValue {
value: TVMValue { v_int64: 0 as i64 },
type_code: 0,
box_value: box (),
}
impl From<TVMByteArray> for TVMRetValue {
fn from(arr: TVMByteArray) -> Self {
Self::Bytes(arr)
}
}
macro_rules! impl_pod_ret_value {
($code:expr, [ $( $ty:ty ),+ ] ) => {
$(
impl From<$ty> for TVMRetValue {
fn from(val: $ty) -> Self {
Self {
value: val.into(),
type_code: $code as i64,
box_value: box (),
}
}
}
impl TryFrom<TVMRetValue> for $ty {
type Error = Error;
fn try_from(ret: TVMRetValue) -> Result<$ty, Self::Error> {
ensure_type!(ret, $code);
Ok(ret.value.into())
}
}
)+
};
}
impl_pod_ret_value!(DLDataTypeCode_kDLInt, [i8, i16, i32, i64, isize]);
impl_pod_ret_value!(DLDataTypeCode_kDLUInt, [u8, u16, u32, u64, usize]);
impl_pod_ret_value!(DLDataTypeCode_kDLFloat, [f32, f64]);
impl_pod_ret_value!(TVMTypeCode_kTVMType, [TVMType]);
impl_pod_ret_value!(TVMTypeCode_kTVMContext, [TVMContext]);
impl TryFrom<TVMRetValue> for String {
type Error = Error;
fn try_from(ret: TVMRetValue) -> Result<String, Self::Error> {
ensure_type!(ret, TVMTypeCode_kStr);
let cs = unsafe { std::ffi::CString::from_raw(ret.value.v_handle as *mut i8) };
let ret_str = cs.clone().into_string();
if cfg!(feature = "bindings") {
std::mem::forget(cs); // TVM C++ takes ownership of CString. (@see TVMFuncCall)
}
Ok(ret_str?)
impl TryFrom<TVMRetValue> for TVMByteArray {
type Error = ValueDowncastError;
fn try_from(val: TVMRetValue) -> Result<Self, Self::Error> {
try_downcast!(val -> TVMByteArray, |TVMRetValue::Bytes(val)| { val })
}
}
impl From<String> for TVMRetValue {
fn from(s: String) -> Self {
let cs = std::ffi::CString::new(s).unwrap();
Self {
value: TVMValue {
v_str: cs.into_raw() as *mut i8,
},
box_value: box (),
type_code: TVMTypeCode_kStr as i64,
}
impl Default for TVMRetValue {
fn default() -> Self {
Self::Int(0)
}
}
......@@ -137,3 +137,18 @@ impl_tvm_context!(
DLDeviceType_kDLROCM: [rocm],
DLDeviceType_kDLExtDev: [ext_dev]
);
impl TVMByteArray {
pub fn data(&self) -> &'static [u8] {
unsafe { std::slice::from_raw_parts(self.data as *const u8, self.size) }
}
}
impl<'a> From<&'a [u8]> for TVMByteArray {
fn from(bytes: &[u8]) -> Self {
Self {
data: bytes.as_ptr() as *const i8,
size: bytes.len(),
}
}
}
......@@ -9,6 +9,7 @@ readme = "README.md"
keywords = ["rust", "tvm", "nnvm"]
categories = ["api-bindings", "science"]
authors = ["TVM Contributors"]
edition = "2018"
[lib]
name = "tvm_frontend"
......
......@@ -3,9 +3,9 @@
//!
//! For more detail, please see the example `resnet` in `examples` repository.
use std::os::raw::{c_char, c_void};
use std::os::raw::c_char;
use tvm_common::{ffi, TVMArgValue};
use tvm_common::ffi;
/// A struct holding TVM byte-array.
///
......@@ -44,8 +44,9 @@ impl TVMByteArray {
}
}
impl<'a> From<&'a Vec<u8>> for TVMByteArray {
fn from(arg: &Vec<u8>) -> Self {
impl<'a, T: AsRef<[u8]>> From<T> for TVMByteArray {
fn from(arg: T) -> Self {
let arg = arg.as_ref();
let barr = ffi::TVMByteArray {
data: arg.as_ptr() as *const c_char,
size: arg.len(),
......@@ -54,18 +55,6 @@ impl<'a> From<&'a Vec<u8>> for TVMByteArray {
}
}
impl<'a> From<&TVMByteArray> for TVMArgValue<'a> {
fn from(arr: &TVMByteArray) -> Self {
Self {
value: ffi::TVMValue {
v_handle: &arr.inner as *const ffi::TVMByteArray as *const c_void as *mut c_void,
},
type_code: ffi::TVMTypeCode_kBytes as i64,
_lifetime: std::marker::PhantomData,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
......
......@@ -26,10 +26,7 @@ use std::{
use failure::Error;
use tvm_common::{
ffi::{self, TVMValue},
TVMArgValue,
};
use tvm_common::ffi;
use crate::function;
......@@ -125,18 +122,6 @@ impl<'a> From<&'a str> for TVMDeviceType {
}
}
impl<'a> From<&'a TVMDeviceType> for TVMArgValue<'a> {
fn from(dev_type: &'a TVMDeviceType) -> Self {
Self {
value: TVMValue {
v_int64: dev_type.0 as i64,
},
type_code: ffi::DLDataTypeCode_kDLInt as i64,
_lifetime: std::marker::PhantomData,
}
}
}
/// Represents the underlying device context. Default is cpu.
///
/// ## Examples
......@@ -209,7 +194,7 @@ impl TVMContext {
let dt = self.device_type.0 as usize;
// `unwrap` is ok here because if there is any error,
// if would occure inside `call_packed!`
let ret: u64 = call_packed!(func, &dt, &self.device_id, &0)
let ret: u64 = call_packed!(func, dt, self.device_id, 0)
.unwrap()
.try_into()
.unwrap();
......@@ -238,7 +223,9 @@ macro_rules! impl_device_attrs {
// `unwrap` is ok here because if there is any error,
// if would occur in function call.
function::Builder::from(func)
.args(&[dt, self.device_id as usize, $attr_kind])
.arg(dt)
.arg(self.device_id as usize)
.arg($attr_kind)
.invoke()
.unwrap()
.try_into()
......
......@@ -156,9 +156,9 @@ impl<'a, 'm> Builder<'a, 'm> {
}
/// Pushes a [`TVMArgValue`] into the function argument buffer.
pub fn arg<T: 'a>(&mut self, arg: &'a T) -> &mut Self
pub fn arg<T: 'a>(&mut self, arg: T) -> &mut Self
where
TVMArgValue<'a>: From<&'a T>,
TVMArgValue<'a>: From<T>,
{
self.arg_buf.push(arg.into());
self
......@@ -192,14 +192,11 @@ impl<'a, 'm> Builder<'a, 'm> {
ensure!(self.func.is_some(), errors::FunctionNotFoundError);
let num_args = self.arg_buf.len();
let (mut values, mut type_codes): (Vec<ffi::TVMValue>, Vec<ffi::TVMTypeCode>) = self
.arg_buf
.iter()
.map(|tvm_arg| (tvm_arg.value, tvm_arg.type_code as ffi::TVMTypeCode))
.unzip();
let (mut values, mut type_codes): (Vec<ffi::TVMValue>, Vec<ffi::TVMTypeCode>) =
self.arg_buf.iter().map(|arg| arg.to_tvm_value()).unzip();
let mut ret_val = unsafe { std::mem::uninitialized::<TVMValue>() };
let mut ret_type_code = 0;
let mut ret_type_code = 0i32;
check_call!(ffi::TVMFuncCall(
self.func.ok_or(errors::FunctionNotFoundError)?.handle,
values.as_mut_ptr(),
......@@ -209,7 +206,7 @@ impl<'a, 'm> Builder<'a, 'm> {
&mut ret_type_code as *mut _
));
Ok(unsafe { TVMRetValue::from_tvm_value(ret_val, ret_type_code as i64) })
Ok(unsafe { TVMRetValue::from_tvm_value(ret_val, ret_type_code as u32) })
}
}
......@@ -254,7 +251,7 @@ unsafe extern "C" fn tvm_callback(
{
check_call!(ffi::TVMCbArgToReturn(&mut value as *mut _, tcode));
}
local_args.push(TVMArgValue::new(value.into(), (tcode as i64).into()));
local_args.push(TVMArgValue::from_tvm_value(value.into(), tcode as u32));
}
let rv = match rust_fn(local_args.as_slice()) {
......@@ -265,7 +262,7 @@ unsafe extern "C" fn tvm_callback(
}
};
let (mut ret_val, ret_tcode) = rv.into_tvm_value();
let (mut ret_val, ret_tcode) = rv.to_tvm_value();
let mut ret_type_code = ret_tcode as c_int;
check_call!(ffi::TVMCFuncSetReturn(
ret,
......@@ -437,8 +434,9 @@ mod tests {
let str_arg = CString::new("test").unwrap();
let mut func = Builder::default();
func.get_function("tvm.graph_runtime.remote_create")
.args(&[10, 20])
.arg(&str_arg);
.arg(10)
.arg(20)
.arg(str_arg.as_c_str());
assert_eq!(func.arg_buf.len(), 3);
}
}
......@@ -80,7 +80,7 @@ impl Module {
CString::new(path.as_ref().to_str().ok_or_else(|| {
format_err!("Bad module load path: `{}`.", path.as_ref().display())
})?)?;
let ret: Module = call_packed!(func, &cpath, &ext)?.try_into()?;
let ret: Module = call_packed!(func, cpath.as_c_str(), ext.as_c_str())?.try_into()?;
Ok(ret)
}
......@@ -90,7 +90,10 @@ impl Module {
// `unwrap` is safe here because if there is any error during the
// function call, it would occur in `call_packed!`.
let tgt = CString::new(target).unwrap();
let ret: i64 = call_packed!(func, &tgt).unwrap().try_into().unwrap();
let ret: i64 = call_packed!(func, tgt.as_c_str())
.unwrap()
.try_into()
.unwrap();
ret != 0
}
......
......@@ -161,7 +161,7 @@ impl NDArray {
/// Converts the NDArray to [`TVMByteArray`].
pub fn to_bytearray(&self) -> Result<TVMByteArray, Error> {
let v = self.to_vec::<u8>()?;
Ok(TVMByteArray::from(&v))
Ok(TVMByteArray::from(v))
}
/// Creates an NDArray from a mutable buffer of types i32, u32 or f32 in cpu.
......
......@@ -2,140 +2,80 @@
//! and their conversions needed for the types used in frontend crate.
//! `TVMRetValue` is the owned version of `TVMPODValue`.
use std::{convert::TryFrom, os::raw::c_void};
use std::convert::TryFrom;
use failure::Error;
use tvm_common::{
ensure_type,
ffi::{self, TVMValue},
errors::ValueDowncastError,
ffi::{TVMArrayHandle, TVMFunctionHandle, TVMModuleHandle},
try_downcast,
};
use crate::{
common_errors::*, context::TVMContext, Function, Module, NDArray, TVMArgValue, TVMByteArray,
TVMRetValue,
};
use crate::{Function, Module, NDArray, TVMArgValue, TVMRetValue};
macro_rules! impl_tvm_val_from_handle {
($ty:ident, $type_code:expr, $handle:ty) => {
impl<'a> From<&'a $ty> for TVMArgValue<'a> {
fn from(arg: &$ty) -> Self {
TVMArgValue {
value: TVMValue {
v_handle: arg.handle as *mut _ as *mut c_void,
},
type_code: $type_code as i64,
_lifetime: std::marker::PhantomData,
}
macro_rules! impl_handle_val {
($type:ty, $variant:ident, $inner_type:ty, $ctor:path) => {
impl<'a> From<&'a $type> for TVMArgValue<'a> {
fn from(arg: &'a $type) -> Self {
TVMArgValue::$variant(arg.handle() as $inner_type)
}
}
impl<'a> From<&'a mut $ty> for TVMArgValue<'a> {
fn from(arg: &mut $ty) -> Self {
TVMArgValue {
value: TVMValue {
v_handle: arg.handle as *mut _ as *mut c_void,
},
type_code: $type_code as i64,
_lifetime: std::marker::PhantomData,
}
impl<'a> From<&'a mut $type> for TVMArgValue<'a> {
fn from(arg: &'a mut $type) -> Self {
TVMArgValue::$variant(arg.handle() as $inner_type)
}
}
impl<'a, 'v> TryFrom<&'a TVMArgValue<'v>> for $ty {
type Error = Error;
fn try_from(arg: &TVMArgValue<'v>) -> Result<$ty, Self::Error> {
ensure_type!(arg, $type_code);
Ok($ty::new(unsafe { arg.value.v_handle as $handle }))
impl<'a> TryFrom<TVMArgValue<'a>> for $type {
type Error = ValueDowncastError;
fn try_from(val: TVMArgValue<'a>) -> Result<$type, Self::Error> {
try_downcast!(val -> $type, |TVMArgValue::$variant(val)| { $ctor(val) })
}
}
impl From<$ty> for TVMRetValue {
fn from(val: $ty) -> TVMRetValue {
TVMRetValue {
value: TVMValue {
v_handle: val.handle() as *mut c_void,
},
box_value: box val,
type_code: $type_code as i64,
}
impl<'a, 'v> TryFrom<&'a TVMArgValue<'v>> for $type {
type Error = ValueDowncastError;
fn try_from(val: &'a TVMArgValue<'v>) -> Result<$type, Self::Error> {
try_downcast!(val -> $type, |TVMArgValue::$variant(val)| { $ctor(*val) })
}
}
impl TryFrom<TVMRetValue> for $ty {
type Error = Error;
fn try_from(ret: TVMRetValue) -> Result<$ty, Self::Error> {
ensure_type!(ret, $type_code);
Ok($ty::new(unsafe { ret.value.v_handle as $handle }))
}
}
};
}
impl_tvm_val_from_handle!(
Function,
ffi::TVMTypeCode_kFuncHandle,
ffi::TVMFunctionHandle
);
impl_tvm_val_from_handle!(Module, ffi::TVMTypeCode_kModuleHandle, ffi::TVMModuleHandle);
impl_tvm_val_from_handle!(NDArray, ffi::TVMTypeCode_kArrayHandle, ffi::TVMArrayHandle);
impl<'a> From<&'a TVMByteArray> for TVMValue {
fn from(barr: &TVMByteArray) -> Self {
TVMValue {
v_handle: &barr.inner as *const ffi::TVMByteArray as *mut c_void,
}
}
}
macro_rules! impl_boxed_ret_value {
($type:ty, $code:expr) => {
impl From<$type> for TVMRetValue {
fn from(val: $type) -> Self {
TVMRetValue {
value: TVMValue { v_int64: 0 },
box_value: box val,
type_code: $code as i64,
}
fn from(val: $type) -> TVMRetValue {
TVMRetValue::$variant(val.handle() as $inner_type)
}
}
impl TryFrom<TVMRetValue> for $type {
type Error = Error;
fn try_from(ret: TVMRetValue) -> Result<$type, Self::Error> {
if let Ok(val) = ret.box_value.downcast::<$type>() {
Ok(*val)
} else {
bail!(ValueDowncastError::new($code as i64, ret.type_code as i64))
}
type Error = ValueDowncastError;
fn try_from(val: TVMRetValue) -> Result<$type, Self::Error> {
try_downcast!(val -> $type, |TVMRetValue::$variant(val)| { $ctor(val) })
}
}
};
}
impl_boxed_ret_value!(TVMContext, ffi::TVMTypeCode_kTVMContext);
impl_boxed_ret_value!(TVMByteArray, ffi::TVMTypeCode_kBytes);
impl<'a, 'v> TryFrom<&'a TVMArgValue<'v>> for TVMByteArray {
type Error = Error;
fn try_from(arg: &TVMArgValue<'v>) -> Result<Self, Self::Error> {
ensure_type!(arg, ffi::TVMTypeCode_kBytes);
Ok(TVMByteArray::new(unsafe {
*(arg.value.v_handle as *mut ffi::TVMByteArray)
}))
}
}
impl_handle_val!(Function, FuncHandle, TVMFunctionHandle, Function::new);
impl_handle_val!(Module, ModuleHandle, TVMModuleHandle, Module::new);
impl_handle_val!(NDArray, ArrayHandle, TVMArrayHandle, NDArray::new);
#[cfg(test)]
mod tests {
use super::*;
use std::{convert::TryInto, str::FromStr};
use tvm_common::ffi::TVMType;
use tvm_common::{TVMByteArray, TVMContext, TVMType};
use super::*;
#[test]
fn bytearray() {
let w = vec![1u8, 2, 3, 4, 5];
let v = TVMByteArray::from(&w);
let v = TVMByteArray::from(w.as_slice());
let tvm: TVMByteArray = TVMRetValue::from(v).try_into().unwrap();
assert_eq!(tvm.data(), w.iter().map(|e| *e as i8).collect::<Vec<i8>>());
assert_eq!(
tvm.data(),
w.iter().map(|e| *e).collect::<Vec<u8>>().as_slice()
);
}
#[test]
......@@ -147,7 +87,7 @@ mod tests {
#[test]
fn ctx() {
let c = TVMContext::from("gpu");
let c = TVMContext::from_str("gpu").unwrap();
let tvm: TVMContext = TVMRetValue::from(c).try_into().unwrap();
assert_eq!(tvm, c);
}
......
......@@ -24,9 +24,9 @@ fn main() {
registered.get_function("concate_str");
assert!(registered.func.is_some());
let ret: String = registered
.arg(&a)
.arg(&b)
.arg(&c)
.arg(a.as_c_str())
.arg(b.as_c_str())
.arg(c.as_c_str())
.invoke()
.unwrap()
.try_into()
......
......@@ -8,6 +8,7 @@ readme = "README.md"
keywords = ["tvm", "nnvm"]
categories = ["api-bindings", "science"]
authors = ["TVM Contributors"]
edition = "2018"
[features]
default = ["nom/std"]
......
......@@ -265,7 +265,7 @@ impl<'m, 't> GraphExecutor<'m, 't> {
.iter()
.map(|t| t.into())
.collect::<Vec<TVMArgValue>>();
func(args.as_slice()).unwrap();
func(&args).unwrap();
};
op_execs.push(op);
}
......@@ -283,7 +283,7 @@ impl<'m, 't> GraphExecutor<'m, 't> {
// TODO: consider `new_with_params` to avoid ever allocating
let ptr = self.tensors[idx].data.as_ptr();
let mut to_replace = self.tensors.iter_mut().filter(|t| t.data.as_ptr() == ptr);
let mut owner = to_replace.nth(0).unwrap();
let owner = to_replace.nth(0).unwrap();
if value.data.is_owned() {
// FIXME: for no-copy, need setup_op_execs to not capture tensor ptr
// mem::replace(&mut (*owner), value);
......
......@@ -40,17 +40,14 @@ pub(super) fn wrap_backend_packed_func(
func: BackendPackedCFunc,
) -> Box<dyn PackedFunc> {
box move |args: &[TVMArgValue]| {
let exit_code = func(
args.iter()
.map(|ref arg| arg.value)
.collect::<Vec<TVMValue>>()
.as_ptr(),
args.iter()
.map(|ref arg| arg.type_code as i32)
.collect::<Vec<i32>>()
.as_ptr() as *const i32,
args.len() as i32,
);
let (values, type_codes): (Vec<TVMValue>, Vec<i32>) = args
.into_iter()
.map(|arg| {
let (val, code) = arg.to_tvm_value();
(val, code as i32)
})
.unzip();
let exit_code = func(values.as_ptr(), type_codes.as_ptr(), values.len() as i32);
if exit_code == 0 {
Ok(TVMRetValue::default())
} else {
......
#![feature(try_from)]
extern crate serde;
extern crate serde_json;
......
#![feature(try_from)]
#[macro_use]
extern crate ndarray;
extern crate serde;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment