Commit 7e7154f1 by eqy Committed by Tianqi Chen

[RUNTIME] Support setting CPU affinity (#1403)

parent 0be4384e
...@@ -44,6 +44,25 @@ class ThreadGroup { ...@@ -44,6 +44,25 @@ class ThreadGroup {
*/ */
void Join(); void Join();
enum AffinityMode : int {
kBig = 1,
kLittle = -1,
};
/*!
* \brief configure the CPU id affinity
*
* \param mode The preferred CPU type (1 = big, -1 = little).
* \param nthreads The number of threads to use (0 = use all).
* \param exclude_worker0 Whether to use the main thread as a worker.
* If `true`, worker0 will not be launched in a new thread and
* `worker_callback` will only be called for values >= 1. This
* allows use of the main thread as a worker.
*
* \return The number of workers to use.
*/
int Configure(AffinityMode mode, int nthreads, bool exclude_worker0);
private: private:
Impl* impl_; Impl* impl_;
}; };
...@@ -58,6 +77,7 @@ void Yield(); ...@@ -58,6 +77,7 @@ void Yield();
*/ */
int MaxConcurrency(); int MaxConcurrency();
} // namespace threading } // namespace threading
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
......
...@@ -53,6 +53,13 @@ ThreadGroup::ThreadGroup(int num_workers, ...@@ -53,6 +53,13 @@ ThreadGroup::ThreadGroup(int num_workers,
bool exclude_worker0) bool exclude_worker0)
: impl_(new ThreadGroup::Impl(num_workers, worker_callback, exclude_worker0)) {} : impl_(new ThreadGroup::Impl(num_workers, worker_callback, exclude_worker0)) {}
void ThreadGroup::Join() {} void ThreadGroup::Join() {}
int ThreadGroup::Configure(AffinityMode mode, int nthreads, bool exclude_worker0) {
int max_conc = MaxConcurrency();
if (!nthreads || ntheads > max_conc) {
return max_conc;
}
return nthreads;
}
ThreadGroup::~ThreadGroup() { delete impl_; } ThreadGroup::~ThreadGroup() { delete impl_; }
void Yield() {} void Yield() {}
......
...@@ -5,6 +5,8 @@ ...@@ -5,6 +5,8 @@
*/ */
#include <tvm/runtime/c_runtime_api.h> #include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/c_backend_api.h> #include <tvm/runtime/c_backend_api.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/threading_backend.h> #include <tvm/runtime/threading_backend.h>
#include <dmlc/thread_local.h> #include <dmlc/thread_local.h>
#include <dmlc/logging.h> #include <dmlc/logging.h>
...@@ -250,6 +252,10 @@ class ThreadPool { ...@@ -250,6 +252,10 @@ class ThreadPool {
new tvm::runtime::threading::ThreadGroup( new tvm::runtime::threading::ThreadGroup(
num_workers_, [this](int worker_id) { this->RunWorker(worker_id); }, num_workers_, [this](int worker_id) { this->RunWorker(worker_id); },
exclude_worker0_ /* include_main_thread */)); exclude_worker0_ /* include_main_thread */));
num_workers_used_ = threads_->Configure(threading::ThreadGroup::kBig, 0, exclude_worker0_);
// if MaxConcurrency restricted the number of workers (e.g., due to
// hyperthreading), respect the restriction
num_workers_used_ = std::min(num_workers_, num_workers_used_);
} }
~ThreadPool() { ~ThreadPool() {
for (std::unique_ptr<SpscTaskQueue>& q : queues_) { for (std::unique_ptr<SpscTaskQueue>& q : queues_) {
...@@ -265,12 +271,12 @@ class ThreadPool { ...@@ -265,12 +271,12 @@ class ThreadPool {
CHECK(!launcher->is_worker) CHECK(!launcher->is_worker)
<< "Cannot launch parallel job inside worker, consider fuse then parallel"; << "Cannot launch parallel job inside worker, consider fuse then parallel";
if (num_task == 0) { if (num_task == 0) {
num_task = num_workers_; num_task = num_workers_used_;
} }
if (need_sync != 0) { if (need_sync != 0) {
CHECK_LE(num_task, num_workers_) CHECK_LE(num_task, num_workers_used_)
<< "Request parallel sync task larger than number of threads available " << "Request parallel sync task larger than number of threads used "
<< " workers=" << num_workers_ << " request=" << num_task; << " workers=" << num_workers_used_ << " request=" << num_task;
} }
launcher->Init(flambda, cdata, num_task, need_sync != 0); launcher->Init(flambda, cdata, num_task, need_sync != 0);
SpscTaskQueue::Task tsk; SpscTaskQueue::Task tsk;
...@@ -297,6 +303,16 @@ class ThreadPool { ...@@ -297,6 +303,16 @@ class ThreadPool {
return dmlc::ThreadLocalStore<ThreadPool>::Get(); return dmlc::ThreadLocalStore<ThreadPool>::Get();
} }
void UpdateWorkerConfiguration(threading::ThreadGroup::AffinityMode mode, int nthreads) {
// this will also reset the affinity of the ThreadGroup
// may use less than the MaxConcurrency number of workers
num_workers_used_ = threads_->Configure(mode, nthreads,
exclude_worker0_);
// if MaxConcurrency restricted the number of workers (e.g., due to
// hyperthreading), respect the restriction
num_workers_used_ = std::min(num_workers_, num_workers_used_);
}
private: private:
// Internal worker function. // Internal worker function.
void RunWorker(int worker_id) { void RunWorker(int worker_id) {
...@@ -315,6 +331,8 @@ class ThreadPool { ...@@ -315,6 +331,8 @@ class ThreadPool {
} }
} }
int num_workers_; int num_workers_;
// number of workers used (can be restricted with affinity pref)
int num_workers_used_;
// if excluding worker 0 and using master to run task 0 // if excluding worker 0 and using master to run task 0
#ifndef _LIBCPP_SGX_CONFIG #ifndef _LIBCPP_SGX_CONFIG
bool exclude_worker0_{true}; bool exclude_worker0_{true};
...@@ -325,9 +343,20 @@ class ThreadPool { ...@@ -325,9 +343,20 @@ class ThreadPool {
std::unique_ptr<tvm::runtime::threading::ThreadGroup> threads_; std::unique_ptr<tvm::runtime::threading::ThreadGroup> threads_;
}; };
TVM_REGISTER_GLOBAL("runtime.config_threadpool")
.set_body([](TVMArgs args, TVMRetValue* rv) {
threading::ThreadGroup::AffinityMode mode =\
static_cast<threading::ThreadGroup::AffinityMode>(\
static_cast<int>(args[0]));
int nthreads = args[1];
ThreadPool::ThreadLocal()->UpdateWorkerConfiguration(mode, nthreads);
});
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
int TVMBackendParallelLaunch( int TVMBackendParallelLaunch(
FTVMParallelLambda flambda, FTVMParallelLambda flambda,
void* cdata, void* cdata,
......
...@@ -7,6 +7,10 @@ ...@@ -7,6 +7,10 @@
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <thread> #include <thread>
#include <algorithm> #include <algorithm>
#if defined(__linux__) || defined(__ANDROID__)
#include <fstream>
#else
#endif
#if defined(__linux__) #if defined(__linux__)
#include <sched.h> #include <sched.h>
#endif #endif
...@@ -26,30 +30,49 @@ class ThreadGroup::Impl { ...@@ -26,30 +30,49 @@ class ThreadGroup::Impl {
for (int i = exclude_worker0; i < num_workers_; ++i) { for (int i = exclude_worker0; i < num_workers_; ++i) {
threads_.emplace_back([worker_callback, i] { worker_callback(i); }); threads_.emplace_back([worker_callback, i] { worker_callback(i); });
} }
InitSortedOrder();
}
~Impl() { Join(); }
void Join() {
for (auto& t : threads_) {
if (t.joinable()) t.join();
}
}
int Configure(AffinityMode mode, int nthreads, bool exclude_worker0) {
int num_workers_used = 0;
if (mode == kLittle) {
num_workers_used = little_count_;
} else if (mode == kBig) {
num_workers_used = big_count_;
} else {
// use default
num_workers_used = threading::MaxConcurrency();
}
// if a specific number was given, use that
if (nthreads) {
num_workers_used = nthreads;
}
const char *val = getenv("TVM_BIND_THREADS"); const char *val = getenv("TVM_BIND_THREADS");
if (val == nullptr || atoi(val) == 1) { if (val == nullptr || atoi(val) == 1) {
if (static_cast<size_t>(num_workers_) <= std::thread::hardware_concurrency()) { // Skip if sorted_order.size() is bigger than the number of workers (threads_)
SetAffinity(exclude_worker0); if (!(sorted_order_.size() > static_cast<unsigned int>(num_workers_))) {
SetAffinity(exclude_worker0, mode == kLittle);
} else { } else {
LOG(WARNING) LOG(WARNING)
<< "The thread affinity cannot be set when the number of workers" << "The thread affinity cannot be set when the number of workers"
<< "is larger than the number of available cores in the system."; << "is larger than the number of available cores in the system.";
} }
} }
} return num_workers_used;
~Impl() { Join(); }
void Join() {
for (auto& t : threads_) {
if (t.joinable()) t.join();
}
} }
private: private:
// bind worker threads to disjoint cores // bind worker threads to disjoint cores
// if worker 0 is offloaded to master, i.e. exclude_worker0 is true, // if worker 0 is offloaded to master, i.e. exclude_worker0 is true,
// the master thread is bound to core 0. // the master thread is bound to core 0.
void SetAffinity(bool exclude_worker0) { void SetAffinity(bool exclude_worker0, bool reverse = false) {
#if defined(__ANDROID__) #if defined(__ANDROID__)
#ifndef CPU_SET #ifndef CPU_SET
#define CPU_SETSIZE 1024 #define CPU_SETSIZE 1024
...@@ -65,8 +88,15 @@ class ThreadGroup::Impl { ...@@ -65,8 +88,15 @@ class ThreadGroup::Impl {
#endif #endif
#endif #endif
#if defined(__linux__) || defined(__ANDROID__) #if defined(__linux__) || defined(__ANDROID__)
CHECK_GE(sorted_order_.size(), num_workers_);
for (unsigned i = 0; i < threads_.size(); ++i) { for (unsigned i = 0; i < threads_.size(); ++i) {
unsigned core_id = i + exclude_worker0; unsigned core_id;
if (reverse) {
core_id = sorted_order_[sorted_order_.size() - (i + exclude_worker0) - 1];
} else {
core_id = sorted_order_[i + exclude_worker0];
}
cpu_set_t cpuset; cpu_set_t cpuset;
CPU_ZERO(&cpuset); CPU_ZERO(&cpuset);
CPU_SET(core_id, &cpuset); CPU_SET(core_id, &cpuset);
...@@ -80,7 +110,11 @@ class ThreadGroup::Impl { ...@@ -80,7 +110,11 @@ class ThreadGroup::Impl {
if (exclude_worker0) { // bind the master thread to core 0 if (exclude_worker0) { // bind the master thread to core 0
cpu_set_t cpuset; cpu_set_t cpuset;
CPU_ZERO(&cpuset); CPU_ZERO(&cpuset);
CPU_SET(0, &cpuset); if (reverse) {
CPU_SET(sorted_order_[sorted_order_.size() - 1], &cpuset);
} else {
CPU_SET(sorted_order_[0], &cpuset);
}
#if defined(__ANDROID__) #if defined(__ANDROID__)
sched_setaffinity(pthread_self(), sched_setaffinity(pthread_self(),
sizeof(cpu_set_t), &cpuset); sizeof(cpu_set_t), &cpuset);
...@@ -92,8 +126,52 @@ class ThreadGroup::Impl { ...@@ -92,8 +126,52 @@ class ThreadGroup::Impl {
#endif #endif
} }
void InitSortedOrder() {
unsigned int threads = std::thread::hardware_concurrency();
std::vector<std::pair <unsigned int, int64_t> > max_freqs;
for (unsigned int i = 0; i < threads; ++i) {
int64_t cur_freq = 0;
#if defined(__linux__) || defined(__ANDROID__)
std::ostringstream filepath;
filepath << "/sys/devices/system/cpu/cpu" << i << "/cpufreq/cpuinfo_max_freq";
std::ifstream ifs(filepath.str());
if (!ifs.fail()) {
if (!(ifs >> cur_freq)) {
cur_freq = -1;
}
ifs.close();
}
#endif
max_freqs.push_back(std::make_pair(i, cur_freq));
}
auto fcmpbyfreq = [] (const std::pair<unsigned int, int64_t> &a,
const std::pair<unsigned int, int64_t> &b) {
return a.second == b.second ? a.first < b.first : a.second > b.second;
};
std::sort(max_freqs.begin(), max_freqs.end(), fcmpbyfreq);
int64_t big_freq = max_freqs.begin()->second;
int64_t little_freq = max_freqs.rbegin()->second;
for (auto it = max_freqs.begin(); it != max_freqs.end(); it++) {
sorted_order_.push_back(it->first);
if (big_freq == it->second) {
big_count_++;
}
if (big_freq != little_freq && little_freq == it->second) {
little_count_++;
}
}
if (big_count_ + little_count_ != static_cast<int>(sorted_order_.size())) {
LOG(WARNING) << "more than two frequencies detected!";
}
}
int num_workers_; int num_workers_;
std::vector<std::thread> threads_; std::vector<std::thread> threads_;
std::vector<unsigned int> sorted_order_;
int big_count_ = 0;
int little_count_ = 0;
}; };
ThreadGroup::ThreadGroup(int num_workers, ThreadGroup::ThreadGroup(int num_workers,
...@@ -103,6 +181,10 @@ ThreadGroup::ThreadGroup(int num_workers, ...@@ -103,6 +181,10 @@ ThreadGroup::ThreadGroup(int num_workers,
ThreadGroup::~ThreadGroup() { delete impl_; } ThreadGroup::~ThreadGroup() { delete impl_; }
void ThreadGroup::Join() { impl_->Join(); } void ThreadGroup::Join() { impl_->Join(); }
int ThreadGroup::Configure(AffinityMode mode, int nthreads, bool exclude_worker0) {
return impl_->Configure(mode, nthreads, exclude_worker0);
}
void Yield() { void Yield() {
std::this_thread::yield(); std::this_thread::yield();
} }
...@@ -124,6 +206,7 @@ int MaxConcurrency() { ...@@ -124,6 +206,7 @@ int MaxConcurrency() {
return std::max(max_concurrency, 1); return std::max(max_concurrency, 1);
} }
} // namespace threading } // namespace threading
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment