-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Replace itertools.product with np.indices
This significantly improves performance. Also factored out some MPI code into the parallel module.
- Loading branch information
Showing
3 changed files
with
95 additions
and
79 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,48 +1,66 @@ | ||
import itertools | ||
# import itertools | ||
|
||
import numpy as np | ||
|
||
ROOT_PROCESS = 0 # Root MPI process global | ||
JOB_SENTINEL = -99999 | ||
|
||
|
||
def distribute_load(n_proc, masses, times): | ||
def setup_mpi(): | ||
""" | ||
This function sets up the MPI environment and returns the communicator. | ||
""" | ||
from mpi4py.MPI import COMM_WORLD | ||
|
||
n_ranks = COMM_WORLD.Get_size() # Number of ranks | ||
rank = COMM_WORLD.Get_rank() # Processor rank | ||
|
||
if rank == ROOT_PROCESS: | ||
print(f"Done setting up MPI. Using {n_ranks} ranks.") | ||
return COMM_WORLD, n_ranks, rank | ||
|
||
|
||
def distribute_load(n_ranks, calc_list): | ||
""" | ||
This function distributes the load of jobs across the available processors. | ||
It attempts to balance the load as much as possible. | ||
""" | ||
# TODO: change to logging | ||
print("Distributing load among processors.") | ||
# Our jobs list is a list of tuples, where each tuple is a pair of (mass, time) indices | ||
total_job_list = np.array( | ||
list(itertools.product(range(len(masses)), range(len(times)))) | ||
) | ||
nm, nv = len(calc_list[0]), len(calc_list) | ||
# total_job_list = np.array(list(itertools.product(range(nm), range(nv)))) | ||
total_job_list = np.indices((nm, nv)).T.reshape(-1, 2) | ||
n_jobs = len(total_job_list) | ||
|
||
base_jobs_per_proc = n_jobs // n_proc # each processor has at least this many jobs | ||
extra_jobs = 1 if n_jobs % n_proc else 0 # might need 1 more job on some processors | ||
base_jobs_per_rank = n_jobs // n_ranks # each processor has at least this many jobs | ||
extra_jobs = ( | ||
1 if n_jobs % n_ranks else 0 | ||
) # might need 1 more job on some processors | ||
# Note a fan of using a sentinel value to indicate a "do nothing" job | ||
job_list = JOB_SENTINEL * np.ones( | ||
(n_proc, base_jobs_per_proc + extra_jobs, 2), dtype=int | ||
(n_ranks, base_jobs_per_rank + extra_jobs, 2), dtype=int | ||
) | ||
|
||
for i in range(n_jobs): | ||
proc_index = i % n_proc | ||
job_index = i // n_proc | ||
proc_index = i % n_ranks | ||
job_index = i // n_ranks | ||
job_list[proc_index, job_index] = total_job_list[i] | ||
|
||
if n_jobs > n_proc: | ||
print("Number of jobs exceeds the number of processors.") | ||
print(f"Baseline number of jobs per processor: {str(base_jobs_per_proc)}") | ||
print( | ||
"Remaining processors with one extra job: " | ||
+ str(n_jobs - n_proc * base_jobs_per_proc) | ||
) | ||
elif n_jobs < n_proc: | ||
# This is the number of ranks that will have on extra job | ||
n_unbalanced = n_jobs - n_ranks * base_jobs_per_rank | ||
if n_jobs > n_ranks and n_unbalanced != 0: | ||
print("Number of jobs exceeds the number of ranks.") | ||
print(f"Number of jobs per rank: {str(base_jobs_per_rank)}") | ||
print(f"{n_unbalanced} ranks will have on extra job each.") | ||
elif n_jobs < n_ranks: | ||
print("Number of jobs is fewer than the number of processors.") | ||
print("Consider reducing the number of processors for more efficiency.") | ||
print(f"Total number of jobs: {n_jobs}") | ||
else: | ||
print("Number of jobs matches the number of processors. Maximally parallized.") | ||
print( | ||
"Number of jobs is perfectly divisible by the number of ranks. " | ||
"Optimal load balancing." | ||
) | ||
|
||
return job_list |