packed_func.rs 3.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
use std::{convert::TryFrom, marker::PhantomData, os::raw::c_void};

use super::Tensor;
use crate::ffi::runtime::{
    BackendPackedCFunc, DLTensor as _DLTensor, TVMTypeCode_kArrayHandle,
    TVMTypeCode_kNDArrayContainer, TVMValue as _TVMValue,
};

use super::DLTensor;
use crate::{
    common::{TVMArgValue, TVMRetValue, TVMTypeCode, TVMValue},
    errors::*,
};

pub type PackedFunc = Box<Fn(&[TVMArgValue]) -> TVMRetValue + Send + Sync>;

/// Calls a packed function and returns a `TVMRetValue`.
///
/// # Example
///
/// `call_packed!(my_tvm_func, &mut arg1, &mut arg2)`
#[macro_export]
macro_rules! call_packed {
  ($fn:expr, $($args:expr),+) => {
    $fn(&[$($args.into(),)+])
  };
  ($fn:expr) => {
    $fn(&Vec::new())
  };
}

impl<'a> From<&'a DLTensor> for TVMArgValue<'a> {
    fn from(arr: &'a DLTensor) -> Self {
        let raw = _TVMValue {
            v_handle: arr as *const _ as *mut DLTensor as *mut c_void,
        };
        TVMArgValue {
            value: TVMValue::new(raw),
            type_code: TVMTypeCode::kArrayHandle,
            lifetime: PhantomData,
        }
    }
}

impl<'a> From<&'a mut DLTensor> for TVMArgValue<'a> {
    fn from(arr: &'a mut DLTensor) -> Self {
        let raw = _TVMValue {
            v_handle: arr as *mut _ as *mut c_void,
        };
        TVMArgValue {
            value: TVMValue::new(raw),
            type_code: TVMTypeCode::kArrayHandle,
            lifetime: PhantomData,
        }
    }
}

impl<'a> TryFrom<TVMArgValue<'a>> for Tensor<'a> {
    type Error = Error;
    fn try_from(val: TVMArgValue<'a>) -> Result<Self> {
        ensure!(
            val.type_code == TVMTypeCode::kArrayHandle
                || val.type_code == TVMTypeCode::kNDArrayContainer,
            "Could not downcast arg. Expected `{}` or `{}`, but got `{}`",
            TVMTypeCode::kArrayHandle,
            TVMTypeCode::kNDArrayContainer,
            val.type_code,
        );

        let dlt = unsafe { *(val.value.v_handle as *mut _DLTensor as *const _DLTensor) };
        Ok(DLTensor::new(dlt).into())
    }
}

impl<'a, 't> From<&'t Tensor<'a>> for TVMRetValue {
    fn from(val: &'t Tensor<'a>) -> Self {
        TVMRetValue {
            prim_value: 0,
            box_value: box DLTensor::from(val),
            type_code: TVMTypeCode::kNDArrayContainer,
        }
    }
}

impl<'a> TryFrom<TVMRetValue> for Tensor<'a> {
    type Error = Error;
    fn try_from(ret: TVMRetValue) -> Result<Self> {
        ensure!(
            ret.type_code == TVMTypeCode::kArrayHandle
                || ret.type_code == TVMTypeCode::kNDArrayContainer,
            "Could not downcast arg. Expected `{}` or `{}`, but got `{}`",
            TVMTypeCode_kArrayHandle,
            TVMTypeCode_kNDArrayContainer,
            ret.type_code,
        );

        let dlt = unsafe { *(ret.prim_value as *mut _DLTensor as *const _DLTensor) };
        Ok(DLTensor::new(dlt).into())
    }
}

// @see `WrapPackedFunc` in `llvm_module.cc`.
pub(crate) fn wrap_backend_packed_func(func: BackendPackedCFunc) -> PackedFunc {
    box move |args: &[TVMArgValue]| {
        func(
            args.iter()
                .map(|ref arg| arg.value.inner)
                .collect::<Vec<_TVMValue>>()
                .as_ptr(),
            args.iter()
                .map(|ref arg| arg.type_code as i32)
                .collect::<Vec<i32>>()
                .as_ptr() as *const i32,
            args.len() as i32,
        );
        TVMRetValue::default()
    }
}