Skip to content

Commit

Permalink
use default types of fixed # of objects
Browse files Browse the repository at this point in the history
Former-commit-id: 21d4462e96cafa82eef74098a216c4a44714e602
  • Loading branch information
Javi Ribera committed Nov 5, 2018
1 parent 918c438 commit dcb57d3
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 3 deletions.
9 changes: 7 additions & 2 deletions object-locator/models/unet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@
class UNet(nn.Module):
def __init__(self, n_channels, n_classes,
height, width,
known_n_points=None):
known_n_points=None,
device=torch.device('cuda')):
super(UNet, self).__init__()

self.device = device

# With this network depth, there is a minimum image size
if height < 256 or width < 256:
raise ValueError('Minimum input image size is 256x256, got {}x{}'.\
Expand Down Expand Up @@ -92,7 +95,9 @@ def forward(self, x):

return x, regression
else:
n_pts = torch.tensor([self.known_n_points]*batch_size)
n_pts = torch.tensor([self.known_n_points]*batch_size,
dtype=torch.get_default_dtype())
n_pts = n_pts.to(self.device)
return x, n_pts
# summ = torch.sum(x)
# count = self.lin(summ)
Expand Down
5 changes: 4 additions & 1 deletion object-locator/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@
model = unet_model.UNet(3, 1,
height=args.height,
width=args.width,
known_n_points=args.n_points)
known_n_points=args.n_points,
device=device)
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f" with {ballpark(num_params)} trainable parameters. ", end='')
model = nn.DataParallel(model)
Expand Down Expand Up @@ -344,6 +345,8 @@

# The 3 terms
with torch.no_grad():
est_counts = est_counts.view(-1)
target_counts = target_counts.view(-1)
term1, term2 = loss_loc.forward(est_maps,
target_locations,
target_orig_sizes)
Expand Down

0 comments on commit dcb57d3

Please sign in to comment.