thread_pool.cc 11.5 KB
Newer Older
1 2 3 4 5 6 7
/*!
 *  Copyright (c) 2017 by Contributors
 * \file thread_pool.cc
 * \brief Threadpool for multi-threading runtime.
 */
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/c_backend_api.h>
8 9
#include <tvm/runtime/registry.h>
#include <tvm/runtime/packed_func.h>
10
#include <tvm/runtime/threading_backend.h>
11 12 13 14 15 16
#include <dmlc/thread_local.h>
#include <dmlc/logging.h>
#include <thread>
#include <condition_variable>
#include <mutex>
#include <atomic>
17
#include <algorithm>
18 19 20 21 22
#include <vector>
#include <string>
#include <cstring>
#include <memory>
#include <sstream>
23 24

const constexpr int kL1CacheBytes = 64;
25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41

namespace tvm {
namespace runtime {

// stride in the page, fit to cache line.
constexpr int kSyncStride = 64 / sizeof(std::atomic<int>);

/*!
 * \brief Thread local master environment.
 */
class ParallelLauncher {
 public:
  // Reset the the task request.
  void Init(FTVMParallelLambda flambda,
            void* cdata,
            int num_task,
            bool need_sync) {
42
    num_pending_.store(num_task);
43 44 45
    this->cdata = cdata;
    this->flambda = flambda;
    this->env.num_task = num_task;
46
    has_error_.store(false);
47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69
    // reshape
    if (static_cast<size_t>(num_task) > par_errors_.size()) {
      par_errors_.resize(num_task + 1);
      if (need_sync) {
        delete[] sync_counter_;
        sync_counter_ = new std::atomic<int>[num_task * kSyncStride];
      }
    }
    if (need_sync) {
      for (int i = 0; i < num_task; ++i) {
        sync_counter_[i * kSyncStride].store(
            0, std::memory_order_relaxed);
      }
      this->env.sync_handle = sync_counter_;
    } else {
      this->env.sync_handle = nullptr;
    }
  }
  ~ParallelLauncher() {
    delete[] sync_counter_;
  }
  // Wait n jobs to finish
  int WaitForJobs() {
70 71 72 73
    while (num_pending_.load() != 0) {
      tvm::runtime::threading::Yield();
    }
    if (!has_error_.load()) return 0;
74 75
    // the following is intended to use string due to
    // security issue raised in SGX backend
76
    std::string err("");
77 78
    for (size_t i = 0; i < par_errors_.size(); ++i) {
      if (par_errors_[i].length() != 0) {
79
        err += "Task " + std::to_string(i) + " error: " + par_errors_[i] + '\n';
80 81 82
        par_errors_[i].clear();
      }
    }
83
    TVMAPISetLastError(err.c_str());
84 85 86 87
    return -1;
  }
  // Signal that one job has finished.
  void SignalJobError(int task_id) {
88
    num_pending_.fetch_sub(1);
89
    par_errors_[task_id] = TVMGetLastError();
90
    has_error_.store(true);
91 92 93
  }
  // Signal that one job has finished.
  void SignalJobFinish() {
94
    num_pending_.fetch_sub(1);
95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
  }
  // Get thread local version of the store.
  static ParallelLauncher* ThreadLocal() {
    return dmlc::ThreadLocalStore<ParallelLauncher>::Get();
  }
  // The parallel lambda
  FTVMParallelLambda flambda;
  // The closure data
  void* cdata;
  // Local env
  TVMParallelGroupEnv env;
  // Whether this thread is worker of the pool.
  // used to prevent recursive launch.
  bool is_worker{false};

 private:
  // The pending jobs.
112
  std::atomic<int32_t> num_pending_;
113
  // Whether error has been countered.
114
  std::atomic<bool> has_error_;
115 116 117 118 119 120
  // The counter page.
  std::atomic<int32_t>* sync_counter_{nullptr};
  // The error message
  std::vector<std::string> par_errors_;
};

121 122
/*! \brief Lock-free single-producer-single-consumer queue for each thread */
class SpscTaskQueue {
123 124 125 126 127 128
 public:
  /*! \brief The task entry */
  struct Task {
    ParallelLauncher* launcher;
    int32_t task_id;
  };
129 130 131 132 133

  SpscTaskQueue() :
    buffer_(new Task[kRingSize]),
    head_(0),
    tail_(0) {
134
  }
135 136 137

