Skip to content

Commit f66f475

Browse files
committed
Refactor parallel DNN
1 parent a1181fd commit f66f475

5 files changed

Lines changed: 40 additions & 56 deletions

File tree

benchmark/parallel_dnn/dnn.hpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ struct MNIST {
242242

243243
void validate() {
244244
Eigen::MatrixXf res = test_images;
245-
auto t1 = std::chrono::high_resolution_clock::now();
245+
//auto t1 = std::chrono::high_resolution_clock::now();
246246
for(size_t i=0; i<acts.size(); i++) {
247247
res = res * Ws[i] + Bs[i].replicate(res.rows(), 1);
248248
if(acts[i] == Activation::RELU) {
@@ -252,8 +252,8 @@ struct MNIST {
252252
sigmoid(res);
253253
}
254254
}
255-
auto t2 = std::chrono::high_resolution_clock::now();
256-
std::cout << "Infer runtime: " << time_diff(t1, t2) << " ms\n";
255+
//auto t2 = std::chrono::high_resolution_clock::now();
256+
//std::cout << "Infer runtime: " << time_diff(t1, t2) << " ms\n";
257257

258258
size_t correct_num {0};
259259
for(int k=0; k<res.rows(); k++) {
@@ -263,7 +263,7 @@ struct MNIST {
263263
correct_num ++;
264264
}
265265
}
266-
std::cout << "Accuracy: " << correct_num << '/' << res.rows() << '\n';
266+
//std::cout << "Accuracy: " << correct_num << '/' << res.rows() << '\n';
267267
}
268268

269269

@@ -406,7 +406,7 @@ struct MNIST_DNN {
406406
correct_num ++;
407407
}
408408
}
409-
std::cout << "Accuracy: " << correct_num << '/' << res.rows() << '\n';
409+
//std::cout << "Accuracy: " << correct_num << '/' << res.rows() << '\n';
410410
}
411411

412412
// Parameter functions ------------------------------------------------------

