Skip to content

Commit

Permalink
Make recording hard losses for identity permutation optional
Browse files Browse the repository at this point in the history
  • Loading branch information
ulupo committed Mar 7, 2024
1 parent 6a1fbec commit 2395bc9
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 0 deletions.
2 changes: 2 additions & 0 deletions diffpass/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@
'diffpass.train.Information.fit': ('train.html#information.fit', 'diffpass/train.py'),
'diffpass.train.Information.forward': ('train.html#information.forward', 'diffpass/train.py'),
'diffpass.train.Information.hard_': ('train.html#information.hard_', 'diffpass/train.py'),
'diffpass.train.Information.record_hard_losses_identity_perm': ( 'train.html#information.record_hard_losses_identity_perm',
'diffpass/train.py'),
'diffpass.train.Information.soft_': ('train.html#information.soft_', 'diffpass/train.py'),
'diffpass.train.InformationAndBestHits': ('train.html#informationandbesthits', 'diffpass/train.py'),
'diffpass.train.InformationAndBestHits.__init__': ( 'train.html#informationandbesthits.__init__',
Expand Down
9 changes: 9 additions & 0 deletions diffpass/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,11 @@ def _prepare_fit(self, x: torch.Tensor, y: torch.Tensor) -> DiffPASSResults:
},
)

return results

def record_hard_losses_identity_perm(
self, x: torch.Tensor, y: torch.Tensor, results: DiffPASSResults
) -> DiffPASSResults:
# Compute hard losses with identity permutation
self.hard_()
with torch.no_grad():
Expand Down Expand Up @@ -316,8 +321,12 @@ def fit(
mean_centering: bool = True,
show_pbar: bool = True,
compute_final_soft: bool = True,
compute_hard_losses_identity_perm: bool = False,
) -> DiffPASSResults:
results = self._prepare_fit(x, y)
if compute_hard_losses_identity_perm:
results = self.record_hard_losses_identity_perm(x, y, results)

try:
self.check_can_optimize(
self.permutation._total_number_fixed_matchings, len(x)
Expand Down
9 changes: 9 additions & 0 deletions nbs/train.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,11 @@
" },\n",
" )\n",
"\n",
" return results\n",
"\n",
" def record_hard_losses_identity_perm(\n",
" self, x: torch.Tensor, y: torch.Tensor, results: DiffPASSResults\n",
" ) -> DiffPASSResults:\n",
" # Compute hard losses with identity permutation\n",
" self.hard_()\n",
" with torch.no_grad():\n",
Expand Down Expand Up @@ -388,8 +393,12 @@
" mean_centering: bool = True,\n",
" show_pbar: bool = True,\n",
" compute_final_soft: bool = True,\n",
" compute_hard_losses_identity_perm: bool = False,\n",
" ) -> DiffPASSResults:\n",
" results = self._prepare_fit(x, y)\n",
" if compute_hard_losses_identity_perm:\n",
" results = self.record_hard_losses_identity_perm(x, y, results)\n",
"\n",
" try:\n",
" self.check_can_optimize(\n",
" self.permutation._total_number_fixed_matchings, len(x)\n",
Expand Down

0 comments on commit 2395bc9

Please sign in to comment.