DPC++ Runtime
Runtime libraries for oneAPI DPC++
thread_pool.hpp
Go to the documentation of this file.
1 //===-- thread_pool.hpp - Simple thread pool --------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #pragma once
10 
11 #include <algorithm>
12 #include <atomic>
13 #include <condition_variable>
14 #include <mutex>
15 #include <queue>
16 #include <thread>
17 #include <vector>
18 
19 #include <sycl/detail/defines.hpp>
20 
21 namespace sycl {
22 inline namespace _V1 {
23 namespace detail {
24 
25 class ThreadPool {
26  std::vector<std::thread> MLaunchedThreads;
27 
28  size_t MThreadCount;
29  std::queue<std::function<void()>> MJobQueue;
30  std::mutex MJobQueueMutex;
31  std::condition_variable MDoSmthOrStop;
32  std::atomic_bool MStop;
33  std::atomic_uint MJobsInPool;
34 
35  void worker() {
36  GlobalHandler::instance().registerSchedulerUsage(/*ModifyCounter*/ false);
37  std::unique_lock<std::mutex> Lock(MJobQueueMutex);
38  while (true) {
39  MDoSmthOrStop.wait(
40  Lock, [this]() { return !MJobQueue.empty() || MStop.load(); });
41 
42  if (MStop.load())
43  break;
44 
45  std::function<void()> Job = std::move(MJobQueue.front());
46  MJobQueue.pop();
47  Lock.unlock();
48 
49  Job();
50 
51  Lock.lock();
52 
53  MJobsInPool--;
54  }
55  }
56 
57  void start() {
58  MLaunchedThreads.reserve(MThreadCount);
59 
60  MStop.store(false);
61  MJobsInPool.store(0);
62 
63  for (size_t Idx = 0; Idx < MThreadCount; ++Idx)
64  MLaunchedThreads.emplace_back([this] { worker(); });
65  }
66 
67 public:
68  void drain() {
69  while (MJobsInPool != 0)
70  std::this_thread::yield();
71  }
72 
73  ThreadPool(unsigned int ThreadCount = 1) : MThreadCount(ThreadCount) {
74  start();
75  }
76 
78  try {
79  finishAndWait();
80  } catch (std::exception &e) {
81  __SYCL_REPORT_EXCEPTION_TO_STREAM("exception in ~ThreadPool", e);
82  }
83  }
84 
85  void finishAndWait() {
86  MStop.store(true);
87 
88  MDoSmthOrStop.notify_all();
89 
90  for (std::thread &Thread : MLaunchedThreads)
91  if (Thread.joinable())
92  Thread.join();
93  }
94 
95  template <typename T> void submit(T &&Func) {
96  {
97  std::lock_guard<std::mutex> Lock(MJobQueueMutex);
98  MJobQueue.emplace([F = std::move(Func)]() { F(); });
99  }
100  MJobsInPool++;
101  MDoSmthOrStop.notify_one();
102  }
103 
104  void submit(std::function<void()> &&Func) {
105  {
106  std::lock_guard<std::mutex> Lock(MJobQueueMutex);
107  MJobQueue.emplace(Func);
108  }
109  MJobsInPool++;
110  MDoSmthOrStop.notify_one();
111  }
112 };
113 
114 } // namespace detail
115 } // namespace _V1
116 } // namespace sycl
void registerSchedulerUsage(bool ModifyCounter=true)
static GlobalHandler & instance()
void submit(std::function< void()> &&Func)
ThreadPool(unsigned int ThreadCount=1)
Definition: thread_pool.hpp:73
#define __SYCL_REPORT_EXCEPTION_TO_STREAM(str, e)
Definition: common.hpp:373
Definition: access.hpp:18