  ~SpscTaskQueue() {
    delete[] buffer_;
138
  }
139

140
  /*!
141 142
   * \brief Push a task into the queue and notify the comsumer if it is on wait.
   * \param input The task to be dequeued.
143
   */
144 145
  void Push(const Task& input) {
    while (!Enqueue(input)) {
146
      tvm::runtime::threading::Yield();
147
    }
148 149
    if (pending_.fetch_add(1) == -1) {
      std::unique_lock<std::mutex> lock(mutex_);
150 151 152
      cv_.notify_one();
    }
  }
153

154
  /*!
155 156 157 158
   * \brief Pop a task out of the queue and condition wait if no tasks.
   * \param output The pointer to the task to be dequeued.
   * \param spin_count The number of iterations to spin before sleep.
   * \return Whether pop is successful (true) or we need to exit now (false).
159
   */
160 161 162 163 164
  bool Pop(Task* output, uint32_t spin_count = 300000) {
    // Busy wait a bit when the queue is empty.
    // If a new task comes to the queue quickly, this wait avoid the worker from sleeping.
    // The default spin count is set by following the typical omp convention
    for (uint32_t i = 0; i < spin_count && pending_.load() == 0; ++i) {
165 166
      tvm::runtime::threading::Yield();
    }
167 168
    if (pending_.fetch_sub(1) == 0) {
      std::unique_lock<std::mutex> lock(mutex_);
169
      cv_.wait(lock, [this] {
170
          return pending_.load() >= 0 || exit_now_.load();
171 172
        });
    }
173 174 175 176 177 178 179 180
    if (exit_now_.load(std::memory_order_relaxed)) {
      return false;
    }
    const uint32_t head = head_.load(std::memory_order_relaxed);
    // sanity check if the queue is empty
    CHECK(tail_.load(std::memory_order_acquire) != head);
    *output = buffer_[head];
    head_.store((head + 1) % kRingSize, std::memory_order_release);
181 182 183
    return true;
  }

184 185 186 187 188 189 190 191 192
  /*!
   * \brief Signal to terminate the worker.
   */
  void SignalForKill() {
    std::lock_guard<std::mutex> lock(mutex_);
    exit_now_.store(true);
    cv_.notify_all();
  }

193
 protected:
194 195 196 197 198 199
  /*!
   * \brief Lock-free enqueue.
   * \param input The task to be enqueued.
   * \return Whether the task is enqueued.
   */
  bool Enqueue(const Task& input) {
200 201
    if (exit_now_.load(std::memory_order_relaxed)) return false;

202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236
    const uint32_t tail = tail_.load(std::memory_order_relaxed);

    if ((tail + 1) % kRingSize != (head_.load(std::memory_order_acquire))) {
      buffer_[tail] = input;
      tail_.store((tail + 1) % kRingSize, std::memory_order_release);
      return true;
    }
    return false;
  }

  // the cache line paddings are used for avoid false sharing between atomic variables
  typedef char cache_line_pad_t[kL1CacheBytes];
  cache_line_pad_t pad0_;
  // size of the queue, the queue can host size_ - 1 items at most
  // define it as a constant for better compiler optimization
  static constexpr const int kRingSize = 2;
  // pointer to access the item
  Task* const buffer_;

  cache_line_pad_t pad1_;
  // queue head, where one gets a task from the queue
  std::atomic<uint32_t> head_;

  cache_line_pad_t pad2_;
  // queue tail, when one puts a task to the queue
  std::atomic<uint32_t> tail_;

  cache_line_pad_t pad3_;
  // pending tasks in the queue
  std::atomic<int8_t> pending_{0};

