Skip to content

Commit

Permalink
Use default floating dtype instead of imposing float32
Browse files Browse the repository at this point in the history
  • Loading branch information
ulupo committed Mar 5, 2024
1 parent 40eb009 commit b1f12f9
Show file tree
Hide file tree
Showing 6 changed files with 10 additions and 221 deletions.
8 changes: 3 additions & 5 deletions diffpass/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,13 +111,11 @@ def check_can_optimize(n_effectively_fixed: int, n_available: int) -> None:
)

# %% ../nbs/base.ipynb 4
def scalar_or_1d_tensor(
*, param: Any, param_name: str, dtype: torch.dtype = torch.float32
) -> torch.Tensor:
def scalar_or_1d_tensor(*, param: Any, param_name: str) -> torch.Tensor:
if not isinstance(param, (int, float, torch.Tensor)):
raise TypeError(f"`{param_name}` must be a scalar or a torch.Tensor.")
if not isinstance(param, torch.Tensor):
param = torch.tensor(param, dtype=dtype)
param = torch.tensor(param, dtype=torch.get_default_dtype())
elif param.ndim > 1:
raise ValueError(
f"`{param_name}` must be a scalar or a tensor of dimension <= 1."
Expand Down Expand Up @@ -163,7 +161,7 @@ def _reshape_ensemble_param(
raise ValueError(
f"`dim_in_ensemble` cannot be None if {param_name} is 1D."
)
param = param.to(torch.float32)
param = param.to(torch.get_default_dtype())
# If param is not a scalar, broadcast it along the `ensemble_dim`-th ensemble dimension
if dim_in_ensemble >= n_ensemble_dims or dim_in_ensemble < -n_ensemble_dims:
raise ValueError(
Expand Down
2 changes: 1 addition & 1 deletion diffpass/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,6 @@ def get_blosum62_data(

return TokenizedSubstitutionMatrix(
name=BLOSUM62.name,
mat=torch.tensor(mat.to_numpy(), dtype=torch.float32),
mat=torch.tensor(mat.to_numpy(), dtype=torch.get_default_dtype()),
expected_value=BLOSUM62.expected_value,
)
2 changes: 1 addition & 1 deletion diffpass/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def __init__(
self.log_alphas = ParameterList(
[
Parameter(
torch.zeros(*self.ensemble_shape, s, s, dtype=torch.float32),
torch.zeros(*self.ensemble_shape, s, s),
requires_grad=bool(s),
)
for s in self.nonfixed_group_sizes_
Expand Down
215 changes: 3 additions & 212 deletions nbs/base.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -149,13 +149,11 @@
"source": [
"#| export\n",
"\n",
"def scalar_or_1d_tensor(\n",
" *, param: Any, param_name: str, dtype: torch.dtype = torch.float32\n",
") -> torch.Tensor:\n",
"def scalar_or_1d_tensor(*, param: Any, param_name: str) -> torch.Tensor:\n",
" if not isinstance(param, (int, float, torch.Tensor)):\n",
" raise TypeError(f\"`{param_name}` must be a scalar or a torch.Tensor.\")\n",
" if not isinstance(param, torch.Tensor):\n",
" param = torch.tensor(param, dtype=dtype)\n",
" param = torch.tensor(param, dtype=torch.get_default_dtype())\n",
" elif param.ndim > 1:\n",
" raise ValueError(\n",
" f\"`{param_name}` must be a scalar or a tensor of dimension <= 1.\"\n",
Expand Down Expand Up @@ -201,7 +199,7 @@
" raise ValueError(\n",
" f\"`dim_in_ensemble` cannot be None if {param_name} is 1D.\"\n",
" )\n",
" param = param.to(torch.float32)\n",
" param = param.to(torch.get_default_dtype())\n",
" # If param is not a scalar, broadcast it along the `ensemble_dim`-th ensemble dimension\n",
" if dim_in_ensemble >= n_ensemble_dims or dim_in_ensemble < -n_ensemble_dims:\n",
" raise ValueError(\n",
Expand All @@ -224,213 +222,6 @@
"\n",
" return param"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# class PermutationsMixin:\n",
"# \"\"\"Mixin class for validating input and plotting the results of the optimization.\"\"\"\n",
"# \n",
"# std_init: float\n",
"# device: torch.device\n",
"# group_sizes: list[int]\n",
"# \n",
"# def _init_log_alpha(self, skip=False):\n",
"# \"\"\"Initialize log_alpha as a list of matrices of shape (s, s) where d is the\n",
"# size of the species MSA. The matrices are initialized with standard normal entries.\n",
"# \"\"\"\n",
"# if not skip:\n",
"# # Permutations restricted to species\n",
"# self.log_alpha = [\n",
"# (self.std_init * torch.randn(s, s, device=self.device)).requires_grad_()\n",
"# for s in self._effective_sizes_not_fixed\n",
"# ]\n",
"# \n",
"# def _validator(self, input_1, input_2, fixed_pairings=None):\n",
"# \"\"\"Validate input MSAs and check fixed pairings.\"\"\"\n",
"# # Validate input MSAs\n",
"# size_1, length_1, alphabet_size_1 = input_1.shape[1:]\n",
"# size_2, length_2, alphabet_size_2 = input_2.shape[1:]\n",
"# length_1 -= 1\n",
"# length_2 -= 1\n",
"# if size_1 != size_2:\n",
"# raise ValueError(\n",
"# f\"Size mismatch between MSA 1 ({size_1}) and MSA 2 \" f\"({size_2})\"\n",
"# )\n",
"# if alphabet_size_1 != alphabet_size_2:\n",
"# raise ValueError(\"Input MSAs must have the same alphabet size/\")\n",
"# self._alphabet_size = alphabet_size_1\n",
"# \n",
"# # Validate size attribute\n",
"# self._total_size = sum(self.group_sizes)\n",
"# if size_1 != self._total_size:\n",
"# raise ValueError(\n",
"# f\"Input MSAs have size {size_1} but model expects a total \"\n",
"# f\"size of {self._total_size}\"\n",
"# )\n",
"# self._length_1, self._length_2 = length_1, length_2\n",
"# \n",
"# self._effective_non_fixed_pairs = torch.ones(\n",
"# self._total_size, self._total_size, dtype=torch.bool, device=self.device\n",
"# )\n",
"# \n",
"# if fixed_pairings is not None:\n",
"# if len(fixed_pairings) != len(self.group_sizes):\n",
"# raise ValueError(\n",
"# f\"`fixed_pairings` has length {len(fixed_pairings)} but \"\n",
"# f\"there are {self.group_sizes} species.\"\n",
"# )\n",
"# _fixed_pairings = fixed_pairings\n",
"# \n",
"# start = 0\n",
"# self._effective_sizes_not_fixed = []\n",
"# self._effective_fixed_pairings_zip = []\n",
"# for species_idx, (species_size, species_fixed_pairings) in enumerate(\n",
"# zip(self.group_sizes, _fixed_pairings)\n",
"# ):\n",
"# # Check uniqueness of pairs (i, j)\n",
"# n_fixed = len(set(species_fixed_pairings))\n",
"# if len(species_fixed_pairings) > n_fixed:\n",
"# raise ValueError(\n",
"# \"Repeated indices for fixed pairings at species \"\n",
"# f\"{species_idx}: {species_fixed_pairings}\"\n",
"# )\n",
"# fixed_pairings_arr = np.zeros((species_size, species_size), dtype=int)\n",
"# if species_fixed_pairings:\n",
"# species_fixed_pairings_zip = tuple(zip(*species_fixed_pairings))\n",
"# else:\n",
"# # species_fixed_pairings is an empty list\n",
"# species_fixed_pairings_zip = (tuple(), tuple())\n",
"# try:\n",
"# fixed_pairings_arr[species_fixed_pairings_zip] = 1\n",
"# except IndexError:\n",
"# raise ValueError(\n",
"# f\"Fixed pairings indices out of bounds: passed {species_fixed_pairings} \"\n",
"# f\"for species {species_idx} with size {species_size}.\"\n",
"# )\n",
"# partial_sum_0 = fixed_pairings_arr.sum(axis=0)\n",
"# partial_sum_1 = fixed_pairings_arr.sum(axis=1)\n",
"# if (partial_sum_0 > 1).any() or (partial_sum_1 > 1).any():\n",
"# raise ValueError(\n",
"# f\"Passed fixed pairings for species {species_idx} are either not one-one \"\n",
"# \"or a multiply-defined mapping from row to column indices: \"\n",
"# f\"{species_fixed_pairings}\"\n",
"# )\n",
"# for i, j in species_fixed_pairings:\n",
"# self._effective_non_fixed_pairs[start + i, :] = False\n",
"# self._effective_non_fixed_pairs[:, start + j] = False\n",
"# total_minus_fixed = species_size - n_fixed\n",
"# # If species_size - n_fixed <= 1 then actually everything is fixed\n",
"# self._effective_sizes_not_fixed.append(\n",
"# int(total_minus_fixed > 1) * total_minus_fixed\n",
"# )\n",
"# if total_minus_fixed == 1:\n",
"# # Deduce implicitly fixed pair\n",
"# i_implicit, j_implicit = np.argmin(partial_sum_1), np.argmin(\n",
"# partial_sum_0\n",
"# )\n",
"# self._effective_non_fixed_pairs[start + i_implicit, :] = False\n",
"# self._effective_non_fixed_pairs[:, start + j_implicit] = False\n",
"# species_fixed_pairings_zip = (\n",
"# species_fixed_pairings_zip[0] + (i_implicit,),\n",
"# species_fixed_pairings_zip[1] + (j_implicit,),\n",
"# )\n",
"# self._effective_fixed_pairings_zip.append(species_fixed_pairings_zip)\n",
"# start += species_size\n",
"# else:\n",
"# self._effective_sizes_not_fixed = self.group_sizes\n",
"# self._effective_fixed_pairings_zip = None\n",
"# \n",
"# self._default_target_idx = torch.arange(\n",
"# self._total_size, dtype=torch.int64, device=self.device\n",
"# )\n",
"# \n",
"# def plot_real_time(\n",
"# self,\n",
"# it,\n",
"# gs_matching_mat_np,\n",
"# gs_mat_np,\n",
"# list_idx,\n",
"# target_idx,\n",
"# list_log_alpha,\n",
"# losses,\n",
"# batch_size,\n",
"# epochs,\n",
"# lr,\n",
"# tar_loss,\n",
"# new_noise_factor,\n",
"# output_dir,\n",
"# only_loss_plot,\n",
"# ):\n",
"# \"\"\"Plot the results of the optimization in real time.\"\"\"\n",
"# n_correct = [sum(idx == target_idx) for idx in list_idx[::batch_size]]\n",
"# \n",
"# cmap = cm.get_cmap(\"Blues\")\n",
"# normalizer = None\n",
"# fig, axes = plt.subplots(figsize=(30, 5), ncols=5, constrained_layout=True)\n",
"# \n",
"# null_model = len(self.group_sizes)\n",
"# _size = [0] + list(np.cumsum(self.group_sizes))\n",
"# for k in range(1, len(_size)):\n",
"# for ii in range(2):\n",
"# elem, elem1 = _size[k - 1], _size[k]\n",
"# axes[ii].plot(\n",
"# [elem - 0.5, elem1 - 0.5, elem1 - 0.5, elem - 0.5],\n",
"# [elem - 0.5, elem - 0.5, elem1 - 0.5, elem1 - 0.5],\n",
"# color=\"r\",\n",
"# )\n",
"# axes[ii].plot(\n",
"# [elem - 0.5, elem - 0.5, elem1 - 0.5, elem1 - 0.5],\n",
"# [elem - 0.5, elem1 - 0.5, elem1 - 0.5, elem - 0.5],\n",
"# color=\"r\",\n",
"# )\n",
"# \n",
"# ims_soft = axes[0].imshow(gs_mat_np, cmap=cmap, norm=normalizer)\n",
"# axes[0].set_title(f\"Soft {it // batch_size}\")\n",
"# axes[1].imshow(gs_matching_mat_np, cmap=cmap, norm=normalizer)\n",
"# axes[1].set_title(\"Hard\")\n",
"# \n",
"# curr_log_alpha = list_log_alpha[-1]\n",
"# ims_log_alpha = axes[2].imshow(curr_log_alpha, norm=CenteredNorm(), cmap=cm.bwr)\n",
"# axes[2].set_title(\"Log-alpha\")\n",
"# \n",
"# prev_log_alpha = (\n",
"# list_log_alpha[-2] if len(list_log_alpha) > 1 else curr_log_alpha\n",
"# )\n",
"# diff_log_alpha = curr_log_alpha - prev_log_alpha\n",
"# if np.nansum(np.abs(diff_log_alpha)):\n",
"# ims_log_alpha_diff = axes[3].imshow(\n",
"# diff_log_alpha, norm=CenteredNorm(), cmap=cm.bwr\n",
"# )\n",
"# fig.colorbar(ims_log_alpha_diff, ax=axes[3], shrink=0.8)\n",
"# else:\n",
"# axes[3].imshow(np.zeros_like(diff_log_alpha), cmap=cm.bwr)\n",
"# axes[3].set_title(\"Log-alpha diff\")\n",
"# \n",
"# avg_loss = np.mean(np.array(losses).reshape(-1, batch_size), axis=1)\n",
"# axes[4].plot(avg_loss, color=\"b\", linewidth=1)\n",
"# if not only_loss_plot:\n",
"# if tar_loss is not None:\n",
"# axes[4].axhline(tar_loss, color=\"b\", linewidth=2)\n",
"# diff = avg_loss[0] - tar_loss\n",
"# axes[4].set_ylim([tar_loss - 0.6 * diff, avg_loss[0] + 0.5 * diff])\n",
"# ax3_2 = axes[4].twinx()\n",
"# ax3_2.set_ylabel(\"No. of correct pairs\", color=\"red\")\n",
"# ax3_2.plot(n_correct, color=\"red\", linewidth=1)\n",
"# ax3_2.axhline(null_model, color=\"red\", linewidth=2)\n",
"# ax3_2.tick_params(axis=\"y\", labelcolor=\"red\")\n",
"# axes[4].set_ylabel(\"Loss\")\n",
"# axes[4].set_xlim([0, epochs])\n",
"# axes[4].set_title(f\"lr: {lr:.3g}, noise:{new_noise_factor:.3g}\")\n",
"# fig.colorbar(ims_soft, ax=axes[0], shrink=0.8)\n",
"# fig.colorbar(ims_log_alpha, ax=axes[2], shrink=0.8)\n",
"# if output_dir is not None:\n",
"# fig.savefig(output_dir / \"Iterations\" / f\"Epoch={it // batch_size}.svg\")\n",
"# plt.show()"
]
}
],
"metadata": {
Expand Down
2 changes: 1 addition & 1 deletion nbs/constants.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@
"\n",
" return TokenizedSubstitutionMatrix(\n",
" name=BLOSUM62.name,\n",
" mat=torch.tensor(mat.to_numpy(), dtype=torch.float32),\n",
" mat=torch.tensor(mat.to_numpy(), dtype=torch.get_default_dtype()),\n",
" expected_value=BLOSUM62.expected_value,\n",
" )"
]
Expand Down
2 changes: 1 addition & 1 deletion nbs/model.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@
" self.log_alphas = ParameterList(\n",
" [\n",
" Parameter(\n",
" torch.zeros(*self.ensemble_shape, s, s, dtype=torch.float32),\n",
" torch.zeros(*self.ensemble_shape, s, s),\n",
" requires_grad=bool(s),\n",
" )\n",
" for s in self.nonfixed_group_sizes_\n",
Expand Down

0 comments on commit b1f12f9

Please sign in to comment.