From 651fd0142a965ca1b03cc52f0f2f8d960936a1cd Mon Sep 17 00:00:00 2001 From: Lorenzo <79980269+bastonero@users.noreply.github.com> Date: Wed, 10 Jul 2024 13:23:37 +0200 Subject: [PATCH 1/2] Add calcjob, parser and base workchain plugin for `bands.x` (#1033) An implementation that wraps the binary `bands.x` was missing. This program is used to find the symmetries of wavefunctions, to re-order the bands, and to perform basic post-processings, such as the calculation of the momentum operator. The current parser only performs basic parsing, with no specialized outputs for k-point wavefunction symmetries, nor momentum operator, nor bands. This can be, for instance, implemented over time depending on user request. --- pyproject.toml | 3 + .../calculations/bands.py | 39 ++++++++++++ src/aiida_quantumespresso/parsers/bands.py | 27 ++++++++ .../workflows/bands/__init__.py | 0 .../workflows/bands/base.py | 61 +++++++++++++++++++ tests/calculations/test_bands.py | 27 ++++++++ .../test_bands/test_bands_default.in | 6 ++ tests/conftest.py | 23 ++++++- .../parsers/fixtures/bands/default/aiida.out | 22 +++++++ tests/parsers/test_bands.py | 24 ++++++++ .../parsers/test_bands/test_bands_default.yml | 2 + tests/workflows/bands/__init__.py | 0 tests/workflows/bands/test_base.py | 54 ++++++++++++++++ 13 files changed, 287 insertions(+), 1 deletion(-) create mode 100644 src/aiida_quantumespresso/calculations/bands.py create mode 100644 src/aiida_quantumespresso/parsers/bands.py create mode 100644 src/aiida_quantumespresso/workflows/bands/__init__.py create mode 100644 src/aiida_quantumespresso/workflows/bands/base.py create mode 100644 tests/calculations/test_bands.py create mode 100644 tests/calculations/test_bands/test_bands_default.in create mode 100644 tests/parsers/fixtures/bands/default/aiida.out create mode 100644 tests/parsers/test_bands.py create mode 100644 tests/parsers/test_bands/test_bands_default.yml create mode 100644 tests/workflows/bands/__init__.py create mode 100644 tests/workflows/bands/test_base.py diff --git a/pyproject.toml b/pyproject.toml index bb18c6e70..687b0fc89 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -86,6 +86,7 @@ aiida-quantumespresso = 'aiida_quantumespresso.cli:cmd_root' 'quantumespresso.seekpath_structure_analysis' = 'aiida_quantumespresso.calculations.functions.seekpath_structure_analysis:seekpath_structure_analysis' 'quantumespresso.xspectra' = 'aiida_quantumespresso.calculations.xspectra:XspectraCalculation' 'quantumespresso.open_grid' = 'aiida_quantumespresso.calculations.open_grid:OpenGridCalculation' +'quantumespresso.bands' = 'aiida_quantumespresso.calculations.bands:BandsCalculation' [project.entry-points.'aiida.data'] 'quantumespresso.force_constants' = 'aiida_quantumespresso.data.force_constants:ForceConstantsData' @@ -105,6 +106,7 @@ aiida-quantumespresso = 'aiida_quantumespresso.cli:cmd_root' 'quantumespresso.pw2wannier90' = 'aiida_quantumespresso.parsers.pw2wannier90:Pw2wannier90Parser' 'quantumespresso.xspectra' = 'aiida_quantumespresso.parsers.xspectra:XspectraParser' 'quantumespresso.open_grid' = 'aiida_quantumespresso.parsers.open_grid:OpenGridParser' +'quantumespresso.bands' = 'aiida_quantumespresso.parsers.bands:BandsParser' [project.entry-points.'aiida.tools.calculations'] 'quantumespresso.pw' = 'aiida_quantumespresso.tools.calculations.pw:PwCalculationTools' @@ -125,6 +127,7 @@ aiida-quantumespresso = 'aiida_quantumespresso.cli:cmd_root' 'quantumespresso.xps' = 'aiida_quantumespresso.workflows.xps:XpsWorkChain' 'quantumespresso.xspectra.core' = 'aiida_quantumespresso.workflows.xspectra.core:XspectraCoreWorkChain' 'quantumespresso.xspectra.crystal' = 'aiida_quantumespresso.workflows.xspectra.crystal:XspectraCrystalWorkChain' +'quantumespresso.bands.base' = 'aiida_quantumespresso.workflows.bands.base:BandsBaseWorkChain' [tool.flit.module] name = 'aiida_quantumespresso' diff --git a/src/aiida_quantumespresso/calculations/bands.py b/src/aiida_quantumespresso/calculations/bands.py new file mode 100644 index 000000000..a29af6012 --- /dev/null +++ b/src/aiida_quantumespresso/calculations/bands.py @@ -0,0 +1,39 @@ +# -*- coding: utf-8 -*- +"""`CalcJob` implementation for the bands.x code of Quantum ESPRESSO.""" + +from aiida import orm + +from aiida_quantumespresso.calculations.namelists import NamelistsCalculation + + +class BandsCalculation(NamelistsCalculation): + """`CalcJob` implementation for the bands.x code of Quantum ESPRESSO. + + bands.x code of the Quantum ESPRESSO distribution, re-orders bands, and computes band-related properties. + + It computes for instance the expectation value of the momentum operator: + . For more information, refer to http://www.quantum-espresso.org/ + """ + + _MOMENTUM_OPERATOR_NAME = 'momentum_operator.dat' + _BANDS_NAME = 'bands.dat' + + _default_namelists = ['BANDS'] + _blocked_keywords = [ + ('BANDS', 'outdir', NamelistsCalculation._OUTPUT_SUBFOLDER), # pylint: disable=protected-access + ('BANDS', 'prefix', NamelistsCalculation._PREFIX), # pylint: disable=protected-access + ('BANDS', 'filband', _BANDS_NAME), + ('BANDS', 'filp', _MOMENTUM_OPERATOR_NAME), # Momentum operator + ] + + _internal_retrieve_list = [] + _default_parser = 'quantumespresso.bands' + + @classmethod + def define(cls, spec): + """Define the process specification.""" + # yapf: disable + super().define(spec) + spec.input('parent_folder', valid_type=(orm.RemoteData, orm.FolderData), required=True) + spec.output('output_parameters', valid_type=orm.Dict) + # yapf: enable diff --git a/src/aiida_quantumespresso/parsers/bands.py b/src/aiida_quantumespresso/parsers/bands.py new file mode 100644 index 000000000..6c2665a84 --- /dev/null +++ b/src/aiida_quantumespresso/parsers/bands.py @@ -0,0 +1,27 @@ +# -*- coding: utf-8 -*- +from aiida.orm import Dict + +from aiida_quantumespresso.utils.mapping import get_logging_container + +from .base import BaseParser + + +class BandsParser(BaseParser): + """``Parser`` implementation for the ``BandsCalculation`` calculation job class.""" + + def parse(self, **kwargs): + """Parse the retrieved files of a ``BandsCalculation`` into output nodes.""" + logs = get_logging_container() + + _, parsed_data, logs = self.parse_stdout_from_retrieved(logs) + + base_exit_code = self.check_base_errors(logs) + if base_exit_code: + return self.exit(base_exit_code, logs) + + self.out('output_parameters', Dict(parsed_data)) + + if 'ERROR_OUTPUT_STDOUT_INCOMPLETE'in logs.error: + return self.exit(self.exit_codes.ERROR_OUTPUT_STDOUT_INCOMPLETE, logs) + + return self.exit(logs=logs) diff --git a/src/aiida_quantumespresso/workflows/bands/__init__.py b/src/aiida_quantumespresso/workflows/bands/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/aiida_quantumespresso/workflows/bands/base.py b/src/aiida_quantumespresso/workflows/bands/base.py new file mode 100644 index 000000000..9e998ca4f --- /dev/null +++ b/src/aiida_quantumespresso/workflows/bands/base.py @@ -0,0 +1,61 @@ +# -*- coding: utf-8 -*- +"""Workchain to run a Quantum ESPRESSO bands.x calculation with automated error handling and restarts.""" +from aiida.common import AttributeDict +from aiida.engine import BaseRestartWorkChain, ProcessHandlerReport, process_handler, while_ +from aiida.plugins import CalculationFactory + +BandsCalculation = CalculationFactory('quantumespresso.bands') + + +class BandsBaseWorkChain(BaseRestartWorkChain): + """Workchain to run a Quantum ESPRESSO bands.x calculation with automated error handling and restarts.""" + + _process_class = BandsCalculation + + @classmethod + def define(cls, spec): + """Define the process specification.""" + # yapf: disable + super().define(spec) + spec.expose_inputs(BandsCalculation, namespace='bands') + spec.expose_outputs(BandsCalculation) + spec.outline( + cls.setup, + while_(cls.should_run_process)( + cls.run_process, + cls.inspect_process, + ), + cls.results, + ) + spec.exit_code(300, 'ERROR_UNRECOVERABLE_FAILURE', + message='The calculation failed with an unrecoverable error.') + # yapf: enable + + def setup(self): + """Call the `setup` of the `BaseRestartWorkChain` and then create the inputs dictionary in `self.ctx.inputs`. + + This `self.ctx.inputs` dictionary will be used by the `BaseRestartWorkChain` to submit the calculations in the + internal loop. + """ + super().setup() + self.ctx.restart_calc = None + self.ctx.inputs = AttributeDict(self.exposed_inputs(BandsCalculation, 'bands')) + + def report_error_handled(self, calculation, action): + """Report an action taken for a calculation that has failed. + + This should be called in a registered error handler if its condition is met and an action was taken. + + :param calculation: the failed calculation node + :param action: a string message with the action taken + """ + arguments = [calculation.process_label, calculation.pk, calculation.exit_status, calculation.exit_message] + self.report('{}<{}> failed with exit status {}: {}'.format(*arguments)) + self.report(f'Action taken: {action}') + + @process_handler(priority=600) + def handle_unrecoverable_failure(self, node): + """Handle calculations with an exit status below 400 which are unrecoverable, so abort the work chain.""" + if node.is_failed and node.exit_status < 400: + self.report_error_handled(node, 'unrecoverable error, aborting...') + return ProcessHandlerReport(True, self.exit_codes.ERROR_UNRECOVERABLE_FAILURE) diff --git a/tests/calculations/test_bands.py b/tests/calculations/test_bands.py new file mode 100644 index 000000000..e06b092f3 --- /dev/null +++ b/tests/calculations/test_bands.py @@ -0,0 +1,27 @@ +# -*- coding: utf-8 -*- +"""Tests for the `BandsCalculation` class.""" +# pylint: disable=protected-access +from aiida.common import datastructures + +from aiida_quantumespresso.calculations.bands import BandsCalculation + + +def test_bands_default(fixture_sandbox, generate_calc_job, generate_inputs_bands, file_regression): + """Test a default `BandsCalculation`.""" + entry_point_name = 'quantumespresso.bands' + + inputs = generate_inputs_bands() + calc_info = generate_calc_job(fixture_sandbox, entry_point_name, inputs) + + retrieve_list = [BandsCalculation._DEFAULT_OUTPUT_FILE] + BandsCalculation._internal_retrieve_list + + # Check the attributes of the returned `CalcInfo` + assert isinstance(calc_info, datastructures.CalcInfo) + assert sorted(calc_info.retrieve_list) == sorted(retrieve_list) + + with fixture_sandbox.open('aiida.in') as handle: + input_written = handle.read() + + # Checks on the files written to the sandbox folder as raw input + assert sorted(fixture_sandbox.get_content_list()) == sorted(['aiida.in']) + file_regression.check(input_written, encoding='utf-8', extension='.in') diff --git a/tests/calculations/test_bands/test_bands_default.in b/tests/calculations/test_bands/test_bands_default.in new file mode 100644 index 000000000..eec8daecd --- /dev/null +++ b/tests/calculations/test_bands/test_bands_default.in @@ -0,0 +1,6 @@ +&BANDS + filband = 'bands.dat' + filp = 'momentum_operator.dat' + outdir = './out/' + prefix = 'aiida' +/ diff --git a/tests/conftest.py b/tests/conftest.py index 8f9fffa32..ee465315d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# pylint: disable=redefined-outer-name,too-many-statements +# pylint: disable=redefined-outer-name,too-many-statements,too-many-lines """Initialise a text database and profile for pytest.""" from collections.abc import Mapping import io @@ -594,6 +594,27 @@ def _generate_inputs_q2r(): return _generate_inputs_q2r +@pytest.fixture +def generate_inputs_bands(fixture_sandbox, fixture_localhost, fixture_code, generate_remote_data): + """Generate default inputs for a `BandsCalculation.""" + + def _generate_inputs_bands(): + """Generate default inputs for a `BandsCalculation.""" + from aiida_quantumespresso.utils.resources import get_default_options + + inputs = { + 'code': fixture_code('quantumespresso.bands'), + 'parent_folder': generate_remote_data(fixture_localhost, fixture_sandbox.abspath, 'quantumespresso.pw'), + 'metadata': { + 'options': get_default_options() + } + } + + return inputs + + return _generate_inputs_bands + + @pytest.fixture def generate_inputs_ph( generate_calc_job_node, generate_structure, fixture_localhost, fixture_code, generate_kpoints_mesh diff --git a/tests/parsers/fixtures/bands/default/aiida.out b/tests/parsers/fixtures/bands/default/aiida.out new file mode 100644 index 000000000..21323f09d --- /dev/null +++ b/tests/parsers/fixtures/bands/default/aiida.out @@ -0,0 +1,22 @@ + + Program BANDS v.7.1 starts on 12Jun2024 at 18:53:54 + + This program is part of the open-source Quantum ESPRESSO suite + for quantum simulation of materials; please cite + "P. Giannozzi et al., J. Phys.:Condens. Matter 21 395502 (2009); + "P. Giannozzi et al., J. Phys.:Condens. Matter 29 465901 (2017); + "P. Giannozzi et al., J. Chem. Phys. 152 154105 (2020); + URL http://www.quantum-espresso.org", + in publications or presentations arising from this work. More details at + http://www.quantum-espresso.org/quote + + # many lines deleted + + BANDS : 0.68s CPU 0.72s WALL + + + This run was terminated on: 18:53:55 12Jun2024 + +=------------------------------------------------------------------------------= + JOB DONE. +=------------------------------------------------------------------------------= diff --git a/tests/parsers/test_bands.py b/tests/parsers/test_bands.py new file mode 100644 index 000000000..bdbecf565 --- /dev/null +++ b/tests/parsers/test_bands.py @@ -0,0 +1,24 @@ +# -*- coding: utf-8 -*- +"""Tests for the `BandsParser`.""" +from aiida import orm + + +def generate_inputs(): + """Return only those inputs that the parser will expect to be there.""" + return {} + + +def test_bands_default(fixture_localhost, generate_calc_job_node, generate_parser, data_regression): + """Test a default `bands.x` calculation.""" + entry_point_calc_job = 'quantumespresso.bands' + entry_point_parser = 'quantumespresso.bands' + + node = generate_calc_job_node(entry_point_calc_job, fixture_localhost, 'default', generate_inputs()) + parser = generate_parser(entry_point_parser) + results, calcfunction = parser.parse_from_node(node, store_provenance=False) + + assert calcfunction.is_finished, calcfunction.exception + assert calcfunction.is_finished_ok, calcfunction.exit_message + assert not orm.Log.collection.get_logs_for(node) + assert 'output_parameters' in results + data_regression.check(results['output_parameters'].get_dict()) diff --git a/tests/parsers/test_bands/test_bands_default.yml b/tests/parsers/test_bands/test_bands_default.yml new file mode 100644 index 000000000..fe5329099 --- /dev/null +++ b/tests/parsers/test_bands/test_bands_default.yml @@ -0,0 +1,2 @@ +code_version: '7.1' +wall_time_seconds: 0.72 diff --git a/tests/workflows/bands/__init__.py b/tests/workflows/bands/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/workflows/bands/test_base.py b/tests/workflows/bands/test_base.py new file mode 100644 index 000000000..2bd6d5528 --- /dev/null +++ b/tests/workflows/bands/test_base.py @@ -0,0 +1,54 @@ +# -*- coding: utf-8 -*- +# pylint: disable=no-member,redefined-outer-name +"""Tests for the `BandsBaseWorkChain` class.""" +from aiida.common import AttributeDict +from aiida.engine import ProcessHandlerReport +from plumpy import ProcessState +import pytest + +from aiida_quantumespresso.calculations.bands import BandsCalculation +from aiida_quantumespresso.workflows.bands.base import BandsBaseWorkChain + + +@pytest.fixture +def generate_workchain_bands(generate_workchain, generate_inputs_bands, generate_calc_job_node): + """Generate an instance of a `BandsBaseWorkChain`.""" + + def _generate_workchain_bands(exit_code=None): + entry_point = 'quantumespresso.bands.base' + process = generate_workchain(entry_point, {'bands': generate_inputs_bands()}) + + if exit_code is not None: + node = generate_calc_job_node() + node.set_process_state(ProcessState.FINISHED) + node.set_exit_status(exit_code.status) + + process.ctx.iteration = 1 + process.ctx.children = [node] + + return process + + return _generate_workchain_bands + + +def test_setup(generate_workchain_bands): + """Test `BandsBaseWorkChain.setup`.""" + process = generate_workchain_bands() + process.setup() + + assert process.ctx.restart_calc is None + assert isinstance(process.ctx.inputs, AttributeDict) + + +def test_handle_unrecoverable_failure(generate_workchain_bands): + """Test `BandsBaseWorkChain.handle_unrecoverable_failure`.""" + process = generate_workchain_bands(exit_code=BandsCalculation.exit_codes.ERROR_NO_RETRIEVED_FOLDER) + process.setup() + + result = process.handle_unrecoverable_failure(process.ctx.children[-1]) + assert isinstance(result, ProcessHandlerReport) + assert result.do_break + assert result.exit_code == BandsBaseWorkChain.exit_codes.ERROR_UNRECOVERABLE_FAILURE + + result = process.inspect_process() + assert result == BandsBaseWorkChain.exit_codes.ERROR_UNRECOVERABLE_FAILURE From b79189d7ce4756e846ab39c567ba4681474741ed Mon Sep 17 00:00:00 2001 From: Peter Gillespie <55498719+PNOGillespie@users.noreply.github.com> Date: Thu, 11 Jul 2024 10:17:42 +0100 Subject: [PATCH 2/2] `XspectraCrystalWorkChain`: Enable Symmetry Data Inputs (#1028) Adds an input namespace for the `XspectraCrystalWorkChain` which allows the user to define the spacegroup and equivalent sites data for the incoming structure, thus instructing the WorkChain to generate structures and run calculations for only the sites specified. Changes: * Adds the `symmetry_data` input namespace to `XspectraCrystalWorkChain`, which the `WorkChain` will use to generate structures and set the list of polarisation vectors to calculate. * Adds input validation steps for the symmetry data to check for required information and for entries which may cause a crash, though does not check for issues beyond this in order to maximise flexibility of use. * Fixes an oversight in `get_xspectra_structures` where the `supercell` entry was not returned to the outputs when external symmetry data were provided by the user. --- .../functions/get_xspectra_structures.py | 1 + .../workflows/xspectra/crystal.py | 105 ++++++++++++++---- 2 files changed, 82 insertions(+), 24 deletions(-) diff --git a/src/aiida_quantumespresso/workflows/functions/get_xspectra_structures.py b/src/aiida_quantumespresso/workflows/functions/get_xspectra_structures.py index 369bfa451..7dbcbaf06 100644 --- a/src/aiida_quantumespresso/workflows/functions/get_xspectra_structures.py +++ b/src/aiida_quantumespresso/workflows/functions/get_xspectra_structures.py @@ -360,6 +360,7 @@ def get_xspectra_structures(structure, **kwargs): # pylint: disable=too-many-st new_supercell = get_supercell_result['new_supercell'] output_params['supercell_factors'] = multiples + result['supercell'] = new_supercell output_params['supercell_num_sites'] = len(new_supercell.sites) output_params['supercell_cell_matrix'] = new_supercell.cell output_params['supercell_cell_lengths'] = new_supercell.cell_lengths diff --git a/src/aiida_quantumespresso/workflows/xspectra/crystal.py b/src/aiida_quantumespresso/workflows/xspectra/crystal.py index 1d56a50f9..76c866fca 100644 --- a/src/aiida_quantumespresso/workflows/xspectra/crystal.py +++ b/src/aiida_quantumespresso/workflows/xspectra/crystal.py @@ -4,7 +4,7 @@ Uses QuantumESPRESSO pw.x and xspectra.x. """ from aiida import orm -from aiida.common import AttributeDict, ValidationError +from aiida.common import AttributeDict from aiida.engine import ToContext, WorkChain, if_ from aiida.orm import UpfData as aiida_core_upf from aiida.plugins import CalculationFactory, DataFactory, WorkflowFactory @@ -173,6 +173,19 @@ def define(cls, spec): help=('Input namespace to provide core wavefunction inputs for each element. Must follow the format: ' '``core_wfc_data__{symbol} = {node}``') ) + spec.input_namespace( + 'symmetry_data', + valid_type=(orm.Dict, orm.Int), + dynamic=True, + required=False, + help=( + 'Input namespace to define equivalent sites and spacegroup number for the system. If defined, will ' + 'skip symmetry analysis and structure standardization. Use *only* if symmetry data are known ' + 'for certain. Requires ``spacegroup_number`` (Int) and ``equivalent_sites_data`` (Dict) to be ' + 'defined separately. All keys in `equivalent_sites_data` must be formatted as "site_". ' + 'See docstring of `get_xspectra_structures` for more information about inputs.' + ) + ) spec.inputs.validator = cls.validate_inputs spec.outline( cls.setup, @@ -370,7 +383,7 @@ def get_builder_from_protocol( # pylint: disable=too-many-statements @staticmethod - def validate_inputs(inputs, _): + def validate_inputs(inputs, _): # pylint: disable=too-many-return-statements """Validate the inputs before launching the WorkChain.""" structure = inputs['structure'] kinds_present = [kind.name for kind in structure.kinds] @@ -382,54 +395,92 @@ def validate_inputs(inputs, _): if element not in elements_present: extra_elements.append(element) if len(extra_elements) > 0: - raise ValidationError( + return ( f'Some elements in ``elements_list`` {extra_elements} do not exist in the' f' structure provided {elements_present}.' ) abs_atom_marker = inputs['abs_atom_marker'].value if abs_atom_marker in kinds_present: - raise ValidationError( + return ( f'The marker given for the absorbing atom ("{abs_atom_marker}") matches an existing Kind in the ' f'input structure ({kinds_present}).' ) if not inputs['core']['get_powder_spectrum'].value: - raise ValidationError( + return ( 'The ``get_powder_spectrum`` input for the XspectraCoreWorkChain namespace must be ``True``.' ) if 'upf2plotcore_code' not in inputs and 'core_wfc_data' not in inputs: - raise ValidationError( + return ( 'Neither a ``Code`` node for upf2plotcore.sh or a set of ``core_wfc_data`` were provided.' ) if 'core_wfc_data' in inputs: core_wfc_data_list = sorted(inputs['core_wfc_data'].keys()) if core_wfc_data_list != absorbing_elements_list: - raise ValidationError( + return ( f'The ``core_wfc_data`` provided ({core_wfc_data_list}) does not match the list of' f' absorbing elements ({absorbing_elements_list})' ) - else: - empty_core_wfc_data = [] - for key, value in inputs['core_wfc_data'].items(): - header_line = value.get_content()[:40] - try: - num_core_states = int(header_line.split(' ')[5]) - except Exception as exc: - raise ValidationError( - 'The core wavefunction data file is not of the correct format' - ) from exc - if num_core_states == 0: - empty_core_wfc_data.append(key) - if len(empty_core_wfc_data) > 0: - raise ValidationError( - f'The ``core_wfc_data`` provided for elements {empty_core_wfc_data} do not contain ' - 'any wavefunction data.' - ) + empty_core_wfc_data = [] + for key, value in inputs['core_wfc_data'].items(): + header_line = value.get_content()[:40] + try: + num_core_states = int(header_line.split(' ')[5]) + except: # pylint: disable=bare-except + return ( + 'The core wavefunction data file is not of the correct format' + ) # pylint: enable=bare-except + if num_core_states == 0: + empty_core_wfc_data.append(key) + if len(empty_core_wfc_data) > 0: + return ( + f'The ``core_wfc_data`` provided for elements {empty_core_wfc_data} do not contain ' + 'any wavefunction data.' + ) + + if 'symmetry_data' in inputs: + spacegroup_number = inputs['symmetry_data']['spacegroup_number'].value + equivalent_sites_data = inputs['symmetry_data']['equivalent_sites_data'].get_dict() + if spacegroup_number <= 0 or spacegroup_number >= 231: + return ( + f'Input spacegroup number ({spacegroup_number}) outside of valid range (1-230).' + ) + input_elements = [] + required_keys = sorted(['symbol', 'multiplicity', 'kind_name', 'site_index']) + invalid_entries = [] + # We check three things here: (1) are there any site indices which are outside of the possible + # range of site indices (2) do we have all the required keys for each entry, + # and (3) is there a mismatch between `absorbing_elements_list` and the elements specified + # in the entries of `equivalent_sites_data`. These checks are intended only to avoid a crash. + # We assume otherwise that the user knows what they're doing and has set everything else + # to their preferences correctly. + for site_label, value in equivalent_sites_data.items(): + if not set(required_keys).issubset(set(value.keys())) : + invalid_entries.append(site_label) + elif value['symbol'] not in input_elements: + input_elements.append(value['symbol']) + if value['site_index'] < 0 or value['site_index'] >= len(structure.sites): + return ( + f'The site index for {site_label} ({value["site_index"]}) is outside the range of ' + + f'sites within the structure (0-{len(structure.sites) -1}).' + ) + + if len(invalid_entries) != 0: + return ( + f'The required keys ({required_keys}) were not found in the following entries: {invalid_entries}' + ) + + sorted_input_elements = sorted(input_elements) + if sorted_input_elements != absorbing_elements_list: + return (f'Elements defined for sites in `equivalent_sites_data` ({sorted_input_elements}) ' + f'do not match the list of absorbing elements ({absorbing_elements_list})') + + # pylint: enable=too-many-return-statements def setup(self): """Set required context variables.""" if 'core_wfc_data' in self.inputs.keys(): @@ -489,6 +540,12 @@ def get_xspectra_structures(self): if 'spglib_settings' in self.inputs: inputs['spglib_settings'] = self.inputs.spglib_settings + if 'symmetry_data' in self.inputs: + inputs['parse_symmetry'] = orm.Bool(False) + input_sym_data = self.inputs.symmetry_data + inputs['equivalent_sites_data'] = input_sym_data['equivalent_sites_data'] + inputs['spacegroup_number'] = input_sym_data['spacegroup_number'] + if 'relax' in self.inputs: result = get_xspectra_structures(self.ctx.optimized_structure, **inputs) else: