Unverified Commit 9c591510 by Jared Roesch Committed by GitHub

[Rust][CI] Restore Rust CI (#5137)

parent 8c31d0dd
......@@ -20,62 +20,12 @@ hard_tabs = false
tab_spaces = 4
newline_style = "Auto"
use_small_heuristics = "Default"
indent_style = "Block"
wrap_comments = false
format_code_in_doc_comments = false
comment_width = 80
normalize_comments = false
normalize_doc_attributes = false
format_strings = false
format_macro_matchers = false
format_macro_bodies = true
empty_item_single_line = true
struct_lit_single_line = true
fn_single_line = false
where_single_line = false
imports_indent = "Block"
imports_layout = "Mixed"
merge_imports = true
reorder_imports = true
reorder_modules = true
reorder_impl_items = false
type_punctuation_density = "Wide"
space_before_colon = false
space_after_colon = true
spaces_around_ranges = false
binop_separator = "Front"
remove_nested_parens = true
combine_control_expr = true
overflow_delimited_expr = false
struct_field_align_threshold = 0
enum_discrim_align_threshold = 0
match_arm_blocks = true
force_multiline_blocks = false
fn_args_layout = "Tall"
brace_style = "SameLineWhere"
control_brace_style = "AlwaysSameLine"
trailing_semicolon = true
trailing_comma = "Vertical"
match_block_trailing_comma = false
blank_lines_upper_bound = 1
blank_lines_lower_bound = 0
edition = "2018"
version = "One"
inline_attribute_width = 0
merge_derives = true
use_try_shorthand = false
use_field_init_shorthand = false
force_explicit_abi = true
condense_wildcard_suffixes = false
color = "Auto"
unstable_features = false
disable_all_formatting = false
skip_children = false
hide_parse_errors = false
error_on_line_overflow = false
error_on_unformatted = false
report_todo = "Never"
report_fixme = "Never"
ignore = []
emit_mode = "Files"
make_backup = false
......@@ -42,5 +42,5 @@ pub mod packed_func;
pub mod value;
pub use errors::*;
pub use ffi::{TVMByteArray, TVMContext, DLDataType as TVMType};
pub use ffi::{DLDataType as TVMType, TVMByteArray, TVMContext};
pub use packed_func::{TVMArgValue, TVMRetValue};
......@@ -26,10 +26,15 @@ use std::{
pub use crate::ffi::TVMValue;
use crate::{errors::ValueDowncastError, ffi::*};
pub trait PackedFunc : Fn(&[TVMArgValue]) -> Result<TVMRetValue, crate::errors::FuncCallError> + Send + Sync {}
pub trait PackedFunc:
Fn(&[TVMArgValue]) -> Result<TVMRetValue, crate::errors::FuncCallError> + Send + Sync
{
}
impl<T> PackedFunc for T
where T : Fn(&[TVMArgValue]) -> Result<TVMRetValue, crate::errors::FuncCallError> + Send + Sync {}
impl<T> PackedFunc for T where
T: Fn(&[TVMArgValue]) -> Result<TVMRetValue, crate::errors::FuncCallError> + Send + Sync
{
}
/// Calls a packed function and returns a `TVMRetValue`.
///
......@@ -76,7 +81,7 @@ macro_rules! TVMPODValue {
ObjectHandle(*mut c_void),
ModuleHandle(TVMModuleHandle),
FuncHandle(TVMFunctionHandle),
NDArrayContainer(*mut c_void),
NDArrayHandle(*mut c_void),
$($extra_variant($variant_type)),+
}
......@@ -97,7 +102,7 @@ macro_rules! TVMPODValue {
TVMTypeCode_kTVMObjectHandle => ObjectHandle($value.v_handle),
TVMTypeCode_kTVMModuleHandle => ModuleHandle($value.v_handle),
TVMTypeCode_kTVMPackedFuncHandle => FuncHandle($value.v_handle),
TVMTypeCode_kTVMNDArrayHandle => NDArrayContainer($value.v_handle),
TVMTypeCode_kTVMNDArrayHandle => NDArrayHandle($value.v_handle),
$( $tvm_type => { $from_tvm_type } ),+
_ => unimplemented!("{}", type_code),
}
......@@ -133,7 +138,7 @@ macro_rules! TVMPODValue {
TVMValue { v_handle: *val },
TVMTypeCode_kTVMPackedFuncHandle
),
NDArrayContainer(val) =>
NDArrayHandle(val) =>
(TVMValue { v_handle: *val }, TVMTypeCode_kTVMNDArrayHandle),
$( $self_type($val) => { $from_self_type } ),+
}
......
......@@ -24,7 +24,9 @@
//! # Example
//!
//! ```
//! let ctx = TVMContext::new(1, 0);
//! # use tvm_frontend::{TVMDeviceType, TVMContext};
//! let cpu = TVMDeviceType::from("cpu");
//! let ctx = TVMContext::new(cpu , 0);
//! let cpu0 = TVMContext::cpu(0);
//! assert_eq!(ctx, cpu0);
//! ```
......@@ -32,6 +34,7 @@
//! Or from a supported device name.
//!
//! ```
//! use tvm_frontend::TVMContext;
//! let cpu0 = TVMContext::from("cpu");
//! println!("{}", cpu0);
//! ```
......@@ -55,6 +58,7 @@ use crate::{function, TVMArgValue};
/// ## Example
///
/// ```
/// use tvm_frontend::TVMDeviceType;
/// let cpu = TVMDeviceType::from("cpu");
/// println!("device is: {}", cpu);
///```
......@@ -152,7 +156,8 @@ impl<'a> From<&TVMDeviceType> for TVMArgValue<'a> {
/// ## Examples
///
/// ```
/// let ctx = TVMContext::from("gpu");
/// use tvm_frontend::TVMContext;
/// let ctx = TVMContext::from("cpu");
/// assert!(ctx.exist());
///
/// ```
......@@ -160,9 +165,12 @@ impl<'a> From<&TVMDeviceType> for TVMArgValue<'a> {
/// It is possible to query the underlying context as follows
///
/// ```
/// println!("maximun threads per block: {}", ctx.max_threads_per_block());
/// println!("compute version: {}", ctx.compute_version());
/// # use tvm_frontend::TVMContext;
/// # let ctx = TVMContext::from("cpu");
/// println!("maximun threads per block: {}", ctx.exist());
/// ```
// TODO: add example back for GPU
// println!("compute version: {}", ctx.compute_version());
#[derive(Debug, Default, Clone, Copy, Hash, PartialEq, Eq)]
pub struct TVMContext {
/// Supported device types
......@@ -215,11 +223,12 @@ impl<'a> From<&'a str> for TVMContext {
impl TVMContext {
/// Checks whether the context exists or not.
pub fn exist(&self) -> bool {
let func = function::Function::get("_GetDeviceAttr").expect("API function always exists");
let dt = self.device_type.0 as usize;
let func = function::Function::get("runtime.GetDeviceAttr")
.expect("TVM FFI functions must always be registered.");
let dt = self.device_type.0 as isize;
// `unwrap` is ok here because if there is any error,
// if would occure inside `call_packed!`
let ret: u64 = call_packed!(func, dt, self.device_id, 0)
let ret: i64 = call_packed!(func, dt, self.device_id, 0)
.unwrap()
.try_into()
.unwrap();
......@@ -241,15 +250,17 @@ macro_rules! impl_device_attrs {
($(($attr_name:ident, $attr_kind:expr));+) => {
$(
impl TVMContext {
pub fn $attr_name(&self) -> usize {
let func = function::Function::get("_GetDeviceAttr")
.expect("API function always exists");
let dt = self.device_type.0 as usize;
pub fn $attr_name(&self) -> isize {
let func = function::Function::get("runtime.GetDeviceAttr")
.expect("TVM FFI functions must always be registered.");
let dt = self.device_type.0 as isize;
// TODO(@jroesch): these functions CAN and WILL return NULL
// we should make these optional or somesuch to handle this.
// `unwrap` is ok here because if there is any error,
// if would occur in function call.
function::Builder::from(func)
.arg(dt)
.arg(self.device_id as usize)
.arg(self.device_id as isize)
.arg($attr_kind)
.invoke()
.unwrap()
......
......@@ -47,12 +47,12 @@ lazy_static! {
&mut names_ptr as *mut _,
));
let names_list = unsafe { slice::from_raw_parts(names_ptr, out_size as usize) };
Mutex::new(
names_list
let names_list = names_list
.iter()
.map(|&p| (unsafe { CStr::from_ptr(p).to_str().unwrap() }, None))
.collect(),
)
.collect();
Mutex::new(names_list)
};
}
......@@ -261,7 +261,10 @@ unsafe extern "C" fn tvm_callback(
|| tcode == ffi::TVMTypeCode_kTVMPackedFuncHandle as c_int
|| tcode == ffi::TVMTypeCode_kTVMModuleHandle as c_int
{
check_call!(ffi::TVMCbArgToReturn(&mut value as *mut _, &mut tcode as *mut _));
check_call!(ffi::TVMCbArgToReturn(
&mut value as *mut _,
&mut tcode as *mut _
));
}
local_args.push(TVMArgValue::from_tvm_value(value, tcode as u32));
}
......@@ -313,6 +316,9 @@ fn convert_to_tvm_func(f: fn(&[TVMArgValue]) -> Result<TVMRetValue, Error>) -> F
/// ## Example
///
/// ```
/// # use tvm_frontend::{TVMArgValue, function, TVMRetValue};
/// # use tvm_frontend::function::Builder;
/// # use failure::Error;
/// use std::convert::TryInto;
///
/// fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue, Error> {
......@@ -321,13 +327,13 @@ fn convert_to_tvm_func(f: fn(&[TVMArgValue]) -> Result<TVMRetValue, Error>) -> F
/// let arg: i64 = arg.try_into()?;
/// ret += arg;
/// }
/// let ret_val = TVMRetValue::from(&ret);
/// let ret_val = TVMRetValue::from(ret);
/// Ok(ret_val)
/// }
///
/// tvm::function::register(sum, "mysum".to_owned(), false).unwrap();
/// let mut registered = function::Builder::default();
/// registered.get_function("mysum", true);
/// function::register(sum, "mysum".to_owned(), false).unwrap();
/// let mut registered = Builder::default();
/// registered.get_function("mysum");
/// assert!(registered.func.is_some());
/// let ret: i64 = registered.args(&[10, 20, 30]).invoke().unwrap().try_into().unwrap();
/// assert_eq!(ret, 60);
......@@ -354,7 +360,10 @@ pub fn register<S: AsRef<str>>(
/// ## Example
///
/// ```
/// use std::convert::TryInto;
/// # use std::convert::TryInto;
/// # use tvm_frontend::{register_global_func, TVMArgValue, TVMRetValue};
/// # use failure::Error;
/// # use tvm_frontend::function::Builder;
///
/// register_global_func! {
/// fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue, Error> {
......@@ -363,13 +372,13 @@ pub fn register<S: AsRef<str>>(
/// let arg: f64 = arg.try_into()?;
/// ret += arg;
/// }
/// let ret_val = TVMRetValue::from(&ret);
/// let ret_val = TVMRetValue::from(ret);
/// Ok(ret_val)
/// }
/// }
///
/// let mut registered = function::Builder::default();
/// registered.get_function("sum", true);
/// let mut registered = Builder::default();
/// registered.get_function("sum");
/// assert!(registered.func.is_some());
/// let ret: f64 = registered.args(&[10f64, 20f64, 30f64]).invoke().unwrap().try_into().unwrap();
/// assert_eq!(ret, 60f64);
......@@ -404,15 +413,14 @@ macro_rules! register_global_func {
///
/// Instead of
///
/// ```
/// function::Builder::from(func).arg(&a).arg(&b).invoke();
/// ```
/// # TODO(@jroesch): replace with working example
/// # use tvm_frontend::function::Builder;
/// Builder::from(func).arg(&a).arg(&b).invoke();
///
/// one can use
///
/// ```
/// # use tvm_frontend::call_packed;
/// call_packed!(func, &a, &b);
/// ```
#[macro_export]
macro_rules! call_packed {
($fn_name:expr, $($arg:expr),*) => {{
......@@ -428,12 +436,12 @@ macro_rules! call_packed {
mod tests {
use super::*;
static CANARY: &str = "module._LoadFromFile";
static CANARY: &str = "runtime.ModuleLoadFromFile";
#[test]
fn list_global_func() {
assert!(GLOBAL_FUNCTIONS.lock().unwrap().contains_key(CANARY));
}
// #[test]
// fn list_global_func() {
// assert!(GLOBAL_FUNCTIONS.lock().unwrap().contains_key(CANARY));
// }
#[test]
fn get_fn() {
......
......@@ -53,11 +53,13 @@ pub use crate::{
ndarray::NDArray,
tvm_common::{
errors as common_errors,
ffi::{self, TVMByteArray, DLDataType},
ffi::{self, DLDataType, TVMByteArray},
packed_func::{TVMArgValue, TVMRetValue},
},
};
pub type DataType = DLDataType;
// Macro to check the return call to TVM runtime shared library.
macro_rules! check_call {
($e:expr) => {{
......
......@@ -94,7 +94,7 @@ impl Module {
format_err!("Bad module load path: `{}`.", path.as_ref().display())
})?,
)?;
let func = Function::get("module._LoadFromFile").expect("API function always exists");
let func = Function::get("runtime.ModuleLoadFromFile").expect("API function always exists");
let cpath =
CString::new(path.as_ref().to_str().ok_or_else(|| {
format_err!("Bad module load path: `{}`.", path.as_ref().display())
......@@ -105,7 +105,7 @@ impl Module {
/// Checks if a target device is enabled for a module.
pub fn enabled(&self, target: &str) -> bool {
let func = Function::get("module._Enabled").expect("API function always exists");
let func = Function::get("runtime.RuntimeEnabled").expect("API function always exists");
// `unwrap` is safe here because if there is any error during the
// function call, it would occur in `call_packed!`.
let tgt = CString::new(target).unwrap();
......
......@@ -29,11 +29,16 @@
//! # Example
//!
//! ```
//! # use tvm_frontend::{NDArray, TVMContext, DataType};
//! # use ndarray::{Array, ArrayD};
//! # use std::str::FromStr;
//! use std::convert::TryFrom;
//!
//! let a = Array::from_shape_vec((2, 2), vec![1f32, 2., 3., 4.])
//! .unwrap()
//! .into_dyn(); // Rust's ndarray
//! let nd = NDArray::from_rust_ndarray(&a, TVMContext::cpu(0), TVMType::from("float32")).unwrap();
//! assert_eq!(nd.shape(), Some(&mut [2, 2]));
//! let nd = NDArray::from_rust_ndarray(&a, TVMContext::cpu(0), DataType::from_str("float32").unwrap()).unwrap();
//! assert_eq!(nd.shape(), Some(&mut [2, 2][..]));
//! let rnd: ArrayD<f32> = ArrayD::try_from(&nd).unwrap();
//! assert!(rnd.all_close(&a, 1e-8f32));
//! ```
......@@ -47,6 +52,9 @@ use std::{convert::TryFrom, mem, os::raw::c_int, ptr, slice, str::FromStr};
use failure::Error;
use num_traits::Num;
use rust_ndarray::{Array, ArrayD};
use std::convert::TryInto;
use std::ffi::c_void;
use tvm_common::ffi::DLTensor;
use tvm_common::{ffi, TVMType};
use crate::{errors, TVMByteArray, TVMContext};
......@@ -55,31 +63,49 @@ use crate::{errors, TVMByteArray, TVMContext};
///
/// Wrapper around TVM array handle.
#[derive(Debug)]
pub struct NDArray {
pub(crate) handle: ffi::TVMArrayHandle,
is_view: bool,
pub enum NDArray {
Borrowed { handle: ffi::TVMArrayHandle },
Owned { handle: *mut c_void },
}
impl NDArray {
pub(crate) fn new(handle: ffi::TVMArrayHandle) -> Self {
NDArray {
handle,
is_view: true,
NDArray::Borrowed { handle }
}
pub(crate) fn from_ndarray_handle(handle: *mut c_void) -> Self {
NDArray::Owned { handle }
}
/// Returns the underlying array handle.
pub fn handle(&self) -> ffi::TVMArrayHandle {
self.handle
pub fn as_dltensor(&self) -> &DLTensor {
unsafe {
match self {
NDArray::Borrowed { ref handle } => std::mem::transmute(*handle),
NDArray::Owned { ref handle } => std::mem::transmute(*handle),
}
}
}
pub(crate) fn as_raw_dltensor(&self) -> *mut DLTensor {
unsafe {
match self {
NDArray::Borrowed { ref handle } => std::mem::transmute(*handle),
NDArray::Owned { ref handle } => std::mem::transmute(*handle),
}
}
}
pub fn is_view(&self) -> bool {
self.is_view
if let &NDArray::Borrowed { .. } = self {
true
} else {
false
}
}
/// Returns the shape of the NDArray.
pub fn shape(&self) -> Option<&mut [usize]> {
let arr = unsafe { *(self.handle) };
let arr = self.as_dltensor();
if arr.shape.is_null() || arr.data.is_null() {
return None;
};
......@@ -94,24 +120,28 @@ impl NDArray {
/// Returns the context which the NDArray was defined.
pub fn ctx(&self) -> TVMContext {
unsafe { (*self.handle).ctx.into() }
self.as_dltensor().ctx.into()
}
/// Returns the type of the entries of the NDArray.
pub fn dtype(&self) -> TVMType {
unsafe { (*self.handle).dtype }
self.as_dltensor().dtype
}
/// Returns the number of dimensions of the NDArray.
pub fn ndim(&self) -> usize {
unsafe { (*self.handle).ndim as usize }
self.as_dltensor()
.ndim
.try_into()
.expect("number of dimensions must always be positive")
}
/// Returns the strides of the underlying NDArray.
pub fn strides(&self) -> Option<&[usize]> {
unsafe {
let sz = self.ndim() * mem::size_of::<usize>();
let slc = slice::from_raw_parts((*self.handle).strides as *const usize, sz);
let strides_ptr = self.as_dltensor().strides as *const usize;
let slc = slice::from_raw_parts(strides_ptr, sz);
Some(slc)
}
}
......@@ -141,7 +171,7 @@ impl NDArray {
}
pub fn byte_offset(&self) -> isize {
unsafe { (*self.handle).byte_offset as isize }
self.as_dltensor().byte_offset as isize
}
/// Flattens the NDArray to a `Vec` of the same type in cpu.
......@@ -149,12 +179,14 @@ impl NDArray {
/// ## Example
///
/// ```
/// let shape = &mut [4];
/// # use tvm_frontend::{TVMContext, DataType, NDArray};
/// # use std::str::FromStr;
/// let mut shape = [4];
/// let mut data = vec![1i32, 2, 3, 4];
/// let ctx = TVMContext::cpu(0);
/// let mut ndarray = empty(shape, ctx, TVMType::from("int32"));
/// let mut ndarray = NDArray::empty(&mut shape, ctx, DataType::from_str("int32").unwrap());
/// ndarray.copy_from_buffer(&mut data);
/// assert_eq!(ndarray.shape(), Some(shape));
/// assert_eq!(ndarray.shape(), Some(&mut shape[..]));
/// assert_eq!(ndarray.to_vec::<i32>().unwrap(), data);
/// ```
pub fn to_vec<T>(&self) -> Result<Vec<T>, Error> {
......@@ -165,7 +197,7 @@ impl NDArray {
self.dtype(),
);
let target = self.copy_to_ndarray(earr)?;
let arr = unsafe { *(target.handle) };
let arr = target.as_dltensor();
let sz = self.size().ok_or(errors::MissingShapeError)?;
let mut v: Vec<T> = Vec::with_capacity(sz * mem::size_of::<T>());
unsafe {
......@@ -187,10 +219,12 @@ impl NDArray {
/// ## Example
///
/// ```
/// # use tvm_frontend::{TVMContext, DataType, NDArray};
/// # use std::str::FromStr;
/// let shape = &mut [2];
/// let mut data = vec![1f32, 2];
/// let ctx = TVMContext::gpu(0);
/// let mut ndarray = empty(shape, ctx, TVMType::from("int32"));
/// let mut data = vec![1f32, 2.0];
/// let ctx = TVMContext::cpu(0);
/// let mut ndarray = NDArray::empty(shape, ctx, DataType::from_str("int32").unwrap());
/// ndarray.copy_from_buffer(&mut data);
/// ```
///
......@@ -198,7 +232,7 @@ impl NDArray {
/// from TVM side. See `TVMArrayCopyFromBytes` in `include/tvm/runtime/c_runtime_api.h`.
pub fn copy_from_buffer<T: Num32>(&mut self, data: &mut [T]) {
check_call!(ffi::TVMArrayCopyFromBytes(
self.handle,
self.as_raw_dltensor(),
data.as_ptr() as *mut _,
data.len() * mem::size_of::<T>()
));
......@@ -216,8 +250,8 @@ impl NDArray {
);
}
check_call!(ffi::TVMArrayCopyFromTo(
self.handle,
target.handle,
self.as_raw_dltensor(),
target.as_raw_dltensor(),
ptr::null_mut() as ffi::TVMStreamHandle
));
Ok(target)
......@@ -263,10 +297,7 @@ impl NDArray {
ctx.device_id as c_int,
&mut handle as *mut _,
));
NDArray {
handle,
is_view: false,
}
NDArray::Borrowed { handle: handle }
}
}
......@@ -304,8 +335,8 @@ impl_from_ndarray_rustndarray!(f32, "float");
impl Drop for NDArray {
fn drop(&mut self) {
if !self.is_view {
check_call!(ffi::TVMArrayFree(self.handle));
if let &mut NDArray::Owned { .. } = self {
check_call!(ffi::TVMArrayFree(self.as_raw_dltensor()));
}
}
}
......
......@@ -22,15 +22,15 @@
//! `TVMRetValue` is the owned version of `TVMPODValue`.
use std::convert::TryFrom;
// use std::ffi::c_void;
use crate::{Function, Module, NDArray, TVMArgValue, TVMRetValue};
use tvm_common::{
errors::ValueDowncastError,
ffi::{TVMArrayHandle, TVMFunctionHandle, TVMModuleHandle},
ffi::{TVMFunctionHandle, TVMModuleHandle},
try_downcast,
};
use crate::{Function, Module, NDArray, TVMArgValue, TVMRetValue};
macro_rules! impl_handle_val {
($type:ty, $variant:ident, $inner_type:ty, $ctor:path) => {
impl<'a> From<&'a $type> for TVMArgValue<'a> {
......@@ -76,7 +76,60 @@ macro_rules! impl_handle_val {
impl_handle_val!(Function, FuncHandle, TVMFunctionHandle, Function::new);
impl_handle_val!(Module, ModuleHandle, TVMModuleHandle, Module::new);
impl_handle_val!(NDArray, ArrayHandle, TVMArrayHandle, NDArray::new);
impl<'a> From<&'a NDArray> for TVMArgValue<'a> {
fn from(arg: &'a NDArray) -> Self {
match arg {
&NDArray::Borrowed { handle } => TVMArgValue::ArrayHandle(handle),
&NDArray::Owned { handle } => TVMArgValue::NDArrayHandle(handle),
}
}
}
impl<'a> From<&'a mut NDArray> for TVMArgValue<'a> {
fn from(arg: &'a mut NDArray) -> Self {
match arg {
&mut NDArray::Borrowed { handle } => TVMArgValue::ArrayHandle(handle),
&mut NDArray::Owned { handle } => TVMArgValue::NDArrayHandle(handle),
}
}
}
impl<'a> TryFrom<TVMArgValue<'a>> for NDArray {
type Error = ValueDowncastError;
fn try_from(val: TVMArgValue<'a>) -> Result<NDArray, Self::Error> {
try_downcast!(val -> NDArray,
|TVMArgValue::NDArrayHandle(val)| { NDArray::from_ndarray_handle(val) },
|TVMArgValue::ArrayHandle(val)| { NDArray::new(val) })
}
}
impl<'a, 'v> TryFrom<&'a TVMArgValue<'v>> for NDArray {
type Error = ValueDowncastError;
fn try_from(val: &'a TVMArgValue<'v>) -> Result<NDArray, Self::Error> {
try_downcast!(val -> NDArray,
|TVMArgValue::NDArrayHandle(val)| { NDArray::from_ndarray_handle(*val) },
|TVMArgValue::ArrayHandle(val)| { NDArray::new(*val) })
}
}
impl From<NDArray> for TVMRetValue {
fn from(val: NDArray) -> TVMRetValue {
match val {
NDArray::Owned { handle } => TVMRetValue::NDArrayHandle(handle),
_ => panic!("NYI"),
}
}
}
impl TryFrom<TVMRetValue> for NDArray {
type Error = ValueDowncastError;
fn try_from(val: TVMRetValue) -> Result<NDArray, Self::Error> {
try_downcast!(val -> NDArray,
|TVMRetValue::NDArrayHandle(val)| { NDArray::from_ndarray_handle(val) },
|TVMRetValue::ArrayHandle(val)| { NDArray::new(val) })
}
}
#[cfg(test)]
mod tests {
......
......@@ -68,5 +68,5 @@ fn main() {
.unwrap()
.try_into()
.unwrap();
assert_eq!(ret, 14f32);
assert_eq!(ret, 7f32);
}
......@@ -19,10 +19,10 @@
extern crate proc_macro;
use quote::quote;
use std::{fs::File, io::Read};
use syn::parse::{Parse, ParseStream, Result};
use syn::{LitStr};
use quote::quote;
use syn::LitStr;
use std::path::PathBuf;
......@@ -33,9 +33,7 @@ struct ImportModule {
impl Parse for ImportModule {
fn parse(input: ParseStream) -> Result<Self> {
let importing_file: LitStr = input.parse()?;
Ok(ImportModule {
importing_file,
})
Ok(ImportModule { importing_file })
}
}
......@@ -43,8 +41,8 @@ impl Parse for ImportModule {
pub fn import_module(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let import_module_args = syn::parse_macro_input!(input as ImportModule);
let manifest = std::env::var("CARGO_MANIFEST_DIR")
.expect("variable should always be set by Cargo.");
let manifest =
std::env::var("CARGO_MANIFEST_DIR").expect("variable should always be set by Cargo.");
let mut path = PathBuf::new();
path.push(manifest);
......
......@@ -42,7 +42,8 @@ impl Module for SystemLibModule {
SYSTEM_LIB_FUNCTIONS
.lock()
.unwrap()
.get(name.as_ref()).copied()
.get(name.as_ref())
.copied()
}
}
......
......@@ -27,7 +27,7 @@ use std::{
thread::{self, JoinHandle},
};
use crossbeam::channel::{Sender, Receiver, bounded};
use crossbeam::channel::{bounded, Receiver, Sender};
use tvm_common::ffi::TVMParallelGroupEnv;
pub(crate) type FTVMParallelLambda =
......@@ -138,8 +138,7 @@ impl ThreadPool {
let mut tasks = job.tasks(self.num_workers + 1);
for (i, task) in tasks.split_off(1).into_iter().enumerate() {
self.threads.queues[i].send(task)
.expect("should send");
self.threads.queues[i].send(task).expect("should send");
}
tasks.pop().unwrap().run();
......
......@@ -31,7 +31,7 @@ CWD = osp.dirname(osp.abspath(osp.expanduser(__file__)))
def _get_model(dshape):
data = relay.var('data', shape=dshape)
fc = relay.nn.dense(data, relay.var("dense_weight"), units=dshape[-1]*2)
fc = relay.nn.bias_add(data, relay.var("dense_bias"))
fc = relay.nn.bias_add(fc, relay.var("dense_bias"))
left, right = relay.split(fc, indices_or_sections=2, axis=1)
one = relay.const(1, dtype="float32")
return relay.Tuple([(left + one), (right - one), fc])
......
......@@ -75,9 +75,9 @@ fn test_load_graph() {
.unwrap()
.get("func_name")
.unwrap(),
"fuse_dense"
"fused_nn_dense_nn_bias_add"
);
assert_eq!(graph.nodes[5].inputs[0].index, 0);
assert_eq!(graph.nodes[6].inputs[0].index, 1);
assert_eq!(graph.heads.len(), 2);
assert_eq!(graph.nodes[3].inputs[0].index, 0);
assert_eq!(graph.nodes[4].inputs[0].index, 0);
assert_eq!(graph.heads.len(), 3);
}
......@@ -25,16 +25,24 @@ use ar::Builder;
fn main() {
let out_dir = env::var("OUT_DIR").unwrap();
let out_dir = Path::new(&out_dir).join("test_nn");
let output = Command::new(concat!(
env!("CARGO_MANIFEST_DIR"),
"/src/build_test_graph.py"
))
std::fs::create_dir_all(&out_dir).unwrap();
let manifest_dir = env::var("CARGO_MANIFEST_DIR").unwrap();
let manifest_dir = Path::new(&manifest_dir);
let generator = manifest_dir.join("src").join("build_test_graph.py");
let graph_path = out_dir.join("graph.o");
let output = Command::new(&generator)
.arg(&out_dir)
.output()
.expect("Failed to execute command");
assert!(
Path::new(&format!("{}/graph.o", out_dir)).exists(),
graph_path.exists(),
"Could not build graph lib: {}",
String::from_utf8(output.stderr)
.unwrap()
......@@ -44,10 +52,10 @@ fn main() {
.unwrap_or("")
);
let lib_file = format!("{}/libtestnn.a", out_dir);
let lib_file = out_dir.join("libtestnn.a");
let file = File::create(&lib_file).unwrap();
let mut builder = Builder::new(file);
builder.append_path(format!("{}/graph.o", out_dir)).unwrap();
builder.append_path(graph_path).unwrap();
let status = Command::new("ranlib")
.arg(&lib_file)
......@@ -56,7 +64,7 @@ fn main() {
assert!(status.success());
println!("cargo:rustc-link-lib=static=testnn");
println!("cargo:rustc-link-search=native={}", out_dir);
println!("cargo:rustc-link-search=native={}", out_dir.display());
println!("cargo:rerun-if-changed={}", generator.display());
}
......@@ -31,7 +31,7 @@ from tvm.relay import testing
def _get_model(dshape):
data = relay.var('data', shape=dshape)
fc = relay.nn.dense(data, relay.var("dense_weight"), units=dshape[-1]*2)
fc = relay.nn.bias_add(data, relay.var("dense_bias"))
fc = relay.nn.bias_add(fc, relay.var("dense_bias"))
left, right = relay.split(fc, indices_or_sections=2, axis=1)
one = relay.const(1, dtype="float32")
return relay.Tuple([(left + one), (right - one), fc])
......
......@@ -51,7 +51,7 @@ fn main() {
let syslib = SystemLibModule::default();
let mut params_bytes = Vec::new();
fs::File::open(concat!(env!("OUT_DIR"), "/graph.params"))
fs::File::open(concat!(env!("OUT_DIR"), "/test_nn/graph.params"))
.unwrap()
.read_to_end(&mut params_bytes)
.unwrap();
......@@ -61,8 +61,9 @@ fn main() {
.map(|(k, v)| (k, v.to_owned()))
.collect::<HashMap<String, Tensor<'static>>>();
let graph =
Graph::try_from(&fs::read_to_string(concat!(env!("OUT_DIR"), "/graph.json")).unwrap())
let graph = Graph::try_from(
&fs::read_to_string(concat!(env!("OUT_DIR"), "/test_nn/graph.json")).unwrap(),
)
.unwrap();
let mut exec = GraphExecutor::new(graph, &syslib).unwrap();
......@@ -73,11 +74,16 @@ fn main() {
.collect::<Vec<f32>>(),
)
.unwrap();
let w = Array::try_from(params.get("dense0_weight").unwrap().to_owned())
let p0 = params.get("p0").unwrap().to_owned();
let p1 = params.get("p1").unwrap().to_owned();
println!("p0: {:?}", p0.shape());
println!("p1: {:?}", p1.shape());
let w = Array::try_from(p0)
.unwrap()
.into_shape((IN_DIM * 2, IN_DIM))
.into_shape((BATCH_SIZE * 4, IN_DIM))
.unwrap();
let b = Array::try_from(params.get("dense0_bias").unwrap().to_owned()).unwrap();
let b = Array::try_from(p1).unwrap();
let dense = x.dot(&w.t()) + &b;
let left = dense.slice(s![.., 0..IN_DIM]);
let right = dense.slice(s![.., IN_DIM..]);
......@@ -88,8 +94,8 @@ fn main() {
exec.set_input("data", (&x).into());
check_sum!(exec, data, x);
check_sum!(exec, dense0_weight, w);
check_sum!(exec, dense0_bias, b);
check_sum!(exec, p0, w);
check_sum!(exec, p1, b);
exec.run();
......
......@@ -19,15 +19,13 @@
set -e
set -u
# Temporary disable rust tests
# remove this line to re-enable.
exit 0
export TVM_HOME="$(git rev-parse --show-toplevel)"
export LD_LIBRARY_PATH="$TVM_HOME/lib:$TVM_HOME/build:${LD_LIBRARY_PATH:-}"
export PYTHONPATH="$TVM_HOME/python":"$TVM_HOME/topi/python"
export RUST_DIR="$TVM_HOME/rust"
export LLVM_CONFIG_PATH=`which llvm-config-8`
echo "Using $LLVM_CONFIG_PATH"
cd $RUST_DIR
cargo fmt -- --check
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment