Skip to content

Commit

Permalink
fix: input to the CNN could only be 256x256
Browse files Browse the repository at this point in the history
  • Loading branch information
Javier Ribera committed Aug 10, 2019
1 parent 8316f72 commit f860f7c
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions object-locator/models/unet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,15 @@ def __init__(self, n_channels, n_classes,

self.known_n_points = known_n_points
if known_n_points is None:
self.branch_1 = nn.Sequential(nn.Linear(512, 64),
height_mid_features = height//(2**8)
width_mid_features = width//(2**8)
self.branch_1 = nn.Sequential(nn.Linear(height_mid_features*\
width_mid_features*\
512,
64),
nn.ReLU(inplace=True),
nn.Dropout(p=0.5))
self.branch_2 = nn.Sequential(nn.Linear(256*256, 64),
self.branch_2 = nn.Sequential(nn.Linear(height*width, 64),
nn.ReLU(inplace=True),
nn.Dropout(p=0.5))
self.regressor = nn.Sequential(nn.Linear(64 + 64, 1),
Expand Down

0 comments on commit f860f7c

Please sign in to comment.