Unverified Commit 41e1d5f9 by Jared Roesch Committed by GitHub

Revive the Rust + SGX refactor (#4976)

* Add Nick's changes's squashed

* Fix frontend compilation

* Re-enable Rust CI

* Add changes with conflicted badly

* Restructure import_module! macro in order to avoid unstable features

* Kill old unstable feature enablement

* Refactor common to use new APIs

* Move the code to stable

* Fix warning

Co-authored-by: Nick Hynes <nhynes@oasislabs.com>
parent 93dff448
...@@ -36,7 +36,6 @@ tvm_option(USE_GRAPH_RUNTIME "Build with tiny graph runtime" ON) ...@@ -36,7 +36,6 @@ tvm_option(USE_GRAPH_RUNTIME "Build with tiny graph runtime" ON)
tvm_option(USE_GRAPH_RUNTIME_DEBUG "Build with tiny graph runtime debug mode" OFF) tvm_option(USE_GRAPH_RUNTIME_DEBUG "Build with tiny graph runtime debug mode" OFF)
tvm_option(USE_OPENMP "Build with OpenMP thread pool implementation" OFF) tvm_option(USE_OPENMP "Build with OpenMP thread pool implementation" OFF)
tvm_option(USE_RELAY_DEBUG "Building Relay in debug mode..." OFF) tvm_option(USE_RELAY_DEBUG "Building Relay in debug mode..." OFF)
tvm_option(USE_SGX "Build with SGX" OFF)
tvm_option(USE_RTTI "Build with RTTI" ON) tvm_option(USE_RTTI "Build with RTTI" ON)
tvm_option(USE_MSVC_MT "Build with MT" OFF) tvm_option(USE_MSVC_MT "Build with MT" OFF)
tvm_option(USE_MICRO "Build with Micro" OFF) tvm_option(USE_MICRO "Build with Micro" OFF)
...@@ -243,7 +242,6 @@ include(cmake/modules/OpenMP.cmake) ...@@ -243,7 +242,6 @@ include(cmake/modules/OpenMP.cmake)
include(cmake/modules/Vulkan.cmake) include(cmake/modules/Vulkan.cmake)
include(cmake/modules/Metal.cmake) include(cmake/modules/Metal.cmake)
include(cmake/modules/ROCM.cmake) include(cmake/modules/ROCM.cmake)
include(cmake/modules/SGX.cmake)
include(cmake/modules/LLVM.cmake) include(cmake/modules/LLVM.cmake)
include(cmake/modules/Micro.cmake) include(cmake/modules/Micro.cmake)
include(cmake/modules/ANTLR.cmake) include(cmake/modules/ANTLR.cmake)
...@@ -283,12 +281,6 @@ else() ...@@ -283,12 +281,6 @@ else()
set_target_properties(tvm PROPERTIES COMPILE_DEFINITIONS "NDEBUG") set_target_properties(tvm PROPERTIES COMPILE_DEFINITIONS "NDEBUG")
endif(USE_RELAY_DEBUG) endif(USE_RELAY_DEBUG)
if(NOT USE_SGX STREQUAL "OFF")
add_dependencies(tvm sgx_edl)
add_dependencies(tvm_runtime sgx_edl tvm_t)
install(TARGETS tvm_t ARCHIVE DESTINATION lib${LIB_SUFFIX})
endif()
if(USE_THREADS) if(USE_THREADS)
message(STATUS "Build with thread support...") message(STATUS "Build with thread support...")
set(CMAKE_THREAD_PREFER_PTHREAD TRUE) set(CMAKE_THREAD_PREFER_PTHREAD TRUE)
......
...@@ -220,6 +220,7 @@ stage('Build') { ...@@ -220,6 +220,7 @@ stage('Build') {
sh "${docker_run} ${ci_cpu} ./tests/scripts/task_python_vta_fsim.sh" sh "${docker_run} ${ci_cpu} ./tests/scripts/task_python_vta_fsim.sh"
sh "${docker_run} ${ci_cpu} ./tests/scripts/task_python_vta_tsim.sh" sh "${docker_run} ${ci_cpu} ./tests/scripts/task_python_vta_tsim.sh"
sh "${docker_run} ${ci_cpu} ./tests/scripts/task_golang.sh" sh "${docker_run} ${ci_cpu} ./tests/scripts/task_golang.sh"
sh "${docker_run} ${ci_cpu} ./tests/scripts/task_rust.sh"
} }
} }
} }
......
[build]
target = "x86_64-fortanix-unknown-sgx"
[target.x86_64-fortanix-unknown-sgx]
runner = "ftxsgx-runner-cargo"
../../rust/.rustfmt.toml
\ No newline at end of file
...@@ -16,17 +16,13 @@ ...@@ -16,17 +16,13 @@
# under the License. # under the License.
[package] [package]
name = "model-enclave" name = "sgx-demo"
version = "0.1.0" version = "0.1.0"
authors = ["Nick Hynes <nhynes@berkeley.edu>"] authors = ["Nick Hynes <nhynes@nhynes.com>"]
edition = "2018"
[lib]
crate-type = ["staticlib"]
[dependencies] [dependencies]
lazy_static = "1.1.0" tvm-runtime = { path = "../../rust/runtime" }
tvm = { path = "../../../rust", default-features = false, features = ["sgx"] }
[profile.release] [patch.crates-io]
lto = true "backtrace" = { git = "https://github.com/nhynes/backtrace-rs", branch = "fix-sgx" }
opt-level = 3
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
SGX_SDK ?= /opt/sgxsdk
RUST_SGX_SDK ?= /opt/rust-sgx-sdk
SGX_MODE ?= SIM
DEBUG ?= true
NUM_THREADS ?= 4
TVM_DIR ?= $(shell git rev-parse --show-toplevel)
export
sgx_edger8r := $(SGX_SDK)/bin/x64/sgx_edger8r
sgx_enclave_signer := $(SGX_SDK)/bin/x64/sgx_sign
ifneq ($(SGX_MODE), HW)
sgx_sim := _sim
endif
urts_library_name := sgx_urts$(sgx_sim)
trts_library_name := sgx_trts$(sgx_sim)
tservice_library_name := sgx_tservice$(sgx_sim)
uservice_library_name := sgx_uae_service$(sgx_sim)
pkg_cflags := -std=c++11 -fPIC \
-I$(SGX_SDK)/include \
-I$(TVM_DIR)/include \
-I$(TVM_DIR)/dlpack/include \
-I$(TVM_DIR)/dmlc-core/include
pkg_ldflags := -L$(TVM_DIR)/build -ltvm_runtime
ifneq ($(DEBUG), false)
debug := debug
enclave_cflags += -Og -g
pkg_cflags += -Og -g
else
debug := release
enclave_cflags += -O2
pkg_cflags += -O2
endif
build_dir := build
enclave_cflags := \
-I$(SGX_SDK)/include \
-I$(SGX_SDK)/include/tlibc \
-I$(SGX_SDK)/include/stdport \
-I$(SGX_SDK)/include/epid \
-I$(TVM_DIR)/include \
-I$(TVM_DIR)/dlpack/include \
-I$(TVM_DIR)/dmlc-core/include
enclave_ldflags :=\
-L$(build_dir) -L$(TVM_DIR)/build \
-Wl,--no-undefined -nostdlib -nodefaultlibs -nostartfiles -L$(SGX_SDK)/lib64\
-Wl,--whole-archive -l$(trts_library_name) -Wl,--no-whole-archive\
-Wl,--start-group\
-lsgx_tstdc -lsgx_tstdcxx -lsgx_tcxx -lsgx_tcrypto -lsgx_tkey_exchange -l$(tservice_library_name)\
-lenclave -ltvm_t\
-Wl,--end-group\
-Wl,-Bstatic -Wl,-Bsymbolic -Wl,--no-undefined\
-Wl,-pie,-eenclave_entry -Wl,--export-dynamic\
-Wl,--defsym,__ImageBase=0 -Wl,--gc-sections\
-Wl,--version-script=enclave/enclave.lds
.PHONY: enclave clean
enclave: $(build_dir)/enclave.signed.so
$(build_dir)/enclave.signed.so: $(build_dir)/enclave.so build/enclave_config.xml enclave/enclave.pem
$(sgx_enclave_signer) sign -key enclave/enclave.pem -enclave $< -out $@ -config build/enclave_config.xml
enclave/enclave.pem:
curl -sSo $@ 'https://gist.githubusercontent.com/nhynes/8a2d80068a92e672f8b0b7d710ceb404/raw/2d5ae5fbe83198ede49465fdc6535065e093543b/tvm_sgx_demo.pem'
build/enclave_config.xml: enclave/enclave_config.xml.in
cpp $^ -P -o $@ -DNUM_THREADS=$$(( $(NUM_THREADS) + 1 ))
$(build_dir)/enclave.so: $(build_dir)/libenclave.a $(TVM_DIR)/build/libtvm_t.a
$(CXX) $< -o $@ $(enclave_ldflags) $(enclave_cflags) -ltvm_t
$(build_dir)/libenclave.a: enclave/target/x86_64-unknown-linux-sgx/$(debug)/libmodel_enclave.a
@mkdir -p $(@D)
@cp $< $@
enclave/target/x86_64-unknown-linux-sgx/$(debug)/libmodel_enclave.a: enclave/**/*
$(MAKE) -C enclave
clean:
$(MAKE) -s -C enclave clean
rm -rf build
...@@ -15,64 +15,10 @@ ...@@ -15,64 +15,10 @@
<!--- specific language governing permissions and limitations --> <!--- specific language governing permissions and limitations -->
<!--- under the License. --> <!--- under the License. -->
# TVM in Intel SGX Example ## Setup
This application demonstrates the use of a simple TVM model in the [Intel SGX](https://software.intel.com/en-us/blogs/2013/09/26/protecting-application-secrets-with-intel-sgx) trusted computing environment. 1. [Install the Fortanix Enclave Development Platform](https://edp.fortanix.com/docs/installation/guide/)
2. `rustup component add llvm-tools-preview` to get `llvm-ar` and `llvm-objcopy`
## Prerequisites 3. `pip install numpy decorator psutil`
4. `cargo run` to start the enclave TCP server
1. The TVM premade Docker image 5. Send a 28x28 "image" to the enclave model server using `head -c $((28*28*4)) /dev/urandom | nc 127.0.0.1 4242 | python read_results.py`
or
1. A GNU/Linux environment
2. TVM compiled with LLVM and SGX; and the `tvm` Python module
3. The [Linux SGX SDK](https://github.com/intel/linux-sgx) [link to pre-built libraries](https://01.org/intel-software-guard-extensions/downloads)
4. [Rust](https://rustup.sh)
5. The [rust-sgx-sdk](https://github.com/baidu/rust-sgx-sdk)
6. [xargo](https://github.com/japaric/xargo)
Check out the `/tvm/install/ubuntu_install_sgx.sh` for the commands to get these dependencies.
## Running the example
If using Docker, start by running
```
git clone --recursive https://github.com/apache/incubator-tvm.git tvm
docker run --rm -it -v $(pwd)/tvm:/mnt tvmai/ci-cpu /bin/bash
```
then, in the container
```
cd /mnt
mkdir build && cd build
cmake .. -DUSE_LLVM=ON -DUSE_SGX=/opt/sgxsdk -DRUST_SGX_SDK=/opt/rust-sgx-sdk
make -j4
cd ..
pip install -e python -e topi/python
cd apps/sgx
```
Once TVM is build and installed, just
`./run_example.sh`
If everything goes well, you should see a lot of build messages and below them
the text `It works!`.
## High-level overview
First of all, it helps to think of an SGX enclave as a library that can be called
to perform trusted computation.
In this library, one can use other libraries like TVM.
Building this example performs the following steps:
1. Creates a simple TVM module that computes `x + 1` and save it as a system library.
2. Builds a TVM runtime that links the module and allows running it using the TVM Python runtime.
3. Packages the bundle into an SGX enclave
4. Runs the enclave using the usual TVM Python `module` API
For more information on building, please refer to the `Makefile`.
For more information on the TVM module, please refer to `../howto_deploy`.
For more in formation on SGX enclaves, please refer to the [SGX Enclave Demo](https://github.com/intel/linux-sgx/tree/master/SampleCode/SampleEnclave/)
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
use std::process::Command;
macro_rules! mf_dir {
($p:literal) => {
concat!(env!("CARGO_MANIFEST_DIR"), $p)
};
}
fn main() {
let out_dir = std::env::var("OUT_DIR").unwrap();
let build_output = Command::new(mf_dir!("/src/build_model.py"))
.arg(&out_dir)
.env(
"PYTHONPATH",
concat!(
mf_dir!("/../../python"),
":",
mf_dir!("/../../nnvm/python"),
":",
mf_dir!("/../../topi/python")
),
)
.output()
.expect("Failed to build model");
assert!(
["model.o", "graph.json", "params.bin"]
.iter()
.all(|f| { std::path::Path::new(&format!("{}/{}", out_dir, f)).exists() }),
"Could not build tvm lib: STDOUT:\n\n{}\n\nSTDERR\n\n{}",
String::from_utf8(build_output.stdout).unwrap().trim(),
String::from_utf8(build_output.stderr).unwrap().trim()
);
let sysroot_output = Command::new("rustc")
.args(&["--print", "sysroot"])
.output()
.expect("Failed to get sysroot");
let sysroot = String::from_utf8(sysroot_output.stdout).unwrap();
let sysroot = sysroot.trim();
let mut llvm_tools_path = std::path::PathBuf::from(&sysroot);
llvm_tools_path.push("lib/rustlib/x86_64-unknown-linux-gnu/bin");
Command::new("rustup")
.args(&["component", "add", "llvm-tools-preview"])
.output()
.expect("failed to install llvm tools");
std::process::Command::new(llvm_tools_path.join("llvm-objcopy"))
.arg("--globalize-symbol=__tvm_module_startup")
.arg("--remove-section=.ctors")
.arg(&format!("{}/model.o", out_dir))
.output()
.expect("gould not gloablize startup function");
std::process::Command::new(llvm_tools_path.join("llvm-ar"))
.arg("rcs")
.arg(&format!("{}/libmodel.a", out_dir))
.arg(&format!("{}/model.o", out_dir))
.output()
.expect("failed to package model archive");
println!("cargo:rustc-link-lib=static=model");
println!("cargo:rustc-link-search=native={}", out_dir);
}
../../../rust/.rustfmt.toml
\ No newline at end of file
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
MODEL ?= resnet
NUM_THREADS ?= 4
BATCH_SIZE ?= 64
TRAINING ?= true
DEBUG ?= false
build_dir := ../build
ifeq ($(DEBUG), false)
debug := release
xargo_args := --release
else
debug := debug
endif
target=target/x86_64-unknown-linux-sgx/$(debug)/libmodel-enclave.a
$(target): $(build_dir)/libmodel.a **/* $(TVM_DIR)/rust/patched.txt
RUST_TARGET_PATH=$(shell pwd) \
RUST_TARGET_DIR=$(shell pwd)/target \
RUSTFLAGS="-Z force-unstable-if-unmarked" \
TVM_NUM_THREADS=$(NUM_THREADS) \
BUILD_DIR=../build \
xargo build --target x86_64-unknown-linux-sgx $(xargo_args) -q
$(TVM_DIR)/rust/patched.txt: $(shell pwd)/sgx-deps.diff
echo $(TVM_DIR)
cd $(TVM_DIR) && git apply $<
touch $@
$(build_dir)/libmodel.a: $(build_dir)/model.o
$(AR) cr $@ $^
$(build_dir)/model.o: $(build_dir)/model.bc
$(CC) -c $< -o $@ -fPIC -O3
objcopy --globalize-symbol __tvm_module_startup $@
$(build_dir)/model.bc: src/build_model.py
python3 $< -o $(build_dir)
clean:
xargo clean
enclave.so
{
global:
g_global_data_sim;
g_global_data;
enclave_entry;
local:
*;
};
<EnclaveConfiguration>
<ProdID>0</ProdID>
<ISVSVN>0</ISVSVN>
<StackMaxSize>0xf0000</StackMaxSize>
<HeapMaxSize>0xf000000</HeapMaxSize>
<TCSNum>NUM_THREADS</TCSNum>
<TCSPolicy>0</TCSPolicy> <!-- must be "bound" to use thread_local -->
<DisableDebug>0</DisableDebug>
<MiscSelect>0</MiscSelect>
<MiscMask>0xFFFFFFFF</MiscMask>
</EnclaveConfiguration>
diff --git a/rust/Cargo.toml b/rust/Cargo.toml
index 0819e0c7..e56f4ef2 100644
--- a/rust/Cargo.toml
+++ b/rust/Cargo.toml
@@ -14,7 +14,7 @@ default = ["nom/std"]
sgx = ["nom/alloc"]
[dependencies]
-bounded-spsc-queue = "0.4.0"
+bounded-spsc-queue = { git = "https://github.com/nhynes/bounded-spsc-queue", branch = "sgx" }
error-chain = { version = "0.12.0", default-features = false }
itertools = "0.7.8"
lazy_static = "1.1.0"
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#![feature(try_from)]
#[macro_use]
extern crate lazy_static;
#[macro_use]
extern crate tvm;
use std::{
convert::{TryFrom, TryInto},
sync::Mutex,
};
use tvm::{
ffi::runtime::DLTensor,
runtime::{
load_param_dict, sgx, Graph, GraphExecutor, SystemLibModule, TVMArgValue, TVMRetValue, Tensor,
},
};
lazy_static! {
static ref SYSLIB: SystemLibModule = { SystemLibModule::default() };
static ref MODEL: Mutex<GraphExecutor<'static, 'static>> = {
let graph_json = include_str!(concat!("../", env!("BUILD_DIR"), "/graph.json"));
let params_bytes = include_bytes!(concat!("../", env!("BUILD_DIR"), "/params.bin"));
let params = load_param_dict(params_bytes).unwrap();
let graph = Graph::try_from(graph_json).unwrap();
let mut exec = GraphExecutor::new(graph, &*SYSLIB).unwrap();
exec.load_params(params);
Mutex::new(exec)
};
}
fn ecall_init(_args: &[TVMArgValue]) -> TVMRetValue {
lazy_static::initialize(&MODEL);
TVMRetValue::from(0)
}
fn ecall_main(args: &[TVMArgValue<'static>]) -> TVMRetValue {
let mut model = MODEL.lock().unwrap();
let inp = args[0].try_into().unwrap();
let mut out: Tensor = args[1].try_into().unwrap();
model.set_input("data", inp);
model.run();
sgx::shutdown();
out.copy(model.get_output(0).unwrap());
TVMRetValue::from(1)
}
pub mod ecalls {
//! todo: generate this using proc_macros
use super::*;
use std::{
ffi::CString,
mem,
os::raw::{c_char, c_int, c_void},
slice,
};
use tvm::{
ffi::runtime::{TVMRetValueHandle, TVMValue},
runtime::{
sgx::{ocall_packed_func, run_worker, SgxStatus},
DataType, PackedFunc,
},
};
macro_rules! tvm_ocall {
($func: expr) => {
match $func {
0 => Ok(()),
err => Err(err),
}
};
}
const ECALLS: &'static [&'static str] = &["__tvm_run_worker__", "__tvm_main__", "init"];
pub type EcallPackedFunc = Box<Fn(&[TVMArgValue<'static>]) -> TVMRetValue + Send + Sync>;
lazy_static! {
static ref ECALL_FUNCS: Vec<EcallPackedFunc> = {
vec![
Box::new(run_worker),
Box::new(ecall_main),
Box::new(ecall_init),
]
};
}
extern "C" {
fn __tvm_module_startup() -> ();
fn tvm_ocall_register_export(name: *const c_char, func_id: c_int) -> SgxStatus;
}
#[no_mangle]
pub extern "C" fn tvm_ecall_init(_ret: TVMRetValueHandle) {
unsafe {
__tvm_module_startup();
ECALLS.into_iter().enumerate().for_each(|(i, ecall)| {
tvm_ocall!(tvm_ocall_register_export(
CString::new(*ecall).unwrap().as_ptr(),
i as i32
))
.expect(&format!("Error registering `{}`", ecall));
});
}
}
#[no_mangle]
pub extern "C" fn tvm_ecall_packed_func(
func_id: c_int,
arg_values: *const TVMValue,
type_codes: *const c_int,
num_args: c_int,
ret_val: *mut TVMValue,
ret_type_code: *mut i64,
) {
let args = unsafe {
let values = slice::from_raw_parts(arg_values, num_args as usize);
let type_codes = slice::from_raw_parts(type_codes, num_args as usize);
values
.into_iter()
.zip(type_codes.into_iter())
.map(|(v, t)| TVMArgValue::new(*v, *t as i64))
.collect::<Vec<TVMArgValue<'static>>>()
};
let (rv, tc) = ECALL_FUNCS[func_id as usize](&args).into_tvm_value();
unsafe {
*ret_val = rv;
*ret_type_code = tc;
}
}
}
{
"arch": "x86_64",
"cpu": "x86-64",
"data-layout": "e-m:e-i64:64-f80:128-n8:16:32:64-S128",
"dynamic-linking": true,
"env": "sgx",
"exe-allocation-crate": "alloc_system",
"executables": true,
"has-elf-tls": true,
"has-rpath": true,
"linker-flavor": "gcc",
"linker-is-gnu": true,
"llvm-target": "x86_64-unknown-linux-gnu",
"max-atomic-width": 64,
"os": "linux",
"position-independent-executables": true,
"pre-link-args": {
"gcc": [
"-Wl,--as-needed",
"-Wl,-z,noexecstack",
"-m64"
]
},
"relro-level": "full",
"stack-probes": true,
"target-c-int-width": "32",
"target-endian": "little",
"target-family": "unix",
"target-pointer-width": "64",
"vendor": "unknown"
}
...@@ -15,16 +15,14 @@ ...@@ -15,16 +15,14 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
[dependencies] import struct
alloc = {} import sys
panic_unwind = {}
panic_abort = {}
[dependencies.std] import numpy as np
path = "/opt/rust-sgx-sdk/xargo/sgx_tstd"
features = ["backtrace", "stdio", "untrusted_time"]
stage = 2
[dependencies.xargo_sgx_rand] def float_bytes(l):
path = "/opt/rust-sgx-sdk/xargo/sgx_rand" for i in range(0, len(l), 4):
stage = 3 yield l[i:i + 4]
floats = [struct.unpack('f', f)[0] for f in float_bytes(sys.stdin.buffer.read())]
print(np.array(floats))
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import os.path as osp
import numpy as np
import tvm
from tvm import te
CWD = osp.abspath(osp.dirname(__file__))
def main():
ctx = tvm.context('cpu', 0)
model = tvm.runtime.load_module(osp.join(CWD, 'build', 'enclave.signed.so'))
inp = tvm.nd.array(np.ones((1, 3, 224, 224), dtype='float32'), ctx)
out = tvm.nd.array(np.empty((1, 1000), dtype='float32'), ctx)
model(inp, out)
if abs(out.asnumpy().sum() - 1) < 0.001:
print('It works!')
else:
print('It doesn\'t work!')
exit(1)
if __name__ == '__main__':
main()
#!/usr/bin/python3
# Licensed to the Apache Software Foundation (ASF) under one # Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file # or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information # distributed with this work for additional information
...@@ -14,11 +16,12 @@ ...@@ -14,11 +16,12 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
"""Creates a simple TVM modules.""" """Creates a simple TVM modules."""
import argparse
import os import os
from os import path as osp from os import path as osp
import sys
from tvm import relay from tvm import relay
from tvm.relay import testing from tvm.relay import testing
...@@ -27,9 +30,8 @@ from tvm import te ...@@ -27,9 +30,8 @@ from tvm import te
def main(): def main():
parser = argparse.ArgumentParser() dshape = (1, 28, 28)
parser.add_argument('-o', '--out-dir', default='.') net, params = relay.testing.mlp.get_workload(batch_size=dshape[0], dtype='float32')
opts = parser.parse_args()
dshape = (1, 3, 224, 224) dshape = (1, 3, 224, 224)
net, params = relay.testing.resnet.get_workload( net, params = relay.testing.resnet.get_workload(
...@@ -39,11 +41,11 @@ def main(): ...@@ -39,11 +41,11 @@ def main():
graph, lib, params = relay.build( graph, lib, params = relay.build(
net, 'llvm --system-lib', params=params) net, 'llvm --system-lib', params=params)
build_dir = osp.abspath(opts.out_dir) build_dir = osp.abspath(sys.argv[1])
if not osp.isdir(build_dir): if not osp.isdir(build_dir):
os.makedirs(build_dir, exist_ok=True) os.makedirs(build_dir, exist_ok=True)
lib.save(osp.join(build_dir, 'model.bc')) lib.save(osp.join(build_dir, 'model.o'))
with open(osp.join(build_dir, 'graph.json'), 'w') as f_graph_json: with open(osp.join(build_dir, 'graph.json'), 'w') as f_graph_json:
f_graph_json.write(graph) f_graph_json.write(graph)
with open(osp.join(build_dir, 'params.bin'), 'wb') as f_params: with open(osp.join(build_dir, 'params.bin'), 'wb') as f_params:
......
...@@ -17,12 +17,35 @@ ...@@ -17,12 +17,35 @@
* under the License. * under the License.
*/ */
use std::env; extern crate tvm_runtime;
use std::{
convert::TryFrom as _,
io::{Read as _, Write as _},
};
fn main() { fn main() {
println!( let syslib = tvm_runtime::SystemLibModule::default();
"cargo:rustc-link-search=native={}",
env::var("BUILD_DIR").unwrap() let graph_json = include_str!(concat!(env!("OUT_DIR"), "/graph.json"));
); let params_bytes = include_bytes!(concat!(env!("OUT_DIR"), "/params.bin"));
println!("cargo:rustc-link-lib=static=model"); let params = tvm_runtime::load_param_dict(params_bytes).unwrap();
let graph = tvm_runtime::Graph::try_from(graph_json).unwrap();
let mut exec = tvm_runtime::GraphExecutor::new(graph, &syslib).unwrap();
exec.load_params(params);
let listener = std::net::TcpListener::bind("127.0.0.1:4242").unwrap();
for stream in listener.incoming() {
let mut stream = stream.unwrap();
if let Err(_) =
stream.read_exact(exec.get_input("data").unwrap().data().view().as_mut_slice())
{
continue;
}
exec.run();
if let Err(_) = stream.write_all(exec.get_output(0).unwrap().data().as_slice()) {
continue;
}
}
} }
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
if(NOT USE_SGX STREQUAL "OFF")
set(_sgx_src ${CMAKE_CURRENT_SOURCE_DIR}/src/runtime/sgx)
set(_tvm_u_h ${_sgx_src}/untrusted/tvm_u.h)
set(_tvm_t_h ${_sgx_src}/trusted/tvm_t.h)
set(_tvm_t_c ${_sgx_src}/trusted/tvm_t.c)
set(_tvm_edl ${_sgx_src}/tvm.edl)
set(_sgx_ustdc ${RUST_SGX_SDK}/sgx_ustdc)
set(_urts_lib "sgx_urts")
if(NOT SGX_MODE STREQUAL "HW")
message(STATUS "Build with SGX support (SIM)")
set(_urts_lib "${_urts_lib}_sim")
else()
message(STATUS "Build with SGX support (HW)")
endif()
# build edge routines
add_custom_command(
OUTPUT ${_tvm_u_h}
COMMAND ${USE_SGX}/bin/x64/sgx_edger8r --untrusted
--untrusted --untrusted-dir ${_sgx_src}/untrusted
--trusted --trusted-dir ${_sgx_src}/trusted
--search-path ${USE_SGX}/include --search-path ${RUST_SGX_SDK}/edl
${_tvm_edl}
COMMAND sed -i "4i '#include <tvm/runtime/c_runtime_api.h>'" ${_tvm_u_h}
COMMAND sed -i "4i '#include <tvm/runtime/c_runtime_api.h>'" ${_tvm_t_h}
DEPENDS ${_tvm_edl}
)
add_custom_command(
OUTPUT ${_sgx_ustdc}/libsgx_ustdc.a
COMMAND make
WORKING_DIRECTORY ${_sgx_ustdc}
)
add_custom_target(sgx_edl DEPENDS ${_tvm_u_h} ${_sgx_ustdc}/libsgx_ustdc.a)
# build trusted library
set_source_files_properties(${_tvm_t_c} PROPERTIES GENERATED TRUE)
add_library(tvm_t STATIC ${_tvm_t_c})
add_dependencies(tvm_t sgx_edl)
target_include_directories(tvm_t PUBLIC ${USE_SGX}/include ${USE_SGX}/include/tlibc)
# add untrusted runtime files
include_directories(${USE_SGX}/include)
file(GLOB RUNTIME_SGX_SRCS ${_sgx_src}/untrusted/*.c*)
list(APPEND TVM_RUNTIME_LINKER_LIBS
-lpthread
-L${USE_SGX}/lib64 -l${_urts_lib}
-L${RUST_SGX_SDK}/sgx_ustdc -lsgx_ustdc)
list(APPEND RUNTIME_SRCS ${RUNTIME_SGX_SRCS})
include_directories(${RUST_SGX_SDK}/edl ${RUST_SGX_SDK}/common)
endif()
...@@ -31,8 +31,6 @@ echo set\(USE_RPC ON\) >> config.cmake ...@@ -31,8 +31,6 @@ echo set\(USE_RPC ON\) >> config.cmake
echo set\(USE_SORT ON\) >> config.cmake echo set\(USE_SORT ON\) >> config.cmake
echo set\(USE_GRAPH_RUNTIME ON\) >> config.cmake echo set\(USE_GRAPH_RUNTIME ON\) >> config.cmake
echo set\(USE_BLAS openblas\) >> config.cmake echo set\(USE_BLAS openblas\) >> config.cmake
echo set\(USE_SGX /opt/sgxsdk\) >> config.cmake
echo set\(RUST_SGX_SDK /opt/rust-sgx-sdk\) >> config.cmake
mkdir -p build mkdir -p build
cd build cd build
cmake .. cmake ..
......
#!/bin/bash
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
set -e
set -u
set -o pipefail
apt-get update && apt-get install -y --no-install-recommends \
build-essential git cmake \
wget python pkg-config software-properties-common \
autoconf automake libtool ocaml \
protobuf-compiler libprotobuf-dev \
libssl-dev libcurl4-openssl-dev curl
git clone --branch=sgx_2.2 --depth=1 https://github.com/intel/linux-sgx.git
cd linux-sgx
curl -s -S -L 'https://gist.githubusercontent.com/nhynes/c770b0e91610f8c020a8d1a803a1e7cb/raw/8f5372d9cb88929b3cc49a384943bb363bc06827/intel-sgx.patch' | git apply
./download_prebuilt.sh
make -j4 sdk && make -j4 sdk_install_pkg
./linux/installer/bin/sgx_linux_x64_sdk*.bin --prefix /opt
cd -
git clone --branch=v1.0.5 --depth=1 https://github.com/baidu/rust-sgx-sdk.git /opt/rust-sgx-sdk
cd /opt/rust-sgx-sdk
curl -s -S -L 'https://gist.githubusercontent.com/nhynes/37164039c5d3f33aa4f123e4ba720036/raw/b0de575fe937231799930764e76c664b92975163/rust-sgx-sdk.diff' | git apply
cd -
...@@ -221,7 +221,6 @@ inline const char* DeviceName(int type) { ...@@ -221,7 +221,6 @@ inline const char* DeviceName(int type) {
} }
} }
#ifndef _LIBCPP_SGX_NO_IOSTREAMS
inline std::ostream& operator<<(std::ostream& os, DLContext ctx) { // NOLINT(*) inline std::ostream& operator<<(std::ostream& os, DLContext ctx) { // NOLINT(*)
int device_type = static_cast<int>(ctx.device_type); int device_type = static_cast<int>(ctx.device_type);
if (device_type > kRPCSessMask) { if (device_type > kRPCSessMask) {
...@@ -231,8 +230,6 @@ inline std::ostream& operator<<(std::ostream& os, DLContext ctx) { // NOLINT(*) ...@@ -231,8 +230,6 @@ inline std::ostream& operator<<(std::ostream& os, DLContext ctx) { // NOLINT(*)
os << runtime::DeviceName(device_type) << "(" << ctx.device_id << ")"; os << runtime::DeviceName(device_type) << "(" << ctx.device_id << ")";
return os; return os;
} }
#endif
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
#endif // TVM_RUNTIME_DEVICE_API_H_ #endif // TVM_RUNTIME_DEVICE_API_H_
...@@ -24,9 +24,6 @@ ...@@ -24,9 +24,6 @@
#ifndef TVM_RUNTIME_PACKED_FUNC_H_ #ifndef TVM_RUNTIME_PACKED_FUNC_H_
#define TVM_RUNTIME_PACKED_FUNC_H_ #define TVM_RUNTIME_PACKED_FUNC_H_
#ifndef _LIBCPP_SGX_NO_IOSTREAMS
#include <sstream>
#endif
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <tvm/runtime/c_runtime_api.h> #include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/module.h> #include <tvm/runtime/module.h>
...@@ -1019,7 +1016,6 @@ inline const char* TypeCode2Str(int type_code) { ...@@ -1019,7 +1016,6 @@ inline const char* TypeCode2Str(int type_code) {
} }
} }
#ifndef _LIBCPP_SGX_NO_IOSTREAMS
inline std::ostream& operator<<(std::ostream& os, DLDataType t) { // NOLINT(*) inline std::ostream& operator<<(std::ostream& os, DLDataType t) { // NOLINT(*)
if (t.bits == 1 && t.lanes == 1 && t.code == kDLUInt) { if (t.bits == 1 && t.lanes == 1 && t.code == kDLUInt) {
os << "bool"; return os; os << "bool"; return os;
...@@ -1041,30 +1037,11 @@ inline std::ostream& operator<<(std::ostream& os, const DataType& dtype) { // NO ...@@ -1041,30 +1037,11 @@ inline std::ostream& operator<<(std::ostream& os, const DataType& dtype) { // NO
return os << dtype.operator DLDataType(); return os << dtype.operator DLDataType();
} }
#endif
inline std::string DLDataType2String(DLDataType t) { inline std::string DLDataType2String(DLDataType t) {
if (t.bits == 0) return ""; if (t.bits == 0) return "";
#ifndef _LIBCPP_SGX_NO_IOSTREAMS
std::ostringstream os; std::ostringstream os;
os << t; os << t;
return os.str(); return os.str();
#else
if (t.bits == 1 && t.lanes == 1 && t.code == kDLUInt) {
return "bool";
}
if (t.code < kTVMCustomBegin) {
repr += TypeCode2Str(t.code);
} else {
repr += "custom[" + GetCustomTypeName(t.code) + "]";
}
if (t.code == kTVMOpaqueHandle) return repr;
repr += std::to_string(static_cast<int>(t.bits));
if (t.lanes != 1) {
repr += "x" + std::to_string(static_cast<int>(t.lanes));
}
return repr;
#endif
} }
inline DLDataType String2DLDataType(std::string s) { inline DLDataType String2DLDataType(std::string s) {
......
...@@ -22,8 +22,10 @@ newline_style = "Auto" ...@@ -22,8 +22,10 @@ newline_style = "Auto"
use_small_heuristics = "Default" use_small_heuristics = "Default"
indent_style = "Block" indent_style = "Block"
wrap_comments = false wrap_comments = false
format_code_in_doc_comments = false
comment_width = 80 comment_width = 80
normalize_comments = false normalize_comments = false
normalize_doc_attributes = false
format_strings = false format_strings = false
format_macro_matchers = false format_macro_matchers = false
format_macro_bodies = true format_macro_bodies = true
...@@ -44,10 +46,12 @@ spaces_around_ranges = false ...@@ -44,10 +46,12 @@ spaces_around_ranges = false
binop_separator = "Front" binop_separator = "Front"
remove_nested_parens = true remove_nested_parens = true
combine_control_expr = true combine_control_expr = true
overflow_delimited_expr = false
struct_field_align_threshold = 0 struct_field_align_threshold = 0
enum_discrim_align_threshold = 0
match_arm_blocks = true match_arm_blocks = true
force_multiline_blocks = false force_multiline_blocks = false
fn_args_density = "Tall" fn_args_layout = "Tall"
brace_style = "SameLineWhere" brace_style = "SameLineWhere"
control_brace_style = "AlwaysSameLine" control_brace_style = "AlwaysSameLine"
trailing_semicolon = true trailing_semicolon = true
...@@ -56,8 +60,10 @@ match_block_trailing_comma = false ...@@ -56,8 +60,10 @@ match_block_trailing_comma = false
blank_lines_upper_bound = 1 blank_lines_upper_bound = 1
blank_lines_lower_bound = 0 blank_lines_lower_bound = 0
edition = "2018" edition = "2018"
version = "One"
inline_attribute_width = 0
merge_derives = true merge_derives = true
use_try_shorthand = true use_try_shorthand = false
use_field_init_shorthand = false use_field_init_shorthand = false
force_explicit_abi = true force_explicit_abi = true
condense_wildcard_suffixes = false condense_wildcard_suffixes = false
...@@ -66,8 +72,8 @@ unstable_features = false ...@@ -66,8 +72,8 @@ unstable_features = false
disable_all_formatting = false disable_all_formatting = false
skip_children = false skip_children = false
hide_parse_errors = false hide_parse_errors = false
error_on_line_overflow = true error_on_line_overflow = false
error_on_unformatted = true error_on_unformatted = false
report_todo = "Never" report_todo = "Never"
report_fixme = "Never" report_fixme = "Never"
ignore = [] ignore = []
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
members = [ members = [
"common", "common",
"macros", "macros",
"macros_raw",
"runtime", "runtime",
"runtime/tests/test_tvm_basic", "runtime/tests/test_tvm_basic",
"runtime/tests/test_tvm_dso", "runtime/tests/test_tvm_dso",
......
...@@ -26,8 +26,8 @@ edition = "2018" ...@@ -26,8 +26,8 @@ edition = "2018"
bindings = [] bindings = []
[dependencies] [dependencies]
failure = "0.1.5" failure = { version = "0.1", default-features = false, features = ["derive"] }
ndarray = "0.12.1" ndarray = "0.12"
[build-dependencies] [build-dependencies]
bindgen = "0.37.4" bindgen = "0.51"
...@@ -23,10 +23,10 @@ use std::path::PathBuf; ...@@ -23,10 +23,10 @@ use std::path::PathBuf;
fn main() { fn main() {
let tvm_home = option_env!("TVM_HOME").map(str::to_string).unwrap_or({ let tvm_home = option_env!("TVM_HOME").map(str::to_string).unwrap_or({
let tvm_home = PathBuf::from(env!("CARGO_MANIFEST_DIR")) let crate_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.canonicalize() .canonicalize()
.unwrap(); .unwrap();
tvm_home crate_dir
.parent() .parent()
.unwrap() .unwrap()
.parent() .parent()
...@@ -46,6 +46,7 @@ fn main() { ...@@ -46,6 +46,7 @@ fn main() {
.header(format!("{}/include/tvm/runtime/c_runtime_api.h", tvm_home)) .header(format!("{}/include/tvm/runtime/c_runtime_api.h", tvm_home))
.header(format!("{}/include/tvm/runtime/c_backend_api.h", tvm_home)) .header(format!("{}/include/tvm/runtime/c_backend_api.h", tvm_home))
.clang_arg(format!("-I{}/3rdparty/dlpack/include/", tvm_home)) .clang_arg(format!("-I{}/3rdparty/dlpack/include/", tvm_home))
.clang_arg(format!("-I{}/include/", tvm_home))
.blacklist_type("max_align_t") .blacklist_type("max_align_t")
.layout_tests(false) .layout_tests(false)
.derive_partialeq(true) .derive_partialeq(true)
......
...@@ -20,8 +20,6 @@ ...@@ -20,8 +20,6 @@
//! This crate contains the refactored basic components required //! This crate contains the refactored basic components required
//! for `runtime` and `frontend` TVM crates. //! for `runtime` and `frontend` TVM crates.
#![feature(box_syntax, trait_alias)]
#[macro_use] #[macro_use]
extern crate failure; extern crate failure;
...@@ -44,5 +42,5 @@ pub mod packed_func; ...@@ -44,5 +42,5 @@ pub mod packed_func;
pub mod value; pub mod value;
pub use errors::*; pub use errors::*;
pub use ffi::{TVMByteArray, TVMContext, TVMType}; pub use ffi::{TVMByteArray, TVMContext, DLDataType as TVMType};
pub use packed_func::{TVMArgValue, TVMRetValue}; pub use packed_func::{TVMArgValue, TVMRetValue};
...@@ -26,8 +26,10 @@ use std::{ ...@@ -26,8 +26,10 @@ use std::{
pub use crate::ffi::TVMValue; pub use crate::ffi::TVMValue;
use crate::{errors::ValueDowncastError, ffi::*}; use crate::{errors::ValueDowncastError, ffi::*};
pub trait PackedFunc = pub trait PackedFunc : Fn(&[TVMArgValue]) -> Result<TVMRetValue, crate::errors::FuncCallError> + Send + Sync {}
Fn(&[TVMArgValue]) -> Result<TVMRetValue, crate::errors::FuncCallError> + Send + Sync;
impl<T> PackedFunc for T
where T : Fn(&[TVMArgValue]) -> Result<TVMRetValue, crate::errors::FuncCallError> + Send + Sync {}
/// Calls a packed function and returns a `TVMRetValue`. /// Calls a packed function and returns a `TVMRetValue`.
/// ///
...@@ -66,7 +68,7 @@ macro_rules! TVMPODValue { ...@@ -66,7 +68,7 @@ macro_rules! TVMPODValue {
UInt(i64), UInt(i64),
Float(f64), Float(f64),
Null, Null,
Type(TVMType), DataType(DLDataType),
String(CString), String(CString),
Context(TVMContext), Context(TVMContext),
Handle(*mut c_void), Handle(*mut c_void),
...@@ -87,15 +89,15 @@ macro_rules! TVMPODValue { ...@@ -87,15 +89,15 @@ macro_rules! TVMPODValue {
DLDataTypeCode_kDLInt => Int($value.v_int64), DLDataTypeCode_kDLInt => Int($value.v_int64),
DLDataTypeCode_kDLUInt => UInt($value.v_int64), DLDataTypeCode_kDLUInt => UInt($value.v_int64),
DLDataTypeCode_kDLFloat => Float($value.v_float64), DLDataTypeCode_kDLFloat => Float($value.v_float64),
TVMTypeCode_kNull => Null, TVMTypeCode_kTVMNullptr => Null,
TVMTypeCode_kTVMType => Type($value.v_type), TVMTypeCode_kTVMDataType => DataType($value.v_type),
TVMTypeCode_kTVMContext => Context($value.v_ctx), TVMTypeCode_kTVMContext => Context($value.v_ctx),
TVMTypeCode_kHandle => Handle($value.v_handle), TVMTypeCode_kTVMOpaqueHandle => Handle($value.v_handle),
TVMTypeCode_kArrayHandle => ArrayHandle($value.v_handle as TVMArrayHandle), TVMTypeCode_kTVMDLTensorHandle => ArrayHandle($value.v_handle as TVMArrayHandle),
TVMTypeCode_kObjectHandle => ObjectHandle($value.v_handle), TVMTypeCode_kTVMObjectHandle => ObjectHandle($value.v_handle),
TVMTypeCode_kModuleHandle => ModuleHandle($value.v_handle), TVMTypeCode_kTVMModuleHandle => ModuleHandle($value.v_handle),
TVMTypeCode_kFuncHandle => FuncHandle($value.v_handle), TVMTypeCode_kTVMPackedFuncHandle => FuncHandle($value.v_handle),
TVMTypeCode_kNDArrayContainer => NDArrayContainer($value.v_handle), TVMTypeCode_kTVMNDArrayHandle => NDArrayContainer($value.v_handle),
$( $tvm_type => { $from_tvm_type } ),+ $( $tvm_type => { $from_tvm_type } ),+
_ => unimplemented!("{}", type_code), _ => unimplemented!("{}", type_code),
} }
...@@ -108,31 +110,31 @@ macro_rules! TVMPODValue { ...@@ -108,31 +110,31 @@ macro_rules! TVMPODValue {
Int(val) => (TVMValue { v_int64: *val }, DLDataTypeCode_kDLInt), Int(val) => (TVMValue { v_int64: *val }, DLDataTypeCode_kDLInt),
UInt(val) => (TVMValue { v_int64: *val as i64 }, DLDataTypeCode_kDLUInt), UInt(val) => (TVMValue { v_int64: *val as i64 }, DLDataTypeCode_kDLUInt),
Float(val) => (TVMValue { v_float64: *val }, DLDataTypeCode_kDLFloat), Float(val) => (TVMValue { v_float64: *val }, DLDataTypeCode_kDLFloat),
Null => (TVMValue{ v_int64: 0 },TVMTypeCode_kNull), Null => (TVMValue{ v_int64: 0 },TVMTypeCode_kTVMNullptr),
Type(val) => (TVMValue { v_type: *val }, TVMTypeCode_kTVMType), DataType(val) => (TVMValue { v_type: *val }, TVMTypeCode_kTVMDataType),
Context(val) => (TVMValue { v_ctx: val.clone() }, TVMTypeCode_kTVMContext), Context(val) => (TVMValue { v_ctx: val.clone() }, TVMTypeCode_kTVMContext),
String(val) => { String(val) => {
( (
TVMValue { v_handle: val.as_ptr() as *mut c_void }, TVMValue { v_handle: val.as_ptr() as *mut c_void },
TVMTypeCode_kStr, TVMTypeCode_kTVMStr,
) )
} }
Handle(val) => (TVMValue { v_handle: *val }, TVMTypeCode_kHandle), Handle(val) => (TVMValue { v_handle: *val }, TVMTypeCode_kTVMOpaqueHandle),
ArrayHandle(val) => { ArrayHandle(val) => {
( (
TVMValue { v_handle: *val as *const _ as *mut c_void }, TVMValue { v_handle: *val as *const _ as *mut c_void },
TVMTypeCode_kArrayHandle, TVMTypeCode_kTVMNDArrayHandle,
) )
}, },
ObjectHandle(val) => (TVMValue { v_handle: *val }, TVMTypeCode_kObjectHandle), ObjectHandle(val) => (TVMValue { v_handle: *val }, TVMTypeCode_kTVMObjectHandle),
ModuleHandle(val) => ModuleHandle(val) =>
(TVMValue { v_handle: *val }, TVMTypeCode_kModuleHandle), (TVMValue { v_handle: *val }, TVMTypeCode_kTVMModuleHandle),
FuncHandle(val) => ( FuncHandle(val) => (
TVMValue { v_handle: *val }, TVMValue { v_handle: *val },
TVMTypeCode_kFuncHandle TVMTypeCode_kTVMPackedFuncHandle
), ),
NDArrayContainer(val) => NDArrayContainer(val) =>
(TVMValue { v_handle: *val }, TVMTypeCode_kNDArrayContainer), (TVMValue { v_handle: *val }, TVMTypeCode_kTVMNDArrayHandle),
$( $self_type($val) => { $from_self_type } ),+ $( $self_type($val) => { $from_self_type } ),+
} }
} }
...@@ -148,14 +150,14 @@ TVMPODValue! { ...@@ -148,14 +150,14 @@ TVMPODValue! {
Str(&'a CStr), Str(&'a CStr),
}, },
match value { match value {
TVMTypeCode_kBytes => { Bytes(&*(value.v_handle as *const TVMByteArray)) } TVMTypeCode_kTVMBytes => { Bytes(&*(value.v_handle as *const TVMByteArray)) }
TVMTypeCode_kStr => { Str(CStr::from_ptr(value.v_handle as *const i8)) } TVMTypeCode_kTVMStr => { Str(CStr::from_ptr(value.v_handle as *const i8)) }
}, },
match &self { match &self {
Bytes(val) => { Bytes(val) => {
(TVMValue { v_handle: val.clone() as *const _ as *mut c_void }, TVMTypeCode_kBytes) (TVMValue { v_handle: val as *const _ as *mut c_void }, TVMTypeCode_kTVMBytes)
} }
Str(val) => { (TVMValue { v_handle: val.as_ptr() as *mut c_void }, TVMTypeCode_kStr) } Str(val) => { (TVMValue { v_handle: val.as_ptr() as *mut c_void }, TVMTypeCode_kTVMStr) }
} }
} }
...@@ -166,11 +168,14 @@ TVMPODValue! { ...@@ -166,11 +168,14 @@ TVMPODValue! {
/// # Example /// # Example
/// ///
/// ``` /// ```
/// use std::convert::{TryFrom, TryInto};
/// use tvm_common::TVMRetValue;
///
/// let a = 42u32; /// let a = 42u32;
/// let b: i64 = TVMRetValue::from(a).try_into().unwrap(); /// let b: u32 = tvm_common::TVMRetValue::from(a).try_into().unwrap();
/// ///
/// let s = "hello, world!"; /// let s = "hello, world!";
/// let t: TVMRetValue = s.into(); /// let t: TVMRetValue = s.to_string().into();
/// assert_eq!(String::try_from(t).unwrap(), s); /// assert_eq!(String::try_from(t).unwrap(), s);
/// ``` /// ```
TVMRetValue { TVMRetValue {
...@@ -178,14 +183,14 @@ TVMPODValue! { ...@@ -178,14 +183,14 @@ TVMPODValue! {
Str(&'static CStr), Str(&'static CStr),
}, },
match value { match value {
TVMTypeCode_kBytes => { Bytes(*(value.v_handle as *const TVMByteArray)) } TVMTypeCode_kTVMBytes => { Bytes(*(value.v_handle as *const TVMByteArray)) }
TVMTypeCode_kStr => { Str(CStr::from_ptr(value.v_handle as *mut i8)) } TVMTypeCode_kTVMStr => { Str(CStr::from_ptr(value.v_handle as *mut i8)) }
}, },
match &self { match &self {
Bytes(val) => Bytes(val) =>
{ (TVMValue { v_handle: val as *const _ as *mut c_void }, TVMTypeCode_kBytes ) } { (TVMValue { v_handle: val as *const _ as *mut c_void }, TVMTypeCode_kTVMBytes ) }
Str(val) => Str(val) =>
{ (TVMValue { v_str: val.as_ptr() }, TVMTypeCode_kStr ) } { (TVMValue { v_str: val.as_ptr() }, TVMTypeCode_kTVMStr ) }
} }
} }
...@@ -251,7 +256,7 @@ macro_rules! impl_pod_value { ...@@ -251,7 +256,7 @@ macro_rules! impl_pod_value {
impl_pod_value!(Int, i64, [i8, i16, i32, i64, isize]); impl_pod_value!(Int, i64, [i8, i16, i32, i64, isize]);
impl_pod_value!(UInt, i64, [u8, u16, u32, u64, usize]); impl_pod_value!(UInt, i64, [u8, u16, u32, u64, usize]);
impl_pod_value!(Float, f64, [f32, f64]); impl_pod_value!(Float, f64, [f32, f64]);
impl_pod_value!(Type, TVMType, [TVMType]); impl_pod_value!(DataType, DLDataType, [DLDataType]);
impl_pod_value!(Context, TVMContext, [TVMContext]); impl_pod_value!(Context, TVMContext, [TVMContext]);
impl<'a> From<&'a str> for TVMArgValue<'a> { impl<'a> From<&'a str> for TVMArgValue<'a> {
......
...@@ -19,11 +19,9 @@ ...@@ -19,11 +19,9 @@
use std::{os::raw::c_char, str::FromStr}; use std::{os::raw::c_char, str::FromStr};
use failure::Error;
use crate::ffi::*; use crate::ffi::*;
impl TVMType { impl DLDataType {
fn new(type_code: u8, bits: u8, lanes: u16) -> Self { fn new(type_code: u8, bits: u8, lanes: u16) -> Self {
Self { Self {
code: type_code, code: type_code,
...@@ -33,25 +31,37 @@ impl TVMType { ...@@ -33,25 +31,37 @@ impl TVMType {
} }
} }
#[derive(Debug, Fail)]
pub enum ParseTvmTypeError {
#[fail(display = "invalid number: {}", _0)]
InvalidNumber(std::num::ParseIntError),
#[fail(display = "unknown type: {}", _0)]
UnknownType(String),
}
/// Implements TVMType conversion from `&str` of general format `{dtype}{bits}x{lanes}` /// Implements TVMType conversion from `&str` of general format `{dtype}{bits}x{lanes}`
/// such as "int32", "float32" or with lane "float32x1". /// such as "int32", "float32" or with lane "float32x1".
impl FromStr for TVMType { impl FromStr for DLDataType {
type Err = Error; type Err = ParseTvmTypeError;
fn from_str(type_str: &str) -> Result<Self, Self::Err> { fn from_str(type_str: &str) -> Result<Self, Self::Err> {
if type_str == "bool" { if type_str == "bool" {
return Ok(TVMType::new(1, 1, 1)); return Ok(DLDataType::new(1, 1, 1));
} }
let mut type_lanes = type_str.split("x"); let mut type_lanes = type_str.split('x');
let typ = type_lanes.next().expect("Missing dtype"); let typ = type_lanes.next().expect("Missing dtype");
let lanes = type_lanes let lanes = type_lanes
.next() .next()
.map(|l| <u16>::from_str_radix(l, 10)) .map(|l| <u16>::from_str_radix(l, 10))
.unwrap_or(Ok(1))?; .unwrap_or(Ok(1))
.map_err(ParseTvmTypeError::InvalidNumber)?;
let (type_name, bits) = match typ.find(char::is_numeric) { let (type_name, bits) = match typ.find(char::is_numeric) {
Some(idx) => { Some(idx) => {
let (name, bits_str) = typ.split_at(idx); let (name, bits_str) = typ.split_at(idx);
(name, u8::from_str_radix(bits_str, 10)?) (
name,
u8::from_str_radix(bits_str, 10).map_err(ParseTvmTypeError::InvalidNumber)?,
)
} }
None => (typ, 32), None => (typ, 32),
}; };
...@@ -61,14 +71,14 @@ impl FromStr for TVMType { ...@@ -61,14 +71,14 @@ impl FromStr for TVMType {
"uint" => 1, "uint" => 1,
"float" => 2, "float" => 2,
"handle" => 3, "handle" => 3,
_ => return Err(format_err!("Unknown type {}", type_name)), _ => return Err(ParseTvmTypeError::UnknownType(type_name.to_string())),
}; };
Ok(TVMType::new(type_code, bits, lanes)) Ok(DLDataType::new(type_code, bits, lanes))
} }
} }
impl std::fmt::Display for TVMType { impl std::fmt::Display for DLDataType {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
if self.bits == 1 && self.lanes == 1 { if self.bits == 1 && self.lanes == 1 {
return write!(f, "bool"); return write!(f, "bool");
...@@ -113,19 +123,23 @@ macro_rules! impl_pod_tvm_value { ...@@ -113,19 +123,23 @@ macro_rules! impl_pod_tvm_value {
impl_pod_tvm_value!(v_int64, i64, i8, u8, i16, u16, i32, u32, i64, u64, isize, usize); impl_pod_tvm_value!(v_int64, i64, i8, u8, i16, u16, i32, u32, i64, u64, isize, usize);
impl_pod_tvm_value!(v_float64, f64, f32, f64); impl_pod_tvm_value!(v_float64, f64, f32, f64);
impl_pod_tvm_value!(v_type, TVMType); impl_pod_tvm_value!(v_type, DLDataType);
impl_pod_tvm_value!(v_ctx, TVMContext); impl_pod_tvm_value!(v_ctx, TVMContext);
#[derive(Debug, Fail)]
#[fail(display = "unsupported device: {}", _0)]
pub struct UnsupportedDeviceError(String);
macro_rules! impl_tvm_context { macro_rules! impl_tvm_context {
( $( $dev_type:ident : [ $( $dev_name:ident ),+ ] ),+ ) => { ( $( $dev_type:ident : [ $( $dev_name:ident ),+ ] ),+ ) => {
/// Creates a TVMContext from a string (e.g., "cpu", "gpu", "ext_dev") /// Creates a TVMContext from a string (e.g., "cpu", "gpu", "ext_dev")
impl FromStr for TVMContext { impl FromStr for TVMContext {
type Err = Error; type Err = UnsupportedDeviceError;
fn from_str(type_str: &str) -> Result<Self, Self::Err> { fn from_str(type_str: &str) -> Result<Self, Self::Err> {
Ok(Self { Ok(Self {
device_type: match type_str { device_type: match type_str {
$( $( stringify!($dev_name) )|+ => $dev_type ),+, $( $( stringify!($dev_name) )|+ => $dev_type ),+,
_ => return Err(format_err!("device {} not supported", type_str).into()), _ => return Err(UnsupportedDeviceError(type_str.to_string())),
}, },
device_id: 0, device_id: 0,
}) })
...@@ -163,7 +177,7 @@ impl_tvm_context!( ...@@ -163,7 +177,7 @@ impl_tvm_context!(
/// ///
/// ``` /// ```
/// let v = b"hello"; /// let v = b"hello";
/// let barr = TVMByteArray::from(&v); /// let barr = tvm_common::TVMByteArray::from(&v);
/// assert_eq!(barr.len(), v.len()); /// assert_eq!(barr.len(), v.len());
/// assert_eq!(barr.data(), &[104u8, 101, 108, 108, 111]); /// assert_eq!(barr.data(), &[104u8, 101, 108, 108, 111]);
/// ``` /// ```
...@@ -182,6 +196,10 @@ impl TVMByteArray { ...@@ -182,6 +196,10 @@ impl TVMByteArray {
pub fn to_vec(&self) -> Vec<u8> { pub fn to_vec(&self) -> Vec<u8> {
self.data().to_vec() self.data().to_vec()
} }
pub fn is_empty(&self) -> bool {
self.len() == 0
}
} }
// Needs AsRef for Vec // Needs AsRef for Vec
......
...@@ -28,16 +28,12 @@ categories = ["api-bindings", "science"] ...@@ -28,16 +28,12 @@ categories = ["api-bindings", "science"]
authors = ["TVM Contributors"] authors = ["TVM Contributors"]
edition = "2018" edition = "2018"
[lib]
name = "tvm_frontend"
crate-type = ["dylib"]
[dependencies] [dependencies]
failure = "0.1.5" failure = "0.1"
lazy_static = "1.1.0" lazy_static = "1.1"
ndarray = "0.12.1" ndarray = "0.12"
num-traits = "0.2" num-traits = "0.2"
tvm-common = { version = "0.1.0", path = "../common/", features = ["bindings"] } tvm-common = { version = "0.1", path = "../common/", features = ["bindings"] }
[features] [features]
blas = ["ndarray/blas"] blas = ["ndarray/blas"]
...@@ -23,7 +23,7 @@ license = "Apache-2.0" ...@@ -23,7 +23,7 @@ license = "Apache-2.0"
build = "build.rs" build = "build.rs"
[dependencies] [dependencies]
ndarray = "0.12.1" ndarray = "0.12"
tvm-frontend = { path = "../../" } tvm-frontend = { path = "../../" }
image = "0.20.1" image = "0.20"
csv = "1" csv = "1.1"
...@@ -65,7 +65,7 @@ fn main() { ...@@ -65,7 +65,7 @@ fn main() {
let input = NDArray::from_rust_ndarray( let input = NDArray::from_rust_ndarray(
&arr, &arr,
TVMContext::cpu(0), TVMContext::cpu(0),
TVMType::from_str("float32").unwrap(), DLDataType::from_str("float32").unwrap(),
) )
.unwrap(); .unwrap();
println!( println!(
...@@ -117,7 +117,7 @@ fn main() { ...@@ -117,7 +117,7 @@ fn main() {
let output = NDArray::empty( let output = NDArray::empty(
output_shape, output_shape,
TVMContext::cpu(0), TVMContext::cpu(0),
TVMType::from_str("float32").unwrap(), DLDataType::from_str("float32").unwrap(),
); );
// get the `get_output` function from runtime module // get the `get_output` function from runtime module
let ref get_output_fn = graph_runtime_module let ref get_output_fn = graph_runtime_module
......
...@@ -28,7 +28,7 @@ ...@@ -28,7 +28,7 @@
use std::{ use std::{
collections::BTreeMap, collections::BTreeMap,
ffi::{CStr, CString}, ffi::{CStr, CString},
mem, mem::{self, MaybeUninit},
os::raw::{c_char, c_int, c_void}, os::raw::{c_char, c_int, c_void},
ptr, slice, str, ptr, slice, str,
sync::Mutex, sync::Mutex,
...@@ -36,25 +36,20 @@ use std::{ ...@@ -36,25 +36,20 @@ use std::{
use failure::Error; use failure::Error;
use crate::{ use crate::{errors, ffi, Module, TVMArgValue, TVMRetValue};
errors,
ffi::{self, TVMValue},
Module, TVMArgValue, TVMRetValue,
};
lazy_static! { lazy_static! {
static ref GLOBAL_FUNCTIONS: Mutex<BTreeMap<&'static str, Option<Function>>> = { static ref GLOBAL_FUNCTIONS: Mutex<BTreeMap<&'static str, Option<Function>>> = {
let mut out_size = 0 as c_int; let mut out_size = 0 as c_int;
let name = ptr::null_mut() as *mut c_char; let mut names_ptr = ptr::null_mut() as *mut *const c_char;
let mut out_array = name as *mut _;
check_call!(ffi::TVMFuncListGlobalNames( check_call!(ffi::TVMFuncListGlobalNames(
&mut out_size as *mut _, &mut out_size as *mut _,
&mut out_array &mut names_ptr as *mut _,
)); ));
let names_list = unsafe { slice::from_raw_parts(out_array, out_size as usize) }; let names_list = unsafe { slice::from_raw_parts(names_ptr, out_size as usize) };
Mutex::new( Mutex::new(
names_list names_list
.into_iter() .iter()
.map(|&p| (unsafe { CStr::from_ptr(p).to_str().unwrap() }, None)) .map(|&p| (unsafe { CStr::from_ptr(p).to_str().unwrap() }, None))
.collect(), .collect(),
) )
...@@ -80,7 +75,7 @@ unsafe impl Sync for Function {} ...@@ -80,7 +75,7 @@ unsafe impl Sync for Function {}
impl Function { impl Function {
pub(crate) fn new(handle: ffi::TVMFunctionHandle) -> Self { pub(crate) fn new(handle: ffi::TVMFunctionHandle) -> Self {
Function { Function {
handle: handle, handle,
is_global: false, is_global: false,
is_cloned: false, is_cloned: false,
} }
...@@ -98,15 +93,13 @@ impl Function { ...@@ -98,15 +93,13 @@ impl Function {
&mut handle as *mut _ &mut handle as *mut _
)); ));
maybe_func.replace(Function { maybe_func.replace(Function {
handle: handle, handle,
is_global: true, is_global: true,
is_cloned: false, is_cloned: false,
}); });
} }
unsafe { unsafe {
std::mem::transmute::<Option<&Function>, Option<&'static Function>>( mem::transmute::<Option<&Function>, Option<&'static Function>>(maybe_func.as_ref())
maybe_func.as_ref(),
)
} }
}) })
} }
...@@ -214,7 +207,7 @@ impl<'a, 'm> Builder<'a, 'm> { ...@@ -214,7 +207,7 @@ impl<'a, 'm> Builder<'a, 'm> {
let (mut values, mut type_codes): (Vec<ffi::TVMValue>, Vec<ffi::TVMTypeCode>) = let (mut values, mut type_codes): (Vec<ffi::TVMValue>, Vec<ffi::TVMTypeCode>) =
self.arg_buf.iter().map(|arg| arg.to_tvm_value()).unzip(); self.arg_buf.iter().map(|arg| arg.to_tvm_value()).unzip();
let mut ret_val = unsafe { std::mem::uninitialized::<TVMValue>() }; let mut ret_val = unsafe { MaybeUninit::uninit().assume_init() };
let mut ret_type_code = 0i32; let mut ret_type_code = 0i32;
check_call!(ffi::TVMFuncCall( check_call!(ffi::TVMFuncCall(
self.func.ok_or(errors::FunctionNotFoundError)?.handle, self.func.ok_or(errors::FunctionNotFoundError)?.handle,
...@@ -257,20 +250,20 @@ unsafe extern "C" fn tvm_callback( ...@@ -257,20 +250,20 @@ unsafe extern "C" fn tvm_callback(
let args_list = slice::from_raw_parts_mut(args, len); let args_list = slice::from_raw_parts_mut(args, len);
let type_codes_list = slice::from_raw_parts_mut(type_codes, len); let type_codes_list = slice::from_raw_parts_mut(type_codes, len);
let mut local_args: Vec<TVMArgValue> = Vec::new(); let mut local_args: Vec<TVMArgValue> = Vec::new();
let mut value = mem::uninitialized::<ffi::TVMValue>(); let mut value = MaybeUninit::uninit().assume_init();
let mut tcode = mem::uninitialized::<c_int>(); let mut tcode = MaybeUninit::uninit().assume_init();
let rust_fn = let rust_fn =
mem::transmute::<*mut c_void, fn(&[TVMArgValue]) -> Result<TVMRetValue, Error>>(fhandle); mem::transmute::<*mut c_void, fn(&[TVMArgValue]) -> Result<TVMRetValue, Error>>(fhandle);
for i in 0..len { for i in 0..len {
value = args_list[i]; value = args_list[i];
tcode = type_codes_list[i]; tcode = type_codes_list[i];
if tcode == ffi::TVMTypeCode_kObjectHandle as c_int if tcode == ffi::TVMTypeCode_kTVMObjectHandle as c_int
|| tcode == ffi::TVMTypeCode_kFuncHandle as c_int || tcode == ffi::TVMTypeCode_kTVMPackedFuncHandle as c_int
|| tcode == ffi::TVMTypeCode_kModuleHandle as c_int || tcode == ffi::TVMTypeCode_kTVMModuleHandle as c_int
{ {
check_call!(ffi::TVMCbArgToReturn(&mut value as *mut _, tcode)); check_call!(ffi::TVMCbArgToReturn(&mut value as *mut _, tcode));
} }
local_args.push(TVMArgValue::from_tvm_value(value.into(), tcode as u32)); local_args.push(TVMArgValue::from_tvm_value(value, tcode as u32));
} }
let rv = match rust_fn(local_args.as_slice()) { let rv = match rust_fn(local_args.as_slice()) {
...@@ -293,9 +286,9 @@ unsafe extern "C" fn tvm_callback( ...@@ -293,9 +286,9 @@ unsafe extern "C" fn tvm_callback(
} }
unsafe extern "C" fn tvm_callback_finalizer(fhandle: *mut c_void) { unsafe extern "C" fn tvm_callback_finalizer(fhandle: *mut c_void) {
let rust_fn = let _rust_fn =
mem::transmute::<*mut c_void, fn(&[TVMArgValue]) -> Result<TVMRetValue, Error>>(fhandle); mem::transmute::<*mut c_void, fn(&[TVMArgValue]) -> Result<TVMRetValue, Error>>(fhandle);
mem::drop(rust_fn); // XXX: give converted functions lifetimes so they're not called after use
} }
fn convert_to_tvm_func(f: fn(&[TVMArgValue]) -> Result<TVMRetValue, Error>) -> Function { fn convert_to_tvm_func(f: fn(&[TVMArgValue]) -> Result<TVMRetValue, Error>) -> Function {
......
...@@ -30,8 +30,6 @@ ...@@ -30,8 +30,6 @@
//! //!
//! Checkout the `examples` repository for more details. //! Checkout the `examples` repository for more details.
#![feature(box_syntax)]
#[macro_use] #[macro_use]
extern crate failure; extern crate failure;
#[macro_use] #[macro_use]
...@@ -55,7 +53,7 @@ pub use crate::{ ...@@ -55,7 +53,7 @@ pub use crate::{
ndarray::NDArray, ndarray::NDArray,
tvm_common::{ tvm_common::{
errors as common_errors, errors as common_errors,
ffi::{self, TVMByteArray, TVMType}, ffi::{self, TVMByteArray, DLDataType},
packed_func::{TVMArgValue, TVMRetValue}, packed_func::{TVMArgValue, TVMRetValue},
}, },
}; };
......
...@@ -32,7 +32,7 @@ use tvm_common::ffi; ...@@ -32,7 +32,7 @@ use tvm_common::ffi;
use crate::{errors, function::Function}; use crate::{errors, function::Function};
const ENTRY_FUNC: &'static str = "__tvm_main__"; const ENTRY_FUNC: &str = "__tvm_main__";
/// Wrapper around TVM module handle which contains an entry function. /// Wrapper around TVM module handle which contains an entry function.
/// The entry function can be applied to an imported module through [`entry_func`]. /// The entry function can be applied to an imported module through [`entry_func`].
...@@ -72,7 +72,7 @@ impl Module { ...@@ -72,7 +72,7 @@ impl Module {
ensure!( ensure!(
!fhandle.is_null(), !fhandle.is_null(),
errors::NullHandleError { errors::NullHandleError {
name: format!("{}", name.into_string()?) name: name.into_string()?.to_string()
} }
); );
Ok(Function::new(fhandle)) Ok(Function::new(fhandle))
...@@ -88,7 +88,7 @@ impl Module { ...@@ -88,7 +88,7 @@ impl Module {
let ext = CString::new( let ext = CString::new(
path.as_ref() path.as_ref()
.extension() .extension()
.unwrap_or(std::ffi::OsStr::new("")) .unwrap_or_else(|| std::ffi::OsStr::new(""))
.to_str() .to_str()
.ok_or_else(|| { .ok_or_else(|| {
format_err!("Bad module load path: `{}`.", path.as_ref().display()) format_err!("Bad module load path: `{}`.", path.as_ref().display())
......
...@@ -63,7 +63,7 @@ pub struct NDArray { ...@@ -63,7 +63,7 @@ pub struct NDArray {
impl NDArray { impl NDArray {
pub(crate) fn new(handle: ffi::TVMArrayHandle) -> Self { pub(crate) fn new(handle: ffi::TVMArrayHandle) -> Self {
NDArray { NDArray {
handle: handle, handle,
is_view: true, is_view: true,
} }
} }
...@@ -89,8 +89,7 @@ impl NDArray { ...@@ -89,8 +89,7 @@ impl NDArray {
/// Returns the total number of entries of the NDArray. /// Returns the total number of entries of the NDArray.
pub fn size(&self) -> Option<usize> { pub fn size(&self) -> Option<usize> {
self.shape() self.shape().map(|v| v.iter().product())
.map(|v| v.into_iter().fold(1, |acc, &mut e| acc * e))
} }
/// Returns the context which the NDArray was defined. /// Returns the context which the NDArray was defined.
...@@ -100,7 +99,7 @@ impl NDArray { ...@@ -100,7 +99,7 @@ impl NDArray {
/// Returns the type of the entries of the NDArray. /// Returns the type of the entries of the NDArray.
pub fn dtype(&self) -> TVMType { pub fn dtype(&self) -> TVMType {
unsafe { (*self.handle).dtype.into() } unsafe { (*self.handle).dtype }
} }
/// Returns the number of dimensions of the NDArray. /// Returns the number of dimensions of the NDArray.
...@@ -211,8 +210,8 @@ impl NDArray { ...@@ -211,8 +210,8 @@ impl NDArray {
bail!( bail!(
"{}", "{}",
errors::TypeMismatchError { errors::TypeMismatchError {
expected: format!("{}", self.dtype().to_string()), expected: self.dtype().to_string(),
actual: format!("{}", target.dtype().to_string()), actual: target.dtype().to_string(),
} }
); );
} }
...@@ -228,7 +227,7 @@ impl NDArray { ...@@ -228,7 +227,7 @@ impl NDArray {
pub fn copy_to_ctx(&self, target: &TVMContext) -> Result<NDArray, Error> { pub fn copy_to_ctx(&self, target: &TVMContext) -> Result<NDArray, Error> {
let tmp = NDArray::empty( let tmp = NDArray::empty(
self.shape().ok_or(errors::MissingShapeError)?, self.shape().ok_or(errors::MissingShapeError)?,
target.clone(), *target,
self.dtype(), self.dtype(),
); );
let copy = self.copy_to_ndarray(tmp)?; let copy = self.copy_to_ndarray(tmp)?;
...@@ -241,8 +240,8 @@ impl NDArray { ...@@ -241,8 +240,8 @@ impl NDArray {
ctx: TVMContext, ctx: TVMContext,
dtype: TVMType, dtype: TVMType,
) -> Result<Self, Error> { ) -> Result<Self, Error> {
let mut shape = rnd.shape().to_vec(); let shape = rnd.shape().to_vec();
let mut nd = NDArray::empty(&mut shape, ctx, dtype); let mut nd = NDArray::empty(&shape, ctx, dtype);
let mut buf = Array::from_iter(rnd.into_iter().map(|&v| v as T)); let mut buf = Array::from_iter(rnd.into_iter().map(|&v| v as T));
nd.copy_from_buffer( nd.copy_from_buffer(
buf.as_slice_mut() buf.as_slice_mut()
...@@ -257,9 +256,9 @@ impl NDArray { ...@@ -257,9 +256,9 @@ impl NDArray {
check_call!(ffi::TVMArrayAlloc( check_call!(ffi::TVMArrayAlloc(
shape.as_ptr() as *const i64, shape.as_ptr() as *const i64,
shape.len() as c_int, shape.len() as c_int,
dtype.code as c_int, i32::from(dtype.code) as c_int,
dtype.bits as c_int, i32::from(dtype.bits) as c_int,
dtype.lanes as c_int, i32::from(dtype.lanes) as c_int,
ctx.device_type.0 as c_int, ctx.device_type.0 as c_int,
ctx.device_id as c_int, ctx.device_id as c_int,
&mut handle as *mut _, &mut handle as *mut _,
...@@ -364,9 +363,9 @@ mod tests { ...@@ -364,9 +363,9 @@ mod tests {
assert_eq!(ndarray.ndim(), 1); assert_eq!(ndarray.ndim(), 1);
assert!(ndarray.is_contiguous().is_ok()); assert!(ndarray.is_contiguous().is_ok());
assert_eq!(ndarray.byte_offset(), 0); assert_eq!(ndarray.byte_offset(), 0);
let mut shape = vec![4]; let shape = vec![4];
let e = NDArray::empty( let e = NDArray::empty(
&mut shape, &shape,
TVMContext::cpu(0), TVMContext::cpu(0),
TVMType::from_str("int32").unwrap(), TVMType::from_str("int32").unwrap(),
); );
...@@ -378,16 +377,12 @@ mod tests { ...@@ -378,16 +377,12 @@ mod tests {
#[test] #[test]
#[should_panic(expected = "called `Result::unwrap()` on an `Err`")] #[should_panic(expected = "called `Result::unwrap()` on an `Err`")]
fn copy_wrong_dtype() { fn copy_wrong_dtype() {
let mut shape = vec![4]; let shape = vec![4];
let mut data = vec![1f32, 2., 3., 4.]; let mut data = vec![1f32, 2., 3., 4.];
let ctx = TVMContext::cpu(0); let ctx = TVMContext::cpu(0);
let mut nd_float = NDArray::empty( let mut nd_float = NDArray::empty(&shape, ctx, TVMType::from_str("float32").unwrap());
&mut shape,
ctx.clone(),
TVMType::from_str("float32").unwrap(),
);
nd_float.copy_from_buffer(&mut data); nd_float.copy_from_buffer(&mut data);
let empty_int = NDArray::empty(&mut shape, ctx, TVMType::from_str("int32").unwrap()); let empty_int = NDArray::empty(&shape, ctx, TVMType::from_str("int32").unwrap());
nd_float.copy_to_ndarray(empty_int).unwrap(); nd_float.copy_to_ndarray(empty_int).unwrap();
} }
......
...@@ -93,7 +93,7 @@ mod tests { ...@@ -93,7 +93,7 @@ mod tests {
let tvm: TVMByteArray = TVMRetValue::from(v).try_into().unwrap(); let tvm: TVMByteArray = TVMRetValue::from(v).try_into().unwrap();
assert_eq!( assert_eq!(
tvm.data(), tvm.data(),
w.iter().map(|e| *e).collect::<Vec<u8>>().as_slice() w.iter().copied().collect::<Vec<u8>>().as_slice()
); );
} }
......
...@@ -23,7 +23,7 @@ license = "Apache-2.0" ...@@ -23,7 +23,7 @@ license = "Apache-2.0"
build = "build.rs" build = "build.rs"
[dependencies] [dependencies]
ndarray = "0.12.1" ndarray = "0.12"
tvm-frontend = { path = "../../" } tvm-frontend = { path = "../../" }
[features] [features]
......
...@@ -33,7 +33,7 @@ fn main() { ...@@ -33,7 +33,7 @@ fn main() {
} else { } else {
(TVMContext::gpu(0), "gpu") (TVMContext::gpu(0), "gpu")
}; };
let dtype = TVMType::from_str("float32").unwrap(); let dtype = DLDataType::from_str("float32").unwrap();
let mut arr = NDArray::empty(shape, ctx, dtype); let mut arr = NDArray::empty(shape, ctx, dtype);
arr.copy_from_buffer(data.as_mut_slice()); arr.copy_from_buffer(data.as_mut_slice());
let mut ret = NDArray::empty(shape, ctx, dtype); let mut ret = NDArray::empty(shape, ctx, dtype);
......
...@@ -21,5 +21,5 @@ version = "0.0.0" ...@@ -21,5 +21,5 @@ version = "0.0.0"
authors = ["TVM Contributors"] authors = ["TVM Contributors"]
[dependencies] [dependencies]
ndarray = "0.12.1" ndarray = "0.12"
tvm-frontend = { path = "../../" } tvm-frontend = { path = "../../" }
...@@ -39,7 +39,7 @@ fn main() { ...@@ -39,7 +39,7 @@ fn main() {
for arg in args.iter() { for arg in args.iter() {
let e = NDArray::empty( let e = NDArray::empty(
shape, TVMContext::cpu(0), shape, TVMContext::cpu(0),
TVMType::from_str("float32").unwrap() DLDataType::from_str("float32").unwrap()
); );
let arg: NDArray = arg.try_into()?; let arg: NDArray = arg.try_into()?;
let arr = arg.copy_to_ndarray(e)?; let arr = arg.copy_to_ndarray(e)?;
...@@ -55,7 +55,7 @@ fn main() { ...@@ -55,7 +55,7 @@ fn main() {
let mut arr = NDArray::empty( let mut arr = NDArray::empty(
shape, shape,
TVMContext::cpu(0), TVMContext::cpu(0),
TVMType::from_str("float32").unwrap(), DLDataType::from_str("float32").unwrap(),
); );
arr.copy_from_buffer(data.as_mut_slice()); arr.copy_from_buffer(data.as_mut_slice());
......
...@@ -17,9 +17,6 @@ ...@@ -17,9 +17,6 @@
* under the License. * under the License.
*/ */
#![feature(panic_info_message)]
#![allow(unused_imports)]
use std::panic; use std::panic;
#[macro_use] #[macro_use]
...@@ -44,9 +41,9 @@ fn main() { ...@@ -44,9 +41,9 @@ fn main() {
println!("expected error message is:"); println!("expected error message is:");
panic::set_hook(Box::new(|panic_info| { panic::set_hook(Box::new(|panic_info| {
if let Some(msg) = panic_info.message() { // if let Some(msg) = panic_info.message() {
println!("{:?}", msg); // println!("{:?}", msg);
} // }
if let Some(location) = panic_info.location() { if let Some(location) = panic_info.location() {
println!( println!(
"panic occurred in file '{}' at line {}", "panic occurred in file '{}' at line {}",
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
[package] [package]
name = "tvm-macros" name = "tvm-macros"
version = "0.1.0" version = "0.1.1"
license = "Apache-2.0" license = "Apache-2.0"
description = "Proc macros used by the TVM crates." description = "Proc macros used by the TVM crates."
repository = "https://github.com/apache/incubator-tvm" repository = "https://github.com/apache/incubator-tvm"
...@@ -26,11 +26,6 @@ keywords = ["tvm"] ...@@ -26,11 +26,6 @@ keywords = ["tvm"]
authors = ["TVM Contributors"] authors = ["TVM Contributors"]
edition = "2018" edition = "2018"
[lib]
proc-macro = true
[dependencies] [dependencies]
goblin = "0.0.22" tvm-macros-raw = { path = "../macros_raw" }
proc-macro2 = "0.4"
proc-quote = "0.2"
syn = "0.15"
...@@ -17,106 +17,12 @@ ...@@ -17,106 +17,12 @@
* under the License. * under the License.
*/ */
#![feature(proc_macro_span)] #[macro_use]
extern crate tvm_macros_raw;
extern crate proc_macro; #[macro_export]
macro_rules! import_module {
use std::{fs::File, io::Read}; ($module_path:literal) => {
$crate::import_module_raw!(file!(), $module_path);
use proc_quote::quote;
#[proc_macro]
pub fn import_module(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let obj_file_path = syn::parse_macro_input!(input as syn::LitStr);
let mut path = obj_file_path.span().unwrap().source_file().path();
path.pop(); // remove the filename
path.push(obj_file_path.value());
let mut fd = File::open(&path)
.unwrap_or_else(|_| panic!("Unable to find TVM object file at `{}`", path.display()));
let mut buffer = Vec::new();
fd.read_to_end(&mut buffer).unwrap();
let fn_names = match goblin::Object::parse(&buffer).unwrap() {
goblin::Object::Elf(elf) => elf
.syms
.iter()
.filter_map(|s| {
if s.st_type() == 0 || goblin::elf::sym::type_to_str(s.st_type()) == "FILE" {
return None;
}
match elf.strtab.get(s.st_name) {
Some(Ok(name)) if name != "" => {
Some(syn::Ident::new(name, proc_macro2::Span::call_site()))
}
_ => None,
}
})
.collect::<Vec<_>>(),
goblin::Object::Mach(goblin::mach::Mach::Binary(obj)) => {
obj.symbols()
.filter_map(|s| match s {
Ok((name, nlist))
if nlist.is_global()
&& nlist.n_sect != 0
&& !name.ends_with("tvm_module_ctx") =>
{
Some(syn::Ident::new(
if name.starts_with('_') {
// Mach objects prepend a _ to globals.
&name[1..]
} else {
&name
},
proc_macro2::Span::call_site(),
))
}
_ => None,
})
.collect::<Vec<_>>()
}
_ => panic!("Unsupported object format."),
}; };
let extern_fns = quote! {
mod ext {
extern "C" {
#(
pub(super) fn #fn_names(
args: *const tvm_runtime::ffi::TVMValue,
type_codes: *const std::os::raw::c_int,
num_args: std::os::raw::c_int
) -> std::os::raw::c_int;
)*
}
}
};
let fns = quote! {
use tvm_runtime::{ffi::TVMValue, TVMArgValue, TVMRetValue, FuncCallError};
#extern_fns
#(
pub fn #fn_names(args: &[TVMArgValue]) -> Result<TVMRetValue, FuncCallError> {
let (values, type_codes): (Vec<TVMValue>, Vec<i32>) = args
.into_iter()
.map(|arg| {
let (val, code) = arg.to_tvm_value();
(val, code as i32)
})
.unzip();
let exit_code = unsafe {
ext::#fn_names(values.as_ptr(), type_codes.as_ptr(), values.len() as i32)
};
if exit_code == 0 {
Ok(TVMRetValue::default())
} else {
Err(FuncCallError::get_with_context(stringify!(#fn_names).to_string()))
}
}
)*
};
proc_macro::TokenStream::from(fns)
} }
#!/bin/bash
# Licensed to the Apache Software Foundation (ASF) under one # Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file # or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information # distributed with this work for additional information
...@@ -6,9 +5,9 @@ ...@@ -6,9 +5,9 @@
# to you under the Apache License, Version 2.0 (the # to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance # "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at # with the License. You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -16,11 +15,22 @@ ...@@ -16,11 +15,22 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
sgx_sdk=${SGX_SDK:=/opt/sgxsdk} [package]
name = "tvm-macros-raw"
version = "0.1.1"
license = "Apache-2.0"
description = "Proc macros used by the TVM crates."
repository = "https://github.com/apache/incubator-tvm"
readme = "README.md"
keywords = ["tvm"]
authors = ["TVM Contributors"]
edition = "2018"
export LD_LIBRARY_PATH="$sgx_sdk/lib64":${LD_LIBRARY_PATH} [lib]
export CC=clang-6.0 proc-macro = true
export AR=llvm-ar-6.0
export TVM_CACHE_DIR=/tmp
make && printf "\n" && python3 run_model.py [dependencies]
goblin = "0.0.24"
proc-macro2 = "^1.0"
quote = "1.0"
syn = "1.0"
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
extern crate proc_macro;
use std::{fs::File, io::Read};
use syn::parse::{Parse, ParseStream, Result};
use syn::{Token, LitStr};
use quote::quote;
use std::path::PathBuf;
struct ImportModule {
importing_file: LitStr,
module_path: LitStr,
}
impl Parse for ImportModule {
fn parse(input: ParseStream) -> Result<Self> {
let importing_file: LitStr = input.parse()?;
input.parse::<Token![,]>()?;
let module_path: LitStr = input.parse()?;
Ok(ImportModule {
importing_file,
module_path,
})
}
}
#[proc_macro]
pub fn import_module_raw(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let import_module_args = syn::parse_macro_input!(input as ImportModule);
let mut path = PathBuf::new();
path = path.join(import_module_args.importing_file.value());
path.pop(); // remove the filename
path.push(import_module_args.module_path.value());
let mut fd = File::open(&path)
.unwrap_or_else(|_| panic!("Unable to find TVM object file at `{}`", path.display()));
let mut buffer = Vec::new();
fd.read_to_end(&mut buffer).unwrap();
let fn_names = match goblin::Object::parse(&buffer).unwrap() {
goblin::Object::Elf(elf) => elf
.syms
.iter()
.filter_map(|s| {
if s.st_type() == 0 || goblin::elf::sym::type_to_str(s.st_type()) == "FILE" {
return None;
}
match elf.strtab.get(s.st_name) {
Some(Ok(name)) if name != "" => {
Some(syn::Ident::new(name, proc_macro2::Span::call_site()))
}
_ => None,
}
})
.collect::<Vec<_>>(),
goblin::Object::Mach(goblin::mach::Mach::Binary(obj)) => {
obj.symbols()
.filter_map(|s| match s {
Ok((name, ref nlist))
if nlist.is_global()
&& nlist.n_sect != 0
&& !name.ends_with("tvm_module_ctx") =>
{
Some(syn::Ident::new(
if name.starts_with('_') {
// Mach objects prepend a _ to globals.
&name[1..]
} else {
&name
},
proc_macro2::Span::call_site(),
))
}
_ => None,
})
.collect::<Vec<_>>()
}
_ => panic!("Unsupported object format."),
};
let extern_fns = quote! {
mod ext {
extern "C" {
#(
pub(super) fn #fn_names(
args: *const tvm_runtime::ffi::TVMValue,
type_codes: *const std::os::raw::c_int,
num_args: std::os::raw::c_int
) -> std::os::raw::c_int;
)*
}
}
};
let fns = quote! {
use tvm_runtime::{ffi::TVMValue, TVMArgValue, TVMRetValue, FuncCallError};
#extern_fns
#(
pub fn #fn_names(args: &[TVMArgValue]) -> Result<TVMRetValue, FuncCallError> {
let (values, type_codes): (Vec<TVMValue>, Vec<i32>) = args
.into_iter()
.map(|arg| {
let (val, code) = arg.to_tvm_value();
(val, code as i32)
})
.unzip();
let exit_code = unsafe {
ext::#fn_names(values.as_ptr(), type_codes.as_ptr(), values.len() as i32)
};
if exit_code == 0 {
Ok(TVMRetValue::default())
} else {
Err(FuncCallError::get_with_context(stringify!(#fn_names).to_string()))
}
}
)*
};
proc_macro::TokenStream::from(fns)
}
...@@ -27,25 +27,19 @@ categories = ["api-bindings", "science"] ...@@ -27,25 +27,19 @@ categories = ["api-bindings", "science"]
authors = ["TVM Contributors"] authors = ["TVM Contributors"]
edition = "2018" edition = "2018"
[features]
default = ["nom/std"]
sgx = ["nom/alloc"]
[dependencies] [dependencies]
bounded-spsc-queue = "0.4.0" crossbeam = "0.7.3"
failure = "0.1.5" failure = "0.1"
itertools = "0.7.8" itertools = "0.8"
lazy_static = "1.1.0" lazy_static = "1.4"
ndarray="0.12.1" ndarray="0.12"
nom = {version = "4.0.0", default-features = false } nom = "5.0"
serde = "1.0.59" num_cpus = "1.10"
serde_derive = "1.0.79" serde = "1.0"
serde_json = "1.0.17" serde_derive = "1.0"
serde_json = "1.0"
tvm-common = { version = "0.1", path = "../common" } tvm-common = { version = "0.1", path = "../common" }
tvm-macros = { version = "0.1", path = "../macros" } tvm-macros = { version = "0.1", path = "../macros" }
[target.'cfg(not(target_env = "sgx"))'.dependencies]
num_cpus = "1.8.0"
[target.'cfg(not(any(target_arch = "wasm32", target_env = "sgx")))'.dependencies] [target.'cfg(not(any(target_arch = "wasm32", target_env = "sgx")))'.dependencies]
libloading = "0.5" libloading = "0.5"
...@@ -17,9 +17,6 @@ ...@@ -17,9 +17,6 @@
* under the License. * under the License.
*/ */
#[cfg(target_env = "sgx")]
use alloc::alloc::{self, Layout, LayoutErr};
#[cfg(not(target_env = "sgx"))]
use std::alloc::{self, Layout, LayoutErr}; use std::alloc::{self, Layout, LayoutErr};
const DEFAULT_ALIGN_BYTES: usize = 4; const DEFAULT_ALIGN_BYTES: usize = 4;
...@@ -35,14 +32,11 @@ impl Allocation { ...@@ -35,14 +32,11 @@ impl Allocation {
pub fn new(size: usize, align: Option<usize>) -> Result<Self, LayoutErr> { pub fn new(size: usize, align: Option<usize>) -> Result<Self, LayoutErr> {
let alignment = align.unwrap_or(DEFAULT_ALIGN_BYTES); let alignment = align.unwrap_or(DEFAULT_ALIGN_BYTES);
let layout = Layout::from_size_align(size, alignment)?; let layout = Layout::from_size_align(size, alignment)?;
let ptr = unsafe { alloc::alloc(layout.clone()) }; let ptr = unsafe { alloc::alloc(layout) };
if ptr.is_null() { if ptr.is_null() {
alloc::handle_alloc_error(layout); alloc::handle_alloc_error(layout);
} }
Ok(Self { Ok(Self { ptr, layout })
ptr: ptr,
layout: layout,
})
} }
pub fn as_mut_ptr(&self) -> *mut u8 { pub fn as_mut_ptr(&self) -> *mut u8 {
...@@ -58,12 +52,22 @@ impl Allocation { ...@@ -58,12 +52,22 @@ impl Allocation {
pub fn align(&self) -> usize { pub fn align(&self) -> usize {
self.layout.align() self.layout.align()
} }
/// Returns a view of the Allocation.
pub fn as_slice(&self) -> &[u8] {
unsafe { std::slice::from_raw_parts(self.as_mut_ptr(), self.size()) }
}
/// Returns a mutable view of the Allocation.
pub fn as_mut_slice(&mut self) -> &mut [u8] {
unsafe { std::slice::from_raw_parts_mut(self.as_mut_ptr(), self.size()) }
}
} }
impl Drop for Allocation { impl Drop for Allocation {
fn drop(&mut self) { fn drop(&mut self) {
unsafe { unsafe {
alloc::dealloc(self.ptr, self.layout.clone()); alloc::dealloc(self.ptr, self.layout);
} }
} }
} }
...@@ -101,6 +101,22 @@ impl<'a> Storage<'a> { ...@@ -101,6 +101,22 @@ impl<'a> Storage<'a> {
} }
s s
} }
/// Returns a view of the stored data.
pub fn as_slice(&self) -> &[u8] {
match self {
Storage::Owned(alloc) => alloc.as_slice(),
Storage::View(slice, _) => &*slice,
}
}
/// Returns a mutable view of the stored data.
pub fn as_mut_slice(&mut self) -> &mut [u8] {
match self {
Storage::Owned(alloc) => alloc.as_mut_slice(),
Storage::View(slice, _) => slice,
}
}
} }
impl<'d, 's, T> From<&'d [T]> for Storage<'s> { impl<'d, 's, T> From<&'d [T]> for Storage<'s> {
...@@ -123,14 +139,18 @@ impl<'d, 's, T> From<&'d [T]> for Storage<'s> { ...@@ -123,14 +139,18 @@ impl<'d, 's, T> From<&'d [T]> for Storage<'s> {
/// ///
/// ``` /// ```
/// extern crate ndarray; /// extern crate ndarray;
/// use std::convert::TryInto;
/// use tvm_runtime::{call_packed, DLTensor, TVMArgValue, TVMRetValue, Tensor};
/// ///
/// let mut a_nd: ndarray::Array = ndarray::Array::from_vec(vec![1f32, 2., 3., 4.]); /// let mut a_nd: ndarray::Array1<f32> = ndarray::Array::from_vec(vec![1f32, 2., 3., 4.]);
/// let mut a: Tensor = a_nd.into(); /// let mut a: Tensor = a_nd.into();
/// let mut a_dl: DLTensor = (&mut t).into(); /// let mut a_dl: DLTensor = (&mut a).into();
///
/// let tvm_fn = |args: &[TVMArgValue]| -> Result<TVMRetValue, ()> { Ok(TVMRetValue::default()) };
/// call_packed!(tvm_fn, &mut a_dl); /// call_packed!(tvm_fn, &mut a_dl);
/// ///
/// // Array -> Tensor is mostly useful when post-processing TVM graph outputs. /// // Array -> Tensor is mostly useful when post-processing TVM graph outputs.
/// let mut a_nd = ndarray::Array::try_from(&a).unwrap(); /// let mut a_nd: ndarray::ArrayD<f32> = a.try_into().unwrap();
/// ``` /// ```
#[derive(PartialEq)] #[derive(PartialEq)]
pub struct Tensor<'a> { pub struct Tensor<'a> {
...@@ -154,6 +174,14 @@ impl<'a> Tensor<'a> { ...@@ -154,6 +174,14 @@ impl<'a> Tensor<'a> {
self.shape.clone() self.shape.clone()
} }
pub fn data(&self) -> &Storage {
&self.data
}
pub fn data_mut(&mut self) -> &'a mut Storage {
&mut self.data
}
/// Returns the data of this `Tensor` as a `Vec`. /// Returns the data of this `Tensor` as a `Vec`.
/// ///
/// # Panics /// # Panics
...@@ -220,9 +248,9 @@ impl<'a> Tensor<'a> { ...@@ -220,9 +248,9 @@ impl<'a> Tensor<'a> {
pub fn to_owned(&self) -> Tensor<'static> { pub fn to_owned(&self) -> Tensor<'static> {
let t = Tensor { let t = Tensor {
data: self.data.to_owned(), data: self.data.to_owned(),
ctx: self.ctx.clone(), ctx: self.ctx,
dtype: self.dtype.clone(), dtype: self.dtype,
size: self.size.clone(), size: self.size,
shape: self.shape.clone(), shape: self.shape.clone(),
strides: None, strides: None,
byte_offset: 0, byte_offset: 0,
...@@ -246,7 +274,7 @@ impl<'a> Tensor<'a> { ...@@ -246,7 +274,7 @@ impl<'a> Tensor<'a> {
}, },
size: arr.len(), size: arr.len(),
shape: arr.shape().iter().map(|&v| v as i64).collect(), shape: arr.shape().iter().map(|&v| v as i64).collect(),
strides: Some(arr.strides().into_iter().map(|&v| v as usize).collect()), strides: Some(arr.strides().iter().map(|&v| v as usize).collect()),
byte_offset: 0, byte_offset: 0,
} }
} }
...@@ -276,9 +304,9 @@ impl<'a> Tensor<'a> { ...@@ -276,9 +304,9 @@ impl<'a> Tensor<'a> {
/// Conversions to `ndarray::Array` from `Tensor`, if the types match. /// Conversions to `ndarray::Array` from `Tensor`, if the types match.
macro_rules! impl_ndarray_try_from_tensor { macro_rules! impl_ndarray_try_from_tensor {
($type:ty, $dtype:expr) => { ($type:ty, $dtype:expr) => {
impl<'a, 't> TryFrom<&'a Tensor<'t>> for ndarray::ArrayD<$type> { impl<'t> TryFrom<Tensor<'t>> for ndarray::ArrayD<$type> {
type Error = Error; type Error = Error;
fn try_from(tensor: &'a Tensor) -> Result<ndarray::ArrayD<$type>, Error> { fn try_from(tensor: Tensor) -> Result<ndarray::ArrayD<$type>, Error> {
ensure!( ensure!(
tensor.dtype == $dtype, tensor.dtype == $dtype,
"Cannot convert Tensor with dtype {:?} to ndarray", "Cannot convert Tensor with dtype {:?} to ndarray",
...@@ -342,10 +370,10 @@ impl<'a> From<DLTensor> for Tensor<'a> { ...@@ -342,10 +370,10 @@ impl<'a> From<DLTensor> for Tensor<'a> {
Self { Self {
data: storage, data: storage,
ctx: TVMContext::default(), ctx: TVMContext::default(),
dtype: dtype, dtype,
size: size, size,
shape: shape, shape,
strides: if dlt.strides == ptr::null_mut() { strides: if dlt.strides.is_null() {
None None
} else { } else {
Some(slice::from_raw_parts_mut(dlt.strides as *mut usize, size).to_vec()) Some(slice::from_raw_parts_mut(dlt.strides as *mut usize, size).to_vec())
......
...@@ -30,9 +30,3 @@ pub enum GraphFormatError { ...@@ -30,9 +30,3 @@ pub enum GraphFormatError {
#[fail(display = "Invalid DLType: {}", 0)] #[fail(display = "Invalid DLType: {}", 0)]
InvalidDLType(String), InvalidDLType(String),
} }
#[derive(Debug, Fail)]
#[fail(display = "SGX error: 0x{:x}", code)]
pub struct SgxError {
pub code: u32,
}
...@@ -28,19 +28,6 @@ ...@@ -28,19 +28,6 @@
//! The main entrypoints to this crate are `GraphExecutor` //! The main entrypoints to this crate are `GraphExecutor`
//! For examples of use, please refer to the multi-file tests in the `tests` directory. //! For examples of use, please refer to the multi-file tests in the `tests` directory.
#![feature(
allocator_api,
box_syntax,
fn_traits,
unboxed_closures,
vec_remove_item
)]
#[cfg(target_env = "sgx")]
extern crate alloc;
extern crate bounded_spsc_queue;
#[cfg(target_env = "sgx")]
extern crate core;
#[macro_use] #[macro_use]
extern crate failure; extern crate failure;
#[macro_use] #[macro_use]
...@@ -50,7 +37,6 @@ extern crate lazy_static; ...@@ -50,7 +37,6 @@ extern crate lazy_static;
extern crate ndarray; extern crate ndarray;
#[macro_use] #[macro_use]
extern crate nom; extern crate nom;
#[cfg(not(target_env = "sgx"))]
extern crate num_cpus; extern crate num_cpus;
extern crate serde; extern crate serde;
#[macro_use] #[macro_use]
...@@ -63,9 +49,6 @@ mod array; ...@@ -63,9 +49,6 @@ mod array;
pub mod errors; pub mod errors;
mod graph; mod graph;
mod module; mod module;
#[cfg(target_env = "sgx")]
#[macro_use]
pub mod sgx;
mod threading; mod threading;
mod workspace; mod workspace;
...@@ -86,10 +69,8 @@ lazy_static! { ...@@ -86,10 +69,8 @@ lazy_static! {
} }
#[no_mangle] #[no_mangle]
pub extern "C" fn TVMAPISetLastError(cmsg: *const i8) { pub unsafe extern "C" fn TVMAPISetLastError(cmsg: *const i8) {
*LAST_ERROR.write().unwrap() = Some(unsafe { std::ffi::CStr::from_ptr(cmsg) }); *LAST_ERROR.write().unwrap() = Some(std::ffi::CStr::from_ptr(cmsg));
#[cfg(target_env = "sgx")]
ocall_packed!("__sgx_set_last_error__", cmsg);
} }
#[no_mangle] #[no_mangle]
......
...@@ -35,8 +35,8 @@ use crate::{ ...@@ -35,8 +35,8 @@ use crate::{
use super::Module; use super::Module;
const TVM_MAIN: &'static [u8] = b"__tvm_main__"; const TVM_MAIN: &[u8] = b"__tvm_main__";
const TVM_MODULE_CTX: &'static [u8] = b"__tvm_module_ctx"; const TVM_MODULE_CTX: &[u8] = b"__tvm_module_ctx";
/// A module backed by a Dynamic Shared Object (dylib). /// A module backed by a Dynamic Shared Object (dylib).
pub struct DsoModule<'a> { pub struct DsoModule<'a> {
...@@ -64,22 +64,26 @@ impl<'a> DsoModule<'a> { ...@@ -64,22 +64,26 @@ impl<'a> DsoModule<'a> {
init_context_func!( init_context_func!(
lib, lib,
(TVMAPISetLastError, extern "C" fn(*const i8)), (TVMAPISetLastError, unsafe extern "C" fn(*const i8)),
( (
TVMBackendAllocWorkspace, TVMBackendAllocWorkspace,
extern "C" fn(c_int, c_int, u64, c_int, c_int) -> *mut c_void unsafe extern "C" fn(c_int, c_int, u64, c_int, c_int) -> *mut c_void
), ),
( (
TVMBackendFreeWorkspace, TVMBackendFreeWorkspace,
extern "C" fn(c_int, c_int, *mut c_void) -> c_int unsafe extern "C" fn(c_int, c_int, *mut c_void) -> c_int
), ),
( (
TVMBackendParallelLaunch, TVMBackendParallelLaunch,
extern "C" fn(crate::threading::FTVMParallelLambda, *const c_void, usize) -> c_int unsafe extern "C" fn(
crate::threading::FTVMParallelLambda,
*const c_void,
usize,
) -> c_int
), ),
( (
TVMBackendParallelBarrier, TVMBackendParallelBarrier,
extern "C" fn(usize, *const tvm_common::ffi::TVMParallelGroupEnv) unsafe extern "C" fn(usize, *const tvm_common::ffi::TVMParallelGroupEnv)
), ),
); );
...@@ -129,7 +133,7 @@ impl<'a> Module for DsoModule<'a> { ...@@ -129,7 +133,7 @@ impl<'a> Module for DsoModule<'a> {
&*Box::leak(super::wrap_backend_packed_func(name.to_string(), *func)), &*Box::leak(super::wrap_backend_packed_func(name.to_string(), *func)),
); );
self.packed_funcs.borrow().get(name).map(|f| *f) self.packed_funcs.borrow().get(name).copied()
} }
} }
......
...@@ -36,9 +36,9 @@ pub trait Module { ...@@ -36,9 +36,9 @@ pub trait Module {
// @see `WrapPackedFunc` in `llvm_module.cc`. // @see `WrapPackedFunc` in `llvm_module.cc`.
fn wrap_backend_packed_func(func_name: String, func: BackendPackedCFunc) -> Box<dyn PackedFunc> { fn wrap_backend_packed_func(func_name: String, func: BackendPackedCFunc) -> Box<dyn PackedFunc> {
box move |args: &[TVMArgValue]| { Box::new(move |args: &[TVMArgValue]| {
let (values, type_codes): (Vec<TVMValue>, Vec<i32>) = args let (values, type_codes): (Vec<TVMValue>, Vec<i32>) = args
.into_iter() .iter()
.map(|arg| { .map(|arg| {
let (val, code) = arg.to_tvm_value(); let (val, code) = arg.to_tvm_value();
(val, code as i32) (val, code as i32)
...@@ -52,5 +52,5 @@ fn wrap_backend_packed_func(func_name: String, func: BackendPackedCFunc) -> Box< ...@@ -52,5 +52,5 @@ fn wrap_backend_packed_func(func_name: String, func: BackendPackedCFunc) -> Box<
func_name.clone(), func_name.clone(),
)) ))
} }
} })
} }
...@@ -27,6 +27,11 @@ use super::Module; ...@@ -27,6 +27,11 @@ use super::Module;
pub struct SystemLibModule; pub struct SystemLibModule;
#[cfg(target_env = "sgx")]
extern "C" {
fn __tvm_module_startup();
}
lazy_static! { lazy_static! {
static ref SYSTEM_LIB_FUNCTIONS: Mutex<HashMap<String, &'static (dyn PackedFunc)>> = static ref SYSTEM_LIB_FUNCTIONS: Mutex<HashMap<String, &'static (dyn PackedFunc)>> =
Mutex::new(HashMap::new()); Mutex::new(HashMap::new());
...@@ -37,13 +42,16 @@ impl Module for SystemLibModule { ...@@ -37,13 +42,16 @@ impl Module for SystemLibModule {
SYSTEM_LIB_FUNCTIONS SYSTEM_LIB_FUNCTIONS
.lock() .lock()
.unwrap() .unwrap()
.get(name.as_ref()) .get(name.as_ref()).copied()
.map(|f| *f)
} }
} }
impl Default for SystemLibModule { impl Default for SystemLibModule {
fn default() -> Self { fn default() -> Self {
#[cfg(target_env = "sgx")]
unsafe {
__tvm_module_startup();
}
SystemLibModule {} SystemLibModule {}
} }
} }
...@@ -58,5 +66,5 @@ pub extern "C" fn TVMBackendRegisterSystemLibSymbol( ...@@ -58,5 +66,5 @@ pub extern "C" fn TVMBackendRegisterSystemLibSymbol(
name.to_string(), name.to_string(),
&*Box::leak(super::wrap_backend_packed_func(name.to_string(), func)), &*Box::leak(super::wrap_backend_packed_func(name.to_string(), func)),
); );
return 0; 0
} }
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
use std::{
ffi::CString,
os::raw::{c_char, c_int},
};
pub use crate::threading::tvm_run_worker as run_worker;
use crate::{threading::sgx_join_threads, SystemLibModule, TVMArgValue, TVMRetValue};
use errors::SgxError;
use ffi::TVMValue;
#[macro_export]
macro_rules! tvm_ocall {
($func: expr) => {
match $func {
0 => Ok(()),
code => Err(SgxError { code }),
}
};
}
pub type SgxStatus = u32;
#[cfg(target_env = "sgx")]
extern "C" {
fn tvm_ocall_packed_func(
name: *const c_char,
arg_values: *const TVMValue,
type_codes: *const c_int,
num_args: c_int,
ret_val: *mut TVMValue,
ret_type_code: *mut c_int,
) -> SgxStatus;
}
pub fn ocall_packed_func<S: AsRef<str>>(
fn_name: S,
args: &[TVMArgValue],
) -> Result<TVMRetValue, SgxError> {
let mut ret_val = TVMValue { v_int64: 0 };
let ret_type_code = 0i64;
unsafe {
tvm_ocall!(tvm_ocall_packed_func(
CString::new(fn_name.as_ref()).unwrap().as_ptr(),
args.iter()
.map(|ref arg| arg.value)
.collect::<Vec<TVMValue>>()
.as_ptr(),
args.iter()
.map(|ref arg| arg.type_code as i32)
.collect::<Vec<i32>>()
.as_ptr() as *const i32,
args.len() as i32,
&mut ret_val as *mut TVMValue,
&mut (ret_type_code as i32) as *mut c_int,
))?;
}
Ok(TVMRetValue::from_tvm_value(ret_val, ret_type_code as i64))
}
#[macro_export]
macro_rules! ocall_packed {
($fn_name:expr, $($args:expr),+) => {
$crate::sgx::ocall_packed_func($fn_name, &[$($args.into(),)+])
.expect(concat!("Error calling `", $fn_name, "`"))
};
($fn_name:expr) => {
$crate::sgx::ocall_packed_func($fn_name, &Vec::new())
.expect(concat!("Error calling `", $fn_name, "`"))
}
}
pub fn shutdown() {
if env!("TVM_NUM_THREADS") != "0" {
sgx_join_threads()
}
}
impl Drop for SystemLibModule {
fn drop(&mut self) {
shutdown()
}
}
...@@ -18,30 +18,18 @@ ...@@ -18,30 +18,18 @@
*/ */
use std::{ use std::{
env,
os::raw::{c_int, c_void}, os::raw::{c_int, c_void},
sync::{ sync::{
atomic::{AtomicUsize, Ordering}, atomic::{AtomicUsize, Ordering},
Arc, Barrier, Arc, Barrier,
}, },
};
#[cfg(not(target_env = "sgx"))]
use num_cpus;
#[cfg(not(target_env = "sgx"))]
use std::{
env,
thread::{self, JoinHandle}, thread::{self, JoinHandle},
}; };
#[cfg(target_env = "sgx")] use crossbeam::channel::{Sender, Receiver, bounded};
use std::{collections::VecDeque, ptr, sync::Mutex};
use bounded_spsc_queue::{self, Producer};
use tvm_common::ffi::TVMParallelGroupEnv; use tvm_common::ffi::TVMParallelGroupEnv;
#[cfg(target_env = "sgx")]
use super::{TVMArgValue, TVMRetValue};
pub(crate) type FTVMParallelLambda = pub(crate) type FTVMParallelLambda =
extern "C" fn(task_id: usize, penv: *const TVMParallelGroupEnv, cdata: *const c_void) -> i32; extern "C" fn(task_id: usize, penv: *const TVMParallelGroupEnv, cdata: *const c_void) -> i32;
...@@ -82,7 +70,6 @@ impl Job { ...@@ -82,7 +70,6 @@ impl Job {
/// Waits for all tasks in this `Job` to be completed. /// Waits for all tasks in this `Job` to be completed.
fn wait(&self) { fn wait(&self) {
while self.pending.load(Ordering::Acquire) > 0 { while self.pending.load(Ordering::Acquire) > 0 {
#[cfg(not(target_env = "sgx"))]
thread::yield_now(); thread::yield_now();
} }
} }
...@@ -99,9 +86,8 @@ struct Task { ...@@ -99,9 +86,8 @@ struct Task {
unsafe impl Send for Task {} unsafe impl Send for Task {}
unsafe impl Sync for Task {} unsafe impl Sync for Task {}
impl FnOnce<()> for Task { impl Task {
type Output = i32; fn run(self) -> i32 {
extern "rust-call" fn call_once(self, _args: ()) -> Self::Output {
let status = (self.flambda)(self.id, &self.penv as *const _, self.cdata); let status = (self.flambda)(self.id, &self.penv as *const _, self.cdata);
self.pending.fetch_sub(1, Ordering::AcqRel); self.pending.fetch_sub(1, Ordering::AcqRel);
status status
...@@ -111,45 +97,23 @@ impl FnOnce<()> for Task { ...@@ -111,45 +97,23 @@ impl FnOnce<()> for Task {
#[derive(Default)] #[derive(Default)]
struct Threads { struct Threads {
#[allow(unused)] #[allow(unused)]
#[cfg(not(target_env = "sgx"))]
handles: Vec<JoinHandle<()>>, handles: Vec<JoinHandle<()>>,
queues: Vec<Producer<Task>>, queues: Vec<Sender<Task>>,
} }
impl<'a> Threads { impl<'a> Threads {
#[cfg(not(target_env = "sgx"))] fn launch<F: Sync + Send + FnOnce(Receiver<Task>) + 'static + Copy>(
fn launch<F: Sync + Send + FnOnce(Consumer<Task>) + 'static + Copy>(
num_threads: usize, num_threads: usize,
cb: F, cb: F,
) -> Self { ) -> Self {
let (handles, queues) = (0..num_threads) let (handles, queues) = (0..num_threads)
.map(|_| { .map(|_| {
let (p, c) = bounded_spsc_queue::make(2); let (p, c) = bounded(2);
let handle = thread::spawn(move || cb(c.into())); let handle = thread::spawn(move || cb(c.into()));
(handle, p) (handle, p)
}) })
.unzip(); .unzip();
Threads { Threads { handles, queues }
handles: handles,
queues: queues,
}
}
#[cfg(target_env = "sgx")]
fn launch<F: Sync + Send + FnOnce(Consumer<Task>) + 'static + Copy>(
num_threads: usize,
_cb: F,
) -> Self {
let mut consumer_queues = SGX_QUEUES.lock().unwrap();
let queues = (0..num_threads)
.map(|_| {
let (p, c) = bounded_spsc_queue::make(2);
consumer_queues.push_back(c.into());
p
})
.collect();
ocall_packed!("__sgx_thread_group_launch__", num_threads as u64);
Threads { queues: queues }
} }
} }
...@@ -165,7 +129,7 @@ impl ThreadPool { ...@@ -165,7 +129,7 @@ impl ThreadPool {
fn new() -> Self { fn new() -> Self {
let num_workers = max_concurrency(); let num_workers = max_concurrency();
ThreadPool { ThreadPool {
num_workers: num_workers, num_workers,
threads: Threads::launch(num_workers, ThreadPool::run_worker), threads: Threads::launch(num_workers, ThreadPool::run_worker),
} }
} }
...@@ -174,17 +138,18 @@ impl ThreadPool { ...@@ -174,17 +138,18 @@ impl ThreadPool {
let mut tasks = job.tasks(self.num_workers + 1); let mut tasks = job.tasks(self.num_workers + 1);
for (i, task) in tasks.split_off(1).into_iter().enumerate() { for (i, task) in tasks.split_off(1).into_iter().enumerate() {
self.threads.queues[i].push(task); self.threads.queues[i].send(task)
.expect("should send");
} }
tasks.pop().unwrap()(); tasks.pop().unwrap().run();
job.wait(); job.wait();
} }
fn run_worker(queue: Consumer<Task>) { fn run_worker(queue: Receiver<Task>) {
loop { loop {
let task = queue.pop(); let task = queue.recv().expect("should recv");
let result = task(); let result = task.run();
if result == <i32>::min_value() { if result == <i32>::min_value() {
break; break;
} else if result != 0 { } else if result != 0 {
...@@ -194,42 +159,14 @@ impl ThreadPool { ...@@ -194,42 +159,14 @@ impl ThreadPool {
} }
} }
// Send + Sync wrapper for bounded_spsc_queue::Consumer #[cfg(not(target_arch = "wasm32"))]
struct Consumer<T> {
consumer: bounded_spsc_queue::Consumer<T>,
}
impl<T> From<bounded_spsc_queue::Consumer<T>> for Consumer<T> {
fn from(c: bounded_spsc_queue::Consumer<T>) -> Self {
Consumer { consumer: c }
}
}
impl<T> Consumer<T> {
fn pop(&self) -> T {
self.consumer.pop()
}
}
unsafe impl<T> Send for Consumer<T> {}
unsafe impl<T> Sync for Consumer<T> {}
#[cfg(target_env = "sgx")]
lazy_static! {
/// Holds tasks for untrusted threads which re-enter the enclave to execute.
static ref SGX_QUEUES: Mutex<VecDeque<Consumer<Task>>> = Mutex::new(VecDeque::new());
}
#[cfg(all(not(target_arch = "wasm32"), not(target_env = "sgx")))]
fn max_concurrency() -> usize { fn max_concurrency() -> usize {
if let Ok(threads_str) = env::var("TVM_NUM_THREADS").or(env::var("OMP_NUM_THREADS")) { if let Ok(threads_str) = env::var("TVM_NUM_THREADS").or_else(|_| env::var("OMP_NUM_THREADS")) {
if let Ok(threads) = usize::from_str_radix(&threads_str, 10) { if let Ok(threads) = usize::from_str_radix(&threads_str, 10) {
return threads; return threads;
} }
} }
num_cpus::get_physical() num_cpus::get()
}
#[cfg(target_env = "sgx")]
fn max_concurrency() -> usize {
usize::from_str_radix(env!("TVM_NUM_THREADS"), 10).unwrap_or(1)
} }
#[cfg(target_arch = "wasm32")] #[cfg(target_arch = "wasm32")]
...@@ -237,69 +174,38 @@ fn max_concurrency() -> usize { ...@@ -237,69 +174,38 @@ fn max_concurrency() -> usize {
0 // wasm doesn't support threads yet 0 // wasm doesn't support threads yet
} }
#[cfg(target_env = "sgx")]
pub fn tvm_run_worker(_args: &[TVMArgValue]) -> TVMRetValue {
let q = {
let mut qs = SGX_QUEUES.lock().unwrap();
qs.pop_front()
// `qs: MutexGuard` needs to be dropped here since `run_worker` won't return
};
if let Some(q) = q {
ThreadPool::run_worker(q);
}
TVMRetValue::default()
}
#[no_mangle] #[no_mangle]
pub extern "C" fn TVMBackendParallelLaunch( pub extern "C" fn TVMBackendParallelLaunch(
cb: FTVMParallelLambda, cb: FTVMParallelLambda,
cdata: *const c_void, cdata: *const c_void,
num_task: usize, num_task: usize,
) -> c_int { ) -> c_int {
if max_concurrency() == 0 { if max_concurrency() < 2 {
let penv = TVMParallelGroupEnv { let penv = TVMParallelGroupEnv {
sync_handle: 0 as *mut c_void, sync_handle: std::ptr::null_mut(),
num_task: 1, num_task: 1,
}; };
cb(0, &penv as *const _, cdata); cb(0, &penv as *const _, cdata);
} else { } else {
THREAD_POOL.with(|pool| { THREAD_POOL.with(|pool| {
pool.launch(Job { pool.launch(Job {
cb: cb, cb,
cdata: cdata, cdata,
req_num_tasks: num_task, req_num_tasks: num_task,
pending: Arc::new(AtomicUsize::new(0)), pending: Arc::new(AtomicUsize::new(0)),
}); });
}); });
} }
return 0; 0
}
#[cfg(target_env = "sgx")]
pub(crate) fn sgx_join_threads() {
extern "C" fn poison_pill(
_task_id: usize,
_penv: *const TVMParallelGroupEnv,
_cdata: *const c_void,
) -> i32 {
<i32>::min_value()
}
THREAD_POOL.with(|pool| {
pool.launch(Job {
cb: poison_pill,
cdata: ptr::null(),
req_num_tasks: 0,
pending: Arc::new(AtomicUsize::new(0)),
});
});
ocall_packed!("__sgx_thread_group_join__", 0);
} }
// @see issue 988 for information on why this function is used. // @see issue 988 for information on why this function is used.
#[no_mangle] #[no_mangle]
pub extern "C" fn TVMBackendParallelBarrier(_task_id: usize, penv: *const TVMParallelGroupEnv) { pub unsafe extern "C" fn TVMBackendParallelBarrier(
let barrier: &Arc<Barrier> = unsafe { &*((*penv).sync_handle as *const Arc<Barrier>) }; _task_id: usize,
penv: *const TVMParallelGroupEnv,
) {
let barrier: &Arc<Barrier> = &*((*penv).sync_handle as *const Arc<Barrier>);
barrier.wait(); barrier.wait();
} }
...@@ -323,7 +229,7 @@ mod tests { ...@@ -323,7 +229,7 @@ mod tests {
penv: *const TVMParallelGroupEnv, penv: *const TVMParallelGroupEnv,
cdata: *const c_void, cdata: *const c_void,
) -> i32 { ) -> i32 {
if cdata == ptr::null() { if cdata.is_null() {
return 0; return 0;
} }
unsafe { unsafe {
......
...@@ -29,6 +29,11 @@ use crate::allocator::Allocation; ...@@ -29,6 +29,11 @@ use crate::allocator::Allocation;
const WS_ALIGN: usize = 64; // taken from `kTempAllocaAlignment` in `device_api.h` const WS_ALIGN: usize = 64; // taken from `kTempAllocaAlignment` in `device_api.h`
pub fn remove_item<T: PartialEq>(vec: &mut Vec<T>, item: &T) -> Option<T> {
let pos = vec.iter().position(|x| *x == *item)?;
Some(vec.remove(pos))
}
struct WorkspacePool { struct WorkspacePool {
workspaces: Vec<Allocation>, workspaces: Vec<Allocation>,
free: Vec<usize>, free: Vec<usize>,
...@@ -51,7 +56,7 @@ impl WorkspacePool { ...@@ -51,7 +56,7 @@ impl WorkspacePool {
} }
fn alloc(&mut self, size: usize) -> Result<*mut u8, Error> { fn alloc(&mut self, size: usize) -> Result<*mut u8, Error> {
if self.free.len() == 0 { if self.free.is_empty() {
return self.alloc_new(size); return self.alloc_new(size);
} }
let idx = self let idx = self
...@@ -64,15 +69,12 @@ impl WorkspacePool { ...@@ -64,15 +69,12 @@ impl WorkspacePool {
} }
cur_ws_idx.or(Some(idx)).and_then(|cur_idx| { cur_ws_idx.or(Some(idx)).and_then(|cur_idx| {
let cur_size = self.workspaces[cur_idx].size(); let cur_size = self.workspaces[cur_idx].size();
Some(match ws_size <= cur_size { Some(if ws_size <= cur_size { idx } else { cur_idx })
true => idx,
false => cur_idx,
})
}) })
}); });
match idx { match idx {
Some(idx) => { Some(idx) => {
self.free.remove_item(&idx).unwrap(); remove_item(&mut self.free, &idx).unwrap();
self.in_use.push(idx); self.in_use.push(idx);
Ok(self.workspaces[idx].as_mut_ptr()) Ok(self.workspaces[idx].as_mut_ptr())
} }
...@@ -90,9 +92,10 @@ impl WorkspacePool { ...@@ -90,9 +92,10 @@ impl WorkspacePool {
break; break;
} }
} }
Ok(self if let Some(ws_idx) = ws_idx {
.free self.free.push(ws_idx);
.push(ws_idx.ok_or(format_err!("Tried to free nonexistent workspace."))?)) }
Ok(())
} }
} }
...@@ -133,5 +136,5 @@ pub extern "C" fn TVMBackendFreeWorkspace( ...@@ -133,5 +136,5 @@ pub extern "C" fn TVMBackendFreeWorkspace(
Err(_) => -1, Err(_) => -1,
}) as c_int }) as c_int
}); });
return 0; 0
} }
...@@ -26,11 +26,38 @@ use std::{convert::TryFrom, fs, io::Read}; ...@@ -26,11 +26,38 @@ use std::{convert::TryFrom, fs, io::Read};
use tvm_runtime::Graph; use tvm_runtime::Graph;
macro_rules! mf_dir {
($p:literal) => {
concat!(env!("CARGO_MANIFEST_DIR"), $p)
};
}
static PARAMS_FIXTURE_PATH: &str = mf_dir!("/tests/graph.params");
#[test] #[test]
fn test_load_graph() { fn test_load_graph() {
let output = std::process::Command::new(mf_dir!("/tests/build_model.py"))
.env(
"PYTHONPATH",
concat!(
mf_dir!("/../../python"),
":",
mf_dir!("/../../nnvm/python"),
":",
mf_dir!("/../../topi/python")
),
)
.output()
.expect("Failed to build test model");
assert!(
std::path::Path::new(PARAMS_FIXTURE_PATH).exists(),
"Could not build test graph fixture: STDOUT:\n\n{}\nSTDERR: {}\n\n",
String::from_utf8(output.stdout).unwrap(),
String::from_utf8(output.stderr).unwrap()
);
let mut params_bytes = Vec::new(); let mut params_bytes = Vec::new();
fs::File::open(concat!(env!("CARGO_MANIFEST_DIR"), "/tests/graph.params")) fs::File::open(PARAMS_FIXTURE_PATH)
.expect("Could not find TVM graph. Did you run `tests/build_model.py`?") .unwrap()
.read_to_end(&mut params_bytes) .read_to_end(&mut params_bytes)
.unwrap(); .unwrap();
let _params = tvm_runtime::load_param_dict(&params_bytes); let _params = tvm_runtime::load_param_dict(&params_bytes);
......
...@@ -22,10 +22,10 @@ license = "Apache-2.0" ...@@ -22,10 +22,10 @@ license = "Apache-2.0"
authors = ["TVM Contributors"] authors = ["TVM Contributors"]
[dependencies] [dependencies]
ndarray="0.12.1" ndarray="0.12"
serde = "1.0.59" serde = "1.0"
serde_json = "1.0.17" serde_json = "1.0"
tvm-runtime = { path = "../../" } tvm-runtime = { path = "../../" }
[build-dependencies] [build-dependencies]
ar = "0.6.0" ar = "0.6"
...@@ -33,11 +33,11 @@ const IN_DIM: usize = 8; ...@@ -33,11 +33,11 @@ const IN_DIM: usize = 8;
macro_rules! check_sum { macro_rules! check_sum {
($e:expr, $a:ident, $b:ident) => { ($e:expr, $a:ident, $b:ident) => {
let a = Array::try_from($e.get_input(stringify!($a)).unwrap()).unwrap(); let a = Array::try_from($e.get_input(stringify!($a)).unwrap().to_owned()).unwrap();
check_sum!(a, $b); check_sum!(a, $b);
}; };
($e:expr, $a:expr, $b:ident) => { ($e:expr, $a:expr, $b:ident) => {
let a = Array::try_from($e.get_output($a).unwrap()).unwrap(); let a = Array::try_from($e.get_output($a).unwrap().to_owned()).unwrap();
check_sum!(a, $b); check_sum!(a, $b);
}; };
($a:ident, $b:ident) => { ($a:ident, $b:ident) => {
...@@ -73,11 +73,11 @@ fn main() { ...@@ -73,11 +73,11 @@ fn main() {
.collect::<Vec<f32>>(), .collect::<Vec<f32>>(),
) )
.unwrap(); .unwrap();
let w = Array::try_from(params.get("dense0_weight").unwrap()) let w = Array::try_from(params.get("dense0_weight").unwrap().to_owned())
.unwrap() .unwrap()
.into_shape((IN_DIM * 2, IN_DIM)) .into_shape((IN_DIM * 2, IN_DIM))
.unwrap(); .unwrap();
let b = Array::try_from(params.get("dense0_bias").unwrap()).unwrap(); let b = Array::try_from(params.get("dense0_bias").unwrap().to_owned()).unwrap();
let dense = x.dot(&w.t()) + &b; let dense = x.dot(&w.t()) + &b;
let left = dense.slice(s![.., 0..IN_DIM]); let left = dense.slice(s![.., 0..IN_DIM]);
let right = dense.slice(s![.., IN_DIM..]); let right = dense.slice(s![.., IN_DIM..]);
......
...@@ -22,8 +22,8 @@ license = "Apache-2.0" ...@@ -22,8 +22,8 @@ license = "Apache-2.0"
authors = ["TVM Contributors"] authors = ["TVM Contributors"]
[dependencies] [dependencies]
ndarray="0.12.1" ndarray="0.12"
tvm-runtime = { path = "../../" } tvm-runtime = { path = "../../" }
[build-dependencies] [build-dependencies]
ar = "0.6.0" ar = "0.6"
...@@ -28,12 +28,7 @@ ...@@ -28,12 +28,7 @@
#include <tvm/runtime/module.h> #include <tvm/runtime/module.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/runtime/device_api.h> #include <tvm/runtime/device_api.h>
#ifdef _LIBCPP_SGX_CONFIG
#include "sgx/trusted/runtime.h"
#endif
#ifndef _LIBCPP_SGX_NO_IOSTREAMS
#include <sstream> #include <sstream>
#endif
#include <array> #include <array>
#include <algorithm> #include <algorithm>
#include <string> #include <string>
...@@ -174,7 +169,6 @@ void DeviceAPI::SyncStreamFromTo(TVMContext ctx, ...@@ -174,7 +169,6 @@ void DeviceAPI::SyncStreamFromTo(TVMContext ctx,
LOG(FATAL) << "Device does not support stream api."; LOG(FATAL) << "Device does not support stream api.";
} }
#ifndef _LIBCPP_SGX_NO_IOSTREAMS
//-------------------------------------------------------- //--------------------------------------------------------
// Error handling mechanism // Error handling mechanism
// ------------------------------------------------------- // -------------------------------------------------------
...@@ -338,11 +332,6 @@ std::string NormalizeError(std::string err_msg) { ...@@ -338,11 +332,6 @@ std::string NormalizeError(std::string err_msg) {
return os.str(); return os.str();
} }
#else
std::string NormalizeError(std::string err_msg) {
return err_msg;
}
#endif
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
...@@ -366,11 +355,7 @@ int TVMAPIHandleException(const std::runtime_error &e) { ...@@ -366,11 +355,7 @@ int TVMAPIHandleException(const std::runtime_error &e) {
} }
void TVMAPISetLastError(const char* msg) { void TVMAPISetLastError(const char* msg) {
#ifndef _LIBCPP_SGX_CONFIG
TVMAPIRuntimeStore::Get()->last_error = msg; TVMAPIRuntimeStore::Get()->last_error = msg;
#else
sgx::OCallPackedFunc("__sgx_set_last_error__", msg);
#endif
} }
int TVMModLoadFromFile(const char* file_name, int TVMModLoadFromFile(const char* file_name,
......
...@@ -25,11 +25,7 @@ ...@@ -25,11 +25,7 @@
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <dmlc/thread_local.h> #include <dmlc/thread_local.h>
#include <algorithm> #include <algorithm>
#ifndef _LIBCPP_SGX_CONFIG
#include "mt_random_engine.cc" #include "mt_random_engine.cc"
#else
#include "sgx_random_engine.cc"
#endif
#define DLPACK_INTEGER_TYPE_SWITCH(type, DType, ...) \ #define DLPACK_INTEGER_TYPE_SWITCH(type, DType, ...) \
if (type.code == kDLInt && type.bits == 32) { \ if (type.code == kDLInt && type.bits == 32) { \
......
...@@ -50,7 +50,7 @@ class CPUDeviceAPI final : public DeviceAPI { ...@@ -50,7 +50,7 @@ class CPUDeviceAPI final : public DeviceAPI {
#if _MSC_VER #if _MSC_VER
ptr = _aligned_malloc(nbytes, alignment); ptr = _aligned_malloc(nbytes, alignment);
if (ptr == nullptr) throw std::bad_alloc(); if (ptr == nullptr) throw std::bad_alloc();
#elif defined(_LIBCPP_SGX_CONFIG) || (defined(__ANDROID__) && __ANDROID_API__ < 17) #elif defined(__ANDROID__) && __ANDROID_API__ < 17
ptr = memalign(alignment, nbytes); ptr = memalign(alignment, nbytes);
if (ptr == nullptr) throw std::bad_alloc(); if (ptr == nullptr) throw std::bad_alloc();
#else #else
......
...@@ -73,7 +73,6 @@ std::string GetFileFormat(const std::string& file_name, ...@@ -73,7 +73,6 @@ std::string GetFileFormat(const std::string& file_name,
const std::string& format) { const std::string& format) {
std::string fmt = format; std::string fmt = format;
if (fmt.length() == 0) { if (fmt.length() == 0) {
if (file_name.find(".signed.so") != std::string::npos) return "sgx";
size_t pos = file_name.find_last_of("."); size_t pos = file_name.find_last_of(".");
if (pos != std::string::npos) { if (pos != std::string::npos) {
return file_name.substr(pos + 1, file_name.length() - pos - 1); return file_name.substr(pos + 1, file_name.length() - pos - 1);
......
...@@ -67,11 +67,7 @@ void GraphRuntime::Run() { ...@@ -67,11 +67,7 @@ void GraphRuntime::Run() {
void GraphRuntime::Init(const std::string& graph_json, void GraphRuntime::Init(const std::string& graph_json,
tvm::runtime::Module module, tvm::runtime::Module module,
const std::vector<TVMContext>& ctxs) { const std::vector<TVMContext>& ctxs) {
#ifndef _LIBCPP_SGX_NO_IOSTREAMS
std::istringstream is(graph_json); std::istringstream is(graph_json);
#else
std::string is = graph_json;
#endif
dmlc::JSONReader reader(&is); dmlc::JSONReader reader(&is);
this->Load(&reader); this->Load(&reader);
module_ = module; module_ = module;
......
...@@ -21,9 +21,7 @@ ...@@ -21,9 +21,7 @@
* \file module_util.cc * \file module_util.cc
* \brief Utilities for module. * \brief Utilities for module.
*/ */
#ifndef _LIBCPP_SGX_CONFIG
#include <dmlc/memory_io.h> #include <dmlc/memory_io.h>
#endif
#include <tvm/runtime/module.h> #include <tvm/runtime/module.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <string> #include <string>
...@@ -121,7 +119,6 @@ void InitContextFunctions(std::function<void*(const char*)> fgetsymbol) { ...@@ -121,7 +119,6 @@ void InitContextFunctions(std::function<void*(const char*)> fgetsymbol) {
* \return Root Module. * \return Root Module.
*/ */
runtime::Module ProcessModuleBlob(const char* mblob, ObjectPtr<Library> lib) { runtime::Module ProcessModuleBlob(const char* mblob, ObjectPtr<Library> lib) {
#ifndef _LIBCPP_SGX_CONFIG
CHECK(mblob != nullptr); CHECK(mblob != nullptr);
uint64_t nbytes = 0; uint64_t nbytes = 0;
for (size_t i = 0; i < sizeof(nbytes); ++i) { for (size_t i = 0; i < sizeof(nbytes); ++i) {
...@@ -180,10 +177,6 @@ runtime::Module ProcessModuleBlob(const char* mblob, ObjectPtr<Library> lib) { ...@@ -180,10 +177,6 @@ runtime::Module ProcessModuleBlob(const char* mblob, ObjectPtr<Library> lib) {
// invariance: root module is always at location 0. // invariance: root module is always at location 0.
// The module order is collected via DFS // The module order is collected via DFS
return modules[0]; return modules[0];
#else
LOG(FATAL) << "SGX does not support ImportModuleBlob";
return Module();
#endif
} }
Module CreateModuleFromLibrary(ObjectPtr<Library> lib) { Module CreateModuleFromLibrary(ObjectPtr<Library> lib) {
......
...@@ -26,9 +26,7 @@ ...@@ -26,9 +26,7 @@
#include <tvm/runtime/packed_func.h> #include <tvm/runtime/packed_func.h>
#include <unordered_set> #include <unordered_set>
#include <cstring> #include <cstring>
#ifndef _LIBCPP_SGX_CONFIG
#include "file_util.h" #include "file_util.h"
#endif
namespace tvm { namespace tvm {
namespace runtime { namespace runtime {
...@@ -77,7 +75,6 @@ PackedFunc ModuleNode::GetFunction(const std::string& name, bool query_imports) ...@@ -77,7 +75,6 @@ PackedFunc ModuleNode::GetFunction(const std::string& name, bool query_imports)
Module Module::LoadFromFile(const std::string& file_name, Module Module::LoadFromFile(const std::string& file_name,
const std::string& format) { const std::string& format) {
#ifndef _LIBCPP_SGX_CONFIG
std::string fmt = GetFileFormat(file_name, format); std::string fmt = GetFileFormat(file_name, format);
CHECK(fmt.length() != 0) CHECK(fmt.length() != 0)
<< "Cannot deduce format of file " << file_name; << "Cannot deduce format of file " << file_name;
...@@ -91,9 +88,6 @@ Module Module::LoadFromFile(const std::string& file_name, ...@@ -91,9 +88,6 @@ Module Module::LoadFromFile(const std::string& file_name,
<< load_f_name << ") is not presented."; << load_f_name << ") is not presented.";
Module m = (*f)(file_name, format); Module m = (*f)(file_name, format);
return m; return m;
#else
LOG(FATAL) << "SGX does not support LoadFromFile";
#endif
} }
void ModuleNode::SaveToFile(const std::string& file_name, void ModuleNode::SaveToFile(const std::string& file_name,
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file common.h
* \brief TVM SGX common API.
*/
#ifndef TVM_RUNTIME_SGX_COMMON_H_
#define TVM_RUNTIME_SGX_COMMON_H_
#include <sgx_error.h>
namespace tvm {
namespace runtime {
namespace sgx {
#define TVM_SGX_CHECKED_CALL(Function) \
sgx_status_t TVM_STR_CONCAT(__sgx_status_, __LINE__) = SGX_ERROR_UNEXPECTED; \
TVM_STR_CONCAT(__sgx_status_, __LINE__) = Function; \
CHECK_EQ(TVM_STR_CONCAT(__sgx_status_, __LINE__), SGX_SUCCESS) \
<< "SGX Error: " << TVM_STR_CONCAT(__sgx_status_, __LINE__);
} // namespace sgx
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_SGX_COMMON_H_
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file ecall_registry.h
* \brief The global registry of packed functions available via ecall_packed_func.
*/
#ifndef TVM_RUNTIME_SGX_TRUSTED_ECALL_REGISTRY_H_
#define TVM_RUNTIME_SGX_TRUSTED_ECALL_REGISTRY_H_
#include <dmlc/logging.h>
#include <tvm/runtime/registry.h>
#include <string>
#include <algorithm>
#include <vector>
namespace tvm {
namespace runtime {
namespace sgx {
class ECallRegistry: public Registry {
public:
explicit ECallRegistry(std::string name) {
name_ = name;
}
Registry& set_body(PackedFunc f) {
func_ = f;
return *this;
}
Registry& set_body(PackedFunc::FType f) { // NOLINT(*)
return set_body(PackedFunc(f));
}
static Registry& Register(const std::string& name, bool override = false) {
for (auto& r : exports_) {
if (r.name_ == name) {
CHECK(override) << "ecall " << name << " is already registered";
return r;
}
}
TVM_SGX_CHECKED_CALL(
tvm_ocall_register_export(name.c_str(), exports_.size()));
exports_.emplace_back(name);
return exports_.back();
}
static bool Remove(const std::string& name) {
LOG(FATAL) << "Removing enclave exports is not supported.";
}
static const PackedFunc* Get(const std::string& name) {
for (const auto& r : exports_) {
if (r.name_ == name) return &r.func_;
}
return nullptr;
}
static const PackedFunc* Get(unsigned func_id) {
return func_id >= exports_.size() ? nullptr : &exports_[func_id].func_;
}
static std::vector<std::string> ListNames() {
std::vector<std::string> names;
names.resize(exports_.size());
std::transform(exports_.begin(), exports_.end(), names.begin(),
[](ECallRegistry r) { return r.name_; });
return names;
}
static std::vector<ECallRegistry> exports_;
};
std::vector<ECallRegistry> ECallRegistry::exports_;
/*!
* \brief Register a function callable via ecall_packed_func
* \code
* TVM_REGISTER_ENCLAVE_FUNC("DoThing")
* .set_body([](TVMArgs args, TVMRetValue* rv) {
* });
* \endcode
*/
#define TVM_REGISTER_ENCLAVE_FUNC(OpName) \
TVM_STR_CONCAT(TVM_FUNC_REG_VAR_DEF, __COUNTER__) = \
::tvm::runtime::sgx::ECallRegistry::Register(OpName, true)
} // namespace sgx
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_SGX_TRUSTED_ECALL_REGISTRY_H_
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file runtime_t.cc
*/
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/packed_func.h>
#include "../../c_runtime_api.cc"
#include "../../cpu_device_api.cc"
#include "../../module.cc"
#include "../../module_util.cc"
#include "../../registry.cc"
#include "../../system_lib_module.cc"
#include "../../thread_pool.cc"
#include "../../workspace_pool.cc"
#include "ecall_registry.h"
#include "runtime.h"
#include "threading_backend.cc"
namespace tvm {
namespace runtime {
namespace sgx {
extern "C" {
void tvm_ecall_init(TVMRetValueHandle ret) {}
void tvm_ecall_packed_func(int func_id,
const TVMValue* arg_values,
const int* type_codes,
int num_args,
TVMRetValueHandle ret) {
const PackedFunc* f = ECallRegistry::Get(func_id);
CHECK(f != nullptr) << "ecall function not found.";
TVMRetValue rv;
f->CallPacked(TVMArgs(arg_values, type_codes, num_args), &rv);
int ret_type_code = rv.type_code();
if (ret_type_code == kTVMNullptr) return;
TVMValue ret_value;
if (ret_type_code == kTVMBytes || ret_type_code == kTVMStr) {
// allocate a buffer in untrusted, copy the values in
std::string bytes = rv;
void* ret_buf;
TVM_SGX_CHECKED_CALL(tvm_ocall_reserve_space(
&ret_buf, bytes.size() + sizeof(TVMByteArray), sizeof(uint64_t)));
char* data_buf = static_cast<char*>(ret_buf) + sizeof(TVMByteArray);
memcpy(data_buf, bytes.data(), bytes.size());
TVMByteArray* arr = static_cast<TVMByteArray*>(ret_buf);
arr->data = data_buf;
arr->size = bytes.size();
ret_value = TVMValue{.v_handle = arr};
ret_type_code = kTVMBytes;
} else {
rv.MoveToCHost(&ret_value, &ret_type_code);
}
TVM_SGX_CHECKED_CALL(tvm_ocall_set_return(ret, &ret_value, &ret_type_code, 1));
}
} // extern "C"
TVM_REGISTER_ENCLAVE_FUNC("__tvm_main__")
.set_body([](TVMArgs args, TVMRetValue* rv) {
Module mod = (*Registry::Get("runtime.SystemLib"))();
mod.GetFunction("default_function").CallPacked(args, rv);
});
} // namespace sgx
} // namespace runtime
} // namespace tvm
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file trusted/runtime.h
* \brief TVM SGX trusted API.
*/
#ifndef TVM_RUNTIME_SGX_TRUSTED_RUNTIME_H_
#define TVM_RUNTIME_SGX_TRUSTED_RUNTIME_H_
#include <tvm/runtime/packed_func.h>
#include <string>
#include <utility>
#include "../common.h"
namespace tvm {
namespace runtime {
namespace sgx {
template<typename... Args>
inline TVMRetValue OCallPackedFunc(std::string name, Args&& ...args) {
const int kNumArgs = sizeof...(Args);
const int kArraySize = kNumArgs > 0 ? kNumArgs : 1;
TVMValue values[kArraySize];
int type_codes[kArraySize];
detail::for_each(TVMArgsSetter(values, type_codes),
std::forward<Args>(args)...);
TVMValue ret_val;
int ret_type_code;
TVM_SGX_CHECKED_CALL(tvm_ocall_packed_func(name.c_str(),
values,
type_codes,
kNumArgs,
&ret_val,
&ret_type_code));
TVMRetValue* rv = new TVMRetValue();
*rv = TVMArgValue(ret_val, ret_type_code);
return *rv;
}
} // namespace sgx
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_SGX_TRUSTED_RUNTIME_H_
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file sgx/threading_backend.cc
* \brief SGX threading backend
*/
#include <tvm/runtime/threading_backend.h>
#include <dmlc/logging.h>
#include <sgx_edger8r.h>
#include <sgx_trts.h>
#include <atomic>
#include "runtime.h"
#ifndef TVM_SGX_MAX_CONCURRENCY
#define TVM_SGX_MAX_CONCURRENCY 1
#endif
namespace tvm {
namespace runtime {
namespace threading {
class ThreadGroup::Impl {
public:
Impl(int num_workers, std::function<void(int)> worker_callback,
bool exclude_worker0)
: num_workers_(num_workers),
worker_callback_(worker_callback),
next_task_id_(exclude_worker0) {
CHECK(num_workers <= TVM_SGX_MAX_CONCURRENCY)
<< "Tried spawning more threads than allowed by TVM_SGX_MAX_CONCURRENCY.";
sgx::OCallPackedFunc("__sgx_thread_group_launch__",
num_workers_, reinterpret_cast<void*>(this));
}
~Impl() {
sgx::OCallPackedFunc("__sgx_thread_group_join__");
}
void RunTask() {
int task_id = next_task_id_++;
CHECK(task_id < num_workers_)
<< "More workers entered enclave than allowed by TVM_SGX_MAX_CONCURRENCY";
worker_callback_(task_id);
}
private:
int num_workers_;
std::function<void(int)> worker_callback_;
std::atomic<int> next_task_id_;
};
ThreadGroup::ThreadGroup(int num_workers,
std::function<void(int)> worker_callback,
bool exclude_worker0)
: impl_(new ThreadGroup::Impl(num_workers, worker_callback, exclude_worker0)) {}
void ThreadGroup::Join() {}
int ThreadGroup::Configure(AffinityMode mode, int nthreads, bool exclude_worker0) {
int max_conc = MaxConcurrency();
if (!nthreads || ntheads > max_conc) {
return max_conc;
}
return nthreads;
}
ThreadGroup::~ThreadGroup() { delete impl_; }
void Yield() {}
int MaxConcurrency() { return TVM_SGX_MAX_CONCURRENCY; }
TVM_REGISTER_ENCLAVE_FUNC("__tvm_run_worker__")
.set_body([](TVMArgs args, TVMRetValue* rv) {
void* tg = args[0];
if (!sgx_is_within_enclave(tg, sizeof(ThreadGroup::Impl))) return;
reinterpret_cast<ThreadGroup::Impl*>(tg)->RunTask();
});
} // namespace threading
} // namespace runtime
} // namespace tvm
enclave {
from "sgx_tstdc.edl" import *;
from "sgx_stdio.edl" import *;
from "sgx_backtrace.edl" import *;
trusted {
public void tvm_ecall_init([isptr, user_check] TVMRetValueHandle ret);
public void tvm_ecall_packed_func(int func_id,
[in, count=num_args] const TVMValue* arg_values,
[in, count=num_args] const int* type_codes,
int num_args,
[out] TVMValue* ret_val,
[out] int* ret_type_code);
};
untrusted {
void tvm_ocall_packed_func([in, string] const char* name,
[in, count=num_args] const TVMValue* arg_values,
[in, count=num_args] const int* type_codes,
int num_args,
[out] TVMValue* ret_val,
[out] int* ret_type_code);
void tvm_ocall_register_export([in, string] const char* name, int func_id);
};
};
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file sgx_module.cc
* \brief SGX enclave module.
*/
#include <dmlc/logging.h>
#include <sgx_urts.h>
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/threading_backend.h>
#include <algorithm>
#include <fstream>
#include <iostream>
#include <iterator>
#include <sstream>
#include <string>
#include <unordered_map>
#include "../common.h"
#include "../../file_util.h"
#include "./tvm_u.h"
namespace tvm {
namespace runtime {
class SGXModuleNode;
namespace sgx {
class EnclaveContext {
public:
explicit EnclaveContext(SGXModuleNode* mod) {
CHECK(Context()->mod_ == nullptr)
<< "Tried overriding existing enclave context.";
CHECK(mod != nullptr) << "Tried setting null enclave context.";
Context()->mod_ = mod;
}
~EnclaveContext() {
Context()->mod_ = nullptr;
}
static SGXModuleNode* GetModule() {
SGXModuleNode* ctx = Context()->mod_;
CHECK(ctx != nullptr) << "No current enclave context";
return ctx;
}
private:
EnclaveContext() {}
SGXModuleNode* mod_;
static EnclaveContext* Context() {
static thread_local EnclaveContext inst;
return &inst;
}
};
} // namespace sgx
class SGXModuleNode : public ModuleNode {
public:
~SGXModuleNode() {
if (eid_) {
sgx::EnclaveContext ctx(this);
sgx_destroy_enclave(eid_);
}
}
void Init(const std::string& enclave_file) {
std::string token_file = GetCacheDir() + "/" +
GetFileBasename(enclave_file) + ".token";
sgx_launch_token_t token = {0};
int token_updated = 0;
try {
std::ifstream ifs(token_file, std::fstream::in | std::fstream::binary);
ifs.exceptions(std::ifstream::failbit | std::ifstream::badbit);
ifs >> token;
} catch (std::ifstream::failure e) {
memset(&token, 0x0, sizeof(sgx_launch_token_t));
}
TVM_SGX_CHECKED_CALL(sgx_create_enclave(
enclave_file.c_str(), SGX_DEBUG_FLAG, &token, &token_updated, &eid_, NULL));
sgx::EnclaveContext ctx(this);
TVMRetValue rv;
TVM_SGX_CHECKED_CALL(tvm_ecall_init(eid_, &rv));
if (!token_updated) return;
try {
std::ofstream ofs(token_file, std::fstream::trunc | std::fstream::binary);
ofs.exceptions(std::ifstream::failbit | std::ifstream::badbit);
ofs << token;
} catch (std::ifstream::failure e) {
LOG(INFO) << "Could not save SGX launch token to " << token_file;
}
}
const char* type_key() const final {
return "sgx";
}
PackedFunc GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final {
auto exported = exports_.find(name);
if (exported == exports_.end()) return PackedFunc();
int func_id = exported->second;
return PackedFunc([this, func_id](TVMArgs args, TVMRetValue* rv) {
sgx::EnclaveContext ctx(this);
TVMValue ret_value;
int ret_type_code;
TVM_SGX_CHECKED_CALL(tvm_ecall_packed_func(eid_, func_id,
args.values, args.type_codes, args.num_args, &ret_value, &ret_type_code));
*rv = TVMArgValue(ret_value, ret_type_code);
});
}
void RunWorkers(int num_tasks) {
std::function<void(int)> runner = [this](int _worker_id) {
this->GetFunction("__tvm_run_worker__",
std::shared_ptr<SGXModuleNode>(nullptr))();
};
thread_group_.reset(new tvm::runtime::threading::ThreadGroup(
num_tasks, runner, false /* include_main_thread */));
}
void JoinThreads() {
thread_group_->Join();
}
void RegisterExport(std::string name, int func_id) {
exports_[name] = func_id;
}
private:
// ID of the loaded enclave
sgx_enclave_id_t eid_;
// Names and IDs of functions exported by the enclave module
std::unordered_map<std::string, int> exports_;
std::unique_ptr<tvm::runtime::threading::ThreadGroup> thread_group_;
};
namespace sgx {
TVM_REGISTER_GLOBAL("__sgx_thread_group_launch__")
.set_body([](TVMArgs args, TVMRetValue* rv) {
EnclaveContext::GetModule()->RunWorkers(args[0]);
});
TVM_REGISTER_GLOBAL("__sgx_thread_group_join__")
.set_body([](TVMArgs args, TVMRetValue* rv) {
EnclaveContext::GetModule()->JoinThreads();
});
TVM_REGISTER_GLOBAL("__sgx_set_last_error__")
.set_body([](TVMArgs args, TVMRetValue* rv) {
std::string err = args[0];
TVMAPISetLastError(err.c_str());
});
TVM_REGISTER_GLOBAL("__sgx_println__")
.set_body([](TVMArgs args, TVMRetValue* rv) {
std::ostringstream msg;
for (int i = 0; i < args.num_args; ++i) {
switch (args.type_codes[i]) {
case kDLInt: msg << static_cast<int64_t>(args[i]); break;
case kDLUInt: msg << static_cast<uint64_t>(args[i]); break;
case kDLFloat: msg << static_cast<double>(args[i]); break;
case kTVMStr:
case kTVMBytes: {
std::string val = args[i];
msg << val;
}
break;
}
msg << " ";
}
LOG(INFO) << msg.str();
});
extern "C" {
void tvm_ocall_register_export(const char* name, int func_id) {
EnclaveContext::GetModule()->RegisterExport(name, func_id);
}
void tvm_ocall_packed_func(const char* name,
const TVMValue* arg_values,
const int* type_codes,
int num_args,
TVMValue* ret_val,
int* ret_type_code) {
const PackedFunc* f = Registry::Get(name);
CHECK(f != nullptr) << "ocall to nonexistent function \"" << name << "\"";
TVMRetValue rv;
f->CallPacked(TVMArgs(arg_values, type_codes, num_args), &rv);
rv.MoveToCHost(ret_val, ret_type_code);
}
// Allocates space for return values. The returned pointer is only valid between
// successive calls to `tvm_ocall_reserve_space`.
TVM_REGISTER_GLOBAL("__sgx_reserve_space__")
.set_body([](TVMArgs args, TVMRetValue* rv) {
size_t num_bytes = args[0];
size_t alignment = args[1];
static TVMContext ctx = { kDLCPU, 0 };
static thread_local void* buf = nullptr;
static thread_local size_t buf_size = 0;
static thread_local size_t buf_align = 0;
if (buf_size >= num_bytes && buf_align >= alignment) *rv = nullptr;
DeviceAPI::Get(ctx)->FreeDataSpace(ctx, buf);
buf = DeviceAPI::Get(ctx)->AllocDataSpace(ctx, num_bytes, alignment, {});
buf_size = num_bytes;
buf_align = alignment;
*rv = buf;
});
} // extern "C"
} // namespace sgx
TVM_REGISTER_GLOBAL("runtime.module.loadfile_sgx")
.set_body([](TVMArgs args, TVMRetValue* rv) {
std::shared_ptr<SGXModuleNode> node = std::make_shared<SGXModuleNode>();
node->Init(args[0]);
*rv = runtime::Module(node);
});
} // namespace runtime
} // namespace tvm
...@@ -105,16 +105,14 @@ class ParallelLauncher { ...@@ -105,16 +105,14 @@ class ParallelLauncher {
tvm::runtime::threading::Yield(); tvm::runtime::threading::Yield();
} }
if (!has_error_.load()) return 0; if (!has_error_.load()) return 0;
// the following is intended to use string due to std::ostringstream os;
// security issue raised in SGX backend
std::string err("");
for (size_t i = 0; i < par_errors_.size(); ++i) { for (size_t i = 0; i < par_errors_.size(); ++i) {
if (par_errors_[i].length() != 0) { if (par_errors_[i].length() != 0) {
err += "Task " + std::to_string(i) + " error: " + par_errors_[i] + '\n'; os << "Task " << i << " error: " << par_errors_[i] << '\n';
par_errors_[i].clear(); par_errors_[i].clear();
} }
} }
TVMAPISetLastError(err.c_str()); TVMAPISetLastError(os.str().c_str());
return -1; return -1;
} }
// Signal that one job has finished. // Signal that one job has finished.
...@@ -373,11 +371,7 @@ class ThreadPool { ...@@ -373,11 +371,7 @@ class ThreadPool {
// number of workers used (can be restricted with affinity pref) // number of workers used (can be restricted with affinity pref)
int num_workers_used_; int num_workers_used_;
// if or not to exclude worker 0 and use master to run task 0 // if or not to exclude worker 0 and use master to run task 0
#ifndef _LIBCPP_SGX_CONFIG
bool exclude_worker0_{true}; bool exclude_worker0_{true};
#else
bool exclude_worker0_{false};
#endif
std::vector<std::unique_ptr<SpscTaskQueue> > queues_; std::vector<std::unique_ptr<SpscTaskQueue> > queues_;
std::unique_ptr<tvm::runtime::threading::ThreadGroup> threads_; std::unique_ptr<tvm::runtime::threading::ThreadGroup> threads_;
}; };
......
...@@ -89,6 +89,7 @@ ALLOW_FILE_NAME = { ...@@ -89,6 +89,7 @@ ALLOW_FILE_NAME = {
".gitmodules", ".gitmodules",
"CODEOWNERS", "CODEOWNERS",
".scalafmt.conf", ".scalafmt.conf",
"Cargo.lock"
} }
# List of specific files allowed in relpath to <proj_root> # List of specific files allowed in relpath to <proj_root>
...@@ -99,8 +100,8 @@ ALLOW_SPECIFIC_FILE = { ...@@ -99,8 +100,8 @@ ALLOW_SPECIFIC_FILE = {
"KEYS", "KEYS",
"DISCLAIMER", "DISCLAIMER",
"Jenkinsfile", "Jenkinsfile",
# sgx file # sgx config
"apps/sgx/enclave/sgx-deps.diff", "apps/sgx/.cargo/config",
# html for demo purposes # html for demo purposes
"tests/webgl/test_static_webgl_library.html", "tests/webgl/test_static_webgl_library.html",
"web/example_rpc.html", "web/example_rpc.html",
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment