Commit dd23bb6f by nhynes Committed by Tianqi Chen

[SGX] Improve edgeroutines (#1775)

parent e986f87e
......@@ -183,6 +183,7 @@ docs.tgz
cat.png
*.mlmodel
tvm_u.*
tvm_t.*
# Mac OS X
.DS_Store
build*
......
......@@ -183,7 +183,9 @@ add_library(tvm SHARED ${COMPILER_SRCS} ${RUNTIME_SRCS})
add_library(tvm_topi SHARED ${TOPI_SRCS})
add_library(tvm_runtime SHARED ${RUNTIME_SRCS})
if(NOT USE_SGX STREQUAL "OFF")
add_dependencies(tvm_runtime sgx_edl)
add_dependencies(tvm sgx_edl)
add_dependencies(tvm_runtime sgx_edl tvm_t)
install(TARGETS tvm_t ARCHIVE DESTINATION lib${LIB_SUFFIX})
endif()
add_library(nnvm_compiler SHARED ${NNVM_COMPILER_SRCS})
......
......@@ -3,6 +3,8 @@ if(NOT USE_SGX STREQUAL "OFF")
set(_sgx_src ${CMAKE_CURRENT_SOURCE_DIR}/src/runtime/sgx)
set(_tvm_u_h ${_sgx_src}/untrusted/tvm_u.h)
set(_tvm_t_h ${_sgx_src}/trusted/tvm_t.h)
set(_tvm_t_c ${_sgx_src}/trusted/tvm_t.c)
set(_tvm_edl ${_sgx_src}/tvm.edl)
set(_sgx_ustdc ${RUST_SGX_SDK}/sgx_ustdc)
......@@ -11,13 +13,16 @@ if(NOT USE_SGX STREQUAL "OFF")
set(_urts_lib "${_urts_lib}_sim")
endif()
# build edge routines
add_custom_command(
OUTPUT ${_tvm_u_h}
COMMAND ${USE_SGX}/bin/x64/sgx_edger8r --untrusted
--untrusted-dir ${_sgx_src}/untrusted
--untrusted --untrusted-dir ${_sgx_src}/untrusted
--trusted --trusted-dir ${_sgx_src}/trusted
--search-path ${USE_SGX}/include --search-path ${RUST_SGX_SDK}/edl
${_tvm_edl}
COMMAND sed -i "4i '#include <tvm/runtime/c_runtime_api.h>'" ${_tvm_u_h}
COMMAND sed -i "4i '#include <tvm/runtime/c_runtime_api.h>'" ${_tvm_t_h}
DEPENDS ${_tvm_edl}
)
add_custom_command(
......@@ -27,6 +32,13 @@ if(NOT USE_SGX STREQUAL "OFF")
)
add_custom_target(sgx_edl DEPENDS ${_tvm_u_h} ${_sgx_ustdc}/libsgx_ustdc.a)
# build trusted library
set_source_files_properties(${_tvm_t_c} PROPERTIES GENERATED TRUE)
add_library(tvm_t STATIC ${_tvm_t_c})
add_dependencies(tvm_t sgx_edl)
target_include_directories(tvm_t PUBLIC ${USE_SGX}/include ${USE_SGX}/include/tlibc)
# add untrusted runtime files
include_directories(${USE_SGX}/include)
file(GLOB RUNTIME_SGX_SRCS ${_sgx_src}/untrusted/*.c*)
list(APPEND TVM_RUNTIME_LINKER_LIBS
......
......@@ -9,7 +9,8 @@ enclave {
[in, count=num_args] const TVMValue* arg_values,
[in, count=num_args] const int* type_codes,
int num_args,
[isptr, user_check] TVMRetValueHandle ret);
[out] TVMValue* ret_val,
[out] int* ret_type_code);
};
untrusted {
......@@ -19,10 +20,6 @@ enclave {
int num_args,
[out] TVMValue* ret_val,
[out] int* ret_type_code);
void tvm_ocall_set_return([isptr, user_check] TVMRetValueHandle ret,
[in, count=num_ret] const TVMValue* value,
[in, count=num_ret] const int* type_code,
int num_ret);
void tvm_ocall_register_export([in, string] const char* name, int func_id);
void* tvm_ocall_reserve_space(size_t num_bytes, size_t alignment);
};
......
......@@ -110,15 +110,18 @@ class SGXModuleNode : public ModuleNode {
int func_id = exported->second;
return PackedFunc([this, func_id](TVMArgs args, TVMRetValue* rv) {
sgx::EnclaveContext ctx(this);
TVMValue ret_value;
int ret_type_code;
TVM_SGX_CHECKED_CALL(tvm_ecall_packed_func(eid_, func_id,
args.values, args.type_codes, args.num_args, rv));
args.values, args.type_codes, args.num_args, &ret_value, &ret_type_code));
*rv = TVMArgValue(ret_value, ret_type_code);
});
}
void RunWorkers(int num_tasks, void* tg) {
std::function<void(int)> runner = [this, tg](int _worker_id) {
void RunWorkers(int num_tasks) {
std::function<void(int)> runner = [this](int _worker_id) {
this->GetFunction("__tvm_run_worker__",
std::shared_ptr<SGXModuleNode>(nullptr))(tg);
std::shared_ptr<SGXModuleNode>(nullptr))();
};
thread_group_.reset(new tvm::runtime::threading::ThreadGroup(
num_tasks, runner, false /* include_main_thread */));
......@@ -144,7 +147,7 @@ namespace sgx {
TVM_REGISTER_GLOBAL("__sgx_thread_group_launch__")
.set_body([](TVMArgs args, TVMRetValue* rv) {
EnclaveContext::GetModule()->RunWorkers(args[0], args[1]);
EnclaveContext::GetModule()->RunWorkers(args[0]);
});
TVM_REGISTER_GLOBAL("__sgx_thread_group_join__")
......@@ -215,16 +218,6 @@ void* tvm_ocall_reserve_space(size_t num_bytes, size_t alignment) {
return buf;
}
void tvm_ocall_set_return(TVMRetValueHandle ret,
const TVMValue* value,
const int* type_code,
int num_ret) {
CHECK_EQ(num_ret, 1) << "Only one return value is currently supported.";
CHECK(type_code[0] != kStr) << "Return kBytes, not kStr.";
TVMRetValue* rv = static_cast<TVMRetValue*>(ret);
*rv = TVMArgValue(value[0], type_code[0]);
}
} // extern "C"
} // namespace sgx
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment