DPC++ Runtime
Runtime libraries for oneAPI DPC++
thread_pool.hpp
Go to the documentation of this file.
1 //===-- thread_pool.hpp - Simple thread pool --------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #pragma once
10 
11 #include <algorithm>
12 #include <atomic>
13 #include <condition_variable>
14 #include <mutex>
15 #include <queue>
16 #include <thread>
17 #include <vector>
18 
19 #include <sycl/detail/defines.hpp>
20 
21 namespace sycl {
22 inline namespace _V1 {
23 namespace detail {
24 
25 class ThreadPool {
26  std::vector<std::thread> MLaunchedThreads;
27 
28  size_t MThreadCount;
29  std::queue<std::function<void()>> MJobQueue;
30  std::mutex MJobQueueMutex;
31  std::condition_variable MDoSmthOrStop;
32  std::atomic_bool MStop;
33  std::atomic_uint MJobsInPool;
34 
35  void worker() {
36  GlobalHandler::instance().registerSchedulerUsage(/*ModifyCounter*/ false);
37  std::unique_lock<std::mutex> Lock(MJobQueueMutex);
38  while (true) {
39  MDoSmthOrStop.wait(
40  Lock, [this]() { return !MJobQueue.empty() || MStop.load(); });
41 
42  if (MStop.load())
43  break;
44 
45  std::function<void()> Job = std::move(MJobQueue.front());
46  MJobQueue.pop();
47  Lock.unlock();
48 
49  Job();
50 
51  Lock.lock();
52 
53  MJobsInPool--;
54  }
55  }
56 
57  void start() {
58  MLaunchedThreads.reserve(MThreadCount);
59 
60  MStop.store(false);
61  MJobsInPool.store(0);
62 
63  for (size_t Idx = 0; Idx < MThreadCount; ++Idx)
64  MLaunchedThreads.emplace_back([this] { worker(); });
65  }
66 
67 public:
68  void drain() {
69  while (MJobsInPool != 0)
70  std::this_thread::yield();
71  }
72 
73  ThreadPool(unsigned int ThreadCount = 1) : MThreadCount(ThreadCount) {
74  start();
75  }
76 
78 
79  void finishAndWait() {
80  MStop.store(true);
81 
82  MDoSmthOrStop.notify_all();
83 
84  for (std::thread &Thread : MLaunchedThreads)
85  if (Thread.joinable())
86  Thread.join();
87  }
88 
89  template <typename T> void submit(T &&Func) {
90  {
91  std::lock_guard<std::mutex> Lock(MJobQueueMutex);
92  MJobQueue.emplace([F = std::move(Func)]() { F(); });
93  }
94  MJobsInPool++;
95  MDoSmthOrStop.notify_one();
96  }
97 
98  void submit(std::function<void()> &&Func) {
99  {
100  std::lock_guard<std::mutex> Lock(MJobQueueMutex);
101  MJobQueue.emplace(Func);
102  }
103  MJobsInPool++;
104  MDoSmthOrStop.notify_one();
105  }
106 };
107 
108 } // namespace detail
109 } // namespace _V1
110 } // namespace sycl
void registerSchedulerUsage(bool ModifyCounter=true)
static GlobalHandler & instance()
void submit(std::function< void()> &&Func)
Definition: thread_pool.hpp:98
ThreadPool(unsigned int ThreadCount=1)
Definition: thread_pool.hpp:73
Definition: access.hpp:18