Skip to content

Commit

Permalink
Changes to pt_3dresnet allowing for a softmax output when using one-h…
Browse files Browse the repository at this point in the history
…ot encoding of labels, and simply applying sigmoid to map to [0,1] otherwise.
  • Loading branch information
Edwards committed Sep 9, 2020
1 parent b0edf6f commit 3f9b276
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 3 deletions.
5 changes: 5 additions & 0 deletions fets/data/pytorch/ptbrainmagedata.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,11 @@ def __init__(self,

# if we are performing binary classification per pixel, we will disable one_hot conversion of labels
self.binary_classification = self.n_classes == 2

# there is an assumption that for binary classification, new classes are exactly 0 and 1
if self.binary_classification:
if set(class_label_map.values()) != set([0, 1]):
raise ValueError("When performing binary classification, the new labels should be 0 and 1")



Expand Down
13 changes: 11 additions & 2 deletions fets/models/pytorch/brainmage/seg_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ def forward(self, x1, x2):
return x

class out_conv(nn.Module):
def __init__(self, input_channels, output_channels, leakiness=1e-2, kernel_size=3,
def __init__(self, input_channels, output_channels, binary_classification=True, leakiness=1e-2, kernel_size=3,
conv_bias=True, inst_norm_affine=True, res=True, lrelu_inplace=True):
"""[The Out convolution module to learn the information and use later]
Expand All @@ -391,6 +391,10 @@ def __init__(self, input_channels, output_channels, leakiness=1e-2, kernel_size=
the number of channels from downsample]
output_channels {[int]} -- [the output number of channels, will det-
-ermine the upcoming channels]
binary_classification {[bool]} -- signals that the per-pixel output is one
channel with values between 0 and 1, otherwise
is multi-channel over which a softmax should be
applied
Keyword Arguments:
kernel_size {number} -- [size of filter] (default: {3})
Expand All @@ -402,6 +406,7 @@ def __init__(self, input_channels, output_channels, leakiness=1e-2, kernel_size=
(default: {True})
"""
nn.Module.__init__(self)
self.binary_classification = binary_classification
self.lrelu_inplace = lrelu_inplace
self.inst_norm_affine = inst_norm_affine
self.conv_bias = conv_bias
Expand Down Expand Up @@ -445,7 +450,11 @@ def forward(self, x1, x2):
if self.res == True:
x = x + skip
x = F.leaky_relu(self.in_3(x))
x = F.softmax(self.conv3(x),dim=1)
x = self.conv3(x)
if self.binary_classification:
x = torch.sigmoid(x)
else:
x = F.softmax(x,dim=1)
return x

'''
Expand Down
2 changes: 1 addition & 1 deletion fets/models/pytorch/pt_3dresunet/pt_3dresunet.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def init_network(self, device, print_model=False, **kwargs):
self.us_1 = UpsamplingModule(self.base_filters*4, self.base_filters*2)
self.de_1 = DecodingModule(self.base_filters*4, self.base_filters*2, res=True)
self.us_0 = UpsamplingModule(self.base_filters*2, self.base_filters)
self.out = out_conv(self.base_filters*2, self.label_channels, res=True)
self.out = out_conv(self.base_filters*2, self.label_channels, self.binary_classification, res=True)

if print_model:
print(self)
Expand Down

0 comments on commit 3f9b276

Please sign in to comment.