-
Notifications
You must be signed in to change notification settings - Fork 2
/
model.py
116 lines (92 loc) · 5.61 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# The SELDnet architecture
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from IPython import embed
from ngcc.model import NGCCPHAT
class ConvBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)):
super().__init__()
self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
self.bn = nn.BatchNorm2d(out_channels)
def forward(self, x):
x = F.relu(self.bn(self.conv(x)))
return x
class NGCCModel(torch.nn.Module):
def __init__(self, in_feat_shape, out_shape, params, in_vid_feat_shape=None):
super().__init__()
self.ngcc_channels = params['ngcc_channels']
self.ngcc_out_channels = params['ngcc_out_channels']
self.mel_bins = params['nb_mel_bins']
self.fs = params['fs']
self.sig_len = int(self.fs * params['hop_len_s']) # 480 samples
self.predict_tdoa = params['predict_tdoa']
self.pool_len = int(params['t_pool_size'][0])
if params['use_mel']:
self.in_channels = int(self.ngcc_out_channels * params['n_mics'] * (params['n_mics'] - 1) / 2 + params['n_mics'])
else:
self.in_channels = int(self.ngcc_out_channels * params['n_mics'] * ( 1 + (params['n_mics'] - 1) / 2))
self.ngcc = NGCCPHAT(max_tau=params['max_tau'], n_mel_bins=self.mel_bins , use_sinc=True,
sig_len=self.sig_len , num_channels=self.ngcc_channels, num_out_channels=self.ngcc_out_channels, fs=self.fs,
normalize_input=False, normalize_output=False, pool_len=1, use_mel=params['use_mel'], use_mfcc=params['use_mfcc'],
predict_tdoa=params['predict_tdoa'], tracks=params['tracks'], fixed_tdoa=params['fixed_tdoa'])
self.nb_classes = params['unique_classes']
self.params=params
self.conv_block_list = nn.ModuleList()
if len(params['f_pool_size']):
for conv_cnt in range(len(params['f_pool_size'])):
self.conv_block_list.append(ConvBlock(in_channels=params['nb_cnn2d_filt'] if conv_cnt else self.in_channels, out_channels=params['nb_cnn2d_filt']))
self.conv_block_list.append(nn.MaxPool2d((params['t_pool_size'][conv_cnt], params['f_pool_size'][conv_cnt])))
self.conv_block_list.append(nn.Dropout2d(p=params['dropout_rate']))
self.gru_input_dim = params['nb_cnn2d_filt'] * int(np.floor(self.mel_bins / np.prod(params['f_pool_size'])))
self.gru = torch.nn.GRU(input_size=self.gru_input_dim, hidden_size=params['rnn_size'],
num_layers=params['nb_rnn_layers'], batch_first=True,
dropout=params['dropout_rate'], bidirectional=True)
self.mhsa_block_list = nn.ModuleList()
self.layer_norm_list = nn.ModuleList()
for mhsa_cnt in range(params['nb_self_attn_layers']):
self.mhsa_block_list.append(nn.MultiheadAttention(embed_dim=self.params['rnn_size'], num_heads=self.params['nb_heads'], dropout=self.params['dropout_rate'], batch_first=True))
self.layer_norm_list.append(nn.LayerNorm(self.params['rnn_size']))
# fusion layers
if in_vid_feat_shape is not None:
self.visual_embed_to_d_model = nn.Linear(in_features = int(in_vid_feat_shape[2]*in_vid_feat_shape[3]), out_features = self.params['rnn_size'] )
self.transformer_decoder_layer = nn.TransformerDecoderLayer(d_model=self.params['rnn_size'], nhead=self.params['nb_heads'], batch_first=True)
self.transformer_decoder = nn.TransformerDecoder(self.transformer_decoder_layer, num_layers=self.params['nb_transformer_layers'])
self.fnn_list = torch.nn.ModuleList()
if params['nb_fnn_layers']:
for fc_cnt in range(params['nb_fnn_layers']):
self.fnn_list.append(nn.Linear(params['fnn_size'] if fc_cnt else self.params['rnn_size'], params['fnn_size'], bias=True))
self.fnn_list.append(nn.Linear(params['fnn_size'] if params['nb_fnn_layers'] else self.params['rnn_size'], out_shape[-1], bias=True))
self.doa_act = nn.Tanh()
self.dist_act = nn.ReLU()
def forward(self, x, vid_feat=None):
"""input: (batch_size, mic_channels, time_steps, sig_len)"""
if self.predict_tdoa:
x, tdoa = self.ngcc(x)
else:
x = self.ngcc(x)
for conv_cnt in range(len(self.conv_block_list)):
x = self.conv_block_list[conv_cnt](x)
x = x.transpose(1, 2).contiguous()
x = x.view(x.shape[0], x.shape[1], -1).contiguous()
(x, _) = self.gru(x)
x = torch.tanh(x)
x = x[:, :, x.shape[-1]//2:] * x[:, :, :x.shape[-1]//2]
for mhsa_cnt in range(len(self.mhsa_block_list)):
x_attn_in = x
x, _ = self.mhsa_block_list[mhsa_cnt](x_attn_in, x_attn_in, x_attn_in)
x = x + x_attn_in
x = self.layer_norm_list[mhsa_cnt](x)
if vid_feat is not None:
vid_feat = vid_feat.view(vid_feat.shape[0], vid_feat.shape[1], -1) # b x 50 x 49
vid_feat = self.visual_embed_to_d_model(vid_feat)
x = self.transformer_decoder(x, vid_feat)
for fnn_cnt in range(len(self.fnn_list) - 1):
x = self.fnn_list[fnn_cnt](x)
doa = self.fnn_list[-1](x)
if self.predict_tdoa:
return doa, tdoa.mean(dim=1, keepdims=True) #tdoa[:, ::self.pool_len] # pool tdoas to get correct resolution
else:
return doa