From bfceafc7b46fbf13ecd3a5421fdea29848276c44 Mon Sep 17 00:00:00 2001 From: nhynes <nhynes@berkeley.edu> Date: Thu, 15 Mar 2018 19:06:46 -0700 Subject: [PATCH] Pluggable Thread Launching Mechanism (#991) --- apps/howto_deploy/tvm_runtime_pack.cc | 1 + apps/sgx/Makefile | 9 +- apps/sgx/app.cc | 21 +++-- apps/sgx/enclave.cc | 6 +- apps/sgx/enclave_config.xml | 6 +- apps/sgx/prepare_test_libs.py | 2 + apps/sgx/test_addone.edl | 6 +- apps/sgx/tvm_runtime_pack.cc | 6 +- include/tvm/runtime/threading_backend.h | 65 +++++++++++++ sgx/{sgx_runtime.cc => runtime_t.cc} | 20 ++-- sgx/runtime_u.cc | 34 +++++++ sgx/threading_backend.cc | 71 ++++++++++++++ sgx/tvm.edl | 15 +++ src/runtime/thread_pool.cc | 120 ++++++------------------ src/runtime/threading_backend.cc | 113 ++++++++++++++++++++++ 15 files changed, 374 insertions(+), 121 deletions(-) create mode 100644 include/tvm/runtime/threading_backend.h rename sgx/{sgx_runtime.cc => runtime_t.cc} (52%) create mode 100644 sgx/runtime_u.cc create mode 100644 sgx/threading_backend.cc create mode 100644 sgx/tvm.edl create mode 100644 src/runtime/threading_backend.cc diff --git a/apps/howto_deploy/tvm_runtime_pack.cc b/apps/howto_deploy/tvm_runtime_pack.cc index 9a090d863..445768128 100644 --- a/apps/howto_deploy/tvm_runtime_pack.cc +++ b/apps/howto_deploy/tvm_runtime_pack.cc @@ -25,6 +25,7 @@ #include "../../src/runtime/module.cc" #include "../../src/runtime/registry.cc" #include "../../src/runtime/file_util.cc" +#include "../../src/runtime/threading_backend.cc" #include "../../src/runtime/thread_pool.cc" // NOTE: all the files after this are optional modules diff --git a/apps/sgx/Makefile b/apps/sgx/Makefile index 6a1eeb5b8..fd1d0cc8f 100644 --- a/apps/sgx/Makefile +++ b/apps/sgx/Makefile @@ -26,6 +26,7 @@ pkg_cflags := -std=c++11 -O2 -fPIC\ -I${TVM_ROOT}/dlpack/include\ -I.\ -DDMLC_LOG_STACK_TRACE=0\ + -fmax-errors=4 pkg_ldflags := -L${TVM_ROOT}/lib @@ -40,7 +41,7 @@ enclave_cflags := -static -nostdinc\ -DDMLC_CXX11_THREAD_LOCAL=0\ $(enclave_include_paths)\ -enclave_cxxflags := -nostdinc++ $(enclave_cflags) +enclave_cxxflags := -nostdinc++ $(enclave_cflags) -DTVM_SGX_MAX_CONCURRENCY=4 enclave_ldflags :=\ -Wl,--no-undefined -nostdlib -nodefaultlibs -nostartfiles -L$(SGX_SDK)/lib64\ @@ -62,7 +63,7 @@ app_ldflags := -L$(SGX_SDK)/lib64\ all: lib/test_addone.signed.so bin/test_addone # Build rule for all-in-one TVM package library -lib/tvm_runtime_pack.o: tvm_runtime_pack.cc +lib/tvm_runtime_pack.o: tvm_runtime_pack.cc lib/test_addone_t.o @mkdir -p $(@D) $(CXX) -c $< -o $@ $(pkg_cflags) $(pkg_ldflags) $(enclave_cxxflags) -g @@ -94,7 +95,7 @@ lib/test_addone.signed.so: lib/test_addone.so enclave_config.xml # An app that runs the enclave bin/test_addone: app.cc lib/test_addone_u.o @mkdir -p $(@D) - $(CXX) $^ -o $@ $(app_cflags) $(app_ldflags) + $(CXX) $^ -o $@ $(app_cflags) $(app_ldflags) $(pkg_cflags) -g # Debugging runtime pack built without SGX (c.f. howto_deploy/tvm_runtime_pack.cc) lib/tvm_runtime_pack_nosgx.o: tvm_runtime_pack.cc @@ -104,7 +105,7 @@ lib/tvm_runtime_pack_nosgx.o: tvm_runtime_pack.cc # Debugging binary that runs TVM without SGX bin/addone_nosgx: enclave.cc lib/tvm_runtime_pack_nosgx.o lib/test_addone_sys.o @mkdir -p $(@D) - $(CXX) $^ -o $@ $(pkg_cflags) $(pkg_ldflags) -g + $(CXX) $^ -o $@ $(pkg_cflags) $(pkg_ldflags) -g -lpthread clean: rm -rf lib bin diff --git a/apps/sgx/app.cc b/apps/sgx/app.cc index 1516e8b4e..d008bfb37 100644 --- a/apps/sgx/app.cc +++ b/apps/sgx/app.cc @@ -1,13 +1,15 @@ #include <cstdio> +#include <iostream> #include "sgx_urts.h" #include "sgx_eid.h" #include "test_addone_u.h" +#include "../../sgx/runtime_u.cc" #define TOKEN_FILENAME "bin/test_addone.token" #define ENCLAVE_FILENAME "lib/test_addone.signed.so" -sgx_enclave_id_t global_eid = 0; // global EID shared by multiple threads +sgx_enclave_id_t tvm_sgx_eid; typedef struct _sgx_errlist_t { sgx_status_t err; @@ -80,7 +82,7 @@ int initialize_enclave(void) /* Step 2: call sgx_create_enclave to initialize an enclave instance */ /* Debug Support: set 2nd parameter to 1 */ - sgx_status = sgx_create_enclave(ENCLAVE_FILENAME, SGX_DEBUG_FLAG, &token, &updated, &global_eid, NULL); + sgx_status = sgx_create_enclave(ENCLAVE_FILENAME, SGX_DEBUG_FLAG, &token, &updated, &tvm_sgx_eid, NULL); if (sgx_status != SGX_SUCCESS) { print_error_message(sgx_status); if (fp != NULL) fclose(fp); @@ -105,7 +107,7 @@ int initialize_enclave(void) } int SGX_CDECL main(int argc, char *argv[]) { - if(initialize_enclave() < 0){ + if(initialize_enclave() < 0) { printf("Failed to initialize enclave.\n"); return -1; } @@ -113,12 +115,13 @@ int SGX_CDECL main(int argc, char *argv[]) { /* Run TVM within the enclave */ int addone_status; sgx_status_t sgx_status = SGX_ERROR_UNEXPECTED; - sgx_status = enclave_main(global_eid, &addone_status); + sgx_status = tvm_ecall_run_module(tvm_sgx_eid, nullptr, &addone_status); if (sgx_status != SGX_SUCCESS) { print_error_message(sgx_status); } - - sgx_destroy_enclave(global_eid); + tvm_ecall_shutdown(tvm_sgx_eid); + tvm::runtime::sgx::Shutdown(); + sgx_destroy_enclave(tvm_sgx_eid); if (addone_status == 1) { printf("It works!"); @@ -127,3 +130,9 @@ int SGX_CDECL main(int argc, char *argv[]) { printf("It doesn't work."); return -1; } + +extern "C" { +void ocall_println(const char* str) { + std::cout << "Enclave says: " << str << std::endl; +} +} diff --git a/apps/sgx/enclave.cc b/apps/sgx/enclave.cc index 758845554..d43107288 100644 --- a/apps/sgx/enclave.cc +++ b/apps/sgx/enclave.cc @@ -6,6 +6,8 @@ #include <iostream> #endif +extern void Shutdown(); + /* This function mirrors the one in howto_deploy except without the iostream */ int Verify(tvm::runtime::Module mod, std::string fname) { // Get the function from the module. @@ -43,9 +45,9 @@ int Verify(tvm::runtime::Module mod, std::string fname) { extern "C" { -int enclave_main() { +void tvm_ecall_run_module(const void* tvm_args, void* tvm_return_value) { tvm::runtime::Module mod_syslib = (*tvm::runtime::Registry::Get("module._GetSystemLib"))(); - return Verify(mod_syslib, "addonesys"); + *(int*)tvm_return_value = Verify(mod_syslib, "addonesys"); } } diff --git a/apps/sgx/enclave_config.xml b/apps/sgx/enclave_config.xml index d24da1882..f7fc129d6 100644 --- a/apps/sgx/enclave_config.xml +++ b/apps/sgx/enclave_config.xml @@ -1,9 +1,9 @@ <EnclaveConfiguration> <ProdID>0</ProdID> <ISVSVN>0</ISVSVN> - <StackMaxSize>0x2000</StackMaxSize> - <HeapMaxSize>0x1000</HeapMaxSize> - <TCSNum>1</TCSNum> + <StackMaxSize>0x100000</StackMaxSize> + <HeapMaxSize>0x100000</HeapMaxSize> + <TCSNum>5</TCSNum> <TCSPolicy>1</TCSPolicy> <DisableDebug>0</DisableDebug> <MiscSelect>0</MiscSelect> diff --git a/apps/sgx/prepare_test_libs.py b/apps/sgx/prepare_test_libs.py index 1fa9d74ef..715880e61 100644 --- a/apps/sgx/prepare_test_libs.py +++ b/apps/sgx/prepare_test_libs.py @@ -11,6 +11,8 @@ def prepare_test_libs(base_path): A = tvm.placeholder((n,), name='A') B = tvm.compute(A.shape, lambda *i: A(*i) + 1, name='B') s = tvm.create_schedule(B.op) + s[B].parallel(s[B].op.axis[0]) + print(tvm.lower(s, [A, B], simple_mode=True)) # Compile library in system library mode fadd_syslib = tvm.build(s, [A, B], 'llvm --system-lib', name='addonesys') diff --git a/apps/sgx/test_addone.edl b/apps/sgx/test_addone.edl index 58341a727..0127a5818 100644 --- a/apps/sgx/test_addone.edl +++ b/apps/sgx/test_addone.edl @@ -1,7 +1,7 @@ enclave { - from "sgx_tstdc.edl" import sgx_thread_wait_untrusted_event_ocall, sgx_thread_set_untrusted_event_ocall, sgx_thread_setwait_untrusted_events_ocall, sgx_thread_set_multiple_untrusted_events_ocall; + from "../../sgx/tvm.edl" import *; - trusted { - public int enclave_main(); + untrusted { + void ocall_println([in, string] const char *str); }; }; diff --git a/apps/sgx/tvm_runtime_pack.cc b/apps/sgx/tvm_runtime_pack.cc index 709386b78..0d88af03a 100644 --- a/apps/sgx/tvm_runtime_pack.cc +++ b/apps/sgx/tvm_runtime_pack.cc @@ -5,7 +5,11 @@ * Please refer to the Makefile (rule lib/tvm_runtime_pack.o) for how to build. * */ -#include "../../sgx/sgx_runtime.cc" +#ifdef _LIBCPP_SGX_CONFIG +#include "lib/test_addone_t.h" +#endif +#include "../../sgx/runtime_t.cc" + #ifndef _LIBCPP_SGX_CONFIG #include "../../src/runtime/file_util.cc" #endif diff --git a/include/tvm/runtime/threading_backend.h b/include/tvm/runtime/threading_backend.h new file mode 100644 index 000000000..6c8c4f5eb --- /dev/null +++ b/include/tvm/runtime/threading_backend.h @@ -0,0 +1,65 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file threading_backend.h + * \brief Utilities for manipulating thread pool threads. + */ +#ifndef TVM_RUNTIME_THREADING_BACKEND_H_ +#define TVM_RUNTIME_THREADING_BACKEND_H_ + +#include <functional> +#include <memory> +#include <vector> + +namespace tvm { +namespace runtime { +namespace threading { + +/*! + * \brief A platform-agnostic abstraction for managing a collection of + * thread pool threads. + */ +class ThreadGroup { + public: + class Impl; + + /*! + * \brief Creates a collection of threads which run a provided function. + * + * \param num_workers The total number of worker threads in this group. + Includes main thread if `exclude_worker0 = true` + * \param worker_callback A callback which is run in its own thread. + Receives the worker_id as an argument. + * \param exclude_worker0 Whether to use the main thread as a worker. + * If `true`, worker0 will not be launched in a new thread and + * `worker_callback` will only be called for values >= 1. This + * allows use of the main thread as a worker. + */ + ThreadGroup(int num_workers, + std::function<void(int)> worker_callback, + bool exclude_worker0 = false); + ~ThreadGroup(); + + /*! + * \brief Blocks until all non-main threads in the pool finish. + */ + void Join(); + + private: + Impl* impl_; +}; + +/*! + * \brief Platform-agnostic no-op. + */ +void Yield(); + +/*! + * \return the maximum number of effective workers for this system. + */ +int MaxConcurrency(); + +} // namespace threading +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_THREADING_BACKEND_H_ diff --git a/sgx/sgx_runtime.cc b/sgx/runtime_t.cc similarity index 52% rename from sgx/sgx_runtime.cc rename to sgx/runtime_t.cc index 6a0d0dfb2..5f280ffce 100644 --- a/sgx/sgx_runtime.cc +++ b/sgx/runtime_t.cc @@ -9,17 +9,15 @@ #include "../../src/runtime/module.cc" #include "../../src/runtime/registry.cc" #include "../../src/runtime/system_lib_module.cc" +#ifndef _LIBCPP_SGX_CONFIG +#include "../../src/runtime/threading_backend.cc" +#else +#include "threading_backend.cc" +#endif +#include "../../src/runtime/thread_pool.cc" -// dummy parallel runtime (for now) -int TVMBackendParallelLaunch( - FTVMParallelLambda flambda, - void* cdata, - int num_task) { - TVMParallelGroupEnv env = { nullptr /* sync_handle */, 1 /* num_task */ }; - return flambda(0 /* task_id */, &env, cdata); +extern "C" { +void tvm_ecall_shutdown() { + tvm::runtime::ThreadPool::Global()->Shutdown(); } - -int TVMBackendParallelBarrier(int task_id, TVMParallelGroupEnv* penv) { - return 0; } - diff --git a/sgx/runtime_u.cc b/sgx/runtime_u.cc new file mode 100644 index 000000000..0acccf614 --- /dev/null +++ b/sgx/runtime_u.cc @@ -0,0 +1,34 @@ +#include <tvm/runtime/threading_backend.h> +#include "../../src/runtime/threading_backend.cc" +#include <iostream> + +extern sgx_enclave_id_t tvm_sgx_eid; +extern "C" { +sgx_status_t tvm_ecall_run_worker(sgx_enclave_id_t eid, const void* cb); +} + +namespace tvm { +namespace runtime { +namespace sgx { + +static std::unique_ptr<tvm::runtime::threading::ThreadGroup> sgx_thread_group; + +extern "C" { +void tvm_ocall_thread_pool_launch(int num_tasks, void* cb) { + std::function<void(int)> runner = [cb](int _worker_id) { + sgx_status_t sgx_status = SGX_ERROR_UNEXPECTED; + sgx_status = tvm_ecall_run_worker(tvm_sgx_eid, cb); + CHECK(sgx_status == SGX_SUCCESS) << "SGX Error: " << sgx_status; + }; + sgx_thread_group.reset(new tvm::runtime::threading::ThreadGroup( + num_tasks, runner, false /* include_main_thread */)); +} +} + +void Shutdown() { + sgx_thread_group->Join(); +} + +} // namespace sgx +} // namespace runtime +} // namespace tvm diff --git a/sgx/threading_backend.cc b/sgx/threading_backend.cc new file mode 100644 index 000000000..7f820ab51 --- /dev/null +++ b/sgx/threading_backend.cc @@ -0,0 +1,71 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file sgx/threading_backend.cc + * \brief SGX threading backend + */ +#include <tvm/runtime/threading_backend.h> +#include <dmlc/logging.h> +#include <sgx_edger8r.h> +#include <sgx_trts.h> +#include <atomic> + +extern "C" { +sgx_status_t SGX_CDECL tvm_ocall_thread_pool_launch(int num_workers, void* cb); +} + +#ifndef TVM_SGX_MAX_CONCURRENCY +#define TVM_SGX_MAX_CONCURRENCY 1 +#endif + +namespace tvm { +namespace runtime { +namespace threading { + +class ThreadGroup::Impl { + public: + Impl(int num_workers, std::function<void(int)> worker_callback, + bool exclude_worker0) + : num_workers_(num_workers), + worker_callback_(worker_callback), + next_task_id_(exclude_worker0) { + CHECK(num_workers <= TVM_SGX_MAX_CONCURRENCY) + << "Tried spawning more threads than allowed by TVM_SGX_MAX_CONCURRENCY."; + sgx_status_t sgx_status = SGX_ERROR_UNEXPECTED; + sgx_status = tvm_ocall_thread_pool_launch(num_workers, this); + CHECK(sgx_status == SGX_SUCCESS) << "SGX Error: " << sgx_status; + } + + void RunTask() { + int task_id = next_task_id_++; + CHECK(task_id < num_workers_) + << "More workers entered enclave than allowed by TVM_SGX_MAX_CONCURRENCY"; + worker_callback_(task_id); + } + + private: + int num_workers_; + std::function<void(int)> worker_callback_; + std::atomic<int> next_task_id_; +}; + +ThreadGroup::ThreadGroup(int num_workers, + std::function<void(int)> worker_callback, + bool exclude_worker0) + : impl_(new ThreadGroup::Impl(num_workers, worker_callback, exclude_worker0)) {} +void ThreadGroup::Join() {} +ThreadGroup::~ThreadGroup() { delete impl_; } + +void Yield() {} + +int MaxConcurrency() { return TVM_SGX_MAX_CONCURRENCY; } + +extern "C" { +void tvm_ecall_run_worker(const void* impl) { + if (!sgx_is_within_enclave(impl, sizeof(ThreadGroup::Impl))) return; + ((ThreadGroup::Impl*)impl)->RunTask(); +} +} + +} // namespace threading +} // namespace runtime +} // namespace tvm diff --git a/sgx/tvm.edl b/sgx/tvm.edl new file mode 100644 index 000000000..e88ac0ac7 --- /dev/null +++ b/sgx/tvm.edl @@ -0,0 +1,15 @@ +enclave { + from "sgx_tstdc.edl" import *; + + trusted { + public void tvm_ecall_run_module([user_check] const void* tvm_args, + [user_check] void* tvm_ret_value); + public void tvm_ecall_run_worker([user_check] const void* cb); + public void tvm_ecall_shutdown(); + }; + + untrusted { + void tvm_ocall_thread_pool_launch(int num_workers, [user_check] void* cb); + }; +}; + diff --git a/src/runtime/thread_pool.cc b/src/runtime/thread_pool.cc index 4e13fdd14..ac42a0c03 100644 --- a/src/runtime/thread_pool.cc +++ b/src/runtime/thread_pool.cc @@ -5,6 +5,7 @@ */ #include <tvm/runtime/c_runtime_api.h> #include <tvm/runtime/c_backend_api.h> +#include <tvm/runtime/threading_backend.h> #include <dmlc/thread_local.h> #include <dmlc/logging.h> #include <thread> @@ -17,9 +18,6 @@ #include <cstring> #include <memory> #include <sstream> -#if defined(__linux__) -#include <sched.h> -#endif const constexpr int kL1CacheBytes = 64; @@ -73,14 +71,14 @@ class ParallelLauncher { return num_pending_ == 0; }); if (!has_error_) return 0; - std::ostringstream os; + std::string err(""); for (size_t i = 0; i < par_errors_.size(); ++i) { if (par_errors_[i].length() != 0) { - os << "Task " << i << " error: " << par_errors_[i] << '\n'; + err += "Task " + std::to_string(i) + " error: " + par_errors_[i] + '\n'; par_errors_[i].clear(); } } - TVMAPISetLastError(os.str().c_str()); + TVMAPISetLastError(err.c_str()); return -1; } // Signal that one job has finished. @@ -157,7 +155,7 @@ class SpscTaskQueue { */ void Push(const Task& input) { while (!Enqueue(input)) { - std::this_thread::yield(); + tvm::runtime::threading::Yield(); } if (pending_.fetch_add(1) == -1) { std::unique_lock<std::mutex> lock(mutex_); @@ -176,8 +174,8 @@ class SpscTaskQueue { // If a new task comes to the queue quickly, this wait avoid the worker from sleeping. // The default spin count is set by following the typical omp convention for (uint32_t i = 0; i < spin_count && pending_.load() == 0; ++i) { - std::this_thread::yield(); - } + tvm::runtime::threading::Yield(); + } if (pending_.fetch_sub(1) == 0) { std::unique_lock<std::mutex> lock(mutex_); cv_.wait(lock, [this] { @@ -211,6 +209,8 @@ class SpscTaskQueue { * \return Whether the task is enqueued. */ bool Enqueue(const Task& input) { + if (exit_now_.load(std::memory_order_relaxed)) return false; + const uint32_t tail = tail_.load(std::memory_order_relaxed); if ((tail + 1) % kRingSize != (head_.load(std::memory_order_acquire))) { @@ -255,32 +255,17 @@ class SpscTaskQueue { // The thread pool class ThreadPool { public: - ThreadPool() { - const char *val = getenv("TVM_NUM_THREADS"); - if (val == nullptr) { - val = getenv("OMP_NUM_THREADS"); - } - if (val != nullptr) { - num_workers_ = atoi(val); - } else { -#if defined(_M_X64) || defined(__x86_64__) - // Half to not count hyper threading. - num_workers_ = std::thread::hardware_concurrency() / 2; -#else - num_workers_ = std::thread::hardware_concurrency(); -#endif - } - num_workers_ = std::max(num_workers_, 1); - this->Init(); - } - ~ThreadPool() { - for (std::unique_ptr<SpscTaskQueue>& q : queues_) { - q->SignalForKill(); - } - for (std::thread& t : threads_) { - t.join(); + ThreadPool(): num_workers_(tvm::runtime::threading::MaxConcurrency()) { + for (int i = 0; i < num_workers_; ++i) { + // The SpscTaskQueue only host ONE item at a time + queues_.emplace_back(std::unique_ptr<SpscTaskQueue>(new SpscTaskQueue())); } + threads_ = std::unique_ptr<tvm::runtime::threading::ThreadGroup>( + new tvm::runtime::threading::ThreadGroup( + num_workers_, [this](int worker_id) { this->RunWorker(worker_id); }, + false /* include_main_thread */)); } + ~ThreadPool() { Shutdown(); } int Launch(FTVMParallelLambda flambda, void* cdata, int num_task, @@ -307,38 +292,22 @@ class ThreadPool { return res; } + void Shutdown() { + for (std::unique_ptr<SpscTaskQueue>& q : queues_) { + q->SignalForKill(); + } + threads_.reset(); + } + static ThreadPool* Global() { static ThreadPool inst; return &inst; } private: - // Initialize the pool. - void Init() { - for (int i = 0; i < num_workers_; ++i) { - // The SpscTaskQueue only host ONE item at a time - queues_.emplace_back( - std::unique_ptr<SpscTaskQueue>(new SpscTaskQueue())); - } - threads_.resize(num_workers_); - for (int i = 0; i < num_workers_; ++i) { - threads_[i] = std::thread([this, i] { - this->RunWorker(queues_[i].get()); - }); - } - const char *val = getenv("TVM_BIND_THREADS"); - if (val == nullptr || atoi(val) == 1) { - if (num_workers_ <= std::thread::hardware_concurrency()) { - SetThreadAffinity(); - } else { - LOG(WARNING) - << "The thread affinity cannot be set when the number of workers is larger " - << "than the number of available cores in the system."; - } - } - } // Internal worker function. - void RunWorker(SpscTaskQueue* queue) { + void RunWorker(int worker_id) { + SpscTaskQueue* queue = queues_[worker_id].get(); SpscTaskQueue::Task task; ParallelLauncher::ThreadLocal()->is_worker = true; while (queue->Pop(&task)) { @@ -352,40 +321,9 @@ class ThreadPool { } } } - // bind worker threads to disjoint cores - void SetThreadAffinity() { -#if defined(__ANDROID__) -#ifndef CPU_SET - #define CPU_SETSIZE 1024 - #define __NCPUBITS (8 * sizeof (uint64_t)) - typedef struct { - uint64_t __bits[CPU_SETSIZE / __NCPUBITS]; - } cpu_set_t; - - #define CPU_SET(cpu, cpusetp) \ - ((cpusetp)->__bits[(cpu)/__NCPUBITS] |= (1UL << ((cpu) % __NCPUBITS))) - #define CPU_ZERO(cpusetp) \ - memset((cpusetp), 0, sizeof(cpu_set_t)) -#endif -#endif - for (int i=0; i < num_workers_; ++i) { -#if defined(__linux__) || defined(__ANDROID__) - cpu_set_t cpuset; - CPU_ZERO(&cpuset); - CPU_SET(i, &cpuset); -#if defined(__ANDROID__) - sched_setaffinity(threads_[i].native_handle(), sizeof(cpu_set_t), &cpuset); -#else - pthread_setaffinity_np(threads_[i].native_handle(), - sizeof(cpu_set_t), &cpuset); -#endif -#endif - } - } - // Number of workers int num_workers_; std::vector<std::unique_ptr<SpscTaskQueue> > queues_; - std::vector<std::thread> threads_; + std::unique_ptr<tvm::runtime::threading::ThreadGroup> threads_; }; } // namespace runtime @@ -411,7 +349,7 @@ int TVMBackendParallelBarrier(int task_id, TVMParallelGroupEnv* penv) { if (i != task_id) { while (sync_counter[i * kSyncStride].load( std::memory_order_relaxed) <= old_counter) { - std::this_thread::yield(); + tvm::runtime::threading::Yield(); } } } diff --git a/src/runtime/threading_backend.cc b/src/runtime/threading_backend.cc new file mode 100644 index 000000000..19ba9bf2d --- /dev/null +++ b/src/runtime/threading_backend.cc @@ -0,0 +1,113 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file threading_backend.cc + * \brief Native threading backend + */ +#include <tvm/runtime/threading_backend.h> +#include <dmlc/logging.h> +#include <thread> +#if defined(__linux__) +#include <sched.h> +#endif + +namespace tvm { +namespace runtime { +namespace threading { + +class ThreadGroup::Impl { + public: + Impl(int num_workers, + std::function<void(int)> worker_callback, + bool exclude_worker0) + : num_workers_(num_workers) { + CHECK_GE(num_workers, 1) + << "Requested a non-positive number of worker threads."; + for (int i = exclude_worker0; i < num_workers_; ++i) { + threads_.emplace_back([worker_callback, i] { worker_callback(i); }); + } + const char *val = getenv("TVM_BIND_THREADS"); + if (val == nullptr || atoi(val) == 1) { + if (num_workers_ <= std::thread::hardware_concurrency()) { + SetAffinity(); + } else { + LOG(WARNING) + << "The thread affinity cannot be set when the number of workers" + << "is larger than the number of available cores in the system."; + } + } + } + ~Impl() { Join(); } + + void Join() { + for (auto& t : threads_) { + if (t.joinable()) t.join(); + } + } + + private: + // bind worker threads to disjoint cores + void SetAffinity() { +#if defined(__ANDROID__) +#ifndef CPU_SET +#define CPU_SETSIZE 1024 +#define __NCPUBITS (8 * sizeof (uint64_t)) + typedef struct { + uint64_t __bits[CPU_SETSIZE / __NCPUBITS]; + } cpu_set_t; + +#define CPU_SET(cpu, cpusetp) \ + ((cpusetp)->__bits[(cpu)/__NCPUBITS] |= (1UL << ((cpu) % __NCPUBITS))) +#define CPU_ZERO(cpusetp) \ + memset((cpusetp), 0, sizeof(cpu_set_t)) +#endif +#endif + for (unsigned i=0; i < threads_.size(); ++i) { +#if defined(__linux__) || defined(__ANDROID__) + cpu_set_t cpuset; + CPU_ZERO(&cpuset); + CPU_SET(i, &cpuset); +#if defined(__ANDROID__) + sched_setaffinity(threads_[i].native_handle(), sizeof(cpu_set_t), &cpuset); +#else + pthread_setaffinity_np(threads_[i].native_handle(), + sizeof(cpu_set_t), &cpuset); +#endif +#endif + } + } + + int num_workers_; + std::vector<std::thread> threads_; +}; + +ThreadGroup::ThreadGroup(int num_workers, + std::function<void(int)> worker_callback, + bool exclude_worker0) + : impl_(new ThreadGroup::Impl(num_workers, worker_callback, exclude_worker0)) {} +ThreadGroup::~ThreadGroup() { delete impl_; } +void ThreadGroup::Join() { impl_->Join(); } + +void Yield() { + std::this_thread::yield(); +} + +int MaxConcurrency() { + int max_concurrency = 1; + const char *val = getenv("TVM_NUM_THREADS"); + if (val == nullptr) { + val = getenv("OMP_NUM_THREADS"); + } + if (val != nullptr) { + max_concurrency = atoi(val); + } else { + max_concurrency = std::thread::hardware_concurrency(); +#if defined(_M_X64) || defined(__x86_64__) + max_concurrency /= 2; // ignore hyper-threading +#endif + } + return std::max(max_concurrency, 1); +} + +} // namespace threading +} // namespace runtime +} // namespace tvm -- GitLab