From 05d8d96a628b647d9f3123dd40b08c3976ab9b2b Mon Sep 17 00:00:00 2001 From: bmandracchia Date: Mon, 22 Jul 2024 20:19:18 +0200 Subject: [PATCH] update nets --- nbs/01_nets.ipynb | 119 ++++++++++++++++++++++++++-------------------- 1 file changed, 68 insertions(+), 51 deletions(-) diff --git a/nbs/01_nets.ipynb b/nbs/01_nets.ipynb index a64b474..ba5121c 100755 --- a/nbs/01_nets.ipynb +++ b/nbs/01_nets.ipynb @@ -66,7 +66,16 @@ "\n", "from fastai.vision.all import ConvLayer, Lambda, MaxPool, NormType, nn, np\n", "from torch import cat as torch_cat\n", + "from torch import Tensor as torch_Tensor\n", "import torch.nn.functional as F\n", + "import torch.nn as nn\n", + "from monai.networks.blocks import Convolution\n", + "from monai.networks.layers.factories import Act, Norm, Pool\n", + "from monai.networks.nets import resnet\n", + "\n", + "from dataclasses import dataclass, field\n", + "from typing import List, Tuple, Optional, Union\n", + "\n", "from bioMONAI.core import attributesFromDict" ] }, @@ -426,24 +435,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## DeepLab v3" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from dataclasses import dataclass, field\n", - "from typing import List, Tuple, Optional, Union\n", - "import torch\n", - "import torch.nn as nn\n", - "import torch.nn.functional as F\n", - "from monai.networks.blocks import Convolution\n", - "from monai.networks.layers.factories import Act, Norm, Pool\n", - "from torchvision.models.resnet import resnet50, resnet101\n", - "from monai.networks.nets import resnet\n" + "## DeepLab v3+" ] }, { @@ -455,11 +447,10 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 14, "metadata": {}, "outputs": [], "source": [ - "\n", "@dataclass\n", "class DeeplabConfig:\n", " dimensions: int\n", @@ -475,14 +466,14 @@ "\n", "def get_padding(kernel_size: int, dilation: int) -> int:\n", " return (kernel_size - 1) * dilation // 2\n", - "\n", - "def interpolate(x: torch.Tensor, size: Union[List[int], Tuple[int, ...]], mode: str) -> torch.Tensor:\n", - " if x.dim() == 4: # 2D\n", - " return F.interpolate(x, size=size, mode=mode, align_corners=True)\n", - " elif x.dim() == 5: # 3D\n", - " return F.interpolate(x, size=size, mode=mode, align_corners=True)\n", + " \n", + "def interpolate(x: torch_Tensor, size: Union[List[int], Tuple[int, ...]], dims: int) -> torch_Tensor:\n", + " if dims == 2:\n", + " return F.interpolate(x, size=size, mode='bilinear', align_corners=True)\n", + " elif dims == 3:\n", + " return F.interpolate(x, size=size, mode='trilinear', align_corners=True)\n", " else:\n", - " raise ValueError(f\"Unsupported input dimension: {x.dim()}\")" + " raise ValueError(f\"Unsupported number of dimensions: {dims}\")" ] }, { @@ -494,7 +485,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 15, "metadata": {}, "outputs": [], "source": [ @@ -503,7 +494,7 @@ " dilation: int = 1, bias: bool = False, norm: Optional[str] = None):\n", " super().__init__()\n", " self.conv1 = Convolution(\n", - " dimensions=config.dimensions, \n", + " spatial_dims=config.dimensions, \n", " in_channels=inplanes, \n", " out_channels=inplanes, \n", " kernel_size=kernel_size,\n", @@ -514,7 +505,7 @@ " strides=stride\n", " )\n", " self.pointwise = Convolution(\n", - " dimensions=config.dimensions, \n", + " spatial_dims=config.dimensions, \n", " in_channels=inplanes, \n", " out_channels=planes, \n", " kernel_size=1, \n", @@ -526,7 +517,7 @@ " norm=Norm.BATCH if norm else None\n", " )\n", "\n", - " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", + " def forward(self, x: torch_Tensor) -> torch_Tensor:\n", " x = self.conv1(x)\n", " x = self.pointwise(x)\n", " return x\n", @@ -570,7 +561,7 @@ "\n", " self.rep = nn.Sequential(*rep)\n", "\n", - " def forward(self, inp: torch.Tensor) -> torch.Tensor:\n", + " def forward(self, inp: torch_Tensor) -> torch_Tensor:\n", " x = self.rep(inp)\n", " if self.skip is not None:\n", " skip = self.skip(inp)\n", @@ -589,7 +580,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 16, "metadata": {}, "outputs": [], "source": [ @@ -624,7 +615,7 @@ " self.conv4 = SeparableConv(config, 1536, 1536, 3, stride=1, dilation=config.exit_block_dilations[1], norm=Norm.BATCH)\n", " self.conv5 = SeparableConv(config, 1536, 2048, 3, stride=1, dilation=config.exit_block_dilations[1], norm=Norm.BATCH)\n", "\n", - " def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:\n", + " def forward(self, x: torch_Tensor) -> Tuple[torch_Tensor, torch_Tensor]:\n", " # Entry flow\n", " x = self.conv1(x)\n", " x = self.relu(x)\n", @@ -660,7 +651,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 17, "metadata": {}, "outputs": [], "source": [ @@ -674,7 +665,7 @@ " bias=False, norm=Norm.BATCH)\n", " self.relu = nn.ReLU()\n", "\n", - " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", + " def forward(self, x: torch_Tensor) -> torch_Tensor:\n", " x = self.atrous_convolution(x)\n", " return self.relu(x)" ] @@ -688,7 +679,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 18, "metadata": {}, "outputs": [], "source": [ @@ -737,7 +728,7 @@ " Convolution(config.dimensions, 256, config.out_channels, kernel_size=1, strides=1)\n", " )\n", "\n", - " def forward(self, input: torch.Tensor) -> torch.Tensor:\n", + " def forward(self, input: torch.Tensor) -> torch.Tensor:\n", " if self.config.backbone == \"xception\":\n", " x, low_level_features = self.backbone(input)\n", " else:\n", @@ -753,17 +744,17 @@ "\n", " aspp_results = [module(x) for module in self.aspp_modules]\n", " x5 = self.global_avg_pool(x)\n", - " x5 = interpolate(x5, size=x.shape[2:], mode='linear' if self.config.dimensions == 2 else 'trilinear')\n", + " x5 = interpolate(x5, size=x.shape[2:], dims=self.config.dimensions)\n", " x = torch.cat(aspp_results + [x5], dim=1)\n", "\n", " x = self.conv1(x)\n", - " x = interpolate(x, size=low_level_features.shape[2:], mode='linear' if self.config.dimensions == 2 else 'trilinear')\n", + " x = interpolate(x, size=low_level_features.shape[2:], dims=self.config.dimensions)\n", "\n", " low_level_features = self.conv2(low_level_features)\n", "\n", " x = torch.cat((x, low_level_features), dim=1)\n", " x = self.last_conv(x)\n", - " x = interpolate(x, size=input.shape[2:], mode='linear' if self.config.dimensions == 2 else 'trilinear')\n", + " x = interpolate(x, size=input.shape[2:], dims=self.config.dimensions)\n", "\n", " return x" ] @@ -777,7 +768,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 19, "metadata": {}, "outputs": [], "source": [ @@ -786,20 +777,46 @@ " dimensions=2,\n", " in_channels=3, # For RGB images\n", " out_channels=4,\n", - " middle_flow_blocks=16,\n", + " backbone=\"xception\", # or whatever backbone you're using\n", " aspp_dilations=[1, 6, 12, 18]\n", ")\n", "model_2d = Deeplab(config_2d)\n", "\n", "# For 3D images\n", - "config_3d = DeeplabConfig(\n", - " dimensions=3,\n", - " in_channels=1, # For single-channel 3D medical images\n", - " out_channels=4,\n", - " middle_flow_blocks=16,\n", - " aspp_dilations=[1, 6, 12, 18]\n", - ")\n", - "model_3d = Deeplab(config_3d)" + "# config_3d = DeeplabConfig(\n", + "# dimensions=3,\n", + "# in_channels=1, # For single-channel 3D medical images\n", + "# out_channels=4,\n", + "# middle_flow_blocks=16,\n", + "# aspp_dilations=[1, 6, 12, 18]\n", + "# )\n", + "# model_3d = Deeplab(config_3d)" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "ename": "NotImplementedError", + "evalue": "Got 4D input, but linear mode needs 3D input", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNotImplementedError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[23], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m x \u001b[38;5;241m=\u001b[39m torch_randn(\u001b[38;5;241m16\u001b[39m,\u001b[38;5;241m3\u001b[39m, \u001b[38;5;241m64\u001b[39m, \u001b[38;5;241m64\u001b[39m)\n\u001b[0;32m----> 2\u001b[0m y \u001b[38;5;241m=\u001b[39m model_2d(x)\n", + "File \u001b[0;32m~/miniconda3/envs/bioMONAI-env/lib/python3.11/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m forward_call(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 1502\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", + "Cell \u001b[0;32mIn[18], line 62\u001b[0m, in \u001b[0;36mDeeplab.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 60\u001b[0m aspp_results \u001b[38;5;241m=\u001b[39m [module(x) \u001b[38;5;28;01mfor\u001b[39;00m module \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maspp_modules]\n\u001b[1;32m 61\u001b[0m x5 \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mglobal_avg_pool(x)\n\u001b[0;32m---> 62\u001b[0m x5 \u001b[38;5;241m=\u001b[39m interpolate(x5, size\u001b[38;5;241m=\u001b[39mx\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m2\u001b[39m:], mode\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mlinear\u001b[39m\u001b[38;5;124m'\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39mdimensions \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m2\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtrilinear\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m 63\u001b[0m x \u001b[38;5;241m=\u001b[39m torch_cat(aspp_results \u001b[38;5;241m+\u001b[39m [x5], dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m 65\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconv1(x)\n", + "Cell \u001b[0;32mIn[14], line 19\u001b[0m, in \u001b[0;36minterpolate\u001b[0;34m(x, size, mode)\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21minterpolate\u001b[39m(x: torch_Tensor, size: Union[List[\u001b[38;5;28mint\u001b[39m], Tuple[\u001b[38;5;28mint\u001b[39m, \u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m]], mode: \u001b[38;5;28mstr\u001b[39m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m torch_Tensor:\n\u001b[1;32m 18\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m x\u001b[38;5;241m.\u001b[39mdim() \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m4\u001b[39m: \u001b[38;5;66;03m# 2D\u001b[39;00m\n\u001b[0;32m---> 19\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m F\u001b[38;5;241m.\u001b[39minterpolate(x, size\u001b[38;5;241m=\u001b[39msize, mode\u001b[38;5;241m=\u001b[39mmode, align_corners\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m 20\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m x\u001b[38;5;241m.\u001b[39mdim() \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m5\u001b[39m: \u001b[38;5;66;03m# 3D\u001b[39;00m\n\u001b[1;32m 21\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m F\u001b[38;5;241m.\u001b[39minterpolate(x, size\u001b[38;5;241m=\u001b[39msize, mode\u001b[38;5;241m=\u001b[39mmode, align_corners\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n", + "File \u001b[0;32m~/miniconda3/envs/bioMONAI-env/lib/python3.11/site-packages/torch/nn/functional.py:3974\u001b[0m, in \u001b[0;36minterpolate\u001b[0;34m(input, size, scale_factor, mode, align_corners, recompute_scale_factor, antialias)\u001b[0m\n\u001b[1;32m 3972\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mNotImplementedError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mGot 3D input, but trilinear mode needs 5D input\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 3973\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28minput\u001b[39m\u001b[38;5;241m.\u001b[39mdim() \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m4\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m mode \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlinear\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[0;32m-> 3974\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mNotImplementedError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mGot 4D input, but linear mode needs 3D input\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 3975\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28minput\u001b[39m\u001b[38;5;241m.\u001b[39mdim() \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m4\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m mode \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtrilinear\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[1;32m 3976\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mNotImplementedError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mGot 4D input, but trilinear mode needs 5D input\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", + "\u001b[0;31mNotImplementedError\u001b[0m: Got 4D input, but linear mode needs 3D input" + ] + } + ], + "source": [ + "x = torch_randn(16,3, 64, 64)\n", + "y = model_2d(x)" ] }, {