Skip to content

Commit

Permalink
Various fixes to improve robustness of workchain inputs validation (#31)
Browse files Browse the repository at this point in the history
* Small fix to improve robustness

* Print less warnings

* Add warning for bands_kpoints
  • Loading branch information
qiaojunfeng authored Nov 8, 2023
1 parent 4a46dc6 commit 48489f7
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 13 deletions.
34 changes: 21 additions & 13 deletions src/aiida_wannier90_workflows/workflows/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,12 @@ def validate_inputs(inputs, ctx=None): # pylint: disable=unused-argument

parameters = inputs["wannier90"]["wannier90"]["parameters"].get_dict()

if inputs["optimize_disproj"]:
optimize_disproj = inputs.get("optimize_disproj", True)
if optimize_disproj:
if all(_ not in parameters for _ in ("dis_proj_min", "dis_proj_max")):
return "Trying to optimize dis_proj_min/max but no dis_proj_min/max in wannier90 parameters?"

if "optimize_reference_bands" in inputs and not inputs["optimize_disproj"]:
if "optimize_reference_bands" in inputs and not optimize_disproj:
warnings.warn(
"`optimize_reference_bands` is provided but `optimize_disproj = False`?"
)
Expand All @@ -46,22 +47,24 @@ def validate_inputs(inputs, ctx=None): # pylint: disable=unused-argument
):
return "No `optimize_reference_bands` but `optimize_bands_distance_threshold` is set?"

if inputs["separate_plotting"]:
plot_inputs = [
parameters.get(_, False)
for _ in Wannier90OptimizeWorkChain._WANNIER90_PLOT_INPUTS # pylint: disable=protected-access
]
separate_plotting = inputs.get("separate_plotting", False)
plot_inputs = [
parameters.get(_, False)
# for _ in Wannier90OptimizeWorkChain._WANNIER90_PLOT_INPUTS # pylint: disable=protected-access
for _ in ["wannier_plot"]
]
if separate_plotting:
if not any(plot_inputs):
return (
"Trying to separate plotting routines but no "
f"{'/'.join(Wannier90OptimizeWorkChain._WANNIER90_PLOT_INPUTS)} in wannier90 parameters?" # pylint: disable=protected-access
)

if inputs["optimize_disproj"] and not inputs["separate_plotting"]:
warnings.warn(
"`optimize_disproj = True` but `separate_plotting = False`. For optimizing projectability "
"disentanglement, it is highly recommended to run the plotting mode in a separate step."
)
else:
if optimize_disproj and any(plot_inputs):
warnings.warn(
"`optimize_disproj = True` but `separate_plotting = False`. For optimizing projectability "
"disentanglement, it is highly recommended to run the plotting mode in a separate step."
)

return None

Expand Down Expand Up @@ -262,6 +265,11 @@ def get_builder_from_protocol( # pylint: disable=arguments-differ
)

kwargs.setdefault("projection_type", WannierProjectionType.ATOMIC_PROJECTORS_QE)
if reference_bands and kwargs.get("bands_kpoints", None) is None:
warnings.warn(
"It is recommended to provide both `reference_bands` and `bands_kpoints` so that"
" the seekpath step can be skipped."
)
parent_builder = super().get_builder_from_protocol(codes, structure, **kwargs)

if reference_bands is not None:
Expand Down
1 change: 1 addition & 0 deletions src/aiida_wannier90_workflows/workflows/wannier90.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,7 @@ def get_builder_from_protocol( # pylint: disable=unused-argument
projection_type=projection_type,
disentanglement_type=disentanglement_type,
frozen_type=frozen_type,
pseudo_family=pseudo_family,
)
# Remove workchain excluded inputs
wannier_builder["wannier90"].pop("structure", None)
Expand Down

0 comments on commit 48489f7

Please sign in to comment.