From bfceafc7b46fbf13ecd3a5421fdea29848276c44 Mon Sep 17 00:00:00 2001
From: nhynes <nhynes@berkeley.edu>
Date: Thu, 15 Mar 2018 19:06:46 -0700
Subject: [PATCH] Pluggable Thread Launching Mechanism (#991)

---
 apps/howto_deploy/tvm_runtime_pack.cc   |   1 +
 apps/sgx/Makefile                       |   9 +-
 apps/sgx/app.cc                         |  21 +++--
 apps/sgx/enclave.cc                     |   6 +-
 apps/sgx/enclave_config.xml             |   6 +-
 apps/sgx/prepare_test_libs.py           |   2 +
 apps/sgx/test_addone.edl                |   6 +-
 apps/sgx/tvm_runtime_pack.cc            |   6 +-
 include/tvm/runtime/threading_backend.h |  65 +++++++++++++
 sgx/{sgx_runtime.cc => runtime_t.cc}    |  20 ++--
 sgx/runtime_u.cc                        |  34 +++++++
 sgx/threading_backend.cc                |  71 ++++++++++++++
 sgx/tvm.edl                             |  15 +++
 src/runtime/thread_pool.cc              | 120 ++++++------------------
 src/runtime/threading_backend.cc        | 113 ++++++++++++++++++++++
 15 files changed, 374 insertions(+), 121 deletions(-)
 create mode 100644 include/tvm/runtime/threading_backend.h
 rename sgx/{sgx_runtime.cc => runtime_t.cc} (52%)
 create mode 100644 sgx/runtime_u.cc
 create mode 100644 sgx/threading_backend.cc
 create mode 100644 sgx/tvm.edl
 create mode 100644 src/runtime/threading_backend.cc

diff --git a/apps/howto_deploy/tvm_runtime_pack.cc b/apps/howto_deploy/tvm_runtime_pack.cc
index 9a090d863..445768128 100644
--- a/apps/howto_deploy/tvm_runtime_pack.cc
+++ b/apps/howto_deploy/tvm_runtime_pack.cc
@@ -25,6 +25,7 @@
 #include "../../src/runtime/module.cc"
 #include "../../src/runtime/registry.cc"
 #include "../../src/runtime/file_util.cc"
+#include "../../src/runtime/threading_backend.cc"
 #include "../../src/runtime/thread_pool.cc"
 
 // NOTE: all the files after this are optional modules
diff --git a/apps/sgx/Makefile b/apps/sgx/Makefile
index 6a1eeb5b8..fd1d0cc8f 100644
--- a/apps/sgx/Makefile
+++ b/apps/sgx/Makefile
@@ -26,6 +26,7 @@ pkg_cflags := -std=c++11 -O2 -fPIC\
 	-I${TVM_ROOT}/dlpack/include\
 	-I.\
 	-DDMLC_LOG_STACK_TRACE=0\
+	-fmax-errors=4
 
 pkg_ldflags := -L${TVM_ROOT}/lib
 
@@ -40,7 +41,7 @@ enclave_cflags := -static -nostdinc\
 	-DDMLC_CXX11_THREAD_LOCAL=0\
 	$(enclave_include_paths)\
 
-enclave_cxxflags := -nostdinc++ $(enclave_cflags)
+enclave_cxxflags := -nostdinc++ $(enclave_cflags) -DTVM_SGX_MAX_CONCURRENCY=4
 
 enclave_ldflags :=\
 	-Wl,--no-undefined -nostdlib -nodefaultlibs -nostartfiles -L$(SGX_SDK)/lib64\
@@ -62,7 +63,7 @@ app_ldflags := -L$(SGX_SDK)/lib64\
 all: lib/test_addone.signed.so bin/test_addone
 
 # Build rule for all-in-one TVM package library
-lib/tvm_runtime_pack.o: tvm_runtime_pack.cc
+lib/tvm_runtime_pack.o: tvm_runtime_pack.cc lib/test_addone_t.o
 	@mkdir -p $(@D)
 	$(CXX) -c $< -o $@ $(pkg_cflags) $(pkg_ldflags) $(enclave_cxxflags) -g
 
@@ -94,7 +95,7 @@ lib/test_addone.signed.so: lib/test_addone.so enclave_config.xml
 # An app that runs the enclave
 bin/test_addone: app.cc lib/test_addone_u.o
 	@mkdir -p $(@D)
-	$(CXX) $^ -o $@ $(app_cflags) $(app_ldflags)
+	$(CXX) $^ -o $@ $(app_cflags) $(app_ldflags) $(pkg_cflags) -g
 
 # Debugging runtime pack built without SGX (c.f. howto_deploy/tvm_runtime_pack.cc)
 lib/tvm_runtime_pack_nosgx.o: tvm_runtime_pack.cc
@@ -104,7 +105,7 @@ lib/tvm_runtime_pack_nosgx.o: tvm_runtime_pack.cc
 # Debugging binary that runs TVM without SGX
 bin/addone_nosgx: enclave.cc lib/tvm_runtime_pack_nosgx.o lib/test_addone_sys.o
 	@mkdir -p $(@D)
-	$(CXX) $^ -o $@ $(pkg_cflags) $(pkg_ldflags) -g
+	$(CXX) $^ -o $@ $(pkg_cflags) $(pkg_ldflags) -g -lpthread
 
 clean:
 	rm -rf lib bin
diff --git a/apps/sgx/app.cc b/apps/sgx/app.cc
index 1516e8b4e..d008bfb37 100644
--- a/apps/sgx/app.cc
+++ b/apps/sgx/app.cc
@@ -1,13 +1,15 @@
 #include <cstdio>
+#include <iostream>
 
 #include "sgx_urts.h"
 #include "sgx_eid.h"
 #include "test_addone_u.h"
+#include "../../sgx/runtime_u.cc"
 
 #define TOKEN_FILENAME   "bin/test_addone.token"
 #define ENCLAVE_FILENAME "lib/test_addone.signed.so"
 
-sgx_enclave_id_t global_eid = 0;  // global EID shared by multiple threads
+sgx_enclave_id_t tvm_sgx_eid;
 
 typedef struct _sgx_errlist_t {
   sgx_status_t err;
@@ -80,7 +82,7 @@ int initialize_enclave(void)
 
   /* Step 2: call sgx_create_enclave to initialize an enclave instance */
   /* Debug Support: set 2nd parameter to 1 */
-  sgx_status = sgx_create_enclave(ENCLAVE_FILENAME, SGX_DEBUG_FLAG, &token, &updated, &global_eid, NULL);
+  sgx_status = sgx_create_enclave(ENCLAVE_FILENAME, SGX_DEBUG_FLAG, &token, &updated, &tvm_sgx_eid, NULL);
   if (sgx_status != SGX_SUCCESS) {
     print_error_message(sgx_status);
     if (fp != NULL) fclose(fp);
@@ -105,7 +107,7 @@ int initialize_enclave(void)
 }
 
 int SGX_CDECL main(int argc, char *argv[]) {
-  if(initialize_enclave() < 0){
+  if(initialize_enclave() < 0) {
     printf("Failed to initialize enclave.\n");
     return -1;
   }
@@ -113,12 +115,13 @@ int SGX_CDECL main(int argc, char *argv[]) {
   /* Run TVM within the enclave */
   int addone_status;
   sgx_status_t sgx_status = SGX_ERROR_UNEXPECTED;
-  sgx_status = enclave_main(global_eid, &addone_status);
+  sgx_status = tvm_ecall_run_module(tvm_sgx_eid, nullptr, &addone_status);
   if (sgx_status != SGX_SUCCESS) {
     print_error_message(sgx_status);
   }
-
-  sgx_destroy_enclave(global_eid);
+  tvm_ecall_shutdown(tvm_sgx_eid);
+  tvm::runtime::sgx::Shutdown();
+  sgx_destroy_enclave(tvm_sgx_eid);
 
   if (addone_status == 1) {
     printf("It works!");
@@ -127,3 +130,9 @@ int SGX_CDECL main(int argc, char *argv[]) {
   printf("It doesn't work.");
   return -1;
 }
+
+extern "C" {
+void ocall_println(const char* str) {
+  std::cout << "Enclave says: " << str << std::endl;
+}
+}
diff --git a/apps/sgx/enclave.cc b/apps/sgx/enclave.cc
index 758845554..d43107288 100644
--- a/apps/sgx/enclave.cc
+++ b/apps/sgx/enclave.cc
@@ -6,6 +6,8 @@
 #include <iostream>
 #endif
 
+extern void Shutdown();
+
 /* This function mirrors the one in howto_deploy except without the iostream */
 int Verify(tvm::runtime::Module mod, std::string fname) {
   // Get the function from the module.
@@ -43,9 +45,9 @@ int Verify(tvm::runtime::Module mod, std::string fname) {
 
 
 extern "C" {
-int enclave_main() {
+void tvm_ecall_run_module(const void* tvm_args, void* tvm_return_value) {
   tvm::runtime::Module mod_syslib = (*tvm::runtime::Registry::Get("module._GetSystemLib"))();
-  return Verify(mod_syslib, "addonesys");
+  *(int*)tvm_return_value = Verify(mod_syslib, "addonesys");
 }
 }
 
diff --git a/apps/sgx/enclave_config.xml b/apps/sgx/enclave_config.xml
index d24da1882..f7fc129d6 100644
--- a/apps/sgx/enclave_config.xml
+++ b/apps/sgx/enclave_config.xml
@@ -1,9 +1,9 @@
 <EnclaveConfiguration>
   <ProdID>0</ProdID>
   <ISVSVN>0</ISVSVN>
-  <StackMaxSize>0x2000</StackMaxSize>
-  <HeapMaxSize>0x1000</HeapMaxSize>
-  <TCSNum>1</TCSNum>
+  <StackMaxSize>0x100000</StackMaxSize>
+  <HeapMaxSize>0x100000</HeapMaxSize>
+  <TCSNum>5</TCSNum>
   <TCSPolicy>1</TCSPolicy>
   <DisableDebug>0</DisableDebug>
   <MiscSelect>0</MiscSelect>
diff --git a/apps/sgx/prepare_test_libs.py b/apps/sgx/prepare_test_libs.py
index 1fa9d74ef..715880e61 100644
--- a/apps/sgx/prepare_test_libs.py
+++ b/apps/sgx/prepare_test_libs.py
@@ -11,6 +11,8 @@ def prepare_test_libs(base_path):
     A = tvm.placeholder((n,), name='A')
     B = tvm.compute(A.shape, lambda *i: A(*i) + 1, name='B')
     s = tvm.create_schedule(B.op)
+    s[B].parallel(s[B].op.axis[0])
+    print(tvm.lower(s, [A, B], simple_mode=True))
 
     # Compile library in system library mode
     fadd_syslib = tvm.build(s, [A, B], 'llvm --system-lib', name='addonesys')
diff --git a/apps/sgx/test_addone.edl b/apps/sgx/test_addone.edl
index 58341a727..0127a5818 100644
--- a/apps/sgx/test_addone.edl
+++ b/apps/sgx/test_addone.edl
@@ -1,7 +1,7 @@
 enclave {
-    from "sgx_tstdc.edl" import sgx_thread_wait_untrusted_event_ocall, sgx_thread_set_untrusted_event_ocall, sgx_thread_setwait_untrusted_events_ocall, sgx_thread_set_multiple_untrusted_events_ocall;
+    from "../../sgx/tvm.edl" import *;
 
-    trusted {
-        public int enclave_main();
+    untrusted {
+        void ocall_println([in, string] const char *str);
     };
 };
diff --git a/apps/sgx/tvm_runtime_pack.cc b/apps/sgx/tvm_runtime_pack.cc
index 709386b78..0d88af03a 100644
--- a/apps/sgx/tvm_runtime_pack.cc
+++ b/apps/sgx/tvm_runtime_pack.cc
@@ -5,7 +5,11 @@
  *   Please refer to the Makefile (rule lib/tvm_runtime_pack.o) for how to build.
  *
  */
-#include "../../sgx/sgx_runtime.cc"
+#ifdef _LIBCPP_SGX_CONFIG
+#include "lib/test_addone_t.h"
+#endif
+#include "../../sgx/runtime_t.cc"
+
 #ifndef _LIBCPP_SGX_CONFIG
 #include "../../src/runtime/file_util.cc"
 #endif
diff --git a/include/tvm/runtime/threading_backend.h b/include/tvm/runtime/threading_backend.h
new file mode 100644
index 000000000..6c8c4f5eb
--- /dev/null
+++ b/include/tvm/runtime/threading_backend.h
@@ -0,0 +1,65 @@
+/*!
+ *  Copyright (c) 2018 by Contributors
+ * \file threading_backend.h
+ * \brief Utilities for manipulating thread pool threads.
+ */
+#ifndef TVM_RUNTIME_THREADING_BACKEND_H_
+#define TVM_RUNTIME_THREADING_BACKEND_H_
+
+#include <functional>
+#include <memory>
+#include <vector>
+
+namespace tvm {
+namespace runtime {
+namespace threading {
+
+/*!
+ * \brief A platform-agnostic abstraction for managing a collection of
+ *        thread pool threads.
+ */
+class ThreadGroup {
+ public:
+  class Impl;
+
+   /*!
+    * \brief Creates a collection of threads which run a provided function.
+    *
+    * \param num_workers The total number of worker threads in this group.
+             Includes main thread if `exclude_worker0 = true`
+    * \param worker_callback A callback which is run in its own thread.
+             Receives the worker_id as an argument.
+    * \param exclude_worker0 Whether to use the main thread as a worker.
+    *        If  `true`, worker0 will not be launched in a new thread and
+    *        `worker_callback` will only be called for values >= 1. This
+    *        allows use of the main thread as a worker.
+    */
+  ThreadGroup(int num_workers,
+              std::function<void(int)> worker_callback,
+              bool exclude_worker0 = false);
+  ~ThreadGroup();
+
+   /*!
+    * \brief Blocks until all non-main threads in the pool finish.
+    */
+  void Join();
+
+ private:
+  Impl* impl_;
+};
+
+/*!
+ * \brief Platform-agnostic no-op.
+ */
+void Yield();
+
+/*!
+ * \return the maximum number of effective workers for this system.
+ */
+int MaxConcurrency();
+
+}  // namespace threading
+}  // namespace runtime
+}  // namespace tvm
+
+#endif  // TVM_RUNTIME_THREADING_BACKEND_H_
diff --git a/sgx/sgx_runtime.cc b/sgx/runtime_t.cc
similarity index 52%
rename from sgx/sgx_runtime.cc
rename to sgx/runtime_t.cc
index 6a0d0dfb2..5f280ffce 100644
--- a/sgx/sgx_runtime.cc
+++ b/sgx/runtime_t.cc
@@ -9,17 +9,15 @@
 #include "../../src/runtime/module.cc"
 #include "../../src/runtime/registry.cc"
 #include "../../src/runtime/system_lib_module.cc"
+#ifndef _LIBCPP_SGX_CONFIG
+#include "../../src/runtime/threading_backend.cc"
+#else
+#include "threading_backend.cc"
+#endif
+#include "../../src/runtime/thread_pool.cc"
 
-// dummy parallel runtime (for now)
-int TVMBackendParallelLaunch(
-    FTVMParallelLambda flambda,
-    void* cdata,
-    int num_task) {
-  TVMParallelGroupEnv env = { nullptr /* sync_handle */, 1 /* num_task */ };
-  return flambda(0 /* task_id */, &env, cdata);
+extern "C" {
+void tvm_ecall_shutdown() {
+  tvm::runtime::ThreadPool::Global()->Shutdown();
 }
-
-int TVMBackendParallelBarrier(int task_id, TVMParallelGroupEnv* penv) {
-  return 0;
 }
-
diff --git a/sgx/runtime_u.cc b/sgx/runtime_u.cc
new file mode 100644
index 000000000..0acccf614
--- /dev/null
+++ b/sgx/runtime_u.cc
@@ -0,0 +1,34 @@
+#include <tvm/runtime/threading_backend.h>
+#include "../../src/runtime/threading_backend.cc"
+#include <iostream>
+
+extern sgx_enclave_id_t tvm_sgx_eid;
+extern "C" {
+sgx_status_t tvm_ecall_run_worker(sgx_enclave_id_t eid, const void* cb);
+}
+
+namespace tvm {
+namespace runtime {
+namespace sgx {
+
+static std::unique_ptr<tvm::runtime::threading::ThreadGroup> sgx_thread_group;
+
+extern "C" {
+void tvm_ocall_thread_pool_launch(int num_tasks, void* cb) {
+  std::function<void(int)> runner = [cb](int _worker_id) {
+    sgx_status_t sgx_status = SGX_ERROR_UNEXPECTED;
+    sgx_status = tvm_ecall_run_worker(tvm_sgx_eid, cb);
+    CHECK(sgx_status == SGX_SUCCESS) << "SGX Error: " << sgx_status;
+  };
+  sgx_thread_group.reset(new tvm::runtime::threading::ThreadGroup(
+        num_tasks, runner, false /* include_main_thread */));
+}
+}
+
+void Shutdown() {
+  sgx_thread_group->Join();
+}
+
+}  // namespace sgx
+}  // namespace runtime
+}  // namespace tvm
diff --git a/sgx/threading_backend.cc b/sgx/threading_backend.cc
new file mode 100644
index 000000000..7f820ab51
--- /dev/null
+++ b/sgx/threading_backend.cc
@@ -0,0 +1,71 @@
+/*!
+ *  Copyright (c) 2018 by Contributors
+ * \file sgx/threading_backend.cc
+ * \brief SGX threading backend
+ */
+#include <tvm/runtime/threading_backend.h>
+#include <dmlc/logging.h>
+#include <sgx_edger8r.h>
+#include <sgx_trts.h>
+#include <atomic>
+
+extern "C" {
+sgx_status_t SGX_CDECL tvm_ocall_thread_pool_launch(int num_workers, void* cb);
+}
+
+#ifndef TVM_SGX_MAX_CONCURRENCY
+#define TVM_SGX_MAX_CONCURRENCY 1
+#endif
+
+namespace tvm {
+namespace runtime {
+namespace threading {
+
+class ThreadGroup::Impl {
+ public:
+  Impl(int num_workers, std::function<void(int)> worker_callback,
+       bool exclude_worker0)
+      : num_workers_(num_workers),
+        worker_callback_(worker_callback),
+        next_task_id_(exclude_worker0) {
+    CHECK(num_workers <= TVM_SGX_MAX_CONCURRENCY)
+      << "Tried spawning more threads than allowed by TVM_SGX_MAX_CONCURRENCY.";
+    sgx_status_t sgx_status = SGX_ERROR_UNEXPECTED;
+    sgx_status = tvm_ocall_thread_pool_launch(num_workers, this);
+    CHECK(sgx_status == SGX_SUCCESS) << "SGX Error: " << sgx_status;
+  }
+
+  void RunTask() {
+    int task_id = next_task_id_++;
+    CHECK(task_id < num_workers_)
+      << "More workers entered enclave than allowed by TVM_SGX_MAX_CONCURRENCY";
+    worker_callback_(task_id);
+  }
+
+ private:
+  int num_workers_;
+  std::function<void(int)> worker_callback_;
+  std::atomic<int> next_task_id_;
+};
+
+ThreadGroup::ThreadGroup(int num_workers,
+                         std::function<void(int)> worker_callback,
+                         bool exclude_worker0)
+  : impl_(new ThreadGroup::Impl(num_workers, worker_callback, exclude_worker0)) {}
+void ThreadGroup::Join() {}
+ThreadGroup::~ThreadGroup() { delete impl_; }
+
+void Yield() {}
+
+int MaxConcurrency() { return TVM_SGX_MAX_CONCURRENCY; }
+
+extern "C" {
+void tvm_ecall_run_worker(const void* impl) {
+  if (!sgx_is_within_enclave(impl, sizeof(ThreadGroup::Impl))) return;
+  ((ThreadGroup::Impl*)impl)->RunTask();
+}
+}
+
+}  // namespace threading
+}  // namespace runtime
+}  // namespace tvm
diff --git a/sgx/tvm.edl b/sgx/tvm.edl
new file mode 100644
index 000000000..e88ac0ac7
--- /dev/null
+++ b/sgx/tvm.edl
@@ -0,0 +1,15 @@
+enclave {
+    from "sgx_tstdc.edl" import *;
+
+    trusted {
+        public void tvm_ecall_run_module([user_check] const void* tvm_args,
+                                         [user_check] void* tvm_ret_value);
+        public void tvm_ecall_run_worker([user_check] const void* cb);
+        public void tvm_ecall_shutdown();
+    };
+
+    untrusted {
+        void tvm_ocall_thread_pool_launch(int num_workers, [user_check] void* cb);
+    };
+};
+
diff --git a/src/runtime/thread_pool.cc b/src/runtime/thread_pool.cc
index 4e13fdd14..ac42a0c03 100644
--- a/src/runtime/thread_pool.cc
+++ b/src/runtime/thread_pool.cc
@@ -5,6 +5,7 @@
  */
 #include <tvm/runtime/c_runtime_api.h>
 #include <tvm/runtime/c_backend_api.h>
+#include <tvm/runtime/threading_backend.h>
 #include <dmlc/thread_local.h>
 #include <dmlc/logging.h>
 #include <thread>
@@ -17,9 +18,6 @@
 #include <cstring>
 #include <memory>
 #include <sstream>
-#if defined(__linux__)
-#include <sched.h>
-#endif
 
 const constexpr int kL1CacheBytes = 64;
 
@@ -73,14 +71,14 @@ class ParallelLauncher {
         return num_pending_ == 0;
       });
     if (!has_error_) return 0;
-    std::ostringstream os;
+    std::string err("");
     for (size_t i = 0; i < par_errors_.size(); ++i) {
       if (par_errors_[i].length() != 0) {
-        os << "Task " << i << " error: " << par_errors_[i] << '\n';
+        err += "Task " + std::to_string(i) + " error: " + par_errors_[i] + '\n';
         par_errors_[i].clear();
       }
     }
-    TVMAPISetLastError(os.str().c_str());
+    TVMAPISetLastError(err.c_str());
     return -1;
   }
   // Signal that one job has finished.
@@ -157,7 +155,7 @@ class SpscTaskQueue {
    */
   void Push(const Task& input) {
     while (!Enqueue(input)) {
-      std::this_thread::yield();
+      tvm::runtime::threading::Yield();
     }
     if (pending_.fetch_add(1) == -1) {
       std::unique_lock<std::mutex> lock(mutex_);
@@ -176,8 +174,8 @@ class SpscTaskQueue {
     // If a new task comes to the queue quickly, this wait avoid the worker from sleeping.
     // The default spin count is set by following the typical omp convention
     for (uint32_t i = 0; i < spin_count && pending_.load() == 0; ++i) {
-        std::this_thread::yield();
-      }
+      tvm::runtime::threading::Yield();
+    }
     if (pending_.fetch_sub(1) == 0) {
       std::unique_lock<std::mutex> lock(mutex_);
       cv_.wait(lock, [this] {
@@ -211,6 +209,8 @@ class SpscTaskQueue {
    * \return Whether the task is enqueued.
    */
   bool Enqueue(const Task& input) {
+    if (exit_now_.load(std::memory_order_relaxed)) return false;
+
     const uint32_t tail = tail_.load(std::memory_order_relaxed);
 
     if ((tail + 1) % kRingSize != (head_.load(std::memory_order_acquire))) {
@@ -255,32 +255,17 @@ class SpscTaskQueue {
 // The thread pool
 class ThreadPool {
  public:
-  ThreadPool() {
-    const char *val = getenv("TVM_NUM_THREADS");
-    if (val == nullptr) {
-      val = getenv("OMP_NUM_THREADS");
-    }
-    if (val != nullptr) {
-      num_workers_ = atoi(val);
-    } else {
-#if defined(_M_X64) || defined(__x86_64__)
-      // Half to not count hyper threading.
-      num_workers_ = std::thread::hardware_concurrency() / 2;
-#else
-      num_workers_ = std::thread::hardware_concurrency();
-#endif
-    }
-    num_workers_ = std::max(num_workers_, 1);
-    this->Init();
-  }
-  ~ThreadPool() {
-    for (std::unique_ptr<SpscTaskQueue>& q : queues_) {
-      q->SignalForKill();
-    }
-    for (std::thread& t : threads_) {
-      t.join();
+  ThreadPool(): num_workers_(tvm::runtime::threading::MaxConcurrency()) {
+    for (int i = 0; i < num_workers_; ++i) {
+      // The SpscTaskQueue only host ONE item at a time
+      queues_.emplace_back(std::unique_ptr<SpscTaskQueue>(new SpscTaskQueue()));
     }
+    threads_ = std::unique_ptr<tvm::runtime::threading::ThreadGroup>(
+        new tvm::runtime::threading::ThreadGroup(
+          num_workers_, [this](int worker_id) { this->RunWorker(worker_id); },
+          false /* include_main_thread */));
   }
+  ~ThreadPool() { Shutdown(); }
   int Launch(FTVMParallelLambda flambda,
              void* cdata,
              int num_task,
@@ -307,38 +292,22 @@ class ThreadPool {
     return res;
   }
 
+  void Shutdown() {
+    for (std::unique_ptr<SpscTaskQueue>& q : queues_) {
+      q->SignalForKill();
+    }
+    threads_.reset();
+  }
+
   static ThreadPool* Global() {
     static ThreadPool inst;
     return &inst;
   }
 
  private:
-  // Initialize the pool.
-  void Init() {
-    for (int i = 0; i < num_workers_; ++i) {
-      // The SpscTaskQueue only host ONE item at a time
-      queues_.emplace_back(
-          std::unique_ptr<SpscTaskQueue>(new SpscTaskQueue()));
-    }
-    threads_.resize(num_workers_);
-    for (int i = 0; i < num_workers_; ++i) {
-      threads_[i] = std::thread([this, i] {
-          this->RunWorker(queues_[i].get());
-        });
-    }
-    const char *val = getenv("TVM_BIND_THREADS");
-    if (val == nullptr || atoi(val) == 1) {
-      if (num_workers_ <= std::thread::hardware_concurrency()) {
-        SetThreadAffinity();
-      } else {
-        LOG(WARNING)
-          << "The thread affinity cannot be set when the number of workers is larger "
-          << "than the number of available cores in the system.";
-      }
-    }
-  }
   // Internal worker function.
-  void RunWorker(SpscTaskQueue* queue) {
+  void RunWorker(int worker_id) {
+    SpscTaskQueue* queue = queues_[worker_id].get();
     SpscTaskQueue::Task task;
     ParallelLauncher::ThreadLocal()->is_worker = true;
     while (queue->Pop(&task)) {
@@ -352,40 +321,9 @@ class ThreadPool {
       }
     }
   }
-  // bind worker threads to disjoint cores
-  void SetThreadAffinity() {
-#if defined(__ANDROID__)
-#ifndef CPU_SET
-  #define CPU_SETSIZE 1024
-  #define __NCPUBITS (8 * sizeof (uint64_t))
-  typedef struct {
-    uint64_t __bits[CPU_SETSIZE / __NCPUBITS];
-  } cpu_set_t;
-
-  #define CPU_SET(cpu, cpusetp) \
-    ((cpusetp)->__bits[(cpu)/__NCPUBITS] |= (1UL << ((cpu) % __NCPUBITS)))
-  #define CPU_ZERO(cpusetp) \
-    memset((cpusetp), 0, sizeof(cpu_set_t))
-#endif
-#endif
-    for (int i=0; i < num_workers_; ++i) {
-#if defined(__linux__) || defined(__ANDROID__)
-      cpu_set_t cpuset;
-      CPU_ZERO(&cpuset);
-      CPU_SET(i, &cpuset);
-#if defined(__ANDROID__)
-      sched_setaffinity(threads_[i].native_handle(), sizeof(cpu_set_t), &cpuset);
-#else
-      pthread_setaffinity_np(threads_[i].native_handle(),
-        sizeof(cpu_set_t), &cpuset);
-#endif
-#endif
-    }
-  }
-  // Number of workers
   int num_workers_;
   std::vector<std::unique_ptr<SpscTaskQueue> > queues_;
-  std::vector<std::thread> threads_;
+  std::unique_ptr<tvm::runtime::threading::ThreadGroup> threads_;
 };
 
 }  // namespace runtime
@@ -411,7 +349,7 @@ int TVMBackendParallelBarrier(int task_id, TVMParallelGroupEnv* penv) {
     if (i != task_id) {
       while (sync_counter[i * kSyncStride].load(
                  std::memory_order_relaxed) <= old_counter) {
-        std::this_thread::yield();
+        tvm::runtime::threading::Yield();
       }
     }
   }
diff --git a/src/runtime/threading_backend.cc b/src/runtime/threading_backend.cc
new file mode 100644
index 000000000..19ba9bf2d
--- /dev/null
+++ b/src/runtime/threading_backend.cc
@@ -0,0 +1,113 @@
+/*!
+ *  Copyright (c) 2018 by Contributors
+ * \file threading_backend.cc
+ * \brief Native threading backend
+ */
+#include <tvm/runtime/threading_backend.h>
+#include <dmlc/logging.h>
+#include <thread>
+#if defined(__linux__)
+#include <sched.h>
+#endif
+
+namespace tvm {
+namespace runtime {
+namespace threading {
+
+class ThreadGroup::Impl {
+ public:
+  Impl(int num_workers,
+       std::function<void(int)> worker_callback,
+       bool exclude_worker0)
+      : num_workers_(num_workers) {
+    CHECK_GE(num_workers, 1)
+      << "Requested a non-positive number of worker threads.";
+    for (int i = exclude_worker0; i < num_workers_; ++i) {
+      threads_.emplace_back([worker_callback, i] { worker_callback(i); });
+    }
+    const char *val = getenv("TVM_BIND_THREADS");
+    if (val == nullptr || atoi(val) == 1) {
+      if (num_workers_ <= std::thread::hardware_concurrency()) {
+        SetAffinity();
+      } else {
+        LOG(WARNING)
+          << "The thread affinity cannot be set when the number of workers"
+          << "is larger than the number of available cores in the system.";
+      }
+    }
+  }
+  ~Impl() { Join(); }
+
+  void Join() {
+    for (auto& t : threads_) {
+      if (t.joinable()) t.join();
+    }
+  }
+
+ private:
+  // bind worker threads to disjoint cores
+  void SetAffinity() {
+#if defined(__ANDROID__)
+#ifndef CPU_SET
+#define CPU_SETSIZE 1024
+#define __NCPUBITS (8 * sizeof (uint64_t))
+    typedef struct {
+      uint64_t __bits[CPU_SETSIZE / __NCPUBITS];
+    } cpu_set_t;
+
+#define CPU_SET(cpu, cpusetp) \
+    ((cpusetp)->__bits[(cpu)/__NCPUBITS] |= (1UL << ((cpu) % __NCPUBITS)))
+#define CPU_ZERO(cpusetp) \
+    memset((cpusetp), 0, sizeof(cpu_set_t))
+#endif
+#endif
+    for (unsigned i=0; i < threads_.size(); ++i) {
+#if defined(__linux__) || defined(__ANDROID__)
+      cpu_set_t cpuset;
+      CPU_ZERO(&cpuset);
+      CPU_SET(i, &cpuset);
+#if defined(__ANDROID__)
+      sched_setaffinity(threads_[i].native_handle(), sizeof(cpu_set_t), &cpuset);
+#else
+      pthread_setaffinity_np(threads_[i].native_handle(),
+          sizeof(cpu_set_t), &cpuset);
+#endif
+#endif
+    }
+  }
+
+  int num_workers_;
+  std::vector<std::thread> threads_;
+};
+
+ThreadGroup::ThreadGroup(int num_workers,
+                         std::function<void(int)> worker_callback,
+                         bool exclude_worker0)
+  : impl_(new ThreadGroup::Impl(num_workers, worker_callback, exclude_worker0)) {}
+ThreadGroup::~ThreadGroup() { delete impl_; }
+void ThreadGroup::Join() { impl_->Join(); }
+
+void Yield() {
+  std::this_thread::yield();
+}
+
+int MaxConcurrency() {
+  int max_concurrency = 1;
+  const char *val = getenv("TVM_NUM_THREADS");
+  if (val == nullptr) {
+    val = getenv("OMP_NUM_THREADS");
+  }
+  if (val != nullptr) {
+    max_concurrency = atoi(val);
+  } else {
+    max_concurrency = std::thread::hardware_concurrency();
+#if defined(_M_X64) || defined(__x86_64__)
+    max_concurrency /= 2;  // ignore hyper-threading
+#endif
+  }
+  return std::max(max_concurrency, 1);
+}
+
+}  // namespace threading
+}  // namespace runtime
+}  // namespace tvm
-- 
GitLab