diff --git a/pyannote/audio/pipelines/speaker_diarization.py b/pyannote/audio/pipelines/speaker_diarization.py index 5ab8cbc6e..597127b41 100644 --- a/pyannote/audio/pipelines/speaker_diarization.py +++ b/pyannote/audio/pipelines/speaker_diarization.py @@ -547,7 +547,7 @@ def apply( # during counting, we could possibly overcount the number of instantaneous # speakers due to segmentation errors, so we cap the maximum instantaneous number # of speakers by the `max_speakers` value - count.data = np.minimum(count.data, max_speakers) + count.data = np.minimum(count.data, max_speakers).astype(np.int8) # reconstruct discrete diarization from raw hard clusters @@ -604,6 +604,14 @@ def apply( if not return_embeddings: return diarization + # The number of centroids may be smaller than the number of speakers + # in the annotation. This can happen if the number of active speakers + # obtained from `speaker_count` for some frames is larger than the number + # of clusters obtained from `clustering`. In this case, we append zero embeddings + # for extra speakers + if len(diarization.labels()) > centroids.shape[0]: + centroids = np.pad(centroids, (0, len(diarization.labels()) - centroids.shape[0])) + # re-order centroids so that they match # the order given by diarization.labels() inverse_mapping = {label: index for index, label in mapping.items()} @@ -611,11 +619,6 @@ def apply( [inverse_mapping[label] for label in diarization.labels()] ] - # FIXME: the number of centroids may be smaller than the number of speakers - # in the annotation. This can happen if the number of active speakers - # obtained from `speaker_count` for some frames is larger than the number - # of clusters obtained from `clustering`. Will be fixed in the future - return diarization, centroids def get_metric(self) -> GreedyDiarizationErrorRate: