module.rs 4.22 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23 24 25 26 27 28 29
//! Provides the [`Module`] type and methods for working with runtime TVM modules.

use std::{
    convert::TryInto,
    ffi::CString,
    os::raw::{c_char, c_int},
    path::Path,
    ptr,
};

30 31
use failure::Error;
use tvm_common::ffi;
32

33
use crate::{errors, function::Function};
34 35 36 37 38 39 40 41 42

const ENTRY_FUNC: &'static str = "__tvm_main__";

/// Wrapper around TVM module handle which contains an entry function.
/// The entry function can be applied to an imported module through [`entry_func`].
///
/// [`entry_func`]:struct.Module.html#method.entry_func
#[derive(Debug, Clone)]
pub struct Module {
43
    pub(crate) handle: ffi::TVMModuleHandle,
44 45 46 47
    entry_func: Option<Function>,
}

impl Module {
48
    pub(crate) fn new(handle: ffi::TVMModuleHandle) -> Self {
49 50 51 52 53 54 55 56 57 58 59 60 61 62
        Self {
            handle,
            entry_func: None,
        }
    }

    pub fn entry(&mut self) -> Option<&Function> {
        if self.entry_func.is_none() {
            self.entry_func = self.get_function(ENTRY_FUNC, false).ok();
        }
        self.entry_func.as_ref()
    }

    /// Gets a function by name from a registered module.
63
    pub fn get_function(&self, name: &str, query_import: bool) -> Result<Function, Error> {
64
        let name = CString::new(name)?;
65 66
        let mut fhandle = ptr::null_mut() as ffi::TVMFunctionHandle;
        check_call!(ffi::TVMModGetFunction(
67 68 69 70 71
            self.handle,
            name.as_ptr() as *const c_char,
            query_import as c_int,
            &mut fhandle as *mut _
        ));
72 73 74 75 76 77 78
        ensure!(
            !fhandle.is_null(),
            errors::NullHandleError {
                name: format!("{}", name.into_string()?)
            }
        );
        Ok(Function::new(fhandle))
79 80 81 82
    }

    /// Imports a dependent module such as `.ptx` for gpu.
    pub fn import_module(&self, dependent_module: Module) {
83
        check_call!(ffi::TVMModImport(self.handle, dependent_module.handle))
84 85 86
    }

    /// Loads a module shared library from path.
87 88 89 90 91 92 93 94 95 96 97 98 99 100 101
    pub fn load<P: AsRef<Path>>(path: &P) -> Result<Module, Error> {
        let ext = CString::new(
            path.as_ref()
                .extension()
                .unwrap_or(std::ffi::OsStr::new(""))
                .to_str()
                .ok_or_else(|| {
                    format_err!("Bad module load path: `{}`.", path.as_ref().display())
                })?,
        )?;
        let func = Function::get("module._LoadFromFile").expect("API function always exists");
        let cpath =
            CString::new(path.as_ref().to_str().ok_or_else(|| {
                format_err!("Bad module load path: `{}`.", path.as_ref().display())
            })?)?;
102
        let ret: Module = call_packed!(func, cpath.as_c_str(), ext.as_c_str())?.try_into()?;
103 104 105 106 107
        Ok(ret)
    }

    /// Checks if a target device is enabled for a module.
    pub fn enabled(&self, target: &str) -> bool {
108
        let func = Function::get("module._Enabled").expect("API function always exists");
109 110
        // `unwrap` is safe here because if there is any error during the
        // function call, it would occur in `call_packed!`.
111
        let tgt = CString::new(target).unwrap();
112 113 114 115
        let ret: i64 = call_packed!(func, tgt.as_c_str())
            .unwrap()
            .try_into()
            .unwrap();
116 117 118 119
        ret != 0
    }

    /// Returns the underlying module handle.
120
    pub fn handle(&self) -> ffi::TVMModuleHandle {
121 122 123 124 125 126
        self.handle
    }
}

impl Drop for Module {
    fn drop(&mut self) {
127
        check_call!(ffi::TVMModFree(self.handle));
128 129
    }
}