diff --git a/README.md b/README.md index 89c84fc..7453e61 100644 --- a/README.md +++ b/README.md @@ -120,7 +120,7 @@ python -m pip install -e . ## Quickstart -### Input data preprocessing +### Input data preprocessing (MSA pairing) First, parse your multiple sequence alignments (MSAs) in FASTA format into a list of tuples `(header, sequence)` using @@ -205,21 +205,28 @@ msa_B_oh = one_hot_encode_msa(msa_B_for_pairing, device=device) ### Pairing optimization -Finally, we can instantiate an +Finally, we can instantiate a class from `diffpass.train` to find an +optimal pairing between `x` and `y`. Here, `x` and `y` are MSAs, so we +can look for a pairing that optimizes the mutual information between `x` +and `y`. For this, we use [`InformationPairing`](https://Bitbol-Lab.github.io/DiffPaSS/train.html#informationpairing) -object and optimize the mutual information between the paired MSAs using -the DiffPaSS bootstrapped optimization algorithm. The results are stored -in a -[`DiffPaSSResults`](https://Bitbol-Lab.github.io/DiffPaSS/base.html#diffpassresults) -container. The lists of (hard) losses and permutations found during the -optimization can be accessed as attributes of the container. +and the DiffPaSS bootstrapped optimization algorithm. See the tutorials +below for other examples, including for graph alignment when `x` and `y` +are weighted adjacency matrices. ``` python from diffpass.train import InformationPairing information_pairing = InformationPairing(group_sizes=species_sizes).to(device) bootstrap_results = information_pairing.fit_bootstrap(x, y) +``` +The results are stored in a +[`DiffPaSSResults`](https://Bitbol-Lab.github.io/DiffPaSS/base.html#diffpassresults) +container. The lists of (hard) losses and permutations found during the +optimization can be accessed as attributes of the container: + +``` python print(f"Final hard loss: {bootstrap_results.hard_losses[-1].item()}") print(f"Final hard permutations (one permutation per species): {bootstrap_results.hard_perms[-1][-1].item()}") ``` @@ -229,11 +236,14 @@ the tutorials. ## Tutorials -See the -[`mutual_information_msa_pairing.ipynb`](https://github.com/Bitbol-Lab/DiffPaSS/blob/main/nbs/tutorials/mutual_information_msa_pairing.ipynb) -notebook for an example of paired MSA optimization in the case of -well-known prokaryotic datasets, for which ground truth pairings are -given by genome proximity. +- [`mutual_information_msa_pairing.ipynb`](https://github.com/Bitbol-Lab/DiffPaSS/blob/main/nbs/tutorials/mutual_information_msa_pairing.ipynb): + paired MSA optimization using mutual information in the case of + well-known prokaryotic datasets, for which ground truth pairings are + given by genome proximity. +- [`graph_alignment.ipynb`](https://github.com/Bitbol-Lab/DiffPaSS/blob/main/nbs/tutorials/graph_alignment.ipynb): + general graph alignment using + [`diffpass.train.GraphAlignment`](https://Bitbol-Lab.github.io/DiffPaSS/train.html#graphalignment), + with an example of aligning two weighted adjacency matrices. ## Documentation diff --git a/nbs/index.ipynb b/nbs/index.ipynb index 03590ac..8d8be68 100644 --- a/nbs/index.ipynb +++ b/nbs/index.ipynb @@ -91,7 +91,7 @@ "source": [ "## Quickstart\n", "\n", - "### Input data preprocessing\n", + "### Input data preprocessing (MSA pairing)\n", "\n", "First, parse your multiple sequence alignments (MSAs) in FASTA format into a list of tuples ``(header, sequence)`` using `read_msa`.\n", "\n", @@ -164,14 +164,17 @@ "\n", "### Pairing optimization\n", "\n", - "Finally, we can instantiate an `InformationPairing` object and optimize the mutual information between the paired MSAs using the DiffPaSS bootstrapped optimization algorithm. The results are stored in a `DiffPaSSResults` container. The lists of (hard) losses and permutations found during the optimization can be accessed as attributes of the container.\n", + "Finally, we can instantiate a class from `diffpass.train` to find an optimal pairing between `x` and `y`. Here, `x` and `y` are MSAs, so we can look for a pairing that optimizes the mutual information between `x` and `y`. For this, we use `InformationPairing` and the DiffPaSS bootstrapped optimization algorithm. See the tutorials below for other examples, including for graph alignment when `x` and `y` are weighted adjacency matrices.\n", "\n", "```python\n", "from diffpass.train import InformationPairing\n", "\n", "information_pairing = InformationPairing(group_sizes=species_sizes).to(device)\n", "bootstrap_results = information_pairing.fit_bootstrap(x, y)\n", + "```\n", "\n", + "The results are stored in a `DiffPaSSResults` container. The lists of (hard) losses and permutations found during the optimization can be accessed as attributes of the container:\n", + "```python\n", "print(f\"Final hard loss: {bootstrap_results.hard_losses[-1].item()}\")\n", "print(f\"Final hard permutations (one permutation per species): {bootstrap_results.hard_perms[-1][-1].item()}\")\n", "```\n", @@ -185,7 +188,8 @@ "source": [ "## Tutorials\n", "\n", - "See the [`mutual_information_msa_pairing.ipynb`](https://github.com/Bitbol-Lab/DiffPaSS/blob/main/nbs/tutorials/mutual_information_msa_pairing.ipynb) notebook for an example of paired MSA optimization in the case of well-known prokaryotic datasets, for which ground truth pairings are given by genome proximity." + "- [`mutual_information_msa_pairing.ipynb`](https://github.com/Bitbol-Lab/DiffPaSS/blob/main/nbs/tutorials/mutual_information_msa_pairing.ipynb): paired MSA optimization using mutual information in the case of well-known prokaryotic datasets, for which ground truth pairings are given by genome proximity.\n", + "- [`graph_alignment.ipynb`](https://github.com/Bitbol-Lab/DiffPaSS/blob/main/nbs/tutorials/graph_alignment.ipynb): general graph alignment using `diffpass.train.GraphAlignment`, with an example of aligning two weighted adjacency matrices." ] }, { diff --git a/nbs/sidebar.yml b/nbs/sidebar.yml index 657fa01..89fba9b 100644 --- a/nbs/sidebar.yml +++ b/nbs/sidebar.yml @@ -14,4 +14,5 @@ website: - train.ipynb - section: tutorials contents: + - tutorials/graph_alignment.ipynb - tutorials/mutual_information_msa_pairing.ipynb diff --git a/nbs/tutorials/graph_alignment.ipynb b/nbs/tutorials/graph_alignment.ipynb new file mode 100644 index 0000000..eb931d9 --- /dev/null +++ b/nbs/tutorials/graph_alignment.ipynb @@ -0,0 +1,323 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| hide\n", + "\n", + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Graph alignment\n", + "\n", + "> Using DiffPaSS to align two graphs." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": "**Note**: This notebook requires `networkx`." + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Using device: cuda\n" + ] + } + ], + "source": [ + "# Stdlib imports\n", + "from typing import Optional\n", + "\n", + "# NumPy\n", + "import numpy as np\n", + "\n", + "# PyTorch\n", + "import torch\n", + "\n", + "# Plotting\n", + "from matplotlib import pyplot as plt\n", + "\n", + "# NetworkX\n", + "import networkx as nx\n", + "\n", + "# Set the number of threads for PyTorch\n", + "torch.set_num_threads(8)\n", + "\n", + "# Device\n", + "DEVICE = torch.device(\n", + " f\"cuda{(':' + input('Enter the CUDA device number:')) if torch.cuda.device_count() > 1 else ''}\"\n", + " if torch.cuda.is_available() else \"cpu\"\n", + ")\n", + "# DEVICE = torch.device(\"cpu\")\n", + "print(f\"Using device: {DEVICE}\")\n", + "\n", + "# Set the seeds for NumPy and PyTorch\n", + "NUMPY_SEED = 42\n", + "np.random.seed(NUMPY_SEED)\n", + "\n", + "TORCH_SEED = 42\n", + "torch.manual_seed(TORCH_SEED);\n", + "\n", + "NX_SEED = 42" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Create the adjacency matrices of two SBM graphs\n", + "\n", + "We create two SBM graphs with a few communities of equal size. We then randomly permute the nodes in one of the graphs to create the second graph. The Laplacian matrices of these graphs are used as inputs for `diffpass.train.GraphAlignment`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def create_test_sbm_laplacian(\n", + " block_sizes: list ,\n", + " block_prob: list,\n", + "):\n", + " \"\"\"Create a stochastic block model graph with given block sizes and block probabilities. Return its dense Laplacian matrix.\"\"\"\n", + " g = nx.stochastic_block_model(block_sizes, block_prob, seed=NX_SEED)\n", + " g.remove_nodes_from(list(nx.isolates(g)))\n", + " n = len(g)\n", + " # Laplacian matrix\n", + " l = nx.laplacian_matrix(g, range(n))\n", + "\n", + " return torch.tensor(\n", + " l.todense(), dtype=torch.get_default_dtype(), device=DEVICE\n", + " )\n", + "\n", + "# Create an SBM graph with `n_blocks` communities of equal size\n", + "n_nodes = 40\n", + "n_blocks = 4\n", + "block_sizes = [int(n_nodes / n_blocks)] * n_blocks\n", + "inter = 0.3\n", + "intra = 0.7\n", + "probs = [\n", + " [intra, inter, inter, inter],\n", + " [inter, intra, inter, inter],\n", + " [inter, inter, intra, inter],\n", + " [inter, inter, inter, intra]\n", + "]\n", + "\n", + "# Create the Laplacian matrix of the first graph\n", + "x = create_test_sbm_laplacian(block_sizes, probs)\n", + "\n", + "# Define the second graph by randomly permuting the nodes in the first\n", + "P = torch.zeros_like(x)\n", + "P[torch.arange(n_nodes), torch.randperm(n_nodes)] = 1\n", + "y = P @ x @ P.T" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Align the graphs by maximising the dot product between the Laplacian matrices: `GraphAlignment`\n", + "\n", + "The default behaviour of `GraphAlignment` is to maximise the dot product between the upper triangles (including or excluding the main diagonal) of two input square matrices. The `GraphAlignment` class also allows for custom comparison losses to be used, but we do not do so in this example" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from diffpass.train import GraphAlignment, IntraGroupSimilarityLoss\n", + "\n", + "group_sizes = [n_nodes]\n", + "comparison_loss = IntraGroupSimilarityLoss(group_sizes=group_sizes, exclude_diagonal=False) # Default behaviour except for inclusion of the main diagonal\n", + "\n", + "# Initialize the GraphAlignment model\n", + "graph_alignment = GraphAlignment(\n", + " group_sizes=group_sizes,\n", + " comparison_loss=comparison_loss\n", + ").to(DEVICE)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 38/38 [00:00<00:00, 252.12it/s]\n" + ] + } + ], + "source": [ + "# Optimization parameters for DiffPaSS bootstrap\n", + "bootstrap_cfg = {\n", + " \"n_start\": 1,\n", + " \"n_end\": None,\n", + " \"step_size\": 1, # Increase to speed up if needed\n", + " \"show_pbar\": True,\n", + " \"single_fit_cfg\": None # Default\n", + "}\n", + "\n", + "# Run the DiffPaSS bootstrap\n", + "bootstrap_results = graph_alignment.fit_bootstrap(x, y, **bootstrap_cfg)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": "## 3. Plot the resulting hard losses" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Ground truth hard loss\n", + "target_hard_loss = graph_alignment.compute_losses_identity_perm(x, x)[\"hard\"]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from diffpass.data_utils import compute_num_correct_pairings\n", + "\n", + "def plot_hard_losses(\n", + " results,\n", + " target_hard_loss: Optional[float] = None\n", + "):\n", + " hard_losses = [\n", + " min(hard_losses_this_step)\n", + " for hard_losses_this_step in results.hard_losses\n", + " ]\n", + "\n", + " plt.plot(hard_losses, \".-\", label=\"DiffPaSS hard permutation\")\n", + " plt.axhline(target_hard_loss, color=\"red\", label=f\"Ground truth loss (identity) = {target_hard_loss:.4f}\")\n", + " plt.ylabel(\"Hard loss\")\n", + " plt.xlabel(\"Bootstrap iteration\")\n", + " plt.title(f\"Minimum loss during optimization: {np.min(hard_losses):.4f}\")\n", + " plt.legend()\n", + "\n", + " plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": "plot_hard_losses(bootstrap_results, target_hard_loss=target_hard_loss)" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The hard loss has been decreased, but is a far cry from the ground truth loss.\n", + "\n", + "Can we do better?" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4. Increase the number of repeats for each bootstrap iteration\n", + "\n", + "We can use a \"greedier\" approach and further exploit the randomness in the bootstrap procedure. At each bootstrap iteration, instead of only performing a single gradient optimization, we can perform multiple optimizations -- each using a different, randomly sampled set of fixed pairs. We can then choose the one yielding the lowest hard loss.\n", + "\n", + "This is done by setting `n_repeats` to a value greater than 1 in the `fit_bootstrap` method." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 38/38 [00:13<00:00, 2.89it/s]\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "bootstrap_cfg = {\n", + " \"n_start\": 1,\n", + " \"n_end\": None,\n", + " \"step_size\": 1, # Increase to speed up if needed\n", + " \"show_pbar\": True,\n", + " \"single_fit_cfg\": None, # Default\n", + " \"n_repeats\": 100 # Number of repeats for each bootstrap iteration -- increasing this should lead to better results\n", + "}\n", + "\n", + "graph_alignment = GraphAlignment(\n", + " group_sizes=group_sizes,\n", + " comparison_loss=comparison_loss\n", + ").to(DEVICE)\n", + "\n", + "bootstrap_results = graph_alignment.fit_bootstrap(x, y, **bootstrap_cfg)\n", + "\n", + "plot_hard_losses(bootstrap_results, target_hard_loss=target_hard_loss)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": "Awesome! We have perfectly matched the hard loss given by the ground truth alignment!" + } + ], + "metadata": { + "kernelspec": { + "display_name": "python3", + "language": "python", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/nbs/tutorials/mutual_information_msa_pairing.ipynb b/nbs/tutorials/mutual_information_msa_pairing.ipynb index d5be9e8..21bdb97 100644 --- a/nbs/tutorials/mutual_information_msa_pairing.ipynb +++ b/nbs/tutorials/mutual_information_msa_pairing.ipynb @@ -16,7 +16,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# DiffPaSS – Example usage on datasets of interacting protein systems\n", + "# Pairing two protein MSAs by maximising mutual information\n", "\n", "> DiffPaSS and DiffPaSS-IPA for pairing two interacting MSAs using mutual information." ] @@ -180,9 +180,7 @@ { "cell_type": "markdown", "metadata": {}, - "source": [ - "## 3. Optimize pairings by maximising mutual information between chains: ``InformationAlignment``" - ] + "source": "## 3. Optimize pairings by maximising mutual information between chains: `InformationAlignment`" }, { "cell_type": "code",