From 3f9b276f22d70d4750eae694f4aabc3ca9425b36 Mon Sep 17 00:00:00 2001 From: Edwards Date: Wed, 9 Sep 2020 13:10:18 -0700 Subject: [PATCH] Changes to pt_3dresnet allowing for a softmax output when using one-hot encoding of labels, and simply applying sigmoid to map to [0,1] otherwise. --- fets/data/pytorch/ptbrainmagedata.py | 5 +++++ fets/models/pytorch/brainmage/seg_modules.py | 13 +++++++++++-- fets/models/pytorch/pt_3dresunet/pt_3dresunet.py | 2 +- 3 files changed, 17 insertions(+), 3 deletions(-) diff --git a/fets/data/pytorch/ptbrainmagedata.py b/fets/data/pytorch/ptbrainmagedata.py index bf7270e..c7a8cb9 100644 --- a/fets/data/pytorch/ptbrainmagedata.py +++ b/fets/data/pytorch/ptbrainmagedata.py @@ -112,6 +112,11 @@ def __init__(self, # if we are performing binary classification per pixel, we will disable one_hot conversion of labels self.binary_classification = self.n_classes == 2 + + # there is an assumption that for binary classification, new classes are exactly 0 and 1 + if self.binary_classification: + if set(class_label_map.values()) != set([0, 1]): + raise ValueError("When performing binary classification, the new labels should be 0 and 1") diff --git a/fets/models/pytorch/brainmage/seg_modules.py b/fets/models/pytorch/brainmage/seg_modules.py index ab5e9b0..10a2113 100644 --- a/fets/models/pytorch/brainmage/seg_modules.py +++ b/fets/models/pytorch/brainmage/seg_modules.py @@ -380,7 +380,7 @@ def forward(self, x1, x2): return x class out_conv(nn.Module): - def __init__(self, input_channels, output_channels, leakiness=1e-2, kernel_size=3, + def __init__(self, input_channels, output_channels, binary_classification=True, leakiness=1e-2, kernel_size=3, conv_bias=True, inst_norm_affine=True, res=True, lrelu_inplace=True): """[The Out convolution module to learn the information and use later] @@ -391,6 +391,10 @@ def __init__(self, input_channels, output_channels, leakiness=1e-2, kernel_size= the number of channels from downsample] output_channels {[int]} -- [the output number of channels, will det- -ermine the upcoming channels] + binary_classification {[bool]} -- signals that the per-pixel output is one + channel with values between 0 and 1, otherwise + is multi-channel over which a softmax should be + applied Keyword Arguments: kernel_size {number} -- [size of filter] (default: {3}) @@ -402,6 +406,7 @@ def __init__(self, input_channels, output_channels, leakiness=1e-2, kernel_size= (default: {True}) """ nn.Module.__init__(self) + self.binary_classification = binary_classification self.lrelu_inplace = lrelu_inplace self.inst_norm_affine = inst_norm_affine self.conv_bias = conv_bias @@ -445,7 +450,11 @@ def forward(self, x1, x2): if self.res == True: x = x + skip x = F.leaky_relu(self.in_3(x)) - x = F.softmax(self.conv3(x),dim=1) + x = self.conv3(x) + if self.binary_classification: + x = torch.sigmoid(x) + else: + x = F.softmax(x,dim=1) return x ''' diff --git a/fets/models/pytorch/pt_3dresunet/pt_3dresunet.py b/fets/models/pytorch/pt_3dresunet/pt_3dresunet.py index f142911..0aa34a8 100644 --- a/fets/models/pytorch/pt_3dresunet/pt_3dresunet.py +++ b/fets/models/pytorch/pt_3dresunet/pt_3dresunet.py @@ -54,7 +54,7 @@ def init_network(self, device, print_model=False, **kwargs): self.us_1 = UpsamplingModule(self.base_filters*4, self.base_filters*2) self.de_1 = DecodingModule(self.base_filters*4, self.base_filters*2, res=True) self.us_0 = UpsamplingModule(self.base_filters*2, self.base_filters) - self.out = out_conv(self.base_filters*2, self.label_channels, res=True) + self.out = out_conv(self.base_filters*2, self.label_channels, self.binary_classification, res=True) if print_model: print(self)