Commit dd23bb6f by nhynes Committed by Tianqi Chen

[SGX] Improve edgeroutines (#1775)

parent e986f87e
...@@ -183,6 +183,7 @@ docs.tgz ...@@ -183,6 +183,7 @@ docs.tgz
cat.png cat.png
*.mlmodel *.mlmodel
tvm_u.* tvm_u.*
tvm_t.*
# Mac OS X # Mac OS X
.DS_Store .DS_Store
build* build*
......
...@@ -183,7 +183,9 @@ add_library(tvm SHARED ${COMPILER_SRCS} ${RUNTIME_SRCS}) ...@@ -183,7 +183,9 @@ add_library(tvm SHARED ${COMPILER_SRCS} ${RUNTIME_SRCS})
add_library(tvm_topi SHARED ${TOPI_SRCS}) add_library(tvm_topi SHARED ${TOPI_SRCS})
add_library(tvm_runtime SHARED ${RUNTIME_SRCS}) add_library(tvm_runtime SHARED ${RUNTIME_SRCS})
if(NOT USE_SGX STREQUAL "OFF") if(NOT USE_SGX STREQUAL "OFF")
add_dependencies(tvm_runtime sgx_edl) add_dependencies(tvm sgx_edl)
add_dependencies(tvm_runtime sgx_edl tvm_t)
install(TARGETS tvm_t ARCHIVE DESTINATION lib${LIB_SUFFIX})
endif() endif()
add_library(nnvm_compiler SHARED ${NNVM_COMPILER_SRCS}) add_library(nnvm_compiler SHARED ${NNVM_COMPILER_SRCS})
......
...@@ -3,6 +3,8 @@ if(NOT USE_SGX STREQUAL "OFF") ...@@ -3,6 +3,8 @@ if(NOT USE_SGX STREQUAL "OFF")
set(_sgx_src ${CMAKE_CURRENT_SOURCE_DIR}/src/runtime/sgx) set(_sgx_src ${CMAKE_CURRENT_SOURCE_DIR}/src/runtime/sgx)
set(_tvm_u_h ${_sgx_src}/untrusted/tvm_u.h) set(_tvm_u_h ${_sgx_src}/untrusted/tvm_u.h)
set(_tvm_t_h ${_sgx_src}/trusted/tvm_t.h)
set(_tvm_t_c ${_sgx_src}/trusted/tvm_t.c)
set(_tvm_edl ${_sgx_src}/tvm.edl) set(_tvm_edl ${_sgx_src}/tvm.edl)
set(_sgx_ustdc ${RUST_SGX_SDK}/sgx_ustdc) set(_sgx_ustdc ${RUST_SGX_SDK}/sgx_ustdc)
...@@ -11,13 +13,16 @@ if(NOT USE_SGX STREQUAL "OFF") ...@@ -11,13 +13,16 @@ if(NOT USE_SGX STREQUAL "OFF")
set(_urts_lib "${_urts_lib}_sim") set(_urts_lib "${_urts_lib}_sim")
endif() endif()
# build edge routines
add_custom_command( add_custom_command(
OUTPUT ${_tvm_u_h} OUTPUT ${_tvm_u_h}
COMMAND ${USE_SGX}/bin/x64/sgx_edger8r --untrusted COMMAND ${USE_SGX}/bin/x64/sgx_edger8r --untrusted
--untrusted-dir ${_sgx_src}/untrusted --untrusted --untrusted-dir ${_sgx_src}/untrusted
--trusted --trusted-dir ${_sgx_src}/trusted
--search-path ${USE_SGX}/include --search-path ${RUST_SGX_SDK}/edl --search-path ${USE_SGX}/include --search-path ${RUST_SGX_SDK}/edl
${_tvm_edl} ${_tvm_edl}
COMMAND sed -i "4i '#include <tvm/runtime/c_runtime_api.h>'" ${_tvm_u_h} COMMAND sed -i "4i '#include <tvm/runtime/c_runtime_api.h>'" ${_tvm_u_h}
COMMAND sed -i "4i '#include <tvm/runtime/c_runtime_api.h>'" ${_tvm_t_h}
DEPENDS ${_tvm_edl} DEPENDS ${_tvm_edl}
) )
add_custom_command( add_custom_command(
...@@ -27,6 +32,13 @@ if(NOT USE_SGX STREQUAL "OFF") ...@@ -27,6 +32,13 @@ if(NOT USE_SGX STREQUAL "OFF")
) )
add_custom_target(sgx_edl DEPENDS ${_tvm_u_h} ${_sgx_ustdc}/libsgx_ustdc.a) add_custom_target(sgx_edl DEPENDS ${_tvm_u_h} ${_sgx_ustdc}/libsgx_ustdc.a)
# build trusted library
set_source_files_properties(${_tvm_t_c} PROPERTIES GENERATED TRUE)
add_library(tvm_t STATIC ${_tvm_t_c})
add_dependencies(tvm_t sgx_edl)
target_include_directories(tvm_t PUBLIC ${USE_SGX}/include ${USE_SGX}/include/tlibc)
# add untrusted runtime files
include_directories(${USE_SGX}/include) include_directories(${USE_SGX}/include)
file(GLOB RUNTIME_SGX_SRCS ${_sgx_src}/untrusted/*.c*) file(GLOB RUNTIME_SGX_SRCS ${_sgx_src}/untrusted/*.c*)
list(APPEND TVM_RUNTIME_LINKER_LIBS list(APPEND TVM_RUNTIME_LINKER_LIBS
......
...@@ -9,7 +9,8 @@ enclave { ...@@ -9,7 +9,8 @@ enclave {
[in, count=num_args] const TVMValue* arg_values, [in, count=num_args] const TVMValue* arg_values,
[in, count=num_args] const int* type_codes, [in, count=num_args] const int* type_codes,
int num_args, int num_args,
[isptr, user_check] TVMRetValueHandle ret); [out] TVMValue* ret_val,
[out] int* ret_type_code);
}; };
untrusted { untrusted {
...@@ -19,10 +20,6 @@ enclave { ...@@ -19,10 +20,6 @@ enclave {
int num_args, int num_args,
[out] TVMValue* ret_val, [out] TVMValue* ret_val,
[out] int* ret_type_code); [out] int* ret_type_code);
void tvm_ocall_set_return([isptr, user_check] TVMRetValueHandle ret,
[in, count=num_ret] const TVMValue* value,
[in, count=num_ret] const int* type_code,
int num_ret);
void tvm_ocall_register_export([in, string] const char* name, int func_id); void tvm_ocall_register_export([in, string] const char* name, int func_id);
void* tvm_ocall_reserve_space(size_t num_bytes, size_t alignment); void* tvm_ocall_reserve_space(size_t num_bytes, size_t alignment);
}; };
......
...@@ -110,15 +110,18 @@ class SGXModuleNode : public ModuleNode { ...@@ -110,15 +110,18 @@ class SGXModuleNode : public ModuleNode {
int func_id = exported->second; int func_id = exported->second;
return PackedFunc([this, func_id](TVMArgs args, TVMRetValue* rv) { return PackedFunc([this, func_id](TVMArgs args, TVMRetValue* rv) {
sgx::EnclaveContext ctx(this); sgx::EnclaveContext ctx(this);
TVMValue ret_value;
int ret_type_code;
TVM_SGX_CHECKED_CALL(tvm_ecall_packed_func(eid_, func_id, TVM_SGX_CHECKED_CALL(tvm_ecall_packed_func(eid_, func_id,
args.values, args.type_codes, args.num_args, rv)); args.values, args.type_codes, args.num_args, &ret_value, &ret_type_code));
*rv = TVMArgValue(ret_value, ret_type_code);
}); });
} }
void RunWorkers(int num_tasks, void* tg) { void RunWorkers(int num_tasks) {
std::function<void(int)> runner = [this, tg](int _worker_id) { std::function<void(int)> runner = [this](int _worker_id) {
this->GetFunction("__tvm_run_worker__", this->GetFunction("__tvm_run_worker__",
std::shared_ptr<SGXModuleNode>(nullptr))(tg); std::shared_ptr<SGXModuleNode>(nullptr))();
}; };
thread_group_.reset(new tvm::runtime::threading::ThreadGroup( thread_group_.reset(new tvm::runtime::threading::ThreadGroup(
num_tasks, runner, false /* include_main_thread */)); num_tasks, runner, false /* include_main_thread */));
...@@ -144,7 +147,7 @@ namespace sgx { ...@@ -144,7 +147,7 @@ namespace sgx {
TVM_REGISTER_GLOBAL("__sgx_thread_group_launch__") TVM_REGISTER_GLOBAL("__sgx_thread_group_launch__")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
EnclaveContext::GetModule()->RunWorkers(args[0], args[1]); EnclaveContext::GetModule()->RunWorkers(args[0]);
}); });
TVM_REGISTER_GLOBAL("__sgx_thread_group_join__") TVM_REGISTER_GLOBAL("__sgx_thread_group_join__")
...@@ -215,16 +218,6 @@ void* tvm_ocall_reserve_space(size_t num_bytes, size_t alignment) { ...@@ -215,16 +218,6 @@ void* tvm_ocall_reserve_space(size_t num_bytes, size_t alignment) {
return buf; return buf;
} }
void tvm_ocall_set_return(TVMRetValueHandle ret,
const TVMValue* value,
const int* type_code,
int num_ret) {
CHECK_EQ(num_ret, 1) << "Only one return value is currently supported.";
CHECK(type_code[0] != kStr) << "Return kBytes, not kStr.";
TVMRetValue* rv = static_cast<TVMRetValue*>(ret);
*rv = TVMArgValue(value[0], type_code[0]);
}
} // extern "C" } // extern "C"
} // namespace sgx } // namespace sgx
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment