Unverified Commit a4bc50eb by Nick Hynes Committed by GitHub

[Rust] Static syslib (#3274)

parent 73358be5
......@@ -18,6 +18,7 @@
[workspace]
members = [
"common",
"macros",
"runtime",
"runtime/tests/test_tvm_basic",
"runtime/tests/test_tvm_dso",
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
[package]
name = "tvm-macros"
version = "0.1.0"
license = "Apache-2.0"
description = "Proc macros used by the TVM crates."
repository = "https://github.com/dmlc/tvm"
readme = "README.md"
keywords = ["tvm"]
authors = ["TVM Contributors"]
edition = "2018"
[lib]
proc-macro = true
[dependencies]
goblin = "0.0.22"
proc-macro2 = "0.4"
proc-quote = "0.2"
syn = "0.15"
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#![feature(bind_by_move_pattern_guards, proc_macro_span)]
extern crate proc_macro;
use std::{fs::File, io::Read};
use proc_quote::quote;
#[proc_macro]
pub fn import_module(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let obj_file_path = syn::parse_macro_input!(input as syn::LitStr);
let mut path = obj_file_path.span().unwrap().source_file().path();
path.pop(); // remove the filename
path.push(obj_file_path.value());
let mut fd = File::open(&path)
.unwrap_or_else(|_| panic!("Unable to find TVM object file at `{}`", path.display()));
let mut buffer = Vec::new();
fd.read_to_end(&mut buffer).unwrap();
let fn_names = match goblin::Object::parse(&buffer).unwrap() {
goblin::Object::Elf(elf) => elf
.syms
.iter()
.filter_map(|s| {
if s.st_type() == 0 || goblin::elf::sym::type_to_str(s.st_type()) == "FILE" {
return None;
}
match elf.strtab.get(s.st_name) {
Some(Ok(name)) if name != "" => {
Some(syn::Ident::new(name, proc_macro2::Span::call_site()))
}
_ => None,
}
})
.collect::<Vec<_>>(),
goblin::Object::Mach(goblin::mach::Mach::Binary(obj)) => {
obj.symbols()
.filter_map(|s| match s {
Ok((name, nlist))
if nlist.is_global()
&& nlist.n_sect != 0
&& !name.ends_with("tvm_module_ctx") =>
{
Some(syn::Ident::new(
if name.starts_with('_') {
// Mach objects prepend a _ to globals.
&name[1..]
} else {
&name
},
proc_macro2::Span::call_site(),
))
}
_ => None,
})
.collect::<Vec<_>>()
}
_ => panic!("Unsupported object format."),
};
let extern_fns = quote! {
mod ext {
extern "C" {
#(
pub(super) fn #fn_names(
args: *const tvm_runtime::ffi::TVMValue,
type_codes: *const std::os::raw::c_int,
num_args: std::os::raw::c_int
) -> std::os::raw::c_int;
)*
}
}
};
let fns = quote! {
use tvm_runtime::{ffi::TVMValue, TVMArgValue, TVMRetValue, FuncCallError};
#extern_fns
#(
pub fn #fn_names(args: &[TVMArgValue]) -> Result<TVMRetValue, FuncCallError> {
let (values, type_codes): (Vec<TVMValue>, Vec<i32>) = args
.into_iter()
.map(|arg| {
let (val, code) = arg.to_tvm_value();
(val, code as i32)
})
.unzip();
let exit_code = unsafe {
ext::#fn_names(values.as_ptr(), type_codes.as_ptr(), values.len() as i32)
};
if exit_code == 0 {
Ok(TVMRetValue::default())
} else {
Err(FuncCallError::get_with_context(stringify!(#fn_names).to_string()))
}
}
)*
};
proc_macro::TokenStream::from(fns)
}
......@@ -41,7 +41,8 @@ nom = {version = "4.0.0", default-features = false }
serde = "1.0.59"
serde_derive = "1.0.79"
serde_json = "1.0.17"
tvm-common = { version = "0.1.0", path = "../common/" }
tvm-common = { version = "0.1", path = "../common" }
tvm-macros = { version = "0.1", path = "../macros" }
[target.'cfg(not(target_env = "sgx"))'.dependencies]
num_cpus = "1.8.0"
......
......@@ -164,7 +164,7 @@ impl<'a> TryFrom<&'a str> for Graph {
/// ```
pub struct GraphExecutor<'m, 't> {
graph: Graph,
op_execs: Vec<Box<Fn() + 'm>>,
op_execs: Vec<Box<dyn Fn() + 'm>>,
tensors: Vec<Tensor<'t>>,
}
......@@ -240,7 +240,7 @@ impl<'m, 't> GraphExecutor<'m, 't> {
graph: &Graph,
lib: &'m M,
tensors: &Vec<Tensor<'t>>,
) -> Result<Vec<Box<Fn() + 'm>>, Error> {
) -> Result<Vec<Box<dyn Fn() + 'm>>, Error> {
ensure!(graph.node_row_ptr.is_some(), "Missing node_row_ptr.");
let node_row_ptr = graph.node_row_ptr.as_ref().unwrap();
......@@ -279,7 +279,7 @@ impl<'m, 't> GraphExecutor<'m, 't> {
})
.collect::<Result<Vec<DLTensor>, Error>>()
.unwrap();
let op: Box<Fn()> = box move || {
let op: Box<dyn Fn()> = box move || {
let args = dl_tensors
.iter()
.map(|t| t.into())
......
......@@ -29,7 +29,6 @@
//! For examples of use, please refer to the multi-file tests in the `tests` directory.
#![feature(
alloc,
allocator_api,
box_syntax,
fn_traits,
......@@ -77,6 +76,7 @@ pub use tvm_common::{
packed_func::{self, *},
TVMArgValue, TVMRetValue,
};
pub use tvm_macros::import_module;
pub use self::{array::*, errors::*, graph::*, module::*, threading::*, workspace::*};
......
......@@ -19,13 +19,21 @@
extern crate ar;
use std::{env, path::Path, process::Command};
use std::{path::PathBuf, process::Command};
use ar::Builder;
use std::fs::File;
fn main() {
let out_dir = env::var("OUT_DIR").unwrap();
let mut out_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
out_dir.push("lib");
if !out_dir.is_dir() {
std::fs::create_dir(&out_dir).unwrap();
}
let obj_file = out_dir.join("test.o");
let lib_file = out_dir.join("libtest.a");
let output = Command::new(concat!(
env!("CARGO_MANIFEST_DIR"),
......@@ -35,7 +43,7 @@ fn main() {
.output()
.expect("Failed to execute command");
assert!(
Path::new(&format!("{}/test.o", out_dir)).exists(),
obj_file.exists(),
"Could not build tvm lib: {}",
String::from_utf8(output.stderr)
.unwrap()
......@@ -45,9 +53,9 @@ fn main() {
.unwrap_or("")
);
let mut builder = Builder::new(File::create(format!("{}/libtest.a", out_dir)).unwrap());
builder.append_path(format!("{}/test.o", out_dir)).unwrap();
let mut builder = Builder::new(File::create(lib_file).unwrap());
builder.append_path(obj_file).unwrap();
println!("cargo:rustc-link-lib=static=test");
println!("cargo:rustc-link-search=native={}", out_dir);
println!("cargo:rustc-link-search=native={}", out_dir.display());
}
......@@ -22,13 +22,14 @@ extern crate ndarray;
extern crate tvm_runtime;
use ndarray::Array;
use tvm_runtime::{DLTensor, Module, SystemLibModule};
use tvm_runtime::{DLTensor, Module as _, SystemLibModule};
mod tvm_mod {
import_module!("../lib/test.o");
}
fn main() {
let syslib = SystemLibModule::default();
let add = syslib
.get_function("default_function")
.expect("main function not found");
// try static
let mut a = Array::from_vec(vec![1f32, 2., 3., 4.]);
let mut b = Array::from_vec(vec![1f32, 0., 1., 0.]);
let mut c = Array::from_vec(vec![0f32; 4]);
......@@ -36,6 +37,14 @@ fn main() {
let mut a_dl: DLTensor = (&mut a).into();
let mut b_dl: DLTensor = (&mut b).into();
let mut c_dl: DLTensor = (&mut c).into();
call_packed!(tvm_mod::default_function, &mut a_dl, &mut b_dl, &mut c_dl).unwrap();
assert!(c.all_close(&e, 1e-8f32));
// try runtime
let syslib = SystemLibModule::default();
let add = syslib
.get_function("default_function")
.expect("main function not found");
call_packed!(add, &mut a_dl, &mut b_dl, &mut c_dl).unwrap();
assert!(c.all_close(&e, 1e-8f32));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment