Commit bfceafc7 by nhynes Committed by Tianqi Chen

Pluggable Thread Launching Mechanism (#991)

parent cbde86f9
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#include "../../src/runtime/module.cc" #include "../../src/runtime/module.cc"
#include "../../src/runtime/registry.cc" #include "../../src/runtime/registry.cc"
#include "../../src/runtime/file_util.cc" #include "../../src/runtime/file_util.cc"
#include "../../src/runtime/threading_backend.cc"
#include "../../src/runtime/thread_pool.cc" #include "../../src/runtime/thread_pool.cc"
// NOTE: all the files after this are optional modules // NOTE: all the files after this are optional modules
......
...@@ -26,6 +26,7 @@ pkg_cflags := -std=c++11 -O2 -fPIC\ ...@@ -26,6 +26,7 @@ pkg_cflags := -std=c++11 -O2 -fPIC\
-I${TVM_ROOT}/dlpack/include\ -I${TVM_ROOT}/dlpack/include\
-I.\ -I.\
-DDMLC_LOG_STACK_TRACE=0\ -DDMLC_LOG_STACK_TRACE=0\
-fmax-errors=4
pkg_ldflags := -L${TVM_ROOT}/lib pkg_ldflags := -L${TVM_ROOT}/lib
...@@ -40,7 +41,7 @@ enclave_cflags := -static -nostdinc\ ...@@ -40,7 +41,7 @@ enclave_cflags := -static -nostdinc\
-DDMLC_CXX11_THREAD_LOCAL=0\ -DDMLC_CXX11_THREAD_LOCAL=0\
$(enclave_include_paths)\ $(enclave_include_paths)\
enclave_cxxflags := -nostdinc++ $(enclave_cflags) enclave_cxxflags := -nostdinc++ $(enclave_cflags) -DTVM_SGX_MAX_CONCURRENCY=4
enclave_ldflags :=\ enclave_ldflags :=\
-Wl,--no-undefined -nostdlib -nodefaultlibs -nostartfiles -L$(SGX_SDK)/lib64\ -Wl,--no-undefined -nostdlib -nodefaultlibs -nostartfiles -L$(SGX_SDK)/lib64\
...@@ -62,7 +63,7 @@ app_ldflags := -L$(SGX_SDK)/lib64\ ...@@ -62,7 +63,7 @@ app_ldflags := -L$(SGX_SDK)/lib64\
all: lib/test_addone.signed.so bin/test_addone all: lib/test_addone.signed.so bin/test_addone
# Build rule for all-in-one TVM package library # Build rule for all-in-one TVM package library
lib/tvm_runtime_pack.o: tvm_runtime_pack.cc lib/tvm_runtime_pack.o: tvm_runtime_pack.cc lib/test_addone_t.o
@mkdir -p $(@D) @mkdir -p $(@D)
$(CXX) -c $< -o $@ $(pkg_cflags) $(pkg_ldflags) $(enclave_cxxflags) -g $(CXX) -c $< -o $@ $(pkg_cflags) $(pkg_ldflags) $(enclave_cxxflags) -g
...@@ -94,7 +95,7 @@ lib/test_addone.signed.so: lib/test_addone.so enclave_config.xml ...@@ -94,7 +95,7 @@ lib/test_addone.signed.so: lib/test_addone.so enclave_config.xml
# An app that runs the enclave # An app that runs the enclave
bin/test_addone: app.cc lib/test_addone_u.o bin/test_addone: app.cc lib/test_addone_u.o
@mkdir -p $(@D) @mkdir -p $(@D)
$(CXX) $^ -o $@ $(app_cflags) $(app_ldflags) $(CXX) $^ -o $@ $(app_cflags) $(app_ldflags) $(pkg_cflags) -g
# Debugging runtime pack built without SGX (c.f. howto_deploy/tvm_runtime_pack.cc) # Debugging runtime pack built without SGX (c.f. howto_deploy/tvm_runtime_pack.cc)
lib/tvm_runtime_pack_nosgx.o: tvm_runtime_pack.cc lib/tvm_runtime_pack_nosgx.o: tvm_runtime_pack.cc
...@@ -104,7 +105,7 @@ lib/tvm_runtime_pack_nosgx.o: tvm_runtime_pack.cc ...@@ -104,7 +105,7 @@ lib/tvm_runtime_pack_nosgx.o: tvm_runtime_pack.cc
# Debugging binary that runs TVM without SGX # Debugging binary that runs TVM without SGX
bin/addone_nosgx: enclave.cc lib/tvm_runtime_pack_nosgx.o lib/test_addone_sys.o bin/addone_nosgx: enclave.cc lib/tvm_runtime_pack_nosgx.o lib/test_addone_sys.o
@mkdir -p $(@D) @mkdir -p $(@D)
$(CXX) $^ -o $@ $(pkg_cflags) $(pkg_ldflags) -g $(CXX) $^ -o $@ $(pkg_cflags) $(pkg_ldflags) -g -lpthread
clean: clean:
rm -rf lib bin rm -rf lib bin
#include <cstdio> #include <cstdio>
#include <iostream>
#include "sgx_urts.h" #include "sgx_urts.h"
#include "sgx_eid.h" #include "sgx_eid.h"
#include "test_addone_u.h" #include "test_addone_u.h"
#include "../../sgx/runtime_u.cc"
#define TOKEN_FILENAME "bin/test_addone.token" #define TOKEN_FILENAME "bin/test_addone.token"
#define ENCLAVE_FILENAME "lib/test_addone.signed.so" #define ENCLAVE_FILENAME "lib/test_addone.signed.so"
sgx_enclave_id_t global_eid = 0; // global EID shared by multiple threads sgx_enclave_id_t tvm_sgx_eid;
typedef struct _sgx_errlist_t { typedef struct _sgx_errlist_t {
sgx_status_t err; sgx_status_t err;
...@@ -80,7 +82,7 @@ int initialize_enclave(void) ...@@ -80,7 +82,7 @@ int initialize_enclave(void)
/* Step 2: call sgx_create_enclave to initialize an enclave instance */ /* Step 2: call sgx_create_enclave to initialize an enclave instance */
/* Debug Support: set 2nd parameter to 1 */ /* Debug Support: set 2nd parameter to 1 */
sgx_status = sgx_create_enclave(ENCLAVE_FILENAME, SGX_DEBUG_FLAG, &token, &updated, &global_eid, NULL); sgx_status = sgx_create_enclave(ENCLAVE_FILENAME, SGX_DEBUG_FLAG, &token, &updated, &tvm_sgx_eid, NULL);
if (sgx_status != SGX_SUCCESS) { if (sgx_status != SGX_SUCCESS) {
print_error_message(sgx_status); print_error_message(sgx_status);
if (fp != NULL) fclose(fp); if (fp != NULL) fclose(fp);
...@@ -105,7 +107,7 @@ int initialize_enclave(void) ...@@ -105,7 +107,7 @@ int initialize_enclave(void)
} }
int SGX_CDECL main(int argc, char *argv[]) { int SGX_CDECL main(int argc, char *argv[]) {
if(initialize_enclave() < 0){ if(initialize_enclave() < 0) {
printf("Failed to initialize enclave.\n"); printf("Failed to initialize enclave.\n");
return -1; return -1;
} }
...@@ -113,12 +115,13 @@ int SGX_CDECL main(int argc, char *argv[]) { ...@@ -113,12 +115,13 @@ int SGX_CDECL main(int argc, char *argv[]) {
/* Run TVM within the enclave */ /* Run TVM within the enclave */
int addone_status; int addone_status;
sgx_status_t sgx_status = SGX_ERROR_UNEXPECTED; sgx_status_t sgx_status = SGX_ERROR_UNEXPECTED;
sgx_status = enclave_main(global_eid, &addone_status); sgx_status = tvm_ecall_run_module(tvm_sgx_eid, nullptr, &addone_status);
if (sgx_status != SGX_SUCCESS) { if (sgx_status != SGX_SUCCESS) {
print_error_message(sgx_status); print_error_message(sgx_status);
} }
tvm_ecall_shutdown(tvm_sgx_eid);
sgx_destroy_enclave(global_eid); tvm::runtime::sgx::Shutdown();
sgx_destroy_enclave(tvm_sgx_eid);
if (addone_status == 1) { if (addone_status == 1) {
printf("It works!"); printf("It works!");
...@@ -127,3 +130,9 @@ int SGX_CDECL main(int argc, char *argv[]) { ...@@ -127,3 +130,9 @@ int SGX_CDECL main(int argc, char *argv[]) {
printf("It doesn't work."); printf("It doesn't work.");
return -1; return -1;
} }
extern "C" {
void ocall_println(const char* str) {
std::cout << "Enclave says: " << str << std::endl;
}
}
...@@ -6,6 +6,8 @@ ...@@ -6,6 +6,8 @@
#include <iostream> #include <iostream>
#endif #endif
extern void Shutdown();
/* This function mirrors the one in howto_deploy except without the iostream */ /* This function mirrors the one in howto_deploy except without the iostream */
int Verify(tvm::runtime::Module mod, std::string fname) { int Verify(tvm::runtime::Module mod, std::string fname) {
// Get the function from the module. // Get the function from the module.
...@@ -43,9 +45,9 @@ int Verify(tvm::runtime::Module mod, std::string fname) { ...@@ -43,9 +45,9 @@ int Verify(tvm::runtime::Module mod, std::string fname) {
extern "C" { extern "C" {
int enclave_main() { void tvm_ecall_run_module(const void* tvm_args, void* tvm_return_value) {
tvm::runtime::Module mod_syslib = (*tvm::runtime::Registry::Get("module._GetSystemLib"))(); tvm::runtime::Module mod_syslib = (*tvm::runtime::Registry::Get("module._GetSystemLib"))();
return Verify(mod_syslib, "addonesys"); *(int*)tvm_return_value = Verify(mod_syslib, "addonesys");
} }
} }
......
<EnclaveConfiguration> <EnclaveConfiguration>
<ProdID>0</ProdID> <ProdID>0</ProdID>
<ISVSVN>0</ISVSVN> <ISVSVN>0</ISVSVN>
<StackMaxSize>0x2000</StackMaxSize> <StackMaxSize>0x100000</StackMaxSize>
<HeapMaxSize>0x1000</HeapMaxSize> <HeapMaxSize>0x100000</HeapMaxSize>
<TCSNum>1</TCSNum> <TCSNum>5</TCSNum>
<TCSPolicy>1</TCSPolicy> <TCSPolicy>1</TCSPolicy>
<DisableDebug>0</DisableDebug> <DisableDebug>0</DisableDebug>
<MiscSelect>0</MiscSelect> <MiscSelect>0</MiscSelect>
......
...@@ -11,6 +11,8 @@ def prepare_test_libs(base_path): ...@@ -11,6 +11,8 @@ def prepare_test_libs(base_path):
A = tvm.placeholder((n,), name='A') A = tvm.placeholder((n,), name='A')
B = tvm.compute(A.shape, lambda *i: A(*i) + 1, name='B') B = tvm.compute(A.shape, lambda *i: A(*i) + 1, name='B')
s = tvm.create_schedule(B.op) s = tvm.create_schedule(B.op)
s[B].parallel(s[B].op.axis[0])
print(tvm.lower(s, [A, B], simple_mode=True))
# Compile library in system library mode # Compile library in system library mode
fadd_syslib = tvm.build(s, [A, B], 'llvm --system-lib', name='addonesys') fadd_syslib = tvm.build(s, [A, B], 'llvm --system-lib', name='addonesys')
......
enclave { enclave {
from "sgx_tstdc.edl" import sgx_thread_wait_untrusted_event_ocall, sgx_thread_set_untrusted_event_ocall, sgx_thread_setwait_untrusted_events_ocall, sgx_thread_set_multiple_untrusted_events_ocall; from "../../sgx/tvm.edl" import *;
trusted { untrusted {
public int enclave_main(); void ocall_println([in, string] const char *str);
}; };
}; };
...@@ -5,7 +5,11 @@ ...@@ -5,7 +5,11 @@
* Please refer to the Makefile (rule lib/tvm_runtime_pack.o) for how to build. * Please refer to the Makefile (rule lib/tvm_runtime_pack.o) for how to build.
* *
*/ */
#include "../../sgx/sgx_runtime.cc" #ifdef _LIBCPP_SGX_CONFIG
#include "lib/test_addone_t.h"
#endif
#include "../../sgx/runtime_t.cc"
#ifndef _LIBCPP_SGX_CONFIG #ifndef _LIBCPP_SGX_CONFIG
#include "../../src/runtime/file_util.cc" #include "../../src/runtime/file_util.cc"
#endif #endif
/*!
* Copyright (c) 2018 by Contributors
* \file threading_backend.h
* \brief Utilities for manipulating thread pool threads.
*/
#ifndef TVM_RUNTIME_THREADING_BACKEND_H_
#define TVM_RUNTIME_THREADING_BACKEND_H_
#include <functional>
#include <memory>
#include <vector>
namespace tvm {
namespace runtime {
namespace threading {
/*!
* \brief A platform-agnostic abstraction for managing a collection of
* thread pool threads.
*/
class ThreadGroup {
public:
class Impl;
/*!
* \brief Creates a collection of threads which run a provided function.
*
* \param num_workers The total number of worker threads in this group.
Includes main thread if `exclude_worker0 = true`
* \param worker_callback A callback which is run in its own thread.
Receives the worker_id as an argument.
* \param exclude_worker0 Whether to use the main thread as a worker.
* If `true`, worker0 will not be launched in a new thread and
* `worker_callback` will only be called for values >= 1. This
* allows use of the main thread as a worker.
*/
ThreadGroup(int num_workers,
std::function<void(int)> worker_callback,
bool exclude_worker0 = false);
~ThreadGroup();
/*!
* \brief Blocks until all non-main threads in the pool finish.
*/
void Join();
private:
Impl* impl_;
};
/*!
* \brief Platform-agnostic no-op.
*/
void Yield();
/*!
* \return the maximum number of effective workers for this system.
*/
int MaxConcurrency();
} // namespace threading
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_THREADING_BACKEND_H_
...@@ -9,17 +9,15 @@ ...@@ -9,17 +9,15 @@
#include "../../src/runtime/module.cc" #include "../../src/runtime/module.cc"
#include "../../src/runtime/registry.cc" #include "../../src/runtime/registry.cc"
#include "../../src/runtime/system_lib_module.cc" #include "../../src/runtime/system_lib_module.cc"
#ifndef _LIBCPP_SGX_CONFIG
#include "../../src/runtime/threading_backend.cc"
#else
#include "threading_backend.cc"
#endif
#include "../../src/runtime/thread_pool.cc"
// dummy parallel runtime (for now) extern "C" {
int TVMBackendParallelLaunch( void tvm_ecall_shutdown() {
FTVMParallelLambda flambda, tvm::runtime::ThreadPool::Global()->Shutdown();
void* cdata,
int num_task) {
TVMParallelGroupEnv env = { nullptr /* sync_handle */, 1 /* num_task */ };
return flambda(0 /* task_id */, &env, cdata);
} }
int TVMBackendParallelBarrier(int task_id, TVMParallelGroupEnv* penv) {
return 0;
} }
#include <tvm/runtime/threading_backend.h>
#include "../../src/runtime/threading_backend.cc"
#include <iostream>
extern sgx_enclave_id_t tvm_sgx_eid;
extern "C" {
sgx_status_t tvm_ecall_run_worker(sgx_enclave_id_t eid, const void* cb);
}
namespace tvm {
namespace runtime {
namespace sgx {
static std::unique_ptr<tvm::runtime::threading::ThreadGroup> sgx_thread_group;
extern "C" {
void tvm_ocall_thread_pool_launch(int num_tasks, void* cb) {
std::function<void(int)> runner = [cb](int _worker_id) {
sgx_status_t sgx_status = SGX_ERROR_UNEXPECTED;
sgx_status = tvm_ecall_run_worker(tvm_sgx_eid, cb);
CHECK(sgx_status == SGX_SUCCESS) << "SGX Error: " << sgx_status;
};
sgx_thread_group.reset(new tvm::runtime::threading::ThreadGroup(
num_tasks, runner, false /* include_main_thread */));
}
}
void Shutdown() {
sgx_thread_group->Join();
}
} // namespace sgx
} // namespace runtime
} // namespace tvm
/*!
* Copyright (c) 2018 by Contributors
* \file sgx/threading_backend.cc
* \brief SGX threading backend
*/
#include <tvm/runtime/threading_backend.h>
#include <dmlc/logging.h>
#include <sgx_edger8r.h>
#include <sgx_trts.h>
#include <atomic>
extern "C" {
sgx_status_t SGX_CDECL tvm_ocall_thread_pool_launch(int num_workers, void* cb);
}
#ifndef TVM_SGX_MAX_CONCURRENCY
#define TVM_SGX_MAX_CONCURRENCY 1
#endif
namespace tvm {
namespace runtime {
namespace threading {
class ThreadGroup::Impl {
public:
Impl(int num_workers, std::function<void(int)> worker_callback,
bool exclude_worker0)
: num_workers_(num_workers),
worker_callback_(worker_callback),
next_task_id_(exclude_worker0) {
CHECK(num_workers <= TVM_SGX_MAX_CONCURRENCY)
<< "Tried spawning more threads than allowed by TVM_SGX_MAX_CONCURRENCY.";
sgx_status_t sgx_status = SGX_ERROR_UNEXPECTED;
sgx_status = tvm_ocall_thread_pool_launch(num_workers, this);
CHECK(sgx_status == SGX_SUCCESS) << "SGX Error: " << sgx_status;
}
void RunTask() {
int task_id = next_task_id_++;
CHECK(task_id < num_workers_)
<< "More workers entered enclave than allowed by TVM_SGX_MAX_CONCURRENCY";
worker_callback_(task_id);
}
private:
int num_workers_;
std::function<void(int)> worker_callback_;
std::atomic<int> next_task_id_;
};
ThreadGroup::ThreadGroup(int num_workers,
std::function<void(int)> worker_callback,
bool exclude_worker0)
: impl_(new ThreadGroup::Impl(num_workers, worker_callback, exclude_worker0)) {}
void ThreadGroup::Join() {}
ThreadGroup::~ThreadGroup() { delete impl_; }
void Yield() {}
int MaxConcurrency() { return TVM_SGX_MAX_CONCURRENCY; }
extern "C" {
void tvm_ecall_run_worker(const void* impl) {
if (!sgx_is_within_enclave(impl, sizeof(ThreadGroup::Impl))) return;
((ThreadGroup::Impl*)impl)->RunTask();
}
}
} // namespace threading
} // namespace runtime
} // namespace tvm
enclave {
from "sgx_tstdc.edl" import *;
trusted {
public void tvm_ecall_run_module([user_check] const void* tvm_args,
[user_check] void* tvm_ret_value);
public void tvm_ecall_run_worker([user_check] const void* cb);
public void tvm_ecall_shutdown();
};
untrusted {
void tvm_ocall_thread_pool_launch(int num_workers, [user_check] void* cb);
};
};
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
*/ */
#include <tvm/runtime/c_runtime_api.h> #include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/c_backend_api.h> #include <tvm/runtime/c_backend_api.h>
#include <tvm/runtime/threading_backend.h>
#include <dmlc/thread_local.h> #include <dmlc/thread_local.h>
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <thread> #include <thread>
...@@ -17,9 +18,6 @@ ...@@ -17,9 +18,6 @@
#include <cstring> #include <cstring>
#include <memory> #include <memory>
#include <sstream> #include <sstream>
#if defined(__linux__)
#include <sched.h>
#endif
const constexpr int kL1CacheBytes = 64; const constexpr int kL1CacheBytes = 64;
...@@ -73,14 +71,14 @@ class ParallelLauncher { ...@@ -73,14 +71,14 @@ class ParallelLauncher {
return num_pending_ == 0; return num_pending_ == 0;
}); });
if (!has_error_) return 0; if (!has_error_) return 0;
std::ostringstream os; std::string err("");
for (size_t i = 0; i < par_errors_.size(); ++i) { for (size_t i = 0; i < par_errors_.size(); ++i) {
if (par_errors_[i].length() != 0) { if (par_errors_[i].length() != 0) {
os << "Task " << i << " error: " << par_errors_[i] << '\n'; err += "Task " + std::to_string(i) + " error: " + par_errors_[i] + '\n';
par_errors_[i].clear(); par_errors_[i].clear();
} }
} }
TVMAPISetLastError(os.str().c_str()); TVMAPISetLastError(err.c_str());
return -1; return -1;
} }
// Signal that one job has finished. // Signal that one job has finished.
...@@ -157,7 +155,7 @@ class SpscTaskQueue { ...@@ -157,7 +155,7 @@ class SpscTaskQueue {
*/ */
void Push(const Task& input) { void Push(const Task& input) {
while (!Enqueue(input)) { while (!Enqueue(input)) {
std::this_thread::yield(); tvm::runtime::threading::Yield();
} }
if (pending_.fetch_add(1) == -1) { if (pending_.fetch_add(1) == -1) {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
...@@ -176,7 +174,7 @@ class SpscTaskQueue { ...@@ -176,7 +174,7 @@ class SpscTaskQueue {
// If a new task comes to the queue quickly, this wait avoid the worker from sleeping. // If a new task comes to the queue quickly, this wait avoid the worker from sleeping.
// The default spin count is set by following the typical omp convention // The default spin count is set by following the typical omp convention
for (uint32_t i = 0; i < spin_count && pending_.load() == 0; ++i) { for (uint32_t i = 0; i < spin_count && pending_.load() == 0; ++i) {
std::this_thread::yield(); tvm::runtime::threading::Yield();
} }
if (pending_.fetch_sub(1) == 0) { if (pending_.fetch_sub(1) == 0) {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
...@@ -211,6 +209,8 @@ class SpscTaskQueue { ...@@ -211,6 +209,8 @@ class SpscTaskQueue {
* \return Whether the task is enqueued. * \return Whether the task is enqueued.
*/ */
bool Enqueue(const Task& input) { bool Enqueue(const Task& input) {
if (exit_now_.load(std::memory_order_relaxed)) return false;
const uint32_t tail = tail_.load(std::memory_order_relaxed); const uint32_t tail = tail_.load(std::memory_order_relaxed);
if ((tail + 1) % kRingSize != (head_.load(std::memory_order_acquire))) { if ((tail + 1) % kRingSize != (head_.load(std::memory_order_acquire))) {
...@@ -255,32 +255,17 @@ class SpscTaskQueue { ...@@ -255,32 +255,17 @@ class SpscTaskQueue {
// The thread pool // The thread pool
class ThreadPool { class ThreadPool {
public: public:
ThreadPool() { ThreadPool(): num_workers_(tvm::runtime::threading::MaxConcurrency()) {
const char *val = getenv("TVM_NUM_THREADS"); for (int i = 0; i < num_workers_; ++i) {
if (val == nullptr) { // The SpscTaskQueue only host ONE item at a time
val = getenv("OMP_NUM_THREADS"); queues_.emplace_back(std::unique_ptr<SpscTaskQueue>(new SpscTaskQueue()));
}
if (val != nullptr) {
num_workers_ = atoi(val);
} else {
#if defined(_M_X64) || defined(__x86_64__)
// Half to not count hyper threading.
num_workers_ = std::thread::hardware_concurrency() / 2;
#else
num_workers_ = std::thread::hardware_concurrency();
#endif
}
num_workers_ = std::max(num_workers_, 1);
this->Init();
}
~ThreadPool() {
for (std::unique_ptr<SpscTaskQueue>& q : queues_) {
q->SignalForKill();
}
for (std::thread& t : threads_) {
t.join();
} }
threads_ = std::unique_ptr<tvm::runtime::threading::ThreadGroup>(
new tvm::runtime::threading::ThreadGroup(
num_workers_, [this](int worker_id) { this->RunWorker(worker_id); },
false /* include_main_thread */));
} }
~ThreadPool() { Shutdown(); }
int Launch(FTVMParallelLambda flambda, int Launch(FTVMParallelLambda flambda,
void* cdata, void* cdata,
int num_task, int num_task,
...@@ -307,38 +292,22 @@ class ThreadPool { ...@@ -307,38 +292,22 @@ class ThreadPool {
return res; return res;
} }
void Shutdown() {
for (std::unique_ptr<SpscTaskQueue>& q : queues_) {
q->SignalForKill();
}
threads_.reset();
}
static ThreadPool* Global() { static ThreadPool* Global() {
static ThreadPool inst; static ThreadPool inst;
return &inst; return &inst;
} }
private: private:
// Initialize the pool.
void Init() {
for (int i = 0; i < num_workers_; ++i) {
// The SpscTaskQueue only host ONE item at a time
queues_.emplace_back(
std::unique_ptr<SpscTaskQueue>(new SpscTaskQueue()));
}
threads_.resize(num_workers_);
for (int i = 0; i < num_workers_; ++i) {
threads_[i] = std::thread([this, i] {
this->RunWorker(queues_[i].get());
});
}
const char *val = getenv("TVM_BIND_THREADS");
if (val == nullptr || atoi(val) == 1) {
if (num_workers_ <= std::thread::hardware_concurrency()) {
SetThreadAffinity();
} else {
LOG(WARNING)
<< "The thread affinity cannot be set when the number of workers is larger "
<< "than the number of available cores in the system.";
}
}
}
// Internal worker function. // Internal worker function.
void RunWorker(SpscTaskQueue* queue) { void RunWorker(int worker_id) {
SpscTaskQueue* queue = queues_[worker_id].get();
SpscTaskQueue::Task task; SpscTaskQueue::Task task;
ParallelLauncher::ThreadLocal()->is_worker = true; ParallelLauncher::ThreadLocal()->is_worker = true;
while (queue->Pop(&task)) { while (queue->Pop(&task)) {
...@@ -352,40 +321,9 @@ class ThreadPool { ...@@ -352,40 +321,9 @@ class ThreadPool {
} }
} }
} }
// bind worker threads to disjoint cores
void SetThreadAffinity() {
#if defined(__ANDROID__)
#ifndef CPU_SET
#define CPU_SETSIZE 1024
#define __NCPUBITS (8 * sizeof (uint64_t))
typedef struct {
uint64_t __bits[CPU_SETSIZE / __NCPUBITS];
} cpu_set_t;
#define CPU_SET(cpu, cpusetp) \
((cpusetp)->__bits[(cpu)/__NCPUBITS] |= (1UL << ((cpu) % __NCPUBITS)))
#define CPU_ZERO(cpusetp) \
memset((cpusetp), 0, sizeof(cpu_set_t))
#endif
#endif
for (int i=0; i < num_workers_; ++i) {
#if defined(__linux__) || defined(__ANDROID__)
cpu_set_t cpuset;
CPU_ZERO(&cpuset);
CPU_SET(i, &cpuset);
#if defined(__ANDROID__)
sched_setaffinity(threads_[i].native_handle(), sizeof(cpu_set_t), &cpuset);
#else
pthread_setaffinity_np(threads_[i].native_handle(),
sizeof(cpu_set_t), &cpuset);
#endif
#endif
}
}
// Number of workers
int num_workers_; int num_workers_;
std::vector<std::unique_ptr<SpscTaskQueue> > queues_; std::vector<std::unique_ptr<SpscTaskQueue> > queues_;
std::vector<std::thread> threads_; std::unique_ptr<tvm::runtime::threading::ThreadGroup> threads_;
}; };
} // namespace runtime } // namespace runtime
...@@ -411,7 +349,7 @@ int TVMBackendParallelBarrier(int task_id, TVMParallelGroupEnv* penv) { ...@@ -411,7 +349,7 @@ int TVMBackendParallelBarrier(int task_id, TVMParallelGroupEnv* penv) {
if (i != task_id) { if (i != task_id) {
while (sync_counter[i * kSyncStride].load( while (sync_counter[i * kSyncStride].load(
std::memory_order_relaxed) <= old_counter) { std::memory_order_relaxed) <= old_counter) {
std::this_thread::yield(); tvm::runtime::threading::Yield();
} }
} }
} }
......
/*!
* Copyright (c) 2018 by Contributors
* \file threading_backend.cc
* \brief Native threading backend
*/
#include <tvm/runtime/threading_backend.h>
#include <dmlc/logging.h>
#include <thread>
#if defined(__linux__)
#include <sched.h>
#endif
namespace tvm {
namespace runtime {
namespace threading {
class ThreadGroup::Impl {
public:
Impl(int num_workers,
std::function<void(int)> worker_callback,
bool exclude_worker0)
: num_workers_(num_workers) {
CHECK_GE(num_workers, 1)
<< "Requested a non-positive number of worker threads.";
for (int i = exclude_worker0; i < num_workers_; ++i) {
threads_.emplace_back([worker_callback, i] { worker_callback(i); });
}
const char *val = getenv("TVM_BIND_THREADS");
if (val == nullptr || atoi(val) == 1) {
if (num_workers_ <= std::thread::hardware_concurrency()) {
SetAffinity();
} else {
LOG(WARNING)
<< "The thread affinity cannot be set when the number of workers"
<< "is larger than the number of available cores in the system.";
}
}
}
~Impl() { Join(); }
void Join() {
for (auto& t : threads_) {
if (t.joinable()) t.join();
}
}
private:
// bind worker threads to disjoint cores
void SetAffinity() {
#if defined(__ANDROID__)
#ifndef CPU_SET
#define CPU_SETSIZE 1024
#define __NCPUBITS (8 * sizeof (uint64_t))
typedef struct {
uint64_t __bits[CPU_SETSIZE / __NCPUBITS];
} cpu_set_t;
#define CPU_SET(cpu, cpusetp) \
((cpusetp)->__bits[(cpu)/__NCPUBITS] |= (1UL << ((cpu) % __NCPUBITS)))
#define CPU_ZERO(cpusetp) \
memset((cpusetp), 0, sizeof(cpu_set_t))
#endif
#endif
for (unsigned i=0; i < threads_.size(); ++i) {
#if defined(__linux__) || defined(__ANDROID__)
cpu_set_t cpuset;
CPU_ZERO(&cpuset);
CPU_SET(i, &cpuset);
#if defined(__ANDROID__)
sched_setaffinity(threads_[i].native_handle(), sizeof(cpu_set_t), &cpuset);
#else
pthread_setaffinity_np(threads_[i].native_handle(),
sizeof(cpu_set_t), &cpuset);
#endif
#endif
}
}
int num_workers_;
std::vector<std::thread> threads_;
};
ThreadGroup::ThreadGroup(int num_workers,
std::function<void(int)> worker_callback,
bool exclude_worker0)
: impl_(new ThreadGroup::Impl(num_workers, worker_callback, exclude_worker0)) {}
ThreadGroup::~ThreadGroup() { delete impl_; }
void ThreadGroup::Join() { impl_->Join(); }
void Yield() {
std::this_thread::yield();
}
int MaxConcurrency() {
int max_concurrency = 1;
const char *val = getenv("TVM_NUM_THREADS");
if (val == nullptr) {
val = getenv("OMP_NUM_THREADS");
}
if (val != nullptr) {
max_concurrency = atoi(val);
} else {
max_concurrency = std::thread::hardware_concurrency();
#if defined(_M_X64) || defined(__x86_64__)
max_concurrency /= 2; // ignore hyper-threading
#endif
}
return std::max(max_concurrency, 1);
}
} // namespace threading
} // namespace runtime
} // namespace tvm
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment