Skip to content

Commit

Permalink
+option to make UNet smaller for pupil dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
Javier Ribera committed Oct 2, 2019
1 parent 02cb134 commit 7432e0b
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 36 deletions.
17 changes: 17 additions & 0 deletions object-locator/argparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,15 @@ def parse_command_args(training_or_testing):
help="If you know the number of points "
"(e.g, just one pupil), then set it. "
"Otherwise it will be estimated.")
optional_args.add_argument('--ultrasmallnet',
default=False,
action="store_true",
help="If True, the 5 central layers are removed,"
"resulting in a much smaller UNet. "
"This is used for example for the pupil dataset."
"Make sure to enable this if your are restoring "
"a checkpoint that was trained using this option enabled.")

optional_args.add_argument('--lambdaa',
type=strictly_positive,
default=1,
Expand Down Expand Up @@ -322,6 +331,14 @@ def parse_command_args(training_or_testing):
type=int,
metavar='N',
help='Number of data loading threads.')
optional_args.add_argument('--ultrasmallnet',
default=False,
action="store_true",
help="If True, the 5 central layers are removed,"
"resulting in a much smaller UNet. "
"This is used for example for the pupil dataset."
"Make sure to enable this if your are restoring "
"a checkpoint that was trained using this option enabled.")
parser._action_groups.append(optional_args)
args = parser.parse_args()

Expand Down
10 changes: 7 additions & 3 deletions object-locator/locate.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,19 +121,23 @@
model = unet_model.UNet(3, 1,
known_n_points=None,
height=args.height,
width=args.width)
width=args.width,
ultrasmall=args.ultrasmallnet)

else:
# The checkpoint tells us the # of points to estimate
model = unet_model.UNet(3, 1,
known_n_points=checkpoint['n_points'],
height=args.height,
width=args.width)
width=args.width,
ultrasmall=args.ultrasmallnet)
else:
# The user tells us the # of points to estimate
model = unet_model.UNet(3, 1,
known_n_points=args.n_points,
height=args.height,
width=args.width)
width=args.width,
ultrasmall=args.ultrasmallnet)

# Parallelize
if args.cuda:
Expand Down
93 changes: 61 additions & 32 deletions object-locator/models/unet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,25 @@ class UNet(nn.Module):
def __init__(self, n_channels, n_classes,
height, width,
known_n_points=None,
ultrasmall=False,
device=torch.device('cuda')):
"""
Instantiate a UNet network.
:param n_channels: Number of input channels (e.g, 3 for RGB)
:param n_classes: Number of output classes
:param height: Height of the input images
:param known_n_points: If you know the number of points,
(e.g, one pupil), then set it.
Otherwise it will be estimated by a lateral NN.
If provided, no lateral network will be build
and the resulting UNet will be a FCN.
:param ultrasmall: If True, the 5 central layers are removed,
resulting in a much smaller UNet.
:param device: Which torch device to use. Default: CUDA (GPU).
"""
super(UNet, self).__init__()

self.ultrasmall = ultrasmall
self.device = device

# With this network depth, there is a minimum image size
Expand All @@ -40,27 +56,34 @@ def __init__(self, n_channels, n_classes,
self.inc = inconv(n_channels, 64)
self.down1 = down(64, 128)
self.down2 = down(128, 256)
self.down3 = down(256, 512)
self.down4 = down(512, 512)
self.down5 = down(512, 512)
self.down6 = down(512, 512)
self.down7 = down(512, 512)
self.down8 = down(512, 512, normaliz=False)
self.up1 = up(1024, 512)
self.up2 = up(1024, 512)
self.up3 = up(1024, 512)
self.up4 = up(1024, 512)
self.up5 = up(1024, 256)
self.up6 = up(512, 128)
self.up7 = up(256, 64)
self.up8 = up(128, 64, activ=False)
if self.ultrasmall:
self.down3 = down(256, 512, normaliz=False)
self.up1 = up(768, 128)
self.up2 = up(256, 64)
self.up3 = up(128, 64, activ=False)
else:
self.down3 = down(256, 512)
self.down4 = down(512, 512)
self.down5 = down(512, 512)
self.down6 = down(512, 512)
self.down7 = down(512, 512)
self.down8 = down(512, 512, normaliz=False)
self.up1 = up(1024, 512)
self.up2 = up(1024, 512)
self.up3 = up(1024, 512)
self.up4 = up(1024, 512)
self.up5 = up(1024, 256)
self.up6 = up(512, 128)
self.up7 = up(256, 64)
self.up8 = up(128, 64, activ=False)
self.outc = outconv(64, n_classes)
self.out_nonlin = nn.Sigmoid()

self.known_n_points = known_n_points
if known_n_points is None:
height_mid_features = height//(2**8)
width_mid_features = width//(2**8)
steps = 3 if self.ultrasmall else 8
height_mid_features = height//(2**steps)
width_mid_features = width//(2**steps)
self.branch_1 = nn.Sequential(nn.Linear(height_mid_features*\
width_mid_features*\
512,
Expand All @@ -85,19 +108,24 @@ def forward(self, x):
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x6 = self.down5(x5)
x7 = self.down6(x6)
x8 = self.down7(x7)
x9 = self.down8(x8)
x = self.up1(x9, x8)
x = self.up2(x, x7)
x = self.up3(x, x6)
x = self.up4(x, x5)
x = self.up5(x, x4)
x = self.up6(x, x3)
x = self.up7(x, x2)
x = self.up8(x, x1)
if self.ultrasmall:
x = self.up1(x4, x3)
x = self.up2(x, x2)
x = self.up3(x, x1)
else:
x5 = self.down4(x4)
x6 = self.down5(x5)
x7 = self.down6(x6)
x8 = self.down7(x7)
x9 = self.down8(x8)
x = self.up1(x9, x8)
x = self.up2(x, x7)
x = self.up3(x, x6)
x = self.up4(x, x5)
x = self.up5(x, x4)
x = self.up6(x, x3)
x = self.up7(x, x2)
x = self.up8(x, x1)
x = self.outc(x)
x = self.out_nonlin(x)

Expand All @@ -106,13 +134,14 @@ def forward(self, x):
x = x.squeeze(1)

if self.known_n_points is None:
x9_flat = x9.view(batch_size, -1)
last_layer = x4 if self.ultrasmall else x9
last_layer_flat = last_layer.view(batch_size, -1)
x_flat = x.view(batch_size, -1)

x10_flat = self.branch_1(x9_flat)
lateral_flat = self.branch_1(last_layer_flat)
x_flat = self.branch_2(x_flat)

regression_features = torch.cat((x_flat, x10_flat), dim=1)
regression_features = torch.cat((x_flat, lateral_flat), dim=1)
regression = self.regressor(regression_features)

return x, regression
Expand Down
3 changes: 2 additions & 1 deletion object-locator/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@
height=args.height,
width=args.width,
known_n_points=args.n_points,
device=device)
device=device,
ultrasmall=args.ultrasmallnet)
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f" with {ballpark(num_params)} trainable parameters. ", end='')
model = nn.DataParallel(model)
Expand Down

0 comments on commit 7432e0b

Please sign in to comment.