From 8bcdc243175c8c374532745e7bac0e0985290cf1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 16 Feb 2024 15:56:53 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../workflows/cp2k/reftraj_md_workchain.py | 34 ++++++++++++------- examples/workflows/example_cp2k_md_reftraj.py | 5 +-- 2 files changed, 23 insertions(+), 16 deletions(-) diff --git a/aiida_nanotech_empa/workflows/cp2k/reftraj_md_workchain.py b/aiida_nanotech_empa/workflows/cp2k/reftraj_md_workchain.py index b6e0da8..3cc21a7 100644 --- a/aiida_nanotech_empa/workflows/cp2k/reftraj_md_workchain.py +++ b/aiida_nanotech_empa/workflows/cp2k/reftraj_md_workchain.py @@ -10,14 +10,15 @@ # Cp2kRefTrajWorkChain = plugins.WorkflowFactory("cp2k.reftraj") TrajectoryData = plugins.DataFactory("array.trajectory") + @engine.calcfunction def merge_trajectories(*trajectories): """Merge a list of trajectories into a single one.""" - positions=[] - cells=[] - forces=[] + positions = [] + cells = [] + forces = [] for trajectory in trajectories: - positions.append(trajectory.get_array("positions") ) + positions.append(trajectory.get_array("positions")) try: cells.append(trajectory.get_array("cells")) except KeyError: @@ -31,11 +32,12 @@ def merge_trajectories(*trajectories): if len(cells) == 0: merged_trajectory.set_trajectory(symbols, positions) else: - merged_trajectory.set_trajectory(symbols, positions, cells=cells) + merged_trajectory.set_trajectory(symbols, positions, cells=cells) merged_trajectory.set_array("forces", forces) - + return merged_trajectory + @engine.calcfunction def create_batches(trajectory, num_batches, steps_completed): """Create lists of consecutive integers. Counting start from 1 for CP2K input.""" @@ -56,7 +58,7 @@ def create_batches(trajectory, num_batches, steps_completed): current_list = [] if current_list: consecutive_lists.append(current_list) - + # [[1,2,3],[4,5,6]] --> [[1],[2,3],[4,5,6]] batches = [[consecutive_lists[0].pop(0)]] for batch in consecutive_lists: @@ -187,8 +189,10 @@ def run_reftraj_batches(self): builder.cp2k.metadata.label = f"structures_{batch[0]}_to_{batch[-1]}" builder.cp2k.metadata.options.parser_name = "cp2k_advanced_parser" builder.cp2k.parameters = orm.Dict(dict=input_dict) - builder.cp2k.parent_calc_folder = getattr(self.ctx, key0).outputs.remote_folder - + builder.cp2k.parent_calc_folder = getattr( + self.ctx, key0 + ).outputs.remote_folder + future = self.submit(builder) key = f"reftraj_batch_{batch[0]}_to_{batch[-1]}" @@ -203,14 +207,20 @@ def merge_batches_output(self): # for i_batch in range(self.ctx.n_batches): # merged_traj.extend(self.ctx[f"reftraj_batch_{i_batch}"].outputs.trajectory) - - trajectories_to_merge=[getattr(self.ctx, f"reftraj_batch_{self.ctx.batches[0][0]}_to_{self.ctx.batches[0][0]}").outputs.output_trajectory] + trajectories_to_merge = [ + getattr( + self.ctx, + f"reftraj_batch_{self.ctx.batches[0][0]}_to_{self.ctx.batches[0][0]}", + ).outputs.output_trajectory + ] for batch in self.ctx.batches[1:]: key = f"reftraj_batch_{batch[0]}_to_{batch[-1]}" if not getattr(self.ctx, key).is_finished_ok: self.report(f"Batch {key} failed") return self.exit_codes.ERROR_TERMINATION - trajectories_to_merge.append(getattr(self.ctx, key).outputs.output_trajectory) + trajectories_to_merge.append( + getattr(self.ctx, key).outputs.output_trajectory + ) merged_trajectory = merge_trajectories(*trajectories_to_merge) self.out("output_trajectory", merged_trajectory) diff --git a/examples/workflows/example_cp2k_md_reftraj.py b/examples/workflows/example_cp2k_md_reftraj.py index ec9b17f..ec4a34e 100644 --- a/examples/workflows/example_cp2k_md_reftraj.py +++ b/examples/workflows/example_cp2k_md_reftraj.py @@ -28,10 +28,7 @@ def _example_cp2k_reftraj(cp2k_code): ] ) cells = np.array( - [ - [[5, 0, 0], [0, 5, 0], [0, 0, 5 + 0.0001 * i]] - for i in range(steps) - ] + [[[5, 0, 0], [0, 5, 0], [0, 0, 5 + 0.0001 * i]] for i in range(steps)] ) symbols = ["H", "H"] trajectory = TrajectoryData()