Skip to content

Commit 2b4b404

Browse files
committed
add wait_for_all to the ProactiveThreadpool
1 parent 27c2799 commit 2b4b404

2 files changed

Lines changed: 56 additions & 1 deletion

File tree

taskflow/threadpool/proactive_threadpool.hpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,11 @@ class ProactiveThreadpool {
9595
if(_task_queue.empty()){
9696
w.ready = false;
9797
_workers.push_back(&w);
98+
99+
if(_workers.size() == num_workers()){
100+
_complete.notify_one();
101+
}
102+
98103
w.cv.wait(lock, [&w](){ return w.ready; });
99104
t = std::move(w.task);
100105
}
@@ -215,6 +220,18 @@ class ProactiveThreadpool {
215220

216221
}
217222

223+
void wait_for_all(){
224+
225+
if(is_worker()){
226+
throw std::runtime_error("Worker thread cannot wait for all");
227+
}
228+
229+
std::unique_lock<std::mutex> lock(_mutex);
230+
_complete.wait(lock, [this](){ return _workers.size() == num_workers(); });
231+
232+
}
233+
234+
218235
private:
219236

220237
template <typename T>
@@ -238,6 +255,8 @@ class ProactiveThreadpool {
238255
mutable std::mutex _mutex;
239256

240257
std::condition_variable _empty;
258+
std::condition_variable _complete;
259+
241260
std::deque<UnitTask> _task_queue;
242261
std::vector<std::thread> _threads;
243262
std::vector<Worker*> _workers;

unittest/threadpool.cpp

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,38 @@ void test_threadpool_async(ThreadpoolType& tp, const size_t task_num){
5454
for(size_t i=0; i<int_future.size(); i++){
5555
CHECK(int_future[i].get() == int_result[i]);
5656
}
57-
57+
}
58+
59+
template <typename ThreadpoolType>
60+
void test_threadpool_wait_for_all(ThreadpoolType& tp){
61+
62+
const size_t worker_num = tp.num_workers();
63+
const size_t task_num = 20;
64+
std::atomic<size_t> counter{0};
65+
66+
for(size_t i=0; i<task_num; i++){
67+
tp.silent_async([&counter](){
68+
std::this_thread::sleep_for(std::chrono::milliseconds(200));
69+
counter++;
70+
});
71+
}
72+
CHECK(counter < task_num);
73+
//std::cout << "counter: " << counter << std::endl;
74+
75+
tp.shutdown();
76+
tp.spawn(worker_num);
77+
78+
counter = 0;
79+
for(size_t i=0; i<task_num; i++){
80+
tp.silent_async([&counter](){
81+
std::this_thread::sleep_for(std::chrono::milliseconds(200));
82+
counter++;
83+
});
84+
}
85+
tp.wait_for_all();
86+
//std::cout << "counter:" << counter << std::endl;
87+
CHECK(counter == task_num);
88+
5889
}
5990

6091
// --------------------------------------------------------
@@ -76,5 +107,10 @@ TEST_CASE("Threadpool.ProactiveThreadpool" * doctest::timeout(5)) {
76107
test_threadpool_silent_async(tp, task_num);
77108
}
78109

110+
SUBCASE("WaitForAll"){
111+
tf::ProactiveThreadpool tp(4);
112+
test_threadpool_wait_for_all(tp);
113+
}
114+
79115
}
80116

0 commit comments

Comments
 (0)