use std::{ any::TypeId, convert::TryFrom, mem, os::raw::{c_int, c_void}, ptr, slice, }; use ndarray; use super::allocator::Allocation; use errors::*; use ffi::runtime::{ DLContext, DLDataType, DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLDataTypeCode_kDLUInt, DLDeviceType_kDLCPU, DLTensor, }; /// A `Storage` is a container which holds `Tensor` data. #[derive(PartialEq)] pub enum Storage<'a> { /// A `Storage` which owns its contained bytes. Owned(Allocation), /// A view of an existing `Storage`. View(&'a mut [u8], usize), // ptr, align } impl<'a> Storage<'a> { pub fn new(size: usize, align: Option<usize>) -> Result<Storage<'static>> { Ok(Storage::Owned(Allocation::new(size, align)?)) } pub fn as_mut_ptr(&self) -> *mut u8 { match self { Storage::Owned(alloc) => alloc.as_mut_ptr(), Storage::View(slice, _) => slice.as_ptr() as *mut u8, } } pub fn size(&self) -> usize { match self { Storage::Owned(alloc) => alloc.size(), Storage::View(slice, _) => slice.len(), } } pub fn align(&self) -> usize { match self { Storage::Owned(alloc) => alloc.align(), Storage::View(_, align) => *align, } } pub fn as_ptr(&self) -> *const u8 { self.as_mut_ptr() as *const _ } /// Returns a `Storage::View` which points to an owned `Storage::Owned`. pub fn view(&self) -> Storage<'a> { match self { Storage::Owned(alloc) => Storage::View( unsafe { slice::from_raw_parts_mut(alloc.as_mut_ptr(), self.size()) }, self.align(), ), Storage::View(slice, _) => Storage::View( unsafe { slice::from_raw_parts_mut(self.as_mut_ptr(), slice.len()) }, self.align(), ), } } pub fn is_owned(&self) -> bool { match self { Storage::Owned(_) => true, _ => false, } } /// Returns an owned version of this storage via cloning. pub fn to_owned(&self) -> Storage<'static> { let s = Storage::new(self.size(), Some(self.align())).unwrap(); unsafe { s.as_mut_ptr() .copy_from_nonoverlapping(self.as_ptr(), self.size()) } s } } impl<'a, T> From<&'a [T]> for Storage<'a> { fn from(data: &'a [T]) -> Self { let data = unsafe { slice::from_raw_parts_mut( data.as_ptr() as *const u8 as *mut u8, data.len() * mem::size_of::<T>() as usize, ) }; Storage::View(data, mem::align_of::<T>()) } } /// A n-dimensional array type which can be converted to/from `tvm::DLTensor` and `ndarray::Array`. /// `Tensor` is primarily a holder of data which can be operated on via TVM (via `DLTensor`) or /// converted to `ndarray::Array` for non-TVM processing. /// /// # Examples /// /// ``` /// extern crate ndarray; /// /// let mut a_nd: ndarray::Array = ndarray::Array::from_vec(vec![1f32, 2., 3., 4.]); /// let mut a: Tensor = a_nd.into(); /// let mut a_dl: DLTensor = (&mut t).into(); /// call_packed!(tvm_fn, &mut a_dl); /// /// // Array -> Tensor is mostly useful when post-processing TVM graph outputs. /// let mut a_nd = ndarray::Array::try_from(&a).unwrap(); /// ``` #[derive(PartialEq)] pub struct Tensor<'a> { /// The bytes which contain the data this `Tensor` represents. pub(super) data: Storage<'a>, pub(super) ctx: TVMContext, pub(super) dtype: DataType, pub(super) shape: Vec<i64>, // not usize because `typedef int64_t tvm_index_t` in c_runtime_api.h /// The `Tensor` strides. Can be `None` if the `Tensor` is contiguous. pub(super) strides: Option<Vec<usize>>, pub(super) byte_offset: isize, /// The number of elements in the `Tensor`. pub(super) size: usize, } unsafe impl<'a> Send for Tensor<'a> {} impl<'a> Tensor<'a> { pub fn shape(&self) -> Vec<i64> { self.shape.clone() } /// Returns the data of this `Tensor` as a `Vec`. /// /// # Panics /// /// Panics if the `Tensor` is not contiguous or does not contain elements of type `T`. pub fn to_vec<T: 'static>(&self) -> Vec<T> { assert!(self.is_contiguous()); assert!(self.dtype.is_type::<T>()); let mut vec: Vec<T> = Vec::with_capacity(self.size * self.dtype.itemsize()); unsafe { vec.as_mut_ptr().copy_from_nonoverlapping( self.data.as_ptr().offset(self.byte_offset) as *const T, self.size, ); vec.set_len(self.size); } vec } /// Returns `true` iff this `Tensor` is represented by a contiguous region of memory. pub fn is_contiguous(&self) -> bool { match self.strides { None => true, Some(ref strides) => { // check that stride for each dimension is the product of all trailing dimensons' shapes self .shape .iter() .zip(strides) .rfold( (true, 1), |(is_contig, expected_stride), (shape, stride)| { ( is_contig && *stride == expected_stride, expected_stride * (*shape as usize), ) }, ) .0 } } } /// Returns a clone of this `Tensor`. /// /// # Panics /// /// Panics if the `Tensor` is not contiguous or does not contain elements of type `T`. pub fn copy(&mut self, other: &Tensor) { assert!( self.dtype == other.dtype && self.size == other.size, "Tensor shape/dtype mismatch." ); assert!( self.is_contiguous() && other.is_contiguous(), "copy currently requires contiguous tensors\n`self.strides = {:?}` `other.strides = {:?}`", self.strides, other.strides ); unsafe { self .data .as_mut_ptr() .offset(self.byte_offset as isize) .copy_from_nonoverlapping( other.data.as_mut_ptr().offset(other.byte_offset), other.size * other.dtype.itemsize(), ); } } /// Returns an owned version of this `Tensor` via cloning. pub fn to_owned(&self) -> Tensor<'static> { let t = Tensor { data: self.data.to_owned(), ctx: self.ctx.clone(), dtype: self.dtype.clone(), size: self.size.clone(), shape: self.shape.clone(), strides: None, byte_offset: 0, }; unsafe { mem::transmute::<Tensor<'a>, Tensor<'static>>(t) } } fn from_array_storage<'s, T, D: ndarray::Dimension>( arr: &ndarray::Array<T, D>, storage: Storage<'s>, type_code: usize, ) -> Tensor<'s> { let type_width = mem::size_of::<T>() as usize; Tensor { data: storage, ctx: TVMContext::default(), dtype: DataType { code: type_code, bits: 8 * type_width, lanes: 1, }, size: arr.len(), shape: arr.shape().iter().map(|&v| v as i64).collect(), strides: Some(arr.strides().into_iter().map(|&v| v as usize).collect()), byte_offset: 0, } } } /// Conversions to `ndarray::Array` from `Tensor`, if the types match. macro_rules! impl_ndarray_try_from_tensor { ($type:ty, $dtype:expr) => { impl<'a, 't> TryFrom<&'a Tensor<'t>> for ndarray::ArrayD<$type> { type Error = Error; fn try_from(tensor: &'a Tensor) -> Result<ndarray::ArrayD<$type>> { ensure!( tensor.dtype == $dtype, "Cannot convert Tensor with dtype {:?} to ndarray", tensor.dtype ); Ok(ndarray::Array::from_shape_vec( tensor .shape .iter() .map(|s| *s as usize) .collect::<Vec<usize>>(), tensor.to_vec::<$type>(), )?) } } }; } impl_ndarray_try_from_tensor!(i32, DTYPE_INT32); impl_ndarray_try_from_tensor!(u32, DTYPE_UINT32); impl_ndarray_try_from_tensor!(f32, DTYPE_FLOAT32); impl_ndarray_try_from_tensor!(f64, DTYPE_FLOAT64); impl DLTensor { pub(super) fn from_tensor<'a>(tensor: &'a Tensor, flatten: bool) -> Self { assert!(!flatten || tensor.is_contiguous()); Self { data: unsafe { tensor.data.as_mut_ptr().offset(tensor.byte_offset) } as *mut c_void, ctx: DLContext::from(&tensor.ctx), ndim: if flatten { 1 } else { tensor.shape.len() } as i32, dtype: DLDataType::from(&tensor.dtype), shape: if flatten { &tensor.size as *const _ as *mut i64 } else { tensor.shape.as_ptr() } as *mut i64, strides: if flatten || tensor.is_contiguous() { ptr::null_mut() } else { tensor.strides.as_ref().unwrap().as_ptr() } as *mut i64, byte_offset: 0, } } } impl<'a, 't> From<&'a Tensor<'t>> for DLTensor { fn from(tensor: &'a Tensor<'t>) -> Self { DLTensor::from_tensor(tensor, false /* flatten */) } } impl<'a, 't> From<&'a mut Tensor<'t>> for DLTensor { fn from(tensor: &'a mut Tensor<'t>) -> Self { DLTensor::from_tensor(tensor, false /* flatten */) } } #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub struct DataType { pub(super) code: usize, pub(super) bits: usize, pub(super) lanes: usize, } impl DataType { /// Returns the number of bytes occupied by an element of this `DataType`. pub fn itemsize(&self) -> usize { (self.bits * self.lanes) >> 3 } /// Returns whether this `DataType` represents primitive type `T`. pub fn is_type<T: 'static>(&self) -> bool { if self.lanes != 1 { return false; } let typ = TypeId::of::<T>(); (typ == TypeId::of::<i32>() && self.code == 0 && self.bits == 32) || (typ == TypeId::of::<i64>() && self.code == 0 && self.bits == 64) || (typ == TypeId::of::<u32>() && self.code == 1 && self.bits == 32) || (typ == TypeId::of::<u64>() && self.code == 1 && self.bits == 64) || (typ == TypeId::of::<f32>() && self.code == 2 && self.bits == 32) || (typ == TypeId::of::<f64>() && self.code == 2 && self.bits == 64) } } impl<'a> From<&'a DataType> for DLDataType { fn from(dtype: &'a DataType) -> Self { Self { code: dtype.code as u8, bits: dtype.bits as u8, lanes: dtype.lanes as u16, } } } impl From<DLDataType> for DataType { fn from(dtype: DLDataType) -> Self { Self { code: dtype.code as usize, bits: dtype.bits as usize, lanes: dtype.lanes as usize, } } } macro_rules! make_dtype_const { ($name: ident, $code: ident, $bits: expr, $lanes: expr) => { const $name: DataType = DataType { code: $code as usize, bits: $bits, lanes: $lanes, }; }; } make_dtype_const!(DTYPE_INT32, DLDataTypeCode_kDLInt, 32, 1); make_dtype_const!(DTYPE_UINT32, DLDataTypeCode_kDLUInt, 32, 1); // make_dtype_const!(DTYPE_FLOAT16, DLDataTypeCode_kDLFloat, 16, 1); make_dtype_const!(DTYPE_FLOAT32, DLDataTypeCode_kDLFloat, 32, 1); make_dtype_const!(DTYPE_FLOAT64, DLDataTypeCode_kDLFloat, 64, 1); impl Default for DLContext { fn default() -> Self { DLContext { device_type: DLDeviceType_kDLCPU, device_id: 0, } } } #[derive(Debug, Clone, Copy, PartialEq)] pub struct TVMContext { pub(super) device_type: usize, pub(super) device_id: usize, } impl<'a> From<&'a TVMContext> for DLContext { fn from(ctx: &'a TVMContext) -> Self { Self { device_type: ctx.device_type as u32, device_id: ctx.device_id as i32, } } } impl Default for TVMContext { fn default() -> Self { Self { device_type: DLDeviceType_kDLCPU as usize, device_id: 0, } } } impl<'a> From<DLTensor> for Tensor<'a> { fn from(dlt: DLTensor) -> Self { unsafe { let dtype = DataType::from(dlt.dtype); let shape = slice::from_raw_parts(dlt.shape, dlt.ndim as usize).to_vec(); let size = shape.iter().map(|v| *v as usize).product::<usize>() as usize; let storage = Storage::from(slice::from_raw_parts( dlt.data as *const u8, dtype.itemsize() * size, )); Self { data: storage, ctx: TVMContext::default(), dtype: dtype, size: size, shape: shape, strides: if dlt.strides == ptr::null_mut() { None } else { Some(slice::from_raw_parts_mut(dlt.strides as *mut usize, size).to_vec()) }, byte_offset: dlt.byte_offset as isize, } } } } /// `From` conversions to `Tensor` for owned or borrowed `ndarray::Array`. /// /// # Panics /// /// Panics if the ndarray is not contiguous. macro_rules! impl_tensor_from_ndarray { ($type:ty, $typecode:expr) => { impl<D: ndarray::Dimension> From<ndarray::Array<$type, D>> for Tensor<'static> { fn from(arr: ndarray::Array<$type, D>) -> Self { assert!(arr.is_standard_layout(), "Array must be contiguous."); let size = arr.len() * mem::size_of::<$type>() as usize; let storage = Storage::from(unsafe { slice::from_raw_parts(arr.as_ptr() as *const u8, size) }); Tensor::from_array_storage(&arr, storage, $typecode as usize) } } impl<'a, D: ndarray::Dimension> From<&'a ndarray::Array<$type, D>> for Tensor<'a> { fn from(arr: &'a ndarray::Array<$type, D>) -> Self { assert!(arr.is_standard_layout(), "Array must be contiguous."); Tensor::from_array_storage( arr, Storage::from(arr.as_slice().unwrap()), $typecode as usize, ) } } }; } /// `From` conversions to `DLTensor` for `ndarray::Array`. /// Takes a reference to the `ndarray` since `DLTensor` is not owned. macro_rules! impl_dltensor_from_ndarray { ($type:ty, $typecode:expr) => { impl<'a, D: ndarray::Dimension> From<&'a mut ndarray::Array<$type, D>> for DLTensor { fn from(arr: &'a mut ndarray::Array<$type, D>) -> Self { DLTensor { data: arr.as_mut_ptr() as *mut c_void, ctx: DLContext::default(), ndim: arr.ndim() as c_int, dtype: DLDataType { code: $typecode as u8, bits: 8 * mem::size_of::<$type>() as u8, lanes: 1, }, shape: arr.shape().as_ptr() as *const i64 as *mut i64, strides: arr.strides().as_ptr() as *const isize as *mut i64, byte_offset: 0, } } } }; } impl_dltensor_from_ndarray!(f32, DLDataTypeCode_kDLFloat); impl_dltensor_from_ndarray!(f64, DLDataTypeCode_kDLFloat); impl_dltensor_from_ndarray!(i32, DLDataTypeCode_kDLInt); impl_dltensor_from_ndarray!(i64, DLDataTypeCode_kDLInt); impl_dltensor_from_ndarray!(u32, DLDataTypeCode_kDLUInt); impl_dltensor_from_ndarray!(u64, DLDataTypeCode_kDLUInt); impl_tensor_from_ndarray!(f32, DLDataTypeCode_kDLFloat); impl_tensor_from_ndarray!(f64, DLDataTypeCode_kDLFloat); impl_tensor_from_ndarray!(i32, DLDataTypeCode_kDLInt); impl_tensor_from_ndarray!(i64, DLDataTypeCode_kDLInt); impl_tensor_from_ndarray!(u32, DLDataTypeCode_kDLUInt); impl_tensor_from_ndarray!(u64, DLDataTypeCode_kDLUInt);