Commit 5ff65749 by Yida Wang Committed by Tianqi Chen

[RUNTIME] better parallel launcher and task distribution (#1026)

parent 0ec7cabe
...@@ -37,12 +37,11 @@ class ParallelLauncher { ...@@ -37,12 +37,11 @@ class ParallelLauncher {
void* cdata, void* cdata,
int num_task, int num_task,
bool need_sync) { bool need_sync) {
std::lock_guard<std::mutex> lock(mutex_); num_pending_.store(num_task);
num_pending_ = num_task;
this->cdata = cdata; this->cdata = cdata;
this->flambda = flambda; this->flambda = flambda;
this->env.num_task = num_task; this->env.num_task = num_task;
has_error_ = false; has_error_.store(false);
// reshape // reshape
if (static_cast<size_t>(num_task) > par_errors_.size()) { if (static_cast<size_t>(num_task) > par_errors_.size()) {
par_errors_.resize(num_task + 1); par_errors_.resize(num_task + 1);
...@@ -66,11 +65,10 @@ class ParallelLauncher { ...@@ -66,11 +65,10 @@ class ParallelLauncher {
} }
// Wait n jobs to finish // Wait n jobs to finish
int WaitForJobs() { int WaitForJobs() {
std::unique_lock<std::mutex> lock(mutex_); while (num_pending_.load() != 0) {
cv_.wait(lock, [this] { tvm::runtime::threading::Yield();
return num_pending_ == 0; }
}); if (!has_error_.load()) return 0;
if (!has_error_) return 0;
std::string err(""); std::string err("");
for (size_t i = 0; i < par_errors_.size(); ++i) { for (size_t i = 0; i < par_errors_.size(); ++i) {
if (par_errors_[i].length() != 0) { if (par_errors_[i].length() != 0) {
...@@ -83,23 +81,13 @@ class ParallelLauncher { ...@@ -83,23 +81,13 @@ class ParallelLauncher {
} }
// Signal that one job has finished. // Signal that one job has finished.
void SignalJobError(int task_id) { void SignalJobError(int task_id) {
std::unique_lock<std::mutex> lock(mutex_); num_pending_.fetch_sub(1);
--num_pending_;
par_errors_[task_id] = TVMGetLastError(); par_errors_[task_id] = TVMGetLastError();
has_error_ = true; has_error_.store(true);
if (num_pending_ == 0) {
lock.unlock();
cv_.notify_one();
}
} }
// Signal that one job has finished. // Signal that one job has finished.
void SignalJobFinish() { void SignalJobFinish() {
std::unique_lock<std::mutex> lock(mutex_); num_pending_.fetch_sub(1);
--num_pending_;
if (num_pending_ == 0) {
lock.unlock();
cv_.notify_one();
}
} }
// Get thread local version of the store. // Get thread local version of the store.
static ParallelLauncher* ThreadLocal() { static ParallelLauncher* ThreadLocal() {
...@@ -116,14 +104,10 @@ class ParallelLauncher { ...@@ -116,14 +104,10 @@ class ParallelLauncher {
bool is_worker{false}; bool is_worker{false};
private: private:
// The mutex to access local env.
std::mutex mutex_;
// The conditional variable.
std::condition_variable cv_;
// The pending jobs. // The pending jobs.
uint32_t num_pending_; std::atomic<int32_t> num_pending_;
// Whether error has been countered. // Whether error has been countered.
bool has_error_; std::atomic<bool> has_error_;
// The counter page. // The counter page.
std::atomic<int32_t>* sync_counter_{nullptr}; std::atomic<int32_t>* sync_counter_{nullptr};
// The error message // The error message
...@@ -257,13 +241,13 @@ class ThreadPool { ...@@ -257,13 +241,13 @@ class ThreadPool {
public: public:
ThreadPool(): num_workers_(tvm::runtime::threading::MaxConcurrency()) { ThreadPool(): num_workers_(tvm::runtime::threading::MaxConcurrency()) {
for (int i = 0; i < num_workers_; ++i) { for (int i = 0; i < num_workers_; ++i) {
// The SpscTaskQueue only host ONE item at a time // The SpscTaskQueue only hosts ONE item at a time
queues_.emplace_back(std::unique_ptr<SpscTaskQueue>(new SpscTaskQueue())); queues_.emplace_back(std::unique_ptr<SpscTaskQueue>(new SpscTaskQueue()));
} }
threads_ = std::unique_ptr<tvm::runtime::threading::ThreadGroup>( threads_ = std::unique_ptr<tvm::runtime::threading::ThreadGroup>(
new tvm::runtime::threading::ThreadGroup( new tvm::runtime::threading::ThreadGroup(
num_workers_, [this](int worker_id) { this->RunWorker(worker_id); }, num_workers_, [this](int worker_id) { this->RunWorker(worker_id); },
false /* include_main_thread */)); exclude_worker0_ /* include_main_thread */));
} }
~ThreadPool() { ~ThreadPool() {
for (std::unique_ptr<SpscTaskQueue>& q : queues_) { for (std::unique_ptr<SpscTaskQueue>& q : queues_) {
...@@ -289,10 +273,20 @@ class ThreadPool { ...@@ -289,10 +273,20 @@ class ThreadPool {
launcher->Init(flambda, cdata, num_task, need_sync != 0); launcher->Init(flambda, cdata, num_task, need_sync != 0);
SpscTaskQueue::Task tsk; SpscTaskQueue::Task tsk;
tsk.launcher = launcher; tsk.launcher = launcher;
for (int i = 0; i < num_task; ++i) { // if worker0 is taken by the master, queues_[0] is abandoned
for (int i = exclude_worker0_; i < num_task; ++i) {
tsk.task_id = i; tsk.task_id = i;
queues_[i]->Push(tsk); queues_[i]->Push(tsk);
} }
// use the master thread to run task 0
if (exclude_worker0_) {
TVMParallelGroupEnv* penv = &(tsk.launcher->env);
if ((*tsk.launcher->flambda)(0, penv, cdata) == 0) {
tsk.launcher->SignalJobFinish();
} else {
tsk.launcher->SignalJobError(tsk.task_id);
}
}
int res = launcher->WaitForJobs(); int res = launcher->WaitForJobs();
return res; return res;
} }
...@@ -320,6 +314,8 @@ class ThreadPool { ...@@ -320,6 +314,8 @@ class ThreadPool {
} }
} }
int num_workers_; int num_workers_;
// if excluding worker 0 and using master to run task 0
bool exclude_worker0_{true};
std::vector<std::unique_ptr<SpscTaskQueue> > queues_; std::vector<std::unique_ptr<SpscTaskQueue> > queues_;
std::unique_ptr<tvm::runtime::threading::ThreadGroup> threads_; std::unique_ptr<tvm::runtime::threading::ThreadGroup> threads_;
}; };
......
...@@ -29,7 +29,7 @@ class ThreadGroup::Impl { ...@@ -29,7 +29,7 @@ class ThreadGroup::Impl {
const char *val = getenv("TVM_BIND_THREADS"); const char *val = getenv("TVM_BIND_THREADS");
if (val == nullptr || atoi(val) == 1) { if (val == nullptr || atoi(val) == 1) {
if (num_workers_ <= std::thread::hardware_concurrency()) { if (num_workers_ <= std::thread::hardware_concurrency()) {
SetAffinity(); SetAffinity(exclude_worker0);
} else { } else {
LOG(WARNING) LOG(WARNING)
<< "The thread affinity cannot be set when the number of workers" << "The thread affinity cannot be set when the number of workers"
...@@ -47,7 +47,9 @@ class ThreadGroup::Impl { ...@@ -47,7 +47,9 @@ class ThreadGroup::Impl {
private: private:
// bind worker threads to disjoint cores // bind worker threads to disjoint cores
void SetAffinity() { // if worker 0 is offloaded to master, i.e. exclude_worker0 is true,
// the master thread is bound to core 0.
void SetAffinity(bool exclude_worker0) {
#if defined(__ANDROID__) #if defined(__ANDROID__)
#ifndef CPU_SET #ifndef CPU_SET
#define CPU_SETSIZE 1024 #define CPU_SETSIZE 1024
...@@ -62,19 +64,27 @@ class ThreadGroup::Impl { ...@@ -62,19 +64,27 @@ class ThreadGroup::Impl {
memset((cpusetp), 0, sizeof(cpu_set_t)) memset((cpusetp), 0, sizeof(cpu_set_t))
#endif #endif
#endif #endif
for (unsigned i=0; i < threads_.size(); ++i) {
#if defined(__linux__) || defined(__ANDROID__) #if defined(__linux__) || defined(__ANDROID__)
for (unsigned i = 0; i < threads_.size(); ++i) {
unsigned core_id = i + exclude_worker0;
cpu_set_t cpuset; cpu_set_t cpuset;
CPU_ZERO(&cpuset); CPU_ZERO(&cpuset);
CPU_SET(i, &cpuset); CPU_SET(core_id, &cpuset);
#if defined(__ANDROID__) #if defined(__ANDROID__)
sched_setaffinity(threads_[i].native_handle(), sizeof(cpu_set_t), &cpuset); sched_setaffinity(threads_[i].native_handle(), sizeof(cpu_set_t), &cpuset);
#else #else
pthread_setaffinity_np(threads_[i].native_handle(), pthread_setaffinity_np(threads_[i].native_handle(),
sizeof(cpu_set_t), &cpuset); sizeof(cpu_set_t), &cpuset);
#endif #endif
#endif
} }
if (exclude_worker0) { // bind the master thread to core 0
cpu_set_t cpuset;
CPU_ZERO(&cpuset);
CPU_SET(0, &cpuset);
pthread_setaffinity_np(pthread_self(),
sizeof(cpu_set_t), &cpuset);
}
#endif
} }
int num_workers_; int num_workers_;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment