Unverified Commit 41e1d5f9 by Jared Roesch Committed by GitHub

Revive the Rust + SGX refactor (#4976)

* Add Nick's changes's squashed

* Fix frontend compilation

* Re-enable Rust CI

* Add changes with conflicted badly

* Restructure import_module! macro in order to avoid unstable features

* Kill old unstable feature enablement

* Refactor common to use new APIs

* Move the code to stable

* Fix warning

Co-authored-by: Nick Hynes <nhynes@oasislabs.com>
parent 93dff448
......@@ -36,7 +36,6 @@ tvm_option(USE_GRAPH_RUNTIME "Build with tiny graph runtime" ON)
tvm_option(USE_GRAPH_RUNTIME_DEBUG "Build with tiny graph runtime debug mode" OFF)
tvm_option(USE_OPENMP "Build with OpenMP thread pool implementation" OFF)
tvm_option(USE_RELAY_DEBUG "Building Relay in debug mode..." OFF)
tvm_option(USE_SGX "Build with SGX" OFF)
tvm_option(USE_RTTI "Build with RTTI" ON)
tvm_option(USE_MSVC_MT "Build with MT" OFF)
tvm_option(USE_MICRO "Build with Micro" OFF)
......@@ -243,7 +242,6 @@ include(cmake/modules/OpenMP.cmake)
include(cmake/modules/Vulkan.cmake)
include(cmake/modules/Metal.cmake)
include(cmake/modules/ROCM.cmake)
include(cmake/modules/SGX.cmake)
include(cmake/modules/LLVM.cmake)
include(cmake/modules/Micro.cmake)
include(cmake/modules/ANTLR.cmake)
......@@ -283,12 +281,6 @@ else()
set_target_properties(tvm PROPERTIES COMPILE_DEFINITIONS "NDEBUG")
endif(USE_RELAY_DEBUG)
if(NOT USE_SGX STREQUAL "OFF")
add_dependencies(tvm sgx_edl)
add_dependencies(tvm_runtime sgx_edl tvm_t)
install(TARGETS tvm_t ARCHIVE DESTINATION lib${LIB_SUFFIX})
endif()
if(USE_THREADS)
message(STATUS "Build with thread support...")
set(CMAKE_THREAD_PREFER_PTHREAD TRUE)
......
......@@ -220,6 +220,7 @@ stage('Build') {
sh "${docker_run} ${ci_cpu} ./tests/scripts/task_python_vta_fsim.sh"
sh "${docker_run} ${ci_cpu} ./tests/scripts/task_python_vta_tsim.sh"
sh "${docker_run} ${ci_cpu} ./tests/scripts/task_golang.sh"
sh "${docker_run} ${ci_cpu} ./tests/scripts/task_rust.sh"
}
}
}
......
[build]
target = "x86_64-fortanix-unknown-sgx"
[target.x86_64-fortanix-unknown-sgx]
runner = "ftxsgx-runner-cargo"
../../rust/.rustfmt.toml
\ No newline at end of file
......@@ -16,17 +16,13 @@
# under the License.
[package]
name = "model-enclave"
name = "sgx-demo"
version = "0.1.0"
authors = ["Nick Hynes <nhynes@berkeley.edu>"]
[lib]
crate-type = ["staticlib"]
authors = ["Nick Hynes <nhynes@nhynes.com>"]
edition = "2018"
[dependencies]
lazy_static = "1.1.0"
tvm = { path = "../../../rust", default-features = false, features = ["sgx"] }
tvm-runtime = { path = "../../rust/runtime" }
[profile.release]
lto = true
opt-level = 3
[patch.crates-io]
"backtrace" = { git = "https://github.com/nhynes/backtrace-rs", branch = "fix-sgx" }
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
SGX_SDK ?= /opt/sgxsdk
RUST_SGX_SDK ?= /opt/rust-sgx-sdk
SGX_MODE ?= SIM
DEBUG ?= true
NUM_THREADS ?= 4
TVM_DIR ?= $(shell git rev-parse --show-toplevel)
export
sgx_edger8r := $(SGX_SDK)/bin/x64/sgx_edger8r
sgx_enclave_signer := $(SGX_SDK)/bin/x64/sgx_sign
ifneq ($(SGX_MODE), HW)
sgx_sim := _sim
endif
urts_library_name := sgx_urts$(sgx_sim)
trts_library_name := sgx_trts$(sgx_sim)
tservice_library_name := sgx_tservice$(sgx_sim)
uservice_library_name := sgx_uae_service$(sgx_sim)
pkg_cflags := -std=c++11 -fPIC \
-I$(SGX_SDK)/include \
-I$(TVM_DIR)/include \
-I$(TVM_DIR)/dlpack/include \
-I$(TVM_DIR)/dmlc-core/include
pkg_ldflags := -L$(TVM_DIR)/build -ltvm_runtime
ifneq ($(DEBUG), false)
debug := debug
enclave_cflags += -Og -g
pkg_cflags += -Og -g
else
debug := release
enclave_cflags += -O2
pkg_cflags += -O2
endif
build_dir := build
enclave_cflags := \
-I$(SGX_SDK)/include \
-I$(SGX_SDK)/include/tlibc \
-I$(SGX_SDK)/include/stdport \
-I$(SGX_SDK)/include/epid \
-I$(TVM_DIR)/include \
-I$(TVM_DIR)/dlpack/include \
-I$(TVM_DIR)/dmlc-core/include
enclave_ldflags :=\
-L$(build_dir) -L$(TVM_DIR)/build \
-Wl,--no-undefined -nostdlib -nodefaultlibs -nostartfiles -L$(SGX_SDK)/lib64\
-Wl,--whole-archive -l$(trts_library_name) -Wl,--no-whole-archive\
-Wl,--start-group\
-lsgx_tstdc -lsgx_tstdcxx -lsgx_tcxx -lsgx_tcrypto -lsgx_tkey_exchange -l$(tservice_library_name)\
-lenclave -ltvm_t\
-Wl,--end-group\
-Wl,-Bstatic -Wl,-Bsymbolic -Wl,--no-undefined\
-Wl,-pie,-eenclave_entry -Wl,--export-dynamic\
-Wl,--defsym,__ImageBase=0 -Wl,--gc-sections\
-Wl,--version-script=enclave/enclave.lds
.PHONY: enclave clean
enclave: $(build_dir)/enclave.signed.so
$(build_dir)/enclave.signed.so: $(build_dir)/enclave.so build/enclave_config.xml enclave/enclave.pem
$(sgx_enclave_signer) sign -key enclave/enclave.pem -enclave $< -out $@ -config build/enclave_config.xml
enclave/enclave.pem:
curl -sSo $@ 'https://gist.githubusercontent.com/nhynes/8a2d80068a92e672f8b0b7d710ceb404/raw/2d5ae5fbe83198ede49465fdc6535065e093543b/tvm_sgx_demo.pem'
build/enclave_config.xml: enclave/enclave_config.xml.in
cpp $^ -P -o $@ -DNUM_THREADS=$$(( $(NUM_THREADS) + 1 ))
$(build_dir)/enclave.so: $(build_dir)/libenclave.a $(TVM_DIR)/build/libtvm_t.a
$(CXX) $< -o $@ $(enclave_ldflags) $(enclave_cflags) -ltvm_t
$(build_dir)/libenclave.a: enclave/target/x86_64-unknown-linux-sgx/$(debug)/libmodel_enclave.a
@mkdir -p $(@D)
@cp $< $@
enclave/target/x86_64-unknown-linux-sgx/$(debug)/libmodel_enclave.a: enclave/**/*
$(MAKE) -C enclave
clean:
$(MAKE) -s -C enclave clean
rm -rf build
......@@ -15,64 +15,10 @@
<!--- specific language governing permissions and limitations -->
<!--- under the License. -->
# TVM in Intel SGX Example
## Setup
This application demonstrates the use of a simple TVM model in the [Intel SGX](https://software.intel.com/en-us/blogs/2013/09/26/protecting-application-secrets-with-intel-sgx) trusted computing environment.
## Prerequisites
1. The TVM premade Docker image
or
1. A GNU/Linux environment
2. TVM compiled with LLVM and SGX; and the `tvm` Python module
3. The [Linux SGX SDK](https://github.com/intel/linux-sgx) [link to pre-built libraries](https://01.org/intel-software-guard-extensions/downloads)
4. [Rust](https://rustup.sh)
5. The [rust-sgx-sdk](https://github.com/baidu/rust-sgx-sdk)
6. [xargo](https://github.com/japaric/xargo)
Check out the `/tvm/install/ubuntu_install_sgx.sh` for the commands to get these dependencies.
## Running the example
If using Docker, start by running
```
git clone --recursive https://github.com/apache/incubator-tvm.git tvm
docker run --rm -it -v $(pwd)/tvm:/mnt tvmai/ci-cpu /bin/bash
```
then, in the container
```
cd /mnt
mkdir build && cd build
cmake .. -DUSE_LLVM=ON -DUSE_SGX=/opt/sgxsdk -DRUST_SGX_SDK=/opt/rust-sgx-sdk
make -j4
cd ..
pip install -e python -e topi/python
cd apps/sgx
```
Once TVM is build and installed, just
`./run_example.sh`
If everything goes well, you should see a lot of build messages and below them
the text `It works!`.
## High-level overview
First of all, it helps to think of an SGX enclave as a library that can be called
to perform trusted computation.
In this library, one can use other libraries like TVM.
Building this example performs the following steps:
1. Creates a simple TVM module that computes `x + 1` and save it as a system library.
2. Builds a TVM runtime that links the module and allows running it using the TVM Python runtime.
3. Packages the bundle into an SGX enclave
4. Runs the enclave using the usual TVM Python `module` API
For more information on building, please refer to the `Makefile`.
For more information on the TVM module, please refer to `../howto_deploy`.
For more in formation on SGX enclaves, please refer to the [SGX Enclave Demo](https://github.com/intel/linux-sgx/tree/master/SampleCode/SampleEnclave/)
1. [Install the Fortanix Enclave Development Platform](https://edp.fortanix.com/docs/installation/guide/)
2. `rustup component add llvm-tools-preview` to get `llvm-ar` and `llvm-objcopy`
3. `pip install numpy decorator psutil`
4. `cargo run` to start the enclave TCP server
5. Send a 28x28 "image" to the enclave model server using `head -c $((28*28*4)) /dev/urandom | nc 127.0.0.1 4242 | python read_results.py`
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
use std::process::Command;
macro_rules! mf_dir {
($p:literal) => {
concat!(env!("CARGO_MANIFEST_DIR"), $p)
};
}
fn main() {
let out_dir = std::env::var("OUT_DIR").unwrap();
let build_output = Command::new(mf_dir!("/src/build_model.py"))
.arg(&out_dir)
.env(
"PYTHONPATH",
concat!(
mf_dir!("/../../python"),
":",
mf_dir!("/../../nnvm/python"),
":",
mf_dir!("/../../topi/python")
),
)
.output()
.expect("Failed to build model");
assert!(
["model.o", "graph.json", "params.bin"]
.iter()
.all(|f| { std::path::Path::new(&format!("{}/{}", out_dir, f)).exists() }),
"Could not build tvm lib: STDOUT:\n\n{}\n\nSTDERR\n\n{}",
String::from_utf8(build_output.stdout).unwrap().trim(),
String::from_utf8(build_output.stderr).unwrap().trim()
);
let sysroot_output = Command::new("rustc")
.args(&["--print", "sysroot"])
.output()
.expect("Failed to get sysroot");
let sysroot = String::from_utf8(sysroot_output.stdout).unwrap();
let sysroot = sysroot.trim();
let mut llvm_tools_path = std::path::PathBuf::from(&sysroot);
llvm_tools_path.push("lib/rustlib/x86_64-unknown-linux-gnu/bin");
Command::new("rustup")
.args(&["component", "add", "llvm-tools-preview"])
.output()
.expect("failed to install llvm tools");
std::process::Command::new(llvm_tools_path.join("llvm-objcopy"))
.arg("--globalize-symbol=__tvm_module_startup")
.arg("--remove-section=.ctors")
.arg(&format!("{}/model.o", out_dir))
.output()
.expect("gould not gloablize startup function");
std::process::Command::new(llvm_tools_path.join("llvm-ar"))
.arg("rcs")
.arg(&format!("{}/libmodel.a", out_dir))
.arg(&format!("{}/model.o", out_dir))
.output()
.expect("failed to package model archive");
println!("cargo:rustc-link-lib=static=model");
println!("cargo:rustc-link-search=native={}", out_dir);
}
../../../rust/.rustfmt.toml
\ No newline at end of file
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
MODEL ?= resnet
NUM_THREADS ?= 4
BATCH_SIZE ?= 64
TRAINING ?= true
DEBUG ?= false
build_dir := ../build
ifeq ($(DEBUG), false)
debug := release
xargo_args := --release
else
debug := debug
endif
target=target/x86_64-unknown-linux-sgx/$(debug)/libmodel-enclave.a
$(target): $(build_dir)/libmodel.a **/* $(TVM_DIR)/rust/patched.txt
RUST_TARGET_PATH=$(shell pwd) \
RUST_TARGET_DIR=$(shell pwd)/target \
RUSTFLAGS="-Z force-unstable-if-unmarked" \
TVM_NUM_THREADS=$(NUM_THREADS) \
BUILD_DIR=../build \
xargo build --target x86_64-unknown-linux-sgx $(xargo_args) -q
$(TVM_DIR)/rust/patched.txt: $(shell pwd)/sgx-deps.diff
echo $(TVM_DIR)
cd $(TVM_DIR) && git apply $<
touch $@
$(build_dir)/libmodel.a: $(build_dir)/model.o
$(AR) cr $@ $^
$(build_dir)/model.o: $(build_dir)/model.bc
$(CC) -c $< -o $@ -fPIC -O3
objcopy --globalize-symbol __tvm_module_startup $@
$(build_dir)/model.bc: src/build_model.py
python3 $< -o $(build_dir)
clean:
xargo clean
enclave.so
{
global:
g_global_data_sim;
g_global_data;
enclave_entry;
local:
*;
};
<EnclaveConfiguration>
<ProdID>0</ProdID>
<ISVSVN>0</ISVSVN>
<StackMaxSize>0xf0000</StackMaxSize>
<HeapMaxSize>0xf000000</HeapMaxSize>
<TCSNum>NUM_THREADS</TCSNum>
<TCSPolicy>0</TCSPolicy> <!-- must be "bound" to use thread_local -->
<DisableDebug>0</DisableDebug>
<MiscSelect>0</MiscSelect>
<MiscMask>0xFFFFFFFF</MiscMask>
</EnclaveConfiguration>
diff --git a/rust/Cargo.toml b/rust/Cargo.toml
index 0819e0c7..e56f4ef2 100644
--- a/rust/Cargo.toml
+++ b/rust/Cargo.toml
@@ -14,7 +14,7 @@ default = ["nom/std"]
sgx = ["nom/alloc"]
[dependencies]
-bounded-spsc-queue = "0.4.0"
+bounded-spsc-queue = { git = "https://github.com/nhynes/bounded-spsc-queue", branch = "sgx" }
error-chain = { version = "0.12.0", default-features = false }
itertools = "0.7.8"
lazy_static = "1.1.0"
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#![feature(try_from)]
#[macro_use]
extern crate lazy_static;
#[macro_use]
extern crate tvm;
use std::{
convert::{TryFrom, TryInto},
sync::Mutex,
};
use tvm::{
ffi::runtime::DLTensor,
runtime::{
load_param_dict, sgx, Graph, GraphExecutor, SystemLibModule, TVMArgValue, TVMRetValue, Tensor,
},
};
lazy_static! {
static ref SYSLIB: SystemLibModule = { SystemLibModule::default() };
static ref MODEL: Mutex<GraphExecutor<'static, 'static>> = {
let graph_json = include_str!(concat!("../", env!("BUILD_DIR"), "/graph.json"));
let params_bytes = include_bytes!(concat!("../", env!("BUILD_DIR"), "/params.bin"));
let params = load_param_dict(params_bytes).unwrap();
let graph = Graph::try_from(graph_json).unwrap();
let mut exec = GraphExecutor::new(graph, &*SYSLIB).unwrap();
exec.load_params(params);
Mutex::new(exec)
};
}
fn ecall_init(_args: &[TVMArgValue]) -> TVMRetValue {
lazy_static::initialize(&MODEL);
TVMRetValue::from(0)
}
fn ecall_main(args: &[TVMArgValue<'static>]) -> TVMRetValue {
let mut model = MODEL.lock().unwrap();
let inp = args[0].try_into().unwrap();
let mut out: Tensor = args[1].try_into().unwrap();
model.set_input("data", inp);
model.run();
sgx::shutdown();
out.copy(model.get_output(0).unwrap());
TVMRetValue::from(1)
}
pub mod ecalls {
//! todo: generate this using proc_macros
use super::*;
use std::{
ffi::CString,
mem,
os::raw::{c_char, c_int, c_void},
slice,
};
use tvm::{
ffi::runtime::{TVMRetValueHandle, TVMValue},
runtime::{
sgx::{ocall_packed_func, run_worker, SgxStatus},
DataType, PackedFunc,
},
};
macro_rules! tvm_ocall {
($func: expr) => {
match $func {
0 => Ok(()),
err => Err(err),
}
};
}
const ECALLS: &'static [&'static str] = &["__tvm_run_worker__", "__tvm_main__", "init"];
pub type EcallPackedFunc = Box<Fn(&[TVMArgValue<'static>]) -> TVMRetValue + Send + Sync>;
lazy_static! {
static ref ECALL_FUNCS: Vec<EcallPackedFunc> = {
vec![
Box::new(run_worker),
Box::new(ecall_main),
Box::new(ecall_init),
]
};
}
extern "C" {
fn __tvm_module_startup() -> ();
fn tvm_ocall_register_export(name: *const c_char, func_id: c_int) -> SgxStatus;
}
#[no_mangle]
pub extern "C" fn tvm_ecall_init(_ret: TVMRetValueHandle) {
unsafe {
__tvm_module_startup();
ECALLS.into_iter().enumerate().for_each(|(i, ecall)| {
tvm_ocall!(tvm_ocall_register_export(
CString::new(*ecall).unwrap().as_ptr(),
i as i32
))
.expect(&format!("Error registering `{}`", ecall));
});
}
}
#[no_mangle]
pub extern "C" fn tvm_ecall_packed_func(
func_id: c_int,
arg_values: *const TVMValue,
type_codes: *const c_int,
num_args: c_int,
ret_val: *mut TVMValue,
ret_type_code: *mut i64,
) {
let args = unsafe {
let values = slice::from_raw_parts(arg_values, num_args as usize);
let type_codes = slice::from_raw_parts(type_codes, num_args as usize);
values
.into_iter()
.zip(type_codes.into_iter())
.map(|(v, t)| TVMArgValue::new(*v, *t as i64))
.collect::<Vec<TVMArgValue<'static>>>()
};
let (rv, tc) = ECALL_FUNCS[func_id as usize](&args).into_tvm_value();
unsafe {
*ret_val = rv;
*ret_type_code = tc;
}
}
}
{
"arch": "x86_64",
"cpu": "x86-64",
"data-layout": "e-m:e-i64:64-f80:128-n8:16:32:64-S128",
"dynamic-linking": true,
"env": "sgx",
"exe-allocation-crate": "alloc_system",
"executables": true,
"has-elf-tls": true,
"has-rpath": true,
"linker-flavor": "gcc",
"linker-is-gnu": true,
"llvm-target": "x86_64-unknown-linux-gnu",
"max-atomic-width": 64,
"os": "linux",
"position-independent-executables": true,
"pre-link-args": {
"gcc": [
"-Wl,--as-needed",
"-Wl,-z,noexecstack",
"-m64"
]
},
"relro-level": "full",
"stack-probes": true,
"target-c-int-width": "32",
"target-endian": "little",
"target-family": "unix",
"target-pointer-width": "64",
"vendor": "unknown"
}
......@@ -15,16 +15,14 @@
# specific language governing permissions and limitations
# under the License.
[dependencies]
alloc = {}
panic_unwind = {}
panic_abort = {}
import struct
import sys
[dependencies.std]
path = "/opt/rust-sgx-sdk/xargo/sgx_tstd"
features = ["backtrace", "stdio", "untrusted_time"]
stage = 2
import numpy as np
[dependencies.xargo_sgx_rand]
path = "/opt/rust-sgx-sdk/xargo/sgx_rand"
stage = 3
def float_bytes(l):
for i in range(0, len(l), 4):
yield l[i:i + 4]
floats = [struct.unpack('f', f)[0] for f in float_bytes(sys.stdin.buffer.read())]
print(np.array(floats))
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import os.path as osp
import numpy as np
import tvm
from tvm import te
CWD = osp.abspath(osp.dirname(__file__))
def main():
ctx = tvm.context('cpu', 0)
model = tvm.runtime.load_module(osp.join(CWD, 'build', 'enclave.signed.so'))
inp = tvm.nd.array(np.ones((1, 3, 224, 224), dtype='float32'), ctx)
out = tvm.nd.array(np.empty((1, 1000), dtype='float32'), ctx)
model(inp, out)
if abs(out.asnumpy().sum() - 1) < 0.001:
print('It works!')
else:
print('It doesn\'t work!')
exit(1)
if __name__ == '__main__':
main()
#!/usr/bin/python3
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
......@@ -14,11 +16,12 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Creates a simple TVM modules."""
import argparse
import os
from os import path as osp
import sys
from tvm import relay
from tvm.relay import testing
......@@ -27,9 +30,8 @@ from tvm import te
def main():
parser = argparse.ArgumentParser()
parser.add_argument('-o', '--out-dir', default='.')
opts = parser.parse_args()
dshape = (1, 28, 28)
net, params = relay.testing.mlp.get_workload(batch_size=dshape[0], dtype='float32')
dshape = (1, 3, 224, 224)
net, params = relay.testing.resnet.get_workload(
......@@ -39,11 +41,11 @@ def main():
graph, lib, params = relay.build(
net, 'llvm --system-lib', params=params)
build_dir = osp.abspath(opts.out_dir)
build_dir = osp.abspath(sys.argv[1])
if not osp.isdir(build_dir):
os.makedirs(build_dir, exist_ok=True)
lib.save(osp.join(build_dir, 'model.bc'))
lib.save(osp.join(build_dir, 'model.o'))
with open(osp.join(build_dir, 'graph.json'), 'w') as f_graph_json:
f_graph_json.write(graph)
with open(osp.join(build_dir, 'params.bin'), 'wb') as f_params:
......
......@@ -17,12 +17,35 @@
* under the License.
*/
use std::env;
extern crate tvm_runtime;
use std::{
convert::TryFrom as _,
io::{Read as _, Write as _},
};
fn main() {
println!(
"cargo:rustc-link-search=native={}",
env::var("BUILD_DIR").unwrap()
);
println!("cargo:rustc-link-lib=static=model");
let syslib = tvm_runtime::SystemLibModule::default();
let graph_json = include_str!(concat!(env!("OUT_DIR"), "/graph.json"));
let params_bytes = include_bytes!(concat!(env!("OUT_DIR"), "/params.bin"));
let params = tvm_runtime::load_param_dict(params_bytes).unwrap();
let graph = tvm_runtime::Graph::try_from(graph_json).unwrap();
let mut exec = tvm_runtime::GraphExecutor::new(graph, &syslib).unwrap();
exec.load_params(params);
let listener = std::net::TcpListener::bind("127.0.0.1:4242").unwrap();
for stream in listener.incoming() {
let mut stream = stream.unwrap();
if let Err(_) =
stream.read_exact(exec.get_input("data").unwrap().data().view().as_mut_slice())
{
continue;
}
exec.run();
if let Err(_) = stream.write_all(exec.get_output(0).unwrap().data().as_slice()) {
continue;
}
}
}
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
if(NOT USE_SGX STREQUAL "OFF")
set(_sgx_src ${CMAKE_CURRENT_SOURCE_DIR}/src/runtime/sgx)
set(_tvm_u_h ${_sgx_src}/untrusted/tvm_u.h)
set(_tvm_t_h ${_sgx_src}/trusted/tvm_t.h)
set(_tvm_t_c ${_sgx_src}/trusted/tvm_t.c)
set(_tvm_edl ${_sgx_src}/tvm.edl)
set(_sgx_ustdc ${RUST_SGX_SDK}/sgx_ustdc)
set(_urts_lib "sgx_urts")
if(NOT SGX_MODE STREQUAL "HW")
message(STATUS "Build with SGX support (SIM)")
set(_urts_lib "${_urts_lib}_sim")
else()
message(STATUS "Build with SGX support (HW)")
endif()
# build edge routines
add_custom_command(
OUTPUT ${_tvm_u_h}
COMMAND ${USE_SGX}/bin/x64/sgx_edger8r --untrusted
--untrusted --untrusted-dir ${_sgx_src}/untrusted
--trusted --trusted-dir ${_sgx_src}/trusted
--search-path ${USE_SGX}/include --search-path ${RUST_SGX_SDK}/edl
${_tvm_edl}
COMMAND sed -i "4i '#include <tvm/runtime/c_runtime_api.h>'" ${_tvm_u_h}
COMMAND sed -i "4i '#include <tvm/runtime/c_runtime_api.h>'" ${_tvm_t_h}
DEPENDS ${_tvm_edl}
)
add_custom_command(
OUTPUT ${_sgx_ustdc}/libsgx_ustdc.a
COMMAND make
WORKING_DIRECTORY ${_sgx_ustdc}
)
add_custom_target(sgx_edl DEPENDS ${_tvm_u_h} ${_sgx_ustdc}/libsgx_ustdc.a)
# build trusted library
set_source_files_properties(${_tvm_t_c} PROPERTIES GENERATED TRUE)
add_library(tvm_t STATIC ${_tvm_t_c})
add_dependencies(tvm_t sgx_edl)
target_include_directories(tvm_t PUBLIC ${USE_SGX}/include ${USE_SGX}/include/tlibc)
# add untrusted runtime files
include_directories(${USE_SGX}/include)
file(GLOB RUNTIME_SGX_SRCS ${_sgx_src}/untrusted/*.c*)
list(APPEND TVM_RUNTIME_LINKER_LIBS
-lpthread
-L${USE_SGX}/lib64 -l${_urts_lib}
-L${RUST_SGX_SDK}/sgx_ustdc -lsgx_ustdc)
list(APPEND RUNTIME_SRCS ${RUNTIME_SGX_SRCS})
include_directories(${RUST_SGX_SDK}/edl ${RUST_SGX_SDK}/common)
endif()
......@@ -31,8 +31,6 @@ echo set\(USE_RPC ON\) >> config.cmake
echo set\(USE_SORT ON\) >> config.cmake
echo set\(USE_GRAPH_RUNTIME ON\) >> config.cmake
echo set\(USE_BLAS openblas\) >> config.cmake
echo set\(USE_SGX /opt/sgxsdk\) >> config.cmake
echo set\(RUST_SGX_SDK /opt/rust-sgx-sdk\) >> config.cmake
mkdir -p build
cd build
cmake ..
......
#!/bin/bash
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
set -e
set -u
set -o pipefail
apt-get update && apt-get install -y --no-install-recommends \
build-essential git cmake \
wget python pkg-config software-properties-common \
autoconf automake libtool ocaml \
protobuf-compiler libprotobuf-dev \
libssl-dev libcurl4-openssl-dev curl
git clone --branch=sgx_2.2 --depth=1 https://github.com/intel/linux-sgx.git
cd linux-sgx
curl -s -S -L 'https://gist.githubusercontent.com/nhynes/c770b0e91610f8c020a8d1a803a1e7cb/raw/8f5372d9cb88929b3cc49a384943bb363bc06827/intel-sgx.patch' | git apply
./download_prebuilt.sh
make -j4 sdk && make -j4 sdk_install_pkg
./linux/installer/bin/sgx_linux_x64_sdk*.bin --prefix /opt
cd -
git clone --branch=v1.0.5 --depth=1 https://github.com/baidu/rust-sgx-sdk.git /opt/rust-sgx-sdk
cd /opt/rust-sgx-sdk
curl -s -S -L 'https://gist.githubusercontent.com/nhynes/37164039c5d3f33aa4f123e4ba720036/raw/b0de575fe937231799930764e76c664b92975163/rust-sgx-sdk.diff' | git apply
cd -
......@@ -221,7 +221,6 @@ inline const char* DeviceName(int type) {
}
}
#ifndef _LIBCPP_SGX_NO_IOSTREAMS
inline std::ostream& operator<<(std::ostream& os, DLContext ctx) { // NOLINT(*)
int device_type = static_cast<int>(ctx.device_type);
if (device_type > kRPCSessMask) {
......@@ -231,8 +230,6 @@ inline std::ostream& operator<<(std::ostream& os, DLContext ctx) { // NOLINT(*)
os << runtime::DeviceName(device_type) << "(" << ctx.device_id << ")";
return os;
}
#endif
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_DEVICE_API_H_
......@@ -24,9 +24,6 @@
#ifndef TVM_RUNTIME_PACKED_FUNC_H_
#define TVM_RUNTIME_PACKED_FUNC_H_
#ifndef _LIBCPP_SGX_NO_IOSTREAMS
#include <sstream>
#endif
#include <dmlc/logging.h>
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/module.h>
......@@ -1019,7 +1016,6 @@ inline const char* TypeCode2Str(int type_code) {
}
}
#ifndef _LIBCPP_SGX_NO_IOSTREAMS
inline std::ostream& operator<<(std::ostream& os, DLDataType t) { // NOLINT(*)
if (t.bits == 1 && t.lanes == 1 && t.code == kDLUInt) {
os << "bool"; return os;
......@@ -1041,30 +1037,11 @@ inline std::ostream& operator<<(std::ostream& os, const DataType& dtype) { // NO
return os << dtype.operator DLDataType();
}
#endif
inline std::string DLDataType2String(DLDataType t) {
if (t.bits == 0) return "";
#ifndef _LIBCPP_SGX_NO_IOSTREAMS
std::ostringstream os;
os << t;
return os.str();
#else
if (t.bits == 1 && t.lanes == 1 && t.code == kDLUInt) {
return "bool";
}
if (t.code < kTVMCustomBegin) {
repr += TypeCode2Str(t.code);
} else {
repr += "custom[" + GetCustomTypeName(t.code) + "]";
}
if (t.code == kTVMOpaqueHandle) return repr;
repr += std::to_string(static_cast<int>(t.bits));
if (t.lanes != 1) {
repr += "x" + std::to_string(static_cast<int>(t.lanes));
}
return repr;
#endif
}
inline DLDataType String2DLDataType(std::string s) {
......
......@@ -22,8 +22,10 @@ newline_style = "Auto"
use_small_heuristics = "Default"
indent_style = "Block"
wrap_comments = false
format_code_in_doc_comments = false
comment_width = 80
normalize_comments = false
normalize_doc_attributes = false
format_strings = false
format_macro_matchers = false
format_macro_bodies = true
......@@ -44,10 +46,12 @@ spaces_around_ranges = false
binop_separator = "Front"
remove_nested_parens = true
combine_control_expr = true
overflow_delimited_expr = false
struct_field_align_threshold = 0
enum_discrim_align_threshold = 0
match_arm_blocks = true
force_multiline_blocks = false
fn_args_density = "Tall"
fn_args_layout = "Tall"
brace_style = "SameLineWhere"
control_brace_style = "AlwaysSameLine"
trailing_semicolon = true
......@@ -56,8 +60,10 @@ match_block_trailing_comma = false
blank_lines_upper_bound = 1
blank_lines_lower_bound = 0
edition = "2018"
version = "One"
inline_attribute_width = 0
merge_derives = true
use_try_shorthand = true
use_try_shorthand = false
use_field_init_shorthand = false
force_explicit_abi = true
condense_wildcard_suffixes = false
......@@ -66,8 +72,8 @@ unstable_features = false
disable_all_formatting = false
skip_children = false
hide_parse_errors = false
error_on_line_overflow = true
error_on_unformatted = true
error_on_line_overflow = false
error_on_unformatted = false
report_todo = "Never"
report_fixme = "Never"
ignore = []
......
......@@ -19,6 +19,7 @@
members = [
"common",
"macros",
"macros_raw",
"runtime",
"runtime/tests/test_tvm_basic",
"runtime/tests/test_tvm_dso",
......
......@@ -26,8 +26,8 @@ edition = "2018"
bindings = []
[dependencies]
failure = "0.1.5"
ndarray = "0.12.1"
failure = { version = "0.1", default-features = false, features = ["derive"] }
ndarray = "0.12"
[build-dependencies]
bindgen = "0.37.4"
bindgen = "0.51"
......@@ -23,10 +23,10 @@ use std::path::PathBuf;
fn main() {
let tvm_home = option_env!("TVM_HOME").map(str::to_string).unwrap_or({
let tvm_home = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
let crate_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.canonicalize()
.unwrap();
tvm_home
crate_dir
.parent()
.unwrap()
.parent()
......@@ -46,6 +46,7 @@ fn main() {
.header(format!("{}/include/tvm/runtime/c_runtime_api.h", tvm_home))
.header(format!("{}/include/tvm/runtime/c_backend_api.h", tvm_home))
.clang_arg(format!("-I{}/3rdparty/dlpack/include/", tvm_home))
.clang_arg(format!("-I{}/include/", tvm_home))
.blacklist_type("max_align_t")
.layout_tests(false)
.derive_partialeq(true)
......
......@@ -20,8 +20,6 @@
//! This crate contains the refactored basic components required
//! for `runtime` and `frontend` TVM crates.
#![feature(box_syntax, trait_alias)]
#[macro_use]
extern crate failure;
......@@ -44,5 +42,5 @@ pub mod packed_func;
pub mod value;
pub use errors::*;
pub use ffi::{TVMByteArray, TVMContext, TVMType};
pub use ffi::{TVMByteArray, TVMContext, DLDataType as TVMType};
pub use packed_func::{TVMArgValue, TVMRetValue};
......@@ -26,8 +26,10 @@ use std::{
pub use crate::ffi::TVMValue;
use crate::{errors::ValueDowncastError, ffi::*};
pub trait PackedFunc =
Fn(&[TVMArgValue]) -> Result<TVMRetValue, crate::errors::FuncCallError> + Send + Sync;
pub trait PackedFunc : Fn(&[TVMArgValue]) -> Result<TVMRetValue, crate::errors::FuncCallError> + Send + Sync {}
impl<T> PackedFunc for T
where T : Fn(&[TVMArgValue]) -> Result<TVMRetValue, crate::errors::FuncCallError> + Send + Sync {}
/// Calls a packed function and returns a `TVMRetValue`.
///
......@@ -66,7 +68,7 @@ macro_rules! TVMPODValue {
UInt(i64),
Float(f64),
Null,
Type(TVMType),
DataType(DLDataType),
String(CString),
Context(TVMContext),
Handle(*mut c_void),
......@@ -87,15 +89,15 @@ macro_rules! TVMPODValue {
DLDataTypeCode_kDLInt => Int($value.v_int64),
DLDataTypeCode_kDLUInt => UInt($value.v_int64),
DLDataTypeCode_kDLFloat => Float($value.v_float64),
TVMTypeCode_kNull => Null,
TVMTypeCode_kTVMType => Type($value.v_type),
TVMTypeCode_kTVMNullptr => Null,
TVMTypeCode_kTVMDataType => DataType($value.v_type),
TVMTypeCode_kTVMContext => Context($value.v_ctx),
TVMTypeCode_kHandle => Handle($value.v_handle),
TVMTypeCode_kArrayHandle => ArrayHandle($value.v_handle as TVMArrayHandle),
TVMTypeCode_kObjectHandle => ObjectHandle($value.v_handle),
TVMTypeCode_kModuleHandle => ModuleHandle($value.v_handle),
TVMTypeCode_kFuncHandle => FuncHandle($value.v_handle),
TVMTypeCode_kNDArrayContainer => NDArrayContainer($value.v_handle),
TVMTypeCode_kTVMOpaqueHandle => Handle($value.v_handle),
TVMTypeCode_kTVMDLTensorHandle => ArrayHandle($value.v_handle as TVMArrayHandle),
TVMTypeCode_kTVMObjectHandle => ObjectHandle($value.v_handle),
TVMTypeCode_kTVMModuleHandle => ModuleHandle($value.v_handle),
TVMTypeCode_kTVMPackedFuncHandle => FuncHandle($value.v_handle),
TVMTypeCode_kTVMNDArrayHandle => NDArrayContainer($value.v_handle),
$( $tvm_type => { $from_tvm_type } ),+
_ => unimplemented!("{}", type_code),
}
......@@ -108,31 +110,31 @@ macro_rules! TVMPODValue {
Int(val) => (TVMValue { v_int64: *val }, DLDataTypeCode_kDLInt),
UInt(val) => (TVMValue { v_int64: *val as i64 }, DLDataTypeCode_kDLUInt),
Float(val) => (TVMValue { v_float64: *val }, DLDataTypeCode_kDLFloat),
Null => (TVMValue{ v_int64: 0 },TVMTypeCode_kNull),
Type(val) => (TVMValue { v_type: *val }, TVMTypeCode_kTVMType),
Null => (TVMValue{ v_int64: 0 },TVMTypeCode_kTVMNullptr),
DataType(val) => (TVMValue { v_type: *val }, TVMTypeCode_kTVMDataType),
Context(val) => (TVMValue { v_ctx: val.clone() }, TVMTypeCode_kTVMContext),
String(val) => {
(
TVMValue { v_handle: val.as_ptr() as *mut c_void },
TVMTypeCode_kStr,
TVMTypeCode_kTVMStr,
)
}
Handle(val) => (TVMValue { v_handle: *val }, TVMTypeCode_kHandle),
Handle(val) => (TVMValue { v_handle: *val }, TVMTypeCode_kTVMOpaqueHandle),
ArrayHandle(val) => {
(
TVMValue { v_handle: *val as *const _ as *mut c_void },
TVMTypeCode_kArrayHandle,
TVMTypeCode_kTVMNDArrayHandle,
)
},
ObjectHandle(val) => (TVMValue { v_handle: *val }, TVMTypeCode_kObjectHandle),
ObjectHandle(val) => (TVMValue { v_handle: *val }, TVMTypeCode_kTVMObjectHandle),
ModuleHandle(val) =>
(TVMValue { v_handle: *val }, TVMTypeCode_kModuleHandle),
(TVMValue { v_handle: *val }, TVMTypeCode_kTVMModuleHandle),
FuncHandle(val) => (
TVMValue { v_handle: *val },
TVMTypeCode_kFuncHandle
TVMTypeCode_kTVMPackedFuncHandle
),
NDArrayContainer(val) =>
(TVMValue { v_handle: *val }, TVMTypeCode_kNDArrayContainer),
(TVMValue { v_handle: *val }, TVMTypeCode_kTVMNDArrayHandle),
$( $self_type($val) => { $from_self_type } ),+
}
}
......@@ -148,14 +150,14 @@ TVMPODValue! {
Str(&'a CStr),
},
match value {
TVMTypeCode_kBytes => { Bytes(&*(value.v_handle as *const TVMByteArray)) }
TVMTypeCode_kStr => { Str(CStr::from_ptr(value.v_handle as *const i8)) }
TVMTypeCode_kTVMBytes => { Bytes(&*(value.v_handle as *const TVMByteArray)) }
TVMTypeCode_kTVMStr => { Str(CStr::from_ptr(value.v_handle as *const i8)) }
},
match &self {
Bytes(val) => {
(TVMValue { v_handle: val.clone() as *const _ as *mut c_void }, TVMTypeCode_kBytes)
(TVMValue { v_handle: val as *const _ as *mut c_void }, TVMTypeCode_kTVMBytes)
}
Str(val) => { (TVMValue { v_handle: val.as_ptr() as *mut c_void }, TVMTypeCode_kStr) }
Str(val) => { (TVMValue { v_handle: val.as_ptr() as *mut c_void }, TVMTypeCode_kTVMStr) }
}
}
......@@ -166,11 +168,14 @@ TVMPODValue! {
/// # Example
///
/// ```
/// use std::convert::{TryFrom, TryInto};
/// use tvm_common::TVMRetValue;
///
/// let a = 42u32;
/// let b: i64 = TVMRetValue::from(a).try_into().unwrap();
/// let b: u32 = tvm_common::TVMRetValue::from(a).try_into().unwrap();
///
/// let s = "hello, world!";
/// let t: TVMRetValue = s.into();
/// let t: TVMRetValue = s.to_string().into();
/// assert_eq!(String::try_from(t).unwrap(), s);
/// ```
TVMRetValue {
......@@ -178,14 +183,14 @@ TVMPODValue! {
Str(&'static CStr),
},
match value {
TVMTypeCode_kBytes => { Bytes(*(value.v_handle as *const TVMByteArray)) }
TVMTypeCode_kStr => { Str(CStr::from_ptr(value.v_handle as *mut i8)) }
TVMTypeCode_kTVMBytes => { Bytes(*(value.v_handle as *const TVMByteArray)) }
TVMTypeCode_kTVMStr => { Str(CStr::from_ptr(value.v_handle as *mut i8)) }
},
match &self {
Bytes(val) =>
{ (TVMValue { v_handle: val as *const _ as *mut c_void }, TVMTypeCode_kBytes ) }
{ (TVMValue { v_handle: val as *const _ as *mut c_void }, TVMTypeCode_kTVMBytes ) }
Str(val) =>
{ (TVMValue { v_str: val.as_ptr() }, TVMTypeCode_kStr ) }
{ (TVMValue { v_str: val.as_ptr() }, TVMTypeCode_kTVMStr ) }
}
}
......@@ -251,7 +256,7 @@ macro_rules! impl_pod_value {
impl_pod_value!(Int, i64, [i8, i16, i32, i64, isize]);
impl_pod_value!(UInt, i64, [u8, u16, u32, u64, usize]);
impl_pod_value!(Float, f64, [f32, f64]);
impl_pod_value!(Type, TVMType, [TVMType]);
impl_pod_value!(DataType, DLDataType, [DLDataType]);
impl_pod_value!(Context, TVMContext, [TVMContext]);
impl<'a> From<&'a str> for TVMArgValue<'a> {
......
......@@ -19,11 +19,9 @@
use std::{os::raw::c_char, str::FromStr};
use failure::Error;
use crate::ffi::*;
impl TVMType {
impl DLDataType {
fn new(type_code: u8, bits: u8, lanes: u16) -> Self {
Self {
code: type_code,
......@@ -33,25 +31,37 @@ impl TVMType {
}
}
#[derive(Debug, Fail)]
pub enum ParseTvmTypeError {
#[fail(display = "invalid number: {}", _0)]
InvalidNumber(std::num::ParseIntError),
#[fail(display = "unknown type: {}", _0)]
UnknownType(String),
}
/// Implements TVMType conversion from `&str` of general format `{dtype}{bits}x{lanes}`
/// such as "int32", "float32" or with lane "float32x1".
impl FromStr for TVMType {
type Err = Error;
impl FromStr for DLDataType {
type Err = ParseTvmTypeError;
fn from_str(type_str: &str) -> Result<Self, Self::Err> {
if type_str == "bool" {
return Ok(TVMType::new(1, 1, 1));
return Ok(DLDataType::new(1, 1, 1));
}
let mut type_lanes = type_str.split("x");
let mut type_lanes = type_str.split('x');
let typ = type_lanes.next().expect("Missing dtype");
let lanes = type_lanes
.next()
.map(|l| <u16>::from_str_radix(l, 10))
.unwrap_or(Ok(1))?;
.unwrap_or(Ok(1))
.map_err(ParseTvmTypeError::InvalidNumber)?;
let (type_name, bits) = match typ.find(char::is_numeric) {
Some(idx) => {
let (name, bits_str) = typ.split_at(idx);
(name, u8::from_str_radix(bits_str, 10)?)
(
name,
u8::from_str_radix(bits_str, 10).map_err(ParseTvmTypeError::InvalidNumber)?,
)
}
None => (typ, 32),
};
......@@ -61,14 +71,14 @@ impl FromStr for TVMType {
"uint" => 1,
"float" => 2,
"handle" => 3,
_ => return Err(format_err!("Unknown type {}", type_name)),
_ => return Err(ParseTvmTypeError::UnknownType(type_name.to_string())),
};
Ok(TVMType::new(type_code, bits, lanes))
Ok(DLDataType::new(type_code, bits, lanes))
}
}
impl std::fmt::Display for TVMType {
impl std::fmt::Display for DLDataType {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
if self.bits == 1 && self.lanes == 1 {
return write!(f, "bool");
......@@ -113,19 +123,23 @@ macro_rules! impl_pod_tvm_value {
impl_pod_tvm_value!(v_int64, i64, i8, u8, i16, u16, i32, u32, i64, u64, isize, usize);
impl_pod_tvm_value!(v_float64, f64, f32, f64);
impl_pod_tvm_value!(v_type, TVMType);
impl_pod_tvm_value!(v_type, DLDataType);
impl_pod_tvm_value!(v_ctx, TVMContext);
#[derive(Debug, Fail)]
#[fail(display = "unsupported device: {}", _0)]
pub struct UnsupportedDeviceError(String);
macro_rules! impl_tvm_context {
( $( $dev_type:ident : [ $( $dev_name:ident ),+ ] ),+ ) => {
/// Creates a TVMContext from a string (e.g., "cpu", "gpu", "ext_dev")
impl FromStr for TVMContext {
type Err = Error;
type Err = UnsupportedDeviceError;
fn from_str(type_str: &str) -> Result<Self, Self::Err> {
Ok(Self {
device_type: match type_str {
$( $( stringify!($dev_name) )|+ => $dev_type ),+,
_ => return Err(format_err!("device {} not supported", type_str).into()),
_ => return Err(UnsupportedDeviceError(type_str.to_string())),
},
device_id: 0,
})
......@@ -163,7 +177,7 @@ impl_tvm_context!(
///
/// ```
/// let v = b"hello";
/// let barr = TVMByteArray::from(&v);
/// let barr = tvm_common::TVMByteArray::from(&v);
/// assert_eq!(barr.len(), v.len());
/// assert_eq!(barr.data(), &[104u8, 101, 108, 108, 111]);
/// ```
......@@ -182,6 +196,10 @@ impl TVMByteArray {
pub fn to_vec(&self) -> Vec<u8> {
self.data().to_vec()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
// Needs AsRef for Vec
......
......@@ -28,16 +28,12 @@ categories = ["api-bindings", "science"]
authors = ["TVM Contributors"]
edition = "2018"
[lib]
name = "tvm_frontend"
crate-type = ["dylib"]
[dependencies]
failure = "0.1.5"
lazy_static = "1.1.0"
ndarray = "0.12.1"
failure = "0.1"
lazy_static = "1.1"
ndarray = "0.12"
num-traits = "0.2"
tvm-common = { version = "0.1.0", path = "../common/", features = ["bindings"] }
tvm-common = { version = "0.1", path = "../common/", features = ["bindings"] }
[features]
blas = ["ndarray/blas"]
......@@ -23,7 +23,7 @@ license = "Apache-2.0"
build = "build.rs"
[dependencies]
ndarray = "0.12.1"
ndarray = "0.12"
tvm-frontend = { path = "../../" }
image = "0.20.1"
csv = "1"
image = "0.20"
csv = "1.1"
......@@ -65,7 +65,7 @@ fn main() {
let input = NDArray::from_rust_ndarray(
&arr,
TVMContext::cpu(0),
TVMType::from_str("float32").unwrap(),
DLDataType::from_str("float32").unwrap(),
)
.unwrap();
println!(
......@@ -117,7 +117,7 @@ fn main() {
let output = NDArray::empty(
output_shape,
TVMContext::cpu(0),
TVMType::from_str("float32").unwrap(),
DLDataType::from_str("float32").unwrap(),
);
// get the `get_output` function from runtime module
let ref get_output_fn = graph_runtime_module
......
......@@ -28,7 +28,7 @@
use std::{
collections::BTreeMap,
ffi::{CStr, CString},
mem,
mem::{self, MaybeUninit},
os::raw::{c_char, c_int, c_void},
ptr, slice, str,
sync::Mutex,
......@@ -36,25 +36,20 @@ use std::{
use failure::Error;
use crate::{
errors,
ffi::{self, TVMValue},
Module, TVMArgValue, TVMRetValue,
};
use crate::{errors, ffi, Module, TVMArgValue, TVMRetValue};
lazy_static! {
static ref GLOBAL_FUNCTIONS: Mutex<BTreeMap<&'static str, Option<Function>>> = {
let mut out_size = 0 as c_int;
let name = ptr::null_mut() as *mut c_char;
let mut out_array = name as *mut _;
let mut names_ptr = ptr::null_mut() as *mut *const c_char;
check_call!(ffi::TVMFuncListGlobalNames(
&mut out_size as *mut _,
&mut out_array
&mut names_ptr as *mut _,
));
let names_list = unsafe { slice::from_raw_parts(out_array, out_size as usize) };
let names_list = unsafe { slice::from_raw_parts(names_ptr, out_size as usize) };
Mutex::new(
names_list
.into_iter()
.iter()
.map(|&p| (unsafe { CStr::from_ptr(p).to_str().unwrap() }, None))
.collect(),
)
......@@ -80,7 +75,7 @@ unsafe impl Sync for Function {}
impl Function {
pub(crate) fn new(handle: ffi::TVMFunctionHandle) -> Self {
Function {
handle: handle,
handle,
is_global: false,
is_cloned: false,
}
......@@ -98,15 +93,13 @@ impl Function {
&mut handle as *mut _
));
maybe_func.replace(Function {
handle: handle,
handle,
is_global: true,
is_cloned: false,
});
}
unsafe {
std::mem::transmute::<Option<&Function>, Option<&'static Function>>(
maybe_func.as_ref(),
)
mem::transmute::<Option<&Function>, Option<&'static Function>>(maybe_func.as_ref())
}
})
}
......@@ -214,7 +207,7 @@ impl<'a, 'm> Builder<'a, 'm> {
let (mut values, mut type_codes): (Vec<ffi::TVMValue>, Vec<ffi::TVMTypeCode>) =
self.arg_buf.iter().map(|arg| arg.to_tvm_value()).unzip();
let mut ret_val = unsafe { std::mem::uninitialized::<TVMValue>() };
let mut ret_val = unsafe { MaybeUninit::uninit().assume_init() };
let mut ret_type_code = 0i32;
check_call!(ffi::TVMFuncCall(
self.func.ok_or(errors::FunctionNotFoundError)?.handle,
......@@ -257,20 +250,20 @@ unsafe extern "C" fn tvm_callback(
let args_list = slice::from_raw_parts_mut(args, len);
let type_codes_list = slice::from_raw_parts_mut(type_codes, len);
let mut local_args: Vec<TVMArgValue> = Vec::new();
let mut value = mem::uninitialized::<ffi::TVMValue>();
let mut tcode = mem::uninitialized::<c_int>();
let mut value = MaybeUninit::uninit().assume_init();
let mut tcode = MaybeUninit::uninit().assume_init();
let rust_fn =
mem::transmute::<*mut c_void, fn(&[TVMArgValue]) -> Result<TVMRetValue, Error>>(fhandle);
for i in 0..len {
value = args_list[i];
tcode = type_codes_list[i];
if tcode == ffi::TVMTypeCode_kObjectHandle as c_int
|| tcode == ffi::TVMTypeCode_kFuncHandle as c_int
|| tcode == ffi::TVMTypeCode_kModuleHandle as c_int
if tcode == ffi::TVMTypeCode_kTVMObjectHandle as c_int
|| tcode == ffi::TVMTypeCode_kTVMPackedFuncHandle as c_int
|| tcode == ffi::TVMTypeCode_kTVMModuleHandle as c_int
{
check_call!(ffi::TVMCbArgToReturn(&mut value as *mut _, tcode));
}
local_args.push(TVMArgValue::from_tvm_value(value.into(), tcode as u32));
local_args.push(TVMArgValue::from_tvm_value(value, tcode as u32));
}
let rv = match rust_fn(local_args.as_slice()) {
......@@ -293,9 +286,9 @@ unsafe extern "C" fn tvm_callback(
}
unsafe extern "C" fn tvm_callback_finalizer(fhandle: *mut c_void) {
let rust_fn =
let _rust_fn =
mem::transmute::<*mut c_void, fn(&[TVMArgValue]) -> Result<TVMRetValue, Error>>(fhandle);
mem::drop(rust_fn);
// XXX: give converted functions lifetimes so they're not called after use
}
fn convert_to_tvm_func(f: fn(&[TVMArgValue]) -> Result<TVMRetValue, Error>) -> Function {
......
......@@ -30,8 +30,6 @@
//!
//! Checkout the `examples` repository for more details.
#![feature(box_syntax)]
#[macro_use]
extern crate failure;
#[macro_use]
......@@ -55,7 +53,7 @@ pub use crate::{
ndarray::NDArray,
tvm_common::{
errors as common_errors,
ffi::{self, TVMByteArray, TVMType},
ffi::{self, TVMByteArray, DLDataType},
packed_func::{TVMArgValue, TVMRetValue},
},
};
......
......@@ -32,7 +32,7 @@ use tvm_common::ffi;
use crate::{errors, function::Function};
const ENTRY_FUNC: &'static str = "__tvm_main__";
const ENTRY_FUNC: &str = "__tvm_main__";
/// Wrapper around TVM module handle which contains an entry function.
/// The entry function can be applied to an imported module through [`entry_func`].
......@@ -72,7 +72,7 @@ impl Module {
ensure!(
!fhandle.is_null(),
errors::NullHandleError {
name: format!("{}", name.into_string()?)
name: name.into_string()?.to_string()
}
);
Ok(Function::new(fhandle))
......@@ -88,7 +88,7 @@ impl Module {
let ext = CString::new(
path.as_ref()
.extension()
.unwrap_or(std::ffi::OsStr::new(""))
.unwrap_or_else(|| std::ffi::OsStr::new(""))
.to_str()
.ok_or_else(|| {
format_err!("Bad module load path: `{}`.", path.as_ref().display())
......
......@@ -63,7 +63,7 @@ pub struct NDArray {
impl NDArray {
pub(crate) fn new(handle: ffi::TVMArrayHandle) -> Self {
NDArray {
handle: handle,
handle,
is_view: true,
}
}
......@@ -89,8 +89,7 @@ impl NDArray {
/// Returns the total number of entries of the NDArray.
pub fn size(&self) -> Option<usize> {
self.shape()
.map(|v| v.into_iter().fold(1, |acc, &mut e| acc * e))
self.shape().map(|v| v.iter().product())
}
/// Returns the context which the NDArray was defined.
......@@ -100,7 +99,7 @@ impl NDArray {
/// Returns the type of the entries of the NDArray.
pub fn dtype(&self) -> TVMType {
unsafe { (*self.handle).dtype.into() }
unsafe { (*self.handle).dtype }
}
/// Returns the number of dimensions of the NDArray.
......@@ -211,8 +210,8 @@ impl NDArray {
bail!(
"{}",
errors::TypeMismatchError {
expected: format!("{}", self.dtype().to_string()),
actual: format!("{}", target.dtype().to_string()),
expected: self.dtype().to_string(),
actual: target.dtype().to_string(),
}
);
}
......@@ -228,7 +227,7 @@ impl NDArray {
pub fn copy_to_ctx(&self, target: &TVMContext) -> Result<NDArray, Error> {
let tmp = NDArray::empty(
self.shape().ok_or(errors::MissingShapeError)?,
target.clone(),
*target,
self.dtype(),
);
let copy = self.copy_to_ndarray(tmp)?;
......@@ -241,8 +240,8 @@ impl NDArray {
ctx: TVMContext,
dtype: TVMType,
) -> Result<Self, Error> {
let mut shape = rnd.shape().to_vec();
let mut nd = NDArray::empty(&mut shape, ctx, dtype);
let shape = rnd.shape().to_vec();
let mut nd = NDArray::empty(&shape, ctx, dtype);
let mut buf = Array::from_iter(rnd.into_iter().map(|&v| v as T));
nd.copy_from_buffer(
buf.as_slice_mut()
......@@ -257,9 +256,9 @@ impl NDArray {
check_call!(ffi::TVMArrayAlloc(
shape.as_ptr() as *const i64,
shape.len() as c_int,
dtype.code as c_int,
dtype.bits as c_int,
dtype.lanes as c_int,
i32::from(dtype.code) as c_int,
i32::from(dtype.bits) as c_int,
i32::from(dtype.lanes) as c_int,
ctx.device_type.0 as c_int,
ctx.device_id as c_int,
&mut handle as *mut _,
......@@ -364,9 +363,9 @@ mod tests {
assert_eq!(ndarray.ndim(), 1);
assert!(ndarray.is_contiguous().is_ok());
assert_eq!(ndarray.byte_offset(), 0);
let mut shape = vec![4];
let shape = vec![4];
let e = NDArray::empty(
&mut shape,
&shape,
TVMContext::cpu(0),
TVMType::from_str("int32").unwrap(),
);
......@@ -378,16 +377,12 @@ mod tests {
#[test]
#[should_panic(expected = "called `Result::unwrap()` on an `Err`")]
fn copy_wrong_dtype() {
let mut shape = vec![4];
let shape = vec![4];
let mut data = vec![1f32, 2., 3., 4.];
let ctx = TVMContext::cpu(0);
let mut nd_float = NDArray::empty(
&mut shape,
ctx.clone(),
TVMType::from_str("float32").unwrap(),
);
let mut nd_float = NDArray::empty(&shape, ctx, TVMType::from_str("float32").unwrap());
nd_float.copy_from_buffer(&mut data);
let empty_int = NDArray::empty(&mut shape, ctx, TVMType::from_str("int32").unwrap());
let empty_int = NDArray::empty(&shape, ctx, TVMType::from_str("int32").unwrap());
nd_float.copy_to_ndarray(empty_int).unwrap();
}
......
......@@ -93,7 +93,7 @@ mod tests {
let tvm: TVMByteArray = TVMRetValue::from(v).try_into().unwrap();
assert_eq!(
tvm.data(),
w.iter().map(|e| *e).collect::<Vec<u8>>().as_slice()
w.iter().copied().collect::<Vec<u8>>().as_slice()
);
}
......
......@@ -23,7 +23,7 @@ license = "Apache-2.0"
build = "build.rs"
[dependencies]
ndarray = "0.12.1"
ndarray = "0.12"
tvm-frontend = { path = "../../" }
[features]
......
......@@ -33,7 +33,7 @@ fn main() {
} else {
(TVMContext::gpu(0), "gpu")
};
let dtype = TVMType::from_str("float32").unwrap();
let dtype = DLDataType::from_str("float32").unwrap();
let mut arr = NDArray::empty(shape, ctx, dtype);
arr.copy_from_buffer(data.as_mut_slice());
let mut ret = NDArray::empty(shape, ctx, dtype);
......
......@@ -21,5 +21,5 @@ version = "0.0.0"
authors = ["TVM Contributors"]
[dependencies]
ndarray = "0.12.1"
ndarray = "0.12"
tvm-frontend = { path = "../../" }
......@@ -39,7 +39,7 @@ fn main() {
for arg in args.iter() {
let e = NDArray::empty(
shape, TVMContext::cpu(0),
TVMType::from_str("float32").unwrap()
DLDataType::from_str("float32").unwrap()
);
let arg: NDArray = arg.try_into()?;
let arr = arg.copy_to_ndarray(e)?;
......@@ -55,7 +55,7 @@ fn main() {
let mut arr = NDArray::empty(
shape,
TVMContext::cpu(0),
TVMType::from_str("float32").unwrap(),
DLDataType::from_str("float32").unwrap(),
);
arr.copy_from_buffer(data.as_mut_slice());
......
......@@ -17,9 +17,6 @@
* under the License.
*/
#![feature(panic_info_message)]
#![allow(unused_imports)]
use std::panic;
#[macro_use]
......@@ -44,9 +41,9 @@ fn main() {
println!("expected error message is:");
panic::set_hook(Box::new(|panic_info| {
if let Some(msg) = panic_info.message() {
println!("{:?}", msg);
}
// if let Some(msg) = panic_info.message() {
// println!("{:?}", msg);
// }
if let Some(location) = panic_info.location() {
println!(
"panic occurred in file '{}' at line {}",
......
......@@ -17,7 +17,7 @@
[package]
name = "tvm-macros"
version = "0.1.0"
version = "0.1.1"
license = "Apache-2.0"
description = "Proc macros used by the TVM crates."
repository = "https://github.com/apache/incubator-tvm"
......@@ -26,11 +26,6 @@ keywords = ["tvm"]
authors = ["TVM Contributors"]
edition = "2018"
[lib]
proc-macro = true
[dependencies]
goblin = "0.0.22"
proc-macro2 = "0.4"
proc-quote = "0.2"
syn = "0.15"
tvm-macros-raw = { path = "../macros_raw" }
......@@ -17,106 +17,12 @@
* under the License.
*/
#![feature(proc_macro_span)]
#[macro_use]
extern crate tvm_macros_raw;
extern crate proc_macro;
use std::{fs::File, io::Read};
use proc_quote::quote;
#[proc_macro]
pub fn import_module(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let obj_file_path = syn::parse_macro_input!(input as syn::LitStr);
let mut path = obj_file_path.span().unwrap().source_file().path();
path.pop(); // remove the filename
path.push(obj_file_path.value());
let mut fd = File::open(&path)
.unwrap_or_else(|_| panic!("Unable to find TVM object file at `{}`", path.display()));
let mut buffer = Vec::new();
fd.read_to_end(&mut buffer).unwrap();
let fn_names = match goblin::Object::parse(&buffer).unwrap() {
goblin::Object::Elf(elf) => elf
.syms
.iter()
.filter_map(|s| {
if s.st_type() == 0 || goblin::elf::sym::type_to_str(s.st_type()) == "FILE" {
return None;
}
match elf.strtab.get(s.st_name) {
Some(Ok(name)) if name != "" => {
Some(syn::Ident::new(name, proc_macro2::Span::call_site()))
}
_ => None,
}
})
.collect::<Vec<_>>(),
goblin::Object::Mach(goblin::mach::Mach::Binary(obj)) => {
obj.symbols()
.filter_map(|s| match s {
Ok((name, nlist))
if nlist.is_global()
&& nlist.n_sect != 0
&& !name.ends_with("tvm_module_ctx") =>
{
Some(syn::Ident::new(
if name.starts_with('_') {
// Mach objects prepend a _ to globals.
&name[1..]
} else {
&name
},
proc_macro2::Span::call_site(),
))
}
_ => None,
})
.collect::<Vec<_>>()
}
_ => panic!("Unsupported object format."),
#[macro_export]
macro_rules! import_module {
($module_path:literal) => {
$crate::import_module_raw!(file!(), $module_path);
};
let extern_fns = quote! {
mod ext {
extern "C" {
#(
pub(super) fn #fn_names(
args: *const tvm_runtime::ffi::TVMValue,
type_codes: *const std::os::raw::c_int,
num_args: std::os::raw::c_int
) -> std::os::raw::c_int;
)*
}
}
};
let fns = quote! {
use tvm_runtime::{ffi::TVMValue, TVMArgValue, TVMRetValue, FuncCallError};
#extern_fns
#(
pub fn #fn_names(args: &[TVMArgValue]) -> Result<TVMRetValue, FuncCallError> {
let (values, type_codes): (Vec<TVMValue>, Vec<i32>) = args
.into_iter()
.map(|arg| {
let (val, code) = arg.to_tvm_value();
(val, code as i32)
})
.unzip();
let exit_code = unsafe {
ext::#fn_names(values.as_ptr(), type_codes.as_ptr(), values.len() as i32)
};
if exit_code == 0 {
Ok(TVMRetValue::default())
} else {
Err(FuncCallError::get_with_context(stringify!(#fn_names).to_string()))
}
}
)*
};
proc_macro::TokenStream::from(fns)
}
#!/bin/bash
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
......@@ -6,9 +5,9 @@
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -16,11 +15,22 @@
# specific language governing permissions and limitations
# under the License.
sgx_sdk=${SGX_SDK:=/opt/sgxsdk}
[package]
name = "tvm-macros-raw"
version = "0.1.1"
license = "Apache-2.0"
description = "Proc macros used by the TVM crates."
repository = "https://github.com/apache/incubator-tvm"
readme = "README.md"
keywords = ["tvm"]
authors = ["TVM Contributors"]
edition = "2018"
export LD_LIBRARY_PATH="$sgx_sdk/lib64":${LD_LIBRARY_PATH}
export CC=clang-6.0
export AR=llvm-ar-6.0
export TVM_CACHE_DIR=/tmp
[lib]
proc-macro = true
make && printf "\n" && python3 run_model.py
[dependencies]
goblin = "0.0.24"
proc-macro2 = "^1.0"
quote = "1.0"
syn = "1.0"
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
extern crate proc_macro;
use std::{fs::File, io::Read};
use syn::parse::{Parse, ParseStream, Result};
use syn::{Token, LitStr};
use quote::quote;
use std::path::PathBuf;
struct ImportModule {
importing_file: LitStr,
module_path: LitStr,
}
impl Parse for ImportModule {
fn parse(input: ParseStream) -> Result<Self> {
let importing_file: LitStr = input.parse()?;
input.parse::<Token![,]>()?;
let module_path: LitStr = input.parse()?;
Ok(ImportModule {
importing_file,
module_path,
})
}
}
#[proc_macro]
pub fn import_module_raw(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let import_module_args = syn::parse_macro_input!(input as ImportModule);
let mut path = PathBuf::new();
path = path.join(import_module_args.importing_file.value());
path.pop(); // remove the filename
path.push(import_module_args.module_path.value());
let mut fd = File::open(&path)
.unwrap_or_else(|_| panic!("Unable to find TVM object file at `{}`", path.display()));
let mut buffer = Vec::new();
fd.read_to_end(&mut buffer).unwrap();
let fn_names = match goblin::Object::parse(&buffer).unwrap() {
goblin::Object::Elf(elf) => elf
.syms
.iter()
.filter_map(|s| {
if s.st_type() == 0 || goblin::elf::sym::type_to_str(s.st_type()) == "FILE" {
return None;
}
match elf.strtab.get(s.st_name) {
Some(Ok(name)) if name != "" => {
Some(syn::Ident::new(name, proc_macro2::Span::call_site()))
}
_ => None,
}
})
.collect::<Vec<_>>(),
goblin::Object::Mach(goblin::mach::Mach::Binary(obj)) => {
obj.symbols()
.filter_map(|s| match s {
Ok((name, ref nlist))
if nlist.is_global()
&& nlist.n_sect != 0
&& !name.ends_with("tvm_module_ctx") =>
{
Some(syn::Ident::new(
if name.starts_with('_') {
// Mach objects prepend a _ to globals.
&name[1..]
} else {
&name
},
proc_macro2::Span::call_site(),
))
}
_ => None,
})
.collect::<Vec<_>>()
}
_ => panic!("Unsupported object format."),
};
let extern_fns = quote! {
mod ext {
extern "C" {
#(
pub(super) fn #fn_names(
args: *const tvm_runtime::ffi::TVMValue,
type_codes: *const std::os::raw::c_int,
num_args: std::os::raw::c_int
) -> std::os::raw::c_int;
)*
}
}
};
let fns = quote! {
use tvm_runtime::{ffi::TVMValue, TVMArgValue, TVMRetValue, FuncCallError};
#extern_fns
#(
pub fn #fn_names(args: &[TVMArgValue]) -> Result<TVMRetValue, FuncCallError> {
let (values, type_codes): (Vec<TVMValue>, Vec<i32>) = args
.into_iter()
.map(|arg| {
let (val, code) = arg.to_tvm_value();
(val, code as i32)
})
.unzip();
let exit_code = unsafe {
ext::#fn_names(values.as_ptr(), type_codes.as_ptr(), values.len() as i32)
};
if exit_code == 0 {
Ok(TVMRetValue::default())
} else {
Err(FuncCallError::get_with_context(stringify!(#fn_names).to_string()))
}
}
)*
};
proc_macro::TokenStream::from(fns)
}
......@@ -27,25 +27,19 @@ categories = ["api-bindings", "science"]
authors = ["TVM Contributors"]
edition = "2018"
[features]
default = ["nom/std"]
sgx = ["nom/alloc"]
[dependencies]
bounded-spsc-queue = "0.4.0"
failure = "0.1.5"
itertools = "0.7.8"
lazy_static = "1.1.0"
ndarray="0.12.1"
nom = {version = "4.0.0", default-features = false }
serde = "1.0.59"
serde_derive = "1.0.79"
serde_json = "1.0.17"
crossbeam = "0.7.3"
failure = "0.1"
itertools = "0.8"
lazy_static = "1.4"
ndarray="0.12"
nom = "5.0"
num_cpus = "1.10"
serde = "1.0"
serde_derive = "1.0"
serde_json = "1.0"
tvm-common = { version = "0.1", path = "../common" }
tvm-macros = { version = "0.1", path = "../macros" }
[target.'cfg(not(target_env = "sgx"))'.dependencies]
num_cpus = "1.8.0"
[target.'cfg(not(any(target_arch = "wasm32", target_env = "sgx")))'.dependencies]
libloading = "0.5"
......@@ -17,9 +17,6 @@
* under the License.
*/
#[cfg(target_env = "sgx")]
use alloc::alloc::{self, Layout, LayoutErr};
#[cfg(not(target_env = "sgx"))]
use std::alloc::{self, Layout, LayoutErr};
const DEFAULT_ALIGN_BYTES: usize = 4;
......@@ -35,14 +32,11 @@ impl Allocation {
pub fn new(size: usize, align: Option<usize>) -> Result<Self, LayoutErr> {
let alignment = align.unwrap_or(DEFAULT_ALIGN_BYTES);
let layout = Layout::from_size_align(size, alignment)?;
let ptr = unsafe { alloc::alloc(layout.clone()) };
let ptr = unsafe { alloc::alloc(layout) };
if ptr.is_null() {
alloc::handle_alloc_error(layout);
}
Ok(Self {
ptr: ptr,
layout: layout,
})
Ok(Self { ptr, layout })
}
pub fn as_mut_ptr(&self) -> *mut u8 {
......@@ -58,12 +52,22 @@ impl Allocation {
pub fn align(&self) -> usize {
self.layout.align()
}
/// Returns a view of the Allocation.
pub fn as_slice(&self) -> &[u8] {
unsafe { std::slice::from_raw_parts(self.as_mut_ptr(), self.size()) }
}
/// Returns a mutable view of the Allocation.
pub fn as_mut_slice(&mut self) -> &mut [u8] {
unsafe { std::slice::from_raw_parts_mut(self.as_mut_ptr(), self.size()) }
}
}
impl Drop for Allocation {
fn drop(&mut self) {
unsafe {
alloc::dealloc(self.ptr, self.layout.clone());
alloc::dealloc(self.ptr, self.layout);
}
}
}
......@@ -101,6 +101,22 @@ impl<'a> Storage<'a> {
}
s
}
/// Returns a view of the stored data.
pub fn as_slice(&self) -> &[u8] {
match self {
Storage::Owned(alloc) => alloc.as_slice(),
Storage::View(slice, _) => &*slice,
}
}
/// Returns a mutable view of the stored data.
pub fn as_mut_slice(&mut self) -> &mut [u8] {
match self {
Storage::Owned(alloc) => alloc.as_mut_slice(),
Storage::View(slice, _) => slice,
}
}
}
impl<'d, 's, T> From<&'d [T]> for Storage<'s> {
......@@ -123,14 +139,18 @@ impl<'d, 's, T> From<&'d [T]> for Storage<'s> {
///
/// ```
/// extern crate ndarray;
/// use std::convert::TryInto;
/// use tvm_runtime::{call_packed, DLTensor, TVMArgValue, TVMRetValue, Tensor};
///
/// let mut a_nd: ndarray::Array = ndarray::Array::from_vec(vec![1f32, 2., 3., 4.]);
/// let mut a_nd: ndarray::Array1<f32> = ndarray::Array::from_vec(vec![1f32, 2., 3., 4.]);
/// let mut a: Tensor = a_nd.into();
/// let mut a_dl: DLTensor = (&mut t).into();
/// let mut a_dl: DLTensor = (&mut a).into();
///
/// let tvm_fn = |args: &[TVMArgValue]| -> Result<TVMRetValue, ()> { Ok(TVMRetValue::default()) };
/// call_packed!(tvm_fn, &mut a_dl);
///
/// // Array -> Tensor is mostly useful when post-processing TVM graph outputs.
/// let mut a_nd = ndarray::Array::try_from(&a).unwrap();
/// let mut a_nd: ndarray::ArrayD<f32> = a.try_into().unwrap();
/// ```
#[derive(PartialEq)]
pub struct Tensor<'a> {
......@@ -154,6 +174,14 @@ impl<'a> Tensor<'a> {
self.shape.clone()
}
pub fn data(&self) -> &Storage {
&self.data
}
pub fn data_mut(&mut self) -> &'a mut Storage {
&mut self.data
}
/// Returns the data of this `Tensor` as a `Vec`.
///
/// # Panics
......@@ -220,9 +248,9 @@ impl<'a> Tensor<'a> {
pub fn to_owned(&self) -> Tensor<'static> {
let t = Tensor {
data: self.data.to_owned(),
ctx: self.ctx.clone(),
dtype: self.dtype.clone(),
size: self.size.clone(),
ctx: self.ctx,
dtype: self.dtype,
size: self.size,
shape: self.shape.clone(),
strides: None,
byte_offset: 0,
......@@ -246,7 +274,7 @@ impl<'a> Tensor<'a> {
},
size: arr.len(),
shape: arr.shape().iter().map(|&v| v as i64).collect(),
strides: Some(arr.strides().into_iter().map(|&v| v as usize).collect()),
strides: Some(arr.strides().iter().map(|&v| v as usize).collect()),
byte_offset: 0,
}
}
......@@ -276,9 +304,9 @@ impl<'a> Tensor<'a> {
/// Conversions to `ndarray::Array` from `Tensor`, if the types match.
macro_rules! impl_ndarray_try_from_tensor {
($type:ty, $dtype:expr) => {
impl<'a, 't> TryFrom<&'a Tensor<'t>> for ndarray::ArrayD<$type> {
impl<'t> TryFrom<Tensor<'t>> for ndarray::ArrayD<$type> {
type Error = Error;
fn try_from(tensor: &'a Tensor) -> Result<ndarray::ArrayD<$type>, Error> {
fn try_from(tensor: Tensor) -> Result<ndarray::ArrayD<$type>, Error> {
ensure!(
tensor.dtype == $dtype,
"Cannot convert Tensor with dtype {:?} to ndarray",
......@@ -342,10 +370,10 @@ impl<'a> From<DLTensor> for Tensor<'a> {
Self {
data: storage,
ctx: TVMContext::default(),
dtype: dtype,
size: size,
shape: shape,
strides: if dlt.strides == ptr::null_mut() {
dtype,
size,
shape,
strides: if dlt.strides.is_null() {
None
} else {
Some(slice::from_raw_parts_mut(dlt.strides as *mut usize, size).to_vec())
......
......@@ -30,9 +30,3 @@ pub enum GraphFormatError {
#[fail(display = "Invalid DLType: {}", 0)]
InvalidDLType(String),
}
#[derive(Debug, Fail)]
#[fail(display = "SGX error: 0x{:x}", code)]
pub struct SgxError {
pub code: u32,
}
......@@ -28,19 +28,6 @@
//! The main entrypoints to this crate are `GraphExecutor`
//! For examples of use, please refer to the multi-file tests in the `tests` directory.
#![feature(
allocator_api,
box_syntax,
fn_traits,
unboxed_closures,
vec_remove_item
)]
#[cfg(target_env = "sgx")]
extern crate alloc;
extern crate bounded_spsc_queue;
#[cfg(target_env = "sgx")]
extern crate core;
#[macro_use]
extern crate failure;
#[macro_use]
......@@ -50,7 +37,6 @@ extern crate lazy_static;
extern crate ndarray;
#[macro_use]
extern crate nom;
#[cfg(not(target_env = "sgx"))]
extern crate num_cpus;
extern crate serde;
#[macro_use]
......@@ -63,9 +49,6 @@ mod array;
pub mod errors;
mod graph;
mod module;
#[cfg(target_env = "sgx")]
#[macro_use]
pub mod sgx;
mod threading;
mod workspace;
......@@ -86,10 +69,8 @@ lazy_static! {
}
#[no_mangle]
pub extern "C" fn TVMAPISetLastError(cmsg: *const i8) {
*LAST_ERROR.write().unwrap() = Some(unsafe { std::ffi::CStr::from_ptr(cmsg) });
#[cfg(target_env = "sgx")]
ocall_packed!("__sgx_set_last_error__", cmsg);
pub unsafe extern "C" fn TVMAPISetLastError(cmsg: *const i8) {
*LAST_ERROR.write().unwrap() = Some(std::ffi::CStr::from_ptr(cmsg));
}
#[no_mangle]
......
......@@ -35,8 +35,8 @@ use crate::{
use super::Module;
const TVM_MAIN: &'static [u8] = b"__tvm_main__";
const TVM_MODULE_CTX: &'static [u8] = b"__tvm_module_ctx";
const TVM_MAIN: &[u8] = b"__tvm_main__";
const TVM_MODULE_CTX: &[u8] = b"__tvm_module_ctx";
/// A module backed by a Dynamic Shared Object (dylib).
pub struct DsoModule<'a> {
......@@ -64,22 +64,26 @@ impl<'a> DsoModule<'a> {
init_context_func!(
lib,
(TVMAPISetLastError, extern "C" fn(*const i8)),
(TVMAPISetLastError, unsafe extern "C" fn(*const i8)),
(
TVMBackendAllocWorkspace,
extern "C" fn(c_int, c_int, u64, c_int, c_int) -> *mut c_void
unsafe extern "C" fn(c_int, c_int, u64, c_int, c_int) -> *mut c_void
),
(
TVMBackendFreeWorkspace,
extern "C" fn(c_int, c_int, *mut c_void) -> c_int
unsafe extern "C" fn(c_int, c_int, *mut c_void) -> c_int
),
(
TVMBackendParallelLaunch,
extern "C" fn(crate::threading::FTVMParallelLambda, *const c_void, usize) -> c_int
unsafe extern "C" fn(
crate::threading::FTVMParallelLambda,
*const c_void,
usize,
) -> c_int
),
(
TVMBackendParallelBarrier,
extern "C" fn(usize, *const tvm_common::ffi::TVMParallelGroupEnv)
unsafe extern "C" fn(usize, *const tvm_common::ffi::TVMParallelGroupEnv)
),
);
......@@ -129,7 +133,7 @@ impl<'a> Module for DsoModule<'a> {
&*Box::leak(super::wrap_backend_packed_func(name.to_string(), *func)),
);
self.packed_funcs.borrow().get(name).map(|f| *f)
self.packed_funcs.borrow().get(name).copied()
}
}
......
......@@ -36,9 +36,9 @@ pub trait Module {
// @see `WrapPackedFunc` in `llvm_module.cc`.
fn wrap_backend_packed_func(func_name: String, func: BackendPackedCFunc) -> Box<dyn PackedFunc> {
box move |args: &[TVMArgValue]| {
Box::new(move |args: &[TVMArgValue]| {
let (values, type_codes): (Vec<TVMValue>, Vec<i32>) = args
.into_iter()
.iter()
.map(|arg| {
let (val, code) = arg.to_tvm_value();
(val, code as i32)
......@@ -52,5 +52,5 @@ fn wrap_backend_packed_func(func_name: String, func: BackendPackedCFunc) -> Box<
func_name.clone(),
))
}
}
})
}
......@@ -27,6 +27,11 @@ use super::Module;
pub struct SystemLibModule;
#[cfg(target_env = "sgx")]
extern "C" {
fn __tvm_module_startup();
}
lazy_static! {
static ref SYSTEM_LIB_FUNCTIONS: Mutex<HashMap<String, &'static (dyn PackedFunc)>> =
Mutex::new(HashMap::new());
......@@ -37,13 +42,16 @@ impl Module for SystemLibModule {
SYSTEM_LIB_FUNCTIONS
.lock()
.unwrap()
.get(name.as_ref())
.map(|f| *f)
.get(name.as_ref()).copied()
}
}
impl Default for SystemLibModule {
fn default() -> Self {
#[cfg(target_env = "sgx")]
unsafe {
__tvm_module_startup();
}
SystemLibModule {}
}
}
......@@ -58,5 +66,5 @@ pub extern "C" fn TVMBackendRegisterSystemLibSymbol(
name.to_string(),
&*Box::leak(super::wrap_backend_packed_func(name.to_string(), func)),
);
return 0;
0
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
use std::{
ffi::CString,
os::raw::{c_char, c_int},
};
pub use crate::threading::tvm_run_worker as run_worker;
use crate::{threading::sgx_join_threads, SystemLibModule, TVMArgValue, TVMRetValue};
use errors::SgxError;
use ffi::TVMValue;
#[macro_export]
macro_rules! tvm_ocall {
($func: expr) => {
match $func {
0 => Ok(()),
code => Err(SgxError { code }),
}
};
}
pub type SgxStatus = u32;
#[cfg(target_env = "sgx")]
extern "C" {
fn tvm_ocall_packed_func(
name: *const c_char,
arg_values: *const TVMValue,
type_codes: *const c_int,
num_args: c_int,
ret_val: *mut TVMValue,
ret_type_code: *mut c_int,
) -> SgxStatus;
}
pub fn ocall_packed_func<S: AsRef<str>>(
fn_name: S,
args: &[TVMArgValue],
) -> Result<TVMRetValue, SgxError> {
let mut ret_val = TVMValue { v_int64: 0 };
let ret_type_code = 0i64;
unsafe {
tvm_ocall!(tvm_ocall_packed_func(
CString::new(fn_name.as_ref()).unwrap().as_ptr(),
args.iter()
.map(|ref arg| arg.value)
.collect::<Vec<TVMValue>>()
.as_ptr(),
args.iter()
.map(|ref arg| arg.type_code as i32)
.collect::<Vec<i32>>()
.as_ptr() as *const i32,
args.len() as i32,
&mut ret_val as *mut TVMValue,
&mut (ret_type_code as i32) as *mut c_int,
))?;
}
Ok(TVMRetValue::from_tvm_value(ret_val, ret_type_code as i64))
}
#[macro_export]
macro_rules! ocall_packed {
($fn_name:expr, $($args:expr),+) => {
$crate::sgx::ocall_packed_func($fn_name, &[$($args.into(),)+])
.expect(concat!("Error calling `", $fn_name, "`"))
};
($fn_name:expr) => {
$crate::sgx::ocall_packed_func($fn_name, &Vec::new())
.expect(concat!("Error calling `", $fn_name, "`"))
}
}
pub fn shutdown() {
if env!("TVM_NUM_THREADS") != "0" {
sgx_join_threads()
}
}
impl Drop for SystemLibModule {
fn drop(&mut self) {
shutdown()
}
}
......@@ -18,30 +18,18 @@
*/
use std::{
env,
os::raw::{c_int, c_void},
sync::{
atomic::{AtomicUsize, Ordering},
Arc, Barrier,
},
};
#[cfg(not(target_env = "sgx"))]
use num_cpus;
#[cfg(not(target_env = "sgx"))]
use std::{
env,
thread::{self, JoinHandle},
};
#[cfg(target_env = "sgx")]
use std::{collections::VecDeque, ptr, sync::Mutex};
use bounded_spsc_queue::{self, Producer};
use crossbeam::channel::{Sender, Receiver, bounded};
use tvm_common::ffi::TVMParallelGroupEnv;
#[cfg(target_env = "sgx")]
use super::{TVMArgValue, TVMRetValue};
pub(crate) type FTVMParallelLambda =
extern "C" fn(task_id: usize, penv: *const TVMParallelGroupEnv, cdata: *const c_void) -> i32;
......@@ -82,7 +70,6 @@ impl Job {
/// Waits for all tasks in this `Job` to be completed.
fn wait(&self) {
while self.pending.load(Ordering::Acquire) > 0 {
#[cfg(not(target_env = "sgx"))]
thread::yield_now();
}
}
......@@ -99,9 +86,8 @@ struct Task {
unsafe impl Send for Task {}
unsafe impl Sync for Task {}
impl FnOnce<()> for Task {
type Output = i32;
extern "rust-call" fn call_once(self, _args: ()) -> Self::Output {
impl Task {
fn run(self) -> i32 {
let status = (self.flambda)(self.id, &self.penv as *const _, self.cdata);
self.pending.fetch_sub(1, Ordering::AcqRel);
status
......@@ -111,45 +97,23 @@ impl FnOnce<()> for Task {
#[derive(Default)]
struct Threads {
#[allow(unused)]
#[cfg(not(target_env = "sgx"))]
handles: Vec<JoinHandle<()>>,
queues: Vec<Producer<Task>>,
queues: Vec<Sender<Task>>,
}
impl<'a> Threads {
#[cfg(not(target_env = "sgx"))]
fn launch<F: Sync + Send + FnOnce(Consumer<Task>) + 'static + Copy>(
fn launch<F: Sync + Send + FnOnce(Receiver<Task>) + 'static + Copy>(
num_threads: usize,
cb: F,
) -> Self {
let (handles, queues) = (0..num_threads)
.map(|_| {
let (p, c) = bounded_spsc_queue::make(2);
let (p, c) = bounded(2);
let handle = thread::spawn(move || cb(c.into()));
(handle, p)
})
.unzip();
Threads {
handles: handles,
queues: queues,
}
}
#[cfg(target_env = "sgx")]
fn launch<F: Sync + Send + FnOnce(Consumer<Task>) + 'static + Copy>(
num_threads: usize,
_cb: F,
) -> Self {
let mut consumer_queues = SGX_QUEUES.lock().unwrap();
let queues = (0..num_threads)
.map(|_| {
let (p, c) = bounded_spsc_queue::make(2);
consumer_queues.push_back(c.into());
p
})
.collect();
ocall_packed!("__sgx_thread_group_launch__", num_threads as u64);
Threads { queues: queues }
Threads { handles, queues }
}
}
......@@ -165,7 +129,7 @@ impl ThreadPool {
fn new() -> Self {
let num_workers = max_concurrency();
ThreadPool {
num_workers: num_workers,
num_workers,
threads: Threads::launch(num_workers, ThreadPool::run_worker),
}
}
......@@ -174,17 +138,18 @@ impl ThreadPool {
let mut tasks = job.tasks(self.num_workers + 1);
for (i, task) in tasks.split_off(1).into_iter().enumerate() {
self.threads.queues[i].push(task);
self.threads.queues[i].send(task)
.expect("should send");
}
tasks.pop().unwrap()();
tasks.pop().unwrap().run();
job.wait();
}
fn run_worker(queue: Consumer<Task>) {
fn run_worker(queue: Receiver<Task>) {
loop {
let task = queue.pop();
let result = task();
let task = queue.recv().expect("should recv");
let result = task.run();
if result == <i32>::min_value() {
break;
} else if result != 0 {
......@@ -194,42 +159,14 @@ impl ThreadPool {
}
}
// Send + Sync wrapper for bounded_spsc_queue::Consumer
struct Consumer<T> {
consumer: bounded_spsc_queue::Consumer<T>,
}
impl<T> From<bounded_spsc_queue::Consumer<T>> for Consumer<T> {
fn from(c: bounded_spsc_queue::Consumer<T>) -> Self {
Consumer { consumer: c }
}
}
impl<T> Consumer<T> {
fn pop(&self) -> T {
self.consumer.pop()
}
}
unsafe impl<T> Send for Consumer<T> {}
unsafe impl<T> Sync for Consumer<T> {}
#[cfg(target_env = "sgx")]
lazy_static! {
/// Holds tasks for untrusted threads which re-enter the enclave to execute.
static ref SGX_QUEUES: Mutex<VecDeque<Consumer<Task>>> = Mutex::new(VecDeque::new());
}
#[cfg(all(not(target_arch = "wasm32"), not(target_env = "sgx")))]
#[cfg(not(target_arch = "wasm32"))]
fn max_concurrency() -> usize {
if let Ok(threads_str) = env::var("TVM_NUM_THREADS").or(env::var("OMP_NUM_THREADS")) {
if let Ok(threads_str) = env::var("TVM_NUM_THREADS").or_else(|_| env::var("OMP_NUM_THREADS")) {
if let Ok(threads) = usize::from_str_radix(&threads_str, 10) {
return threads;
}
}
num_cpus::get_physical()
}
#[cfg(target_env = "sgx")]
fn max_concurrency() -> usize {
usize::from_str_radix(env!("TVM_NUM_THREADS"), 10).unwrap_or(1)
num_cpus::get()
}
#[cfg(target_arch = "wasm32")]
......@@ -237,69 +174,38 @@ fn max_concurrency() -> usize {
0 // wasm doesn't support threads yet
}
#[cfg(target_env = "sgx")]
pub fn tvm_run_worker(_args: &[TVMArgValue]) -> TVMRetValue {
let q = {
let mut qs = SGX_QUEUES.lock().unwrap();
qs.pop_front()
// `qs: MutexGuard` needs to be dropped here since `run_worker` won't return
};
if let Some(q) = q {
ThreadPool::run_worker(q);
}
TVMRetValue::default()
}
#[no_mangle]
pub extern "C" fn TVMBackendParallelLaunch(
cb: FTVMParallelLambda,
cdata: *const c_void,
num_task: usize,
) -> c_int {
if max_concurrency() == 0 {
if max_concurrency() < 2 {
let penv = TVMParallelGroupEnv {
sync_handle: 0 as *mut c_void,
sync_handle: std::ptr::null_mut(),
num_task: 1,
};
cb(0, &penv as *const _, cdata);
} else {
THREAD_POOL.with(|pool| {
pool.launch(Job {
cb: cb,
cdata: cdata,
cb,
cdata,
req_num_tasks: num_task,
pending: Arc::new(AtomicUsize::new(0)),
});
});
}
return 0;
}
#[cfg(target_env = "sgx")]
pub(crate) fn sgx_join_threads() {
extern "C" fn poison_pill(
_task_id: usize,
_penv: *const TVMParallelGroupEnv,
_cdata: *const c_void,
) -> i32 {
<i32>::min_value()
}
THREAD_POOL.with(|pool| {
pool.launch(Job {
cb: poison_pill,
cdata: ptr::null(),
req_num_tasks: 0,
pending: Arc::new(AtomicUsize::new(0)),
});
});
ocall_packed!("__sgx_thread_group_join__", 0);
0
}
// @see issue 988 for information on why this function is used.
#[no_mangle]
pub extern "C" fn TVMBackendParallelBarrier(_task_id: usize, penv: *const TVMParallelGroupEnv) {
let barrier: &Arc<Barrier> = unsafe { &*((*penv).sync_handle as *const Arc<Barrier>) };
pub unsafe extern "C" fn TVMBackendParallelBarrier(
_task_id: usize,
penv: *const TVMParallelGroupEnv,
) {
let barrier: &Arc<Barrier> = &*((*penv).sync_handle as *const Arc<Barrier>);
barrier.wait();
}
......@@ -323,7 +229,7 @@ mod tests {
penv: *const TVMParallelGroupEnv,
cdata: *const c_void,
) -> i32 {
if cdata == ptr::null() {
if cdata.is_null() {
return 0;
}
unsafe {
......
......@@ -29,6 +29,11 @@ use crate::allocator::Allocation;
const WS_ALIGN: usize = 64; // taken from `kTempAllocaAlignment` in `device_api.h`
pub fn remove_item<T: PartialEq>(vec: &mut Vec<T>, item: &T) -> Option<T> {
let pos = vec.iter().position(|x| *x == *item)?;
Some(vec.remove(pos))
}
struct WorkspacePool {
workspaces: Vec<Allocation>,
free: Vec<usize>,
......@@ -51,7 +56,7 @@ impl WorkspacePool {
}
fn alloc(&mut self, size: usize) -> Result<*mut u8, Error> {
if self.free.len() == 0 {
if self.free.is_empty() {
return self.alloc_new(size);
}
let idx = self
......@@ -64,15 +69,12 @@ impl WorkspacePool {
}
cur_ws_idx.or(Some(idx)).and_then(|cur_idx| {
let cur_size = self.workspaces[cur_idx].size();
Some(match ws_size <= cur_size {
true => idx,
false => cur_idx,
})
Some(if ws_size <= cur_size { idx } else { cur_idx })
})
});
match idx {
Some(idx) => {
self.free.remove_item(&idx).unwrap();
remove_item(&mut self.free, &idx).unwrap();
self.in_use.push(idx);
Ok(self.workspaces[idx].as_mut_ptr())
}
......@@ -90,9 +92,10 @@ impl WorkspacePool {
break;
}
}
Ok(self
.free
.push(ws_idx.ok_or(format_err!("Tried to free nonexistent workspace."))?))
if let Some(ws_idx) = ws_idx {
self.free.push(ws_idx);
}
Ok(())
}
}
......@@ -133,5 +136,5 @@ pub extern "C" fn TVMBackendFreeWorkspace(
Err(_) => -1,
}) as c_int
});
return 0;
0
}
......@@ -26,11 +26,38 @@ use std::{convert::TryFrom, fs, io::Read};
use tvm_runtime::Graph;
macro_rules! mf_dir {
($p:literal) => {
concat!(env!("CARGO_MANIFEST_DIR"), $p)
};
}
static PARAMS_FIXTURE_PATH: &str = mf_dir!("/tests/graph.params");
#[test]
fn test_load_graph() {
let output = std::process::Command::new(mf_dir!("/tests/build_model.py"))
.env(
"PYTHONPATH",
concat!(
mf_dir!("/../../python"),
":",
mf_dir!("/../../nnvm/python"),
":",
mf_dir!("/../../topi/python")
),
)
.output()
.expect("Failed to build test model");
assert!(
std::path::Path::new(PARAMS_FIXTURE_PATH).exists(),
"Could not build test graph fixture: STDOUT:\n\n{}\nSTDERR: {}\n\n",
String::from_utf8(output.stdout).unwrap(),
String::from_utf8(output.stderr).unwrap()
);
let mut params_bytes = Vec::new();
fs::File::open(concat!(env!("CARGO_MANIFEST_DIR"), "/tests/graph.params"))
.expect("Could not find TVM graph. Did you run `tests/build_model.py`?")
fs::File::open(PARAMS_FIXTURE_PATH)
.unwrap()
.read_to_end(&mut params_bytes)
.unwrap();
let _params = tvm_runtime::load_param_dict(&params_bytes);
......
......@@ -22,10 +22,10 @@ license = "Apache-2.0"
authors = ["TVM Contributors"]
[dependencies]
ndarray="0.12.1"
serde = "1.0.59"
serde_json = "1.0.17"
ndarray="0.12"
serde = "1.0"
serde_json = "1.0"
tvm-runtime = { path = "../../" }
[build-dependencies]
ar = "0.6.0"
ar = "0.6"
......@@ -33,11 +33,11 @@ const IN_DIM: usize = 8;
macro_rules! check_sum {
($e:expr, $a:ident, $b:ident) => {
let a = Array::try_from($e.get_input(stringify!($a)).unwrap()).unwrap();
let a = Array::try_from($e.get_input(stringify!($a)).unwrap().to_owned()).unwrap();
check_sum!(a, $b);
};
($e:expr, $a:expr, $b:ident) => {
let a = Array::try_from($e.get_output($a).unwrap()).unwrap();
let a = Array::try_from($e.get_output($a).unwrap().to_owned()).unwrap();
check_sum!(a, $b);
};
($a:ident, $b:ident) => {
......@@ -73,11 +73,11 @@ fn main() {
.collect::<Vec<f32>>(),
)
.unwrap();
let w = Array::try_from(params.get("dense0_weight").unwrap())
let w = Array::try_from(params.get("dense0_weight").unwrap().to_owned())
.unwrap()
.into_shape((IN_DIM * 2, IN_DIM))
.unwrap();
let b = Array::try_from(params.get("dense0_bias").unwrap()).unwrap();
let b = Array::try_from(params.get("dense0_bias").unwrap().to_owned()).unwrap();
let dense = x.dot(&w.t()) + &b;
let left = dense.slice(s![.., 0..IN_DIM]);
let right = dense.slice(s![.., IN_DIM..]);
......
......@@ -22,8 +22,8 @@ license = "Apache-2.0"
authors = ["TVM Contributors"]
[dependencies]
ndarray="0.12.1"
ndarray="0.12"
tvm-runtime = { path = "../../" }
[build-dependencies]
ar = "0.6.0"
ar = "0.6"
......@@ -28,12 +28,7 @@
#include <tvm/runtime/module.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/device_api.h>
#ifdef _LIBCPP_SGX_CONFIG
#include "sgx/trusted/runtime.h"
#endif
#ifndef _LIBCPP_SGX_NO_IOSTREAMS
#include <sstream>
#endif
#include <array>
#include <algorithm>
#include <string>
......@@ -174,7 +169,6 @@ void DeviceAPI::SyncStreamFromTo(TVMContext ctx,
LOG(FATAL) << "Device does not support stream api.";
}
#ifndef _LIBCPP_SGX_NO_IOSTREAMS
//--------------------------------------------------------
// Error handling mechanism
// -------------------------------------------------------
......@@ -338,11 +332,6 @@ std::string NormalizeError(std::string err_msg) {
return os.str();
}
#else
std::string NormalizeError(std::string err_msg) {
return err_msg;
}
#endif
} // namespace runtime
} // namespace tvm
......@@ -366,11 +355,7 @@ int TVMAPIHandleException(const std::runtime_error &e) {
}
void TVMAPISetLastError(const char* msg) {
#ifndef _LIBCPP_SGX_CONFIG
TVMAPIRuntimeStore::Get()->last_error = msg;
#else
sgx::OCallPackedFunc("__sgx_set_last_error__", msg);
#endif
}
int TVMModLoadFromFile(const char* file_name,
......
......@@ -25,11 +25,7 @@
#include <dmlc/logging.h>
#include <dmlc/thread_local.h>
#include <algorithm>
#ifndef _LIBCPP_SGX_CONFIG
#include "mt_random_engine.cc"
#else
#include "sgx_random_engine.cc"
#endif
#define DLPACK_INTEGER_TYPE_SWITCH(type, DType, ...) \
if (type.code == kDLInt && type.bits == 32) { \
......
......@@ -50,7 +50,7 @@ class CPUDeviceAPI final : public DeviceAPI {
#if _MSC_VER
ptr = _aligned_malloc(nbytes, alignment);
if (ptr == nullptr) throw std::bad_alloc();
#elif defined(_LIBCPP_SGX_CONFIG) || (defined(__ANDROID__) && __ANDROID_API__ < 17)
#elif defined(__ANDROID__) && __ANDROID_API__ < 17
ptr = memalign(alignment, nbytes);
if (ptr == nullptr) throw std::bad_alloc();
#else
......
......@@ -73,7 +73,6 @@ std::string GetFileFormat(const std::string& file_name,
const std::string& format) {
std::string fmt = format;
if (fmt.length() == 0) {
if (file_name.find(".signed.so") != std::string::npos) return "sgx";
size_t pos = file_name.find_last_of(".");
if (pos != std::string::npos) {
return file_name.substr(pos + 1, file_name.length() - pos - 1);
......
......@@ -67,11 +67,7 @@ void GraphRuntime::Run() {
void GraphRuntime::Init(const std::string& graph_json,
tvm::runtime::Module module,
const std::vector<TVMContext>& ctxs) {
#ifndef _LIBCPP_SGX_NO_IOSTREAMS
std::istringstream is(graph_json);
#else
std::string is = graph_json;
#endif
dmlc::JSONReader reader(&is);
this->Load(&reader);
module_ = module;
......
......@@ -21,9 +21,7 @@
* \file module_util.cc
* \brief Utilities for module.
*/
#ifndef _LIBCPP_SGX_CONFIG
#include <dmlc/memory_io.h>
#endif
#include <tvm/runtime/module.h>
#include <tvm/runtime/registry.h>
#include <string>
......@@ -121,7 +119,6 @@ void InitContextFunctions(std::function<void*(const char*)> fgetsymbol) {
* \return Root Module.
*/
runtime::Module ProcessModuleBlob(const char* mblob, ObjectPtr<Library> lib) {
#ifndef _LIBCPP_SGX_CONFIG
CHECK(mblob != nullptr);
uint64_t nbytes = 0;
for (size_t i = 0; i < sizeof(nbytes); ++i) {
......@@ -180,10 +177,6 @@ runtime::Module ProcessModuleBlob(const char* mblob, ObjectPtr<Library> lib) {
// invariance: root module is always at location 0.
// The module order is collected via DFS
return modules[0];
#else
LOG(FATAL) << "SGX does not support ImportModuleBlob";
return Module();
#endif
}
Module CreateModuleFromLibrary(ObjectPtr<Library> lib) {
......
......@@ -26,9 +26,7 @@
#include <tvm/runtime/packed_func.h>
#include <unordered_set>
#include <cstring>
#ifndef _LIBCPP_SGX_CONFIG
#include "file_util.h"
#endif
namespace tvm {
namespace runtime {
......@@ -77,7 +75,6 @@ PackedFunc ModuleNode::GetFunction(const std::string& name, bool query_imports)
Module Module::LoadFromFile(const std::string& file_name,
const std::string& format) {
#ifndef _LIBCPP_SGX_CONFIG
std::string fmt = GetFileFormat(file_name, format);
CHECK(fmt.length() != 0)
<< "Cannot deduce format of file " << file_name;
......@@ -91,9 +88,6 @@ Module Module::LoadFromFile(const std::string& file_name,
<< load_f_name << ") is not presented.";
Module m = (*f)(file_name, format);
return m;
#else
LOG(FATAL) << "SGX does not support LoadFromFile";
#endif
}
void ModuleNode::SaveToFile(const std::string& file_name,
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file common.h
* \brief TVM SGX common API.
*/
#ifndef TVM_RUNTIME_SGX_COMMON_H_
#define TVM_RUNTIME_SGX_COMMON_H_
#include <sgx_error.h>
namespace tvm {
namespace runtime {
namespace sgx {
#define TVM_SGX_CHECKED_CALL(Function) \
sgx_status_t TVM_STR_CONCAT(__sgx_status_, __LINE__) = SGX_ERROR_UNEXPECTED; \
TVM_STR_CONCAT(__sgx_status_, __LINE__) = Function; \
CHECK_EQ(TVM_STR_CONCAT(__sgx_status_, __LINE__), SGX_SUCCESS) \
<< "SGX Error: " << TVM_STR_CONCAT(__sgx_status_, __LINE__);
} // namespace sgx
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_SGX_COMMON_H_
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file ecall_registry.h
* \brief The global registry of packed functions available via ecall_packed_func.
*/
#ifndef TVM_RUNTIME_SGX_TRUSTED_ECALL_REGISTRY_H_
#define TVM_RUNTIME_SGX_TRUSTED_ECALL_REGISTRY_H_
#include <dmlc/logging.h>
#include <tvm/runtime/registry.h>
#include <string>
#include <algorithm>
#include <vector>
namespace tvm {
namespace runtime {
namespace sgx {
class ECallRegistry: public Registry {
public:
explicit ECallRegistry(std::string name) {
name_ = name;
}
Registry& set_body(PackedFunc f) {
func_ = f;
return *this;
}
Registry& set_body(PackedFunc::FType f) { // NOLINT(*)
return set_body(PackedFunc(f));
}
static Registry& Register(const std::string& name, bool override = false) {
for (auto& r : exports_) {
if (r.name_ == name) {
CHECK(override) << "ecall " << name << " is already registered";
return r;
}
}
TVM_SGX_CHECKED_CALL(
tvm_ocall_register_export(name.c_str(), exports_.size()));
exports_.emplace_back(name);
return exports_.back();
}
static bool Remove(const std::string& name) {
LOG(FATAL) << "Removing enclave exports is not supported.";
}
static const PackedFunc* Get(const std::string& name) {
for (const auto& r : exports_) {
if (r.name_ == name) return &r.func_;
}
return nullptr;
}
static const PackedFunc* Get(unsigned func_id) {
return func_id >= exports_.size() ? nullptr : &exports_[func_id].func_;
}
static std::vector<std::string> ListNames() {
std::vector<std::string> names;
names.resize(exports_.size());
std::transform(exports_.begin(), exports_.end(), names.begin(),
[](ECallRegistry r) { return r.name_; });
return names;
}
static std::vector<ECallRegistry> exports_;
};
std::vector<ECallRegistry> ECallRegistry::exports_;
/*!
* \brief Register a function callable via ecall_packed_func
* \code
* TVM_REGISTER_ENCLAVE_FUNC("DoThing")
* .set_body([](TVMArgs args, TVMRetValue* rv) {
* });
* \endcode
*/
#define TVM_REGISTER_ENCLAVE_FUNC(OpName) \
TVM_STR_CONCAT(TVM_FUNC_REG_VAR_DEF, __COUNTER__) = \
::tvm::runtime::sgx::ECallRegistry::Register(OpName, true)
} // namespace sgx
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_SGX_TRUSTED_ECALL_REGISTRY_H_
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file runtime_t.cc
*/
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/packed_func.h>
#include "../../c_runtime_api.cc"
#include "../../cpu_device_api.cc"
#include "../../module.cc"
#include "../../module_util.cc"
#include "../../registry.cc"
#include "../../system_lib_module.cc"
#include "../../thread_pool.cc"
#include "../../workspace_pool.cc"
#include "ecall_registry.h"
#include "runtime.h"
#include "threading_backend.cc"
namespace tvm {
namespace runtime {
namespace sgx {
extern "C" {
void tvm_ecall_init(TVMRetValueHandle ret) {}
void tvm_ecall_packed_func(int func_id,
const TVMValue* arg_values,
const int* type_codes,
int num_args,
TVMRetValueHandle ret) {
const PackedFunc* f = ECallRegistry::Get(func_id);
CHECK(f != nullptr) << "ecall function not found.";
TVMRetValue rv;
f->CallPacked(TVMArgs(arg_values, type_codes, num_args), &rv);
int ret_type_code = rv.type_code();
if (ret_type_code == kTVMNullptr) return;
TVMValue ret_value;
if (ret_type_code == kTVMBytes || ret_type_code == kTVMStr) {
// allocate a buffer in untrusted, copy the values in
std::string bytes = rv;
void* ret_buf;
TVM_SGX_CHECKED_CALL(tvm_ocall_reserve_space(
&ret_buf, bytes.size() + sizeof(TVMByteArray), sizeof(uint64_t)));
char* data_buf = static_cast<char*>(ret_buf) + sizeof(TVMByteArray);
memcpy(data_buf, bytes.data(), bytes.size());
TVMByteArray* arr = static_cast<TVMByteArray*>(ret_buf);
arr->data = data_buf;
arr->size = bytes.size();
ret_value = TVMValue{.v_handle = arr};
ret_type_code = kTVMBytes;
} else {
rv.MoveToCHost(&ret_value, &ret_type_code);
}
TVM_SGX_CHECKED_CALL(tvm_ocall_set_return(ret, &ret_value, &ret_type_code, 1));
}
} // extern "C"
TVM_REGISTER_ENCLAVE_FUNC("__tvm_main__")
.set_body([](TVMArgs args, TVMRetValue* rv) {
Module mod = (*Registry::Get("runtime.SystemLib"))();
mod.GetFunction("default_function").CallPacked(args, rv);
});
} // namespace sgx
} // namespace runtime
} // namespace tvm
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file trusted/runtime.h
* \brief TVM SGX trusted API.
*/
#ifndef TVM_RUNTIME_SGX_TRUSTED_RUNTIME_H_
#define TVM_RUNTIME_SGX_TRUSTED_RUNTIME_H_
#include <tvm/runtime/packed_func.h>
#include <string>
#include <utility>
#include "../common.h"
namespace tvm {
namespace runtime {
namespace sgx {
template<typename... Args>
inline TVMRetValue OCallPackedFunc(std::string name, Args&& ...args) {
const int kNumArgs = sizeof...(Args);
const int kArraySize = kNumArgs > 0 ? kNumArgs : 1;
TVMValue values[kArraySize];
int type_codes[kArraySize];
detail::for_each(TVMArgsSetter(values, type_codes),
std::forward<Args>(args)...);
TVMValue ret_val;
int ret_type_code;
TVM_SGX_CHECKED_CALL(tvm_ocall_packed_func(name.c_str(),
values,
type_codes,
kNumArgs,
&ret_val,
&ret_type_code));
TVMRetValue* rv = new TVMRetValue();
*rv = TVMArgValue(ret_val, ret_type_code);
return *rv;
}
} // namespace sgx
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_SGX_TRUSTED_RUNTIME_H_
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file sgx/threading_backend.cc
* \brief SGX threading backend
*/
#include <tvm/runtime/threading_backend.h>
#include <dmlc/logging.h>
#include <sgx_edger8r.h>
#include <sgx_trts.h>
#include <atomic>
#include "runtime.h"
#ifndef TVM_SGX_MAX_CONCURRENCY
#define TVM_SGX_MAX_CONCURRENCY 1
#endif
namespace tvm {
namespace runtime {
namespace threading {
class ThreadGroup::Impl {
public:
Impl(int num_workers, std::function<void(int)> worker_callback,
bool exclude_worker0)
: num_workers_(num_workers),
worker_callback_(worker_callback),
next_task_id_(exclude_worker0) {
CHECK(num_workers <= TVM_SGX_MAX_CONCURRENCY)
<< "Tried spawning more threads than allowed by TVM_SGX_MAX_CONCURRENCY.";
sgx::OCallPackedFunc("__sgx_thread_group_launch__",
num_workers_, reinterpret_cast<void*>(this));
}
~Impl() {
sgx::OCallPackedFunc("__sgx_thread_group_join__");
}
void RunTask() {
int task_id = next_task_id_++;
CHECK(task_id < num_workers_)
<< "More workers entered enclave than allowed by TVM_SGX_MAX_CONCURRENCY";
worker_callback_(task_id);
}
private:
int num_workers_;
std::function<void(int)> worker_callback_;
std::atomic<int> next_task_id_;
};
ThreadGroup::ThreadGroup(int num_workers,
std::function<void(int)> worker_callback,
bool exclude_worker0)
: impl_(new ThreadGroup::Impl(num_workers, worker_callback, exclude_worker0)) {}
void ThreadGroup::Join() {}
int ThreadGroup::Configure(AffinityMode mode, int nthreads, bool exclude_worker0) {
int max_conc = MaxConcurrency();
if (!nthreads || ntheads > max_conc) {
return max_conc;
}
return nthreads;
}
ThreadGroup::~ThreadGroup() { delete impl_; }
void Yield() {}
int MaxConcurrency() { return TVM_SGX_MAX_CONCURRENCY; }
TVM_REGISTER_ENCLAVE_FUNC("__tvm_run_worker__")
.set_body([](TVMArgs args, TVMRetValue* rv) {
void* tg = args[0];
if (!sgx_is_within_enclave(tg, sizeof(ThreadGroup::Impl))) return;
reinterpret_cast<ThreadGroup::Impl*>(tg)->RunTask();
});
} // namespace threading
} // namespace runtime
} // namespace tvm
enclave {
from "sgx_tstdc.edl" import *;
from "sgx_stdio.edl" import *;
from "sgx_backtrace.edl" import *;
trusted {
public void tvm_ecall_init([isptr, user_check] TVMRetValueHandle ret);
public void tvm_ecall_packed_func(int func_id,
[in, count=num_args] const TVMValue* arg_values,
[in, count=num_args] const int* type_codes,
int num_args,
[out] TVMValue* ret_val,
[out] int* ret_type_code);
};
untrusted {
void tvm_ocall_packed_func([in, string] const char* name,
[in, count=num_args] const TVMValue* arg_values,
[in, count=num_args] const int* type_codes,
int num_args,
[out] TVMValue* ret_val,
[out] int* ret_type_code);
void tvm_ocall_register_export([in, string] const char* name, int func_id);
};
};
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file sgx_module.cc
* \brief SGX enclave module.
*/
#include <dmlc/logging.h>
#include <sgx_urts.h>
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/threading_backend.h>
#include <algorithm>
#include <fstream>
#include <iostream>
#include <iterator>
#include <sstream>
#include <string>
#include <unordered_map>
#include "../common.h"
#include "../../file_util.h"
#include "./tvm_u.h"
namespace tvm {
namespace runtime {
class SGXModuleNode;
namespace sgx {
class EnclaveContext {
public:
explicit EnclaveContext(SGXModuleNode* mod) {
CHECK(Context()->mod_ == nullptr)
<< "Tried overriding existing enclave context.";
CHECK(mod != nullptr) << "Tried setting null enclave context.";
Context()->mod_ = mod;
}
~EnclaveContext() {
Context()->mod_ = nullptr;
}
static SGXModuleNode* GetModule() {
SGXModuleNode* ctx = Context()->mod_;
CHECK(ctx != nullptr) << "No current enclave context";
return ctx;
}
private:
EnclaveContext() {}
SGXModuleNode* mod_;
static EnclaveContext* Context() {
static thread_local EnclaveContext inst;
return &inst;
}
};
} // namespace sgx
class SGXModuleNode : public ModuleNode {
public:
~SGXModuleNode() {
if (eid_) {
sgx::EnclaveContext ctx(this);
sgx_destroy_enclave(eid_);
}
}
void Init(const std::string& enclave_file) {
std::string token_file = GetCacheDir() + "/" +
GetFileBasename(enclave_file) + ".token";
sgx_launch_token_t token = {0};
int token_updated = 0;
try {
std::ifstream ifs(token_file, std::fstream::in | std::fstream::binary);
ifs.exceptions(std::ifstream::failbit | std::ifstream::badbit);
ifs >> token;
} catch (std::ifstream::failure e) {
memset(&token, 0x0, sizeof(sgx_launch_token_t));
}
TVM_SGX_CHECKED_CALL(sgx_create_enclave(
enclave_file.c_str(), SGX_DEBUG_FLAG, &token, &token_updated, &eid_, NULL));
sgx::EnclaveContext ctx(this);
TVMRetValue rv;
TVM_SGX_CHECKED_CALL(tvm_ecall_init(eid_, &rv));
if (!token_updated) return;
try {
std::ofstream ofs(token_file, std::fstream::trunc | std::fstream::binary);
ofs.exceptions(std::ifstream::failbit | std::ifstream::badbit);
ofs << token;
} catch (std::ifstream::failure e) {
LOG(INFO) << "Could not save SGX launch token to " << token_file;
}
}
const char* type_key() const final {
return "sgx";
}
PackedFunc GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final {
auto exported = exports_.find(name);
if (exported == exports_.end()) return PackedFunc();
int func_id = exported->second;
return PackedFunc([this, func_id](TVMArgs args, TVMRetValue* rv) {
sgx::EnclaveContext ctx(this);
TVMValue ret_value;
int ret_type_code;
TVM_SGX_CHECKED_CALL(tvm_ecall_packed_func(eid_, func_id,
args.values, args.type_codes, args.num_args, &ret_value, &ret_type_code));
*rv = TVMArgValue(ret_value, ret_type_code);
});
}
void RunWorkers(int num_tasks) {
std::function<void(int)> runner = [this](int _worker_id) {
this->GetFunction("__tvm_run_worker__",
std::shared_ptr<SGXModuleNode>(nullptr))();
};
thread_group_.reset(new tvm::runtime::threading::ThreadGroup(
num_tasks, runner, false /* include_main_thread */));
}
void JoinThreads() {
thread_group_->Join();
}
void RegisterExport(std::string name, int func_id) {
exports_[name] = func_id;
}
private:
// ID of the loaded enclave
sgx_enclave_id_t eid_;
// Names and IDs of functions exported by the enclave module
std::unordered_map<std::string, int> exports_;
std::unique_ptr<tvm::runtime::threading::ThreadGroup> thread_group_;
};
namespace sgx {
TVM_REGISTER_GLOBAL("__sgx_thread_group_launch__")
.set_body([](TVMArgs args, TVMRetValue* rv) {
EnclaveContext::GetModule()->RunWorkers(args[0]);
});
TVM_REGISTER_GLOBAL("__sgx_thread_group_join__")
.set_body([](TVMArgs args, TVMRetValue* rv) {
EnclaveContext::GetModule()->JoinThreads();
});
TVM_REGISTER_GLOBAL("__sgx_set_last_error__")
.set_body([](TVMArgs args, TVMRetValue* rv) {
std::string err = args[0];
TVMAPISetLastError(err.c_str());
});
TVM_REGISTER_GLOBAL("__sgx_println__")
.set_body([](TVMArgs args, TVMRetValue* rv) {
std::ostringstream msg;
for (int i = 0; i < args.num_args; ++i) {
switch (args.type_codes[i]) {
case kDLInt: msg << static_cast<int64_t>(args[i]); break;
case kDLUInt: msg << static_cast<uint64_t>(args[i]); break;
case kDLFloat: msg << static_cast<double>(args[i]); break;
case kTVMStr:
case kTVMBytes: {
std::string val = args[i];
msg << val;
}
break;
}
msg << " ";
}
LOG(INFO) << msg.str();
});
extern "C" {
void tvm_ocall_register_export(const char* name, int func_id) {
EnclaveContext::GetModule()->RegisterExport(name, func_id);
}
void tvm_ocall_packed_func(const char* name,
const TVMValue* arg_values,
const int* type_codes,
int num_args,
TVMValue* ret_val,
int* ret_type_code) {
const PackedFunc* f = Registry::Get(name);
CHECK(f != nullptr) << "ocall to nonexistent function \"" << name << "\"";
TVMRetValue rv;
f->CallPacked(TVMArgs(arg_values, type_codes, num_args), &rv);
rv.MoveToCHost(ret_val, ret_type_code);
}
// Allocates space for return values. The returned pointer is only valid between
// successive calls to `tvm_ocall_reserve_space`.
TVM_REGISTER_GLOBAL("__sgx_reserve_space__")
.set_body([](TVMArgs args, TVMRetValue* rv) {
size_t num_bytes = args[0];
size_t alignment = args[1];
static TVMContext ctx = { kDLCPU, 0 };
static thread_local void* buf = nullptr;
static thread_local size_t buf_size = 0;
static thread_local size_t buf_align = 0;
if (buf_size >= num_bytes && buf_align >= alignment) *rv = nullptr;
DeviceAPI::Get(ctx)->FreeDataSpace(ctx, buf);
buf = DeviceAPI::Get(ctx)->AllocDataSpace(ctx, num_bytes, alignment, {});
buf_size = num_bytes;
buf_align = alignment;
*rv = buf;
});
} // extern "C"
} // namespace sgx
TVM_REGISTER_GLOBAL("runtime.module.loadfile_sgx")
.set_body([](TVMArgs args, TVMRetValue* rv) {
std::shared_ptr<SGXModuleNode> node = std::make_shared<SGXModuleNode>();
node->Init(args[0]);
*rv = runtime::Module(node);
});
} // namespace runtime
} // namespace tvm
......@@ -105,16 +105,14 @@ class ParallelLauncher {
tvm::runtime::threading::Yield();
}
if (!has_error_.load()) return 0;
// the following is intended to use string due to
// security issue raised in SGX backend
std::string err("");
std::ostringstream os;
for (size_t i = 0; i < par_errors_.size(); ++i) {
if (par_errors_[i].length() != 0) {
err += "Task " + std::to_string(i) + " error: " + par_errors_[i] + '\n';
os << "Task " << i << " error: " << par_errors_[i] << '\n';
par_errors_[i].clear();
}
}
TVMAPISetLastError(err.c_str());
TVMAPISetLastError(os.str().c_str());
return -1;
}
// Signal that one job has finished.
......@@ -373,11 +371,7 @@ class ThreadPool {
// number of workers used (can be restricted with affinity pref)
int num_workers_used_;
// if or not to exclude worker 0 and use master to run task 0
#ifndef _LIBCPP_SGX_CONFIG
bool exclude_worker0_{true};
#else
bool exclude_worker0_{false};
#endif
std::vector<std::unique_ptr<SpscTaskQueue> > queues_;
std::unique_ptr<tvm::runtime::threading::ThreadGroup> threads_;
};
......
......@@ -89,6 +89,7 @@ ALLOW_FILE_NAME = {
".gitmodules",
"CODEOWNERS",
".scalafmt.conf",
"Cargo.lock"
}
# List of specific files allowed in relpath to <proj_root>
......@@ -99,8 +100,8 @@ ALLOW_SPECIFIC_FILE = {
"KEYS",
"DISCLAIMER",
"Jenkinsfile",
# sgx file
"apps/sgx/enclave/sgx-deps.diff",
# sgx config
"apps/sgx/.cargo/config",
# html for demo purposes
"tests/webgl/test_static_webgl_library.html",
"web/example_rpc.html",
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment