From 6d2973a79df35737716b1b9964d83f34ea0aa1cb Mon Sep 17 00:00:00 2001 From: YousefMetwally Date: Wed, 31 Jul 2024 13:59:29 +0200 Subject: [PATCH] resunet network included --- resources/configs/config_resunet.json | 19 +++ tomotwin/modules/networks/Unet.py | 115 ++++++++++++++++++ tomotwin/modules/networks/Unet_GN.py | 115 ++++++++++++++++++ tomotwin/modules/networks/networkmanager.py | 6 + tomotwin/modules/networks/resunet.py | 127 ++++++++++++++++++++ 5 files changed, 382 insertions(+) create mode 100644 resources/configs/config_resunet.json create mode 100644 tomotwin/modules/networks/Unet.py create mode 100644 tomotwin/modules/networks/Unet_GN.py create mode 100644 tomotwin/modules/networks/resunet.py diff --git a/resources/configs/config_resunet.json b/resources/configs/config_resunet.json new file mode 100644 index 0000000..3a41314 --- /dev/null +++ b/resources/configs/config_resunet.json @@ -0,0 +1,19 @@ +{ + "identifier": "resunet", + "network_config": {}, + + "train_config":{ + "loss": "TripletLoss", + "tl_margin": 0.539, + "miner": true, + "miner_margin": 0.734, + "learning_rate": 5.945e-05, + "optimizer": "Adam", + "weight_decay": 0, + "batchsize": 35, + "patience": 50, + "aug_train_shift_distance": 2 + }, + + "distance": "COSINE" +} diff --git a/tomotwin/modules/networks/Unet.py b/tomotwin/modules/networks/Unet.py new file mode 100644 index 0000000..8082a2e --- /dev/null +++ b/tomotwin/modules/networks/Unet.py @@ -0,0 +1,115 @@ +from typing import Dict, Union +import torch +import torch.nn as nn +from tomotwin.modules.networks.torchmodel import TorchModel +from tomotwin.modules.networks.torchmodel import TorchModel + +class DoubleConv(nn.Module): + def __init__(self, in_channels, out_channels): + super(DoubleConv, self).__init__() + self.double_conv = nn.Sequential( + nn.Conv3d(in_channels, out_channels, kernel_size=3, padding='same'), + nn.BatchNorm3d(out_channels), + nn.ReLU(inplace=True), + nn.Conv3d(out_channels, out_channels, kernel_size=3, padding='same'), + nn.BatchNorm3d(out_channels), + nn.ReLU(inplace=True) + ) + + def forward(self, x): + return self.double_conv(x) + +class Down(nn.Module): + def __init__(self, in_channels, out_channels): + super(Down, self).__init__() + self.maxpool_conv = nn.Sequential( + nn.MaxPool3d(2), + DoubleConv(in_channels, out_channels) + ) + + def forward(self, x): + return self.maxpool_conv(x) + +class Up(nn.Module): + def __init__(self, in_channels, out_channels): + super(Up, self).__init__() + self.up = nn.ConvTranspose3d(in_channels , in_channels , kernel_size=2, stride=2) + self.conv = DoubleConv(in_channels, out_channels) + + def forward(self, x1): + x1 = self.up(x1) + return self.conv(x1) + +class OutConv(nn.Module): + def __init__(self, in_channels, out_channels): + super(OutConv, self).__init__() + self.conv = nn.Sequential( nn.Conv3d(in_channels, out_channels, kernel_size=3, padding ='same'), + nn.Sigmoid()) + + def forward(self, x): + return self.conv(x) + +class UNet3D(nn.Module): + def __init__(self, n_channels, out_channels): + super(UNet3D, self).__init__() + self.n_channels = n_channels + self.out_channels = out_channels + + self.inc = DoubleConv(n_channels, 32) + self.down1 = Down(32, 32) + self.down2 = Down(32, 64) + self.down3 = Down(64, 64) + #self.up1 = Up(64, 64) + #self.up2 = Up(64, 32) + #self.up3 = Up(32,32) + #self.outc = OutConv(32, out_channels) + self.Flatten = nn.Flatten() + self._initialize_weights() + self._load_checkpoint() + + + def forward(self, x, return_embedding=False): + x1 = self.inc(x) + x2 = self.down1(x1) + x3 = self.down2(x2) + x4 = self.down3(x3) + x_flat = self.Flatten(x4) + x_flat = torch.nn.functional.normalize(x_flat, p=2.0, dim=1) + #x = x_flat.view(x4.size()) + #x = self.up1(x) + #x = self.up2(x) + #x = self.up3(x) + #logits = self.outc(x) + return x_flat + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv3d) or isinstance(m, nn.ConvTranspose3d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm3d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def _load_checkpoint(self): + chk_pth = '/home/yousef.metwally/projects/no32_64_4_sphere/weights/model_weights_epoch_206.pt' + checkpoint = torch.load(chk_pth) + state_dict = {k.replace('module.', ''): v for k, v in checkpoint.items()} + model_state_dict = self.state_dict() + filtered_state_dict = {k: v for k, v in state_dict.items() if k in model_state_dict} + self.load_state_dict(filtered_state_dict) + + + +class UNet(TorchModel): + def __init__(self) -> None: + super().__init__() + self.model = UNet3D(1,1) + + def init_weights(self): + self.model._initialize_weights() + self.model._load_checkpoint() + + def get_model(self) -> nn.Module: + return self.model diff --git a/tomotwin/modules/networks/Unet_GN.py b/tomotwin/modules/networks/Unet_GN.py new file mode 100644 index 0000000..6917e21 --- /dev/null +++ b/tomotwin/modules/networks/Unet_GN.py @@ -0,0 +1,115 @@ +from typing import Dict, Union +import torch +import torch.nn as nn +from tomotwin.modules.networks.torchmodel import TorchModel + + +class DoubleConv(nn.Module): + def __init__(self, in_channels, out_channels, num_groups): + super(DoubleConv, self).__init__() + self.double_conv = nn.Sequential( + nn.Conv3d(in_channels, out_channels, kernel_size=3, padding='same'), + nn.GroupNorm(num_groups, out_channels), + nn.ReLU(inplace=True), + nn.Conv3d(out_channels, out_channels, kernel_size=3, padding='same'), + nn.GroupNorm(num_groups, out_channels), + nn.ReLU(inplace=True) + ) + + def forward(self, x): + return self.double_conv(x) + +class Down(nn.Module): + def __init__(self, in_channels, out_channels, num_groups): + super(Down, self).__init__() + self.maxpool_conv = nn.Sequential( + nn.MaxPool3d(2), + DoubleConv(in_channels, out_channels, num_groups) + ) + + def forward(self, x): + return self.maxpool_conv(x) + +class Up(nn.Module): + def __init__(self, in_channels, out_channels, num_groups): + super(Up, self).__init__() + self.up = nn.ConvTranspose3d(in_channels, in_channels, kernel_size=2, stride=2) + self.conv = DoubleConv(in_channels, out_channels, num_groups) + + def forward(self, x1): + x1 = self.up(x1) + return self.conv(x1) + +class OutConv(nn.Module): + def __init__(self, in_channels, out_channels): + super(OutConv, self).__init__() + self.conv = nn.Sequential(nn.Conv3d(in_channels, out_channels, kernel_size=3, padding='same'), + nn.Sigmoid()) + + def forward(self, x): + return self.conv(x) + +class UNet3D(nn.Module): + def __init__(self, n_channels, out_channels, num_groups=8): + super(UNet3D, self).__init__() + self.n_channels = n_channels + self.out_channels = out_channels + + self.inc = DoubleConv(n_channels, 32, num_groups) + self.down1 = Down(32, 32, num_groups) + self.down2 = Down(32, 64, num_groups) + self.down3 = Down(64, 64, num_groups) + #self.up1 = Up(64, 64, num_groups) + #self.up2 = Up(64, 32, num_groups) + #self.up3 = Up(32, 32, num_groups) + #self.outc = OutConv(32, out_channels) + self.Flatten = nn.Flatten() + self._initialize_weights() + self._load_checkpoint() + def forward(self, x): + x1 = self.inc(x) + x2 = self.down1(x1) + x3 = self.down2(x2) + x4 = self.down3(x3) + x_flat = self.Flatten(x4) + x_flat = torch.nn.functional.normalize(x_flat, p=2.0, dim=1) + #x = x_flat.view(x4.size()) + #x = self.up1(x) + #x = self.up2(x) + #x = self.up3(x) + #logits = self.outc(x) + return x_flat + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv3d) or isinstance(m, nn.ConvTranspose3d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm3d) or isinstance(m, nn.GroupNorm): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + + def _load_checkpoint(self): + chk_pth = '/home/yousef.metwally/projects/no32_64_4_sphere/weights/model_weights_epoch_206.pt' + checkpoint = torch.load(chk_pth) + state_dict = {k.replace('module.', ''): v for k, v in checkpoint.items()} + model_state_dict = self.state_dict() + filtered_state_dict = {k: v for k, v in state_dict.items() if k in model_state_dict} + self.load_state_dict(filtered_state_dict) + + + + +class UNet_GN(TorchModel): + def __init__(self) -> None: + super().__init__() + self.model = UNet3D(1,1) + + def init_weights(self): + self.model._initialize_weights() + self.model._load_checkpoint() + + def get_model(self) -> nn.Module: + return self.model diff --git a/tomotwin/modules/networks/networkmanager.py b/tomotwin/modules/networks/networkmanager.py index 4599ac1..311729f 100644 --- a/tomotwin/modules/networks/networkmanager.py +++ b/tomotwin/modules/networks/networkmanager.py @@ -17,6 +17,9 @@ from tomotwin.modules.networks.SiameseNet3D import SiameseNet3D from tomotwin.modules.networks.resnet import Resnet from tomotwin.modules.networks.torchmodel import TorchModel +from tomotwin.modules.networks.Unet import UNet +from tomotwin.modules.networks.Unet_GN import UNet_GN +from tomotwin.modules.networks.resunet import resunet class NetworkNotExistError(Exception): @@ -34,6 +37,9 @@ class NetworkManager: network_identifier_map = { "SiameseNet".upper(): SiameseNet3D, "ResNet".upper(): Resnet, + "UNet".upper(): UNet, + "UNet_GN".upper(): UNet_GN, + "resunet".upper(): resunet } diff --git a/tomotwin/modules/networks/resunet.py b/tomotwin/modules/networks/resunet.py new file mode 100644 index 0000000..21e3141 --- /dev/null +++ b/tomotwin/modules/networks/resunet.py @@ -0,0 +1,127 @@ +from typing import Dict, Union +import torch +import torch.nn as nn +from tomotwin.modules.networks.torchmodel import TorchModel + + +class ResidualBlock(nn.Module): + def __init__(self, in_channels, out_channels, stride=1): + super(ResidualBlock, self).__init__() + self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn1 = nn.BatchNorm3d(out_channels) + self.relu = nn.ReLU(inplace=True) + self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False) + self.bn2 = nn.BatchNorm3d(out_channels) + + self.shortcut = nn.Sequential() + if stride != 1 or in_channels != out_channels: + self.shortcut = nn.Sequential( + nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm3d(out_channels) + ) + + def forward(self, x): + out = self.relu(self.bn1(self.conv1(x))) + out = self.bn2(self.conv2(out)) + out += self.shortcut(x) + out = self.relu(out) + return out + +class DoubleConv(nn.Module): + def __init__(self, in_channels, out_channels): + super(DoubleConv, self).__init__() + self.double_conv = nn.Sequential( + ResidualBlock(in_channels, out_channels), + ResidualBlock(out_channels, out_channels) + ) + + def forward(self, x): + return self.double_conv(x) + +class Down(nn.Module): + def __init__(self, in_channels, out_channels): + super(Down, self).__init__() + self.maxpool_conv = nn.Sequential( + nn.MaxPool3d(2), + DoubleConv(in_channels, out_channels) + ) + + def forward(self, x): + return self.maxpool_conv(x) + +class Up(nn.Module): + def __init__(self, in_channels, out_channels): + super(Up, self).__init__() + self.up = nn.ConvTranspose3d(in_channels , in_channels , kernel_size=2, stride=2) + self.conv = DoubleConv(in_channels, out_channels) + + def forward(self, x1): + x1 = self.up(x1) + return self.conv(x1) + +class OutConv(nn.Module): + def __init__(self, in_channels, out_channels): + super(OutConv, self).__init__() + self.conv = nn.Sequential( nn.Conv3d(in_channels, out_channels, kernel_size=3, padding ='same'), + nn.Sigmoid()) + + def forward(self, x): + return self.conv(x) + +class UNet3D(nn.Module): + def __init__(self, n_channels, out_channels): + super(UNet3D, self).__init__() + self.n_channels = n_channels + self.out_channels = out_channels + + self.inc = DoubleConv(n_channels, 64) + self.down1 = Down(64, 64) + self.down2 = Down(64, 64) + self.down3 = Down(64, 64) + self.ls = nn.Conv3d(64, 64, kernel_size=3, padding='same') + + self.Flatten = nn.Flatten() + self._initialize_weights() + + def forward(self, x): + x1 = self.inc(x) + x2 = self.down1(x1) + x3 = self.down2(x2) + x4 = self.down3(x3) + ls = self.ls(x4) + x_flat = self.Flatten(ls) + x_flat = torch.nn.functional.normalize(x_flat, p=2.0, dim=1) + return x_flat + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv3d) or isinstance(m, nn.ConvTranspose3d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm3d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + + def _load_checkpoint(self): + chk_pth = '/home/yousef.metwally/projects/no32_64_4_sphere/weights/model_weights_epoch_206.pt' + checkpoint = torch.load(chk_pth) + state_dict = {k.replace('module.', ''): v for k, v in checkpoint.items()} + model_state_dict = self.state_dict() + filtered_state_dict = {k: v for k, v in state_dict.items() if k in model_state_dict} + self.load_state_dict(filtered_state_dict) + + + +class resunet(TorchModel): + def __init__(self) -> None: + super().__init__() + self.model = UNet3D(1,1) + + def init_weights(self): + self.model._initialize_weights() + #self.model._load_checkpoint() + + def get_model(self) -> nn.Module: + return self.model