-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
d2a8a56
commit 6d2973a
Showing
5 changed files
with
382 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
{ | ||
"identifier": "resunet", | ||
"network_config": {}, | ||
|
||
"train_config":{ | ||
"loss": "TripletLoss", | ||
"tl_margin": 0.539, | ||
"miner": true, | ||
"miner_margin": 0.734, | ||
"learning_rate": 5.945e-05, | ||
"optimizer": "Adam", | ||
"weight_decay": 0, | ||
"batchsize": 35, | ||
"patience": 50, | ||
"aug_train_shift_distance": 2 | ||
}, | ||
|
||
"distance": "COSINE" | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
from typing import Dict, Union | ||
import torch | ||
import torch.nn as nn | ||
from tomotwin.modules.networks.torchmodel import TorchModel | ||
from tomotwin.modules.networks.torchmodel import TorchModel | ||
|
||
class DoubleConv(nn.Module): | ||
def __init__(self, in_channels, out_channels): | ||
super(DoubleConv, self).__init__() | ||
self.double_conv = nn.Sequential( | ||
nn.Conv3d(in_channels, out_channels, kernel_size=3, padding='same'), | ||
nn.BatchNorm3d(out_channels), | ||
nn.ReLU(inplace=True), | ||
nn.Conv3d(out_channels, out_channels, kernel_size=3, padding='same'), | ||
nn.BatchNorm3d(out_channels), | ||
nn.ReLU(inplace=True) | ||
) | ||
|
||
def forward(self, x): | ||
return self.double_conv(x) | ||
|
||
class Down(nn.Module): | ||
def __init__(self, in_channels, out_channels): | ||
super(Down, self).__init__() | ||
self.maxpool_conv = nn.Sequential( | ||
nn.MaxPool3d(2), | ||
DoubleConv(in_channels, out_channels) | ||
) | ||
|
||
def forward(self, x): | ||
return self.maxpool_conv(x) | ||
|
||
class Up(nn.Module): | ||
def __init__(self, in_channels, out_channels): | ||
super(Up, self).__init__() | ||
self.up = nn.ConvTranspose3d(in_channels , in_channels , kernel_size=2, stride=2) | ||
self.conv = DoubleConv(in_channels, out_channels) | ||
|
||
def forward(self, x1): | ||
x1 = self.up(x1) | ||
return self.conv(x1) | ||
|
||
class OutConv(nn.Module): | ||
def __init__(self, in_channels, out_channels): | ||
super(OutConv, self).__init__() | ||
self.conv = nn.Sequential( nn.Conv3d(in_channels, out_channels, kernel_size=3, padding ='same'), | ||
nn.Sigmoid()) | ||
|
||
def forward(self, x): | ||
return self.conv(x) | ||
|
||
class UNet3D(nn.Module): | ||
def __init__(self, n_channels, out_channels): | ||
super(UNet3D, self).__init__() | ||
self.n_channels = n_channels | ||
self.out_channels = out_channels | ||
|
||
self.inc = DoubleConv(n_channels, 32) | ||
self.down1 = Down(32, 32) | ||
self.down2 = Down(32, 64) | ||
self.down3 = Down(64, 64) | ||
#self.up1 = Up(64, 64) | ||
#self.up2 = Up(64, 32) | ||
#self.up3 = Up(32,32) | ||
#self.outc = OutConv(32, out_channels) | ||
self.Flatten = nn.Flatten() | ||
self._initialize_weights() | ||
self._load_checkpoint() | ||
|
||
|
||
def forward(self, x, return_embedding=False): | ||
x1 = self.inc(x) | ||
x2 = self.down1(x1) | ||
x3 = self.down2(x2) | ||
x4 = self.down3(x3) | ||
x_flat = self.Flatten(x4) | ||
x_flat = torch.nn.functional.normalize(x_flat, p=2.0, dim=1) | ||
#x = x_flat.view(x4.size()) | ||
#x = self.up1(x) | ||
#x = self.up2(x) | ||
#x = self.up3(x) | ||
#logits = self.outc(x) | ||
return x_flat | ||
|
||
def _initialize_weights(self): | ||
for m in self.modules(): | ||
if isinstance(m, nn.Conv3d) or isinstance(m, nn.ConvTranspose3d): | ||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') | ||
if m.bias is not None: | ||
nn.init.constant_(m.bias, 0) | ||
elif isinstance(m, nn.BatchNorm3d): | ||
nn.init.constant_(m.weight, 1) | ||
nn.init.constant_(m.bias, 0) | ||
|
||
def _load_checkpoint(self): | ||
chk_pth = '/home/yousef.metwally/projects/no32_64_4_sphere/weights/model_weights_epoch_206.pt' | ||
checkpoint = torch.load(chk_pth) | ||
state_dict = {k.replace('module.', ''): v for k, v in checkpoint.items()} | ||
model_state_dict = self.state_dict() | ||
filtered_state_dict = {k: v for k, v in state_dict.items() if k in model_state_dict} | ||
self.load_state_dict(filtered_state_dict) | ||
|
||
|
||
|
||
class UNet(TorchModel): | ||
def __init__(self) -> None: | ||
super().__init__() | ||
self.model = UNet3D(1,1) | ||
|
||
def init_weights(self): | ||
self.model._initialize_weights() | ||
self.model._load_checkpoint() | ||
|
||
def get_model(self) -> nn.Module: | ||
return self.model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
from typing import Dict, Union | ||
import torch | ||
import torch.nn as nn | ||
from tomotwin.modules.networks.torchmodel import TorchModel | ||
|
||
|
||
class DoubleConv(nn.Module): | ||
def __init__(self, in_channels, out_channels, num_groups): | ||
super(DoubleConv, self).__init__() | ||
self.double_conv = nn.Sequential( | ||
nn.Conv3d(in_channels, out_channels, kernel_size=3, padding='same'), | ||
nn.GroupNorm(num_groups, out_channels), | ||
nn.ReLU(inplace=True), | ||
nn.Conv3d(out_channels, out_channels, kernel_size=3, padding='same'), | ||
nn.GroupNorm(num_groups, out_channels), | ||
nn.ReLU(inplace=True) | ||
) | ||
|
||
def forward(self, x): | ||
return self.double_conv(x) | ||
|
||
class Down(nn.Module): | ||
def __init__(self, in_channels, out_channels, num_groups): | ||
super(Down, self).__init__() | ||
self.maxpool_conv = nn.Sequential( | ||
nn.MaxPool3d(2), | ||
DoubleConv(in_channels, out_channels, num_groups) | ||
) | ||
|
||
def forward(self, x): | ||
return self.maxpool_conv(x) | ||
|
||
class Up(nn.Module): | ||
def __init__(self, in_channels, out_channels, num_groups): | ||
super(Up, self).__init__() | ||
self.up = nn.ConvTranspose3d(in_channels, in_channels, kernel_size=2, stride=2) | ||
self.conv = DoubleConv(in_channels, out_channels, num_groups) | ||
|
||
def forward(self, x1): | ||
x1 = self.up(x1) | ||
return self.conv(x1) | ||
|
||
class OutConv(nn.Module): | ||
def __init__(self, in_channels, out_channels): | ||
super(OutConv, self).__init__() | ||
self.conv = nn.Sequential(nn.Conv3d(in_channels, out_channels, kernel_size=3, padding='same'), | ||
nn.Sigmoid()) | ||
|
||
def forward(self, x): | ||
return self.conv(x) | ||
|
||
class UNet3D(nn.Module): | ||
def __init__(self, n_channels, out_channels, num_groups=8): | ||
super(UNet3D, self).__init__() | ||
self.n_channels = n_channels | ||
self.out_channels = out_channels | ||
|
||
self.inc = DoubleConv(n_channels, 32, num_groups) | ||
self.down1 = Down(32, 32, num_groups) | ||
self.down2 = Down(32, 64, num_groups) | ||
self.down3 = Down(64, 64, num_groups) | ||
#self.up1 = Up(64, 64, num_groups) | ||
#self.up2 = Up(64, 32, num_groups) | ||
#self.up3 = Up(32, 32, num_groups) | ||
#self.outc = OutConv(32, out_channels) | ||
self.Flatten = nn.Flatten() | ||
self._initialize_weights() | ||
self._load_checkpoint() | ||
def forward(self, x): | ||
x1 = self.inc(x) | ||
x2 = self.down1(x1) | ||
x3 = self.down2(x2) | ||
x4 = self.down3(x3) | ||
x_flat = self.Flatten(x4) | ||
x_flat = torch.nn.functional.normalize(x_flat, p=2.0, dim=1) | ||
#x = x_flat.view(x4.size()) | ||
#x = self.up1(x) | ||
#x = self.up2(x) | ||
#x = self.up3(x) | ||
#logits = self.outc(x) | ||
return x_flat | ||
|
||
def _initialize_weights(self): | ||
for m in self.modules(): | ||
if isinstance(m, nn.Conv3d) or isinstance(m, nn.ConvTranspose3d): | ||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') | ||
if m.bias is not None: | ||
nn.init.constant_(m.bias, 0) | ||
elif isinstance(m, nn.BatchNorm3d) or isinstance(m, nn.GroupNorm): | ||
nn.init.constant_(m.weight, 1) | ||
nn.init.constant_(m.bias, 0) | ||
|
||
|
||
def _load_checkpoint(self): | ||
chk_pth = '/home/yousef.metwally/projects/no32_64_4_sphere/weights/model_weights_epoch_206.pt' | ||
checkpoint = torch.load(chk_pth) | ||
state_dict = {k.replace('module.', ''): v for k, v in checkpoint.items()} | ||
model_state_dict = self.state_dict() | ||
filtered_state_dict = {k: v for k, v in state_dict.items() if k in model_state_dict} | ||
self.load_state_dict(filtered_state_dict) | ||
|
||
|
||
|
||
|
||
class UNet_GN(TorchModel): | ||
def __init__(self) -> None: | ||
super().__init__() | ||
self.model = UNet3D(1,1) | ||
|
||
def init_weights(self): | ||
self.model._initialize_weights() | ||
self.model._load_checkpoint() | ||
|
||
def get_model(self) -> nn.Module: | ||
return self.model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
from typing import Dict, Union | ||
import torch | ||
import torch.nn as nn | ||
from tomotwin.modules.networks.torchmodel import TorchModel | ||
|
||
|
||
class ResidualBlock(nn.Module): | ||
def __init__(self, in_channels, out_channels, stride=1): | ||
super(ResidualBlock, self).__init__() | ||
self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False) | ||
self.bn1 = nn.BatchNorm3d(out_channels) | ||
self.relu = nn.ReLU(inplace=True) | ||
self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False) | ||
self.bn2 = nn.BatchNorm3d(out_channels) | ||
|
||
self.shortcut = nn.Sequential() | ||
if stride != 1 or in_channels != out_channels: | ||
self.shortcut = nn.Sequential( | ||
nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False), | ||
nn.BatchNorm3d(out_channels) | ||
) | ||
|
||
def forward(self, x): | ||
out = self.relu(self.bn1(self.conv1(x))) | ||
out = self.bn2(self.conv2(out)) | ||
out += self.shortcut(x) | ||
out = self.relu(out) | ||
return out | ||
|
||
class DoubleConv(nn.Module): | ||
def __init__(self, in_channels, out_channels): | ||
super(DoubleConv, self).__init__() | ||
self.double_conv = nn.Sequential( | ||
ResidualBlock(in_channels, out_channels), | ||
ResidualBlock(out_channels, out_channels) | ||
) | ||
|
||
def forward(self, x): | ||
return self.double_conv(x) | ||
|
||
class Down(nn.Module): | ||
def __init__(self, in_channels, out_channels): | ||
super(Down, self).__init__() | ||
self.maxpool_conv = nn.Sequential( | ||
nn.MaxPool3d(2), | ||
DoubleConv(in_channels, out_channels) | ||
) | ||
|
||
def forward(self, x): | ||
return self.maxpool_conv(x) | ||
|
||
class Up(nn.Module): | ||
def __init__(self, in_channels, out_channels): | ||
super(Up, self).__init__() | ||
self.up = nn.ConvTranspose3d(in_channels , in_channels , kernel_size=2, stride=2) | ||
self.conv = DoubleConv(in_channels, out_channels) | ||
|
||
def forward(self, x1): | ||
x1 = self.up(x1) | ||
return self.conv(x1) | ||
|
||
class OutConv(nn.Module): | ||
def __init__(self, in_channels, out_channels): | ||
super(OutConv, self).__init__() | ||
self.conv = nn.Sequential( nn.Conv3d(in_channels, out_channels, kernel_size=3, padding ='same'), | ||
nn.Sigmoid()) | ||
|
||
def forward(self, x): | ||
return self.conv(x) | ||
|
||
class UNet3D(nn.Module): | ||
def __init__(self, n_channels, out_channels): | ||
super(UNet3D, self).__init__() | ||
self.n_channels = n_channels | ||
self.out_channels = out_channels | ||
|
||
self.inc = DoubleConv(n_channels, 64) | ||
self.down1 = Down(64, 64) | ||
self.down2 = Down(64, 64) | ||
self.down3 = Down(64, 64) | ||
self.ls = nn.Conv3d(64, 64, kernel_size=3, padding='same') | ||
|
||
self.Flatten = nn.Flatten() | ||
self._initialize_weights() | ||
|
||
def forward(self, x): | ||
x1 = self.inc(x) | ||
x2 = self.down1(x1) | ||
x3 = self.down2(x2) | ||
x4 = self.down3(x3) | ||
ls = self.ls(x4) | ||
x_flat = self.Flatten(ls) | ||
x_flat = torch.nn.functional.normalize(x_flat, p=2.0, dim=1) | ||
return x_flat | ||
|
||
def _initialize_weights(self): | ||
for m in self.modules(): | ||
if isinstance(m, nn.Conv3d) or isinstance(m, nn.ConvTranspose3d): | ||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') | ||
if m.bias is not None: | ||
nn.init.constant_(m.bias, 0) | ||
elif isinstance(m, nn.BatchNorm3d): | ||
nn.init.constant_(m.weight, 1) | ||
nn.init.constant_(m.bias, 0) | ||
|
||
|
||
def _load_checkpoint(self): | ||
chk_pth = '/home/yousef.metwally/projects/no32_64_4_sphere/weights/model_weights_epoch_206.pt' | ||
checkpoint = torch.load(chk_pth) | ||
state_dict = {k.replace('module.', ''): v for k, v in checkpoint.items()} | ||
model_state_dict = self.state_dict() | ||
filtered_state_dict = {k: v for k, v in state_dict.items() if k in model_state_dict} | ||
self.load_state_dict(filtered_state_dict) | ||
|
||
|
||
|
||
class resunet(TorchModel): | ||
def __init__(self) -> None: | ||
super().__init__() | ||
self.model = UNet3D(1,1) | ||
|
||
def init_weights(self): | ||
self.model._initialize_weights() | ||
#self.model._load_checkpoint() | ||
|
||
def get_model(self) -> nn.Module: | ||
return self.model |