/*! * Copyright (c) 2018 by Contributors * \file sgx/threading_backend.cc * \brief SGX threading backend */ #include <tvm/runtime/threading_backend.h> #include <dmlc/logging.h> #include <sgx_edger8r.h> #include <sgx_trts.h> #include <atomic> #include "runtime.h" #ifndef TVM_SGX_MAX_CONCURRENCY #define TVM_SGX_MAX_CONCURRENCY 1 #endif namespace tvm { namespace runtime { namespace threading { class ThreadGroup::Impl { public: Impl(int num_workers, std::function<void(int)> worker_callback, bool exclude_worker0) : num_workers_(num_workers), worker_callback_(worker_callback), next_task_id_(exclude_worker0) { CHECK(num_workers <= TVM_SGX_MAX_CONCURRENCY) << "Tried spawning more threads than allowed by TVM_SGX_MAX_CONCURRENCY."; sgx::OCallPackedFunc("__sgx_thread_group_launch__", num_workers_, reinterpret_cast<void*>(this)); } ~Impl() { sgx::OCallPackedFunc("__sgx_thread_group_join__"); } void RunTask() { int task_id = next_task_id_++; CHECK(task_id < num_workers_) << "More workers entered enclave than allowed by TVM_SGX_MAX_CONCURRENCY"; worker_callback_(task_id); } private: int num_workers_; std::function<void(int)> worker_callback_; std::atomic<int> next_task_id_; }; ThreadGroup::ThreadGroup(int num_workers, std::function<void(int)> worker_callback, bool exclude_worker0) : impl_(new ThreadGroup::Impl(num_workers, worker_callback, exclude_worker0)) {} void ThreadGroup::Join() {} int ThreadGroup::Configure(AffinityMode mode, int nthreads, bool exclude_worker0) { int max_conc = MaxConcurrency(); if (!nthreads || ntheads > max_conc) { return max_conc; } return nthreads; } ThreadGroup::~ThreadGroup() { delete impl_; } void Yield() {} int MaxConcurrency() { return TVM_SGX_MAX_CONCURRENCY; } TVM_REGISTER_ENCLAVE_FUNC("__tvm_run_worker__") .set_body([](TVMArgs args, TVMRetValue* rv) { void* tg = args[0]; if (!sgx_is_within_enclave(tg, sizeof(ThreadGroup::Impl))) return; reinterpret_cast<ThreadGroup::Impl*>(tg)->RunTask(); }); } // namespace threading } // namespace runtime } // namespace tvm