Skip to content

Commit

Permalink
Replace itertools.product with np.indices
Browse files Browse the repository at this point in the history
This significantly improves performance. Also
factored out some MPI code into the parallel
module.
  • Loading branch information
oashour committed May 5, 2024
1 parent add150f commit 88e9031
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 79 deletions.
97 changes: 52 additions & 45 deletions darkmagic/calculator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import math
import warnings
import itertools
# import itertools

import numpy as np
from numpy.typing import ArrayLike
Expand All @@ -12,7 +12,7 @@
from darkmagic.model import Model
from darkmagic.numerics import Numerics
from darkmagic.rate import RATE_CALC_CLASSES
from darkmagic.parallel import JOB_SENTINEL, ROOT_PROCESS, distribute_load
from darkmagic.parallel import JOB_SENTINEL, ROOT_PROCESS, distribute_load, setup_mpi


class Calculator:
Expand Down Expand Up @@ -117,10 +117,10 @@ def evaluate(self, mpi: bool = False):
self.diff_rate = np.zeros((nv, nm, max_bin_num))
self.binned_rate = np.zeros((nv, nm, self.material.n_modes))
self.total_rate = np.zeros((nv, nm))
if mpi:
self._evaluate_mpi()
else:
self._evaluate_serial()
eval_method = self._evaluate_mpi if mpi else self._evaluate_serial
self.comm = eval_method(
self.diff_rate, self.binned_rate, self.total_rate, self.calc_list
)

