From 8d25c45d2d181879a8f1f656c9230ea039922f44 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Herv=C3=A9=20BREDIN?= Date: Mon, 15 May 2023 10:46:03 +0200 Subject: [PATCH 01/55] wip: improve support for fixed duration tasks --- pyannote/audio/core/model.py | 77 ++++++++++++++----- pyannote/audio/core/task.py | 10 +-- pyannote/audio/tasks/embedding/mixins.py | 15 +--- .../audio/tasks/segmentation/multilabel.py | 5 +- .../overlapped_speech_detection.py | 2 +- .../tasks/segmentation/speaker_diarization.py | 12 +-- .../segmentation/voice_activity_detection.py | 2 +- 7 files changed, 70 insertions(+), 53 deletions(-) diff --git a/pyannote/audio/core/model.py b/pyannote/audio/core/model.py index 18b301086..da5c79fe8 100644 --- a/pyannote/audio/core/model.py +++ b/pyannote/audio/core/model.py @@ -65,13 +65,17 @@ class Introspection: Parameters ---------- min_num_samples: int - Minimum number of input samples + For fixed-duration models, expected number of input samples. + For variable-duration models, minimum number of input samples supported by + the model (i.e. model fails for smaller number of samples). min_num_frames: int - Corresponding minimum number of output frames + Corresponding number of output frames. inc_num_samples: int - Number of input samples leading to an increase of number of output frames + Number of input samples leading to an increase of number of output frames. + Has no meaning for fixed-duration models (set to 0). inc_num_frames: int Corresponding increase in number of output frames + Has no meaning for fixed-duration models (set to 0). dimension: int Output dimension sample_rate: int @@ -103,12 +107,46 @@ def __init__( self.sample_rate = sample_rate @classmethod - def from_model(cls, model: "Model", task: str = None) -> Introspection: + def from_model(cls, model: "Model") -> Introspection: + """ + + Parameters + ---------- + model : Model + """ specifications = model.specifications - if task is not None: - specifications = specifications[task] + duration = specifications.duration + min_duration = specifications.min_duration or duration + + # case 1: the model expects a fixed-duration chunk + if min_duration == duration: + num_samples = model.audio.get_num_samples(specifications.duration) + frames = model(model.example_input_array) + if specifications.resolution == Resolution.FRAME: + _, num_frames, dimension = frames.shape + return cls( + min_num_samples=num_samples, + min_num_frames=num_frames, + inc_num_samples=0, + inc_num_frames=0, + dimension=dimension, + sample_rate=model.hparams.sample_rate, + ) + elif specifications.resolution == Resolution.CHUNK: + _, dimension = frames.shape + return cls( + min_num_samples=num_samples, + min_num_frames=1, + inc_num_samples=0, + inc_num_frames=0, + dimension=dimension, + sample_rate=model.hparams.sample_rate, + ) + + # case 2: the model supports variable-duration chunks + # we use dichotomic search to find the minimum number of samples example_input_array = model.example_input_array batch_size, num_channels, num_samples = example_input_array.shape example_input_array = torch.randn( @@ -126,8 +164,6 @@ def from_model(cls, model: "Model", task: str = None) -> Introspection: try: with torch.no_grad(): frames = model(example_input_array[:, :, :num_samples]) - if task is not None: - frames = frames[task] except Exception: lower = num_samples else: @@ -175,8 +211,6 @@ def from_model(cls, model: "Model", task: str = None) -> Introspection: ) with torch.no_grad(): frames = model(example_input_array) - if task is not None: - frames = frames[task] num_frames = frames.shape[1] if num_frames > min_num_frames: break @@ -194,8 +228,6 @@ def from_model(cls, model: "Model", task: str = None) -> Introspection: ) with torch.no_grad(): frames = model(example_input_array) - if task is not None: - frames = frames[task] num_frames = frames.shape[1] if num_frames > min_num_frames: inc_num_frames = num_frames - min_num_frames @@ -232,6 +264,13 @@ def __call__(self, num_samples: int) -> Tuple[int, int]: Dimension of output frames """ + # case 1: the model expects a fixed-duration chunk + if self.inc_num_frames == 0: + assert num_samples == self.min_num_samples + return self.min_num_frames, self.dimension + + # case 2: the model supports variable-duration chunks + if num_samples < self.min_num_samples: return 0, self.dimension @@ -246,7 +285,14 @@ def __call__(self, num_samples: int) -> Tuple[int, int]: def frames(self) -> SlidingWindow: # HACK to support model trained before 'sample_rate' was an Introspection attribute sample_rate = getattr(self, "sample_rate", 16000) - step = (self.inc_num_samples / self.inc_num_frames) / sample_rate + + if self.inc_num_frames == 0: + step = (self.min_num_samples / self.min_num_frames) / sample_rate + else: + # FIXME: this is not 100% accurate, but it's good enough for now + # FIXME: it should probably be estimated from the maximum duration + step = (self.inc_num_samples / self.inc_num_frames) / sample_rate + return SlidingWindow(start=0.0, step=step, duration=step) @@ -368,7 +414,6 @@ def introspection(self): del self._introspection def setup(self, stage=None): - if stage == "fit": self.task.setup() @@ -421,7 +466,6 @@ def setup(self, stage=None): self.task_dependent = list(name for name, _ in after - before) def on_save_checkpoint(self, checkpoint): - # put everything pyannote.audio-specific under pyannote.audio # to avoid any future conflicts with pytorch-lightning updates checkpoint["pyannote.audio"] = { @@ -438,7 +482,6 @@ def on_save_checkpoint(self, checkpoint): } def on_load_checkpoint(self, checkpoint: Dict[str, Any]): - check_version( "pyannote.audio", checkpoint["pyannote.audio"]["versions"]["pyannote.audio"], @@ -636,7 +679,6 @@ def _helper_by_name( modules = [modules] for name, module in ModelSummary(self, max_depth=-1).named_modules: - if name not in modules: continue @@ -826,7 +868,6 @@ def from_pretrained( # HACK do not use it. Fails silently in case model does not # HACK have a config.yaml file. try: - _ = hf_hub_download( model_id, HF_LIGHTNING_CONFIG_NAME, diff --git a/pyannote/audio/core/task.py b/pyannote/audio/core/task.py index a46308f88..d02e643b4 100644 --- a/pyannote/audio/core/task.py +++ b/pyannote/audio/core/task.py @@ -72,9 +72,11 @@ class Specifications: problem: Problem resolution: Resolution - # chunk duration in seconds. - # use None for variable-length chunks - duration: Optional[float] = None + # (maximum) chunk duration in seconds + duration: float + + # (for variable-duration tasks only) minimum chunk duration in seconds + min_duration: Optional[float] = None # use that many seconds on the left- and rightmost parts of each chunk # to warm up the model. This is mostly useful for segmentation tasks. @@ -96,7 +98,6 @@ class Specifications: @cached_property def powerset(self): - if self.powerset_max_classes is None: return False @@ -302,7 +303,6 @@ def train_dataloader(self) -> DataLoader: @cached_property def logging_prefix(self): - prefix = f"{self.__class__.__name__}-" if hasattr(self.protocol, "name"): # "." has a special meaning for pytorch-lightning checkpointing diff --git a/pyannote/audio/tasks/embedding/mixins.py b/pyannote/audio/tasks/embedding/mixins.py index 00aa7e608..b02ae7f71 100644 --- a/pyannote/audio/tasks/embedding/mixins.py +++ b/pyannote/audio/tasks/embedding/mixins.py @@ -76,7 +76,6 @@ def batch_size(self, batch_size: int): self.batch_size_ = batch_size def setup(self, stage: Optional[str] = None): - # loop over the training set, remove annotated regions shorter than # chunk duration, and keep track of the reference annotations, per class. @@ -87,9 +86,7 @@ def setup(self, stage: Optional[str] = None): desc = f"Loading {self.protocol.name} training labels" for f in tqdm(iterable=self.protocol.train(), desc=desc, unit="file"): - for klass in f["annotation"].labels(): - # keep class's (long enough) speech turns... speech_turns = [ segment @@ -121,6 +118,7 @@ def setup(self, stage: Optional[str] = None): problem=Problem.REPRESENTATION, resolution=Resolution.CHUNK, duration=self.duration, + min_duration=self.min_duration, classes=sorted(self._train), ) @@ -133,7 +131,6 @@ def setup(self, stage: Optional[str] = None): def default_metric( self, ) -> Union[Metric, Sequence[Metric], Dict[str, Metric]]: - return [ EqualErrorRate(compute_on_cpu=True, distances=False), BinaryAUROC(compute_on_cpu=True), @@ -155,11 +152,11 @@ def train__iter__(self): classes = list(self.specifications.classes) + # select batch-wise duration at random batch_duration = rng.uniform(self.min_duration, self.duration) num_samples = 0 while True: - # shuffle classes so that we don't always have the same # groups of classes in a batch (which might be especially # problematic for contrast-based losses like contrastive @@ -167,13 +164,11 @@ def train__iter__(self): rng.shuffle(classes) for klass in classes: - # class index in original sorted order y = self.specifications.classes.index(klass) # multiple chunks per class for _ in range(self.num_chunks_per_class): - # select one file at random (with probability proportional to its class duration) file, *_ = rng.choices( self._train[klass], @@ -227,7 +222,6 @@ def train__len__(self): return max(self.batch_size, math.ceil(duration / avg_chunk_duration)) def collate_fn(self, batch, stage="train"): - collated = default_collate(batch) if stage == "train": @@ -241,7 +235,6 @@ def collate_fn(self, batch, stage="train"): return collated def training_step(self, batch, batch_idx: int): - X, y = batch["X"], batch["y"] loss = self.model.loss_func(self.model(X), y) @@ -261,7 +254,6 @@ def training_step(self, batch, batch_idx: int): return {"loss": loss} def val__getitem__(self, idx): - if isinstance(self.protocol, SpeakerVerificationProtocol): trial = self._validation[idx] @@ -291,7 +283,6 @@ def val__getitem__(self, idx): pass def val__len__(self): - if isinstance(self.protocol, SpeakerVerificationProtocol): return len(self._validation) @@ -299,9 +290,7 @@ def val__len__(self): return 0 def validation_step(self, batch, batch_idx: int): - if isinstance(self.protocol, SpeakerVerificationProtocol): - with torch.no_grad(): emb1 = self.model(batch["X1"]).detach() emb2 = self.model(batch["X2"]).detach() diff --git a/pyannote/audio/tasks/segmentation/multilabel.py b/pyannote/audio/tasks/segmentation/multilabel.py index 19270e26f..f27303e2a 100644 --- a/pyannote/audio/tasks/segmentation/multilabel.py +++ b/pyannote/audio/tasks/segmentation/multilabel.py @@ -95,7 +95,6 @@ def __init__( augmentation: BaseWaveformTransform = None, metric: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None, ): - if not isinstance(protocol, SegmentationProtocol): raise ValueError( f"MultiLabelSegmentation task expects a SegmentationProtocol but you gave {type(protocol)}. " @@ -121,7 +120,6 @@ def __init__( # specifications to setup() def setup(self, stage: Optional[str] = None): - super().setup(stage=stage) self.specifications = Specifications( @@ -129,6 +127,7 @@ def setup(self, stage: Optional[str] = None): problem=Problem.MULTI_LABEL_CLASSIFICATION, resolution=Resolution.FRAME, duration=self.duration, + min_duration=self.min_duration, warm_up=self.warm_up, ) @@ -208,7 +207,6 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): return sample def training_step(self, batch, batch_idx: int): - X = batch["X"] y_pred = self.model(X) y_true = batch["y"] @@ -238,7 +236,6 @@ def training_step(self, batch, batch_idx: int): return {"loss": loss} def validation_step(self, batch, batch_idx: int): - X = batch["X"] y_pred = self.model(X) y_true = batch["y"] diff --git a/pyannote/audio/tasks/segmentation/overlapped_speech_detection.py b/pyannote/audio/tasks/segmentation/overlapped_speech_detection.py index 658c350a7..8e6551447 100644 --- a/pyannote/audio/tasks/segmentation/overlapped_speech_detection.py +++ b/pyannote/audio/tasks/segmentation/overlapped_speech_detection.py @@ -106,7 +106,6 @@ def __init__( augmentation: BaseWaveformTransform = None, metric: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None, ): - super().__init__( protocol, duration=duration, @@ -122,6 +121,7 @@ def __init__( problem=Problem.BINARY_CLASSIFICATION, resolution=Resolution.FRAME, duration=self.duration, + min_duration=self.min_duration, warm_up=self.warm_up, classes=[ "overlap", diff --git a/pyannote/audio/tasks/segmentation/speaker_diarization.py b/pyannote/audio/tasks/segmentation/speaker_diarization.py index 21f4416cc..3ef0b1a17 100644 --- a/pyannote/audio/tasks/segmentation/speaker_diarization.py +++ b/pyannote/audio/tasks/segmentation/speaker_diarization.py @@ -143,7 +143,6 @@ def __init__( max_num_speakers: int = None, # deprecated in favor of `max_speakers_per_chunk`` loss: Literal["bce", "mse"] = None, # deprecated ): - super().__init__( protocol, duration=duration, @@ -188,12 +187,10 @@ def __init__( self.vad_loss = vad_loss def setup(self, stage: Optional[str] = None): - super().setup(stage=stage) # estimate maximum number of speakers per chunk when not provided if self.max_speakers_per_chunk is None: - training = self.metadata["subset"] == Subsets.index("train") num_unique_speakers = [] @@ -201,7 +198,6 @@ def setup(self, stage: Optional[str] = None): for file_id in track( np.where(training)[0], description=progress_description ): - annotations = self.annotations[ np.where(self.annotations["file_id"] == file_id)[0] ] @@ -280,6 +276,7 @@ def setup(self, stage: Optional[str] = None): else Problem.MONO_LABEL_CLASSIFICATION, resolution=Resolution.FRAME, duration=self.duration, + min_duration=self.min_duration, warm_up=self.warm_up, classes=[f"speaker#{i+1}" for i in range(self.max_speakers_per_chunk)], powerset_max_classes=self.max_speakers_per_frame, @@ -448,7 +445,6 @@ def segmentation_loss( """ if self.specifications.powerset: - # `clamp_min` is needed to set non-speech weight to 1. class_weight = ( torch.clamp_min(self.model.powerset.cardinality, 1.0) @@ -569,7 +565,6 @@ def training_step(self, batch, batch_idx: int): weight[:, num_frames - warm_up_right :] = 0.0 if self.specifications.powerset: - powerset = torch.nn.functional.one_hot( torch.argmax(prediction, dim=-1), self.model.powerset.num_powerset_classes, @@ -602,7 +597,6 @@ def training_step(self, batch, batch_idx: int): vad_loss = 0.0 else: - # TODO: vad_loss probably does not make sense in powerset mode # because first class (empty set of labels) does exactly this... if self.specifications.powerset: @@ -704,7 +698,6 @@ def validation_step(self, batch, batch_idx: int): weight[:, num_frames - warm_up_right :] = 0.0 if self.specifications.powerset: - powerset = torch.nn.functional.one_hot( torch.argmax(prediction, dim=-1), self.model.powerset.num_powerset_classes, @@ -740,7 +733,6 @@ def validation_step(self, batch, batch_idx: int): vad_loss = 0.0 else: - # TODO: vad_loss probably does not make sense in powerset mode # because first class (empty set of labels) does exactly this... if self.specifications.powerset: @@ -833,7 +825,6 @@ def validation_step(self, batch, batch_idx: int): # plot each sample for sample_idx in range(num_samples): - # find where in the grid it should be plotted row_idx = sample_idx // nrows col_idx = sample_idx % ncols @@ -893,7 +884,6 @@ def main(protocol: str, subset: str = "test", model: str = "pyannote/segmentatio files = list(getattr(protocol, subset)()) with Progress() as progress: - main_task = progress.add_task(protocol.name, total=len(files)) file_task = progress.add_task("Processing", total=1.0) diff --git a/pyannote/audio/tasks/segmentation/voice_activity_detection.py b/pyannote/audio/tasks/segmentation/voice_activity_detection.py index 559ff24eb..4851b7455 100644 --- a/pyannote/audio/tasks/segmentation/voice_activity_detection.py +++ b/pyannote/audio/tasks/segmentation/voice_activity_detection.py @@ -89,7 +89,6 @@ def __init__( augmentation: BaseWaveformTransform = None, metric: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None, ): - super().__init__( protocol, duration=duration, @@ -108,6 +107,7 @@ def __init__( problem=Problem.BINARY_CLASSIFICATION, resolution=Resolution.FRAME, duration=self.duration, + min_duration=self.min_duration, warm_up=self.warm_up, classes=[ "speech", From 840f0efb062fadd600e0d44b59e66954851cf419 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Herv=C3=A9=20BREDIN?= Date: Mon, 15 May 2023 11:37:46 +0200 Subject: [PATCH 02/55] fix: add missing get_num_samples --- pyannote/audio/core/io.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/pyannote/audio/core/io.py b/pyannote/audio/core/io.py index b2e8842b1..0a44e75ea 100644 --- a/pyannote/audio/core/io.py +++ b/pyannote/audio/core/io.py @@ -150,7 +150,6 @@ def validate_file(file: AudioFile) -> Mapping: raise ValueError(AudioFileDocString) if "waveform" in file: - waveform: Union[np.ndarray, Tensor] = file["waveform"] if len(waveform.shape) != 2 or waveform.shape[0] > waveform.shape[1]: raise ValueError( @@ -166,7 +165,6 @@ def validate_file(file: AudioFile) -> Mapping: file.setdefault("uri", "waveform") elif "audio" in file: - if isinstance(file["audio"], IOBase): return file @@ -177,7 +175,6 @@ def validate_file(file: AudioFile) -> Mapping: file.setdefault("uri", path.stem) else: - raise ValueError( "Neither 'waveform' nor 'audio' is available for this file." ) @@ -185,7 +182,6 @@ def validate_file(file: AudioFile) -> Mapping: return file def __init__(self, sample_rate=None, mono=None): - super().__init__() self.sample_rate = sample_rate self.mono = mono @@ -257,6 +253,18 @@ def get_duration(self, file: AudioFile) -> float: return frames / sample_rate + def get_num_samples(self, duration: float, sample_rate: int = None) -> int: + """Deterministic number of samples from duration and sample rate""" + + sample_rate = sample_rate or self.sample_rate + + if sample_rate is None: + raise ValueError( + "`sample_rate` must be provided to compute number of samples." + ) + + return math.floor(duration * sample_rate) + def __call__(self, file: AudioFile) -> Tuple[Tensor, int]: """Obtain waveform @@ -359,7 +367,6 @@ def crop( num_frames = end_frame - start_frame if mode == "raise": - if num_frames > frames: raise ValueError( f"requested fixed duration ({duration:6f}s, or {num_frames:d} frames) is longer " @@ -400,7 +407,6 @@ def crop( if isinstance(file["audio"], IOBase): file["audio"].seek(0) except RuntimeError: - if isinstance(file["audio"], IOBase): msg = "torchaudio failed to seek-and-read in file-like object." raise RuntimeError(msg) From 7272caf888c1777f15999905727eb74533327792 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Herv=C3=A9=20BREDIN?= Date: Mon, 15 May 2023 13:27:27 +0200 Subject: [PATCH 03/55] fix: make window_size consistent with training duration --- pyannote/audio/core/inference.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/pyannote/audio/core/inference.py b/pyannote/audio/core/inference.py index 9243babc1..998a3df67 100644 --- a/pyannote/audio/core/inference.py +++ b/pyannote/audio/core/inference.py @@ -93,7 +93,6 @@ def __init__( batch_size: int = 32, use_auth_token: Union[Text, None] = None, ): - self.model = ( model if isinstance(model, Model) @@ -240,7 +239,7 @@ def slide( and (num_frames, dimension) for frame-level tasks. """ - window_size: int = round(self.duration * sample_rate) + window_size: int = self.model.audio.get_num_samples(self.duration) step_size: int = round(self.step * sample_rate) _, num_samples = waveform.shape @@ -284,7 +283,6 @@ def slide( # process orphan last chunk if has_last_chunk: - last_output = self.infer(last_chunk[None]) if specifications.resolution == Resolution.FRAME: @@ -409,7 +407,6 @@ def crop( """ if self.window == "sliding": - if not isinstance(chunk, Segment): start = min(c.start for c in chunk) end = max(c.end for c in chunk) @@ -427,7 +424,6 @@ def crop( return SlidingWindowFeature(output.data, shifted_frames) elif self.window == "whole": - if isinstance(chunk, Segment): waveform, sample_rate = self.model.audio.crop( file, chunk, duration=duration @@ -685,7 +681,6 @@ def always_match(this: np.ndarray, that: np.ndarray, cost: float): stitches = [] for C, (chunk, activation) in enumerate(activations): - local_stitch = np.NAN * np.zeros( (sum(lookahead) + 1, num_frames, num_classes) ) @@ -693,7 +688,6 @@ def always_match(this: np.ndarray, that: np.ndarray, cost: float): for c in range( max(0, C - lookahead[0]), min(num_chunks, C + lookahead[1] + 1) ): - # extract common temporal support shift = round((C - c) * num_frames * chunks.step / chunks.duration) @@ -714,7 +708,6 @@ def always_match(this: np.ndarray, that: np.ndarray, cost: float): ) for this, that in enumerate(permutation): - # only stitch under certain condiditions matching = (c == C) or ( match_func( From 5a0d1bb253f3afa363107534278a3b2b6cddc69d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Herv=C3=A9=20BREDIN?= Date: Mon, 15 May 2023 14:23:28 +0200 Subject: [PATCH 04/55] BREAKING: pad last audio chunk with zeros --- pyannote/audio/core/inference.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/pyannote/audio/core/inference.py b/pyannote/audio/core/inference.py index 998a3df67..692cf3814 100644 --- a/pyannote/audio/core/inference.py +++ b/pyannote/audio/core/inference.py @@ -27,6 +27,7 @@ import numpy as np import torch +import torch.nn.functional as F from einops import rearrange from pyannote.core import Segment, SlidingWindow, SlidingWindowFeature from pytorch_lightning.utilities.memory import is_oom_error @@ -246,11 +247,12 @@ def slide( specifications = self.model.specifications resolution = specifications.resolution introspection = self.model.introspection + if resolution == Resolution.CHUNK: frames = SlidingWindow(start=0.0, duration=self.duration, step=self.step) + elif resolution == Resolution.FRAME: frames = introspection.frames - num_frames_per_chunk, dimension = introspection(window_size) # prepare complete chunks if num_samples >= window_size: @@ -267,7 +269,11 @@ def slide( num_samples - window_size ) % step_size > 0 if has_last_chunk: + # pad last chunk with zeros last_chunk: torch.Tensor = waveform[:, num_chunks * step_size :] + _, last_window_size = last_chunk.shape + last_pad = window_size - last_window_size + last_chunk = F.pad(last_chunk, (0, last_pad)) outputs: Union[List[np.ndarray], np.ndarray] = list() @@ -284,11 +290,6 @@ def slide( # process orphan last chunk if has_last_chunk: last_output = self.infer(last_chunk[None]) - - if specifications.resolution == Resolution.FRAME: - pad = num_frames_per_chunk - last_output.shape[1] - last_output = np.pad(last_output, ((0, 0), (0, pad), (0, 0))) - outputs.append(last_output) if hook is not None: hook( @@ -326,9 +327,11 @@ def slide( missing=0.0, ) + # remove padding that was added to last chunk if has_last_chunk: - num_frames = aggregated.data.shape[0] - aggregated.data = aggregated.data[: num_frames - pad, :] + aggregated.data = aggregated.crop( + Segment(0.0, num_samples / sample_rate), mode="loose" + ) return aggregated From ad247501eb8f52bde73016c4b91ec73ceb329eca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Herv=C3=A9=20BREDIN?= Date: Wed, 24 May 2023 14:46:20 +0200 Subject: [PATCH 05/55] feat(task): add support for multi-task models - BREAKING(model): get rid of (flaky) `Model.introspection` --- CHANGELOG.md | 11 +- pyannote/audio/core/inference.py | 309 ++++++++----- pyannote/audio/core/model.py | 432 +++++------------- pyannote/audio/core/task.py | 15 +- pyannote/audio/models/segmentation/PyanNet.py | 5 +- pyannote/audio/models/segmentation/debug.py | 12 +- .../pipelines/overlapped_speech_detection.py | 2 +- pyannote/audio/pipelines/resegmentation.py | 3 +- .../audio/pipelines/speaker_diarization.py | 2 +- .../audio/pipelines/speaker_verification.py | 51 ++- pyannote/audio/pipelines/utils/oracle.py | 2 +- pyannote/audio/tasks/segmentation/mixins.py | 26 +- .../audio/tasks/segmentation/multilabel.py | 18 +- .../overlapped_speech_detection.py | 17 +- .../tasks/segmentation/speaker_diarization.py | 29 +- .../segmentation/voice_activity_detection.py | 17 +- pyannote/audio/utils/multi_task.py | 59 +++ pyannote/audio/utils/powerset.py | 22 +- pyannote/audio/utils/preview.py | 4 +- 19 files changed, 474 insertions(+), 562 deletions(-) create mode 100644 pyannote/audio/utils/multi_task.py diff --git a/CHANGELOG.md b/CHANGELOG.md index cf525f958..79a2e93ab 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,10 +5,10 @@ ### Breaking changes - BREAKING(task): rename `Segmentation` task to `SpeakerDiarization` - - BREAKING(task): remove support for variable chunk duration + - BREAKING(task): remove support for variable chunk duration for segmentation tasks - BREAKING(pipeline): pipeline defaults to CPU (use `pipeline.to(device)`) - BREAKING(pipeline): remove `SpeakerSegmentation` pipeline (use `SpeakerDiarization` pipeline) - - BREAKING(pipeline): remove support `FINCHClustering` and `HiddenMarkovModelClustering` + - BREAKING(pipeline): remove support for `FINCHClustering` and `HiddenMarkovModelClustering` - BREAKING(pipeline): remove `segmentation_duration` parameter from `SpeakerDiarization` pipeline (defaults to `duration` of segmentation model) - BREAKING(setup): drop support for Python 3.7 - BREAKING(io): channels are now 0-indexed (used to be 1-indexed) @@ -17,9 +17,16 @@ * replace `Audio()` by `Audio(mono="downmix")`; * replace `Audio(mono=True)` by `Audio(mono="downmix")`; * replace `Audio(mono=False)` by `Audio()`. + - BREAKING(model): get rid of (flaky) `Model.introspection` + If, for some weird reason, you wrote some custom code based on that, you should instead rely on: + * `Model.example_output(duration=...)` to get example output(s) + * `Model.output_frames` to get output frame resolution(s) + * `Model.output_dimension` to get output dimension(s) + ### Features and improvements + - feat(task): add support for multi-task models (for inference) - feat(pipeline): send pipeline to device with `pipeline.to(device)` - feat(pipeline): make `segmentation_batch_size` and `embedding_batch_size` mutable in `SpeakerDiarization` pipeline (they now default to `1`) - feat(task): add [powerset](https://arxiv.org/PLACEHOLDER) support to `SpeakerDiarization` task diff --git a/pyannote/audio/core/inference.py b/pyannote/audio/core/inference.py index 692cf3814..98f72f6e9 100644 --- a/pyannote/audio/core/inference.py +++ b/pyannote/audio/core/inference.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2020-2021 CNRS +# Copyright (c) 2020- CNRS # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -27,19 +27,19 @@ import numpy as np import torch +import torch.nn as nn import torch.nn.functional as F from einops import rearrange from pyannote.core import Segment, SlidingWindow, SlidingWindowFeature from pytorch_lightning.utilities.memory import is_oom_error from pyannote.audio.core.io import AudioFile -from pyannote.audio.core.model import Model +from pyannote.audio.core.model import Model, Specifications from pyannote.audio.core.task import Resolution +from pyannote.audio.utils.multi_task import map_with_specifications from pyannote.audio.utils.permutation import mae_cost_func, permutate from pyannote.audio.utils.powerset import Powerset -TaskName = Union[Text, None] - class BaseInference: pass @@ -68,10 +68,10 @@ class Inference(BaseInference): skip_aggregation : bool, optional Do not aggregate outputs when using "sliding" window. Defaults to False. skip_conversion: bool, optional - In case `model` has been trained with `powerset` mode, its output is automatically + In case a task has been trained with `powerset` mode, output is automatically converted to `multi-label`, unless `skip_conversion` is set to True. batch_size : int, optional - Batch size. Larger values make inference faster. Defaults to 32. + Batch size. Larger values (should) make inference faster. Defaults to 32. device : torch.device, optional Device used for inference. Defaults to `model.device`. In case `device` and `model.device` are different, model is sent to device. @@ -94,6 +94,8 @@ def __init__( batch_size: int = 32, use_auth_token: Union[Text, None] = None, ): + # ~~~~ model ~~~~~ + self.model = ( model if isinstance(model, Model) @@ -105,50 +107,70 @@ def __init__( ) ) - if window not in ["sliding", "whole"]: - raise ValueError('`window` must be "sliding" or "whole".') - - specifications = self.model.specifications - if specifications.resolution == Resolution.FRAME and window == "whole": - warnings.warn( - 'Using "whole" `window` inference with a frame-based model might lead to bad results ' - 'and huge memory consumption: it is recommended to set `window` to "sliding".' - ) - - self.window = window - self.skip_aggregation = skip_aggregation - if device is None: device = self.model.device self.device = device - self.pre_aggregation_hook = pre_aggregation_hook - self.model.eval() self.model.to(self.device) - # chunk duration used during training specifications = self.model.specifications - training_duration = specifications.duration - if duration is None: - duration = training_duration - elif training_duration != duration: + # ~~~~ sliding window ~~~~~ + + if window not in ["sliding", "whole"]: + raise ValueError('`window` must be "sliding" or "whole".') + + if window == "whole" and any( + s.resolution == Resolution.FRAME for s in specifications + ): + warnings.warn( + 'Using "whole" `window` inference with a frame-based model might lead to bad results ' + 'and huge memory consumption: it is recommended to set `window` to "sliding".' + ) + self.window = window + + training_duration = next(iter(specifications)).duration + duration = duration or training_duration + if training_duration != duration: warnings.warn( f"Model was trained with {training_duration:g}s chunks, and you requested " f"{duration:g}s chunks for inference: this might lead to suboptimal results." ) self.duration = duration - self.warm_up = specifications.warm_up + # ~~~~ powerset to multilabel conversion ~~~~ + + self.skip_conversion = skip_conversion + + conversion = list() + for s in specifications: + if s.powerset and not skip_conversion: + c = Powerset(len(s.classes), s.powerset_max_classes) + else: + c = nn.Identity() + conversion.append(c.to(self.device)) + + if isinstance(specifications, Specifications): + self.conversion = conversion[0] + else: + self.conversion = nn.ModuleList(conversion) + + # ~~~~ overlap-add aggregation ~~~~~ + + self.skip_aggregation = skip_aggregation + self.pre_aggregation_hook = pre_aggregation_hook + + self.warm_up = next(iter(specifications)).warm_up # Use that many seconds on the left- and rightmost parts of each chunk # to warm up the model. While the model does process those left- and right-most # parts, only the remaining central part of each chunk is used for aggregating # scores during inference. # step between consecutive chunks - if step is None: - step = 0.1 * self.duration if self.warm_up[0] == 0.0 else self.warm_up[0] + step = step or ( + 0.1 * self.duration if self.warm_up[0] == 0.0 else self.warm_up[0] + ) if step > self.duration: raise ValueError( @@ -159,23 +181,16 @@ def __init__( self.step = step self.batch_size = batch_size - self.skip_conversion = skip_conversion - if specifications.powerset and not self.skip_conversion: - self._powerset = Powerset( - len(specifications.classes), specifications.powerset_max_classes - ) - self._powerset.to(self.device) - def to(self, device: torch.device): + def to(self, device: torch.device) -> "Inference": """Send internal model to `device`""" self.model.to(device) - if self.model.specifications.powerset and not self.skip_conversion: - self._powerset.to(device) + self.conversion.to(device) self.device = device return self - def infer(self, chunks: torch.Tensor) -> np.ndarray: + def infer(self, chunks: torch.Tensor) -> Union[np.ndarray, Tuple[np.ndarray]]: """Forward pass Takes care of sending chunks to right device and outputs back to CPU @@ -187,11 +202,11 @@ def infer(self, chunks: torch.Tensor) -> np.ndarray: Returns ------- - outputs : (batch_size, ...) np.ndarray + outputs : (tuple of) (batch_size, ...) np.ndarray Model output. """ - with torch.no_grad(): + with torch.inference_mode(): try: outputs = self.model(chunks.to(self.device)) except RuntimeError as exception: @@ -203,22 +218,19 @@ def infer(self, chunks: torch.Tensor) -> np.ndarray: else: raise exception - # convert powerset to multi-label unless specifically requested not to - if self.model.specifications.powerset and not self.skip_conversion: - powerset = torch.nn.functional.one_hot( - torch.argmax(outputs, dim=-1), - self.model.specifications.num_powerset_classes, - ).float() - outputs = self._powerset.to_multilabel(powerset) + def __convert(output: torch.Tensor, conversion: nn.Module, **kwargs): + return conversion(output).cpu().numpy() - return outputs.cpu().numpy() + return map_with_specifications( + self.model.specifications, __convert, outputs, self.conversion + ) def slide( self, waveform: torch.Tensor, sample_rate: int, hook: Optional[Callable], - ) -> SlidingWindowFeature: + ) -> Union[SlidingWindowFeature, Tuple[SlidingWindowFeature]]: """Slide model on a waveform Parameters @@ -235,7 +247,7 @@ def slide( Returns ------- - output : SlidingWindowFeature + output : (tuple of) SlidingWindowFeature Model output. Shape is (num_chunks, dimension) for chunk-level tasks, and (num_frames, dimension) for frame-level tasks. """ @@ -244,15 +256,20 @@ def slide( step_size: int = round(self.step * sample_rate) _, num_samples = waveform.shape - specifications = self.model.specifications - resolution = specifications.resolution - introspection = self.model.introspection + frames = self.model.output_frames - if resolution == Resolution.CHUNK: - frames = SlidingWindow(start=0.0, duration=self.duration, step=self.step) + def __frames( + frames, specifications: Optional[Specifications] = None + ) -> SlidingWindow: + if specifications.resolution == Resolution.CHUNK: + return SlidingWindow(start=0.0, duration=self.duration, step=self.step) + return frames - elif resolution == Resolution.FRAME: - frames = introspection.frames + frames: Union[SlidingWindow, Tuple[SlidingWindow]] = map_with_specifications( + self.model.specifications, + __frames, + self.model.output_frames, + ) # prepare complete chunks if num_samples >= window_size: @@ -275,69 +292,107 @@ def slide( last_pad = window_size - last_window_size last_chunk = F.pad(last_chunk, (0, last_pad)) - outputs: Union[List[np.ndarray], np.ndarray] = list() + def __empty_list(**kwargs): + return list() + + outputs: Union[ + List[np.ndarray], Tuple[List[np.ndarray]] + ] = map_with_specifications(self.model.specifications, __empty_list) if hook is not None: hook(completed=0, total=num_chunks + has_last_chunk) + def __append_batch(output, batch_output, **kwargs) -> None: + output.append(batch_output) + return + # slide over audio chunks in batch for c in np.arange(0, num_chunks, self.batch_size): batch: torch.Tensor = chunks[c : c + self.batch_size] - outputs.append(self.infer(batch)) + + batch_outputs: Union[np.ndarray, Tuple[np.ndarray]] = self.infer(batch) + + _ = map_with_specifications( + self.model.specifications, __append_batch, outputs, batch_outputs + ) + if hook is not None: hook(completed=c + self.batch_size, total=num_chunks + has_last_chunk) # process orphan last chunk if has_last_chunk: - last_output = self.infer(last_chunk[None]) - outputs.append(last_output) + last_outputs = self.infer(last_chunk[None]) + + _ = map_with_specifications( + self.model.specifications, __append_batch, outputs, last_outputs + ) + if hook is not None: hook( completed=num_chunks + has_last_chunk, total=num_chunks + has_last_chunk, ) - outputs = np.vstack(outputs) - - # skip aggregation when requested, - # or when model outputs just one vector per chunk - # or when model is permutation-invariant (and not post-processed) - if ( - self.skip_aggregation - or specifications.resolution == Resolution.CHUNK - or ( - specifications.permutation_invariant - and self.pre_aggregation_hook is None - ) - ): - frames = SlidingWindow(start=0.0, duration=self.duration, step=self.step) - return SlidingWindowFeature(outputs, frames) - - if self.pre_aggregation_hook is not None: - outputs = self.pre_aggregation_hook(outputs) - - aggregated = self.aggregate( - SlidingWindowFeature( - outputs, - SlidingWindow(start=0.0, duration=self.duration, step=self.step), - ), - frames=frames, - warm_up=self.warm_up, - hamming=True, - missing=0.0, + def __vstack(output: List[np.ndarray], **kwargs) -> np.ndarray: + return np.vstack(output) + + outputs: Union[np.ndarray, Tuple[np.ndarray]] = map_with_specifications( + self.model.specifications, __vstack, outputs ) - # remove padding that was added to last chunk - if has_last_chunk: - aggregated.data = aggregated.crop( - Segment(0.0, num_samples / sample_rate), mode="loose" + def __aggregate( + outputs: np.ndarray, + frames: SlidingWindow, + specifications: Optional[Specifications] = None, + ) -> SlidingWindowFeature: + # skip aggregation when requested, + # or when model outputs just one vector per chunk + # or when model is permutation-invariant (and not post-processed) + if ( + self.skip_aggregation + or specifications.resolution == Resolution.CHUNK + or ( + specifications.permutation_invariant + and self.pre_aggregation_hook is None + ) + ): + frames = SlidingWindow( + start=0.0, duration=self.duration, step=self.step + ) + return SlidingWindowFeature(outputs, frames) + + if self.pre_aggregation_hook is not None: + outputs = self.pre_aggregation_hook(outputs) + + aggregated = self.aggregate( + SlidingWindowFeature( + outputs, + SlidingWindow(start=0.0, duration=self.duration, step=self.step), + ), + frames=frames, + warm_up=self.warm_up, + hamming=True, + missing=0.0, ) - return aggregated + # remove padding that was added to last chunk + if has_last_chunk: + aggregated.data = aggregated.crop( + Segment(0.0, num_samples / sample_rate), mode="loose" + ) + + return aggregated + + return map_with_specifications( + self.model.specifications, __aggregate, outputs, frames + ) def __call__( self, file: AudioFile, hook: Optional[Callable] = None - ) -> Union[SlidingWindowFeature, np.ndarray]: + ) -> Union[ + Tuple[Union[SlidingWindowFeature, np.ndarray]], + Union[SlidingWindowFeature, np.ndarray], + ]: """Run inference on a whole file Parameters @@ -352,7 +407,7 @@ def __call__( Returns ------- - output : SlidingWindowFeature or np.ndarray + output : (tuple of) SlidingWindowFeature or np.ndarray Model output, as `SlidingWindowFeature` if `window` is set to "sliding" and `np.ndarray` if is set to "whole". @@ -362,7 +417,14 @@ def __call__( if self.window == "sliding": return self.slide(waveform, sample_rate, hook=hook) - return self.infer(waveform[None])[0] + outputs: Union[np.ndarray, Tuple[np.ndarray]] = self.infer(waveform[None]) + + def __first_sample(outputs: np.ndarray, **kwargs) -> np.ndarray: + return outputs[0] + + return map_with_specifications( + self.model.specifications, __first_sample, outputs + ) def crop( self, @@ -370,7 +432,10 @@ def crop( chunk: Union[Segment, List[Segment]], duration: Optional[float] = None, hook: Optional[Callable] = None, - ) -> Union[SlidingWindowFeature, np.ndarray]: + ) -> Union[ + Tuple[Union[SlidingWindowFeature, np.ndarray]], + Union[SlidingWindowFeature, np.ndarray], + ]: """Run inference on a chunk or a list of chunks Parameters @@ -395,7 +460,7 @@ def crop( Returns ------- - output : SlidingWindowFeature or np.ndarray + output : (tuple of) SlidingWindowFeature or np.ndarray Model output, as `SlidingWindowFeature` if `window` is set to "sliding" and `np.ndarray` if is set to "whole". @@ -418,31 +483,37 @@ def crop( waveform, sample_rate = self.model.audio.crop( file, chunk, duration=duration ) - output = self.slide(waveform, sample_rate, hook=hook) - - frames = output.sliding_window - shifted_frames = SlidingWindow( - start=chunk.start, duration=frames.duration, step=frames.step - ) - return SlidingWindowFeature(output.data, shifted_frames) - - elif self.window == "whole": - if isinstance(chunk, Segment): - waveform, sample_rate = self.model.audio.crop( - file, chunk, duration=duration - ) - else: - waveform = torch.cat( - [self.model.audio.crop(file, c)[0] for c in chunk], dim=1 + outputs: Union[ + SlidingWindowFeature, Tuple[SlidingWindowFeature] + ] = self.slide(waveform, sample_rate, hook=hook) + + def __shift(output: SlidingWindowFeature, **kwargs) -> SlidingWindowFeature: + frames = output.sliding_window + shifted_frames = SlidingWindow( + start=chunk.start, duration=frames.duration, step=frames.step ) + return SlidingWindowFeature(output.data, shifted_frames) - return self.infer(waveform[None])[0] + return map_with_specifications(self.model.specifications, __shift, outputs) + if isinstance(chunk, Segment): + waveform, sample_rate = self.model.audio.crop( + file, chunk, duration=duration + ) else: - raise NotImplementedError( - f"Unsupported window type '{self.window}': should be 'sliding' or 'whole'." + waveform = torch.cat( + [self.model.audio.crop(file, c)[0] for c in chunk], dim=1 ) + outputs: Union[np.ndarray, Tuple[np.ndarray]] = self.infer(waveform[None]) + + def __first_sample(outputs: np.ndarray, **kwargs) -> np.ndarray: + return outputs[0] + + return map_with_specifications( + self.model.specifications, __first_sample, outputs + ) + @staticmethod def aggregate( scores: SlidingWindowFeature, diff --git a/pyannote/audio/core/model.py b/pyannote/audio/core/model.py index da5c79fe8..199dcfb24 100644 --- a/pyannote/audio/core/model.py +++ b/pyannote/audio/core/model.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2020-2021 CNRS +# Copyright (c) 2020- CNRS # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -49,6 +49,7 @@ Task, UnknownSpecificationsError, ) +from pyannote.audio.utils.multi_task import map_with_specifications from pyannote.audio.utils.version import check_version CACHE_DIR = os.getenv( @@ -59,241 +60,9 @@ HF_LIGHTNING_CONFIG_NAME = "config.yaml" +# NOTE: needed to backward compatibility to load models trained before pyannote.audio 3.x class Introspection: - """Model introspection - - Parameters - ---------- - min_num_samples: int - For fixed-duration models, expected number of input samples. - For variable-duration models, minimum number of input samples supported by - the model (i.e. model fails for smaller number of samples). - min_num_frames: int - Corresponding number of output frames. - inc_num_samples: int - Number of input samples leading to an increase of number of output frames. - Has no meaning for fixed-duration models (set to 0). - inc_num_frames: int - Corresponding increase in number of output frames - Has no meaning for fixed-duration models (set to 0). - dimension: int - Output dimension - sample_rate: int - Expected input sample rate - - Usage - ----- - >>> introspection = Introspection.from_model(model) - >>> isinstance(introspection.frames, SlidingWindow) - >>> num_samples = 16000 # 1s at 16kHz - >>> num_frames, dimension = introspection(num_samples) - """ - - def __init__( - self, - min_num_samples: int, - min_num_frames: int, - inc_num_samples: int, - inc_num_frames: int, - dimension: int, - sample_rate: int, - ): - super().__init__() - self.min_num_samples = min_num_samples - self.min_num_frames = min_num_frames - self.inc_num_samples = inc_num_samples - self.inc_num_frames = inc_num_frames - self.dimension = dimension - self.sample_rate = sample_rate - - @classmethod - def from_model(cls, model: "Model") -> Introspection: - """ - - Parameters - ---------- - model : Model - """ - - specifications = model.specifications - duration = specifications.duration - min_duration = specifications.min_duration or duration - - # case 1: the model expects a fixed-duration chunk - if min_duration == duration: - num_samples = model.audio.get_num_samples(specifications.duration) - frames = model(model.example_input_array) - if specifications.resolution == Resolution.FRAME: - _, num_frames, dimension = frames.shape - return cls( - min_num_samples=num_samples, - min_num_frames=num_frames, - inc_num_samples=0, - inc_num_frames=0, - dimension=dimension, - sample_rate=model.hparams.sample_rate, - ) - - elif specifications.resolution == Resolution.CHUNK: - _, dimension = frames.shape - return cls( - min_num_samples=num_samples, - min_num_frames=1, - inc_num_samples=0, - inc_num_frames=0, - dimension=dimension, - sample_rate=model.hparams.sample_rate, - ) - - # case 2: the model supports variable-duration chunks - # we use dichotomic search to find the minimum number of samples - example_input_array = model.example_input_array - batch_size, num_channels, num_samples = example_input_array.shape - example_input_array = torch.randn( - (batch_size, num_channels, num_samples), - dtype=example_input_array.dtype, - layout=example_input_array.layout, - device=example_input_array.device, - requires_grad=False, - ) - - # dichotomic search of "min_num_samples" - lower, upper, min_num_samples = 1, num_samples, None - while True: - num_samples = (lower + upper) // 2 - try: - with torch.no_grad(): - frames = model(example_input_array[:, :, :num_samples]) - except Exception: - lower = num_samples - else: - min_num_samples = num_samples - if specifications.resolution == Resolution.FRAME: - _, min_num_frames, dimension = frames.shape - elif specifications.resolution == Resolution.CHUNK: - _, dimension = frames.shape - else: - # should never happen - pass - upper = num_samples - - if lower + 1 == upper: - break - - # if "min_num_samples" is still None at this point, it means that - # the forward pass always failed and raised an exception. most likely, - # it means that there is a problem with the model definition. - # we try again without catching the exception to help the end user debug - # their model - if min_num_samples is None: - frames = model(example_input_array) - - # corner case for chunk-level tasks - if specifications.resolution == Resolution.CHUNK: - return cls( - min_num_samples=min_num_samples, - min_num_frames=1, - inc_num_samples=0, - inc_num_frames=0, - dimension=dimension, - sample_rate=model.hparams.sample_rate, - ) - - # search reasonable upper bound for "inc_num_samples" - while True: - num_samples = 2 * min_num_samples - example_input_array = torch.randn( - (batch_size, num_channels, num_samples), - dtype=example_input_array.dtype, - layout=example_input_array.layout, - device=example_input_array.device, - requires_grad=False, - ) - with torch.no_grad(): - frames = model(example_input_array) - num_frames = frames.shape[1] - if num_frames > min_num_frames: - break - - # dichotomic search of "inc_num_samples" - lower, upper = min_num_samples, num_samples - while True: - num_samples = (lower + upper) // 2 - example_input_array = torch.randn( - (batch_size, num_channels, num_samples), - dtype=example_input_array.dtype, - layout=example_input_array.layout, - device=example_input_array.device, - requires_grad=False, - ) - with torch.no_grad(): - frames = model(example_input_array) - num_frames = frames.shape[1] - if num_frames > min_num_frames: - inc_num_frames = num_frames - min_num_frames - inc_num_samples = num_samples - min_num_samples - upper = num_samples - else: - lower = num_samples - - if lower + 1 == upper: - break - - return cls( - min_num_samples=min_num_samples, - min_num_frames=min_num_frames, - inc_num_samples=inc_num_samples, - inc_num_frames=inc_num_frames, - dimension=dimension, - sample_rate=model.hparams.sample_rate, - ) - - def __call__(self, num_samples: int) -> Tuple[int, int]: - """Predict output shape, given number of input samples - - Parameters - ---------- - num_samples : int - Number of input samples. - - Returns - ------- - num_frames : int - Number of output frames - dimension : int - Dimension of output frames - """ - - # case 1: the model expects a fixed-duration chunk - if self.inc_num_frames == 0: - assert num_samples == self.min_num_samples - return self.min_num_frames, self.dimension - - # case 2: the model supports variable-duration chunks - - if num_samples < self.min_num_samples: - return 0, self.dimension - - return ( - self.min_num_frames - + self.inc_num_frames - * ((num_samples - self.min_num_samples + 1) // self.inc_num_samples), - self.dimension, - ) - - @property - def frames(self) -> SlidingWindow: - # HACK to support model trained before 'sample_rate' was an Introspection attribute - sample_rate = getattr(self, "sample_rate", 16000) - - if self.inc_num_frames == 0: - step = (self.min_num_samples / self.min_num_frames) / sample_rate - else: - # FIXME: this is not 100% accurate, but it's good enough for now - # FIXME: it should probably be estimated from the maximum duration - step = (self.inc_num_samples / self.inc_num_frames) / sample_rate - - return SlidingWindow(start=0.0, step=step, duration=step) + pass class Model(pl.LightningModule): @@ -327,31 +96,21 @@ def __init__( self.audio = Audio(sample_rate=self.hparams.sample_rate, mono="downmix") @property - def example_input_array(self) -> torch.Tensor: - batch_size = 3 if self.task is None else self.task.batch_size - duration = 2.0 if self.task is None else self.task.duration - - return torch.randn( - ( - batch_size, - self.hparams.num_channels, - int(self.hparams.sample_rate * duration), - ), - device=self.device, - ) - - @property - def task(self): + def task(self) -> Task: return self._task @task.setter - def task(self, task): + def task(self, task: Task): self._task = task - del self.introspection del self.specifications + def build(self): + # use this method to add task-dependent layers to the model + # (e.g. the final classification and activation layers) + pass + @property - def specifications(self): + def specifications(self) -> Union[Specifications, Tuple[Specifications]]: if self.task is None: try: specifications = self._specifications @@ -376,7 +135,22 @@ def specifications(self): return specifications @specifications.setter - def specifications(self, specifications): + def specifications( + self, specifications: Union[Specifications, Tuple[Specifications]] + ): + if not isinstance(specifications, (Specifications, tuple)): + raise ValueError( + "Only regular specifications or tuple of specifications are supported." + ) + + durations = set(s.duration for s in specifications) + if len(durations) > 1: + raise ValueError("All tasks must share the same (maximum) duration.") + + min_durations = set(s.min_duration for s in specifications) + if len(min_durations) > 1: + raise ValueError("All tasks must share the same minimum duration.") + self._specifications = specifications @specifications.deleter @@ -384,34 +158,70 @@ def specifications(self): if hasattr(self, "_specifications"): del self._specifications - def build(self): - # use this method to add task-dependent layers to the model - # (e.g. the final classification and activation layers) - pass + def __example_input_array(self, duration: Optional[float] = None) -> torch.Tensor: + duration = duration or next(iter(self.specifications)).duration + return torch.randn( + ( + 1, + self.hparams.num_channels, + self.audio.get_num_samples(duration), + ), + device=self.device, + ) @property - def introspection(self) -> Introspection: - """Introspection + def example_input_array(self) -> torch.Tensor: + return self.__example_input_array() + + def example_output( + self, duration: Optional[float] = None + ) -> Union[torch.Tensor, Tuple[torch.Tensor]]: + """Example output""" + example_input_array = self.__example_input_array(duration=duration) + with torch.inference_mode(): + example_output = self(example_input_array) + + if not isinstance(example_output, (torch.Tensor, tuple)): + raise ValueError( + "Models must return either a torch.Tensor or a tuple of torch.Tensor" + ) - Returns - ------- - introspection: Introspection - Model introspection - """ + return example_output - if not hasattr(self, "_introspection"): - self._introspection = Introspection.from_model(self) + @property + def output_frames( + self, + ) -> Union[Optional[SlidingWindow], Tuple[Optional[SlidingWindow]]]: + """Output frames as (tuple of) SlidingWindow(s)""" + + def __output_frames( + example_output: torch.Tensor, + specifications: Specifications = None, + ) -> Optional[SlidingWindow]: + if specifications.resolution == Resolution.FRAME: + _, num_frames, _ = example_output.shape + frame_duration = specifications.duration / num_frames + return SlidingWindow(step=frame_duration, duration=frame_duration) + + return None + + return map_with_specifications( + self.specifications, __output_frames, self.example_output() + ) + + @property + def output_dimension(self) -> Union[int, Tuple[int]]: + """Output dimension as (tuple of) int(s)""" - return self._introspection + duration = next(iter(self.specifications)).duration + example_output = self.example_output(duration=duration) - @introspection.setter - def introspection(self, introspection): - self._introspection = introspection + def __output_dimension(example_output: torch.Tensor, **kwargs) -> int: + return example_output.shape[-1] - @introspection.deleter - def introspection(self): - if hasattr(self, "_introspection"): - del self._introspection + return map_with_specifications( + self.specifications, __output_dimension, example_output + ) def setup(self, stage=None): if stage == "fit": @@ -456,9 +266,6 @@ def setup(self, stage=None): # setup custom validation metrics self.task.setup_validation_metric() - # this is to make sure introspection is performed here, once and for all - _ = self.introspection - # list of layers after adding task-dependent layers after = set((name, id(module)) for name, module in self.named_modules()) @@ -477,7 +284,6 @@ def on_save_checkpoint(self, checkpoint): "module": self.__class__.__module__, "class": self.__class__.__name__, }, - "introspection": self.introspection, "specifications": self.specifications, } @@ -507,41 +313,14 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]): self.setup() - self.introspection = checkpoint["pyannote.audio"]["introspection"] - - def forward(self, waveforms: torch.Tensor) -> torch.Tensor: + def forward( + self, waveforms: torch.Tensor, **kwargs + ) -> Union[torch.Tensor, Tuple[torch.Tensor]]: msg = "Class {self.__class__.__name__} should define a `forward` method." raise NotImplementedError(msg) - def helper_default_activation(self, specifications: Specifications) -> nn.Module: - """Helper function for default_activation - - Parameters - ---------- - specifications: Specifications - Task specification. - - Returns - ------- - activation : nn.Module - Default activation function. - """ - - if specifications.problem == Problem.BINARY_CLASSIFICATION: - return nn.Sigmoid() - - elif specifications.problem == Problem.MONO_LABEL_CLASSIFICATION: - return nn.LogSoftmax(dim=-1) - - elif specifications.problem == Problem.MULTI_LABEL_CLASSIFICATION: - return nn.Sigmoid() - - else: - msg = "TODO: implement default activation for other types of problems" - raise NotImplementedError(msg) - # convenience function to automate the choice of the final activation function - def default_activation(self) -> nn.Module: + def default_activation(self) -> Union[nn.Module, Tuple[nn.Module]]: """Guess default activation function according to task specification * sigmoid for binary classification @@ -550,10 +329,25 @@ def default_activation(self) -> nn.Module: Returns ------- - activation : nn.Module + activation : (tuple of) nn.Module Activation. """ - return self.helper_default_activation(self.specifications) + + def __default_activation(specifications: Specifications = None) -> nn.Module: + if specifications.problem == Problem.BINARY_CLASSIFICATION: + return nn.Sigmoid() + + elif specifications.problem == Problem.MONO_LABEL_CLASSIFICATION: + return nn.LogSoftmax(dim=-1) + + elif specifications.problem == Problem.MULTI_LABEL_CLASSIFICATION: + return nn.Sigmoid() + + else: + msg = "TODO: implement default activation for other types of problems" + raise NotImplementedError(msg) + + return map_with_specifications(self.specifications, __default_activation) # training data logic is delegated to the task because the # model does not really need to know how it is being used. @@ -578,9 +372,7 @@ def validation_step(self, batch, batch_idx): def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=1e-3) - def _helper_up_to( - self, module_name: Text, requires_grad: bool = False - ) -> List[Text]: + def __up_to(self, module_name: Text, requires_grad: bool = False) -> List[Text]: """Helper function for freeze_up_to and unfreeze_up_to""" tokens = module_name.split(".") @@ -637,7 +429,7 @@ def freeze_up_to(self, module_name: Text) -> List[Text]: If your model does not follow a sequential structure, you might want to use freeze_by_name for more control. """ - return self._helper_up_to(module_name, requires_grad=False) + return self.__up_to(module_name, requires_grad=False) def unfreeze_up_to(self, module_name: Text) -> List[Text]: """Unfreeze model up to specific module @@ -662,9 +454,9 @@ def unfreeze_up_to(self, module_name: Text) -> List[Text]: If your model does not follow a sequential structure, you might want to use freeze_by_name for more control. """ - return self._helper_up_to(module_name, requires_grad=True) + return self.__up_to(module_name, requires_grad=True) - def _helper_by_name( + def __by_name( self, modules: Union[List[Text], Text], recurse: bool = True, @@ -720,7 +512,7 @@ def freeze_by_name( ValueError if at least one of `modules` does not exist. """ - return self._helper_by_name( + return self.__by_name( modules, recurse=recurse, requires_grad=False, @@ -751,7 +543,7 @@ def unfreeze_by_name( ValueError if at least one of `modules` does not exist. """ - return self._helper_by_name(modules, recurse=recurse, requires_grad=True) + return self.__by_name(modules, recurse=recurse, requires_grad=True) @classmethod def from_pretrained( diff --git a/pyannote/audio/core/task.py b/pyannote/audio/core/task.py index d02e643b4..e8baee29f 100644 --- a/pyannote/audio/core/task.py +++ b/pyannote/audio/core/task.py @@ -97,7 +97,7 @@ class Specifications: permutation_invariant: bool = False @cached_property - def powerset(self): + def powerset(self) -> bool: if self.powerset_max_classes is None: return False @@ -120,6 +120,12 @@ def num_powerset_classes(self) -> int: ) ) + def __len__(self): + return 1 + + def __iter__(self): + yield self + class TrainDataset(IterableDataset): def __init__(self, task: Task): @@ -193,7 +199,7 @@ class Task(pl.LightningDataModule): Attributes ---------- - specifications : Specifications or dict of Specifications + specifications : Specifications or tuple of Specifications Task specifications (available after `Task.setup` has been called.) """ @@ -375,6 +381,11 @@ def common_step(self, batch, batch_idx: int, stage: Literal["train", "val"]): {"loss": loss} """ + if isinstance(self.specifications, tuple): + raise NotImplementedError( + "Default training/validation step is not implemented for multi-task." + ) + # forward pass y_pred = self.model(batch["X"]) diff --git a/pyannote/audio/models/segmentation/PyanNet.py b/pyannote/audio/models/segmentation/PyanNet.py index 1b68a32a9..5af3734b1 100644 --- a/pyannote/audio/models/segmentation/PyanNet.py +++ b/pyannote/audio/models/segmentation/PyanNet.py @@ -80,7 +80,6 @@ def __init__( num_channels: int = 1, task: Optional[Task] = None, ): - super().__init__(sample_rate=sample_rate, num_channels=num_channels, task=task) sincnet = merge_dict(self.SINCNET_DEFAULTS, sincnet) @@ -140,7 +139,6 @@ def __init__( ) def build(self): - if self.hparams.linear["num_layers"] > 0: in_features = self.hparams.linear["hidden_size"] else: @@ -148,6 +146,9 @@ def build(self): 2 if self.hparams.lstm["bidirectional"] else 1 ) + if isinstance(self.specifications, tuple): + raise ValueError("PyanNet does not support multi-tasking.") + if self.specifications.powerset: out_features = self.specifications.num_powerset_classes else: diff --git a/pyannote/audio/models/segmentation/debug.py b/pyannote/audio/models/segmentation/debug.py index 498faee27..89512320c 100644 --- a/pyannote/audio/models/segmentation/debug.py +++ b/pyannote/audio/models/segmentation/debug.py @@ -39,7 +39,6 @@ def __init__( num_channels: int = 1, task: Optional[Task] = None, ): - super().__init__(sample_rate=sample_rate, num_channels=num_channels, task=task) self.mfcc = MFCC( @@ -60,7 +59,16 @@ def __init__( def build(self): # define task-dependent layers - self.classifier = nn.Linear(32 * 2, len(self.specifications.classes)) + + if isinstance(self.specifications, tuple): + raise ValueError("SimpleSegmentationModel does not support multi-tasking.") + + if self.specifications.powerset: + out_features = self.specifications.num_powerset_classes + else: + out_features = len(self.specifications.classes) + + self.classifier = nn.Linear(32 * 2, out_features) self.activation = self.default_activation() def forward(self, waveforms: torch.Tensor) -> torch.Tensor: diff --git a/pyannote/audio/pipelines/overlapped_speech_detection.py b/pyannote/audio/pipelines/overlapped_speech_detection.py index 9b14ee10f..a326b8786 100644 --- a/pyannote/audio/pipelines/overlapped_speech_detection.py +++ b/pyannote/audio/pipelines/overlapped_speech_detection.py @@ -128,7 +128,7 @@ def __init__( # load model model = get_model(segmentation, use_auth_token=use_auth_token) - if model.introspection.dimension > 1: + if model.output_dimension > 1: inference_kwargs["pre_aggregation_hook"] = lambda scores: np.partition( scores, -2, axis=-1 )[:, :, -2, np.newaxis] diff --git a/pyannote/audio/pipelines/resegmentation.py b/pyannote/audio/pipelines/resegmentation.py index 57cf9004b..468d5087d 100644 --- a/pyannote/audio/pipelines/resegmentation.py +++ b/pyannote/audio/pipelines/resegmentation.py @@ -88,7 +88,6 @@ def __init__( der_variant: dict = None, use_auth_token: Union[Text, None] = None, ): - super().__init__() self.segmentation = segmentation @@ -96,7 +95,7 @@ def __init__( model: Model = get_model(segmentation, use_auth_token=use_auth_token) self._segmentation = Inference(model) - self._frames = self._segmentation.model.introspection.frames + self._frames = self._segmentation.model.output_frames self._audio = model.audio diff --git a/pyannote/audio/pipelines/speaker_diarization.py b/pyannote/audio/pipelines/speaker_diarization.py index 6bc81f28a..038bd4676 100644 --- a/pyannote/audio/pipelines/speaker_diarization.py +++ b/pyannote/audio/pipelines/speaker_diarization.py @@ -136,7 +136,7 @@ def __init__( skip_aggregation=True, batch_size=segmentation_batch_size, ) - self._frames: SlidingWindow = self._segmentation.model.introspection.frames + self._frames: SlidingWindow = self._segmentation.model.output_frames if self._segmentation.model.specifications.powerset: self.segmentation = ParamDict( diff --git a/pyannote/audio/pipelines/speaker_verification.py b/pyannote/audio/pipelines/speaker_verification.py index 1a672d614..f928aabbb 100644 --- a/pyannote/audio/pipelines/speaker_verification.py +++ b/pyannote/audio/pipelines/speaker_verification.py @@ -64,7 +64,6 @@ def __init__( embedding: Text = "nvidia/speakerverification_en_titanet_large", device: torch.device = None, ): - if not NEMO_IS_AVAILABLE: raise ImportError( f"'NeMo' must be installed to use '{embedding}' embeddings. " @@ -90,7 +89,6 @@ def sample_rate(self) -> int: @cached_property def dimension(self) -> int: - input_signal = torch.rand(1, self.sample_rate).to(self.device) input_signal_length = torch.tensor([self.sample_rate]).to(self.device) _, embeddings = self.model_( @@ -105,7 +103,6 @@ def metric(self) -> str: @cached_property def min_num_samples(self) -> int: - lower, upper = 2, round(0.5 * self.sample_rate) middle = (lower + upper) // 2 while lower + 1 < upper: @@ -152,7 +149,6 @@ def __call__( wav_lens = signals.shape[1] * torch.ones(batch_size) else: - batch_size_masks, _ = masks.shape assert batch_size == batch_size_masks @@ -229,7 +225,6 @@ def __init__( device: torch.device = None, use_auth_token: Union[Text, None] = None, ): - if not SPEECHBRAIN_IS_AVAILABLE: raise ImportError( f"'speechbrain' must be installed to use '{embedding}' embeddings. " @@ -281,19 +276,19 @@ def metric(self) -> str: @cached_property def min_num_samples(self) -> int: - - lower, upper = 2, round(0.5 * self.sample_rate) - middle = (lower + upper) // 2 - while lower + 1 < upper: - try: - _ = self.classifier_.encode_batch( - torch.randn(1, middle).to(self.device) - ) - upper = middle - except RuntimeError: - lower = middle - + with torch.inference_mode(): + lower, upper = 2, round(0.5 * self.sample_rate) middle = (lower + upper) // 2 + while lower + 1 < upper: + try: + _ = self.classifier_.encode_batch( + torch.randn(1, middle).to(self.device) + ) + upper = middle + except RuntimeError: + lower = middle + + middle = (lower + upper) // 2 return upper @@ -324,7 +319,6 @@ def __call__( wav_lens = signals.shape[1] * torch.ones(batch_size) else: - batch_size_masks, _ = masks.shape assert batch_size == batch_size_masks @@ -425,7 +419,7 @@ def sample_rate(self) -> int: @cached_property def dimension(self) -> int: - return self.model_.introspection.dimension + return self.model_.output_dimension @cached_property def metric(self) -> str: @@ -433,12 +427,24 @@ def metric(self) -> str: @cached_property def min_num_samples(self) -> int: - return self.model_.introspection.min_num_samples + with torch.inference_mode(): + lower, upper = 2, round(0.5 * self.sample_rate) + middle = (lower + upper) // 2 + while lower + 1 < upper: + try: + _ = self.model_(torch.randn(1, 1, middle).to(self.device)) + upper = middle + except RuntimeError: + lower = middle + + middle = (lower + upper) // 2 + + return upper def __call__( self, waveforms: torch.Tensor, masks: torch.Tensor = None ) -> np.ndarray: - with torch.no_grad(): + with torch.inference_mode(): if masks is None: embeddings = self.model_(waveforms.to(self.device)) else: @@ -557,7 +563,6 @@ def __init__( ) def apply(self, file: AudioFile) -> np.ndarray: - device = self.embedding_model_.device # read audio file and send it to GPU @@ -583,7 +588,6 @@ def main( embedding: str = "pyannote/embedding", segmentation: str = None, ): - import typer from pyannote.database import FileFinder, get_protocol from pyannote.metrics.binary_classification import det_curve @@ -601,7 +605,6 @@ def main( trials = getattr(protocol, f"{subset}_trial")() for t, trial in enumerate(tqdm(trials)): - audio1 = trial["file1"]["audio"] if audio1 not in emb: emb[audio1] = pipeline(audio1) diff --git a/pyannote/audio/pipelines/utils/oracle.py b/pyannote/audio/pipelines/utils/oracle.py index 486b09274..0b6b58f85 100644 --- a/pyannote/audio/pipelines/utils/oracle.py +++ b/pyannote/audio/pipelines/utils/oracle.py @@ -39,7 +39,7 @@ def oracle_segmentation( Simulates inference based on an (imaginary) oracle segmentation model: >>> oracle = Model.from_pretrained("oracle") - >>> assert frames == oracle.introspection.frames + >>> assert frames == oracle.output_frames >>> inference = Inference(oracle, duration=window.duration, step=window.step, skip_aggregation=True) >>> oracle_segmentation = inference(file) diff --git a/pyannote/audio/tasks/segmentation/mixins.py b/pyannote/audio/tasks/segmentation/mixins.py index 1cdb9840d..8dbb04488 100644 --- a/pyannote/audio/tasks/segmentation/mixins.py +++ b/pyannote/audio/tasks/segmentation/mixins.py @@ -25,11 +25,13 @@ import random import warnings from collections import defaultdict +from functools import cached_property from typing import Dict, Optional, Sequence, Union import matplotlib.pyplot as plt import numpy as np import torch +from pyannote.core import SlidingWindow from pyannote.database.protocol import SegmentationProtocol, SpeakerDiarizationProtocol from pyannote.database.protocol.protocol import Scope, Subset from pytorch_lightning.loggers import MLFlowLogger, TensorBoardLogger @@ -49,7 +51,6 @@ class SegmentationTaskMixin: """Methods common to most segmentation tasks""" def get_file(self, file_id): - file = dict() file["audio"] = str(self.audios[file_id], encoding="utf-8") @@ -121,7 +122,6 @@ def setup(self, stage: Optional[str] = None): files_iter = self.protocol.train() for file_id, file in enumerate(files_iter): - # gather metadata and update metadata_unique_values so that each metadatum # (e.g. source database or label) is represented by an integer. metadatum = dict() @@ -142,7 +142,6 @@ def setup(self, stage: Optional[str] = None): # Different files may be annotated using a different set of classes # (e.g. one database for speech/music/noise, and another one for male/female/child) if isinstance(self.protocol, SegmentationProtocol): - if "classes" in file: local_classes = file["classes"] else: @@ -191,7 +190,6 @@ def setup(self, stage: Optional[str] = None): # keep track of any other (integer or string) metadata provided by the protocol # (e.g. a "domain" key for domain-adversarial training) for key in remaining_metadata_keys: - value = file[key] if isinstance(value, str): @@ -233,7 +231,6 @@ def setup(self, stage: Optional[str] = None): # annotated regions and duration _annotated_duration = 0.0 for segment in file["annotated"]: - # skip annotated regions that are shorter than training chunk duration if segment.duration < duration: continue @@ -255,13 +252,11 @@ def setup(self, stage: Optional[str] = None): # annotations for segment, _, label in file["annotation"].itertracks(yield_label=True): - # "scope" is provided by speaker diarization protocols to indicate # whether speaker labels are local to the file ('file'), consistent across # all files in a database ('database'), or globally consistent ('global') if "scope" in file: - # 0 = 'file' # 1 = 'database' # 2 = 'global' @@ -276,7 +271,6 @@ def setup(self, stage: Optional[str] = None): database_label_idx = global_label_idx = -1 if scope > 0: # 'database' or 'global' - # update list of database-scope labels if label not in database_unique_labels: database_unique_labels.append(label) @@ -285,7 +279,6 @@ def setup(self, stage: Optional[str] = None): database_label_idx = database_unique_labels.index(label) if scope > 1: # 'global' - # update list of global-scope labels if label not in unique_labels: unique_labels.append(label) @@ -381,7 +374,6 @@ def setup(self, stage: Optional[str] = None): # iterate over files in the validation subset for file_id in validation_file_ids: - # get annotated regions in file annotated_regions = self.annotated_regions[ self.annotated_regions["file_id"] == file_id @@ -389,7 +381,6 @@ def setup(self, stage: Optional[str] = None): # iterate over annotated regions for annotated_region in annotated_regions: - # number of chunks in annotated region num_chunks = round(annotated_region["duration"] // duration) @@ -401,6 +392,15 @@ def setup(self, stage: Optional[str] = None): dtype = [("file_id", "i"), ("start", "f"), ("duration", "f")] self.validation_chunks = np.array(validation_chunks, dtype=dtype) + @cached_property + def frames(self) -> SlidingWindow: + return self.model.output_frames + + @cached_property + def num_frames_per_chunk(self) -> int: + batch_size, num_frames, num_classes = self.model.example_output().shape + return num_frames + def default_metric( self, ) -> Union[Metric, Sequence[Metric], Dict[str, Metric]]: @@ -450,13 +450,11 @@ def train__iter__helper(self, rng: random.Random, **filters): num_chunks_per_file = getattr(self, "num_chunks_per_file", 1) while True: - # select one file at random (with probability proportional to its annotated duration) file_id = np.random.choice(file_ids, p=prob_annotated_duration) # generate `num_chunks_per_file` chunks from this file for _ in range(num_chunks_per_file): - # find indices of annotated regions in this file annotated_region_indices = np.where( self.annotated_regions["file_id"] == file_id @@ -510,7 +508,6 @@ def train__iter__(self): subchunks[product] = self.train__iter__helper(rng, **filters) while True: - # select one subchunk generator at random (with uniform probability) # so that it is balanced on average if balance is not None: @@ -686,7 +683,6 @@ def validation_step(self, batch, batch_idx: int): # plot each sample for sample_idx in range(num_samples): - # find where in the grid it should be plotted row_idx = sample_idx // nrows col_idx = sample_idx % ncols diff --git a/pyannote/audio/tasks/segmentation/multilabel.py b/pyannote/audio/tasks/segmentation/multilabel.py index f27303e2a..70b1e1fc8 100644 --- a/pyannote/audio/tasks/segmentation/multilabel.py +++ b/pyannote/audio/tasks/segmentation/multilabel.py @@ -25,7 +25,7 @@ import numpy as np import torch import torch.nn.functional as F -from pyannote.core import Segment, SlidingWindow, SlidingWindowFeature +from pyannote.core import Segment, SlidingWindowFeature from pyannote.database import Protocol from pyannote.database.protocol import SegmentationProtocol from torch_audiomentations.core.transforms_interface import BaseWaveformTransform @@ -168,14 +168,6 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): sample = dict() sample["X"], _ = self.model.audio.crop(file, chunk, duration=duration) - - # TODO: this should be cached - # use model introspection to predict how many frames it will output - num_samples = sample["X"].shape[1] - num_frames, _ = self.model.introspection(num_samples) - resolution = duration / num_frames - frames = SlidingWindow(start=0.0, duration=resolution, step=resolution) - # gather all annotations of current file annotations = self.annotations[self.annotations["file_id"] == file_id] @@ -186,19 +178,19 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): # discretize chunk annotations at model output resolution start = np.maximum(chunk_annotations["start"], chunk.start) - chunk.start - start_idx = np.floor(start / resolution).astype(int) + start_idx = np.floor(start / self.frames.step).astype(int) end = np.minimum(chunk_annotations["end"], chunk.end) - chunk.start - end_idx = np.ceil(end / resolution).astype(int) + end_idx = np.ceil(end / self.frames.step).astype(int) # frame-level targets (-1 for un-annotated classes) - y = -np.ones((num_frames, len(self.classes)), dtype=np.int8) + y = -np.ones((self.num_frames_per_chunk, len(self.classes)), dtype=np.int8) y[:, self.annotated_classes[file_id]] = 0 for start, end, label in zip( start_idx, end_idx, chunk_annotations["global_label_idx"] ): y[start:end, label] = 1 - sample["y"] = SlidingWindowFeature(y, frames, labels=self.classes) + sample["y"] = SlidingWindowFeature(y, self.frames, labels=self.classes) metadata = self.metadata[file_id] sample["meta"] = {key: metadata[key] for key in metadata.dtype.names} diff --git a/pyannote/audio/tasks/segmentation/overlapped_speech_detection.py b/pyannote/audio/tasks/segmentation/overlapped_speech_detection.py index 8e6551447..492190938 100644 --- a/pyannote/audio/tasks/segmentation/overlapped_speech_detection.py +++ b/pyannote/audio/tasks/segmentation/overlapped_speech_detection.py @@ -24,7 +24,7 @@ from typing import Dict, Sequence, Text, Tuple, Union import numpy as np -from pyannote.core import Segment, SlidingWindow, SlidingWindowFeature +from pyannote.core import Segment, SlidingWindowFeature from pyannote.database import Protocol from torch_audiomentations.core.transforms_interface import BaseWaveformTransform from torchmetrics import Metric @@ -162,13 +162,6 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): sample = dict() sample["X"], _ = self.model.audio.crop(file, chunk, duration=duration) - # use model introspection to predict how many frames it will output - # TODO: this should be cached - num_samples = sample["X"].shape[1] - num_frames, _ = self.model.introspection(num_samples) - resolution = duration / num_frames - frames = SlidingWindow(start=0.0, duration=resolution, step=resolution) - # gather all annotations of current file annotations = self.annotations[self.annotations["file_id"] == file_id] @@ -179,17 +172,17 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): # discretize chunk annotations at model output resolution start = np.maximum(chunk_annotations["start"], chunk.start) - chunk.start - start_idx = np.floor(start / resolution).astype(int) + start_idx = np.floor(start / self.frames.step).astype(int) end = np.minimum(chunk_annotations["end"], chunk.end) - chunk.start - end_idx = np.ceil(end / resolution).astype(int) + end_idx = np.ceil(end / self.frames.step).astype(int) # frame-level targets - y = np.zeros((num_frames, 1), dtype=np.uint8) + y = np.zeros((self.num_frames_per_chunk, 1), dtype=np.uint8) for start, end in zip(start_idx, end_idx): y[start:end, 0] += 1 y = 1 * (y > 1) - sample["y"] = SlidingWindowFeature(y, frames, labels=["speech"]) + sample["y"] = SlidingWindowFeature(y, self.frames, labels=["speech"]) metadata = self.metadata[file_id] sample["meta"] = {key: metadata[key] for key in metadata.dtype.names} diff --git a/pyannote/audio/tasks/segmentation/speaker_diarization.py b/pyannote/audio/tasks/segmentation/speaker_diarization.py index 3ef0b1a17..dc3a33025 100644 --- a/pyannote/audio/tasks/segmentation/speaker_diarization.py +++ b/pyannote/audio/tasks/segmentation/speaker_diarization.py @@ -29,7 +29,7 @@ import torch import torch.nn.functional from matplotlib import pyplot as plt -from pyannote.core import Segment, SlidingWindow, SlidingWindowFeature +from pyannote.core import Segment, SlidingWindowFeature from pyannote.database.protocol import SpeakerDiarizationProtocol from pyannote.database.protocol.protocol import Scope, Subset from pytorch_lightning.loggers import MLFlowLogger, TensorBoardLogger @@ -327,13 +327,6 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): sample = dict() sample["X"], _ = self.model.audio.crop(file, chunk, duration=duration) - # use model introspection to predict how many frames it will output - # TODO: this should be cached - num_samples = sample["X"].shape[1] - num_frames, _ = self.model.introspection(num_samples) - resolution = duration / num_frames - frames = SlidingWindow(start=0.0, duration=resolution, step=resolution) - # gather all annotations of current file annotations = self.annotations[self.annotations["file_id"] == file_id] @@ -344,9 +337,9 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): # discretize chunk annotations at model output resolution start = np.maximum(chunk_annotations["start"], chunk.start) - chunk.start - start_idx = np.floor(start / resolution).astype(int) + start_idx = np.floor(start / self.frames.step).astype(int) end = np.minimum(chunk_annotations["end"], chunk.end) - chunk.start - end_idx = np.ceil(end / resolution).astype(int) + end_idx = np.ceil(end / self.frames.step).astype(int) # get list and number of labels for current scope labels = list(np.unique(chunk_annotations[label_scope_key])) @@ -356,7 +349,7 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): pass # initial frame-level targets - y = np.zeros((num_frames, num_labels), dtype=np.uint8) + y = np.zeros((self.num_frames_per_chunk, num_labels), dtype=np.uint8) # map labels to indices mapping = {label: idx for idx, label in enumerate(labels)} @@ -367,7 +360,7 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): mapped_label = mapping[label] y[start:end, mapped_label] = 1 - sample["y"] = SlidingWindowFeature(y, frames, labels=labels) + sample["y"] = SlidingWindowFeature(y, self.frames, labels=labels) metadata = self.metadata[file_id] sample["meta"] = {key: metadata[key] for key in metadata.dtype.names} @@ -565,11 +558,7 @@ def training_step(self, batch, batch_idx: int): weight[:, num_frames - warm_up_right :] = 0.0 if self.specifications.powerset: - powerset = torch.nn.functional.one_hot( - torch.argmax(prediction, dim=-1), - self.model.powerset.num_powerset_classes, - ).float() - multilabel = self.model.powerset.to_multilabel(powerset) + multilabel = self.model.powerset.to_multilabel(prediction) permutated_target, _ = permutate(multilabel, target) permutated_target_powerset = self.model.powerset.to_powerset( permutated_target.float() @@ -698,11 +687,7 @@ def validation_step(self, batch, batch_idx: int): weight[:, num_frames - warm_up_right :] = 0.0 if self.specifications.powerset: - powerset = torch.nn.functional.one_hot( - torch.argmax(prediction, dim=-1), - self.model.powerset.num_powerset_classes, - ).float() - multilabel = self.model.powerset.to_multilabel(powerset) + multilabel = self.model.powerset.to_multilabel(prediction) permutated_target, _ = permutate(multilabel, target) # FIXME: handle case where target have too many speakers? diff --git a/pyannote/audio/tasks/segmentation/voice_activity_detection.py b/pyannote/audio/tasks/segmentation/voice_activity_detection.py index 4851b7455..d740427f0 100644 --- a/pyannote/audio/tasks/segmentation/voice_activity_detection.py +++ b/pyannote/audio/tasks/segmentation/voice_activity_detection.py @@ -23,7 +23,7 @@ from typing import Dict, Sequence, Text, Tuple, Union import numpy as np -from pyannote.core import Segment, SlidingWindow, SlidingWindowFeature +from pyannote.core import Segment, SlidingWindowFeature from pyannote.database import Protocol from torch_audiomentations.core.transforms_interface import BaseWaveformTransform from torchmetrics import Metric @@ -144,13 +144,6 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): sample = dict() sample["X"], _ = self.model.audio.crop(file, chunk, duration=duration) - # use model introspection to predict how many frames it will output - # TODO: this should be cached - num_samples = sample["X"].shape[1] - num_frames, _ = self.model.introspection(num_samples) - resolution = duration / num_frames - frames = SlidingWindow(start=0.0, duration=resolution, step=resolution) - # gather all annotations of current file annotations = self.annotations[self.annotations["file_id"] == file_id] @@ -161,16 +154,16 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): # discretize chunk annotations at model output resolution start = np.maximum(chunk_annotations["start"], chunk.start) - chunk.start - start_idx = np.floor(start / resolution).astype(int) + start_idx = np.floor(start / self.frames.step).astype(int) end = np.minimum(chunk_annotations["end"], chunk.end) - chunk.start - end_idx = np.ceil(end / resolution).astype(int) + end_idx = np.ceil(end / self.frames.step).astype(int) # frame-level targets - y = np.zeros((num_frames, 1), dtype=np.uint8) + y = np.zeros((self.num_frames_per_chunk, 1), dtype=np.uint8) for start, end in zip(start_idx, end_idx): y[start:end, 0] = 1 - sample["y"] = SlidingWindowFeature(y, frames, labels=["speech"]) + sample["y"] = SlidingWindowFeature(y, self.frames, labels=["speech"]) metadata = self.metadata[file_id] sample["meta"] = {key: metadata[key] for key in metadata.dtype.names} diff --git a/pyannote/audio/utils/multi_task.py b/pyannote/audio/utils/multi_task.py new file mode 100644 index 000000000..3886a0eeb --- /dev/null +++ b/pyannote/audio/utils/multi_task.py @@ -0,0 +1,59 @@ +# MIT License +# +# Copyright (c) 2023- CNRS +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + + +from typing import Any, Callable, Tuple, Union + +from pyannote.audio.core.model import Specifications + + +def map_with_specifications( + specifications: Union[Specifications, Tuple[Specifications]], + func: Callable, + *iterables, +) -> Union[Any, Tuple[Any]]: + """Compute the function using arguments from each of the iterables + + Returns a tuple if provided `specifications` is a tuple, + otherwise returns the function return value. + + Parameters + ---------- + specifications : (tuple of) Specifications + Specifications or tuple of specifications + func : callable + Function called for each specification with + `func(*iterables[i], specifications=specifications[i])` + *iterables : + List of iterables with same length as `specifications`. + + Returns + ------- + output : (tuple of) `func` return value(s) + """ + + if isinstance(specifications, Specifications): + return func(*iterables, specifications=specifications) + + return tuple( + func(*i, specifications=s) for s, *i in zip(specifications, *iterables) + ) diff --git a/pyannote/audio/utils/powerset.py b/pyannote/audio/utils/powerset.py index 215cb7946..0f5cfb5bc 100644 --- a/pyannote/audio/utils/powerset.py +++ b/pyannote/audio/utils/powerset.py @@ -85,25 +85,29 @@ def build_cardinality(self) -> torch.Tensor: return cardinality def to_multilabel(self, powerset: torch.Tensor) -> torch.Tensor: - """Convert (hard) predictions from powerset to multi-label + """Convert predictions from (soft) powerset to (hard) multi-label Parameter --------- powerset : (batch_size, num_frames, num_powerset_classes) torch.Tensor - Hard predictions in "powerset" space. + Soft predictions in "powerset" space. Returns ------- multi_label : (batch_size, num_frames, num_classes) torch.Tensor Hard predictions in "multi-label" space. - - Note - ---- - This method will not complain if `powerset` is provided a soft predictions - (e.g. the output of a softmax-ed classifier). However, in that particular - case, the resulting soft multi-label output will not make much sense. """ - return torch.matmul(powerset, self.mapping) + + hard_powerset = torch.nn.functional.one_hot( + torch.argmax(powerset, dim=-1), + self.num_powerset_classes, + ).float() + + return torch.matmul(hard_powerset, self.mapping) + + def forward(self, powerset: torch.Tensor) -> torch.Tensor: + """Alias for `to_multilabel`""" + return self.to_multilabel(powerset) def to_powerset(self, multilabel: torch.Tensor) -> torch.Tensor: """Convert (hard) predictions from multi-label to powerset diff --git a/pyannote/audio/utils/preview.py b/pyannote/audio/utils/preview.py index ac085b10f..0464b165f 100644 --- a/pyannote/audio/utils/preview.py +++ b/pyannote/audio/utils/preview.py @@ -196,7 +196,6 @@ def make_audio_frame(T: float): ylim = (-0.1, 1.1) def make_frame(T: float): - # make sure all subsequent calls to notebook.plot_* # will only display the region center on current time t = T + segment.start @@ -215,7 +214,6 @@ def make_frame(T: float): ax_wav.set_ylabel("waveform") for (name, view), ax_view in zip(views.items(), ax_views): - ax_view.clear() if isinstance(view, Timeline): @@ -258,7 +256,7 @@ def make_frame(T: float): return IPythonVideo(video_path, embed=True) -def preview_training_samples( +def BROKEN_preview_training_samples( model: Model, blank: float = 1.0, video_fps: int = 5, From c12077c2adb220a31dd6dfeb685ba3bcdb57878b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Herv=C3=A9=20BREDIN?= Date: Wed, 24 May 2023 16:14:39 +0200 Subject: [PATCH 06/55] chore(task): remove `stage` argument from Task.setup --- pyannote/audio/core/task.py | 2 +- pyannote/audio/tasks/embedding/mixins.py | 4 ++-- pyannote/audio/tasks/segmentation/mixins.py | 12 +++--------- pyannote/audio/tasks/segmentation/multilabel.py | 4 ++-- .../audio/tasks/segmentation/speaker_diarization.py | 6 +++--- 5 files changed, 11 insertions(+), 17 deletions(-) diff --git a/pyannote/audio/core/task.py b/pyannote/audio/core/task.py index e8baee29f..de441590b 100644 --- a/pyannote/audio/core/task.py +++ b/pyannote/audio/core/task.py @@ -268,7 +268,7 @@ def prepare_data(self): """ pass - def setup(self, stage: Optional[str] = None): + def setup(self): """Called at the beginning of training at the very beginning of Model.setup(stage="fit") Notes diff --git a/pyannote/audio/tasks/embedding/mixins.py b/pyannote/audio/tasks/embedding/mixins.py index b02ae7f71..d83009802 100644 --- a/pyannote/audio/tasks/embedding/mixins.py +++ b/pyannote/audio/tasks/embedding/mixins.py @@ -21,7 +21,7 @@ # SOFTWARE. import math -from typing import Dict, Optional, Sequence, Union +from typing import Dict, Sequence, Union import torch import torch.nn.functional as F @@ -75,7 +75,7 @@ def batch_size(self) -> int: def batch_size(self, batch_size: int): self.batch_size_ = batch_size - def setup(self, stage: Optional[str] = None): + def setup(self): # loop over the training set, remove annotated regions shorter than # chunk duration, and keep track of the reference annotations, per class. diff --git a/pyannote/audio/tasks/segmentation/mixins.py b/pyannote/audio/tasks/segmentation/mixins.py index 8dbb04488..016d2032d 100644 --- a/pyannote/audio/tasks/segmentation/mixins.py +++ b/pyannote/audio/tasks/segmentation/mixins.py @@ -26,7 +26,7 @@ import warnings from collections import defaultdict from functools import cached_property -from typing import Dict, Optional, Sequence, Union +from typing import Dict, Sequence, Union import matplotlib.pyplot as plt import numpy as np @@ -73,14 +73,8 @@ def get_file(self, file_id): return file - def setup(self, stage: Optional[str] = None): - """Setup method - - Parameters - ---------- - stage : {'fit', 'validate', 'test'}, optional - Setup stage. Defaults to 'fit'. - """ + def setup(self): + """Setup""" # duration of training chunks # TODO: handle variable duration case diff --git a/pyannote/audio/tasks/segmentation/multilabel.py b/pyannote/audio/tasks/segmentation/multilabel.py index 70b1e1fc8..3afdd6bda 100644 --- a/pyannote/audio/tasks/segmentation/multilabel.py +++ b/pyannote/audio/tasks/segmentation/multilabel.py @@ -119,8 +119,8 @@ def __init__( # classes should be detected. therefore, we postpone the definition of # specifications to setup() - def setup(self, stage: Optional[str] = None): - super().setup(stage=stage) + def setup(self): + super().setup() self.specifications = Specifications( classes=self.classes, diff --git a/pyannote/audio/tasks/segmentation/speaker_diarization.py b/pyannote/audio/tasks/segmentation/speaker_diarization.py index dc3a33025..b5ffe3da1 100644 --- a/pyannote/audio/tasks/segmentation/speaker_diarization.py +++ b/pyannote/audio/tasks/segmentation/speaker_diarization.py @@ -23,7 +23,7 @@ import math import warnings from collections import Counter -from typing import Dict, Literal, Optional, Sequence, Text, Tuple, Union +from typing import Dict, Literal, Sequence, Text, Tuple, Union import numpy as np import torch @@ -186,8 +186,8 @@ def __init__( self.weight = weight self.vad_loss = vad_loss - def setup(self, stage: Optional[str] = None): - super().setup(stage=stage) + def setup(self): + super().setup() # estimate maximum number of speakers per chunk when not provided if self.max_speakers_per_chunk is None: From d19a728a0c157491a50eedc31035957d75ef3d31 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Herv=C3=A9=20BREDIN?= Date: Wed, 24 May 2023 17:17:08 +0200 Subject: [PATCH 07/55] fix(test): force test training to run on CPU --- tests/inference_test.py | 9 +- tests/test_train.py | 44 ++++----- tutorials/add_your_own_task.ipynb | 100 ++++++++++---------- tutorials/overlapped_speech_detection.ipynb | 18 +++- 4 files changed, 92 insertions(+), 79 deletions(-) diff --git a/tests/inference_test.py b/tests/inference_test.py index 807f94cc1..bd5040394 100644 --- a/tests/inference_test.py +++ b/tests/inference_test.py @@ -1,13 +1,13 @@ import numpy as np import pytest import pytorch_lightning as pl +from pyannote.core import SlidingWindowFeature +from pyannote.database import FileFinder, get_protocol from pyannote.audio import Inference, Model from pyannote.audio.core.task import Resolution from pyannote.audio.models.segmentation.debug import SimpleSegmentationModel from pyannote.audio.tasks import VoiceActivityDetection -from pyannote.core import SlidingWindowFeature -from pyannote.database import FileFinder, get_protocol HF_SAMPLE_MODEL_ID = "pyannote/TestModelForContinuousIntegration" @@ -29,8 +29,8 @@ def trained(): ) vad = VoiceActivityDetection(protocol, duration=2.0, batch_size=16, num_workers=4) model = SimpleSegmentationModel(task=vad) - trainer = pl.Trainer(fast_dev_run=True) - trainer.fit(model, vad) + trainer = pl.Trainer(fast_dev_run=True, accelerator="cpu") + trainer.fit(model) return protocol, model @@ -91,7 +91,6 @@ def test_on_file_path(trained): def test_skip_aggregation(pretrained_model, dev_file): - inference = Inference(pretrained_model, skip_aggregation=True) scores = inference(dev_file) assert len(scores.data.shape) == 3 diff --git a/tests/test_train.py b/tests/test_train.py index 79e7f071a..7a7bfe338 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -20,125 +20,119 @@ def protocol(): def test_train_segmentation(protocol): segmentation = SpeakerDiarization(protocol) model = SimpleSegmentationModel(task=segmentation) - trainer = Trainer(fast_dev_run=True) + trainer = Trainer(fast_dev_run=True, accelerator="cpu") trainer.fit(model) def test_train_voice_activity_detection(protocol): voice_activity_detection = VoiceActivityDetection(protocol) model = SimpleSegmentationModel(task=voice_activity_detection) - trainer = Trainer(fast_dev_run=True) + trainer = Trainer(fast_dev_run=True, accelerator="cpu") trainer.fit(model) def test_train_overlapped_speech_detection(protocol): overlapped_speech_detection = OverlappedSpeechDetection(protocol) model = SimpleSegmentationModel(task=overlapped_speech_detection) - trainer = Trainer(fast_dev_run=True) + trainer = Trainer(fast_dev_run=True, accelerator="cpu") trainer.fit(model) def test_finetune_with_task_that_does_not_need_setup_for_specs(protocol): voice_activity_detection = VoiceActivityDetection(protocol) model = SimpleSegmentationModel(task=voice_activity_detection) - trainer = Trainer(fast_dev_run=True) + trainer = Trainer(fast_dev_run=True, accelerator="cpu") trainer.fit(model) voice_activity_detection = VoiceActivityDetection(protocol) model.task = voice_activity_detection - trainer = Trainer(fast_dev_run=True) + trainer = Trainer(fast_dev_run=True, accelerator="cpu") trainer.fit(model) def test_finetune_with_task_that_needs_setup_for_specs(protocol): segmentation = SpeakerDiarization(protocol) model = SimpleSegmentationModel(task=segmentation) - trainer = Trainer(fast_dev_run=True) + trainer = Trainer(fast_dev_run=True, accelerator="cpu") trainer.fit(model) segmentation = SpeakerDiarization(protocol) model.task = segmentation - trainer = Trainer(fast_dev_run=True) + trainer = Trainer(fast_dev_run=True, accelerator="cpu") trainer.fit(model) def test_transfer_with_task_that_does_not_need_setup_for_specs(protocol): - segmentation = SpeakerDiarization(protocol) model = SimpleSegmentationModel(task=segmentation) - trainer = Trainer(fast_dev_run=True) + trainer = Trainer(fast_dev_run=True, accelerator="cpu") trainer.fit(model) voice_activity_detection = VoiceActivityDetection(protocol) model.task = voice_activity_detection - trainer = Trainer(fast_dev_run=True) + trainer = Trainer(fast_dev_run=True, accelerator="cpu") trainer.fit(model) def test_transfer_with_task_that_needs_setup_for_specs(protocol): - voice_activity_detection = VoiceActivityDetection(protocol) model = SimpleSegmentationModel(task=voice_activity_detection) - trainer = Trainer(fast_dev_run=True) + trainer = Trainer(fast_dev_run=True, accelerator="cpu") trainer.fit(model) segmentation = SpeakerDiarization(protocol) model.task = segmentation - trainer = Trainer(fast_dev_run=True) + trainer = Trainer(fast_dev_run=True, accelerator="cpu") trainer.fit(model) def test_finetune_freeze_with_task_that_needs_setup_for_specs(protocol): - segmentation = SpeakerDiarization(protocol) model = SimpleSegmentationModel(task=segmentation) - trainer = Trainer(fast_dev_run=True) + trainer = Trainer(fast_dev_run=True, accelerator="cpu") trainer.fit(model) segmentation = SpeakerDiarization(protocol) model.task = segmentation model.freeze_up_to("mfcc") - trainer = Trainer(fast_dev_run=True) + trainer = Trainer(fast_dev_run=True, accelerator="cpu") trainer.fit(model) def test_finetune_freeze_with_task_that_does_not_need_setup_for_specs(protocol): - vad = VoiceActivityDetection(protocol) model = SimpleSegmentationModel(task=vad) - trainer = Trainer(fast_dev_run=True) + trainer = Trainer(fast_dev_run=True, accelerator="cpu") trainer.fit(model) vad = VoiceActivityDetection(protocol) model.task = vad model.freeze_up_to("mfcc") - trainer = Trainer(fast_dev_run=True) + trainer = Trainer(fast_dev_run=True, accelerator="cpu") trainer.fit(model) def test_transfer_freeze_with_task_that_does_not_need_setup_for_specs(protocol): - segmentation = SpeakerDiarization(protocol) model = SimpleSegmentationModel(task=segmentation) - trainer = Trainer(fast_dev_run=True) + trainer = Trainer(fast_dev_run=True, accelerator="cpu") trainer.fit(model) voice_activity_detection = VoiceActivityDetection(protocol) model.task = voice_activity_detection model.freeze_up_to("mfcc") - trainer = Trainer(fast_dev_run=True) + trainer = Trainer(fast_dev_run=True, accelerator="cpu") trainer.fit(model) def test_transfer_freeze_with_task_that_needs_setup_for_specs(protocol): - voice_activity_detection = VoiceActivityDetection(protocol) model = SimpleSegmentationModel(task=voice_activity_detection) - trainer = Trainer(fast_dev_run=True) + trainer = Trainer(fast_dev_run=True, accelerator="cpu") trainer.fit(model) segmentation = SpeakerDiarization(protocol) model.task = segmentation model.freeze_up_to("mfcc") - trainer = Trainer(fast_dev_run=True) + trainer = Trainer(fast_dev_run=True, accelerator="cpu") trainer.fit(model) diff --git a/tutorials/add_your_own_task.ipynb b/tutorials/add_your_own_task.ipynb index b2053f459..251846957 100644 --- a/tutorials/add_your_own_task.ipynb +++ b/tutorials/add_your_own_task.ipynb @@ -1,6 +1,7 @@ { "cells": [ { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -32,6 +33,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -48,6 +50,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -57,6 +60,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -82,6 +86,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -125,6 +130,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -176,54 +182,52 @@ " augmentation=augmentation,\n", " )\n", "\n", - " def setup(self, stage=None):\n", - "\n", - " if stage == \"fit\":\n", - "\n", - " # load metadata for training subset\n", - " self.train_metadata_ = list()\n", - " for training_file in self.protocol.train():\n", - " self.training_metadata_.append({\n", - " # path to audio file (str)\n", - " \"audio\": training_file[\"audio\"],\n", - " # duration of audio file (float)\n", - " \"duration\": training_file[\"duration\"],\n", - " # reference annotation (pyannote.core.Annotation)\n", - " \"annotation\": training_file[\"annotation\"],\n", - " })\n", - "\n", - " # gather the list of classes\n", - " classes = set()\n", - " for training_file in self.train_metadata_:\n", - " classes.update(training_file[\"reference\"].labels())\n", - " classes = sorted(classes)\n", - "\n", - " # specify the addressed problem\n", - " self.specifications = Specifications(\n", - " # it is a multi-label classification problem\n", - " problem=Problem.MULTI_LABEL_CLASSIFICATION,\n", - " # we expect the model to output one prediction \n", - " # for the whole chunk\n", - " resolution=Resolution.CHUNK,\n", - " # the model will ingest chunks with that duration (in seconds)\n", - " duration=self.duration,\n", - " # human-readable names of classes\n", - " classes=classes)\n", - "\n", - " # `has_validation` is True iff protocol defines a development set\n", - " if not self.has_validation:\n", - " return\n", - "\n", - " # load metadata for validation subset\n", - " self.validation_metadata_ = list()\n", - " for validation_file in self.protocol.development():\n", - " self.validation_metadata_.append({\n", - " \"audio\": validation_file[\"audio\"],\n", - " \"num_samples\": math.floor(validation_file[\"duration\"] / self.duration),\n", - " \"annotation\": validation_file[\"annotation\"],\n", - " })\n", - " \n", - " \n", + " def setup(self):\n", + "\n", + " # load metadata for training subset\n", + " self.train_metadata_ = list()\n", + " for training_file in self.protocol.train():\n", + " self.training_metadata_.append({\n", + " # path to audio file (str)\n", + " \"audio\": training_file[\"audio\"],\n", + " # duration of audio file (float)\n", + " \"duration\": training_file[\"duration\"],\n", + " # reference annotation (pyannote.core.Annotation)\n", + " \"annotation\": training_file[\"annotation\"],\n", + " })\n", + "\n", + " # gather the list of classes\n", + " classes = set()\n", + " for training_file in self.train_metadata_:\n", + " classes.update(training_file[\"reference\"].labels())\n", + " classes = sorted(classes)\n", + "\n", + " # specify the addressed problem\n", + " self.specifications = Specifications(\n", + " # it is a multi-label classification problem\n", + " problem=Problem.MULTI_LABEL_CLASSIFICATION,\n", + " # we expect the model to output one prediction \n", + " # for the whole chunk\n", + " resolution=Resolution.CHUNK,\n", + " # the model will ingest chunks with that duration (in seconds)\n", + " duration=self.duration,\n", + " # human-readable names of classes\n", + " classes=classes)\n", + "\n", + " # `has_validation` is True iff protocol defines a development set\n", + " if not self.has_validation:\n", + " return\n", + "\n", + " # load metadata for validation subset\n", + " self.validation_metadata_ = list()\n", + " for validation_file in self.protocol.development():\n", + " self.validation_metadata_.append({\n", + " \"audio\": validation_file[\"audio\"],\n", + " \"num_samples\": math.floor(validation_file[\"duration\"] / self.duration),\n", + " \"annotation\": validation_file[\"annotation\"],\n", + " })\n", + " \n", + " \n", "\n", " def train__iter__(self):\n", " # this method generates training samples, one at a time, \"ad infinitum\". each worker \n", diff --git a/tutorials/overlapped_speech_detection.ipynb b/tutorials/overlapped_speech_detection.ipynb index 78c6372cb..1ad5d4090 100644 --- a/tutorials/overlapped_speech_detection.ipynb +++ b/tutorials/overlapped_speech_detection.ipynb @@ -20,6 +20,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -39,6 +40,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -49,6 +51,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -84,6 +87,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -103,6 +107,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -110,6 +115,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -130,6 +136,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -147,6 +154,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -161,10 +169,11 @@ "source": [ "import pytorch_lightning as pl\n", "trainer = pl.Trainer(max_epochs=10)\n", - "trainer.fit(model, osd)" + "trainer.fit(model)" ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -185,6 +194,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -212,6 +222,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -219,6 +230,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -242,6 +254,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -258,6 +271,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -265,6 +279,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -297,6 +312,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ From 7ea9c9ac72e07f6afc3a103256e0fe7d640adffc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Herv=C3=A9=20BREDIN?= Date: Wed, 24 May 2023 17:24:31 +0200 Subject: [PATCH 08/55] fix(train): prevent metadata preparation to happen twice cc @clement-pages --- CHANGELOG.md | 1 + pyannote/audio/core/model.py | 5 +++-- pyannote/audio/core/task.py | 28 ++++++++++++++++++++++-- pyannote/audio/tasks/embedding/mixins.py | 3 --- 4 files changed, 30 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 79a2e93ab..64eaea2ce 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -41,6 +41,7 @@ - fix(pipeline): fix support for IOBase audio - fix(pipeline): fix corner case with no speaker + - fix(train): prevent metadata preparation to happen twice ### Dependencies diff --git a/pyannote/audio/core/model.py b/pyannote/audio/core/model.py index 199dcfb24..2286f010f 100644 --- a/pyannote/audio/core/model.py +++ b/pyannote/audio/core/model.py @@ -101,8 +101,8 @@ def task(self) -> Task: @task.setter def task(self, task: Task): - self._task = task del self.specifications + self._task = task def build(self): # use this method to add task-dependent layers to the model @@ -225,7 +225,7 @@ def __output_dimension(example_output: torch.Tensor, **kwargs) -> int: def setup(self, stage=None): if stage == "fit": - self.task.setup() + self.task.setup_metadata() # list of layers before adding task-dependent layers before = set((name, id(module)) for name, module in self.named_modules()) @@ -311,6 +311,7 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]): self.specifications = checkpoint["pyannote.audio"]["specifications"] + # add task-dependent (e.g. final classifier) layers self.setup() def forward( diff --git a/pyannote/audio/core/task.py b/pyannote/audio/core/task.py index de441590b..b7a626bd9 100644 --- a/pyannote/audio/core/task.py +++ b/pyannote/audio/core/task.py @@ -268,7 +268,28 @@ def prepare_data(self): """ pass - def setup(self): + @property + def specifications(self) -> Union[Specifications, Tuple[Specifications]]: + # setup metadata on-demand the first time specifications are requested and missing + if not hasattr(self, "_specifications"): + self.setup_metadata() + return self._specifications + + @specifications.setter + def specifications( + self, specifications: Union[Specifications, Tuple[Specifications]] + ): + self._specifications = specifications + + @property + def has_setup_metadata(self): + return getattr(self, "_has_setup_metadata", False) + + @has_setup_metadata.setter + def has_setup_metadata(self, value: bool): + self._has_setup_metadata = value + + def setup_metadata(self): """Called at the beginning of training at the very beginning of Model.setup(stage="fit") Notes @@ -278,7 +299,10 @@ def setup(self): If `specifications` attribute has not been set in `__init__`, `setup` is your last chance to set it. """ - pass + + if not self.has_setup_metadata: + self.setup() + self.has_setup_metadata = True def setup_loss_func(self): pass diff --git a/pyannote/audio/tasks/embedding/mixins.py b/pyannote/audio/tasks/embedding/mixins.py index d83009802..4bc51c9b5 100644 --- a/pyannote/audio/tasks/embedding/mixins.py +++ b/pyannote/audio/tasks/embedding/mixins.py @@ -79,9 +79,6 @@ def setup(self): # loop over the training set, remove annotated regions shorter than # chunk duration, and keep track of the reference annotations, per class. - # FIXME: it looks like this time consuming step is called multiple times. - # it should not be... - self._train = dict() desc = f"Loading {self.protocol.name} training labels" From 93ab70a33c228fd1bf8a50321ae8ae8b1cd8f0dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Herv=C3=A9=20BREDIN?= Date: Wed, 24 May 2023 22:22:07 +0200 Subject: [PATCH 09/55] fix: tentative fix for CUDA error --- CHANGELOG.md | 4 +- pyannote/audio/core/inference.py | 4 +- pyannote/audio/core/model.py | 70 +++++++++---------- .../pipelines/overlapped_speech_detection.py | 2 +- pyannote/audio/pipelines/resegmentation.py | 2 +- .../audio/pipelines/speaker_diarization.py | 2 +- .../audio/pipelines/speaker_verification.py | 2 +- pyannote/audio/pipelines/utils/oracle.py | 2 +- pyannote/audio/tasks/segmentation/mixins.py | 11 --- .../audio/tasks/segmentation/multilabel.py | 12 ++-- .../overlapped_speech_detection.py | 10 +-- .../tasks/segmentation/speaker_diarization.py | 10 +-- .../segmentation/voice_activity_detection.py | 10 +-- 13 files changed, 68 insertions(+), 73 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 64eaea2ce..66a5a7f69 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,8 +20,8 @@ - BREAKING(model): get rid of (flaky) `Model.introspection` If, for some weird reason, you wrote some custom code based on that, you should instead rely on: * `Model.example_output(duration=...)` to get example output(s) - * `Model.output_frames` to get output frame resolution(s) - * `Model.output_dimension` to get output dimension(s) + * `Model.example_output.frames` to get output frame resolution(s) + * `Model.example_output.dimension` to get output dimension(s) ### Features and improvements diff --git a/pyannote/audio/core/inference.py b/pyannote/audio/core/inference.py index 98f72f6e9..a4847b83e 100644 --- a/pyannote/audio/core/inference.py +++ b/pyannote/audio/core/inference.py @@ -256,7 +256,7 @@ def slide( step_size: int = round(self.step * sample_rate) _, num_samples = waveform.shape - frames = self.model.output_frames + frames = self.model.example_output.frames def __frames( frames, specifications: Optional[Specifications] = None @@ -268,7 +268,7 @@ def __frames( frames: Union[SlidingWindow, Tuple[SlidingWindow]] = map_with_specifications( self.model.specifications, __frames, - self.model.output_frames, + self.model.example_output.frames, ) # prepare complete chunks diff --git a/pyannote/audio/core/model.py b/pyannote/audio/core/model.py index 2286f010f..5cb6c0e6b 100644 --- a/pyannote/audio/core/model.py +++ b/pyannote/audio/core/model.py @@ -24,6 +24,8 @@ import os import warnings +from dataclasses import dataclass +from functools import cached_property from importlib import import_module from pathlib import Path from typing import Any, Dict, List, Optional, Text, Tuple, Union @@ -65,6 +67,13 @@ class Introspection: pass +@dataclass +class Output: + num_frames: int + dimension: int + frames: SlidingWindow + + class Model(pl.LightningModule): """Base model @@ -101,7 +110,12 @@ def task(self) -> Task: @task.setter def task(self, task: Task): + # reset (cached) properties when task changes del self.specifications + try: + del self.example_output + except AttributeError: + pass self._task = task def build(self): @@ -173,54 +187,33 @@ def __example_input_array(self, duration: Optional[float] = None) -> torch.Tenso def example_input_array(self) -> torch.Tensor: return self.__example_input_array() - def example_output( - self, duration: Optional[float] = None - ) -> Union[torch.Tensor, Tuple[torch.Tensor]]: + @cached_property + def example_output(self) -> Union[Output, Tuple[Output]]: """Example output""" - example_input_array = self.__example_input_array(duration=duration) + example_input_array = self.__example_input_array() with torch.inference_mode(): example_output = self(example_input_array) - if not isinstance(example_output, (torch.Tensor, tuple)): - raise ValueError( - "Models must return either a torch.Tensor or a tuple of torch.Tensor" - ) - - return example_output - - @property - def output_frames( - self, - ) -> Union[Optional[SlidingWindow], Tuple[Optional[SlidingWindow]]]: - """Output frames as (tuple of) SlidingWindow(s)""" - - def __output_frames( + def __example_output( example_output: torch.Tensor, specifications: Specifications = None, - ) -> Optional[SlidingWindow]: + ) -> Output: + _, num_frames, dimension = example_output.shape + if specifications.resolution == Resolution.FRAME: - _, num_frames, _ = example_output.shape frame_duration = specifications.duration / num_frames - return SlidingWindow(step=frame_duration, duration=frame_duration) - - return None - - return map_with_specifications( - self.specifications, __output_frames, self.example_output() - ) - - @property - def output_dimension(self) -> Union[int, Tuple[int]]: - """Output dimension as (tuple of) int(s)""" - - duration = next(iter(self.specifications)).duration - example_output = self.example_output(duration=duration) + frames = SlidingWindow(step=frame_duration, duration=frame_duration) + else: + frames = None - def __output_dimension(example_output: torch.Tensor, **kwargs) -> int: - return example_output.shape[-1] + return Output( + num_frames=num_frames, + dimension=dimension, + frames=frames, + ) return map_with_specifications( - self.specifications, __output_dimension, example_output + self.specifications, __example_output, example_output ) def setup(self, stage=None): @@ -266,6 +259,9 @@ def setup(self, stage=None): # setup custom validation metrics self.task.setup_validation_metric() + # cache for later (and to avoid later CUDA error with multiprocessing) + _ = self.example_output + # list of layers after adding task-dependent layers after = set((name, id(module)) for name, module in self.named_modules()) diff --git a/pyannote/audio/pipelines/overlapped_speech_detection.py b/pyannote/audio/pipelines/overlapped_speech_detection.py index a326b8786..064cae1be 100644 --- a/pyannote/audio/pipelines/overlapped_speech_detection.py +++ b/pyannote/audio/pipelines/overlapped_speech_detection.py @@ -128,7 +128,7 @@ def __init__( # load model model = get_model(segmentation, use_auth_token=use_auth_token) - if model.output_dimension > 1: + if model.example_output.dimension > 1: inference_kwargs["pre_aggregation_hook"] = lambda scores: np.partition( scores, -2, axis=-1 )[:, :, -2, np.newaxis] diff --git a/pyannote/audio/pipelines/resegmentation.py b/pyannote/audio/pipelines/resegmentation.py index 468d5087d..bb71abf22 100644 --- a/pyannote/audio/pipelines/resegmentation.py +++ b/pyannote/audio/pipelines/resegmentation.py @@ -95,7 +95,7 @@ def __init__( model: Model = get_model(segmentation, use_auth_token=use_auth_token) self._segmentation = Inference(model) - self._frames = self._segmentation.model.output_frames + self._frames = self._segmentation.model.example_output.frames self._audio = model.audio diff --git a/pyannote/audio/pipelines/speaker_diarization.py b/pyannote/audio/pipelines/speaker_diarization.py index 038bd4676..8cf30f3b9 100644 --- a/pyannote/audio/pipelines/speaker_diarization.py +++ b/pyannote/audio/pipelines/speaker_diarization.py @@ -136,7 +136,7 @@ def __init__( skip_aggregation=True, batch_size=segmentation_batch_size, ) - self._frames: SlidingWindow = self._segmentation.model.output_frames + self._frames: SlidingWindow = self._segmentation.model.example_output.frames if self._segmentation.model.specifications.powerset: self.segmentation = ParamDict( diff --git a/pyannote/audio/pipelines/speaker_verification.py b/pyannote/audio/pipelines/speaker_verification.py index f928aabbb..a672e2017 100644 --- a/pyannote/audio/pipelines/speaker_verification.py +++ b/pyannote/audio/pipelines/speaker_verification.py @@ -419,7 +419,7 @@ def sample_rate(self) -> int: @cached_property def dimension(self) -> int: - return self.model_.output_dimension + return self.model_.example_output.dimension @cached_property def metric(self) -> str: diff --git a/pyannote/audio/pipelines/utils/oracle.py b/pyannote/audio/pipelines/utils/oracle.py index 0b6b58f85..44b4ded61 100644 --- a/pyannote/audio/pipelines/utils/oracle.py +++ b/pyannote/audio/pipelines/utils/oracle.py @@ -39,7 +39,7 @@ def oracle_segmentation( Simulates inference based on an (imaginary) oracle segmentation model: >>> oracle = Model.from_pretrained("oracle") - >>> assert frames == oracle.output_frames + >>> assert frames == oracle.example_output.frames >>> inference = Inference(oracle, duration=window.duration, step=window.step, skip_aggregation=True) >>> oracle_segmentation = inference(file) diff --git a/pyannote/audio/tasks/segmentation/mixins.py b/pyannote/audio/tasks/segmentation/mixins.py index 016d2032d..f071488ab 100644 --- a/pyannote/audio/tasks/segmentation/mixins.py +++ b/pyannote/audio/tasks/segmentation/mixins.py @@ -25,13 +25,11 @@ import random import warnings from collections import defaultdict -from functools import cached_property from typing import Dict, Sequence, Union import matplotlib.pyplot as plt import numpy as np import torch -from pyannote.core import SlidingWindow from pyannote.database.protocol import SegmentationProtocol, SpeakerDiarizationProtocol from pyannote.database.protocol.protocol import Scope, Subset from pytorch_lightning.loggers import MLFlowLogger, TensorBoardLogger @@ -386,15 +384,6 @@ def setup(self): dtype = [("file_id", "i"), ("start", "f"), ("duration", "f")] self.validation_chunks = np.array(validation_chunks, dtype=dtype) - @cached_property - def frames(self) -> SlidingWindow: - return self.model.output_frames - - @cached_property - def num_frames_per_chunk(self) -> int: - batch_size, num_frames, num_classes = self.model.example_output().shape - return num_frames - def default_metric( self, ) -> Union[Metric, Sequence[Metric], Dict[str, Metric]]: diff --git a/pyannote/audio/tasks/segmentation/multilabel.py b/pyannote/audio/tasks/segmentation/multilabel.py index 3afdd6bda..1917d2806 100644 --- a/pyannote/audio/tasks/segmentation/multilabel.py +++ b/pyannote/audio/tasks/segmentation/multilabel.py @@ -178,19 +178,23 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): # discretize chunk annotations at model output resolution start = np.maximum(chunk_annotations["start"], chunk.start) - chunk.start - start_idx = np.floor(start / self.frames.step).astype(int) + start_idx = np.floor(start / self.model.example_output.frames.step).astype(int) end = np.minimum(chunk_annotations["end"], chunk.end) - chunk.start - end_idx = np.ceil(end / self.frames.step).astype(int) + end_idx = np.ceil(end / self.model.example_output.frames.step).astype(int) # frame-level targets (-1 for un-annotated classes) - y = -np.ones((self.num_frames_per_chunk, len(self.classes)), dtype=np.int8) + y = -np.ones( + (self.model.example_output.num_frames, len(self.classes)), dtype=np.int8 + ) y[:, self.annotated_classes[file_id]] = 0 for start, end, label in zip( start_idx, end_idx, chunk_annotations["global_label_idx"] ): y[start:end, label] = 1 - sample["y"] = SlidingWindowFeature(y, self.frames, labels=self.classes) + sample["y"] = SlidingWindowFeature( + y, self.model.example_output.frames, labels=self.classes + ) metadata = self.metadata[file_id] sample["meta"] = {key: metadata[key] for key in metadata.dtype.names} diff --git a/pyannote/audio/tasks/segmentation/overlapped_speech_detection.py b/pyannote/audio/tasks/segmentation/overlapped_speech_detection.py index 492190938..cd3711d61 100644 --- a/pyannote/audio/tasks/segmentation/overlapped_speech_detection.py +++ b/pyannote/audio/tasks/segmentation/overlapped_speech_detection.py @@ -172,17 +172,19 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): # discretize chunk annotations at model output resolution start = np.maximum(chunk_annotations["start"], chunk.start) - chunk.start - start_idx = np.floor(start / self.frames.step).astype(int) + start_idx = np.floor(start / self.model.example_output.frames.step).astype(int) end = np.minimum(chunk_annotations["end"], chunk.end) - chunk.start - end_idx = np.ceil(end / self.frames.step).astype(int) + end_idx = np.ceil(end / self.model.example_output.frames.step).astype(int) # frame-level targets - y = np.zeros((self.num_frames_per_chunk, 1), dtype=np.uint8) + y = np.zeros((self.model.example_output.num_frames, 1), dtype=np.uint8) for start, end in zip(start_idx, end_idx): y[start:end, 0] += 1 y = 1 * (y > 1) - sample["y"] = SlidingWindowFeature(y, self.frames, labels=["speech"]) + sample["y"] = SlidingWindowFeature( + y, self.model.example_output.frames, labels=["speech"] + ) metadata = self.metadata[file_id] sample["meta"] = {key: metadata[key] for key in metadata.dtype.names} diff --git a/pyannote/audio/tasks/segmentation/speaker_diarization.py b/pyannote/audio/tasks/segmentation/speaker_diarization.py index b5ffe3da1..27c4b2dc9 100644 --- a/pyannote/audio/tasks/segmentation/speaker_diarization.py +++ b/pyannote/audio/tasks/segmentation/speaker_diarization.py @@ -337,9 +337,9 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): # discretize chunk annotations at model output resolution start = np.maximum(chunk_annotations["start"], chunk.start) - chunk.start - start_idx = np.floor(start / self.frames.step).astype(int) + start_idx = np.floor(start / self.model.example_output.frames.step).astype(int) end = np.minimum(chunk_annotations["end"], chunk.end) - chunk.start - end_idx = np.ceil(end / self.frames.step).astype(int) + end_idx = np.ceil(end / self.model.example_output.frames.step).astype(int) # get list and number of labels for current scope labels = list(np.unique(chunk_annotations[label_scope_key])) @@ -349,7 +349,7 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): pass # initial frame-level targets - y = np.zeros((self.num_frames_per_chunk, num_labels), dtype=np.uint8) + y = np.zeros((self.model.example_output.num_frames, num_labels), dtype=np.uint8) # map labels to indices mapping = {label: idx for idx, label in enumerate(labels)} @@ -360,7 +360,9 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): mapped_label = mapping[label] y[start:end, mapped_label] = 1 - sample["y"] = SlidingWindowFeature(y, self.frames, labels=labels) + sample["y"] = SlidingWindowFeature( + y, self.model.example_output.frames, labels=labels + ) metadata = self.metadata[file_id] sample["meta"] = {key: metadata[key] for key in metadata.dtype.names} diff --git a/pyannote/audio/tasks/segmentation/voice_activity_detection.py b/pyannote/audio/tasks/segmentation/voice_activity_detection.py index d740427f0..967ea1f9b 100644 --- a/pyannote/audio/tasks/segmentation/voice_activity_detection.py +++ b/pyannote/audio/tasks/segmentation/voice_activity_detection.py @@ -154,16 +154,18 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): # discretize chunk annotations at model output resolution start = np.maximum(chunk_annotations["start"], chunk.start) - chunk.start - start_idx = np.floor(start / self.frames.step).astype(int) + start_idx = np.floor(start / self.model.example_output.frames.step).astype(int) end = np.minimum(chunk_annotations["end"], chunk.end) - chunk.start - end_idx = np.ceil(end / self.frames.step).astype(int) + end_idx = np.ceil(end / self.model.example_output.frames.step).astype(int) # frame-level targets - y = np.zeros((self.num_frames_per_chunk, 1), dtype=np.uint8) + y = np.zeros((self.model.example_output.num_frames, 1), dtype=np.uint8) for start, end in zip(start_idx, end_idx): y[start:end, 0] = 1 - sample["y"] = SlidingWindowFeature(y, self.frames, labels=["speech"]) + sample["y"] = SlidingWindowFeature( + y, self.model.example_output.frames, labels=["speech"] + ) metadata = self.metadata[file_id] sample["meta"] = {key: metadata[key] for key in metadata.dtype.names} From f4337641260fae097aea38409c11bf25123ceaed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Herv=C3=A9=20BREDIN?= Date: Wed, 24 May 2023 22:39:10 +0200 Subject: [PATCH 10/55] doc: update changelog --- CHANGELOG.md | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 66a5a7f69..15e11405a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,15 +18,13 @@ * replace `Audio(mono=True)` by `Audio(mono="downmix")`; * replace `Audio(mono=False)` by `Audio()`. - BREAKING(model): get rid of (flaky) `Model.introspection` - If, for some weird reason, you wrote some custom code based on that, you should instead rely on: - * `Model.example_output(duration=...)` to get example output(s) - * `Model.example_output.frames` to get output frame resolution(s) - * `Model.example_output.dimension` to get output dimension(s) + If, for some weird reason, you wrote some custom code based on that, + you should instead rely on `Model.example_output`. ### Features and improvements - - feat(task): add support for multi-task models (for inference) + - feat(task): add support for multi-task models - feat(pipeline): send pipeline to device with `pipeline.to(device)` - feat(pipeline): make `segmentation_batch_size` and `embedding_batch_size` mutable in `SpeakerDiarization` pipeline (they now default to `1`) - feat(task): add [powerset](https://arxiv.org/PLACEHOLDER) support to `SpeakerDiarization` task From d76db0beb63e6a7ad86065ad5fa2f28511630ce2 Mon Sep 17 00:00:00 2001 From: Joonas Kalda Date: Fri, 10 Mar 2023 10:07:33 +0200 Subject: [PATCH 11/55] add convnet layers to PyanNet --- pyannote/audio/models/segmentation/PyanNet.py | 32 +++++++++++++++---- 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/pyannote/audio/models/segmentation/PyanNet.py b/pyannote/audio/models/segmentation/PyanNet.py index 5af3734b1..31809ef13 100644 --- a/pyannote/audio/models/segmentation/PyanNet.py +++ b/pyannote/audio/models/segmentation/PyanNet.py @@ -33,6 +33,7 @@ from pyannote.audio.core.task import Task from pyannote.audio.models.blocks.sincnet import SincNet from pyannote.audio.utils.params import merge_dict +from asteroid.masknn.convolutional import TDConvNet class PyanNet(Model): @@ -70,12 +71,24 @@ class PyanNet(Model): "dropout": 0.0, } LINEAR_DEFAULTS = {"hidden_size": 128, "num_layers": 2} + CONVNET_DEFAULTS = { + "n_src": 3, + "n_blocks":8, + "n_repeats":3, + "bn_chan":128, + "hid_chan":512, + "skip_chan":128, + "conv_kernel_size":3, + "norm_type":"gLN", + "mask_act":"relu", + } def __init__( self, sincnet: dict = None, lstm: dict = None, linear: dict = None, + convnet: dict = None, sample_rate: int = 16000, num_channels: int = 1, task: Optional[Task] = None, @@ -87,15 +100,18 @@ def __init__( lstm = merge_dict(self.LSTM_DEFAULTS, lstm) lstm["batch_first"] = True linear = merge_dict(self.LINEAR_DEFAULTS, linear) - self.save_hyperparameters("sincnet", "lstm", "linear") + convnet = merge_dict(self.CONVNET_DEFAULTS, convnet) + self.save_hyperparameters("sincnet", "lstm", "linear", "convnet") self.sincnet = SincNet(**self.hparams.sincnet) + self.convnet = TDConvNet(60, **self.hparams.convnet) + monolithic = lstm["monolithic"] if monolithic: multi_layer_lstm = dict(lstm) del multi_layer_lstm["monolithic"] - self.lstm = nn.LSTM(60, **multi_layer_lstm) + self.lstm = nn.LSTM(3*60, **multi_layer_lstm) else: num_layers = lstm["num_layers"] @@ -110,7 +126,7 @@ def __init__( self.lstm = nn.ModuleList( [ nn.LSTM( - 60 + 3*60 if i == 0 else lstm["hidden_size"] * (2 if lstm["bidirectional"] else 1), **one_layer_lstm @@ -170,13 +186,15 @@ def forward(self, waveforms: torch.Tensor) -> torch.Tensor: """ outputs = self.sincnet(waveforms) + outputs = self.convnet(outputs) + outputs = rearrange( + outputs, "batch nsrc nfilters nframes -> batch nframes nfilters nsrc" + ) + outputs = torch.flatten(outputs, start_dim=2, end_dim=3) if self.hparams.lstm["monolithic"]: - outputs, _ = self.lstm( - rearrange(outputs, "batch feature frame -> batch frame feature") - ) + outputs, _ = self.lstm(outputs) else: - outputs = rearrange(outputs, "batch feature frame -> batch frame feature") for i, lstm in enumerate(self.lstm): outputs, _ = lstm(outputs) if i + 1 < self.hparams.lstm["num_layers"]: From 3b2737ab6fd8b40b32b609762892c40dac013a27 Mon Sep 17 00:00:00 2001 From: Joonas Kalda Date: Wed, 15 Mar 2023 11:41:36 +0200 Subject: [PATCH 12/55] add stft and free encoders/decoder --- pyannote/audio/models/segmentation/PyanNet.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/pyannote/audio/models/segmentation/PyanNet.py b/pyannote/audio/models/segmentation/PyanNet.py index 31809ef13..d07d5641b 100644 --- a/pyannote/audio/models/segmentation/PyanNet.py +++ b/pyannote/audio/models/segmentation/PyanNet.py @@ -104,6 +104,12 @@ def __init__( self.save_hyperparameters("sincnet", "lstm", "linear", "convnet") self.sincnet = SincNet(**self.hparams.sincnet) + # self.encoder, self.decoder = make_enc_dec( + # fb_name="free", kernel_size=16, n_filters=512, stride=8, sample_rate=16000 + # ) + self.encoder, self.decoder = make_enc_dec( + fb_name="stft", kernel_size=512, n_filters=512, stride=25, sample_rate=16000 + ) self.convnet = TDConvNet(60, **self.hparams.convnet) @@ -185,7 +191,8 @@ def forward(self, waveforms: torch.Tensor) -> torch.Tensor: scores : (batch, frame, classes) """ - outputs = self.sincnet(waveforms) + # outputs = self.sincnet(waveforms) + outputs = self.encoder(waveforms) outputs = self.convnet(outputs) outputs = rearrange( outputs, "batch nsrc nfilters nframes -> batch nframes nfilters nsrc" From f3a57906dff4075650dfd03c524bf5728233c015 Mon Sep 17 00:00:00 2001 From: Joonas Kalda Date: Wed, 15 Mar 2023 15:50:31 +0200 Subject: [PATCH 13/55] multitask learning first attempt --- pyannote/audio/models/segmentation/PyanNet.py | 69 +++++++++++-------- .../tasks/segmentation/speaker_diarization.py | 50 ++++++++++++-- 2 files changed, 86 insertions(+), 33 deletions(-) diff --git a/pyannote/audio/models/segmentation/PyanNet.py b/pyannote/audio/models/segmentation/PyanNet.py index d07d5641b..1b20c1de9 100644 --- a/pyannote/audio/models/segmentation/PyanNet.py +++ b/pyannote/audio/models/segmentation/PyanNet.py @@ -34,6 +34,8 @@ from pyannote.audio.models.blocks.sincnet import SincNet from pyannote.audio.utils.params import merge_dict from asteroid.masknn.convolutional import TDConvNet +from asteroid_filterbanks import make_enc_dec +from asteroid.utils.torch_utils import pad_x_to_y class PyanNet(Model): @@ -63,6 +65,12 @@ class PyanNet(Model): """ SINCNET_DEFAULTS = {"stride": 10} + ENCODER_DECODER_DEFAULTS = { + "fb_name": "stft", + "kernel_size": 512, + "n_filters": 512, + "stride": 256, + } LSTM_DEFAULTS = { "hidden_size": 128, "num_layers": 2, @@ -72,52 +80,55 @@ class PyanNet(Model): } LINEAR_DEFAULTS = {"hidden_size": 128, "num_layers": 2} CONVNET_DEFAULTS = { - "n_src": 3, - "n_blocks":8, - "n_repeats":3, - "bn_chan":128, - "hid_chan":512, - "skip_chan":128, - "conv_kernel_size":3, - "norm_type":"gLN", - "mask_act":"relu", + "n_src": 6, + "n_blocks": 8, + "n_repeats": 3, + "bn_chan": 128, + "hid_chan": 512, + "skip_chan": 128, + "conv_kernel_size": 3, + "norm_type": "gLN", + "mask_act": "relu", } def __init__( self, - sincnet: dict = None, + encoder_decoder: dict = None, lstm: dict = None, linear: dict = None, convnet: dict = None, + free_encoder: dict = None, + stft_encoder: dict = None, sample_rate: int = 16000, num_channels: int = 1, task: Optional[Task] = None, + encoder_type: str = None, ): super().__init__(sample_rate=sample_rate, num_channels=num_channels, task=task) - sincnet = merge_dict(self.SINCNET_DEFAULTS, sincnet) - sincnet["sample_rate"] = sample_rate lstm = merge_dict(self.LSTM_DEFAULTS, lstm) lstm["batch_first"] = True linear = merge_dict(self.LINEAR_DEFAULTS, linear) convnet = merge_dict(self.CONVNET_DEFAULTS, convnet) - self.save_hyperparameters("sincnet", "lstm", "linear", "convnet") + encoder_decoder = merge_dict(self.ENCODER_DECODER_DEFAULTS, encoder_decoder) + self.save_hyperparameters("encoder_decoder", "lstm", "linear", "convnet") - self.sincnet = SincNet(**self.hparams.sincnet) - # self.encoder, self.decoder = make_enc_dec( - # fb_name="free", kernel_size=16, n_filters=512, stride=8, sample_rate=16000 - # ) + if encoder_decoder["fb_name"] == "free": + n_feats_out = encoder_decoder["n_filters"] + elif encoder_decoder["fb_name"] == "stft": + n_feats_out = int(2 * (encoder_decoder["n_filters"] / 2 + 1)) + else: + raise ValueError("Filterbank type not recognized.") self.encoder, self.decoder = make_enc_dec( - fb_name="stft", kernel_size=512, n_filters=512, stride=25, sample_rate=16000 + sample_rate=sample_rate, **self.hparams.encoder_decoder ) - - self.convnet = TDConvNet(60, **self.hparams.convnet) + self.convnet = TDConvNet(n_feats_out, **self.hparams.convnet) monolithic = lstm["monolithic"] if monolithic: multi_layer_lstm = dict(lstm) del multi_layer_lstm["monolithic"] - self.lstm = nn.LSTM(3*60, **multi_layer_lstm) + self.lstm = nn.LSTM(6 * n_feats_out, **multi_layer_lstm) else: num_layers = lstm["num_layers"] @@ -132,7 +143,7 @@ def __init__( self.lstm = nn.ModuleList( [ nn.LSTM( - 3*60 + 6 * n_feats_out if i == 0 else lstm["hidden_size"] * (2 if lstm["bidirectional"] else 1), **one_layer_lstm @@ -191,11 +202,15 @@ def forward(self, waveforms: torch.Tensor) -> torch.Tensor: scores : (batch, frame, classes) """ - # outputs = self.sincnet(waveforms) - outputs = self.encoder(waveforms) - outputs = self.convnet(outputs) + tf_rep = self.encoder(waveforms) + masks = self.convnet(tf_rep) + + masked_tf_rep = masks * tf_rep.unsqueeze(1) + decoded_sources = self.decoder(masked_tf_rep) + decoded_sources = pad_x_to_y(decoded_sources, waveforms) + outputs = rearrange( - outputs, "batch nsrc nfilters nframes -> batch nframes nfilters nsrc" + masks, "batch nsrc nfilters nframes -> batch nframes nfilters nsrc" ) outputs = torch.flatten(outputs, start_dim=2, end_dim=3) @@ -211,4 +226,4 @@ def forward(self, waveforms: torch.Tensor) -> torch.Tensor: for linear in self.linear: outputs = F.leaky_relu(linear(outputs)) - return self.activation(self.classifier(outputs)) + return self.activation(self.classifier(outputs)), decoded_sources diff --git a/pyannote/audio/tasks/segmentation/speaker_diarization.py b/pyannote/audio/tasks/segmentation/speaker_diarization.py index eac795a47..0ab4a0bdf 100644 --- a/pyannote/audio/tasks/segmentation/speaker_diarization.py +++ b/pyannote/audio/tasks/segmentation/speaker_diarization.py @@ -13,7 +13,7 @@ # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIESOF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, @@ -53,6 +53,7 @@ from pyannote.audio.utils.loss import binary_cross_entropy, mse_loss, nll_loss from pyannote.audio.utils.permutation import permutate from pyannote.audio.utils.powerset import Powerset +from asteroid.losses import MixITLossWrapper, multisrc_neg_sisdr Subsets = list(Subset.__args__) Scopes = list(Scope.__args__) @@ -185,6 +186,7 @@ def __init__( self.balance = balance self.weight = weight self.vad_loss = vad_loss + self.separation_loss = MixITLossWrapper(multisrc_neg_sisdr, generalized=True) def setup(self): super().setup() @@ -416,6 +418,11 @@ def collate_y(self, batch) -> torch.Tensor: return torch.from_numpy(np.stack(collated_y)) + # def separation_loss(self, prediction, target): + # mixit_loss = MixITLossWrapper(multisrc_neg_sisdr, generalized=True) + # return mixit_loss + + def segmentation_loss( self, permutated_prediction: torch.Tensor, @@ -528,9 +535,16 @@ def training_step(self, batch, batch_idx: int): # corner case if not keep.any(): return None - + # TODO: pair up waveforms for MIXIT + bsz = waveform.shape[0] + mix1 = waveform[bsz // 2 :].squeeze(1) + mix2 = waveform[: bsz // 2].squeeze(1) + moms = mix1 + mix2 # forward pass - prediction = self.model(waveform) + # TODO: model should output predictions for estimated sources as well + prediction, _ = self.model(waveform) + _, prediction_sources = self.model(moms) + batch_size, num_frames, _ = prediction.shape # (batch_size, num_frames, num_classes) @@ -564,6 +578,20 @@ def training_step(self, batch, batch_idx: int): permutated_prediction, target, weight=weight ) + # TODO: add also separation loss, warmup? + mixit_loss = self.separation_loss( + prediction_sources, torch.stack((mix1, mix2)).transpose(0, 1) + ) + + self.model.log( + f"{self.logging_prefix}TrainSeparationLoss", + seg_loss, + on_step=False, + on_epoch=True, + prog_bar=False, + logger=True, + ) + self.model.log( "loss/train/segmentation", seg_loss, @@ -598,7 +626,7 @@ def training_step(self, batch, batch_idx: int): logger=True, ) - loss = seg_loss + vad_loss + loss = seg_loss + vad_loss + mixit_loss # skip batch if something went wrong for some reason if torch.isnan(loss): @@ -659,8 +687,14 @@ def validation_step(self, batch, batch_idx: int): # waveform = waveform[keep] # target = target[keep] + bsz = waveform.shape[0] + mix1 = waveform[bsz // 2 :].squeeze(1) + mix2 = waveform[: bsz // 2].squeeze(1) + moms = mix1 + mix2 + # forward pass - prediction = self.model(waveform) + prediction, _ = self.model(waveform) + _, prediction_sources = self.model(moms) batch_size, num_frames, _ = prediction.shape # frames weight @@ -696,6 +730,10 @@ def validation_step(self, batch, batch_idx: int): permutated_prediction, target, weight=weight ) + mixit_loss = self.separation_loss( + prediction_sources, torch.stack((mix1, mix2)).transpose(0, 1) + ) + self.model.log( "loss/val/segmentation", seg_loss, @@ -730,7 +768,7 @@ def validation_step(self, batch, batch_idx: int): logger=True, ) - loss = seg_loss + vad_loss + loss = seg_loss + vad_loss + mixit_loss self.model.log( "loss/val", From 1b03cf4da9398b04c7f85a046d17bd8e56b857c1 Mon Sep 17 00:00:00 2001 From: Joonas Kalda Date: Wed, 15 Mar 2023 17:24:17 +0200 Subject: [PATCH 14/55] properly logging mixit loss in train/valid --- .../tasks/segmentation/speaker_diarization.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/pyannote/audio/tasks/segmentation/speaker_diarization.py b/pyannote/audio/tasks/segmentation/speaker_diarization.py index 0ab4a0bdf..c20bea085 100644 --- a/pyannote/audio/tasks/segmentation/speaker_diarization.py +++ b/pyannote/audio/tasks/segmentation/speaker_diarization.py @@ -585,7 +585,7 @@ def training_step(self, batch, batch_idx: int): self.model.log( f"{self.logging_prefix}TrainSeparationLoss", - seg_loss, + mixit_loss, on_step=False, on_epoch=True, prog_bar=False, @@ -632,6 +632,7 @@ def training_step(self, batch, batch_idx: int): if torch.isnan(loss): return None + breakpoint() self.model.log( "loss/train", loss, @@ -735,7 +736,16 @@ def validation_step(self, batch, batch_idx: int): ) self.model.log( - "loss/val/segmentation", + f"{self.logging_prefix}ValSeparationLoss", + mixit_loss, + on_step=False, + on_epoch=True, + prog_bar=False, + logger=True, + ) + + self.model.log( + f"{self.logging_prefix}ValSegLoss", seg_loss, on_step=False, on_epoch=True, From 6294c4987bbc93cc5868f77f079a889068363eb1 Mon Sep 17 00:00:00 2001 From: Joonas Kalda Date: Thu, 16 Mar 2023 11:41:04 +0200 Subject: [PATCH 15/55] add a weight to mixit_loss --- .../audio/tasks/segmentation/speaker_diarization.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/pyannote/audio/tasks/segmentation/speaker_diarization.py b/pyannote/audio/tasks/segmentation/speaker_diarization.py index c20bea085..8b12b70ed 100644 --- a/pyannote/audio/tasks/segmentation/speaker_diarization.py +++ b/pyannote/audio/tasks/segmentation/speaker_diarization.py @@ -111,6 +111,8 @@ class SpeakerDiarization(SegmentationTaskMixin, Task): metric : optional Validation metric(s). Can be anything supported by torchmetrics.MetricCollection. Defaults to AUROC (area under the ROC curve). + mixit_loss_weight : float, optional + Factor that speaker separation loss is scaled by when calculating total loss. References ---------- @@ -143,6 +145,7 @@ def __init__( metric: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None, max_num_speakers: int = None, # deprecated in favor of `max_speakers_per_chunk`` loss: Literal["bce", "mse"] = None, # deprecated + mixit_loss_weight: float = 0.2, ): super().__init__( protocol, @@ -187,6 +190,7 @@ def __init__( self.weight = weight self.vad_loss = vad_loss self.separation_loss = MixITLossWrapper(multisrc_neg_sisdr, generalized=True) + self.mixit_loss_weight = mixit_loss_weight def setup(self): super().setup() @@ -626,13 +630,12 @@ def training_step(self, batch, batch_idx: int): logger=True, ) - loss = seg_loss + vad_loss + mixit_loss + loss = seg_loss + vad_loss + self.mixit_loss_weight * mixit_loss # skip batch if something went wrong for some reason if torch.isnan(loss): return None - breakpoint() self.model.log( "loss/train", loss, @@ -743,7 +746,7 @@ def validation_step(self, batch, batch_idx: int): prog_bar=False, logger=True, ) - + self.model.log( f"{self.logging_prefix}ValSegLoss", seg_loss, @@ -778,7 +781,7 @@ def validation_step(self, batch, batch_idx: int): logger=True, ) - loss = seg_loss + vad_loss + mixit_loss + loss = seg_loss + vad_loss + self.mixit_loss_weight * mixit_loss self.model.log( "loss/val", From ed2930ecdf2c9b7de92ff1230f3d3fd41e70a24a Mon Sep 17 00:00:00 2001 From: joonaskalda Date: Mon, 27 Mar 2023 08:58:55 +0000 Subject: [PATCH 16/55] reformulate multitask loss --- pyannote/audio/tasks/segmentation/speaker_diarization.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyannote/audio/tasks/segmentation/speaker_diarization.py b/pyannote/audio/tasks/segmentation/speaker_diarization.py index 8b12b70ed..33fc1fab4 100644 --- a/pyannote/audio/tasks/segmentation/speaker_diarization.py +++ b/pyannote/audio/tasks/segmentation/speaker_diarization.py @@ -630,7 +630,7 @@ def training_step(self, batch, batch_idx: int): logger=True, ) - loss = seg_loss + vad_loss + self.mixit_loss_weight * mixit_loss + loss = (1 - self.mixit_loss_weight) * (seg_loss + vad_loss) + self.mixit_loss_weight * mixit_loss # skip batch if something went wrong for some reason if torch.isnan(loss): @@ -781,7 +781,7 @@ def validation_step(self, batch, batch_idx: int): logger=True, ) - loss = seg_loss + vad_loss + self.mixit_loss_weight * mixit_loss + loss = (1 - self.mixit_loss_weight) * (seg_loss + vad_loss) + self.mixit_loss_weight * mixit_loss self.model.log( "loss/val", From 54f2080d75c87ac53fe71ad8ede155211adc8ecd Mon Sep 17 00:00:00 2001 From: joonaskalda Date: Sun, 16 Apr 2023 02:07:43 +0000 Subject: [PATCH 17/55] add dprnn --- pyannote/audio/models/segmentation/PyanNet.py | 20 ++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/pyannote/audio/models/segmentation/PyanNet.py b/pyannote/audio/models/segmentation/PyanNet.py index 1b20c1de9..eab93f021 100644 --- a/pyannote/audio/models/segmentation/PyanNet.py +++ b/pyannote/audio/models/segmentation/PyanNet.py @@ -36,6 +36,7 @@ from asteroid.masknn.convolutional import TDConvNet from asteroid_filterbanks import make_enc_dec from asteroid.utils.torch_utils import pad_x_to_y +from asteroid.masknn import DPRNN class PyanNet(Model): @@ -90,6 +91,16 @@ class PyanNet(Model): "norm_type": "gLN", "mask_act": "relu", } + DPRNN_DEFAULTS = { + "n_src": 6, + "n_repeats": 6, + "bn_chan": 128, + "hid_size": 128, + "chunk_size": 100, + "norm_type": "gLN", + "mask_act": "relu", + "rnn_type": "LSTM", + } def __init__( self, @@ -97,6 +108,7 @@ def __init__( lstm: dict = None, linear: dict = None, convnet: dict = None, + dprnn: dict = None, free_encoder: dict = None, stft_encoder: dict = None, sample_rate: int = 16000, @@ -110,8 +122,9 @@ def __init__( lstm["batch_first"] = True linear = merge_dict(self.LINEAR_DEFAULTS, linear) convnet = merge_dict(self.CONVNET_DEFAULTS, convnet) + dprnn = merge_dict(self.DPRNN_DEFAULTS, dprnn) encoder_decoder = merge_dict(self.ENCODER_DECODER_DEFAULTS, encoder_decoder) - self.save_hyperparameters("encoder_decoder", "lstm", "linear", "convnet") + self.save_hyperparameters("encoder_decoder", "lstm", "linear", "convnet", "dprnn") if encoder_decoder["fb_name"] == "free": n_feats_out = encoder_decoder["n_filters"] @@ -122,7 +135,8 @@ def __init__( self.encoder, self.decoder = make_enc_dec( sample_rate=sample_rate, **self.hparams.encoder_decoder ) - self.convnet = TDConvNet(n_feats_out, **self.hparams.convnet) + self.masker = DPRNN(n_feats_out, **self.hparams.dprnn) + #self.convnet= TDConvNet(n_feats_out, **self.hparams.convnet) monolithic = lstm["monolithic"] if monolithic: @@ -203,7 +217,7 @@ def forward(self, waveforms: torch.Tensor) -> torch.Tensor: """ tf_rep = self.encoder(waveforms) - masks = self.convnet(tf_rep) + masks = self.masker(tf_rep) masked_tf_rep = masks * tf_rep.unsqueeze(1) decoded_sources = self.decoder(masked_tf_rep) From c79109a025a1f3e85d677f35d856bbebc2978140 Mon Sep 17 00:00:00 2001 From: joonaskalda Date: Fri, 12 May 2023 08:49:04 +0000 Subject: [PATCH 18/55] pair mixtures from same file with no overlapping speakers --- pyannote/audio/tasks/segmentation/mixins.py | 72 ++++++++++++++++++- .../tasks/segmentation/speaker_diarization.py | 4 +- 2 files changed, 73 insertions(+), 3 deletions(-) diff --git a/pyannote/audio/tasks/segmentation/mixins.py b/pyannote/audio/tasks/segmentation/mixins.py index 142245ae8..4f9c0f594 100644 --- a/pyannote/audio/tasks/segmentation/mixins.py +++ b/pyannote/audio/tasks/segmentation/mixins.py @@ -435,6 +435,7 @@ def train__iter__helper(self, rng: random.Random, **filters): while True: # select one file at random (with probability proportional to its annotated duration) file_id = np.random.choice(file_ids, p=prob_annotated_duration) + annotations = self.annotations[np.where(self.annotations["file_id"] == file_id)[0]] # generate `num_chunks_per_file` chunks from this file for _ in range(num_chunks_per_file): @@ -457,7 +458,76 @@ def train__iter__helper(self, rng: random.Random, **filters): _, _, start, end = self.annotated_regions[annotated_region_index] start_time = rng.uniform(start, end - duration) - yield self.prepare_chunk(file_id, start_time, duration) + # find speakers that already appeared and all annotations that contain them + chunk_annotations = annotations[ + (annotations["start"] < start_time+duration) & (annotations["end"] > start_time) + ] + previous_speaker_labels = list(np.unique(chunk_annotations["file_label_idx"])) + repeated_speaker_annotations = annotations[np.isin(annotations["file_label_idx"], previous_speaker_labels)] + + if repeated_speaker_annotations.size == 0: + # if previous chunk has 0 speakers then just sample from all annotated regions again + first_chunk = self.prepare_chunk(file_id, start_time, duration) + first_chunk["meta"]["mixture_type"]="first_mixture" + yield first_chunk + + # selected one annotated region at random (with probability proportional to its duration) + annotated_region_index = np.random.choice( + annotated_region_indices, p=prob_annotated_regions_duration + ) + + # select one chunk at random in this annotated region + _, _, start, end = self.annotated_regions[annotated_region_index] + start_time = rng.uniform(start, end - duration) + + second_chunk = self.prepare_chunk(file_id, start_time, duration) + second_chunk["meta"]["mixture_type"]="second_mixture" + yield second_chunk + + else: + # merge segments that contain repeated speakers + merged_repeated_segments = [[repeated_speaker_annotations["start"][0],repeated_speaker_annotations["end"][0]]] + for _, start, end, _, _, _ in repeated_speaker_annotations: + previous = merged_repeated_segments[-1] + if start <= previous[1]: + previous[1] = max(previous[1], end) + else: + merged_repeated_segments.append([start, end]) + + # find segments that don't contain repeated speakers + segments_without_repeat = [] + current_region_index = 0 + previous_time = self.annotated_regions["start"][annotated_region_indices[0]] + for segment in merged_repeated_segments: + if segment[0] > self.annotated_regions["end"][annotated_region_indices[current_region_index]]: + current_region_index+=1 + previous_time = self.annotated_regions["start"][annotated_region_indices[current_region_index]] + + if segment[0] - previous_time > duration: + segments_without_repeat.append((previous_time, segment[0], segment[0] - previous_time)) + previous_time = segment[1] + + dtype = [("start", "f"), ("end", "f"),("duration", "f")] + segments_without_repeat = np.array(segments_without_repeat,dtype=dtype) + + if np.sum(segments_without_repeat["duration"]) != 0: + + # only yield chunks if it is possible to choose the second chunk so that yielded chunks are always paired + + first_chunk = self.prepare_chunk(file_id, start_time, duration) + first_chunk["meta"]["mixture_type"]="first_mixture" + yield first_chunk + + prob_segments_duration = segments_without_repeat["duration"] / np.sum(segments_without_repeat["duration"]) + segment = np.random.choice( + segments_without_repeat, p=prob_segments_duration + ) + + start, end, _ = segment + new_start_time = rng.uniform(start, end - duration) + second_chunk = self.prepare_chunk(file_id, new_start_time, duration) + second_chunk["meta"]["mixture_type"]="second_mixture" + yield second_chunk def train__iter__(self): """Iterate over training samples diff --git a/pyannote/audio/tasks/segmentation/speaker_diarization.py b/pyannote/audio/tasks/segmentation/speaker_diarization.py index 33fc1fab4..bfee14845 100644 --- a/pyannote/audio/tasks/segmentation/speaker_diarization.py +++ b/pyannote/audio/tasks/segmentation/speaker_diarization.py @@ -541,8 +541,8 @@ def training_step(self, batch, batch_idx: int): return None # TODO: pair up waveforms for MIXIT bsz = waveform.shape[0] - mix1 = waveform[bsz // 2 :].squeeze(1) - mix2 = waveform[: bsz // 2].squeeze(1) + mix1 = waveform[0::2].squeeze(1) + mix2 = waveform[1::2].squeeze(1) moms = mix1 + mix2 # forward pass # TODO: model should output predictions for estimated sources as well From b8ceceb525387b3bed366446df7b86e8a798582c Mon Sep 17 00:00:00 2001 From: joonaskalda Date: Fri, 12 May 2023 08:50:50 +0000 Subject: [PATCH 19/55] fix mixit loss for odd batch size in validation --- pyannote/audio/tasks/segmentation/speaker_diarization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyannote/audio/tasks/segmentation/speaker_diarization.py b/pyannote/audio/tasks/segmentation/speaker_diarization.py index bfee14845..fdbf38da7 100644 --- a/pyannote/audio/tasks/segmentation/speaker_diarization.py +++ b/pyannote/audio/tasks/segmentation/speaker_diarization.py @@ -692,7 +692,7 @@ def validation_step(self, batch, batch_idx: int): # target = target[keep] bsz = waveform.shape[0] - mix1 = waveform[bsz // 2 :].squeeze(1) + mix1 = waveform[bsz // 2 : 2 * (bsz // 2)].squeeze(1) mix2 = waveform[: bsz // 2].squeeze(1) moms = mix1 + mix2 From 8155f043ebee2919aed10546874d6c7c5af9d75a Mon Sep 17 00:00:00 2001 From: joonaskalda Date: Fri, 12 May 2023 08:54:02 +0000 Subject: [PATCH 20/55] make the MoM part of the original batch --- pyannote/audio/tasks/segmentation/mixins.py | 23 +++++++++++++++++++ .../tasks/segmentation/speaker_diarization.py | 4 ++-- 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/pyannote/audio/tasks/segmentation/mixins.py b/pyannote/audio/tasks/segmentation/mixins.py index 4f9c0f594..b7b55c6ab 100644 --- a/pyannote/audio/tasks/segmentation/mixins.py +++ b/pyannote/audio/tasks/segmentation/mixins.py @@ -30,6 +30,7 @@ import matplotlib.pyplot as plt import numpy as np import torch +from pyannote.core import Segment, SlidingWindow, SlidingWindowFeature from pyannote.database.protocol import SegmentationProtocol, SpeakerDiarizationProtocol from pyannote.database.protocol.protocol import Scope, Subset from pytorch_lightning.loggers import MLFlowLogger, TensorBoardLogger @@ -484,6 +485,17 @@ def train__iter__helper(self, rng: random.Random, **filters): second_chunk["meta"]["mixture_type"]="second_mixture" yield second_chunk + # add previous two chunks to get a third one + third_chunk = dict() + third_chunk["X"] = first_chunk["X"] + second_chunk["X"] + third_chunk["meta"] = first_chunk["meta"].copy() + y = np.concatenate((first_chunk["y"].data, second_chunk["y"].data), axis=1) + frames = first_chunk["y"].sliding_window + labels = first_chunk["y"].labels + second_chunk["y"].labels + third_chunk["y"] = SlidingWindowFeature(y, frames, labels=labels) + third_chunk["meta"]["mixture_type"]="mom" + yield third_chunk + else: # merge segments that contain repeated speakers merged_repeated_segments = [[repeated_speaker_annotations["start"][0],repeated_speaker_annotations["end"][0]]] @@ -529,6 +541,17 @@ def train__iter__helper(self, rng: random.Random, **filters): second_chunk["meta"]["mixture_type"]="second_mixture" yield second_chunk + #add previous two chunks to get a third one + third_chunk = dict() + third_chunk["X"] = first_chunk["X"] + second_chunk["X"] + third_chunk["meta"] = first_chunk["meta"].copy() + y = np.concatenate((first_chunk["y"].data, second_chunk["y"].data), axis=1) + frames = first_chunk["y"].sliding_window + labels = first_chunk["y"].labels + second_chunk["y"].labels + third_chunk["y"] = SlidingWindowFeature(y, frames, labels=labels) + third_chunk["meta"]["mixture_type"]="mom" + yield third_chunk + def train__iter__(self): """Iterate over training samples diff --git a/pyannote/audio/tasks/segmentation/speaker_diarization.py b/pyannote/audio/tasks/segmentation/speaker_diarization.py index fdbf38da7..7fd2aa767 100644 --- a/pyannote/audio/tasks/segmentation/speaker_diarization.py +++ b/pyannote/audio/tasks/segmentation/speaker_diarization.py @@ -541,8 +541,8 @@ def training_step(self, batch, batch_idx: int): return None # TODO: pair up waveforms for MIXIT bsz = waveform.shape[0] - mix1 = waveform[0::2].squeeze(1) - mix2 = waveform[1::2].squeeze(1) + mix1 = waveform[0::3].squeeze(1) + mix2 = waveform[1::3].squeeze(1) moms = mix1 + mix2 # forward pass # TODO: model should output predictions for estimated sources as well From 77099d0ce5a8476776e0fb26e4aa8e363f67ef99 Mon Sep 17 00:00:00 2001 From: joonaskalda Date: Fri, 12 May 2023 11:04:44 +0000 Subject: [PATCH 21/55] clean up --- .../audio/tasks/segmentation/speaker_diarization.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/pyannote/audio/tasks/segmentation/speaker_diarization.py b/pyannote/audio/tasks/segmentation/speaker_diarization.py index 7fd2aa767..395838750 100644 --- a/pyannote/audio/tasks/segmentation/speaker_diarization.py +++ b/pyannote/audio/tasks/segmentation/speaker_diarization.py @@ -13,7 +13,7 @@ # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIESOF MERCHANTABILITY, +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, @@ -422,11 +422,6 @@ def collate_y(self, batch) -> torch.Tensor: return torch.from_numpy(np.stack(collated_y)) - # def separation_loss(self, prediction, target): - # mixit_loss = MixITLossWrapper(multisrc_neg_sisdr, generalized=True) - # return mixit_loss - - def segmentation_loss( self, permutated_prediction: torch.Tensor, @@ -539,13 +534,12 @@ def training_step(self, batch, batch_idx: int): # corner case if not keep.any(): return None - # TODO: pair up waveforms for MIXIT + + # forward pass bsz = waveform.shape[0] mix1 = waveform[0::3].squeeze(1) mix2 = waveform[1::3].squeeze(1) moms = mix1 + mix2 - # forward pass - # TODO: model should output predictions for estimated sources as well prediction, _ = self.model(waveform) _, prediction_sources = self.model(moms) @@ -582,7 +576,6 @@ def training_step(self, batch, batch_idx: int): permutated_prediction, target, weight=weight ) - # TODO: add also separation loss, warmup? mixit_loss = self.separation_loss( prediction_sources, torch.stack((mix1, mix2)).transpose(0, 1) ) From 04c9a9ce1a97e7c36b081d18a1da8b0da8c100a5 Mon Sep 17 00:00:00 2001 From: joonaskalda Date: Sun, 14 May 2023 13:57:43 +0000 Subject: [PATCH 22/55] check that BS is divisible by 3 --- pyannote/audio/tasks/segmentation/speaker_diarization.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pyannote/audio/tasks/segmentation/speaker_diarization.py b/pyannote/audio/tasks/segmentation/speaker_diarization.py index 395838750..c79ec2386 100644 --- a/pyannote/audio/tasks/segmentation/speaker_diarization.py +++ b/pyannote/audio/tasks/segmentation/speaker_diarization.py @@ -183,6 +183,9 @@ def __init__( "`vad_loss` cannot be used jointly with `max_speakers_per_frame`" ) + if batch_size % 3 != 0: + raise ValueError("`batch_size` must be divisible by 3 for mixtures of mixtures training") + self.max_speakers_per_chunk = max_speakers_per_chunk self.max_speakers_per_frame = max_speakers_per_frame self.weigh_by_cardinality = weigh_by_cardinality From 418ba770ec67d000c46aa1e4880f8b31e3cf3240 Mon Sep 17 00:00:00 2001 From: joonaskalda Date: Sun, 14 May 2023 13:59:56 +0000 Subject: [PATCH 23/55] don't use MoMs with more than 3 speakers --- pyannote/audio/tasks/segmentation/speaker_diarization.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/pyannote/audio/tasks/segmentation/speaker_diarization.py b/pyannote/audio/tasks/segmentation/speaker_diarization.py index c79ec2386..2882c93fa 100644 --- a/pyannote/audio/tasks/segmentation/speaker_diarization.py +++ b/pyannote/audio/tasks/segmentation/speaker_diarization.py @@ -543,9 +543,16 @@ def training_step(self, batch, batch_idx: int): mix1 = waveform[0::3].squeeze(1) mix2 = waveform[1::3].squeeze(1) moms = mix1 + mix2 - prediction, _ = self.model(waveform) _, prediction_sources = self.model(moms) + # don't use moms with more than max_speakers_per_chunk speakers for training speaker diarization + num_speakers: torch.Tensor = torch.sum(torch.any(target, dim=1), dim=1) + num_speakers[2::3] = num_speakers[::3] + num_speakers[1::3] + keep: torch.Tensor = num_speakers <= self.max_speakers_per_chunk + target = target[keep] + waveform = waveform[keep] + prediction, _ = self.model(waveform) + batch_size, num_frames, _ = prediction.shape # (batch_size, num_frames, num_classes) From 36b12bb09f1eb626df46fce81fbdb0443905bc80 Mon Sep 17 00:00:00 2001 From: joonaskalda Date: Mon, 22 May 2023 09:39:37 +0000 Subject: [PATCH 24/55] include original mixtures in separation branch training --- pyannote/audio/tasks/segmentation/mixins.py | 12 +++++ .../tasks/segmentation/speaker_diarization.py | 44 ++++++++++++++++--- 2 files changed, 49 insertions(+), 7 deletions(-) diff --git a/pyannote/audio/tasks/segmentation/mixins.py b/pyannote/audio/tasks/segmentation/mixins.py index b7b55c6ab..49d13a4b3 100644 --- a/pyannote/audio/tasks/segmentation/mixins.py +++ b/pyannote/audio/tasks/segmentation/mixins.py @@ -494,6 +494,9 @@ def train__iter__helper(self, rng: random.Random, **filters): labels = first_chunk["y"].labels + second_chunk["y"].labels third_chunk["y"] = SlidingWindowFeature(y, frames, labels=labels) third_chunk["meta"]["mixture_type"]="mom" + + # the whole mom should be used in the separation branch training + third_chunk["X_separation_mask"] = torch.ones_like(first_chunk["X_separation_mask"]) yield third_chunk else: @@ -550,6 +553,9 @@ def train__iter__helper(self, rng: random.Random, **filters): labels = first_chunk["y"].labels + second_chunk["y"].labels third_chunk["y"] = SlidingWindowFeature(y, frames, labels=labels) third_chunk["meta"]["mixture_type"]="mom" + + # the whole mom should be used in the separation branch training + third_chunk["X_separation_mask"] = torch.ones_like(first_chunk["X_separation_mask"]) yield third_chunk def train__iter__(self): @@ -601,6 +607,9 @@ def collate_y(self, batch) -> torch.Tensor: def collate_meta(self, batch) -> torch.Tensor: return default_collate([b["meta"] for b in batch]) + def collate_X_separation_mask(self, batch) -> torch.Tensor: + return default_collate([b["X_separation_mask"] for b in batch]) + def collate_fn(self, batch, stage="train"): """Collate function used for most segmentation tasks @@ -630,6 +639,8 @@ def collate_fn(self, batch, stage="train"): # collate metadata collated_meta = self.collate_meta(batch) + collated_X_separation_mask = self.collate_X_separation_mask(batch) + # apply augmentation (only in "train" stage) self.augmentation.train(mode=(stage == "train")) augmented = self.augmentation( @@ -642,6 +653,7 @@ def collate_fn(self, batch, stage="train"): "X": augmented.samples, "y": augmented.targets.squeeze(1), "meta": collated_meta, + "X_separation_mask" : collated_X_separation_mask } def train__len__(self): diff --git a/pyannote/audio/tasks/segmentation/speaker_diarization.py b/pyannote/audio/tasks/segmentation/speaker_diarization.py index 2882c93fa..4c72098d9 100644 --- a/pyannote/audio/tasks/segmentation/speaker_diarization.py +++ b/pyannote/audio/tasks/segmentation/speaker_diarization.py @@ -336,6 +336,11 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): sample = dict() sample["X"], _ = self.model.audio.crop(file, chunk, duration=duration) + # use model introspection to predict how many frames it will output + # TODO: this should be cached + num_samples = sample["X"].shape[1] + resolution_samples = self.model.example_output.frames.step * self.model.example_output.num_frames / num_samples + # gather all annotations of current file annotations = self.annotations[self.annotations["file_id"] == file_id] @@ -344,11 +349,13 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): (annotations["start"] < chunk.end) & (annotations["end"] > chunk.start) ] - # discretize chunk annotations at model output resolution + # discretize chunk annotations at model output and input resolutions start = np.maximum(chunk_annotations["start"], chunk.start) - chunk.start start_idx = np.floor(start / self.model.example_output.frames.step).astype(int) + start_idx_samples = np.floor(start / resolution_samples).astype(int) end = np.minimum(chunk_annotations["end"], chunk.end) - chunk.start end_idx = np.ceil(end / self.model.example_output.frames.step).astype(int) + end_idx_samples = np.floor(end / resolution_samples).astype(int) # get list and number of labels for current scope labels = list(np.unique(chunk_annotations[label_scope_key])) @@ -359,6 +366,7 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): # initial frame-level targets y = np.zeros((self.model.example_output.num_frames, num_labels), dtype=np.uint8) + sample_level_labels = np.zeros((num_samples, num_labels), dtype=np.uint8) # map labels to indices mapping = {label: idx for idx, label in enumerate(labels)} @@ -372,7 +380,15 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): sample["y"] = SlidingWindowFeature( y, self.model.example_output.frames, labels=labels ) + + for start, end, label in zip( + start_idx_samples, end_idx_samples, chunk_annotations[label_scope_key] + ): + mapped_label = mapping[label] + sample_level_labels[start:end, mapped_label] = 1 + # only frames with a single label should be used for mixit training + sample["X_separation_mask"] = torch.from_numpy(sample_level_labels.sum(axis=1) == 1) metadata = self.metadata[file_id] sample["meta"] = {key: metadata[key] for key in metadata.dtype.names} sample["meta"]["file"] = file_id @@ -540,10 +556,20 @@ def training_step(self, batch, batch_idx: int): # forward pass bsz = waveform.shape[0] + num_samples = waveform.shape[2] mix1 = waveform[0::3].squeeze(1) mix2 = waveform[1::3].squeeze(1) + # extract parts with only one speaker from original mixtures + mix1_masks = batch["X_separation_mask"][0::3] + mix2_masks = batch["X_separation_mask"][1::3] + mix1_masked = mix1 * mix1_masks + mix2_masked = mix2 * mix2_masks + moms = mix1 + mix2 - _, prediction_sources = self.model(moms) + + _, predicted_sources_mom = self.model(moms) + _, predicted_sources_mix1 = self.model(mix1) + _, predicted_sources_mix2 = self.model(mix2) # don't use moms with more than max_speakers_per_chunk speakers for training speaker diarization num_speakers: torch.Tensor = torch.sum(torch.any(target, dim=1), dim=1) @@ -585,10 +611,14 @@ def training_step(self, batch, batch_idx: int): seg_loss = self.segmentation_loss( permutated_prediction, target, weight=weight ) - + # contributions from original mixtures is weighed by the proportion of remaining frames mixit_loss = self.separation_loss( - prediction_sources, torch.stack((mix1, mix2)).transpose(0, 1) - ) + predicted_sources_mom, torch.stack((mix1, mix2)).transpose(0, 1) + ) + self.separation_loss( + predicted_sources_mix1, torch.stack((mix1_masked, torch.zeros_like(mix1))).transpose(0, 1) + ) * mix1_masks.sum() / num_samples / bsz * 3 + self.separation_loss( + predicted_sources_mix2, torch.stack((mix2_masked, torch.zeros_like(mix2))).transpose(0, 1) + ) * mix2_masks.sum() / num_samples / bsz * 3 self.model.log( f"{self.logging_prefix}TrainSeparationLoss", @@ -701,7 +731,7 @@ def validation_step(self, batch, batch_idx: int): # forward pass prediction, _ = self.model(waveform) - _, prediction_sources = self.model(moms) + _, predicted_sources_mom = self.model(moms) batch_size, num_frames, _ = prediction.shape # frames weight @@ -738,7 +768,7 @@ def validation_step(self, batch, batch_idx: int): ) mixit_loss = self.separation_loss( - prediction_sources, torch.stack((mix1, mix2)).transpose(0, 1) + predicted_sources_mom, torch.stack((mix1, mix2)).transpose(0, 1) ) self.model.log( From e7f656946d27583636fbd8973d026f47844d43b9 Mon Sep 17 00:00:00 2001 From: joonaskalda Date: Wed, 24 May 2023 15:21:19 +0000 Subject: [PATCH 25/55] matching the order of dimensions of branch outputs --- pyannote/audio/models/segmentation/PyanNet.py | 1 + pyannote/audio/tasks/segmentation/speaker_diarization.py | 8 ++++---- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/pyannote/audio/models/segmentation/PyanNet.py b/pyannote/audio/models/segmentation/PyanNet.py index eab93f021..6306cc229 100644 --- a/pyannote/audio/models/segmentation/PyanNet.py +++ b/pyannote/audio/models/segmentation/PyanNet.py @@ -222,6 +222,7 @@ def forward(self, waveforms: torch.Tensor) -> torch.Tensor: masked_tf_rep = masks * tf_rep.unsqueeze(1) decoded_sources = self.decoder(masked_tf_rep) decoded_sources = pad_x_to_y(decoded_sources, waveforms) + decoded_sources = decoded_sources.transpose(1, 2) outputs = rearrange( masks, "batch nsrc nfilters nframes -> batch nframes nfilters nsrc" diff --git a/pyannote/audio/tasks/segmentation/speaker_diarization.py b/pyannote/audio/tasks/segmentation/speaker_diarization.py index 4c72098d9..017779c04 100644 --- a/pyannote/audio/tasks/segmentation/speaker_diarization.py +++ b/pyannote/audio/tasks/segmentation/speaker_diarization.py @@ -613,11 +613,11 @@ def training_step(self, batch, batch_idx: int): ) # contributions from original mixtures is weighed by the proportion of remaining frames mixit_loss = self.separation_loss( - predicted_sources_mom, torch.stack((mix1, mix2)).transpose(0, 1) + predicted_sources_mom.transpose(1, 2), torch.stack((mix1, mix2)).transpose(0, 1) ) + self.separation_loss( - predicted_sources_mix1, torch.stack((mix1_masked, torch.zeros_like(mix1))).transpose(0, 1) + predicted_sources_mix1.transpose(1, 2), torch.stack((mix1_masked, torch.zeros_like(mix1))).transpose(0, 1) ) * mix1_masks.sum() / num_samples / bsz * 3 + self.separation_loss( - predicted_sources_mix2, torch.stack((mix2_masked, torch.zeros_like(mix2))).transpose(0, 1) + predicted_sources_mix2.transpose(1, 2), torch.stack((mix2_masked, torch.zeros_like(mix2))).transpose(0, 1) ) * mix2_masks.sum() / num_samples / bsz * 3 self.model.log( @@ -768,7 +768,7 @@ def validation_step(self, batch, batch_idx: int): ) mixit_loss = self.separation_loss( - predicted_sources_mom, torch.stack((mix1, mix2)).transpose(0, 1) + predicted_sources_mom.transpose(1, 2), torch.stack((mix1, mix2)).transpose(0, 1) ) self.model.log( From aa15687b422f5cbc304022c5bbb2a5f66676876c Mon Sep 17 00:00:00 2001 From: joonaskalda Date: Thu, 25 May 2023 08:59:20 +0000 Subject: [PATCH 26/55] make n_sources an argument for model constructor --- pyannote/audio/models/segmentation/PyanNet.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/pyannote/audio/models/segmentation/PyanNet.py b/pyannote/audio/models/segmentation/PyanNet.py index 6306cc229..a790e71d1 100644 --- a/pyannote/audio/models/segmentation/PyanNet.py +++ b/pyannote/audio/models/segmentation/PyanNet.py @@ -81,7 +81,6 @@ class PyanNet(Model): } LINEAR_DEFAULTS = {"hidden_size": 128, "num_layers": 2} CONVNET_DEFAULTS = { - "n_src": 6, "n_blocks": 8, "n_repeats": 3, "bn_chan": 128, @@ -92,7 +91,6 @@ class PyanNet(Model): "mask_act": "relu", } DPRNN_DEFAULTS = { - "n_src": 6, "n_repeats": 6, "bn_chan": 128, "hid_size": 128, @@ -115,6 +113,7 @@ def __init__( num_channels: int = 1, task: Optional[Task] = None, encoder_type: str = None, + n_sources: int = 6, ): super().__init__(sample_rate=sample_rate, num_channels=num_channels, task=task) @@ -135,14 +134,14 @@ def __init__( self.encoder, self.decoder = make_enc_dec( sample_rate=sample_rate, **self.hparams.encoder_decoder ) - self.masker = DPRNN(n_feats_out, **self.hparams.dprnn) + self.masker = DPRNN(n_feats_out, n_src=n_sources, **self.hparams.dprnn) #self.convnet= TDConvNet(n_feats_out, **self.hparams.convnet) monolithic = lstm["monolithic"] if monolithic: multi_layer_lstm = dict(lstm) del multi_layer_lstm["monolithic"] - self.lstm = nn.LSTM(6 * n_feats_out, **multi_layer_lstm) + self.lstm = nn.LSTM(n_sources * n_feats_out, **multi_layer_lstm) else: num_layers = lstm["num_layers"] From cbdeecaa54d78cd00edbab9943b334cdef076ce7 Mon Sep 17 00:00:00 2001 From: joonaskalda Date: Thu, 25 May 2023 09:00:17 +0000 Subject: [PATCH 27/55] changing LSTM default num_layers to 4 --- pyannote/audio/models/segmentation/PyanNet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyannote/audio/models/segmentation/PyanNet.py b/pyannote/audio/models/segmentation/PyanNet.py index a790e71d1..bb553b755 100644 --- a/pyannote/audio/models/segmentation/PyanNet.py +++ b/pyannote/audio/models/segmentation/PyanNet.py @@ -74,7 +74,7 @@ class PyanNet(Model): } LSTM_DEFAULTS = { "hidden_size": 128, - "num_layers": 2, + "num_layers": 4, "bidirectional": True, "monolithic": True, "dropout": 0.0, From 0a45c04be64ad6489f4ded536a81e15b3083de0d Mon Sep 17 00:00:00 2001 From: Joonas Kalda Date: Fri, 9 Jun 2023 10:03:36 +0200 Subject: [PATCH 28/55] create separate tasks and models --- pyannote/audio/models/segmentation/PyanNet.py | 84 +- .../audio/models/segmentation/SepDiarNet.py | 243 ++++ .../audio/models/segmentation/__init__.py | 3 +- pyannote/audio/tasks/__init__.py | 2 + pyannote/audio/tasks/segmentation/mixins.py | 107 +- .../tasks/segmentation/speaker_diarization.py | 96 +- .../speaker_separation_diarization.py | 1179 +++++++++++++++++ 7 files changed, 1448 insertions(+), 266 deletions(-) create mode 100644 pyannote/audio/models/segmentation/SepDiarNet.py create mode 100644 pyannote/audio/tasks/segmentation/speaker_separation_diarization.py diff --git a/pyannote/audio/models/segmentation/PyanNet.py b/pyannote/audio/models/segmentation/PyanNet.py index bb553b755..faf92e8b5 100644 --- a/pyannote/audio/models/segmentation/PyanNet.py +++ b/pyannote/audio/models/segmentation/PyanNet.py @@ -33,10 +33,6 @@ from pyannote.audio.core.task import Task from pyannote.audio.models.blocks.sincnet import SincNet from pyannote.audio.utils.params import merge_dict -from asteroid.masknn.convolutional import TDConvNet -from asteroid_filterbanks import make_enc_dec -from asteroid.utils.torch_utils import pad_x_to_y -from asteroid.masknn import DPRNN class PyanNet(Model): @@ -66,82 +62,40 @@ class PyanNet(Model): """ SINCNET_DEFAULTS = {"stride": 10} - ENCODER_DECODER_DEFAULTS = { - "fb_name": "stft", - "kernel_size": 512, - "n_filters": 512, - "stride": 256, - } LSTM_DEFAULTS = { "hidden_size": 128, - "num_layers": 4, + "num_layers": 2, "bidirectional": True, "monolithic": True, "dropout": 0.0, } LINEAR_DEFAULTS = {"hidden_size": 128, "num_layers": 2} - CONVNET_DEFAULTS = { - "n_blocks": 8, - "n_repeats": 3, - "bn_chan": 128, - "hid_chan": 512, - "skip_chan": 128, - "conv_kernel_size": 3, - "norm_type": "gLN", - "mask_act": "relu", - } - DPRNN_DEFAULTS = { - "n_repeats": 6, - "bn_chan": 128, - "hid_size": 128, - "chunk_size": 100, - "norm_type": "gLN", - "mask_act": "relu", - "rnn_type": "LSTM", - } def __init__( self, - encoder_decoder: dict = None, + sincnet: dict = None, lstm: dict = None, linear: dict = None, - convnet: dict = None, - dprnn: dict = None, - free_encoder: dict = None, - stft_encoder: dict = None, sample_rate: int = 16000, num_channels: int = 1, task: Optional[Task] = None, - encoder_type: str = None, - n_sources: int = 6, ): super().__init__(sample_rate=sample_rate, num_channels=num_channels, task=task) + sincnet = merge_dict(self.SINCNET_DEFAULTS, sincnet) + sincnet["sample_rate"] = sample_rate lstm = merge_dict(self.LSTM_DEFAULTS, lstm) lstm["batch_first"] = True linear = merge_dict(self.LINEAR_DEFAULTS, linear) - convnet = merge_dict(self.CONVNET_DEFAULTS, convnet) - dprnn = merge_dict(self.DPRNN_DEFAULTS, dprnn) - encoder_decoder = merge_dict(self.ENCODER_DECODER_DEFAULTS, encoder_decoder) - self.save_hyperparameters("encoder_decoder", "lstm", "linear", "convnet", "dprnn") - - if encoder_decoder["fb_name"] == "free": - n_feats_out = encoder_decoder["n_filters"] - elif encoder_decoder["fb_name"] == "stft": - n_feats_out = int(2 * (encoder_decoder["n_filters"] / 2 + 1)) - else: - raise ValueError("Filterbank type not recognized.") - self.encoder, self.decoder = make_enc_dec( - sample_rate=sample_rate, **self.hparams.encoder_decoder - ) - self.masker = DPRNN(n_feats_out, n_src=n_sources, **self.hparams.dprnn) - #self.convnet= TDConvNet(n_feats_out, **self.hparams.convnet) + self.save_hyperparameters("sincnet", "lstm", "linear") + + self.sincnet = SincNet(**self.hparams.sincnet) monolithic = lstm["monolithic"] if monolithic: multi_layer_lstm = dict(lstm) del multi_layer_lstm["monolithic"] - self.lstm = nn.LSTM(n_sources * n_feats_out, **multi_layer_lstm) + self.lstm = nn.LSTM(60, **multi_layer_lstm) else: num_layers = lstm["num_layers"] @@ -156,7 +110,7 @@ def __init__( self.lstm = nn.ModuleList( [ nn.LSTM( - 6 * n_feats_out + 60 if i == 0 else lstm["hidden_size"] * (2 if lstm["bidirectional"] else 1), **one_layer_lstm @@ -215,22 +169,14 @@ def forward(self, waveforms: torch.Tensor) -> torch.Tensor: scores : (batch, frame, classes) """ - tf_rep = self.encoder(waveforms) - masks = self.masker(tf_rep) - - masked_tf_rep = masks * tf_rep.unsqueeze(1) - decoded_sources = self.decoder(masked_tf_rep) - decoded_sources = pad_x_to_y(decoded_sources, waveforms) - decoded_sources = decoded_sources.transpose(1, 2) - - outputs = rearrange( - masks, "batch nsrc nfilters nframes -> batch nframes nfilters nsrc" - ) - outputs = torch.flatten(outputs, start_dim=2, end_dim=3) + outputs = self.sincnet(waveforms) if self.hparams.lstm["monolithic"]: - outputs, _ = self.lstm(outputs) + outputs, _ = self.lstm( + rearrange(outputs, "batch feature frame -> batch frame feature") + ) else: + outputs = rearrange(outputs, "batch feature frame -> batch frame feature") for i, lstm in enumerate(self.lstm): outputs, _ = lstm(outputs) if i + 1 < self.hparams.lstm["num_layers"]: @@ -240,4 +186,4 @@ def forward(self, waveforms: torch.Tensor) -> torch.Tensor: for linear in self.linear: outputs = F.leaky_relu(linear(outputs)) - return self.activation(self.classifier(outputs)), decoded_sources + return self.activation(self.classifier(outputs)) \ No newline at end of file diff --git a/pyannote/audio/models/segmentation/SepDiarNet.py b/pyannote/audio/models/segmentation/SepDiarNet.py new file mode 100644 index 000000000..bc44dbcd2 --- /dev/null +++ b/pyannote/audio/models/segmentation/SepDiarNet.py @@ -0,0 +1,243 @@ +# MIT License +# +# Copyright (c) 2020 CNRS +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + + +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from pyannote.core.utils.generators import pairwise + +from pyannote.audio.core.model import Model +from pyannote.audio.core.task import Task +from pyannote.audio.models.blocks.sincnet import SincNet +from pyannote.audio.utils.params import merge_dict +from asteroid.masknn.convolutional import TDConvNet +from asteroid_filterbanks import make_enc_dec +from asteroid.utils.torch_utils import pad_x_to_y +from asteroid.masknn import DPRNN + + +class SepDiarNet(Model): + """PyanNet segmentation model + + SincNet > LSTM > Feed forward > Classifier + + Parameters + ---------- + sample_rate : int, optional + Audio sample rate. Defaults to 16kHz (16000). + num_channels : int, optional + Number of channels. Defaults to mono (1). + sincnet : dict, optional + Keyword arugments passed to the SincNet block. + Defaults to {"stride": 1}. + lstm : dict, optional + Keyword arguments passed to the LSTM layer. + Defaults to {"hidden_size": 128, "num_layers": 2, "bidirectional": True}, + i.e. two bidirectional layers with 128 units each. + Set "monolithic" to False to split monolithic multi-layer LSTM into multiple mono-layer LSTMs. + This may proove useful for probing LSTM internals. + linear : dict, optional + Keyword arugments used to initialize linear layers + Defaults to {"hidden_size": 128, "num_layers": 2}, + i.e. two linear layers with 128 units each. + """ + + SINCNET_DEFAULTS = {"stride": 10} + ENCODER_DECODER_DEFAULTS = { + "fb_name": "stft", + "kernel_size": 512, + "n_filters": 512, + "stride": 256, + } + LSTM_DEFAULTS = { + "hidden_size": 128, + "num_layers": 4, + "bidirectional": True, + "monolithic": True, + "dropout": 0.0, + } + LINEAR_DEFAULTS = {"hidden_size": 128, "num_layers": 2} + CONVNET_DEFAULTS = { + "n_blocks": 8, + "n_repeats": 3, + "bn_chan": 128, + "hid_chan": 512, + "skip_chan": 128, + "conv_kernel_size": 3, + "norm_type": "gLN", + "mask_act": "relu", + } + DPRNN_DEFAULTS = { + "n_repeats": 6, + "bn_chan": 128, + "hid_size": 128, + "chunk_size": 100, + "norm_type": "gLN", + "mask_act": "relu", + "rnn_type": "LSTM", + } + + def __init__( + self, + encoder_decoder: dict = None, + lstm: dict = None, + linear: dict = None, + convnet: dict = None, + dprnn: dict = None, + free_encoder: dict = None, + stft_encoder: dict = None, + sample_rate: int = 16000, + num_channels: int = 1, + task: Optional[Task] = None, + encoder_type: str = None, + n_sources: int = 6, + ): + super().__init__(sample_rate=sample_rate, num_channels=num_channels, task=task) + + lstm = merge_dict(self.LSTM_DEFAULTS, lstm) + lstm["batch_first"] = True + linear = merge_dict(self.LINEAR_DEFAULTS, linear) + convnet = merge_dict(self.CONVNET_DEFAULTS, convnet) + dprnn = merge_dict(self.DPRNN_DEFAULTS, dprnn) + encoder_decoder = merge_dict(self.ENCODER_DECODER_DEFAULTS, encoder_decoder) + self.save_hyperparameters("encoder_decoder", "lstm", "linear", "convnet", "dprnn") + + if encoder_decoder["fb_name"] == "free": + n_feats_out = encoder_decoder["n_filters"] + elif encoder_decoder["fb_name"] == "stft": + n_feats_out = int(2 * (encoder_decoder["n_filters"] / 2 + 1)) + else: + raise ValueError("Filterbank type not recognized.") + self.encoder, self.decoder = make_enc_dec( + sample_rate=sample_rate, **self.hparams.encoder_decoder + ) + self.masker = DPRNN(n_feats_out, n_src=n_sources, **self.hparams.dprnn) + #self.convnet= TDConvNet(n_feats_out, **self.hparams.convnet) + + monolithic = lstm["monolithic"] + if monolithic: + multi_layer_lstm = dict(lstm) + del multi_layer_lstm["monolithic"] + self.lstm = nn.LSTM(n_sources * n_feats_out, **multi_layer_lstm) + + else: + num_layers = lstm["num_layers"] + if num_layers > 1: + self.dropout = nn.Dropout(p=lstm["dropout"]) + + one_layer_lstm = dict(lstm) + one_layer_lstm["num_layers"] = 1 + one_layer_lstm["dropout"] = 0.0 + del one_layer_lstm["monolithic"] + + self.lstm = nn.ModuleList( + [ + nn.LSTM( + 6 * n_feats_out + if i == 0 + else lstm["hidden_size"] * (2 if lstm["bidirectional"] else 1), + **one_layer_lstm + ) + for i in range(num_layers) + ] + ) + + if linear["num_layers"] < 1: + return + + lstm_out_features: int = self.hparams.lstm["hidden_size"] * ( + 2 if self.hparams.lstm["bidirectional"] else 1 + ) + self.linear = nn.ModuleList( + [ + nn.Linear(in_features, out_features) + for in_features, out_features in pairwise( + [ + lstm_out_features, + ] + + [self.hparams.linear["hidden_size"]] + * self.hparams.linear["num_layers"] + ) + ] + ) + + def build(self): + if self.hparams.linear["num_layers"] > 0: + in_features = self.hparams.linear["hidden_size"] + else: + in_features = self.hparams.lstm["hidden_size"] * ( + 2 if self.hparams.lstm["bidirectional"] else 1 + ) + + # if isinstance(self.specifications, tuple): + # raise ValueError("PyanNet does not support multi-tasking.") + + # if self.specifications.powerset: + out_features = self.specifications[0].num_powerset_classes + # else: + # out_features = len(self.specifications.classes) + + self.classifier = nn.Linear(in_features, out_features) + self.activation = self.default_activation() + + def forward(self, waveforms: torch.Tensor) -> torch.Tensor: + """Pass forward + + Parameters + ---------- + waveforms : (batch, channel, sample) + + Returns + ------- + scores : (batch, frame, classes) + """ + + tf_rep = self.encoder(waveforms) + masks = self.masker(tf_rep) + + masked_tf_rep = masks * tf_rep.unsqueeze(1) + decoded_sources = self.decoder(masked_tf_rep) + decoded_sources = pad_x_to_y(decoded_sources, waveforms) + decoded_sources = decoded_sources.transpose(1, 2) + + outputs = rearrange( + masks, "batch nsrc nfilters nframes -> batch nframes nfilters nsrc" + ) + outputs = torch.flatten(outputs, start_dim=2, end_dim=3) + + if self.hparams.lstm["monolithic"]: + outputs, _ = self.lstm(outputs) + else: + for i, lstm in enumerate(self.lstm): + outputs, _ = lstm(outputs) + if i + 1 < self.hparams.lstm["num_layers"]: + outputs = self.dropout(outputs) + + if self.hparams.linear["num_layers"] > 0: + for linear in self.linear: + outputs = F.leaky_relu(linear(outputs)) + + return self.activation[0](self.classifier(outputs)), decoded_sources diff --git a/pyannote/audio/models/segmentation/__init__.py b/pyannote/audio/models/segmentation/__init__.py index 82e149853..aa336b1c6 100644 --- a/pyannote/audio/models/segmentation/__init__.py +++ b/pyannote/audio/models/segmentation/__init__.py @@ -21,5 +21,6 @@ # SOFTWARE. from .PyanNet import PyanNet +from .SepDiarNet import SepDiarNet -__all__ = ["PyanNet"] +__all__ = ["PyanNet", "SepDiarNet"] diff --git a/pyannote/audio/tasks/__init__.py b/pyannote/audio/tasks/__init__.py index 6cbba258f..23be6a547 100644 --- a/pyannote/audio/tasks/__init__.py +++ b/pyannote/audio/tasks/__init__.py @@ -22,6 +22,7 @@ from .segmentation.multilabel import MultiLabelSegmentation # isort:skip from .segmentation.speaker_diarization import SpeakerDiarization # isort:skip +from .segmentation.speaker_separation_diarization import JointSpeakerSeparationAndDiarization # isort:skip from .segmentation.voice_activity_detection import VoiceActivityDetection # isort:skip from .segmentation.overlapped_speech_detection import ( # isort:skip OverlappedSpeechDetection, @@ -41,4 +42,5 @@ "MultiLabelSegmentation", "SpeakerEmbedding", "Segmentation", + "JointSpeakerSeparationAndDiarization", ] diff --git a/pyannote/audio/tasks/segmentation/mixins.py b/pyannote/audio/tasks/segmentation/mixins.py index 49d13a4b3..142245ae8 100644 --- a/pyannote/audio/tasks/segmentation/mixins.py +++ b/pyannote/audio/tasks/segmentation/mixins.py @@ -30,7 +30,6 @@ import matplotlib.pyplot as plt import numpy as np import torch -from pyannote.core import Segment, SlidingWindow, SlidingWindowFeature from pyannote.database.protocol import SegmentationProtocol, SpeakerDiarizationProtocol from pyannote.database.protocol.protocol import Scope, Subset from pytorch_lightning.loggers import MLFlowLogger, TensorBoardLogger @@ -436,7 +435,6 @@ def train__iter__helper(self, rng: random.Random, **filters): while True: # select one file at random (with probability proportional to its annotated duration) file_id = np.random.choice(file_ids, p=prob_annotated_duration) - annotations = self.annotations[np.where(self.annotations["file_id"] == file_id)[0]] # generate `num_chunks_per_file` chunks from this file for _ in range(num_chunks_per_file): @@ -459,104 +457,7 @@ def train__iter__helper(self, rng: random.Random, **filters): _, _, start, end = self.annotated_regions[annotated_region_index] start_time = rng.uniform(start, end - duration) - # find speakers that already appeared and all annotations that contain them - chunk_annotations = annotations[ - (annotations["start"] < start_time+duration) & (annotations["end"] > start_time) - ] - previous_speaker_labels = list(np.unique(chunk_annotations["file_label_idx"])) - repeated_speaker_annotations = annotations[np.isin(annotations["file_label_idx"], previous_speaker_labels)] - - if repeated_speaker_annotations.size == 0: - # if previous chunk has 0 speakers then just sample from all annotated regions again - first_chunk = self.prepare_chunk(file_id, start_time, duration) - first_chunk["meta"]["mixture_type"]="first_mixture" - yield first_chunk - - # selected one annotated region at random (with probability proportional to its duration) - annotated_region_index = np.random.choice( - annotated_region_indices, p=prob_annotated_regions_duration - ) - - # select one chunk at random in this annotated region - _, _, start, end = self.annotated_regions[annotated_region_index] - start_time = rng.uniform(start, end - duration) - - second_chunk = self.prepare_chunk(file_id, start_time, duration) - second_chunk["meta"]["mixture_type"]="second_mixture" - yield second_chunk - - # add previous two chunks to get a third one - third_chunk = dict() - third_chunk["X"] = first_chunk["X"] + second_chunk["X"] - third_chunk["meta"] = first_chunk["meta"].copy() - y = np.concatenate((first_chunk["y"].data, second_chunk["y"].data), axis=1) - frames = first_chunk["y"].sliding_window - labels = first_chunk["y"].labels + second_chunk["y"].labels - third_chunk["y"] = SlidingWindowFeature(y, frames, labels=labels) - third_chunk["meta"]["mixture_type"]="mom" - - # the whole mom should be used in the separation branch training - third_chunk["X_separation_mask"] = torch.ones_like(first_chunk["X_separation_mask"]) - yield third_chunk - - else: - # merge segments that contain repeated speakers - merged_repeated_segments = [[repeated_speaker_annotations["start"][0],repeated_speaker_annotations["end"][0]]] - for _, start, end, _, _, _ in repeated_speaker_annotations: - previous = merged_repeated_segments[-1] - if start <= previous[1]: - previous[1] = max(previous[1], end) - else: - merged_repeated_segments.append([start, end]) - - # find segments that don't contain repeated speakers - segments_without_repeat = [] - current_region_index = 0 - previous_time = self.annotated_regions["start"][annotated_region_indices[0]] - for segment in merged_repeated_segments: - if segment[0] > self.annotated_regions["end"][annotated_region_indices[current_region_index]]: - current_region_index+=1 - previous_time = self.annotated_regions["start"][annotated_region_indices[current_region_index]] - - if segment[0] - previous_time > duration: - segments_without_repeat.append((previous_time, segment[0], segment[0] - previous_time)) - previous_time = segment[1] - - dtype = [("start", "f"), ("end", "f"),("duration", "f")] - segments_without_repeat = np.array(segments_without_repeat,dtype=dtype) - - if np.sum(segments_without_repeat["duration"]) != 0: - - # only yield chunks if it is possible to choose the second chunk so that yielded chunks are always paired - - first_chunk = self.prepare_chunk(file_id, start_time, duration) - first_chunk["meta"]["mixture_type"]="first_mixture" - yield first_chunk - - prob_segments_duration = segments_without_repeat["duration"] / np.sum(segments_without_repeat["duration"]) - segment = np.random.choice( - segments_without_repeat, p=prob_segments_duration - ) - - start, end, _ = segment - new_start_time = rng.uniform(start, end - duration) - second_chunk = self.prepare_chunk(file_id, new_start_time, duration) - second_chunk["meta"]["mixture_type"]="second_mixture" - yield second_chunk - - #add previous two chunks to get a third one - third_chunk = dict() - third_chunk["X"] = first_chunk["X"] + second_chunk["X"] - third_chunk["meta"] = first_chunk["meta"].copy() - y = np.concatenate((first_chunk["y"].data, second_chunk["y"].data), axis=1) - frames = first_chunk["y"].sliding_window - labels = first_chunk["y"].labels + second_chunk["y"].labels - third_chunk["y"] = SlidingWindowFeature(y, frames, labels=labels) - third_chunk["meta"]["mixture_type"]="mom" - - # the whole mom should be used in the separation branch training - third_chunk["X_separation_mask"] = torch.ones_like(first_chunk["X_separation_mask"]) - yield third_chunk + yield self.prepare_chunk(file_id, start_time, duration) def train__iter__(self): """Iterate over training samples @@ -607,9 +508,6 @@ def collate_y(self, batch) -> torch.Tensor: def collate_meta(self, batch) -> torch.Tensor: return default_collate([b["meta"] for b in batch]) - def collate_X_separation_mask(self, batch) -> torch.Tensor: - return default_collate([b["X_separation_mask"] for b in batch]) - def collate_fn(self, batch, stage="train"): """Collate function used for most segmentation tasks @@ -639,8 +537,6 @@ def collate_fn(self, batch, stage="train"): # collate metadata collated_meta = self.collate_meta(batch) - collated_X_separation_mask = self.collate_X_separation_mask(batch) - # apply augmentation (only in "train" stage) self.augmentation.train(mode=(stage == "train")) augmented = self.augmentation( @@ -653,7 +549,6 @@ def collate_fn(self, batch, stage="train"): "X": augmented.samples, "y": augmented.targets.squeeze(1), "meta": collated_meta, - "X_separation_mask" : collated_X_separation_mask } def train__len__(self): diff --git a/pyannote/audio/tasks/segmentation/speaker_diarization.py b/pyannote/audio/tasks/segmentation/speaker_diarization.py index 017779c04..eac795a47 100644 --- a/pyannote/audio/tasks/segmentation/speaker_diarization.py +++ b/pyannote/audio/tasks/segmentation/speaker_diarization.py @@ -53,7 +53,6 @@ from pyannote.audio.utils.loss import binary_cross_entropy, mse_loss, nll_loss from pyannote.audio.utils.permutation import permutate from pyannote.audio.utils.powerset import Powerset -from asteroid.losses import MixITLossWrapper, multisrc_neg_sisdr Subsets = list(Subset.__args__) Scopes = list(Scope.__args__) @@ -111,8 +110,6 @@ class SpeakerDiarization(SegmentationTaskMixin, Task): metric : optional Validation metric(s). Can be anything supported by torchmetrics.MetricCollection. Defaults to AUROC (area under the ROC curve). - mixit_loss_weight : float, optional - Factor that speaker separation loss is scaled by when calculating total loss. References ---------- @@ -145,7 +142,6 @@ def __init__( metric: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None, max_num_speakers: int = None, # deprecated in favor of `max_speakers_per_chunk`` loss: Literal["bce", "mse"] = None, # deprecated - mixit_loss_weight: float = 0.2, ): super().__init__( protocol, @@ -183,17 +179,12 @@ def __init__( "`vad_loss` cannot be used jointly with `max_speakers_per_frame`" ) - if batch_size % 3 != 0: - raise ValueError("`batch_size` must be divisible by 3 for mixtures of mixtures training") - self.max_speakers_per_chunk = max_speakers_per_chunk self.max_speakers_per_frame = max_speakers_per_frame self.weigh_by_cardinality = weigh_by_cardinality self.balance = balance self.weight = weight self.vad_loss = vad_loss - self.separation_loss = MixITLossWrapper(multisrc_neg_sisdr, generalized=True) - self.mixit_loss_weight = mixit_loss_weight def setup(self): super().setup() @@ -336,11 +327,6 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): sample = dict() sample["X"], _ = self.model.audio.crop(file, chunk, duration=duration) - # use model introspection to predict how many frames it will output - # TODO: this should be cached - num_samples = sample["X"].shape[1] - resolution_samples = self.model.example_output.frames.step * self.model.example_output.num_frames / num_samples - # gather all annotations of current file annotations = self.annotations[self.annotations["file_id"] == file_id] @@ -349,13 +335,11 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): (annotations["start"] < chunk.end) & (annotations["end"] > chunk.start) ] - # discretize chunk annotations at model output and input resolutions + # discretize chunk annotations at model output resolution start = np.maximum(chunk_annotations["start"], chunk.start) - chunk.start start_idx = np.floor(start / self.model.example_output.frames.step).astype(int) - start_idx_samples = np.floor(start / resolution_samples).astype(int) end = np.minimum(chunk_annotations["end"], chunk.end) - chunk.start end_idx = np.ceil(end / self.model.example_output.frames.step).astype(int) - end_idx_samples = np.floor(end / resolution_samples).astype(int) # get list and number of labels for current scope labels = list(np.unique(chunk_annotations[label_scope_key])) @@ -366,7 +350,6 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): # initial frame-level targets y = np.zeros((self.model.example_output.num_frames, num_labels), dtype=np.uint8) - sample_level_labels = np.zeros((num_samples, num_labels), dtype=np.uint8) # map labels to indices mapping = {label: idx for idx, label in enumerate(labels)} @@ -380,15 +363,7 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): sample["y"] = SlidingWindowFeature( y, self.model.example_output.frames, labels=labels ) - - for start, end, label in zip( - start_idx_samples, end_idx_samples, chunk_annotations[label_scope_key] - ): - mapped_label = mapping[label] - sample_level_labels[start:end, mapped_label] = 1 - # only frames with a single label should be used for mixit training - sample["X_separation_mask"] = torch.from_numpy(sample_level_labels.sum(axis=1) == 1) metadata = self.metadata[file_id] sample["meta"] = {key: metadata[key] for key in metadata.dtype.names} sample["meta"]["file"] = file_id @@ -555,30 +530,7 @@ def training_step(self, batch, batch_idx: int): return None # forward pass - bsz = waveform.shape[0] - num_samples = waveform.shape[2] - mix1 = waveform[0::3].squeeze(1) - mix2 = waveform[1::3].squeeze(1) - # extract parts with only one speaker from original mixtures - mix1_masks = batch["X_separation_mask"][0::3] - mix2_masks = batch["X_separation_mask"][1::3] - mix1_masked = mix1 * mix1_masks - mix2_masked = mix2 * mix2_masks - - moms = mix1 + mix2 - - _, predicted_sources_mom = self.model(moms) - _, predicted_sources_mix1 = self.model(mix1) - _, predicted_sources_mix2 = self.model(mix2) - - # don't use moms with more than max_speakers_per_chunk speakers for training speaker diarization - num_speakers: torch.Tensor = torch.sum(torch.any(target, dim=1), dim=1) - num_speakers[2::3] = num_speakers[::3] + num_speakers[1::3] - keep: torch.Tensor = num_speakers <= self.max_speakers_per_chunk - target = target[keep] - waveform = waveform[keep] - prediction, _ = self.model(waveform) - + prediction = self.model(waveform) batch_size, num_frames, _ = prediction.shape # (batch_size, num_frames, num_classes) @@ -611,23 +563,6 @@ def training_step(self, batch, batch_idx: int): seg_loss = self.segmentation_loss( permutated_prediction, target, weight=weight ) - # contributions from original mixtures is weighed by the proportion of remaining frames - mixit_loss = self.separation_loss( - predicted_sources_mom.transpose(1, 2), torch.stack((mix1, mix2)).transpose(0, 1) - ) + self.separation_loss( - predicted_sources_mix1.transpose(1, 2), torch.stack((mix1_masked, torch.zeros_like(mix1))).transpose(0, 1) - ) * mix1_masks.sum() / num_samples / bsz * 3 + self.separation_loss( - predicted_sources_mix2.transpose(1, 2), torch.stack((mix2_masked, torch.zeros_like(mix2))).transpose(0, 1) - ) * mix2_masks.sum() / num_samples / bsz * 3 - - self.model.log( - f"{self.logging_prefix}TrainSeparationLoss", - mixit_loss, - on_step=False, - on_epoch=True, - prog_bar=False, - logger=True, - ) self.model.log( "loss/train/segmentation", @@ -663,7 +598,7 @@ def training_step(self, batch, batch_idx: int): logger=True, ) - loss = (1 - self.mixit_loss_weight) * (seg_loss + vad_loss) + self.mixit_loss_weight * mixit_loss + loss = seg_loss + vad_loss # skip batch if something went wrong for some reason if torch.isnan(loss): @@ -724,14 +659,8 @@ def validation_step(self, batch, batch_idx: int): # waveform = waveform[keep] # target = target[keep] - bsz = waveform.shape[0] - mix1 = waveform[bsz // 2 : 2 * (bsz // 2)].squeeze(1) - mix2 = waveform[: bsz // 2].squeeze(1) - moms = mix1 + mix2 - # forward pass - prediction, _ = self.model(waveform) - _, predicted_sources_mom = self.model(moms) + prediction = self.model(waveform) batch_size, num_frames, _ = prediction.shape # frames weight @@ -767,21 +696,8 @@ def validation_step(self, batch, batch_idx: int): permutated_prediction, target, weight=weight ) - mixit_loss = self.separation_loss( - predicted_sources_mom.transpose(1, 2), torch.stack((mix1, mix2)).transpose(0, 1) - ) - - self.model.log( - f"{self.logging_prefix}ValSeparationLoss", - mixit_loss, - on_step=False, - on_epoch=True, - prog_bar=False, - logger=True, - ) - self.model.log( - f"{self.logging_prefix}ValSegLoss", + "loss/val/segmentation", seg_loss, on_step=False, on_epoch=True, @@ -814,7 +730,7 @@ def validation_step(self, batch, batch_idx: int): logger=True, ) - loss = (1 - self.mixit_loss_weight) * (seg_loss + vad_loss) + self.mixit_loss_weight * mixit_loss + loss = seg_loss + vad_loss self.model.log( "loss/val", diff --git a/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py b/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py new file mode 100644 index 000000000..77a0573dd --- /dev/null +++ b/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py @@ -0,0 +1,1179 @@ +# MIT License +# +# Copyright (c) 2020- CNRS +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import math +import warnings +import random +from collections import Counter +from typing import Dict, Literal, Sequence, Text, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional +from matplotlib import pyplot as plt +from pyannote.core import Segment, SlidingWindow, SlidingWindowFeature +from pyannote.database.protocol import SpeakerDiarizationProtocol +from pyannote.database.protocol.protocol import Scope, Subset +from pytorch_lightning.loggers import MLFlowLogger, TensorBoardLogger +from rich.progress import track +from torch_audiomentations.core.transforms_interface import BaseWaveformTransform +from torchmetrics import Metric + +from pyannote.audio.core.task import Problem, Resolution, Specifications, Task +from pyannote.audio.tasks.segmentation.mixins import SegmentationTaskMixin +from pyannote.audio.torchmetrics import ( + DiarizationErrorRate, + FalseAlarmRate, + MissedDetectionRate, + OptimalDiarizationErrorRate, + OptimalDiarizationErrorRateThreshold, + OptimalFalseAlarmRate, + OptimalMissedDetectionRate, + OptimalSpeakerConfusionRate, + SpeakerConfusionRate, +) +from pyannote.audio.utils.loss import binary_cross_entropy, mse_loss, nll_loss +from pyannote.audio.utils.permutation import permutate +from pyannote.audio.utils.powerset import Powerset +from asteroid.losses import MixITLossWrapper, multisrc_neg_sisdr +from torch.utils.data._utils.collate import default_collate + +Subsets = list(Subset.__args__) +Scopes = list(Scope.__args__) + + +class JointSpeakerSeparationAndDiarization(SegmentationTaskMixin, Task): + """Speaker diarization + + Parameters + ---------- + protocol : SpeakerDiarizationProtocol + pyannote.database protocol + duration : float, optional + Chunks duration. Defaults to 2s. + max_speakers_per_chunk : int, optional + Maximum number of speakers per chunk (must be at least 2). + Defaults to estimating it from the training set. + max_speakers_per_frame : int, optional + Maximum number of (overlapping) speakers per frame. + Setting this value to 1 or more enables `powerset multi-class` training. + Default behavior is to use `multi-label` training. + weigh_by_cardinality: bool, optional + Weigh each powerset classes by the size of the corresponding speaker set. + In other words, {0, 1} powerset class weight is 2x bigger than that of {0} + or {1} powerset classes. Note that empty (non-speech) powerset class is + assigned the same weight as mono-speaker classes. Defaults to False (i.e. use + same weight for every class). Has no effect with `multi-label` training. + warm_up : float or (float, float), optional + Use that many seconds on the left- and rightmost parts of each chunk + to warm up the model. While the model does process those left- and right-most + parts, only the remaining central part of each chunk is used for computing the + loss during training, and for aggregating scores during inference. + Defaults to 0. (i.e. no warm-up). + balance: str, optional + When provided, training samples are sampled uniformly with respect to that key. + For instance, setting `balance` to "database" will make sure that each database + will be equally represented in the training samples. + weight: str, optional + When provided, use this key as frame-wise weight in loss function. + batch_size : int, optional + Number of training samples per batch. Defaults to 32. + num_workers : int, optional + Number of workers used for generating training samples. + Defaults to multiprocessing.cpu_count() // 2. + pin_memory : bool, optional + If True, data loaders will copy tensors into CUDA pinned + memory before returning them. See pytorch documentation + for more details. Defaults to False. + augmentation : BaseWaveformTransform, optional + torch_audiomentations waveform transform, used by dataloader + during training. + vad_loss : {"bce", "mse"}, optional + Add voice activity detection loss. + Cannot be used in conjunction with `max_speakers_per_frame`. + metric : optional + Validation metric(s). Can be anything supported by torchmetrics.MetricCollection. + Defaults to AUROC (area under the ROC curve). + mixit_loss_weight : float, optional + Factor that speaker separation loss is scaled by when calculating total loss. + + References + ---------- + Hervé Bredin and Antoine Laurent + "End-To-End Speaker Segmentation for Overlap-Aware Resegmentation." + Proc. Interspeech 2021 + + Zhihao Du, Shiliang Zhang, Siqi Zheng, and Zhijie Yan + "Speaker Embedding-aware Neural Diarization: an Efficient Framework for Overlapping + Speech Diarization in Meeting Scenarios" + https://arxiv.org/abs/2203.09767 + + """ + + def __init__( + self, + protocol: SpeakerDiarizationProtocol, + duration: float = 2.0, + max_speakers_per_chunk: int = None, + max_speakers_per_frame: int = None, + weigh_by_cardinality: bool = False, + warm_up: Union[float, Tuple[float, float]] = 0.0, + balance: Text = None, + weight: Text = None, + batch_size: int = 32, + num_workers: int = None, + pin_memory: bool = False, + augmentation: BaseWaveformTransform = None, + vad_loss: Literal["bce", "mse"] = None, + metric: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None, + max_num_speakers: int = None, # deprecated in favor of `max_speakers_per_chunk`` + loss: Literal["bce", "mse"] = None, # deprecated + mixit_loss_weight: float = 0.2, + ): + super().__init__( + protocol, + duration=duration, + warm_up=warm_up, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + augmentation=augmentation, + metric=metric, + ) + + if not isinstance(protocol, SpeakerDiarizationProtocol): + raise ValueError( + "SpeakerDiarization task requires a SpeakerDiarizationProtocol." + ) + + # deprecation warnings + if max_speakers_per_chunk is None and max_num_speakers is not None: + max_speakers_per_chunk = max_num_speakers + warnings.warn( + "`max_num_speakers` has been deprecated in favor of `max_speakers_per_chunk`." + ) + if loss is not None: + warnings.warn("`loss` has been deprecated and has no effect.") + + # parameter validation + if max_speakers_per_frame is not None: + if max_speakers_per_frame < 1: + raise ValueError( + f"`max_speakers_per_frame` must be 1 or more (you used {max_speakers_per_frame})." + ) + if vad_loss is not None: + raise ValueError( + "`vad_loss` cannot be used jointly with `max_speakers_per_frame`" + ) + + if batch_size % 3 != 0: + raise ValueError("`batch_size` must be divisible by 3 for mixtures of mixtures training") + + self.max_speakers_per_chunk = max_speakers_per_chunk + self.max_speakers_per_frame = max_speakers_per_frame + self.weigh_by_cardinality = weigh_by_cardinality + self.balance = balance + self.weight = weight + self.vad_loss = vad_loss + self.separation_loss = MixITLossWrapper(multisrc_neg_sisdr, generalized=True) + self.mixit_loss_weight = mixit_loss_weight + + def setup(self): + super().setup() + + # estimate maximum number of speakers per chunk when not provided + if self.max_speakers_per_chunk is None: + training = self.metadata["subset"] == Subsets.index("train") + + num_unique_speakers = [] + progress_description = f"Estimating maximum number of speakers per {self.duration:g}s chunk in the training set" + for file_id in track( + np.where(training)[0], description=progress_description + ): + annotations = self.annotations[ + np.where(self.annotations["file_id"] == file_id)[0] + ] + annotated_regions = self.annotated_regions[ + np.where(self.annotated_regions["file_id"] == file_id)[0] + ] + for region in annotated_regions: + # find annotations within current region + region_start = region["start"] + region_end = region["end"] + region_annotations = annotations[ + np.where( + (annotations["start"] >= region_start) + * (annotations["end"] <= region_end) + )[0] + ] + + for window_start in np.arange( + region_start, region_end - self.duration, 0.25 * self.duration + ): + window_end = window_start + self.duration + window_annotations = region_annotations[ + np.where( + (region_annotations["start"] <= window_end) + * (region_annotations["end"] >= window_start) + )[0] + ] + num_unique_speakers.append( + len(np.unique(window_annotations["file_label_idx"])) + ) + + # because there might a few outliers, estimate the upper bound for the + # number of speakers as the 97th percentile + + num_speakers, counts = zip(*list(Counter(num_unique_speakers).items())) + num_speakers, counts = np.array(num_speakers), np.array(counts) + + sorting_indices = np.argsort(num_speakers) + num_speakers = num_speakers[sorting_indices] + counts = counts[sorting_indices] + + ratios = np.cumsum(counts) / np.sum(counts) + + for k, ratio in zip(num_speakers, ratios): + if k == 0: + print(f" - {ratio:7.2%} of all chunks contain no speech at all.") + elif k == 1: + print(f" - {ratio:7.2%} contain 1 speaker or less") + else: + print(f" - {ratio:7.2%} contain {k} speakers or less") + + self.max_speakers_per_chunk = max( + 2, + num_speakers[np.where(ratios > 0.97)[0][0]], + ) + + print( + f"Setting `max_speakers_per_chunk` to {self.max_speakers_per_chunk}. " + f"You can override this value (or avoid this estimation step) by passing `max_speakers_per_chunk={self.max_speakers_per_chunk}` to the task constructor." + ) + + if ( + self.max_speakers_per_frame is not None + and self.max_speakers_per_frame > self.max_speakers_per_chunk + ): + raise ValueError( + f"`max_speakers_per_frame` ({self.max_speakers_per_frame}) must be smaller " + f"than `max_speakers_per_chunk` ({self.max_speakers_per_chunk})" + ) + + # now that we know about the number of speakers upper bound + # we can set task specifications + speaker_diarization = Specifications( + duration=self.duration, + resolution=Resolution.FRAME, + problem=Problem.MONO_LABEL_CLASSIFICATION, + permutation_invariant=True, + classes=[f"speaker#{i+1}" for i in range(self.max_speakers_per_chunk)], + powerset_max_classes=self.max_speakers_per_frame, + ) + + speaker_separation = Specifications( + duration=self.duration, + resolution=Resolution.FRAME, + problem=Problem.MONO_LABEL_CLASSIFICATION, # Doesn't matter + classes=[f"speaker#{i+1}" for i in range(self.max_speakers_per_chunk)], + ) + + self.specifications = (speaker_diarization, speaker_separation) + + def setup_loss_func(self): + if self.specifications[0].powerset: + self.model.powerset = Powerset( + len(self.specifications[0].classes), + self.specifications[0].powerset_max_classes, + ) + + def prepare_chunk(self, file_id: int, start_time: float, duration: float): + """Prepare chunk + + Parameters + ---------- + file_id : int + File index + start_time : float + Chunk start time + duration : float + Chunk duration. + + Returns + ------- + sample : dict + Dictionary containing the chunk data with the following keys: + - `X`: waveform + - `y`: target as a SlidingWindowFeature instance where y.labels is + in meta.scope space. + - `meta`: + - `scope`: target scope (0: file, 1: database, 2: global) + - `database`: database index + - `file`: file index + """ + + file = self.get_file(file_id) + + # get label scope + label_scope = Scopes[self.metadata[file_id]["scope"]] + label_scope_key = f"{label_scope}_label_idx" + + # + chunk = Segment(start_time, start_time + duration) + + sample = dict() + sample["X"], _ = self.model.audio.crop(file, chunk, duration=duration) + + # use model introspection to predict how many frames it will output + # TODO: this should be cached + num_samples = sample["X"].shape[1] + #resolution_samples = self.model.example_output[0].frames.step * self.model.example_output[0].num_frames / num_samples + + # gather all annotations of current file + annotations = self.annotations[self.annotations["file_id"] == file_id] + + # gather all annotations with non-empty intersection with current chunk + chunk_annotations = annotations[ + (annotations["start"] < chunk.end) & (annotations["end"] > chunk.start) + ] + + # discretize chunk annotations at model output and input resolutions + start = np.maximum(chunk_annotations["start"], chunk.start) - chunk.start + start_idx = np.floor(start / self.model.example_output[0].frames.step).astype(int) + start_idx_samples = np.floor(start * 16000).astype(int) + end = np.minimum(chunk_annotations["end"], chunk.end) - chunk.start + end_idx = np.ceil(end / self.model.example_output[0].frames.step).astype(int) + end_idx_samples = np.floor(end * 16000).astype(int) + + # get list and number of labels for current scope + labels = list(np.unique(chunk_annotations[label_scope_key])) + num_labels = len(labels) + + if num_labels > self.max_speakers_per_chunk: + pass + + # initial frame-level targets + y = np.zeros((self.model.example_output[0].num_frames, num_labels), dtype=np.uint8) + sample_level_labels = np.zeros((num_samples, num_labels), dtype=np.uint8) + + # map labels to indices + mapping = {label: idx for idx, label in enumerate(labels)} + + for start, end, label in zip( + start_idx, end_idx, chunk_annotations[label_scope_key] + ): + mapped_label = mapping[label] + y[start:end, mapped_label] = 1 + + sample["y"] = SlidingWindowFeature( + y, self.model.example_output[0].frames, labels=labels + ) + + for start, end, label in zip( + start_idx_samples, end_idx_samples, chunk_annotations[label_scope_key] + ): + mapped_label = mapping[label] + sample_level_labels[start:end, mapped_label] = 1 + + # only frames with a single label should be used for mixit training + sample["X_separation_mask"] = torch.from_numpy(sample_level_labels.sum(axis=1) == 1) + metadata = self.metadata[file_id] + sample["meta"] = {key: metadata[key] for key in metadata.dtype.names} + sample["meta"]["file"] = file_id + + return sample + + def train__iter__helper(self, rng: random.Random, **filters): + """Iterate over training samples with optional domain filtering + + Parameters + ---------- + rng : random.Random + Random number generator + filters : dict, optional + When provided (as {key: value} dict), filter training files so that + only files such as file[key] == value are used for generating chunks. + + Yields + ------ + chunk : dict + Training chunks. + """ + + # indices of training files that matches domain filters + training = self.metadata["subset"] == Subsets.index("train") + for key, value in filters.items(): + training &= self.metadata[key] == value + file_ids = np.where(training)[0] + + # turn annotated duration into a probability distribution + annotated_duration = self.annotated_duration[file_ids] + prob_annotated_duration = annotated_duration / np.sum(annotated_duration) + + duration = self.duration + + num_chunks_per_file = getattr(self, "num_chunks_per_file", 1) + + while True: + # select one file at random (with probability proportional to its annotated duration) + file_id = np.random.choice(file_ids, p=prob_annotated_duration) + annotations = self.annotations[np.where(self.annotations["file_id"] == file_id)[0]] + + # generate `num_chunks_per_file` chunks from this file + for _ in range(num_chunks_per_file): + # find indices of annotated regions in this file + annotated_region_indices = np.where( + self.annotated_regions["file_id"] == file_id + )[0] + + # turn annotated regions duration into a probability distribution + prob_annotated_regions_duration = self.annotated_regions["duration"][ + annotated_region_indices + ] / np.sum(self.annotated_regions["duration"][annotated_region_indices]) + + # selected one annotated region at random (with probability proportional to its duration) + annotated_region_index = np.random.choice( + annotated_region_indices, p=prob_annotated_regions_duration + ) + + # select one chunk at random in this annotated region + _, _, start, end = self.annotated_regions[annotated_region_index] + start_time = rng.uniform(start, end - duration) + + # find speakers that already appeared and all annotations that contain them + chunk_annotations = annotations[ + (annotations["start"] < start_time+duration) & (annotations["end"] > start_time) + ] + previous_speaker_labels = list(np.unique(chunk_annotations["file_label_idx"])) + repeated_speaker_annotations = annotations[np.isin(annotations["file_label_idx"], previous_speaker_labels)] + + if repeated_speaker_annotations.size == 0: + # if previous chunk has 0 speakers then just sample from all annotated regions again + first_chunk = self.prepare_chunk(file_id, start_time, duration) + first_chunk["meta"]["mixture_type"]="first_mixture" + yield first_chunk + + # selected one annotated region at random (with probability proportional to its duration) + annotated_region_index = np.random.choice( + annotated_region_indices, p=prob_annotated_regions_duration + ) + + # select one chunk at random in this annotated region + _, _, start, end = self.annotated_regions[annotated_region_index] + start_time = rng.uniform(start, end - duration) + + second_chunk = self.prepare_chunk(file_id, start_time, duration) + second_chunk["meta"]["mixture_type"]="second_mixture" + yield second_chunk + + # add previous two chunks to get a third one + third_chunk = dict() + third_chunk["X"] = first_chunk["X"] + second_chunk["X"] + third_chunk["meta"] = first_chunk["meta"].copy() + y = np.concatenate((first_chunk["y"].data, second_chunk["y"].data), axis=1) + frames = first_chunk["y"].sliding_window + labels = first_chunk["y"].labels + second_chunk["y"].labels + third_chunk["y"] = SlidingWindowFeature(y, frames, labels=labels) + third_chunk["meta"]["mixture_type"]="mom" + + # the whole mom should be used in the separation branch training + third_chunk["X_separation_mask"] = torch.ones_like(first_chunk["X_separation_mask"]) + yield third_chunk + + else: + # merge segments that contain repeated speakers + merged_repeated_segments = [[repeated_speaker_annotations["start"][0],repeated_speaker_annotations["end"][0]]] + for _, start, end, _, _, _ in repeated_speaker_annotations: + previous = merged_repeated_segments[-1] + if start <= previous[1]: + previous[1] = max(previous[1], end) + else: + merged_repeated_segments.append([start, end]) + + # find segments that don't contain repeated speakers + segments_without_repeat = [] + current_region_index = 0 + previous_time = self.annotated_regions["start"][annotated_region_indices[0]] + for segment in merged_repeated_segments: + if segment[0] > self.annotated_regions["end"][annotated_region_indices[current_region_index]]: + current_region_index+=1 + previous_time = self.annotated_regions["start"][annotated_region_indices[current_region_index]] + + if segment[0] - previous_time > duration: + segments_without_repeat.append((previous_time, segment[0], segment[0] - previous_time)) + previous_time = segment[1] + + dtype = [("start", "f"), ("end", "f"),("duration", "f")] + segments_without_repeat = np.array(segments_without_repeat,dtype=dtype) + + if np.sum(segments_without_repeat["duration"]) != 0: + + # only yield chunks if it is possible to choose the second chunk so that yielded chunks are always paired + + first_chunk = self.prepare_chunk(file_id, start_time, duration) + first_chunk["meta"]["mixture_type"]="first_mixture" + yield first_chunk + + prob_segments_duration = segments_without_repeat["duration"] / np.sum(segments_without_repeat["duration"]) + segment = np.random.choice( + segments_without_repeat, p=prob_segments_duration + ) + + start, end, _ = segment + new_start_time = rng.uniform(start, end - duration) + second_chunk = self.prepare_chunk(file_id, new_start_time, duration) + second_chunk["meta"]["mixture_type"]="second_mixture" + yield second_chunk + + #add previous two chunks to get a third one + third_chunk = dict() + third_chunk["X"] = first_chunk["X"] + second_chunk["X"] + third_chunk["meta"] = first_chunk["meta"].copy() + y = np.concatenate((first_chunk["y"].data, second_chunk["y"].data), axis=1) + frames = first_chunk["y"].sliding_window + labels = first_chunk["y"].labels + second_chunk["y"].labels + third_chunk["y"] = SlidingWindowFeature(y, frames, labels=labels) + third_chunk["meta"]["mixture_type"]="mom" + + # the whole mom should be used in the separation branch training + third_chunk["X_separation_mask"] = torch.ones_like(first_chunk["X_separation_mask"]) + yield third_chunk + + def collate_X_separation_mask(self, batch) -> torch.Tensor: + return default_collate([b["X_separation_mask"] for b in batch]) + + def collate_fn(self, batch, stage="train"): + """Collate function used for most segmentation tasks + + This function does the following: + * stack waveforms into a (batch_size, num_channels, num_samples) tensor batch["X"]) + * apply augmentation when in "train" stage + * convert targets into a (batch_size, num_frames, num_classes) tensor batch["y"] + * collate any other keys that might be present in the batch using pytorch default_collate function + + Parameters + ---------- + batch : list of dict + List of training samples. + + Returns + ------- + batch : dict + Collated batch as {"X": torch.Tensor, "y": torch.Tensor} dict. + """ + + # collate X + collated_X = self.collate_X(batch) + + # collate y + collated_y = self.collate_y(batch) + + # collate metadata + collated_meta = self.collate_meta(batch) + + collated_X_separation_mask = self.collate_X_separation_mask(batch) + + # apply augmentation (only in "train" stage) + self.augmentation.train(mode=(stage == "train")) + augmented = self.augmentation( + samples=collated_X, + sample_rate=self.model.hparams.sample_rate, + targets=collated_y.unsqueeze(1), + ) + + return { + "X": augmented.samples, + "y": augmented.targets.squeeze(1), + "meta": collated_meta, + "X_separation_mask" : collated_X_separation_mask + } + + def collate_y(self, batch) -> torch.Tensor: + """ + + Parameters + ---------- + batch : list + List of samples to collate. + "y" field is expected to be a SlidingWindowFeature. + + Returns + ------- + y : torch.Tensor + Collated target tensor of shape (num_frames, self.max_speakers_per_chunk) + If one chunk has more than `self.max_speakers_per_chunk` speakers, we keep + the max_speakers_per_chunk most talkative ones. If it has less, we pad with + zeros (artificial inactive speakers). + """ + + collated_y = [] + for b in batch: + y = b["y"].data + num_speakers = len(b["y"].labels) + if num_speakers > self.max_speakers_per_chunk: + # sort speakers in descending talkativeness order + indices = np.argsort(-np.sum(y, axis=0), axis=0) + # keep only the most talkative speakers + y = y[:, indices[: self.max_speakers_per_chunk]] + + # TODO: we should also sort the speaker labels in the same way + + elif num_speakers < self.max_speakers_per_chunk: + # create inactive speakers by zero padding + y = np.pad( + y, + ((0, 0), (0, self.max_speakers_per_chunk - num_speakers)), + mode="constant", + ) + + else: + # we have exactly the right number of speakers + pass + + collated_y.append(y) + + return torch.from_numpy(np.stack(collated_y)) + + def segmentation_loss( + self, + permutated_prediction: torch.Tensor, + target: torch.Tensor, + weight: torch.Tensor = None, + ) -> torch.Tensor: + """Permutation-invariant segmentation loss + + Parameters + ---------- + permutated_prediction : (batch_size, num_frames, num_classes) torch.Tensor + Permutated speaker activity predictions. + target : (batch_size, num_frames, num_speakers) torch.Tensor + Speaker activity. + weight : (batch_size, num_frames, 1) torch.Tensor, optional + Frames weight. + + Returns + ------- + seg_loss : torch.Tensor + Permutation-invariant segmentation loss + """ + + if self.specifications[0].powerset: + # `clamp_min` is needed to set non-speech weight to 1. + class_weight = ( + torch.clamp_min(self.model.powerset.cardinality, 1.0) + if self.weigh_by_cardinality + else None + ) + seg_loss = nll_loss( + permutated_prediction, + torch.argmax(target, dim=-1), + class_weight=class_weight, + weight=weight, + ) + else: + seg_loss = binary_cross_entropy( + permutated_prediction, target.float(), weight=weight + ) + + return seg_loss + + def voice_activity_detection_loss( + self, + permutated_prediction: torch.Tensor, + target: torch.Tensor, + weight: torch.Tensor = None, + ) -> torch.Tensor: + """Voice activity detection loss + + Parameters + ---------- + permutated_prediction : (batch_size, num_frames, num_classes) torch.Tensor + Speaker activity predictions. + target : (batch_size, num_frames, num_speakers) torch.Tensor + Speaker activity. + weight : (batch_size, num_frames, 1) torch.Tensor, optional + Frames weight. + + Returns + ------- + vad_loss : torch.Tensor + Voice activity detection loss. + """ + + vad_prediction, _ = torch.max(permutated_prediction, dim=2, keepdim=True) + # (batch_size, num_frames, 1) + + vad_target, _ = torch.max(target.float(), dim=2, keepdim=False) + # (batch_size, num_frames) + + if self.vad_loss == "bce": + loss = binary_cross_entropy(vad_prediction, vad_target, weight=weight) + + elif self.vad_loss == "mse": + loss = mse_loss(vad_prediction, vad_target, weight=weight) + + return loss + + def training_step(self, batch, batch_idx: int): + """Compute permutation-invariant segmentation loss + + Parameters + ---------- + batch : (usually) dict of torch.Tensor + Current batch. + batch_idx: int + Batch index. + + Returns + ------- + loss : {str: torch.tensor} + {"loss": loss} + """ + + # target + target = batch["y"] + # (batch_size, num_frames, num_speakers) + + waveform = batch["X"] + # (batch_size, num_channels, num_samples) + + # drop samples that contain too many speakers + num_speakers: torch.Tensor = torch.sum(torch.any(target, dim=1), dim=1) + keep: torch.Tensor = num_speakers <= self.max_speakers_per_chunk + target = target[keep] + waveform = waveform[keep] + + # corner case + if not keep.any(): + return None + + # forward pass + bsz = waveform.shape[0] + num_samples = waveform.shape[2] + mix1 = waveform[0::3].squeeze(1) + mix2 = waveform[1::3].squeeze(1) + # extract parts with only one speaker from original mixtures + mix1_masks = batch["X_separation_mask"][0::3] + mix2_masks = batch["X_separation_mask"][1::3] + mix1_masked = mix1 * mix1_masks + mix2_masked = mix2 * mix2_masks + + moms = mix1 + mix2 + + _, predicted_sources_mom = self.model(moms) + _, predicted_sources_mix1 = self.model(mix1) + _, predicted_sources_mix2 = self.model(mix2) + + # don't use moms with more than max_speakers_per_chunk speakers for training speaker diarization + num_speakers: torch.Tensor = torch.sum(torch.any(target, dim=1), dim=1) + num_speakers[2::3] = num_speakers[::3] + num_speakers[1::3] + keep: torch.Tensor = num_speakers <= self.max_speakers_per_chunk + target = target[keep] + waveform = waveform[keep] + prediction, _ = self.model(waveform) + + batch_size, num_frames, _ = prediction.shape + # (batch_size, num_frames, num_classes) + + # frames weight + weight_key = getattr(self, "weight", None) + weight = batch.get( + weight_key, + torch.ones(batch_size, num_frames, 1, device=self.model.device), + ) + # (batch_size, num_frames, 1) + + # warm-up + warm_up_left = round(self.warm_up[0] / self.duration * num_frames) + weight[:, :warm_up_left] = 0.0 + warm_up_right = round(self.warm_up[1] / self.duration * num_frames) + weight[:, num_frames - warm_up_right :] = 0.0 + + if self.specifications[0].powerset: + multilabel = self.model.powerset.to_multilabel(prediction) + permutated_target, _ = permutate(multilabel, target) + permutated_target_powerset = self.model.powerset.to_powerset( + permutated_target.float() + ) + seg_loss = self.segmentation_loss( + prediction, permutated_target_powerset, weight=weight + ) + + else: + permutated_prediction, _ = permutate(target, prediction) + seg_loss = self.segmentation_loss( + permutated_prediction, target, weight=weight + ) + # contributions from original mixtures is weighed by the proportion of remaining frames + mixit_loss = self.separation_loss( + predicted_sources_mom.transpose(1, 2), torch.stack((mix1, mix2)).transpose(0, 1) + ) + self.separation_loss( + predicted_sources_mix1.transpose(1, 2), torch.stack((mix1_masked, torch.zeros_like(mix1))).transpose(0, 1) + ) * mix1_masks.sum() / num_samples / bsz * 3 + self.separation_loss( + predicted_sources_mix2.transpose(1, 2), torch.stack((mix2_masked, torch.zeros_like(mix2))).transpose(0, 1) + ) * mix2_masks.sum() / num_samples / bsz * 3 + + self.model.log( + "loss/train/separation", + mixit_loss, + on_step=False, + on_epoch=True, + prog_bar=False, + logger=True, + ) + + self.model.log( + "loss/train/segmentation", + seg_loss, + on_step=False, + on_epoch=True, + prog_bar=False, + logger=True, + ) + + if self.vad_loss is None: + vad_loss = 0.0 + + else: + # TODO: vad_loss probably does not make sense in powerset mode + # because first class (empty set of labels) does exactly this... + if self.specifications[0].powerset: + vad_loss = self.voice_activity_detection_loss( + prediction, permutated_target_powerset, weight=weight + ) + + else: + vad_loss = self.voice_activity_detection_loss( + permutated_prediction, target, weight=weight + ) + + self.model.log( + "loss/train/vad", + vad_loss, + on_step=False, + on_epoch=True, + prog_bar=False, + logger=True, + ) + + loss = (1 - self.mixit_loss_weight) * (seg_loss + vad_loss) + self.mixit_loss_weight * mixit_loss + + # skip batch if something went wrong for some reason + if torch.isnan(loss): + return None + + self.model.log( + "loss/train", + loss, + on_step=False, + on_epoch=True, + prog_bar=False, + logger=True, + ) + + return {"loss": loss} + + def default_metric( + self, + ) -> Union[Metric, Sequence[Metric], Dict[str, Metric]]: + """Returns diarization error rate and its components""" + + if self.specifications[0].powerset: + return { + "DiarizationErrorRate": DiarizationErrorRate(0.5), + "DiarizationErrorRate/Confusion": SpeakerConfusionRate(0.5), + "DiarizationErrorRate/Miss": MissedDetectionRate(0.5), + "DiarizationErrorRate/FalseAlarm": FalseAlarmRate(0.5), + } + + return { + "DiarizationErrorRate": OptimalDiarizationErrorRate(), + "DiarizationErrorRate/Threshold": OptimalDiarizationErrorRateThreshold(), + "DiarizationErrorRate/Confusion": OptimalSpeakerConfusionRate(), + "DiarizationErrorRate/Miss": OptimalMissedDetectionRate(), + "DiarizationErrorRate/FalseAlarm": OptimalFalseAlarmRate(), + } + + # TODO: no need to compute gradient in this method + def validation_step(self, batch, batch_idx: int): + """Compute validation loss and metric + + Parameters + ---------- + batch : dict of torch.Tensor + Current batch. + batch_idx: int + Batch index. + """ + + # target + target = batch["y"] + # (batch_size, num_frames, num_speakers) + + waveform = batch["X"] + # (batch_size, num_channels, num_samples) + + # TODO: should we handle validation samples with too many speakers + # waveform = waveform[keep] + # target = target[keep] + + bsz = waveform.shape[0] + mix1 = waveform[bsz // 2 : 2 * (bsz // 2)].squeeze(1) + mix2 = waveform[: bsz // 2].squeeze(1) + moms = mix1 + mix2 + + # forward pass + prediction, _ = self.model(waveform) + _, predicted_sources_mom = self.model(moms) + batch_size, num_frames, _ = prediction.shape + + # frames weight + weight_key = getattr(self, "weight", None) + weight = batch.get( + weight_key, + torch.ones(batch_size, num_frames, 1, device=self.model.device), + ) + # (batch_size, num_frames, 1) + + # warm-up + warm_up_left = round(self.warm_up[0] / self.duration * num_frames) + weight[:, :warm_up_left] = 0.0 + warm_up_right = round(self.warm_up[1] / self.duration * num_frames) + weight[:, num_frames - warm_up_right :] = 0.0 + + if self.specifications[0].powerset: + multilabel = self.model.powerset.to_multilabel(prediction) + permutated_target, _ = permutate(multilabel, target) + + # FIXME: handle case where target have too many speakers? + # since we don't need + permutated_target_powerset = self.model.powerset.to_powerset( + permutated_target.float() + ) + seg_loss = self.segmentation_loss( + prediction, permutated_target_powerset, weight=weight + ) + + else: + permutated_prediction, _ = permutate(target, prediction) + seg_loss = self.segmentation_loss( + permutated_prediction, target, weight=weight + ) + + mixit_loss = self.separation_loss( + predicted_sources_mom.transpose(1, 2), torch.stack((mix1, mix2)).transpose(0, 1) + ) + + self.model.log( + "loss/val/separation", + mixit_loss, + on_step=False, + on_epoch=True, + prog_bar=False, + logger=True, + ) + + self.model.log( + "loss/val/segmentation", + seg_loss, + on_step=False, + on_epoch=True, + prog_bar=False, + logger=True, + ) + + if self.vad_loss is None: + vad_loss = 0.0 + + else: + # TODO: vad_loss probably does not make sense in powerset mode + # because first class (empty set of labels) does exactly this... + if self.specifications[0].powerset: + vad_loss = self.voice_activity_detection_loss( + prediction, permutated_target_powerset, weight=weight + ) + + else: + vad_loss = self.voice_activity_detection_loss( + permutated_prediction, target, weight=weight + ) + + self.model.log( + "loss/val/vad", + vad_loss, + on_step=False, + on_epoch=True, + prog_bar=False, + logger=True, + ) + + loss = (1 - self.mixit_loss_weight) * (seg_loss + vad_loss) + self.mixit_loss_weight * mixit_loss + + self.model.log( + "loss/val", + loss, + on_step=False, + on_epoch=True, + prog_bar=False, + logger=True, + ) + + if self.specifications[0].powerset: + self.model.validation_metric( + torch.transpose( + multilabel[:, warm_up_left : num_frames - warm_up_right], 1, 2 + ), + torch.transpose( + target[:, warm_up_left : num_frames - warm_up_right], 1, 2 + ), + ) + else: + self.model.validation_metric( + torch.transpose( + prediction[:, warm_up_left : num_frames - warm_up_right], 1, 2 + ), + torch.transpose( + target[:, warm_up_left : num_frames - warm_up_right], 1, 2 + ), + ) + + self.model.log_dict( + self.model.validation_metric, + on_step=False, + on_epoch=True, + prog_bar=True, + logger=True, + ) + + # log first batch visualization every 2^n epochs. + if ( + self.model.current_epoch == 0 + or math.log2(self.model.current_epoch) % 1 > 0 + or batch_idx > 0 + ): + return + + # visualize first 9 validation samples of first batch in Tensorboard/MLflow + + if self.specifications[0].powerset: + y = permutated_target.float().cpu().numpy() + y_pred = multilabel.cpu().numpy() + else: + y = target.float().cpu().numpy() + y_pred = permutated_prediction.cpu().numpy() + + # prepare 3 x 3 grid (or smaller if batch size is smaller) + num_samples = min(self.batch_size, 9) + nrows = math.ceil(math.sqrt(num_samples)) + ncols = math.ceil(num_samples / nrows) + fig, axes = plt.subplots( + nrows=2 * nrows, ncols=ncols, figsize=(8, 5), squeeze=False + ) + + # reshape target so that there is one line per class when plotting it + y[y == 0] = np.NaN + if len(y.shape) == 2: + y = y[:, :, np.newaxis] + y *= np.arange(y.shape[2]) + + # plot each sample + for sample_idx in range(num_samples): + # find where in the grid it should be plotted + row_idx = sample_idx // nrows + col_idx = sample_idx % ncols + + # plot target + ax_ref = axes[row_idx * 2 + 0, col_idx] + sample_y = y[sample_idx] + ax_ref.plot(sample_y) + ax_ref.set_xlim(0, len(sample_y)) + ax_ref.set_ylim(-1, sample_y.shape[1]) + ax_ref.get_xaxis().set_visible(False) + ax_ref.get_yaxis().set_visible(False) + + # plot predictions + ax_hyp = axes[row_idx * 2 + 1, col_idx] + sample_y_pred = y_pred[sample_idx] + ax_hyp.axvspan(0, warm_up_left, color="k", alpha=0.5, lw=0) + ax_hyp.axvspan( + num_frames - warm_up_right, num_frames, color="k", alpha=0.5, lw=0 + ) + ax_hyp.plot(sample_y_pred) + ax_hyp.set_ylim(-0.1, 1.1) + ax_hyp.set_xlim(0, len(sample_y)) + ax_hyp.get_xaxis().set_visible(False) + + plt.tight_layout() + + for logger in self.model.loggers: + if isinstance(logger, TensorBoardLogger): + logger.experiment.add_figure("samples", fig, self.model.current_epoch) + elif isinstance(logger, MLFlowLogger): + logger.experiment.log_figure( + run_id=logger.run_id, + figure=fig, + artifact_file=f"samples_epoch{self.model.current_epoch}.png", + ) + + plt.close(fig) + + +def main(protocol: str, subset: str = "test", model: str = "pyannote/segmentation"): + """Evaluate a segmentation model""" + + from pyannote.database import FileFinder, get_protocol + from rich.progress import Progress + + from pyannote.audio import Inference + from pyannote.audio.pipelines.utils import get_devices + from pyannote.audio.utils.metric import DiscreteDiarizationErrorRate + from pyannote.audio.utils.signal import binarize + + (device,) = get_devices(needs=1) + metric = DiscreteDiarizationErrorRate() + protocol = get_protocol(protocol, preprocessors={"audio": FileFinder()}) + files = list(getattr(protocol, subset)()) + + with Progress() as progress: + main_task = progress.add_task(protocol.name, total=len(files)) + file_task = progress.add_task("Processing", total=1.0) + + def progress_hook(completed: int = None, total: int = None): + progress.update(file_task, completed=completed / total) + + inference = Inference(model, device=device) + + for file in files: + progress.update(file_task, description=file["uri"]) + reference = file["annotation"] + hypothesis = binarize(inference(file, hook=progress_hook)) + uem = file["annotated"] + _ = metric(reference, hypothesis, uem=uem) + progress.advance(main_task) + + _ = metric.report(display=True) + + +if __name__ == "__main__": + import typer + + typer.run(main) From edb0155085d8ac593a4f9b85ac37a00ed5ac52da Mon Sep 17 00:00:00 2001 From: Joonas Kalda Date: Mon, 12 Jun 2023 14:25:17 +0300 Subject: [PATCH 29/55] Changing n_sources to 3 --- pyannote/audio/models/segmentation/SepDiarNet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyannote/audio/models/segmentation/SepDiarNet.py b/pyannote/audio/models/segmentation/SepDiarNet.py index bc44dbcd2..08efd869e 100644 --- a/pyannote/audio/models/segmentation/SepDiarNet.py +++ b/pyannote/audio/models/segmentation/SepDiarNet.py @@ -113,7 +113,7 @@ def __init__( num_channels: int = 1, task: Optional[Task] = None, encoder_type: str = None, - n_sources: int = 6, + n_sources: int = 3, ): super().__init__(sample_rate=sample_rate, num_channels=num_channels, task=task) From a81da406b4c432b903ac6a1ed3174600b5d15046 Mon Sep 17 00:00:00 2001 From: joonaskalda Date: Fri, 16 Jun 2023 13:48:22 +0200 Subject: [PATCH 30/55] forcing alignment between separation and diarization --- .../speaker_separation_diarization.py | 145 ++++++++++++++++-- 1 file changed, 131 insertions(+), 14 deletions(-) diff --git a/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py b/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py index 77a0573dd..65ad00356 100644 --- a/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py +++ b/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py @@ -54,12 +54,107 @@ from pyannote.audio.utils.loss import binary_cross_entropy, mse_loss, nll_loss from pyannote.audio.utils.permutation import permutate from pyannote.audio.utils.powerset import Powerset -from asteroid.losses import MixITLossWrapper, multisrc_neg_sisdr +from asteroid.losses import multisrc_neg_sisdr from torch.utils.data._utils.collate import default_collate Subsets = list(Subset.__args__) Scopes = list(Scope.__args__) +from itertools import combinations +from torch import nn + +class ModifiedMixITLossWrapper(nn.Module): + r"""Mixture invariant loss wrapper modifed to force alignment between separation and diarization. + + Args: + loss_func: function with signature (est_targets, targets, **kwargs). + generalized (bool): Determines how MixIT is applied. If False , + apply MixIT for any number of mixtures as soon as they contain + the same number of sources (:meth:`~MixITLossWrapper.best_part_mixit`.) + If True (default), apply MixIT for two mixtures, but those mixtures do not + necessarly have to contain the same number of sources. + See :meth:`~MixITLossWrapper.best_part_mixit_generalized`. + reduction (string, optional): Specifies the reduction to apply to + the output: + ``'none'`` | ``'mean'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output. + + For each of these modes, the best partition and reordering will be + automatically computed. + + Examples: + >>> import torch + >>> from asteroid.losses import multisrc_mse + >>> mixtures = torch.randn(10, 2, 16000) + >>> est_sources = torch.randn(10, 4, 16000) + >>> # Compute MixIT loss based on pairwise losses + >>> loss_func = MixITLossWrapper(multisrc_mse) + >>> loss_val = loss_func(est_sources, mixtures) + + References + [1] Scott Wisdom et al. "Unsupervised sound separation using + mixtures of mixtures." arXiv:2006.12701 (2020) + """ + + def __init__(self, loss_func, generalized=True, reduction="mean"): + super().__init__() + self.loss_func = loss_func + self.generalized = generalized + self.reduction = reduction + + def forward(self, est_targets, targets, part_from_mix1, part_from_mix2, return_est=False, **kwargs): + r"""Find the best partition and return the loss. + + Args: + est_targets: torch.Tensor. Expected shape :math:`(batch, nsrc, *)`. + The batch of target estimates. + targets: torch.Tensor. Expected shape :math:`(batch, nmix, ...)`. + The batch of training targets + return_est: Boolean. Whether to return the estimated mixtures + estimates (To compute metrics or to save example). + **kwargs: additional keyword argument that will be passed to the + loss function. + + Returns: + - Best partition loss for each batch sample, average over + the batch. torch.Tensor(loss_value) + - The estimated mixtures (estimated sources summed according to the partition) + if return_est is True. torch.Tensor of shape :math:`(batch, nmix, ...)`. + """ + # Check input dimensions + assert est_targets.shape[0] == targets.shape[0] + assert est_targets.shape[2] == targets.shape[2] + + # if not self.generalized: + # min_loss, min_loss_idx, parts = self.best_part_mixit( + # self.loss_func, est_targets, targets, **kwargs + # ) + # else: + # min_loss, min_loss_idx, parts = self.best_part_mixit_generalized( + # self.loss_func, est_targets, targets, **kwargs + # ) + est_mixes = [] + for i in range(est_targets.shape[0]): + # sum the sources according to the given partition + est_mix1 = est_targets[i, part_from_mix1[i], :].sum(0) + est_mix2 = est_targets[i, part_from_mix2[i], :].sum(0) + # get loss for the given partition + + est_mixes.append(torch.stack((est_mix1, est_mix2))) + est_mixes = torch.stack(est_mixes) + loss_partition = self.loss_func(est_mixes, targets, **kwargs) + if loss_partition.ndim != 1: + raise ValueError("Loss function return value should be of size (batch,).") + + # Apply any reductions over the batch axis + returned_loss = loss_partition.mean() if self.reduction == "mean" else loss_partition + if not return_est: + return returned_loss + + # Order and sum on the best partition to get the estimated mixtures + # reordered = self.reorder_source(est_targets, targets, min_loss_idx, parts) + return returned_loss, est_mixes class JointSpeakerSeparationAndDiarization(SegmentationTaskMixin, Task): """Speaker diarization @@ -194,7 +289,7 @@ def __init__( self.balance = balance self.weight = weight self.vad_loss = vad_loss - self.separation_loss = MixITLossWrapper(multisrc_neg_sisdr, generalized=True) + self.separation_loss = ModifiedMixITLossWrapper(multisrc_neg_sisdr, generalized=True) self.mixit_loss_weight = mixit_loss_weight def setup(self): @@ -470,6 +565,9 @@ def train__iter__helper(self, rng: random.Random, **filters): # if previous chunk has 0 speakers then just sample from all annotated regions again first_chunk = self.prepare_chunk(file_id, start_time, duration) first_chunk["meta"]["mixture_type"]="first_mixture" + # in order to align separation and diarization branches we need to know which mixtures do speakers/sources originate from + first_chunk["meta"]["sources_from_first_mixture"] = len(first_chunk["y"].labels) + first_chunk["meta"]["sources_from_second_mixture"] = 0 yield first_chunk # selected one annotated region at random (with probability proportional to its duration) @@ -483,6 +581,8 @@ def train__iter__helper(self, rng: random.Random, **filters): second_chunk = self.prepare_chunk(file_id, start_time, duration) second_chunk["meta"]["mixture_type"]="second_mixture" + second_chunk["meta"]["sources_from_first_mixture"] = 0 + second_chunk["meta"]["sources_from_second_mixture"] = len(second_chunk["y"].labels) yield second_chunk # add previous two chunks to get a third one @@ -494,6 +594,8 @@ def train__iter__helper(self, rng: random.Random, **filters): labels = first_chunk["y"].labels + second_chunk["y"].labels third_chunk["y"] = SlidingWindowFeature(y, frames, labels=labels) third_chunk["meta"]["mixture_type"]="mom" + third_chunk["meta"]["sources_from_first_mixture"] = len(first_chunk["y"].labels) + third_chunk["meta"]["sources_from_second_mixture"] = len(second_chunk["y"].labels) # the whole mom should be used in the separation branch training third_chunk["X_separation_mask"] = torch.ones_like(first_chunk["X_separation_mask"]) @@ -531,7 +633,9 @@ def train__iter__helper(self, rng: random.Random, **filters): first_chunk = self.prepare_chunk(file_id, start_time, duration) first_chunk["meta"]["mixture_type"]="first_mixture" - yield first_chunk + first_chunk["meta"]["sources_from_first_mixture"] = len(first_chunk["y"].labels) + first_chunk["meta"]["sources_from_second_mixture"] = 0 + #yield first_chunk prob_segments_duration = segments_without_repeat["duration"] / np.sum(segments_without_repeat["duration"]) segment = np.random.choice( @@ -542,7 +646,9 @@ def train__iter__helper(self, rng: random.Random, **filters): new_start_time = rng.uniform(start, end - duration) second_chunk = self.prepare_chunk(file_id, new_start_time, duration) second_chunk["meta"]["mixture_type"]="second_mixture" - yield second_chunk + second_chunk["meta"]["sources_from_first_mixture"] = 0 + second_chunk["meta"]["sources_from_second_mixture"] = len(second_chunk["y"].labels) + #yield second_chunk #add previous two chunks to get a third one third_chunk = dict() @@ -556,7 +662,13 @@ def train__iter__helper(self, rng: random.Random, **filters): # the whole mom should be used in the separation branch training third_chunk["X_separation_mask"] = torch.ones_like(first_chunk["X_separation_mask"]) - yield third_chunk + third_chunk["meta"]["sources_from_first_mixture"] = len(first_chunk["y"].labels) + third_chunk["meta"]["sources_from_second_mixture"] = len(second_chunk["y"].labels) + #third_chunk["sources_from_first_mixture"] = len(first_chunk["y"].labels) + if len(labels) < 4: + yield first_chunk + yield second_chunk + yield third_chunk def collate_X_separation_mask(self, batch) -> torch.Tensor: return default_collate([b["X_separation_mask"] for b in batch]) @@ -810,7 +922,7 @@ def training_step(self, batch, batch_idx: int): if self.specifications[0].powerset: multilabel = self.model.powerset.to_multilabel(prediction) - permutated_target, _ = permutate(multilabel, target) + permutated_target, permutations = permutate(multilabel, target) permutated_target_powerset = self.model.powerset.to_powerset( permutated_target.float() ) @@ -819,17 +931,21 @@ def training_step(self, batch, batch_idx: int): ) else: - permutated_prediction, _ = permutate(target, prediction) + permutated_prediction, permutations = permutate(target, prediction) seg_loss = self.segmentation_loss( permutated_prediction, target, weight=weight ) + # to find which predicted sources correspond to which mixtures, we need to invert the permutations + permutations_inverse = torch.argsort(torch.tensor(permutations)) + predicted_sources_idx_mix1 = [[permutations_inverse[i][j] for j in range(batch["meta"]["sources_from_first_mixture"][i])] for i in range(batch_size)] + predicted_sources_idx_mix2 = [[permutations_inverse[i][j] for j in range(batch["meta"]["sources_from_first_mixture"][i],batch["meta"]["sources_from_second_mixture"][i])] for i in range(batch_size)] # contributions from original mixtures is weighed by the proportion of remaining frames mixit_loss = self.separation_loss( - predicted_sources_mom.transpose(1, 2), torch.stack((mix1, mix2)).transpose(0, 1) + predicted_sources_mom.transpose(1, 2), torch.stack((mix1, mix2)).transpose(0, 1), predicted_sources_idx_mix1[2::3], predicted_sources_idx_mix2[2::3] ) + self.separation_loss( - predicted_sources_mix1.transpose(1, 2), torch.stack((mix1_masked, torch.zeros_like(mix1))).transpose(0, 1) + predicted_sources_mix1.transpose(1, 2), torch.stack((mix1_masked, torch.zeros_like(mix1))).transpose(0, 1), predicted_sources_idx_mix1[0::3], predicted_sources_idx_mix2[0::3] ) * mix1_masks.sum() / num_samples / bsz * 3 + self.separation_loss( - predicted_sources_mix2.transpose(1, 2), torch.stack((mix2_masked, torch.zeros_like(mix2))).transpose(0, 1) + predicted_sources_mix2.transpose(1, 2), torch.stack((mix2_masked, torch.zeros_like(mix2))).transpose(0, 1), predicted_sources_idx_mix1[1::3], predicted_sources_idx_mix2[1::3] ) * mix2_masks.sum() / num_samples / bsz * 3 self.model.log( @@ -978,10 +1094,11 @@ def validation_step(self, batch, batch_idx: int): seg_loss = self.segmentation_loss( permutated_prediction, target, weight=weight ) - - mixit_loss = self.separation_loss( - predicted_sources_mom.transpose(1, 2), torch.stack((mix1, mix2)).transpose(0, 1) - ) + # forced alignment mixit can't be implemented for validation because since data loading is different + mixit_loss = 0 + # mixit_loss = self.separation_loss( + # predicted_sources_mom.transpose(1, 2), torch.stack((mix1, mix2)).transpose(0, 1) + # ) self.model.log( "loss/val/separation", From b5a3517a8cae1322d21c777e540c13695d049b0b Mon Sep 17 00:00:00 2001 From: joonaskalda Date: Sat, 17 Jun 2023 09:20:05 +0000 Subject: [PATCH 31/55] fixing edge case of 4 speakers in a second chunk --- .../segmentation/speaker_separation_diarization.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py b/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py index 65ad00356..49884ac0c 100644 --- a/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py +++ b/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py @@ -568,7 +568,7 @@ def train__iter__helper(self, rng: random.Random, **filters): # in order to align separation and diarization branches we need to know which mixtures do speakers/sources originate from first_chunk["meta"]["sources_from_first_mixture"] = len(first_chunk["y"].labels) first_chunk["meta"]["sources_from_second_mixture"] = 0 - yield first_chunk + # yield first_chunk # selected one annotated region at random (with probability proportional to its duration) annotated_region_index = np.random.choice( @@ -583,7 +583,7 @@ def train__iter__helper(self, rng: random.Random, **filters): second_chunk["meta"]["mixture_type"]="second_mixture" second_chunk["meta"]["sources_from_first_mixture"] = 0 second_chunk["meta"]["sources_from_second_mixture"] = len(second_chunk["y"].labels) - yield second_chunk + # yield second_chunk # add previous two chunks to get a third one third_chunk = dict() @@ -599,7 +599,11 @@ def train__iter__helper(self, rng: random.Random, **filters): # the whole mom should be used in the separation branch training third_chunk["X_separation_mask"] = torch.ones_like(first_chunk["X_separation_mask"]) - yield third_chunk + + if len(labels) < 4: + yield first_chunk + yield second_chunk + yield third_chunk else: # merge segments that contain repeated speakers From 124f3e34692ef602c773417d5c6f4dd864ee133d Mon Sep 17 00:00:00 2001 From: joonaskalda Date: Sun, 18 Jun 2023 18:20:30 +0200 Subject: [PATCH 32/55] adding a VAD-like forced alignment loss --- .../tasks/segmentation/speaker_separation_diarization.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py b/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py index 49884ac0c..3b65f5686 100644 --- a/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py +++ b/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py @@ -952,6 +952,12 @@ def training_step(self, batch, batch_idx: int): predicted_sources_mix2.transpose(1, 2), torch.stack((mix2_masked, torch.zeros_like(mix2))).transpose(0, 1), predicted_sources_idx_mix1[1::3], predicted_sources_idx_mix2[1::3] ) * mix2_masks.sum() / num_samples / bsz * 3 + upscaled_permutated_target = torch.nn.functional.interpolate(permutated_target.transpose(1, 2), size=(80000)).transpose(1, 2) + forced_alignment_loss = (1 - 2 * upscaled_permutated_target[::3]) * predicted_sources_mix1 ** 2 +\ + (1 - 2 * upscaled_permutated_target[1::3]) * predicted_sources_mix2 ** 2 +\ + (1 - 2 * upscaled_permutated_target[2::3]) * predicted_sources_mom ** 2 + forced_alignment_loss = forced_alignment_loss.mean() / 3 + self.model.log( "loss/train/separation", mixit_loss, @@ -995,7 +1001,7 @@ def training_step(self, batch, batch_idx: int): logger=True, ) - loss = (1 - self.mixit_loss_weight) * (seg_loss + vad_loss) + self.mixit_loss_weight * mixit_loss + loss = (1 - self.mixit_loss_weight) * (seg_loss + vad_loss) + self.mixit_loss_weight * mixit_loss + forced_alignment_loss # skip batch if something went wrong for some reason if torch.isnan(loss): From 5e7a0af8ad71882926a77faa7bf64eabcd36ab1d Mon Sep 17 00:00:00 2001 From: joonaskalda Date: Tue, 20 Jun 2023 19:30:16 +0200 Subject: [PATCH 33/55] refactor: remove vad_loss and warm_up, assume powerset everywhere --- .../speaker_separation_diarization.py | 262 ++++-------------- 1 file changed, 51 insertions(+), 211 deletions(-) diff --git a/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py b/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py index 3b65f5686..fa0e8e8f7 100644 --- a/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py +++ b/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py @@ -178,12 +178,6 @@ class JointSpeakerSeparationAndDiarization(SegmentationTaskMixin, Task): or {1} powerset classes. Note that empty (non-speech) powerset class is assigned the same weight as mono-speaker classes. Defaults to False (i.e. use same weight for every class). Has no effect with `multi-label` training. - warm_up : float or (float, float), optional - Use that many seconds on the left- and rightmost parts of each chunk - to warm up the model. While the model does process those left- and right-most - parts, only the remaining central part of each chunk is used for computing the - loss during training, and for aggregating scores during inference. - Defaults to 0. (i.e. no warm-up). balance: str, optional When provided, training samples are sampled uniformly with respect to that key. For instance, setting `balance` to "database" will make sure that each database @@ -202,9 +196,6 @@ class JointSpeakerSeparationAndDiarization(SegmentationTaskMixin, Task): augmentation : BaseWaveformTransform, optional torch_audiomentations waveform transform, used by dataloader during training. - vad_loss : {"bce", "mse"}, optional - Add voice activity detection loss. - Cannot be used in conjunction with `max_speakers_per_frame`. metric : optional Validation metric(s). Can be anything supported by torchmetrics.MetricCollection. Defaults to AUROC (area under the ROC curve). @@ -231,14 +222,12 @@ def __init__( max_speakers_per_chunk: int = None, max_speakers_per_frame: int = None, weigh_by_cardinality: bool = False, - warm_up: Union[float, Tuple[float, float]] = 0.0, balance: Text = None, weight: Text = None, batch_size: int = 32, num_workers: int = None, pin_memory: bool = False, augmentation: BaseWaveformTransform = None, - vad_loss: Literal["bce", "mse"] = None, metric: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None, max_num_speakers: int = None, # deprecated in favor of `max_speakers_per_chunk`` loss: Literal["bce", "mse"] = None, # deprecated @@ -247,7 +236,6 @@ def __init__( super().__init__( protocol, duration=duration, - warm_up=warm_up, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, @@ -275,10 +263,6 @@ def __init__( raise ValueError( f"`max_speakers_per_frame` must be 1 or more (you used {max_speakers_per_frame})." ) - if vad_loss is not None: - raise ValueError( - "`vad_loss` cannot be used jointly with `max_speakers_per_frame`" - ) if batch_size % 3 != 0: raise ValueError("`batch_size` must be divisible by 3 for mixtures of mixtures training") @@ -288,7 +272,6 @@ def __init__( self.weigh_by_cardinality = weigh_by_cardinality self.balance = balance self.weight = weight - self.vad_loss = vad_loss self.separation_loss = ModifiedMixITLossWrapper(multisrc_neg_sisdr, generalized=True) self.mixit_loss_weight = mixit_loss_weight @@ -395,11 +378,10 @@ def setup(self): self.specifications = (speaker_diarization, speaker_separation) def setup_loss_func(self): - if self.specifications[0].powerset: - self.model.powerset = Powerset( - len(self.specifications[0].classes), - self.specifications[0].powerset_max_classes, - ) + self.model.powerset = Powerset( + len(self.specifications[0].classes), + self.specifications[0].powerset_max_classes, + ) def prepare_chunk(self, file_id: int, start_time: float, duration: float): """Prepare chunk @@ -792,63 +774,21 @@ def segmentation_loss( Permutation-invariant segmentation loss """ - if self.specifications[0].powerset: - # `clamp_min` is needed to set non-speech weight to 1. - class_weight = ( - torch.clamp_min(self.model.powerset.cardinality, 1.0) - if self.weigh_by_cardinality - else None - ) - seg_loss = nll_loss( - permutated_prediction, - torch.argmax(target, dim=-1), - class_weight=class_weight, - weight=weight, - ) - else: - seg_loss = binary_cross_entropy( - permutated_prediction, target.float(), weight=weight - ) + # `clamp_min` is needed to set non-speech weight to 1. + class_weight = ( + torch.clamp_min(self.model.powerset.cardinality, 1.0) + if self.weigh_by_cardinality + else None + ) + seg_loss = nll_loss( + permutated_prediction, + torch.argmax(target, dim=-1), + class_weight=class_weight, + weight=weight, + ) return seg_loss - def voice_activity_detection_loss( - self, - permutated_prediction: torch.Tensor, - target: torch.Tensor, - weight: torch.Tensor = None, - ) -> torch.Tensor: - """Voice activity detection loss - - Parameters - ---------- - permutated_prediction : (batch_size, num_frames, num_classes) torch.Tensor - Speaker activity predictions. - target : (batch_size, num_frames, num_speakers) torch.Tensor - Speaker activity. - weight : (batch_size, num_frames, 1) torch.Tensor, optional - Frames weight. - - Returns - ------- - vad_loss : torch.Tensor - Voice activity detection loss. - """ - - vad_prediction, _ = torch.max(permutated_prediction, dim=2, keepdim=True) - # (batch_size, num_frames, 1) - - vad_target, _ = torch.max(target.float(), dim=2, keepdim=False) - # (batch_size, num_frames) - - if self.vad_loss == "bce": - loss = binary_cross_entropy(vad_prediction, vad_target, weight=weight) - - elif self.vad_loss == "mse": - loss = mse_loss(vad_prediction, vad_target, weight=weight) - - return loss - def training_step(self, batch, batch_idx: int): """Compute permutation-invariant segmentation loss @@ -918,27 +858,15 @@ def training_step(self, batch, batch_idx: int): ) # (batch_size, num_frames, 1) - # warm-up - warm_up_left = round(self.warm_up[0] / self.duration * num_frames) - weight[:, :warm_up_left] = 0.0 - warm_up_right = round(self.warm_up[1] / self.duration * num_frames) - weight[:, num_frames - warm_up_right :] = 0.0 - - if self.specifications[0].powerset: - multilabel = self.model.powerset.to_multilabel(prediction) - permutated_target, permutations = permutate(multilabel, target) - permutated_target_powerset = self.model.powerset.to_powerset( - permutated_target.float() - ) - seg_loss = self.segmentation_loss( - prediction, permutated_target_powerset, weight=weight - ) + multilabel = self.model.powerset.to_multilabel(prediction) + permutated_target, permutations = permutate(multilabel, target) + permutated_target_powerset = self.model.powerset.to_powerset( + permutated_target.float() + ) + seg_loss = self.segmentation_loss( + prediction, permutated_target_powerset, weight=weight + ) - else: - permutated_prediction, permutations = permutate(target, prediction) - seg_loss = self.segmentation_loss( - permutated_prediction, target, weight=weight - ) # to find which predicted sources correspond to which mixtures, we need to invert the permutations permutations_inverse = torch.argsort(torch.tensor(permutations)) predicted_sources_idx_mix1 = [[permutations_inverse[i][j] for j in range(batch["meta"]["sources_from_first_mixture"][i])] for i in range(batch_size)] @@ -976,32 +904,7 @@ def training_step(self, batch, batch_idx: int): logger=True, ) - if self.vad_loss is None: - vad_loss = 0.0 - - else: - # TODO: vad_loss probably does not make sense in powerset mode - # because first class (empty set of labels) does exactly this... - if self.specifications[0].powerset: - vad_loss = self.voice_activity_detection_loss( - prediction, permutated_target_powerset, weight=weight - ) - - else: - vad_loss = self.voice_activity_detection_loss( - permutated_prediction, target, weight=weight - ) - - self.model.log( - "loss/train/vad", - vad_loss, - on_step=False, - on_epoch=True, - prog_bar=False, - logger=True, - ) - - loss = (1 - self.mixit_loss_weight) * (seg_loss + vad_loss) + self.mixit_loss_weight * mixit_loss + forced_alignment_loss + loss = (1 - self.mixit_loss_weight) * seg_loss + self.mixit_loss_weight * mixit_loss + forced_alignment_loss # skip batch if something went wrong for some reason if torch.isnan(loss): @@ -1023,20 +926,11 @@ def default_metric( ) -> Union[Metric, Sequence[Metric], Dict[str, Metric]]: """Returns diarization error rate and its components""" - if self.specifications[0].powerset: - return { - "DiarizationErrorRate": DiarizationErrorRate(0.5), - "DiarizationErrorRate/Confusion": SpeakerConfusionRate(0.5), - "DiarizationErrorRate/Miss": MissedDetectionRate(0.5), - "DiarizationErrorRate/FalseAlarm": FalseAlarmRate(0.5), - } - return { - "DiarizationErrorRate": OptimalDiarizationErrorRate(), - "DiarizationErrorRate/Threshold": OptimalDiarizationErrorRateThreshold(), - "DiarizationErrorRate/Confusion": OptimalSpeakerConfusionRate(), - "DiarizationErrorRate/Miss": OptimalMissedDetectionRate(), - "DiarizationErrorRate/FalseAlarm": OptimalFalseAlarmRate(), + "DiarizationErrorRate": DiarizationErrorRate(0.5), + "DiarizationErrorRate/Confusion": SpeakerConfusionRate(0.5), + "DiarizationErrorRate/Miss": MissedDetectionRate(0.5), + "DiarizationErrorRate/FalseAlarm": FalseAlarmRate(0.5), } # TODO: no need to compute gradient in this method @@ -1080,30 +974,18 @@ def validation_step(self, batch, batch_idx: int): ) # (batch_size, num_frames, 1) - # warm-up - warm_up_left = round(self.warm_up[0] / self.duration * num_frames) - weight[:, :warm_up_left] = 0.0 - warm_up_right = round(self.warm_up[1] / self.duration * num_frames) - weight[:, num_frames - warm_up_right :] = 0.0 - - if self.specifications[0].powerset: - multilabel = self.model.powerset.to_multilabel(prediction) - permutated_target, _ = permutate(multilabel, target) + multilabel = self.model.powerset.to_multilabel(prediction) + permutated_target, _ = permutate(multilabel, target) - # FIXME: handle case where target have too many speakers? - # since we don't need - permutated_target_powerset = self.model.powerset.to_powerset( - permutated_target.float() - ) - seg_loss = self.segmentation_loss( - prediction, permutated_target_powerset, weight=weight - ) + # FIXME: handle case where target have too many speakers? + # since we don't need + permutated_target_powerset = self.model.powerset.to_powerset( + permutated_target.float() + ) + seg_loss = self.segmentation_loss( + prediction, permutated_target_powerset, weight=weight + ) - else: - permutated_prediction, _ = permutate(target, prediction) - seg_loss = self.segmentation_loss( - permutated_prediction, target, weight=weight - ) # forced alignment mixit can't be implemented for validation because since data loading is different mixit_loss = 0 # mixit_loss = self.separation_loss( @@ -1128,32 +1010,7 @@ def validation_step(self, batch, batch_idx: int): logger=True, ) - if self.vad_loss is None: - vad_loss = 0.0 - - else: - # TODO: vad_loss probably does not make sense in powerset mode - # because first class (empty set of labels) does exactly this... - if self.specifications[0].powerset: - vad_loss = self.voice_activity_detection_loss( - prediction, permutated_target_powerset, weight=weight - ) - - else: - vad_loss = self.voice_activity_detection_loss( - permutated_prediction, target, weight=weight - ) - - self.model.log( - "loss/val/vad", - vad_loss, - on_step=False, - on_epoch=True, - prog_bar=False, - logger=True, - ) - - loss = (1 - self.mixit_loss_weight) * (seg_loss + vad_loss) + self.mixit_loss_weight * mixit_loss + loss = (1 - self.mixit_loss_weight) * seg_loss + self.mixit_loss_weight * mixit_loss self.model.log( "loss/val", @@ -1164,24 +1021,15 @@ def validation_step(self, batch, batch_idx: int): logger=True, ) - if self.specifications[0].powerset: - self.model.validation_metric( - torch.transpose( - multilabel[:, warm_up_left : num_frames - warm_up_right], 1, 2 - ), - torch.transpose( - target[:, warm_up_left : num_frames - warm_up_right], 1, 2 - ), - ) - else: - self.model.validation_metric( - torch.transpose( - prediction[:, warm_up_left : num_frames - warm_up_right], 1, 2 - ), - torch.transpose( - target[:, warm_up_left : num_frames - warm_up_right], 1, 2 - ), - ) + self.model.validation_metric( + torch.transpose( + multilabel, 1, 2 + ), + torch.transpose( + target, 1, 2 + ), + ) + self.model.log_dict( self.model.validation_metric, @@ -1201,12 +1049,8 @@ def validation_step(self, batch, batch_idx: int): # visualize first 9 validation samples of first batch in Tensorboard/MLflow - if self.specifications[0].powerset: - y = permutated_target.float().cpu().numpy() - y_pred = multilabel.cpu().numpy() - else: - y = target.float().cpu().numpy() - y_pred = permutated_prediction.cpu().numpy() + y = permutated_target.float().cpu().numpy() + y_pred = multilabel.cpu().numpy() # prepare 3 x 3 grid (or smaller if batch size is smaller) num_samples = min(self.batch_size, 9) @@ -1240,10 +1084,6 @@ def validation_step(self, batch, batch_idx: int): # plot predictions ax_hyp = axes[row_idx * 2 + 1, col_idx] sample_y_pred = y_pred[sample_idx] - ax_hyp.axvspan(0, warm_up_left, color="k", alpha=0.5, lw=0) - ax_hyp.axvspan( - num_frames - warm_up_right, num_frames, color="k", alpha=0.5, lw=0 - ) ax_hyp.plot(sample_y_pred) ax_hyp.set_ylim(-0.1, 1.1) ax_hyp.set_xlim(0, len(sample_y)) From 5037031398e191d6f2346324de5ab2c41afb01a4 Mon Sep 17 00:00:00 2001 From: joonaskalda Date: Tue, 20 Jun 2023 19:41:29 +0200 Subject: [PATCH 34/55] remove double check of num_speakers --- .../tasks/segmentation/speaker_separation_diarization.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py b/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py index fa0e8e8f7..56284b501 100644 --- a/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py +++ b/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py @@ -814,9 +814,6 @@ def training_step(self, batch, batch_idx: int): # drop samples that contain too many speakers num_speakers: torch.Tensor = torch.sum(torch.any(target, dim=1), dim=1) - keep: torch.Tensor = num_speakers <= self.max_speakers_per_chunk - target = target[keep] - waveform = waveform[keep] # corner case if not keep.any(): @@ -839,12 +836,6 @@ def training_step(self, batch, batch_idx: int): _, predicted_sources_mix1 = self.model(mix1) _, predicted_sources_mix2 = self.model(mix2) - # don't use moms with more than max_speakers_per_chunk speakers for training speaker diarization - num_speakers: torch.Tensor = torch.sum(torch.any(target, dim=1), dim=1) - num_speakers[2::3] = num_speakers[::3] + num_speakers[1::3] - keep: torch.Tensor = num_speakers <= self.max_speakers_per_chunk - target = target[keep] - waveform = waveform[keep] prediction, _ = self.model(waveform) batch_size, num_frames, _ = prediction.shape From 9f20916e53479957dc075b71591def7e2bb0c8ce Mon Sep 17 00:00:00 2001 From: joonaskalda Date: Wed, 21 Jun 2023 14:50:04 +0200 Subject: [PATCH 35/55] refactor: moved mom constrcution --- .../speaker_separation_diarization.py | 207 +++++++++--------- 1 file changed, 99 insertions(+), 108 deletions(-) diff --git a/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py b/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py index 56284b501..7b3102169 100644 --- a/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py +++ b/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py @@ -231,7 +231,9 @@ def __init__( metric: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None, max_num_speakers: int = None, # deprecated in favor of `max_speakers_per_chunk`` loss: Literal["bce", "mse"] = None, # deprecated - mixit_loss_weight: float = 0.2, + mixit_loss_weight: float = 0.5, + original_mixtures_for_separation: bool = False, + forced_alignment_weight: float = 0.0, ): super().__init__( protocol, @@ -264,8 +266,8 @@ def __init__( f"`max_speakers_per_frame` must be 1 or more (you used {max_speakers_per_frame})." ) - if batch_size % 3 != 0: - raise ValueError("`batch_size` must be divisible by 3 for mixtures of mixtures training") + if batch_size % 2 != 0: + raise ValueError("`batch_size` must be divisible by 2 for mixtures of mixtures training") self.max_speakers_per_chunk = max_speakers_per_chunk self.max_speakers_per_frame = max_speakers_per_frame @@ -274,6 +276,8 @@ def __init__( self.weight = weight self.separation_loss = ModifiedMixITLossWrapper(multisrc_neg_sisdr, generalized=True) self.mixit_loss_weight = mixit_loss_weight + self.original_mixtures_for_separation = original_mixtures_for_separation + self.forced_alignment_weight = forced_alignment_weight def setup(self): super().setup() @@ -423,7 +427,6 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): # use model introspection to predict how many frames it will output # TODO: this should be cached num_samples = sample["X"].shape[1] - #resolution_samples = self.model.example_output[0].frames.step * self.model.example_output[0].num_frames / num_samples # gather all annotations of current file annotations = self.annotations[self.annotations["file_id"] == file_id] @@ -436,10 +439,8 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): # discretize chunk annotations at model output and input resolutions start = np.maximum(chunk_annotations["start"], chunk.start) - chunk.start start_idx = np.floor(start / self.model.example_output[0].frames.step).astype(int) - start_idx_samples = np.floor(start * 16000).astype(int) end = np.minimum(chunk_annotations["end"], chunk.end) - chunk.start end_idx = np.ceil(end / self.model.example_output[0].frames.step).astype(int) - end_idx_samples = np.floor(end * 16000).astype(int) # get list and number of labels for current scope labels = list(np.unique(chunk_annotations[label_scope_key])) @@ -450,7 +451,6 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): # initial frame-level targets y = np.zeros((self.model.example_output[0].num_frames, num_labels), dtype=np.uint8) - sample_level_labels = np.zeros((num_samples, num_labels), dtype=np.uint8) # map labels to indices mapping = {label: idx for idx, label in enumerate(labels)} @@ -464,15 +464,20 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): sample["y"] = SlidingWindowFeature( y, self.model.example_output[0].frames, labels=labels ) - - for start, end, label in zip( - start_idx_samples, end_idx_samples, chunk_annotations[label_scope_key] - ): - mapped_label = mapping[label] - sample_level_labels[start:end, mapped_label] = 1 - # only frames with a single label should be used for mixit training - sample["X_separation_mask"] = torch.from_numpy(sample_level_labels.sum(axis=1) == 1) + if self.original_mixtures_for_separation: + start_idx_samples = np.floor(start * 16000).astype(int) + end_idx_samples = np.floor(end * 16000).astype(int) + sample_level_labels = np.zeros((num_samples, num_labels), dtype=np.uint8) + for start, end, label in zip( + start_idx_samples, end_idx_samples, chunk_annotations[label_scope_key] + ): + mapped_label = mapping[label] + sample_level_labels[start:end, mapped_label] = 1 + + # only frames with a single label should be used for mixit training + sample["X_separation_mask"] = torch.from_numpy(sample_level_labels.sum(axis=1) == 1) + metadata = self.metadata[file_id] sample["meta"] = {key: metadata[key] for key in metadata.dtype.names} sample["meta"]["file"] = file_id @@ -546,11 +551,6 @@ def train__iter__helper(self, rng: random.Random, **filters): if repeated_speaker_annotations.size == 0: # if previous chunk has 0 speakers then just sample from all annotated regions again first_chunk = self.prepare_chunk(file_id, start_time, duration) - first_chunk["meta"]["mixture_type"]="first_mixture" - # in order to align separation and diarization branches we need to know which mixtures do speakers/sources originate from - first_chunk["meta"]["sources_from_first_mixture"] = len(first_chunk["y"].labels) - first_chunk["meta"]["sources_from_second_mixture"] = 0 - # yield first_chunk # selected one annotated region at random (with probability proportional to its duration) annotated_region_index = np.random.choice( @@ -562,30 +562,12 @@ def train__iter__helper(self, rng: random.Random, **filters): start_time = rng.uniform(start, end - duration) second_chunk = self.prepare_chunk(file_id, start_time, duration) - second_chunk["meta"]["mixture_type"]="second_mixture" - second_chunk["meta"]["sources_from_first_mixture"] = 0 - second_chunk["meta"]["sources_from_second_mixture"] = len(second_chunk["y"].labels) - # yield second_chunk - - # add previous two chunks to get a third one - third_chunk = dict() - third_chunk["X"] = first_chunk["X"] + second_chunk["X"] - third_chunk["meta"] = first_chunk["meta"].copy() - y = np.concatenate((first_chunk["y"].data, second_chunk["y"].data), axis=1) - frames = first_chunk["y"].sliding_window - labels = first_chunk["y"].labels + second_chunk["y"].labels - third_chunk["y"] = SlidingWindowFeature(y, frames, labels=labels) - third_chunk["meta"]["mixture_type"]="mom" - third_chunk["meta"]["sources_from_first_mixture"] = len(first_chunk["y"].labels) - third_chunk["meta"]["sources_from_second_mixture"] = len(second_chunk["y"].labels) - # the whole mom should be used in the separation branch training - third_chunk["X_separation_mask"] = torch.ones_like(first_chunk["X_separation_mask"]) + labels = first_chunk["y"].labels + second_chunk["y"].labels - if len(labels) < 4: + if len(labels) <= self.max_speakers_per_chunk: yield first_chunk yield second_chunk - yield third_chunk else: # merge segments that contain repeated speakers @@ -614,14 +596,8 @@ def train__iter__helper(self, rng: random.Random, **filters): segments_without_repeat = np.array(segments_without_repeat,dtype=dtype) if np.sum(segments_without_repeat["duration"]) != 0: - # only yield chunks if it is possible to choose the second chunk so that yielded chunks are always paired - first_chunk = self.prepare_chunk(file_id, start_time, duration) - first_chunk["meta"]["mixture_type"]="first_mixture" - first_chunk["meta"]["sources_from_first_mixture"] = len(first_chunk["y"].labels) - first_chunk["meta"]["sources_from_second_mixture"] = 0 - #yield first_chunk prob_segments_duration = segments_without_repeat["duration"] / np.sum(segments_without_repeat["duration"]) segment = np.random.choice( @@ -631,30 +607,11 @@ def train__iter__helper(self, rng: random.Random, **filters): start, end, _ = segment new_start_time = rng.uniform(start, end - duration) second_chunk = self.prepare_chunk(file_id, new_start_time, duration) - second_chunk["meta"]["mixture_type"]="second_mixture" - second_chunk["meta"]["sources_from_first_mixture"] = 0 - second_chunk["meta"]["sources_from_second_mixture"] = len(second_chunk["y"].labels) - #yield second_chunk - - #add previous two chunks to get a third one - third_chunk = dict() - third_chunk["X"] = first_chunk["X"] + second_chunk["X"] - third_chunk["meta"] = first_chunk["meta"].copy() - y = np.concatenate((first_chunk["y"].data, second_chunk["y"].data), axis=1) - frames = first_chunk["y"].sliding_window + labels = first_chunk["y"].labels + second_chunk["y"].labels - third_chunk["y"] = SlidingWindowFeature(y, frames, labels=labels) - third_chunk["meta"]["mixture_type"]="mom" - - # the whole mom should be used in the separation branch training - third_chunk["X_separation_mask"] = torch.ones_like(first_chunk["X_separation_mask"]) - third_chunk["meta"]["sources_from_first_mixture"] = len(first_chunk["y"].labels) - third_chunk["meta"]["sources_from_second_mixture"] = len(second_chunk["y"].labels) - #third_chunk["sources_from_first_mixture"] = len(first_chunk["y"].labels) - if len(labels) < 4: + if len(labels) <= self.max_speakers_per_chunk: yield first_chunk yield second_chunk - yield third_chunk def collate_X_separation_mask(self, batch) -> torch.Tensor: return default_collate([b["X_separation_mask"] for b in batch]) @@ -688,7 +645,8 @@ def collate_fn(self, batch, stage="train"): # collate metadata collated_meta = self.collate_meta(batch) - collated_X_separation_mask = self.collate_X_separation_mask(batch) + if self.original_mixtures_for_separation: + collated_X_separation_mask = self.collate_X_separation_mask(batch) # apply augmentation (only in "train" stage) self.augmentation.train(mode=(stage == "train")) @@ -698,11 +656,17 @@ def collate_fn(self, batch, stage="train"): targets=collated_y.unsqueeze(1), ) + if self.original_mixtures_for_separation: + return { + "X": augmented.samples, + "y": augmented.targets.squeeze(1), + "meta": collated_meta, + "X_separation_mask" : collated_X_separation_mask + } return { "X": augmented.samples, "y": augmented.targets.squeeze(1), - "meta": collated_meta, - "X_separation_mask" : collated_X_separation_mask + "meta": collated_meta } def collate_y(self, batch) -> torch.Tensor: @@ -789,6 +753,28 @@ def segmentation_loss( return seg_loss + def create_mixtures_of_mixtures(self, mix1, mix2, target1, target2): + """ + Creates mixtures of mixtures from two mixtures and their targets.""" + # mapping1[i1] = i means that i1-th speaker of mix1/tgt1 has been mapped to i-th speaker of mom/tgt + # mapping2[i2] = i means that i2-th speaker of mix2/tgt2 has been mapped to i-th speaker of mom/tgt + # mapping1[i1] = None means that i1-th (inactive) speaker of mix1/tgt1 does not exist in mom/tgt + batch_size = mix1.shape[0] + mom = mix1 + mix2 + # (batch_size, num_speakers, num_frames) + num_active_speakers_mix1 = (target1.sum(dim=1) != 0).sum(dim=1) + num_active_speakers_mix2 = (target2.sum(dim=1) != 0).sum(dim=1) + targets = [] + for i in range(batch_size): + target = torch.cat((target1[i][:, target1[i].sum(dim=0) != 0], target2[i][:, target2[i].sum(dim=0) != 0]), dim=1) + padding_dim = target1.shape[2] - num_active_speakers_mix1[i] - num_active_speakers_mix2[i] + padding_tensor = torch.zeros((target1.shape[1], padding_dim), device=target.device) + target = torch.cat((target, padding_tensor), dim=1) + targets.append(target) + targets = torch.stack(targets) + + return mom, targets, num_active_speakers_mix1, num_active_speakers_mix2 + def training_step(self, batch, batch_idx: int): """Compute permutation-invariant segmentation loss @@ -815,30 +801,27 @@ def training_step(self, batch, batch_idx: int): # drop samples that contain too many speakers num_speakers: torch.Tensor = torch.sum(torch.any(target, dim=1), dim=1) - # corner case - if not keep.any(): - return None - # forward pass bsz = waveform.shape[0] num_samples = waveform.shape[2] - mix1 = waveform[0::3].squeeze(1) - mix2 = waveform[1::3].squeeze(1) - # extract parts with only one speaker from original mixtures - mix1_masks = batch["X_separation_mask"][0::3] - mix2_masks = batch["X_separation_mask"][1::3] - mix1_masked = mix1 * mix1_masks - mix2_masked = mix2 * mix2_masks - - moms = mix1 + mix2 - - _, predicted_sources_mom = self.model(moms) - _, predicted_sources_mix1 = self.model(mix1) - _, predicted_sources_mix2 = self.model(mix2) - - prediction, _ = self.model(waveform) - - batch_size, num_frames, _ = prediction.shape + mix1 = waveform[0::2].squeeze(1) + mix2 = waveform[1::2].squeeze(1) + if self.original_mixtures_for_separation: + # extract parts with only one speaker from original mixtures + mix1_masks = batch["X_separation_mask"][0::2] + mix2_masks = batch["X_separation_mask"][1::2] + mix1_masked = mix1 * mix1_masks + mix2_masked = mix2 * mix2_masks + + mom, mom_target, num_active_speakers_mix1, num_active_speakers_mix2 = self.create_mixtures_of_mixtures(mix1, mix2, target[0::2], target[1::2]) + target = torch.cat((target[0::2], target[1::2], mom_target), dim=0) + + diarization, sources = self.model(torch.cat((mix1, mix2, mom), dim=0)) + mix1_sources = sources[:bsz // 2] + mix2_sources = sources[bsz // 2:bsz] + mom_sources = sources[bsz:] + + batch_size, num_frames, _ = diarization.shape # (batch_size, num_frames, num_classes) # frames weight @@ -849,32 +832,40 @@ def training_step(self, batch, batch_idx: int): ) # (batch_size, num_frames, 1) - multilabel = self.model.powerset.to_multilabel(prediction) + multilabel = self.model.powerset.to_multilabel(diarization) permutated_target, permutations = permutate(multilabel, target) permutated_target_powerset = self.model.powerset.to_powerset( permutated_target.float() ) seg_loss = self.segmentation_loss( - prediction, permutated_target_powerset, weight=weight + diarization, permutated_target_powerset, weight=weight ) # to find which predicted sources correspond to which mixtures, we need to invert the permutations permutations_inverse = torch.argsort(torch.tensor(permutations)) - predicted_sources_idx_mix1 = [[permutations_inverse[i][j] for j in range(batch["meta"]["sources_from_first_mixture"][i])] for i in range(batch_size)] - predicted_sources_idx_mix2 = [[permutations_inverse[i][j] for j in range(batch["meta"]["sources_from_first_mixture"][i],batch["meta"]["sources_from_second_mixture"][i])] for i in range(batch_size)] + speaker_idx_mix1 = [[permutations_inverse[i][j] for j in range(num_active_speakers_mix1[i])] for i in range(bsz//2)] + speaker_idx_mix2 = [[permutations_inverse[i][j] for j in range(num_active_speakers_mix1[i], num_active_speakers_mix2[i])] for i in range(bsz//2)] # contributions from original mixtures is weighed by the proportion of remaining frames - mixit_loss = self.separation_loss( - predicted_sources_mom.transpose(1, 2), torch.stack((mix1, mix2)).transpose(0, 1), predicted_sources_idx_mix1[2::3], predicted_sources_idx_mix2[2::3] - ) + self.separation_loss( - predicted_sources_mix1.transpose(1, 2), torch.stack((mix1_masked, torch.zeros_like(mix1))).transpose(0, 1), predicted_sources_idx_mix1[0::3], predicted_sources_idx_mix2[0::3] - ) * mix1_masks.sum() / num_samples / bsz * 3 + self.separation_loss( - predicted_sources_mix2.transpose(1, 2), torch.stack((mix2_masked, torch.zeros_like(mix2))).transpose(0, 1), predicted_sources_idx_mix1[1::3], predicted_sources_idx_mix2[1::3] - ) * mix2_masks.sum() / num_samples / bsz * 3 + est_mixes = [] + for i in range(bsz//2): + est_mix1 = mom_sources[i, :, speaker_idx_mix1[i]].sum(1) + est_mix2 = mom_sources[i, :, speaker_idx_mix2[i]].sum(1) + est_mixes.append(torch.stack((est_mix1, est_mix2))) + est_mixes = torch.stack(est_mixes) + mixit_loss = multisrc_neg_sisdr(est_mixes, torch.stack((mix1, mix2)).transpose(0, 1)).mean() + + if self.original_mixtures_for_separation: + raise NotImplementedError + # mixit_loss += self.separation_loss( + # predicted_sources_mix1.transpose(1, 2), torch.stack((mix1_masked, torch.zeros_like(mix1))).transpose(0, 1), speaker_idx_mix1[0::3], speaker_idx_mix2[0::3] + # ) * mix1_masks.sum() / num_samples / bsz * 3 + self.separation_loss( + # predicted_sources_mix2.transpose(1, 2), torch.stack((mix2_masked, torch.zeros_like(mix2))).transpose(0, 1), speaker_idx_mix1[1::3], speaker_idx_mix2[1::3] + # ) * mix2_masks.sum() / num_samples / bsz * 3 upscaled_permutated_target = torch.nn.functional.interpolate(permutated_target.transpose(1, 2), size=(80000)).transpose(1, 2) - forced_alignment_loss = (1 - 2 * upscaled_permutated_target[::3]) * predicted_sources_mix1 ** 2 +\ - (1 - 2 * upscaled_permutated_target[1::3]) * predicted_sources_mix2 ** 2 +\ - (1 - 2 * upscaled_permutated_target[2::3]) * predicted_sources_mom ** 2 + forced_alignment_loss = (1 - 2 * upscaled_permutated_target[:bsz//2]) * mix1_sources ** 2 +\ + (1 - 2 * upscaled_permutated_target[bsz//2:bsz]) * mix2_sources ** 2 +\ + (1 - 2 * upscaled_permutated_target[bsz:]) * mom_sources ** 2 forced_alignment_loss = forced_alignment_loss.mean() / 3 self.model.log( @@ -895,7 +886,7 @@ def training_step(self, batch, batch_idx: int): logger=True, ) - loss = (1 - self.mixit_loss_weight) * seg_loss + self.mixit_loss_weight * mixit_loss + forced_alignment_loss + loss = (1 - self.mixit_loss_weight) * seg_loss + self.mixit_loss_weight * mixit_loss + forced_alignment_loss * self.forced_alignment_weight # skip batch if something went wrong for some reason if torch.isnan(loss): @@ -954,7 +945,7 @@ def validation_step(self, batch, batch_idx: int): # forward pass prediction, _ = self.model(waveform) - _, predicted_sources_mom = self.model(moms) + _, mom_sources = self.model(moms) batch_size, num_frames, _ = prediction.shape # frames weight @@ -980,7 +971,7 @@ def validation_step(self, batch, batch_idx: int): # forced alignment mixit can't be implemented for validation because since data loading is different mixit_loss = 0 # mixit_loss = self.separation_loss( - # predicted_sources_mom.transpose(1, 2), torch.stack((mix1, mix2)).transpose(0, 1) + # mom_sources.transpose(1, 2), torch.stack((mix1, mix2)).transpose(0, 1) # ) self.model.log( From d09e828b4535ad9245d0b094d31b37c796eb6ad2 Mon Sep 17 00:00:00 2001 From: joonaskalda Date: Wed, 21 Jun 2023 14:52:28 +0200 Subject: [PATCH 36/55] remove unused mixit wrapper --- .../speaker_separation_diarization.py | 93 ------------------- 1 file changed, 93 deletions(-) diff --git a/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py b/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py index 7b3102169..e2b7cbebf 100644 --- a/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py +++ b/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py @@ -63,98 +63,6 @@ from itertools import combinations from torch import nn -class ModifiedMixITLossWrapper(nn.Module): - r"""Mixture invariant loss wrapper modifed to force alignment between separation and diarization. - - Args: - loss_func: function with signature (est_targets, targets, **kwargs). - generalized (bool): Determines how MixIT is applied. If False , - apply MixIT for any number of mixtures as soon as they contain - the same number of sources (:meth:`~MixITLossWrapper.best_part_mixit`.) - If True (default), apply MixIT for two mixtures, but those mixtures do not - necessarly have to contain the same number of sources. - See :meth:`~MixITLossWrapper.best_part_mixit_generalized`. - reduction (string, optional): Specifies the reduction to apply to - the output: - ``'none'`` | ``'mean'``. ``'none'``: no reduction will be applied, - ``'mean'``: the sum of the output will be divided by the number of - elements in the output. - - For each of these modes, the best partition and reordering will be - automatically computed. - - Examples: - >>> import torch - >>> from asteroid.losses import multisrc_mse - >>> mixtures = torch.randn(10, 2, 16000) - >>> est_sources = torch.randn(10, 4, 16000) - >>> # Compute MixIT loss based on pairwise losses - >>> loss_func = MixITLossWrapper(multisrc_mse) - >>> loss_val = loss_func(est_sources, mixtures) - - References - [1] Scott Wisdom et al. "Unsupervised sound separation using - mixtures of mixtures." arXiv:2006.12701 (2020) - """ - - def __init__(self, loss_func, generalized=True, reduction="mean"): - super().__init__() - self.loss_func = loss_func - self.generalized = generalized - self.reduction = reduction - - def forward(self, est_targets, targets, part_from_mix1, part_from_mix2, return_est=False, **kwargs): - r"""Find the best partition and return the loss. - - Args: - est_targets: torch.Tensor. Expected shape :math:`(batch, nsrc, *)`. - The batch of target estimates. - targets: torch.Tensor. Expected shape :math:`(batch, nmix, ...)`. - The batch of training targets - return_est: Boolean. Whether to return the estimated mixtures - estimates (To compute metrics or to save example). - **kwargs: additional keyword argument that will be passed to the - loss function. - - Returns: - - Best partition loss for each batch sample, average over - the batch. torch.Tensor(loss_value) - - The estimated mixtures (estimated sources summed according to the partition) - if return_est is True. torch.Tensor of shape :math:`(batch, nmix, ...)`. - """ - # Check input dimensions - assert est_targets.shape[0] == targets.shape[0] - assert est_targets.shape[2] == targets.shape[2] - - # if not self.generalized: - # min_loss, min_loss_idx, parts = self.best_part_mixit( - # self.loss_func, est_targets, targets, **kwargs - # ) - # else: - # min_loss, min_loss_idx, parts = self.best_part_mixit_generalized( - # self.loss_func, est_targets, targets, **kwargs - # ) - est_mixes = [] - for i in range(est_targets.shape[0]): - # sum the sources according to the given partition - est_mix1 = est_targets[i, part_from_mix1[i], :].sum(0) - est_mix2 = est_targets[i, part_from_mix2[i], :].sum(0) - # get loss for the given partition - - est_mixes.append(torch.stack((est_mix1, est_mix2))) - est_mixes = torch.stack(est_mixes) - loss_partition = self.loss_func(est_mixes, targets, **kwargs) - if loss_partition.ndim != 1: - raise ValueError("Loss function return value should be of size (batch,).") - - # Apply any reductions over the batch axis - returned_loss = loss_partition.mean() if self.reduction == "mean" else loss_partition - if not return_est: - return returned_loss - - # Order and sum on the best partition to get the estimated mixtures - # reordered = self.reorder_source(est_targets, targets, min_loss_idx, parts) - return returned_loss, est_mixes class JointSpeakerSeparationAndDiarization(SegmentationTaskMixin, Task): """Speaker diarization @@ -274,7 +182,6 @@ def __init__( self.weigh_by_cardinality = weigh_by_cardinality self.balance = balance self.weight = weight - self.separation_loss = ModifiedMixITLossWrapper(multisrc_neg_sisdr, generalized=True) self.mixit_loss_weight = mixit_loss_weight self.original_mixtures_for_separation = original_mixtures_for_separation self.forced_alignment_weight = forced_alignment_weight From 9605155443ad93bb152dfa3a5404ada691c081c3 Mon Sep 17 00:00:00 2001 From: joonaskalda Date: Wed, 21 Jun 2023 14:59:41 +0200 Subject: [PATCH 37/55] format with black --- .../speaker_separation_diarization.py | 168 +++++++++++++----- 1 file changed, 119 insertions(+), 49 deletions(-) diff --git a/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py b/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py index e2b7cbebf..7a11d5d09 100644 --- a/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py +++ b/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py @@ -175,7 +175,9 @@ def __init__( ) if batch_size % 2 != 0: - raise ValueError("`batch_size` must be divisible by 2 for mixtures of mixtures training") + raise ValueError( + "`batch_size` must be divisible by 2 for mixtures of mixtures training" + ) self.max_speakers_per_chunk = max_speakers_per_chunk self.max_speakers_per_frame = max_speakers_per_frame @@ -282,7 +284,7 @@ def setup(self): speaker_separation = Specifications( duration=self.duration, resolution=Resolution.FRAME, - problem=Problem.MONO_LABEL_CLASSIFICATION, # Doesn't matter + problem=Problem.MONO_LABEL_CLASSIFICATION, # Doesn't matter classes=[f"speaker#{i+1}" for i in range(self.max_speakers_per_chunk)], ) @@ -345,7 +347,9 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): # discretize chunk annotations at model output and input resolutions start = np.maximum(chunk_annotations["start"], chunk.start) - chunk.start - start_idx = np.floor(start / self.model.example_output[0].frames.step).astype(int) + start_idx = np.floor(start / self.model.example_output[0].frames.step).astype( + int + ) end = np.minimum(chunk_annotations["end"], chunk.end) - chunk.start end_idx = np.ceil(end / self.model.example_output[0].frames.step).astype(int) @@ -357,7 +361,9 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): pass # initial frame-level targets - y = np.zeros((self.model.example_output[0].num_frames, num_labels), dtype=np.uint8) + y = np.zeros( + (self.model.example_output[0].num_frames, num_labels), dtype=np.uint8 + ) # map labels to indices mapping = {label: idx for idx, label in enumerate(labels)} @@ -383,7 +389,9 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): sample_level_labels[start:end, mapped_label] = 1 # only frames with a single label should be used for mixit training - sample["X_separation_mask"] = torch.from_numpy(sample_level_labels.sum(axis=1) == 1) + sample["X_separation_mask"] = torch.from_numpy( + sample_level_labels.sum(axis=1) == 1 + ) metadata = self.metadata[file_id] sample["meta"] = {key: metadata[key] for key in metadata.dtype.names} @@ -425,7 +433,9 @@ def train__iter__helper(self, rng: random.Random, **filters): while True: # select one file at random (with probability proportional to its annotated duration) file_id = np.random.choice(file_ids, p=prob_annotated_duration) - annotations = self.annotations[np.where(self.annotations["file_id"] == file_id)[0]] + annotations = self.annotations[ + np.where(self.annotations["file_id"] == file_id)[0] + ] # generate `num_chunks_per_file` chunks from this file for _ in range(num_chunks_per_file): @@ -450,11 +460,16 @@ def train__iter__helper(self, rng: random.Random, **filters): # find speakers that already appeared and all annotations that contain them chunk_annotations = annotations[ - (annotations["start"] < start_time+duration) & (annotations["end"] > start_time) + (annotations["start"] < start_time + duration) + & (annotations["end"] > start_time) + ] + previous_speaker_labels = list( + np.unique(chunk_annotations["file_label_idx"]) + ) + repeated_speaker_annotations = annotations[ + np.isin(annotations["file_label_idx"], previous_speaker_labels) ] - previous_speaker_labels = list(np.unique(chunk_annotations["file_label_idx"])) - repeated_speaker_annotations = annotations[np.isin(annotations["file_label_idx"], previous_speaker_labels)] - + if repeated_speaker_annotations.size == 0: # if previous chunk has 0 speakers then just sample from all annotated regions again first_chunk = self.prepare_chunk(file_id, start_time, duration) @@ -475,45 +490,67 @@ def train__iter__helper(self, rng: random.Random, **filters): if len(labels) <= self.max_speakers_per_chunk: yield first_chunk yield second_chunk - + else: # merge segments that contain repeated speakers - merged_repeated_segments = [[repeated_speaker_annotations["start"][0],repeated_speaker_annotations["end"][0]]] + merged_repeated_segments = [ + [ + repeated_speaker_annotations["start"][0], + repeated_speaker_annotations["end"][0], + ] + ] for _, start, end, _, _, _ in repeated_speaker_annotations: previous = merged_repeated_segments[-1] if start <= previous[1]: previous[1] = max(previous[1], end) else: merged_repeated_segments.append([start, end]) - + # find segments that don't contain repeated speakers segments_without_repeat = [] current_region_index = 0 - previous_time = self.annotated_regions["start"][annotated_region_indices[0]] + previous_time = self.annotated_regions["start"][ + annotated_region_indices[0] + ] for segment in merged_repeated_segments: - if segment[0] > self.annotated_regions["end"][annotated_region_indices[current_region_index]]: - current_region_index+=1 - previous_time = self.annotated_regions["start"][annotated_region_indices[current_region_index]] - + if ( + segment[0] + > self.annotated_regions["end"][ + annotated_region_indices[current_region_index] + ] + ): + current_region_index += 1 + previous_time = self.annotated_regions["start"][ + annotated_region_indices[current_region_index] + ] + if segment[0] - previous_time > duration: - segments_without_repeat.append((previous_time, segment[0], segment[0] - previous_time)) + segments_without_repeat.append( + (previous_time, segment[0], segment[0] - previous_time) + ) previous_time = segment[1] - - dtype = [("start", "f"), ("end", "f"),("duration", "f")] - segments_without_repeat = np.array(segments_without_repeat,dtype=dtype) + + dtype = [("start", "f"), ("end", "f"), ("duration", "f")] + segments_without_repeat = np.array( + segments_without_repeat, dtype=dtype + ) if np.sum(segments_without_repeat["duration"]) != 0: # only yield chunks if it is possible to choose the second chunk so that yielded chunks are always paired first_chunk = self.prepare_chunk(file_id, start_time, duration) - prob_segments_duration = segments_without_repeat["duration"] / np.sum(segments_without_repeat["duration"]) + prob_segments_duration = segments_without_repeat[ + "duration" + ] / np.sum(segments_without_repeat["duration"]) segment = np.random.choice( segments_without_repeat, p=prob_segments_duration ) start, end, _ = segment new_start_time = rng.uniform(start, end - duration) - second_chunk = self.prepare_chunk(file_id, new_start_time, duration) + second_chunk = self.prepare_chunk( + file_id, new_start_time, duration + ) labels = first_chunk["y"].labels + second_chunk["y"].labels if len(labels) <= self.max_speakers_per_chunk: @@ -568,12 +605,12 @@ def collate_fn(self, batch, stage="train"): "X": augmented.samples, "y": augmented.targets.squeeze(1), "meta": collated_meta, - "X_separation_mask" : collated_X_separation_mask + "X_separation_mask": collated_X_separation_mask, } return { "X": augmented.samples, "y": augmented.targets.squeeze(1), - "meta": collated_meta + "meta": collated_meta, } def collate_y(self, batch) -> torch.Tensor: @@ -673,9 +710,21 @@ def create_mixtures_of_mixtures(self, mix1, mix2, target1, target2): num_active_speakers_mix2 = (target2.sum(dim=1) != 0).sum(dim=1) targets = [] for i in range(batch_size): - target = torch.cat((target1[i][:, target1[i].sum(dim=0) != 0], target2[i][:, target2[i].sum(dim=0) != 0]), dim=1) - padding_dim = target1.shape[2] - num_active_speakers_mix1[i] - num_active_speakers_mix2[i] - padding_tensor = torch.zeros((target1.shape[1], padding_dim), device=target.device) + target = torch.cat( + ( + target1[i][:, target1[i].sum(dim=0) != 0], + target2[i][:, target2[i].sum(dim=0) != 0], + ), + dim=1, + ) + padding_dim = ( + target1.shape[2] + - num_active_speakers_mix1[i] + - num_active_speakers_mix2[i] + ) + padding_tensor = torch.zeros( + (target1.shape[1], padding_dim), device=target.device + ) target = torch.cat((target, padding_tensor), dim=1) targets.append(target) targets = torch.stack(targets) @@ -720,12 +769,17 @@ def training_step(self, batch, batch_idx: int): mix1_masked = mix1 * mix1_masks mix2_masked = mix2 * mix2_masks - mom, mom_target, num_active_speakers_mix1, num_active_speakers_mix2 = self.create_mixtures_of_mixtures(mix1, mix2, target[0::2], target[1::2]) + ( + mom, + mom_target, + num_active_speakers_mix1, + num_active_speakers_mix2, + ) = self.create_mixtures_of_mixtures(mix1, mix2, target[0::2], target[1::2]) target = torch.cat((target[0::2], target[1::2], mom_target), dim=0) diarization, sources = self.model(torch.cat((mix1, mix2, mom), dim=0)) - mix1_sources = sources[:bsz // 2] - mix2_sources = sources[bsz // 2:bsz] + mix1_sources = sources[: bsz // 2] + mix2_sources = sources[bsz // 2 : bsz] mom_sources = sources[bsz:] batch_size, num_frames, _ = diarization.shape @@ -750,16 +804,27 @@ def training_step(self, batch, batch_idx: int): # to find which predicted sources correspond to which mixtures, we need to invert the permutations permutations_inverse = torch.argsort(torch.tensor(permutations)) - speaker_idx_mix1 = [[permutations_inverse[i][j] for j in range(num_active_speakers_mix1[i])] for i in range(bsz//2)] - speaker_idx_mix2 = [[permutations_inverse[i][j] for j in range(num_active_speakers_mix1[i], num_active_speakers_mix2[i])] for i in range(bsz//2)] + speaker_idx_mix1 = [ + [permutations_inverse[i][j] for j in range(num_active_speakers_mix1[i])] + for i in range(bsz // 2) + ] + speaker_idx_mix2 = [ + [ + permutations_inverse[i][j] + for j in range(num_active_speakers_mix1[i], num_active_speakers_mix2[i]) + ] + for i in range(bsz // 2) + ] # contributions from original mixtures is weighed by the proportion of remaining frames est_mixes = [] - for i in range(bsz//2): + for i in range(bsz // 2): est_mix1 = mom_sources[i, :, speaker_idx_mix1[i]].sum(1) est_mix2 = mom_sources[i, :, speaker_idx_mix2[i]].sum(1) est_mixes.append(torch.stack((est_mix1, est_mix2))) est_mixes = torch.stack(est_mixes) - mixit_loss = multisrc_neg_sisdr(est_mixes, torch.stack((mix1, mix2)).transpose(0, 1)).mean() + mixit_loss = multisrc_neg_sisdr( + est_mixes, torch.stack((mix1, mix2)).transpose(0, 1) + ).mean() if self.original_mixtures_for_separation: raise NotImplementedError @@ -769,10 +834,14 @@ def training_step(self, batch, batch_idx: int): # predicted_sources_mix2.transpose(1, 2), torch.stack((mix2_masked, torch.zeros_like(mix2))).transpose(0, 1), speaker_idx_mix1[1::3], speaker_idx_mix2[1::3] # ) * mix2_masks.sum() / num_samples / bsz * 3 - upscaled_permutated_target = torch.nn.functional.interpolate(permutated_target.transpose(1, 2), size=(80000)).transpose(1, 2) - forced_alignment_loss = (1 - 2 * upscaled_permutated_target[:bsz//2]) * mix1_sources ** 2 +\ - (1 - 2 * upscaled_permutated_target[bsz//2:bsz]) * mix2_sources ** 2 +\ - (1 - 2 * upscaled_permutated_target[bsz:]) * mom_sources ** 2 + upscaled_permutated_target = torch.nn.functional.interpolate( + permutated_target.transpose(1, 2), size=(80000) + ).transpose(1, 2) + forced_alignment_loss = ( + (1 - 2 * upscaled_permutated_target[: bsz // 2]) * mix1_sources**2 + + (1 - 2 * upscaled_permutated_target[bsz // 2 : bsz]) * mix2_sources**2 + + (1 - 2 * upscaled_permutated_target[bsz:]) * mom_sources**2 + ) forced_alignment_loss = forced_alignment_loss.mean() / 3 self.model.log( @@ -793,7 +862,11 @@ def training_step(self, batch, batch_idx: int): logger=True, ) - loss = (1 - self.mixit_loss_weight) * seg_loss + self.mixit_loss_weight * mixit_loss + forced_alignment_loss * self.forced_alignment_weight + loss = ( + (1 - self.mixit_loss_weight) * seg_loss + + self.mixit_loss_weight * mixit_loss + + forced_alignment_loss * self.forced_alignment_weight + ) # skip batch if something went wrong for some reason if torch.isnan(loss): @@ -899,7 +972,9 @@ def validation_step(self, batch, batch_idx: int): logger=True, ) - loss = (1 - self.mixit_loss_weight) * seg_loss + self.mixit_loss_weight * mixit_loss + loss = ( + 1 - self.mixit_loss_weight + ) * seg_loss + self.mixit_loss_weight * mixit_loss self.model.log( "loss/val", @@ -911,14 +986,9 @@ def validation_step(self, batch, batch_idx: int): ) self.model.validation_metric( - torch.transpose( - multilabel, 1, 2 - ), - torch.transpose( - target, 1, 2 - ), + torch.transpose(multilabel, 1, 2), + torch.transpose(target, 1, 2), ) - self.model.log_dict( self.model.validation_metric, From 1e22a128b3179cb0f8c46b59511ae558cfca38b5 Mon Sep 17 00:00:00 2001 From: joonaskalda Date: Wed, 21 Jun 2023 15:45:17 +0200 Subject: [PATCH 38/55] fix for last batch in validation having size 1 --- .../tasks/segmentation/speaker_separation_diarization.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py b/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py index 7a11d5d09..7c78772c2 100644 --- a/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py +++ b/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py @@ -919,6 +919,10 @@ def validation_step(self, batch, batch_idx: int): # target = target[keep] bsz = waveform.shape[0] + # MoMs can't be created for batch size < 2 + if bsz < 2: + return None + mix1 = waveform[bsz // 2 : 2 * (bsz // 2)].squeeze(1) mix2 = waveform[: bsz // 2].squeeze(1) moms = mix1 + mix2 From d8bb87af5583fa478fd5bb9b684da479c5c58f47 Mon Sep 17 00:00:00 2001 From: joonaskalda Date: Wed, 21 Jun 2023 15:56:52 +0200 Subject: [PATCH 39/55] adding documentation --- .../speaker_separation_diarization.py | 34 +++++++++++++++---- 1 file changed, 28 insertions(+), 6 deletions(-) diff --git a/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py b/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py index 7c78772c2..f63b58cd5 100644 --- a/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py +++ b/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py @@ -699,13 +699,34 @@ def segmentation_loss( def create_mixtures_of_mixtures(self, mix1, mix2, target1, target2): """ - Creates mixtures of mixtures from two mixtures and their targets.""" - # mapping1[i1] = i means that i1-th speaker of mix1/tgt1 has been mapped to i-th speaker of mom/tgt - # mapping2[i2] = i means that i2-th speaker of mix2/tgt2 has been mapped to i-th speaker of mom/tgt - # mapping1[i1] = None means that i1-th (inactive) speaker of mix1/tgt1 does not exist in mom/tgt + Creates mixtures of mixtures and corresponding diarization targets. + Keeps track of how many speakers came from each mixture in order to + reconstruct the original mixtures. + + Parameters + ---------- + mix1 : torch.Tensor + First mixture. + mix2 : torch.Tensor + Second mixture. + target1 : torch.Tensor + First mixture diarization targets. + target2 : torch.Tensor + Second mixture diarization targets. + + Returns + ------- + mom : torch.Tensor + Mixtures of mixtures. + targets : torch.Tensor + Diarization targets for mixtures of mixtures. + num_active_speakers_mix1 : torch.Tensor + Number of active speakers in the first mixture. + num_active_speakers_mix2 : torch.Tensor + Number of active speakers in the second mixture. + """ batch_size = mix1.shape[0] mom = mix1 + mix2 - # (batch_size, num_speakers, num_frames) num_active_speakers_mix1 = (target1.sum(dim=1) != 0).sum(dim=1) num_active_speakers_mix2 = (target2.sum(dim=1) != 0).sum(dim=1) targets = [] @@ -922,7 +943,8 @@ def validation_step(self, batch, batch_idx: int): # MoMs can't be created for batch size < 2 if bsz < 2: return None - + + # if bsz not even, then leave out last sample mix1 = waveform[bsz // 2 : 2 * (bsz // 2)].squeeze(1) mix2 = waveform[: bsz // 2].squeeze(1) moms = mix1 + mix2 From db1ae59ea6ad26aa7a57a2f197aee505d47308c3 Mon Sep 17 00:00:00 2001 From: joonaskalda Date: Mon, 26 Jun 2023 11:32:09 +0200 Subject: [PATCH 40/55] diarization on sources separately and back to multilabel --- .../audio/models/segmentation/SepDiarNet.py | 18 ++- .../speaker_separation_diarization.py | 131 ++++++++++++------ 2 files changed, 98 insertions(+), 51 deletions(-) diff --git a/pyannote/audio/models/segmentation/SepDiarNet.py b/pyannote/audio/models/segmentation/SepDiarNet.py index 08efd869e..22684aae0 100644 --- a/pyannote/audio/models/segmentation/SepDiarNet.py +++ b/pyannote/audio/models/segmentation/SepDiarNet.py @@ -123,6 +123,7 @@ def __init__( convnet = merge_dict(self.CONVNET_DEFAULTS, convnet) dprnn = merge_dict(self.DPRNN_DEFAULTS, dprnn) encoder_decoder = merge_dict(self.ENCODER_DECODER_DEFAULTS, encoder_decoder) + self.n_src = n_sources self.save_hyperparameters("encoder_decoder", "lstm", "linear", "convnet", "dprnn") if encoder_decoder["fb_name"] == "free": @@ -141,7 +142,7 @@ def __init__( if monolithic: multi_layer_lstm = dict(lstm) del multi_layer_lstm["monolithic"] - self.lstm = nn.LSTM(n_sources * n_feats_out, **multi_layer_lstm) + self.lstm = nn.LSTM(n_feats_out, **multi_layer_lstm) else: num_layers = lstm["num_layers"] @@ -156,7 +157,7 @@ def __init__( self.lstm = nn.ModuleList( [ nn.LSTM( - 6 * n_feats_out + n_feats_out if i == 0 else lstm["hidden_size"] * (2 if lstm["bidirectional"] else 1), **one_layer_lstm @@ -196,7 +197,7 @@ def build(self): # raise ValueError("PyanNet does not support multi-tasking.") # if self.specifications.powerset: - out_features = self.specifications[0].num_powerset_classes + out_features = 1 # else: # out_features = len(self.specifications.classes) @@ -214,7 +215,7 @@ def forward(self, waveforms: torch.Tensor) -> torch.Tensor: ------- scores : (batch, frame, classes) """ - + bsz = waveforms.shape[0] tf_rep = self.encoder(waveforms) masks = self.masker(tf_rep) @@ -224,9 +225,9 @@ def forward(self, waveforms: torch.Tensor) -> torch.Tensor: decoded_sources = decoded_sources.transpose(1, 2) outputs = rearrange( - masks, "batch nsrc nfilters nframes -> batch nframes nfilters nsrc" + masks, "batch nsrc nfilters nframes -> batch nsrc nframes nfilters" ) - outputs = torch.flatten(outputs, start_dim=2, end_dim=3) + outputs = torch.flatten(outputs, start_dim=0, end_dim=1) if self.hparams.lstm["monolithic"]: outputs, _ = self.lstm(outputs) @@ -239,5 +240,8 @@ def forward(self, waveforms: torch.Tensor) -> torch.Tensor: if self.hparams.linear["num_layers"] > 0: for linear in self.linear: outputs = F.leaky_relu(linear(outputs)) + outputs = self.classifier(outputs) + outputs = outputs.reshape(bsz, 3, -1) + outputs = outputs.transpose(1, 2) - return self.activation[0](self.classifier(outputs)), decoded_sources + return self.activation[0](outputs), decoded_sources diff --git a/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py b/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py index f63b58cd5..bda51ee1d 100644 --- a/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py +++ b/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py @@ -275,7 +275,9 @@ def setup(self): speaker_diarization = Specifications( duration=self.duration, resolution=Resolution.FRAME, - problem=Problem.MONO_LABEL_CLASSIFICATION, + problem=Problem.MULTI_LABEL_CLASSIFICATION + if self.max_speakers_per_frame is None + else Problem.MONO_LABEL_CLASSIFICATION, permutation_invariant=True, classes=[f"speaker#{i+1}" for i in range(self.max_speakers_per_chunk)], powerset_max_classes=self.max_speakers_per_frame, @@ -291,10 +293,11 @@ def setup(self): self.specifications = (speaker_diarization, speaker_separation) def setup_loss_func(self): - self.model.powerset = Powerset( - len(self.specifications[0].classes), - self.specifications[0].powerset_max_classes, - ) + if self.specifications[0].powerset: + self.model.powerset = Powerset( + len(self.specifications.classes), + self.specifications.powerset_max_classes, + ) def prepare_chunk(self, file_id: int, start_time: float, duration: float): """Prepare chunk @@ -683,17 +686,23 @@ def segmentation_loss( """ # `clamp_min` is needed to set non-speech weight to 1. - class_weight = ( - torch.clamp_min(self.model.powerset.cardinality, 1.0) - if self.weigh_by_cardinality - else None - ) - seg_loss = nll_loss( - permutated_prediction, - torch.argmax(target, dim=-1), - class_weight=class_weight, - weight=weight, - ) + if self.specifications[0].powerset: + # `clamp_min` is needed to set non-speech weight to 1. + class_weight = ( + torch.clamp_min(self.model.powerset.cardinality, 1.0) + if self.weigh_by_cardinality + else None + ) + seg_loss = nll_loss( + permutated_prediction, + torch.argmax(target, dim=-1), + class_weight=class_weight, + weight=weight, + ) + else: + seg_loss = binary_cross_entropy( + permutated_prediction, target.float(), weight=weight + ) return seg_loss @@ -814,14 +823,21 @@ def training_step(self, batch, batch_idx: int): ) # (batch_size, num_frames, 1) - multilabel = self.model.powerset.to_multilabel(diarization) - permutated_target, permutations = permutate(multilabel, target) - permutated_target_powerset = self.model.powerset.to_powerset( - permutated_target.float() - ) - seg_loss = self.segmentation_loss( - diarization, permutated_target_powerset, weight=weight - ) + if self.specifications[0].powerset: + multilabel = self.model.powerset.to_multilabel(diarization) + permutated_target, permutations = permutate(multilabel, target) + permutated_target_powerset = self.model.powerset.to_powerset( + permutated_target.float() + ) + seg_loss = self.segmentation_loss( + diarization, permutated_target_powerset, weight=weight + ) + + else: + permutated_target, permutations = permutate(target, diarization) + seg_loss = self.segmentation_loss( + permutated_target, target, weight=weight + ) # to find which predicted sources correspond to which mixtures, we need to invert the permutations permutations_inverse = torch.argsort(torch.tensor(permutations)) @@ -909,11 +925,20 @@ def default_metric( ) -> Union[Metric, Sequence[Metric], Dict[str, Metric]]: """Returns diarization error rate and its components""" + if self.specifications[0].powerset: + return { + "DiarizationErrorRate": DiarizationErrorRate(0.5), + "DiarizationErrorRate/Confusion": SpeakerConfusionRate(0.5), + "DiarizationErrorRate/Miss": MissedDetectionRate(0.5), + "DiarizationErrorRate/FalseAlarm": FalseAlarmRate(0.5), + } + return { - "DiarizationErrorRate": DiarizationErrorRate(0.5), - "DiarizationErrorRate/Confusion": SpeakerConfusionRate(0.5), - "DiarizationErrorRate/Miss": MissedDetectionRate(0.5), - "DiarizationErrorRate/FalseAlarm": FalseAlarmRate(0.5), + "DiarizationErrorRate": OptimalDiarizationErrorRate(), + "DiarizationErrorRate/Threshold": OptimalDiarizationErrorRateThreshold(), + "DiarizationErrorRate/Confusion": OptimalSpeakerConfusionRate(), + "DiarizationErrorRate/Miss": OptimalMissedDetectionRate(), + "DiarizationErrorRate/FalseAlarm": OptimalFalseAlarmRate(), } # TODO: no need to compute gradient in this method @@ -962,17 +987,24 @@ def validation_step(self, batch, batch_idx: int): ) # (batch_size, num_frames, 1) - multilabel = self.model.powerset.to_multilabel(prediction) - permutated_target, _ = permutate(multilabel, target) + if self.specifications[0].powerset: + multilabel = self.model.powerset.to_multilabel(prediction) + permutated_target, _ = permutate(multilabel, target) - # FIXME: handle case where target have too many speakers? - # since we don't need - permutated_target_powerset = self.model.powerset.to_powerset( - permutated_target.float() - ) - seg_loss = self.segmentation_loss( - prediction, permutated_target_powerset, weight=weight - ) + # FIXME: handle case where target have too many speakers? + # since we don't need + permutated_target_powerset = self.model.powerset.to_powerset( + permutated_target.float() + ) + seg_loss = self.segmentation_loss( + prediction, permutated_target_powerset, weight=weight + ) + + else: + permutated_prediction, _ = permutate(target, prediction) + seg_loss = self.segmentation_loss( + permutated_prediction, target, weight=weight + ) # forced alignment mixit can't be implemented for validation because since data loading is different mixit_loss = 0 @@ -1011,10 +1043,17 @@ def validation_step(self, batch, batch_idx: int): logger=True, ) - self.model.validation_metric( - torch.transpose(multilabel, 1, 2), - torch.transpose(target, 1, 2), - ) + if self.specifications[0].powerset: + self.model.validation_metric( + torch.transpose(multilabel, 1, 2), + torch.transpose(target, 1, 2), + ) + else: + self.model.validation_metric( + torch.transpose(prediction, 1, 2), + torch.transpose(target, 1, 2), + ) + self.model.log_dict( self.model.validation_metric, @@ -1034,8 +1073,12 @@ def validation_step(self, batch, batch_idx: int): # visualize first 9 validation samples of first batch in Tensorboard/MLflow - y = permutated_target.float().cpu().numpy() - y_pred = multilabel.cpu().numpy() + if self.specifications[0].powerset: + y = permutated_target.float().cpu().numpy() + y_pred = multilabel.cpu().numpy() + else: + y = target.float().cpu().numpy() + y_pred = permutated_prediction.cpu().numpy() # prepare 3 x 3 grid (or smaller if batch size is smaller) num_samples = min(self.batch_size, 9) From 54229df0f441db074a6eb49c0dcca0416d836842 Mon Sep 17 00:00:00 2001 From: joonaskalda Date: Mon, 26 Jun 2023 21:46:29 +0200 Subject: [PATCH 41/55] make lstm use optional --- .../audio/models/segmentation/SepDiarNet.py | 86 ++++++++++--------- 1 file changed, 47 insertions(+), 39 deletions(-) diff --git a/pyannote/audio/models/segmentation/SepDiarNet.py b/pyannote/audio/models/segmentation/SepDiarNet.py index 22684aae0..4261e6712 100644 --- a/pyannote/audio/models/segmentation/SepDiarNet.py +++ b/pyannote/audio/models/segmentation/SepDiarNet.py @@ -79,7 +79,7 @@ class SepDiarNet(Model): "monolithic": True, "dropout": 0.0, } - LINEAR_DEFAULTS = {"hidden_size": 128, "num_layers": 2} + LINEAR_DEFAULTS = {"hidden_size": 64, "num_layers": 2} CONVNET_DEFAULTS = { "n_blocks": 8, "n_repeats": 3, @@ -114,6 +114,7 @@ def __init__( task: Optional[Task] = None, encoder_type: str = None, n_sources: int = 3, + use_lstm: bool = False, ): super().__init__(sample_rate=sample_rate, num_channels=num_channels, task=task) @@ -124,6 +125,7 @@ def __init__( dprnn = merge_dict(self.DPRNN_DEFAULTS, dprnn) encoder_decoder = merge_dict(self.ENCODER_DECODER_DEFAULTS, encoder_decoder) self.n_src = n_sources + self.use_lstm = use_lstm self.save_hyperparameters("encoder_decoder", "lstm", "linear", "convnet", "dprnn") if encoder_decoder["fb_name"] == "free": @@ -138,40 +140,45 @@ def __init__( self.masker = DPRNN(n_feats_out, n_src=n_sources, **self.hparams.dprnn) #self.convnet= TDConvNet(n_feats_out, **self.hparams.convnet) - monolithic = lstm["monolithic"] - if monolithic: - multi_layer_lstm = dict(lstm) - del multi_layer_lstm["monolithic"] - self.lstm = nn.LSTM(n_feats_out, **multi_layer_lstm) + if use_lstm: + monolithic = lstm["monolithic"] + if monolithic: + multi_layer_lstm = dict(lstm) + del multi_layer_lstm["monolithic"] + self.lstm = nn.LSTM(n_feats_out, **multi_layer_lstm) - else: - num_layers = lstm["num_layers"] - if num_layers > 1: - self.dropout = nn.Dropout(p=lstm["dropout"]) - - one_layer_lstm = dict(lstm) - one_layer_lstm["num_layers"] = 1 - one_layer_lstm["dropout"] = 0.0 - del one_layer_lstm["monolithic"] - - self.lstm = nn.ModuleList( - [ - nn.LSTM( - n_feats_out - if i == 0 - else lstm["hidden_size"] * (2 if lstm["bidirectional"] else 1), - **one_layer_lstm - ) - for i in range(num_layers) - ] - ) + else: + num_layers = lstm["num_layers"] + if num_layers > 1: + self.dropout = nn.Dropout(p=lstm["dropout"]) + + one_layer_lstm = dict(lstm) + one_layer_lstm["num_layers"] = 1 + one_layer_lstm["dropout"] = 0.0 + del one_layer_lstm["monolithic"] + + self.lstm = nn.ModuleList( + [ + nn.LSTM( + n_feats_out + if i == 0 + else lstm["hidden_size"] * (2 if lstm["bidirectional"] else 1), + **one_layer_lstm + ) + for i in range(num_layers) + ] + ) if linear["num_layers"] < 1: return - - lstm_out_features: int = self.hparams.lstm["hidden_size"] * ( - 2 if self.hparams.lstm["bidirectional"] else 1 - ) + + if use_lstm: + lstm_out_features: int = self.hparams.lstm["hidden_size"] * ( + 2 if self.hparams.lstm["bidirectional"] else 1 + ) + else: + lstm_out_features = 64 + self.linear = nn.ModuleList( [ nn.Linear(in_features, out_features) @@ -228,14 +235,15 @@ def forward(self, waveforms: torch.Tensor) -> torch.Tensor: masks, "batch nsrc nfilters nframes -> batch nsrc nframes nfilters" ) outputs = torch.flatten(outputs, start_dim=0, end_dim=1) - - if self.hparams.lstm["monolithic"]: - outputs, _ = self.lstm(outputs) - else: - for i, lstm in enumerate(self.lstm): - outputs, _ = lstm(outputs) - if i + 1 < self.hparams.lstm["num_layers"]: - outputs = self.dropout(outputs) + + if self.use_lstm: + if self.hparams.lstm["monolithic"]: + outputs, _ = self.lstm(outputs) + else: + for i, lstm in enumerate(self.lstm): + outputs, _ = lstm(outputs) + if i + 1 < self.hparams.lstm["num_layers"]: + outputs = self.dropout(outputs) if self.hparams.linear["num_layers"] > 0: for linear in self.linear: From c05a5297e462a83af92d2a247625fd1d4f62d498 Mon Sep 17 00:00:00 2001 From: joonaskalda Date: Mon, 26 Jun 2023 22:02:54 +0200 Subject: [PATCH 42/55] make alignment forcing optional --- .../speaker_separation_diarization.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py b/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py index bda51ee1d..8b7e453a2 100644 --- a/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py +++ b/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py @@ -54,7 +54,7 @@ from pyannote.audio.utils.loss import binary_cross_entropy, mse_loss, nll_loss from pyannote.audio.utils.permutation import permutate from pyannote.audio.utils.powerset import Powerset -from asteroid.losses import multisrc_neg_sisdr +from asteroid.losses import MixITLossWrapper, multisrc_neg_sisdr from torch.utils.data._utils.collate import default_collate Subsets = list(Subset.__args__) @@ -142,6 +142,7 @@ def __init__( mixit_loss_weight: float = 0.5, original_mixtures_for_separation: bool = False, forced_alignment_weight: float = 0.0, + force_alignment = False, ): super().__init__( protocol, @@ -184,7 +185,10 @@ def __init__( self.weigh_by_cardinality = weigh_by_cardinality self.balance = balance self.weight = weight + if not force_alignment: + self.separation_loss = MixITLossWrapper(multisrc_neg_sisdr, generalized=True) self.mixit_loss_weight = mixit_loss_weight + self.force_alignment = force_alignment self.original_mixtures_for_separation = original_mixtures_for_separation self.forced_alignment_weight = forced_alignment_weight @@ -859,10 +863,12 @@ def training_step(self, batch, batch_idx: int): est_mix2 = mom_sources[i, :, speaker_idx_mix2[i]].sum(1) est_mixes.append(torch.stack((est_mix1, est_mix2))) est_mixes = torch.stack(est_mixes) - mixit_loss = multisrc_neg_sisdr( - est_mixes, torch.stack((mix1, mix2)).transpose(0, 1) - ).mean() - + if self.force_alignment: + mixit_loss = multisrc_neg_sisdr( + est_mixes, torch.stack((mix1, mix2)).transpose(0, 1) + ).mean() + else: + mixit_loss = self.separation_loss(mom_sources.transpose(1, 2), torch.stack((mix1, mix2)).transpose(0, 1)) if self.original_mixtures_for_separation: raise NotImplementedError # mixit_loss += self.separation_loss( From cef2dbd0b56908f873000e9a8f396cd629daeeaa Mon Sep 17 00:00:00 2001 From: joonaskalda Date: Tue, 27 Jun 2023 13:26:21 +0200 Subject: [PATCH 43/55] bug fix --- .../audio/tasks/segmentation/speaker_separation_diarization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py b/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py index 8b7e453a2..7845fd5f8 100644 --- a/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py +++ b/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py @@ -852,7 +852,7 @@ def training_step(self, batch, batch_idx: int): speaker_idx_mix2 = [ [ permutations_inverse[i][j] - for j in range(num_active_speakers_mix1[i], num_active_speakers_mix2[i]) + for j in range(num_active_speakers_mix1[i], num_active_speakers_mix1[i] + num_active_speakers_mix2[i]) ] for i in range(bsz // 2) ] From 6b8f8a05e132d898a93ca150a9e979b65081f587 Mon Sep 17 00:00:00 2001 From: joonaskalda Date: Sun, 2 Jul 2023 13:47:36 +0100 Subject: [PATCH 44/55] rename mixit_loss to separation_loss for clarity --- .../speaker_separation_diarization.py | 28 +++++++++---------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py b/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py index 7845fd5f8..54348dee9 100644 --- a/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py +++ b/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py @@ -107,7 +107,7 @@ class JointSpeakerSeparationAndDiarization(SegmentationTaskMixin, Task): metric : optional Validation metric(s). Can be anything supported by torchmetrics.MetricCollection. Defaults to AUROC (area under the ROC curve). - mixit_loss_weight : float, optional + separation_loss_weight : float, optional Factor that speaker separation loss is scaled by when calculating total loss. References @@ -139,7 +139,7 @@ def __init__( metric: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None, max_num_speakers: int = None, # deprecated in favor of `max_speakers_per_chunk`` loss: Literal["bce", "mse"] = None, # deprecated - mixit_loss_weight: float = 0.5, + separation_loss_weight: float = 0.5, original_mixtures_for_separation: bool = False, forced_alignment_weight: float = 0.0, force_alignment = False, @@ -187,7 +187,7 @@ def __init__( self.weight = weight if not force_alignment: self.separation_loss = MixITLossWrapper(multisrc_neg_sisdr, generalized=True) - self.mixit_loss_weight = mixit_loss_weight + self.separation_loss_weight = separation_loss_weight self.force_alignment = force_alignment self.original_mixtures_for_separation = original_mixtures_for_separation self.forced_alignment_weight = forced_alignment_weight @@ -864,14 +864,14 @@ def training_step(self, batch, batch_idx: int): est_mixes.append(torch.stack((est_mix1, est_mix2))) est_mixes = torch.stack(est_mixes) if self.force_alignment: - mixit_loss = multisrc_neg_sisdr( + separation_loss = multisrc_neg_sisdr( est_mixes, torch.stack((mix1, mix2)).transpose(0, 1) ).mean() else: - mixit_loss = self.separation_loss(mom_sources.transpose(1, 2), torch.stack((mix1, mix2)).transpose(0, 1)) + separation_loss = self.separation_loss(mom_sources.transpose(1, 2), torch.stack((mix1, mix2)).transpose(0, 1)) if self.original_mixtures_for_separation: raise NotImplementedError - # mixit_loss += self.separation_loss( + # separation_loss += self.separation_loss( # predicted_sources_mix1.transpose(1, 2), torch.stack((mix1_masked, torch.zeros_like(mix1))).transpose(0, 1), speaker_idx_mix1[0::3], speaker_idx_mix2[0::3] # ) * mix1_masks.sum() / num_samples / bsz * 3 + self.separation_loss( # predicted_sources_mix2.transpose(1, 2), torch.stack((mix2_masked, torch.zeros_like(mix2))).transpose(0, 1), speaker_idx_mix1[1::3], speaker_idx_mix2[1::3] @@ -889,7 +889,7 @@ def training_step(self, batch, batch_idx: int): self.model.log( "loss/train/separation", - mixit_loss, + separation_loss, on_step=False, on_epoch=True, prog_bar=False, @@ -906,8 +906,8 @@ def training_step(self, batch, batch_idx: int): ) loss = ( - (1 - self.mixit_loss_weight) * seg_loss - + self.mixit_loss_weight * mixit_loss + (1 - self.separation_loss_weight) * seg_loss + + self.separation_loss_weight * separation_loss + forced_alignment_loss * self.forced_alignment_weight ) @@ -1013,14 +1013,14 @@ def validation_step(self, batch, batch_idx: int): ) # forced alignment mixit can't be implemented for validation because since data loading is different - mixit_loss = 0 - # mixit_loss = self.separation_loss( + separation_loss = 0 + # separation_loss = self.separation_loss( # mom_sources.transpose(1, 2), torch.stack((mix1, mix2)).transpose(0, 1) # ) self.model.log( "loss/val/separation", - mixit_loss, + separation_loss, on_step=False, on_epoch=True, prog_bar=False, @@ -1037,8 +1037,8 @@ def validation_step(self, batch, batch_idx: int): ) loss = ( - 1 - self.mixit_loss_weight - ) * seg_loss + self.mixit_loss_weight * mixit_loss + 1 - self.separation_loss_weight + ) * seg_loss + self.separation_loss_weight * separation_loss self.model.log( "loss/val", From ba66fb235c0b6245a03170523b937462510e3e84 Mon Sep 17 00:00:00 2001 From: joonaskalda Date: Tue, 25 Jul 2023 22:58:42 +0200 Subject: [PATCH 45/55] add 2 sources for noise and alignement accuracy measure --- .../audio/models/segmentation/SepDiarNet.py | 10 +- .../speaker_separation_diarization.py | 377 ++++++++++++++++-- 2 files changed, 353 insertions(+), 34 deletions(-) diff --git a/pyannote/audio/models/segmentation/SepDiarNet.py b/pyannote/audio/models/segmentation/SepDiarNet.py index 4261e6712..99e034023 100644 --- a/pyannote/audio/models/segmentation/SepDiarNet.py +++ b/pyannote/audio/models/segmentation/SepDiarNet.py @@ -113,8 +113,9 @@ def __init__( num_channels: int = 1, task: Optional[Task] = None, encoder_type: str = None, - n_sources: int = 3, + n_sources: int = 5, use_lstm: bool = False, + lr: float = 1e-3, ): super().__init__(sample_rate=sample_rate, num_channels=num_channels, task=task) @@ -127,6 +128,8 @@ def __init__( self.n_src = n_sources self.use_lstm = use_lstm self.save_hyperparameters("encoder_decoder", "lstm", "linear", "convnet", "dprnn") + self.learning_rate = lr + self.n_sources = n_sources if encoder_decoder["fb_name"] == "free": n_feats_out = encoder_decoder["n_filters"] @@ -211,6 +214,9 @@ def build(self): self.classifier = nn.Linear(in_features, out_features) self.activation = self.default_activation() + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=self.learning_rate) + def forward(self, waveforms: torch.Tensor) -> torch.Tensor: """Pass forward @@ -249,7 +255,7 @@ def forward(self, waveforms: torch.Tensor) -> torch.Tensor: for linear in self.linear: outputs = F.leaky_relu(linear(outputs)) outputs = self.classifier(outputs) - outputs = outputs.reshape(bsz, 3, -1) + outputs = outputs.reshape(bsz, self.n_sources, -1) outputs = outputs.transpose(1, 2) return self.activation[0](outputs), decoded_sources diff --git a/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py b/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py index 54348dee9..1ca945703 100644 --- a/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py +++ b/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py @@ -25,6 +25,7 @@ import random from collections import Counter from typing import Dict, Literal, Sequence, Text, Tuple, Union +import lightning.pytorch as pl import numpy as np import torch @@ -62,7 +63,257 @@ from itertools import combinations from torch import nn +from pytorch_lightning.callbacks import Callback + +class CountingCallback(Callback): + def on_train_epoch_start(self, trainer, pl_module) -> None: + "reset counters" + if pl_module.task.log_alignment_accuracy and pl_module.task.force_alignment: + pl_module.task.num_correct = 0 + pl_module.task.num_total = 0 + pl_module.task.num_correct30 = 0 + pl_module.task.num_correct21 = 0 + pl_module.task.num_correct20 = 0 + pl_module.task.num_correct11 = 0 + pl_module.task.num_correct10 = 0 + pl_module.task.num_total30 = 0 + pl_module.task.num_total21 = 0 + pl_module.task.num_total20 = 0 + pl_module.task.num_total11 = 0 + pl_module.task.num_total10 = 0 + +class CustomMixITLossWrapper(nn.Module): + r"""Custom mixture invariant loss wrapper that returns the best partition + so that it can be checked against the partition determined by forced + alignment. + + Args: + loss_func: function with signature (est_targets, targets, **kwargs). + generalized (bool): Determines how MixIT is applied. If False , + apply MixIT for any number of mixtures as soon as they contain + the same number of sources (:meth:`~MixITLossWrapper.best_part_mixit`.) + If True (default), apply MixIT for two mixtures, but those mixtures do not + necessarly have to contain the same number of sources. + See :meth:`~MixITLossWrapper.best_part_mixit_generalized`. + reduction (string, optional): Specifies the reduction to apply to + the output: + ``'none'`` | ``'mean'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output. + + For each of these modes, the best partition and reordering will be + automatically computed. + + Examples: + >>> import torch + >>> from asteroid.losses import multisrc_mse + >>> mixtures = torch.randn(10, 2, 16000) + >>> est_sources = torch.randn(10, 4, 16000) + >>> # Compute MixIT loss based on pairwise losses + >>> loss_func = MixITLossWrapper(multisrc_mse) + >>> loss_val = loss_func(est_sources, mixtures) + References + [1] Scott Wisdom et al. "Unsupervised sound separation using + mixtures of mixtures." arXiv:2006.12701 (2020) + """ + + def __init__(self, loss_func, generalized=True, reduction="mean"): + super().__init__() + self.loss_func = loss_func + self.generalized = generalized + self.reduction = reduction + + def forward(self, est_targets, targets, return_est=False, **kwargs): + r"""Find the best partition and return the loss. + + Args: + est_targets: torch.Tensor. Expected shape :math:`(batch, nsrc, *)`. + The batch of target estimates. + targets: torch.Tensor. Expected shape :math:`(batch, nmix, ...)`. + The batch of training targets + return_est: Boolean. Whether to return the estimated mixtures + estimates (To compute metrics or to save example). + **kwargs: additional keyword argument that will be passed to the + loss function. + + Returns: + - Best partition loss for each batch sample, average over + the batch. torch.Tensor(loss_value) + - The estimated mixtures (estimated sources summed according to the partition) + if return_est is True. torch.Tensor of shape :math:`(batch, nmix, ...)`. + """ + # Check input dimensions + assert est_targets.shape[0] == targets.shape[0] + assert est_targets.shape[2] == targets.shape[2] + + if not self.generalized: + min_loss, min_loss_idx, parts = self.best_part_mixit( + self.loss_func, est_targets, targets, **kwargs + ) + else: + min_loss, min_loss_idx, parts = self.best_part_mixit_generalized( + self.loss_func, est_targets, targets, **kwargs + ) + + # Apply any reductions over the batch axis + returned_loss = min_loss.mean() if self.reduction == "mean" else min_loss + if not return_est: + return returned_loss, [parts[i] for i in min_loss_idx] + + # Order and sum on the best partition to get the estimated mixtures + reordered = self.reorder_source(est_targets, targets, min_loss_idx, parts) + return returned_loss, reordered + + @staticmethod + def best_part_mixit(loss_func, est_targets, targets, **kwargs): + r"""Find best partition of the estimated sources that gives the minimum + loss for the MixIT training paradigm in [1]. Valid for any number of + mixtures as soon as they contain the same number of sources. + + Args: + loss_func: function with signature ``(est_targets, targets, **kwargs)`` + The loss function to get batch losses from. + est_targets: torch.Tensor. Expected shape :math:`(batch, nsrc, ...)`. + The batch of target estimates. + targets: torch.Tensor. Expected shape :math:`(batch, nmix, ...)`. + The batch of training targets (mixtures). + **kwargs: additional keyword argument that will be passed to the + loss function. + + Returns: + - :class:`torch.Tensor`: + The loss corresponding to the best permutation of size (batch,). + + - :class:`torch.LongTensor`: + The indices of the best partition. + + - :class:`list`: + list of the possible partitions of the sources. + + """ + nmix = targets.shape[1] + nsrc = est_targets.shape[1] + if nsrc % nmix != 0: + raise ValueError("The mixtures are assumed to contain the same number of sources") + nsrcmix = nsrc // nmix + + # Generate all unique partitions of size k from a list lst of + # length n, where l = n // k is the number of parts. The total + # number of such partitions is: NPK(n,k) = n! / ((k!)^l * l!) + # Algorithm recursively distributes items over parts + def parts_mixit(lst, k, l): + if l == 0: + yield [] + else: + for c in combinations(lst, k): + rest = [x for x in lst if x not in c] + for r in parts_mixit(rest, k, l - 1): + yield [list(c), *r] + + # Generate all the possible partitions + parts = list(parts_mixit(range(nsrc), nsrcmix, nmix)) + # Compute the loss corresponding to each partition + loss_set = CustomMixITLossWrapper.loss_set_from_parts( + loss_func, est_targets=est_targets, targets=targets, parts=parts, **kwargs + ) + # Indexes and values of min losses for each batch element + min_loss, min_loss_indexes = torch.min(loss_set, dim=1, keepdim=True) + return min_loss, min_loss_indexes, parts + + @staticmethod + def best_part_mixit_generalized(loss_func, est_targets, targets, **kwargs): + r"""Find best partition of the estimated sources that gives the minimum + loss for the MixIT training paradigm in [1]. Valid only for two mixtures, + but those mixtures do not necessarly have to contain the same number of + sources e.g the case where one mixture is silent is allowed.. + + Args: + loss_func: function with signature ``(est_targets, targets, **kwargs)`` + The loss function to get batch losses from. + est_targets: torch.Tensor. Expected shape :math:`(batch, nsrc, ...)`. + The batch of target estimates. + targets: torch.Tensor. Expected shape :math:`(batch, nmix, ...)`. + The batch of training targets (mixtures). + **kwargs: additional keyword argument that will be passed to the + loss function. + + Returns: + - :class:`torch.Tensor`: + The loss corresponding to the best permutation of size (batch,). + + - :class:`torch.LongTensor`: + The indexes of the best permutations. + + - :class:`list`: + list of the possible partitions of the sources. + """ + nmix = targets.shape[1] # number of mixtures + nsrc = est_targets.shape[1] # number of estimated sources + if nmix != 2: + raise ValueError("Works only with two mixtures") + + # Generate all unique partitions of any size from a list lst of + # length n. Algorithm recursively distributes items over parts + def parts_mixit_gen(lst): + partitions = [] + for k in range(len(lst) + 1): + for c in combinations(lst, k): + rest = [x for x in lst if x not in c] + partitions.append([list(c), rest]) + return partitions + + # Generate all the possible partitions + parts = parts_mixit_gen(range(nsrc)) + # Compute the loss corresponding to each partition + loss_set = CustomMixITLossWrapper.loss_set_from_parts( + loss_func, est_targets=est_targets, targets=targets, parts=parts, **kwargs + ) + # Indexes and values of min losses for each batch element + min_loss, min_loss_indexes = torch.min(loss_set, dim=1, keepdim=True) + return min_loss, min_loss_indexes, parts + + @staticmethod + def loss_set_from_parts(loss_func, est_targets, targets, parts, **kwargs): + """Common loop between both best_part_mixit""" + loss_set = [] + for partition in parts: + # sum the sources according to the given partition + est_mixes = torch.stack([est_targets[:, idx, :].sum(1) for idx in partition], dim=1) + # get loss for the given partition + loss_partition = loss_func(est_mixes, targets, **kwargs) + if loss_partition.ndim != 1: + raise ValueError("Loss function return value should be of size (batch,).") + loss_set.append(loss_partition[:, None]) + loss_set = torch.cat(loss_set, dim=1) + return loss_set + + @staticmethod + def reorder_source(est_targets, targets, min_loss_idx, parts): + """Reorder sources according to the best partition. + + Args: + est_targets: torch.Tensor. Expected shape :math:`(batch, nsrc, ...)`. + The batch of target estimates. + targets: torch.Tensor. Expected shape :math:`(batch, nmix, ...)`. + The batch of training targets. + min_loss_idx: torch.LongTensor. The indexes of the best permutations. + parts: list of the possible partitions of the sources. + + Returns: + :class:`torch.Tensor`: Reordered sources of shape :math:`(batch, nmix, time)`. + + """ + # For each batch there is a different min_loss_idx + ordered = torch.zeros_like(targets) + for b, idx in enumerate(min_loss_idx): + right_partition = parts[idx] + # Sum the estimated sources to get the estimated mixtures + ordered[b, :, :] = torch.stack( + [est_targets[b, idx, :][None, :, :].sum(1) for idx in right_partition], dim=1 + ) + + return ordered class JointSpeakerSeparationAndDiarization(SegmentationTaskMixin, Task): """Speaker diarization @@ -142,7 +393,8 @@ def __init__( separation_loss_weight: float = 0.5, original_mixtures_for_separation: bool = False, forced_alignment_weight: float = 0.0, - force_alignment = False, + force_alignment: bool = False, + log_alignment_accuracy: bool = False, ): super().__init__( protocol, @@ -185,12 +437,12 @@ def __init__( self.weigh_by_cardinality = weigh_by_cardinality self.balance = balance self.weight = weight - if not force_alignment: - self.separation_loss = MixITLossWrapper(multisrc_neg_sisdr, generalized=True) + self.separation_loss = CustomMixITLossWrapper(multisrc_neg_sisdr, generalized=True) self.separation_loss_weight = separation_loss_weight self.force_alignment = force_alignment self.original_mixtures_for_separation = original_mixtures_for_separation self.forced_alignment_weight = forced_alignment_weight + self.log_alignment_accuracy = log_alignment_accuracy def setup(self): super().setup() @@ -838,37 +1090,54 @@ def training_step(self, batch, batch_idx: int): ) else: - permutated_target, permutations = permutate(target, diarization) + # last 2 sources should only contain noise so we force diarization outputs to 0 + permutated_target, permutations = permutate(target, diarization[:, :, :3]) + permutated_target = torch.cat((permutated_target, diarization[:, :, 3:]), dim=2) + target = torch.cat((target, torch.zeros(batch_size, num_frames, 2, device=target.device)), dim=2) + seg_loss = self.segmentation_loss( permutated_target, target, weight=weight ) - # to find which predicted sources correspond to which mixtures, we need to invert the permutations - permutations_inverse = torch.argsort(torch.tensor(permutations)) - speaker_idx_mix1 = [ - [permutations_inverse[i][j] for j in range(num_active_speakers_mix1[i])] - for i in range(bsz // 2) - ] - speaker_idx_mix2 = [ - [ - permutations_inverse[i][j] - for j in range(num_active_speakers_mix1[i], num_active_speakers_mix1[i] + num_active_speakers_mix2[i]) - ] - for i in range(bsz // 2) - ] - # contributions from original mixtures is weighed by the proportion of remaining frames - est_mixes = [] - for i in range(bsz // 2): - est_mix1 = mom_sources[i, :, speaker_idx_mix1[i]].sum(1) - est_mix2 = mom_sources[i, :, speaker_idx_mix2[i]].sum(1) - est_mixes.append(torch.stack((est_mix1, est_mix2))) - est_mixes = torch.stack(est_mixes) if self.force_alignment: + # to find which predicted sources correspond to which mixtures, we need to invert the permutations + permutations_inverse = torch.argsort(torch.tensor(permutations)) + speaker_idx_mix1 = [ + [permutations_inverse[i][j] for j in range(num_active_speakers_mix1[i])] + for i in range(bsz // 2) + ] + speaker_idx_mix2 = [ + [ + permutations_inverse[i][j] + for j in range(num_active_speakers_mix1[i], num_active_speakers_mix1[i] + num_active_speakers_mix2[i]) + ] + for i in range(bsz // 2) + ] + + est_mixes = [] + for i in range(bsz // 2): + est_mix1 = mom_sources[i, :, speaker_idx_mix1[i]].sum(1) + mom_sources[i,:,3] + est_mix2 = mom_sources[i, :, speaker_idx_mix2[i]].sum(1) + mom_sources[i,:,4] + est_mix3 = mom_sources[i, :, speaker_idx_mix1[i]].sum(1) + mom_sources[i,:,4] + est_mix4 = mom_sources[i, :, speaker_idx_mix2[i]].sum(1) + mom_sources[i,:,3] + sep_loss_first_part = multisrc_neg_sisdr( + torch.stack((est_mix1, est_mix2)).unsqueeze(0), torch.stack((mix1[i], mix2[i])).unsqueeze(0) + ) + sep_loss_second_part = multisrc_neg_sisdr( + torch.stack((est_mix3, est_mix4)).unsqueeze(0), torch.stack((mix1[i], mix2[i])).unsqueeze(0) + ) + if sep_loss_first_part < sep_loss_second_part: + est_mixes.append(torch.stack((est_mix1, est_mix2))) + else: + est_mixes.append(torch.stack((est_mix3, est_mix4))) + est_mixes = torch.stack(est_mixes) separation_loss = multisrc_neg_sisdr( est_mixes, torch.stack((mix1, mix2)).transpose(0, 1) ).mean() + _, mixit_partitions = self.separation_loss(mom_sources[:,:,:3].transpose(1, 2), torch.stack((mix1, mix2)).transpose(0, 1)) else: - separation_loss = self.separation_loss(mom_sources.transpose(1, 2), torch.stack((mix1, mix2)).transpose(0, 1)) + separation_loss, _ = self.separation_loss(mom_sources[:,:,:3].transpose(1, 2), torch.stack((mix1, mix2)).transpose(0, 1)) + if self.original_mixtures_for_separation: raise NotImplementedError # separation_loss += self.separation_loss( @@ -880,13 +1149,13 @@ def training_step(self, batch, batch_idx: int): upscaled_permutated_target = torch.nn.functional.interpolate( permutated_target.transpose(1, 2), size=(80000) ).transpose(1, 2) - forced_alignment_loss = ( - (1 - 2 * upscaled_permutated_target[: bsz // 2]) * mix1_sources**2 - + (1 - 2 * upscaled_permutated_target[bsz // 2 : bsz]) * mix2_sources**2 - + (1 - 2 * upscaled_permutated_target[bsz:]) * mom_sources**2 - ) - forced_alignment_loss = forced_alignment_loss.mean() / 3 - + # forced_alignment_loss = ( + # (1 - 2 * upscaled_permutated_target[: bsz // 2]) * mix1_sources**2 + # + (1 - 2 * upscaled_permutated_target[bsz // 2 : bsz]) * mix2_sources**2 + # + (1 - 2 * upscaled_permutated_target[bsz:]) * mom_sources**2 + # ) + # forced_alignment_loss = forced_alignment_loss.mean() / 3 + forced_alignment_loss = 0 self.model.log( "loss/train/separation", separation_loss, @@ -923,6 +1192,50 @@ def training_step(self, batch, batch_idx: int): prog_bar=False, logger=True, ) + if self.log_alignment_accuracy and self.force_alignment: + for i in range(bsz // 2): + inverse_mixit_partition = permutations_inverse[i][mixit_partitions[i][0]], permutations_inverse[i][mixit_partitions[i][1]] + if set([int(j) for j in speaker_idx_mix1[i]]) <= set(inverse_mixit_partition[0].tolist()) and set([int(j) for j in speaker_idx_mix2[i]]) <= set(inverse_mixit_partition[1].tolist()): + self.num_correct += 1 + if (num_active_speakers_mix1[i], num_active_speakers_mix2[i]) in [(1, 0), (0, 1)]: + self.num_correct10 += 1 + if (num_active_speakers_mix1[i], num_active_speakers_mix2[i]) in [(2, 0), (0, 2)]: + self.num_correct20 += 1 + if (num_active_speakers_mix1[i], num_active_speakers_mix2[i]) in [(3, 0), (0, 3)]: + self.num_correct30 += 1 + if num_active_speakers_mix1[i] == 1 and num_active_speakers_mix2[i] == 1: + self.num_correct11 += 1 + if (num_active_speakers_mix1[i], num_active_speakers_mix2[i]) in [(1, 2), (2, 1)]: + self.num_correct21 += 1 + if (num_active_speakers_mix1[i], num_active_speakers_mix2[i]) in [(1, 0), (0, 1)]: + self.num_total10 += 1 + if (num_active_speakers_mix1[i], num_active_speakers_mix2[i]) in [(2, 0), (0, 2)]: + self.num_total20 += 1 + if (num_active_speakers_mix1[i], num_active_speakers_mix2[i]) in [(3, 0), (0, 3)]: + self.num_total30 += 1 + if num_active_speakers_mix1[i] == 1 and num_active_speakers_mix2[i] == 1: + self.num_total11 += 1 + if (num_active_speakers_mix1[i], num_active_speakers_mix2[i]) in [(1, 2), (2, 1)]: + self.num_total21 += 1 + self.num_total+=1 + if self.num_total30 > 0: + self.model.log("accuracy/3_0", self.num_correct30/self.num_total30, on_step=False, on_epoch=True, prog_bar=False, logger=True) + if self.num_total20 > 0: + self.model.log("accuracy/2_0", self.num_correct20/self.num_total20, on_step=False, on_epoch=True, prog_bar=False, logger=True) + if self.num_total10 > 0: + self.model.log("accuracy/1_0", self.num_correct10/self.num_total10, on_step=False, on_epoch=True, prog_bar=False, logger=True) + if self.num_total11 > 0: + self.model.log("accuracy/1_1", self.num_correct11/self.num_total11, on_step=False, on_epoch=True, prog_bar=False, logger=True) + if self.num_total21 > 0: + self.model.log("accuracy/2_1", self.num_correct21/self.num_total21, on_step=False, on_epoch=True, prog_bar=False, logger=True) + if self.num_total > 0: + self.model.log("accuracy/total", self.num_correct/self.num_total, on_step=False, on_epoch=True, prog_bar=False, logger=True) + self.model.log("counts/3_0", self.num_total30, on_step=False, on_epoch=True, prog_bar=False, logger=True) + self.model.log("counts/2_0", self.num_total20, on_step=False, on_epoch=True, prog_bar=False, logger=True) + self.model.log("counts/1_0", self.num_total10, on_step=False, on_epoch=True, prog_bar=False, logger=True) + self.model.log("counts/1_1", self.num_total11, on_step=False, on_epoch=True, prog_bar=False, logger=True) + self.model.log("counts/2_1", self.num_total21, on_step=False, on_epoch=True, prog_bar=False, logger=True) + self.model.log("counts/total", self.num_total, on_step=False, on_epoch=True, prog_bar=False, logger=True) return {"loss": loss} From dc8f2f38931217c8455e53d301b6cd26dc44c36a Mon Sep 17 00:00:00 2001 From: joonaskalda Date: Tue, 25 Jul 2023 23:02:12 +0200 Subject: [PATCH 46/55] bug regarding specifications being a tuple --- .../tasks/segmentation/speaker_separation_diarization.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py b/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py index 1ca945703..69193286b 100644 --- a/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py +++ b/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py @@ -551,8 +551,8 @@ def setup(self): def setup_loss_func(self): if self.specifications[0].powerset: self.model.powerset = Powerset( - len(self.specifications.classes), - self.specifications.powerset_max_classes, + len(self.specifications[0].classes), + self.specifications[0].powerset_max_classes, ) def prepare_chunk(self, file_id: int, start_time: float, duration: float): From 5e18060174f63e2f8d1b55aaccc64f5380093b86 Mon Sep 17 00:00:00 2001 From: joonaskalda Date: Mon, 28 Aug 2023 13:48:48 +0300 Subject: [PATCH 47/55] clean up --- .../speaker_separation_diarization.py | 25 +++++-------------- 1 file changed, 6 insertions(+), 19 deletions(-) diff --git a/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py b/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py index 69193286b..00be48e19 100644 --- a/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py +++ b/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py @@ -1080,35 +1080,25 @@ def training_step(self, batch, batch_idx: int): # (batch_size, num_frames, 1) if self.specifications[0].powerset: - multilabel = self.model.powerset.to_multilabel(diarization) - permutated_target, permutations = permutate(multilabel, target) - permutated_target_powerset = self.model.powerset.to_powerset( - permutated_target.float() - ) - seg_loss = self.segmentation_loss( - diarization, permutated_target_powerset, weight=weight - ) + raise NotImplementedError("Forced alignment requires multilabel diarization") else: # last 2 sources should only contain noise so we force diarization outputs to 0 - permutated_target, permutations = permutate(target, diarization[:, :, :3]) - permutated_target = torch.cat((permutated_target, diarization[:, :, 3:]), dim=2) + permutated_diarization, permutations = permutate(target, diarization[:, :, :3]) target = torch.cat((target, torch.zeros(batch_size, num_frames, 2, device=target.device)), dim=2) - + permutated_diarization = torch.cat((permutated_diarization, diarization[:, :, 3:]), dim=2) seg_loss = self.segmentation_loss( - permutated_target, target, weight=weight + permutated_diarization, target, weight=weight ) if self.force_alignment: - # to find which predicted sources correspond to which mixtures, we need to invert the permutations - permutations_inverse = torch.argsort(torch.tensor(permutations)) speaker_idx_mix1 = [ - [permutations_inverse[i][j] for j in range(num_active_speakers_mix1[i])] + [permutations[i][j] for j in range(num_active_speakers_mix1[i])] for i in range(bsz // 2) ] speaker_idx_mix2 = [ [ - permutations_inverse[i][j] + permutations[i][j] for j in range(num_active_speakers_mix1[i], num_active_speakers_mix1[i] + num_active_speakers_mix2[i]) ] for i in range(bsz // 2) @@ -1146,9 +1136,6 @@ def training_step(self, batch, batch_idx: int): # predicted_sources_mix2.transpose(1, 2), torch.stack((mix2_masked, torch.zeros_like(mix2))).transpose(0, 1), speaker_idx_mix1[1::3], speaker_idx_mix2[1::3] # ) * mix2_masks.sum() / num_samples / bsz * 3 - upscaled_permutated_target = torch.nn.functional.interpolate( - permutated_target.transpose(1, 2), size=(80000) - ).transpose(1, 2) # forced_alignment_loss = ( # (1 - 2 * upscaled_permutated_target[: bsz // 2]) * mix1_sources**2 # + (1 - 2 * upscaled_permutated_target[bsz // 2 : bsz]) * mix2_sources**2 From c2fb4821a2b2d11aef869ff98f18c1c0791db26c Mon Sep 17 00:00:00 2001 From: joonaskalda Date: Thu, 7 Sep 2023 22:36:35 +0300 Subject: [PATCH 48/55] add avg pooling to diarization branch for smaller kernel sizes --- pyannote/audio/models/segmentation/SepDiarNet.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/pyannote/audio/models/segmentation/SepDiarNet.py b/pyannote/audio/models/segmentation/SepDiarNet.py index 99e034023..6db75f4cb 100644 --- a/pyannote/audio/models/segmentation/SepDiarNet.py +++ b/pyannote/audio/models/segmentation/SepDiarNet.py @@ -142,6 +142,10 @@ def __init__( ) self.masker = DPRNN(n_feats_out, n_src=n_sources, **self.hparams.dprnn) #self.convnet= TDConvNet(n_feats_out, **self.hparams.convnet) + + # diarization can use a lower resolution than separation + diarization_scaling = int(256 / encoder_decoder["kernel_size"]) + self.average_pool = nn.AvgPool1d(diarization_scaling, stride=diarization_scaling) if use_lstm: monolithic = lstm["monolithic"] @@ -231,17 +235,15 @@ def forward(self, waveforms: torch.Tensor) -> torch.Tensor: bsz = waveforms.shape[0] tf_rep = self.encoder(waveforms) masks = self.masker(tf_rep) - + # shape: (batch, nsrc, nfilters, nframes) masked_tf_rep = masks * tf_rep.unsqueeze(1) decoded_sources = self.decoder(masked_tf_rep) decoded_sources = pad_x_to_y(decoded_sources, waveforms) decoded_sources = decoded_sources.transpose(1, 2) - outputs = rearrange( - masks, "batch nsrc nfilters nframes -> batch nsrc nframes nfilters" - ) - outputs = torch.flatten(outputs, start_dim=0, end_dim=1) - + outputs = torch.flatten(masks, start_dim=0, end_dim=1) + outputs = self.average_pool(outputs) + outputs = outputs.transpose(1, 2) if self.use_lstm: if self.hparams.lstm["monolithic"]: outputs, _ = self.lstm(outputs) From ad8850c05d45755a6920a2662486659013198693 Mon Sep 17 00:00:00 2001 From: joonaskalda Date: Sat, 9 Sep 2023 12:07:56 +0300 Subject: [PATCH 49/55] fix validation loss --- .../speaker_separation_diarization.py | 101 ++++++++++++------ 1 file changed, 70 insertions(+), 31 deletions(-) diff --git a/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py b/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py index 00be48e19..1b59f8b7b 100644 --- a/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py +++ b/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py @@ -1271,19 +1271,35 @@ def validation_step(self, batch, batch_idx: int): # target = target[keep] bsz = waveform.shape[0] + num_samples = waveform.shape[2] # MoMs can't be created for batch size < 2 if bsz < 2: return None + # if bsz not even, then leave out last sample + if bsz % 2 != 0: + waveform = waveform[:-1] # if bsz not even, then leave out last sample - mix1 = waveform[bsz // 2 : 2 * (bsz // 2)].squeeze(1) - mix2 = waveform[: bsz // 2].squeeze(1) - moms = mix1 + mix2 + mix1 = waveform[0::2].squeeze(1) + mix2 = waveform[1::2].squeeze(1) + if self.original_mixtures_for_separation: + # extract parts with only one speaker from original mixtures + mix1_masks = batch["X_separation_mask"][0::2] + mix2_masks = batch["X_separation_mask"][1::2] + mix1_masked = mix1 * mix1_masks + mix2_masked = mix2 * mix2_masks + + ( + mom, + mom_target, + num_active_speakers_mix1, + num_active_speakers_mix2, + ) = self.create_mixtures_of_mixtures(mix1, mix2, target[0::2], target[1::2]) # forward pass - prediction, _ = self.model(waveform) - _, mom_sources = self.model(moms) - batch_size, num_frames, _ = prediction.shape + diarization, _ = self.model(waveform) + _, mom_sources = self.model(mom) + batch_size, num_frames, _ = diarization.shape # frames weight weight_key = getattr(self, "weight", None) @@ -1294,29 +1310,56 @@ def validation_step(self, batch, batch_idx: int): # (batch_size, num_frames, 1) if self.specifications[0].powerset: - multilabel = self.model.powerset.to_multilabel(prediction) - permutated_target, _ = permutate(multilabel, target) + raise NotImplementedError("Forced alignment requires multilabel diarization") - # FIXME: handle case where target have too many speakers? - # since we don't need - permutated_target_powerset = self.model.powerset.to_powerset( - permutated_target.float() - ) + else: + # last 2 sources should only contain noise so we force diarization outputs to 0 + permutated_diarization, permutations = permutate(target, diarization[:, :, :3]) + target = torch.cat((target, torch.zeros(batch_size, num_frames, 2, device=target.device)), dim=2) + permutated_diarization = torch.cat((permutated_diarization, diarization[:, :, 3:]), dim=2) seg_loss = self.segmentation_loss( - prediction, permutated_target_powerset, weight=weight + permutated_diarization, target, weight=weight ) + if self.force_alignment: + speaker_idx_mix1 = [ + [permutations[i][j] for j in range(num_active_speakers_mix1[i])] + for i in range(bsz // 2) + ] + speaker_idx_mix2 = [ + [ + permutations[i][j] + for j in range(num_active_speakers_mix1[i], num_active_speakers_mix1[i] + num_active_speakers_mix2[i]) + ] + for i in range(bsz // 2) + ] + + est_mixes = [] + for i in range(bsz // 2): + est_mix1 = mom_sources[i, :, speaker_idx_mix1[i]].sum(1) + mom_sources[i,:,3] + est_mix2 = mom_sources[i, :, speaker_idx_mix2[i]].sum(1) + mom_sources[i,:,4] + est_mix3 = mom_sources[i, :, speaker_idx_mix1[i]].sum(1) + mom_sources[i,:,4] + est_mix4 = mom_sources[i, :, speaker_idx_mix2[i]].sum(1) + mom_sources[i,:,3] + sep_loss_first_part = multisrc_neg_sisdr( + torch.stack((est_mix1, est_mix2)).unsqueeze(0), torch.stack((mix1[i], mix2[i])).unsqueeze(0) + ) + sep_loss_second_part = multisrc_neg_sisdr( + torch.stack((est_mix3, est_mix4)).unsqueeze(0), torch.stack((mix1[i], mix2[i])).unsqueeze(0) + ) + if sep_loss_first_part < sep_loss_second_part: + est_mixes.append(torch.stack((est_mix1, est_mix2))) + else: + est_mixes.append(torch.stack((est_mix3, est_mix4))) + est_mixes = torch.stack(est_mixes) + separation_loss = multisrc_neg_sisdr( + est_mixes, torch.stack((mix1, mix2)).transpose(0, 1) + ).mean() + _, mixit_partitions = self.separation_loss(mom_sources[:,:,:3].transpose(1, 2), torch.stack((mix1, mix2)).transpose(0, 1)) else: - permutated_prediction, _ = permutate(target, prediction) - seg_loss = self.segmentation_loss( - permutated_prediction, target, weight=weight - ) + separation_loss, _ = self.separation_loss(mom_sources[:,:,:3].transpose(1, 2), torch.stack((mix1, mix2)).transpose(0, 1)) - # forced alignment mixit can't be implemented for validation because since data loading is different - separation_loss = 0 - # separation_loss = self.separation_loss( - # mom_sources.transpose(1, 2), torch.stack((mix1, mix2)).transpose(0, 1) - # ) + if self.original_mixtures_for_separation: + raise NotImplementedError self.model.log( "loss/val/separation", @@ -1350,13 +1393,10 @@ def validation_step(self, batch, batch_idx: int): ) if self.specifications[0].powerset: - self.model.validation_metric( - torch.transpose(multilabel, 1, 2), - torch.transpose(target, 1, 2), - ) + raise NotImplementedError("Forced alignment requires multilabel diarization") else: self.model.validation_metric( - torch.transpose(prediction, 1, 2), + torch.transpose(diarization, 1, 2), torch.transpose(target, 1, 2), ) @@ -1380,11 +1420,10 @@ def validation_step(self, batch, batch_idx: int): # visualize first 9 validation samples of first batch in Tensorboard/MLflow if self.specifications[0].powerset: - y = permutated_target.float().cpu().numpy() - y_pred = multilabel.cpu().numpy() + raise NotImplementedError("Forced alignment requires multilabel diarization") else: y = target.float().cpu().numpy() - y_pred = permutated_prediction.cpu().numpy() + y_pred = permutated_diarization.cpu().numpy() # prepare 3 x 3 grid (or smaller if batch size is smaller) num_samples = min(self.batch_size, 9) From 8557f4939f2e0e0792d08835fd07446e110383d4 Mon Sep 17 00:00:00 2001 From: joonaskalda Date: Thu, 14 Sep 2023 21:58:13 +0300 Subject: [PATCH 50/55] changing to pit_loss --- .../speaker_separation_diarization.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py b/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py index 1b59f8b7b..a1ff59ed4 100644 --- a/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py +++ b/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py @@ -55,7 +55,7 @@ from pyannote.audio.utils.loss import binary_cross_entropy, mse_loss, nll_loss from pyannote.audio.utils.permutation import permutate from pyannote.audio.utils.powerset import Powerset -from asteroid.losses import MixITLossWrapper, multisrc_neg_sisdr +from asteroid.losses import MixITLossWrapper, multisrc_neg_sisdr, PITLossWrapper, pairwise_neg_sisdr from torch.utils.data._utils.collate import default_collate Subsets = list(Subset.__args__) @@ -438,6 +438,7 @@ def __init__( self.balance = balance self.weight = weight self.separation_loss = CustomMixITLossWrapper(multisrc_neg_sisdr, generalized=True) + self.pit_sep_loss = PITLossWrapper(pairwise_neg_sisdr, pit_from="pw_mtx") self.separation_loss_weight = separation_loss_weight self.force_alignment = force_alignment self.original_mixtures_for_separation = original_mixtures_for_separation @@ -1110,10 +1111,10 @@ def training_step(self, batch, batch_idx: int): est_mix2 = mom_sources[i, :, speaker_idx_mix2[i]].sum(1) + mom_sources[i,:,4] est_mix3 = mom_sources[i, :, speaker_idx_mix1[i]].sum(1) + mom_sources[i,:,4] est_mix4 = mom_sources[i, :, speaker_idx_mix2[i]].sum(1) + mom_sources[i,:,3] - sep_loss_first_part = multisrc_neg_sisdr( + sep_loss_first_part = self.pit_sep_loss( torch.stack((est_mix1, est_mix2)).unsqueeze(0), torch.stack((mix1[i], mix2[i])).unsqueeze(0) ) - sep_loss_second_part = multisrc_neg_sisdr( + sep_loss_second_part = self.pit_sep_loss( torch.stack((est_mix3, est_mix4)).unsqueeze(0), torch.stack((mix1[i], mix2[i])).unsqueeze(0) ) if sep_loss_first_part < sep_loss_second_part: @@ -1121,7 +1122,7 @@ def training_step(self, batch, batch_idx: int): else: est_mixes.append(torch.stack((est_mix3, est_mix4))) est_mixes = torch.stack(est_mixes) - separation_loss = multisrc_neg_sisdr( + separation_loss = self.pit_sep_loss( est_mixes, torch.stack((mix1, mix2)).transpose(0, 1) ).mean() _, mixit_partitions = self.separation_loss(mom_sources[:,:,:3].transpose(1, 2), torch.stack((mix1, mix2)).transpose(0, 1)) @@ -1340,10 +1341,10 @@ def validation_step(self, batch, batch_idx: int): est_mix2 = mom_sources[i, :, speaker_idx_mix2[i]].sum(1) + mom_sources[i,:,4] est_mix3 = mom_sources[i, :, speaker_idx_mix1[i]].sum(1) + mom_sources[i,:,4] est_mix4 = mom_sources[i, :, speaker_idx_mix2[i]].sum(1) + mom_sources[i,:,3] - sep_loss_first_part = multisrc_neg_sisdr( + sep_loss_first_part = self.pit_sep_loss( torch.stack((est_mix1, est_mix2)).unsqueeze(0), torch.stack((mix1[i], mix2[i])).unsqueeze(0) ) - sep_loss_second_part = multisrc_neg_sisdr( + sep_loss_second_part = self.pit_sep_loss( torch.stack((est_mix3, est_mix4)).unsqueeze(0), torch.stack((mix1[i], mix2[i])).unsqueeze(0) ) if sep_loss_first_part < sep_loss_second_part: @@ -1351,7 +1352,7 @@ def validation_step(self, batch, batch_idx: int): else: est_mixes.append(torch.stack((est_mix3, est_mix4))) est_mixes = torch.stack(est_mixes) - separation_loss = multisrc_neg_sisdr( + separation_loss = self.pit_sep_loss( est_mixes, torch.stack((mix1, mix2)).transpose(0, 1) ).mean() _, mixit_partitions = self.separation_loss(mom_sources[:,:,:3].transpose(1, 2), torch.stack((mix1, mix2)).transpose(0, 1)) From 712d76567e31276b832bf4eb4cf8bde4031bf689 Mon Sep 17 00:00:00 2001 From: joonaskalda Date: Thu, 14 Sep 2023 22:41:21 +0300 Subject: [PATCH 51/55] changing validation dataloader --- .../speaker_separation_diarization.py | 230 ++++++++++++++++++ 1 file changed, 230 insertions(+) diff --git a/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py b/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py index a1ff59ed4..3170680d2 100644 --- a/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py +++ b/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py @@ -20,6 +20,7 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +import itertools import math import warnings import random @@ -64,6 +65,10 @@ from itertools import combinations from torch import nn from pytorch_lightning.callbacks import Callback +from pyannote.audio.core.task import TrainDataset +from functools import cached_property, partial +from torch.utils.data import DataLoader, Dataset, IterableDataset +from pyannote.audio.utils.random import create_rng_for_worker class CountingCallback(Callback): def on_train_epoch_start(self, trainer, pl_module) -> None: @@ -315,6 +320,17 @@ def reorder_source(est_targets, targets, min_loss_idx, parts): return ordered +class ValDataset(IterableDataset): + def __init__(self, task: Task): + super().__init__() + self.task = task + + def __iter__(self): + return self.task.val__iter__() + + def __len__(self): + return self.task.val__len__() + class JointSpeakerSeparationAndDiarization(SegmentationTaskMixin, Task): """Speaker diarization @@ -658,6 +674,220 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): sample["meta"]["file"] = file_id return sample + + def val_dataloader(self) -> DataLoader: + return DataLoader( + ValDataset(self), + batch_size=self.batch_size, + num_workers=self.num_workers, + pin_memory=self.pin_memory, + drop_last=True, + collate_fn=partial(self.collate_fn, stage="train"), + ) + + def val__iter__(self): + """Iterate over training samples + + Yields + ------ + dict: + X: (time, channel) + Audio chunks. + y: (frame, ) + Frame-level targets. Note that frame < time. + `frame` is infered automagically from the + example model output. + ... + """ + + # create worker-specific random number generator + rng = create_rng_for_worker(0) + + balance = getattr(self, "balance", None) + if balance is None: + chunks = self.val__iter__helper(rng) + + else: + # create a subchunk generator for each combination of "balance" keys + subchunks = dict() + for product in itertools.product( + [self.metadata_unique_values[key] for key in balance] + ): + filters = {key: value for key, value in zip(balance, product)} + subchunks[product] = self.train__iter__helper(rng, **filters) + + while True: + # select one subchunk generator at random (with uniform probability) + # so that it is balanced on average + if balance is not None: + chunks = subchunks[rng.choice(subchunks)] + + # generate random chunk + yield next(chunks) + + def train__len__(self): + # Number of training samples in one epoch + + duration = np.sum(self.annotated_duration) + return max(self.batch_size, math.ceil(duration / self.duration)) + + def val__iter__helper(self, rng: random.Random, **filters): + """Iterate over training samples with optional domain filtering + + Parameters + ---------- + rng : random.Random + Random number generator + filters : dict, optional + When provided (as {key: value} dict), filter training files so that + only files such as file[key] == value are used for generating chunks. + + Yields + ------ + chunk : dict + Training chunks. + """ + + # indices of training files that matches domain filters + validating = self.metadata["subset"] == Subsets.index("train") + for key, value in filters.items(): + validating &= self.metadata[key] == value + file_ids = np.where(validating)[0] + + # turn annotated duration into a probability distribution + annotated_duration = self.annotated_duration[file_ids] + prob_annotated_duration = annotated_duration / np.sum(annotated_duration) + + duration = self.duration + + num_chunks_per_file = getattr(self, "num_chunks_per_file", 1) + + while True: + # select one file at random (with probability proportional to its annotated duration) + file_id = np.random.choice(file_ids, p=prob_annotated_duration) + annotations = self.annotations[ + np.where(self.annotations["file_id"] == file_id)[0] + ] + + # generate `num_chunks_per_file` chunks from this file + for _ in range(num_chunks_per_file): + # find indices of annotated regions in this file + annotated_region_indices = np.where( + self.annotated_regions["file_id"] == file_id + )[0] + + # turn annotated regions duration into a probability distribution + prob_annotated_regions_duration = self.annotated_regions["duration"][ + annotated_region_indices + ] / np.sum(self.annotated_regions["duration"][annotated_region_indices]) + + # selected one annotated region at random (with probability proportional to its duration) + annotated_region_index = np.random.choice( + annotated_region_indices, p=prob_annotated_regions_duration + ) + + # select one chunk at random in this annotated region + _, _, start, end = self.annotated_regions[annotated_region_index] + start_time = rng.uniform(start, end - duration) + + # find speakers that already appeared and all annotations that contain them + chunk_annotations = annotations[ + (annotations["start"] < start_time + duration) + & (annotations["end"] > start_time) + ] + previous_speaker_labels = list( + np.unique(chunk_annotations["file_label_idx"]) + ) + repeated_speaker_annotations = annotations[ + np.isin(annotations["file_label_idx"], previous_speaker_labels) + ] + + if repeated_speaker_annotations.size == 0: + # if previous chunk has 0 speakers then just sample from all annotated regions again + first_chunk = self.prepare_chunk(file_id, start_time, duration) + + # selected one annotated region at random (with probability proportional to its duration) + annotated_region_index = np.random.choice( + annotated_region_indices, p=prob_annotated_regions_duration + ) + + # select one chunk at random in this annotated region + _, _, start, end = self.annotated_regions[annotated_region_index] + start_time = rng.uniform(start, end - duration) + + second_chunk = self.prepare_chunk(file_id, start_time, duration) + + labels = first_chunk["y"].labels + second_chunk["y"].labels + + if len(labels) <= self.max_speakers_per_chunk: + yield first_chunk + yield second_chunk + + else: + # merge segments that contain repeated speakers + merged_repeated_segments = [ + [ + repeated_speaker_annotations["start"][0], + repeated_speaker_annotations["end"][0], + ] + ] + for _, start, end, _, _, _ in repeated_speaker_annotations: + previous = merged_repeated_segments[-1] + if start <= previous[1]: + previous[1] = max(previous[1], end) + else: + merged_repeated_segments.append([start, end]) + + # find segments that don't contain repeated speakers + segments_without_repeat = [] + current_region_index = 0 + previous_time = self.annotated_regions["start"][ + annotated_region_indices[0] + ] + for segment in merged_repeated_segments: + if ( + segment[0] + > self.annotated_regions["end"][ + annotated_region_indices[current_region_index] + ] + ): + current_region_index += 1 + previous_time = self.annotated_regions["start"][ + annotated_region_indices[current_region_index] + ] + + if segment[0] - previous_time > duration: + segments_without_repeat.append( + (previous_time, segment[0], segment[0] - previous_time) + ) + previous_time = segment[1] + + dtype = [("start", "f"), ("end", "f"), ("duration", "f")] + segments_without_repeat = np.array( + segments_without_repeat, dtype=dtype + ) + + if np.sum(segments_without_repeat["duration"]) != 0: + # only yield chunks if it is possible to choose the second chunk so that yielded chunks are always paired + first_chunk = self.prepare_chunk(file_id, start_time, duration) + + prob_segments_duration = segments_without_repeat[ + "duration" + ] / np.sum(segments_without_repeat["duration"]) + segment = np.random.choice( + segments_without_repeat, p=prob_segments_duration + ) + + start, end, _ = segment + new_start_time = rng.uniform(start, end - duration) + second_chunk = self.prepare_chunk( + file_id, new_start_time, duration + ) + + labels = first_chunk["y"].labels + second_chunk["y"].labels + if len(labels) <= self.max_speakers_per_chunk: + yield first_chunk + yield second_chunk def train__iter__helper(self, rng: random.Random, **filters): """Iterate over training samples with optional domain filtering From 69ce717260ae77dbf6de5a8e5adb469488164f22 Mon Sep 17 00:00:00 2001 From: joonaskalda Date: Fri, 15 Sep 2023 01:18:09 +0300 Subject: [PATCH 52/55] make the additional 2 noise sources optional --- .../speaker_separation_diarization.py | 45 ++++++++++++------- 1 file changed, 28 insertions(+), 17 deletions(-) diff --git a/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py b/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py index 3170680d2..1fbff3dbe 100644 --- a/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py +++ b/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py @@ -411,6 +411,7 @@ def __init__( forced_alignment_weight: float = 0.0, force_alignment: bool = False, log_alignment_accuracy: bool = False, + add_noise_sources: bool = False, ): super().__init__( protocol, @@ -460,6 +461,7 @@ def __init__( self.original_mixtures_for_separation = original_mixtures_for_separation self.forced_alignment_weight = forced_alignment_weight self.log_alignment_accuracy = log_alignment_accuracy + self.add_noise_sources = add_noise_sources def setup(self): super().setup() @@ -1314,10 +1316,14 @@ def training_step(self, batch, batch_idx: int): raise NotImplementedError("Forced alignment requires multilabel diarization") else: - # last 2 sources should only contain noise so we force diarization outputs to 0 - permutated_diarization, permutations = permutate(target, diarization[:, :, :3]) - target = torch.cat((target, torch.zeros(batch_size, num_frames, 2, device=target.device)), dim=2) - permutated_diarization = torch.cat((permutated_diarization, diarization[:, :, 3:]), dim=2) + if self.add_noise_sources: + # last 2 sources should only contain noise so we force diarization outputs to 0 + permutated_diarization, permutations = permutate(target, diarization[:, :, :3]) + target = torch.cat((target, torch.zeros(batch_size, num_frames, 2, device=target.device)), dim=2) + permutated_diarization = torch.cat((permutated_diarization, diarization[:, :, 3:]), dim=2) + else: + permutated_diarization, permutations = permutate(target, diarization) + seg_loss = self.segmentation_loss( permutated_diarization, target, weight=weight ) @@ -1337,20 +1343,25 @@ def training_step(self, batch, batch_idx: int): est_mixes = [] for i in range(bsz // 2): - est_mix1 = mom_sources[i, :, speaker_idx_mix1[i]].sum(1) + mom_sources[i,:,3] - est_mix2 = mom_sources[i, :, speaker_idx_mix2[i]].sum(1) + mom_sources[i,:,4] - est_mix3 = mom_sources[i, :, speaker_idx_mix1[i]].sum(1) + mom_sources[i,:,4] - est_mix4 = mom_sources[i, :, speaker_idx_mix2[i]].sum(1) + mom_sources[i,:,3] - sep_loss_first_part = self.pit_sep_loss( - torch.stack((est_mix1, est_mix2)).unsqueeze(0), torch.stack((mix1[i], mix2[i])).unsqueeze(0) - ) - sep_loss_second_part = self.pit_sep_loss( - torch.stack((est_mix3, est_mix4)).unsqueeze(0), torch.stack((mix1[i], mix2[i])).unsqueeze(0) - ) - if sep_loss_first_part < sep_loss_second_part: - est_mixes.append(torch.stack((est_mix1, est_mix2))) + if self.add_noise_sources: + est_mix1 = mom_sources[i, :, speaker_idx_mix1[i]].sum(1) + mom_sources[i,:,3] + est_mix2 = mom_sources[i, :, speaker_idx_mix2[i]].sum(1) + mom_sources[i,:,4] + est_mix3 = mom_sources[i, :, speaker_idx_mix1[i]].sum(1) + mom_sources[i,:,4] + est_mix4 = mom_sources[i, :, speaker_idx_mix2[i]].sum(1) + mom_sources[i,:,3] + sep_loss_first_part = self.pit_sep_loss( + torch.stack((est_mix1, est_mix2)).unsqueeze(0), torch.stack((mix1[i], mix2[i])).unsqueeze(0) + ) + sep_loss_second_part = self.pit_sep_loss( + torch.stack((est_mix3, est_mix4)).unsqueeze(0), torch.stack((mix1[i], mix2[i])).unsqueeze(0) + ) + if sep_loss_first_part < sep_loss_second_part: + est_mixes.append(torch.stack((est_mix1, est_mix2))) + else: + est_mixes.append(torch.stack((est_mix3, est_mix4))) else: - est_mixes.append(torch.stack((est_mix3, est_mix4))) + est_mix1 = mom_sources[i, :, speaker_idx_mix1[i]].sum(1) + est_mix2 = mom_sources[i, :, speaker_idx_mix2[i]].sum(1) + est_mixes.append(torch.stack((est_mix1, est_mix2))) est_mixes = torch.stack(est_mixes) separation_loss = self.pit_sep_loss( est_mixes, torch.stack((mix1, mix2)).transpose(0, 1) From 3fd95c6314c503eb2aed8cf42c3addd647662a7d Mon Sep 17 00:00:00 2001 From: joonaskalda Date: Fri, 15 Sep 2023 13:07:49 +0300 Subject: [PATCH 53/55] make aligned training the only supported behavior --- .../speaker_separation_diarization.py | 249 +++++++----------- 1 file changed, 97 insertions(+), 152 deletions(-) diff --git a/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py b/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py index 1fbff3dbe..d3872b9c8 100644 --- a/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py +++ b/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py @@ -73,7 +73,7 @@ class CountingCallback(Callback): def on_train_epoch_start(self, trainer, pl_module) -> None: "reset counters" - if pl_module.task.log_alignment_accuracy and pl_module.task.force_alignment: + if pl_module.task.log_alignment_accuracy: pl_module.task.num_correct = 0 pl_module.task.num_total = 0 pl_module.task.num_correct30 = 0 @@ -409,7 +409,6 @@ def __init__( separation_loss_weight: float = 0.5, original_mixtures_for_separation: bool = False, forced_alignment_weight: float = 0.0, - force_alignment: bool = False, log_alignment_accuracy: bool = False, add_noise_sources: bool = False, ): @@ -439,10 +438,7 @@ def __init__( # parameter validation if max_speakers_per_frame is not None: - if max_speakers_per_frame < 1: - raise ValueError( - f"`max_speakers_per_frame` must be 1 or more (you used {max_speakers_per_frame})." - ) + raise NotImplementedError("Powerset multi-class training is not implemented") if batch_size % 2 != 0: raise ValueError( @@ -457,7 +453,6 @@ def __init__( self.separation_loss = CustomMixITLossWrapper(multisrc_neg_sisdr, generalized=True) self.pit_sep_loss = PITLossWrapper(pairwise_neg_sisdr, pit_from="pw_mtx") self.separation_loss_weight = separation_loss_weight - self.force_alignment = force_alignment self.original_mixtures_for_separation = original_mixtures_for_separation self.forced_alignment_weight = forced_alignment_weight self.log_alignment_accuracy = log_alignment_accuracy @@ -567,13 +562,6 @@ def setup(self): self.specifications = (speaker_diarization, speaker_separation) - def setup_loss_func(self): - if self.specifications[0].powerset: - self.model.powerset = Powerset( - len(self.specifications[0].classes), - self.specifications[0].powerset_max_classes, - ) - def prepare_chunk(self, file_id: int, start_time: float, duration: float): """Prepare chunk @@ -1174,24 +1162,9 @@ def segmentation_loss( Permutation-invariant segmentation loss """ - # `clamp_min` is needed to set non-speech weight to 1. - if self.specifications[0].powerset: - # `clamp_min` is needed to set non-speech weight to 1. - class_weight = ( - torch.clamp_min(self.model.powerset.cardinality, 1.0) - if self.weigh_by_cardinality - else None - ) - seg_loss = nll_loss( - permutated_prediction, - torch.argmax(target, dim=-1), - class_weight=class_weight, - weight=weight, - ) - else: - seg_loss = binary_cross_entropy( - permutated_prediction, target.float(), weight=weight - ) + seg_loss = binary_cross_entropy( + permutated_prediction, target.float(), weight=weight + ) return seg_loss @@ -1312,64 +1285,57 @@ def training_step(self, batch, batch_idx: int): ) # (batch_size, num_frames, 1) - if self.specifications[0].powerset: - raise NotImplementedError("Forced alignment requires multilabel diarization") - + if self.add_noise_sources: + # last 2 sources should only contain noise so we force diarization outputs to 0 + permutated_diarization, permutations = permutate(target, diarization[:, :, :3]) + target = torch.cat((target, torch.zeros(batch_size, num_frames, 2, device=target.device)), dim=2) + permutated_diarization = torch.cat((permutated_diarization, diarization[:, :, 3:]), dim=2) else: - if self.add_noise_sources: - # last 2 sources should only contain noise so we force diarization outputs to 0 - permutated_diarization, permutations = permutate(target, diarization[:, :, :3]) - target = torch.cat((target, torch.zeros(batch_size, num_frames, 2, device=target.device)), dim=2) - permutated_diarization = torch.cat((permutated_diarization, diarization[:, :, 3:]), dim=2) - else: - permutated_diarization, permutations = permutate(target, diarization) + permutated_diarization, permutations = permutate(target, diarization) - seg_loss = self.segmentation_loss( - permutated_diarization, target, weight=weight - ) + seg_loss = self.segmentation_loss( + permutated_diarization, target, weight=weight + ) - if self.force_alignment: - speaker_idx_mix1 = [ - [permutations[i][j] for j in range(num_active_speakers_mix1[i])] - for i in range(bsz // 2) - ] - speaker_idx_mix2 = [ - [ - permutations[i][j] - for j in range(num_active_speakers_mix1[i], num_active_speakers_mix1[i] + num_active_speakers_mix2[i]) - ] - for i in range(bsz // 2) + speaker_idx_mix1 = [ + [permutations[i][j] for j in range(num_active_speakers_mix1[i])] + for i in range(bsz // 2) + ] + speaker_idx_mix2 = [ + [ + permutations[i][j] + for j in range(num_active_speakers_mix1[i], num_active_speakers_mix1[i] + num_active_speakers_mix2[i]) ] - - est_mixes = [] - for i in range(bsz // 2): - if self.add_noise_sources: - est_mix1 = mom_sources[i, :, speaker_idx_mix1[i]].sum(1) + mom_sources[i,:,3] - est_mix2 = mom_sources[i, :, speaker_idx_mix2[i]].sum(1) + mom_sources[i,:,4] - est_mix3 = mom_sources[i, :, speaker_idx_mix1[i]].sum(1) + mom_sources[i,:,4] - est_mix4 = mom_sources[i, :, speaker_idx_mix2[i]].sum(1) + mom_sources[i,:,3] - sep_loss_first_part = self.pit_sep_loss( - torch.stack((est_mix1, est_mix2)).unsqueeze(0), torch.stack((mix1[i], mix2[i])).unsqueeze(0) - ) - sep_loss_second_part = self.pit_sep_loss( - torch.stack((est_mix3, est_mix4)).unsqueeze(0), torch.stack((mix1[i], mix2[i])).unsqueeze(0) - ) - if sep_loss_first_part < sep_loss_second_part: - est_mixes.append(torch.stack((est_mix1, est_mix2))) - else: - est_mixes.append(torch.stack((est_mix3, est_mix4))) - else: - est_mix1 = mom_sources[i, :, speaker_idx_mix1[i]].sum(1) - est_mix2 = mom_sources[i, :, speaker_idx_mix2[i]].sum(1) + for i in range(bsz // 2) + ] + + est_mixes = [] + for i in range(bsz // 2): + if self.add_noise_sources: + est_mix1 = mom_sources[i, :, speaker_idx_mix1[i]].sum(1) + mom_sources[i,:,3] + est_mix2 = mom_sources[i, :, speaker_idx_mix2[i]].sum(1) + mom_sources[i,:,4] + est_mix3 = mom_sources[i, :, speaker_idx_mix1[i]].sum(1) + mom_sources[i,:,4] + est_mix4 = mom_sources[i, :, speaker_idx_mix2[i]].sum(1) + mom_sources[i,:,3] + sep_loss_first_part = self.pit_sep_loss( + torch.stack((est_mix1, est_mix2)).unsqueeze(0), torch.stack((mix1[i], mix2[i])).unsqueeze(0) + ) + sep_loss_second_part = self.pit_sep_loss( + torch.stack((est_mix3, est_mix4)).unsqueeze(0), torch.stack((mix1[i], mix2[i])).unsqueeze(0) + ) + if sep_loss_first_part < sep_loss_second_part: est_mixes.append(torch.stack((est_mix1, est_mix2))) - est_mixes = torch.stack(est_mixes) - separation_loss = self.pit_sep_loss( - est_mixes, torch.stack((mix1, mix2)).transpose(0, 1) - ).mean() - _, mixit_partitions = self.separation_loss(mom_sources[:,:,:3].transpose(1, 2), torch.stack((mix1, mix2)).transpose(0, 1)) - else: - separation_loss, _ = self.separation_loss(mom_sources[:,:,:3].transpose(1, 2), torch.stack((mix1, mix2)).transpose(0, 1)) - + else: + est_mixes.append(torch.stack((est_mix3, est_mix4))) + else: + est_mix1 = mom_sources[i, :, speaker_idx_mix1[i]].sum(1) + est_mix2 = mom_sources[i, :, speaker_idx_mix2[i]].sum(1) + est_mixes.append(torch.stack((est_mix1, est_mix2))) + est_mixes = torch.stack(est_mixes) + separation_loss = self.pit_sep_loss( + est_mixes, torch.stack((mix1, mix2)).transpose(0, 1) + ).mean() + _, mixit_partitions = self.separation_loss(mom_sources[:,:,:3].transpose(1, 2), torch.stack((mix1, mix2)).transpose(0, 1)) + if self.original_mixtures_for_separation: raise NotImplementedError # separation_loss += self.separation_loss( @@ -1421,7 +1387,7 @@ def training_step(self, batch, batch_idx: int): prog_bar=False, logger=True, ) - if self.log_alignment_accuracy and self.force_alignment: + if self.log_alignment_accuracy: for i in range(bsz // 2): inverse_mixit_partition = permutations_inverse[i][mixit_partitions[i][0]], permutations_inverse[i][mixit_partitions[i][1]] if set([int(j) for j in speaker_idx_mix1[i]]) <= set(inverse_mixit_partition[0].tolist()) and set([int(j) for j in speaker_idx_mix2[i]]) <= set(inverse_mixit_partition[1].tolist()): @@ -1473,14 +1439,6 @@ def default_metric( ) -> Union[Metric, Sequence[Metric], Dict[str, Metric]]: """Returns diarization error rate and its components""" - if self.specifications[0].powerset: - return { - "DiarizationErrorRate": DiarizationErrorRate(0.5), - "DiarizationErrorRate/Confusion": SpeakerConfusionRate(0.5), - "DiarizationErrorRate/Miss": MissedDetectionRate(0.5), - "DiarizationErrorRate/FalseAlarm": FalseAlarmRate(0.5), - } - return { "DiarizationErrorRate": OptimalDiarizationErrorRate(), "DiarizationErrorRate/Threshold": OptimalDiarizationErrorRateThreshold(), @@ -1551,54 +1509,47 @@ def validation_step(self, batch, batch_idx: int): ) # (batch_size, num_frames, 1) - if self.specifications[0].powerset: - raise NotImplementedError("Forced alignment requires multilabel diarization") - - else: - # last 2 sources should only contain noise so we force diarization outputs to 0 - permutated_diarization, permutations = permutate(target, diarization[:, :, :3]) - target = torch.cat((target, torch.zeros(batch_size, num_frames, 2, device=target.device)), dim=2) - permutated_diarization = torch.cat((permutated_diarization, diarization[:, :, 3:]), dim=2) - seg_loss = self.segmentation_loss( - permutated_diarization, target, weight=weight - ) + # last 2 sources should only contain noise so we force diarization outputs to 0 + permutated_diarization, permutations = permutate(target, diarization[:, :, :3]) + target = torch.cat((target, torch.zeros(batch_size, num_frames, 2, device=target.device)), dim=2) + permutated_diarization = torch.cat((permutated_diarization, diarization[:, :, 3:]), dim=2) + seg_loss = self.segmentation_loss( + permutated_diarization, target, weight=weight + ) - if self.force_alignment: - speaker_idx_mix1 = [ - [permutations[i][j] for j in range(num_active_speakers_mix1[i])] - for i in range(bsz // 2) - ] - speaker_idx_mix2 = [ - [ - permutations[i][j] - for j in range(num_active_speakers_mix1[i], num_active_speakers_mix1[i] + num_active_speakers_mix2[i]) - ] - for i in range(bsz // 2) + speaker_idx_mix1 = [ + [permutations[i][j] for j in range(num_active_speakers_mix1[i])] + for i in range(bsz // 2) + ] + speaker_idx_mix2 = [ + [ + permutations[i][j] + for j in range(num_active_speakers_mix1[i], num_active_speakers_mix1[i] + num_active_speakers_mix2[i]) ] - - est_mixes = [] - for i in range(bsz // 2): - est_mix1 = mom_sources[i, :, speaker_idx_mix1[i]].sum(1) + mom_sources[i,:,3] - est_mix2 = mom_sources[i, :, speaker_idx_mix2[i]].sum(1) + mom_sources[i,:,4] - est_mix3 = mom_sources[i, :, speaker_idx_mix1[i]].sum(1) + mom_sources[i,:,4] - est_mix4 = mom_sources[i, :, speaker_idx_mix2[i]].sum(1) + mom_sources[i,:,3] - sep_loss_first_part = self.pit_sep_loss( - torch.stack((est_mix1, est_mix2)).unsqueeze(0), torch.stack((mix1[i], mix2[i])).unsqueeze(0) - ) - sep_loss_second_part = self.pit_sep_loss( - torch.stack((est_mix3, est_mix4)).unsqueeze(0), torch.stack((mix1[i], mix2[i])).unsqueeze(0) - ) - if sep_loss_first_part < sep_loss_second_part: - est_mixes.append(torch.stack((est_mix1, est_mix2))) - else: - est_mixes.append(torch.stack((est_mix3, est_mix4))) - est_mixes = torch.stack(est_mixes) - separation_loss = self.pit_sep_loss( - est_mixes, torch.stack((mix1, mix2)).transpose(0, 1) - ).mean() - _, mixit_partitions = self.separation_loss(mom_sources[:,:,:3].transpose(1, 2), torch.stack((mix1, mix2)).transpose(0, 1)) - else: - separation_loss, _ = self.separation_loss(mom_sources[:,:,:3].transpose(1, 2), torch.stack((mix1, mix2)).transpose(0, 1)) + for i in range(bsz // 2) + ] + + est_mixes = [] + for i in range(bsz // 2): + est_mix1 = mom_sources[i, :, speaker_idx_mix1[i]].sum(1) + mom_sources[i,:,3] + est_mix2 = mom_sources[i, :, speaker_idx_mix2[i]].sum(1) + mom_sources[i,:,4] + est_mix3 = mom_sources[i, :, speaker_idx_mix1[i]].sum(1) + mom_sources[i,:,4] + est_mix4 = mom_sources[i, :, speaker_idx_mix2[i]].sum(1) + mom_sources[i,:,3] + sep_loss_first_part = self.pit_sep_loss( + torch.stack((est_mix1, est_mix2)).unsqueeze(0), torch.stack((mix1[i], mix2[i])).unsqueeze(0) + ) + sep_loss_second_part = self.pit_sep_loss( + torch.stack((est_mix3, est_mix4)).unsqueeze(0), torch.stack((mix1[i], mix2[i])).unsqueeze(0) + ) + if sep_loss_first_part < sep_loss_second_part: + est_mixes.append(torch.stack((est_mix1, est_mix2))) + else: + est_mixes.append(torch.stack((est_mix3, est_mix4))) + est_mixes = torch.stack(est_mixes) + separation_loss = self.pit_sep_loss( + est_mixes, torch.stack((mix1, mix2)).transpose(0, 1) + ).mean() + _, mixit_partitions = self.separation_loss(mom_sources[:,:,:3].transpose(1, 2), torch.stack((mix1, mix2)).transpose(0, 1)) if self.original_mixtures_for_separation: raise NotImplementedError @@ -1634,13 +1585,10 @@ def validation_step(self, batch, batch_idx: int): logger=True, ) - if self.specifications[0].powerset: - raise NotImplementedError("Forced alignment requires multilabel diarization") - else: - self.model.validation_metric( - torch.transpose(diarization, 1, 2), - torch.transpose(target, 1, 2), - ) + self.model.validation_metric( + torch.transpose(diarization, 1, 2), + torch.transpose(target, 1, 2), + ) self.model.log_dict( @@ -1661,11 +1609,8 @@ def validation_step(self, batch, batch_idx: int): # visualize first 9 validation samples of first batch in Tensorboard/MLflow - if self.specifications[0].powerset: - raise NotImplementedError("Forced alignment requires multilabel diarization") - else: - y = target.float().cpu().numpy() - y_pred = permutated_diarization.cpu().numpy() + y = target.float().cpu().numpy() + y_pred = permutated_diarization.cpu().numpy() # prepare 3 x 3 grid (or smaller if batch size is smaller) num_samples = min(self.batch_size, 9) From c2b3fcb7c409b89cca11700cb27aba7ae8c7baba Mon Sep 17 00:00:00 2001 From: joonaskalda Date: Wed, 20 Sep 2023 13:53:23 +0300 Subject: [PATCH 54/55] clean up --- .../audio/models/segmentation/SepDiarNet.py | 29 +- .../speaker_separation_diarization.py | 537 ++++-------------- 2 files changed, 116 insertions(+), 450 deletions(-) diff --git a/pyannote/audio/models/segmentation/SepDiarNet.py b/pyannote/audio/models/segmentation/SepDiarNet.py index 6db75f4cb..a0e7069ef 100644 --- a/pyannote/audio/models/segmentation/SepDiarNet.py +++ b/pyannote/audio/models/segmentation/SepDiarNet.py @@ -127,7 +127,9 @@ def __init__( encoder_decoder = merge_dict(self.ENCODER_DECODER_DEFAULTS, encoder_decoder) self.n_src = n_sources self.use_lstm = use_lstm - self.save_hyperparameters("encoder_decoder", "lstm", "linear", "convnet", "dprnn") + self.save_hyperparameters( + "encoder_decoder", "lstm", "linear", "convnet", "dprnn" + ) self.learning_rate = lr self.n_sources = n_sources @@ -141,11 +143,12 @@ def __init__( sample_rate=sample_rate, **self.hparams.encoder_decoder ) self.masker = DPRNN(n_feats_out, n_src=n_sources, **self.hparams.dprnn) - #self.convnet= TDConvNet(n_feats_out, **self.hparams.convnet) - - # diarization can use a lower resolution than separation - diarization_scaling = int(256 / encoder_decoder["kernel_size"]) - self.average_pool = nn.AvgPool1d(diarization_scaling, stride=diarization_scaling) + + # diarization can use a lower resolution than separation, use 128x downsampling + diarization_scaling = int(128 / encoder_decoder["stride"]) + self.average_pool = nn.AvgPool1d( + diarization_scaling, stride=diarization_scaling + ) if use_lstm: monolithic = lstm["monolithic"] @@ -169,7 +172,8 @@ def __init__( nn.LSTM( n_feats_out if i == 0 - else lstm["hidden_size"] * (2 if lstm["bidirectional"] else 1), + else lstm["hidden_size"] + * (2 if lstm["bidirectional"] else 1), **one_layer_lstm ) for i in range(num_layers) @@ -178,14 +182,14 @@ def __init__( if linear["num_layers"] < 1: return - + if use_lstm: lstm_out_features: int = self.hparams.lstm["hidden_size"] * ( 2 if self.hparams.lstm["bidirectional"] else 1 ) else: lstm_out_features = 64 - + self.linear = nn.ModuleList( [ nn.Linear(in_features, out_features) @@ -207,14 +211,7 @@ def build(self): 2 if self.hparams.lstm["bidirectional"] else 1 ) - # if isinstance(self.specifications, tuple): - # raise ValueError("PyanNet does not support multi-tasking.") - - # if self.specifications.powerset: out_features = 1 - # else: - # out_features = len(self.specifications.classes) - self.classifier = nn.Linear(in_features, out_features) self.activation = self.default_activation() diff --git a/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py b/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py index d3872b9c8..5e6fe0897 100644 --- a/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py +++ b/pyannote/audio/tasks/segmentation/speaker_separation_diarization.py @@ -56,7 +56,12 @@ from pyannote.audio.utils.loss import binary_cross_entropy, mse_loss, nll_loss from pyannote.audio.utils.permutation import permutate from pyannote.audio.utils.powerset import Powerset -from asteroid.losses import MixITLossWrapper, multisrc_neg_sisdr, PITLossWrapper, pairwise_neg_sisdr +from asteroid.losses import ( + MixITLossWrapper, + multisrc_neg_sisdr, + PITLossWrapper, + pairwise_neg_sisdr, +) from torch.utils.data._utils.collate import default_collate Subsets = list(Subset.__args__) @@ -70,255 +75,6 @@ from torch.utils.data import DataLoader, Dataset, IterableDataset from pyannote.audio.utils.random import create_rng_for_worker -class CountingCallback(Callback): - def on_train_epoch_start(self, trainer, pl_module) -> None: - "reset counters" - if pl_module.task.log_alignment_accuracy: - pl_module.task.num_correct = 0 - pl_module.task.num_total = 0 - pl_module.task.num_correct30 = 0 - pl_module.task.num_correct21 = 0 - pl_module.task.num_correct20 = 0 - pl_module.task.num_correct11 = 0 - pl_module.task.num_correct10 = 0 - pl_module.task.num_total30 = 0 - pl_module.task.num_total21 = 0 - pl_module.task.num_total20 = 0 - pl_module.task.num_total11 = 0 - pl_module.task.num_total10 = 0 - -class CustomMixITLossWrapper(nn.Module): - r"""Custom mixture invariant loss wrapper that returns the best partition - so that it can be checked against the partition determined by forced - alignment. - - Args: - loss_func: function with signature (est_targets, targets, **kwargs). - generalized (bool): Determines how MixIT is applied. If False , - apply MixIT for any number of mixtures as soon as they contain - the same number of sources (:meth:`~MixITLossWrapper.best_part_mixit`.) - If True (default), apply MixIT for two mixtures, but those mixtures do not - necessarly have to contain the same number of sources. - See :meth:`~MixITLossWrapper.best_part_mixit_generalized`. - reduction (string, optional): Specifies the reduction to apply to - the output: - ``'none'`` | ``'mean'``. ``'none'``: no reduction will be applied, - ``'mean'``: the sum of the output will be divided by the number of - elements in the output. - - For each of these modes, the best partition and reordering will be - automatically computed. - - Examples: - >>> import torch - >>> from asteroid.losses import multisrc_mse - >>> mixtures = torch.randn(10, 2, 16000) - >>> est_sources = torch.randn(10, 4, 16000) - >>> # Compute MixIT loss based on pairwise losses - >>> loss_func = MixITLossWrapper(multisrc_mse) - >>> loss_val = loss_func(est_sources, mixtures) - - References - [1] Scott Wisdom et al. "Unsupervised sound separation using - mixtures of mixtures." arXiv:2006.12701 (2020) - """ - - def __init__(self, loss_func, generalized=True, reduction="mean"): - super().__init__() - self.loss_func = loss_func - self.generalized = generalized - self.reduction = reduction - - def forward(self, est_targets, targets, return_est=False, **kwargs): - r"""Find the best partition and return the loss. - - Args: - est_targets: torch.Tensor. Expected shape :math:`(batch, nsrc, *)`. - The batch of target estimates. - targets: torch.Tensor. Expected shape :math:`(batch, nmix, ...)`. - The batch of training targets - return_est: Boolean. Whether to return the estimated mixtures - estimates (To compute metrics or to save example). - **kwargs: additional keyword argument that will be passed to the - loss function. - - Returns: - - Best partition loss for each batch sample, average over - the batch. torch.Tensor(loss_value) - - The estimated mixtures (estimated sources summed according to the partition) - if return_est is True. torch.Tensor of shape :math:`(batch, nmix, ...)`. - """ - # Check input dimensions - assert est_targets.shape[0] == targets.shape[0] - assert est_targets.shape[2] == targets.shape[2] - - if not self.generalized: - min_loss, min_loss_idx, parts = self.best_part_mixit( - self.loss_func, est_targets, targets, **kwargs - ) - else: - min_loss, min_loss_idx, parts = self.best_part_mixit_generalized( - self.loss_func, est_targets, targets, **kwargs - ) - - # Apply any reductions over the batch axis - returned_loss = min_loss.mean() if self.reduction == "mean" else min_loss - if not return_est: - return returned_loss, [parts[i] for i in min_loss_idx] - - # Order and sum on the best partition to get the estimated mixtures - reordered = self.reorder_source(est_targets, targets, min_loss_idx, parts) - return returned_loss, reordered - - @staticmethod - def best_part_mixit(loss_func, est_targets, targets, **kwargs): - r"""Find best partition of the estimated sources that gives the minimum - loss for the MixIT training paradigm in [1]. Valid for any number of - mixtures as soon as they contain the same number of sources. - - Args: - loss_func: function with signature ``(est_targets, targets, **kwargs)`` - The loss function to get batch losses from. - est_targets: torch.Tensor. Expected shape :math:`(batch, nsrc, ...)`. - The batch of target estimates. - targets: torch.Tensor. Expected shape :math:`(batch, nmix, ...)`. - The batch of training targets (mixtures). - **kwargs: additional keyword argument that will be passed to the - loss function. - - Returns: - - :class:`torch.Tensor`: - The loss corresponding to the best permutation of size (batch,). - - - :class:`torch.LongTensor`: - The indices of the best partition. - - - :class:`list`: - list of the possible partitions of the sources. - - """ - nmix = targets.shape[1] - nsrc = est_targets.shape[1] - if nsrc % nmix != 0: - raise ValueError("The mixtures are assumed to contain the same number of sources") - nsrcmix = nsrc // nmix - - # Generate all unique partitions of size k from a list lst of - # length n, where l = n // k is the number of parts. The total - # number of such partitions is: NPK(n,k) = n! / ((k!)^l * l!) - # Algorithm recursively distributes items over parts - def parts_mixit(lst, k, l): - if l == 0: - yield [] - else: - for c in combinations(lst, k): - rest = [x for x in lst if x not in c] - for r in parts_mixit(rest, k, l - 1): - yield [list(c), *r] - - # Generate all the possible partitions - parts = list(parts_mixit(range(nsrc), nsrcmix, nmix)) - # Compute the loss corresponding to each partition - loss_set = CustomMixITLossWrapper.loss_set_from_parts( - loss_func, est_targets=est_targets, targets=targets, parts=parts, **kwargs - ) - # Indexes and values of min losses for each batch element - min_loss, min_loss_indexes = torch.min(loss_set, dim=1, keepdim=True) - return min_loss, min_loss_indexes, parts - - @staticmethod - def best_part_mixit_generalized(loss_func, est_targets, targets, **kwargs): - r"""Find best partition of the estimated sources that gives the minimum - loss for the MixIT training paradigm in [1]. Valid only for two mixtures, - but those mixtures do not necessarly have to contain the same number of - sources e.g the case where one mixture is silent is allowed.. - - Args: - loss_func: function with signature ``(est_targets, targets, **kwargs)`` - The loss function to get batch losses from. - est_targets: torch.Tensor. Expected shape :math:`(batch, nsrc, ...)`. - The batch of target estimates. - targets: torch.Tensor. Expected shape :math:`(batch, nmix, ...)`. - The batch of training targets (mixtures). - **kwargs: additional keyword argument that will be passed to the - loss function. - - Returns: - - :class:`torch.Tensor`: - The loss corresponding to the best permutation of size (batch,). - - - :class:`torch.LongTensor`: - The indexes of the best permutations. - - - :class:`list`: - list of the possible partitions of the sources. - """ - nmix = targets.shape[1] # number of mixtures - nsrc = est_targets.shape[1] # number of estimated sources - if nmix != 2: - raise ValueError("Works only with two mixtures") - - # Generate all unique partitions of any size from a list lst of - # length n. Algorithm recursively distributes items over parts - def parts_mixit_gen(lst): - partitions = [] - for k in range(len(lst) + 1): - for c in combinations(lst, k): - rest = [x for x in lst if x not in c] - partitions.append([list(c), rest]) - return partitions - - # Generate all the possible partitions - parts = parts_mixit_gen(range(nsrc)) - # Compute the loss corresponding to each partition - loss_set = CustomMixITLossWrapper.loss_set_from_parts( - loss_func, est_targets=est_targets, targets=targets, parts=parts, **kwargs - ) - # Indexes and values of min losses for each batch element - min_loss, min_loss_indexes = torch.min(loss_set, dim=1, keepdim=True) - return min_loss, min_loss_indexes, parts - - @staticmethod - def loss_set_from_parts(loss_func, est_targets, targets, parts, **kwargs): - """Common loop between both best_part_mixit""" - loss_set = [] - for partition in parts: - # sum the sources according to the given partition - est_mixes = torch.stack([est_targets[:, idx, :].sum(1) for idx in partition], dim=1) - # get loss for the given partition - loss_partition = loss_func(est_mixes, targets, **kwargs) - if loss_partition.ndim != 1: - raise ValueError("Loss function return value should be of size (batch,).") - loss_set.append(loss_partition[:, None]) - loss_set = torch.cat(loss_set, dim=1) - return loss_set - - @staticmethod - def reorder_source(est_targets, targets, min_loss_idx, parts): - """Reorder sources according to the best partition. - - Args: - est_targets: torch.Tensor. Expected shape :math:`(batch, nsrc, ...)`. - The batch of target estimates. - targets: torch.Tensor. Expected shape :math:`(batch, nmix, ...)`. - The batch of training targets. - min_loss_idx: torch.LongTensor. The indexes of the best permutations. - parts: list of the possible partitions of the sources. - - Returns: - :class:`torch.Tensor`: Reordered sources of shape :math:`(batch, nmix, time)`. - - """ - # For each batch there is a different min_loss_idx - ordered = torch.zeros_like(targets) - for b, idx in enumerate(min_loss_idx): - right_partition = parts[idx] - # Sum the estimated sources to get the estimated mixtures - ordered[b, :, :] = torch.stack( - [est_targets[b, idx, :][None, :, :].sum(1) for idx in right_partition], dim=1 - ) - - return ordered class ValDataset(IterableDataset): def __init__(self, task: Task): @@ -330,7 +86,8 @@ def __iter__(self): def __len__(self): return self.task.val__len__() - + + class JointSpeakerSeparationAndDiarization(SegmentationTaskMixin, Task): """Speaker diarization @@ -409,7 +166,6 @@ def __init__( separation_loss_weight: float = 0.5, original_mixtures_for_separation: bool = False, forced_alignment_weight: float = 0.0, - log_alignment_accuracy: bool = False, add_noise_sources: bool = False, ): super().__init__( @@ -438,7 +194,9 @@ def __init__( # parameter validation if max_speakers_per_frame is not None: - raise NotImplementedError("Powerset multi-class training is not implemented") + raise NotImplementedError( + "Diarization is done on masks separately which is incompatible powerset training" + ) if batch_size % 2 != 0: raise ValueError( @@ -450,12 +208,10 @@ def __init__( self.weigh_by_cardinality = weigh_by_cardinality self.balance = balance self.weight = weight - self.separation_loss = CustomMixITLossWrapper(multisrc_neg_sisdr, generalized=True) self.pit_sep_loss = PITLossWrapper(pairwise_neg_sisdr, pit_from="pw_mtx") self.separation_loss_weight = separation_loss_weight self.original_mixtures_for_separation = original_mixtures_for_separation self.forced_alignment_weight = forced_alignment_weight - self.log_alignment_accuracy = log_alignment_accuracy self.add_noise_sources = add_noise_sources def setup(self): @@ -664,7 +420,7 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): sample["meta"]["file"] = file_id return sample - + def val_dataloader(self) -> DataLoader: return DataLoader( ValDataset(self), @@ -674,7 +430,7 @@ def val_dataloader(self) -> DataLoader: drop_last=True, collate_fn=partial(self.collate_fn, stage="train"), ) - + def val__iter__(self): """Iterate over training samples @@ -714,7 +470,7 @@ def val__iter__(self): # generate random chunk yield next(chunks) - + def train__len__(self): # Number of training samples in one epoch @@ -1170,8 +926,8 @@ def segmentation_loss( def create_mixtures_of_mixtures(self, mix1, mix2, target1, target2): """ - Creates mixtures of mixtures and corresponding diarization targets. - Keeps track of how many speakers came from each mixture in order to + Creates mixtures of mixtures and corresponding diarization targets. + Keeps track of how many speakers came from each mixture in order to reconstruct the original mixtures. Parameters @@ -1223,23 +979,7 @@ def create_mixtures_of_mixtures(self, mix1, mix2, target1, target2): return mom, targets, num_active_speakers_mix1, num_active_speakers_mix2 - def training_step(self, batch, batch_idx: int): - """Compute permutation-invariant segmentation loss - - Parameters - ---------- - batch : (usually) dict of torch.Tensor - Current batch. - batch_idx: int - Batch index. - - Returns - ------- - loss : {str: torch.tensor} - {"loss": loss} - """ - - # target + def common_step(self, batch): target = batch["y"] # (batch_size, num_frames, num_speakers) @@ -1251,6 +991,14 @@ def training_step(self, batch, batch_idx: int): # forward pass bsz = waveform.shape[0] + + # MoMs can't be created for batch size < 2 + if bsz < 2: + return None + # if bsz not even, then leave out last sample + if bsz % 2 != 0: + waveform = waveform[:-1] + num_samples = waveform.shape[2] mix1 = waveform[0::2].squeeze(1) mix2 = waveform[1::2].squeeze(1) @@ -1287,15 +1035,20 @@ def training_step(self, batch, batch_idx: int): if self.add_noise_sources: # last 2 sources should only contain noise so we force diarization outputs to 0 - permutated_diarization, permutations = permutate(target, diarization[:, :, :3]) - target = torch.cat((target, torch.zeros(batch_size, num_frames, 2, device=target.device)), dim=2) - permutated_diarization = torch.cat((permutated_diarization, diarization[:, :, 3:]), dim=2) + permutated_diarization, permutations = permutate( + target, diarization[:, :, :3] + ) + target = torch.cat( + (target, torch.zeros(batch_size, num_frames, 2, device=target.device)), + dim=2, + ) + permutated_diarization = torch.cat( + (permutated_diarization, diarization[:, :, 3:]), dim=2 + ) else: permutated_diarization, permutations = permutate(target, diarization) - seg_loss = self.segmentation_loss( - permutated_diarization, target, weight=weight - ) + seg_loss = self.segmentation_loss(permutated_diarization, target, weight=weight) speaker_idx_mix1 = [ [permutations[i][j] for j in range(num_active_speakers_mix1[i])] @@ -1304,23 +1057,36 @@ def training_step(self, batch, batch_idx: int): speaker_idx_mix2 = [ [ permutations[i][j] - for j in range(num_active_speakers_mix1[i], num_active_speakers_mix1[i] + num_active_speakers_mix2[i]) + for j in range( + num_active_speakers_mix1[i], + num_active_speakers_mix1[i] + num_active_speakers_mix2[i], + ) ] for i in range(bsz // 2) ] - + est_mixes = [] for i in range(bsz // 2): if self.add_noise_sources: - est_mix1 = mom_sources[i, :, speaker_idx_mix1[i]].sum(1) + mom_sources[i,:,3] - est_mix2 = mom_sources[i, :, speaker_idx_mix2[i]].sum(1) + mom_sources[i,:,4] - est_mix3 = mom_sources[i, :, speaker_idx_mix1[i]].sum(1) + mom_sources[i,:,4] - est_mix4 = mom_sources[i, :, speaker_idx_mix2[i]].sum(1) + mom_sources[i,:,3] + est_mix1 = ( + mom_sources[i, :, speaker_idx_mix1[i]].sum(1) + mom_sources[i, :, 3] + ) + est_mix2 = ( + mom_sources[i, :, speaker_idx_mix2[i]].sum(1) + mom_sources[i, :, 4] + ) + est_mix3 = ( + mom_sources[i, :, speaker_idx_mix1[i]].sum(1) + mom_sources[i, :, 4] + ) + est_mix4 = ( + mom_sources[i, :, speaker_idx_mix2[i]].sum(1) + mom_sources[i, :, 3] + ) sep_loss_first_part = self.pit_sep_loss( - torch.stack((est_mix1, est_mix2)).unsqueeze(0), torch.stack((mix1[i], mix2[i])).unsqueeze(0) + torch.stack((est_mix1, est_mix2)).unsqueeze(0), + torch.stack((mix1[i], mix2[i])).unsqueeze(0), ) - sep_loss_second_part = self.pit_sep_loss( - torch.stack((est_mix3, est_mix4)).unsqueeze(0), torch.stack((mix1[i], mix2[i])).unsqueeze(0) + sep_loss_second_part = self.pit_sep_loss( + torch.stack((est_mix3, est_mix4)).unsqueeze(0), + torch.stack((mix1[i], mix2[i])).unsqueeze(0), ) if sep_loss_first_part < sep_loss_second_part: est_mixes.append(torch.stack((est_mix1, est_mix2))) @@ -1334,8 +1100,7 @@ def training_step(self, batch, batch_idx: int): separation_loss = self.pit_sep_loss( est_mixes, torch.stack((mix1, mix2)).transpose(0, 1) ).mean() - _, mixit_partitions = self.separation_loss(mom_sources[:,:,:3].transpose(1, 2), torch.stack((mix1, mix2)).transpose(0, 1)) - + if self.original_mixtures_for_separation: raise NotImplementedError # separation_loss += self.separation_loss( @@ -1351,6 +1116,39 @@ def training_step(self, batch, batch_idx: int): # ) # forced_alignment_loss = forced_alignment_loss.mean() / 3 forced_alignment_loss = 0 + return ( + seg_loss, + separation_loss, + forced_alignment_loss, + diarization, + permutated_diarization, + target, + ) + + def training_step(self, batch, batch_idx: int): + """Compute permutation-invariant segmentation loss + + Parameters + ---------- + batch : (usually) dict of torch.Tensor + Current batch. + batch_idx: int + Batch index. + + Returns + ------- + loss : {str: torch.tensor} + {"loss": loss} + """ + + ( + seg_loss, + separation_loss, + forced_alignment_loss, + diarization, + permutated_diarization, + target, + ) = self.common_step(batch) self.model.log( "loss/train/separation", separation_loss, @@ -1387,50 +1185,6 @@ def training_step(self, batch, batch_idx: int): prog_bar=False, logger=True, ) - if self.log_alignment_accuracy: - for i in range(bsz // 2): - inverse_mixit_partition = permutations_inverse[i][mixit_partitions[i][0]], permutations_inverse[i][mixit_partitions[i][1]] - if set([int(j) for j in speaker_idx_mix1[i]]) <= set(inverse_mixit_partition[0].tolist()) and set([int(j) for j in speaker_idx_mix2[i]]) <= set(inverse_mixit_partition[1].tolist()): - self.num_correct += 1 - if (num_active_speakers_mix1[i], num_active_speakers_mix2[i]) in [(1, 0), (0, 1)]: - self.num_correct10 += 1 - if (num_active_speakers_mix1[i], num_active_speakers_mix2[i]) in [(2, 0), (0, 2)]: - self.num_correct20 += 1 - if (num_active_speakers_mix1[i], num_active_speakers_mix2[i]) in [(3, 0), (0, 3)]: - self.num_correct30 += 1 - if num_active_speakers_mix1[i] == 1 and num_active_speakers_mix2[i] == 1: - self.num_correct11 += 1 - if (num_active_speakers_mix1[i], num_active_speakers_mix2[i]) in [(1, 2), (2, 1)]: - self.num_correct21 += 1 - if (num_active_speakers_mix1[i], num_active_speakers_mix2[i]) in [(1, 0), (0, 1)]: - self.num_total10 += 1 - if (num_active_speakers_mix1[i], num_active_speakers_mix2[i]) in [(2, 0), (0, 2)]: - self.num_total20 += 1 - if (num_active_speakers_mix1[i], num_active_speakers_mix2[i]) in [(3, 0), (0, 3)]: - self.num_total30 += 1 - if num_active_speakers_mix1[i] == 1 and num_active_speakers_mix2[i] == 1: - self.num_total11 += 1 - if (num_active_speakers_mix1[i], num_active_speakers_mix2[i]) in [(1, 2), (2, 1)]: - self.num_total21 += 1 - self.num_total+=1 - if self.num_total30 > 0: - self.model.log("accuracy/3_0", self.num_correct30/self.num_total30, on_step=False, on_epoch=True, prog_bar=False, logger=True) - if self.num_total20 > 0: - self.model.log("accuracy/2_0", self.num_correct20/self.num_total20, on_step=False, on_epoch=True, prog_bar=False, logger=True) - if self.num_total10 > 0: - self.model.log("accuracy/1_0", self.num_correct10/self.num_total10, on_step=False, on_epoch=True, prog_bar=False, logger=True) - if self.num_total11 > 0: - self.model.log("accuracy/1_1", self.num_correct11/self.num_total11, on_step=False, on_epoch=True, prog_bar=False, logger=True) - if self.num_total21 > 0: - self.model.log("accuracy/2_1", self.num_correct21/self.num_total21, on_step=False, on_epoch=True, prog_bar=False, logger=True) - if self.num_total > 0: - self.model.log("accuracy/total", self.num_correct/self.num_total, on_step=False, on_epoch=True, prog_bar=False, logger=True) - self.model.log("counts/3_0", self.num_total30, on_step=False, on_epoch=True, prog_bar=False, logger=True) - self.model.log("counts/2_0", self.num_total20, on_step=False, on_epoch=True, prog_bar=False, logger=True) - self.model.log("counts/1_0", self.num_total10, on_step=False, on_epoch=True, prog_bar=False, logger=True) - self.model.log("counts/1_1", self.num_total11, on_step=False, on_epoch=True, prog_bar=False, logger=True) - self.model.log("counts/2_1", self.num_total21, on_step=False, on_epoch=True, prog_bar=False, logger=True) - self.model.log("counts/total", self.num_total, on_step=False, on_epoch=True, prog_bar=False, logger=True) return {"loss": loss} @@ -1459,100 +1213,14 @@ def validation_step(self, batch, batch_idx: int): Batch index. """ - # target - target = batch["y"] - # (batch_size, num_frames, num_speakers) - - waveform = batch["X"] - # (batch_size, num_channels, num_samples) - - # TODO: should we handle validation samples with too many speakers - # waveform = waveform[keep] - # target = target[keep] - - bsz = waveform.shape[0] - num_samples = waveform.shape[2] - # MoMs can't be created for batch size < 2 - if bsz < 2: - return None - # if bsz not even, then leave out last sample - if bsz % 2 != 0: - waveform = waveform[:-1] - - # if bsz not even, then leave out last sample - mix1 = waveform[0::2].squeeze(1) - mix2 = waveform[1::2].squeeze(1) - if self.original_mixtures_for_separation: - # extract parts with only one speaker from original mixtures - mix1_masks = batch["X_separation_mask"][0::2] - mix2_masks = batch["X_separation_mask"][1::2] - mix1_masked = mix1 * mix1_masks - mix2_masked = mix2 * mix2_masks - ( - mom, - mom_target, - num_active_speakers_mix1, - num_active_speakers_mix2, - ) = self.create_mixtures_of_mixtures(mix1, mix2, target[0::2], target[1::2]) - - # forward pass - diarization, _ = self.model(waveform) - _, mom_sources = self.model(mom) - batch_size, num_frames, _ = diarization.shape - - # frames weight - weight_key = getattr(self, "weight", None) - weight = batch.get( - weight_key, - torch.ones(batch_size, num_frames, 1, device=self.model.device), - ) - # (batch_size, num_frames, 1) - - # last 2 sources should only contain noise so we force diarization outputs to 0 - permutated_diarization, permutations = permutate(target, diarization[:, :, :3]) - target = torch.cat((target, torch.zeros(batch_size, num_frames, 2, device=target.device)), dim=2) - permutated_diarization = torch.cat((permutated_diarization, diarization[:, :, 3:]), dim=2) - seg_loss = self.segmentation_loss( - permutated_diarization, target, weight=weight - ) - - speaker_idx_mix1 = [ - [permutations[i][j] for j in range(num_active_speakers_mix1[i])] - for i in range(bsz // 2) - ] - speaker_idx_mix2 = [ - [ - permutations[i][j] - for j in range(num_active_speakers_mix1[i], num_active_speakers_mix1[i] + num_active_speakers_mix2[i]) - ] - for i in range(bsz // 2) - ] - - est_mixes = [] - for i in range(bsz // 2): - est_mix1 = mom_sources[i, :, speaker_idx_mix1[i]].sum(1) + mom_sources[i,:,3] - est_mix2 = mom_sources[i, :, speaker_idx_mix2[i]].sum(1) + mom_sources[i,:,4] - est_mix3 = mom_sources[i, :, speaker_idx_mix1[i]].sum(1) + mom_sources[i,:,4] - est_mix4 = mom_sources[i, :, speaker_idx_mix2[i]].sum(1) + mom_sources[i,:,3] - sep_loss_first_part = self.pit_sep_loss( - torch.stack((est_mix1, est_mix2)).unsqueeze(0), torch.stack((mix1[i], mix2[i])).unsqueeze(0) - ) - sep_loss_second_part = self.pit_sep_loss( - torch.stack((est_mix3, est_mix4)).unsqueeze(0), torch.stack((mix1[i], mix2[i])).unsqueeze(0) - ) - if sep_loss_first_part < sep_loss_second_part: - est_mixes.append(torch.stack((est_mix1, est_mix2))) - else: - est_mixes.append(torch.stack((est_mix3, est_mix4))) - est_mixes = torch.stack(est_mixes) - separation_loss = self.pit_sep_loss( - est_mixes, torch.stack((mix1, mix2)).transpose(0, 1) - ).mean() - _, mixit_partitions = self.separation_loss(mom_sources[:,:,:3].transpose(1, 2), torch.stack((mix1, mix2)).transpose(0, 1)) - - if self.original_mixtures_for_separation: - raise NotImplementedError + seg_loss, + separation_loss, + forced_alignment_loss, + diarization, + permutated_diarization, + target, + ) = self.common_step(batch) self.model.log( "loss/val/separation", @@ -1573,8 +1241,10 @@ def validation_step(self, batch, batch_idx: int): ) loss = ( - 1 - self.separation_loss_weight - ) * seg_loss + self.separation_loss_weight * separation_loss + (1 - self.separation_loss_weight) * seg_loss + + self.separation_loss_weight * separation_loss + + forced_alignment_loss * self.forced_alignment_weight + ) self.model.log( "loss/val", @@ -1590,7 +1260,6 @@ def validation_step(self, batch, batch_idx: int): torch.transpose(target, 1, 2), ) - self.model.log_dict( self.model.validation_metric, on_step=False, From da8c8da8c1b5d79c3a1abb438ac4dbaa5b60b9b3 Mon Sep 17 00:00:00 2001 From: joonaskalda Date: Wed, 20 Sep 2023 14:59:27 +0300 Subject: [PATCH 55/55] change default model parameters --- pyannote/audio/models/segmentation/SepDiarNet.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/pyannote/audio/models/segmentation/SepDiarNet.py b/pyannote/audio/models/segmentation/SepDiarNet.py index a0e7069ef..c9f03e0a1 100644 --- a/pyannote/audio/models/segmentation/SepDiarNet.py +++ b/pyannote/audio/models/segmentation/SepDiarNet.py @@ -65,12 +65,11 @@ class SepDiarNet(Model): i.e. two linear layers with 128 units each. """ - SINCNET_DEFAULTS = {"stride": 10} ENCODER_DECODER_DEFAULTS = { "fb_name": "stft", "kernel_size": 512, - "n_filters": 512, - "stride": 256, + "n_filters": 64, + "stride": 32, } LSTM_DEFAULTS = { "hidden_size": 128, @@ -113,7 +112,7 @@ def __init__( num_channels: int = 1, task: Optional[Task] = None, encoder_type: str = None, - n_sources: int = 5, + n_sources: int = 3, use_lstm: bool = False, lr: float = 1e-3, ):