Skip to content

Commit

Permalink
Make MPI optional
Browse files Browse the repository at this point in the history
  • Loading branch information
oashour committed May 5, 2024
1 parent 86b2011 commit add150f
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 52 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,6 @@ VBT.vasp
VBTS_mdm.h5
hcp_He_lsm.h5
venv/**
dm.py
run.py
omni_model.py
./*.h5
129 changes: 81 additions & 48 deletions darkmagic/calculator.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import math
import warnings
import itertools

import numpy as np
from mpi4py.MPI import COMM_WORLD
from numpy.typing import ArrayLike

import darkmagic.constants as const
Expand All @@ -11,8 +11,8 @@
from darkmagic.material import Material
from darkmagic.model import Model
from darkmagic.numerics import Numerics
from darkmagic.parallel import JOB_SENTINEL, ROOT_PROCESS, distribute_load
from darkmagic.rate import RATE_CALC_CLASSES
from darkmagic.parallel import JOB_SENTINEL, ROOT_PROCESS, distribute_load


class Calculator:
Expand Down Expand Up @@ -60,6 +60,7 @@ def __init__(
self.material = material
self.m_chi = m_chi
self.model = model
self.comm = None # MPI communicator

# TODO: this is ugly
calc_class = RATE_CALC_CLASSES.get((calc_type, type(material)))
Expand Down Expand Up @@ -104,47 +105,22 @@ def from_file(cls, filename: str, format="darkmagic"):

return calc

def evaluate(self):
def evaluate(self, mpi: bool = False):
"""
Computes the differential rate for all masses and earth velocities
"""

n_ranks = COMM_WORLD.Get_size() # Number of ranks
rank = COMM_WORLD.Get_rank() # Processor rank

if rank == ROOT_PROCESS:
print("Done setting up MPI")

all_tasks = None
if rank == ROOT_PROCESS:
all_tasks = distribute_load(n_ranks, self.m_chi, self.v_e)

task_list = COMM_WORLD.scatter(all_tasks, root=ROOT_PROCESS)

if rank == ROOT_PROCESS:
print("Done configuring calculation")
diff_rates, binned_rates, total_rates = [], [], []
# Loop over the tasks and calculate the rates
for job in task_list:
if job[0] == JOB_SENTINEL and job[1] == JOB_SENTINEL:
continue
im, iv = job[0], job[1]
d, b, t = self.calc_list[iv][im].calculate_rate()
diff_rates.append([job, d.real])
binned_rates.append([job, b.real])
total_rates.append([job, t.real])

print(f"Rank {rank} done calculating rates.")

# TODO: this is hideous....
# Might be desirable to save these for parallel IO reimplementation
diff_rate = COMM_WORLD.gather(diff_rates, root=ROOT_PROCESS)
binned_rate = COMM_WORLD.gather(binned_rates, root=ROOT_PROCESS)
total_rate = COMM_WORLD.gather(total_rates, root=ROOT_PROCESS)

if rank == ROOT_PROCESS:
self._reshape_rates(diff_rate, binned_rate, total_rate)
Args:
mpi (bool, optional): Whether to run the calculation in parallel using MPI. Defaults to False.
"""
nv, nm = len(self.time), len(self.m_chi)
max_bin_num = math.ceil(self.material.max_dE / self.numerics.bin_width)
self.diff_rate = np.zeros((nv, nm, max_bin_num))
self.binned_rate = np.zeros((nv, nm, self.material.n_modes))
self.total_rate = np.zeros((nv, nm))
if mpi:
self._evaluate_mpi()
else:
self._evaluate_serial()

def to_file(self, filename: str | None = None, format="darkmagic"):
"""
Expand All @@ -157,11 +133,12 @@ def to_file(self, filename: str | None = None, format="darkmagic"):
if filename is None:
filename = f"{self.material.name}_{self.model.shortname}.h5"
# Make sure all the rates are note None
rank = self.comm.Get_rank() if self.comm is not None else 0
if (
self.diff_rate is None
or self.binned_rate is None
or self.total_rate is None
) and COMM_WORLD.Get_rank() == ROOT_PROCESS:
) and rank == ROOT_PROCESS:
raise ValueError(
"Rates are not computed yet. Please run the calculation first using the evaluate method."
)
Expand All @@ -178,8 +155,8 @@ def to_file(self, filename: str | None = None, format="darkmagic"):
self.total_rate,
self.diff_rate,
self.binned_rate,
COMM_WORLD.Get_rank(),
None,
rank,
None, # should be self.comm when parallel is reimplemented
parallel=False,
format=format,
)
Expand Down Expand Up @@ -259,15 +236,71 @@ def compute_reach(

return sigma

def _evaluate_serial(self):
"""
Evaluate the rates without MPI
"""

for im, iv in itertools.product(
range(nm := len(self.m_chi)), range(nv := len(self.v_e))
):
print(
f"Rate calculation: {im * nv + iv + 1}/{(nv*nm)}.",
)
(
self.diff_rate[iv, im],
self.binned_rate[iv, im],
self.total_rate[iv, im],
) = self.calc_list[iv][im].calculate_rate()

def _evaluate_mpi(self):
"""
Evaluate the rates in parallel using MPI
"""
from mpi4py.MPI import COMM_WORLD

self.comm = COMM_WORLD
n_ranks = self.comm.Get_size() # Number of ranks
rank = self.comm.Get_rank() # Processor rank

if rank == ROOT_PROCESS:
print("Done setting up MPI")

all_tasks = None
if rank == ROOT_PROCESS:
all_tasks = distribute_load(n_ranks, self.m_chi, self.v_e)

task_list = self.comm.scatter(all_tasks, root=ROOT_PROCESS)

if rank == ROOT_PROCESS:
print("Done configuring calculation")

diff_rates, binned_rates, total_rates = [], [], []
# Loop over the tasks and calculate the rates
for job in task_list:
if job[0] == JOB_SENTINEL and job[1] == JOB_SENTINEL:
continue
im, iv = job[0], job[1]
d, b, t = self.calc_list[iv][im].calculate_rate()
diff_rates.append([job, d.real])
binned_rates.append([job, b.real])
total_rates.append([job, t.real])

print(f"Rank {rank} done calculating rates.")

# TODO: this is hideous....
# Might be desirable to save these for parallel IO reimplementation
diff_rate = self.comm.gather(diff_rates, root=ROOT_PROCESS)
binned_rate = self.comm.gather(binned_rates, root=ROOT_PROCESS)
total_rate = self.comm.gather(total_rates, root=ROOT_PROCESS)

if rank == ROOT_PROCESS:
self._reshape_rates(diff_rate, binned_rate, total_rate)

def _reshape_rates(self, diff_rate, binned_rate, total_rate):
"""
Reshape the rates into the correct format.
"""
nt, nm = len(self.time), len(self.m_chi)
max_bin_num = math.ceil(self.material.max_dE / self.numerics.bin_width)
self.diff_rate = np.zeros((nt, nm, max_bin_num))
self.binned_rate = np.zeros((nt, nm, self.material.n_modes))
self.total_rate = np.zeros((nt, nm))

for rate_array, rate_list in zip(
[self.diff_rate, self.binned_rate, self.total_rate],
Expand Down
10 changes: 7 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@ dependencies = [
"pymatgen",
"rad-tools",
"h5py",
"mpi4py",
"setuptools; python_version=='3.12'",
]
classifiers = [
"Programming Language :: Python :: 3",
Expand All @@ -34,6 +32,7 @@ Homepage = "https://github.com/Griffin-Group/DarkMAGIC"
Issues = "https://github.com/Griffin-Group/DarkMAGIC/issues"

[project.optional-dependencies]
mpi = ["mpi4py"] # Deps for MPI support
docs = [ # Optional dependencies for building documentation
"mkdocs-material",
"mike",
Expand All @@ -43,7 +42,12 @@ docs = [ # Optional dependencies for building documentation
"mkdocs-gen-files",
"mkdocs-section-index",
]
dev = ["pytest", "ruff", "pytest_parametrize_cases"] # Deps for development
dev = [
"ruff",
"pytest",
"pytest_parametrize_cases",
"pre-commit",
] # Deps for development
build = ["build", "twine", "wheel"] # Deps for building and publishing

[tool.setuptools.packages]
Expand Down

0 comments on commit add150f

Please sign in to comment.