Skip to content

Commit

Permalink
[Lint] Remove unused imports
Browse files Browse the repository at this point in the history
  • Loading branch information
ulupo committed May 20, 2024
1 parent 2022080 commit 091dd9c
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 21 deletions.
14 changes: 6 additions & 8 deletions diffpass/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
# Stdlib imports
from collections.abc import Iterable, Sequence
from typing import Optional, Union, Iterator, Literal
from copy import deepcopy
from warnings import warn
from functools import partial

Expand All @@ -20,7 +19,6 @@
# PyTorch
import torch
from torch.nn import Module, ParameterList, Parameter
from torch.nn.functional import softmax

# DiffPaSS imports
from .gumbel_sinkhorn_ops import gumbel_sinkhorn, gumbel_matching
Expand All @@ -43,15 +41,15 @@
IndexPairsInGroup = list[IndexPair] # Pairs of indices in a group of sequences
IndexPairsInGroups = list[IndexPairsInGroup] # Pairs of indices in groups of sequences

# %% ../nbs/model.ipynb 6
# %% ../nbs/model.ipynb 7
def _consecutive_slices_from_sizes(group_sizes: Optional[Sequence[int]]) -> list[slice]:
if group_sizes is None:
return [slice(None)]
cumsum = np.cumsum(group_sizes).tolist()

return [slice(start, end) for start, end in zip([0] + cumsum, cumsum)]

# %% ../nbs/model.ipynb 8
# %% ../nbs/model.ipynb 9
class GeneralizedPermutation(Module):
"""Generalized permutation layer implementing both soft and hard permutations."""

Expand Down Expand Up @@ -331,7 +329,7 @@ def apply_hard_permutation_batch_to_similarity(

return torch.gather(x_permuted_rows, -1, index)

# %% ../nbs/model.ipynb 12
# %% ../nbs/model.ipynb 13
class TwoBodyEntropyLoss(Module):
"""Differentiable extension of the mean of estimated two-body entropies between
all pairs of columns from two one-hot encoded tensors."""
Expand All @@ -353,7 +351,7 @@ def __init__(self):
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return smooth_mean_two_body_entropy(x, y) - smooth_mean_one_body_entropy(x)

# %% ../nbs/model.ipynb 17
# %% ../nbs/model.ipynb 18
class HammingSimilarities(Module):
"""Compute Hamming similarities between sequences using differentiable
operations.
Expand Down Expand Up @@ -471,7 +469,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:

return out

# %% ../nbs/model.ipynb 22
# %% ../nbs/model.ipynb 23
class BestHits(Module):
"""Compute (reciprocal) best hits within and between groups of sequences,
starting from a similarity matrix.
Expand Down Expand Up @@ -536,7 +534,7 @@ def _hard_bh_fn(self, similarities: torch.Tensor) -> torch.Tensor:
def forward(self, similarities: torch.Tensor) -> torch.Tensor:
return self._bh_fn(similarities)

# %% ../nbs/model.ipynb 25
# %% ../nbs/model.ipynb 26
class InterGroupSimilarityLoss(Module):
"""Compute a loss that compares similarity matrices restricted to inter-group
relationships.
Expand Down
11 changes: 4 additions & 7 deletions diffpass/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,6 @@
from collections.abc import Sequence
from typing import Optional, Any, Literal

# NumPy
import numpy as np

# PyTorch
import torch

Expand All @@ -32,7 +29,7 @@
IndexPairsInGroup = list[IndexPair] # Pairs of indices in a group of sequences
IndexPairsInGroups = list[IndexPairsInGroup] # Pairs of indices in groups of sequences

# %% ../nbs/train.ipynb 6
# %% ../nbs/train.ipynb 7
class InformationPairing(DiffPaSSModel):
"""DiffPaSS model for information-theoretic pairing of multiple sequence alignments (MSAs)."""

Expand Down Expand Up @@ -97,7 +94,7 @@ def compute_losses_identity_perm(

return {"hard": hard_loss_identity_perm, "soft": soft_loss_identity_perm}

# %% ../nbs/train.ipynb 9
# %% ../nbs/train.ipynb 10
class BestHitsPairing(DiffPaSSModel):
"""DiffPaSS model for pairing of multiple sequence alignments (MSAs) by aligning their orthology networks, constructed using (reciprocal) best hits ."""

Expand Down Expand Up @@ -238,7 +235,7 @@ def compute_losses_identity_perm(

return {"hard": hard_loss_identity_perm, "soft": soft_loss_identity_perm}

# %% ../nbs/train.ipynb 12
# %% ../nbs/train.ipynb 13
class MirrortreePairing(DiffPaSSModel):
"""DiffPaSS model for pairing of multiple sequence alignments (MSAs) by aligning their sequence distance networks as in the Mirrortree method."""

Expand Down Expand Up @@ -339,7 +336,7 @@ def compute_losses_identity_perm(

return {"hard": hard_loss_identity_perm, "soft": soft_loss_identity_perm}

# %% ../nbs/train.ipynb 15
# %% ../nbs/train.ipynb 16
class GraphAlignment(DiffPaSSModel):
"""DiffPaSS model for general graph alignment starting from the weighted adjacency matrices of two graphs."""

Expand Down
15 changes: 13 additions & 2 deletions nbs/model.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@
"# Stdlib imports\n",
"from collections.abc import Iterable, Sequence\n",
"from typing import Optional, Union, Iterator, Literal\n",
"from copy import deepcopy\n",
"from warnings import warn\n",
"from functools import partial\n",
"\n",
Expand All @@ -62,7 +61,6 @@
"# PyTorch\n",
"import torch\n",
"from torch.nn import Module, ParameterList, Parameter\n",
"from torch.nn.functional import softmax\n",
"\n",
"# DiffPaSS imports\n",
"from diffpass.gumbel_sinkhorn_ops import gumbel_sinkhorn, gumbel_matching\n",
Expand Down Expand Up @@ -98,6 +96,19 @@
"```"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#| hide\n",
"\n",
"# Imports for tests\n",
"from copy import deepcopy\n",
"from torch.nn.functional import softmax"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down
15 changes: 12 additions & 3 deletions nbs/train.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,6 @@
"from collections.abc import Sequence\n",
"from typing import Optional, Any, Literal\n",
"\n",
"# NumPy\n",
"import numpy as np\n",
"\n",
"# PyTorch\n",
"import torch\n",
"\n",
Expand Down Expand Up @@ -90,6 +87,18 @@
"```"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#| hide\n",
"\n",
"# Imports for tests\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down
3 changes: 2 additions & 1 deletion nbs/tutorials/graph_alignment.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,8 @@
"metadata": {},
"outputs": [],
"source": [
"from diffpass.train import GraphAlignment, IntraGroupSimilarityLoss\n",
"from diffpass.model import IntraGroupSimilarityLoss\n",
"from diffpass.train import GraphAlignment\n",
"\n",
"group_sizes = [n_nodes]\n",
"comparison_loss = IntraGroupSimilarityLoss(\n",
Expand Down

0 comments on commit 091dd9c

Please sign in to comment.