diff --git a/pyannote/audio/pipelines/speaker_verification.py b/pyannote/audio/pipelines/speaker_verification.py index b30ea2b21..594a4823c 100644 --- a/pyannote/audio/pipelines/speaker_verification.py +++ b/pyannote/audio/pipelines/speaker_verification.py @@ -22,6 +22,7 @@ import warnings from functools import cached_property +from pathlib import Path from typing import Text, Union import numpy as np @@ -29,6 +30,8 @@ import torch.nn.functional as F import torchaudio import torchaudio.compliance.kaldi as kaldi +from huggingface_hub import hf_hub_download +from huggingface_hub.utils import RepositoryNotFoundError from torch.nn.utils.rnn import pad_sequence from pyannote.audio import Inference, Model, Pipeline @@ -395,7 +398,7 @@ class WeSpeakerPretrainedSpeakerEmbedding(BaseInference): Usage ----- - >>> get_embedding = WeSpeakerPretrainedSpeakerEmbedding("wespeaker.xxxx.onnx") + >>> get_embedding = WeSpeakerPretrainedSpeakerEmbedding("hbredin/wespeaker-voxceleb-resnet34-LM") >>> assert waveforms.ndim == 3 >>> batch_size, num_channels, num_samples = waveforms.shape >>> assert num_channels == 1 @@ -410,7 +413,7 @@ class WeSpeakerPretrainedSpeakerEmbedding(BaseInference): def __init__( self, - embedding: Text = "speechbrain/spkrec-ecapa-voxceleb", + embedding: Text = "hbredin/wespeaker-voxceleb-resnet34-LM", device: torch.device = None, ): if not ONNX_IS_AVAILABLE: @@ -420,6 +423,17 @@ def __init__( super().__init__() + if not Path(embedding).exists(): + try: + embedding = hf_hub_download( + repo_id=embedding, + filename="speaker-embedding.onnx", + ) + except RepositoryNotFoundError: + raise ValueError( + f"Could not find '{embedding}' on huggingface.co nor on local disk." + ) + self.embedding = embedding self.to(device or torch.device("cpu"))