  cache_line_pad_t pad4_;
  // signal for exit now
  std::atomic<bool> exit_now_{false};

237 238 239 240 241 242 243 244 245
  // internal mutex
  std::mutex mutex_;
  // cv for consumer
  std::condition_variable cv_;
};

// The thread pool
class ThreadPool {
 public:
246 247
  ThreadPool(): num_workers_(tvm::runtime::threading::MaxConcurrency()) {
    for (int i = 0; i < num_workers_; ++i) {
248
      // The SpscTaskQueue only hosts ONE item at a time
249
      queues_.emplace_back(std::unique_ptr<SpscTaskQueue>(new SpscTaskQueue()));
250
    }
251 252 253
    threads_ = std::unique_ptr<tvm::runtime::threading::ThreadGroup>(
        new tvm::runtime::threading::ThreadGroup(
          num_workers_, [this](int worker_id) { this->RunWorker(worker_id); },
254
          exclude_worker0_ /* include_main_thread */));
255
    num_workers_used_ = threads_->Configure(threading::ThreadGroup::kBig, 0, exclude_worker0_);
256
  }
257 258 259 260 261 262
  ~ThreadPool() {
    for (std::unique_ptr<SpscTaskQueue>& q : queues_) {
      q->SignalForKill();
    }
    threads_.reset();
  }
263 264 265 266 267 268 269 270
  int Launch(FTVMParallelLambda flambda,
             void* cdata,
             int num_task,
             int need_sync) {
    ParallelLauncher* launcher = ParallelLauncher::ThreadLocal();
    CHECK(!launcher->is_worker)
        << "Cannot launch parallel job inside worker, consider fuse then parallel";
    if (num_task == 0) {
271
      num_task = num_workers_used_;
272 273
    }
    if (need_sync != 0) {
274 275 276
      CHECK_LE(num_task, num_workers_used_)
          << "Request parallel sync task larger than number of threads used "
          << " workers=" << num_workers_used_ << " request=" << num_task;
277 278
    }
    launcher->Init(flambda, cdata, num_task, need_sync != 0);
279
    SpscTaskQueue::Task tsk;
280
    tsk.launcher = launcher;
281 282
    // if worker0 is taken by the master, queues_[0] is abandoned
    for (int i = exclude_worker0_; i < num_task; ++i) {
283 284 285
      tsk.task_id = i;
      queues_[i]->Push(tsk);
    }
286 287 288 289 290 291 292 293 294
    // use the master thread to run task 0
    if (exclude_worker0_) {
      TVMParallelGroupEnv* penv = &(tsk.launcher->env);
      if ((*tsk.launcher->flambda)(0, penv, cdata) == 0) {
        tsk.launcher->SignalJobFinish();
      } else {
        tsk.launcher->SignalJobError(tsk.task_id);
      }
    }
295 296
    int res = launcher->WaitForJobs();
    return res;
297 298
  }

299 300
  static ThreadPool* ThreadLocal() {
    return dmlc::ThreadLocalStore<ThreadPool>::Get();
301 302
  }

303 304 305 306 307 308 309 310 311 312
  void UpdateWorkerConfiguration(threading::ThreadGroup::AffinityMode mode, int nthreads) {
    // this will also reset the affinity of the ThreadGroup
    // may use less than the MaxConcurrency number of workers
    num_workers_used_ = threads_->Configure(mode, nthreads,
                                            exclude_worker0_);
    // if MaxConcurrency restricted the number of workers (e.g., due to
    // hyperthreading), respect the restriction
    num_workers_used_ = std::min(num_workers_, num_workers_used_);
  }

313 314
 private:
  // Internal worker function.
315 316
  void RunWorker(int worker_id) {
    SpscTaskQueue* queue = queues_[worker_id].get();
317
    SpscTaskQueue::Task task;
318 319 320 321 322 323 324 325 326 327 328 329 330
    ParallelLauncher::ThreadLocal()->is_worker = true;
    while (queue->Pop(&task)) {
      CHECK(task.launcher != nullptr);
      TVMParallelGroupEnv* penv = &(task.launcher->env);
      void* cdata = task.launcher->cdata;
      if ((*task.launcher->flambda)(task.task_id, penv, cdata) == 0) {
        task.launcher->SignalJobFinish();
      } else {
        task.launcher->SignalJobError(task.task_id);
      }
    }
  }
  int num_workers_;
331 332
  // number of workers used (can be restricted with affinity pref)
  int num_workers_used_;
333
  // if excluding worker 0 and using master to run task 0
nhynes committed
334
#ifndef _LIBCPP_SGX_CONFIG
335
  bool exclude_worker0_{true};
nhynes committed
336 337 338
#else
  bool exclude_worker0_{false};
#endif
339
  std::vector<std::unique_ptr<SpscTaskQueue> > queues_;
340
  std::unique_ptr<tvm::runtime::threading::ThreadGroup> threads_;
341 342
};

343 344 345 346 347 348 349 350 351 352
TVM_REGISTER_GLOBAL("runtime.config_threadpool")
.set_body([](TVMArgs args, TVMRetValue* rv) {
    threading::ThreadGroup::AffinityMode mode =\
    static_cast<threading::ThreadGroup::AffinityMode>(\
    static_cast<int>(args[0]));
    int nthreads = args[1];
    ThreadPool::ThreadLocal()->UpdateWorkerConfiguration(mode, nthreads);
});


353 354 355
}  // namespace runtime
}  // namespace tvm

356

357 358 359 360
int TVMBackendParallelLaunch(
    FTVMParallelLambda flambda,
    void* cdata,
    int num_task) {
361
  int res = tvm::runtime::ThreadPool::ThreadLocal()->Launch(
362
      flambda, cdata, num_task, 1);
363
  return res;
364 365 366 367 368 369 370 371 372 373 374 375 376
}

int TVMBackendParallelBarrier(int task_id, TVMParallelGroupEnv* penv) {
  using tvm::runtime::kSyncStride;
  int num_task = penv->num_task;
  std::atomic<int>* sync_counter =
      reinterpret_cast<std::atomic<int>*>(penv->sync_handle);
  int old_counter = sync_counter[task_id * kSyncStride].fetch_add(
      1, std::memory_order_release);
  for (int i = 0; i < num_task; ++i) {
    if (i != task_id) {
      while (sync_counter[i * kSyncStride].load(
                 std::memory_order_relaxed) <= old_counter) {
377
        tvm::runtime::threading::Yield();
378 379 380 381 382 383
      }
    }
  }
  std::atomic_thread_fence(std::memory_order_acquire);
  return 0;
}