def to_file(self, filename: str | None = None, format="darkmagic"):
"""
Expand Down Expand Up @@ -236,76 +236,83 @@ def compute_reach(

return sigma

def _evaluate_serial(self):
# @njit
@staticmethod
def _evaluate_serial(
diff_rate: np.ndarray,
binned_rate: np.ndarray,
total_rate: np.ndarray,
calc_list: list,
):
"""
Evaluate the rates without MPI
"""

for im, iv in itertools.product(
range(nm := len(self.m_chi)), range(nv := len(self.v_e))
):
nm, nv = len(calc_list[0]), len(calc_list)
for task in np.indices((nm, nv)).T.reshape(-1, 2):
im, iv = task[0], task[1]
print(
f"Rate calculation: {im * nv + iv + 1}/{(nv*nm)}.",
)
(
self.diff_rate[iv, im],
self.binned_rate[iv, im],
self.total_rate[iv, im],
) = self.calc_list[iv][im].calculate_rate()
diff_rate[iv, im],
binned_rate[iv, im],
total_rate[iv, im],
) = calc_list[iv][im].calculate_rate()

def _evaluate_mpi(self):
@staticmethod
def _evaluate_mpi(
diff_rate: np.ndarray,
binned_rate: np.ndarray,
total_rate: np.ndarray,
calc_list: list,
):
"""
Evaluate the rates in parallel using MPI
"""
from mpi4py.MPI import COMM_WORLD

self.comm = COMM_WORLD
n_ranks = self.comm.Get_size() # Number of ranks
rank = self.comm.Get_rank() # Processor rank

if rank == ROOT_PROCESS:
print("Done setting up MPI")
comm, n_ranks, rank = setup_mpi()

# comm, n_ranks, rank = None, 1, ROOT_PROCESS
all_tasks = None
if rank == ROOT_PROCESS:
all_tasks = distribute_load(n_ranks, self.m_chi, self.v_e)
all_tasks = distribute_load(n_ranks, calc_list)

task_list = self.comm.scatter(all_tasks, root=ROOT_PROCESS)
# task_list = all_tasks[0]
task_list = comm.scatter(all_tasks, root=ROOT_PROCESS)

if rank == ROOT_PROCESS:
print("Done configuring calculation")
print("Done distributing tasks.")

diff_rates, binned_rates, total_rates = [], [], []
# Loop over the tasks and calculate the rates
for job in task_list:
for j, job in enumerate(task_list):
if job[0] == JOB_SENTINEL and job[1] == JOB_SENTINEL:
continue
im, iv = job[0], job[1]
d, b, t = self.calc_list[iv][im].calculate_rate()
print(f"Rank {rank} working on task {j+1}/{len(task_list)}")
d, b, t = calc_list[iv][im].calculate_rate()
diff_rates.append([job, d.real])
binned_rates.append([job, b.real])
total_rates.append([job, t.real])

print(f"Rank {rank} done calculating rates.")

# comm.Barrier()
if rank == ROOT_PROCESS:
print("All ranks done calculating rates. Gathering results.")
# TODO: this is hideous....
# Might be desirable to save these for parallel IO reimplementation
diff_rate = self.comm.gather(diff_rates, root=ROOT_PROCESS)
binned_rate = self.comm.gather(binned_rates, root=ROOT_PROCESS)
total_rate = self.comm.gather(total_rates, root=ROOT_PROCESS)
diff_rate_list = comm.gather(diff_rates, root=ROOT_PROCESS)
binned_rate_list = comm.gather(binned_rates, root=ROOT_PROCESS)
total_rate_list = comm.gather(total_rates, root=ROOT_PROCESS)

if rank == ROOT_PROCESS:
self._reshape_rates(diff_rate, binned_rate, total_rate)

def _reshape_rates(self, diff_rate, binned_rate, total_rate):
"""
Reshape the rates into the correct format.
"""

for rate_array, rate_list in zip(
[self.diff_rate, self.binned_rate, self.total_rate],
[diff_rate, binned_rate, total_rate],
):
for rates in rate_list:
for job, r in rates:
rate_array[job[1], job[0]] = r
for rate_array, rate_list in zip(
[diff_rate, binned_rate, total_rate],
[diff_rate_list, binned_rate_list, total_rate_list],
):
for rates in rate_list:
for job, r in rates:
rate_array[job[1], job[0]] = r

return comm
21 changes: 6 additions & 15 deletions darkmagic/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,21 +114,9 @@ def write_h5(
if rank == ROOT_PROCESS:
print(f"Writing to file {out_filename}{' in parallel' if parallel else ''}...")
writer = write_phonodark if format == "phonodark" else write_darkmagic
if not parallel and rank == ROOT_PROCESS:
writer(
out_filename,
material,
model,
numerics,
masses,
times,
v_e,
all_total_rate_list,
all_diff_rate_list,
all_binned_rate_list,
comm=None,
)
elif parallel:
if not parallel:
comm = None
if (not parallel and rank == ROOT_PROCESS) or parallel:
writer(
out_filename,
material,
Expand All @@ -142,6 +130,9 @@ def write_h5(
all_binned_rate_list,
comm=comm,
)
if rank == ROOT_PROCESS and not parallel:
# TODO: should have proper message for paralle IO when reimplemented
print("Done writing.")


def write_darkmagic(
Expand Down
56 changes: 37 additions & 19 deletions darkmagic/parallel.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,66 @@
import itertools
# import itertools

import numpy as np

ROOT_PROCESS = 0 # Root MPI process global
JOB_SENTINEL = -99999


def distribute_load(n_proc, masses, times):
def setup_mpi():
"""
This function sets up the MPI environment and returns the communicator.
"""
from mpi4py.MPI import COMM_WORLD

n_ranks = COMM_WORLD.Get_size() # Number of ranks
rank = COMM_WORLD.Get_rank() # Processor rank

if rank == ROOT_PROCESS:
print(f"Done setting up MPI. Using {n_ranks} ranks.")
return COMM_WORLD, n_ranks, rank


def distribute_load(n_ranks, calc_list):
"""
This function distributes the load of jobs across the available processors.
It attempts to balance the load as much as possible.
"""
# TODO: change to logging
print("Distributing load among processors.")
# Our jobs list is a list of tuples, where each tuple is a pair of (mass, time) indices
total_job_list = np.array(
list(itertools.product(range(len(masses)), range(len(times))))
)
nm, nv = len(calc_list[0]), len(calc_list)
# total_job_list = np.array(list(itertools.product(range(nm), range(nv))))
total_job_list = np.indices((nm, nv)).T.reshape(-1, 2)
n_jobs = len(total_job_list)

base_jobs_per_proc = n_jobs // n_proc # each processor has at least this many jobs
extra_jobs = 1 if n_jobs % n_proc else 0 # might need 1 more job on some processors
base_jobs_per_rank = n_jobs // n_ranks # each processor has at least this many jobs
extra_jobs = (
1 if n_jobs % n_ranks else 0
) # might need 1 more job on some processors
# Note a fan of using a sentinel value to indicate a "do nothing" job
job_list = JOB_SENTINEL * np.ones(
(n_proc, base_jobs_per_proc + extra_jobs, 2), dtype=int
(n_ranks, base_jobs_per_rank + extra_jobs, 2), dtype=int
)

for i in range(n_jobs):
proc_index = i % n_proc
job_index = i // n_proc
proc_index = i % n_ranks
job_index = i // n_ranks
job_list[proc_index, job_index] = total_job_list[i]

if n_jobs > n_proc:
print("Number of jobs exceeds the number of processors.")
print(f"Baseline number of jobs per processor: {str(base_jobs_per_proc)}")
print(
"Remaining processors with one extra job: "
+ str(n_jobs - n_proc * base_jobs_per_proc)
)
elif n_jobs < n_proc:
# This is the number of ranks that will have on extra job
n_unbalanced = n_jobs - n_ranks * base_jobs_per_rank
if n_jobs > n_ranks and n_unbalanced != 0:
print("Number of jobs exceeds the number of ranks.")
print(f"Number of jobs per rank: {str(base_jobs_per_rank)}")
print(f"{n_unbalanced} ranks will have on extra job each.")
elif n_jobs < n_ranks:
print("Number of jobs is fewer than the number of processors.")
print("Consider reducing the number of processors for more efficiency.")
print(f"Total number of jobs: {n_jobs}")
else:
print("Number of jobs matches the number of processors. Maximally parallized.")
print(
"Number of jobs is perfectly divisible by the number of ranks. "
"Optimal load balancing."
)

return job_list

0 comments on commit 88e9031

Please sign in to comment.