From 018426069300e3a21e453940a537d8872fc9b68f Mon Sep 17 00:00:00 2001 From: Umberto Lupo Date: Thu, 7 Mar 2024 15:21:01 +0100 Subject: [PATCH] Make similarity_gradient_bypass init parameter of InformationAndBestHits and add compare_soft_best_hits_to_hard parameter --- diffpass/_modidx.py | 2 ++ diffpass/train.py | 40 +++++++++++++++++++++++----------------- nbs/train.ipynb | 40 +++++++++++++++++++++++----------------- 3 files changed, 48 insertions(+), 34 deletions(-) diff --git a/diffpass/_modidx.py b/diffpass/_modidx.py index 40c3d6a..2f7588b 100644 --- a/diffpass/_modidx.py +++ b/diffpass/_modidx.py @@ -171,6 +171,8 @@ 'diffpass/train.py'), 'diffpass.train.InformationAndBestHits._adjust_loss_weights_and_ensemble_shape': ( 'train.html#informationandbesthits._adjust_loss_weights_and_ensemble_shape', 'diffpass/train.py'), + 'diffpass.train.InformationAndBestHits._bh_y_for_soft_x': ( 'train.html#informationandbesthits._bh_y_for_soft_x', + 'diffpass/train.py'), 'diffpass.train.InformationAndBestHits._fit': ( 'train.html#informationandbesthits._fit', 'diffpass/train.py'), 'diffpass.train.InformationAndBestHits._precompute_bh': ( 'train.html#informationandbesthits._precompute_bh', diff --git a/diffpass/train.py b/diffpass/train.py index e224ee8..8fd5be6 100644 --- a/diffpass/train.py +++ b/diffpass/train.py @@ -364,6 +364,8 @@ def __init__( similarities_cfg: Optional[dict[str, Any]] = None, best_hits_cfg: Optional[dict[str, Any]] = None, inter_group_loss_score_fn: Optional[callable] = None, + similarity_gradient_bypass: bool = False, + compare_soft_best_hits_to_hard: bool = True, ): super().__init__() self.group_sizes = tuple(s for s in group_sizes) @@ -375,6 +377,8 @@ def __init__( self.similarities_cfg = similarities_cfg self.best_hits_cfg = best_hits_cfg self.inter_group_loss_score_fn = inter_group_loss_score_fn + self.similarity_gradient_bypass = similarity_gradient_bypass + self.compare_soft_best_hits_to_hard = compare_soft_best_hits_to_hard ensemble_shape = [] _dim_in_ensemble = -1 @@ -561,6 +565,12 @@ def _precompute_bh(self, x: torch.Tensor, y: torch.Tensor) -> None: similarities_y = self.best_hits.prepare_fixed(similarities_y) self.register_buffer("_bh_soft_y", self.best_hits(similarities_y)) + @property + def _bh_y_for_soft_x(self): + if self.compare_soft_best_hits_to_hard: + return self._bh_hard_y + return self._bh_soft_y + def forward( self, x: torch.Tensor, @@ -580,16 +590,18 @@ def forward( # Best hits portion of the loss, with shortcut for hard permutations if mode == "soft": - if x_perm_hard is not None: + if self.similarity_gradient_bypass: x_perm = (x_perm_hard - x_perm).detach() + x_perm similarities_x = self.similarities(x_perm) bh_x = self.best_hits(similarities_x) + # Ensure comparisons are soft_x-{soft,hard}_y, depending on + # self.compare_soft_best_hits_to_hard + loss_bh = self.inter_group_loss(bh_x, self._bh_y_for_soft_x) else: bh_x = apply_hard_permutation_batch_to_similarity( x=self._bh_hard_x, perms=perms ) - # Ensure comparisons are only hard-hard or soft-soft - loss_bh = self.inter_group_loss(bh_x, getattr(self, f"_bh_{mode}_y")) + loss_bh = self.inter_group_loss(bh_x, self._bh_hard_y) return { "perms": perms, @@ -643,7 +655,7 @@ def _prepare_fit(self, x: torch.Tensor, y: torch.Tensor) -> DiffPASSResults: self.information_measure ] = results.hard_losses_identity_perm[self.information_measure] results.soft_losses_identity_perm["BestHits"] = self.inter_group_loss( - self._bh_soft_x, self._bh_soft_y + self._bh_soft_x, self._bh_y_for_soft_x ).item() return results @@ -658,7 +670,6 @@ def _fit( optimizer_name: Optional[str] = "SGD", optimizer_kwargs: Optional[dict[str, Any]] = None, mean_centering: bool = True, - similarity_gradient_bypass: bool = False, show_pbar: bool = True, compute_final_soft: bool = True, ) -> DiffPASSResults: @@ -686,7 +697,7 @@ def _fit( epoch_results = self(x, y) loss_info = epoch_results["loss_info"] loss_bh = epoch_results["loss_bh"] - if similarity_gradient_bypass: + if self.similarity_gradient_bypass: x_perm_hard = epoch_results["x_perm"] perms = epoch_results["perms"] results.log_alphas.append( @@ -738,7 +749,6 @@ def fit( optimizer_name: Optional[str] = "SGD", optimizer_kwargs: Optional[dict[str, Any]] = None, mean_centering: bool = True, - similarity_gradient_bypass: bool = False, show_pbar: bool = True, compute_final_soft: bool = True, ) -> DiffPASSResults: @@ -752,7 +762,6 @@ def fit( optimizer_name=optimizer_name, optimizer_kwargs=optimizer_kwargs, mean_centering=mean_centering, - similarity_gradient_bypass=similarity_gradient_bypass, show_pbar=show_pbar, compute_final_soft=compute_final_soft, ) @@ -771,6 +780,7 @@ def __init__( similarity_kind: Literal["Hamming", "Blosum62"] = "Hamming", similarities_cfg: Optional[dict[str, Any]] = None, intra_group_loss_score_fn: Optional[callable] = None, + similarity_gradient_bypass: bool = False, ): super().__init__() self.group_sizes = tuple(s for s in group_sizes) @@ -781,6 +791,7 @@ def __init__( self.similarity_kind = similarity_kind self.similarities_cfg = similarities_cfg self.intra_group_loss_score_fn = intra_group_loss_score_fn + self.similarity_gradient_bypass = similarity_gradient_bypass ensemble_shape = [] _dim_in_ensemble = -1 @@ -956,17 +967,15 @@ def forward( # Mirrortree portion of the loss, with shortcut for hard permutations if mode == "soft": - if x_perm_hard is not None: + if self.similarity_gradient_bypass: x_perm = (x_perm_hard - x_perm).detach() + x_perm similarities_x = self.similarities(x_perm) else: similarities_x = apply_hard_permutation_batch_to_similarity( x=self._similarities_hard_x, perms=perms ) - # Ensure comparisons are only hard-hard or soft-soft - loss_mt = self.intra_group_loss( - similarities_x, getattr(self, f"_similarities_hard_y") - ) + + loss_mt = self.intra_group_loss(similarities_x, self._similarities_hard_y) return { "perms": perms, @@ -1023,7 +1032,6 @@ def _fit( optimizer_name: Optional[str] = "SGD", optimizer_kwargs: Optional[dict[str, Any]] = None, mean_centering: bool = True, - similarity_gradient_bypass: bool = False, show_pbar: bool = True, compute_final_soft: bool = True, ) -> DiffPASSResults: @@ -1051,7 +1059,7 @@ def _fit( epoch_results = self(x, y) loss_info = epoch_results["loss_info"] loss_mt = epoch_results["loss_mt"] - if similarity_gradient_bypass: + if self.similarity_gradient_bypass: x_perm_hard = epoch_results["x_perm"] perms = epoch_results["perms"] results.log_alphas.append( @@ -1103,7 +1111,6 @@ def fit( optimizer_name: Optional[str] = "SGD", optimizer_kwargs: Optional[dict[str, Any]] = None, mean_centering: bool = True, - similarity_gradient_bypass: bool = False, show_pbar: bool = True, compute_final_soft: bool = True, ) -> DiffPASSResults: @@ -1117,7 +1124,6 @@ def fit( optimizer_name=optimizer_name, optimizer_kwargs=optimizer_kwargs, mean_centering=mean_centering, - similarity_gradient_bypass=similarity_gradient_bypass, show_pbar=show_pbar, compute_final_soft=compute_final_soft, ) diff --git a/nbs/train.ipynb b/nbs/train.ipynb index 6c0f268..45185d3 100644 --- a/nbs/train.ipynb +++ b/nbs/train.ipynb @@ -444,6 +444,8 @@ " similarities_cfg: Optional[dict[str, Any]] = None,\n", " best_hits_cfg: Optional[dict[str, Any]] = None,\n", " inter_group_loss_score_fn: Optional[callable] = None,\n", + " similarity_gradient_bypass: bool = False,\n", + " compare_soft_best_hits_to_hard: bool = True,\n", " ):\n", " super().__init__()\n", " self.group_sizes = tuple(s for s in group_sizes)\n", @@ -455,6 +457,8 @@ " self.similarities_cfg = similarities_cfg\n", " self.best_hits_cfg = best_hits_cfg\n", " self.inter_group_loss_score_fn = inter_group_loss_score_fn\n", + " self.similarity_gradient_bypass = similarity_gradient_bypass\n", + " self.compare_soft_best_hits_to_hard = compare_soft_best_hits_to_hard\n", "\n", " ensemble_shape = []\n", " _dim_in_ensemble = -1\n", @@ -641,6 +645,12 @@ " similarities_y = self.best_hits.prepare_fixed(similarities_y)\n", " self.register_buffer(\"_bh_soft_y\", self.best_hits(similarities_y))\n", "\n", + " @property\n", + " def _bh_y_for_soft_x(self):\n", + " if self.compare_soft_best_hits_to_hard:\n", + " return self._bh_hard_y\n", + " return self._bh_soft_y\n", + "\n", " def forward(\n", " self,\n", " x: torch.Tensor,\n", @@ -660,16 +670,18 @@ "\n", " # Best hits portion of the loss, with shortcut for hard permutations\n", " if mode == \"soft\":\n", - " if x_perm_hard is not None:\n", + " if self.similarity_gradient_bypass:\n", " x_perm = (x_perm_hard - x_perm).detach() + x_perm\n", " similarities_x = self.similarities(x_perm)\n", " bh_x = self.best_hits(similarities_x)\n", + " # Ensure comparisons are soft_x-{soft,hard}_y, depending on\n", + " # self.compare_soft_best_hits_to_hard\n", + " loss_bh = self.inter_group_loss(bh_x, self._bh_y_for_soft_x)\n", " else:\n", " bh_x = apply_hard_permutation_batch_to_similarity(\n", " x=self._bh_hard_x, perms=perms\n", " )\n", - " # Ensure comparisons are only hard-hard or soft-soft\n", - " loss_bh = self.inter_group_loss(bh_x, getattr(self, f\"_bh_{mode}_y\"))\n", + " loss_bh = self.inter_group_loss(bh_x, self._bh_hard_y)\n", "\n", " return {\n", " \"perms\": perms,\n", @@ -723,7 +735,7 @@ " self.information_measure\n", " ] = results.hard_losses_identity_perm[self.information_measure]\n", " results.soft_losses_identity_perm[\"BestHits\"] = self.inter_group_loss(\n", - " self._bh_soft_x, self._bh_soft_y\n", + " self._bh_soft_x, self._bh_y_for_soft_x\n", " ).item()\n", "\n", " return results\n", @@ -738,7 +750,6 @@ " optimizer_name: Optional[str] = \"SGD\",\n", " optimizer_kwargs: Optional[dict[str, Any]] = None,\n", " mean_centering: bool = True,\n", - " similarity_gradient_bypass: bool = False,\n", " show_pbar: bool = True,\n", " compute_final_soft: bool = True,\n", " ) -> DiffPASSResults:\n", @@ -766,7 +777,7 @@ " epoch_results = self(x, y)\n", " loss_info = epoch_results[\"loss_info\"]\n", " loss_bh = epoch_results[\"loss_bh\"]\n", - " if similarity_gradient_bypass:\n", + " if self.similarity_gradient_bypass:\n", " x_perm_hard = epoch_results[\"x_perm\"]\n", " perms = epoch_results[\"perms\"]\n", " results.log_alphas.append(\n", @@ -818,7 +829,6 @@ " optimizer_name: Optional[str] = \"SGD\",\n", " optimizer_kwargs: Optional[dict[str, Any]] = None,\n", " mean_centering: bool = True,\n", - " similarity_gradient_bypass: bool = False,\n", " show_pbar: bool = True,\n", " compute_final_soft: bool = True,\n", " ) -> DiffPASSResults:\n", @@ -832,7 +842,6 @@ " optimizer_name=optimizer_name,\n", " optimizer_kwargs=optimizer_kwargs,\n", " mean_centering=mean_centering,\n", - " similarity_gradient_bypass=similarity_gradient_bypass,\n", " show_pbar=show_pbar,\n", " compute_final_soft=compute_final_soft,\n", " )\n", @@ -859,6 +868,7 @@ " similarity_kind: Literal[\"Hamming\", \"Blosum62\"] = \"Hamming\",\n", " similarities_cfg: Optional[dict[str, Any]] = None,\n", " intra_group_loss_score_fn: Optional[callable] = None,\n", + " similarity_gradient_bypass: bool = False,\n", " ):\n", " super().__init__()\n", " self.group_sizes = tuple(s for s in group_sizes)\n", @@ -869,6 +879,7 @@ " self.similarity_kind = similarity_kind\n", " self.similarities_cfg = similarities_cfg\n", " self.intra_group_loss_score_fn = intra_group_loss_score_fn\n", + " self.similarity_gradient_bypass = similarity_gradient_bypass\n", "\n", " ensemble_shape = []\n", " _dim_in_ensemble = -1\n", @@ -1044,17 +1055,15 @@ "\n", " # Mirrortree portion of the loss, with shortcut for hard permutations\n", " if mode == \"soft\":\n", - " if x_perm_hard is not None:\n", + " if self.similarity_gradient_bypass:\n", " x_perm = (x_perm_hard - x_perm).detach() + x_perm\n", " similarities_x = self.similarities(x_perm)\n", " else:\n", " similarities_x = apply_hard_permutation_batch_to_similarity(\n", " x=self._similarities_hard_x, perms=perms\n", " )\n", - " # Ensure comparisons are only hard-hard or soft-soft\n", - " loss_mt = self.intra_group_loss(\n", - " similarities_x, getattr(self, f\"_similarities_hard_y\")\n", - " )\n", + "\n", + " loss_mt = self.intra_group_loss(similarities_x, self._similarities_hard_y)\n", "\n", " return {\n", " \"perms\": perms,\n", @@ -1111,7 +1120,6 @@ " optimizer_name: Optional[str] = \"SGD\",\n", " optimizer_kwargs: Optional[dict[str, Any]] = None,\n", " mean_centering: bool = True,\n", - " similarity_gradient_bypass: bool = False,\n", " show_pbar: bool = True,\n", " compute_final_soft: bool = True,\n", " ) -> DiffPASSResults:\n", @@ -1139,7 +1147,7 @@ " epoch_results = self(x, y)\n", " loss_info = epoch_results[\"loss_info\"]\n", " loss_mt = epoch_results[\"loss_mt\"]\n", - " if similarity_gradient_bypass:\n", + " if self.similarity_gradient_bypass:\n", " x_perm_hard = epoch_results[\"x_perm\"]\n", " perms = epoch_results[\"perms\"]\n", " results.log_alphas.append(\n", @@ -1191,7 +1199,6 @@ " optimizer_name: Optional[str] = \"SGD\",\n", " optimizer_kwargs: Optional[dict[str, Any]] = None,\n", " mean_centering: bool = True,\n", - " similarity_gradient_bypass: bool = False,\n", " show_pbar: bool = True,\n", " compute_final_soft: bool = True,\n", " ) -> DiffPASSResults:\n", @@ -1205,7 +1212,6 @@ " optimizer_name=optimizer_name,\n", " optimizer_kwargs=optimizer_kwargs,\n", " mean_centering=mean_centering,\n", - " similarity_gradient_bypass=similarity_gradient_bypass,\n", " show_pbar=show_pbar,\n", " compute_final_soft=compute_final_soft,\n", " )\n",