Skip to content

Instantly share code, notes, and snippets.

@yzchen
Created April 3, 2019 14:10
Show Gist options
  • Save yzchen/abc67bfe0bfd75674951f6eccc6909c1 to your computer and use it in GitHub Desktop.
Save yzchen/abc67bfe0bfd75674951f6eccc6909c1 to your computer and use it in GitHub Desktop.
Benchmark allreduce in MPI
#include <cstdio>
#include <cstdlib>
#include <iostream>
#include <vector>
#include <mpi.h>
#include <ctime>
#include <cmath>
using namespace std;
static const int NUMSTATS = 100;
int main(int argc, char** argv) {
if (argc != 2) {
fprintf(stderr, "Usage: avg nelems\n");
exit(1);
}
int nelems = atoi(argv[1]);
MPI_Init(NULL, NULL);
int rank;
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
int size;
MPI_Comm_size(MPI_COMM_WORLD, &size);
vector<float> rand_nums(nelems);
for (int i = 0; i < nelems; i++) {
rand_nums[i] = static_cast <float> (rand()) / static_cast <float> (RAND_MAX);
}
// Reduce all of the local nums into the all sums
vector<float> all_sums(nelems, 0.f);
clock_t totaltime = 0;
for (int i = 0; i < NUMSTATS; i++) {
clock_t start = clock();
MPI_Allreduce(rand_nums.data(), all_sums.data(), nelems, MPI_FLOAT, MPI_SUM, MPI_COMM_WORLD);
clock_t end = clock();
totaltime += (end - start);
}
if (rank == 0) {
printf("average allreduce time : %f s\n", (totaltime + 1.0f) / CLOCKS_PER_SEC / NUMSTATS);
}
MPI_Barrier(MPI_COMM_WORLD);
MPI_Finalize();
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment