diff --git a/ImmuneBuilder/models.py b/ImmuneBuilder/models.py index 9f58050..b4f4156 100644 --- a/ImmuneBuilder/models.py +++ b/ImmuneBuilder/models.py @@ -225,7 +225,7 @@ def forward(self, node_features, sequence): # Remove atoms of side chains with outrageous clashes ds = torch.linalg.norm(all_atoms[None,:,None] - all_atoms[:,None,:,None], axis = -1) - ds[torch.isnan(ds!=ds) | (ds==0.0)] = 10 + ds[torch.isnan(ds) | (ds==0.0)] = 10 min_ds = ds.min(dim=-1)[0].min(dim=-1)[0].min(dim=-1)[0] all_atoms[min_ds < 0.2, 5:, :] = float("Nan")