Skip to content

Commit

Permalink
Workaround apparent PyTorch-CUDA bug with empty tensors in ParameterL…
Browse files Browse the repository at this point in the history
…ist, add exception if fit not possible
  • Loading branch information
ulupo committed Dec 20, 2023
1 parent 73cf9b7 commit fc28e3d
Show file tree
Hide file tree
Showing 7 changed files with 76 additions and 16 deletions.
2 changes: 2 additions & 0 deletions diffpass/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
'git_url': 'https://github.com/Bitbol-Lab/DiffPASS',
'lib_path': 'diffpass'},
'syms': { 'diffpass.base': { 'diffpass.base.DiffPASSMixin': ('base.html#diffpassmixin', 'diffpass/base.py'),
'diffpass.base.DiffPASSMixin.check_can_optimize': ( 'base.html#diffpassmixin.check_can_optimize',
'diffpass/base.py'),
'diffpass.base.DiffPASSMixin.reduce_num_tokens': ( 'base.html#diffpassmixin.reduce_num_tokens',
'diffpass/base.py'),
'diffpass.base.DiffPASSMixin.validate_best_hits_cfg': ( 'base.html#diffpassmixin.validate_best_hits_cfg',
Expand Down
28 changes: 22 additions & 6 deletions diffpass/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ def reduce_num_tokens(x: torch.Tensor) -> torch.Tensor:
def validate_permutation_cfg(self, permutation_cfg: dict) -> None:
if not set(permutation_cfg).issubset(self.allowed_permutation_cfg_keys):
raise ValueError(
f"Invalid keys in `permutation_cfg`: {set(permutation_cfg) - self.allowed_permutation_cfg_keys}"
f"Invalid keys in `permutation_cfg`: "
f"{set(permutation_cfg) - self.allowed_permutation_cfg_keys}"
)

def validate_information_measure(self, information_measure: str) -> None:
Expand All @@ -66,13 +67,15 @@ def validate_similarities_cfg(self, similarities_cfg: dict) -> None:
self.allowed_similarities_cfg_keys[self.similarity_kind]
):
raise ValueError(
f"Invalid keys in `similarities_cfg`: {set(similarities_cfg) - self.allowed_similarities_cfg_keys[self.similarity_kind]}"
f"Invalid keys in `similarities_cfg`: "
f"{set(similarities_cfg) - self.allowed_similarities_cfg_keys[self.similarity_kind]}"
)

def validate_best_hits_cfg(self, best_hits_cfg: dict) -> None:
if not set(best_hits_cfg).issubset(self.allowed_best_hits_cfg_keys):
raise ValueError(
f"Invalid keys in `best_hits_cfg`: {set(best_hits_cfg) - self.allowed_best_hits_cfg_keys}"
f"Invalid keys in `best_hits_cfg`: "
f"{set(best_hits_cfg) - self.allowed_best_hits_cfg_keys}"
)

def validate_inputs(
Expand All @@ -94,13 +97,26 @@ def validate_inputs(
f"size of {total_size}."
)

@staticmethod
def check_can_optimize(n_effectively_fixed: int, n_available: int) -> None:
if n_effectively_fixed == n_available:
raise ValueError(
"The number of effectively fixed matchings is equal to the number "
"of sequences. No optimization can be performed."
)
elif n_effectively_fixed > n_available:
raise ValueError(
"The number of effectively fixed matchings is greater than the number "
"of available sequences. Check your inputs."
)

# %% ../nbs/base.ipynb 4
def scalar_or_1d_tensor(
*, param: Any, param_name: str, dtype: torch.dtype = torch.float32
) -> torch.Tensor:
if not isinstance(param, (float, torch.Tensor)):
raise TypeError(f"`{param_name}` must be a float or a torch.Tensor.")
if isinstance(param, float):
if not isinstance(param, (int, float, torch.Tensor)):
raise TypeError(f"`{param_name}` must be a scalar or a torch.Tensor.")
if not isinstance(param, torch.Tensor):
param = torch.tensor(param, dtype=dtype)
elif param.ndim > 1:
raise ValueError(
Expand Down
14 changes: 12 additions & 2 deletions diffpass/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,10 @@ def __init__(
)
self.log_alphas = ParameterList(
[
Parameter(torch.zeros(*self.ensemble_shape, s, s, dtype=torch.float32))
Parameter(
torch.zeros(*self.ensemble_shape, s, s, dtype=torch.float32),
requires_grad=bool(s),
)
for s in self.nonfixed_group_sizes_
]
)
Expand Down Expand Up @@ -133,7 +136,7 @@ def _validate_fixed_matchings(
fm_zip = list(zip(*fm))
else:
num_fm = 0
fm_zip = ((), ())
fm_zip = [(), ()]
complement = s - num_fm # Effectively fully fixed when complement <= 1
is_fully_fixed = complement <= 1
num_efm = s - (s - num_fm) * (not is_fully_fixed)
Expand All @@ -151,6 +154,13 @@ def _validate_fixed_matchings(
mask[..., :, i] = False
self.register_buffer(f"_not_fixed_masks_{idx}", mask)
self._effective_fixed_matchings_zip.append(fm_zip)
self._total_number_fixed_matchings = sum(
self._effective_number_fixed_matchings
)
else:
self._effective_fixed_matchings_zip = [[(), ()] for _ in self.group_sizes]
self._effective_number_fixed_matchings = [0] * len(self.group_sizes)
self._total_number_fixed_matchings = 0

@property
def _not_fixed_masks(self) -> list[torch.Tensor]:
Expand Down
3 changes: 3 additions & 0 deletions diffpass/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,7 @@ def fit(
compute_final_soft: bool = True,
) -> DiffPASSResults:
results = self._prepare_fit(x, y)
self.check_can_optimize(self.permutation._total_number_fixed_matchings, len(x))
results = self._fit(
x,
y,
Expand Down Expand Up @@ -695,6 +696,7 @@ def fit(
compute_final_soft: bool = True,
) -> DiffPASSResults:
results = self._prepare_fit(x, y)
self.check_can_optimize(self.permutation._total_number_fixed_matchings, len(x))
results = self._fit(
x,
y,
Expand Down Expand Up @@ -1059,6 +1061,7 @@ def fit(
compute_final_soft: bool = True,
) -> DiffPASSResults:
results = self._prepare_fit(x, y)
self.check_can_optimize(self.permutation._total_number_fixed_matchings, len(x))
results = self._fit(
x,
y,
Expand Down
28 changes: 22 additions & 6 deletions nbs/base.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@
" def validate_permutation_cfg(self, permutation_cfg: dict) -> None:\n",
" if not set(permutation_cfg).issubset(self.allowed_permutation_cfg_keys):\n",
" raise ValueError(\n",
" f\"Invalid keys in `permutation_cfg`: {set(permutation_cfg) - self.allowed_permutation_cfg_keys}\"\n",
" f\"Invalid keys in `permutation_cfg`: \"\n",
" f\"{set(permutation_cfg) - self.allowed_permutation_cfg_keys}\"\n",
" )\n",
"\n",
" def validate_information_measure(self, information_measure: str) -> None:\n",
Expand All @@ -96,13 +97,15 @@
" self.allowed_similarities_cfg_keys[self.similarity_kind]\n",
" ):\n",
" raise ValueError(\n",
" f\"Invalid keys in `similarities_cfg`: {set(similarities_cfg) - self.allowed_similarities_cfg_keys[self.similarity_kind]}\"\n",
" f\"Invalid keys in `similarities_cfg`: \"\n",
" f\"{set(similarities_cfg) - self.allowed_similarities_cfg_keys[self.similarity_kind]}\"\n",
" )\n",
"\n",
" def validate_best_hits_cfg(self, best_hits_cfg: dict) -> None:\n",
" if not set(best_hits_cfg).issubset(self.allowed_best_hits_cfg_keys):\n",
" raise ValueError(\n",
" f\"Invalid keys in `best_hits_cfg`: {set(best_hits_cfg) - self.allowed_best_hits_cfg_keys}\"\n",
" f\"Invalid keys in `best_hits_cfg`: \"\n",
" f\"{set(best_hits_cfg) - self.allowed_best_hits_cfg_keys}\"\n",
" )\n",
"\n",
" def validate_inputs(\n",
Expand All @@ -122,6 +125,19 @@
" raise ValueError(\n",
" f\"Inputs have size {total_size} but `group_sizes` implies a total \"\n",
" f\"size of {total_size}.\"\n",
" )\n",
"\n",
" @staticmethod\n",
" def check_can_optimize(n_effectively_fixed: int, n_available: int) -> None:\n",
" if n_effectively_fixed == n_available:\n",
" raise ValueError(\n",
" \"The number of effectively fixed matchings is equal to the number \"\n",
" \"of sequences. No optimization can be performed.\"\n",
" )\n",
" elif n_effectively_fixed > n_available:\n",
" raise ValueError(\n",
" \"The number of effectively fixed matchings is greater than the number \"\n",
" \"of available sequences. Check your inputs.\"\n",
" )"
]
},
Expand All @@ -136,9 +152,9 @@
"def scalar_or_1d_tensor(\n",
" *, param: Any, param_name: str, dtype: torch.dtype = torch.float32\n",
") -> torch.Tensor:\n",
" if not isinstance(param, (float, torch.Tensor)):\n",
" raise TypeError(f\"`{param_name}` must be a float or a torch.Tensor.\")\n",
" if isinstance(param, float):\n",
" if not isinstance(param, (int, float, torch.Tensor)):\n",
" raise TypeError(f\"`{param_name}` must be a scalar or a torch.Tensor.\")\n",
" if not isinstance(param, torch.Tensor):\n",
" param = torch.tensor(param, dtype=dtype)\n",
" elif param.ndim > 1:\n",
" raise ValueError(\n",
Expand Down
14 changes: 12 additions & 2 deletions nbs/model.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,10 @@
" )\n",
" self.log_alphas = ParameterList(\n",
" [\n",
" Parameter(torch.zeros(*self.ensemble_shape, s, s, dtype=torch.float32))\n",
" Parameter(\n",
" torch.zeros(*self.ensemble_shape, s, s, dtype=torch.float32),\n",
" requires_grad=bool(s),\n",
" )\n",
" for s in self.nonfixed_group_sizes_\n",
" ]\n",
" )\n",
Expand Down Expand Up @@ -188,7 +191,7 @@
" fm_zip = list(zip(*fm))\n",
" else:\n",
" num_fm = 0\n",
" fm_zip = ((), ())\n",
" fm_zip = [(), ()]\n",
" complement = s - num_fm # Effectively fully fixed when complement <= 1\n",
" is_fully_fixed = complement <= 1\n",
" num_efm = s - (s - num_fm) * (not is_fully_fixed)\n",
Expand All @@ -206,6 +209,13 @@
" mask[..., :, i] = False\n",
" self.register_buffer(f\"_not_fixed_masks_{idx}\", mask)\n",
" self._effective_fixed_matchings_zip.append(fm_zip)\n",
" self._total_number_fixed_matchings = sum(\n",
" self._effective_number_fixed_matchings\n",
" )\n",
" else:\n",
" self._effective_fixed_matchings_zip = [[(), ()] for _ in self.group_sizes]\n",
" self._effective_number_fixed_matchings = [0] * len(self.group_sizes)\n",
" self._total_number_fixed_matchings = 0\n",
"\n",
" @property\n",
" def _not_fixed_masks(self) -> list[torch.Tensor]:\n",
Expand Down
3 changes: 3 additions & 0 deletions nbs/train.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,7 @@
" compute_final_soft: bool = True,\n",
" ) -> DiffPASSResults:\n",
" results = self._prepare_fit(x, y)\n",
" self.check_can_optimize(self.permutation._total_number_fixed_matchings, len(x))\n",
" results = self._fit(\n",
" x,\n",
" y,\n",
Expand Down Expand Up @@ -775,6 +776,7 @@
" compute_final_soft: bool = True,\n",
" ) -> DiffPASSResults:\n",
" results = self._prepare_fit(x, y)\n",
" self.check_can_optimize(self.permutation._total_number_fixed_matchings, len(x))\n",
" results = self._fit(\n",
" x,\n",
" y,\n",
Expand Down Expand Up @@ -1147,6 +1149,7 @@
" compute_final_soft: bool = True,\n",
" ) -> DiffPASSResults:\n",
" results = self._prepare_fit(x, y)\n",
" self.check_can_optimize(self.permutation._total_number_fixed_matchings, len(x))\n",
" results = self._fit(\n",
" x,\n",
" y,\n",
Expand Down

0 comments on commit fc28e3d

Please sign in to comment.