/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ //! Provides the [`Module`] type and methods for working with runtime TVM modules. use std::{ convert::TryInto, ffi::CString, os::raw::{c_char, c_int}, path::Path, ptr, }; use failure::Error; use tvm_common::ffi; use crate::{errors, function::Function}; const ENTRY_FUNC: &'static str = "__tvm_main__"; /// Wrapper around TVM module handle which contains an entry function. /// The entry function can be applied to an imported module through [`entry_func`]. /// /// [`entry_func`]:struct.Module.html#method.entry_func #[derive(Debug, Clone)] pub struct Module { pub(crate) handle: ffi::TVMModuleHandle, entry_func: Option<Function>, } impl Module { pub(crate) fn new(handle: ffi::TVMModuleHandle) -> Self { Self { handle, entry_func: None, } } pub fn entry(&mut self) -> Option<&Function> { if self.entry_func.is_none() { self.entry_func = self.get_function(ENTRY_FUNC, false).ok(); } self.entry_func.as_ref() } /// Gets a function by name from a registered module. pub fn get_function(&self, name: &str, query_import: bool) -> Result<Function, Error> { let name = CString::new(name)?; let mut fhandle = ptr::null_mut() as ffi::TVMFunctionHandle; check_call!(ffi::TVMModGetFunction( self.handle, name.as_ptr() as *const c_char, query_import as c_int, &mut fhandle as *mut _ )); ensure!( !fhandle.is_null(), errors::NullHandleError { name: format!("{}", name.into_string()?) } ); Ok(Function::new(fhandle)) } /// Imports a dependent module such as `.ptx` for gpu. pub fn import_module(&self, dependent_module: Module) { check_call!(ffi::TVMModImport(self.handle, dependent_module.handle)) } /// Loads a module shared library from path. pub fn load<P: AsRef<Path>>(path: &P) -> Result<Module, Error> { let ext = CString::new( path.as_ref() .extension() .unwrap_or(std::ffi::OsStr::new("")) .to_str() .ok_or_else(|| { format_err!("Bad module load path: `{}`.", path.as_ref().display()) })?, )?; let func = Function::get("module._LoadFromFile").expect("API function always exists"); let cpath = CString::new(path.as_ref().to_str().ok_or_else(|| { format_err!("Bad module load path: `{}`.", path.as_ref().display()) })?)?; let ret: Module = call_packed!(func, cpath.as_c_str(), ext.as_c_str())?.try_into()?; Ok(ret) } /// Checks if a target device is enabled for a module. pub fn enabled(&self, target: &str) -> bool { let func = Function::get("module._Enabled").expect("API function always exists"); // `unwrap` is safe here because if there is any error during the // function call, it would occur in `call_packed!`. let tgt = CString::new(target).unwrap(); let ret: i64 = call_packed!(func, tgt.as_c_str()) .unwrap() .try_into() .unwrap(); ret != 0 } /// Returns the underlying module handle. pub fn handle(&self) -> ffi::TVMModuleHandle { self.handle } } impl Drop for Module { fn drop(&mut self) { check_call!(ffi::TVMModFree(self.handle)); } }