Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Feb 16, 2024
1 parent 1701a2d commit 8bcdc24
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 16 deletions.
34 changes: 22 additions & 12 deletions aiida_nanotech_empa/workflows/cp2k/reftraj_md_workchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@
# Cp2kRefTrajWorkChain = plugins.WorkflowFactory("cp2k.reftraj")
TrajectoryData = plugins.DataFactory("array.trajectory")


@engine.calcfunction
def merge_trajectories(*trajectories):
"""Merge a list of trajectories into a single one."""
positions=[]
cells=[]
forces=[]
positions = []
cells = []
forces = []
for trajectory in trajectories:
positions.append(trajectory.get_array("positions") )
positions.append(trajectory.get_array("positions"))
try:
cells.append(trajectory.get_array("cells"))
except KeyError:
Expand All @@ -31,11 +32,12 @@ def merge_trajectories(*trajectories):
if len(cells) == 0:
merged_trajectory.set_trajectory(symbols, positions)
else:
merged_trajectory.set_trajectory(symbols, positions, cells=cells)
merged_trajectory.set_trajectory(symbols, positions, cells=cells)
merged_trajectory.set_array("forces", forces)

return merged_trajectory


@engine.calcfunction
def create_batches(trajectory, num_batches, steps_completed):
"""Create lists of consecutive integers. Counting start from 1 for CP2K input."""
Expand All @@ -56,7 +58,7 @@ def create_batches(trajectory, num_batches, steps_completed):
current_list = []
if current_list:
consecutive_lists.append(current_list)

# [[1,2,3],[4,5,6]] --> [[1],[2,3],[4,5,6]]
batches = [[consecutive_lists[0].pop(0)]]
for batch in consecutive_lists:
Expand Down Expand Up @@ -187,8 +189,10 @@ def run_reftraj_batches(self):
builder.cp2k.metadata.label = f"structures_{batch[0]}_to_{batch[-1]}"
builder.cp2k.metadata.options.parser_name = "cp2k_advanced_parser"
builder.cp2k.parameters = orm.Dict(dict=input_dict)
builder.cp2k.parent_calc_folder = getattr(self.ctx, key0).outputs.remote_folder

builder.cp2k.parent_calc_folder = getattr(
self.ctx, key0
).outputs.remote_folder

future = self.submit(builder)

key = f"reftraj_batch_{batch[0]}_to_{batch[-1]}"
Expand All @@ -203,14 +207,20 @@ def merge_batches_output(self):
# for i_batch in range(self.ctx.n_batches):
# merged_traj.extend(self.ctx[f"reftraj_batch_{i_batch}"].outputs.trajectory)


trajectories_to_merge=[getattr(self.ctx, f"reftraj_batch_{self.ctx.batches[0][0]}_to_{self.ctx.batches[0][0]}").outputs.output_trajectory]
trajectories_to_merge = [
getattr(
self.ctx,
f"reftraj_batch_{self.ctx.batches[0][0]}_to_{self.ctx.batches[0][0]}",
).outputs.output_trajectory
]
for batch in self.ctx.batches[1:]:
key = f"reftraj_batch_{batch[0]}_to_{batch[-1]}"
if not getattr(self.ctx, key).is_finished_ok:
self.report(f"Batch {key} failed")
return self.exit_codes.ERROR_TERMINATION
trajectories_to_merge.append(getattr(self.ctx, key).outputs.output_trajectory)
trajectories_to_merge.append(
getattr(self.ctx, key).outputs.output_trajectory
)
merged_trajectory = merge_trajectories(*trajectories_to_merge)

self.out("output_trajectory", merged_trajectory)
Expand Down
5 changes: 1 addition & 4 deletions examples/workflows/example_cp2k_md_reftraj.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,7 @@ def _example_cp2k_reftraj(cp2k_code):
]
)
cells = np.array(
[
[[5, 0, 0], [0, 5, 0], [0, 0, 5 + 0.0001 * i]]
for i in range(steps)
]
[[[5, 0, 0], [0, 5, 0], [0, 0, 5 + 0.0001 * i]] for i in range(steps)]
)
symbols = ["H", "H"]
trajectory = TrajectoryData()
Expand Down

0 comments on commit 8bcdc24

Please sign in to comment.