Skip to content

Commit

Permalink
Better code related to plugins + testing 3d
Browse files Browse the repository at this point in the history
  • Loading branch information
paugier committed Jan 26, 2024
1 parent 08e9282 commit 4a04154
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 42 deletions.
2 changes: 0 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,6 @@ fluidfft-bench = "fluidfft.bench:run"
fluidfft-bench-analysis = "fluidfft.bench_analysis:run"

[project.entry-points."fluidfft.plugins"]
"fft2d.fake_mod_fft2d_for_doc" = "fluidfft.fft2d.fake_mod_fft2d_for_doc"
"fft3d.fake_mod_fft3d_for_doc" = "fluidfft.fft3d.fake_mod_fft3d_for_doc"
"fft2d.with_pyfftw" = "fluidfft.fft2d.with_pyfftw"
"fft3d.with_pyfftw" = "fluidfft.fft3d.with_pyfftw"
"fft2d.with_dask" = "fluidfft.fft2d.with_dask"
Expand Down
56 changes: 43 additions & 13 deletions src/fluidfft/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@
import logging

if sys.version_info < (3, 10):
from importlib_metadata import entry_points
from importlib_metadata import entry_points, EntryPoint
else:
from importlib.metadata import entry_points
from importlib.metadata import entry_points, EntryPoint

from fluiddyn.util import mpi

Expand Down Expand Up @@ -95,7 +95,7 @@ def byte_align(values, *args):
_plugins = None


def get_plugins(reload=False):
def get_plugins(reload=False, ndim=None, sequential=None):
"""Discover the fluidfft plugins installed"""
global _plugins
if _plugins is None or reload:
Expand All @@ -104,7 +104,26 @@ def get_plugins(reload=False):
if not _plugins:
raise RuntimeError("No Fluidfft plugins were found.")

return _plugins
if ndim is None and sequential is None:
return _plugins

if ndim is None:
index = 6
prefix = ""
elif ndim in (2, 3):
index = 0
prefix = f"fft{ndim}d."
else:
raise ValueError(f"Unsupported value for {ndim = }")

if sequential is not None and not sequential:
prefix += "mpi_"
elif sequential:
prefix += "with_"

return tuple(
plugin for plugin in _plugins if plugin.name[index:].startswith(prefix)
)


def get_module_fullname_from_method(method):
Expand Down Expand Up @@ -140,12 +159,10 @@ def _normalize_method_name(method):
return method


def _check_failure(module_fullname):
def _check_failure(method):
"""Check if a tiny fft maker can be created"""

if not any(
module_fullname.endswith(postfix) for postfix in ("pfft", "p3dfft")
):
if not any(method.endswith(postfix) for postfix in ("pfft", "p3dfft")):
return False

