Commit 4968279f by Nick Hynes Committed by Tianqi Chen

[Rust] Unify types between bindings and pure Rust impl (#2616)

parent 71abe36e
Cargo.lock
target/ target/
**/*.rs.bk *.rs.bk
Cargo.lock
c_runtime_api.rs
target
**/*.rs.bk
Cargo.lock
/tvm-sys/src/bindgen.rs
...@@ -5,9 +5,11 @@ authors = ["TVM Contributors"] ...@@ -5,9 +5,11 @@ authors = ["TVM Contributors"]
license = "Apache-2.0" license = "Apache-2.0"
[features] [features]
runtime = [] bindings = []
frontend = ["tvm-sys"]
[dependencies] [dependencies]
error-chain = { version = "0.12.0", default-features = false } failure = "0.1.5"
tvm-sys = { version = "0.1.0", path = "tvm-sys", optional = true } ndarray = "0.12.1"
[build-dependencies]
bindgen = "0.37.4"
...@@ -3,23 +3,29 @@ extern crate bindgen; ...@@ -3,23 +3,29 @@ extern crate bindgen;
use std::path::PathBuf; use std::path::PathBuf;
fn main() { fn main() {
if cfg!(feature = "bindings") {
println!("cargo:rerun-if-env-changed=TVM_HOME"); println!("cargo:rerun-if-env-changed=TVM_HOME");
println!("cargo:rustc-link-lib=dylib=tvm_runtime"); println!("cargo:rustc-link-lib=dylib=tvm_runtime");
println!("cargo:rustc-link-search={}/build", env!("TVM_HOME")); println!("cargo:rustc-link-search={}/build", env!("TVM_HOME"));
let bindings = bindgen::Builder::default() }
// @see rust-bindgen#550 for `blacklist_type`
bindgen::Builder::default()
.header(format!( .header(format!(
"{}/include/tvm/runtime/c_runtime_api.h", "{}/include/tvm/runtime/c_runtime_api.h",
env!("TVM_HOME") env!("TVM_HOME")
)) ))
.header(format!(
"{}/include/tvm/runtime/c_backend_api.h",
env!("TVM_HOME")
))
.clang_arg(format!("-I{}/3rdparty/dlpack/include/", env!("TVM_HOME"))) .clang_arg(format!("-I{}/3rdparty/dlpack/include/", env!("TVM_HOME")))
.blacklist_type("max_align_t") // @see rust-bindgen#550 .blacklist_type("max_align_t")
.layout_tests(false) .layout_tests(false)
.derive_partialeq(true) .derive_partialeq(true)
.derive_eq(true) .derive_eq(true)
.generate() .generate()
.expect("unable to generate bindings"); .expect("unable to generate bindings")
.write_to_file(PathBuf::from("src/c_runtime_api.rs"))
bindings
.write_to_file(PathBuf::from("src/bindgen.rs"))
.expect("can not write the bindings!"); .expect("can not write the bindings!");
} }
use std::{
any::TypeId,
mem,
os::raw::{c_int, c_void},
};
use crate::ffi::{
DLContext, DLDataType, DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLDataTypeCode_kDLUInt,
DLDeviceType_kDLCPU, DLTensor,
};
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct DataType {
pub code: usize,
pub bits: usize,
pub lanes: usize,
}
impl DataType {
/// Returns the number of bytes occupied by an element of this `DataType`.
pub fn itemsize(&self) -> usize {
(self.bits * self.lanes) >> 3
}
/// Returns whether this `DataType` represents primitive type `T`.
pub fn is_type<T: 'static>(&self) -> bool {
if self.lanes != 1 {
return false;
}
let typ = TypeId::of::<T>();
(typ == TypeId::of::<i32>() && self.code == 0 && self.bits == 32)
|| (typ == TypeId::of::<i64>() && self.code == 0 && self.bits == 64)
|| (typ == TypeId::of::<u32>() && self.code == 1 && self.bits == 32)
|| (typ == TypeId::of::<u64>() && self.code == 1 && self.bits == 64)
|| (typ == TypeId::of::<f32>() && self.code == 2 && self.bits == 32)
|| (typ == TypeId::of::<f64>() && self.code == 2 && self.bits == 64)
}
pub fn code(&self) -> usize {
self.code
}
pub fn bits(&self) -> usize {
self.bits
}
pub fn lanes(&self) -> usize {
self.lanes
}
}
impl<'a> From<&'a DataType> for DLDataType {
fn from(dtype: &'a DataType) -> Self {
Self {
code: dtype.code as u8,
bits: dtype.bits as u8,
lanes: dtype.lanes as u16,
}
}
}
impl From<DLDataType> for DataType {
fn from(dtype: DLDataType) -> Self {
Self {
code: dtype.code as usize,
bits: dtype.bits as usize,
lanes: dtype.lanes as usize,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct TVMContext {
pub device_type: usize,
pub device_id: usize,
}
impl<'a> From<&'a TVMContext> for DLContext {
fn from(ctx: &'a TVMContext) -> Self {
Self {
device_type: ctx.device_type as u32,
device_id: ctx.device_id as i32,
}
}
}
impl Default for TVMContext {
fn default() -> Self {
Self {
device_type: DLDeviceType_kDLCPU as usize,
device_id: 0,
}
}
}
/// `From` conversions to `DLTensor` for `ndarray::Array`.
/// Takes a reference to the `ndarray` since `DLTensor` is not owned.
macro_rules! impl_dltensor_from_ndarray {
($type:ty, $typecode:expr) => {
impl<'a, D: ndarray::Dimension> From<&'a mut ndarray::Array<$type, D>> for DLTensor {
fn from(arr: &'a mut ndarray::Array<$type, D>) -> Self {
DLTensor {
data: arr.as_mut_ptr() as *mut c_void,
ctx: DLContext {
device_type: DLDeviceType_kDLCPU,
device_id: 0,
},
ndim: arr.ndim() as c_int,
dtype: DLDataType {
code: $typecode as u8,
bits: 8 * mem::size_of::<$type>() as u8,
lanes: 1,
},
shape: arr.shape().as_ptr() as *const i64 as *mut i64,
strides: arr.strides().as_ptr() as *const isize as *mut i64,
byte_offset: 0,
}
}
}
};
}
impl_dltensor_from_ndarray!(f32, DLDataTypeCode_kDLFloat);
impl_dltensor_from_ndarray!(f64, DLDataTypeCode_kDLFloat);
impl_dltensor_from_ndarray!(i32, DLDataTypeCode_kDLInt);
impl_dltensor_from_ndarray!(i64, DLDataTypeCode_kDLInt);
impl_dltensor_from_ndarray!(u32, DLDataTypeCode_kDLUInt);
impl_dltensor_from_ndarray!(u64, DLDataTypeCode_kDLUInt);
/* automatically generated by rust-bindgen for TVM revision 6292c78 */
pub const TVM_VERSION: &'static [u8; 8usize] = b"0.5.dev\0";
pub const DLPACK_VERSION: u32 = 8;
pub const _STDINT_H: u32 = 1;
pub const _FEATURES_H: u32 = 1;
pub const _DEFAULT_SOURCE: u32 = 1;
pub const __USE_ISOC11: u32 = 1;
pub const __USE_ISOC99: u32 = 1;
pub const __USE_ISOC95: u32 = 1;
pub const __USE_POSIX_IMPLICITLY: u32 = 1;
pub const _POSIX_SOURCE: u32 = 1;
pub const _POSIX_C_SOURCE: u32 = 200809;
pub const __USE_POSIX: u32 = 1;
pub const __USE_POSIX2: u32 = 1;
pub const __USE_POSIX199309: u32 = 1;
pub const __USE_POSIX199506: u32 = 1;
pub const __USE_XOPEN2K: u32 = 1;
pub const __USE_XOPEN2K8: u32 = 1;
pub const _ATFILE_SOURCE: u32 = 1;
pub const __USE_MISC: u32 = 1;
pub const __USE_ATFILE: u32 = 1;
pub const __USE_FORTIFY_LEVEL: u32 = 0;
pub const _STDC_PREDEF_H: u32 = 1;
pub const __STDC_IEC_559__: u32 = 1;
pub const __STDC_IEC_559_COMPLEX__: u32 = 1;
pub const __STDC_ISO_10646__: u32 = 201505;
pub const __STDC_NO_THREADS__: u32 = 1;
pub const __GNU_LIBRARY__: u32 = 6;
pub const __GLIBC__: u32 = 2;
pub const __GLIBC_MINOR__: u32 = 23;
pub const _SYS_CDEFS_H: u32 = 1;
pub const __WORDSIZE: u32 = 64;
pub const __WORDSIZE_TIME64_COMPAT32: u32 = 1;
pub const __SYSCALL_WORDSIZE: u32 = 64;
pub const _BITS_WCHAR_H: u32 = 1;
pub const INT8_MIN: i32 = -128;
pub const INT16_MIN: i32 = -32768;
pub const INT32_MIN: i32 = -2147483648;
pub const INT8_MAX: u32 = 127;
pub const INT16_MAX: u32 = 32767;
pub const INT32_MAX: u32 = 2147483647;
pub const UINT8_MAX: u32 = 255;
pub const UINT16_MAX: u32 = 65535;
pub const UINT32_MAX: u32 = 4294967295;
pub const INT_LEAST8_MIN: i32 = -128;
pub const INT_LEAST16_MIN: i32 = -32768;
pub const INT_LEAST32_MIN: i32 = -2147483648;
pub const INT_LEAST8_MAX: u32 = 127;
pub const INT_LEAST16_MAX: u32 = 32767;
pub const INT_LEAST32_MAX: u32 = 2147483647;
pub const UINT_LEAST8_MAX: u32 = 255;
pub const UINT_LEAST16_MAX: u32 = 65535;
pub const UINT_LEAST32_MAX: u32 = 4294967295;
pub const INT_FAST8_MIN: i32 = -128;
pub const INT_FAST16_MIN: i64 = -9223372036854775808;
pub const INT_FAST32_MIN: i64 = -9223372036854775808;
pub const INT_FAST8_MAX: u32 = 127;
pub const INT_FAST16_MAX: u64 = 9223372036854775807;
pub const INT_FAST32_MAX: u64 = 9223372036854775807;
pub const UINT_FAST8_MAX: u32 = 255;
pub const UINT_FAST16_MAX: i32 = -1;
pub const UINT_FAST32_MAX: i32 = -1;
pub const INTPTR_MIN: i64 = -9223372036854775808;
pub const INTPTR_MAX: u64 = 9223372036854775807;
pub const UINTPTR_MAX: i32 = -1;
pub const PTRDIFF_MIN: i64 = -9223372036854775808;
pub const PTRDIFF_MAX: u64 = 9223372036854775807;
pub const SIG_ATOMIC_MIN: i32 = -2147483648;
pub const SIG_ATOMIC_MAX: u32 = 2147483647;
pub const SIZE_MAX: i32 = -1;
pub const WINT_MIN: u32 = 0;
pub const WINT_MAX: u32 = 4294967295;
pub type int_least8_t = ::std::os::raw::c_schar;
pub type int_least16_t = ::std::os::raw::c_short;
pub type int_least32_t = ::std::os::raw::c_int;
pub type int_least64_t = ::std::os::raw::c_long;
pub type uint_least8_t = ::std::os::raw::c_uchar;
pub type uint_least16_t = ::std::os::raw::c_ushort;
pub type uint_least32_t = ::std::os::raw::c_uint;
pub type uint_least64_t = ::std::os::raw::c_ulong;
pub type int_fast8_t = ::std::os::raw::c_schar;
pub type int_fast16_t = ::std::os::raw::c_long;
pub type int_fast32_t = ::std::os::raw::c_long;
pub type int_fast64_t = ::std::os::raw::c_long;
pub type uint_fast8_t = ::std::os::raw::c_uchar;
pub type uint_fast16_t = ::std::os::raw::c_ulong;
pub type uint_fast32_t = ::std::os::raw::c_ulong;
pub type uint_fast64_t = ::std::os::raw::c_ulong;
pub type intmax_t = ::std::os::raw::c_long;
pub type uintmax_t = ::std::os::raw::c_ulong;
pub type wchar_t = ::std::os::raw::c_int;
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct max_align_t {
pub __clang_max_align_nonce1: ::std::os::raw::c_longlong,
pub __bindgen_padding_0: u64,
pub __clang_max_align_nonce2: f64,
}
pub const DLDeviceType_kDLCPU: DLDeviceType = 1;
pub const DLDeviceType_kDLGPU: DLDeviceType = 2;
pub const DLDeviceType_kDLCPUPinned: DLDeviceType = 3;
pub const DLDeviceType_kDLOpenCL: DLDeviceType = 4;
pub const DLDeviceType_kDLMetal: DLDeviceType = 8;
pub const DLDeviceType_kDLVPI: DLDeviceType = 9;
pub const DLDeviceType_kDLROCM: DLDeviceType = 10;
/// \brief The device type in DLContext.
pub type DLDeviceType = u32;
/// \brief A Device context for Tensor and operator.
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct DLContext {
/// \brief The device type used in the device.
pub device_type: DLDeviceType,
/// \brief The device index
pub device_id: ::std::os::raw::c_int,
}
pub const DLDataTypeCode_kDLInt: DLDataTypeCode = 0;
pub const DLDataTypeCode_kDLUInt: DLDataTypeCode = 1;
pub const DLDataTypeCode_kDLFloat: DLDataTypeCode = 2;
/// \brief The type code options DLDataType.
pub type DLDataTypeCode = u32;
/// \brief The data type the tensor can hold.
///
/// Examples
/// - float: type_code = 2, bits = 32, lanes=1
/// - float4(vectorized 4 float): type_code = 2, bits = 32, lanes=4
/// - int8: type_code = 0, bits = 8, lanes=1
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct DLDataType {
/// \brief Type code of base types.
/// We keep it uint8_t instead of DLDataTypeCode for minimal memory
/// footprint, but the value should be one of DLDataTypeCode enum values.
///
pub code: u8,
/// \brief Number of bits, common choices are 8, 16, 32.
pub bits: u8,
/// \brief Number of lanes in the type, used for vector types.
pub lanes: u16,
}
/// \brief Plain C Tensor object, does not manage memory.
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct DLTensor {
/// \brief The opaque data pointer points to the allocated data.
/// This will be CUDA device pointer or cl_mem handle in OpenCL.
/// This pointer is always aligns to 256 bytes as in CUDA.
pub data: *mut ::std::os::raw::c_void,
/// \brief The device context of the tensor
pub ctx: DLContext,
/// \brief Number of dimensions
pub ndim: ::std::os::raw::c_int,
/// \brief The data type of the pointer
pub dtype: DLDataType,
/// \brief The shape of the tensor
pub shape: *mut i64,
/// \brief strides of the tensor,
/// can be NULL, indicating tensor is compact.
pub strides: *mut i64,
/// \brief The offset in bytes to the beginning pointer to data
pub byte_offset: u64,
}
/// \brief C Tensor object, manage memory of DLTensor. This data structure is
/// intended to faciliate the borrowing of DLTensor by another framework. It is
/// not meant to transfer the tensor. When the borrowing framework doesn't need
/// the tensor, it should call the deleter to notify the host that the resource
/// is no longer needed.
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct DLManagedTensor {
/// \brief DLTensor which is being memory managed
pub dl_tensor: DLTensor,
/// \brief the context of the original host framework of DLManagedTensor in
/// which DLManagedTensor is used in the framework. It can also be NULL.
pub manager_ctx: *mut ::std::os::raw::c_void,
/// \brief Destructor signature void (*)(void*) - this should be called
/// to destruct manager_ctx which holds the DLManagedTensor. It can be NULL
/// if there is no way for the caller to provide a reasonable destructor.
pub deleter: ::std::option::Option<unsafe extern "C" fn(self_: *mut DLManagedTensor)>,
}
/// \brief type of array index.
pub type tvm_index_t = i64;
pub const TVMDeviceExtType_kDLAOCL: TVMDeviceExtType = 5;
pub const TVMDeviceExtType_kDLSDAccel: TVMDeviceExtType = 6;
pub const TVMDeviceExtType_kDLVulkan: TVMDeviceExtType = 7;
pub const TVMDeviceExtType_kOpenGL: TVMDeviceExtType = 11;
pub const TVMDeviceExtType_kExtDev: TVMDeviceExtType = 12;
/// \brief Extension device types in TVM
pub type TVMDeviceExtType = u32;
pub const TVMTypeCode_kHandle: TVMTypeCode = 3;
pub const TVMTypeCode_kNull: TVMTypeCode = 4;
pub const TVMTypeCode_kTVMType: TVMTypeCode = 5;
pub const TVMTypeCode_kTVMContext: TVMTypeCode = 6;
pub const TVMTypeCode_kArrayHandle: TVMTypeCode = 7;
pub const TVMTypeCode_kNodeHandle: TVMTypeCode = 8;
pub const TVMTypeCode_kModuleHandle: TVMTypeCode = 9;
pub const TVMTypeCode_kFuncHandle: TVMTypeCode = 10;
pub const TVMTypeCode_kStr: TVMTypeCode = 11;
pub const TVMTypeCode_kBytes: TVMTypeCode = 12;
pub const TVMTypeCode_kNDArrayContainer: TVMTypeCode = 13;
pub const TVMTypeCode_kExtBegin: TVMTypeCode = 15;
pub const TVMTypeCode_kNNVMFirst: TVMTypeCode = 16;
pub const TVMTypeCode_kNNVMLast: TVMTypeCode = 20;
pub const TVMTypeCode_kExtReserveEnd: TVMTypeCode = 64;
pub const TVMTypeCode_kExtEnd: TVMTypeCode = 128;
/// \brief The type code in TVMType
/// \note TVMType is used in two places.
pub type TVMTypeCode = u32;
/// \brief The data type used in TVM Runtime.
///
/// Examples
/// - float: type_code = 2, bits = 32, lanes=1
/// - float4(vectorized 4 float): type_code = 2, bits = 32, lanes=4
/// - int8: type_code = 0, bits = 8, lanes=1
///
/// \note Arguments TVM API function always takes bits=64 and lanes=1
pub type TVMType = DLDataType;
/// \brief The Device information, abstract away common device types.
pub type TVMContext = DLContext;
/// \brief The tensor array stucture to TVM API.
pub type TVMArray = DLTensor;
/// \brief the array handle
pub type TVMArrayHandle = *mut TVMArray;
/// \brief Union type of values
/// being passed through API and function calls.
#[repr(C)]
#[derive(Copy, Clone)]
pub union TVMValue {
pub v_int64: i64,
pub v_float64: f64,
pub v_handle: *mut ::std::os::raw::c_void,
pub v_str: *const ::std::os::raw::c_char,
pub v_type: TVMType,
pub v_ctx: TVMContext,
_bindgen_union_align: u64,
}
/// \brief Byte array type used to pass in byte array
/// When kBytes is used as data type.
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct TVMByteArray {
pub data: *const ::std::os::raw::c_char,
pub size: usize,
}
/// \brief Handle to TVM runtime modules.
pub type TVMModuleHandle = *mut ::std::os::raw::c_void;
/// \brief Handle to packed function handle.
pub type TVMFunctionHandle = *mut ::std::os::raw::c_void;
/// \brief Handle to hold return value.
pub type TVMRetValueHandle = *mut ::std::os::raw::c_void;
/// \brief The stream that is specific to device
/// can be NULL, which indicates the default one.
pub type TVMStreamHandle = *mut ::std::os::raw::c_void;
extern "C" {
/// \brief Used for implementing C API function.
/// Set last error message before return.
/// \param msg The error message to be set.
pub fn TVMAPISetLastError(msg: *const ::std::os::raw::c_char);
}
extern "C" {
/// \brief return str message of the last error
/// all function in this file will return 0 when success
/// and -1 when an error occured,
/// TVMGetLastError can be called to retrieve the error
///
/// this function is threadsafe and can be called by different thread
/// \return error info
pub fn TVMGetLastError() -> *const ::std::os::raw::c_char;
}
extern "C" {
/// \brief Load module from file.
/// \param file_name The file name to load the module from.
/// \param format The format of the module.
/// \param out The result module
///
/// \return 0 when success, -1 when failure happens
/// \note The resulting module do not contain import relation.
/// It can be reconstructed by TVMModImport.
pub fn TVMModLoadFromFile(
file_name: *const ::std::os::raw::c_char,
format: *const ::std::os::raw::c_char,
out: *mut TVMModuleHandle,
) -> ::std::os::raw::c_int;
}
extern "C" {
/// \brief Add dep to mod's dependency.
/// This allows functions in this module to use modules.
///
/// \param mod The module handle.
/// \param dep The dependent module to be imported.
/// \return 0 when success, -1 when failure happens
pub fn TVMModImport(mod_: TVMModuleHandle, dep: TVMModuleHandle) -> ::std::os::raw::c_int;
}
extern "C" {
/// \brief Get function from the module.
/// \param mod The module handle.
/// \param func_name The name of the function.
/// \param query_imports Whether to query imported modules
/// \param out The result function, can be NULL if it is not available.
/// \return 0 when no error is thrown, -1 when failure happens
pub fn TVMModGetFunction(
mod_: TVMModuleHandle,
func_name: *const ::std::os::raw::c_char,
query_imports: ::std::os::raw::c_int,
out: *mut TVMFunctionHandle,
) -> ::std::os::raw::c_int;
}
extern "C" {
/// \brief Free front-end extension type resource.
/// \param handle The extension handle.
/// \param type_code The type of of the extension type.
/// \return 0 when success, -1 when failure happens
pub fn TVMExtTypeFree(
handle: *mut ::std::os::raw::c_void,
type_code: ::std::os::raw::c_int,
) -> ::std::os::raw::c_int;
}
extern "C" {
/// \brief Free the Module
/// \param mod The module to be freed.
///
/// \note This may not free up the module's resources.
/// If there is active TVMFunctionHandle uses the module
/// Or if this module is imported by another active module.
///
/// The all functions remains valid until TVMFuncFree is called.
/// \return 0 when success, -1 when failure happens
pub fn TVMModFree(mod_: TVMModuleHandle) -> ::std::os::raw::c_int;
}
extern "C" {
/// \brief Free the function when it is no longer needed.
/// \param func The function handle
/// \return 0 when success, -1 when failure happens
pub fn TVMFuncFree(func: TVMFunctionHandle) -> ::std::os::raw::c_int;
}
extern "C" {
/// \brief Call a Packed TVM Function.
///
/// \param func node handle of the function.
/// \param arg_values The arguments
/// \param type_codes The type codes of the arguments
/// \param num_args Number of arguments.
///
/// \param ret_val The return value.
/// \param ret_type_code the type code of return value.
///
/// \return 0 when success, -1 when failure happens
/// \note TVM calls always exchanges with type bits=64, lanes=1
///
/// \note API calls always exchanges with type bits=64, lanes=1
/// If API call returns container handles (e.g. FunctionHandle)
/// these handles should be managed by the front-end.
/// The front-end need to call free function (e.g. TVMFuncFree)
/// to free these handles.
pub fn TVMFuncCall(
func: TVMFunctionHandle,
arg_values: *mut TVMValue,
type_codes: *mut ::std::os::raw::c_int,
num_args: ::std::os::raw::c_int,
ret_val: *mut TVMValue,
ret_type_code: *mut ::std::os::raw::c_int,
) -> ::std::os::raw::c_int;
}
extern "C" {
/// \brief Set the return value of TVMPackedCFunc.
///
/// This function is called by TVMPackedCFunc to set the return value.
/// When this function is not called, the function returns null by default.
///
/// \param ret The return value handle, pass by ret in TVMPackedCFunc
/// \param value The value to be returned.
/// \param type_code The type of the value to be returned.
/// \param num_ret Number of return values, for now only 1 is supported.
pub fn TVMCFuncSetReturn(
ret: TVMRetValueHandle,
value: *mut TVMValue,
type_code: *mut ::std::os::raw::c_int,
num_ret: ::std::os::raw::c_int,
) -> ::std::os::raw::c_int;
}
extern "C" {
/// \brief Inplace translate callback argument value to return value.
/// This is only needed for non-POD arguments.
///
/// \param value The value to be translated.
/// \param code The type code to be translated.
/// \note This function will do a shallow copy when necessary.
///
/// \return 0 when success, -1 when failure happens.
pub fn TVMCbArgToReturn(
value: *mut TVMValue,
code: ::std::os::raw::c_int,
) -> ::std::os::raw::c_int;
}
/// \brief C type of packed function.
///
/// \param args The arguments
/// \param type_codes The type codes of the arguments
/// \param num_args Number of arguments.
/// \param ret The return value handle.
/// \param resource_handle The handle additional resouce handle from fron-end.
/// \return 0 if success, -1 if failure happens, set error via TVMAPISetLastError.
/// \sa TVMCFuncSetReturn
pub type TVMPackedCFunc = ::std::option::Option<
unsafe extern "C" fn(
args: *mut TVMValue,
type_codes: *mut ::std::os::raw::c_int,
num_args: ::std::os::raw::c_int,
ret: TVMRetValueHandle,
resource_handle: *mut ::std::os::raw::c_void,
) -> ::std::os::raw::c_int,
>;
/// \brief C callback to free the resource handle in C packed function.
/// \param resource_handle The handle additional resouce handle from fron-end.
pub type TVMPackedCFuncFinalizer =
::std::option::Option<unsafe extern "C" fn(resource_handle: *mut ::std::os::raw::c_void)>;
/// \brief Signature for extension function declarer.
///
/// TVM call this function to get the extension functions
/// The declarer will call register_func to register function and their name.
///
/// \param register_func_handle The register function
/// \return 0 if success, -1 if failure happens
pub type TVMExtensionFuncDeclarer = ::std::option::Option<
unsafe extern "C" fn(register_func_handle: TVMFunctionHandle) -> ::std::os::raw::c_int,
>;
extern "C" {
/// \brief Wrap a TVMPackedCFunc to become a FunctionHandle.
///
/// The resource_handle will be managed by TVM API, until the function is no longer used.
///
/// \param func The packed C function.
/// \param resource_handle The resource handle from front-end, can be NULL.
/// \param fin The finalizer on resource handle when the FunctionHandle get freed, can be NULL
/// \param out the result function handle.
/// \return 0 when success, -1 when failure happens
pub fn TVMFuncCreateFromCFunc(
func: TVMPackedCFunc,
resource_handle: *mut ::std::os::raw::c_void,
fin: TVMPackedCFuncFinalizer,
out: *mut TVMFunctionHandle,
) -> ::std::os::raw::c_int;
}
extern "C" {
/// \brief Register the function to runtime's global table.
///
/// The registered function then can be pulled by the backend by the name.
///
/// \param name The name of the function.
/// \param f The function to be registered.
/// \param override Whether allow override already registered function.
pub fn TVMFuncRegisterGlobal(
name: *const ::std::os::raw::c_char,
f: TVMFunctionHandle,
override_: ::std::os::raw::c_int,
) -> ::std::os::raw::c_int;
}
extern "C" {
/// \brief Get a global function.
///
/// \param name The name of the function.
/// \param out the result function pointer, NULL if it does not exist.
///
/// \note The function handle of global function is managed by TVM runtime,
/// So TVMFuncFree is should not be called when it get deleted.
pub fn TVMFuncGetGlobal(
name: *const ::std::os::raw::c_char,
out: *mut TVMFunctionHandle,
) -> ::std::os::raw::c_int;
}
extern "C" {
/// \brief List all the globally registered function name
/// \param out_size The number of functions
/// \param out_array The array of function names.
/// \return 0 when success, -1 when failure happens
pub fn TVMFuncListGlobalNames(
out_size: *mut ::std::os::raw::c_int,
out_array: *mut *mut *const ::std::os::raw::c_char,
) -> ::std::os::raw::c_int;
}
extern "C" {
/// \brief Allocate a nd-array's memory,
/// including space of shape, of given spec.
///
/// \param shape The shape of the array, the data content will be copied to out
/// \param ndim The number of dimension of the array.
/// \param dtype_code The type code of the dtype
/// \param dtype_bits The number of bits of dtype
/// \param dtype_lanes The number of lanes in the dtype.
/// \param device_type The device type of context
/// \param device_id The device id of context.
/// \param out The output handle.
/// \return 0 when success, -1 when failure happens
pub fn TVMArrayAlloc(
shape: *const tvm_index_t,
ndim: ::std::os::raw::c_int,
dtype_code: ::std::os::raw::c_int,
dtype_bits: ::std::os::raw::c_int,
dtype_lanes: ::std::os::raw::c_int,
device_type: ::std::os::raw::c_int,
device_id: ::std::os::raw::c_int,
out: *mut TVMArrayHandle,
) -> ::std::os::raw::c_int;
}
extern "C" {
/// \brief Free the TVM Array.
/// \param handle The array handle to be freed.
/// \return 0 when success, -1 when failure happens
pub fn TVMArrayFree(handle: TVMArrayHandle) -> ::std::os::raw::c_int;
}
extern "C" {
/// \brief Copy array data from CPU byte array.
/// \param handle The array handle.
/// \param data the data pointer
/// \param nbytes The number of bytes to copy.
/// \return 0 when success, -1 when failure happens
pub fn TVMArrayCopyFromBytes(
handle: TVMArrayHandle,
data: *mut ::std::os::raw::c_void,
nbytes: usize,
) -> ::std::os::raw::c_int;
}
extern "C" {
/// \brief Copy array data to CPU byte array.
/// \param handle The array handle.
/// \param data the data pointer
/// \param nbytes The number of bytes to copy.
/// \return 0 when success, -1 when failure happens
pub fn TVMArrayCopyToBytes(
handle: TVMArrayHandle,
data: *mut ::std::os::raw::c_void,
nbytes: usize,
) -> ::std::os::raw::c_int;
}
extern "C" {
/// \brief Copy the array, both from and to must be valid during the copy.
/// \param from The array to be copied from.
/// \param to The target space.
/// \param stream The stream where the copy happens, can be NULL.
/// \return 0 when success, -1 when failure happens
pub fn TVMArrayCopyFromTo(
from: TVMArrayHandle,
to: TVMArrayHandle,
stream: TVMStreamHandle,
) -> ::std::os::raw::c_int;
}
extern "C" {
/// \brief Produce an array from the DLManagedTensor that shares data memory
/// with the DLManagedTensor.
/// \param from The source DLManagedTensor.
/// \param out The output array handle.
/// \return 0 when success, -1 when failure happens
pub fn TVMArrayFromDLPack(
from: *mut DLManagedTensor,
out: *mut TVMArrayHandle,
) -> ::std::os::raw::c_int;
}
extern "C" {
/// \brief Produce a DLMangedTensor from the array that shares data memory with
/// the array.
/// \param from The source array.
/// \param out The DLManagedTensor handle.
/// \return 0 when success, -1 when failure happens
pub fn TVMArrayToDLPack(
from: TVMArrayHandle,
out: *mut *mut DLManagedTensor,
) -> ::std::os::raw::c_int;
}
extern "C" {
/// \brief Delete (free) a DLManagedTensor's data.
/// \param dltensor Pointer to the DLManagedTensor.
pub fn TVMDLManagedTensorCallDeleter(dltensor: *mut DLManagedTensor);
}
extern "C" {
/// \brief Create a new runtime stream.
///
/// \param device_type The device type of context
/// \param device_id The device id of context
/// \param out The new stream handle
/// \return 0 when success, -1 when failure happens
pub fn TVMStreamCreate(
device_type: ::std::os::raw::c_int,
device_id: ::std::os::raw::c_int,
out: *mut TVMStreamHandle,
) -> ::std::os::raw::c_int;
}
extern "C" {
/// \brief Free a created stream handle.
///
/// \param device_type The device type of context
/// \param device_id The device id of context
/// \param stream The stream to be freed
/// \return 0 when success, -1 when failure happens
pub fn TVMStreamFree(
device_type: ::std::os::raw::c_int,
device_id: ::std::os::raw::c_int,
stream: TVMStreamHandle,
) -> ::std::os::raw::c_int;
}
extern "C" {
/// \brief Set the runtime stream of current thread to be stream.
/// The subsequent calls to the same device_type
/// will use the setted stream handle.
/// The specific type of stream is runtime device dependent.
///
/// \param device_type The device type of context
/// \param device_id The device id of context.
/// \param handle The stream handle.
/// \return 0 when success, -1 when failure happens
pub fn TVMSetStream(
device_type: ::std::os::raw::c_int,
device_id: ::std::os::raw::c_int,
handle: TVMStreamHandle,
) -> ::std::os::raw::c_int;
}
extern "C" {
/// \brief Wait until all computations on stream completes.
///
/// \param device_type The device type of context
/// \param device_id The device id of context.
/// \param stream The stream to be synchronized.
/// \return 0 when success, -1 when failure happens
pub fn TVMSynchronize(
device_type: ::std::os::raw::c_int,
device_id: ::std::os::raw::c_int,
stream: TVMStreamHandle,
) -> ::std::os::raw::c_int;
}
extern "C" {
/// \brief Synchronize two streams of execution.
///
/// \param device_type The device type of context
/// \param device_id The device id of context
/// \param src The source stream to synchronize.
/// \param dst The destination stream to synchronize.
/// \return 0 when success, -1 when failure happens
pub fn TVMStreamStreamSynchronize(
device_type: ::std::os::raw::c_int,
device_id: ::std::os::raw::c_int,
src: TVMStreamHandle,
dst: TVMStreamHandle,
) -> ::std::os::raw::c_int;
}
extern "C" {
/// \brief Backend function for modules to get function
/// from its environment mod_node (its imports and global function).
/// The user do should not call TVMFuncFree on func.
///
/// \param mod_node The module handle.
/// \param func_name The name of the function.
/// \param out The result function.
/// \return 0 when no error is thrown, -1 when failure happens
pub fn TVMBackendGetFuncFromEnv(
mod_node: *mut ::std::os::raw::c_void,
func_name: *const ::std::os::raw::c_char,
out: *mut TVMFunctionHandle,
) -> ::std::os::raw::c_int;
}
extern "C" {
/// \brief Backend function to register system-wide library symbol.
///
/// \param name The name of the symbol
/// \param ptr The symbol address.
/// \return 0 when no error is thrown, -1 when failure happens
pub fn TVMBackendRegisterSystemLibSymbol(
name: *const ::std::os::raw::c_char,
ptr: *mut ::std::os::raw::c_void,
) -> ::std::os::raw::c_int;
}
extern "C" {
/// \brief Backend function to allocate temporal workspace.
///
/// \note The result allocate spaced is ensured to be aligned to kTempAllocaAlignment.
///
/// \param nbytes The size of the space requested.
/// \param device_type The device type which the space will be allocated.
/// \param device_id The device id which the space will be allocated.
/// \param dtype_code_hint The type code of the array elements. Only used in
/// certain backends such as OpenGL.
/// \param dtype_bits_hint The type bits of the array elements. Only used in
/// certain backends such as OpenGL.
/// \return nullptr when error is thrown, a valid ptr if success
pub fn TVMBackendAllocWorkspace(
device_type: ::std::os::raw::c_int,
device_id: ::std::os::raw::c_int,
nbytes: u64,
dtype_code_hint: ::std::os::raw::c_int,
dtype_bits_hint: ::std::os::raw::c_int,
) -> *mut ::std::os::raw::c_void;
}
extern "C" {
/// \brief Backend function to free temporal workspace.
///
/// \param ptr The result allocated space pointer.
/// \param device_type The device type which the space will be allocated.
/// \param device_id The device id which the space will be allocated.
/// \return 0 when no error is thrown, -1 when failure happens
///
/// \sa TVMBackendAllocWorkspace
pub fn TVMBackendFreeWorkspace(
device_type: ::std::os::raw::c_int,
device_id: ::std::os::raw::c_int,
ptr: *mut ::std::os::raw::c_void,
) -> ::std::os::raw::c_int;
}
/// \brief Environment for TVM parallel task.
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct TVMParallelGroupEnv {
/// \brief Auxiliary used for synchronization
pub sync_handle: *mut ::std::os::raw::c_void,
/// \brief total amount of task
pub num_task: i32,
}
/// \brief The callback function to execute a parallel lambda
/// \param task_id the task id of the function.
/// \param penv The parallel environment backs the execution.
/// \param cdata The supporting closure data.
pub type FTVMParallelLambda = ::std::option::Option<
unsafe extern "C" fn(
task_id: ::std::os::raw::c_int,
penv: *mut TVMParallelGroupEnv,
cdata: *mut ::std::os::raw::c_void,
) -> ::std::os::raw::c_int,
>;
extern "C" {
/// \brief Backend function for running parallel jobs.
///
/// \param flambda The parallel function to be launched.
/// \param cdata The closure data.
/// \param num_task Number of tasks to launch, can be 0, means launch
/// with all available threads.
///
/// \return 0 when no error is thrown, -1 when failure happens
pub fn TVMBackendParallelLaunch(
flambda: FTVMParallelLambda,
cdata: *mut ::std::os::raw::c_void,
num_task: ::std::os::raw::c_int,
) -> ::std::os::raw::c_int;
}
extern "C" {
/// \brief BSP barrrier between parallel threads
/// \param task_id the task id of the function.
/// \param penv The parallel environment backs the execution.
/// \return 0 when no error is thrown, -1 when failure happens
pub fn TVMBackendParallelBarrier(
task_id: ::std::os::raw::c_int,
penv: *mut TVMParallelGroupEnv,
) -> ::std::os::raw::c_int;
}
extern "C" {
/// \brief Simple static initialization function.
/// Run f once and set handle to be not null.
/// This function is mainly used for test purpose.
///
/// \param handle An global address to indicate f
/// \param f The function to be ran
/// \param cdata The closure data to pass to the function.
/// \param nbytes Number of bytes in the closure data.
/// \return 0 when no error is thrown, -1 when failure happens
pub fn TVMBackendRunOnce(
handle: *mut *mut ::std::os::raw::c_void,
f: ::std::option::Option<
unsafe extern "C" fn(arg1: *mut ::std::os::raw::c_void) -> ::std::os::raw::c_int,
>,
cdata: *mut ::std::os::raw::c_void,
nbytes: ::std::os::raw::c_int,
) -> ::std::os::raw::c_int;
}
//! Error types for `TVMArgValue` and `TVMRetValue` conversions. use std::fmt;
error_chain! { static TYPE_CODE_STRS: [&str; 15] = [
errors { "int",
TryFromTVMArgValueError(expected: String, actual: String) { "uint",
description("mismatched types while converting from TVMArgValue") "float",
display("expected `{}` but given `{}`", expected, actual) "handle",
"null",
"TVMType",
"TVMContext",
"ArrayHandle",
"NodeHandle",
"ModuleHandle",
"FuncHandle",
"str",
"bytes",
"NDArrayContainer",
"ExtBegin",
];
#[derive(Debug, Fail)]
pub struct ValueDowncastError {
actual_type_code: i64,
expected_type_code: i64,
}
impl ValueDowncastError {
pub fn new(actual_type_code: i64, expected_type_code: i64) -> Self {
Self {
actual_type_code,
expected_type_code,
} }
}
}
TryFromTVMRetValueError(expected: String, actual: String) { impl fmt::Display for ValueDowncastError {
description("mismatched types while downcasting TVMRetValue") fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
display("invalid downcast: expected `{}` but given `{}`", expected, actual) write!(
formatter,
"Could not downcast TVMValue: expected `{}` but was {}",
TYPE_CODE_STRS[self.actual_type_code as usize],
TYPE_CODE_STRS[self.expected_type_code as usize]
)
}
}
#[derive(Debug, Fail)]
#[fail(display = "Function call `{}` returned error: {}", context, message)]
pub struct FuncCallError {
context: String,
message: String,
}
impl FuncCallError {
pub fn get_with_context(context: String) -> Self {
Self {
context,
message: unsafe { std::ffi::CStr::from_ptr(crate::ffi::TVMGetLastError()) }
.to_str()
.expect("double fault")
.to_owned(),
} }
} }
} }
// error_chain! {
// errors {
// TryFromTVMRetValueError(expected_type: String, actual_type_code: i64) {
// description("mismatched types while downcasting TVMRetValue")
// display("invalid downcast: expected `{}` but was `{}`",
// expected_type, type_code_to_string(actual_type_code))
// }
// }
// foreign_links {
// IntoString(std::ffi::IntoStringError);
// ParseInt(std::num::ParseIntError);
// Utf8(std::str::Utf8Error);
// }
// }
//! This crate contains the refactored basic components required //! This crate contains the refactored basic components required
//! for `runtime` and `frontend` TVM crates. //! for `runtime` and `frontend` TVM crates.
#![crate_name = "tvm_common"] #![feature(box_syntax, trait_alias)]
#![recursion_limit = "1024"]
#![allow(non_camel_case_types, unused_imports)]
#![feature(box_syntax, try_from)]
#[macro_use] #[macro_use]
extern crate error_chain; extern crate failure;
/// Unified ffi module for both runtime and frontend crates. /// Unified ffi module for both runtime and frontend crates.
pub mod ffi { pub mod ffi {
#![allow(non_camel_case_types, non_snake_case, non_upper_case_globals, unused)] #![allow(non_camel_case_types, non_snake_case, non_upper_case_globals, unused)]
#[cfg(feature = "frontend")]
pub extern crate tvm_sys as ts;
#[cfg(feature = "runtime")]
pub mod runtime {
use std::os::raw::{c_char, c_int, c_void}; use std::os::raw::{c_char, c_int, c_void};
include!(concat!(env!("CARGO_MANIFEST_DIR"), "/src/c_runtime_api.rs")); include!(concat!(env!("CARGO_MANIFEST_DIR"), "/src/c_runtime_api.rs"));
pub type BackendPackedCFunc = extern "C" fn( pub type BackendPackedCFunc =
args: *const TVMValue, extern "C" fn(args: *const TVMValue, type_codes: *const c_int, num_args: c_int) -> c_int;
type_codes: *const c_int,
num_args: c_int,
) -> c_int;
}
} }
pub mod array;
pub mod errors; pub mod errors;
pub mod ty; #[macro_use]
pub mod packed_func;
pub mod value; pub mod value;
pub use errors::*; pub use errors::*;
pub use ty::TVMTypeCode; pub use ffi::{TVMContext, TVMType};
pub use value::{TVMArgValue, TVMRetValue, TVMValue}; pub use packed_func::{TVMArgValue, TVMRetValue};
use std::{any::Any, convert::TryFrom, marker::PhantomData, os::raw::c_void};
use failure::Error;
pub use crate::ffi::TVMValue;
use crate::ffi::*;
pub trait PackedFunc =
Fn(&[TVMArgValue]) -> Result<TVMRetValue, crate::errors::FuncCallError> + Send + Sync;
/// Calls a packed function and returns a `TVMRetValue`.
///
/// # Example
///
/// `call_packed!(my_tvm_func, &mut arg1, &mut arg2)`
#[macro_export]
macro_rules! call_packed {
($fn:expr, $($args:expr),+) => {
$fn(&[$($args.into(),)+])
};
($fn:expr) => {
$fn(&Vec::new())
};
}
/// A borrowed TVMPODValue. Can be constructed using `into()` but the preferred way
/// to obtain a `TVMArgValue` is automatically via `call_packed!`.
#[derive(Clone, Copy)]
pub struct TVMArgValue<'a> {
pub _lifetime: PhantomData<&'a ()>,
pub value: TVMValue,
pub type_code: i64,
}
impl<'a> TVMArgValue<'a> {
pub fn new(value: TVMValue, type_code: i64) -> Self {
TVMArgValue {
_lifetime: PhantomData,
value: value,
type_code: type_code,
}
}
}
#[macro_export]
macro_rules! ensure_type {
($val:ident, $expected_type_code:expr) => {
ensure!(
$val.type_code == $expected_type_code as i64,
$crate::errors::ValueDowncastError::new(
$val.type_code as i64,
$expected_type_code as i64
)
);
};
}
/// Creates a conversion to a `TVMArgValue` for a primitive type and DLDataTypeCode.
macro_rules! impl_prim_tvm_arg {
($type_code:ident, $field:ident, $field_type:ty, [ $( $type:ty ),+ ] ) => {
$(
impl From<$type> for TVMArgValue<'static> {
fn from(val: $type) -> Self {
TVMArgValue {
value: TVMValue { $field: val as $field_type },
type_code: $type_code as i64,
_lifetime: PhantomData,
}
}
}
impl<'a> From<&'a $type> for TVMArgValue<'a> {
fn from(val: &'a $type) -> Self {
TVMArgValue {
value: TVMValue {
$field: val.to_owned() as $field_type,
},
type_code: $type_code as i64,
_lifetime: PhantomData,
}
}
}
impl<'a> TryFrom<TVMArgValue<'a>> for $type {
type Error = Error;
fn try_from(val: TVMArgValue<'a>) -> Result<Self, Self::Error> {
ensure_type!(val, $type_code);
Ok(unsafe { val.value.$field as $type })
}
}
impl<'a> TryFrom<&TVMArgValue<'a>> for $type {
type Error = Error;
fn try_from(val: &TVMArgValue<'a>) -> Result<Self, Self::Error> {
ensure_type!(val, $type_code);
Ok(unsafe { val.value.$field as $type })
}
}
)+
};
}
impl_prim_tvm_arg!(DLDataTypeCode_kDLFloat, v_float64, f64, [f32, f64]);
impl_prim_tvm_arg!(
DLDataTypeCode_kDLInt,
v_int64,
i64,
[i8, i16, i32, i64, isize]
);
impl_prim_tvm_arg!(
DLDataTypeCode_kDLUInt,
v_int64,
i64,
[u8, u16, u32, u64, usize]
);
#[cfg(feature = "bindings")]
// only allow this in bindings because pure-rust can't take ownership of leaked CString
impl<'a> From<&String> for TVMArgValue<'a> {
fn from(string: &String) -> Self {
TVMArgValue {
value: TVMValue {
v_str: std::ffi::CString::new(string.clone()).unwrap().into_raw(),
},
type_code: TVMTypeCode_kStr as i64,
_lifetime: PhantomData,
}
}
}
impl<'a> From<&std::ffi::CString> for TVMArgValue<'a> {
fn from(string: &std::ffi::CString) -> Self {
TVMArgValue {
value: TVMValue {
v_str: string.as_ptr(),
},
type_code: TVMTypeCode_kStr as i64,
_lifetime: PhantomData,
}
}
}
impl<'a> TryFrom<TVMArgValue<'a>> for &str {
type Error = Error;
fn try_from(arg: TVMArgValue<'a>) -> Result<Self, Self::Error> {
ensure_type!(arg, TVMTypeCode_kStr);
Ok(unsafe { std::ffi::CStr::from_ptr(arg.value.v_handle as *const i8) }.to_str()?)
}
}
impl<'a> TryFrom<&TVMArgValue<'a>> for &str {
type Error = Error;
fn try_from(arg: &TVMArgValue<'a>) -> Result<Self, Self::Error> {
ensure_type!(arg, TVMTypeCode_kStr);
Ok(unsafe { std::ffi::CStr::from_ptr(arg.value.v_handle as *const i8) }.to_str()?)
}
}
/// Creates a conversion to a `TVMArgValue` for an object handle.
impl<'a, T> From<*const T> for TVMArgValue<'a> {
fn from(ptr: *const T) -> Self {
TVMArgValue {
value: TVMValue {
v_handle: ptr as *mut T as *mut c_void,
},
type_code: TVMTypeCode_kArrayHandle as i64,
_lifetime: PhantomData,
}
}
}
/// Creates a conversion to a `TVMArgValue` for a mutable object handle.
impl<'a, T> From<*mut T> for TVMArgValue<'a> {
fn from(ptr: *mut T) -> Self {
TVMArgValue {
value: TVMValue {
v_handle: ptr as *mut c_void,
},
type_code: TVMTypeCode_kHandle as i64,
_lifetime: PhantomData,
}
}
}
impl<'a> From<&'a mut DLTensor> for TVMArgValue<'a> {
fn from(arr: &'a mut DLTensor) -> Self {
TVMArgValue {
value: TVMValue {
v_handle: arr as *mut _ as *mut c_void,
},
type_code: TVMTypeCode_kArrayHandle as i64,
_lifetime: PhantomData,
}
}
}
impl<'a> From<&'a DLTensor> for TVMArgValue<'a> {
fn from(arr: &'a DLTensor) -> Self {
TVMArgValue {
value: TVMValue {
v_handle: arr as *const _ as *mut DLTensor as *mut c_void,
},
type_code: TVMTypeCode_kArrayHandle as i64,
_lifetime: PhantomData,
}
}
}
impl<'a, 'v> TryFrom<&'a TVMArgValue<'v>> for TVMType {
type Error = Error;
fn try_from(arg: &'a TVMArgValue<'v>) -> Result<Self, Self::Error> {
ensure_type!(arg, TVMTypeCode_kTVMType);
Ok(unsafe { arg.value.v_type.into() })
}
}
/// An owned TVMPODValue. Can be converted from a variety of primitive and object types.
/// Can be downcasted using `try_from` if it contains the desired type.
///
/// # Example
///
/// ```
/// let a = 42u32;
/// let b: i64 = TVMRetValue::from(a).try_into().unwrap();
///
/// let s = "hello, world!";
/// let t: TVMRetValue = s.into();
/// assert_eq!(String::try_from(t).unwrap(), s);
/// ```
pub struct TVMRetValue {
pub value: TVMValue,
pub box_value: Box<Any>,
pub type_code: i64,
}
impl TVMRetValue {
pub fn from_tvm_value(value: TVMValue, type_code: i64) -> Self {
Self {
value,
type_code,
box_value: box (),
}
}
pub fn into_tvm_value(self) -> (TVMValue, TVMTypeCode) {
(self.value, self.type_code as TVMTypeCode)
}
}
impl Default for TVMRetValue {
fn default() -> Self {
TVMRetValue {
value: TVMValue { v_int64: 0 as i64 },
type_code: 0,
box_value: box (),
}
}
}
macro_rules! impl_pod_ret_value {
($code:expr, [ $( $ty:ty ),+ ] ) => {
$(
impl From<$ty> for TVMRetValue {
fn from(val: $ty) -> Self {
Self {
value: val.into(),
type_code: $code as i64,
box_value: box (),
}
}
}
impl TryFrom<TVMRetValue> for $ty {
type Error = Error;
fn try_from(ret: TVMRetValue) -> Result<$ty, Self::Error> {
ensure_type!(ret, $code);
Ok(ret.value.into())
}
}
)+
};
}
impl_pod_ret_value!(DLDataTypeCode_kDLInt, [i8, i16, i32, i64, isize]);
impl_pod_ret_value!(DLDataTypeCode_kDLUInt, [u8, u16, u32, u64, usize]);
impl_pod_ret_value!(DLDataTypeCode_kDLFloat, [f32, f64]);
impl_pod_ret_value!(TVMTypeCode_kTVMType, [TVMType]);
impl_pod_ret_value!(TVMTypeCode_kTVMContext, [TVMContext]);
impl TryFrom<TVMRetValue> for String {
type Error = Error;
fn try_from(ret: TVMRetValue) -> Result<String, Self::Error> {
ensure_type!(ret, TVMTypeCode_kStr);
let cs = unsafe { std::ffi::CString::from_raw(ret.value.v_handle as *mut i8) };
let ret_str = cs.clone().into_string();
if cfg!(feature = "bindings") {
std::mem::forget(cs); // TVM C++ takes ownership of CString. (@see TVMFuncCall)
}
Ok(ret_str?)
}
}
impl From<String> for TVMRetValue {
fn from(s: String) -> Self {
let cs = std::ffi::CString::new(s).unwrap();
Self {
value: TVMValue {
v_str: cs.into_raw() as *mut i8,
},
box_value: box (),
type_code: TVMTypeCode_kStr as i64,
}
}
}
//! This module containes `TVMTypeCode` and `TVMType` with some conversion methods.
//!
//! # Example
//!
//! ```
//! let dtype = TVMType::from("float");
//! println!("dtype is: {}", dtype);
//! ```
use std::{
ffi::{CStr, CString},
fmt::{self, Display, Formatter},
};
/// TVM type codes.
#[repr(u32)]
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub enum TVMTypeCode {
kDLInt = 0,
kDLUInt = 1,
kDLFloat = 2,
kHandle = 3,
kNull = 4,
kTVMType = 5,
kTVMContext = 6,
kArrayHandle = 7,
kNodeHandle = 8,
kModuleHandle = 9,
kFuncHandle = 10,
kStr = 11,
kBytes = 12,
kNDArrayContainer = 13,
}
impl Default for TVMTypeCode {
fn default() -> Self {
TVMTypeCode::kDLInt
}
}
impl From<TVMTypeCode> for i64 {
fn from(arg: TVMTypeCode) -> i64 {
match arg {
TVMTypeCode::kDLInt => 0,
TVMTypeCode::kDLUInt => 1,
TVMTypeCode::kDLFloat => 2,
TVMTypeCode::kHandle => 3,
TVMTypeCode::kNull => 4,
TVMTypeCode::kTVMType => 5,
TVMTypeCode::kTVMContext => 6,
TVMTypeCode::kArrayHandle => 7,
TVMTypeCode::kNodeHandle => 8,
TVMTypeCode::kModuleHandle => 9,
TVMTypeCode::kFuncHandle => 10,
TVMTypeCode::kStr => 11,
TVMTypeCode::kBytes => 12,
TVMTypeCode::kNDArrayContainer => 13,
}
}
}
impl Into<TVMTypeCode> for i64 {
fn into(self) -> TVMTypeCode {
match self {
0 => TVMTypeCode::kDLInt,
1 => TVMTypeCode::kDLUInt,
2 => TVMTypeCode::kDLFloat,
3 => TVMTypeCode::kHandle,
4 => TVMTypeCode::kNull,
5 => TVMTypeCode::kTVMType,
6 => TVMTypeCode::kTVMContext,
7 => TVMTypeCode::kArrayHandle,
8 => TVMTypeCode::kNodeHandle,
9 => TVMTypeCode::kModuleHandle,
10 => TVMTypeCode::kFuncHandle,
11 => TVMTypeCode::kStr,
12 => TVMTypeCode::kBytes,
13 => TVMTypeCode::kNDArrayContainer,
_ => unreachable!(),
}
}
}
impl Display for TVMTypeCode {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
write!(
f,
"{}",
match self {
TVMTypeCode::kDLInt => "int",
TVMTypeCode::kDLUInt => "uint",
TVMTypeCode::kDLFloat => "float",
TVMTypeCode::kHandle => "handle",
TVMTypeCode::kNull => "null",
TVMTypeCode::kTVMType => "TVM type",
TVMTypeCode::kTVMContext => "TVM context",
TVMTypeCode::kArrayHandle => "Array handle",
TVMTypeCode::kNodeHandle => "Node handle",
TVMTypeCode::kModuleHandle => "Module handle",
TVMTypeCode::kFuncHandle => "Function handle",
TVMTypeCode::kStr => "string",
TVMTypeCode::kBytes => "bytes",
TVMTypeCode::kNDArrayContainer => "ndarray container",
}
)
}
}
macro_rules! impl_prim_type {
($type:ty, $variant:ident) => {
impl<'a> From<&'a $type> for TVMTypeCode {
fn from(_arg: &$type) -> Self {
TVMTypeCode::$variant
}
}
impl<'a> From<&'a mut $type> for TVMTypeCode {
fn from(_arg: &mut $type) -> Self {
TVMTypeCode::$variant
}
}
};
}
impl_prim_type!(usize, kDLInt);
impl_prim_type!(i64, kDLInt);
impl_prim_type!(i32, kDLInt);
impl_prim_type!(i16, kDLInt);
impl_prim_type!(i8, kDLInt);
impl_prim_type!(u64, kDLUInt);
impl_prim_type!(u32, kDLUInt);
impl_prim_type!(u16, kDLUInt);
impl_prim_type!(u8, kDLUInt);
impl_prim_type!(f64, kDLFloat);
impl_prim_type!(f32, kDLFloat);
impl_prim_type!(str, kStr);
impl_prim_type!(CStr, kStr);
impl_prim_type!(String, kStr);
impl_prim_type!(CString, kStr);
impl_prim_type!([u8], kBytes);
//! This module provides the the wrapped `TVMValue`, `TVMArgValue` and `TVMRetValue` use std::str::FromStr;
//! required for using TVM functions.
use std::{ use failure::Error;
any::Any,
convert::TryFrom,
ffi::{CStr, CString},
fmt::{self, Debug, Formatter},
marker::PhantomData,
mem,
ops::Deref,
os::raw::{c_char, c_void},
};
#[cfg(feature = "runtime")] use crate::ffi::*;
use ffi::runtime::TVMValue as _TVMValue;
#[cfg(feature = "frontend")] impl TVMType {
use ffi::ts::TVMValue as _TVMValue; fn new(type_code: u8, bits: u8, lanes: u16) -> Self {
Self {
use errors::*; code: type_code,
bits,
use ty::TVMTypeCode; lanes,
/// Wrapped TVMValue type.
#[derive(Clone, Copy)]
pub struct TVMValue {
pub inner: _TVMValue,
}
impl TVMValue {
/// Creates TVMValue from the raw part.
pub fn new(inner: _TVMValue) -> Self {
TVMValue { inner }
}
pub(crate) fn into_raw(self) -> _TVMValue {
self.inner
}
}
impl Debug for TVMValue {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
unsafe {
write!(
f,
"TVMValue: [v_int64: {:?}], [v_float64: {:?}], [v_handle: {:?}],\
[v_str: {:?}]",
self.inner.v_int64, self.inner.v_float64, self.inner.v_handle, self.inner.v_str
)
} }
} }
} }
impl Deref for TVMValue { /// Implements TVMType conversion from `&str` of general format `{dtype}{bits}x{lanes}`
type Target = _TVMValue; /// such as "int32", "float32" or with lane "float32x1".
fn deref(&self) -> &Self::Target { impl FromStr for TVMType {
&self.inner type Err = Error;
fn from_str(type_str: &str) -> Result<Self, Self::Err> {
if type_str == "bool" {
return Ok(TVMType::new(1, 1, 1));
} }
}
macro_rules! impl_prim_val { let mut type_lanes = type_str.split("x");
($type:ty, $field:ident, $cast:ty) => { let typ = type_lanes.next().expect("Missing dtype");
impl From<$type> for TVMValue { let lanes = type_lanes
fn from(arg: $type) -> Self { .next()
let inner = _TVMValue { .map(|l| <u16>::from_str_radix(l, 10))
$field: arg as $cast, .unwrap_or(Ok(1))?;
}; let (type_name, bits) = match typ.find(char::is_numeric) {
Self::new(inner) Some(idx) => {
let (name, bits_str) = typ.split_at(idx);
(name, u8::from_str_radix(bits_str, 10)?)
} }
} None => (typ, 32),
impl<'a> From<&'a $type> for TVMValue {
fn from(arg: &$type) -> Self {
let inner = _TVMValue {
$field: *arg as $cast,
}; };
Self::new(inner)
}
}
impl<'a> From<&'a mut $type> for TVMValue { let type_code = match type_name {
fn from(arg: &mut $type) -> Self { "int" => 0,
let inner = _TVMValue { "uint" => 1,
$field: *arg as $cast, "float" => 2,
"handle" => 3,
_ => return Err(format_err!("Unknown type {}", type_name)),
}; };
Self::new(inner)
}
}
impl TryFrom<TVMValue> for $type {
type Error = Error;
fn try_from(val: TVMValue) -> Result<Self> {
Ok(unsafe { val.inner.$field as $type })
}
}
impl<'a> TryFrom<&'a TVMValue> for $type { Ok(TVMType::new(type_code, bits, lanes))
type Error = Error;
fn try_from(val: &TVMValue) -> Result<Self> {
Ok(unsafe { val.into_raw().$field as $type })
}
}
impl<'a> TryFrom<&'a mut TVMValue> for $type {
type Error = Error;
fn try_from(val: &mut TVMValue) -> Result<Self> {
Ok(unsafe { val.into_raw().$field as $type })
}
} }
};
} }
impl_prim_val!(isize, v_int64, i64); impl std::fmt::Display for TVMType {
impl_prim_val!(i64, v_int64, i64); fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
impl_prim_val!(i32, v_int64, i64); if self.bits == 1 && self.lanes == 1 {
impl_prim_val!(i16, v_int64, i64); return write!(f, "bool");
impl_prim_val!(i8, v_int64, i64);
impl_prim_val!(usize, v_int64, i64);
impl_prim_val!(u64, v_int64, i64);
impl_prim_val!(u32, v_int64, i64);
impl_prim_val!(u16, v_int64, i64);
impl_prim_val!(u8, v_int64, i64);
impl_prim_val!(f64, v_float64, f64);
impl_prim_val!(f32, v_float64, f64);
impl<'a> From<&'a str> for TVMValue {
fn from(arg: &str) -> TVMValue {
let arg = CString::new(arg).unwrap();
let inner = _TVMValue {
v_str: arg.as_ptr() as *const c_char,
};
mem::forget(arg);
Self::new(inner)
} }
} let mut type_str = match self.code {
0 => "int",
impl<'a> From<&'a String> for TVMValue { 1 => "uint",
fn from(arg: &String) -> TVMValue { 2 => "float",
let arg = CString::new(arg.as_bytes()).unwrap(); 4 => "handle",
let inner = _TVMValue { _ => "unknown",
v_str: arg.as_ptr() as *const c_char,
};
mem::forget(arg);
Self::new(inner)
} }
} .to_string();
impl<'a> From<&'a CString> for TVMValue { type_str += &self.bits.to_string();
fn from(arg: &CString) -> TVMValue { if self.lanes > 1 {
let arg = arg.to_owned(); type_str += &format!("x{}", self.lanes);
let inner = _TVMValue {
v_str: arg.as_ptr() as *const c_char,
};
mem::forget(arg);
Self::new(inner)
} }
} f.write_str(&type_str)
impl<'a> From<&'a [u8]> for TVMValue {
fn from(arg: &[u8]) -> TVMValue {
let arg = arg.to_owned();
let inner = _TVMValue {
v_handle: &arg as *const _ as *mut c_void,
};
mem::forget(arg);
Self::new(inner)
} }
} }
/// Captures both `TVMValue` and `TVMTypeCode` needed for TVM function. macro_rules! impl_pod_tvm_value {
/// The preferred way to obtain a `TVMArgValue` is automatically via `call_packed!`. ($field:ident, $field_ty:ty, $( $ty:ty ),+) => {
/// or in the frontend crate, with `function::Builder`. Checkout the methods for conversions. $(
/// impl From<$ty> for TVMValue {
/// ## Example fn from(val: $ty) -> Self {
/// TVMValue { $field: val as $field_ty }
/// ```
/// let s = "hello".to_string();
/// let arg = TVMArgValue::from(&s);
/// let tvm: String = arg.try_into().unwrap();
/// assert_eq!(arg, s);
/// ```
#[derive(Debug, Clone, Copy)]
pub struct TVMArgValue<'a> {
/// The wrapped TVMValue
pub value: TVMValue,
/// The matching type code.
pub type_code: TVMTypeCode,
/// This is only exposed to runtime and frontend crates and is not meant to be used directly.
pub lifetime: PhantomData<&'a ()>,
}
impl<'a> TVMArgValue<'a> {
pub fn new(value: TVMValue, type_code: TVMTypeCode) -> Self {
TVMArgValue {
value: value,
type_code: type_code,
lifetime: PhantomData,
}
}
}
impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for i64 {
type Error = Error;
fn try_from(arg: &TVMArgValue<'a>) -> Result<Self> {
if (arg.type_code == TVMTypeCode::kDLInt)
| (arg.type_code == TVMTypeCode::kDLUInt)
| (arg.type_code == TVMTypeCode::kNull)
{
Ok(unsafe { arg.value.inner.v_int64 })
} else {
bail!(ErrorKind::TryFromTVMArgValueError(
stringify!(i64).to_string(),
arg.type_code.to_string()
))
} }
} }
}
impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for f64 { impl From<TVMValue> for $ty {
type Error = Error; fn from(val: TVMValue) -> Self {
fn try_from(arg: &TVMArgValue<'a>) -> Result<Self> { unsafe { val.$field as $ty }
if arg.type_code == TVMTypeCode::kDLFloat {
Ok(unsafe { arg.value.inner.v_float64 })
} else {
bail!(ErrorKind::TryFromTVMArgValueError(
stringify!(f64).to_string(),
arg.type_code.to_string()
))
}
} }
}
impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for String {
type Error = Error;
fn try_from(arg: &TVMArgValue<'a>) -> Result<Self> {
if arg.type_code == TVMTypeCode::kStr {
let ret_str = unsafe {
match CStr::from_ptr(arg.value.inner.v_str).to_str() {
Ok(s) => s,
Err(_) => "Invalid UTF-8 message",
} }
)+
}; };
Ok(ret_str.to_string()) ($field:ident, $ty:ty) => {
} else { impl_pod_tvm_value!($field, $ty, $ty);
bail!(ErrorKind::TryFromTVMArgValueError(
stringify!(String).to_string(),
arg.type_code.to_string()
))
}
} }
} }
/// Main way to create a TVMArgValue from suported Rust values. impl_pod_tvm_value!(v_int64, i64, i8, u8, i16, u16, i32, u32, i64, u64, isize, usize);
impl<'b, 'a: 'b, T: 'b + ?Sized> From<&'b T> for TVMArgValue<'a> impl_pod_tvm_value!(v_float64, f64, f32, f64);
where impl_pod_tvm_value!(v_type, TVMType);
TVMValue: From<&'b T>, impl_pod_tvm_value!(v_ctx, TVMContext);
TVMTypeCode: From<&'b T>,
{
fn from(arg: &'b T) -> Self {
TVMArgValue::new(TVMValue::from(arg), TVMTypeCode::from(arg))
}
}
/// Creates a conversion to a `TVMArgValue` for an object handle. macro_rules! impl_tvm_context {
impl<'a, T> From<*const T> for TVMArgValue<'a> { ( $( $dev_type:ident : [ $( $dev_name:ident ),+ ] ),+ ) => {
fn from(ptr: *const T) -> Self { /// Creates a TVMContext from a string (e.g., "cpu", "gpu", "ext_dev")
let value = TVMValue::new(_TVMValue { impl FromStr for TVMContext {
v_handle: ptr as *mut T as *mut c_void, type Err = Error;
}); fn from_str(type_str: &str) -> Result<Self, Self::Err> {
Ok(Self {
TVMArgValue::new(value, TVMTypeCode::kArrayHandle) device_type: match type_str {
$( $( stringify!($dev_name) )|+ => $dev_type ),+,
_ => return Err(format_err!("device {} not supported", type_str).into()),
},
device_id: 0,
})
} }
}
/// Creates a conversion to a `TVMArgValue` for a mutable object handle.
impl<'a, T> From<*mut T> for TVMArgValue<'a> {
fn from(ptr: *mut T) -> Self {
let value = TVMValue::new(_TVMValue {
v_handle: ptr as *mut c_void,
});
TVMArgValue::new(value, TVMTypeCode::kHandle)
} }
}
/// An owned version of TVMPODValue. It can be converted from varieties of
/// primitive and object types.
/// It can be downcasted using `try_from` if it contains the desired type.
///
/// # Example
///
/// ```
/// let a = 42u32;
/// let b: i64 = TVMRetValue::from(a).try_into().unwrap();
///
/// let s = "hello, world!";
/// let t: TVMRetValue = s.into();
/// assert_eq!(String::try_from(t).unwrap(), s);
/// ```
pub struct TVMRetValue {
/// A primitive return value, if any.
pub prim_value: usize,
/// An object return value, if any.
pub box_value: Box<Any>,
pub type_code: TVMTypeCode,
}
impl TVMRetValue { impl TVMContext {
fn new(prim_value: usize, box_value: Box<Any>, type_code: TVMTypeCode) -> Self { $(
$(
pub fn $dev_name(device_id: usize) -> Self {
Self { Self {
prim_value, device_type: $dev_type,
box_value, device_id: device_id as i32,
type_code,
}
}
/// unsafe function to create `TVMRetValue` from `TVMValue` and
/// its matching `TVMTypeCode`.
pub unsafe fn from_tvm_value(value: TVMValue, type_code: TVMTypeCode) -> Self {
let value = value.into_raw();
match type_code {
TVMTypeCode::kDLInt | TVMTypeCode::kDLUInt => {
Self::new(value.v_int64 as usize, box (), type_code)
}
TVMTypeCode::kDLFloat => Self::new(value.v_float64 as usize, box (), type_code),
TVMTypeCode::kHandle
| TVMTypeCode::kArrayHandle
| TVMTypeCode::kNodeHandle
| TVMTypeCode::kModuleHandle
| TVMTypeCode::kFuncHandle => {
Self::new(value.v_handle as usize, box value.v_handle, type_code)
}
TVMTypeCode::kStr | TVMTypeCode::kBytes => {
Self::new(value.v_str as usize, box (value.v_str), type_code)
} }
_ => Self::new(0usize, box (), type_code),
} }
} )+
)+
/// Returns the underlying `TVMValue` and `TVMTypeCode`.
pub fn into_tvm_value(self) -> (TVMValue, TVMTypeCode) {
let val = match self.type_code {
TVMTypeCode::kDLInt | TVMTypeCode::kDLUInt => TVMValue::new(_TVMValue {
v_int64: self.prim_value as i64,
}),
TVMTypeCode::kDLFloat => TVMValue::new(_TVMValue {
v_float64: self.prim_value as f64,
}),
TVMTypeCode::kHandle
| TVMTypeCode::kArrayHandle
| TVMTypeCode::kNodeHandle
| TVMTypeCode::kModuleHandle
| TVMTypeCode::kFuncHandle
| TVMTypeCode::kNDArrayContainer => TVMValue::new(_TVMValue {
v_handle: self.prim_value as *const c_void as *mut c_void,
}),
TVMTypeCode::kStr | TVMTypeCode::kBytes => TVMValue::new(_TVMValue {
v_str: self.prim_value as *const c_char,
}),
_ => unreachable!(),
};
(val, self.type_code)
}
}
impl Default for TVMRetValue {
fn default() -> Self {
TVMRetValue {
prim_value: 0usize,
box_value: box (),
type_code: TVMTypeCode::default(),
}
}
}
impl Clone for TVMRetValue {
fn clone(&self) -> Self {
match self.type_code {
TVMTypeCode::kDLInt | TVMTypeCode::kDLUInt | TVMTypeCode::kDLFloat => {
Self::new(self.prim_value.clone(), box (), self.type_code.clone())
}
TVMTypeCode::kHandle
| TVMTypeCode::kArrayHandle
| TVMTypeCode::kNodeHandle
| TVMTypeCode::kModuleHandle
| TVMTypeCode::kFuncHandle
| TVMTypeCode::kNDArrayContainer => Self::new(
self.prim_value.clone(),
box (self.prim_value.clone() as *const c_void as *mut c_void),
self.type_code.clone(),
),
TVMTypeCode::kStr | TVMTypeCode::kBytes => Self::new(
self.prim_value.clone(),
box (self.prim_value.clone() as *const c_char),
self.type_code.clone(),
),
_ => unreachable!(),
}
}
}
impl Debug for TVMRetValue {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
write!(
f,
"prim_value: {:?}, box_value: {:?}, type_code: {:?}",
self.prim_value, self.prim_value as *const c_void as *mut c_void, self.type_code
)
}
}
macro_rules! impl_prim_ret_value {
($type:ty, $code:expr) => {
impl From<$type> for TVMRetValue {
fn from(val: $type) -> Self {
TVMRetValue {
prim_value: val as usize,
box_value: box (),
type_code: $code,
}
}
}
impl<'a> From<&'a $type> for TVMRetValue {
fn from(val: &$type) -> Self {
TVMRetValue {
prim_value: *val as usize,
box_value: box (),
type_code: $code,
}
}
}
impl<'a> From<&'a mut $type> for TVMRetValue {
fn from(val: &mut $type) -> Self {
TVMRetValue {
prim_value: *val as usize,
box_value: box (),
type_code: $code,
}
}
}
impl TryFrom<TVMRetValue> for $type {
type Error = Error;
fn try_from(ret: TVMRetValue) -> Result<$type> {
if ret.type_code == $code {
Ok(ret.prim_value as $type)
} else {
bail!(ErrorKind::TryFromTVMRetValueError(
stringify!($type).to_string(),
ret.type_code.to_string(),
))
}
}
}
};
}
impl_prim_ret_value!(i8, TVMTypeCode::kDLInt);
impl_prim_ret_value!(i16, TVMTypeCode::kDLInt);
impl_prim_ret_value!(i32, TVMTypeCode::kDLInt);
impl_prim_ret_value!(i64, TVMTypeCode::kDLInt);
impl_prim_ret_value!(isize, TVMTypeCode::kDLInt);
impl_prim_ret_value!(u8, TVMTypeCode::kDLUInt);
impl_prim_ret_value!(u16, TVMTypeCode::kDLUInt);
impl_prim_ret_value!(u32, TVMTypeCode::kDLUInt);
impl_prim_ret_value!(u64, TVMTypeCode::kDLUInt);
impl_prim_ret_value!(usize, TVMTypeCode::kDLUInt);
impl_prim_ret_value!(f32, TVMTypeCode::kDLFloat);
impl_prim_ret_value!(f64, TVMTypeCode::kDLFloat);
macro_rules! impl_ptr_ret_value {
($type:ty) => {
impl From<$type> for TVMRetValue {
fn from(ptr: $type) -> Self {
TVMRetValue {
prim_value: ptr as usize,
box_value: box (),
type_code: TVMTypeCode::kHandle,
}
}
}
impl TryFrom<TVMRetValue> for $type {
type Error = Error;
fn try_from(ret: TVMRetValue) -> Result<$type> {
if ret.type_code == TVMTypeCode::kHandle {
Ok(ret.prim_value as $type)
} else {
bail!(ErrorKind::TryFromTVMRetValueError(
stringify!($type).to_string(),
ret.type_code.to_string(),
))
}
}
}
};
}
impl_ptr_ret_value!(*const c_void);
impl_ptr_ret_value!(*mut c_void);
impl From<String> for TVMRetValue {
fn from(val: String) -> Self {
let pval = val.as_ptr() as *const c_char as usize;
let bval = box (val.as_ptr() as *const c_char);
mem::forget(val);
TVMRetValue::new(pval, bval, TVMTypeCode::kStr)
}
}
impl TryFrom<TVMRetValue> for String {
type Error = Error;
fn try_from(ret: TVMRetValue) -> Result<String> {
// Note: simple downcast doesn't work for function call return values
let ret_str = unsafe {
match CStr::from_ptr(ret.prim_value as *const c_char).to_str() {
Ok(s) => s,
Err(_) => "Invalid UTF-8 message",
} }
}; };
Ok(ret_str.to_string())
}
} }
#[cfg(test)] impl_tvm_context!(
mod tests { DLDeviceType_kDLCPU: [cpu, llvm, stackvm],
use super::*; DLDeviceType_kDLGPU: [gpu, cuda, nvptx],
use std::convert::TryInto; DLDeviceType_kDLOpenCL: [cl],
DLDeviceType_kDLMetal: [metal],
#[test] DLDeviceType_kDLVPI: [vpi],
fn numeric() { DLDeviceType_kDLROCM: [rocm],
macro_rules! arg_ret_tests { DLDeviceType_kDLExtDev: [ext_dev]
($v:expr; $($ty:ty),+) => {{ );
$(
let v = $v as $ty;
let b = TVMRetValue::from(&v);
let b: $ty = b.try_into().unwrap();
assert_eq!(b, v);
)+
}};
}
arg_ret_tests!(42; i8, i16, i32, i64, f32, f64);
}
#[test]
fn string() {
let s = "hello".to_string();
let tvm_arg: String = TVMRetValue::from(s.clone()).try_into().unwrap();
assert_eq!(tvm_arg, s);
}
}
[package]
name = "tvm-sys"
version = "0.1.0"
authors = ["TVM Contributors"]
license = "Apache-2.0"
description = "Raw C API"
[build-dependencies]
bindgen = "0.37.4"
#![allow(
non_camel_case_types,
non_snake_case,
non_upper_case_globals,
dead_code,
improper_ctypes
)]
include!("bindgen.rs");
...@@ -15,11 +15,11 @@ name = "tvm_frontend" ...@@ -15,11 +15,11 @@ name = "tvm_frontend"
crate-type = ["dylib"] crate-type = ["dylib"]
[dependencies] [dependencies]
error-chain = "0.12.0" failure = "0.1.5"
lazy_static = "1.1.0" lazy_static = "1.1.0"
ndarray = "0.12.1" ndarray = "0.12.1"
num-traits = "0.2" num-traits = "0.2"
tvm-common = { version = "0.1.0", path = "../common/", features = ["frontend"] } tvm-common = { version = "0.1.0", path = "../common/", features = ["bindings"] }
[features] [features]
blas = ["ndarray/blas"] blas = ["ndarray/blas"]
#![feature(try_from)]
extern crate csv; extern crate csv;
extern crate image; extern crate image;
extern crate ndarray; extern crate ndarray;
...@@ -10,6 +8,7 @@ use std::{ ...@@ -10,6 +8,7 @@ use std::{
convert::TryInto, convert::TryInto,
fs::{self, File}, fs::{self, File},
path::Path, path::Path,
str::FromStr,
}; };
use image::{FilterType, GenericImageView}; use image::{FilterType, GenericImageView};
...@@ -44,8 +43,12 @@ fn main() { ...@@ -44,8 +43,12 @@ fn main() {
// make arr shape as [1, 3, 224, 224] acceptable to resnet // make arr shape as [1, 3, 224, 224] acceptable to resnet
let arr = arr.insert_axis(Axis(0)); let arr = arr.insert_axis(Axis(0));
// create input tensor from rust's ndarray // create input tensor from rust's ndarray
let input = let input = NDArray::from_rust_ndarray(
NDArray::from_rust_ndarray(&arr, TVMContext::cpu(0), TVMType::from("float32")).unwrap(); &arr,
TVMContext::cpu(0),
TVMType::from_str("float32").unwrap(),
)
.unwrap();
println!( println!(
"input size is {:?}", "input size is {:?}",
input.shape().expect("cannot get the input shape") input.shape().expect("cannot get the input shape")
...@@ -59,7 +62,7 @@ fn main() { ...@@ -59,7 +62,7 @@ fn main() {
))) )))
.unwrap(); .unwrap();
// get the global TVM graph runtime function // get the global TVM graph runtime function
let runtime_create_fn = Function::get("tvm.graph_runtime.create", true).unwrap(); let runtime_create_fn = Function::get("tvm.graph_runtime.create").unwrap();
let runtime_create_fn_ret = call_packed!( let runtime_create_fn_ret = call_packed!(
runtime_create_fn, runtime_create_fn,
&graph, &graph,
...@@ -85,14 +88,19 @@ fn main() { ...@@ -85,14 +88,19 @@ fn main() {
.get_function("set_input", false) .get_function("set_input", false)
.unwrap(); .unwrap();
call_packed!(set_input_fn, "data", &input).unwrap(); let data_str = "data".to_string();
call_packed!(set_input_fn, &data_str, &input).unwrap();
// get `run` function from runtime module // get `run` function from runtime module
let ref run_fn = graph_runtime_module.get_function("run", false).unwrap(); let ref run_fn = graph_runtime_module.get_function("run", false).unwrap();
// execute the run function. Note that it has no argument // execute the run function. Note that it has no argument
call_packed!(run_fn,).unwrap(); call_packed!(run_fn,).unwrap();
// prepare to get the output // prepare to get the output
let output_shape = &mut [1, 1000]; let output_shape = &mut [1, 1000];
let output = NDArray::empty(output_shape, TVMContext::cpu(0), TVMType::from("float32")); let output = NDArray::empty(
output_shape,
TVMContext::cpu(0),
TVMType::from_str("float32").unwrap(),
);
// get the `get_output` function from runtime module // get the `get_output` function from runtime module
let ref get_output_fn = graph_runtime_module let ref get_output_fn = graph_runtime_module
.get_function("get_output", false) .get_function("get_output", false)
......
...@@ -3,9 +3,9 @@ ...@@ -3,9 +3,9 @@
//! //!
//! For more detail, please see the example `resnet` in `examples` repository. //! For more detail, please see the example `resnet` in `examples` repository.
use std::os::raw::c_char; use std::os::raw::{c_char, c_void};
use crate::ts; use tvm_common::{ffi, TVMArgValue};
/// A struct holding TVM byte-array. /// A struct holding TVM byte-array.
/// ///
...@@ -19,11 +19,11 @@ use crate::ts; ...@@ -19,11 +19,11 @@ use crate::ts;
/// ``` /// ```
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct TVMByteArray { pub struct TVMByteArray {
pub(crate) inner: ts::TVMByteArray, pub(crate) inner: ffi::TVMByteArray,
} }
impl TVMByteArray { impl TVMByteArray {
pub(crate) fn new(barr: ts::TVMByteArray) -> TVMByteArray { pub(crate) fn new(barr: ffi::TVMByteArray) -> TVMByteArray {
TVMByteArray { inner: barr } TVMByteArray { inner: barr }
} }
...@@ -46,7 +46,7 @@ impl TVMByteArray { ...@@ -46,7 +46,7 @@ impl TVMByteArray {
impl<'a> From<&'a Vec<u8>> for TVMByteArray { impl<'a> From<&'a Vec<u8>> for TVMByteArray {
fn from(arg: &Vec<u8>) -> Self { fn from(arg: &Vec<u8>) -> Self {
let barr = ts::TVMByteArray { let barr = ffi::TVMByteArray {
data: arg.as_ptr() as *const c_char, data: arg.as_ptr() as *const c_char,
size: arg.len(), size: arg.len(),
}; };
...@@ -54,6 +54,18 @@ impl<'a> From<&'a Vec<u8>> for TVMByteArray { ...@@ -54,6 +54,18 @@ impl<'a> From<&'a Vec<u8>> for TVMByteArray {
} }
} }
impl<'a> From<&TVMByteArray> for TVMArgValue<'a> {
fn from(arr: &TVMByteArray) -> Self {
Self {
value: ffi::TVMValue {
v_handle: &arr.inner as *const ffi::TVMByteArray as *const c_void as *mut c_void,
},
type_code: ffi::TVMTypeCode_kBytes as i64,
_lifetime: std::marker::PhantomData,
}
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
......
...@@ -18,12 +18,20 @@ ...@@ -18,12 +18,20 @@
//! ``` //! ```
use std::{ use std::{
convert::TryInto,
fmt::{self, Display, Formatter}, fmt::{self, Display, Formatter},
os::raw::c_void, os::raw::c_void,
ptr, ptr,
}; };
use crate::{function, ts, Result}; use failure::Error;
use tvm_common::{
ffi::{self, TVMValue},
TVMArgValue,
};
use crate::function;
/// Device type can be from a supported device name. See the supported devices /// Device type can be from a supported device name. See the supported devices
/// in [TVM](https://github.com/dmlc/tvm). /// in [TVM](https://github.com/dmlc/tvm).
...@@ -45,35 +53,35 @@ impl Default for TVMDeviceType { ...@@ -45,35 +53,35 @@ impl Default for TVMDeviceType {
} }
} }
impl From<TVMDeviceType> for ts::DLDeviceType { impl From<TVMDeviceType> for ffi::DLDeviceType {
fn from(device_type: TVMDeviceType) -> Self { fn from(device_type: TVMDeviceType) -> Self {
match device_type.0 { match device_type.0 {
1 => ts::DLDeviceType_kDLCPU, 1 => ffi::DLDeviceType_kDLCPU,
2 => ts::DLDeviceType_kDLGPU, 2 => ffi::DLDeviceType_kDLGPU,
3 => ts::DLDeviceType_kDLCPUPinned, 3 => ffi::DLDeviceType_kDLCPUPinned,
4 => ts::DLDeviceType_kDLOpenCL, 4 => ffi::DLDeviceType_kDLOpenCL,
7 => ts::DLDeviceType_kDLVulkan, 7 => ffi::DLDeviceType_kDLVulkan,
8 => ts::DLDeviceType_kDLMetal, 8 => ffi::DLDeviceType_kDLMetal,
9 => ts::DLDeviceType_kDLVPI, 9 => ffi::DLDeviceType_kDLVPI,
10 => ts::DLDeviceType_kDLROCM, 10 => ffi::DLDeviceType_kDLROCM,
12 => ts::DLDeviceType_kDLExtDev, 12 => ffi::DLDeviceType_kDLExtDev,
_ => panic!("device type not found!"), _ => panic!("device type not found!"),
} }
} }
} }
impl From<ts::DLDeviceType> for TVMDeviceType { impl From<ffi::DLDeviceType> for TVMDeviceType {
fn from(device_type: ts::DLDeviceType) -> Self { fn from(device_type: ffi::DLDeviceType) -> Self {
match device_type { match device_type {
ts::DLDeviceType_kDLCPU => TVMDeviceType(1), ffi::DLDeviceType_kDLCPU => TVMDeviceType(1),
ts::DLDeviceType_kDLGPU => TVMDeviceType(2), ffi::DLDeviceType_kDLGPU => TVMDeviceType(2),
ts::DLDeviceType_kDLCPUPinned => TVMDeviceType(3), ffi::DLDeviceType_kDLCPUPinned => TVMDeviceType(3),
ts::DLDeviceType_kDLOpenCL => TVMDeviceType(4), ffi::DLDeviceType_kDLOpenCL => TVMDeviceType(4),
ts::DLDeviceType_kDLVulkan => TVMDeviceType(7), ffi::DLDeviceType_kDLVulkan => TVMDeviceType(7),
ts::DLDeviceType_kDLMetal => TVMDeviceType(8), ffi::DLDeviceType_kDLMetal => TVMDeviceType(8),
ts::DLDeviceType_kDLVPI => TVMDeviceType(9), ffi::DLDeviceType_kDLVPI => TVMDeviceType(9),
ts::DLDeviceType_kDLROCM => TVMDeviceType(10), ffi::DLDeviceType_kDLROCM => TVMDeviceType(10),
ts::DLDeviceType_kDLExtDev => TVMDeviceType(12), ffi::DLDeviceType_kDLExtDev => TVMDeviceType(12),
_ => panic!("device type not found!"), _ => panic!("device type not found!"),
} }
} }
...@@ -117,6 +125,18 @@ impl<'a> From<&'a str> for TVMDeviceType { ...@@ -117,6 +125,18 @@ impl<'a> From<&'a str> for TVMDeviceType {
} }
} }
impl<'a> From<&'a TVMDeviceType> for TVMArgValue<'a> {
fn from(dev_type: &'a TVMDeviceType) -> Self {
Self {
value: TVMValue {
v_int64: dev_type.0 as i64,
},
type_code: ffi::DLDataTypeCode_kDLInt as i64,
_lifetime: std::marker::PhantomData,
}
}
}
/// Represents the underlying device context. Default is cpu. /// Represents the underlying device context. Default is cpu.
/// ///
/// ## Examples /// ## Examples
...@@ -138,15 +158,15 @@ pub struct TVMContext { ...@@ -138,15 +158,15 @@ pub struct TVMContext {
/// Supported device types /// Supported device types
pub device_type: TVMDeviceType, pub device_type: TVMDeviceType,
/// Device id /// Device id
pub device_id: usize, pub device_id: i32,
} }
impl TVMContext { impl TVMContext {
/// Creates context from device type and id. /// Creates context from device type and id.
pub fn new(device_type: TVMDeviceType, device_id: usize) -> Self { pub fn new(device_type: TVMDeviceType, device_id: i32) -> Self {
TVMContext { TVMContext {
device_type: device_type, device_type,
device_id: device_id, device_id,
} }
} }
} }
...@@ -155,7 +175,7 @@ macro_rules! impl_ctxs { ...@@ -155,7 +175,7 @@ macro_rules! impl_ctxs {
($(($ctx:ident, $dldevt:expr));+) => { ($(($ctx:ident, $dldevt:expr));+) => {
$( $(
impl TVMContext { impl TVMContext {
pub fn $ctx(device_id: usize) -> Self { pub fn $ctx(device_id: i32) -> Self {
Self::new(TVMDeviceType($dldevt), device_id) Self::new(TVMDeviceType($dldevt), device_id)
} }
} }
...@@ -185,20 +205,20 @@ impl<'a> From<&'a str> for TVMContext { ...@@ -185,20 +205,20 @@ impl<'a> From<&'a str> for TVMContext {
impl TVMContext { impl TVMContext {
/// Checks whether the context exists or not. /// Checks whether the context exists or not.
pub fn exist(&self) -> bool { pub fn exist(&self) -> bool {
let func = function::Function::get("_GetDeviceAttr", true /* is_global */) let func = function::Function::get("_GetDeviceAttr").expect("API function always exists");
.expect("API function always exists");
let dt = self.device_type.0 as usize; let dt = self.device_type.0 as usize;
// `unwrap` is ok here because if there is any error, // `unwrap` is ok here because if there is any error,
// if would occure inside `call_packed!` // if would occure inside `call_packed!`
let ret = call_packed!(func, &dt, &self.device_id, &0) let ret: u64 = call_packed!(func, &dt, &self.device_id, &0)
.unwrap() .unwrap()
.prim_value; .try_into()
.unwrap();
ret != 0 ret != 0
} }
/// Synchronize the context stream. /// Synchronize the context stream.
pub fn sync(&self) -> Result<()> { pub fn sync(&self) -> Result<(), Error> {
check_call!(ts::TVMSynchronize( check_call!(ffi::TVMSynchronize(
self.device_type.0 as i32, self.device_type.0 as i32,
self.device_id as i32, self.device_id as i32,
ptr::null_mut() as *mut c_void ptr::null_mut() as *mut c_void
...@@ -212,16 +232,17 @@ macro_rules! impl_device_attrs { ...@@ -212,16 +232,17 @@ macro_rules! impl_device_attrs {
$( $(
impl TVMContext { impl TVMContext {
pub fn $attr_name(&self) -> usize { pub fn $attr_name(&self) -> usize {
let func = function::Function::get("_GetDeviceAttr", true /* is_global */) let func = function::Function::get("_GetDeviceAttr")
.expect("API function always exists"); .expect("API function always exists");
let dt = self.device_type.0 as usize; let dt = self.device_type.0 as usize;
// `unwrap` is ok here because if there is any error, // `unwrap` is ok here because if there is any error,
// if would occur in function call. // if would occur in function call.
let ret = function::Builder::from(func) function::Builder::from(func)
.args(&[dt, self.device_id, $attr_kind]) .args(&[dt, self.device_id as usize, $attr_kind])
.invoke() .invoke()
.unwrap(); .unwrap()
ret.prim_value as usize .try_into()
.unwrap()
} }
} }
)+ )+
...@@ -237,18 +258,18 @@ impl_device_attrs!((max_threads_per_block, 1); ...@@ -237,18 +258,18 @@ impl_device_attrs!((max_threads_per_block, 1);
(multi_processor_count, 7); (multi_processor_count, 7);
(max_thread_dimensions, 8)); (max_thread_dimensions, 8));
impl From<ts::DLContext> for TVMContext { impl From<ffi::DLContext> for TVMContext {
fn from(ctx: ts::DLContext) -> Self { fn from(ctx: ffi::DLContext) -> Self {
TVMContext { TVMContext {
device_type: TVMDeviceType::from(ctx.device_type), device_type: TVMDeviceType::from(ctx.device_type),
device_id: ctx.device_id as usize, device_id: ctx.device_id,
} }
} }
} }
impl From<TVMContext> for ts::DLContext { impl From<TVMContext> for ffi::DLContext {
fn from(ctx: TVMContext) -> Self { fn from(ctx: TVMContext) -> Self {
ts::DLContext { ffi::DLContext {
device_type: ctx.device_type.into(), device_type: ctx.device_type.into(),
device_id: ctx.device_id as i32, device_id: ctx.device_id as i32,
} }
......
//! This module implements TVM custom [`Error`], [`ErrorKind`] and [`Result`] types. pub use failure::Error;
use std::{ffi, option}; #[derive(Debug, Fail)]
#[fail(display = "Cannot convert from an empty array.")]
pub struct EmptyArrayError;
use crate::{common_errors, rust_ndarray}; #[derive(Debug, Fail)]
#[fail(display = "Handle `{}` is null.", name)]
error_chain! { pub struct NullHandleError {
errors { pub name: String,
EmptyArray { }
description("cannot convert from an empty array")
}
NullHandle(name: String) {
description("null handle")
display("requested `{}` handle is null", name)
}
FunctionNotFound {
description("function not found")
display("function was not set in `function::Builder`")
}
TypeMismatch(expected: String, found: String) {
description("type mismatch!")
display("expected type `{}`, but found `{}`", expected, found)
}
MissingShapeError {
description("ndarray `shape()` returns `None`")
display("called `Option::unwrap()` on a `None` value")
}
AtMostOneReturn {
description("TVM functions accept at most one return value")
}
} #[derive(Debug, Fail)]
#[fail(display = "Function was not set in `function::Builder`")]
pub struct FunctionNotFoundError;
foreign_links { #[derive(Debug, Fail)]
ShapeError(rust_ndarray::ShapeError); #[fail(display = "Expected type `{}` but found `{}`", expected, actual)]
NulError(ffi::NulError); pub struct TypeMismatchError {
IntoStringError(ffi::IntoStringError); pub expected: String,
CommonError(common_errors::Error); pub actual: String,
}
} }
impl From<option::NoneError> for Error { #[derive(Debug, Fail)]
fn from(_err: option::NoneError) -> Self { #[fail(display = "Missing NDArray shape.")]
ErrorKind::MissingShapeError.into() pub struct MissingShapeError;
}
}
...@@ -15,14 +15,20 @@ use std::{ ...@@ -15,14 +15,20 @@ use std::{
sync::Mutex, sync::Mutex,
}; };
use crate::{ts, ErrorKind, Module, Result, TVMArgValue, TVMRetValue, TVMTypeCode, TVMValue}; use failure::Error;
use crate::{
errors,
ffi::{self, TVMValue},
Module, TVMArgValue, TVMRetValue,
};
lazy_static! { lazy_static! {
static ref GLOBAL_FUNCTIONS: Mutex<BTreeMap<&'static str, Option<Function>>> = { static ref GLOBAL_FUNCTIONS: Mutex<BTreeMap<&'static str, Option<Function>>> = {
let mut out_size = 0 as c_int; let mut out_size = 0 as c_int;
let name = ptr::null_mut() as *mut c_char; let name = ptr::null_mut() as *mut c_char;
let mut out_array = name as *mut _; let mut out_array = name as *mut _;
check_call!(ts::TVMFuncListGlobalNames( check_call!(ffi::TVMFuncListGlobalNames(
&mut out_size as *mut _, &mut out_size as *mut _,
&mut out_array &mut out_array
)); ));
...@@ -37,17 +43,14 @@ lazy_static! { ...@@ -37,17 +43,14 @@ lazy_static! {
} }
/// Wrapper around TVM function handle which includes `is_global` /// Wrapper around TVM function handle which includes `is_global`
/// indicating whether the function is global or not, `is_released` /// indicating whether the function is global or not, and `is_cloned` showing
/// to hint dropping the function handle and `is_cloned` showing
/// not to drop a cloned function from Rust side. /// not to drop a cloned function from Rust side.
/// The value of these fields can be accessed through their respective methods. /// The value of these fields can be accessed through their respective methods.
#[derive(Debug, Hash)] #[derive(Debug, Hash)]
pub struct Function { pub struct Function {
pub(crate) handle: ts::TVMFunctionHandle, pub(crate) handle: ffi::TVMFunctionHandle,
// whether the registered function is global or not. // whether the registered function is global or not.
is_global: bool, is_global: bool,
// whether the function has been dropped from frontend or not.
is_released: bool,
// whether the function has been cloned from frontend or not. // whether the function has been cloned from frontend or not.
is_cloned: bool, is_cloned: bool,
} }
...@@ -56,29 +59,30 @@ unsafe impl Send for Function {} ...@@ -56,29 +59,30 @@ unsafe impl Send for Function {}
unsafe impl Sync for Function {} unsafe impl Sync for Function {}
impl Function { impl Function {
pub(crate) fn new(handle: ts::TVMFunctionHandle, is_global: bool, is_released: bool) -> Self { pub(crate) fn new(handle: ffi::TVMFunctionHandle) -> Self {
Function { Function {
handle: handle, handle: handle,
is_global: is_global, is_global: false,
is_released: is_released,
is_cloned: false, is_cloned: false,
} }
} }
/// For a given function, it returns a function by name. /// For a given function, it returns a function by name.
pub fn get<S: AsRef<str>>(name: S, is_global: bool) -> Option<&'static Function> { pub fn get<S: AsRef<str>>(name: S) -> Option<&'static Function> {
let mut globals = GLOBAL_FUNCTIONS.lock().unwrap(); let mut globals = GLOBAL_FUNCTIONS.lock().unwrap();
globals.get_mut(name.as_ref()).and_then(|maybe_func| { globals.get_mut(name.as_ref()).and_then(|maybe_func| {
if maybe_func.is_none() { if maybe_func.is_none() {
let name = CString::new(name.as_ref()).unwrap(); let name = CString::new(name.as_ref()).unwrap();
let mut handle = ptr::null_mut() as ts::TVMFunctionHandle; let mut handle = ptr::null_mut() as ffi::TVMFunctionHandle;
check_call!(ts::TVMFuncGetGlobal( check_call!(ffi::TVMFuncGetGlobal(
name.as_ptr() as *const c_char, name.as_ptr() as *const c_char,
&mut handle as *mut _ &mut handle as *mut _
)); ));
maybe_func.replace(Function::new( maybe_func.replace(Function {
handle, is_global, false, /* is_released */ handle: handle,
)); is_global: true,
is_cloned: false,
});
} }
unsafe { unsafe {
std::mem::transmute::<Option<&Function>, Option<&'static Function>>( std::mem::transmute::<Option<&Function>, Option<&'static Function>>(
...@@ -89,7 +93,7 @@ impl Function { ...@@ -89,7 +93,7 @@ impl Function {
} }
/// Returns the underlying TVM function handle. /// Returns the underlying TVM function handle.
pub fn handle(&self) -> ts::TVMFunctionHandle { pub fn handle(&self) -> ffi::TVMFunctionHandle {
self.handle self.handle
} }
...@@ -98,12 +102,6 @@ impl Function { ...@@ -98,12 +102,6 @@ impl Function {
self.is_global self.is_global
} }
/// Returns `true` if the underlying TVM function has been released
/// from the frontend and `false` otherwise.
pub fn is_released(&self) -> bool {
self.is_released
}
/// Returns `true` if the underlying TVM function has been cloned /// Returns `true` if the underlying TVM function has been cloned
/// from the frontend and `false` otherwise. /// from the frontend and `false` otherwise.
pub fn is_cloned(&self) -> bool { pub fn is_cloned(&self) -> bool {
...@@ -113,24 +111,18 @@ impl Function { ...@@ -113,24 +111,18 @@ impl Function {
impl Clone for Function { impl Clone for Function {
fn clone(&self) -> Function { fn clone(&self) -> Function {
if !self.is_released && !self.is_cloned {
Self { Self {
handle: self.handle, handle: self.handle,
is_global: self.is_global, is_global: self.is_global,
is_released: self.is_released,
is_cloned: true, is_cloned: true,
} }
} else {
Function::new(self.handle, self.is_global, self.is_released)
}
} }
} }
impl Drop for Function { impl Drop for Function {
fn drop(&mut self) { fn drop(&mut self) {
if !self.is_released && !self.is_global && !self.is_cloned { if !self.is_global && !self.is_cloned {
check_call!(ts::TVMFuncFree(self.handle)); check_call!(ffi::TVMFuncFree(self.handle));
self.is_released = true;
} }
} }
} }
...@@ -138,17 +130,17 @@ impl Drop for Function { ...@@ -138,17 +130,17 @@ impl Drop for Function {
/// Function builder in order to create and call functions. /// Function builder in order to create and call functions.
/// ///
/// *Note:* Currently TVM functions accept *at most* one return value. /// *Note:* Currently TVM functions accept *at most* one return value.
#[derive(Debug, Clone, Default)] #[derive(Default)]
pub struct Builder<'a, 'm> { pub struct Builder<'a, 'm> {
pub func: Option<&'m Function>, pub func: Option<&'m Function>,
pub arg_buf: Option<Box<[TVMArgValue<'a>]>>, pub arg_buf: Vec<TVMArgValue<'a>>,
pub ret_buf: Option<TVMRetValue>, pub ret_buf: Option<TVMRetValue>,
} }
impl<'a, 'm> Builder<'a, 'm> { impl<'a, 'm> Builder<'a, 'm> {
pub fn new( pub fn new(
func: Option<&'m Function>, func: Option<&'m Function>,
arg_buf: Option<Box<[TVMArgValue<'a>]>>, arg_buf: Vec<TVMArgValue<'a>>,
ret_buf: Option<TVMRetValue>, ret_buf: Option<TVMRetValue>,
) -> Self { ) -> Self {
Self { Self {
...@@ -158,123 +150,66 @@ impl<'a, 'm> Builder<'a, 'm> { ...@@ -158,123 +150,66 @@ impl<'a, 'm> Builder<'a, 'm> {
} }
} }
pub fn get_function(&mut self, name: &'m str, is_global: bool) -> &mut Self { pub fn get_function(&mut self, name: &'m str) -> &mut Self {
self.func = Function::get(name, is_global); self.func = Function::get(name);
self self
} }
/// Pushes a [`TVMArgValue`] into the function argument buffer. /// Pushes a [`TVMArgValue`] into the function argument buffer.
pub fn arg<'b, T: ?Sized>(&mut self, arg: &'b T) -> &mut Self pub fn arg<T: 'a>(&mut self, arg: &'a T) -> &mut Self
where where
TVMValue: From<&'b T>, TVMArgValue<'a>: From<&'a T>,
TVMTypeCode: From<&'b T>,
{ {
let tvm_arg = TVMArgValue::from(arg); self.arg_buf.push(arg.into());
if self.arg_buf.is_none() {
self.arg_buf = Some(Box::new([tvm_arg]));
} else {
let new_arg_buf = self.arg_buf.take().map(|bbuf| {
let mut new_arg_buf = Vec::from(bbuf);
new_arg_buf.push(tvm_arg);
let new_len = new_arg_buf.len();
new_arg_buf.truncate(new_len);
new_arg_buf.into_boxed_slice()
});
self.arg_buf = new_arg_buf;
}
self self
} }
/// Pushes multiple [`TVMArgValue`]s into the function argument buffer. /// Pushes multiple [`TVMArgValue`]s into the function argument buffer.
pub fn args<'b, T: 'b + ?Sized, I>(&mut self, args: I) -> &mut Self pub fn args<T: 'a, I>(&mut self, args: I) -> &mut Self
where where
I: IntoIterator<Item = &'b T>, I: IntoIterator<Item = &'a T>,
TVMValue: From<&'b T>, TVMArgValue<'a>: From<&'a T>,
TVMTypeCode: From<&'b T>,
{ {
for arg in args { args.into_iter().for_each(|arg| {
self.arg(&arg); self.arg(&arg);
} });
self self
} }
/// Sets an output for a function that requirs a mutable output to be provided. /// Sets an output for a function that requirs a mutable output to be provided.
/// See the `basics` in tests for an example. /// See the `basics` in tests for an example.
pub fn set_output<'b, T: 'b + ?Sized>(&mut self, arg: &'b mut T) -> Result<&mut Self> pub fn set_output<T>(&mut self, ret: T) -> &mut Self
where where
TVMValue: From<&'b T>, TVMRetValue: From<T>,
TVMTypeCode: From<&'b T>,
{ {
if self.ret_buf.is_none() { self.ret_buf = Some(ret.into());
let tvm_ret = self
unsafe { TVMRetValue::from_tvm_value(TVMValue::from(arg), TVMTypeCode::from(arg)) };
self.ret_buf = Some(tvm_ret);
} else {
bail!(ErrorKind::AtMostOneReturn)
}
Ok(self)
} }
/// Calls the function that created from `Builder`. /// Calls the function that created from `Builder`.
pub fn invoke(&mut self) -> Result<TVMRetValue> { pub fn invoke(&mut self) -> Result<TVMRetValue, Error> {
self.clone()(()) #![allow(unused_unsafe)]
} ensure!(self.func.is_some(), errors::FunctionNotFoundError);
}
impl<'a, 'm> FnOnce<((),)> for Builder<'a, 'm> { let num_args = self.arg_buf.len();
type Output = Result<TVMRetValue>; let (mut values, mut type_codes): (Vec<ffi::TVMValue>, Vec<ffi::TVMTypeCode>) = self
extern "rust-call" fn call_once(self, _: ((),)) -> Self::Output { .arg_buf
if self.func.is_none() {
bail!("{}", ErrorKind::FunctionNotFound);
}
let mut ret_val = unsafe { mem::uninitialized::<ts::TVMValue>() };
let mut ret_type_code = 0 as c_int;
if self.arg_buf.is_some() {
let arg_buf = self.arg_buf?;
let mut num_args = arg_buf.len();
let mut values = arg_buf
.iter()
.map(|tav| tav.value.inner)
.collect::<Vec<ts::TVMValue>>();
let mut tcodes = arg_buf
.iter() .iter()
.map(|tav| tav.type_code as c_int) .map(|tvm_arg| (tvm_arg.value, tvm_arg.type_code as ffi::TVMTypeCode))
.collect::<Vec<_>>(); .unzip();
if self.ret_buf.is_some() {
num_args = num_args + 1;
let ret_buf = self.ret_buf?;
let (ret_val, ret_type_code) = TVMRetValue::into_tvm_value(ret_buf);
values.append(&mut vec![ret_val.inner]);
tcodes.append(&mut vec![ret_type_code as c_int]);
}
values.truncate(num_args); let mut ret_val = unsafe { std::mem::uninitialized::<TVMValue>() };
tcodes.truncate(num_args); let mut ret_type_code = 0;
check_call!(ts::TVMFuncCall( check_call!(ffi::TVMFuncCall(
self.func?.handle, self.func.ok_or(errors::FunctionNotFoundError)?.handle,
values.as_mut_ptr(), values.as_mut_ptr(),
tcodes.as_mut_ptr(), type_codes.as_mut_ptr() as *mut i32,
num_args as c_int, num_args as c_int,
&mut ret_val as *mut _, &mut ret_val as *mut _,
&mut ret_type_code as *mut _ &mut ret_type_code as *mut _
)); ));
} else {
check_call!(ts::TVMFuncCall(
self.func?.handle,
ptr::null_mut(),
ptr::null_mut(),
0 as c_int,
&mut ret_val as *mut _,
&mut ret_type_code as *mut _
));
}
let ret = unsafe { Ok(unsafe { TVMRetValue::from_tvm_value(ret_val, ret_type_code as i64) })
TVMRetValue::from_tvm_value(TVMValue::new(ret_val), (ret_type_code as i64).into())
};
Ok(ret)
} }
} }
...@@ -282,46 +217,44 @@ impl<'a, 'm> FnOnce<((),)> for Builder<'a, 'm> { ...@@ -282,46 +217,44 @@ impl<'a, 'm> FnOnce<((),)> for Builder<'a, 'm> {
/// TVM functions. /// TVM functions.
impl<'a, 'm> From<&'m Function> for Builder<'a, 'm> { impl<'a, 'm> From<&'m Function> for Builder<'a, 'm> {
fn from(func: &'m Function) -> Self { fn from(func: &'m Function) -> Self {
Builder::new(Some(func), None, None) Builder::new(Some(func), Vec::new(), None)
} }
} }
/// Converts a mutable reference of a [`Module`] to [`Builder`]. /// Converts a mutable reference of a [`Module`] to [`Builder`].
impl<'a, 'm> From<&'m mut Module> for Builder<'a, 'm> { impl<'a, 'm> From<&'m mut Module> for Builder<'a, 'm> {
fn from(module: &'m mut Module) -> Self { fn from(module: &'m mut Module) -> Self {
Builder::new(module.entry(), None, None) Builder::new(module.entry(), Vec::new(), None)
} }
} }
unsafe extern "C" fn tvm_callback( unsafe extern "C" fn tvm_callback(
args: *mut ts::TVMValue, args: *mut ffi::TVMValue,
type_codes: *mut c_int, type_codes: *mut c_int,
num_args: c_int, num_args: c_int,
ret: ts::TVMRetValueHandle, ret: ffi::TVMRetValueHandle,
fhandle: *mut c_void, fhandle: *mut c_void,
) -> c_int { ) -> c_int {
// turning off the incorrect linter complaints // turning off the incorrect linter complaints
#![allow(unused_assignments)] #![allow(unused_assignments, unused_unsafe)]
let len = num_args as usize; let len = num_args as usize;
let args_list = slice::from_raw_parts_mut(args, len); let args_list = slice::from_raw_parts_mut(args, len);
let type_codes_list = slice::from_raw_parts_mut(type_codes, len); let type_codes_list = slice::from_raw_parts_mut(type_codes, len);
let mut local_args: Vec<TVMArgValue> = Vec::new(); let mut local_args: Vec<TVMArgValue> = Vec::new();
let mut value = mem::uninitialized::<ts::TVMValue>(); let mut value = mem::uninitialized::<ffi::TVMValue>();
let mut tcode = mem::uninitialized::<c_int>(); let mut tcode = mem::uninitialized::<c_int>();
let rust_fn = mem::transmute::<*mut c_void, fn(&[TVMArgValue]) -> Result<TVMRetValue>>(fhandle); let rust_fn =
mem::transmute::<*mut c_void, fn(&[TVMArgValue]) -> Result<TVMRetValue, Error>>(fhandle);
for i in 0..len { for i in 0..len {
value = args_list[i]; value = args_list[i];
tcode = type_codes_list[i]; tcode = type_codes_list[i];
if tcode == ts::TVMTypeCode_kNodeHandle as c_int if tcode == ffi::TVMTypeCode_kNodeHandle as c_int
|| tcode == ts::TVMTypeCode_kFuncHandle as c_int || tcode == ffi::TVMTypeCode_kFuncHandle as c_int
|| tcode == ts::TVMTypeCode_kModuleHandle as c_int || tcode == ffi::TVMTypeCode_kModuleHandle as c_int
{ {
check_call!(ts::TVMCbArgToReturn(&mut value as *mut _, tcode)); check_call!(ffi::TVMCbArgToReturn(&mut value as *mut _, tcode));
} }
local_args.push(TVMArgValue::new( local_args.push(TVMArgValue::new(value.into(), (tcode as i64).into()));
TVMValue::new(value),
(tcode as i64).into(),
));
} }
let rv = match rust_fn(local_args.as_slice()) { let rv = match rust_fn(local_args.as_slice()) {
...@@ -332,10 +265,9 @@ unsafe extern "C" fn tvm_callback( ...@@ -332,10 +265,9 @@ unsafe extern "C" fn tvm_callback(
} }
}; };
let (ret_val, ret_tcode) = TVMRetValue::into_tvm_value(rv); let (mut ret_val, ret_tcode) = rv.into_tvm_value();
let mut ret_val = ret_val.inner;
let mut ret_type_code = ret_tcode as c_int; let mut ret_type_code = ret_tcode as c_int;
check_call!(ts::TVMCFuncSetReturn( check_call!(ffi::TVMCFuncSetReturn(
ret, ret,
&mut ret_val as *mut _, &mut ret_val as *mut _,
&mut ret_type_code as *mut _, &mut ret_type_code as *mut _,
...@@ -345,24 +277,25 @@ unsafe extern "C" fn tvm_callback( ...@@ -345,24 +277,25 @@ unsafe extern "C" fn tvm_callback(
} }
unsafe extern "C" fn tvm_callback_finalizer(fhandle: *mut c_void) { unsafe extern "C" fn tvm_callback_finalizer(fhandle: *mut c_void) {
let rust_fn = mem::transmute::<*mut c_void, fn(&[TVMArgValue]) -> Result<TVMRetValue>>(fhandle); let rust_fn =
mem::transmute::<*mut c_void, fn(&[TVMArgValue]) -> Result<TVMRetValue, Error>>(fhandle);
mem::drop(rust_fn); mem::drop(rust_fn);
} }
fn convert_to_tvm_func(f: fn(&[TVMArgValue]) -> Result<TVMRetValue>) -> Function { fn convert_to_tvm_func(f: fn(&[TVMArgValue]) -> Result<TVMRetValue, Error>) -> Function {
let mut fhandle = ptr::null_mut() as ts::TVMFunctionHandle; let mut fhandle = ptr::null_mut() as ffi::TVMFunctionHandle;
let resource_handle = f as *mut fn(&[TVMArgValue]) -> Result<TVMRetValue>; let resource_handle = f as *mut fn(&[TVMArgValue]) -> Result<TVMRetValue, Error>;
check_call!(ts::TVMFuncCreateFromCFunc( check_call!(ffi::TVMFuncCreateFromCFunc(
Some(tvm_callback), Some(tvm_callback),
resource_handle as *mut c_void, resource_handle as *mut c_void,
Some(tvm_callback_finalizer), Some(tvm_callback_finalizer),
&mut fhandle as *mut _ &mut fhandle as *mut _
)); ));
Function::new(fhandle, false, false) Function::new(fhandle)
} }
/// Registers a Rust function with signature /// Registers a Rust function with signature
/// `fn(&[TVMArgValue]) -> Result<TVMRetValue>` /// `fn(&[TVMArgValue]) -> Result<TVMRetValue, Error>`
/// as a **global TVM packed function** from frontend to TVM backend. /// as a **global TVM packed function** from frontend to TVM backend.
/// ///
/// Use [`register_global_func`] if overriding an existing global TVM function /// Use [`register_global_func`] if overriding an existing global TVM function
...@@ -373,7 +306,7 @@ fn convert_to_tvm_func(f: fn(&[TVMArgValue]) -> Result<TVMRetValue>) -> Function ...@@ -373,7 +306,7 @@ fn convert_to_tvm_func(f: fn(&[TVMArgValue]) -> Result<TVMRetValue>) -> Function
/// ``` /// ```
/// use std::convert::TryInto; /// use std::convert::TryInto;
/// ///
/// fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue> { /// fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue, Error> {
/// let mut ret = 0i64; /// let mut ret = 0i64;
/// for arg in args.iter() { /// for arg in args.iter() {
/// let arg: i64 = arg.try_into()?; /// let arg: i64 = arg.try_into()?;
...@@ -391,18 +324,17 @@ fn convert_to_tvm_func(f: fn(&[TVMArgValue]) -> Result<TVMRetValue>) -> Function ...@@ -391,18 +324,17 @@ fn convert_to_tvm_func(f: fn(&[TVMArgValue]) -> Result<TVMRetValue>) -> Function
/// assert_eq!(ret, 60); /// assert_eq!(ret, 60);
/// ``` /// ```
pub fn register<S: AsRef<str>>( pub fn register<S: AsRef<str>>(
f: fn(&[TVMArgValue]) -> Result<TVMRetValue>, f: fn(&[TVMArgValue]) -> Result<TVMRetValue, Error>,
name: S, name: S,
override_: bool, override_: bool,
) -> Result<()> { ) -> Result<(), Error> {
let func = convert_to_tvm_func(f); let func = convert_to_tvm_func(f);
let name = CString::new(name.as_ref())?; let name = CString::new(name.as_ref())?;
check_call!(ts::TVMFuncRegisterGlobal( check_call!(ffi::TVMFuncRegisterGlobal(
name.as_ref().as_ptr() as *const c_char, name.into_raw(),
func.handle(), func.handle(),
override_ as c_int override_ as c_int
)); ));
mem::forget(name);
Ok(()) Ok(())
} }
...@@ -416,7 +348,7 @@ pub fn register<S: AsRef<str>>( ...@@ -416,7 +348,7 @@ pub fn register<S: AsRef<str>>(
/// use std::convert::TryInto; /// use std::convert::TryInto;
/// ///
/// register_global_func! { /// register_global_func! {
/// fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue> { /// fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue, Error> {
/// let mut ret = 0f64; /// let mut ret = 0f64;
/// for arg in args.iter() { /// for arg in args.iter() {
/// let arg: f64 = arg.try_into()?; /// let arg: f64 = arg.try_into()?;
...@@ -437,12 +369,12 @@ pub fn register<S: AsRef<str>>( ...@@ -437,12 +369,12 @@ pub fn register<S: AsRef<str>>(
macro_rules! register_global_func { macro_rules! register_global_func {
{ {
$(#[$m:meta])* $(#[$m:meta])*
fn $fn_name:ident($args:ident : &[TVMArgValue]) -> Result<TVMRetValue> { fn $fn_name:ident($args:ident : &[TVMArgValue]) -> Result<TVMRetValue, Error> {
$($code:tt)* $($code:tt)*
} }
} => {{ } => {{
$(#[$m])* $(#[$m])*
fn $fn_name($args: &[TVMArgValue]) -> Result<TVMRetValue> { fn $fn_name($args: &[TVMArgValue]) -> Result<TVMRetValue, Error> {
$($code)* $($code)*
} }
...@@ -496,17 +428,17 @@ mod tests { ...@@ -496,17 +428,17 @@ mod tests {
#[test] #[test]
fn get_fn() { fn get_fn() {
assert!(Function::get(CANARY, true).is_some()); assert!(Function::get(CANARY).is_some());
assert!(Function::get("does not exists!", false).is_none()); assert!(Function::get("does not exists!").is_none());
} }
#[test] #[test]
fn provide_args() { fn provide_args() {
let str_arg = CString::new("test").unwrap();
let mut func = Builder::default(); let mut func = Builder::default();
func.get_function("tvm.graph_runtime.remote_create", true) func.get_function("tvm.graph_runtime.remote_create")
.args(&[10, 20]) .args(&[10, 20])
.arg(&"test".to_owned()); .arg(&str_arg);
assert!(func.arg_buf.is_some()); assert_eq!(func.arg_buf.len(), 3);
assert_eq!(func.arg_buf.take().map(|bv| Vec::from(bv).len()), Some(3));
} }
} }
...@@ -11,32 +11,36 @@ ...@@ -11,32 +11,36 @@
//! //!
//! Checkout the `examples` repository for more details. //! Checkout the `examples` repository for more details.
#![crate_name = "tvm_frontend"] #![feature(box_syntax)]
#![recursion_limit = "1024"]
#![allow(non_camel_case_types, unused_unsafe)]
#![feature(
try_from,
try_trait,
fn_traits,
unboxed_closures,
box_syntax,
option_replace
)]
#[macro_use] #[macro_use]
extern crate error_chain; extern crate failure;
extern crate tvm_common as common;
#[macro_use] #[macro_use]
extern crate lazy_static; extern crate lazy_static;
extern crate ndarray as rust_ndarray; extern crate ndarray as rust_ndarray;
extern crate num_traits; extern crate num_traits;
extern crate tvm_common;
use std::{ use std::{
ffi::{CStr, CString}, ffi::{CStr, CString},
str, str,
}; };
use crate::common::ffi::ts; use failure::Error;
pub use crate::{
bytearray::TVMByteArray,
context::{TVMContext, TVMDeviceType},
errors::*,
function::Function,
module::Module,
ndarray::NDArray,
tvm_common::{
errors as common_errors,
ffi::{self, TVMType},
packed_func::{TVMArgValue, TVMRetValue},
},
};
// Macro to check the return call to TVM runtime shared library. // Macro to check the return call to TVM runtime shared library.
macro_rules! check_call { macro_rules! check_call {
...@@ -50,7 +54,7 @@ macro_rules! check_call { ...@@ -50,7 +54,7 @@ macro_rules! check_call {
/// Gets the last error message. /// Gets the last error message.
pub fn get_last_error() -> &'static str { pub fn get_last_error() -> &'static str {
unsafe { unsafe {
match CStr::from_ptr(ts::TVMGetLastError()).to_str() { match CStr::from_ptr(ffi::TVMGetLastError()).to_str() {
Ok(s) => s, Ok(s) => s,
Err(_) => "Invalid UTF-8 message", Err(_) => "Invalid UTF-8 message",
} }
...@@ -60,7 +64,7 @@ pub fn get_last_error() -> &'static str { ...@@ -60,7 +64,7 @@ pub fn get_last_error() -> &'static str {
pub(crate) fn set_last_error(err: &Error) { pub(crate) fn set_last_error(err: &Error) {
let c_string = CString::new(err.to_string()).unwrap(); let c_string = CString::new(err.to_string()).unwrap();
unsafe { unsafe {
ts::TVMAPISetLastError(c_string.as_ptr()); ffi::TVMAPISetLastError(c_string.as_ptr());
} }
} }
...@@ -71,27 +75,11 @@ pub mod context; ...@@ -71,27 +75,11 @@ pub mod context;
pub mod errors; pub mod errors;
pub mod module; pub mod module;
pub mod ndarray; pub mod ndarray;
pub mod ty;
pub mod value; pub mod value;
pub use crate::{
bytearray::TVMByteArray,
common::{
errors as common_errors,
ty::TVMTypeCode,
value::{TVMArgValue, TVMRetValue, TVMValue},
},
context::{TVMContext, TVMDeviceType},
errors::*,
function::Function,
module::Module,
ndarray::NDArray,
ty::TVMType,
};
/// Outputs the current TVM version. /// Outputs the current TVM version.
pub fn version() -> &'static str { pub fn version() -> &'static str {
match str::from_utf8(ts::TVM_VERSION) { match str::from_utf8(ffi::TVM_VERSION) {
Ok(s) => s, Ok(s) => s,
Err(_) => "Invalid UTF-8 string", Err(_) => "Invalid UTF-8 string",
} }
...@@ -108,8 +96,8 @@ mod tests { ...@@ -108,8 +96,8 @@ mod tests {
#[test] #[test]
fn set_error() { fn set_error() {
let err = ErrorKind::EmptyArray; let err = errors::EmptyArrayError;
set_last_error(&err.into()); set_last_error(&err.into());
assert_eq!(get_last_error().trim(), ErrorKind::EmptyArray.to_string()); assert_eq!(get_last_error().trim(), errors::EmptyArrayError.to_string());
} }
} }
...@@ -8,30 +8,27 @@ use std::{ ...@@ -8,30 +8,27 @@ use std::{
ptr, ptr,
}; };
use crate::ts; use failure::Error;
use tvm_common::ffi;
use crate::{function::Function, ErrorKind, Result}; use crate::{errors, function::Function};
const ENTRY_FUNC: &'static str = "__tvm_main__"; const ENTRY_FUNC: &'static str = "__tvm_main__";
/// Wrapper around TVM module handle which contains an entry function. /// Wrapper around TVM module handle which contains an entry function.
/// The entry function can be applied to an imported module through [`entry_func`]. /// The entry function can be applied to an imported module through [`entry_func`].
/// Also [`is_released`] shows whether the module is dropped or not.
/// ///
/// [`entry_func`]:struct.Module.html#method.entry_func /// [`entry_func`]:struct.Module.html#method.entry_func
/// [`is_released`]:struct.Module.html#method.is_released
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Module { pub struct Module {
pub(crate) handle: ts::TVMModuleHandle, pub(crate) handle: ffi::TVMModuleHandle,
is_released: bool,
entry_func: Option<Function>, entry_func: Option<Function>,
} }
impl Module { impl Module {
pub(crate) fn new(handle: ts::TVMModuleHandle, is_released: bool) -> Self { pub(crate) fn new(handle: ffi::TVMModuleHandle) -> Self {
Self { Self {
handle, handle,
is_released,
entry_func: None, entry_func: None,
} }
} }
...@@ -44,62 +41,67 @@ impl Module { ...@@ -44,62 +41,67 @@ impl Module {
} }
/// Gets a function by name from a registered module. /// Gets a function by name from a registered module.
pub fn get_function(&self, name: &str, query_import: bool) -> Result<Function> { pub fn get_function(&self, name: &str, query_import: bool) -> Result<Function, Error> {
let name = CString::new(name)?; let name = CString::new(name)?;
let mut fhandle = ptr::null_mut() as ts::TVMFunctionHandle; let mut fhandle = ptr::null_mut() as ffi::TVMFunctionHandle;
check_call!(ts::TVMModGetFunction( check_call!(ffi::TVMModGetFunction(
self.handle, self.handle,
name.as_ptr() as *const c_char, name.as_ptr() as *const c_char,
query_import as c_int, query_import as c_int,
&mut fhandle as *mut _ &mut fhandle as *mut _
)); ));
if fhandle.is_null() { ensure!(
bail!(ErrorKind::NullHandle(format!("{}", name.into_string()?))) !fhandle.is_null(),
} else { errors::NullHandleError {
Ok(Function::new(fhandle, false, false)) name: format!("{}", name.into_string()?)
} }
);
Ok(Function::new(fhandle))
} }
/// Imports a dependent module such as `.ptx` for gpu. /// Imports a dependent module such as `.ptx` for gpu.
pub fn import_module(&self, dependent_module: Module) { pub fn import_module(&self, dependent_module: Module) {
check_call!(ts::TVMModImport(self.handle, dependent_module.handle)) check_call!(ffi::TVMModImport(self.handle, dependent_module.handle))
} }
/// Loads a module shared library from path. /// Loads a module shared library from path.
pub fn load<P: AsRef<Path>>(path: &P) -> Result<Module> { pub fn load<P: AsRef<Path>>(path: &P) -> Result<Module, Error> {
let ext = path.as_ref().extension()?.to_str()?; let ext = CString::new(
let func = Function::get("module._LoadFromFile", true /* is_global */) path.as_ref()
.expect("API function always exists"); .extension()
let ret: Module = call_packed!(func, path.as_ref().to_str()?, ext)?.try_into()?; .unwrap_or(std::ffi::OsStr::new(""))
.to_str()
.ok_or_else(|| {
format_err!("Bad module load path: `{}`.", path.as_ref().display())
})?,
)?;
let func = Function::get("module._LoadFromFile").expect("API function always exists");
let cpath =
CString::new(path.as_ref().to_str().ok_or_else(|| {
format_err!("Bad module load path: `{}`.", path.as_ref().display())
})?)?;
let ret: Module = call_packed!(func, &cpath, &ext)?.try_into()?;
Ok(ret) Ok(ret)
} }
/// Checks if a target device is enabled for a module. /// Checks if a target device is enabled for a module.
pub fn enabled(&self, target: &str) -> bool { pub fn enabled(&self, target: &str) -> bool {
let func = Function::get("module._Enabled", true /* is_global */) let func = Function::get("module._Enabled").expect("API function always exists");
.expect("API function always exists");
// `unwrap` is safe here because if there is any error during the // `unwrap` is safe here because if there is any error during the
// function call, it would occur in `call_packed!`. // function call, it would occur in `call_packed!`.
let ret: i64 = call_packed!(func, target).unwrap().try_into().unwrap(); let tgt = CString::new(target).unwrap();
let ret: i64 = call_packed!(func, &tgt).unwrap().try_into().unwrap();
ret != 0 ret != 0
} }
/// Returns the underlying module handle. /// Returns the underlying module handle.
pub fn handle(&self) -> ts::TVMModuleHandle { pub fn handle(&self) -> ffi::TVMModuleHandle {
self.handle self.handle
} }
/// Returns true if the underlying module has been dropped and false otherwise.
pub fn is_released(&self) -> bool {
self.is_released
}
} }
impl Drop for Module { impl Drop for Module {
fn drop(&mut self) { fn drop(&mut self) {
if !self.is_released { check_call!(ffi::TVMModFree(self.handle));
check_call!(ts::TVMModFree(self.handle));
self.is_released = true;
}
} }
} }
...@@ -23,34 +23,34 @@ ...@@ -23,34 +23,34 @@
//! [`copy_from_buffer`]:struct.NDArray.html#method.copy_from_buffer //! [`copy_from_buffer`]:struct.NDArray.html#method.copy_from_buffer
//! [`copy_to_ctx`]:struct.NDArray.html#method.copy_to_ctx //! [`copy_to_ctx`]:struct.NDArray.html#method.copy_to_ctx
use std::{convert::TryFrom, mem, os::raw::c_int, ptr, slice}; use std::{convert::TryFrom, mem, os::raw::c_int, ptr, slice, str::FromStr};
use crate::rust_ndarray::{Array, ArrayD}; use failure::Error;
use num_traits::Num; use num_traits::Num;
use rust_ndarray::{Array, ArrayD};
use tvm_common::{ffi, TVMType};
use crate::ts; use crate::{errors, TVMByteArray, TVMContext};
use crate::{Error, ErrorKind, Result, TVMByteArray, TVMContext, TVMType};
/// See the [`module-level documentation`](../ndarray/index.html) for more details. /// See the [`module-level documentation`](../ndarray/index.html) for more details.
/// ///
/// Wrapper around TVM array handle. /// Wrapper around TVM array handle.
#[derive(Debug)] #[derive(Debug)]
pub struct NDArray { pub struct NDArray {
pub(crate) handle: ts::TVMArrayHandle, pub(crate) handle: ffi::TVMArrayHandle,
is_view: bool, is_view: bool,
} }
impl NDArray { impl NDArray {
pub(crate) fn new(handle: ts::TVMArrayHandle, is_view: bool) -> Self { pub(crate) fn new(handle: ffi::TVMArrayHandle) -> Self {
NDArray { NDArray {
handle: handle, handle: handle,
is_view: is_view, is_view: true,
} }
} }
/// Returns the underlying array handle. /// Returns the underlying array handle.
pub fn handle(&self) -> ts::TVMArrayHandle { pub fn handle(&self) -> ffi::TVMArrayHandle {
self.handle self.handle
} }
...@@ -99,12 +99,13 @@ impl NDArray { ...@@ -99,12 +99,13 @@ impl NDArray {
} }
/// Shows whether the underlying ndarray is contiguous in memory or not. /// Shows whether the underlying ndarray is contiguous in memory or not.
pub fn is_contiguous(&self) -> Result<bool> { pub fn is_contiguous(&self) -> Result<bool, Error> {
Ok(match self.strides() { Ok(match self.strides() {
None => true, None => true,
Some(strides) => { Some(strides) => {
// MissingShapeError in case shape is not determined // errors::MissingShapeError in case shape is not determined
self.shape()? self.shape()
.ok_or(errors::MissingShapeError)?
.iter() .iter()
.zip(strides) .zip(strides)
.rfold( .rfold(
...@@ -138,14 +139,16 @@ impl NDArray { ...@@ -138,14 +139,16 @@ impl NDArray {
/// assert_eq!(ndarray.shape(), Some(shape)); /// assert_eq!(ndarray.shape(), Some(shape));
/// assert_eq!(ndarray.to_vec::<i32>().unwrap(), data); /// assert_eq!(ndarray.to_vec::<i32>().unwrap(), data);
/// ``` /// ```
pub fn to_vec<T>(&self) -> Result<Vec<T>> { pub fn to_vec<T>(&self) -> Result<Vec<T>, Error> {
if self.shape().is_none() { ensure!(self.shape().is_some(), errors::EmptyArrayError);
bail!("{}", ErrorKind::EmptyArray); let earr = NDArray::empty(
} self.shape().ok_or(errors::MissingShapeError)?,
let earr = NDArray::empty(self.shape()?, TVMContext::cpu(0), self.dtype()); TVMContext::cpu(0),
self.dtype(),
);
let target = self.copy_to_ndarray(earr)?; let target = self.copy_to_ndarray(earr)?;
let arr = unsafe { *(target.handle) }; let arr = unsafe { *(target.handle) };
let sz = self.size()? as usize; let sz = self.size().ok_or(errors::MissingShapeError)?;
let mut v: Vec<T> = Vec::with_capacity(sz * mem::size_of::<T>()); let mut v: Vec<T> = Vec::with_capacity(sz * mem::size_of::<T>());
unsafe { unsafe {
v.as_mut_ptr() v.as_mut_ptr()
...@@ -156,7 +159,7 @@ impl NDArray { ...@@ -156,7 +159,7 @@ impl NDArray {
} }
/// Converts the NDArray to [`TVMByteArray`]. /// Converts the NDArray to [`TVMByteArray`].
pub fn to_bytearray(&self) -> Result<TVMByteArray> { pub fn to_bytearray(&self) -> Result<TVMByteArray, Error> {
let v = self.to_vec::<u8>()?; let v = self.to_vec::<u8>()?;
Ok(TVMByteArray::from(&v)) Ok(TVMByteArray::from(&v))
} }
...@@ -176,7 +179,7 @@ impl NDArray { ...@@ -176,7 +179,7 @@ impl NDArray {
/// *Note*: if something goes wrong during the copy, it will panic /// *Note*: if something goes wrong during the copy, it will panic
/// from TVM side. See `TVMArrayCopyFromBytes` in `include/tvm/runtime/c_runtime_api.h`. /// from TVM side. See `TVMArrayCopyFromBytes` in `include/tvm/runtime/c_runtime_api.h`.
pub fn copy_from_buffer<T: Num32>(&mut self, data: &mut [T]) { pub fn copy_from_buffer<T: Num32>(&mut self, data: &mut [T]) {
check_call!(ts::TVMArrayCopyFromBytes( check_call!(ffi::TVMArrayCopyFromBytes(
self.handle, self.handle,
data.as_ptr() as *mut _, data.as_ptr() as *mut _,
data.len() * mem::size_of::<T>() data.len() * mem::size_of::<T>()
...@@ -184,27 +187,31 @@ impl NDArray { ...@@ -184,27 +187,31 @@ impl NDArray {
} }
/// Copies the NDArray to another target NDArray. /// Copies the NDArray to another target NDArray.
pub fn copy_to_ndarray(&self, target: NDArray) -> Result<NDArray> { pub fn copy_to_ndarray(&self, target: NDArray) -> Result<NDArray, Error> {
if self.dtype() != target.dtype() { if self.dtype() != target.dtype() {
bail!( bail!(
"{}", "{}",
ErrorKind::TypeMismatch( errors::TypeMismatchError {
format!("{}", self.dtype().to_string()), expected: format!("{}", self.dtype().to_string()),
format!("{}", target.dtype().to_string()), actual: format!("{}", target.dtype().to_string()),
) }
); );
} }
check_call!(ts::TVMArrayCopyFromTo( check_call!(ffi::TVMArrayCopyFromTo(
self.handle, self.handle,
target.handle, target.handle,
ptr::null_mut() as ts::TVMStreamHandle ptr::null_mut() as ffi::TVMStreamHandle
)); ));
Ok(target) Ok(target)
} }
/// Copies the NDArray to a target context. /// Copies the NDArray to a target context.
pub fn copy_to_ctx(&self, target: &TVMContext) -> Result<NDArray> { pub fn copy_to_ctx(&self, target: &TVMContext) -> Result<NDArray, Error> {
let tmp = NDArray::empty(self.shape()?, target.clone(), self.dtype()); let tmp = NDArray::empty(
self.shape().ok_or(errors::MissingShapeError)?,
target.clone(),
self.dtype(),
);
let copy = self.copy_to_ndarray(tmp)?; let copy = self.copy_to_ndarray(tmp)?;
Ok(copy) Ok(copy)
} }
...@@ -214,28 +221,34 @@ impl NDArray { ...@@ -214,28 +221,34 @@ impl NDArray {
rnd: &ArrayD<T>, rnd: &ArrayD<T>,
ctx: TVMContext, ctx: TVMContext,
dtype: TVMType, dtype: TVMType,
) -> Result<Self> { ) -> Result<Self, Error> {
let mut shape = rnd.shape().to_vec(); let mut shape = rnd.shape().to_vec();
let mut nd = NDArray::empty(&mut shape, ctx, dtype); let mut nd = NDArray::empty(&mut shape, ctx, dtype);
let mut buf = Array::from_iter(rnd.into_iter().map(|&v| v as T)); let mut buf = Array::from_iter(rnd.into_iter().map(|&v| v as T));
nd.copy_from_buffer(buf.as_slice_mut()?); nd.copy_from_buffer(
buf.as_slice_mut()
.expect("Array from iter must be contiguous."),
);
Ok(nd) Ok(nd)
} }
/// Allocates and creates an empty NDArray given the shape, context and dtype. /// Allocates and creates an empty NDArray given the shape, context and dtype.
pub fn empty(shape: &[usize], ctx: TVMContext, dtype: TVMType) -> NDArray { pub fn empty(shape: &[usize], ctx: TVMContext, dtype: TVMType) -> NDArray {
let mut handle = ptr::null_mut() as ts::TVMArrayHandle; let mut handle = ptr::null_mut() as ffi::TVMArrayHandle;
check_call!(ts::TVMArrayAlloc( check_call!(ffi::TVMArrayAlloc(
shape.as_ptr() as *const i64, shape.as_ptr() as *const i64,
shape.len() as c_int, shape.len() as c_int,
dtype.inner.code as c_int, dtype.code as c_int,
dtype.inner.bits as c_int, dtype.bits as c_int,
dtype.inner.lanes as c_int, dtype.lanes as c_int,
ctx.device_type.0 as c_int, ctx.device_type.0 as c_int,
ctx.device_id as c_int, ctx.device_id as c_int,
&mut handle as *mut _, &mut handle as *mut _,
)); ));
NDArray::new(handle, false) NDArray {
handle,
is_view: false,
}
} }
} }
...@@ -243,23 +256,25 @@ macro_rules! impl_from_ndarray_rustndarray { ...@@ -243,23 +256,25 @@ macro_rules! impl_from_ndarray_rustndarray {
($type:ty, $type_name:tt) => { ($type:ty, $type_name:tt) => {
impl<'a> TryFrom<&'a NDArray> for ArrayD<$type> { impl<'a> TryFrom<&'a NDArray> for ArrayD<$type> {
type Error = Error; type Error = Error;
fn try_from(nd: &NDArray) -> Result<ArrayD<$type>> { fn try_from(nd: &NDArray) -> Result<ArrayD<$type>, Self::Error> {
if nd.shape().is_none() { ensure!(nd.shape().is_some(), errors::MissingShapeError);
bail!("{}", ErrorKind::EmptyArray); assert_eq!(nd.dtype(), TVMType::from_str($type_name)?, "Type mismatch");
} Ok(Array::from_shape_vec(
assert_eq!(nd.dtype(), TVMType::from($type_name), "Type mismatch"); &*nd.shape().ok_or(errors::MissingShapeError)?,
Ok(Array::from_shape_vec(&*nd.shape()?, nd.to_vec::<$type>()?)?) nd.to_vec::<$type>()?,
)?)
} }
} }
impl<'a> TryFrom<&'a mut NDArray> for ArrayD<$type> { impl<'a> TryFrom<&'a mut NDArray> for ArrayD<$type> {
type Error = Error; type Error = Error;
fn try_from(nd: &mut NDArray) -> Result<ArrayD<$type>> { fn try_from(nd: &mut NDArray) -> Result<ArrayD<$type>, Self::Error> {
if nd.shape().is_none() { ensure!(nd.shape().is_some(), errors::MissingShapeError);
bail!("{}", ErrorKind::EmptyArray); assert_eq!(nd.dtype(), TVMType::from_str($type_name)?, "Type mismatch");
} Ok(Array::from_shape_vec(
assert_eq!(nd.dtype(), TVMType::from($type_name), "Type mismatch"); &*nd.shape().ok_or(errors::MissingShapeError)?,
Ok(Array::from_shape_vec(&*nd.shape()?, nd.to_vec::<$type>()?)?) nd.to_vec::<$type>()?,
)?)
} }
} }
}; };
...@@ -272,7 +287,7 @@ impl_from_ndarray_rustndarray!(f32, "float"); ...@@ -272,7 +287,7 @@ impl_from_ndarray_rustndarray!(f32, "float");
impl Drop for NDArray { impl Drop for NDArray {
fn drop(&mut self) { fn drop(&mut self) {
if !self.is_view { if !self.is_view {
check_call!(ts::TVMArrayFree(self.handle)); check_call!(ffi::TVMArrayFree(self.handle));
} }
} }
} }
...@@ -306,7 +321,7 @@ mod tests { ...@@ -306,7 +321,7 @@ mod tests {
fn basics() { fn basics() {
let shape = &mut [1, 2, 3]; let shape = &mut [1, 2, 3];
let ctx = TVMContext::cpu(0); let ctx = TVMContext::cpu(0);
let ndarray = NDArray::empty(shape, ctx, TVMType::from("int32")); let ndarray = NDArray::empty(shape, ctx, TVMType::from_str("int32").unwrap());
assert_eq!(ndarray.shape().unwrap(), shape); assert_eq!(ndarray.shape().unwrap(), shape);
assert_eq!( assert_eq!(
ndarray.size().unwrap(), ndarray.size().unwrap(),
...@@ -322,7 +337,7 @@ mod tests { ...@@ -322,7 +337,7 @@ mod tests {
let shape = &mut [4]; let shape = &mut [4];
let mut data = vec![1i32, 2, 3, 4]; let mut data = vec![1i32, 2, 3, 4];
let ctx = TVMContext::cpu(0); let ctx = TVMContext::cpu(0);
let mut ndarray = NDArray::empty(shape, ctx, TVMType::from("int32")); let mut ndarray = NDArray::empty(shape, ctx, TVMType::from_str("int32").unwrap());
assert!(ndarray.to_vec::<i32>().is_ok()); assert!(ndarray.to_vec::<i32>().is_ok());
ndarray.copy_from_buffer(&mut data); ndarray.copy_from_buffer(&mut data);
assert_eq!(ndarray.shape().unwrap(), shape); assert_eq!(ndarray.shape().unwrap(), shape);
...@@ -331,7 +346,11 @@ mod tests { ...@@ -331,7 +346,11 @@ mod tests {
assert!(ndarray.is_contiguous().is_ok()); assert!(ndarray.is_contiguous().is_ok());
assert_eq!(ndarray.byte_offset(), 0); assert_eq!(ndarray.byte_offset(), 0);
let mut shape = vec![4]; let mut shape = vec![4];
let e = NDArray::empty(&mut shape, TVMContext::cpu(0), TVMType::from("int32")); let e = NDArray::empty(
&mut shape,
TVMContext::cpu(0),
TVMType::from_str("int32").unwrap(),
);
let nd = ndarray.copy_to_ndarray(e); let nd = ndarray.copy_to_ndarray(e);
assert!(nd.is_ok()); assert!(nd.is_ok());
assert_eq!(nd.unwrap().to_vec::<i32>().unwrap(), data); assert_eq!(nd.unwrap().to_vec::<i32>().unwrap(), data);
...@@ -343,9 +362,13 @@ mod tests { ...@@ -343,9 +362,13 @@ mod tests {
let mut shape = vec![4]; let mut shape = vec![4];
let mut data = vec![1f32, 2., 3., 4.]; let mut data = vec![1f32, 2., 3., 4.];
let ctx = TVMContext::cpu(0); let ctx = TVMContext::cpu(0);
let mut nd_float = NDArray::empty(&mut shape, ctx.clone(), TVMType::from("float32")); let mut nd_float = NDArray::empty(
&mut shape,
ctx.clone(),
TVMType::from_str("float32").unwrap(),
);
nd_float.copy_from_buffer(&mut data); nd_float.copy_from_buffer(&mut data);
let empty_int = NDArray::empty(&mut shape, ctx, TVMType::from("int32")); let empty_int = NDArray::empty(&mut shape, ctx, TVMType::from_str("int32").unwrap());
nd_float.copy_to_ndarray(empty_int).unwrap(); nd_float.copy_to_ndarray(empty_int).unwrap();
} }
...@@ -354,8 +377,12 @@ mod tests { ...@@ -354,8 +377,12 @@ mod tests {
let a = Array::from_shape_vec((2, 2), vec![1f32, 2., 3., 4.]) let a = Array::from_shape_vec((2, 2), vec![1f32, 2., 3., 4.])
.unwrap() .unwrap()
.into_dyn(); .into_dyn();
let nd = let nd = NDArray::from_rust_ndarray(
NDArray::from_rust_ndarray(&a, TVMContext::cpu(0), TVMType::from("float32")).unwrap(); &a,
TVMContext::cpu(0),
TVMType::from_str("float32").unwrap(),
)
.unwrap();
assert_eq!(nd.shape().unwrap(), &mut [2, 2]); assert_eq!(nd.shape().unwrap(), &mut [2, 2]);
let rnd: ArrayD<f32> = ArrayD::try_from(&nd).unwrap(); let rnd: ArrayD<f32> = ArrayD::try_from(&nd).unwrap();
assert!(rnd.all_close(&a, 1e-8f32)); assert!(rnd.all_close(&a, 1e-8f32));
......
//! This module implements the required conversions from Rust types to TVM types.
//!
//! In TVM frontend only conversions from Rust's 32-bits (POD) numeric types (i32, u32, f32)
//! and 64-bits pointers are supported.
use std::{
fmt::{self, Display, Formatter},
ops::{Deref, DerefMut},
};
use crate::ts;
use crate::{Function, Module, NDArray, TVMByteArray, TVMContext, TVMDeviceType, TVMTypeCode};
macro_rules! impl_prim_type {
($type:ty, $variant:ident) => {
impl From<$type> for TVMTypeCode {
fn from(_arg: $type) -> Self {
TVMTypeCode::$variant
}
}
impl<'a> From<&'a $type> for TVMTypeCode {
fn from(_arg: &$type) -> Self {
TVMTypeCode::$variant
}
}
impl<'a> From<&'a mut $type> for TVMTypeCode {
fn from(_arg: &mut $type) -> Self {
TVMTypeCode::$variant
}
}
};
}
impl_prim_type!(TVMDeviceType, kDLInt);
impl_prim_type!(TVMContext, kTVMContext);
impl_prim_type!(TVMType, kTVMType);
impl_prim_type!(Function, kFuncHandle);
impl_prim_type!(Module, kModuleHandle);
impl_prim_type!(NDArray, kArrayHandle);
impl_prim_type!(TVMByteArray, kBytes);
/// See the [module-level documentation](../ty/index.html) for more details.
///
/// Wrapper around underlying TVMType
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub struct TVMType {
// inner fields are (code: u8, bits: u8, lanes: u16)
pub inner: ts::TVMType,
}
impl TVMType {
pub(crate) fn new(type_code: u8, bits: u8, lanes: u16) -> Self {
TVMType {
inner: ts::TVMType {
code: type_code,
bits: bits,
lanes: lanes,
},
}
}
}
/// Implements TVMType conversion from `&str` of general format `{dtype}{bits}x{lanes}`
/// such as "int32", "float32" or with lane "float32x1".
impl<'a> From<&'a str> for TVMType {
fn from(type_str: &'a str) -> Self {
if type_str == "bool" {
return TVMType::new(1, 1, 1);
}
let mut type_lanes = type_str.split("x");
let typ = type_lanes.next().expect("Missing dtype");
let lanes = type_lanes
.next()
.map(|l| u16::from_str_radix(l, 10).expect(&format!("Bad dtype lanes: {}", l)))
.unwrap_or(1);
let (type_name, bits) = match typ.find(char::is_numeric) {
Some(idx) => {
let (name, bits_str) = typ.split_at(idx);
(
name,
u8::from_str_radix(bits_str, 10)
.expect(&format!("Bad dtype bits: {}", bits_str)),
)
}
None => (typ, 32),
};
let type_code = match type_name {
"int" => 0,
"uint" => 1,
"float" => 2,
"handle" => 3,
_ => unimplemented!(),
};
TVMType::new(type_code, bits, lanes)
}
}
impl Display for TVMType {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
let ts::TVMType { code, bits, lanes } = self.inner;
if bits == 1 && lanes == 1 {
return write!(f, "bool");
}
let mut tcode_str = match code {
0 => "int",
1 => "uint",
2 => "float",
4 => "handle",
_ => "Unknown",
}
.to_string();
tcode_str += &bits.to_string();
if lanes > 1 {
tcode_str += &format!("x{}", lanes.to_string());
}
f.write_str(&tcode_str)
}
}
impl From<TVMType> for ts::DLDataType {
fn from(dtype: TVMType) -> Self {
dtype.inner
}
}
impl From<ts::DLDataType> for TVMType {
fn from(dtype: ts::DLDataType) -> Self {
Self::new(dtype.code, dtype.bits, dtype.lanes)
}
}
impl Deref for TVMType {
type Target = ts::TVMType;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl DerefMut for TVMType {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.inner
}
}
...@@ -2,139 +2,87 @@ ...@@ -2,139 +2,87 @@
//! and their conversions needed for the types used in frontend crate. //! and their conversions needed for the types used in frontend crate.
//! `TVMRetValue` is the owned version of `TVMPODValue`. //! `TVMRetValue` is the owned version of `TVMPODValue`.
use std::{convert::TryFrom, mem, os::raw::c_void}; use std::{convert::TryFrom, os::raw::c_void};
use failure::Error;
use tvm_common::{
ensure_type,
ffi::{self, TVMValue},
};
use crate::{ use crate::{
common_errors::*, ts, Function, Module, NDArray, TVMArgValue, TVMByteArray, TVMContext, common_errors::*, context::TVMContext, Function, Module, NDArray, TVMArgValue, TVMByteArray,
TVMDeviceType, TVMRetValue, TVMType, TVMTypeCode, TVMValue, TVMRetValue,
}; };
macro_rules! impl_tvm_val_from_handle { macro_rules! impl_tvm_val_from_handle {
($($ty:ty),+) => { ($ty:ident, $type_code:expr, $handle:ty) => {
$( impl<'a> From<&'a $ty> for TVMArgValue<'a> {
impl<'a> From<&'a $ty> for TVMValue {
fn from(arg: &$ty) -> Self { fn from(arg: &$ty) -> Self {
let inner = ts::TVMValue { TVMArgValue {
value: TVMValue {
v_handle: arg.handle as *mut _ as *mut c_void, v_handle: arg.handle as *mut _ as *mut c_void,
}; },
Self::new(inner) type_code: $type_code as i64,
} _lifetime: std::marker::PhantomData,
} }
)+
}
}
impl_tvm_val_from_handle!(Module, Function, NDArray);
impl<'a> From<&'a TVMType> for TVMValue {
fn from(ty: &TVMType) -> Self {
let inner = ts::TVMValue { v_type: ty.inner };
Self::new(inner)
} }
}
impl<'a> From<&'a TVMContext> for TVMValue {
fn from(ctx: &TVMContext) -> Self {
let inner = ts::TVMValue {
v_ctx: ctx.clone().into(),
};
Self::new(inner)
} }
}
impl<'a> From<&'a TVMDeviceType> for TVMValue { impl<'a> From<&'a mut $ty> for TVMArgValue<'a> {
fn from(dev: &TVMDeviceType) -> Self { fn from(arg: &mut $ty) -> Self {
let inner = ts::TVMValue { TVMArgValue {
v_int64: dev.0 as i64, value: TVMValue {
}; v_handle: arg.handle as *mut _ as *mut c_void,
Self::new(inner) },
} type_code: $type_code as i64,
} _lifetime: std::marker::PhantomData,
impl<'a> From<&'a TVMByteArray> for TVMValue {
fn from(barr: &TVMByteArray) -> Self {
let inner = ts::TVMValue {
v_handle: &barr.inner as *const ts::TVMByteArray as *mut c_void,
};
Self::new(inner)
} }
}
impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for NDArray {
type Error = Error;
fn try_from(arg: &TVMArgValue<'a>) -> Result<Self> {
if arg.type_code == TVMTypeCode::kArrayHandle {
let handle = unsafe { arg.value.inner.v_handle };
let arr_handle = unsafe { mem::transmute::<*mut c_void, ts::TVMArrayHandle>(handle) };
Ok(Self::new(arr_handle, true))
} else {
bail!(ErrorKind::TryFromTVMArgValueError(
stringify!(NDArray).to_string(),
arg.type_code.to_string()
))
} }
} }
}
impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for Module { impl<'a, 'v> TryFrom<&'a TVMArgValue<'v>> for $ty {
type Error = Error; type Error = Error;
fn try_from(arg: &TVMArgValue<'a>) -> Result<Self> { fn try_from(arg: &TVMArgValue<'v>) -> Result<$ty, Self::Error> {
if arg.type_code == TVMTypeCode::kModuleHandle { ensure_type!(arg, $type_code);
let handle = unsafe { arg.value.inner.v_handle }; Ok($ty::new(unsafe { arg.value.v_handle as $handle }))
Ok(Self::new(handle, false))
} else {
bail!(ErrorKind::TryFromTVMArgValueError(
stringify!(Module).to_string(),
arg.type_code.to_string()
))
} }
} }
}
impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for TVMByteArray { impl From<$ty> for TVMRetValue {
type Error = Error; fn from(val: $ty) -> TVMRetValue {
fn try_from(arg: &TVMArgValue<'a>) -> Result<Self> { TVMRetValue {
if arg.type_code == TVMTypeCode::kBytes { value: TVMValue {
unsafe { v_handle: val.handle() as *mut c_void,
let barr_ptr = },
mem::transmute::<*mut c_void, *mut ts::TVMByteArray>(arg.value.inner.v_handle); box_value: box val,
Ok(Self::new(*barr_ptr)) type_code: $type_code as i64,
} }
} else {
bail!(ErrorKind::TryFromTVMArgValueError(
stringify!(TVMByteArray).to_string(),
arg.type_code.to_string()
))
} }
} }
}
impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for TVMType { impl TryFrom<TVMRetValue> for $ty {
type Error = Error; type Error = Error;
fn try_from(arg: &TVMArgValue<'a>) -> Result<Self> { fn try_from(ret: TVMRetValue) -> Result<$ty, Self::Error> {
if arg.type_code == TVMTypeCode::kTVMType { ensure_type!(ret, $type_code);
let ty = unsafe { arg.value.inner.v_type }; Ok($ty::new(unsafe { ret.value.v_handle as $handle }))
Ok(TVMType::from(ty))
} else {
bail!(ErrorKind::TryFromTVMArgValueError(
stringify!(TVMType).to_string(),
arg.type_code.to_string()
))
} }
} }
};
} }
impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for TVMContext { impl_tvm_val_from_handle!(
type Error = Error; Function,
fn try_from(arg: &TVMArgValue<'a>) -> Result<Self> { ffi::TVMTypeCode_kFuncHandle,
if arg.type_code == TVMTypeCode::kTVMContext { ffi::TVMFunctionHandle
let ty = unsafe { arg.value.inner.v_ctx }; );
Ok(TVMContext::from(ty)) impl_tvm_val_from_handle!(Module, ffi::TVMTypeCode_kModuleHandle, ffi::TVMModuleHandle);
} else { impl_tvm_val_from_handle!(NDArray, ffi::TVMTypeCode_kArrayHandle, ffi::TVMArrayHandle);
bail!(ErrorKind::TryFromTVMArgValueError(
stringify!(TVMContext).to_string(), impl<'a> From<&'a TVMByteArray> for TVMValue {
arg.type_code.to_string() fn from(barr: &TVMByteArray) -> Self {
)) TVMValue {
v_handle: &barr.inner as *const ffi::TVMByteArray as *mut c_void,
} }
} }
} }
...@@ -144,78 +92,43 @@ macro_rules! impl_boxed_ret_value { ...@@ -144,78 +92,43 @@ macro_rules! impl_boxed_ret_value {
impl From<$type> for TVMRetValue { impl From<$type> for TVMRetValue {
fn from(val: $type) -> Self { fn from(val: $type) -> Self {
TVMRetValue { TVMRetValue {
prim_value: 0, value: TVMValue { v_int64: 0 },
box_value: box val, box_value: box val,
type_code: $code, type_code: $code as i64,
} }
} }
} }
impl TryFrom<TVMRetValue> for $type { impl TryFrom<TVMRetValue> for $type {
type Error = Error; type Error = Error;
fn try_from(ret: TVMRetValue) -> Result<$type> { fn try_from(ret: TVMRetValue) -> Result<$type, Self::Error> {
if let Ok(val) = ret.box_value.downcast::<$type>() { if let Ok(val) = ret.box_value.downcast::<$type>() {
Ok(*val) Ok(*val)
} else { } else {
bail!(ErrorKind::TryFromTVMRetValueError( bail!(ValueDowncastError::new($code as i64, ret.type_code as i64))
stringify!($type).to_string(),
ret.type_code.to_string()
))
} }
} }
} }
}; };
} }
impl_boxed_ret_value!(TVMType, TVMTypeCode::kTVMType); impl_boxed_ret_value!(TVMContext, ffi::TVMTypeCode_kTVMContext);
impl_boxed_ret_value!(TVMContext, TVMTypeCode::kTVMContext); impl_boxed_ret_value!(TVMByteArray, ffi::TVMTypeCode_kBytes);
impl_boxed_ret_value!(TVMByteArray, TVMTypeCode::kBytes);
impl TryFrom<TVMRetValue> for Module {
type Error = Error;
fn try_from(ret: TVMRetValue) -> Result<Module> {
if let Ok(handle) = ret.box_value.downcast::<ts::TVMModuleHandle>() {
Ok(Module::new(*handle, false))
} else {
bail!(ErrorKind::TryFromTVMRetValueError(
stringify!(TVMTypeCode::kModuleHandle).to_string(),
ret.type_code.to_string()
))
}
}
}
impl TryFrom<TVMRetValue> for Function {
type Error = Error;
fn try_from(ret: TVMRetValue) -> Result<Function> {
if let Ok(handle) = ret.box_value.downcast::<ts::TVMFunctionHandle>() {
Ok(Function::new(*handle, false, false))
} else {
bail!(ErrorKind::TryFromTVMRetValueError(
stringify!(TVMTypeCode::kFuncHandle).to_string(),
ret.type_code.to_string()
))
}
}
}
impl TryFrom<TVMRetValue> for NDArray { impl<'a, 'v> TryFrom<&'a TVMArgValue<'v>> for TVMByteArray {
type Error = Error; type Error = Error;
fn try_from(ret: TVMRetValue) -> Result<NDArray> { fn try_from(arg: &TVMArgValue<'v>) -> Result<Self, Self::Error> {
if let Ok(handle) = ret.box_value.downcast::<ts::TVMArrayHandle>() { ensure_type!(arg, ffi::TVMTypeCode_kBytes);
Ok(NDArray::new(*handle, false)) Ok(TVMByteArray::new(unsafe {
} else { *(arg.value.v_handle as *mut ffi::TVMByteArray)
bail!(ErrorKind::TryFromTVMRetValueError( }))
stringify!(TVMTypeCode::kArrayHandle).to_string(),
ret.type_code.to_string()
))
}
} }
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use std::convert::TryInto; use std::{convert::TryInto, str::FromStr};
use tvm_common::ffi::TVMType;
#[test] #[test]
fn bytearray() { fn bytearray() {
...@@ -227,7 +140,7 @@ mod tests { ...@@ -227,7 +140,7 @@ mod tests {
#[test] #[test]
fn ty() { fn ty() {
let t = TVMType::from("int32"); let t = TVMType::from_str("int32").unwrap();
let tvm: TVMType = TVMRetValue::from(t).try_into().unwrap(); let tvm: TVMType = TVMRetValue::from(t).try_into().unwrap();
assert_eq!(tvm, t); assert_eq!(tvm, t);
} }
......
extern crate ndarray as rust_ndarray; extern crate ndarray as rust_ndarray;
extern crate tvm_frontend as tvm; extern crate tvm_frontend as tvm;
use std::str::FromStr;
use tvm::*; use tvm::*;
fn main() { fn main() {
...@@ -12,7 +14,7 @@ fn main() { ...@@ -12,7 +14,7 @@ fn main() {
} else { } else {
(TVMContext::gpu(0), "gpu") (TVMContext::gpu(0), "gpu")
}; };
let dtype = TVMType::from("float32"); let dtype = TVMType::from_str("float32").unwrap();
let mut arr = NDArray::empty(shape, ctx, dtype); let mut arr = NDArray::empty(shape, ctx, dtype);
arr.copy_from_buffer(data.as_mut_slice()); arr.copy_from_buffer(data.as_mut_slice());
let mut ret = NDArray::empty(shape, ctx, dtype); let mut ret = NDArray::empty(shape, ctx, dtype);
...@@ -26,8 +28,7 @@ fn main() { ...@@ -26,8 +28,7 @@ fn main() {
function::Builder::from(&mut fadd) function::Builder::from(&mut fadd)
.arg(&arr) .arg(&arr)
.arg(&arr) .arg(&arr)
.set_output(&mut ret) .arg(&mut ret)
.unwrap()
.invoke() .invoke()
.unwrap(); .unwrap();
......
#![feature(extern_crate_item_prelude, try_from)]
#![allow(unused_imports)] #![allow(unused_imports)]
extern crate ndarray as rust_ndarray; extern crate ndarray as rust_ndarray;
...@@ -6,17 +5,23 @@ extern crate ndarray as rust_ndarray; ...@@ -6,17 +5,23 @@ extern crate ndarray as rust_ndarray;
extern crate tvm_frontend as tvm; extern crate tvm_frontend as tvm;
use rust_ndarray::ArrayD; use rust_ndarray::ArrayD;
use std::convert::{TryFrom, TryInto}; use std::{
convert::{TryFrom, TryInto},
str::FromStr,
};
use tvm::*; use tvm::{errors::Error, *};
fn main() { fn main() {
register_global_func! { register_global_func! {
fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue> { fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue, Error> {
let mut ret = 0f32; let mut ret = 0f32;
let shape = &mut [2]; let shape = &mut [2];
for arg in args.iter() { for arg in args.iter() {
let e = NDArray::empty(shape, TVMContext::cpu(0), TVMType::from("float32")); let e = NDArray::empty(
shape, TVMContext::cpu(0),
TVMType::from_str("float32").unwrap()
);
let arg: NDArray = arg.try_into()?; let arg: NDArray = arg.try_into()?;
let arr = arg.copy_to_ndarray(e)?; let arr = arg.copy_to_ndarray(e)?;
let rnd: ArrayD<f32> = ArrayD::try_from(&arr)?; let rnd: ArrayD<f32> = ArrayD::try_from(&arr)?;
...@@ -28,12 +33,16 @@ fn main() { ...@@ -28,12 +33,16 @@ fn main() {
let shape = &mut [2]; let shape = &mut [2];
let mut data = vec![3f32, 4.0]; let mut data = vec![3f32, 4.0];
let mut arr = NDArray::empty(shape, TVMContext::cpu(0), TVMType::from("float32")); let mut arr = NDArray::empty(
shape,
TVMContext::cpu(0),
TVMType::from_str("float32").unwrap(),
);
arr.copy_from_buffer(data.as_mut_slice()); arr.copy_from_buffer(data.as_mut_slice());
let mut registered = function::Builder::default(); let mut registered = function::Builder::default();
let ret: f32 = registered let ret: f32 = registered
.get_function("sum", true) .get_function("sum")
.arg(&arr) .arg(&arr)
.arg(&arr) .arg(&arr)
.invoke() .invoke()
......
#![feature(extern_crate_item_prelude, panic_info_message)] #![feature(panic_info_message)]
#![allow(unused_imports)] #![allow(unused_imports)]
use std::panic; use std::panic;
...@@ -6,20 +6,20 @@ use std::panic; ...@@ -6,20 +6,20 @@ use std::panic;
#[macro_use] #[macro_use]
extern crate tvm_frontend as tvm; extern crate tvm_frontend as tvm;
use tvm::*; use tvm::{errors::Error, *};
fn main() { fn main() {
register_global_func! { register_global_func! {
fn error(_args: &[TVMArgValue]) -> Result<TVMRetValue> { fn error(_args: &[TVMArgValue]) -> Result<TVMRetValue, Error> {
Err(ErrorKind::TypeMismatch( Err(errors::TypeMismatchError{
format!("{}", "i64".to_string()), expected: "i64".to_string(),
format!("{}", "f64".to_string()), actual: "f64".to_string(),
).into()) }.into())
} }
} }
let mut registered = function::Builder::default(); let mut registered = function::Builder::default();
registered.get_function("error", true); registered.get_function("error");
assert!(registered.func.is_some()); assert!(registered.func.is_some());
registered.args(&[10, 20]); registered.args(&[10, 20]);
......
#![feature(extern_crate_item_prelude, try_from)]
#![allow(unused_imports)] #![allow(unused_imports)]
#[macro_use] #[macro_use]
extern crate tvm_frontend as tvm; extern crate tvm_frontend as tvm;
use std::convert::TryInto; use std::convert::TryInto;
use tvm::*; use tvm::{errors::Error, *};
fn main() { fn main() {
register_global_func! { register_global_func! {
fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue> { fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue, Error> {
let mut ret = 0.0; let mut ret = 0.0;
for arg in args.iter() { for arg in args.into_iter() {
let val: f64 = arg.try_into()?; let val: f64 = arg.try_into()?;
ret += val; ret += val;
} }
Ok(TVMRetValue::from(&ret)) Ok(TVMRetValue::from(ret))
} }
} }
let mut registered = function::Builder::default(); let mut registered = function::Builder::default();
registered.get_function("sum", true); registered.get_function("sum");
assert!(registered.func.is_some()); assert!(registered.func.is_some());
let ret: f64 = registered let ret: f64 = registered
.args(&[10.0f64, 20.0, 30.0]) .args(&[10.0f64, 20.0, 30.0])
......
#![feature(extern_crate_item_prelude, try_from)]
#![allow(unused_imports)] #![allow(unused_imports)]
extern crate tvm_frontend as tvm; extern crate tvm_frontend as tvm;
use std::convert::TryInto; use std::convert::TryInto;
use tvm::*; use tvm::{errors::Error, *};
fn main() { fn main() {
fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue> { fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue, Error> {
let mut ret = 0i64; let mut ret = 0i64;
for arg in args.iter() { for arg in args.iter() {
let val: i64 = arg.try_into()?; let val: i64 = arg.try_into()?;
ret += val; ret += val;
} }
Ok(TVMRetValue::from(&ret)) Ok(TVMRetValue::from(ret))
} }
tvm::function::register(sum, "mysum".to_owned(), false).unwrap(); tvm::function::register(sum, "mysum".to_owned(), false).unwrap();
let mut registered = function::Builder::default(); let mut registered = function::Builder::default();
registered.get_function("mysum", true); registered.get_function("mysum");
assert!(registered.func.is_some()); assert!(registered.func.is_some());
let ret: i64 = registered let ret: i64 = registered
.args(&[10, 20, 30]) .args(&[10, 20, 30])
......
#![feature(extern_crate_item_prelude, try_from)]
#![allow(unused_imports)] #![allow(unused_imports)]
#[macro_use] #[macro_use]
extern crate tvm_frontend as tvm; extern crate tvm_frontend as tvm;
use std::convert::TryInto; use std::convert::TryInto;
use tvm::*; use tvm::{errors::Error, *};
// FIXME // FIXME
fn main() { fn main() {
register_global_func! { register_global_func! {
fn concate_str(args: &[TVMArgValue]) -> Result<TVMRetValue> { fn concate_str(args: &[TVMArgValue]) -> Result<TVMRetValue, Error> {
let mut ret = "".to_string(); let mut ret = "".to_string();
for arg in args.iter() { for arg in args.iter() {
let val: String = arg.try_into()?; let val: &str = arg.try_into()?;
ret += val.as_str(); ret += val;
} }
Ok(TVMRetValue::from(ret)) Ok(TVMRetValue::from(ret))
} }
} }
let a = std::ffi::CString::new("a").unwrap();
let b = std::ffi::CString::new("b").unwrap();
let c = std::ffi::CString::new("c").unwrap();
let mut registered = function::Builder::default(); let mut registered = function::Builder::default();
registered.get_function("concate_str", true); registered.get_function("concate_str");
assert!(registered.func.is_some()); assert!(registered.func.is_some());
let a = "a".to_string();
let b = "b".to_string();
let c = "c".to_string();
let ret: String = registered let ret: String = registered
.args(&[a, b, c]) .arg(&a)
.arg(&b)
.arg(&c)
.invoke() .invoke()
.unwrap() .unwrap()
.try_into() .try_into()
......
...@@ -15,15 +15,15 @@ sgx = ["nom/alloc"] ...@@ -15,15 +15,15 @@ sgx = ["nom/alloc"]
[dependencies] [dependencies]
bounded-spsc-queue = "0.4.0" bounded-spsc-queue = "0.4.0"
error-chain = { version = "0.12.0", default-features = false } failure = "0.1.5"
itertools = "0.7.8" itertools = "0.7.8"
lazy_static = "1.1.0" lazy_static = "1.1.0"
ndarray = "0.11.2" ndarray="0.12.1"
nom = {version = "4.0.0", default-features = false } nom = {version = "4.0.0", default-features = false }
serde = "1.0.59" serde = "1.0.59"
serde_derive = "1.0.79" serde_derive = "1.0.79"
serde_json = "1.0.17" serde_json = "1.0.17"
tvm-common = { version = "0.1.0", path = "../common/", features = ["runtime"] } tvm-common = { version = "0.1.0", path = "../common/" }
[target.'cfg(not(target_env = "sgx"))'.dependencies] [target.'cfg(not(target_env = "sgx"))'.dependencies]
num_cpus = "1.8.0" num_cpus = "1.8.0"
#[cfg(target_env = "sgx")] #[cfg(target_env = "sgx")]
use alloc::alloc::{self, Layout}; use alloc::alloc::{self, Layout, LayoutErr};
#[cfg(not(target_env = "sgx"))] #[cfg(not(target_env = "sgx"))]
use std::alloc::{self, Layout}; use std::alloc::{self, Layout, LayoutErr};
use crate::errors::*;
const DEFAULT_ALIGN_BYTES: usize = 4; const DEFAULT_ALIGN_BYTES: usize = 4;
...@@ -15,7 +13,7 @@ pub struct Allocation { ...@@ -15,7 +13,7 @@ pub struct Allocation {
impl Allocation { impl Allocation {
/// Allocates a chunk of memory of `size` bytes with optional alignment. /// Allocates a chunk of memory of `size` bytes with optional alignment.
pub fn new(size: usize, align: Option<usize>) -> Result<Self> { pub fn new(size: usize, align: Option<usize>) -> Result<Self, LayoutErr> {
let alignment = align.unwrap_or(DEFAULT_ALIGN_BYTES); let alignment = align.unwrap_or(DEFAULT_ALIGN_BYTES);
let layout = Layout::from_size_align(size, alignment)?; let layout = Layout::from_size_align(size, alignment)?;
let ptr = unsafe { alloc::alloc(layout.clone()) }; let ptr = unsafe { alloc::alloc(layout.clone()) };
......
use std::{ use std::{convert::TryFrom, mem, os::raw::c_void, ptr, slice};
any::TypeId,
convert::TryFrom,
mem,
ops::{Deref, DerefMut},
os::raw::{c_int, c_void},
ptr, slice,
};
use failure::Error;
use ndarray; use ndarray;
use tvm_common::{
use crate::{ array::{DataType, TVMContext},
allocator::Allocation, ffi::{
errors::*,
ffi::runtime::{
DLContext, DLDataType, DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLContext, DLDataType, DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt,
DLDataTypeCode_kDLUInt, DLDeviceType_kDLCPU, DLTensor as _DLTensor, DLDataTypeCode_kDLUInt, DLTensor,
}, },
}; };
use crate::allocator::Allocation;
/// A `Storage` is a container which holds `Tensor` data. /// A `Storage` is a container which holds `Tensor` data.
#[derive(PartialEq)] #[derive(PartialEq)]
pub enum Storage<'a> { pub enum Storage<'a> {
...@@ -29,7 +23,7 @@ pub enum Storage<'a> { ...@@ -29,7 +23,7 @@ pub enum Storage<'a> {
} }
impl<'a> Storage<'a> { impl<'a> Storage<'a> {
pub fn new(size: usize, align: Option<usize>) -> Result<Storage<'static>> { pub fn new(size: usize, align: Option<usize>) -> Result<Storage<'static>, Error> {
Ok(Storage::Owned(Allocation::new(size, align)?)) Ok(Storage::Owned(Allocation::new(size, align)?))
} }
...@@ -237,6 +231,27 @@ impl<'a> Tensor<'a> { ...@@ -237,6 +231,27 @@ impl<'a> Tensor<'a> {
byte_offset: 0, byte_offset: 0,
} }
} }
pub(crate) fn as_dltensor(&self, flatten: bool) -> DLTensor {
assert!(!flatten || self.is_contiguous());
DLTensor {
data: unsafe { self.data.as_mut_ptr().offset(self.byte_offset) } as *mut c_void,
ctx: DLContext::from(&self.ctx),
ndim: if flatten { 1 } else { self.shape.len() } as i32,
dtype: DLDataType::from(&self.dtype),
shape: if flatten {
&self.size as *const _ as *mut i64
} else {
self.shape.as_ptr()
} as *mut i64,
strides: if flatten || self.is_contiguous() {
ptr::null_mut()
} else {
self.strides.as_ref().unwrap().as_ptr()
} as *mut i64,
byte_offset: 0,
}
}
} }
/// Conversions to `ndarray::Array` from `Tensor`, if the types match. /// Conversions to `ndarray::Array` from `Tensor`, if the types match.
...@@ -244,7 +259,7 @@ macro_rules! impl_ndarray_try_from_tensor { ...@@ -244,7 +259,7 @@ macro_rules! impl_ndarray_try_from_tensor {
($type:ty, $dtype:expr) => { ($type:ty, $dtype:expr) => {
impl<'a, 't> TryFrom<&'a Tensor<'t>> for ndarray::ArrayD<$type> { impl<'a, 't> TryFrom<&'a Tensor<'t>> for ndarray::ArrayD<$type> {
type Error = Error; type Error = Error;
fn try_from(tensor: &'a Tensor) -> Result<ndarray::ArrayD<$type>> { fn try_from(tensor: &'a Tensor) -> Result<ndarray::ArrayD<$type>, Error> {
ensure!( ensure!(
tensor.dtype == $dtype, tensor.dtype == $dtype,
"Cannot convert Tensor with dtype {:?} to ndarray", "Cannot convert Tensor with dtype {:?} to ndarray",
...@@ -263,120 +278,9 @@ macro_rules! impl_ndarray_try_from_tensor { ...@@ -263,120 +278,9 @@ macro_rules! impl_ndarray_try_from_tensor {
}; };
} }
impl_ndarray_try_from_tensor!(i32, DTYPE_INT32);
impl_ndarray_try_from_tensor!(u32, DTYPE_UINT32);
impl_ndarray_try_from_tensor!(f32, DTYPE_FLOAT32);
impl_ndarray_try_from_tensor!(f64, DTYPE_FLOAT64);
pub struct DLTensor {
pub(crate) inner: _DLTensor,
}
impl Deref for DLTensor {
type Target = _DLTensor;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl DerefMut for DLTensor {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.inner
}
}
impl DLTensor {
pub(crate) fn new(raw: _DLTensor) -> Self {
Self { inner: raw }
}
pub(crate) fn from_tensor<'a>(tensor: &'a Tensor, flatten: bool) -> Self {
assert!(!flatten || tensor.is_contiguous());
Self {
inner: _DLTensor {
data: unsafe { tensor.data.as_mut_ptr().offset(tensor.byte_offset) } as *mut c_void,
ctx: DLContext::from(&tensor.ctx),
ndim: if flatten { 1 } else { tensor.shape.len() } as i32,
dtype: DLDataType::from(&tensor.dtype),
shape: if flatten {
&tensor.size as *const _ as *mut i64
} else {
tensor.shape.as_ptr()
} as *mut i64,
strides: if flatten || tensor.is_contiguous() {
ptr::null_mut()
} else {
tensor.strides.as_ref().unwrap().as_ptr()
} as *mut i64,
byte_offset: 0,
},
}
}
}
impl<'a, 't> From<&'a Tensor<'t>> for DLTensor {
fn from(tensor: &'a Tensor<'t>) -> Self {
DLTensor::from_tensor(tensor, false /* flatten */)
}
}
impl<'a, 't> From<&'a mut Tensor<'t>> for DLTensor {
fn from(tensor: &'a mut Tensor<'t>) -> Self {
DLTensor::from_tensor(tensor, false /* flatten */)
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct DataType {
pub(crate) code: usize,
pub(crate) bits: usize,
pub(crate) lanes: usize,
}
impl DataType {
/// Returns the number of bytes occupied by an element of this `DataType`.
pub fn itemsize(&self) -> usize {
(self.bits * self.lanes) >> 3
}
/// Returns whether this `DataType` represents primitive type `T`.
pub fn is_type<T: 'static>(&self) -> bool {
if self.lanes != 1 {
return false;
}
let typ = TypeId::of::<T>();
(typ == TypeId::of::<i32>() && self.code == 0 && self.bits == 32)
|| (typ == TypeId::of::<i64>() && self.code == 0 && self.bits == 64)
|| (typ == TypeId::of::<u32>() && self.code == 1 && self.bits == 32)
|| (typ == TypeId::of::<u64>() && self.code == 1 && self.bits == 64)
|| (typ == TypeId::of::<f32>() && self.code == 2 && self.bits == 32)
|| (typ == TypeId::of::<f64>() && self.code == 2 && self.bits == 64)
}
}
impl<'a> From<&'a DataType> for DLDataType {
fn from(dtype: &'a DataType) -> Self {
Self {
code: dtype.code as u8,
bits: dtype.bits as u8,
lanes: dtype.lanes as u16,
}
}
}
impl From<DLDataType> for DataType {
fn from(dtype: DLDataType) -> Self {
Self {
code: dtype.code as usize,
bits: dtype.bits as usize,
lanes: dtype.lanes as usize,
}
}
}
macro_rules! make_dtype_const { macro_rules! make_dtype_const {
($name: ident, $code: ident, $bits: expr, $lanes: expr) => { ($name: ident, $code: ident, $bits: expr, $lanes: expr) => {
const $name: DataType = DataType { pub const $name: DataType = DataType {
code: $code as usize, code: $code as usize,
bits: $bits, bits: $bits,
lanes: $lanes, lanes: $lanes,
...@@ -389,28 +293,20 @@ make_dtype_const!(DTYPE_UINT32, DLDataTypeCode_kDLUInt, 32, 1); ...@@ -389,28 +293,20 @@ make_dtype_const!(DTYPE_UINT32, DLDataTypeCode_kDLUInt, 32, 1);
// make_dtype_const!(DTYPE_FLOAT16, DLDataTypeCode_kDLFloat, 16, 1); // make_dtype_const!(DTYPE_FLOAT16, DLDataTypeCode_kDLFloat, 16, 1);
make_dtype_const!(DTYPE_FLOAT32, DLDataTypeCode_kDLFloat, 32, 1); make_dtype_const!(DTYPE_FLOAT32, DLDataTypeCode_kDLFloat, 32, 1);
make_dtype_const!(DTYPE_FLOAT64, DLDataTypeCode_kDLFloat, 64, 1); make_dtype_const!(DTYPE_FLOAT64, DLDataTypeCode_kDLFloat, 64, 1);
impl_ndarray_try_from_tensor!(i32, DTYPE_INT32);
impl_ndarray_try_from_tensor!(u32, DTYPE_UINT32);
impl_ndarray_try_from_tensor!(f32, DTYPE_FLOAT32);
impl_ndarray_try_from_tensor!(f64, DTYPE_FLOAT64);
#[derive(Debug, Clone, Copy, PartialEq)] impl<'a, 't> From<&'a Tensor<'t>> for DLTensor {
pub struct TVMContext { fn from(tensor: &'a Tensor<'t>) -> Self {
pub(crate) device_type: usize, Tensor::as_dltensor(tensor, false /* flatten */)
pub(crate) device_id: usize,
}
impl<'a> From<&'a TVMContext> for DLContext {
fn from(ctx: &'a TVMContext) -> Self {
Self {
device_type: ctx.device_type as u32,
device_id: ctx.device_id as i32,
}
} }
} }
impl Default for TVMContext { impl<'a, 't> From<&'a mut Tensor<'t>> for DLTensor {
fn default() -> Self { fn from(tensor: &'a mut Tensor<'t>) -> Self {
Self { Tensor::as_dltensor(tensor, false /* flatten */)
device_type: DLDeviceType_kDLCPU as usize,
device_id: 0,
}
} }
} }
...@@ -463,42 +359,6 @@ macro_rules! impl_tensor_from_ndarray { ...@@ -463,42 +359,6 @@ macro_rules! impl_tensor_from_ndarray {
}; };
} }
/// `From` conversions to `DLTensor` for `ndarray::Array`.
/// Takes a reference to the `ndarray` since `DLTensor` is not owned.
macro_rules! impl_dltensor_from_ndarray {
($type:ty, $typecode:expr) => {
impl<'a, D: ndarray::Dimension> From<&'a mut ndarray::Array<$type, D>> for DLTensor {
fn from(arr: &'a mut ndarray::Array<$type, D>) -> Self {
DLTensor {
inner: _DLTensor {
data: arr.as_mut_ptr() as *mut c_void,
ctx: DLContext {
device_type: DLDeviceType_kDLCPU,
device_id: 0,
},
ndim: arr.ndim() as c_int,
dtype: DLDataType {
code: $typecode as u8,
bits: 8 * mem::size_of::<$type>() as u8,
lanes: 1,
},
shape: arr.shape().as_ptr() as *const i64 as *mut i64,
strides: arr.strides().as_ptr() as *const isize as *mut i64,
byte_offset: 0,
},
}
}
}
};
}
impl_dltensor_from_ndarray!(f32, DLDataTypeCode_kDLFloat);
impl_dltensor_from_ndarray!(f64, DLDataTypeCode_kDLFloat);
impl_dltensor_from_ndarray!(i32, DLDataTypeCode_kDLInt);
impl_dltensor_from_ndarray!(i64, DLDataTypeCode_kDLInt);
impl_dltensor_from_ndarray!(u32, DLDataTypeCode_kDLUInt);
impl_dltensor_from_ndarray!(u64, DLDataTypeCode_kDLUInt);
impl_tensor_from_ndarray!(f32, DLDataTypeCode_kDLFloat); impl_tensor_from_ndarray!(f32, DLDataTypeCode_kDLFloat);
impl_tensor_from_ndarray!(f64, DLDataTypeCode_kDLFloat); impl_tensor_from_ndarray!(f64, DLDataTypeCode_kDLFloat);
impl_tensor_from_ndarray!(i32, DLDataTypeCode_kDLInt); impl_tensor_from_ndarray!(i32, DLDataTypeCode_kDLInt);
......
#[cfg(target_env = "sgx")] #[derive(Debug, Fail)]
use alloc::alloc; pub enum GraphFormatError {
#[cfg(not(target_env = "sgx"))] #[fail(display = "Could not parse graph json")]
use std::alloc; Parse(#[fail(cause)] failure::Error),
use std::num; #[fail(display = "Could not parse graph params")]
Params,
use crate::common::errors as common_errors; #[fail(display = "{} is missing attr: {}", 0, 1)]
use ndarray; MissingAttr(String, String),
use serde_json; #[fail(display = "Missing field: {}", 0)]
MissingField(&'static str),
error_chain! { #[fail(display = "Invalid DLType: {}", 0)]
errors { InvalidDLType(String),
GraphFormatError(msg: String) {
description("unable to load graph")
display("could not load graph json: {}", msg)
}
LoadGraphParamsError(msg: String) {
description("unable to load graph params")
display("could not load graph params: {}", msg)
}
}
foreign_links {
Alloc(alloc::AllocErr);
GraphDeserialize(serde_json::Error);
ParseInt(num::ParseIntError);
ShapeError(ndarray::ShapeError);
CommonError(common_errors::Error);
}
} }
impl From<alloc::LayoutErr> for Error { #[derive(Debug, Fail)]
fn from(_err: alloc::LayoutErr) -> Error { #[fail(display = "SGX error: 0x{:x}", code)]
Error::from_kind(ErrorKind::Msg("Layout error".to_string())) pub struct SgxError {
} pub code: u32,
} }
use std::{cmp, collections::HashMap, convert::TryFrom, iter::FromIterator, mem, str}; use std::{cmp, collections::HashMap, convert::TryFrom, iter::FromIterator, mem, str};
use failure::Error;
use nom::{alpha1, digit1, le_i32, le_i64, le_u16, le_u32, le_u64, le_u8, types::CompleteStr}; use nom::{alpha1, digit1, le_i32, le_i64, le_u16, le_u32, le_u64, le_u8, types::CompleteStr};
use serde; use serde;
use serde_json; use serde_json;
use tvm_common::{
use super::{DLTensor, DataType, Module, Storage, TVMContext, Tensor}; array::{DataType, TVMContext},
use crate::{ ffi::{DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLDataTypeCode_kDLUInt, DLTensor},
common::value::TVMArgValue, TVMArgValue,
errors::{Error, ErrorKind, Result},
ffi::runtime::{DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLDataTypeCode_kDLUInt},
}; };
use crate::{errors::GraphFormatError, Module, Storage, Tensor};
// @see `kTVMNDArrayMagic` in `ndarray.h` // @see `kTVMNDArrayMagic` in `ndarray.h`
const _NDARRAY_MAGIC: u64 = 0xDD5E40F096B4A13F; const _NDARRAY_MAGIC: u64 = 0xDD5E40F096B4A13F;
// @see `kTVMNDArrayListMagic` in `graph_runtime.h` // @see `kTVMNDArrayListMagic` in `graph_runtime.h`
...@@ -41,28 +42,26 @@ pub struct Entry { ...@@ -41,28 +42,26 @@ pub struct Entry {
} }
impl Graph { impl Graph {
fn entry_index(&self, entry: &Entry) -> Result<usize> { fn entry_index(&self, entry: &Entry) -> Result<usize, GraphFormatError> {
self.node_row_ptr self.node_row_ptr
.as_ref() .as_ref()
.map(|nrp| nrp[entry.id] + entry.index) .map(|nrp| nrp[entry.id] + entry.index)
.ok_or("Missing node_row_ptr.".into()) .ok_or_else(|| GraphFormatError::MissingField("node_row_ptr"))
} }
/// Attempt to deserialize a JSON attribute to a type `T`. /// Attempt to deserialize a JSON attribute to a type `T`.
fn get_attr<T: serde::de::DeserializeOwned>(&self, attr: &str) -> Result<T> { fn get_attr<T: serde::de::DeserializeOwned>(&self, attr: &str) -> Result<T, GraphFormatError> {
Ok(serde_json::from_value::<T>( Ok(serde_json::from_value::<T>(
self.attrs self.attrs
.as_ref() .as_ref()
.ok_or(ErrorKind::GraphFormatError( .ok_or(GraphFormatError::MissingField("attrs"))?
"Missing graph attrs".to_string(),
))?
.get(attr) .get(attr)
.ok_or(ErrorKind::GraphFormatError(format!( .ok_or_else(|| {
"Missing {} attr", GraphFormatError::MissingAttr("graph".to_string(), attr.to_string())
attr })?
)))?
.to_owned(), .to_owned(),
)?) )
.map_err(|err| GraphFormatError::Parse(err.into()))?)
} }
} }
...@@ -81,39 +80,31 @@ struct NodeAttrs { ...@@ -81,39 +80,31 @@ struct NodeAttrs {
flatten_data: bool, flatten_data: bool,
} }
macro_rules! get_node_attr {
($node:expr, $attrs:ident, $attr:literal) => {
$attrs
.get($attr)
.ok_or_else(|| GraphFormatError::MissingAttr($node.to_owned(), $attr.to_owned()))
};
}
impl Node { impl Node {
fn parse_attrs(&self) -> Result<NodeAttrs> { fn parse_attrs(&self) -> Result<NodeAttrs, Error> {
let attrs = self let attrs = self
.attrs .attrs
.as_ref() .as_ref()
.ok_or(format!("Missing node.attrs for `{}`", self.name))?; .ok_or_else(|| GraphFormatError::MissingAttr(self.name.clone(), "attrs".to_owned()))?;
let func_name = attrs
.get("func_name")
.ok_or(format!("Node `{}` is missing attrs.func_name", self.name))?
.to_string();
let num_outputs = attrs
.get("num_outputs")
.ok_or(format!("Node `{}` is missing attrs.num_outputs", self.name))?
.parse::<usize>()?;
let flatten_data = attrs
.get("flatten_data")
.ok_or(format!(
"Node `{}` is missing attrs.flatten_data",
self.name
))?
.parse::<u8>()?
== 1;
Ok(NodeAttrs { Ok(NodeAttrs {
func_name, func_name: get_node_attr!(self.name, attrs, "func_name")?.to_owned(),
num_outputs, num_outputs: get_node_attr!(self.name, attrs, "num_outputs")?.parse::<usize>()?,
flatten_data, flatten_data: get_node_attr!(self.name, attrs, "flatten_data")?.parse::<u8>()? == 1,
}) })
} }
} }
impl<'a> TryFrom<&'a String> for Graph { impl<'a> TryFrom<&'a String> for Graph {
type Error = Error; type Error = Error;
fn try_from(graph_json: &String) -> Result<Self> { fn try_from(graph_json: &String) -> Result<Self, self::Error> {
let graph = serde_json::from_str(graph_json)?; let graph = serde_json::from_str(graph_json)?;
Ok(graph) Ok(graph)
} }
...@@ -121,7 +112,7 @@ impl<'a> TryFrom<&'a String> for Graph { ...@@ -121,7 +112,7 @@ impl<'a> TryFrom<&'a String> for Graph {
impl<'a> TryFrom<&'a str> for Graph { impl<'a> TryFrom<&'a str> for Graph {
type Error = Error; type Error = Error;
fn try_from(graph_json: &'a str) -> Result<Self> { fn try_from(graph_json: &'a str) -> Result<Self, Self::Error> {
let graph = serde_json::from_str(graph_json)?; let graph = serde_json::from_str(graph_json)?;
Ok(graph) Ok(graph)
} }
...@@ -161,7 +152,7 @@ pub struct GraphExecutor<'m, 't> { ...@@ -161,7 +152,7 @@ pub struct GraphExecutor<'m, 't> {
unsafe impl<'m, 't> Send for GraphExecutor<'m, 't> {} unsafe impl<'m, 't> Send for GraphExecutor<'m, 't> {}
impl<'m, 't> GraphExecutor<'m, 't> { impl<'m, 't> GraphExecutor<'m, 't> {
pub fn new<M: 'm + Module>(graph: Graph, lib: &'m M) -> Result<Self> { pub fn new<M: 'm + Module>(graph: Graph, lib: &'m M) -> Result<Self, Error> {
let tensors = Self::setup_storages(&graph)?; let tensors = Self::setup_storages(&graph)?;
Ok(GraphExecutor { Ok(GraphExecutor {
op_execs: Self::setup_op_execs(&graph, lib, &tensors)?, op_execs: Self::setup_op_execs(&graph, lib, &tensors)?,
...@@ -178,7 +169,7 @@ impl<'m, 't> GraphExecutor<'m, 't> { ...@@ -178,7 +169,7 @@ impl<'m, 't> GraphExecutor<'m, 't> {
} }
/// Allocates `Storages` for each `storage_id` and returns `Tensor`s to hold each output. /// Allocates `Storages` for each `storage_id` and returns `Tensor`s to hold each output.
fn setup_storages<'a>(graph: &'a Graph) -> Result<Vec<Tensor<'t>>> { fn setup_storages<'a>(graph: &'a Graph) -> Result<Vec<Tensor<'t>>, Error> {
let storage_ids = graph.get_attr::<(String, Vec<usize>)>("storage_id")?.1; let storage_ids = graph.get_attr::<(String, Vec<usize>)>("storage_id")?.1;
let shapes = graph.get_attr::<(String, Vec<Vec<i64>>)>("shape")?.1; let shapes = graph.get_attr::<(String, Vec<Vec<i64>>)>("shape")?.1;
let dtypes = graph let dtypes = graph
...@@ -189,18 +180,15 @@ impl<'m, 't> GraphExecutor<'m, 't> { ...@@ -189,18 +180,15 @@ impl<'m, 't> GraphExecutor<'m, 't> {
if let Ok((_, dtype)) = tvm_str_to_type(CompleteStr(dltype)) { if let Ok((_, dtype)) = tvm_str_to_type(CompleteStr(dltype)) {
Ok(dtype) Ok(dtype)
} else { } else {
Err(ErrorKind::GraphFormatError( Err(GraphFormatError::InvalidDLType(dltype.to_string()))
format!("Invalid dltype: {}", dltype).to_string(),
)
.into())
} }
}) })
.collect::<Result<Vec<DataType>>>()?; .collect::<Result<Vec<DataType>, GraphFormatError>>()?;
let align = dtypes.iter().map(|dtype| dtype.bits as usize).max(); let align = dtypes.iter().map(|dtype| dtype.bits() as usize).max();
let mut storage_num_bytes = vec![0usize; *storage_ids.iter().max().unwrap_or(&1) + 1]; let mut storage_num_bytes = vec![0usize; *storage_ids.iter().max().unwrap_or(&1) + 1];
for (i, &storage_id) in storage_ids.iter().enumerate() { for (i, &storage_id) in storage_ids.iter().enumerate() {
let dtype_size = dtypes[i].bits * dtypes[i].lanes >> 3; let dtype_size = dtypes[i].bits() * dtypes[i].lanes() >> 3;
let nbytes = dtype_size * shapes[i].iter().product::<i64>() as usize; let nbytes = dtype_size * shapes[i].iter().product::<i64>() as usize;
storage_num_bytes[storage_id] = cmp::max(nbytes, storage_num_bytes[storage_id]); storage_num_bytes[storage_id] = cmp::max(nbytes, storage_num_bytes[storage_id]);
} }
...@@ -208,7 +196,7 @@ impl<'m, 't> GraphExecutor<'m, 't> { ...@@ -208,7 +196,7 @@ impl<'m, 't> GraphExecutor<'m, 't> {
let mut storages: Vec<Storage> = storage_num_bytes let mut storages: Vec<Storage> = storage_num_bytes
.into_iter() .into_iter()
.map(|nbytes| Storage::new(nbytes, align)) .map(|nbytes| Storage::new(nbytes, align))
.collect::<Result<Vec<Storage>>>()?; .collect::<Result<Vec<Storage>, Error>>()?;
let tensors = izip!(storage_ids, shapes, dtypes) let tensors = izip!(storage_ids, shapes, dtypes)
.map(|(storage_id, shape, dtype)| { .map(|(storage_id, shape, dtype)| {
...@@ -233,7 +221,7 @@ impl<'m, 't> GraphExecutor<'m, 't> { ...@@ -233,7 +221,7 @@ impl<'m, 't> GraphExecutor<'m, 't> {
graph: &Graph, graph: &Graph,
lib: &'m M, lib: &'m M,
tensors: &Vec<Tensor<'t>>, tensors: &Vec<Tensor<'t>>,
) -> Result<Vec<Box<Fn() + 'm>>> { ) -> Result<Vec<Box<Fn() + 'm>>, Error> {
ensure!(graph.node_row_ptr.is_some(), "Missing node_row_ptr."); ensure!(graph.node_row_ptr.is_some(), "Missing node_row_ptr.");
let node_row_ptr = graph.node_row_ptr.as_ref().unwrap(); let node_row_ptr = graph.node_row_ptr.as_ref().unwrap();
...@@ -251,9 +239,10 @@ impl<'m, 't> GraphExecutor<'m, 't> { ...@@ -251,9 +239,10 @@ impl<'m, 't> GraphExecutor<'m, 't> {
continue; continue;
} }
let func = lib let func = lib.get_function(&attrs.func_name).ok_or(format_err!(
.get_function(&attrs.func_name) "Library is missing function {}",
.ok_or(format!("Missing function {}", attrs.func_name))?; attrs.func_name
))?;
let arg_indices = node let arg_indices = node
.inputs .inputs
.iter() .iter()
...@@ -264,19 +253,19 @@ impl<'m, 't> GraphExecutor<'m, 't> { ...@@ -264,19 +253,19 @@ impl<'m, 't> GraphExecutor<'m, 't> {
.map(|idx| { .map(|idx| {
let tensor = &tensors[idx?]; let tensor = &tensors[idx?];
Ok(if attrs.flatten_data { Ok(if attrs.flatten_data {
DLTensor::from_tensor(tensor, true /* flatten */) Tensor::as_dltensor(tensor, true /* flatten */)
} else { } else {
DLTensor::from(tensor) DLTensor::from(tensor)
}) })
}) })
.collect::<Result<Vec<DLTensor>>>() .collect::<Result<Vec<DLTensor>, Error>>()
.unwrap(); .unwrap();
let op: Box<Fn()> = box move || { let op: Box<Fn()> = box move || {
let args = dl_tensors let args = dl_tensors
.iter() .iter()
.map(|t| t.into()) .map(|t| t.into())
.collect::<Vec<TVMArgValue>>(); .collect::<Vec<TVMArgValue>>();
func(args.as_slice()); func(args.as_slice()).unwrap();
}; };
op_execs.push(op); op_execs.push(op);
} }
...@@ -344,7 +333,7 @@ impl<'m, 't> GraphExecutor<'m, 't> { ...@@ -344,7 +333,7 @@ impl<'m, 't> GraphExecutor<'m, 't> {
} }
} }
/// Converts a string to TVM DLDataTypeCode. @see `String2TVMType` in packed_func.h // Converts a string to TVM DLDataTypeCode. @see `String2TVMType` in packed_func.h
named!( named!(
tvm_str_to_type<CompleteStr, DataType>, tvm_str_to_type<CompleteStr, DataType>,
do_parse!( do_parse!(
...@@ -367,7 +356,7 @@ named!( ...@@ -367,7 +356,7 @@ named!(
) )
); );
/// Converts a bytes to String. // Converts a bytes to String.
named!( named!(
name<String>, name<String>,
map_res!(length_bytes!(le_u64), |b: &[u8]| String::from_utf8( map_res!(length_bytes!(le_u64), |b: &[u8]| String::from_utf8(
...@@ -375,7 +364,7 @@ named!( ...@@ -375,7 +364,7 @@ named!(
)) ))
); );
/// Parses a TVMContext // Parses a TVMContext
named!( named!(
tvm_ctx<&[u8], TVMContext>, tvm_ctx<&[u8], TVMContext>,
do_parse!( do_parse!(
...@@ -385,7 +374,7 @@ named!( ...@@ -385,7 +374,7 @@ named!(
) )
); );
/// Parses a DataType // Parses a DataType
named!( named!(
data_type<&[u8], DataType>, data_type<&[u8], DataType>,
do_parse!( do_parse!(
...@@ -396,7 +385,7 @@ named!( ...@@ -396,7 +385,7 @@ named!(
) )
); );
/// Parses a Tensor from a TVM array file. // Parses a Tensor from a TVM array file.
named!( named!(
tensor<Tensor>, tensor<Tensor>,
do_parse!( do_parse!(
...@@ -420,7 +409,7 @@ named!( ...@@ -420,7 +409,7 @@ named!(
) )
); );
/// Parses a graph params dict from a params binary file. // Parses a graph params dict from a params binary file.
named!( named!(
parse_param_dict<HashMap<String, Tensor>>, parse_param_dict<HashMap<String, Tensor>>,
do_parse!( do_parse!(
...@@ -433,17 +422,15 @@ named!( ...@@ -433,17 +422,15 @@ named!(
); );
/// Loads a param dict saved using `nnvm.compiler.save_param_dict`. /// Loads a param dict saved using `nnvm.compiler.save_param_dict`.
pub fn load_param_dict(bytes: &[u8]) -> Result<HashMap<String, Tensor>> { pub fn load_param_dict(bytes: &[u8]) -> Result<HashMap<String, Tensor>, GraphFormatError> {
if let Ok((remaining_bytes, param_dict)) = parse_param_dict(bytes) { if let Ok((remaining_bytes, param_dict)) = parse_param_dict(bytes) {
if remaining_bytes.len() > 0 { if remaining_bytes.len() == 0 {
bail!(ErrorKind::LoadGraphParamsError("extra input".to_string()))
} else {
Ok(param_dict) Ok(param_dict)
} else {
Err(GraphFormatError::Params)
} }
} else { } else {
bail!(ErrorKind::LoadGraphParamsError( Err(GraphFormatError::Params)
"invalid parameters file".to_string()
))
} }
} }
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
allocator_api, allocator_api,
box_syntax, box_syntax,
fn_traits, fn_traits,
try_from,
unboxed_closures, unboxed_closures,
vec_remove_item vec_remove_item
)] )]
...@@ -25,7 +24,7 @@ extern crate bounded_spsc_queue; ...@@ -25,7 +24,7 @@ extern crate bounded_spsc_queue;
#[cfg(target_env = "sgx")] #[cfg(target_env = "sgx")]
extern crate core; extern crate core;
#[macro_use] #[macro_use]
extern crate error_chain; extern crate failure;
#[macro_use] #[macro_use]
extern crate itertools; extern crate itertools;
#[macro_use] #[macro_use]
...@@ -39,36 +38,45 @@ extern crate serde; ...@@ -39,36 +38,45 @@ extern crate serde;
#[macro_use] #[macro_use]
extern crate serde_derive; extern crate serde_derive;
extern crate serde_json; extern crate serde_json;
extern crate tvm_common as common; extern crate tvm_common;
mod allocator; mod allocator;
mod array; mod array;
pub mod errors; pub mod errors;
mod module;
#[macro_use]
mod packed_func;
mod graph; mod graph;
mod module;
#[cfg(target_env = "sgx")] #[cfg(target_env = "sgx")]
#[macro_use] #[macro_use]
pub mod sgx; pub mod sgx;
mod threading; mod threading;
mod workspace; mod workspace;
pub use crate::common::{errors::*, ffi, TVMArgValue, TVMRetValue}; pub use tvm_common::{
call_packed,
pub use self::{ errors::*,
array::*, errors::*, graph::*, module::*, packed_func::*, threading::*, workspace::*, ffi::{self, DLTensor},
packed_func::{self, *},
TVMArgValue, TVMRetValue,
}; };
#[cfg(target_env = "sgx")] pub use self::{array::*, errors::*, graph::*, module::*, threading::*, workspace::*};
use self::sgx::ocall_packed_func;
lazy_static! {
static ref LAST_ERROR: std::sync::RwLock<Option<&'static std::ffi::CStr>> =
std::sync::RwLock::new(None);
}
#[no_mangle] #[no_mangle]
pub extern "C" fn TVMAPISetLastError(cmsg: *const i8) { pub extern "C" fn TVMAPISetLastError(cmsg: *const i8) {
#[cfg(not(target_env = "sgx"))] *LAST_ERROR.write().unwrap() = Some(unsafe { std::ffi::CStr::from_ptr(cmsg) });
unsafe {
panic!(std::ffi::CStr::from_ptr(cmsg).to_str().unwrap());
}
#[cfg(target_env = "sgx")] #[cfg(target_env = "sgx")]
ocall_packed!("__sgx_set_last_error__", cmsg); ocall_packed!("__sgx_set_last_error__", cmsg);
} }
#[no_mangle]
pub extern "C" fn TVMGetLastError() -> *const std::os::raw::c_char {
match *LAST_ERROR.read().unwrap() {
Some(err) => err.as_ptr(),
None => std::ptr::null(),
}
}
...@@ -2,29 +2,29 @@ use std::{ ...@@ -2,29 +2,29 @@ use std::{
collections::HashMap, convert::AsRef, ffi::CStr, os::raw::c_char, string::String, sync::Mutex, collections::HashMap, convert::AsRef, ffi::CStr, os::raw::c_char, string::String, sync::Mutex,
}; };
use crate::{ use tvm_common::{
ffi::runtime::BackendPackedCFunc, ffi::BackendPackedCFunc,
packed_func::{wrap_backend_packed_func, PackedFunc}, packed_func::{PackedFunc, TVMArgValue, TVMRetValue, TVMValue},
}; };
pub trait Module { pub trait Module {
fn get_function<S: AsRef<str>>(&self, name: S) -> Option<PackedFunc>; fn get_function<S: AsRef<str>>(&self, name: S) -> Option<&(dyn PackedFunc)>;
} }
pub struct SystemLibModule; pub struct SystemLibModule;
lazy_static! { lazy_static! {
static ref SYSTEM_LIB_FUNCTIONS: Mutex<HashMap<String, BackendPackedCFunc>> = static ref SYSTEM_LIB_FUNCTIONS: Mutex<HashMap<String, &'static (dyn PackedFunc)>> =
Mutex::new(HashMap::new()); Mutex::new(HashMap::new());
} }
impl Module for SystemLibModule { impl Module for SystemLibModule {
fn get_function<S: AsRef<str>>(&self, name: S) -> Option<PackedFunc> { fn get_function<S: AsRef<str>>(&self, name: S) -> Option<&(dyn PackedFunc)> {
SYSTEM_LIB_FUNCTIONS SYSTEM_LIB_FUNCTIONS
.lock() .lock()
.unwrap() .unwrap()
.get(name.as_ref()) .get(name.as_ref())
.map(|func| wrap_backend_packed_func(func.to_owned())) .map(|f| *f)
} }
} }
...@@ -34,15 +34,42 @@ impl Default for SystemLibModule { ...@@ -34,15 +34,42 @@ impl Default for SystemLibModule {
} }
} }
// @see `WrapPackedFunc` in `llvm_module.cc`.
pub(super) fn wrap_backend_packed_func(
func_name: String,
func: BackendPackedCFunc,
) -> Box<dyn PackedFunc> {
box move |args: &[TVMArgValue]| {
let exit_code = func(
args.iter()
.map(|ref arg| arg.value)
.collect::<Vec<TVMValue>>()
.as_ptr(),
args.iter()
.map(|ref arg| arg.type_code as i32)
.collect::<Vec<i32>>()
.as_ptr() as *const i32,
args.len() as i32,
);
if exit_code == 0 {
Ok(TVMRetValue::default())
} else {
Err(tvm_common::errors::FuncCallError::get_with_context(
func_name.clone(),
))
}
}
}
#[no_mangle] #[no_mangle]
pub extern "C" fn TVMBackendRegisterSystemLibSymbol( pub extern "C" fn TVMBackendRegisterSystemLibSymbol(
cname: *const c_char, cname: *const c_char,
func: BackendPackedCFunc, func: BackendPackedCFunc,
) -> i32 { ) -> i32 {
let name = unsafe { CStr::from_ptr(cname).to_str().unwrap() }; let name = unsafe { CStr::from_ptr(cname).to_str().unwrap() };
SYSTEM_LIB_FUNCTIONS SYSTEM_LIB_FUNCTIONS.lock().unwrap().insert(
.lock() name.to_string(),
.unwrap() &*Box::leak(wrap_backend_packed_func(name.to_string(), func)),
.insert(name.to_string(), func); );
return 0; return 0;
} }
use std::{convert::TryFrom, marker::PhantomData, os::raw::c_void};
use super::Tensor;
use crate::ffi::runtime::{
BackendPackedCFunc, DLTensor as _DLTensor, TVMTypeCode_kArrayHandle,
TVMTypeCode_kNDArrayContainer, TVMValue as _TVMValue,
};
use super::DLTensor;
use crate::{
common::{TVMArgValue, TVMRetValue, TVMTypeCode, TVMValue},
errors::*,
};
pub type PackedFunc = Box<Fn(&[TVMArgValue]) -> TVMRetValue + Send + Sync>;
/// Calls a packed function and returns a `TVMRetValue`.
///
/// # Example
///
/// `call_packed!(my_tvm_func, &mut arg1, &mut arg2)`
#[macro_export]
macro_rules! call_packed {
($fn:expr, $($args:expr),+) => {
$fn(&[$($args.into(),)+])
};
($fn:expr) => {
$fn(&Vec::new())
};
}
impl<'a> From<&'a DLTensor> for TVMArgValue<'a> {
fn from(arr: &'a DLTensor) -> Self {
let raw = _TVMValue {
v_handle: arr as *const _ as *mut DLTensor as *mut c_void,
};
TVMArgValue {
value: TVMValue::new(raw),
type_code: TVMTypeCode::kArrayHandle,
lifetime: PhantomData,
}
}
}
impl<'a> From<&'a mut DLTensor> for TVMArgValue<'a> {
fn from(arr: &'a mut DLTensor) -> Self {
let raw = _TVMValue {
v_handle: arr as *mut _ as *mut c_void,
};
TVMArgValue {
value: TVMValue::new(raw),
type_code: TVMTypeCode::kArrayHandle,
lifetime: PhantomData,
}
}
}
impl<'a> TryFrom<TVMArgValue<'a>> for Tensor<'a> {
type Error = Error;
fn try_from(val: TVMArgValue<'a>) -> Result<Self> {
ensure!(
val.type_code == TVMTypeCode::kArrayHandle
|| val.type_code == TVMTypeCode::kNDArrayContainer,
"Could not downcast arg. Expected `{}` or `{}`, but got `{}`",
TVMTypeCode::kArrayHandle,
TVMTypeCode::kNDArrayContainer,
val.type_code,
);
let dlt = unsafe { *(val.value.v_handle as *mut _DLTensor as *const _DLTensor) };
Ok(DLTensor::new(dlt).into())
}
}
impl<'a, 't> From<&'t Tensor<'a>> for TVMRetValue {
fn from(val: &'t Tensor<'a>) -> Self {
TVMRetValue {
prim_value: 0,
box_value: box DLTensor::from(val),
type_code: TVMTypeCode::kNDArrayContainer,
}
}
}
impl<'a> TryFrom<TVMRetValue> for Tensor<'a> {
type Error = Error;
fn try_from(ret: TVMRetValue) -> Result<Self> {
ensure!(
ret.type_code == TVMTypeCode::kArrayHandle
|| ret.type_code == TVMTypeCode::kNDArrayContainer,
"Could not downcast arg. Expected `{}` or `{}`, but got `{}`",
TVMTypeCode_kArrayHandle,
TVMTypeCode_kNDArrayContainer,
ret.type_code,
);
let dlt = unsafe { *(ret.prim_value as *mut _DLTensor as *const _DLTensor) };
Ok(DLTensor::new(dlt).into())
}
}
// @see `WrapPackedFunc` in `llvm_module.cc`.
pub(crate) fn wrap_backend_packed_func(func: BackendPackedCFunc) -> PackedFunc {
box move |args: &[TVMArgValue]| {
func(
args.iter()
.map(|ref arg| arg.value.inner)
.collect::<Vec<_TVMValue>>()
.as_ptr(),
args.iter()
.map(|ref arg| arg.type_code as i32)
.collect::<Vec<i32>>()
.as_ptr() as *const i32,
args.len() as i32,
);
TVMRetValue::default()
}
}
...@@ -3,18 +3,17 @@ use std::{ ...@@ -3,18 +3,17 @@ use std::{
os::raw::{c_char, c_int}, os::raw::{c_char, c_int},
}; };
use errors::Result; pub use crate::threading::tvm_run_worker as run_worker;
use ffi::runtime::TVMValue; use crate::{threading::sgx_join_threads, SystemLibModule, TVMArgValue, TVMRetValue};
use runtime::{threading::sgx_join_threads, SystemLibModule, TVMArgValue, TVMRetValue}; use errors::SgxError;
use ffi::TVMValue;
pub use runtime::threading::tvm_run_worker as run_worker;
#[macro_export] #[macro_export]
macro_rules! tvm_ocall { macro_rules! tvm_ocall {
($func: expr) => { ($func: expr) => {
match $func { match $func {
0 => Ok(()), 0 => Ok(()),
err => Err(format!("SGX error: {}", err)), code => Err(SgxError { code }),
} }
}; };
} }
...@@ -33,7 +32,10 @@ extern "C" { ...@@ -33,7 +32,10 @@ extern "C" {
) -> SgxStatus; ) -> SgxStatus;
} }
pub fn ocall_packed_func<S: AsRef<str>>(fn_name: S, args: &[TVMArgValue]) -> Result<TVMRetValue> { pub fn ocall_packed_func<S: AsRef<str>>(
fn_name: S,
args: &[TVMArgValue],
) -> Result<TVMRetValue, SgxError> {
let mut ret_val = TVMValue { v_int64: 0 }; let mut ret_val = TVMValue { v_int64: 0 };
let ret_type_code = 0i64; let ret_type_code = 0i64;
unsafe { unsafe {
...@@ -58,11 +60,11 @@ pub fn ocall_packed_func<S: AsRef<str>>(fn_name: S, args: &[TVMArgValue]) -> Res ...@@ -58,11 +60,11 @@ pub fn ocall_packed_func<S: AsRef<str>>(fn_name: S, args: &[TVMArgValue]) -> Res
#[macro_export] #[macro_export]
macro_rules! ocall_packed { macro_rules! ocall_packed {
($fn_name:expr, $($args:expr),+) => { ($fn_name:expr, $($args:expr),+) => {
ocall_packed_func($fn_name, &[$($args.into(),)+]) $crate::sgx::ocall_packed_func($fn_name, &[$($args.into(),)+])
.expect(concat!("Error calling `", $fn_name, "`")) .expect(concat!("Error calling `", $fn_name, "`"))
}; };
($fn_name:expr) => { ($fn_name:expr) => {
ocall_packed_func($fn_name, &Vec::new()) $crate::sgx::ocall_packed_func($fn_name, &Vec::new())
.expect(concat!("Error calling `", $fn_name, "`")) .expect(concat!("Error calling `", $fn_name, "`"))
} }
} }
......
use std::{ use std::{
os::raw::{c_int, c_void}, os::raw::{c_int, c_void},
sync::{ sync::{
atomic::{AtomicUsize, Ordering, ATOMIC_USIZE_INIT}, atomic::{AtomicUsize, Ordering},
Arc, Barrier, Arc, Barrier,
}, },
}; };
...@@ -18,11 +18,10 @@ use std::{ ...@@ -18,11 +18,10 @@ use std::{
use std::{collections::VecDeque, ptr, sync::Mutex}; use std::{collections::VecDeque, ptr, sync::Mutex};
use bounded_spsc_queue::{self, Producer}; use bounded_spsc_queue::{self, Producer};
use tvm_common::ffi::TVMParallelGroupEnv;
use crate::{errors::*, ffi::runtime::TVMParallelGroupEnv};
#[cfg(target_env = "sgx")] #[cfg(target_env = "sgx")]
use super::{sgx::ocall_packed_func, TVMArgValue, TVMRetValue}; use super::{TVMArgValue, TVMRetValue};
type FTVMParallelLambda = type FTVMParallelLambda =
extern "C" fn(task_id: usize, penv: *const TVMParallelGroupEnv, cdata: *const c_void) -> i32; extern "C" fn(task_id: usize, penv: *const TVMParallelGroupEnv, cdata: *const c_void) -> i32;
...@@ -62,12 +61,11 @@ impl Job { ...@@ -62,12 +61,11 @@ impl Job {
} }
/// Waits for all tasks in this `Job` to be completed. /// Waits for all tasks in this `Job` to be completed.
fn wait(&self) -> Result<()> { fn wait(&self) {
while self.pending.load(Ordering::Acquire) > 0 { while self.pending.load(Ordering::Acquire) > 0 {
#[cfg(not(target_env = "sgx"))] #[cfg(not(target_env = "sgx"))]
thread::yield_now(); thread::yield_now();
} }
Ok(())
} }
} }
...@@ -161,7 +159,7 @@ impl ThreadPool { ...@@ -161,7 +159,7 @@ impl ThreadPool {
} }
tasks.pop().unwrap()(); tasks.pop().unwrap()();
job.wait().unwrap(); job.wait();
} }
fn run_worker(queue: Consumer<Task>) { fn run_worker(queue: Consumer<Task>) {
...@@ -251,7 +249,7 @@ pub extern "C" fn TVMBackendParallelLaunch( ...@@ -251,7 +249,7 @@ pub extern "C" fn TVMBackendParallelLaunch(
cb: cb, cb: cb,
cdata: cdata, cdata: cdata,
req_num_tasks: num_task, req_num_tasks: num_task,
pending: Arc::new(ATOMIC_USIZE_INIT), pending: Arc::new(AtomicUsize::new(0)),
}); });
}); });
} }
...@@ -273,7 +271,7 @@ pub(crate) fn sgx_join_threads() { ...@@ -273,7 +271,7 @@ pub(crate) fn sgx_join_threads() {
cb: poison_pill, cb: poison_pill,
cdata: ptr::null(), cdata: ptr::null(),
req_num_tasks: 0, req_num_tasks: 0,
pending: Arc::new(ATOMIC_USIZE_INIT), pending: Arc::new(AtomicUsize::new(0)),
}); });
}); });
ocall_packed!("__sgx_thread_group_join__", 0); ocall_packed!("__sgx_thread_group_join__", 0);
...@@ -322,8 +320,8 @@ mod tests { ...@@ -322,8 +320,8 @@ mod tests {
#[test] #[test]
fn test_parallel_launch() { fn test_parallel_launch() {
TVMBackendParallelLaunch(flambda, ptr::null(), 6); TVMBackendParallelLaunch(flambda, ptr::null(), 6);
let counter = ATOMIC_USIZE_INIT; let counter = AtomicUsize::new(0);
let task_ids_sum = ATOMIC_USIZE_INIT; let task_ids_sum = AtomicUsize::new(0);
let cdata = (counter, task_ids_sum); let cdata = (counter, task_ids_sum);
let num_tasks = 3; let num_tasks = 3;
TVMBackendParallelLaunch(flambda, &cdata as *const _ as *const c_void, num_tasks); TVMBackendParallelLaunch(flambda, &cdata as *const _ as *const c_void, num_tasks);
......
...@@ -4,8 +4,9 @@ use std::{ ...@@ -4,8 +4,9 @@ use std::{
ptr, ptr,
}; };
use super::allocator::Allocation; use failure::Error;
use crate::errors::*;
use crate::allocator::Allocation;
const WS_ALIGN: usize = 64; // taken from `kTempAllocaAlignment` in `device_api.h` const WS_ALIGN: usize = 64; // taken from `kTempAllocaAlignment` in `device_api.h`
...@@ -24,13 +25,13 @@ impl WorkspacePool { ...@@ -24,13 +25,13 @@ impl WorkspacePool {
} }
} }
fn alloc_new(&mut self, size: usize) -> Result<*mut u8> { fn alloc_new(&mut self, size: usize) -> Result<*mut u8, Error> {
self.workspaces.push(Allocation::new(size, Some(WS_ALIGN))?); self.workspaces.push(Allocation::new(size, Some(WS_ALIGN))?);
self.in_use.push(self.workspaces.len() - 1); self.in_use.push(self.workspaces.len() - 1);
Ok(self.workspaces[self.workspaces.len() - 1].as_mut_ptr()) Ok(self.workspaces[self.workspaces.len() - 1].as_mut_ptr())
} }
fn alloc(&mut self, size: usize) -> Result<*mut u8> { fn alloc(&mut self, size: usize) -> Result<*mut u8, Error> {
if self.free.len() == 0 { if self.free.len() == 0 {
return self.alloc_new(size); return self.alloc_new(size);
} }
...@@ -60,7 +61,7 @@ impl WorkspacePool { ...@@ -60,7 +61,7 @@ impl WorkspacePool {
} }
} }
fn free(&mut self, ptr: *mut u8) -> Result<()> { fn free(&mut self, ptr: *mut u8) -> Result<(), Error> {
let mut ws_idx = None; let mut ws_idx = None;
for i in 0..self.in_use.len() { for i in 0..self.in_use.len() {
let idx = self.in_use[i]; let idx = self.in_use[i];
...@@ -72,7 +73,7 @@ impl WorkspacePool { ...@@ -72,7 +73,7 @@ impl WorkspacePool {
} }
Ok(self Ok(self
.free .free
.push(ws_idx.ok_or("Tried to free nonexistent workspace.")?)) .push(ws_idx.ok_or(format_err!("Tried to free nonexistent workspace."))?))
} }
} }
......
...@@ -5,7 +5,7 @@ license = "Apache-2.0" ...@@ -5,7 +5,7 @@ license = "Apache-2.0"
authors = ["TVM Contributors"] authors = ["TVM Contributors"]
[dependencies] [dependencies]
ndarray = "0.11.2" ndarray="0.12.1"
serde = "1.0.59" serde = "1.0.59"
serde_json = "1.0.17" serde_json = "1.0.17"
tvm-runtime = { path = "../../" } tvm-runtime = { path = "../../" }
......
...@@ -5,7 +5,7 @@ license = "Apache-2.0" ...@@ -5,7 +5,7 @@ license = "Apache-2.0"
authors = ["TVM Contributors"] authors = ["TVM Contributors"]
[dependencies] [dependencies]
ndarray = "0.11.2" ndarray="0.12.1"
tvm-runtime = { path = "../../" } tvm-runtime = { path = "../../" }
[build-dependencies] [build-dependencies]
......
...@@ -17,6 +17,6 @@ fn main() { ...@@ -17,6 +17,6 @@ fn main() {
let mut a_dl: DLTensor = (&mut a).into(); let mut a_dl: DLTensor = (&mut a).into();
let mut b_dl: DLTensor = (&mut b).into(); let mut b_dl: DLTensor = (&mut b).into();
let mut c_dl: DLTensor = (&mut c).into(); let mut c_dl: DLTensor = (&mut c).into();
call_packed!(add, &mut a_dl, &mut b_dl, &mut c_dl); call_packed!(add, &mut a_dl, &mut b_dl, &mut c_dl).unwrap();
assert!(c.all_close(&e, 1e-8f32)); assert!(c.all_close(&e, 1e-8f32));
} }
...@@ -14,11 +14,11 @@ cargo fmt -- --check ...@@ -14,11 +14,11 @@ cargo fmt -- --check
# test common # test common
cd $RUST_DIR/common cd $RUST_DIR/common
cargo build --features runtime cargo build
cargo test --features runtime --tests cargo test --tests
cargo build --features frontend cargo build --features bindings
cargo test --features frontend --tests cargo test --features bindings --tests
# test runtime # test runtime
cd $RUST_DIR/runtime cd $RUST_DIR/runtime
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment