Skip to content

Commit

Permalink
trying CSRNet
Browse files Browse the repository at this point in the history
Former-commit-id: c5b487c932881c1da211e8c95425962746eca0ed
  • Loading branch information
Javi Ribera committed Nov 8, 2018
1 parent 3fd4ca4 commit 4558009
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 7 deletions.
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ dependencies:
- tqdm=4.23.1
- xmltodict=0.11.0
- pytorch=0.4.0
- h5py
- pip:
- ballpark==1.4.0
- visdom==0.1.7
Expand Down
60 changes: 60 additions & 0 deletions object-locator/models/csrnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import torch.nn as nn
import torch
from torchvision import models
from .utils import save_net,load_net
import copy

class CSRNet(nn.Module):
def __init__(self, load_weights=False):
super(CSRNet, self).__init__()
self.seen = 0
self.frontend_feat = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512]
self.backend_feat = [512, 512, 512,256,128,64, 'U8']
self.frontend = make_layers(self.frontend_feat)
self.backend = make_layers(self.backend_feat,in_channels = 512,dilation = True)
self.output_layer = nn.Conv2d(64, 1, kernel_size=1)
if not load_weights:
mod = models.vgg16(pretrained = True)
self._initialize_weights()
# for i in range(len(self.frontend.state_dict().items())):
for k_frontend, k_mod in zip(self.frontend.state_dict(),
mod.state_dict()):
self.frontend.state_dict()[k_frontend][1].data = mod.state_dict()[k_mod][1].data
# self.frontend.state_dict().items()[i][1].data[:] = mod.state_dict().items()[i][1].data[:]
def forward(self,x):
x = self.frontend(x)
x = self.backend(x)
x = self.output_layer(x)
return x
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.normal_(m.weight, std=0.01)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)


def make_layers(cfg, in_channels = 3,batch_norm=False,dilation = False):
if dilation:
d_rate = 2
else:
d_rate = 1
layers = []
for v in cfg:
if v == 'U8':
layers += [nn.Upsample(scale_factor=8,
mode='trilinear',
align_corners=True)]
elif v == 'M':
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
else:
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=d_rate,dilation = d_rate)
if batch_norm:
layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
else:
layers += [conv2d, nn.ReLU(inplace=True)]
in_channels = v
return nn.Sequential(*layers)
18 changes: 18 additions & 0 deletions object-locator/models/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import h5py
import torch
import shutil

def save_net(fname, net):
with h5py.File(fname, 'w') as h5f:
for k, v in net.state_dict().items():
h5f.create_dataset(k, data=v.cpu().numpy())
def load_net(fname, net):
with h5py.File(fname, 'r') as h5f:
for k, v in net.state_dict().items():
param = torch.from_numpy(np.asarray(h5f[k]))
v.copy_(param)

def save_checkpoint(state, is_best,task_id, filename='checkpoint.pth.tar'):
torch.save(state, task_id+filename)
if is_best:
shutil.copyfile(task_id+filename, task_id+'model_best.pth.tar')
15 changes: 9 additions & 6 deletions object-locator/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from . import logger
from . import argparser
from . import utils
from .models import csrnet


# Parse command line arguments
Expand Down Expand Up @@ -71,7 +72,7 @@
training_transforms = []
if not args.no_data_augm:
training_transforms += [RandomHorizontalFlipImageAndLabel(p=0.5, seed=args.seed)]
training_transforms += [RandomVerticalFlipImageAndLabel(p=0.5, seed=args.seed)]
# training_transforms += [RandomVerticalFlipImageAndLabel(p=0.5, seed=args.seed)]
training_transforms += [ScaleImageAndLabel(size=(args.height, args.width))]
training_transforms += [transforms.ToTensor()]
training_transforms += [transforms.Normalize((0.5, 0.5, 0.5),
Expand Down Expand Up @@ -103,11 +104,13 @@

# Model
with peter('Building network'):
model = unet_model.UNet(3, 1,
height=args.height,
width=args.width,
known_n_points=args.n_points,
device=device)
model = csrnet.CSRNet()
import ipdb; ipdb.set_trace() # BREAKPOINT
# model = unet_model.UNet(3, 1,
# height=args.height,
# width=args.width,
# known_n_points=args.n_points,
# device=device)
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f" with {ballpark(num_params)} trainable parameters. ", end='')
model = nn.DataParallel(model)
Expand Down
2 changes: 1 addition & 1 deletion object-locator/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ def paint_circles(img, points, color='red', crosshair=False):
for y, x in points:
img = cv2.drawMarker(img,
(x, y),
color, cv2.MARKER_TILTED_CROSS, 9, 3, cv2.LINE_AA)
color, cv2.MARKER_TILTED_CROSS, 7, 1, cv2.LINE_AA)
img = np.moveaxis(img, 2, 0)

return img

0 comments on commit 4558009

Please sign in to comment.