Commit 21f71f9e by nhynes Committed by Tianqi Chen

Simplify enclave lifecycle management (#1013)

parent 1587038e
......@@ -119,8 +119,7 @@ int SGX_CDECL main(int argc, char *argv[]) {
if (sgx_status != SGX_SUCCESS) {
print_error_message(sgx_status);
}
tvm_ecall_shutdown(tvm_sgx_eid);
tvm::runtime::sgx::Shutdown();
sgx_destroy_enclave(tvm_sgx_eid);
if (addone_status == 1) {
......
......@@ -6,8 +6,6 @@
#include <iostream>
#endif
extern void Shutdown();
/* This function mirrors the one in howto_deploy except without the iostream */
int Verify(tvm::runtime::Module mod, std::string fname) {
// Get the function from the module.
......
......@@ -15,9 +15,3 @@
#include "threading_backend.cc"
#endif
#include "../../src/runtime/thread_pool.cc"
extern "C" {
void tvm_ecall_shutdown() {
tvm::runtime::ThreadPool::Global()->Shutdown();
}
}
......@@ -14,7 +14,8 @@ namespace sgx {
static std::unique_ptr<tvm::runtime::threading::ThreadGroup> sgx_thread_group;
extern "C" {
void tvm_ocall_thread_pool_launch(int num_tasks, void* cb) {
void tvm_ocall_thread_group_launch(int num_tasks, void* cb) {
std::function<void(int)> runner = [cb](int _worker_id) {
sgx_status_t sgx_status = SGX_ERROR_UNEXPECTED;
sgx_status = tvm_ecall_run_worker(tvm_sgx_eid, cb);
......@@ -23,12 +24,13 @@ void tvm_ocall_thread_pool_launch(int num_tasks, void* cb) {
sgx_thread_group.reset(new tvm::runtime::threading::ThreadGroup(
num_tasks, runner, false /* include_main_thread */));
}
}
void Shutdown() {
void tvm_ocall_thread_group_join() {
sgx_thread_group->Join();
}
}
} // namespace sgx
} // namespace runtime
} // namespace tvm
......@@ -10,7 +10,8 @@
#include <atomic>
extern "C" {
sgx_status_t SGX_CDECL tvm_ocall_thread_pool_launch(int num_workers, void* cb);
sgx_status_t SGX_CDECL tvm_ocall_thread_group_launch(int num_workers, void* cb);
sgx_status_t SGX_CDECL tvm_ocall_thread_group_join();
}
#ifndef TVM_SGX_MAX_CONCURRENCY
......@@ -31,10 +32,14 @@ class ThreadGroup::Impl {
CHECK(num_workers <= TVM_SGX_MAX_CONCURRENCY)
<< "Tried spawning more threads than allowed by TVM_SGX_MAX_CONCURRENCY.";
sgx_status_t sgx_status = SGX_ERROR_UNEXPECTED;
sgx_status = tvm_ocall_thread_pool_launch(num_workers, this);
sgx_status = tvm_ocall_thread_group_launch(num_workers, this);
CHECK(sgx_status == SGX_SUCCESS) << "SGX Error: " << sgx_status;
}
~Impl() {
tvm_ocall_thread_group_join();
}
void RunTask() {
int task_id = next_task_id_++;
CHECK(task_id < num_workers_)
......
......@@ -5,11 +5,11 @@ enclave {
public void tvm_ecall_run_module([user_check] const void* tvm_args,
[user_check] void* tvm_ret_value);
public void tvm_ecall_run_worker([user_check] const void* cb);
public void tvm_ecall_shutdown();
};
untrusted {
void tvm_ocall_thread_pool_launch(int num_workers, [user_check] void* cb);
void tvm_ocall_thread_group_launch(int num_workers, [user_check] void* cb);
void tvm_ocall_thread_group_join();
};
};
......@@ -265,7 +265,12 @@ class ThreadPool {
num_workers_, [this](int worker_id) { this->RunWorker(worker_id); },
false /* include_main_thread */));
}
~ThreadPool() { Shutdown(); }
~ThreadPool() {
for (std::unique_ptr<SpscTaskQueue>& q : queues_) {
q->SignalForKill();
}
threads_.reset();
}
int Launch(FTVMParallelLambda flambda,
void* cdata,
int num_task,
......@@ -292,13 +297,6 @@ class ThreadPool {
return res;
}
void Shutdown() {
for (std::unique_ptr<SpscTaskQueue>& q : queues_) {
q->SignalForKill();
}
threads_.reset();
}
static ThreadPool* Global() {
static ThreadPool inst;
return &inst;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment