//! Provides the [`Module`] type and methods for working with runtime TVM modules. use std::{ convert::TryInto, ffi::CString, os::raw::{c_char, c_int}, path::Path, ptr, }; use crate::ts; use crate::{function::Function, ErrorKind, Result}; const ENTRY_FUNC: &'static str = "__tvm_main__"; /// Wrapper around TVM module handle which contains an entry function. /// The entry function can be applied to an imported module through [`entry_func`]. /// Also [`is_released`] shows whether the module is dropped or not. /// /// [`entry_func`]:struct.Module.html#method.entry_func /// [`is_released`]:struct.Module.html#method.is_released #[derive(Debug, Clone)] pub struct Module { pub(crate) handle: ts::TVMModuleHandle, is_released: bool, entry_func: Option<Function>, } impl Module { pub(crate) fn new(handle: ts::TVMModuleHandle, is_released: bool) -> Self { Self { handle, is_released, entry_func: None, } } pub fn entry(&mut self) -> Option<&Function> { if self.entry_func.is_none() { self.entry_func = self.get_function(ENTRY_FUNC, false).ok(); } self.entry_func.as_ref() } /// Gets a function by name from a registered module. pub fn get_function(&self, name: &str, query_import: bool) -> Result<Function> { let name = CString::new(name)?; let mut fhandle = ptr::null_mut() as ts::TVMFunctionHandle; check_call!(ts::TVMModGetFunction( self.handle, name.as_ptr() as *const c_char, query_import as c_int, &mut fhandle as *mut _ )); if fhandle.is_null() { bail!(ErrorKind::NullHandle(format!("{}", name.into_string()?))) } else { Ok(Function::new(fhandle, false, false)) } } /// Imports a dependent module such as `.ptx` for gpu. pub fn import_module(&self, dependent_module: Module) { check_call!(ts::TVMModImport(self.handle, dependent_module.handle)) } /// Loads a module shared library from path. pub fn load<P: AsRef<Path>>(path: &P) -> Result<Module> { let ext = path.as_ref().extension()?.to_str()?; let func = Function::get("module._LoadFromFile", true /* is_global */) .expect("API function always exists"); let ret: Module = call_packed!(func, path.as_ref().to_str()?, ext)?.try_into()?; Ok(ret) } /// Checks if a target device is enabled for a module. pub fn enabled(&self, target: &str) -> bool { let func = Function::get("module._Enabled", true /* is_global */) .expect("API function always exists"); // `unwrap` is safe here because if there is any error during the // function call, it would occur in `call_packed!`. let ret: i64 = call_packed!(func, target).unwrap().try_into().unwrap(); ret != 0 } /// Returns the underlying module handle. pub fn handle(&self) -> ts::TVMModuleHandle { self.handle } /// Returns true if the underlying module has been dropped and false otherwise. pub fn is_released(&self) -> bool { self.is_released } } impl Drop for Module { fn drop(&mut self) { if !self.is_released { check_call!(ts::TVMModFree(self.handle)); self.is_released = true; } } }