diff --git a/pyannote/audio/models/blocks/conformer.py b/pyannote/audio/models/blocks/conformer.py new file mode 100644 index 000000000..011a410a2 --- /dev/null +++ b/pyannote/audio/models/blocks/conformer.py @@ -0,0 +1,345 @@ +# MIT License + +# Copyright (c) 2024 BUT Speech@FIT + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# Copied from https://github.com/BUTSpeechFIT/DiariZen/blob/e747106e753bb17799602b24d396f60b13da81b4/diarizen/models/module/conformer.py + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class RelativePositionalEncoding(nn.Module): + def __init__(self, d_model, maxlen=1000, embed_v=False): + super(RelativePositionalEncoding, self).__init__() + + self.d_model = d_model + self.maxlen = maxlen + self.pe_k = nn.Embedding(2 * maxlen, d_model) + if embed_v: + self.pe_v = nn.Embedding(2 * maxlen, d_model) + self.embed_v = embed_v + + def forward(self, pos_seq): + pos_seq.clamp_(-self.maxlen, self.maxlen - 1) + pos_seq = pos_seq + self.maxlen + if self.embed_v: + return self.pe_k(pos_seq), self.pe_v(pos_seq) + else: + return self.pe_k(pos_seq), None + + +class MultiHeadSelfAttention(nn.Module): + """Multi head self-attention layer""" + + def __init__(self, n_units: int, h: int, dropout: float) -> None: + super().__init__() + self.linearQ = nn.Linear(n_units, n_units) + self.linearK = nn.Linear(n_units, n_units) + self.linearV = nn.Linear(n_units, n_units) + self.linearO = nn.Linear(n_units, n_units) + + self.d_k = n_units // h + self.h = h + self.dropout = nn.Dropout(p=dropout) + self.att = None # attention for plot + + def __call__(self, x: torch.Tensor, batch_size: int, pos_k=None) -> torch.Tensor: + # x: (BT, F) + q = self.linearQ(x).reshape(batch_size, -1, self.h, self.d_k) + k = self.linearK(x).reshape(batch_size, -1, self.h, self.d_k) + v = self.linearV(x).reshape(batch_size, -1, self.h, self.d_k) + + q = q.transpose(1, 2) # (batch, head, time, d_k) + k = k.transpose(1, 2) # (batch, head, time, d_k) + v = v.transpose(1, 2) # (batch, head, time, d_k) + att_score = torch.matmul(q, k.transpose(-2, -1)) + + if pos_k is not None: + reshape_q = q.reshape(batch_size * self.h, -1, self.d_k).transpose(0, 1) + att_score_pos = torch.matmul(reshape_q, pos_k.transpose(-2, -1)) + att_score_pos = att_score_pos.transpose(0, 1).reshape( + batch_size, self.h, pos_k.size(0), pos_k.size(1) + ) + scores = (att_score + att_score_pos) / np.sqrt(self.d_k) + else: + scores = att_score / np.sqrt(self.d_k) + + # scores: (B, h, T, T) + self.att = F.softmax(scores, dim=3) + p_att = self.dropout(self.att) + x = torch.matmul(p_att, v) + x = x.permute(0, 2, 1, 3).reshape(-1, self.h * self.d_k) + return self.linearO(x) + + +class Swish(nn.Module): + """ + Swish is a smooth, non-monotonic function that consistently matches or outperforms ReLU on deep networks applied + to a variety of challenging domains such as Image classification and Machine translation. + """ + + def __init__(self): + super(Swish, self).__init__() + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + return inputs * inputs.sigmoid() + + +class ConformerMHA(nn.Module): + """ + Conformer MultiHeadedAttention(RelMHA) module with residule connection and dropout. + """ + + def __init__( + self, + in_size: int = 256, + num_head: int = 4, + dropout: float = 0.1, + ) -> None: + super().__init__() + self.ln_norm = nn.LayerNorm(in_size) + self.mha = MultiHeadSelfAttention(n_units=in_size, h=num_head, dropout=dropout) + self.dropout = nn.Dropout(dropout) + + def forward(self, x: torch.Tensor, pos_k=None) -> torch.Tensor: + """ + x: B, T, N + """ + bs, time, idim = x.shape + x = x.reshape(-1, idim) + res = x + x = self.ln_norm(x) + x = self.mha(x, bs, pos_k) + x = self.dropout(x) + x = res + x + x = x.reshape(bs, time, -1) + return x + + +class PositionwiseFeedForward(nn.Module): + """Positionwise feed forward layer + with scaled residule connection and dropout. + Args: + idim (int): Input dimenstion. + hidden_units (int): The number of hidden units. + dropout_rate (float): Dropout rate. + + """ + + def __init__(self, in_size, ffn_hidden, dropout=0.1, swish=Swish()): + """Construct an PositionwiseFeedForward object.""" + super(PositionwiseFeedForward, self).__init__() + self.ln_norm = nn.LayerNorm(in_size) + self.w_1 = nn.Linear(in_size, ffn_hidden) + self.swish = swish + self.dropout1 = nn.Dropout(dropout) + self.w_2 = nn.Linear(ffn_hidden, in_size) + self.dropout2 = nn.Dropout(dropout) + + def forward(self, x): + """Forward function.""" + res = x + x = self.ln_norm(x) + x = self.swish(self.w_1(x)) + x = self.dropout1(x) + x = self.dropout2(self.w_2(x)) + + return res + 0.5 * x + + +class ConvolutionModule(nn.Module): + """ConvolutionModule in Conformer model + with residule connection and dropout. + + Args: + channels (int): The number of channels of conv layers. + kernel_size (int): Kernerl size of conv layers. + + """ + + def __init__( + self, channels, kernel_size=31, dropout_rate=0.1, swish=Swish(), bias=True + ): + """Construct an ConvolutionModule object.""" + super(ConvolutionModule, self).__init__() + # kernerl_size should be a odd number for 'SAME' padding + assert (kernel_size - 1) % 2 == 0 + self.ln_norm = nn.LayerNorm(channels) + self.pointwise_conv1 = nn.Conv1d( + channels, + 2 * channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + self.glu = nn.GLU(dim=1) + self.depthwise_conv = nn.Conv1d( + channels, + channels, + kernel_size, + stride=1, + padding=(kernel_size - 1) // 2, + groups=channels, + bias=bias, + ) + self.bn_norm = nn.BatchNorm1d(channels) + self.swish = swish + self.pointwise_conv2 = nn.Conv1d( + channels, + channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + self.dropout = nn.Dropout(dropout_rate) + + def forward(self, x): + """Compute convolution module. + + Args: + x (torch.Tensor): Input tensor (#batch, time, channels). + + Returns: + torch.Tensor: Output tensor (#batch, time, channels). + + """ + # exchange the temporal dimension and the feature dimension + res = x + x = self.ln_norm(x) + x = x.transpose(1, 2) # B, N, T + + x = self.pointwise_conv1(x) # (batch, 2*channel, dim) + x = self.glu(x) # (batch, channel, dim) + + x = self.depthwise_conv(x) + x = self.swish(self.bn_norm(x)) + x = self.dropout(self.pointwise_conv2(x)) + + return res + x.transpose(1, 2) + + +class ConformerBlock(nn.Module): + def __init__( + self, + in_size: int = 256, + ffn_hidden: int = 1024, + num_head: int = 2, + kernel_size: int = 31, + dropout: float = 0.1, + ) -> None: + super().__init__() + self.ffn1 = PositionwiseFeedForward( + in_size=in_size, ffn_hidden=ffn_hidden, dropout=dropout + ) + self.mha = ConformerMHA(in_size=in_size, num_head=num_head, dropout=dropout) + self.conv = ConvolutionModule(channels=in_size, kernel_size=kernel_size) + self.ffn2 = PositionwiseFeedForward( + in_size=in_size, ffn_hidden=ffn_hidden, dropout=dropout + ) + self.ln_norm = nn.LayerNorm(in_size) + + def forward(self, x: torch.Tensor, pos_k=None) -> torch.Tensor: + """ + Args: + x (torch.Tensor): Input tensor (#batch, time, channels). + """ + x = self.ffn1(x) + x = self.mha(x, pos_k) + x = self.conv(x) + x = self.ffn2(x) + + return self.ln_norm(x) + + +class ConformerEncoder(nn.Module): + def __init__( + self, + attention_in: int = 256, + ffn_hidden: int = 1024, + num_head: int = 4, + num_layer: int = 4, + kernel_size: int = 31, + dropout: float = 0.1, + use_posi: bool = False, + output_activate_function="ReLU", + ) -> None: + super().__init__() + + if not use_posi: + self.pos_emb = None + else: + self.pos_emb = RelativePositionalEncoding(attention_in // num_head) + + self.conformer_layer = nn.ModuleList( + [ + ConformerBlock( + in_size=attention_in, + ffn_hidden=ffn_hidden, + num_head=num_head, + kernel_size=kernel_size, + dropout=dropout, + ) + for _ in range(num_layer) + ] + ) + + # Activation function layer + if output_activate_function: + if output_activate_function == "Tanh": + self.activate_function = nn.Tanh() + elif output_activate_function == "ReLU": + self.activate_function = nn.ReLU() + elif output_activate_function == "ReLU6": + self.activate_function = nn.ReLU6() + elif output_activate_function == "LeakyReLU": + self.activate_function = nn.LeakyReLU() + elif output_activate_function == "PReLU": + self.activate_function = nn.PReLU() + elif output_activate_function == "Sigmoid": + self.activate_function = nn.Sigmoid() + else: + raise NotImplementedError( + f"Not implemented activation function {self.activate_function}" + ) + self.output_activate_function = output_activate_function + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x (torch.Tensor): Input tensor (#batch, time, idim). + """ + if self.pos_emb is not None: + x_len = x.shape[1] + pos_seq = torch.arange(0, x_len).long().to(x.device) + pos_seq = pos_seq[:, None] - pos_seq[None, :] + pos_k, _ = self.pos_emb(pos_seq) + else: + pos_k = None + + for layer in self.conformer_layer: + x = layer(x, pos_k) + if self.output_activate_function: + x = self.activate_function(x) + return x diff --git a/pyannote/audio/models/segmentation/DiariZen.py b/pyannote/audio/models/segmentation/DiariZen.py new file mode 100644 index 000000000..cf219cc2a --- /dev/null +++ b/pyannote/audio/models/segmentation/DiariZen.py @@ -0,0 +1,279 @@ +# MIT License +# +# Copyright 2024 CNRS (author: Herve Bredin, herve.bredin@irit.fr) +# Copyright 2024 Brno University of Technology (author: Jiangyu Han, ihan@fit.vut.cz) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# Initially copied from https://github.com/BUTSpeechFIT/DiariZen/blob/e747106e753bb17799602b24d396f60b13da81b4/diarizen/models/eend/model_wavlm_conformer.py + + +import contextlib +from functools import lru_cache +from typing import Optional, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchaudio + +from pyannote.audio.core.model import Model +from pyannote.audio.core.task import Task +from pyannote.audio.models.blocks.conformer import ConformerEncoder +from pyannote.audio.utils.params import merge_dict +from pyannote.audio.utils.receptive_field import ( + conv1d_num_frames, + conv1d_receptive_field_center, + conv1d_receptive_field_size, +) + + +class DiariZen(Model): + """Architecture used in Leveraging Self-Supervised Learning for Speaker Diarization + + Parameters + ---------- + sample_rate : int, optional + Audio sample rate. Defaults to 16kHz (16000). + num_channels : int, optional + Number of channels. Defaults to mono (1). + wav2vec: dict or str, optional + Defaults to "WAVLM_BASE". + wav2vec_frozen: bool, optional + Whether to freeze wav2vec weights. Defaults to False. + wav2vec_layer: int, optional + Index of layer to use as input to the LSTM. + Defaults (-1) to use average of all layers (with learnable weights). + conformer : dict, optional + Keyword arguments passed to the Conformer layer. + + Reference + --------- + Jiangyu Han, Federico Landini, Johan Rohdin, Anna Silnova, Mireia Diez, and Lukas Burget + "Leveraging Self-Supervised Learning for Speaker Diarization" + https://arxiv.org/abs/2409.09408 + """ + + WAV2VEC_DEFAULTS = "WAVLM_BASE" + + CONFORMER_DEFAULTS = { + "attention_in": 256, + "ffn_hidden": 1024, + "num_head": 4, + "num_layer": 4, + "kernel_size": 31, + "dropout": 0.1, + "use_posi": False, + "output_activate_function": False, + } + + def __init__( + self, + wav2vec: Union[dict, str] = None, + wav2vec_frozen: bool = False, + wav2vec_layer: int = -1, + conformer: Optional[dict] = None, + sample_rate: int = 16000, + num_channels: int = 1, + task: Optional[Task] = None, + ): + super().__init__(sample_rate=sample_rate, num_channels=num_channels, task=task) + + wav2vec_dim: int + wav2vec_num_layers: int + + if isinstance(wav2vec, str): + # `wav2vec` is one of the supported pipelines from torchaudio (e.g. "WAVLM_BASE") + if hasattr(torchaudio.pipelines, wav2vec): + bundle = getattr(torchaudio.pipelines, wav2vec) + if sample_rate != bundle._sample_rate: + raise ValueError( + f"Expected {bundle._sample_rate}Hz, found {sample_rate}Hz." + ) + wav2vec_dim = bundle._params["encoder_embed_dim"] + wav2vec_num_layers = bundle._params["encoder_num_layers"] + self.wav2vec = bundle.get_model() + + # `wav2vec` is a path to a self-supervised representation checkpoint + else: + _checkpoint = torch.load(wav2vec) + wav2vec = _checkpoint.pop("config") + self.wav2vec = torchaudio.models.wav2vec2_model(**wav2vec) + state_dict = _checkpoint.pop("state_dict") + self.wav2vec.load_state_dict(state_dict) + wav2vec_dim = wav2vec["encoder_embed_dim"] + wav2vec_num_layers = wav2vec["encoder_num_layers"] + + # `wav2vec` is a config dictionary understood by `wav2vec2_model` + # this branch is typically used by Model.from_pretrained(...) + elif isinstance(wav2vec, dict): + self.wav2vec = torchaudio.models.wav2vec2_model(**wav2vec) + wav2vec_dim = wav2vec["encoder_embed_dim"] + wav2vec_num_layers = wav2vec["encoder_num_layers"] + + if wav2vec_layer < 0: + self.wav2vec_weights = nn.Parameter( + data=torch.ones(wav2vec_num_layers), requires_grad=True + ) + + conformer = merge_dict(self.CONFORMER_DEFAULTS, conformer) + + self.save_hyperparameters( + "wav2vec", "wav2vec_frozen", "wav2vec_layer", "conformer" + ) + + self.conformer = ConformerEncoder(**conformer) + self.proj = nn.Linear(wav2vec_dim, conformer["attention_in"]) + self.lnorm = nn.LayerNorm(conformer["attention_in"]) + + @property + def dimension(self) -> int: + """Dimension of output""" + if isinstance(self.specifications, tuple): + raise ValueError("DiariZen does not support multi-tasking.") + + if self.specifications.powerset: + return self.specifications.num_powerset_classes + else: + return len(self.specifications.classes) + + def build(self): + self.classifier = nn.Linear( + self.hparams.conformer["attention_in"], self.dimension + ) + self.activation = self.default_activation() + + @lru_cache + def num_frames(self, num_samples: int) -> int: + """Compute number of output frames + + Parameters + ---------- + num_samples : int + Number of input samples. + + Returns + ------- + num_frames : int + Number of output frames. + """ + + num_frames = num_samples + for conv_layer in self.wav2vec.feature_extractor.conv_layers: + num_frames = conv1d_num_frames( + num_frames, + kernel_size=conv_layer.kernel_size, + stride=conv_layer.stride, + padding=conv_layer.conv.padding[0], + dilation=conv_layer.conv.dilation[0], + ) + + # TODO: look at conformer.num_frames + + return num_frames + + def receptive_field_size(self, num_frames: int = 1) -> int: + """Compute size of receptive field + + Parameters + ---------- + num_frames : int, optional + Number of frames in the output signal + + Returns + ------- + receptive_field_size : int + Receptive field size. + """ + + # TODO: look at conformer receptive field size + + receptive_field_size = num_frames + for conv_layer in reversed(self.wav2vec.feature_extractor.conv_layers): + receptive_field_size = conv1d_receptive_field_size( + num_frames=receptive_field_size, + kernel_size=conv_layer.kernel_size, + stride=conv_layer.stride, + padding=conv_layer.conv.padding[0], + dilation=conv_layer.conv.dilation[0], + ) + return receptive_field_size + + def receptive_field_center(self, frame: int = 0) -> int: + """Compute center of receptive field + + Parameters + ---------- + frame : int, optional + Frame index + + Returns + ------- + receptive_field_center : int + Index of receptive field center. + """ + + # TODO: look at conformer receptive field center + + receptive_field_center = frame + for conv_layer in reversed(self.wav2vec.feature_extractor.conv_layers): + receptive_field_center = conv1d_receptive_field_center( + receptive_field_center, + kernel_size=conv_layer.kernel_size, + stride=conv_layer.stride, + padding=conv_layer.conv.padding[0], + dilation=conv_layer.conv.dilation[0], + ) + return receptive_field_center + + def forward(self, waveforms: torch.Tensor) -> torch.Tensor: + """Pass forward + + Parameters + ---------- + waveforms : (batch, channel, sample) + + Returns + ------- + scores : (batch, frame, classes) + """ + + num_layers = ( + None if self.hparams.wav2vec_layer < 0 else self.hparams.wav2vec_layer + ) + + context = ( + torch.no_grad() if self.hparams.wav2vec_frozen else contextlib.nullcontext() + ) + with context: + outputs, _ = self.wav2vec.extract_features( + waveforms.squeeze(1), num_layers=num_layers + ) + + if num_layers is None: + outputs = torch.stack(outputs, dim=-1) @ F.softmax( + self.wav2vec_weights, dim=0 + ) + else: + outputs = outputs[-1] + + outputs = self.proj(outputs) + outputs = self.lnorm(outputs) + outputs = self.conformer(outputs) + return self.activation(self.classifier(outputs)) diff --git a/pyannote/audio/models/segmentation/SSeRiouSS.py b/pyannote/audio/models/segmentation/SSeRiouSS.py index b96464ab3..a403a1df9 100644 --- a/pyannote/audio/models/segmentation/SSeRiouSS.py +++ b/pyannote/audio/models/segmentation/SSeRiouSS.py @@ -20,6 +20,7 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +import contextlib from functools import lru_cache from typing import Optional, Union @@ -52,6 +53,8 @@ class SSeRiouSS(Model): Number of channels. Defaults to mono (1). wav2vec: dict or str, optional Defaults to "WAVLM_BASE". + wav2vec_frozen: bool, optional + Whether to freeze wav2vec weights. Defaults to False. wav2vec_layer: int, optional Index of layer to use as input to the LSTM. Defaults (-1) to use average of all layers (with learnable weights). @@ -81,6 +84,7 @@ class SSeRiouSS(Model): def __init__( self, wav2vec: Union[dict, str] = None, + wav2vec_frozen: bool = False, wav2vec_layer: int = -1, lstm: Optional[dict] = None, linear: Optional[dict] = None, @@ -128,7 +132,9 @@ def __init__( lstm["batch_first"] = True linear = merge_dict(self.LINEAR_DEFAULTS, linear) - self.save_hyperparameters("wav2vec", "wav2vec_layer", "lstm", "linear") + self.save_hyperparameters( + "wav2vec", "wav2vec_frozen", "wav2vec_layer", "lstm", "linear" + ) monolithic = lstm["monolithic"] if monolithic: @@ -294,7 +300,10 @@ def forward(self, waveforms: torch.Tensor) -> torch.Tensor: None if self.hparams.wav2vec_layer < 0 else self.hparams.wav2vec_layer ) - with torch.no_grad(): + context = ( + torch.no_grad() if self.hparams.wav2vec_frozen else contextlib.nullcontext() + ) + with context: outputs, _ = self.wav2vec.extract_features( waveforms.squeeze(1), num_layers=num_layers ) diff --git a/pyannote/audio/models/segmentation/__init__.py b/pyannote/audio/models/segmentation/__init__.py index 9f6f5f6e3..01883f1e3 100644 --- a/pyannote/audio/models/segmentation/__init__.py +++ b/pyannote/audio/models/segmentation/__init__.py @@ -20,7 +20,8 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +from .DiariZen import DiariZen from .PyanNet import PyanNet from .SSeRiouSS import SSeRiouSS -__all__ = ["PyanNet", "SSeRiouSS"] +__all__ = ["PyanNet", "SSeRiouSS", "DiariZen"]