benchmark/parallel_dnn/main.cpp

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ std::chrono::milliseconds measure_time_taskflow(
88
unsigned num_threads
99
) {
1010
std::puts("Taskflow");
11-
//auto dnn {build_dnn(num_epochs)};
1211
auto t1 = std::chrono::high_resolution_clock::now();
1312
run_taskflow(num_epochs, num_threads);
1413
auto t2 = std::chrono::high_resolution_clock::now();
@@ -33,10 +32,7 @@ std::chrono::milliseconds measure_time_tbb(
3332
unsigned num_threads
3433
) {
3534
std::puts("TBB");
36-
//auto dnn {build_dnn(num_epochs)};
3735
auto t1 = std::chrono::high_resolution_clock::now();
38-
//run_tbb(dnn, num_threads);
39-
//run_tbb(num_epochs, num_threads);
4036
run_tbb(num_epochs, num_threads);
4137
auto t2 = std::chrono::high_resolution_clock::now();
4238
return std::chrono::duration_cast<std::chrono::milliseconds>(t2 - t1);
@@ -88,10 +84,10 @@ int main(int argc, char *argv[]){
8884
double tf_time {0.0};
8985

9086
for(int j=0; j<rounds; ++j) {
91-
//omp_time += measure_time_omp(epoch, num_threads).count();
92-
//tbb_time += measure_time_tbb(epoch, num_threads).count();
87+
omp_time += measure_time_omp(epoch, num_threads).count();
88+
tbb_time += measure_time_tbb(epoch, num_threads).count();
9389
tf_time += measure_time_taskflow(epoch, num_threads).count();
94-
exit(0);
90+
//exit(0);
9591
}
9692

9793
std::cout << std::setw(12) << epoch

benchmark/parallel_dnn/omp.cpp

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ void omp_dnn(MNIST_DNN& D, unsigned num_iteration) {
1515
#pragma omp single
1616
{
1717
for(auto i=0u; i<num_iteration; i++) {
18-
//for(auto i=0u; i<2; i++) {
1918
// Forward Task
2019
if(i == 0) {
2120
#pragma omp task depend (out: dep_f[i]) shared(D, IMAGES, LABELS)
@@ -94,15 +93,14 @@ void run_omp(unsigned num_epochs, unsigned num_threads) {
9493
init_dnn(dnns[i]);
9594
}
9695

97-
//omp_set_num_threads(num_threads);
98-
omp_set_num_threads(4);
96+
omp_set_num_threads(num_threads);
9997

100-
auto t1 = std::chrono::high_resolution_clock::now();
98+
//auto t1 = std::chrono::high_resolution_clock::now();
10199
#pragma omp parallel
102100
{
103101
#pragma omp single
104102
{
105-
for(auto i=0u; i<100; i++) {
103+
for(auto i=0u; i<num_epochs; i++) {
106104
for(auto j=0u; j<NUM_DNNS; j++) {
107105
#pragma omp task firstprivate(j) shared(dnns)
108106
{
@@ -112,11 +110,11 @@ void run_omp(unsigned num_epochs, unsigned num_threads) {
112110
#pragma omp taskwait
113111

114112
for(auto j=0u; j<NUM_DNNS; j++) {
115-
std::cout << "Validate " << j << "th NN: ";
113+
//std::cout << "Validate " << j << "th NN: ";
116114
dnns[j].validate(TEST_IMAGES, TEST_LABELS);
117115
}
118116
shuffle(IMAGES, LABELS);
119-
report_runtime(t1);
117+
//report_runtime(t1);
120118
}
121119
}
122120
}

benchmark/parallel_dnn/taskflow.cpp

Lines changed: 23 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2,43 +2,39 @@
22
#include <taskflow/taskflow.hpp>
33

44

5-
class TF_DNNTrainingPattern : public tf::Framework {
6-
7-
public:
5+
struct TF_DNNTrainingPattern : public tf::Framework {
86

97
TF_DNNTrainingPattern() {
10-
init_dnn(_dnn);
11-
_build_task_graph();
8+
init_dnn(dnn);
9+
build_task_graph();
1210
};
1311

1412
void validate(Eigen::MatrixXf &mat, Eigen::VectorXi &vec) {
15-
_dnn.validate(mat, vec);
13+
dnn.validate(mat, vec);
1614
}
1715

18-
private:
1916

20-
MNIST_DNN _dnn;
2117

22-
void _build_task_graph() {
18+
void build_task_graph() {
2319
auto f_task = emplace(
24-
[&]() { forward_task(_dnn, IMAGES, LABELS); }
20+
[&]() { forward_task(dnn, IMAGES, LABELS); }
2521
);
2622

2723
std::vector<tf::Task> backward_tasks;
2824
std::vector<tf::Task> update_tasks;
2925

30-
for(int j=_dnn.acts.size()-1; j>=0; j--) {
26+
for(int j=dnn.acts.size()-1; j>=0; j--) {
3127
// backward propagation
3228
auto& b_task = backward_tasks.emplace_back(emplace(
33-
[&, i=j] () { backward_task(_dnn, i, IMAGES); }
29+
[&, i=j] () { backward_task(dnn, i, IMAGES); }
3430
));
3531

3632
// update weight
3733
auto& u_task = update_tasks.emplace_back(
38-
emplace([&, i=j] () { _dnn.update(i); })
34+
emplace([&, i=j] () { dnn.update(i); })
3935
);
4036

41-
if(j + 1u == _dnn.acts.size()) {
37+
if(j + 1u == dnn.acts.size()) {
4238
f_task.precede(b_task);
4339
}
4440
else {
@@ -47,35 +43,29 @@ class TF_DNNTrainingPattern : public tf::Framework {
4743
b_task.precede(u_task);
4844
}
4945
}
50-
};
51-
52-
5346

47+
MNIST_DNN dnn;
48+
};
5449

55-
class DNN : public tf::Framework {
56-
public:
50+
struct TF_DNNTrainingEpoch : public tf::Framework {
5751

58-
DNN(TF_DNNTrainingPattern &dnn_pattern) : _dnn_pattern(dnn_pattern) {
52+
TF_DNNTrainingEpoch(TF_DNNTrainingPattern &dnn_pattern) {
5953
std::vector<tf::Task> tasks;
6054
for(auto i=0u; i<NUM_ITERATIONS; i++) {
61-
tasks.emplace_back(composed_of(_dnn_pattern));
55+
tasks.emplace_back(composed_of(dnn_pattern));
6256
}
6357
linearize(tasks);
6458
}
65-
66-
private:
67-
68-
TF_DNNTrainingPattern& _dnn_pattern;
6959
};
7060

7161
void run_taskflow(unsigned num_epochs, unsigned num_threads) {
7262

73-
tf::Taskflow tf {4};
63+
tf::Taskflow tf {num_threads};
7464
auto dnn_patterns = std::make_unique<TF_DNNTrainingPattern[]>(NUM_DNNS);
75-
auto dnns = std::make_unique<std::unique_ptr<DNN>[]>(NUM_DNNS);
65+
auto dnns = std::make_unique<std::unique_ptr<TF_DNNTrainingEpoch>[]>(NUM_DNNS);
7666

7767
for(size_t i=0; i<NUM_DNNS; i++) {
78-
dnns[i] = std::make_unique<DNN>(dnn_patterns[i]);
68+
dnns[i] = std::make_unique<TF_DNNTrainingEpoch>(dnn_patterns[i]);
7969
}
8070

8171
std::vector<tf::Task> tasks;
@@ -84,17 +74,18 @@ void run_taskflow(unsigned num_epochs, unsigned num_threads) {
8474
parallel_dnn.composed_of(*(dnns[i]));
8575
}
8676

87-
auto t1 = std::chrono::high_resolution_clock::now();
77+
//auto t1 = std::chrono::high_resolution_clock::now();
8878
parallel_dnn.emplace([&](){
8979
for(size_t i=0; i<NUM_DNNS; i++) {
90-
std::cout << "Validate " << i << "th NN: ";
80+
//std::cout << "Validate " << i << "th NN: ";
9181
dnn_patterns[i].validate(TEST_IMAGES, TEST_LABELS);
9282
}
9383
shuffle(IMAGES, LABELS);
94-
report_runtime(t1);
84+
//report_runtime(t1);
9585
}).gather(tasks);
9686

97-
tf.run_n(parallel_dnn, 1).get();
87+
tf.run_n(parallel_dnn, num_epochs).get();
88+
//tf.run_n(parallel_dnn, 4).get();
9889
//std::cout << parallel_dnn.dump() << std::endl;
9990
}
10091

benchmark/parallel_dnn/tbb.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ void run_tbb(unsigned num_epochs, unsigned num_threads) {
8181
auto sync_node = std::make_unique<continue_node<continue_msg>>(parallel_dnn,
8282
[&](const continue_msg&) {
8383
for(size_t i=0; i<NUM_DNNS; i++) {
84-
std::cout << "Validate " << i << "th NN: ";
84+
//std::cout << "Validate " << i << "th NN: ";
8585
dnn_patterns[i].dnn.validate(TEST_IMAGES, TEST_LABELS);
8686
}
8787
shuffle(IMAGES, LABELS);
@@ -92,13 +92,12 @@ void run_tbb(unsigned num_epochs, unsigned num_threads) {
9292
make_edge(*(dnns[i]), *sync_node);
9393
}
9494

95-
auto t1 = std::chrono::high_resolution_clock::now();
96-
for(auto i=0u; i<100; i++) {
95+
//auto t1 = std::chrono::high_resolution_clock::now();
96+
for(auto i=0u; i<num_epochs; i++) {
9797
for(auto i=0u; i<NUM_DNNS; i++) {
9898
dnns[i]->try_put(continue_msg());
9999
}
100-
101100
parallel_dnn.wait_for_all();
102-
report_runtime(t1);
101+
//report_runtime(t1);
103102
}
104103
}

0 commit comments

Comments
 (0)