Skip to content

Commit

Permalink
fix StressForceNode
Browse files Browse the repository at this point in the history
  • Loading branch information
shinkle-lanl committed Oct 17, 2024
1 parent 10bded4 commit 524a4c0
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 8 deletions.
2 changes: 1 addition & 1 deletion examples/ase_example_multilayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
# Load the files
try:
with active_directory("TEST_ALUMINUM_MODEL_MULTILAYER", create=False):
bundle = load_checkpoint_from_cwd(map_location='cpu',e)
bundle = load_checkpoint_from_cwd(map_location='cpu')
except FileNotFoundError:
raise FileNotFoundError("Model not found, run ani_aluminum_example_multilayer.py first!")

Expand Down
6 changes: 3 additions & 3 deletions hippynn/interfaces/ase_interface/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,9 +339,9 @@ def calculate(self, atoms=None, properties=None, system_changes=True):
cell = torch.as_tensor(self.atoms.cell.array) # ExternalNieghbors doesn't take batch index
# Get pair first and second from neighbors list

pair_first = torch.as_tensor(self.nl.nl.pair_first,dtype=torch.long)
pair_second = torch.as_tensor(self.nl.nl.pair_second,dtype=torch.long)
pair_shiftvecs = torch.as_tensor(self.nl.nl.offset_vec,dtype=torch.long)
pair_first = torch.as_tensor(self.nl.nl.pair_first,dtype=torch.long).unsqueeze(0)
pair_second = torch.as_tensor(self.nl.nl.pair_second,dtype=torch.long).unsqueeze(0)
pair_shiftvecs = torch.as_tensor(self.nl.nl.offset_vec,dtype=torch.long).unsqueeze(0)

# This order must be synchronized with function setup_ase_graph above
inputs = species, positions, cell, pair_first, pair_second, pair_shiftvecs
Expand Down
5 changes: 2 additions & 3 deletions hippynn/layers/indexers.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,10 +172,9 @@ def __init__(self, *args, **kwargs):
def forward(self, coordinates, cell):
strain = torch.eye(
coordinates.shape[2], dtype=coordinates.dtype, device=coordinates.device, requires_grad=True
).unsqueeze(0)
).tile(coordinates.shape[0],1,1)
strained_coordinates = torch.bmm(coordinates, strain)
if cell.dim() == 2:
strained_cell = torch.mm(cell, strain.squeeze(0))
strained_cell = torch.matmul(cell, strain)
return strained_coordinates, strained_cell, strain


Expand Down
2 changes: 1 addition & 1 deletion hippynn/layers/pairs/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def forward(self, coordinates, real_atoms, shifts, cell, pair_first, pair_second
n_molecules, n_atoms, _ = coordinates.shape
atom_coordinates = coordinates.reshape(n_molecules * n_atoms, 3)[real_atoms]
paircoord = atom_coordinates[pair_second] - atom_coordinates[pair_first] + shifts.to(cell.dtype) @ cell
distflat = paircoord.norm(dim=1)
distflat = paircoord.norm(dim=-1)

# We filter the lists to only send forward relevant pairs (those with distance under cutoff), improving performance.
return filter_pairs(self.hard_dist_cutoff, distflat, pair_first, pair_second, paircoord)
Expand Down

0 comments on commit 524a4c0

Please sign in to comment.