|
| 1 | +#include <cuda.h> |
| 2 | +#include <stdint.h> |
| 3 | +#include <torch/torch.h> |
| 4 | +#include <unistd.h> |
| 5 | +#include <iostream> |
| 6 | +#include "cuda_runtime.h" |
| 7 | +#include "mpi.h" |
| 8 | +#include "nccl.h" |
| 9 | + |
| 10 | +std::map<at::ScalarType, MPI_Datatype> mpiDatatype = { |
| 11 | + {at::kByte, MPI_UNSIGNED_CHAR}, |
| 12 | + {at::kChar, MPI_CHAR}, |
| 13 | + {at::kDouble, MPI_DOUBLE}, |
| 14 | + {at::kFloat, MPI_FLOAT}, |
| 15 | + {at::kInt, MPI_INT}, |
| 16 | + {at::kLong, MPI_LONG}, |
| 17 | + {at::kShort, MPI_SHORT}, |
| 18 | +}; |
| 19 | + |
| 20 | +static uint64_t getHostHash(const char* string) { |
| 21 | + // Based on DJB2, result = result * 33 + char |
| 22 | + uint64_t result = 5381; |
| 23 | + for (int c = 0; string[c] != '\0'; c++) { |
| 24 | + result = ((result << 5) + result) + string[c]; |
| 25 | + } |
| 26 | + return result; |
| 27 | +} |
| 28 | + |
| 29 | +static void getHostName(char* hostname, int maxlen) { |
| 30 | + gethostname(hostname, maxlen); |
| 31 | + for (int i = 0; i < maxlen; i++) { |
| 32 | + if (hostname[i] == '.') { |
| 33 | + hostname[i] = '\0'; |
| 34 | + return; |
| 35 | + } |
| 36 | + } |
| 37 | +} |
| 38 | + |
| 39 | +// Define a Convolutional Module |
| 40 | +struct Model : torch::nn::Module { |
| 41 | + Model() |
| 42 | + : conv1(torch::nn::Conv2dOptions(1, 10, 5)), |
| 43 | + conv2(torch::nn::Conv2dOptions(10, 20, 5)), |
| 44 | + fc1(320, 50), |
| 45 | + fc2(50, 10) { |
| 46 | + register_module("conv1", conv1); |
| 47 | + register_module("conv2", conv2); |
| 48 | + register_module("conv2_drop", conv2_drop); |
| 49 | + register_module("fc1", fc1); |
| 50 | + register_module("fc2", fc2); |
| 51 | + } |
| 52 | + |
| 53 | + torch::Tensor forward(torch::Tensor x) { |
| 54 | + x = torch::relu(torch::max_pool2d(conv1->forward(x), 2)); |
| 55 | + x = torch::relu( |
| 56 | + torch::max_pool2d(conv2_drop->forward(conv2->forward(x)), 2)); |
| 57 | + x = x.view({-1, 320}); |
| 58 | + x = torch::relu(fc1->forward(x)); |
| 59 | + x = torch::dropout(x, 0.5, is_training()); |
| 60 | + x = fc2->forward(x); |
| 61 | + return torch::log_softmax(x, 1); |
| 62 | + } |
| 63 | + |
| 64 | + torch::nn::Conv2d conv1; |
| 65 | + torch::nn::Conv2d conv2; |
| 66 | + torch::nn::Dropout2d conv2_drop; |
| 67 | + torch::nn::Linear fc1; |
| 68 | + torch::nn::Linear fc2; |
| 69 | +}; |
| 70 | + |
| 71 | +int main(int argc, char* argv[]) { |
| 72 | + // MPI variables |
| 73 | + int rank, numranks, localRank; |
| 74 | + MPI_Init(&argc, &argv); |
| 75 | + MPI_Comm_size(MPI_COMM_WORLD, &numranks); |
| 76 | + MPI_Comm_rank(MPI_COMM_WORLD, &rank); |
| 77 | + // end MPI variables |
| 78 | + |
| 79 | + torch::Device device = (torch::kCPU); |
| 80 | + if (torch::cuda::is_available()) { |
| 81 | + std::cout << "CUDA is available! Training on GPU." << std::endl; |
| 82 | + device = torch::kCUDA; |
| 83 | + } else { |
| 84 | + std::cout << "CUDA not available. Training on CPU." << std::endl; |
| 85 | + } |
| 86 | + |
| 87 | + // Calculating localRank based on hostname which is used in |
| 88 | + // selecting a GPU |
| 89 | + uint64_t hostHashs[numranks]; |
| 90 | + char hostname[1024]; |
| 91 | + getHostName(hostname, 1024); |
| 92 | + hostHashs[rank] = getHostHash(hostname); |
| 93 | + MPI_Allgather( |
| 94 | + MPI_IN_PLACE, |
| 95 | + 0, |
| 96 | + MPI_DATATYPE_NULL, |
| 97 | + hostHashs, |
| 98 | + sizeof(uint64_t), |
| 99 | + MPI_BYTE, |
| 100 | + MPI_COMM_WORLD); |
| 101 | + for (int p = 0; p < numranks; p++) { |
| 102 | + if (p == rank) |
| 103 | + break; |
| 104 | + if (hostHashs[p] == hostHashs[rank]) |
| 105 | + localRank++; |
| 106 | + } |
| 107 | + |
| 108 | + ncclUniqueId id; |
| 109 | + ncclComm_t comm; |
| 110 | + float *sendbuff, *recvbuff; |
| 111 | + cudaStream_t s; |
| 112 | + |
| 113 | + // get NCCL unique ID at rank 0 and broadcast it to all others |
| 114 | + if (rank == 0) |
| 115 | + ncclGetUniqueId(&id); |
| 116 | + MPI_Bcast((void*)&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD); |
| 117 | + |
| 118 | + // picking a GPU based on localRank, allocate device buffers |
| 119 | + cudaSetDevice(localRank); |
| 120 | + cudaStreamCreate(&s); |
| 121 | + |
| 122 | + // initializing NCCL |
| 123 | + ncclCommInitRank(&comm, numranks, id, rank); |
| 124 | + |
| 125 | + // Timer variables |
| 126 | + auto tstart = 0.0; |
| 127 | + auto tend = 0.0; |
| 128 | + |
| 129 | + // TRAINING |
| 130 | + // Read train dataset |
| 131 | + const char* kDataRoot = "../data"; |
| 132 | + auto train_dataset = |
| 133 | + torch::data::datasets::MNIST(kDataRoot) |
| 134 | + .map(torch::data::transforms::Normalize<>(0.1307, 0.3081)) |
| 135 | + .map(torch::data::transforms::Stack<>()); |
| 136 | + |
| 137 | + // Distributed Random Sampler |
| 138 | + auto data_sampler = torch::data::samplers::DistributedRandomSampler( |
| 139 | + train_dataset.size().value(), numranks, rank, false); |
| 140 | + |
| 141 | + auto num_train_samples_per_proc = train_dataset.size().value() / numranks; |
| 142 | + |
| 143 | + // Generate dataloader |
| 144 | + auto total_batch_size = 64; |
| 145 | + auto batch_size_per_proc = |
| 146 | + total_batch_size / numranks; // effective batch size in each processor |
| 147 | + auto data_loader = torch::data::make_data_loader( |
| 148 | + std::move(train_dataset), data_sampler, batch_size_per_proc); |
| 149 | + |
| 150 | + // setting manual seed |
| 151 | + torch::manual_seed(0); |
| 152 | + |
| 153 | + auto model = std::make_shared<Model>(); |
| 154 | + model->to(device); |
| 155 | + |
| 156 | + auto learning_rate = 1e-2; |
| 157 | + |
| 158 | + torch::optim::SGD optimizer(model->parameters(), learning_rate); |
| 159 | + |
| 160 | + // Number of epochs |
| 161 | + size_t num_epochs = 10; |
| 162 | + |
| 163 | + // start timer |
| 164 | + tstart = MPI_Wtime(); |
| 165 | + |
| 166 | + for (size_t epoch = 1; epoch <= num_epochs; ++epoch) { |
| 167 | + size_t num_correct = 0; |
| 168 | + |
| 169 | + for (auto& batch : *data_loader) { |
| 170 | + auto ip = batch.data.to(device); |
| 171 | + auto op = batch.target.squeeze().to(device); |
| 172 | + |
| 173 | + // convert to required formats |
| 174 | + ip = ip.to(torch::kF32); |
| 175 | + op = op.to(torch::kLong); |
| 176 | + |
| 177 | + // Reset gradients |
| 178 | + model->zero_grad(); |
| 179 | + |
| 180 | + // Execute forward pass |
| 181 | + auto prediction = model->forward(ip); |
| 182 | + |
| 183 | + auto loss = torch::nll_loss(torch::log_softmax(prediction, 1), op); |
| 184 | + |
| 185 | + // Backpropagation |
| 186 | + loss.backward(); |
| 187 | + |
| 188 | + // Averaging the gradients of the parameters in all the processors |
| 189 | + // Note: This may lag behind DistributedDataParallel (DDP) in performance |
| 190 | + // since this synchronizes parameters after backward pass while DDP |
| 191 | + // overlaps synchronizing parameters and computing gradients in backward |
| 192 | + // pass |
| 193 | + |
| 194 | + if (torch::cuda::is_available()) { |
| 195 | + for (auto& param : model->named_parameters()) { |
| 196 | + ncclAllReduce( |
| 197 | + param.value().grad().data_ptr(), |
| 198 | + param.value().grad().data_ptr(), |
| 199 | + param.value().grad().numel(), |
| 200 | + ncclFloat, |
| 201 | + ncclSum, |
| 202 | + comm, |
| 203 | + s); |
| 204 | + cudaStreamSynchronize(s); |
| 205 | + param.value().grad().data() = param.value().grad().data() / numranks; |
| 206 | + } |
| 207 | + } else { |
| 208 | + for (auto& param : model->named_parameters()) { |
| 209 | + MPI_Allreduce( |
| 210 | + MPI_IN_PLACE, |
| 211 | + param.value().grad().data_ptr(), |
| 212 | + param.value().grad().numel(), |
| 213 | + mpiDatatype.at(param.value().grad().scalar_type()), |
| 214 | + MPI_SUM, |
| 215 | + MPI_COMM_WORLD); |
| 216 | + |
| 217 | + param.value().grad().data() = param.value().grad().data() / numranks; |
| 218 | + } |
| 219 | + } |
| 220 | + |
| 221 | + // Update parameters |
| 222 | + optimizer.step(); |
| 223 | + |
| 224 | + auto guess = prediction.argmax(1); |
| 225 | + num_correct += torch::sum(guess.eq_(op)).item<int64_t>(); |
| 226 | + } // end batch loader |
| 227 | + |
| 228 | + auto accuracy = 100.0 * num_correct / num_train_samples_per_proc; |
| 229 | + |
| 230 | + std::cout << "Accuracy in rank " << rank << " in epoch " << epoch << " - " |
| 231 | + << accuracy << std::endl; |
| 232 | + |
| 233 | + } // end epoch |
| 234 | + |
| 235 | + // end timer |
| 236 | + tend = MPI_Wtime(); |
| 237 | + if (rank == 0) { |
| 238 | + std::cout << "Training time - " << (tend - tstart) << std::endl; |
| 239 | + } |
| 240 | + |
| 241 | + // TESTING ONLY IN RANK 0 |
| 242 | + if (rank == 0) { |
| 243 | + auto test_dataset = |
| 244 | + torch::data::datasets::MNIST( |
| 245 | + kDataRoot, torch::data::datasets::MNIST::Mode::kTest) |
| 246 | + .map(torch::data::transforms::Normalize<>(0.1307, 0.3081)) |
| 247 | + .map(torch::data::transforms::Stack<>()); |
| 248 | + |
| 249 | + auto num_test_samples = test_dataset.size().value(); |
| 250 | + auto test_loader = torch::data::make_data_loader( |
| 251 | + std::move(test_dataset), num_test_samples); |
| 252 | + |
| 253 | + model->eval(); // enable eval mode to prevent backprop |
| 254 | + |
| 255 | + size_t num_correct = 0; |
| 256 | + |
| 257 | + for (auto& batch : *test_loader) { |
| 258 | + auto ip = batch.data.to(device); |
| 259 | + auto op = batch.target.squeeze().to(device); |
| 260 | + |
| 261 | + // convert to required format |
| 262 | + ip = ip.to(torch::kF32); |
| 263 | + op = op.to(torch::kLong); |
| 264 | + |
| 265 | + auto prediction = model->forward(ip); |
| 266 | + |
| 267 | + auto loss = torch::nll_loss(torch::log_softmax(prediction, 1), op); |
| 268 | + |
| 269 | + std::cout << "Test loss - " << loss.item<float>() << std::endl; |
| 270 | + |
| 271 | + auto guess = prediction.argmax(1); |
| 272 | + |
| 273 | + num_correct += torch::sum(guess.eq_(op)).item<int64_t>(); |
| 274 | + |
| 275 | + } // end test loader |
| 276 | + |
| 277 | + std::cout << "Num correct - " << num_correct << std::endl; |
| 278 | + std::cout << "Test Accuracy - " << 100.0 * num_correct / num_test_samples |
| 279 | + << std::endl; |
| 280 | + } // end rank 0 |
| 281 | + |
| 282 | + ncclCommDestroy(comm); |
| 283 | + |
| 284 | + MPI_Finalize(); |
| 285 | +} |
0 commit comments