Skip to content

Commit

Permalink
fix: Fix mistakes in the distributed Ray Pool usage
Browse files Browse the repository at this point in the history
  • Loading branch information
adigitoleo committed Apr 22, 2024
1 parent 9df8a6e commit e5ae658
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 30 deletions.
12 changes: 8 additions & 4 deletions src/pydrex/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,11 +285,15 @@ def misorientation_indices(
texture snapshots and M is the number of grains.
Uses either Ray or the Python multiprocessing library to calculate texture indices
for multiple snapshots simultaneously. If `ncpus` is `None` the number of CPU cores
to use is chosen automatically based on the maximum number available to the Python
for multiple snapshots simultaneously. The arguments `ncpus` and `pool` are only
relevant the latter option: if `ncpus` is `None` the number of CPU cores to use is
chosen automatically based on the maximum number available to the Python
interpreter, otherwise the specified number of cores is requested. Alternatively, an
existing instance of `multiprocessing.Pool` or `ray.util.multiprocessing.Pool` can
be provided.
existing instance of `multiprocessing.Pool` can be provided.
If Ray is installed, it will be automatically preferred. In this case, the number of
processors (actually Ray “workers”) should be set upon initialisation of the Ray
cluster (which can be distributed over the network).
See `misorientation_index` for documentation of the remaining arguments.
Expand Down
4 changes: 2 additions & 2 deletions src/pydrex/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

@ray.remote
def misorientation_indices(*args, **kwargs):
_diagnostics.misorientation_indices(*args, **kwargs)
return _diagnostics.misorientation_indices(*args, **kwargs)

@ray.remote
def misorientation_index(*args, **kwargs):
_diagnostics.misorientation_index(*args, **kwargs)
return _diagnostics.misorientation_index(*args, **kwargs)
32 changes: 8 additions & 24 deletions tests/test_simple_shear_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,8 @@ def test_direction_change(
)
with Pool(processes=ncpus) as pool:
for s, out in enumerate(pool.imap_unordered(_run, _seeds)):
# if HAS_RAY:
# assert len(pool._actor_pool) == ncpus
olivine, enstatite = out
_log.info("%s; # %d; postprocessing olivine...", _id, _seeds[s])
olA_resampled, _ = _stats.resample_orientations(
Expand All @@ -194,18 +196,9 @@ def test_direction_change(
olA_downsampled, _ = _stats.resample_orientations(
olivine.orientations, olivine.fractions, seed=_seeds[s], n_samples=1000
)
if HAS_RAY:
olA_strength[s, :] = ray.get(
_dstr.misorientation_indices.remote(
ray.put(olA_downsampled),
_geo.LatticeSystem.orthorhombic,
pool=pool,
)
)
else:
olA_strength[s, :] = _diagnostics.misorientation_indices(
olA_downsampled, _geo.LatticeSystem.orthorhombic, pool=pool
)
olA_strength[s, :] = _diagnostics.misorientation_indices(
olA_downsampled, _geo.LatticeSystem.orthorhombic, pool=pool
)

del olivine, olA_resampled, olA_mean_vectors

Expand Down Expand Up @@ -234,18 +227,9 @@ def test_direction_change(
ens_downsampled, _ = _stats.resample_orientations(
enstatite.orientations, enstatite.fractions, seed=_seeds[s], n_samples=1000
)
if HAS_RAY:
olA_strength[s, :] = ray.get(
_dstr.misorientation_indices.remote(
ray.put(ens_downsampled),
_geo.LatticeSystem.orthorhombic,
pool=pool,
)
)
else:
ens_strength[s, :] = _diagnostics.misorientation_indices(
ens_downsampled, _geo.LatticeSystem.orthorhombic, pool=pool
)
ens_strength[s, :] = _diagnostics.misorientation_indices(
ens_downsampled, _geo.LatticeSystem.orthorhombic, pool=pool
)
del enstatite, ens_resampled, ens_mean_vectors

_log.info("elapsed CPU time: %s", np.abs(process_time() - clock_start))
Expand Down

0 comments on commit e5ae658

Please sign in to comment.