/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ //! This module implements the [`NDArray`] type for working with *TVM tensors* or //! coverting from a Rust's ndarray to TVM `NDArray`. //! //! One can create an empty NDArray given the shape, context and dtype using [`empty`]. //! To create an NDArray from a mutable buffer in cpu use [`copy_from_buffer`]. //! To copy an NDArray to different context use [`copy_to_ctx`]. //! //! Given a [`Rust's dynamic ndarray`], one can convert it to TVM NDArray as follows: //! //! # Example //! //! ``` //! let a = Array::from_shape_vec((2, 2), vec![1f32, 2., 3., 4.]) //! .unwrap() //! .into_dyn(); // Rust's ndarray //! let nd = NDArray::from_rust_ndarray(&a, TVMContext::cpu(0), TVMType::from("float32")).unwrap(); //! assert_eq!(nd.shape(), Some(&mut [2, 2])); //! let rnd: ArrayD<f32> = ArrayD::try_from(&nd).unwrap(); //! assert!(rnd.all_close(&a, 1e-8f32)); //! ``` //! //! [`Rust's dynamic ndarray`]:https://docs.rs/ndarray/0.12.1/ndarray/ //! [`copy_from_buffer`]:struct.NDArray.html#method.copy_from_buffer //! [`copy_to_ctx`]:struct.NDArray.html#method.copy_to_ctx use std::{convert::TryFrom, mem, os::raw::c_int, ptr, slice, str::FromStr}; use failure::Error; use num_traits::Num; use rust_ndarray::{Array, ArrayD}; use tvm_common::{ffi, TVMType}; use crate::{errors, TVMByteArray, TVMContext}; /// See the [`module-level documentation`](../ndarray/index.html) for more details. /// /// Wrapper around TVM array handle. #[derive(Debug)] pub struct NDArray { pub(crate) handle: ffi::TVMArrayHandle, is_view: bool, } impl NDArray { pub(crate) fn new(handle: ffi::TVMArrayHandle) -> Self { NDArray { handle: handle, is_view: true, } } /// Returns the underlying array handle. pub fn handle(&self) -> ffi::TVMArrayHandle { self.handle } pub fn is_view(&self) -> bool { self.is_view } /// Returns the shape of the NDArray. pub fn shape(&self) -> Option<&mut [usize]> { let arr = unsafe { *(self.handle) }; if arr.shape.is_null() || arr.data.is_null() { return None; }; let slc = unsafe { slice::from_raw_parts_mut(arr.shape as *mut usize, arr.ndim as usize) }; Some(slc) } /// Returns the total number of entries of the NDArray. pub fn size(&self) -> Option<usize> { self.shape() .map(|v| v.into_iter().fold(1, |acc, &mut e| acc * e)) } /// Returns the context which the NDArray was defined. pub fn ctx(&self) -> TVMContext { unsafe { (*self.handle).ctx.into() } } /// Returns the type of the entries of the NDArray. pub fn dtype(&self) -> TVMType { unsafe { (*self.handle).dtype.into() } } /// Returns the number of dimensions of the NDArray. pub fn ndim(&self) -> usize { unsafe { (*self.handle).ndim as usize } } /// Returns the strides of the underlying NDArray. pub fn strides(&self) -> Option<&[usize]> { unsafe { let sz = self.ndim() * mem::size_of::<usize>(); let slc = slice::from_raw_parts((*self.handle).strides as *const usize, sz); Some(slc) } } /// Shows whether the underlying ndarray is contiguous in memory or not. pub fn is_contiguous(&self) -> Result<bool, Error> { Ok(match self.strides() { None => true, Some(strides) => { // errors::MissingShapeError in case shape is not determined self.shape() .ok_or(errors::MissingShapeError)? .iter() .zip(strides) .rfold( (true, 1), |(is_contig, expected_stride), (shape, stride)| { ( is_contig && *stride == expected_stride, expected_stride * (*shape as usize), ) }, ) .0 } }) } pub fn byte_offset(&self) -> isize { unsafe { (*self.handle).byte_offset as isize } } /// Flattens the NDArray to a `Vec` of the same type in cpu. /// /// ## Example /// /// ``` /// let shape = &mut [4]; /// let mut data = vec![1i32, 2, 3, 4]; /// let ctx = TVMContext::cpu(0); /// let mut ndarray = empty(shape, ctx, TVMType::from("int32")); /// ndarray.copy_from_buffer(&mut data); /// assert_eq!(ndarray.shape(), Some(shape)); /// assert_eq!(ndarray.to_vec::<i32>().unwrap(), data); /// ``` pub fn to_vec<T>(&self) -> Result<Vec<T>, Error> { ensure!(self.shape().is_some(), errors::EmptyArrayError); let earr = NDArray::empty( self.shape().ok_or(errors::MissingShapeError)?, TVMContext::cpu(0), self.dtype(), ); let target = self.copy_to_ndarray(earr)?; let arr = unsafe { *(target.handle) }; let sz = self.size().ok_or(errors::MissingShapeError)?; let mut v: Vec<T> = Vec::with_capacity(sz * mem::size_of::<T>()); unsafe { v.as_mut_ptr() .copy_from_nonoverlapping(arr.data as *const T, sz); v.set_len(sz); } Ok(v) } /// Converts the NDArray to [`TVMByteArray`]. pub fn to_bytearray(&self) -> Result<TVMByteArray, Error> { let v = self.to_vec::<u8>()?; Ok(TVMByteArray::from(v)) } /// Creates an NDArray from a mutable buffer of types i32, u32 or f32 in cpu. /// /// ## Example /// /// ``` /// let shape = &mut [2]; /// let mut data = vec![1f32, 2]; /// let ctx = TVMContext::gpu(0); /// let mut ndarray = empty(shape, ctx, TVMType::from("int32")); /// ndarray.copy_from_buffer(&mut data); /// ``` /// /// *Note*: if something goes wrong during the copy, it will panic /// from TVM side. See `TVMArrayCopyFromBytes` in `include/tvm/runtime/c_runtime_api.h`. pub fn copy_from_buffer<T: Num32>(&mut self, data: &mut [T]) { check_call!(ffi::TVMArrayCopyFromBytes( self.handle, data.as_ptr() as *mut _, data.len() * mem::size_of::<T>() )); } /// Copies the NDArray to another target NDArray. pub fn copy_to_ndarray(&self, target: NDArray) -> Result<NDArray, Error> { if self.dtype() != target.dtype() { bail!( "{}", errors::TypeMismatchError { expected: format!("{}", self.dtype().to_string()), actual: format!("{}", target.dtype().to_string()), } ); } check_call!(ffi::TVMArrayCopyFromTo( self.handle, target.handle, ptr::null_mut() as ffi::TVMStreamHandle )); Ok(target) } /// Copies the NDArray to a target context. pub fn copy_to_ctx(&self, target: &TVMContext) -> Result<NDArray, Error> { let tmp = NDArray::empty( self.shape().ok_or(errors::MissingShapeError)?, target.clone(), self.dtype(), ); let copy = self.copy_to_ndarray(tmp)?; Ok(copy) } /// Converts a Rust's ndarray to TVM NDArray. pub fn from_rust_ndarray<T: Num32 + Copy>( rnd: &ArrayD<T>, ctx: TVMContext, dtype: TVMType, ) -> Result<Self, Error> { let mut shape = rnd.shape().to_vec(); let mut nd = NDArray::empty(&mut shape, ctx, dtype); let mut buf = Array::from_iter(rnd.into_iter().map(|&v| v as T)); nd.copy_from_buffer( buf.as_slice_mut() .expect("Array from iter must be contiguous."), ); Ok(nd) } /// Allocates and creates an empty NDArray given the shape, context and dtype. pub fn empty(shape: &[usize], ctx: TVMContext, dtype: TVMType) -> NDArray { let mut handle = ptr::null_mut() as ffi::TVMArrayHandle; check_call!(ffi::TVMArrayAlloc( shape.as_ptr() as *const i64, shape.len() as c_int, dtype.code as c_int, dtype.bits as c_int, dtype.lanes as c_int, ctx.device_type.0 as c_int, ctx.device_id as c_int, &mut handle as *mut _, )); NDArray { handle, is_view: false, } } } macro_rules! impl_from_ndarray_rustndarray { ($type:ty, $type_name:tt) => { impl<'a> TryFrom<&'a NDArray> for ArrayD<$type> { type Error = Error; fn try_from(nd: &NDArray) -> Result<ArrayD<$type>, Self::Error> { ensure!(nd.shape().is_some(), errors::MissingShapeError); assert_eq!(nd.dtype(), TVMType::from_str($type_name)?, "Type mismatch"); Ok(Array::from_shape_vec( &*nd.shape().ok_or(errors::MissingShapeError)?, nd.to_vec::<$type>()?, )?) } } impl<'a> TryFrom<&'a mut NDArray> for ArrayD<$type> { type Error = Error; fn try_from(nd: &mut NDArray) -> Result<ArrayD<$type>, Self::Error> { ensure!(nd.shape().is_some(), errors::MissingShapeError); assert_eq!(nd.dtype(), TVMType::from_str($type_name)?, "Type mismatch"); Ok(Array::from_shape_vec( &*nd.shape().ok_or(errors::MissingShapeError)?, nd.to_vec::<$type>()?, )?) } } }; } impl_from_ndarray_rustndarray!(i32, "int"); impl_from_ndarray_rustndarray!(u32, "uint"); impl_from_ndarray_rustndarray!(f32, "float"); impl Drop for NDArray { fn drop(&mut self) { if !self.is_view { check_call!(ffi::TVMArrayFree(self.handle)); } } } mod sealed { /// Private trait to prevent other traits from being implemeneted in downstream crates. pub trait Sealed {} } /// A trait for the supported 32-bits numerical types in frontend. pub trait Num32: Num + sealed::Sealed { const BITS: u8 = 32; } macro_rules! impl_num32 { ($($type:ty),+) => { $( impl sealed::Sealed for $type {} impl Num32 for $type {} )+ }; } impl_num32!(i32, u32, f32); #[cfg(test)] mod tests { use super::*; #[test] fn basics() { let shape = &mut [1, 2, 3]; let ctx = TVMContext::cpu(0); let ndarray = NDArray::empty(shape, ctx, TVMType::from_str("int32").unwrap()); assert_eq!(ndarray.shape().unwrap(), shape); assert_eq!( ndarray.size().unwrap(), shape.to_vec().into_iter().product() ); assert_eq!(ndarray.ndim(), 3); assert!(ndarray.strides().is_none()); assert_eq!(ndarray.byte_offset(), 0); } #[test] fn copy() { let shape = &mut [4]; let mut data = vec![1i32, 2, 3, 4]; let ctx = TVMContext::cpu(0); let mut ndarray = NDArray::empty(shape, ctx, TVMType::from_str("int32").unwrap()); assert!(ndarray.to_vec::<i32>().is_ok()); ndarray.copy_from_buffer(&mut data); assert_eq!(ndarray.shape().unwrap(), shape); assert_eq!(ndarray.to_vec::<i32>().unwrap(), data); assert_eq!(ndarray.ndim(), 1); assert!(ndarray.is_contiguous().is_ok()); assert_eq!(ndarray.byte_offset(), 0); let mut shape = vec![4]; let e = NDArray::empty( &mut shape, TVMContext::cpu(0), TVMType::from_str("int32").unwrap(), ); let nd = ndarray.copy_to_ndarray(e); assert!(nd.is_ok()); assert_eq!(nd.unwrap().to_vec::<i32>().unwrap(), data); } #[test] #[should_panic(expected = "called `Result::unwrap()` on an `Err`")] fn copy_wrong_dtype() { let mut shape = vec![4]; let mut data = vec![1f32, 2., 3., 4.]; let ctx = TVMContext::cpu(0); let mut nd_float = NDArray::empty( &mut shape, ctx.clone(), TVMType::from_str("float32").unwrap(), ); nd_float.copy_from_buffer(&mut data); let empty_int = NDArray::empty(&mut shape, ctx, TVMType::from_str("int32").unwrap()); nd_float.copy_to_ndarray(empty_int).unwrap(); } #[test] fn rust_ndarray() { let a = Array::from_shape_vec((2, 2), vec![1f32, 2., 3., 4.]) .unwrap() .into_dyn(); let nd = NDArray::from_rust_ndarray( &a, TVMContext::cpu(0), TVMType::from_str("float32").unwrap(), ) .unwrap(); assert_eq!(nd.shape().unwrap(), &mut [2, 2]); let rnd: ArrayD<f32> = ArrayD::try_from(&nd).unwrap(); assert!(rnd.all_close(&a, 1e-8f32)); } }