diff --git a/aiida_common_workflows/workflows/relax/bigdft/generator.py b/aiida_common_workflows/workflows/relax/bigdft/generator.py index 14536131..7422907c 100644 --- a/aiida_common_workflows/workflows/relax/bigdft/generator.py +++ b/aiida_common_workflows/workflows/relax/bigdft/generator.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- """Implementation of `aiida_common_workflows.common.relax.generator.CommonRelaxInputGenerator` for BigDFT.""" -from aiida import engine, orm, plugins +from aiida import engine, plugins from aiida_common_workflows.common import ElectronicType, RelaxType, SpinType from aiida_common_workflows.generators import ChoiceType, CodeType @@ -176,18 +176,10 @@ def _construct_builder(self, **kwargs) -> engine.ProcessBuilder: builder = self.process_class.get_builder() - if relax_type == RelaxType.POSITIONS: - relaxation_schema = 'relax' - elif relax_type == RelaxType.NONE: - relaxation_schema = 'relax' - builder.relax.perform = orm.Bool(False) - else: - raise ValueError(f'relaxation type `{relax_type.value}` is not supported') - - builder.structure = structure + builder.BigDFT.structure = structure # for now apply simple stupid heuristic : atoms < 200 -> cubic, else -> linear. - if len(builder.structure.sites) <= 200: + if len(builder.BigDFT.structure.sites) <= 200: inputdict = copy.deepcopy(self.get_protocol(protocol)['inputdict_cubic']) else: inputdict = copy.deepcopy(self.get_protocol(protocol)['inputdict_linear']) @@ -200,7 +192,7 @@ def _construct_builder(self, **kwargs) -> engine.ProcessBuilder: else: hgrids = logfile.get('dft').get('hgrids') first_hgrid = hgrids[0] if isinstance(hgrids, list) else hgrids - inputdict['dft']['hgrids'] = first_hgrid * builder.structure.cell_lengths[0] / \ + inputdict['dft']['hgrids'] = first_hgrid * builder.BigDFT.structure.cell_lengths[0] / \ reference_workchain.inputs.structure.cell_lengths[0] if electronic_type is ElectronicType.METAL: @@ -227,12 +219,18 @@ def _construct_builder(self, **kwargs) -> engine.ProcessBuilder: if self.get_protocol(protocol).get('kpoints_distance'): inputdict['kpt'] = {'method': 'auto', 'kptrlen': self.get_protocol(protocol).get('kpoints_distance')} - builder.parameters = BigDFTParameters(dict=inputdict) - builder.code = engines[relaxation_schema]['code'] - run_opts = {'options': engines[relaxation_schema]['options']} - builder.run_opts = orm.Dict(dict=run_opts) + if relax_type == RelaxType.POSITIONS: + inputdict['geopt'] = { + 'method': 'FIRE', + 'forcemax': threshold_forces or 0, + } + elif relax_type == RelaxType.NONE: + pass + else: + raise ValueError(f'relaxation type `{relax_type.value}` is not supported') - if threshold_forces is not None: - builder.relax.threshold_forces = orm.Float(threshold_forces) + builder.BigDFT.parameters = BigDFTParameters(dict=inputdict) + builder.BigDFT.code = engines['relax']['code'] + builder.BigDFT.metadata = {'options': engines['relax']['options']} return builder diff --git a/aiida_common_workflows/workflows/relax/bigdft/workchain.py b/aiida_common_workflows/workflows/relax/bigdft/workchain.py index 324a122c..5d1ad686 100644 --- a/aiida_common_workflows/workflows/relax/bigdft/workchain.py +++ b/aiida_common_workflows/workflows/relax/bigdft/workchain.py @@ -11,7 +11,7 @@ class BigDftCommonRelaxWorkChain(CommonRelaxWorkChain): """Implementation of `aiida_common_workflows.common.relax.workchain.CommonRelaxWorkChain` for BigDFT.""" - _process_class = WorkflowFactory('bigdft.relax') + _process_class = WorkflowFactory('bigdft') _generator_class = BigDftCommonRelaxInputGenerator @classmethod