Skip to content

Commit

Permalink
Add calcjob, parser and base workchain plugin for bands.x (#1033)
Browse files Browse the repository at this point in the history
An implementation that wraps the binary `bands.x` was missing.
This program is used to find the symmetries of wavefunctions,
to re-order the bands, and to perform basic post-processings,
such as the calculation of the momentum operator.

The current parser only performs basic parsing, with no specialized
outputs for k-point wavefunction symmetries, nor momentum operator,
nor bands. This can be, for instance, implemented over time depending
on user request.
  • Loading branch information
bastonero authored Jul 10, 2024
1 parent 91c3e1d commit 651fd01
Show file tree
Hide file tree
Showing 13 changed files with 287 additions and 1 deletion.
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ aiida-quantumespresso = 'aiida_quantumespresso.cli:cmd_root'
'quantumespresso.seekpath_structure_analysis' = 'aiida_quantumespresso.calculations.functions.seekpath_structure_analysis:seekpath_structure_analysis'
'quantumespresso.xspectra' = 'aiida_quantumespresso.calculations.xspectra:XspectraCalculation'
'quantumespresso.open_grid' = 'aiida_quantumespresso.calculations.open_grid:OpenGridCalculation'
'quantumespresso.bands' = 'aiida_quantumespresso.calculations.bands:BandsCalculation'

[project.entry-points.'aiida.data']
'quantumespresso.force_constants' = 'aiida_quantumespresso.data.force_constants:ForceConstantsData'
Expand All @@ -105,6 +106,7 @@ aiida-quantumespresso = 'aiida_quantumespresso.cli:cmd_root'
'quantumespresso.pw2wannier90' = 'aiida_quantumespresso.parsers.pw2wannier90:Pw2wannier90Parser'
'quantumespresso.xspectra' = 'aiida_quantumespresso.parsers.xspectra:XspectraParser'
'quantumespresso.open_grid' = 'aiida_quantumespresso.parsers.open_grid:OpenGridParser'
'quantumespresso.bands' = 'aiida_quantumespresso.parsers.bands:BandsParser'

[project.entry-points.'aiida.tools.calculations']
'quantumespresso.pw' = 'aiida_quantumespresso.tools.calculations.pw:PwCalculationTools'
Expand All @@ -125,6 +127,7 @@ aiida-quantumespresso = 'aiida_quantumespresso.cli:cmd_root'
'quantumespresso.xps' = 'aiida_quantumespresso.workflows.xps:XpsWorkChain'
'quantumespresso.xspectra.core' = 'aiida_quantumespresso.workflows.xspectra.core:XspectraCoreWorkChain'
'quantumespresso.xspectra.crystal' = 'aiida_quantumespresso.workflows.xspectra.crystal:XspectraCrystalWorkChain'
'quantumespresso.bands.base' = 'aiida_quantumespresso.workflows.bands.base:BandsBaseWorkChain'

[tool.flit.module]
name = 'aiida_quantumespresso'
Expand Down
39 changes: 39 additions & 0 deletions src/aiida_quantumespresso/calculations/bands.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# -*- coding: utf-8 -*-
"""`CalcJob` implementation for the bands.x code of Quantum ESPRESSO."""

from aiida import orm

from aiida_quantumespresso.calculations.namelists import NamelistsCalculation


class BandsCalculation(NamelistsCalculation):
"""`CalcJob` implementation for the bands.x code of Quantum ESPRESSO.
bands.x code of the Quantum ESPRESSO distribution, re-orders bands, and computes band-related properties.
It computes for instance the expectation value of the momentum operator:
<Psi(n,k) | i * m * [H, x] | Psi(m,k)>. For more information, refer to http://www.quantum-espresso.org/
"""

_MOMENTUM_OPERATOR_NAME = 'momentum_operator.dat'
_BANDS_NAME = 'bands.dat'

_default_namelists = ['BANDS']
_blocked_keywords = [
('BANDS', 'outdir', NamelistsCalculation._OUTPUT_SUBFOLDER), # pylint: disable=protected-access
('BANDS', 'prefix', NamelistsCalculation._PREFIX), # pylint: disable=protected-access
('BANDS', 'filband', _BANDS_NAME),
('BANDS', 'filp', _MOMENTUM_OPERATOR_NAME), # Momentum operator
]

_internal_retrieve_list = []
_default_parser = 'quantumespresso.bands'

@classmethod
def define(cls, spec):
"""Define the process specification."""
# yapf: disable
super().define(spec)
spec.input('parent_folder', valid_type=(orm.RemoteData, orm.FolderData), required=True)
spec.output('output_parameters', valid_type=orm.Dict)
# yapf: enable
27 changes: 27 additions & 0 deletions src/aiida_quantumespresso/parsers/bands.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# -*- coding: utf-8 -*-
from aiida.orm import Dict

from aiida_quantumespresso.utils.mapping import get_logging_container

from .base import BaseParser


class BandsParser(BaseParser):
"""``Parser`` implementation for the ``BandsCalculation`` calculation job class."""

def parse(self, **kwargs):
"""Parse the retrieved files of a ``BandsCalculation`` into output nodes."""
logs = get_logging_container()

_, parsed_data, logs = self.parse_stdout_from_retrieved(logs)

base_exit_code = self.check_base_errors(logs)
if base_exit_code:
return self.exit(base_exit_code, logs)

self.out('output_parameters', Dict(parsed_data))

if 'ERROR_OUTPUT_STDOUT_INCOMPLETE'in logs.error:
return self.exit(self.exit_codes.ERROR_OUTPUT_STDOUT_INCOMPLETE, logs)

return self.exit(logs=logs)
Empty file.
61 changes: 61 additions & 0 deletions src/aiida_quantumespresso/workflows/bands/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# -*- coding: utf-8 -*-
"""Workchain to run a Quantum ESPRESSO bands.x calculation with automated error handling and restarts."""
from aiida.common import AttributeDict
from aiida.engine import BaseRestartWorkChain, ProcessHandlerReport, process_handler, while_
from aiida.plugins import CalculationFactory

BandsCalculation = CalculationFactory('quantumespresso.bands')


class BandsBaseWorkChain(BaseRestartWorkChain):
"""Workchain to run a Quantum ESPRESSO bands.x calculation with automated error handling and restarts."""

_process_class = BandsCalculation

@classmethod
def define(cls, spec):
"""Define the process specification."""
# yapf: disable
super().define(spec)
spec.expose_inputs(BandsCalculation, namespace='bands')
spec.expose_outputs(BandsCalculation)
spec.outline(
cls.setup,
while_(cls.should_run_process)(
cls.run_process,
cls.inspect_process,
),
cls.results,
)
spec.exit_code(300, 'ERROR_UNRECOVERABLE_FAILURE',
message='The calculation failed with an unrecoverable error.')
# yapf: enable

def setup(self):
"""Call the `setup` of the `BaseRestartWorkChain` and then create the inputs dictionary in `self.ctx.inputs`.
This `self.ctx.inputs` dictionary will be used by the `BaseRestartWorkChain` to submit the calculations in the
internal loop.
"""
super().setup()
self.ctx.restart_calc = None
self.ctx.inputs = AttributeDict(self.exposed_inputs(BandsCalculation, 'bands'))

def report_error_handled(self, calculation, action):
"""Report an action taken for a calculation that has failed.
This should be called in a registered error handler if its condition is met and an action was taken.
:param calculation: the failed calculation node
:param action: a string message with the action taken
"""
arguments = [calculation.process_label, calculation.pk, calculation.exit_status, calculation.exit_message]
self.report('{}<{}> failed with exit status {}: {}'.format(*arguments))
self.report(f'Action taken: {action}')

@process_handler(priority=600)
def handle_unrecoverable_failure(self, node):
"""Handle calculations with an exit status below 400 which are unrecoverable, so abort the work chain."""
if node.is_failed and node.exit_status < 400:
self.report_error_handled(node, 'unrecoverable error, aborting...')
return ProcessHandlerReport(True, self.exit_codes.ERROR_UNRECOVERABLE_FAILURE)
27 changes: 27 additions & 0 deletions tests/calculations/test_bands.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# -*- coding: utf-8 -*-
"""Tests for the `BandsCalculation` class."""
# pylint: disable=protected-access
from aiida.common import datastructures

from aiida_quantumespresso.calculations.bands import BandsCalculation


def test_bands_default(fixture_sandbox, generate_calc_job, generate_inputs_bands, file_regression):
"""Test a default `BandsCalculation`."""
entry_point_name = 'quantumespresso.bands'

inputs = generate_inputs_bands()
calc_info = generate_calc_job(fixture_sandbox, entry_point_name, inputs)

retrieve_list = [BandsCalculation._DEFAULT_OUTPUT_FILE] + BandsCalculation._internal_retrieve_list

# Check the attributes of the returned `CalcInfo`
assert isinstance(calc_info, datastructures.CalcInfo)
assert sorted(calc_info.retrieve_list) == sorted(retrieve_list)

with fixture_sandbox.open('aiida.in') as handle:
input_written = handle.read()

# Checks on the files written to the sandbox folder as raw input
assert sorted(fixture_sandbox.get_content_list()) == sorted(['aiida.in'])
file_regression.check(input_written, encoding='utf-8', extension='.in')
6 changes: 6 additions & 0 deletions tests/calculations/test_bands/test_bands_default.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
&BANDS
filband = 'bands.dat'
filp = 'momentum_operator.dat'
outdir = './out/'
prefix = 'aiida'
/
23 changes: 22 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# pylint: disable=redefined-outer-name,too-many-statements
# pylint: disable=redefined-outer-name,too-many-statements,too-many-lines
"""Initialise a text database and profile for pytest."""
from collections.abc import Mapping
import io
Expand Down Expand Up @@ -594,6 +594,27 @@ def _generate_inputs_q2r():
return _generate_inputs_q2r


@pytest.fixture
def generate_inputs_bands(fixture_sandbox, fixture_localhost, fixture_code, generate_remote_data):
"""Generate default inputs for a `BandsCalculation."""

def _generate_inputs_bands():
"""Generate default inputs for a `BandsCalculation."""
from aiida_quantumespresso.utils.resources import get_default_options

inputs = {
'code': fixture_code('quantumespresso.bands'),
'parent_folder': generate_remote_data(fixture_localhost, fixture_sandbox.abspath, 'quantumespresso.pw'),
'metadata': {
'options': get_default_options()
}
}

return inputs

return _generate_inputs_bands


@pytest.fixture
def generate_inputs_ph(
generate_calc_job_node, generate_structure, fixture_localhost, fixture_code, generate_kpoints_mesh
Expand Down
22 changes: 22 additions & 0 deletions tests/parsers/fixtures/bands/default/aiida.out
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@

Program BANDS v.7.1 starts on 12Jun2024 at 18:53:54

This program is part of the open-source Quantum ESPRESSO suite
for quantum simulation of materials; please cite
"P. Giannozzi et al., J. Phys.:Condens. Matter 21 395502 (2009);
"P. Giannozzi et al., J. Phys.:Condens. Matter 29 465901 (2017);
"P. Giannozzi et al., J. Chem. Phys. 152 154105 (2020);
URL http://www.quantum-espresso.org",
in publications or presentations arising from this work. More details at
http://www.quantum-espresso.org/quote

# many lines deleted

BANDS : 0.68s CPU 0.72s WALL


This run was terminated on: 18:53:55 12Jun2024

=------------------------------------------------------------------------------=
JOB DONE.
=------------------------------------------------------------------------------=
24 changes: 24 additions & 0 deletions tests/parsers/test_bands.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# -*- coding: utf-8 -*-
"""Tests for the `BandsParser`."""
from aiida import orm


def generate_inputs():
"""Return only those inputs that the parser will expect to be there."""
return {}


def test_bands_default(fixture_localhost, generate_calc_job_node, generate_parser, data_regression):
"""Test a default `bands.x` calculation."""
entry_point_calc_job = 'quantumespresso.bands'
entry_point_parser = 'quantumespresso.bands'

node = generate_calc_job_node(entry_point_calc_job, fixture_localhost, 'default', generate_inputs())
parser = generate_parser(entry_point_parser)
results, calcfunction = parser.parse_from_node(node, store_provenance=False)

assert calcfunction.is_finished, calcfunction.exception
assert calcfunction.is_finished_ok, calcfunction.exit_message
assert not orm.Log.collection.get_logs_for(node)
assert 'output_parameters' in results
data_regression.check(results['output_parameters'].get_dict())
2 changes: 2 additions & 0 deletions tests/parsers/test_bands/test_bands_default.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
code_version: '7.1'
wall_time_seconds: 0.72
Empty file.
54 changes: 54 additions & 0 deletions tests/workflows/bands/test_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# -*- coding: utf-8 -*-
# pylint: disable=no-member,redefined-outer-name
"""Tests for the `BandsBaseWorkChain` class."""
from aiida.common import AttributeDict
from aiida.engine import ProcessHandlerReport
from plumpy import ProcessState
import pytest

from aiida_quantumespresso.calculations.bands import BandsCalculation
from aiida_quantumespresso.workflows.bands.base import BandsBaseWorkChain


@pytest.fixture
def generate_workchain_bands(generate_workchain, generate_inputs_bands, generate_calc_job_node):
"""Generate an instance of a `BandsBaseWorkChain`."""

def _generate_workchain_bands(exit_code=None):
entry_point = 'quantumespresso.bands.base'
process = generate_workchain(entry_point, {'bands': generate_inputs_bands()})

if exit_code is not None:
node = generate_calc_job_node()
node.set_process_state(ProcessState.FINISHED)
node.set_exit_status(exit_code.status)

process.ctx.iteration = 1
process.ctx.children = [node]

return process

return _generate_workchain_bands


def test_setup(generate_workchain_bands):
"""Test `BandsBaseWorkChain.setup`."""
process = generate_workchain_bands()
process.setup()

assert process.ctx.restart_calc is None
assert isinstance(process.ctx.inputs, AttributeDict)


def test_handle_unrecoverable_failure(generate_workchain_bands):
"""Test `BandsBaseWorkChain.handle_unrecoverable_failure`."""
process = generate_workchain_bands(exit_code=BandsCalculation.exit_codes.ERROR_NO_RETRIEVED_FOLDER)
process.setup()

result = process.handle_unrecoverable_failure(process.ctx.children[-1])
assert isinstance(result, ProcessHandlerReport)
assert result.do_break
assert result.exit_code == BandsBaseWorkChain.exit_codes.ERROR_UNRECOVERABLE_FAILURE

result = process.inspect_process()
assert result == BandsBaseWorkChain.exit_codes.ERROR_UNRECOVERABLE_FAILURE

0 comments on commit 651fd01

Please sign in to comment.