-
Notifications
You must be signed in to change notification settings - Fork 3
/
utils.py
100 lines (80 loc) · 3.34 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
class IndexToImageDataset(Dataset):
"""Wrap a dataset to map indices to images
In other words, instead of producing (X, y) it produces (idx, X). The label
y is not relevant for our task.
"""
def __init__(self, base_dataset):
self.base = base_dataset
def __len__(self):
return len(self.base)
def __getitem__(self, idx):
img, _ = self.base[idx]
return (idx, img)
def gaussian(x, sigma=1.0):
return np.exp(-(x**2) / (2*(sigma**2)))
def build_gauss_kernel(
size=5, sigma=1.0, n_channels=1, device=None):
"""Construct the convolution kernel for a gaussian blur
See https://en.wikipedia.org/wiki/Gaussian_blur for a definition.
Overall I first generate a NxNx2 matrix of indices, and then use those to
calculate the gaussian function on each element. The two dimensional
Gaussian function is then the product along axis=2.
Also, in_channels == out_channels == n_channels
"""
if size % 2 != 1:
raise ValueError("kernel size must be uneven")
grid = np.mgrid[range(size), range(size)] - size//2
kernel = np.prod(gaussian(grid, sigma), axis=0)
# kernel = np.sum(gaussian(grid, sigma), axis=0)
kernel /= np.sum(kernel)
# repeat same kernel for all pictures and all channels
# Also, conv weight should be (out_channels, in_channels/groups, h, w)
kernel = np.tile(kernel, (n_channels, 1, 1, 1))
kernel = torch.from_numpy(kernel).to(torch.float).to(device)
return kernel
def blur_images(images, kernel):
"""Convolve the gaussian kernel with the given stack of images"""
_, n_channels, _, _ = images.shape
_, _, kw, kh = kernel.shape
imgs_padded = F.pad(images, (kw//2, kh//2, kw//2, kh//2), mode='replicate')
return F.conv2d(imgs_padded, kernel, groups=n_channels)
def laplacian_pyramid(images, kernel, max_levels=5):
"""Laplacian pyramid of each image
https://en.wikipedia.org/wiki/Pyramid_(image_processing)#Laplacian_pyramid
"""
current = images
pyramid = []
for level in range(max_levels):
filtered = blur_images(current, kernel)
diff = current - filtered
pyramid.append(diff)
current = F.avg_pool2d(filtered, 2)
pyramid.append(current)
return pyramid
class LapLoss(nn.Module):
def __init__(self, max_levels=5, kernel_size=5, sigma=1.0):
super(LapLoss, self).__init__()
self.max_levels = max_levels
self.kernel_size = kernel_size
self.sigma = sigma
self._gauss_kernel = None
def forward(self, output, target):
if (self._gauss_kernel is None
or self._gauss_kernel.shape[1] != output.shape[1]):
self._gauss_kernel = build_gauss_kernel(
n_channels=output.shape[1],
device=output.device)
output_pyramid = laplacian_pyramid(
output, self._gauss_kernel, max_levels=self.max_levels)
target_pyramid = laplacian_pyramid(
target, self._gauss_kernel, max_levels=self.max_levels)
diff_levels = [F.l1_loss(o, t)
for o, t in zip(output_pyramid, target_pyramid)]
loss = sum([2**(-2*j) * diff_levels[j]
for j in range(self.max_levels)])
return loss