Commit 5563b72b by nhynes Committed by Tianqi Chen

Add rust runtime (#1597)

parent 6330797d
......@@ -91,10 +91,8 @@ ENV/
*~
*.pyc
*~
build
config.mk
config.cmake
build_*
Win32
*.dir
perf
......@@ -187,7 +185,6 @@ tvm_u.*
tvm_t.*
# Mac OS X
.DS_Store
build*
# Jetbrain
.idea
......
Cargo.lock
target/
**/*.rs.bk
max_width = 100
hard_tabs = false
tab_spaces = 2
newline_style = "Auto"
use_small_heuristics = "Default"
indent_style = "Block"
wrap_comments = false
comment_width = 80
normalize_comments = false
format_strings = false
format_macro_matchers = false
format_macro_bodies = true
empty_item_single_line = true
struct_lit_single_line = true
fn_single_line = false
where_single_line = false
imports_indent = "Block"
imports_layout = "Mixed"
merge_imports = true
reorder_imports = true
reorder_modules = true
reorder_impl_items = false
type_punctuation_density = "Wide"
space_before_colon = false
space_after_colon = true
spaces_around_ranges = false
binop_separator = "Front"
remove_nested_parens = true
combine_control_expr = true
struct_field_align_threshold = 0
match_arm_blocks = true
force_multiline_blocks = false
fn_args_density = "Tall"
brace_style = "SameLineWhere"
control_brace_style = "AlwaysSameLine"
trailing_semicolon = true
trailing_comma = "Vertical"
match_block_trailing_comma = false
blank_lines_upper_bound = 1
blank_lines_lower_bound = 0
edition = "Edition2015"
merge_derives = true
use_try_shorthand = true
use_field_init_shorthand = false
force_explicit_abi = true
condense_wildcard_suffixes = false
color = "Auto"
required_version = "0.99.4"
unstable_features = false
disable_all_formatting = false
skip_children = false
hide_parse_errors = false
error_on_line_overflow = false
error_on_unformatted = false
report_todo = "Never"
report_fixme = "Never"
ignore = []
emit_mode = "Files"
make_backup = false
language: rust
rust:
- nightly
matrix:
fast_finish: true
[package]
name = "tvm"
version = "0.1.0"
license = "Apache-2.0"
description = "TVM Rust runtime"
repository = "https://github.com/dmlc/tvm"
readme = "README.md"
keywords = ["tvm", "nnvm"]
categories = ["api-bindings", "science"]
authors = ["Nick Hynes <nhynes@berkeley.edu>"]
[features]
default = ["nom/std"]
sgx = ["nom/alloc"]
[dependencies]
bounded-spsc-queue = "0.4.0"
error-chain = { version = "0.12.0", default-features = false }
itertools = "0.7.8"
lazy_static = "1.1.0"
ndarray = "0.11.2"
nom = {version = "4.0.0", default-features = false }
serde = "1.0.59"
serde_derive = "1.0.79"
serde_json = "1.0.17"
[target.'cfg(not(target_env = "sgx"))'.dependencies]
num_cpus = "1.8.0"
#[cfg(target_env = "sgx")]
use alloc::alloc;
#[cfg(not(target_env = "sgx"))]
use std::alloc;
use std::num;
use ndarray;
use serde_json;
error_chain! {
errors {
TryFromTVMRetValueError(expected: String, actual: i64) {
description("mismatched types while downcasting TVMRetValue")
display("invalid downcast: expected `{}` but was `{}`", expected, actual)
}
GraphFormatError(msg: String) {
description("unable to load graph")
display("could not load graph json: {}", msg)
}
LoadGraphParamsError(msg: String) {
description("unable to load graph params")
display("could not load graph params: {}", msg)
}
}
foreign_links {
Alloc(alloc::AllocErr);
GraphDeserialize(serde_json::Error);
ParseInt(num::ParseIntError);
ShapeError(ndarray::ShapeError);
}
}
impl From<alloc::LayoutErr> for Error {
fn from(_err: alloc::LayoutErr) -> Error {
Error::from_kind(ErrorKind::Msg("Layout error".to_string()))
}
}
//! This crate is an implementation of the TVM runtime for modules compiled with `--system-lib`.
//! It's mainly useful for compiling to WebAssembly and SGX,
//! but also native if you prefer Rust to C++.
//!
//! For TVM graphs, the entrypoint to this crate is `runtime::GraphExecutor`.
//! Single-function modules are used via the `packed_func!` macro after obtaining
//! the function from `runtime::SystemLibModule`
//!
//! The main entrypoints to this crate are `GraphExecutor`
//! For examples of use, please refer to the multi-file tests in the `tests` directory.
#![feature(
alloc,
allocator_api,
box_syntax,
extern_prelude,
fn_traits,
try_from,
unboxed_closures,
vec_remove_item
)]
#[cfg(target_env = "sgx")]
extern crate alloc;
extern crate bounded_spsc_queue;
#[cfg(target_env = "sgx")]
extern crate core;
#[macro_use]
extern crate error_chain;
#[macro_use]
extern crate itertools;
#[macro_use]
extern crate lazy_static;
extern crate ndarray;
#[macro_use]
extern crate nom;
#[cfg(not(target_env = "sgx"))]
extern crate num_cpus;
extern crate serde;
#[macro_use]
extern crate serde_derive;
extern crate serde_json;
pub mod ffi {
#![allow(
non_camel_case_types,
non_snake_case,
non_upper_case_globals,
unused
)]
pub mod runtime {
use std::os::raw::{c_char, c_int, c_void};
include!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/src/runtime/c_runtime_api.rs"
));
pub type BackendPackedCFunc =
extern "C" fn(args: *const TVMValue, type_codes: *const c_int, num_args: c_int) -> c_int;
}
}
pub mod errors;
pub mod runtime;
pub use errors::*;
#[cfg(target_env = "sgx")]
use alloc::alloc::{self, Layout};
#[cfg(not(target_env = "sgx"))]
use std::alloc::{self, Layout};
use errors::*;
const DEFAULT_ALIGN_BYTES: usize = 4;
#[derive(PartialEq, Eq)]
pub struct Allocation {
layout: Layout,
ptr: *mut u8,
}
impl Allocation {
/// Allocates a chunk of memory of `size` bytes with optional alignment.
pub fn new(size: usize, align: Option<usize>) -> Result<Self> {
let alignment = align.unwrap_or(DEFAULT_ALIGN_BYTES);
let layout = Layout::from_size_align(size, alignment)?;
let ptr = unsafe { alloc::alloc(layout.clone()) };
if ptr.is_null() {
alloc::handle_alloc_error(layout);
}
Ok(Self {
ptr: ptr,
layout: layout,
})
}
pub fn as_mut_ptr(&self) -> *mut u8 {
self.ptr
}
/// Returns the size of the Allocation in bytes.
pub fn size(&self) -> usize {
self.layout.size()
}
/// Returns the byte alignment of the Allocation.
pub fn align(&self) -> usize {
self.layout.align()
}
}
impl Drop for Allocation {
fn drop(&mut self) {
unsafe {
alloc::dealloc(self.ptr, self.layout.clone());
}
}
}
use std::{
any::TypeId,
convert::TryFrom,
mem,
os::raw::{c_int, c_void},
ptr, slice,
};
use ndarray;
use super::allocator::Allocation;
use errors::*;
use ffi::runtime::{
DLContext, DLDataType, DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLDataTypeCode_kDLUInt,
DLDeviceType_kDLCPU, DLTensor,
};
/// A `Storage` is a container which holds `Tensor` data.
#[derive(PartialEq)]
pub enum Storage<'a> {
/// A `Storage` which owns its contained bytes.
Owned(Allocation),
/// A view of an existing `Storage`.
View(&'a mut [u8], usize), // ptr, align
}
impl<'a> Storage<'a> {
pub fn new(size: usize, align: Option<usize>) -> Result<Storage<'static>> {
Ok(Storage::Owned(Allocation::new(size, align)?))
}
pub fn as_mut_ptr(&self) -> *mut u8 {
match self {
Storage::Owned(alloc) => alloc.as_mut_ptr(),
Storage::View(slice, _) => slice.as_ptr() as *mut u8,
}
}
pub fn size(&self) -> usize {
match self {
Storage::Owned(alloc) => alloc.size(),
Storage::View(slice, _) => slice.len(),
}
}
pub fn align(&self) -> usize {
match self {
Storage::Owned(alloc) => alloc.align(),
Storage::View(_, align) => *align,
}
}
pub fn as_ptr(&self) -> *const u8 {
self.as_mut_ptr() as *const _
}
/// Returns a `Storage::View` which points to an owned `Storage::Owned`.
pub fn view(&self) -> Storage<'a> {
match self {
Storage::Owned(alloc) => Storage::View(
unsafe { slice::from_raw_parts_mut(alloc.as_mut_ptr(), self.size()) },
self.align(),
),
Storage::View(slice, _) => Storage::View(
unsafe { slice::from_raw_parts_mut(self.as_mut_ptr(), slice.len()) },
self.align(),
),
}
}
pub fn is_owned(&self) -> bool {
match self {
Storage::Owned(_) => true,
_ => false,
}
}
/// Returns an owned version of this storage via cloning.
pub fn to_owned(&self) -> Storage<'static> {
let s = Storage::new(self.size(), Some(self.align())).unwrap();
unsafe {
s.as_mut_ptr()
.copy_from_nonoverlapping(self.as_ptr(), self.size())
}
s
}
}
impl<'a, T> From<&'a [T]> for Storage<'a> {
fn from(data: &'a [T]) -> Self {
let data = unsafe {
slice::from_raw_parts_mut(
data.as_ptr() as *const u8 as *mut u8,
data.len() * mem::size_of::<T>() as usize,
)
};
Storage::View(data, mem::align_of::<T>())
}
}
/// A n-dimensional array type which can be converted to/from `tvm::DLTensor` and `ndarray::Array`.
/// `Tensor` is primarily a holder of data which can be operated on via TVM (via `DLTensor`) or
/// converted to `ndarray::Array` for non-TVM processing.
///
/// # Examples
///
/// ```
/// extern crate ndarray;
///
/// let mut a_nd: ndarray::Array = ndarray::Array::from_vec(vec![1f32, 2., 3., 4.]);
/// let mut a: Tensor = a_nd.into();
/// let mut a_dl: DLTensor = (&mut t).into();
/// call_packed!(tvm_fn, &mut a_dl);
///
/// // Array -> Tensor is mostly useful when post-processing TVM graph outputs.
/// let mut a_nd = ndarray::Array::try_from(&a).unwrap();
/// ```
#[derive(PartialEq)]
pub struct Tensor<'a> {
/// The bytes which contain the data this `Tensor` represents.
pub(super) data: Storage<'a>,
pub(super) ctx: TVMContext,
pub(super) dtype: DataType,
pub(super) shape: Vec<i64>, // not usize because `typedef int64_t tvm_index_t` in c_runtime_api.h
/// The `Tensor` strides. Can be `None` if the `Tensor` is contiguous.
pub(super) strides: Option<Vec<usize>>,
pub(super) byte_offset: isize,
pub(super) size: usize,
}
unsafe impl<'a> Send for Tensor<'a> {}
impl<'a> Tensor<'a> {
pub fn shape(&self) -> Vec<i64> {
self.shape.clone()
}
/// Returns the data of this `Tensor` as a `Vec`.
///
/// # Panics
///
/// Panics if the `Tensor` is not contiguous or does not contain elements of type `T`.
pub fn to_vec<T: 'static>(&self) -> Vec<T> {
assert!(self.is_contiguous());
assert!(self.dtype.is_type::<T>());
let mut vec: Vec<T> = Vec::with_capacity(self.size * self.dtype.itemsize());
unsafe {
vec.as_mut_ptr().copy_from_nonoverlapping(
self.data.as_ptr().offset(self.byte_offset) as *const T,
self.size,
);
vec.set_len(self.size);
}
vec
}
/// Returns `true` iff this `Tensor` is represented by a contiguous region of memory.
pub fn is_contiguous(&self) -> bool {
match self.strides {
None => true,
Some(ref strides) => {
// check that stride for each dimension is the product of all trailing dimensons' shapes
self
.shape
.iter()
.zip(strides)
.rfold(
(true, 1),
|(is_contig, expected_stride), (shape, stride)| {
(
is_contig && *stride == expected_stride,
expected_stride * (*shape as usize),
)
},
).0
}
}
}
/// Returns a clone of this `Tensor`.
///
/// # Panics
///
/// Panics if the `Tensor` is not contiguous or does not contain elements of type `T`.
pub fn copy(&mut self, other: &Tensor) {
assert!(
self.dtype == other.dtype && self.size == other.size,
"Tensor shape/dtype mismatch."
);
assert!(
self.is_contiguous() && other.is_contiguous(),
"copy currently requires contiguous tensors\n`self.strides = {:?}` `other.strides = {:?}`",
self.strides,
other.strides
);
unsafe {
self
.data
.as_mut_ptr()
.offset(self.byte_offset as isize)
.copy_from_nonoverlapping(
other.data.as_mut_ptr().offset(other.byte_offset),
other.size * other.dtype.itemsize(),
);
}
}
/// Returns an owned version of this `Tensor` via cloning.
pub fn to_owned(&self) -> Tensor<'static> {
let t = Tensor {
data: self.data.to_owned(),
ctx: self.ctx.clone(),
dtype: self.dtype.clone(),
size: self.size.clone(),
shape: self.shape.clone(),
strides: None,
byte_offset: 0,
};
unsafe { mem::transmute::<Tensor<'a>, Tensor<'static>>(t) }
}
fn from_array_storage<'s, T, D: ndarray::Dimension>(
arr: &ndarray::Array<T, D>,
storage: Storage<'s>,
type_code: usize,
) -> Tensor<'s> {
let type_width = mem::size_of::<T>() as usize;
Tensor {
data: storage,
ctx: TVMContext::default(),
dtype: DataType {
code: type_code,
bits: 8 * type_width,
lanes: 1,
},
size: arr.len(),
shape: arr.shape().iter().map(|&v| v as i64).collect(),
strides: Some(arr.strides().into_iter().map(|&v| v as usize).collect()),
byte_offset: 0,
}
}
}
/// Conversions to `ndarray::Array` from `Tensor`, if the types match.
macro_rules! impl_ndarray_try_from_tensor {
($type:ty, $dtype:expr) => {
impl<'a, 't> TryFrom<&'a Tensor<'t>> for ndarray::ArrayD<$type> {
type Error = Error;
fn try_from(tensor: &'a Tensor) -> Result<ndarray::ArrayD<$type>> {
ensure!(
tensor.dtype == $dtype,
"Cannot convert Tensor with dtype {:?} to ndarray",
tensor.dtype
);
Ok(ndarray::Array::from_shape_vec(
tensor
.shape
.iter()
.map(|s| *s as usize)
.collect::<Vec<usize>>(),
tensor.to_vec::<$type>(),
)?)
}
}
};
}
impl_ndarray_try_from_tensor!(i32, DTYPE_INT32);
impl_ndarray_try_from_tensor!(u32, DTYPE_UINT32);
impl_ndarray_try_from_tensor!(f32, DTYPE_FLOAT32);
impl_ndarray_try_from_tensor!(f64, DTYPE_FLOAT64);
impl DLTensor {
pub(super) fn from_tensor<'a>(tensor: &'a Tensor, flatten: bool) -> Self {
assert!(!flatten || tensor.is_contiguous());
Self {
data: unsafe { tensor.data.as_mut_ptr().offset(tensor.byte_offset) } as *mut c_void,
ctx: DLContext::from(&tensor.ctx),
ndim: if flatten { 1 } else { tensor.shape.len() } as i32,
dtype: DLDataType::from(&tensor.dtype),
shape: if flatten {
&tensor.size as *const _ as *mut i64
} else {
tensor.shape.as_ptr()
} as *mut i64,
strides: if flatten || tensor.is_contiguous() {
ptr::null_mut()
} else {
tensor.strides.as_ref().unwrap().as_ptr()
} as *mut i64,
byte_offset: 0,
}
}
}
impl<'a, 't> From<&'a Tensor<'t>> for DLTensor {
fn from(tensor: &'a Tensor<'t>) -> Self {
DLTensor::from_tensor(tensor, false /* flatten */)
}
}
impl<'a, 't> From<&'a mut Tensor<'t>> for DLTensor {
fn from(tensor: &'a mut Tensor<'t>) -> Self {
DLTensor::from_tensor(tensor, false /* flatten */)
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct DataType {
pub(super) code: usize,
pub(super) bits: usize,
pub(super) lanes: usize,
}
impl DataType {
/// Returns the number of bytes occupied by an element of this `DataType`.
fn itemsize(&self) -> usize {
(self.bits * self.lanes) >> 3
}
/// Returns whether this `DataType` represents primitive type `T`.
fn is_type<T: 'static>(&self) -> bool {
if self.lanes != 1 {
return false;
}
let typ = TypeId::of::<T>();
(typ == TypeId::of::<i32>() && self.code == 0 && self.bits == 32)
|| (typ == TypeId::of::<i64>() && self.code == 0 && self.bits == 64)
|| (typ == TypeId::of::<u32>() && self.code == 1 && self.bits == 32)
|| (typ == TypeId::of::<u64>() && self.code == 1 && self.bits == 64)
|| (typ == TypeId::of::<f32>() && self.code == 2 && self.bits == 32)
|| (typ == TypeId::of::<f64>() && self.code == 2 && self.bits == 64)
}
}
impl<'a> From<&'a DataType> for DLDataType {
fn from(dtype: &'a DataType) -> Self {
Self {
code: dtype.code as u8,
bits: dtype.bits as u8,
lanes: dtype.lanes as u16,
}
}
}
macro_rules! make_dtype_const {
($name: ident, $code: ident, $bits: expr, $lanes: expr) => {
const $name: DataType = DataType {
code: $code as usize,
bits: $bits,
lanes: $lanes,
};
};
}
make_dtype_const!(DTYPE_INT32, DLDataTypeCode_kDLInt, 32, 1);
make_dtype_const!(DTYPE_UINT32, DLDataTypeCode_kDLUInt, 32, 1);
// make_dtype_const!(DTYPE_FLOAT16, DLDataTypeCode_kDLFloat, 16, 1);
make_dtype_const!(DTYPE_FLOAT32, DLDataTypeCode_kDLFloat, 32, 1);
make_dtype_const!(DTYPE_FLOAT64, DLDataTypeCode_kDLFloat, 64, 1);
impl Default for DLContext {
fn default() -> Self {
DLContext {
device_type: DLDeviceType_kDLCPU,
device_id: 0,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct TVMContext {
pub(super) device_type: usize,
pub(super) device_id: usize,
}
impl<'a> From<&'a TVMContext> for DLContext {
fn from(ctx: &'a TVMContext) -> Self {
Self {
device_type: ctx.device_type as u32,
device_id: ctx.device_id as i32,
}
}
}
impl Default for TVMContext {
fn default() -> Self {
Self {
device_type: DLDeviceType_kDLCPU as usize,
device_id: 0,
}
}
}
/// `From` conversions to `Tensor` for owned or borrowed `ndarray::Array`.
///
/// # Panics
///
/// Panics if the ndarray is not contiguous.
macro_rules! impl_tensor_from_ndarray {
($type:ty, $typecode:expr) => {
impl<D: ndarray::Dimension> From<ndarray::Array<$type, D>> for Tensor<'static> {
fn from(arr: ndarray::Array<$type, D>) -> Self {
assert!(arr.is_standard_layout(), "Array must be contiguous.");
let size = arr.len() * mem::size_of::<$type>() as usize;
let storage =
Storage::from(unsafe { slice::from_raw_parts(arr.as_ptr() as *const u8, size) });
Tensor::from_array_storage(&arr, storage, $typecode as usize)
}
}
impl<'a, D: ndarray::Dimension> From<&'a ndarray::Array<$type, D>> for Tensor<'a> {
fn from(arr: &'a ndarray::Array<$type, D>) -> Self {
assert!(arr.is_standard_layout(), "Array must be contiguous.");
Tensor::from_array_storage(
arr,
Storage::from(arr.as_slice().unwrap()),
$typecode as usize,
)
}
}
};
}
/// `From` conversions to `DLTensor` for `ndarray::Array`.
/// Takes a reference to the `ndarray` since `DLTensor` is not owned.
macro_rules! impl_dltensor_from_ndarray {
($type:ty, $typecode:expr) => {
impl<'a, D: ndarray::Dimension> From<&'a mut ndarray::Array<$type, D>> for DLTensor {
fn from(arr: &'a mut ndarray::Array<$type, D>) -> Self {
DLTensor {
data: arr.as_mut_ptr() as *mut c_void,
ctx: DLContext::default(),
ndim: arr.ndim() as c_int,
dtype: DLDataType {
code: $typecode as u8,
bits: 8 * mem::size_of::<$type>() as u8,
lanes: 1,
},
shape: arr.shape().as_ptr() as *const i64 as *mut i64,
strides: arr.strides().as_ptr() as *const isize as *mut i64,
byte_offset: 0,
}
}
}
};
}
impl_dltensor_from_ndarray!(f32, DLDataTypeCode_kDLFloat);
impl_dltensor_from_ndarray!(f64, DLDataTypeCode_kDLFloat);
impl_dltensor_from_ndarray!(i32, DLDataTypeCode_kDLInt);
impl_dltensor_from_ndarray!(i64, DLDataTypeCode_kDLInt);
impl_dltensor_from_ndarray!(u32, DLDataTypeCode_kDLUInt);
impl_dltensor_from_ndarray!(u64, DLDataTypeCode_kDLUInt);
impl_tensor_from_ndarray!(f32, DLDataTypeCode_kDLFloat);
impl_tensor_from_ndarray!(f64, DLDataTypeCode_kDLFloat);
impl_tensor_from_ndarray!(i32, DLDataTypeCode_kDLInt);
impl_tensor_from_ndarray!(i64, DLDataTypeCode_kDLInt);
impl_tensor_from_ndarray!(u32, DLDataTypeCode_kDLUInt);
impl_tensor_from_ndarray!(u64, DLDataTypeCode_kDLUInt);
/* automatically generated by rust-bindgen for TVM revision 6292c78 */
pub const TVM_VERSION: &'static [u8; 8usize] = b"0.5.dev\0";
pub const DLPACK_VERSION: u32 = 8;
pub const _STDINT_H: u32 = 1;
pub const _FEATURES_H: u32 = 1;
pub const _DEFAULT_SOURCE: u32 = 1;
pub const __USE_ISOC11: u32 = 1;
pub const __USE_ISOC99: u32 = 1;
pub const __USE_ISOC95: u32 = 1;
pub const __USE_POSIX_IMPLICITLY: u32 = 1;
pub const _POSIX_SOURCE: u32 = 1;
pub const _POSIX_C_SOURCE: u32 = 200809;
pub const __USE_POSIX: u32 = 1;
pub const __USE_POSIX2: u32 = 1;
pub const __USE_POSIX199309: u32 = 1;
pub const __USE_POSIX199506: u32 = 1;
pub const __USE_XOPEN2K: u32 = 1;
pub const __USE_XOPEN2K8: u32 = 1;
pub const _ATFILE_SOURCE: u32 = 1;
pub const __USE_MISC: u32 = 1;
pub const __USE_ATFILE: u32 = 1;
pub const __USE_FORTIFY_LEVEL: u32 = 0;
pub const _STDC_PREDEF_H: u32 = 1;
pub const __STDC_IEC_559__: u32 = 1;
pub const __STDC_IEC_559_COMPLEX__: u32 = 1;
pub const __STDC_ISO_10646__: u32 = 201505;
pub const __STDC_NO_THREADS__: u32 = 1;
pub const __GNU_LIBRARY__: u32 = 6;
pub const __GLIBC__: u32 = 2;
pub const __GLIBC_MINOR__: u32 = 23;
pub const _SYS_CDEFS_H: u32 = 1;
pub const __WORDSIZE: u32 = 64;
pub const __WORDSIZE_TIME64_COMPAT32: u32 = 1;
pub const __SYSCALL_WORDSIZE: u32 = 64;
pub const _BITS_WCHAR_H: u32 = 1;
pub const INT8_MIN: i32 = -128;
pub const INT16_MIN: i32 = -32768;
pub const INT32_MIN: i32 = -2147483648;
pub const INT8_MAX: u32 = 127;
pub const INT16_MAX: u32 = 32767;
pub const INT32_MAX: u32 = 2147483647;
pub const UINT8_MAX: u32 = 255;
pub const UINT16_MAX: u32 = 65535;
pub const UINT32_MAX: u32 = 4294967295;
pub const INT_LEAST8_MIN: i32 = -128;
pub const INT_LEAST16_MIN: i32 = -32768;
pub const INT_LEAST32_MIN: i32 = -2147483648;
pub const INT_LEAST8_MAX: u32 = 127;
pub const INT_LEAST16_MAX: u32 = 32767;
pub const INT_LEAST32_MAX: u32 = 2147483647;
pub const UINT_LEAST8_MAX: u32 = 255;
pub const UINT_LEAST16_MAX: u32 = 65535;
pub const UINT_LEAST32_MAX: u32 = 4294967295;
pub const INT_FAST8_MIN: i32 = -128;
pub const INT_FAST16_MIN: i64 = -9223372036854775808;
pub const INT_FAST32_MIN: i64 = -9223372036854775808;
pub const INT_FAST8_MAX: u32 = 127;
pub const INT_FAST16_MAX: u64 = 9223372036854775807;
pub const INT_FAST32_MAX: u64 = 9223372036854775807;
pub const UINT_FAST8_MAX: u32 = 255;
pub const UINT_FAST16_MAX: i32 = -1;
pub const UINT_FAST32_MAX: i32 = -1;
pub const INTPTR_MIN: i64 = -9223372036854775808;
pub const INTPTR_MAX: u64 = 9223372036854775807;
pub const UINTPTR_MAX: i32 = -1;
pub const PTRDIFF_MIN: i64 = -9223372036854775808;
pub const PTRDIFF_MAX: u64 = 9223372036854775807;
pub const SIG_ATOMIC_MIN: i32 = -2147483648;
pub const SIG_ATOMIC_MAX: u32 = 2147483647;
pub const SIZE_MAX: i32 = -1;
pub const WINT_MIN: u32 = 0;
pub const WINT_MAX: u32 = 4294967295;
pub type int_least8_t = ::std::os::raw::c_schar;
pub type int_least16_t = ::std::os::raw::c_short;
pub type int_least32_t = ::std::os::raw::c_int;
pub type int_least64_t = ::std::os::raw::c_long;
pub type uint_least8_t = ::std::os::raw::c_uchar;
pub type uint_least16_t = ::std::os::raw::c_ushort;
pub type uint_least32_t = ::std::os::raw::c_uint;
pub type uint_least64_t = ::std::os::raw::c_ulong;
pub type int_fast8_t = ::std::os::raw::c_schar;
pub type int_fast16_t = ::std::os::raw::c_long;
pub type int_fast32_t = ::std::os::raw::c_long;
pub type int_fast64_t = ::std::os::raw::c_long;
pub type uint_fast8_t = ::std::os::raw::c_uchar;
pub type uint_fast16_t = ::std::os::raw::c_ulong;
pub type uint_fast32_t = ::std::os::raw::c_ulong;
pub type uint_fast64_t = ::std::os::raw::c_ulong;
pub type intmax_t = ::std::os::raw::c_long;
pub type uintmax_t = ::std::os::raw::c_ulong;
pub type wchar_t = ::std::os::raw::c_int;
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct max_align_t {
pub __clang_max_align_nonce1: ::std::os::raw::c_longlong,
pub __bindgen_padding_0: u64,
pub __clang_max_align_nonce2: f64,
}
pub const DLDeviceType_kDLCPU: DLDeviceType = 1;
pub const DLDeviceType_kDLGPU: DLDeviceType = 2;
pub const DLDeviceType_kDLCPUPinned: DLDeviceType = 3;
pub const DLDeviceType_kDLOpenCL: DLDeviceType = 4;
pub const DLDeviceType_kDLMetal: DLDeviceType = 8;
pub const DLDeviceType_kDLVPI: DLDeviceType = 9;
pub const DLDeviceType_kDLROCM: DLDeviceType = 10;
/// \brief The device type in DLContext.
pub type DLDeviceType = u32;
/// \brief A Device context for Tensor and operator.
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct DLContext {
/// \brief The device type used in the device.
pub device_type: DLDeviceType,
/// \brief The device index
pub device_id: ::std::os::raw::c_int,
}
pub const DLDataTypeCode_kDLInt: DLDataTypeCode = 0;
pub const DLDataTypeCode_kDLUInt: DLDataTypeCode = 1;
pub const DLDataTypeCode_kDLFloat: DLDataTypeCode = 2;
/// \brief The type code options DLDataType.
pub type DLDataTypeCode = u32;
/// \brief The data type the tensor can hold.
///
/// Examples
/// - float: type_code = 2, bits = 32, lanes=1
/// - float4(vectorized 4 float): type_code = 2, bits = 32, lanes=4
/// - int8: type_code = 0, bits = 8, lanes=1
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct DLDataType {
/// \brief Type code of base types.
/// We keep it uint8_t instead of DLDataTypeCode for minimal memory
/// footprint, but the value should be one of DLDataTypeCode enum values.
///
pub code: u8,
/// \brief Number of bits, common choices are 8, 16, 32.
pub bits: u8,
/// \brief Number of lanes in the type, used for vector types.
pub lanes: u16,
}
/// \brief Plain C Tensor object, does not manage memory.
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct DLTensor {
/// \brief The opaque data pointer points to the allocated data.
/// This will be CUDA device pointer or cl_mem handle in OpenCL.
/// This pointer is always aligns to 256 bytes as in CUDA.
pub data: *mut ::std::os::raw::c_void,
/// \brief The device context of the tensor
pub ctx: DLContext,
/// \brief Number of dimensions
pub ndim: ::std::os::raw::c_int,
/// \brief The data type of the pointer
pub dtype: DLDataType,
/// \brief The shape of the tensor
pub shape: *mut i64,
/// \brief strides of the tensor,
/// can be NULL, indicating tensor is compact.
pub strides: *mut i64,
/// \brief The offset in bytes to the beginning pointer to data
pub byte_offset: u64,
}
/// \brief C Tensor object, manage memory of DLTensor. This data structure is
/// intended to faciliate the borrowing of DLTensor by another framework. It is
/// not meant to transfer the tensor. When the borrowing framework doesn't need
/// the tensor, it should call the deleter to notify the host that the resource
/// is no longer needed.
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct DLManagedTensor {
/// \brief DLTensor which is being memory managed
pub dl_tensor: DLTensor,
/// \brief the context of the original host framework of DLManagedTensor in
/// which DLManagedTensor is used in the framework. It can also be NULL.
pub manager_ctx: *mut ::std::os::raw::c_void,
/// \brief Destructor signature void (*)(void*) - this should be called
/// to destruct manager_ctx which holds the DLManagedTensor. It can be NULL
/// if there is no way for the caller to provide a reasonable destructor.
pub deleter: ::std::option::Option<unsafe extern "C" fn(self_: *mut DLManagedTensor)>,
}
/// \brief type of array index.
pub type tvm_index_t = i64;
pub const TVMDeviceExtType_kDLAOCL: TVMDeviceExtType = 5;
pub const TVMDeviceExtType_kDLSDAccel: TVMDeviceExtType = 6;
pub const TVMDeviceExtType_kDLVulkan: TVMDeviceExtType = 7;
pub const TVMDeviceExtType_kOpenGL: TVMDeviceExtType = 11;
pub const TVMDeviceExtType_kExtDev: TVMDeviceExtType = 12;
/// \brief Extension device types in TVM
pub type TVMDeviceExtType = u32;
pub const TVMTypeCode_kHandle: TVMTypeCode = 3;
pub const TVMTypeCode_kNull: TVMTypeCode = 4;
pub const TVMTypeCode_kTVMType: TVMTypeCode = 5;
pub const TVMTypeCode_kTVMContext: TVMTypeCode = 6;
pub const TVMTypeCode_kArrayHandle: TVMTypeCode = 7;
pub const TVMTypeCode_kNodeHandle: TVMTypeCode = 8;
pub const TVMTypeCode_kModuleHandle: TVMTypeCode = 9;
pub const TVMTypeCode_kFuncHandle: TVMTypeCode = 10;
pub const TVMTypeCode_kStr: TVMTypeCode = 11;
pub const TVMTypeCode_kBytes: TVMTypeCode = 12;
pub const TVMTypeCode_kNDArrayContainer: TVMTypeCode = 13;
pub const TVMTypeCode_kExtBegin: TVMTypeCode = 15;
pub const TVMTypeCode_kNNVMFirst: TVMTypeCode = 16;
pub const TVMTypeCode_kNNVMLast: TVMTypeCode = 20;
pub const TVMTypeCode_kExtReserveEnd: TVMTypeCode = 64;
pub const TVMTypeCode_kExtEnd: TVMTypeCode = 128;
/// \brief The type code in TVMType
/// \note TVMType is used in two places.
pub type TVMTypeCode = u32;
/// \brief The data type used in TVM Runtime.
///
/// Examples
/// - float: type_code = 2, bits = 32, lanes=1
/// - float4(vectorized 4 float): type_code = 2, bits = 32, lanes=4
/// - int8: type_code = 0, bits = 8, lanes=1
///
/// \note Arguments TVM API function always takes bits=64 and lanes=1
pub type TVMType = DLDataType;
/// \brief The Device information, abstract away common device types.
pub type TVMContext = DLContext;
/// \brief The tensor array stucture to TVM API.
pub type TVMArray = DLTensor;
/// \brief the array handle
pub type TVMArrayHandle = *mut TVMArray;
/// \brief Union type of values
/// being passed through API and function calls.
#[repr(C)]
#[derive(Copy, Clone)]
pub union TVMValue {
pub v_int64: i64,
pub v_float64: f64,
pub v_handle: *mut ::std::os::raw::c_void,
pub v_str: *const ::std::os::raw::c_char,
pub v_type: TVMType,
pub v_ctx: TVMContext,
_bindgen_union_align: u64,
}
/// \brief Byte array type used to pass in byte array
/// When kBytes is used as data type.
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct TVMByteArray {
pub data: *const ::std::os::raw::c_char,
pub size: usize,
}
/// \brief Handle to TVM runtime modules.
pub type TVMModuleHandle = *mut ::std::os::raw::c_void;
/// \brief Handle to packed function handle.
pub type TVMFunctionHandle = *mut ::std::os::raw::c_void;
/// \brief Handle to hold return value.
pub type TVMRetValueHandle = *mut ::std::os::raw::c_void;
/// \brief The stream that is specific to device
/// can be NULL, which indicates the default one.
pub type TVMStreamHandle = *mut ::std::os::raw::c_void;
extern "C" {
/// \brief Used for implementing C API function.
/// Set last error message before return.
/// \param msg The error message to be set.
pub fn TVMAPISetLastError(msg: *const ::std::os::raw::c_char);
}
extern "C" {
/// \brief return str message of the last error
/// all function in this file will return 0 when success
/// and -1 when an error occured,
/// TVMGetLastError can be called to retrieve the error
///
/// this function is threadsafe and can be called by different thread
/// \return error info
pub fn TVMGetLastError() -> *const ::std::os::raw::c_char;
}
extern "C" {
/// \brief Load module from file.
/// \param file_name The file name to load the module from.
/// \param format The format of the module.
/// \param out The result module
///
/// \return 0 when success, -1 when failure happens
/// \note The resulting module do not contain import relation.
/// It can be reconstructed by TVMModImport.
pub fn TVMModLoadFromFile(
file_name: *const ::std::os::raw::c_char,
format: *const ::std::os::raw::c_char,
out: *mut TVMModuleHandle,
) -> ::std::os::raw::c_int;
}
extern "C" {
/// \brief Add dep to mod's dependency.
/// This allows functions in this module to use modules.
///
/// \param mod The module handle.
/// \param dep The dependent module to be imported.
/// \return 0 when success, -1 when failure happens
pub fn TVMModImport(mod_: TVMModuleHandle, dep: TVMModuleHandle) -> ::std::os::raw::c_int;
}
extern "C" {
/// \brief Get function from the module.
/// \param mod The module handle.
/// \param func_name The name of the function.
/// \param query_imports Whether to query imported modules
/// \param out The result function, can be NULL if it is not available.
/// \return 0 when no error is thrown, -1 when failure happens
pub fn TVMModGetFunction(
mod_: TVMModuleHandle,
func_name: *const ::std::os::raw::c_char,
query_imports: ::std::os::raw::c_int,
out: *mut TVMFunctionHandle,
) -> ::std::os::raw::c_int;
}
extern "C" {
/// \brief Free front-end extension type resource.
/// \param handle The extension handle.
/// \param type_code The type of of the extension type.
/// \return 0 when success, -1 when failure happens
pub fn TVMExtTypeFree(
handle: *mut ::std::os::raw::c_void,
type_code: ::std::os::raw::c_int,
) -> ::std::os::raw::c_int;
}
extern "C" {
/// \brief Free the Module
/// \param mod The module to be freed.
///
/// \note This may not free up the module's resources.
/// If there is active TVMFunctionHandle uses the module
/// Or if this module is imported by another active module.
///
/// The all functions remains valid until TVMFuncFree is called.
/// \return 0 when success, -1 when failure happens
pub fn TVMModFree(mod_: TVMModuleHandle) -> ::std::os::raw::c_int;
}
extern "C" {
/// \brief Free the function when it is no longer needed.
/// \param func The function handle
/// \return 0 when success, -1 when failure happens
pub fn TVMFuncFree(func: TVMFunctionHandle) -> ::std::os::raw::c_int;
}
extern "C" {
/// \brief Call a Packed TVM Function.
///
/// \param func node handle of the function.
/// \param arg_values The arguments
/// \param type_codes The type codes of the arguments
/// \param num_args Number of arguments.
///
/// \param ret_val The return value.
/// \param ret_type_code the type code of return value.
///
/// \return 0 when success, -1 when failure happens
/// \note TVM calls always exchanges with type bits=64, lanes=1
///
/// \note API calls always exchanges with type bits=64, lanes=1
/// If API call returns container handles (e.g. FunctionHandle)
/// these handles should be managed by the front-end.
/// The front-end need to call free function (e.g. TVMFuncFree)
/// to free these handles.
pub fn TVMFuncCall(
func: TVMFunctionHandle,
arg_values: *mut TVMValue,
type_codes: *mut ::std::os::raw::c_int,
num_args: ::std::os::raw::c_int,
ret_val: *mut TVMValue,
ret_type_code: *mut ::std::os::raw::c_int,
) -> ::std::os::raw::c_int;
}
extern "C" {
/// \brief Set the return value of TVMPackedCFunc.
///
/// This function is called by TVMPackedCFunc to set the return value.
/// When this function is not called, the function returns null by default.
///
/// \param ret The return value handle, pass by ret in TVMPackedCFunc
/// \param value The value to be returned.
/// \param type_code The type of the value to be returned.
/// \param num_ret Number of return values, for now only 1 is supported.
pub fn TVMCFuncSetReturn(
ret: TVMRetValueHandle,
value: *mut TVMValue,
type_code: *mut ::std::os::raw::c_int,
num_ret: ::std::os::raw::c_int,
) -> ::std::os::raw::c_int;
}
extern "C" {
/// \brief Inplace translate callback argument value to return value.
/// This is only needed for non-POD arguments.
///
/// \param value The value to be translated.
/// \param code The type code to be translated.
/// \note This function will do a shallow copy when necessary.
///
/// \return 0 when success, -1 when failure happens.
pub fn TVMCbArgToReturn(
value: *mut TVMValue,
code: ::std::os::raw::c_int,
) -> ::std::os::raw::c_int;
}
/// \brief C type of packed function.
///
/// \param args The arguments
/// \param type_codes The type codes of the arguments
/// \param num_args Number of arguments.
/// \param ret The return value handle.
/// \param resource_handle The handle additional resouce handle from fron-end.
/// \return 0 if success, -1 if failure happens, set error via TVMAPISetLastError.
/// \sa TVMCFuncSetReturn
pub type TVMPackedCFunc = ::std::option::Option<
unsafe extern "C" fn(
args: *mut TVMValue,
type_codes: *mut ::std::os::raw::c_int,
num_args: ::std::os::raw::c_int,
ret: TVMRetValueHandle,
resource_handle: *mut ::std::os::raw::c_void,
) -> ::std::os::raw::c_int,
>;
/// \brief C callback to free the resource handle in C packed function.
/// \param resource_handle The handle additional resouce handle from fron-end.
pub type TVMPackedCFuncFinalizer =
::std::option::Option<unsafe extern "C" fn(resource_handle: *mut ::std::os::raw::c_void)>;
/// \brief Signature for extension function declarer.
///
/// TVM call this function to get the extension functions
/// The declarer will call register_func to register function and their name.
///
/// \param register_func_handle The register function
/// \return 0 if success, -1 if failure happens
pub type TVMExtensionFuncDeclarer = ::std::option::Option<
unsafe extern "C" fn(register_func_handle: TVMFunctionHandle) -> ::std::os::raw::c_int,
>;
extern "C" {
/// \brief Wrap a TVMPackedCFunc to become a FunctionHandle.
///
/// The resource_handle will be managed by TVM API, until the function is no longer used.
///
/// \param func The packed C function.
/// \param resource_handle The resource handle from front-end, can be NULL.
/// \param fin The finalizer on resource handle when the FunctionHandle get freed, can be NULL
/// \param out the result function handle.
/// \return 0 when success, -1 when failure happens
pub fn TVMFuncCreateFromCFunc(
func: TVMPackedCFunc,
resource_handle: *mut ::std::os::raw::c_void,
fin: TVMPackedCFuncFinalizer,
out: *mut TVMFunctionHandle,
) -> ::std::os::raw::c_int;
}
extern "C" {
/// \brief Register the function to runtime's global table.
///
/// The registered function then can be pulled by the backend by the name.
///
/// \param name The name of the function.
/// \param f The function to be registered.
/// \param override Whether allow override already registered function.
pub fn TVMFuncRegisterGlobal(
name: *const ::std::os::raw::c_char,
f: TVMFunctionHandle,
override_: ::std::os::raw::c_int,
) -> ::std::os::raw::c_int;
}
extern "C" {
/// \brief Get a global function.
///
/// \param name The name of the function.
/// \param out the result function pointer, NULL if it does not exist.
///
/// \note The function handle of global function is managed by TVM runtime,
/// So TVMFuncFree is should not be called when it get deleted.
pub fn TVMFuncGetGlobal(
name: *const ::std::os::raw::c_char,
out: *mut TVMFunctionHandle,
) -> ::std::os::raw::c_int;
}
extern "C" {
/// \brief List all the globally registered function name
/// \param out_size The number of functions
/// \param out_array The array of function names.
/// \return 0 when success, -1 when failure happens
pub fn TVMFuncListGlobalNames(
out_size: *mut ::std::os::raw::c_int,
out_array: *mut *mut *const ::std::os::raw::c_char,
) -> ::std::os::raw::c_int;
}
extern "C" {
/// \brief Allocate a nd-array's memory,
/// including space of shape, of given spec.
///
/// \param shape The shape of the array, the data content will be copied to out
/// \param ndim The number of dimension of the array.
/// \param dtype_code The type code of the dtype
/// \param dtype_bits The number of bits of dtype
/// \param dtype_lanes The number of lanes in the dtype.
/// \param device_type The device type of context
/// \param device_id The device id of context.
/// \param out The output handle.
/// \return 0 when success, -1 when failure happens
pub fn TVMArrayAlloc(
shape: *const tvm_index_t,
ndim: ::std::os::raw::c_int,
dtype_code: ::std::os::raw::c_int,
dtype_bits: ::std::os::raw::c_int,
dtype_lanes: ::std::os::raw::c_int,
device_type: ::std::os::raw::c_int,
device_id: ::std::os::raw::c_int,
out: *mut TVMArrayHandle,
) -> ::std::os::raw::c_int;
}
extern "C" {
/// \brief Free the TVM Array.
/// \param handle The array handle to be freed.
/// \return 0 when success, -1 when failure happens
pub fn TVMArrayFree(handle: TVMArrayHandle) -> ::std::os::raw::c_int;
}
extern "C" {
/// \brief Copy array data from CPU byte array.
/// \param handle The array handle.
/// \param data the data pointer
/// \param nbytes The number of bytes to copy.
/// \return 0 when success, -1 when failure happens
pub fn TVMArrayCopyFromBytes(
handle: TVMArrayHandle,
data: *mut ::std::os::raw::c_void,
nbytes: usize,
) -> ::std::os::raw::c_int;
}
extern "C" {
/// \brief Copy array data to CPU byte array.
/// \param handle The array handle.
/// \param data the data pointer
/// \param nbytes The number of bytes to copy.
/// \return 0 when success, -1 when failure happens
pub fn TVMArrayCopyToBytes(
handle: TVMArrayHandle,
data: *mut ::std::os::raw::c_void,
nbytes: usize,
) -> ::std::os::raw::c_int;
}
extern "C" {
/// \brief Copy the array, both from and to must be valid during the copy.
/// \param from The array to be copied from.
/// \param to The target space.
/// \param stream The stream where the copy happens, can be NULL.
/// \return 0 when success, -1 when failure happens
pub fn TVMArrayCopyFromTo(
from: TVMArrayHandle,
to: TVMArrayHandle,
stream: TVMStreamHandle,
) -> ::std::os::raw::c_int;
}
extern "C" {
/// \brief Produce an array from the DLManagedTensor that shares data memory
/// with the DLManagedTensor.
/// \param from The source DLManagedTensor.
/// \param out The output array handle.
/// \return 0 when success, -1 when failure happens
pub fn TVMArrayFromDLPack(
from: *mut DLManagedTensor,
out: *mut TVMArrayHandle,
) -> ::std::os::raw::c_int;
}
extern "C" {
/// \brief Produce a DLMangedTensor from the array that shares data memory with
/// the array.
/// \param from The source array.
/// \param out The DLManagedTensor handle.
/// \return 0 when success, -1 when failure happens
pub fn TVMArrayToDLPack(
from: TVMArrayHandle,
out: *mut *mut DLManagedTensor,
) -> ::std::os::raw::c_int;
}
extern "C" {
/// \brief Delete (free) a DLManagedTensor's data.
/// \param dltensor Pointer to the DLManagedTensor.
pub fn TVMDLManagedTensorCallDeleter(dltensor: *mut DLManagedTensor);
}
extern "C" {
/// \brief Create a new runtime stream.
///
/// \param device_type The device type of context
/// \param device_id The device id of context
/// \param out The new stream handle
/// \return 0 when success, -1 when failure happens
pub fn TVMStreamCreate(
device_type: ::std::os::raw::c_int,
device_id: ::std::os::raw::c_int,
out: *mut TVMStreamHandle,
) -> ::std::os::raw::c_int;
}
extern "C" {
/// \brief Free a created stream handle.
///
/// \param device_type The device type of context
/// \param device_id The device id of context
/// \param stream The stream to be freed
/// \return 0 when success, -1 when failure happens
pub fn TVMStreamFree(
device_type: ::std::os::raw::c_int,
device_id: ::std::os::raw::c_int,
stream: TVMStreamHandle,
) -> ::std::os::raw::c_int;
}
extern "C" {
/// \brief Set the runtime stream of current thread to be stream.
/// The subsequent calls to the same device_type
/// will use the setted stream handle.
/// The specific type of stream is runtime device dependent.
///
/// \param device_type The device type of context
/// \param device_id The device id of context.
/// \param handle The stream handle.
/// \return 0 when success, -1 when failure happens
pub fn TVMSetStream(
device_type: ::std::os::raw::c_int,
device_id: ::std::os::raw::c_int,
handle: TVMStreamHandle,
) -> ::std::os::raw::c_int;
}
extern "C" {
/// \brief Wait until all computations on stream completes.
///
/// \param device_type The device type of context
/// \param device_id The device id of context.
/// \param stream The stream to be synchronized.
/// \return 0 when success, -1 when failure happens
pub fn TVMSynchronize(
device_type: ::std::os::raw::c_int,
device_id: ::std::os::raw::c_int,
stream: TVMStreamHandle,
) -> ::std::os::raw::c_int;
}
extern "C" {
/// \brief Synchronize two streams of execution.
///
/// \param device_type The device type of context
/// \param device_id The device id of context
/// \param src The source stream to synchronize.
/// \param dst The destination stream to synchronize.
/// \return 0 when success, -1 when failure happens
pub fn TVMStreamStreamSynchronize(
device_type: ::std::os::raw::c_int,
device_id: ::std::os::raw::c_int,
src: TVMStreamHandle,
dst: TVMStreamHandle,
) -> ::std::os::raw::c_int;
}
extern "C" {
/// \brief Backend function for modules to get function
/// from its environment mod_node (its imports and global function).
/// The user do should not call TVMFuncFree on func.
///
/// \param mod_node The module handle.
/// \param func_name The name of the function.
/// \param out The result function.
/// \return 0 when no error is thrown, -1 when failure happens
pub fn TVMBackendGetFuncFromEnv(
mod_node: *mut ::std::os::raw::c_void,
func_name: *const ::std::os::raw::c_char,
out: *mut TVMFunctionHandle,
) -> ::std::os::raw::c_int;
}
extern "C" {
/// \brief Backend function to register system-wide library symbol.
///
/// \param name The name of the symbol
/// \param ptr The symbol address.
/// \return 0 when no error is thrown, -1 when failure happens
pub fn TVMBackendRegisterSystemLibSymbol(
name: *const ::std::os::raw::c_char,
ptr: *mut ::std::os::raw::c_void,
) -> ::std::os::raw::c_int;
}
extern "C" {
/// \brief Backend function to allocate temporal workspace.
///
/// \note The result allocate spaced is ensured to be aligned to kTempAllocaAlignment.
///
/// \param nbytes The size of the space requested.
/// \param device_type The device type which the space will be allocated.
/// \param device_id The device id which the space will be allocated.
/// \param dtype_code_hint The type code of the array elements. Only used in
/// certain backends such as OpenGL.
/// \param dtype_bits_hint The type bits of the array elements. Only used in
/// certain backends such as OpenGL.
/// \return nullptr when error is thrown, a valid ptr if success
pub fn TVMBackendAllocWorkspace(
device_type: ::std::os::raw::c_int,
device_id: ::std::os::raw::c_int,
nbytes: u64,
dtype_code_hint: ::std::os::raw::c_int,
dtype_bits_hint: ::std::os::raw::c_int,
) -> *mut ::std::os::raw::c_void;
}
extern "C" {
/// \brief Backend function to free temporal workspace.
///
/// \param ptr The result allocated space pointer.
/// \param device_type The device type which the space will be allocated.
/// \param device_id The device id which the space will be allocated.
/// \return 0 when no error is thrown, -1 when failure happens
///
/// \sa TVMBackendAllocWorkspace
pub fn TVMBackendFreeWorkspace(
device_type: ::std::os::raw::c_int,
device_id: ::std::os::raw::c_int,
ptr: *mut ::std::os::raw::c_void,
) -> ::std::os::raw::c_int;
}
/// \brief Environment for TVM parallel task.
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct TVMParallelGroupEnv {
/// \brief Auxiliary used for synchronization
pub sync_handle: *mut ::std::os::raw::c_void,
/// \brief total amount of task
pub num_task: i32,
}
/// \brief The callback function to execute a parallel lambda
/// \param task_id the task id of the function.
/// \param penv The parallel environment backs the execution.
/// \param cdata The supporting closure data.
pub type FTVMParallelLambda = ::std::option::Option<
unsafe extern "C" fn(
task_id: ::std::os::raw::c_int,
penv: *mut TVMParallelGroupEnv,
cdata: *mut ::std::os::raw::c_void,
) -> ::std::os::raw::c_int,
>;
extern "C" {
/// \brief Backend function for running parallel jobs.
///
/// \param flambda The parallel function to be launched.
/// \param cdata The closure data.
/// \param num_task Number of tasks to launch, can be 0, means launch
/// with all available threads.
///
/// \return 0 when no error is thrown, -1 when failure happens
pub fn TVMBackendParallelLaunch(
flambda: FTVMParallelLambda,
cdata: *mut ::std::os::raw::c_void,
num_task: ::std::os::raw::c_int,
) -> ::std::os::raw::c_int;
}
extern "C" {
/// \brief BSP barrrier between parallel threads
/// \param task_id the task id of the function.
/// \param penv The parallel environment backs the execution.
/// \return 0 when no error is thrown, -1 when failure happens
pub fn TVMBackendParallelBarrier(
task_id: ::std::os::raw::c_int,
penv: *mut TVMParallelGroupEnv,
) -> ::std::os::raw::c_int;
}
extern "C" {
/// \brief Simple static initialization fucntion.
/// Run f once and set handle to be not null.
/// This function is mainly used for test purpose.
///
/// \param handle An global address to indicate f
/// \param f The function to be ran
/// \param cdata The closure data to pass to the function.
/// \param nbytes Number of bytes in the closure data.
/// \return 0 when no error is thrown, -1 when failure happens
pub fn TVMBackendRunOnce(
handle: *mut *mut ::std::os::raw::c_void,
f: ::std::option::Option<
unsafe extern "C" fn(arg1: *mut ::std::os::raw::c_void) -> ::std::os::raw::c_int,
>,
cdata: *mut ::std::os::raw::c_void,
nbytes: ::std::os::raw::c_int,
) -> ::std::os::raw::c_int;
}
use std::{cmp, collections::HashMap, convert::TryFrom, iter::FromIterator, mem, str};
use nom::{alpha1, digit1, le_i32, le_i64, le_u16, le_u32, le_u64, le_u8, types::CompleteStr};
use serde;
use serde_json;
use super::{DataType, Module, Storage, TVMArgValue, TVMContext, Tensor};
use errors::{Error, ErrorKind, Result};
use ffi::runtime::{
DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLDataTypeCode_kDLUInt, DLTensor,
};
// Magic number for NDArray file. @see `kTVMNDArrayMagic` in `ndarray.h`
const _NDARRAY_MAGIC: u64 = 0xDD5E40F096B4A13F;
// Magic number for NDArray list file. @see `kTVMNDArrayListMagic` in `graph_runtime.h`
const _NDARRAY_LIST_MAGIC: u64 = 0xF7E58D4F05049CB7;
/// A TVM computation graph.
///
/// # Examples
///
/// ```
/// let graph_json = fs::read_to_string("graph.json")).unwrap();
/// let graph = Graph::try_from(&graph_json).unwrap();
/// ```
#[derive(Serialize, Deserialize, Debug)]
pub struct Graph {
pub nodes: Vec<Node>,
pub arg_nodes: Vec<usize>,
pub heads: Vec<Entry>,
pub node_row_ptr: Option<Vec<usize>>,
pub attrs: Option<HashMap<String, serde_json::Value>>,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct Entry {
pub id: usize,
pub index: usize,
pub version: usize,
}
impl Graph {
fn entry_index(&self, entry: &Entry) -> Result<usize> {
self
.node_row_ptr
.as_ref()
.map(|nrp| nrp[entry.id] + entry.index)
.ok_or("Missing node_row_ptr.".into())
}
/// Attempt to deserialize a JSON attribute to a type `T`.
fn get_attr<T: serde::de::DeserializeOwned>(&self, attr: &str) -> Result<T> {
Ok(serde_json::from_value::<T>(
self
.attrs
.as_ref()
.ok_or(ErrorKind::GraphFormatError(
"Missing graph attrs".to_string(),
))?.get(attr)
.ok_or(ErrorKind::GraphFormatError(format!(
"Missing {} attr",
attr
)))?.to_owned(),
)?)
}
}
#[derive(Serialize, Deserialize, Debug)]
pub struct Node {
pub op: String,
pub name: String,
pub inputs: Vec<Entry>,
pub attrs: Option<HashMap<String, String>>,
pub control_deps: Option<Vec<Entry>>,
}
struct NodeAttrs {
func_name: String,
num_outputs: usize,
flatten_data: bool,
}
impl Node {
fn parse_attrs(&self) -> Result<NodeAttrs> {
let attrs = self
.attrs
.as_ref()
.ok_or(format!("Missing node.attrs for `{}`", self.name))?;
let func_name = attrs
.get("func_name")
.ok_or(format!("Node `{}` is missing attrs.func_name", self.name))?
.to_string();
let num_outputs = attrs
.get("num_outputs")
.ok_or(format!("Node `{}` is missing attrs.num_outputs", self.name))?
.parse::<usize>()?;
let flatten_data = attrs
.get("flatten_data")
.ok_or(format!(
"Node `{}` is missing attrs.flatten_data",
self.name
))?.parse::<u8>()?
== 1;
Ok(NodeAttrs {
func_name,
num_outputs,
flatten_data,
})
}
}
impl<'a> TryFrom<&'a String> for Graph {
type Error = Error;
fn try_from(graph_json: &String) -> Result<Self> {
let graph = serde_json::from_str(graph_json)?;
Ok(graph)
}
}
impl<'a> TryFrom<&'a str> for Graph {
type Error = Error;
fn try_from(graph_json: &'a str) -> Result<Self> {
let graph = serde_json::from_str(graph_json)?;
Ok(graph)
}
}
/// A executor for a TVM computation graph.
///
/// # Examples
///
/// ```
/// use ndarray::Array;
///
/// let syslib = SystemLibModule::default(); // a provider of TVM functions
///
/// let mut params_bytes = Vec::new();
/// fs::File::open("graph.params").unwrap().read_to_end(&mut params_bytes).unwrap();
/// let params = tvm::runtime::load_param_dict(&params_bytes).unwrap();
///
/// let graph = Graph::try_from(&fs::read_to_string("graph.json").unwrap()).unwrap();
///
/// let mut exec = GraphExecutor::new(graph, &syslib).unwrap();
/// exec.load_params(params);
///
/// let x = Array::from_vec(vec![1f32, 2., 3., 4.]);
/// exec.set_input("data", x.into());
/// exec.run();
/// let output = exec.get_output(0).unwrap();
///
/// println!("{:#?}", Array::try_from(output).unwrap());
/// ```
pub struct GraphExecutor<'m, 't> {
graph: Graph,
op_execs: Vec<Box<Fn() + 'm>>,
tensors: Vec<Tensor<'t>>,
}
unsafe impl<'m, 't> Send for GraphExecutor<'m, 't> {}
impl<'m, 't> GraphExecutor<'m, 't> {
pub fn new<M: 'm + Module>(graph: Graph, lib: &'m M) -> Result<Self> {
let tensors = Self::setup_storages(&graph)?;
Ok(GraphExecutor {
op_execs: Self::setup_op_execs(&graph, lib, &tensors)?,
tensors: tensors,
graph: graph,
})
}
/// Runs the computation graph.
pub fn run(&self) {
self.op_execs.iter().for_each(|op_exec| {
op_exec();
});
}
/// Allocates `Storages` for each `storage_id` and returns `Tensor`s to hold each output.
fn setup_storages<'a>(graph: &'a Graph) -> Result<Vec<Tensor<'t>>> {
let storage_ids = graph.get_attr::<(String, Vec<usize>)>("storage_id")?.1;
let shapes = graph.get_attr::<(String, Vec<Vec<i64>>)>("shape")?.1;
let dtypes = graph
.get_attr::<(String, Vec<String>)>("dltype")?
.1
.iter()
.map(|dltype| {
if let Ok((_, dtype)) = tvm_str_to_type(CompleteStr(dltype)) {
Ok(dtype)
} else {
Err(ErrorKind::GraphFormatError(format!("Invalid dltype: {}", dltype).to_string()).into())
}
}).collect::<Result<Vec<DataType>>>()?;
let align = dtypes.iter().map(|dtype| dtype.bits as usize).max();
let mut storage_num_bytes = vec![0usize; *storage_ids.iter().max().unwrap_or(&1) + 1];
for (i, &storage_id) in storage_ids.iter().enumerate() {
let dtype_size = dtypes[i].bits * dtypes[i].lanes >> 3;
let nbytes = dtype_size * shapes[i].iter().product::<i64>() as usize;
storage_num_bytes[storage_id] = cmp::max(nbytes, storage_num_bytes[storage_id]);
}
let mut storages: Vec<Storage> = storage_num_bytes
.into_iter()
.map(|nbytes| Storage::new(nbytes, align))
.collect::<Result<Vec<Storage>>>()?;
let tensors = izip!(storage_ids, shapes, dtypes)
.map(|(storage_id, shape, dtype)| {
let storage = storages[storage_id].view();
Tensor {
data: mem::replace(&mut storages[storage_id], storage),
ctx: TVMContext::default(),
dtype: dtype,
size: shape.iter().product::<i64>() as usize,
shape: shape,
strides: None,
byte_offset: 0,
}
}).collect();
Ok(tensors)
}
/// Creates closures which represent the computation performed by this graph.
fn setup_op_execs<M: 'm + Module>(
graph: &Graph,
lib: &'m M,
tensors: &Vec<Tensor<'t>>,
) -> Result<Vec<Box<Fn() + 'm>>> {
ensure!(graph.node_row_ptr.is_some(), "Missing node_row_ptr.");
let node_row_ptr = graph.node_row_ptr.as_ref().unwrap();
let mut op_execs = Vec::new();
for (i, node) in graph.nodes.iter().enumerate() {
if node.op == "null" {
continue;
}
ensure!(node.op == "tvm_op", "Only TVM ops are supported.");
ensure!(node.attrs.is_some(), "Missing node attrs.");
let attrs = node.parse_attrs()?;
if attrs.func_name == "__nop" {
continue;
}
let func = lib
.get_function(&attrs.func_name)
.ok_or(format!("Missing function {}", attrs.func_name))?;
let arg_indices = node
.inputs
.iter()
.map(|entry| graph.entry_index(entry))
.chain((0..attrs.num_outputs).map(|oi| Ok(node_row_ptr[i].clone() + oi)));
let dl_tensors = arg_indices
.map(|idx| {
let tensor = &tensors[idx?];
Ok(if attrs.flatten_data {
DLTensor::from_tensor(tensor, true /* flatten */)
} else {
DLTensor::from(tensor)
})
}).collect::<Result<Vec<DLTensor>>>()
.unwrap();
let op: Box<Fn()> = box move || {
let args = dl_tensors
.iter()
.map(|t| t.into())
.collect::<Vec<TVMArgValue>>();
func(args.as_slice());
};
op_execs.push(op);
}
Ok(op_execs)
}
pub fn load_params(&mut self, params: HashMap<String, Tensor<'t>>) {
params.into_iter().for_each(|(name, param)| {
self.set_input(name, param);
})
}
pub fn set_input<S: AsRef<str>>(&mut self, name: S, value: Tensor<'t>) {
if let Some(idx) = self.get_input_index(name.as_ref()) {
// TODO: consider `new_with_params` to avoid ever allocating
let ptr = self.tensors[idx].data.as_ptr();
let mut to_replace = self.tensors.iter_mut().filter(|t| t.data.as_ptr() == ptr);
let mut owner = to_replace.nth(0).unwrap();
if value.data.is_owned() {
// FIXME: for no-copy, need setup_op_execs to not capture tensor ptr
// mem::replace(&mut (*owner), value);
// to_replace.for_each(|t| {
// panic!("replacing");
// t.data = owner.data.view();
// });
owner.copy(&value);
} else {
owner.copy(&value);
}
} else {
println!("Unexpected input `{}`", name.as_ref());
}
}
/// Returns the graph input with name `name`, if it exists.
pub fn get_input<S: AsRef<str>>(&mut self, name: S) -> Option<&Tensor> {
self
.get_input_index(name.as_ref())
.and_then(move |idx| Some(&self.tensors[idx]))
}
/// Returns the graph output with index `index`, if it exists.
pub fn get_output(&self, idx: usize) -> Option<&Tensor> {
let graph = &self.graph;
graph.heads.get(idx).and_then(|entry| {
graph
.entry_index(entry)
.map(|idx| self.tensors.get(idx))
.unwrap_or(None)
})
}
/// Returns the index for graph input with name `name`, if it exists.
pub fn get_input_index<S: AsRef<str>>(&self, name: S) -> Option<usize> {
let graph = &self.graph;
(0..graph.nodes.len())
.skip_while(|&i| graph.nodes[i].name != name.as_ref())
.nth(0)
.and_then(|i| {
if graph.arg_nodes.iter().any(|&id| id == i) {
graph.node_row_ptr.as_ref().map(|nrp| nrp[i])
} else {
None
}
})
}
}
/// Converts a string to TVM DLDataTypeCode. @see `String2TVMType` in packed_func.h
named!(
tvm_str_to_type<CompleteStr, DataType>,
do_parse!(
type_name: alpha1 >>
bits: digit1 >>
lanes: opt!(tuple!(tag!("x"), digit1)) >>
(DataType {
code: match type_name {
CompleteStr("int") => DLDataTypeCode_kDLInt,
CompleteStr("uint") => DLDataTypeCode_kDLUInt,
CompleteStr("float") => DLDataTypeCode_kDLFloat,
_ => DLDataTypeCode_kDLFloat,
} as usize,
bits: bits.parse::<u8>().unwrap() as usize,
lanes: match lanes {
Some(lanes) => lanes.1.parse::<u16>().unwrap() as usize,
None => 1,
},
})
)
);
/// Converts a bytes to String.
named!(
name<String>,
map_res!(length_bytes!(le_u64), |b: &[u8]| String::from_utf8(
b.to_vec()
))
);
/// Parses a TVMContext
named!(
tvm_ctx<&[u8], TVMContext>,
do_parse!(
device_type: le_u32 >>
device_id: le_i32 >>
(TVMContext { device_type: device_type as usize, device_id: device_id as usize })
)
);
/// Parses a DataType
named!(
data_type<&[u8], DataType>,
do_parse!(
code: le_u8 >>
bits: le_u8 >>
lanes: le_u16 >>
(DataType { code: code as usize, bits: bits as usize, lanes: lanes as usize })
)
);
/// Parses a Tensor from a TVM array file.
named!(
tensor<Tensor>,
do_parse!(
take!(8)
>> bits!(tag_bits!(u64, 64, 0))
>> ctx: tvm_ctx
>> ndim: le_u32
>> dtype: data_type
>> shape: count!(map!(le_i64, |sz| sz as i64), ndim as usize)
>> length: le_i64
>> data: take!(length)
>> (Tensor {
data: Storage::from(data),
ctx: ctx,
dtype: dtype,
size: shape.iter().product::<i64>() as usize,
shape: shape,
strides: None,
byte_offset: 0,
})
)
);
/// Parses a graph params dict from a params binary file.
named!(
parse_param_dict<HashMap<String, Tensor>>,
do_parse!(
take!(8)
>> bits!(tag_bits!(u64, 64, 0))
>> names: length_count!(le_u64, name)
>> tensors: length_count!(le_u64, tensor)
>> (HashMap::from_iter(names.into_iter().zip(tensors.into_iter())))
)
);
/// Loads a param dict saved using `nnvm.compiler.save_param_dict`.
pub fn load_param_dict(bytes: &[u8]) -> Result<HashMap<String, Tensor>> {
if let Ok((remaining_bytes, param_dict)) = parse_param_dict(bytes) {
if remaining_bytes.len() > 0 {
bail!(ErrorKind::LoadGraphParamsError("extra input".to_string()))
} else {
Ok(param_dict)
}
} else {
bail!(ErrorKind::LoadGraphParamsError(
"invalid parameters file".to_string()
))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_str_to_type() {
assert_eq!(
tvm_str_to_type(CompleteStr("float24")).unwrap().1,
DataType {
code: DLDataTypeCode_kDLFloat as usize,
bits: 24,
lanes: 1
}
);
assert_eq!(
tvm_str_to_type(CompleteStr("uint111x44")).unwrap().1,
DataType {
code: DLDataTypeCode_kDLUInt as usize,
bits: 111,
lanes: 44
}
);
}
}
mod allocator;
mod array;
mod module;
#[macro_use]
mod packed_func;
mod graph;
#[cfg(target_env = "sgx")]
#[macro_use]
pub mod sgx;
mod threading;
mod workspace;
use std::os::raw::c_char;
pub use self::{array::*, graph::*, module::*, packed_func::*, threading::*, workspace::*};
#[no_mangle]
pub extern "C" fn TVMAPISetLastError(cmsg: *const c_char) {
#[cfg(not(target_env = "sgx"))]
unsafe {
panic!(std::ffi::CStr::from_ptr(cmsg).to_str().unwrap());
}
#[cfg(target_env = "sgx")]
ocall_packed!("__sgx_set_last_error__", cmsg);
}
use std::{
collections::HashMap, convert::AsRef, ffi::CStr, os::raw::c_char, string::String, sync::Mutex,
};
use ffi::runtime::BackendPackedCFunc;
use runtime::packed_func::{wrap_backend_packed_func, PackedFunc};
pub trait Module {
fn get_function<S: AsRef<str>>(&self, name: S) -> Option<PackedFunc>;
}
pub struct SystemLibModule;
lazy_static! {
static ref SYSTEM_LIB_FUNCTIONS: Mutex<HashMap<String, BackendPackedCFunc>> =
Mutex::new(HashMap::new());
}
impl Module for SystemLibModule {
fn get_function<S: AsRef<str>>(&self, name: S) -> Option<PackedFunc> {
SYSTEM_LIB_FUNCTIONS
.lock()
.unwrap()
.get(name.as_ref())
.map(|func| wrap_backend_packed_func(func.to_owned()))
}
}
impl Default for SystemLibModule {
fn default() -> Self {
SystemLibModule {}
}
}
#[no_mangle]
pub extern "C" fn TVMBackendRegisterSystemLibSymbol(
cname: *const c_char,
func: BackendPackedCFunc,
) -> i32 {
let name = unsafe { CStr::from_ptr(cname).to_str().unwrap() };
SYSTEM_LIB_FUNCTIONS
.lock()
.unwrap()
.insert(name.to_string(), func);
return 0;
}
use std::{any::Any, convert::TryFrom, marker::PhantomData, os::raw::c_void};
use ffi::runtime::{
BackendPackedCFunc, DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLTensor,
TVMTypeCode_kArrayHandle, TVMTypeCode_kHandle, TVMValue,
};
use errors::*;
pub type PackedFunc = Box<Fn(&[TVMArgValue]) -> TVMRetValue + Send + Sync>;
/// Calls a packed function and returns a `TVMRetValue`.
///
/// # Example
///
/// `call_packed!(my_tvm_func, &mut arg1, &mut arg2)`
#[macro_export]
macro_rules! call_packed {
($fn:expr, $($args:expr),+) => {
$fn(&[$($args.into(),)+])
};
($fn:expr) => {
$fn(&Vec::new())
};
}
/// A borrowed TVMPODValue. Can be constructed using `into()` but the preferred way
/// to obtain a `TVMArgValue` is automatically via `call_packed!`.
#[derive(Clone, Copy)]
pub struct TVMArgValue<'a> {
_lifetime: PhantomData<&'a ()>,
pub(crate) value: TVMValue,
pub(crate) type_code: i64,
}
impl<'a> TVMArgValue<'a> {
pub fn new(value: TVMValue, type_code: i64) -> Self {
TVMArgValue {
_lifetime: PhantomData,
value: value,
type_code: type_code,
}
}
}
/// Creates a conversion to a `TVMArgValue` for a primitive type and DLDataTypeCode.
macro_rules! impl_prim_tvm_arg {
($type:ty, $field:ident, $code:expr, $as:ty) => {
impl<'a> From<$type> for TVMArgValue<'a> {
fn from(val: $type) -> Self {
TVMArgValue {
value: TVMValue { $field: val as $as },
type_code: $code as i64,
_lifetime: PhantomData,
}
}
}
};
($type:ty, $field:ident, $code:expr) => {
impl_prim_tvm_arg!($type, $field, $code, $type);
};
($type:ty,v_int64) => {
impl_prim_tvm_arg!($type, v_int64, DLDataTypeCode_kDLInt, i64);
};
($type:ty,v_float64) => {
impl_prim_tvm_arg!($type, v_float64, DLDataTypeCode_kDLFloat, f64);
};
}
impl_prim_tvm_arg!(f32, v_float64);
impl_prim_tvm_arg!(f64, v_float64);
impl_prim_tvm_arg!(i8, v_int64);
impl_prim_tvm_arg!(u8, v_int64);
impl_prim_tvm_arg!(i32, v_int64);
impl_prim_tvm_arg!(u32, v_int64);
impl_prim_tvm_arg!(i64, v_int64);
impl_prim_tvm_arg!(u64, v_int64);
impl_prim_tvm_arg!(bool, v_int64);
/// Creates a conversion to a `TVMArgValue` for an object handle.
impl<'a, T> From<*const T> for TVMArgValue<'a> {
fn from(ptr: *const T) -> Self {
TVMArgValue {
value: TVMValue {
v_handle: ptr as *mut T as *mut c_void,
},
type_code: TVMTypeCode_kArrayHandle as i64,
_lifetime: PhantomData,
}
}
}
/// Creates a conversion to a `TVMArgValue` for a mutable object handle.
impl<'a, T> From<*mut T> for TVMArgValue<'a> {
fn from(ptr: *mut T) -> Self {
TVMArgValue {
value: TVMValue {
v_handle: ptr as *mut c_void,
},
type_code: TVMTypeCode_kHandle as i64,
_lifetime: PhantomData,
}
}
}
impl<'a> From<&'a mut DLTensor> for TVMArgValue<'a> {
fn from(arr: &'a mut DLTensor) -> Self {
TVMArgValue {
value: TVMValue {
v_handle: arr as *mut _ as *mut c_void,
},
type_code: TVMTypeCode_kArrayHandle as i64,
_lifetime: PhantomData,
}
}
}
impl<'a> From<&'a DLTensor> for TVMArgValue<'a> {
fn from(arr: &'a DLTensor) -> Self {
TVMArgValue {
value: TVMValue {
v_handle: arr as *const _ as *mut DLTensor as *mut c_void,
},
type_code: TVMTypeCode_kArrayHandle as i64,
_lifetime: PhantomData,
}
}
}
/// An owned TVMPODValue. Can be converted from a variety of primitive and object types.
/// Can be downcasted using `try_from` if it contains the desired type.
///
/// # Example
///
/// ```
/// let a = 42u32;
/// let b: i64 = TVMRetValue::from(a).try_into().unwrap();
///
/// let s = "hello, world!";
/// let t: TVMRetValue = s.into();
/// assert_eq!(String::try_from(t).unwrap(), s);
/// ```
pub struct TVMRetValue {
/// A primitive return value, if any.
prim_value: u64,
/// An object return value, if any.
box_value: Box<Any>,
/// The DLDataTypeCode which determines whether `prim_value` or `box_value` is in use.
type_code: i64,
}
#[cfg(target_env = "sgx")]
impl TVMRetValue {
pub(crate) fn from_tvm_value(value: TVMValue, type_code: i64) -> Self {
unsafe {
Self {
prim_value: match type_code {
0 | 1 => value.v_int64 as u64,
2 => value.v_float64 as u64,
3 | 7 | 8 | 9 | 10 => value.v_handle as u64,
11 | 12 => value.v_str as u64,
_ => 0,
} as u64,
box_value: box (),
type_code: type_code,
}
}
}
pub fn into_tvm_value(self) -> (TVMValue, i64) {
let val = match self.type_code {
0 | 1 => TVMValue {
v_int64: self.prim_value.clone() as i64,
},
2 => TVMValue {
v_float64: self.prim_value.clone() as f64,
},
3 | 7 | 8 | 9 | 10 => TVMValue {
v_handle: Box::into_raw(self.box_value) as *mut c_void,
},
11 | 12 => TVMValue {
v_str: Box::into_raw(self.box_value) as *const _,
},
_ => unreachable!(),
};
(val, self.type_code)
}
}
impl Default for TVMRetValue {
fn default() -> Self {
TVMRetValue {
prim_value: 0,
box_value: box (),
type_code: 0,
}
}
}
macro_rules! impl_prim_ret_value {
($type:ty, $code:expr) => {
impl From<$type> for TVMRetValue {
fn from(val: $type) -> Self {
TVMRetValue {
prim_value: val as u64,
box_value: box (),
type_code: $code,
}
}
}
impl TryFrom<TVMRetValue> for $type {
type Error = Error;
fn try_from(ret: TVMRetValue) -> Result<$type> {
if ret.type_code == $code {
Ok(ret.prim_value as $type)
} else {
bail!(ErrorKind::TryFromTVMRetValueError(
stringify!($type).to_string(),
ret.type_code
))
}
}
}
};
}
macro_rules! impl_boxed_ret_value {
($type:ty, $code:expr) => {
impl From<$type> for TVMRetValue {
fn from(val: $type) -> Self {
TVMRetValue {
prim_value: 0,
box_value: box val,
type_code: $code,
}
}
}
impl TryFrom<TVMRetValue> for $type {
type Error = Error;
fn try_from(ret: TVMRetValue) -> Result<$type> {
if let Ok(val) = ret.box_value.downcast::<$type>() {
Ok(*val)
} else {
bail!(ErrorKind::TryFromTVMRetValueError(
stringify!($type).to_string(),
ret.type_code
))
}
}
}
};
}
impl_prim_ret_value!(i8, 0);
impl_prim_ret_value!(u8, 1);
impl_prim_ret_value!(i16, 0);
impl_prim_ret_value!(u16, 1);
impl_prim_ret_value!(i32, 0);
impl_prim_ret_value!(u32, 1);
impl_prim_ret_value!(f32, 2);
impl_prim_ret_value!(i64, 0);
impl_prim_ret_value!(u64, 1);
impl_prim_ret_value!(f64, 2);
impl_prim_ret_value!(isize, 0);
impl_prim_ret_value!(usize, 1);
impl_boxed_ret_value!(String, 11);
// @see `WrapPackedFunc` in `llvm_module.cc`.
pub(super) fn wrap_backend_packed_func(func: BackendPackedCFunc) -> PackedFunc {
box move |args: &[TVMArgValue]| {
func(
args
.iter()
.map(|ref arg| arg.value)
.collect::<Vec<TVMValue>>()
.as_ptr(),
args
.iter()
.map(|ref arg| arg.type_code as i32)
.collect::<Vec<i32>>()
.as_ptr() as *const i32,
args.len() as i32,
);
TVMRetValue::default()
}
}
use std::{
ffi::CString,
os::raw::{c_char, c_int},
};
use errors::Result;
use ffi::runtime::TVMValue;
use runtime::{threading::sgx_join_threads, SystemLibModule, TVMArgValue, TVMRetValue};
pub use runtime::threading::tvm_run_worker as run_worker;
#[macro_export]
macro_rules! tvm_ocall {
($func: expr) => {
match $func {
0 => Ok(()),
err => Err(format!("SGX error: {}", err)),
}
};
}
pub type SgxStatus = u32;
#[cfg(target_env = "sgx")]
extern "C" {
fn tvm_ocall_packed_func(
name: *const c_char,
arg_values: *const TVMValue,
type_codes: *const c_int,
num_args: c_int,
ret_val: *mut TVMValue,
ret_type_code: *mut c_int,
) -> SgxStatus;
}
pub fn ocall_packed_func<S: AsRef<str>>(fn_name: S, args: &[TVMArgValue]) -> Result<TVMRetValue> {
let mut ret_val = TVMValue { v_int64: 0 };
let ret_type_code = 0i64;
unsafe {
tvm_ocall!(tvm_ocall_packed_func(
CString::new(fn_name.as_ref()).unwrap().as_ptr(),
args
.iter()
.map(|ref arg| arg.value)
.collect::<Vec<TVMValue>>()
.as_ptr(),
args
.iter()
.map(|ref arg| arg.type_code as i32)
.collect::<Vec<i32>>()
.as_ptr() as *const i32,
args.len() as i32,
&mut ret_val as *mut TVMValue,
&mut (ret_type_code as i32) as *mut c_int,
))?;
}
Ok(TVMRetValue::from_tvm_value(ret_val, ret_type_code as i64))
}
#[macro_export]
macro_rules! ocall_packed {
($fn_name:expr, $($args:expr),+) => {
::runtime::sgx::ocall_packed_func($fn_name, &[$($args.into(),)+])
.expect(concat!("Error calling `", $fn_name, "`"))
};
($fn_name:expr) => {
::runtime::sgx::ocall_packed_func($fn_name, &Vec::new())
.expect(concat!("Error calling `", $fn_name, "`"))
}
}
pub fn shutdown() {
if env!("TVM_NUM_THREADS") != "0" {
sgx_join_threads()
}
}
impl Drop for SystemLibModule {
fn drop(&mut self) {
shutdown()
}
}
use std::{
os::raw::{c_int, c_void},
sync::{
atomic::{AtomicUsize, Ordering, ATOMIC_USIZE_INIT},
Arc, Barrier,
},
};
#[cfg(not(target_env = "sgx"))]
use num_cpus;
#[cfg(not(target_env = "sgx"))]
use std::{
env,
thread::{self, JoinHandle},
};
#[cfg(target_env = "sgx")]
use std::{collections::VecDeque, ptr, sync::Mutex};
use bounded_spsc_queue::{self, Producer};
use super::super::errors::*;
use ffi::runtime::TVMParallelGroupEnv;
#[cfg(target_env = "sgx")]
use super::{TVMArgValue, TVMRetValue};
type FTVMParallelLambda =
extern "C" fn(task_id: usize, penv: *const TVMParallelGroupEnv, cdata: *const c_void) -> i32;
/// Holds a parallel job request made by a TVM library function.
struct Job {
cb: FTVMParallelLambda,
cdata: *const c_void,
req_num_tasks: usize,
pending: Arc<AtomicUsize>,
}
impl Job {
/// Splits this job into a number of `Task`s which can be scheduled.
fn tasks(&self, num_workers: usize) -> Vec<Task> {
let num_tasks = if self.req_num_tasks == 0 {
num_workers
} else {
self.req_num_tasks.min(num_workers)
};
self.pending.store(num_tasks, Ordering::SeqCst);
let barrier = Arc::new(Barrier::new(num_tasks));
(0..num_tasks)
.map(move |i| Task {
id: i,
flambda: self.cb,
penv: TVMParallelGroupEnv {
sync_handle: &Arc::clone(&barrier) as *const _ as *mut c_void,
num_task: num_tasks as i32,
},
cdata: self.cdata,
pending: Arc::clone(&self.pending),
}).collect()
}
/// Waits for all tasks in this `Job` to be completed.
fn wait(&self) -> Result<()> {
while self.pending.load(Ordering::Acquire) > 0 {
#[cfg(not(target_env = "sgx"))]
thread::yield_now();
}
Ok(())
}
}
/// A chunk of work requested by a TVM function.
struct Task {
id: usize,
flambda: FTVMParallelLambda,
penv: TVMParallelGroupEnv,
cdata: *const c_void,
pending: Arc<AtomicUsize>,
}
unsafe impl Send for Task {}
unsafe impl Sync for Task {}
impl FnOnce<()> for Task {
type Output = i32;
extern "rust-call" fn call_once(self, _args: ()) -> Self::Output {
let status = (self.flambda)(self.id, &self.penv as *const _, self.cdata);
self.pending.fetch_sub(1, Ordering::AcqRel);
status
}
}
#[derive(Default)]
struct Threads {
#[allow(unused)]
#[cfg(not(target_env = "sgx"))]
handles: Vec<JoinHandle<()>>,
queues: Vec<Producer<Task>>,
}
impl<'a> Threads {
#[cfg(not(target_env = "sgx"))]
fn launch<F: Sync + Send + FnOnce(Consumer<Task>) + 'static + Copy>(
num_threads: usize,
cb: F,
) -> Self {
let (handles, queues) = (0..num_threads)
.map(|_| {
let (p, c) = bounded_spsc_queue::make(2);
let handle = thread::spawn(move || cb(c.into()));
(handle, p)
}).unzip();
Threads {
handles: handles,
queues: queues,
}
}
#[cfg(target_env = "sgx")]
fn launch<F: Sync + Send + FnOnce(Consumer<Task>) + 'static + Copy>(
num_threads: usize,
_cb: F,
) -> Self {
let mut consumer_queues = SGX_QUEUES.lock().unwrap();
let queues = (0..num_threads)
.map(|_| {
let (p, c) = bounded_spsc_queue::make(2);
consumer_queues.push_back(c.into());
p
}).collect();
ocall_packed!("__sgx_thread_group_launch__", num_threads as u64);
Threads { queues: queues }
}
}
struct ThreadPool {
num_workers: usize,
#[allow(unused)]
threads: Threads,
}
thread_local!(static THREAD_POOL: ThreadPool = ThreadPool::new());
impl ThreadPool {
fn new() -> Self {
let num_workers = max_concurrency();
ThreadPool {
num_workers: num_workers,
threads: Threads::launch(num_workers, ThreadPool::run_worker),
}
}
fn launch(&self, job: Job) {
let mut tasks = job.tasks(self.num_workers + 1);
for (i, task) in tasks.split_off(1).into_iter().enumerate() {
self.threads.queues[i].push(task);
}
tasks.pop().unwrap()();
job.wait().unwrap();
}
fn run_worker(queue: Consumer<Task>) {
loop {
let task = queue.pop();
let result = task();
if result == <i32>::min_value() {
break;
} else if result != 0 {
panic!("Error running task.");
}
}
}
}
// Send + Sync wrapper for bounded_spsc_queue::Consumer
struct Consumer<T> {
consumer: bounded_spsc_queue::Consumer<T>,
}
impl<T> From<bounded_spsc_queue::Consumer<T>> for Consumer<T> {
fn from(c: bounded_spsc_queue::Consumer<T>) -> Self {
Consumer { consumer: c }
}
}
impl<T> Consumer<T> {
fn pop(&self) -> T {
self.consumer.pop()
}
}
unsafe impl<T> Send for Consumer<T> {}
unsafe impl<T> Sync for Consumer<T> {}
#[cfg(target_env = "sgx")]
lazy_static! {
/// Holds tasks for untrusted threads which re-enter the enclave to execute.
static ref SGX_QUEUES: Mutex<VecDeque<Consumer<Task>>> = Mutex::new(VecDeque::new());
}
#[cfg(all(not(target_arch = "wasm32"), not(target_env = "sgx")))]
fn max_concurrency() -> usize {
if let Ok(threads_str) = env::var("TVM_NUM_THREADS").or(env::var("OMP_NUM_THREADS")) {
if let Ok(threads) = usize::from_str_radix(&threads_str, 10) {
return threads;
}
}
num_cpus::get_physical()
}
#[cfg(target_env = "sgx")]
fn max_concurrency() -> usize {
usize::from_str_radix(env!("TVM_NUM_THREADS"), 10).unwrap_or(1)
}
#[cfg(target_arch = "wasm32")]
fn max_concurrency() -> usize {
0 // wasm doesn't support threads yet
}
#[cfg(target_env = "sgx")]
pub fn tvm_run_worker(_args: &[TVMArgValue]) -> TVMRetValue {
let q = {
let mut qs = SGX_QUEUES.lock().unwrap();
qs.pop_front()
// `qs: MutexGuard` needs to be dropped here since `run_worker` won't return
};
if let Some(q) = q {
ThreadPool::run_worker(q);
}
TVMRetValue::default()
}
#[no_mangle]
pub extern "C" fn TVMBackendParallelLaunch(
cb: FTVMParallelLambda,
cdata: *const c_void,
num_task: usize,
) -> c_int {
if max_concurrency() == 0 {
let penv = TVMParallelGroupEnv {
sync_handle: 0 as *mut c_void,
num_task: 1,
};
cb(0, &penv as *const _, cdata);
} else {
THREAD_POOL.with(|pool| {
pool.launch(Job {
cb: cb,
cdata: cdata,
req_num_tasks: num_task,
pending: Arc::new(ATOMIC_USIZE_INIT),
});
});
}
return 0;
}
#[cfg(target_env = "sgx")]
pub(crate) fn sgx_join_threads() {
extern "C" fn poison_pill(
_task_id: usize,
_penv: *const TVMParallelGroupEnv,
_cdata: *const c_void,
) -> i32 {
<i32>::min_value()
}
THREAD_POOL.with(|pool| {
pool.launch(Job {
cb: poison_pill,
cdata: ptr::null(),
req_num_tasks: 0,
pending: Arc::new(ATOMIC_USIZE_INIT),
});
});
ocall_packed!("__sgx_thread_group_join__", 0);
}
// @see https://github.com/dmlc/tvm/issues/988 for information on why this function is used.
#[no_mangle]
pub extern "C" fn TVMBackendParallelBarrier(_task_id: usize, penv: *const TVMParallelGroupEnv) {
let barrier: &Arc<Barrier> = unsafe { &*((*penv).sync_handle as *const Arc<Barrier>) };
barrier.wait();
}
#[cfg(test)]
mod tests {
use std::{ptr, thread, time::Duration};
use super::*;
#[test]
fn test_max_concurrency() {
env::set_var("TVM_NUM_THREADS", "42");
env::set_var("OMP_NUM_THREADS", "24");
assert_eq!(max_concurrency(), 42);
env::remove_var("TVM_NUM_THREADS");
assert_eq!(max_concurrency(), 24);
}
extern "C" fn flambda(
task_id: usize,
penv: *const TVMParallelGroupEnv,
cdata: *const c_void,
) -> i32 {
if cdata == ptr::null() {
return 0;
}
unsafe {
let &(ref counter, ref task_ids_sum) = &*(cdata as *const (AtomicUsize, AtomicUsize));
thread::sleep(Duration::from_millis(50 * task_id as u64));
counter.fetch_add(1, Ordering::SeqCst);
task_ids_sum.fetch_add(task_id, Ordering::SeqCst);
assert_eq!((*penv).num_task, 3);
}
0
}
#[test]
fn test_parallel_launch() {
TVMBackendParallelLaunch(flambda, ptr::null(), 6);
let counter = ATOMIC_USIZE_INIT;
let task_ids_sum = ATOMIC_USIZE_INIT;
let cdata = (counter, task_ids_sum);
let num_tasks = 3;
TVMBackendParallelLaunch(flambda, &cdata as *const _ as *const c_void, num_tasks);
assert_eq!(cdata.0.load(Ordering::SeqCst), num_tasks);
assert_eq!(
cdata.1.load(Ordering::SeqCst),
(0..num_tasks).sum::<usize>()
);
}
}
use std::{
cell::RefCell,
os::raw::{c_int, c_void},
ptr,
};
use super::allocator::Allocation;
use errors::*;
const WS_ALIGN: usize = 64; // taken from `kTempAllocaAlignment` in `device_api.h`
struct WorkspacePool {
workspaces: Vec<Allocation>,
free: Vec<usize>,
in_use: Vec<usize>,
}
impl WorkspacePool {
fn new() -> Self {
WorkspacePool {
workspaces: Vec::new(),
free: Vec::new(),
in_use: Vec::new(),
}
}
fn alloc_new(&mut self, size: usize) -> Result<*mut u8> {
self.workspaces.push(Allocation::new(size, Some(WS_ALIGN))?);
self.in_use.push(self.workspaces.len() - 1);
Ok(self.workspaces[self.workspaces.len() - 1].as_mut_ptr())
}
fn alloc(&mut self, size: usize) -> Result<*mut u8> {
if self.free.len() == 0 {
return self.alloc_new(size);
}
let idx = self
.free
.iter()
.fold(None, |cur_ws_idx: Option<usize>, &idx| {
let ws_size = self.workspaces[idx].size();
if !ws_size >= size {
return cur_ws_idx;
}
cur_ws_idx.or(Some(idx)).and_then(|cur_idx| {
let cur_size = self.workspaces[cur_idx].size();
Some(match ws_size <= cur_size {
true => idx,
false => cur_idx,
})
})
});
match idx {
Some(idx) => {
self.free.remove_item(&idx).unwrap();
self.in_use.push(idx);
Ok(self.workspaces[idx].as_mut_ptr())
}
None => self.alloc_new(size),
}
}
fn free(&mut self, ptr: *mut u8) -> Result<()> {
let mut ws_idx = None;
for i in 0..self.in_use.len() {
let idx = self.in_use[i];
if self.workspaces[idx].as_mut_ptr() == ptr {
self.in_use.remove(i);
ws_idx = Some(idx);
break;
}
}
Ok(
self
.free
.push(ws_idx.ok_or("Tried to free nonexistent workspace.")?),
)
}
}
thread_local!(static WORKSPACE_POOL: RefCell<WorkspacePool> = RefCell::new(WorkspacePool::new()));
const WORKSPACE_PAGE_SIZE: usize = 4 << 10;
#[no_mangle]
pub extern "C" fn TVMBackendAllocWorkspace(
_device_type: c_int,
_device_id: c_int,
size: u64,
_dtype_code_hint: c_int,
_dtype_bits_hint: c_int,
) -> *mut c_void {
let nbytes = if size == 0 {
WORKSPACE_PAGE_SIZE
} else {
size as usize
};
WORKSPACE_POOL.with(|pool_cell| {
pool_cell
.borrow_mut()
.alloc(nbytes as usize)
.unwrap_or(ptr::null_mut()) as *mut c_void
})
}
#[no_mangle]
pub extern "C" fn TVMBackendFreeWorkspace(
_device_type: c_int,
_device_id: c_int,
ptr: *mut c_void,
) -> c_int {
WORKSPACE_POOL.with(|pool_cell| {
(match pool_cell.borrow_mut().free(ptr as *mut u8) {
Ok(()) => 0,
Err(_) => -1,
}) as c_int
});
return 0;
}
*.json
*.params
*.o
"""Builds a simple NNVM graph for testing."""
from os import path as osp
import nnvm
from nnvm import sym
from nnvm.compiler import graph_util
from nnvm.testing import init
import numpy as np
import tvm
CWD = osp.dirname(osp.abspath(osp.expanduser(__file__)))
def _get_model(dshape):
data = sym.Variable('data', shape=dshape)
fc1 = sym.dense(data, units=dshape[-1]*2, use_bias=True)
left, right = sym.split(fc1, indices_or_sections=2, axis=1)
return sym.Group(((left + 1), (right - 1)))
def _init_params(graph, input_shapes, initializer=init.Xavier(), seed=10):
if isinstance(graph, sym.Symbol):
graph = nnvm.graph.create(graph)
ishapes, _ = graph_util.infer_shape(graph, **input_shapes)
param_shapes = dict(zip(graph.index.input_names, ishapes))
np.random.seed(seed)
params = {}
for param, shape in param_shapes.items():
if param in {'data', 'label'} or not shape:
continue
init_value = np.empty(shape).astype('float32')
initializer(param, init_value)
params[param] = tvm.nd.array(init_value)
return params
def main():
dshape = (32, 16)
net = _get_model(dshape)
ishape_dict = {'data': dshape}
params = _init_params(net, ishape_dict)
graph, lib, params = nnvm.compiler.build(net, 'llvm',
shape=ishape_dict,
params=params,
dtype='float32')
with open(osp.join(CWD, 'graph.json'), 'w') as f_resnet:
f_resnet.write(graph.json())
with open(osp.join(CWD, 'graph.params'), 'wb') as f_params:
f_params.write(nnvm.compiler.save_param_dict(params))
if __name__ == '__main__':
main()
#![feature(try_from)]
extern crate serde;
extern crate serde_json;
extern crate tvm;
use std::{convert::TryFrom, fs, io::Read};
use tvm::runtime::Graph;
#[test]
fn test_load_graph() {
let mut params_bytes = Vec::new();
fs::File::open(concat!(env!("CARGO_MANIFEST_DIR"), "/tests/graph.params"))
.expect("Could not find TVM graph. Did you run `tests/build_model.py`?")
.read_to_end(&mut params_bytes)
.unwrap();
let _params = tvm::runtime::load_param_dict(&params_bytes);
let graph = Graph::try_from(
&fs::read_to_string(concat!(env!("CARGO_MANIFEST_DIR"), "/tests/graph.json")).unwrap(),
).unwrap();
assert_eq!(graph.nodes[3].op, "tvm_op");
assert_eq!(
graph.nodes[3]
.attrs
.as_ref()
.unwrap()
.get("func_name")
.unwrap(),
"fuse_dense"
);
assert_eq!(graph.nodes[5].inputs[0].index, 0);
assert_eq!(graph.nodes[6].inputs[0].index, 1);
assert_eq!(graph.heads.len(), 2);
}
[package]
name = "test-nnvm"
version = "0.0.0"
license = "Apache-2.0"
authors = ["Nick Hynes <nhynes@berkeley.edu>"]
[dependencies]
ndarray = "0.11.2"
tvm = { path = "../../" }
serde = "1.0.59"
serde_json = "1.0.17"
[build-dependencies]
ar = "0.6.0"
extern crate ar;
use std::{env, path::PathBuf, process::Command};
use ar::Builder;
use std::fs::File;
fn main() {
let out_dir = env::var("OUT_DIR").unwrap();
let output = Command::new(concat!(
env!("CARGO_MANIFEST_DIR"),
"/src/build_test_graph.py"
)).arg(&out_dir)
.output()
.expect("Failed to execute command");
if output.stderr.len() > 0 {
panic!(String::from_utf8(output.stderr).unwrap());
}
let in_path: PathBuf = [&out_dir, "graph.o"].iter().collect();
let out_path: PathBuf = [&out_dir, "libgraph.a"].iter().collect();
let mut builder = Builder::new(File::create(out_path.to_str().unwrap()).unwrap());
builder.append_path(in_path.to_str().unwrap()).unwrap();
println!("cargo:rustc-link-lib=static=graph");
println!("cargo:rustc-link-search=native={}", out_dir);
}
#!/usr/bin/env python3
"""Builds a simple NNVM graph for testing."""
from os import path as osp
import sys
import nnvm
from nnvm import sym
from nnvm.compiler import graph_util
from nnvm.testing import init
import numpy as np
import tvm
def _get_model(dshape):
data = sym.Variable('data', shape=dshape)
fc = sym.dense(data, units=dshape[-1]*2, use_bias=True)
left, right = sym.split(fc, indices_or_sections=2, axis=1)
return sym.Group(((left + 1), (right - 1), fc))
def _init_params(graph, input_shapes, initializer=init.Xavier(), seed=10):
if isinstance(graph, sym.Symbol):
graph = nnvm.graph.create(graph)
ishapes, _ = graph_util.infer_shape(graph, **input_shapes)
param_shapes = dict(zip(graph.index.input_names, ishapes))
np.random.seed(seed)
params = {}
for param, shape in param_shapes.items():
if param in {'data', 'label'} or not shape:
continue
init_value = np.arange(np.product(shape), 0, -1).reshape(*shape).astype('float32')
if param.endswith('_bias'):
params[param] = tvm.nd.array(init_value)
continue
init_value = np.empty(shape).astype('float32')
initializer(param, init_value)
# init_value /= init_value.sum() + 1e-10
params[param] = tvm.nd.array(init_value)
return params
def main():
dshape = (4, 8)
net = _get_model(dshape)
ishape_dict = {'data': dshape}
params = _init_params(net, ishape_dict)
graph, lib, params = nnvm.compiler.build(net, 'llvm --system-lib',
shape=ishape_dict,
params=params,
dtype='float32')
out_dir = sys.argv[1]
lib.save(osp.join(sys.argv[1], 'graph.o'))
with open(osp.join(out_dir, 'graph.json'), 'w') as f_resnet:
f_resnet.write(graph.json())
with open(osp.join(out_dir, 'graph.params'), 'wb') as f_params:
f_params.write(nnvm.compiler.save_param_dict(params))
if __name__ == '__main__':
main()
#![feature(try_from)]
#[macro_use]
extern crate ndarray;
extern crate serde;
extern crate serde_json;
extern crate tvm;
use std::{collections::HashMap, convert::TryFrom, fs, io::Read};
use ndarray::Array;
use tvm::runtime::{Graph, GraphExecutor, SystemLibModule, Tensor};
const BATCH_SIZE: usize = 4;
const IN_DIM: usize = 8;
macro_rules! check_sum {
($e:expr, $a:ident, $b:ident) => {
let a = Array::try_from($e.get_input(stringify!($a)).unwrap()).unwrap();
check_sum!(a, $b);
};
($e:expr, $a:expr, $b:ident) => {
let a = Array::try_from($e.get_output($a).unwrap()).unwrap();
check_sum!(a, $b);
};
($a:ident, $b:ident) => {
let a_sum: f32 = $a.scalar_sum();
let b_sum: f32 = $b.scalar_sum();
assert!((a_sum - b_sum).abs() < 1e-2, "{} != {}", a_sum, b_sum);
};
}
fn main() {
let syslib = SystemLibModule::default();
let mut params_bytes = Vec::new();
fs::File::open(concat!(env!("OUT_DIR"), "/graph.params"))
.unwrap()
.read_to_end(&mut params_bytes)
.unwrap();
let params = tvm::runtime::load_param_dict(&params_bytes)
.unwrap()
.into_iter()
.map(|(k, v)| (k, v.to_owned()))
.collect::<HashMap<String, Tensor<'static>>>();
let graph =
Graph::try_from(&fs::read_to_string(concat!(env!("OUT_DIR"), "/graph.json")).unwrap()).unwrap();
let mut exec = GraphExecutor::new(graph, &syslib).unwrap();
let x = Array::from_shape_vec(
(BATCH_SIZE, IN_DIM),
(0..BATCH_SIZE * IN_DIM)
.map(|x| x as f32)
.collect::<Vec<f32>>(),
).unwrap();
let w = Array::try_from(params.get("dense0_weight").unwrap())
.unwrap()
.into_shape((IN_DIM * 2, IN_DIM))
.unwrap();
let b = Array::try_from(params.get("dense0_bias").unwrap()).unwrap();
let dense = x.dot(&w.t()) + &b;
let left = dense.slice(s![.., 0..IN_DIM]);
let right = dense.slice(s![.., IN_DIM..]);
let expected_o0 = &left + 1f32;
let expected_o1 = &right - 1f32;
exec.load_params(params);
exec.set_input("data", x.clone().into());
check_sum!(exec, data, x);
check_sum!(exec, dense0_weight, w);
check_sum!(exec, dense0_bias, b);
exec.run();
check_sum!(exec, 0, expected_o0);
check_sum!(exec, 1, expected_o1);
check_sum!(exec, 2, dense);
}
[package]
name = "test-tvm-basic"
version = "0.0.0"
license = "Apache-2.0"
authors = ["Nick Hynes <nhynes@berkeley.edu>"]
[dependencies]
ndarray = "0.11.2"
tvm = { path = "../../" }
[build-dependencies]
ar = "0.6.0"
extern crate ar;
use std::{env, path::PathBuf, process::Command};
use ar::Builder;
use std::fs::File;
fn main() {
let out_dir = env::var("OUT_DIR").unwrap();
let output = Command::new(concat!(
env!("CARGO_MANIFEST_DIR"),
"/src/build_test_lib.py"
)).arg(&out_dir)
.output()
.expect("Failed to execute command");
if output.stderr.len() > 0 {
panic!(String::from_utf8(output.stderr).unwrap());
}
let in_path: PathBuf = [&out_dir, "test.o"].iter().collect();
let out_path: PathBuf = [&out_dir, "libtest.a"].iter().collect();
let mut builder = Builder::new(File::create(out_path.to_str().unwrap()).unwrap());
builder.append_path(in_path.to_str().unwrap()).unwrap();
println!("cargo:rustc-link-lib=static=test");
println!("cargo:rustc-link-search=native={}", out_dir);
}
#!/usr/bin/env python3
"""Prepares a simple TVM library for testing."""
from os import path as osp
import sys
import tvm
def main():
n = tvm.var('n')
A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B')
C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
s = tvm.create_schedule(C.op)
s[C].parallel(s[C].op.axis[0])
print(tvm.lower(s, [A, B, C], simple_mode=True))
tvm.build(s, [A, B, C], 'llvm --system-lib').save(osp.join(sys.argv[1], 'test.o'))
if __name__ == '__main__':
main()
extern crate ndarray;
#[macro_use]
extern crate tvm;
use ndarray::Array;
use tvm::{
ffi::runtime::DLTensor,
runtime::{Module, SystemLibModule},
};
fn main() {
let syslib = SystemLibModule::default();
let add = syslib
.get_function("default_function")
.expect("main function not found");
let mut a = Array::from_vec(vec![1f32, 2., 3., 4.]);
let mut b = Array::from_vec(vec![1f32, 0., 1., 0.]);
let mut c = Array::from_vec(vec![0f32; 4]);
let e = Array::from_vec(vec![2f32, 2., 4., 4.]);
let mut a_dl: DLTensor = (&mut a).into();
let mut b_dl: DLTensor = (&mut b).into();
let mut c_dl: DLTensor = (&mut c).into();
call_packed!(add, &mut a_dl, &mut b_dl, &mut c_dl);
assert!(c.all_close(&e, 1e-8f32));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment