Unverified Commit d2bc94d9 by Jared Roesch Committed by GitHub

[Rust] Fix the existing test cases before refactoring. (#5122)

* Fix up the final pieces

* Tweak build.rs
parent f2b9ec4a
...@@ -19,7 +19,6 @@ ...@@ -19,7 +19,6 @@
members = [ members = [
"common", "common",
"macros", "macros",
"macros_raw",
"runtime", "runtime",
"runtime/tests/test_tvm_basic", "runtime/tests/test_tvm_basic",
"runtime/tests/test_tvm_dso", "runtime/tests/test_tvm_dso",
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
name = "callback" name = "callback"
version = "0.0.0" version = "0.0.0"
authors = ["TVM Contributors"] authors = ["TVM Contributors"]
edition = "2018"
[dependencies] [dependencies]
ndarray = "0.12" ndarray = "0.12"
......
...@@ -19,10 +19,7 @@ ...@@ -19,10 +19,7 @@
use std::panic; use std::panic;
#[macro_use] use tvm_frontend::{errors::Error, *};
extern crate tvm_frontend as tvm;
use tvm::{errors::Error, *};
fn main() { fn main() {
register_global_func! { register_global_func! {
......
...@@ -19,13 +19,18 @@ ...@@ -19,13 +19,18 @@
name = "tvm-macros" name = "tvm-macros"
version = "0.1.1" version = "0.1.1"
license = "Apache-2.0" license = "Apache-2.0"
description = "Proc macros used by the TVM crates." description = "Procedural macros of the TVM crate."
repository = "https://github.com/apache/incubator-tvm" repository = "https://github.com/apache/incubator-tvm"
readme = "README.md" readme = "README.md"
keywords = ["tvm"] keywords = ["tvm"]
authors = ["TVM Contributors"] authors = ["TVM Contributors"]
edition = "2018" edition = "2018"
[lib]
proc-macro = true
[dependencies] [dependencies]
tvm-macros-raw = { path = "../macros_raw" } goblin = "0.0.24"
proc-macro2 = "^1.0"
quote = "1.0"
syn = "1.0"
...@@ -17,12 +17,123 @@ ...@@ -17,12 +17,123 @@
* under the License. * under the License.
*/ */
#[macro_use] extern crate proc_macro;
extern crate tvm_macros_raw;
#[macro_export] use std::{fs::File, io::Read};
macro_rules! import_module { use syn::parse::{Parse, ParseStream, Result};
($module_path:literal) => { use syn::{LitStr};
$crate::import_module_raw!(file!(), $module_path); use quote::quote;
use std::path::PathBuf;
struct ImportModule {
importing_file: LitStr,
}
impl Parse for ImportModule {
fn parse(input: ParseStream) -> Result<Self> {
let importing_file: LitStr = input.parse()?;
Ok(ImportModule {
importing_file,
})
}
}
#[proc_macro]
pub fn import_module(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let import_module_args = syn::parse_macro_input!(input as ImportModule);
let manifest = std::env::var("CARGO_MANIFEST_DIR")
.expect("variable should always be set by Cargo.");
let mut path = PathBuf::new();
path.push(manifest);
path = path.join(import_module_args.importing_file.value());
let mut fd = File::open(&path)
.unwrap_or_else(|_| panic!("Unable to find TVM object file at `{}`", path.display()));
let mut buffer = Vec::new();
fd.read_to_end(&mut buffer).unwrap();
let fn_names = match goblin::Object::parse(&buffer).unwrap() {
goblin::Object::Elf(elf) => elf
.syms
.iter()
.filter_map(|s| {
if s.st_type() == 0 || goblin::elf::sym::type_to_str(s.st_type()) == "FILE" {
return None;
}
match elf.strtab.get(s.st_name) {
Some(Ok(name)) if name != "" => {
Some(syn::Ident::new(name, proc_macro2::Span::call_site()))
}
_ => None,
}
})
.collect::<Vec<_>>(),
goblin::Object::Mach(goblin::mach::Mach::Binary(obj)) => {
obj.symbols()
.filter_map(|s| match s {
Ok((name, ref nlist))
if nlist.is_global()
&& nlist.n_sect != 0
&& !name.ends_with("tvm_module_ctx") =>
{
Some(syn::Ident::new(
if name.starts_with('_') {
// Mach objects prepend a _ to globals.
&name[1..]
} else {
&name
},
proc_macro2::Span::call_site(),
))
}
_ => None,
})
.collect::<Vec<_>>()
}
_ => panic!("Unsupported object format."),
};
let extern_fns = quote! {
mod ext {
extern "C" {
#(
pub(super) fn #fn_names(
args: *const tvm_runtime::ffi::TVMValue,
type_codes: *const std::os::raw::c_int,
num_args: std::os::raw::c_int
) -> std::os::raw::c_int;
)*
}
}
}; };
let fns = quote! {
use tvm_runtime::{ffi::TVMValue, TVMArgValue, TVMRetValue, FuncCallError};
#extern_fns
#(
pub fn #fn_names(args: &[TVMArgValue]) -> Result<TVMRetValue, FuncCallError> {
let (values, type_codes): (Vec<TVMValue>, Vec<i32>) = args
.into_iter()
.map(|arg| {
let (val, code) = arg.to_tvm_value();
(val, code as i32)
})
.unzip();
let exit_code = unsafe {
ext::#fn_names(values.as_ptr(), type_codes.as_ptr(), values.len() as i32)
};
if exit_code == 0 {
Ok(TVMRetValue::default())
} else {
Err(FuncCallError::get_with_context(stringify!(#fn_names).to_string()))
}
}
)*
};
proc_macro::TokenStream::from(fns)
} }
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
[package]
name = "tvm-macros-raw"
version = "0.1.1"
license = "Apache-2.0"
description = "Proc macros used by the TVM crates."
repository = "https://github.com/apache/incubator-tvm"
readme = "README.md"
keywords = ["tvm"]
authors = ["TVM Contributors"]
edition = "2018"
[lib]
proc-macro = true
[dependencies]
goblin = "0.0.24"
proc-macro2 = "^1.0"
quote = "1.0"
syn = "1.0"
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
extern crate proc_macro;
use std::{fs::File, io::Read};
use syn::parse::{Parse, ParseStream, Result};
use syn::{Token, LitStr};
use quote::quote;
use std::path::PathBuf;
struct ImportModule {
importing_file: LitStr,
module_path: LitStr,
}
impl Parse for ImportModule {
fn parse(input: ParseStream) -> Result<Self> {
let importing_file: LitStr = input.parse()?;
input.parse::<Token![,]>()?;
let module_path: LitStr = input.parse()?;
Ok(ImportModule {
importing_file,
module_path,
})
}
}
#[proc_macro]
pub fn import_module_raw(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let import_module_args = syn::parse_macro_input!(input as ImportModule);
let mut path = PathBuf::new();
path = path.join(import_module_args.importing_file.value());
path.pop(); // remove the filename
path.push(import_module_args.module_path.value());
let mut fd = File::open(&path)
.unwrap_or_else(|_| panic!("Unable to find TVM object file at `{}`", path.display()));
let mut buffer = Vec::new();
fd.read_to_end(&mut buffer).unwrap();
let fn_names = match goblin::Object::parse(&buffer).unwrap() {
goblin::Object::Elf(elf) => elf
.syms
.iter()
.filter_map(|s| {
if s.st_type() == 0 || goblin::elf::sym::type_to_str(s.st_type()) == "FILE" {
return None;
}
match elf.strtab.get(s.st_name) {
Some(Ok(name)) if name != "" => {
Some(syn::Ident::new(name, proc_macro2::Span::call_site()))
}
_ => None,
}
})
.collect::<Vec<_>>(),
goblin::Object::Mach(goblin::mach::Mach::Binary(obj)) => {
obj.symbols()
.filter_map(|s| match s {
Ok((name, ref nlist))
if nlist.is_global()
&& nlist.n_sect != 0
&& !name.ends_with("tvm_module_ctx") =>
{
Some(syn::Ident::new(
if name.starts_with('_') {
// Mach objects prepend a _ to globals.
&name[1..]
} else {
&name
},
proc_macro2::Span::call_site(),
))
}
_ => None,
})
.collect::<Vec<_>>()
}
_ => panic!("Unsupported object format."),
};
let extern_fns = quote! {
mod ext {
extern "C" {
#(
pub(super) fn #fn_names(
args: *const tvm_runtime::ffi::TVMValue,
type_codes: *const std::os::raw::c_int,
num_args: std::os::raw::c_int
) -> std::os::raw::c_int;
)*
}
}
};
let fns = quote! {
use tvm_runtime::{ffi::TVMValue, TVMArgValue, TVMRetValue, FuncCallError};
#extern_fns
#(
pub fn #fn_names(args: &[TVMArgValue]) -> Result<TVMRetValue, FuncCallError> {
let (values, type_codes): (Vec<TVMValue>, Vec<i32>) = args
.into_iter()
.map(|arg| {
let (val, code) = arg.to_tvm_value();
(val, code as i32)
})
.unzip();
let exit_code = unsafe {
ext::#fn_names(values.as_ptr(), type_codes.as_ptr(), values.len() as i32)
};
if exit_code == 0 {
Ok(TVMRetValue::default())
} else {
Err(FuncCallError::get_with_context(stringify!(#fn_names).to_string()))
}
}
)*
};
proc_macro::TokenStream::from(fns)
}
...@@ -44,9 +44,19 @@ fn main() { ...@@ -44,9 +44,19 @@ fn main() {
.unwrap_or("") .unwrap_or("")
); );
let mut builder = Builder::new(File::create(format!("{}/libgraph.a", out_dir)).unwrap()); let lib_file = format!("{}/libtestnn.a", out_dir);
let file = File::create(&lib_file).unwrap();
let mut builder = Builder::new(file);
builder.append_path(format!("{}/graph.o", out_dir)).unwrap(); builder.append_path(format!("{}/graph.o", out_dir)).unwrap();
println!("cargo:rustc-link-lib=static=graph"); let status = Command::new("ranlib")
.arg(&lib_file)
.status()
.expect("fdjlksafjdsa");
assert!(status.success());
println!("cargo:rustc-link-lib=static=testnn");
println!("cargo:rustc-link-search=native={}", out_dir); println!("cargo:rustc-link-search=native={}", out_dir);
} }
...@@ -33,7 +33,7 @@ fn main() { ...@@ -33,7 +33,7 @@ fn main() {
} }
let obj_file = out_dir.join("test.o"); let obj_file = out_dir.join("test.o");
let lib_file = out_dir.join("libtest.a"); let lib_file = out_dir.join("libtest_basic.a");
let output = Command::new(concat!( let output = Command::new(concat!(
env!("CARGO_MANIFEST_DIR"), env!("CARGO_MANIFEST_DIR"),
...@@ -53,9 +53,17 @@ fn main() { ...@@ -53,9 +53,17 @@ fn main() {
.unwrap_or("") .unwrap_or("")
); );
let mut builder = Builder::new(File::create(lib_file).unwrap()); let mut builder = Builder::new(File::create(&lib_file).unwrap());
builder.append_path(obj_file).unwrap(); builder.append_path(&obj_file).unwrap();
drop(builder);
println!("cargo:rustc-link-lib=static=test"); let status = Command::new("ranlib")
.arg(&lib_file)
.status()
.expect("fdjlksafjdsa");
assert!(status.success());
println!("cargo:rustc-link-lib=static=test_basic");
println!("cargo:rustc-link-search=native={}", out_dir.display()); println!("cargo:rustc-link-search=native={}", out_dir.display());
} }
...@@ -25,7 +25,7 @@ use ndarray::Array; ...@@ -25,7 +25,7 @@ use ndarray::Array;
use tvm_runtime::{DLTensor, Module as _, SystemLibModule}; use tvm_runtime::{DLTensor, Module as _, SystemLibModule};
mod tvm_mod { mod tvm_mod {
import_module!("../lib/test.o"); import_module!("lib/test.o");
} }
fn main() { fn main() {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment