Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the task "balance" option #1436

Merged
merged 28 commits into from
Sep 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
0551070
fix: raise TypeError on wrong device type in Pipeline.to and Inferenc…
chai3 Jun 8, 2023
30ddb0b
feat(task): add support for multi-task models (#1374)
hbredin Jun 12, 2023
4eb7190
fix(inference): fix multi-task inference
hbredin Jun 12, 2023
dcdfc15
feat: update FAQtory default answer
hbredin Jun 15, 2023
3363be6
improve(test): use pyannote.database.registry (#1413)
hbredin Jun 22, 2023
017c910
feat(pipeline): add `return_embeddings` option to `SpeakerDiarization…
flyingleafe Jun 23, 2023
cf0e3b3
fix: fix missed speech at the very beginning/end
hbredin Jun 27, 2023
f393546
doc: add note to self regarding cluster reassignment (#1419)
hbredin Jun 28, 2023
35be745
fix(doc): fix typo in diarization docstring
DiaaAj Jul 9, 2023
bc0920f
ci: update suggest.md (#1435)
hbredin Jul 16, 2023
6e6e6e1
fix balance
FrenchKrab Jul 17, 2023
9724c1a
update docstring
FrenchKrab Jul 17, 2023
da744fe
add comments related to balance usage in mixins
FrenchKrab Jul 17, 2023
9f7c68f
Merge branch 'develop' into fix_balance
hbredin Jul 17, 2023
ea1a1d4
fix balance random choice of generator
FrenchKrab Jul 18, 2023
f39a33d
Merge branch 'fix_balance' of github.com:FrenchKrab/pyannote-audio in…
FrenchKrab Jul 18, 2023
2632cad
fix train__iter__helper filtering
FrenchKrab Jul 18, 2023
7194929
feat: add support for WeSpeaker embeddings (#1444)
hbredin Aug 2, 2023
37b39b0
fix: fix security issue in FAQtory bot
aashish-19 Aug 7, 2023
5a7df38
Update README.md
hbredin Aug 30, 2023
2af703d
Update README.md
hbredin Aug 30, 2023
b660b1e
fix(task): fix MultiLabelSegmentation.val_monitor
FrenchKrab Sep 15, 2023
93fb800
Merge branch 'develop' into fix_balance
hbredin Sep 15, 2023
9df6944
fix(core): fix Model.example_output for embedding models
hbredin Sep 16, 2023
fc4da7c
Merge branch 'develop' into fix_balance
hbredin Sep 18, 2023
34fb96b
Merge branch 'develop' into fix_balance
hbredin Sep 20, 2023
8baba48
doc: update CHANGELOG
hbredin Sep 20, 2023
8bfd924
Merge branch 'develop' into fix_balance
hbredin Sep 20, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
- fix(pipeline): fix support for IOBase audio
- fix(pipeline): fix corner case with no speaker
- fix(train): prevent metadata preparation to happen twice
- fix(task): fix support for "balance" option
- improve(task): shorten and improve structure of Tensorboard tags

### Dependencies
Expand Down
9 changes: 6 additions & 3 deletions pyannote/audio/tasks/segmentation/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ def train__iter__helper(self, rng: random.Random, **filters):
# indices of training files that matches domain filters
training = self.metadata["subset"] == Subsets.index("train")
for key, value in filters.items():
training &= self.metadata[key] == value
training &= self.metadata[key] == self.metadata_unique_values[key].index(value)
file_ids = np.where(training)[0]

# turn annotated duration into a probability distribution
Expand Down Expand Up @@ -485,16 +485,19 @@ def train__iter__(self):
# create a subchunk generator for each combination of "balance" keys
subchunks = dict()
for product in itertools.product(
[self.metadata_unique_values[key] for key in balance]
*[self.metadata_unique_values[key] for key in balance]
):
# we iterate on the cartesian product of the values in metadata_unique_values
# eg: for balance=["database", "split"], with 2 databases and 2 splits:
# ("DIHARD", "A"), ("DIHARD", "B"), ("REPERE", "A"), ("REPERE", "B")
filters = {key: value for key, value in zip(balance, product)}
subchunks[product] = self.train__iter__helper(rng, **filters)

while True:
# select one subchunk generator at random (with uniform probability)
# so that it is balanced on average
if balance is not None:
chunks = subchunks[rng.choice(subchunks)]
chunks = subchunks[rng.choice(list(subchunks))]

# generate random chunk
yield next(chunks)
Expand Down
10 changes: 5 additions & 5 deletions pyannote/audio/tasks/segmentation/multilabel.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,10 @@ class MultiLabelSegmentation(SegmentationTaskMixin, Task):
parts, only the remaining central part of each chunk is used for computing the
loss during training, and for aggregating scores during inference.
Defaults to 0. (i.e. no warm-up).
balance: str, optional
When provided, training samples are sampled uniformly with respect to that key.
For instance, setting `balance` to "uri" will make sure that each file will be
equally represented in the training samples.
balance: Sequence[Text], optional
When provided, training samples are sampled uniformly with respect to these keys.
For instance, setting `balance` to ["database","subset"] will make sure that each
database & subset combination will be equally represented in the training samples.
weight: str, optional
When provided, use this key to as frame-wise weight in loss function.
batch_size : int, optional
Expand All @@ -87,7 +87,7 @@ def __init__(
classes: Optional[List[str]] = None,
duration: float = 2.0,
warm_up: Union[float, Tuple[float, float]] = 0.0,
balance: Text = None,
balance: Sequence[Text] = None,
weight: Text = None,
batch_size: int = 32,
num_workers: int = None,
Expand Down
10 changes: 5 additions & 5 deletions pyannote/audio/tasks/segmentation/overlapped_speech_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,10 @@ class OverlappedSpeechDetection(SegmentationTaskMixin, Task):
parts, only the remaining central part of each chunk is used for computing the
loss during training, and for aggregating scores during inference.
Defaults to 0. (i.e. no warm-up).
balance: str, optional
When provided, training samples are sampled uniformly with respect to that key.
For instance, setting `balance` to "uri" will make sure that each file will be
equally represented in the training samples.
balance: Sequence[Text], optional
When provided, training samples are sampled uniformly with respect to these keys.
For instance, setting `balance` to ["database","subset"] will make sure that each
database & subset combination will be equally represented in the training samples.
overlap: dict, optional
Controls how artificial chunks with overlapping speech are generated:
- "probability" key is the probability of artificial overlapping chunks. Setting
Expand Down Expand Up @@ -98,7 +98,7 @@ def __init__(
duration: float = 2.0,
warm_up: Union[float, Tuple[float, float]] = 0.0,
overlap: dict = OVERLAP_DEFAULTS,
balance: Text = None,
balance: Sequence[Text] = None,
weight: Text = None,
batch_size: int = 32,
num_workers: int = None,
Expand Down
10 changes: 5 additions & 5 deletions pyannote/audio/tasks/segmentation/speaker_diarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,10 @@ class SpeakerDiarization(SegmentationTaskMixin, Task):
parts, only the remaining central part of each chunk is used for computing the
loss during training, and for aggregating scores during inference.
Defaults to 0. (i.e. no warm-up).
balance: str, optional
When provided, training samples are sampled uniformly with respect to that key.
For instance, setting `balance` to "database" will make sure that each database
will be equally represented in the training samples.
balance: Sequence[Text], optional
When provided, training samples are sampled uniformly with respect to these keys.
For instance, setting `balance` to ["database","subset"] will make sure that each
database & subset combination will be equally represented in the training samples.
weight: str, optional
When provided, use this key as frame-wise weight in loss function.
batch_size : int, optional
Expand Down Expand Up @@ -132,7 +132,7 @@ def __init__(
max_speakers_per_frame: int = None,
weigh_by_cardinality: bool = False,
warm_up: Union[float, Tuple[float, float]] = 0.0,
balance: Text = None,
balance: Sequence[Text] = None,
weight: Text = None,
batch_size: int = 32,
num_workers: int = None,
Expand Down
10 changes: 5 additions & 5 deletions pyannote/audio/tasks/segmentation/voice_activity_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,10 @@ class VoiceActivityDetection(SegmentationTaskMixin, Task):
parts, only the remaining central part of each chunk is used for computing the
loss during training, and for aggregating scores during inference.
Defaults to 0. (i.e. no warm-up).
balance: str, optional
When provided, training samples are sampled uniformly with respect to that key.
For instance, setting `balance` to "uri" will make sure that each file will be
equally represented in the training samples.
balance: Sequence[Text], optional
When provided, training samples are sampled uniformly with respect to these keys.
For instance, setting `balance` to ["database","subset"] will make sure that each
database & subset combination will be equally represented in the training samples.
weight: str, optional
When provided, use this key to as frame-wise weight in loss function.
batch_size : int, optional
Expand All @@ -81,7 +81,7 @@ def __init__(
protocol: Protocol,
duration: float = 2.0,
warm_up: Union[float, Tuple[float, float]] = 0.0,
balance: Text = None,
balance: Sequence[Text] = None,
weight: Text = None,
batch_size: int = 32,
num_workers: int = None,
Expand Down
Loading