Skip to content

Commit

Permalink
Make similarity_gradient_bypass init parameter of InformationAndBestH…
Browse files Browse the repository at this point in the history
…its and add compare_soft_best_hits_to_hard parameter
  • Loading branch information
ulupo committed Mar 7, 2024
1 parent 2395bc9 commit 0184260
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 34 deletions.
2 changes: 2 additions & 0 deletions diffpass/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,8 @@
'diffpass/train.py'),
'diffpass.train.InformationAndBestHits._adjust_loss_weights_and_ensemble_shape': ( 'train.html#informationandbesthits._adjust_loss_weights_and_ensemble_shape',
'diffpass/train.py'),
'diffpass.train.InformationAndBestHits._bh_y_for_soft_x': ( 'train.html#informationandbesthits._bh_y_for_soft_x',
'diffpass/train.py'),
'diffpass.train.InformationAndBestHits._fit': ( 'train.html#informationandbesthits._fit',
'diffpass/train.py'),
'diffpass.train.InformationAndBestHits._precompute_bh': ( 'train.html#informationandbesthits._precompute_bh',
Expand Down
40 changes: 23 additions & 17 deletions diffpass/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,8 @@ def __init__(
similarities_cfg: Optional[dict[str, Any]] = None,
best_hits_cfg: Optional[dict[str, Any]] = None,
inter_group_loss_score_fn: Optional[callable] = None,
similarity_gradient_bypass: bool = False,
compare_soft_best_hits_to_hard: bool = True,
):
super().__init__()
self.group_sizes = tuple(s for s in group_sizes)
Expand All @@ -375,6 +377,8 @@ def __init__(
self.similarities_cfg = similarities_cfg
self.best_hits_cfg = best_hits_cfg
self.inter_group_loss_score_fn = inter_group_loss_score_fn
self.similarity_gradient_bypass = similarity_gradient_bypass
self.compare_soft_best_hits_to_hard = compare_soft_best_hits_to_hard

ensemble_shape = []
_dim_in_ensemble = -1
Expand Down Expand Up @@ -561,6 +565,12 @@ def _precompute_bh(self, x: torch.Tensor, y: torch.Tensor) -> None:
similarities_y = self.best_hits.prepare_fixed(similarities_y)
self.register_buffer("_bh_soft_y", self.best_hits(similarities_y))

@property
def _bh_y_for_soft_x(self):
if self.compare_soft_best_hits_to_hard:
return self._bh_hard_y
return self._bh_soft_y

def forward(
self,
x: torch.Tensor,
Expand All @@ -580,16 +590,18 @@ def forward(

# Best hits portion of the loss, with shortcut for hard permutations
if mode == "soft":
if x_perm_hard is not None:
if self.similarity_gradient_bypass:
x_perm = (x_perm_hard - x_perm).detach() + x_perm
similarities_x = self.similarities(x_perm)
bh_x = self.best_hits(similarities_x)
# Ensure comparisons are soft_x-{soft,hard}_y, depending on
# self.compare_soft_best_hits_to_hard
loss_bh = self.inter_group_loss(bh_x, self._bh_y_for_soft_x)
else:
bh_x = apply_hard_permutation_batch_to_similarity(
x=self._bh_hard_x, perms=perms
)
# Ensure comparisons are only hard-hard or soft-soft
loss_bh = self.inter_group_loss(bh_x, getattr(self, f"_bh_{mode}_y"))
loss_bh = self.inter_group_loss(bh_x, self._bh_hard_y)

return {
"perms": perms,
Expand Down Expand Up @@ -643,7 +655,7 @@ def _prepare_fit(self, x: torch.Tensor, y: torch.Tensor) -> DiffPASSResults:
self.information_measure
] = results.hard_losses_identity_perm[self.information_measure]
results.soft_losses_identity_perm["BestHits"] = self.inter_group_loss(
self._bh_soft_x, self._bh_soft_y
self._bh_soft_x, self._bh_y_for_soft_x
).item()

return results
Expand All @@ -658,7 +670,6 @@ def _fit(
optimizer_name: Optional[str] = "SGD",
optimizer_kwargs: Optional[dict[str, Any]] = None,
mean_centering: bool = True,
similarity_gradient_bypass: bool = False,
show_pbar: bool = True,
compute_final_soft: bool = True,
) -> DiffPASSResults:
Expand Down Expand Up @@ -686,7 +697,7 @@ def _fit(
epoch_results = self(x, y)
loss_info = epoch_results["loss_info"]
loss_bh = epoch_results["loss_bh"]
if similarity_gradient_bypass:
if self.similarity_gradient_bypass:
x_perm_hard = epoch_results["x_perm"]
perms = epoch_results["perms"]
results.log_alphas.append(
Expand Down Expand Up @@ -738,7 +749,6 @@ def fit(
optimizer_name: Optional[str] = "SGD",
optimizer_kwargs: Optional[dict[str, Any]] = None,
mean_centering: bool = True,
similarity_gradient_bypass: bool = False,
show_pbar: bool = True,
compute_final_soft: bool = True,
) -> DiffPASSResults:
Expand All @@ -752,7 +762,6 @@ def fit(
optimizer_name=optimizer_name,
optimizer_kwargs=optimizer_kwargs,
mean_centering=mean_centering,
similarity_gradient_bypass=similarity_gradient_bypass,
show_pbar=show_pbar,
compute_final_soft=compute_final_soft,
)
Expand All @@ -771,6 +780,7 @@ def __init__(
similarity_kind: Literal["Hamming", "Blosum62"] = "Hamming",
similarities_cfg: Optional[dict[str, Any]] = None,
intra_group_loss_score_fn: Optional[callable] = None,
similarity_gradient_bypass: bool = False,
):
super().__init__()
self.group_sizes = tuple(s for s in group_sizes)
Expand All @@ -781,6 +791,7 @@ def __init__(
self.similarity_kind = similarity_kind
self.similarities_cfg = similarities_cfg
self.intra_group_loss_score_fn = intra_group_loss_score_fn
self.similarity_gradient_bypass = similarity_gradient_bypass

ensemble_shape = []
_dim_in_ensemble = -1
Expand Down Expand Up @@ -956,17 +967,15 @@ def forward(

# Mirrortree portion of the loss, with shortcut for hard permutations
if mode == "soft":
if x_perm_hard is not None:
if self.similarity_gradient_bypass:
x_perm = (x_perm_hard - x_perm).detach() + x_perm
similarities_x = self.similarities(x_perm)
else:
similarities_x = apply_hard_permutation_batch_to_similarity(
x=self._similarities_hard_x, perms=perms
)
# Ensure comparisons are only hard-hard or soft-soft
loss_mt = self.intra_group_loss(
similarities_x, getattr(self, f"_similarities_hard_y")
)

loss_mt = self.intra_group_loss(similarities_x, self._similarities_hard_y)

return {
"perms": perms,
Expand Down Expand Up @@ -1023,7 +1032,6 @@ def _fit(
optimizer_name: Optional[str] = "SGD",
optimizer_kwargs: Optional[dict[str, Any]] = None,
mean_centering: bool = True,
similarity_gradient_bypass: bool = False,
show_pbar: bool = True,
compute_final_soft: bool = True,
) -> DiffPASSResults:
Expand Down Expand Up @@ -1051,7 +1059,7 @@ def _fit(
epoch_results = self(x, y)
loss_info = epoch_results["loss_info"]
loss_mt = epoch_results["loss_mt"]
if similarity_gradient_bypass:
if self.similarity_gradient_bypass:
x_perm_hard = epoch_results["x_perm"]
perms = epoch_results["perms"]
results.log_alphas.append(
Expand Down Expand Up @@ -1103,7 +1111,6 @@ def fit(
optimizer_name: Optional[str] = "SGD",
optimizer_kwargs: Optional[dict[str, Any]] = None,
mean_centering: bool = True,
similarity_gradient_bypass: bool = False,
show_pbar: bool = True,
compute_final_soft: bool = True,
) -> DiffPASSResults:
Expand All @@ -1117,7 +1124,6 @@ def fit(
optimizer_name=optimizer_name,
optimizer_kwargs=optimizer_kwargs,
mean_centering=mean_centering,
similarity_gradient_bypass=similarity_gradient_bypass,
show_pbar=show_pbar,
compute_final_soft=compute_final_soft,
)
Expand Down
40 changes: 23 additions & 17 deletions nbs/train.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,8 @@
" similarities_cfg: Optional[dict[str, Any]] = None,\n",
" best_hits_cfg: Optional[dict[str, Any]] = None,\n",
" inter_group_loss_score_fn: Optional[callable] = None,\n",
" similarity_gradient_bypass: bool = False,\n",
" compare_soft_best_hits_to_hard: bool = True,\n",
" ):\n",
" super().__init__()\n",
" self.group_sizes = tuple(s for s in group_sizes)\n",
Expand All @@ -455,6 +457,8 @@
" self.similarities_cfg = similarities_cfg\n",
" self.best_hits_cfg = best_hits_cfg\n",
" self.inter_group_loss_score_fn = inter_group_loss_score_fn\n",
" self.similarity_gradient_bypass = similarity_gradient_bypass\n",
" self.compare_soft_best_hits_to_hard = compare_soft_best_hits_to_hard\n",
"\n",
" ensemble_shape = []\n",
" _dim_in_ensemble = -1\n",
Expand Down Expand Up @@ -641,6 +645,12 @@
" similarities_y = self.best_hits.prepare_fixed(similarities_y)\n",
" self.register_buffer(\"_bh_soft_y\", self.best_hits(similarities_y))\n",
"\n",
" @property\n",
" def _bh_y_for_soft_x(self):\n",
" if self.compare_soft_best_hits_to_hard:\n",
" return self._bh_hard_y\n",
" return self._bh_soft_y\n",
"\n",
" def forward(\n",
" self,\n",
" x: torch.Tensor,\n",
Expand All @@ -660,16 +670,18 @@
"\n",
" # Best hits portion of the loss, with shortcut for hard permutations\n",
" if mode == \"soft\":\n",
" if x_perm_hard is not None:\n",
" if self.similarity_gradient_bypass:\n",
" x_perm = (x_perm_hard - x_perm).detach() + x_perm\n",
" similarities_x = self.similarities(x_perm)\n",
" bh_x = self.best_hits(similarities_x)\n",
" # Ensure comparisons are soft_x-{soft,hard}_y, depending on\n",
" # self.compare_soft_best_hits_to_hard\n",
" loss_bh = self.inter_group_loss(bh_x, self._bh_y_for_soft_x)\n",
" else:\n",
" bh_x = apply_hard_permutation_batch_to_similarity(\n",
" x=self._bh_hard_x, perms=perms\n",
" )\n",
" # Ensure comparisons are only hard-hard or soft-soft\n",
" loss_bh = self.inter_group_loss(bh_x, getattr(self, f\"_bh_{mode}_y\"))\n",
" loss_bh = self.inter_group_loss(bh_x, self._bh_hard_y)\n",
"\n",
" return {\n",
" \"perms\": perms,\n",
Expand Down Expand Up @@ -723,7 +735,7 @@
" self.information_measure\n",
" ] = results.hard_losses_identity_perm[self.information_measure]\n",
" results.soft_losses_identity_perm[\"BestHits\"] = self.inter_group_loss(\n",
" self._bh_soft_x, self._bh_soft_y\n",
" self._bh_soft_x, self._bh_y_for_soft_x\n",
" ).item()\n",
"\n",
" return results\n",
Expand All @@ -738,7 +750,6 @@
" optimizer_name: Optional[str] = \"SGD\",\n",
" optimizer_kwargs: Optional[dict[str, Any]] = None,\n",
" mean_centering: bool = True,\n",
" similarity_gradient_bypass: bool = False,\n",
" show_pbar: bool = True,\n",
" compute_final_soft: bool = True,\n",
" ) -> DiffPASSResults:\n",
Expand Down Expand Up @@ -766,7 +777,7 @@
" epoch_results = self(x, y)\n",
" loss_info = epoch_results[\"loss_info\"]\n",
" loss_bh = epoch_results[\"loss_bh\"]\n",
" if similarity_gradient_bypass:\n",
" if self.similarity_gradient_bypass:\n",
" x_perm_hard = epoch_results[\"x_perm\"]\n",
" perms = epoch_results[\"perms\"]\n",
" results.log_alphas.append(\n",
Expand Down Expand Up @@ -818,7 +829,6 @@
" optimizer_name: Optional[str] = \"SGD\",\n",
" optimizer_kwargs: Optional[dict[str, Any]] = None,\n",
" mean_centering: bool = True,\n",
" similarity_gradient_bypass: bool = False,\n",
" show_pbar: bool = True,\n",
" compute_final_soft: bool = True,\n",
" ) -> DiffPASSResults:\n",
Expand All @@ -832,7 +842,6 @@
" optimizer_name=optimizer_name,\n",
" optimizer_kwargs=optimizer_kwargs,\n",
" mean_centering=mean_centering,\n",
" similarity_gradient_bypass=similarity_gradient_bypass,\n",
" show_pbar=show_pbar,\n",
" compute_final_soft=compute_final_soft,\n",
" )\n",
Expand All @@ -859,6 +868,7 @@
" similarity_kind: Literal[\"Hamming\", \"Blosum62\"] = \"Hamming\",\n",
" similarities_cfg: Optional[dict[str, Any]] = None,\n",
" intra_group_loss_score_fn: Optional[callable] = None,\n",
" similarity_gradient_bypass: bool = False,\n",
" ):\n",
" super().__init__()\n",
" self.group_sizes = tuple(s for s in group_sizes)\n",
Expand All @@ -869,6 +879,7 @@
" self.similarity_kind = similarity_kind\n",
" self.similarities_cfg = similarities_cfg\n",
" self.intra_group_loss_score_fn = intra_group_loss_score_fn\n",
" self.similarity_gradient_bypass = similarity_gradient_bypass\n",
"\n",
" ensemble_shape = []\n",
" _dim_in_ensemble = -1\n",
Expand Down Expand Up @@ -1044,17 +1055,15 @@
"\n",
" # Mirrortree portion of the loss, with shortcut for hard permutations\n",
" if mode == \"soft\":\n",
" if x_perm_hard is not None:\n",
" if self.similarity_gradient_bypass:\n",
" x_perm = (x_perm_hard - x_perm).detach() + x_perm\n",
" similarities_x = self.similarities(x_perm)\n",
" else:\n",
" similarities_x = apply_hard_permutation_batch_to_similarity(\n",
" x=self._similarities_hard_x, perms=perms\n",
" )\n",
" # Ensure comparisons are only hard-hard or soft-soft\n",
" loss_mt = self.intra_group_loss(\n",
" similarities_x, getattr(self, f\"_similarities_hard_y\")\n",
" )\n",
"\n",
" loss_mt = self.intra_group_loss(similarities_x, self._similarities_hard_y)\n",
"\n",
" return {\n",
" \"perms\": perms,\n",
Expand Down Expand Up @@ -1111,7 +1120,6 @@
" optimizer_name: Optional[str] = \"SGD\",\n",
" optimizer_kwargs: Optional[dict[str, Any]] = None,\n",
" mean_centering: bool = True,\n",
" similarity_gradient_bypass: bool = False,\n",
" show_pbar: bool = True,\n",
" compute_final_soft: bool = True,\n",
" ) -> DiffPASSResults:\n",
Expand Down Expand Up @@ -1139,7 +1147,7 @@
" epoch_results = self(x, y)\n",
" loss_info = epoch_results[\"loss_info\"]\n",
" loss_mt = epoch_results[\"loss_mt\"]\n",
" if similarity_gradient_bypass:\n",
" if self.similarity_gradient_bypass:\n",
" x_perm_hard = epoch_results[\"x_perm\"]\n",
" perms = epoch_results[\"perms\"]\n",
" results.log_alphas.append(\n",
Expand Down Expand Up @@ -1191,7 +1199,6 @@
" optimizer_name: Optional[str] = \"SGD\",\n",
" optimizer_kwargs: Optional[dict[str, Any]] = None,\n",
" mean_centering: bool = True,\n",
" similarity_gradient_bypass: bool = False,\n",
" show_pbar: bool = True,\n",
" compute_final_soft: bool = True,\n",
" ) -> DiffPASSResults:\n",
Expand All @@ -1205,7 +1212,6 @@
" optimizer_name=optimizer_name,\n",
" optimizer_kwargs=optimizer_kwargs,\n",
" mean_centering=mean_centering,\n",
" similarity_gradient_bypass=similarity_gradient_bypass,\n",
" show_pbar=show_pbar,\n",
" compute_final_soft=compute_final_soft,\n",
" )\n",
Expand Down

0 comments on commit 0184260

Please sign in to comment.