function.rs 13.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36
//! This module provides an idiomatic Rust API for creating and working with TVM functions.
//!
//! For calling an already registered TVM function use [`function::Builder`]
//! To register a TVM packed function from Rust side either
//! use [`function::register`] or the macro [`register_global_func`].
//!
//! See the tests and examples repository for more examples.

use std::{
    collections::BTreeMap,
    ffi::{CStr, CString},
    mem,
    os::raw::{c_char, c_int, c_void},
    ptr, slice, str,
    sync::Mutex,
};

37 38 39 40 41 42 43
use failure::Error;

use crate::{
    errors,
    ffi::{self, TVMValue},
    Module, TVMArgValue, TVMRetValue,
};
44 45 46 47 48 49

lazy_static! {
    static ref GLOBAL_FUNCTIONS: Mutex<BTreeMap<&'static str, Option<Function>>> = {
        let mut out_size = 0 as c_int;
        let name = ptr::null_mut() as *mut c_char;
        let mut out_array = name as *mut _;
50
        check_call!(ffi::TVMFuncListGlobalNames(
51 52 53 54 55 56 57 58 59 60 61 62 63 64
            &mut out_size as *mut _,
            &mut out_array
        ));
        let names_list = unsafe { slice::from_raw_parts(out_array, out_size as usize) };
        Mutex::new(
            names_list
                .into_iter()
                .map(|&p| (unsafe { CStr::from_ptr(p).to_str().unwrap() }, None))
                .collect(),
        )
    };
}

/// Wrapper around TVM function handle which includes `is_global`
65
/// indicating whether the function is global or not, and `is_cloned` showing
66 67 68 69
/// not to drop a cloned function from Rust side.
/// The value of these fields can be accessed through their respective methods.
#[derive(Debug, Hash)]
pub struct Function {
70
    pub(crate) handle: ffi::TVMFunctionHandle,
71 72 73 74 75 76 77 78 79 80
    // whether the registered function is global or not.
    is_global: bool,
    // whether the function has been cloned from frontend or not.
    is_cloned: bool,
}

unsafe impl Send for Function {}
unsafe impl Sync for Function {}

impl Function {
81
    pub(crate) fn new(handle: ffi::TVMFunctionHandle) -> Self {
82 83
        Function {
            handle: handle,
84
            is_global: false,
85 86 87 88 89
            is_cloned: false,
        }
    }

    /// For a given function, it returns a function by name.
90
    pub fn get<S: AsRef<str>>(name: S) -> Option<&'static Function> {
91 92 93 94
        let mut globals = GLOBAL_FUNCTIONS.lock().unwrap();
        globals.get_mut(name.as_ref()).and_then(|maybe_func| {
            if maybe_func.is_none() {
                let name = CString::new(name.as_ref()).unwrap();
95 96
                let mut handle = ptr::null_mut() as ffi::TVMFunctionHandle;
                check_call!(ffi::TVMFuncGetGlobal(
97 98 99
                    name.as_ptr() as *const c_char,
                    &mut handle as *mut _
                ));
100 101 102 103 104
                maybe_func.replace(Function {
                    handle: handle,
                    is_global: true,
                    is_cloned: false,
                });
105 106 107 108 109 110 111 112 113 114
            }
            unsafe {
                std::mem::transmute::<Option<&Function>, Option<&'static Function>>(
                    maybe_func.as_ref(),
                )
            }
        })
    }

    /// Returns the underlying TVM function handle.
115
    pub fn handle(&self) -> ffi::TVMFunctionHandle {
116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132
        self.handle
    }

    /// Returns `true` if the underlying TVM function is global and `false` otherwise.
    pub fn is_global(&self) -> bool {
        self.is_global
    }

    /// Returns `true` if the underlying TVM function has been cloned
    /// from the frontend and `false` otherwise.
    pub fn is_cloned(&self) -> bool {
        self.is_cloned
    }
}

impl Clone for Function {
    fn clone(&self) -> Function {
133 134 135 136
        Self {
            handle: self.handle,
            is_global: self.is_global,
            is_cloned: true,
137 138 139 140 141 142
        }
    }
}

