Commit 4968279f by Nick Hynes Committed by Tianqi Chen

[Rust] Unify types between bindings and pure Rust impl (#2616)

parent 71abe36e
Cargo.lock
target/
**/*.rs.bk
*.rs.bk
Cargo.lock
c_runtime_api.rs
target
**/*.rs.bk
Cargo.lock
/tvm-sys/src/bindgen.rs
......@@ -5,9 +5,11 @@ authors = ["TVM Contributors"]
license = "Apache-2.0"
[features]
runtime = []
frontend = ["tvm-sys"]
bindings = []
[dependencies]
error-chain = { version = "0.12.0", default-features = false }
tvm-sys = { version = "0.1.0", path = "tvm-sys", optional = true }
failure = "0.1.5"
ndarray = "0.12.1"
[build-dependencies]
bindgen = "0.37.4"
......@@ -3,23 +3,29 @@ extern crate bindgen;
use std::path::PathBuf;
fn main() {
println!("cargo:rerun-if-env-changed=TVM_HOME");
println!("cargo:rustc-link-lib=dylib=tvm_runtime");
println!("cargo:rustc-link-search={}/build", env!("TVM_HOME"));
let bindings = bindgen::Builder::default()
if cfg!(feature = "bindings") {
println!("cargo:rerun-if-env-changed=TVM_HOME");
println!("cargo:rustc-link-lib=dylib=tvm_runtime");
println!("cargo:rustc-link-search={}/build", env!("TVM_HOME"));
}
// @see rust-bindgen#550 for `blacklist_type`
bindgen::Builder::default()
.header(format!(
"{}/include/tvm/runtime/c_runtime_api.h",
env!("TVM_HOME")
))
.header(format!(
"{}/include/tvm/runtime/c_backend_api.h",
env!("TVM_HOME")
))
.clang_arg(format!("-I{}/3rdparty/dlpack/include/", env!("TVM_HOME")))
.blacklist_type("max_align_t") // @see rust-bindgen#550
.blacklist_type("max_align_t")
.layout_tests(false)
.derive_partialeq(true)
.derive_eq(true)
.generate()
.expect("unable to generate bindings");
bindings
.write_to_file(PathBuf::from("src/bindgen.rs"))
.expect("unable to generate bindings")
.write_to_file(PathBuf::from("src/c_runtime_api.rs"))
.expect("can not write the bindings!");
}
use std::{
any::TypeId,
mem,
os::raw::{c_int, c_void},
};
use crate::ffi::{
DLContext, DLDataType, DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLDataTypeCode_kDLUInt,
DLDeviceType_kDLCPU, DLTensor,
};
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct DataType {
pub code: usize,
pub bits: usize,
pub lanes: usize,
}
impl DataType {
/// Returns the number of bytes occupied by an element of this `DataType`.
pub fn itemsize(&self) -> usize {
(self.bits * self.lanes) >> 3
}
/// Returns whether this `DataType` represents primitive type `T`.
pub fn is_type<T: 'static>(&self) -> bool {
if self.lanes != 1 {
return false;
}
let typ = TypeId::of::<T>();
(typ == TypeId::of::<i32>() && self.code == 0 && self.bits == 32)
|| (typ == TypeId::of::<i64>() && self.code == 0 && self.bits == 64)
|| (typ == TypeId::of::<u32>() && self.code == 1 && self.bits == 32)
|| (typ == TypeId::of::<u64>() && self.code == 1 && self.bits == 64)
|| (typ == TypeId::of::<f32>() && self.code == 2 && self.bits == 32)
|| (typ == TypeId::of::<f64>() && self.code == 2 && self.bits == 64)
}
pub fn code(&self) -> usize {
self.code
}
pub fn bits(&self) -> usize {
self.bits
}
pub fn lanes(&self) -> usize {
self.lanes
}
}
impl<'a> From<&'a DataType> for DLDataType {
fn from(dtype: &'a DataType) -> Self {
Self {
code: dtype.code as u8,
bits: dtype.bits as u8,
lanes: dtype.lanes as u16,
}
}
}
impl From<DLDataType> for DataType {
fn from(dtype: DLDataType) -> Self {
Self {
code: dtype.code as usize,
bits: dtype.bits as usize,
lanes: dtype.lanes as usize,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct TVMContext {
pub device_type: usize,
pub device_id: usize,
}
impl<'a> From<&'a TVMContext> for DLContext {
fn from(ctx: &'a TVMContext) -> Self {
Self {
device_type: ctx.device_type as u32,
device_id: ctx.device_id as i32,
}
}
}
impl Default for TVMContext {
fn default() -> Self {
Self {
device_type: DLDeviceType_kDLCPU as usize,
device_id: 0,
}
}
}
/// `From` conversions to `DLTensor` for `ndarray::Array`.
/// Takes a reference to the `ndarray` since `DLTensor` is not owned.
macro_rules! impl_dltensor_from_ndarray {
($type:ty, $typecode:expr) => {
impl<'a, D: ndarray::Dimension> From<&'a mut ndarray::Array<$type, D>> for DLTensor {
fn from(arr: &'a mut ndarray::Array<$type, D>) -> Self {
DLTensor {
data: arr.as_mut_ptr() as *mut c_void,
ctx: DLContext {
device_type: DLDeviceType_kDLCPU,
device_id: 0,
},
ndim: arr.ndim() as c_int,
dtype: DLDataType {
code: $typecode as u8,
bits: 8 * mem::size_of::<$type>() as u8,
lanes: 1,
},
shape: arr.shape().as_ptr() as *const i64 as *mut i64,
strides: arr.strides().as_ptr() as *const isize as *mut i64,
byte_offset: 0,
}
}
}
};
}
impl_dltensor_from_ndarray!(f32, DLDataTypeCode_kDLFloat);
impl_dltensor_from_ndarray!(f64, DLDataTypeCode_kDLFloat);
impl_dltensor_from_ndarray!(i32, DLDataTypeCode_kDLInt);
impl_dltensor_from_ndarray!(i64, DLDataTypeCode_kDLInt);
impl_dltensor_from_ndarray!(u32, DLDataTypeCode_kDLUInt);
impl_dltensor_from_ndarray!(u64, DLDataTypeCode_kDLUInt);
//! Error types for `TVMArgValue` and `TVMRetValue` conversions.
use std::fmt;
error_chain! {
errors {
TryFromTVMArgValueError(expected: String, actual: String) {
description("mismatched types while converting from TVMArgValue")
display("expected `{}` but given `{}`", expected, actual)
static TYPE_CODE_STRS: [&str; 15] = [
"int",
"uint",
"float",
"handle",
"null",
"TVMType",
"TVMContext",
"ArrayHandle",
"NodeHandle",
"ModuleHandle",
"FuncHandle",
"str",
"bytes",
"NDArrayContainer",
"ExtBegin",
];
#[derive(Debug, Fail)]
pub struct ValueDowncastError {
actual_type_code: i64,
expected_type_code: i64,
}
impl ValueDowncastError {
pub fn new(actual_type_code: i64, expected_type_code: i64) -> Self {
Self {
actual_type_code,
expected_type_code,
}
}
}
TryFromTVMRetValueError(expected: String, actual: String) {
description("mismatched types while downcasting TVMRetValue")
display("invalid downcast: expected `{}` but given `{}`", expected, actual)
impl fmt::Display for ValueDowncastError {
fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
write!(
formatter,
"Could not downcast TVMValue: expected `{}` but was {}",
TYPE_CODE_STRS[self.actual_type_code as usize],
TYPE_CODE_STRS[self.expected_type_code as usize]
)
}
}
#[derive(Debug, Fail)]
#[fail(display = "Function call `{}` returned error: {}", context, message)]
pub struct FuncCallError {
context: String,
message: String,
}
impl FuncCallError {
pub fn get_with_context(context: String) -> Self {
Self {
context,
message: unsafe { std::ffi::CStr::from_ptr(crate::ffi::TVMGetLastError()) }
.to_str()
.expect("double fault")
.to_owned(),
}
}
}
// error_chain! {
// errors {
// TryFromTVMRetValueError(expected_type: String, actual_type_code: i64) {
// description("mismatched types while downcasting TVMRetValue")
// display("invalid downcast: expected `{}` but was `{}`",
// expected_type, type_code_to_string(actual_type_code))
// }
// }
// foreign_links {
// IntoString(std::ffi::IntoStringError);
// ParseInt(std::num::ParseIntError);
// Utf8(std::str::Utf8Error);
// }
// }
//! This crate contains the refactored basic components required
//! for `runtime` and `frontend` TVM crates.
#![crate_name = "tvm_common"]
#![recursion_limit = "1024"]
#![allow(non_camel_case_types, unused_imports)]
#![feature(box_syntax, try_from)]
#![feature(box_syntax, trait_alias)]
#[macro_use]
extern crate error_chain;
extern crate failure;
/// Unified ffi module for both runtime and frontend crates.
pub mod ffi {
#![allow(non_camel_case_types, non_snake_case, non_upper_case_globals, unused)]
#[cfg(feature = "frontend")]
pub extern crate tvm_sys as ts;
use std::os::raw::{c_char, c_int, c_void};
#[cfg(feature = "runtime")]
pub mod runtime {
use std::os::raw::{c_char, c_int, c_void};
include!(concat!(env!("CARGO_MANIFEST_DIR"), "/src/c_runtime_api.rs"));
include!(concat!(env!("CARGO_MANIFEST_DIR"), "/src/c_runtime_api.rs"));
pub type BackendPackedCFunc = extern "C" fn(
args: *const TVMValue,
type_codes: *const c_int,
num_args: c_int,
) -> c_int;
}
pub type BackendPackedCFunc =
extern "C" fn(args: *const TVMValue, type_codes: *const c_int, num_args: c_int) -> c_int;
}
pub mod array;
pub mod errors;
pub mod ty;
#[macro_use]
pub mod packed_func;
pub mod value;
pub use errors::*;
pub use ty::TVMTypeCode;
pub use value::{TVMArgValue, TVMRetValue, TVMValue};
pub use ffi::{TVMContext, TVMType};
pub use packed_func::{TVMArgValue, TVMRetValue};
use std::{any::Any, convert::TryFrom, marker::PhantomData, os::raw::c_void};
use failure::Error;
pub use crate::ffi::TVMValue;
use crate::ffi::*;
pub trait PackedFunc =
Fn(&[TVMArgValue]) -> Result<TVMRetValue, crate::errors::FuncCallError> + Send + Sync;
/// Calls a packed function and returns a `TVMRetValue`.
///
/// # Example
///
/// `call_packed!(my_tvm_func, &mut arg1, &mut arg2)`
#[macro_export]
macro_rules! call_packed {
($fn:expr, $($args:expr),+) => {
$fn(&[$($args.into(),)+])
};
($fn:expr) => {
$fn(&Vec::new())
};
}
/// A borrowed TVMPODValue. Can be constructed using `into()` but the preferred way
/// to obtain a `TVMArgValue` is automatically via `call_packed!`.
#[derive(Clone, Copy)]
pub struct TVMArgValue<'a> {
pub _lifetime: PhantomData<&'a ()>,
pub value: TVMValue,
pub type_code: i64,
}
impl<'a> TVMArgValue<'a> {
pub fn new(value: TVMValue, type_code: i64) -> Self {
TVMArgValue {
_lifetime: PhantomData,
value: value,
type_code: type_code,
}
}
}
#[macro_export]
macro_rules! ensure_type {
($val:ident, $expected_type_code:expr) => {
ensure!(
$val.type_code == $expected_type_code as i64,
$crate::errors::ValueDowncastError::new(
$val.type_code as i64,
$expected_type_code as i64
)
);
};
}
/// Creates a conversion to a `TVMArgValue` for a primitive type and DLDataTypeCode.
macro_rules! impl_prim_tvm_arg {
($type_code:ident, $field:ident, $field_type:ty, [ $( $type:ty ),+ ] ) => {
$(
impl From<$type> for TVMArgValue<'static> {
fn from(val: $type) -> Self {
TVMArgValue {
value: TVMValue { $field: val as $field_type },
type_code: $type_code as i64,
_lifetime: PhantomData,
}
}
}
impl<'a> From<&'a $type> for TVMArgValue<'a> {
fn from(val: &'a $type) -> Self {
TVMArgValue {
value: TVMValue {
$field: val.to_owned() as $field_type,
},
type_code: $type_code as i64,
_lifetime: PhantomData,
}
}
}
impl<'a> TryFrom<TVMArgValue<'a>> for $type {
type Error = Error;
fn try_from(val: TVMArgValue<'a>) -> Result<Self, Self::Error> {
ensure_type!(val, $type_code);
Ok(unsafe { val.value.$field as $type })
}
}
impl<'a> TryFrom<&TVMArgValue<'a>> for $type {
type Error = Error;
fn try_from(val: &TVMArgValue<'a>) -> Result<Self, Self::Error> {
ensure_type!(val, $type_code);
Ok(unsafe { val.value.$field as $type })
}
}
)+
};
}
impl_prim_tvm_arg!(DLDataTypeCode_kDLFloat, v_float64, f64, [f32, f64]);
impl_prim_tvm_arg!(
DLDataTypeCode_kDLInt,
v_int64,
i64,
[i8, i16, i32, i64, isize]
);
impl_prim_tvm_arg!(
DLDataTypeCode_kDLUInt,
v_int64,
i64,
[u8, u16, u32, u64, usize]
);
#[cfg(feature = "bindings")]
// only allow this in bindings because pure-rust can't take ownership of leaked CString
impl<'a> From<&String> for TVMArgValue<'a> {
fn from(string: &String) -> Self {
TVMArgValue {
value: TVMValue {
v_str: std::ffi::CString::new(string.clone()).unwrap().into_raw(),
},
type_code: TVMTypeCode_kStr as i64,
_lifetime: PhantomData,
}
}
}
impl<'a> From<&std::ffi::CString> for TVMArgValue<'a> {
fn from(string: &std::ffi::CString) -> Self {
TVMArgValue {
value: TVMValue {
v_str: string.as_ptr(),
},
type_code: TVMTypeCode_kStr as i64,
_lifetime: PhantomData,
}
}
}
impl<'a> TryFrom<TVMArgValue<'a>> for &str {
type Error = Error;
fn try_from(arg: TVMArgValue<'a>) -> Result<Self, Self::Error> {
ensure_type!(arg, TVMTypeCode_kStr);
Ok(unsafe { std::ffi::CStr::from_ptr(arg.value.v_handle as *const i8) }.to_str()?)
}
}
impl<'a> TryFrom<&TVMArgValue<'a>> for &str {
type Error = Error;
fn try_from(arg: &TVMArgValue<'a>) -> Result<Self, Self::Error> {
ensure_type!(arg, TVMTypeCode_kStr);
Ok(unsafe { std::ffi::CStr::from_ptr(arg.value.v_handle as *const i8) }.to_str()?)
}
}
/// Creates a conversion to a `TVMArgValue` for an object handle.
impl<'a, T> From<*const T> for TVMArgValue<'a> {
fn from(ptr: *const T) -> Self {
TVMArgValue {
value: TVMValue {
v_handle: ptr as *mut T as *mut c_void,
},
type_code: TVMTypeCode_kArrayHandle as i64,
_lifetime: PhantomData,
}
}
}
/// Creates a conversion to a `TVMArgValue` for a mutable object handle.
impl<'a, T> From<*mut T> for TVMArgValue<'a> {
fn from(ptr: *mut T) -> Self {
TVMArgValue {
value: TVMValue {
v_handle: ptr as *mut c_void,
},
type_code: TVMTypeCode_kHandle as i64,
_lifetime: PhantomData,
}
}
}
impl<'a> From<&'a mut DLTensor> for TVMArgValue<'a> {
fn from(arr: &'a mut DLTensor) -> Self {
TVMArgValue {
value: TVMValue {
v_handle: arr as *mut _ as *mut c_void,
},
type_code: TVMTypeCode_kArrayHandle as i64,
_lifetime: PhantomData,
}
}
}
impl<'a> From<&'a DLTensor> for TVMArgValue<'a> {
fn from(arr: &'a DLTensor) -> Self {
TVMArgValue {
value: TVMValue {
v_handle: arr as *const _ as *mut DLTensor as *mut c_void,
},
type_code: TVMTypeCode_kArrayHandle as i64,
_lifetime: PhantomData,
}
}
}
impl<'a, 'v> TryFrom<&'a TVMArgValue<'v>> for TVMType {
type Error = Error;
fn try_from(arg: &'a TVMArgValue<'v>) -> Result<Self, Self::Error> {
ensure_type!(arg, TVMTypeCode_kTVMType);
Ok(unsafe { arg.value.v_type.into() })
}
}
/// An owned TVMPODValue. Can be converted from a variety of primitive and object types.
/// Can be downcasted using `try_from` if it contains the desired type.
///
/// # Example
///
/// ```
/// let a = 42u32;
/// let b: i64 = TVMRetValue::from(a).try_into().unwrap();
///
/// let s = "hello, world!";
/// let t: TVMRetValue = s.into();
/// assert_eq!(String::try_from(t).unwrap(), s);
/// ```
pub struct TVMRetValue {
pub value: TVMValue,
pub box_value: Box<Any>,
pub type_code: i64,
}
impl TVMRetValue {
pub fn from_tvm_value(value: TVMValue, type_code: i64) -> Self {
Self {
value,
type_code,
box_value: box (),
}
}
pub fn into_tvm_value(self) -> (TVMValue, TVMTypeCode) {
(self.value, self.type_code as TVMTypeCode)
}
}
impl Default for TVMRetValue {
fn default() -> Self {
TVMRetValue {
value: TVMValue { v_int64: 0 as i64 },
type_code: 0,
box_value: box (),
}
}
}
macro_rules! impl_pod_ret_value {
($code:expr, [ $( $ty:ty ),+ ] ) => {
$(
impl From<$ty> for TVMRetValue {
fn from(val: $ty) -> Self {
Self {
value: val.into(),
type_code: $code as i64,
box_value: box (),
}
}
}
impl TryFrom<TVMRetValue> for $ty {
type Error = Error;
fn try_from(ret: TVMRetValue) -> Result<$ty, Self::Error> {
ensure_type!(ret, $code);
Ok(ret.value.into())
}
}
)+
};
}
impl_pod_ret_value!(DLDataTypeCode_kDLInt, [i8, i16, i32, i64, isize]);
impl_pod_ret_value!(DLDataTypeCode_kDLUInt, [u8, u16, u32, u64, usize]);
impl_pod_ret_value!(DLDataTypeCode_kDLFloat, [f32, f64]);
impl_pod_ret_value!(TVMTypeCode_kTVMType, [TVMType]);
impl_pod_ret_value!(TVMTypeCode_kTVMContext, [TVMContext]);
impl TryFrom<TVMRetValue> for String {
type Error = Error;
fn try_from(ret: TVMRetValue) -> Result<String, Self::Error> {
ensure_type!(ret, TVMTypeCode_kStr);
let cs = unsafe { std::ffi::CString::from_raw(ret.value.v_handle as *mut i8) };
let ret_str = cs.clone().into_string();
if cfg!(feature = "bindings") {
std::mem::forget(cs); // TVM C++ takes ownership of CString. (@see TVMFuncCall)
}
Ok(ret_str?)
}
}
impl From<String> for TVMRetValue {
fn from(s: String) -> Self {
let cs = std::ffi::CString::new(s).unwrap();
Self {
value: TVMValue {
v_str: cs.into_raw() as *mut i8,
},
box_value: box (),
type_code: TVMTypeCode_kStr as i64,
}
}
}
//! This module containes `TVMTypeCode` and `TVMType` with some conversion methods.
//!
//! # Example
//!
//! ```
//! let dtype = TVMType::from("float");
//! println!("dtype is: {}", dtype);
//! ```
use std::{
ffi::{CStr, CString},
fmt::{self, Display, Formatter},
};
/// TVM type codes.
#[repr(u32)]
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub enum TVMTypeCode {
kDLInt = 0,
kDLUInt = 1,
kDLFloat = 2,
kHandle = 3,
kNull = 4,
kTVMType = 5,
kTVMContext = 6,
kArrayHandle = 7,
kNodeHandle = 8,
kModuleHandle = 9,
kFuncHandle = 10,
kStr = 11,
kBytes = 12,
kNDArrayContainer = 13,
}
impl Default for TVMTypeCode {
fn default() -> Self {
TVMTypeCode::kDLInt
}
}
impl From<TVMTypeCode> for i64 {
fn from(arg: TVMTypeCode) -> i64 {
match arg {
TVMTypeCode::kDLInt => 0,
TVMTypeCode::kDLUInt => 1,
TVMTypeCode::kDLFloat => 2,
TVMTypeCode::kHandle => 3,
TVMTypeCode::kNull => 4,
TVMTypeCode::kTVMType => 5,
TVMTypeCode::kTVMContext => 6,
TVMTypeCode::kArrayHandle => 7,
TVMTypeCode::kNodeHandle => 8,
TVMTypeCode::kModuleHandle => 9,
TVMTypeCode::kFuncHandle => 10,
TVMTypeCode::kStr => 11,
TVMTypeCode::kBytes => 12,
TVMTypeCode::kNDArrayContainer => 13,
}
}
}
impl Into<TVMTypeCode> for i64 {
fn into(self) -> TVMTypeCode {
match self {
0 => TVMTypeCode::kDLInt,
1 => TVMTypeCode::kDLUInt,
2 => TVMTypeCode::kDLFloat,
3 => TVMTypeCode::kHandle,
4 => TVMTypeCode::kNull,
5 => TVMTypeCode::kTVMType,
6 => TVMTypeCode::kTVMContext,
7 => TVMTypeCode::kArrayHandle,
8 => TVMTypeCode::kNodeHandle,
9 => TVMTypeCode::kModuleHandle,
10 => TVMTypeCode::kFuncHandle,
11 => TVMTypeCode::kStr,
12 => TVMTypeCode::kBytes,
13 => TVMTypeCode::kNDArrayContainer,
_ => unreachable!(),
}
}
}
impl Display for TVMTypeCode {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
write!(
f,
"{}",
match self {
TVMTypeCode::kDLInt => "int",
TVMTypeCode::kDLUInt => "uint",
TVMTypeCode::kDLFloat => "float",
TVMTypeCode::kHandle => "handle",
TVMTypeCode::kNull => "null",
TVMTypeCode::kTVMType => "TVM type",
TVMTypeCode::kTVMContext => "TVM context",
TVMTypeCode::kArrayHandle => "Array handle",
TVMTypeCode::kNodeHandle => "Node handle",
TVMTypeCode::kModuleHandle => "Module handle",
TVMTypeCode::kFuncHandle => "Function handle",
TVMTypeCode::kStr => "string",
TVMTypeCode::kBytes => "bytes",
TVMTypeCode::kNDArrayContainer => "ndarray container",
}
)
}
}
macro_rules! impl_prim_type {
($type:ty, $variant:ident) => {
impl<'a> From<&'a $type> for TVMTypeCode {
fn from(_arg: &$type) -> Self {
TVMTypeCode::$variant
}
}
impl<'a> From<&'a mut $type> for TVMTypeCode {
fn from(_arg: &mut $type) -> Self {
TVMTypeCode::$variant
}
}
};
}
impl_prim_type!(usize, kDLInt);
impl_prim_type!(i64, kDLInt);
impl_prim_type!(i32, kDLInt);
impl_prim_type!(i16, kDLInt);
impl_prim_type!(i8, kDLInt);
impl_prim_type!(u64, kDLUInt);
impl_prim_type!(u32, kDLUInt);
impl_prim_type!(u16, kDLUInt);
impl_prim_type!(u8, kDLUInt);
impl_prim_type!(f64, kDLFloat);
impl_prim_type!(f32, kDLFloat);
impl_prim_type!(str, kStr);
impl_prim_type!(CStr, kStr);
impl_prim_type!(String, kStr);
impl_prim_type!(CString, kStr);
impl_prim_type!([u8], kBytes);
[package]
name = "tvm-sys"
version = "0.1.0"
authors = ["TVM Contributors"]
license = "Apache-2.0"
description = "Raw C API"
[build-dependencies]
bindgen = "0.37.4"
#![allow(
non_camel_case_types,
non_snake_case,
non_upper_case_globals,
dead_code,
improper_ctypes
)]
include!("bindgen.rs");
......@@ -15,11 +15,11 @@ name = "tvm_frontend"
crate-type = ["dylib"]
[dependencies]
error-chain = "0.12.0"
failure = "0.1.5"
lazy_static = "1.1.0"
ndarray = "0.12.1"
num-traits = "0.2"
tvm-common = { version = "0.1.0", path = "../common/", features = ["frontend"] }
tvm-common = { version = "0.1.0", path = "../common/", features = ["bindings"] }
[features]
blas = ["ndarray/blas"]
#![feature(try_from)]
extern crate csv;
extern crate image;
extern crate ndarray;
......@@ -10,6 +8,7 @@ use std::{
convert::TryInto,
fs::{self, File},
path::Path,
str::FromStr,
};
use image::{FilterType, GenericImageView};
......@@ -44,8 +43,12 @@ fn main() {
// make arr shape as [1, 3, 224, 224] acceptable to resnet
let arr = arr.insert_axis(Axis(0));
// create input tensor from rust's ndarray
let input =
NDArray::from_rust_ndarray(&arr, TVMContext::cpu(0), TVMType::from("float32")).unwrap();
let input = NDArray::from_rust_ndarray(
&arr,
TVMContext::cpu(0),
TVMType::from_str("float32").unwrap(),
)
.unwrap();
println!(
"input size is {:?}",
input.shape().expect("cannot get the input shape")
......@@ -59,7 +62,7 @@ fn main() {
)))
.unwrap();
// get the global TVM graph runtime function
let runtime_create_fn = Function::get("tvm.graph_runtime.create", true).unwrap();
let runtime_create_fn = Function::get("tvm.graph_runtime.create").unwrap();
let runtime_create_fn_ret = call_packed!(
runtime_create_fn,
&graph,
......@@ -85,14 +88,19 @@ fn main() {
.get_function("set_input", false)
.unwrap();
call_packed!(set_input_fn, "data", &input).unwrap();
let data_str = "data".to_string();
call_packed!(set_input_fn, &data_str, &input).unwrap();
// get `run` function from runtime module
let ref run_fn = graph_runtime_module.get_function("run", false).unwrap();
// execute the run function. Note that it has no argument
call_packed!(run_fn,).unwrap();
// prepare to get the output
let output_shape = &mut [1, 1000];
let output = NDArray::empty(output_shape, TVMContext::cpu(0), TVMType::from("float32"));
let output = NDArray::empty(
output_shape,
TVMContext::cpu(0),
TVMType::from_str("float32").unwrap(),
);
// get the `get_output` function from runtime module
let ref get_output_fn = graph_runtime_module
.get_function("get_output", false)
......
......@@ -3,9 +3,9 @@
//!
//! For more detail, please see the example `resnet` in `examples` repository.
use std::os::raw::c_char;
use std::os::raw::{c_char, c_void};
use crate::ts;
use tvm_common::{ffi, TVMArgValue};
/// A struct holding TVM byte-array.
///
......@@ -19,11 +19,11 @@ use crate::ts;
/// ```
#[derive(Debug, Clone)]
pub struct TVMByteArray {
pub(crate) inner: ts::TVMByteArray,
pub(crate) inner: ffi::TVMByteArray,
}
impl TVMByteArray {
pub(crate) fn new(barr: ts::TVMByteArray) -> TVMByteArray {
pub(crate) fn new(barr: ffi::TVMByteArray) -> TVMByteArray {
TVMByteArray { inner: barr }
}
......@@ -46,7 +46,7 @@ impl TVMByteArray {
impl<'a> From<&'a Vec<u8>> for TVMByteArray {
fn from(arg: &Vec<u8>) -> Self {
let barr = ts::TVMByteArray {
let barr = ffi::TVMByteArray {
data: arg.as_ptr() as *const c_char,
size: arg.len(),
};
......@@ -54,6 +54,18 @@ impl<'a> From<&'a Vec<u8>> for TVMByteArray {
}
}
impl<'a> From<&TVMByteArray> for TVMArgValue<'a> {
fn from(arr: &TVMByteArray) -> Self {
Self {
value: ffi::TVMValue {
v_handle: &arr.inner as *const ffi::TVMByteArray as *const c_void as *mut c_void,
},
type_code: ffi::TVMTypeCode_kBytes as i64,
_lifetime: std::marker::PhantomData,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
......
......@@ -18,12 +18,20 @@
//! ```
use std::{
convert::TryInto,
fmt::{self, Display, Formatter},
os::raw::c_void,
ptr,
};
use crate::{function, ts, Result};
use failure::Error;
use tvm_common::{
ffi::{self, TVMValue},
TVMArgValue,
};
use crate::function;
/// Device type can be from a supported device name. See the supported devices
/// in [TVM](https://github.com/dmlc/tvm).
......@@ -45,35 +53,35 @@ impl Default for TVMDeviceType {
}
}
impl From<TVMDeviceType> for ts::DLDeviceType {
impl From<TVMDeviceType> for ffi::DLDeviceType {
fn from(device_type: TVMDeviceType) -> Self {
match device_type.0 {
1 => ts::DLDeviceType_kDLCPU,
2 => ts::DLDeviceType_kDLGPU,
3 => ts::DLDeviceType_kDLCPUPinned,
4 => ts::DLDeviceType_kDLOpenCL,
7 => ts::DLDeviceType_kDLVulkan,
8 => ts::DLDeviceType_kDLMetal,
9 => ts::DLDeviceType_kDLVPI,
10 => ts::DLDeviceType_kDLROCM,
12 => ts::DLDeviceType_kDLExtDev,
1 => ffi::DLDeviceType_kDLCPU,
2 => ffi::DLDeviceType_kDLGPU,
3 => ffi::DLDeviceType_kDLCPUPinned,
4 => ffi::DLDeviceType_kDLOpenCL,
7 => ffi::DLDeviceType_kDLVulkan,
8 => ffi::DLDeviceType_kDLMetal,
9 => ffi::DLDeviceType_kDLVPI,
10 => ffi::DLDeviceType_kDLROCM,
12 => ffi::DLDeviceType_kDLExtDev,
_ => panic!("device type not found!"),
}
}
}
impl From<ts::DLDeviceType> for TVMDeviceType {
fn from(device_type: ts::DLDeviceType) -> Self {
impl From<ffi::DLDeviceType> for TVMDeviceType {
fn from(device_type: ffi::DLDeviceType) -> Self {
match device_type {
ts::DLDeviceType_kDLCPU => TVMDeviceType(1),
ts::DLDeviceType_kDLGPU => TVMDeviceType(2),
ts::DLDeviceType_kDLCPUPinned => TVMDeviceType(3),
ts::DLDeviceType_kDLOpenCL => TVMDeviceType(4),
ts::DLDeviceType_kDLVulkan => TVMDeviceType(7),
ts::DLDeviceType_kDLMetal => TVMDeviceType(8),
ts::DLDeviceType_kDLVPI => TVMDeviceType(9),
ts::DLDeviceType_kDLROCM => TVMDeviceType(10),
ts::DLDeviceType_kDLExtDev => TVMDeviceType(12),
ffi::DLDeviceType_kDLCPU => TVMDeviceType(1),
ffi::DLDeviceType_kDLGPU => TVMDeviceType(2),
ffi::DLDeviceType_kDLCPUPinned => TVMDeviceType(3),
ffi::DLDeviceType_kDLOpenCL => TVMDeviceType(4),
ffi::DLDeviceType_kDLVulkan => TVMDeviceType(7),
ffi::DLDeviceType_kDLMetal => TVMDeviceType(8),
ffi::DLDeviceType_kDLVPI => TVMDeviceType(9),
ffi::DLDeviceType_kDLROCM => TVMDeviceType(10),
ffi::DLDeviceType_kDLExtDev => TVMDeviceType(12),
_ => panic!("device type not found!"),
}
}
......@@ -117,6 +125,18 @@ impl<'a> From<&'a str> for TVMDeviceType {
}
}
impl<'a> From<&'a TVMDeviceType> for TVMArgValue<'a> {
fn from(dev_type: &'a TVMDeviceType) -> Self {
Self {
value: TVMValue {
v_int64: dev_type.0 as i64,
},
type_code: ffi::DLDataTypeCode_kDLInt as i64,
_lifetime: std::marker::PhantomData,
}
}
}
/// Represents the underlying device context. Default is cpu.
///
/// ## Examples
......@@ -138,15 +158,15 @@ pub struct TVMContext {
/// Supported device types
pub device_type: TVMDeviceType,
/// Device id
pub device_id: usize,
pub device_id: i32,
}
impl TVMContext {
/// Creates context from device type and id.
pub fn new(device_type: TVMDeviceType, device_id: usize) -> Self {
pub fn new(device_type: TVMDeviceType, device_id: i32) -> Self {
TVMContext {
device_type: device_type,
device_id: device_id,
device_type,
device_id,
}
}
}
......@@ -155,7 +175,7 @@ macro_rules! impl_ctxs {
($(($ctx:ident, $dldevt:expr));+) => {
$(
impl TVMContext {
pub fn $ctx(device_id: usize) -> Self {
pub fn $ctx(device_id: i32) -> Self {
Self::new(TVMDeviceType($dldevt), device_id)
}
}
......@@ -185,20 +205,20 @@ impl<'a> From<&'a str> for TVMContext {
impl TVMContext {
/// Checks whether the context exists or not.
pub fn exist(&self) -> bool {
let func = function::Function::get("_GetDeviceAttr", true /* is_global */)
.expect("API function always exists");
let func = function::Function::get("_GetDeviceAttr").expect("API function always exists");
let dt = self.device_type.0 as usize;
// `unwrap` is ok here because if there is any error,
// if would occure inside `call_packed!`
let ret = call_packed!(func, &dt, &self.device_id, &0)
let ret: u64 = call_packed!(func, &dt, &self.device_id, &0)
.unwrap()
.prim_value;
.try_into()
.unwrap();
ret != 0
}
/// Synchronize the context stream.
pub fn sync(&self) -> Result<()> {
check_call!(ts::TVMSynchronize(
pub fn sync(&self) -> Result<(), Error> {
check_call!(ffi::TVMSynchronize(
self.device_type.0 as i32,
self.device_id as i32,
ptr::null_mut() as *mut c_void
......@@ -212,16 +232,17 @@ macro_rules! impl_device_attrs {
$(
impl TVMContext {
pub fn $attr_name(&self) -> usize {
let func = function::Function::get("_GetDeviceAttr", true /* is_global */)
let func = function::Function::get("_GetDeviceAttr")
.expect("API function always exists");
let dt = self.device_type.0 as usize;
// `unwrap` is ok here because if there is any error,
// if would occur in function call.
let ret = function::Builder::from(func)
.args(&[dt, self.device_id, $attr_kind])
function::Builder::from(func)
.args(&[dt, self.device_id as usize, $attr_kind])
.invoke()
.unwrap();
ret.prim_value as usize
.unwrap()
.try_into()
.unwrap()
}
}
)+
......@@ -237,18 +258,18 @@ impl_device_attrs!((max_threads_per_block, 1);
(multi_processor_count, 7);
(max_thread_dimensions, 8));
impl From<ts::DLContext> for TVMContext {
fn from(ctx: ts::DLContext) -> Self {
impl From<ffi::DLContext> for TVMContext {
fn from(ctx: ffi::DLContext) -> Self {
TVMContext {
device_type: TVMDeviceType::from(ctx.device_type),
device_id: ctx.device_id as usize,
device_id: ctx.device_id,
}
}
}
impl From<TVMContext> for ts::DLContext {
impl From<TVMContext> for ffi::DLContext {
fn from(ctx: TVMContext) -> Self {
ts::DLContext {
ffi::DLContext {
device_type: ctx.device_type.into(),
device_id: ctx.device_id as i32,
}
......
//! This module implements TVM custom [`Error`], [`ErrorKind`] and [`Result`] types.
pub use failure::Error;
use std::{ffi, option};
#[derive(Debug, Fail)]
#[fail(display = "Cannot convert from an empty array.")]
pub struct EmptyArrayError;
use crate::{common_errors, rust_ndarray};
error_chain! {
errors {
EmptyArray {
description("cannot convert from an empty array")
}
NullHandle(name: String) {
description("null handle")
display("requested `{}` handle is null", name)
}
FunctionNotFound {
description("function not found")
display("function was not set in `function::Builder`")
}
TypeMismatch(expected: String, found: String) {
description("type mismatch!")
display("expected type `{}`, but found `{}`", expected, found)
}
MissingShapeError {
description("ndarray `shape()` returns `None`")
display("called `Option::unwrap()` on a `None` value")
}
AtMostOneReturn {
description("TVM functions accept at most one return value")
}
#[derive(Debug, Fail)]
#[fail(display = "Handle `{}` is null.", name)]
pub struct NullHandleError {
pub name: String,
}
}
#[derive(Debug, Fail)]
#[fail(display = "Function was not set in `function::Builder`")]
pub struct FunctionNotFoundError;
foreign_links {
ShapeError(rust_ndarray::ShapeError);
NulError(ffi::NulError);
IntoStringError(ffi::IntoStringError);
CommonError(common_errors::Error);
}
#[derive(Debug, Fail)]
#[fail(display = "Expected type `{}` but found `{}`", expected, actual)]
pub struct TypeMismatchError {
pub expected: String,
pub actual: String,
}
impl From<option::NoneError> for Error {
fn from(_err: option::NoneError) -> Self {
ErrorKind::MissingShapeError.into()
}
}
#[derive(Debug, Fail)]
#[fail(display = "Missing NDArray shape.")]
pub struct MissingShapeError;
......@@ -11,32 +11,36 @@
//!
//! Checkout the `examples` repository for more details.
#![crate_name = "tvm_frontend"]
#![recursion_limit = "1024"]
#![allow(non_camel_case_types, unused_unsafe)]
#![feature(
try_from,
try_trait,
fn_traits,
unboxed_closures,
box_syntax,
option_replace
)]
#![feature(box_syntax)]
#[macro_use]
extern crate error_chain;
extern crate tvm_common as common;
extern crate failure;
#[macro_use]
extern crate lazy_static;
extern crate ndarray as rust_ndarray;
extern crate num_traits;
extern crate tvm_common;
use std::{
ffi::{CStr, CString},
str,
};
use crate::common::ffi::ts;
use failure::Error;
pub use crate::{
bytearray::TVMByteArray,
context::{TVMContext, TVMDeviceType},
errors::*,
function::Function,
module::Module,
ndarray::NDArray,
tvm_common::{
errors as common_errors,
ffi::{self, TVMType},
packed_func::{TVMArgValue, TVMRetValue},
},
};
// Macro to check the return call to TVM runtime shared library.
macro_rules! check_call {
......@@ -50,7 +54,7 @@ macro_rules! check_call {
/// Gets the last error message.
pub fn get_last_error() -> &'static str {
unsafe {
match CStr::from_ptr(ts::TVMGetLastError()).to_str() {
match CStr::from_ptr(ffi::TVMGetLastError()).to_str() {
Ok(s) => s,
Err(_) => "Invalid UTF-8 message",
}
......@@ -60,7 +64,7 @@ pub fn get_last_error() -> &'static str {
pub(crate) fn set_last_error(err: &Error) {
let c_string = CString::new(err.to_string()).unwrap();
unsafe {
ts::TVMAPISetLastError(c_string.as_ptr());
ffi::TVMAPISetLastError(c_string.as_ptr());
}
}
......@@ -71,27 +75,11 @@ pub mod context;
pub mod errors;
pub mod module;
pub mod ndarray;
pub mod ty;
pub mod value;
pub use crate::{
bytearray::TVMByteArray,
common::{
errors as common_errors,
ty::TVMTypeCode,
value::{TVMArgValue, TVMRetValue, TVMValue},
},
context::{TVMContext, TVMDeviceType},
errors::*,
function::Function,
module::Module,
ndarray::NDArray,
ty::TVMType,
};
/// Outputs the current TVM version.
pub fn version() -> &'static str {
match str::from_utf8(ts::TVM_VERSION) {
match str::from_utf8(ffi::TVM_VERSION) {
Ok(s) => s,
Err(_) => "Invalid UTF-8 string",
}
......@@ -108,8 +96,8 @@ mod tests {
#[test]
fn set_error() {
let err = ErrorKind::EmptyArray;
let err = errors::EmptyArrayError;
set_last_error(&err.into());
assert_eq!(get_last_error().trim(), ErrorKind::EmptyArray.to_string());
assert_eq!(get_last_error().trim(), errors::EmptyArrayError.to_string());
}
}
......@@ -8,30 +8,27 @@ use std::{
ptr,
};
use crate::ts;
use failure::Error;
use tvm_common::ffi;
use crate::{function::Function, ErrorKind, Result};
use crate::{errors, function::Function};
const ENTRY_FUNC: &'static str = "__tvm_main__";
/// Wrapper around TVM module handle which contains an entry function.
/// The entry function can be applied to an imported module through [`entry_func`].
/// Also [`is_released`] shows whether the module is dropped or not.
///
/// [`entry_func`]:struct.Module.html#method.entry_func
/// [`is_released`]:struct.Module.html#method.is_released
#[derive(Debug, Clone)]
pub struct Module {
pub(crate) handle: ts::TVMModuleHandle,
is_released: bool,
pub(crate) handle: ffi::TVMModuleHandle,
entry_func: Option<Function>,
}
impl Module {
pub(crate) fn new(handle: ts::TVMModuleHandle, is_released: bool) -> Self {
pub(crate) fn new(handle: ffi::TVMModuleHandle) -> Self {
Self {
handle,
is_released,
entry_func: None,
}
}
......@@ -44,62 +41,67 @@ impl Module {
}
/// Gets a function by name from a registered module.
pub fn get_function(&self, name: &str, query_import: bool) -> Result<Function> {
pub fn get_function(&self, name: &str, query_import: bool) -> Result<Function, Error> {
let name = CString::new(name)?;
let mut fhandle = ptr::null_mut() as ts::TVMFunctionHandle;
check_call!(ts::TVMModGetFunction(
let mut fhandle = ptr::null_mut() as ffi::TVMFunctionHandle;
check_call!(ffi::TVMModGetFunction(
self.handle,
name.as_ptr() as *const c_char,
query_import as c_int,
&mut fhandle as *mut _
));
if fhandle.is_null() {
bail!(ErrorKind::NullHandle(format!("{}", name.into_string()?)))
} else {
Ok(Function::new(fhandle, false, false))
}
ensure!(
!fhandle.is_null(),
errors::NullHandleError {
name: format!("{}", name.into_string()?)
}
);
Ok(Function::new(fhandle))
}
/// Imports a dependent module such as `.ptx` for gpu.
pub fn import_module(&self, dependent_module: Module) {
check_call!(ts::TVMModImport(self.handle, dependent_module.handle))
check_call!(ffi::TVMModImport(self.handle, dependent_module.handle))
}
/// Loads a module shared library from path.
pub fn load<P: AsRef<Path>>(path: &P) -> Result<Module> {
let ext = path.as_ref().extension()?.to_str()?;
let func = Function::get("module._LoadFromFile", true /* is_global */)
.expect("API function always exists");
let ret: Module = call_packed!(func, path.as_ref().to_str()?, ext)?.try_into()?;
pub fn load<P: AsRef<Path>>(path: &P) -> Result<Module, Error> {
let ext = CString::new(
path.as_ref()
.extension()
.unwrap_or(std::ffi::OsStr::new(""))
.to_str()
.ok_or_else(|| {
format_err!("Bad module load path: `{}`.", path.as_ref().display())
})?,
)?;
let func = Function::get("module._LoadFromFile").expect("API function always exists");
let cpath =
CString::new(path.as_ref().to_str().ok_or_else(|| {
format_err!("Bad module load path: `{}`.", path.as_ref().display())
})?)?;
let ret: Module = call_packed!(func, &cpath, &ext)?.try_into()?;
Ok(ret)
}
/// Checks if a target device is enabled for a module.
pub fn enabled(&self, target: &str) -> bool {
let func = Function::get("module._Enabled", true /* is_global */)
.expect("API function always exists");
let func = Function::get("module._Enabled").expect("API function always exists");
// `unwrap` is safe here because if there is any error during the
// function call, it would occur in `call_packed!`.
let ret: i64 = call_packed!(func, target).unwrap().try_into().unwrap();
let tgt = CString::new(target).unwrap();
let ret: i64 = call_packed!(func, &tgt).unwrap().try_into().unwrap();
ret != 0
}
/// Returns the underlying module handle.
pub fn handle(&self) -> ts::TVMModuleHandle {
pub fn handle(&self) -> ffi::TVMModuleHandle {
self.handle
}
/// Returns true if the underlying module has been dropped and false otherwise.
pub fn is_released(&self) -> bool {
self.is_released
}
}
impl Drop for Module {
fn drop(&mut self) {
if !self.is_released {
check_call!(ts::TVMModFree(self.handle));
self.is_released = true;
}
check_call!(ffi::TVMModFree(self.handle));
}
}
//! This module implements the required conversions from Rust types to TVM types.
//!
//! In TVM frontend only conversions from Rust's 32-bits (POD) numeric types (i32, u32, f32)
//! and 64-bits pointers are supported.
use std::{
fmt::{self, Display, Formatter},
ops::{Deref, DerefMut},
};
use crate::ts;
use crate::{Function, Module, NDArray, TVMByteArray, TVMContext, TVMDeviceType, TVMTypeCode};
macro_rules! impl_prim_type {
($type:ty, $variant:ident) => {
impl From<$type> for TVMTypeCode {
fn from(_arg: $type) -> Self {
TVMTypeCode::$variant
}
}
impl<'a> From<&'a $type> for TVMTypeCode {
fn from(_arg: &$type) -> Self {
TVMTypeCode::$variant
}
}
impl<'a> From<&'a mut $type> for TVMTypeCode {
fn from(_arg: &mut $type) -> Self {
TVMTypeCode::$variant
}
}
};
}
impl_prim_type!(TVMDeviceType, kDLInt);
impl_prim_type!(TVMContext, kTVMContext);
impl_prim_type!(TVMType, kTVMType);
impl_prim_type!(Function, kFuncHandle);
impl_prim_type!(Module, kModuleHandle);
impl_prim_type!(NDArray, kArrayHandle);
impl_prim_type!(TVMByteArray, kBytes);
/// See the [module-level documentation](../ty/index.html) for more details.
///
/// Wrapper around underlying TVMType
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub struct TVMType {
// inner fields are (code: u8, bits: u8, lanes: u16)
pub inner: ts::TVMType,
}
impl TVMType {
pub(crate) fn new(type_code: u8, bits: u8, lanes: u16) -> Self {
TVMType {
inner: ts::TVMType {
code: type_code,
bits: bits,
lanes: lanes,
},
}
}
}
/// Implements TVMType conversion from `&str` of general format `{dtype}{bits}x{lanes}`
/// such as "int32", "float32" or with lane "float32x1".
impl<'a> From<&'a str> for TVMType {
fn from(type_str: &'a str) -> Self {
if type_str == "bool" {
return TVMType::new(1, 1, 1);
}
let mut type_lanes = type_str.split("x");
let typ = type_lanes.next().expect("Missing dtype");
let lanes = type_lanes
.next()
.map(|l| u16::from_str_radix(l, 10).expect(&format!("Bad dtype lanes: {}", l)))
.unwrap_or(1);
let (type_name, bits) = match typ.find(char::is_numeric) {
Some(idx) => {
let (name, bits_str) = typ.split_at(idx);
(
name,
u8::from_str_radix(bits_str, 10)
.expect(&format!("Bad dtype bits: {}", bits_str)),
)
}
None => (typ, 32),
};
let type_code = match type_name {
"int" => 0,
"uint" => 1,
"float" => 2,
"handle" => 3,
_ => unimplemented!(),
};
TVMType::new(type_code, bits, lanes)
}
}
impl Display for TVMType {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
let ts::TVMType { code, bits, lanes } = self.inner;
if bits == 1 && lanes == 1 {
return write!(f, "bool");
}
let mut tcode_str = match code {
0 => "int",
1 => "uint",
2 => "float",
4 => "handle",
_ => "Unknown",
}
.to_string();
tcode_str += &bits.to_string();
if lanes > 1 {
tcode_str += &format!("x{}", lanes.to_string());
}
f.write_str(&tcode_str)
}
}
impl From<TVMType> for ts::DLDataType {
fn from(dtype: TVMType) -> Self {
dtype.inner
}
}
impl From<ts::DLDataType> for TVMType {
fn from(dtype: ts::DLDataType) -> Self {
Self::new(dtype.code, dtype.bits, dtype.lanes)
}
}
impl Deref for TVMType {
type Target = ts::TVMType;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl DerefMut for TVMType {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.inner
}
}
extern crate ndarray as rust_ndarray;
extern crate tvm_frontend as tvm;
use std::str::FromStr;
use tvm::*;
fn main() {
......@@ -12,7 +14,7 @@ fn main() {
} else {
(TVMContext::gpu(0), "gpu")
};
let dtype = TVMType::from("float32");
let dtype = TVMType::from_str("float32").unwrap();
let mut arr = NDArray::empty(shape, ctx, dtype);
arr.copy_from_buffer(data.as_mut_slice());
let mut ret = NDArray::empty(shape, ctx, dtype);
......@@ -26,8 +28,7 @@ fn main() {
function::Builder::from(&mut fadd)
.arg(&arr)
.arg(&arr)
.set_output(&mut ret)
.unwrap()
.arg(&mut ret)
.invoke()
.unwrap();
......
#![feature(extern_crate_item_prelude, try_from)]
#![allow(unused_imports)]
extern crate ndarray as rust_ndarray;
......@@ -6,17 +5,23 @@ extern crate ndarray as rust_ndarray;
extern crate tvm_frontend as tvm;
use rust_ndarray::ArrayD;
use std::convert::{TryFrom, TryInto};
use std::{
convert::{TryFrom, TryInto},
str::FromStr,
};
use tvm::*;
use tvm::{errors::Error, *};
fn main() {
register_global_func! {
fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue> {
fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue, Error> {
let mut ret = 0f32;
let shape = &mut [2];
for arg in args.iter() {
let e = NDArray::empty(shape, TVMContext::cpu(0), TVMType::from("float32"));
let e = NDArray::empty(
shape, TVMContext::cpu(0),
TVMType::from_str("float32").unwrap()
);
let arg: NDArray = arg.try_into()?;
let arr = arg.copy_to_ndarray(e)?;
let rnd: ArrayD<f32> = ArrayD::try_from(&arr)?;
......@@ -28,12 +33,16 @@ fn main() {
let shape = &mut [2];
let mut data = vec![3f32, 4.0];
let mut arr = NDArray::empty(shape, TVMContext::cpu(0), TVMType::from("float32"));
let mut arr = NDArray::empty(
shape,
TVMContext::cpu(0),
TVMType::from_str("float32").unwrap(),
);
arr.copy_from_buffer(data.as_mut_slice());
let mut registered = function::Builder::default();
let ret: f32 = registered
.get_function("sum", true)
.get_function("sum")
.arg(&arr)
.arg(&arr)
.invoke()
......
#![feature(extern_crate_item_prelude, panic_info_message)]
#![feature(panic_info_message)]
#![allow(unused_imports)]
use std::panic;
......@@ -6,20 +6,20 @@ use std::panic;
#[macro_use]
extern crate tvm_frontend as tvm;
use tvm::*;
use tvm::{errors::Error, *};
fn main() {
register_global_func! {
fn error(_args: &[TVMArgValue]) -> Result<TVMRetValue> {
Err(ErrorKind::TypeMismatch(
format!("{}", "i64".to_string()),
format!("{}", "f64".to_string()),
).into())
fn error(_args: &[TVMArgValue]) -> Result<TVMRetValue, Error> {
Err(errors::TypeMismatchError{
expected: "i64".to_string(),
actual: "f64".to_string(),
}.into())
}
}
let mut registered = function::Builder::default();
registered.get_function("error", true);
registered.get_function("error");
assert!(registered.func.is_some());
registered.args(&[10, 20]);
......
#![feature(extern_crate_item_prelude, try_from)]
#![allow(unused_imports)]
#[macro_use]
extern crate tvm_frontend as tvm;
use std::convert::TryInto;
use tvm::*;
use tvm::{errors::Error, *};
fn main() {
register_global_func! {
fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue> {
fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue, Error> {
let mut ret = 0.0;
for arg in args.iter() {
for arg in args.into_iter() {
let val: f64 = arg.try_into()?;
ret += val;
}
Ok(TVMRetValue::from(&ret))
Ok(TVMRetValue::from(ret))
}
}
let mut registered = function::Builder::default();
registered.get_function("sum", true);
registered.get_function("sum");
assert!(registered.func.is_some());
let ret: f64 = registered
.args(&[10.0f64, 20.0, 30.0])
......
#![feature(extern_crate_item_prelude, try_from)]
#![allow(unused_imports)]
extern crate tvm_frontend as tvm;
use std::convert::TryInto;
use tvm::*;
use tvm::{errors::Error, *};
fn main() {
fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue> {
fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue, Error> {
let mut ret = 0i64;
for arg in args.iter() {
let val: i64 = arg.try_into()?;
ret += val;
}
Ok(TVMRetValue::from(&ret))
Ok(TVMRetValue::from(ret))
}
tvm::function::register(sum, "mysum".to_owned(), false).unwrap();
let mut registered = function::Builder::default();
registered.get_function("mysum", true);
registered.get_function("mysum");
assert!(registered.func.is_some());
let ret: i64 = registered
.args(&[10, 20, 30])
......
#![feature(extern_crate_item_prelude, try_from)]
#![allow(unused_imports)]
#[macro_use]
extern crate tvm_frontend as tvm;
use std::convert::TryInto;
use tvm::*;
use tvm::{errors::Error, *};
// FIXME
fn main() {
register_global_func! {
fn concate_str(args: &[TVMArgValue]) -> Result<TVMRetValue> {
fn concate_str(args: &[TVMArgValue]) -> Result<TVMRetValue, Error> {
let mut ret = "".to_string();
for arg in args.iter() {
let val: String = arg.try_into()?;
ret += val.as_str();
let val: &str = arg.try_into()?;
ret += val;
}
Ok(TVMRetValue::from(ret))
}
}
let a = std::ffi::CString::new("a").unwrap();
let b = std::ffi::CString::new("b").unwrap();
let c = std::ffi::CString::new("c").unwrap();
let mut registered = function::Builder::default();
registered.get_function("concate_str", true);
registered.get_function("concate_str");
assert!(registered.func.is_some());
let a = "a".to_string();
let b = "b".to_string();
let c = "c".to_string();
let ret: String = registered
.args(&[a, b, c])
.arg(&a)
.arg(&b)
.arg(&c)
.invoke()
.unwrap()
.try_into()
......
......@@ -15,15 +15,15 @@ sgx = ["nom/alloc"]
[dependencies]
bounded-spsc-queue = "0.4.0"
error-chain = { version = "0.12.0", default-features = false }
failure = "0.1.5"
itertools = "0.7.8"
lazy_static = "1.1.0"
ndarray = "0.11.2"
ndarray="0.12.1"
nom = {version = "4.0.0", default-features = false }
serde = "1.0.59"
serde_derive = "1.0.79"
serde_json = "1.0.17"
tvm-common = { version = "0.1.0", path = "../common/", features = ["runtime"] }
tvm-common = { version = "0.1.0", path = "../common/" }
[target.'cfg(not(target_env = "sgx"))'.dependencies]
num_cpus = "1.8.0"
#[cfg(target_env = "sgx")]
use alloc::alloc::{self, Layout};
use alloc::alloc::{self, Layout, LayoutErr};
#[cfg(not(target_env = "sgx"))]
use std::alloc::{self, Layout};
use crate::errors::*;
use std::alloc::{self, Layout, LayoutErr};
const DEFAULT_ALIGN_BYTES: usize = 4;
......@@ -15,7 +13,7 @@ pub struct Allocation {
impl Allocation {
/// Allocates a chunk of memory of `size` bytes with optional alignment.
pub fn new(size: usize, align: Option<usize>) -> Result<Self> {
pub fn new(size: usize, align: Option<usize>) -> Result<Self, LayoutErr> {
let alignment = align.unwrap_or(DEFAULT_ALIGN_BYTES);
let layout = Layout::from_size_align(size, alignment)?;
let ptr = unsafe { alloc::alloc(layout.clone()) };
......
use std::{
any::TypeId,
convert::TryFrom,
mem,
ops::{Deref, DerefMut},
os::raw::{c_int, c_void},
ptr, slice,
};
use std::{convert::TryFrom, mem, os::raw::c_void, ptr, slice};
use failure::Error;
use ndarray;
use crate::{
allocator::Allocation,
errors::*,
ffi::runtime::{
use tvm_common::{
array::{DataType, TVMContext},
ffi::{
DLContext, DLDataType, DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt,
DLDataTypeCode_kDLUInt, DLDeviceType_kDLCPU, DLTensor as _DLTensor,
DLDataTypeCode_kDLUInt, DLTensor,
},
};
use crate::allocator::Allocation;
/// A `Storage` is a container which holds `Tensor` data.
#[derive(PartialEq)]
pub enum Storage<'a> {
......@@ -29,7 +23,7 @@ pub enum Storage<'a> {
}
impl<'a> Storage<'a> {
pub fn new(size: usize, align: Option<usize>) -> Result<Storage<'static>> {
pub fn new(size: usize, align: Option<usize>) -> Result<Storage<'static>, Error> {
Ok(Storage::Owned(Allocation::new(size, align)?))
}
......@@ -237,6 +231,27 @@ impl<'a> Tensor<'a> {
byte_offset: 0,
}
}
pub(crate) fn as_dltensor(&self, flatten: bool) -> DLTensor {
assert!(!flatten || self.is_contiguous());
DLTensor {
data: unsafe { self.data.as_mut_ptr().offset(self.byte_offset) } as *mut c_void,
ctx: DLContext::from(&self.ctx),
ndim: if flatten { 1 } else { self.shape.len() } as i32,
dtype: DLDataType::from(&self.dtype),
shape: if flatten {
&self.size as *const _ as *mut i64
} else {
self.shape.as_ptr()
} as *mut i64,
strides: if flatten || self.is_contiguous() {
ptr::null_mut()
} else {
self.strides.as_ref().unwrap().as_ptr()
} as *mut i64,
byte_offset: 0,
}
}
}
/// Conversions to `ndarray::Array` from `Tensor`, if the types match.
......@@ -244,7 +259,7 @@ macro_rules! impl_ndarray_try_from_tensor {
($type:ty, $dtype:expr) => {
impl<'a, 't> TryFrom<&'a Tensor<'t>> for ndarray::ArrayD<$type> {
type Error = Error;
fn try_from(tensor: &'a Tensor) -> Result<ndarray::ArrayD<$type>> {
fn try_from(tensor: &'a Tensor) -> Result<ndarray::ArrayD<$type>, Error> {
ensure!(
tensor.dtype == $dtype,
"Cannot convert Tensor with dtype {:?} to ndarray",
......@@ -263,120 +278,9 @@ macro_rules! impl_ndarray_try_from_tensor {
};
}
impl_ndarray_try_from_tensor!(i32, DTYPE_INT32);
impl_ndarray_try_from_tensor!(u32, DTYPE_UINT32);
impl_ndarray_try_from_tensor!(f32, DTYPE_FLOAT32);
impl_ndarray_try_from_tensor!(f64, DTYPE_FLOAT64);
pub struct DLTensor {
pub(crate) inner: _DLTensor,
}
impl Deref for DLTensor {
type Target = _DLTensor;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl DerefMut for DLTensor {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.inner
}
}
impl DLTensor {
pub(crate) fn new(raw: _DLTensor) -> Self {
Self { inner: raw }
}
pub(crate) fn from_tensor<'a>(tensor: &'a Tensor, flatten: bool) -> Self {
assert!(!flatten || tensor.is_contiguous());
Self {
inner: _DLTensor {
data: unsafe { tensor.data.as_mut_ptr().offset(tensor.byte_offset) } as *mut c_void,
ctx: DLContext::from(&tensor.ctx),
ndim: if flatten { 1 } else { tensor.shape.len() } as i32,
dtype: DLDataType::from(&tensor.dtype),
shape: if flatten {
&tensor.size as *const _ as *mut i64
} else {
tensor.shape.as_ptr()
} as *mut i64,
strides: if flatten || tensor.is_contiguous() {
ptr::null_mut()
} else {
tensor.strides.as_ref().unwrap().as_ptr()
} as *mut i64,
byte_offset: 0,
},
}
}
}
impl<'a, 't> From<&'a Tensor<'t>> for DLTensor {
fn from(tensor: &'a Tensor<'t>) -> Self {
DLTensor::from_tensor(tensor, false /* flatten */)
}
}
impl<'a, 't> From<&'a mut Tensor<'t>> for DLTensor {
fn from(tensor: &'a mut Tensor<'t>) -> Self {
DLTensor::from_tensor(tensor, false /* flatten */)
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct DataType {
pub(crate) code: usize,
pub(crate) bits: usize,
pub(crate) lanes: usize,
}
impl DataType {
/// Returns the number of bytes occupied by an element of this `DataType`.
pub fn itemsize(&self) -> usize {
(self.bits * self.lanes) >> 3
}
/// Returns whether this `DataType` represents primitive type `T`.
pub fn is_type<T: 'static>(&self) -> bool {
if self.lanes != 1 {
return false;
}
let typ = TypeId::of::<T>();
(typ == TypeId::of::<i32>() && self.code == 0 && self.bits == 32)
|| (typ == TypeId::of::<i64>() && self.code == 0 && self.bits == 64)
|| (typ == TypeId::of::<u32>() && self.code == 1 && self.bits == 32)
|| (typ == TypeId::of::<u64>() && self.code == 1 && self.bits == 64)
|| (typ == TypeId::of::<f32>() && self.code == 2 && self.bits == 32)
|| (typ == TypeId::of::<f64>() && self.code == 2 && self.bits == 64)
}
}
impl<'a> From<&'a DataType> for DLDataType {
fn from(dtype: &'a DataType) -> Self {
Self {
code: dtype.code as u8,
bits: dtype.bits as u8,
lanes: dtype.lanes as u16,
}
}
}
impl From<DLDataType> for DataType {
fn from(dtype: DLDataType) -> Self {
Self {
code: dtype.code as usize,
bits: dtype.bits as usize,
lanes: dtype.lanes as usize,
}
}
}
macro_rules! make_dtype_const {
($name: ident, $code: ident, $bits: expr, $lanes: expr) => {
const $name: DataType = DataType {
pub const $name: DataType = DataType {
code: $code as usize,
bits: $bits,
lanes: $lanes,
......@@ -389,28 +293,20 @@ make_dtype_const!(DTYPE_UINT32, DLDataTypeCode_kDLUInt, 32, 1);
// make_dtype_const!(DTYPE_FLOAT16, DLDataTypeCode_kDLFloat, 16, 1);
make_dtype_const!(DTYPE_FLOAT32, DLDataTypeCode_kDLFloat, 32, 1);
make_dtype_const!(DTYPE_FLOAT64, DLDataTypeCode_kDLFloat, 64, 1);
impl_ndarray_try_from_tensor!(i32, DTYPE_INT32);
impl_ndarray_try_from_tensor!(u32, DTYPE_UINT32);
impl_ndarray_try_from_tensor!(f32, DTYPE_FLOAT32);
impl_ndarray_try_from_tensor!(f64, DTYPE_FLOAT64);
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct TVMContext {
pub(crate) device_type: usize,
pub(crate) device_id: usize,
}
impl<'a> From<&'a TVMContext> for DLContext {
fn from(ctx: &'a TVMContext) -> Self {
Self {
device_type: ctx.device_type as u32,
device_id: ctx.device_id as i32,
}
impl<'a, 't> From<&'a Tensor<'t>> for DLTensor {
fn from(tensor: &'a Tensor<'t>) -> Self {
Tensor::as_dltensor(tensor, false /* flatten */)
}
}
impl Default for TVMContext {
fn default() -> Self {
Self {
device_type: DLDeviceType_kDLCPU as usize,
device_id: 0,
}
impl<'a, 't> From<&'a mut Tensor<'t>> for DLTensor {
fn from(tensor: &'a mut Tensor<'t>) -> Self {
Tensor::as_dltensor(tensor, false /* flatten */)
}
}
......@@ -463,42 +359,6 @@ macro_rules! impl_tensor_from_ndarray {
};
}
/// `From` conversions to `DLTensor` for `ndarray::Array`.
/// Takes a reference to the `ndarray` since `DLTensor` is not owned.
macro_rules! impl_dltensor_from_ndarray {
($type:ty, $typecode:expr) => {
impl<'a, D: ndarray::Dimension> From<&'a mut ndarray::Array<$type, D>> for DLTensor {
fn from(arr: &'a mut ndarray::Array<$type, D>) -> Self {
DLTensor {
inner: _DLTensor {
data: arr.as_mut_ptr() as *mut c_void,
ctx: DLContext {
device_type: DLDeviceType_kDLCPU,
device_id: 0,
},
ndim: arr.ndim() as c_int,
dtype: DLDataType {
code: $typecode as u8,
bits: 8 * mem::size_of::<$type>() as u8,
lanes: 1,
},
shape: arr.shape().as_ptr() as *const i64 as *mut i64,
strides: arr.strides().as_ptr() as *const isize as *mut i64,
byte_offset: 0,
},
}
}
}
};
}
impl_dltensor_from_ndarray!(f32, DLDataTypeCode_kDLFloat);
impl_dltensor_from_ndarray!(f64, DLDataTypeCode_kDLFloat);
impl_dltensor_from_ndarray!(i32, DLDataTypeCode_kDLInt);
impl_dltensor_from_ndarray!(i64, DLDataTypeCode_kDLInt);
impl_dltensor_from_ndarray!(u32, DLDataTypeCode_kDLUInt);
impl_dltensor_from_ndarray!(u64, DLDataTypeCode_kDLUInt);
impl_tensor_from_ndarray!(f32, DLDataTypeCode_kDLFloat);
impl_tensor_from_ndarray!(f64, DLDataTypeCode_kDLFloat);
impl_tensor_from_ndarray!(i32, DLDataTypeCode_kDLInt);
......
#[cfg(target_env = "sgx")]
use alloc::alloc;
#[cfg(not(target_env = "sgx"))]
use std::alloc;
use std::num;
use crate::common::errors as common_errors;
use ndarray;
use serde_json;
error_chain! {
errors {
GraphFormatError(msg: String) {
description("unable to load graph")
display("could not load graph json: {}", msg)
}
LoadGraphParamsError(msg: String) {
description("unable to load graph params")
display("could not load graph params: {}", msg)
}
}
foreign_links {
Alloc(alloc::AllocErr);
GraphDeserialize(serde_json::Error);
ParseInt(num::ParseIntError);
ShapeError(ndarray::ShapeError);
CommonError(common_errors::Error);
}
#[derive(Debug, Fail)]
pub enum GraphFormatError {
#[fail(display = "Could not parse graph json")]
Parse(#[fail(cause)] failure::Error),
#[fail(display = "Could not parse graph params")]
Params,
#[fail(display = "{} is missing attr: {}", 0, 1)]
MissingAttr(String, String),
#[fail(display = "Missing field: {}", 0)]
MissingField(&'static str),
#[fail(display = "Invalid DLType: {}", 0)]
InvalidDLType(String),
}
impl From<alloc::LayoutErr> for Error {
fn from(_err: alloc::LayoutErr) -> Error {
Error::from_kind(ErrorKind::Msg("Layout error".to_string()))
}
#[derive(Debug, Fail)]
#[fail(display = "SGX error: 0x{:x}", code)]
pub struct SgxError {
pub code: u32,
}
......@@ -14,7 +14,6 @@
allocator_api,
box_syntax,
fn_traits,
try_from,
unboxed_closures,
vec_remove_item
)]
......@@ -25,7 +24,7 @@ extern crate bounded_spsc_queue;
#[cfg(target_env = "sgx")]
extern crate core;
#[macro_use]
extern crate error_chain;
extern crate failure;
#[macro_use]
extern crate itertools;
#[macro_use]
......@@ -39,36 +38,45 @@ extern crate serde;
#[macro_use]
extern crate serde_derive;
extern crate serde_json;
extern crate tvm_common as common;
extern crate tvm_common;
mod allocator;
mod array;
pub mod errors;
mod module;
#[macro_use]
mod packed_func;
mod graph;
mod module;
#[cfg(target_env = "sgx")]
#[macro_use]
pub mod sgx;
mod threading;
mod workspace;
pub use crate::common::{errors::*, ffi, TVMArgValue, TVMRetValue};
pub use self::{
array::*, errors::*, graph::*, module::*, packed_func::*, threading::*, workspace::*,
pub use tvm_common::{
call_packed,
errors::*,
ffi::{self, DLTensor},
packed_func::{self, *},
TVMArgValue, TVMRetValue,
};
#[cfg(target_env = "sgx")]
use self::sgx::ocall_packed_func;
pub use self::{array::*, errors::*, graph::*, module::*, threading::*, workspace::*};
lazy_static! {
static ref LAST_ERROR: std::sync::RwLock<Option<&'static std::ffi::CStr>> =
std::sync::RwLock::new(None);
}
#[no_mangle]
pub extern "C" fn TVMAPISetLastError(cmsg: *const i8) {
#[cfg(not(target_env = "sgx"))]
unsafe {
panic!(std::ffi::CStr::from_ptr(cmsg).to_str().unwrap());
}
*LAST_ERROR.write().unwrap() = Some(unsafe { std::ffi::CStr::from_ptr(cmsg) });
#[cfg(target_env = "sgx")]
ocall_packed!("__sgx_set_last_error__", cmsg);
}
#[no_mangle]
pub extern "C" fn TVMGetLastError() -> *const std::os::raw::c_char {
match *LAST_ERROR.read().unwrap() {
Some(err) => err.as_ptr(),
None => std::ptr::null(),
}
}
......@@ -2,29 +2,29 @@ use std::{
collections::HashMap, convert::AsRef, ffi::CStr, os::raw::c_char, string::String, sync::Mutex,
};
use crate::{
ffi::runtime::BackendPackedCFunc,
packed_func::{wrap_backend_packed_func, PackedFunc},
use tvm_common::{
ffi::BackendPackedCFunc,
packed_func::{PackedFunc, TVMArgValue, TVMRetValue, TVMValue},
};
pub trait Module {
fn get_function<S: AsRef<str>>(&self, name: S) -> Option<PackedFunc>;
fn get_function<S: AsRef<str>>(&self, name: S) -> Option<&(dyn PackedFunc)>;
}
pub struct SystemLibModule;
lazy_static! {
static ref SYSTEM_LIB_FUNCTIONS: Mutex<HashMap<String, BackendPackedCFunc>> =
static ref SYSTEM_LIB_FUNCTIONS: Mutex<HashMap<String, &'static (dyn PackedFunc)>> =
Mutex::new(HashMap::new());
}
impl Module for SystemLibModule {
fn get_function<S: AsRef<str>>(&self, name: S) -> Option<PackedFunc> {
fn get_function<S: AsRef<str>>(&self, name: S) -> Option<&(dyn PackedFunc)> {
SYSTEM_LIB_FUNCTIONS
.lock()
.unwrap()
.get(name.as_ref())
.map(|func| wrap_backend_packed_func(func.to_owned()))
.map(|f| *f)
}
}
......@@ -34,15 +34,42 @@ impl Default for SystemLibModule {
}
}
// @see `WrapPackedFunc` in `llvm_module.cc`.
pub(super) fn wrap_backend_packed_func(
func_name: String,
func: BackendPackedCFunc,
) -> Box<dyn PackedFunc> {
box move |args: &[TVMArgValue]| {
let exit_code = func(
args.iter()
.map(|ref arg| arg.value)
.collect::<Vec<TVMValue>>()
.as_ptr(),
args.iter()
.map(|ref arg| arg.type_code as i32)
.collect::<Vec<i32>>()
.as_ptr() as *const i32,
args.len() as i32,
);
if exit_code == 0 {
Ok(TVMRetValue::default())
} else {
Err(tvm_common::errors::FuncCallError::get_with_context(
func_name.clone(),
))
}
}
}
#[no_mangle]
pub extern "C" fn TVMBackendRegisterSystemLibSymbol(
cname: *const c_char,
func: BackendPackedCFunc,
) -> i32 {
let name = unsafe { CStr::from_ptr(cname).to_str().unwrap() };
SYSTEM_LIB_FUNCTIONS
.lock()
.unwrap()
.insert(name.to_string(), func);
SYSTEM_LIB_FUNCTIONS.lock().unwrap().insert(
name.to_string(),
&*Box::leak(wrap_backend_packed_func(name.to_string(), func)),
);
return 0;
}
use std::{convert::TryFrom, marker::PhantomData, os::raw::c_void};
use super::Tensor;
use crate::ffi::runtime::{
BackendPackedCFunc, DLTensor as _DLTensor, TVMTypeCode_kArrayHandle,
TVMTypeCode_kNDArrayContainer, TVMValue as _TVMValue,
};
use super::DLTensor;
use crate::{
common::{TVMArgValue, TVMRetValue, TVMTypeCode, TVMValue},
errors::*,
};
pub type PackedFunc = Box<Fn(&[TVMArgValue]) -> TVMRetValue + Send + Sync>;
/// Calls a packed function and returns a `TVMRetValue`.
///
/// # Example
///
/// `call_packed!(my_tvm_func, &mut arg1, &mut arg2)`
#[macro_export]
macro_rules! call_packed {
($fn:expr, $($args:expr),+) => {
$fn(&[$($args.into(),)+])
};
($fn:expr) => {
$fn(&Vec::new())
};
}
impl<'a> From<&'a DLTensor> for TVMArgValue<'a> {
fn from(arr: &'a DLTensor) -> Self {
let raw = _TVMValue {
v_handle: arr as *const _ as *mut DLTensor as *mut c_void,
};
TVMArgValue {
value: TVMValue::new(raw),
type_code: TVMTypeCode::kArrayHandle,
lifetime: PhantomData,
}
}
}
impl<'a> From<&'a mut DLTensor> for TVMArgValue<'a> {
fn from(arr: &'a mut DLTensor) -> Self {
let raw = _TVMValue {
v_handle: arr as *mut _ as *mut c_void,
};
TVMArgValue {
value: TVMValue::new(raw),
type_code: TVMTypeCode::kArrayHandle,
lifetime: PhantomData,
}
}
}
impl<'a> TryFrom<TVMArgValue<'a>> for Tensor<'a> {
type Error = Error;
fn try_from(val: TVMArgValue<'a>) -> Result<Self> {
ensure!(
val.type_code == TVMTypeCode::kArrayHandle
|| val.type_code == TVMTypeCode::kNDArrayContainer,
"Could not downcast arg. Expected `{}` or `{}`, but got `{}`",
TVMTypeCode::kArrayHandle,
TVMTypeCode::kNDArrayContainer,
val.type_code,
);
let dlt = unsafe { *(val.value.v_handle as *mut _DLTensor as *const _DLTensor) };
Ok(DLTensor::new(dlt).into())
}
}
impl<'a, 't> From<&'t Tensor<'a>> for TVMRetValue {
fn from(val: &'t Tensor<'a>) -> Self {
TVMRetValue {
prim_value: 0,
box_value: box DLTensor::from(val),
type_code: TVMTypeCode::kNDArrayContainer,
}
}
}
impl<'a> TryFrom<TVMRetValue> for Tensor<'a> {
type Error = Error;
fn try_from(ret: TVMRetValue) -> Result<Self> {
ensure!(
ret.type_code == TVMTypeCode::kArrayHandle
|| ret.type_code == TVMTypeCode::kNDArrayContainer,
"Could not downcast arg. Expected `{}` or `{}`, but got `{}`",
TVMTypeCode_kArrayHandle,
TVMTypeCode_kNDArrayContainer,
ret.type_code,
);
let dlt = unsafe { *(ret.prim_value as *mut _DLTensor as *const _DLTensor) };
Ok(DLTensor::new(dlt).into())
}
}
// @see `WrapPackedFunc` in `llvm_module.cc`.
pub(crate) fn wrap_backend_packed_func(func: BackendPackedCFunc) -> PackedFunc {
box move |args: &[TVMArgValue]| {
func(
args.iter()
.map(|ref arg| arg.value.inner)
.collect::<Vec<_TVMValue>>()
.as_ptr(),
args.iter()
.map(|ref arg| arg.type_code as i32)
.collect::<Vec<i32>>()
.as_ptr() as *const i32,
args.len() as i32,
);
TVMRetValue::default()
}
}
......@@ -3,18 +3,17 @@ use std::{
os::raw::{c_char, c_int},
};
use errors::Result;
use ffi::runtime::TVMValue;
use runtime::{threading::sgx_join_threads, SystemLibModule, TVMArgValue, TVMRetValue};
pub use runtime::threading::tvm_run_worker as run_worker;
pub use crate::threading::tvm_run_worker as run_worker;
use crate::{threading::sgx_join_threads, SystemLibModule, TVMArgValue, TVMRetValue};
use errors::SgxError;
use ffi::TVMValue;
#[macro_export]
macro_rules! tvm_ocall {
($func: expr) => {
match $func {
0 => Ok(()),
err => Err(format!("SGX error: {}", err)),
code => Err(SgxError { code }),
}
};
}
......@@ -33,7 +32,10 @@ extern "C" {
) -> SgxStatus;
}
pub fn ocall_packed_func<S: AsRef<str>>(fn_name: S, args: &[TVMArgValue]) -> Result<TVMRetValue> {
pub fn ocall_packed_func<S: AsRef<str>>(
fn_name: S,
args: &[TVMArgValue],
) -> Result<TVMRetValue, SgxError> {
let mut ret_val = TVMValue { v_int64: 0 };
let ret_type_code = 0i64;
unsafe {
......@@ -58,11 +60,11 @@ pub fn ocall_packed_func<S: AsRef<str>>(fn_name: S, args: &[TVMArgValue]) -> Res
#[macro_export]
macro_rules! ocall_packed {
($fn_name:expr, $($args:expr),+) => {
ocall_packed_func($fn_name, &[$($args.into(),)+])
$crate::sgx::ocall_packed_func($fn_name, &[$($args.into(),)+])
.expect(concat!("Error calling `", $fn_name, "`"))
};
($fn_name:expr) => {
ocall_packed_func($fn_name, &Vec::new())
$crate::sgx::ocall_packed_func($fn_name, &Vec::new())
.expect(concat!("Error calling `", $fn_name, "`"))
}
}
......
use std::{
os::raw::{c_int, c_void},
sync::{
atomic::{AtomicUsize, Ordering, ATOMIC_USIZE_INIT},
atomic::{AtomicUsize, Ordering},
Arc, Barrier,
},
};
......@@ -18,11 +18,10 @@ use std::{
use std::{collections::VecDeque, ptr, sync::Mutex};
use bounded_spsc_queue::{self, Producer};
use crate::{errors::*, ffi::runtime::TVMParallelGroupEnv};
use tvm_common::ffi::TVMParallelGroupEnv;
#[cfg(target_env = "sgx")]
use super::{sgx::ocall_packed_func, TVMArgValue, TVMRetValue};
use super::{TVMArgValue, TVMRetValue};
type FTVMParallelLambda =
extern "C" fn(task_id: usize, penv: *const TVMParallelGroupEnv, cdata: *const c_void) -> i32;
......@@ -62,12 +61,11 @@ impl Job {
}
/// Waits for all tasks in this `Job` to be completed.
fn wait(&self) -> Result<()> {
fn wait(&self) {
while self.pending.load(Ordering::Acquire) > 0 {
#[cfg(not(target_env = "sgx"))]
thread::yield_now();
}
Ok(())
}
}
......@@ -161,7 +159,7 @@ impl ThreadPool {
}
tasks.pop().unwrap()();
job.wait().unwrap();
job.wait();
}
fn run_worker(queue: Consumer<Task>) {
......@@ -251,7 +249,7 @@ pub extern "C" fn TVMBackendParallelLaunch(
cb: cb,
cdata: cdata,
req_num_tasks: num_task,
pending: Arc::new(ATOMIC_USIZE_INIT),
pending: Arc::new(AtomicUsize::new(0)),
});
});
}
......@@ -273,7 +271,7 @@ pub(crate) fn sgx_join_threads() {
cb: poison_pill,
cdata: ptr::null(),
req_num_tasks: 0,
pending: Arc::new(ATOMIC_USIZE_INIT),
pending: Arc::new(AtomicUsize::new(0)),
});
});
ocall_packed!("__sgx_thread_group_join__", 0);
......@@ -322,8 +320,8 @@ mod tests {
#[test]
fn test_parallel_launch() {
TVMBackendParallelLaunch(flambda, ptr::null(), 6);
let counter = ATOMIC_USIZE_INIT;
let task_ids_sum = ATOMIC_USIZE_INIT;
let counter = AtomicUsize::new(0);
let task_ids_sum = AtomicUsize::new(0);
let cdata = (counter, task_ids_sum);
let num_tasks = 3;
TVMBackendParallelLaunch(flambda, &cdata as *const _ as *const c_void, num_tasks);
......
......@@ -4,8 +4,9 @@ use std::{
ptr,
};
use super::allocator::Allocation;
use crate::errors::*;
use failure::Error;
use crate::allocator::Allocation;
const WS_ALIGN: usize = 64; // taken from `kTempAllocaAlignment` in `device_api.h`
......@@ -24,13 +25,13 @@ impl WorkspacePool {
}
}
fn alloc_new(&mut self, size: usize) -> Result<*mut u8> {
fn alloc_new(&mut self, size: usize) -> Result<*mut u8, Error> {
self.workspaces.push(Allocation::new(size, Some(WS_ALIGN))?);
self.in_use.push(self.workspaces.len() - 1);
Ok(self.workspaces[self.workspaces.len() - 1].as_mut_ptr())
}
fn alloc(&mut self, size: usize) -> Result<*mut u8> {
fn alloc(&mut self, size: usize) -> Result<*mut u8, Error> {
if self.free.len() == 0 {
return self.alloc_new(size);
}
......@@ -60,7 +61,7 @@ impl WorkspacePool {
}
}
fn free(&mut self, ptr: *mut u8) -> Result<()> {
fn free(&mut self, ptr: *mut u8) -> Result<(), Error> {
let mut ws_idx = None;
for i in 0..self.in_use.len() {
let idx = self.in_use[i];
......@@ -72,7 +73,7 @@ impl WorkspacePool {
}
Ok(self
.free
.push(ws_idx.ok_or("Tried to free nonexistent workspace.")?))
.push(ws_idx.ok_or(format_err!("Tried to free nonexistent workspace."))?))
}
}
......
......@@ -5,7 +5,7 @@ license = "Apache-2.0"
authors = ["TVM Contributors"]
[dependencies]
ndarray = "0.11.2"
ndarray="0.12.1"
serde = "1.0.59"
serde_json = "1.0.17"
tvm-runtime = { path = "../../" }
......
......@@ -5,7 +5,7 @@ license = "Apache-2.0"
authors = ["TVM Contributors"]
[dependencies]
ndarray = "0.11.2"
ndarray="0.12.1"
tvm-runtime = { path = "../../" }
[build-dependencies]
......
......@@ -17,6 +17,6 @@ fn main() {
let mut a_dl: DLTensor = (&mut a).into();
let mut b_dl: DLTensor = (&mut b).into();
let mut c_dl: DLTensor = (&mut c).into();
call_packed!(add, &mut a_dl, &mut b_dl, &mut c_dl);
call_packed!(add, &mut a_dl, &mut b_dl, &mut c_dl).unwrap();
assert!(c.all_close(&e, 1e-8f32));
}
......@@ -14,11 +14,11 @@ cargo fmt -- --check
# test common
cd $RUST_DIR/common
cargo build --features runtime
cargo test --features runtime --tests
cargo build
cargo test --tests
cargo build --features frontend
cargo test --features frontend --tests
cargo build --features bindings
cargo test --features bindings --tests
# test runtime
cd $RUST_DIR/runtime
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment