thread_pool.cc 10.3 KB
Newer Older
1 2 3 4 5 6 7
/*!
 *  Copyright (c) 2017 by Contributors
 * \file thread_pool.cc
 * \brief Threadpool for multi-threading runtime.
 */
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/c_backend_api.h>
8
#include <tvm/runtime/threading_backend.h>
9 10 11 12 13 14
#include <dmlc/thread_local.h>
#include <dmlc/logging.h>
#include <thread>
#include <condition_variable>
#include <mutex>
#include <atomic>
15
#include <algorithm>
16 17 18 19 20
#include <vector>
#include <string>
#include <cstring>
#include <memory>
#include <sstream>
21 22

const constexpr int kL1CacheBytes = 64;
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39

namespace tvm {
namespace runtime {

// stride in the page, fit to cache line.
constexpr int kSyncStride = 64 / sizeof(std::atomic<int>);

/*!
 * \brief Thread local master environment.
 */
class ParallelLauncher {
 public:
  // Reset the the task request.
  void Init(FTVMParallelLambda flambda,
            void* cdata,
            int num_task,
            bool need_sync) {
40
    num_pending_.store(num_task);
41 42 43
    this->cdata = cdata;
    this->flambda = flambda;
    this->env.num_task = num_task;
44
    has_error_.store(false);
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
    // reshape
    if (static_cast<size_t>(num_task) > par_errors_.size()) {
      par_errors_.resize(num_task + 1);
      if (need_sync) {
        delete[] sync_counter_;
        sync_counter_ = new std::atomic<int>[num_task * kSyncStride];
      }
    }
    if (need_sync) {
      for (int i = 0; i < num_task; ++i) {
        sync_counter_[i * kSyncStride].store(
            0, std::memory_order_relaxed);
      }
      this->env.sync_handle = sync_counter_;
    } else {
      this->env.sync_handle = nullptr;
    }
  }
  ~ParallelLauncher() {
    delete[] sync_counter_;
  }
  // Wait n jobs to finish
  int WaitForJobs() {
68 69 70 71
    while (num_pending_.load() != 0) {
      tvm::runtime::threading::Yield();
    }
    if (!has_error_.load()) return 0;
72
    std::string err("");
73 74
    for (size_t i = 0; i < par_errors_.size(); ++i) {
      if (par_errors_[i].length() != 0) {
75
        err += "Task " + std::to_string(i) + " error: " + par_errors_[i] + '\n';
76 77 78
        par_errors_[i].clear();
      }
    }
79
    TVMAPISetLastError(err.c_str());
80 81 82 83
    return -1;
  }
  // Signal that one job has finished.
  void SignalJobError(int task_id) {
84
    num_pending_.fetch_sub(1);
85
    par_errors_[task_id] = TVMGetLastError();
86
    has_error_.store(true);
87 88 89
  }
  // Signal that one job has finished.
  void SignalJobFinish() {
90
    num_pending_.fetch_sub(1);
91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
  }
  // Get thread local version of the store.
  static ParallelLauncher* ThreadLocal() {
    return dmlc::ThreadLocalStore<ParallelLauncher>::Get();
  }
  // The parallel lambda
  FTVMParallelLambda flambda;
  // The closure data
  void* cdata;
  // Local env
  TVMParallelGroupEnv env;
  // Whether this thread is worker of the pool.
  // used to prevent recursive launch.
  bool is_worker{false};

 private:
  // The pending jobs.
108
  std::atomic<int32_t> num_pending_;
109
  // Whether error has been countered.
110
  std::atomic<bool> has_error_;
111 112 113 114 115 116
  // The counter page.
  std::atomic<int32_t>* sync_counter_{nullptr};
  // The error message
  std::vector<std::string> par_errors_;
};

117 118
/*! \brief Lock-free single-producer-single-consumer queue for each thread */
class SpscTaskQueue {
119 120 121 122 123 124
 public:
  /*! \brief The task entry */
  struct Task {
    ParallelLauncher* launcher;
    int32_t task_id;
  };
125 126 127 128 129

  SpscTaskQueue() :
    buffer_(new Task[kRingSize]),
    head_(0),
    tail_(0) {
130
  }
131 132 133

