-
-
Notifications
You must be signed in to change notification settings - Fork 766
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(model) : add segmentation model based on self-supervised representation #1362
Conversation
Added WavLM-Base model which replaces the SincNet feature extraction model within the PyanNet architecture (loaded outside of the class from HuggingFace.co).
#Loading the model from HuggingFace (requires git lfs to load the .bin checkpoint) | ||
#model = AutoModel.from_pretrained('/content/drive/MyDrive/PyanNet/wavlm-base') | ||
|
||
model = AutoModel.from_pretrained('microsoft/wavlm-base') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is definitely the reason why the training complains about GPU/CPU mismatch.
The WavLM module should be instantiated in __init__
and assigned as an attribute of the model.
Read this carefully to understand why.
…uding layer selection. Created a block (in replacement of the old WavLM one) called "selfsup.py" which loads and apply a specific SSL Torchaudio model, depending on PyanNet's input parameter. User can now also choose a specific layer which will then be used for feature extraction. Ex : seg_model = PyanNet(task=seg, model = "HUBERT_BASE", layer = 5) This will load "HUBERT_BASE" model and select the 6th layer for the feature extraction. If layer is not specified, will automatically use the first one (layer 0). All available models can be found at : https://pytorch.org/audio/main/pipelines.html
… class Can use pre-trained ssl models from huggingface using PyanHugg class. Tested (and working) models are : - "microsoft/wavlm-base" - "microsoft/wavlm-large" - "facebook/hubert-base-ls960" - "facebook/wav2vec2-base-960h" Class supports model and layer selection (as well as cache location for the downloaded model and configuration file). Ex : seg_model = PyanHugg(task=seg, selfsupervised={ 'model' : 'microsoft/wavlm-base', 'layer' : 2, 'cache' : 'mod_location/'})
lstm = merge_dict(self.LSTM_DEFAULTS, lstm) | ||
lstm["batch_first"] = True | ||
linear = merge_dict(self.LINEAR_DEFAULTS, linear) | ||
if (selfsupervised["model"] == "sincnet") : |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would remove support for SincNet
completely to avoid any confusion.
Can load a fairseq ckpt from a pretrained model (which is converted to torchaudio wav2vec2 format)
model = wav2vec2_model(**config) | ||
model.load_state_dict(ordered_dict) #Assign state dict to the model | ||
|
||
if finetune: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Peux-tu m'expliquer pourquoi cela est nécessaire ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Je parle du passage au mode eval
dans le cas où le WavLM est gelé.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
J'avais cru comprendre que le .eval() était pertinent lorsque certains modules tel que des couches de Dropout, sont présentes dans le modèle en question que l'on souhaite passer en mode inférence (le cas pour WavLM). J'avais lu ce post qui conseillait l'utilisation des deux :
Après, de souvenirs, je n'avais pas identifié de quelconque changement au niveau des features entre le ".eval()" et le "no_grad". Apparemment, le .eval() consommerait plus de mémoire que le no_grad aussi... Donc, pourquoi pas l'enlever. C'était aussi pour étudier les changements entre le .eval() et le .train() quand je voulais voir ce qui se passait au niveau du finetuning.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
model.eval()
et torch.no_grad()
ont deux rôles bien différents.
torch.no_grad()
supprimer le calcul du gradient des couches concernées et est donc utile quand tu veux geler une partie du réseau.model.eval()
passe les couches qui ont un comportement particulier lors de l'apprentissage (e.g.dropout
qui désactive aléatoirement certains poids, oubatchnorm
qui calcule une moyenne des données qui la traverse) en mode inférence pour supprimer tout cet aléa.
En résumé, il ne faut pas utiliser ni model.eval()
ni model.train()
pour contrôler si tu finetunes ou non la partie feature extraction. Il faut seulement utiliser torch.no_grad()
(pour geler) ou pas (pour finetuner).
Le passage en mode eval
ou train
est effectué automatiquement par pytorch-lightning
lors des phases de validation et d'apprentissage.
if finetune: #Finetuning not working | ||
print("Self-supervised model is unfrozen.") | ||
#config['encoder_ff_interm_dropout'] = 0.3 | ||
config['encoder_layer_norm_first'] = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you explain?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The issue regarding the funetuning of WavLM seemed similar to a normalization issue that occured during the feature extraction process. If gradient is computed during training, validation will extract feature vectors that are almost identical amongst each frames of the input audio. I assume that it might be the reason why validation does not seem to improve (or change) during training (but I might be completely wrong on this...). Since this problem seemed similar to the one I encountered with WavLM from back a few months (and has been fixed), where features were the also the same amongst the frames, I tried applying a normalization step to see if the behavior of the features extracted would change. Did not seem to be the case... It is one of the many things I tried but forgot to remove when pushing codes ^^
…ntation (#1362) Co-authored-by: Hervé BREDIN <hbredin@users.noreply.github.com>
Added WavLM-Base model which replaces the SincNet feature extraction model within the PyanNet architecture (loaded outside of the class from HuggingFace.co).