From aa6f6ac10dc25ad023d7e26467ca65c99fb30c06 Mon Sep 17 00:00:00 2001 From: Oisin-M Date: Thu, 20 Jun 2024 17:45:22 +0200 Subject: [PATCH] Add example of pooling via masking --- tutorials/pooling/masking.ipynb | 414 ++++++++++++++++++++++++++++++++ 1 file changed, 414 insertions(+) create mode 100644 tutorials/pooling/masking.ipynb diff --git a/tutorials/pooling/masking.ipynb b/tutorials/pooling/masking.ipynb new file mode 100644 index 0000000..0cdde3b --- /dev/null +++ b/tutorials/pooling/masking.ipynb @@ -0,0 +1,414 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "view-in-github" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "CtOVuaD1oJGf" + }, + "outputs": [], + "source": [ + "# Install PyTorch\n", + "try:\n", + " import torch\n", + "except ImportError:\n", + " !pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118\n", + " import torch" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "NSPoTlJupTMQ" + }, + "outputs": [], + "source": [ + "# Install PyG\n", + "try:\n", + " import torch_geometric\n", + "except ImportError:\n", + " !pip3 install torch_geometric\n", + " import torch_geometric" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "rXVVqRjT0CHr" + }, + "outputs": [], + "source": [ + "# Clone and import gca-rom\n", + "import sys\n", + "if 'google.colab' in str(get_ipython()):\n", + " !git clone https://github.com/fpichi/gca-rom.git\n", + " sys.path.append('gca-rom')\n", + "else:\n", + " sys.path.append('../..')\n", + "\n", + "from gca_rom import network, pde, loader, plotting, preprocessing, training, initialization, testing, error, gui" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "sBjDa3kg2gyX" + }, + "outputs": [], + "source": [ + "import numpy as np\n", + "from itertools import product" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "xrvQWQQD274v" + }, + "source": [ + "# Define PDE problem\n", + "For the description of the model and generation of the dataset look at: [RBniCS/tutorials/09_advection_dominated](https://github.com/RBniCS/RBniCS/blob/master/tutorials/09_advection_dominated/tutorial_advection_dominated_1_pod.ipynb)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "l1sNbNWi2x5J", + "outputId": "becbd29c-699e-4929-d4bb-4c48582777c7" + }, + "outputs": [], + "source": [ + "problem_name, variable, mu_space, n_param, dim_pde, n_comp = pde.problem(2)\n", + "argv = gui.hyperparameters_selection(problem_name, variable, n_param, n_comp)\n", + "HyperParams = network.HyperParams(argv)\n", + "pool_rate = 0.7 # % of nodes to keep\n", + "HyperParams.__dict__" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "PgphDqaw3Dvk" + }, + "source": [ + "# Initialize device and set reproducibility" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "oLoNgZkI257l", + "outputId": "33acd5e9-963f-47f1-c3ab-92cd356c9562" + }, + "outputs": [], + "source": [ + "device = initialization.set_device()\n", + "initialization.set_reproducibility(HyperParams)\n", + "initialization.set_path(HyperParams)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "NbCmXQ4I3IAv" + }, + "source": [ + "# Load dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "PgCWJRAm3CuH", + "outputId": "6d8e1461-6498-4c42-8519-5244aaeef703" + }, + "outputs": [], + "source": [ + "if 'google.colab' in str(get_ipython()):\n", + " dataset_dir = '/content/gca-rom/dataset/'+problem_name+'_unstructured.mat'\n", + "else:\n", + " dataset_dir = '../../dataset/'+problem_name+'_unstructured.mat'\n", + "dataset = loader.LoadDataset(dataset_dir, variable, dim_pde, n_comp)\n", + "\n", + "graph_loader, train_loader, test_loader, \\\n", + " val_loader, scaler_all, scaler_test, xyz, VAR_all, VAR_test, \\\n", + " train_trajectories, test_trajectories = preprocessing.graphs_dataset(dataset, HyperParams)\n", + "\n", + "params = torch.tensor(np.array(list(product(*mu_space))))\n", + "params = params.to(device)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Define Pooling Network" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "mask_size = int(VAR_all.shape[1]*pool_rate)\n", + "random_indices=torch.randperm(VAR_all.shape[1])\n", + "mask = random_indices[0:mask_size]\n", + "mask, _ = torch.sort(mask)\n", + "print(\"Mask of size\", mask.shape[0], \"/\", VAR_all.shape[1])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from torch_geometric.nn.unpool import knn_interpolate\n", + "from torch_geometric.utils import subgraph\n", + "from copy import deepcopy\n", + "\n", + "class PooledNet(network.Net):\n", + " def __init__(self, HyperParams, mask):\n", + " prev_size = HyperParams.num_nodes\n", + " HyperParams.num_nodes = int(HyperParams.num_nodes*pool_rate)\n", + " network.Net.__init__(self, HyperParams)\n", + " self.mask = mask\n", + " HyperParams.num_nodes = prev_size\n", + "\n", + " def pool(self, data):\n", + " data = deepcopy(data)\n", + " mask = torch.tile(self.mask, (data.num_graphs, 1)) + data.ptr[:-1].reshape(-1,1)\n", + " mask = mask.flatten()\n", + " edge_index, edge_attr, edge_mask = subgraph(mask, data.edge_index, data.edge_attr, return_edge_mask=True, relabel_nodes=True)\n", + " data.edge_index = edge_index\n", + " data.edge_attr = edge_attr\n", + " data.edge_weight = data.edge_weight[edge_mask]\n", + " data.x = data.x[mask]\n", + " data.batch = data.batch[mask]\n", + " data.ptr = torch.arange(0, mask.shape[0]+1, self.mask.shape[0])\n", + " data.pos = data.pos[mask]\n", + " return data\n", + "\n", + " def solo_encoder(self, data):\n", + " pooled_data = self.pool(data)\n", + " x = self.encoder(pooled_data)\n", + " return x\n", + "\n", + " def solo_decoder(self, x, data):\n", + " pooled_data = self.pool(data)\n", + " x = self.decoder(x, pooled_data)\n", + " x = knn_interpolate(x=x, pos_x=pooled_data.pos, pos_y=data.pos, batch_x=pooled_data.batch, batch_y=data.batch)\n", + " return x" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "NReTyPBUMzy0" + }, + "source": [ + "# Define the architecture" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "lW5gsJNl3K7_" + }, + "outputs": [], + "source": [ + "model = PooledNet(HyperParams, mask)\n", + "model = model.to(device)\n", + "if 'google.colab' in str(get_ipython()):\n", + " torch.set_default_dtype(torch.float32)\n", + "optimizer = torch.optim.Adam(model.parameters(), lr=HyperParams.learning_rate, weight_decay=HyperParams.weight_decay)\n", + "scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=HyperParams.miles, gamma=HyperParams.gamma)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Zpl8zN4eXRaF" + }, + "source": [ + "# Train or load a pre-trained network" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "6BPwYDBKXRru", + "outputId": "db035bc1-918c-4206-f48a-0c71a0ee561d" + }, + "outputs": [], + "source": [ + "try:\n", + " model.load_state_dict(torch.load(HyperParams.net_dir+HyperParams.net_name+HyperParams.net_run+'.rpt'))\n", + " print('Loading saved network')\n", + "except FileNotFoundError:\n", + " print('Training network')\n", + " training.train(model, optimizer, device, scheduler, params, train_loader, test_loader, train_trajectories, test_trajectories, HyperParams)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "3hKW8WLhNuJB" + }, + "source": [ + "# Evaluate the model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "1Fgnj-NzNxDs", + "outputId": "452f2a99-1a1a-4bd5-99a6-a05575a67526" + }, + "outputs": [], + "source": [ + "model.to(\"cpu\")\n", + "params = params.to(\"cpu\")\n", + "vars = \"GCA-ROM\"\n", + "results, latents_map, latents_gca = testing.evaluate(VAR_all, model, graph_loader, params, HyperParams, range(params.shape[0]))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "RSJLj7rMN3iE" + }, + "source": [ + "# Plot the results" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000 + }, + "id": "VwIkAWi4NzsF", + "outputId": "f5a28d12-59b4-4911-a55b-c1f88ed9a759" + }, + "outputs": [], + "source": [ + "plotting.plot_sample(HyperParams, mu_space, params, train_trajectories, test_trajectories)\n", + "plotting.plot_loss(HyperParams)\n", + "plotting.plot_latent(HyperParams, latents_map, latents_gca)\n", + "\n", + "plotting.plot_error(results, VAR_all, scaler_all, HyperParams, mu_space, params, train_trajectories, vars)\n", + "plotting.plot_error_2d(results, VAR_all, scaler_all, HyperParams, mu_space, params, train_trajectories, vars)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "N = 3\n", + "snapshots = np.arange(params.shape[0]).tolist()\n", + "np.random.shuffle(snapshots)\n", + "for SNAP in snapshots[0:N]:\n", + " plotting.plot_fields(SNAP, results, scaler_all, HyperParams, dataset, xyz, params)\n", + " plotting.plot_error_fields(SNAP, results, VAR_all, scaler_all, HyperParams, dataset, xyz, params)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "uiaGsslJN8AH" + }, + "source": [ + "# Print the errors on the testing set" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "R1BkkSEBN7K5", + "outputId": "879994ab-c765-4da3-bba1-2097417d770c" + }, + "outputs": [], + "source": [ + "results_test, _, _ = testing.evaluate(VAR_test, model, val_loader, params, HyperParams, test_trajectories)\n", + "\n", + "error_abs, norm = error.compute_error(results_test, VAR_test, scaler_test, HyperParams)\n", + "error.print_error(error_abs, norm, vars)\n", + "error.save_error(error_abs, norm, HyperParams, vars)\n", + "\n", + "plotting.plot_comparison_fields(results, VAR_all, scaler_all, HyperParams, dataset, xyz, params)" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "authorship_tag": "ABX9TyOsdK1YHSg5y9p5Yx8JmLPb", + "gpuType": "T4", + "include_colab_link": true, + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}