-
Notifications
You must be signed in to change notification settings - Fork 1
/
mdvrnet.py
89 lines (74 loc) · 2.68 KB
/
mdvrnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
"""
FastDVDnet denoising algorithm
@author: Matias Tassano <mtassano@parisdescartes.fr>
"""
import torch
import torch.nn.functional as F
def temp_denoise(model, noisyframe, sigma_noise):
'''Encapsulates call to denoising model and handles padding.
Expects noisyframe to be normalized in [0., 1.]
'''
# make size a multiple of four (we have two scales in the denoiser)
sh_im = noisyframe.size()
expanded_h = sh_im[-2]%4
if expanded_h:
expanded_h = 4-expanded_h
expanded_w = sh_im[-1]%4
if expanded_w:
expanded_w = 4-expanded_w
padexp = (0, expanded_w, 0, expanded_h)
noisyframe = F.pad(input=noisyframe, pad=padexp, mode='reflect')
sigma_noise = F.pad(input=sigma_noise, pad=padexp, mode='reflect')
# denoise and decompress
out = torch.clamp(model(noisyframe, sigma_noise), 0., 1.)
if expanded_h:
out = out[:, :, :-expanded_h, :]
if expanded_w:
out = out[:, :, :, :-expanded_w]
return out
def denoise_decompress_seq_mdvrnet(seq, noise_std, temp_psz, model_temporal, q):
r"""Denoises a sequence of frames with MdVRNet.
Args:
seq: Tensor. [numframes, 1, C, H, W] array containing the noisy input frames
noise_std: Tensor. Standard deviation of the added noise
temp_psz: size of the temporal patch
model_temp: instance of the PyTorch model of the temporal denoiser
Returns:
denframes: Tensor, [numframes, C, H, W]
"""
# init arrays to handle contiguous frames and related patches
numframes, C, H, W = seq.shape
ctrlfr_idx = int((temp_psz-1)//2)
inframes = list()
denframes = torch.empty((numframes, C, H, W)).to(seq.device)
# build noise map from noise std and q
noise_map = torch.zeros((numframes, 2, H, W))
if isinstance(noise_std, list):
for i, current_std in enumerate(noise_std):
noise_map[i, 0].fill_(current_std)
for i, current_q in enumerate(q):
noise_map[i, 1].fill_(current_q)
else:
for i in range(numframes):
noise_map[i, 0].fill_(noise_std[0])
noise_map[i, 1].fill_(q)
for fridx in range(numframes):
# load input frames
if not inframes:
# if list not yet created, fill it with temp_patchsz frames
for idx in range(temp_psz):
relidx = abs(idx-ctrlfr_idx) # handle border conditions, reflect
inframes.append(seq[relidx])
else:
del inframes[0]
relidx = min(fridx + ctrlfr_idx, -fridx + 2*(numframes-1)-ctrlfr_idx) # handle border conditions
inframes.append(seq[relidx])
inframes_t = torch.stack(inframes, dim=0).contiguous().view((1, temp_psz*C, H, W)).to(seq.device)
# append result to output list
denframes[fridx] = temp_denoise(model_temporal, inframes_t, noise_map[fridx].unsqueeze(0))
# free memory up
del inframes
del inframes_t
torch.cuda.empty_cache()
# convert to appropiate type and return
return denframes