/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ #![feature(proc_macro_span)] extern crate proc_macro; use std::{fs::File, io::Read}; use proc_quote::quote; #[proc_macro] pub fn import_module(input: proc_macro::TokenStream) -> proc_macro::TokenStream { let obj_file_path = syn::parse_macro_input!(input as syn::LitStr); let mut path = obj_file_path.span().unwrap().source_file().path(); path.pop(); // remove the filename path.push(obj_file_path.value()); let mut fd = File::open(&path) .unwrap_or_else(|_| panic!("Unable to find TVM object file at `{}`", path.display())); let mut buffer = Vec::new(); fd.read_to_end(&mut buffer).unwrap(); let fn_names = match goblin::Object::parse(&buffer).unwrap() { goblin::Object::Elf(elf) => elf .syms .iter() .filter_map(|s| { if s.st_type() == 0 || goblin::elf::sym::type_to_str(s.st_type()) == "FILE" { return None; } match elf.strtab.get(s.st_name) { Some(Ok(name)) if name != "" => { Some(syn::Ident::new(name, proc_macro2::Span::call_site())) } _ => None, } }) .collect::<Vec<_>>(), goblin::Object::Mach(goblin::mach::Mach::Binary(obj)) => { obj.symbols() .filter_map(|s| match s { Ok((name, nlist)) if nlist.is_global() && nlist.n_sect != 0 && !name.ends_with("tvm_module_ctx") => { Some(syn::Ident::new( if name.starts_with('_') { // Mach objects prepend a _ to globals. &name[1..] } else { &name }, proc_macro2::Span::call_site(), )) } _ => None, }) .collect::<Vec<_>>() } _ => panic!("Unsupported object format."), }; let extern_fns = quote! { mod ext { extern "C" { #( pub(super) fn #fn_names( args: *const tvm_runtime::ffi::TVMValue, type_codes: *const std::os::raw::c_int, num_args: std::os::raw::c_int ) -> std::os::raw::c_int; )* } } }; let fns = quote! { use tvm_runtime::{ffi::TVMValue, TVMArgValue, TVMRetValue, FuncCallError}; #extern_fns #( pub fn #fn_names(args: &[TVMArgValue]) -> Result<TVMRetValue, FuncCallError> { let (values, type_codes): (Vec<TVMValue>, Vec<i32>) = args .into_iter() .map(|arg| { let (val, code) = arg.to_tvm_value(); (val, code as i32) }) .unzip(); let exit_code = unsafe { ext::#fn_names(values.as_ptr(), type_codes.as_ptr(), values.len() as i32) }; if exit_code == 0 { Ok(TVMRetValue::default()) } else { Err(FuncCallError::get_with_context(stringify!(#fn_names).to_string())) } } )* }; proc_macro::TokenStream::from(fns) }