Unverified Commit 9c591510 by Jared Roesch Committed by GitHub

[Rust][CI] Restore Rust CI (#5137)

parent 8c31d0dd
...@@ -20,62 +20,12 @@ hard_tabs = false ...@@ -20,62 +20,12 @@ hard_tabs = false
tab_spaces = 4 tab_spaces = 4
newline_style = "Auto" newline_style = "Auto"
use_small_heuristics = "Default" use_small_heuristics = "Default"
indent_style = "Block"
wrap_comments = false
format_code_in_doc_comments = false
comment_width = 80
normalize_comments = false
normalize_doc_attributes = false
format_strings = false
format_macro_matchers = false
format_macro_bodies = true
empty_item_single_line = true
struct_lit_single_line = true
fn_single_line = false
where_single_line = false
imports_indent = "Block"
imports_layout = "Mixed"
merge_imports = true
reorder_imports = true reorder_imports = true
reorder_modules = true reorder_modules = true
reorder_impl_items = false
type_punctuation_density = "Wide"
space_before_colon = false
space_after_colon = true
spaces_around_ranges = false
binop_separator = "Front"
remove_nested_parens = true remove_nested_parens = true
combine_control_expr = true
overflow_delimited_expr = false
struct_field_align_threshold = 0
enum_discrim_align_threshold = 0
match_arm_blocks = true
force_multiline_blocks = false
fn_args_layout = "Tall" fn_args_layout = "Tall"
brace_style = "SameLineWhere"
control_brace_style = "AlwaysSameLine"
trailing_semicolon = true
trailing_comma = "Vertical"
match_block_trailing_comma = false
blank_lines_upper_bound = 1
blank_lines_lower_bound = 0
edition = "2018" edition = "2018"
version = "One"
inline_attribute_width = 0
merge_derives = true merge_derives = true
use_try_shorthand = false use_try_shorthand = false
use_field_init_shorthand = false use_field_init_shorthand = false
force_explicit_abi = true force_explicit_abi = true
condense_wildcard_suffixes = false
color = "Auto"
unstable_features = false
disable_all_formatting = false
skip_children = false
hide_parse_errors = false
error_on_line_overflow = false
error_on_unformatted = false
report_todo = "Never"
report_fixme = "Never"
ignore = []
emit_mode = "Files"
make_backup = false
...@@ -42,5 +42,5 @@ pub mod packed_func; ...@@ -42,5 +42,5 @@ pub mod packed_func;
pub mod value; pub mod value;
pub use errors::*; pub use errors::*;
pub use ffi::{TVMByteArray, TVMContext, DLDataType as TVMType}; pub use ffi::{DLDataType as TVMType, TVMByteArray, TVMContext};
pub use packed_func::{TVMArgValue, TVMRetValue}; pub use packed_func::{TVMArgValue, TVMRetValue};
...@@ -26,10 +26,15 @@ use std::{ ...@@ -26,10 +26,15 @@ use std::{
pub use crate::ffi::TVMValue; pub use crate::ffi::TVMValue;
use crate::{errors::ValueDowncastError, ffi::*}; use crate::{errors::ValueDowncastError, ffi::*};
pub trait PackedFunc : Fn(&[TVMArgValue]) -> Result<TVMRetValue, crate::errors::FuncCallError> + Send + Sync {} pub trait PackedFunc:
Fn(&[TVMArgValue]) -> Result<TVMRetValue, crate::errors::FuncCallError> + Send + Sync
{
}
impl<T> PackedFunc for T impl<T> PackedFunc for T where
where T : Fn(&[TVMArgValue]) -> Result<TVMRetValue, crate::errors::FuncCallError> + Send + Sync {} T: Fn(&[TVMArgValue]) -> Result<TVMRetValue, crate::errors::FuncCallError> + Send + Sync
{
}
/// Calls a packed function and returns a `TVMRetValue`. /// Calls a packed function and returns a `TVMRetValue`.
/// ///
...@@ -76,7 +81,7 @@ macro_rules! TVMPODValue { ...@@ -76,7 +81,7 @@ macro_rules! TVMPODValue {
ObjectHandle(*mut c_void), ObjectHandle(*mut c_void),
ModuleHandle(TVMModuleHandle), ModuleHandle(TVMModuleHandle),
FuncHandle(TVMFunctionHandle), FuncHandle(TVMFunctionHandle),
NDArrayContainer(*mut c_void), NDArrayHandle(*mut c_void),
$($extra_variant($variant_type)),+ $($extra_variant($variant_type)),+
} }
...@@ -97,7 +102,7 @@ macro_rules! TVMPODValue { ...@@ -97,7 +102,7 @@ macro_rules! TVMPODValue {
TVMTypeCode_kTVMObjectHandle => ObjectHandle($value.v_handle), TVMTypeCode_kTVMObjectHandle => ObjectHandle($value.v_handle),
TVMTypeCode_kTVMModuleHandle => ModuleHandle($value.v_handle), TVMTypeCode_kTVMModuleHandle => ModuleHandle($value.v_handle),
TVMTypeCode_kTVMPackedFuncHandle => FuncHandle($value.v_handle), TVMTypeCode_kTVMPackedFuncHandle => FuncHandle($value.v_handle),
TVMTypeCode_kTVMNDArrayHandle => NDArrayContainer($value.v_handle), TVMTypeCode_kTVMNDArrayHandle => NDArrayHandle($value.v_handle),
$( $tvm_type => { $from_tvm_type } ),+ $( $tvm_type => { $from_tvm_type } ),+
_ => unimplemented!("{}", type_code), _ => unimplemented!("{}", type_code),
} }
...@@ -133,7 +138,7 @@ macro_rules! TVMPODValue { ...@@ -133,7 +138,7 @@ macro_rules! TVMPODValue {
TVMValue { v_handle: *val }, TVMValue { v_handle: *val },
TVMTypeCode_kTVMPackedFuncHandle TVMTypeCode_kTVMPackedFuncHandle
), ),
NDArrayContainer(val) => NDArrayHandle(val) =>
(TVMValue { v_handle: *val }, TVMTypeCode_kTVMNDArrayHandle), (TVMValue { v_handle: *val }, TVMTypeCode_kTVMNDArrayHandle),
$( $self_type($val) => { $from_self_type } ),+ $( $self_type($val) => { $from_self_type } ),+
} }
......
...@@ -24,7 +24,9 @@ ...@@ -24,7 +24,9 @@
//! # Example //! # Example
//! //!
//! ``` //! ```
//! let ctx = TVMContext::new(1, 0); //! # use tvm_frontend::{TVMDeviceType, TVMContext};
//! let cpu = TVMDeviceType::from("cpu");
//! let ctx = TVMContext::new(cpu , 0);
//! let cpu0 = TVMContext::cpu(0); //! let cpu0 = TVMContext::cpu(0);
//! assert_eq!(ctx, cpu0); //! assert_eq!(ctx, cpu0);
//! ``` //! ```
...@@ -32,6 +34,7 @@ ...@@ -32,6 +34,7 @@
//! Or from a supported device name. //! Or from a supported device name.
//! //!
//! ``` //! ```
//! use tvm_frontend::TVMContext;
//! let cpu0 = TVMContext::from("cpu"); //! let cpu0 = TVMContext::from("cpu");
//! println!("{}", cpu0); //! println!("{}", cpu0);
//! ``` //! ```
...@@ -55,6 +58,7 @@ use crate::{function, TVMArgValue}; ...@@ -55,6 +58,7 @@ use crate::{function, TVMArgValue};
/// ## Example /// ## Example
/// ///
/// ``` /// ```
/// use tvm_frontend::TVMDeviceType;
/// let cpu = TVMDeviceType::from("cpu"); /// let cpu = TVMDeviceType::from("cpu");
/// println!("device is: {}", cpu); /// println!("device is: {}", cpu);
///``` ///```
...@@ -152,7 +156,8 @@ impl<'a> From<&TVMDeviceType> for TVMArgValue<'a> { ...@@ -152,7 +156,8 @@ impl<'a> From<&TVMDeviceType> for TVMArgValue<'a> {
/// ## Examples /// ## Examples
/// ///
/// ``` /// ```
/// let ctx = TVMContext::from("gpu"); /// use tvm_frontend::TVMContext;
/// let ctx = TVMContext::from("cpu");
/// assert!(ctx.exist()); /// assert!(ctx.exist());
/// ///
/// ``` /// ```
...@@ -160,9 +165,12 @@ impl<'a> From<&TVMDeviceType> for TVMArgValue<'a> { ...@@ -160,9 +165,12 @@ impl<'a> From<&TVMDeviceType> for TVMArgValue<'a> {
/// It is possible to query the underlying context as follows /// It is possible to query the underlying context as follows
/// ///
/// ``` /// ```
/// println!("maximun threads per block: {}", ctx.max_threads_per_block()); /// # use tvm_frontend::TVMContext;
/// println!("compute version: {}", ctx.compute_version()); /// # let ctx = TVMContext::from("cpu");
/// println!("maximun threads per block: {}", ctx.exist());
/// ``` /// ```
// TODO: add example back for GPU
// println!("compute version: {}", ctx.compute_version());
#[derive(Debug, Default, Clone, Copy, Hash, PartialEq, Eq)] #[derive(Debug, Default, Clone, Copy, Hash, PartialEq, Eq)]
pub struct TVMContext { pub struct TVMContext {
/// Supported device types /// Supported device types
...@@ -215,11 +223,12 @@ impl<'a> From<&'a str> for TVMContext { ...@@ -215,11 +223,12 @@ impl<'a> From<&'a str> for TVMContext {
impl TVMContext { impl TVMContext {
/// Checks whether the context exists or not. /// Checks whether the context exists or not.
pub fn exist(&self) -> bool { pub fn exist(&self) -> bool {
let func = function::Function::get("_GetDeviceAttr").expect("API function always exists"); let func = function::Function::get("runtime.GetDeviceAttr")
let dt = self.device_type.0 as usize; .expect("TVM FFI functions must always be registered.");
let dt = self.device_type.0 as isize;
// `unwrap` is ok here because if there is any error, // `unwrap` is ok here because if there is any error,
// if would occure inside `call_packed!` // if would occure inside `call_packed!`
let ret: u64 = call_packed!(func, dt, self.device_id, 0) let ret: i64 = call_packed!(func, dt, self.device_id, 0)
.unwrap() .unwrap()
.try_into() .try_into()
.unwrap(); .unwrap();
...@@ -241,15 +250,17 @@ macro_rules! impl_device_attrs { ...@@ -241,15 +250,17 @@ macro_rules! impl_device_attrs {
($(($attr_name:ident, $attr_kind:expr));+) => { ($(($attr_name:ident, $attr_kind:expr));+) => {
$( $(
impl TVMContext { impl TVMContext {
pub fn $attr_name(&self) -> usize { pub fn $attr_name(&self) -> isize {
let func = function::Function::get("_GetDeviceAttr") let func = function::Function::get("runtime.GetDeviceAttr")
.expect("API function always exists"); .expect("TVM FFI functions must always be registered.");
let dt = self.device_type.0 as usize; let dt = self.device_type.0 as isize;
// TODO(@jroesch): these functions CAN and WILL return NULL
// we should make these optional or somesuch to handle this.
// `unwrap` is ok here because if there is any error, // `unwrap` is ok here because if there is any error,
// if would occur in function call. // if would occur in function call.
function::Builder::from(func) function::Builder::from(func)
.arg(dt) .arg(dt)
.arg(self.device_id as usize) .arg(self.device_id as isize)
.arg($attr_kind) .arg($attr_kind)
.invoke() .invoke()
.unwrap() .unwrap()
......
...@@ -47,12 +47,12 @@ lazy_static! { ...@@ -47,12 +47,12 @@ lazy_static! {
&mut names_ptr as *mut _, &mut names_ptr as *mut _,
)); ));
let names_list = unsafe { slice::from_raw_parts(names_ptr, out_size as usize) }; let names_list = unsafe { slice::from_raw_parts(names_ptr, out_size as usize) };
Mutex::new( let names_list = names_list
names_list .iter()
.iter() .map(|&p| (unsafe { CStr::from_ptr(p).to_str().unwrap() }, None))
.map(|&p| (unsafe { CStr::from_ptr(p).to_str().unwrap() }, None)) .collect();
.collect(),
) Mutex::new(names_list)
}; };
} }
...@@ -261,7 +261,10 @@ unsafe extern "C" fn tvm_callback( ...@@ -261,7 +261,10 @@ unsafe extern "C" fn tvm_callback(
|| tcode == ffi::TVMTypeCode_kTVMPackedFuncHandle as c_int || tcode == ffi::TVMTypeCode_kTVMPackedFuncHandle as c_int
|| tcode == ffi::TVMTypeCode_kTVMModuleHandle as c_int || tcode == ffi::TVMTypeCode_kTVMModuleHandle as c_int
{ {
check_call!(ffi::TVMCbArgToReturn(&mut value as *mut _, &mut tcode as *mut _)); check_call!(ffi::TVMCbArgToReturn(
&mut value as *mut _,
&mut tcode as *mut _
));
} }
local_args.push(TVMArgValue::from_tvm_value(value, tcode as u32)); local_args.push(TVMArgValue::from_tvm_value(value, tcode as u32));
} }
...@@ -313,6 +316,9 @@ fn convert_to_tvm_func(f: fn(&[TVMArgValue]) -> Result<TVMRetValue, Error>) -> F ...@@ -313,6 +316,9 @@ fn convert_to_tvm_func(f: fn(&[TVMArgValue]) -> Result<TVMRetValue, Error>) -> F
/// ## Example /// ## Example
/// ///
/// ``` /// ```
/// # use tvm_frontend::{TVMArgValue, function, TVMRetValue};
/// # use tvm_frontend::function::Builder;
/// # use failure::Error;
/// use std::convert::TryInto; /// use std::convert::TryInto;
/// ///
/// fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue, Error> { /// fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue, Error> {
...@@ -321,13 +327,13 @@ fn convert_to_tvm_func(f: fn(&[TVMArgValue]) -> Result<TVMRetValue, Error>) -> F ...@@ -321,13 +327,13 @@ fn convert_to_tvm_func(f: fn(&[TVMArgValue]) -> Result<TVMRetValue, Error>) -> F
/// let arg: i64 = arg.try_into()?; /// let arg: i64 = arg.try_into()?;
/// ret += arg; /// ret += arg;
/// } /// }
/// let ret_val = TVMRetValue::from(&ret); /// let ret_val = TVMRetValue::from(ret);
/// Ok(ret_val) /// Ok(ret_val)
/// } /// }
/// ///
/// tvm::function::register(sum, "mysum".to_owned(), false).unwrap(); /// function::register(sum, "mysum".to_owned(), false).unwrap();
/// let mut registered = function::Builder::default(); /// let mut registered = Builder::default();
/// registered.get_function("mysum", true); /// registered.get_function("mysum");
/// assert!(registered.func.is_some()); /// assert!(registered.func.is_some());
/// let ret: i64 = registered.args(&[10, 20, 30]).invoke().unwrap().try_into().unwrap(); /// let ret: i64 = registered.args(&[10, 20, 30]).invoke().unwrap().try_into().unwrap();
/// assert_eq!(ret, 60); /// assert_eq!(ret, 60);
...@@ -354,7 +360,10 @@ pub fn register<S: AsRef<str>>( ...@@ -354,7 +360,10 @@ pub fn register<S: AsRef<str>>(
/// ## Example /// ## Example
/// ///
/// ``` /// ```
/// use std::convert::TryInto; /// # use std::convert::TryInto;
/// # use tvm_frontend::{register_global_func, TVMArgValue, TVMRetValue};
/// # use failure::Error;
/// # use tvm_frontend::function::Builder;
/// ///
/// register_global_func! { /// register_global_func! {
/// fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue, Error> { /// fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue, Error> {
...@@ -363,13 +372,13 @@ pub fn register<S: AsRef<str>>( ...@@ -363,13 +372,13 @@ pub fn register<S: AsRef<str>>(
/// let arg: f64 = arg.try_into()?; /// let arg: f64 = arg.try_into()?;
/// ret += arg; /// ret += arg;
/// } /// }
/// let ret_val = TVMRetValue::from(&ret); /// let ret_val = TVMRetValue::from(ret);
/// Ok(ret_val) /// Ok(ret_val)
/// } /// }
/// } /// }
/// ///
/// let mut registered = function::Builder::default(); /// let mut registered = Builder::default();
/// registered.get_function("sum", true); /// registered.get_function("sum");
/// assert!(registered.func.is_some()); /// assert!(registered.func.is_some());
/// let ret: f64 = registered.args(&[10f64, 20f64, 30f64]).invoke().unwrap().try_into().unwrap(); /// let ret: f64 = registered.args(&[10f64, 20f64, 30f64]).invoke().unwrap().try_into().unwrap();
/// assert_eq!(ret, 60f64); /// assert_eq!(ret, 60f64);
...@@ -404,15 +413,14 @@ macro_rules! register_global_func { ...@@ -404,15 +413,14 @@ macro_rules! register_global_func {
/// ///
/// Instead of /// Instead of
/// ///
/// ``` /// # TODO(@jroesch): replace with working example
/// function::Builder::from(func).arg(&a).arg(&b).invoke(); /// # use tvm_frontend::function::Builder;
/// ``` /// Builder::from(func).arg(&a).arg(&b).invoke();
/// ///
/// one can use /// one can use
/// ///
/// ``` /// # use tvm_frontend::call_packed;
/// call_packed!(func, &a, &b); /// call_packed!(func, &a, &b);
/// ```
#[macro_export] #[macro_export]
macro_rules! call_packed { macro_rules! call_packed {
($fn_name:expr, $($arg:expr),*) => {{ ($fn_name:expr, $($arg:expr),*) => {{
...@@ -428,12 +436,12 @@ macro_rules! call_packed { ...@@ -428,12 +436,12 @@ macro_rules! call_packed {
mod tests { mod tests {
use super::*; use super::*;
static CANARY: &str = "module._LoadFromFile"; static CANARY: &str = "runtime.ModuleLoadFromFile";
#[test] // #[test]
fn list_global_func() { // fn list_global_func() {
assert!(GLOBAL_FUNCTIONS.lock().unwrap().contains_key(CANARY)); // assert!(GLOBAL_FUNCTIONS.lock().unwrap().contains_key(CANARY));
} // }
#[test] #[test]
fn get_fn() { fn get_fn() {
......
...@@ -53,11 +53,13 @@ pub use crate::{ ...@@ -53,11 +53,13 @@ pub use crate::{
ndarray::NDArray, ndarray::NDArray,
tvm_common::{ tvm_common::{
errors as common_errors, errors as common_errors,
ffi::{self, TVMByteArray, DLDataType}, ffi::{self, DLDataType, TVMByteArray},
packed_func::{TVMArgValue, TVMRetValue}, packed_func::{TVMArgValue, TVMRetValue},
}, },
}; };
pub type DataType = DLDataType;
// Macro to check the return call to TVM runtime shared library. // Macro to check the return call to TVM runtime shared library.
macro_rules! check_call { macro_rules! check_call {
($e:expr) => {{ ($e:expr) => {{
......
...@@ -94,7 +94,7 @@ impl Module { ...@@ -94,7 +94,7 @@ impl Module {
format_err!("Bad module load path: `{}`.", path.as_ref().display()) format_err!("Bad module load path: `{}`.", path.as_ref().display())
})?, })?,
)?; )?;
let func = Function::get("module._LoadFromFile").expect("API function always exists"); let func = Function::get("runtime.ModuleLoadFromFile").expect("API function always exists");
let cpath = let cpath =
CString::new(path.as_ref().to_str().ok_or_else(|| { CString::new(path.as_ref().to_str().ok_or_else(|| {
format_err!("Bad module load path: `{}`.", path.as_ref().display()) format_err!("Bad module load path: `{}`.", path.as_ref().display())
...@@ -105,7 +105,7 @@ impl Module { ...@@ -105,7 +105,7 @@ impl Module {
/// Checks if a target device is enabled for a module. /// Checks if a target device is enabled for a module.
pub fn enabled(&self, target: &str) -> bool { pub fn enabled(&self, target: &str) -> bool {
let func = Function::get("module._Enabled").expect("API function always exists"); let func = Function::get("runtime.RuntimeEnabled").expect("API function always exists");
// `unwrap` is safe here because if there is any error during the // `unwrap` is safe here because if there is any error during the
// function call, it would occur in `call_packed!`. // function call, it would occur in `call_packed!`.
let tgt = CString::new(target).unwrap(); let tgt = CString::new(target).unwrap();
......
...@@ -29,11 +29,16 @@ ...@@ -29,11 +29,16 @@
//! # Example //! # Example
//! //!
//! ``` //! ```
//! # use tvm_frontend::{NDArray, TVMContext, DataType};
//! # use ndarray::{Array, ArrayD};
//! # use std::str::FromStr;
//! use std::convert::TryFrom;
//!
//! let a = Array::from_shape_vec((2, 2), vec![1f32, 2., 3., 4.]) //! let a = Array::from_shape_vec((2, 2), vec![1f32, 2., 3., 4.])
//! .unwrap() //! .unwrap()
//! .into_dyn(); // Rust's ndarray //! .into_dyn(); // Rust's ndarray
//! let nd = NDArray::from_rust_ndarray(&a, TVMContext::cpu(0), TVMType::from("float32")).unwrap(); //! let nd = NDArray::from_rust_ndarray(&a, TVMContext::cpu(0), DataType::from_str("float32").unwrap()).unwrap();
//! assert_eq!(nd.shape(), Some(&mut [2, 2])); //! assert_eq!(nd.shape(), Some(&mut [2, 2][..]));
//! let rnd: ArrayD<f32> = ArrayD::try_from(&nd).unwrap(); //! let rnd: ArrayD<f32> = ArrayD::try_from(&nd).unwrap();
//! assert!(rnd.all_close(&a, 1e-8f32)); //! assert!(rnd.all_close(&a, 1e-8f32));
//! ``` //! ```
...@@ -47,6 +52,9 @@ use std::{convert::TryFrom, mem, os::raw::c_int, ptr, slice, str::FromStr}; ...@@ -47,6 +52,9 @@ use std::{convert::TryFrom, mem, os::raw::c_int, ptr, slice, str::FromStr};
use failure::Error; use failure::Error;
use num_traits::Num; use num_traits::Num;
use rust_ndarray::{Array, ArrayD}; use rust_ndarray::{Array, ArrayD};
use std::convert::TryInto;
use std::ffi::c_void;
use tvm_common::ffi::DLTensor;
use tvm_common::{ffi, TVMType}; use tvm_common::{ffi, TVMType};
use crate::{errors, TVMByteArray, TVMContext}; use crate::{errors, TVMByteArray, TVMContext};
...@@ -55,31 +63,49 @@ use crate::{errors, TVMByteArray, TVMContext}; ...@@ -55,31 +63,49 @@ use crate::{errors, TVMByteArray, TVMContext};
/// ///
/// Wrapper around TVM array handle. /// Wrapper around TVM array handle.
#[derive(Debug)] #[derive(Debug)]
pub struct NDArray { pub enum NDArray {
pub(crate) handle: ffi::TVMArrayHandle, Borrowed { handle: ffi::TVMArrayHandle },
is_view: bool, Owned { handle: *mut c_void },
} }
impl NDArray { impl NDArray {
pub(crate) fn new(handle: ffi::TVMArrayHandle) -> Self { pub(crate) fn new(handle: ffi::TVMArrayHandle) -> Self {
NDArray { NDArray::Borrowed { handle }
handle, }
is_view: true,
pub(crate) fn from_ndarray_handle(handle: *mut c_void) -> Self {
NDArray::Owned { handle }
}
pub fn as_dltensor(&self) -> &DLTensor {
unsafe {
match self {
NDArray::Borrowed { ref handle } => std::mem::transmute(*handle),
NDArray::Owned { ref handle } => std::mem::transmute(*handle),
}
} }
} }
/// Returns the underlying array handle. pub(crate) fn as_raw_dltensor(&self) -> *mut DLTensor {
pub fn handle(&self) -> ffi::TVMArrayHandle { unsafe {
self.handle match self {
NDArray::Borrowed { ref handle } => std::mem::transmute(*handle),
NDArray::Owned { ref handle } => std::mem::transmute(*handle),
}
}
} }
pub fn is_view(&self) -> bool { pub fn is_view(&self) -> bool {
self.is_view if let &NDArray::Borrowed { .. } = self {
true
} else {
false
}
} }
/// Returns the shape of the NDArray. /// Returns the shape of the NDArray.
pub fn shape(&self) -> Option<&mut [usize]> { pub fn shape(&self) -> Option<&mut [usize]> {
let arr = unsafe { *(self.handle) }; let arr = self.as_dltensor();
if arr.shape.is_null() || arr.data.is_null() { if arr.shape.is_null() || arr.data.is_null() {
return None; return None;
}; };
...@@ -94,24 +120,28 @@ impl NDArray { ...@@ -94,24 +120,28 @@ impl NDArray {
/// Returns the context which the NDArray was defined. /// Returns the context which the NDArray was defined.
pub fn ctx(&self) -> TVMContext { pub fn ctx(&self) -> TVMContext {
unsafe { (*self.handle).ctx.into() } self.as_dltensor().ctx.into()
} }
/// Returns the type of the entries of the NDArray. /// Returns the type of the entries of the NDArray.
pub fn dtype(&self) -> TVMType { pub fn dtype(&self) -> TVMType {
unsafe { (*self.handle).dtype } self.as_dltensor().dtype
} }
/// Returns the number of dimensions of the NDArray. /// Returns the number of dimensions of the NDArray.
pub fn ndim(&self) -> usize { pub fn ndim(&self) -> usize {
unsafe { (*self.handle).ndim as usize } self.as_dltensor()
.ndim
.try_into()
.expect("number of dimensions must always be positive")
} }
/// Returns the strides of the underlying NDArray. /// Returns the strides of the underlying NDArray.
pub fn strides(&self) -> Option<&[usize]> { pub fn strides(&self) -> Option<&[usize]> {
unsafe { unsafe {
let sz = self.ndim() * mem::size_of::<usize>(); let sz = self.ndim() * mem::size_of::<usize>();
let slc = slice::from_raw_parts((*self.handle).strides as *const usize, sz); let strides_ptr = self.as_dltensor().strides as *const usize;
let slc = slice::from_raw_parts(strides_ptr, sz);
Some(slc) Some(slc)
} }
} }
...@@ -141,7 +171,7 @@ impl NDArray { ...@@ -141,7 +171,7 @@ impl NDArray {
} }
pub fn byte_offset(&self) -> isize { pub fn byte_offset(&self) -> isize {
unsafe { (*self.handle).byte_offset as isize } self.as_dltensor().byte_offset as isize
} }
/// Flattens the NDArray to a `Vec` of the same type in cpu. /// Flattens the NDArray to a `Vec` of the same type in cpu.
...@@ -149,12 +179,14 @@ impl NDArray { ...@@ -149,12 +179,14 @@ impl NDArray {
/// ## Example /// ## Example
/// ///
/// ``` /// ```
/// let shape = &mut [4]; /// # use tvm_frontend::{TVMContext, DataType, NDArray};
/// # use std::str::FromStr;
/// let mut shape = [4];
/// let mut data = vec![1i32, 2, 3, 4]; /// let mut data = vec![1i32, 2, 3, 4];
/// let ctx = TVMContext::cpu(0); /// let ctx = TVMContext::cpu(0);
/// let mut ndarray = empty(shape, ctx, TVMType::from("int32")); /// let mut ndarray = NDArray::empty(&mut shape, ctx, DataType::from_str("int32").unwrap());
/// ndarray.copy_from_buffer(&mut data); /// ndarray.copy_from_buffer(&mut data);
/// assert_eq!(ndarray.shape(), Some(shape)); /// assert_eq!(ndarray.shape(), Some(&mut shape[..]));
/// assert_eq!(ndarray.to_vec::<i32>().unwrap(), data); /// assert_eq!(ndarray.to_vec::<i32>().unwrap(), data);
/// ``` /// ```
pub fn to_vec<T>(&self) -> Result<Vec<T>, Error> { pub fn to_vec<T>(&self) -> Result<Vec<T>, Error> {
...@@ -165,7 +197,7 @@ impl NDArray { ...@@ -165,7 +197,7 @@ impl NDArray {
self.dtype(), self.dtype(),
); );
let target = self.copy_to_ndarray(earr)?; let target = self.copy_to_ndarray(earr)?;
let arr = unsafe { *(target.handle) }; let arr = target.as_dltensor();
let sz = self.size().ok_or(errors::MissingShapeError)?; let sz = self.size().ok_or(errors::MissingShapeError)?;
let mut v: Vec<T> = Vec::with_capacity(sz * mem::size_of::<T>()); let mut v: Vec<T> = Vec::with_capacity(sz * mem::size_of::<T>());
unsafe { unsafe {
...@@ -187,10 +219,12 @@ impl NDArray { ...@@ -187,10 +219,12 @@ impl NDArray {
/// ## Example /// ## Example
/// ///
/// ``` /// ```
/// # use tvm_frontend::{TVMContext, DataType, NDArray};
/// # use std::str::FromStr;
/// let shape = &mut [2]; /// let shape = &mut [2];
/// let mut data = vec![1f32, 2]; /// let mut data = vec![1f32, 2.0];
/// let ctx = TVMContext::gpu(0); /// let ctx = TVMContext::cpu(0);
/// let mut ndarray = empty(shape, ctx, TVMType::from("int32")); /// let mut ndarray = NDArray::empty(shape, ctx, DataType::from_str("int32").unwrap());
/// ndarray.copy_from_buffer(&mut data); /// ndarray.copy_from_buffer(&mut data);
/// ``` /// ```
/// ///
...@@ -198,7 +232,7 @@ impl NDArray { ...@@ -198,7 +232,7 @@ impl NDArray {
/// from TVM side. See `TVMArrayCopyFromBytes` in `include/tvm/runtime/c_runtime_api.h`. /// from TVM side. See `TVMArrayCopyFromBytes` in `include/tvm/runtime/c_runtime_api.h`.
pub fn copy_from_buffer<T: Num32>(&mut self, data: &mut [T]) { pub fn copy_from_buffer<T: Num32>(&mut self, data: &mut [T]) {
check_call!(ffi::TVMArrayCopyFromBytes( check_call!(ffi::TVMArrayCopyFromBytes(
self.handle, self.as_raw_dltensor(),
data.as_ptr() as *mut _, data.as_ptr() as *mut _,
data.len() * mem::size_of::<T>() data.len() * mem::size_of::<T>()
)); ));
...@@ -216,8 +250,8 @@ impl NDArray { ...@@ -216,8 +250,8 @@ impl NDArray {
); );
} }
check_call!(ffi::TVMArrayCopyFromTo( check_call!(ffi::TVMArrayCopyFromTo(
self.handle, self.as_raw_dltensor(),
target.handle, target.as_raw_dltensor(),
ptr::null_mut() as ffi::TVMStreamHandle ptr::null_mut() as ffi::TVMStreamHandle
)); ));
Ok(target) Ok(target)
...@@ -263,10 +297,7 @@ impl NDArray { ...@@ -263,10 +297,7 @@ impl NDArray {
ctx.device_id as c_int, ctx.device_id as c_int,
&mut handle as *mut _, &mut handle as *mut _,
)); ));
NDArray { NDArray::Borrowed { handle: handle }
handle,
is_view: false,
}
} }
} }
...@@ -304,8 +335,8 @@ impl_from_ndarray_rustndarray!(f32, "float"); ...@@ -304,8 +335,8 @@ impl_from_ndarray_rustndarray!(f32, "float");
impl Drop for NDArray { impl Drop for NDArray {
fn drop(&mut self) { fn drop(&mut self) {
if !self.is_view { if let &mut NDArray::Owned { .. } = self {
check_call!(ffi::TVMArrayFree(self.handle)); check_call!(ffi::TVMArrayFree(self.as_raw_dltensor()));
} }
} }
} }
......
...@@ -22,15 +22,15 @@ ...@@ -22,15 +22,15 @@
//! `TVMRetValue` is the owned version of `TVMPODValue`. //! `TVMRetValue` is the owned version of `TVMPODValue`.
use std::convert::TryFrom; use std::convert::TryFrom;
// use std::ffi::c_void;
use crate::{Function, Module, NDArray, TVMArgValue, TVMRetValue};
use tvm_common::{ use tvm_common::{
errors::ValueDowncastError, errors::ValueDowncastError,
ffi::{TVMArrayHandle, TVMFunctionHandle, TVMModuleHandle}, ffi::{TVMFunctionHandle, TVMModuleHandle},
try_downcast, try_downcast,
}; };
use crate::{Function, Module, NDArray, TVMArgValue, TVMRetValue};
macro_rules! impl_handle_val { macro_rules! impl_handle_val {
($type:ty, $variant:ident, $inner_type:ty, $ctor:path) => { ($type:ty, $variant:ident, $inner_type:ty, $ctor:path) => {
impl<'a> From<&'a $type> for TVMArgValue<'a> { impl<'a> From<&'a $type> for TVMArgValue<'a> {
...@@ -76,7 +76,60 @@ macro_rules! impl_handle_val { ...@@ -76,7 +76,60 @@ macro_rules! impl_handle_val {
impl_handle_val!(Function, FuncHandle, TVMFunctionHandle, Function::new); impl_handle_val!(Function, FuncHandle, TVMFunctionHandle, Function::new);
impl_handle_val!(Module, ModuleHandle, TVMModuleHandle, Module::new); impl_handle_val!(Module, ModuleHandle, TVMModuleHandle, Module::new);
impl_handle_val!(NDArray, ArrayHandle, TVMArrayHandle, NDArray::new);
impl<'a> From<&'a NDArray> for TVMArgValue<'a> {
fn from(arg: &'a NDArray) -> Self {
match arg {
&NDArray::Borrowed { handle } => TVMArgValue::ArrayHandle(handle),
&NDArray::Owned { handle } => TVMArgValue::NDArrayHandle(handle),
}
}
}
impl<'a> From<&'a mut NDArray> for TVMArgValue<'a> {
fn from(arg: &'a mut NDArray) -> Self {
match arg {
&mut NDArray::Borrowed { handle } => TVMArgValue::ArrayHandle(handle),
&mut NDArray::Owned { handle } => TVMArgValue::NDArrayHandle(handle),
}
}
}
impl<'a> TryFrom<TVMArgValue<'a>> for NDArray {
type Error = ValueDowncastError;
fn try_from(val: TVMArgValue<'a>) -> Result<NDArray, Self::Error> {
try_downcast!(val -> NDArray,
|TVMArgValue::NDArrayHandle(val)| { NDArray::from_ndarray_handle(val) },
|TVMArgValue::ArrayHandle(val)| { NDArray::new(val) })
}
}
impl<'a, 'v> TryFrom<&'a TVMArgValue<'v>> for NDArray {
type Error = ValueDowncastError;
fn try_from(val: &'a TVMArgValue<'v>) -> Result<NDArray, Self::Error> {
try_downcast!(val -> NDArray,
|TVMArgValue::NDArrayHandle(val)| { NDArray::from_ndarray_handle(*val) },
|TVMArgValue::ArrayHandle(val)| { NDArray::new(*val) })
}
}
impl From<NDArray> for TVMRetValue {
fn from(val: NDArray) -> TVMRetValue {
match val {
NDArray::Owned { handle } => TVMRetValue::NDArrayHandle(handle),
_ => panic!("NYI"),
}
}
}
impl TryFrom<TVMRetValue> for NDArray {
type Error = ValueDowncastError;
fn try_from(val: TVMRetValue) -> Result<NDArray, Self::Error> {
try_downcast!(val -> NDArray,
|TVMRetValue::NDArrayHandle(val)| { NDArray::from_ndarray_handle(val) },
|TVMRetValue::ArrayHandle(val)| { NDArray::new(val) })
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
......
...@@ -68,5 +68,5 @@ fn main() { ...@@ -68,5 +68,5 @@ fn main() {
.unwrap() .unwrap()
.try_into() .try_into()
.unwrap(); .unwrap();
assert_eq!(ret, 14f32); assert_eq!(ret, 7f32);
} }
...@@ -19,10 +19,10 @@ ...@@ -19,10 +19,10 @@
extern crate proc_macro; extern crate proc_macro;
use quote::quote;
use std::{fs::File, io::Read}; use std::{fs::File, io::Read};
use syn::parse::{Parse, ParseStream, Result}; use syn::parse::{Parse, ParseStream, Result};
use syn::{LitStr}; use syn::LitStr;
use quote::quote;
use std::path::PathBuf; use std::path::PathBuf;
...@@ -33,9 +33,7 @@ struct ImportModule { ...@@ -33,9 +33,7 @@ struct ImportModule {
impl Parse for ImportModule { impl Parse for ImportModule {
fn parse(input: ParseStream) -> Result<Self> { fn parse(input: ParseStream) -> Result<Self> {
let importing_file: LitStr = input.parse()?; let importing_file: LitStr = input.parse()?;
Ok(ImportModule { Ok(ImportModule { importing_file })
importing_file,
})
} }
} }
...@@ -43,8 +41,8 @@ impl Parse for ImportModule { ...@@ -43,8 +41,8 @@ impl Parse for ImportModule {
pub fn import_module(input: proc_macro::TokenStream) -> proc_macro::TokenStream { pub fn import_module(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let import_module_args = syn::parse_macro_input!(input as ImportModule); let import_module_args = syn::parse_macro_input!(input as ImportModule);
let manifest = std::env::var("CARGO_MANIFEST_DIR") let manifest =
.expect("variable should always be set by Cargo."); std::env::var("CARGO_MANIFEST_DIR").expect("variable should always be set by Cargo.");
let mut path = PathBuf::new(); let mut path = PathBuf::new();
path.push(manifest); path.push(manifest);
......
...@@ -42,7 +42,8 @@ impl Module for SystemLibModule { ...@@ -42,7 +42,8 @@ impl Module for SystemLibModule {
SYSTEM_LIB_FUNCTIONS SYSTEM_LIB_FUNCTIONS
.lock() .lock()
.unwrap() .unwrap()
.get(name.as_ref()).copied() .get(name.as_ref())
.copied()
} }
} }
......
...@@ -27,7 +27,7 @@ use std::{ ...@@ -27,7 +27,7 @@ use std::{
thread::{self, JoinHandle}, thread::{self, JoinHandle},
}; };
use crossbeam::channel::{Sender, Receiver, bounded}; use crossbeam::channel::{bounded, Receiver, Sender};
use tvm_common::ffi::TVMParallelGroupEnv; use tvm_common::ffi::TVMParallelGroupEnv;
pub(crate) type FTVMParallelLambda = pub(crate) type FTVMParallelLambda =
...@@ -138,8 +138,7 @@ impl ThreadPool { ...@@ -138,8 +138,7 @@ impl ThreadPool {
let mut tasks = job.tasks(self.num_workers + 1); let mut tasks = job.tasks(self.num_workers + 1);
for (i, task) in tasks.split_off(1).into_iter().enumerate() { for (i, task) in tasks.split_off(1).into_iter().enumerate() {
self.threads.queues[i].send(task) self.threads.queues[i].send(task).expect("should send");
.expect("should send");
} }
tasks.pop().unwrap().run(); tasks.pop().unwrap().run();
......
...@@ -31,7 +31,7 @@ CWD = osp.dirname(osp.abspath(osp.expanduser(__file__))) ...@@ -31,7 +31,7 @@ CWD = osp.dirname(osp.abspath(osp.expanduser(__file__)))
def _get_model(dshape): def _get_model(dshape):
data = relay.var('data', shape=dshape) data = relay.var('data', shape=dshape)
fc = relay.nn.dense(data, relay.var("dense_weight"), units=dshape[-1]*2) fc = relay.nn.dense(data, relay.var("dense_weight"), units=dshape[-1]*2)
fc = relay.nn.bias_add(data, relay.var("dense_bias")) fc = relay.nn.bias_add(fc, relay.var("dense_bias"))
left, right = relay.split(fc, indices_or_sections=2, axis=1) left, right = relay.split(fc, indices_or_sections=2, axis=1)
one = relay.const(1, dtype="float32") one = relay.const(1, dtype="float32")
return relay.Tuple([(left + one), (right - one), fc]) return relay.Tuple([(left + one), (right - one), fc])
......
...@@ -75,9 +75,9 @@ fn test_load_graph() { ...@@ -75,9 +75,9 @@ fn test_load_graph() {
.unwrap() .unwrap()
.get("func_name") .get("func_name")
.unwrap(), .unwrap(),
"fuse_dense" "fused_nn_dense_nn_bias_add"
); );
assert_eq!(graph.nodes[5].inputs[0].index, 0); assert_eq!(graph.nodes[3].inputs[0].index, 0);
assert_eq!(graph.nodes[6].inputs[0].index, 1); assert_eq!(graph.nodes[4].inputs[0].index, 0);
assert_eq!(graph.heads.len(), 2); assert_eq!(graph.heads.len(), 3);
} }
...@@ -25,16 +25,24 @@ use ar::Builder; ...@@ -25,16 +25,24 @@ use ar::Builder;
fn main() { fn main() {
let out_dir = env::var("OUT_DIR").unwrap(); let out_dir = env::var("OUT_DIR").unwrap();
let out_dir = Path::new(&out_dir).join("test_nn");
std::fs::create_dir_all(&out_dir).unwrap();
let manifest_dir = env::var("CARGO_MANIFEST_DIR").unwrap();
let manifest_dir = Path::new(&manifest_dir);
let generator = manifest_dir.join("src").join("build_test_graph.py");
let graph_path = out_dir.join("graph.o");
let output = Command::new(&generator)
.arg(&out_dir)
.output()
.expect("Failed to execute command");
let output = Command::new(concat!(
env!("CARGO_MANIFEST_DIR"),
"/src/build_test_graph.py"
))
.arg(&out_dir)
.output()
.expect("Failed to execute command");
assert!( assert!(
Path::new(&format!("{}/graph.o", out_dir)).exists(), graph_path.exists(),
"Could not build graph lib: {}", "Could not build graph lib: {}",
String::from_utf8(output.stderr) String::from_utf8(output.stderr)
.unwrap() .unwrap()
...@@ -44,10 +52,10 @@ fn main() { ...@@ -44,10 +52,10 @@ fn main() {
.unwrap_or("") .unwrap_or("")
); );
let lib_file = format!("{}/libtestnn.a", out_dir); let lib_file = out_dir.join("libtestnn.a");
let file = File::create(&lib_file).unwrap(); let file = File::create(&lib_file).unwrap();
let mut builder = Builder::new(file); let mut builder = Builder::new(file);
builder.append_path(format!("{}/graph.o", out_dir)).unwrap(); builder.append_path(graph_path).unwrap();
let status = Command::new("ranlib") let status = Command::new("ranlib")
.arg(&lib_file) .arg(&lib_file)
...@@ -56,7 +64,7 @@ fn main() { ...@@ -56,7 +64,7 @@ fn main() {
assert!(status.success()); assert!(status.success());
println!("cargo:rustc-link-lib=static=testnn"); println!("cargo:rustc-link-lib=static=testnn");
println!("cargo:rustc-link-search=native={}", out_dir); println!("cargo:rustc-link-search=native={}", out_dir.display());
println!("cargo:rerun-if-changed={}", generator.display());
} }
...@@ -31,7 +31,7 @@ from tvm.relay import testing ...@@ -31,7 +31,7 @@ from tvm.relay import testing
def _get_model(dshape): def _get_model(dshape):
data = relay.var('data', shape=dshape) data = relay.var('data', shape=dshape)
fc = relay.nn.dense(data, relay.var("dense_weight"), units=dshape[-1]*2) fc = relay.nn.dense(data, relay.var("dense_weight"), units=dshape[-1]*2)
fc = relay.nn.bias_add(data, relay.var("dense_bias")) fc = relay.nn.bias_add(fc, relay.var("dense_bias"))
left, right = relay.split(fc, indices_or_sections=2, axis=1) left, right = relay.split(fc, indices_or_sections=2, axis=1)
one = relay.const(1, dtype="float32") one = relay.const(1, dtype="float32")
return relay.Tuple([(left + one), (right - one), fc]) return relay.Tuple([(left + one), (right - one), fc])
......
...@@ -51,7 +51,7 @@ fn main() { ...@@ -51,7 +51,7 @@ fn main() {
let syslib = SystemLibModule::default(); let syslib = SystemLibModule::default();
let mut params_bytes = Vec::new(); let mut params_bytes = Vec::new();
fs::File::open(concat!(env!("OUT_DIR"), "/graph.params")) fs::File::open(concat!(env!("OUT_DIR"), "/test_nn/graph.params"))
.unwrap() .unwrap()
.read_to_end(&mut params_bytes) .read_to_end(&mut params_bytes)
.unwrap(); .unwrap();
...@@ -61,9 +61,10 @@ fn main() { ...@@ -61,9 +61,10 @@ fn main() {
.map(|(k, v)| (k, v.to_owned())) .map(|(k, v)| (k, v.to_owned()))
.collect::<HashMap<String, Tensor<'static>>>(); .collect::<HashMap<String, Tensor<'static>>>();
let graph = let graph = Graph::try_from(
Graph::try_from(&fs::read_to_string(concat!(env!("OUT_DIR"), "/graph.json")).unwrap()) &fs::read_to_string(concat!(env!("OUT_DIR"), "/test_nn/graph.json")).unwrap(),
.unwrap(); )
.unwrap();
let mut exec = GraphExecutor::new(graph, &syslib).unwrap(); let mut exec = GraphExecutor::new(graph, &syslib).unwrap();
let x = Array::from_shape_vec( let x = Array::from_shape_vec(
...@@ -73,11 +74,16 @@ fn main() { ...@@ -73,11 +74,16 @@ fn main() {
.collect::<Vec<f32>>(), .collect::<Vec<f32>>(),
) )
.unwrap(); .unwrap();
let w = Array::try_from(params.get("dense0_weight").unwrap().to_owned())
let p0 = params.get("p0").unwrap().to_owned();
let p1 = params.get("p1").unwrap().to_owned();
println!("p0: {:?}", p0.shape());
println!("p1: {:?}", p1.shape());
let w = Array::try_from(p0)
.unwrap() .unwrap()
.into_shape((IN_DIM * 2, IN_DIM)) .into_shape((BATCH_SIZE * 4, IN_DIM))
.unwrap(); .unwrap();
let b = Array::try_from(params.get("dense0_bias").unwrap().to_owned()).unwrap(); let b = Array::try_from(p1).unwrap();
let dense = x.dot(&w.t()) + &b; let dense = x.dot(&w.t()) + &b;
let left = dense.slice(s![.., 0..IN_DIM]); let left = dense.slice(s![.., 0..IN_DIM]);
let right = dense.slice(s![.., IN_DIM..]); let right = dense.slice(s![.., IN_DIM..]);
...@@ -88,8 +94,8 @@ fn main() { ...@@ -88,8 +94,8 @@ fn main() {
exec.set_input("data", (&x).into()); exec.set_input("data", (&x).into());
check_sum!(exec, data, x); check_sum!(exec, data, x);
check_sum!(exec, dense0_weight, w); check_sum!(exec, p0, w);
check_sum!(exec, dense0_bias, b); check_sum!(exec, p1, b);
exec.run(); exec.run();
......
...@@ -19,15 +19,13 @@ ...@@ -19,15 +19,13 @@
set -e set -e
set -u set -u
# Temporary disable rust tests
# remove this line to re-enable.
exit 0
export TVM_HOME="$(git rev-parse --show-toplevel)" export TVM_HOME="$(git rev-parse --show-toplevel)"
export LD_LIBRARY_PATH="$TVM_HOME/lib:$TVM_HOME/build:${LD_LIBRARY_PATH:-}" export LD_LIBRARY_PATH="$TVM_HOME/lib:$TVM_HOME/build:${LD_LIBRARY_PATH:-}"
export PYTHONPATH="$TVM_HOME/python":"$TVM_HOME/topi/python" export PYTHONPATH="$TVM_HOME/python":"$TVM_HOME/topi/python"
export RUST_DIR="$TVM_HOME/rust" export RUST_DIR="$TVM_HOME/rust"
export LLVM_CONFIG_PATH=`which llvm-config-8`
echo "Using $LLVM_CONFIG_PATH"
cd $RUST_DIR cd $RUST_DIR
cargo fmt -- --check cargo fmt -- --check
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment