Skip to content

Commit

Permalink
Make two variants of substitution-matrix--based similarity functions
Browse files Browse the repository at this point in the history
Versions are 'dot' and 'cdist', with 'p' a new parameter for 'cdist'
  • Loading branch information
ulupo committed Mar 5, 2024
1 parent 72b1e67 commit 8e71df6
Show file tree
Hide file tree
Showing 7 changed files with 78 additions and 17 deletions.
6 changes: 4 additions & 2 deletions diffpass/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,10 @@
'diffpass/sequence_similarity_ops.py'),
'diffpass.sequence_similarity_ops.smooth_hamming_similarities_dot': ( 'sequence_similarity_ops.html#smooth_hamming_similarities_dot',
'diffpass/sequence_similarity_ops.py'),
'diffpass.sequence_similarity_ops.smooth_substitution_matrix_similarities': ( 'sequence_similarity_ops.html#smooth_substitution_matrix_similarities',
'diffpass/sequence_similarity_ops.py'),
'diffpass.sequence_similarity_ops.smooth_substitution_matrix_similarities_cdist': ( 'sequence_similarity_ops.html#smooth_substitution_matrix_similarities_cdist',
'diffpass/sequence_similarity_ops.py'),
'diffpass.sequence_similarity_ops.smooth_substitution_matrix_similarities_dot': ( 'sequence_similarity_ops.html#smooth_substitution_matrix_similarities_dot',
'diffpass/sequence_similarity_ops.py'),
'diffpass.sequence_similarity_ops.soft_best_hits': ( 'sequence_similarity_ops.html#soft_best_hits',
'diffpass/sequence_similarity_ops.py')},
'diffpass.train': { 'diffpass.train.DiffPASSResults': ('train.html#diffpassresults', 'diffpass/train.py'),
Expand Down
2 changes: 1 addition & 1 deletion diffpass/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class DiffPASSMixin:
allowed_similarity_kinds = {"Hamming", "Blosum62"}
allowed_similarities_cfg_keys = {
"Hamming": {"use_dot", "p"},
"Blosum62": {"use_scoredist", "aa_to_int", "gaps_as_stars"},
"Blosum62": {"use_dot", "p", "use_scoredist", "aa_to_int", "gaps_as_stars"},
}
allowed_best_hits_cfg_keys = {"tau", "reciprocal"}

Expand Down
29 changes: 24 additions & 5 deletions diffpass/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@
from diffpass.sequence_similarity_ops import (
smooth_hamming_similarities_dot,
smooth_hamming_similarities_cdist,
smooth_substitution_matrix_similarities,
smooth_substitution_matrix_similarities_dot,
smooth_substitution_matrix_similarities_cdist,
soft_best_hits,
hard_best_hits,
)
Expand Down Expand Up @@ -312,7 +313,7 @@ def __init__(
if self.p is None:
raise ValueError("If `use_dot` is False, `p` must be provided.")
self._similarities_fn = smooth_hamming_similarities_cdist
self._similarities_fn_kwargs = {"p": p}
self._similarities_fn_kwargs = {"p": self.p}

self._group_slices = _consecutive_slices_from_sizes(self.group_sizes)

Expand All @@ -334,6 +335,8 @@ def __init__(
self,
*,
group_sizes: Optional[Iterable[int]] = None,
use_dot: bool = True,
p: Optional[float] = None,
use_scoredist: bool = False,
aa_to_int: Optional[dict[str, int]] = None,
gaps_as_stars: bool = True,
Expand All @@ -342,6 +345,8 @@ def __init__(
self.group_sizes = (
tuple(s for s in group_sizes) if group_sizes is not None else None
)
self.use_dot = use_dot
self.p = p
self.use_scoredist = use_scoredist
self.aa_to_int = aa_to_int
self.gaps_as_stars = gaps_as_stars
Expand All @@ -352,6 +357,21 @@ def __init__(
self.register_buffer("subs_mat", blosum62_data.mat)
self.expected_value = blosum62_data.expected_value

self._similarities_fn_kwargs = {"subs_mat": self.subs_mat}
if self.use_dot:
if self.p is not None:
warn("Since a `p` was provided, `use_dot` will be ignored.")
self._similarities_fn = smooth_substitution_matrix_similarities_dot
self._similarities_fn_kwargs = {
"use_scoredist": self.use_scoredist,
"expected_value": self.expected_value,
}
else:
if self.p is None:
raise ValueError("If `use_dot` is False, `p` must be provided.")
self._similarities_fn = smooth_substitution_matrix_similarities_cdist
self._similarities_fn_kwargs = {"p": self.p}

self._group_slices = _consecutive_slices_from_sizes(self.group_sizes)

def forward(self, x: torch.Tensor) -> torch.Tensor:
Expand All @@ -361,11 +381,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
)
for sl in self._group_slices:
out[..., sl, sl].copy_(
smooth_substitution_matrix_similarities(
self._similarities_fn(
x[..., sl, :, :],
subs_mat=self.subs_mat,
expected_value=self.expected_value,
use_scoredist=self.use_scoredist,
**self._similarities_fn_kwargs,
)
)

Expand Down
15 changes: 13 additions & 2 deletions diffpass/sequence_similarity_ops.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/sequence_similarity_ops.ipynb.

# %% auto 0
__all__ = ['smooth_hamming_similarities_cdist', 'smooth_hamming_similarities_dot', 'smooth_substitution_matrix_similarities',
__all__ = ['smooth_hamming_similarities_cdist', 'smooth_hamming_similarities_dot',
'smooth_substitution_matrix_similarities_cdist', 'smooth_substitution_matrix_similarities_dot',
'soft_best_hits', 'hard_best_hits']

# %% ../nbs/sequence_similarity_ops.ipynb 3
Expand Down Expand Up @@ -33,7 +34,17 @@ def smooth_hamming_similarities_dot(x: torch.Tensor) -> torch.Tensor:
return norm_similarities


def smooth_substitution_matrix_similarities(
def smooth_substitution_matrix_similarities_cdist(
x: torch.Tensor, subs_mat: torch.Tensor, p: float = 1.0
) -> torch.Tensor:
"""TODO."""
x = torch.einsum("ab,...nib->...nia", subs_mat, x).flatten(start_dim=-2)
scores = -torch.cdist(x, x, p=p) ** p

return scores


def smooth_substitution_matrix_similarities_dot(
x: torch.Tensor,
subs_mat: torch.Tensor,
use_scoredist: bool = False,
Expand Down
2 changes: 1 addition & 1 deletion nbs/base.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
" allowed_similarity_kinds = {\"Hamming\", \"Blosum62\"}\n",
" allowed_similarities_cfg_keys = {\n",
" \"Hamming\": {\"use_dot\", \"p\"},\n",
" \"Blosum62\": {\"use_scoredist\", \"aa_to_int\", \"gaps_as_stars\"},\n",
" \"Blosum62\": {\"use_dot\", \"p\", \"use_scoredist\", \"aa_to_int\", \"gaps_as_stars\"},\n",
" }\n",
" allowed_best_hits_cfg_keys = {\"tau\", \"reciprocal\"}\n",
"\n",
Expand Down
29 changes: 24 additions & 5 deletions nbs/model.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@
"from diffpass.sequence_similarity_ops import (\n",
" smooth_hamming_similarities_dot,\n",
" smooth_hamming_similarities_cdist,\n",
" smooth_substitution_matrix_similarities,\n",
" smooth_substitution_matrix_similarities_dot,\n",
" smooth_substitution_matrix_similarities_cdist,\n",
" soft_best_hits,\n",
" hard_best_hits,\n",
")"
Expand Down Expand Up @@ -490,7 +491,7 @@
" if self.p is None:\n",
" raise ValueError(\"If `use_dot` is False, `p` must be provided.\")\n",
" self._similarities_fn = smooth_hamming_similarities_cdist\n",
" self._similarities_fn_kwargs = {\"p\": p}\n",
" self._similarities_fn_kwargs = {\"p\": self.p}\n",
"\n",
" self._group_slices = _consecutive_slices_from_sizes(self.group_sizes)\n",
"\n",
Expand All @@ -512,6 +513,8 @@
" self,\n",
" *,\n",
" group_sizes: Optional[Iterable[int]] = None,\n",
" use_dot: bool = True,\n",
" p: Optional[float] = None,\n",
" use_scoredist: bool = False,\n",
" aa_to_int: Optional[dict[str, int]] = None,\n",
" gaps_as_stars: bool = True,\n",
Expand All @@ -520,6 +523,8 @@
" self.group_sizes = (\n",
" tuple(s for s in group_sizes) if group_sizes is not None else None\n",
" )\n",
" self.use_dot = use_dot\n",
" self.p = p\n",
" self.use_scoredist = use_scoredist\n",
" self.aa_to_int = aa_to_int\n",
" self.gaps_as_stars = gaps_as_stars\n",
Expand All @@ -530,6 +535,21 @@
" self.register_buffer(\"subs_mat\", blosum62_data.mat)\n",
" self.expected_value = blosum62_data.expected_value\n",
"\n",
" self._similarities_fn_kwargs = {\"subs_mat\": self.subs_mat}\n",
" if self.use_dot:\n",
" if self.p is not None:\n",
" warn(\"Since a `p` was provided, `use_dot` will be ignored.\")\n",
" self._similarities_fn = smooth_substitution_matrix_similarities_dot\n",
" self._similarities_fn_kwargs = {\n",
" \"use_scoredist\": self.use_scoredist,\n",
" \"expected_value\": self.expected_value,\n",
" }\n",
" else:\n",
" if self.p is None:\n",
" raise ValueError(\"If `use_dot` is False, `p` must be provided.\")\n",
" self._similarities_fn = smooth_substitution_matrix_similarities_cdist\n",
" self._similarities_fn_kwargs = {\"p\": self.p}\n",
"\n",
" self._group_slices = _consecutive_slices_from_sizes(self.group_sizes)\n",
"\n",
" def forward(self, x: torch.Tensor) -> torch.Tensor:\n",
Expand All @@ -539,11 +559,10 @@
" )\n",
" for sl in self._group_slices:\n",
" out[..., sl, sl].copy_(\n",
" smooth_substitution_matrix_similarities(\n",
" self._similarities_fn(\n",
" x[..., sl, :, :],\n",
" subs_mat=self.subs_mat,\n",
" expected_value=self.expected_value,\n",
" use_scoredist=self.use_scoredist,\n",
" **self._similarities_fn_kwargs,\n",
" )\n",
" )\n",
"\n",
Expand Down
12 changes: 11 additions & 1 deletion nbs/sequence_similarity_ops.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,17 @@
" return norm_similarities\n",
"\n",
"\n",
"def smooth_substitution_matrix_similarities(\n",
"def smooth_substitution_matrix_similarities_cdist(\n",
" x: torch.Tensor, subs_mat: torch.Tensor, p: float = 1.0\n",
") -> torch.Tensor:\n",
" \"\"\"TODO.\"\"\"\n",
" x = torch.einsum(\"ab,...nib->...nia\", subs_mat, x).flatten(start_dim=-2)\n",
" scores = -torch.cdist(x, x, p=p) ** p\n",
"\n",
" return scores\n",
"\n",
"\n",
"def smooth_substitution_matrix_similarities_dot(\n",
" x: torch.Tensor,\n",
" subs_mat: torch.Tensor,\n",
" use_scoredist: bool = False,\n",
Expand Down

0 comments on commit 8e71df6

Please sign in to comment.