@@ -247,6 +247,7 @@ class BasicTaskflow : public FlowBuilder {
247247
248248 void _schedule (Node&);
249249 void _schedule (PassiveVector<Node*>&);
250+ void _set_module_node (Node*, Framework*);
250251};
251252
252253// ============================================================================
@@ -391,8 +392,17 @@ void BasicTaskflow<E>::Closure::operator () () const {
391392 // regular node type
392393 // The default node work type. We only need to execute the callback if any.
393394 if (auto index=node->_work .index (); index == 0 ) {
394- if (auto &f = std::get<StaticWork>(node->_work ); f != nullptr ){
395- std::invoke (f);
395+ if (node->is_module ()) {
396+ bool first_time = !node->is_spawned ();
397+ std::invoke (std::get<StaticWork>(node->_work ));
398+ if (first_time) {
399+ return ;
400+ }
401+ }
402+ else {
403+ if (auto &f = std::get<StaticWork>(node->_work ); f != nullptr ){
404+ std::invoke (f);
405+ }
396406 }
397407 }
398408 // subflow node type
@@ -451,7 +461,7 @@ void BasicTaskflow<E>::Closure::operator () () const {
451461 }
452462 }
453463 node->_num_dependents = node->_dependents .size ();
454- node->clear_status ();
464+ node->unset_spawned ();
455465 }
456466
457467 // At this point, the node storage might be destructed.
@@ -621,6 +631,10 @@ void BasicTaskflow<E>::wait_for_topologies() {
621631// Each task node has two types of tasks - regular and subflow.
622632template <template <typename ...> typename E>
623633void BasicTaskflow<E>::_schedule(Node& node) {
634+ if (node.is_module () && !node.is_spawned ()) {
635+ _set_module_node (&node, node._module );
636+ assert (node._work .index () == 0 );
637+ }
624638 _executor->emplace (*this , node);
625639}
626640
@@ -633,11 +647,51 @@ void BasicTaskflow<E>::_schedule(PassiveVector<Node*>& nodes) {
633647 std::vector<Closure> closures;
634648 closures.reserve (nodes.size ());
635649 for (auto src : nodes) {
650+ if (src->is_module () && !src->is_spawned ()) {
651+ assert (src->_module != nullptr );
652+ _set_module_node (src, src->_module );
653+ assert (src->_work .index () == 0 );
654+ }
636655 closures.emplace_back (*this , *src);
637656 }
638657 _executor->batch (closures);
639658}
640659
660+
661+ template <template <typename ...> typename E>
662+ void BasicTaskflow<E>::_set_module_node(Node* n, Framework* f) {
663+
664+ n->_work = [node=n, this , tgt {PassiveVector<Node*>()}] () mutable {
665+
666+ // second time to enter this context
667+ if (node->is_spawned ()) {
668+ node->_dependents .resize (node->_dependents .size ()-tgt.size ());
669+ for (auto & t: tgt) {
670+ t->_successors .clear ();
671+ }
672+ return ;
673+ }
674+ // first time to enter this context
675+ node->set_spawned ();
676+
677+ PassiveVector<Node*> src;
678+
679+ for (auto &n: node->_module ->_graph ) {
680+ n._topology = node->_topology ;
681+ if (n.num_dependents () == 0 ) {
682+ src.push_back (&n);
683+ }
684+ if (n.num_successors () == 0 ) {
685+ n.precede (*node);
686+ tgt.push_back (&n);
687+ }
688+ }
689+
690+ _schedule (src);
691+ };
692+ }
693+
694+
641695// Function: dump_topologies
642696template <template <typename ...> typename E>
643697std::string BasicTaskflow<E>::dump_topologies() const {
0 commit comments