-
-
Notifications
You must be signed in to change notification settings - Fork 766
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(model) : add segmentation model based on self-supervised representation #1362
Merged
Merged
Changes from 4 commits
Commits
Show all changes
25 commits
Select commit
Hold shift + click to select a range
6d3af2e
add WaVLM-Base model to PyanNet.py in replacement of SincNet
SevKod d03906b
implement wavlm inside PyanNet class and add wavlm block
SevKod 3fc2d37
add support of all Torchaudio self-supverised models to PyanNet, incl…
SevKod 1e370fc
add support of ssl models from huggingface to pyannote using PyanHugg…
SevKod e170eed
remove support for sincnet block
SevKod 9f81c30
Remove unnecessary computation for unused deeper layers.
SevKod e5330fc
add support for fairseq pretrained ssl models
SevKod 7a21fc9
fairseq dependency only used if needed
SevKod 6243f91
Merge branch 'develop' into PyanNetWavLM
hbredin 328505c
Remove unnecessary computation for unused deeper layers (regarding a …
SevKod d4ddd53
Merge branch 'PyanNetWavLM' of github.com:SevKod/pyannote-audio into …
SevKod 63a9e42
Merge branch 'develop' into PyanNetWavLM
hbredin cbd01a3
Remove HuggingFace and fairseq dependencies from self-sup
SevKod f608eb7
Merge branch 'PyanNetWavLM' of github.com:SevKod/pyannote-audio into …
SevKod d7e9203
add support for torchaudio self sup models
SevKod 81aafdd
fixed bug condition of wavlm_base and wavlm_large
SevKod b9c89b6
add layer-wise pooling and finetuning (still wip)
SevKod 8aba20e
Merge branch 'develop' into PyanNetWavLM
hbredin 4a8bfe2
Merge branch 'develop' into PyanNetWavLM
hbredin 2323105
Merge branch 'develop' into PyanNetWavLM
hbredin cedf042
feat: add SSeRiouSS architecture
hbredin 06641bf
chore: remove old PyanSup
hbredin 31d08a4
chore: remove now replaced SelfSupModel block
hbredin 421ba03
doc: update changelog
hbredin 5f9211c
Merge branch 'develop' into PyanNetWavLM
hbredin File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
# MIT License | ||
# | ||
# Copyright (c) 2020 CNRS | ||
# | ||
# Permission is hereby granted, free of charge, to any person obtaining a copy | ||
# of this software and associated documentation files (the "Software"), to deal | ||
# in the Software without restriction, including without limitation the rights | ||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
# copies of the Software, and to permit persons to whom the Software is | ||
# furnished to do so, subject to the following conditions: | ||
# | ||
# The above copyright notice and this permission notice shall be included in all | ||
# copies or substantial portions of the Software. | ||
# | ||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
# SOFTWARE. | ||
|
||
from typing import Optional | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from transformers import AutoModel, Wav2Vec2FeatureExtractor, AutoConfig | ||
|
||
class SelfSupModel(nn.Module): | ||
|
||
def __init__(self, model,layer, cache): | ||
super().__init__() | ||
self.model = model | ||
print("\nThe selected Self-Supervised Model from HuggingFace is "+ model+".\n") | ||
SelfSupModel.__name__ = model.rsplit('/', 1)[1] #Overwrite the class name to that of the selected model | ||
if cache is not None : | ||
print("Model and configuration file location is : "+str(cache)) | ||
config = AutoConfig.from_pretrained(model, cache_dir = cache) | ||
config.cache_dir= cache | ||
else : | ||
config = AutoConfig.from_pretrained(model) | ||
|
||
config.output_hidden_states = True | ||
|
||
|
||
self.ssl_model = AutoModel.from_pretrained(model, config = config, cache_dir = cache) #Load the model | ||
self.ssl_model.eval() | ||
|
||
self.feat_size = config.hidden_size #Get the encoder feature size | ||
self.processor = Wav2Vec2FeatureExtractor.from_pretrained(model, return_tensors="pt") | ||
|
||
if layer == None : | ||
print("\nLayer number not specified. Default to the first one (layer 0).\n") | ||
self.layer = 0 | ||
else : | ||
self.layer = layer | ||
print("\nSelected frozen layer is "+ str(layer) +". \n") | ||
|
||
def forward(self, waveforms: torch.Tensor) -> torch.Tensor: | ||
waveforms = torch.squeeze(waveforms,1) #waveforms : (batch, channel, sample) -> (batch,sample) | ||
if self.processor.do_normalize == True : | ||
waveforms = F.layer_norm(waveforms, waveforms.shape) | ||
|
||
with torch.no_grad(): | ||
features = self.ssl_model(waveforms) #Compute the features and extract last hidden layer weights | ||
|
||
outputs = features.hidden_states[self.layer + 1] | ||
|
||
return (outputs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,211 @@ | ||
from typing import Optional | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from einops import rearrange | ||
from pyannote.core.utils.generators import pairwise | ||
|
||
from pyannote.audio.core.model import Model | ||
from pyannote.audio.core.task import Task | ||
from pyannote.audio.models.blocks.sincnet import SincNet | ||
from pyannote.audio.models.blocks.selfsup import SelfSupModel | ||
from pyannote.audio.utils.params import merge_dict | ||
|
||
|
||
class PyanHugg(Model): | ||
"""PyanHugg segmentation model | ||
|
||
Self-Supervised Model (or SincNet if specified) > LSTM > Feed forward > Classifier | ||
|
||
All HuggingFace Self-Sup. models can be found at https://huggingface.co/models | ||
Tested (and currently working) models are : | ||
- "microsoft/wavlm-base" | ||
- "microsoft/wavlm-large" | ||
- "facebook/hubert-base-ls960" | ||
- "facebook/wav2vec2-base-960h" | ||
|
||
Parameters | ||
---------- | ||
sample_rate : int, optional | ||
Audio sample rate. Defaults to 16kHz (16000). | ||
num_channels : int, optional | ||
Number of channels. Defaults to mono (1). | ||
|
||
selfsupervised : dict, optional | ||
Keyword arugments passed to the selfsupervised block. | ||
Defaults to { | ||
"model": "microsoft/wavlm-base", | ||
"layer": 4, | ||
"cache": None, | ||
}. If "model" is specified as "sincnet", SincNet block will be used instead. | ||
sincnet : dict, optional | ||
Keyword arugments passed to the SincNet block. | ||
Defaults to {"stride": 1}. | ||
lstm : dict, optional | ||
Keyword arguments passed to the LSTM layer. | ||
Defaults to {"hidden_size": 128, "num_layers": 2, "bidirectional": True}, | ||
i.e. two bidirectional layers with 128 units each. | ||
Set "monolithic" to False to split monolithic multi-layer LSTM into multiple mono-layer LSTMs. | ||
This may proove useful for probing LSTM internals. | ||
linear : dict, optional | ||
Keyword arugments used to initialize linear layers | ||
Defaults to {"hidden_size": 128, "num_layers": 2}, | ||
i.e. two linear layers with 128 units each. | ||
""" | ||
|
||
|
||
|
||
SINCNET_DEFAULTS = {"stride": 10} | ||
SSL_DEFAULTS = { | ||
"model": "microsoft/wavlm-base", | ||
"layer": 4, | ||
"cache": None, | ||
} | ||
LSTM_DEFAULTS = { | ||
"hidden_size": 128, | ||
"num_layers": 2, | ||
"bidirectional": True, | ||
"monolithic": True, | ||
"dropout": 0.0, | ||
} | ||
LINEAR_DEFAULTS = {"hidden_size": 128, "num_layers": 2} | ||
|
||
def __init__( | ||
self, | ||
selfsupervised: dict = None, | ||
sincnet: dict = None, | ||
lstm: dict = None, | ||
linear: dict = None, | ||
sample_rate: int = 16000, | ||
num_channels: int = 1, | ||
task: Optional[Task] = None, | ||
): | ||
|
||
super().__init__(sample_rate=sample_rate, num_channels=num_channels, task=task) | ||
|
||
selfsupervised = merge_dict(self.SSL_DEFAULTS, selfsupervised) | ||
sincnet = merge_dict(self.SINCNET_DEFAULTS, sincnet) | ||
sincnet["sample_rate"] = sample_rate | ||
lstm = merge_dict(self.LSTM_DEFAULTS, lstm) | ||
lstm["batch_first"] = True | ||
linear = merge_dict(self.LINEAR_DEFAULTS, linear) | ||
if (selfsupervised["model"] == "sincnet") : | ||
self.save_hyperparameters("sincnet", "lstm", "linear") | ||
else : | ||
self.save_hyperparameters("selfsupervised", "lstm", "linear") | ||
|
||
self.model = selfsupervised["model"] | ||
|
||
#All HuggingFace Self-Sup. models can be found at https://huggingface.co/models | ||
print("\n##################################################################") | ||
if selfsupervised["model"] is not "sincnet" : | ||
print("### A self-supervised model is used for the feature extraction ###") | ||
print("##################################################################") | ||
self.selfsupervised = SelfSupModel(**self.hparams.selfsupervised) | ||
feat_size = self.selfsupervised.feat_size | ||
else : | ||
self.sincnet = SincNet(**self.hparams.sincnet) | ||
print("### The SincNet module is used for the feature extraction ### ") | ||
feat_size = 60 | ||
|
||
print("##################################################################\n") | ||
monolithic = lstm["monolithic"] | ||
if monolithic: | ||
multi_layer_lstm = dict(lstm) | ||
del multi_layer_lstm["monolithic"] | ||
self.lstm = nn.LSTM(feat_size, **multi_layer_lstm) | ||
|
||
else: | ||
num_layers = lstm["num_layers"] | ||
if num_layers > 1: | ||
self.dropout = nn.Dropout(p=lstm["dropout"]) | ||
|
||
one_layer_lstm = dict(lstm) | ||
one_layer_lstm["num_layers"] = 1 | ||
one_layer_lstm["dropout"] = 0.0 | ||
del one_layer_lstm["monolithic"] | ||
|
||
self.lstm = nn.ModuleList( | ||
[ | ||
nn.LSTM( | ||
feat_size | ||
if i == 0 | ||
else lstm["hidden_size"] * (2 if lstm["bidirectional"] else 1), | ||
**one_layer_lstm | ||
) | ||
for i in range(num_layers) | ||
] | ||
) | ||
|
||
if linear["num_layers"] < 1: | ||
return | ||
|
||
lstm_out_features: int = self.hparams.lstm["hidden_size"] * ( | ||
2 if self.hparams.lstm["bidirectional"] else 1 | ||
) | ||
self.linear = nn.ModuleList( | ||
[ | ||
nn.Linear(in_features, out_features) | ||
for in_features, out_features in pairwise( | ||
[ | ||
lstm_out_features, | ||
] | ||
+ [self.hparams.linear["hidden_size"]] | ||
* self.hparams.linear["num_layers"] | ||
) | ||
] | ||
) | ||
|
||
def build(self): | ||
|
||
if self.hparams.linear["num_layers"] > 0: | ||
in_features = self.hparams.linear["hidden_size"] | ||
else: | ||
in_features = self.hparams.lstm["hidden_size"] * ( | ||
2 if self.hparams.lstm["bidirectional"] else 1 | ||
) | ||
|
||
if self.specifications.powerset: | ||
out_features = self.specifications.num_powerset_classes | ||
else: | ||
out_features = len(self.specifications.classes) | ||
|
||
self.classifier = nn.Linear(in_features, out_features) | ||
self.activation = self.default_activation() | ||
|
||
def forward(self, waveforms: torch.Tensor) -> torch.Tensor: | ||
"""Pass forward | ||
|
||
Parameters | ||
---------- | ||
waveforms : (batch, channel, sample) | ||
|
||
Returns | ||
------- | ||
scores : (batch, frame, classes) | ||
""" | ||
if self.model != "sincnet" : | ||
outputs = self.selfsupervised(waveforms) | ||
else : | ||
outputs = self.sincnet(waveforms) | ||
if self.hparams.lstm["monolithic"]: | ||
if self.model != "sincnet" : | ||
outputs, _ = self.lstm(outputs) | ||
else: | ||
outputs, _ = self.lstm( | ||
rearrange(outputs, "batch feature frame -> batch frame feature") | ||
) | ||
else: | ||
if self.model == "sincnet" : | ||
outputs = rearrange(outputs, "batch feature frame -> batch frame feature") | ||
for i, lstm in enumerate(self.lstm): | ||
outputs, _ = lstm(outputs) | ||
if i + 1 < self.hparams.lstm["num_layers"]: | ||
outputs = self.dropout(outputs) | ||
|
||
if self.hparams.linear["num_layers"] > 0: | ||
for linear in self.linear: | ||
outputs = F.leaky_relu(linear(outputs)) | ||
|
||
return self.activation(self.classifier(outputs)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would remove support for
SincNet
completely to avoid any confusion.