From 88e903147b90b34004c920a592edb03860f86729 Mon Sep 17 00:00:00 2001 From: Omar Ashour Date: Sun, 5 May 2024 15:10:50 -0700 Subject: [PATCH] Replace itertools.product with np.indices This significantly improves performance. Also factored out some MPI code into the parallel module. --- darkmagic/calculator.py | 97 ++++++++++++++++++++++------------------- darkmagic/io.py | 21 +++------ darkmagic/parallel.py | 56 ++++++++++++++++-------- 3 files changed, 95 insertions(+), 79 deletions(-) diff --git a/darkmagic/calculator.py b/darkmagic/calculator.py index f53c3d9..5bfbd75 100644 --- a/darkmagic/calculator.py +++ b/darkmagic/calculator.py @@ -1,6 +1,6 @@ import math import warnings -import itertools +# import itertools import numpy as np from numpy.typing import ArrayLike @@ -12,7 +12,7 @@ from darkmagic.model import Model from darkmagic.numerics import Numerics from darkmagic.rate import RATE_CALC_CLASSES -from darkmagic.parallel import JOB_SENTINEL, ROOT_PROCESS, distribute_load +from darkmagic.parallel import JOB_SENTINEL, ROOT_PROCESS, distribute_load, setup_mpi class Calculator: @@ -117,10 +117,10 @@ def evaluate(self, mpi: bool = False): self.diff_rate = np.zeros((nv, nm, max_bin_num)) self.binned_rate = np.zeros((nv, nm, self.material.n_modes)) self.total_rate = np.zeros((nv, nm)) - if mpi: - self._evaluate_mpi() - else: - self._evaluate_serial() + eval_method = self._evaluate_mpi if mpi else self._evaluate_serial + self.comm = eval_method( + self.diff_rate, self.binned_rate, self.total_rate, self.calc_list + ) def to_file(self, filename: str | None = None, format="darkmagic"): """ @@ -236,76 +236,83 @@ def compute_reach( return sigma - def _evaluate_serial(self): + # @njit + @staticmethod + def _evaluate_serial( + diff_rate: np.ndarray, + binned_rate: np.ndarray, + total_rate: np.ndarray, + calc_list: list, + ): """ Evaluate the rates without MPI """ - for im, iv in itertools.product( - range(nm := len(self.m_chi)), range(nv := len(self.v_e)) - ): + nm, nv = len(calc_list[0]), len(calc_list) + for task in np.indices((nm, nv)).T.reshape(-1, 2): + im, iv = task[0], task[1] print( f"Rate calculation: {im * nv + iv + 1}/{(nv*nm)}.", ) ( - self.diff_rate[iv, im], - self.binned_rate[iv, im], - self.total_rate[iv, im], - ) = self.calc_list[iv][im].calculate_rate() + diff_rate[iv, im], + binned_rate[iv, im], + total_rate[iv, im], + ) = calc_list[iv][im].calculate_rate() - def _evaluate_mpi(self): + @staticmethod + def _evaluate_mpi( + diff_rate: np.ndarray, + binned_rate: np.ndarray, + total_rate: np.ndarray, + calc_list: list, + ): """ Evaluate the rates in parallel using MPI """ - from mpi4py.MPI import COMM_WORLD - - self.comm = COMM_WORLD - n_ranks = self.comm.Get_size() # Number of ranks - rank = self.comm.Get_rank() # Processor rank - - if rank == ROOT_PROCESS: - print("Done setting up MPI") + comm, n_ranks, rank = setup_mpi() + # comm, n_ranks, rank = None, 1, ROOT_PROCESS all_tasks = None if rank == ROOT_PROCESS: - all_tasks = distribute_load(n_ranks, self.m_chi, self.v_e) + all_tasks = distribute_load(n_ranks, calc_list) - task_list = self.comm.scatter(all_tasks, root=ROOT_PROCESS) + # task_list = all_tasks[0] + task_list = comm.scatter(all_tasks, root=ROOT_PROCESS) if rank == ROOT_PROCESS: - print("Done configuring calculation") + print("Done distributing tasks.") diff_rates, binned_rates, total_rates = [], [], [] # Loop over the tasks and calculate the rates - for job in task_list: + for j, job in enumerate(task_list): if job[0] == JOB_SENTINEL and job[1] == JOB_SENTINEL: continue im, iv = job[0], job[1] - d, b, t = self.calc_list[iv][im].calculate_rate() + print(f"Rank {rank} working on task {j+1}/{len(task_list)}") + d, b, t = calc_list[iv][im].calculate_rate() diff_rates.append([job, d.real]) binned_rates.append([job, b.real]) total_rates.append([job, t.real]) print(f"Rank {rank} done calculating rates.") + # comm.Barrier() + if rank == ROOT_PROCESS: + print("All ranks done calculating rates. Gathering results.") # TODO: this is hideous.... # Might be desirable to save these for parallel IO reimplementation - diff_rate = self.comm.gather(diff_rates, root=ROOT_PROCESS) - binned_rate = self.comm.gather(binned_rates, root=ROOT_PROCESS) - total_rate = self.comm.gather(total_rates, root=ROOT_PROCESS) + diff_rate_list = comm.gather(diff_rates, root=ROOT_PROCESS) + binned_rate_list = comm.gather(binned_rates, root=ROOT_PROCESS) + total_rate_list = comm.gather(total_rates, root=ROOT_PROCESS) if rank == ROOT_PROCESS: - self._reshape_rates(diff_rate, binned_rate, total_rate) - - def _reshape_rates(self, diff_rate, binned_rate, total_rate): - """ - Reshape the rates into the correct format. - """ - - for rate_array, rate_list in zip( - [self.diff_rate, self.binned_rate, self.total_rate], - [diff_rate, binned_rate, total_rate], - ): - for rates in rate_list: - for job, r in rates: - rate_array[job[1], job[0]] = r + for rate_array, rate_list in zip( + [diff_rate, binned_rate, total_rate], + [diff_rate_list, binned_rate_list, total_rate_list], + ): + for rates in rate_list: + for job, r in rates: + rate_array[job[1], job[0]] = r + + return comm diff --git a/darkmagic/io.py b/darkmagic/io.py index 14fbd84..f7dd416 100644 --- a/darkmagic/io.py +++ b/darkmagic/io.py @@ -114,21 +114,9 @@ def write_h5( if rank == ROOT_PROCESS: print(f"Writing to file {out_filename}{' in parallel' if parallel else ''}...") writer = write_phonodark if format == "phonodark" else write_darkmagic - if not parallel and rank == ROOT_PROCESS: - writer( - out_filename, - material, - model, - numerics, - masses, - times, - v_e, - all_total_rate_list, - all_diff_rate_list, - all_binned_rate_list, - comm=None, - ) - elif parallel: + if not parallel: + comm = None + if (not parallel and rank == ROOT_PROCESS) or parallel: writer( out_filename, material, @@ -142,6 +130,9 @@ def write_h5( all_binned_rate_list, comm=comm, ) + if rank == ROOT_PROCESS and not parallel: + # TODO: should have proper message for paralle IO when reimplemented + print("Done writing.") def write_darkmagic( diff --git a/darkmagic/parallel.py b/darkmagic/parallel.py index efb9ebc..7205e09 100644 --- a/darkmagic/parallel.py +++ b/darkmagic/parallel.py @@ -1,4 +1,4 @@ -import itertools +# import itertools import numpy as np @@ -6,7 +6,21 @@ JOB_SENTINEL = -99999 -def distribute_load(n_proc, masses, times): +def setup_mpi(): + """ + This function sets up the MPI environment and returns the communicator. + """ + from mpi4py.MPI import COMM_WORLD + + n_ranks = COMM_WORLD.Get_size() # Number of ranks + rank = COMM_WORLD.Get_rank() # Processor rank + + if rank == ROOT_PROCESS: + print(f"Done setting up MPI. Using {n_ranks} ranks.") + return COMM_WORLD, n_ranks, rank + + +def distribute_load(n_ranks, calc_list): """ This function distributes the load of jobs across the available processors. It attempts to balance the load as much as possible. @@ -14,35 +28,39 @@ def distribute_load(n_proc, masses, times): # TODO: change to logging print("Distributing load among processors.") # Our jobs list is a list of tuples, where each tuple is a pair of (mass, time) indices - total_job_list = np.array( - list(itertools.product(range(len(masses)), range(len(times)))) - ) + nm, nv = len(calc_list[0]), len(calc_list) + # total_job_list = np.array(list(itertools.product(range(nm), range(nv)))) + total_job_list = np.indices((nm, nv)).T.reshape(-1, 2) n_jobs = len(total_job_list) - base_jobs_per_proc = n_jobs // n_proc # each processor has at least this many jobs - extra_jobs = 1 if n_jobs % n_proc else 0 # might need 1 more job on some processors + base_jobs_per_rank = n_jobs // n_ranks # each processor has at least this many jobs + extra_jobs = ( + 1 if n_jobs % n_ranks else 0 + ) # might need 1 more job on some processors # Note a fan of using a sentinel value to indicate a "do nothing" job job_list = JOB_SENTINEL * np.ones( - (n_proc, base_jobs_per_proc + extra_jobs, 2), dtype=int + (n_ranks, base_jobs_per_rank + extra_jobs, 2), dtype=int ) for i in range(n_jobs): - proc_index = i % n_proc - job_index = i // n_proc + proc_index = i % n_ranks + job_index = i // n_ranks job_list[proc_index, job_index] = total_job_list[i] - if n_jobs > n_proc: - print("Number of jobs exceeds the number of processors.") - print(f"Baseline number of jobs per processor: {str(base_jobs_per_proc)}") - print( - "Remaining processors with one extra job: " - + str(n_jobs - n_proc * base_jobs_per_proc) - ) - elif n_jobs < n_proc: + # This is the number of ranks that will have on extra job + n_unbalanced = n_jobs - n_ranks * base_jobs_per_rank + if n_jobs > n_ranks and n_unbalanced != 0: + print("Number of jobs exceeds the number of ranks.") + print(f"Number of jobs per rank: {str(base_jobs_per_rank)}") + print(f"{n_unbalanced} ranks will have on extra job each.") + elif n_jobs < n_ranks: print("Number of jobs is fewer than the number of processors.") print("Consider reducing the number of processors for more efficiency.") print(f"Total number of jobs: {n_jobs}") else: - print("Number of jobs matches the number of processors. Maximally parallized.") + print( + "Number of jobs is perfectly divisible by the number of ranks. " + "Optimal load balancing." + ) return job_list