Unverified Commit c7a16d89 by MORITA Kazutaka Committed by GitHub

[Rust] Fixes for wasm32 target (#5489)

* [Rust] Fixes for wasm32 target

* [Rust] Add test for wasm32

* allow cargo config to be into repo

* Disable wasm tests in CI
parent 8599f7c6
...@@ -22,6 +22,7 @@ members = [ ...@@ -22,6 +22,7 @@ members = [
"runtime", "runtime",
"runtime/tests/test_tvm_basic", "runtime/tests/test_tvm_basic",
"runtime/tests/test_tvm_dso", "runtime/tests/test_tvm_dso",
"runtime/tests/test_wasm32",
"runtime/tests/test_nn", "runtime/tests/test_nn",
"frontend", "frontend",
"frontend/tests/basics", "frontend/tests/basics",
......
...@@ -51,6 +51,7 @@ fn main() { ...@@ -51,6 +51,7 @@ fn main() {
.layout_tests(false) .layout_tests(false)
.derive_partialeq(true) .derive_partialeq(true)
.derive_eq(true) .derive_eq(true)
.derive_default(true)
.generate() .generate()
.expect("unable to generate bindings") .expect("unable to generate bindings")
.write_to_file(PathBuf::from("src/c_runtime_api.rs")) .write_to_file(PathBuf::from("src/c_runtime_api.rs"))
......
...@@ -133,6 +133,7 @@ macro_rules! impl_dltensor_from_ndarray { ...@@ -133,6 +133,7 @@ macro_rules! impl_dltensor_from_ndarray {
shape: arr.shape().as_ptr() as *const i64 as *mut i64, shape: arr.shape().as_ptr() as *const i64 as *mut i64,
strides: arr.strides().as_ptr() as *const isize as *mut i64, strides: arr.strides().as_ptr() as *const isize as *mut i64,
byte_offset: 0, byte_offset: 0,
..Default::default()
} }
} }
} }
......
...@@ -31,8 +31,13 @@ pub mod ffi { ...@@ -31,8 +31,13 @@ pub mod ffi {
include!(concat!(env!("CARGO_MANIFEST_DIR"), "/src/c_runtime_api.rs")); include!(concat!(env!("CARGO_MANIFEST_DIR"), "/src/c_runtime_api.rs"));
pub type BackendPackedCFunc = pub type BackendPackedCFunc = extern "C" fn(
extern "C" fn(args: *const TVMValue, type_codes: *const c_int, num_args: c_int) -> c_int; args: *const TVMValue,
type_codes: *const c_int,
num_args: c_int,
out_ret_value: *mut TVMValue,
out_ret_tcode: *mut u32,
) -> c_int;
} }
pub mod array; pub mod array;
......
...@@ -297,6 +297,7 @@ impl<'a> Tensor<'a> { ...@@ -297,6 +297,7 @@ impl<'a> Tensor<'a> {
self.strides.as_ref().unwrap().as_ptr() self.strides.as_ref().unwrap().as_ptr()
} as *mut i64, } as *mut i64,
byte_offset: 0, byte_offset: 0,
..Default::default()
} }
} }
} }
......
...@@ -382,7 +382,18 @@ named! { ...@@ -382,7 +382,18 @@ named! {
// Converts a bytes to String. // Converts a bytes to String.
named! { named! {
name<String>, name<String>,
map_res!(length_data!(le_u64), |b: &[u8]| String::from_utf8(b.to_vec())) do_parse!(
len_l: le_u32 >>
len_h: le_u32 >>
data: take!(len_l) >>
(
if len_h == 0 {
String::from_utf8(data.to_vec()).unwrap()
} else {
panic!("Too long string")
}
)
)
} }
// Parses a TVMContext // Parses a TVMContext
......
...@@ -44,9 +44,17 @@ fn wrap_backend_packed_func(func_name: String, func: BackendPackedCFunc) -> Box< ...@@ -44,9 +44,17 @@ fn wrap_backend_packed_func(func_name: String, func: BackendPackedCFunc) -> Box<
(val, code as i32) (val, code as i32)
}) })
.unzip(); .unzip();
let exit_code = func(values.as_ptr(), type_codes.as_ptr(), values.len() as i32); let ret: TVMRetValue = TVMRetValue::default();
let (mut ret_val, mut ret_type_code) = ret.to_tvm_value();
let exit_code = func(
values.as_ptr(),
type_codes.as_ptr(),
values.len() as i32,
&mut ret_val,
&mut ret_type_code,
);
if exit_code == 0 { if exit_code == 0 {
Ok(TVMRetValue::default()) Ok(TVMRetValue::from_tvm_value(ret_val, ret_type_code))
} else { } else {
Err(tvm_common::errors::FuncCallError::get_with_context( Err(tvm_common::errors::FuncCallError::get_with_context(
func_name.clone(), func_name.clone(),
......
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
*/ */
use std::{ use std::{
env,
os::raw::{c_int, c_void}, os::raw::{c_int, c_void},
sync::{ sync::{
atomic::{AtomicUsize, Ordering}, atomic::{AtomicUsize, Ordering},
...@@ -27,6 +26,9 @@ use std::{ ...@@ -27,6 +26,9 @@ use std::{
thread::{self, JoinHandle}, thread::{self, JoinHandle},
}; };
#[cfg(not(target_arch = "wasm32"))]
use std::env;
use crossbeam::channel::{bounded, Receiver, Sender}; use crossbeam::channel::{bounded, Receiver, Sender};
use tvm_common::ffi::TVMParallelGroupEnv; use tvm_common::ffi::TVMParallelGroupEnv;
...@@ -147,7 +149,10 @@ impl ThreadPool { ...@@ -147,7 +149,10 @@ impl ThreadPool {
fn run_worker(queue: Receiver<Task>) { fn run_worker(queue: Receiver<Task>) {
loop { loop {
let task = queue.recv().expect("should recv"); let task = match queue.recv() {
Ok(v) => v,
Err(_) => break,
};
let result = task.run(); let result = task.run();
if result == <i32>::min_value() { if result == <i32>::min_value() {
break; break;
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
[package]
name = "test-wasm32"
version = "0.0.0"
license = "Apache-2.0"
authors = ["TVM Contributors"]
[dependencies]
ndarray="0.12"
tvm-runtime = { path = "../../" }
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
use std::{path::PathBuf, process::Command};
fn main() {
let mut out_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
out_dir.push("lib");
if !out_dir.is_dir() {
std::fs::create_dir(&out_dir).unwrap();
}
let obj_file = out_dir.join("test.o");
let lib_file = out_dir.join("libtest_wasm32.a");
let output = Command::new(concat!(
env!("CARGO_MANIFEST_DIR"),
"/src/build_test_lib.py"
))
.arg(&out_dir)
.output()
.expect("Failed to execute command");
assert!(
obj_file.exists(),
"Could not build tvm lib: {}",
String::from_utf8(output.stderr)
.unwrap()
.trim()
.split("\n")
.last()
.unwrap_or("")
);
let ar = option_env!("LLVM_AR").unwrap_or("llvm-ar-8");
let output = Command::new(ar)
.arg("rcs")
.arg(&lib_file)
.arg(&obj_file)
.output()
.expect("Failed to execute command");
assert!(
lib_file.exists(),
"Could not create archive: {}",
String::from_utf8(output.stderr)
.unwrap()
.trim()
.split("\n")
.last()
.unwrap_or("")
);
println!("cargo:rustc-link-lib=static=test_wasm32");
println!("cargo:rustc-link-search=native={}", out_dir.display());
}
#!/usr/bin/env python3
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Prepares a simple TVM library for testing."""
from os import path as osp
import sys
import tvm
from tvm import te
def main():
n = te.var('n')
A = te.placeholder((n,), name='A')
B = te.placeholder((n,), name='B')
C = te.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
s = tvm.te.create_schedule(C.op)
s[C].parallel(s[C].op.axis[0])
print(tvm.lower(s, [A, B, C], simple_mode=True))
tvm.build(s, [A, B, C], 'llvm -target=wasm32-unknown-unknown --system-lib').save(osp.join(sys.argv[1], 'test.o'))
if __name__ == '__main__':
main()
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
extern "C" {
static __tvm_module_ctx: i32;
}
#[no_mangle]
unsafe fn __get_tvm_module_ctx() -> i32 {
// Refer a symbol in the libtest_wasm32.a to make sure that the link of the
// library is not optimized out.
__tvm_module_ctx
}
extern crate ndarray;
#[macro_use]
extern crate tvm_runtime;
use ndarray::Array;
use tvm_runtime::{DLTensor, Module as _, SystemLibModule};
fn main() {
// try static
let mut a = Array::from_vec(vec![1f32, 2., 3., 4.]);
let mut b = Array::from_vec(vec![1f32, 0., 1., 0.]);
let mut c = Array::from_vec(vec![0f32; 4]);
let e = Array::from_vec(vec![2f32, 2., 4., 4.]);
let mut a_dl: DLTensor = (&mut a).into();
let mut b_dl: DLTensor = (&mut b).into();
let mut c_dl: DLTensor = (&mut c).into();
let syslib = SystemLibModule::default();
let add = syslib
.get_function("default_function")
.expect("main function not found");
call_packed!(add, &mut a_dl, &mut b_dl, &mut c_dl).unwrap();
assert!(c.all_close(&e, 1e-8f32));
}
...@@ -103,7 +103,8 @@ ALLOW_SPECIFIC_FILE = { ...@@ -103,7 +103,8 @@ ALLOW_SPECIFIC_FILE = {
"KEYS", "KEYS",
"DISCLAIMER", "DISCLAIMER",
"Jenkinsfile", "Jenkinsfile",
# sgx config # cargo config
"rust/runtime/tests/test_wasm32/.cargo/config",
"apps/sgx/.cargo/config", "apps/sgx/.cargo/config",
# html for demo purposes # html for demo purposes
"tests/webgl/test_static_webgl_library.html", "tests/webgl/test_static_webgl_library.html",
......
...@@ -54,6 +54,12 @@ cd tests/test_tvm_dso ...@@ -54,6 +54,12 @@ cd tests/test_tvm_dso
cargo run cargo run
cd - cd -
# # run wasm32 test
# cd tests/test_wasm32
# cargo build
# wasmtime $RUST_DIR/target/wasm32-wasi/debug/test-wasm32.wasm
# cd -
# run nn graph test # run nn graph test
cd tests/test_nn cd tests/test_nn
cargo run cargo run
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment