Skip to content

Commit a0bfb9f

Browse files
adding a separate file for GPU
1 parent 31fb2fc commit a0bfb9f

1 file changed

Lines changed: 285 additions & 0 deletions

File tree

cpp/distributed/dist-mnist-gpu.cpp

Lines changed: 285 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,285 @@
1+
#include <cuda.h>
2+
#include <stdint.h>
3+
#include <torch/torch.h>
4+
#include <unistd.h>
5+
#include <iostream>
6+
#include "cuda_runtime.h"
7+
#include "mpi.h"
8+
#include "nccl.h"
9+
10+
std::map<at::ScalarType, MPI_Datatype> mpiDatatype = {
11+
{at::kByte, MPI_UNSIGNED_CHAR},
12+
{at::kChar, MPI_CHAR},
13+
{at::kDouble, MPI_DOUBLE},
14+
{at::kFloat, MPI_FLOAT},
15+
{at::kInt, MPI_INT},
16+
{at::kLong, MPI_LONG},
17+
{at::kShort, MPI_SHORT},
18+
};
19+
20+
static uint64_t getHostHash(const char* string) {
21+
// Based on DJB2, result = result * 33 + char
22+
uint64_t result = 5381;
23+
for (int c = 0; string[c] != '\0'; c++) {
24+
result = ((result << 5) + result) + string[c];
25+
}
26+
return result;
27+
}
28+
29+
static void getHostName(char* hostname, int maxlen) {
30+
gethostname(hostname, maxlen);
31+
for (int i = 0; i < maxlen; i++) {
32+
if (hostname[i] == '.') {
33+
hostname[i] = '\0';
34+
return;
35+
}
36+
}
37+
}
38+
39+
// Define a Convolutional Module
40+
struct Model : torch::nn::Module {
41+
Model()
42+
: conv1(torch::nn::Conv2dOptions(1, 10, 5)),
43+
conv2(torch::nn::Conv2dOptions(10, 20, 5)),
44+
fc1(320, 50),
45+
fc2(50, 10) {
46+
register_module("conv1", conv1);
47+
register_module("conv2", conv2);
48+
register_module("conv2_drop", conv2_drop);
49+
register_module("fc1", fc1);
50+
register_module("fc2", fc2);
51+
}
52+
53+
torch::Tensor forward(torch::Tensor x) {
54+
x = torch::relu(torch::max_pool2d(conv1->forward(x), 2));
55+
x = torch::relu(
56+
torch::max_pool2d(conv2_drop->forward(conv2->forward(x)), 2));
57+
x = x.view({-1, 320});
58+
x = torch::relu(fc1->forward(x));
59+
x = torch::dropout(x, 0.5, is_training());
60+
x = fc2->forward(x);
61+
return torch::log_softmax(x, 1);
62+
}
63+
64+
torch::nn::Conv2d conv1;
65+
torch::nn::Conv2d conv2;
66+
torch::nn::Dropout2d conv2_drop;
67+
torch::nn::Linear fc1;
68+
torch::nn::Linear fc2;
69+
};
70+
71+
int main(int argc, char* argv[]) {
72+
// MPI variables
73+
int rank, numranks, localRank;
74+
MPI_Init(&argc, &argv);
75+
MPI_Comm_size(MPI_COMM_WORLD, &numranks);
76+
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
77+
// end MPI variables
78+
79+
torch::Device device = (torch::kCPU);
80+
if (torch::cuda::is_available()) {
81+
std::cout << "CUDA is available! Training on GPU." << std::endl;
82+
device = torch::kCUDA;
83+
} else {
84+
std::cout << "CUDA not available. Training on CPU." << std::endl;
85+
}
86+
87+
// Calculating localRank based on hostname which is used in
88+
// selecting a GPU
89+
uint64_t hostHashs[numranks];
90+
char hostname[1024];
91+
getHostName(hostname, 1024);
92+
hostHashs[rank] = getHostHash(hostname);
93+
MPI_Allgather(
94+
MPI_IN_PLACE,
95+
0,
96+
MPI_DATATYPE_NULL,
97+
hostHashs,
98+
sizeof(uint64_t),
99+
MPI_BYTE,
100+
MPI_COMM_WORLD);
101+
for (int p = 0; p < numranks; p++) {
102+
if (p == rank)
103+
break;
104+
if (hostHashs[p] == hostHashs[rank])
105+
localRank++;
106+
}
107+
108+
ncclUniqueId id;
109+
ncclComm_t comm;
110+
float *sendbuff, *recvbuff;
111+
cudaStream_t s;
112+
113+
// get NCCL unique ID at rank 0 and broadcast it to all others
114+
if (rank == 0)
115+
ncclGetUniqueId(&id);
116+
MPI_Bcast((void*)&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD);
117+
118+
// picking a GPU based on localRank, allocate device buffers
119+
cudaSetDevice(localRank);
120+
cudaStreamCreate(&s);
121+
122+
// initializing NCCL
123+
ncclCommInitRank(&comm, numranks, id, rank);
124+
125+
// Timer variables
126+
auto tstart = 0.0;
127+
auto tend = 0.0;
128+
129+
// TRAINING
130+
// Read train dataset
131+
const char* kDataRoot = "../data";
132+
auto train_dataset =
133+
torch::data::datasets::MNIST(kDataRoot)
134+
.map(torch::data::transforms::Normalize<>(0.1307, 0.3081))
135+
.map(torch::data::transforms::Stack<>());
136+
137+
// Distributed Random Sampler
138+
auto data_sampler = torch::data::samplers::DistributedRandomSampler(
139+
train_dataset.size().value(), numranks, rank, false);
140+
141+
auto num_train_samples_per_proc = train_dataset.size().value() / numranks;
142+
143+
// Generate dataloader
144+
auto total_batch_size = 64;
145+
auto batch_size_per_proc =
146+
total_batch_size / numranks; // effective batch size in each processor
147+
auto data_loader = torch::data::make_data_loader(
148+
std::move(train_dataset), data_sampler, batch_size_per_proc);
149+
150+
// setting manual seed
151+
torch::manual_seed(0);
152+
153+
auto model = std::make_shared<Model>();
154+
model->to(device);
155+
156+
auto learning_rate = 1e-2;
157+
158+
torch::optim::SGD optimizer(model->parameters(), learning_rate);
159+
160+
// Number of epochs
161+
size_t num_epochs = 10;
162+
163+
// start timer
164+
tstart = MPI_Wtime();
165+
166+
for (size_t epoch = 1; epoch <= num_epochs; ++epoch) {
167+
size_t num_correct = 0;
168+
169+
for (auto& batch : *data_loader) {
170+
auto ip = batch.data.to(device);
171+
auto op = batch.target.squeeze().to(device);
172+
173+
// convert to required formats
174+
ip = ip.to(torch::kF32);
175+
op = op.to(torch::kLong);
176+
177+
// Reset gradients
178+
model->zero_grad();
179+
180+
// Execute forward pass
181+
auto prediction = model->forward(ip);
182+
183+
auto loss = torch::nll_loss(torch::log_softmax(prediction, 1), op);
184+
185+
// Backpropagation
186+
loss.backward();
187+
188+
// Averaging the gradients of the parameters in all the processors
189+
// Note: This may lag behind DistributedDataParallel (DDP) in performance
190+
// since this synchronizes parameters after backward pass while DDP
191+
// overlaps synchronizing parameters and computing gradients in backward
192+
// pass
193+
194+
if (torch::cuda::is_available()) {
195+
for (auto& param : model->named_parameters()) {
196+
ncclAllReduce(
197+
param.value().grad().data_ptr(),
198+
param.value().grad().data_ptr(),
199+
param.value().grad().numel(),
200+
ncclFloat,
201+
ncclSum,
202+
comm,
203+
s);
204+
cudaStreamSynchronize(s);
205+
param.value().grad().data() = param.value().grad().data() / numranks;
206+
}
207+
} else {
208+
for (auto& param : model->named_parameters()) {
209+
MPI_Allreduce(
210+
MPI_IN_PLACE,
211+
param.value().grad().data_ptr(),
212+
param.value().grad().numel(),
213+
mpiDatatype.at(param.value().grad().scalar_type()),
214+
MPI_SUM,
215+
MPI_COMM_WORLD);
216+
217+
param.value().grad().data() = param.value().grad().data() / numranks;
218+
}
219+
}
220+
221+
// Update parameters
222+
optimizer.step();
223+
224+
auto guess = prediction.argmax(1);
225+
num_correct += torch::sum(guess.eq_(op)).item<int64_t>();
226+
} // end batch loader
227+
228+
auto accuracy = 100.0 * num_correct / num_train_samples_per_proc;
229+
230+
std::cout << "Accuracy in rank " << rank << " in epoch " << epoch << " - "
231+
<< accuracy << std::endl;
232+
233+
} // end epoch
234+
235+
// end timer
236+
tend = MPI_Wtime();
237+
if (rank == 0) {
238+
std::cout << "Training time - " << (tend - tstart) << std::endl;
239+
}
240+
241+
// TESTING ONLY IN RANK 0
242+
if (rank == 0) {
243+
auto test_dataset =
244+
torch::data::datasets::MNIST(
245+
kDataRoot, torch::data::datasets::MNIST::Mode::kTest)
246+
.map(torch::data::transforms::Normalize<>(0.1307, 0.3081))
247+
.map(torch::data::transforms::Stack<>());
248+
249+
auto num_test_samples = test_dataset.size().value();
250+
auto test_loader = torch::data::make_data_loader(
251+
std::move(test_dataset), num_test_samples);
252+
253+
model->eval(); // enable eval mode to prevent backprop
254+
255+
size_t num_correct = 0;
256+
257+
for (auto& batch : *test_loader) {
258+
auto ip = batch.data.to(device);
259+
auto op = batch.target.squeeze().to(device);
260+
261+
// convert to required format
262+
ip = ip.to(torch::kF32);
263+
op = op.to(torch::kLong);
264+
265+
auto prediction = model->forward(ip);
266+
267+
auto loss = torch::nll_loss(torch::log_softmax(prediction, 1), op);
268+
269+
std::cout << "Test loss - " << loss.item<float>() << std::endl;
270+
271+
auto guess = prediction.argmax(1);
272+
273+
num_correct += torch::sum(guess.eq_(op)).item<int64_t>();
274+
275+
} // end test loader
276+
277+
std::cout << "Num correct - " << num_correct << std::endl;
278+
std::cout << "Test Accuracy - " << 100.0 * num_correct / num_test_samples
279+
<< std::endl;
280+
} // end rank 0
281+
282+
ncclCommDestroy(comm);
283+
284+
MPI_Finalize();
285+
}

0 commit comments

Comments
 (0)