Commit 91515322 by Yida Wang Committed by Tianqi Chen

[RUNTIME] Better scalability for multi-thread parallelization of CPUs (#971)

parent 537b70e4
...@@ -17,6 +17,11 @@ ...@@ -17,6 +17,11 @@
#include <cstring> #include <cstring>
#include <memory> #include <memory>
#include <sstream> #include <sstream>
#if defined(__linux__)
#include <sched.h>
#endif
const constexpr int kL1CacheBytes = 64;
namespace tvm { namespace tvm {
namespace runtime { namespace runtime {
...@@ -127,99 +132,124 @@ class ParallelLauncher { ...@@ -127,99 +132,124 @@ class ParallelLauncher {
std::vector<std::string> par_errors_; std::vector<std::string> par_errors_;
}; };
/*! \brief Working queue for each thread */ /*! \brief Lock-free single-producer-single-consumer queue for each thread */
class ParallelTaskQueue { class SpscTaskQueue {
public: public:
/*! \brief The task entry */ /*! \brief The task entry */
struct Task { struct Task {
ParallelLauncher* launcher; ParallelLauncher* launcher;
int32_t task_id; int32_t task_id;
}; };
ParallelTaskQueue() {
ring_.resize(2); SpscTaskQueue() :
buffer_(new Task[kRingSize]),
head_(0),
tail_(0) {
} }
/*!
* \brief Signal to kill the job. ~SpscTaskQueue() {
*/ delete[] buffer_;
void SignalForKill() {
std::lock_guard<std::mutex> lock(mutex_);
exit_now_.store(true);
cv_.notify_all();
} }
/*! /*!
* \brief Push task into the queue. * \brief Push a task into the queue and notify the comsumer if it is on wait.
* \param task The task to be pushed. * \param input The task to be dequeued.
*/ */
void Push(Task task) { void Push(const Task& input) {
std::unique_lock<std::mutex> lock(mutex_); while (!Enqueue(input)) {
if (num_pending_ < ring_.size()) { std::this_thread::yield();
CHECK_NE(ring_.size(), 0U);
ring_[(head_ + num_pending_) % ring_.size()] = task;
++num_pending_;
} else {
size_t old_size = ring_.size();
ring_.resize(old_size * 2);
if (head_ + num_pending_ > old_size) {
// copy the ring overflow part into the tail.
size_t ncopy = head_ + num_pending_ - old_size;
memcpy(&ring_[0] + old_size, &ring_[0], ncopy * sizeof(Task));
}
ring_[(head_ + num_pending_) % ring_.size()] = task;
++num_pending_;
} }
if (nwait_consumer_ != 0) { if (pending_.fetch_add(1) == -1) {
lock.unlock(); std::unique_lock<std::mutex> lock(mutex_);
cv_.notify_one(); cv_.notify_one();
} }
} }
/*! /*!
* \brief Pop task from the queue * \brief Pop a task out of the queue and condition wait if no tasks.
* \param task The task to be poped. * \param output The pointer to the task to be dequeued.
* \param timeout The number of cycles to spin before sleep. * \param spin_count The number of iterations to spin before sleep.
* \return Whether pop is successful or we need to exit now. * \return Whether pop is successful (true) or we need to exit now (false).
*/ */
bool Pop(Task* task, int timeout = 10) { bool Pop(Task* output, uint32_t spin_count = 300000) {
std::unique_lock<std::mutex> lock(mutex_); // Busy wait a bit when the queue is empty.
if (num_pending_ != 0) { // If a new task comes to the queue quickly, this wait avoid the worker from sleeping.
*task = ring_[head_]; // The default spin count is set by following the typical omp convention
head_ = (head_ + 1) % ring_.size(); for (uint32_t i = 0; i < spin_count && pending_.load() == 0; ++i) {
--num_pending_;
if (exit_now_.load()) return false;
} else {
lock.unlock();
// do a bit spin and busy waiting before sleep.
for (int i = 0; i < timeout && num_pending_ == 0; ++i) {
std::this_thread::yield(); std::this_thread::yield();
} }
lock.lock(); if (pending_.fetch_sub(1) == 0) {
++nwait_consumer_; std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [this] { cv_.wait(lock, [this] {
return num_pending_ != 0 || exit_now_.load(); return pending_.load() >= 0 || exit_now_.load();
}); });
--nwait_consumer_;
*task = ring_[head_];
head_ = (head_ + 1) % ring_.size();
--num_pending_;
if (exit_now_.load()) return false;
} }
if (exit_now_.load(std::memory_order_relaxed)) {
return false;
}
const uint32_t head = head_.load(std::memory_order_relaxed);
// sanity check if the queue is empty
CHECK(tail_.load(std::memory_order_acquire) != head);
*output = buffer_[head];
head_.store((head + 1) % kRingSize, std::memory_order_release);
return true; return true;
} }
/*!
* \brief Signal to terminate the worker.
*/
void SignalForKill() {
std::lock_guard<std::mutex> lock(mutex_);
exit_now_.store(true);
cv_.notify_all();
}
private: private:
// Number of the elments in the queue /*!
uint32_t num_pending_{0}; * \brief Lock-free enqueue.
// Queue head * \param input The task to be enqueued.
uint32_t head_{0}; * \return Whether the task is enqueued.
// Number of consumers to wait. */
uint32_t nwait_consumer_{0}; bool Enqueue(const Task& input) {
const uint32_t tail = tail_.load(std::memory_order_relaxed);
if ((tail + 1) % kRingSize != (head_.load(std::memory_order_acquire))) {
buffer_[tail] = input;
tail_.store((tail + 1) % kRingSize, std::memory_order_release);
return true;
}
return false;
}
// the cache line paddings are used for avoid false sharing between atomic variables
typedef char cache_line_pad_t[kL1CacheBytes];
cache_line_pad_t pad0_;
// size of the queue, the queue can host size_ - 1 items at most
// define it as a constant for better compiler optimization
static constexpr const int kRingSize = 2;
// pointer to access the item
Task* const buffer_;
cache_line_pad_t pad1_;
// queue head, where one gets a task from the queue
std::atomic<uint32_t> head_;
cache_line_pad_t pad2_;
// queue tail, when one puts a task to the queue
std::atomic<uint32_t> tail_;
cache_line_pad_t pad3_;
// pending tasks in the queue
std::atomic<int8_t> pending_{0};
cache_line_pad_t pad4_;
// signal for exit now
std::atomic<bool> exit_now_{false};
// internal mutex // internal mutex
std::mutex mutex_; std::mutex mutex_;
// cv for consumer // cv for consumer
std::condition_variable cv_; std::condition_variable cv_;
// signal for exit now
std::atomic<bool> exit_now_{false};
// The internal ring.
std::vector<Task> ring_;
}; };
// The thread pool // The thread pool
...@@ -244,7 +274,7 @@ class ThreadPool { ...@@ -244,7 +274,7 @@ class ThreadPool {
this->Init(); this->Init();
} }
~ThreadPool() { ~ThreadPool() {
for (std::unique_ptr<ParallelTaskQueue>& q : queues_) { for (std::unique_ptr<SpscTaskQueue>& q : queues_) {
q->SignalForKill(); q->SignalForKill();
} }
for (std::thread& t : threads_) { for (std::thread& t : threads_) {
...@@ -267,13 +297,14 @@ class ThreadPool { ...@@ -267,13 +297,14 @@ class ThreadPool {
<< " workers=" << num_workers_ << " request=" << num_task; << " workers=" << num_workers_ << " request=" << num_task;
} }
launcher->Init(flambda, cdata, num_task, need_sync != 0); launcher->Init(flambda, cdata, num_task, need_sync != 0);
ParallelTaskQueue::Task tsk; SpscTaskQueue::Task tsk;
tsk.launcher = launcher; tsk.launcher = launcher;
for (int i = 0; i < num_task; ++i) { for (int i = 0; i < num_task; ++i) {
tsk.task_id = i; tsk.task_id = i;
queues_[i]->Push(tsk); queues_[i]->Push(tsk);
} }
return launcher->WaitForJobs(); int res = launcher->WaitForJobs();
return res;
} }
static ThreadPool* Global() { static ThreadPool* Global() {
...@@ -285,8 +316,9 @@ class ThreadPool { ...@@ -285,8 +316,9 @@ class ThreadPool {
// Initialize the pool. // Initialize the pool.
void Init() { void Init() {
for (int i = 0; i < num_workers_; ++i) { for (int i = 0; i < num_workers_; ++i) {
// The SpscTaskQueue only host ONE item at a time
queues_.emplace_back( queues_.emplace_back(
std::unique_ptr<ParallelTaskQueue>(new ParallelTaskQueue())); std::unique_ptr<SpscTaskQueue>(new SpscTaskQueue()));
} }
threads_.resize(num_workers_); threads_.resize(num_workers_);
for (int i = 0; i < num_workers_; ++i) { for (int i = 0; i < num_workers_; ++i) {
...@@ -294,10 +326,20 @@ class ThreadPool { ...@@ -294,10 +326,20 @@ class ThreadPool {
this->RunWorker(queues_[i].get()); this->RunWorker(queues_[i].get());
}); });
} }
const char *val = getenv("TVM_BIND_THREADS");
if (val == nullptr || atoi(val) == 1) {
if (num_workers_ <= std::thread::hardware_concurrency()) {
SetThreadAffinity();
} else {
LOG(WARNING)
<< "The thread affinity cannot be set when the number of workers is larger "
<< "than the number of available cores in the system.";
}
}
} }
// Internal worker function. // Internal worker function.
void RunWorker(ParallelTaskQueue* queue) { void RunWorker(SpscTaskQueue* queue) {
ParallelTaskQueue::Task task; SpscTaskQueue::Task task;
ParallelLauncher::ThreadLocal()->is_worker = true; ParallelLauncher::ThreadLocal()->is_worker = true;
while (queue->Pop(&task)) { while (queue->Pop(&task)) {
CHECK(task.launcher != nullptr); CHECK(task.launcher != nullptr);
...@@ -310,9 +352,33 @@ class ThreadPool { ...@@ -310,9 +352,33 @@ class ThreadPool {
} }
} }
} }
// bind worker threads to disjoint cores
void SetThreadAffinity() {
#if defined(__ANDROID__)
#define CPU_SETSIZE 1024
#define __NCPUBITS (8 * sizeof (uint64_t))
typedef struct {
uint64_t __bits[CPU_SETSIZE / __NCPUBITS];
} cpu_set_t;
#define CPU_SET(cpu, cpusetp) \
((cpusetp)->__bits[(cpu)/__NCPUBITS] |= (1UL << ((cpu) % __NCPUBITS)))
#define CPU_ZERO(cpusetp) \
memset((cpusetp), 0, sizeof(cpu_set_t))
#endif
for (int i=0; i < num_workers_; ++i) {
#if defined(__linux__) || defined(__ANDROID__)
cpu_set_t cpuset;
CPU_ZERO(&cpuset);
CPU_SET(i, &cpuset);
pthread_setaffinity_np(threads_[i].native_handle(),
sizeof(cpu_set_t), &cpuset);
#endif
}
}
// Number of workers // Number of workers
int num_workers_; int num_workers_;
std::vector<std::unique_ptr<ParallelTaskQueue> > queues_; std::vector<std::unique_ptr<SpscTaskQueue> > queues_;
std::vector<std::thread> threads_; std::vector<std::thread> threads_;
}; };
...@@ -323,8 +389,9 @@ int TVMBackendParallelLaunch( ...@@ -323,8 +389,9 @@ int TVMBackendParallelLaunch(
FTVMParallelLambda flambda, FTVMParallelLambda flambda,
void* cdata, void* cdata,
int num_task) { int num_task) {
return tvm::runtime::ThreadPool::Global()->Launch( int res = tvm::runtime::ThreadPool::Global()->Launch(
flambda, cdata, num_task, 1); flambda, cdata, num_task, 1);
return res;
} }
int TVMBackendParallelBarrier(int task_id, TVMParallelGroupEnv* penv) { int TVMBackendParallelBarrier(int task_id, TVMParallelGroupEnv* penv) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment