From c2b3fcb7c409b89cca11700cb27aba7ae8c7baba Mon Sep 17 00:00:00 2001 From: joonaskalda Date: Wed, 20 Sep 2023 13:53:23 +0300 Subject: [PATCH] clean up --- .../audio/models/segmentation/SepDiarNet.py | 29 +- .../speaker_separation_diarization.py | 537 ++++-------------- 2 files changed, 116 insertions(+), 450 deletions(-) diff --git a/pyannote/audio/models/segmentation/SepDiarNet.py b/pyannote/audio/models/segmentation/SepDiarNet.py index 6db75f4cb..a0e7069ef 100644 --- a/pyannote/audio/models/segmentation/SepDiarNet.py +++ b/pyannote/audio/models/segmentation/SepDiarNet.py @@ -127,7 +127,9 @@ def __init__( encoder_decoder = merge_dict(self.ENCODER_DECODER_DEFAULTS, encoder_decoder) self.n_src = n_sources self.use_lstm = use_lstm - self.save_hyperparameters("encoder_decoder", "lstm", "linear", "convnet", "dprnn") + self.save_hyperparameters( + "encoder_decoder", "lstm", "linear", "convnet", "dprnn" + ) self.learning_rate = lr self.n_sources = n_sources @@ -141,11 +143,12 @@ def __init__( sample_rate=sample_rate, **self.hparams.encoder_decoder ) self.masker = DPRNN(n_feats_out, n_src=n_sources, **self.hparams.dprnn) - #self.convnet= TDConvNet(n_feats_out, **self.hparams.convnet) - - # diarization can use a lower resolution than separation - diarization_scaling = int(256 / encoder_decoder["kernel_size"]) - self.average_pool = nn.AvgPool1d(diarization_scaling, stride=diarization_scaling) + + # diarization can use a lower resolution than separation, use 128x downsampling + diarization_scaling = int(128 / encoder_decoder["stride"]) + self.average_pool = nn.AvgPool1d( + diarization_scaling, stride=diarization_scaling + ) if use_lstm: monolithic = lstm["monolithic"] @@ -169,7 +172,8 @@ def __init__( nn.LSTM( n_feats_out if i == 0 - else lstm["hidden_size"] * (2 if lstm["bidirectional"] else 1), + else lstm["hidden_size"] + * (2 if lstm["bidirectional"] else 1), **one_layer_lstm ) for i in range(num_layers) @@ -178,14 +182,14 @@ def __init__( if linear["num_layers"] < 1: return - + if use_lstm: lstm_out_features: int = self.hparams.lstm["hidden_size"] * ( 2 if self.hparams.lstm["bidirectional"] else 1 ) else: lstm_out_features = 64 - + self.linear = nn.ModuleList( [ nn.Linear(in_features, out_features) @@ -207,14 +211,7 @@ def build(self): 2 if self.hparams.lstm["bidirectional"] else 1 ) - # if isinstance(self.specifications, tuple): - # raise ValueError("PyanNet does not support multi-tasking.") - - # if self.specifications.powerset: out_features = 1 - # else: - # out_features = len(self.specifications.classes) - self.classifier = nn.Linear(in_features, out_features) self.activation = self.default_activation() diff --git a/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py b/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py index d3872b9c8..5e6fe0897 100644 --- a/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py +++ b/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py @@ -56,7 +56,12 @@ from pyannote.audio.utils.loss import binary_cross_entropy, mse_loss, nll_loss from pyannote.audio.utils.permutation import permutate from pyannote.audio.utils.powerset import Powerset -from asteroid.losses import MixITLossWrapper, multisrc_neg_sisdr, PITLossWrapper, pairwise_neg_sisdr +from asteroid.losses import ( + MixITLossWrapper, + multisrc_neg_sisdr, + PITLossWrapper, + pairwise_neg_sisdr, +) from torch.utils.data._utils.collate import default_collate Subsets = list(Subset.__args__) @@ -70,255 +75,6 @@ from torch.utils.data import DataLoader, Dataset, IterableDataset from pyannote.audio.utils.random import create_rng_for_worker -class CountingCallback(Callback): - def on_train_epoch_start(self, trainer, pl_module) -> None: - "reset counters" - if pl_module.task.log_alignment_accuracy: - pl_module.task.num_correct = 0 - pl_module.task.num_total = 0 - pl_module.task.num_correct30 = 0 - pl_module.task.num_correct21 = 0 - pl_module.task.num_correct20 = 0 - pl_module.task.num_correct11 = 0 - pl_module.task.num_correct10 = 0 - pl_module.task.num_total30 = 0 - pl_module.task.num_total21 = 0 - pl_module.task.num_total20 = 0 - pl_module.task.num_total11 = 0 - pl_module.task.num_total10 = 0 - -class CustomMixITLossWrapper(nn.Module): - r"""Custom mixture invariant loss wrapper that returns the best partition - so that it can be checked against the partition determined by forced - alignment. - - Args: - loss_func: function with signature (est_targets, targets, **kwargs). - generalized (bool): Determines how MixIT is applied. If False , - apply MixIT for any number of mixtures as soon as they contain - the same number of sources (:meth:`~MixITLossWrapper.best_part_mixit`.) - If True (default), apply MixIT for two mixtures, but those mixtures do not - necessarly have to contain the same number of sources. - See :meth:`~MixITLossWrapper.best_part_mixit_generalized`. - reduction (string, optional): Specifies the reduction to apply to - the output: - ``'none'`` | ``'mean'``. ``'none'``: no reduction will be applied, - ``'mean'``: the sum of the output will be divided by the number of - elements in the output. - - For each of these modes, the best partition and reordering will be - automatically computed. - - Examples: - >>> import torch - >>> from asteroid.losses import multisrc_mse - >>> mixtures = torch.randn(10, 2, 16000) - >>> est_sources = torch.randn(10, 4, 16000) - >>> # Compute MixIT loss based on pairwise losses - >>> loss_func = MixITLossWrapper(multisrc_mse) - >>> loss_val = loss_func(est_sources, mixtures) - - References - [1] Scott Wisdom et al. "Unsupervised sound separation using - mixtures of mixtures." arXiv:2006.12701 (2020) - """ - - def __init__(self, loss_func, generalized=True, reduction="mean"): - super().__init__() - self.loss_func = loss_func - self.generalized = generalized - self.reduction = reduction - - def forward(self, est_targets, targets, return_est=False, **kwargs): - r"""Find the best partition and return the loss. - - Args: - est_targets: torch.Tensor. Expected shape :math:`(batch, nsrc, *)`. - The batch of target estimates. - targets: torch.Tensor. Expected shape :math:`(batch, nmix, ...)`. - The batch of training targets - return_est: Boolean. Whether to return the estimated mixtures - estimates (To compute metrics or to save example). - **kwargs: additional keyword argument that will be passed to the - loss function. - - Returns: - - Best partition loss for each batch sample, average over - the batch. torch.Tensor(loss_value) - - The estimated mixtures (estimated sources summed according to the partition) - if return_est is True. torch.Tensor of shape :math:`(batch, nmix, ...)`. - """ - # Check input dimensions - assert est_targets.shape[0] == targets.shape[0] - assert est_targets.shape[2] == targets.shape[2] - - if not self.generalized: - min_loss, min_loss_idx, parts = self.best_part_mixit( - self.loss_func, est_targets, targets, **kwargs - ) - else: - min_loss, min_loss_idx, parts = self.best_part_mixit_generalized( - self.loss_func, est_targets, targets, **kwargs - ) - - # Apply any reductions over the batch axis - returned_loss = min_loss.mean() if self.reduction == "mean" else min_loss - if not return_est: - return returned_loss, [parts[i] for i in min_loss_idx] - - # Order and sum on the best partition to get the estimated mixtures - reordered = self.reorder_source(est_targets, targets, min_loss_idx, parts) - return returned_loss, reordered - - @staticmethod - def best_part_mixit(loss_func, est_targets, targets, **kwargs): - r"""Find best partition of the estimated sources that gives the minimum - loss for the MixIT training paradigm in [1]. Valid for any number of - mixtures as soon as they contain the same number of sources. - - Args: - loss_func: function with signature ``(est_targets, targets, **kwargs)`` - The loss function to get batch losses from. - est_targets: torch.Tensor. Expected shape :math:`(batch, nsrc, ...)`. - The batch of target estimates. - targets: torch.Tensor. Expected shape :math:`(batch, nmix, ...)`. - The batch of training targets (mixtures). - **kwargs: additional keyword argument that will be passed to the - loss function. - - Returns: - - :class:`torch.Tensor`: - The loss corresponding to the best permutation of size (batch,). - - - :class:`torch.LongTensor`: - The indices of the best partition. - - - :class:`list`: - list of the possible partitions of the sources. - - """ - nmix = targets.shape[1] - nsrc = est_targets.shape[1] - if nsrc % nmix != 0: - raise ValueError("The mixtures are assumed to contain the same number of sources") - nsrcmix = nsrc // nmix - - # Generate all unique partitions of size k from a list lst of - # length n, where l = n // k is the number of parts. The total - # number of such partitions is: NPK(n,k) = n! / ((k!)^l * l!) - # Algorithm recursively distributes items over parts - def parts_mixit(lst, k, l): - if l == 0: - yield [] - else: - for c in combinations(lst, k): - rest = [x for x in lst if x not in c] - for r in parts_mixit(rest, k, l - 1): - yield [list(c), *r] - - # Generate all the possible partitions - parts = list(parts_mixit(range(nsrc), nsrcmix, nmix)) - # Compute the loss corresponding to each partition - loss_set = CustomMixITLossWrapper.loss_set_from_parts( - loss_func, est_targets=est_targets, targets=targets, parts=parts, **kwargs - ) - # Indexes and values of min losses for each batch element - min_loss, min_loss_indexes = torch.min(loss_set, dim=1, keepdim=True) - return min_loss, min_loss_indexes, parts - - @staticmethod - def best_part_mixit_generalized(loss_func, est_targets, targets, **kwargs): - r"""Find best partition of the estimated sources that gives the minimum - loss for the MixIT training paradigm in [1]. Valid only for two mixtures, - but those mixtures do not necessarly have to contain the same number of - sources e.g the case where one mixture is silent is allowed.. - - Args: - loss_func: function with signature ``(est_targets, targets, **kwargs)`` - The loss function to get batch losses from. - est_targets: torch.Tensor. Expected shape :math:`(batch, nsrc, ...)`. - The batch of target estimates. - targets: torch.Tensor. Expected shape :math:`(batch, nmix, ...)`. - The batch of training targets (mixtures). - **kwargs: additional keyword argument that will be passed to the - loss function. - - Returns: - - :class:`torch.Tensor`: - The loss corresponding to the best permutation of size (batch,). - - - :class:`torch.LongTensor`: - The indexes of the best permutations. - - - :class:`list`: - list of the possible partitions of the sources. - """ - nmix = targets.shape[1] # number of mixtures - nsrc = est_targets.shape[1] # number of estimated sources - if nmix != 2: - raise ValueError("Works only with two mixtures") - - # Generate all unique partitions of any size from a list lst of - # length n. Algorithm recursively distributes items over parts - def parts_mixit_gen(lst): - partitions = [] - for k in range(len(lst) + 1): - for c in combinations(lst, k): - rest = [x for x in lst if x not in c] - partitions.append([list(c), rest]) - return partitions - - # Generate all the possible partitions - parts = parts_mixit_gen(range(nsrc)) - # Compute the loss corresponding to each partition - loss_set = CustomMixITLossWrapper.loss_set_from_parts( - loss_func, est_targets=est_targets, targets=targets, parts=parts, **kwargs - ) - # Indexes and values of min losses for each batch element - min_loss, min_loss_indexes = torch.min(loss_set, dim=1, keepdim=True) - return min_loss, min_loss_indexes, parts - - @staticmethod - def loss_set_from_parts(loss_func, est_targets, targets, parts, **kwargs): - """Common loop between both best_part_mixit""" - loss_set = [] - for partition in parts: - # sum the sources according to the given partition - est_mixes = torch.stack([est_targets[:, idx, :].sum(1) for idx in partition], dim=1) - # get loss for the given partition - loss_partition = loss_func(est_mixes, targets, **kwargs) - if loss_partition.ndim != 1: - raise ValueError("Loss function return value should be of size (batch,).") - loss_set.append(loss_partition[:, None]) - loss_set = torch.cat(loss_set, dim=1) - return loss_set - - @staticmethod - def reorder_source(est_targets, targets, min_loss_idx, parts): - """Reorder sources according to the best partition. - - Args: - est_targets: torch.Tensor. Expected shape :math:`(batch, nsrc, ...)`. - The batch of target estimates. - targets: torch.Tensor. Expected shape :math:`(batch, nmix, ...)`. - The batch of training targets. - min_loss_idx: torch.LongTensor. The indexes of the best permutations. - parts: list of the possible partitions of the sources. - - Returns: - :class:`torch.Tensor`: Reordered sources of shape :math:`(batch, nmix, time)`. - - """ - # For each batch there is a different min_loss_idx - ordered = torch.zeros_like(targets) - for b, idx in enumerate(min_loss_idx): - right_partition = parts[idx] - # Sum the estimated sources to get the estimated mixtures - ordered[b, :, :] = torch.stack( - [est_targets[b, idx, :][None, :, :].sum(1) for idx in right_partition], dim=1 - ) - - return ordered class ValDataset(IterableDataset): def __init__(self, task: Task): @@ -330,7 +86,8 @@ def __iter__(self): def __len__(self): return self.task.val__len__() - + + class JointSpeakerSeparationAndDiarization(SegmentationTaskMixin, Task): """Speaker diarization @@ -409,7 +166,6 @@ def __init__( separation_loss_weight: float = 0.5, original_mixtures_for_separation: bool = False, forced_alignment_weight: float = 0.0, - log_alignment_accuracy: bool = False, add_noise_sources: bool = False, ): super().__init__( @@ -438,7 +194,9 @@ def __init__( # parameter validation if max_speakers_per_frame is not None: - raise NotImplementedError("Powerset multi-class training is not implemented") + raise NotImplementedError( + "Diarization is done on masks separately which is incompatible powerset training" + ) if batch_size % 2 != 0: raise ValueError( @@ -450,12 +208,10 @@ def __init__( self.weigh_by_cardinality = weigh_by_cardinality self.balance = balance self.weight = weight - self.separation_loss = CustomMixITLossWrapper(multisrc_neg_sisdr, generalized=True) self.pit_sep_loss = PITLossWrapper(pairwise_neg_sisdr, pit_from="pw_mtx") self.separation_loss_weight = separation_loss_weight self.original_mixtures_for_separation = original_mixtures_for_separation self.forced_alignment_weight = forced_alignment_weight - self.log_alignment_accuracy = log_alignment_accuracy self.add_noise_sources = add_noise_sources def setup(self): @@ -664,7 +420,7 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): sample["meta"]["file"] = file_id return sample - + def val_dataloader(self) -> DataLoader: return DataLoader( ValDataset(self), @@ -674,7 +430,7 @@ def val_dataloader(self) -> DataLoader: drop_last=True, collate_fn=partial(self.collate_fn, stage="train"), ) - + def val__iter__(self): """Iterate over training samples @@ -714,7 +470,7 @@ def val__iter__(self): # generate random chunk yield next(chunks) - + def train__len__(self): # Number of training samples in one epoch @@ -1170,8 +926,8 @@ def segmentation_loss( def create_mixtures_of_mixtures(self, mix1, mix2, target1, target2): """ - Creates mixtures of mixtures and corresponding diarization targets. - Keeps track of how many speakers came from each mixture in order to + Creates mixtures of mixtures and corresponding diarization targets. + Keeps track of how many speakers came from each mixture in order to reconstruct the original mixtures. Parameters @@ -1223,23 +979,7 @@ def create_mixtures_of_mixtures(self, mix1, mix2, target1, target2): return mom, targets, num_active_speakers_mix1, num_active_speakers_mix2 - def training_step(self, batch, batch_idx: int): - """Compute permutation-invariant segmentation loss - - Parameters - ---------- - batch : (usually) dict of torch.Tensor - Current batch. - batch_idx: int - Batch index. - - Returns - ------- - loss : {str: torch.tensor} - {"loss": loss} - """ - - # target + def common_step(self, batch): target = batch["y"] # (batch_size, num_frames, num_speakers) @@ -1251,6 +991,14 @@ def training_step(self, batch, batch_idx: int): # forward pass bsz = waveform.shape[0] + + # MoMs can't be created for batch size < 2 + if bsz < 2: + return None + # if bsz not even, then leave out last sample + if bsz % 2 != 0: + waveform = waveform[:-1] + num_samples = waveform.shape[2] mix1 = waveform[0::2].squeeze(1) mix2 = waveform[1::2].squeeze(1) @@ -1287,15 +1035,20 @@ def training_step(self, batch, batch_idx: int): if self.add_noise_sources: # last 2 sources should only contain noise so we force diarization outputs to 0 - permutated_diarization, permutations = permutate(target, diarization[:, :, :3]) - target = torch.cat((target, torch.zeros(batch_size, num_frames, 2, device=target.device)), dim=2) - permutated_diarization = torch.cat((permutated_diarization, diarization[:, :, 3:]), dim=2) + permutated_diarization, permutations = permutate( + target, diarization[:, :, :3] + ) + target = torch.cat( + (target, torch.zeros(batch_size, num_frames, 2, device=target.device)), + dim=2, + ) + permutated_diarization = torch.cat( + (permutated_diarization, diarization[:, :, 3:]), dim=2 + ) else: permutated_diarization, permutations = permutate(target, diarization) - seg_loss = self.segmentation_loss( - permutated_diarization, target, weight=weight - ) + seg_loss = self.segmentation_loss(permutated_diarization, target, weight=weight) speaker_idx_mix1 = [ [permutations[i][j] for j in range(num_active_speakers_mix1[i])] @@ -1304,23 +1057,36 @@ def training_step(self, batch, batch_idx: int): speaker_idx_mix2 = [ [ permutations[i][j] - for j in range(num_active_speakers_mix1[i], num_active_speakers_mix1[i] + num_active_speakers_mix2[i]) + for j in range( + num_active_speakers_mix1[i], + num_active_speakers_mix1[i] + num_active_speakers_mix2[i], + ) ] for i in range(bsz // 2) ] - + est_mixes = [] for i in range(bsz // 2): if self.add_noise_sources: - est_mix1 = mom_sources[i, :, speaker_idx_mix1[i]].sum(1) + mom_sources[i,:,3] - est_mix2 = mom_sources[i, :, speaker_idx_mix2[i]].sum(1) + mom_sources[i,:,4] - est_mix3 = mom_sources[i, :, speaker_idx_mix1[i]].sum(1) + mom_sources[i,:,4] - est_mix4 = mom_sources[i, :, speaker_idx_mix2[i]].sum(1) + mom_sources[i,:,3] + est_mix1 = ( + mom_sources[i, :, speaker_idx_mix1[i]].sum(1) + mom_sources[i, :, 3] + ) + est_mix2 = ( + mom_sources[i, :, speaker_idx_mix2[i]].sum(1) + mom_sources[i, :, 4] + ) + est_mix3 = ( + mom_sources[i, :, speaker_idx_mix1[i]].sum(1) + mom_sources[i, :, 4] + ) + est_mix4 = ( + mom_sources[i, :, speaker_idx_mix2[i]].sum(1) + mom_sources[i, :, 3] + ) sep_loss_first_part = self.pit_sep_loss( - torch.stack((est_mix1, est_mix2)).unsqueeze(0), torch.stack((mix1[i], mix2[i])).unsqueeze(0) + torch.stack((est_mix1, est_mix2)).unsqueeze(0), + torch.stack((mix1[i], mix2[i])).unsqueeze(0), ) - sep_loss_second_part = self.pit_sep_loss( - torch.stack((est_mix3, est_mix4)).unsqueeze(0), torch.stack((mix1[i], mix2[i])).unsqueeze(0) + sep_loss_second_part = self.pit_sep_loss( + torch.stack((est_mix3, est_mix4)).unsqueeze(0), + torch.stack((mix1[i], mix2[i])).unsqueeze(0), ) if sep_loss_first_part < sep_loss_second_part: est_mixes.append(torch.stack((est_mix1, est_mix2))) @@ -1334,8 +1100,7 @@ def training_step(self, batch, batch_idx: int): separation_loss = self.pit_sep_loss( est_mixes, torch.stack((mix1, mix2)).transpose(0, 1) ).mean() - _, mixit_partitions = self.separation_loss(mom_sources[:,:,:3].transpose(1, 2), torch.stack((mix1, mix2)).transpose(0, 1)) - + if self.original_mixtures_for_separation: raise NotImplementedError # separation_loss += self.separation_loss( @@ -1351,6 +1116,39 @@ def training_step(self, batch, batch_idx: int): # ) # forced_alignment_loss = forced_alignment_loss.mean() / 3 forced_alignment_loss = 0 + return ( + seg_loss, + separation_loss, + forced_alignment_loss, + diarization, + permutated_diarization, + target, + ) + + def training_step(self, batch, batch_idx: int): + """Compute permutation-invariant segmentation loss + + Parameters + ---------- + batch : (usually) dict of torch.Tensor + Current batch. + batch_idx: int + Batch index. + + Returns + ------- + loss : {str: torch.tensor} + {"loss": loss} + """ + + ( + seg_loss, + separation_loss, + forced_alignment_loss, + diarization, + permutated_diarization, + target, + ) = self.common_step(batch) self.model.log( "loss/train/separation", separation_loss, @@ -1387,50 +1185,6 @@ def training_step(self, batch, batch_idx: int): prog_bar=False, logger=True, ) - if self.log_alignment_accuracy: - for i in range(bsz // 2): - inverse_mixit_partition = permutations_inverse[i][mixit_partitions[i][0]], permutations_inverse[i][mixit_partitions[i][1]] - if set([int(j) for j in speaker_idx_mix1[i]]) <= set(inverse_mixit_partition[0].tolist()) and set([int(j) for j in speaker_idx_mix2[i]]) <= set(inverse_mixit_partition[1].tolist()): - self.num_correct += 1 - if (num_active_speakers_mix1[i], num_active_speakers_mix2[i]) in [(1, 0), (0, 1)]: - self.num_correct10 += 1 - if (num_active_speakers_mix1[i], num_active_speakers_mix2[i]) in [(2, 0), (0, 2)]: - self.num_correct20 += 1 - if (num_active_speakers_mix1[i], num_active_speakers_mix2[i]) in [(3, 0), (0, 3)]: - self.num_correct30 += 1 - if num_active_speakers_mix1[i] == 1 and num_active_speakers_mix2[i] == 1: - self.num_correct11 += 1 - if (num_active_speakers_mix1[i], num_active_speakers_mix2[i]) in [(1, 2), (2, 1)]: - self.num_correct21 += 1 - if (num_active_speakers_mix1[i], num_active_speakers_mix2[i]) in [(1, 0), (0, 1)]: - self.num_total10 += 1 - if (num_active_speakers_mix1[i], num_active_speakers_mix2[i]) in [(2, 0), (0, 2)]: - self.num_total20 += 1 - if (num_active_speakers_mix1[i], num_active_speakers_mix2[i]) in [(3, 0), (0, 3)]: - self.num_total30 += 1 - if num_active_speakers_mix1[i] == 1 and num_active_speakers_mix2[i] == 1: - self.num_total11 += 1 - if (num_active_speakers_mix1[i], num_active_speakers_mix2[i]) in [(1, 2), (2, 1)]: - self.num_total21 += 1 - self.num_total+=1 - if self.num_total30 > 0: - self.model.log("accuracy/3_0", self.num_correct30/self.num_total30, on_step=False, on_epoch=True, prog_bar=False, logger=True) - if self.num_total20 > 0: - self.model.log("accuracy/2_0", self.num_correct20/self.num_total20, on_step=False, on_epoch=True, prog_bar=False, logger=True) - if self.num_total10 > 0: - self.model.log("accuracy/1_0", self.num_correct10/self.num_total10, on_step=False, on_epoch=True, prog_bar=False, logger=True) - if self.num_total11 > 0: - self.model.log("accuracy/1_1", self.num_correct11/self.num_total11, on_step=False, on_epoch=True, prog_bar=False, logger=True) - if self.num_total21 > 0: - self.model.log("accuracy/2_1", self.num_correct21/self.num_total21, on_step=False, on_epoch=True, prog_bar=False, logger=True) - if self.num_total > 0: - self.model.log("accuracy/total", self.num_correct/self.num_total, on_step=False, on_epoch=True, prog_bar=False, logger=True) - self.model.log("counts/3_0", self.num_total30, on_step=False, on_epoch=True, prog_bar=False, logger=True) - self.model.log("counts/2_0", self.num_total20, on_step=False, on_epoch=True, prog_bar=False, logger=True) - self.model.log("counts/1_0", self.num_total10, on_step=False, on_epoch=True, prog_bar=False, logger=True) - self.model.log("counts/1_1", self.num_total11, on_step=False, on_epoch=True, prog_bar=False, logger=True) - self.model.log("counts/2_1", self.num_total21, on_step=False, on_epoch=True, prog_bar=False, logger=True) - self.model.log("counts/total", self.num_total, on_step=False, on_epoch=True, prog_bar=False, logger=True) return {"loss": loss} @@ -1459,100 +1213,14 @@ def validation_step(self, batch, batch_idx: int): Batch index. """ - # target - target = batch["y"] - # (batch_size, num_frames, num_speakers) - - waveform = batch["X"] - # (batch_size, num_channels, num_samples) - - # TODO: should we handle validation samples with too many speakers - # waveform = waveform[keep] - # target = target[keep] - - bsz = waveform.shape[0] - num_samples = waveform.shape[2] - # MoMs can't be created for batch size < 2 - if bsz < 2: - return None - # if bsz not even, then leave out last sample - if bsz % 2 != 0: - waveform = waveform[:-1] - - # if bsz not even, then leave out last sample - mix1 = waveform[0::2].squeeze(1) - mix2 = waveform[1::2].squeeze(1) - if self.original_mixtures_for_separation: - # extract parts with only one speaker from original mixtures - mix1_masks = batch["X_separation_mask"][0::2] - mix2_masks = batch["X_separation_mask"][1::2] - mix1_masked = mix1 * mix1_masks - mix2_masked = mix2 * mix2_masks - ( - mom, - mom_target, - num_active_speakers_mix1, - num_active_speakers_mix2, - ) = self.create_mixtures_of_mixtures(mix1, mix2, target[0::2], target[1::2]) - - # forward pass - diarization, _ = self.model(waveform) - _, mom_sources = self.model(mom) - batch_size, num_frames, _ = diarization.shape - - # frames weight - weight_key = getattr(self, "weight", None) - weight = batch.get( - weight_key, - torch.ones(batch_size, num_frames, 1, device=self.model.device), - ) - # (batch_size, num_frames, 1) - - # last 2 sources should only contain noise so we force diarization outputs to 0 - permutated_diarization, permutations = permutate(target, diarization[:, :, :3]) - target = torch.cat((target, torch.zeros(batch_size, num_frames, 2, device=target.device)), dim=2) - permutated_diarization = torch.cat((permutated_diarization, diarization[:, :, 3:]), dim=2) - seg_loss = self.segmentation_loss( - permutated_diarization, target, weight=weight - ) - - speaker_idx_mix1 = [ - [permutations[i][j] for j in range(num_active_speakers_mix1[i])] - for i in range(bsz // 2) - ] - speaker_idx_mix2 = [ - [ - permutations[i][j] - for j in range(num_active_speakers_mix1[i], num_active_speakers_mix1[i] + num_active_speakers_mix2[i]) - ] - for i in range(bsz // 2) - ] - - est_mixes = [] - for i in range(bsz // 2): - est_mix1 = mom_sources[i, :, speaker_idx_mix1[i]].sum(1) + mom_sources[i,:,3] - est_mix2 = mom_sources[i, :, speaker_idx_mix2[i]].sum(1) + mom_sources[i,:,4] - est_mix3 = mom_sources[i, :, speaker_idx_mix1[i]].sum(1) + mom_sources[i,:,4] - est_mix4 = mom_sources[i, :, speaker_idx_mix2[i]].sum(1) + mom_sources[i,:,3] - sep_loss_first_part = self.pit_sep_loss( - torch.stack((est_mix1, est_mix2)).unsqueeze(0), torch.stack((mix1[i], mix2[i])).unsqueeze(0) - ) - sep_loss_second_part = self.pit_sep_loss( - torch.stack((est_mix3, est_mix4)).unsqueeze(0), torch.stack((mix1[i], mix2[i])).unsqueeze(0) - ) - if sep_loss_first_part < sep_loss_second_part: - est_mixes.append(torch.stack((est_mix1, est_mix2))) - else: - est_mixes.append(torch.stack((est_mix3, est_mix4))) - est_mixes = torch.stack(est_mixes) - separation_loss = self.pit_sep_loss( - est_mixes, torch.stack((mix1, mix2)).transpose(0, 1) - ).mean() - _, mixit_partitions = self.separation_loss(mom_sources[:,:,:3].transpose(1, 2), torch.stack((mix1, mix2)).transpose(0, 1)) - - if self.original_mixtures_for_separation: - raise NotImplementedError + seg_loss, + separation_loss, + forced_alignment_loss, + diarization, + permutated_diarization, + target, + ) = self.common_step(batch) self.model.log( "loss/val/separation", @@ -1573,8 +1241,10 @@ def validation_step(self, batch, batch_idx: int): ) loss = ( - 1 - self.separation_loss_weight - ) * seg_loss + self.separation_loss_weight * separation_loss + (1 - self.separation_loss_weight) * seg_loss + + self.separation_loss_weight * separation_loss + + forced_alignment_loss * self.forced_alignment_weight + ) self.model.log( "loss/val", @@ -1590,7 +1260,6 @@ def validation_step(self, batch, batch_idx: int): torch.transpose(target, 1, 2), ) - self.model.log_dict( self.model.validation_metric, on_step=False,