Commit a479432d by Nick Hynes Committed by Tianqi Chen

[RUST] Rust DSO module (#2976)

parent 05f7fa9b
......@@ -20,6 +20,7 @@ members = [
"common",
"runtime",
"runtime/tests/test_tvm_basic",
"runtime/tests/test_tvm_dso",
"runtime/tests/test_nnvm",
"frontend",
"frontend/tests/basics",
......
......@@ -22,23 +22,30 @@ extern crate bindgen;
use std::path::PathBuf;
fn main() {
let tvm_home = option_env!("TVM_HOME").map(str::to_string).unwrap_or({
let tvm_home = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.canonicalize()
.unwrap();
tvm_home
.parent()
.unwrap()
.parent()
.unwrap()
.to_str()
.unwrap()
.to_string()
});
if cfg!(feature = "bindings") {
println!("cargo:rerun-if-env-changed=TVM_HOME");
println!("cargo:rustc-link-lib=dylib=tvm_runtime");
println!("cargo:rustc-link-search={}/build", env!("TVM_HOME"));
println!("cargo:rustc-link-search={}/build", tvm_home);
}
// @see rust-bindgen#550 for `blacklist_type`
bindgen::Builder::default()
.header(format!(
"{}/include/tvm/runtime/c_runtime_api.h",
env!("TVM_HOME")
))
.header(format!(
"{}/include/tvm/runtime/c_backend_api.h",
env!("TVM_HOME")
))
.clang_arg(format!("-I{}/3rdparty/dlpack/include/", env!("TVM_HOME")))
.header(format!("{}/include/tvm/runtime/c_runtime_api.h", tvm_home))
.header(format!("{}/include/tvm/runtime/c_backend_api.h", tvm_home))
.clang_arg(format!("-I{}/3rdparty/dlpack/include/", tvm_home))
.blacklist_type("max_align_t")
.layout_tests(false)
.derive_partialeq(true)
......
......@@ -45,3 +45,6 @@ tvm-common = { version = "0.1.0", path = "../common/" }
[target.'cfg(not(target_env = "sgx"))'.dependencies]
num_cpus = "1.8.0"
[target.'cfg(not(any(target_arch = "wasm32", target_env = "sgx")))'.dependencies]
libloading = "0.5"
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
use std::{
cell::RefCell,
collections::HashMap,
ffi::CStr,
os::raw::{c_char, c_int, c_void},
pin::Pin,
};
use tvm_common::{ffi::BackendPackedCFunc, packed_func::PackedFunc};
use crate::{
threading::{TVMBackendParallelBarrier, TVMBackendParallelLaunch},
workspace::{TVMBackendAllocWorkspace, TVMBackendFreeWorkspace},
TVMAPISetLastError,
};
use super::Module;
const TVM_MAIN: &'static [u8] = b"__tvm_main__";
const TVM_MODULE_CTX: &'static [u8] = b"__tvm_module_ctx";
/// A module backed by a Dynamic Shared Object (dylib).
pub struct DsoModule<'a> {
lib: libloading::Library,
packed_funcs: RefCell<HashMap<String, &'a (dyn PackedFunc)>>,
_pin: std::marker::PhantomPinned,
}
macro_rules! init_context_func {
($lib:ident, $( ($fn:ident, $sig:ty) ),+ $(,)?) => {
unsafe {
$(
let fn_ptr = $lib.get::<*mut $sig>(concat!("__", stringify!($fn)).as_bytes());
if let Ok(fn_ptr) = fn_ptr {
**fn_ptr = $fn;
}
)+
}
};
}
impl<'a> DsoModule<'a> {
pub fn new<P: AsRef<std::ffi::OsStr>>(filename: P) -> Result<Pin<Box<Self>>, failure::Error> {
let lib = libloading::Library::new(filename)?;
init_context_func!(
lib,
(TVMAPISetLastError, extern "C" fn(*const i8)),
(
TVMBackendAllocWorkspace,
extern "C" fn(c_int, c_int, u64, c_int, c_int) -> *mut c_void
),
(
TVMBackendFreeWorkspace,
extern "C" fn(c_int, c_int, *mut c_void) -> c_int
),
(
TVMBackendParallelLaunch,
extern "C" fn(crate::threading::FTVMParallelLambda, *const c_void, usize) -> c_int
),
(
TVMBackendParallelBarrier,
extern "C" fn(usize, *const tvm_common::ffi::TVMParallelGroupEnv)
),
);
// Pin the module in memory so that `ctx` pointer (below) is stable.
let dso_mod = Box::pin(Self {
lib,
packed_funcs: RefCell::new(HashMap::new()),
_pin: std::marker::PhantomPinned,
});
unsafe {
if let Ok(ctx) = dso_mod.lib.get::<*mut *const c_void>(TVM_MODULE_CTX) {
**ctx = &dso_mod as *const _ as *const c_void;
}
}
Ok(dso_mod)
}
}
impl<'a> Module for DsoModule<'a> {
fn get_function<S: AsRef<str>>(&self, name: S) -> Option<&(dyn PackedFunc)> {
let name = name.as_ref();
let func = match unsafe {
self.lib
.get::<BackendPackedCFunc>(if name.as_bytes() == TVM_MAIN {
// If __tvm_main__ is present, it contains the name of the
// actual main function.
match self
.lib
.get::<*const c_char>(TVM_MAIN)
.map(|p| CStr::from_ptr(*p))
{
Ok(m) => m.to_bytes(),
_ => return None,
}
} else {
name.as_bytes()
})
} {
Ok(func) => unsafe { func.into_raw() },
Err(_) => return None,
};
self.packed_funcs.borrow_mut().insert(
name.to_string(),
&*Box::leak(super::wrap_backend_packed_func(name.to_string(), *func)),
);
self.packed_funcs.borrow().get(name).map(|f| *f)
}
}
impl<'a> Drop for DsoModule<'a> {
fn drop(&mut self) {
self.packed_funcs
.replace(HashMap::new())
.into_iter()
.map(|(_name, f)| unsafe { Box::from_raw(f as *const _ as *mut (dyn PackedFunc)) })
.for_each(std::mem::drop);
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#[cfg(not(any(target_arch = "wasm32", target_env = "sgx")))]
mod dso;
mod syslib;
use tvm_common::{
ffi::BackendPackedCFunc,
packed_func::{PackedFunc, TVMArgValue, TVMRetValue, TVMValue},
};
#[cfg(not(any(target_arch = "wasm32", target_env = "sgx")))]
pub use dso::DsoModule;
pub use syslib::SystemLibModule;
pub trait Module {
fn get_function<S: AsRef<str>>(&self, name: S) -> Option<&(dyn PackedFunc)>;
}
// @see `WrapPackedFunc` in `llvm_module.cc`.
fn wrap_backend_packed_func(func_name: String, func: BackendPackedCFunc) -> Box<dyn PackedFunc> {
box move |args: &[TVMArgValue]| {
let (values, type_codes): (Vec<TVMValue>, Vec<i32>) = args
.into_iter()
.map(|arg| {
let (val, code) = arg.to_tvm_value();
(val, code as i32)
})
.unzip();
let exit_code = func(values.as_ptr(), type_codes.as_ptr(), values.len() as i32);
if exit_code == 0 {
Ok(TVMRetValue::default())
} else {
Err(tvm_common::errors::FuncCallError::get_with_context(
func_name.clone(),
))
}
}
}
......@@ -21,14 +21,9 @@ use std::{
collections::HashMap, convert::AsRef, ffi::CStr, os::raw::c_char, string::String, sync::Mutex,
};
use tvm_common::{
ffi::BackendPackedCFunc,
packed_func::{PackedFunc, TVMArgValue, TVMRetValue, TVMValue},
};
use tvm_common::{ffi::BackendPackedCFunc, packed_func::PackedFunc};
pub trait Module {
fn get_function<S: AsRef<str>>(&self, name: S) -> Option<&(dyn PackedFunc)>;
}
use super::Module;
pub struct SystemLibModule;
......@@ -53,30 +48,6 @@ impl Default for SystemLibModule {
}
}
// @see `WrapPackedFunc` in `llvm_module.cc`.
pub(super) fn wrap_backend_packed_func(
func_name: String,
func: BackendPackedCFunc,
) -> Box<dyn PackedFunc> {
box move |args: &[TVMArgValue]| {
let (values, type_codes): (Vec<TVMValue>, Vec<i32>) = args
.into_iter()
.map(|arg| {
let (val, code) = arg.to_tvm_value();
(val, code as i32)
})
.unzip();
let exit_code = func(values.as_ptr(), type_codes.as_ptr(), values.len() as i32);
if exit_code == 0 {
Ok(TVMRetValue::default())
} else {
Err(tvm_common::errors::FuncCallError::get_with_context(
func_name.clone(),
))
}
}
}
#[no_mangle]
pub extern "C" fn TVMBackendRegisterSystemLibSymbol(
cname: *const c_char,
......@@ -85,7 +56,7 @@ pub extern "C" fn TVMBackendRegisterSystemLibSymbol(
let name = unsafe { CStr::from_ptr(cname).to_str().unwrap() };
SYSTEM_LIB_FUNCTIONS.lock().unwrap().insert(
name.to_string(),
&*Box::leak(wrap_backend_packed_func(name.to_string(), func)),
&*Box::leak(super::wrap_backend_packed_func(name.to_string(), func)),
);
return 0;
}
......@@ -42,7 +42,7 @@ use tvm_common::ffi::TVMParallelGroupEnv;
#[cfg(target_env = "sgx")]
use super::{TVMArgValue, TVMRetValue};
type FTVMParallelLambda =
pub(crate) type FTVMParallelLambda =
extern "C" fn(task_id: usize, penv: *const TVMParallelGroupEnv, cdata: *const c_void) -> i32;
/// Holds a parallel job request made by a TVM library function.
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
[package]
name = "test-tvm-dso"
version = "0.0.0"
license = "Apache-2.0"
authors = ["TVM Contributors"]
[dependencies]
ndarray="0.12"
tvm-runtime = { path = "../../" }
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
use std::{env, path::Path, process::Command};
fn main() {
let out_dir = env::var("OUT_DIR").unwrap();
let output = Command::new(concat!(
env!("CARGO_MANIFEST_DIR"),
"/src/build_test_lib.py"
))
.arg(&out_dir)
.output()
.expect("Failed to execute command");
assert!(
Path::new(&format!("{}/test.so", out_dir)).exists(),
"Could not build tvm lib: {}",
String::from_utf8(output.stderr)
.unwrap()
.trim()
.split("\n")
.last()
.unwrap_or("")
);
}
#!/usr/bin/env python3
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Prepares a simple TVM library for testing."""
from os import path as osp
import sys
import tvm
from tvm.contrib import cc
def main():
n = tvm.var('n')
A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B')
C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
s = tvm.create_schedule(C.op)
s[C].parallel(s[C].op.axis[0])
print(tvm.lower(s, [A, B, C], simple_mode=True))
obj_file = osp.join(sys.argv[1], 'test.o')
tvm.build(s, [A, B, C], 'llvm').save(obj_file)
cc.create_shared(osp.join(sys.argv[1], 'test.so'), [obj_file])
if __name__ == '__main__':
main()
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
extern crate ndarray;
#[macro_use]
extern crate tvm_runtime;
use ndarray::Array;
use tvm_runtime::{DLTensor, DsoModule, Module};
fn main() {
tvm_runtime::TVMGetLastError();
let module = DsoModule::new(concat!(env!("OUT_DIR"), "/test.so")).unwrap();
let add = module
.get_function("__tvm_main__")
.expect("main function not found");
let mut a = Array::from_vec(vec![1f32, 2., 3., 4.]);
let mut b = Array::from_vec(vec![1f32, 0., 1., 0.]);
let mut c = Array::from_vec(vec![0f32; 4]);
let e = Array::from_vec(vec![2f32, 2., 4., 4.]);
let mut a_dl: DLTensor = (&mut a).into();
let mut b_dl: DLTensor = (&mut b).into();
let mut c_dl: DLTensor = (&mut c).into();
call_packed!(add, &mut a_dl, &mut b_dl, &mut c_dl).unwrap();
assert!(c.all_close(&e, 1e-8f32));
}
......@@ -48,6 +48,10 @@ cd tests/test_tvm_basic
cargo run
cd -
cd tests/test_tvm_dso
cargo run
cd -
# run NNVM graph test
cd tests/test_nnvm
cargo run
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment