Commit f2b30f9e by Nick Hynes Committed by Tianqi Chen

Update SGX example (#1933)

parent f9281241
......@@ -206,3 +206,6 @@ tvm_t.*
*.cer
*.crt
*.der
# patch sentinel
patched.txt
<EnclaveConfiguration>
<ProdID>0</ProdID>
<ISVSVN>0</ISVSVN>
<StackMaxSize>0x20000</StackMaxSize>
<HeapMaxSize>0x5000000</HeapMaxSize>
<StackMaxSize>0xf0000</StackMaxSize>
<HeapMaxSize>0xf000000</HeapMaxSize>
<TCSNum>NUM_THREADS</TCSNum>
<TCSPolicy>0</TCSPolicy> <!-- must be "bound" to use thread_local -->
<DisableDebug>0</DisableDebug>
......
......@@ -2,20 +2,32 @@
#[macro_use]
extern crate lazy_static;
#[macro_use]
extern crate tvm;
use std::{convert::TryFrom, sync::Mutex};
use std::{
convert::{TryFrom, TryInto},
sync::Mutex,
};
use tvm::runtime::{sgx, Graph, GraphExecutor, SystemLibModule, TVMArgValue, TVMRetValue};
use tvm::{
ffi::runtime::DLTensor,
runtime::{
load_param_dict, sgx, Graph, GraphExecutor, SystemLibModule, TVMArgValue, TVMRetValue, Tensor,
},
};
lazy_static! {
static ref SYSLIB: SystemLibModule = { SystemLibModule::default() };
static ref MODEL: Mutex<GraphExecutor<'static, 'static>> = {
let _params = include_bytes!(concat!("../", env!("BUILD_DIR"), "/params.bin"));
let graph_json = include_str!(concat!("../", env!("BUILD_DIR"), "/graph.json"));
let params_bytes = include_bytes!(concat!("../", env!("BUILD_DIR"), "/params.bin"));
let params = load_param_dict(params_bytes).unwrap();
let graph = Graph::try_from(graph_json).unwrap();
Mutex::new(GraphExecutor::new(graph, &*SYSLIB).unwrap())
let mut exec = GraphExecutor::new(graph, &*SYSLIB).unwrap();
exec.load_params(params);
Mutex::new(exec)
};
}
......@@ -24,13 +36,15 @@ fn ecall_init(_args: &[TVMArgValue]) -> TVMRetValue {
TVMRetValue::from(0)
}
fn ecall_main(_args: &[TVMArgValue]) -> TVMRetValue {
let model = MODEL.lock().unwrap();
// model.set_input("data", args[0]);
fn ecall_main(args: &[TVMArgValue<'static>]) -> TVMRetValue {
let mut model = MODEL.lock().unwrap();
let inp = args[0].try_into().unwrap();
let mut out: Tensor = args[1].try_into().unwrap();
model.set_input("data", inp);
model.run();
sgx::shutdown();
// model.get_output(0).into()
TVMRetValue::from(42)
out.copy(model.get_output(0).unwrap());
TVMRetValue::from(1)
}
pub mod ecalls {
......@@ -40,15 +54,16 @@ pub mod ecalls {
use std::{
ffi::CString,
os::raw::{c_char, c_int},
mem,
os::raw::{c_char, c_int, c_void},
slice,
};
use tvm::{
ffi::runtime::{TVMRetValueHandle, TVMValue},
runtime::{
sgx::{run_worker, SgxStatus},
PackedFunc,
sgx::{ocall_packed_func, run_worker, SgxStatus},
DataType, PackedFunc,
},
};
......@@ -63,8 +78,10 @@ pub mod ecalls {
const ECALLS: &'static [&'static str] = &["__tvm_run_worker__", "__tvm_main__", "init"];
pub type EcallPackedFunc = Box<Fn(&[TVMArgValue<'static>]) -> TVMRetValue + Send + Sync>;
lazy_static! {
static ref ECALL_FUNCS: Vec<PackedFunc> = {
static ref ECALL_FUNCS: Vec<EcallPackedFunc> = {
vec![
Box::new(run_worker),
Box::new(ecall_main),
......@@ -87,7 +104,8 @@ pub mod ecalls {
tvm_ocall!(tvm_ocall_register_export(
CString::new(*ecall).unwrap().as_ptr(),
i as i32
)).expect(&format!("Error registering `{}`", ecall));
))
.expect(&format!("Error registering `{}`", ecall));
});
}
}
......@@ -108,7 +126,7 @@ pub mod ecalls {
.into_iter()
.zip(type_codes.into_iter())
.map(|(v, t)| TVMArgValue::new(*v, *t as i64))
.collect::<Vec<TVMArgValue>>()
.collect::<Vec<TVMArgValue<'static>>>()
};
let (rv, tc) = ECALL_FUNCS[func_id as usize](&args).into_tvm_value();
unsafe {
......
......@@ -8,8 +8,10 @@ CWD = osp.abspath(osp.dirname(__file__))
def main():
ctx = tvm.context('cpu', 0)
model = tvm.module.load(osp.join(CWD, 'build', 'enclave.signed.so'))
out = model()
if out == 42:
inp = tvm.nd.array(np.ones((1, 3, 224, 224), dtype='float32'), ctx)
out = tvm.nd.array(np.empty((1, 1000), dtype='float32'), ctx)
model(inp, out)
if abs(out.asnumpy().sum() - 1) < 0.001:
print('It works!')
else:
print('It doesn\'t work!')
......
......@@ -3,7 +3,7 @@ apt-get update && apt-get install -y --no-install-recommends --force-yes curl
export RUSTUP_HOME=/opt/rust
export CARGO_HOME=/opt/rust
# this rustc is one supported by the installed version of rust-sgx-sdk
curl https://sh.rustup.rs -sSf | sh -s -- -y --no-modify-path --default-toolchain nightly-2018-09-25
curl https://sh.rustup.rs -sSf | sh -s -- -y --no-modify-path --default-toolchain nightly-2018-10-01
. $CARGO_HOME/env
rustup toolchain add nightly
rustup component add rust-src
......
......@@ -126,6 +126,7 @@ pub struct Tensor<'a> {
/// The `Tensor` strides. Can be `None` if the `Tensor` is contiguous.
pub(super) strides: Option<Vec<usize>>,
pub(super) byte_offset: isize,
/// The number of elements in the `Tensor`.
pub(super) size: usize,
}
......@@ -316,12 +317,12 @@ pub struct DataType {
impl DataType {
/// Returns the number of bytes occupied by an element of this `DataType`.
fn itemsize(&self) -> usize {
pub fn itemsize(&self) -> usize {
(self.bits * self.lanes) >> 3
}
/// Returns whether this `DataType` represents primitive type `T`.
fn is_type<T: 'static>(&self) -> bool {
pub fn is_type<T: 'static>(&self) -> bool {
if self.lanes != 1 {
return false;
}
......@@ -345,6 +346,16 @@ impl<'a> From<&'a DataType> for DLDataType {
}
}
impl From<DLDataType> for DataType {
fn from(dtype: DLDataType) -> Self {
Self {
code: dtype.code as usize,
bits: dtype.bits as usize,
lanes: dtype.lanes as usize,
}
}
}
macro_rules! make_dtype_const {
($name: ident, $code: ident, $bits: expr, $lanes: expr) => {
const $name: DataType = DataType {
......@@ -394,6 +405,33 @@ impl Default for TVMContext {
}
}
impl<'a> From<DLTensor> for Tensor<'a> {
fn from(dlt: DLTensor) -> Self {
unsafe {
let dtype = DataType::from(dlt.dtype);
let shape = slice::from_raw_parts(dlt.shape, dlt.ndim as usize).to_vec();
let size = shape.iter().map(|v| *v as usize).product::<usize>() as usize;
let storage = Storage::from(slice::from_raw_parts(
dlt.data as *const u8,
dtype.itemsize() * size,
));
Self {
data: storage,
ctx: TVMContext::default(),
dtype: dtype,
size: size,
shape: shape,
strides: if dlt.strides == ptr::null_mut() {
None
} else {
Some(slice::from_raw_parts_mut(dlt.strides as *mut usize, size).to_vec())
},
byte_offset: dlt.byte_offset as isize,
}
}
}
}
/// `From` conversions to `Tensor` for owned or borrowed `ndarray::Array`.
///
/// # Panics
......
......@@ -14,6 +14,9 @@ use std::os::raw::c_char;
pub use self::{array::*, graph::*, module::*, packed_func::*, threading::*, workspace::*};
#[cfg(target_env = "sgx")]
use self::sgx::ocall_packed_func;
#[no_mangle]
pub extern "C" fn TVMAPISetLastError(cmsg: *const c_char) {
#[cfg(not(target_env = "sgx"))]
......
use std::{any::Any, convert::TryFrom, marker::PhantomData, os::raw::c_void};
use super::Tensor;
use ffi::runtime::{
BackendPackedCFunc, DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLTensor,
TVMTypeCode_kArrayHandle, TVMTypeCode_kHandle, TVMValue,
TVMTypeCode_kArrayHandle, TVMTypeCode_kHandle, TVMTypeCode_kNDArrayContainer, TVMValue,
};
use errors::*;
......@@ -55,6 +56,18 @@ macro_rules! impl_prim_tvm_arg {
}
}
}
impl<'a> TryFrom<TVMArgValue<'a>> for $type {
type Error = Error;
fn try_from(val: TVMArgValue<'a>) -> Result<Self> {
ensure!(
val.type_code == $code as i64,
"Could not downcast arg. Expected `{}`, got `{}`",
$code,
val.type_code
);
Ok(unsafe { val.value.$field as $type })
}
}
};
($type:ty, $field:ident, $code:expr) => {
impl_prim_tvm_arg!($type, $field, $code, $type);
......@@ -75,7 +88,6 @@ impl_prim_tvm_arg!(i32, v_int64);
impl_prim_tvm_arg!(u32, v_int64);
impl_prim_tvm_arg!(i64, v_int64);
impl_prim_tvm_arg!(u64, v_int64);
impl_prim_tvm_arg!(bool, v_int64);
/// Creates a conversion to a `TVMArgValue` for an object handle.
impl<'a, T> From<*const T> for TVMArgValue<'a> {
......@@ -127,6 +139,23 @@ impl<'a> From<&'a DLTensor> for TVMArgValue<'a> {
}
}
impl<'a> TryFrom<TVMArgValue<'a>> for Tensor<'a> {
type Error = Error;
fn try_from(val: TVMArgValue<'a>) -> Result<Self> {
ensure!(
val.type_code == TVMTypeCode_kArrayHandle as i64
|| val.type_code == TVMTypeCode_kNDArrayContainer as i64,
"Could not downcast arg. Expected `{}` or `{}`, but got `{}`",
TVMTypeCode_kArrayHandle,
TVMTypeCode_kNDArrayContainer,
val.type_code,
);
let dlt = unsafe { *(val.value.v_handle as *mut DLTensor as *const DLTensor) };
Ok(dlt.into())
}
}
/// An owned TVMPODValue. Can be converted from a variety of primitive and object types.
/// Can be downcasted using `try_from` if it contains the desired type.
///
......@@ -175,7 +204,7 @@ impl TVMRetValue {
2 => TVMValue {
v_float64: self.prim_value.clone() as f64,
},
3 | 7 | 8 | 9 | 10 => TVMValue {
3 | 7 | 8 | 9 | 10 | 13 => TVMValue {
v_handle: Box::into_raw(self.box_value) as *mut c_void,
},
11 | 12 => TVMValue {
......@@ -265,6 +294,33 @@ impl_prim_ret_value!(isize, 0);
impl_prim_ret_value!(usize, 1);
impl_boxed_ret_value!(String, 11);
impl<'a, 't> From<&'t Tensor<'a>> for TVMRetValue {
fn from(val: &'t Tensor<'a>) -> Self {
TVMRetValue {
prim_value: 0,
box_value: box DLTensor::from(val),
type_code: TVMTypeCode_kNDArrayContainer as i64,
}
}
}
impl<'a> TryFrom<TVMRetValue> for Tensor<'a> {
type Error = Error;
fn try_from(ret: TVMRetValue) -> Result<Self> {
ensure!(
ret.type_code == TVMTypeCode_kArrayHandle as i64
|| ret.type_code == TVMTypeCode_kNDArrayContainer as i64,
"Could not downcast arg. Expected `{}` or `{}`, but got `{}`",
TVMTypeCode_kArrayHandle,
TVMTypeCode_kNDArrayContainer,
ret.type_code,
);
let dlt = unsafe { *(ret.prim_value as *mut DLTensor as *const DLTensor) };
Ok(dlt.into())
}
}
// @see `WrapPackedFunc` in `llvm_module.cc`.
pub(super) fn wrap_backend_packed_func(func: BackendPackedCFunc) -> PackedFunc {
box move |args: &[TVMArgValue]| {
......
......@@ -60,11 +60,11 @@ pub fn ocall_packed_func<S: AsRef<str>>(fn_name: S, args: &[TVMArgValue]) -> Res
#[macro_export]
macro_rules! ocall_packed {
($fn_name:expr, $($args:expr),+) => {
::runtime::sgx::ocall_packed_func($fn_name, &[$($args.into(),)+])
ocall_packed_func($fn_name, &[$($args.into(),)+])
.expect(concat!("Error calling `", $fn_name, "`"))
};
($fn_name:expr) => {
::runtime::sgx::ocall_packed_func($fn_name, &Vec::new())
ocall_packed_func($fn_name, &Vec::new())
.expect(concat!("Error calling `", $fn_name, "`"))
}
}
......
......@@ -23,7 +23,7 @@ use super::super::errors::*;
use ffi::runtime::TVMParallelGroupEnv;
#[cfg(target_env = "sgx")]
use super::{TVMArgValue, TVMRetValue};
use super::{sgx::ocall_packed_func, TVMArgValue, TVMRetValue};
type FTVMParallelLambda =
extern "C" fn(task_id: usize, penv: *const TVMParallelGroupEnv, cdata: *const c_void) -> i32;
......
......@@ -21,7 +21,6 @@ enclave {
[out] TVMValue* ret_val,
[out] int* ret_type_code);
void tvm_ocall_register_export([in, string] const char* name, int func_id);
void* tvm_ocall_reserve_space(size_t num_bytes, size_t alignment);
};
};
......@@ -202,21 +202,25 @@ void tvm_ocall_packed_func(const char* name,
// Allocates space for return values. The returned pointer is only valid between
// successive calls to `tvm_ocall_reserve_space`.
void* tvm_ocall_reserve_space(size_t num_bytes, size_t alignment) {
TVM_REGISTER_GLOBAL("__sgx_reserve_space__")
.set_body([](TVMArgs args, TVMRetValue* rv) {
size_t num_bytes = args[0];
size_t alignment = args[1];
static TVMContext ctx = { kDLCPU, 0 };
static thread_local void* buf = nullptr;
static thread_local size_t buf_size = 0;
static thread_local size_t buf_align = 0;
if (buf_size >= num_bytes && buf_align >= alignment) return buf;
if (buf_size >= num_bytes && buf_align >= alignment) *rv = nullptr;
DeviceAPI::Get(ctx)->FreeDataSpace(ctx, buf);
buf = DeviceAPI::Get(ctx)->AllocDataSpace(ctx, num_bytes, alignment, {});
buf_size = num_bytes;
buf_align = alignment;
return buf;
}
*rv = buf;
});
} // extern "C"
} // namespace sgx
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment