Commit f2b30f9e by Nick Hynes Committed by Tianqi Chen

Update SGX example (#1933)

parent f9281241
...@@ -206,3 +206,6 @@ tvm_t.* ...@@ -206,3 +206,6 @@ tvm_t.*
*.cer *.cer
*.crt *.crt
*.der *.der
# patch sentinel
patched.txt
<EnclaveConfiguration> <EnclaveConfiguration>
<ProdID>0</ProdID> <ProdID>0</ProdID>
<ISVSVN>0</ISVSVN> <ISVSVN>0</ISVSVN>
<StackMaxSize>0x20000</StackMaxSize> <StackMaxSize>0xf0000</StackMaxSize>
<HeapMaxSize>0x5000000</HeapMaxSize> <HeapMaxSize>0xf000000</HeapMaxSize>
<TCSNum>NUM_THREADS</TCSNum> <TCSNum>NUM_THREADS</TCSNum>
<TCSPolicy>0</TCSPolicy> <!-- must be "bound" to use thread_local --> <TCSPolicy>0</TCSPolicy> <!-- must be "bound" to use thread_local -->
<DisableDebug>0</DisableDebug> <DisableDebug>0</DisableDebug>
......
...@@ -2,20 +2,32 @@ ...@@ -2,20 +2,32 @@
#[macro_use] #[macro_use]
extern crate lazy_static; extern crate lazy_static;
#[macro_use]
extern crate tvm; extern crate tvm;
use std::{convert::TryFrom, sync::Mutex}; use std::{
convert::{TryFrom, TryInto},
sync::Mutex,
};
use tvm::runtime::{sgx, Graph, GraphExecutor, SystemLibModule, TVMArgValue, TVMRetValue}; use tvm::{
ffi::runtime::DLTensor,
runtime::{
load_param_dict, sgx, Graph, GraphExecutor, SystemLibModule, TVMArgValue, TVMRetValue, Tensor,
},
};
lazy_static! { lazy_static! {
static ref SYSLIB: SystemLibModule = { SystemLibModule::default() }; static ref SYSLIB: SystemLibModule = { SystemLibModule::default() };
static ref MODEL: Mutex<GraphExecutor<'static, 'static>> = { static ref MODEL: Mutex<GraphExecutor<'static, 'static>> = {
let _params = include_bytes!(concat!("../", env!("BUILD_DIR"), "/params.bin"));
let graph_json = include_str!(concat!("../", env!("BUILD_DIR"), "/graph.json")); let graph_json = include_str!(concat!("../", env!("BUILD_DIR"), "/graph.json"));
let params_bytes = include_bytes!(concat!("../", env!("BUILD_DIR"), "/params.bin"));
let params = load_param_dict(params_bytes).unwrap();
let graph = Graph::try_from(graph_json).unwrap(); let graph = Graph::try_from(graph_json).unwrap();
Mutex::new(GraphExecutor::new(graph, &*SYSLIB).unwrap()) let mut exec = GraphExecutor::new(graph, &*SYSLIB).unwrap();
exec.load_params(params);
Mutex::new(exec)
}; };
} }
...@@ -24,13 +36,15 @@ fn ecall_init(_args: &[TVMArgValue]) -> TVMRetValue { ...@@ -24,13 +36,15 @@ fn ecall_init(_args: &[TVMArgValue]) -> TVMRetValue {
TVMRetValue::from(0) TVMRetValue::from(0)
} }
fn ecall_main(_args: &[TVMArgValue]) -> TVMRetValue { fn ecall_main(args: &[TVMArgValue<'static>]) -> TVMRetValue {
let model = MODEL.lock().unwrap(); let mut model = MODEL.lock().unwrap();
// model.set_input("data", args[0]); let inp = args[0].try_into().unwrap();
let mut out: Tensor = args[1].try_into().unwrap();
model.set_input("data", inp);
model.run(); model.run();
sgx::shutdown(); sgx::shutdown();
// model.get_output(0).into() out.copy(model.get_output(0).unwrap());
TVMRetValue::from(42) TVMRetValue::from(1)
} }
pub mod ecalls { pub mod ecalls {
...@@ -40,15 +54,16 @@ pub mod ecalls { ...@@ -40,15 +54,16 @@ pub mod ecalls {
use std::{ use std::{
ffi::CString, ffi::CString,
os::raw::{c_char, c_int}, mem,
os::raw::{c_char, c_int, c_void},
slice, slice,
}; };
use tvm::{ use tvm::{
ffi::runtime::{TVMRetValueHandle, TVMValue}, ffi::runtime::{TVMRetValueHandle, TVMValue},
runtime::{ runtime::{
sgx::{run_worker, SgxStatus}, sgx::{ocall_packed_func, run_worker, SgxStatus},
PackedFunc, DataType, PackedFunc,
}, },
}; };
...@@ -63,8 +78,10 @@ pub mod ecalls { ...@@ -63,8 +78,10 @@ pub mod ecalls {
const ECALLS: &'static [&'static str] = &["__tvm_run_worker__", "__tvm_main__", "init"]; const ECALLS: &'static [&'static str] = &["__tvm_run_worker__", "__tvm_main__", "init"];
pub type EcallPackedFunc = Box<Fn(&[TVMArgValue<'static>]) -> TVMRetValue + Send + Sync>;
lazy_static! { lazy_static! {
static ref ECALL_FUNCS: Vec<PackedFunc> = { static ref ECALL_FUNCS: Vec<EcallPackedFunc> = {
vec![ vec![
Box::new(run_worker), Box::new(run_worker),
Box::new(ecall_main), Box::new(ecall_main),
...@@ -87,7 +104,8 @@ pub mod ecalls { ...@@ -87,7 +104,8 @@ pub mod ecalls {
tvm_ocall!(tvm_ocall_register_export( tvm_ocall!(tvm_ocall_register_export(
CString::new(*ecall).unwrap().as_ptr(), CString::new(*ecall).unwrap().as_ptr(),
i as i32 i as i32
)).expect(&format!("Error registering `{}`", ecall)); ))
.expect(&format!("Error registering `{}`", ecall));
}); });
} }
} }
...@@ -108,7 +126,7 @@ pub mod ecalls { ...@@ -108,7 +126,7 @@ pub mod ecalls {
.into_iter() .into_iter()
.zip(type_codes.into_iter()) .zip(type_codes.into_iter())
.map(|(v, t)| TVMArgValue::new(*v, *t as i64)) .map(|(v, t)| TVMArgValue::new(*v, *t as i64))
.collect::<Vec<TVMArgValue>>() .collect::<Vec<TVMArgValue<'static>>>()
}; };
let (rv, tc) = ECALL_FUNCS[func_id as usize](&args).into_tvm_value(); let (rv, tc) = ECALL_FUNCS[func_id as usize](&args).into_tvm_value();
unsafe { unsafe {
......
...@@ -8,8 +8,10 @@ CWD = osp.abspath(osp.dirname(__file__)) ...@@ -8,8 +8,10 @@ CWD = osp.abspath(osp.dirname(__file__))
def main(): def main():
ctx = tvm.context('cpu', 0) ctx = tvm.context('cpu', 0)
model = tvm.module.load(osp.join(CWD, 'build', 'enclave.signed.so')) model = tvm.module.load(osp.join(CWD, 'build', 'enclave.signed.so'))
out = model() inp = tvm.nd.array(np.ones((1, 3, 224, 224), dtype='float32'), ctx)
if out == 42: out = tvm.nd.array(np.empty((1, 1000), dtype='float32'), ctx)
model(inp, out)
if abs(out.asnumpy().sum() - 1) < 0.001:
print('It works!') print('It works!')
else: else:
print('It doesn\'t work!') print('It doesn\'t work!')
......
...@@ -3,7 +3,7 @@ apt-get update && apt-get install -y --no-install-recommends --force-yes curl ...@@ -3,7 +3,7 @@ apt-get update && apt-get install -y --no-install-recommends --force-yes curl
export RUSTUP_HOME=/opt/rust export RUSTUP_HOME=/opt/rust
export CARGO_HOME=/opt/rust export CARGO_HOME=/opt/rust
# this rustc is one supported by the installed version of rust-sgx-sdk # this rustc is one supported by the installed version of rust-sgx-sdk
curl https://sh.rustup.rs -sSf | sh -s -- -y --no-modify-path --default-toolchain nightly-2018-09-25 curl https://sh.rustup.rs -sSf | sh -s -- -y --no-modify-path --default-toolchain nightly-2018-10-01
. $CARGO_HOME/env . $CARGO_HOME/env
rustup toolchain add nightly rustup toolchain add nightly
rustup component add rust-src rustup component add rust-src
......
...@@ -126,6 +126,7 @@ pub struct Tensor<'a> { ...@@ -126,6 +126,7 @@ pub struct Tensor<'a> {
/// The `Tensor` strides. Can be `None` if the `Tensor` is contiguous. /// The `Tensor` strides. Can be `None` if the `Tensor` is contiguous.
pub(super) strides: Option<Vec<usize>>, pub(super) strides: Option<Vec<usize>>,
pub(super) byte_offset: isize, pub(super) byte_offset: isize,
/// The number of elements in the `Tensor`.
pub(super) size: usize, pub(super) size: usize,
} }
...@@ -316,12 +317,12 @@ pub struct DataType { ...@@ -316,12 +317,12 @@ pub struct DataType {
impl DataType { impl DataType {
/// Returns the number of bytes occupied by an element of this `DataType`. /// Returns the number of bytes occupied by an element of this `DataType`.
fn itemsize(&self) -> usize { pub fn itemsize(&self) -> usize {
(self.bits * self.lanes) >> 3 (self.bits * self.lanes) >> 3
} }
/// Returns whether this `DataType` represents primitive type `T`. /// Returns whether this `DataType` represents primitive type `T`.
fn is_type<T: 'static>(&self) -> bool { pub fn is_type<T: 'static>(&self) -> bool {
if self.lanes != 1 { if self.lanes != 1 {
return false; return false;
} }
...@@ -345,6 +346,16 @@ impl<'a> From<&'a DataType> for DLDataType { ...@@ -345,6 +346,16 @@ impl<'a> From<&'a DataType> for DLDataType {
} }
} }
impl From<DLDataType> for DataType {
fn from(dtype: DLDataType) -> Self {
Self {
code: dtype.code as usize,
bits: dtype.bits as usize,
lanes: dtype.lanes as usize,
}
}
}
macro_rules! make_dtype_const { macro_rules! make_dtype_const {
($name: ident, $code: ident, $bits: expr, $lanes: expr) => { ($name: ident, $code: ident, $bits: expr, $lanes: expr) => {
const $name: DataType = DataType { const $name: DataType = DataType {
...@@ -394,6 +405,33 @@ impl Default for TVMContext { ...@@ -394,6 +405,33 @@ impl Default for TVMContext {
} }
} }
impl<'a> From<DLTensor> for Tensor<'a> {
fn from(dlt: DLTensor) -> Self {
unsafe {
let dtype = DataType::from(dlt.dtype);
let shape = slice::from_raw_parts(dlt.shape, dlt.ndim as usize).to_vec();
let size = shape.iter().map(|v| *v as usize).product::<usize>() as usize;
let storage = Storage::from(slice::from_raw_parts(
dlt.data as *const u8,
dtype.itemsize() * size,
));
Self {
data: storage,
ctx: TVMContext::default(),
dtype: dtype,
size: size,
shape: shape,
strides: if dlt.strides == ptr::null_mut() {
None
} else {
Some(slice::from_raw_parts_mut(dlt.strides as *mut usize, size).to_vec())
},
byte_offset: dlt.byte_offset as isize,
}
}
}
}
/// `From` conversions to `Tensor` for owned or borrowed `ndarray::Array`. /// `From` conversions to `Tensor` for owned or borrowed `ndarray::Array`.
/// ///
/// # Panics /// # Panics
......
...@@ -14,6 +14,9 @@ use std::os::raw::c_char; ...@@ -14,6 +14,9 @@ use std::os::raw::c_char;
pub use self::{array::*, graph::*, module::*, packed_func::*, threading::*, workspace::*}; pub use self::{array::*, graph::*, module::*, packed_func::*, threading::*, workspace::*};
#[cfg(target_env = "sgx")]
use self::sgx::ocall_packed_func;
#[no_mangle] #[no_mangle]
pub extern "C" fn TVMAPISetLastError(cmsg: *const c_char) { pub extern "C" fn TVMAPISetLastError(cmsg: *const c_char) {
#[cfg(not(target_env = "sgx"))] #[cfg(not(target_env = "sgx"))]
......
use std::{any::Any, convert::TryFrom, marker::PhantomData, os::raw::c_void}; use std::{any::Any, convert::TryFrom, marker::PhantomData, os::raw::c_void};
use super::Tensor;
use ffi::runtime::{ use ffi::runtime::{
BackendPackedCFunc, DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLTensor, BackendPackedCFunc, DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLTensor,
TVMTypeCode_kArrayHandle, TVMTypeCode_kHandle, TVMValue, TVMTypeCode_kArrayHandle, TVMTypeCode_kHandle, TVMTypeCode_kNDArrayContainer, TVMValue,
}; };
use errors::*; use errors::*;
...@@ -55,6 +56,18 @@ macro_rules! impl_prim_tvm_arg { ...@@ -55,6 +56,18 @@ macro_rules! impl_prim_tvm_arg {
} }
} }
} }
impl<'a> TryFrom<TVMArgValue<'a>> for $type {
type Error = Error;
fn try_from(val: TVMArgValue<'a>) -> Result<Self> {
ensure!(
val.type_code == $code as i64,
"Could not downcast arg. Expected `{}`, got `{}`",
$code,
val.type_code
);
Ok(unsafe { val.value.$field as $type })
}
}
}; };
($type:ty, $field:ident, $code:expr) => { ($type:ty, $field:ident, $code:expr) => {
impl_prim_tvm_arg!($type, $field, $code, $type); impl_prim_tvm_arg!($type, $field, $code, $type);
...@@ -75,7 +88,6 @@ impl_prim_tvm_arg!(i32, v_int64); ...@@ -75,7 +88,6 @@ impl_prim_tvm_arg!(i32, v_int64);
impl_prim_tvm_arg!(u32, v_int64); impl_prim_tvm_arg!(u32, v_int64);
impl_prim_tvm_arg!(i64, v_int64); impl_prim_tvm_arg!(i64, v_int64);
impl_prim_tvm_arg!(u64, v_int64); impl_prim_tvm_arg!(u64, v_int64);
impl_prim_tvm_arg!(bool, v_int64);
/// Creates a conversion to a `TVMArgValue` for an object handle. /// Creates a conversion to a `TVMArgValue` for an object handle.
impl<'a, T> From<*const T> for TVMArgValue<'a> { impl<'a, T> From<*const T> for TVMArgValue<'a> {
...@@ -127,6 +139,23 @@ impl<'a> From<&'a DLTensor> for TVMArgValue<'a> { ...@@ -127,6 +139,23 @@ impl<'a> From<&'a DLTensor> for TVMArgValue<'a> {
} }
} }
impl<'a> TryFrom<TVMArgValue<'a>> for Tensor<'a> {
type Error = Error;
fn try_from(val: TVMArgValue<'a>) -> Result<Self> {
ensure!(
val.type_code == TVMTypeCode_kArrayHandle as i64
|| val.type_code == TVMTypeCode_kNDArrayContainer as i64,
"Could not downcast arg. Expected `{}` or `{}`, but got `{}`",
TVMTypeCode_kArrayHandle,
TVMTypeCode_kNDArrayContainer,
val.type_code,
);
let dlt = unsafe { *(val.value.v_handle as *mut DLTensor as *const DLTensor) };
Ok(dlt.into())
}
}
/// An owned TVMPODValue. Can be converted from a variety of primitive and object types. /// An owned TVMPODValue. Can be converted from a variety of primitive and object types.
/// Can be downcasted using `try_from` if it contains the desired type. /// Can be downcasted using `try_from` if it contains the desired type.
/// ///
...@@ -175,7 +204,7 @@ impl TVMRetValue { ...@@ -175,7 +204,7 @@ impl TVMRetValue {
2 => TVMValue { 2 => TVMValue {
v_float64: self.prim_value.clone() as f64, v_float64: self.prim_value.clone() as f64,
}, },
3 | 7 | 8 | 9 | 10 => TVMValue { 3 | 7 | 8 | 9 | 10 | 13 => TVMValue {
v_handle: Box::into_raw(self.box_value) as *mut c_void, v_handle: Box::into_raw(self.box_value) as *mut c_void,
}, },
11 | 12 => TVMValue { 11 | 12 => TVMValue {
...@@ -265,6 +294,33 @@ impl_prim_ret_value!(isize, 0); ...@@ -265,6 +294,33 @@ impl_prim_ret_value!(isize, 0);
impl_prim_ret_value!(usize, 1); impl_prim_ret_value!(usize, 1);
impl_boxed_ret_value!(String, 11); impl_boxed_ret_value!(String, 11);
impl<'a, 't> From<&'t Tensor<'a>> for TVMRetValue {
fn from(val: &'t Tensor<'a>) -> Self {
TVMRetValue {
prim_value: 0,
box_value: box DLTensor::from(val),
type_code: TVMTypeCode_kNDArrayContainer as i64,
}
}
}
impl<'a> TryFrom<TVMRetValue> for Tensor<'a> {
type Error = Error;
fn try_from(ret: TVMRetValue) -> Result<Self> {
ensure!(
ret.type_code == TVMTypeCode_kArrayHandle as i64
|| ret.type_code == TVMTypeCode_kNDArrayContainer as i64,
"Could not downcast arg. Expected `{}` or `{}`, but got `{}`",
TVMTypeCode_kArrayHandle,
TVMTypeCode_kNDArrayContainer,
ret.type_code,
);
let dlt = unsafe { *(ret.prim_value as *mut DLTensor as *const DLTensor) };
Ok(dlt.into())
}
}
// @see `WrapPackedFunc` in `llvm_module.cc`. // @see `WrapPackedFunc` in `llvm_module.cc`.
pub(super) fn wrap_backend_packed_func(func: BackendPackedCFunc) -> PackedFunc { pub(super) fn wrap_backend_packed_func(func: BackendPackedCFunc) -> PackedFunc {
box move |args: &[TVMArgValue]| { box move |args: &[TVMArgValue]| {
......
...@@ -60,11 +60,11 @@ pub fn ocall_packed_func<S: AsRef<str>>(fn_name: S, args: &[TVMArgValue]) -> Res ...@@ -60,11 +60,11 @@ pub fn ocall_packed_func<S: AsRef<str>>(fn_name: S, args: &[TVMArgValue]) -> Res
#[macro_export] #[macro_export]
macro_rules! ocall_packed { macro_rules! ocall_packed {
($fn_name:expr, $($args:expr),+) => { ($fn_name:expr, $($args:expr),+) => {
::runtime::sgx::ocall_packed_func($fn_name, &[$($args.into(),)+]) ocall_packed_func($fn_name, &[$($args.into(),)+])
.expect(concat!("Error calling `", $fn_name, "`")) .expect(concat!("Error calling `", $fn_name, "`"))
}; };
($fn_name:expr) => { ($fn_name:expr) => {
::runtime::sgx::ocall_packed_func($fn_name, &Vec::new()) ocall_packed_func($fn_name, &Vec::new())
.expect(concat!("Error calling `", $fn_name, "`")) .expect(concat!("Error calling `", $fn_name, "`"))
} }
} }
......
...@@ -23,7 +23,7 @@ use super::super::errors::*; ...@@ -23,7 +23,7 @@ use super::super::errors::*;
use ffi::runtime::TVMParallelGroupEnv; use ffi::runtime::TVMParallelGroupEnv;
#[cfg(target_env = "sgx")] #[cfg(target_env = "sgx")]
use super::{TVMArgValue, TVMRetValue}; use super::{sgx::ocall_packed_func, TVMArgValue, TVMRetValue};
type FTVMParallelLambda = type FTVMParallelLambda =
extern "C" fn(task_id: usize, penv: *const TVMParallelGroupEnv, cdata: *const c_void) -> i32; extern "C" fn(task_id: usize, penv: *const TVMParallelGroupEnv, cdata: *const c_void) -> i32;
......
...@@ -21,7 +21,6 @@ enclave { ...@@ -21,7 +21,6 @@ enclave {
[out] TVMValue* ret_val, [out] TVMValue* ret_val,
[out] int* ret_type_code); [out] int* ret_type_code);
void tvm_ocall_register_export([in, string] const char* name, int func_id); void tvm_ocall_register_export([in, string] const char* name, int func_id);
void* tvm_ocall_reserve_space(size_t num_bytes, size_t alignment);
}; };
}; };
...@@ -202,21 +202,25 @@ void tvm_ocall_packed_func(const char* name, ...@@ -202,21 +202,25 @@ void tvm_ocall_packed_func(const char* name,
// Allocates space for return values. The returned pointer is only valid between // Allocates space for return values. The returned pointer is only valid between
// successive calls to `tvm_ocall_reserve_space`. // successive calls to `tvm_ocall_reserve_space`.
void* tvm_ocall_reserve_space(size_t num_bytes, size_t alignment) { TVM_REGISTER_GLOBAL("__sgx_reserve_space__")
.set_body([](TVMArgs args, TVMRetValue* rv) {
size_t num_bytes = args[0];
size_t alignment = args[1];
static TVMContext ctx = { kDLCPU, 0 }; static TVMContext ctx = { kDLCPU, 0 };
static thread_local void* buf = nullptr; static thread_local void* buf = nullptr;
static thread_local size_t buf_size = 0; static thread_local size_t buf_size = 0;
static thread_local size_t buf_align = 0; static thread_local size_t buf_align = 0;
if (buf_size >= num_bytes && buf_align >= alignment) return buf; if (buf_size >= num_bytes && buf_align >= alignment) *rv = nullptr;
DeviceAPI::Get(ctx)->FreeDataSpace(ctx, buf); DeviceAPI::Get(ctx)->FreeDataSpace(ctx, buf);
buf = DeviceAPI::Get(ctx)->AllocDataSpace(ctx, num_bytes, alignment, {}); buf = DeviceAPI::Get(ctx)->AllocDataSpace(ctx, num_bytes, alignment, {});
buf_size = num_bytes; buf_size = num_bytes;
buf_align = alignment; buf_align = alignment;
return buf; *rv = buf;
} });
} // extern "C" } // extern "C"
} // namespace sgx } // namespace sgx
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment