forked from taskflow/taskflow
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtaskflow.cpp
More file actions
92 lines (70 loc) · 2.29 KB
/
taskflow.cpp
File metadata and controls
92 lines (70 loc) · 2.29 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
#include "dnn.hpp"
#include <taskflow/taskflow.hpp>
struct TF_DNNTrainingPattern : public tf::Taskflow {
TF_DNNTrainingPattern() {
init_dnn(dnn, rand_rate());
build_task_graph();
};
void validate(Eigen::MatrixXf &mat, Eigen::VectorXi &vec) {
dnn.validate(mat, vec);
}
void build_task_graph() {
auto f_task = emplace(
[&]() { forward_task(dnn, IMAGES, LABELS); }
);
std::vector<tf::Task> backward_tasks;
std::vector<tf::Task> update_tasks;
for(int j=dnn.acts.size()-1; j>=0; j--) {
// backward propagation
auto& b_task = backward_tasks.emplace_back(emplace(
[&, i=j] () { backward_task(dnn, i, IMAGES); }
));
// update weight
auto& u_task = update_tasks.emplace_back(
emplace([&, i=j] () { dnn.update(i); })
);
if(j + 1u == dnn.acts.size()) {
f_task.precede(b_task);
}
else {
backward_tasks[backward_tasks.size()-2].precede(b_task);
}
b_task.precede(u_task);
}
}
MNIST_DNN dnn;
};
struct TF_DNNTrainingEpoch : public tf::Taskflow {
TF_DNNTrainingEpoch(TF_DNNTrainingPattern &dnn_pattern) {
std::vector<tf::Task> tasks;
for(auto i=0u; i<NUM_ITERATIONS; i++) {
tasks.emplace_back(composed_of(dnn_pattern));
}
linearize(tasks);
}
};
void run_taskflow(const unsigned num_epochs, const unsigned num_threads) {
tf::Executor executor(num_threads);
auto dnn_patterns = std::make_unique<TF_DNNTrainingPattern[]>(NUM_DNNS);
auto dnns = std::make_unique<std::unique_ptr<TF_DNNTrainingEpoch>[]>(NUM_DNNS);
for(size_t i=0; i<NUM_DNNS; i++) {
dnns[i] = std::make_unique<TF_DNNTrainingEpoch>(dnn_patterns[i]);
}
std::vector<tf::Task> tasks;
tf::Taskflow parallel_dnn;
for(size_t i=0; i<NUM_DNNS; i++) {
tasks.emplace_back(parallel_dnn.composed_of(*(dnns[i])));
}
//auto t1 = std::chrono::high_resolution_clock::now();
parallel_dnn.emplace([&](){
for(size_t i=0; i<NUM_DNNS; i++) {
//std::cout << "Validate " << i << "th NN: ";
dnn_patterns[i].validate(TEST_IMAGES, TEST_LABELS);
}
shuffle(IMAGES, LABELS);
//report_runtime(t1);
}).gather(tasks);
//std::cout << parallel_dnn.dump() << std::endl;
//tf.run_n(parallel_dnn, 100).get();
executor.run_n(parallel_dnn, num_epochs).get();
}