  ~SpscTaskQueue() {
    delete[] buffer_;
134
  }
135

136
  /*!
137 138
   * \brief Push a task into the queue and notify the comsumer if it is on wait.
   * \param input The task to be dequeued.
139
   */
140 141
  void Push(const Task& input) {
    while (!Enqueue(input)) {
142
      tvm::runtime::threading::Yield();
143
    }
144 145
    if (pending_.fetch_add(1) == -1) {
      std::unique_lock<std::mutex> lock(mutex_);
146 147 148
      cv_.notify_one();
    }
  }
149

150
  /*!
151 152 153 154
   * \brief Pop a task out of the queue and condition wait if no tasks.
   * \param output The pointer to the task to be dequeued.
   * \param spin_count The number of iterations to spin before sleep.
   * \return Whether pop is successful (true) or we need to exit now (false).
155
   */
156 157 158 159 160
  bool Pop(Task* output, uint32_t spin_count = 300000) {
    // Busy wait a bit when the queue is empty.
    // If a new task comes to the queue quickly, this wait avoid the worker from sleeping.
    // The default spin count is set by following the typical omp convention
    for (uint32_t i = 0; i < spin_count && pending_.load() == 0; ++i) {
161 162
      tvm::runtime::threading::Yield();
    }
163 164
    if (pending_.fetch_sub(1) == 0) {
      std::unique_lock<std::mutex> lock(mutex_);
165
      cv_.wait(lock, [this] {
166
          return pending_.load() >= 0 || exit_now_.load();
167 168
        });
    }
169 170 171 172 173 174 175 176
    if (exit_now_.load(std::memory_order_relaxed)) {
      return false;
    }
    const uint32_t head = head_.load(std::memory_order_relaxed);
    // sanity check if the queue is empty
    CHECK(tail_.load(std::memory_order_acquire) != head);
    *output = buffer_[head];
    head_.store((head + 1) % kRingSize, std::memory_order_release);
177 178 179
    return true;
  }

180 181 182 183 184 185 186 187 188
  /*!
   * \brief Signal to terminate the worker.
   */
  void SignalForKill() {
    std::lock_guard<std::mutex> lock(mutex_);
    exit_now_.store(true);
    cv_.notify_all();
  }

189
 protected:
190 191 192 193 194 195
  /*!
   * \brief Lock-free enqueue.
   * \param input The task to be enqueued.
   * \return Whether the task is enqueued.
   */
  bool Enqueue(const Task& input) {
196 197
    if (exit_now_.load(std::memory_order_relaxed)) return false;

198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232
    const uint32_t tail = tail_.load(std::memory_order_relaxed);

    if ((tail + 1) % kRingSize != (head_.load(std::memory_order_acquire))) {
      buffer_[tail] = input;
      tail_.store((tail + 1) % kRingSize, std::memory_order_release);
      return true;
    }
    return false;
  }

  // the cache line paddings are used for avoid false sharing between atomic variables
  typedef char cache_line_pad_t[kL1CacheBytes];
  cache_line_pad_t pad0_;
  // size of the queue, the queue can host size_ - 1 items at most
  // define it as a constant for better compiler optimization
  static constexpr const int kRingSize = 2;
  // pointer to access the item
  Task* const buffer_;

  cache_line_pad_t pad1_;
  // queue head, where one gets a task from the queue
  std::atomic<uint32_t> head_;

  cache_line_pad_t pad2_;
  // queue tail, when one puts a task to the queue
  std::atomic<uint32_t> tail_;

  cache_line_pad_t pad3_;
  // pending tasks in the queue
  std::atomic<int8_t> pending_{0};

