Skip to content

Commit

Permalink
Fix #7 (#8)
Browse files Browse the repository at this point in the history
Type-annotate all `group_sizes` as `Sequence`
  • Loading branch information
ulupo authored May 14, 2024
1 parent 6170564 commit 835a570
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 26 deletions.
16 changes: 8 additions & 8 deletions diffpass/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class GeneralizedPermutation(Module):
def __init__(
self,
*,
group_sizes: Iterable[int],
group_sizes: Sequence[int],
fixed_pairings: Optional[IndexPairsInGroups] = None,
tau: float = 1.0,
n_iter: int = 1,
Expand Down Expand Up @@ -256,7 +256,7 @@ class MatrixApply(Module):
"""Apply matrices to chunks of a tensor of shape (n_samples, length, alphabet_size)
and collate the results."""

def __init__(self, group_sizes: Iterable[int]) -> None:
def __init__(self, group_sizes: Sequence[int]) -> None:
super().__init__()
self.group_sizes = tuple(s for s in group_sizes)
self._group_slices = _consecutive_slices_from_sizes(self.group_sizes)
Expand All @@ -275,7 +275,7 @@ class PermutationConjugate(Module):
"""Conjugate blocks of a square 2D tensor of shape (n_samples, n_samples) by
permutation matrices."""

def __init__(self, group_sizes: Iterable[int]) -> None:
def __init__(self, group_sizes: Sequence[int]) -> None:
super().__init__()
self.group_sizes = tuple(s for s in group_sizes)
self._group_slices = _consecutive_slices_from_sizes(self.group_sizes)
Expand Down Expand Up @@ -366,7 +366,7 @@ class HammingSimilarities(Module):
def __init__(
self,
*,
group_sizes: Optional[Iterable[int]] = None,
group_sizes: Optional[Sequence[int]] = None,
use_dot: bool = True,
p: Optional[float] = None,
) -> None:
Expand Down Expand Up @@ -415,7 +415,7 @@ class Blosum62Similarities(Module):
def __init__(
self,
*,
group_sizes: Optional[Iterable[int]] = None,
group_sizes: Optional[Sequence[int]] = None,
use_dot: bool = True,
p: Optional[float] = None,
use_scoredist: bool = False,
Expand Down Expand Up @@ -485,7 +485,7 @@ def __init__(
self,
*,
reciprocal: bool = True,
group_sizes: Optional[Iterable[int]],
group_sizes: Optional[Sequence[int]],
tau: float = 0.1,
mode: Literal["soft", "hard"] = "soft",
) -> None:
Expand Down Expand Up @@ -550,7 +550,7 @@ def __init__(
*,
# Number of entries in each group (e.g. species). Groups are assumed to be
# contiguous in the input similarity matrices
group_sizes: Iterable[int],
group_sizes: Sequence[int],
# If not ``None``, custom callable to compute the differentiable score between
# the flattened and concatenated inter-group blocks of the similarity matrices.
# Default: dot product
Expand Down Expand Up @@ -604,7 +604,7 @@ def __init__(
*,
# Number of entries in each group (e.g. species). Groups are assumed to be
# contiguous in the input similarity matrices
group_sizes: Optional[Iterable[int]] = None,
group_sizes: Optional[Sequence[int]] = None,
# If not ``None``, custom callable to compute the differentiable score between
# the flattened and concatenated intra-group blocks of the similarity matrices
# Default: dot product
Expand Down
36 changes: 18 additions & 18 deletions nbs/model.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@
" def __init__(\n",
" self,\n",
" *,\n",
" group_sizes: Iterable[int],\n",
" group_sizes: Sequence[int],\n",
" fixed_pairings: Optional[IndexPairsInGroups] = None,\n",
" tau: float = 1.0,\n",
" n_iter: int = 1,\n",
Expand Down Expand Up @@ -333,7 +333,7 @@
" \"\"\"Apply matrices to chunks of a tensor of shape (n_samples, length, alphabet_size)\n",
" and collate the results.\"\"\"\n",
"\n",
" def __init__(self, group_sizes: Iterable[int]) -> None:\n",
" def __init__(self, group_sizes: Sequence[int]) -> None:\n",
" super().__init__()\n",
" self.group_sizes = tuple(s for s in group_sizes)\n",
" self._group_slices = _consecutive_slices_from_sizes(self.group_sizes)\n",
Expand All @@ -352,7 +352,7 @@
" \"\"\"Conjugate blocks of a square 2D tensor of shape (n_samples, n_samples) by\n",
" permutation matrices.\"\"\"\n",
"\n",
" def __init__(self, group_sizes: Iterable[int]) -> None:\n",
" def __init__(self, group_sizes: Sequence[int]) -> None:\n",
" super().__init__()\n",
" self.group_sizes = tuple(s for s in group_sizes)\n",
" self._group_slices = _consecutive_slices_from_sizes(self.group_sizes)\n",
Expand Down Expand Up @@ -416,15 +416,15 @@
"outputs": [
{
"data": {
"text/markdown": "---\n\n[source](https://github.com/Bitbol-Lab/DiffPaSS/blob/main/diffpass/model.py#L49){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n\n### GeneralizedPermutation\n\n> GeneralizedPermutation (group_sizes:collections.abc.Iterable[int], fixed_\n> pairings:Optional[collections.abc.Sequence[collec\n> tions.abc.Sequence[tuple[int,int]]]]=None,\n> tau:float=1.0, n_iter:int=1, noise:bool=False,\n> noise_factor:float=1.0, noise_std:bool=False,\n> mode:Literal['soft','hard']='soft')\n\n*Generalized permutation layer implementing both soft and hard permutations.*",
"text/markdown": "---\n\n[source](https://github.com/Bitbol-Lab/DiffPaSS/blob/main/diffpass/model.py#L49){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n\n### GeneralizedPermutation\n\n> GeneralizedPermutation (group_sizes:collections.abc.Sequence[int], fixed_\n> pairings:Optional[collections.abc.Sequence[collec\n> tions.abc.Sequence[tuple[int,int]]]]=None,\n> tau:float=1.0, n_iter:int=1, noise:bool=False,\n> noise_factor:float=1.0, noise_std:bool=False,\n> mode:Literal['soft','hard']='soft')\n\n*Generalized permutation layer implementing both soft and hard permutations.*",
"text/plain": [
"---\n",
"\n",
"[source](https://github.com/Bitbol-Lab/DiffPaSS/blob/main/diffpass/model.py#L49){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### GeneralizedPermutation\n",
"\n",
"> GeneralizedPermutation (group_sizes:collections.abc.Iterable[int], fixed_\n",
"> GeneralizedPermutation (group_sizes:collections.abc.Sequence[int], fixed_\n",
"> pairings:Optional[collections.abc.Sequence[collec\n",
"> tions.abc.Sequence[tuple[int,int]]]]=None,\n",
"> tau:float=1.0, n_iter:int=1, noise:bool=False,\n",
Expand Down Expand Up @@ -668,7 +668,7 @@
" def __init__(\n",
" self,\n",
" *,\n",
" group_sizes: Optional[Iterable[int]] = None,\n",
" group_sizes: Optional[Sequence[int]] = None,\n",
" use_dot: bool = True,\n",
" p: Optional[float] = None,\n",
" ) -> None:\n",
Expand Down Expand Up @@ -717,7 +717,7 @@
" def __init__(\n",
" self,\n",
" *,\n",
" group_sizes: Optional[Iterable[int]] = None,\n",
" group_sizes: Optional[Sequence[int]] = None,\n",
" use_dot: bool = True,\n",
" p: Optional[float] = None,\n",
" use_scoredist: bool = False,\n",
Expand Down Expand Up @@ -781,7 +781,7 @@
"outputs": [
{
"data": {
"text/markdown": "---\n\n[source](https://github.com/Bitbol-Lab/DiffPaSS/blob/main/diffpass/model.py#L351){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n\n### HammingSimilarities\n\n> HammingSimilarities\n> (group_sizes:Optional[collections.abc.Iterable[int]]\n> =None, use_dot:bool=True, p:Optional[float]=None)\n\nCompute Hamming similarities between sequences using differentiable\noperations.\n\nOptionally, if the sequences are arranged in groups, the computation of\nsimilarities can be restricted to within groups.\nDifferentiable operations are used to compute the similarities, which can be\neither dot products or an L^p distance function.",
"text/markdown": "---\n\n[source](https://github.com/Bitbol-Lab/DiffPaSS/blob/main/diffpass/model.py#L351){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n\n### HammingSimilarities\n\n> HammingSimilarities\n> (group_sizes:Optional[collections.abc.Sequence[int]]\n> =None, use_dot:bool=True, p:Optional[float]=None)\n\nCompute Hamming similarities between sequences using differentiable\noperations.\n\nOptionally, if the sequences are arranged in groups, the computation of\nsimilarities can be restricted to within groups.\nDifferentiable operations are used to compute the similarities, which can be\neither dot products or an L^p distance function.",
"text/plain": [
"---\n",
"\n",
Expand All @@ -790,7 +790,7 @@
"### HammingSimilarities\n",
"\n",
"> HammingSimilarities\n",
"> (group_sizes:Optional[collections.abc.Iterable[int]]\n",
"> (group_sizes:Optional[collections.abc.Sequence[int]]\n",
"> =None, use_dot:bool=True, p:Optional[float]=None)\n",
"\n",
"Compute Hamming similarities between sequences using differentiable\n",
Expand Down Expand Up @@ -818,7 +818,7 @@
"outputs": [
{
"data": {
"text/markdown": "---\n\n[source](https://github.com/Bitbol-Lab/DiffPaSS/blob/main/diffpass/model.py#L400){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n\n### Blosum62Similarities\n\n> Blosum62Similarities\n> (group_sizes:Optional[collections.abc.Iterable[int]\n> ]=None, use_dot:bool=True, p:Optional[float]=None,\n> use_scoredist:bool=False,\n> aa_to_int:Optional[dict[str,int]]=None,\n> gaps_as_stars:bool=True)\n\nCompute Blosum62-based similarities between sequences using differentiable\noperations.\n\nOptionally, if the sequences are arranged in groups, the computation of\nsimilarities can be restricted to within groups.\nDifferentiable operations are used to compute the similarities, which can be\neither dot products or an L^p distance function.",
"text/markdown": "---\n\n[source](https://github.com/Bitbol-Lab/DiffPaSS/blob/main/diffpass/model.py#L400){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n\n### Blosum62Similarities\n\n> Blosum62Similarities\n> (group_sizes:Optional[collections.abc.Sequence[int]\n> ]=None, use_dot:bool=True, p:Optional[float]=None,\n> use_scoredist:bool=False,\n> aa_to_int:Optional[dict[str,int]]=None,\n> gaps_as_stars:bool=True)\n\nCompute Blosum62-based similarities between sequences using differentiable\noperations.\n\nOptionally, if the sequences are arranged in groups, the computation of\nsimilarities can be restricted to within groups.\nDifferentiable operations are used to compute the similarities, which can be\neither dot products or an L^p distance function.",
"text/plain": [
"---\n",
"\n",
Expand All @@ -827,7 +827,7 @@
"### Blosum62Similarities\n",
"\n",
"> Blosum62Similarities\n",
"> (group_sizes:Optional[collections.abc.Iterable[int]\n",
"> (group_sizes:Optional[collections.abc.Sequence[int]\n",
"> ]=None, use_dot:bool=True, p:Optional[float]=None,\n",
"> use_scoredist:bool=False,\n",
"> aa_to_int:Optional[dict[str,int]]=None,\n",
Expand Down Expand Up @@ -933,7 +933,7 @@
" self,\n",
" *,\n",
" reciprocal: bool = True,\n",
" group_sizes: Optional[Iterable[int]],\n",
" group_sizes: Optional[Sequence[int]],\n",
" tau: float = 0.1,\n",
" mode: Literal[\"soft\", \"hard\"] = \"soft\",\n",
" ) -> None:\n",
Expand Down Expand Up @@ -992,7 +992,7 @@
"outputs": [
{
"data": {
"text/markdown": "---\n\n[source](https://github.com/Bitbol-Lab/DiffPaSS/blob/main/diffpass/model.py#L469){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n\n### BestHits\n\n> BestHits (reciprocal:bool=True,\n> group_sizes:Optional[collections.abc.Iterable[int]],\n> tau:float=0.1, mode:Literal['soft','hard']='soft')\n\nCompute (reciprocal) best hits within and between groups of sequences,\nstarting from a similarity matrix.\n\nBest hits can be either 'hard', in which cases they are computed using the\nargmax, or 'soft', in which case they are computed using the softmax with a\ntemperature parameter `tau`. In both cases, the main diagonal in the similarity\nmatrix is excluded by setting its entries to minus infinity.",
"text/markdown": "---\n\n[source](https://github.com/Bitbol-Lab/DiffPaSS/blob/main/diffpass/model.py#L469){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n\n### BestHits\n\n> BestHits (reciprocal:bool=True,\n> group_sizes:Optional[collections.abc.Sequence[int]],\n> tau:float=0.1, mode:Literal['soft','hard']='soft')\n\nCompute (reciprocal) best hits within and between groups of sequences,\nstarting from a similarity matrix.\n\nBest hits can be either 'hard', in which cases they are computed using the\nargmax, or 'soft', in which case they are computed using the softmax with a\ntemperature parameter `tau`. In both cases, the main diagonal in the similarity\nmatrix is excluded by setting its entries to minus infinity.",
"text/plain": [
"---\n",
"\n",
Expand All @@ -1001,7 +1001,7 @@
"### BestHits\n",
"\n",
"> BestHits (reciprocal:bool=True,\n",
"> group_sizes:Optional[collections.abc.Iterable[int]],\n",
"> group_sizes:Optional[collections.abc.Sequence[int]],\n",
"> tau:float=0.1, mode:Literal['soft','hard']='soft')\n",
"\n",
"Compute (reciprocal) best hits within and between groups of sequences,\n",
Expand Down Expand Up @@ -1050,7 +1050,7 @@
" *,\n",
" # Number of entries in each group (e.g. species). Groups are assumed to be\n",
" # contiguous in the input similarity matrices\n",
" group_sizes: Iterable[int],\n",
" group_sizes: Sequence[int],\n",
" # If not ``None``, custom callable to compute the differentiable score between\n",
" # the flattened and concatenated inter-group blocks of the similarity matrices.\n",
" # Default: dot product\n",
Expand Down Expand Up @@ -1104,7 +1104,7 @@
" *,\n",
" # Number of entries in each group (e.g. species). Groups are assumed to be\n",
" # contiguous in the input similarity matrices\n",
" group_sizes: Optional[Iterable[int]] = None,\n",
" group_sizes: Optional[Sequence[int]] = None,\n",
" # If not ``None``, custom callable to compute the differentiable score between\n",
" # the flattened and concatenated intra-group blocks of the similarity matrices\n",
" # Default: dot product\n",
Expand Down Expand Up @@ -1170,15 +1170,15 @@
"outputs": [
{
"data": {
"text/markdown": "---\n\n[source](https://github.com/Bitbol-Lab/DiffPaSS/blob/main/diffpass/model.py#L534){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n\n### InterGroupSimilarityLoss\n\n> InterGroupSimilarityLoss (group_sizes:collections.abc.Iterable[int],\n> score_fn:Optional[<built-\n> infunctioncallable>]=None)\n\nCompute a loss that compares similarity matrices restricted to inter-group\nrelationships.\n\nSimilarity matrices are expected to be square and symmetric. The loss is computed\nby comparing the (unrolled and concatenated) upper triangular blocks containing\ninter-group similarities.",
"text/markdown": "---\n\n[source](https://github.com/Bitbol-Lab/DiffPaSS/blob/main/diffpass/model.py#L534){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n\n### InterGroupSimilarityLoss\n\n> InterGroupSimilarityLoss (group_sizes:collections.abc.Sequence[int],\n> score_fn:Optional[<built-\n> infunctioncallable>]=None)\n\nCompute a loss that compares similarity matrices restricted to inter-group\nrelationships.\n\nSimilarity matrices are expected to be square and symmetric. The loss is computed\nby comparing the (unrolled and concatenated) upper triangular blocks containing\ninter-group similarities.",
"text/plain": [
"---\n",
"\n",
"[source](https://github.com/Bitbol-Lab/DiffPaSS/blob/main/diffpass/model.py#L534){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### InterGroupSimilarityLoss\n",
"\n",
"> InterGroupSimilarityLoss (group_sizes:collections.abc.Iterable[int],\n",
"> InterGroupSimilarityLoss (group_sizes:collections.abc.Sequence[int],\n",
"> score_fn:Optional[<built-\n",
"> infunctioncallable>]=None)\n",
"\n",
Expand Down

0 comments on commit 835a570

Please sign in to comment.