diff --git a/pyproject.toml b/pyproject.toml index c6264f7..c97cd4c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,8 +54,6 @@ fluidfft-bench = "fluidfft.bench:run" fluidfft-bench-analysis = "fluidfft.bench_analysis:run" [project.entry-points."fluidfft.plugins"] -"fft2d.fake_mod_fft2d_for_doc" = "fluidfft.fft2d.fake_mod_fft2d_for_doc" -"fft3d.fake_mod_fft3d_for_doc" = "fluidfft.fft3d.fake_mod_fft3d_for_doc" "fft2d.with_pyfftw" = "fluidfft.fft2d.with_pyfftw" "fft3d.with_pyfftw" = "fluidfft.fft3d.with_pyfftw" "fft2d.with_dask" = "fluidfft.fft2d.with_dask" diff --git a/src/fluidfft/__init__.py b/src/fluidfft/__init__.py index 8da0287..d1e966c 100644 --- a/src/fluidfft/__init__.py +++ b/src/fluidfft/__init__.py @@ -35,9 +35,9 @@ import logging if sys.version_info < (3, 10): - from importlib_metadata import entry_points + from importlib_metadata import entry_points, EntryPoint else: - from importlib.metadata import entry_points + from importlib.metadata import entry_points, EntryPoint from fluiddyn.util import mpi @@ -95,7 +95,7 @@ def byte_align(values, *args): _plugins = None -def get_plugins(reload=False): +def get_plugins(reload=False, ndim=None, sequential=None): """Discover the fluidfft plugins installed""" global _plugins if _plugins is None or reload: @@ -104,7 +104,26 @@ def get_plugins(reload=False): if not _plugins: raise RuntimeError("No Fluidfft plugins were found.") - return _plugins + if ndim is None and sequential is None: + return _plugins + + if ndim is None: + index = 6 + prefix = "" + elif ndim in (2, 3): + index = 0 + prefix = f"fft{ndim}d." + else: + raise ValueError(f"Unsupported value for {ndim = }") + + if sequential is not None and not sequential: + prefix += "mpi_" + elif sequential: + prefix += "with_" + + return tuple( + plugin for plugin in _plugins if plugin.name[index:].startswith(prefix) + ) def get_module_fullname_from_method(method): @@ -140,12 +159,10 @@ def _normalize_method_name(method): return method -def _check_failure(module_fullname): +def _check_failure(method): """Check if a tiny fft maker can be created""" - if not any( - module_fullname.endswith(postfix) for postfix in ("pfft", "p3dfft") - ): + if not any(method.endswith(postfix) for postfix in ("pfft", "p3dfft")): return False # for few methods, try before real import because importing can lead to @@ -162,11 +179,12 @@ def _check_failure(module_fullname): else: env = os.environ try: + # TODO: capture stdout and stderr and include last line in case of failure subprocess.check_call( [ sys.executable, "-c", - f"from fluidfft import create_fft_object as c; c({module_fullname}, 2, 2, 2, check=0)", + f"from fluidfft import create_fft_object as c; c('{method}', 2, 2, 2, check=False)", ], env=env, shell=False, @@ -207,11 +225,15 @@ def import_fft_class(method, raise_import_error=True, check=True): """ - method = _normalize_method_name(method) - module_fullname = get_module_fullname_from_method(method) + if isinstance(method, EntryPoint): + module_fullname = method.value + method = method.name + else: + method = _normalize_method_name(method) + module_fullname = get_module_fullname_from_method(method) if check: - failure = _check_failure(module_fullname) + failure = _check_failure(method) if failure: if not raise_import_error: mpi.printby0("ImportError during check:", module_fullname) @@ -231,6 +253,14 @@ def import_fft_class(method, raise_import_error=True, check=True): return mod.FFTclass +def _get_classes(ndim, sequential): + plugins = get_plugins(ndim=ndim, sequential=sequential) + return { + plugin.name: import_fft_class(plugin, raise_import_error=False) + for plugin in plugins + } + + def create_fft_object(method, n0, n1, n2=None, check=True): """Helper for creating fft objects. @@ -253,7 +283,7 @@ def create_fft_object(method, n0, n1, n2=None, check=True): """ - cls = import_fft_class(method, check) + cls = import_fft_class(method, check=check) str_module = cls.__module__ diff --git a/src/fluidfft/fft3d/__init__.py b/src/fluidfft/fft3d/__init__.py index bdac4fc..d42adb4 100644 --- a/src/fluidfft/fft3d/__init__.py +++ b/src/fluidfft/fft3d/__init__.py @@ -37,7 +37,7 @@ class :class:`fluidfft.fft3d.operators.OperatorsPseudoSpectral3D` defined in import sys -from .. import import_fft_class +from .. import _get_classes __all__ = [ "FFT3dFakeForDoc", @@ -52,32 +52,15 @@ class :class:`fluidfft.fft3d.operators.OperatorsPseudoSpectral3D` defined in except ImportError: pass -methods_seq = ["fftw3d", "pyfftw"] -methods_seq = ["fft3d.with_" + method for method in methods_seq] - -methods_mpi = [ - "fftw1d", - "fftwmpi3d", - "p3dfft", - "pfft", -] -methods_mpi = ["fft3d.mpi_with_" + method for method in methods_mpi] - def get_classes_seq(): """Return all sequential 3d classes.""" - return { - method: import_fft_class(method, raise_import_error=False) - for method in methods_seq - } + return _get_classes(3, sequential=True) def get_classes_mpi(): """Return all parallel 3d classes.""" - return { - method: import_fft_class(method, raise_import_error=False) - for method in methods_mpi - } + return _get_classes(3, sequential=False) if any("pytest" in part for part in sys.argv): diff --git a/tests/test_3d.py b/tests/test_3d.py index a12d8ae..5579693 100644 --- a/tests/test_3d.py +++ b/tests/test_3d.py @@ -5,6 +5,7 @@ from fluiddyn.util import mpi +from fluidfft import import_fft_class from fluidfft.fft3d import get_classes_seq, get_classes_mpi from fluidfft.fft3d.testing import make_testop_functions @@ -15,19 +16,30 @@ traceback.print_exc() -n = 8 +def test_get_classes(): + get_classes_seq() + get_classes_mpi() -rank = mpi.rank -nb_proc = mpi.nb_proc -classes_seq = get_classes_seq() +methods_seq = ["fftw3d", "pyfftw"] +methods_seq = ["fft3d.with_" + method for method in methods_seq] +classes_seq = { + method: import_fft_class(method, raise_import_error=False) + for method in methods_seq +} classes_seq = {name: cls for name, cls in classes_seq.items() if cls is not None} - if not classes_seq: raise ImportError("Not sequential 2d classes working!") +methods_mpi = ["fftw1d", "fftwmpi3d", "p3dfft", "pfft"] +methods_mpi = ["fft3d.mpi_with_" + method for method in methods_mpi] + +nb_proc = mpi.nb_proc if nb_proc > 1: - classes_mpi = get_classes_mpi() + classes_mpi = { + method: import_fft_class(method, raise_import_error=False) + for method in methods_mpi + } classes_mpi = { name: cls for name, cls in classes_mpi.items() if cls is not None } diff --git a/tests/test_plugins.py b/tests/test_plugins.py index 854d8bf..21235a0 100644 --- a/tests/test_plugins.py +++ b/tests/test_plugins.py @@ -3,7 +3,57 @@ from fluidfft import get_plugins -def test_toto(): +methodss = { + (2, True): set( + [ + "fft2d.with_fftw1d", + "fft2d.with_fftw2d", + "fft2d.with_cufft", + "fft2d.with_pyfftw", + "fft2d.with_dask", + ] + ), + (2, False): set( + [ + "fft2d.mpi_with_fftw1d", + "fft2d.mpi_with_fftwmpi2d", + ] + ), + (3, True): set( + [ + "fft3d.with_fftw3d", + "fft3d.with_pyfftw", + "fft3d.with_cufft", + ] + ), + (3, False): set( + [ + "fft3d.mpi_with_fftw1d", + "fft3d.mpi_with_fftwmpi3d", + "fft3d.mpi_with_p3dfft", + "fft3d.mpi_with_pfft", + ] + ), +} + +def _methods_from_plugins(plugins): + return set(plug.name for plug in plugins) + + +def _get_methods(ndim=None, sequential=None): + return _methods_from_plugins(get_plugins(ndim=ndim, sequential=sequential)) + + +def test_plugins(): plugins = get_plugins() assert plugins + + for ndim in (2, 3): + assert _get_methods(ndim=ndim) == methodss[(ndim, True)].union( + methodss[(ndim, False)] + ) + for sequential in (True, False): + assert methodss[(ndim, sequential)] == _get_methods( + ndim=ndim, sequential=sequential + )