Skip to content

Commit a91043b

Browse files
updated privatized_threadpool
1 parent d917030 commit a91043b

1 file changed

Lines changed: 82 additions & 74 deletions

File tree

taskflow/threadpool/privatized_threadpool.hpp

Lines changed: 82 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929

3030
namespace tf {
3131

32+
// ----------------------------------------------------------------------------
33+
3234
template <typename T, unsigned N>
3335
class RunQueue {
3436

@@ -203,7 +205,7 @@ bool RunQueue<T, N>::empty() const {
203205
}
204206

205207
// Class: BasicPrivatizedThreadpool
206-
/*template < template<typename...> class Func >
208+
template < template<typename...> class Func >
207209
class BasicPrivatizedThreadpool {
208210

209211
using TaskType = Func<void()>;
@@ -236,34 +238,28 @@ class BasicPrivatizedThreadpool {
236238

237239
private:
238240

241+
const std::thread::id _owner {std::this_thread::get_id()};
242+
239243
mutable std::mutex _mutex;
240244

241245
std::condition_variable _empty_cv;
242246

243247
std::deque<TaskType> _task_queue;
248+
244249
std::vector<std::thread> _threads;
250+
std::vector<std::unique_ptr<Worker>> _workers;
251+
std::vector<size_t> _coprimes;
245252

246-
// TODO: do we need atomic variable here?
247-
std::atomic<size_t> _idle_workers {0};
253+
size_t _num_idlers {0};
254+
size_t _next_queue {0};
248255

249256
std::unordered_map<std::thread::id, size_t> _worker_maps;
250257

251-
const std::thread::id _owner {std::this_thread::get_id()};
252-
253258
bool _exiting {false};
254259
bool _wait_for_all {false};
255260

256-
std::vector<std::unique_ptr<Worker>> _works;
257-
258-
// TODO: can we just use some hacky method to replace atomic
259-
// or make it relaxed
260-
std::atomic<size_t> _next_queue {0};
261-
262261
size_t _nonempty_worker_queue() const;
263262

264-
bool _sync {false};
265-
266-
std::vector<size_t> _coprimes;
267263
void _xorshift32(uint32_t&);
268264
bool _steal(TaskType&, uint32_t&);
269265

@@ -273,12 +269,12 @@ class BasicPrivatizedThreadpool {
273269
// Function: _nonempty_worker_queue
274270
template < template<typename...> class Func >
275271
size_t BasicPrivatizedThreadpool<Func>::_nonempty_worker_queue() const {
276-
for(size_t i=0;i <_works.size(); ++i){
277-
if(!_works[i]->queue.empty()){
272+
for(size_t i=0;i <_workers.size(); ++i){
273+
if(!_workers[i]->queue.empty()){
278274
return i;
279275
}
280276
}
281-
return _works.size();
277+
return _workers.size();
282278
}
283279

284280
// Function: _xorshift32
@@ -299,7 +295,7 @@ bool BasicPrivatizedThreadpool<Func>::_steal(TaskType& w, uint32_t& dice){
299295
const auto queue_num = num_workers();
300296
auto victim = dice % queue_num;
301297
for(size_t i=0; i<queue_num; i++){
302-
if(_works[victim]->queue.pop_back(w)){
298+
if(_workers[victim]->queue.pop_back(w)){
303299
return true;
304300
}
305301
victim += inc;
@@ -359,18 +355,17 @@ void BasicPrivatizedThreadpool<Func>::shutdown(){
359355
_wait_for_all = true;
360356

361357
// Wake up all workers in case they are already idle
362-
for(const auto& w : _works){
358+
for(const auto& w : _workers){
363359
w->cv.notify_one();
364360
}
365361

366-
//while(_idle_workers != num_workers()) {
367-
while(!_sync){
362+
//while(_num_idlers != num_workers()) {
363+
while(_wait_for_all){
368364
_empty_cv.wait(lock);
369365
}
370-
_sync = false;
371366
_exiting = true;
372367

373-
for(auto& w : _works){
368+
for(auto& w : _workers){
374369
// TODO: can we replace this dummy task with state?
375370
w->queue.push_back([](){});
376371
w->cv.notify_one();
@@ -382,7 +377,7 @@ void BasicPrivatizedThreadpool<Func>::shutdown(){
382377
}
383378
_threads.clear();
384379

385-
_works.clear();
380+
_workers.clear();
386381
_worker_maps.clear();
387382

388383
_wait_for_all = false;
@@ -421,53 +416,53 @@ void BasicPrivatizedThreadpool<Func>::spawn(unsigned N) {
421416
}
422417

423418
for(size_t i=0; i<N; ++i){
424-
_works.push_back(std::make_unique<Worker>());
419+
_workers.push_back(std::make_unique<Worker>());
425420
}
426421

427422
for(size_t i=0; i<N; ++i){
428423
_threads.emplace_back([this, i=i+sz]() -> void {
429424

430425
TaskType t {nullptr};
431-
Worker& w = *(_works[i]);
426+
Worker& w = *(_workers[i]);
432427
uint32_t dice = i+1;
433428
std::unique_lock<std::mutex> lock(_mutex);
434429

435430
while(!_exiting){
436431

437432
if(!w.queue.pop_front(t)){
438-
if(_steal(t, dice)){}
433+
if(_steal(t, dice)) {
434+
}
439435
else if(!_task_queue.empty()) {
440436
t = std::move(_task_queue.front());
441437
_task_queue.pop_front();
442438
}
443439
else {
444-
// TODO: do we need another while loop here?
445-
//while(!w.queue.pop_front(t) && _task_queue.empty()){
446-
if(++_idle_workers == num_workers() && _wait_for_all){
447-
// Last active thread checks if all queues are empty
448-
if(auto ret = _nonempty_worker_queue(); ret == num_workers()){
449-
// TODO: here only one thread will do so
450-
_sync = true;
451-
_empty_cv.notify_one();
452-
}
453-
else{
454-
if(ret == i){
455-
-- _idle_workers;
456-
continue;
457-
}
458-
_works[ret]->cv.notify_one();
440+
if(++_num_idlers == num_workers() && _wait_for_all){
441+
// Last active thread checks if all queues are empty
442+
if(auto ret = _nonempty_worker_queue(); ret == num_workers()){
443+
_wait_for_all = false;
444+
_empty_cv.notify_one();
445+
}
446+
else{
447+
if(ret == i){
448+
--_num_idlers;
449+
continue;
459450
}
460-
}
461-
w.cv.wait(lock);
462-
--_idle_workers;
463-
//}
451+
_workers[ret]->cv.notify_one();
452+
}
453+
}
454+
w.cv.wait(lock);
455+
--_num_idlers;
464456
}
465457
} // End of first if
466458

467459
if(t){
468460
_mutex.unlock();
469-
t();
470-
t = nullptr;
461+
// speculation
462+
do {
463+
t();
464+
t = nullptr;
465+
} while(w.queue.pop_front(t));
471466
_mutex.lock();
472467
}
473468
} // End of while ------------------------------------------------------
@@ -535,7 +530,7 @@ void BasicPrivatizedThreadpool<Func>::silent_async(C&& c){
535530
if(std::this_thread::get_id() != _owner){
536531
auto tid = std::this_thread::get_id();
537532
if(_worker_maps.find(tid) != _worker_maps.end()){
538-
if(!_works[_worker_maps.at(tid)]->queue.push_front(t)){
533+
if(!_workers[_worker_maps.at(tid)]->queue.push_front(t)){
539534
std::scoped_lock<std::mutex> lock(_mutex);
540535
_task_queue.push_back(std::move(t));
541536
}
@@ -545,14 +540,14 @@ void BasicPrivatizedThreadpool<Func>::silent_async(C&& c){
545540

546541
// owner thread or other threads
547542
// TODO: use random for load balancing?
548-
auto id = (++_next_queue)%_works.size();
549-
if(!_works[id]->queue.push_back(t)){
543+
auto id = (++_next_queue)%_workers.size();
544+
if(!_workers[id]->queue.push_back(t)){
550545
std::scoped_lock<std::mutex> lock(_mutex);
551546
_task_queue.push_back(std::move(t));
552547
}
553548

554549
// Make sure at least one worker will handle the task
555-
_works[id]->cv.notify_one();
550+
_workers[id]->cv.notify_one();
556551
}
557552

558553

@@ -567,23 +562,20 @@ void BasicPrivatizedThreadpool<Func>::wait_for_all() {
567562
if(num_workers() == 0) return ;
568563

569564
std::unique_lock<std::mutex> lock(_mutex);
565+
570566
_wait_for_all = true;
567+
571568
// Wake up all workers in case they are already idle
572-
for(const auto& w : _works){
569+
for(const auto& w : _workers){
573570
w->cv.notify_one();
574571
}
575572

576-
// TODO: can we use a single wait_for_all?
577-
while(!_sync){
573+
while(_wait_for_all) {
578574
_empty_cv.wait(lock);
579575
}
576+
}
580577

581-
_sync = false;
582-
_wait_for_all = false;
583-
} */
584-
585-
586-
578+
/*
587579
template < template<typename...> class Func >
588580
class BasicPrivatizedThreadpool {
589581
@@ -638,7 +630,7 @@ class BasicPrivatizedThreadpool {
638630
// TODO: do we need atomic variable here?
639631
std::atomic<bool> _allow_steal {true};
640632
641-
size_t _idle_workers {0};
633+
size_t _num_idlers {0};
642634
size_t _next_queue {0};
643635
644636
bool _wait_for_all {false};
@@ -649,8 +641,6 @@ class BasicPrivatizedThreadpool {
649641
void _xorshift32(uint32_t&);
650642
bool _steal(TaskType&, uint32_t&);
651643
652-
653-
654644
}; // class BasicPrivatizedThreadpool. --------------------------------------
655645
656646
@@ -678,19 +668,37 @@ void BasicPrivatizedThreadpool<Func>::_xorshift32(uint32_t& x){
678668
// Function: _steal
679669
template < template<typename...> class Func >
680670
bool BasicPrivatizedThreadpool<Func>::_steal(TaskType& w, uint32_t& dice){
671+
681672
_xorshift32(dice);
682673
const auto inc = _coprimes[dice % _coprimes.size()];
683674
const auto queue_num = _workers.size();
684675
auto victim = dice % queue_num;
676+
//for(size_t i=0; i<queue_num; i++){
677+
// if(_workers[victim]->queue.pop_back(w)){
678+
// return true;
679+
// }
680+
// victim += inc;
681+
// if(victim >= queue_num){
682+
// victim -= queue_num;
683+
// }
684+
//}
685+
//return false;
686+
687+
static std::atomic_flag locked {ATOMIC_FLAG_INIT};
688+
while (locked.test_and_set(std::memory_order_acquire));
689+
685690
for(size_t i=0; i<queue_num; i++){
686691
if(_workers[victim]->queue.pop_back(w)){
692+
locked.clear(std::memory_order_release);
687693
return true;
688-
}
689-
victim += inc;
694+
}
695+
victim += inc;
690696
if(victim >= queue_num){
691697
victim -= queue_num;
692-
}
698+
}
693699
}
700+
701+
locked.clear(std::memory_order_release);
694702
return false;
695703
}
696704
@@ -742,7 +750,7 @@ void BasicPrivatizedThreadpool<Func>::shutdown(){
742750
std::unique_lock<std::mutex> lock(_mutex);
743751
// If all workers are idle && all queues are empty, then master
744752
// can directly wake up workers without waiting for notified
745-
if(_idle_workers != num_workers() || _nonempty_worker_queue().has_value()){
753+
if(_num_idlers != num_workers() || _nonempty_worker_queue().has_value()){
746754
_wait_for_all = true;
747755
748756
// Wake up all workers in case their queues are not empty
@@ -767,7 +775,6 @@ void BasicPrivatizedThreadpool<Func>::shutdown(){
767775
}
768776
769777
_threads.clear();
770-
771778
_workers.clear();
772779
_worker_maps.clear();
773780
}
@@ -821,12 +828,12 @@ void BasicPrivatizedThreadpool<Func>::spawn(unsigned N) {
821828
_task_queue.pop_front();
822829
}
823830
else{
824-
if(++_idle_workers == num_workers()){
831+
if(++_num_idlers == num_workers()){
825832
// Last active thread checks if all queues are empty
826833
if(auto ret = _nonempty_worker_queue(); ret.has_value()){
827834
// if the nonempty queue is mine, continue to process tasks in queue
828835
if(*ret == i){
829-
--_idle_workers;
836+
--_num_idlers;
830837
lock.unlock();
831838
continue;
832839
}
@@ -847,7 +854,7 @@ void BasicPrivatizedThreadpool<Func>::spawn(unsigned N) {
847854
while(w.state == Worker::ALIVE && w.queue.empty()){
848855
w.cv.wait(lock);
849856
}
850-
--_idle_workers;
857+
--_num_idlers;
851858
}
852859
lock.unlock();
853860
} // End of Steal
@@ -862,6 +869,7 @@ void BasicPrivatizedThreadpool<Func>::spawn(unsigned N) {
862869
863870
_worker_maps.insert({_threads.back().get_id(), i+sz});
864871
} // End of For ---------------------------------------------------------------------------------
872+
865873
_allow_steal = true;
866874
}
867875
@@ -959,7 +967,7 @@ void BasicPrivatizedThreadpool<Func>::wait_for_all() {
959967
std::unique_lock lock(_mutex);
960968
// If all workers are idle && all queues are empty,
961969
// then wait_for_all is done.
962-
if(_idle_workers == num_workers() && !_nonempty_worker_queue()){
970+
if(_num_idlers == num_workers() && !_nonempty_worker_queue()){
963971
return ;
964972
}
965973
@@ -973,7 +981,7 @@ void BasicPrivatizedThreadpool<Func>::wait_for_all() {
973981
_empty_cv.wait(lock);
974982
}
975983
}
976-
984+
*/
977985

978986
}; // namespace tf -----------------------------------------------------------
979987

0 commit comments

Comments
 (0)