Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/joint diarization and embedding with prepared data #1583

Draft
wants to merge 101 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
101 commits
Select commit Hold shift + click to select a range
0551070
fix: raise TypeError on wrong device type in Pipeline.to and Inferenc…
chai3 Jun 8, 2023
30ddb0b
feat(task): add support for multi-task models (#1374)
hbredin Jun 12, 2023
4eb7190
fix(inference): fix multi-task inference
hbredin Jun 12, 2023
dcdfc15
feat: update FAQtory default answer
hbredin Jun 15, 2023
87f49f9
add draft version of the joint diarization and embedding tasks
clement-pages Jun 17, 2023
6025a80
Merge branch 'develop' of github.com:clement-pages/pyannote-audio int…
clement-pages Jun 17, 2023
58599c9
update `train__iter__helper` method of the joint task
clement-pages Jun 19, 2023
04de82f
fix `StopIteration` error
clement-pages Jun 19, 2023
d8cb598
add missing collate methods
clement-pages Jun 19, 2023
d2d6e14
remove support for non-powerset mode
clement-pages Jun 19, 2023
e58943b
remove computing of vad loss
clement-pages Jun 19, 2023
bc989cd
remove unused imports
clement-pages Jun 19, 2023
b4d0a78
fix probabilities do not sum to 1 error
clement-pages Jun 19, 2023
78718b1
attempt to fix file duration error
clement-pages Jun 20, 2023
dfdd8f3
attempt to fix negative `start_time` in embedding part
clement-pages Jun 20, 2023
1888360
add end-to-end diarization and embedding model
clement-pages Jun 20, 2023
6216d1f
update end-to-end model
clement-pages Jun 21, 2023
b42cc33
clean multi-task source code
clement-pages Jun 21, 2023
3d295dd
remove support for `SegmentationProtocol` in the multi-tasks
clement-pages Jun 21, 2023
3363be6
improve(test): use pyannote.database.registry (#1413)
hbredin Jun 22, 2023
99a7762
Set `alpha` coefficient as attribute
clement-pages Jun 23, 2023
f2a4e34
remove `diarization_database_files` attribute
clement-pages Jun 23, 2023
017c910
feat(pipeline): add `return_embeddings` option to `SpeakerDiarization…
flyingleafe Jun 23, 2023
cf0e3b3
fix: fix missed speech at the very beginning/end
hbredin Jun 27, 2023
f48b74f
add losses computation in `training_step` method
clement-pages Jun 27, 2023
f393546
doc: add note to self regarding cluster reassignment (#1419)
hbredin Jun 28, 2023
5718593
remove for loops in embedding loss computation
clement-pages Jun 28, 2023
8036572
add validation part into the multi-task
clement-pages Jul 3, 2023
aa36d7b
remove `subtask` parameter from `prepare_chunk`
clement-pages Jul 4, 2023
6617c9c
fix bugs in validation part
clement-pages Jul 4, 2023
60d5543
simplify the way embedding loss is calculated
clement-pages Jul 5, 2023
2834d3e
handle case where there is no files from diarization dataset
clement-pages Jul 6, 2023
35be745
fix(doc): fix typo in diarization docstring
DiaaAj Jul 9, 2023
5628b48
fix size issue in `collate_y` when building embedding ref
clement-pages Jul 11, 2023
c4988f4
fix condition to compute `emb_loss` in `training_step`
clement-pages Jul 12, 2023
75467f0
Merge branch 'develop' into feat/joint-diarization-and-embedding
hbredin Jul 12, 2023
78b5b04
add missing docstrings
clement-pages Jul 12, 2023
bdf3567
remove redefinitions of `collate_X` and `collate_meta`
clement-pages Jul 12, 2023
aae90a0
add missing `dia_loss` assignment
clement-pages Jul 12, 2023
d3b3efc
filter out the speaker in ref not found by diarization
clement-pages Jul 18, 2023
4289ea9
modifiy `start_time` possible values interval in `draw_embedding_chunk`
clement-pages Jul 18, 2023
e9f40a3
add V2 of `SpeakerEndToEndDiarization`
clement-pages Jul 19, 2023
3f7cb8a
Add `padding="same"` in model `Conv1d` layers
clement-pages Jul 26, 2023
0f1577d
update LSTM encoder in SPEED V2
clement-pages Jul 26, 2023
933a660
add `prepare_data` method in `Task` class
Oct 13, 2023
5257145
Merge branch 'develop' into feat/data_preparation
hbredin Oct 26, 2023
8829574
modify organisation of `pyannote` segmentation tasks
Nov 2, 2023
fa63c8a
Merge branch 'feat/data_preparation' of github.com:clement-pages/pyan…
Nov 2, 2023
be6f7ec
add two training tests
Nov 7, 2023
f447bb6
assign data directly to task in main process, in `prepare_data`
Nov 7, 2023
930deda
Merge branch 'develop' into feat/data_preparation
hbredin Nov 7, 2023
05ccc30
handle call to `Task.prepare_data` and `Task.setup` under different s…
Nov 8, 2023
44a01fe
Merge branch 'feat/data_preparation' of github.com:clement-pages/pyan…
Nov 8, 2023
4b8e8a2
add training tests using task caches
Nov 9, 2023
45918bd
update `cache_path` type and docstrings
Nov 9, 2023
980414e
fix `classes` variable used before assigment
Nov 9, 2023
a9ea07f
Merge branch 'develop' into feat/data_preparation
hbredin Nov 14, 2023
51a36f9
Merge branch 'feat/joint-diarization-and-embedding' into feat/joint-d…
Nov 15, 2023
c1fbb81
fix: fix residual merge problems
Nov 15, 2023
797a8a4
Merge branch 'pyannote:develop' into feat/data_preparation
clement-pages Nov 20, 2023
a17c2d0
Merge branch 'pyannote:develop' into feat/joint-diarization-and-embed…
clement-pages Nov 20, 2023
987e702
improve code readability
Nov 21, 2023
042dc43
improve: use `numpy` method for w/r task cache instead `pickle` (#1)
clement-pages Nov 27, 2023
5358986
Merge branch 'pyannote:develop' into feat/data_preparation
clement-pages Nov 27, 2023
0011870
improve: remove complete redefinition of `setup` in joint task
Nov 29, 2023
68763dc
Merge branch 'feat/joint-diarization-and-embedding-with-prepared-data…
Nov 29, 2023
7d78548
Merge branch 'feat/data-preparation' into feat/joint-diarization-and-…
Nov 29, 2023
6e6b62d
improve: remove duplicated attributes in `JointSpeakerDiarizationAndE…
Nov 29, 2023
e60873c
update: replace old `Task` attributes with prepared_data in joint task
Nov 29, 2023
40cc903
improve: handle multi-speaker embeddings in `example_output`
Nov 29, 2023
30ae9fb
feat: add new end-to-end model for joint speaker diarization and embe…
Nov 30, 2023
72f9916
fix: fix empty dict issue for `metadata_unique_values` in `prepared_d…
Nov 30, 2023
ecd2cb4
improve: add dynamic typing for np array in `prepare_data`
Nov 30, 2023
5e1abad
Merge branch 'feat/data-preparation' into feat/joint-diarization-and-…
Nov 30, 2023
fb6d540
improve: check matching bewteen task current protocol and cached prot…
Dec 4, 2023
3810308
remove: remove unused argument `stage` in `Task.setup`
Dec 4, 2023
f916db5
Merge branch 'feat/data-preparation' into feat/joint-diarization-and-…
Dec 4, 2023
e7da160
update: change name of attribute `database_ratio` to `dia_task_rate`
Dec 8, 2023
77ac89f
wip: attempt to fix issues encountered during training
Dec 8, 2023
ea6d06d
update: use all the `pyannet` pretrained model
Dec 8, 2023
185798d
fix: fix diarization loss calculation condition in `training_step`
Dec 8, 2023
3fef4f5
Merge branch 'develop' into feat/joint-diarization-and-embedding-with…
clement-pages May 14, 2024
9d13697
update joint task with last modifications on data preparation
clement-pages May 14, 2024
6c67fc6
update the way batches are generated in the joint task
clement-pages May 14, 2024
519db89
fix random generators
clement-pages May 14, 2024
106bfc5
delete remaining call to `example_output`
clement-pages May 14, 2024
d3326b1
update joint task `training_step`
clement-pages May 15, 2024
a36420d
fix(task): fiw wrong call to `receptive_field` in `prepare_chunk`
May 27, 2024
101f1d3
Merge branch 'develop' into feat/joint-diarization-and-embedding-with…
May 27, 2024
62fad78
update(joint task): filter out inactive speaker embeddings from loss …
May 28, 2024
8349818
allow to only compute mean or std in `StatsPool`
Jun 21, 2024
0858227
update diarization + embeddings joint task
Jun 21, 2024
ad9e435
wip: update joint model
Jun 21, 2024
aeb147f
Merge branch 'develop' into feat/joint-diarization-and-embedding-with…
clement-pages Jun 21, 2024
f484033
Merge branch 'pyannote:develop' into feat/joint-diarization-and-embed…
clement-pages Jul 1, 2024
8608a1c
wip: add pipeline working with joint model
hbredin Jul 8, 2024
1132cfc
Merge branch 'develop' into feat/joint-diarization-and-embedding-with…
clement-pages Oct 18, 2024
446c17c
Merge branch 'develop' into feat/joint-diarization-and-embedding-with…
Oct 18, 2024
e6a00b9
Merge branch 'feat/joint-diarization-and-embedding-with-prepared-data…
Oct 18, 2024
b91df8c
wip: add validation pipeline
Oct 25, 2024
5e54108
clean validation pipeline code
Oct 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pyannote/audio/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,9 @@ def receptive_field(self) -> SlidingWindow:
def prepare_data(self):
self.task.prepare_data()

def prepare_data(self):
self.task.prepare_data()

def setup(self, stage=None):
if stage == "fit":
# let the task know about the trainer (e.g for broadcasting
Expand Down
45 changes: 37 additions & 8 deletions pyannote/audio/models/blocks/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@
import torch.nn.functional as F


def _pool(sequences: torch.Tensor, weights: torch.Tensor) -> torch.Tensor:
def _pool(
sequences: torch.Tensor, weights: torch.Tensor, compute_mean: bool, compute_std:bool
) -> torch.Tensor:
"""Helper function to compute statistics pooling

Assumes that weights are already interpolated to match the number of frames
Expand All @@ -50,16 +52,24 @@ def _pool(sequences: torch.Tensor, weights: torch.Tensor) -> torch.Tensor:
weights = weights.unsqueeze(dim=1)
# (batch, 1, frames)

stats = []

v1 = weights.sum(dim=2) + 1e-8
mean = torch.sum(sequences * weights, dim=2) / v1

dx2 = torch.square(sequences - mean.unsqueeze(2))
v2 = torch.square(weights).sum(dim=2)
if compute_mean:
stats.append(mean)

if compute_std:
dx2 = torch.square(sequences - mean.unsqueeze(2))
v2 = torch.square(weights).sum(dim=2)

var = torch.sum(dx2 * weights, dim=2) / (v1 - v2 / v1 + 1e-8)
std = torch.sqrt(var)
var = torch.sum(dx2 * weights, dim=2) / (v1 - v2 / v1 + 1e-8)
std = torch.sqrt(var)

return torch.cat([mean, std], dim=1)
stats.append(std)

return torch.cat(stats, dim=1)


class StatsPool(nn.Module):
Expand All @@ -68,14 +78,33 @@ class StatsPool(nn.Module):
Compute temporal mean and (unbiased) standard deviation
and returns their concatenation.

Parameters
----------

compute_mean: bool, optional
whether to compute (and return) temporal mean.
Default to True
compute_std: bool, optional
whether to compute (and return) temporal standard deviation.
Default to True

Reference
---------
https://en.wikipedia.org/wiki/Weighted_arithmetic_mean

"""

def __init__(
self,
compute_mean: Optional[bool] = True,
computde_std: Optional[bool] = True,
):
super().__init__()
self.compute_mean = compute_mean
self.compute_std = computde_std

def forward(
self, sequences: torch.Tensor, weights: Optional[torch.Tensor] = None
self, sequences: torch.Tensor, weights: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass

Expand Down Expand Up @@ -122,7 +151,7 @@ def forward(

output = torch.stack(
[
_pool(sequences, weights[:, speaker, :])
_pool(sequences, weights[:, speaker, :], self.compute_mean, self.compute_std)
for speaker in range(num_speakers)
],
dim=1,
Expand Down
27 changes: 27 additions & 0 deletions pyannote/audio/models/joint/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# MIT License
#
# Copyright (c) 2020 CNRS
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

from .end_to_end_diarization import (
WavLMEnd2EndDiarization, WavLMEnd2EndDiarizationv2, WavLMEnd2EndDiarizationv3
)

__all__ = ["WavLMEnd2EndDiarization", "WavLMEnd2EndDiarizationv2", "WavLMEnd2EndDiarizationv3"]
Loading
Loading