22#include < taskflow/taskflow.hpp>
33
44
5- class TF_DNNTrainingPattern : public tf ::Framework {
6-
7- public:
5+ struct TF_DNNTrainingPattern : public tf ::Framework {
86
97 TF_DNNTrainingPattern () {
10- init_dnn (_dnn );
11- _build_task_graph ();
8+ init_dnn (dnn );
9+ build_task_graph ();
1210 };
1311
1412 void validate (Eigen::MatrixXf &mat, Eigen::VectorXi &vec) {
15- _dnn .validate (mat, vec);
13+ dnn .validate (mat, vec);
1614 }
1715
18- private:
1916
20- MNIST_DNN _dnn;
2117
22- void _build_task_graph () {
18+ void build_task_graph () {
2319 auto f_task = emplace (
24- [&]() { forward_task (_dnn , IMAGES, LABELS); }
20+ [&]() { forward_task (dnn , IMAGES, LABELS); }
2521 );
2622
2723 std::vector<tf::Task> backward_tasks;
2824 std::vector<tf::Task> update_tasks;
2925
30- for (int j=_dnn .acts .size ()-1 ; j>=0 ; j--) {
26+ for (int j=dnn .acts .size ()-1 ; j>=0 ; j--) {
3127 // backward propagation
3228 auto & b_task = backward_tasks.emplace_back (emplace (
33- [&, i=j] () { backward_task (_dnn , i, IMAGES); }
29+ [&, i=j] () { backward_task (dnn , i, IMAGES); }
3430 ));
3531
3632 // update weight
3733 auto & u_task = update_tasks.emplace_back (
38- emplace ([&, i=j] () { _dnn .update (i); })
34+ emplace ([&, i=j] () { dnn .update (i); })
3935 );
4036
41- if (j + 1u == _dnn .acts .size ()) {
37+ if (j + 1u == dnn .acts .size ()) {
4238 f_task.precede (b_task);
4339 }
4440 else {
@@ -47,35 +43,29 @@ class TF_DNNTrainingPattern : public tf::Framework {
4743 b_task.precede (u_task);
4844 }
4945 }
50- };
51-
52-
5346
47+ MNIST_DNN dnn;
48+ };
5449
55- class DNN : public tf ::Framework {
56- public:
50+ struct TF_DNNTrainingEpoch : public tf ::Framework {
5751
58- DNN (TF_DNNTrainingPattern &dnn_pattern) : _dnn_pattern(dnn_pattern) {
52+ TF_DNNTrainingEpoch (TF_DNNTrainingPattern &dnn_pattern) {
5953 std::vector<tf::Task> tasks;
6054 for (auto i=0u ; i<NUM_ITERATIONS; i++) {
61- tasks.emplace_back (composed_of (_dnn_pattern ));
55+ tasks.emplace_back (composed_of (dnn_pattern ));
6256 }
6357 linearize (tasks);
6458 }
65-
66- private:
67-
68- TF_DNNTrainingPattern& _dnn_pattern;
6959};
7060
7161void run_taskflow (unsigned num_epochs, unsigned num_threads) {
7262
73- tf::Taskflow tf {4 };
63+ tf::Taskflow tf {num_threads };
7464 auto dnn_patterns = std::make_unique<TF_DNNTrainingPattern[]>(NUM_DNNS);
75- auto dnns = std::make_unique<std::unique_ptr<DNN >[]>(NUM_DNNS);
65+ auto dnns = std::make_unique<std::unique_ptr<TF_DNNTrainingEpoch >[]>(NUM_DNNS);
7666
7767 for (size_t i=0 ; i<NUM_DNNS; i++) {
78- dnns[i] = std::make_unique<DNN >(dnn_patterns[i]);
68+ dnns[i] = std::make_unique<TF_DNNTrainingEpoch >(dnn_patterns[i]);
7969 }
8070
8171 std::vector<tf::Task> tasks;
@@ -84,17 +74,18 @@ void run_taskflow(unsigned num_epochs, unsigned num_threads) {
8474 parallel_dnn.composed_of (*(dnns[i]));
8575 }
8676
87- auto t1 = std::chrono::high_resolution_clock::now ();
77+ // auto t1 = std::chrono::high_resolution_clock::now();
8878 parallel_dnn.emplace ([&](){
8979 for (size_t i=0 ; i<NUM_DNNS; i++) {
90- std::cout << " Validate " << i << " th NN: " ;
80+ // std::cout << "Validate " << i << "th NN: ";
9181 dnn_patterns[i].validate (TEST_IMAGES, TEST_LABELS);
9282 }
9383 shuffle (IMAGES, LABELS);
94- report_runtime (t1);
84+ // report_runtime(t1);
9585 }).gather (tasks);
9686
97- tf.run_n (parallel_dnn, 1 ).get ();
87+ tf.run_n (parallel_dnn, num_epochs).get ();
88+ // tf.run_n(parallel_dnn, 4).get();
9889 // std::cout << parallel_dnn.dump() << std::endl;
9990}
10091
0 commit comments