From e5ae658a7b015ce57d456781ab789d7b8c06795b Mon Sep 17 00:00:00 2001 From: adigitoleo Date: Fri, 19 Apr 2024 00:26:11 +1000 Subject: [PATCH] fix: Fix mistakes in the distributed Ray Pool usage --- src/pydrex/diagnostics.py | 12 ++++++++---- src/pydrex/distributed.py | 4 ++-- tests/test_simple_shear_3d.py | 32 ++++++++------------------------ 3 files changed, 18 insertions(+), 30 deletions(-) diff --git a/src/pydrex/diagnostics.py b/src/pydrex/diagnostics.py index 0d5d81ba..eeae8b85 100644 --- a/src/pydrex/diagnostics.py +++ b/src/pydrex/diagnostics.py @@ -285,11 +285,15 @@ def misorientation_indices( texture snapshots and M is the number of grains. Uses either Ray or the Python multiprocessing library to calculate texture indices - for multiple snapshots simultaneously. If `ncpus` is `None` the number of CPU cores - to use is chosen automatically based on the maximum number available to the Python + for multiple snapshots simultaneously. The arguments `ncpus` and `pool` are only + relevant the latter option: if `ncpus` is `None` the number of CPU cores to use is + chosen automatically based on the maximum number available to the Python interpreter, otherwise the specified number of cores is requested. Alternatively, an - existing instance of `multiprocessing.Pool` or `ray.util.multiprocessing.Pool` can - be provided. + existing instance of `multiprocessing.Pool` can be provided. + + If Ray is installed, it will be automatically preferred. In this case, the number of + processors (actually Ray “workers”) should be set upon initialisation of the Ray + cluster (which can be distributed over the network). See `misorientation_index` for documentation of the remaining arguments. diff --git a/src/pydrex/distributed.py b/src/pydrex/distributed.py index 5f9c5068..016e9867 100644 --- a/src/pydrex/distributed.py +++ b/src/pydrex/distributed.py @@ -5,8 +5,8 @@ @ray.remote def misorientation_indices(*args, **kwargs): - _diagnostics.misorientation_indices(*args, **kwargs) + return _diagnostics.misorientation_indices(*args, **kwargs) @ray.remote def misorientation_index(*args, **kwargs): - _diagnostics.misorientation_index(*args, **kwargs) + return _diagnostics.misorientation_index(*args, **kwargs) diff --git a/tests/test_simple_shear_3d.py b/tests/test_simple_shear_3d.py index f9d19884..b8762692 100644 --- a/tests/test_simple_shear_3d.py +++ b/tests/test_simple_shear_3d.py @@ -168,6 +168,8 @@ def test_direction_change( ) with Pool(processes=ncpus) as pool: for s, out in enumerate(pool.imap_unordered(_run, _seeds)): + # if HAS_RAY: + # assert len(pool._actor_pool) == ncpus olivine, enstatite = out _log.info("%s; # %d; postprocessing olivine...", _id, _seeds[s]) olA_resampled, _ = _stats.resample_orientations( @@ -194,18 +196,9 @@ def test_direction_change( olA_downsampled, _ = _stats.resample_orientations( olivine.orientations, olivine.fractions, seed=_seeds[s], n_samples=1000 ) - if HAS_RAY: - olA_strength[s, :] = ray.get( - _dstr.misorientation_indices.remote( - ray.put(olA_downsampled), - _geo.LatticeSystem.orthorhombic, - pool=pool, - ) - ) - else: - olA_strength[s, :] = _diagnostics.misorientation_indices( - olA_downsampled, _geo.LatticeSystem.orthorhombic, pool=pool - ) + olA_strength[s, :] = _diagnostics.misorientation_indices( + olA_downsampled, _geo.LatticeSystem.orthorhombic, pool=pool + ) del olivine, olA_resampled, olA_mean_vectors @@ -234,18 +227,9 @@ def test_direction_change( ens_downsampled, _ = _stats.resample_orientations( enstatite.orientations, enstatite.fractions, seed=_seeds[s], n_samples=1000 ) - if HAS_RAY: - olA_strength[s, :] = ray.get( - _dstr.misorientation_indices.remote( - ray.put(ens_downsampled), - _geo.LatticeSystem.orthorhombic, - pool=pool, - ) - ) - else: - ens_strength[s, :] = _diagnostics.misorientation_indices( - ens_downsampled, _geo.LatticeSystem.orthorhombic, pool=pool - ) + ens_strength[s, :] = _diagnostics.misorientation_indices( + ens_downsampled, _geo.LatticeSystem.orthorhombic, pool=pool + ) del enstatite, ens_resampled, ens_mean_vectors _log.info("elapsed CPU time: %s", np.abs(process_time() - clock_start))