From 20220800620bf0a6068ef95377748816524fd527 Mon Sep 17 00:00:00 2001 From: Umberto Lupo Date: Fri, 17 May 2024 10:27:44 +0200 Subject: [PATCH] [DOC] Small fixes in graph alignment tutorial --- nbs/tutorials/graph_alignment.ipynb | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/nbs/tutorials/graph_alignment.ipynb b/nbs/tutorials/graph_alignment.ipynb index eb931d9..3bd8e5c 100644 --- a/nbs/tutorials/graph_alignment.ipynb +++ b/nbs/tutorials/graph_alignment.ipynb @@ -121,6 +121,7 @@ "\n", "# Create the Laplacian matrix of the first graph\n", "x = create_test_sbm_laplacian(block_sizes, probs)\n", + "n_nodes = len(x) # In case the number of nodes has changed due to removal of isolates\n", "\n", "# Define the second graph by randomly permuting the nodes in the first\n", "P = torch.zeros_like(x)\n", @@ -146,7 +147,9 @@ "from diffpass.train import GraphAlignment, IntraGroupSimilarityLoss\n", "\n", "group_sizes = [n_nodes]\n", - "comparison_loss = IntraGroupSimilarityLoss(group_sizes=group_sizes, exclude_diagonal=False) # Default behaviour except for inclusion of the main diagonal\n", + "comparison_loss = IntraGroupSimilarityLoss(\n", + " group_sizes=group_sizes, exclude_diagonal=False\n", + ") # Default behaviour except for inclusion of the main diagonal\n", "\n", "# Initialize the GraphAlignment model\n", "graph_alignment = GraphAlignment(\n", @@ -203,8 +206,6 @@ "metadata": {}, "outputs": [], "source": [ - "from diffpass.data_utils import compute_num_correct_pairings\n", - "\n", "def plot_hard_losses(\n", " results,\n", " target_hard_loss: Optional[float] = None\n",