Skip to content

Commit

Permalink
Clean up mpi-driver
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasgibson committed Aug 12, 2022
1 parent 738af0f commit ffe986d
Showing 1 changed file with 27 additions and 49 deletions.
76 changes: 27 additions & 49 deletions src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,20 +87,22 @@ int main(int argc, char *argv[])
{
#if USE_MPI
int provided;
int localRank;

MPI_Init_thread(&argc, &argv, MPI_THREAD_FUNNELED, &provided);
if (provided < MPI_THREAD_FUNNELED) {

if (provided < MPI_THREAD_FUNNELED)
MPI_Abort(MPI_COMM_WORLD, provided);
}

MPI_Comm_rank(MPI_COMM_WORLD, &rank);
MPI_Comm_size(MPI_COMM_WORLD, &procs);

// Each local rank on a given node will own a single device/GCD
MPI_Comm shmcomm;
// Each rank will run the benchmark on a single device
MPI_Comm shared_comm;
MPI_Comm_split_type(MPI_COMM_WORLD, MPI_COMM_TYPE_SHARED, 0,
MPI_INFO_NULL, &shmcomm);
int localRank;
MPI_Comm_rank(shmcomm, &localRank);
MPI_INFO_NULL, &shared_comm);
MPI_Comm_rank(shared_comm, &localRank);

// Set device index to be the local MPI rank
deviceIndex = localRank;
#endif
Expand All @@ -110,16 +112,17 @@ int main(int argc, char *argv[])
if (!output_as_csv)
{
#if USE_MPI
if (rank == 0) {
if (rank == 0)
#endif
{
std::cout
<< "BabelStream" << std::endl
<< "Version: " << VERSION_STRING << std::endl
<< "Implementation: " << IMPLEMENTATION_STRING << std::endl;
#if USE_MPI
std::cout << "Number of MPI ranks: " << procs << std::endl;
}
#endif
}
}

if (use_float)
Expand All @@ -145,54 +148,48 @@ std::vector<std::vector<double>> run_all(Stream<T> *stream, T& sum)
// Declare timers
std::chrono::high_resolution_clock::time_point t1, t2;

#if USE_MPI
// Set MPI data type for the dot-product reduction
MPI_Datatype MPI_DTYPE = use_float ? MPI_FLOAT : MPI_DOUBLE;
#endif

// Main loop
for (unsigned int k = 0; k < num_times; k++)
{
#if USE_MPI
MPI_Barrier(MPI_COMM_WORLD);
#endif

// Execute Copy
t1 = std::chrono::high_resolution_clock::now();
stream->copy();
#if USE_MPI
MPI_Barrier(MPI_COMM_WORLD);
#endif
t2 = std::chrono::high_resolution_clock::now();
timings[0].push_back(std::chrono::duration_cast<std::chrono::duration<double> >(t2 - t1).count());

// Execute Mul
t1 = std::chrono::high_resolution_clock::now();
stream->mul();
#if USE_MPI
MPI_Barrier(MPI_COMM_WORLD);
#endif
t2 = std::chrono::high_resolution_clock::now();
timings[1].push_back(std::chrono::duration_cast<std::chrono::duration<double> >(t2 - t1).count());

// Execute Add
t1 = std::chrono::high_resolution_clock::now();
stream->add();
#if USE_MPI
MPI_Barrier(MPI_COMM_WORLD);
#endif
t2 = std::chrono::high_resolution_clock::now();
timings[2].push_back(std::chrono::duration_cast<std::chrono::duration<double> >(t2 - t1).count());

// Execute Triad
t1 = std::chrono::high_resolution_clock::now();
stream->triad();
#if USE_MPI
MPI_Barrier(MPI_COMM_WORLD);
#endif
t2 = std::chrono::high_resolution_clock::now();
timings[3].push_back(std::chrono::duration_cast<std::chrono::duration<double> >(t2 - t1).count());

// Execute Dot
#if USE_MPI
// Synchronize ranks before computing dot-product
MPI_Barrier(MPI_COMM_WORLD);
#endif
t1 = std::chrono::high_resolution_clock::now();
sum = stream->dot();
#if USE_MPI
MPI_Allreduce(MPI_IN_PLACE, &sum, 1, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD);
MPI_Allreduce(MPI_IN_PLACE, &sum, 1, MPI_DTYPE, MPI_SUM, MPI_COMM_WORLD);
#endif
t2 = std::chrono::high_resolution_clock::now();
timings[4].push_back(std::chrono::duration_cast<std::chrono::duration<double> >(t2 - t1).count());
Expand All @@ -217,9 +214,6 @@ std::vector<std::vector<double>> run_triad(Stream<T> *stream)
t1 = std::chrono::high_resolution_clock::now();
for (unsigned int k = 0; k < num_times; k++)
{
#if USE_MPI
MPI_Barrier(MPI_COMM_WORLD);
#endif
stream->triad();
}
t2 = std::chrono::high_resolution_clock::now();
Expand All @@ -241,14 +235,8 @@ std::vector<std::vector<double>> run_nstream(Stream<T> *stream)

// Run nstream in loop
for (int k = 0; k < num_times; k++) {
#if USE_MPI
MPI_Barrier(MPI_COMM_WORLD);
#endif
t1 = std::chrono::high_resolution_clock::now();
stream->nstream();
#if USE_MPI
MPI_Barrier(MPI_COMM_WORLD);
#endif
t2 = std::chrono::high_resolution_clock::now();
timings[0].push_back(std::chrono::duration_cast<std::chrono::duration<double> >(t2 - t1).count());
}
Expand Down Expand Up @@ -416,10 +404,6 @@ void run()


stream->read_arrays(a, b, c);
#if USE_MPI
// Only check solutions on rank 0 in case verificaiton fails
if (rank == 0)
#endif
check_solution<T>(num_times, a, b, c, sum);

// Display timing results
Expand Down Expand Up @@ -485,17 +469,11 @@ void run()
double max = *minmax.second;

#if USE_MPI
// Collate timings
if (rank == 0)
{
MPI_Reduce(MPI_IN_PLACE, &min, 1, MPI_DOUBLE, MPI_MIN, 0, MPI_COMM_WORLD);
MPI_Reduce(MPI_IN_PLACE, &max, 1, MPI_DOUBLE, MPI_MAX, 0, MPI_COMM_WORLD);
}
else
{
MPI_Reduce(&min, NULL, 1, MPI_DOUBLE, MPI_MIN, 0, MPI_COMM_WORLD);
MPI_Reduce(&max, NULL, 1, MPI_DOUBLE, MPI_MAX, 0, MPI_COMM_WORLD);
}
MPI_Datatype MPI_DTYPE = use_float ? MPI_FLOAT : MPI_DOUBLE;

// Collect global min/max timings
MPI_Allreduce(MPI_IN_PLACE, &min, 1, MPI_DTYPE, MPI_MIN, MPI_COMM_WORLD);
MPI_Allreduce(MPI_IN_PLACE, &max, 1, MPI_DTYPE, MPI_MAX, MPI_COMM_WORLD);
sizes[i] *= procs;
#endif

Expand Down

0 comments on commit ffe986d

Please sign in to comment.