diff --git a/diffpass/_modidx.py b/diffpass/_modidx.py index 32cb05c..40c3d6a 100644 --- a/diffpass/_modidx.py +++ b/diffpass/_modidx.py @@ -161,6 +161,8 @@ 'diffpass.train.Information.fit': ('train.html#information.fit', 'diffpass/train.py'), 'diffpass.train.Information.forward': ('train.html#information.forward', 'diffpass/train.py'), 'diffpass.train.Information.hard_': ('train.html#information.hard_', 'diffpass/train.py'), + 'diffpass.train.Information.record_hard_losses_identity_perm': ( 'train.html#information.record_hard_losses_identity_perm', + 'diffpass/train.py'), 'diffpass.train.Information.soft_': ('train.html#information.soft_', 'diffpass/train.py'), 'diffpass.train.InformationAndBestHits': ('train.html#informationandbesthits', 'diffpass/train.py'), 'diffpass.train.InformationAndBestHits.__init__': ( 'train.html#informationandbesthits.__init__', diff --git a/diffpass/train.py b/diffpass/train.py index 75f91ea..e224ee8 100644 --- a/diffpass/train.py +++ b/diffpass/train.py @@ -198,6 +198,11 @@ def _prepare_fit(self, x: torch.Tensor, y: torch.Tensor) -> DiffPASSResults: }, ) + return results + + def record_hard_losses_identity_perm( + self, x: torch.Tensor, y: torch.Tensor, results: DiffPASSResults + ) -> DiffPASSResults: # Compute hard losses with identity permutation self.hard_() with torch.no_grad(): @@ -316,8 +321,12 @@ def fit( mean_centering: bool = True, show_pbar: bool = True, compute_final_soft: bool = True, + compute_hard_losses_identity_perm: bool = False, ) -> DiffPASSResults: results = self._prepare_fit(x, y) + if compute_hard_losses_identity_perm: + results = self.record_hard_losses_identity_perm(x, y, results) + try: self.check_can_optimize( self.permutation._total_number_fixed_matchings, len(x) diff --git a/nbs/train.ipynb b/nbs/train.ipynb index ba1eb8a..6c0f268 100644 --- a/nbs/train.ipynb +++ b/nbs/train.ipynb @@ -270,6 +270,11 @@ " },\n", " )\n", "\n", + " return results\n", + "\n", + " def record_hard_losses_identity_perm(\n", + " self, x: torch.Tensor, y: torch.Tensor, results: DiffPASSResults\n", + " ) -> DiffPASSResults:\n", " # Compute hard losses with identity permutation\n", " self.hard_()\n", " with torch.no_grad():\n", @@ -388,8 +393,12 @@ " mean_centering: bool = True,\n", " show_pbar: bool = True,\n", " compute_final_soft: bool = True,\n", + " compute_hard_losses_identity_perm: bool = False,\n", " ) -> DiffPASSResults:\n", " results = self._prepare_fit(x, y)\n", + " if compute_hard_losses_identity_perm:\n", + " results = self.record_hard_losses_identity_perm(x, y, results)\n", + "\n", " try:\n", " self.check_can_optimize(\n", " self.permutation._total_number_fixed_matchings, len(x)\n",