impl Drop for Function {
    fn drop(&mut self) {
143 144
        if !self.is_global && !self.is_cloned {
            check_call!(ffi::TVMFuncFree(self.handle));
145 146 147 148 149 150 151
        }
    }
}

/// Function builder in order to create and call functions.
///
/// *Note:* Currently TVM functions accept *at most* one return value.
152
#[derive(Default)]
153 154
pub struct Builder<'a, 'm> {
    pub func: Option<&'m Function>,
155
    pub arg_buf: Vec<TVMArgValue<'a>>,
156 157 158 159 160 161
    pub ret_buf: Option<TVMRetValue>,
}

impl<'a, 'm> Builder<'a, 'm> {
    pub fn new(
        func: Option<&'m Function>,
162
        arg_buf: Vec<TVMArgValue<'a>>,
163 164 165 166 167 168 169 170 171
        ret_buf: Option<TVMRetValue>,
    ) -> Self {
        Self {
            func,
            arg_buf,
            ret_buf,
        }
    }

172 173
    pub fn get_function(&mut self, name: &'m str) -> &mut Self {
        self.func = Function::get(name);
174 175 176 177
        self
    }

    /// Pushes a [`TVMArgValue`] into the function argument buffer.
178
    pub fn arg<T: 'a>(&mut self, arg: T) -> &mut Self
179
    where
180
        TVMArgValue<'a>: From<T>,
181
    {
182
        self.arg_buf.push(arg.into());
183 184 185 186
        self
    }

    /// Pushes multiple [`TVMArgValue`]s into the function argument buffer.
187
    pub fn args<T: 'a, I>(&mut self, args: I) -> &mut Self
188
    where
189 190
        I: IntoIterator<Item = &'a T>,
        TVMArgValue<'a>: From<&'a T>,
191
    {
192
        args.into_iter().for_each(|arg| {
193
            self.arg(&arg);
194
        });
195 196 197 198 199
        self
    }

    /// Sets an output for a function that requirs a mutable output to be provided.
    /// See the `basics` in tests for an example.
200
    pub fn set_output<T>(&mut self, ret: T) -> &mut Self
201
    where
202
        TVMRetValue: From<T>,
203
    {
204 205
        self.ret_buf = Some(ret.into());
        self
206 207 208
    }

    /// Calls the function that created from `Builder`.
209 210 211 212 213
    pub fn invoke(&mut self) -> Result<TVMRetValue, Error> {
        #![allow(unused_unsafe)]
        ensure!(self.func.is_some(), errors::FunctionNotFoundError);

        let num_args = self.arg_buf.len();
214 215
        let (mut values, mut type_codes): (Vec<ffi::TVMValue>, Vec<ffi::TVMTypeCode>) =
            self.arg_buf.iter().map(|arg| arg.to_tvm_value()).unzip();
216 217

        let mut ret_val = unsafe { std::mem::uninitialized::<TVMValue>() };
218
        let mut ret_type_code = 0i32;
219 220 221 222 223 224 225 226
        check_call!(ffi::TVMFuncCall(
            self.func.ok_or(errors::FunctionNotFoundError)?.handle,
            values.as_mut_ptr(),
            type_codes.as_mut_ptr() as *mut i32,
            num_args as c_int,
            &mut ret_val as *mut _,
            &mut ret_type_code as *mut _
        ));
227

228
        Ok(unsafe { TVMRetValue::from_tvm_value(ret_val, ret_type_code as u32) })
229 230 231 232 233 234 235
    }
}

/// Converts a [`Function`] to builder. Currently, this is the best way to work with
/// TVM functions.
impl<'a, 'm> From<&'m Function> for Builder<'a, 'm> {
    fn from(func: &'m Function) -> Self {
236
        Builder::new(Some(func), Vec::new(), None)
237 238 239 240 241 242
    }
}

