Skip to content

Commit 8578e50

Browse files
committed
fixed bug in nested executor
1 parent 852923e commit 8578e50

4 files changed

Lines changed: 120 additions & 49 deletions

File tree

CMakeLists.txt

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -344,14 +344,14 @@ add_test(creation ${TF_UTEST_DIR}/basics -tc=Creation)
344344
add_test(iterators ${TF_UTEST_DIR}/basics -tc=Iterators)
345345
add_test(std_function ${TF_UTEST_DIR}/basics -tc=STDFunction)
346346
add_test(hash ${TF_UTEST_DIR}/basics -tc=Hash)
347-
add_test(serialruns.1thread ${TF_UTEST_DIR}/basics -tc=SerialRuns.1thread)
348-
add_test(serialruns.2threads ${TF_UTEST_DIR}/basics -tc=SerialRuns.2threads)
349-
add_test(serialruns.3threads ${TF_UTEST_DIR}/basics -tc=SerialRuns.3threads)
350-
add_test(serialruns.4threads ${TF_UTEST_DIR}/basics -tc=SerialRuns.4threads)
351-
add_test(serialruns.5threads ${TF_UTEST_DIR}/basics -tc=SerialRuns.5threads)
352-
add_test(serialruns.6threads ${TF_UTEST_DIR}/basics -tc=SerialRuns.6threads)
353-
add_test(serialruns.7threads ${TF_UTEST_DIR}/basics -tc=SerialRuns.7threads)
354-
add_test(serialruns.8threads ${TF_UTEST_DIR}/basics -tc=SerialRuns.8threads)
347+
add_test(serial_runs.1thread ${TF_UTEST_DIR}/basics -tc=SerialRuns.1thread)
348+
add_test(serial_runs.2threads ${TF_UTEST_DIR}/basics -tc=SerialRuns.2threads)
349+
add_test(serial_runs.3threads ${TF_UTEST_DIR}/basics -tc=SerialRuns.3threads)
350+
add_test(serial_runs.4threads ${TF_UTEST_DIR}/basics -tc=SerialRuns.4threads)
351+
add_test(serial_runs.5threads ${TF_UTEST_DIR}/basics -tc=SerialRuns.5threads)
352+
add_test(serial_runs.6threads ${TF_UTEST_DIR}/basics -tc=SerialRuns.6threads)
353+
add_test(serial_runs.7threads ${TF_UTEST_DIR}/basics -tc=SerialRuns.7threads)
354+
add_test(serial_runs.8threads ${TF_UTEST_DIR}/basics -tc=SerialRuns.8threads)
355355
add_test(parallel_runs.1thread ${TF_UTEST_DIR}/basics -tc=ParallelRuns.1thread)
356356
add_test(parallel_runs.2threads ${TF_UTEST_DIR}/basics -tc=ParallelRuns.2threads)
357357
add_test(parallel_runs.3threads ${TF_UTEST_DIR}/basics -tc=ParallelRuns.3threads)
@@ -360,6 +360,12 @@ add_test(parallel_runs.5threads ${TF_UTEST_DIR}/basics -tc=ParallelRuns.5threads
360360
add_test(parallel_runs.6threads ${TF_UTEST_DIR}/basics -tc=ParallelRuns.6threads)
361361
add_test(parallel_runs.7threads ${TF_UTEST_DIR}/basics -tc=ParallelRuns.7threads)
362362
add_test(parallel_runs.8threads ${TF_UTEST_DIR}/basics -tc=ParallelRuns.8threads)
363+
add_test(nested_runs.1thread ${TF_UTEST_DIR}/basics -tc=NestedRuns.1thread)
364+
add_test(nested_runs.2threads ${TF_UTEST_DIR}/basics -tc=NestedRuns.2threads)
365+
add_test(nested_runs.3threads ${TF_UTEST_DIR}/basics -tc=NestedRuns.3threads)
366+
add_test(nested_runs.4threads ${TF_UTEST_DIR}/basics -tc=NestedRuns.4threads)
367+
add_test(nested_runs.8threads ${TF_UTEST_DIR}/basics -tc=NestedRuns.8threads)
368+
add_test(nested_runs.16threads ${TF_UTEST_DIR}/basics -tc=NestedRuns.16threads)
363369
add_test(parallel_for.1thread ${TF_UTEST_DIR}/basics -tc=ParallelFor.1thread)
364370
add_test(parallel_for.2threads ${TF_UTEST_DIR}/basics -tc=ParallelFor.2threads)
365371
add_test(parallel_for.3threads ${TF_UTEST_DIR}/basics -tc=ParallelFor.3threads)

examples/simple.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,5 +26,3 @@ int main(){
2626
return 0; // +---+
2727
}
2828

29-
30-

taskflow/core/executor.hpp

Lines changed: 6 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ class Executor {
3434
struct Worker {
3535
unsigned id;
3636
Domain domain;
37+
Executor* executor;
3738
Notifier::Waiter* waiter;
3839
std::mt19937 rdgen { std::random_device{}() };
3940
TaskQueue<Node*> wsq[HETEROGENEITY];
@@ -332,13 +333,14 @@ inline void Executor::_spawn(unsigned N, Domain d) {
332333

333334
_workers[id].id = id;
334335
_workers[id].domain = d;
336+
_workers[id].executor = this;
335337
_workers[id].waiter = &_notifier[d]._waiters[i];
336338

337339
_threads.emplace_back([this] (Worker& w) -> void {
338340

339341
PerThread& pt = _per_thread();
340342
pt.worker = &w;
341-
343+
342344
Node* t = nullptr;
343345

344346
// must use 1 as condition instead of !done
@@ -581,7 +583,7 @@ inline void Executor::_schedule(Node* node, bool bypass_hint) {
581583
// caller is a worker to this pool
582584
auto worker = _per_thread().worker;
583585

584-
if(worker != nullptr) {
586+
if(worker != nullptr && worker->executor == this) {
585587
if(bypass_hint) {
586588
assert(!worker->cache);
587589
worker->cache = node;
@@ -627,7 +629,7 @@ inline void Executor::_schedule(PassiveVector<Node*>& nodes) {
627629
// task counts
628630
size_t tcount[HETEROGENEITY] = {0};
629631

630-
if(worker != nullptr) {
632+
if(worker != nullptr && worker->executor == this) {
631633
for(size_t i=0; i<num_nodes; ++i) {
632634
const auto d = nodes[i]->domain();
633635
worker->wsq[d].push(nodes[i]);
@@ -991,7 +993,7 @@ std::future<void> Executor::run_until(Taskflow& f, P&& pred) {
991993

992994
// Function: _set_up_topology
993995
inline void Executor::_set_up_topology(Topology* tpg) {
994-
996+
995997
tpg->_sources.clear();
996998

997999
// scan each node in the graph and build up the links
@@ -1098,41 +1100,6 @@ std::future<void> Executor::run_until(Taskflow& f, P&& pred, C&& c) {
10981100
return promise.get_future();
10991101
}
11001102

1101-
1102-
1103-
//// Special case of zero workers requires:
1104-
//// - iterative execution to avoid stack overflow
1105-
//// - avoid execution of last_work
1106-
//if(_workers.size() == 0) {
1107-
//
1108-
// Topology tpg(f, std::forward<P>(pred), std::forward<C>(c));
1109-
1110-
// // Clear last execution data & Build precedence between nodes and target
1111-
// tpg._bind(f._graph);
1112-
1113-
// std::stack<Node*> stack;
1114-
1115-
// do {
1116-
// _schedule_unsync(tpg._sources, stack);
1117-
// while(!stack.empty()) {
1118-
// auto node = stack.top();
1119-
// stack.pop();
1120-
// _invoke_unsync(node, stack);
1121-
// }
1122-
// tpg._recover_num_sinks();
1123-
// } while(!std::invoke(tpg._pred));
1124-
1125-
// if(tpg._call != nullptr) {
1126-
// std::invoke(tpg._call);
1127-
// }
1128-
1129-
// tpg._promise.set_value();
1130-
//
1131-
// _decrement_topology_and_notify();
1132-
//
1133-
// return tpg._promise.get_future();
1134-
//}
1135-
11361103
// Multi-threaded execution.
11371104
bool run_now {false};
11381105
Topology* tpg;

unittests/basics.cpp

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -780,6 +780,106 @@ TEST_CASE("ParallelRuns.8threads" * doctest::timeout(300)) {
780780
parallel_runs(8);
781781
}
782782

783+
// --------------------------------------------------------
784+
// Testcase: NestedRuns
785+
// --------------------------------------------------------
786+
void nested_runs(unsigned w) {
787+
788+
int counter {0};
789+
790+
struct A {
791+
792+
tf::Executor executor;
793+
tf::Taskflow taskflow;
794+
795+
int& counter;
796+
797+
A(unsigned w, int& c) : executor{w}, counter{c} { }
798+
799+
void run()
800+
{
801+
taskflow.clear();
802+
auto A1 = taskflow.emplace([&]() { counter++; });
803+
auto A2 = taskflow.emplace([&]() { counter++; });
804+
A1.precede(A2);
805+
executor.run_n(taskflow, 10).wait();
806+
}
807+
808+
};
809+
810+
struct B {
811+
812+
tf::Taskflow taskflow;
813+
tf::Executor executor;
814+
815+
int& counter;
816+
817+
A a_sim;
818+
819+
B(unsigned w, int& c) : executor{w}, counter{c}, a_sim{w, c} { }
820+
821+
void run()
822+
{
823+
taskflow.clear();
824+
auto B1 = taskflow.emplace([&] () { ++counter; });
825+
auto B2 = taskflow.emplace([&] () { ++counter; a_sim.run(); });
826+
B1.precede(B2);
827+
executor.run_n(taskflow, 100).wait();
828+
}
829+
};
830+
831+
struct C {
832+
833+
tf::Taskflow taskflow;
834+
tf::Executor executor;
835+
836+
int& counter;
837+
838+
B b_sim;
839+
840+
C(unsigned w, int& c) : executor{w}, counter{c}, b_sim{w, c} { }
841+
842+
void run()
843+
{
844+
taskflow.clear();
845+
auto C1 = taskflow.emplace([&] () { ++counter; });
846+
auto C2 = taskflow.emplace([&] () { ++counter; b_sim.run(); });
847+
C1.precede(C2);
848+
executor.run_n(taskflow, 100).wait();
849+
}
850+
};
851+
852+
C c(w, counter);
853+
c.run();
854+
855+
REQUIRE(counter == 220200);
856+
}
857+
858+
TEST_CASE("NestedRuns.1thread") {
859+
nested_runs(1);
860+
}
861+
862+
TEST_CASE("NestedRuns.2threads") {
863+
nested_runs(2);
864+
}
865+
866+
TEST_CASE("NestedRuns.3threads") {
867+
nested_runs(3);
868+
}
869+
870+
TEST_CASE("NestedRuns.4threads") {
871+
nested_runs(4);
872+
}
873+
874+
TEST_CASE("NestedRuns.8threads") {
875+
nested_runs(8);
876+
}
877+
878+
TEST_CASE("NestedRuns.16threads") {
879+
nested_runs(16);
880+
}
881+
882+
783883
// --------------------------------------------------------
784884
// Testcase: ParallelFor
785885
// --------------------------------------------------------

0 commit comments

Comments
 (0)