Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify get_builder_from_protocol in ProjwfcBandsWorkChain #56

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 3 additions & 19 deletions src/aiida_wannier90_workflows/workflows/projwfcbands.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,6 @@ def get_builder_from_protocol( # pylint: disable=arguments-differ
"""
from aiida_wannier90_workflows.utils.workflows.builder.submit import (
recursive_merge_builder,
recursive_merge_container,
)

type_check(pw_code, (str, int, orm.Code))
Expand All @@ -119,10 +118,6 @@ def get_builder_from_protocol( # pylint: disable=arguments-differ
# Prepare workchain builder
builder = cls.get_builder()

protocol_inputs = cls.get_protocol_inputs(
protocol=protocol, overrides=overrides
)

Comment on lines -122 to -125
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @t-reents , thanks for the PR.

These lines should remain. They are used to read the protocol parameters from the file, projwfcbands.yaml. For the moment, the file only specifies clean_workdir. I think that's why you get identical results even after you remove them. However, more parameters may be added in the future.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @superstar54

I see the point that one might create a specific protocol for this WorkChain. I also mentioned this in the PR description. My point was that this WorkChain doesn't implement any additional inputs and basically inherits the PwBandsWorkChain only. I assumed that one would simply provide the individual changes via the overrides.

But yeah, it makes sense to have the possibility to specify a protocol to overwrite the defaults of the parents. Otherwise, one would always need to provide adjusted protocols for PwBandsWorkChain and ProjwfcBaseWorkChain.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@superstar54 Coming back to the point, that the ProjwfcBandsWorkChain doesn't have additional inputs, I'd suggest to load the protocol inputs and merge the relevant parts with the overrides for the subsequent calls of PwBandsWorkChain and ProjwfcBaseWorkChain. In this way, we still have the possibility to provide custom protocols but also fix the initial problem that the transformed values get overwritten by the Python base types.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fine to me. If you do this, please also remove the projwfcbands.yaml file, and then add a detailed comment in the code that the WorkChain uses the parent protocol setting.

It would be good that @qiaojunfeng could have a quick look at it.

projwfc_overrides = None
if overrides:
projwfc_overrides = overrides.pop("projwfc", None)
Expand All @@ -137,25 +132,14 @@ def get_builder_from_protocol( # pylint: disable=arguments-differ

# By default do not run relax
pwbands_builder.pop("relax", None)
inputs = pwbands_builder._inputs(prune=True) # pylint: disable=protected-access

projwfc_builder = ProjwfcBaseWorkChain.get_builder_from_protocol(
projwfc_code, protocol=protocol, overrides=projwfc_overrides
)
projwfc_builder.pop("clean_workdir", None)

inputs["projwfc"] = projwfc_builder._inputs( # pylint: disable=protected-access
prune=True
)
inputs["projwfc"].pop("clean_workdir", None)

# Need to convert `clean_workdir` to `orm.Bool`
if "clean_workdir" in protocol_inputs:
protocol_inputs["clean_workdir"] = orm.Bool(
protocol_inputs["clean_workdir"]
)

inputs = recursive_merge_container(inputs, protocol_inputs)
builder = recursive_merge_builder(builder, inputs)
builder.projwfc = projwfc_builder
builder = recursive_merge_builder(builder, pwbands_builder)

return builder

Expand Down
Loading