/// Converts a mutable reference of a [`Module`] to [`Builder`].
impl<'a, 'm> From<&'m mut Module> for Builder<'a, 'm> {
    fn from(module: &'m mut Module) -> Self {
243
        Builder::new(module.entry(), Vec::new(), None)
244 245 246 247
    }
}

unsafe extern "C" fn tvm_callback(
248
    args: *mut ffi::TVMValue,
249 250
    type_codes: *mut c_int,
    num_args: c_int,
251
    ret: ffi::TVMRetValueHandle,
252 253 254
    fhandle: *mut c_void,
) -> c_int {
    // turning off the incorrect linter complaints
255
    #![allow(unused_assignments, unused_unsafe)]
256 257 258 259
    let len = num_args as usize;
    let args_list = slice::from_raw_parts_mut(args, len);
    let type_codes_list = slice::from_raw_parts_mut(type_codes, len);
    let mut local_args: Vec<TVMArgValue> = Vec::new();
260
    let mut value = mem::uninitialized::<ffi::TVMValue>();
261
    let mut tcode = mem::uninitialized::<c_int>();
262 263
    let rust_fn =
        mem::transmute::<*mut c_void, fn(&[TVMArgValue]) -> Result<TVMRetValue, Error>>(fhandle);
264 265 266
    for i in 0..len {
        value = args_list[i];
        tcode = type_codes_list[i];
267 268 269
        if tcode == ffi::TVMTypeCode_kNodeHandle as c_int
            || tcode == ffi::TVMTypeCode_kFuncHandle as c_int
            || tcode == ffi::TVMTypeCode_kModuleHandle as c_int
270
        {
271
            check_call!(ffi::TVMCbArgToReturn(&mut value as *mut _, tcode));
272
        }
273
        local_args.push(TVMArgValue::from_tvm_value(value.into(), tcode as u32));
274 275 276 277 278 279 280 281 282 283
    }

    let rv = match rust_fn(local_args.as_slice()) {
        Ok(v) => v,
        Err(msg) => {
            crate::set_last_error(&msg);
            return -1;
        }
    };

284
    let (mut ret_val, ret_tcode) = rv.to_tvm_value();
285
    let mut ret_type_code = ret_tcode as c_int;
286
    check_call!(ffi::TVMCFuncSetReturn(
287 288 289 290 291 292 293 294 295
        ret,
        &mut ret_val as *mut _,
        &mut ret_type_code as *mut _,
        1 as c_int
    ));
    0
}

unsafe extern "C" fn tvm_callback_finalizer(fhandle: *mut c_void) {
296 297
    let rust_fn =
        mem::transmute::<*mut c_void, fn(&[TVMArgValue]) -> Result<TVMRetValue, Error>>(fhandle);
298 299 300
    mem::drop(rust_fn);
}

301 302 303 304
fn convert_to_tvm_func(f: fn(&[TVMArgValue]) -> Result<TVMRetValue, Error>) -> Function {
    let mut fhandle = ptr::null_mut() as ffi::TVMFunctionHandle;
    let resource_handle = f as *mut fn(&[TVMArgValue]) -> Result<TVMRetValue, Error>;
    check_call!(ffi::TVMFuncCreateFromCFunc(
305 306 307 308 309
        Some(tvm_callback),
        resource_handle as *mut c_void,
        Some(tvm_callback_finalizer),
        &mut fhandle as *mut _
    ));
310
    Function::new(fhandle)
311 312 313
}

