From 524a4c09715746e186d59e1462fe9fce5882972b Mon Sep 17 00:00:00 2001 From: Emily Shinkle Date: Thu, 17 Oct 2024 13:21:24 -0600 Subject: [PATCH] fix StressForceNode --- examples/ase_example_multilayer.py | 2 +- hippynn/interfaces/ase_interface/calculator.py | 6 +++--- hippynn/layers/indexers.py | 5 ++--- hippynn/layers/pairs/indexing.py | 2 +- 4 files changed, 7 insertions(+), 8 deletions(-) diff --git a/examples/ase_example_multilayer.py b/examples/ase_example_multilayer.py index c873745e..d71c0755 100644 --- a/examples/ase_example_multilayer.py +++ b/examples/ase_example_multilayer.py @@ -30,7 +30,7 @@ # Load the files try: with active_directory("TEST_ALUMINUM_MODEL_MULTILAYER", create=False): - bundle = load_checkpoint_from_cwd(map_location='cpu',e) + bundle = load_checkpoint_from_cwd(map_location='cpu') except FileNotFoundError: raise FileNotFoundError("Model not found, run ani_aluminum_example_multilayer.py first!") diff --git a/hippynn/interfaces/ase_interface/calculator.py b/hippynn/interfaces/ase_interface/calculator.py index a9dd7231..17fd4663 100644 --- a/hippynn/interfaces/ase_interface/calculator.py +++ b/hippynn/interfaces/ase_interface/calculator.py @@ -339,9 +339,9 @@ def calculate(self, atoms=None, properties=None, system_changes=True): cell = torch.as_tensor(self.atoms.cell.array) # ExternalNieghbors doesn't take batch index # Get pair first and second from neighbors list - pair_first = torch.as_tensor(self.nl.nl.pair_first,dtype=torch.long) - pair_second = torch.as_tensor(self.nl.nl.pair_second,dtype=torch.long) - pair_shiftvecs = torch.as_tensor(self.nl.nl.offset_vec,dtype=torch.long) + pair_first = torch.as_tensor(self.nl.nl.pair_first,dtype=torch.long).unsqueeze(0) + pair_second = torch.as_tensor(self.nl.nl.pair_second,dtype=torch.long).unsqueeze(0) + pair_shiftvecs = torch.as_tensor(self.nl.nl.offset_vec,dtype=torch.long).unsqueeze(0) # This order must be synchronized with function setup_ase_graph above inputs = species, positions, cell, pair_first, pair_second, pair_shiftvecs diff --git a/hippynn/layers/indexers.py b/hippynn/layers/indexers.py index ebd691a4..d5604753 100644 --- a/hippynn/layers/indexers.py +++ b/hippynn/layers/indexers.py @@ -172,10 +172,9 @@ def __init__(self, *args, **kwargs): def forward(self, coordinates, cell): strain = torch.eye( coordinates.shape[2], dtype=coordinates.dtype, device=coordinates.device, requires_grad=True - ).unsqueeze(0) + ).tile(coordinates.shape[0],1,1) strained_coordinates = torch.bmm(coordinates, strain) - if cell.dim() == 2: - strained_cell = torch.mm(cell, strain.squeeze(0)) + strained_cell = torch.matmul(cell, strain) return strained_coordinates, strained_cell, strain diff --git a/hippynn/layers/pairs/indexing.py b/hippynn/layers/pairs/indexing.py index a0bbddeb..f59a0deb 100644 --- a/hippynn/layers/pairs/indexing.py +++ b/hippynn/layers/pairs/indexing.py @@ -17,7 +17,7 @@ def forward(self, coordinates, real_atoms, shifts, cell, pair_first, pair_second n_molecules, n_atoms, _ = coordinates.shape atom_coordinates = coordinates.reshape(n_molecules * n_atoms, 3)[real_atoms] paircoord = atom_coordinates[pair_second] - atom_coordinates[pair_first] + shifts.to(cell.dtype) @ cell - distflat = paircoord.norm(dim=1) + distflat = paircoord.norm(dim=-1) # We filter the lists to only send forward relevant pairs (those with distance under cutoff), improving performance. return filter_pairs(self.hard_dist_cutoff, distflat, pair_first, pair_second, paircoord)