# for few methods, try before real import because importing can lead to
Expand All @@ -162,11 +179,12 @@ def _check_failure(module_fullname):
else:
env = os.environ
try:
# TODO: capture stdout and stderr and include last line in case of failure
subprocess.check_call(
[
sys.executable,
"-c",
f"from fluidfft import create_fft_object as c; c({module_fullname}, 2, 2, 2, check=0)",
f"from fluidfft import create_fft_object as c; c('{method}', 2, 2, 2, check=False)",
],
env=env,
shell=False,
Expand Down Expand Up @@ -207,11 +225,15 @@ def import_fft_class(method, raise_import_error=True, check=True):
"""

method = _normalize_method_name(method)
module_fullname = get_module_fullname_from_method(method)
if isinstance(method, EntryPoint):
module_fullname = method.value
method = method.name
else:
method = _normalize_method_name(method)
module_fullname = get_module_fullname_from_method(method)

if check:
failure = _check_failure(module_fullname)
failure = _check_failure(method)
if failure:
if not raise_import_error:
mpi.printby0("ImportError during check:", module_fullname)
Expand All @@ -231,6 +253,14 @@ def import_fft_class(method, raise_import_error=True, check=True):
return mod.FFTclass


def _get_classes(ndim, sequential):
plugins = get_plugins(ndim=ndim, sequential=sequential)
return {
plugin.name: import_fft_class(plugin, raise_import_error=False)
for plugin in plugins
}


def create_fft_object(method, n0, n1, n2=None, check=True):
"""Helper for creating fft objects.
Expand All @@ -253,7 +283,7 @@ def create_fft_object(method, n0, n1, n2=None, check=True):
"""

cls = import_fft_class(method, check)
cls = import_fft_class(method, check=check)

str_module = cls.__module__

Expand Down
23 changes: 3 additions & 20 deletions src/fluidfft/fft3d/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class :class:`fluidfft.fft3d.operators.OperatorsPseudoSpectral3D` defined in

import sys

from .. import import_fft_class
from .. import _get_classes

__all__ = [
"FFT3dFakeForDoc",
Expand All @@ -52,32 +52,15 @@ class :class:`fluidfft.fft3d.operators.OperatorsPseudoSpectral3D` defined in
except ImportError:
pass

methods_seq = ["fftw3d", "pyfftw"]
methods_seq = ["fft3d.with_" + method for method in methods_seq]

methods_mpi = [
"fftw1d",
"fftwmpi3d",
"p3dfft",
"pfft",
]
methods_mpi = ["fft3d.mpi_with_" + method for method in methods_mpi]


def get_classes_seq():
"""Return all sequential 3d classes."""
return {
method: import_fft_class(method, raise_import_error=False)
for method in methods_seq
}
return _get_classes(3, sequential=True)


def get_classes_mpi():
"""Return all parallel 3d classes."""
return {
method: import_fft_class(method, raise_import_error=False)
for method in methods_mpi
}
return _get_classes(3, sequential=False)


if any("pytest" in part for part in sys.argv):
Expand Down
24 changes: 18 additions & 6 deletions tests/test_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from fluiddyn.util import mpi

from fluidfft import import_fft_class
from fluidfft.fft3d import get_classes_seq, get_classes_mpi
from fluidfft.fft3d.testing import make_testop_functions

Expand All @@ -15,19 +16,30 @@
traceback.print_exc()


n = 8
def test_get_classes():
get_classes_seq()
get_classes_mpi()

rank = mpi.rank
nb_proc = mpi.nb_proc

classes_seq = get_classes_seq()
methods_seq = ["fftw3d", "pyfftw"]
methods_seq = ["fft3d.with_" + method for method in methods_seq]
classes_seq = {
method: import_fft_class(method, raise_import_error=False)
for method in methods_seq
}
classes_seq = {name: cls for name, cls in classes_seq.items() if cls is not None}

if not classes_seq:
raise ImportError("Not sequential 2d classes working!")

methods_mpi = ["fftw1d", "fftwmpi3d", "p3dfft", "pfft"]
methods_mpi = ["fft3d.mpi_with_" + method for method in methods_mpi]

nb_proc = mpi.nb_proc
if nb_proc > 1:
classes_mpi = get_classes_mpi()
classes_mpi = {
method: import_fft_class(method, raise_import_error=False)
for method in methods_mpi
}
classes_mpi = {
name: cls for name, cls in classes_mpi.items() if cls is not None
}
Expand Down
52 changes: 51 additions & 1 deletion tests/test_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,57 @@
from fluidfft import get_plugins


def test_toto():
methodss = {
(2, True): set(
[
"fft2d.with_fftw1d",
"fft2d.with_fftw2d",
"fft2d.with_cufft",
"fft2d.with_pyfftw",
"fft2d.with_dask",
]
),
(2, False): set(
[
"fft2d.mpi_with_fftw1d",
"fft2d.mpi_with_fftwmpi2d",
]
),
(3, True): set(
[
"fft3d.with_fftw3d",
"fft3d.with_pyfftw",
"fft3d.with_cufft",
]
),
(3, False): set(
[
"fft3d.mpi_with_fftw1d",
"fft3d.mpi_with_fftwmpi3d",
"fft3d.mpi_with_p3dfft",
"fft3d.mpi_with_pfft",
]
),
}


def _methods_from_plugins(plugins):
return set(plug.name for plug in plugins)


def _get_methods(ndim=None, sequential=None):
return _methods_from_plugins(get_plugins(ndim=ndim, sequential=sequential))


def test_plugins():
plugins = get_plugins()
assert plugins

for ndim in (2, 3):
assert _get_methods(ndim=ndim) == methodss[(ndim, True)].union(
methodss[(ndim, False)]
)
for sequential in (True, False):
assert methodss[(ndim, sequential)] == _get_methods(
ndim=ndim, sequential=sequential
)

0 comments on commit 4a04154

Please sign in to comment.