diff --git a/test.py b/test.py index f522568..c30f74c 100644 --- a/test.py +++ b/test.py @@ -123,16 +123,6 @@ def __getitem__(self, idx): batch_size=args.eval_batch_size, num_workers=args.nThreads) -# Model -print('Building network... ', end='') -# model = unet.UnetGenerator(input_nc=3, output_nc=1, num_downs=8) -model = unet_model.UNet(3, 1, known_n_points=args.n_points) -print('DONE') -print(model) -model = nn.DataParallel(model) -if args.cuda: - model.cuda() - # Loss function l1_loss = nn.L1Loss() criterion_training = losses.ModifiedChamferLoss(256, 256, return_2_terms=True) @@ -142,9 +132,28 @@ def __getitem__(self, idx): if os.path.isfile(args.checkpoint): checkpoint = torch.load(args.checkpoint) start_epoch = checkpoint['epoch'] + # Model + if args.n_points is None: + if 'n_points' not in checkpoint: + # Model will also estimate # of points + model=unet_model.UNet(3, 1, None) + else: + # The checkpoint tells us the # of points to estimate + model=unet_model.UNet(3, 1, checkpoint['n_points']) + else: + # The user tells us the # of points to estimate + model=unet_model.UNet(3, 1, known_n_points=args.n_points) + + # Parallelize + model = nn.DataParallel(model) + if args.cuda: + model.cuda() + + # Load model in checkpoint model.load_state_dict(checkpoint['model']) print("╰─ loaded checkpoint '{}' (now on epoch {})" .format(args.checkpoint, checkpoint['epoch'])) + print(model) else: print("╰─ E: no checkpoint found at '{}'".format(args.checkpoint)) exit(-1) diff --git a/train_and_validate.py b/train_and_validate.py index cf0daa6..d85cb5e 100644 --- a/train_and_validate.py +++ b/train_and_validate.py @@ -454,6 +454,7 @@ def __getitem__(self, idx): 'model': model.state_dict(), 'lowest_avg_ahd_val': avg_ahd_val_float, 'optimizer': optimizer.state_dict(), + 'n_points': args.n_points, }, args.save) print("Saved best checkpoint so far in %s " % args.save) diff --git a/unet_model.py b/unet_model.py index a20519b..d74a754 100644 --- a/unet_model.py +++ b/unet_model.py @@ -7,7 +7,8 @@ class UNet(nn.Module): - def __init__(self, n_channels, n_classes, known_n_points=None): + def __init__(self, n_channels, n_classes, + known_n_points=None): super(UNet, self).__init__() self.inc = inconv(n_channels, 64) self.down1 = down(64, 128) @@ -29,12 +30,10 @@ def __init__(self, n_channels, n_classes, known_n_points=None): self.outc = outconv(64, n_classes) self.out_nonlin = nn.Sigmoid() - self.regressor = nn.Linear(256*256, 1) - self.regressor_nonlin = nn.Softplus() - - self.lin = nn.Linear(1, 1, bias=False) - self.known_n_points = known_n_points + if known_n_points is None: + self.regressor = nn.Linear(288*384, 1) + self.regressor_nonlin = nn.Softplus() def forward(self, x): x1 = self.inc(x) @@ -57,12 +56,14 @@ def forward(self, x): x = self.outc(x) x = self.out_nonlin(x) - x_flat = x.view(1, -1) - - regression = self.regressor(x_flat) - regression = self.regressor_nonlin(regression) - - return x, regression + if self.known_n_points is None: + x_flat = x.view(1, -1) + regression = self.regressor(x_flat) + regression = self.regressor_nonlin(regression) + return x, regression + else: + n_pts = Variable(torch.cuda.FloatTensor([self.known_n_points])) + return x, n_pts # summ = torch.sum(x) # count = self.lin(summ)