/// Registers a Rust function with signature
314
/// `fn(&[TVMArgValue]) -> Result<TVMRetValue, Error>`
315 316 317 318 319 320 321 322 323 324
/// as a **global TVM packed function** from frontend to TVM backend.
///
/// Use [`register_global_func`] if overriding an existing global TVM function
/// is not required.
///
/// ## Example
///
/// ```
/// use std::convert::TryInto;
///
325
/// fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue, Error> {
326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342
///     let mut ret = 0i64;
///     for arg in args.iter() {
///         let arg: i64 = arg.try_into()?;
///         ret += arg;
///     }
///     let ret_val = TVMRetValue::from(&ret);
///     Ok(ret_val)
/// }
///
/// tvm::function::register(sum, "mysum".to_owned(), false).unwrap();
/// let mut registered = function::Builder::default();
/// registered.get_function("mysum", true);
/// assert!(registered.func.is_some());
/// let ret: i64 = registered.args(&[10, 20, 30]).invoke().unwrap().try_into().unwrap();
/// assert_eq!(ret, 60);
/// ```
pub fn register<S: AsRef<str>>(
343
    f: fn(&[TVMArgValue]) -> Result<TVMRetValue, Error>,
344 345
    name: S,
    override_: bool,
346
) -> Result<(), Error> {
347 348
    let func = convert_to_tvm_func(f);
    let name = CString::new(name.as_ref())?;
349 350
    check_call!(ffi::TVMFuncRegisterGlobal(
        name.into_raw(),
351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366
        func.handle(),
        override_ as c_int
    ));
    Ok(())
}

/// Convenient macro for registering functions from frontend to backend as global
/// TVM packed functions without overriding. If overriding an existing function is needed
/// use the [`function::register`] function instead.
///
/// ## Example
///
/// ```
/// use std::convert::TryInto;
///
/// register_global_func! {
367
///     fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue, Error> {
368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387
///         let mut ret = 0f64;
///         for arg in args.iter() {
///             let arg: f64 = arg.try_into()?;
///             ret += arg;
///         }
///         let ret_val = TVMRetValue::from(&ret);
///         Ok(ret_val)
///     }
/// }
///
/// let mut registered = function::Builder::default();
/// registered.get_function("sum", true);
/// assert!(registered.func.is_some());
/// let ret: f64 = registered.args(&[10f64, 20f64, 30f64]).invoke().unwrap().try_into().unwrap();
/// assert_eq!(ret, 60f64);
/// ```
#[macro_export]
macro_rules! register_global_func {
    {
        $(#[$m:meta])*
388
        fn $fn_name:ident($args:ident : &[TVMArgValue]) -> Result<TVMRetValue, Error> {
389 390 391 392
            $($code:tt)*
        }
    } => {{
        $(#[$m])*
393
        fn $fn_name($args: &[TVMArgValue]) -> Result<TVMRetValue, Error> {
394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446
            $($code)*
        }

        $crate::function::register($fn_name, stringify!($fn_name).to_owned(), false).unwrap();
    }}
}

/// Convenient macro for calling TVM packed functions by providing a
/// function identifier and some arguments. This macro outputs a `Result` type
/// and let user to perform proper error handling.
///
/// **Note**: this macro does *not* expect an outside mutable output. To
/// set mutable output use [`set_output`] directly in the builder pattern.
///
/// [`set_output`]:function/struct.Builder.html#method.set_output
///
/// ## Example
///
/// Instead of
///
/// ```
/// function::Builder::from(func).arg(&a).arg(&b).invoke();
/// ```
///
/// one can use
///
/// ```
/// call_packed!(func, &a, &b);
/// ```
#[macro_export]
macro_rules! call_packed {
    ($fn_name:expr, $($arg:expr),*) => {{
        let mut builder = $crate::function::Builder::from($fn_name);
        $(
            builder.arg($arg);
        )*
        builder.invoke()
    }}
}

#[cfg(test)]
mod tests {
    use super::*;

    static CANARY: &str = "module._LoadFromFile";

    #[test]
    fn list_global_func() {
        assert!(GLOBAL_FUNCTIONS.lock().unwrap().contains_key(CANARY));
    }

    #[test]
    fn get_fn() {
447 448
        assert!(Function::get(CANARY).is_some());
        assert!(Function::get("does not exists!").is_none());
449 450 451 452
    }

    #[test]
    fn provide_args() {
453
        let str_arg = CString::new("test").unwrap();
454
        let mut func = Builder::default();
455
        func.get_function("tvm.graph_runtime.remote_create")
456 457 458
            .arg(10)
            .arg(20)
            .arg(str_arg.as_c_str());
459
        assert_eq!(func.arg_buf.len(), 3);
460 461
    }
}