diff --git a/pyannote/audio/core/model.py b/pyannote/audio/core/model.py index 8af802293..e6a8842fd 100644 --- a/pyannote/audio/core/model.py +++ b/pyannote/audio/core/model.py @@ -194,6 +194,9 @@ def receptive_field(self) -> SlidingWindow: def prepare_data(self): self.task.prepare_data() + def prepare_data(self): + self.task.prepare_data() + def setup(self, stage=None): if stage == "fit": # let the task know about the trainer (e.g for broadcasting diff --git a/pyannote/audio/models/blocks/pooling.py b/pyannote/audio/models/blocks/pooling.py index dc31bea8e..17c2f9030 100644 --- a/pyannote/audio/models/blocks/pooling.py +++ b/pyannote/audio/models/blocks/pooling.py @@ -28,7 +28,9 @@ import torch.nn.functional as F -def _pool(sequences: torch.Tensor, weights: torch.Tensor) -> torch.Tensor: +def _pool( + sequences: torch.Tensor, weights: torch.Tensor, compute_mean: bool, compute_std:bool + ) -> torch.Tensor: """Helper function to compute statistics pooling Assumes that weights are already interpolated to match the number of frames @@ -50,16 +52,24 @@ def _pool(sequences: torch.Tensor, weights: torch.Tensor) -> torch.Tensor: weights = weights.unsqueeze(dim=1) # (batch, 1, frames) + stats = [] + v1 = weights.sum(dim=2) + 1e-8 mean = torch.sum(sequences * weights, dim=2) / v1 - dx2 = torch.square(sequences - mean.unsqueeze(2)) - v2 = torch.square(weights).sum(dim=2) + if compute_mean: + stats.append(mean) + + if compute_std: + dx2 = torch.square(sequences - mean.unsqueeze(2)) + v2 = torch.square(weights).sum(dim=2) - var = torch.sum(dx2 * weights, dim=2) / (v1 - v2 / v1 + 1e-8) - std = torch.sqrt(var) + var = torch.sum(dx2 * weights, dim=2) / (v1 - v2 / v1 + 1e-8) + std = torch.sqrt(var) - return torch.cat([mean, std], dim=1) + stats.append(std) + + return torch.cat(stats, dim=1) class StatsPool(nn.Module): @@ -68,14 +78,33 @@ class StatsPool(nn.Module): Compute temporal mean and (unbiased) standard deviation and returns their concatenation. + Parameters + ---------- + + compute_mean: bool, optional + whether to compute (and return) temporal mean. + Default to True + compute_std: bool, optional + whether to compute (and return) temporal standard deviation. + Default to True + Reference --------- https://en.wikipedia.org/wiki/Weighted_arithmetic_mean """ + def __init__( + self, + compute_mean: Optional[bool] = True, + computde_std: Optional[bool] = True, + ): + super().__init__() + self.compute_mean = compute_mean + self.compute_std = computde_std + def forward( - self, sequences: torch.Tensor, weights: Optional[torch.Tensor] = None + self, sequences: torch.Tensor, weights: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass @@ -122,7 +151,7 @@ def forward( output = torch.stack( [ - _pool(sequences, weights[:, speaker, :]) + _pool(sequences, weights[:, speaker, :], self.compute_mean, self.compute_std) for speaker in range(num_speakers) ], dim=1, diff --git a/pyannote/audio/models/joint/__init__.py b/pyannote/audio/models/joint/__init__.py new file mode 100644 index 000000000..97c1481d8 --- /dev/null +++ b/pyannote/audio/models/joint/__init__.py @@ -0,0 +1,27 @@ +# MIT License +# +# Copyright (c) 2020 CNRS +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +from .end_to_end_diarization import ( + WavLMEnd2EndDiarization, WavLMEnd2EndDiarizationv2, WavLMEnd2EndDiarizationv3 +) + +__all__ = ["WavLMEnd2EndDiarization", "WavLMEnd2EndDiarizationv2", "WavLMEnd2EndDiarizationv3"] diff --git a/pyannote/audio/models/joint/end_to_end_diarization.py b/pyannote/audio/models/joint/end_to_end_diarization.py new file mode 100644 index 000000000..1be9abd05 --- /dev/null +++ b/pyannote/audio/models/joint/end_to_end_diarization.py @@ -0,0 +1,974 @@ +# MIT License +# +# Copyright (c) 2023 CNRS +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +from functools import lru_cache +from typing import Literal, Optional, Union + +import torch +import torch.nn.functional as F +from einops import rearrange +from pyannote.core.utils.generators import pairwise +from torch import nn + +from pyannote.audio.core.model import Model +from pyannote.audio.core.task import Task +from pyannote.audio.models.blocks.pooling import StatsPool +from pyannote.audio.utils.params import merge_dict +from pyannote.audio.utils.powerset import Powerset + +from pyannote.audio.utils.receptive_field import ( + conv1d_num_frames, + conv1d_receptive_field_center, + conv1d_receptive_field_size, +) + +import torchaudio + +# TODO deplace these two lines into uitls/multi_task +Subtask = Literal["diarization", "embedding"] +Subtasks = list(Subtask.__args__) + + +class WavLMEnd2EndDiarization(Model): + """Self-Supervised representation for joint speaker diarization + and speaker embeddings extraction + + + Parameters + ---------- + sample_rate : int, optional + Audio sample rate. Defaults to 16kHz (16000). + num_channels : int, optional + Number of channels. Defaults to mono (1). + wav2vec: dict or str, optional + Defaults to "WAVLM_BASE". + wav2vec_layer: int, optional + Index of layer to use as input to the LSTM. + Defaults (-1) to use average of all layers (with learnable weights). + freeze_wav2vec: bool, optional + Whether to freeze wa2vec. Default to true + emb_dim: int, optional + Dimension of the speaker embedding in output + """ + + WAV2VEC_DEFAULTS = "WAVLM_BASE" + + LSTM_DEFAULTS = { + "hidden_size": 128, + "num_layers": 4, + "bidirectional": True, + "monolithic": True, + "dropout": 0.0, + } + + LINEAR_DEFAULT = {"hidden_size": 128, "num_layers": 2} + + def __init__( + self, + sample_rate: int = 16000, + num_channels: int = 1, + wav2vec: Union[dict, str] = None, + wav2vec_layer: int = -1, + freeze_wav2vec: bool = True, + lstm: Optional[dict] = None, + linear: Optional[dict] = None, + embedding_dim: Optional[int] = 192, + task: Optional[Task] = None, + ): + super().__init__(sample_rate=sample_rate, num_channels=num_channels, task=task) + + if isinstance(wav2vec, str): + # `wav2vec` is one of the supported pipelines from torchaudio (e.g. "WAVLM_BASE") + if hasattr(torchaudio.pipelines, wav2vec): + bundle = getattr(torchaudio.pipelines, wav2vec) + if sample_rate != bundle.sample_rate: + raise ValueError( + f"Expected {bundle.sample_rate}Hz, found {sample_rate}Hz." + ) + wav2vec_dim = bundle._params["encoder_embed_dim"] + wav2vec_num_layers = bundle._params["encoder_num_layers"] + self.wav2vec = bundle.get_model() + + # `wav2vec` is a path to a self-supervised representation checkpoint + else: + _checkpoint = torch.load(wav2vec) + wav2vec = _checkpoint.pop("config") + self.wav2vec = torchaudio.models.wav2vec2_model(**wav2vec) + state_dict = _checkpoint.pop("state_dict") + self.wav2vec.load_state_dict(state_dict) + wav2vec_dim = wav2vec["encoder_embed_dim"] + wav2vec_num_layers = wav2vec["encoder_num_layers"] + + # `wav2vec` is a config dictionary understood by `wav2vec2_model` + # this branch is typically used by Model.from_pretrained(...) + elif isinstance(wav2vec, dict): + self.wav2vec = torchaudio.models.wav2vec2_model(**wav2vec) + wav2vec_dim = wav2vec["encoder_embed_dim"] + wav2vec_num_layers = wav2vec["encoder_num_layers"] + + if wav2vec_layer < 0: + # weighting parameters for the diarization branch + self.dia_wav2vec_weights = nn.Parameter( + data=torch.ones(wav2vec_num_layers), requires_grad=True + ) + # weighting parameters for the embedding branch + self.emb_wav2vec_weights = nn.Parameter( + data=torch.ones(wav2vec_num_layers), requires_grad=True + ) + self.save_hyperparameters("wav2vec", "wav2vec_layer", "freeze_wav2vec") + + lstm = merge_dict(self.LSTM_DEFAULTS, lstm) + lstm["batch_first"] = True + linear = merge_dict(self.LINEAR_DEFAULT, linear) + self.save_hyperparameters("lstm", "linear") + monolithic = lstm["monolithic"] + if monolithic: + multi_layer_lstm = dict(lstm) + del multi_layer_lstm["monolithic"] + self.lstm = nn.LSTM(wav2vec_dim, **multi_layer_lstm) + else: + num_layers = lstm["num_layers"] + if num_layers > 1: + self.dropout = nn.Dropout(p=lstm["dropout"]) + + one_layer_lstm = dict(lstm) + one_layer_lstm["num_layers"] = 1 + one_layer_lstm["dropout"] = 0.0 + del one_layer_lstm["monolithic"] + + self.lstm = nn.ModuleList( + [ + nn.LSTM( + ( + wav2vec_dim + if i == 0 + else lstm["hidden_size"] + * (2 if lstm["bidirectional"] else 1) + ), + **one_layer_lstm, + ) + for i in range(num_layers) + ] + ) + + if linear["num_layers"] < 1: + return + lstm_out_features = self.hparams.lstm["hidden_size"] * ( + 2 if self.hparams.lstm["bidirectional"] else 1 + ) + self.linear = nn.ModuleList( + [ + nn.Linear(in_features, out_features) + for in_features, out_features in pairwise( + [ + lstm_out_features, + ] + + [self.hparams.linear["hidden_size"]] + * self.hparams.linear["num_layers"] + ) + ] + ) + + self.pooling = StatsPool(computde_std=False) + self.embeddings = nn.Linear(wav2vec_dim, embedding_dim) + + self.save_hyperparameters("embedding_dim") + + @property + def dimension(self) -> int: + """Dimension of output""" + return self.specifications[Subtasks.index("diarization")].num_powerset_classes + + @lru_cache + def num_frames(self, num_samples: int) -> int: + """Compute number of output frames + + Parameters + ---------- + num_samples : int + Number of input samples. + + Returns + ------- + num_frames : int + Number of output frames. + """ + + num_frames = num_samples + for conv_layer in self.wav2vec.feature_extractor.conv_layers: + num_frames = conv1d_num_frames( + num_frames, + kernel_size=conv_layer.kernel_size, + stride=conv_layer.stride, + padding=conv_layer.conv.padding[0], + dilation=conv_layer.conv.dilation[0], + ) + + return num_frames + + def receptive_field_size(self, num_frames: int = 1) -> int: + """Compute size of receptive field + + Parameters + ---------- + num_frames : int, optional + Number of frames in the output signal + + Returns + ------- + receptive_field_size : int + Receptive field size. + """ + + receptive_field_size = num_frames + for conv_layer in reversed(self.wav2vec.feature_extractor.conv_layers): + receptive_field_size = conv1d_receptive_field_size( + num_frames=receptive_field_size, + kernel_size=conv_layer.kernel_size, + stride=conv_layer.stride, + dilation=conv_layer.conv.dilation[0], + ) + return receptive_field_size + + def receptive_field_center(self, frame: int = 0) -> int: + """Compute center of receptive field + + Parameters + ---------- + frame : int, optional + Frame index + + Returns + ------- + receptive_field_center : int + Index of receptive field center. + """ + receptive_field_center = frame + for conv_layer in reversed(self.wav2vec.feature_extractor.conv_layers): + receptive_field_center = conv1d_receptive_field_center( + receptive_field_center, + kernel_size=conv_layer.kernel_size, + stride=conv_layer.stride, + padding=conv_layer.conv.padding[0], + dilation=conv_layer.conv.dilation[0], + ) + return receptive_field_center + + def build(self): + """""" + max_num_speaker_per_chunk = len(self.specifications[Subtasks.index("diarization")].classes) + max_num_speaker_per_frame = self.specifications[Subtasks.index("diarization")].powerset_max_classes + self.powerset = Powerset( + max_num_speaker_per_chunk, + max_num_speaker_per_frame + ) + + if self.hparams.linear["num_layers"] > 0: + in_features = self.hparams.linear["hidden_size"] + else: + lstm = self.hparams.lstm + in_features = lstm["hidden_size"] * (2 if lstm["bidirectional"] else 1) + + self.classifier = nn.Linear(in_features, self.dimension) + + def forward(self, waveforms: torch.Tensor) -> torch.Tensor: + """Pass forward + + Parameters + ---------- + waveforms : (batch, channel, sample) + + Returns + ------- + diarization, embeddings : (batch, frames, classes), (batch, num_speaker, embed_dim) + """ + + num_layers = ( + None if self.hparams.wav2vec_layer < 0 else self.hparams.wav2vec_layer + ) + + if self.hparams.freeze_wav2vec: + with torch.no_grad(): + outputs, _ = self.wav2vec.extract_features( + waveforms.squeeze(1), num_layers=num_layers + ) + else: + outputs, _ = self.wav2vec.extract_features( + waveforms.squeeze(1), num_layers=num_layers + ) + + if num_layers is None: + dia_outputs = torch.stack(outputs, dim=-1) @ F.softmax( + self.dia_wav2vec_weights, dim=0 + ) + emb_outputs = torch.stack(outputs, dim=-1) @ F.softmax( + self.emb_wav2vec_weights, dim=0 + ) + else: + dia_outputs = emb_outputs = outputs[-1] + + if self.hparams.lstm["monolithic"]: + dia_outputs, _ = self.lstm(dia_outputs) + else: + for i, lstm in enumerate(self.lstm): + dia_outputs, _ = lstm(dia_outputs) + if i + 1 < self.hparams.lstm["num_layers"]: + dia_outputs = self.dropout(dia_outputs) + + if self.hparams.linear["num_layers"] > 0: + for linear in self.linear: + dia_outputs = F.leaky_relu(linear(dia_outputs)) + dia_outputs = self.classifier(dia_outputs) + dia_outputs = F.log_softmax(dia_outputs, dim=-1) + + weights = self.powerset.to_multilabel(dia_outputs, soft=True) + weights = rearrange(weights, "b f s -> b s f") + emb_outputs = rearrange(emb_outputs, "b f w -> b w f") + emb_outputs = self.pooling(emb_outputs, weights) + emb_outputs = self.embeddings(emb_outputs) + + return (dia_outputs, emb_outputs) + + +class WavLMEnd2EndDiarizationv2(Model): + """Self-Supervised representation for joint speaker diarization + and speaker embeddings extraction + + + Parameters + ---------- + sample_rate : int, optional + Audio sample rate. Defaults to 16kHz (16000). + num_channels : int, optional + Number of channels. Defaults to mono (1). + wav2vec: dict or str, optional + Defaults to "WAVLM_BASE". + wav2vec_layer: int, optional + Index of layer to use as input to the LSTM. + Defaults (-1) to use average of all layers (with learnable weights). + freeze_wav2vec: bool, optional + Whether to freeze wa2vec. Default to true + emb_dim: int, optional + Dimension of the speaker embedding in output + """ + + WAV2VEC_DEFAULTS = "WAVLM_BASE" + + LSTM_DEFAULTS = { + "hidden_size": 128, + "num_layers": 4, + "bidirectional": True, + "monolithic": True, + "dropout": 0.0, + } + + LINEAR_DEFAULT = {"hidden_size": 128, "num_layers": 2} + + def __init__( + self, + sample_rate: int = 16000, + num_channels: int = 1, + wav2vec: Union[dict, str] = None, + wav2vec_layer: int = -1, + freeze_wav2vec: bool = True, + lstm: Optional[dict] = None, + linear: Optional[dict] = None, + embedding_dim: Optional[int] = 192, + task: Optional[Task] = None, + ): + super().__init__(sample_rate=sample_rate, num_channels=num_channels, task=task) + + if isinstance(wav2vec, str): + # `wav2vec` is one of the supported pipelines from torchaudio (e.g. "WAVLM_BASE") + if hasattr(torchaudio.pipelines, wav2vec): + bundle = getattr(torchaudio.pipelines, wav2vec) + if sample_rate != bundle.sample_rate: + raise ValueError( + f"Expected {bundle.sample_rate}Hz, found {sample_rate}Hz." + ) + wav2vec_dim = bundle._params["encoder_embed_dim"] + wav2vec_num_layers = bundle._params["encoder_num_layers"] + self.wav2vec = bundle.get_model() + + # `wav2vec` is a path to a self-supervised representation checkpoint + else: + _checkpoint = torch.load(wav2vec) + wav2vec = _checkpoint.pop("config") + self.wav2vec = torchaudio.models.wav2vec2_model(**wav2vec) + state_dict = _checkpoint.pop("state_dict") + self.wav2vec.load_state_dict(state_dict) + wav2vec_dim = wav2vec["encoder_embed_dim"] + wav2vec_num_layers = wav2vec["encoder_num_layers"] + + # `wav2vec` is a config dictionary understood by `wav2vec2_model` + # this branch is typically used by Model.from_pretrained(...) + elif isinstance(wav2vec, dict): + self.wav2vec = torchaudio.models.wav2vec2_model(**wav2vec) + wav2vec_dim = wav2vec["encoder_embed_dim"] + wav2vec_num_layers = wav2vec["encoder_num_layers"] + + if wav2vec_layer < 0: + # weighting parameters for the diarization branch + self.dia_wav2vec_weights = nn.Parameter( + data=torch.ones(wav2vec_num_layers), requires_grad=True + ) + # weighting parameters for the embedding branch + self.emb_wav2vec_weights = nn.Parameter( + data=torch.ones(wav2vec_num_layers), requires_grad=True + ) + self.save_hyperparameters("wav2vec", "wav2vec_layer", "freeze_wav2vec") + + lstm = merge_dict(self.LSTM_DEFAULTS, lstm) + lstm["batch_first"] = True + linear = merge_dict(self.LINEAR_DEFAULT, linear) + self.save_hyperparameters("lstm", "linear") + monolithic = lstm["monolithic"] + if monolithic: + multi_layer_lstm = dict(lstm) + del multi_layer_lstm["monolithic"] + self.lstm = nn.LSTM(wav2vec_dim, **multi_layer_lstm) + else: + num_layers = lstm["num_layers"] + if num_layers > 1: + self.dropout = nn.Dropout(p=lstm["dropout"]) + + one_layer_lstm = dict(lstm) + one_layer_lstm["num_layers"] = 1 + one_layer_lstm["dropout"] = 0.0 + del one_layer_lstm["monolithic"] + + self.lstm = nn.ModuleList( + [ + nn.LSTM( + ( + wav2vec_dim + if i == 0 + else lstm["hidden_size"] + * (2 if lstm["bidirectional"] else 1) + ), + **one_layer_lstm, + ) + for i in range(num_layers) + ] + ) + + if linear["num_layers"] < 1: + return + lstm_out_features = self.hparams.lstm["hidden_size"] * ( + 2 if self.hparams.lstm["bidirectional"] else 1 + ) + self.linear = nn.ModuleList( + [ + nn.Linear(in_features, out_features) + for in_features, out_features in pairwise( + [ + lstm_out_features, + ] + + [self.hparams.linear["hidden_size"]] + * self.hparams.linear["num_layers"] + ) + ] + ) + + self.pooling = StatsPool(computde_std=False) + self.embeddings = nn.Sequential( + nn.Linear(in_features=wav2vec_dim, out_features=1024), + nn.LeakyReLU(), + nn.Linear(in_features=1024, out_features=embedding_dim), + ) + + self.save_hyperparameters("embedding_dim") + + @property + def dimension(self) -> int: + """Dimension of output""" + return self.specifications[Subtasks.index("diarization")].num_powerset_classes + + @lru_cache + def num_frames(self, num_samples: int) -> int: + """Compute number of output frames + + Parameters + ---------- + num_samples : int + Number of input samples. + + Returns + ------- + num_frames : int + Number of output frames. + """ + + num_frames = num_samples + for conv_layer in self.wav2vec.feature_extractor.conv_layers: + num_frames = conv1d_num_frames( + num_frames, + kernel_size=conv_layer.kernel_size, + stride=conv_layer.stride, + padding=conv_layer.conv.padding[0], + dilation=conv_layer.conv.dilation[0], + ) + + return num_frames + + def receptive_field_size(self, num_frames: int = 1) -> int: + """Compute size of receptive field + + Parameters + ---------- + num_frames : int, optional + Number of frames in the output signal + + Returns + ------- + receptive_field_size : int + Receptive field size. + """ + + receptive_field_size = num_frames + for conv_layer in reversed(self.wav2vec.feature_extractor.conv_layers): + receptive_field_size = conv1d_receptive_field_size( + num_frames=receptive_field_size, + kernel_size=conv_layer.kernel_size, + stride=conv_layer.stride, + dilation=conv_layer.conv.dilation[0], + ) + return receptive_field_size + + def receptive_field_center(self, frame: int = 0) -> int: + """Compute center of receptive field + + Parameters + ---------- + frame : int, optional + Frame index + + Returns + ------- + receptive_field_center : int + Index of receptive field center. + """ + receptive_field_center = frame + for conv_layer in reversed(self.wav2vec.feature_extractor.conv_layers): + receptive_field_center = conv1d_receptive_field_center( + receptive_field_center, + kernel_size=conv_layer.kernel_size, + stride=conv_layer.stride, + padding=conv_layer.conv.padding[0], + dilation=conv_layer.conv.dilation[0], + ) + return receptive_field_center + + def build(self): + """""" + max_num_speaker_per_chunk = len(self.specifications[Subtasks.index("diarization")].classes) + max_num_speaker_per_frame = self.specifications[Subtasks.index("diarization")].powerset_max_classes + self.powerset = Powerset( + max_num_speaker_per_chunk, + max_num_speaker_per_frame + ) + + if self.hparams.linear["num_layers"] > 0: + in_features = self.hparams.linear["hidden_size"] + else: + lstm = self.hparams.lstm + in_features = lstm["hidden_size"] * (2 if lstm["bidirectional"] else 1) + + self.classifier = nn.Linear(in_features, self.dimension) + + def forward(self, waveforms: torch.Tensor) -> torch.Tensor: + """Pass forward + + Parameters + ---------- + waveforms : (batch, channel, sample) + + Returns + ------- + diarization, embeddings : (batch, frames, classes), (batch, num_speaker, embed_dim) + """ + + num_layers = ( + None if self.hparams.wav2vec_layer < 0 else self.hparams.wav2vec_layer + ) + + if self.hparams.freeze_wav2vec: + with torch.no_grad(): + outputs, _ = self.wav2vec.extract_features( + waveforms.squeeze(1), num_layers=num_layers + ) + else: + outputs, _ = self.wav2vec.extract_features( + waveforms.squeeze(1), num_layers=num_layers + ) + + if num_layers is None: + dia_outputs = torch.stack(outputs, dim=-1) @ F.softmax( + self.dia_wav2vec_weights, dim=0 + ) + emb_outputs = torch.stack(outputs, dim=-1) @ F.softmax( + self.emb_wav2vec_weights, dim=0 + ) + else: + dia_outputs = emb_outputs = outputs[-1] + + if self.hparams.lstm["monolithic"]: + dia_outputs, _ = self.lstm(dia_outputs) + else: + for i, lstm in enumerate(self.lstm): + dia_outputs, _ = lstm(dia_outputs) + if i + 1 < self.hparams.lstm["num_layers"]: + dia_outputs = self.dropout(dia_outputs) + + if self.hparams.linear["num_layers"] > 0: + for linear in self.linear: + dia_outputs = F.leaky_relu(linear(dia_outputs)) + dia_outputs = self.classifier(dia_outputs) + dia_outputs = F.log_softmax(dia_outputs, dim=-1) + + weights = self.powerset.to_multilabel(dia_outputs, soft=True) + weights = rearrange(weights, "b f s -> b s f") + emb_outputs = rearrange(emb_outputs, "b f w -> b w f") + emb_outputs = self.pooling(emb_outputs, weights) + emb_outputs = self.embeddings(emb_outputs) + + return (dia_outputs, emb_outputs) + + +class WavLMEnd2EndDiarizationv3(Model): + """With modified weights + + Parameters + ---------- + sample_rate : int, optional + Audio sample rate. Defaults to 16kHz (16000). + num_channels : int, optional + Number of channels. Defaults to mono (1). + wav2vec: dict or str, optional + Defaults to "WAVLM_BASE". + wav2vec_layer: int, optional + Index of layer to use as input to the LSTM. + Defaults (-1) to use average of all layers (with learnable weights). + freeze_wav2vec: bool, optional + Whether to freeze wa2vec. Default to true + emb_dim: int, optional + Dimension of the speaker embedding in output + """ + + WAV2VEC_DEFAULTS = "WAVLM_BASE" + + LSTM_DEFAULTS = { + "hidden_size": 128, + "num_layers": 4, + "bidirectional": True, + "monolithic": True, + "dropout": 0.0, + } + + LINEAR_DEFAULT = {"hidden_size": 128, "num_layers": 2} + + def __init__( + self, + sample_rate: int = 16000, + num_channels: int = 1, + wav2vec: Union[dict, str] = None, + wav2vec_layer: int = -1, + freeze_wav2vec: bool = True, + lstm: Optional[dict] = None, + linear: Optional[dict] = None, + embedding_dim: Optional[int] = 192, + task: Optional[Task] = None, + ): + super().__init__(sample_rate=sample_rate, num_channels=num_channels, task=task) + + if isinstance(wav2vec, str): + # `wav2vec` is one of the supported pipelines from torchaudio (e.g. "WAVLM_BASE") + if hasattr(torchaudio.pipelines, wav2vec): + bundle = getattr(torchaudio.pipelines, wav2vec) + if sample_rate != bundle.sample_rate: + raise ValueError( + f"Expected {bundle.sample_rate}Hz, found {sample_rate}Hz." + ) + wav2vec_dim = bundle._params["encoder_embed_dim"] + wav2vec_num_layers = bundle._params["encoder_num_layers"] + self.wav2vec = bundle.get_model() + + # `wav2vec` is a path to a self-supervised representation checkpoint + else: + _checkpoint = torch.load(wav2vec) + wav2vec = _checkpoint.pop("config") + self.wav2vec = torchaudio.models.wav2vec2_model(**wav2vec) + state_dict = _checkpoint.pop("state_dict") + self.wav2vec.load_state_dict(state_dict) + wav2vec_dim = wav2vec["encoder_embed_dim"] + wav2vec_num_layers = wav2vec["encoder_num_layers"] + + # `wav2vec` is a config dictionary understood by `wav2vec2_model` + # this branch is typically used by Model.from_pretrained(...) + elif isinstance(wav2vec, dict): + self.wav2vec = torchaudio.models.wav2vec2_model(**wav2vec) + wav2vec_dim = wav2vec["encoder_embed_dim"] + wav2vec_num_layers = wav2vec["encoder_num_layers"] + + if wav2vec_layer < 0: + # weighting parameters for the diarization branch + self.dia_wav2vec_weights = nn.Parameter( + data=torch.ones(wav2vec_num_layers), requires_grad=True + ) + # weighting parameters for the embedding branch + self.emb_wav2vec_weights = nn.Parameter( + data=torch.ones(wav2vec_num_layers), requires_grad=True + ) + self.save_hyperparameters("wav2vec", "wav2vec_layer", "freeze_wav2vec") + + lstm = merge_dict(self.LSTM_DEFAULTS, lstm) + lstm["batch_first"] = True + linear = merge_dict(self.LINEAR_DEFAULT, linear) + self.save_hyperparameters("lstm", "linear") + monolithic = lstm["monolithic"] + if monolithic: + multi_layer_lstm = dict(lstm) + del multi_layer_lstm["monolithic"] + self.lstm = nn.LSTM(wav2vec_dim, **multi_layer_lstm) + else: + num_layers = lstm["num_layers"] + if num_layers > 1: + self.dropout = nn.Dropout(p=lstm["dropout"]) + + one_layer_lstm = dict(lstm) + one_layer_lstm["num_layers"] = 1 + one_layer_lstm["dropout"] = 0.0 + del one_layer_lstm["monolithic"] + + self.lstm = nn.ModuleList( + [ + nn.LSTM( + ( + wav2vec_dim + if i == 0 + else lstm["hidden_size"] + * (2 if lstm["bidirectional"] else 1) + ), + **one_layer_lstm, + ) + for i in range(num_layers) + ] + ) + + if linear["num_layers"] < 1: + return + lstm_out_features = self.hparams.lstm["hidden_size"] * ( + 2 if self.hparams.lstm["bidirectional"] else 1 + ) + self.linear = nn.ModuleList( + [ + nn.Linear(in_features, out_features) + for in_features, out_features in pairwise( + [ + lstm_out_features, + ] + + [self.hparams.linear["hidden_size"]] + * self.hparams.linear["num_layers"] + ) + ] + ) + + self.pooling = StatsPool(computde_std=False) + self.embeddings = nn.Sequential( + nn.Linear(in_features=wav2vec_dim, out_features=1024), + nn.LeakyReLU(), + nn.Linear(in_features=1024, out_features=embedding_dim), + ) + + self.save_hyperparameters("embedding_dim") + + @property + def dimension(self) -> int: + """Dimension of output""" + return self.specifications[Subtasks.index("diarization")].num_powerset_classes + + @lru_cache + def num_frames(self, num_samples: int) -> int: + """Compute number of output frames + + Parameters + ---------- + num_samples : int + Number of input samples. + + Returns + ------- + num_frames : int + Number of output frames. + """ + + num_frames = num_samples + for conv_layer in self.wav2vec.feature_extractor.conv_layers: + num_frames = conv1d_num_frames( + num_frames, + kernel_size=conv_layer.kernel_size, + stride=conv_layer.stride, + padding=conv_layer.conv.padding[0], + dilation=conv_layer.conv.dilation[0], + ) + + return num_frames + + def receptive_field_size(self, num_frames: int = 1) -> int: + """Compute size of receptive field + + Parameters + ---------- + num_frames : int, optional + Number of frames in the output signal + + Returns + ------- + receptive_field_size : int + Receptive field size. + """ + + receptive_field_size = num_frames + for conv_layer in reversed(self.wav2vec.feature_extractor.conv_layers): + receptive_field_size = conv1d_receptive_field_size( + num_frames=receptive_field_size, + kernel_size=conv_layer.kernel_size, + stride=conv_layer.stride, + dilation=conv_layer.conv.dilation[0], + ) + return receptive_field_size + + def receptive_field_center(self, frame: int = 0) -> int: + """Compute center of receptive field + + Parameters + ---------- + frame : int, optional + Frame index + + Returns + ------- + receptive_field_center : int + Index of receptive field center. + """ + receptive_field_center = frame + for conv_layer in reversed(self.wav2vec.feature_extractor.conv_layers): + receptive_field_center = conv1d_receptive_field_center( + receptive_field_center, + kernel_size=conv_layer.kernel_size, + stride=conv_layer.stride, + padding=conv_layer.conv.padding[0], + dilation=conv_layer.conv.dilation[0], + ) + return receptive_field_center + + def build(self): + """""" + max_num_speaker_per_chunk = len(self.specifications[Subtasks.index("diarization")].classes) + max_num_speaker_per_frame = self.specifications[Subtasks.index("diarization")].powerset_max_classes + self.powerset = Powerset( + max_num_speaker_per_chunk, + max_num_speaker_per_frame + ) + + if self.hparams.linear["num_layers"] > 0: + in_features = self.hparams.linear["hidden_size"] + else: + lstm = self.hparams.lstm + in_features = lstm["hidden_size"] * (2 if lstm["bidirectional"] else 1) + + self.classifier = nn.Linear(in_features, self.dimension) + + def forward(self, waveforms: torch.Tensor) -> torch.Tensor: + """Pass forward + + Parameters + ---------- + waveforms : (batch, channel, sample) + + Returns + ------- + diarization, embeddings : (batch, frames, classes), (batch, num_speaker, embed_dim) + """ + + num_layers = ( + None if self.hparams.wav2vec_layer < 0 else self.hparams.wav2vec_layer + ) + + if self.hparams.freeze_wav2vec: + with torch.no_grad(): + outputs, _ = self.wav2vec.extract_features( + waveforms.squeeze(1), num_layers=num_layers + ) + else: + outputs, _ = self.wav2vec.extract_features( + waveforms.squeeze(1), num_layers=num_layers + ) + + if num_layers is None: + dia_outputs = torch.stack(outputs, dim=-1) @ F.softmax( + self.dia_wav2vec_weights, dim=0 + ) + emb_outputs = torch.stack(outputs, dim=-1) @ F.softmax( + self.emb_wav2vec_weights, dim=0 + ) + else: + dia_outputs = emb_outputs = outputs[-1] + + if self.hparams.lstm["monolithic"]: + dia_outputs, _ = self.lstm(dia_outputs) + else: + for i, lstm in enumerate(self.lstm): + dia_outputs, _ = lstm(dia_outputs) + if i + 1 < self.hparams.lstm["num_layers"]: + dia_outputs = self.dropout(dia_outputs) + + if self.hparams.linear["num_layers"] > 0: + for linear in self.linear: + dia_outputs = F.leaky_relu(linear(dia_outputs)) + dia_outputs = self.classifier(dia_outputs) + dia_outputs = F.log_softmax(dia_outputs, dim=-1) + + # hard-segmentation in multilabel space + multilabel_segmentations: torch.Tensor = self.powerset.to_multilabel(dia_outputs) + # (batch_size, num_frames, max_speakers_per_chunk), {0, 1} + + weights = ( + ( + F.one_hot( + torch.argmax(dia_outputs, dim=2), + num_classes=self.powerset.num_powerset_classes, + )[:, :, 1 : 1 + self.powerset.num_classes] + + 1e-2 + ) + * multilabel_segmentations + ).transpose(2, 1) + # (batch_size, max_speakers_per_chunk, num_frames) + # 0.000 if speaker is inactive + # 0.001 if speaker is active but not alone + # 1.001 if speaker is active and alone + + emb_outputs = rearrange(emb_outputs, "b f w -> b w f") + emb_outputs = self.pooling(emb_outputs, weights) + emb_outputs = self.embeddings(emb_outputs) + + return (dia_outputs, emb_outputs) diff --git a/pyannote/audio/pipelines/speaker_diarization.py b/pyannote/audio/pipelines/speaker_diarization.py index e0d43e30c..7e89a70fc 100644 --- a/pyannote/audio/pipelines/speaker_diarization.py +++ b/pyannote/audio/pipelines/speaker_diarization.py @@ -35,10 +35,12 @@ from pyannote.core import Annotation, SlidingWindowFeature from pyannote.metrics.diarization import GreedyDiarizationErrorRate from pyannote.pipeline.parameter import ParamDict, Uniform +from sklearn.cluster import AgglomerativeClustering from pyannote.audio import Audio, Inference, Model, Pipeline from pyannote.audio.core.io import AudioFile -from pyannote.audio.pipelines.clustering import Clustering +from pyannote.audio.core.task import Problem, Resolution +from pyannote.audio.pipelines.clustering import AgglomerativeClustering, Clustering from pyannote.audio.pipelines.speaker_verification import PretrainedSpeakerEmbedding from pyannote.audio.pipelines.utils import ( PipelineModel, @@ -663,3 +665,374 @@ def apply( def get_metric(self) -> GreedyDiarizationErrorRate: return GreedyDiarizationErrorRate(**self.der_variant) + + +class SpeakerDiarizationV2(SpeakerDiarizationMixin, Pipeline): + """Speaker diarization pipeline with joint segmentation + embedding model + + Parameters + ---------- + model : Model, str, or dict, optional + Pretrained (segmentation + embedding) model. + See pyannote.audio.pipelines.utils.get_model for supported format. + step: float, optional + The model is applied on a window sliding over the whole audio file. + `step` controls the step of this window, provided as a ratio of its + duration. Defaults to 0.1 (i.e. 90% overlap between two consecuive windows). + clustering : str, optional + Clustering algorithm. See pyannote.audio.pipelines.clustering.Clustering + for available options. Defaults to "AgglomerativeClustering". + batch_size : int, optional + Batch size used for inference. Defaults to 1. + use_auth_token : str, optional + When loading private huggingface.co models, set `use_auth_token` + to True or to a string containing your hugginface.co authentication + token that can be obtained by running `huggingface-cli login` + + Usage + ----- + # perform (unconstrained) diarization + >>> diarization = pipeline("/path/to/audio.wav") + + # perform diarization, targetting exactly 4 speakers + >>> diarization = pipeline("/path/to/audio.wav", num_speakers=4) + + # perform diarization, with at least 2 speakers and at most 10 speakers + >>> diarization = pipeline("/path/to/audio.wav", min_speakers=2, max_speakers=10) + + # perform diarization and get one representative embedding per speaker + >>> diarization, embeddings = pipeline("/path/to/audio.wav", return_embeddings=True) + >>> for s, speaker in enumerate(diarization.labels()): + ... # embeddings[s] is the embedding of speaker `speaker` + + """ + + def __init__( + self, + model: PipelineModel = None, + step: float = 0.1, + clustering: str = "AgglomerativeClustering", + batch_size: int = 1, + use_auth_token: Union[Text, None] = None, + ): + super().__init__() + + self.model = model + model: Model = get_model(model, use_auth_token=use_auth_token) + + assert len(model.specifications) == 2 + segmentation_specifications, embedding_specifications = model.specifications + # TODO: check that specs are correct + assert segmentation_specifications.problem == Problem.MONO_LABEL_CLASSIFICATION + assert segmentation_specifications.resolution == Resolution.FRAME + assert embedding_specifications.problem == Problem.REPRESENTATION + assert embedding_specifications.resolution == Resolution.CHUNK + + self.step = step + self.klustering = clustering + + duration: float = segmentation_specifications.duration + self._inference = Inference( + model, + duration=duration, + step=self.step * duration, + skip_aggregation=True, + skip_conversion=False, # <-- output multilabel segmentation + batch_size=batch_size, + ) + + self.clustering = AgglomerativeClustering(metric="cosine") + + @property + def batch_size(self) -> int: + return self._inference.batch_size + + @batch_size.setter + def batch_size(self, batch_size: int): + self._inference.batch_size = batch_size + + def default_parameters(self): + raise NotImplementedError() + + def classes(self): + speaker = 0 + while True: + yield f"SPEAKER_{speaker:02d}" + speaker += 1 + + @property + def CACHED_INFERENCE(self): + return "training_cache/inference" + + def get_inference(self, file, hook=None) -> Tuple[SlidingWindowFeature]: + """Apply joint model + + Parameter + --------- + file : AudioFile + hook : Optional[Callable] + + Returns + ------- + segmentations : (num_chunks, num_frames, num_speakers) SlidingWindowFeature + embeddings : (num_chunks, num_speakers, dimension) SlidingWindowFeature + """ + + if hook is not None: + hook = functools.partial(hook, "inference", None) + + if self.training: + if self.CACHED_INFERENCE in file: + inference = file[self.CACHED_INFERENCE] + else: + inference = self._inference(file, hook=hook) + file[self.CACHED_INFERENCE] = inference + else: + inference = self._inference(file, hook=hook) + + return inference + + def reconstruct( + self, + segmentations: SlidingWindowFeature, + hard_clusters: np.ndarray, + count: SlidingWindowFeature, + ) -> SlidingWindowFeature: + """Build final discrete diarization out of clustered segmentation + + Parameters + ---------- + segmentations : (num_chunks, num_frames, num_speakers) SlidingWindowFeature + Raw speaker segmentation. + hard_clusters : (num_chunks, num_speakers) array + Output of clustering step. + count : (total_num_frames, 1) SlidingWindowFeature + Instantaneous number of active speakers. + + Returns + ------- + discrete_diarization : SlidingWindowFeature + Discrete (0s and 1s) diarization. + """ + + num_chunks, num_frames, local_num_speakers = segmentations.data.shape + + num_clusters = np.max(hard_clusters) + 1 + clustered_segmentations = np.nan * np.zeros( + (num_chunks, num_frames, num_clusters) + ) + + for c, (cluster, (chunk, segmentation)) in enumerate( + zip(hard_clusters, segmentations) + ): + # cluster is (local_num_speakers, )-shaped + # segmentation is (num_frames, local_num_speakers)-shaped + for k in np.unique(cluster): + if k == -2: + continue + + # TODO: can we do better than this max here? + clustered_segmentations[c, :, k] = np.max( + segmentation[:, cluster == k], axis=1 + ) + + clustered_segmentations = SlidingWindowFeature( + clustered_segmentations, segmentations.sliding_window + ) + + return self.to_diarization(clustered_segmentations, count) + + def apply( + self, + file: AudioFile, + num_speakers: Optional[int] = None, + min_speakers: Optional[int] = None, + max_speakers: Optional[int] = None, + return_embeddings: bool = False, + hook: Optional[Callable] = None, + ) -> Union[Annotation, Tuple[Annotation, np.ndarray]]: + """Apply speaker diarization + + Parameters + ---------- + file : AudioFile + Processed file. + num_speakers : int, optional + Number of speakers, when known. + min_speakers : int, optional + Minimum number of speakers. Has no effect when `num_speakers` is provided. + max_speakers : int, optional + Maximum number of speakers. Has no effect when `num_speakers` is provided. + return_embeddings : bool, optional + Return representative speaker embeddings. + hook : callable, optional + Callback called after each major steps of the pipeline as follows: + hook(step_name, # human-readable name of current step + step_artefact, # artifact generated by current step + file=file) # file being processed + Time-consuming steps call `hook` multiple times with the same `step_name` + and additional `completed` and `total` keyword arguments usable to track + progress of current step. + + Returns + ------- + diarization : Annotation + Speaker diarization + embeddings : np.array, optional + Representative speaker embeddings such that `embeddings[i]` is the + speaker embedding for i-th speaker in diarization.labels(). + Only returned when `return_embeddings` is True. + """ + + # setup hook (e.g. for debugging purposes) + hook = self.setup_hook(file, hook=hook) + + num_speakers, min_speakers, max_speakers = self.set_num_speakers( + num_speakers=num_speakers, + min_speakers=min_speakers, + max_speakers=max_speakers, + ) + + inference = self.get_inference(file, hook=hook) + hook("inference", inference) + binarized_segmentations, embeddings = inference + # shape: (num_chunks, num_frames, local_num_speakers) + num_chunks, num_frames, local_num_speakers = binarized_segmentations.data.shape + _, _, dimension = embeddings.data.shape + + # estimate frame-level number of instantaneous speakers + count = self.speaker_count( + binarized_segmentations, + self._inference.model.receptive_field, + warm_up=(0.0, 0.0), + ) + hook("speaker_counting", count) + # shape: (num_frames, 1) + # dtype: int + + # exit early when no speaker is ever active + if np.nanmax(count.data) == 0.0: + diarization = Annotation(uri=file["uri"]) + if return_embeddings: + return diarization, np.zeros((0, dimension)) + + return diarization + + hard_clusters, _, centroids = self.clustering( + embeddings=embeddings.data, + segmentations=binarized_segmentations, + num_clusters=num_speakers, + min_clusters=min_speakers, + max_clusters=max_speakers, + file=file, # <== for oracle clustering + frames=self._inference.model.receptive_field, # <== for oracle clustering + ) + # hard_clusters: (num_chunks, num_speakers) + # centroids: (num_speakers, dimension) + + # number of detected clusters is the number of different speakers + num_different_speakers = np.max(hard_clusters) + 1 + + # detected number of speakers can still be out of bounds + # (specifically, lower than `min_speakers`), since there could be too few embeddings + # to make enough clusters with a given minimum cluster size. + if ( + num_different_speakers < min_speakers + or num_different_speakers > max_speakers + ): + warnings.warn( + textwrap.dedent( + f""" + The detected number of speakers ({num_different_speakers}) is outside + the given bounds [{min_speakers}, {max_speakers}]. This can happen if the + given audio file is too short to contain {min_speakers} or more speakers. + Try to lower the desired minimal number of speakers. + """ + ) + ) + + # during counting, we could possibly overcount the number of instantaneous + # speakers due to segmentation errors, so we cap the maximum instantaneous number + # of speakers by the `max_speakers` value + count.data = np.minimum(count.data, max_speakers).astype(np.int8) + + # reconstruct discrete diarization from raw hard clusters + + # keep track of inactive speakers + inactive_speakers = np.sum(binarized_segmentations.data, axis=1) == 0 + # shape: (num_chunks, num_speakers) + + hard_clusters[inactive_speakers] = -2 + discrete_diarization = self.reconstruct( + binarized_segmentations, + hard_clusters, + count, + ) + hook("discrete_diarization", discrete_diarization) + + # convert to continuous diarization + diarization = self.to_annotation( + discrete_diarization, + min_duration_on=0.0, + min_duration_off=0.0, + ) + diarization.uri = file["uri"] + + # at this point, `diarization` speaker labels are integers + # from 0 to `num_speakers - 1`, aligned with `centroids` rows. + + if "annotation" in file and file["annotation"]: + # when reference is available, use it to map hypothesized speakers + # to reference speakers (this makes later error analysis easier + # but does not modify the actual output of the diarization pipeline) + _, mapping = self.optimal_mapping( + file["annotation"], diarization, return_mapping=True + ) + + # in case there are more speakers in the hypothesis than in + # the reference, those extra speakers are missing from `mapping`. + # we add them back here + mapping = {key: mapping.get(key, key) for key in diarization.labels()} + + else: + # when reference is not available, rename hypothesized speakers + # to human-readable SPEAKER_00, SPEAKER_01, ... + mapping = { + label: expected_label + for label, expected_label in zip(diarization.labels(), self.classes()) + } + + diarization = diarization.rename_labels(mapping=mapping) + + # at this point, `diarization` speaker labels are strings (or mix of + # strings and integers when reference is available and some hypothesis + # speakers are not present in the reference) + + if not return_embeddings: + return diarization + + # this can happen when we use OracleClustering + if centroids is None: + return diarization, None + + # The number of centroids may be smaller than the number of speakers + # in the annotation. This can happen if the number of active speakers + # obtained from `speaker_count` for some frames is larger than the number + # of clusters obtained from `clustering`. In this case, we append zero embeddings + # for extra speakers + if len(diarization.labels()) > centroids.shape[0]: + centroids = np.pad( + centroids, ((0, len(diarization.labels()) - centroids.shape[0]), (0, 0)) + ) + + # re-order centroids so that they match + # the order given by diarization.labels() + inverse_mapping = {label: index for index, label in mapping.items()} + centroids = centroids[ + [inverse_mapping[label] for label in diarization.labels()] + ] + + return diarization, centroids + + def get_metric(self) -> GreedyDiarizationErrorRate: + return GreedyDiarizationErrorRate(**self.der_variant) diff --git a/pyannote/audio/tasks/__init__.py b/pyannote/audio/tasks/__init__.py index 517c6dd55..e23a4751c 100644 --- a/pyannote/audio/tasks/__init__.py +++ b/pyannote/audio/tasks/__init__.py @@ -29,6 +29,8 @@ ) from .embedding.arcface import SupervisedRepresentationLearningWithArcFace # isort:skip +from .joint_task.speaker_diarization_and_embedding import JointSpeakerDiarizationAndEmbedding + # Segmentation has been renamed to SpeakerDiarization but we keep Segmentation here for backward compatibility Segmentation = SpeakerDiarization @@ -42,5 +44,6 @@ "MultiLabelSegmentation", "SpeakerEmbedding", "Segmentation", + "JointSpeakerDiarizationAndEmbedding", "PixIT", ] diff --git a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py new file mode 100644 index 000000000..578da632f --- /dev/null +++ b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py @@ -0,0 +1,1399 @@ +# MIT License +# +# Copyright (c) 2023- CNRS +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +from collections import defaultdict +import itertools +from pathlib import Path +import random +import warnings +from tempfile import mkstemp +from typing import Dict, Literal, Optional, Sequence, Tuple, Union + +import numpy as np +import torch +from einops import rearrange +from matplotlib import pyplot as plt +from pyannote.core import ( + Annotation, + Segment, + SlidingWindow, + SlidingWindowFeature, + Timeline, +) +from pyannote.database.protocol.protocol import Scope, Subset +from pytorch_metric_learning.losses import ArcFaceLoss +from torch_audiomentations.core.transforms_interface import BaseWaveformTransform +from torchmetrics import Metric + +from scipy.spatial.distance import cdist + +from pyannote.audio.core.task import Problem, Resolution, Specifications, get_dtype +from pyannote.audio.tasks import SpeakerDiarization +from pyannote.audio.torchmetrics import ( + DiarizationErrorRate, + FalseAlarmRate, + MissedDetectionRate, + SpeakerConfusionRate, +) +from pyannote.audio.utils.loss import nll_loss +from pyannote.audio.utils.permutation import permutate +from pyannote.audio.utils.random import create_rng_for_worker +from pyannote.audio.pipelines.clustering import KMeansClustering, OracleClustering +from pyannote.audio.pipelines.utils import SpeakerDiarizationMixin +from pyannote.audio.core.io import Audio + +from pyannote.metrics.diarization import ( + DiarizationErrorRate as GlobalDiarizationErrorRate, +) + +Subtask = Literal["diarization", "embedding"] + +Subsets = list(Subset.__args__) +Scopes = list(Scope.__args__) +Subtasks = list(Subtask.__args__) + + +class JointSpeakerDiarizationAndEmbedding(SpeakerDiarization): + """Joint speaker diarization and embedding task + + Usage + ----- + load a meta protocol containing both diarization (e.g. X.SpeakerDiarization.Pretraining) + and verification (e.g. VoxCeleb.SpeakerVerification.VoxCeleb) datasets + >>> from pyannote.database import registry + >>> protocol = registry.get_protocol(...) + + instantiate task + >>> task = JointSpeakerDiarizationAndEmbedding(protocol) + + instantiate multi-task model + >>> model = JointSpeakerDiarizationAndEmbeddingModel() + >>> model.task = task + + train as usual... + + """ + + def __init__( + self, + protocol, + duration: float = 5.0, + max_speakers_per_chunk: int = 3, + max_speakers_per_frame: int = 2, + weigh_by_cardinality: bool = False, + batch_size: int = 32, + dia_task_rate: float = 0.5, + num_workers: int = None, + pin_memory: bool = False, + margin: float = 28.6, + scale: float = 64.0, + alpha: float = 0.5, + augmentation: BaseWaveformTransform = None, + cache: Optional[Union[str, None]] = None, + ) -> None: + """TODO Add docstring""" + super().__init__( + protocol, + duration=duration, + max_speakers_per_chunk=max_speakers_per_chunk, + max_speakers_per_frame=max_speakers_per_frame, + weigh_by_cardinality=weigh_by_cardinality, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + augmentation=augmentation, + cache=cache, + ) + + self.num_dia_samples = int(batch_size * dia_task_rate) + self.margin = margin + self.scale = scale + self.alpha = alpha + # keep track of the use of database available in the meta protocol + # * embedding databases are those with global speaker label scope + # * diarization databases are those with file or database speaker label scope + self.embedding_files_id = [] + + def prepare_data(self): + """Use this to prepare data from task protocol + + Notes + ----- + Called only once on the main process (and only on it), for global_rank 0. + + After this method is called, the task should have a `prepared_data` attribute + with the following dictionary structure: + + prepared_data = { + 'protocol': name of the protocol + 'audio-path': array of N paths to audio + 'audio-metadata': array of N audio infos such as audio subset, scope and database + 'audio-info': array of N audio torchaudio.info struct + 'audio-encoding': array of N audio encodings + 'audio-annotated': array of N annotated duration (usually equals file duration but might be shorter if file is not fully annotated) + 'annotations-regions': array of M annotated regions + 'annotations-segments': array of M' annotated segments + 'metadata-values': dict of lists of values for subset, scope and database + 'metadata-`database-name`-labels': array of `database-name` labels. Each database with "database" scope labels has it own array. + 'metadata-labels': array of global scope labels + } + + """ + + if self.cache: + # check if cache exists and is not empty: + if self.cache.exists() and self.cache.stat().st_size > 0: + # data was already created, nothing to do + return + # create parent directory if needed + self.cache.parent.mkdir(parents=True, exist_ok=True) + else: + # if no cache was provided by user, create a temporary file + # in system directory used for temp files + self.cache = Path(mkstemp()[1]) + + # list of possible values for each metadata key + # (will become .prepared_data[""]) + metadata_unique_values = defaultdict(list) + metadata_unique_values["subset"] = Subsets + metadata_unique_values["scope"] = Scopes + + audios = list() # list of path to audio files + audio_infos = list() + audio_encodings = list() + metadata = list() # list of metadata + + annotated_duration = list() # total duration of annotated regions (per file) + annotated_regions = list() # annotated regions + annotations = list() # actual annotations + unique_labels = list() + database_unique_labels = {} + + if self.has_validation: + files_iter = itertools.chain( + zip(itertools.repeat("train"), self.protocol.train()), + zip(itertools.repeat("development"), self.protocol.development()), + ) + else: + files_iter = zip(itertools.repeat("train"), self.protocol.train()) + + for file_id, (subset, file) in enumerate(files_iter): + # gather metadata and update metadata_unique_values so that each metadatum + # (e.g. source database or label) is represented by an integer. + metadatum = dict() + + # keep track of source database and subset (train, development, or test) + if file["database"] not in metadata_unique_values["database"]: + metadata_unique_values["database"].append(file["database"]) + metadatum["database"] = metadata_unique_values["database"].index( + file["database"] + ) + + metadatum["subset"] = Subsets.index(subset) + + # keep track of label scope (file, database, or global) + metadatum["scope"] = Scopes.index(file["scope"]) + + remaining_metadata_keys = set(file) - set( + [ + "uri", + "database", + "subset", + "audio", + "torchaudio.info", + "scope", + "classes", + "annotation", + "annotated", + ] + ) + + # keep track of any other (integer or string) metadata provided by the protocol + # (e.g. a "domain" key for domain-adversarial training) + for key in remaining_metadata_keys: + value = file[key] + + if isinstance(value, str): + if value not in metadata_unique_values[key]: + metadata_unique_values[key].append(value) + metadatum[key] = metadata_unique_values[key].index(value) + + elif isinstance(value, int): + metadatum[key] = value + + else: + warnings.warn( + f"Ignoring '{key}' metadata because of its type ({type(value)}). Only str and int are supported for now.", + category=UserWarning, + ) + + metadata.append(metadatum) + + # reset list of file-scoped labels + file_unique_labels = list() + + # path to audio file + audios.append(str(file["audio"])) + + # audio info + audio_info = file["torchaudio.info"] + audio_infos.append( + ( + audio_info.sample_rate, # sample rate + audio_info.num_frames, # number of frames + audio_info.num_channels, # number of channels + audio_info.bits_per_sample, # bits per sample + ) + ) + audio_encodings.append(audio_info.encoding) # encoding + + # annotated regions and duration + _annotated_duration = 0.0 + for segment in file["annotated"]: + # skip annotated regions that are shorter than training chunk duration + # if segment.duration < self.duration: + # continue + + # append annotated region + annotated_region = ( + file_id, + segment.duration, + segment.start, + ) + annotated_regions.append(annotated_region) + + # increment annotated duration + _annotated_duration += segment.duration + + # append annotated duration + annotated_duration.append(_annotated_duration) + + # annotations + for segment, _, label in file["annotation"].itertracks(yield_label=True): + # "scope" is provided by speaker diarization protocols to indicate + # whether speaker labels are local to the file ('file'), consistent across + # all files in a database ('database'), or globally consistent ('global') + + # 0 = 'file' / 1 = 'database' / 2 = 'global' + scope = Scopes.index(file["scope"]) + + # update list of file-scope labels + if label not in file_unique_labels: + file_unique_labels.append(label) + # and convert label to its (file-scope) index + file_label_idx = file_unique_labels.index(label) + + database_label_idx = global_label_idx = -1 + + if scope > 0: # 'database' or 'global' + # update list of database-scope labels + database = file["database"] + if database not in database_unique_labels: + database_unique_labels[database] = [] + if label not in database_unique_labels[database]: + database_unique_labels[database].append(label) + + # and convert label to its (database-scope) index + database_label_idx = database_unique_labels[database].index(label) + + if scope > 1: # 'global' + # update list of global-scope labels + if label not in unique_labels: + unique_labels.append(label) + # and convert label to its (global-scope) index + global_label_idx = unique_labels.index(label) + + annotations.append( + ( + file_id, # index of file + segment.start, # start time + segment.end, # end time + file_label_idx, # file-scope label index + database_label_idx, # database-scope label index + global_label_idx, # global-scope index + ) + ) + + # since not all metadata keys are present in all files, fallback to -1 when a key is missing + metadata = [ + tuple(metadatum.get(key, -1) for key in metadata_unique_values) + for metadatum in metadata + ] + metadata_dtype = [ + (key, get_dtype(max(m[i] for m in metadata))) + for i, key in enumerate(metadata_unique_values) + ] + + # turn list of files metadata into a single numpy array + # TODO: improve using https://github.com/pytorch/pytorch/issues/13246#issuecomment-617140519 + info_dtype = [ + ( + "sample_rate", + get_dtype(max(ai[0] for ai in audio_infos)), + ), + ( + "num_frames", + get_dtype(max(ai[1] for ai in audio_infos)), + ), + ("num_channels", "B"), + ("bits_per_sample", "B"), + ] + + # turn list of annotated regions into a single numpy array + region_dtype = [ + ( + "file_id", + get_dtype(max(ar[0] for ar in annotated_regions)), + ), + ("duration", "f"), + ("start", "f"), + ] + + # turn list of annotations into a single numpy array + segment_dtype = [ + ( + "file_id", + get_dtype(max(a[0] for a in annotations)), + ), + ("start", "f"), + ("end", "f"), + ("file_label_idx", get_dtype(max(a[3] for a in annotations))), + ("database_label_idx", get_dtype(max(a[4] for a in annotations))), + ("global_label_idx", get_dtype(max(a[5] for a in annotations))), + ] + + # save all protocol data in a dict + prepared_data = {} + + # keep track of protocol name + prepared_data["protocol"] = self.protocol.name + + prepared_data["audio-path"] = np.array(audios, dtype=np.str_) + audios.clear() + + prepared_data["audio-metadata"] = np.array(metadata, dtype=metadata_dtype) + metadata.clear() + + prepared_data["audio-info"] = np.array(audio_infos, dtype=info_dtype) + audio_infos.clear() + + prepared_data["audio-encoding"] = np.array(audio_encodings, dtype=np.str_) + audio_encodings.clear() + + prepared_data["audio-annotated"] = np.array(annotated_duration) + annotated_duration.clear() + + prepared_data["annotations-regions"] = np.array( + annotated_regions, dtype=region_dtype + ) + annotated_regions.clear() + + prepared_data["annotations-segments"] = np.array( + annotations, dtype=segment_dtype + ) + annotations.clear() + + prepared_data["metadata-values"] = metadata_unique_values + + for database, labels in database_unique_labels.items(): + prepared_data[f"metadata-{database}-labels"] = np.array( + labels, dtype=np.str_ + ) + database_unique_labels.clear() + + prepared_data["metadata-labels"] = np.array(unique_labels, dtype=np.str_) + unique_labels.clear() + + if self.has_validation: + self.prepare_validation(prepared_data) + + self.post_prepare_data(prepared_data) + + # save prepared data on the disk + with open(self.cache, "wb") as cache_file: + np.savez_compressed(cache_file, **prepared_data) + + def prepare_validation(self, prepared_data: Dict) -> None: + """Each validation batch correspond to a part of a validation file""" + validation_mask = prepared_data["audio-metadata"]["subset"] == Subsets.index( + "development" + ) + prepared_data["validation-files"] = np.argwhere(validation_mask).reshape((-1,)) + + def setup(self, stage="fit"): + """Setup method + + Parameters + ---------- + stage : {'fit', 'validate', 'test'}, optional + Setup stage. Defaults to 'fit'. + """ + + super().setup() + + global_scope_mask = ( + self.prepared_data["annotations-segments"]["global_label_idx"] > -1 + ) + self.embedding_files_id = np.unique( + self.prepared_data["annotations-segments"]["file_id"][global_scope_mask] + ) + embedding_classes = np.unique( + self.prepared_data["annotations-segments"]["global_label_idx"][ + global_scope_mask + ] + ) + + # if there is no file dedicated to the embedding task + if self.alpha != 1.0 and len(embedding_classes) == 0: + self.num_dia_samples = self.batch_size + self.alpha = 1.0 + warnings.warn( + "No class found for the speaker embedding task. Model will be trained on the speaker diarization task only." + ) + + if self.alpha != 0.0 and np.sum(global_scope_mask) == len( + self.prepared_data["annotations-segments"] + ): + self.num_dia_samples = 0 + self.alpha = 0.0 + warnings.warn( + "No segment found for the speaker diarization task. Model will be trained on the speaker embedding task only." + ) + + speaker_diarization = Specifications( + duration=self.duration, + resolution=Resolution.FRAME, + problem=Problem.MONO_LABEL_CLASSIFICATION, + permutation_invariant=True, + classes=[f"speaker{i+1}" for i in range(self.max_speakers_per_chunk)], + powerset_max_classes=self.max_speakers_per_frame, + ) + speaker_embedding = Specifications( + duration=self.duration, + resolution=Resolution.CHUNK, + problem=Problem.REPRESENTATION, + classes=embedding_classes, + ) + self.specifications = (speaker_diarization, speaker_embedding) + + def prepare_chunk(self, file_id: int, start_time: float, duration: float): + """Prepare chunk + + Parameters + ---------- + file_id : int + File index + start_time : float + Chunk start time + duration : float + Chunk duration. + + Returns + ------- + sample : dict + Dictionary containing the chunk data with the following keys: + - `X`: waveform + - `y`: target as a SlidingWindowFeature instance where y.labels is + in meta.scope space. + - `meta`: + - `scope`: target scope (0: file, 1: database, 2: global) + - `database`: database index + - `file`: file index + """ + + file = self.get_file(file_id) + + # get label scope + label_scope = Scopes[self.prepared_data["audio-metadata"][file_id]["scope"]] + label_scope_key = f"{label_scope}_label_idx" + + chunk = Segment(start_time, start_time + duration) + + sample = dict() + sample["X"], _ = self.model.audio.crop( + file, chunk, duration=duration, mode="pad" + ) + + # gather all annotations of current file + annotations = self.prepared_data["annotations-segments"][ + self.prepared_data["annotations-segments"]["file_id"] == file_id + ] + + # gather all annotations with non-empty intersection with current chunk + chunk_annotations = annotations[ + (annotations["start"] < chunk.end) & (annotations["end"] > chunk.start) + ] + + # discretize chunk annotations at model output resolution + step = self.model.receptive_field.step + half = 0.5 * self.model.receptive_field.duration + + start = np.maximum(chunk_annotations["start"], chunk.start) - chunk.start - half + start_idx = np.maximum(0, np.round(start / step)).astype(int) + + end = np.minimum(chunk_annotations["end"], chunk.end) - chunk.start - half + end_idx = np.round(end / step).astype(int) + + # get list and number of labels for current scope + labels = list(np.unique(chunk_annotations[label_scope_key])) + num_labels = len(labels) + + if num_labels > self.max_speakers_per_chunk: + pass + + # initial frame-level targets + num_frames = self.model.num_frames( + round(duration * self.model.hparams.sample_rate) + ) + y = np.zeros((num_frames, num_labels), dtype=np.uint8) + + # map labels to indices + mapping = {label: idx for idx, label in enumerate(labels)} + + for start, end, label in zip( + start_idx, end_idx, chunk_annotations[label_scope_key] + ): + mapped_label = mapping[label] + y[start : end + 1, mapped_label] = 1 + + sample["y"] = SlidingWindowFeature(y, self.model.receptive_field, labels=labels) + + metadata = self.prepared_data["audio-metadata"][file_id] + sample["meta"] = {key: metadata[key] for key in metadata.dtype.names} + sample["meta"]["file"] = file_id + + return sample + + def draw_diarization_chunk( + self, + file_ids: np.ndarray, + cum_prob_annotated_duration: np.ndarray, + rng: random.Random, + duration: float, + ) -> tuple: + """Sample one chunk for the diarization task + + Parameters + ---------- + file_ids: np.ndarray + array containing files id + cum_prob_annotated_duration: np.ndarray + array of the same size than file_ids array, containing probability + to corresponding file to be drawn + rng : random.Random + Random number generator + duration: float + duration of the chunk to draw + """ + # select one file at random (wiht probability proportional to its annotated duration) + file_id = file_ids[cum_prob_annotated_duration.searchsorted(rng.random())] + # find indices of annotated regions in this file + annotated_region_indices = np.where( + self.prepared_data["annotations-regions"]["file_id"] == file_id + )[0] + + # turn annotated regions duration into a probability distribution + cum_prob_annotaded_regions_duration = np.cumsum( + self.prepared_data["annotations-regions"]["duration"][ + annotated_region_indices + ] + / np.sum( + self.prepared_data["annotations-regions"]["duration"][ + annotated_region_indices + ] + ) + ) + + # seletect one annotated region at random (with probability proportional to its duration) + annotated_region_index = annotated_region_indices[ + cum_prob_annotaded_regions_duration.searchsorted(rng.random()) + ] + + # select one chunk at random in this annotated region + _, region_duration, start = self.prepared_data["annotations-regions"][ + annotated_region_index + ] + start_time = rng.uniform(start, start + region_duration - duration) + + return (file_id, start_time) + + def draw_embedding_chunk(self, class_id: int, duration: float) -> tuple: + """Sample one chunk for the embedding task + + Parameters + ---------- + class_id : int + class ID in the task speficiations + duration: float + duration of the chunk to draw + + Return + ------ + tuple: + file_id: + the file id to which the sampled chunk belongs + start_time: + start time of the sampled chunk + """ + # get index of the current class in the order of original class list + # get segments for current class + class_segments_idx = ( + self.prepared_data["annotations-segments"]["global_label_idx"] == class_id + ) + class_segments = self.prepared_data["annotations-segments"][class_segments_idx] + + # sample one segment from all the class segments: + segments_duration = class_segments["end"] - class_segments["start"] + segments_total_duration = np.sum(segments_duration) + prob_segments = segments_duration / segments_total_duration + segment = np.random.choice(class_segments, p=prob_segments) + + # sample chunk start time in order to intersect it with the sampled segment + start_time = np.random.uniform( + max(segment["start"] - duration, 0), segment["end"] + ) + + return (segment["file_id"], start_time) + + def train__iter__helper(self, rng: random.Random, **filters): + """Iterate over training samples with optional domain filtering + + Parameters + ---------- + rng : random.Random + Random number generator + filters : dict, optional + When provided (as {key : value} dict), filter training files so that + only file such as file [key] == value are used for generating chunks + + Yields + ------ + chunk : dict + Training chunks + """ + + # indices of training files that matches domain filters + training = self.prepared_data["audio-metadata"]["subset"] == Subsets.index( + "train" + ) + for key, value in filters.items(): + training &= self.prepared_data["audio-metadata"][key] == self.prepared_data[ + "metadata" + ][key].index(value) + file_ids = np.where(training)[0] + # get the subset of embedding database files from training files + embedding_files_ids = file_ids[np.isin(file_ids, self.embedding_files_id)] + + if self.num_dia_samples > 0: + annotated_duration = self.prepared_data["audio-annotated"][file_ids] + # set duration of files for the embedding part to zero, in order to not + # drawn them for diarization part + annotated_duration[embedding_files_ids] = 0.0 + + cum_prob_annotated_duration = np.cumsum( + annotated_duration / np.sum(annotated_duration) + ) + + duration = self.duration + batch_size = self.batch_size + + # use original order for the first run on the shuffled classes list: + emb_task_classes = self.specifications[Subtasks.index("embedding")].classes[:] + + sample_idx = 0 + embedding_class_idx = 0 + while True: + if sample_idx < self.num_dia_samples: + file_id, start_time = self.draw_diarization_chunk( + file_ids, cum_prob_annotated_duration, rng, duration + ) + else: + # shuffle embedding classes list and go through this shuffled list + # to make sure to see all the speakers during training + if embedding_class_idx == len(emb_task_classes): + rng.shuffle(emb_task_classes) + embedding_class_idx = 0 + klass = emb_task_classes[embedding_class_idx] + embedding_class_idx += 1 + file_id, start_time = self.draw_embedding_chunk(klass, duration) + + sample = self.prepare_chunk(file_id, start_time, duration) + sample_idx = (sample_idx + 1) % batch_size + + yield sample + + def train__iter__(self): + """Iterate over trainig samples + + Yields + ------ + dict: + x: (time, channel) + Audio chunks. + task: "diarization" or "embedding" + y: target speaker label for speaker embedding task, + (frame, ) frame-level targets for speaker diarization task. + Note that frame < time. + `frame is infered automagically from the exemple model output` + """ + + # create worker-specific random number generator + rng = create_rng_for_worker(self.model) + + balance = getattr(self, "balance", None) + if balance is None: + chunks = self.train__iter__helper(rng) + else: + # create + subchunks = dict() + for product in itertools.product( + [self.prepared_data["metadata-values"][key] for key in balance] + ): + filters = {key: value for key, value in zip(balance, product)} + subchunks[product] = self.train__iter__helper(rng, **filters) + + while True: + # select one subchunck generator at random (with uniform probability) + # so thath it is balanced on average + if balance is not None: + chunks = subchunks[rng.choice(subchunks)] + + # generate random chunk + yield next(chunks) + + def val__getitem__(self, idx) -> Dict: + """Validation items are generated so that all samples in a batch come from the same + validation file. These samples are created by sliding a window over the first seconds of + the validation file, with a step (for now arbitrally) set to 0.2 (20% of the task duration, + e.g. 1 second for a duration of 5 seconds)""" + + file_idx = idx // self.batch_size + chunk_idx = idx % self.batch_size + + file_id = self.prepared_data["validation-files"][file_idx] + file = next( + itertools.islice(self.protocol.development(), file_idx, file_idx + 1) + ) + + file_duration = file.get( + "duration", Audio("downmix").get_duration(file["audio"]) + ) + start_time = chunk_idx * ( + (file_duration - self.duration) / (self.batch_size - 1) + ) + + chunk = self.prepare_chunk(file_id, start_time, self.duration) + + if chunk_idx == 0: + chunk["annotation"] = file["annotation"] + + chunk["start_time"] = start_time + + return chunk + + def val__len__(self): + return len(self.prepared_data["validation-files"]) * self.batch_size + + def collate_y(self, batch) -> torch.Tensor: + """ + Parameters + ---------- + batch : list + List of samples to collate. + "y" field is expected to be a SlidingWindowFeature. + + Returns + ------- + y : torch.Tensor + Collated target tensor of shape (num_frames, self.max_speakers_per_chunk) + If one chunk has more than `self.max_speakers_per_chunk` speakers, we keep + the max_speakers_per_chunk most talkative ones. If it has less, we pad with + zeros (artificial inactive speakers). + """ + + collated_y_dia = [] + collate_y_emb = [] + + for b in batch: + # diarization reference + y_dia = b["y"].data + labels = b["y"].labels + num_speakers = len(labels) + # embedding reference + y_emb = np.full((self.max_speakers_per_chunk,), -1, dtype=int) + + if num_speakers > self.max_speakers_per_chunk: + # sort speakers in descending talkativeness order + indices = np.argsort(-np.sum(y_dia, axis=0), axis=0) + # keep only the most talkative speakers + y_dia = y_dia[:, indices[: self.max_speakers_per_chunk]] + # TODO: we should also sort the speaker labels in the same way + + # if current chunck is for the embedding subtask + if b["meta"]["scope"] > 1: + labels = np.array(labels) + y_emb = labels[indices[: self.max_speakers_per_chunk]] + + elif num_speakers < self.max_speakers_per_chunk: + # create inactive speakers by zero padding + y_dia = np.pad( + y_dia, + ((0, 0), (0, self.max_speakers_per_chunk - num_speakers)), + mode="constant", + ) + if b["meta"]["scope"] > 1: + y_emb[:num_speakers] = labels[:] + + else: + if b["meta"]["scope"] > 1: + y_emb[:num_speakers] = labels[:] + + collated_y_dia.append(y_dia) + collate_y_emb.append(y_emb) + + return ( + torch.from_numpy(np.stack(collated_y_dia)), + torch.from_numpy(np.stack(collate_y_emb)).squeeze(1), + ) + + def collate_fn(self, batch, stage="train"): + """Collate function used for most segmentation tasks + + This function does the following: + * stack waveforms into a (batch_size, num_channels, num_samples) tensor batch["X"]) + * apply augmentation when in "train" stage + * convert targets into a (batch_size, num_frames, num_classes) tensor batch["y"] + * collate any other keys that might be present in the batch using pytorch default_collate function + + Parameters + ---------- + batch : list of dict + List of training samples. + + Returns + ------- + batch : dict + Collated batch as {"X": torch.Tensor, "y": torch.Tensor} dict. + """ + + # collate X + collated_X = self.collate_X(batch) + # collate y + collated_y_dia, collate_y_emb = self.collate_y(batch) + + # collate metadata + collated_meta = self.collate_meta(batch) + + # apply augmentation (only in "train" stage) + self.augmentation.train(mode=(stage == "train")) + augmented = self.augmentation( + samples=collated_X, + sample_rate=self.model.hparams.sample_rate, + targets=collated_y_dia.unsqueeze(1), + ) + collated_batch = { + "X": augmented.samples, + "y_dia": augmented.targets.squeeze(1), + "y_emb": collate_y_emb, + "meta": collated_meta, + } + + if stage == "val": + collated_batch["annotation"] = batch[0]["annotation"] + collated_batch["start_times"] = [b["start_time"] for b in batch] + + return collated_batch + + def setup_loss_func(self): + self.model.arc_face_loss = ArcFaceLoss( + len(self.specifications[Subtasks.index("embedding")].classes), + self.model.hparams["embedding_dim"], + margin=self.margin, + scale=self.scale, + ) + + def segmentation_loss( + self, + permutated_prediction: torch.Tensor, + target: torch.Tensor, + weight: torch.Tensor = None, + ) -> torch.Tensor: + """Permutation-invariant segmentation loss + + Parameters + ---------- + permutated_prediction : (batch_size, num_frames, num_classes) torch.Tensor + Permutated speaker activity predictions. + target : (batch_size, num_frames, num_speakers) torch.Tensor + Speaker activity. + weight : (batch_size, num_frames, 1) torch.Tensor, optional + Frames weight. + + Returns + ------- + seg_loss : torch.Tensor + Permutation-invariant segmentation loss + """ + + # `clamp_min` is needed to set non-speech weight to 1. + class_weight = ( + torch.clamp_min(self.model.powerset.cardinality, 1.0) + if self.weigh_by_cardinality + else None + ) + seg_loss = nll_loss( + permutated_prediction, + torch.argmax(target, dim=-1), + class_weight=class_weight, + weight=weight, + ) + + return seg_loss + + def compute_diarization_loss(self, prediction, permutated_target): + """Compute loss for the speaker diarization subtask + + Parameters + ---------- + prediction : torch.Tensor + speaker diarization output predicted by the model for the current batch. + Shape of (batch_size, num_spk, num_frames) + permutated_target: torch.Tensor + permutated target for the current batch. Shape of (batch_size, num_spk, num_frames) + + Returns + ------- + dia_loss : torch.Tensor + Permutation-invariant diarization loss + """ + + # Compute segmentation loss + dia_loss = self.segmentation_loss(prediction, permutated_target) + self.model.log( + "loss/train/dia", + dia_loss, + on_step=False, + on_epoch=True, + prog_bar=True, + logger=True, + ) + return dia_loss + + def compute_embedding_loss(self, emb_prediction, target_emb, valid_embs): + """Compute loss for the speaker embeddings extraction subtask + + Parameters + ---------- + emb_prediction : torch.Tensor + speaker embeddings predicted by the model for the current batch. + Shape of (batch_size * num_spk, embedding_dim) + target_emb : torch.Tensor + target embeddings for the current batch + Shape of (batch_size * num_spk,) + Returns + ------- + emb_loss : torch.Tensor + arcface loss for the current batch + """ + + # Get speaker representations from the embedding subtask + embeddings = rearrange(emb_prediction, "b s e -> (b s) e") + # Get corresponding target label + targets = rearrange(target_emb, "b s -> (b s)") + # compute loss only on global scope speaker embedding + valid_embs = rearrange(valid_embs, "b s -> (b s)") + # compute the loss + emb_loss = self.model.arc_face_loss( + embeddings[valid_embs, :], targets[valid_embs] + ) + + if torch.any(valid_embs): + emb_loss = (1.0 / torch.sum(valid_embs)) * emb_loss + + # skip batch if something went wrong for some reason + if torch.isnan(emb_loss): + return None + + self.model.log( + "loss/train/arcface", + emb_loss, + on_step=False, + on_epoch=True, + prog_bar=True, + logger=True, + ) + return emb_loss + + def training_step(self, batch, batch_idx: int): + """Compute loss for the joint task + + Parameters + ---------- + batch : (usually) dict of torch.Tensor + current batch. + batch_idx: int + Batch index. + + Returns + ------- + loss : {str: torch.tensor} + {"loss": loss} + """ + + # batch waveforms (batch_size, num_channels, num_samples) + waveform = batch["X"] + # batch diarization references (batch_size, num_channels, num_speakers) + target_dia = batch["y_dia"] + # batch embedding references (batch, num_speakers) + target_emb = batch["y_emb"] + + # drop samples that contain too many speakers + num_speakers = torch.sum(torch.any(target_dia, dim=1), dim=1) + keep = num_speakers <= self.max_speakers_per_chunk + + target_dia = target_dia[keep] + target_emb = target_emb[keep] + waveform = waveform[keep] + + num_remaining_dia_samples = torch.sum(keep[: self.num_dia_samples]) + + # corner case + if not keep.any(): + return None + + # forward pass + dia_prediction, emb_prediction = self.model(waveform) + # (batch_size, num_frames, num_cls), (batch_size, num_spk, emb_size) + + # get the best permutation + dia_multilabel = self.model.powerset.to_multilabel(dia_prediction) + permutated_target_dia, permut_map = permutate(dia_multilabel, target_dia) + permutated_target_emb = target_emb[ + torch.arange(target_emb.shape[0]).unsqueeze(1), permut_map + ] + + # an embedding is valid only if corresponding speaker is active in the diarization prediction and reference + active_speaker_pred = torch.any(dia_multilabel > 0, dim=1) + active_speaker_ref = torch.any(permutated_target_dia == 1, dim=1) + valid_embs = torch.logical_and(active_speaker_pred, active_speaker_ref)[ + num_remaining_dia_samples: + ] + + permutated_target_powerset = self.model.powerset.to_powerset( + permutated_target_dia.float() + ) + + dia_prediction = dia_prediction[:num_remaining_dia_samples] + permutated_target_powerset = permutated_target_powerset[ + :num_remaining_dia_samples + ] + + dia_loss = torch.tensor(0) + # if batch contains diarization subtask chunks, then compute diarization loss on these chunks: + if self.alpha != 0.0 and torch.any(keep[: self.num_dia_samples]): + dia_loss = self.compute_diarization_loss( + dia_prediction, permutated_target_powerset + ) + + emb_loss = torch.tensor(0) + # if batch contains embedding subtask chunks, then compute embedding loss on these chunks: + if self.alpha != 1.0 and torch.any(valid_embs): + emb_prediction = emb_prediction[num_remaining_dia_samples:] + permutated_target_emb = permutated_target_emb[num_remaining_dia_samples:] + emb_loss = self.compute_embedding_loss( + emb_prediction, permutated_target_emb, valid_embs + ) + loss = self.alpha * dia_loss + (1 - self.alpha) * emb_loss + else: + loss = self.alpha * dia_loss + + return {"loss": loss} + + def reconstruct( + self, + segmentations: SlidingWindowFeature, + clusters: np.ndarray, + ) -> SlidingWindowFeature: + """Build final discrete diarization out of clustered segmentation + + Parameters + ---------- + segmentations : (num_chunks, num_frames, num_speakers) SlidingWindowFeature + Raw speaker segmentation. + hard_clusters : (num_chunks, num_speakers) array + Output of clustering step. + count : (total_num_frames, 1) SlidingWindowFeature + Instantaneous number of active speakers. + + Returns + ------- + discrete_diarization : SlidingWindowFeature + Discrete (0s and 1s) diarization. + """ + + num_chunks, num_frames, _ = segmentations.data.shape + num_clusters = np.max(clusters) + 1 + clustered_segmentations = np.nan * np.zeros( + (num_chunks, num_frames, num_clusters) + ) + + for c, (cluster, (chunk, segmentation)) in enumerate( + zip(clusters, segmentations) + ): + # cluster is (local_num_speakers, )-shaped + # segmentation is (num_frames, local_num_speakers)-shaped + for k in np.unique(cluster): + if k == -2: + continue + + # TODO: can we do better than this max here? + clustered_segmentations[c, :, k] = np.max( + segmentation[:, cluster == k], axis=1 + ) + + clustered_segmentations = SlidingWindowFeature( + clustered_segmentations, segmentations.sliding_window + ) + return clustered_segmentations + + def aggregate(self, segmentations: SlidingWindowFeature, pad_duration:float) -> SlidingWindowFeature: + num_chunks, num_frames, num_speakers = segmentations.data.shape + sliding_window = segmentations.sliding_window + frame_duration = sliding_window.duration / num_frames + + if num_chunks <= 1: + return segmentations[0] + + num_padding_frames = np.round( + pad_duration / frame_duration + ).astype(np.uint32) + aggregated_segmentation = segmentations[0] + + for chunk_segmentation in segmentations[1:]: + padding = np.zeros((num_padding_frames, num_speakers)) + aggregated_segmentation = np.concatenate( + (aggregated_segmentation, padding, chunk_segmentation) + ) + return SlidingWindowFeature(aggregated_segmentation.astype(np.int8), SlidingWindow(step=frame_duration, duration=frame_duration)) + + def to_diarization( + self, + segmentations: SlidingWindowFeature, + pad_duration: float = 0., + ) -> SlidingWindowFeature: + """Build diarization out of preprocessed segmentation and precomputed speaker count + + Parameters + ---------- + segmentations : SlidingWindowFeature + (num_chunks, num_frames, num_speakers)-shaped segmentations + count : SlidingWindow_feature + (num_frames, 1)-shaped speaker count + + Returns + ------- + discrete_diarization : SlidingWindowFeature + Discrete (0s and 1s) diarization. + """ + + activations = self.aggregate(segmentations, pad_duration=pad_duration) + # shape: (num_frames, num_speakers) + _, num_speakers = activations.data.shape + + count = np.sum(activations, axis=1, keepdims=True) + # shape: (num_frames, 1) + + max_speakers_per_frame = np.max(count.data) + if num_speakers < max_speakers_per_frame: + activations.data = np.pad( + activations.data, ((0, 0), (0, max_speakers_per_frame - num_speakers)) + ) + + extent = activations.extent & count.extent + activations = activations.crop(extent, return_data=False) + count = count.crop(extent, return_data=False) + + sorted_speakers = np.argsort(-activations, axis=-1) + binary = np.zeros_like(activations.data) + + for t, ((_, c), speakers) in enumerate(zip(count, sorted_speakers)): + for i in range(c.item()): + binary[t, speakers[i]] = 1.0 + + return SlidingWindowFeature(binary, activations.sliding_window) + + def compute_metric( + self, + reference: Annotation, + hypothesis: Tuple[SlidingWindowFeature, np.ndarray], + pad_duration: float, + ): + """Compute diarization annotation from binarized segmentation and cluster (num_chunk, num_speaker)""" + frames = self.model.receptive_field + binarized_segmentations, clusters = hypothesis + + # keep track of inactive speakers + inactive_speakers = np.sum(binarized_segmentations.data, axis=1) == 0 + # shape: (num_chunks, num_speakers) + clusters[inactive_speakers] = -2 + + clustered_segmentations = self.reconstruct( + binarized_segmentations, clusters + ) + + binarized_diarization = self.to_diarization(clustered_segmentations, pad_duration=pad_duration) + diarization = SpeakerDiarizationMixin.to_annotation(binarized_diarization) + + metric = GlobalDiarizationErrorRate() + metric(reference, diarization, detailed=True) + + result = metric[:] + metric_dict = {"der": 0.} + for component in ["false alarm", "missed detection", "confusion"]: + metric_dict[component] = (result[component] / result["total"]) + metric_dict["der"] += metric_dict[component] + + return metric_dict + + # TODO: no need to compute gradient in this method + def validation_step(self, batch, batch_idx: int): + """Compute validation loss and metric + + Parameters + ---------- + batch : dict of torch.Tensor + current batch. All chunks come from the same + file and are in chronological order + batch_idx: int + Batch index. + """ + + # get reference + reference = batch["annotation"] + num_speakers = len(reference.labels()) + + frames = self.model.receptive_field + + start_times = batch["start_times"] + + file_id = batch["meta"]["file"][0] + file = self.get_file(file_id) + file["annotation"] = reference + + assert reference.uri in file["audio"] + + # build support timeline from chunk segments + support = Timeline() + for start_time in start_times: + support.add(Segment(start_time, start_time + self.duration)) + + # keep reference only on chunk segments: + reference = reference.crop(support) + # corner case where no reference segments intersects the timeline + if len(reference) == 0: + return None + + waveform = batch["X"] + #shape: (num_chunks, num_channels, local_num_samples) + + # segmentation + embeddings extraction step + segmentations, embeddings = self.model(waveform) + # shapes: (num_chunks, num_frames, powerset_classes), (num_chunks, local_num_speakers, embed_dim) + + if self.batch_size > 1: + step = batch["start_times"][1] - batch["start_times"][0] + else: + step = self.duration + + sliding_window = SlidingWindow( + start=batch["start_times"][0], duration=self.duration, step=step + ) + + binarized_segmentations = self.model.powerset.to_multilabel(segmentations) + + binarized_segmentations = binarized_segmentations.cpu().detach().numpy() + binarized_segmentations = SlidingWindowFeature( + binarized_segmentations, sliding_window + ) + + embeddings = embeddings.cpu().detach().numpy() + + # clustering step + clustering = KMeansClustering() + hard_clusters, _, _ = clustering( + embeddings=embeddings, + segmentations=binarized_segmentations, + num_clusters=num_speakers, + ) + oracle_clustering = OracleClustering() + oracle_hard_clusters, _, _ = oracle_clustering( + segmentations=binarized_segmentations, + file=file, + frames=self.model.receptive_field.step, + ) + + pad_duration = step - self.duration + der = self.compute_metric( + reference=reference, + hypothesis=(binarized_segmentations, hard_clusters), + pad_duration=pad_duration, + ) + + oder = self.compute_metric( + reference=reference, + hypothesis=(binarized_segmentations, oracle_hard_clusters), + pad_duration=pad_duration, + ) + + for key in der: + self.model.log( + f"BS={self.batch_size}-Duration={self.duration}s/DER/{key}", + der[key], + on_step=False, + on_epoch=True, + prog_bar=True, + logger=True, + ) + + self.model.log( + f"BS={self.batch_size}-Duration={self.duration}s/ODER/{key}", + oder[key], + on_step=False, + on_epoch=True, + prog_bar=True, + logger=True, + ) + + return None + + def default_metric( + self, + ) -> Union[Metric, Sequence[Metric], Dict[str, Metric]]: + """Returns diarization error rate and its components for diarization subtask, + and equal error rate for the embedding part + """ + return { + "DiarizationErrorRate": DiarizationErrorRate(0.5), + "DiarizationErrorRate/Confusion": SpeakerConfusionRate(0.5), + "DiarizationErrorRate/Miss": MissedDetectionRate(0.5), + "DiarizationErrorRate/FalseAlarm": FalseAlarmRate(0.5), + } diff --git a/pyannote/audio/tasks/segmentation/multilabel.py b/pyannote/audio/tasks/segmentation/multilabel.py index 9184121c4..03c242d4f 100644 --- a/pyannote/audio/tasks/segmentation/multilabel.py +++ b/pyannote/audio/tasks/segmentation/multilabel.py @@ -87,6 +87,8 @@ class MultiLabelSegmentation(SegmentationTask): metric : optional Validation metric(s). Can be anything supported by torchmetrics.MetricCollection. Defaults to AUROC (area under the ROC curve). + cache_path : str, optional + path to file where to write or load task caches """ def __init__( @@ -103,6 +105,7 @@ def __init__( pin_memory: bool = False, augmentation: Optional[BaseWaveformTransform] = None, metric: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None, + cache_path: Optional[Union[str, None]] = None, ): if not isinstance(protocol, SegmentationProtocol): raise ValueError( diff --git a/pyannote/audio/tasks/segmentation/overlapped_speech_detection.py b/pyannote/audio/tasks/segmentation/overlapped_speech_detection.py index 89d299a8d..d8d6a3365 100644 --- a/pyannote/audio/tasks/segmentation/overlapped_speech_detection.py +++ b/pyannote/audio/tasks/segmentation/overlapped_speech_detection.py @@ -96,6 +96,8 @@ class OverlappedSpeechDetection(SegmentationTask): metric : optional Validation metric(s). Can be anything supported by torchmetrics.MetricCollection. Defaults to AUROC (area under the ROC curve). + cache_path : str, optional + path to file where to write or load task caches """ OVERLAP_DEFAULTS = {"probability": 0.5, "snr_min": 0.0, "snr_max": 10.0} diff --git a/pyannote/audio/tasks/segmentation/speaker_diarization.py b/pyannote/audio/tasks/segmentation/speaker_diarization.py index fb0b9b979..2c736922f 100644 --- a/pyannote/audio/tasks/segmentation/speaker_diarization.py +++ b/pyannote/audio/tasks/segmentation/speaker_diarization.py @@ -117,6 +117,8 @@ class SpeakerDiarization(SegmentationTask): metric : optional Validation metric(s). Can be anything supported by torchmetrics.MetricCollection. Defaults to AUROC (area under the ROC curve). + cache_path : str, optional + path to file where to write or load task caches References ---------- diff --git a/pyannote/audio/tasks/segmentation/voice_activity_detection.py b/pyannote/audio/tasks/segmentation/voice_activity_detection.py index e52613aeb..4a5481426 100644 --- a/pyannote/audio/tasks/segmentation/voice_activity_detection.py +++ b/pyannote/audio/tasks/segmentation/voice_activity_detection.py @@ -81,6 +81,8 @@ class VoiceActivityDetection(SegmentationTask): metric : optional Validation metric(s). Can be anything supported by torchmetrics.MetricCollection. Defaults to AUROC (area under the ROC curve). + cache_path : str, optional + path to file where to write or load task caches """ def __init__( @@ -96,6 +98,7 @@ def __init__( pin_memory: bool = False, augmentation: Optional[BaseWaveformTransform] = None, metric: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None, + cache_path: Optional[Union[str, None]] = None, ): super().__init__( protocol, diff --git a/tests/test_train.py b/tests/test_train.py index 6a7a6c69b..b69f5a1f3 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -21,6 +21,7 @@ # SOFTWARE. +from pathlib import Path from tempfile import mkstemp import pytest @@ -37,6 +38,8 @@ VoiceActivityDetection, ) +CACHE_FILE_PATH = "./cache/cache_file" + @pytest.fixture() def protocol(): @@ -247,6 +250,21 @@ def test_finetune_freeze_with_task_that_needs_setup_for_specs_and_with_cache( trainer.fit(model) +def test_finetune_freeze_with_task_that_needs_setup_for_specs_and_with_cache(protocol): + segmentation = SpeakerDiarization(protocol, cache_path=CACHE_FILE_PATH) + model = SimpleSegmentationModel(task=segmentation) + trainer = Trainer(fast_dev_run=True, accelerator="cpu") + trainer.fit(model) + + segmentation = SpeakerDiarization(protocol) + model.task = segmentation + model.freeze_up_to("mfcc") + trainer = Trainer(fast_dev_run=True, accelerator="cpu") + trainer.fit(model) + + Path(CACHE_FILE_PATH).unlink(missing_ok=True) + + def test_finetune_freeze_with_task_that_does_not_need_setup_for_specs(protocol): vad = VoiceActivityDetection(protocol) model = SimpleSegmentationModel(task=vad) @@ -276,6 +294,40 @@ def test_finetune_freeze_with_task_that_does_not_need_setup_for_specs_and_with_c trainer.fit(model) +def test_finetune_freeze_with_task_that_does_not_need_setup_for_specs_and_with_cache( + protocol, +): + vad = VoiceActivityDetection(protocol, cache_path=CACHE_FILE_PATH) + model = SimpleSegmentationModel(task=vad) + trainer = Trainer(fast_dev_run=True, accelerator="cpu") + trainer.fit(model) + + vad = VoiceActivityDetection(protocol, cache_path=CACHE_FILE_PATH) + model.task = vad + model.freeze_up_to("mfcc") + trainer = Trainer(fast_dev_run=True, accelerator="cpu") + trainer.fit(model) + + Path(CACHE_FILE_PATH).unlink(missing_ok=True) + + +def test_finetune_freeze_with_task_that_does_not_need_setup_for_specs_and_with_cache( + protocol, +): + vad = VoiceActivityDetection(protocol, cache_path=CACHE_FILE_PATH) + model = SimpleSegmentationModel(task=vad) + trainer = Trainer(fast_dev_run=True, accelerator="cpu") + trainer.fit(model) + + vad = VoiceActivityDetection(protocol, cache_path=CACHE_FILE_PATH) + model.task = vad + model.freeze_up_to("mfcc") + trainer = Trainer(fast_dev_run=True, accelerator="cpu") + trainer.fit(model) + + Path(CACHE_FILE_PATH).unlink(missing_ok=True) + + def test_transfer_freeze_with_task_that_does_not_need_setup_for_specs(protocol): segmentation = SpeakerDiarization(protocol) model = SimpleSegmentationModel(task=segmentation)