Commit bfceafc7 by nhynes Committed by Tianqi Chen

Pluggable Thread Launching Mechanism (#991)

parent cbde86f9
......@@ -25,6 +25,7 @@
#include "../../src/runtime/module.cc"
#include "../../src/runtime/registry.cc"
#include "../../src/runtime/file_util.cc"
#include "../../src/runtime/threading_backend.cc"
#include "../../src/runtime/thread_pool.cc"
// NOTE: all the files after this are optional modules
......
......@@ -26,6 +26,7 @@ pkg_cflags := -std=c++11 -O2 -fPIC\
-I${TVM_ROOT}/dlpack/include\
-I.\
-DDMLC_LOG_STACK_TRACE=0\
-fmax-errors=4
pkg_ldflags := -L${TVM_ROOT}/lib
......@@ -40,7 +41,7 @@ enclave_cflags := -static -nostdinc\
-DDMLC_CXX11_THREAD_LOCAL=0\
$(enclave_include_paths)\
enclave_cxxflags := -nostdinc++ $(enclave_cflags)
enclave_cxxflags := -nostdinc++ $(enclave_cflags) -DTVM_SGX_MAX_CONCURRENCY=4
enclave_ldflags :=\
-Wl,--no-undefined -nostdlib -nodefaultlibs -nostartfiles -L$(SGX_SDK)/lib64\
......@@ -62,7 +63,7 @@ app_ldflags := -L$(SGX_SDK)/lib64\
all: lib/test_addone.signed.so bin/test_addone
# Build rule for all-in-one TVM package library
lib/tvm_runtime_pack.o: tvm_runtime_pack.cc
lib/tvm_runtime_pack.o: tvm_runtime_pack.cc lib/test_addone_t.o
@mkdir -p $(@D)
$(CXX) -c $< -o $@ $(pkg_cflags) $(pkg_ldflags) $(enclave_cxxflags) -g
......@@ -94,7 +95,7 @@ lib/test_addone.signed.so: lib/test_addone.so enclave_config.xml
# An app that runs the enclave
bin/test_addone: app.cc lib/test_addone_u.o
@mkdir -p $(@D)
$(CXX) $^ -o $@ $(app_cflags) $(app_ldflags)
$(CXX) $^ -o $@ $(app_cflags) $(app_ldflags) $(pkg_cflags) -g
# Debugging runtime pack built without SGX (c.f. howto_deploy/tvm_runtime_pack.cc)
lib/tvm_runtime_pack_nosgx.o: tvm_runtime_pack.cc
......@@ -104,7 +105,7 @@ lib/tvm_runtime_pack_nosgx.o: tvm_runtime_pack.cc
# Debugging binary that runs TVM without SGX
bin/addone_nosgx: enclave.cc lib/tvm_runtime_pack_nosgx.o lib/test_addone_sys.o
@mkdir -p $(@D)
$(CXX) $^ -o $@ $(pkg_cflags) $(pkg_ldflags) -g
$(CXX) $^ -o $@ $(pkg_cflags) $(pkg_ldflags) -g -lpthread
clean:
rm -rf lib bin
#include <cstdio>
#include <iostream>
#include "sgx_urts.h"
#include "sgx_eid.h"
#include "test_addone_u.h"
#include "../../sgx/runtime_u.cc"
#define TOKEN_FILENAME "bin/test_addone.token"
#define ENCLAVE_FILENAME "lib/test_addone.signed.so"
sgx_enclave_id_t global_eid = 0; // global EID shared by multiple threads
sgx_enclave_id_t tvm_sgx_eid;
typedef struct _sgx_errlist_t {
sgx_status_t err;
......@@ -80,7 +82,7 @@ int initialize_enclave(void)
/* Step 2: call sgx_create_enclave to initialize an enclave instance */
/* Debug Support: set 2nd parameter to 1 */
sgx_status = sgx_create_enclave(ENCLAVE_FILENAME, SGX_DEBUG_FLAG, &token, &updated, &global_eid, NULL);
sgx_status = sgx_create_enclave(ENCLAVE_FILENAME, SGX_DEBUG_FLAG, &token, &updated, &tvm_sgx_eid, NULL);
if (sgx_status != SGX_SUCCESS) {
print_error_message(sgx_status);
if (fp != NULL) fclose(fp);
......@@ -105,7 +107,7 @@ int initialize_enclave(void)
}
int SGX_CDECL main(int argc, char *argv[]) {
if(initialize_enclave() < 0){
if(initialize_enclave() < 0) {
printf("Failed to initialize enclave.\n");
return -1;
}
......@@ -113,12 +115,13 @@ int SGX_CDECL main(int argc, char *argv[]) {
/* Run TVM within the enclave */
int addone_status;
sgx_status_t sgx_status = SGX_ERROR_UNEXPECTED;
sgx_status = enclave_main(global_eid, &addone_status);
sgx_status = tvm_ecall_run_module(tvm_sgx_eid, nullptr, &addone_status);
if (sgx_status != SGX_SUCCESS) {
print_error_message(sgx_status);
}
sgx_destroy_enclave(global_eid);
tvm_ecall_shutdown(tvm_sgx_eid);
tvm::runtime::sgx::Shutdown();
sgx_destroy_enclave(tvm_sgx_eid);
if (addone_status == 1) {
printf("It works!");
......@@ -127,3 +130,9 @@ int SGX_CDECL main(int argc, char *argv[]) {
printf("It doesn't work.");
return -1;
}
extern "C" {
void ocall_println(const char* str) {
std::cout << "Enclave says: " << str << std::endl;
}
}
......@@ -6,6 +6,8 @@
#include <iostream>
#endif
extern void Shutdown();
/* This function mirrors the one in howto_deploy except without the iostream */
int Verify(tvm::runtime::Module mod, std::string fname) {
// Get the function from the module.
......@@ -43,9 +45,9 @@ int Verify(tvm::runtime::Module mod, std::string fname) {
extern "C" {
int enclave_main() {
void tvm_ecall_run_module(const void* tvm_args, void* tvm_return_value) {
tvm::runtime::Module mod_syslib = (*tvm::runtime::Registry::Get("module._GetSystemLib"))();
return Verify(mod_syslib, "addonesys");
*(int*)tvm_return_value = Verify(mod_syslib, "addonesys");
}
}
......
<EnclaveConfiguration>
<ProdID>0</ProdID>
<ISVSVN>0</ISVSVN>
<StackMaxSize>0x2000</StackMaxSize>
<HeapMaxSize>0x1000</HeapMaxSize>
<TCSNum>1</TCSNum>
<StackMaxSize>0x100000</StackMaxSize>
<HeapMaxSize>0x100000</HeapMaxSize>
<TCSNum>5</TCSNum>
<TCSPolicy>1</TCSPolicy>
<DisableDebug>0</DisableDebug>
<MiscSelect>0</MiscSelect>
......
......@@ -11,6 +11,8 @@ def prepare_test_libs(base_path):
A = tvm.placeholder((n,), name='A')
B = tvm.compute(A.shape, lambda *i: A(*i) + 1, name='B')
s = tvm.create_schedule(B.op)
s[B].parallel(s[B].op.axis[0])
print(tvm.lower(s, [A, B], simple_mode=True))
# Compile library in system library mode
fadd_syslib = tvm.build(s, [A, B], 'llvm --system-lib', name='addonesys')
......
enclave {
from "sgx_tstdc.edl" import sgx_thread_wait_untrusted_event_ocall, sgx_thread_set_untrusted_event_ocall, sgx_thread_setwait_untrusted_events_ocall, sgx_thread_set_multiple_untrusted_events_ocall;
from "../../sgx/tvm.edl" import *;
trusted {
public int enclave_main();
untrusted {
void ocall_println([in, string] const char *str);
};
};
......@@ -5,7 +5,11 @@
* Please refer to the Makefile (rule lib/tvm_runtime_pack.o) for how to build.
*
*/
#include "../../sgx/sgx_runtime.cc"
#ifdef _LIBCPP_SGX_CONFIG
#include "lib/test_addone_t.h"
#endif
#include "../../sgx/runtime_t.cc"
#ifndef _LIBCPP_SGX_CONFIG
#include "../../src/runtime/file_util.cc"
#endif
/*!
* Copyright (c) 2018 by Contributors
* \file threading_backend.h
* \brief Utilities for manipulating thread pool threads.
*/
#ifndef TVM_RUNTIME_THREADING_BACKEND_H_
#define TVM_RUNTIME_THREADING_BACKEND_H_
#include <functional>
#include <memory>
#include <vector>
namespace tvm {
namespace runtime {
namespace threading {
/*!
* \brief A platform-agnostic abstraction for managing a collection of
* thread pool threads.
*/
class ThreadGroup {
public:
class Impl;
/*!
* \brief Creates a collection of threads which run a provided function.
*
* \param num_workers The total number of worker threads in this group.
Includes main thread if `exclude_worker0 = true`
* \param worker_callback A callback which is run in its own thread.
Receives the worker_id as an argument.
* \param exclude_worker0 Whether to use the main thread as a worker.
* If `true`, worker0 will not be launched in a new thread and
* `worker_callback` will only be called for values >= 1. This
* allows use of the main thread as a worker.
*/
ThreadGroup(int num_workers,
std::function<void(int)> worker_callback,
bool exclude_worker0 = false);
~ThreadGroup();
/*!
* \brief Blocks until all non-main threads in the pool finish.
*/
void Join();
private:
Impl* impl_;
};
/*!
* \brief Platform-agnostic no-op.
*/
void Yield();
/*!
* \return the maximum number of effective workers for this system.
*/
int MaxConcurrency();
} // namespace threading
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_THREADING_BACKEND_H_
......@@ -9,17 +9,15 @@
#include "../../src/runtime/module.cc"
#include "../../src/runtime/registry.cc"
#include "../../src/runtime/system_lib_module.cc"
#ifndef _LIBCPP_SGX_CONFIG
#include "../../src/runtime/threading_backend.cc"
#else
#include "threading_backend.cc"
#endif
#include "../../src/runtime/thread_pool.cc"
// dummy parallel runtime (for now)
int TVMBackendParallelLaunch(
FTVMParallelLambda flambda,
void* cdata,
int num_task) {
TVMParallelGroupEnv env = { nullptr /* sync_handle */, 1 /* num_task */ };
return flambda(0 /* task_id */, &env, cdata);
extern "C" {
void tvm_ecall_shutdown() {
tvm::runtime::ThreadPool::Global()->Shutdown();
}
int TVMBackendParallelBarrier(int task_id, TVMParallelGroupEnv* penv) {
return 0;
}
#include <tvm/runtime/threading_backend.h>
#include "../../src/runtime/threading_backend.cc"
#include <iostream>
extern sgx_enclave_id_t tvm_sgx_eid;
extern "C" {
sgx_status_t tvm_ecall_run_worker(sgx_enclave_id_t eid, const void* cb);
}
namespace tvm {
namespace runtime {
namespace sgx {
static std::unique_ptr<tvm::runtime::threading::ThreadGroup> sgx_thread_group;
extern "C" {
void tvm_ocall_thread_pool_launch(int num_tasks, void* cb) {
std::function<void(int)> runner = [cb](int _worker_id) {
sgx_status_t sgx_status = SGX_ERROR_UNEXPECTED;
sgx_status = tvm_ecall_run_worker(tvm_sgx_eid, cb);
CHECK(sgx_status == SGX_SUCCESS) << "SGX Error: " << sgx_status;
};
sgx_thread_group.reset(new tvm::runtime::threading::ThreadGroup(
num_tasks, runner, false /* include_main_thread */));
}
}
void Shutdown() {
sgx_thread_group->Join();
}
} // namespace sgx
} // namespace runtime
} // namespace tvm
/*!
* Copyright (c) 2018 by Contributors
* \file sgx/threading_backend.cc
* \brief SGX threading backend
*/
#include <tvm/runtime/threading_backend.h>
#include <dmlc/logging.h>
#include <sgx_edger8r.h>
#include <sgx_trts.h>
#include <atomic>
extern "C" {
sgx_status_t SGX_CDECL tvm_ocall_thread_pool_launch(int num_workers, void* cb);
}
#ifndef TVM_SGX_MAX_CONCURRENCY
#define TVM_SGX_MAX_CONCURRENCY 1
#endif
namespace tvm {
namespace runtime {
namespace threading {
class ThreadGroup::Impl {
public:
Impl(int num_workers, std::function<void(int)> worker_callback,
bool exclude_worker0)
: num_workers_(num_workers),
worker_callback_(worker_callback),
next_task_id_(exclude_worker0) {
CHECK(num_workers <= TVM_SGX_MAX_CONCURRENCY)
<< "Tried spawning more threads than allowed by TVM_SGX_MAX_CONCURRENCY.";
sgx_status_t sgx_status = SGX_ERROR_UNEXPECTED;
sgx_status = tvm_ocall_thread_pool_launch(num_workers, this);
CHECK(sgx_status == SGX_SUCCESS) << "SGX Error: " << sgx_status;
}
void RunTask() {
int task_id = next_task_id_++;
CHECK(task_id < num_workers_)
<< "More workers entered enclave than allowed by TVM_SGX_MAX_CONCURRENCY";
worker_callback_(task_id);
}
private:
int num_workers_;
std::function<void(int)> worker_callback_;
std::atomic<int> next_task_id_;
};
ThreadGroup::ThreadGroup(int num_workers,
std::function<void(int)> worker_callback,
bool exclude_worker0)
: impl_(new ThreadGroup::Impl(num_workers, worker_callback, exclude_worker0)) {}
void ThreadGroup::Join() {}
ThreadGroup::~ThreadGroup() { delete impl_; }
void Yield() {}
int MaxConcurrency() { return TVM_SGX_MAX_CONCURRENCY; }
extern "C" {
void tvm_ecall_run_worker(const void* impl) {
if (!sgx_is_within_enclave(impl, sizeof(ThreadGroup::Impl))) return;
((ThreadGroup::Impl*)impl)->RunTask();
}
}
} // namespace threading
} // namespace runtime
} // namespace tvm
enclave {
from "sgx_tstdc.edl" import *;
trusted {
public void tvm_ecall_run_module([user_check] const void* tvm_args,
[user_check] void* tvm_ret_value);
public void tvm_ecall_run_worker([user_check] const void* cb);
public void tvm_ecall_shutdown();
};
untrusted {
void tvm_ocall_thread_pool_launch(int num_workers, [user_check] void* cb);
};
};
......@@ -5,6 +5,7 @@
*/
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/c_backend_api.h>
#include <tvm/runtime/threading_backend.h>
#include <dmlc/thread_local.h>
#include <dmlc/logging.h>
#include <thread>
......@@ -17,9 +18,6 @@
#include <cstring>
#include <memory>
#include <sstream>
#if defined(__linux__)
#include <sched.h>
#endif
const constexpr int kL1CacheBytes = 64;
......@@ -73,14 +71,14 @@ class ParallelLauncher {
return num_pending_ == 0;
});
if (!has_error_) return 0;
std::ostringstream os;
std::string err("");
for (size_t i = 0; i < par_errors_.size(); ++i) {
if (par_errors_[i].length() != 0) {
os << "Task " << i << " error: " << par_errors_[i] << '\n';
err += "Task " + std::to_string(i) + " error: " + par_errors_[i] + '\n';
par_errors_[i].clear();
}
}
TVMAPISetLastError(os.str().c_str());
TVMAPISetLastError(err.c_str());
return -1;
}
// Signal that one job has finished.
......@@ -157,7 +155,7 @@ class SpscTaskQueue {
*/
void Push(const Task& input) {
while (!Enqueue(input)) {
std::this_thread::yield();
tvm::runtime::threading::Yield();
}
if (pending_.fetch_add(1) == -1) {
std::unique_lock<std::mutex> lock(mutex_);
......@@ -176,8 +174,8 @@ class SpscTaskQueue {
// If a new task comes to the queue quickly, this wait avoid the worker from sleeping.
// The default spin count is set by following the typical omp convention
for (uint32_t i = 0; i < spin_count && pending_.load() == 0; ++i) {
std::this_thread::yield();
}
tvm::runtime::threading::Yield();
}
if (pending_.fetch_sub(1) == 0) {
std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [this] {
......@@ -211,6 +209,8 @@ class SpscTaskQueue {
* \return Whether the task is enqueued.
*/
bool Enqueue(const Task& input) {
if (exit_now_.load(std::memory_order_relaxed)) return false;
const uint32_t tail = tail_.load(std::memory_order_relaxed);
if ((tail + 1) % kRingSize != (head_.load(std::memory_order_acquire))) {
......@@ -255,32 +255,17 @@ class SpscTaskQueue {
// The thread pool
class ThreadPool {
public:
ThreadPool() {
const char *val = getenv("TVM_NUM_THREADS");
if (val == nullptr) {
val = getenv("OMP_NUM_THREADS");
}
if (val != nullptr) {
num_workers_ = atoi(val);
} else {
#if defined(_M_X64) || defined(__x86_64__)
// Half to not count hyper threading.
num_workers_ = std::thread::hardware_concurrency() / 2;
#else
num_workers_ = std::thread::hardware_concurrency();
#endif
}
num_workers_ = std::max(num_workers_, 1);
this->Init();
}
~ThreadPool() {
for (std::unique_ptr<SpscTaskQueue>& q : queues_) {
q->SignalForKill();
}
for (std::thread& t : threads_) {
t.join();
ThreadPool(): num_workers_(tvm::runtime::threading::MaxConcurrency()) {
for (int i = 0; i < num_workers_; ++i) {
// The SpscTaskQueue only host ONE item at a time
queues_.emplace_back(std::unique_ptr<SpscTaskQueue>(new SpscTaskQueue()));
}
threads_ = std::unique_ptr<tvm::runtime::threading::ThreadGroup>(
new tvm::runtime::threading::ThreadGroup(
num_workers_, [this](int worker_id) { this->RunWorker(worker_id); },
false /* include_main_thread */));
}
~ThreadPool() { Shutdown(); }
int Launch(FTVMParallelLambda flambda,
void* cdata,
int num_task,
......@@ -307,38 +292,22 @@ class ThreadPool {
return res;
}
void Shutdown() {
for (std::unique_ptr<SpscTaskQueue>& q : queues_) {
q->SignalForKill();
}
threads_.reset();
}
static ThreadPool* Global() {
static ThreadPool inst;
return &inst;
}
private:
// Initialize the pool.
void Init() {
for (int i = 0; i < num_workers_; ++i) {
// The SpscTaskQueue only host ONE item at a time
queues_.emplace_back(
std::unique_ptr<SpscTaskQueue>(new SpscTaskQueue()));
}
threads_.resize(num_workers_);
for (int i = 0; i < num_workers_; ++i) {
threads_[i] = std::thread([this, i] {
this->RunWorker(queues_[i].get());
});
}
const char *val = getenv("TVM_BIND_THREADS");
if (val == nullptr || atoi(val) == 1) {
if (num_workers_ <= std::thread::hardware_concurrency()) {
SetThreadAffinity();
} else {
LOG(WARNING)
<< "The thread affinity cannot be set when the number of workers is larger "
<< "than the number of available cores in the system.";
}
}
}
// Internal worker function.
void RunWorker(SpscTaskQueue* queue) {
void RunWorker(int worker_id) {
SpscTaskQueue* queue = queues_[worker_id].get();
SpscTaskQueue::Task task;
ParallelLauncher::ThreadLocal()->is_worker = true;
while (queue->Pop(&task)) {
......@@ -352,40 +321,9 @@ class ThreadPool {
}
}
}
// bind worker threads to disjoint cores
void SetThreadAffinity() {
#if defined(__ANDROID__)
#ifndef CPU_SET
#define CPU_SETSIZE 1024
#define __NCPUBITS (8 * sizeof (uint64_t))
typedef struct {
uint64_t __bits[CPU_SETSIZE / __NCPUBITS];
} cpu_set_t;
#define CPU_SET(cpu, cpusetp) \
((cpusetp)->__bits[(cpu)/__NCPUBITS] |= (1UL << ((cpu) % __NCPUBITS)))
#define CPU_ZERO(cpusetp) \
memset((cpusetp), 0, sizeof(cpu_set_t))
#endif
#endif
for (int i=0; i < num_workers_; ++i) {
#if defined(__linux__) || defined(__ANDROID__)
cpu_set_t cpuset;
CPU_ZERO(&cpuset);
CPU_SET(i, &cpuset);
#if defined(__ANDROID__)
sched_setaffinity(threads_[i].native_handle(), sizeof(cpu_set_t), &cpuset);
#else
pthread_setaffinity_np(threads_[i].native_handle(),
sizeof(cpu_set_t), &cpuset);
#endif
#endif
}
}
// Number of workers
int num_workers_;
std::vector<std::unique_ptr<SpscTaskQueue> > queues_;
std::vector<std::thread> threads_;
std::unique_ptr<tvm::runtime::threading::ThreadGroup> threads_;
};
} // namespace runtime
......@@ -411,7 +349,7 @@ int TVMBackendParallelBarrier(int task_id, TVMParallelGroupEnv* penv) {
if (i != task_id) {
while (sync_counter[i * kSyncStride].load(
std::memory_order_relaxed) <= old_counter) {
std::this_thread::yield();
tvm::runtime::threading::Yield();
}
}
}
......
/*!
* Copyright (c) 2018 by Contributors
* \file threading_backend.cc
* \brief Native threading backend
*/
#include <tvm/runtime/threading_backend.h>
#include <dmlc/logging.h>
#include <thread>
#if defined(__linux__)
#include <sched.h>
#endif
namespace tvm {
namespace runtime {
namespace threading {
class ThreadGroup::Impl {
public:
Impl(int num_workers,
std::function<void(int)> worker_callback,
bool exclude_worker0)
: num_workers_(num_workers) {
CHECK_GE(num_workers, 1)
<< "Requested a non-positive number of worker threads.";
for (int i = exclude_worker0; i < num_workers_; ++i) {
threads_.emplace_back([worker_callback, i] { worker_callback(i); });
}
const char *val = getenv("TVM_BIND_THREADS");
if (val == nullptr || atoi(val) == 1) {
if (num_workers_ <= std::thread::hardware_concurrency()) {
SetAffinity();
} else {
LOG(WARNING)
<< "The thread affinity cannot be set when the number of workers"
<< "is larger than the number of available cores in the system.";
}
}
}
~Impl() { Join(); }
void Join() {
for (auto& t : threads_) {
if (t.joinable()) t.join();
}
}
private:
// bind worker threads to disjoint cores
void SetAffinity() {
#if defined(__ANDROID__)
#ifndef CPU_SET
#define CPU_SETSIZE 1024
#define __NCPUBITS (8 * sizeof (uint64_t))
typedef struct {
uint64_t __bits[CPU_SETSIZE / __NCPUBITS];
} cpu_set_t;
#define CPU_SET(cpu, cpusetp) \
((cpusetp)->__bits[(cpu)/__NCPUBITS] |= (1UL << ((cpu) % __NCPUBITS)))
#define CPU_ZERO(cpusetp) \
memset((cpusetp), 0, sizeof(cpu_set_t))
#endif
#endif
for (unsigned i=0; i < threads_.size(); ++i) {
#if defined(__linux__) || defined(__ANDROID__)
cpu_set_t cpuset;
CPU_ZERO(&cpuset);
CPU_SET(i, &cpuset);
#if defined(__ANDROID__)
sched_setaffinity(threads_[i].native_handle(), sizeof(cpu_set_t), &cpuset);
#else
pthread_setaffinity_np(threads_[i].native_handle(),
sizeof(cpu_set_t), &cpuset);
#endif
#endif
}
}
int num_workers_;
std::vector<std::thread> threads_;
};
ThreadGroup::ThreadGroup(int num_workers,
std::function<void(int)> worker_callback,
bool exclude_worker0)
: impl_(new ThreadGroup::Impl(num_workers, worker_callback, exclude_worker0)) {}
ThreadGroup::~ThreadGroup() { delete impl_; }
void ThreadGroup::Join() { impl_->Join(); }
void Yield() {
std::this_thread::yield();
}
int MaxConcurrency() {
int max_concurrency = 1;
const char *val = getenv("TVM_NUM_THREADS");
if (val == nullptr) {
val = getenv("OMP_NUM_THREADS");
}
if (val != nullptr) {
max_concurrency = atoi(val);
} else {
max_concurrency = std::thread::hardware_concurrency();
#if defined(_M_X64) || defined(__x86_64__)
max_concurrency /= 2; // ignore hyper-threading
#endif
}
return std::max(max_concurrency, 1);
}
} // namespace threading
} // namespace runtime
} // namespace tvm
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment