Skip to content

Commit

Permalink
clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
joonaskalda committed Sep 20, 2023
1 parent 3fd95c6 commit c2b3fcb
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 450 deletions.
29 changes: 13 additions & 16 deletions pyannote/audio/models/segmentation/SepDiarNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,9 @@ def __init__(
encoder_decoder = merge_dict(self.ENCODER_DECODER_DEFAULTS, encoder_decoder)
self.n_src = n_sources
self.use_lstm = use_lstm
self.save_hyperparameters("encoder_decoder", "lstm", "linear", "convnet", "dprnn")
self.save_hyperparameters(
"encoder_decoder", "lstm", "linear", "convnet", "dprnn"
)
self.learning_rate = lr
self.n_sources = n_sources

Expand All @@ -141,11 +143,12 @@ def __init__(
sample_rate=sample_rate, **self.hparams.encoder_decoder
)
self.masker = DPRNN(n_feats_out, n_src=n_sources, **self.hparams.dprnn)
#self.convnet= TDConvNet(n_feats_out, **self.hparams.convnet)

# diarization can use a lower resolution than separation
diarization_scaling = int(256 / encoder_decoder["kernel_size"])
self.average_pool = nn.AvgPool1d(diarization_scaling, stride=diarization_scaling)

# diarization can use a lower resolution than separation, use 128x downsampling
diarization_scaling = int(128 / encoder_decoder["stride"])
self.average_pool = nn.AvgPool1d(
diarization_scaling, stride=diarization_scaling
)

if use_lstm:
monolithic = lstm["monolithic"]
Expand All @@ -169,7 +172,8 @@ def __init__(
nn.LSTM(
n_feats_out
if i == 0
else lstm["hidden_size"] * (2 if lstm["bidirectional"] else 1),
else lstm["hidden_size"]
* (2 if lstm["bidirectional"] else 1),
**one_layer_lstm
)
for i in range(num_layers)
Expand All @@ -178,14 +182,14 @@ def __init__(

if linear["num_layers"] < 1:
return

if use_lstm:
lstm_out_features: int = self.hparams.lstm["hidden_size"] * (
2 if self.hparams.lstm["bidirectional"] else 1
)
else:
lstm_out_features = 64

self.linear = nn.ModuleList(
[
nn.Linear(in_features, out_features)
Expand All @@ -207,14 +211,7 @@ def build(self):
2 if self.hparams.lstm["bidirectional"] else 1
)

# if isinstance(self.specifications, tuple):
# raise ValueError("PyanNet does not support multi-tasking.")

# if self.specifications.powerset:
out_features = 1
# else:
# out_features = len(self.specifications.classes)

self.classifier = nn.Linear(in_features, out_features)
self.activation = self.default_activation()

Expand Down
Loading

0 comments on commit c2b3fcb

Please sign in to comment.