From 1b1b545e42519d3978972a7dff164ff8d6de30b1 Mon Sep 17 00:00:00 2001 From: bmandracchia Date: Wed, 17 Jul 2024 18:31:11 +0200 Subject: [PATCH] update deeplab --- nbs/01_nets.ipynb | 541 ++++++++++++++++++++++------------------------ 1 file changed, 264 insertions(+), 277 deletions(-) diff --git a/nbs/01_nets.ipynb b/nbs/01_nets.ipynb index 64fab4b..a64b474 100755 --- a/nbs/01_nets.ipynb +++ b/nbs/01_nets.ipynb @@ -70,6 +70,26 @@ "from bioMONAI.core import attributesFromDict" ] }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "cpu\n" + ] + } + ], + "source": [ + "from torch import device as torch_device\n", + "from torch.cuda import is_available as cuda_is_available\n", + "device = torch_device(\"cuda\" if cuda_is_available() else \"cpu\")\n", + "print(device)" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -79,7 +99,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -107,7 +127,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -133,7 +153,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ @@ -171,7 +191,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 10, "metadata": {}, "outputs": [ { @@ -225,7 +245,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ @@ -300,7 +320,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ @@ -415,99 +435,22 @@ "metadata": {}, "outputs": [], "source": [ - "# Import necessary libraries\n", + "from dataclasses import dataclass, field\n", + "from typing import List, Tuple, Optional, Union\n", "import torch\n", "import torch.nn as nn\n", - "import torch.optim as optim\n", - "from torchvision import models\n", - "\n", - "class DeepLabV3Plus(nn.Module):\n", - " def __init__(self, num_classes=21, input_size=(256, 256)):\n", - " super(DeepLabV3Plus, self).__init__()\n", - " # Store the user-defined input size for later use\n", - " self.input_size = input_size\n", - " \n", - " # Load the pre-trained ResNet50 model\n", - " self.backbone = models.resnet50(pretrained=True)\n", - " \n", - " # Replace the last layer of the backbone to match the number of classes\n", - " num_features = self.backbone.fc.in_features\n", - " self.backbone.fc = nn.Identity()\n", - " \n", - " # ASPP (Atrous Spatial Pyramid Pooling)\n", - " self.aspp = nn.Sequential(\n", - " nn.Conv2d(num_features, 256, kernel_size=1, stride=1),\n", - " nn.ReLU(),\n", - " nn.BatchNorm2d(256),\n", - " ASPPPooling(num_features, 256),\n", - " nn.Conv2d(256, num_classes, kernel_size=1, stride=1)\n", - " )\n", - " \n", - " # Decoder (upsampling + convolution)\n", - " self.decoder = nn.Sequential(\n", - " nn.Conv2d(num_features + 256, 256, kernel_size=3, padding=1),\n", - " nn.ReLU(),\n", - " nn.BatchNorm2d(256),\n", - " nn.Conv2d(256, num_classes, kernel_size=3, padding=1)\n", - " )\n", - " \n", - " def forward(self, x):\n", - " # Ensure the input size matches the defined input size\n", - " if x.shape[2:] != self.input_size:\n", - " raise ValueError(\"Input size must be {}x{}\".format(*self.input_size))\n", - " \n", - " # Extract features from the backbone\n", - " feat = self.backbone.features(x)\n", - " \n", - " # Apply ASPP to get high-level feature maps\n", - " aspp_out = self.aspp(feat[-1])\n", - " \n", - " # Upsample and concatenate with low-level features\n", - " x = F.interpolate(x, size=feat[-1].size()[2:], mode='bilinear', align_corners=False)\n", - " concat_feat = torch.cat([x, aspp_out], dim=1)\n", - "\n", - " # Apply decoder to get the final segmentation output\n", - " out = self.decoder(concat_feat)\n", - " \n", - " return out\n", - "\n", - "class ASPPPooling(nn.Module):\n", - " def __init__(self, in_channels, out_channels):\n", - " super(ASPPPooling, self).__init__()\n", - " self.gap = nn.Sequential(\n", - " nn.AdaptiveAvgPool2d(1),\n", - " nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1)\n", - " )\n", - " \n", - " def forward(self, x):\n", - " pooled = self.gap(x)\n", - " return F.interpolate(pooled, size=x.size()[2:], mode='bilinear', align_corners=False)\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Example usage\n", - "# Create the model and move it to GPU if available\n", - "model = DeepLabV3Plus(num_classes=21, input_size=(512, 512)).cuda()\n", - "\n", - "# Define a dummy input tensor (e.g., batch size of 1, RGB images with defined resolution)\n", - "x = torch.randn(1, 3, 512, 512).cuda()\n", - "\n", - "# Forward pass through the model\n", - "output = model(x)\n", - "print(output.shape) # Should be (1, num_classes, defined resolution)\n", - "\n" + "import torch.nn.functional as F\n", + "from monai.networks.blocks import Convolution\n", + "from monai.networks.layers.factories import Act, Norm, Pool\n", + "from torchvision.models.resnet import resnet50, resnet101\n", + "from monai.networks.nets import resnet\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "---" + "### Config" ] }, { @@ -516,41 +459,37 @@ "metadata": {}, "outputs": [], "source": [ - "from typing import Sequence, Union\n", "\n", - "import torch\n", - "import torch.nn as nn\n", - "import torch.nn.functional as F\n", - "from monai.networks.blocks import Convolution, UpSample\n", - "from monai.networks.layers.factories import Act, Conv, Dropout, Norm, Pool\n", - "from monai.utils import ensure_tuple_rep\n", - "import math\n", - "from typing import Optional, Sequence, Type, Union\n" + "@dataclass\n", + "class DeeplabConfig:\n", + " dimensions: int\n", + " in_channels: int\n", + " out_channels: int\n", + " backbone: str = \"xception\" \n", + " pretrained: bool = False\n", + " middle_flow_blocks: int = 16\n", + " aspp_dilations: List[int] = field(default_factory=lambda: [1, 6, 12, 18])\n", + " entry_block3_stride: int = 2\n", + " middle_block_dilation: int = 1\n", + " exit_block_dilations: Tuple[int, int] = (1, 2)\n", + "\n", + "def get_padding(kernel_size: int, dilation: int) -> int:\n", + " return (kernel_size - 1) * dilation // 2\n", + "\n", + "def interpolate(x: torch.Tensor, size: Union[List[int], Tuple[int, ...]], mode: str) -> torch.Tensor:\n", + " if x.dim() == 4: # 2D\n", + " return F.interpolate(x, size=size, mode=mode, align_corners=True)\n", + " elif x.dim() == 5: # 3D\n", + " return F.interpolate(x, size=size, mode=mode, align_corners=True)\n", + " else:\n", + " raise ValueError(f\"Unsupported input dimension: {x.dim()}\")" ] }, { - "cell_type": "code", - "execution_count": null, + "cell_type": "markdown", "metadata": {}, - "outputs": [], "source": [ - "\n", - "# Function to apply fixed padding to inputs for convolutional layers\n", - "def fixed_padding(inputs, kernel_size, dilation):\n", - " # Calculate the effective kernel size considering dilation\n", - " kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1)\n", - " \n", - " # Calculate the total padding required on each side of the input\n", - " pad_total = kernel_size_effective - 1\n", - " \n", - " # Split the total padding equally between the beginning and end\n", - " pad_beg = pad_total // 2\n", - " pad_end = pad_total - pad_beg\n", - " \n", - " # Apply padding to the inputs using PyTorch's F.pad function\n", - " padded_inputs = F.pad(inputs, (pad_beg, pad_end, pad_beg, pad_end, pad_beg, pad_end))\n", - " \n", - " return padded_inputs\n" + "### Blocks" ] }, { @@ -559,83 +498,93 @@ "metadata": {}, "outputs": [], "source": [ - "class SeparableConv2d_same(nn.Module):\n", - " def __init__(self, inplanes, planes, kernel_size=3, stride=1, dilation=1, bias=False,norm=None):\n", - " super(SeparableConv2d_same, self).__init__()\n", - " dim = 3\n", - " self.kernel_size = kernel_size\n", - " self.dilation = dilation\n", - " self.conv1 = Convolution(dim, inplanes, inplanes, kernel_size=kernel_size,\n", - " groups=inplanes, padding=0, dilation=dilation, bias=bias, strides=stride)\n", - " if norm == None:\n", - " self.pointwise = Convolution(dim, inplanes, planes, kernel_size=1, strides=1,\n", - " padding=0, dilation=1, groups=1, bias=bias)\n", - " else:\n", - " self.pointwise = Convolution(dim, inplanes, planes, kernel_size=1, strides=1,\n", - " padding=0, dilation=1, groups=1, bias=bias,norm=Norm.BATCH)\n", + "class SeparableConv(nn.Module):\n", + " def __init__(self, config: DeeplabConfig, inplanes: int, planes: int, kernel_size: int = 3, stride: int = 1, \n", + " dilation: int = 1, bias: bool = False, norm: Optional[str] = None):\n", + " super().__init__()\n", + " self.conv1 = Convolution(\n", + " dimensions=config.dimensions, \n", + " in_channels=inplanes, \n", + " out_channels=inplanes, \n", + " kernel_size=kernel_size,\n", + " groups=inplanes, \n", + " padding=get_padding(kernel_size, dilation), \n", + " dilation=dilation, \n", + " bias=bias, \n", + " strides=stride\n", + " )\n", + " self.pointwise = Convolution(\n", + " dimensions=config.dimensions, \n", + " in_channels=inplanes, \n", + " out_channels=planes, \n", + " kernel_size=1, \n", + " strides=1,\n", + " padding=0, \n", + " dilation=1, \n", + " groups=1, \n", + " bias=bias,\n", + " norm=Norm.BATCH if norm else None\n", + " )\n", "\n", - " def forward(self, x):\n", - " x = fixed_padding(x, self.kernel_size, self.dilation)\n", + " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", " x = self.conv1(x)\n", " x = self.pointwise(x)\n", - " return x\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ + " return x\n", + "\n", "class Block(nn.Module):\n", - " def __init__(self, inplanes, planes, reps, stride=1, dilation=1, start_with_relu=True, grow_first=True, is_last=False):\n", - " super(Block, self).__init__()\n", - " dim = 3\n", + " def __init__(self, config: DeeplabConfig, inplanes: int, planes: int, reps: int, stride: int = 1, \n", + " dilation: int = 1, start_with_relu: bool = True, grow_first: bool = True, \n", + " is_last: bool = False):\n", + " super().__init__()\n", " if planes != inplanes or stride != 1:\n", - " self.skip = Convolution(dim, inplanes, planes, kernel_size=1, bias=False, strides=stride,norm=Norm.BATCH)\n", + " self.skip = Convolution(config.dimensions, inplanes, planes, kernel_size=1, bias=False, \n", + " strides=stride, norm=Norm.BATCH)\n", " else:\n", " self.skip = None\n", "\n", - " self.relu = Act[Act.RELU](inplace=True)\n", + " self.relu = nn.ReLU(inplace=True)\n", " rep = []\n", "\n", " filters = inplanes\n", " if grow_first:\n", " rep.append(self.relu)\n", - " rep.append(SeparableConv2d_same(inplanes, planes, 3, stride=1, dilation=dilation,norm=Norm.BATCH))\n", + " rep.append(SeparableConv(config, inplanes, planes, 3, stride=1, dilation=dilation, norm=Norm.BATCH))\n", " filters = planes\n", "\n", - " for i in range(reps - 1):\n", + " for _ in range(reps - 1):\n", " rep.append(self.relu)\n", - " rep.append(SeparableConv2d_same(filters, filters, 3, stride=1, dilation=dilation,norm=Norm.BATCH))\n", + " rep.append(SeparableConv(config, filters, filters, 3, stride=1, dilation=dilation, norm=Norm.BATCH))\n", "\n", " if not grow_first:\n", " rep.append(self.relu)\n", - " rep.append(SeparableConv2d_same(inplanes, planes, 3, stride=1, dilation=dilation,norm=Norm.BATCH))\n", + " rep.append(SeparableConv(config, inplanes, planes, 3, stride=1, dilation=dilation, norm=Norm.BATCH))\n", "\n", " if not start_with_relu:\n", " rep = rep[1:]\n", "\n", " if stride != 1:\n", - " rep.append(SeparableConv2d_same(planes, planes, 3, stride=2))\n", + " rep.append(SeparableConv(config, planes, planes, 3, stride=2))\n", "\n", " if stride == 1 and is_last:\n", - " rep.append(SeparableConv2d_same(planes, planes, 3, stride=1))\n", + " rep.append(SeparableConv(config, planes, planes, 3, stride=1))\n", "\n", " self.rep = nn.Sequential(*rep)\n", "\n", - " def forward(self, inp):\n", + " def forward(self, inp: torch.Tensor) -> torch.Tensor:\n", " x = self.rep(inp)\n", - "\n", " if self.skip is not None:\n", " skip = self.skip(inp)\n", " else:\n", " skip = inp\n", - "\n", " x += skip\n", - "\n", - " return x\n", - "\n" + " return x" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Aligned Xception" ] }, { @@ -644,59 +593,41 @@ "metadata": {}, "outputs": [], "source": [ - "\n", "class Xception(nn.Module):\n", - " \"\"\"\n", - " Modified Aligned Xception\n", - " \"\"\"\n", - "\n", - " def __init__(\n", - " self,\n", - " dim: int = 3,\n", - " in_chns: int = 1,\n", - " out_chns: int = 4,\n", - " ):\n", - "\n", - " super(Xception, self).__init__()\n", - "\n", - " entry_block3_stride = 2\n", - " middle_block_dilation = 1\n", - " exit_block_dilations = (1, 2)\n", - "\n", - " # entry flow\n", - " self.conv1 = Convolution(dim, in_chns, 32, kernel_size=3,bias=False, strides=2, padding=1,norm=Norm.BATCH)\n", - " self.relu = Act[Act.RELU](inplace=True)\n", + " def __init__(self, config: DeeplabConfig):\n", + " super().__init__()\n", + " self.config = config\n", "\n", - " self.conv2 = Convolution(dim, 32, 64, kernel_size=3,bias=False, strides=1, padding=1,norm=Norm.BATCH)\n", + " self.conv1 = Convolution(config.dimensions, config.in_channels, 32, kernel_size=3,\n", + " bias=False, strides=2, padding=1, norm=Norm.BATCH)\n", + " self.relu = nn.ReLU(inplace=True)\n", + " self.conv2 = Convolution(config.dimensions, 32, 64, kernel_size=3,\n", + " bias=False, strides=1, padding=1, norm=Norm.BATCH)\n", "\n", - " self.block1 = Block(64, 128, reps=2, stride=2, start_with_relu=False)\n", - " self.block2 = Block(128, 256, reps=2, stride=2, start_with_relu=True, grow_first=True)\n", - " self.block3 = Block(256, 728, reps=2, stride=entry_block3_stride, start_with_relu=True, grow_first=True,\n", - " is_last=True)\n", + " self.block1 = Block(config, 64, 128, reps=2, stride=2, start_with_relu=False)\n", + " self.block2 = Block(config, 128, 256, reps=2, stride=2, start_with_relu=True, grow_first=True)\n", + " self.block3 = Block(config, 256, 728, reps=2, stride=config.entry_block3_stride, \n", + " start_with_relu=True, grow_first=True, is_last=True)\n", "\n", " # Middle flow\n", - " self.middle_flow = nn.Sequential(\n", - " *[Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, start_with_relu=True, grow_first=True)\n", - " for _ in range(16)]\n", - " )\n", + " self.middle_flow = nn.Sequential(*[\n", + " Block(config, 728, 728, reps=3, stride=1, dilation=config.middle_block_dilation,\n", + " start_with_relu=True, grow_first=True)\n", + " for _ in range(config.middle_flow_blocks)\n", + " ])\n", "\n", " # Exit flow\n", - " self.block20 = Block(728, 1024, reps=2, stride=1, dilation=exit_block_dilations[0],\n", + " self.exit_block = Block(config, 728, 1024, reps=2, stride=1, dilation=config.exit_block_dilations[0],\n", " start_with_relu=True, grow_first=False, is_last=True)\n", "\n", - " self.conv3 = SeparableConv2d_same(1024, 1536, 3, stride=1, dilation=exit_block_dilations[1],norm=Norm.BATCH)\n", + " self.conv3 = SeparableConv(config, 1024, 1536, 3, stride=1, dilation=config.exit_block_dilations[1], norm=Norm.BATCH)\n", + " self.conv4 = SeparableConv(config, 1536, 1536, 3, stride=1, dilation=config.exit_block_dilations[1], norm=Norm.BATCH)\n", + " self.conv5 = SeparableConv(config, 1536, 2048, 3, stride=1, dilation=config.exit_block_dilations[1], norm=Norm.BATCH)\n", "\n", - " self.conv4 = SeparableConv2d_same(1536, 1536, 3, stride=1, dilation=exit_block_dilations[1],norm=Norm.BATCH)\n", - "\n", - " self.conv5 = SeparableConv2d_same(1536, 2048, 3, stride=1, dilation=exit_block_dilations[1],norm=Norm.BATCH)\n", - "\n", - "\n", - "\n", - " def forward(self, x):\n", + " def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:\n", " # Entry flow\n", " x = self.conv1(x)\n", " x = self.relu(x)\n", - "\n", " x = self.conv2(x)\n", " x = self.relu(x)\n", "\n", @@ -709,18 +640,22 @@ " x = self.middle_flow(x)\n", "\n", " # Exit flow\n", - " x = self.block20(x)\n", + " x = self.exit_block(x)\n", " x = self.conv3(x)\n", " x = self.relu(x)\n", - "\n", " x = self.conv4(x)\n", " x = self.relu(x)\n", - "\n", " x = self.conv5(x)\n", " x = self.relu(x)\n", "\n", - " return x, low_level_feat\n", - "\n" + " return x, low_level_feat" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### ASPP" ] }, { @@ -729,28 +664,26 @@ "metadata": {}, "outputs": [], "source": [ - "\n", "class ASPP_module(nn.Module):\n", - " def __init__(self, inplanes, planes, dilation):\n", - " super(ASPP_module, self).__init__()\n", - " dim = 3\n", - " if dilation == 1:\n", - " kernel_size = 1\n", - " padding = 0\n", - " else:\n", - " kernel_size = 3\n", - " padding = dilation\n", - " self.atrous_convolution = Convolution(dim, inplanes, planes, kernel_size=kernel_size,\n", - " strides=1, padding=padding, dilation=dilation, bias=False,norm=Norm.BATCH)\n", - " self.relu = Act[Act.RELU]()\n", - "\n", - "\n", - " def forward(self, x):\n", + " def __init__(self, config: DeeplabConfig, inplanes: int, planes: int, dilation: int):\n", + " super().__init__()\n", + " kernel_size = 1 if dilation == 1 else 3\n", + " padding = 0 if dilation == 1 else dilation\n", + " self.atrous_convolution = Convolution(config.dimensions, inplanes, planes, kernel_size=kernel_size,\n", + " strides=1, padding=padding, dilation=dilation, \n", + " bias=False, norm=Norm.BATCH)\n", + " self.relu = nn.ReLU()\n", + "\n", + " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", " x = self.atrous_convolution(x)\n", - "\n", - " return self.relu(x)\n", - "\n", - "\n" + " return self.relu(x)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### DeepLab" ] }, { @@ -759,72 +692,126 @@ "metadata": {}, "outputs": [], "source": [ - "\n", "class Deeplab(nn.Module):\n", - " def __init__(\n", - " self,\n", - " dimensions: int = 3,\n", - " in_channels: int = 1,\n", - " out_channels: int = 4,\n", - " ):\n", - "\n", - " super(Deeplab, self).__init__()\n", - " self.xception_features = Xception(dimensions, in_channels, out_channels)\n", + " def __init__(self, config: DeeplabConfig):\n", + " super().__init__()\n", + " self.config = config\n", + "\n", + " # Choose backbone based on configuration\n", + " if config.backbone == \"xception\":\n", + " self.backbone = Xception(config)\n", + " backbone_out_channels = 2048\n", + " elif config.backbone == \"resnet50\":\n", + " self.backbone = resnet.resnet50(pretrained=config.pretrained, spatial_dims=config.dimensions)\n", + " del self.backbone.fc\n", + " del self.backbone.avgpool\n", + " backbone_out_channels = 2048\n", + " elif config.backbone == \"resnet101\":\n", + " self.backbone = resnet.resnet101(pretrained=config.pretrained, spatial_dims=config.dimensions)\n", + " del self.backbone.fc\n", + " del self.backbone.avgpool\n", + " backbone_out_channels = 2048\n", + " else:\n", + " raise ValueError(f\"Unsupported backbone: {config.backbone}\")\n", "\n", " # ASPP\n", - " dilations = [1, 6, 12, 18]\n", - " self.aspp1 = ASPP_module(2048, 256, dilation=dilations[0])\n", - " self.aspp2 = ASPP_module(2048, 256, dilation=dilations[1])\n", - " self.aspp3 = ASPP_module(2048, 256, dilation=dilations[2])\n", - " self.aspp4 = ASPP_module(2048, 256, dilation=dilations[3])\n", - "\n", - " self.relu = Act[Act.RELU]()\n", - " pool_type: Type[Union[nn.MaxPool2d, nn.MaxPool3d]] = Pool[Pool.ADAPTIVEAVG, dimensions]\n", - " self.global_avg_pool = nn.Sequential(pool_type(1),\n", - " Convolution(dimensions, 2048, 256, kernel_size=1, strides=1, bias=False,norm=Norm.BATCH),\n", - " Act[Act.RELU]())\n", - "\n", - " self.conv1 = Convolution(dimensions, 1280, 256, kernel_size=1, bias=False,norm=Norm.BATCH)\n", - "\n", - " # adopt [1x1, 48] for channel reduction.\n", - " self.conv2 = Convolution(dimensions, 128, 48, kernel_size=1, bias=False,norm=Norm.BATCH)\n", - "\n", - " self.last_conv = nn.Sequential(Convolution(dimensions, 304, 256, kernel_size=3, strides=1, padding=1, bias=False,norm=Norm.BATCH),\n", - " Act[Act.RELU](),\n", - " Convolution(dimensions, 256, 256, kernel_size=3,\n", - " strides=1, padding=1, bias=False,norm=Norm.BATCH),\n", - " Act[Act.RELU](),\n", - " Convolution(dimensions, 256, out_channels, kernel_size=1, strides=1))\n", - "\n", - " def forward(self, input):\n", - "\n", - " x, low_level_features = self.xception_features(input)\n", - " x1 = self.aspp1(x)\n", - " x2 = self.aspp2(x)\n", - " x3 = self.aspp3(x)\n", - " x4 = self.aspp4(x)\n", + " self.aspp_modules = nn.ModuleList([\n", + " ASPP_module(config, backbone_out_channels, 256, dilation=dilation)\n", + " for dilation in config.aspp_dilations\n", + " ])\n", + "\n", + " self.global_avg_pool = nn.Sequential(\n", + " Pool[Pool.ADAPTIVEAVG, config.dimensions](1),\n", + " Convolution(config.dimensions, backbone_out_channels, 256, kernel_size=1, strides=1, bias=False, norm=Norm.BATCH),\n", + " nn.ReLU()\n", + " )\n", + "\n", + " self.conv1 = Convolution(config.dimensions, 1280, 256, kernel_size=1, bias=False, norm=Norm.BATCH)\n", + " self.conv2 = Convolution(config.dimensions, 256, 48, kernel_size=1, bias=False, norm=Norm.BATCH)\n", + "\n", + " self.last_conv = nn.Sequential(\n", + " Convolution(config.dimensions, 304, 256, kernel_size=3, strides=1, padding=1, bias=False, norm=Norm.BATCH),\n", + " nn.ReLU(),\n", + " Convolution(config.dimensions, 256, 256, kernel_size=3, strides=1, padding=1, bias=False, norm=Norm.BATCH),\n", + " nn.ReLU(),\n", + " Convolution(config.dimensions, 256, config.out_channels, kernel_size=1, strides=1)\n", + " )\n", + "\n", + " def forward(self, input: torch.Tensor) -> torch.Tensor:\n", + " if self.config.backbone == \"xception\":\n", + " x, low_level_features = self.backbone(input)\n", + " else:\n", + " x = self.backbone.conv1(input)\n", + " x = self.backbone.bn1(x)\n", + " x = self.backbone.relu(x)\n", + " x = self.backbone.maxpool(x)\n", + "\n", + " low_level_features = self.backbone.layer1(x)\n", + " x = self.backbone.layer2(low_level_features)\n", + " x = self.backbone.layer3(x)\n", + " x = self.backbone.layer4(x)\n", + "\n", + " aspp_results = [module(x) for module in self.aspp_modules]\n", " x5 = self.global_avg_pool(x)\n", - " x5 = F.interpolate(x5, size=(x4.size()[2],x4.size()[3],x4.size()[4]), mode='trilinear', align_corners=True)\n", - " x = torch.cat((x1, x2, x3, x4, x5), dim=1)\n", + " x5 = interpolate(x5, size=x.shape[2:], mode='linear' if self.config.dimensions == 2 else 'trilinear')\n", + " x = torch.cat(aspp_results + [x5], dim=1)\n", "\n", " x = self.conv1(x)\n", - " x = self.relu(x)\n", - " x = F.interpolate(x, size=(int(math.ceil(input.size()[-3] / 4)), int(math.ceil(input.size()[-2] / 4)),\n", - " int(math.ceil(input.size()[-1] / 4))), mode='trilinear', align_corners=True)\n", + " x = interpolate(x, size=low_level_features.shape[2:], mode='linear' if self.config.dimensions == 2 else 'trilinear')\n", "\n", " low_level_features = self.conv2(low_level_features)\n", - " low_level_features = self.relu(low_level_features)\n", "\n", " x = torch.cat((x, low_level_features), dim=1)\n", " x = self.last_conv(x)\n", - " x = F.interpolate(x, size=(input.size()[2],input.size()[3],input.size()[4]), mode='trilinear', align_corners=True)\n", + " x = interpolate(x, size=input.shape[2:], mode='linear' if self.config.dimensions == 2 else 'trilinear')\n", "\n", - " return x\n" + " return x" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Example" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# For 2D images\n", + "config_2d = DeeplabConfig(\n", + " dimensions=2,\n", + " in_channels=3, # For RGB images\n", + " out_channels=4,\n", + " middle_flow_blocks=16,\n", + " aspp_dilations=[1, 6, 12, 18]\n", + ")\n", + "model_2d = Deeplab(config_2d)\n", + "\n", + "# For 3D images\n", + "config_3d = DeeplabConfig(\n", + " dimensions=3,\n", + " in_channels=1, # For single-channel 3D medical images\n", + " out_channels=4,\n", + " middle_flow_blocks=16,\n", + " aspp_dilations=[1, 6, 12, 18]\n", + ")\n", + "model_3d = Deeplab(config_3d)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n" ] }, { "cell_type": "code", - "execution_count": 14, + "execution_count": null, "metadata": {}, "outputs": [], "source": [