Skip to content

Commit

Permalink
update nets
Browse files Browse the repository at this point in the history
  • Loading branch information
bmandracchia committed Jul 22, 2024
1 parent 1b1b545 commit 05d8d96
Showing 1 changed file with 68 additions and 51 deletions.
119 changes: 68 additions & 51 deletions nbs/01_nets.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,16 @@
"\n",
"from fastai.vision.all import ConvLayer, Lambda, MaxPool, NormType, nn, np\n",
"from torch import cat as torch_cat\n",
"from torch import Tensor as torch_Tensor\n",
"import torch.nn.functional as F\n",
"import torch.nn as nn\n",
"from monai.networks.blocks import Convolution\n",
"from monai.networks.layers.factories import Act, Norm, Pool\n",
"from monai.networks.nets import resnet\n",
"\n",
"from dataclasses import dataclass, field\n",
"from typing import List, Tuple, Optional, Union\n",
"\n",
"from bioMONAI.core import attributesFromDict"
]
},
Expand Down Expand Up @@ -426,24 +435,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## DeepLab v3"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from dataclasses import dataclass, field\n",
"from typing import List, Tuple, Optional, Union\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"from monai.networks.blocks import Convolution\n",
"from monai.networks.layers.factories import Act, Norm, Pool\n",
"from torchvision.models.resnet import resnet50, resnet101\n",
"from monai.networks.nets import resnet\n"
"## DeepLab v3+"
]
},
{
Expand All @@ -455,11 +447,10 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"\n",
"@dataclass\n",
"class DeeplabConfig:\n",
" dimensions: int\n",
Expand All @@ -475,14 +466,14 @@
"\n",
"def get_padding(kernel_size: int, dilation: int) -> int:\n",
" return (kernel_size - 1) * dilation // 2\n",
"\n",
"def interpolate(x: torch.Tensor, size: Union[List[int], Tuple[int, ...]], mode: str) -> torch.Tensor:\n",
" if x.dim() == 4: # 2D\n",
" return F.interpolate(x, size=size, mode=mode, align_corners=True)\n",
" elif x.dim() == 5: # 3D\n",
" return F.interpolate(x, size=size, mode=mode, align_corners=True)\n",
" \n",
"def interpolate(x: torch_Tensor, size: Union[List[int], Tuple[int, ...]], dims: int) -> torch_Tensor:\n",
" if dims == 2:\n",
" return F.interpolate(x, size=size, mode='bilinear', align_corners=True)\n",
" elif dims == 3:\n",
" return F.interpolate(x, size=size, mode='trilinear', align_corners=True)\n",
" else:\n",
" raise ValueError(f\"Unsupported input dimension: {x.dim()}\")"
" raise ValueError(f\"Unsupported number of dimensions: {dims}\")"
]
},
{
Expand All @@ -494,7 +485,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -503,7 +494,7 @@
" dilation: int = 1, bias: bool = False, norm: Optional[str] = None):\n",
" super().__init__()\n",
" self.conv1 = Convolution(\n",
" dimensions=config.dimensions, \n",
" spatial_dims=config.dimensions, \n",
" in_channels=inplanes, \n",
" out_channels=inplanes, \n",
" kernel_size=kernel_size,\n",
Expand All @@ -514,7 +505,7 @@
" strides=stride\n",
" )\n",
" self.pointwise = Convolution(\n",
" dimensions=config.dimensions, \n",
" spatial_dims=config.dimensions, \n",
" in_channels=inplanes, \n",
" out_channels=planes, \n",
" kernel_size=1, \n",
Expand All @@ -526,7 +517,7 @@
" norm=Norm.BATCH if norm else None\n",
" )\n",
"\n",
" def forward(self, x: torch.Tensor) -> torch.Tensor:\n",
" def forward(self, x: torch_Tensor) -> torch_Tensor:\n",
" x = self.conv1(x)\n",
" x = self.pointwise(x)\n",
" return x\n",
Expand Down Expand Up @@ -570,7 +561,7 @@
"\n",
" self.rep = nn.Sequential(*rep)\n",
"\n",
" def forward(self, inp: torch.Tensor) -> torch.Tensor:\n",
" def forward(self, inp: torch_Tensor) -> torch_Tensor:\n",
" x = self.rep(inp)\n",
" if self.skip is not None:\n",
" skip = self.skip(inp)\n",
Expand All @@ -589,7 +580,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -624,7 +615,7 @@
" self.conv4 = SeparableConv(config, 1536, 1536, 3, stride=1, dilation=config.exit_block_dilations[1], norm=Norm.BATCH)\n",
" self.conv5 = SeparableConv(config, 1536, 2048, 3, stride=1, dilation=config.exit_block_dilations[1], norm=Norm.BATCH)\n",
"\n",
" def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:\n",
" def forward(self, x: torch_Tensor) -> Tuple[torch_Tensor, torch_Tensor]:\n",
" # Entry flow\n",
" x = self.conv1(x)\n",
" x = self.relu(x)\n",
Expand Down Expand Up @@ -660,7 +651,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -674,7 +665,7 @@
" bias=False, norm=Norm.BATCH)\n",
" self.relu = nn.ReLU()\n",
"\n",
" def forward(self, x: torch.Tensor) -> torch.Tensor:\n",
" def forward(self, x: torch_Tensor) -> torch_Tensor:\n",
" x = self.atrous_convolution(x)\n",
" return self.relu(x)"
]
Expand All @@ -688,7 +679,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -737,7 +728,7 @@
" Convolution(config.dimensions, 256, config.out_channels, kernel_size=1, strides=1)\n",
" )\n",
"\n",
" def forward(self, input: torch.Tensor) -> torch.Tensor:\n",
" def forward(self, input: torch.Tensor) -> torch.Tensor:\n",
" if self.config.backbone == \"xception\":\n",
" x, low_level_features = self.backbone(input)\n",
" else:\n",
Expand All @@ -753,17 +744,17 @@
"\n",
" aspp_results = [module(x) for module in self.aspp_modules]\n",
" x5 = self.global_avg_pool(x)\n",
" x5 = interpolate(x5, size=x.shape[2:], mode='linear' if self.config.dimensions == 2 else 'trilinear')\n",
" x5 = interpolate(x5, size=x.shape[2:], dims=self.config.dimensions)\n",
" x = torch.cat(aspp_results + [x5], dim=1)\n",
"\n",
" x = self.conv1(x)\n",
" x = interpolate(x, size=low_level_features.shape[2:], mode='linear' if self.config.dimensions == 2 else 'trilinear')\n",
" x = interpolate(x, size=low_level_features.shape[2:], dims=self.config.dimensions)\n",
"\n",
" low_level_features = self.conv2(low_level_features)\n",
"\n",
" x = torch.cat((x, low_level_features), dim=1)\n",
" x = self.last_conv(x)\n",
" x = interpolate(x, size=input.shape[2:], mode='linear' if self.config.dimensions == 2 else 'trilinear')\n",
" x = interpolate(x, size=input.shape[2:], dims=self.config.dimensions)\n",
"\n",
" return x"
]
Expand All @@ -777,7 +768,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -786,20 +777,46 @@
" dimensions=2,\n",
" in_channels=3, # For RGB images\n",
" out_channels=4,\n",
" middle_flow_blocks=16,\n",
" backbone=\"xception\", # or whatever backbone you're using\n",
" aspp_dilations=[1, 6, 12, 18]\n",
")\n",
"model_2d = Deeplab(config_2d)\n",
"\n",
"# For 3D images\n",
"config_3d = DeeplabConfig(\n",
" dimensions=3,\n",
" in_channels=1, # For single-channel 3D medical images\n",
" out_channels=4,\n",
" middle_flow_blocks=16,\n",
" aspp_dilations=[1, 6, 12, 18]\n",
")\n",
"model_3d = Deeplab(config_3d)"
"# config_3d = DeeplabConfig(\n",
"# dimensions=3,\n",
"# in_channels=1, # For single-channel 3D medical images\n",
"# out_channels=4,\n",
"# middle_flow_blocks=16,\n",
"# aspp_dilations=[1, 6, 12, 18]\n",
"# )\n",
"# model_3d = Deeplab(config_3d)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"ename": "NotImplementedError",
"evalue": "Got 4D input, but linear mode needs 3D input",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mNotImplementedError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[23], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m x \u001b[38;5;241m=\u001b[39m torch_randn(\u001b[38;5;241m16\u001b[39m,\u001b[38;5;241m3\u001b[39m, \u001b[38;5;241m64\u001b[39m, \u001b[38;5;241m64\u001b[39m)\n\u001b[0;32m----> 2\u001b[0m y \u001b[38;5;241m=\u001b[39m model_2d(x)\n",
"File \u001b[0;32m~/miniconda3/envs/bioMONAI-env/lib/python3.11/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m forward_call(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 1502\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
"Cell \u001b[0;32mIn[18], line 62\u001b[0m, in \u001b[0;36mDeeplab.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 60\u001b[0m aspp_results \u001b[38;5;241m=\u001b[39m [module(x) \u001b[38;5;28;01mfor\u001b[39;00m module \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maspp_modules]\n\u001b[1;32m 61\u001b[0m x5 \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mglobal_avg_pool(x)\n\u001b[0;32m---> 62\u001b[0m x5 \u001b[38;5;241m=\u001b[39m interpolate(x5, size\u001b[38;5;241m=\u001b[39mx\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m2\u001b[39m:], mode\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mlinear\u001b[39m\u001b[38;5;124m'\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39mdimensions \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m2\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtrilinear\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m 63\u001b[0m x \u001b[38;5;241m=\u001b[39m torch_cat(aspp_results \u001b[38;5;241m+\u001b[39m [x5], dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m 65\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconv1(x)\n",
"Cell \u001b[0;32mIn[14], line 19\u001b[0m, in \u001b[0;36minterpolate\u001b[0;34m(x, size, mode)\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21minterpolate\u001b[39m(x: torch_Tensor, size: Union[List[\u001b[38;5;28mint\u001b[39m], Tuple[\u001b[38;5;28mint\u001b[39m, \u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m]], mode: \u001b[38;5;28mstr\u001b[39m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m torch_Tensor:\n\u001b[1;32m 18\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m x\u001b[38;5;241m.\u001b[39mdim() \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m4\u001b[39m: \u001b[38;5;66;03m# 2D\u001b[39;00m\n\u001b[0;32m---> 19\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m F\u001b[38;5;241m.\u001b[39minterpolate(x, size\u001b[38;5;241m=\u001b[39msize, mode\u001b[38;5;241m=\u001b[39mmode, align_corners\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m 20\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m x\u001b[38;5;241m.\u001b[39mdim() \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m5\u001b[39m: \u001b[38;5;66;03m# 3D\u001b[39;00m\n\u001b[1;32m 21\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m F\u001b[38;5;241m.\u001b[39minterpolate(x, size\u001b[38;5;241m=\u001b[39msize, mode\u001b[38;5;241m=\u001b[39mmode, align_corners\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n",
"File \u001b[0;32m~/miniconda3/envs/bioMONAI-env/lib/python3.11/site-packages/torch/nn/functional.py:3974\u001b[0m, in \u001b[0;36minterpolate\u001b[0;34m(input, size, scale_factor, mode, align_corners, recompute_scale_factor, antialias)\u001b[0m\n\u001b[1;32m 3972\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mNotImplementedError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mGot 3D input, but trilinear mode needs 5D input\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 3973\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28minput\u001b[39m\u001b[38;5;241m.\u001b[39mdim() \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m4\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m mode \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlinear\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[0;32m-> 3974\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mNotImplementedError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mGot 4D input, but linear mode needs 3D input\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 3975\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28minput\u001b[39m\u001b[38;5;241m.\u001b[39mdim() \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m4\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m mode \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtrilinear\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[1;32m 3976\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mNotImplementedError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mGot 4D input, but trilinear mode needs 5D input\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
"\u001b[0;31mNotImplementedError\u001b[0m: Got 4D input, but linear mode needs 3D input"
]
}
],
"source": [
"x = torch_randn(16,3, 64, 64)\n",
"y = model_2d(x)"
]
},
{
Expand Down

0 comments on commit 05d8d96

Please sign in to comment.