From 60e62ff7a6b9194d6206f66345a870da80c9dfbd Mon Sep 17 00:00:00 2001 From: priyansi Date: Sat, 20 Nov 2021 10:12:54 +0530 Subject: [PATCH] added nb --- .../12-train-multi-output-model.ipynb | 410 ++++++++++++++++++ 1 file changed, 410 insertions(+) create mode 100644 how-to-guides/12-train-multi-output-model.ipynb diff --git a/how-to-guides/12-train-multi-output-model.ipynb b/how-to-guides/12-train-multi-output-model.ipynb new file mode 100644 index 0000000..da4ab0a --- /dev/null +++ b/how-to-guides/12-train-multi-output-model.ipynb @@ -0,0 +1,410 @@ +{ + "nbformat": 4, + "nbformat_minor": 5, + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.8" + }, + "colab": { + "name": "12-train-multi-output-model.ipynb", + "provenance": [], + "collapsed_sections": [] + }, + "accelerator": "GPU" + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "HgljXEAJEcFq" + }, + "source": [ + "\n", + "# How to train a multi output model " + ], + "id": "HgljXEAJEcFq" + }, + { + "cell_type": "markdown", + "metadata": { + "id": "hQPlnvmfEfir" + }, + "source": [ + "## Required Dependencies" + ], + "id": "hQPlnvmfEfir" + }, + { + "cell_type": "code", + "metadata": { + "id": "3d09Jnzt_qmp" + }, + "source": [ + "!pip install pytorch-ignite -q" + ], + "id": "3d09Jnzt_qmp", + "execution_count": 32, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "DcnSr5sGEcFz" + }, + "source": [ + "## Imports" + ], + "id": "DcnSr5sGEcFz" + }, + { + "cell_type": "code", + "metadata": { + "pycharm": { + "is_executing": false + }, + "id": "Y0sJP9iFa1TB" + }, + "source": [ + "import torch\n", + "from torch import nn\n", + "from torch.utils.data import Dataset, DataLoader\n", + "\n", + "from ignite.engine import Engine, Events, create_supervised_trainer, create_supervised_evaluator\n", + "from ignite.metrics import MeanAbsoluteError" + ], + "id": "Y0sJP9iFa1TB", + "execution_count": 33, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "NFzQCnLMOSXX" + }, + "source": [ + "## Data\n", + "\n", + "For out dataset, we will create a custom dataset which takes a range and calculates two labels via the `power2()` and `mul2()` methods." + ], + "id": "NFzQCnLMOSXX" + }, + { + "cell_type": "code", + "metadata": { + "id": "RQARkoCcNO1o" + }, + "source": [ + "def power2(x):\n", + " return x**2\n", + "\n", + "def mul2(x):\n", + " return 2*x\n", + "\n", + "class RangeDataset(Dataset):\n", + " def __init__(self, start, end):\n", + " self.start = start\n", + " self.end = end\n", + "\n", + " def __len__(self):\n", + " return self.end - self.start + 1\n", + "\n", + " def __getitem__(self, idx):\n", + " if torch.is_tensor(idx):\n", + " idx = idx.tolist()\n", + "\n", + " item = self.start + idx\n", + " y1 = power2(item)\n", + " y2 = mul2(item)\n", + " sample = [torch.Tensor([item]), torch.Tensor([y1]), torch.Tensor([y2])]\n", + " return sample\n", + "\n", + "train_dataset = RangeDataset(1, 80)\n", + "val_dataset = RangeDataset(90, 100)\n", + "\n", + "train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)\n", + "val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False)" + ], + "id": "RQARkoCcNO1o", + "execution_count": 34, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Gynzy2VWOpRL" + }, + "source": [ + "## Model\n", + "\n", + "Let's create a dummy `Net()`. We have two linear layers for predicting the two labels which we can simply return from `forward()`." + ], + "id": "Gynzy2VWOpRL" + }, + { + "cell_type": "code", + "metadata": { + "id": "iK_9cOP6a1TI" + }, + "source": [ + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "\n", + "class Net(nn.Module):\n", + "\n", + " def __init__(self):\n", + " super(Net, self).__init__()\n", + " \n", + " self.model1 = nn.Linear(1, 1)\n", + " self.model2 = nn.Linear(1, 1)\n", + "\n", + " def forward(self, x):\n", + " return self.model1(x), self.model2(x)\n", + "\n", + "model = Net().to(device)\n", + "optimizer = torch.optim.Adam(model.parameters(), lr=0.005)\n", + "criterion = nn.MSELoss()" + ], + "id": "iK_9cOP6a1TI", + "execution_count": 35, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ODT-jag3O8B1" + }, + "source": [ + "## Trainer\n", + "\n", + "Now we will create a custom `trainer` which unpacks a batch, makes a prediction, calculates the two losses and returns their sum." + ], + "id": "ODT-jag3O8B1" + }, + { + "cell_type": "code", + "metadata": { + "id": "ItoswUK-23St" + }, + "source": [ + "def train_step(engine, batch):\n", + " model.train()\n", + " optimizer.zero_grad()\n", + " x, y1, y2 = batch[0].to(device), batch[1].to(device), batch[2].to(device)\n", + " y_pred1, y_pred2 = model(x)\n", + " loss1 = criterion(y_pred1, y1)\n", + " loss2 = criterion(y_pred2, y2)\n", + "\n", + " loss1.backward()\n", + " loss2.backward()\n", + " optimizer.step()\n", + " return loss1.item() + loss2.item()\n", + "\n", + "trainer = Engine(train_step)" + ], + "id": "ItoswUK-23St", + "execution_count": 42, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "iHo0ET3FPYN-" + }, + "source": [ + "## Evaluator\n", + "\n", + "For the evaluation step, we will leverage [`create_supervised_evaluator()`](https://pytorch.org/ignite/generated/ignite.engine.create_supervised_evaluator.html#create-supervised-evaluator) to adapt to our use case. \n", + "\n", + "At first, we need to prepare our batch since we have two labels. The `prepare_batch()` method receives the necessary parameters and must return a tuple of inputs and targets `(x, y)`. Here we have adapted `y` to be another tuple containing the two labels." + ], + "id": "iHo0ET3FPYN-" + }, + { + "cell_type": "code", + "metadata": { + "id": "MwE7mXSbNbro" + }, + "source": [ + "def prepare_batch(batch, device, non_blocking):\n", + " x, y1, y2 = batch[0].to(device), batch[1].to(device), batch[2].to(device)\n", + " return (x, (y1, y2))\n", + "\n", + "val_evaluator = create_supervised_evaluator(model, prepare_batch=prepare_batch, device=device)" + ], + "id": "MwE7mXSbNbro", + "execution_count": 43, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ZedFZp-BfKet" + }, + "source": [ + "The above evaluator translates into:\n", + "```python\n", + "def validation_step(engine, batch):\n", + " model.eval()\n", + " with torch.no_grad():\n", + " x, y = prepare_batch(...)\n", + " y_pred = model(x)\n", + " return y_pred, y\n", + "\n", + "val_evaluator = Engine(validation_step)\n", + "```" + ], + "id": "ZedFZp-BfKet" + }, + { + "cell_type": "markdown", + "metadata": { + "id": "NYfhTF5FoAv2" + }, + "source": [ + "You can also pass an `output_transform()` to `val_evaluator` for more complex needs as follows:\n", + "```python\n", + "def output_transform(x, y, y_pred):\n", + " return {'y_pred': y_pred, 'y': y}\n", + "```" + ], + "id": "NYfhTF5FoAv2" + }, + { + "cell_type": "markdown", + "metadata": { + "id": "S28hIDeUR-LR" + }, + "source": [ + "## Metrics\n", + "\n", + "Since our dummy dataset is better suited for a regression task, let's use [`MeanAbsoluteError()`](https://pytorch.org/ignite/generated/ignite.metrics.MeanAbsoluteError.html#meanabsoluteerror) as our metric. However, we cannnot directly use it and need to pass a transform method to further preprocess the predictions and labels returned. Below, we are simply adding the predictions and labels respectively to return our final `y_pred` and `y` upon which the mtric will be calculated." + ], + "id": "S28hIDeUR-LR" + }, + { + "cell_type": "code", + "metadata": { + "id": "LoL5OF8MNhW1" + }, + "source": [ + "def mae_output_transform(output):\n", + " y_pred1, y_pred2 = output[0]\n", + " y1, y2 = output[1]\n", + " y_pred = y_pred1 + y_pred2\n", + " y = y1 + y2\n", + " return y_pred, y\n", + "\n", + "val_metrics = {\n", + " 'mae': MeanAbsoluteError(mae_output_transform),\n", + "}\n", + "\n", + "for name, metric in val_metrics.items():\n", + " metric.attach(val_evaluator, name)" + ], + "id": "LoL5OF8MNhW1", + "execution_count": 44, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "2HEoHRkio6vS" + }, + "source": [ + "Finally, let's log our results and start training." + ], + "id": "2HEoHRkio6vS" + }, + { + "cell_type": "code", + "metadata": { + "id": "7_Y_U0cCEp4W" + }, + "source": [ + "@trainer.on(Events.EPOCH_COMPLETED)\n", + "def log_validation_results(trainer):\n", + " val_evaluator.run(val_loader)\n", + " metrics = val_evaluator.state.metrics\n", + " print(f\"Validation Results - Epoch[{trainer.state.epoch}] Avg MAE: {metrics['mae']:.2f}\")" + ], + "id": "7_Y_U0cCEp4W", + "execution_count": 45, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "0tkgJU9SE24S", + "outputId": "edc26afc-fab2-44e2-9c39-fb99bf751396" + }, + "source": [ + "trainer.run(train_loader, max_epochs=1)" + ], + "id": "0tkgJU9SE24S", + "execution_count": 46, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Validation Results - Epoch[1] Avg MAE: 9322.08\n" + ] + }, + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "State:\n", + "\titeration: 20\n", + "\tepoch: 1\n", + "\tepoch_length: 20\n", + "\tmax_epochs: 1\n", + "\toutput: 10597183.16796875\n", + "\tbatch: \n", + "\tmetrics: \n", + "\tdataloader: \n", + "\tseed: \n", + "\ttimes: " + ] + }, + "metadata": {}, + "execution_count": 46 + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "1N6ndAQ6de4E" + }, + "source": [ + "Although, the MAE for our dummy model is really large, this code can easily be adapted to fit multioutput models and then calculate respective metrics!" + ], + "id": "1N6ndAQ6de4E" + } + ] +} \ No newline at end of file