Commit 21f71f9e by nhynes Committed by Tianqi Chen

Simplify enclave lifecycle management (#1013)

parent 1587038e
...@@ -119,8 +119,7 @@ int SGX_CDECL main(int argc, char *argv[]) { ...@@ -119,8 +119,7 @@ int SGX_CDECL main(int argc, char *argv[]) {
if (sgx_status != SGX_SUCCESS) { if (sgx_status != SGX_SUCCESS) {
print_error_message(sgx_status); print_error_message(sgx_status);
} }
tvm_ecall_shutdown(tvm_sgx_eid);
tvm::runtime::sgx::Shutdown();
sgx_destroy_enclave(tvm_sgx_eid); sgx_destroy_enclave(tvm_sgx_eid);
if (addone_status == 1) { if (addone_status == 1) {
......
...@@ -6,8 +6,6 @@ ...@@ -6,8 +6,6 @@
#include <iostream> #include <iostream>
#endif #endif
extern void Shutdown();
/* This function mirrors the one in howto_deploy except without the iostream */ /* This function mirrors the one in howto_deploy except without the iostream */
int Verify(tvm::runtime::Module mod, std::string fname) { int Verify(tvm::runtime::Module mod, std::string fname) {
// Get the function from the module. // Get the function from the module.
......
...@@ -15,9 +15,3 @@ ...@@ -15,9 +15,3 @@
#include "threading_backend.cc" #include "threading_backend.cc"
#endif #endif
#include "../../src/runtime/thread_pool.cc" #include "../../src/runtime/thread_pool.cc"
extern "C" {
void tvm_ecall_shutdown() {
tvm::runtime::ThreadPool::Global()->Shutdown();
}
}
...@@ -14,7 +14,8 @@ namespace sgx { ...@@ -14,7 +14,8 @@ namespace sgx {
static std::unique_ptr<tvm::runtime::threading::ThreadGroup> sgx_thread_group; static std::unique_ptr<tvm::runtime::threading::ThreadGroup> sgx_thread_group;
extern "C" { extern "C" {
void tvm_ocall_thread_pool_launch(int num_tasks, void* cb) {
void tvm_ocall_thread_group_launch(int num_tasks, void* cb) {
std::function<void(int)> runner = [cb](int _worker_id) { std::function<void(int)> runner = [cb](int _worker_id) {
sgx_status_t sgx_status = SGX_ERROR_UNEXPECTED; sgx_status_t sgx_status = SGX_ERROR_UNEXPECTED;
sgx_status = tvm_ecall_run_worker(tvm_sgx_eid, cb); sgx_status = tvm_ecall_run_worker(tvm_sgx_eid, cb);
...@@ -23,12 +24,13 @@ void tvm_ocall_thread_pool_launch(int num_tasks, void* cb) { ...@@ -23,12 +24,13 @@ void tvm_ocall_thread_pool_launch(int num_tasks, void* cb) {
sgx_thread_group.reset(new tvm::runtime::threading::ThreadGroup( sgx_thread_group.reset(new tvm::runtime::threading::ThreadGroup(
num_tasks, runner, false /* include_main_thread */)); num_tasks, runner, false /* include_main_thread */));
} }
}
void Shutdown() { void tvm_ocall_thread_group_join() {
sgx_thread_group->Join(); sgx_thread_group->Join();
} }
}
} // namespace sgx } // namespace sgx
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
...@@ -10,7 +10,8 @@ ...@@ -10,7 +10,8 @@
#include <atomic> #include <atomic>
extern "C" { extern "C" {
sgx_status_t SGX_CDECL tvm_ocall_thread_pool_launch(int num_workers, void* cb); sgx_status_t SGX_CDECL tvm_ocall_thread_group_launch(int num_workers, void* cb);
sgx_status_t SGX_CDECL tvm_ocall_thread_group_join();
} }
#ifndef TVM_SGX_MAX_CONCURRENCY #ifndef TVM_SGX_MAX_CONCURRENCY
...@@ -31,10 +32,14 @@ class ThreadGroup::Impl { ...@@ -31,10 +32,14 @@ class ThreadGroup::Impl {
CHECK(num_workers <= TVM_SGX_MAX_CONCURRENCY) CHECK(num_workers <= TVM_SGX_MAX_CONCURRENCY)
<< "Tried spawning more threads than allowed by TVM_SGX_MAX_CONCURRENCY."; << "Tried spawning more threads than allowed by TVM_SGX_MAX_CONCURRENCY.";
sgx_status_t sgx_status = SGX_ERROR_UNEXPECTED; sgx_status_t sgx_status = SGX_ERROR_UNEXPECTED;
sgx_status = tvm_ocall_thread_pool_launch(num_workers, this); sgx_status = tvm_ocall_thread_group_launch(num_workers, this);
CHECK(sgx_status == SGX_SUCCESS) << "SGX Error: " << sgx_status; CHECK(sgx_status == SGX_SUCCESS) << "SGX Error: " << sgx_status;
} }
~Impl() {
tvm_ocall_thread_group_join();
}
void RunTask() { void RunTask() {
int task_id = next_task_id_++; int task_id = next_task_id_++;
CHECK(task_id < num_workers_) CHECK(task_id < num_workers_)
......
...@@ -5,11 +5,11 @@ enclave { ...@@ -5,11 +5,11 @@ enclave {
public void tvm_ecall_run_module([user_check] const void* tvm_args, public void tvm_ecall_run_module([user_check] const void* tvm_args,
[user_check] void* tvm_ret_value); [user_check] void* tvm_ret_value);
public void tvm_ecall_run_worker([user_check] const void* cb); public void tvm_ecall_run_worker([user_check] const void* cb);
public void tvm_ecall_shutdown();
}; };
untrusted { untrusted {
void tvm_ocall_thread_pool_launch(int num_workers, [user_check] void* cb); void tvm_ocall_thread_group_launch(int num_workers, [user_check] void* cb);
void tvm_ocall_thread_group_join();
}; };
}; };
...@@ -265,7 +265,12 @@ class ThreadPool { ...@@ -265,7 +265,12 @@ class ThreadPool {
num_workers_, [this](int worker_id) { this->RunWorker(worker_id); }, num_workers_, [this](int worker_id) { this->RunWorker(worker_id); },
false /* include_main_thread */)); false /* include_main_thread */));
} }
~ThreadPool() { Shutdown(); } ~ThreadPool() {
for (std::unique_ptr<SpscTaskQueue>& q : queues_) {
q->SignalForKill();
}
threads_.reset();
}
int Launch(FTVMParallelLambda flambda, int Launch(FTVMParallelLambda flambda,
void* cdata, void* cdata,
int num_task, int num_task,
...@@ -292,13 +297,6 @@ class ThreadPool { ...@@ -292,13 +297,6 @@ class ThreadPool {
return res; return res;
} }
void Shutdown() {
for (std::unique_ptr<SpscTaskQueue>& q : queues_) {
q->SignalForKill();
}
threads_.reset();
}
static ThreadPool* Global() { static ThreadPool* Global() {
static ThreadPool inst; static ThreadPool inst;
return &inst; return &inst;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment