Skip to content

Commit

Permalink
update to pythorch 0.3 and fix 1x1 normalization layer
Browse files Browse the repository at this point in the history
Former-commit-id: 402604b46ef9d13bfcee81444859835d7e14e7b3
  • Loading branch information
Javi Ribera committed Feb 27, 2018
1 parent a3a0c1c commit 4531d92
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 18 deletions.
6 changes: 3 additions & 3 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -95,16 +95,16 @@ dependencies:
- zeromq=4.2.1=1
- zlib=1.2.8=3
- cudatoolkit=8.0=3
- cudnn=6.0.21=cuda8.0_0
- cudnn=7.0.5=cuda8.0_0
- intel-openmp=2018.0.0=h15fc484_7
- libgcc=7.2.0=h69d50b8_2
- libgcc-ng=7.2.0=h7cc24e2_2
- libgfortran=3.0.0=1
- libstdcxx-ng=7.2.0=h7a57d05_2
- mkl=2018.0.0=hb491cac_4
- nccl=1.3.4=cuda8.0_1
- pytorch=0.2.0=py36cuda8.0cudnn6.0_0
- opencv3=3.1.0=py36_0
- pytorch=0.3.1=py36_cuda8.0.61_cudnn7.0.5_2
- cuda80=1.0=0
- torchvision=0.1.9=py36h7584368_1
- pip:
Expand All @@ -114,7 +114,7 @@ dependencies:
- pycrayon==0.5
- pyyaml==3.12
- pyzmq==17.0.0
- torch==0.2.0.post0
- torch==0.3.1.post2
- torchfile==0.1.0
- visdom==0.1.7
prefix: /home/jprat/.anaconda3/envs/plant-location-unet
Expand Down
2 changes: 1 addition & 1 deletion plant-locator/models/unet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(self, n_channels, n_classes,
self.down5 = down(512, 512)
self.down6 = down(512, 512)
self.down7 = down(512, 512)
self.down8 = down(512, 512)
self.down8 = down(512, 512, normaliz=False)
self.up1 = up(1024, 512)
self.up2 = up(1024, 512)
self.up3 = up(1024, 512)
Expand Down
31 changes: 17 additions & 14 deletions plant-locator/models/unet_parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,22 @@


class double_conv(nn.Module):
def __init__(self, in_ch, out_ch):
def __init__(self, in_ch, out_ch, normaliz=True):
super(double_conv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1),
# nn.Dropout(p=0.1),
# TODO: BATCH NORM WITH A BATCH SIZE OF 1 crashes w/ pytorch 0.3
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, 3, padding=1),
# nn.Dropout(p=0.1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True)
)

ops = []
ops += [nn.Conv2d(in_ch, out_ch, 3, padding=1)]
# ops += [nn.Dropout(p=0.1)]
if normaliz:
ops += [nn.BatchNorm2d(out_ch)]
ops += [nn.ReLU(inplace=True)]
ops += [nn.Conv2d(out_ch, out_ch, 3, padding=1)]
# ops += [nn.Dropout(p=0.1)]
if normaliz:
ops += [nn.BatchNorm2d(out_ch)]
ops += [nn.ReLU(inplace=True)]

self.conv = nn.Sequential(*ops)

def forward(self, x):
x = self.conv(x)
Expand All @@ -37,11 +40,11 @@ def forward(self, x):


class down(nn.Module):
def __init__(self, in_ch, out_ch):
def __init__(self, in_ch, out_ch, normaliz=True):
super(down, self).__init__()
self.mpconv = nn.Sequential(
nn.MaxPool2d(2),
double_conv(in_ch, out_ch)
double_conv(in_ch, out_ch, normaliz=normaliz)
)

def forward(self, x):
Expand Down

0 comments on commit 4531d92

Please sign in to comment.