diff --git a/diffpass/_modidx.py b/diffpass/_modidx.py index 1fd7a78..797d197 100644 --- a/diffpass/_modidx.py +++ b/diffpass/_modidx.py @@ -6,6 +6,8 @@ 'git_url': 'https://github.com/Bitbol-Lab/DiffPASS', 'lib_path': 'diffpass'}, 'syms': { 'diffpass.base': { 'diffpass.base.DiffPASSMixin': ('base.html#diffpassmixin', 'diffpass/base.py'), + 'diffpass.base.DiffPASSMixin.check_can_optimize': ( 'base.html#diffpassmixin.check_can_optimize', + 'diffpass/base.py'), 'diffpass.base.DiffPASSMixin.reduce_num_tokens': ( 'base.html#diffpassmixin.reduce_num_tokens', 'diffpass/base.py'), 'diffpass.base.DiffPASSMixin.validate_best_hits_cfg': ( 'base.html#diffpassmixin.validate_best_hits_cfg', diff --git a/diffpass/base.py b/diffpass/base.py index 9859f60..578b75d 100644 --- a/diffpass/base.py +++ b/diffpass/base.py @@ -44,7 +44,8 @@ def reduce_num_tokens(x: torch.Tensor) -> torch.Tensor: def validate_permutation_cfg(self, permutation_cfg: dict) -> None: if not set(permutation_cfg).issubset(self.allowed_permutation_cfg_keys): raise ValueError( - f"Invalid keys in `permutation_cfg`: {set(permutation_cfg) - self.allowed_permutation_cfg_keys}" + f"Invalid keys in `permutation_cfg`: " + f"{set(permutation_cfg) - self.allowed_permutation_cfg_keys}" ) def validate_information_measure(self, information_measure: str) -> None: @@ -66,13 +67,15 @@ def validate_similarities_cfg(self, similarities_cfg: dict) -> None: self.allowed_similarities_cfg_keys[self.similarity_kind] ): raise ValueError( - f"Invalid keys in `similarities_cfg`: {set(similarities_cfg) - self.allowed_similarities_cfg_keys[self.similarity_kind]}" + f"Invalid keys in `similarities_cfg`: " + f"{set(similarities_cfg) - self.allowed_similarities_cfg_keys[self.similarity_kind]}" ) def validate_best_hits_cfg(self, best_hits_cfg: dict) -> None: if not set(best_hits_cfg).issubset(self.allowed_best_hits_cfg_keys): raise ValueError( - f"Invalid keys in `best_hits_cfg`: {set(best_hits_cfg) - self.allowed_best_hits_cfg_keys}" + f"Invalid keys in `best_hits_cfg`: " + f"{set(best_hits_cfg) - self.allowed_best_hits_cfg_keys}" ) def validate_inputs( @@ -94,13 +97,26 @@ def validate_inputs( f"size of {total_size}." ) + @staticmethod + def check_can_optimize(n_effectively_fixed: int, n_available: int) -> None: + if n_effectively_fixed == n_available: + raise ValueError( + "The number of effectively fixed matchings is equal to the number " + "of sequences. No optimization can be performed." + ) + elif n_effectively_fixed > n_available: + raise ValueError( + "The number of effectively fixed matchings is greater than the number " + "of available sequences. Check your inputs." + ) + # %% ../nbs/base.ipynb 4 def scalar_or_1d_tensor( *, param: Any, param_name: str, dtype: torch.dtype = torch.float32 ) -> torch.Tensor: - if not isinstance(param, (float, torch.Tensor)): - raise TypeError(f"`{param_name}` must be a float or a torch.Tensor.") - if isinstance(param, float): + if not isinstance(param, (int, float, torch.Tensor)): + raise TypeError(f"`{param_name}` must be a scalar or a torch.Tensor.") + if not isinstance(param, torch.Tensor): param = torch.tensor(param, dtype=dtype) elif param.ndim > 1: raise ValueError( diff --git a/diffpass/model.py b/diffpass/model.py index 575f25f..f7b7ebb 100644 --- a/diffpass/model.py +++ b/diffpass/model.py @@ -88,7 +88,10 @@ def __init__( ) self.log_alphas = ParameterList( [ - Parameter(torch.zeros(*self.ensemble_shape, s, s, dtype=torch.float32)) + Parameter( + torch.zeros(*self.ensemble_shape, s, s, dtype=torch.float32), + requires_grad=bool(s), + ) for s in self.nonfixed_group_sizes_ ] ) @@ -133,7 +136,7 @@ def _validate_fixed_matchings( fm_zip = list(zip(*fm)) else: num_fm = 0 - fm_zip = ((), ()) + fm_zip = [(), ()] complement = s - num_fm # Effectively fully fixed when complement <= 1 is_fully_fixed = complement <= 1 num_efm = s - (s - num_fm) * (not is_fully_fixed) @@ -151,6 +154,13 @@ def _validate_fixed_matchings( mask[..., :, i] = False self.register_buffer(f"_not_fixed_masks_{idx}", mask) self._effective_fixed_matchings_zip.append(fm_zip) + self._total_number_fixed_matchings = sum( + self._effective_number_fixed_matchings + ) + else: + self._effective_fixed_matchings_zip = [[(), ()] for _ in self.group_sizes] + self._effective_number_fixed_matchings = [0] * len(self.group_sizes) + self._total_number_fixed_matchings = 0 @property def _not_fixed_masks(self) -> list[torch.Tensor]: diff --git a/diffpass/train.py b/diffpass/train.py index 461037d..2d51379 100644 --- a/diffpass/train.py +++ b/diffpass/train.py @@ -289,6 +289,7 @@ def fit( compute_final_soft: bool = True, ) -> DiffPASSResults: results = self._prepare_fit(x, y) + self.check_can_optimize(self.permutation._total_number_fixed_matchings, len(x)) results = self._fit( x, y, @@ -695,6 +696,7 @@ def fit( compute_final_soft: bool = True, ) -> DiffPASSResults: results = self._prepare_fit(x, y) + self.check_can_optimize(self.permutation._total_number_fixed_matchings, len(x)) results = self._fit( x, y, @@ -1059,6 +1061,7 @@ def fit( compute_final_soft: bool = True, ) -> DiffPASSResults: results = self._prepare_fit(x, y) + self.check_can_optimize(self.permutation._total_number_fixed_matchings, len(x)) results = self._fit( x, y, diff --git a/nbs/base.ipynb b/nbs/base.ipynb index 1e862aa..7ad1719 100644 --- a/nbs/base.ipynb +++ b/nbs/base.ipynb @@ -74,7 +74,8 @@ " def validate_permutation_cfg(self, permutation_cfg: dict) -> None:\n", " if not set(permutation_cfg).issubset(self.allowed_permutation_cfg_keys):\n", " raise ValueError(\n", - " f\"Invalid keys in `permutation_cfg`: {set(permutation_cfg) - self.allowed_permutation_cfg_keys}\"\n", + " f\"Invalid keys in `permutation_cfg`: \"\n", + " f\"{set(permutation_cfg) - self.allowed_permutation_cfg_keys}\"\n", " )\n", "\n", " def validate_information_measure(self, information_measure: str) -> None:\n", @@ -96,13 +97,15 @@ " self.allowed_similarities_cfg_keys[self.similarity_kind]\n", " ):\n", " raise ValueError(\n", - " f\"Invalid keys in `similarities_cfg`: {set(similarities_cfg) - self.allowed_similarities_cfg_keys[self.similarity_kind]}\"\n", + " f\"Invalid keys in `similarities_cfg`: \"\n", + " f\"{set(similarities_cfg) - self.allowed_similarities_cfg_keys[self.similarity_kind]}\"\n", " )\n", "\n", " def validate_best_hits_cfg(self, best_hits_cfg: dict) -> None:\n", " if not set(best_hits_cfg).issubset(self.allowed_best_hits_cfg_keys):\n", " raise ValueError(\n", - " f\"Invalid keys in `best_hits_cfg`: {set(best_hits_cfg) - self.allowed_best_hits_cfg_keys}\"\n", + " f\"Invalid keys in `best_hits_cfg`: \"\n", + " f\"{set(best_hits_cfg) - self.allowed_best_hits_cfg_keys}\"\n", " )\n", "\n", " def validate_inputs(\n", @@ -122,6 +125,19 @@ " raise ValueError(\n", " f\"Inputs have size {total_size} but `group_sizes` implies a total \"\n", " f\"size of {total_size}.\"\n", + " )\n", + "\n", + " @staticmethod\n", + " def check_can_optimize(n_effectively_fixed: int, n_available: int) -> None:\n", + " if n_effectively_fixed == n_available:\n", + " raise ValueError(\n", + " \"The number of effectively fixed matchings is equal to the number \"\n", + " \"of sequences. No optimization can be performed.\"\n", + " )\n", + " elif n_effectively_fixed > n_available:\n", + " raise ValueError(\n", + " \"The number of effectively fixed matchings is greater than the number \"\n", + " \"of available sequences. Check your inputs.\"\n", " )" ] }, @@ -136,9 +152,9 @@ "def scalar_or_1d_tensor(\n", " *, param: Any, param_name: str, dtype: torch.dtype = torch.float32\n", ") -> torch.Tensor:\n", - " if not isinstance(param, (float, torch.Tensor)):\n", - " raise TypeError(f\"`{param_name}` must be a float or a torch.Tensor.\")\n", - " if isinstance(param, float):\n", + " if not isinstance(param, (int, float, torch.Tensor)):\n", + " raise TypeError(f\"`{param_name}` must be a scalar or a torch.Tensor.\")\n", + " if not isinstance(param, torch.Tensor):\n", " param = torch.tensor(param, dtype=dtype)\n", " elif param.ndim > 1:\n", " raise ValueError(\n", diff --git a/nbs/model.ipynb b/nbs/model.ipynb index 4617579..2f8465d 100644 --- a/nbs/model.ipynb +++ b/nbs/model.ipynb @@ -143,7 +143,10 @@ " )\n", " self.log_alphas = ParameterList(\n", " [\n", - " Parameter(torch.zeros(*self.ensemble_shape, s, s, dtype=torch.float32))\n", + " Parameter(\n", + " torch.zeros(*self.ensemble_shape, s, s, dtype=torch.float32),\n", + " requires_grad=bool(s),\n", + " )\n", " for s in self.nonfixed_group_sizes_\n", " ]\n", " )\n", @@ -188,7 +191,7 @@ " fm_zip = list(zip(*fm))\n", " else:\n", " num_fm = 0\n", - " fm_zip = ((), ())\n", + " fm_zip = [(), ()]\n", " complement = s - num_fm # Effectively fully fixed when complement <= 1\n", " is_fully_fixed = complement <= 1\n", " num_efm = s - (s - num_fm) * (not is_fully_fixed)\n", @@ -206,6 +209,13 @@ " mask[..., :, i] = False\n", " self.register_buffer(f\"_not_fixed_masks_{idx}\", mask)\n", " self._effective_fixed_matchings_zip.append(fm_zip)\n", + " self._total_number_fixed_matchings = sum(\n", + " self._effective_number_fixed_matchings\n", + " )\n", + " else:\n", + " self._effective_fixed_matchings_zip = [[(), ()] for _ in self.group_sizes]\n", + " self._effective_number_fixed_matchings = [0] * len(self.group_sizes)\n", + " self._total_number_fixed_matchings = 0\n", "\n", " @property\n", " def _not_fixed_masks(self) -> list[torch.Tensor]:\n", diff --git a/nbs/train.ipynb b/nbs/train.ipynb index 7ad86cd..3842a0b 100644 --- a/nbs/train.ipynb +++ b/nbs/train.ipynb @@ -361,6 +361,7 @@ " compute_final_soft: bool = True,\n", " ) -> DiffPASSResults:\n", " results = self._prepare_fit(x, y)\n", + " self.check_can_optimize(self.permutation._total_number_fixed_matchings, len(x))\n", " results = self._fit(\n", " x,\n", " y,\n", @@ -775,6 +776,7 @@ " compute_final_soft: bool = True,\n", " ) -> DiffPASSResults:\n", " results = self._prepare_fit(x, y)\n", + " self.check_can_optimize(self.permutation._total_number_fixed_matchings, len(x))\n", " results = self._fit(\n", " x,\n", " y,\n", @@ -1147,6 +1149,7 @@ " compute_final_soft: bool = True,\n", " ) -> DiffPASSResults:\n", " results = self._prepare_fit(x, y)\n", + " self.check_can_optimize(self.permutation._total_number_fixed_matchings, len(x))\n", " results = self._fit(\n", " x,\n", " y,\n",