Skip to content

Commit

Permalink
Merge branch 'develop' into check-task-hyperparameters-when-loading-d…
Browse files Browse the repository at this point in the history
…ata-from-cache
  • Loading branch information
hbredin authored Oct 23, 2024
2 parents 4bccb40 + a39991b commit 6af4856
Show file tree
Hide file tree
Showing 21 changed files with 3,865 additions and 49 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
strategy:
matrix:
os: [ubuntu-latest]
python-version: [3.8, 3.9, "3.10"]
python-version: ["3.9", "3.10", "3.11"]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
Expand Down
31 changes: 31 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,37 @@

## develop

### New features

- feat: add support for `k-means` clustering
- feat: add `"hidden"` option to `ProgressHook`
- feat: add `FilterByNumberOfSpeakers` protocol files filter

### Fixes

- fix: fix clipping issue in speech separation pipeline ([@joonaskalda](https://github.com/joonaskalda/))

## Version 3.3.2 (2024-09-11)

### Fixes

- fix: (really) fix support for `numpy==2.x` ([@metal3d](https://github.com/metal3d/))
- doc: fix `Pipeline` docstring ([@huisman](https://github.com/huisman/))

## Version 3.3.1 (2024-06-19)

### Breaking changes

- setup: drop support for Python 3.8

### Fixes

- fix: fix support for `numpy==2.x` ([@ibevers](https://github.com/ibevers/))
- fix: fix support for `speechbrain==1.x` ([@Adel-Moumen](https://github.com/Adel-Moumen/))


## Version 3.3.0 (2024-06-14)

### TL;DR

`pyannote.audio` does [speech separation](https://hf.co/pyannote/speech-separation-ami-1.0): multi-speaker audio in, one audio channel per speaker out!
Expand Down
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
Using `pyannote.audio` open-source toolkit in production?
Using `pyannote.audio` open-source toolkit in production?
Consider switching to [pyannoteAI](https://www.pyannote.ai) for better and faster options.

# `pyannote.audio` speaker diarization toolkit
Expand Down Expand Up @@ -73,6 +73,7 @@ for turn, _, speaker in diarization.itertracks(yield_label=True):
- [First release of pyannote.audio](https://www.youtube.com/watch?v=37R_R82lfwA) / ICASSP 2020 / 8 min
- Community contributions (not maintained by the core team)
- 2024-04-05 > [Offline speaker diarization (speaker-diarization-3.1)](tutorials/community/offline_usage_speaker_diarization.ipynb) by [Simon Ottenhaus](https://github.com/simonottenhauskenbun)
- 2024-09-24 > [Evaluating `pyannote` pretrained speech separation pipelines](tutorials/community/eval_separation_pipeline.ipynb) by [Clément Pagés](https://github.com/)

## Benchmark

Expand Down
2 changes: 1 addition & 1 deletion pyannote/audio/core/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,7 @@ def aggregate(
warm_up: Tuple[float, float] = (0.0, 0.0),
epsilon: float = 1e-12,
hamming: bool = False,
missing: float = np.NaN,
missing: float = np.nan,
skip_average: bool = False,
) -> SlidingWindowFeature:
"""Aggregation
Expand Down
3 changes: 2 additions & 1 deletion pyannote/audio/core/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ def from_pretrained(
to True or to a string containing your hugginface.co authentication
token that can be obtained by running `huggingface-cli login`
cache_dir: Path or str, optional
Path to model cache directory. Defauorch/pyannote" when unset.
Path to model cache directory. Defaults to content of PYANNOTE_CACHE
environment variable, or "~/.cache/torch/pyannote" when unset.
"""

checkpoint_path = str(checkpoint_path)
Expand Down
82 changes: 77 additions & 5 deletions pyannote/audio/pipelines/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

"""Clustering pipelines"""


import random
from enum import Enum
from typing import Optional, Tuple
Expand All @@ -35,6 +34,7 @@
from scipy.cluster.hierarchy import fcluster, linkage
from scipy.optimize import linear_sum_assignment
from scipy.spatial.distance import cdist
from sklearn.cluster import KMeans

from pyannote.audio.core.io import AudioFile
from pyannote.audio.pipelines.utils import oracle_segmentation
Expand Down Expand Up @@ -264,8 +264,8 @@ def __call__(

train_clusters = self.cluster(
train_embeddings,
min_clusters,
max_clusters,
min_clusters=min_clusters,
max_clusters=max_clusters,
num_clusters=num_clusters,
)

Expand Down Expand Up @@ -298,6 +298,8 @@ class AgglomerativeClustering(BaseClustering):
Minimum cluster size
"""

expects_num_clusters: bool = False

def __init__(
self,
metric: str = "cosine",
Expand All @@ -321,8 +323,8 @@ def __init__(
def cluster(
self,
embeddings: np.ndarray,
min_clusters: int,
max_clusters: int,
min_clusters: Optional[int] = None,
max_clusters: Optional[int] = None,
num_clusters: Optional[int] = None,
):
"""
Expand Down Expand Up @@ -471,9 +473,78 @@ def cluster(
return clusters


class KMeansClustering(BaseClustering):
"""KMeans clustering
Parameters
----------
metric : {"cosine", "euclidean"}, optional
Distance metric to use. Defaults to "cosine".
Hyper-parameters
----------------
None
"""

expects_num_clusters: bool = True

def __init__(
self,
metric: str = "cosine",
):
if metric not in ["cosine", "euclidean"]:
raise ValueError(
f"Unsupported metric: {metric}. Must be 'cosine' or 'euclidean'."
)

super().__init__(metric=metric)

def cluster(
self,
embeddings: np.ndarray,
min_clusters: Optional[int] = None,
max_clusters: Optional[int] = None,
num_clusters: Optional[int] = None,
):
"""Perform KMeans clustering
Parameters
----------
embeddings : (num_embeddings, dimension) array
Embeddings
num_clusters : int, optional
Expected number of clusters.
Returns
-------
clusters : (num_embeddings, ) array
0-indexed cluster indices.
"""

if num_clusters is None:
raise ValueError("`num_clusters` must be provided.")

num_embeddings, _ = embeddings.shape
if num_embeddings < num_clusters:
# one cluster per embedding as int
return np.arange(num_embeddings, dtype=np.int32)

# unit-normalize embeddings to use 'euclidean' distance
if self.metric == "cosine":
with np.errstate(divide="ignore", invalid="ignore"):
embeddings /= np.linalg.norm(embeddings, axis=-1, keepdims=True)

# perform Kmeans clustering
return KMeans(
n_clusters=num_clusters, n_init=3, random_state=42, copy_x=False
).fit_predict(embeddings)


class OracleClustering(BaseClustering):
"""Oracle clustering"""

expects_num_clusters: bool = True

def __call__(
self,
embeddings: Optional[np.ndarray] = None,
Expand Down Expand Up @@ -558,4 +629,5 @@ def __call__(

class Clustering(Enum):
AgglomerativeClustering = AgglomerativeClustering
KMeansClustering = KMeansClustering
OracleClustering = OracleClustering
23 changes: 19 additions & 4 deletions pyannote/audio/pipelines/speaker_diarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import math
import textwrap
import warnings
from typing import Callable, Optional, Text, Union
from typing import Callable, Mapping, Optional, Text, Union

import numpy as np
import torch
Expand All @@ -45,6 +45,7 @@
SpeakerDiarizationMixin,
get_model,
)
from pyannote.audio.pipelines.utils.diarization import set_num_speakers
from pyannote.audio.utils.signal import binarize


Expand Down Expand Up @@ -177,6 +178,8 @@ def __init__(
)
self.clustering = Klustering.value(metric=metric)

self._expects_num_speakers = self.clustering.expects_num_clusters

@property
def segmentation_batch_size(self) -> int:
return self._segmentation.batch_size
Expand Down Expand Up @@ -400,7 +403,7 @@ def reconstruct(
num_chunks, num_frames, local_num_speakers = segmentations.data.shape

num_clusters = np.max(hard_clusters) + 1
clustered_segmentations = np.NAN * np.zeros(
clustered_segmentations = np.nan * np.zeros(
(num_chunks, num_frames, num_clusters)
)

Expand Down Expand Up @@ -469,12 +472,25 @@ def apply(
# setup hook (e.g. for debugging purposes)
hook = self.setup_hook(file, hook=hook)

num_speakers, min_speakers, max_speakers = self.set_num_speakers(
num_speakers, min_speakers, max_speakers = set_num_speakers(
num_speakers=num_speakers,
min_speakers=min_speakers,
max_speakers=max_speakers,
)

# when using KMeans clustering (or equivalent), the number of speakers must
# be provided alongside the audio file. also, during pipeline training, we
# infer the number of speakers from the reference annotation to avoid the
# pipeline complaining about missing number of speakers.
if self._expects_num_speakers and num_speakers is None:
if isinstance(file, Mapping) and "annotation" in file:
num_speakers = len(file["annotation"].labels())

else:
raise ValueError(
f"num_speakers must be provided when using {self.klustering} clustering"
)

segmentations = self.get_segmentations(file, hook=hook)
hook("segmentation", segmentations)
# shape: (num_chunks, num_frames, local_num_speakers)
Expand Down Expand Up @@ -515,7 +531,6 @@ def apply(
centroids = None

else:

# skip speaker embedding extraction with oracle clustering
if self.klustering == "OracleClustering" and not return_embeddings:
embeddings = None
Expand Down
12 changes: 6 additions & 6 deletions pyannote/audio/pipelines/speaker_verification.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from pyannote.audio.pipelines.utils import PipelineModel, get_model

try:
from speechbrain.pretrained import (
from speechbrain.inference import (
EncoderClassifier as SpeechBrain_EncoderClassifier,
)

Expand Down Expand Up @@ -186,7 +186,7 @@ def __call__(

# corner case: every signal is too short
if max_len < self.min_num_samples:
return np.NAN * np.zeros((batch_size, self.dimension))
return np.nan * np.zeros((batch_size, self.dimension))

too_short = wav_lens < self.min_num_samples
wav_lens[too_short] = max_len
Expand All @@ -197,7 +197,7 @@ def __call__(
)

embeddings = embeddings.cpu().numpy()
embeddings[too_short.cpu().numpy()] = np.NAN
embeddings[too_short.cpu().numpy()] = np.nan

return embeddings

Expand Down Expand Up @@ -364,7 +364,7 @@ def __call__(

# corner case: every signal is too short
if max_len < self.min_num_samples:
return np.NAN * np.zeros((batch_size, self.dimension))
return np.nan * np.zeros((batch_size, self.dimension))

too_short = wav_lens < self.min_num_samples
wav_lens = wav_lens / max_len
Expand All @@ -377,7 +377,7 @@ def __call__(
.numpy()
)

embeddings[too_short.cpu().numpy()] = np.NAN
embeddings[too_short.cpu().numpy()] = np.nan

return embeddings

Expand Down Expand Up @@ -594,7 +594,7 @@ def __call__(

imasks = imasks > 0.5

embeddings = np.NAN * np.zeros((batch_size, self.dimension))
embeddings = np.nan * np.zeros((batch_size, self.dimension))

for f, (feature, imask) in enumerate(zip(features, imasks)):
masked_feature = feature[imask]
Expand Down
12 changes: 9 additions & 3 deletions pyannote/audio/pipelines/speech_separation.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
SpeakerDiarizationMixin,
get_model,
)
from pyannote.audio.pipelines.utils.diarization import set_num_speakers
from pyannote.audio.utils.signal import binarize


Expand Down Expand Up @@ -419,7 +420,7 @@ def reconstruct(
num_chunks, num_frames, local_num_speakers = segmentations.data.shape

num_clusters = np.max(hard_clusters) + 1
clustered_segmentations = np.NAN * np.zeros(
clustered_segmentations = np.nan * np.zeros(
(num_chunks, num_frames, num_clusters)
)

Expand All @@ -441,7 +442,6 @@ def reconstruct(
clustered_segmentations, segmentations.sliding_window
)
return clustered_segmentations
return self.to_diarization(clustered_segmentations, count)

def apply(
self,
Expand Down Expand Up @@ -490,7 +490,7 @@ def apply(
# setup hook (e.g. for debugging purposes)
hook = self.setup_hook(file, hook=hook)

num_speakers, min_speakers, max_speakers = self.set_num_speakers(
num_speakers, min_speakers, max_speakers = set_num_speakers(
num_speakers=num_speakers,
min_speakers=min_speakers,
max_speakers=max_speakers,
Expand Down Expand Up @@ -654,6 +654,12 @@ def apply(
sources.data * discrete_diarization.align(sources).data[:, :num_sources]
)

# separated sources might be scaled up/down due to SI-SDR loss used when training
# so we peak-normalize them
sources.data = sources.data / np.max(
np.abs(sources.data), axis=0, keepdims=True
)

# convert to continuous diarization
diarization = self.to_annotation(
discrete_diarization,
Expand Down
Loading

0 comments on commit 6af4856

Please sign in to comment.