From 303f3df989b9b09b5c7ff63abd1c56250f925f04 Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Thu, 29 Feb 2024 22:53:51 +0100 Subject: [PATCH] Use custom `WorkflowFactory` to provide plugin install instructions The `WorkflowFactory` from `aiida-core` is replaced with a custom version in the `aiida_common_workflows.plugins.factories` module. This function will call the factory from `aiida-core` but catch the `MissingEntryPointError` exception. In this case, if the entry point corresponds to a plugin implementation of one of the common workflows the exception is reraised but with a useful message that provides the user with the install command to install the necessary plugin package. While this should catch all cases of users trying to load a workflow for a plugin that is not installed through its entry point, it won't catch import errors that are raised when a module is imported directly from that plugin package. Therefore, these imports should not be placed at the top of modules, but placed inside functions/methods of the implementation as much as possible. --- .../plugins/__init__.py | 8 +++- .../plugins/entry_point.py | 6 ++- .../plugins/factories.py | 47 +++++++++++++++++++ .../workflows/relax/abinit/workchain.py | 4 +- .../workflows/relax/castep/generator.py | 17 +++---- .../relax/quantum_espresso/generator.py | 3 +- tests/test_minimal_install.py | 28 +++++++++++ 7 files changed, 98 insertions(+), 15 deletions(-) create mode 100644 src/aiida_common_workflows/plugins/factories.py diff --git a/src/aiida_common_workflows/plugins/__init__.py b/src/aiida_common_workflows/plugins/__init__.py index 2936a136..567bb19c 100644 --- a/src/aiida_common_workflows/plugins/__init__.py +++ b/src/aiida_common_workflows/plugins/__init__.py @@ -1,4 +1,10 @@ """Module with utilities for working with the plugins provided by this plugin package.""" from .entry_point import get_entry_point_name_from_class, get_workflow_entry_point_names, load_workflow_entry_point +from .factories import WorkflowFactory -__all__ = ('get_workflow_entry_point_names', 'get_entry_point_name_from_class', 'load_workflow_entry_point') +__all__ = ( + 'WorkflowFactory', + 'get_workflow_entry_point_names', + 'get_entry_point_name_from_class', + 'load_workflow_entry_point', +) diff --git a/src/aiida_common_workflows/plugins/entry_point.py b/src/aiida_common_workflows/plugins/entry_point.py index 56337038..4dd4f8dc 100644 --- a/src/aiida_common_workflows/plugins/entry_point.py +++ b/src/aiida_common_workflows/plugins/entry_point.py @@ -3,6 +3,8 @@ from aiida.plugins import entry_point +from .factories import WorkflowFactory + PACKAGE_PREFIX = 'common_workflows' __all__ = ('get_workflow_entry_point_names', 'get_entry_point_name_from_class', 'load_workflow_entry_point') @@ -38,5 +40,5 @@ def load_workflow_entry_point(workflow: str, plugin_name: str): :param plugin_name: name of the plugin implementation. :return: the workchain class of the plugin implementation of the common workflow. """ - prefix = f'{PACKAGE_PREFIX}.{workflow}.{plugin_name}' - return entry_point.load_entry_point('aiida.workflows', prefix) + entry_point_name = f'{PACKAGE_PREFIX}.{workflow}.{plugin_name}' + return WorkflowFactory(entry_point_name) diff --git a/src/aiida_common_workflows/plugins/factories.py b/src/aiida_common_workflows/plugins/factories.py new file mode 100644 index 00000000..b8a48435 --- /dev/null +++ b/src/aiida_common_workflows/plugins/factories.py @@ -0,0 +1,47 @@ +"""Factories to load entry points.""" +import typing as t + +from aiida import plugins +from aiida.common import exceptions + +if t.TYPE_CHECKING: + from aiida.engine import WorkChain + from importlib_metadata import EntryPoint + +__all__ = ('WorkflowFactory',) + + +@t.overload +def WorkflowFactory(entry_point_name: str, load: t.Literal[True] = True) -> t.Union[t.Type['WorkChain'], t.Callable]: + ... + + +@t.overload +def WorkflowFactory(entry_point_name: str, load: t.Literal[False]) -> 'EntryPoint': + ... + + +def WorkflowFactory(entry_point_name: str, load: bool = True) -> t.Union['EntryPoint', t.Type['WorkChain'], t.Callable]: # noqa: N802 + """Return the `WorkChain` sub class registered under the given entry point. + + :param entry_point_name: the entry point name. + :param load: if True, load the matched entry point and return the loaded resource instead of the entry point itself. + :return: sub class of :py:class:`~aiida.engine.processes.workchains.workchain.WorkChain` or a `workfunction` + :raises aiida.common.MissingEntryPointError: entry point was not registered + :raises aiida.common.MultipleEntryPointError: entry point could not be uniquely resolved + :raises aiida.common.LoadingEntryPointError: entry point could not be loaded + :raises aiida.common.InvalidEntryPointTypeError: if the type of the loaded entry point is invalid. + """ + common_workflow_prefixes = ('common_workflows.relax.', 'common_workflows.bands.') + try: + return plugins.WorkflowFactory(entry_point_name, load) + except exceptions.MissingEntryPointError as exception: + for prefix in common_workflow_prefixes: + if entry_point_name.startswith(prefix): + plugin_name = entry_point_name.removeprefix(prefix) + raise exceptions.MissingEntryPointError( + f'Could not load the entry point `{entry_point_name}`, probably because the plugin package is not ' + f'installed. Please install it with `pip install aiida-common-workflows[{plugin_name}]`.' + ) from exception + else: # noqa: PLW0120 + raise diff --git a/src/aiida_common_workflows/workflows/relax/abinit/workchain.py b/src/aiida_common_workflows/workflows/relax/abinit/workchain.py index 52083421..64b0a79b 100644 --- a/src/aiida_common_workflows/workflows/relax/abinit/workchain.py +++ b/src/aiida_common_workflows/workflows/relax/abinit/workchain.py @@ -3,7 +3,7 @@ from aiida import orm from aiida.common import exceptions from aiida.engine import calcfunction -from aiida_abinit.workflows.base import AbinitBaseWorkChain +from aiida.plugins import WorkflowFactory from ..workchain import CommonRelaxWorkChain from .generator import AbinitCommonRelaxInputGenerator @@ -44,7 +44,7 @@ def get_total_magnetization(parameters): class AbinitCommonRelaxWorkChain(CommonRelaxWorkChain): """Implementation of `aiida_common_workflows.common.relax.workchain.CommonRelaxWorkChain` for Abinit.""" - _process_class = AbinitBaseWorkChain + _process_class = WorkflowFactory('abinit.base') _generator_class = AbinitCommonRelaxInputGenerator def convert_outputs(self): diff --git a/src/aiida_common_workflows/workflows/relax/castep/generator.py b/src/aiida_common_workflows/workflows/relax/castep/generator.py index 1f380356..4c6e1cbd 100644 --- a/src/aiida_common_workflows/workflows/relax/castep/generator.py +++ b/src/aiida_common_workflows/workflows/relax/castep/generator.py @@ -8,14 +8,15 @@ import yaml from aiida import engine, orm, plugins from aiida.common import exceptions -from aiida_castep.data import get_pseudos_from_structure -from aiida_castep.data.otfg import OTFGGroup from aiida_common_workflows.common import ElectronicType, RelaxType, SpinType from aiida_common_workflows.generators import ChoiceType, CodeType from ..generator import CommonRelaxInputGenerator +if t.TYPE_CHECKING: + from aiida_castep.data.otfg import OTFGGroup + KNOWN_BUILTIN_FAMILIES = ('C19', 'NCP19', 'QC5', 'C17', 'C9') __all__ = ('CastepCommonRelaxInputGenerator',) @@ -247,8 +248,8 @@ def generate_inputs( :param override: a dictionary to override specific inputs :return: input dictionary """ - from aiida.common.lang import type_check + from aiida_castep.data.otfg import OTFGGroup family_name = protocol['relax']['base']['pseudos_family'] if isinstance(family_name, orm.Str): @@ -285,7 +286,7 @@ def generate_inputs_relax( protocol: t.Dict, code: orm.Code, structure: orm.StructureData, - otfg_family: OTFGGroup, + otfg_family: 'OTFGGroup', override: t.Optional[t.Dict[str, t.Any]] = None, ) -> t.Dict[str, t.Any]: """Generate the inputs for the `CastepCommonRelaxWorkChain` for a given code, structure and pseudo potential family. @@ -321,7 +322,7 @@ def generate_inputs_base( protocol: t.Dict, code: orm.Code, structure: orm.StructureData, - otfg_family: OTFGGroup, + otfg_family: 'OTFGGroup', override: t.Optional[t.Dict[str, t.Any]] = None, ) -> t.Dict[str, t.Any]: """Generate the inputs for the `CastepBaseWorkChain` for a given code, structure and pseudo potential family. @@ -359,7 +360,7 @@ def generate_inputs_calculation( protocol: t.Dict, code: orm.Code, structure: orm.StructureData, - otfg_family: OTFGGroup, + otfg_family: 'OTFGGroup', override: t.Optional[t.Dict[str, t.Any]] = None, ) -> t.Dict[str, t.Any]: """Generate the inputs for the `CastepCalculation` for a given code, structure and pseudo potential family. @@ -372,6 +373,7 @@ def generate_inputs_calculation( :return: the fully defined input dictionary. """ from aiida_castep.calculations.helper import CastepHelper + from aiida_castep.data import get_pseudos_from_structure override = {} if not override else override.get('calc', {}) # This merge perserves the merged `parameters` in the override @@ -415,9 +417,8 @@ def ensure_otfg_family(family_name, force_update=False): NOTE: CASTEP also supports UPF families, but it is not enabled here, since no UPS based protocol has been implemented. """ - from aiida.common import NotExistent - from aiida_castep.data.otfg import upload_otfg_family + from aiida_castep.data.otfg import OTFGGroup, upload_otfg_family # Ensure family name is a str if isinstance(family_name, orm.Str): diff --git a/src/aiida_common_workflows/workflows/relax/quantum_espresso/generator.py b/src/aiida_common_workflows/workflows/relax/quantum_espresso/generator.py index 3d1cdfdf..64ab66bf 100644 --- a/src/aiida_common_workflows/workflows/relax/quantum_espresso/generator.py +++ b/src/aiida_common_workflows/workflows/relax/quantum_espresso/generator.py @@ -3,7 +3,6 @@ import yaml from aiida import engine, orm, plugins -from aiida_quantumespresso.workflows.protocols.utils import recursive_merge from aiida_common_workflows.common import ElectronicType, RelaxType, SpinType from aiida_common_workflows.generators import ChoiceType, CodeType @@ -108,8 +107,8 @@ def _construct_builder(self, **kwargs) -> engine.ProcessBuilder: # noqa: PLR091 The keyword arguments will have been validated against the input generator specification. """ - from aiida_quantumespresso.common import types + from aiida_quantumespresso.workflows.protocols.utils import recursive_merge from qe_tools import CONSTANTS structure = kwargs['structure'] diff --git a/tests/test_minimal_install.py b/tests/test_minimal_install.py index 9ca3ec22..e1c682cb 100644 --- a/tests/test_minimal_install.py +++ b/tests/test_minimal_install.py @@ -4,6 +4,8 @@ installed. This guarantees that most of the code can be imported without any plugin packages being installed. """ import pytest +from aiida.common import exceptions +from aiida_common_workflows.plugins import WorkflowFactory, get_workflow_entry_point_names @pytest.mark.minimal_install @@ -18,3 +20,29 @@ def test_imports(): import aiida_common_workflows.workflows import aiida_common_workflows.workflows.dissociation import aiida_common_workflows.workflows.eos # noqa: F401 + + +@pytest.mark.minimal_install +@pytest.mark.parametrize('entry_point_name', get_workflow_entry_point_names('relax')) +def test_workflow_factory_relax(entry_point_name): + """Test that trying to load common relax workflow implementations will raise if not installed. + + The exception message should provide the pip command to install the require plugin package. + """ + plugin_name = entry_point_name.removeprefix('common_workflows.relax.') + match = rf'.*plugin package is not installed.*`pip install aiida-common-workflows\[{plugin_name}\]`.*' + with pytest.raises(exceptions.MissingEntryPointError, match=match): + WorkflowFactory(entry_point_name) + + +@pytest.mark.minimal_install +@pytest.mark.parametrize('entry_point_name', get_workflow_entry_point_names('bands')) +def test_workflow_factory_bands(entry_point_name): + """Test that trying to load common bands workflow implementations will raise if not installed. + + The exception message should provide the pip command to install the require plugin package. + """ + plugin_name = entry_point_name.removeprefix('common_workflows.bands.') + match = rf'.*plugin package is not installed.*`pip install aiida-common-workflows\[{plugin_name}\]`.*' + with pytest.raises(exceptions.MissingEntryPointError, match=match): + WorkflowFactory(entry_point_name)