Commit 4968279f by Nick Hynes Committed by Tianqi Chen

[Rust] Unify types between bindings and pure Rust impl (#2616)

parent 71abe36e
Cargo.lock
target/
**/*.rs.bk
*.rs.bk
Cargo.lock
c_runtime_api.rs
target
**/*.rs.bk
Cargo.lock
/tvm-sys/src/bindgen.rs
......@@ -5,9 +5,11 @@ authors = ["TVM Contributors"]
license = "Apache-2.0"
[features]
runtime = []
frontend = ["tvm-sys"]
bindings = []
[dependencies]
error-chain = { version = "0.12.0", default-features = false }
tvm-sys = { version = "0.1.0", path = "tvm-sys", optional = true }
failure = "0.1.5"
ndarray = "0.12.1"
[build-dependencies]
bindgen = "0.37.4"
......@@ -3,23 +3,29 @@ extern crate bindgen;
use std::path::PathBuf;
fn main() {
if cfg!(feature = "bindings") {
println!("cargo:rerun-if-env-changed=TVM_HOME");
println!("cargo:rustc-link-lib=dylib=tvm_runtime");
println!("cargo:rustc-link-search={}/build", env!("TVM_HOME"));
let bindings = bindgen::Builder::default()
}
// @see rust-bindgen#550 for `blacklist_type`
bindgen::Builder::default()
.header(format!(
"{}/include/tvm/runtime/c_runtime_api.h",
env!("TVM_HOME")
))
.header(format!(
"{}/include/tvm/runtime/c_backend_api.h",
env!("TVM_HOME")
))
.clang_arg(format!("-I{}/3rdparty/dlpack/include/", env!("TVM_HOME")))
.blacklist_type("max_align_t") // @see rust-bindgen#550
.blacklist_type("max_align_t")
.layout_tests(false)
.derive_partialeq(true)
.derive_eq(true)
.generate()
.expect("unable to generate bindings");
bindings
.write_to_file(PathBuf::from("src/bindgen.rs"))
.expect("unable to generate bindings")
.write_to_file(PathBuf::from("src/c_runtime_api.rs"))
.expect("can not write the bindings!");
}
use std::{
any::TypeId,
mem,
os::raw::{c_int, c_void},
};
use crate::ffi::{
DLContext, DLDataType, DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLDataTypeCode_kDLUInt,
DLDeviceType_kDLCPU, DLTensor,
};
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct DataType {
pub code: usize,
pub bits: usize,
pub lanes: usize,
}
impl DataType {
/// Returns the number of bytes occupied by an element of this `DataType`.
pub fn itemsize(&self) -> usize {
(self.bits * self.lanes) >> 3
}
/// Returns whether this `DataType` represents primitive type `T`.
pub fn is_type<T: 'static>(&self) -> bool {
if self.lanes != 1 {
return false;
}
let typ = TypeId::of::<T>();
(typ == TypeId::of::<i32>() && self.code == 0 && self.bits == 32)
|| (typ == TypeId::of::<i64>() && self.code == 0 && self.bits == 64)
|| (typ == TypeId::of::<u32>() && self.code == 1 && self.bits == 32)
|| (typ == TypeId::of::<u64>() && self.code == 1 && self.bits == 64)
|| (typ == TypeId::of::<f32>() && self.code == 2 && self.bits == 32)
|| (typ == TypeId::of::<f64>() && self.code == 2 && self.bits == 64)
}
pub fn code(&self) -> usize {
self.code
}
pub fn bits(&self) -> usize {
self.bits
}
pub fn lanes(&self) -> usize {
self.lanes
}
}
impl<'a> From<&'a DataType> for DLDataType {
fn from(dtype: &'a DataType) -> Self {
Self {
code: dtype.code as u8,
bits: dtype.bits as u8,
lanes: dtype.lanes as u16,
}
}
}
impl From<DLDataType> for DataType {
fn from(dtype: DLDataType) -> Self {
Self {
code: dtype.code as usize,
bits: dtype.bits as usize,
lanes: dtype.lanes as usize,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct TVMContext {
pub device_type: usize,
pub device_id: usize,
}
impl<'a> From<&'a TVMContext> for DLContext {
fn from(ctx: &'a TVMContext) -> Self {
Self {
device_type: ctx.device_type as u32,
device_id: ctx.device_id as i32,
}
}
}
impl Default for TVMContext {
fn default() -> Self {
Self {
device_type: DLDeviceType_kDLCPU as usize,
device_id: 0,
}
}
}
/// `From` conversions to `DLTensor` for `ndarray::Array`.
/// Takes a reference to the `ndarray` since `DLTensor` is not owned.
macro_rules! impl_dltensor_from_ndarray {
($type:ty, $typecode:expr) => {
impl<'a, D: ndarray::Dimension> From<&'a mut ndarray::Array<$type, D>> for DLTensor {
fn from(arr: &'a mut ndarray::Array<$type, D>) -> Self {
DLTensor {
data: arr.as_mut_ptr() as *mut c_void,
ctx: DLContext {
device_type: DLDeviceType_kDLCPU,
device_id: 0,
},
ndim: arr.ndim() as c_int,
dtype: DLDataType {
code: $typecode as u8,
bits: 8 * mem::size_of::<$type>() as u8,
lanes: 1,
},
shape: arr.shape().as_ptr() as *const i64 as *mut i64,
strides: arr.strides().as_ptr() as *const isize as *mut i64,
byte_offset: 0,
}
}
}
};
}
impl_dltensor_from_ndarray!(f32, DLDataTypeCode_kDLFloat);
impl_dltensor_from_ndarray!(f64, DLDataTypeCode_kDLFloat);
impl_dltensor_from_ndarray!(i32, DLDataTypeCode_kDLInt);
impl_dltensor_from_ndarray!(i64, DLDataTypeCode_kDLInt);
impl_dltensor_from_ndarray!(u32, DLDataTypeCode_kDLUInt);
impl_dltensor_from_ndarray!(u64, DLDataTypeCode_kDLUInt);
/* automatically generated by rust-bindgen for TVM revision 6292c78 */
pub const TVM_VERSION: &'static [u8; 8usize] = b"0.5.dev\0";
pub const DLPACK_VERSION: u32 = 8;
pub const _STDINT_H: u32 = 1;
pub const _FEATURES_H: u32 = 1;
pub const _DEFAULT_SOURCE: u32 = 1;
pub const __USE_ISOC11: u32 = 1;
pub const __USE_ISOC99: u32 = 1;
pub const __USE_ISOC95: u32 = 1;
pub const __USE_POSIX_IMPLICITLY: u32 = 1;
pub const _POSIX_SOURCE: u32 = 1;
pub const _POSIX_C_SOURCE: u32 = 200809;
pub const __USE_POSIX: u32 = 1;
pub const __USE_POSIX2: u32 = 1;
pub const __USE_POSIX199309: u32 = 1;
pub const __USE_POSIX199506: u32 = 1;
pub const __USE_XOPEN2K: u32 = 1;
pub const __USE_XOPEN2K8: u32 = 1;
pub const _ATFILE_SOURCE: u32 = 1;
pub const __USE_MISC: u32 = 1;
pub const __USE_ATFILE: u32 = 1;
pub const __USE_FORTIFY_LEVEL: u32 = 0;
pub const _STDC_PREDEF_H: u32 = 1;
pub const __STDC_IEC_559__: u32 = 1;
pub const __STDC_IEC_559_COMPLEX__: u32 = 1;
pub const __STDC_ISO_10646__: u32 = 201505;
pub const __STDC_NO_THREADS__: u32 = 1;
pub const __GNU_LIBRARY__: u32 = 6;
pub const __GLIBC__: u32 = 2;
pub const __GLIBC_MINOR__: u32 = 23;
pub const _SYS_CDEFS_H: u32 = 1;
pub const __WORDSIZE: u32 = 64;
pub const __WORDSIZE_TIME64_COMPAT32: u32 = 1;
pub const __SYSCALL_WORDSIZE: u32 = 64;
pub const _BITS_WCHAR_H: u32 = 1;
pub const INT8_MIN: i32 = -128;
pub const INT16_MIN: i32 = -32768;
pub const INT32_MIN: i32 = -2147483648;
pub const INT8_MAX: u32 = 127;
pub const INT16_MAX: u32 = 32767;
pub const INT32_MAX: u32 = 2147483647;
pub const UINT8_MAX: u32 = 255;
pub const UINT16_MAX: u32 = 65535;
pub const UINT32_MAX: u32 = 4294967295;
pub const INT_LEAST8_MIN: i32 = -128;
pub const INT_LEAST16_MIN: i32 = -32768;
pub const INT_LEAST32_MIN: i32 = -2147483648;
pub const INT_LEAST8_MAX: u32 = 127;
pub const INT_LEAST16_MAX: u32 = 32767;
pub const INT_LEAST32_MAX: u32 = 2147483647;
pub const UINT_LEAST8_MAX: u32 = 255;
pub const UINT_LEAST16_MAX: u32 = 65535;
pub const UINT_LEAST32_MAX: u32 = 4294967295;
pub const INT_FAST8_MIN: i32 = -128;
pub const INT_FAST16_MIN: i64 = -9223372036854775808;
pub const INT_FAST32_MIN: i64 = -9223372036854775808;
pub const INT_FAST8_MAX: u32 = 127;
pub const INT_FAST16_MAX: u64 = 9223372036854775807;
pub const INT_FAST32_MAX: u64 = 9223372036854775807;
pub const UINT_FAST8_MAX: u32 = 255;
pub const UINT_FAST16_MAX: i32 = -1;
pub const UINT_FAST32_MAX: i32 = -1;
pub const INTPTR_MIN: i64 = -9223372036854775808;
pub const INTPTR_MAX: u64 = 9223372036854775807;
pub const UINTPTR_MAX: i32 = -1;
pub const PTRDIFF_MIN: i64 = -9223372036854775808;
pub const PTRDIFF_MAX: u64 = 9223372036854775807;
pub const SIG_ATOMIC_MIN: i32 = -2147483648;
pub const SIG_ATOMIC_MAX: u32 = 2147483647;
pub const SIZE_MAX: i32 = -1;
pub const WINT_MIN: u32 = 0;
pub const WINT_MAX: u32 = 4294967295;
pub type int_least8_t = ::std::os::raw::c_schar;
pub type int_least16_t = ::std::os::raw::c_short;
pub type int_least32_t = ::std::os::raw::c_int;
pub type int_least64_t = ::std::os::raw::c_long;
pub type uint_least8_t = ::std::os::raw::c_uchar;
pub type uint_least16_t = ::std::os::raw::c_ushort;
pub type uint_least32_t = ::std::os::raw::c_uint;
pub type uint_least64_t = ::std::os::raw::c_ulong;
pub type int_fast8_t = ::std::os::raw::c_schar;
pub type int_fast16_t = ::std::os::raw::c_long;
pub type int_fast32_t = ::std::os::raw::c_long;
pub type int_fast64_t = ::std::os::raw::c_long;
pub type uint_fast8_t = ::std::os::raw::c_uchar;
pub type uint_fast16_t = ::std::os::raw::c_ulong;
pub type uint_fast32_t = ::std::os::raw::c_ulong;
pub type uint_fast64_t = ::std::os::raw::c_ulong;
pub type intmax_t = ::std::os::raw::c_long;
pub type uintmax_t = ::std::os::raw::c_ulong;
pub type wchar_t = ::std::os::raw::c_int;
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct max_align_t {
pub __clang_max_align_nonce1: ::std::os::raw::c_longlong,
pub __bindgen_padding_0: u64,
pub __clang_max_align_nonce2: f64,
}
pub const DLDeviceType_kDLCPU: DLDeviceType = 1;
pub const DLDeviceType_kDLGPU: DLDeviceType = 2;
pub const DLDeviceType_kDLCPUPinned: DLDeviceType = 3;
pub const DLDeviceType_kDLOpenCL: DLDeviceType = 4;
pub const DLDeviceType_kDLMetal: DLDeviceType = 8;
pub const DLDeviceType_kDLVPI: DLDeviceType = 9;
pub const DLDeviceType_kDLROCM: DLDeviceType = 10;
/// \brief The device type in DLContext.
pub type DLDeviceType = u32;
/// \brief A Device context for Tensor and operator.
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct DLContext {
/// \brief The device type used in the device.
pub device_type: DLDeviceType,
/// \brief The device index
pub device_id: ::std::os::raw::c_int,
}
pub const DLDataTypeCode_kDLInt: DLDataTypeCode = 0;
pub const DLDataTypeCode_kDLUInt: DLDataTypeCode = 1;
pub const DLDataTypeCode_kDLFloat: DLDataTypeCode = 2;
/// \brief The type code options DLDataType.
pub type DLDataTypeCode = u32;
/// \brief The data type the tensor can hold.
///
/// Examples
/// - float: type_code = 2, bits = 32, lanes=1
/// - float4(vectorized 4 float): type_code = 2, bits = 32, lanes=4
/// - int8: type_code = 0, bits = 8, lanes=1
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct DLDataType {
/// \brief Type code of base types.
/// We keep it uint8_t instead of DLDataTypeCode for minimal memory
/// footprint, but the value should be one of DLDataTypeCode enum values.
///
pub code: u8,
/// \brief Number of bits, common choices are 8, 16, 32.
pub bits: u8,
/// \brief Number of lanes in the type, used for vector types.
pub lanes: u16,
}
/// \brief Plain C Tensor object, does not manage memory.
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct DLTensor {
/// \brief The opaque data pointer points to the allocated data.
/// This will be CUDA device pointer or cl_mem handle in OpenCL.
/// This pointer is always aligns to 256 bytes as in CUDA.
pub data: *mut ::std::os::raw::c_void,
/// \brief The device context of the tensor
pub ctx: DLContext,
/// \brief Number of dimensions
pub ndim: ::std::os::raw::c_int,
/// \brief The data type of the pointer
pub dtype: DLDataType,
/// \brief The shape of the tensor
pub shape: *mut i64,
/// \brief strides of the tensor,
/// can be NULL, indicating tensor is compact.
pub strides: *mut i64,
/// \brief The offset in bytes to the beginning pointer to data
pub byte_offset: u64,
}
/// \brief C Tensor object, manage memory of DLTensor. This data structure is
/// intended to faciliate the borrowing of DLTensor by another framework. It is
/// not meant to transfer the tensor. When the borrowing framework doesn't need
/// the tensor, it should call the deleter to notify the host that the resource
/// is no longer needed.
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct DLManagedTensor {
/// \brief DLTensor which is being memory managed
pub dl_tensor: DLTensor,
/// \brief the context of the original host framework of DLManagedTensor in
/// which DLManagedTensor is used in the framework. It can also be NULL.
pub manager_ctx: *mut ::std::os::raw::c_void,
/// \brief Destructor signature void (*)(void*) - this should be called
/// to destruct manager_ctx which holds the DLManagedTensor. It can be NULL
/// if there is no way for the caller to provide a reasonable destructor.
pub deleter: ::std::option::Option<unsafe extern "C" fn(self_: *mut DLManagedTensor)>,
}
/// \brief type of array index.
pub type tvm_index_t = i64;
pub const TVMDeviceExtType_kDLAOCL: TVMDeviceExtType = 5;
pub const TVMDeviceExtType_kDLSDAccel: TVMDeviceExtType = 6;
pub const TVMDeviceExtType_kDLVulkan: TVMDeviceExtType = 7;
pub const TVMDeviceExtType_kOpenGL: TVMDeviceExtType = 11;
pub const TVMDeviceExtType_kExtDev: TVMDeviceExtType = 12;
/// \brief Extension device types in TVM
pub type TVMDeviceExtType = u32;
pub const TVMTypeCode_kHandle: TVMTypeCode = 3;
pub const TVMTypeCode_kNull: TVMTypeCode = 4;
pub const TVMTypeCode_kTVMType: TVMTypeCode = 5;
pub const TVMTypeCode_kTVMContext: TVMTypeCode = 6;
pub const TVMTypeCode_kArrayHandle: TVMTypeCode = 7;
pub const TVMTypeCode_kNodeHandle: TVMTypeCode = 8;
pub const TVMTypeCode_kModuleHandle: TVMTypeCode = 9;
pub const TVMTypeCode_kFuncHandle: TVMTypeCode = 10;
pub const TVMTypeCode_kStr: TVMTypeCode = 11;
pub const TVMTypeCode_kBytes: TVMTypeCode = 12;
pub const TVMTypeCode_kNDArrayContainer: TVMTypeCode = 13;
pub const TVMTypeCode_kExtBegin: TVMTypeCode = 15;
pub const TVMTypeCode_kNNVMFirst: TVMTypeCode = 16;
pub const TVMTypeCode_kNNVMLast: TVMTypeCode = 20;
pub const TVMTypeCode_kExtReserveEnd: TVMTypeCode = 64;
pub const TVMTypeCode_kExtEnd: TVMTypeCode = 128;
/// \brief The type code in TVMType
/// \note TVMType is used in two places.
pub type TVMTypeCode = u32;
/// \brief The data type used in TVM Runtime.
///
/// Examples
/// - float: type_code = 2, bits = 32, lanes=1
/// - float4(vectorized 4 float): type_code = 2, bits = 32, lanes=4
/// - int8: type_code = 0, bits = 8, lanes=1
///
/// \note Arguments TVM API function always takes bits=64 and lanes=1
pub type TVMType = DLDataType;
/// \brief The Device information, abstract away common device types.
pub type TVMContext = DLContext;
/// \brief The tensor array stucture to TVM API.
pub type TVMArray = DLTensor;
/// \brief the array handle
pub type TVMArrayHandle = *mut TVMArray;
/// \brief Union type of values
/// being passed through API and function calls.
#[repr(C)]
#[derive(Copy, Clone)]
pub union TVMValue {
pub v_int64: i64,
pub v_float64: f64,
pub v_handle: *mut ::std::os::raw::c_void,
pub v_str: *const ::std::os::raw::c_char,
pub v_type: TVMType,
pub v_ctx: TVMContext,
_bindgen_union_align: u64,
}
/// \brief Byte array type used to pass in byte array
/// When kBytes is used as data type.
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct TVMByteArray {
pub data: *const ::std::os::raw::c_char,
pub size: usize,
}
/// \brief Handle to TVM runtime modules.
pub type TVMModuleHandle = *mut ::std::os::raw::c_void;
/// \brief Handle to packed function handle.
pub type TVMFunctionHandle = *mut ::std::os::raw::c_void;
/// \brief Handle to hold return value.
pub type TVMRetValueHandle = *mut ::std::os::raw::c_void;
/// \brief The stream that is specific to device
/// can be NULL, which indicates the default one.
pub type TVMStreamHandle = *mut ::std::os::raw::c_void;
extern "C" {
/// \brief Used for implementing C API function.
/// Set last error message before return.
/// \param msg The error message to be set.
pub fn TVMAPISetLastError(msg: *const ::std::os::raw::c_char);
}
extern "C" {
/// \brief return str message of the last error
/// all function in this file will return 0 when success
/// and -1 when an error occured,
/// TVMGetLastError can be called to retrieve the error
///
/// this function is threadsafe and can be called by different thread
/// \return error info
pub fn TVMGetLastError() -> *const ::std::os::raw::c_char;
}
extern "C" {
/// \brief Load module from file.
/// \param file_name The file name to load the module from.
/// \param format The format of the module.
/// \param out The result module
///
/// \return 0 when success, -1 when failure happens
/// \note The resulting module do not contain import relation.
/// It can be reconstructed by TVMModImport.
pub fn TVMModLoadFromFile(
file_name: *const ::std::os::raw::c_char,
format: *const ::std::os::raw::c_char,
out: *mut TVMModuleHandle,
) -> ::std::os::raw::c_int;
}
extern "C" {
/// \brief Add dep to mod's dependency.
/// This allows functions in this module to use modules.
///
/// \param mod The module handle.
/// \param dep The dependent module to be imported.
/// \return 0 when success, -1 when failure happens
pub fn TVMModImport(mod_: TVMModuleHandle, dep: TVMModuleHandle) -> ::std::os::raw::c_int;
}
extern "C" {
/// \brief Get function from the module.
/// \param mod The module handle.
/// \param func_name The name of the function.
/// \param query_imports Whether to query imported modules
/// \param out The result function, can be NULL if it is not available.
/// \return 0 when no error is thrown, -1 when failure happens
pub fn TVMModGetFunction(
mod_: TVMModuleHandle,
func_name: *const ::std::os::raw::c_char,
query_imports: ::std::os::raw::c_int,
out: *mut TVMFunctionHandle,
) -> ::std::os::raw::c_int;
}
extern "C" {
/// \brief Free front-end extension type resource.
/// \param handle The extension handle.
/// \param type_code The type of of the extension type.
/// \return 0 when success, -1 when failure happens
pub fn TVMExtTypeFree(
handle: *mut ::std::os::raw::c_void,
type_code: ::std::os::raw::c_int,
) -> ::std::os::raw::c_int;
}
extern "C" {
/// \brief Free the Module
/// \param mod The module to be freed.
///
/// \note This may not free up the module's resources.
/// If there is active TVMFunctionHandle uses the module
/// Or if this module is imported by another active module.
///
/// The all functions remains valid until TVMFuncFree is called.
/// \return 0 when success, -1 when failure happens
pub fn TVMModFree(mod_: TVMModuleHandle) -> ::std::os::raw::c_int;
}
extern "C" {
/// \brief Free the function when it is no longer needed.
/// \param func The function handle
/// \return 0 when success, -1 when failure happens
pub fn TVMFuncFree(func: TVMFunctionHandle) -> ::std::os::raw::c_int;
}
extern "C" {
/// \brief Call a Packed TVM Function.
///
/// \param func node handle of the function.
/// \param arg_values The arguments
/// \param type_codes The type codes of the arguments
/// \param num_args Number of arguments.
///
/// \param ret_val The return value.
/// \param ret_type_code the type code of return value.
///
/// \return 0 when success, -1 when failure happens
/// \note TVM calls always exchanges with type bits=64, lanes=1
///
/// \note API calls always exchanges with type bits=64, lanes=1
/// If API call returns container handles (e.g. FunctionHandle)
/// these handles should be managed by the front-end.
/// The front-end need to call free function (e.g. TVMFuncFree)
/// to free these handles.
pub fn TVMFuncCall(
func: TVMFunctionHandle,
arg_values: *mut TVMValue,
type_codes: *mut ::std::os::raw::c_int,
num_args: ::std::os::raw::c_int,
ret_val: *mut TVMValue,
ret_type_code: *mut ::std::os::raw::c_int,
) -> ::std::os::raw::c_int;
}
extern "C" {
/// \brief Set the return value of TVMPackedCFunc.
///
/// This function is called by TVMPackedCFunc to set the return value.
/// When this function is not called, the function returns null by default.
///
/// \param ret The return value handle, pass by ret in TVMPackedCFunc
/// \param value The value to be returned.
/// \param type_code The type of the value to be returned.
/// \param num_ret Number of return values, for now only 1 is supported.
pub fn TVMCFuncSetReturn(
ret: TVMRetValueHandle,
value: *mut TVMValue,
type_code: *mut ::std::os::raw::c_int,
num_ret: ::std::os::raw::c_int,
) -> ::std::os::raw::c_int;
}
extern "C" {
/// \brief Inplace translate callback argument value to return value.
/// This is only needed for non-POD arguments.
///
/// \param value The value to be translated.
/// \param code The type code to be translated.
/// \note This function will do a shallow copy when necessary.
///
/// \return 0 when success, -1 when failure happens.
pub fn TVMCbArgToReturn(
value: *mut TVMValue,
code: ::std::os::raw::c_int,
) -> ::std::os::raw::c_int;
}
/// \brief C type of packed function.
///
/// \param args The arguments
/// \param type_codes The type codes of the arguments
/// \param num_args Number of arguments.
/// \param ret The return value handle.
/// \param resource_handle The handle additional resouce handle from fron-end.
/// \return 0 if success, -1 if failure happens, set error via TVMAPISetLastError.
/// \sa TVMCFuncSetReturn
pub type TVMPackedCFunc = ::std::option::Option<
unsafe extern "C" fn(
args: *mut TVMValue,
type_codes: *mut ::std::os::raw::c_int,
num_args: ::std::os::raw::c_int,
ret: TVMRetValueHandle,
resource_handle: *mut ::std::os::raw::c_void,
) -> ::std::os::raw::c_int,
>;
/// \brief C callback to free the resource handle in C packed function.
/// \param resource_handle The handle additional resouce handle from fron-end.
pub type TVMPackedCFuncFinalizer =
::std::option::Option<unsafe extern "C" fn(resource_handle: *mut ::std::os::raw::c_void)>;
/// \brief Signature for extension function declarer.
///
/// TVM call this function to get the extension functions
/// The declarer will call register_func to register function and their name.
///
/// \param register_func_handle The register function
/// \return 0 if success, -1 if failure happens
pub type TVMExtensionFuncDeclarer = ::std::option::Option<
unsafe extern "C" fn(register_func_handle: TVMFunctionHandle) -> ::std::os::raw::c_int,
>;
extern "C" {
/// \brief Wrap a TVMPackedCFunc to become a FunctionHandle.
///
/// The resource_handle will be managed by TVM API, until the function is no longer used.
///
/// \param func The packed C function.
/// \param resource_handle The resource handle from front-end, can be NULL.
/// \param fin The finalizer on resource handle when the FunctionHandle get freed, can be NULL
/// \param out the result function handle.
/// \return 0 when success, -1 when failure happens
pub fn TVMFuncCreateFromCFunc(
func: TVMPackedCFunc,
resource_handle: *mut ::std::os::raw::c_void,
fin: TVMPackedCFuncFinalizer,
out: *mut TVMFunctionHandle,
) -> ::std::os::raw::c_int;
}
extern "C" {
/// \brief Register the function to runtime's global table.
///
/// The registered function then can be pulled by the backend by the name.
///
/// \param name The name of the function.
/// \param f The function to be registered.
/// \param override Whether allow override already registered function.
pub fn TVMFuncRegisterGlobal(
name: *const ::std::os::raw::c_char,
f: TVMFunctionHandle,
override_: ::std::os::raw::c_int,
) -> ::std::os::raw::c_int;
}
extern "C" {
/// \brief Get a global function.
///
/// \param name The name of the function.
/// \param out the result function pointer, NULL if it does not exist.
///
/// \note The function handle of global function is managed by TVM runtime,
/// So TVMFuncFree is should not be called when it get deleted.
pub fn TVMFuncGetGlobal(
name: *const ::std::os::raw::c_char,
out: *mut TVMFunctionHandle,
) -> ::std::os::raw::c_int;
}
extern "C" {
/// \brief List all the globally registered function name
/// \param out_size The number of functions
/// \param out_array The array of function names.
/// \return 0 when success, -1 when failure happens
pub fn TVMFuncListGlobalNames(
out_size: *mut ::std::os::raw::c_int,
out_array: *mut *mut *const ::std::os::raw::c_char,
) -> ::std::os::raw::c_int;
}
extern "C" {
/// \brief Allocate a nd-array's memory,
/// including space of shape, of given spec.
///
/// \param shape The shape of the array, the data content will be copied to out
/// \param ndim The number of dimension of the array.
/// \param dtype_code The type code of the dtype
/// \param dtype_bits The number of bits of dtype
/// \param dtype_lanes The number of lanes in the dtype.
/// \param device_type The device type of context
/// \param device_id The device id of context.
/// \param out The output handle.
/// \return 0 when success, -1 when failure happens
pub fn TVMArrayAlloc(
shape: *const tvm_index_t,
ndim: ::std::os::raw::c_int,
dtype_code: ::std::os::raw::c_int,
dtype_bits: ::std::os::raw::c_int,
dtype_lanes: ::std::os::raw::c_int,
device_type: ::std::os::raw::c_int,
device_id: ::std::os::raw::c_int,
out: *mut TVMArrayHandle,
) -> ::std::os::raw::c_int;
}
extern "C" {
/// \brief Free the TVM Array.
/// \param handle The array handle to be freed.
/// \return 0 when success, -1 when failure happens
pub fn TVMArrayFree(handle: TVMArrayHandle) -> ::std::os::raw::c_int;
}
extern "C" {
/// \brief Copy array data from CPU byte array.
/// \param handle The array handle.
/// \param data the data pointer
/// \param nbytes The number of bytes to copy.
/// \return 0 when success, -1 when failure happens
pub fn TVMArrayCopyFromBytes(
handle: TVMArrayHandle,
data: *mut ::std::os::raw::c_void,
nbytes: usize,
) -> ::std::os::raw::c_int;
}
extern "C" {
/// \brief Copy array data to CPU byte array.
/// \param handle The array handle.
/// \param data the data pointer
/// \param nbytes The number of bytes to copy.
/// \return 0 when success, -1 when failure happens
pub fn TVMArrayCopyToBytes(
handle: TVMArrayHandle,
data: *mut ::std::os::raw::c_void,
nbytes: usize,
) -> ::std::os::raw::c_int;
}
extern "C" {
/// \brief Copy the array, both from and to must be valid during the copy.
/// \param from The array to be copied from.
/// \param to The target space.
/// \param stream The stream where the copy happens, can be NULL.
/// \return 0 when success, -1 when failure happens
pub fn TVMArrayCopyFromTo(
from: TVMArrayHandle,
to: TVMArrayHandle,
stream: TVMStreamHandle,
) -> ::std::os::raw::c_int;
}
extern "C" {
/// \brief Produce an array from the DLManagedTensor that shares data memory
/// with the DLManagedTensor.
/// \param from The source DLManagedTensor.
/// \param out The output array handle.
/// \return 0 when success, -1 when failure happens
pub fn TVMArrayFromDLPack(
from: *mut DLManagedTensor,
out: *mut TVMArrayHandle,
) -> ::std::os::raw::c_int;
}
extern "C" {
/// \brief Produce a DLMangedTensor from the array that shares data memory with
/// the array.
/// \param from The source array.
/// \param out The DLManagedTensor handle.
/// \return 0 when success, -1 when failure happens
pub fn TVMArrayToDLPack(
from: TVMArrayHandle,
out: *mut *mut DLManagedTensor,
) -> ::std::os::raw::c_int;
}
extern "C" {
/// \brief Delete (free) a DLManagedTensor's data.
/// \param dltensor Pointer to the DLManagedTensor.
pub fn TVMDLManagedTensorCallDeleter(dltensor: *mut DLManagedTensor);
}
extern "C" {
/// \brief Create a new runtime stream.
///
/// \param device_type The device type of context
/// \param device_id The device id of context
/// \param out The new stream handle
/// \return 0 when success, -1 when failure happens
pub fn TVMStreamCreate(
device_type: ::std::os::raw::c_int,
device_id: ::std::os::raw::c_int,
out: *mut TVMStreamHandle,
) -> ::std::os::raw::c_int;
}
extern "C" {
/// \brief Free a created stream handle.
///
/// \param device_type The device type of context
/// \param device_id The device id of context
/// \param stream The stream to be freed
/// \return 0 when success, -1 when failure happens
pub fn TVMStreamFree(
device_type: ::std::os::raw::c_int,
device_id: ::std::os::raw::c_int,
stream: TVMStreamHandle,
) -> ::std::os::raw::c_int;
}
extern "C" {
/// \brief Set the runtime stream of current thread to be stream.
/// The subsequent calls to the same device_type
/// will use the setted stream handle.
/// The specific type of stream is runtime device dependent.
///
/// \param device_type The device type of context
/// \param device_id The device id of context.
/// \param handle The stream handle.
/// \return 0 when success, -1 when failure happens
pub fn TVMSetStream(
device_type: ::std::os::raw::c_int,
device_id: ::std::os::raw::c_int,
handle: TVMStreamHandle,
) -> ::std::os::raw::c_int;
}
extern "C" {
/// \brief Wait until all computations on stream completes.
///
/// \param device_type The device type of context
/// \param device_id The device id of context.
/// \param stream The stream to be synchronized.
/// \return 0 when success, -1 when failure happens
pub fn TVMSynchronize(
device_type: ::std::os::raw::c_int,
device_id: ::std::os::raw::c_int,
stream: TVMStreamHandle,
) -> ::std::os::raw::c_int;
}
extern "C" {
/// \brief Synchronize two streams of execution.
///
/// \param device_type The device type of context
/// \param device_id The device id of context
/// \param src The source stream to synchronize.
/// \param dst The destination stream to synchronize.
/// \return 0 when success, -1 when failure happens
pub fn TVMStreamStreamSynchronize(
device_type: ::std::os::raw::c_int,
device_id: ::std::os::raw::c_int,
src: TVMStreamHandle,
dst: TVMStreamHandle,
) -> ::std::os::raw::c_int;
}
extern "C" {
/// \brief Backend function for modules to get function
/// from its environment mod_node (its imports and global function).
/// The user do should not call TVMFuncFree on func.
///
/// \param mod_node The module handle.
/// \param func_name The name of the function.
/// \param out The result function.
/// \return 0 when no error is thrown, -1 when failure happens
pub fn TVMBackendGetFuncFromEnv(
mod_node: *mut ::std::os::raw::c_void,
func_name: *const ::std::os::raw::c_char,
out: *mut TVMFunctionHandle,
) -> ::std::os::raw::c_int;
}
extern "C" {
/// \brief Backend function to register system-wide library symbol.
///
/// \param name The name of the symbol
/// \param ptr The symbol address.
/// \return 0 when no error is thrown, -1 when failure happens
pub fn TVMBackendRegisterSystemLibSymbol(
name: *const ::std::os::raw::c_char,
ptr: *mut ::std::os::raw::c_void,
) -> ::std::os::raw::c_int;
}
extern "C" {
/// \brief Backend function to allocate temporal workspace.
///
/// \note The result allocate spaced is ensured to be aligned to kTempAllocaAlignment.
///
/// \param nbytes The size of the space requested.
/// \param device_type The device type which the space will be allocated.
/// \param device_id The device id which the space will be allocated.
/// \param dtype_code_hint The type code of the array elements. Only used in
/// certain backends such as OpenGL.
/// \param dtype_bits_hint The type bits of the array elements. Only used in
/// certain backends such as OpenGL.
/// \return nullptr when error is thrown, a valid ptr if success
pub fn TVMBackendAllocWorkspace(
device_type: ::std::os::raw::c_int,
device_id: ::std::os::raw::c_int,
nbytes: u64,
dtype_code_hint: ::std::os::raw::c_int,
dtype_bits_hint: ::std::os::raw::c_int,
) -> *mut ::std::os::raw::c_void;
}
extern "C" {
/// \brief Backend function to free temporal workspace.
///
/// \param ptr The result allocated space pointer.
/// \param device_type The device type which the space will be allocated.
/// \param device_id The device id which the space will be allocated.
/// \return 0 when no error is thrown, -1 when failure happens
///
/// \sa TVMBackendAllocWorkspace
pub fn TVMBackendFreeWorkspace(
device_type: ::std::os::raw::c_int,
device_id: ::std::os::raw::c_int,
ptr: *mut ::std::os::raw::c_void,
) -> ::std::os::raw::c_int;
}
/// \brief Environment for TVM parallel task.
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct TVMParallelGroupEnv {
/// \brief Auxiliary used for synchronization
pub sync_handle: *mut ::std::os::raw::c_void,
/// \brief total amount of task
pub num_task: i32,
}
/// \brief The callback function to execute a parallel lambda
/// \param task_id the task id of the function.
/// \param penv The parallel environment backs the execution.
/// \param cdata The supporting closure data.
pub type FTVMParallelLambda = ::std::option::Option<
unsafe extern "C" fn(
task_id: ::std::os::raw::c_int,
penv: *mut TVMParallelGroupEnv,
cdata: *mut ::std::os::raw::c_void,
) -> ::std::os::raw::c_int,
>;
extern "C" {
/// \brief Backend function for running parallel jobs.
///
/// \param flambda The parallel function to be launched.
/// \param cdata The closure data.
/// \param num_task Number of tasks to launch, can be 0, means launch
/// with all available threads.
///
/// \return 0 when no error is thrown, -1 when failure happens
pub fn TVMBackendParallelLaunch(
flambda: FTVMParallelLambda,
cdata: *mut ::std::os::raw::c_void,
num_task: ::std::os::raw::c_int,
) -> ::std::os::raw::c_int;
}
extern "C" {
/// \brief BSP barrrier between parallel threads
/// \param task_id the task id of the function.
/// \param penv The parallel environment backs the execution.
/// \return 0 when no error is thrown, -1 when failure happens
pub fn TVMBackendParallelBarrier(
task_id: ::std::os::raw::c_int,
penv: *mut TVMParallelGroupEnv,
) -> ::std::os::raw::c_int;
}
extern "C" {
/// \brief Simple static initialization function.
/// Run f once and set handle to be not null.
/// This function is mainly used for test purpose.
///
/// \param handle An global address to indicate f
/// \param f The function to be ran
/// \param cdata The closure data to pass to the function.
/// \param nbytes Number of bytes in the closure data.
/// \return 0 when no error is thrown, -1 when failure happens
pub fn TVMBackendRunOnce(
handle: *mut *mut ::std::os::raw::c_void,
f: ::std::option::Option<
unsafe extern "C" fn(arg1: *mut ::std::os::raw::c_void) -> ::std::os::raw::c_int,
>,
cdata: *mut ::std::os::raw::c_void,
nbytes: ::std::os::raw::c_int,
) -> ::std::os::raw::c_int;
}
//! Error types for `TVMArgValue` and `TVMRetValue` conversions.
use std::fmt;
error_chain! {
errors {
TryFromTVMArgValueError(expected: String, actual: String) {
description("mismatched types while converting from TVMArgValue")
display("expected `{}` but given `{}`", expected, actual)
static TYPE_CODE_STRS: [&str; 15] = [
"int",
"uint",
"float",
"handle",
"null",
"TVMType",
"TVMContext",
"ArrayHandle",
"NodeHandle",
"ModuleHandle",
"FuncHandle",
"str",
"bytes",
"NDArrayContainer",
"ExtBegin",
];
#[derive(Debug, Fail)]
pub struct ValueDowncastError {
actual_type_code: i64,
expected_type_code: i64,
}
impl ValueDowncastError {
pub fn new(actual_type_code: i64, expected_type_code: i64) -> Self {
Self {
actual_type_code,
expected_type_code,
}
}
}
TryFromTVMRetValueError(expected: String, actual: String) {
description("mismatched types while downcasting TVMRetValue")
display("invalid downcast: expected `{}` but given `{}`", expected, actual)
impl fmt::Display for ValueDowncastError {
fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
write!(
formatter,
"Could not downcast TVMValue: expected `{}` but was {}",
TYPE_CODE_STRS[self.actual_type_code as usize],
TYPE_CODE_STRS[self.expected_type_code as usize]
)
}
}
#[derive(Debug, Fail)]
#[fail(display = "Function call `{}` returned error: {}", context, message)]
pub struct FuncCallError {
context: String,
message: String,
}
impl FuncCallError {
pub fn get_with_context(context: String) -> Self {
Self {
context,
message: unsafe { std::ffi::CStr::from_ptr(crate::ffi::TVMGetLastError()) }
.to_str()
.expect("double fault")
.to_owned(),
}
}
}
// error_chain! {
// errors {
// TryFromTVMRetValueError(expected_type: String, actual_type_code: i64) {
// description("mismatched types while downcasting TVMRetValue")
// display("invalid downcast: expected `{}` but was `{}`",
// expected_type, type_code_to_string(actual_type_code))
// }
// }
// foreign_links {
// IntoString(std::ffi::IntoStringError);
// ParseInt(std::num::ParseIntError);
// Utf8(std::str::Utf8Error);
// }
// }
//! This crate contains the refactored basic components required
//! for `runtime` and `frontend` TVM crates.
#![crate_name = "tvm_common"]
#![recursion_limit = "1024"]
#![allow(non_camel_case_types, unused_imports)]
#![feature(box_syntax, try_from)]
#![feature(box_syntax, trait_alias)]
#[macro_use]
extern crate error_chain;
extern crate failure;
/// Unified ffi module for both runtime and frontend crates.
pub mod ffi {
#![allow(non_camel_case_types, non_snake_case, non_upper_case_globals, unused)]
#[cfg(feature = "frontend")]
pub extern crate tvm_sys as ts;
#[cfg(feature = "runtime")]
pub mod runtime {
use std::os::raw::{c_char, c_int, c_void};
include!(concat!(env!("CARGO_MANIFEST_DIR"), "/src/c_runtime_api.rs"));
pub type BackendPackedCFunc = extern "C" fn(
args: *const TVMValue,
type_codes: *const c_int,
num_args: c_int,
) -> c_int;
}
pub type BackendPackedCFunc =
extern "C" fn(args: *const TVMValue, type_codes: *const c_int, num_args: c_int) -> c_int;
}
pub mod array;
pub mod errors;
pub mod ty;
#[macro_use]
pub mod packed_func;
pub mod value;
pub use errors::*;
pub use ty::TVMTypeCode;
pub use value::{TVMArgValue, TVMRetValue, TVMValue};
pub use ffi::{TVMContext, TVMType};
pub use packed_func::{TVMArgValue, TVMRetValue};
use std::{any::Any, convert::TryFrom, marker::PhantomData, os::raw::c_void};
use failure::Error;
pub use crate::ffi::TVMValue;
use crate::ffi::*;
pub trait PackedFunc =
Fn(&[TVMArgValue]) -> Result<TVMRetValue, crate::errors::FuncCallError> + Send + Sync;
/// Calls a packed function and returns a `TVMRetValue`.
///
/// # Example
///
/// `call_packed!(my_tvm_func, &mut arg1, &mut arg2)`
#[macro_export]
macro_rules! call_packed {
($fn:expr, $($args:expr),+) => {
$fn(&[$($args.into(),)+])
};
($fn:expr) => {
$fn(&Vec::new())
};
}
/// A borrowed TVMPODValue. Can be constructed using `into()` but the preferred way
/// to obtain a `TVMArgValue` is automatically via `call_packed!`.
#[derive(Clone, Copy)]
pub struct TVMArgValue<'a> {
pub _lifetime: PhantomData<&'a ()>,
pub value: TVMValue,
pub type_code: i64,
}
impl<'a> TVMArgValue<'a> {
pub fn new(value: TVMValue, type_code: i64) -> Self {
TVMArgValue {
_lifetime: PhantomData,
value: value,
type_code: type_code,
}
}
}
#[macro_export]
macro_rules! ensure_type {
($val:ident, $expected_type_code:expr) => {
ensure!(
$val.type_code == $expected_type_code as i64,
$crate::errors::ValueDowncastError::new(
$val.type_code as i64,
$expected_type_code as i64
)
);
};
}
/// Creates a conversion to a `TVMArgValue` for a primitive type and DLDataTypeCode.
macro_rules! impl_prim_tvm_arg {
($type_code:ident, $field:ident, $field_type:ty, [ $( $type:ty ),+ ] ) => {
$(
impl From<$type> for TVMArgValue<'static> {
fn from(val: $type) -> Self {
TVMArgValue {
value: TVMValue { $field: val as $field_type },
type_code: $type_code as i64,
_lifetime: PhantomData,
}
}
}
impl<'a> From<&'a $type> for TVMArgValue<'a> {
fn from(val: &'a $type) -> Self {
TVMArgValue {
value: TVMValue {
$field: val.to_owned() as $field_type,
},
type_code: $type_code as i64,
_lifetime: PhantomData,
}
}
}
impl<'a> TryFrom<TVMArgValue<'a>> for $type {
type Error = Error;
fn try_from(val: TVMArgValue<'a>) -> Result<Self, Self::Error> {
ensure_type!(val, $type_code);
Ok(unsafe { val.value.$field as $type })
}
}
impl<'a> TryFrom<&TVMArgValue<'a>> for $type {
type Error = Error;
fn try_from(val: &TVMArgValue<'a>) -> Result<Self, Self::Error> {
ensure_type!(val, $type_code);
Ok(unsafe { val.value.$field as $type })
}
}
)+
};
}
impl_prim_tvm_arg!(DLDataTypeCode_kDLFloat, v_float64, f64, [f32, f64]);
impl_prim_tvm_arg!(
DLDataTypeCode_kDLInt,
v_int64,
i64,
[i8, i16, i32, i64, isize]
);
impl_prim_tvm_arg!(
DLDataTypeCode_kDLUInt,
v_int64,
i64,
[u8, u16, u32, u64, usize]
);
#[cfg(feature = "bindings")]
// only allow this in bindings because pure-rust can't take ownership of leaked CString
impl<'a> From<&String> for TVMArgValue<'a> {
fn from(string: &String) -> Self {
TVMArgValue {
value: TVMValue {
v_str: std::ffi::CString::new(string.clone()).unwrap().into_raw(),
},
type_code: TVMTypeCode_kStr as i64,
_lifetime: PhantomData,
}
}
}
impl<'a> From<&std::ffi::CString> for TVMArgValue<'a> {
fn from(string: &std::ffi::CString) -> Self {
TVMArgValue {
value: TVMValue {
v_str: string.as_ptr(),
},
type_code: TVMTypeCode_kStr as i64,
_lifetime: PhantomData,
}
}
}
impl<'a> TryFrom<TVMArgValue<'a>> for &str {
type Error = Error;
fn try_from(arg: TVMArgValue<'a>) -> Result<Self, Self::Error> {
ensure_type!(arg, TVMTypeCode_kStr);
Ok(unsafe { std::ffi::CStr::from_ptr(arg.value.v_handle as *const i8) }.to_str()?)
}
}
impl<'a> TryFrom<&TVMArgValue<'a>> for &str {
type Error = Error;
fn try_from(arg: &TVMArgValue<'a>) -> Result<Self, Self::Error> {
ensure_type!(arg, TVMTypeCode_kStr);
Ok(unsafe { std::ffi::CStr::from_ptr(arg.value.v_handle as *const i8) }.to_str()?)
}
}
/// Creates a conversion to a `TVMArgValue` for an object handle.
impl<'a, T> From<*const T> for TVMArgValue<'a> {
fn from(ptr: *const T) -> Self {
TVMArgValue {
value: TVMValue {
v_handle: ptr as *mut T as *mut c_void,
},
type_code: TVMTypeCode_kArrayHandle as i64,
_lifetime: PhantomData,
}
}
}
/// Creates a conversion to a `TVMArgValue` for a mutable object handle.
impl<'a, T> From<*mut T> for TVMArgValue<'a> {
fn from(ptr: *mut T) -> Self {
TVMArgValue {
value: TVMValue {
v_handle: ptr as *mut c_void,
},
type_code: TVMTypeCode_kHandle as i64,
_lifetime: PhantomData,
}
}
}
impl<'a> From<&'a mut DLTensor> for TVMArgValue<'a> {
fn from(arr: &'a mut DLTensor) -> Self {
TVMArgValue {
value: TVMValue {
v_handle: arr as *mut _ as *mut c_void,
},
type_code: TVMTypeCode_kArrayHandle as i64,
_lifetime: PhantomData,
}
}
}
impl<'a> From<&'a DLTensor> for TVMArgValue<'a> {
fn from(arr: &'a DLTensor) -> Self {
TVMArgValue {
value: TVMValue {
v_handle: arr as *const _ as *mut DLTensor as *mut c_void,
},
type_code: TVMTypeCode_kArrayHandle as i64,
_lifetime: PhantomData,
}
}
}
impl<'a, 'v> TryFrom<&'a TVMArgValue<'v>> for TVMType {
type Error = Error;
fn try_from(arg: &'a TVMArgValue<'v>) -> Result<Self, Self::Error> {
ensure_type!(arg, TVMTypeCode_kTVMType);
Ok(unsafe { arg.value.v_type.into() })
}
}
/// An owned TVMPODValue. Can be converted from a variety of primitive and object types.
/// Can be downcasted using `try_from` if it contains the desired type.
///
/// # Example
///
/// ```
/// let a = 42u32;
/// let b: i64 = TVMRetValue::from(a).try_into().unwrap();
///
/// let s = "hello, world!";
/// let t: TVMRetValue = s.into();
/// assert_eq!(String::try_from(t).unwrap(), s);
/// ```
pub struct TVMRetValue {
pub value: TVMValue,
pub box_value: Box<Any>,
pub type_code: i64,
}
impl TVMRetValue {
pub fn from_tvm_value(value: TVMValue, type_code: i64) -> Self {
Self {
value,
type_code,
box_value: box (),
}
}
pub fn into_tvm_value(self) -> (TVMValue, TVMTypeCode) {
(self.value, self.type_code as TVMTypeCode)
}
}
impl Default for TVMRetValue {
fn default() -> Self {
TVMRetValue {
value: TVMValue { v_int64: 0 as i64 },
type_code: 0,
box_value: box (),
}
}
}
macro_rules! impl_pod_ret_value {
($code:expr, [ $( $ty:ty ),+ ] ) => {
$(
impl From<$ty> for TVMRetValue {
fn from(val: $ty) -> Self {
Self {
value: val.into(),
type_code: $code as i64,
box_value: box (),
}
}
}
impl TryFrom<TVMRetValue> for $ty {
type Error = Error;
fn try_from(ret: TVMRetValue) -> Result<$ty, Self::Error> {
ensure_type!(ret, $code);
Ok(ret.value.into())
}
}
)+
};
}
impl_pod_ret_value!(DLDataTypeCode_kDLInt, [i8, i16, i32, i64, isize]);
impl_pod_ret_value!(DLDataTypeCode_kDLUInt, [u8, u16, u32, u64, usize]);
impl_pod_ret_value!(DLDataTypeCode_kDLFloat, [f32, f64]);
impl_pod_ret_value!(TVMTypeCode_kTVMType, [TVMType]);
impl_pod_ret_value!(TVMTypeCode_kTVMContext, [TVMContext]);
impl TryFrom<TVMRetValue> for String {
type Error = Error;
fn try_from(ret: TVMRetValue) -> Result<String, Self::Error> {
ensure_type!(ret, TVMTypeCode_kStr);
let cs = unsafe { std::ffi::CString::from_raw(ret.value.v_handle as *mut i8) };
let ret_str = cs.clone().into_string();
if cfg!(feature = "bindings") {
std::mem::forget(cs); // TVM C++ takes ownership of CString. (@see TVMFuncCall)
}
Ok(ret_str?)
}
}
impl From<String> for TVMRetValue {
fn from(s: String) -> Self {
let cs = std::ffi::CString::new(s).unwrap();
Self {
value: TVMValue {
v_str: cs.into_raw() as *mut i8,
},
box_value: box (),
type_code: TVMTypeCode_kStr as i64,
}
}
}
//! This module containes `TVMTypeCode` and `TVMType` with some conversion methods.
//!
//! # Example
//!
//! ```
//! let dtype = TVMType::from("float");
//! println!("dtype is: {}", dtype);
//! ```
use std::{
ffi::{CStr, CString},
fmt::{self, Display, Formatter},
};
/// TVM type codes.
#[repr(u32)]
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub enum TVMTypeCode {
kDLInt = 0,
kDLUInt = 1,
kDLFloat = 2,
kHandle = 3,
kNull = 4,
kTVMType = 5,
kTVMContext = 6,
kArrayHandle = 7,
kNodeHandle = 8,
kModuleHandle = 9,
kFuncHandle = 10,
kStr = 11,
kBytes = 12,
kNDArrayContainer = 13,
}
impl Default for TVMTypeCode {
fn default() -> Self {
TVMTypeCode::kDLInt
}
}
impl From<TVMTypeCode> for i64 {
fn from(arg: TVMTypeCode) -> i64 {
match arg {
TVMTypeCode::kDLInt => 0,
TVMTypeCode::kDLUInt => 1,
TVMTypeCode::kDLFloat => 2,
TVMTypeCode::kHandle => 3,
TVMTypeCode::kNull => 4,
TVMTypeCode::kTVMType => 5,
TVMTypeCode::kTVMContext => 6,
TVMTypeCode::kArrayHandle => 7,
TVMTypeCode::kNodeHandle => 8,
TVMTypeCode::kModuleHandle => 9,
TVMTypeCode::kFuncHandle => 10,
TVMTypeCode::kStr => 11,
TVMTypeCode::kBytes => 12,
TVMTypeCode::kNDArrayContainer => 13,
}
}
}
impl Into<TVMTypeCode> for i64 {
fn into(self) -> TVMTypeCode {
match self {
0 => TVMTypeCode::kDLInt,
1 => TVMTypeCode::kDLUInt,
2 => TVMTypeCode::kDLFloat,
3 => TVMTypeCode::kHandle,
4 => TVMTypeCode::kNull,
5 => TVMTypeCode::kTVMType,
6 => TVMTypeCode::kTVMContext,
7 => TVMTypeCode::kArrayHandle,
8 => TVMTypeCode::kNodeHandle,
9 => TVMTypeCode::kModuleHandle,
10 => TVMTypeCode::kFuncHandle,
11 => TVMTypeCode::kStr,
12 => TVMTypeCode::kBytes,
13 => TVMTypeCode::kNDArrayContainer,
_ => unreachable!(),
}
}
}
impl Display for TVMTypeCode {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
write!(
f,
"{}",
match self {
TVMTypeCode::kDLInt => "int",
TVMTypeCode::kDLUInt => "uint",
TVMTypeCode::kDLFloat => "float",
TVMTypeCode::kHandle => "handle",
TVMTypeCode::kNull => "null",
TVMTypeCode::kTVMType => "TVM type",
TVMTypeCode::kTVMContext => "TVM context",
TVMTypeCode::kArrayHandle => "Array handle",
TVMTypeCode::kNodeHandle => "Node handle",
TVMTypeCode::kModuleHandle => "Module handle",
TVMTypeCode::kFuncHandle => "Function handle",
TVMTypeCode::kStr => "string",
TVMTypeCode::kBytes => "bytes",
TVMTypeCode::kNDArrayContainer => "ndarray container",
}
)
}
}
macro_rules! impl_prim_type {
($type:ty, $variant:ident) => {
impl<'a> From<&'a $type> for TVMTypeCode {
fn from(_arg: &$type) -> Self {
TVMTypeCode::$variant
}
}
impl<'a> From<&'a mut $type> for TVMTypeCode {
fn from(_arg: &mut $type) -> Self {
TVMTypeCode::$variant
}
}
};
}
impl_prim_type!(usize, kDLInt);
impl_prim_type!(i64, kDLInt);
impl_prim_type!(i32, kDLInt);
impl_prim_type!(i16, kDLInt);
impl_prim_type!(i8, kDLInt);
impl_prim_type!(u64, kDLUInt);
impl_prim_type!(u32, kDLUInt);
impl_prim_type!(u16, kDLUInt);
impl_prim_type!(u8, kDLUInt);
impl_prim_type!(f64, kDLFloat);
impl_prim_type!(f32, kDLFloat);
impl_prim_type!(str, kStr);
impl_prim_type!(CStr, kStr);
impl_prim_type!(String, kStr);
impl_prim_type!(CString, kStr);
impl_prim_type!([u8], kBytes);
//! This module provides the the wrapped `TVMValue`, `TVMArgValue` and `TVMRetValue`
//! required for using TVM functions.
use std::str::FromStr;
use std::{
any::Any,
convert::TryFrom,
ffi::{CStr, CString},
fmt::{self, Debug, Formatter},
marker::PhantomData,
mem,
ops::Deref,
os::raw::{c_char, c_void},
};
use failure::Error;
#[cfg(feature = "runtime")]
use ffi::runtime::TVMValue as _TVMValue;
use crate::ffi::*;
#[cfg(feature = "frontend")]
use ffi::ts::TVMValue as _TVMValue;
use errors::*;
use ty::TVMTypeCode;
/// Wrapped TVMValue type.
#[derive(Clone, Copy)]
pub struct TVMValue {
pub inner: _TVMValue,
}
impl TVMValue {
/// Creates TVMValue from the raw part.
pub fn new(inner: _TVMValue) -> Self {
TVMValue { inner }
}
pub(crate) fn into_raw(self) -> _TVMValue {
self.inner
}
}
impl Debug for TVMValue {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
unsafe {
write!(
f,
"TVMValue: [v_int64: {:?}], [v_float64: {:?}], [v_handle: {:?}],\
[v_str: {:?}]",
self.inner.v_int64, self.inner.v_float64, self.inner.v_handle, self.inner.v_str
)
impl TVMType {
fn new(type_code: u8, bits: u8, lanes: u16) -> Self {
Self {
code: type_code,
bits,
lanes,
}
}
}
impl Deref for TVMValue {
type Target = _TVMValue;
fn deref(&self) -> &Self::Target {
&self.inner
/// Implements TVMType conversion from `&str` of general format `{dtype}{bits}x{lanes}`
/// such as "int32", "float32" or with lane "float32x1".
impl FromStr for TVMType {
type Err = Error;
fn from_str(type_str: &str) -> Result<Self, Self::Err> {
if type_str == "bool" {
return Ok(TVMType::new(1, 1, 1));
}
}
macro_rules! impl_prim_val {
($type:ty, $field:ident, $cast:ty) => {
impl From<$type> for TVMValue {
fn from(arg: $type) -> Self {
let inner = _TVMValue {
$field: arg as $cast,
};
Self::new(inner)
let mut type_lanes = type_str.split("x");
let typ = type_lanes.next().expect("Missing dtype");
let lanes = type_lanes
.next()
.map(|l| <u16>::from_str_radix(l, 10))
.unwrap_or(Ok(1))?;
let (type_name, bits) = match typ.find(char::is_numeric) {
Some(idx) => {
let (name, bits_str) = typ.split_at(idx);
(name, u8::from_str_radix(bits_str, 10)?)
}
}
impl<'a> From<&'a $type> for TVMValue {
fn from(arg: &$type) -> Self {
let inner = _TVMValue {
$field: *arg as $cast,
None => (typ, 32),
};
Self::new(inner)
}
}
impl<'a> From<&'a mut $type> for TVMValue {
fn from(arg: &mut $type) -> Self {
let inner = _TVMValue {
$field: *arg as $cast,
let type_code = match type_name {
"int" => 0,
"uint" => 1,
"float" => 2,
"handle" => 3,
_ => return Err(format_err!("Unknown type {}", type_name)),
};
Self::new(inner)
}
}
impl TryFrom<TVMValue> for $type {
type Error = Error;
fn try_from(val: TVMValue) -> Result<Self> {
Ok(unsafe { val.inner.$field as $type })
}
}
impl<'a> TryFrom<&'a TVMValue> for $type {
type Error = Error;
fn try_from(val: &TVMValue) -> Result<Self> {
Ok(unsafe { val.into_raw().$field as $type })
}
}
impl<'a> TryFrom<&'a mut TVMValue> for $type {
type Error = Error;
fn try_from(val: &mut TVMValue) -> Result<Self> {
Ok(unsafe { val.into_raw().$field as $type })
}
Ok(TVMType::new(type_code, bits, lanes))
}
};
}
impl_prim_val!(isize, v_int64, i64);
impl_prim_val!(i64, v_int64, i64);
impl_prim_val!(i32, v_int64, i64);
impl_prim_val!(i16, v_int64, i64);
impl_prim_val!(i8, v_int64, i64);
impl_prim_val!(usize, v_int64, i64);
impl_prim_val!(u64, v_int64, i64);
impl_prim_val!(u32, v_int64, i64);
impl_prim_val!(u16, v_int64, i64);
impl_prim_val!(u8, v_int64, i64);
impl_prim_val!(f64, v_float64, f64);
impl_prim_val!(f32, v_float64, f64);
impl<'a> From<&'a str> for TVMValue {
fn from(arg: &str) -> TVMValue {
let arg = CString::new(arg).unwrap();
let inner = _TVMValue {
v_str: arg.as_ptr() as *const c_char,
};
mem::forget(arg);
Self::new(inner)
impl std::fmt::Display for TVMType {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
if self.bits == 1 && self.lanes == 1 {
return write!(f, "bool");
}
}
impl<'a> From<&'a String> for TVMValue {
fn from(arg: &String) -> TVMValue {
let arg = CString::new(arg.as_bytes()).unwrap();
let inner = _TVMValue {
v_str: arg.as_ptr() as *const c_char,
};
mem::forget(arg);
Self::new(inner)
let mut type_str = match self.code {
0 => "int",
1 => "uint",
2 => "float",
4 => "handle",
_ => "unknown",
}
}
.to_string();
impl<'a> From<&'a CString> for TVMValue {
fn from(arg: &CString) -> TVMValue {
let arg = arg.to_owned();
let inner = _TVMValue {
v_str: arg.as_ptr() as *const c_char,
};
mem::forget(arg);
Self::new(inner)
type_str += &self.bits.to_string();
if self.lanes > 1 {
type_str += &format!("x{}", self.lanes);
}
}
impl<'a> From<&'a [u8]> for TVMValue {
fn from(arg: &[u8]) -> TVMValue {
let arg = arg.to_owned();
let inner = _TVMValue {
v_handle: &arg as *const _ as *mut c_void,
};
mem::forget(arg);
Self::new(inner)
f.write_str(&type_str)
}
}
/// Captures both `TVMValue` and `TVMTypeCode` needed for TVM function.
/// The preferred way to obtain a `TVMArgValue` is automatically via `call_packed!`.
/// or in the frontend crate, with `function::Builder`. Checkout the methods for conversions.
///
/// ## Example
///
/// ```
/// let s = "hello".to_string();
/// let arg = TVMArgValue::from(&s);
/// let tvm: String = arg.try_into().unwrap();
/// assert_eq!(arg, s);
/// ```
#[derive(Debug, Clone, Copy)]
pub struct TVMArgValue<'a> {
/// The wrapped TVMValue
pub value: TVMValue,
/// The matching type code.
pub type_code: TVMTypeCode,
/// This is only exposed to runtime and frontend crates and is not meant to be used directly.
pub lifetime: PhantomData<&'a ()>,
}
impl<'a> TVMArgValue<'a> {
pub fn new(value: TVMValue, type_code: TVMTypeCode) -> Self {
TVMArgValue {
value: value,
type_code: type_code,
lifetime: PhantomData,
}
}
}
impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for i64 {
type Error = Error;
fn try_from(arg: &TVMArgValue<'a>) -> Result<Self> {
if (arg.type_code == TVMTypeCode::kDLInt)
| (arg.type_code == TVMTypeCode::kDLUInt)
| (arg.type_code == TVMTypeCode::kNull)
{
Ok(unsafe { arg.value.inner.v_int64 })
} else {
bail!(ErrorKind::TryFromTVMArgValueError(
stringify!(i64).to_string(),
arg.type_code.to_string()
))
macro_rules! impl_pod_tvm_value {
($field:ident, $field_ty:ty, $( $ty:ty ),+) => {
$(
impl From<$ty> for TVMValue {
fn from(val: $ty) -> Self {
TVMValue { $field: val as $field_ty }
}
}
}
impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for f64 {
type Error = Error;
fn try_from(arg: &TVMArgValue<'a>) -> Result<Self> {
if arg.type_code == TVMTypeCode::kDLFloat {
Ok(unsafe { arg.value.inner.v_float64 })
} else {
bail!(ErrorKind::TryFromTVMArgValueError(
stringify!(f64).to_string(),
arg.type_code.to_string()
))
}
impl From<TVMValue> for $ty {
fn from(val: TVMValue) -> Self {
unsafe { val.$field as $ty }
}
}
impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for String {
type Error = Error;
fn try_from(arg: &TVMArgValue<'a>) -> Result<Self> {
if arg.type_code == TVMTypeCode::kStr {
let ret_str = unsafe {
match CStr::from_ptr(arg.value.inner.v_str).to_str() {
Ok(s) => s,
Err(_) => "Invalid UTF-8 message",
}
)+
};
Ok(ret_str.to_string())
} else {
bail!(ErrorKind::TryFromTVMArgValueError(
stringify!(String).to_string(),
arg.type_code.to_string()
))
}
($field:ident, $ty:ty) => {
impl_pod_tvm_value!($field, $ty, $ty);
}
}
/// Main way to create a TVMArgValue from suported Rust values.
impl<'b, 'a: 'b, T: 'b + ?Sized> From<&'b T> for TVMArgValue<'a>
where
TVMValue: From<&'b T>,
TVMTypeCode: From<&'b T>,
{
fn from(arg: &'b T) -> Self {
TVMArgValue::new(TVMValue::from(arg), TVMTypeCode::from(arg))
}
}
impl_pod_tvm_value!(v_int64, i64, i8, u8, i16, u16, i32, u32, i64, u64, isize, usize);
impl_pod_tvm_value!(v_float64, f64, f32, f64);
impl_pod_tvm_value!(v_type, TVMType);
impl_pod_tvm_value!(v_ctx, TVMContext);
/// Creates a conversion to a `TVMArgValue` for an object handle.
impl<'a, T> From<*const T> for TVMArgValue<'a> {
fn from(ptr: *const T) -> Self {
let value = TVMValue::new(_TVMValue {
v_handle: ptr as *mut T as *mut c_void,
});
TVMArgValue::new(value, TVMTypeCode::kArrayHandle)
macro_rules! impl_tvm_context {
( $( $dev_type:ident : [ $( $dev_name:ident ),+ ] ),+ ) => {
/// Creates a TVMContext from a string (e.g., "cpu", "gpu", "ext_dev")
impl FromStr for TVMContext {
type Err = Error;
fn from_str(type_str: &str) -> Result<Self, Self::Err> {
Ok(Self {
device_type: match type_str {
$( $( stringify!($dev_name) )|+ => $dev_type ),+,
_ => return Err(format_err!("device {} not supported", type_str).into()),
},
device_id: 0,
})
}
}
/// Creates a conversion to a `TVMArgValue` for a mutable object handle.
impl<'a, T> From<*mut T> for TVMArgValue<'a> {
fn from(ptr: *mut T) -> Self {
let value = TVMValue::new(_TVMValue {
v_handle: ptr as *mut c_void,
});
TVMArgValue::new(value, TVMTypeCode::kHandle)
}
}
/// An owned version of TVMPODValue. It can be converted from varieties of
/// primitive and object types.
/// It can be downcasted using `try_from` if it contains the desired type.
///
/// # Example
///
/// ```
/// let a = 42u32;
/// let b: i64 = TVMRetValue::from(a).try_into().unwrap();
///
/// let s = "hello, world!";
/// let t: TVMRetValue = s.into();
/// assert_eq!(String::try_from(t).unwrap(), s);
/// ```
pub struct TVMRetValue {
/// A primitive return value, if any.
pub prim_value: usize,
/// An object return value, if any.
pub box_value: Box<Any>,
pub type_code: TVMTypeCode,
}
impl TVMRetValue {
fn new(prim_value: usize, box_value: Box<Any>, type_code: TVMTypeCode) -> Self {
impl TVMContext {
$(
$(
pub fn $dev_name(device_id: usize) -> Self {
Self {
prim_value,
box_value,
type_code,
}
}
/// unsafe function to create `TVMRetValue` from `TVMValue` and
/// its matching `TVMTypeCode`.
pub unsafe fn from_tvm_value(value: TVMValue, type_code: TVMTypeCode) -> Self {
let value = value.into_raw();
match type_code {
TVMTypeCode::kDLInt | TVMTypeCode::kDLUInt => {
Self::new(value.v_int64 as usize, box (), type_code)
}
TVMTypeCode::kDLFloat => Self::new(value.v_float64 as usize, box (), type_code),
TVMTypeCode::kHandle
| TVMTypeCode::kArrayHandle
| TVMTypeCode::kNodeHandle
| TVMTypeCode::kModuleHandle
| TVMTypeCode::kFuncHandle => {
Self::new(value.v_handle as usize, box value.v_handle, type_code)
}
TVMTypeCode::kStr | TVMTypeCode::kBytes => {
Self::new(value.v_str as usize, box (value.v_str), type_code)
device_type: $dev_type,
device_id: device_id as i32,
}
_ => Self::new(0usize, box (), type_code),
}
}
/// Returns the underlying `TVMValue` and `TVMTypeCode`.
pub fn into_tvm_value(self) -> (TVMValue, TVMTypeCode) {
let val = match self.type_code {
TVMTypeCode::kDLInt | TVMTypeCode::kDLUInt => TVMValue::new(_TVMValue {
v_int64: self.prim_value as i64,
}),
TVMTypeCode::kDLFloat => TVMValue::new(_TVMValue {
v_float64: self.prim_value as f64,
}),
TVMTypeCode::kHandle
| TVMTypeCode::kArrayHandle
| TVMTypeCode::kNodeHandle
| TVMTypeCode::kModuleHandle
| TVMTypeCode::kFuncHandle
| TVMTypeCode::kNDArrayContainer => TVMValue::new(_TVMValue {
v_handle: self.prim_value as *const c_void as *mut c_void,
}),
TVMTypeCode::kStr | TVMTypeCode::kBytes => TVMValue::new(_TVMValue {
v_str: self.prim_value as *const c_char,
}),
_ => unreachable!(),
};
(val, self.type_code)
}
}
impl Default for TVMRetValue {
fn default() -> Self {
TVMRetValue {
prim_value: 0usize,
box_value: box (),
type_code: TVMTypeCode::default(),
}
}
}
impl Clone for TVMRetValue {
fn clone(&self) -> Self {
match self.type_code {
TVMTypeCode::kDLInt | TVMTypeCode::kDLUInt | TVMTypeCode::kDLFloat => {
Self::new(self.prim_value.clone(), box (), self.type_code.clone())
}
TVMTypeCode::kHandle
| TVMTypeCode::kArrayHandle
| TVMTypeCode::kNodeHandle
| TVMTypeCode::kModuleHandle
| TVMTypeCode::kFuncHandle
| TVMTypeCode::kNDArrayContainer => Self::new(
self.prim_value.clone(),
box (self.prim_value.clone() as *const c_void as *mut c_void),
self.type_code.clone(),
),
TVMTypeCode::kStr | TVMTypeCode::kBytes => Self::new(
self.prim_value.clone(),
box (self.prim_value.clone() as *const c_char),
self.type_code.clone(),
),
_ => unreachable!(),
}
}
}
impl Debug for TVMRetValue {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
write!(
f,
"prim_value: {:?}, box_value: {:?}, type_code: {:?}",
self.prim_value, self.prim_value as *const c_void as *mut c_void, self.type_code
)
}
}
macro_rules! impl_prim_ret_value {
($type:ty, $code:expr) => {
impl From<$type> for TVMRetValue {
fn from(val: $type) -> Self {
TVMRetValue {
prim_value: val as usize,
box_value: box (),
type_code: $code,
}
}
}
impl<'a> From<&'a $type> for TVMRetValue {
fn from(val: &$type) -> Self {
TVMRetValue {
prim_value: *val as usize,
box_value: box (),
type_code: $code,
}
}
}
impl<'a> From<&'a mut $type> for TVMRetValue {
fn from(val: &mut $type) -> Self {
TVMRetValue {
prim_value: *val as usize,
box_value: box (),
type_code: $code,
}
}
}
impl TryFrom<TVMRetValue> for $type {
type Error = Error;
fn try_from(ret: TVMRetValue) -> Result<$type> {
if ret.type_code == $code {
Ok(ret.prim_value as $type)
} else {
bail!(ErrorKind::TryFromTVMRetValueError(
stringify!($type).to_string(),
ret.type_code.to_string(),
))
}
}
}
};
}
impl_prim_ret_value!(i8, TVMTypeCode::kDLInt);
impl_prim_ret_value!(i16, TVMTypeCode::kDLInt);
impl_prim_ret_value!(i32, TVMTypeCode::kDLInt);
impl_prim_ret_value!(i64, TVMTypeCode::kDLInt);
impl_prim_ret_value!(isize, TVMTypeCode::kDLInt);
impl_prim_ret_value!(u8, TVMTypeCode::kDLUInt);
impl_prim_ret_value!(u16, TVMTypeCode::kDLUInt);
impl_prim_ret_value!(u32, TVMTypeCode::kDLUInt);
impl_prim_ret_value!(u64, TVMTypeCode::kDLUInt);
impl_prim_ret_value!(usize, TVMTypeCode::kDLUInt);
impl_prim_ret_value!(f32, TVMTypeCode::kDLFloat);
impl_prim_ret_value!(f64, TVMTypeCode::kDLFloat);
macro_rules! impl_ptr_ret_value {
($type:ty) => {
impl From<$type> for TVMRetValue {
fn from(ptr: $type) -> Self {
TVMRetValue {
prim_value: ptr as usize,
box_value: box (),
type_code: TVMTypeCode::kHandle,
}
}
}
impl TryFrom<TVMRetValue> for $type {
type Error = Error;
fn try_from(ret: TVMRetValue) -> Result<$type> {
if ret.type_code == TVMTypeCode::kHandle {
Ok(ret.prim_value as $type)
} else {
bail!(ErrorKind::TryFromTVMRetValueError(
stringify!($type).to_string(),
ret.type_code.to_string(),
))
}
}
}
};
}
impl_ptr_ret_value!(*const c_void);
impl_ptr_ret_value!(*mut c_void);
impl From<String> for TVMRetValue {
fn from(val: String) -> Self {
let pval = val.as_ptr() as *const c_char as usize;
let bval = box (val.as_ptr() as *const c_char);
mem::forget(val);
TVMRetValue::new(pval, bval, TVMTypeCode::kStr)
}
}
impl TryFrom<TVMRetValue> for String {
type Error = Error;
fn try_from(ret: TVMRetValue) -> Result<String> {
// Note: simple downcast doesn't work for function call return values
let ret_str = unsafe {
match CStr::from_ptr(ret.prim_value as *const c_char).to_str() {
Ok(s) => s,
Err(_) => "Invalid UTF-8 message",
)+
)+
}
};
Ok(ret_str.to_string())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::convert::TryInto;
#[test]
fn numeric() {
macro_rules! arg_ret_tests {
($v:expr; $($ty:ty),+) => {{
$(
let v = $v as $ty;
let b = TVMRetValue::from(&v);
let b: $ty = b.try_into().unwrap();
assert_eq!(b, v);
)+
}};
}
arg_ret_tests!(42; i8, i16, i32, i64, f32, f64);
}
#[test]
fn string() {
let s = "hello".to_string();
let tvm_arg: String = TVMRetValue::from(s.clone()).try_into().unwrap();
assert_eq!(tvm_arg, s);
}
}
impl_tvm_context!(
DLDeviceType_kDLCPU: [cpu, llvm, stackvm],
DLDeviceType_kDLGPU: [gpu, cuda, nvptx],
DLDeviceType_kDLOpenCL: [cl],
DLDeviceType_kDLMetal: [metal],
DLDeviceType_kDLVPI: [vpi],
DLDeviceType_kDLROCM: [rocm],
DLDeviceType_kDLExtDev: [ext_dev]
);
[package]
name = "tvm-sys"
version = "0.1.0"
authors = ["TVM Contributors"]
license = "Apache-2.0"
description = "Raw C API"
[build-dependencies]
bindgen = "0.37.4"
#![allow(
non_camel_case_types,
non_snake_case,
non_upper_case_globals,
dead_code,
improper_ctypes
)]
include!("bindgen.rs");
......@@ -15,11 +15,11 @@ name = "tvm_frontend"
crate-type = ["dylib"]
[dependencies]
error-chain = "0.12.0"
failure = "0.1.5"
lazy_static = "1.1.0"
ndarray = "0.12.1"
num-traits = "0.2"
tvm-common = { version = "0.1.0", path = "../common/", features = ["frontend"] }
tvm-common = { version = "0.1.0", path = "../common/", features = ["bindings"] }
[features]
blas = ["ndarray/blas"]
#![feature(try_from)]
extern crate csv;
extern crate image;
extern crate ndarray;
......@@ -10,6 +8,7 @@ use std::{
convert::TryInto,
fs::{self, File},
path::Path,
str::FromStr,
};
use image::{FilterType, GenericImageView};
......@@ -44,8 +43,12 @@ fn main() {
// make arr shape as [1, 3, 224, 224] acceptable to resnet
let arr = arr.insert_axis(Axis(0));
// create input tensor from rust's ndarray
let input =
NDArray::from_rust_ndarray(&arr, TVMContext::cpu(0), TVMType::from("float32")).unwrap();
let input = NDArray::from_rust_ndarray(
&arr,
TVMContext::cpu(0),
TVMType::from_str("float32").unwrap(),
)
.unwrap();
println!(
"input size is {:?}",
input.shape().expect("cannot get the input shape")
......@@ -59,7 +62,7 @@ fn main() {
)))
.unwrap();
// get the global TVM graph runtime function
let runtime_create_fn = Function::get("tvm.graph_runtime.create", true).unwrap();
let runtime_create_fn = Function::get("tvm.graph_runtime.create").unwrap();
let runtime_create_fn_ret = call_packed!(
runtime_create_fn,
&graph,
......@@ -85,14 +88,19 @@ fn main() {
.get_function("set_input", false)
.unwrap();
call_packed!(set_input_fn, "data", &input).unwrap();
let data_str = "data".to_string();
call_packed!(set_input_fn, &data_str, &input).unwrap();
// get `run` function from runtime module
let ref run_fn = graph_runtime_module.get_function("run", false).unwrap();
// execute the run function. Note that it has no argument
call_packed!(run_fn,).unwrap();
// prepare to get the output
let output_shape = &mut [1, 1000];
let output = NDArray::empty(output_shape, TVMContext::cpu(0), TVMType::from("float32"));
let output = NDArray::empty(
output_shape,
TVMContext::cpu(0),
TVMType::from_str("float32").unwrap(),
);
// get the `get_output` function from runtime module
let ref get_output_fn = graph_runtime_module
.get_function("get_output", false)
......
......@@ -3,9 +3,9 @@
//!
//! For more detail, please see the example `resnet` in `examples` repository.
use std::os::raw::c_char;
use std::os::raw::{c_char, c_void};
use crate::ts;
use tvm_common::{ffi, TVMArgValue};
/// A struct holding TVM byte-array.
///
......@@ -19,11 +19,11 @@ use crate::ts;
/// ```
#[derive(Debug, Clone)]
pub struct TVMByteArray {
pub(crate) inner: ts::TVMByteArray,
pub(crate) inner: ffi::TVMByteArray,
}
impl TVMByteArray {
pub(crate) fn new(barr: ts::TVMByteArray) -> TVMByteArray {
pub(crate) fn new(barr: ffi::TVMByteArray) -> TVMByteArray {
TVMByteArray { inner: barr }
}
......@@ -46,7 +46,7 @@ impl TVMByteArray {
impl<'a> From<&'a Vec<u8>> for TVMByteArray {
fn from(arg: &Vec<u8>) -> Self {
let barr = ts::TVMByteArray {
let barr = ffi::TVMByteArray {
data: arg.as_ptr() as *const c_char,
size: arg.len(),
};
......@@ -54,6 +54,18 @@ impl<'a> From<&'a Vec<u8>> for TVMByteArray {
}
}
impl<'a> From<&TVMByteArray> for TVMArgValue<'a> {
fn from(arr: &TVMByteArray) -> Self {
Self {
value: ffi::TVMValue {
v_handle: &arr.inner as *const ffi::TVMByteArray as *const c_void as *mut c_void,
},
type_code: ffi::TVMTypeCode_kBytes as i64,
_lifetime: std::marker::PhantomData,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
......
......@@ -18,12 +18,20 @@
//! ```
use std::{
convert::TryInto,
fmt::{self, Display, Formatter},
os::raw::c_void,
ptr,
};
use crate::{function, ts, Result};
use failure::Error;
use tvm_common::{
ffi::{self, TVMValue},
TVMArgValue,
};
use crate::function;
/// Device type can be from a supported device name. See the supported devices
/// in [TVM](https://github.com/dmlc/tvm).
......@@ -45,35 +53,35 @@ impl Default for TVMDeviceType {
}
}
impl From<TVMDeviceType> for ts::DLDeviceType {
impl From<TVMDeviceType> for ffi::DLDeviceType {
fn from(device_type: TVMDeviceType) -> Self {
match device_type.0 {
1 => ts::DLDeviceType_kDLCPU,
2 => ts::DLDeviceType_kDLGPU,
3 => ts::DLDeviceType_kDLCPUPinned,
4 => ts::DLDeviceType_kDLOpenCL,
7 => ts::DLDeviceType_kDLVulkan,
8 => ts::DLDeviceType_kDLMetal,
9 => ts::DLDeviceType_kDLVPI,
10 => ts::DLDeviceType_kDLROCM,
12 => ts::DLDeviceType_kDLExtDev,
1 => ffi::DLDeviceType_kDLCPU,
2 => ffi::DLDeviceType_kDLGPU,
3 => ffi::DLDeviceType_kDLCPUPinned,
4 => ffi::DLDeviceType_kDLOpenCL,
7 => ffi::DLDeviceType_kDLVulkan,
8 => ffi::DLDeviceType_kDLMetal,
9 => ffi::DLDeviceType_kDLVPI,
10 => ffi::DLDeviceType_kDLROCM,
12 => ffi::DLDeviceType_kDLExtDev,
_ => panic!("device type not found!"),
}
}
}
impl From<ts::DLDeviceType> for TVMDeviceType {
fn from(device_type: ts::DLDeviceType) -> Self {
impl From<ffi::DLDeviceType> for TVMDeviceType {
fn from(device_type: ffi::DLDeviceType) -> Self {
match device_type {
ts::DLDeviceType_kDLCPU => TVMDeviceType(1),
ts::DLDeviceType_kDLGPU => TVMDeviceType(2),
ts::DLDeviceType_kDLCPUPinned => TVMDeviceType(3),
ts::DLDeviceType_kDLOpenCL => TVMDeviceType(4),
ts::DLDeviceType_kDLVulkan => TVMDeviceType(7),
ts::DLDeviceType_kDLMetal => TVMDeviceType(8),
ts::DLDeviceType_kDLVPI => TVMDeviceType(9),
ts::DLDeviceType_kDLROCM => TVMDeviceType(10),
ts::DLDeviceType_kDLExtDev => TVMDeviceType(12),
ffi::DLDeviceType_kDLCPU => TVMDeviceType(1),
ffi::DLDeviceType_kDLGPU => TVMDeviceType(2),
ffi::DLDeviceType_kDLCPUPinned => TVMDeviceType(3),
ffi::DLDeviceType_kDLOpenCL => TVMDeviceType(4),
ffi::DLDeviceType_kDLVulkan => TVMDeviceType(7),
ffi::DLDeviceType_kDLMetal => TVMDeviceType(8),
ffi::DLDeviceType_kDLVPI => TVMDeviceType(9),
ffi::DLDeviceType_kDLROCM => TVMDeviceType(10),
ffi::DLDeviceType_kDLExtDev => TVMDeviceType(12),
_ => panic!("device type not found!"),
}
}
......@@ -117,6 +125,18 @@ impl<'a> From<&'a str> for TVMDeviceType {
}
}
impl<'a> From<&'a TVMDeviceType> for TVMArgValue<'a> {
fn from(dev_type: &'a TVMDeviceType) -> Self {
Self {
value: TVMValue {
v_int64: dev_type.0 as i64,
},
type_code: ffi::DLDataTypeCode_kDLInt as i64,
_lifetime: std::marker::PhantomData,
}
}
}
/// Represents the underlying device context. Default is cpu.
///
/// ## Examples
......@@ -138,15 +158,15 @@ pub struct TVMContext {
/// Supported device types
pub device_type: TVMDeviceType,
/// Device id
pub device_id: usize,
pub device_id: i32,
}
impl TVMContext {
/// Creates context from device type and id.
pub fn new(device_type: TVMDeviceType, device_id: usize) -> Self {
pub fn new(device_type: TVMDeviceType, device_id: i32) -> Self {
TVMContext {
device_type: device_type,
device_id: device_id,
device_type,
device_id,
}
}
}
......@@ -155,7 +175,7 @@ macro_rules! impl_ctxs {
($(($ctx:ident, $dldevt:expr));+) => {
$(
impl TVMContext {
pub fn $ctx(device_id: usize) -> Self {
pub fn $ctx(device_id: i32) -> Self {
Self::new(TVMDeviceType($dldevt), device_id)
}
}
......@@ -185,20 +205,20 @@ impl<'a> From<&'a str> for TVMContext {
impl TVMContext {
/// Checks whether the context exists or not.
pub fn exist(&self) -> bool {
let func = function::Function::get("_GetDeviceAttr", true /* is_global */)
.expect("API function always exists");
let func = function::Function::get("_GetDeviceAttr").expect("API function always exists");
let dt = self.device_type.0 as usize;
// `unwrap` is ok here because if there is any error,
// if would occure inside `call_packed!`
let ret = call_packed!(func, &dt, &self.device_id, &0)
let ret: u64 = call_packed!(func, &dt, &self.device_id, &0)
.unwrap()
.prim_value;
.try_into()
.unwrap();
ret != 0
}
/// Synchronize the context stream.
pub fn sync(&self) -> Result<()> {
check_call!(ts::TVMSynchronize(
pub fn sync(&self) -> Result<(), Error> {
check_call!(ffi::TVMSynchronize(
self.device_type.0 as i32,
self.device_id as i32,
ptr::null_mut() as *mut c_void
......@@ -212,16 +232,17 @@ macro_rules! impl_device_attrs {
$(
impl TVMContext {
pub fn $attr_name(&self) -> usize {
let func = function::Function::get("_GetDeviceAttr", true /* is_global */)
let func = function::Function::get("_GetDeviceAttr")
.expect("API function always exists");
let dt = self.device_type.0 as usize;
// `unwrap` is ok here because if there is any error,
// if would occur in function call.
let ret = function::Builder::from(func)
.args(&[dt, self.device_id, $attr_kind])
function::Builder::from(func)
.args(&[dt, self.device_id as usize, $attr_kind])
.invoke()
.unwrap();
ret.prim_value as usize
.unwrap()
.try_into()
.unwrap()
}
}
)+
......@@ -237,18 +258,18 @@ impl_device_attrs!((max_threads_per_block, 1);
(multi_processor_count, 7);
(max_thread_dimensions, 8));
impl From<ts::DLContext> for TVMContext {
fn from(ctx: ts::DLContext) -> Self {
impl From<ffi::DLContext> for TVMContext {
fn from(ctx: ffi::DLContext) -> Self {
TVMContext {
device_type: TVMDeviceType::from(ctx.device_type),
device_id: ctx.device_id as usize,
device_id: ctx.device_id,
}
}
}
impl From<TVMContext> for ts::DLContext {
impl From<TVMContext> for ffi::DLContext {
fn from(ctx: TVMContext) -> Self {
ts::DLContext {
ffi::DLContext {
device_type: ctx.device_type.into(),
device_id: ctx.device_id as i32,
}
......
//! This module implements TVM custom [`Error`], [`ErrorKind`] and [`Result`] types.
pub use failure::Error;
use std::{ffi, option};
#[derive(Debug, Fail)]
#[fail(display = "Cannot convert from an empty array.")]
pub struct EmptyArrayError;
use crate::{common_errors, rust_ndarray};
error_chain! {
errors {
EmptyArray {
description("cannot convert from an empty array")
}
NullHandle(name: String) {
description("null handle")
display("requested `{}` handle is null", name)
}
FunctionNotFound {
description("function not found")
display("function was not set in `function::Builder`")
}
TypeMismatch(expected: String, found: String) {
description("type mismatch!")
display("expected type `{}`, but found `{}`", expected, found)
}
MissingShapeError {
description("ndarray `shape()` returns `None`")
display("called `Option::unwrap()` on a `None` value")
}
AtMostOneReturn {
description("TVM functions accept at most one return value")
}
#[derive(Debug, Fail)]
#[fail(display = "Handle `{}` is null.", name)]
pub struct NullHandleError {
pub name: String,
}
}
#[derive(Debug, Fail)]
#[fail(display = "Function was not set in `function::Builder`")]
pub struct FunctionNotFoundError;
foreign_links {
ShapeError(rust_ndarray::ShapeError);
NulError(ffi::NulError);
IntoStringError(ffi::IntoStringError);
CommonError(common_errors::Error);
}
#[derive(Debug, Fail)]
#[fail(display = "Expected type `{}` but found `{}`", expected, actual)]
pub struct TypeMismatchError {
pub expected: String,
pub actual: String,
}
impl From<option::NoneError> for Error {
fn from(_err: option::NoneError) -> Self {
ErrorKind::MissingShapeError.into()
}
}
#[derive(Debug, Fail)]
#[fail(display = "Missing NDArray shape.")]
pub struct MissingShapeError;
......@@ -15,14 +15,20 @@ use std::{
sync::Mutex,
};
use crate::{ts, ErrorKind, Module, Result, TVMArgValue, TVMRetValue, TVMTypeCode, TVMValue};
use failure::Error;
use crate::{
errors,
ffi::{self, TVMValue},
Module, TVMArgValue, TVMRetValue,
};
lazy_static! {
static ref GLOBAL_FUNCTIONS: Mutex<BTreeMap<&'static str, Option<Function>>> = {
let mut out_size = 0 as c_int;
let name = ptr::null_mut() as *mut c_char;
let mut out_array = name as *mut _;
check_call!(ts::TVMFuncListGlobalNames(
check_call!(ffi::TVMFuncListGlobalNames(
&mut out_size as *mut _,
&mut out_array
));
......@@ -37,17 +43,14 @@ lazy_static! {
}
/// Wrapper around TVM function handle which includes `is_global`
/// indicating whether the function is global or not, `is_released`
/// to hint dropping the function handle and `is_cloned` showing
/// indicating whether the function is global or not, and `is_cloned` showing
/// not to drop a cloned function from Rust side.
/// The value of these fields can be accessed through their respective methods.
#[derive(Debug, Hash)]
pub struct Function {
pub(crate) handle: ts::TVMFunctionHandle,
pub(crate) handle: ffi::TVMFunctionHandle,
// whether the registered function is global or not.
is_global: bool,
// whether the function has been dropped from frontend or not.
is_released: bool,
// whether the function has been cloned from frontend or not.
is_cloned: bool,
}
......@@ -56,29 +59,30 @@ unsafe impl Send for Function {}
unsafe impl Sync for Function {}
impl Function {
pub(crate) fn new(handle: ts::TVMFunctionHandle, is_global: bool, is_released: bool) -> Self {
pub(crate) fn new(handle: ffi::TVMFunctionHandle) -> Self {
Function {
handle: handle,
is_global: is_global,
is_released: is_released,
is_global: false,
is_cloned: false,
}
}
/// For a given function, it returns a function by name.
pub fn get<S: AsRef<str>>(name: S, is_global: bool) -> Option<&'static Function> {
pub fn get<S: AsRef<str>>(name: S) -> Option<&'static Function> {
let mut globals = GLOBAL_FUNCTIONS.lock().unwrap();
globals.get_mut(name.as_ref()).and_then(|maybe_func| {
if maybe_func.is_none() {
let name = CString::new(name.as_ref()).unwrap();
let mut handle = ptr::null_mut() as ts::TVMFunctionHandle;
check_call!(ts::TVMFuncGetGlobal(
let mut handle = ptr::null_mut() as ffi::TVMFunctionHandle;
check_call!(ffi::TVMFuncGetGlobal(
name.as_ptr() as *const c_char,
&mut handle as *mut _
));
maybe_func.replace(Function::new(
handle, is_global, false, /* is_released */
));
maybe_func.replace(Function {
handle: handle,
is_global: true,
is_cloned: false,
});
}
unsafe {
std::mem::transmute::<Option<&Function>, Option<&'static Function>>(
......@@ -89,7 +93,7 @@ impl Function {
}
/// Returns the underlying TVM function handle.
pub fn handle(&self) -> ts::TVMFunctionHandle {
pub fn handle(&self) -> ffi::TVMFunctionHandle {
self.handle
}
......@@ -98,12 +102,6 @@ impl Function {
self.is_global
}
/// Returns `true` if the underlying TVM function has been released
/// from the frontend and `false` otherwise.
pub fn is_released(&self) -> bool {
self.is_released
}
/// Returns `true` if the underlying TVM function has been cloned
/// from the frontend and `false` otherwise.
pub fn is_cloned(&self) -> bool {
......@@ -113,24 +111,18 @@ impl Function {
impl Clone for Function {
fn clone(&self) -> Function {
if !self.is_released && !self.is_cloned {
Self {
handle: self.handle,
is_global: self.is_global,
is_released: self.is_released,
is_cloned: true,
}
} else {
Function::new(self.handle, self.is_global, self.is_released)
}
}
}
impl Drop for Function {
fn drop(&mut self) {
if !self.is_released && !self.is_global && !self.is_cloned {
check_call!(ts::TVMFuncFree(self.handle));
self.is_released = true;
if !self.is_global && !self.is_cloned {
check_call!(ffi::TVMFuncFree(self.handle));
}
}
}
......@@ -138,17 +130,17 @@ impl Drop for Function {
/// Function builder in order to create and call functions.
///
/// *Note:* Currently TVM functions accept *at most* one return value.
#[derive(Debug, Clone, Default)]
#[derive(Default)]
pub struct Builder<'a, 'm> {
pub func: Option<&'m Function>,
pub arg_buf: Option<Box<[TVMArgValue<'a>]>>,
pub arg_buf: Vec<TVMArgValue<'a>>,
pub ret_buf: Option<TVMRetValue>,
}
impl<'a, 'm> Builder<'a, 'm> {
pub fn new(
func: Option<&'m Function>,
arg_buf: Option<Box<[TVMArgValue<'a>]>>,
arg_buf: Vec<TVMArgValue<'a>>,
ret_buf: Option<TVMRetValue>,
) -> Self {
Self {
......@@ -158,123 +150,66 @@ impl<'a, 'm> Builder<'a, 'm> {
}
}
pub fn get_function(&mut self, name: &'m str, is_global: bool) -> &mut Self {
self.func = Function::get(name, is_global);
pub fn get_function(&mut self, name: &'m str) -> &mut Self {
self.func = Function::get(name);
self
}
/// Pushes a [`TVMArgValue`] into the function argument buffer.
pub fn arg<'b, T: ?Sized>(&mut self, arg: &'b T) -> &mut Self
pub fn arg<T: 'a>(&mut self, arg: &'a T) -> &mut Self
where
TVMValue: From<&'b T>,
TVMTypeCode: From<&'b T>,
TVMArgValue<'a>: From<&'a T>,
{
let tvm_arg = TVMArgValue::from(arg);
if self.arg_buf.is_none() {
self.arg_buf = Some(Box::new([tvm_arg]));
} else {
let new_arg_buf = self.arg_buf.take().map(|bbuf| {
let mut new_arg_buf = Vec::from(bbuf);
new_arg_buf.push(tvm_arg);
let new_len = new_arg_buf.len();
new_arg_buf.truncate(new_len);
new_arg_buf.into_boxed_slice()
});
self.arg_buf = new_arg_buf;
}
self.arg_buf.push(arg.into());
self
}
/// Pushes multiple [`TVMArgValue`]s into the function argument buffer.
pub fn args<'b, T: 'b + ?Sized, I>(&mut self, args: I) -> &mut Self
pub fn args<T: 'a, I>(&mut self, args: I) -> &mut Self
where
I: IntoIterator<Item = &'b T>,
TVMValue: From<&'b T>,
TVMTypeCode: From<&'b T>,
I: IntoIterator<Item = &'a T>,
TVMArgValue<'a>: From<&'a T>,
{
for arg in args {
args.into_iter().for_each(|arg| {
self.arg(&arg);
}
});
self
}
/// Sets an output for a function that requirs a mutable output to be provided.
/// See the `basics` in tests for an example.
pub fn set_output<'b, T: 'b + ?Sized>(&mut self, arg: &'b mut T) -> Result<&mut Self>
pub fn set_output<T>(&mut self, ret: T) -> &mut Self
where
TVMValue: From<&'b T>,
TVMTypeCode: From<&'b T>,
TVMRetValue: From<T>,
{
if self.ret_buf.is_none() {
let tvm_ret =
unsafe { TVMRetValue::from_tvm_value(TVMValue::from(arg), TVMTypeCode::from(arg)) };
self.ret_buf = Some(tvm_ret);
} else {
bail!(ErrorKind::AtMostOneReturn)
}
Ok(self)
self.ret_buf = Some(ret.into());
self
}
/// Calls the function that created from `Builder`.
pub fn invoke(&mut self) -> Result<TVMRetValue> {
self.clone()(())
}
}
pub fn invoke(&mut self) -> Result<TVMRetValue, Error> {
#![allow(unused_unsafe)]
ensure!(self.func.is_some(), errors::FunctionNotFoundError);
impl<'a, 'm> FnOnce<((),)> for Builder<'a, 'm> {
type Output = Result<TVMRetValue>;
extern "rust-call" fn call_once(self, _: ((),)) -> Self::Output {
if self.func.is_none() {
bail!("{}", ErrorKind::FunctionNotFound);
}
let mut ret_val = unsafe { mem::uninitialized::<ts::TVMValue>() };
let mut ret_type_code = 0 as c_int;
if self.arg_buf.is_some() {
let arg_buf = self.arg_buf?;
let mut num_args = arg_buf.len();
let mut values = arg_buf
.iter()
.map(|tav| tav.value.inner)
.collect::<Vec<ts::TVMValue>>();
let mut tcodes = arg_buf
let num_args = self.arg_buf.len();
let (mut values, mut type_codes): (Vec<ffi::TVMValue>, Vec<ffi::TVMTypeCode>) = self
.arg_buf
.iter()
.map(|tav| tav.type_code as c_int)
.collect::<Vec<_>>();
if self.ret_buf.is_some() {
num_args = num_args + 1;
let ret_buf = self.ret_buf?;
let (ret_val, ret_type_code) = TVMRetValue::into_tvm_value(ret_buf);
values.append(&mut vec![ret_val.inner]);
tcodes.append(&mut vec![ret_type_code as c_int]);
}
.map(|tvm_arg| (tvm_arg.value, tvm_arg.type_code as ffi::TVMTypeCode))
.unzip();
values.truncate(num_args);
tcodes.truncate(num_args);
check_call!(ts::TVMFuncCall(
self.func?.handle,
let mut ret_val = unsafe { std::mem::uninitialized::<TVMValue>() };
let mut ret_type_code = 0;
check_call!(ffi::TVMFuncCall(
self.func.ok_or(errors::FunctionNotFoundError)?.handle,
values.as_mut_ptr(),
tcodes.as_mut_ptr(),
type_codes.as_mut_ptr() as *mut i32,
num_args as c_int,
&mut ret_val as *mut _,
&mut ret_type_code as *mut _
));
} else {
check_call!(ts::TVMFuncCall(
self.func?.handle,
ptr::null_mut(),
ptr::null_mut(),
0 as c_int,
&mut ret_val as *mut _,
&mut ret_type_code as *mut _
));
}
let ret = unsafe {
TVMRetValue::from_tvm_value(TVMValue::new(ret_val), (ret_type_code as i64).into())
};
Ok(ret)
Ok(unsafe { TVMRetValue::from_tvm_value(ret_val, ret_type_code as i64) })
}
}
......@@ -282,46 +217,44 @@ impl<'a, 'm> FnOnce<((),)> for Builder<'a, 'm> {
/// TVM functions.
impl<'a, 'm> From<&'m Function> for Builder<'a, 'm> {
fn from(func: &'m Function) -> Self {
Builder::new(Some(func), None, None)
Builder::new(Some(func), Vec::new(), None)
}
}
/// Converts a mutable reference of a [`Module`] to [`Builder`].
impl<'a, 'm> From<&'m mut Module> for Builder<'a, 'm> {
fn from(module: &'m mut Module) -> Self {
Builder::new(module.entry(), None, None)
Builder::new(module.entry(), Vec::new(), None)
}
}
unsafe extern "C" fn tvm_callback(
args: *mut ts::TVMValue,
args: *mut ffi::TVMValue,
type_codes: *mut c_int,
num_args: c_int,
ret: ts::TVMRetValueHandle,
ret: ffi::TVMRetValueHandle,
fhandle: *mut c_void,
) -> c_int {
// turning off the incorrect linter complaints
#![allow(unused_assignments)]
#![allow(unused_assignments, unused_unsafe)]
let len = num_args as usize;
let args_list = slice::from_raw_parts_mut(args, len);
let type_codes_list = slice::from_raw_parts_mut(type_codes, len);
let mut local_args: Vec<TVMArgValue> = Vec::new();
let mut value = mem::uninitialized::<ts::TVMValue>();
let mut value = mem::uninitialized::<ffi::TVMValue>();
let mut tcode = mem::uninitialized::<c_int>();
let rust_fn = mem::transmute::<*mut c_void, fn(&[TVMArgValue]) -> Result<TVMRetValue>>(fhandle);
let rust_fn =
mem::transmute::<*mut c_void, fn(&[TVMArgValue]) -> Result<TVMRetValue, Error>>(fhandle);
for i in 0..len {
value = args_list[i];
tcode = type_codes_list[i];
if tcode == ts::TVMTypeCode_kNodeHandle as c_int
|| tcode == ts::TVMTypeCode_kFuncHandle as c_int
|| tcode == ts::TVMTypeCode_kModuleHandle as c_int
if tcode == ffi::TVMTypeCode_kNodeHandle as c_int
|| tcode == ffi::TVMTypeCode_kFuncHandle as c_int
|| tcode == ffi::TVMTypeCode_kModuleHandle as c_int
{
check_call!(ts::TVMCbArgToReturn(&mut value as *mut _, tcode));
check_call!(ffi::TVMCbArgToReturn(&mut value as *mut _, tcode));
}
local_args.push(TVMArgValue::new(
TVMValue::new(value),
(tcode as i64).into(),
));
local_args.push(TVMArgValue::new(value.into(), (tcode as i64).into()));
}
let rv = match rust_fn(local_args.as_slice()) {
......@@ -332,10 +265,9 @@ unsafe extern "C" fn tvm_callback(
}
};
let (ret_val, ret_tcode) = TVMRetValue::into_tvm_value(rv);
let mut ret_val = ret_val.inner;
let (mut ret_val, ret_tcode) = rv.into_tvm_value();
let mut ret_type_code = ret_tcode as c_int;
check_call!(ts::TVMCFuncSetReturn(
check_call!(ffi::TVMCFuncSetReturn(
ret,
&mut ret_val as *mut _,
&mut ret_type_code as *mut _,
......@@ -345,24 +277,25 @@ unsafe extern "C" fn tvm_callback(
}
unsafe extern "C" fn tvm_callback_finalizer(fhandle: *mut c_void) {
let rust_fn = mem::transmute::<*mut c_void, fn(&[TVMArgValue]) -> Result<TVMRetValue>>(fhandle);
let rust_fn =
mem::transmute::<*mut c_void, fn(&[TVMArgValue]) -> Result<TVMRetValue, Error>>(fhandle);
mem::drop(rust_fn);
}
fn convert_to_tvm_func(f: fn(&[TVMArgValue]) -> Result<TVMRetValue>) -> Function {
let mut fhandle = ptr::null_mut() as ts::TVMFunctionHandle;
let resource_handle = f as *mut fn(&[TVMArgValue]) -> Result<TVMRetValue>;
check_call!(ts::TVMFuncCreateFromCFunc(
fn convert_to_tvm_func(f: fn(&[TVMArgValue]) -> Result<TVMRetValue, Error>) -> Function {
let mut fhandle = ptr::null_mut() as ffi::TVMFunctionHandle;
let resource_handle = f as *mut fn(&[TVMArgValue]) -> Result<TVMRetValue, Error>;
check_call!(ffi::TVMFuncCreateFromCFunc(
Some(tvm_callback),
resource_handle as *mut c_void,
Some(tvm_callback_finalizer),
&mut fhandle as *mut _
));
Function::new(fhandle, false, false)
Function::new(fhandle)
}
/// Registers a Rust function with signature
/// `fn(&[TVMArgValue]) -> Result<TVMRetValue>`
/// `fn(&[TVMArgValue]) -> Result<TVMRetValue, Error>`
/// as a **global TVM packed function** from frontend to TVM backend.
///
/// Use [`register_global_func`] if overriding an existing global TVM function
......@@ -373,7 +306,7 @@ fn convert_to_tvm_func(f: fn(&[TVMArgValue]) -> Result<TVMRetValue>) -> Function
/// ```
/// use std::convert::TryInto;
///
/// fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue> {
/// fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue, Error> {
/// let mut ret = 0i64;
/// for arg in args.iter() {
/// let arg: i64 = arg.try_into()?;
......@@ -391,18 +324,17 @@ fn convert_to_tvm_func(f: fn(&[TVMArgValue]) -> Result<TVMRetValue>) -> Function
/// assert_eq!(ret, 60);
/// ```
pub fn register<S: AsRef<str>>(
f: fn(&[TVMArgValue]) -> Result<TVMRetValue>,
f: fn(&[TVMArgValue]) -> Result<TVMRetValue, Error>,
name: S,
override_: bool,
) -> Result<()> {
) -> Result<(), Error> {
let func = convert_to_tvm_func(f);
let name = CString::new(name.as_ref())?;
check_call!(ts::TVMFuncRegisterGlobal(
name.as_ref().as_ptr() as *const c_char,
check_call!(ffi::TVMFuncRegisterGlobal(
name.into_raw(),
func.handle(),
override_ as c_int
));
mem::forget(name);
Ok(())
}
......@@ -416,7 +348,7 @@ pub fn register<S: AsRef<str>>(
/// use std::convert::TryInto;
///
/// register_global_func! {
/// fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue> {
/// fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue, Error> {
/// let mut ret = 0f64;
/// for arg in args.iter() {
/// let arg: f64 = arg.try_into()?;
......@@ -437,12 +369,12 @@ pub fn register<S: AsRef<str>>(
macro_rules! register_global_func {
{
$(#[$m:meta])*
fn $fn_name:ident($args:ident : &[TVMArgValue]) -> Result<TVMRetValue> {
fn $fn_name:ident($args:ident : &[TVMArgValue]) -> Result<TVMRetValue, Error> {
$($code:tt)*
}
} => {{
$(#[$m])*
fn $fn_name($args: &[TVMArgValue]) -> Result<TVMRetValue> {
fn $fn_name($args: &[TVMArgValue]) -> Result<TVMRetValue, Error> {
$($code)*
}
......@@ -496,17 +428,17 @@ mod tests {
#[test]
fn get_fn() {
assert!(Function::get(CANARY, true).is_some());
assert!(Function::get("does not exists!", false).is_none());
assert!(Function::get(CANARY).is_some());
assert!(Function::get("does not exists!").is_none());
}
#[test]
fn provide_args() {
let str_arg = CString::new("test").unwrap();
let mut func = Builder::default();
func.get_function("tvm.graph_runtime.remote_create", true)
func.get_function("tvm.graph_runtime.remote_create")
.args(&[10, 20])
.arg(&"test".to_owned());
assert!(func.arg_buf.is_some());
assert_eq!(func.arg_buf.take().map(|bv| Vec::from(bv).len()), Some(3));
.arg(&str_arg);
assert_eq!(func.arg_buf.len(), 3);
}
}
......@@ -11,32 +11,36 @@
//!
//! Checkout the `examples` repository for more details.
#![crate_name = "tvm_frontend"]
#![recursion_limit = "1024"]
#![allow(non_camel_case_types, unused_unsafe)]
#![feature(
try_from,
try_trait,
fn_traits,
unboxed_closures,
box_syntax,
option_replace
)]
#![feature(box_syntax)]
#[macro_use]
extern crate error_chain;
extern crate tvm_common as common;
extern crate failure;
#[macro_use]
extern crate lazy_static;
extern crate ndarray as rust_ndarray;
extern crate num_traits;
extern crate tvm_common;
use std::{
ffi::{CStr, CString},
str,
};
use crate::common::ffi::ts;
use failure::Error;
pub use crate::{
bytearray::TVMByteArray,
context::{TVMContext, TVMDeviceType},
errors::*,
function::Function,
module::Module,
ndarray::NDArray,
tvm_common::{
errors as common_errors,
ffi::{self, TVMType},
packed_func::{TVMArgValue, TVMRetValue},
},
};
// Macro to check the return call to TVM runtime shared library.
macro_rules! check_call {
......@@ -50,7 +54,7 @@ macro_rules! check_call {
/// Gets the last error message.
pub fn get_last_error() -> &'static str {
unsafe {
match CStr::from_ptr(ts::TVMGetLastError()).to_str() {
match CStr::from_ptr(ffi::TVMGetLastError()).to_str() {
Ok(s) => s,
Err(_) => "Invalid UTF-8 message",
}
......@@ -60,7 +64,7 @@ pub fn get_last_error() -> &'static str {
pub(crate) fn set_last_error(err: &Error) {
let c_string = CString::new(err.to_string()).unwrap();
unsafe {
ts::TVMAPISetLastError(c_string.as_ptr());
ffi::TVMAPISetLastError(c_string.as_ptr());
}
}
......@@ -71,27 +75,11 @@ pub mod context;
pub mod errors;
pub mod module;
pub mod ndarray;
pub mod ty;
pub mod value;
pub use crate::{
bytearray::TVMByteArray,
common::{
errors as common_errors,
ty::TVMTypeCode,
value::{TVMArgValue, TVMRetValue, TVMValue},
},
context::{TVMContext, TVMDeviceType},
errors::*,
function::Function,
module::Module,
ndarray::NDArray,
ty::TVMType,
};
/// Outputs the current TVM version.
pub fn version() -> &'static str {
match str::from_utf8(ts::TVM_VERSION) {
match str::from_utf8(ffi::TVM_VERSION) {
Ok(s) => s,
Err(_) => "Invalid UTF-8 string",
}
......@@ -108,8 +96,8 @@ mod tests {
#[test]
fn set_error() {
let err = ErrorKind::EmptyArray;
let err = errors::EmptyArrayError;
set_last_error(&err.into());
assert_eq!(get_last_error().trim(), ErrorKind::EmptyArray.to_string());
assert_eq!(get_last_error().trim(), errors::EmptyArrayError.to_string());
}
}
......@@ -8,30 +8,27 @@ use std::{
ptr,
};
use crate::ts;
use failure::Error;
use tvm_common::ffi;
use crate::{function::Function, ErrorKind, Result};
use crate::{errors, function::Function};
const ENTRY_FUNC: &'static str = "__tvm_main__";
/// Wrapper around TVM module handle which contains an entry function.
/// The entry function can be applied to an imported module through [`entry_func`].
/// Also [`is_released`] shows whether the module is dropped or not.
///
/// [`entry_func`]:struct.Module.html#method.entry_func
/// [`is_released`]:struct.Module.html#method.is_released
#[derive(Debug, Clone)]
pub struct Module {
pub(crate) handle: ts::TVMModuleHandle,
is_released: bool,
pub(crate) handle: ffi::TVMModuleHandle,
entry_func: Option<Function>,
}
impl Module {
pub(crate) fn new(handle: ts::TVMModuleHandle, is_released: bool) -> Self {
pub(crate) fn new(handle: ffi::TVMModuleHandle) -> Self {
Self {
handle,
is_released,
entry_func: None,
}
}
......@@ -44,62 +41,67 @@ impl Module {
}
/// Gets a function by name from a registered module.
pub fn get_function(&self, name: &str, query_import: bool) -> Result<Function> {
pub fn get_function(&self, name: &str, query_import: bool) -> Result<Function, Error> {
let name = CString::new(name)?;
let mut fhandle = ptr::null_mut() as ts::TVMFunctionHandle;
check_call!(ts::TVMModGetFunction(
let mut fhandle = ptr::null_mut() as ffi::TVMFunctionHandle;
check_call!(ffi::TVMModGetFunction(
self.handle,
name.as_ptr() as *const c_char,
query_import as c_int,
&mut fhandle as *mut _
));
if fhandle.is_null() {
bail!(ErrorKind::NullHandle(format!("{}", name.into_string()?)))
} else {
Ok(Function::new(fhandle, false, false))
ensure!(
!fhandle.is_null(),
errors::NullHandleError {
name: format!("{}", name.into_string()?)
}
);
Ok(Function::new(fhandle))
}
/// Imports a dependent module such as `.ptx` for gpu.
pub fn import_module(&self, dependent_module: Module) {
check_call!(ts::TVMModImport(self.handle, dependent_module.handle))
check_call!(ffi::TVMModImport(self.handle, dependent_module.handle))
}
/// Loads a module shared library from path.
pub fn load<P: AsRef<Path>>(path: &P) -> Result<Module> {
let ext = path.as_ref().extension()?.to_str()?;
let func = Function::get("module._LoadFromFile", true /* is_global */)
.expect("API function always exists");
let ret: Module = call_packed!(func, path.as_ref().to_str()?, ext)?.try_into()?;
pub fn load<P: AsRef<Path>>(path: &P) -> Result<Module, Error> {
let ext = CString::new(
path.as_ref()
.extension()
.unwrap_or(std::ffi::OsStr::new(""))
.to_str()
.ok_or_else(|| {
format_err!("Bad module load path: `{}`.", path.as_ref().display())
})?,
)?;
let func = Function::get("module._LoadFromFile").expect("API function always exists");
let cpath =
CString::new(path.as_ref().to_str().ok_or_else(|| {
format_err!("Bad module load path: `{}`.", path.as_ref().display())
})?)?;
let ret: Module = call_packed!(func, &cpath, &ext)?.try_into()?;
Ok(ret)
}
/// Checks if a target device is enabled for a module.
pub fn enabled(&self, target: &str) -> bool {
let func = Function::get("module._Enabled", true /* is_global */)
.expect("API function always exists");
let func = Function::get("module._Enabled").expect("API function always exists");
// `unwrap` is safe here because if there is any error during the
// function call, it would occur in `call_packed!`.
let ret: i64 = call_packed!(func, target).unwrap().try_into().unwrap();
let tgt = CString::new(target).unwrap();
let ret: i64 = call_packed!(func, &tgt).unwrap().try_into().unwrap();
ret != 0
}
/// Returns the underlying module handle.
pub fn handle(&self) -> ts::TVMModuleHandle {
pub fn handle(&self) -> ffi::TVMModuleHandle {
self.handle
}
/// Returns true if the underlying module has been dropped and false otherwise.
pub fn is_released(&self) -> bool {
self.is_released
}
}
impl Drop for Module {
fn drop(&mut self) {
if !self.is_released {
check_call!(ts::TVMModFree(self.handle));
self.is_released = true;
}
check_call!(ffi::TVMModFree(self.handle));
}
}
......@@ -23,34 +23,34 @@
//! [`copy_from_buffer`]:struct.NDArray.html#method.copy_from_buffer
//! [`copy_to_ctx`]:struct.NDArray.html#method.copy_to_ctx
use std::{convert::TryFrom, mem, os::raw::c_int, ptr, slice};
use std::{convert::TryFrom, mem, os::raw::c_int, ptr, slice, str::FromStr};
use crate::rust_ndarray::{Array, ArrayD};
use failure::Error;
use num_traits::Num;
use rust_ndarray::{Array, ArrayD};
use tvm_common::{ffi, TVMType};
use crate::ts;
use crate::{Error, ErrorKind, Result, TVMByteArray, TVMContext, TVMType};
use crate::{errors, TVMByteArray, TVMContext};
/// See the [`module-level documentation`](../ndarray/index.html) for more details.
///
/// Wrapper around TVM array handle.
#[derive(Debug)]
pub struct NDArray {
pub(crate) handle: ts::TVMArrayHandle,
pub(crate) handle: ffi::TVMArrayHandle,
is_view: bool,
}
impl NDArray {
pub(crate) fn new(handle: ts::TVMArrayHandle, is_view: bool) -> Self {
pub(crate) fn new(handle: ffi::TVMArrayHandle) -> Self {
NDArray {
handle: handle,
is_view: is_view,
is_view: true,
}
}
/// Returns the underlying array handle.
pub fn handle(&self) -> ts::TVMArrayHandle {
pub fn handle(&self) -> ffi::TVMArrayHandle {
self.handle
}
......@@ -99,12 +99,13 @@ impl NDArray {
}
/// Shows whether the underlying ndarray is contiguous in memory or not.
pub fn is_contiguous(&self) -> Result<bool> {
pub fn is_contiguous(&self) -> Result<bool, Error> {
Ok(match self.strides() {
None => true,
Some(strides) => {
// MissingShapeError in case shape is not determined
self.shape()?
// errors::MissingShapeError in case shape is not determined
self.shape()
.ok_or(errors::MissingShapeError)?
.iter()
.zip(strides)
.rfold(
......@@ -138,14 +139,16 @@ impl NDArray {
/// assert_eq!(ndarray.shape(), Some(shape));
/// assert_eq!(ndarray.to_vec::<i32>().unwrap(), data);
/// ```
pub fn to_vec<T>(&self) -> Result<Vec<T>> {
if self.shape().is_none() {
bail!("{}", ErrorKind::EmptyArray);
}
let earr = NDArray::empty(self.shape()?, TVMContext::cpu(0), self.dtype());
pub fn to_vec<T>(&self) -> Result<Vec<T>, Error> {
ensure!(self.shape().is_some(), errors::EmptyArrayError);
let earr = NDArray::empty(
self.shape().ok_or(errors::MissingShapeError)?,
TVMContext::cpu(0),
self.dtype(),
);
let target = self.copy_to_ndarray(earr)?;
let arr = unsafe { *(target.handle) };
let sz = self.size()? as usize;
let sz = self.size().ok_or(errors::MissingShapeError)?;
let mut v: Vec<T> = Vec::with_capacity(sz * mem::size_of::<T>());
unsafe {
v.as_mut_ptr()
......@@ -156,7 +159,7 @@ impl NDArray {
}
/// Converts the NDArray to [`TVMByteArray`].
pub fn to_bytearray(&self) -> Result<TVMByteArray> {
pub fn to_bytearray(&self) -> Result<TVMByteArray, Error> {
let v = self.to_vec::<u8>()?;
Ok(TVMByteArray::from(&v))
}
......@@ -176,7 +179,7 @@ impl NDArray {
/// *Note*: if something goes wrong during the copy, it will panic
/// from TVM side. See `TVMArrayCopyFromBytes` in `include/tvm/runtime/c_runtime_api.h`.
pub fn copy_from_buffer<T: Num32>(&mut self, data: &mut [T]) {
check_call!(ts::TVMArrayCopyFromBytes(
check_call!(ffi::TVMArrayCopyFromBytes(
self.handle,
data.as_ptr() as *mut _,
data.len() * mem::size_of::<T>()
......@@ -184,27 +187,31 @@ impl NDArray {
}
/// Copies the NDArray to another target NDArray.
pub fn copy_to_ndarray(&self, target: NDArray) -> Result<NDArray> {
pub fn copy_to_ndarray(&self, target: NDArray) -> Result<NDArray, Error> {
if self.dtype() != target.dtype() {
bail!(
"{}",
ErrorKind::TypeMismatch(
format!("{}", self.dtype().to_string()),
format!("{}", target.dtype().to_string()),
)
errors::TypeMismatchError {
expected: format!("{}", self.dtype().to_string()),
actual: format!("{}", target.dtype().to_string()),
}
);
}
check_call!(ts::TVMArrayCopyFromTo(
check_call!(ffi::TVMArrayCopyFromTo(
self.handle,
target.handle,
ptr::null_mut() as ts::TVMStreamHandle
ptr::null_mut() as ffi::TVMStreamHandle
));
Ok(target)
}
/// Copies the NDArray to a target context.
pub fn copy_to_ctx(&self, target: &TVMContext) -> Result<NDArray> {
let tmp = NDArray::empty(self.shape()?, target.clone(), self.dtype());
pub fn copy_to_ctx(&self, target: &TVMContext) -> Result<NDArray, Error> {
let tmp = NDArray::empty(
self.shape().ok_or(errors::MissingShapeError)?,
target.clone(),
self.dtype(),
);
let copy = self.copy_to_ndarray(tmp)?;
Ok(copy)
}
......@@ -214,28 +221,34 @@ impl NDArray {
rnd: &ArrayD<T>,
ctx: TVMContext,
dtype: TVMType,
) -> Result<Self> {
) -> Result<Self, Error> {
let mut shape = rnd.shape().to_vec();
let mut nd = NDArray::empty(&mut shape, ctx, dtype);
let mut buf = Array::from_iter(rnd.into_iter().map(|&v| v as T));
nd.copy_from_buffer(buf.as_slice_mut()?);
nd.copy_from_buffer(
buf.as_slice_mut()
.expect("Array from iter must be contiguous."),
);
Ok(nd)
}
/// Allocates and creates an empty NDArray given the shape, context and dtype.
pub fn empty(shape: &[usize], ctx: TVMContext, dtype: TVMType) -> NDArray {
let mut handle = ptr::null_mut() as ts::TVMArrayHandle;
check_call!(ts::TVMArrayAlloc(
let mut handle = ptr::null_mut() as ffi::TVMArrayHandle;
check_call!(ffi::TVMArrayAlloc(
shape.as_ptr() as *const i64,
shape.len() as c_int,
dtype.inner.code as c_int,
dtype.inner.bits as c_int,
dtype.inner.lanes as c_int,
dtype.code as c_int,
dtype.bits as c_int,
dtype.lanes as c_int,
ctx.device_type.0 as c_int,
ctx.device_id as c_int,
&mut handle as *mut _,
));
NDArray::new(handle, false)
NDArray {
handle,
is_view: false,
}
}
}
......@@ -243,23 +256,25 @@ macro_rules! impl_from_ndarray_rustndarray {
($type:ty, $type_name:tt) => {
impl<'a> TryFrom<&'a NDArray> for ArrayD<$type> {
type Error = Error;
fn try_from(nd: &NDArray) -> Result<ArrayD<$type>> {
if nd.shape().is_none() {
bail!("{}", ErrorKind::EmptyArray);
}
assert_eq!(nd.dtype(), TVMType::from($type_name), "Type mismatch");
Ok(Array::from_shape_vec(&*nd.shape()?, nd.to_vec::<$type>()?)?)
fn try_from(nd: &NDArray) -> Result<ArrayD<$type>, Self::Error> {
ensure!(nd.shape().is_some(), errors::MissingShapeError);
assert_eq!(nd.dtype(), TVMType::from_str($type_name)?, "Type mismatch");
Ok(Array::from_shape_vec(
&*nd.shape().ok_or(errors::MissingShapeError)?,
nd.to_vec::<$type>()?,
)?)
}
}
impl<'a> TryFrom<&'a mut NDArray> for ArrayD<$type> {
type Error = Error;
fn try_from(nd: &mut NDArray) -> Result<ArrayD<$type>> {
if nd.shape().is_none() {
bail!("{}", ErrorKind::EmptyArray);
}
assert_eq!(nd.dtype(), TVMType::from($type_name), "Type mismatch");
Ok(Array::from_shape_vec(&*nd.shape()?, nd.to_vec::<$type>()?)?)
fn try_from(nd: &mut NDArray) -> Result<ArrayD<$type>, Self::Error> {
ensure!(nd.shape().is_some(), errors::MissingShapeError);
assert_eq!(nd.dtype(), TVMType::from_str($type_name)?, "Type mismatch");
Ok(Array::from_shape_vec(
&*nd.shape().ok_or(errors::MissingShapeError)?,
nd.to_vec::<$type>()?,
)?)
}
}
};
......@@ -272,7 +287,7 @@ impl_from_ndarray_rustndarray!(f32, "float");
impl Drop for NDArray {
fn drop(&mut self) {
if !self.is_view {
check_call!(ts::TVMArrayFree(self.handle));
check_call!(ffi::TVMArrayFree(self.handle));
}
}
}
......@@ -306,7 +321,7 @@ mod tests {
fn basics() {
let shape = &mut [1, 2, 3];
let ctx = TVMContext::cpu(0);
let ndarray = NDArray::empty(shape, ctx, TVMType::from("int32"));
let ndarray = NDArray::empty(shape, ctx, TVMType::from_str("int32").unwrap());
assert_eq!(ndarray.shape().unwrap(), shape);
assert_eq!(
ndarray.size().unwrap(),
......@@ -322,7 +337,7 @@ mod tests {
let shape = &mut [4];
let mut data = vec![1i32, 2, 3, 4];
let ctx = TVMContext::cpu(0);
let mut ndarray = NDArray::empty(shape, ctx, TVMType::from("int32"));
let mut ndarray = NDArray::empty(shape, ctx, TVMType::from_str("int32").unwrap());
assert!(ndarray.to_vec::<i32>().is_ok());
ndarray.copy_from_buffer(&mut data);
assert_eq!(ndarray.shape().unwrap(), shape);
......@@ -331,7 +346,11 @@ mod tests {
assert!(ndarray.is_contiguous().is_ok());
assert_eq!(ndarray.byte_offset(), 0);
let mut shape = vec![4];
let e = NDArray::empty(&mut shape, TVMContext::cpu(0), TVMType::from("int32"));
let e = NDArray::empty(
&mut shape,
TVMContext::cpu(0),
TVMType::from_str("int32").unwrap(),
);
let nd = ndarray.copy_to_ndarray(e);
assert!(nd.is_ok());
assert_eq!(nd.unwrap().to_vec::<i32>().unwrap(), data);
......@@ -343,9 +362,13 @@ mod tests {
let mut shape = vec![4];
let mut data = vec![1f32, 2., 3., 4.];
let ctx = TVMContext::cpu(0);
let mut nd_float = NDArray::empty(&mut shape, ctx.clone(), TVMType::from("float32"));
let mut nd_float = NDArray::empty(
&mut shape,
ctx.clone(),
TVMType::from_str("float32").unwrap(),
);
nd_float.copy_from_buffer(&mut data);
let empty_int = NDArray::empty(&mut shape, ctx, TVMType::from("int32"));
let empty_int = NDArray::empty(&mut shape, ctx, TVMType::from_str("int32").unwrap());
nd_float.copy_to_ndarray(empty_int).unwrap();
}
......@@ -354,8 +377,12 @@ mod tests {
let a = Array::from_shape_vec((2, 2), vec![1f32, 2., 3., 4.])
.unwrap()
.into_dyn();
let nd =
NDArray::from_rust_ndarray(&a, TVMContext::cpu(0), TVMType::from("float32")).unwrap();
let nd = NDArray::from_rust_ndarray(
&a,
TVMContext::cpu(0),
TVMType::from_str("float32").unwrap(),
)
.unwrap();
assert_eq!(nd.shape().unwrap(), &mut [2, 2]);
let rnd: ArrayD<f32> = ArrayD::try_from(&nd).unwrap();
assert!(rnd.all_close(&a, 1e-8f32));
......
//! This module implements the required conversions from Rust types to TVM types.
//!
//! In TVM frontend only conversions from Rust's 32-bits (POD) numeric types (i32, u32, f32)
//! and 64-bits pointers are supported.
use std::{
fmt::{self, Display, Formatter},
ops::{Deref, DerefMut},
};
use crate::ts;
use crate::{Function, Module, NDArray, TVMByteArray, TVMContext, TVMDeviceType, TVMTypeCode};
macro_rules! impl_prim_type {
($type:ty, $variant:ident) => {
impl From<$type> for TVMTypeCode {
fn from(_arg: $type) -> Self {
TVMTypeCode::$variant
}
}
impl<'a> From<&'a $type> for TVMTypeCode {
fn from(_arg: &$type) -> Self {
TVMTypeCode::$variant
}
}
impl<'a> From<&'a mut $type> for TVMTypeCode {
fn from(_arg: &mut $type) -> Self {
TVMTypeCode::$variant
}
}
};
}
impl_prim_type!(TVMDeviceType, kDLInt);
impl_prim_type!(TVMContext, kTVMContext);
impl_prim_type!(TVMType, kTVMType);
impl_prim_type!(Function, kFuncHandle);
impl_prim_type!(Module, kModuleHandle);
impl_prim_type!(NDArray, kArrayHandle);
impl_prim_type!(TVMByteArray, kBytes);
/// See the [module-level documentation](../ty/index.html) for more details.
///
/// Wrapper around underlying TVMType
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub struct TVMType {
// inner fields are (code: u8, bits: u8, lanes: u16)
pub inner: ts::TVMType,
}
impl TVMType {
pub(crate) fn new(type_code: u8, bits: u8, lanes: u16) -> Self {
TVMType {
inner: ts::TVMType {
code: type_code,
bits: bits,
lanes: lanes,
},
}
}
}
/// Implements TVMType conversion from `&str` of general format `{dtype}{bits}x{lanes}`
/// such as "int32", "float32" or with lane "float32x1".
impl<'a> From<&'a str> for TVMType {
fn from(type_str: &'a str) -> Self {
if type_str == "bool" {
return TVMType::new(1, 1, 1);
}
let mut type_lanes = type_str.split("x");
let typ = type_lanes.next().expect("Missing dtype");
let lanes = type_lanes
.next()
.map(|l| u16::from_str_radix(l, 10).expect(&format!("Bad dtype lanes: {}", l)))
.unwrap_or(1);
let (type_name, bits) = match typ.find(char::is_numeric) {
Some(idx) => {
let (name, bits_str) = typ.split_at(idx);
(
name,
u8::from_str_radix(bits_str, 10)
.expect(&format!("Bad dtype bits: {}", bits_str)),
)
}
None => (typ, 32),
};
let type_code = match type_name {
"int" => 0,
"uint" => 1,
"float" => 2,
"handle" => 3,
_ => unimplemented!(),
};
TVMType::new(type_code, bits, lanes)
}
}
impl Display for TVMType {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
let ts::TVMType { code, bits, lanes } = self.inner;
if bits == 1 && lanes == 1 {
return write!(f, "bool");
}
let mut tcode_str = match code {
0 => "int",
1 => "uint",
2 => "float",
4 => "handle",
_ => "Unknown",
}
.to_string();
tcode_str += &bits.to_string();
if lanes > 1 {
tcode_str += &format!("x{}", lanes.to_string());
}
f.write_str(&tcode_str)
}
}
impl From<TVMType> for ts::DLDataType {
fn from(dtype: TVMType) -> Self {
dtype.inner
}
}
impl From<ts::DLDataType> for TVMType {
fn from(dtype: ts::DLDataType) -> Self {
Self::new(dtype.code, dtype.bits, dtype.lanes)
}
}
impl Deref for TVMType {
type Target = ts::TVMType;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl DerefMut for TVMType {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.inner
}
}
......@@ -2,139 +2,87 @@
//! and their conversions needed for the types used in frontend crate.
//! `TVMRetValue` is the owned version of `TVMPODValue`.
use std::{convert::TryFrom, mem, os::raw::c_void};
use std::{convert::TryFrom, os::raw::c_void};
use failure::Error;
use tvm_common::{
ensure_type,
ffi::{self, TVMValue},
};
use crate::{
common_errors::*, ts, Function, Module, NDArray, TVMArgValue, TVMByteArray, TVMContext,
TVMDeviceType, TVMRetValue, TVMType, TVMTypeCode, TVMValue,
common_errors::*, context::TVMContext, Function, Module, NDArray, TVMArgValue, TVMByteArray,
TVMRetValue,
};
macro_rules! impl_tvm_val_from_handle {
($($ty:ty),+) => {
$(
impl<'a> From<&'a $ty> for TVMValue {
($ty:ident, $type_code:expr, $handle:ty) => {
impl<'a> From<&'a $ty> for TVMArgValue<'a> {
fn from(arg: &$ty) -> Self {
let inner = ts::TVMValue {
TVMArgValue {
value: TVMValue {
v_handle: arg.handle as *mut _ as *mut c_void,
};
Self::new(inner)
}
},
type_code: $type_code as i64,
_lifetime: std::marker::PhantomData,
}
)+
}
}
impl_tvm_val_from_handle!(Module, Function, NDArray);
impl<'a> From<&'a TVMType> for TVMValue {
fn from(ty: &TVMType) -> Self {
let inner = ts::TVMValue { v_type: ty.inner };
Self::new(inner)
}
}
impl<'a> From<&'a TVMContext> for TVMValue {
fn from(ctx: &TVMContext) -> Self {
let inner = ts::TVMValue {
v_ctx: ctx.clone().into(),
};
Self::new(inner)
}
}
impl<'a> From<&'a TVMDeviceType> for TVMValue {
fn from(dev: &TVMDeviceType) -> Self {
let inner = ts::TVMValue {
v_int64: dev.0 as i64,
};
Self::new(inner)
}
}
impl<'a> From<&'a TVMByteArray> for TVMValue {
fn from(barr: &TVMByteArray) -> Self {
let inner = ts::TVMValue {
v_handle: &barr.inner as *const ts::TVMByteArray as *mut c_void,
};
Self::new(inner)
impl<'a> From<&'a mut $ty> for TVMArgValue<'a> {
fn from(arg: &mut $ty) -> Self {
TVMArgValue {
value: TVMValue {
v_handle: arg.handle as *mut _ as *mut c_void,
},
type_code: $type_code as i64,
_lifetime: std::marker::PhantomData,
}
}
impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for NDArray {
type Error = Error;
fn try_from(arg: &TVMArgValue<'a>) -> Result<Self> {
if arg.type_code == TVMTypeCode::kArrayHandle {
let handle = unsafe { arg.value.inner.v_handle };
let arr_handle = unsafe { mem::transmute::<*mut c_void, ts::TVMArrayHandle>(handle) };
Ok(Self::new(arr_handle, true))
} else {
bail!(ErrorKind::TryFromTVMArgValueError(
stringify!(NDArray).to_string(),
arg.type_code.to_string()
))
}
}
}
impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for Module {
impl<'a, 'v> TryFrom<&'a TVMArgValue<'v>> for $ty {
type Error = Error;
fn try_from(arg: &TVMArgValue<'a>) -> Result<Self> {
if arg.type_code == TVMTypeCode::kModuleHandle {
let handle = unsafe { arg.value.inner.v_handle };
Ok(Self::new(handle, false))
} else {
bail!(ErrorKind::TryFromTVMArgValueError(
stringify!(Module).to_string(),
arg.type_code.to_string()
))
fn try_from(arg: &TVMArgValue<'v>) -> Result<$ty, Self::Error> {
ensure_type!(arg, $type_code);
Ok($ty::new(unsafe { arg.value.v_handle as $handle }))
}
}
}
impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for TVMByteArray {
type Error = Error;
fn try_from(arg: &TVMArgValue<'a>) -> Result<Self> {
if arg.type_code == TVMTypeCode::kBytes {
unsafe {
let barr_ptr =
mem::transmute::<*mut c_void, *mut ts::TVMByteArray>(arg.value.inner.v_handle);
Ok(Self::new(*barr_ptr))
impl From<$ty> for TVMRetValue {
fn from(val: $ty) -> TVMRetValue {
TVMRetValue {
value: TVMValue {
v_handle: val.handle() as *mut c_void,
},
box_value: box val,
type_code: $type_code as i64,
}
} else {
bail!(ErrorKind::TryFromTVMArgValueError(
stringify!(TVMByteArray).to_string(),
arg.type_code.to_string()
))
}
}
}
impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for TVMType {
impl TryFrom<TVMRetValue> for $ty {
type Error = Error;
fn try_from(arg: &TVMArgValue<'a>) -> Result<Self> {
if arg.type_code == TVMTypeCode::kTVMType {
let ty = unsafe { arg.value.inner.v_type };
Ok(TVMType::from(ty))
} else {
bail!(ErrorKind::TryFromTVMArgValueError(
stringify!(TVMType).to_string(),
arg.type_code.to_string()
))
fn try_from(ret: TVMRetValue) -> Result<$ty, Self::Error> {
ensure_type!(ret, $type_code);
Ok($ty::new(unsafe { ret.value.v_handle as $handle }))
}
}
};
}
impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for TVMContext {
type Error = Error;
fn try_from(arg: &TVMArgValue<'a>) -> Result<Self> {
if arg.type_code == TVMTypeCode::kTVMContext {
let ty = unsafe { arg.value.inner.v_ctx };
Ok(TVMContext::from(ty))
} else {
bail!(ErrorKind::TryFromTVMArgValueError(
stringify!(TVMContext).to_string(),
arg.type_code.to_string()
))
impl_tvm_val_from_handle!(
Function,
ffi::TVMTypeCode_kFuncHandle,
ffi::TVMFunctionHandle
);
impl_tvm_val_from_handle!(Module, ffi::TVMTypeCode_kModuleHandle, ffi::TVMModuleHandle);
impl_tvm_val_from_handle!(NDArray, ffi::TVMTypeCode_kArrayHandle, ffi::TVMArrayHandle);
impl<'a> From<&'a TVMByteArray> for TVMValue {
fn from(barr: &TVMByteArray) -> Self {
TVMValue {
v_handle: &barr.inner as *const ffi::TVMByteArray as *mut c_void,
}
}
}
......@@ -144,78 +92,43 @@ macro_rules! impl_boxed_ret_value {
impl From<$type> for TVMRetValue {
fn from(val: $type) -> Self {
TVMRetValue {
prim_value: 0,
value: TVMValue { v_int64: 0 },
box_value: box val,
type_code: $code,
type_code: $code as i64,
}
}
}
impl TryFrom<TVMRetValue> for $type {
type Error = Error;
fn try_from(ret: TVMRetValue) -> Result<$type> {
fn try_from(ret: TVMRetValue) -> Result<$type, Self::Error> {
if let Ok(val) = ret.box_value.downcast::<$type>() {
Ok(*val)
} else {
bail!(ErrorKind::TryFromTVMRetValueError(
stringify!($type).to_string(),
ret.type_code.to_string()
))
bail!(ValueDowncastError::new($code as i64, ret.type_code as i64))
}
}
}
};
}
impl_boxed_ret_value!(TVMType, TVMTypeCode::kTVMType);
impl_boxed_ret_value!(TVMContext, TVMTypeCode::kTVMContext);
impl_boxed_ret_value!(TVMByteArray, TVMTypeCode::kBytes);
impl TryFrom<TVMRetValue> for Module {
type Error = Error;
fn try_from(ret: TVMRetValue) -> Result<Module> {
if let Ok(handle) = ret.box_value.downcast::<ts::TVMModuleHandle>() {
Ok(Module::new(*handle, false))
} else {
bail!(ErrorKind::TryFromTVMRetValueError(
stringify!(TVMTypeCode::kModuleHandle).to_string(),
ret.type_code.to_string()
))
}
}
}
impl TryFrom<TVMRetValue> for Function {
type Error = Error;
fn try_from(ret: TVMRetValue) -> Result<Function> {
if let Ok(handle) = ret.box_value.downcast::<ts::TVMFunctionHandle>() {
Ok(Function::new(*handle, false, false))
} else {
bail!(ErrorKind::TryFromTVMRetValueError(
stringify!(TVMTypeCode::kFuncHandle).to_string(),
ret.type_code.to_string()
))
}
}
}
impl_boxed_ret_value!(TVMContext, ffi::TVMTypeCode_kTVMContext);
impl_boxed_ret_value!(TVMByteArray, ffi::TVMTypeCode_kBytes);
impl TryFrom<TVMRetValue> for NDArray {
impl<'a, 'v> TryFrom<&'a TVMArgValue<'v>> for TVMByteArray {
type Error = Error;
fn try_from(ret: TVMRetValue) -> Result<NDArray> {
if let Ok(handle) = ret.box_value.downcast::<ts::TVMArrayHandle>() {
Ok(NDArray::new(*handle, false))
} else {
bail!(ErrorKind::TryFromTVMRetValueError(
stringify!(TVMTypeCode::kArrayHandle).to_string(),
ret.type_code.to_string()
))
}
fn try_from(arg: &TVMArgValue<'v>) -> Result<Self, Self::Error> {
ensure_type!(arg, ffi::TVMTypeCode_kBytes);
Ok(TVMByteArray::new(unsafe {
*(arg.value.v_handle as *mut ffi::TVMByteArray)
}))
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::convert::TryInto;
use std::{convert::TryInto, str::FromStr};
use tvm_common::ffi::TVMType;
#[test]
fn bytearray() {
......@@ -227,7 +140,7 @@ mod tests {
#[test]
fn ty() {
let t = TVMType::from("int32");
let t = TVMType::from_str("int32").unwrap();
let tvm: TVMType = TVMRetValue::from(t).try_into().unwrap();
assert_eq!(tvm, t);
}
......
extern crate ndarray as rust_ndarray;
extern crate tvm_frontend as tvm;
use std::str::FromStr;
use tvm::*;
fn main() {
......@@ -12,7 +14,7 @@ fn main() {
} else {
(TVMContext::gpu(0), "gpu")
};
let dtype = TVMType::from("float32");
let dtype = TVMType::from_str("float32").unwrap();
let mut arr = NDArray::empty(shape, ctx, dtype);
arr.copy_from_buffer(data.as_mut_slice());
let mut ret = NDArray::empty(shape, ctx, dtype);
......@@ -26,8 +28,7 @@ fn main() {
function::Builder::from(&mut fadd)
.arg(&arr)
.arg(&arr)
.set_output(&mut ret)
.unwrap()
.arg(&mut ret)
.invoke()
.unwrap();
......
#![feature(extern_crate_item_prelude, try_from)]
#![allow(unused_imports)]
extern crate ndarray as rust_ndarray;
......@@ -6,17 +5,23 @@ extern crate ndarray as rust_ndarray;
extern crate tvm_frontend as tvm;
use rust_ndarray::ArrayD;
use std::convert::{TryFrom, TryInto};
use std::{
convert::{TryFrom, TryInto},
str::FromStr,
};
use tvm::*;
use tvm::{errors::Error, *};
fn main() {
register_global_func! {
fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue> {
fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue, Error> {
let mut ret = 0f32;
let shape = &mut [2];
for arg in args.iter() {
let e = NDArray::empty(shape, TVMContext::cpu(0), TVMType::from("float32"));
let e = NDArray::empty(
shape, TVMContext::cpu(0),
TVMType::from_str("float32").unwrap()
);
let arg: NDArray = arg.try_into()?;
let arr = arg.copy_to_ndarray(e)?;
let rnd: ArrayD<f32> = ArrayD::try_from(&arr)?;
......@@ -28,12 +33,16 @@ fn main() {
let shape = &mut [2];
let mut data = vec![3f32, 4.0];
let mut arr = NDArray::empty(shape, TVMContext::cpu(0), TVMType::from("float32"));
let mut arr = NDArray::empty(
shape,
TVMContext::cpu(0),
TVMType::from_str("float32").unwrap(),
);
arr.copy_from_buffer(data.as_mut_slice());
let mut registered = function::Builder::default();
let ret: f32 = registered
.get_function("sum", true)
.get_function("sum")
.arg(&arr)
.arg(&arr)
.invoke()
......
#![feature(extern_crate_item_prelude, panic_info_message)]
#![feature(panic_info_message)]
#![allow(unused_imports)]
use std::panic;
......@@ -6,20 +6,20 @@ use std::panic;
#[macro_use]
extern crate tvm_frontend as tvm;
use tvm::*;
use tvm::{errors::Error, *};
fn main() {
register_global_func! {
fn error(_args: &[TVMArgValue]) -> Result<TVMRetValue> {
Err(ErrorKind::TypeMismatch(
format!("{}", "i64".to_string()),
format!("{}", "f64".to_string()),
).into())
fn error(_args: &[TVMArgValue]) -> Result<TVMRetValue, Error> {
Err(errors::TypeMismatchError{
expected: "i64".to_string(),
actual: "f64".to_string(),
}.into())
}
}
let mut registered = function::Builder::default();
registered.get_function("error", true);
registered.get_function("error");
assert!(registered.func.is_some());
registered.args(&[10, 20]);
......
#![feature(extern_crate_item_prelude, try_from)]
#![allow(unused_imports)]
#[macro_use]
extern crate tvm_frontend as tvm;
use std::convert::TryInto;
use tvm::*;
use tvm::{errors::Error, *};
fn main() {
register_global_func! {
fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue> {
fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue, Error> {
let mut ret = 0.0;
for arg in args.iter() {
for arg in args.into_iter() {
let val: f64 = arg.try_into()?;
ret += val;
}
Ok(TVMRetValue::from(&ret))
Ok(TVMRetValue::from(ret))
}
}
let mut registered = function::Builder::default();
registered.get_function("sum", true);
registered.get_function("sum");
assert!(registered.func.is_some());
let ret: f64 = registered
.args(&[10.0f64, 20.0, 30.0])
......
#![feature(extern_crate_item_prelude, try_from)]
#![allow(unused_imports)]
extern crate tvm_frontend as tvm;
use std::convert::TryInto;
use tvm::*;
use tvm::{errors::Error, *};
fn main() {
fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue> {
fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue, Error> {
let mut ret = 0i64;
for arg in args.iter() {
let val: i64 = arg.try_into()?;
ret += val;
}
Ok(TVMRetValue::from(&ret))
Ok(TVMRetValue::from(ret))
}
tvm::function::register(sum, "mysum".to_owned(), false).unwrap();
let mut registered = function::Builder::default();
registered.get_function("mysum", true);
registered.get_function("mysum");
assert!(registered.func.is_some());
let ret: i64 = registered
.args(&[10, 20, 30])
......
#![feature(extern_crate_item_prelude, try_from)]
#![allow(unused_imports)]
#[macro_use]
extern crate tvm_frontend as tvm;
use std::convert::TryInto;
use tvm::*;
use tvm::{errors::Error, *};
// FIXME
fn main() {
register_global_func! {
fn concate_str(args: &[TVMArgValue]) -> Result<TVMRetValue> {
fn concate_str(args: &[TVMArgValue]) -> Result<TVMRetValue, Error> {
let mut ret = "".to_string();
for arg in args.iter() {
let val: String = arg.try_into()?;
ret += val.as_str();
let val: &str = arg.try_into()?;
ret += val;
}
Ok(TVMRetValue::from(ret))
}
}
let a = std::ffi::CString::new("a").unwrap();
let b = std::ffi::CString::new("b").unwrap();
let c = std::ffi::CString::new("c").unwrap();
let mut registered = function::Builder::default();
registered.get_function("concate_str", true);
registered.get_function("concate_str");
assert!(registered.func.is_some());
let a = "a".to_string();
let b = "b".to_string();
let c = "c".to_string();
let ret: String = registered
.args(&[a, b, c])
.arg(&a)
.arg(&b)
.arg(&c)
.invoke()
.unwrap()
.try_into()
......
......@@ -15,15 +15,15 @@ sgx = ["nom/alloc"]
[dependencies]
bounded-spsc-queue = "0.4.0"
error-chain = { version = "0.12.0", default-features = false }
failure = "0.1.5"
itertools = "0.7.8"
lazy_static = "1.1.0"
ndarray = "0.11.2"
ndarray="0.12.1"
nom = {version = "4.0.0", default-features = false }
serde = "1.0.59"
serde_derive = "1.0.79"
serde_json = "1.0.17"
tvm-common = { version = "0.1.0", path = "../common/", features = ["runtime"] }
tvm-common = { version = "0.1.0", path = "../common/" }
[target.'cfg(not(target_env = "sgx"))'.dependencies]
num_cpus = "1.8.0"
#[cfg(target_env = "sgx")]
use alloc::alloc::{self, Layout};
use alloc::alloc::{self, Layout, LayoutErr};
#[cfg(not(target_env = "sgx"))]
use std::alloc::{self, Layout};
use crate::errors::*;
use std::alloc::{self, Layout, LayoutErr};
const DEFAULT_ALIGN_BYTES: usize = 4;
......@@ -15,7 +13,7 @@ pub struct Allocation {
impl Allocation {
/// Allocates a chunk of memory of `size` bytes with optional alignment.
pub fn new(size: usize, align: Option<usize>) -> Result<Self> {
pub fn new(size: usize, align: Option<usize>) -> Result<Self, LayoutErr> {
let alignment = align.unwrap_or(DEFAULT_ALIGN_BYTES);
let layout = Layout::from_size_align(size, alignment)?;
let ptr = unsafe { alloc::alloc(layout.clone()) };
......
use std::{
any::TypeId,
convert::TryFrom,
mem,
ops::{Deref, DerefMut},
os::raw::{c_int, c_void},
ptr, slice,
};
use std::{convert::TryFrom, mem, os::raw::c_void, ptr, slice};
use failure::Error;
use ndarray;
use crate::{
allocator::Allocation,
errors::*,
ffi::runtime::{
use tvm_common::{
array::{DataType, TVMContext},
ffi::{
DLContext, DLDataType, DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt,
DLDataTypeCode_kDLUInt, DLDeviceType_kDLCPU, DLTensor as _DLTensor,
DLDataTypeCode_kDLUInt, DLTensor,
},
};
use crate::allocator::Allocation;
/// A `Storage` is a container which holds `Tensor` data.
#[derive(PartialEq)]
pub enum Storage<'a> {
......@@ -29,7 +23,7 @@ pub enum Storage<'a> {
}
impl<'a> Storage<'a> {
pub fn new(size: usize, align: Option<usize>) -> Result<Storage<'static>> {
pub fn new(size: usize, align: Option<usize>) -> Result<Storage<'static>, Error> {
Ok(Storage::Owned(Allocation::new(size, align)?))
}
......@@ -237,6 +231,27 @@ impl<'a> Tensor<'a> {
byte_offset: 0,
}
}
pub(crate) fn as_dltensor(&self, flatten: bool) -> DLTensor {
assert!(!flatten || self.is_contiguous());
DLTensor {
data: unsafe { self.data.as_mut_ptr().offset(self.byte_offset) } as *mut c_void,
ctx: DLContext::from(&self.ctx),
ndim: if flatten { 1 } else { self.shape.len() } as i32,
dtype: DLDataType::from(&self.dtype),
shape: if flatten {
&self.size as *const _ as *mut i64
} else {
self.shape.as_ptr()
} as *mut i64,
strides: if flatten || self.is_contiguous() {
ptr::null_mut()
} else {
self.strides.as_ref().unwrap().as_ptr()
} as *mut i64,
byte_offset: 0,
}
}
}
/// Conversions to `ndarray::Array` from `Tensor`, if the types match.
......@@ -244,7 +259,7 @@ macro_rules! impl_ndarray_try_from_tensor {
($type:ty, $dtype:expr) => {
impl<'a, 't> TryFrom<&'a Tensor<'t>> for ndarray::ArrayD<$type> {
type Error = Error;
fn try_from(tensor: &'a Tensor) -> Result<ndarray::ArrayD<$type>> {
fn try_from(tensor: &'a Tensor) -> Result<ndarray::ArrayD<$type>, Error> {
ensure!(
tensor.dtype == $dtype,
"Cannot convert Tensor with dtype {:?} to ndarray",
......@@ -263,120 +278,9 @@ macro_rules! impl_ndarray_try_from_tensor {
};
}
impl_ndarray_try_from_tensor!(i32, DTYPE_INT32);
impl_ndarray_try_from_tensor!(u32, DTYPE_UINT32);
impl_ndarray_try_from_tensor!(f32, DTYPE_FLOAT32);
impl_ndarray_try_from_tensor!(f64, DTYPE_FLOAT64);
pub struct DLTensor {
pub(crate) inner: _DLTensor,
}
impl Deref for DLTensor {
type Target = _DLTensor;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl DerefMut for DLTensor {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.inner
}
}
impl DLTensor {
pub(crate) fn new(raw: _DLTensor) -> Self {
Self { inner: raw }
}
pub(crate) fn from_tensor<'a>(tensor: &'a Tensor, flatten: bool) -> Self {
assert!(!flatten || tensor.is_contiguous());
Self {
inner: _DLTensor {
data: unsafe { tensor.data.as_mut_ptr().offset(tensor.byte_offset) } as *mut c_void,
ctx: DLContext::from(&tensor.ctx),
ndim: if flatten { 1 } else { tensor.shape.len() } as i32,
dtype: DLDataType::from(&tensor.dtype),
shape: if flatten {
&tensor.size as *const _ as *mut i64
} else {
tensor.shape.as_ptr()
} as *mut i64,
strides: if flatten || tensor.is_contiguous() {
ptr::null_mut()
} else {
tensor.strides.as_ref().unwrap().as_ptr()
} as *mut i64,
byte_offset: 0,
},
}
}
}
impl<'a, 't> From<&'a Tensor<'t>> for DLTensor {
fn from(tensor: &'a Tensor<'t>) -> Self {
DLTensor::from_tensor(tensor, false /* flatten */)
}
}
impl<'a, 't> From<&'a mut Tensor<'t>> for DLTensor {
fn from(tensor: &'a mut Tensor<'t>) -> Self {
DLTensor::from_tensor(tensor, false /* flatten */)
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct DataType {
pub(crate) code: usize,
pub(crate) bits: usize,
pub(crate) lanes: usize,
}
impl DataType {
/// Returns the number of bytes occupied by an element of this `DataType`.
pub fn itemsize(&self) -> usize {
(self.bits * self.lanes) >> 3
}
/// Returns whether this `DataType` represents primitive type `T`.
pub fn is_type<T: 'static>(&self) -> bool {
if self.lanes != 1 {
return false;
}
let typ = TypeId::of::<T>();
(typ == TypeId::of::<i32>() && self.code == 0 && self.bits == 32)
|| (typ == TypeId::of::<i64>() && self.code == 0 && self.bits == 64)
|| (typ == TypeId::of::<u32>() && self.code == 1 && self.bits == 32)
|| (typ == TypeId::of::<u64>() && self.code == 1 && self.bits == 64)
|| (typ == TypeId::of::<f32>() && self.code == 2 && self.bits == 32)
|| (typ == TypeId::of::<f64>() && self.code == 2 && self.bits == 64)
}
}
impl<'a> From<&'a DataType> for DLDataType {
fn from(dtype: &'a DataType) -> Self {
Self {
code: dtype.code as u8,
bits: dtype.bits as u8,
lanes: dtype.lanes as u16,
}
}
}
impl From<DLDataType> for DataType {
fn from(dtype: DLDataType) -> Self {
Self {
code: dtype.code as usize,
bits: dtype.bits as usize,
lanes: dtype.lanes as usize,
}
}
}
macro_rules! make_dtype_const {
($name: ident, $code: ident, $bits: expr, $lanes: expr) => {
const $name: DataType = DataType {
pub const $name: DataType = DataType {
code: $code as usize,
bits: $bits,
lanes: $lanes,
......@@ -389,28 +293,20 @@ make_dtype_const!(DTYPE_UINT32, DLDataTypeCode_kDLUInt, 32, 1);
// make_dtype_const!(DTYPE_FLOAT16, DLDataTypeCode_kDLFloat, 16, 1);
make_dtype_const!(DTYPE_FLOAT32, DLDataTypeCode_kDLFloat, 32, 1);
make_dtype_const!(DTYPE_FLOAT64, DLDataTypeCode_kDLFloat, 64, 1);
impl_ndarray_try_from_tensor!(i32, DTYPE_INT32);
impl_ndarray_try_from_tensor!(u32, DTYPE_UINT32);
impl_ndarray_try_from_tensor!(f32, DTYPE_FLOAT32);
impl_ndarray_try_from_tensor!(f64, DTYPE_FLOAT64);
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct TVMContext {
pub(crate) device_type: usize,
pub(crate) device_id: usize,
}
impl<'a> From<&'a TVMContext> for DLContext {
fn from(ctx: &'a TVMContext) -> Self {
Self {
device_type: ctx.device_type as u32,
device_id: ctx.device_id as i32,
}
impl<'a, 't> From<&'a Tensor<'t>> for DLTensor {
fn from(tensor: &'a Tensor<'t>) -> Self {
Tensor::as_dltensor(tensor, false /* flatten */)
}
}
impl Default for TVMContext {
fn default() -> Self {
Self {
device_type: DLDeviceType_kDLCPU as usize,
device_id: 0,
}
impl<'a, 't> From<&'a mut Tensor<'t>> for DLTensor {
fn from(tensor: &'a mut Tensor<'t>) -> Self {
Tensor::as_dltensor(tensor, false /* flatten */)
}
}
......@@ -463,42 +359,6 @@ macro_rules! impl_tensor_from_ndarray {
};
}
/// `From` conversions to `DLTensor` for `ndarray::Array`.
/// Takes a reference to the `ndarray` since `DLTensor` is not owned.
macro_rules! impl_dltensor_from_ndarray {
($type:ty, $typecode:expr) => {
impl<'a, D: ndarray::Dimension> From<&'a mut ndarray::Array<$type, D>> for DLTensor {
fn from(arr: &'a mut ndarray::Array<$type, D>) -> Self {
DLTensor {
inner: _DLTensor {
data: arr.as_mut_ptr() as *mut c_void,
ctx: DLContext {
device_type: DLDeviceType_kDLCPU,
device_id: 0,
},
ndim: arr.ndim() as c_int,
dtype: DLDataType {
code: $typecode as u8,
bits: 8 * mem::size_of::<$type>() as u8,
lanes: 1,
},
shape: arr.shape().as_ptr() as *const i64 as *mut i64,
strides: arr.strides().as_ptr() as *const isize as *mut i64,
byte_offset: 0,
},
}
}
}
};
}
impl_dltensor_from_ndarray!(f32, DLDataTypeCode_kDLFloat);
impl_dltensor_from_ndarray!(f64, DLDataTypeCode_kDLFloat);
impl_dltensor_from_ndarray!(i32, DLDataTypeCode_kDLInt);
impl_dltensor_from_ndarray!(i64, DLDataTypeCode_kDLInt);
impl_dltensor_from_ndarray!(u32, DLDataTypeCode_kDLUInt);
impl_dltensor_from_ndarray!(u64, DLDataTypeCode_kDLUInt);
impl_tensor_from_ndarray!(f32, DLDataTypeCode_kDLFloat);
impl_tensor_from_ndarray!(f64, DLDataTypeCode_kDLFloat);
impl_tensor_from_ndarray!(i32, DLDataTypeCode_kDLInt);
......
#[cfg(target_env = "sgx")]
use alloc::alloc;
#[cfg(not(target_env = "sgx"))]
use std::alloc;
use std::num;
use crate::common::errors as common_errors;
use ndarray;
use serde_json;
error_chain! {
errors {
GraphFormatError(msg: String) {
description("unable to load graph")
display("could not load graph json: {}", msg)
}
LoadGraphParamsError(msg: String) {
description("unable to load graph params")
display("could not load graph params: {}", msg)
}
}
foreign_links {
Alloc(alloc::AllocErr);
GraphDeserialize(serde_json::Error);
ParseInt(num::ParseIntError);
ShapeError(ndarray::ShapeError);
CommonError(common_errors::Error);
}
#[derive(Debug, Fail)]
pub enum GraphFormatError {
#[fail(display = "Could not parse graph json")]
Parse(#[fail(cause)] failure::Error),
#[fail(display = "Could not parse graph params")]
Params,
#[fail(display = "{} is missing attr: {}", 0, 1)]
MissingAttr(String, String),
#[fail(display = "Missing field: {}", 0)]
MissingField(&'static str),
#[fail(display = "Invalid DLType: {}", 0)]
InvalidDLType(String),
}
impl From<alloc::LayoutErr> for Error {
fn from(_err: alloc::LayoutErr) -> Error {
Error::from_kind(ErrorKind::Msg("Layout error".to_string()))
}
#[derive(Debug, Fail)]
#[fail(display = "SGX error: 0x{:x}", code)]
pub struct SgxError {
pub code: u32,
}
use std::{cmp, collections::HashMap, convert::TryFrom, iter::FromIterator, mem, str};
use failure::Error;
use nom::{alpha1, digit1, le_i32, le_i64, le_u16, le_u32, le_u64, le_u8, types::CompleteStr};
use serde;
use serde_json;
use super::{DLTensor, DataType, Module, Storage, TVMContext, Tensor};
use crate::{
common::value::TVMArgValue,
errors::{Error, ErrorKind, Result},
ffi::runtime::{DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLDataTypeCode_kDLUInt},
use tvm_common::{
array::{DataType, TVMContext},
ffi::{DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLDataTypeCode_kDLUInt, DLTensor},
TVMArgValue,
};
use crate::{errors::GraphFormatError, Module, Storage, Tensor};
// @see `kTVMNDArrayMagic` in `ndarray.h`
const _NDARRAY_MAGIC: u64 = 0xDD5E40F096B4A13F;
// @see `kTVMNDArrayListMagic` in `graph_runtime.h`
......@@ -41,28 +42,26 @@ pub struct Entry {
}
impl Graph {
fn entry_index(&self, entry: &Entry) -> Result<usize> {
fn entry_index(&self, entry: &Entry) -> Result<usize, GraphFormatError> {
self.node_row_ptr
.as_ref()
.map(|nrp| nrp[entry.id] + entry.index)
.ok_or("Missing node_row_ptr.".into())
.ok_or_else(|| GraphFormatError::MissingField("node_row_ptr"))
}
/// Attempt to deserialize a JSON attribute to a type `T`.
fn get_attr<T: serde::de::DeserializeOwned>(&self, attr: &str) -> Result<T> {
fn get_attr<T: serde::de::DeserializeOwned>(&self, attr: &str) -> Result<T, GraphFormatError> {
Ok(serde_json::from_value::<T>(
self.attrs
.as_ref()
.ok_or(ErrorKind::GraphFormatError(
"Missing graph attrs".to_string(),
))?
.ok_or(GraphFormatError::MissingField("attrs"))?
.get(attr)
.ok_or(ErrorKind::GraphFormatError(format!(
"Missing {} attr",
attr
)))?
.ok_or_else(|| {
GraphFormatError::MissingAttr("graph".to_string(), attr.to_string())
})?
.to_owned(),
)?)
)
.map_err(|err| GraphFormatError::Parse(err.into()))?)
}
}
......@@ -81,39 +80,31 @@ struct NodeAttrs {
flatten_data: bool,
}
macro_rules! get_node_attr {
($node:expr, $attrs:ident, $attr:literal) => {
$attrs
.get($attr)
.ok_or_else(|| GraphFormatError::MissingAttr($node.to_owned(), $attr.to_owned()))
};
}
impl Node {
fn parse_attrs(&self) -> Result<NodeAttrs> {
fn parse_attrs(&self) -> Result<NodeAttrs, Error> {
let attrs = self
.attrs
.as_ref()
.ok_or(format!("Missing node.attrs for `{}`", self.name))?;
let func_name = attrs
.get("func_name")
.ok_or(format!("Node `{}` is missing attrs.func_name", self.name))?
.to_string();
let num_outputs = attrs
.get("num_outputs")
.ok_or(format!("Node `{}` is missing attrs.num_outputs", self.name))?
.parse::<usize>()?;
let flatten_data = attrs
.get("flatten_data")
.ok_or(format!(
"Node `{}` is missing attrs.flatten_data",
self.name
))?
.parse::<u8>()?
== 1;
.ok_or_else(|| GraphFormatError::MissingAttr(self.name.clone(), "attrs".to_owned()))?;
Ok(NodeAttrs {
func_name,
num_outputs,
flatten_data,
func_name: get_node_attr!(self.name, attrs, "func_name")?.to_owned(),
num_outputs: get_node_attr!(self.name, attrs, "num_outputs")?.parse::<usize>()?,
flatten_data: get_node_attr!(self.name, attrs, "flatten_data")?.parse::<u8>()? == 1,
})
}
}
impl<'a> TryFrom<&'a String> for Graph {
type Error = Error;
fn try_from(graph_json: &String) -> Result<Self> {
fn try_from(graph_json: &String) -> Result<Self, self::Error> {
let graph = serde_json::from_str(graph_json)?;
Ok(graph)
}
......@@ -121,7 +112,7 @@ impl<'a> TryFrom<&'a String> for Graph {
impl<'a> TryFrom<&'a str> for Graph {
type Error = Error;
fn try_from(graph_json: &'a str) -> Result<Self> {
fn try_from(graph_json: &'a str) -> Result<Self, Self::Error> {
let graph = serde_json::from_str(graph_json)?;
Ok(graph)
}
......@@ -161,7 +152,7 @@ pub struct GraphExecutor<'m, 't> {
unsafe impl<'m, 't> Send for GraphExecutor<'m, 't> {}
impl<'m, 't> GraphExecutor<'m, 't> {
pub fn new<M: 'm + Module>(graph: Graph, lib: &'m M) -> Result<Self> {
pub fn new<M: 'm + Module>(graph: Graph, lib: &'m M) -> Result<Self, Error> {
let tensors = Self::setup_storages(&graph)?;
Ok(GraphExecutor {
op_execs: Self::setup_op_execs(&graph, lib, &tensors)?,
......@@ -178,7 +169,7 @@ impl<'m, 't> GraphExecutor<'m, 't> {
}
/// Allocates `Storages` for each `storage_id` and returns `Tensor`s to hold each output.
fn setup_storages<'a>(graph: &'a Graph) -> Result<Vec<Tensor<'t>>> {
fn setup_storages<'a>(graph: &'a Graph) -> Result<Vec<Tensor<'t>>, Error> {
let storage_ids = graph.get_attr::<(String, Vec<usize>)>("storage_id")?.1;
let shapes = graph.get_attr::<(String, Vec<Vec<i64>>)>("shape")?.1;
let dtypes = graph
......@@ -189,18 +180,15 @@ impl<'m, 't> GraphExecutor<'m, 't> {
if let Ok((_, dtype)) = tvm_str_to_type(CompleteStr(dltype)) {
Ok(dtype)
} else {
Err(ErrorKind::GraphFormatError(
format!("Invalid dltype: {}", dltype).to_string(),
)
.into())
Err(GraphFormatError::InvalidDLType(dltype.to_string()))
}
})
.collect::<Result<Vec<DataType>>>()?;
.collect::<Result<Vec<DataType>, GraphFormatError>>()?;
let align = dtypes.iter().map(|dtype| dtype.bits as usize).max();
let align = dtypes.iter().map(|dtype| dtype.bits() as usize).max();
let mut storage_num_bytes = vec![0usize; *storage_ids.iter().max().unwrap_or(&1) + 1];
for (i, &storage_id) in storage_ids.iter().enumerate() {
let dtype_size = dtypes[i].bits * dtypes[i].lanes >> 3;
let dtype_size = dtypes[i].bits() * dtypes[i].lanes() >> 3;
let nbytes = dtype_size * shapes[i].iter().product::<i64>() as usize;
storage_num_bytes[storage_id] = cmp::max(nbytes, storage_num_bytes[storage_id]);
}
......@@ -208,7 +196,7 @@ impl<'m, 't> GraphExecutor<'m, 't> {
let mut storages: Vec<Storage> = storage_num_bytes
.into_iter()
.map(|nbytes| Storage::new(nbytes, align))
.collect::<Result<Vec<Storage>>>()?;
.collect::<Result<Vec<Storage>, Error>>()?;
let tensors = izip!(storage_ids, shapes, dtypes)
.map(|(storage_id, shape, dtype)| {
......@@ -233,7 +221,7 @@ impl<'m, 't> GraphExecutor<'m, 't> {
graph: &Graph,
lib: &'m M,
tensors: &Vec<Tensor<'t>>,
) -> Result<Vec<Box<Fn() + 'm>>> {
) -> Result<Vec<Box<Fn() + 'm>>, Error> {
ensure!(graph.node_row_ptr.is_some(), "Missing node_row_ptr.");
let node_row_ptr = graph.node_row_ptr.as_ref().unwrap();
......@@ -251,9 +239,10 @@ impl<'m, 't> GraphExecutor<'m, 't> {
continue;
}
let func = lib
.get_function(&attrs.func_name)
.ok_or(format!("Missing function {}", attrs.func_name))?;
let func = lib.get_function(&attrs.func_name).ok_or(format_err!(
"Library is missing function {}",
attrs.func_name
))?;
let arg_indices = node
.inputs
.iter()
......@@ -264,19 +253,19 @@ impl<'m, 't> GraphExecutor<'m, 't> {
.map(|idx| {
let tensor = &tensors[idx?];
Ok(if attrs.flatten_data {
DLTensor::from_tensor(tensor, true /* flatten */)
Tensor::as_dltensor(tensor, true /* flatten */)
} else {
DLTensor::from(tensor)
})
})
.collect::<Result<Vec<DLTensor>>>()
.collect::<Result<Vec<DLTensor>, Error>>()
.unwrap();
let op: Box<Fn()> = box move || {
let args = dl_tensors
.iter()
.map(|t| t.into())
.collect::<Vec<TVMArgValue>>();
func(args.as_slice());
func(args.as_slice()).unwrap();
};
op_execs.push(op);
}
......@@ -344,7 +333,7 @@ impl<'m, 't> GraphExecutor<'m, 't> {
}
}
/// Converts a string to TVM DLDataTypeCode. @see `String2TVMType` in packed_func.h
// Converts a string to TVM DLDataTypeCode. @see `String2TVMType` in packed_func.h
named!(
tvm_str_to_type<CompleteStr, DataType>,
do_parse!(
......@@ -367,7 +356,7 @@ named!(
)
);
/// Converts a bytes to String.
// Converts a bytes to String.
named!(
name<String>,
map_res!(length_bytes!(le_u64), |b: &[u8]| String::from_utf8(
......@@ -375,7 +364,7 @@ named!(
))
);
/// Parses a TVMContext
// Parses a TVMContext
named!(
tvm_ctx<&[u8], TVMContext>,
do_parse!(
......@@ -385,7 +374,7 @@ named!(
)
);
/// Parses a DataType
// Parses a DataType
named!(
data_type<&[u8], DataType>,
do_parse!(
......@@ -396,7 +385,7 @@ named!(
)
);
/// Parses a Tensor from a TVM array file.
// Parses a Tensor from a TVM array file.
named!(
tensor<Tensor>,
do_parse!(
......@@ -420,7 +409,7 @@ named!(
)
);
/// Parses a graph params dict from a params binary file.
// Parses a graph params dict from a params binary file.
named!(
parse_param_dict<HashMap<String, Tensor>>,
do_parse!(
......@@ -433,17 +422,15 @@ named!(
);
/// Loads a param dict saved using `nnvm.compiler.save_param_dict`.
pub fn load_param_dict(bytes: &[u8]) -> Result<HashMap<String, Tensor>> {
pub fn load_param_dict(bytes: &[u8]) -> Result<HashMap<String, Tensor>, GraphFormatError> {
if let Ok((remaining_bytes, param_dict)) = parse_param_dict(bytes) {
if remaining_bytes.len() > 0 {
bail!(ErrorKind::LoadGraphParamsError("extra input".to_string()))
} else {
if remaining_bytes.len() == 0 {
Ok(param_dict)
} else {
Err(GraphFormatError::Params)
}
} else {
bail!(ErrorKind::LoadGraphParamsError(
"invalid parameters file".to_string()
))
Err(GraphFormatError::Params)
}
}
......
......@@ -14,7 +14,6 @@
allocator_api,
box_syntax,
fn_traits,
try_from,
unboxed_closures,
vec_remove_item
)]
......@@ -25,7 +24,7 @@ extern crate bounded_spsc_queue;
#[cfg(target_env = "sgx")]
extern crate core;
#[macro_use]
extern crate error_chain;
extern crate failure;
#[macro_use]
extern crate itertools;
#[macro_use]
......@@ -39,36 +38,45 @@ extern crate serde;
#[macro_use]
extern crate serde_derive;
extern crate serde_json;
extern crate tvm_common as common;
extern crate tvm_common;
mod allocator;
mod array;
pub mod errors;
mod module;
#[macro_use]
mod packed_func;
mod graph;
mod module;
#[cfg(target_env = "sgx")]
#[macro_use]
pub mod sgx;
mod threading;
mod workspace;
pub use crate::common::{errors::*, ffi, TVMArgValue, TVMRetValue};
pub use self::{
array::*, errors::*, graph::*, module::*, packed_func::*, threading::*, workspace::*,
pub use tvm_common::{
call_packed,
errors::*,
ffi::{self, DLTensor},
packed_func::{self, *},
TVMArgValue, TVMRetValue,
};
#[cfg(target_env = "sgx")]
use self::sgx::ocall_packed_func;
pub use self::{array::*, errors::*, graph::*, module::*, threading::*, workspace::*};
lazy_static! {
static ref LAST_ERROR: std::sync::RwLock<Option<&'static std::ffi::CStr>> =
std::sync::RwLock::new(None);
}
#[no_mangle]
pub extern "C" fn TVMAPISetLastError(cmsg: *const i8) {
#[cfg(not(target_env = "sgx"))]
unsafe {
panic!(std::ffi::CStr::from_ptr(cmsg).to_str().unwrap());
}
*LAST_ERROR.write().unwrap() = Some(unsafe { std::ffi::CStr::from_ptr(cmsg) });
#[cfg(target_env = "sgx")]
ocall_packed!("__sgx_set_last_error__", cmsg);
}
#[no_mangle]
pub extern "C" fn TVMGetLastError() -> *const std::os::raw::c_char {
match *LAST_ERROR.read().unwrap() {
Some(err) => err.as_ptr(),
None => std::ptr::null(),
}
}
......@@ -2,29 +2,29 @@ use std::{
collections::HashMap, convert::AsRef, ffi::CStr, os::raw::c_char, string::String, sync::Mutex,
};
use crate::{
ffi::runtime::BackendPackedCFunc,
packed_func::{wrap_backend_packed_func, PackedFunc},
use tvm_common::{
ffi::BackendPackedCFunc,
packed_func::{PackedFunc, TVMArgValue, TVMRetValue, TVMValue},
};
pub trait Module {
fn get_function<S: AsRef<str>>(&self, name: S) -> Option<PackedFunc>;
fn get_function<S: AsRef<str>>(&self, name: S) -> Option<&(dyn PackedFunc)>;
}
pub struct SystemLibModule;
lazy_static! {
static ref SYSTEM_LIB_FUNCTIONS: Mutex<HashMap<String, BackendPackedCFunc>> =
static ref SYSTEM_LIB_FUNCTIONS: Mutex<HashMap<String, &'static (dyn PackedFunc)>> =
Mutex::new(HashMap::new());
}
impl Module for SystemLibModule {
fn get_function<S: AsRef<str>>(&self, name: S) -> Option<PackedFunc> {
fn get_function<S: AsRef<str>>(&self, name: S) -> Option<&(dyn PackedFunc)> {
SYSTEM_LIB_FUNCTIONS
.lock()
.unwrap()
.get(name.as_ref())
.map(|func| wrap_backend_packed_func(func.to_owned()))
.map(|f| *f)
}
}
......@@ -34,15 +34,42 @@ impl Default for SystemLibModule {
}
}
// @see `WrapPackedFunc` in `llvm_module.cc`.
pub(super) fn wrap_backend_packed_func(
func_name: String,
func: BackendPackedCFunc,
) -> Box<dyn PackedFunc> {
box move |args: &[TVMArgValue]| {
let exit_code = func(
args.iter()
.map(|ref arg| arg.value)
.collect::<Vec<TVMValue>>()
.as_ptr(),
args.iter()
.map(|ref arg| arg.type_code as i32)
.collect::<Vec<i32>>()
.as_ptr() as *const i32,
args.len() as i32,
);
if exit_code == 0 {
Ok(TVMRetValue::default())
} else {
Err(tvm_common::errors::FuncCallError::get_with_context(
func_name.clone(),
))
}
}
}
#[no_mangle]
pub extern "C" fn TVMBackendRegisterSystemLibSymbol(
cname: *const c_char,
func: BackendPackedCFunc,
) -> i32 {
let name = unsafe { CStr::from_ptr(cname).to_str().unwrap() };
SYSTEM_LIB_FUNCTIONS
.lock()
.unwrap()
.insert(name.to_string(), func);
SYSTEM_LIB_FUNCTIONS.lock().unwrap().insert(
name.to_string(),
&*Box::leak(wrap_backend_packed_func(name.to_string(), func)),
);
return 0;
}
use std::{convert::TryFrom, marker::PhantomData, os::raw::c_void};
use super::Tensor;
use crate::ffi::runtime::{
BackendPackedCFunc, DLTensor as _DLTensor, TVMTypeCode_kArrayHandle,
TVMTypeCode_kNDArrayContainer, TVMValue as _TVMValue,
};
use super::DLTensor;
use crate::{
common::{TVMArgValue, TVMRetValue, TVMTypeCode, TVMValue},
errors::*,
};
pub type PackedFunc = Box<Fn(&[TVMArgValue]) -> TVMRetValue + Send + Sync>;
/// Calls a packed function and returns a `TVMRetValue`.
///
/// # Example
///
/// `call_packed!(my_tvm_func, &mut arg1, &mut arg2)`
#[macro_export]
macro_rules! call_packed {
($fn:expr, $($args:expr),+) => {
$fn(&[$($args.into(),)+])
};
($fn:expr) => {
$fn(&Vec::new())
};
}
impl<'a> From<&'a DLTensor> for TVMArgValue<'a> {
fn from(arr: &'a DLTensor) -> Self {
let raw = _TVMValue {
v_handle: arr as *const _ as *mut DLTensor as *mut c_void,
};
TVMArgValue {
value: TVMValue::new(raw),
type_code: TVMTypeCode::kArrayHandle,
lifetime: PhantomData,
}
}
}
impl<'a> From<&'a mut DLTensor> for TVMArgValue<'a> {
fn from(arr: &'a mut DLTensor) -> Self {
let raw = _TVMValue {
v_handle: arr as *mut _ as *mut c_void,
};
TVMArgValue {
value: TVMValue::new(raw),
type_code: TVMTypeCode::kArrayHandle,
lifetime: PhantomData,
}
}
}
impl<'a> TryFrom<TVMArgValue<'a>> for Tensor<'a> {
type Error = Error;
fn try_from(val: TVMArgValue<'a>) -> Result<Self> {
ensure!(
val.type_code == TVMTypeCode::kArrayHandle
|| val.type_code == TVMTypeCode::kNDArrayContainer,
"Could not downcast arg. Expected `{}` or `{}`, but got `{}`",
TVMTypeCode::kArrayHandle,
TVMTypeCode::kNDArrayContainer,
val.type_code,
);
let dlt = unsafe { *(val.value.v_handle as *mut _DLTensor as *const _DLTensor) };
Ok(DLTensor::new(dlt).into())
}
}
impl<'a, 't> From<&'t Tensor<'a>> for TVMRetValue {
fn from(val: &'t Tensor<'a>) -> Self {
TVMRetValue {
prim_value: 0,
box_value: box DLTensor::from(val),
type_code: TVMTypeCode::kNDArrayContainer,
}
}
}
impl<'a> TryFrom<TVMRetValue> for Tensor<'a> {
type Error = Error;
fn try_from(ret: TVMRetValue) -> Result<Self> {
ensure!(
ret.type_code == TVMTypeCode::kArrayHandle
|| ret.type_code == TVMTypeCode::kNDArrayContainer,
"Could not downcast arg. Expected `{}` or `{}`, but got `{}`",
TVMTypeCode_kArrayHandle,
TVMTypeCode_kNDArrayContainer,
ret.type_code,
);
let dlt = unsafe { *(ret.prim_value as *mut _DLTensor as *const _DLTensor) };
Ok(DLTensor::new(dlt).into())
}
}
// @see `WrapPackedFunc` in `llvm_module.cc`.
pub(crate) fn wrap_backend_packed_func(func: BackendPackedCFunc) -> PackedFunc {
box move |args: &[TVMArgValue]| {
func(
args.iter()
.map(|ref arg| arg.value.inner)
.collect::<Vec<_TVMValue>>()
.as_ptr(),
args.iter()
.map(|ref arg| arg.type_code as i32)
.collect::<Vec<i32>>()
.as_ptr() as *const i32,
args.len() as i32,
);
TVMRetValue::default()
}
}
......@@ -3,18 +3,17 @@ use std::{
os::raw::{c_char, c_int},
};
use errors::Result;
use ffi::runtime::TVMValue;
use runtime::{threading::sgx_join_threads, SystemLibModule, TVMArgValue, TVMRetValue};
pub use runtime::threading::tvm_run_worker as run_worker;
pub use crate::threading::tvm_run_worker as run_worker;
use crate::{threading::sgx_join_threads, SystemLibModule, TVMArgValue, TVMRetValue};
use errors::SgxError;
use ffi::TVMValue;
#[macro_export]
macro_rules! tvm_ocall {
($func: expr) => {
match $func {
0 => Ok(()),
err => Err(format!("SGX error: {}", err)),
code => Err(SgxError { code }),
}
};
}
......@@ -33,7 +32,10 @@ extern "C" {
) -> SgxStatus;
}
pub fn ocall_packed_func<S: AsRef<str>>(fn_name: S, args: &[TVMArgValue]) -> Result<TVMRetValue> {
pub fn ocall_packed_func<S: AsRef<str>>(
fn_name: S,
args: &[TVMArgValue],
) -> Result<TVMRetValue, SgxError> {
let mut ret_val = TVMValue { v_int64: 0 };
let ret_type_code = 0i64;
unsafe {
......@@ -58,11 +60,11 @@ pub fn ocall_packed_func<S: AsRef<str>>(fn_name: S, args: &[TVMArgValue]) -> Res
#[macro_export]
macro_rules! ocall_packed {
($fn_name:expr, $($args:expr),+) => {
ocall_packed_func($fn_name, &[$($args.into(),)+])
$crate::sgx::ocall_packed_func($fn_name, &[$($args.into(),)+])
.expect(concat!("Error calling `", $fn_name, "`"))
};
($fn_name:expr) => {
ocall_packed_func($fn_name, &Vec::new())
$crate::sgx::ocall_packed_func($fn_name, &Vec::new())
.expect(concat!("Error calling `", $fn_name, "`"))
}
}
......
use std::{
os::raw::{c_int, c_void},
sync::{
atomic::{AtomicUsize, Ordering, ATOMIC_USIZE_INIT},
atomic::{AtomicUsize, Ordering},
Arc, Barrier,
},
};
......@@ -18,11 +18,10 @@ use std::{
use std::{collections::VecDeque, ptr, sync::Mutex};
use bounded_spsc_queue::{self, Producer};
use crate::{errors::*, ffi::runtime::TVMParallelGroupEnv};
use tvm_common::ffi::TVMParallelGroupEnv;
#[cfg(target_env = "sgx")]
use super::{sgx::ocall_packed_func, TVMArgValue, TVMRetValue};
use super::{TVMArgValue, TVMRetValue};
type FTVMParallelLambda =
extern "C" fn(task_id: usize, penv: *const TVMParallelGroupEnv, cdata: *const c_void) -> i32;
......@@ -62,12 +61,11 @@ impl Job {
}
/// Waits for all tasks in this `Job` to be completed.
fn wait(&self) -> Result<()> {
fn wait(&self) {
while self.pending.load(Ordering::Acquire) > 0 {
#[cfg(not(target_env = "sgx"))]
thread::yield_now();
}
Ok(())
}
}
......@@ -161,7 +159,7 @@ impl ThreadPool {
}
tasks.pop().unwrap()();
job.wait().unwrap();
job.wait();
}
fn run_worker(queue: Consumer<Task>) {
......@@ -251,7 +249,7 @@ pub extern "C" fn TVMBackendParallelLaunch(
cb: cb,
cdata: cdata,
req_num_tasks: num_task,
pending: Arc::new(ATOMIC_USIZE_INIT),
pending: Arc::new(AtomicUsize::new(0)),
});
});
}
......@@ -273,7 +271,7 @@ pub(crate) fn sgx_join_threads() {
cb: poison_pill,
cdata: ptr::null(),
req_num_tasks: 0,
pending: Arc::new(ATOMIC_USIZE_INIT),
pending: Arc::new(AtomicUsize::new(0)),
});
});
ocall_packed!("__sgx_thread_group_join__", 0);
......@@ -322,8 +320,8 @@ mod tests {
#[test]
fn test_parallel_launch() {
TVMBackendParallelLaunch(flambda, ptr::null(), 6);
let counter = ATOMIC_USIZE_INIT;
let task_ids_sum = ATOMIC_USIZE_INIT;
let counter = AtomicUsize::new(0);
let task_ids_sum = AtomicUsize::new(0);
let cdata = (counter, task_ids_sum);
let num_tasks = 3;
TVMBackendParallelLaunch(flambda, &cdata as *const _ as *const c_void, num_tasks);
......
......@@ -4,8 +4,9 @@ use std::{
ptr,
};
use super::allocator::Allocation;
use crate::errors::*;
use failure::Error;
use crate::allocator::Allocation;
const WS_ALIGN: usize = 64; // taken from `kTempAllocaAlignment` in `device_api.h`
......@@ -24,13 +25,13 @@ impl WorkspacePool {
}
}
fn alloc_new(&mut self, size: usize) -> Result<*mut u8> {
fn alloc_new(&mut self, size: usize) -> Result<*mut u8, Error> {
self.workspaces.push(Allocation::new(size, Some(WS_ALIGN))?);
self.in_use.push(self.workspaces.len() - 1);
Ok(self.workspaces[self.workspaces.len() - 1].as_mut_ptr())
}
fn alloc(&mut self, size: usize) -> Result<*mut u8> {
fn alloc(&mut self, size: usize) -> Result<*mut u8, Error> {
if self.free.len() == 0 {
return self.alloc_new(size);
}
......@@ -60,7 +61,7 @@ impl WorkspacePool {
}
}
fn free(&mut self, ptr: *mut u8) -> Result<()> {
fn free(&mut self, ptr: *mut u8) -> Result<(), Error> {
let mut ws_idx = None;
for i in 0..self.in_use.len() {
let idx = self.in_use[i];
......@@ -72,7 +73,7 @@ impl WorkspacePool {
}
Ok(self
.free
.push(ws_idx.ok_or("Tried to free nonexistent workspace.")?))
.push(ws_idx.ok_or(format_err!("Tried to free nonexistent workspace."))?))
}
}
......
......@@ -5,7 +5,7 @@ license = "Apache-2.0"
authors = ["TVM Contributors"]
[dependencies]
ndarray = "0.11.2"
ndarray="0.12.1"
serde = "1.0.59"
serde_json = "1.0.17"
tvm-runtime = { path = "../../" }
......
......@@ -5,7 +5,7 @@ license = "Apache-2.0"
authors = ["TVM Contributors"]
[dependencies]
ndarray = "0.11.2"
ndarray="0.12.1"
tvm-runtime = { path = "../../" }
[build-dependencies]
......
......@@ -17,6 +17,6 @@ fn main() {
let mut a_dl: DLTensor = (&mut a).into();
let mut b_dl: DLTensor = (&mut b).into();
let mut c_dl: DLTensor = (&mut c).into();
call_packed!(add, &mut a_dl, &mut b_dl, &mut c_dl);
call_packed!(add, &mut a_dl, &mut b_dl, &mut c_dl).unwrap();
assert!(c.all_close(&e, 1e-8f32));
}
......@@ -14,11 +14,11 @@ cargo fmt -- --check
# test common
cd $RUST_DIR/common
cargo build --features runtime
cargo test --features runtime --tests
cargo build
cargo test --tests
cargo build --features frontend
cargo test --features frontend --tests
cargo build --features bindings
cargo test --features bindings --tests
# test runtime
cd $RUST_DIR/runtime
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment