Skip to content

Commit

Permalink
Replace a use of _dccn with .item()
Browse files Browse the repository at this point in the history
  • Loading branch information
ulupo committed Dec 15, 2023
1 parent e061313 commit 655ad95
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
6 changes: 3 additions & 3 deletions diffpass/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,9 +594,9 @@ def _prepare_fit(self, x: torch.Tensor, y: torch.Tensor) -> DiffPASSResults:
results.soft_losses_identity_perm[
self.information_measure
] = results.hard_losses_identity_perm[self.information_measure]
results.soft_losses_identity_perm["BestHits"] = _dccn(
self.inter_group_loss(self._bh_soft_x, self._bh_soft_y)
)
results.soft_losses_identity_perm["BestHits"] = self.inter_group_loss(
self._bh_soft_x, self._bh_soft_y
).item()

return results

Expand Down
6 changes: 3 additions & 3 deletions nbs/train.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -674,9 +674,9 @@
" results.soft_losses_identity_perm[\n",
" self.information_measure\n",
" ] = results.hard_losses_identity_perm[self.information_measure]\n",
" results.soft_losses_identity_perm[\"BestHits\"] = _dccn(\n",
" self.inter_group_loss(self._bh_soft_x, self._bh_soft_y)\n",
" )\n",
" results.soft_losses_identity_perm[\"BestHits\"] = self.inter_group_loss(\n",
" self._bh_soft_x, self._bh_soft_y\n",
" ).item()\n",
"\n",
" return results\n",
"\n",
Expand Down

0 comments on commit 655ad95

Please sign in to comment.