-
Notifications
You must be signed in to change notification settings - Fork 0
/
search_attenuation.py
54 lines (44 loc) · 1.55 KB
/
search_attenuation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
# -*- coding: utf-8 -*-
"""Search attenuation.ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/1u0g4CNA25CcO7i8VtJUwBYeCbdxIw7DA
"""
#@title Search Attention
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.nn.parameter import Parameter
import numpy as np
import scipy.stats as st
def _get_kernel(kernlen=16, nsig=3):
interval = (2*nsig+1.)/kernlen
x = np.linspace(-nsig-interval/2., nsig+interval/2., kernlen+1)
kern1d = np.diff(st.norm.cdf(x))
kernel_raw = np.sqrt(np.outer(kern1d, kern1d))
kernel = kernel_raw/kernel_raw.sum()
return kernel
def min_max_norm(in_):
"""
normalization
:param: in_
:return:
"""
max_ = in_.max(3)[0].max(2)[0].unsqueeze(2).unsqueeze(3).expand_as(in_)
min_ = in_.min(3)[0].min(2)[0].unsqueeze(2).unsqueeze(3).expand_as(in_)
in_ = in_ - min_
return in_.div(max_ - min_ + 1e-8)
class SA(nn.Module):
"""
holistic attention src
"""
def __init__(self):
super(SA, self).__init__()
gaussian_kernel = np.float32(_get_kernel(31, 4))
gaussian_kernel = gaussian_kernel[np.newaxis, np.newaxis, ...]
self.gaussian_kernel = Parameter(torch.from_numpy(gaussian_kernel))
def forward(self, attention, x):
soft_attention = F.conv2d(attention, self.gaussian_kernel, padding=15)
soft_attention = min_max_norm(soft_attention) # normalization
x = torch.mul(x, soft_attention.max(attention)) # mul
return x