Skip to content

Commit

Permalink
allow fix number of estimated points
Browse files Browse the repository at this point in the history
  • Loading branch information
Javi Ribera committed Feb 3, 2018
1 parent 512e8a7 commit 2539326
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 22 deletions.
29 changes: 19 additions & 10 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,16 +123,6 @@ def __getitem__(self, idx):
batch_size=args.eval_batch_size,
num_workers=args.nThreads)

# Model
print('Building network... ', end='')
# model = unet.UnetGenerator(input_nc=3, output_nc=1, num_downs=8)
model = unet_model.UNet(3, 1, known_n_points=args.n_points)
print('DONE')
print(model)
model = nn.DataParallel(model)
if args.cuda:
model.cuda()

# Loss function
l1_loss = nn.L1Loss()
criterion_training = losses.ModifiedChamferLoss(256, 256, return_2_terms=True)
Expand All @@ -142,9 +132,28 @@ def __getitem__(self, idx):
if os.path.isfile(args.checkpoint):
checkpoint = torch.load(args.checkpoint)
start_epoch = checkpoint['epoch']
# Model
if args.n_points is None:
if 'n_points' not in checkpoint:
# Model will also estimate # of points
model=unet_model.UNet(3, 1, None)
else:
# The checkpoint tells us the # of points to estimate
model=unet_model.UNet(3, 1, checkpoint['n_points'])
else:
# The user tells us the # of points to estimate
model=unet_model.UNet(3, 1, known_n_points=args.n_points)

# Parallelize
model = nn.DataParallel(model)
if args.cuda:
model.cuda()

# Load model in checkpoint
model.load_state_dict(checkpoint['model'])
print("╰─ loaded checkpoint '{}' (now on epoch {})"
.format(args.checkpoint, checkpoint['epoch']))
print(model)
else:
print("╰─ E: no checkpoint found at '{}'".format(args.checkpoint))
exit(-1)
Expand Down
1 change: 1 addition & 0 deletions train_and_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,7 @@ def __getitem__(self, idx):
'model': model.state_dict(),
'lowest_avg_ahd_val': avg_ahd_val_float,
'optimizer': optimizer.state_dict(),
'n_points': args.n_points,
}, args.save)
print("Saved best checkpoint so far in %s " % args.save)

Expand Down
25 changes: 13 additions & 12 deletions unet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@


class UNet(nn.Module):
def __init__(self, n_channels, n_classes, known_n_points=None):
def __init__(self, n_channels, n_classes,
known_n_points=None):
super(UNet, self).__init__()
self.inc = inconv(n_channels, 64)
self.down1 = down(64, 128)
Expand All @@ -29,12 +30,10 @@ def __init__(self, n_channels, n_classes, known_n_points=None):
self.outc = outconv(64, n_classes)
self.out_nonlin = nn.Sigmoid()

self.regressor = nn.Linear(256*256, 1)
self.regressor_nonlin = nn.Softplus()

self.lin = nn.Linear(1, 1, bias=False)

self.known_n_points = known_n_points
if known_n_points is None:
self.regressor = nn.Linear(288*384, 1)
self.regressor_nonlin = nn.Softplus()

def forward(self, x):
x1 = self.inc(x)
Expand All @@ -57,12 +56,14 @@ def forward(self, x):
x = self.outc(x)
x = self.out_nonlin(x)

x_flat = x.view(1, -1)

regression = self.regressor(x_flat)
regression = self.regressor_nonlin(regression)

return x, regression
if self.known_n_points is None:
x_flat = x.view(1, -1)
regression = self.regressor(x_flat)
regression = self.regressor_nonlin(regression)
return x, regression
else:
n_pts = Variable(torch.cuda.FloatTensor([self.known_n_points]))
return x, n_pts
# summ = torch.sum(x)
# count = self.lin(summ)

Expand Down

0 comments on commit 2539326

Please sign in to comment.