  cache_line_pad_t pad4_;
  // signal for exit now
  std::atomic<bool> exit_now_{false};

233 234 235 236 237 238 239 240 241
  // internal mutex
  std::mutex mutex_;
  // cv for consumer
  std::condition_variable cv_;
};

// The thread pool
class ThreadPool {
 public:
242 243
  ThreadPool(): num_workers_(tvm::runtime::threading::MaxConcurrency()) {
    for (int i = 0; i < num_workers_; ++i) {
244
      // The SpscTaskQueue only hosts ONE item at a time
245
      queues_.emplace_back(std::unique_ptr<SpscTaskQueue>(new SpscTaskQueue()));
246
    }
247 248 249
    threads_ = std::unique_ptr<tvm::runtime::threading::ThreadGroup>(
        new tvm::runtime::threading::ThreadGroup(
          num_workers_, [this](int worker_id) { this->RunWorker(worker_id); },
250
          exclude_worker0_ /* include_main_thread */));
251
  }
252 253 254 255 256 257
  ~ThreadPool() {
    for (std::unique_ptr<SpscTaskQueue>& q : queues_) {
      q->SignalForKill();
    }
    threads_.reset();
  }
258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273
  int Launch(FTVMParallelLambda flambda,
             void* cdata,
             int num_task,
             int need_sync) {
    ParallelLauncher* launcher = ParallelLauncher::ThreadLocal();
    CHECK(!launcher->is_worker)
        << "Cannot launch parallel job inside worker, consider fuse then parallel";
    if (num_task == 0) {
      num_task = num_workers_;
    }
    if (need_sync != 0) {
      CHECK_LE(num_task, num_workers_)
          << "Request parallel sync task larger than number of threads available "
          << " workers=" << num_workers_ << " request=" << num_task;
    }
    launcher->Init(flambda, cdata, num_task, need_sync != 0);
274
    SpscTaskQueue::Task tsk;
275
    tsk.launcher = launcher;
276 277
    // if worker0 is taken by the master, queues_[0] is abandoned
    for (int i = exclude_worker0_; i < num_task; ++i) {
278 279 280
      tsk.task_id = i;
      queues_[i]->Push(tsk);
    }
281 282 283 284 285 286 287 288 289
    // use the master thread to run task 0
    if (exclude_worker0_) {
      TVMParallelGroupEnv* penv = &(tsk.launcher->env);
      if ((*tsk.launcher->flambda)(0, penv, cdata) == 0) {
        tsk.launcher->SignalJobFinish();
      } else {
        tsk.launcher->SignalJobError(tsk.task_id);
      }
    }
290 291
    int res = launcher->WaitForJobs();
    return res;
292 293
  }

294 295
  static ThreadPool* ThreadLocal() {
    return dmlc::ThreadLocalStore<ThreadPool>::Get();
296 297 298 299
  }

 private:
  // Internal worker function.
300 301
  void RunWorker(int worker_id) {
    SpscTaskQueue* queue = queues_[worker_id].get();
302
    SpscTaskQueue::Task task;
303 304 305 306 307 308 309 310 311 312 313 314 315
    ParallelLauncher::ThreadLocal()->is_worker = true;
    while (queue->Pop(&task)) {
      CHECK(task.launcher != nullptr);
      TVMParallelGroupEnv* penv = &(task.launcher->env);
      void* cdata = task.launcher->cdata;
      if ((*task.launcher->flambda)(task.task_id, penv, cdata) == 0) {
        task.launcher->SignalJobFinish();
      } else {
        task.launcher->SignalJobError(task.task_id);
      }
    }
  }
  int num_workers_;
316
  // if excluding worker 0 and using master to run task 0
nhynes committed
317
#ifndef _LIBCPP_SGX_CONFIG
318
  bool exclude_worker0_{true};
nhynes committed
319 320 321
#else
  bool exclude_worker0_{false};
#endif
322
  std::vector<std::unique_ptr<SpscTaskQueue> > queues_;
323
  std::unique_ptr<tvm::runtime::threading::ThreadGroup> threads_;
324 325 326 327 328 329 330 331 332
};

}  // namespace runtime
}  // namespace tvm

int TVMBackendParallelLaunch(
    FTVMParallelLambda flambda,
    void* cdata,
    int num_task) {
333
  int res = tvm::runtime::ThreadPool::ThreadLocal()->Launch(
334
      flambda, cdata, num_task, 1);
335
  return res;
336 337 338 339 340 341 342 343 344 345 346 347 348
}

int TVMBackendParallelBarrier(int task_id, TVMParallelGroupEnv* penv) {
  using tvm::runtime::kSyncStride;
  int num_task = penv->num_task;
  std::atomic<int>* sync_counter =
      reinterpret_cast<std::atomic<int>*>(penv->sync_handle);
  int old_counter = sync_counter[task_id * kSyncStride].fetch_add(
      1, std::memory_order_release);
  for (int i = 0; i < num_task; ++i) {
    if (i != task_id) {
      while (sync_counter[i * kSyncStride].load(
                 std::memory_order_relaxed) <= old_counter) {
349
        tvm::runtime::threading::Yield();
350 351 352 353 354 355
      }
    }
  }
  std::atomic_thread_fence(std::memory_order_acquire);
  return 0;
}