diff --git a/diffpass/model.py b/diffpass/model.py index 2704f51..de8391c 100644 --- a/diffpass/model.py +++ b/diffpass/model.py @@ -58,7 +58,7 @@ class GeneralizedPermutation(Module): def __init__( self, *, - group_sizes: Iterable[int], + group_sizes: Sequence[int], fixed_pairings: Optional[IndexPairsInGroups] = None, tau: float = 1.0, n_iter: int = 1, @@ -256,7 +256,7 @@ class MatrixApply(Module): """Apply matrices to chunks of a tensor of shape (n_samples, length, alphabet_size) and collate the results.""" - def __init__(self, group_sizes: Iterable[int]) -> None: + def __init__(self, group_sizes: Sequence[int]) -> None: super().__init__() self.group_sizes = tuple(s for s in group_sizes) self._group_slices = _consecutive_slices_from_sizes(self.group_sizes) @@ -275,7 +275,7 @@ class PermutationConjugate(Module): """Conjugate blocks of a square 2D tensor of shape (n_samples, n_samples) by permutation matrices.""" - def __init__(self, group_sizes: Iterable[int]) -> None: + def __init__(self, group_sizes: Sequence[int]) -> None: super().__init__() self.group_sizes = tuple(s for s in group_sizes) self._group_slices = _consecutive_slices_from_sizes(self.group_sizes) @@ -366,7 +366,7 @@ class HammingSimilarities(Module): def __init__( self, *, - group_sizes: Optional[Iterable[int]] = None, + group_sizes: Optional[Sequence[int]] = None, use_dot: bool = True, p: Optional[float] = None, ) -> None: @@ -415,7 +415,7 @@ class Blosum62Similarities(Module): def __init__( self, *, - group_sizes: Optional[Iterable[int]] = None, + group_sizes: Optional[Sequence[int]] = None, use_dot: bool = True, p: Optional[float] = None, use_scoredist: bool = False, @@ -485,7 +485,7 @@ def __init__( self, *, reciprocal: bool = True, - group_sizes: Optional[Iterable[int]], + group_sizes: Optional[Sequence[int]], tau: float = 0.1, mode: Literal["soft", "hard"] = "soft", ) -> None: @@ -550,7 +550,7 @@ def __init__( *, # Number of entries in each group (e.g. species). Groups are assumed to be # contiguous in the input similarity matrices - group_sizes: Iterable[int], + group_sizes: Sequence[int], # If not ``None``, custom callable to compute the differentiable score between # the flattened and concatenated inter-group blocks of the similarity matrices. # Default: dot product @@ -604,7 +604,7 @@ def __init__( *, # Number of entries in each group (e.g. species). Groups are assumed to be # contiguous in the input similarity matrices - group_sizes: Optional[Iterable[int]] = None, + group_sizes: Optional[Sequence[int]] = None, # If not ``None``, custom callable to compute the differentiable score between # the flattened and concatenated intra-group blocks of the similarity matrices # Default: dot product diff --git a/nbs/model.ipynb b/nbs/model.ipynb index a434c64..baadcc0 100644 --- a/nbs/model.ipynb +++ b/nbs/model.ipynb @@ -135,7 +135,7 @@ " def __init__(\n", " self,\n", " *,\n", - " group_sizes: Iterable[int],\n", + " group_sizes: Sequence[int],\n", " fixed_pairings: Optional[IndexPairsInGroups] = None,\n", " tau: float = 1.0,\n", " n_iter: int = 1,\n", @@ -333,7 +333,7 @@ " \"\"\"Apply matrices to chunks of a tensor of shape (n_samples, length, alphabet_size)\n", " and collate the results.\"\"\"\n", "\n", - " def __init__(self, group_sizes: Iterable[int]) -> None:\n", + " def __init__(self, group_sizes: Sequence[int]) -> None:\n", " super().__init__()\n", " self.group_sizes = tuple(s for s in group_sizes)\n", " self._group_slices = _consecutive_slices_from_sizes(self.group_sizes)\n", @@ -352,7 +352,7 @@ " \"\"\"Conjugate blocks of a square 2D tensor of shape (n_samples, n_samples) by\n", " permutation matrices.\"\"\"\n", "\n", - " def __init__(self, group_sizes: Iterable[int]) -> None:\n", + " def __init__(self, group_sizes: Sequence[int]) -> None:\n", " super().__init__()\n", " self.group_sizes = tuple(s for s in group_sizes)\n", " self._group_slices = _consecutive_slices_from_sizes(self.group_sizes)\n", @@ -416,7 +416,7 @@ "outputs": [ { "data": { - "text/markdown": "---\n\n[source](https://github.com/Bitbol-Lab/DiffPaSS/blob/main/diffpass/model.py#L49){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n\n### GeneralizedPermutation\n\n> GeneralizedPermutation (group_sizes:collections.abc.Iterable[int], fixed_\n> pairings:Optional[collections.abc.Sequence[collec\n> tions.abc.Sequence[tuple[int,int]]]]=None,\n> tau:float=1.0, n_iter:int=1, noise:bool=False,\n> noise_factor:float=1.0, noise_std:bool=False,\n> mode:Literal['soft','hard']='soft')\n\n*Generalized permutation layer implementing both soft and hard permutations.*", + "text/markdown": "---\n\n[source](https://github.com/Bitbol-Lab/DiffPaSS/blob/main/diffpass/model.py#L49){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n\n### GeneralizedPermutation\n\n> GeneralizedPermutation (group_sizes:collections.abc.Sequence[int], fixed_\n> pairings:Optional[collections.abc.Sequence[collec\n> tions.abc.Sequence[tuple[int,int]]]]=None,\n> tau:float=1.0, n_iter:int=1, noise:bool=False,\n> noise_factor:float=1.0, noise_std:bool=False,\n> mode:Literal['soft','hard']='soft')\n\n*Generalized permutation layer implementing both soft and hard permutations.*", "text/plain": [ "---\n", "\n", @@ -424,7 +424,7 @@ "\n", "### GeneralizedPermutation\n", "\n", - "> GeneralizedPermutation (group_sizes:collections.abc.Iterable[int], fixed_\n", + "> GeneralizedPermutation (group_sizes:collections.abc.Sequence[int], fixed_\n", "> pairings:Optional[collections.abc.Sequence[collec\n", "> tions.abc.Sequence[tuple[int,int]]]]=None,\n", "> tau:float=1.0, n_iter:int=1, noise:bool=False,\n", @@ -668,7 +668,7 @@ " def __init__(\n", " self,\n", " *,\n", - " group_sizes: Optional[Iterable[int]] = None,\n", + " group_sizes: Optional[Sequence[int]] = None,\n", " use_dot: bool = True,\n", " p: Optional[float] = None,\n", " ) -> None:\n", @@ -717,7 +717,7 @@ " def __init__(\n", " self,\n", " *,\n", - " group_sizes: Optional[Iterable[int]] = None,\n", + " group_sizes: Optional[Sequence[int]] = None,\n", " use_dot: bool = True,\n", " p: Optional[float] = None,\n", " use_scoredist: bool = False,\n", @@ -781,7 +781,7 @@ "outputs": [ { "data": { - "text/markdown": "---\n\n[source](https://github.com/Bitbol-Lab/DiffPaSS/blob/main/diffpass/model.py#L351){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n\n### HammingSimilarities\n\n> HammingSimilarities\n> (group_sizes:Optional[collections.abc.Iterable[int]]\n> =None, use_dot:bool=True, p:Optional[float]=None)\n\nCompute Hamming similarities between sequences using differentiable\noperations.\n\nOptionally, if the sequences are arranged in groups, the computation of\nsimilarities can be restricted to within groups.\nDifferentiable operations are used to compute the similarities, which can be\neither dot products or an L^p distance function.", + "text/markdown": "---\n\n[source](https://github.com/Bitbol-Lab/DiffPaSS/blob/main/diffpass/model.py#L351){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n\n### HammingSimilarities\n\n> HammingSimilarities\n> (group_sizes:Optional[collections.abc.Sequence[int]]\n> =None, use_dot:bool=True, p:Optional[float]=None)\n\nCompute Hamming similarities between sequences using differentiable\noperations.\n\nOptionally, if the sequences are arranged in groups, the computation of\nsimilarities can be restricted to within groups.\nDifferentiable operations are used to compute the similarities, which can be\neither dot products or an L^p distance function.", "text/plain": [ "---\n", "\n", @@ -790,7 +790,7 @@ "### HammingSimilarities\n", "\n", "> HammingSimilarities\n", - "> (group_sizes:Optional[collections.abc.Iterable[int]]\n", + "> (group_sizes:Optional[collections.abc.Sequence[int]]\n", "> =None, use_dot:bool=True, p:Optional[float]=None)\n", "\n", "Compute Hamming similarities between sequences using differentiable\n", @@ -818,7 +818,7 @@ "outputs": [ { "data": { - "text/markdown": "---\n\n[source](https://github.com/Bitbol-Lab/DiffPaSS/blob/main/diffpass/model.py#L400){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n\n### Blosum62Similarities\n\n> Blosum62Similarities\n> (group_sizes:Optional[collections.abc.Iterable[int]\n> ]=None, use_dot:bool=True, p:Optional[float]=None,\n> use_scoredist:bool=False,\n> aa_to_int:Optional[dict[str,int]]=None,\n> gaps_as_stars:bool=True)\n\nCompute Blosum62-based similarities between sequences using differentiable\noperations.\n\nOptionally, if the sequences are arranged in groups, the computation of\nsimilarities can be restricted to within groups.\nDifferentiable operations are used to compute the similarities, which can be\neither dot products or an L^p distance function.", + "text/markdown": "---\n\n[source](https://github.com/Bitbol-Lab/DiffPaSS/blob/main/diffpass/model.py#L400){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n\n### Blosum62Similarities\n\n> Blosum62Similarities\n> (group_sizes:Optional[collections.abc.Sequence[int]\n> ]=None, use_dot:bool=True, p:Optional[float]=None,\n> use_scoredist:bool=False,\n> aa_to_int:Optional[dict[str,int]]=None,\n> gaps_as_stars:bool=True)\n\nCompute Blosum62-based similarities between sequences using differentiable\noperations.\n\nOptionally, if the sequences are arranged in groups, the computation of\nsimilarities can be restricted to within groups.\nDifferentiable operations are used to compute the similarities, which can be\neither dot products or an L^p distance function.", "text/plain": [ "---\n", "\n", @@ -827,7 +827,7 @@ "### Blosum62Similarities\n", "\n", "> Blosum62Similarities\n", - "> (group_sizes:Optional[collections.abc.Iterable[int]\n", + "> (group_sizes:Optional[collections.abc.Sequence[int]\n", "> ]=None, use_dot:bool=True, p:Optional[float]=None,\n", "> use_scoredist:bool=False,\n", "> aa_to_int:Optional[dict[str,int]]=None,\n", @@ -933,7 +933,7 @@ " self,\n", " *,\n", " reciprocal: bool = True,\n", - " group_sizes: Optional[Iterable[int]],\n", + " group_sizes: Optional[Sequence[int]],\n", " tau: float = 0.1,\n", " mode: Literal[\"soft\", \"hard\"] = \"soft\",\n", " ) -> None:\n", @@ -992,7 +992,7 @@ "outputs": [ { "data": { - "text/markdown": "---\n\n[source](https://github.com/Bitbol-Lab/DiffPaSS/blob/main/diffpass/model.py#L469){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n\n### BestHits\n\n> BestHits (reciprocal:bool=True,\n> group_sizes:Optional[collections.abc.Iterable[int]],\n> tau:float=0.1, mode:Literal['soft','hard']='soft')\n\nCompute (reciprocal) best hits within and between groups of sequences,\nstarting from a similarity matrix.\n\nBest hits can be either 'hard', in which cases they are computed using the\nargmax, or 'soft', in which case they are computed using the softmax with a\ntemperature parameter `tau`. In both cases, the main diagonal in the similarity\nmatrix is excluded by setting its entries to minus infinity.", + "text/markdown": "---\n\n[source](https://github.com/Bitbol-Lab/DiffPaSS/blob/main/diffpass/model.py#L469){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n\n### BestHits\n\n> BestHits (reciprocal:bool=True,\n> group_sizes:Optional[collections.abc.Sequence[int]],\n> tau:float=0.1, mode:Literal['soft','hard']='soft')\n\nCompute (reciprocal) best hits within and between groups of sequences,\nstarting from a similarity matrix.\n\nBest hits can be either 'hard', in which cases they are computed using the\nargmax, or 'soft', in which case they are computed using the softmax with a\ntemperature parameter `tau`. In both cases, the main diagonal in the similarity\nmatrix is excluded by setting its entries to minus infinity.", "text/plain": [ "---\n", "\n", @@ -1001,7 +1001,7 @@ "### BestHits\n", "\n", "> BestHits (reciprocal:bool=True,\n", - "> group_sizes:Optional[collections.abc.Iterable[int]],\n", + "> group_sizes:Optional[collections.abc.Sequence[int]],\n", "> tau:float=0.1, mode:Literal['soft','hard']='soft')\n", "\n", "Compute (reciprocal) best hits within and between groups of sequences,\n", @@ -1050,7 +1050,7 @@ " *,\n", " # Number of entries in each group (e.g. species). Groups are assumed to be\n", " # contiguous in the input similarity matrices\n", - " group_sizes: Iterable[int],\n", + " group_sizes: Sequence[int],\n", " # If not ``None``, custom callable to compute the differentiable score between\n", " # the flattened and concatenated inter-group blocks of the similarity matrices.\n", " # Default: dot product\n", @@ -1104,7 +1104,7 @@ " *,\n", " # Number of entries in each group (e.g. species). Groups are assumed to be\n", " # contiguous in the input similarity matrices\n", - " group_sizes: Optional[Iterable[int]] = None,\n", + " group_sizes: Optional[Sequence[int]] = None,\n", " # If not ``None``, custom callable to compute the differentiable score between\n", " # the flattened and concatenated intra-group blocks of the similarity matrices\n", " # Default: dot product\n", @@ -1170,7 +1170,7 @@ "outputs": [ { "data": { - "text/markdown": "---\n\n[source](https://github.com/Bitbol-Lab/DiffPaSS/blob/main/diffpass/model.py#L534){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n\n### InterGroupSimilarityLoss\n\n> InterGroupSimilarityLoss (group_sizes:collections.abc.Iterable[int],\n> score_fn:Optional[ infunctioncallable>]=None)\n\nCompute a loss that compares similarity matrices restricted to inter-group\nrelationships.\n\nSimilarity matrices are expected to be square and symmetric. The loss is computed\nby comparing the (unrolled and concatenated) upper triangular blocks containing\ninter-group similarities.", + "text/markdown": "---\n\n[source](https://github.com/Bitbol-Lab/DiffPaSS/blob/main/diffpass/model.py#L534){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n\n### InterGroupSimilarityLoss\n\n> InterGroupSimilarityLoss (group_sizes:collections.abc.Sequence[int],\n> score_fn:Optional[ infunctioncallable>]=None)\n\nCompute a loss that compares similarity matrices restricted to inter-group\nrelationships.\n\nSimilarity matrices are expected to be square and symmetric. The loss is computed\nby comparing the (unrolled and concatenated) upper triangular blocks containing\ninter-group similarities.", "text/plain": [ "---\n", "\n", @@ -1178,7 +1178,7 @@ "\n", "### InterGroupSimilarityLoss\n", "\n", - "> InterGroupSimilarityLoss (group_sizes:collections.abc.Iterable[int],\n", + "> InterGroupSimilarityLoss (group_sizes:collections.abc.Sequence[int],\n", "> score_fn:Optional[ infunctioncallable>]=None)\n", "\n",